1//===- BuiltinTypes.h - MLIR Builtin Type Classes ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_IR_BUILTINTYPES_H
10#define MLIR_IR_BUILTINTYPES_H
11
12#include "mlir/IR/BuiltinAttributeInterfaces.h"
13#include "mlir/IR/BuiltinTypeInterfaces.h"
14#include "mlir/Support/ADTExtras.h"
15
16namespace llvm {
17class BitVector;
18struct fltSemantics;
19} // namespace llvm
20
21//===----------------------------------------------------------------------===//
22// Tablegen Interface Declarations
23//===----------------------------------------------------------------------===//
24
25namespace mlir {
26class AffineExpr;
27class AffineMap;
28class FloatType;
29class IndexType;
30class IntegerType;
31class MemRefType;
32class RankedTensorType;
33class StringAttr;
34class TypeRange;
35
36namespace detail {
37struct FunctionTypeStorage;
38struct IntegerTypeStorage;
39struct TupleTypeStorage;
40} // namespace detail
41
42//===----------------------------------------------------------------------===//
43// FloatType
44//===----------------------------------------------------------------------===//
45
46class FloatType : public Type {
47public:
48 using Type::Type;
49
50 // Convenience factories.
51 static FloatType getBF16(MLIRContext *ctx);
52 static FloatType getF16(MLIRContext *ctx);
53 static FloatType getF32(MLIRContext *ctx);
54 static FloatType getTF32(MLIRContext *ctx);
55 static FloatType getF64(MLIRContext *ctx);
56 static FloatType getF80(MLIRContext *ctx);
57 static FloatType getF128(MLIRContext *ctx);
58 static FloatType getFloat8E5M2(MLIRContext *ctx);
59 static FloatType getFloat8E4M3FN(MLIRContext *ctx);
60 static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
61 static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
62 static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
63
64 /// Methods for support type inquiry through isa, cast, and dyn_cast.
65 static bool classof(Type type);
66
67 /// Return the bitwidth of this float type.
68 unsigned getWidth();
69
70 /// Return the width of the mantissa of this type.
71 /// The width includes the integer bit.
72 unsigned getFPMantissaWidth();
73
74 /// Get or create a new FloatType with bitwidth scaled by `scale`.
75 /// Return null if the scaled element type cannot be represented.
76 FloatType scaleElementBitwidth(unsigned scale);
77
78 /// Return the floating semantics of this float type.
79 const llvm::fltSemantics &getFloatSemantics();
80};
81
82//===----------------------------------------------------------------------===//
83// TensorType
84//===----------------------------------------------------------------------===//
85
86/// Tensor types represent multi-dimensional arrays, and have two variants:
87/// RankedTensorType and UnrankedTensorType.
88/// Note: This class attaches the ShapedType trait to act as a mixin to
89/// provide many useful utility functions. This inheritance has no effect
90/// on derived tensor types.
91class TensorType : public Type, public ShapedType::Trait<TensorType> {
92public:
93 using Type::Type;
94
95 /// Returns the element type of this tensor type.
96 Type getElementType() const;
97
98 /// Returns if this type is ranked, i.e. it has a known number of dimensions.
99 bool hasRank() const;
100
101 /// Returns the shape of this tensor type.
102 ArrayRef<int64_t> getShape() const;
103
104 /// Clone this type with the given shape and element type. If the
105 /// provided shape is `std::nullopt`, the current shape of the type is used.
106 TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape,
107 Type elementType) const;
108
109 // Make sure that base class overloads are visible.
110 using ShapedType::Trait<TensorType>::clone;
111
112 /// Return a clone of this type with the given new shape and element type.
113 /// The returned type is ranked, even if this type is unranked.
114 RankedTensorType clone(ArrayRef<int64_t> shape, Type elementType) const;
115
116 /// Return a clone of this type with the given new shape. The returned type
117 /// is ranked, even if this type is unranked.
118 RankedTensorType clone(ArrayRef<int64_t> shape) const;
119
120 /// Return true if the specified element type is ok in a tensor.
121 static bool isValidElementType(Type type);
122
123 /// Methods for support type inquiry through isa, cast, and dyn_cast.
124 static bool classof(Type type);
125
126 /// Allow implicit conversion to ShapedType.
127 operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
128};
129
130//===----------------------------------------------------------------------===//
131// BaseMemRefType
132//===----------------------------------------------------------------------===//
133
134/// This class provides a shared interface for ranked and unranked memref types.
135/// Note: This class attaches the ShapedType trait to act as a mixin to
136/// provide many useful utility functions. This inheritance has no effect
137/// on derived memref types.
138class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
139public:
140 using Type::Type;
141
142 /// Returns the element type of this memref type.
143 Type getElementType() const;
144
145 /// Returns if this type is ranked, i.e. it has a known number of dimensions.
146 bool hasRank() const;
147
148 /// Returns the shape of this memref type.
149 ArrayRef<int64_t> getShape() const;
150
151 /// Clone this type with the given shape and element type. If the
152 /// provided shape is `std::nullopt`, the current shape of the type is used.
153 BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
154 Type elementType) const;
155
156 // Make sure that base class overloads are visible.
157 using ShapedType::Trait<BaseMemRefType>::clone;
158
159 /// Return a clone of this type with the given new shape and element type.
160 /// The returned type is ranked, even if this type is unranked.
161 MemRefType clone(ArrayRef<int64_t> shape, Type elementType) const;
162
163 /// Return a clone of this type with the given new shape. The returned type
164 /// is ranked, even if this type is unranked.
165 MemRefType clone(ArrayRef<int64_t> shape) const;
166
167 /// Return true if the specified element type is ok in a memref.
168 static bool isValidElementType(Type type);
169
170 /// Methods for support type inquiry through isa, cast, and dyn_cast.
171 static bool classof(Type type);
172
173 /// Returns the memory space in which data referred to by this memref resides.
174 Attribute getMemorySpace() const;
175
176 /// [deprecated] Returns the memory space in old raw integer representation.
177 /// New `Attribute getMemorySpace()` method should be used instead.
178 unsigned getMemorySpaceAsInt() const;
179
180 /// Allow implicit conversion to ShapedType.
181 operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
182};
183
184} // namespace mlir
185
186//===----------------------------------------------------------------------===//
187// Tablegen Type Declarations
188//===----------------------------------------------------------------------===//
189
190#define GET_TYPEDEF_CLASSES
191#include "mlir/IR/BuiltinTypes.h.inc"
192
193namespace mlir {
194
195//===----------------------------------------------------------------------===//
196// MemRefType
197//===----------------------------------------------------------------------===//
198
199/// This is a builder type that keeps local references to arguments. Arguments
200/// that are passed into the builder must outlive the builder.
201class MemRefType::Builder {
202public:
203 // Build from another MemRefType.
204 explicit Builder(MemRefType other)
205 : shape(other.getShape()), elementType(other.getElementType()),
206 layout(other.getLayout()), memorySpace(other.getMemorySpace()) {}
207
208 // Build from scratch.
209 Builder(ArrayRef<int64_t> shape, Type elementType)
210 : shape(shape), elementType(elementType) {}
211
212 Builder &setShape(ArrayRef<int64_t> newShape) {
213 shape = newShape;
214 return *this;
215 }
216
217 Builder &setElementType(Type newElementType) {
218 elementType = newElementType;
219 return *this;
220 }
221
222 Builder &setLayout(MemRefLayoutAttrInterface newLayout) {
223 layout = newLayout;
224 return *this;
225 }
226
227 Builder &setMemorySpace(Attribute newMemorySpace) {
228 memorySpace = newMemorySpace;
229 return *this;
230 }
231
232 operator MemRefType() {
233 return MemRefType::get(shape, elementType, layout, memorySpace);
234 }
235
236private:
237 ArrayRef<int64_t> shape;
238 Type elementType;
239 MemRefLayoutAttrInterface layout;
240 Attribute memorySpace;
241};
242
243//===----------------------------------------------------------------------===//
244// RankedTensorType
245//===----------------------------------------------------------------------===//
246
247/// This is a builder type that keeps local references to arguments. Arguments
248/// that are passed into the builder must outlive the builder.
249class RankedTensorType::Builder {
250public:
251 /// Build from another RankedTensorType.
252 explicit Builder(RankedTensorType other)
253 : shape(other.getShape()), elementType(other.getElementType()),
254 encoding(other.getEncoding()) {}
255
256 /// Build from scratch.
257 Builder(ArrayRef<int64_t> shape, Type elementType, Attribute encoding)
258 : shape(shape), elementType(elementType), encoding(encoding) {}
259
260 Builder &setShape(ArrayRef<int64_t> newShape) {
261 shape = newShape;
262 return *this;
263 }
264
265 Builder &setElementType(Type newElementType) {
266 elementType = newElementType;
267 return *this;
268 }
269
270 Builder &setEncoding(Attribute newEncoding) {
271 encoding = newEncoding;
272 return *this;
273 }
274
275 /// Erase a dim from shape @pos.
276 Builder &dropDim(unsigned pos) {
277 assert(pos < shape.size() && "overflow");
278 shape.erase(pos);
279 return *this;
280 }
281
282 /// Insert a val into shape @pos.
283 Builder &insertDim(int64_t val, unsigned pos) {
284 assert(pos <= shape.size() && "overflow");
285 shape.insert(pos, val);
286 return *this;
287 }
288
289 operator RankedTensorType() {
290 return RankedTensorType::get(shape, elementType, encoding);
291 }
292
293private:
294 CopyOnWriteArrayRef<int64_t> shape;
295 Type elementType;
296 Attribute encoding;
297};
298
299//===----------------------------------------------------------------------===//
300// VectorType
301//===----------------------------------------------------------------------===//
302
303/// This is a builder type that keeps local references to arguments. Arguments
304/// that are passed into the builder must outlive the builder.
305class VectorType::Builder {
306public:
307 /// Build from another VectorType.
308 explicit Builder(VectorType other)
309 : elementType(other.getElementType()), shape(other.getShape()),
310 scalableDims(other.getScalableDims()) {}
311
312 /// Build from scratch.
313 Builder(ArrayRef<int64_t> shape, Type elementType,
314 ArrayRef<bool> scalableDims = {})
315 : elementType(elementType), shape(shape), scalableDims(scalableDims) {}
316
317 Builder &setShape(ArrayRef<int64_t> newShape,
318 ArrayRef<bool> newIsScalableDim = {}) {
319 shape = newShape;
320 scalableDims = newIsScalableDim;
321 return *this;
322 }
323
324 Builder &setElementType(Type newElementType) {
325 elementType = newElementType;
326 return *this;
327 }
328
329 /// Erase a dim from shape @pos.
330 Builder &dropDim(unsigned pos) {
331 assert(pos < shape.size() && "overflow");
332 shape.erase(pos);
333 if (!scalableDims.empty())
334 scalableDims.erase(pos);
335 return *this;
336 }
337
338 /// Set a dim in shape @pos to val.
339 Builder &setDim(unsigned pos, int64_t val) {
340 assert(pos < shape.size() && "overflow");
341 shape.set(pos, val);
342 return *this;
343 }
344
345 operator VectorType() {
346 return VectorType::get(shape, elementType, scalableDims);
347 }
348
349private:
350 Type elementType;
351 CopyOnWriteArrayRef<int64_t> shape;
352 CopyOnWriteArrayRef<bool> scalableDims;
353};
354
355/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
356/// `originalShape` with some `1` entries erased, return the set of indices
357/// that specifies which of the entries of `originalShape` are dropped to obtain
358/// `reducedShape`. The returned mask can be applied as a projection to
359/// `originalShape` to obtain the `reducedShape`. This mask is useful to track
360/// which dimensions must be kept when e.g. compute MemRef strides under
361/// rank-reducing operations. Return std::nullopt if reducedShape cannot be
362/// obtained by dropping only `1` entries in `originalShape`.
363std::optional<llvm::SmallDenseSet<unsigned>>
364computeRankReductionMask(ArrayRef<int64_t> originalShape,
365 ArrayRef<int64_t> reducedShape);
366
367/// Enum that captures information related to verifier error conditions on
368/// slice insert/extract type of ops.
369enum class SliceVerificationResult {
370 Success,
371 RankTooLarge,
372 SizeMismatch,
373 ElemTypeMismatch,
374 // Error codes to ops with a memory space and a layout annotation.
375 MemSpaceMismatch,
376 LayoutMismatch
377};
378
379/// Check if `originalType` can be rank reduced to `candidateReducedType` type
380/// by dropping some dimensions with static size `1`.
381/// Return `SliceVerificationResult::Success` on success or an appropriate error
382/// code.
383SliceVerificationResult isRankReducedType(ShapedType originalType,
384 ShapedType candidateReducedType);
385
386//===----------------------------------------------------------------------===//
387// Deferred Method Definitions
388//===----------------------------------------------------------------------===//
389
390inline bool BaseMemRefType::classof(Type type) {
391 return llvm::isa<MemRefType, UnrankedMemRefType>(type);
392}
393
394inline bool BaseMemRefType::isValidElementType(Type type) {
395 return type.isIntOrIndexOrFloat() ||
396 llvm::isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>(
397 type) ||
398 llvm::isa<MemRefElementTypeInterface>(type);
399}
400
401inline bool FloatType::classof(Type type) {
402 return llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
403 Float8E4M3FNUZType, Float8E4M3B11FNUZType, BFloat16Type,
404 Float16Type, FloatTF32Type, Float32Type, Float64Type,
405 Float80Type, Float128Type>(type);
406}
407
408inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
409 return Float8E5M2Type::get(ctx);
410}
411
412inline FloatType FloatType::getFloat8E4M3FN(MLIRContext *ctx) {
413 return Float8E4M3FNType::get(ctx);
414}
415
416inline FloatType FloatType::getFloat8E5M2FNUZ(MLIRContext *ctx) {
417 return Float8E5M2FNUZType::get(ctx);
418}
419
420inline FloatType FloatType::getFloat8E4M3FNUZ(MLIRContext *ctx) {
421 return Float8E4M3FNUZType::get(ctx);
422}
423
424inline FloatType FloatType::getFloat8E4M3B11FNUZ(MLIRContext *ctx) {
425 return Float8E4M3B11FNUZType::get(ctx);
426}
427
428inline FloatType FloatType::getBF16(MLIRContext *ctx) {
429 return BFloat16Type::get(ctx);
430}
431
432inline FloatType FloatType::getF16(MLIRContext *ctx) {
433 return Float16Type::get(ctx);
434}
435
436inline FloatType FloatType::getTF32(MLIRContext *ctx) {
437 return FloatTF32Type::get(ctx);
438}
439
440inline FloatType FloatType::getF32(MLIRContext *ctx) {
441 return Float32Type::get(ctx);
442}
443
444inline FloatType FloatType::getF64(MLIRContext *ctx) {
445 return Float64Type::get(ctx);
446}
447
448inline FloatType FloatType::getF80(MLIRContext *ctx) {
449 return Float80Type::get(ctx);
450}
451
452inline FloatType FloatType::getF128(MLIRContext *ctx) {
453 return Float128Type::get(ctx);
454}
455
456inline bool TensorType::classof(Type type) {
457 return llvm::isa<RankedTensorType, UnrankedTensorType>(type);
458}
459
460//===----------------------------------------------------------------------===//
461// Type Utilities
462//===----------------------------------------------------------------------===//
463
464/// Returns the strides of the MemRef if the layout map is in strided form.
465/// MemRefs with a layout map in strided form include:
466/// 1. empty or identity layout map, in which case the stride information is
467/// the canonical form computed from sizes;
468/// 2. a StridedLayoutAttr layout;
469/// 3. any other layout that be converted into a single affine map layout of
470/// the form `K + k0 * d0 + ... kn * dn`, where K and ki's are constants or
471/// symbols.
472///
473/// A stride specification is a list of integer values that are either static
474/// or dynamic (encoded with ShapedType::kDynamic). Strides encode
475/// the distance in the number of elements between successive entries along a
476/// particular dimension.
477LogicalResult getStridesAndOffset(MemRefType t,
478 SmallVectorImpl<int64_t> &strides,
479 int64_t &offset);
480
481/// Wrapper around getStridesAndOffset(MemRefType, SmallVectorImpl<int64_t>,
482/// int64_t) that will assert if the logical result is not succeeded.
483std::pair<SmallVector<int64_t>, int64_t> getStridesAndOffset(MemRefType t);
484
485/// Return a version of `t` with identity layout if it can be determined
486/// statically that the layout is the canonical contiguous strided layout.
487/// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
488/// `t` with simplified layout.
489MemRefType canonicalizeStridedLayout(MemRefType t);
490
491/// Given MemRef `sizes` that are either static or dynamic, returns the
492/// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
493/// once a dynamic dimension is encountered, all canonical strides become
494/// dynamic and need to be encoded with a different symbol.
495/// For canonical strides expressions, the offset is always 0 and the fastest
496/// varying stride is always `1`.
497///
498/// Examples:
499/// - memref<3x4x5xf32> has canonical stride expression
500/// `20*exprs[0] + 5*exprs[1] + exprs[2]`.
501/// - memref<3x?x5xf32> has canonical stride expression
502/// `s0*exprs[0] + 5*exprs[1] + exprs[2]`.
503/// - memref<3x4x?xf32> has canonical stride expression
504/// `s1*exprs[0] + s0*exprs[1] + exprs[2]`.
505AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
506 ArrayRef<AffineExpr> exprs,
507 MLIRContext *context);
508
509/// Return the result of makeCanonicalStrudedLayoutExpr for the common case
510/// where `exprs` is {d0, d1, .., d_(sizes.size()-1)}
511AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
512 MLIRContext *context);
513
514/// Return "true" if the layout for `t` is compatible with strided semantics.
515bool isStrided(MemRefType t);
516
517/// Return "true" if the last dimension of the given type has a static unit
518/// stride. Also return "true" for types with no strides.
519bool isLastMemrefDimUnitStride(MemRefType type);
520
521/// Return "true" if the last N dimensions of the given type are contiguous.
522///
523/// Examples:
524/// - memref<5x4x3x2xi8, strided<[24, 6, 2, 1]> is contiguous when
525/// considering both _all_ and _only_ the trailing 3 dims,
526/// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when
527/// considering the trailing 3 dims.
528///
529bool trailingNDimsContiguous(MemRefType type, int64_t n);
530
531} // namespace mlir
532
533#endif // MLIR_IR_BUILTINTYPES_H
534

source code of mlir/include/mlir/IR/BuiltinTypes.h