1 | //===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the MatrixBuilder class, which is used as a convenient way |
10 | // to lower matrix operations to LLVM IR. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #ifndef LLVM_IR_MATRIXBUILDER_H |
15 | #define LLVM_IR_MATRIXBUILDER_H |
16 | |
17 | #include "llvm/IR/Constant.h" |
18 | #include "llvm/IR/Constants.h" |
19 | #include "llvm/IR/IRBuilder.h" |
20 | #include "llvm/IR/InstrTypes.h" |
21 | #include "llvm/IR/Instruction.h" |
22 | #include "llvm/IR/IntrinsicInst.h" |
23 | #include "llvm/IR/Type.h" |
24 | #include "llvm/IR/Value.h" |
25 | #include "llvm/Support/Alignment.h" |
26 | |
27 | namespace llvm { |
28 | |
29 | class Function; |
30 | class Twine; |
31 | class Module; |
32 | |
33 | template <class IRBuilderTy> class MatrixBuilder { |
34 | IRBuilderTy &B; |
35 | Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); } |
36 | |
37 | std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS, |
38 | Value *RHS) { |
39 | assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) && |
40 | "One of the operands must be a matrix (embedded in a vector)" ); |
41 | if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { |
42 | assert(!isa<ScalableVectorType>(LHS->getType()) && |
43 | "LHS Assumed to be fixed width" ); |
44 | RHS = B.CreateVectorSplat( |
45 | cast<VectorType>(LHS->getType())->getElementCount(), RHS, |
46 | "scalar.splat" ); |
47 | } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { |
48 | assert(!isa<ScalableVectorType>(RHS->getType()) && |
49 | "RHS Assumed to be fixed width" ); |
50 | LHS = B.CreateVectorSplat( |
51 | cast<VectorType>(RHS->getType())->getElementCount(), LHS, |
52 | "scalar.splat" ); |
53 | } |
54 | return {LHS, RHS}; |
55 | } |
56 | |
57 | public: |
58 | MatrixBuilder(IRBuilderTy &Builder) : B(Builder) {} |
59 | |
60 | /// Create a column major, strided matrix load. |
61 | /// \p DataPtr - Start address of the matrix read |
62 | /// \p Rows - Number of rows in matrix (must be a constant) |
63 | /// \p Columns - Number of columns in matrix (must be a constant) |
64 | /// \p Stride - Space between columns |
65 | CallInst *CreateColumnMajorLoad(Value *DataPtr, Align Alignment, |
66 | Value *Stride, bool IsVolatile, unsigned Rows, |
67 | unsigned Columns, const Twine &Name = "" ) { |
68 | |
69 | // Deal with the pointer |
70 | PointerType *PtrTy = cast<PointerType>(DataPtr->getType()); |
71 | Type *EltTy = PtrTy->getElementType(); |
72 | |
73 | auto *RetType = FixedVectorType::get(EltTy, Rows * Columns); |
74 | |
75 | Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows), |
76 | B.getInt32(Columns)}; |
77 | Type *OverloadedTypes[] = {RetType}; |
78 | |
79 | Function *TheFn = Intrinsic::getDeclaration( |
80 | getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes); |
81 | |
82 | CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); |
83 | Attribute AlignAttr = |
84 | Attribute::getWithAlignment(Call->getContext(), Alignment); |
85 | Call->addAttribute(1, AlignAttr); |
86 | return Call; |
87 | } |
88 | |
89 | /// Create a column major, strided matrix store. |
90 | /// \p Matrix - Matrix to store |
91 | /// \p Ptr - Pointer to write back to |
92 | /// \p Stride - Space between columns |
93 | CallInst *CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment, |
94 | Value *Stride, bool IsVolatile, |
95 | unsigned Rows, unsigned Columns, |
96 | const Twine &Name = "" ) { |
97 | Value *Ops[] = {Matrix, Ptr, |
98 | Stride, B.getInt1(IsVolatile), |
99 | B.getInt32(Rows), B.getInt32(Columns)}; |
100 | Type *OverloadedTypes[] = {Matrix->getType()}; |
101 | |
102 | Function *TheFn = Intrinsic::getDeclaration( |
103 | getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes); |
104 | |
105 | CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); |
106 | Attribute AlignAttr = |
107 | Attribute::getWithAlignment(Call->getContext(), Alignment); |
108 | Call->addAttribute(2, AlignAttr); |
109 | return Call; |
110 | } |
111 | |
112 | /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows |
113 | /// rows and \p Columns columns. |
114 | CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows, |
115 | unsigned Columns, const Twine &Name = "" ) { |
116 | auto *OpType = cast<VectorType>(Matrix->getType()); |
117 | auto *ReturnType = |
118 | FixedVectorType::get(OpType->getElementType(), Rows * Columns); |
119 | |
120 | Type *OverloadedTypes[] = {ReturnType}; |
121 | Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)}; |
122 | Function *TheFn = Intrinsic::getDeclaration( |
123 | getModule(), Intrinsic::matrix_transpose, OverloadedTypes); |
124 | |
125 | return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); |
126 | } |
127 | |
128 | /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p |
129 | /// RHS. |
130 | CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, |
131 | unsigned LHSColumns, unsigned RHSColumns, |
132 | const Twine &Name = "" ) { |
133 | auto *LHSType = cast<VectorType>(LHS->getType()); |
134 | auto *RHSType = cast<VectorType>(RHS->getType()); |
135 | |
136 | auto *ReturnType = |
137 | FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns); |
138 | |
139 | Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns), |
140 | B.getInt32(RHSColumns)}; |
141 | Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType}; |
142 | |
143 | Function *TheFn = Intrinsic::getDeclaration( |
144 | getModule(), Intrinsic::matrix_multiply, OverloadedTypes); |
145 | return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); |
146 | } |
147 | |
148 | /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p |
149 | /// ColumnIdx). |
150 | Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx, |
151 | Value *ColumnIdx, unsigned NumRows) { |
152 | return B.CreateInsertElement( |
153 | Matrix, NewVal, |
154 | B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get( |
155 | ColumnIdx->getType(), NumRows)), |
156 | RowIdx)); |
157 | } |
158 | |
159 | /// Add matrixes \p LHS and \p RHS. Support both integer and floating point |
160 | /// matrixes. |
161 | Value *CreateAdd(Value *LHS, Value *RHS) { |
162 | assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); |
163 | if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { |
164 | assert(!isa<ScalableVectorType>(LHS->getType()) && |
165 | "LHS Assumed to be fixed width" ); |
166 | RHS = B.CreateVectorSplat( |
167 | cast<VectorType>(LHS->getType())->getElementCount(), RHS, |
168 | "scalar.splat" ); |
169 | } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { |
170 | assert(!isa<ScalableVectorType>(RHS->getType()) && |
171 | "RHS Assumed to be fixed width" ); |
172 | LHS = B.CreateVectorSplat( |
173 | cast<VectorType>(RHS->getType())->getElementCount(), LHS, |
174 | "scalar.splat" ); |
175 | } |
176 | |
177 | return cast<VectorType>(LHS->getType()) |
178 | ->getElementType() |
179 | ->isFloatingPointTy() |
180 | ? B.CreateFAdd(LHS, RHS) |
181 | : B.CreateAdd(LHS, RHS); |
182 | } |
183 | |
184 | /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating |
185 | /// point matrixes. |
186 | Value *CreateSub(Value *LHS, Value *RHS) { |
187 | assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); |
188 | if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { |
189 | assert(!isa<ScalableVectorType>(LHS->getType()) && |
190 | "LHS Assumed to be fixed width" ); |
191 | RHS = B.CreateVectorSplat( |
192 | cast<VectorType>(LHS->getType())->getElementCount(), RHS, |
193 | "scalar.splat" ); |
194 | } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { |
195 | assert(!isa<ScalableVectorType>(RHS->getType()) && |
196 | "RHS Assumed to be fixed width" ); |
197 | LHS = B.CreateVectorSplat( |
198 | cast<VectorType>(RHS->getType())->getElementCount(), LHS, |
199 | "scalar.splat" ); |
200 | } |
201 | |
202 | return cast<VectorType>(LHS->getType()) |
203 | ->getElementType() |
204 | ->isFloatingPointTy() |
205 | ? B.CreateFSub(LHS, RHS) |
206 | : B.CreateSub(LHS, RHS); |
207 | } |
208 | |
209 | /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p |
210 | /// RHS. |
211 | Value *CreateScalarMultiply(Value *LHS, Value *RHS) { |
212 | std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS); |
213 | if (LHS->getType()->getScalarType()->isFloatingPointTy()) |
214 | return B.CreateFMul(LHS, RHS); |
215 | return B.CreateMul(LHS, RHS); |
216 | } |
217 | |
218 | /// Divide matrix \p LHS by scalar \p RHS. If the operands are integers, \p |
219 | /// IsUnsigned indicates whether UDiv or SDiv should be used. |
220 | Value *CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned) { |
221 | assert(LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()); |
222 | assert(!isa<ScalableVectorType>(LHS->getType()) && |
223 | "LHS Assumed to be fixed width" ); |
224 | RHS = |
225 | B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getElementCount(), |
226 | RHS, "scalar.splat" ); |
227 | return cast<VectorType>(LHS->getType()) |
228 | ->getElementType() |
229 | ->isFloatingPointTy() |
230 | ? B.CreateFDiv(LHS, RHS) |
231 | : (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS)); |
232 | } |
233 | |
234 | /// Extracts the element at (\p RowIdx, \p ColumnIdx) from \p Matrix. |
235 | Value *(Value *Matrix, Value *RowIdx, Value *ColumnIdx, |
236 | unsigned NumRows, Twine const &Name = "" ) { |
237 | |
238 | unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(), |
239 | ColumnIdx->getType()->getScalarSizeInBits()); |
240 | Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth); |
241 | RowIdx = B.CreateZExt(RowIdx, IntTy); |
242 | ColumnIdx = B.CreateZExt(ColumnIdx, IntTy); |
243 | Value *NumRowsV = B.getIntN(MaxWidth, NumRows); |
244 | return B.CreateExtractElement( |
245 | Matrix, B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx), |
246 | "matext" ); |
247 | } |
248 | }; |
249 | |
250 | } // end namespace llvm |
251 | |
252 | #endif // LLVM_IR_MATRIXBUILDER_H |
253 | |