1//===- TypeUtilities.cpp - Helper function for type queries ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines generic type utilities.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/IR/TypeUtilities.h"
14#include "mlir/IR/Attributes.h"
15#include "mlir/IR/BuiltinTypes.h"
16#include "mlir/IR/Types.h"
17#include "mlir/IR/Value.h"
18#include "llvm/ADT/SmallVectorExtras.h"
19#include <numeric>
20
21using namespace mlir;
22
23Type mlir::getElementTypeOrSelf(Type type) {
24 if (auto st = llvm::dyn_cast<ShapedType>(type))
25 return st.getElementType();
26 return type;
27}
28
29Type mlir::getElementTypeOrSelf(Value val) {
30 return getElementTypeOrSelf(type: val.getType());
31}
32
33Type mlir::getElementTypeOrSelf(Attribute attr) {
34 if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr))
35 return getElementTypeOrSelf(typedAttr.getType());
36 return {};
37}
38
39SmallVector<Type, 10> mlir::getFlattenedTypes(TupleType t) {
40 SmallVector<Type, 10> fTypes;
41 t.getFlattenedTypes(fTypes);
42 return fTypes;
43}
44
45/// Return true if the specified type is an opaque type with the specified
46/// dialect and typeData.
47bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect,
48 StringRef typeData) {
49 if (auto opaque = llvm::dyn_cast<mlir::OpaqueType>(type))
50 return opaque.getDialectNamespace() == dialect &&
51 opaque.getTypeData() == typeData;
52 return false;
53}
54
55/// Returns success if the given two shapes are compatible. That is, they have
56/// the same size and each pair of the elements are equal or one of them is
57/// dynamic.
58LogicalResult mlir::verifyCompatibleShape(ArrayRef<int64_t> shape1,
59 ArrayRef<int64_t> shape2) {
60 if (shape1.size() != shape2.size())
61 return failure();
62 for (auto dims : llvm::zip(t&: shape1, u&: shape2)) {
63 int64_t dim1 = std::get<0>(t&: dims);
64 int64_t dim2 = std::get<1>(t&: dims);
65 if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) &&
66 dim1 != dim2)
67 return failure();
68 }
69 return success();
70}
71
72/// Returns success if the given two types have compatible shape. That is,
73/// they are both scalars (not shaped), or they are both shaped types and at
74/// least one is unranked or they have compatible dimensions. Dimensions are
75/// compatible if at least one is dynamic or both are equal. The element type
76/// does not matter.
77LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) {
78 auto sType1 = llvm::dyn_cast<ShapedType>(type1);
79 auto sType2 = llvm::dyn_cast<ShapedType>(type2);
80
81 // Either both or neither type should be shaped.
82 if (!sType1)
83 return success(!sType2);
84 if (!sType2)
85 return failure();
86
87 if (!sType1.hasRank() || !sType2.hasRank())
88 return success();
89
90 return verifyCompatibleShape(sType1.getShape(), sType2.getShape());
91}
92
93/// Returns success if the given two arrays have the same number of elements and
94/// each pair wise entries have compatible shape.
95LogicalResult mlir::verifyCompatibleShapes(TypeRange types1, TypeRange types2) {
96 if (types1.size() != types2.size())
97 return failure();
98 for (auto it : llvm::zip_first(t&: types1, u&: types2))
99 if (failed(result: verifyCompatibleShape(type1: std::get<0>(t&: it), type2: std::get<1>(t&: it))))
100 return failure();
101 return success();
102}
103
104LogicalResult mlir::verifyCompatibleDims(ArrayRef<int64_t> dims) {
105 if (dims.empty())
106 return success();
107 auto staticDim = std::accumulate(
108 first: dims.begin(), last: dims.end(), init: dims.front(), binary_op: [](auto fold, auto dim) {
109 return ShapedType::isDynamic(dim) ? fold : dim;
110 });
111 return success(isSuccess: llvm::all_of(Range&: dims, P: [&](auto dim) {
112 return ShapedType::isDynamic(dim) || dim == staticDim;
113 }));
114}
115
116/// Returns success if all given types have compatible shapes. That is, they are
117/// all scalars (not shaped), or they are all shaped types and any ranked shapes
118/// have compatible dimensions. Dimensions are compatible if all non-dynamic
119/// dims are equal. The element type does not matter.
120LogicalResult mlir::verifyCompatibleShapes(TypeRange types) {
121 auto shapedTypes = llvm::map_to_vector<8>(
122 types, [](auto type) { return llvm::dyn_cast<ShapedType>(type); });
123 // Return failure if some, but not all are not shaped. Return early if none
124 // are shaped also.
125 if (llvm::none_of(shapedTypes, [](auto t) { return t; }))
126 return success();
127 if (!llvm::all_of(shapedTypes, [](auto t) { return t; }))
128 return failure();
129
130 // Return failure if some, but not all, are scalable vectors.
131 bool hasScalableVecTypes = false;
132 bool hasNonScalableVecTypes = false;
133 for (Type t : types) {
134 auto vType = llvm::dyn_cast<VectorType>(t);
135 if (vType && vType.isScalable())
136 hasScalableVecTypes = true;
137 else
138 hasNonScalableVecTypes = true;
139 if (hasScalableVecTypes && hasNonScalableVecTypes)
140 return failure();
141 }
142
143 // Remove all unranked shapes
144 auto shapes = llvm::to_vector<8>(llvm::make_filter_range(
145 shapedTypes, [](auto shapedType) { return shapedType.hasRank(); }));
146 if (shapes.empty())
147 return success();
148
149 // All ranks should be equal
150 auto firstRank = shapes.front().getRank();
151 if (llvm::any_of(shapes,
152 [&](auto shape) { return firstRank != shape.getRank(); }))
153 return failure();
154
155 for (unsigned i = 0; i < firstRank; ++i) {
156 // Retrieve all ranked dimensions
157 auto dims = llvm::map_to_vector<8>(
158 llvm::make_filter_range(
159 shapes, [&](auto shape) { return shape.getRank() >= i; }),
160 [&](auto shape) { return shape.getDimSize(i); });
161 if (verifyCompatibleDims(dims).failed())
162 return failure();
163 }
164
165 return success();
166}
167
168Type OperandElementTypeIterator::mapElement(Value value) const {
169 return llvm::cast<ShapedType>(value.getType()).getElementType();
170}
171
172Type ResultElementTypeIterator::mapElement(Value value) const {
173 return llvm::cast<ShapedType>(value.getType()).getElementType();
174}
175
176TypeRange mlir::insertTypesInto(TypeRange oldTypes, ArrayRef<unsigned> indices,
177 TypeRange newTypes,
178 SmallVectorImpl<Type> &storage) {
179 assert(indices.size() == newTypes.size() &&
180 "mismatch between indice and type count");
181 if (indices.empty())
182 return oldTypes;
183
184 auto fromIt = oldTypes.begin();
185 for (auto it : llvm::zip(t&: indices, u&: newTypes)) {
186 const auto toIt = oldTypes.begin() + std::get<0>(t&: it);
187 storage.append(in_start: fromIt, in_end: toIt);
188 storage.push_back(Elt: std::get<1>(t&: it));
189 fromIt = toIt;
190 }
191 storage.append(in_start: fromIt, in_end: oldTypes.end());
192 return storage;
193}
194
195TypeRange mlir::filterTypesOut(TypeRange types, const BitVector &indices,
196 SmallVectorImpl<Type> &storage) {
197 if (indices.none())
198 return types;
199
200 for (unsigned i = 0, e = types.size(); i < e; ++i)
201 if (!indices[i])
202 storage.emplace_back(Args: types[i]);
203 return storage;
204}
205

source code of mlir/lib/IR/TypeUtilities.cpp