1 | //===- ExpandReductions.cpp - Expand reduction intrinsics -----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This pass implements IR expansion for reduction intrinsics, allowing targets |
10 | // to enable the intrinsics until just before codegen. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "llvm/CodeGen/ExpandReductions.h" |
15 | #include "llvm/Analysis/TargetTransformInfo.h" |
16 | #include "llvm/CodeGen/Passes.h" |
17 | #include "llvm/IR/IRBuilder.h" |
18 | #include "llvm/IR/InstIterator.h" |
19 | #include "llvm/IR/IntrinsicInst.h" |
20 | #include "llvm/IR/Intrinsics.h" |
21 | #include "llvm/InitializePasses.h" |
22 | #include "llvm/Pass.h" |
23 | #include "llvm/Transforms/Utils/LoopUtils.h" |
24 | |
25 | using namespace llvm; |
26 | |
27 | namespace { |
28 | |
29 | unsigned getOpcode(Intrinsic::ID ID) { |
30 | switch (ID) { |
31 | case Intrinsic::vector_reduce_fadd: |
32 | return Instruction::FAdd; |
33 | case Intrinsic::vector_reduce_fmul: |
34 | return Instruction::FMul; |
35 | case Intrinsic::vector_reduce_add: |
36 | return Instruction::Add; |
37 | case Intrinsic::vector_reduce_mul: |
38 | return Instruction::Mul; |
39 | case Intrinsic::vector_reduce_and: |
40 | return Instruction::And; |
41 | case Intrinsic::vector_reduce_or: |
42 | return Instruction::Or; |
43 | case Intrinsic::vector_reduce_xor: |
44 | return Instruction::Xor; |
45 | case Intrinsic::vector_reduce_smax: |
46 | case Intrinsic::vector_reduce_smin: |
47 | case Intrinsic::vector_reduce_umax: |
48 | case Intrinsic::vector_reduce_umin: |
49 | return Instruction::ICmp; |
50 | case Intrinsic::vector_reduce_fmax: |
51 | case Intrinsic::vector_reduce_fmin: |
52 | return Instruction::FCmp; |
53 | default: |
54 | llvm_unreachable("Unexpected ID" ); |
55 | } |
56 | } |
57 | |
58 | RecurKind getRK(Intrinsic::ID ID) { |
59 | switch (ID) { |
60 | case Intrinsic::vector_reduce_smax: |
61 | return RecurKind::SMax; |
62 | case Intrinsic::vector_reduce_smin: |
63 | return RecurKind::SMin; |
64 | case Intrinsic::vector_reduce_umax: |
65 | return RecurKind::UMax; |
66 | case Intrinsic::vector_reduce_umin: |
67 | return RecurKind::UMin; |
68 | case Intrinsic::vector_reduce_fmax: |
69 | return RecurKind::FMax; |
70 | case Intrinsic::vector_reduce_fmin: |
71 | return RecurKind::FMin; |
72 | default: |
73 | return RecurKind::None; |
74 | } |
75 | } |
76 | |
77 | bool expandReductions(Function &F, const TargetTransformInfo *TTI) { |
78 | bool Changed = false; |
79 | SmallVector<IntrinsicInst *, 4> Worklist; |
80 | for (auto &I : instructions(F)) { |
81 | if (auto *II = dyn_cast<IntrinsicInst>(Val: &I)) { |
82 | switch (II->getIntrinsicID()) { |
83 | default: break; |
84 | case Intrinsic::vector_reduce_fadd: |
85 | case Intrinsic::vector_reduce_fmul: |
86 | case Intrinsic::vector_reduce_add: |
87 | case Intrinsic::vector_reduce_mul: |
88 | case Intrinsic::vector_reduce_and: |
89 | case Intrinsic::vector_reduce_or: |
90 | case Intrinsic::vector_reduce_xor: |
91 | case Intrinsic::vector_reduce_smax: |
92 | case Intrinsic::vector_reduce_smin: |
93 | case Intrinsic::vector_reduce_umax: |
94 | case Intrinsic::vector_reduce_umin: |
95 | case Intrinsic::vector_reduce_fmax: |
96 | case Intrinsic::vector_reduce_fmin: |
97 | if (TTI->shouldExpandReduction(II)) |
98 | Worklist.push_back(Elt: II); |
99 | |
100 | break; |
101 | } |
102 | } |
103 | } |
104 | |
105 | for (auto *II : Worklist) { |
106 | FastMathFlags FMF = |
107 | isa<FPMathOperator>(Val: II) ? II->getFastMathFlags() : FastMathFlags{}; |
108 | Intrinsic::ID ID = II->getIntrinsicID(); |
109 | RecurKind RK = getRK(ID); |
110 | |
111 | Value *Rdx = nullptr; |
112 | IRBuilder<> Builder(II); |
113 | IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); |
114 | Builder.setFastMathFlags(FMF); |
115 | switch (ID) { |
116 | default: llvm_unreachable("Unexpected intrinsic!" ); |
117 | case Intrinsic::vector_reduce_fadd: |
118 | case Intrinsic::vector_reduce_fmul: { |
119 | // FMFs must be attached to the call, otherwise it's an ordered reduction |
120 | // and it can't be handled by generating a shuffle sequence. |
121 | Value *Acc = II->getArgOperand(i: 0); |
122 | Value *Vec = II->getArgOperand(i: 1); |
123 | if (!FMF.allowReassoc()) |
124 | Rdx = getOrderedReduction(Builder, Acc, Src: Vec, Op: getOpcode(ID), MinMaxKind: RK); |
125 | else { |
126 | if (!isPowerOf2_32( |
127 | Value: cast<FixedVectorType>(Val: Vec->getType())->getNumElements())) |
128 | continue; |
129 | |
130 | Rdx = getShuffleReduction(Builder, Src: Vec, Op: getOpcode(ID), MinMaxKind: RK); |
131 | Rdx = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)getOpcode(ID), |
132 | LHS: Acc, RHS: Rdx, Name: "bin.rdx" ); |
133 | } |
134 | break; |
135 | } |
136 | case Intrinsic::vector_reduce_and: |
137 | case Intrinsic::vector_reduce_or: { |
138 | // Canonicalize logical or/and reductions: |
139 | // Or reduction for i1 is represented as: |
140 | // %val = bitcast <ReduxWidth x i1> to iReduxWidth |
141 | // %res = cmp ne iReduxWidth %val, 0 |
142 | // And reduction for i1 is represented as: |
143 | // %val = bitcast <ReduxWidth x i1> to iReduxWidth |
144 | // %res = cmp eq iReduxWidth %val, 11111 |
145 | Value *Vec = II->getArgOperand(i: 0); |
146 | auto *FTy = cast<FixedVectorType>(Val: Vec->getType()); |
147 | unsigned NumElts = FTy->getNumElements(); |
148 | if (!isPowerOf2_32(Value: NumElts)) |
149 | continue; |
150 | |
151 | if (FTy->getElementType() == Builder.getInt1Ty()) { |
152 | Rdx = Builder.CreateBitCast(V: Vec, DestTy: Builder.getIntNTy(N: NumElts)); |
153 | if (ID == Intrinsic::vector_reduce_and) { |
154 | Rdx = Builder.CreateICmpEQ( |
155 | LHS: Rdx, RHS: ConstantInt::getAllOnesValue(Ty: Rdx->getType())); |
156 | } else { |
157 | assert(ID == Intrinsic::vector_reduce_or && "Expected or reduction." ); |
158 | Rdx = Builder.CreateIsNotNull(Arg: Rdx); |
159 | } |
160 | break; |
161 | } |
162 | |
163 | Rdx = getShuffleReduction(Builder, Src: Vec, Op: getOpcode(ID), MinMaxKind: RK); |
164 | break; |
165 | } |
166 | case Intrinsic::vector_reduce_add: |
167 | case Intrinsic::vector_reduce_mul: |
168 | case Intrinsic::vector_reduce_xor: |
169 | case Intrinsic::vector_reduce_smax: |
170 | case Intrinsic::vector_reduce_smin: |
171 | case Intrinsic::vector_reduce_umax: |
172 | case Intrinsic::vector_reduce_umin: { |
173 | Value *Vec = II->getArgOperand(i: 0); |
174 | if (!isPowerOf2_32( |
175 | Value: cast<FixedVectorType>(Val: Vec->getType())->getNumElements())) |
176 | continue; |
177 | |
178 | Rdx = getShuffleReduction(Builder, Src: Vec, Op: getOpcode(ID), MinMaxKind: RK); |
179 | break; |
180 | } |
181 | case Intrinsic::vector_reduce_fmax: |
182 | case Intrinsic::vector_reduce_fmin: { |
183 | // We require "nnan" to use a shuffle reduction; "nsz" is implied by the |
184 | // semantics of the reduction. |
185 | Value *Vec = II->getArgOperand(i: 0); |
186 | if (!isPowerOf2_32( |
187 | Value: cast<FixedVectorType>(Val: Vec->getType())->getNumElements()) || |
188 | !FMF.noNaNs()) |
189 | continue; |
190 | |
191 | Rdx = getShuffleReduction(Builder, Src: Vec, Op: getOpcode(ID), MinMaxKind: RK); |
192 | break; |
193 | } |
194 | } |
195 | II->replaceAllUsesWith(V: Rdx); |
196 | II->eraseFromParent(); |
197 | Changed = true; |
198 | } |
199 | return Changed; |
200 | } |
201 | |
202 | class ExpandReductions : public FunctionPass { |
203 | public: |
204 | static char ID; |
205 | ExpandReductions() : FunctionPass(ID) { |
206 | initializeExpandReductionsPass(*PassRegistry::getPassRegistry()); |
207 | } |
208 | |
209 | bool runOnFunction(Function &F) override { |
210 | const auto *TTI =&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); |
211 | return expandReductions(F, TTI); |
212 | } |
213 | |
214 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
215 | AU.addRequired<TargetTransformInfoWrapperPass>(); |
216 | AU.setPreservesCFG(); |
217 | } |
218 | }; |
219 | } |
220 | |
221 | char ExpandReductions::ID; |
222 | INITIALIZE_PASS_BEGIN(ExpandReductions, "expand-reductions" , |
223 | "Expand reduction intrinsics" , false, false) |
224 | INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) |
225 | INITIALIZE_PASS_END(ExpandReductions, "expand-reductions" , |
226 | "Expand reduction intrinsics" , false, false) |
227 | |
228 | FunctionPass *llvm::createExpandReductionsPass() { |
229 | return new ExpandReductions(); |
230 | } |
231 | |
232 | PreservedAnalyses ExpandReductionsPass::run(Function &F, |
233 | FunctionAnalysisManager &AM) { |
234 | const auto &TTI = AM.getResult<TargetIRAnalysis>(IR&: F); |
235 | if (!expandReductions(F, TTI: &TTI)) |
236 | return PreservedAnalyses::all(); |
237 | PreservedAnalyses PA; |
238 | PA.preserveSet<CFGAnalyses>(); |
239 | return PA; |
240 | } |
241 | |