1//===- Traits.h - Common op traits shared by dialects -----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file declares common op traits that are not core to MLIR but can be
10// shared by multiple dialects.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef MLIR_DIALECT_TRAITS_H
15#define MLIR_DIALECT_TRAITS_H
16
17#include "mlir/IR/OpDefinition.h"
18
19namespace mlir {
20namespace OpTrait {
21
22// These functions are out-of-line implementations of the methods in the
23// corresponding trait classes. This avoids them being template
24// instantiated/duplicated.
25namespace impl {
26LogicalResult verifyCompatibleOperandBroadcast(Operation *op);
27} // namespace impl
28
29namespace util {
30/// Returns true and sets `resultShape` to the broadcasted shape from the two
31/// given shapes if they are broadcast compatible. Returns false and clears
32/// `resultShape` otherwise.
33///
34/// The rules for determining the result shape are:
35///
36/// Zip together the dimensions in the two given shapes by prepending the shape
37/// with less dimensions with 1s. For each dimension pair, deduces the result
38/// dimension according to the following order:
39/// - If there are unknown dimensions, follows the TensorFlow behavior:
40/// - If either dimension is greater than 1, we assume that the program is
41/// correct, and the other dimension will be broadcast to match it.
42/// - If either dimension is 1, the other dimension is the result.
43/// - Otherwise, the result dimension is unknown dimension.
44/// - If one of the dimension is 1, the other dimension is the result.
45/// - If two dimensions are the same, that's the result.
46/// - Otherwise, incompatible shape.
47bool getBroadcastedShape(ArrayRef<int64_t> shape1, ArrayRef<int64_t> shape2,
48 SmallVectorImpl<int64_t> &resultShape);
49
50/// Returns true if a broadcast between n shapes is guaranteed to be
51/// successful and not result in an error. False does not guarantee that the
52/// shapes are not broadcastable; it might guarantee that they are not
53/// broadcastable or it might mean that this function does not have enough
54/// information to know.
55///
56/// Conceptually, this returns true if getBroadcastedShape would have returned
57/// true and vice versa, with one exception. If a dimension is unknown in both
58/// shapes, getBroadcastedShape would return true and have a result with unknown
59/// dimension, while this function will return false because it's possible for
60/// both shapes to have a dimension greater than 1 and different which would
61/// fail to broadcast.
62bool staticallyKnownBroadcastable(ArrayRef<SmallVector<int64_t, 6>> shapes);
63bool staticallyKnownBroadcastable(ArrayRef<int64_t> shape1,
64 ArrayRef<int64_t> shape2);
65
66/// Returns the result broadcast composition type from the two given types by
67/// following NumPy broadcast semantics. Returned type may have dynamic shape if
68/// either of the input types has dynamic shape. Returns null type if the two
69/// given types are not broadcast-compatible.
70///
71/// elementType, if specified, will be used as the element type of the
72/// broadcasted result type. Otherwise it is required that the element type of
73/// type1 and type2 is the same and this element type will be used as the
74/// resultant element type.
75Type getBroadcastedType(Type type1, Type type2, Type elementType = nullptr);
76
77} // namespace util
78
79/// Trait for ops that are known to have broadcast compatible operands and
80/// result types. Specifically, starting from the most varying dimension, each
81/// dimension pair of the operands' shapes should either be the same or one
82/// of them is one. Also, the results's shapes should have the corresponding
83/// dimension equal to the larger one, if known. Shapes are checked partially if
84/// ranks or dimensions are not known. For example, an op with tensor<?x2xf32>
85/// and tensor<2xf32> as operand types and tensor<5x3x2xi16> as the result
86/// type has broadcast compatible operands ns result types.
87template <typename ConcreteType>
88class ResultsBroadcastableShape
89 : public TraitBase<ConcreteType, ResultsBroadcastableShape> {
90public:
91 static LogicalResult verifyTrait(Operation *op) {
92 return impl::verifyCompatibleOperandBroadcast(op);
93 }
94};
95
96} // namespace OpTrait
97} // namespace mlir
98
99#endif // MLIR_DIALECT_TRAITS_H
100

source code of mlir/include/mlir/Dialect/Traits.h