1 | //===- BuiltinTypes.h - MLIR Builtin Type Classes ---------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_IR_BUILTINTYPES_H |
10 | #define MLIR_IR_BUILTINTYPES_H |
11 | |
12 | #include "mlir/IR/BuiltinAttributeInterfaces.h" |
13 | #include "mlir/IR/BuiltinTypeInterfaces.h" |
14 | #include "mlir/Support/ADTExtras.h" |
15 | |
16 | namespace llvm { |
17 | class BitVector; |
18 | struct fltSemantics; |
19 | } // namespace llvm |
20 | |
21 | //===----------------------------------------------------------------------===// |
22 | // Tablegen Interface Declarations |
23 | //===----------------------------------------------------------------------===// |
24 | |
25 | namespace mlir { |
26 | class AffineExpr; |
27 | class AffineMap; |
28 | class FloatType; |
29 | class IndexType; |
30 | class IntegerType; |
31 | class MemRefType; |
32 | class RankedTensorType; |
33 | class StringAttr; |
34 | class TypeRange; |
35 | |
36 | namespace detail { |
37 | struct FunctionTypeStorage; |
38 | struct IntegerTypeStorage; |
39 | struct TupleTypeStorage; |
40 | } // namespace detail |
41 | |
42 | //===----------------------------------------------------------------------===// |
43 | // FloatType |
44 | //===----------------------------------------------------------------------===// |
45 | |
46 | class FloatType : public Type { |
47 | public: |
48 | using Type::Type; |
49 | |
50 | // Convenience factories. |
51 | static FloatType getBF16(MLIRContext *ctx); |
52 | static FloatType getF16(MLIRContext *ctx); |
53 | static FloatType getF32(MLIRContext *ctx); |
54 | static FloatType getTF32(MLIRContext *ctx); |
55 | static FloatType getF64(MLIRContext *ctx); |
56 | static FloatType getF80(MLIRContext *ctx); |
57 | static FloatType getF128(MLIRContext *ctx); |
58 | static FloatType getFloat8E5M2(MLIRContext *ctx); |
59 | static FloatType getFloat8E4M3FN(MLIRContext *ctx); |
60 | static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx); |
61 | static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx); |
62 | static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx); |
63 | |
64 | /// Methods for support type inquiry through isa, cast, and dyn_cast. |
65 | static bool classof(Type type); |
66 | |
67 | /// Return the bitwidth of this float type. |
68 | unsigned getWidth(); |
69 | |
70 | /// Return the width of the mantissa of this type. |
71 | /// The width includes the integer bit. |
72 | unsigned getFPMantissaWidth(); |
73 | |
74 | /// Get or create a new FloatType with bitwidth scaled by `scale`. |
75 | /// Return null if the scaled element type cannot be represented. |
76 | FloatType scaleElementBitwidth(unsigned scale); |
77 | |
78 | /// Return the floating semantics of this float type. |
79 | const llvm::fltSemantics &getFloatSemantics(); |
80 | }; |
81 | |
82 | //===----------------------------------------------------------------------===// |
83 | // TensorType |
84 | //===----------------------------------------------------------------------===// |
85 | |
86 | /// Tensor types represent multi-dimensional arrays, and have two variants: |
87 | /// RankedTensorType and UnrankedTensorType. |
88 | /// Note: This class attaches the ShapedType trait to act as a mixin to |
89 | /// provide many useful utility functions. This inheritance has no effect |
90 | /// on derived tensor types. |
91 | class TensorType : public Type, public ShapedType::Trait<TensorType> { |
92 | public: |
93 | using Type::Type; |
94 | |
95 | /// Returns the element type of this tensor type. |
96 | Type getElementType() const; |
97 | |
98 | /// Returns if this type is ranked, i.e. it has a known number of dimensions. |
99 | bool hasRank() const; |
100 | |
101 | /// Returns the shape of this tensor type. |
102 | ArrayRef<int64_t> getShape() const; |
103 | |
104 | /// Clone this type with the given shape and element type. If the |
105 | /// provided shape is `std::nullopt`, the current shape of the type is used. |
106 | TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape, |
107 | Type elementType) const; |
108 | |
109 | // Make sure that base class overloads are visible. |
110 | using ShapedType::Trait<TensorType>::clone; |
111 | |
112 | /// Return a clone of this type with the given new shape and element type. |
113 | /// The returned type is ranked, even if this type is unranked. |
114 | RankedTensorType clone(ArrayRef<int64_t> shape, Type elementType) const; |
115 | |
116 | /// Return a clone of this type with the given new shape. The returned type |
117 | /// is ranked, even if this type is unranked. |
118 | RankedTensorType clone(ArrayRef<int64_t> shape) const; |
119 | |
120 | /// Return true if the specified element type is ok in a tensor. |
121 | static bool isValidElementType(Type type); |
122 | |
123 | /// Methods for support type inquiry through isa, cast, and dyn_cast. |
124 | static bool classof(Type type); |
125 | |
126 | /// Allow implicit conversion to ShapedType. |
127 | operator ShapedType() const { return llvm::cast<ShapedType>(*this); } |
128 | }; |
129 | |
130 | //===----------------------------------------------------------------------===// |
131 | // BaseMemRefType |
132 | //===----------------------------------------------------------------------===// |
133 | |
134 | /// This class provides a shared interface for ranked and unranked memref types. |
135 | /// Note: This class attaches the ShapedType trait to act as a mixin to |
136 | /// provide many useful utility functions. This inheritance has no effect |
137 | /// on derived memref types. |
138 | class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> { |
139 | public: |
140 | using Type::Type; |
141 | |
142 | /// Returns the element type of this memref type. |
143 | Type getElementType() const; |
144 | |
145 | /// Returns if this type is ranked, i.e. it has a known number of dimensions. |
146 | bool hasRank() const; |
147 | |
148 | /// Returns the shape of this memref type. |
149 | ArrayRef<int64_t> getShape() const; |
150 | |
151 | /// Clone this type with the given shape and element type. If the |
152 | /// provided shape is `std::nullopt`, the current shape of the type is used. |
153 | BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape, |
154 | Type elementType) const; |
155 | |
156 | // Make sure that base class overloads are visible. |
157 | using ShapedType::Trait<BaseMemRefType>::clone; |
158 | |
159 | /// Return a clone of this type with the given new shape and element type. |
160 | /// The returned type is ranked, even if this type is unranked. |
161 | MemRefType clone(ArrayRef<int64_t> shape, Type elementType) const; |
162 | |
163 | /// Return a clone of this type with the given new shape. The returned type |
164 | /// is ranked, even if this type is unranked. |
165 | MemRefType clone(ArrayRef<int64_t> shape) const; |
166 | |
167 | /// Return true if the specified element type is ok in a memref. |
168 | static bool isValidElementType(Type type); |
169 | |
170 | /// Methods for support type inquiry through isa, cast, and dyn_cast. |
171 | static bool classof(Type type); |
172 | |
173 | /// Returns the memory space in which data referred to by this memref resides. |
174 | Attribute getMemorySpace() const; |
175 | |
176 | /// [deprecated] Returns the memory space in old raw integer representation. |
177 | /// New `Attribute getMemorySpace()` method should be used instead. |
178 | unsigned getMemorySpaceAsInt() const; |
179 | |
180 | /// Allow implicit conversion to ShapedType. |
181 | operator ShapedType() const { return llvm::cast<ShapedType>(*this); } |
182 | }; |
183 | |
184 | } // namespace mlir |
185 | |
186 | //===----------------------------------------------------------------------===// |
187 | // Tablegen Type Declarations |
188 | //===----------------------------------------------------------------------===// |
189 | |
190 | #define GET_TYPEDEF_CLASSES |
191 | #include "mlir/IR/BuiltinTypes.h.inc" |
192 | |
193 | namespace mlir { |
194 | |
195 | //===----------------------------------------------------------------------===// |
196 | // MemRefType |
197 | //===----------------------------------------------------------------------===// |
198 | |
199 | /// This is a builder type that keeps local references to arguments. Arguments |
200 | /// that are passed into the builder must outlive the builder. |
201 | class MemRefType::Builder { |
202 | public: |
203 | // Build from another MemRefType. |
204 | explicit Builder(MemRefType other) |
205 | : shape(other.getShape()), elementType(other.getElementType()), |
206 | layout(other.getLayout()), memorySpace(other.getMemorySpace()) {} |
207 | |
208 | // Build from scratch. |
209 | Builder(ArrayRef<int64_t> shape, Type elementType) |
210 | : shape(shape), elementType(elementType) {} |
211 | |
212 | Builder &setShape(ArrayRef<int64_t> newShape) { |
213 | shape = newShape; |
214 | return *this; |
215 | } |
216 | |
217 | Builder &setElementType(Type newElementType) { |
218 | elementType = newElementType; |
219 | return *this; |
220 | } |
221 | |
222 | Builder &setLayout(MemRefLayoutAttrInterface newLayout) { |
223 | layout = newLayout; |
224 | return *this; |
225 | } |
226 | |
227 | Builder &setMemorySpace(Attribute newMemorySpace) { |
228 | memorySpace = newMemorySpace; |
229 | return *this; |
230 | } |
231 | |
232 | operator MemRefType() { |
233 | return MemRefType::get(shape, elementType, layout, memorySpace); |
234 | } |
235 | |
236 | private: |
237 | ArrayRef<int64_t> shape; |
238 | Type elementType; |
239 | MemRefLayoutAttrInterface layout; |
240 | Attribute memorySpace; |
241 | }; |
242 | |
243 | //===----------------------------------------------------------------------===// |
244 | // RankedTensorType |
245 | //===----------------------------------------------------------------------===// |
246 | |
247 | /// This is a builder type that keeps local references to arguments. Arguments |
248 | /// that are passed into the builder must outlive the builder. |
249 | class RankedTensorType::Builder { |
250 | public: |
251 | /// Build from another RankedTensorType. |
252 | explicit Builder(RankedTensorType other) |
253 | : shape(other.getShape()), elementType(other.getElementType()), |
254 | encoding(other.getEncoding()) {} |
255 | |
256 | /// Build from scratch. |
257 | Builder(ArrayRef<int64_t> shape, Type elementType, Attribute encoding) |
258 | : shape(shape), elementType(elementType), encoding(encoding) {} |
259 | |
260 | Builder &setShape(ArrayRef<int64_t> newShape) { |
261 | shape = newShape; |
262 | return *this; |
263 | } |
264 | |
265 | Builder &setElementType(Type newElementType) { |
266 | elementType = newElementType; |
267 | return *this; |
268 | } |
269 | |
270 | Builder &setEncoding(Attribute newEncoding) { |
271 | encoding = newEncoding; |
272 | return *this; |
273 | } |
274 | |
275 | /// Erase a dim from shape @pos. |
276 | Builder &dropDim(unsigned pos) { |
277 | assert(pos < shape.size() && "overflow" ); |
278 | shape.erase(pos); |
279 | return *this; |
280 | } |
281 | |
282 | /// Insert a val into shape @pos. |
283 | Builder &insertDim(int64_t val, unsigned pos) { |
284 | assert(pos <= shape.size() && "overflow" ); |
285 | shape.insert(pos, val); |
286 | return *this; |
287 | } |
288 | |
289 | operator RankedTensorType() { |
290 | return RankedTensorType::get(shape, elementType, encoding); |
291 | } |
292 | |
293 | private: |
294 | CopyOnWriteArrayRef<int64_t> shape; |
295 | Type elementType; |
296 | Attribute encoding; |
297 | }; |
298 | |
299 | //===----------------------------------------------------------------------===// |
300 | // VectorType |
301 | //===----------------------------------------------------------------------===// |
302 | |
303 | /// This is a builder type that keeps local references to arguments. Arguments |
304 | /// that are passed into the builder must outlive the builder. |
305 | class VectorType::Builder { |
306 | public: |
307 | /// Build from another VectorType. |
308 | explicit Builder(VectorType other) |
309 | : elementType(other.getElementType()), shape(other.getShape()), |
310 | scalableDims(other.getScalableDims()) {} |
311 | |
312 | /// Build from scratch. |
313 | Builder(ArrayRef<int64_t> shape, Type elementType, |
314 | ArrayRef<bool> scalableDims = {}) |
315 | : elementType(elementType), shape(shape), scalableDims(scalableDims) {} |
316 | |
317 | Builder &setShape(ArrayRef<int64_t> newShape, |
318 | ArrayRef<bool> newIsScalableDim = {}) { |
319 | shape = newShape; |
320 | scalableDims = newIsScalableDim; |
321 | return *this; |
322 | } |
323 | |
324 | Builder &setElementType(Type newElementType) { |
325 | elementType = newElementType; |
326 | return *this; |
327 | } |
328 | |
329 | /// Erase a dim from shape @pos. |
330 | Builder &dropDim(unsigned pos) { |
331 | assert(pos < shape.size() && "overflow" ); |
332 | shape.erase(pos); |
333 | if (!scalableDims.empty()) |
334 | scalableDims.erase(pos); |
335 | return *this; |
336 | } |
337 | |
338 | /// Set a dim in shape @pos to val. |
339 | Builder &setDim(unsigned pos, int64_t val) { |
340 | assert(pos < shape.size() && "overflow" ); |
341 | shape.set(pos, val); |
342 | return *this; |
343 | } |
344 | |
345 | operator VectorType() { |
346 | return VectorType::get(shape, elementType, scalableDims); |
347 | } |
348 | |
349 | private: |
350 | Type elementType; |
351 | CopyOnWriteArrayRef<int64_t> shape; |
352 | CopyOnWriteArrayRef<bool> scalableDims; |
353 | }; |
354 | |
355 | /// Given an `originalShape` and a `reducedShape` assumed to be a subset of |
356 | /// `originalShape` with some `1` entries erased, return the set of indices |
357 | /// that specifies which of the entries of `originalShape` are dropped to obtain |
358 | /// `reducedShape`. The returned mask can be applied as a projection to |
359 | /// `originalShape` to obtain the `reducedShape`. This mask is useful to track |
360 | /// which dimensions must be kept when e.g. compute MemRef strides under |
361 | /// rank-reducing operations. Return std::nullopt if reducedShape cannot be |
362 | /// obtained by dropping only `1` entries in `originalShape`. |
363 | std::optional<llvm::SmallDenseSet<unsigned>> |
364 | computeRankReductionMask(ArrayRef<int64_t> originalShape, |
365 | ArrayRef<int64_t> reducedShape); |
366 | |
367 | /// Enum that captures information related to verifier error conditions on |
368 | /// slice insert/extract type of ops. |
369 | enum class SliceVerificationResult { |
370 | Success, |
371 | RankTooLarge, |
372 | SizeMismatch, |
373 | ElemTypeMismatch, |
374 | // Error codes to ops with a memory space and a layout annotation. |
375 | MemSpaceMismatch, |
376 | LayoutMismatch |
377 | }; |
378 | |
379 | /// Check if `originalType` can be rank reduced to `candidateReducedType` type |
380 | /// by dropping some dimensions with static size `1`. |
381 | /// Return `SliceVerificationResult::Success` on success or an appropriate error |
382 | /// code. |
383 | SliceVerificationResult isRankReducedType(ShapedType originalType, |
384 | ShapedType candidateReducedType); |
385 | |
386 | //===----------------------------------------------------------------------===// |
387 | // Deferred Method Definitions |
388 | //===----------------------------------------------------------------------===// |
389 | |
390 | inline bool BaseMemRefType::classof(Type type) { |
391 | return llvm::isa<MemRefType, UnrankedMemRefType>(type); |
392 | } |
393 | |
394 | inline bool BaseMemRefType::isValidElementType(Type type) { |
395 | return type.isIntOrIndexOrFloat() || |
396 | llvm::isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>( |
397 | type) || |
398 | llvm::isa<MemRefElementTypeInterface>(type); |
399 | } |
400 | |
401 | inline bool FloatType::classof(Type type) { |
402 | return llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType, |
403 | Float8E4M3FNUZType, Float8E4M3B11FNUZType, BFloat16Type, |
404 | Float16Type, FloatTF32Type, Float32Type, Float64Type, |
405 | Float80Type, Float128Type>(type); |
406 | } |
407 | |
408 | inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) { |
409 | return Float8E5M2Type::get(ctx); |
410 | } |
411 | |
412 | inline FloatType FloatType::getFloat8E4M3FN(MLIRContext *ctx) { |
413 | return Float8E4M3FNType::get(ctx); |
414 | } |
415 | |
416 | inline FloatType FloatType::getFloat8E5M2FNUZ(MLIRContext *ctx) { |
417 | return Float8E5M2FNUZType::get(ctx); |
418 | } |
419 | |
420 | inline FloatType FloatType::getFloat8E4M3FNUZ(MLIRContext *ctx) { |
421 | return Float8E4M3FNUZType::get(ctx); |
422 | } |
423 | |
424 | inline FloatType FloatType::getFloat8E4M3B11FNUZ(MLIRContext *ctx) { |
425 | return Float8E4M3B11FNUZType::get(ctx); |
426 | } |
427 | |
428 | inline FloatType FloatType::getBF16(MLIRContext *ctx) { |
429 | return BFloat16Type::get(ctx); |
430 | } |
431 | |
432 | inline FloatType FloatType::getF16(MLIRContext *ctx) { |
433 | return Float16Type::get(ctx); |
434 | } |
435 | |
436 | inline FloatType FloatType::getTF32(MLIRContext *ctx) { |
437 | return FloatTF32Type::get(ctx); |
438 | } |
439 | |
440 | inline FloatType FloatType::getF32(MLIRContext *ctx) { |
441 | return Float32Type::get(ctx); |
442 | } |
443 | |
444 | inline FloatType FloatType::getF64(MLIRContext *ctx) { |
445 | return Float64Type::get(ctx); |
446 | } |
447 | |
448 | inline FloatType FloatType::getF80(MLIRContext *ctx) { |
449 | return Float80Type::get(ctx); |
450 | } |
451 | |
452 | inline FloatType FloatType::getF128(MLIRContext *ctx) { |
453 | return Float128Type::get(ctx); |
454 | } |
455 | |
456 | inline bool TensorType::classof(Type type) { |
457 | return llvm::isa<RankedTensorType, UnrankedTensorType>(type); |
458 | } |
459 | |
460 | //===----------------------------------------------------------------------===// |
461 | // Type Utilities |
462 | //===----------------------------------------------------------------------===// |
463 | |
464 | /// Returns the strides of the MemRef if the layout map is in strided form. |
465 | /// MemRefs with a layout map in strided form include: |
466 | /// 1. empty or identity layout map, in which case the stride information is |
467 | /// the canonical form computed from sizes; |
468 | /// 2. a StridedLayoutAttr layout; |
469 | /// 3. any other layout that be converted into a single affine map layout of |
470 | /// the form `K + k0 * d0 + ... kn * dn`, where K and ki's are constants or |
471 | /// symbols. |
472 | /// |
473 | /// A stride specification is a list of integer values that are either static |
474 | /// or dynamic (encoded with ShapedType::kDynamic). Strides encode |
475 | /// the distance in the number of elements between successive entries along a |
476 | /// particular dimension. |
477 | LogicalResult getStridesAndOffset(MemRefType t, |
478 | SmallVectorImpl<int64_t> &strides, |
479 | int64_t &offset); |
480 | |
481 | /// Wrapper around getStridesAndOffset(MemRefType, SmallVectorImpl<int64_t>, |
482 | /// int64_t) that will assert if the logical result is not succeeded. |
483 | std::pair<SmallVector<int64_t>, int64_t> getStridesAndOffset(MemRefType t); |
484 | |
485 | /// Return a version of `t` with identity layout if it can be determined |
486 | /// statically that the layout is the canonical contiguous strided layout. |
487 | /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of |
488 | /// `t` with simplified layout. |
489 | MemRefType canonicalizeStridedLayout(MemRefType t); |
490 | |
491 | /// Given MemRef `sizes` that are either static or dynamic, returns the |
492 | /// canonical "contiguous" strides AffineExpr. Strides are multiplicative and |
493 | /// once a dynamic dimension is encountered, all canonical strides become |
494 | /// dynamic and need to be encoded with a different symbol. |
495 | /// For canonical strides expressions, the offset is always 0 and the fastest |
496 | /// varying stride is always `1`. |
497 | /// |
498 | /// Examples: |
499 | /// - memref<3x4x5xf32> has canonical stride expression |
500 | /// `20*exprs[0] + 5*exprs[1] + exprs[2]`. |
501 | /// - memref<3x?x5xf32> has canonical stride expression |
502 | /// `s0*exprs[0] + 5*exprs[1] + exprs[2]`. |
503 | /// - memref<3x4x?xf32> has canonical stride expression |
504 | /// `s1*exprs[0] + s0*exprs[1] + exprs[2]`. |
505 | AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, |
506 | ArrayRef<AffineExpr> exprs, |
507 | MLIRContext *context); |
508 | |
509 | /// Return the result of makeCanonicalStrudedLayoutExpr for the common case |
510 | /// where `exprs` is {d0, d1, .., d_(sizes.size()-1)} |
511 | AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, |
512 | MLIRContext *context); |
513 | |
514 | /// Return "true" if the layout for `t` is compatible with strided semantics. |
515 | bool isStrided(MemRefType t); |
516 | |
517 | /// Return "true" if the last dimension of the given type has a static unit |
518 | /// stride. Also return "true" for types with no strides. |
519 | bool isLastMemrefDimUnitStride(MemRefType type); |
520 | |
521 | /// Return "true" if the last N dimensions of the given type are contiguous. |
522 | /// |
523 | /// Examples: |
524 | /// - memref<5x4x3x2xi8, strided<[24, 6, 2, 1]> is contiguous when |
525 | /// considering both _all_ and _only_ the trailing 3 dims, |
526 | /// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when |
527 | /// considering the trailing 3 dims. |
528 | /// |
529 | bool trailingNDimsContiguous(MemRefType type, int64_t n); |
530 | |
531 | } // namespace mlir |
532 | |
533 | #endif // MLIR_IR_BUILTINTYPES_H |
534 | |