1//===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the MatrixBuilder class, which is used as a convenient way
10// to lower matrix operations to LLVM IR.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef LLVM_IR_MATRIXBUILDER_H
15#define LLVM_IR_MATRIXBUILDER_H
16
17#include "llvm/IR/Constant.h"
18#include "llvm/IR/Constants.h"
19#include "llvm/IR/IRBuilder.h"
20#include "llvm/IR/InstrTypes.h"
21#include "llvm/IR/Instruction.h"
22#include "llvm/IR/IntrinsicInst.h"
23#include "llvm/IR/Type.h"
24#include "llvm/IR/Value.h"
25#include "llvm/Support/Alignment.h"
26
27namespace llvm {
28
29class Function;
30class Twine;
31class Module;
32
33template <class IRBuilderTy> class MatrixBuilder {
34 IRBuilderTy &B;
35 Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); }
36
37 std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS,
38 Value *RHS) {
39 assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) &&
40 "One of the operands must be a matrix (embedded in a vector)");
41 if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
42 assert(!isa<ScalableVectorType>(LHS->getType()) &&
43 "LHS Assumed to be fixed width");
44 RHS = B.CreateVectorSplat(
45 cast<VectorType>(LHS->getType())->getElementCount(), RHS,
46 "scalar.splat");
47 } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
48 assert(!isa<ScalableVectorType>(RHS->getType()) &&
49 "RHS Assumed to be fixed width");
50 LHS = B.CreateVectorSplat(
51 cast<VectorType>(RHS->getType())->getElementCount(), LHS,
52 "scalar.splat");
53 }
54 return {LHS, RHS};
55 }
56
57public:
58 MatrixBuilder(IRBuilderTy &Builder) : B(Builder) {}
59
60 /// Create a column major, strided matrix load.
61 /// \p DataPtr - Start address of the matrix read
62 /// \p Rows - Number of rows in matrix (must be a constant)
63 /// \p Columns - Number of columns in matrix (must be a constant)
64 /// \p Stride - Space between columns
65 CallInst *CreateColumnMajorLoad(Value *DataPtr, Align Alignment,
66 Value *Stride, bool IsVolatile, unsigned Rows,
67 unsigned Columns, const Twine &Name = "") {
68
69 // Deal with the pointer
70 PointerType *PtrTy = cast<PointerType>(DataPtr->getType());
71 Type *EltTy = PtrTy->getElementType();
72
73 auto *RetType = FixedVectorType::get(EltTy, Rows * Columns);
74
75 Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows),
76 B.getInt32(Columns)};
77 Type *OverloadedTypes[] = {RetType};
78
79 Function *TheFn = Intrinsic::getDeclaration(
80 getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes);
81
82 CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
83 Attribute AlignAttr =
84 Attribute::getWithAlignment(Call->getContext(), Alignment);
85 Call->addAttribute(1, AlignAttr);
86 return Call;
87 }
88
89 /// Create a column major, strided matrix store.
90 /// \p Matrix - Matrix to store
91 /// \p Ptr - Pointer to write back to
92 /// \p Stride - Space between columns
93 CallInst *CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment,
94 Value *Stride, bool IsVolatile,
95 unsigned Rows, unsigned Columns,
96 const Twine &Name = "") {
97 Value *Ops[] = {Matrix, Ptr,
98 Stride, B.getInt1(IsVolatile),
99 B.getInt32(Rows), B.getInt32(Columns)};
100 Type *OverloadedTypes[] = {Matrix->getType()};
101
102 Function *TheFn = Intrinsic::getDeclaration(
103 getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes);
104
105 CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
106 Attribute AlignAttr =
107 Attribute::getWithAlignment(Call->getContext(), Alignment);
108 Call->addAttribute(2, AlignAttr);
109 return Call;
110 }
111
112 /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows
113 /// rows and \p Columns columns.
114 CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows,
115 unsigned Columns, const Twine &Name = "") {
116 auto *OpType = cast<VectorType>(Matrix->getType());
117 auto *ReturnType =
118 FixedVectorType::get(OpType->getElementType(), Rows * Columns);
119
120 Type *OverloadedTypes[] = {ReturnType};
121 Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)};
122 Function *TheFn = Intrinsic::getDeclaration(
123 getModule(), Intrinsic::matrix_transpose, OverloadedTypes);
124
125 return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
126 }
127
128 /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p
129 /// RHS.
130 CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows,
131 unsigned LHSColumns, unsigned RHSColumns,
132 const Twine &Name = "") {
133 auto *LHSType = cast<VectorType>(LHS->getType());
134 auto *RHSType = cast<VectorType>(RHS->getType());
135
136 auto *ReturnType =
137 FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns);
138
139 Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns),
140 B.getInt32(RHSColumns)};
141 Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType};
142
143 Function *TheFn = Intrinsic::getDeclaration(
144 getModule(), Intrinsic::matrix_multiply, OverloadedTypes);
145 return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
146 }
147
148 /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p
149 /// ColumnIdx).
150 Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx,
151 Value *ColumnIdx, unsigned NumRows) {
152 return B.CreateInsertElement(
153 Matrix, NewVal,
154 B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get(
155 ColumnIdx->getType(), NumRows)),
156 RowIdx));
157 }
158
159 /// Add matrixes \p LHS and \p RHS. Support both integer and floating point
160 /// matrixes.
161 Value *CreateAdd(Value *LHS, Value *RHS) {
162 assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
163 if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
164 assert(!isa<ScalableVectorType>(LHS->getType()) &&
165 "LHS Assumed to be fixed width");
166 RHS = B.CreateVectorSplat(
167 cast<VectorType>(LHS->getType())->getElementCount(), RHS,
168 "scalar.splat");
169 } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
170 assert(!isa<ScalableVectorType>(RHS->getType()) &&
171 "RHS Assumed to be fixed width");
172 LHS = B.CreateVectorSplat(
173 cast<VectorType>(RHS->getType())->getElementCount(), LHS,
174 "scalar.splat");
175 }
176
177 return cast<VectorType>(LHS->getType())
178 ->getElementType()
179 ->isFloatingPointTy()
180 ? B.CreateFAdd(LHS, RHS)
181 : B.CreateAdd(LHS, RHS);
182 }
183
184 /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating
185 /// point matrixes.
186 Value *CreateSub(Value *LHS, Value *RHS) {
187 assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
188 if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
189 assert(!isa<ScalableVectorType>(LHS->getType()) &&
190 "LHS Assumed to be fixed width");
191 RHS = B.CreateVectorSplat(
192 cast<VectorType>(LHS->getType())->getElementCount(), RHS,
193 "scalar.splat");
194 } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
195 assert(!isa<ScalableVectorType>(RHS->getType()) &&
196 "RHS Assumed to be fixed width");
197 LHS = B.CreateVectorSplat(
198 cast<VectorType>(RHS->getType())->getElementCount(), LHS,
199 "scalar.splat");
200 }
201
202 return cast<VectorType>(LHS->getType())
203 ->getElementType()
204 ->isFloatingPointTy()
205 ? B.CreateFSub(LHS, RHS)
206 : B.CreateSub(LHS, RHS);
207 }
208
209 /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p
210 /// RHS.
211 Value *CreateScalarMultiply(Value *LHS, Value *RHS) {
212 std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS);
213 if (LHS->getType()->getScalarType()->isFloatingPointTy())
214 return B.CreateFMul(LHS, RHS);
215 return B.CreateMul(LHS, RHS);
216 }
217
218 /// Divide matrix \p LHS by scalar \p RHS. If the operands are integers, \p
219 /// IsUnsigned indicates whether UDiv or SDiv should be used.
220 Value *CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned) {
221 assert(LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy());
222 assert(!isa<ScalableVectorType>(LHS->getType()) &&
223 "LHS Assumed to be fixed width");
224 RHS =
225 B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getElementCount(),
226 RHS, "scalar.splat");
227 return cast<VectorType>(LHS->getType())
228 ->getElementType()
229 ->isFloatingPointTy()
230 ? B.CreateFDiv(LHS, RHS)
231 : (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS));
232 }
233
234 /// Extracts the element at (\p RowIdx, \p ColumnIdx) from \p Matrix.
235 Value *CreateExtractElement(Value *Matrix, Value *RowIdx, Value *ColumnIdx,
236 unsigned NumRows, Twine const &Name = "") {
237
238 unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(),
239 ColumnIdx->getType()->getScalarSizeInBits());
240 Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth);
241 RowIdx = B.CreateZExt(RowIdx, IntTy);
242 ColumnIdx = B.CreateZExt(ColumnIdx, IntTy);
243 Value *NumRowsV = B.getIntN(MaxWidth, NumRows);
244 return B.CreateExtractElement(
245 Matrix, B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx),
246 "matext");
247 }
248};
249
250} // end namespace llvm
251
252#endif // LLVM_IR_MATRIXBUILDER_H
253