1 | //===- GPUDialect.h - MLIR Dialect for GPU Kernels --------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the GPU kernel-related operations and puts them in the |
10 | // corresponding dialect. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #ifndef MLIR_DIALECT_GPU_IR_GPUDIALECT_H |
15 | #define MLIR_DIALECT_GPU_IR_GPUDIALECT_H |
16 | |
17 | #include "mlir/Bytecode/BytecodeOpInterface.h" |
18 | #include "mlir/Dialect/DLTI/Traits.h" |
19 | #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h" |
20 | #include "mlir/IR/Builders.h" |
21 | #include "mlir/IR/BuiltinTypes.h" |
22 | #include "mlir/IR/Dialect.h" |
23 | #include "mlir/IR/OpDefinition.h" |
24 | #include "mlir/IR/OpImplementation.h" |
25 | #include "mlir/IR/SymbolTable.h" |
26 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
27 | #include "mlir/Interfaces/FunctionInterfaces.h" |
28 | #include "mlir/Interfaces/InferIntRangeInterface.h" |
29 | #include "mlir/Interfaces/InferTypeOpInterface.h" |
30 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
31 | #include "llvm/ADT/STLExtras.h" |
32 | |
33 | namespace mlir { |
34 | namespace gpu { |
35 | |
36 | /// Utility class for the GPU dialect to represent triples of `Value`s |
37 | /// accessible through `.x`, `.y`, and `.z` similarly to CUDA notation. |
38 | struct KernelDim3 { |
39 | Value x; |
40 | Value y; |
41 | Value z; |
42 | }; |
43 | |
44 | class AsyncTokenType |
45 | : public Type::TypeBase<AsyncTokenType, Type, TypeStorage> { |
46 | public: |
47 | // Used for generic hooks in TypeBase. |
48 | using Base::Base; |
49 | |
50 | static constexpr StringLiteral name = "gpu.async_token" ; |
51 | }; |
52 | |
53 | /// MMAMatrixType storage and uniquing. Array is uniqued based on its shape |
54 | /// and type. |
55 | struct MMAMatrixStorageType : public TypeStorage { |
56 | MMAMatrixStorageType(unsigned numDims, const int64_t *dimShapes, |
57 | Type elementType, StringRef operand) |
58 | : dimShapes(dimShapes), numDims(numDims), elementType(elementType), |
59 | operand(operand) {} |
60 | |
61 | /// The hash key for uniquing. |
62 | using KeyTy = std::tuple<ArrayRef<int64_t>, Type, StringRef>; |
63 | bool operator==(const KeyTy &key) const { |
64 | return key == KeyTy(getShape(), elementType, operand); |
65 | } |
66 | |
67 | /// Construction. |
68 | static MMAMatrixStorageType *construct(TypeStorageAllocator &allocator, |
69 | const KeyTy &key) { |
70 | ArrayRef<int64_t> shape = allocator.copyInto(elements: std::get<0>(t: key)); |
71 | StringRef operand = allocator.copyInto(str: std::get<2>(t: key)); |
72 | |
73 | return new (allocator.allocate<MMAMatrixStorageType>()) |
74 | MMAMatrixStorageType(shape.size(), shape.data(), std::get<1>(t: key), |
75 | operand); |
76 | } |
77 | |
78 | ArrayRef<int64_t> getShape() const { |
79 | return ArrayRef<int64_t>(dimShapes, numDims); |
80 | } |
81 | |
82 | StringRef getOperand() const { return operand; } |
83 | |
84 | /// Reference to the shape of the MMA matrix. |
85 | const int64_t *dimShapes; |
86 | |
87 | /// Number of dimensions in the MMA matrix. |
88 | unsigned numDims; |
89 | |
90 | /// Element type of elements held in the MMA matrix. |
91 | Type elementType; |
92 | |
93 | /// MMA operand that this MMAMatrix holds. The general form of operation this |
94 | /// type supports is given by the equation C += A*B. This field specifies |
95 | /// which operand in the given equation is held by this type. The valid values |
96 | /// are "AOp", "BOp" and "COp". |
97 | StringRef operand; |
98 | }; |
99 | |
100 | /// MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply |
101 | /// accumulate operations. MMAMatrices are taken as direct operands by these |
102 | /// operations and are also produced as results. These matrices are meant to |
103 | /// reside in the registers. A limited number of pointwise operations can be |
104 | /// performed on these matrices, i.e., operations which operate uniformly on |
105 | /// all the elements in the matrix and do not change the order of matrix |
106 | /// elements. The above conditions exist because the layout of matrix elements |
107 | /// inside the matrix is opaque i.e., the elements may be present in the |
108 | /// matrix in any order. The general usage of this type is shown as follows:- |
109 | /// |
110 | /// %0 = gpu.subgroup_mma_load_matrix %arg0[%c0, %c0] {leadDimension = 16 : |
111 | /// index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> |
112 | /// |
113 | /// The MMAMatrixType describes the shape of the matrix being loaded and the |
114 | /// operand being loaded too. The operand needs to be specified to aid the |
115 | /// lowering of this type to dialects such as NVVM where each workitem may |
116 | /// hold different amount of elements depending on the elementType of the |
117 | /// matrix. For e.g., Each workitem holds 4 vector<2xf16>s for f16 data type |
118 | /// and 8 f32s for f32 data type of MMAMatrix. Some other instances of usage |
119 | /// are:- |
120 | /// |
121 | /// %3 = gpu.subgroup_mma_compute %0, %1, %2 : |
122 | /// !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> |
123 | /// -> !gpu.mma_matrix<16x16xf32, "COp"> |
124 | /// |
125 | /// |
126 | /// gpu.subgroup_mma_store_matrix %3, %arg22[%c0, %c0] {leadDimension = 16 |
127 | /// : index}: !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32> |
128 | // TODO: consider moving this to ODS. |
129 | class MMAMatrixType |
130 | : public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType> { |
131 | public: |
132 | using Base::Base; |
133 | |
134 | static constexpr StringLiteral name = "gpu.mma_matrix" ; |
135 | |
136 | /// Get MMAMatrixType and verify construction Invariants. |
137 | static MMAMatrixType get(ArrayRef<int64_t> shape, Type elementType, |
138 | StringRef operand); |
139 | |
140 | /// Get MMAMatrixType at a particular location and verify construction |
141 | /// Invariants. |
142 | static MMAMatrixType getChecked(function_ref<InFlightDiagnostic()> emitError, |
143 | ArrayRef<int64_t> shape, Type elementType, |
144 | StringRef operand); |
145 | |
146 | /// Check if a type is valid a MMAMatrixType elementType. |
147 | static bool isValidElementType(Type elementType); |
148 | |
149 | /// Verify that shape and elementType are actually allowed for the |
150 | /// MMAMatrixType. |
151 | static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError, |
152 | ArrayRef<int64_t> shape, Type elementType, |
153 | StringRef operand); |
154 | |
155 | /// Get number of dims. |
156 | unsigned getNumDims() const; |
157 | |
158 | /// Get shape of the matrix. |
159 | ArrayRef<int64_t> getShape() const; |
160 | |
161 | /// Get elementType of a single element. |
162 | Type getElementType() const; |
163 | |
164 | /// The general form of operation this type supports is given by the equation |
165 | /// C += A*B. This function returns which operand in the given equation is |
166 | /// held by this type. String returned can be one of"AOp", "BOp" and "COp". |
167 | StringRef getOperand() const; |
168 | }; |
169 | |
170 | // Adds a `gpu.async.token` to the front of the argument list. |
171 | void addAsyncDependency(Operation *op, Value token); |
172 | |
173 | // Handle types for sparse. |
174 | enum class SparseHandleKind { SpMat, DnTensor, SpGEMMOp }; |
175 | |
176 | class SparseDnTensorHandleType |
177 | : public Type::TypeBase<SparseDnTensorHandleType, Type, TypeStorage> { |
178 | public: |
179 | using Base = typename Type::TypeBase<SparseDnTensorHandleType, Type, |
180 | TypeStorage>::Base; |
181 | using Base::Base; |
182 | |
183 | static constexpr StringLiteral name = "gpu.sparse.dntensor_handle" ; |
184 | }; |
185 | |
186 | class SparseSpMatHandleType |
187 | : public Type::TypeBase<SparseSpMatHandleType, Type, TypeStorage> { |
188 | public: |
189 | using Base = |
190 | typename Type::TypeBase<SparseSpMatHandleType, Type, TypeStorage>::Base; |
191 | using Base::Base; |
192 | |
193 | static constexpr StringLiteral name = "gpu.sparse.spmat_handle" ; |
194 | }; |
195 | |
196 | class SparseSpGEMMOpHandleType |
197 | : public Type::TypeBase<SparseSpGEMMOpHandleType, Type, TypeStorage> { |
198 | public: |
199 | using Base = typename Type::TypeBase<SparseSpGEMMOpHandleType, Type, |
200 | TypeStorage>::Base; |
201 | using Base::Base; |
202 | |
203 | static constexpr StringLiteral name = "gpu.sparse.spgemmop_handle" ; |
204 | }; |
205 | |
206 | } // namespace gpu |
207 | } // namespace mlir |
208 | |
209 | #include "mlir/Dialect/GPU/IR/GPUOpsEnums.h.inc" |
210 | |
211 | #include "mlir/Dialect/GPU/IR/GPUOpsDialect.h.inc" |
212 | |
213 | #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.h.inc" |
214 | |
215 | #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" |
216 | |
217 | #define GET_ATTRDEF_CLASSES |
218 | #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.h.inc" |
219 | |
220 | #define GET_OP_CLASSES |
221 | #include "mlir/Dialect/GPU/IR/GPUOps.h.inc" |
222 | |
223 | #endif // MLIR_DIALECT_GPU_IR_GPUDIALECT_H |
224 | |