1 | //===-- lib/CodeGen/GlobalISel/GICombinerHelper.cpp -----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | #include "llvm/CodeGen/GlobalISel/CombinerHelper.h" |
9 | #include "llvm/ADT/APFloat.h" |
10 | #include "llvm/ADT/STLExtras.h" |
11 | #include "llvm/ADT/SetVector.h" |
12 | #include "llvm/ADT/SmallBitVector.h" |
13 | #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h" |
14 | #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h" |
15 | #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" |
16 | #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" |
17 | #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h" |
18 | #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" |
19 | #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" |
20 | #include "llvm/CodeGen/GlobalISel/Utils.h" |
21 | #include "llvm/CodeGen/LowLevelTypeUtils.h" |
22 | #include "llvm/CodeGen/MachineBasicBlock.h" |
23 | #include "llvm/CodeGen/MachineDominators.h" |
24 | #include "llvm/CodeGen/MachineInstr.h" |
25 | #include "llvm/CodeGen/MachineMemOperand.h" |
26 | #include "llvm/CodeGen/MachineRegisterInfo.h" |
27 | #include "llvm/CodeGen/RegisterBankInfo.h" |
28 | #include "llvm/CodeGen/TargetInstrInfo.h" |
29 | #include "llvm/CodeGen/TargetLowering.h" |
30 | #include "llvm/CodeGen/TargetOpcodes.h" |
31 | #include "llvm/IR/ConstantRange.h" |
32 | #include "llvm/IR/DataLayout.h" |
33 | #include "llvm/IR/InstrTypes.h" |
34 | #include "llvm/Support/Casting.h" |
35 | #include "llvm/Support/DivisionByConstantInfo.h" |
36 | #include "llvm/Support/ErrorHandling.h" |
37 | #include "llvm/Support/MathExtras.h" |
38 | #include "llvm/Target/TargetMachine.h" |
39 | #include <cmath> |
40 | #include <optional> |
41 | #include <tuple> |
42 | |
43 | #define DEBUG_TYPE "gi-combiner" |
44 | |
45 | using namespace llvm; |
46 | using namespace MIPatternMatch; |
47 | |
48 | // Option to allow testing of the combiner while no targets know about indexed |
49 | // addressing. |
50 | static cl::opt<bool> |
51 | ForceLegalIndexing("force-legal-indexing" , cl::Hidden, cl::init(Val: false), |
52 | cl::desc("Force all indexed operations to be " |
53 | "legal for the GlobalISel combiner" )); |
54 | |
55 | CombinerHelper::CombinerHelper(GISelChangeObserver &Observer, |
56 | MachineIRBuilder &B, bool IsPreLegalize, |
57 | GISelKnownBits *KB, MachineDominatorTree *MDT, |
58 | const LegalizerInfo *LI) |
59 | : Builder(B), MRI(Builder.getMF().getRegInfo()), Observer(Observer), KB(KB), |
60 | MDT(MDT), IsPreLegalize(IsPreLegalize), LI(LI), |
61 | RBI(Builder.getMF().getSubtarget().getRegBankInfo()), |
62 | TRI(Builder.getMF().getSubtarget().getRegisterInfo()) { |
63 | (void)this->KB; |
64 | } |
65 | |
66 | const TargetLowering &CombinerHelper::getTargetLowering() const { |
67 | return *Builder.getMF().getSubtarget().getTargetLowering(); |
68 | } |
69 | |
70 | /// \returns The little endian in-memory byte position of byte \p I in a |
71 | /// \p ByteWidth bytes wide type. |
72 | /// |
73 | /// E.g. Given a 4-byte type x, x[0] -> byte 0 |
74 | static unsigned littleEndianByteAt(const unsigned ByteWidth, const unsigned I) { |
75 | assert(I < ByteWidth && "I must be in [0, ByteWidth)" ); |
76 | return I; |
77 | } |
78 | |
79 | /// Determines the LogBase2 value for a non-null input value using the |
80 | /// transform: LogBase2(V) = (EltBits - 1) - ctlz(V). |
81 | static Register buildLogBase2(Register V, MachineIRBuilder &MIB) { |
82 | auto &MRI = *MIB.getMRI(); |
83 | LLT Ty = MRI.getType(Reg: V); |
84 | auto Ctlz = MIB.buildCTLZ(Dst: Ty, Src0: V); |
85 | auto Base = MIB.buildConstant(Res: Ty, Val: Ty.getScalarSizeInBits() - 1); |
86 | return MIB.buildSub(Dst: Ty, Src0: Base, Src1: Ctlz).getReg(Idx: 0); |
87 | } |
88 | |
89 | /// \returns The big endian in-memory byte position of byte \p I in a |
90 | /// \p ByteWidth bytes wide type. |
91 | /// |
92 | /// E.g. Given a 4-byte type x, x[0] -> byte 3 |
93 | static unsigned bigEndianByteAt(const unsigned ByteWidth, const unsigned I) { |
94 | assert(I < ByteWidth && "I must be in [0, ByteWidth)" ); |
95 | return ByteWidth - I - 1; |
96 | } |
97 | |
98 | /// Given a map from byte offsets in memory to indices in a load/store, |
99 | /// determine if that map corresponds to a little or big endian byte pattern. |
100 | /// |
101 | /// \param MemOffset2Idx maps memory offsets to address offsets. |
102 | /// \param LowestIdx is the lowest index in \p MemOffset2Idx. |
103 | /// |
104 | /// \returns true if the map corresponds to a big endian byte pattern, false if |
105 | /// it corresponds to a little endian byte pattern, and std::nullopt otherwise. |
106 | /// |
107 | /// E.g. given a 32-bit type x, and x[AddrOffset], the in-memory byte patterns |
108 | /// are as follows: |
109 | /// |
110 | /// AddrOffset Little endian Big endian |
111 | /// 0 0 3 |
112 | /// 1 1 2 |
113 | /// 2 2 1 |
114 | /// 3 3 0 |
115 | static std::optional<bool> |
116 | isBigEndian(const SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx, |
117 | int64_t LowestIdx) { |
118 | // Need at least two byte positions to decide on endianness. |
119 | unsigned Width = MemOffset2Idx.size(); |
120 | if (Width < 2) |
121 | return std::nullopt; |
122 | bool BigEndian = true, LittleEndian = true; |
123 | for (unsigned MemOffset = 0; MemOffset < Width; ++ MemOffset) { |
124 | auto MemOffsetAndIdx = MemOffset2Idx.find(Val: MemOffset); |
125 | if (MemOffsetAndIdx == MemOffset2Idx.end()) |
126 | return std::nullopt; |
127 | const int64_t Idx = MemOffsetAndIdx->second - LowestIdx; |
128 | assert(Idx >= 0 && "Expected non-negative byte offset?" ); |
129 | LittleEndian &= Idx == littleEndianByteAt(ByteWidth: Width, I: MemOffset); |
130 | BigEndian &= Idx == bigEndianByteAt(ByteWidth: Width, I: MemOffset); |
131 | if (!BigEndian && !LittleEndian) |
132 | return std::nullopt; |
133 | } |
134 | |
135 | assert((BigEndian != LittleEndian) && |
136 | "Pattern cannot be both big and little endian!" ); |
137 | return BigEndian; |
138 | } |
139 | |
140 | bool CombinerHelper::isPreLegalize() const { return IsPreLegalize; } |
141 | |
142 | bool CombinerHelper::isLegal(const LegalityQuery &Query) const { |
143 | assert(LI && "Must have LegalizerInfo to query isLegal!" ); |
144 | return LI->getAction(Query).Action == LegalizeActions::Legal; |
145 | } |
146 | |
147 | bool CombinerHelper::isLegalOrBeforeLegalizer( |
148 | const LegalityQuery &Query) const { |
149 | return isPreLegalize() || isLegal(Query); |
150 | } |
151 | |
152 | bool CombinerHelper::isConstantLegalOrBeforeLegalizer(const LLT Ty) const { |
153 | if (!Ty.isVector()) |
154 | return isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_CONSTANT, {Ty}}); |
155 | // Vector constants are represented as a G_BUILD_VECTOR of scalar G_CONSTANTs. |
156 | if (isPreLegalize()) |
157 | return true; |
158 | LLT EltTy = Ty.getElementType(); |
159 | return isLegal(Query: {TargetOpcode::G_BUILD_VECTOR, {Ty, EltTy}}) && |
160 | isLegal(Query: {TargetOpcode::G_CONSTANT, {EltTy}}); |
161 | } |
162 | |
163 | void CombinerHelper::replaceRegWith(MachineRegisterInfo &MRI, Register FromReg, |
164 | Register ToReg) const { |
165 | Observer.changingAllUsesOfReg(MRI, Reg: FromReg); |
166 | |
167 | if (MRI.constrainRegAttrs(Reg: ToReg, ConstrainingReg: FromReg)) |
168 | MRI.replaceRegWith(FromReg, ToReg); |
169 | else |
170 | Builder.buildCopy(Res: ToReg, Op: FromReg); |
171 | |
172 | Observer.finishedChangingAllUsesOfReg(); |
173 | } |
174 | |
175 | void CombinerHelper::replaceRegOpWith(MachineRegisterInfo &MRI, |
176 | MachineOperand &FromRegOp, |
177 | Register ToReg) const { |
178 | assert(FromRegOp.getParent() && "Expected an operand in an MI" ); |
179 | Observer.changingInstr(MI&: *FromRegOp.getParent()); |
180 | |
181 | FromRegOp.setReg(ToReg); |
182 | |
183 | Observer.changedInstr(MI&: *FromRegOp.getParent()); |
184 | } |
185 | |
186 | void CombinerHelper::replaceOpcodeWith(MachineInstr &FromMI, |
187 | unsigned ToOpcode) const { |
188 | Observer.changingInstr(MI&: FromMI); |
189 | |
190 | FromMI.setDesc(Builder.getTII().get(Opcode: ToOpcode)); |
191 | |
192 | Observer.changedInstr(MI&: FromMI); |
193 | } |
194 | |
195 | const RegisterBank *CombinerHelper::getRegBank(Register Reg) const { |
196 | return RBI->getRegBank(Reg, MRI, TRI: *TRI); |
197 | } |
198 | |
199 | void CombinerHelper::setRegBank(Register Reg, const RegisterBank *RegBank) { |
200 | if (RegBank) |
201 | MRI.setRegBank(Reg, RegBank: *RegBank); |
202 | } |
203 | |
204 | bool CombinerHelper::tryCombineCopy(MachineInstr &MI) { |
205 | if (matchCombineCopy(MI)) { |
206 | applyCombineCopy(MI); |
207 | return true; |
208 | } |
209 | return false; |
210 | } |
211 | bool CombinerHelper::matchCombineCopy(MachineInstr &MI) { |
212 | if (MI.getOpcode() != TargetOpcode::COPY) |
213 | return false; |
214 | Register DstReg = MI.getOperand(i: 0).getReg(); |
215 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
216 | return canReplaceReg(DstReg, SrcReg, MRI); |
217 | } |
218 | void CombinerHelper::applyCombineCopy(MachineInstr &MI) { |
219 | Register DstReg = MI.getOperand(i: 0).getReg(); |
220 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
221 | MI.eraseFromParent(); |
222 | replaceRegWith(MRI, FromReg: DstReg, ToReg: SrcReg); |
223 | } |
224 | |
225 | bool CombinerHelper::tryCombineConcatVectors(MachineInstr &MI) { |
226 | bool IsUndef = false; |
227 | SmallVector<Register, 4> Ops; |
228 | if (matchCombineConcatVectors(MI, IsUndef, Ops)) { |
229 | applyCombineConcatVectors(MI, IsUndef, Ops); |
230 | return true; |
231 | } |
232 | return false; |
233 | } |
234 | |
235 | bool CombinerHelper::matchCombineConcatVectors(MachineInstr &MI, bool &IsUndef, |
236 | SmallVectorImpl<Register> &Ops) { |
237 | assert(MI.getOpcode() == TargetOpcode::G_CONCAT_VECTORS && |
238 | "Invalid instruction" ); |
239 | IsUndef = true; |
240 | MachineInstr *Undef = nullptr; |
241 | |
242 | // Walk over all the operands of concat vectors and check if they are |
243 | // build_vector themselves or undef. |
244 | // Then collect their operands in Ops. |
245 | for (const MachineOperand &MO : MI.uses()) { |
246 | Register Reg = MO.getReg(); |
247 | MachineInstr *Def = MRI.getVRegDef(Reg); |
248 | assert(Def && "Operand not defined" ); |
249 | switch (Def->getOpcode()) { |
250 | case TargetOpcode::G_BUILD_VECTOR: |
251 | IsUndef = false; |
252 | // Remember the operands of the build_vector to fold |
253 | // them into the yet-to-build flattened concat vectors. |
254 | for (const MachineOperand &BuildVecMO : Def->uses()) |
255 | Ops.push_back(Elt: BuildVecMO.getReg()); |
256 | break; |
257 | case TargetOpcode::G_IMPLICIT_DEF: { |
258 | LLT OpType = MRI.getType(Reg); |
259 | // Keep one undef value for all the undef operands. |
260 | if (!Undef) { |
261 | Builder.setInsertPt(MBB&: *MI.getParent(), II: MI); |
262 | Undef = Builder.buildUndef(Res: OpType.getScalarType()); |
263 | } |
264 | assert(MRI.getType(Undef->getOperand(0).getReg()) == |
265 | OpType.getScalarType() && |
266 | "All undefs should have the same type" ); |
267 | // Break the undef vector in as many scalar elements as needed |
268 | // for the flattening. |
269 | for (unsigned EltIdx = 0, EltEnd = OpType.getNumElements(); |
270 | EltIdx != EltEnd; ++EltIdx) |
271 | Ops.push_back(Elt: Undef->getOperand(i: 0).getReg()); |
272 | break; |
273 | } |
274 | default: |
275 | return false; |
276 | } |
277 | } |
278 | return true; |
279 | } |
280 | void CombinerHelper::applyCombineConcatVectors( |
281 | MachineInstr &MI, bool IsUndef, const ArrayRef<Register> Ops) { |
282 | // We determined that the concat_vectors can be flatten. |
283 | // Generate the flattened build_vector. |
284 | Register DstReg = MI.getOperand(i: 0).getReg(); |
285 | Builder.setInsertPt(MBB&: *MI.getParent(), II: MI); |
286 | Register NewDstReg = MRI.cloneVirtualRegister(VReg: DstReg); |
287 | |
288 | // Note: IsUndef is sort of redundant. We could have determine it by |
289 | // checking that at all Ops are undef. Alternatively, we could have |
290 | // generate a build_vector of undefs and rely on another combine to |
291 | // clean that up. For now, given we already gather this information |
292 | // in tryCombineConcatVectors, just save compile time and issue the |
293 | // right thing. |
294 | if (IsUndef) |
295 | Builder.buildUndef(Res: NewDstReg); |
296 | else |
297 | Builder.buildBuildVector(Res: NewDstReg, Ops); |
298 | MI.eraseFromParent(); |
299 | replaceRegWith(MRI, FromReg: DstReg, ToReg: NewDstReg); |
300 | } |
301 | |
302 | bool CombinerHelper::tryCombineShuffleVector(MachineInstr &MI) { |
303 | SmallVector<Register, 4> Ops; |
304 | if (matchCombineShuffleVector(MI, Ops)) { |
305 | applyCombineShuffleVector(MI, Ops); |
306 | return true; |
307 | } |
308 | return false; |
309 | } |
310 | |
311 | bool CombinerHelper::matchCombineShuffleVector(MachineInstr &MI, |
312 | SmallVectorImpl<Register> &Ops) { |
313 | assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR && |
314 | "Invalid instruction kind" ); |
315 | LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
316 | Register Src1 = MI.getOperand(i: 1).getReg(); |
317 | LLT SrcType = MRI.getType(Reg: Src1); |
318 | // As bizarre as it may look, shuffle vector can actually produce |
319 | // scalar! This is because at the IR level a <1 x ty> shuffle |
320 | // vector is perfectly valid. |
321 | unsigned DstNumElts = DstType.isVector() ? DstType.getNumElements() : 1; |
322 | unsigned SrcNumElts = SrcType.isVector() ? SrcType.getNumElements() : 1; |
323 | |
324 | // If the resulting vector is smaller than the size of the source |
325 | // vectors being concatenated, we won't be able to replace the |
326 | // shuffle vector into a concat_vectors. |
327 | // |
328 | // Note: We may still be able to produce a concat_vectors fed by |
329 | // extract_vector_elt and so on. It is less clear that would |
330 | // be better though, so don't bother for now. |
331 | // |
332 | // If the destination is a scalar, the size of the sources doesn't |
333 | // matter. we will lower the shuffle to a plain copy. This will |
334 | // work only if the source and destination have the same size. But |
335 | // that's covered by the next condition. |
336 | // |
337 | // TODO: If the size between the source and destination don't match |
338 | // we could still emit an extract vector element in that case. |
339 | if (DstNumElts < 2 * SrcNumElts && DstNumElts != 1) |
340 | return false; |
341 | |
342 | // Check that the shuffle mask can be broken evenly between the |
343 | // different sources. |
344 | if (DstNumElts % SrcNumElts != 0) |
345 | return false; |
346 | |
347 | // Mask length is a multiple of the source vector length. |
348 | // Check if the shuffle is some kind of concatenation of the input |
349 | // vectors. |
350 | unsigned NumConcat = DstNumElts / SrcNumElts; |
351 | SmallVector<int, 8> ConcatSrcs(NumConcat, -1); |
352 | ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask(); |
353 | for (unsigned i = 0; i != DstNumElts; ++i) { |
354 | int Idx = Mask[i]; |
355 | // Undef value. |
356 | if (Idx < 0) |
357 | continue; |
358 | // Ensure the indices in each SrcType sized piece are sequential and that |
359 | // the same source is used for the whole piece. |
360 | if ((Idx % SrcNumElts != (i % SrcNumElts)) || |
361 | (ConcatSrcs[i / SrcNumElts] >= 0 && |
362 | ConcatSrcs[i / SrcNumElts] != (int)(Idx / SrcNumElts))) |
363 | return false; |
364 | // Remember which source this index came from. |
365 | ConcatSrcs[i / SrcNumElts] = Idx / SrcNumElts; |
366 | } |
367 | |
368 | // The shuffle is concatenating multiple vectors together. |
369 | // Collect the different operands for that. |
370 | Register UndefReg; |
371 | Register Src2 = MI.getOperand(i: 2).getReg(); |
372 | for (auto Src : ConcatSrcs) { |
373 | if (Src < 0) { |
374 | if (!UndefReg) { |
375 | Builder.setInsertPt(MBB&: *MI.getParent(), II: MI); |
376 | UndefReg = Builder.buildUndef(Res: SrcType).getReg(Idx: 0); |
377 | } |
378 | Ops.push_back(Elt: UndefReg); |
379 | } else if (Src == 0) |
380 | Ops.push_back(Elt: Src1); |
381 | else |
382 | Ops.push_back(Elt: Src2); |
383 | } |
384 | return true; |
385 | } |
386 | |
387 | void CombinerHelper::applyCombineShuffleVector(MachineInstr &MI, |
388 | const ArrayRef<Register> Ops) { |
389 | Register DstReg = MI.getOperand(i: 0).getReg(); |
390 | Builder.setInsertPt(MBB&: *MI.getParent(), II: MI); |
391 | Register NewDstReg = MRI.cloneVirtualRegister(VReg: DstReg); |
392 | |
393 | if (Ops.size() == 1) |
394 | Builder.buildCopy(Res: NewDstReg, Op: Ops[0]); |
395 | else |
396 | Builder.buildMergeLikeInstr(Res: NewDstReg, Ops); |
397 | |
398 | MI.eraseFromParent(); |
399 | replaceRegWith(MRI, FromReg: DstReg, ToReg: NewDstReg); |
400 | } |
401 | |
402 | bool CombinerHelper::(MachineInstr &MI) { |
403 | assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR && |
404 | "Invalid instruction kind" ); |
405 | |
406 | ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask(); |
407 | return Mask.size() == 1; |
408 | } |
409 | |
410 | void CombinerHelper::(MachineInstr &MI) { |
411 | Register DstReg = MI.getOperand(i: 0).getReg(); |
412 | Builder.setInsertPt(MBB&: *MI.getParent(), II: MI); |
413 | |
414 | int I = MI.getOperand(i: 3).getShuffleMask()[0]; |
415 | Register Src1 = MI.getOperand(i: 1).getReg(); |
416 | LLT Src1Ty = MRI.getType(Reg: Src1); |
417 | int Src1NumElts = Src1Ty.isVector() ? Src1Ty.getNumElements() : 1; |
418 | Register SrcReg; |
419 | if (I >= Src1NumElts) { |
420 | SrcReg = MI.getOperand(i: 2).getReg(); |
421 | I -= Src1NumElts; |
422 | } else if (I >= 0) |
423 | SrcReg = Src1; |
424 | |
425 | if (I < 0) |
426 | Builder.buildUndef(Res: DstReg); |
427 | else if (!MRI.getType(Reg: SrcReg).isVector()) |
428 | Builder.buildCopy(Res: DstReg, Op: SrcReg); |
429 | else |
430 | Builder.buildExtractVectorElementConstant(Res: DstReg, Val: SrcReg, Idx: I); |
431 | |
432 | MI.eraseFromParent(); |
433 | } |
434 | |
435 | namespace { |
436 | |
437 | /// Select a preference between two uses. CurrentUse is the current preference |
438 | /// while *ForCandidate is attributes of the candidate under consideration. |
439 | PreferredTuple ChoosePreferredUse(MachineInstr &LoadMI, |
440 | PreferredTuple &CurrentUse, |
441 | const LLT TyForCandidate, |
442 | unsigned OpcodeForCandidate, |
443 | MachineInstr *MIForCandidate) { |
444 | if (!CurrentUse.Ty.isValid()) { |
445 | if (CurrentUse.ExtendOpcode == OpcodeForCandidate || |
446 | CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT) |
447 | return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate}; |
448 | return CurrentUse; |
449 | } |
450 | |
451 | // We permit the extend to hoist through basic blocks but this is only |
452 | // sensible if the target has extending loads. If you end up lowering back |
453 | // into a load and extend during the legalizer then the end result is |
454 | // hoisting the extend up to the load. |
455 | |
456 | // Prefer defined extensions to undefined extensions as these are more |
457 | // likely to reduce the number of instructions. |
458 | if (OpcodeForCandidate == TargetOpcode::G_ANYEXT && |
459 | CurrentUse.ExtendOpcode != TargetOpcode::G_ANYEXT) |
460 | return CurrentUse; |
461 | else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT && |
462 | OpcodeForCandidate != TargetOpcode::G_ANYEXT) |
463 | return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate}; |
464 | |
465 | // Prefer sign extensions to zero extensions as sign-extensions tend to be |
466 | // more expensive. Don't do this if the load is already a zero-extend load |
467 | // though, otherwise we'll rewrite a zero-extend load into a sign-extend |
468 | // later. |
469 | if (!isa<GZExtLoad>(Val: LoadMI) && CurrentUse.Ty == TyForCandidate) { |
470 | if (CurrentUse.ExtendOpcode == TargetOpcode::G_SEXT && |
471 | OpcodeForCandidate == TargetOpcode::G_ZEXT) |
472 | return CurrentUse; |
473 | else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ZEXT && |
474 | OpcodeForCandidate == TargetOpcode::G_SEXT) |
475 | return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate}; |
476 | } |
477 | |
478 | // This is potentially target specific. We've chosen the largest type |
479 | // because G_TRUNC is usually free. One potential catch with this is that |
480 | // some targets have a reduced number of larger registers than smaller |
481 | // registers and this choice potentially increases the live-range for the |
482 | // larger value. |
483 | if (TyForCandidate.getSizeInBits() > CurrentUse.Ty.getSizeInBits()) { |
484 | return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate}; |
485 | } |
486 | return CurrentUse; |
487 | } |
488 | |
489 | /// Find a suitable place to insert some instructions and insert them. This |
490 | /// function accounts for special cases like inserting before a PHI node. |
491 | /// The current strategy for inserting before PHI's is to duplicate the |
492 | /// instructions for each predecessor. However, while that's ok for G_TRUNC |
493 | /// on most targets since it generally requires no code, other targets/cases may |
494 | /// want to try harder to find a dominating block. |
495 | static void InsertInsnsWithoutSideEffectsBeforeUse( |
496 | MachineIRBuilder &Builder, MachineInstr &DefMI, MachineOperand &UseMO, |
497 | std::function<void(MachineBasicBlock *, MachineBasicBlock::iterator, |
498 | MachineOperand &UseMO)> |
499 | Inserter) { |
500 | MachineInstr &UseMI = *UseMO.getParent(); |
501 | |
502 | MachineBasicBlock *InsertBB = UseMI.getParent(); |
503 | |
504 | // If the use is a PHI then we want the predecessor block instead. |
505 | if (UseMI.isPHI()) { |
506 | MachineOperand *PredBB = std::next(x: &UseMO); |
507 | InsertBB = PredBB->getMBB(); |
508 | } |
509 | |
510 | // If the block is the same block as the def then we want to insert just after |
511 | // the def instead of at the start of the block. |
512 | if (InsertBB == DefMI.getParent()) { |
513 | MachineBasicBlock::iterator InsertPt = &DefMI; |
514 | Inserter(InsertBB, std::next(x: InsertPt), UseMO); |
515 | return; |
516 | } |
517 | |
518 | // Otherwise we want the start of the BB |
519 | Inserter(InsertBB, InsertBB->getFirstNonPHI(), UseMO); |
520 | } |
521 | } // end anonymous namespace |
522 | |
523 | bool CombinerHelper::tryCombineExtendingLoads(MachineInstr &MI) { |
524 | PreferredTuple Preferred; |
525 | if (matchCombineExtendingLoads(MI, MatchInfo&: Preferred)) { |
526 | applyCombineExtendingLoads(MI, MatchInfo&: Preferred); |
527 | return true; |
528 | } |
529 | return false; |
530 | } |
531 | |
532 | static unsigned getExtLoadOpcForExtend(unsigned ExtOpc) { |
533 | unsigned CandidateLoadOpc; |
534 | switch (ExtOpc) { |
535 | case TargetOpcode::G_ANYEXT: |
536 | CandidateLoadOpc = TargetOpcode::G_LOAD; |
537 | break; |
538 | case TargetOpcode::G_SEXT: |
539 | CandidateLoadOpc = TargetOpcode::G_SEXTLOAD; |
540 | break; |
541 | case TargetOpcode::G_ZEXT: |
542 | CandidateLoadOpc = TargetOpcode::G_ZEXTLOAD; |
543 | break; |
544 | default: |
545 | llvm_unreachable("Unexpected extend opc" ); |
546 | } |
547 | return CandidateLoadOpc; |
548 | } |
549 | |
550 | bool CombinerHelper::matchCombineExtendingLoads(MachineInstr &MI, |
551 | PreferredTuple &Preferred) { |
552 | // We match the loads and follow the uses to the extend instead of matching |
553 | // the extends and following the def to the load. This is because the load |
554 | // must remain in the same position for correctness (unless we also add code |
555 | // to find a safe place to sink it) whereas the extend is freely movable. |
556 | // It also prevents us from duplicating the load for the volatile case or just |
557 | // for performance. |
558 | GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(Val: &MI); |
559 | if (!LoadMI) |
560 | return false; |
561 | |
562 | Register LoadReg = LoadMI->getDstReg(); |
563 | |
564 | LLT LoadValueTy = MRI.getType(Reg: LoadReg); |
565 | if (!LoadValueTy.isScalar()) |
566 | return false; |
567 | |
568 | // Most architectures are going to legalize <s8 loads into at least a 1 byte |
569 | // load, and the MMOs can only describe memory accesses in multiples of bytes. |
570 | // If we try to perform extload combining on those, we can end up with |
571 | // %a(s8) = extload %ptr (load 1 byte from %ptr) |
572 | // ... which is an illegal extload instruction. |
573 | if (LoadValueTy.getSizeInBits() < 8) |
574 | return false; |
575 | |
576 | // For non power-of-2 types, they will very likely be legalized into multiple |
577 | // loads. Don't bother trying to match them into extending loads. |
578 | if (!llvm::has_single_bit<uint32_t>(Value: LoadValueTy.getSizeInBits())) |
579 | return false; |
580 | |
581 | // Find the preferred type aside from the any-extends (unless it's the only |
582 | // one) and non-extending ops. We'll emit an extending load to that type and |
583 | // and emit a variant of (extend (trunc X)) for the others according to the |
584 | // relative type sizes. At the same time, pick an extend to use based on the |
585 | // extend involved in the chosen type. |
586 | unsigned PreferredOpcode = |
587 | isa<GLoad>(Val: &MI) |
588 | ? TargetOpcode::G_ANYEXT |
589 | : isa<GSExtLoad>(Val: &MI) ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT; |
590 | Preferred = {.Ty: LLT(), .ExtendOpcode: PreferredOpcode, .MI: nullptr}; |
591 | for (auto &UseMI : MRI.use_nodbg_instructions(Reg: LoadReg)) { |
592 | if (UseMI.getOpcode() == TargetOpcode::G_SEXT || |
593 | UseMI.getOpcode() == TargetOpcode::G_ZEXT || |
594 | (UseMI.getOpcode() == TargetOpcode::G_ANYEXT)) { |
595 | const auto &MMO = LoadMI->getMMO(); |
596 | // For atomics, only form anyextending loads. |
597 | if (MMO.isAtomic() && UseMI.getOpcode() != TargetOpcode::G_ANYEXT) |
598 | continue; |
599 | // Check for legality. |
600 | if (!isPreLegalize()) { |
601 | LegalityQuery::MemDesc MMDesc(MMO); |
602 | unsigned CandidateLoadOpc = getExtLoadOpcForExtend(ExtOpc: UseMI.getOpcode()); |
603 | LLT UseTy = MRI.getType(Reg: UseMI.getOperand(i: 0).getReg()); |
604 | LLT SrcTy = MRI.getType(Reg: LoadMI->getPointerReg()); |
605 | if (LI->getAction(Query: {CandidateLoadOpc, {UseTy, SrcTy}, {MMDesc}}) |
606 | .Action != LegalizeActions::Legal) |
607 | continue; |
608 | } |
609 | Preferred = ChoosePreferredUse(LoadMI&: MI, CurrentUse&: Preferred, |
610 | TyForCandidate: MRI.getType(Reg: UseMI.getOperand(i: 0).getReg()), |
611 | OpcodeForCandidate: UseMI.getOpcode(), MIForCandidate: &UseMI); |
612 | } |
613 | } |
614 | |
615 | // There were no extends |
616 | if (!Preferred.MI) |
617 | return false; |
618 | // It should be impossible to chose an extend without selecting a different |
619 | // type since by definition the result of an extend is larger. |
620 | assert(Preferred.Ty != LoadValueTy && "Extending to same type?" ); |
621 | |
622 | LLVM_DEBUG(dbgs() << "Preferred use is: " << *Preferred.MI); |
623 | return true; |
624 | } |
625 | |
626 | void CombinerHelper::applyCombineExtendingLoads(MachineInstr &MI, |
627 | PreferredTuple &Preferred) { |
628 | // Rewrite the load to the chosen extending load. |
629 | Register ChosenDstReg = Preferred.MI->getOperand(i: 0).getReg(); |
630 | |
631 | // Inserter to insert a truncate back to the original type at a given point |
632 | // with some basic CSE to limit truncate duplication to one per BB. |
633 | DenseMap<MachineBasicBlock *, MachineInstr *> EmittedInsns; |
634 | auto InsertTruncAt = [&](MachineBasicBlock *InsertIntoBB, |
635 | MachineBasicBlock::iterator InsertBefore, |
636 | MachineOperand &UseMO) { |
637 | MachineInstr *PreviouslyEmitted = EmittedInsns.lookup(Val: InsertIntoBB); |
638 | if (PreviouslyEmitted) { |
639 | Observer.changingInstr(MI&: *UseMO.getParent()); |
640 | UseMO.setReg(PreviouslyEmitted->getOperand(i: 0).getReg()); |
641 | Observer.changedInstr(MI&: *UseMO.getParent()); |
642 | return; |
643 | } |
644 | |
645 | Builder.setInsertPt(MBB&: *InsertIntoBB, II: InsertBefore); |
646 | Register NewDstReg = MRI.cloneVirtualRegister(VReg: MI.getOperand(i: 0).getReg()); |
647 | MachineInstr *NewMI = Builder.buildTrunc(Res: NewDstReg, Op: ChosenDstReg); |
648 | EmittedInsns[InsertIntoBB] = NewMI; |
649 | replaceRegOpWith(MRI, FromRegOp&: UseMO, ToReg: NewDstReg); |
650 | }; |
651 | |
652 | Observer.changingInstr(MI); |
653 | unsigned LoadOpc = getExtLoadOpcForExtend(ExtOpc: Preferred.ExtendOpcode); |
654 | MI.setDesc(Builder.getTII().get(Opcode: LoadOpc)); |
655 | |
656 | // Rewrite all the uses to fix up the types. |
657 | auto &LoadValue = MI.getOperand(i: 0); |
658 | SmallVector<MachineOperand *, 4> Uses; |
659 | for (auto &UseMO : MRI.use_operands(Reg: LoadValue.getReg())) |
660 | Uses.push_back(Elt: &UseMO); |
661 | |
662 | for (auto *UseMO : Uses) { |
663 | MachineInstr *UseMI = UseMO->getParent(); |
664 | |
665 | // If the extend is compatible with the preferred extend then we should fix |
666 | // up the type and extend so that it uses the preferred use. |
667 | if (UseMI->getOpcode() == Preferred.ExtendOpcode || |
668 | UseMI->getOpcode() == TargetOpcode::G_ANYEXT) { |
669 | Register UseDstReg = UseMI->getOperand(i: 0).getReg(); |
670 | MachineOperand &UseSrcMO = UseMI->getOperand(i: 1); |
671 | const LLT UseDstTy = MRI.getType(Reg: UseDstReg); |
672 | if (UseDstReg != ChosenDstReg) { |
673 | if (Preferred.Ty == UseDstTy) { |
674 | // If the use has the same type as the preferred use, then merge |
675 | // the vregs and erase the extend. For example: |
676 | // %1:_(s8) = G_LOAD ... |
677 | // %2:_(s32) = G_SEXT %1(s8) |
678 | // %3:_(s32) = G_ANYEXT %1(s8) |
679 | // ... = ... %3(s32) |
680 | // rewrites to: |
681 | // %2:_(s32) = G_SEXTLOAD ... |
682 | // ... = ... %2(s32) |
683 | replaceRegWith(MRI, FromReg: UseDstReg, ToReg: ChosenDstReg); |
684 | Observer.erasingInstr(MI&: *UseMO->getParent()); |
685 | UseMO->getParent()->eraseFromParent(); |
686 | } else if (Preferred.Ty.getSizeInBits() < UseDstTy.getSizeInBits()) { |
687 | // If the preferred size is smaller, then keep the extend but extend |
688 | // from the result of the extending load. For example: |
689 | // %1:_(s8) = G_LOAD ... |
690 | // %2:_(s32) = G_SEXT %1(s8) |
691 | // %3:_(s64) = G_ANYEXT %1(s8) |
692 | // ... = ... %3(s64) |
693 | /// rewrites to: |
694 | // %2:_(s32) = G_SEXTLOAD ... |
695 | // %3:_(s64) = G_ANYEXT %2:_(s32) |
696 | // ... = ... %3(s64) |
697 | replaceRegOpWith(MRI, FromRegOp&: UseSrcMO, ToReg: ChosenDstReg); |
698 | } else { |
699 | // If the preferred size is large, then insert a truncate. For |
700 | // example: |
701 | // %1:_(s8) = G_LOAD ... |
702 | // %2:_(s64) = G_SEXT %1(s8) |
703 | // %3:_(s32) = G_ZEXT %1(s8) |
704 | // ... = ... %3(s32) |
705 | /// rewrites to: |
706 | // %2:_(s64) = G_SEXTLOAD ... |
707 | // %4:_(s8) = G_TRUNC %2:_(s32) |
708 | // %3:_(s64) = G_ZEXT %2:_(s8) |
709 | // ... = ... %3(s64) |
710 | InsertInsnsWithoutSideEffectsBeforeUse(Builder, DefMI&: MI, UseMO&: *UseMO, |
711 | Inserter: InsertTruncAt); |
712 | } |
713 | continue; |
714 | } |
715 | // The use is (one of) the uses of the preferred use we chose earlier. |
716 | // We're going to update the load to def this value later so just erase |
717 | // the old extend. |
718 | Observer.erasingInstr(MI&: *UseMO->getParent()); |
719 | UseMO->getParent()->eraseFromParent(); |
720 | continue; |
721 | } |
722 | |
723 | // The use isn't an extend. Truncate back to the type we originally loaded. |
724 | // This is free on many targets. |
725 | InsertInsnsWithoutSideEffectsBeforeUse(Builder, DefMI&: MI, UseMO&: *UseMO, Inserter: InsertTruncAt); |
726 | } |
727 | |
728 | MI.getOperand(i: 0).setReg(ChosenDstReg); |
729 | Observer.changedInstr(MI); |
730 | } |
731 | |
732 | bool CombinerHelper::matchCombineLoadWithAndMask(MachineInstr &MI, |
733 | BuildFnTy &MatchInfo) { |
734 | assert(MI.getOpcode() == TargetOpcode::G_AND); |
735 | |
736 | // If we have the following code: |
737 | // %mask = G_CONSTANT 255 |
738 | // %ld = G_LOAD %ptr, (load s16) |
739 | // %and = G_AND %ld, %mask |
740 | // |
741 | // Try to fold it into |
742 | // %ld = G_ZEXTLOAD %ptr, (load s8) |
743 | |
744 | Register Dst = MI.getOperand(i: 0).getReg(); |
745 | if (MRI.getType(Reg: Dst).isVector()) |
746 | return false; |
747 | |
748 | auto MaybeMask = |
749 | getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI); |
750 | if (!MaybeMask) |
751 | return false; |
752 | |
753 | APInt MaskVal = MaybeMask->Value; |
754 | |
755 | if (!MaskVal.isMask()) |
756 | return false; |
757 | |
758 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
759 | // Don't use getOpcodeDef() here since intermediate instructions may have |
760 | // multiple users. |
761 | GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(Val: MRI.getVRegDef(Reg: SrcReg)); |
762 | if (!LoadMI || !MRI.hasOneNonDBGUse(RegNo: LoadMI->getDstReg())) |
763 | return false; |
764 | |
765 | Register LoadReg = LoadMI->getDstReg(); |
766 | LLT RegTy = MRI.getType(Reg: LoadReg); |
767 | Register PtrReg = LoadMI->getPointerReg(); |
768 | unsigned RegSize = RegTy.getSizeInBits(); |
769 | uint64_t LoadSizeBits = LoadMI->getMemSizeInBits(); |
770 | unsigned MaskSizeBits = MaskVal.countr_one(); |
771 | |
772 | // The mask may not be larger than the in-memory type, as it might cover sign |
773 | // extended bits |
774 | if (MaskSizeBits > LoadSizeBits) |
775 | return false; |
776 | |
777 | // If the mask covers the whole destination register, there's nothing to |
778 | // extend |
779 | if (MaskSizeBits >= RegSize) |
780 | return false; |
781 | |
782 | // Most targets cannot deal with loads of size < 8 and need to re-legalize to |
783 | // at least byte loads. Avoid creating such loads here |
784 | if (MaskSizeBits < 8 || !isPowerOf2_32(Value: MaskSizeBits)) |
785 | return false; |
786 | |
787 | const MachineMemOperand &MMO = LoadMI->getMMO(); |
788 | LegalityQuery::MemDesc MemDesc(MMO); |
789 | |
790 | // Don't modify the memory access size if this is atomic/volatile, but we can |
791 | // still adjust the opcode to indicate the high bit behavior. |
792 | if (LoadMI->isSimple()) |
793 | MemDesc.MemoryTy = LLT::scalar(SizeInBits: MaskSizeBits); |
794 | else if (LoadSizeBits > MaskSizeBits || LoadSizeBits == RegSize) |
795 | return false; |
796 | |
797 | // TODO: Could check if it's legal with the reduced or original memory size. |
798 | if (!isLegalOrBeforeLegalizer( |
799 | Query: {TargetOpcode::G_ZEXTLOAD, {RegTy, MRI.getType(Reg: PtrReg)}, {MemDesc}})) |
800 | return false; |
801 | |
802 | MatchInfo = [=](MachineIRBuilder &B) { |
803 | B.setInstrAndDebugLoc(*LoadMI); |
804 | auto &MF = B.getMF(); |
805 | auto PtrInfo = MMO.getPointerInfo(); |
806 | auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Ty: MemDesc.MemoryTy); |
807 | B.buildLoadInstr(Opcode: TargetOpcode::G_ZEXTLOAD, Res: Dst, Addr: PtrReg, MMO&: *NewMMO); |
808 | LoadMI->eraseFromParent(); |
809 | }; |
810 | return true; |
811 | } |
812 | |
813 | bool CombinerHelper::isPredecessor(const MachineInstr &DefMI, |
814 | const MachineInstr &UseMI) { |
815 | assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() && |
816 | "shouldn't consider debug uses" ); |
817 | assert(DefMI.getParent() == UseMI.getParent()); |
818 | if (&DefMI == &UseMI) |
819 | return true; |
820 | const MachineBasicBlock &MBB = *DefMI.getParent(); |
821 | auto DefOrUse = find_if(Range: MBB, P: [&DefMI, &UseMI](const MachineInstr &MI) { |
822 | return &MI == &DefMI || &MI == &UseMI; |
823 | }); |
824 | if (DefOrUse == MBB.end()) |
825 | llvm_unreachable("Block must contain both DefMI and UseMI!" ); |
826 | return &*DefOrUse == &DefMI; |
827 | } |
828 | |
829 | bool CombinerHelper::dominates(const MachineInstr &DefMI, |
830 | const MachineInstr &UseMI) { |
831 | assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() && |
832 | "shouldn't consider debug uses" ); |
833 | if (MDT) |
834 | return MDT->dominates(A: &DefMI, B: &UseMI); |
835 | else if (DefMI.getParent() != UseMI.getParent()) |
836 | return false; |
837 | |
838 | return isPredecessor(DefMI, UseMI); |
839 | } |
840 | |
841 | bool CombinerHelper::matchSextTruncSextLoad(MachineInstr &MI) { |
842 | assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); |
843 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
844 | Register LoadUser = SrcReg; |
845 | |
846 | if (MRI.getType(Reg: SrcReg).isVector()) |
847 | return false; |
848 | |
849 | Register TruncSrc; |
850 | if (mi_match(R: SrcReg, MRI, P: m_GTrunc(Src: m_Reg(R&: TruncSrc)))) |
851 | LoadUser = TruncSrc; |
852 | |
853 | uint64_t SizeInBits = MI.getOperand(i: 2).getImm(); |
854 | // If the source is a G_SEXTLOAD from the same bit width, then we don't |
855 | // need any extend at all, just a truncate. |
856 | if (auto *LoadMI = getOpcodeDef<GSExtLoad>(Reg: LoadUser, MRI)) { |
857 | // If truncating more than the original extended value, abort. |
858 | auto LoadSizeBits = LoadMI->getMemSizeInBits(); |
859 | if (TruncSrc && MRI.getType(Reg: TruncSrc).getSizeInBits() < LoadSizeBits) |
860 | return false; |
861 | if (LoadSizeBits == SizeInBits) |
862 | return true; |
863 | } |
864 | return false; |
865 | } |
866 | |
867 | void CombinerHelper::applySextTruncSextLoad(MachineInstr &MI) { |
868 | assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); |
869 | Builder.setInstrAndDebugLoc(MI); |
870 | Builder.buildCopy(Res: MI.getOperand(i: 0).getReg(), Op: MI.getOperand(i: 1).getReg()); |
871 | MI.eraseFromParent(); |
872 | } |
873 | |
874 | bool CombinerHelper::matchSextInRegOfLoad( |
875 | MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) { |
876 | assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); |
877 | |
878 | Register DstReg = MI.getOperand(i: 0).getReg(); |
879 | LLT RegTy = MRI.getType(Reg: DstReg); |
880 | |
881 | // Only supports scalars for now. |
882 | if (RegTy.isVector()) |
883 | return false; |
884 | |
885 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
886 | auto *LoadDef = getOpcodeDef<GLoad>(Reg: SrcReg, MRI); |
887 | if (!LoadDef || !MRI.hasOneNonDBGUse(RegNo: DstReg)) |
888 | return false; |
889 | |
890 | uint64_t MemBits = LoadDef->getMemSizeInBits(); |
891 | |
892 | // If the sign extend extends from a narrower width than the load's width, |
893 | // then we can narrow the load width when we combine to a G_SEXTLOAD. |
894 | // Avoid widening the load at all. |
895 | unsigned NewSizeBits = std::min(a: (uint64_t)MI.getOperand(i: 2).getImm(), b: MemBits); |
896 | |
897 | // Don't generate G_SEXTLOADs with a < 1 byte width. |
898 | if (NewSizeBits < 8) |
899 | return false; |
900 | // Don't bother creating a non-power-2 sextload, it will likely be broken up |
901 | // anyway for most targets. |
902 | if (!isPowerOf2_32(Value: NewSizeBits)) |
903 | return false; |
904 | |
905 | const MachineMemOperand &MMO = LoadDef->getMMO(); |
906 | LegalityQuery::MemDesc MMDesc(MMO); |
907 | |
908 | // Don't modify the memory access size if this is atomic/volatile, but we can |
909 | // still adjust the opcode to indicate the high bit behavior. |
910 | if (LoadDef->isSimple()) |
911 | MMDesc.MemoryTy = LLT::scalar(SizeInBits: NewSizeBits); |
912 | else if (MemBits > NewSizeBits || MemBits == RegTy.getSizeInBits()) |
913 | return false; |
914 | |
915 | // TODO: Could check if it's legal with the reduced or original memory size. |
916 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SEXTLOAD, |
917 | {MRI.getType(Reg: LoadDef->getDstReg()), |
918 | MRI.getType(Reg: LoadDef->getPointerReg())}, |
919 | {MMDesc}})) |
920 | return false; |
921 | |
922 | MatchInfo = std::make_tuple(args: LoadDef->getDstReg(), args&: NewSizeBits); |
923 | return true; |
924 | } |
925 | |
926 | void CombinerHelper::applySextInRegOfLoad( |
927 | MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) { |
928 | assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); |
929 | Register LoadReg; |
930 | unsigned ScalarSizeBits; |
931 | std::tie(args&: LoadReg, args&: ScalarSizeBits) = MatchInfo; |
932 | GLoad *LoadDef = cast<GLoad>(Val: MRI.getVRegDef(Reg: LoadReg)); |
933 | |
934 | // If we have the following: |
935 | // %ld = G_LOAD %ptr, (load 2) |
936 | // %ext = G_SEXT_INREG %ld, 8 |
937 | // ==> |
938 | // %ld = G_SEXTLOAD %ptr (load 1) |
939 | |
940 | auto &MMO = LoadDef->getMMO(); |
941 | Builder.setInstrAndDebugLoc(*LoadDef); |
942 | auto &MF = Builder.getMF(); |
943 | auto PtrInfo = MMO.getPointerInfo(); |
944 | auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Size: ScalarSizeBits / 8); |
945 | Builder.buildLoadInstr(Opcode: TargetOpcode::G_SEXTLOAD, Res: MI.getOperand(i: 0).getReg(), |
946 | Addr: LoadDef->getPointerReg(), MMO&: *NewMMO); |
947 | MI.eraseFromParent(); |
948 | } |
949 | |
950 | static Type *getTypeForLLT(LLT Ty, LLVMContext &C) { |
951 | if (Ty.isVector()) |
952 | return FixedVectorType::get(ElementType: IntegerType::get(C, NumBits: Ty.getScalarSizeInBits()), |
953 | NumElts: Ty.getNumElements()); |
954 | return IntegerType::get(C, NumBits: Ty.getSizeInBits()); |
955 | } |
956 | |
957 | /// Return true if 'MI' is a load or a store that may be fold it's address |
958 | /// operand into the load / store addressing mode. |
959 | static bool canFoldInAddressingMode(GLoadStore *MI, const TargetLowering &TLI, |
960 | MachineRegisterInfo &MRI) { |
961 | TargetLowering::AddrMode AM; |
962 | auto *MF = MI->getMF(); |
963 | auto *Addr = getOpcodeDef<GPtrAdd>(Reg: MI->getPointerReg(), MRI); |
964 | if (!Addr) |
965 | return false; |
966 | |
967 | AM.HasBaseReg = true; |
968 | if (auto CstOff = getIConstantVRegVal(VReg: Addr->getOffsetReg(), MRI)) |
969 | AM.BaseOffs = CstOff->getSExtValue(); // [reg +/- imm] |
970 | else |
971 | AM.Scale = 1; // [reg +/- reg] |
972 | |
973 | return TLI.isLegalAddressingMode( |
974 | DL: MF->getDataLayout(), AM, |
975 | Ty: getTypeForLLT(Ty: MI->getMMO().getMemoryType(), |
976 | C&: MF->getFunction().getContext()), |
977 | AddrSpace: MI->getMMO().getAddrSpace()); |
978 | } |
979 | |
980 | static unsigned getIndexedOpc(unsigned LdStOpc) { |
981 | switch (LdStOpc) { |
982 | case TargetOpcode::G_LOAD: |
983 | return TargetOpcode::G_INDEXED_LOAD; |
984 | case TargetOpcode::G_STORE: |
985 | return TargetOpcode::G_INDEXED_STORE; |
986 | case TargetOpcode::G_ZEXTLOAD: |
987 | return TargetOpcode::G_INDEXED_ZEXTLOAD; |
988 | case TargetOpcode::G_SEXTLOAD: |
989 | return TargetOpcode::G_INDEXED_SEXTLOAD; |
990 | default: |
991 | llvm_unreachable("Unexpected opcode" ); |
992 | } |
993 | } |
994 | |
995 | bool CombinerHelper::isIndexedLoadStoreLegal(GLoadStore &LdSt) const { |
996 | // Check for legality. |
997 | LLT PtrTy = MRI.getType(Reg: LdSt.getPointerReg()); |
998 | LLT Ty = MRI.getType(Reg: LdSt.getReg(Idx: 0)); |
999 | LLT MemTy = LdSt.getMMO().getMemoryType(); |
1000 | SmallVector<LegalityQuery::MemDesc, 2> MemDescrs( |
1001 | {{MemTy, MemTy.getSizeInBits(), AtomicOrdering::NotAtomic}}); |
1002 | unsigned IndexedOpc = getIndexedOpc(LdStOpc: LdSt.getOpcode()); |
1003 | SmallVector<LLT> OpTys; |
1004 | if (IndexedOpc == TargetOpcode::G_INDEXED_STORE) |
1005 | OpTys = {PtrTy, Ty, Ty}; |
1006 | else |
1007 | OpTys = {Ty, PtrTy}; // For G_INDEXED_LOAD, G_INDEXED_[SZ]EXTLOAD |
1008 | |
1009 | LegalityQuery Q(IndexedOpc, OpTys, MemDescrs); |
1010 | return isLegal(Query: Q); |
1011 | } |
1012 | |
1013 | static cl::opt<unsigned> PostIndexUseThreshold( |
1014 | "post-index-use-threshold" , cl::Hidden, cl::init(Val: 32), |
1015 | cl::desc("Number of uses of a base pointer to check before it is no longer " |
1016 | "considered for post-indexing." )); |
1017 | |
1018 | bool CombinerHelper::findPostIndexCandidate(GLoadStore &LdSt, Register &Addr, |
1019 | Register &Base, Register &Offset, |
1020 | bool &RematOffset) { |
1021 | // We're looking for the following pattern, for either load or store: |
1022 | // %baseptr:_(p0) = ... |
1023 | // G_STORE %val(s64), %baseptr(p0) |
1024 | // %offset:_(s64) = G_CONSTANT i64 -256 |
1025 | // %new_addr:_(p0) = G_PTR_ADD %baseptr, %offset(s64) |
1026 | const auto &TLI = getTargetLowering(); |
1027 | |
1028 | Register Ptr = LdSt.getPointerReg(); |
1029 | // If the store is the only use, don't bother. |
1030 | if (MRI.hasOneNonDBGUse(RegNo: Ptr)) |
1031 | return false; |
1032 | |
1033 | if (!isIndexedLoadStoreLegal(LdSt)) |
1034 | return false; |
1035 | |
1036 | if (getOpcodeDef(Opcode: TargetOpcode::G_FRAME_INDEX, Reg: Ptr, MRI)) |
1037 | return false; |
1038 | |
1039 | MachineInstr *StoredValDef = getDefIgnoringCopies(Reg: LdSt.getReg(Idx: 0), MRI); |
1040 | auto *PtrDef = MRI.getVRegDef(Reg: Ptr); |
1041 | |
1042 | unsigned NumUsesChecked = 0; |
1043 | for (auto &Use : MRI.use_nodbg_instructions(Reg: Ptr)) { |
1044 | if (++NumUsesChecked > PostIndexUseThreshold) |
1045 | return false; // Try to avoid exploding compile time. |
1046 | |
1047 | auto *PtrAdd = dyn_cast<GPtrAdd>(Val: &Use); |
1048 | // The use itself might be dead. This can happen during combines if DCE |
1049 | // hasn't had a chance to run yet. Don't allow it to form an indexed op. |
1050 | if (!PtrAdd || MRI.use_nodbg_empty(RegNo: PtrAdd->getReg(Idx: 0))) |
1051 | continue; |
1052 | |
1053 | // Check the user of this isn't the store, otherwise we'd be generate a |
1054 | // indexed store defining its own use. |
1055 | if (StoredValDef == &Use) |
1056 | continue; |
1057 | |
1058 | Offset = PtrAdd->getOffsetReg(); |
1059 | if (!ForceLegalIndexing && |
1060 | !TLI.isIndexingLegal(MI&: LdSt, Base: PtrAdd->getBaseReg(), Offset, |
1061 | /*IsPre*/ false, MRI)) |
1062 | continue; |
1063 | |
1064 | // Make sure the offset calculation is before the potentially indexed op. |
1065 | MachineInstr *OffsetDef = MRI.getVRegDef(Reg: Offset); |
1066 | RematOffset = false; |
1067 | if (!dominates(DefMI: *OffsetDef, UseMI: LdSt)) { |
1068 | // If the offset however is just a G_CONSTANT, we can always just |
1069 | // rematerialize it where we need it. |
1070 | if (OffsetDef->getOpcode() != TargetOpcode::G_CONSTANT) |
1071 | continue; |
1072 | RematOffset = true; |
1073 | } |
1074 | |
1075 | for (auto &BasePtrUse : MRI.use_nodbg_instructions(Reg: PtrAdd->getBaseReg())) { |
1076 | if (&BasePtrUse == PtrDef) |
1077 | continue; |
1078 | |
1079 | // If the user is a later load/store that can be post-indexed, then don't |
1080 | // combine this one. |
1081 | auto *BasePtrLdSt = dyn_cast<GLoadStore>(Val: &BasePtrUse); |
1082 | if (BasePtrLdSt && BasePtrLdSt != &LdSt && |
1083 | dominates(DefMI: LdSt, UseMI: *BasePtrLdSt) && |
1084 | isIndexedLoadStoreLegal(LdSt&: *BasePtrLdSt)) |
1085 | return false; |
1086 | |
1087 | // Now we're looking for the key G_PTR_ADD instruction, which contains |
1088 | // the offset add that we want to fold. |
1089 | if (auto *BasePtrUseDef = dyn_cast<GPtrAdd>(Val: &BasePtrUse)) { |
1090 | Register PtrAddDefReg = BasePtrUseDef->getReg(Idx: 0); |
1091 | for (auto &BaseUseUse : MRI.use_nodbg_instructions(Reg: PtrAddDefReg)) { |
1092 | // If the use is in a different block, then we may produce worse code |
1093 | // due to the extra register pressure. |
1094 | if (BaseUseUse.getParent() != LdSt.getParent()) |
1095 | return false; |
1096 | |
1097 | if (auto *UseUseLdSt = dyn_cast<GLoadStore>(Val: &BaseUseUse)) |
1098 | if (canFoldInAddressingMode(MI: UseUseLdSt, TLI, MRI)) |
1099 | return false; |
1100 | } |
1101 | if (!dominates(DefMI: LdSt, UseMI: BasePtrUse)) |
1102 | return false; // All use must be dominated by the load/store. |
1103 | } |
1104 | } |
1105 | |
1106 | Addr = PtrAdd->getReg(Idx: 0); |
1107 | Base = PtrAdd->getBaseReg(); |
1108 | return true; |
1109 | } |
1110 | |
1111 | return false; |
1112 | } |
1113 | |
1114 | bool CombinerHelper::findPreIndexCandidate(GLoadStore &LdSt, Register &Addr, |
1115 | Register &Base, Register &Offset) { |
1116 | auto &MF = *LdSt.getParent()->getParent(); |
1117 | const auto &TLI = *MF.getSubtarget().getTargetLowering(); |
1118 | |
1119 | Addr = LdSt.getPointerReg(); |
1120 | if (!mi_match(R: Addr, MRI, P: m_GPtrAdd(L: m_Reg(R&: Base), R: m_Reg(R&: Offset))) || |
1121 | MRI.hasOneNonDBGUse(RegNo: Addr)) |
1122 | return false; |
1123 | |
1124 | if (!ForceLegalIndexing && |
1125 | !TLI.isIndexingLegal(MI&: LdSt, Base, Offset, /*IsPre*/ true, MRI)) |
1126 | return false; |
1127 | |
1128 | if (!isIndexedLoadStoreLegal(LdSt)) |
1129 | return false; |
1130 | |
1131 | MachineInstr *BaseDef = getDefIgnoringCopies(Reg: Base, MRI); |
1132 | if (BaseDef->getOpcode() == TargetOpcode::G_FRAME_INDEX) |
1133 | return false; |
1134 | |
1135 | if (auto *St = dyn_cast<GStore>(Val: &LdSt)) { |
1136 | // Would require a copy. |
1137 | if (Base == St->getValueReg()) |
1138 | return false; |
1139 | |
1140 | // We're expecting one use of Addr in MI, but it could also be the |
1141 | // value stored, which isn't actually dominated by the instruction. |
1142 | if (St->getValueReg() == Addr) |
1143 | return false; |
1144 | } |
1145 | |
1146 | // Avoid increasing cross-block register pressure. |
1147 | for (auto &AddrUse : MRI.use_nodbg_instructions(Reg: Addr)) |
1148 | if (AddrUse.getParent() != LdSt.getParent()) |
1149 | return false; |
1150 | |
1151 | // FIXME: check whether all uses of the base pointer are constant PtrAdds. |
1152 | // That might allow us to end base's liveness here by adjusting the constant. |
1153 | bool RealUse = false; |
1154 | for (auto &AddrUse : MRI.use_nodbg_instructions(Reg: Addr)) { |
1155 | if (!dominates(DefMI: LdSt, UseMI: AddrUse)) |
1156 | return false; // All use must be dominated by the load/store. |
1157 | |
1158 | // If Ptr may be folded in addressing mode of other use, then it's |
1159 | // not profitable to do this transformation. |
1160 | if (auto *UseLdSt = dyn_cast<GLoadStore>(Val: &AddrUse)) { |
1161 | if (!canFoldInAddressingMode(MI: UseLdSt, TLI, MRI)) |
1162 | RealUse = true; |
1163 | } else { |
1164 | RealUse = true; |
1165 | } |
1166 | } |
1167 | return RealUse; |
1168 | } |
1169 | |
1170 | bool CombinerHelper::(MachineInstr &MI, |
1171 | BuildFnTy &MatchInfo) { |
1172 | assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT); |
1173 | |
1174 | // Check if there is a load that defines the vector being extracted from. |
1175 | auto *LoadMI = getOpcodeDef<GLoad>(Reg: MI.getOperand(i: 1).getReg(), MRI); |
1176 | if (!LoadMI) |
1177 | return false; |
1178 | |
1179 | Register Vector = MI.getOperand(i: 1).getReg(); |
1180 | LLT VecEltTy = MRI.getType(Reg: Vector).getElementType(); |
1181 | |
1182 | assert(MRI.getType(MI.getOperand(0).getReg()) == VecEltTy); |
1183 | |
1184 | // Checking whether we should reduce the load width. |
1185 | if (!MRI.hasOneNonDBGUse(RegNo: Vector)) |
1186 | return false; |
1187 | |
1188 | // Check if the defining load is simple. |
1189 | if (!LoadMI->isSimple()) |
1190 | return false; |
1191 | |
1192 | // If the vector element type is not a multiple of a byte then we are unable |
1193 | // to correctly compute an address to load only the extracted element as a |
1194 | // scalar. |
1195 | if (!VecEltTy.isByteSized()) |
1196 | return false; |
1197 | |
1198 | // Check if the new load that we are going to create is legal |
1199 | // if we are in the post-legalization phase. |
1200 | MachineMemOperand MMO = LoadMI->getMMO(); |
1201 | Align Alignment = MMO.getAlign(); |
1202 | MachinePointerInfo PtrInfo; |
1203 | uint64_t Offset; |
1204 | |
1205 | // Finding the appropriate PtrInfo if offset is a known constant. |
1206 | // This is required to create the memory operand for the narrowed load. |
1207 | // This machine memory operand object helps us infer about legality |
1208 | // before we proceed to combine the instruction. |
1209 | if (auto CVal = getIConstantVRegVal(VReg: Vector, MRI)) { |
1210 | int Elt = CVal->getZExtValue(); |
1211 | // FIXME: should be (ABI size)*Elt. |
1212 | Offset = VecEltTy.getSizeInBits() * Elt / 8; |
1213 | PtrInfo = MMO.getPointerInfo().getWithOffset(O: Offset); |
1214 | } else { |
1215 | // Discard the pointer info except the address space because the memory |
1216 | // operand can't represent this new access since the offset is variable. |
1217 | Offset = VecEltTy.getSizeInBits() / 8; |
1218 | PtrInfo = MachinePointerInfo(MMO.getPointerInfo().getAddrSpace()); |
1219 | } |
1220 | |
1221 | Alignment = commonAlignment(A: Alignment, Offset); |
1222 | |
1223 | Register VecPtr = LoadMI->getPointerReg(); |
1224 | LLT PtrTy = MRI.getType(Reg: VecPtr); |
1225 | |
1226 | MachineFunction &MF = *MI.getMF(); |
1227 | auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Ty: VecEltTy); |
1228 | |
1229 | LegalityQuery::MemDesc MMDesc(*NewMMO); |
1230 | |
1231 | LegalityQuery Q = {TargetOpcode::G_LOAD, {VecEltTy, PtrTy}, {MMDesc}}; |
1232 | |
1233 | if (!isLegalOrBeforeLegalizer(Query: Q)) |
1234 | return false; |
1235 | |
1236 | // Load must be allowed and fast on the target. |
1237 | LLVMContext &C = MF.getFunction().getContext(); |
1238 | auto &DL = MF.getDataLayout(); |
1239 | unsigned Fast = 0; |
1240 | if (!getTargetLowering().allowsMemoryAccess(Context&: C, DL, Ty: VecEltTy, MMO: *NewMMO, |
1241 | Fast: &Fast) || |
1242 | !Fast) |
1243 | return false; |
1244 | |
1245 | Register Result = MI.getOperand(i: 0).getReg(); |
1246 | Register Index = MI.getOperand(i: 2).getReg(); |
1247 | |
1248 | MatchInfo = [=](MachineIRBuilder &B) { |
1249 | GISelObserverWrapper DummyObserver; |
1250 | LegalizerHelper Helper(B.getMF(), DummyObserver, B); |
1251 | //// Get pointer to the vector element. |
1252 | Register finalPtr = Helper.getVectorElementPointer( |
1253 | VecPtr: LoadMI->getPointerReg(), VecTy: MRI.getType(Reg: LoadMI->getOperand(i: 0).getReg()), |
1254 | Index); |
1255 | // New G_LOAD instruction. |
1256 | B.buildLoad(Res: Result, Addr: finalPtr, PtrInfo, Alignment); |
1257 | // Remove original GLOAD instruction. |
1258 | LoadMI->eraseFromParent(); |
1259 | }; |
1260 | |
1261 | return true; |
1262 | } |
1263 | |
1264 | bool CombinerHelper::matchCombineIndexedLoadStore( |
1265 | MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) { |
1266 | auto &LdSt = cast<GLoadStore>(Val&: MI); |
1267 | |
1268 | if (LdSt.isAtomic()) |
1269 | return false; |
1270 | |
1271 | MatchInfo.IsPre = findPreIndexCandidate(LdSt, Addr&: MatchInfo.Addr, Base&: MatchInfo.Base, |
1272 | Offset&: MatchInfo.Offset); |
1273 | if (!MatchInfo.IsPre && |
1274 | !findPostIndexCandidate(LdSt, Addr&: MatchInfo.Addr, Base&: MatchInfo.Base, |
1275 | Offset&: MatchInfo.Offset, RematOffset&: MatchInfo.RematOffset)) |
1276 | return false; |
1277 | |
1278 | return true; |
1279 | } |
1280 | |
1281 | void CombinerHelper::applyCombineIndexedLoadStore( |
1282 | MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) { |
1283 | MachineInstr &AddrDef = *MRI.getUniqueVRegDef(Reg: MatchInfo.Addr); |
1284 | Builder.setInstrAndDebugLoc(MI); |
1285 | unsigned Opcode = MI.getOpcode(); |
1286 | bool IsStore = Opcode == TargetOpcode::G_STORE; |
1287 | unsigned NewOpcode = getIndexedOpc(LdStOpc: Opcode); |
1288 | |
1289 | // If the offset constant didn't happen to dominate the load/store, we can |
1290 | // just clone it as needed. |
1291 | if (MatchInfo.RematOffset) { |
1292 | auto *OldCst = MRI.getVRegDef(Reg: MatchInfo.Offset); |
1293 | auto NewCst = Builder.buildConstant(Res: MRI.getType(Reg: MatchInfo.Offset), |
1294 | Val: *OldCst->getOperand(i: 1).getCImm()); |
1295 | MatchInfo.Offset = NewCst.getReg(Idx: 0); |
1296 | } |
1297 | |
1298 | auto MIB = Builder.buildInstr(Opcode: NewOpcode); |
1299 | if (IsStore) { |
1300 | MIB.addDef(RegNo: MatchInfo.Addr); |
1301 | MIB.addUse(RegNo: MI.getOperand(i: 0).getReg()); |
1302 | } else { |
1303 | MIB.addDef(RegNo: MI.getOperand(i: 0).getReg()); |
1304 | MIB.addDef(RegNo: MatchInfo.Addr); |
1305 | } |
1306 | |
1307 | MIB.addUse(RegNo: MatchInfo.Base); |
1308 | MIB.addUse(RegNo: MatchInfo.Offset); |
1309 | MIB.addImm(Val: MatchInfo.IsPre); |
1310 | MIB->cloneMemRefs(MF&: *MI.getMF(), MI); |
1311 | MI.eraseFromParent(); |
1312 | AddrDef.eraseFromParent(); |
1313 | |
1314 | LLVM_DEBUG(dbgs() << " Combinined to indexed operation" ); |
1315 | } |
1316 | |
1317 | bool CombinerHelper::matchCombineDivRem(MachineInstr &MI, |
1318 | MachineInstr *&OtherMI) { |
1319 | unsigned Opcode = MI.getOpcode(); |
1320 | bool IsDiv, IsSigned; |
1321 | |
1322 | switch (Opcode) { |
1323 | default: |
1324 | llvm_unreachable("Unexpected opcode!" ); |
1325 | case TargetOpcode::G_SDIV: |
1326 | case TargetOpcode::G_UDIV: { |
1327 | IsDiv = true; |
1328 | IsSigned = Opcode == TargetOpcode::G_SDIV; |
1329 | break; |
1330 | } |
1331 | case TargetOpcode::G_SREM: |
1332 | case TargetOpcode::G_UREM: { |
1333 | IsDiv = false; |
1334 | IsSigned = Opcode == TargetOpcode::G_SREM; |
1335 | break; |
1336 | } |
1337 | } |
1338 | |
1339 | Register Src1 = MI.getOperand(i: 1).getReg(); |
1340 | unsigned DivOpcode, RemOpcode, DivremOpcode; |
1341 | if (IsSigned) { |
1342 | DivOpcode = TargetOpcode::G_SDIV; |
1343 | RemOpcode = TargetOpcode::G_SREM; |
1344 | DivremOpcode = TargetOpcode::G_SDIVREM; |
1345 | } else { |
1346 | DivOpcode = TargetOpcode::G_UDIV; |
1347 | RemOpcode = TargetOpcode::G_UREM; |
1348 | DivremOpcode = TargetOpcode::G_UDIVREM; |
1349 | } |
1350 | |
1351 | if (!isLegalOrBeforeLegalizer(Query: {DivremOpcode, {MRI.getType(Reg: Src1)}})) |
1352 | return false; |
1353 | |
1354 | // Combine: |
1355 | // %div:_ = G_[SU]DIV %src1:_, %src2:_ |
1356 | // %rem:_ = G_[SU]REM %src1:_, %src2:_ |
1357 | // into: |
1358 | // %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_ |
1359 | |
1360 | // Combine: |
1361 | // %rem:_ = G_[SU]REM %src1:_, %src2:_ |
1362 | // %div:_ = G_[SU]DIV %src1:_, %src2:_ |
1363 | // into: |
1364 | // %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_ |
1365 | |
1366 | for (auto &UseMI : MRI.use_nodbg_instructions(Reg: Src1)) { |
1367 | if (MI.getParent() == UseMI.getParent() && |
1368 | ((IsDiv && UseMI.getOpcode() == RemOpcode) || |
1369 | (!IsDiv && UseMI.getOpcode() == DivOpcode)) && |
1370 | matchEqualDefs(MOP1: MI.getOperand(i: 2), MOP2: UseMI.getOperand(i: 2)) && |
1371 | matchEqualDefs(MOP1: MI.getOperand(i: 1), MOP2: UseMI.getOperand(i: 1))) { |
1372 | OtherMI = &UseMI; |
1373 | return true; |
1374 | } |
1375 | } |
1376 | |
1377 | return false; |
1378 | } |
1379 | |
1380 | void CombinerHelper::applyCombineDivRem(MachineInstr &MI, |
1381 | MachineInstr *&OtherMI) { |
1382 | unsigned Opcode = MI.getOpcode(); |
1383 | assert(OtherMI && "OtherMI shouldn't be empty." ); |
1384 | |
1385 | Register DestDivReg, DestRemReg; |
1386 | if (Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_UDIV) { |
1387 | DestDivReg = MI.getOperand(i: 0).getReg(); |
1388 | DestRemReg = OtherMI->getOperand(i: 0).getReg(); |
1389 | } else { |
1390 | DestDivReg = OtherMI->getOperand(i: 0).getReg(); |
1391 | DestRemReg = MI.getOperand(i: 0).getReg(); |
1392 | } |
1393 | |
1394 | bool IsSigned = |
1395 | Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_SREM; |
1396 | |
1397 | // Check which instruction is first in the block so we don't break def-use |
1398 | // deps by "moving" the instruction incorrectly. Also keep track of which |
1399 | // instruction is first so we pick it's operands, avoiding use-before-def |
1400 | // bugs. |
1401 | MachineInstr *FirstInst; |
1402 | if (dominates(DefMI: MI, UseMI: *OtherMI)) { |
1403 | Builder.setInstrAndDebugLoc(MI); |
1404 | FirstInst = &MI; |
1405 | } else { |
1406 | Builder.setInstrAndDebugLoc(*OtherMI); |
1407 | FirstInst = OtherMI; |
1408 | } |
1409 | |
1410 | Builder.buildInstr(Opc: IsSigned ? TargetOpcode::G_SDIVREM |
1411 | : TargetOpcode::G_UDIVREM, |
1412 | DstOps: {DestDivReg, DestRemReg}, |
1413 | SrcOps: { FirstInst->getOperand(i: 1), FirstInst->getOperand(i: 2) }); |
1414 | MI.eraseFromParent(); |
1415 | OtherMI->eraseFromParent(); |
1416 | } |
1417 | |
1418 | bool CombinerHelper::matchOptBrCondByInvertingCond(MachineInstr &MI, |
1419 | MachineInstr *&BrCond) { |
1420 | assert(MI.getOpcode() == TargetOpcode::G_BR); |
1421 | |
1422 | // Try to match the following: |
1423 | // bb1: |
1424 | // G_BRCOND %c1, %bb2 |
1425 | // G_BR %bb3 |
1426 | // bb2: |
1427 | // ... |
1428 | // bb3: |
1429 | |
1430 | // The above pattern does not have a fall through to the successor bb2, always |
1431 | // resulting in a branch no matter which path is taken. Here we try to find |
1432 | // and replace that pattern with conditional branch to bb3 and otherwise |
1433 | // fallthrough to bb2. This is generally better for branch predictors. |
1434 | |
1435 | MachineBasicBlock *MBB = MI.getParent(); |
1436 | MachineBasicBlock::iterator BrIt(MI); |
1437 | if (BrIt == MBB->begin()) |
1438 | return false; |
1439 | assert(std::next(BrIt) == MBB->end() && "expected G_BR to be a terminator" ); |
1440 | |
1441 | BrCond = &*std::prev(x: BrIt); |
1442 | if (BrCond->getOpcode() != TargetOpcode::G_BRCOND) |
1443 | return false; |
1444 | |
1445 | // Check that the next block is the conditional branch target. Also make sure |
1446 | // that it isn't the same as the G_BR's target (otherwise, this will loop.) |
1447 | MachineBasicBlock *BrCondTarget = BrCond->getOperand(i: 1).getMBB(); |
1448 | return BrCondTarget != MI.getOperand(i: 0).getMBB() && |
1449 | MBB->isLayoutSuccessor(MBB: BrCondTarget); |
1450 | } |
1451 | |
1452 | void CombinerHelper::applyOptBrCondByInvertingCond(MachineInstr &MI, |
1453 | MachineInstr *&BrCond) { |
1454 | MachineBasicBlock *BrTarget = MI.getOperand(i: 0).getMBB(); |
1455 | Builder.setInstrAndDebugLoc(*BrCond); |
1456 | LLT Ty = MRI.getType(Reg: BrCond->getOperand(i: 0).getReg()); |
1457 | // FIXME: Does int/fp matter for this? If so, we might need to restrict |
1458 | // this to i1 only since we might not know for sure what kind of |
1459 | // compare generated the condition value. |
1460 | auto True = Builder.buildConstant( |
1461 | Res: Ty, Val: getICmpTrueVal(TLI: getTargetLowering(), IsVector: false, IsFP: false)); |
1462 | auto Xor = Builder.buildXor(Dst: Ty, Src0: BrCond->getOperand(i: 0), Src1: True); |
1463 | |
1464 | auto *FallthroughBB = BrCond->getOperand(i: 1).getMBB(); |
1465 | Observer.changingInstr(MI); |
1466 | MI.getOperand(i: 0).setMBB(FallthroughBB); |
1467 | Observer.changedInstr(MI); |
1468 | |
1469 | // Change the conditional branch to use the inverted condition and |
1470 | // new target block. |
1471 | Observer.changingInstr(MI&: *BrCond); |
1472 | BrCond->getOperand(i: 0).setReg(Xor.getReg(Idx: 0)); |
1473 | BrCond->getOperand(i: 1).setMBB(BrTarget); |
1474 | Observer.changedInstr(MI&: *BrCond); |
1475 | } |
1476 | |
1477 | |
1478 | bool CombinerHelper::tryEmitMemcpyInline(MachineInstr &MI) { |
1479 | MachineIRBuilder HelperBuilder(MI); |
1480 | GISelObserverWrapper DummyObserver; |
1481 | LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder); |
1482 | return Helper.lowerMemcpyInline(MI) == |
1483 | LegalizerHelper::LegalizeResult::Legalized; |
1484 | } |
1485 | |
1486 | bool CombinerHelper::tryCombineMemCpyFamily(MachineInstr &MI, unsigned MaxLen) { |
1487 | MachineIRBuilder HelperBuilder(MI); |
1488 | GISelObserverWrapper DummyObserver; |
1489 | LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder); |
1490 | return Helper.lowerMemCpyFamily(MI, MaxLen) == |
1491 | LegalizerHelper::LegalizeResult::Legalized; |
1492 | } |
1493 | |
1494 | static APFloat constantFoldFpUnary(const MachineInstr &MI, |
1495 | const MachineRegisterInfo &MRI, |
1496 | const APFloat &Val) { |
1497 | APFloat Result(Val); |
1498 | switch (MI.getOpcode()) { |
1499 | default: |
1500 | llvm_unreachable("Unexpected opcode!" ); |
1501 | case TargetOpcode::G_FNEG: { |
1502 | Result.changeSign(); |
1503 | return Result; |
1504 | } |
1505 | case TargetOpcode::G_FABS: { |
1506 | Result.clearSign(); |
1507 | return Result; |
1508 | } |
1509 | case TargetOpcode::G_FPTRUNC: { |
1510 | bool Unused; |
1511 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
1512 | Result.convert(ToSemantics: getFltSemanticForLLT(Ty: DstTy), RM: APFloat::rmNearestTiesToEven, |
1513 | losesInfo: &Unused); |
1514 | return Result; |
1515 | } |
1516 | case TargetOpcode::G_FSQRT: { |
1517 | bool Unused; |
1518 | Result.convert(ToSemantics: APFloat::IEEEdouble(), RM: APFloat::rmNearestTiesToEven, |
1519 | losesInfo: &Unused); |
1520 | Result = APFloat(sqrt(x: Result.convertToDouble())); |
1521 | break; |
1522 | } |
1523 | case TargetOpcode::G_FLOG2: { |
1524 | bool Unused; |
1525 | Result.convert(ToSemantics: APFloat::IEEEdouble(), RM: APFloat::rmNearestTiesToEven, |
1526 | losesInfo: &Unused); |
1527 | Result = APFloat(log2(x: Result.convertToDouble())); |
1528 | break; |
1529 | } |
1530 | } |
1531 | // Convert `APFloat` to appropriate IEEE type depending on `DstTy`. Otherwise, |
1532 | // `buildFConstant` will assert on size mismatch. Only `G_FSQRT`, and |
1533 | // `G_FLOG2` reach here. |
1534 | bool Unused; |
1535 | Result.convert(ToSemantics: Val.getSemantics(), RM: APFloat::rmNearestTiesToEven, losesInfo: &Unused); |
1536 | return Result; |
1537 | } |
1538 | |
1539 | void CombinerHelper::applyCombineConstantFoldFpUnary(MachineInstr &MI, |
1540 | const ConstantFP *Cst) { |
1541 | Builder.setInstrAndDebugLoc(MI); |
1542 | APFloat Folded = constantFoldFpUnary(MI, MRI, Val: Cst->getValue()); |
1543 | const ConstantFP *NewCst = ConstantFP::get(Context&: Builder.getContext(), V: Folded); |
1544 | Builder.buildFConstant(Res: MI.getOperand(i: 0), Val: *NewCst); |
1545 | MI.eraseFromParent(); |
1546 | } |
1547 | |
1548 | bool CombinerHelper::matchPtrAddImmedChain(MachineInstr &MI, |
1549 | PtrAddChain &MatchInfo) { |
1550 | // We're trying to match the following pattern: |
1551 | // %t1 = G_PTR_ADD %base, G_CONSTANT imm1 |
1552 | // %root = G_PTR_ADD %t1, G_CONSTANT imm2 |
1553 | // --> |
1554 | // %root = G_PTR_ADD %base, G_CONSTANT (imm1 + imm2) |
1555 | |
1556 | if (MI.getOpcode() != TargetOpcode::G_PTR_ADD) |
1557 | return false; |
1558 | |
1559 | Register Add2 = MI.getOperand(i: 1).getReg(); |
1560 | Register Imm1 = MI.getOperand(i: 2).getReg(); |
1561 | auto MaybeImmVal = getIConstantVRegValWithLookThrough(VReg: Imm1, MRI); |
1562 | if (!MaybeImmVal) |
1563 | return false; |
1564 | |
1565 | MachineInstr *Add2Def = MRI.getVRegDef(Reg: Add2); |
1566 | if (!Add2Def || Add2Def->getOpcode() != TargetOpcode::G_PTR_ADD) |
1567 | return false; |
1568 | |
1569 | Register Base = Add2Def->getOperand(i: 1).getReg(); |
1570 | Register Imm2 = Add2Def->getOperand(i: 2).getReg(); |
1571 | auto MaybeImm2Val = getIConstantVRegValWithLookThrough(VReg: Imm2, MRI); |
1572 | if (!MaybeImm2Val) |
1573 | return false; |
1574 | |
1575 | // Check if the new combined immediate forms an illegal addressing mode. |
1576 | // Do not combine if it was legal before but would get illegal. |
1577 | // To do so, we need to find a load/store user of the pointer to get |
1578 | // the access type. |
1579 | Type *AccessTy = nullptr; |
1580 | auto &MF = *MI.getMF(); |
1581 | for (auto &UseMI : MRI.use_nodbg_instructions(Reg: MI.getOperand(i: 0).getReg())) { |
1582 | if (auto *LdSt = dyn_cast<GLoadStore>(Val: &UseMI)) { |
1583 | AccessTy = getTypeForLLT(Ty: MRI.getType(Reg: LdSt->getReg(Idx: 0)), |
1584 | C&: MF.getFunction().getContext()); |
1585 | break; |
1586 | } |
1587 | } |
1588 | TargetLoweringBase::AddrMode AMNew; |
1589 | APInt CombinedImm = MaybeImmVal->Value + MaybeImm2Val->Value; |
1590 | AMNew.BaseOffs = CombinedImm.getSExtValue(); |
1591 | if (AccessTy) { |
1592 | AMNew.HasBaseReg = true; |
1593 | TargetLoweringBase::AddrMode AMOld; |
1594 | AMOld.BaseOffs = MaybeImmVal->Value.getSExtValue(); |
1595 | AMOld.HasBaseReg = true; |
1596 | unsigned AS = MRI.getType(Reg: Add2).getAddressSpace(); |
1597 | const auto &TLI = *MF.getSubtarget().getTargetLowering(); |
1598 | if (TLI.isLegalAddressingMode(DL: MF.getDataLayout(), AM: AMOld, Ty: AccessTy, AddrSpace: AS) && |
1599 | !TLI.isLegalAddressingMode(DL: MF.getDataLayout(), AM: AMNew, Ty: AccessTy, AddrSpace: AS)) |
1600 | return false; |
1601 | } |
1602 | |
1603 | // Pass the combined immediate to the apply function. |
1604 | MatchInfo.Imm = AMNew.BaseOffs; |
1605 | MatchInfo.Base = Base; |
1606 | MatchInfo.Bank = getRegBank(Reg: Imm2); |
1607 | return true; |
1608 | } |
1609 | |
1610 | void CombinerHelper::applyPtrAddImmedChain(MachineInstr &MI, |
1611 | PtrAddChain &MatchInfo) { |
1612 | assert(MI.getOpcode() == TargetOpcode::G_PTR_ADD && "Expected G_PTR_ADD" ); |
1613 | MachineIRBuilder MIB(MI); |
1614 | LLT OffsetTy = MRI.getType(Reg: MI.getOperand(i: 2).getReg()); |
1615 | auto NewOffset = MIB.buildConstant(Res: OffsetTy, Val: MatchInfo.Imm); |
1616 | setRegBank(Reg: NewOffset.getReg(Idx: 0), RegBank: MatchInfo.Bank); |
1617 | Observer.changingInstr(MI); |
1618 | MI.getOperand(i: 1).setReg(MatchInfo.Base); |
1619 | MI.getOperand(i: 2).setReg(NewOffset.getReg(Idx: 0)); |
1620 | Observer.changedInstr(MI); |
1621 | } |
1622 | |
1623 | bool CombinerHelper::matchShiftImmedChain(MachineInstr &MI, |
1624 | RegisterImmPair &MatchInfo) { |
1625 | // We're trying to match the following pattern with any of |
1626 | // G_SHL/G_ASHR/G_LSHR/G_SSHLSAT/G_USHLSAT shift instructions: |
1627 | // %t1 = SHIFT %base, G_CONSTANT imm1 |
1628 | // %root = SHIFT %t1, G_CONSTANT imm2 |
1629 | // --> |
1630 | // %root = SHIFT %base, G_CONSTANT (imm1 + imm2) |
1631 | |
1632 | unsigned Opcode = MI.getOpcode(); |
1633 | assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR || |
1634 | Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT || |
1635 | Opcode == TargetOpcode::G_USHLSAT) && |
1636 | "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT" ); |
1637 | |
1638 | Register Shl2 = MI.getOperand(i: 1).getReg(); |
1639 | Register Imm1 = MI.getOperand(i: 2).getReg(); |
1640 | auto MaybeImmVal = getIConstantVRegValWithLookThrough(VReg: Imm1, MRI); |
1641 | if (!MaybeImmVal) |
1642 | return false; |
1643 | |
1644 | MachineInstr *Shl2Def = MRI.getUniqueVRegDef(Reg: Shl2); |
1645 | if (Shl2Def->getOpcode() != Opcode) |
1646 | return false; |
1647 | |
1648 | Register Base = Shl2Def->getOperand(i: 1).getReg(); |
1649 | Register Imm2 = Shl2Def->getOperand(i: 2).getReg(); |
1650 | auto MaybeImm2Val = getIConstantVRegValWithLookThrough(VReg: Imm2, MRI); |
1651 | if (!MaybeImm2Val) |
1652 | return false; |
1653 | |
1654 | // Pass the combined immediate to the apply function. |
1655 | MatchInfo.Imm = |
1656 | (MaybeImmVal->Value.getZExtValue() + MaybeImm2Val->Value).getZExtValue(); |
1657 | MatchInfo.Reg = Base; |
1658 | |
1659 | // There is no simple replacement for a saturating unsigned left shift that |
1660 | // exceeds the scalar size. |
1661 | if (Opcode == TargetOpcode::G_USHLSAT && |
1662 | MatchInfo.Imm >= MRI.getType(Reg: Shl2).getScalarSizeInBits()) |
1663 | return false; |
1664 | |
1665 | return true; |
1666 | } |
1667 | |
1668 | void CombinerHelper::applyShiftImmedChain(MachineInstr &MI, |
1669 | RegisterImmPair &MatchInfo) { |
1670 | unsigned Opcode = MI.getOpcode(); |
1671 | assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR || |
1672 | Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT || |
1673 | Opcode == TargetOpcode::G_USHLSAT) && |
1674 | "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT" ); |
1675 | |
1676 | Builder.setInstrAndDebugLoc(MI); |
1677 | LLT Ty = MRI.getType(Reg: MI.getOperand(i: 1).getReg()); |
1678 | unsigned const ScalarSizeInBits = Ty.getScalarSizeInBits(); |
1679 | auto Imm = MatchInfo.Imm; |
1680 | |
1681 | if (Imm >= ScalarSizeInBits) { |
1682 | // Any logical shift that exceeds scalar size will produce zero. |
1683 | if (Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_LSHR) { |
1684 | Builder.buildConstant(Res: MI.getOperand(i: 0), Val: 0); |
1685 | MI.eraseFromParent(); |
1686 | return; |
1687 | } |
1688 | // Arithmetic shift and saturating signed left shift have no effect beyond |
1689 | // scalar size. |
1690 | Imm = ScalarSizeInBits - 1; |
1691 | } |
1692 | |
1693 | LLT ImmTy = MRI.getType(Reg: MI.getOperand(i: 2).getReg()); |
1694 | Register NewImm = Builder.buildConstant(Res: ImmTy, Val: Imm).getReg(Idx: 0); |
1695 | Observer.changingInstr(MI); |
1696 | MI.getOperand(i: 1).setReg(MatchInfo.Reg); |
1697 | MI.getOperand(i: 2).setReg(NewImm); |
1698 | Observer.changedInstr(MI); |
1699 | } |
1700 | |
1701 | bool CombinerHelper::matchShiftOfShiftedLogic(MachineInstr &MI, |
1702 | ShiftOfShiftedLogic &MatchInfo) { |
1703 | // We're trying to match the following pattern with any of |
1704 | // G_SHL/G_ASHR/G_LSHR/G_USHLSAT/G_SSHLSAT shift instructions in combination |
1705 | // with any of G_AND/G_OR/G_XOR logic instructions. |
1706 | // %t1 = SHIFT %X, G_CONSTANT C0 |
1707 | // %t2 = LOGIC %t1, %Y |
1708 | // %root = SHIFT %t2, G_CONSTANT C1 |
1709 | // --> |
1710 | // %t3 = SHIFT %X, G_CONSTANT (C0+C1) |
1711 | // %t4 = SHIFT %Y, G_CONSTANT C1 |
1712 | // %root = LOGIC %t3, %t4 |
1713 | unsigned ShiftOpcode = MI.getOpcode(); |
1714 | assert((ShiftOpcode == TargetOpcode::G_SHL || |
1715 | ShiftOpcode == TargetOpcode::G_ASHR || |
1716 | ShiftOpcode == TargetOpcode::G_LSHR || |
1717 | ShiftOpcode == TargetOpcode::G_USHLSAT || |
1718 | ShiftOpcode == TargetOpcode::G_SSHLSAT) && |
1719 | "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT" ); |
1720 | |
1721 | // Match a one-use bitwise logic op. |
1722 | Register LogicDest = MI.getOperand(i: 1).getReg(); |
1723 | if (!MRI.hasOneNonDBGUse(RegNo: LogicDest)) |
1724 | return false; |
1725 | |
1726 | MachineInstr *LogicMI = MRI.getUniqueVRegDef(Reg: LogicDest); |
1727 | unsigned LogicOpcode = LogicMI->getOpcode(); |
1728 | if (LogicOpcode != TargetOpcode::G_AND && LogicOpcode != TargetOpcode::G_OR && |
1729 | LogicOpcode != TargetOpcode::G_XOR) |
1730 | return false; |
1731 | |
1732 | // Find a matching one-use shift by constant. |
1733 | const Register C1 = MI.getOperand(i: 2).getReg(); |
1734 | auto MaybeImmVal = getIConstantVRegValWithLookThrough(VReg: C1, MRI); |
1735 | if (!MaybeImmVal || MaybeImmVal->Value == 0) |
1736 | return false; |
1737 | |
1738 | const uint64_t C1Val = MaybeImmVal->Value.getZExtValue(); |
1739 | |
1740 | auto matchFirstShift = [&](const MachineInstr *MI, uint64_t &ShiftVal) { |
1741 | // Shift should match previous one and should be a one-use. |
1742 | if (MI->getOpcode() != ShiftOpcode || |
1743 | !MRI.hasOneNonDBGUse(RegNo: MI->getOperand(i: 0).getReg())) |
1744 | return false; |
1745 | |
1746 | // Must be a constant. |
1747 | auto MaybeImmVal = |
1748 | getIConstantVRegValWithLookThrough(VReg: MI->getOperand(i: 2).getReg(), MRI); |
1749 | if (!MaybeImmVal) |
1750 | return false; |
1751 | |
1752 | ShiftVal = MaybeImmVal->Value.getSExtValue(); |
1753 | return true; |
1754 | }; |
1755 | |
1756 | // Logic ops are commutative, so check each operand for a match. |
1757 | Register LogicMIReg1 = LogicMI->getOperand(i: 1).getReg(); |
1758 | MachineInstr *LogicMIOp1 = MRI.getUniqueVRegDef(Reg: LogicMIReg1); |
1759 | Register LogicMIReg2 = LogicMI->getOperand(i: 2).getReg(); |
1760 | MachineInstr *LogicMIOp2 = MRI.getUniqueVRegDef(Reg: LogicMIReg2); |
1761 | uint64_t C0Val; |
1762 | |
1763 | if (matchFirstShift(LogicMIOp1, C0Val)) { |
1764 | MatchInfo.LogicNonShiftReg = LogicMIReg2; |
1765 | MatchInfo.Shift2 = LogicMIOp1; |
1766 | } else if (matchFirstShift(LogicMIOp2, C0Val)) { |
1767 | MatchInfo.LogicNonShiftReg = LogicMIReg1; |
1768 | MatchInfo.Shift2 = LogicMIOp2; |
1769 | } else |
1770 | return false; |
1771 | |
1772 | MatchInfo.ValSum = C0Val + C1Val; |
1773 | |
1774 | // The fold is not valid if the sum of the shift values exceeds bitwidth. |
1775 | if (MatchInfo.ValSum >= MRI.getType(Reg: LogicDest).getScalarSizeInBits()) |
1776 | return false; |
1777 | |
1778 | MatchInfo.Logic = LogicMI; |
1779 | return true; |
1780 | } |
1781 | |
1782 | void CombinerHelper::applyShiftOfShiftedLogic(MachineInstr &MI, |
1783 | ShiftOfShiftedLogic &MatchInfo) { |
1784 | unsigned Opcode = MI.getOpcode(); |
1785 | assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR || |
1786 | Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_USHLSAT || |
1787 | Opcode == TargetOpcode::G_SSHLSAT) && |
1788 | "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT" ); |
1789 | |
1790 | LLT ShlType = MRI.getType(Reg: MI.getOperand(i: 2).getReg()); |
1791 | LLT DestType = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
1792 | Builder.setInstrAndDebugLoc(MI); |
1793 | |
1794 | Register Const = Builder.buildConstant(Res: ShlType, Val: MatchInfo.ValSum).getReg(Idx: 0); |
1795 | |
1796 | Register Shift1Base = MatchInfo.Shift2->getOperand(i: 1).getReg(); |
1797 | Register Shift1 = |
1798 | Builder.buildInstr(Opc: Opcode, DstOps: {DestType}, SrcOps: {Shift1Base, Const}).getReg(Idx: 0); |
1799 | |
1800 | // If LogicNonShiftReg is the same to Shift1Base, and shift1 const is the same |
1801 | // to MatchInfo.Shift2 const, CSEMIRBuilder will reuse the old shift1 when |
1802 | // build shift2. So, if we erase MatchInfo.Shift2 at the end, actually we |
1803 | // remove old shift1. And it will cause crash later. So erase it earlier to |
1804 | // avoid the crash. |
1805 | MatchInfo.Shift2->eraseFromParent(); |
1806 | |
1807 | Register Shift2Const = MI.getOperand(i: 2).getReg(); |
1808 | Register Shift2 = Builder |
1809 | .buildInstr(Opc: Opcode, DstOps: {DestType}, |
1810 | SrcOps: {MatchInfo.LogicNonShiftReg, Shift2Const}) |
1811 | .getReg(Idx: 0); |
1812 | |
1813 | Register Dest = MI.getOperand(i: 0).getReg(); |
1814 | Builder.buildInstr(Opc: MatchInfo.Logic->getOpcode(), DstOps: {Dest}, SrcOps: {Shift1, Shift2}); |
1815 | |
1816 | // This was one use so it's safe to remove it. |
1817 | MatchInfo.Logic->eraseFromParent(); |
1818 | |
1819 | MI.eraseFromParent(); |
1820 | } |
1821 | |
1822 | bool CombinerHelper::matchCommuteShift(MachineInstr &MI, BuildFnTy &MatchInfo) { |
1823 | assert(MI.getOpcode() == TargetOpcode::G_SHL && "Expected G_SHL" ); |
1824 | // Combine (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2) |
1825 | // Combine (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2) |
1826 | auto &Shl = cast<GenericMachineInstr>(Val&: MI); |
1827 | Register DstReg = Shl.getReg(Idx: 0); |
1828 | Register SrcReg = Shl.getReg(Idx: 1); |
1829 | Register ShiftReg = Shl.getReg(Idx: 2); |
1830 | Register X, C1; |
1831 | |
1832 | if (!getTargetLowering().isDesirableToCommuteWithShift(MI, IsAfterLegal: !isPreLegalize())) |
1833 | return false; |
1834 | |
1835 | if (!mi_match(R: SrcReg, MRI, |
1836 | P: m_OneNonDBGUse(SP: m_any_of(preds: m_GAdd(L: m_Reg(R&: X), R: m_Reg(R&: C1)), |
1837 | preds: m_GOr(L: m_Reg(R&: X), R: m_Reg(R&: C1)))))) |
1838 | return false; |
1839 | |
1840 | APInt C1Val, C2Val; |
1841 | if (!mi_match(R: C1, MRI, P: m_ICstOrSplat(Cst&: C1Val)) || |
1842 | !mi_match(R: ShiftReg, MRI, P: m_ICstOrSplat(Cst&: C2Val))) |
1843 | return false; |
1844 | |
1845 | auto *SrcDef = MRI.getVRegDef(Reg: SrcReg); |
1846 | assert((SrcDef->getOpcode() == TargetOpcode::G_ADD || |
1847 | SrcDef->getOpcode() == TargetOpcode::G_OR) && "Unexpected op" ); |
1848 | LLT SrcTy = MRI.getType(Reg: SrcReg); |
1849 | MatchInfo = [=](MachineIRBuilder &B) { |
1850 | auto S1 = B.buildShl(Dst: SrcTy, Src0: X, Src1: ShiftReg); |
1851 | auto S2 = B.buildShl(Dst: SrcTy, Src0: C1, Src1: ShiftReg); |
1852 | B.buildInstr(Opc: SrcDef->getOpcode(), DstOps: {DstReg}, SrcOps: {S1, S2}); |
1853 | }; |
1854 | return true; |
1855 | } |
1856 | |
1857 | bool CombinerHelper::matchCombineMulToShl(MachineInstr &MI, |
1858 | unsigned &ShiftVal) { |
1859 | assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL" ); |
1860 | auto MaybeImmVal = |
1861 | getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI); |
1862 | if (!MaybeImmVal) |
1863 | return false; |
1864 | |
1865 | ShiftVal = MaybeImmVal->Value.exactLogBase2(); |
1866 | return (static_cast<int32_t>(ShiftVal) != -1); |
1867 | } |
1868 | |
1869 | void CombinerHelper::applyCombineMulToShl(MachineInstr &MI, |
1870 | unsigned &ShiftVal) { |
1871 | assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL" ); |
1872 | MachineIRBuilder MIB(MI); |
1873 | LLT ShiftTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
1874 | auto ShiftCst = MIB.buildConstant(Res: ShiftTy, Val: ShiftVal); |
1875 | Observer.changingInstr(MI); |
1876 | MI.setDesc(MIB.getTII().get(Opcode: TargetOpcode::G_SHL)); |
1877 | MI.getOperand(i: 2).setReg(ShiftCst.getReg(Idx: 0)); |
1878 | Observer.changedInstr(MI); |
1879 | } |
1880 | |
1881 | // shl ([sza]ext x), y => zext (shl x, y), if shift does not overflow source |
1882 | bool CombinerHelper::matchCombineShlOfExtend(MachineInstr &MI, |
1883 | RegisterImmPair &MatchData) { |
1884 | assert(MI.getOpcode() == TargetOpcode::G_SHL && KB); |
1885 | if (!getTargetLowering().isDesirableToPullExtFromShl(MI)) |
1886 | return false; |
1887 | |
1888 | Register LHS = MI.getOperand(i: 1).getReg(); |
1889 | |
1890 | Register ExtSrc; |
1891 | if (!mi_match(R: LHS, MRI, P: m_GAnyExt(Src: m_Reg(R&: ExtSrc))) && |
1892 | !mi_match(R: LHS, MRI, P: m_GZExt(Src: m_Reg(R&: ExtSrc))) && |
1893 | !mi_match(R: LHS, MRI, P: m_GSExt(Src: m_Reg(R&: ExtSrc)))) |
1894 | return false; |
1895 | |
1896 | Register RHS = MI.getOperand(i: 2).getReg(); |
1897 | MachineInstr *MIShiftAmt = MRI.getVRegDef(Reg: RHS); |
1898 | auto MaybeShiftAmtVal = isConstantOrConstantSplatVector(MI&: *MIShiftAmt, MRI); |
1899 | if (!MaybeShiftAmtVal) |
1900 | return false; |
1901 | |
1902 | if (LI) { |
1903 | LLT SrcTy = MRI.getType(Reg: ExtSrc); |
1904 | |
1905 | // We only really care about the legality with the shifted value. We can |
1906 | // pick any type the constant shift amount, so ask the target what to |
1907 | // use. Otherwise we would have to guess and hope it is reported as legal. |
1908 | LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: SrcTy); |
1909 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SHL, {SrcTy, ShiftAmtTy}})) |
1910 | return false; |
1911 | } |
1912 | |
1913 | int64_t ShiftAmt = MaybeShiftAmtVal->getSExtValue(); |
1914 | MatchData.Reg = ExtSrc; |
1915 | MatchData.Imm = ShiftAmt; |
1916 | |
1917 | unsigned MinLeadingZeros = KB->getKnownZeroes(R: ExtSrc).countl_one(); |
1918 | unsigned SrcTySize = MRI.getType(Reg: ExtSrc).getScalarSizeInBits(); |
1919 | return MinLeadingZeros >= ShiftAmt && ShiftAmt < SrcTySize; |
1920 | } |
1921 | |
1922 | void CombinerHelper::applyCombineShlOfExtend(MachineInstr &MI, |
1923 | const RegisterImmPair &MatchData) { |
1924 | Register ExtSrcReg = MatchData.Reg; |
1925 | int64_t ShiftAmtVal = MatchData.Imm; |
1926 | |
1927 | LLT ExtSrcTy = MRI.getType(Reg: ExtSrcReg); |
1928 | Builder.setInstrAndDebugLoc(MI); |
1929 | auto ShiftAmt = Builder.buildConstant(Res: ExtSrcTy, Val: ShiftAmtVal); |
1930 | auto NarrowShift = |
1931 | Builder.buildShl(Dst: ExtSrcTy, Src0: ExtSrcReg, Src1: ShiftAmt, Flags: MI.getFlags()); |
1932 | Builder.buildZExt(Res: MI.getOperand(i: 0), Op: NarrowShift); |
1933 | MI.eraseFromParent(); |
1934 | } |
1935 | |
1936 | bool CombinerHelper::matchCombineMergeUnmerge(MachineInstr &MI, |
1937 | Register &MatchInfo) { |
1938 | GMerge &Merge = cast<GMerge>(Val&: MI); |
1939 | SmallVector<Register, 16> MergedValues; |
1940 | for (unsigned I = 0; I < Merge.getNumSources(); ++I) |
1941 | MergedValues.emplace_back(Args: Merge.getSourceReg(I)); |
1942 | |
1943 | auto *Unmerge = getOpcodeDef<GUnmerge>(Reg: MergedValues[0], MRI); |
1944 | if (!Unmerge || Unmerge->getNumDefs() != Merge.getNumSources()) |
1945 | return false; |
1946 | |
1947 | for (unsigned I = 0; I < MergedValues.size(); ++I) |
1948 | if (MergedValues[I] != Unmerge->getReg(Idx: I)) |
1949 | return false; |
1950 | |
1951 | MatchInfo = Unmerge->getSourceReg(); |
1952 | return true; |
1953 | } |
1954 | |
1955 | static Register peekThroughBitcast(Register Reg, |
1956 | const MachineRegisterInfo &MRI) { |
1957 | while (mi_match(R: Reg, MRI, P: m_GBitcast(Src: m_Reg(R&: Reg)))) |
1958 | ; |
1959 | |
1960 | return Reg; |
1961 | } |
1962 | |
1963 | bool CombinerHelper::matchCombineUnmergeMergeToPlainValues( |
1964 | MachineInstr &MI, SmallVectorImpl<Register> &Operands) { |
1965 | assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && |
1966 | "Expected an unmerge" ); |
1967 | auto &Unmerge = cast<GUnmerge>(Val&: MI); |
1968 | Register SrcReg = peekThroughBitcast(Reg: Unmerge.getSourceReg(), MRI); |
1969 | |
1970 | auto *SrcInstr = getOpcodeDef<GMergeLikeInstr>(Reg: SrcReg, MRI); |
1971 | if (!SrcInstr) |
1972 | return false; |
1973 | |
1974 | // Check the source type of the merge. |
1975 | LLT SrcMergeTy = MRI.getType(Reg: SrcInstr->getSourceReg(I: 0)); |
1976 | LLT Dst0Ty = MRI.getType(Reg: Unmerge.getReg(Idx: 0)); |
1977 | bool SameSize = Dst0Ty.getSizeInBits() == SrcMergeTy.getSizeInBits(); |
1978 | if (SrcMergeTy != Dst0Ty && !SameSize) |
1979 | return false; |
1980 | // They are the same now (modulo a bitcast). |
1981 | // We can collect all the src registers. |
1982 | for (unsigned Idx = 0; Idx < SrcInstr->getNumSources(); ++Idx) |
1983 | Operands.push_back(Elt: SrcInstr->getSourceReg(I: Idx)); |
1984 | return true; |
1985 | } |
1986 | |
1987 | void CombinerHelper::applyCombineUnmergeMergeToPlainValues( |
1988 | MachineInstr &MI, SmallVectorImpl<Register> &Operands) { |
1989 | assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && |
1990 | "Expected an unmerge" ); |
1991 | assert((MI.getNumOperands() - 1 == Operands.size()) && |
1992 | "Not enough operands to replace all defs" ); |
1993 | unsigned NumElems = MI.getNumOperands() - 1; |
1994 | |
1995 | LLT SrcTy = MRI.getType(Reg: Operands[0]); |
1996 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
1997 | bool CanReuseInputDirectly = DstTy == SrcTy; |
1998 | Builder.setInstrAndDebugLoc(MI); |
1999 | for (unsigned Idx = 0; Idx < NumElems; ++Idx) { |
2000 | Register DstReg = MI.getOperand(i: Idx).getReg(); |
2001 | Register SrcReg = Operands[Idx]; |
2002 | |
2003 | // This combine may run after RegBankSelect, so we need to be aware of |
2004 | // register banks. |
2005 | const auto &DstCB = MRI.getRegClassOrRegBank(Reg: DstReg); |
2006 | if (!DstCB.isNull() && DstCB != MRI.getRegClassOrRegBank(Reg: SrcReg)) { |
2007 | SrcReg = Builder.buildCopy(Res: MRI.getType(Reg: SrcReg), Op: SrcReg).getReg(Idx: 0); |
2008 | MRI.setRegClassOrRegBank(Reg: SrcReg, RCOrRB: DstCB); |
2009 | } |
2010 | |
2011 | if (CanReuseInputDirectly) |
2012 | replaceRegWith(MRI, FromReg: DstReg, ToReg: SrcReg); |
2013 | else |
2014 | Builder.buildCast(Dst: DstReg, Src: SrcReg); |
2015 | } |
2016 | MI.eraseFromParent(); |
2017 | } |
2018 | |
2019 | bool CombinerHelper::matchCombineUnmergeConstant(MachineInstr &MI, |
2020 | SmallVectorImpl<APInt> &Csts) { |
2021 | unsigned SrcIdx = MI.getNumOperands() - 1; |
2022 | Register SrcReg = MI.getOperand(i: SrcIdx).getReg(); |
2023 | MachineInstr *SrcInstr = MRI.getVRegDef(Reg: SrcReg); |
2024 | if (SrcInstr->getOpcode() != TargetOpcode::G_CONSTANT && |
2025 | SrcInstr->getOpcode() != TargetOpcode::G_FCONSTANT) |
2026 | return false; |
2027 | // Break down the big constant in smaller ones. |
2028 | const MachineOperand &CstVal = SrcInstr->getOperand(i: 1); |
2029 | APInt Val = SrcInstr->getOpcode() == TargetOpcode::G_CONSTANT |
2030 | ? CstVal.getCImm()->getValue() |
2031 | : CstVal.getFPImm()->getValueAPF().bitcastToAPInt(); |
2032 | |
2033 | LLT Dst0Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
2034 | unsigned ShiftAmt = Dst0Ty.getSizeInBits(); |
2035 | // Unmerge a constant. |
2036 | for (unsigned Idx = 0; Idx != SrcIdx; ++Idx) { |
2037 | Csts.emplace_back(Args: Val.trunc(width: ShiftAmt)); |
2038 | Val = Val.lshr(shiftAmt: ShiftAmt); |
2039 | } |
2040 | |
2041 | return true; |
2042 | } |
2043 | |
2044 | void CombinerHelper::applyCombineUnmergeConstant(MachineInstr &MI, |
2045 | SmallVectorImpl<APInt> &Csts) { |
2046 | assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && |
2047 | "Expected an unmerge" ); |
2048 | assert((MI.getNumOperands() - 1 == Csts.size()) && |
2049 | "Not enough operands to replace all defs" ); |
2050 | unsigned NumElems = MI.getNumOperands() - 1; |
2051 | Builder.setInstrAndDebugLoc(MI); |
2052 | for (unsigned Idx = 0; Idx < NumElems; ++Idx) { |
2053 | Register DstReg = MI.getOperand(i: Idx).getReg(); |
2054 | Builder.buildConstant(Res: DstReg, Val: Csts[Idx]); |
2055 | } |
2056 | |
2057 | MI.eraseFromParent(); |
2058 | } |
2059 | |
2060 | bool CombinerHelper::matchCombineUnmergeUndef( |
2061 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
2062 | unsigned SrcIdx = MI.getNumOperands() - 1; |
2063 | Register SrcReg = MI.getOperand(i: SrcIdx).getReg(); |
2064 | MatchInfo = [&MI](MachineIRBuilder &B) { |
2065 | unsigned NumElems = MI.getNumOperands() - 1; |
2066 | for (unsigned Idx = 0; Idx < NumElems; ++Idx) { |
2067 | Register DstReg = MI.getOperand(i: Idx).getReg(); |
2068 | B.buildUndef(Res: DstReg); |
2069 | } |
2070 | }; |
2071 | return isa<GImplicitDef>(Val: MRI.getVRegDef(Reg: SrcReg)); |
2072 | } |
2073 | |
2074 | bool CombinerHelper::matchCombineUnmergeWithDeadLanesToTrunc(MachineInstr &MI) { |
2075 | assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && |
2076 | "Expected an unmerge" ); |
2077 | // Check that all the lanes are dead except the first one. |
2078 | for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) { |
2079 | if (!MRI.use_nodbg_empty(RegNo: MI.getOperand(i: Idx).getReg())) |
2080 | return false; |
2081 | } |
2082 | return true; |
2083 | } |
2084 | |
2085 | void CombinerHelper::applyCombineUnmergeWithDeadLanesToTrunc(MachineInstr &MI) { |
2086 | Builder.setInstrAndDebugLoc(MI); |
2087 | Register SrcReg = MI.getOperand(i: MI.getNumDefs()).getReg(); |
2088 | // Truncating a vector is going to truncate every single lane, |
2089 | // whereas we want the full lowbits. |
2090 | // Do the operation on a scalar instead. |
2091 | LLT SrcTy = MRI.getType(Reg: SrcReg); |
2092 | if (SrcTy.isVector()) |
2093 | SrcReg = |
2094 | Builder.buildCast(Dst: LLT::scalar(SizeInBits: SrcTy.getSizeInBits()), Src: SrcReg).getReg(Idx: 0); |
2095 | |
2096 | Register Dst0Reg = MI.getOperand(i: 0).getReg(); |
2097 | LLT Dst0Ty = MRI.getType(Reg: Dst0Reg); |
2098 | if (Dst0Ty.isVector()) { |
2099 | auto MIB = Builder.buildTrunc(Res: LLT::scalar(SizeInBits: Dst0Ty.getSizeInBits()), Op: SrcReg); |
2100 | Builder.buildCast(Dst: Dst0Reg, Src: MIB); |
2101 | } else |
2102 | Builder.buildTrunc(Res: Dst0Reg, Op: SrcReg); |
2103 | MI.eraseFromParent(); |
2104 | } |
2105 | |
2106 | bool CombinerHelper::matchCombineUnmergeZExtToZExt(MachineInstr &MI) { |
2107 | assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && |
2108 | "Expected an unmerge" ); |
2109 | Register Dst0Reg = MI.getOperand(i: 0).getReg(); |
2110 | LLT Dst0Ty = MRI.getType(Reg: Dst0Reg); |
2111 | // G_ZEXT on vector applies to each lane, so it will |
2112 | // affect all destinations. Therefore we won't be able |
2113 | // to simplify the unmerge to just the first definition. |
2114 | if (Dst0Ty.isVector()) |
2115 | return false; |
2116 | Register SrcReg = MI.getOperand(i: MI.getNumDefs()).getReg(); |
2117 | LLT SrcTy = MRI.getType(Reg: SrcReg); |
2118 | if (SrcTy.isVector()) |
2119 | return false; |
2120 | |
2121 | Register ZExtSrcReg; |
2122 | if (!mi_match(R: SrcReg, MRI, P: m_GZExt(Src: m_Reg(R&: ZExtSrcReg)))) |
2123 | return false; |
2124 | |
2125 | // Finally we can replace the first definition with |
2126 | // a zext of the source if the definition is big enough to hold |
2127 | // all of ZExtSrc bits. |
2128 | LLT ZExtSrcTy = MRI.getType(Reg: ZExtSrcReg); |
2129 | return ZExtSrcTy.getSizeInBits() <= Dst0Ty.getSizeInBits(); |
2130 | } |
2131 | |
2132 | void CombinerHelper::applyCombineUnmergeZExtToZExt(MachineInstr &MI) { |
2133 | assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && |
2134 | "Expected an unmerge" ); |
2135 | |
2136 | Register Dst0Reg = MI.getOperand(i: 0).getReg(); |
2137 | |
2138 | MachineInstr *ZExtInstr = |
2139 | MRI.getVRegDef(Reg: MI.getOperand(i: MI.getNumDefs()).getReg()); |
2140 | assert(ZExtInstr && ZExtInstr->getOpcode() == TargetOpcode::G_ZEXT && |
2141 | "Expecting a G_ZEXT" ); |
2142 | |
2143 | Register ZExtSrcReg = ZExtInstr->getOperand(i: 1).getReg(); |
2144 | LLT Dst0Ty = MRI.getType(Reg: Dst0Reg); |
2145 | LLT ZExtSrcTy = MRI.getType(Reg: ZExtSrcReg); |
2146 | |
2147 | Builder.setInstrAndDebugLoc(MI); |
2148 | |
2149 | if (Dst0Ty.getSizeInBits() > ZExtSrcTy.getSizeInBits()) { |
2150 | Builder.buildZExt(Res: Dst0Reg, Op: ZExtSrcReg); |
2151 | } else { |
2152 | assert(Dst0Ty.getSizeInBits() == ZExtSrcTy.getSizeInBits() && |
2153 | "ZExt src doesn't fit in destination" ); |
2154 | replaceRegWith(MRI, FromReg: Dst0Reg, ToReg: ZExtSrcReg); |
2155 | } |
2156 | |
2157 | Register ZeroReg; |
2158 | for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) { |
2159 | if (!ZeroReg) |
2160 | ZeroReg = Builder.buildConstant(Res: Dst0Ty, Val: 0).getReg(Idx: 0); |
2161 | replaceRegWith(MRI, FromReg: MI.getOperand(i: Idx).getReg(), ToReg: ZeroReg); |
2162 | } |
2163 | MI.eraseFromParent(); |
2164 | } |
2165 | |
2166 | bool CombinerHelper::matchCombineShiftToUnmerge(MachineInstr &MI, |
2167 | unsigned TargetShiftSize, |
2168 | unsigned &ShiftVal) { |
2169 | assert((MI.getOpcode() == TargetOpcode::G_SHL || |
2170 | MI.getOpcode() == TargetOpcode::G_LSHR || |
2171 | MI.getOpcode() == TargetOpcode::G_ASHR) && "Expected a shift" ); |
2172 | |
2173 | LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
2174 | if (Ty.isVector()) // TODO: |
2175 | return false; |
2176 | |
2177 | // Don't narrow further than the requested size. |
2178 | unsigned Size = Ty.getSizeInBits(); |
2179 | if (Size <= TargetShiftSize) |
2180 | return false; |
2181 | |
2182 | auto MaybeImmVal = |
2183 | getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI); |
2184 | if (!MaybeImmVal) |
2185 | return false; |
2186 | |
2187 | ShiftVal = MaybeImmVal->Value.getSExtValue(); |
2188 | return ShiftVal >= Size / 2 && ShiftVal < Size; |
2189 | } |
2190 | |
2191 | void CombinerHelper::applyCombineShiftToUnmerge(MachineInstr &MI, |
2192 | const unsigned &ShiftVal) { |
2193 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2194 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
2195 | LLT Ty = MRI.getType(Reg: SrcReg); |
2196 | unsigned Size = Ty.getSizeInBits(); |
2197 | unsigned HalfSize = Size / 2; |
2198 | assert(ShiftVal >= HalfSize); |
2199 | |
2200 | LLT HalfTy = LLT::scalar(SizeInBits: HalfSize); |
2201 | |
2202 | Builder.setInstr(MI); |
2203 | auto Unmerge = Builder.buildUnmerge(Res: HalfTy, Op: SrcReg); |
2204 | unsigned NarrowShiftAmt = ShiftVal - HalfSize; |
2205 | |
2206 | if (MI.getOpcode() == TargetOpcode::G_LSHR) { |
2207 | Register Narrowed = Unmerge.getReg(Idx: 1); |
2208 | |
2209 | // dst = G_LSHR s64:x, C for C >= 32 |
2210 | // => |
2211 | // lo, hi = G_UNMERGE_VALUES x |
2212 | // dst = G_MERGE_VALUES (G_LSHR hi, C - 32), 0 |
2213 | |
2214 | if (NarrowShiftAmt != 0) { |
2215 | Narrowed = Builder.buildLShr(Dst: HalfTy, Src0: Narrowed, |
2216 | Src1: Builder.buildConstant(Res: HalfTy, Val: NarrowShiftAmt)).getReg(Idx: 0); |
2217 | } |
2218 | |
2219 | auto Zero = Builder.buildConstant(Res: HalfTy, Val: 0); |
2220 | Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Narrowed, Zero}); |
2221 | } else if (MI.getOpcode() == TargetOpcode::G_SHL) { |
2222 | Register Narrowed = Unmerge.getReg(Idx: 0); |
2223 | // dst = G_SHL s64:x, C for C >= 32 |
2224 | // => |
2225 | // lo, hi = G_UNMERGE_VALUES x |
2226 | // dst = G_MERGE_VALUES 0, (G_SHL hi, C - 32) |
2227 | if (NarrowShiftAmt != 0) { |
2228 | Narrowed = Builder.buildShl(Dst: HalfTy, Src0: Narrowed, |
2229 | Src1: Builder.buildConstant(Res: HalfTy, Val: NarrowShiftAmt)).getReg(Idx: 0); |
2230 | } |
2231 | |
2232 | auto Zero = Builder.buildConstant(Res: HalfTy, Val: 0); |
2233 | Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Zero, Narrowed}); |
2234 | } else { |
2235 | assert(MI.getOpcode() == TargetOpcode::G_ASHR); |
2236 | auto Hi = Builder.buildAShr( |
2237 | Dst: HalfTy, Src0: Unmerge.getReg(Idx: 1), |
2238 | Src1: Builder.buildConstant(Res: HalfTy, Val: HalfSize - 1)); |
2239 | |
2240 | if (ShiftVal == HalfSize) { |
2241 | // (G_ASHR i64:x, 32) -> |
2242 | // G_MERGE_VALUES hi_32(x), (G_ASHR hi_32(x), 31) |
2243 | Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Unmerge.getReg(Idx: 1), Hi}); |
2244 | } else if (ShiftVal == Size - 1) { |
2245 | // Don't need a second shift. |
2246 | // (G_ASHR i64:x, 63) -> |
2247 | // %narrowed = (G_ASHR hi_32(x), 31) |
2248 | // G_MERGE_VALUES %narrowed, %narrowed |
2249 | Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Hi, Hi}); |
2250 | } else { |
2251 | auto Lo = Builder.buildAShr( |
2252 | Dst: HalfTy, Src0: Unmerge.getReg(Idx: 1), |
2253 | Src1: Builder.buildConstant(Res: HalfTy, Val: ShiftVal - HalfSize)); |
2254 | |
2255 | // (G_ASHR i64:x, C) ->, for C >= 32 |
2256 | // G_MERGE_VALUES (G_ASHR hi_32(x), C - 32), (G_ASHR hi_32(x), 31) |
2257 | Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Lo, Hi}); |
2258 | } |
2259 | } |
2260 | |
2261 | MI.eraseFromParent(); |
2262 | } |
2263 | |
2264 | bool CombinerHelper::tryCombineShiftToUnmerge(MachineInstr &MI, |
2265 | unsigned TargetShiftAmount) { |
2266 | unsigned ShiftAmt; |
2267 | if (matchCombineShiftToUnmerge(MI, TargetShiftSize: TargetShiftAmount, ShiftVal&: ShiftAmt)) { |
2268 | applyCombineShiftToUnmerge(MI, ShiftVal: ShiftAmt); |
2269 | return true; |
2270 | } |
2271 | |
2272 | return false; |
2273 | } |
2274 | |
2275 | bool CombinerHelper::matchCombineI2PToP2I(MachineInstr &MI, Register &Reg) { |
2276 | assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR" ); |
2277 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2278 | LLT DstTy = MRI.getType(Reg: DstReg); |
2279 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
2280 | return mi_match(R: SrcReg, MRI, |
2281 | P: m_GPtrToInt(Src: m_all_of(preds: m_SpecificType(Ty: DstTy), preds: m_Reg(R&: Reg)))); |
2282 | } |
2283 | |
2284 | void CombinerHelper::applyCombineI2PToP2I(MachineInstr &MI, Register &Reg) { |
2285 | assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR" ); |
2286 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2287 | Builder.setInstr(MI); |
2288 | Builder.buildCopy(Res: DstReg, Op: Reg); |
2289 | MI.eraseFromParent(); |
2290 | } |
2291 | |
2292 | void CombinerHelper::applyCombineP2IToI2P(MachineInstr &MI, Register &Reg) { |
2293 | assert(MI.getOpcode() == TargetOpcode::G_PTRTOINT && "Expected a G_PTRTOINT" ); |
2294 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2295 | Builder.setInstr(MI); |
2296 | Builder.buildZExtOrTrunc(Res: DstReg, Op: Reg); |
2297 | MI.eraseFromParent(); |
2298 | } |
2299 | |
2300 | bool CombinerHelper::matchCombineAddP2IToPtrAdd( |
2301 | MachineInstr &MI, std::pair<Register, bool> &PtrReg) { |
2302 | assert(MI.getOpcode() == TargetOpcode::G_ADD); |
2303 | Register LHS = MI.getOperand(i: 1).getReg(); |
2304 | Register RHS = MI.getOperand(i: 2).getReg(); |
2305 | LLT IntTy = MRI.getType(Reg: LHS); |
2306 | |
2307 | // G_PTR_ADD always has the pointer in the LHS, so we may need to commute the |
2308 | // instruction. |
2309 | PtrReg.second = false; |
2310 | for (Register SrcReg : {LHS, RHS}) { |
2311 | if (mi_match(R: SrcReg, MRI, P: m_GPtrToInt(Src: m_Reg(R&: PtrReg.first)))) { |
2312 | // Don't handle cases where the integer is implicitly converted to the |
2313 | // pointer width. |
2314 | LLT PtrTy = MRI.getType(Reg: PtrReg.first); |
2315 | if (PtrTy.getScalarSizeInBits() == IntTy.getScalarSizeInBits()) |
2316 | return true; |
2317 | } |
2318 | |
2319 | PtrReg.second = true; |
2320 | } |
2321 | |
2322 | return false; |
2323 | } |
2324 | |
2325 | void CombinerHelper::applyCombineAddP2IToPtrAdd( |
2326 | MachineInstr &MI, std::pair<Register, bool> &PtrReg) { |
2327 | Register Dst = MI.getOperand(i: 0).getReg(); |
2328 | Register LHS = MI.getOperand(i: 1).getReg(); |
2329 | Register RHS = MI.getOperand(i: 2).getReg(); |
2330 | |
2331 | const bool DoCommute = PtrReg.second; |
2332 | if (DoCommute) |
2333 | std::swap(a&: LHS, b&: RHS); |
2334 | LHS = PtrReg.first; |
2335 | |
2336 | LLT PtrTy = MRI.getType(Reg: LHS); |
2337 | |
2338 | Builder.setInstrAndDebugLoc(MI); |
2339 | auto PtrAdd = Builder.buildPtrAdd(Res: PtrTy, Op0: LHS, Op1: RHS); |
2340 | Builder.buildPtrToInt(Dst, Src: PtrAdd); |
2341 | MI.eraseFromParent(); |
2342 | } |
2343 | |
2344 | bool CombinerHelper::matchCombineConstPtrAddToI2P(MachineInstr &MI, |
2345 | APInt &NewCst) { |
2346 | auto &PtrAdd = cast<GPtrAdd>(Val&: MI); |
2347 | Register LHS = PtrAdd.getBaseReg(); |
2348 | Register RHS = PtrAdd.getOffsetReg(); |
2349 | MachineRegisterInfo &MRI = Builder.getMF().getRegInfo(); |
2350 | |
2351 | if (auto RHSCst = getIConstantVRegVal(VReg: RHS, MRI)) { |
2352 | APInt Cst; |
2353 | if (mi_match(R: LHS, MRI, P: m_GIntToPtr(Src: m_ICst(Cst)))) { |
2354 | auto DstTy = MRI.getType(Reg: PtrAdd.getReg(Idx: 0)); |
2355 | // G_INTTOPTR uses zero-extension |
2356 | NewCst = Cst.zextOrTrunc(width: DstTy.getSizeInBits()); |
2357 | NewCst += RHSCst->sextOrTrunc(width: DstTy.getSizeInBits()); |
2358 | return true; |
2359 | } |
2360 | } |
2361 | |
2362 | return false; |
2363 | } |
2364 | |
2365 | void CombinerHelper::applyCombineConstPtrAddToI2P(MachineInstr &MI, |
2366 | APInt &NewCst) { |
2367 | auto &PtrAdd = cast<GPtrAdd>(Val&: MI); |
2368 | Register Dst = PtrAdd.getReg(Idx: 0); |
2369 | |
2370 | Builder.setInstrAndDebugLoc(MI); |
2371 | Builder.buildConstant(Res: Dst, Val: NewCst); |
2372 | PtrAdd.eraseFromParent(); |
2373 | } |
2374 | |
2375 | bool CombinerHelper::matchCombineAnyExtTrunc(MachineInstr &MI, Register &Reg) { |
2376 | assert(MI.getOpcode() == TargetOpcode::G_ANYEXT && "Expected a G_ANYEXT" ); |
2377 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2378 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
2379 | Register OriginalSrcReg = getSrcRegIgnoringCopies(Reg: SrcReg, MRI); |
2380 | if (OriginalSrcReg.isValid()) |
2381 | SrcReg = OriginalSrcReg; |
2382 | LLT DstTy = MRI.getType(Reg: DstReg); |
2383 | return mi_match(R: SrcReg, MRI, |
2384 | P: m_GTrunc(Src: m_all_of(preds: m_Reg(R&: Reg), preds: m_SpecificType(Ty: DstTy)))); |
2385 | } |
2386 | |
2387 | bool CombinerHelper::matchCombineZextTrunc(MachineInstr &MI, Register &Reg) { |
2388 | assert(MI.getOpcode() == TargetOpcode::G_ZEXT && "Expected a G_ZEXT" ); |
2389 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2390 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
2391 | LLT DstTy = MRI.getType(Reg: DstReg); |
2392 | if (mi_match(R: SrcReg, MRI, |
2393 | P: m_GTrunc(Src: m_all_of(preds: m_Reg(R&: Reg), preds: m_SpecificType(Ty: DstTy))))) { |
2394 | unsigned DstSize = DstTy.getScalarSizeInBits(); |
2395 | unsigned SrcSize = MRI.getType(Reg: SrcReg).getScalarSizeInBits(); |
2396 | return KB->getKnownBits(R: Reg).countMinLeadingZeros() >= DstSize - SrcSize; |
2397 | } |
2398 | return false; |
2399 | } |
2400 | |
2401 | bool CombinerHelper::matchCombineExtOfExt( |
2402 | MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) { |
2403 | assert((MI.getOpcode() == TargetOpcode::G_ANYEXT || |
2404 | MI.getOpcode() == TargetOpcode::G_SEXT || |
2405 | MI.getOpcode() == TargetOpcode::G_ZEXT) && |
2406 | "Expected a G_[ASZ]EXT" ); |
2407 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
2408 | Register OriginalSrcReg = getSrcRegIgnoringCopies(Reg: SrcReg, MRI); |
2409 | if (OriginalSrcReg.isValid()) |
2410 | SrcReg = OriginalSrcReg; |
2411 | MachineInstr *SrcMI = MRI.getVRegDef(Reg: SrcReg); |
2412 | // Match exts with the same opcode, anyext([sz]ext) and sext(zext). |
2413 | unsigned Opc = MI.getOpcode(); |
2414 | unsigned SrcOpc = SrcMI->getOpcode(); |
2415 | if (Opc == SrcOpc || |
2416 | (Opc == TargetOpcode::G_ANYEXT && |
2417 | (SrcOpc == TargetOpcode::G_SEXT || SrcOpc == TargetOpcode::G_ZEXT)) || |
2418 | (Opc == TargetOpcode::G_SEXT && SrcOpc == TargetOpcode::G_ZEXT)) { |
2419 | MatchInfo = std::make_tuple(args: SrcMI->getOperand(i: 1).getReg(), args&: SrcOpc); |
2420 | return true; |
2421 | } |
2422 | return false; |
2423 | } |
2424 | |
2425 | void CombinerHelper::applyCombineExtOfExt( |
2426 | MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) { |
2427 | assert((MI.getOpcode() == TargetOpcode::G_ANYEXT || |
2428 | MI.getOpcode() == TargetOpcode::G_SEXT || |
2429 | MI.getOpcode() == TargetOpcode::G_ZEXT) && |
2430 | "Expected a G_[ASZ]EXT" ); |
2431 | |
2432 | Register Reg = std::get<0>(t&: MatchInfo); |
2433 | unsigned SrcExtOp = std::get<1>(t&: MatchInfo); |
2434 | |
2435 | // Combine exts with the same opcode. |
2436 | if (MI.getOpcode() == SrcExtOp) { |
2437 | Observer.changingInstr(MI); |
2438 | MI.getOperand(i: 1).setReg(Reg); |
2439 | Observer.changedInstr(MI); |
2440 | return; |
2441 | } |
2442 | |
2443 | // Combine: |
2444 | // - anyext([sz]ext x) to [sz]ext x |
2445 | // - sext(zext x) to zext x |
2446 | if (MI.getOpcode() == TargetOpcode::G_ANYEXT || |
2447 | (MI.getOpcode() == TargetOpcode::G_SEXT && |
2448 | SrcExtOp == TargetOpcode::G_ZEXT)) { |
2449 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2450 | Builder.setInstrAndDebugLoc(MI); |
2451 | Builder.buildInstr(Opc: SrcExtOp, DstOps: {DstReg}, SrcOps: {Reg}); |
2452 | MI.eraseFromParent(); |
2453 | } |
2454 | } |
2455 | |
2456 | bool CombinerHelper::matchCombineTruncOfExt( |
2457 | MachineInstr &MI, std::pair<Register, unsigned> &MatchInfo) { |
2458 | assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC" ); |
2459 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
2460 | MachineInstr *SrcMI = MRI.getVRegDef(Reg: SrcReg); |
2461 | unsigned SrcOpc = SrcMI->getOpcode(); |
2462 | if (SrcOpc == TargetOpcode::G_ANYEXT || SrcOpc == TargetOpcode::G_SEXT || |
2463 | SrcOpc == TargetOpcode::G_ZEXT) { |
2464 | MatchInfo = std::make_pair(x: SrcMI->getOperand(i: 1).getReg(), y&: SrcOpc); |
2465 | return true; |
2466 | } |
2467 | return false; |
2468 | } |
2469 | |
2470 | void CombinerHelper::applyCombineTruncOfExt( |
2471 | MachineInstr &MI, std::pair<Register, unsigned> &MatchInfo) { |
2472 | assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC" ); |
2473 | Register SrcReg = MatchInfo.first; |
2474 | unsigned SrcExtOp = MatchInfo.second; |
2475 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2476 | LLT SrcTy = MRI.getType(Reg: SrcReg); |
2477 | LLT DstTy = MRI.getType(Reg: DstReg); |
2478 | if (SrcTy == DstTy) { |
2479 | MI.eraseFromParent(); |
2480 | replaceRegWith(MRI, FromReg: DstReg, ToReg: SrcReg); |
2481 | return; |
2482 | } |
2483 | Builder.setInstrAndDebugLoc(MI); |
2484 | if (SrcTy.getSizeInBits() < DstTy.getSizeInBits()) |
2485 | Builder.buildInstr(Opc: SrcExtOp, DstOps: {DstReg}, SrcOps: {SrcReg}); |
2486 | else |
2487 | Builder.buildTrunc(Res: DstReg, Op: SrcReg); |
2488 | MI.eraseFromParent(); |
2489 | } |
2490 | |
2491 | static LLT getMidVTForTruncRightShiftCombine(LLT ShiftTy, LLT TruncTy) { |
2492 | const unsigned ShiftSize = ShiftTy.getScalarSizeInBits(); |
2493 | const unsigned TruncSize = TruncTy.getScalarSizeInBits(); |
2494 | |
2495 | // ShiftTy > 32 > TruncTy -> 32 |
2496 | if (ShiftSize > 32 && TruncSize < 32) |
2497 | return ShiftTy.changeElementSize(NewEltSize: 32); |
2498 | |
2499 | // TODO: We could also reduce to 16 bits, but that's more target-dependent. |
2500 | // Some targets like it, some don't, some only like it under certain |
2501 | // conditions/processor versions, etc. |
2502 | // A TL hook might be needed for this. |
2503 | |
2504 | // Don't combine |
2505 | return ShiftTy; |
2506 | } |
2507 | |
2508 | bool CombinerHelper::matchCombineTruncOfShift( |
2509 | MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) { |
2510 | assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC" ); |
2511 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2512 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
2513 | |
2514 | if (!MRI.hasOneNonDBGUse(RegNo: SrcReg)) |
2515 | return false; |
2516 | |
2517 | LLT SrcTy = MRI.getType(Reg: SrcReg); |
2518 | LLT DstTy = MRI.getType(Reg: DstReg); |
2519 | |
2520 | MachineInstr *SrcMI = getDefIgnoringCopies(Reg: SrcReg, MRI); |
2521 | const auto &TL = getTargetLowering(); |
2522 | |
2523 | LLT NewShiftTy; |
2524 | switch (SrcMI->getOpcode()) { |
2525 | default: |
2526 | return false; |
2527 | case TargetOpcode::G_SHL: { |
2528 | NewShiftTy = DstTy; |
2529 | |
2530 | // Make sure new shift amount is legal. |
2531 | KnownBits Known = KB->getKnownBits(R: SrcMI->getOperand(i: 2).getReg()); |
2532 | if (Known.getMaxValue().uge(RHS: NewShiftTy.getScalarSizeInBits())) |
2533 | return false; |
2534 | break; |
2535 | } |
2536 | case TargetOpcode::G_LSHR: |
2537 | case TargetOpcode::G_ASHR: { |
2538 | // For right shifts, we conservatively do not do the transform if the TRUNC |
2539 | // has any STORE users. The reason is that if we change the type of the |
2540 | // shift, we may break the truncstore combine. |
2541 | // |
2542 | // TODO: Fix truncstore combine to handle (trunc(lshr (trunc x), k)). |
2543 | for (auto &User : MRI.use_instructions(Reg: DstReg)) |
2544 | if (User.getOpcode() == TargetOpcode::G_STORE) |
2545 | return false; |
2546 | |
2547 | NewShiftTy = getMidVTForTruncRightShiftCombine(ShiftTy: SrcTy, TruncTy: DstTy); |
2548 | if (NewShiftTy == SrcTy) |
2549 | return false; |
2550 | |
2551 | // Make sure we won't lose information by truncating the high bits. |
2552 | KnownBits Known = KB->getKnownBits(R: SrcMI->getOperand(i: 2).getReg()); |
2553 | if (Known.getMaxValue().ugt(RHS: NewShiftTy.getScalarSizeInBits() - |
2554 | DstTy.getScalarSizeInBits())) |
2555 | return false; |
2556 | break; |
2557 | } |
2558 | } |
2559 | |
2560 | if (!isLegalOrBeforeLegalizer( |
2561 | Query: {SrcMI->getOpcode(), |
2562 | {NewShiftTy, TL.getPreferredShiftAmountTy(ShiftValueTy: NewShiftTy)}})) |
2563 | return false; |
2564 | |
2565 | MatchInfo = std::make_pair(x&: SrcMI, y&: NewShiftTy); |
2566 | return true; |
2567 | } |
2568 | |
2569 | void CombinerHelper::applyCombineTruncOfShift( |
2570 | MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) { |
2571 | Builder.setInstrAndDebugLoc(MI); |
2572 | |
2573 | MachineInstr *ShiftMI = MatchInfo.first; |
2574 | LLT NewShiftTy = MatchInfo.second; |
2575 | |
2576 | Register Dst = MI.getOperand(i: 0).getReg(); |
2577 | LLT DstTy = MRI.getType(Reg: Dst); |
2578 | |
2579 | Register ShiftAmt = ShiftMI->getOperand(i: 2).getReg(); |
2580 | Register ShiftSrc = ShiftMI->getOperand(i: 1).getReg(); |
2581 | ShiftSrc = Builder.buildTrunc(Res: NewShiftTy, Op: ShiftSrc).getReg(Idx: 0); |
2582 | |
2583 | Register NewShift = |
2584 | Builder |
2585 | .buildInstr(Opc: ShiftMI->getOpcode(), DstOps: {NewShiftTy}, SrcOps: {ShiftSrc, ShiftAmt}) |
2586 | .getReg(Idx: 0); |
2587 | |
2588 | if (NewShiftTy == DstTy) |
2589 | replaceRegWith(MRI, FromReg: Dst, ToReg: NewShift); |
2590 | else |
2591 | Builder.buildTrunc(Res: Dst, Op: NewShift); |
2592 | |
2593 | eraseInst(MI); |
2594 | } |
2595 | |
2596 | bool CombinerHelper::matchAnyExplicitUseIsUndef(MachineInstr &MI) { |
2597 | return any_of(Range: MI.explicit_uses(), P: [this](const MachineOperand &MO) { |
2598 | return MO.isReg() && |
2599 | getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MO.getReg(), MRI); |
2600 | }); |
2601 | } |
2602 | |
2603 | bool CombinerHelper::matchAllExplicitUsesAreUndef(MachineInstr &MI) { |
2604 | return all_of(Range: MI.explicit_uses(), P: [this](const MachineOperand &MO) { |
2605 | return !MO.isReg() || |
2606 | getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MO.getReg(), MRI); |
2607 | }); |
2608 | } |
2609 | |
2610 | bool CombinerHelper::matchUndefShuffleVectorMask(MachineInstr &MI) { |
2611 | assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR); |
2612 | ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask(); |
2613 | return all_of(Range&: Mask, P: [](int Elt) { return Elt < 0; }); |
2614 | } |
2615 | |
2616 | bool CombinerHelper::matchUndefStore(MachineInstr &MI) { |
2617 | assert(MI.getOpcode() == TargetOpcode::G_STORE); |
2618 | return getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MI.getOperand(i: 0).getReg(), |
2619 | MRI); |
2620 | } |
2621 | |
2622 | bool CombinerHelper::matchUndefSelectCmp(MachineInstr &MI) { |
2623 | assert(MI.getOpcode() == TargetOpcode::G_SELECT); |
2624 | return getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MI.getOperand(i: 1).getReg(), |
2625 | MRI); |
2626 | } |
2627 | |
2628 | bool CombinerHelper::(MachineInstr &MI) { |
2629 | assert((MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT || |
2630 | MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) && |
2631 | "Expected an insert/extract element op" ); |
2632 | LLT VecTy = MRI.getType(Reg: MI.getOperand(i: 1).getReg()); |
2633 | unsigned IdxIdx = |
2634 | MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT ? 2 : 3; |
2635 | auto Idx = getIConstantVRegVal(VReg: MI.getOperand(i: IdxIdx).getReg(), MRI); |
2636 | if (!Idx) |
2637 | return false; |
2638 | return Idx->getZExtValue() >= VecTy.getNumElements(); |
2639 | } |
2640 | |
2641 | bool CombinerHelper::matchConstantSelectCmp(MachineInstr &MI, unsigned &OpIdx) { |
2642 | GSelect &SelMI = cast<GSelect>(Val&: MI); |
2643 | auto Cst = |
2644 | isConstantOrConstantSplatVector(MI&: *MRI.getVRegDef(Reg: SelMI.getCondReg()), MRI); |
2645 | if (!Cst) |
2646 | return false; |
2647 | OpIdx = Cst->isZero() ? 3 : 2; |
2648 | return true; |
2649 | } |
2650 | |
2651 | void CombinerHelper::eraseInst(MachineInstr &MI) { MI.eraseFromParent(); } |
2652 | |
2653 | bool CombinerHelper::matchEqualDefs(const MachineOperand &MOP1, |
2654 | const MachineOperand &MOP2) { |
2655 | if (!MOP1.isReg() || !MOP2.isReg()) |
2656 | return false; |
2657 | auto InstAndDef1 = getDefSrcRegIgnoringCopies(Reg: MOP1.getReg(), MRI); |
2658 | if (!InstAndDef1) |
2659 | return false; |
2660 | auto InstAndDef2 = getDefSrcRegIgnoringCopies(Reg: MOP2.getReg(), MRI); |
2661 | if (!InstAndDef2) |
2662 | return false; |
2663 | MachineInstr *I1 = InstAndDef1->MI; |
2664 | MachineInstr *I2 = InstAndDef2->MI; |
2665 | |
2666 | // Handle a case like this: |
2667 | // |
2668 | // %0:_(s64), %1:_(s64) = G_UNMERGE_VALUES %2:_(<2 x s64>) |
2669 | // |
2670 | // Even though %0 and %1 are produced by the same instruction they are not |
2671 | // the same values. |
2672 | if (I1 == I2) |
2673 | return MOP1.getReg() == MOP2.getReg(); |
2674 | |
2675 | // If we have an instruction which loads or stores, we can't guarantee that |
2676 | // it is identical. |
2677 | // |
2678 | // For example, we may have |
2679 | // |
2680 | // %x1 = G_LOAD %addr (load N from @somewhere) |
2681 | // ... |
2682 | // call @foo |
2683 | // ... |
2684 | // %x2 = G_LOAD %addr (load N from @somewhere) |
2685 | // ... |
2686 | // %or = G_OR %x1, %x2 |
2687 | // |
2688 | // It's possible that @foo will modify whatever lives at the address we're |
2689 | // loading from. To be safe, let's just assume that all loads and stores |
2690 | // are different (unless we have something which is guaranteed to not |
2691 | // change.) |
2692 | if (I1->mayLoadOrStore() && !I1->isDereferenceableInvariantLoad()) |
2693 | return false; |
2694 | |
2695 | // If both instructions are loads or stores, they are equal only if both |
2696 | // are dereferenceable invariant loads with the same number of bits. |
2697 | if (I1->mayLoadOrStore() && I2->mayLoadOrStore()) { |
2698 | GLoadStore *LS1 = dyn_cast<GLoadStore>(Val: I1); |
2699 | GLoadStore *LS2 = dyn_cast<GLoadStore>(Val: I2); |
2700 | if (!LS1 || !LS2) |
2701 | return false; |
2702 | |
2703 | if (!I2->isDereferenceableInvariantLoad() || |
2704 | (LS1->getMemSizeInBits() != LS2->getMemSizeInBits())) |
2705 | return false; |
2706 | } |
2707 | |
2708 | // Check for physical registers on the instructions first to avoid cases |
2709 | // like this: |
2710 | // |
2711 | // %a = COPY $physreg |
2712 | // ... |
2713 | // SOMETHING implicit-def $physreg |
2714 | // ... |
2715 | // %b = COPY $physreg |
2716 | // |
2717 | // These copies are not equivalent. |
2718 | if (any_of(Range: I1->uses(), P: [](const MachineOperand &MO) { |
2719 | return MO.isReg() && MO.getReg().isPhysical(); |
2720 | })) { |
2721 | // Check if we have a case like this: |
2722 | // |
2723 | // %a = COPY $physreg |
2724 | // %b = COPY %a |
2725 | // |
2726 | // In this case, I1 and I2 will both be equal to %a = COPY $physreg. |
2727 | // From that, we know that they must have the same value, since they must |
2728 | // have come from the same COPY. |
2729 | return I1->isIdenticalTo(Other: *I2); |
2730 | } |
2731 | |
2732 | // We don't have any physical registers, so we don't necessarily need the |
2733 | // same vreg defs. |
2734 | // |
2735 | // On the off-chance that there's some target instruction feeding into the |
2736 | // instruction, let's use produceSameValue instead of isIdenticalTo. |
2737 | if (Builder.getTII().produceSameValue(MI0: *I1, MI1: *I2, MRI: &MRI)) { |
2738 | // Handle instructions with multiple defs that produce same values. Values |
2739 | // are same for operands with same index. |
2740 | // %0:_(s8), %1:_(s8), %2:_(s8), %3:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>) |
2741 | // %5:_(s8), %6:_(s8), %7:_(s8), %8:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>) |
2742 | // I1 and I2 are different instructions but produce same values, |
2743 | // %1 and %6 are same, %1 and %7 are not the same value. |
2744 | return I1->findRegisterDefOperandIdx(Reg: InstAndDef1->Reg) == |
2745 | I2->findRegisterDefOperandIdx(Reg: InstAndDef2->Reg); |
2746 | } |
2747 | return false; |
2748 | } |
2749 | |
2750 | bool CombinerHelper::matchConstantOp(const MachineOperand &MOP, int64_t C) { |
2751 | if (!MOP.isReg()) |
2752 | return false; |
2753 | auto *MI = MRI.getVRegDef(Reg: MOP.getReg()); |
2754 | auto MaybeCst = isConstantOrConstantSplatVector(MI&: *MI, MRI); |
2755 | return MaybeCst && MaybeCst->getBitWidth() <= 64 && |
2756 | MaybeCst->getSExtValue() == C; |
2757 | } |
2758 | |
2759 | bool CombinerHelper::matchConstantFPOp(const MachineOperand &MOP, double C) { |
2760 | if (!MOP.isReg()) |
2761 | return false; |
2762 | std::optional<FPValueAndVReg> MaybeCst; |
2763 | if (!mi_match(R: MOP.getReg(), MRI, P: m_GFCstOrSplat(FPValReg&: MaybeCst))) |
2764 | return false; |
2765 | |
2766 | return MaybeCst->Value.isExactlyValue(V: C); |
2767 | } |
2768 | |
2769 | void CombinerHelper::replaceSingleDefInstWithOperand(MachineInstr &MI, |
2770 | unsigned OpIdx) { |
2771 | assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?" ); |
2772 | Register OldReg = MI.getOperand(i: 0).getReg(); |
2773 | Register Replacement = MI.getOperand(i: OpIdx).getReg(); |
2774 | assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?" ); |
2775 | MI.eraseFromParent(); |
2776 | replaceRegWith(MRI, FromReg: OldReg, ToReg: Replacement); |
2777 | } |
2778 | |
2779 | void CombinerHelper::replaceSingleDefInstWithReg(MachineInstr &MI, |
2780 | Register Replacement) { |
2781 | assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?" ); |
2782 | Register OldReg = MI.getOperand(i: 0).getReg(); |
2783 | assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?" ); |
2784 | MI.eraseFromParent(); |
2785 | replaceRegWith(MRI, FromReg: OldReg, ToReg: Replacement); |
2786 | } |
2787 | |
2788 | bool CombinerHelper::matchConstantLargerBitWidth(MachineInstr &MI, |
2789 | unsigned ConstIdx) { |
2790 | Register ConstReg = MI.getOperand(i: ConstIdx).getReg(); |
2791 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
2792 | |
2793 | // Get the shift amount |
2794 | auto VRegAndVal = getIConstantVRegValWithLookThrough(VReg: ConstReg, MRI); |
2795 | if (!VRegAndVal) |
2796 | return false; |
2797 | |
2798 | // Return true of shift amount >= Bitwidth |
2799 | return (VRegAndVal->Value.uge(RHS: DstTy.getSizeInBits())); |
2800 | } |
2801 | |
2802 | void CombinerHelper::applyFunnelShiftConstantModulo(MachineInstr &MI) { |
2803 | assert((MI.getOpcode() == TargetOpcode::G_FSHL || |
2804 | MI.getOpcode() == TargetOpcode::G_FSHR) && |
2805 | "This is not a funnel shift operation" ); |
2806 | |
2807 | Register ConstReg = MI.getOperand(i: 3).getReg(); |
2808 | LLT ConstTy = MRI.getType(Reg: ConstReg); |
2809 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
2810 | |
2811 | auto VRegAndVal = getIConstantVRegValWithLookThrough(VReg: ConstReg, MRI); |
2812 | assert((VRegAndVal) && "Value is not a constant" ); |
2813 | |
2814 | // Calculate the new Shift Amount = Old Shift Amount % BitWidth |
2815 | APInt NewConst = VRegAndVal->Value.urem( |
2816 | RHS: APInt(ConstTy.getSizeInBits(), DstTy.getScalarSizeInBits())); |
2817 | |
2818 | Builder.setInstrAndDebugLoc(MI); |
2819 | auto NewConstInstr = Builder.buildConstant(Res: ConstTy, Val: NewConst.getZExtValue()); |
2820 | Builder.buildInstr( |
2821 | Opc: MI.getOpcode(), DstOps: {MI.getOperand(i: 0)}, |
2822 | SrcOps: {MI.getOperand(i: 1), MI.getOperand(i: 2), NewConstInstr.getReg(Idx: 0)}); |
2823 | |
2824 | MI.eraseFromParent(); |
2825 | } |
2826 | |
2827 | bool CombinerHelper::matchSelectSameVal(MachineInstr &MI) { |
2828 | assert(MI.getOpcode() == TargetOpcode::G_SELECT); |
2829 | // Match (cond ? x : x) |
2830 | return matchEqualDefs(MOP1: MI.getOperand(i: 2), MOP2: MI.getOperand(i: 3)) && |
2831 | canReplaceReg(DstReg: MI.getOperand(i: 0).getReg(), SrcReg: MI.getOperand(i: 2).getReg(), |
2832 | MRI); |
2833 | } |
2834 | |
2835 | bool CombinerHelper::matchBinOpSameVal(MachineInstr &MI) { |
2836 | return matchEqualDefs(MOP1: MI.getOperand(i: 1), MOP2: MI.getOperand(i: 2)) && |
2837 | canReplaceReg(DstReg: MI.getOperand(i: 0).getReg(), SrcReg: MI.getOperand(i: 1).getReg(), |
2838 | MRI); |
2839 | } |
2840 | |
2841 | bool CombinerHelper::matchOperandIsZero(MachineInstr &MI, unsigned OpIdx) { |
2842 | return matchConstantOp(MOP: MI.getOperand(i: OpIdx), C: 0) && |
2843 | canReplaceReg(DstReg: MI.getOperand(i: 0).getReg(), SrcReg: MI.getOperand(i: OpIdx).getReg(), |
2844 | MRI); |
2845 | } |
2846 | |
2847 | bool CombinerHelper::matchOperandIsUndef(MachineInstr &MI, unsigned OpIdx) { |
2848 | MachineOperand &MO = MI.getOperand(i: OpIdx); |
2849 | return MO.isReg() && |
2850 | getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MO.getReg(), MRI); |
2851 | } |
2852 | |
2853 | bool CombinerHelper::matchOperandIsKnownToBeAPowerOfTwo(MachineInstr &MI, |
2854 | unsigned OpIdx) { |
2855 | MachineOperand &MO = MI.getOperand(i: OpIdx); |
2856 | return isKnownToBeAPowerOfTwo(Val: MO.getReg(), MRI, KnownBits: KB); |
2857 | } |
2858 | |
2859 | void CombinerHelper::replaceInstWithFConstant(MachineInstr &MI, double C) { |
2860 | assert(MI.getNumDefs() == 1 && "Expected only one def?" ); |
2861 | Builder.setInstr(MI); |
2862 | Builder.buildFConstant(Res: MI.getOperand(i: 0), Val: C); |
2863 | MI.eraseFromParent(); |
2864 | } |
2865 | |
2866 | void CombinerHelper::replaceInstWithConstant(MachineInstr &MI, int64_t C) { |
2867 | assert(MI.getNumDefs() == 1 && "Expected only one def?" ); |
2868 | Builder.setInstr(MI); |
2869 | Builder.buildConstant(Res: MI.getOperand(i: 0), Val: C); |
2870 | MI.eraseFromParent(); |
2871 | } |
2872 | |
2873 | void CombinerHelper::replaceInstWithConstant(MachineInstr &MI, APInt C) { |
2874 | assert(MI.getNumDefs() == 1 && "Expected only one def?" ); |
2875 | Builder.setInstr(MI); |
2876 | Builder.buildConstant(Res: MI.getOperand(i: 0), Val: C); |
2877 | MI.eraseFromParent(); |
2878 | } |
2879 | |
2880 | void CombinerHelper::replaceInstWithFConstant(MachineInstr &MI, ConstantFP *CFP) { |
2881 | assert(MI.getNumDefs() == 1 && "Expected only one def?" ); |
2882 | Builder.setInstr(MI); |
2883 | Builder.buildFConstant(Res: MI.getOperand(i: 0), Val: CFP->getValueAPF()); |
2884 | MI.eraseFromParent(); |
2885 | } |
2886 | |
2887 | void CombinerHelper::replaceInstWithUndef(MachineInstr &MI) { |
2888 | assert(MI.getNumDefs() == 1 && "Expected only one def?" ); |
2889 | Builder.setInstr(MI); |
2890 | Builder.buildUndef(Res: MI.getOperand(i: 0)); |
2891 | MI.eraseFromParent(); |
2892 | } |
2893 | |
2894 | bool CombinerHelper::matchSimplifyAddToSub( |
2895 | MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) { |
2896 | Register LHS = MI.getOperand(i: 1).getReg(); |
2897 | Register RHS = MI.getOperand(i: 2).getReg(); |
2898 | Register &NewLHS = std::get<0>(t&: MatchInfo); |
2899 | Register &NewRHS = std::get<1>(t&: MatchInfo); |
2900 | |
2901 | // Helper lambda to check for opportunities for |
2902 | // ((0-A) + B) -> B - A |
2903 | // (A + (0-B)) -> A - B |
2904 | auto CheckFold = [&](Register &MaybeSub, Register &MaybeNewLHS) { |
2905 | if (!mi_match(R: MaybeSub, MRI, P: m_Neg(Src: m_Reg(R&: NewRHS)))) |
2906 | return false; |
2907 | NewLHS = MaybeNewLHS; |
2908 | return true; |
2909 | }; |
2910 | |
2911 | return CheckFold(LHS, RHS) || CheckFold(RHS, LHS); |
2912 | } |
2913 | |
2914 | bool CombinerHelper::matchCombineInsertVecElts( |
2915 | MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) { |
2916 | assert(MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT && |
2917 | "Invalid opcode" ); |
2918 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2919 | LLT DstTy = MRI.getType(Reg: DstReg); |
2920 | assert(DstTy.isVector() && "Invalid G_INSERT_VECTOR_ELT?" ); |
2921 | unsigned NumElts = DstTy.getNumElements(); |
2922 | // If this MI is part of a sequence of insert_vec_elts, then |
2923 | // don't do the combine in the middle of the sequence. |
2924 | if (MRI.hasOneUse(RegNo: DstReg) && MRI.use_instr_begin(RegNo: DstReg)->getOpcode() == |
2925 | TargetOpcode::G_INSERT_VECTOR_ELT) |
2926 | return false; |
2927 | MachineInstr *CurrInst = &MI; |
2928 | MachineInstr *TmpInst; |
2929 | int64_t IntImm; |
2930 | Register TmpReg; |
2931 | MatchInfo.resize(N: NumElts); |
2932 | while (mi_match( |
2933 | R: CurrInst->getOperand(i: 0).getReg(), MRI, |
2934 | P: m_GInsertVecElt(Src0: m_MInstr(MI&: TmpInst), Src1: m_Reg(R&: TmpReg), Src2: m_ICst(Cst&: IntImm)))) { |
2935 | if (IntImm >= NumElts || IntImm < 0) |
2936 | return false; |
2937 | if (!MatchInfo[IntImm]) |
2938 | MatchInfo[IntImm] = TmpReg; |
2939 | CurrInst = TmpInst; |
2940 | } |
2941 | // Variable index. |
2942 | if (CurrInst->getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT) |
2943 | return false; |
2944 | if (TmpInst->getOpcode() == TargetOpcode::G_BUILD_VECTOR) { |
2945 | for (unsigned I = 1; I < TmpInst->getNumOperands(); ++I) { |
2946 | if (!MatchInfo[I - 1].isValid()) |
2947 | MatchInfo[I - 1] = TmpInst->getOperand(i: I).getReg(); |
2948 | } |
2949 | return true; |
2950 | } |
2951 | // If we didn't end in a G_IMPLICIT_DEF, bail out. |
2952 | return TmpInst->getOpcode() == TargetOpcode::G_IMPLICIT_DEF; |
2953 | } |
2954 | |
2955 | void CombinerHelper::applyCombineInsertVecElts( |
2956 | MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) { |
2957 | Builder.setInstr(MI); |
2958 | Register UndefReg; |
2959 | auto GetUndef = [&]() { |
2960 | if (UndefReg) |
2961 | return UndefReg; |
2962 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
2963 | UndefReg = Builder.buildUndef(Res: DstTy.getScalarType()).getReg(Idx: 0); |
2964 | return UndefReg; |
2965 | }; |
2966 | for (unsigned I = 0; I < MatchInfo.size(); ++I) { |
2967 | if (!MatchInfo[I]) |
2968 | MatchInfo[I] = GetUndef(); |
2969 | } |
2970 | Builder.buildBuildVector(Res: MI.getOperand(i: 0).getReg(), Ops: MatchInfo); |
2971 | MI.eraseFromParent(); |
2972 | } |
2973 | |
2974 | void CombinerHelper::applySimplifyAddToSub( |
2975 | MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) { |
2976 | Builder.setInstr(MI); |
2977 | Register SubLHS, SubRHS; |
2978 | std::tie(args&: SubLHS, args&: SubRHS) = MatchInfo; |
2979 | Builder.buildSub(Dst: MI.getOperand(i: 0).getReg(), Src0: SubLHS, Src1: SubRHS); |
2980 | MI.eraseFromParent(); |
2981 | } |
2982 | |
2983 | bool CombinerHelper::matchHoistLogicOpWithSameOpcodeHands( |
2984 | MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) { |
2985 | // Matches: logic (hand x, ...), (hand y, ...) -> hand (logic x, y), ... |
2986 | // |
2987 | // Creates the new hand + logic instruction (but does not insert them.) |
2988 | // |
2989 | // On success, MatchInfo is populated with the new instructions. These are |
2990 | // inserted in applyHoistLogicOpWithSameOpcodeHands. |
2991 | unsigned LogicOpcode = MI.getOpcode(); |
2992 | assert(LogicOpcode == TargetOpcode::G_AND || |
2993 | LogicOpcode == TargetOpcode::G_OR || |
2994 | LogicOpcode == TargetOpcode::G_XOR); |
2995 | MachineIRBuilder MIB(MI); |
2996 | Register Dst = MI.getOperand(i: 0).getReg(); |
2997 | Register LHSReg = MI.getOperand(i: 1).getReg(); |
2998 | Register RHSReg = MI.getOperand(i: 2).getReg(); |
2999 | |
3000 | // Don't recompute anything. |
3001 | if (!MRI.hasOneNonDBGUse(RegNo: LHSReg) || !MRI.hasOneNonDBGUse(RegNo: RHSReg)) |
3002 | return false; |
3003 | |
3004 | // Make sure we have (hand x, ...), (hand y, ...) |
3005 | MachineInstr *LeftHandInst = getDefIgnoringCopies(Reg: LHSReg, MRI); |
3006 | MachineInstr *RightHandInst = getDefIgnoringCopies(Reg: RHSReg, MRI); |
3007 | if (!LeftHandInst || !RightHandInst) |
3008 | return false; |
3009 | unsigned HandOpcode = LeftHandInst->getOpcode(); |
3010 | if (HandOpcode != RightHandInst->getOpcode()) |
3011 | return false; |
3012 | if (!LeftHandInst->getOperand(i: 1).isReg() || |
3013 | !RightHandInst->getOperand(i: 1).isReg()) |
3014 | return false; |
3015 | |
3016 | // Make sure the types match up, and if we're doing this post-legalization, |
3017 | // we end up with legal types. |
3018 | Register X = LeftHandInst->getOperand(i: 1).getReg(); |
3019 | Register Y = RightHandInst->getOperand(i: 1).getReg(); |
3020 | LLT XTy = MRI.getType(Reg: X); |
3021 | LLT YTy = MRI.getType(Reg: Y); |
3022 | if (!XTy.isValid() || XTy != YTy) |
3023 | return false; |
3024 | |
3025 | // Optional extra source register. |
3026 | Register ExtraHandOpSrcReg; |
3027 | switch (HandOpcode) { |
3028 | default: |
3029 | return false; |
3030 | case TargetOpcode::G_ANYEXT: |
3031 | case TargetOpcode::G_SEXT: |
3032 | case TargetOpcode::G_ZEXT: { |
3033 | // Match: logic (ext X), (ext Y) --> ext (logic X, Y) |
3034 | break; |
3035 | } |
3036 | case TargetOpcode::G_AND: |
3037 | case TargetOpcode::G_ASHR: |
3038 | case TargetOpcode::G_LSHR: |
3039 | case TargetOpcode::G_SHL: { |
3040 | // Match: logic (binop x, z), (binop y, z) -> binop (logic x, y), z |
3041 | MachineOperand &ZOp = LeftHandInst->getOperand(i: 2); |
3042 | if (!matchEqualDefs(MOP1: ZOp, MOP2: RightHandInst->getOperand(i: 2))) |
3043 | return false; |
3044 | ExtraHandOpSrcReg = ZOp.getReg(); |
3045 | break; |
3046 | } |
3047 | } |
3048 | |
3049 | if (!isLegalOrBeforeLegalizer(Query: {LogicOpcode, {XTy, YTy}})) |
3050 | return false; |
3051 | |
3052 | // Record the steps to build the new instructions. |
3053 | // |
3054 | // Steps to build (logic x, y) |
3055 | auto NewLogicDst = MRI.createGenericVirtualRegister(Ty: XTy); |
3056 | OperandBuildSteps LogicBuildSteps = { |
3057 | [=](MachineInstrBuilder &MIB) { MIB.addDef(RegNo: NewLogicDst); }, |
3058 | [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: X); }, |
3059 | [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: Y); }}; |
3060 | InstructionBuildSteps LogicSteps(LogicOpcode, LogicBuildSteps); |
3061 | |
3062 | // Steps to build hand (logic x, y), ...z |
3063 | OperandBuildSteps HandBuildSteps = { |
3064 | [=](MachineInstrBuilder &MIB) { MIB.addDef(RegNo: Dst); }, |
3065 | [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: NewLogicDst); }}; |
3066 | if (ExtraHandOpSrcReg.isValid()) |
3067 | HandBuildSteps.push_back( |
3068 | Elt: [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: ExtraHandOpSrcReg); }); |
3069 | InstructionBuildSteps HandSteps(HandOpcode, HandBuildSteps); |
3070 | |
3071 | MatchInfo = InstructionStepsMatchInfo({LogicSteps, HandSteps}); |
3072 | return true; |
3073 | } |
3074 | |
3075 | void CombinerHelper::applyBuildInstructionSteps( |
3076 | MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) { |
3077 | assert(MatchInfo.InstrsToBuild.size() && |
3078 | "Expected at least one instr to build?" ); |
3079 | Builder.setInstr(MI); |
3080 | for (auto &InstrToBuild : MatchInfo.InstrsToBuild) { |
3081 | assert(InstrToBuild.Opcode && "Expected a valid opcode?" ); |
3082 | assert(InstrToBuild.OperandFns.size() && "Expected at least one operand?" ); |
3083 | MachineInstrBuilder Instr = Builder.buildInstr(Opcode: InstrToBuild.Opcode); |
3084 | for (auto &OperandFn : InstrToBuild.OperandFns) |
3085 | OperandFn(Instr); |
3086 | } |
3087 | MI.eraseFromParent(); |
3088 | } |
3089 | |
3090 | bool CombinerHelper::matchAshrShlToSextInreg( |
3091 | MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) { |
3092 | assert(MI.getOpcode() == TargetOpcode::G_ASHR); |
3093 | int64_t ShlCst, AshrCst; |
3094 | Register Src; |
3095 | if (!mi_match(R: MI.getOperand(i: 0).getReg(), MRI, |
3096 | P: m_GAShr(L: m_GShl(L: m_Reg(R&: Src), R: m_ICstOrSplat(Cst&: ShlCst)), |
3097 | R: m_ICstOrSplat(Cst&: AshrCst)))) |
3098 | return false; |
3099 | if (ShlCst != AshrCst) |
3100 | return false; |
3101 | if (!isLegalOrBeforeLegalizer( |
3102 | Query: {TargetOpcode::G_SEXT_INREG, {MRI.getType(Reg: Src)}})) |
3103 | return false; |
3104 | MatchInfo = std::make_tuple(args&: Src, args&: ShlCst); |
3105 | return true; |
3106 | } |
3107 | |
3108 | void CombinerHelper::applyAshShlToSextInreg( |
3109 | MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) { |
3110 | assert(MI.getOpcode() == TargetOpcode::G_ASHR); |
3111 | Register Src; |
3112 | int64_t ShiftAmt; |
3113 | std::tie(args&: Src, args&: ShiftAmt) = MatchInfo; |
3114 | unsigned Size = MRI.getType(Reg: Src).getScalarSizeInBits(); |
3115 | Builder.setInstrAndDebugLoc(MI); |
3116 | Builder.buildSExtInReg(Res: MI.getOperand(i: 0).getReg(), Op: Src, ImmOp: Size - ShiftAmt); |
3117 | MI.eraseFromParent(); |
3118 | } |
3119 | |
3120 | /// and(and(x, C1), C2) -> C1&C2 ? and(x, C1&C2) : 0 |
3121 | bool CombinerHelper::matchOverlappingAnd( |
3122 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
3123 | assert(MI.getOpcode() == TargetOpcode::G_AND); |
3124 | |
3125 | Register Dst = MI.getOperand(i: 0).getReg(); |
3126 | LLT Ty = MRI.getType(Reg: Dst); |
3127 | |
3128 | Register R; |
3129 | int64_t C1; |
3130 | int64_t C2; |
3131 | if (!mi_match( |
3132 | R: Dst, MRI, |
3133 | P: m_GAnd(L: m_GAnd(L: m_Reg(R), R: m_ICst(Cst&: C1)), R: m_ICst(Cst&: C2)))) |
3134 | return false; |
3135 | |
3136 | MatchInfo = [=](MachineIRBuilder &B) { |
3137 | if (C1 & C2) { |
3138 | B.buildAnd(Dst, Src0: R, Src1: B.buildConstant(Res: Ty, Val: C1 & C2)); |
3139 | return; |
3140 | } |
3141 | auto Zero = B.buildConstant(Res: Ty, Val: 0); |
3142 | replaceRegWith(MRI, FromReg: Dst, ToReg: Zero->getOperand(i: 0).getReg()); |
3143 | }; |
3144 | return true; |
3145 | } |
3146 | |
3147 | bool CombinerHelper::matchRedundantAnd(MachineInstr &MI, |
3148 | Register &Replacement) { |
3149 | // Given |
3150 | // |
3151 | // %y:_(sN) = G_SOMETHING |
3152 | // %x:_(sN) = G_SOMETHING |
3153 | // %res:_(sN) = G_AND %x, %y |
3154 | // |
3155 | // Eliminate the G_AND when it is known that x & y == x or x & y == y. |
3156 | // |
3157 | // Patterns like this can appear as a result of legalization. E.g. |
3158 | // |
3159 | // %cmp:_(s32) = G_ICMP intpred(pred), %x(s32), %y |
3160 | // %one:_(s32) = G_CONSTANT i32 1 |
3161 | // %and:_(s32) = G_AND %cmp, %one |
3162 | // |
3163 | // In this case, G_ICMP only produces a single bit, so x & 1 == x. |
3164 | assert(MI.getOpcode() == TargetOpcode::G_AND); |
3165 | if (!KB) |
3166 | return false; |
3167 | |
3168 | Register AndDst = MI.getOperand(i: 0).getReg(); |
3169 | Register LHS = MI.getOperand(i: 1).getReg(); |
3170 | Register RHS = MI.getOperand(i: 2).getReg(); |
3171 | KnownBits LHSBits = KB->getKnownBits(R: LHS); |
3172 | KnownBits RHSBits = KB->getKnownBits(R: RHS); |
3173 | |
3174 | // Check that x & Mask == x. |
3175 | // x & 1 == x, always |
3176 | // x & 0 == x, only if x is also 0 |
3177 | // Meaning Mask has no effect if every bit is either one in Mask or zero in x. |
3178 | // |
3179 | // Check if we can replace AndDst with the LHS of the G_AND |
3180 | if (canReplaceReg(DstReg: AndDst, SrcReg: LHS, MRI) && |
3181 | (LHSBits.Zero | RHSBits.One).isAllOnes()) { |
3182 | Replacement = LHS; |
3183 | return true; |
3184 | } |
3185 | |
3186 | // Check if we can replace AndDst with the RHS of the G_AND |
3187 | if (canReplaceReg(DstReg: AndDst, SrcReg: RHS, MRI) && |
3188 | (LHSBits.One | RHSBits.Zero).isAllOnes()) { |
3189 | Replacement = RHS; |
3190 | return true; |
3191 | } |
3192 | |
3193 | return false; |
3194 | } |
3195 | |
3196 | bool CombinerHelper::matchRedundantOr(MachineInstr &MI, Register &Replacement) { |
3197 | // Given |
3198 | // |
3199 | // %y:_(sN) = G_SOMETHING |
3200 | // %x:_(sN) = G_SOMETHING |
3201 | // %res:_(sN) = G_OR %x, %y |
3202 | // |
3203 | // Eliminate the G_OR when it is known that x | y == x or x | y == y. |
3204 | assert(MI.getOpcode() == TargetOpcode::G_OR); |
3205 | if (!KB) |
3206 | return false; |
3207 | |
3208 | Register OrDst = MI.getOperand(i: 0).getReg(); |
3209 | Register LHS = MI.getOperand(i: 1).getReg(); |
3210 | Register RHS = MI.getOperand(i: 2).getReg(); |
3211 | KnownBits LHSBits = KB->getKnownBits(R: LHS); |
3212 | KnownBits RHSBits = KB->getKnownBits(R: RHS); |
3213 | |
3214 | // Check that x | Mask == x. |
3215 | // x | 0 == x, always |
3216 | // x | 1 == x, only if x is also 1 |
3217 | // Meaning Mask has no effect if every bit is either zero in Mask or one in x. |
3218 | // |
3219 | // Check if we can replace OrDst with the LHS of the G_OR |
3220 | if (canReplaceReg(DstReg: OrDst, SrcReg: LHS, MRI) && |
3221 | (LHSBits.One | RHSBits.Zero).isAllOnes()) { |
3222 | Replacement = LHS; |
3223 | return true; |
3224 | } |
3225 | |
3226 | // Check if we can replace OrDst with the RHS of the G_OR |
3227 | if (canReplaceReg(DstReg: OrDst, SrcReg: RHS, MRI) && |
3228 | (LHSBits.Zero | RHSBits.One).isAllOnes()) { |
3229 | Replacement = RHS; |
3230 | return true; |
3231 | } |
3232 | |
3233 | return false; |
3234 | } |
3235 | |
3236 | bool CombinerHelper::matchRedundantSExtInReg(MachineInstr &MI) { |
3237 | // If the input is already sign extended, just drop the extension. |
3238 | Register Src = MI.getOperand(i: 1).getReg(); |
3239 | unsigned ExtBits = MI.getOperand(i: 2).getImm(); |
3240 | unsigned TypeSize = MRI.getType(Reg: Src).getScalarSizeInBits(); |
3241 | return KB->computeNumSignBits(R: Src) >= (TypeSize - ExtBits + 1); |
3242 | } |
3243 | |
3244 | static bool isConstValidTrue(const TargetLowering &TLI, unsigned ScalarSizeBits, |
3245 | int64_t Cst, bool IsVector, bool IsFP) { |
3246 | // For i1, Cst will always be -1 regardless of boolean contents. |
3247 | return (ScalarSizeBits == 1 && Cst == -1) || |
3248 | isConstTrueVal(TLI, Val: Cst, IsVector, IsFP); |
3249 | } |
3250 | |
3251 | bool CombinerHelper::matchNotCmp(MachineInstr &MI, |
3252 | SmallVectorImpl<Register> &RegsToNegate) { |
3253 | assert(MI.getOpcode() == TargetOpcode::G_XOR); |
3254 | LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
3255 | const auto &TLI = *Builder.getMF().getSubtarget().getTargetLowering(); |
3256 | Register XorSrc; |
3257 | Register CstReg; |
3258 | // We match xor(src, true) here. |
3259 | if (!mi_match(R: MI.getOperand(i: 0).getReg(), MRI, |
3260 | P: m_GXor(L: m_Reg(R&: XorSrc), R: m_Reg(R&: CstReg)))) |
3261 | return false; |
3262 | |
3263 | if (!MRI.hasOneNonDBGUse(RegNo: XorSrc)) |
3264 | return false; |
3265 | |
3266 | // Check that XorSrc is the root of a tree of comparisons combined with ANDs |
3267 | // and ORs. The suffix of RegsToNegate starting from index I is used a work |
3268 | // list of tree nodes to visit. |
3269 | RegsToNegate.push_back(Elt: XorSrc); |
3270 | // Remember whether the comparisons are all integer or all floating point. |
3271 | bool IsInt = false; |
3272 | bool IsFP = false; |
3273 | for (unsigned I = 0; I < RegsToNegate.size(); ++I) { |
3274 | Register Reg = RegsToNegate[I]; |
3275 | if (!MRI.hasOneNonDBGUse(RegNo: Reg)) |
3276 | return false; |
3277 | MachineInstr *Def = MRI.getVRegDef(Reg); |
3278 | switch (Def->getOpcode()) { |
3279 | default: |
3280 | // Don't match if the tree contains anything other than ANDs, ORs and |
3281 | // comparisons. |
3282 | return false; |
3283 | case TargetOpcode::G_ICMP: |
3284 | if (IsFP) |
3285 | return false; |
3286 | IsInt = true; |
3287 | // When we apply the combine we will invert the predicate. |
3288 | break; |
3289 | case TargetOpcode::G_FCMP: |
3290 | if (IsInt) |
3291 | return false; |
3292 | IsFP = true; |
3293 | // When we apply the combine we will invert the predicate. |
3294 | break; |
3295 | case TargetOpcode::G_AND: |
3296 | case TargetOpcode::G_OR: |
3297 | // Implement De Morgan's laws: |
3298 | // ~(x & y) -> ~x | ~y |
3299 | // ~(x | y) -> ~x & ~y |
3300 | // When we apply the combine we will change the opcode and recursively |
3301 | // negate the operands. |
3302 | RegsToNegate.push_back(Elt: Def->getOperand(i: 1).getReg()); |
3303 | RegsToNegate.push_back(Elt: Def->getOperand(i: 2).getReg()); |
3304 | break; |
3305 | } |
3306 | } |
3307 | |
3308 | // Now we know whether the comparisons are integer or floating point, check |
3309 | // the constant in the xor. |
3310 | int64_t Cst; |
3311 | if (Ty.isVector()) { |
3312 | MachineInstr *CstDef = MRI.getVRegDef(Reg: CstReg); |
3313 | auto MaybeCst = getIConstantSplatSExtVal(MI: *CstDef, MRI); |
3314 | if (!MaybeCst) |
3315 | return false; |
3316 | if (!isConstValidTrue(TLI, ScalarSizeBits: Ty.getScalarSizeInBits(), Cst: *MaybeCst, IsVector: true, IsFP)) |
3317 | return false; |
3318 | } else { |
3319 | if (!mi_match(R: CstReg, MRI, P: m_ICst(Cst))) |
3320 | return false; |
3321 | if (!isConstValidTrue(TLI, ScalarSizeBits: Ty.getSizeInBits(), Cst, IsVector: false, IsFP)) |
3322 | return false; |
3323 | } |
3324 | |
3325 | return true; |
3326 | } |
3327 | |
3328 | void CombinerHelper::applyNotCmp(MachineInstr &MI, |
3329 | SmallVectorImpl<Register> &RegsToNegate) { |
3330 | for (Register Reg : RegsToNegate) { |
3331 | MachineInstr *Def = MRI.getVRegDef(Reg); |
3332 | Observer.changingInstr(MI&: *Def); |
3333 | // For each comparison, invert the opcode. For each AND and OR, change the |
3334 | // opcode. |
3335 | switch (Def->getOpcode()) { |
3336 | default: |
3337 | llvm_unreachable("Unexpected opcode" ); |
3338 | case TargetOpcode::G_ICMP: |
3339 | case TargetOpcode::G_FCMP: { |
3340 | MachineOperand &PredOp = Def->getOperand(i: 1); |
3341 | CmpInst::Predicate NewP = CmpInst::getInversePredicate( |
3342 | pred: (CmpInst::Predicate)PredOp.getPredicate()); |
3343 | PredOp.setPredicate(NewP); |
3344 | break; |
3345 | } |
3346 | case TargetOpcode::G_AND: |
3347 | Def->setDesc(Builder.getTII().get(Opcode: TargetOpcode::G_OR)); |
3348 | break; |
3349 | case TargetOpcode::G_OR: |
3350 | Def->setDesc(Builder.getTII().get(Opcode: TargetOpcode::G_AND)); |
3351 | break; |
3352 | } |
3353 | Observer.changedInstr(MI&: *Def); |
3354 | } |
3355 | |
3356 | replaceRegWith(MRI, FromReg: MI.getOperand(i: 0).getReg(), ToReg: MI.getOperand(i: 1).getReg()); |
3357 | MI.eraseFromParent(); |
3358 | } |
3359 | |
3360 | bool CombinerHelper::matchXorOfAndWithSameReg( |
3361 | MachineInstr &MI, std::pair<Register, Register> &MatchInfo) { |
3362 | // Match (xor (and x, y), y) (or any of its commuted cases) |
3363 | assert(MI.getOpcode() == TargetOpcode::G_XOR); |
3364 | Register &X = MatchInfo.first; |
3365 | Register &Y = MatchInfo.second; |
3366 | Register AndReg = MI.getOperand(i: 1).getReg(); |
3367 | Register SharedReg = MI.getOperand(i: 2).getReg(); |
3368 | |
3369 | // Find a G_AND on either side of the G_XOR. |
3370 | // Look for one of |
3371 | // |
3372 | // (xor (and x, y), SharedReg) |
3373 | // (xor SharedReg, (and x, y)) |
3374 | if (!mi_match(R: AndReg, MRI, P: m_GAnd(L: m_Reg(R&: X), R: m_Reg(R&: Y)))) { |
3375 | std::swap(a&: AndReg, b&: SharedReg); |
3376 | if (!mi_match(R: AndReg, MRI, P: m_GAnd(L: m_Reg(R&: X), R: m_Reg(R&: Y)))) |
3377 | return false; |
3378 | } |
3379 | |
3380 | // Only do this if we'll eliminate the G_AND. |
3381 | if (!MRI.hasOneNonDBGUse(RegNo: AndReg)) |
3382 | return false; |
3383 | |
3384 | // We can combine if SharedReg is the same as either the LHS or RHS of the |
3385 | // G_AND. |
3386 | if (Y != SharedReg) |
3387 | std::swap(a&: X, b&: Y); |
3388 | return Y == SharedReg; |
3389 | } |
3390 | |
3391 | void CombinerHelper::applyXorOfAndWithSameReg( |
3392 | MachineInstr &MI, std::pair<Register, Register> &MatchInfo) { |
3393 | // Fold (xor (and x, y), y) -> (and (not x), y) |
3394 | Builder.setInstrAndDebugLoc(MI); |
3395 | Register X, Y; |
3396 | std::tie(args&: X, args&: Y) = MatchInfo; |
3397 | auto Not = Builder.buildNot(Dst: MRI.getType(Reg: X), Src0: X); |
3398 | Observer.changingInstr(MI); |
3399 | MI.setDesc(Builder.getTII().get(Opcode: TargetOpcode::G_AND)); |
3400 | MI.getOperand(i: 1).setReg(Not->getOperand(i: 0).getReg()); |
3401 | MI.getOperand(i: 2).setReg(Y); |
3402 | Observer.changedInstr(MI); |
3403 | } |
3404 | |
3405 | bool CombinerHelper::matchPtrAddZero(MachineInstr &MI) { |
3406 | auto &PtrAdd = cast<GPtrAdd>(Val&: MI); |
3407 | Register DstReg = PtrAdd.getReg(Idx: 0); |
3408 | LLT Ty = MRI.getType(Reg: DstReg); |
3409 | const DataLayout &DL = Builder.getMF().getDataLayout(); |
3410 | |
3411 | if (DL.isNonIntegralAddressSpace(AddrSpace: Ty.getScalarType().getAddressSpace())) |
3412 | return false; |
3413 | |
3414 | if (Ty.isPointer()) { |
3415 | auto ConstVal = getIConstantVRegVal(VReg: PtrAdd.getBaseReg(), MRI); |
3416 | return ConstVal && *ConstVal == 0; |
3417 | } |
3418 | |
3419 | assert(Ty.isVector() && "Expecting a vector type" ); |
3420 | const MachineInstr *VecMI = MRI.getVRegDef(Reg: PtrAdd.getBaseReg()); |
3421 | return isBuildVectorAllZeros(MI: *VecMI, MRI); |
3422 | } |
3423 | |
3424 | void CombinerHelper::applyPtrAddZero(MachineInstr &MI) { |
3425 | auto &PtrAdd = cast<GPtrAdd>(Val&: MI); |
3426 | Builder.setInstrAndDebugLoc(PtrAdd); |
3427 | Builder.buildIntToPtr(Dst: PtrAdd.getReg(Idx: 0), Src: PtrAdd.getOffsetReg()); |
3428 | PtrAdd.eraseFromParent(); |
3429 | } |
3430 | |
3431 | /// The second source operand is known to be a power of 2. |
3432 | void CombinerHelper::applySimplifyURemByPow2(MachineInstr &MI) { |
3433 | Register DstReg = MI.getOperand(i: 0).getReg(); |
3434 | Register Src0 = MI.getOperand(i: 1).getReg(); |
3435 | Register Pow2Src1 = MI.getOperand(i: 2).getReg(); |
3436 | LLT Ty = MRI.getType(Reg: DstReg); |
3437 | Builder.setInstrAndDebugLoc(MI); |
3438 | |
3439 | // Fold (urem x, pow2) -> (and x, pow2-1) |
3440 | auto NegOne = Builder.buildConstant(Res: Ty, Val: -1); |
3441 | auto Add = Builder.buildAdd(Dst: Ty, Src0: Pow2Src1, Src1: NegOne); |
3442 | Builder.buildAnd(Dst: DstReg, Src0, Src1: Add); |
3443 | MI.eraseFromParent(); |
3444 | } |
3445 | |
3446 | bool CombinerHelper::matchFoldBinOpIntoSelect(MachineInstr &MI, |
3447 | unsigned &SelectOpNo) { |
3448 | Register LHS = MI.getOperand(i: 1).getReg(); |
3449 | Register RHS = MI.getOperand(i: 2).getReg(); |
3450 | |
3451 | Register OtherOperandReg = RHS; |
3452 | SelectOpNo = 1; |
3453 | MachineInstr *Select = MRI.getVRegDef(Reg: LHS); |
3454 | |
3455 | // Don't do this unless the old select is going away. We want to eliminate the |
3456 | // binary operator, not replace a binop with a select. |
3457 | if (Select->getOpcode() != TargetOpcode::G_SELECT || |
3458 | !MRI.hasOneNonDBGUse(RegNo: LHS)) { |
3459 | OtherOperandReg = LHS; |
3460 | SelectOpNo = 2; |
3461 | Select = MRI.getVRegDef(Reg: RHS); |
3462 | if (Select->getOpcode() != TargetOpcode::G_SELECT || |
3463 | !MRI.hasOneNonDBGUse(RegNo: RHS)) |
3464 | return false; |
3465 | } |
3466 | |
3467 | MachineInstr *SelectLHS = MRI.getVRegDef(Reg: Select->getOperand(i: 2).getReg()); |
3468 | MachineInstr *SelectRHS = MRI.getVRegDef(Reg: Select->getOperand(i: 3).getReg()); |
3469 | |
3470 | if (!isConstantOrConstantVector(MI: *SelectLHS, MRI, |
3471 | /*AllowFP*/ true, |
3472 | /*AllowOpaqueConstants*/ false)) |
3473 | return false; |
3474 | if (!isConstantOrConstantVector(MI: *SelectRHS, MRI, |
3475 | /*AllowFP*/ true, |
3476 | /*AllowOpaqueConstants*/ false)) |
3477 | return false; |
3478 | |
3479 | unsigned BinOpcode = MI.getOpcode(); |
3480 | |
3481 | // We know that one of the operands is a select of constants. Now verify that |
3482 | // the other binary operator operand is either a constant, or we can handle a |
3483 | // variable. |
3484 | bool CanFoldNonConst = |
3485 | (BinOpcode == TargetOpcode::G_AND || BinOpcode == TargetOpcode::G_OR) && |
3486 | (isNullOrNullSplat(MI: *SelectLHS, MRI) || |
3487 | isAllOnesOrAllOnesSplat(MI: *SelectLHS, MRI)) && |
3488 | (isNullOrNullSplat(MI: *SelectRHS, MRI) || |
3489 | isAllOnesOrAllOnesSplat(MI: *SelectRHS, MRI)); |
3490 | if (CanFoldNonConst) |
3491 | return true; |
3492 | |
3493 | return isConstantOrConstantVector(MI: *MRI.getVRegDef(Reg: OtherOperandReg), MRI, |
3494 | /*AllowFP*/ true, |
3495 | /*AllowOpaqueConstants*/ false); |
3496 | } |
3497 | |
3498 | /// \p SelectOperand is the operand in binary operator \p MI that is the select |
3499 | /// to fold. |
3500 | void CombinerHelper::applyFoldBinOpIntoSelect(MachineInstr &MI, |
3501 | const unsigned &SelectOperand) { |
3502 | Builder.setInstrAndDebugLoc(MI); |
3503 | |
3504 | Register Dst = MI.getOperand(i: 0).getReg(); |
3505 | Register LHS = MI.getOperand(i: 1).getReg(); |
3506 | Register RHS = MI.getOperand(i: 2).getReg(); |
3507 | MachineInstr *Select = MRI.getVRegDef(Reg: MI.getOperand(i: SelectOperand).getReg()); |
3508 | |
3509 | Register SelectCond = Select->getOperand(i: 1).getReg(); |
3510 | Register SelectTrue = Select->getOperand(i: 2).getReg(); |
3511 | Register SelectFalse = Select->getOperand(i: 3).getReg(); |
3512 | |
3513 | LLT Ty = MRI.getType(Reg: Dst); |
3514 | unsigned BinOpcode = MI.getOpcode(); |
3515 | |
3516 | Register FoldTrue, FoldFalse; |
3517 | |
3518 | // We have a select-of-constants followed by a binary operator with a |
3519 | // constant. Eliminate the binop by pulling the constant math into the select. |
3520 | // Example: add (select Cond, CT, CF), CBO --> select Cond, CT + CBO, CF + CBO |
3521 | if (SelectOperand == 1) { |
3522 | // TODO: SelectionDAG verifies this actually constant folds before |
3523 | // committing to the combine. |
3524 | |
3525 | FoldTrue = Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {SelectTrue, RHS}).getReg(Idx: 0); |
3526 | FoldFalse = |
3527 | Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {SelectFalse, RHS}).getReg(Idx: 0); |
3528 | } else { |
3529 | FoldTrue = Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {LHS, SelectTrue}).getReg(Idx: 0); |
3530 | FoldFalse = |
3531 | Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {LHS, SelectFalse}).getReg(Idx: 0); |
3532 | } |
3533 | |
3534 | Builder.buildSelect(Res: Dst, Tst: SelectCond, Op0: FoldTrue, Op1: FoldFalse, Flags: MI.getFlags()); |
3535 | MI.eraseFromParent(); |
3536 | } |
3537 | |
3538 | std::optional<SmallVector<Register, 8>> |
3539 | CombinerHelper::findCandidatesForLoadOrCombine(const MachineInstr *Root) const { |
3540 | assert(Root->getOpcode() == TargetOpcode::G_OR && "Expected G_OR only!" ); |
3541 | // We want to detect if Root is part of a tree which represents a bunch |
3542 | // of loads being merged into a larger load. We'll try to recognize patterns |
3543 | // like, for example: |
3544 | // |
3545 | // Reg Reg |
3546 | // \ / |
3547 | // OR_1 Reg |
3548 | // \ / |
3549 | // OR_2 |
3550 | // \ Reg |
3551 | // .. / |
3552 | // Root |
3553 | // |
3554 | // Reg Reg Reg Reg |
3555 | // \ / \ / |
3556 | // OR_1 OR_2 |
3557 | // \ / |
3558 | // \ / |
3559 | // ... |
3560 | // Root |
3561 | // |
3562 | // Each "Reg" may have been produced by a load + some arithmetic. This |
3563 | // function will save each of them. |
3564 | SmallVector<Register, 8> RegsToVisit; |
3565 | SmallVector<const MachineInstr *, 7> Ors = {Root}; |
3566 | |
3567 | // In the "worst" case, we're dealing with a load for each byte. So, there |
3568 | // are at most #bytes - 1 ORs. |
3569 | const unsigned MaxIter = |
3570 | MRI.getType(Reg: Root->getOperand(i: 0).getReg()).getSizeInBytes() - 1; |
3571 | for (unsigned Iter = 0; Iter < MaxIter; ++Iter) { |
3572 | if (Ors.empty()) |
3573 | break; |
3574 | const MachineInstr *Curr = Ors.pop_back_val(); |
3575 | Register OrLHS = Curr->getOperand(i: 1).getReg(); |
3576 | Register OrRHS = Curr->getOperand(i: 2).getReg(); |
3577 | |
3578 | // In the combine, we want to elimate the entire tree. |
3579 | if (!MRI.hasOneNonDBGUse(RegNo: OrLHS) || !MRI.hasOneNonDBGUse(RegNo: OrRHS)) |
3580 | return std::nullopt; |
3581 | |
3582 | // If it's a G_OR, save it and continue to walk. If it's not, then it's |
3583 | // something that may be a load + arithmetic. |
3584 | if (const MachineInstr *Or = getOpcodeDef(Opcode: TargetOpcode::G_OR, Reg: OrLHS, MRI)) |
3585 | Ors.push_back(Elt: Or); |
3586 | else |
3587 | RegsToVisit.push_back(Elt: OrLHS); |
3588 | if (const MachineInstr *Or = getOpcodeDef(Opcode: TargetOpcode::G_OR, Reg: OrRHS, MRI)) |
3589 | Ors.push_back(Elt: Or); |
3590 | else |
3591 | RegsToVisit.push_back(Elt: OrRHS); |
3592 | } |
3593 | |
3594 | // We're going to try and merge each register into a wider power-of-2 type, |
3595 | // so we ought to have an even number of registers. |
3596 | if (RegsToVisit.empty() || RegsToVisit.size() % 2 != 0) |
3597 | return std::nullopt; |
3598 | return RegsToVisit; |
3599 | } |
3600 | |
3601 | /// Helper function for findLoadOffsetsForLoadOrCombine. |
3602 | /// |
3603 | /// Check if \p Reg is the result of loading a \p MemSizeInBits wide value, |
3604 | /// and then moving that value into a specific byte offset. |
3605 | /// |
3606 | /// e.g. x[i] << 24 |
3607 | /// |
3608 | /// \returns The load instruction and the byte offset it is moved into. |
3609 | static std::optional<std::pair<GZExtLoad *, int64_t>> |
3610 | matchLoadAndBytePosition(Register Reg, unsigned MemSizeInBits, |
3611 | const MachineRegisterInfo &MRI) { |
3612 | assert(MRI.hasOneNonDBGUse(Reg) && |
3613 | "Expected Reg to only have one non-debug use?" ); |
3614 | Register MaybeLoad; |
3615 | int64_t Shift; |
3616 | if (!mi_match(R: Reg, MRI, |
3617 | P: m_OneNonDBGUse(SP: m_GShl(L: m_Reg(R&: MaybeLoad), R: m_ICst(Cst&: Shift))))) { |
3618 | Shift = 0; |
3619 | MaybeLoad = Reg; |
3620 | } |
3621 | |
3622 | if (Shift % MemSizeInBits != 0) |
3623 | return std::nullopt; |
3624 | |
3625 | // TODO: Handle other types of loads. |
3626 | auto *Load = getOpcodeDef<GZExtLoad>(Reg: MaybeLoad, MRI); |
3627 | if (!Load) |
3628 | return std::nullopt; |
3629 | |
3630 | if (!Load->isUnordered() || Load->getMemSizeInBits() != MemSizeInBits) |
3631 | return std::nullopt; |
3632 | |
3633 | return std::make_pair(x&: Load, y: Shift / MemSizeInBits); |
3634 | } |
3635 | |
3636 | std::optional<std::tuple<GZExtLoad *, int64_t, GZExtLoad *>> |
3637 | CombinerHelper::findLoadOffsetsForLoadOrCombine( |
3638 | SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx, |
3639 | const SmallVector<Register, 8> &RegsToVisit, const unsigned MemSizeInBits) { |
3640 | |
3641 | // Each load found for the pattern. There should be one for each RegsToVisit. |
3642 | SmallSetVector<const MachineInstr *, 8> Loads; |
3643 | |
3644 | // The lowest index used in any load. (The lowest "i" for each x[i].) |
3645 | int64_t LowestIdx = INT64_MAX; |
3646 | |
3647 | // The load which uses the lowest index. |
3648 | GZExtLoad *LowestIdxLoad = nullptr; |
3649 | |
3650 | // Keeps track of the load indices we see. We shouldn't see any indices twice. |
3651 | SmallSet<int64_t, 8> SeenIdx; |
3652 | |
3653 | // Ensure each load is in the same MBB. |
3654 | // TODO: Support multiple MachineBasicBlocks. |
3655 | MachineBasicBlock *MBB = nullptr; |
3656 | const MachineMemOperand *MMO = nullptr; |
3657 | |
3658 | // Earliest instruction-order load in the pattern. |
3659 | GZExtLoad *EarliestLoad = nullptr; |
3660 | |
3661 | // Latest instruction-order load in the pattern. |
3662 | GZExtLoad *LatestLoad = nullptr; |
3663 | |
3664 | // Base pointer which every load should share. |
3665 | Register BasePtr; |
3666 | |
3667 | // We want to find a load for each register. Each load should have some |
3668 | // appropriate bit twiddling arithmetic. During this loop, we will also keep |
3669 | // track of the load which uses the lowest index. Later, we will check if we |
3670 | // can use its pointer in the final, combined load. |
3671 | for (auto Reg : RegsToVisit) { |
3672 | // Find the load, and find the position that it will end up in (e.g. a |
3673 | // shifted) value. |
3674 | auto LoadAndPos = matchLoadAndBytePosition(Reg, MemSizeInBits, MRI); |
3675 | if (!LoadAndPos) |
3676 | return std::nullopt; |
3677 | GZExtLoad *Load; |
3678 | int64_t DstPos; |
3679 | std::tie(args&: Load, args&: DstPos) = *LoadAndPos; |
3680 | |
3681 | // TODO: Handle multiple MachineBasicBlocks. Currently not handled because |
3682 | // it is difficult to check for stores/calls/etc between loads. |
3683 | MachineBasicBlock *LoadMBB = Load->getParent(); |
3684 | if (!MBB) |
3685 | MBB = LoadMBB; |
3686 | if (LoadMBB != MBB) |
3687 | return std::nullopt; |
3688 | |
3689 | // Make sure that the MachineMemOperands of every seen load are compatible. |
3690 | auto &LoadMMO = Load->getMMO(); |
3691 | if (!MMO) |
3692 | MMO = &LoadMMO; |
3693 | if (MMO->getAddrSpace() != LoadMMO.getAddrSpace()) |
3694 | return std::nullopt; |
3695 | |
3696 | // Find out what the base pointer and index for the load is. |
3697 | Register LoadPtr; |
3698 | int64_t Idx; |
3699 | if (!mi_match(R: Load->getOperand(i: 1).getReg(), MRI, |
3700 | P: m_GPtrAdd(L: m_Reg(R&: LoadPtr), R: m_ICst(Cst&: Idx)))) { |
3701 | LoadPtr = Load->getOperand(i: 1).getReg(); |
3702 | Idx = 0; |
3703 | } |
3704 | |
3705 | // Don't combine things like a[i], a[i] -> a bigger load. |
3706 | if (!SeenIdx.insert(V: Idx).second) |
3707 | return std::nullopt; |
3708 | |
3709 | // Every load must share the same base pointer; don't combine things like: |
3710 | // |
3711 | // a[i], b[i + 1] -> a bigger load. |
3712 | if (!BasePtr.isValid()) |
3713 | BasePtr = LoadPtr; |
3714 | if (BasePtr != LoadPtr) |
3715 | return std::nullopt; |
3716 | |
3717 | if (Idx < LowestIdx) { |
3718 | LowestIdx = Idx; |
3719 | LowestIdxLoad = Load; |
3720 | } |
3721 | |
3722 | // Keep track of the byte offset that this load ends up at. If we have seen |
3723 | // the byte offset, then stop here. We do not want to combine: |
3724 | // |
3725 | // a[i] << 16, a[i + k] << 16 -> a bigger load. |
3726 | if (!MemOffset2Idx.try_emplace(Key: DstPos, Args&: Idx).second) |
3727 | return std::nullopt; |
3728 | Loads.insert(X: Load); |
3729 | |
3730 | // Keep track of the position of the earliest/latest loads in the pattern. |
3731 | // We will check that there are no load fold barriers between them later |
3732 | // on. |
3733 | // |
3734 | // FIXME: Is there a better way to check for load fold barriers? |
3735 | if (!EarliestLoad || dominates(DefMI: *Load, UseMI: *EarliestLoad)) |
3736 | EarliestLoad = Load; |
3737 | if (!LatestLoad || dominates(DefMI: *LatestLoad, UseMI: *Load)) |
3738 | LatestLoad = Load; |
3739 | } |
3740 | |
3741 | // We found a load for each register. Let's check if each load satisfies the |
3742 | // pattern. |
3743 | assert(Loads.size() == RegsToVisit.size() && |
3744 | "Expected to find a load for each register?" ); |
3745 | assert(EarliestLoad != LatestLoad && EarliestLoad && |
3746 | LatestLoad && "Expected at least two loads?" ); |
3747 | |
3748 | // Check if there are any stores, calls, etc. between any of the loads. If |
3749 | // there are, then we can't safely perform the combine. |
3750 | // |
3751 | // MaxIter is chosen based off the (worst case) number of iterations it |
3752 | // typically takes to succeed in the LLVM test suite plus some padding. |
3753 | // |
3754 | // FIXME: Is there a better way to check for load fold barriers? |
3755 | const unsigned MaxIter = 20; |
3756 | unsigned Iter = 0; |
3757 | for (const auto &MI : instructionsWithoutDebug(It: EarliestLoad->getIterator(), |
3758 | End: LatestLoad->getIterator())) { |
3759 | if (Loads.count(key: &MI)) |
3760 | continue; |
3761 | if (MI.isLoadFoldBarrier()) |
3762 | return std::nullopt; |
3763 | if (Iter++ == MaxIter) |
3764 | return std::nullopt; |
3765 | } |
3766 | |
3767 | return std::make_tuple(args&: LowestIdxLoad, args&: LowestIdx, args&: LatestLoad); |
3768 | } |
3769 | |
3770 | bool CombinerHelper::matchLoadOrCombine( |
3771 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
3772 | assert(MI.getOpcode() == TargetOpcode::G_OR); |
3773 | MachineFunction &MF = *MI.getMF(); |
3774 | // Assuming a little-endian target, transform: |
3775 | // s8 *a = ... |
3776 | // s32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24) |
3777 | // => |
3778 | // s32 val = *((i32)a) |
3779 | // |
3780 | // s8 *a = ... |
3781 | // s32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3] |
3782 | // => |
3783 | // s32 val = BSWAP(*((s32)a)) |
3784 | Register Dst = MI.getOperand(i: 0).getReg(); |
3785 | LLT Ty = MRI.getType(Reg: Dst); |
3786 | if (Ty.isVector()) |
3787 | return false; |
3788 | |
3789 | // We need to combine at least two loads into this type. Since the smallest |
3790 | // possible load is into a byte, we need at least a 16-bit wide type. |
3791 | const unsigned WideMemSizeInBits = Ty.getSizeInBits(); |
3792 | if (WideMemSizeInBits < 16 || WideMemSizeInBits % 8 != 0) |
3793 | return false; |
3794 | |
3795 | // Match a collection of non-OR instructions in the pattern. |
3796 | auto RegsToVisit = findCandidatesForLoadOrCombine(Root: &MI); |
3797 | if (!RegsToVisit) |
3798 | return false; |
3799 | |
3800 | // We have a collection of non-OR instructions. Figure out how wide each of |
3801 | // the small loads should be based off of the number of potential loads we |
3802 | // found. |
3803 | const unsigned NarrowMemSizeInBits = WideMemSizeInBits / RegsToVisit->size(); |
3804 | if (NarrowMemSizeInBits % 8 != 0) |
3805 | return false; |
3806 | |
3807 | // Check if each register feeding into each OR is a load from the same |
3808 | // base pointer + some arithmetic. |
3809 | // |
3810 | // e.g. a[0], a[1] << 8, a[2] << 16, etc. |
3811 | // |
3812 | // Also verify that each of these ends up putting a[i] into the same memory |
3813 | // offset as a load into a wide type would. |
3814 | SmallDenseMap<int64_t, int64_t, 8> MemOffset2Idx; |
3815 | GZExtLoad *LowestIdxLoad, *LatestLoad; |
3816 | int64_t LowestIdx; |
3817 | auto MaybeLoadInfo = findLoadOffsetsForLoadOrCombine( |
3818 | MemOffset2Idx, RegsToVisit: *RegsToVisit, MemSizeInBits: NarrowMemSizeInBits); |
3819 | if (!MaybeLoadInfo) |
3820 | return false; |
3821 | std::tie(args&: LowestIdxLoad, args&: LowestIdx, args&: LatestLoad) = *MaybeLoadInfo; |
3822 | |
3823 | // We have a bunch of loads being OR'd together. Using the addresses + offsets |
3824 | // we found before, check if this corresponds to a big or little endian byte |
3825 | // pattern. If it does, then we can represent it using a load + possibly a |
3826 | // BSWAP. |
3827 | bool IsBigEndianTarget = MF.getDataLayout().isBigEndian(); |
3828 | std::optional<bool> IsBigEndian = isBigEndian(MemOffset2Idx, LowestIdx); |
3829 | if (!IsBigEndian) |
3830 | return false; |
3831 | bool NeedsBSwap = IsBigEndianTarget != *IsBigEndian; |
3832 | if (NeedsBSwap && !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_BSWAP, {Ty}})) |
3833 | return false; |
3834 | |
3835 | // Make sure that the load from the lowest index produces offset 0 in the |
3836 | // final value. |
3837 | // |
3838 | // This ensures that we won't combine something like this: |
3839 | // |
3840 | // load x[i] -> byte 2 |
3841 | // load x[i+1] -> byte 0 ---> wide_load x[i] |
3842 | // load x[i+2] -> byte 1 |
3843 | const unsigned NumLoadsInTy = WideMemSizeInBits / NarrowMemSizeInBits; |
3844 | const unsigned ZeroByteOffset = |
3845 | *IsBigEndian |
3846 | ? bigEndianByteAt(ByteWidth: NumLoadsInTy, I: 0) |
3847 | : littleEndianByteAt(ByteWidth: NumLoadsInTy, I: 0); |
3848 | auto ZeroOffsetIdx = MemOffset2Idx.find(Val: ZeroByteOffset); |
3849 | if (ZeroOffsetIdx == MemOffset2Idx.end() || |
3850 | ZeroOffsetIdx->second != LowestIdx) |
3851 | return false; |
3852 | |
3853 | // We wil reuse the pointer from the load which ends up at byte offset 0. It |
3854 | // may not use index 0. |
3855 | Register Ptr = LowestIdxLoad->getPointerReg(); |
3856 | const MachineMemOperand &MMO = LowestIdxLoad->getMMO(); |
3857 | LegalityQuery::MemDesc MMDesc(MMO); |
3858 | MMDesc.MemoryTy = Ty; |
3859 | if (!isLegalOrBeforeLegalizer( |
3860 | Query: {TargetOpcode::G_LOAD, {Ty, MRI.getType(Reg: Ptr)}, {MMDesc}})) |
3861 | return false; |
3862 | auto PtrInfo = MMO.getPointerInfo(); |
3863 | auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Size: WideMemSizeInBits / 8); |
3864 | |
3865 | // Load must be allowed and fast on the target. |
3866 | LLVMContext &C = MF.getFunction().getContext(); |
3867 | auto &DL = MF.getDataLayout(); |
3868 | unsigned Fast = 0; |
3869 | if (!getTargetLowering().allowsMemoryAccess(Context&: C, DL, Ty, MMO: *NewMMO, Fast: &Fast) || |
3870 | !Fast) |
3871 | return false; |
3872 | |
3873 | MatchInfo = [=](MachineIRBuilder &MIB) { |
3874 | MIB.setInstrAndDebugLoc(*LatestLoad); |
3875 | Register LoadDst = NeedsBSwap ? MRI.cloneVirtualRegister(VReg: Dst) : Dst; |
3876 | MIB.buildLoad(Res: LoadDst, Addr: Ptr, MMO&: *NewMMO); |
3877 | if (NeedsBSwap) |
3878 | MIB.buildBSwap(Dst, Src0: LoadDst); |
3879 | }; |
3880 | return true; |
3881 | } |
3882 | |
3883 | bool CombinerHelper::matchExtendThroughPhis(MachineInstr &MI, |
3884 | MachineInstr *&ExtMI) { |
3885 | auto &PHI = cast<GPhi>(Val&: MI); |
3886 | Register DstReg = PHI.getReg(Idx: 0); |
3887 | |
3888 | // TODO: Extending a vector may be expensive, don't do this until heuristics |
3889 | // are better. |
3890 | if (MRI.getType(Reg: DstReg).isVector()) |
3891 | return false; |
3892 | |
3893 | // Try to match a phi, whose only use is an extend. |
3894 | if (!MRI.hasOneNonDBGUse(RegNo: DstReg)) |
3895 | return false; |
3896 | ExtMI = &*MRI.use_instr_nodbg_begin(RegNo: DstReg); |
3897 | switch (ExtMI->getOpcode()) { |
3898 | case TargetOpcode::G_ANYEXT: |
3899 | return true; // G_ANYEXT is usually free. |
3900 | case TargetOpcode::G_ZEXT: |
3901 | case TargetOpcode::G_SEXT: |
3902 | break; |
3903 | default: |
3904 | return false; |
3905 | } |
3906 | |
3907 | // If the target is likely to fold this extend away, don't propagate. |
3908 | if (Builder.getTII().isExtendLikelyToBeFolded(ExtMI&: *ExtMI, MRI)) |
3909 | return false; |
3910 | |
3911 | // We don't want to propagate the extends unless there's a good chance that |
3912 | // they'll be optimized in some way. |
3913 | // Collect the unique incoming values. |
3914 | SmallPtrSet<MachineInstr *, 4> InSrcs; |
3915 | for (unsigned I = 0; I < PHI.getNumIncomingValues(); ++I) { |
3916 | auto *DefMI = getDefIgnoringCopies(Reg: PHI.getIncomingValue(I), MRI); |
3917 | switch (DefMI->getOpcode()) { |
3918 | case TargetOpcode::G_LOAD: |
3919 | case TargetOpcode::G_TRUNC: |
3920 | case TargetOpcode::G_SEXT: |
3921 | case TargetOpcode::G_ZEXT: |
3922 | case TargetOpcode::G_ANYEXT: |
3923 | case TargetOpcode::G_CONSTANT: |
3924 | InSrcs.insert(Ptr: DefMI); |
3925 | // Don't try to propagate if there are too many places to create new |
3926 | // extends, chances are it'll increase code size. |
3927 | if (InSrcs.size() > 2) |
3928 | return false; |
3929 | break; |
3930 | default: |
3931 | return false; |
3932 | } |
3933 | } |
3934 | return true; |
3935 | } |
3936 | |
3937 | void CombinerHelper::applyExtendThroughPhis(MachineInstr &MI, |
3938 | MachineInstr *&ExtMI) { |
3939 | auto &PHI = cast<GPhi>(Val&: MI); |
3940 | Register DstReg = ExtMI->getOperand(i: 0).getReg(); |
3941 | LLT ExtTy = MRI.getType(Reg: DstReg); |
3942 | |
3943 | // Propagate the extension into the block of each incoming reg's block. |
3944 | // Use a SetVector here because PHIs can have duplicate edges, and we want |
3945 | // deterministic iteration order. |
3946 | SmallSetVector<MachineInstr *, 8> SrcMIs; |
3947 | SmallDenseMap<MachineInstr *, MachineInstr *, 8> OldToNewSrcMap; |
3948 | for (unsigned I = 0; I < PHI.getNumIncomingValues(); ++I) { |
3949 | auto SrcReg = PHI.getIncomingValue(I); |
3950 | auto *SrcMI = MRI.getVRegDef(Reg: SrcReg); |
3951 | if (!SrcMIs.insert(X: SrcMI)) |
3952 | continue; |
3953 | |
3954 | // Build an extend after each src inst. |
3955 | auto *MBB = SrcMI->getParent(); |
3956 | MachineBasicBlock::iterator InsertPt = ++SrcMI->getIterator(); |
3957 | if (InsertPt != MBB->end() && InsertPt->isPHI()) |
3958 | InsertPt = MBB->getFirstNonPHI(); |
3959 | |
3960 | Builder.setInsertPt(MBB&: *SrcMI->getParent(), II: InsertPt); |
3961 | Builder.setDebugLoc(MI.getDebugLoc()); |
3962 | auto NewExt = Builder.buildExtOrTrunc(ExtOpc: ExtMI->getOpcode(), Res: ExtTy, Op: SrcReg); |
3963 | OldToNewSrcMap[SrcMI] = NewExt; |
3964 | } |
3965 | |
3966 | // Create a new phi with the extended inputs. |
3967 | Builder.setInstrAndDebugLoc(MI); |
3968 | auto NewPhi = Builder.buildInstrNoInsert(Opcode: TargetOpcode::G_PHI); |
3969 | NewPhi.addDef(RegNo: DstReg); |
3970 | for (const MachineOperand &MO : llvm::drop_begin(RangeOrContainer: MI.operands())) { |
3971 | if (!MO.isReg()) { |
3972 | NewPhi.addMBB(MBB: MO.getMBB()); |
3973 | continue; |
3974 | } |
3975 | auto *NewSrc = OldToNewSrcMap[MRI.getVRegDef(Reg: MO.getReg())]; |
3976 | NewPhi.addUse(RegNo: NewSrc->getOperand(i: 0).getReg()); |
3977 | } |
3978 | Builder.insertInstr(MIB: NewPhi); |
3979 | ExtMI->eraseFromParent(); |
3980 | } |
3981 | |
3982 | bool CombinerHelper::(MachineInstr &MI, |
3983 | Register &Reg) { |
3984 | assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT); |
3985 | // If we have a constant index, look for a G_BUILD_VECTOR source |
3986 | // and find the source register that the index maps to. |
3987 | Register SrcVec = MI.getOperand(i: 1).getReg(); |
3988 | LLT SrcTy = MRI.getType(Reg: SrcVec); |
3989 | |
3990 | auto Cst = getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI); |
3991 | if (!Cst || Cst->Value.getZExtValue() >= SrcTy.getNumElements()) |
3992 | return false; |
3993 | |
3994 | unsigned VecIdx = Cst->Value.getZExtValue(); |
3995 | |
3996 | // Check if we have a build_vector or build_vector_trunc with an optional |
3997 | // trunc in front. |
3998 | MachineInstr *SrcVecMI = MRI.getVRegDef(Reg: SrcVec); |
3999 | if (SrcVecMI->getOpcode() == TargetOpcode::G_TRUNC) { |
4000 | SrcVecMI = MRI.getVRegDef(Reg: SrcVecMI->getOperand(i: 1).getReg()); |
4001 | } |
4002 | |
4003 | if (SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR && |
4004 | SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR_TRUNC) |
4005 | return false; |
4006 | |
4007 | EVT Ty(getMVTForLLT(Ty: SrcTy)); |
4008 | if (!MRI.hasOneNonDBGUse(RegNo: SrcVec) && |
4009 | !getTargetLowering().aggressivelyPreferBuildVectorSources(VecVT: Ty)) |
4010 | return false; |
4011 | |
4012 | Reg = SrcVecMI->getOperand(i: VecIdx + 1).getReg(); |
4013 | return true; |
4014 | } |
4015 | |
4016 | void CombinerHelper::(MachineInstr &MI, |
4017 | Register &Reg) { |
4018 | // Check the type of the register, since it may have come from a |
4019 | // G_BUILD_VECTOR_TRUNC. |
4020 | LLT ScalarTy = MRI.getType(Reg); |
4021 | Register DstReg = MI.getOperand(i: 0).getReg(); |
4022 | LLT DstTy = MRI.getType(Reg: DstReg); |
4023 | |
4024 | Builder.setInstrAndDebugLoc(MI); |
4025 | if (ScalarTy != DstTy) { |
4026 | assert(ScalarTy.getSizeInBits() > DstTy.getSizeInBits()); |
4027 | Builder.buildTrunc(Res: DstReg, Op: Reg); |
4028 | MI.eraseFromParent(); |
4029 | return; |
4030 | } |
4031 | replaceSingleDefInstWithReg(MI, Replacement: Reg); |
4032 | } |
4033 | |
4034 | bool CombinerHelper::( |
4035 | MachineInstr &MI, |
4036 | SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) { |
4037 | assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR); |
4038 | // This combine tries to find build_vector's which have every source element |
4039 | // extracted using G_EXTRACT_VECTOR_ELT. This can happen when transforms like |
4040 | // the masked load scalarization is run late in the pipeline. There's already |
4041 | // a combine for a similar pattern starting from the extract, but that |
4042 | // doesn't attempt to do it if there are multiple uses of the build_vector, |
4043 | // which in this case is true. Starting the combine from the build_vector |
4044 | // feels more natural than trying to find sibling nodes of extracts. |
4045 | // E.g. |
4046 | // %vec(<4 x s32>) = G_BUILD_VECTOR %s1(s32), %s2, %s3, %s4 |
4047 | // %ext1 = G_EXTRACT_VECTOR_ELT %vec, 0 |
4048 | // %ext2 = G_EXTRACT_VECTOR_ELT %vec, 1 |
4049 | // %ext3 = G_EXTRACT_VECTOR_ELT %vec, 2 |
4050 | // %ext4 = G_EXTRACT_VECTOR_ELT %vec, 3 |
4051 | // ==> |
4052 | // replace ext{1,2,3,4} with %s{1,2,3,4} |
4053 | |
4054 | Register DstReg = MI.getOperand(i: 0).getReg(); |
4055 | LLT DstTy = MRI.getType(Reg: DstReg); |
4056 | unsigned NumElts = DstTy.getNumElements(); |
4057 | |
4058 | SmallBitVector (NumElts); |
4059 | for (MachineInstr &II : MRI.use_nodbg_instructions(Reg: DstReg)) { |
4060 | if (II.getOpcode() != TargetOpcode::G_EXTRACT_VECTOR_ELT) |
4061 | return false; |
4062 | auto Cst = getIConstantVRegVal(VReg: II.getOperand(i: 2).getReg(), MRI); |
4063 | if (!Cst) |
4064 | return false; |
4065 | unsigned Idx = Cst->getZExtValue(); |
4066 | if (Idx >= NumElts) |
4067 | return false; // Out of range. |
4068 | ExtractedElts.set(Idx); |
4069 | SrcDstPairs.emplace_back( |
4070 | Args: std::make_pair(x: MI.getOperand(i: Idx + 1).getReg(), y: &II)); |
4071 | } |
4072 | // Match if every element was extracted. |
4073 | return ExtractedElts.all(); |
4074 | } |
4075 | |
4076 | void CombinerHelper::( |
4077 | MachineInstr &MI, |
4078 | SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) { |
4079 | assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR); |
4080 | for (auto &Pair : SrcDstPairs) { |
4081 | auto *ExtMI = Pair.second; |
4082 | replaceRegWith(MRI, FromReg: ExtMI->getOperand(i: 0).getReg(), ToReg: Pair.first); |
4083 | ExtMI->eraseFromParent(); |
4084 | } |
4085 | MI.eraseFromParent(); |
4086 | } |
4087 | |
4088 | void CombinerHelper::applyBuildFn( |
4089 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
4090 | Builder.setInstrAndDebugLoc(MI); |
4091 | MatchInfo(Builder); |
4092 | MI.eraseFromParent(); |
4093 | } |
4094 | |
4095 | void CombinerHelper::applyBuildFnNoErase( |
4096 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
4097 | Builder.setInstrAndDebugLoc(MI); |
4098 | MatchInfo(Builder); |
4099 | } |
4100 | |
4101 | bool CombinerHelper::matchOrShiftToFunnelShift(MachineInstr &MI, |
4102 | BuildFnTy &MatchInfo) { |
4103 | assert(MI.getOpcode() == TargetOpcode::G_OR); |
4104 | |
4105 | Register Dst = MI.getOperand(i: 0).getReg(); |
4106 | LLT Ty = MRI.getType(Reg: Dst); |
4107 | unsigned BitWidth = Ty.getScalarSizeInBits(); |
4108 | |
4109 | Register ShlSrc, ShlAmt, LShrSrc, LShrAmt, Amt; |
4110 | unsigned FshOpc = 0; |
4111 | |
4112 | // Match (or (shl ...), (lshr ...)). |
4113 | if (!mi_match(R: Dst, MRI, |
4114 | // m_GOr() handles the commuted version as well. |
4115 | P: m_GOr(L: m_GShl(L: m_Reg(R&: ShlSrc), R: m_Reg(R&: ShlAmt)), |
4116 | R: m_GLShr(L: m_Reg(R&: LShrSrc), R: m_Reg(R&: LShrAmt))))) |
4117 | return false; |
4118 | |
4119 | // Given constants C0 and C1 such that C0 + C1 is bit-width: |
4120 | // (or (shl x, C0), (lshr y, C1)) -> (fshl x, y, C0) or (fshr x, y, C1) |
4121 | int64_t CstShlAmt, CstLShrAmt; |
4122 | if (mi_match(R: ShlAmt, MRI, P: m_ICstOrSplat(Cst&: CstShlAmt)) && |
4123 | mi_match(R: LShrAmt, MRI, P: m_ICstOrSplat(Cst&: CstLShrAmt)) && |
4124 | CstShlAmt + CstLShrAmt == BitWidth) { |
4125 | FshOpc = TargetOpcode::G_FSHR; |
4126 | Amt = LShrAmt; |
4127 | |
4128 | } else if (mi_match(R: LShrAmt, MRI, |
4129 | P: m_GSub(L: m_SpecificICstOrSplat(RequestedValue: BitWidth), R: m_Reg(R&: Amt))) && |
4130 | ShlAmt == Amt) { |
4131 | // (or (shl x, amt), (lshr y, (sub bw, amt))) -> (fshl x, y, amt) |
4132 | FshOpc = TargetOpcode::G_FSHL; |
4133 | |
4134 | } else if (mi_match(R: ShlAmt, MRI, |
4135 | P: m_GSub(L: m_SpecificICstOrSplat(RequestedValue: BitWidth), R: m_Reg(R&: Amt))) && |
4136 | LShrAmt == Amt) { |
4137 | // (or (shl x, (sub bw, amt)), (lshr y, amt)) -> (fshr x, y, amt) |
4138 | FshOpc = TargetOpcode::G_FSHR; |
4139 | |
4140 | } else { |
4141 | return false; |
4142 | } |
4143 | |
4144 | LLT AmtTy = MRI.getType(Reg: Amt); |
4145 | if (!isLegalOrBeforeLegalizer(Query: {FshOpc, {Ty, AmtTy}})) |
4146 | return false; |
4147 | |
4148 | MatchInfo = [=](MachineIRBuilder &B) { |
4149 | B.buildInstr(Opc: FshOpc, DstOps: {Dst}, SrcOps: {ShlSrc, LShrSrc, Amt}); |
4150 | }; |
4151 | return true; |
4152 | } |
4153 | |
4154 | /// Match an FSHL or FSHR that can be combined to a ROTR or ROTL rotate. |
4155 | bool CombinerHelper::matchFunnelShiftToRotate(MachineInstr &MI) { |
4156 | unsigned Opc = MI.getOpcode(); |
4157 | assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR); |
4158 | Register X = MI.getOperand(i: 1).getReg(); |
4159 | Register Y = MI.getOperand(i: 2).getReg(); |
4160 | if (X != Y) |
4161 | return false; |
4162 | unsigned RotateOpc = |
4163 | Opc == TargetOpcode::G_FSHL ? TargetOpcode::G_ROTL : TargetOpcode::G_ROTR; |
4164 | return isLegalOrBeforeLegalizer(Query: {RotateOpc, {MRI.getType(Reg: X), MRI.getType(Reg: Y)}}); |
4165 | } |
4166 | |
4167 | void CombinerHelper::applyFunnelShiftToRotate(MachineInstr &MI) { |
4168 | unsigned Opc = MI.getOpcode(); |
4169 | assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR); |
4170 | bool IsFSHL = Opc == TargetOpcode::G_FSHL; |
4171 | Observer.changingInstr(MI); |
4172 | MI.setDesc(Builder.getTII().get(Opcode: IsFSHL ? TargetOpcode::G_ROTL |
4173 | : TargetOpcode::G_ROTR)); |
4174 | MI.removeOperand(OpNo: 2); |
4175 | Observer.changedInstr(MI); |
4176 | } |
4177 | |
4178 | // Fold (rot x, c) -> (rot x, c % BitSize) |
4179 | bool CombinerHelper::matchRotateOutOfRange(MachineInstr &MI) { |
4180 | assert(MI.getOpcode() == TargetOpcode::G_ROTL || |
4181 | MI.getOpcode() == TargetOpcode::G_ROTR); |
4182 | unsigned Bitsize = |
4183 | MRI.getType(Reg: MI.getOperand(i: 0).getReg()).getScalarSizeInBits(); |
4184 | Register AmtReg = MI.getOperand(i: 2).getReg(); |
4185 | bool OutOfRange = false; |
4186 | auto MatchOutOfRange = [Bitsize, &OutOfRange](const Constant *C) { |
4187 | if (auto *CI = dyn_cast<ConstantInt>(Val: C)) |
4188 | OutOfRange |= CI->getValue().uge(RHS: Bitsize); |
4189 | return true; |
4190 | }; |
4191 | return matchUnaryPredicate(MRI, Reg: AmtReg, Match: MatchOutOfRange) && OutOfRange; |
4192 | } |
4193 | |
4194 | void CombinerHelper::applyRotateOutOfRange(MachineInstr &MI) { |
4195 | assert(MI.getOpcode() == TargetOpcode::G_ROTL || |
4196 | MI.getOpcode() == TargetOpcode::G_ROTR); |
4197 | unsigned Bitsize = |
4198 | MRI.getType(Reg: MI.getOperand(i: 0).getReg()).getScalarSizeInBits(); |
4199 | Builder.setInstrAndDebugLoc(MI); |
4200 | Register Amt = MI.getOperand(i: 2).getReg(); |
4201 | LLT AmtTy = MRI.getType(Reg: Amt); |
4202 | auto Bits = Builder.buildConstant(Res: AmtTy, Val: Bitsize); |
4203 | Amt = Builder.buildURem(Dst: AmtTy, Src0: MI.getOperand(i: 2).getReg(), Src1: Bits).getReg(Idx: 0); |
4204 | Observer.changingInstr(MI); |
4205 | MI.getOperand(i: 2).setReg(Amt); |
4206 | Observer.changedInstr(MI); |
4207 | } |
4208 | |
4209 | bool CombinerHelper::matchICmpToTrueFalseKnownBits(MachineInstr &MI, |
4210 | int64_t &MatchInfo) { |
4211 | assert(MI.getOpcode() == TargetOpcode::G_ICMP); |
4212 | auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(i: 1).getPredicate()); |
4213 | auto KnownLHS = KB->getKnownBits(R: MI.getOperand(i: 2).getReg()); |
4214 | auto KnownRHS = KB->getKnownBits(R: MI.getOperand(i: 3).getReg()); |
4215 | std::optional<bool> KnownVal; |
4216 | switch (Pred) { |
4217 | default: |
4218 | llvm_unreachable("Unexpected G_ICMP predicate?" ); |
4219 | case CmpInst::ICMP_EQ: |
4220 | KnownVal = KnownBits::eq(LHS: KnownLHS, RHS: KnownRHS); |
4221 | break; |
4222 | case CmpInst::ICMP_NE: |
4223 | KnownVal = KnownBits::ne(LHS: KnownLHS, RHS: KnownRHS); |
4224 | break; |
4225 | case CmpInst::ICMP_SGE: |
4226 | KnownVal = KnownBits::sge(LHS: KnownLHS, RHS: KnownRHS); |
4227 | break; |
4228 | case CmpInst::ICMP_SGT: |
4229 | KnownVal = KnownBits::sgt(LHS: KnownLHS, RHS: KnownRHS); |
4230 | break; |
4231 | case CmpInst::ICMP_SLE: |
4232 | KnownVal = KnownBits::sle(LHS: KnownLHS, RHS: KnownRHS); |
4233 | break; |
4234 | case CmpInst::ICMP_SLT: |
4235 | KnownVal = KnownBits::slt(LHS: KnownLHS, RHS: KnownRHS); |
4236 | break; |
4237 | case CmpInst::ICMP_UGE: |
4238 | KnownVal = KnownBits::uge(LHS: KnownLHS, RHS: KnownRHS); |
4239 | break; |
4240 | case CmpInst::ICMP_UGT: |
4241 | KnownVal = KnownBits::ugt(LHS: KnownLHS, RHS: KnownRHS); |
4242 | break; |
4243 | case CmpInst::ICMP_ULE: |
4244 | KnownVal = KnownBits::ule(LHS: KnownLHS, RHS: KnownRHS); |
4245 | break; |
4246 | case CmpInst::ICMP_ULT: |
4247 | KnownVal = KnownBits::ult(LHS: KnownLHS, RHS: KnownRHS); |
4248 | break; |
4249 | } |
4250 | if (!KnownVal) |
4251 | return false; |
4252 | MatchInfo = |
4253 | *KnownVal |
4254 | ? getICmpTrueVal(TLI: getTargetLowering(), |
4255 | /*IsVector = */ |
4256 | MRI.getType(Reg: MI.getOperand(i: 0).getReg()).isVector(), |
4257 | /* IsFP = */ false) |
4258 | : 0; |
4259 | return true; |
4260 | } |
4261 | |
4262 | bool CombinerHelper::matchICmpToLHSKnownBits( |
4263 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
4264 | assert(MI.getOpcode() == TargetOpcode::G_ICMP); |
4265 | // Given: |
4266 | // |
4267 | // %x = G_WHATEVER (... x is known to be 0 or 1 ...) |
4268 | // %cmp = G_ICMP ne %x, 0 |
4269 | // |
4270 | // Or: |
4271 | // |
4272 | // %x = G_WHATEVER (... x is known to be 0 or 1 ...) |
4273 | // %cmp = G_ICMP eq %x, 1 |
4274 | // |
4275 | // We can replace %cmp with %x assuming true is 1 on the target. |
4276 | auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(i: 1).getPredicate()); |
4277 | if (!CmpInst::isEquality(pred: Pred)) |
4278 | return false; |
4279 | Register Dst = MI.getOperand(i: 0).getReg(); |
4280 | LLT DstTy = MRI.getType(Reg: Dst); |
4281 | if (getICmpTrueVal(TLI: getTargetLowering(), IsVector: DstTy.isVector(), |
4282 | /* IsFP = */ false) != 1) |
4283 | return false; |
4284 | int64_t OneOrZero = Pred == CmpInst::ICMP_EQ; |
4285 | if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICst(RequestedValue: OneOrZero))) |
4286 | return false; |
4287 | Register LHS = MI.getOperand(i: 2).getReg(); |
4288 | auto KnownLHS = KB->getKnownBits(R: LHS); |
4289 | if (KnownLHS.getMinValue() != 0 || KnownLHS.getMaxValue() != 1) |
4290 | return false; |
4291 | // Make sure replacing Dst with the LHS is a legal operation. |
4292 | LLT LHSTy = MRI.getType(Reg: LHS); |
4293 | unsigned LHSSize = LHSTy.getSizeInBits(); |
4294 | unsigned DstSize = DstTy.getSizeInBits(); |
4295 | unsigned Op = TargetOpcode::COPY; |
4296 | if (DstSize != LHSSize) |
4297 | Op = DstSize < LHSSize ? TargetOpcode::G_TRUNC : TargetOpcode::G_ZEXT; |
4298 | if (!isLegalOrBeforeLegalizer(Query: {Op, {DstTy, LHSTy}})) |
4299 | return false; |
4300 | MatchInfo = [=](MachineIRBuilder &B) { B.buildInstr(Opc: Op, DstOps: {Dst}, SrcOps: {LHS}); }; |
4301 | return true; |
4302 | } |
4303 | |
4304 | // Replace (and (or x, c1), c2) with (and x, c2) iff c1 & c2 == 0 |
4305 | bool CombinerHelper::matchAndOrDisjointMask( |
4306 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
4307 | assert(MI.getOpcode() == TargetOpcode::G_AND); |
4308 | |
4309 | // Ignore vector types to simplify matching the two constants. |
4310 | // TODO: do this for vectors and scalars via a demanded bits analysis. |
4311 | LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
4312 | if (Ty.isVector()) |
4313 | return false; |
4314 | |
4315 | Register Src; |
4316 | Register AndMaskReg; |
4317 | int64_t AndMaskBits; |
4318 | int64_t OrMaskBits; |
4319 | if (!mi_match(MI, MRI, |
4320 | P: m_GAnd(L: m_GOr(L: m_Reg(R&: Src), R: m_ICst(Cst&: OrMaskBits)), |
4321 | R: m_all_of(preds: m_ICst(Cst&: AndMaskBits), preds: m_Reg(R&: AndMaskReg))))) |
4322 | return false; |
4323 | |
4324 | // Check if OrMask could turn on any bits in Src. |
4325 | if (AndMaskBits & OrMaskBits) |
4326 | return false; |
4327 | |
4328 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
4329 | Observer.changingInstr(MI); |
4330 | // Canonicalize the result to have the constant on the RHS. |
4331 | if (MI.getOperand(i: 1).getReg() == AndMaskReg) |
4332 | MI.getOperand(i: 2).setReg(AndMaskReg); |
4333 | MI.getOperand(i: 1).setReg(Src); |
4334 | Observer.changedInstr(MI); |
4335 | }; |
4336 | return true; |
4337 | } |
4338 | |
4339 | /// Form a G_SBFX from a G_SEXT_INREG fed by a right shift. |
4340 | bool CombinerHelper::( |
4341 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
4342 | assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); |
4343 | Register Dst = MI.getOperand(i: 0).getReg(); |
4344 | Register Src = MI.getOperand(i: 1).getReg(); |
4345 | LLT Ty = MRI.getType(Reg: Src); |
4346 | LLT = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
4347 | if (!LI || !LI->isLegalOrCustom(Query: {TargetOpcode::G_SBFX, {Ty, ExtractTy}})) |
4348 | return false; |
4349 | int64_t Width = MI.getOperand(i: 2).getImm(); |
4350 | Register ShiftSrc; |
4351 | int64_t ShiftImm; |
4352 | if (!mi_match( |
4353 | R: Src, MRI, |
4354 | P: m_OneNonDBGUse(SP: m_any_of(preds: m_GAShr(L: m_Reg(R&: ShiftSrc), R: m_ICst(Cst&: ShiftImm)), |
4355 | preds: m_GLShr(L: m_Reg(R&: ShiftSrc), R: m_ICst(Cst&: ShiftImm)))))) |
4356 | return false; |
4357 | if (ShiftImm < 0 || ShiftImm + Width > Ty.getScalarSizeInBits()) |
4358 | return false; |
4359 | |
4360 | MatchInfo = [=](MachineIRBuilder &B) { |
4361 | auto Cst1 = B.buildConstant(Res: ExtractTy, Val: ShiftImm); |
4362 | auto Cst2 = B.buildConstant(Res: ExtractTy, Val: Width); |
4363 | B.buildSbfx(Dst, Src: ShiftSrc, LSB: Cst1, Width: Cst2); |
4364 | }; |
4365 | return true; |
4366 | } |
4367 | |
4368 | /// Form a G_UBFX from "(a srl b) & mask", where b and mask are constants. |
4369 | bool CombinerHelper::matchBitfieldExtractFromAnd( |
4370 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
4371 | assert(MI.getOpcode() == TargetOpcode::G_AND); |
4372 | Register Dst = MI.getOperand(i: 0).getReg(); |
4373 | LLT Ty = MRI.getType(Reg: Dst); |
4374 | LLT = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
4375 | if (LI && !LI->isLegalOrCustom(Query: {TargetOpcode::G_UBFX, {Ty, ExtractTy}})) |
4376 | return false; |
4377 | |
4378 | int64_t AndImm, LSBImm; |
4379 | Register ShiftSrc; |
4380 | const unsigned Size = Ty.getScalarSizeInBits(); |
4381 | if (!mi_match(R: MI.getOperand(i: 0).getReg(), MRI, |
4382 | P: m_GAnd(L: m_OneNonDBGUse(SP: m_GLShr(L: m_Reg(R&: ShiftSrc), R: m_ICst(Cst&: LSBImm))), |
4383 | R: m_ICst(Cst&: AndImm)))) |
4384 | return false; |
4385 | |
4386 | // The mask is a mask of the low bits iff imm & (imm+1) == 0. |
4387 | auto MaybeMask = static_cast<uint64_t>(AndImm); |
4388 | if (MaybeMask & (MaybeMask + 1)) |
4389 | return false; |
4390 | |
4391 | // LSB must fit within the register. |
4392 | if (static_cast<uint64_t>(LSBImm) >= Size) |
4393 | return false; |
4394 | |
4395 | uint64_t Width = APInt(Size, AndImm).countr_one(); |
4396 | MatchInfo = [=](MachineIRBuilder &B) { |
4397 | auto WidthCst = B.buildConstant(Res: ExtractTy, Val: Width); |
4398 | auto LSBCst = B.buildConstant(Res: ExtractTy, Val: LSBImm); |
4399 | B.buildInstr(Opc: TargetOpcode::G_UBFX, DstOps: {Dst}, SrcOps: {ShiftSrc, LSBCst, WidthCst}); |
4400 | }; |
4401 | return true; |
4402 | } |
4403 | |
4404 | bool CombinerHelper::( |
4405 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
4406 | const unsigned Opcode = MI.getOpcode(); |
4407 | assert(Opcode == TargetOpcode::G_ASHR || Opcode == TargetOpcode::G_LSHR); |
4408 | |
4409 | const Register Dst = MI.getOperand(i: 0).getReg(); |
4410 | |
4411 | const unsigned ExtrOpcode = Opcode == TargetOpcode::G_ASHR |
4412 | ? TargetOpcode::G_SBFX |
4413 | : TargetOpcode::G_UBFX; |
4414 | |
4415 | // Check if the type we would use for the extract is legal |
4416 | LLT Ty = MRI.getType(Reg: Dst); |
4417 | LLT = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
4418 | if (!LI || !LI->isLegalOrCustom(Query: {ExtrOpcode, {Ty, ExtractTy}})) |
4419 | return false; |
4420 | |
4421 | Register ShlSrc; |
4422 | int64_t ShrAmt; |
4423 | int64_t ShlAmt; |
4424 | const unsigned Size = Ty.getScalarSizeInBits(); |
4425 | |
4426 | // Try to match shr (shl x, c1), c2 |
4427 | if (!mi_match(R: Dst, MRI, |
4428 | P: m_BinOp(Opcode, |
4429 | L: m_OneNonDBGUse(SP: m_GShl(L: m_Reg(R&: ShlSrc), R: m_ICst(Cst&: ShlAmt))), |
4430 | R: m_ICst(Cst&: ShrAmt)))) |
4431 | return false; |
4432 | |
4433 | // Make sure that the shift sizes can fit a bitfield extract |
4434 | if (ShlAmt < 0 || ShlAmt > ShrAmt || ShrAmt >= Size) |
4435 | return false; |
4436 | |
4437 | // Skip this combine if the G_SEXT_INREG combine could handle it |
4438 | if (Opcode == TargetOpcode::G_ASHR && ShlAmt == ShrAmt) |
4439 | return false; |
4440 | |
4441 | // Calculate start position and width of the extract |
4442 | const int64_t Pos = ShrAmt - ShlAmt; |
4443 | const int64_t Width = Size - ShrAmt; |
4444 | |
4445 | MatchInfo = [=](MachineIRBuilder &B) { |
4446 | auto WidthCst = B.buildConstant(Res: ExtractTy, Val: Width); |
4447 | auto PosCst = B.buildConstant(Res: ExtractTy, Val: Pos); |
4448 | B.buildInstr(Opc: ExtrOpcode, DstOps: {Dst}, SrcOps: {ShlSrc, PosCst, WidthCst}); |
4449 | }; |
4450 | return true; |
4451 | } |
4452 | |
4453 | bool CombinerHelper::matchBitfieldExtractFromShrAnd( |
4454 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
4455 | const unsigned Opcode = MI.getOpcode(); |
4456 | assert(Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_ASHR); |
4457 | |
4458 | const Register Dst = MI.getOperand(i: 0).getReg(); |
4459 | LLT Ty = MRI.getType(Reg: Dst); |
4460 | LLT = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
4461 | if (LI && !LI->isLegalOrCustom(Query: {TargetOpcode::G_UBFX, {Ty, ExtractTy}})) |
4462 | return false; |
4463 | |
4464 | // Try to match shr (and x, c1), c2 |
4465 | Register AndSrc; |
4466 | int64_t ShrAmt; |
4467 | int64_t SMask; |
4468 | if (!mi_match(R: Dst, MRI, |
4469 | P: m_BinOp(Opcode, |
4470 | L: m_OneNonDBGUse(SP: m_GAnd(L: m_Reg(R&: AndSrc), R: m_ICst(Cst&: SMask))), |
4471 | R: m_ICst(Cst&: ShrAmt)))) |
4472 | return false; |
4473 | |
4474 | const unsigned Size = Ty.getScalarSizeInBits(); |
4475 | if (ShrAmt < 0 || ShrAmt >= Size) |
4476 | return false; |
4477 | |
4478 | // If the shift subsumes the mask, emit the 0 directly. |
4479 | if (0 == (SMask >> ShrAmt)) { |
4480 | MatchInfo = [=](MachineIRBuilder &B) { |
4481 | B.buildConstant(Res: Dst, Val: 0); |
4482 | }; |
4483 | return true; |
4484 | } |
4485 | |
4486 | // Check that ubfx can do the extraction, with no holes in the mask. |
4487 | uint64_t UMask = SMask; |
4488 | UMask |= maskTrailingOnes<uint64_t>(N: ShrAmt); |
4489 | UMask &= maskTrailingOnes<uint64_t>(N: Size); |
4490 | if (!isMask_64(Value: UMask)) |
4491 | return false; |
4492 | |
4493 | // Calculate start position and width of the extract. |
4494 | const int64_t Pos = ShrAmt; |
4495 | const int64_t Width = llvm::countr_one(Value: UMask) - ShrAmt; |
4496 | |
4497 | // It's preferable to keep the shift, rather than form G_SBFX. |
4498 | // TODO: remove the G_AND via demanded bits analysis. |
4499 | if (Opcode == TargetOpcode::G_ASHR && Width + ShrAmt == Size) |
4500 | return false; |
4501 | |
4502 | MatchInfo = [=](MachineIRBuilder &B) { |
4503 | auto WidthCst = B.buildConstant(Res: ExtractTy, Val: Width); |
4504 | auto PosCst = B.buildConstant(Res: ExtractTy, Val: Pos); |
4505 | B.buildInstr(Opc: TargetOpcode::G_UBFX, DstOps: {Dst}, SrcOps: {AndSrc, PosCst, WidthCst}); |
4506 | }; |
4507 | return true; |
4508 | } |
4509 | |
4510 | bool CombinerHelper::reassociationCanBreakAddressingModePattern( |
4511 | MachineInstr &MI) { |
4512 | auto &PtrAdd = cast<GPtrAdd>(Val&: MI); |
4513 | |
4514 | Register Src1Reg = PtrAdd.getBaseReg(); |
4515 | auto *Src1Def = getOpcodeDef<GPtrAdd>(Reg: Src1Reg, MRI); |
4516 | if (!Src1Def) |
4517 | return false; |
4518 | |
4519 | Register Src2Reg = PtrAdd.getOffsetReg(); |
4520 | |
4521 | if (MRI.hasOneNonDBGUse(RegNo: Src1Reg)) |
4522 | return false; |
4523 | |
4524 | auto C1 = getIConstantVRegVal(VReg: Src1Def->getOffsetReg(), MRI); |
4525 | if (!C1) |
4526 | return false; |
4527 | auto C2 = getIConstantVRegVal(VReg: Src2Reg, MRI); |
4528 | if (!C2) |
4529 | return false; |
4530 | |
4531 | const APInt &C1APIntVal = *C1; |
4532 | const APInt &C2APIntVal = *C2; |
4533 | const int64_t CombinedValue = (C1APIntVal + C2APIntVal).getSExtValue(); |
4534 | |
4535 | for (auto &UseMI : MRI.use_nodbg_instructions(Reg: PtrAdd.getReg(Idx: 0))) { |
4536 | // This combine may end up running before ptrtoint/inttoptr combines |
4537 | // manage to eliminate redundant conversions, so try to look through them. |
4538 | MachineInstr *ConvUseMI = &UseMI; |
4539 | unsigned ConvUseOpc = ConvUseMI->getOpcode(); |
4540 | while (ConvUseOpc == TargetOpcode::G_INTTOPTR || |
4541 | ConvUseOpc == TargetOpcode::G_PTRTOINT) { |
4542 | Register DefReg = ConvUseMI->getOperand(i: 0).getReg(); |
4543 | if (!MRI.hasOneNonDBGUse(RegNo: DefReg)) |
4544 | break; |
4545 | ConvUseMI = &*MRI.use_instr_nodbg_begin(RegNo: DefReg); |
4546 | ConvUseOpc = ConvUseMI->getOpcode(); |
4547 | } |
4548 | auto *LdStMI = dyn_cast<GLoadStore>(Val: ConvUseMI); |
4549 | if (!LdStMI) |
4550 | continue; |
4551 | // Is x[offset2] already not a legal addressing mode? If so then |
4552 | // reassociating the constants breaks nothing (we test offset2 because |
4553 | // that's the one we hope to fold into the load or store). |
4554 | TargetLoweringBase::AddrMode AM; |
4555 | AM.HasBaseReg = true; |
4556 | AM.BaseOffs = C2APIntVal.getSExtValue(); |
4557 | unsigned AS = MRI.getType(Reg: LdStMI->getPointerReg()).getAddressSpace(); |
4558 | Type *AccessTy = getTypeForLLT(Ty: LdStMI->getMMO().getMemoryType(), |
4559 | C&: PtrAdd.getMF()->getFunction().getContext()); |
4560 | const auto &TLI = *PtrAdd.getMF()->getSubtarget().getTargetLowering(); |
4561 | if (!TLI.isLegalAddressingMode(DL: PtrAdd.getMF()->getDataLayout(), AM, |
4562 | Ty: AccessTy, AddrSpace: AS)) |
4563 | continue; |
4564 | |
4565 | // Would x[offset1+offset2] still be a legal addressing mode? |
4566 | AM.BaseOffs = CombinedValue; |
4567 | if (!TLI.isLegalAddressingMode(DL: PtrAdd.getMF()->getDataLayout(), AM, |
4568 | Ty: AccessTy, AddrSpace: AS)) |
4569 | return true; |
4570 | } |
4571 | |
4572 | return false; |
4573 | } |
4574 | |
4575 | bool CombinerHelper::matchReassocConstantInnerRHS(GPtrAdd &MI, |
4576 | MachineInstr *RHS, |
4577 | BuildFnTy &MatchInfo) { |
4578 | // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C) |
4579 | Register Src1Reg = MI.getOperand(i: 1).getReg(); |
4580 | if (RHS->getOpcode() != TargetOpcode::G_ADD) |
4581 | return false; |
4582 | auto C2 = getIConstantVRegVal(VReg: RHS->getOperand(i: 2).getReg(), MRI); |
4583 | if (!C2) |
4584 | return false; |
4585 | |
4586 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
4587 | LLT PtrTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
4588 | |
4589 | auto NewBase = |
4590 | Builder.buildPtrAdd(Res: PtrTy, Op0: Src1Reg, Op1: RHS->getOperand(i: 1).getReg()); |
4591 | Observer.changingInstr(MI); |
4592 | MI.getOperand(i: 1).setReg(NewBase.getReg(Idx: 0)); |
4593 | MI.getOperand(i: 2).setReg(RHS->getOperand(i: 2).getReg()); |
4594 | Observer.changedInstr(MI); |
4595 | }; |
4596 | return !reassociationCanBreakAddressingModePattern(MI); |
4597 | } |
4598 | |
4599 | bool CombinerHelper::matchReassocConstantInnerLHS(GPtrAdd &MI, |
4600 | MachineInstr *LHS, |
4601 | MachineInstr *RHS, |
4602 | BuildFnTy &MatchInfo) { |
4603 | // G_PTR_ADD (G_PTR_ADD X, C), Y) -> (G_PTR_ADD (G_PTR_ADD(X, Y), C) |
4604 | // if and only if (G_PTR_ADD X, C) has one use. |
4605 | Register LHSBase; |
4606 | std::optional<ValueAndVReg> LHSCstOff; |
4607 | if (!mi_match(R: MI.getBaseReg(), MRI, |
4608 | P: m_OneNonDBGUse(SP: m_GPtrAdd(L: m_Reg(R&: LHSBase), R: m_GCst(ValReg&: LHSCstOff))))) |
4609 | return false; |
4610 | |
4611 | auto *LHSPtrAdd = cast<GPtrAdd>(Val: LHS); |
4612 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
4613 | // When we change LHSPtrAdd's offset register we might cause it to use a reg |
4614 | // before its def. Sink the instruction so the outer PTR_ADD to ensure this |
4615 | // doesn't happen. |
4616 | LHSPtrAdd->moveBefore(MovePos: &MI); |
4617 | Register RHSReg = MI.getOffsetReg(); |
4618 | // set VReg will cause type mismatch if it comes from extend/trunc |
4619 | auto NewCst = B.buildConstant(Res: MRI.getType(Reg: RHSReg), Val: LHSCstOff->Value); |
4620 | Observer.changingInstr(MI); |
4621 | MI.getOperand(i: 2).setReg(NewCst.getReg(Idx: 0)); |
4622 | Observer.changedInstr(MI); |
4623 | Observer.changingInstr(MI&: *LHSPtrAdd); |
4624 | LHSPtrAdd->getOperand(i: 2).setReg(RHSReg); |
4625 | Observer.changedInstr(MI&: *LHSPtrAdd); |
4626 | }; |
4627 | return !reassociationCanBreakAddressingModePattern(MI); |
4628 | } |
4629 | |
4630 | bool CombinerHelper::matchReassocFoldConstantsInSubTree(GPtrAdd &MI, |
4631 | MachineInstr *LHS, |
4632 | MachineInstr *RHS, |
4633 | BuildFnTy &MatchInfo) { |
4634 | // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2) |
4635 | auto *LHSPtrAdd = dyn_cast<GPtrAdd>(Val: LHS); |
4636 | if (!LHSPtrAdd) |
4637 | return false; |
4638 | |
4639 | Register Src2Reg = MI.getOperand(i: 2).getReg(); |
4640 | Register LHSSrc1 = LHSPtrAdd->getBaseReg(); |
4641 | Register LHSSrc2 = LHSPtrAdd->getOffsetReg(); |
4642 | auto C1 = getIConstantVRegVal(VReg: LHSSrc2, MRI); |
4643 | if (!C1) |
4644 | return false; |
4645 | auto C2 = getIConstantVRegVal(VReg: Src2Reg, MRI); |
4646 | if (!C2) |
4647 | return false; |
4648 | |
4649 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
4650 | auto NewCst = B.buildConstant(Res: MRI.getType(Reg: Src2Reg), Val: *C1 + *C2); |
4651 | Observer.changingInstr(MI); |
4652 | MI.getOperand(i: 1).setReg(LHSSrc1); |
4653 | MI.getOperand(i: 2).setReg(NewCst.getReg(Idx: 0)); |
4654 | Observer.changedInstr(MI); |
4655 | }; |
4656 | return !reassociationCanBreakAddressingModePattern(MI); |
4657 | } |
4658 | |
4659 | bool CombinerHelper::matchReassocPtrAdd(MachineInstr &MI, |
4660 | BuildFnTy &MatchInfo) { |
4661 | auto &PtrAdd = cast<GPtrAdd>(Val&: MI); |
4662 | // We're trying to match a few pointer computation patterns here for |
4663 | // re-association opportunities. |
4664 | // 1) Isolating a constant operand to be on the RHS, e.g.: |
4665 | // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C) |
4666 | // |
4667 | // 2) Folding two constants in each sub-tree as long as such folding |
4668 | // doesn't break a legal addressing mode. |
4669 | // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2) |
4670 | // |
4671 | // 3) Move a constant from the LHS of an inner op to the RHS of the outer. |
4672 | // G_PTR_ADD (G_PTR_ADD X, C), Y) -> G_PTR_ADD (G_PTR_ADD(X, Y), C) |
4673 | // iif (G_PTR_ADD X, C) has one use. |
4674 | MachineInstr *LHS = MRI.getVRegDef(Reg: PtrAdd.getBaseReg()); |
4675 | MachineInstr *RHS = MRI.getVRegDef(Reg: PtrAdd.getOffsetReg()); |
4676 | |
4677 | // Try to match example 2. |
4678 | if (matchReassocFoldConstantsInSubTree(MI&: PtrAdd, LHS, RHS, MatchInfo)) |
4679 | return true; |
4680 | |
4681 | // Try to match example 3. |
4682 | if (matchReassocConstantInnerLHS(MI&: PtrAdd, LHS, RHS, MatchInfo)) |
4683 | return true; |
4684 | |
4685 | // Try to match example 1. |
4686 | if (matchReassocConstantInnerRHS(MI&: PtrAdd, RHS, MatchInfo)) |
4687 | return true; |
4688 | |
4689 | return false; |
4690 | } |
4691 | bool CombinerHelper::tryReassocBinOp(unsigned Opc, Register DstReg, |
4692 | Register OpLHS, Register OpRHS, |
4693 | BuildFnTy &MatchInfo) { |
4694 | LLT OpRHSTy = MRI.getType(Reg: OpRHS); |
4695 | MachineInstr *OpLHSDef = MRI.getVRegDef(Reg: OpLHS); |
4696 | |
4697 | if (OpLHSDef->getOpcode() != Opc) |
4698 | return false; |
4699 | |
4700 | MachineInstr *OpRHSDef = MRI.getVRegDef(Reg: OpRHS); |
4701 | Register OpLHSLHS = OpLHSDef->getOperand(i: 1).getReg(); |
4702 | Register OpLHSRHS = OpLHSDef->getOperand(i: 2).getReg(); |
4703 | |
4704 | // If the inner op is (X op C), pull the constant out so it can be folded with |
4705 | // other constants in the expression tree. Folding is not guaranteed so we |
4706 | // might have (C1 op C2). In that case do not pull a constant out because it |
4707 | // won't help and can lead to infinite loops. |
4708 | if (isConstantOrConstantSplatVector(MI&: *MRI.getVRegDef(Reg: OpLHSRHS), MRI) && |
4709 | !isConstantOrConstantSplatVector(MI&: *MRI.getVRegDef(Reg: OpLHSLHS), MRI)) { |
4710 | if (isConstantOrConstantSplatVector(MI&: *OpRHSDef, MRI)) { |
4711 | // (Opc (Opc X, C1), C2) -> (Opc X, (Opc C1, C2)) |
4712 | MatchInfo = [=](MachineIRBuilder &B) { |
4713 | auto NewCst = B.buildInstr(Opc, DstOps: {OpRHSTy}, SrcOps: {OpLHSRHS, OpRHS}); |
4714 | B.buildInstr(Opc, DstOps: {DstReg}, SrcOps: {OpLHSLHS, NewCst}); |
4715 | }; |
4716 | return true; |
4717 | } |
4718 | if (getTargetLowering().isReassocProfitable(MRI, N0: OpLHS, N1: OpRHS)) { |
4719 | // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1) |
4720 | // iff (op x, c1) has one use |
4721 | MatchInfo = [=](MachineIRBuilder &B) { |
4722 | auto NewLHSLHS = B.buildInstr(Opc, DstOps: {OpRHSTy}, SrcOps: {OpLHSLHS, OpRHS}); |
4723 | B.buildInstr(Opc, DstOps: {DstReg}, SrcOps: {NewLHSLHS, OpLHSRHS}); |
4724 | }; |
4725 | return true; |
4726 | } |
4727 | } |
4728 | |
4729 | return false; |
4730 | } |
4731 | |
4732 | bool CombinerHelper::matchReassocCommBinOp(MachineInstr &MI, |
4733 | BuildFnTy &MatchInfo) { |
4734 | // We don't check if the reassociation will break a legal addressing mode |
4735 | // here since pointer arithmetic is handled by G_PTR_ADD. |
4736 | unsigned Opc = MI.getOpcode(); |
4737 | Register DstReg = MI.getOperand(i: 0).getReg(); |
4738 | Register LHSReg = MI.getOperand(i: 1).getReg(); |
4739 | Register RHSReg = MI.getOperand(i: 2).getReg(); |
4740 | |
4741 | if (tryReassocBinOp(Opc, DstReg, OpLHS: LHSReg, OpRHS: RHSReg, MatchInfo)) |
4742 | return true; |
4743 | if (tryReassocBinOp(Opc, DstReg, OpLHS: RHSReg, OpRHS: LHSReg, MatchInfo)) |
4744 | return true; |
4745 | return false; |
4746 | } |
4747 | |
4748 | bool CombinerHelper::matchConstantFoldCastOp(MachineInstr &MI, APInt &MatchInfo) { |
4749 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
4750 | Register SrcOp = MI.getOperand(i: 1).getReg(); |
4751 | |
4752 | if (auto MaybeCst = ConstantFoldCastOp(Opcode: MI.getOpcode(), DstTy, Op0: SrcOp, MRI)) { |
4753 | MatchInfo = *MaybeCst; |
4754 | return true; |
4755 | } |
4756 | |
4757 | return false; |
4758 | } |
4759 | |
4760 | bool CombinerHelper::matchConstantFoldBinOp(MachineInstr &MI, APInt &MatchInfo) { |
4761 | Register Op1 = MI.getOperand(i: 1).getReg(); |
4762 | Register Op2 = MI.getOperand(i: 2).getReg(); |
4763 | auto MaybeCst = ConstantFoldBinOp(Opcode: MI.getOpcode(), Op1, Op2, MRI); |
4764 | if (!MaybeCst) |
4765 | return false; |
4766 | MatchInfo = *MaybeCst; |
4767 | return true; |
4768 | } |
4769 | |
4770 | bool CombinerHelper::matchConstantFoldFPBinOp(MachineInstr &MI, ConstantFP* &MatchInfo) { |
4771 | Register Op1 = MI.getOperand(i: 1).getReg(); |
4772 | Register Op2 = MI.getOperand(i: 2).getReg(); |
4773 | auto MaybeCst = ConstantFoldFPBinOp(Opcode: MI.getOpcode(), Op1, Op2, MRI); |
4774 | if (!MaybeCst) |
4775 | return false; |
4776 | MatchInfo = |
4777 | ConstantFP::get(Context&: MI.getMF()->getFunction().getContext(), V: *MaybeCst); |
4778 | return true; |
4779 | } |
4780 | |
4781 | bool CombinerHelper::matchConstantFoldFMA(MachineInstr &MI, |
4782 | ConstantFP *&MatchInfo) { |
4783 | assert(MI.getOpcode() == TargetOpcode::G_FMA || |
4784 | MI.getOpcode() == TargetOpcode::G_FMAD); |
4785 | auto [_, Op1, Op2, Op3] = MI.getFirst4Regs(); |
4786 | |
4787 | const ConstantFP *Op3Cst = getConstantFPVRegVal(VReg: Op3, MRI); |
4788 | if (!Op3Cst) |
4789 | return false; |
4790 | |
4791 | const ConstantFP *Op2Cst = getConstantFPVRegVal(VReg: Op2, MRI); |
4792 | if (!Op2Cst) |
4793 | return false; |
4794 | |
4795 | const ConstantFP *Op1Cst = getConstantFPVRegVal(VReg: Op1, MRI); |
4796 | if (!Op1Cst) |
4797 | return false; |
4798 | |
4799 | APFloat Op1F = Op1Cst->getValueAPF(); |
4800 | Op1F.fusedMultiplyAdd(Multiplicand: Op2Cst->getValueAPF(), Addend: Op3Cst->getValueAPF(), |
4801 | RM: APFloat::rmNearestTiesToEven); |
4802 | MatchInfo = ConstantFP::get(Context&: MI.getMF()->getFunction().getContext(), V: Op1F); |
4803 | return true; |
4804 | } |
4805 | |
4806 | bool CombinerHelper::matchNarrowBinopFeedingAnd( |
4807 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
4808 | // Look for a binop feeding into an AND with a mask: |
4809 | // |
4810 | // %add = G_ADD %lhs, %rhs |
4811 | // %and = G_AND %add, 000...11111111 |
4812 | // |
4813 | // Check if it's possible to perform the binop at a narrower width and zext |
4814 | // back to the original width like so: |
4815 | // |
4816 | // %narrow_lhs = G_TRUNC %lhs |
4817 | // %narrow_rhs = G_TRUNC %rhs |
4818 | // %narrow_add = G_ADD %narrow_lhs, %narrow_rhs |
4819 | // %new_add = G_ZEXT %narrow_add |
4820 | // %and = G_AND %new_add, 000...11111111 |
4821 | // |
4822 | // This can allow later combines to eliminate the G_AND if it turns out |
4823 | // that the mask is irrelevant. |
4824 | assert(MI.getOpcode() == TargetOpcode::G_AND); |
4825 | Register Dst = MI.getOperand(i: 0).getReg(); |
4826 | Register AndLHS = MI.getOperand(i: 1).getReg(); |
4827 | Register AndRHS = MI.getOperand(i: 2).getReg(); |
4828 | LLT WideTy = MRI.getType(Reg: Dst); |
4829 | |
4830 | // If the potential binop has more than one use, then it's possible that one |
4831 | // of those uses will need its full width. |
4832 | if (!WideTy.isScalar() || !MRI.hasOneNonDBGUse(RegNo: AndLHS)) |
4833 | return false; |
4834 | |
4835 | // Check if the LHS feeding the AND is impacted by the high bits that we're |
4836 | // masking out. |
4837 | // |
4838 | // e.g. for 64-bit x, y: |
4839 | // |
4840 | // add_64(x, y) & 65535 == zext(add_16(trunc(x), trunc(y))) & 65535 |
4841 | MachineInstr *LHSInst = getDefIgnoringCopies(Reg: AndLHS, MRI); |
4842 | if (!LHSInst) |
4843 | return false; |
4844 | unsigned LHSOpc = LHSInst->getOpcode(); |
4845 | switch (LHSOpc) { |
4846 | default: |
4847 | return false; |
4848 | case TargetOpcode::G_ADD: |
4849 | case TargetOpcode::G_SUB: |
4850 | case TargetOpcode::G_MUL: |
4851 | case TargetOpcode::G_AND: |
4852 | case TargetOpcode::G_OR: |
4853 | case TargetOpcode::G_XOR: |
4854 | break; |
4855 | } |
4856 | |
4857 | // Find the mask on the RHS. |
4858 | auto Cst = getIConstantVRegValWithLookThrough(VReg: AndRHS, MRI); |
4859 | if (!Cst) |
4860 | return false; |
4861 | auto Mask = Cst->Value; |
4862 | if (!Mask.isMask()) |
4863 | return false; |
4864 | |
4865 | // No point in combining if there's nothing to truncate. |
4866 | unsigned NarrowWidth = Mask.countr_one(); |
4867 | if (NarrowWidth == WideTy.getSizeInBits()) |
4868 | return false; |
4869 | LLT NarrowTy = LLT::scalar(SizeInBits: NarrowWidth); |
4870 | |
4871 | // Check if adding the zext + truncates could be harmful. |
4872 | auto &MF = *MI.getMF(); |
4873 | const auto &TLI = getTargetLowering(); |
4874 | LLVMContext &Ctx = MF.getFunction().getContext(); |
4875 | auto &DL = MF.getDataLayout(); |
4876 | if (!TLI.isTruncateFree(FromTy: WideTy, ToTy: NarrowTy, DL, Ctx) || |
4877 | !TLI.isZExtFree(FromTy: NarrowTy, ToTy: WideTy, DL, Ctx)) |
4878 | return false; |
4879 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_TRUNC, {NarrowTy, WideTy}}) || |
4880 | !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ZEXT, {WideTy, NarrowTy}})) |
4881 | return false; |
4882 | Register BinOpLHS = LHSInst->getOperand(i: 1).getReg(); |
4883 | Register BinOpRHS = LHSInst->getOperand(i: 2).getReg(); |
4884 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
4885 | auto NarrowLHS = Builder.buildTrunc(Res: NarrowTy, Op: BinOpLHS); |
4886 | auto NarrowRHS = Builder.buildTrunc(Res: NarrowTy, Op: BinOpRHS); |
4887 | auto NarrowBinOp = |
4888 | Builder.buildInstr(Opc: LHSOpc, DstOps: {NarrowTy}, SrcOps: {NarrowLHS, NarrowRHS}); |
4889 | auto Ext = Builder.buildZExt(Res: WideTy, Op: NarrowBinOp); |
4890 | Observer.changingInstr(MI); |
4891 | MI.getOperand(i: 1).setReg(Ext.getReg(Idx: 0)); |
4892 | Observer.changedInstr(MI); |
4893 | }; |
4894 | return true; |
4895 | } |
4896 | |
4897 | bool CombinerHelper::matchMulOBy2(MachineInstr &MI, BuildFnTy &MatchInfo) { |
4898 | unsigned Opc = MI.getOpcode(); |
4899 | assert(Opc == TargetOpcode::G_UMULO || Opc == TargetOpcode::G_SMULO); |
4900 | |
4901 | if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 2))) |
4902 | return false; |
4903 | |
4904 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
4905 | Observer.changingInstr(MI); |
4906 | unsigned NewOpc = Opc == TargetOpcode::G_UMULO ? TargetOpcode::G_UADDO |
4907 | : TargetOpcode::G_SADDO; |
4908 | MI.setDesc(Builder.getTII().get(Opcode: NewOpc)); |
4909 | MI.getOperand(i: 3).setReg(MI.getOperand(i: 2).getReg()); |
4910 | Observer.changedInstr(MI); |
4911 | }; |
4912 | return true; |
4913 | } |
4914 | |
4915 | bool CombinerHelper::matchMulOBy0(MachineInstr &MI, BuildFnTy &MatchInfo) { |
4916 | // (G_*MULO x, 0) -> 0 + no carry out |
4917 | assert(MI.getOpcode() == TargetOpcode::G_UMULO || |
4918 | MI.getOpcode() == TargetOpcode::G_SMULO); |
4919 | if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 0))) |
4920 | return false; |
4921 | Register Dst = MI.getOperand(i: 0).getReg(); |
4922 | Register Carry = MI.getOperand(i: 1).getReg(); |
4923 | if (!isConstantLegalOrBeforeLegalizer(Ty: MRI.getType(Reg: Dst)) || |
4924 | !isConstantLegalOrBeforeLegalizer(Ty: MRI.getType(Reg: Carry))) |
4925 | return false; |
4926 | MatchInfo = [=](MachineIRBuilder &B) { |
4927 | B.buildConstant(Res: Dst, Val: 0); |
4928 | B.buildConstant(Res: Carry, Val: 0); |
4929 | }; |
4930 | return true; |
4931 | } |
4932 | |
4933 | bool CombinerHelper::matchAddOBy0(MachineInstr &MI, BuildFnTy &MatchInfo) { |
4934 | // (G_*ADDO x, 0) -> x + no carry out |
4935 | assert(MI.getOpcode() == TargetOpcode::G_UADDO || |
4936 | MI.getOpcode() == TargetOpcode::G_SADDO); |
4937 | if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 0))) |
4938 | return false; |
4939 | Register Carry = MI.getOperand(i: 1).getReg(); |
4940 | if (!isConstantLegalOrBeforeLegalizer(Ty: MRI.getType(Reg: Carry))) |
4941 | return false; |
4942 | Register Dst = MI.getOperand(i: 0).getReg(); |
4943 | Register LHS = MI.getOperand(i: 2).getReg(); |
4944 | MatchInfo = [=](MachineIRBuilder &B) { |
4945 | B.buildCopy(Res: Dst, Op: LHS); |
4946 | B.buildConstant(Res: Carry, Val: 0); |
4947 | }; |
4948 | return true; |
4949 | } |
4950 | |
4951 | bool CombinerHelper::matchAddEToAddO(MachineInstr &MI, BuildFnTy &MatchInfo) { |
4952 | // (G_*ADDE x, y, 0) -> (G_*ADDO x, y) |
4953 | // (G_*SUBE x, y, 0) -> (G_*SUBO x, y) |
4954 | assert(MI.getOpcode() == TargetOpcode::G_UADDE || |
4955 | MI.getOpcode() == TargetOpcode::G_SADDE || |
4956 | MI.getOpcode() == TargetOpcode::G_USUBE || |
4957 | MI.getOpcode() == TargetOpcode::G_SSUBE); |
4958 | if (!mi_match(R: MI.getOperand(i: 4).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 0))) |
4959 | return false; |
4960 | MatchInfo = [&](MachineIRBuilder &B) { |
4961 | unsigned NewOpcode; |
4962 | switch (MI.getOpcode()) { |
4963 | case TargetOpcode::G_UADDE: |
4964 | NewOpcode = TargetOpcode::G_UADDO; |
4965 | break; |
4966 | case TargetOpcode::G_SADDE: |
4967 | NewOpcode = TargetOpcode::G_SADDO; |
4968 | break; |
4969 | case TargetOpcode::G_USUBE: |
4970 | NewOpcode = TargetOpcode::G_USUBO; |
4971 | break; |
4972 | case TargetOpcode::G_SSUBE: |
4973 | NewOpcode = TargetOpcode::G_SSUBO; |
4974 | break; |
4975 | } |
4976 | Observer.changingInstr(MI); |
4977 | MI.setDesc(B.getTII().get(Opcode: NewOpcode)); |
4978 | MI.removeOperand(OpNo: 4); |
4979 | Observer.changedInstr(MI); |
4980 | }; |
4981 | return true; |
4982 | } |
4983 | |
4984 | bool CombinerHelper::matchSubAddSameReg(MachineInstr &MI, |
4985 | BuildFnTy &MatchInfo) { |
4986 | assert(MI.getOpcode() == TargetOpcode::G_SUB); |
4987 | Register Dst = MI.getOperand(i: 0).getReg(); |
4988 | // (x + y) - z -> x (if y == z) |
4989 | // (x + y) - z -> y (if x == z) |
4990 | Register X, Y, Z; |
4991 | if (mi_match(R: Dst, MRI, P: m_GSub(L: m_GAdd(L: m_Reg(R&: X), R: m_Reg(R&: Y)), R: m_Reg(R&: Z)))) { |
4992 | Register ReplaceReg; |
4993 | int64_t CstX, CstY; |
4994 | if (Y == Z || (mi_match(R: Y, MRI, P: m_ICstOrSplat(Cst&: CstY)) && |
4995 | mi_match(R: Z, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstY)))) |
4996 | ReplaceReg = X; |
4997 | else if (X == Z || (mi_match(R: X, MRI, P: m_ICstOrSplat(Cst&: CstX)) && |
4998 | mi_match(R: Z, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstX)))) |
4999 | ReplaceReg = Y; |
5000 | if (ReplaceReg) { |
5001 | MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Res: Dst, Op: ReplaceReg); }; |
5002 | return true; |
5003 | } |
5004 | } |
5005 | |
5006 | // x - (y + z) -> 0 - y (if x == z) |
5007 | // x - (y + z) -> 0 - z (if x == y) |
5008 | if (mi_match(R: Dst, MRI, P: m_GSub(L: m_Reg(R&: X), R: m_GAdd(L: m_Reg(R&: Y), R: m_Reg(R&: Z))))) { |
5009 | Register ReplaceReg; |
5010 | int64_t CstX; |
5011 | if (X == Z || (mi_match(R: X, MRI, P: m_ICstOrSplat(Cst&: CstX)) && |
5012 | mi_match(R: Z, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstX)))) |
5013 | ReplaceReg = Y; |
5014 | else if (X == Y || (mi_match(R: X, MRI, P: m_ICstOrSplat(Cst&: CstX)) && |
5015 | mi_match(R: Y, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstX)))) |
5016 | ReplaceReg = Z; |
5017 | if (ReplaceReg) { |
5018 | MatchInfo = [=](MachineIRBuilder &B) { |
5019 | auto Zero = B.buildConstant(Res: MRI.getType(Reg: Dst), Val: 0); |
5020 | B.buildSub(Dst, Src0: Zero, Src1: ReplaceReg); |
5021 | }; |
5022 | return true; |
5023 | } |
5024 | } |
5025 | return false; |
5026 | } |
5027 | |
5028 | MachineInstr *CombinerHelper::buildUDivUsingMul(MachineInstr &MI) { |
5029 | assert(MI.getOpcode() == TargetOpcode::G_UDIV); |
5030 | auto &UDiv = cast<GenericMachineInstr>(Val&: MI); |
5031 | Register Dst = UDiv.getReg(Idx: 0); |
5032 | Register LHS = UDiv.getReg(Idx: 1); |
5033 | Register RHS = UDiv.getReg(Idx: 2); |
5034 | LLT Ty = MRI.getType(Reg: Dst); |
5035 | LLT ScalarTy = Ty.getScalarType(); |
5036 | const unsigned EltBits = ScalarTy.getScalarSizeInBits(); |
5037 | LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
5038 | LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType(); |
5039 | auto &MIB = Builder; |
5040 | MIB.setInstrAndDebugLoc(MI); |
5041 | |
5042 | bool UseNPQ = false; |
5043 | SmallVector<Register, 16> PreShifts, PostShifts, MagicFactors, NPQFactors; |
5044 | |
5045 | auto BuildUDIVPattern = [&](const Constant *C) { |
5046 | auto *CI = cast<ConstantInt>(Val: C); |
5047 | const APInt &Divisor = CI->getValue(); |
5048 | |
5049 | bool SelNPQ = false; |
5050 | APInt Magic(Divisor.getBitWidth(), 0); |
5051 | unsigned PreShift = 0, PostShift = 0; |
5052 | |
5053 | // Magic algorithm doesn't work for division by 1. We need to emit a select |
5054 | // at the end. |
5055 | // TODO: Use undef values for divisor of 1. |
5056 | if (!Divisor.isOne()) { |
5057 | UnsignedDivisionByConstantInfo magics = |
5058 | UnsignedDivisionByConstantInfo::get(D: Divisor); |
5059 | |
5060 | Magic = std::move(magics.Magic); |
5061 | |
5062 | assert(magics.PreShift < Divisor.getBitWidth() && |
5063 | "We shouldn't generate an undefined shift!" ); |
5064 | assert(magics.PostShift < Divisor.getBitWidth() && |
5065 | "We shouldn't generate an undefined shift!" ); |
5066 | assert((!magics.IsAdd || magics.PreShift == 0) && "Unexpected pre-shift" ); |
5067 | PreShift = magics.PreShift; |
5068 | PostShift = magics.PostShift; |
5069 | SelNPQ = magics.IsAdd; |
5070 | } |
5071 | |
5072 | PreShifts.push_back( |
5073 | Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: PreShift).getReg(Idx: 0)); |
5074 | MagicFactors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Magic).getReg(Idx: 0)); |
5075 | NPQFactors.push_back( |
5076 | Elt: MIB.buildConstant(Res: ScalarTy, |
5077 | Val: SelNPQ ? APInt::getOneBitSet(numBits: EltBits, BitNo: EltBits - 1) |
5078 | : APInt::getZero(numBits: EltBits)) |
5079 | .getReg(Idx: 0)); |
5080 | PostShifts.push_back( |
5081 | Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: PostShift).getReg(Idx: 0)); |
5082 | UseNPQ |= SelNPQ; |
5083 | return true; |
5084 | }; |
5085 | |
5086 | // Collect the shifts/magic values from each element. |
5087 | bool Matched = matchUnaryPredicate(MRI, Reg: RHS, Match: BuildUDIVPattern); |
5088 | (void)Matched; |
5089 | assert(Matched && "Expected unary predicate match to succeed" ); |
5090 | |
5091 | Register PreShift, PostShift, MagicFactor, NPQFactor; |
5092 | auto *RHSDef = getOpcodeDef<GBuildVector>(Reg: RHS, MRI); |
5093 | if (RHSDef) { |
5094 | PreShift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: PreShifts).getReg(Idx: 0); |
5095 | MagicFactor = MIB.buildBuildVector(Res: Ty, Ops: MagicFactors).getReg(Idx: 0); |
5096 | NPQFactor = MIB.buildBuildVector(Res: Ty, Ops: NPQFactors).getReg(Idx: 0); |
5097 | PostShift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: PostShifts).getReg(Idx: 0); |
5098 | } else { |
5099 | assert(MRI.getType(RHS).isScalar() && |
5100 | "Non-build_vector operation should have been a scalar" ); |
5101 | PreShift = PreShifts[0]; |
5102 | MagicFactor = MagicFactors[0]; |
5103 | PostShift = PostShifts[0]; |
5104 | } |
5105 | |
5106 | Register Q = LHS; |
5107 | Q = MIB.buildLShr(Dst: Ty, Src0: Q, Src1: PreShift).getReg(Idx: 0); |
5108 | |
5109 | // Multiply the numerator (operand 0) by the magic value. |
5110 | Q = MIB.buildUMulH(Dst: Ty, Src0: Q, Src1: MagicFactor).getReg(Idx: 0); |
5111 | |
5112 | if (UseNPQ) { |
5113 | Register NPQ = MIB.buildSub(Dst: Ty, Src0: LHS, Src1: Q).getReg(Idx: 0); |
5114 | |
5115 | // For vectors we might have a mix of non-NPQ/NPQ paths, so use |
5116 | // G_UMULH to act as a SRL-by-1 for NPQ, else multiply by zero. |
5117 | if (Ty.isVector()) |
5118 | NPQ = MIB.buildUMulH(Dst: Ty, Src0: NPQ, Src1: NPQFactor).getReg(Idx: 0); |
5119 | else |
5120 | NPQ = MIB.buildLShr(Dst: Ty, Src0: NPQ, Src1: MIB.buildConstant(Res: ShiftAmtTy, Val: 1)).getReg(Idx: 0); |
5121 | |
5122 | Q = MIB.buildAdd(Dst: Ty, Src0: NPQ, Src1: Q).getReg(Idx: 0); |
5123 | } |
5124 | |
5125 | Q = MIB.buildLShr(Dst: Ty, Src0: Q, Src1: PostShift).getReg(Idx: 0); |
5126 | auto One = MIB.buildConstant(Res: Ty, Val: 1); |
5127 | auto IsOne = MIB.buildICmp( |
5128 | Pred: CmpInst::Predicate::ICMP_EQ, |
5129 | Res: Ty.isScalar() ? LLT::scalar(SizeInBits: 1) : Ty.changeElementSize(NewEltSize: 1), Op0: RHS, Op1: One); |
5130 | return MIB.buildSelect(Res: Ty, Tst: IsOne, Op0: LHS, Op1: Q); |
5131 | } |
5132 | |
5133 | bool CombinerHelper::matchUDivByConst(MachineInstr &MI) { |
5134 | assert(MI.getOpcode() == TargetOpcode::G_UDIV); |
5135 | Register Dst = MI.getOperand(i: 0).getReg(); |
5136 | Register RHS = MI.getOperand(i: 2).getReg(); |
5137 | LLT DstTy = MRI.getType(Reg: Dst); |
5138 | auto *RHSDef = MRI.getVRegDef(Reg: RHS); |
5139 | if (!isConstantOrConstantVector(MI&: *RHSDef, MRI)) |
5140 | return false; |
5141 | |
5142 | auto &MF = *MI.getMF(); |
5143 | AttributeList Attr = MF.getFunction().getAttributes(); |
5144 | const auto &TLI = getTargetLowering(); |
5145 | LLVMContext &Ctx = MF.getFunction().getContext(); |
5146 | auto &DL = MF.getDataLayout(); |
5147 | if (TLI.isIntDivCheap(VT: getApproximateEVTForLLT(Ty: DstTy, DL, Ctx), Attr)) |
5148 | return false; |
5149 | |
5150 | // Don't do this for minsize because the instruction sequence is usually |
5151 | // larger. |
5152 | if (MF.getFunction().hasMinSize()) |
5153 | return false; |
5154 | |
5155 | // Don't do this if the types are not going to be legal. |
5156 | if (LI) { |
5157 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_MUL, {DstTy, DstTy}})) |
5158 | return false; |
5159 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_UMULH, {DstTy}})) |
5160 | return false; |
5161 | if (!isLegalOrBeforeLegalizer( |
5162 | Query: {TargetOpcode::G_ICMP, |
5163 | {DstTy.isVector() ? DstTy.changeElementSize(NewEltSize: 1) : LLT::scalar(SizeInBits: 1), |
5164 | DstTy}})) |
5165 | return false; |
5166 | } |
5167 | |
5168 | auto CheckEltValue = [&](const Constant *C) { |
5169 | if (auto *CI = dyn_cast_or_null<ConstantInt>(Val: C)) |
5170 | return !CI->isZero(); |
5171 | return false; |
5172 | }; |
5173 | return matchUnaryPredicate(MRI, Reg: RHS, Match: CheckEltValue); |
5174 | } |
5175 | |
5176 | void CombinerHelper::applyUDivByConst(MachineInstr &MI) { |
5177 | auto *NewMI = buildUDivUsingMul(MI); |
5178 | replaceSingleDefInstWithReg(MI, Replacement: NewMI->getOperand(i: 0).getReg()); |
5179 | } |
5180 | |
5181 | bool CombinerHelper::matchSDivByConst(MachineInstr &MI) { |
5182 | assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV" ); |
5183 | Register Dst = MI.getOperand(i: 0).getReg(); |
5184 | Register RHS = MI.getOperand(i: 2).getReg(); |
5185 | LLT DstTy = MRI.getType(Reg: Dst); |
5186 | |
5187 | auto &MF = *MI.getMF(); |
5188 | AttributeList Attr = MF.getFunction().getAttributes(); |
5189 | const auto &TLI = getTargetLowering(); |
5190 | LLVMContext &Ctx = MF.getFunction().getContext(); |
5191 | auto &DL = MF.getDataLayout(); |
5192 | if (TLI.isIntDivCheap(VT: getApproximateEVTForLLT(Ty: DstTy, DL, Ctx), Attr)) |
5193 | return false; |
5194 | |
5195 | // Don't do this for minsize because the instruction sequence is usually |
5196 | // larger. |
5197 | if (MF.getFunction().hasMinSize()) |
5198 | return false; |
5199 | |
5200 | // If the sdiv has an 'exact' flag we can use a simpler lowering. |
5201 | if (MI.getFlag(Flag: MachineInstr::MIFlag::IsExact)) { |
5202 | return matchUnaryPredicate( |
5203 | MRI, Reg: RHS, Match: [](const Constant *C) { return C && !C->isZeroValue(); }); |
5204 | } |
5205 | |
5206 | // Don't support the general case for now. |
5207 | return false; |
5208 | } |
5209 | |
5210 | void CombinerHelper::applySDivByConst(MachineInstr &MI) { |
5211 | auto *NewMI = buildSDivUsingMul(MI); |
5212 | replaceSingleDefInstWithReg(MI, Replacement: NewMI->getOperand(i: 0).getReg()); |
5213 | } |
5214 | |
5215 | MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) { |
5216 | assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV" ); |
5217 | auto &SDiv = cast<GenericMachineInstr>(Val&: MI); |
5218 | Register Dst = SDiv.getReg(Idx: 0); |
5219 | Register LHS = SDiv.getReg(Idx: 1); |
5220 | Register RHS = SDiv.getReg(Idx: 2); |
5221 | LLT Ty = MRI.getType(Reg: Dst); |
5222 | LLT ScalarTy = Ty.getScalarType(); |
5223 | LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
5224 | LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType(); |
5225 | auto &MIB = Builder; |
5226 | MIB.setInstrAndDebugLoc(MI); |
5227 | |
5228 | bool UseSRA = false; |
5229 | SmallVector<Register, 16> Shifts, Factors; |
5230 | |
5231 | auto *RHSDef = cast<GenericMachineInstr>(Val: getDefIgnoringCopies(Reg: RHS, MRI)); |
5232 | bool IsSplat = getIConstantSplatVal(MI: *RHSDef, MRI).has_value(); |
5233 | |
5234 | auto BuildSDIVPattern = [&](const Constant *C) { |
5235 | // Don't recompute inverses for each splat element. |
5236 | if (IsSplat && !Factors.empty()) { |
5237 | Shifts.push_back(Elt: Shifts[0]); |
5238 | Factors.push_back(Elt: Factors[0]); |
5239 | return true; |
5240 | } |
5241 | |
5242 | auto *CI = cast<ConstantInt>(Val: C); |
5243 | APInt Divisor = CI->getValue(); |
5244 | unsigned Shift = Divisor.countr_zero(); |
5245 | if (Shift) { |
5246 | Divisor.ashrInPlace(ShiftAmt: Shift); |
5247 | UseSRA = true; |
5248 | } |
5249 | |
5250 | // Calculate the multiplicative inverse modulo BW. |
5251 | // 2^W requires W + 1 bits, so we have to extend and then truncate. |
5252 | unsigned W = Divisor.getBitWidth(); |
5253 | APInt Factor = Divisor.zext(width: W + 1) |
5254 | .multiplicativeInverse(modulo: APInt::getSignedMinValue(numBits: W + 1)) |
5255 | .trunc(width: W); |
5256 | Shifts.push_back(Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: Shift).getReg(Idx: 0)); |
5257 | Factors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Factor).getReg(Idx: 0)); |
5258 | return true; |
5259 | }; |
5260 | |
5261 | // Collect all magic values from the build vector. |
5262 | bool Matched = matchUnaryPredicate(MRI, Reg: RHS, Match: BuildSDIVPattern); |
5263 | (void)Matched; |
5264 | assert(Matched && "Expected unary predicate match to succeed" ); |
5265 | |
5266 | Register Shift, Factor; |
5267 | if (Ty.isVector()) { |
5268 | Shift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: Shifts).getReg(Idx: 0); |
5269 | Factor = MIB.buildBuildVector(Res: Ty, Ops: Factors).getReg(Idx: 0); |
5270 | } else { |
5271 | Shift = Shifts[0]; |
5272 | Factor = Factors[0]; |
5273 | } |
5274 | |
5275 | Register Res = LHS; |
5276 | |
5277 | if (UseSRA) |
5278 | Res = MIB.buildAShr(Dst: Ty, Src0: Res, Src1: Shift, Flags: MachineInstr::IsExact).getReg(Idx: 0); |
5279 | |
5280 | return MIB.buildMul(Dst: Ty, Src0: Res, Src1: Factor); |
5281 | } |
5282 | |
5283 | bool CombinerHelper::matchUMulHToLShr(MachineInstr &MI) { |
5284 | assert(MI.getOpcode() == TargetOpcode::G_UMULH); |
5285 | Register RHS = MI.getOperand(i: 2).getReg(); |
5286 | Register Dst = MI.getOperand(i: 0).getReg(); |
5287 | LLT Ty = MRI.getType(Reg: Dst); |
5288 | LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
5289 | auto MatchPow2ExceptOne = [&](const Constant *C) { |
5290 | if (auto *CI = dyn_cast<ConstantInt>(Val: C)) |
5291 | return CI->getValue().isPowerOf2() && !CI->getValue().isOne(); |
5292 | return false; |
5293 | }; |
5294 | if (!matchUnaryPredicate(MRI, Reg: RHS, Match: MatchPow2ExceptOne, AllowUndefs: false)) |
5295 | return false; |
5296 | return isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_LSHR, {Ty, ShiftAmtTy}}); |
5297 | } |
5298 | |
5299 | void CombinerHelper::applyUMulHToLShr(MachineInstr &MI) { |
5300 | Register LHS = MI.getOperand(i: 1).getReg(); |
5301 | Register RHS = MI.getOperand(i: 2).getReg(); |
5302 | Register Dst = MI.getOperand(i: 0).getReg(); |
5303 | LLT Ty = MRI.getType(Reg: Dst); |
5304 | LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
5305 | unsigned NumEltBits = Ty.getScalarSizeInBits(); |
5306 | |
5307 | Builder.setInstrAndDebugLoc(MI); |
5308 | auto LogBase2 = buildLogBase2(V: RHS, MIB&: Builder); |
5309 | auto ShiftAmt = |
5310 | Builder.buildSub(Dst: Ty, Src0: Builder.buildConstant(Res: Ty, Val: NumEltBits), Src1: LogBase2); |
5311 | auto Trunc = Builder.buildZExtOrTrunc(Res: ShiftAmtTy, Op: ShiftAmt); |
5312 | Builder.buildLShr(Dst, Src0: LHS, Src1: Trunc); |
5313 | MI.eraseFromParent(); |
5314 | } |
5315 | |
5316 | bool CombinerHelper::matchRedundantNegOperands(MachineInstr &MI, |
5317 | BuildFnTy &MatchInfo) { |
5318 | unsigned Opc = MI.getOpcode(); |
5319 | assert(Opc == TargetOpcode::G_FADD || Opc == TargetOpcode::G_FSUB || |
5320 | Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV || |
5321 | Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA); |
5322 | |
5323 | Register Dst = MI.getOperand(i: 0).getReg(); |
5324 | Register X = MI.getOperand(i: 1).getReg(); |
5325 | Register Y = MI.getOperand(i: 2).getReg(); |
5326 | LLT Type = MRI.getType(Reg: Dst); |
5327 | |
5328 | // fold (fadd x, fneg(y)) -> (fsub x, y) |
5329 | // fold (fadd fneg(y), x) -> (fsub x, y) |
5330 | // G_ADD is commutative so both cases are checked by m_GFAdd |
5331 | if (mi_match(R: Dst, MRI, P: m_GFAdd(L: m_Reg(R&: X), R: m_GFNeg(Src: m_Reg(R&: Y)))) && |
5332 | isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FSUB, {Type}})) { |
5333 | Opc = TargetOpcode::G_FSUB; |
5334 | } |
5335 | /// fold (fsub x, fneg(y)) -> (fadd x, y) |
5336 | else if (mi_match(R: Dst, MRI, P: m_GFSub(L: m_Reg(R&: X), R: m_GFNeg(Src: m_Reg(R&: Y)))) && |
5337 | isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FADD, {Type}})) { |
5338 | Opc = TargetOpcode::G_FADD; |
5339 | } |
5340 | // fold (fmul fneg(x), fneg(y)) -> (fmul x, y) |
5341 | // fold (fdiv fneg(x), fneg(y)) -> (fdiv x, y) |
5342 | // fold (fmad fneg(x), fneg(y), z) -> (fmad x, y, z) |
5343 | // fold (fma fneg(x), fneg(y), z) -> (fma x, y, z) |
5344 | else if ((Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV || |
5345 | Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA) && |
5346 | mi_match(R: X, MRI, P: m_GFNeg(Src: m_Reg(R&: X))) && |
5347 | mi_match(R: Y, MRI, P: m_GFNeg(Src: m_Reg(R&: Y)))) { |
5348 | // no opcode change |
5349 | } else |
5350 | return false; |
5351 | |
5352 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5353 | Observer.changingInstr(MI); |
5354 | MI.setDesc(B.getTII().get(Opcode: Opc)); |
5355 | MI.getOperand(i: 1).setReg(X); |
5356 | MI.getOperand(i: 2).setReg(Y); |
5357 | Observer.changedInstr(MI); |
5358 | }; |
5359 | return true; |
5360 | } |
5361 | |
5362 | bool CombinerHelper::matchFsubToFneg(MachineInstr &MI, Register &MatchInfo) { |
5363 | assert(MI.getOpcode() == TargetOpcode::G_FSUB); |
5364 | |
5365 | Register LHS = MI.getOperand(i: 1).getReg(); |
5366 | MatchInfo = MI.getOperand(i: 2).getReg(); |
5367 | LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
5368 | |
5369 | const auto LHSCst = Ty.isVector() |
5370 | ? getFConstantSplat(VReg: LHS, MRI, /* allowUndef */ AllowUndef: true) |
5371 | : getFConstantVRegValWithLookThrough(VReg: LHS, MRI); |
5372 | if (!LHSCst) |
5373 | return false; |
5374 | |
5375 | // -0.0 is always allowed |
5376 | if (LHSCst->Value.isNegZero()) |
5377 | return true; |
5378 | |
5379 | // +0.0 is only allowed if nsz is set. |
5380 | if (LHSCst->Value.isPosZero()) |
5381 | return MI.getFlag(Flag: MachineInstr::FmNsz); |
5382 | |
5383 | return false; |
5384 | } |
5385 | |
5386 | void CombinerHelper::applyFsubToFneg(MachineInstr &MI, Register &MatchInfo) { |
5387 | Builder.setInstrAndDebugLoc(MI); |
5388 | Register Dst = MI.getOperand(i: 0).getReg(); |
5389 | Builder.buildFNeg( |
5390 | Dst, Src0: Builder.buildFCanonicalize(Dst: MRI.getType(Reg: Dst), Src0: MatchInfo).getReg(Idx: 0)); |
5391 | eraseInst(MI); |
5392 | } |
5393 | |
5394 | /// Checks if \p MI is TargetOpcode::G_FMUL and contractable either |
5395 | /// due to global flags or MachineInstr flags. |
5396 | static bool isContractableFMul(MachineInstr &MI, bool AllowFusionGlobally) { |
5397 | if (MI.getOpcode() != TargetOpcode::G_FMUL) |
5398 | return false; |
5399 | return AllowFusionGlobally || MI.getFlag(Flag: MachineInstr::MIFlag::FmContract); |
5400 | } |
5401 | |
5402 | static bool hasMoreUses(const MachineInstr &MI0, const MachineInstr &MI1, |
5403 | const MachineRegisterInfo &MRI) { |
5404 | return std::distance(first: MRI.use_instr_nodbg_begin(RegNo: MI0.getOperand(i: 0).getReg()), |
5405 | last: MRI.use_instr_nodbg_end()) > |
5406 | std::distance(first: MRI.use_instr_nodbg_begin(RegNo: MI1.getOperand(i: 0).getReg()), |
5407 | last: MRI.use_instr_nodbg_end()); |
5408 | } |
5409 | |
5410 | bool CombinerHelper::canCombineFMadOrFMA(MachineInstr &MI, |
5411 | bool &AllowFusionGlobally, |
5412 | bool &HasFMAD, bool &Aggressive, |
5413 | bool CanReassociate) { |
5414 | |
5415 | auto *MF = MI.getMF(); |
5416 | const auto &TLI = *MF->getSubtarget().getTargetLowering(); |
5417 | const TargetOptions &Options = MF->getTarget().Options; |
5418 | LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
5419 | |
5420 | if (CanReassociate && |
5421 | !(Options.UnsafeFPMath || MI.getFlag(Flag: MachineInstr::MIFlag::FmReassoc))) |
5422 | return false; |
5423 | |
5424 | // Floating-point multiply-add with intermediate rounding. |
5425 | HasFMAD = (!isPreLegalize() && TLI.isFMADLegal(MI, Ty: DstType)); |
5426 | // Floating-point multiply-add without intermediate rounding. |
5427 | bool HasFMA = TLI.isFMAFasterThanFMulAndFAdd(MF: *MF, DstType) && |
5428 | isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FMA, {DstType}}); |
5429 | // No valid opcode, do not combine. |
5430 | if (!HasFMAD && !HasFMA) |
5431 | return false; |
5432 | |
5433 | AllowFusionGlobally = Options.AllowFPOpFusion == FPOpFusion::Fast || |
5434 | Options.UnsafeFPMath || HasFMAD; |
5435 | // If the addition is not contractable, do not combine. |
5436 | if (!AllowFusionGlobally && !MI.getFlag(Flag: MachineInstr::MIFlag::FmContract)) |
5437 | return false; |
5438 | |
5439 | Aggressive = TLI.enableAggressiveFMAFusion(Ty: DstType); |
5440 | return true; |
5441 | } |
5442 | |
5443 | bool CombinerHelper::matchCombineFAddFMulToFMadOrFMA( |
5444 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
5445 | assert(MI.getOpcode() == TargetOpcode::G_FADD); |
5446 | |
5447 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
5448 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
5449 | return false; |
5450 | |
5451 | Register Op1 = MI.getOperand(i: 1).getReg(); |
5452 | Register Op2 = MI.getOperand(i: 2).getReg(); |
5453 | DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1}; |
5454 | DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2}; |
5455 | unsigned PreferredFusedOpcode = |
5456 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
5457 | |
5458 | // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), |
5459 | // prefer to fold the multiply with fewer uses. |
5460 | if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
5461 | isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) { |
5462 | if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI)) |
5463 | std::swap(a&: LHS, b&: RHS); |
5464 | } |
5465 | |
5466 | // fold (fadd (fmul x, y), z) -> (fma x, y, z) |
5467 | if (isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
5468 | (Aggressive || MRI.hasOneNonDBGUse(RegNo: LHS.Reg))) { |
5469 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5470 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5471 | SrcOps: {LHS.MI->getOperand(i: 1).getReg(), |
5472 | LHS.MI->getOperand(i: 2).getReg(), RHS.Reg}); |
5473 | }; |
5474 | return true; |
5475 | } |
5476 | |
5477 | // fold (fadd x, (fmul y, z)) -> (fma y, z, x) |
5478 | if (isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally) && |
5479 | (Aggressive || MRI.hasOneNonDBGUse(RegNo: RHS.Reg))) { |
5480 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5481 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5482 | SrcOps: {RHS.MI->getOperand(i: 1).getReg(), |
5483 | RHS.MI->getOperand(i: 2).getReg(), LHS.Reg}); |
5484 | }; |
5485 | return true; |
5486 | } |
5487 | |
5488 | return false; |
5489 | } |
5490 | |
5491 | bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMA( |
5492 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
5493 | assert(MI.getOpcode() == TargetOpcode::G_FADD); |
5494 | |
5495 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
5496 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
5497 | return false; |
5498 | |
5499 | const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering(); |
5500 | Register Op1 = MI.getOperand(i: 1).getReg(); |
5501 | Register Op2 = MI.getOperand(i: 2).getReg(); |
5502 | DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1}; |
5503 | DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2}; |
5504 | LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
5505 | |
5506 | unsigned PreferredFusedOpcode = |
5507 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
5508 | |
5509 | // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), |
5510 | // prefer to fold the multiply with fewer uses. |
5511 | if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
5512 | isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) { |
5513 | if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI)) |
5514 | std::swap(a&: LHS, b&: RHS); |
5515 | } |
5516 | |
5517 | // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) |
5518 | MachineInstr *FpExtSrc; |
5519 | if (mi_match(R: LHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FpExtSrc))) && |
5520 | isContractableFMul(MI&: *FpExtSrc, AllowFusionGlobally) && |
5521 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType, |
5522 | SrcTy: MRI.getType(Reg: FpExtSrc->getOperand(i: 1).getReg()))) { |
5523 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5524 | auto FpExtX = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 1).getReg()); |
5525 | auto FpExtY = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 2).getReg()); |
5526 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5527 | SrcOps: {FpExtX.getReg(Idx: 0), FpExtY.getReg(Idx: 0), RHS.Reg}); |
5528 | }; |
5529 | return true; |
5530 | } |
5531 | |
5532 | // fold (fadd z, (fpext (fmul x, y))) -> (fma (fpext x), (fpext y), z) |
5533 | // Note: Commutes FADD operands. |
5534 | if (mi_match(R: RHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FpExtSrc))) && |
5535 | isContractableFMul(MI&: *FpExtSrc, AllowFusionGlobally) && |
5536 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType, |
5537 | SrcTy: MRI.getType(Reg: FpExtSrc->getOperand(i: 1).getReg()))) { |
5538 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5539 | auto FpExtX = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 1).getReg()); |
5540 | auto FpExtY = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 2).getReg()); |
5541 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5542 | SrcOps: {FpExtX.getReg(Idx: 0), FpExtY.getReg(Idx: 0), LHS.Reg}); |
5543 | }; |
5544 | return true; |
5545 | } |
5546 | |
5547 | return false; |
5548 | } |
5549 | |
5550 | bool CombinerHelper::matchCombineFAddFMAFMulToFMadOrFMA( |
5551 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
5552 | assert(MI.getOpcode() == TargetOpcode::G_FADD); |
5553 | |
5554 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
5555 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive, CanReassociate: true)) |
5556 | return false; |
5557 | |
5558 | Register Op1 = MI.getOperand(i: 1).getReg(); |
5559 | Register Op2 = MI.getOperand(i: 2).getReg(); |
5560 | DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1}; |
5561 | DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2}; |
5562 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
5563 | |
5564 | unsigned PreferredFusedOpcode = |
5565 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
5566 | |
5567 | // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), |
5568 | // prefer to fold the multiply with fewer uses. |
5569 | if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
5570 | isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) { |
5571 | if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI)) |
5572 | std::swap(a&: LHS, b&: RHS); |
5573 | } |
5574 | |
5575 | MachineInstr *FMA = nullptr; |
5576 | Register Z; |
5577 | // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y, (fma u, v, z)) |
5578 | if (LHS.MI->getOpcode() == PreferredFusedOpcode && |
5579 | (MRI.getVRegDef(Reg: LHS.MI->getOperand(i: 3).getReg())->getOpcode() == |
5580 | TargetOpcode::G_FMUL) && |
5581 | MRI.hasOneNonDBGUse(RegNo: LHS.MI->getOperand(i: 0).getReg()) && |
5582 | MRI.hasOneNonDBGUse(RegNo: LHS.MI->getOperand(i: 3).getReg())) { |
5583 | FMA = LHS.MI; |
5584 | Z = RHS.Reg; |
5585 | } |
5586 | // fold (fadd z, (fma x, y, (fmul u, v))) -> (fma x, y, (fma u, v, z)) |
5587 | else if (RHS.MI->getOpcode() == PreferredFusedOpcode && |
5588 | (MRI.getVRegDef(Reg: RHS.MI->getOperand(i: 3).getReg())->getOpcode() == |
5589 | TargetOpcode::G_FMUL) && |
5590 | MRI.hasOneNonDBGUse(RegNo: RHS.MI->getOperand(i: 0).getReg()) && |
5591 | MRI.hasOneNonDBGUse(RegNo: RHS.MI->getOperand(i: 3).getReg())) { |
5592 | Z = LHS.Reg; |
5593 | FMA = RHS.MI; |
5594 | } |
5595 | |
5596 | if (FMA) { |
5597 | MachineInstr *FMulMI = MRI.getVRegDef(Reg: FMA->getOperand(i: 3).getReg()); |
5598 | Register X = FMA->getOperand(i: 1).getReg(); |
5599 | Register Y = FMA->getOperand(i: 2).getReg(); |
5600 | Register U = FMulMI->getOperand(i: 1).getReg(); |
5601 | Register V = FMulMI->getOperand(i: 2).getReg(); |
5602 | |
5603 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5604 | Register InnerFMA = MRI.createGenericVirtualRegister(Ty: DstTy); |
5605 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {InnerFMA}, SrcOps: {U, V, Z}); |
5606 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5607 | SrcOps: {X, Y, InnerFMA}); |
5608 | }; |
5609 | return true; |
5610 | } |
5611 | |
5612 | return false; |
5613 | } |
5614 | |
5615 | bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMAAggressive( |
5616 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
5617 | assert(MI.getOpcode() == TargetOpcode::G_FADD); |
5618 | |
5619 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
5620 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
5621 | return false; |
5622 | |
5623 | if (!Aggressive) |
5624 | return false; |
5625 | |
5626 | const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering(); |
5627 | LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
5628 | Register Op1 = MI.getOperand(i: 1).getReg(); |
5629 | Register Op2 = MI.getOperand(i: 2).getReg(); |
5630 | DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1}; |
5631 | DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2}; |
5632 | |
5633 | unsigned PreferredFusedOpcode = |
5634 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
5635 | |
5636 | // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), |
5637 | // prefer to fold the multiply with fewer uses. |
5638 | if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
5639 | isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) { |
5640 | if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI)) |
5641 | std::swap(a&: LHS, b&: RHS); |
5642 | } |
5643 | |
5644 | // Builds: (fma x, y, (fma (fpext u), (fpext v), z)) |
5645 | auto buildMatchInfo = [=, &MI](Register U, Register V, Register Z, Register X, |
5646 | Register Y, MachineIRBuilder &B) { |
5647 | Register FpExtU = B.buildFPExt(Res: DstType, Op: U).getReg(Idx: 0); |
5648 | Register FpExtV = B.buildFPExt(Res: DstType, Op: V).getReg(Idx: 0); |
5649 | Register InnerFMA = |
5650 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {DstType}, SrcOps: {FpExtU, FpExtV, Z}) |
5651 | .getReg(Idx: 0); |
5652 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5653 | SrcOps: {X, Y, InnerFMA}); |
5654 | }; |
5655 | |
5656 | MachineInstr *FMulMI, *FMAMI; |
5657 | // fold (fadd (fma x, y, (fpext (fmul u, v))), z) |
5658 | // -> (fma x, y, (fma (fpext u), (fpext v), z)) |
5659 | if (LHS.MI->getOpcode() == PreferredFusedOpcode && |
5660 | mi_match(R: LHS.MI->getOperand(i: 3).getReg(), MRI, |
5661 | P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) && |
5662 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
5663 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType, |
5664 | SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) { |
5665 | MatchInfo = [=](MachineIRBuilder &B) { |
5666 | buildMatchInfo(FMulMI->getOperand(i: 1).getReg(), |
5667 | FMulMI->getOperand(i: 2).getReg(), RHS.Reg, |
5668 | LHS.MI->getOperand(i: 1).getReg(), |
5669 | LHS.MI->getOperand(i: 2).getReg(), B); |
5670 | }; |
5671 | return true; |
5672 | } |
5673 | |
5674 | // fold (fadd (fpext (fma x, y, (fmul u, v))), z) |
5675 | // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z)) |
5676 | // FIXME: This turns two single-precision and one double-precision |
5677 | // operation into two double-precision operations, which might not be |
5678 | // interesting for all targets, especially GPUs. |
5679 | if (mi_match(R: LHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMAMI))) && |
5680 | FMAMI->getOpcode() == PreferredFusedOpcode) { |
5681 | MachineInstr *FMulMI = MRI.getVRegDef(Reg: FMAMI->getOperand(i: 3).getReg()); |
5682 | if (isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
5683 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType, |
5684 | SrcTy: MRI.getType(Reg: FMAMI->getOperand(i: 0).getReg()))) { |
5685 | MatchInfo = [=](MachineIRBuilder &B) { |
5686 | Register X = FMAMI->getOperand(i: 1).getReg(); |
5687 | Register Y = FMAMI->getOperand(i: 2).getReg(); |
5688 | X = B.buildFPExt(Res: DstType, Op: X).getReg(Idx: 0); |
5689 | Y = B.buildFPExt(Res: DstType, Op: Y).getReg(Idx: 0); |
5690 | buildMatchInfo(FMulMI->getOperand(i: 1).getReg(), |
5691 | FMulMI->getOperand(i: 2).getReg(), RHS.Reg, X, Y, B); |
5692 | }; |
5693 | |
5694 | return true; |
5695 | } |
5696 | } |
5697 | |
5698 | // fold (fadd z, (fma x, y, (fpext (fmul u, v))) |
5699 | // -> (fma x, y, (fma (fpext u), (fpext v), z)) |
5700 | if (RHS.MI->getOpcode() == PreferredFusedOpcode && |
5701 | mi_match(R: RHS.MI->getOperand(i: 3).getReg(), MRI, |
5702 | P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) && |
5703 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
5704 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType, |
5705 | SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) { |
5706 | MatchInfo = [=](MachineIRBuilder &B) { |
5707 | buildMatchInfo(FMulMI->getOperand(i: 1).getReg(), |
5708 | FMulMI->getOperand(i: 2).getReg(), LHS.Reg, |
5709 | RHS.MI->getOperand(i: 1).getReg(), |
5710 | RHS.MI->getOperand(i: 2).getReg(), B); |
5711 | }; |
5712 | return true; |
5713 | } |
5714 | |
5715 | // fold (fadd z, (fpext (fma x, y, (fmul u, v))) |
5716 | // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z)) |
5717 | // FIXME: This turns two single-precision and one double-precision |
5718 | // operation into two double-precision operations, which might not be |
5719 | // interesting for all targets, especially GPUs. |
5720 | if (mi_match(R: RHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMAMI))) && |
5721 | FMAMI->getOpcode() == PreferredFusedOpcode) { |
5722 | MachineInstr *FMulMI = MRI.getVRegDef(Reg: FMAMI->getOperand(i: 3).getReg()); |
5723 | if (isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
5724 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType, |
5725 | SrcTy: MRI.getType(Reg: FMAMI->getOperand(i: 0).getReg()))) { |
5726 | MatchInfo = [=](MachineIRBuilder &B) { |
5727 | Register X = FMAMI->getOperand(i: 1).getReg(); |
5728 | Register Y = FMAMI->getOperand(i: 2).getReg(); |
5729 | X = B.buildFPExt(Res: DstType, Op: X).getReg(Idx: 0); |
5730 | Y = B.buildFPExt(Res: DstType, Op: Y).getReg(Idx: 0); |
5731 | buildMatchInfo(FMulMI->getOperand(i: 1).getReg(), |
5732 | FMulMI->getOperand(i: 2).getReg(), LHS.Reg, X, Y, B); |
5733 | }; |
5734 | return true; |
5735 | } |
5736 | } |
5737 | |
5738 | return false; |
5739 | } |
5740 | |
5741 | bool CombinerHelper::matchCombineFSubFMulToFMadOrFMA( |
5742 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
5743 | assert(MI.getOpcode() == TargetOpcode::G_FSUB); |
5744 | |
5745 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
5746 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
5747 | return false; |
5748 | |
5749 | Register Op1 = MI.getOperand(i: 1).getReg(); |
5750 | Register Op2 = MI.getOperand(i: 2).getReg(); |
5751 | DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1}; |
5752 | DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2}; |
5753 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
5754 | |
5755 | // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), |
5756 | // prefer to fold the multiply with fewer uses. |
5757 | int FirstMulHasFewerUses = true; |
5758 | if (isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
5759 | isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally) && |
5760 | hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI)) |
5761 | FirstMulHasFewerUses = false; |
5762 | |
5763 | unsigned PreferredFusedOpcode = |
5764 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
5765 | |
5766 | // fold (fsub (fmul x, y), z) -> (fma x, y, -z) |
5767 | if (FirstMulHasFewerUses && |
5768 | (isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
5769 | (Aggressive || MRI.hasOneNonDBGUse(RegNo: LHS.Reg)))) { |
5770 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5771 | Register NegZ = B.buildFNeg(Dst: DstTy, Src0: RHS.Reg).getReg(Idx: 0); |
5772 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5773 | SrcOps: {LHS.MI->getOperand(i: 1).getReg(), |
5774 | LHS.MI->getOperand(i: 2).getReg(), NegZ}); |
5775 | }; |
5776 | return true; |
5777 | } |
5778 | // fold (fsub x, (fmul y, z)) -> (fma -y, z, x) |
5779 | else if ((isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally) && |
5780 | (Aggressive || MRI.hasOneNonDBGUse(RegNo: RHS.Reg)))) { |
5781 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5782 | Register NegY = |
5783 | B.buildFNeg(Dst: DstTy, Src0: RHS.MI->getOperand(i: 1).getReg()).getReg(Idx: 0); |
5784 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5785 | SrcOps: {NegY, RHS.MI->getOperand(i: 2).getReg(), LHS.Reg}); |
5786 | }; |
5787 | return true; |
5788 | } |
5789 | |
5790 | return false; |
5791 | } |
5792 | |
5793 | bool CombinerHelper::matchCombineFSubFNegFMulToFMadOrFMA( |
5794 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
5795 | assert(MI.getOpcode() == TargetOpcode::G_FSUB); |
5796 | |
5797 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
5798 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
5799 | return false; |
5800 | |
5801 | Register LHSReg = MI.getOperand(i: 1).getReg(); |
5802 | Register RHSReg = MI.getOperand(i: 2).getReg(); |
5803 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
5804 | |
5805 | unsigned PreferredFusedOpcode = |
5806 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
5807 | |
5808 | MachineInstr *FMulMI; |
5809 | // fold (fsub (fneg (fmul x, y)), z) -> (fma (fneg x), y, (fneg z)) |
5810 | if (mi_match(R: LHSReg, MRI, P: m_GFNeg(Src: m_MInstr(MI&: FMulMI))) && |
5811 | (Aggressive || (MRI.hasOneNonDBGUse(RegNo: LHSReg) && |
5812 | MRI.hasOneNonDBGUse(RegNo: FMulMI->getOperand(i: 0).getReg()))) && |
5813 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally)) { |
5814 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5815 | Register NegX = |
5816 | B.buildFNeg(Dst: DstTy, Src0: FMulMI->getOperand(i: 1).getReg()).getReg(Idx: 0); |
5817 | Register NegZ = B.buildFNeg(Dst: DstTy, Src0: RHSReg).getReg(Idx: 0); |
5818 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5819 | SrcOps: {NegX, FMulMI->getOperand(i: 2).getReg(), NegZ}); |
5820 | }; |
5821 | return true; |
5822 | } |
5823 | |
5824 | // fold (fsub x, (fneg (fmul, y, z))) -> (fma y, z, x) |
5825 | if (mi_match(R: RHSReg, MRI, P: m_GFNeg(Src: m_MInstr(MI&: FMulMI))) && |
5826 | (Aggressive || (MRI.hasOneNonDBGUse(RegNo: RHSReg) && |
5827 | MRI.hasOneNonDBGUse(RegNo: FMulMI->getOperand(i: 0).getReg()))) && |
5828 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally)) { |
5829 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5830 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5831 | SrcOps: {FMulMI->getOperand(i: 1).getReg(), |
5832 | FMulMI->getOperand(i: 2).getReg(), LHSReg}); |
5833 | }; |
5834 | return true; |
5835 | } |
5836 | |
5837 | return false; |
5838 | } |
5839 | |
5840 | bool CombinerHelper::matchCombineFSubFpExtFMulToFMadOrFMA( |
5841 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
5842 | assert(MI.getOpcode() == TargetOpcode::G_FSUB); |
5843 | |
5844 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
5845 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
5846 | return false; |
5847 | |
5848 | Register LHSReg = MI.getOperand(i: 1).getReg(); |
5849 | Register RHSReg = MI.getOperand(i: 2).getReg(); |
5850 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
5851 | |
5852 | unsigned PreferredFusedOpcode = |
5853 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
5854 | |
5855 | MachineInstr *FMulMI; |
5856 | // fold (fsub (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), (fneg z)) |
5857 | if (mi_match(R: LHSReg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) && |
5858 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
5859 | (Aggressive || MRI.hasOneNonDBGUse(RegNo: LHSReg))) { |
5860 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5861 | Register FpExtX = |
5862 | B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 1).getReg()).getReg(Idx: 0); |
5863 | Register FpExtY = |
5864 | B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 2).getReg()).getReg(Idx: 0); |
5865 | Register NegZ = B.buildFNeg(Dst: DstTy, Src0: RHSReg).getReg(Idx: 0); |
5866 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5867 | SrcOps: {FpExtX, FpExtY, NegZ}); |
5868 | }; |
5869 | return true; |
5870 | } |
5871 | |
5872 | // fold (fsub x, (fpext (fmul y, z))) -> (fma (fneg (fpext y)), (fpext z), x) |
5873 | if (mi_match(R: RHSReg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) && |
5874 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
5875 | (Aggressive || MRI.hasOneNonDBGUse(RegNo: RHSReg))) { |
5876 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5877 | Register FpExtY = |
5878 | B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 1).getReg()).getReg(Idx: 0); |
5879 | Register NegY = B.buildFNeg(Dst: DstTy, Src0: FpExtY).getReg(Idx: 0); |
5880 | Register FpExtZ = |
5881 | B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 2).getReg()).getReg(Idx: 0); |
5882 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5883 | SrcOps: {NegY, FpExtZ, LHSReg}); |
5884 | }; |
5885 | return true; |
5886 | } |
5887 | |
5888 | return false; |
5889 | } |
5890 | |
5891 | bool CombinerHelper::matchCombineFSubFpExtFNegFMulToFMadOrFMA( |
5892 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
5893 | assert(MI.getOpcode() == TargetOpcode::G_FSUB); |
5894 | |
5895 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
5896 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
5897 | return false; |
5898 | |
5899 | const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering(); |
5900 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
5901 | Register LHSReg = MI.getOperand(i: 1).getReg(); |
5902 | Register RHSReg = MI.getOperand(i: 2).getReg(); |
5903 | |
5904 | unsigned PreferredFusedOpcode = |
5905 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
5906 | |
5907 | auto buildMatchInfo = [=](Register Dst, Register X, Register Y, Register Z, |
5908 | MachineIRBuilder &B) { |
5909 | Register FpExtX = B.buildFPExt(Res: DstTy, Op: X).getReg(Idx: 0); |
5910 | Register FpExtY = B.buildFPExt(Res: DstTy, Op: Y).getReg(Idx: 0); |
5911 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {Dst}, SrcOps: {FpExtX, FpExtY, Z}); |
5912 | }; |
5913 | |
5914 | MachineInstr *FMulMI; |
5915 | // fold (fsub (fpext (fneg (fmul x, y))), z) -> |
5916 | // (fneg (fma (fpext x), (fpext y), z)) |
5917 | // fold (fsub (fneg (fpext (fmul x, y))), z) -> |
5918 | // (fneg (fma (fpext x), (fpext y), z)) |
5919 | if ((mi_match(R: LHSReg, MRI, P: m_GFPExt(Src: m_GFNeg(Src: m_MInstr(MI&: FMulMI)))) || |
5920 | mi_match(R: LHSReg, MRI, P: m_GFNeg(Src: m_GFPExt(Src: m_MInstr(MI&: FMulMI))))) && |
5921 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
5922 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstTy, |
5923 | SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) { |
5924 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5925 | Register FMAReg = MRI.createGenericVirtualRegister(Ty: DstTy); |
5926 | buildMatchInfo(FMAReg, FMulMI->getOperand(i: 1).getReg(), |
5927 | FMulMI->getOperand(i: 2).getReg(), RHSReg, B); |
5928 | B.buildFNeg(Dst: MI.getOperand(i: 0).getReg(), Src0: FMAReg); |
5929 | }; |
5930 | return true; |
5931 | } |
5932 | |
5933 | // fold (fsub x, (fpext (fneg (fmul y, z)))) -> (fma (fpext y), (fpext z), x) |
5934 | // fold (fsub x, (fneg (fpext (fmul y, z)))) -> (fma (fpext y), (fpext z), x) |
5935 | if ((mi_match(R: RHSReg, MRI, P: m_GFPExt(Src: m_GFNeg(Src: m_MInstr(MI&: FMulMI)))) || |
5936 | mi_match(R: RHSReg, MRI, P: m_GFNeg(Src: m_GFPExt(Src: m_MInstr(MI&: FMulMI))))) && |
5937 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
5938 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstTy, |
5939 | SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) { |
5940 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5941 | buildMatchInfo(MI.getOperand(i: 0).getReg(), FMulMI->getOperand(i: 1).getReg(), |
5942 | FMulMI->getOperand(i: 2).getReg(), LHSReg, B); |
5943 | }; |
5944 | return true; |
5945 | } |
5946 | |
5947 | return false; |
5948 | } |
5949 | |
5950 | bool CombinerHelper::matchCombineFMinMaxNaN(MachineInstr &MI, |
5951 | unsigned &IdxToPropagate) { |
5952 | bool PropagateNaN; |
5953 | switch (MI.getOpcode()) { |
5954 | default: |
5955 | return false; |
5956 | case TargetOpcode::G_FMINNUM: |
5957 | case TargetOpcode::G_FMAXNUM: |
5958 | PropagateNaN = false; |
5959 | break; |
5960 | case TargetOpcode::G_FMINIMUM: |
5961 | case TargetOpcode::G_FMAXIMUM: |
5962 | PropagateNaN = true; |
5963 | break; |
5964 | } |
5965 | |
5966 | auto MatchNaN = [&](unsigned Idx) { |
5967 | Register MaybeNaNReg = MI.getOperand(i: Idx).getReg(); |
5968 | const ConstantFP *MaybeCst = getConstantFPVRegVal(VReg: MaybeNaNReg, MRI); |
5969 | if (!MaybeCst || !MaybeCst->getValueAPF().isNaN()) |
5970 | return false; |
5971 | IdxToPropagate = PropagateNaN ? Idx : (Idx == 1 ? 2 : 1); |
5972 | return true; |
5973 | }; |
5974 | |
5975 | return MatchNaN(1) || MatchNaN(2); |
5976 | } |
5977 | |
5978 | bool CombinerHelper::matchAddSubSameReg(MachineInstr &MI, Register &Src) { |
5979 | assert(MI.getOpcode() == TargetOpcode::G_ADD && "Expected a G_ADD" ); |
5980 | Register LHS = MI.getOperand(i: 1).getReg(); |
5981 | Register RHS = MI.getOperand(i: 2).getReg(); |
5982 | |
5983 | // Helper lambda to check for opportunities for |
5984 | // A + (B - A) -> B |
5985 | // (B - A) + A -> B |
5986 | auto CheckFold = [&](Register MaybeSub, Register MaybeSameReg) { |
5987 | Register Reg; |
5988 | return mi_match(R: MaybeSub, MRI, P: m_GSub(L: m_Reg(R&: Src), R: m_Reg(R&: Reg))) && |
5989 | Reg == MaybeSameReg; |
5990 | }; |
5991 | return CheckFold(LHS, RHS) || CheckFold(RHS, LHS); |
5992 | } |
5993 | |
5994 | bool CombinerHelper::matchBuildVectorIdentityFold(MachineInstr &MI, |
5995 | Register &MatchInfo) { |
5996 | // This combine folds the following patterns: |
5997 | // |
5998 | // G_BUILD_VECTOR_TRUNC (G_BITCAST(x), G_LSHR(G_BITCAST(x), k)) |
5999 | // G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), G_TRUNC(G_LSHR(G_BITCAST(x), k))) |
6000 | // into |
6001 | // x |
6002 | // if |
6003 | // k == sizeof(VecEltTy)/2 |
6004 | // type(x) == type(dst) |
6005 | // |
6006 | // G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), undef) |
6007 | // into |
6008 | // x |
6009 | // if |
6010 | // type(x) == type(dst) |
6011 | |
6012 | LLT DstVecTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
6013 | LLT DstEltTy = DstVecTy.getElementType(); |
6014 | |
6015 | Register Lo, Hi; |
6016 | |
6017 | if (mi_match( |
6018 | MI, MRI, |
6019 | P: m_GBuildVector(L: m_GTrunc(Src: m_GBitcast(Src: m_Reg(R&: Lo))), R: m_GImplicitDef()))) { |
6020 | MatchInfo = Lo; |
6021 | return MRI.getType(Reg: MatchInfo) == DstVecTy; |
6022 | } |
6023 | |
6024 | std::optional<ValueAndVReg> ShiftAmount; |
6025 | const auto LoPattern = m_GBitcast(Src: m_Reg(R&: Lo)); |
6026 | const auto HiPattern = m_GLShr(L: m_GBitcast(Src: m_Reg(R&: Hi)), R: m_GCst(ValReg&: ShiftAmount)); |
6027 | if (mi_match( |
6028 | MI, MRI, |
6029 | P: m_any_of(preds: m_GBuildVectorTrunc(L: LoPattern, R: HiPattern), |
6030 | preds: m_GBuildVector(L: m_GTrunc(Src: LoPattern), R: m_GTrunc(Src: HiPattern))))) { |
6031 | if (Lo == Hi && ShiftAmount->Value == DstEltTy.getSizeInBits()) { |
6032 | MatchInfo = Lo; |
6033 | return MRI.getType(Reg: MatchInfo) == DstVecTy; |
6034 | } |
6035 | } |
6036 | |
6037 | return false; |
6038 | } |
6039 | |
6040 | bool CombinerHelper::matchTruncBuildVectorFold(MachineInstr &MI, |
6041 | Register &MatchInfo) { |
6042 | // Replace (G_TRUNC (G_BITCAST (G_BUILD_VECTOR x, y)) with just x |
6043 | // if type(x) == type(G_TRUNC) |
6044 | if (!mi_match(R: MI.getOperand(i: 1).getReg(), MRI, |
6045 | P: m_GBitcast(Src: m_GBuildVector(L: m_Reg(R&: MatchInfo), R: m_Reg())))) |
6046 | return false; |
6047 | |
6048 | return MRI.getType(Reg: MatchInfo) == MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
6049 | } |
6050 | |
6051 | bool CombinerHelper::matchTruncLshrBuildVectorFold(MachineInstr &MI, |
6052 | Register &MatchInfo) { |
6053 | // Replace (G_TRUNC (G_LSHR (G_BITCAST (G_BUILD_VECTOR x, y)), K)) with |
6054 | // y if K == size of vector element type |
6055 | std::optional<ValueAndVReg> ShiftAmt; |
6056 | if (!mi_match(R: MI.getOperand(i: 1).getReg(), MRI, |
6057 | P: m_GLShr(L: m_GBitcast(Src: m_GBuildVector(L: m_Reg(), R: m_Reg(R&: MatchInfo))), |
6058 | R: m_GCst(ValReg&: ShiftAmt)))) |
6059 | return false; |
6060 | |
6061 | LLT MatchTy = MRI.getType(Reg: MatchInfo); |
6062 | return ShiftAmt->Value.getZExtValue() == MatchTy.getSizeInBits() && |
6063 | MatchTy == MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
6064 | } |
6065 | |
6066 | unsigned CombinerHelper::getFPMinMaxOpcForSelect( |
6067 | CmpInst::Predicate Pred, LLT DstTy, |
6068 | SelectPatternNaNBehaviour VsNaNRetVal) const { |
6069 | assert(VsNaNRetVal != SelectPatternNaNBehaviour::NOT_APPLICABLE && |
6070 | "Expected a NaN behaviour?" ); |
6071 | // Choose an opcode based off of legality or the behaviour when one of the |
6072 | // LHS/RHS may be NaN. |
6073 | switch (Pred) { |
6074 | default: |
6075 | return 0; |
6076 | case CmpInst::FCMP_UGT: |
6077 | case CmpInst::FCMP_UGE: |
6078 | case CmpInst::FCMP_OGT: |
6079 | case CmpInst::FCMP_OGE: |
6080 | if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER) |
6081 | return TargetOpcode::G_FMAXNUM; |
6082 | if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN) |
6083 | return TargetOpcode::G_FMAXIMUM; |
6084 | if (isLegal(Query: {TargetOpcode::G_FMAXNUM, {DstTy}})) |
6085 | return TargetOpcode::G_FMAXNUM; |
6086 | if (isLegal(Query: {TargetOpcode::G_FMAXIMUM, {DstTy}})) |
6087 | return TargetOpcode::G_FMAXIMUM; |
6088 | return 0; |
6089 | case CmpInst::FCMP_ULT: |
6090 | case CmpInst::FCMP_ULE: |
6091 | case CmpInst::FCMP_OLT: |
6092 | case CmpInst::FCMP_OLE: |
6093 | if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER) |
6094 | return TargetOpcode::G_FMINNUM; |
6095 | if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN) |
6096 | return TargetOpcode::G_FMINIMUM; |
6097 | if (isLegal(Query: {TargetOpcode::G_FMINNUM, {DstTy}})) |
6098 | return TargetOpcode::G_FMINNUM; |
6099 | if (!isLegal(Query: {TargetOpcode::G_FMINIMUM, {DstTy}})) |
6100 | return 0; |
6101 | return TargetOpcode::G_FMINIMUM; |
6102 | } |
6103 | } |
6104 | |
6105 | CombinerHelper::SelectPatternNaNBehaviour |
6106 | CombinerHelper::computeRetValAgainstNaN(Register LHS, Register RHS, |
6107 | bool IsOrderedComparison) const { |
6108 | bool LHSSafe = isKnownNeverNaN(Val: LHS, MRI); |
6109 | bool RHSSafe = isKnownNeverNaN(Val: RHS, MRI); |
6110 | // Completely unsafe. |
6111 | if (!LHSSafe && !RHSSafe) |
6112 | return SelectPatternNaNBehaviour::NOT_APPLICABLE; |
6113 | if (LHSSafe && RHSSafe) |
6114 | return SelectPatternNaNBehaviour::RETURNS_ANY; |
6115 | // An ordered comparison will return false when given a NaN, so it |
6116 | // returns the RHS. |
6117 | if (IsOrderedComparison) |
6118 | return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_NAN |
6119 | : SelectPatternNaNBehaviour::RETURNS_OTHER; |
6120 | // An unordered comparison will return true when given a NaN, so it |
6121 | // returns the LHS. |
6122 | return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_OTHER |
6123 | : SelectPatternNaNBehaviour::RETURNS_NAN; |
6124 | } |
6125 | |
6126 | bool CombinerHelper::matchFPSelectToMinMax(Register Dst, Register Cond, |
6127 | Register TrueVal, Register FalseVal, |
6128 | BuildFnTy &MatchInfo) { |
6129 | // Match: select (fcmp cond x, y) x, y |
6130 | // select (fcmp cond x, y) y, x |
6131 | // And turn it into fminnum/fmaxnum or fmin/fmax based off of the condition. |
6132 | LLT DstTy = MRI.getType(Reg: Dst); |
6133 | // Bail out early on pointers, since we'll never want to fold to a min/max. |
6134 | if (DstTy.isPointer()) |
6135 | return false; |
6136 | // Match a floating point compare with a less-than/greater-than predicate. |
6137 | // TODO: Allow multiple users of the compare if they are all selects. |
6138 | CmpInst::Predicate Pred; |
6139 | Register CmpLHS, CmpRHS; |
6140 | if (!mi_match(R: Cond, MRI, |
6141 | P: m_OneNonDBGUse( |
6142 | SP: m_GFCmp(P: m_Pred(P&: Pred), L: m_Reg(R&: CmpLHS), R: m_Reg(R&: CmpRHS)))) || |
6143 | CmpInst::isEquality(pred: Pred)) |
6144 | return false; |
6145 | SelectPatternNaNBehaviour ResWithKnownNaNInfo = |
6146 | computeRetValAgainstNaN(LHS: CmpLHS, RHS: CmpRHS, IsOrderedComparison: CmpInst::isOrdered(predicate: Pred)); |
6147 | if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::NOT_APPLICABLE) |
6148 | return false; |
6149 | if (TrueVal == CmpRHS && FalseVal == CmpLHS) { |
6150 | std::swap(a&: CmpLHS, b&: CmpRHS); |
6151 | Pred = CmpInst::getSwappedPredicate(pred: Pred); |
6152 | if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_NAN) |
6153 | ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_OTHER; |
6154 | else if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_OTHER) |
6155 | ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_NAN; |
6156 | } |
6157 | if (TrueVal != CmpLHS || FalseVal != CmpRHS) |
6158 | return false; |
6159 | // Decide what type of max/min this should be based off of the predicate. |
6160 | unsigned Opc = getFPMinMaxOpcForSelect(Pred, DstTy, VsNaNRetVal: ResWithKnownNaNInfo); |
6161 | if (!Opc || !isLegal(Query: {Opc, {DstTy}})) |
6162 | return false; |
6163 | // Comparisons between signed zero and zero may have different results... |
6164 | // unless we have fmaximum/fminimum. In that case, we know -0 < 0. |
6165 | if (Opc != TargetOpcode::G_FMAXIMUM && Opc != TargetOpcode::G_FMINIMUM) { |
6166 | // We don't know if a comparison between two 0s will give us a consistent |
6167 | // result. Be conservative and only proceed if at least one side is |
6168 | // non-zero. |
6169 | auto KnownNonZeroSide = getFConstantVRegValWithLookThrough(VReg: CmpLHS, MRI); |
6170 | if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero()) { |
6171 | KnownNonZeroSide = getFConstantVRegValWithLookThrough(VReg: CmpRHS, MRI); |
6172 | if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero()) |
6173 | return false; |
6174 | } |
6175 | } |
6176 | MatchInfo = [=](MachineIRBuilder &B) { |
6177 | B.buildInstr(Opc, DstOps: {Dst}, SrcOps: {CmpLHS, CmpRHS}); |
6178 | }; |
6179 | return true; |
6180 | } |
6181 | |
6182 | bool CombinerHelper::matchSimplifySelectToMinMax(MachineInstr &MI, |
6183 | BuildFnTy &MatchInfo) { |
6184 | // TODO: Handle integer cases. |
6185 | assert(MI.getOpcode() == TargetOpcode::G_SELECT); |
6186 | // Condition may be fed by a truncated compare. |
6187 | Register Cond = MI.getOperand(i: 1).getReg(); |
6188 | Register MaybeTrunc; |
6189 | if (mi_match(R: Cond, MRI, P: m_OneNonDBGUse(SP: m_GTrunc(Src: m_Reg(R&: MaybeTrunc))))) |
6190 | Cond = MaybeTrunc; |
6191 | Register Dst = MI.getOperand(i: 0).getReg(); |
6192 | Register TrueVal = MI.getOperand(i: 2).getReg(); |
6193 | Register FalseVal = MI.getOperand(i: 3).getReg(); |
6194 | return matchFPSelectToMinMax(Dst, Cond, TrueVal, FalseVal, MatchInfo); |
6195 | } |
6196 | |
6197 | bool CombinerHelper::matchRedundantBinOpInEquality(MachineInstr &MI, |
6198 | BuildFnTy &MatchInfo) { |
6199 | assert(MI.getOpcode() == TargetOpcode::G_ICMP); |
6200 | // (X + Y) == X --> Y == 0 |
6201 | // (X + Y) != X --> Y != 0 |
6202 | // (X - Y) == X --> Y == 0 |
6203 | // (X - Y) != X --> Y != 0 |
6204 | // (X ^ Y) == X --> Y == 0 |
6205 | // (X ^ Y) != X --> Y != 0 |
6206 | Register Dst = MI.getOperand(i: 0).getReg(); |
6207 | CmpInst::Predicate Pred; |
6208 | Register X, Y, OpLHS, OpRHS; |
6209 | bool MatchedSub = mi_match( |
6210 | R: Dst, MRI, |
6211 | P: m_c_GICmp(P: m_Pred(P&: Pred), L: m_Reg(R&: X), R: m_GSub(L: m_Reg(R&: OpLHS), R: m_Reg(R&: Y)))); |
6212 | if (MatchedSub && X != OpLHS) |
6213 | return false; |
6214 | if (!MatchedSub) { |
6215 | if (!mi_match(R: Dst, MRI, |
6216 | P: m_c_GICmp(P: m_Pred(P&: Pred), L: m_Reg(R&: X), |
6217 | R: m_any_of(preds: m_GAdd(L: m_Reg(R&: OpLHS), R: m_Reg(R&: OpRHS)), |
6218 | preds: m_GXor(L: m_Reg(R&: OpLHS), R: m_Reg(R&: OpRHS)))))) |
6219 | return false; |
6220 | Y = X == OpLHS ? OpRHS : X == OpRHS ? OpLHS : Register(); |
6221 | } |
6222 | MatchInfo = [=](MachineIRBuilder &B) { |
6223 | auto Zero = B.buildConstant(Res: MRI.getType(Reg: Y), Val: 0); |
6224 | B.buildICmp(Pred, Res: Dst, Op0: Y, Op1: Zero); |
6225 | }; |
6226 | return CmpInst::isEquality(pred: Pred) && Y.isValid(); |
6227 | } |
6228 | |
6229 | bool CombinerHelper::matchShiftsTooBig(MachineInstr &MI) { |
6230 | Register ShiftReg = MI.getOperand(i: 2).getReg(); |
6231 | LLT ResTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
6232 | auto IsShiftTooBig = [&](const Constant *C) { |
6233 | auto *CI = dyn_cast<ConstantInt>(Val: C); |
6234 | return CI && CI->uge(Num: ResTy.getScalarSizeInBits()); |
6235 | }; |
6236 | return matchUnaryPredicate(MRI, Reg: ShiftReg, Match: IsShiftTooBig); |
6237 | } |
6238 | |
6239 | bool CombinerHelper::matchCommuteConstantToRHS(MachineInstr &MI) { |
6240 | Register LHS = MI.getOperand(i: 1).getReg(); |
6241 | Register RHS = MI.getOperand(i: 2).getReg(); |
6242 | auto *LHSDef = MRI.getVRegDef(Reg: LHS); |
6243 | if (getIConstantVRegVal(VReg: LHS, MRI).has_value()) |
6244 | return true; |
6245 | |
6246 | // LHS may be a G_CONSTANT_FOLD_BARRIER. If so we commute |
6247 | // as long as we don't already have a constant on the RHS. |
6248 | if (LHSDef->getOpcode() != TargetOpcode::G_CONSTANT_FOLD_BARRIER) |
6249 | return false; |
6250 | return MRI.getVRegDef(Reg: RHS)->getOpcode() != |
6251 | TargetOpcode::G_CONSTANT_FOLD_BARRIER && |
6252 | !getIConstantVRegVal(VReg: RHS, MRI); |
6253 | } |
6254 | |
6255 | bool CombinerHelper::matchCommuteFPConstantToRHS(MachineInstr &MI) { |
6256 | Register LHS = MI.getOperand(i: 1).getReg(); |
6257 | Register RHS = MI.getOperand(i: 2).getReg(); |
6258 | std::optional<FPValueAndVReg> ValAndVReg; |
6259 | if (!mi_match(R: LHS, MRI, P: m_GFCstOrSplat(FPValReg&: ValAndVReg))) |
6260 | return false; |
6261 | return !mi_match(R: RHS, MRI, P: m_GFCstOrSplat(FPValReg&: ValAndVReg)); |
6262 | } |
6263 | |
6264 | void CombinerHelper::applyCommuteBinOpOperands(MachineInstr &MI) { |
6265 | Observer.changingInstr(MI); |
6266 | Register LHSReg = MI.getOperand(i: 1).getReg(); |
6267 | Register RHSReg = MI.getOperand(i: 2).getReg(); |
6268 | MI.getOperand(i: 1).setReg(RHSReg); |
6269 | MI.getOperand(i: 2).setReg(LHSReg); |
6270 | Observer.changedInstr(MI); |
6271 | } |
6272 | |
6273 | bool CombinerHelper::isOneOrOneSplat(Register Src, bool AllowUndefs) { |
6274 | LLT SrcTy = MRI.getType(Reg: Src); |
6275 | if (SrcTy.isFixedVector()) |
6276 | return isConstantSplatVector(Src, SplatValue: 1, AllowUndefs); |
6277 | if (SrcTy.isScalar()) { |
6278 | if (AllowUndefs && getOpcodeDef<GImplicitDef>(Reg: Src, MRI) != nullptr) |
6279 | return true; |
6280 | auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI); |
6281 | return IConstant && IConstant->Value == 1; |
6282 | } |
6283 | return false; // scalable vector |
6284 | } |
6285 | |
6286 | bool CombinerHelper::isZeroOrZeroSplat(Register Src, bool AllowUndefs) { |
6287 | LLT SrcTy = MRI.getType(Reg: Src); |
6288 | if (SrcTy.isFixedVector()) |
6289 | return isConstantSplatVector(Src, SplatValue: 0, AllowUndefs); |
6290 | if (SrcTy.isScalar()) { |
6291 | if (AllowUndefs && getOpcodeDef<GImplicitDef>(Reg: Src, MRI) != nullptr) |
6292 | return true; |
6293 | auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI); |
6294 | return IConstant && IConstant->Value == 0; |
6295 | } |
6296 | return false; // scalable vector |
6297 | } |
6298 | |
6299 | // Ignores COPYs during conformance checks. |
6300 | // FIXME scalable vectors. |
6301 | bool CombinerHelper::isConstantSplatVector(Register Src, int64_t SplatValue, |
6302 | bool AllowUndefs) { |
6303 | GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Reg: Src, MRI); |
6304 | if (!BuildVector) |
6305 | return false; |
6306 | unsigned NumSources = BuildVector->getNumSources(); |
6307 | |
6308 | for (unsigned I = 0; I < NumSources; ++I) { |
6309 | GImplicitDef *ImplicitDef = |
6310 | getOpcodeDef<GImplicitDef>(Reg: BuildVector->getSourceReg(I), MRI); |
6311 | if (ImplicitDef && AllowUndefs) |
6312 | continue; |
6313 | if (ImplicitDef && !AllowUndefs) |
6314 | return false; |
6315 | std::optional<ValueAndVReg> IConstant = |
6316 | getIConstantVRegValWithLookThrough(VReg: BuildVector->getSourceReg(I), MRI); |
6317 | if (IConstant && IConstant->Value == SplatValue) |
6318 | continue; |
6319 | return false; |
6320 | } |
6321 | return true; |
6322 | } |
6323 | |
6324 | // Ignores COPYs during lookups. |
6325 | // FIXME scalable vectors |
6326 | std::optional<APInt> |
6327 | CombinerHelper::getConstantOrConstantSplatVector(Register Src) { |
6328 | auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI); |
6329 | if (IConstant) |
6330 | return IConstant->Value; |
6331 | |
6332 | GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Reg: Src, MRI); |
6333 | if (!BuildVector) |
6334 | return std::nullopt; |
6335 | unsigned NumSources = BuildVector->getNumSources(); |
6336 | |
6337 | std::optional<APInt> Value = std::nullopt; |
6338 | for (unsigned I = 0; I < NumSources; ++I) { |
6339 | std::optional<ValueAndVReg> IConstant = |
6340 | getIConstantVRegValWithLookThrough(VReg: BuildVector->getSourceReg(I), MRI); |
6341 | if (!IConstant) |
6342 | return std::nullopt; |
6343 | if (!Value) |
6344 | Value = IConstant->Value; |
6345 | else if (*Value != IConstant->Value) |
6346 | return std::nullopt; |
6347 | } |
6348 | return Value; |
6349 | } |
6350 | |
6351 | // TODO: use knownbits to determine zeros |
6352 | bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select, |
6353 | BuildFnTy &MatchInfo) { |
6354 | uint32_t Flags = Select->getFlags(); |
6355 | Register Dest = Select->getReg(Idx: 0); |
6356 | Register Cond = Select->getCondReg(); |
6357 | Register True = Select->getTrueReg(); |
6358 | Register False = Select->getFalseReg(); |
6359 | LLT CondTy = MRI.getType(Reg: Select->getCondReg()); |
6360 | LLT TrueTy = MRI.getType(Reg: Select->getTrueReg()); |
6361 | |
6362 | // We only do this combine for scalar boolean conditions. |
6363 | if (CondTy != LLT::scalar(SizeInBits: 1)) |
6364 | return false; |
6365 | |
6366 | // Both are scalars. |
6367 | std::optional<ValueAndVReg> TrueOpt = |
6368 | getIConstantVRegValWithLookThrough(VReg: True, MRI); |
6369 | std::optional<ValueAndVReg> FalseOpt = |
6370 | getIConstantVRegValWithLookThrough(VReg: False, MRI); |
6371 | |
6372 | if (!TrueOpt || !FalseOpt) |
6373 | return false; |
6374 | |
6375 | APInt TrueValue = TrueOpt->Value; |
6376 | APInt FalseValue = FalseOpt->Value; |
6377 | |
6378 | // select Cond, 1, 0 --> zext (Cond) |
6379 | if (TrueValue.isOne() && FalseValue.isZero()) { |
6380 | MatchInfo = [=](MachineIRBuilder &B) { |
6381 | B.setInstrAndDebugLoc(*Select); |
6382 | B.buildZExtOrTrunc(Res: Dest, Op: Cond); |
6383 | }; |
6384 | return true; |
6385 | } |
6386 | |
6387 | // select Cond, -1, 0 --> sext (Cond) |
6388 | if (TrueValue.isAllOnes() && FalseValue.isZero()) { |
6389 | MatchInfo = [=](MachineIRBuilder &B) { |
6390 | B.setInstrAndDebugLoc(*Select); |
6391 | B.buildSExtOrTrunc(Res: Dest, Op: Cond); |
6392 | }; |
6393 | return true; |
6394 | } |
6395 | |
6396 | // select Cond, 0, 1 --> zext (!Cond) |
6397 | if (TrueValue.isZero() && FalseValue.isOne()) { |
6398 | MatchInfo = [=](MachineIRBuilder &B) { |
6399 | B.setInstrAndDebugLoc(*Select); |
6400 | Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy); |
6401 | B.buildNot(Dst: Inner, Src0: Cond); |
6402 | B.buildZExtOrTrunc(Res: Dest, Op: Inner); |
6403 | }; |
6404 | return true; |
6405 | } |
6406 | |
6407 | // select Cond, 0, -1 --> sext (!Cond) |
6408 | if (TrueValue.isZero() && FalseValue.isAllOnes()) { |
6409 | MatchInfo = [=](MachineIRBuilder &B) { |
6410 | B.setInstrAndDebugLoc(*Select); |
6411 | Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy); |
6412 | B.buildNot(Dst: Inner, Src0: Cond); |
6413 | B.buildSExtOrTrunc(Res: Dest, Op: Inner); |
6414 | }; |
6415 | return true; |
6416 | } |
6417 | |
6418 | // select Cond, C1, C1-1 --> add (zext Cond), C1-1 |
6419 | if (TrueValue - 1 == FalseValue) { |
6420 | MatchInfo = [=](MachineIRBuilder &B) { |
6421 | B.setInstrAndDebugLoc(*Select); |
6422 | Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6423 | B.buildZExtOrTrunc(Res: Inner, Op: Cond); |
6424 | B.buildAdd(Dst: Dest, Src0: Inner, Src1: False); |
6425 | }; |
6426 | return true; |
6427 | } |
6428 | |
6429 | // select Cond, C1, C1+1 --> add (sext Cond), C1+1 |
6430 | if (TrueValue + 1 == FalseValue) { |
6431 | MatchInfo = [=](MachineIRBuilder &B) { |
6432 | B.setInstrAndDebugLoc(*Select); |
6433 | Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6434 | B.buildSExtOrTrunc(Res: Inner, Op: Cond); |
6435 | B.buildAdd(Dst: Dest, Src0: Inner, Src1: False); |
6436 | }; |
6437 | return true; |
6438 | } |
6439 | |
6440 | // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2) |
6441 | if (TrueValue.isPowerOf2() && FalseValue.isZero()) { |
6442 | MatchInfo = [=](MachineIRBuilder &B) { |
6443 | B.setInstrAndDebugLoc(*Select); |
6444 | Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6445 | B.buildZExtOrTrunc(Res: Inner, Op: Cond); |
6446 | // The shift amount must be scalar. |
6447 | LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy; |
6448 | auto ShAmtC = B.buildConstant(Res: ShiftTy, Val: TrueValue.exactLogBase2()); |
6449 | B.buildShl(Dst: Dest, Src0: Inner, Src1: ShAmtC, Flags); |
6450 | }; |
6451 | return true; |
6452 | } |
6453 | // select Cond, -1, C --> or (sext Cond), C |
6454 | if (TrueValue.isAllOnes()) { |
6455 | MatchInfo = [=](MachineIRBuilder &B) { |
6456 | B.setInstrAndDebugLoc(*Select); |
6457 | Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6458 | B.buildSExtOrTrunc(Res: Inner, Op: Cond); |
6459 | B.buildOr(Dst: Dest, Src0: Inner, Src1: False, Flags); |
6460 | }; |
6461 | return true; |
6462 | } |
6463 | |
6464 | // select Cond, C, -1 --> or (sext (not Cond)), C |
6465 | if (FalseValue.isAllOnes()) { |
6466 | MatchInfo = [=](MachineIRBuilder &B) { |
6467 | B.setInstrAndDebugLoc(*Select); |
6468 | Register Not = MRI.createGenericVirtualRegister(Ty: CondTy); |
6469 | B.buildNot(Dst: Not, Src0: Cond); |
6470 | Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6471 | B.buildSExtOrTrunc(Res: Inner, Op: Not); |
6472 | B.buildOr(Dst: Dest, Src0: Inner, Src1: True, Flags); |
6473 | }; |
6474 | return true; |
6475 | } |
6476 | |
6477 | return false; |
6478 | } |
6479 | |
6480 | // TODO: use knownbits to determine zeros |
6481 | bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select, |
6482 | BuildFnTy &MatchInfo) { |
6483 | uint32_t Flags = Select->getFlags(); |
6484 | Register DstReg = Select->getReg(Idx: 0); |
6485 | Register Cond = Select->getCondReg(); |
6486 | Register True = Select->getTrueReg(); |
6487 | Register False = Select->getFalseReg(); |
6488 | LLT CondTy = MRI.getType(Reg: Select->getCondReg()); |
6489 | LLT TrueTy = MRI.getType(Reg: Select->getTrueReg()); |
6490 | |
6491 | // Boolean or fixed vector of booleans. |
6492 | if (CondTy.isScalableVector() || |
6493 | (CondTy.isFixedVector() && |
6494 | CondTy.getElementType().getScalarSizeInBits() != 1) || |
6495 | CondTy.getScalarSizeInBits() != 1) |
6496 | return false; |
6497 | |
6498 | if (CondTy != TrueTy) |
6499 | return false; |
6500 | |
6501 | // select Cond, Cond, F --> or Cond, F |
6502 | // select Cond, 1, F --> or Cond, F |
6503 | if ((Cond == True) || isOneOrOneSplat(Src: True, /* AllowUndefs */ true)) { |
6504 | MatchInfo = [=](MachineIRBuilder &B) { |
6505 | B.setInstrAndDebugLoc(*Select); |
6506 | Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6507 | B.buildZExtOrTrunc(Res: Ext, Op: Cond); |
6508 | B.buildOr(Dst: DstReg, Src0: Ext, Src1: False, Flags); |
6509 | }; |
6510 | return true; |
6511 | } |
6512 | |
6513 | // select Cond, T, Cond --> and Cond, T |
6514 | // select Cond, T, 0 --> and Cond, T |
6515 | if ((Cond == False) || isZeroOrZeroSplat(Src: False, /* AllowUndefs */ true)) { |
6516 | MatchInfo = [=](MachineIRBuilder &B) { |
6517 | B.setInstrAndDebugLoc(*Select); |
6518 | Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6519 | B.buildZExtOrTrunc(Res: Ext, Op: Cond); |
6520 | B.buildAnd(Dst: DstReg, Src0: Ext, Src1: True); |
6521 | }; |
6522 | return true; |
6523 | } |
6524 | |
6525 | // select Cond, T, 1 --> or (not Cond), T |
6526 | if (isOneOrOneSplat(Src: False, /* AllowUndefs */ true)) { |
6527 | MatchInfo = [=](MachineIRBuilder &B) { |
6528 | B.setInstrAndDebugLoc(*Select); |
6529 | // First the not. |
6530 | Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy); |
6531 | B.buildNot(Dst: Inner, Src0: Cond); |
6532 | // Then an ext to match the destination register. |
6533 | Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6534 | B.buildZExtOrTrunc(Res: Ext, Op: Inner); |
6535 | B.buildOr(Dst: DstReg, Src0: Ext, Src1: True, Flags); |
6536 | }; |
6537 | return true; |
6538 | } |
6539 | |
6540 | // select Cond, 0, F --> and (not Cond), F |
6541 | if (isZeroOrZeroSplat(Src: True, /* AllowUndefs */ true)) { |
6542 | MatchInfo = [=](MachineIRBuilder &B) { |
6543 | B.setInstrAndDebugLoc(*Select); |
6544 | // First the not. |
6545 | Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy); |
6546 | B.buildNot(Dst: Inner, Src0: Cond); |
6547 | // Then an ext to match the destination register. |
6548 | Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6549 | B.buildZExtOrTrunc(Res: Ext, Op: Inner); |
6550 | B.buildAnd(Dst: DstReg, Src0: Ext, Src1: False); |
6551 | }; |
6552 | return true; |
6553 | } |
6554 | |
6555 | return false; |
6556 | } |
6557 | |
6558 | bool CombinerHelper::tryFoldSelectToIntMinMax(GSelect *Select, |
6559 | BuildFnTy &MatchInfo) { |
6560 | Register DstReg = Select->getReg(Idx: 0); |
6561 | Register Cond = Select->getCondReg(); |
6562 | Register True = Select->getTrueReg(); |
6563 | Register False = Select->getFalseReg(); |
6564 | LLT DstTy = MRI.getType(Reg: DstReg); |
6565 | |
6566 | if (DstTy.isPointer()) |
6567 | return false; |
6568 | |
6569 | // We need an G_ICMP on the condition register. |
6570 | GICmp *Cmp = getOpcodeDef<GICmp>(Reg: Cond, MRI); |
6571 | if (!Cmp) |
6572 | return false; |
6573 | |
6574 | // We want to fold the icmp and replace the select. |
6575 | if (!MRI.hasOneNonDBGUse(RegNo: Cmp->getReg(Idx: 0))) |
6576 | return false; |
6577 | |
6578 | CmpInst::Predicate Pred = Cmp->getCond(); |
6579 | // We need a larger or smaller predicate for |
6580 | // canonicalization. |
6581 | if (CmpInst::isEquality(pred: Pred)) |
6582 | return false; |
6583 | |
6584 | Register CmpLHS = Cmp->getLHSReg(); |
6585 | Register CmpRHS = Cmp->getRHSReg(); |
6586 | |
6587 | // We can swap CmpLHS and CmpRHS for higher hitrate. |
6588 | if (True == CmpRHS && False == CmpLHS) { |
6589 | std::swap(a&: CmpLHS, b&: CmpRHS); |
6590 | Pred = CmpInst::getSwappedPredicate(pred: Pred); |
6591 | } |
6592 | |
6593 | // (icmp X, Y) ? X : Y -> integer minmax. |
6594 | // see matchSelectPattern in ValueTracking. |
6595 | // Legality between G_SELECT and integer minmax can differ. |
6596 | if (True == CmpLHS && False == CmpRHS) { |
6597 | switch (Pred) { |
6598 | case ICmpInst::ICMP_UGT: |
6599 | case ICmpInst::ICMP_UGE: { |
6600 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_UMAX, DstTy})) |
6601 | return false; |
6602 | MatchInfo = [=](MachineIRBuilder &B) { |
6603 | B.buildUMax(Dst: DstReg, Src0: True, Src1: False); |
6604 | }; |
6605 | return true; |
6606 | } |
6607 | case ICmpInst::ICMP_SGT: |
6608 | case ICmpInst::ICMP_SGE: { |
6609 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SMAX, DstTy})) |
6610 | return false; |
6611 | MatchInfo = [=](MachineIRBuilder &B) { |
6612 | B.buildSMax(Dst: DstReg, Src0: True, Src1: False); |
6613 | }; |
6614 | return true; |
6615 | } |
6616 | case ICmpInst::ICMP_ULT: |
6617 | case ICmpInst::ICMP_ULE: { |
6618 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_UMIN, DstTy})) |
6619 | return false; |
6620 | MatchInfo = [=](MachineIRBuilder &B) { |
6621 | B.buildUMin(Dst: DstReg, Src0: True, Src1: False); |
6622 | }; |
6623 | return true; |
6624 | } |
6625 | case ICmpInst::ICMP_SLT: |
6626 | case ICmpInst::ICMP_SLE: { |
6627 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SMIN, DstTy})) |
6628 | return false; |
6629 | MatchInfo = [=](MachineIRBuilder &B) { |
6630 | B.buildSMin(Dst: DstReg, Src0: True, Src1: False); |
6631 | }; |
6632 | return true; |
6633 | } |
6634 | default: |
6635 | return false; |
6636 | } |
6637 | } |
6638 | |
6639 | return false; |
6640 | } |
6641 | |
6642 | bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) { |
6643 | GSelect *Select = cast<GSelect>(Val: &MI); |
6644 | |
6645 | if (tryFoldSelectOfConstants(Select, MatchInfo)) |
6646 | return true; |
6647 | |
6648 | if (tryFoldBoolSelectToLogic(Select, MatchInfo)) |
6649 | return true; |
6650 | |
6651 | if (tryFoldSelectToIntMinMax(Select, MatchInfo)) |
6652 | return true; |
6653 | |
6654 | return false; |
6655 | } |
6656 | |
6657 | /// Fold (icmp Pred1 V1, C1) && (icmp Pred2 V2, C2) |
6658 | /// or (icmp Pred1 V1, C1) || (icmp Pred2 V2, C2) |
6659 | /// into a single comparison using range-based reasoning. |
6660 | /// see InstCombinerImpl::foldAndOrOfICmpsUsingRanges. |
6661 | bool CombinerHelper::tryFoldAndOrOrICmpsUsingRanges(GLogicalBinOp *Logic, |
6662 | BuildFnTy &MatchInfo) { |
6663 | assert(Logic->getOpcode() != TargetOpcode::G_XOR && "unexpected xor" ); |
6664 | bool IsAnd = Logic->getOpcode() == TargetOpcode::G_AND; |
6665 | Register DstReg = Logic->getReg(Idx: 0); |
6666 | Register LHS = Logic->getLHSReg(); |
6667 | Register RHS = Logic->getRHSReg(); |
6668 | unsigned Flags = Logic->getFlags(); |
6669 | |
6670 | // We need an G_ICMP on the LHS register. |
6671 | GICmp *Cmp1 = getOpcodeDef<GICmp>(Reg: LHS, MRI); |
6672 | if (!Cmp1) |
6673 | return false; |
6674 | |
6675 | // We need an G_ICMP on the RHS register. |
6676 | GICmp *Cmp2 = getOpcodeDef<GICmp>(Reg: RHS, MRI); |
6677 | if (!Cmp2) |
6678 | return false; |
6679 | |
6680 | // We want to fold the icmps. |
6681 | if (!MRI.hasOneNonDBGUse(RegNo: Cmp1->getReg(Idx: 0)) || |
6682 | !MRI.hasOneNonDBGUse(RegNo: Cmp2->getReg(Idx: 0))) |
6683 | return false; |
6684 | |
6685 | APInt C1; |
6686 | APInt C2; |
6687 | std::optional<ValueAndVReg> MaybeC1 = |
6688 | getIConstantVRegValWithLookThrough(VReg: Cmp1->getRHSReg(), MRI); |
6689 | if (!MaybeC1) |
6690 | return false; |
6691 | C1 = MaybeC1->Value; |
6692 | |
6693 | std::optional<ValueAndVReg> MaybeC2 = |
6694 | getIConstantVRegValWithLookThrough(VReg: Cmp2->getRHSReg(), MRI); |
6695 | if (!MaybeC2) |
6696 | return false; |
6697 | C2 = MaybeC2->Value; |
6698 | |
6699 | Register R1 = Cmp1->getLHSReg(); |
6700 | Register R2 = Cmp2->getLHSReg(); |
6701 | CmpInst::Predicate Pred1 = Cmp1->getCond(); |
6702 | CmpInst::Predicate Pred2 = Cmp2->getCond(); |
6703 | LLT CmpTy = MRI.getType(Reg: Cmp1->getReg(Idx: 0)); |
6704 | LLT CmpOperandTy = MRI.getType(Reg: R1); |
6705 | |
6706 | // We build ands, adds, and constants of type CmpOperandTy. |
6707 | // They must be legal to build. |
6708 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_AND, CmpOperandTy}) || |
6709 | !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, CmpOperandTy}) || |
6710 | !isConstantLegalOrBeforeLegalizer(Ty: CmpOperandTy)) |
6711 | return false; |
6712 | |
6713 | // Look through add of a constant offset on R1, R2, or both operands. This |
6714 | // allows us to interpret the R + C' < C'' range idiom into a proper range. |
6715 | std::optional<APInt> Offset1; |
6716 | std::optional<APInt> Offset2; |
6717 | if (R1 != R2) { |
6718 | if (GAdd *Add = getOpcodeDef<GAdd>(Reg: R1, MRI)) { |
6719 | std::optional<ValueAndVReg> MaybeOffset1 = |
6720 | getIConstantVRegValWithLookThrough(VReg: Add->getRHSReg(), MRI); |
6721 | if (MaybeOffset1) { |
6722 | R1 = Add->getLHSReg(); |
6723 | Offset1 = MaybeOffset1->Value; |
6724 | } |
6725 | } |
6726 | if (GAdd *Add = getOpcodeDef<GAdd>(Reg: R2, MRI)) { |
6727 | std::optional<ValueAndVReg> MaybeOffset2 = |
6728 | getIConstantVRegValWithLookThrough(VReg: Add->getRHSReg(), MRI); |
6729 | if (MaybeOffset2) { |
6730 | R2 = Add->getLHSReg(); |
6731 | Offset2 = MaybeOffset2->Value; |
6732 | } |
6733 | } |
6734 | } |
6735 | |
6736 | if (R1 != R2) |
6737 | return false; |
6738 | |
6739 | // We calculate the icmp ranges including maybe offsets. |
6740 | ConstantRange CR1 = ConstantRange::makeExactICmpRegion( |
6741 | Pred: IsAnd ? ICmpInst::getInversePredicate(pred: Pred1) : Pred1, Other: C1); |
6742 | if (Offset1) |
6743 | CR1 = CR1.subtract(CI: *Offset1); |
6744 | |
6745 | ConstantRange CR2 = ConstantRange::makeExactICmpRegion( |
6746 | Pred: IsAnd ? ICmpInst::getInversePredicate(pred: Pred2) : Pred2, Other: C2); |
6747 | if (Offset2) |
6748 | CR2 = CR2.subtract(CI: *Offset2); |
6749 | |
6750 | bool CreateMask = false; |
6751 | APInt LowerDiff; |
6752 | std::optional<ConstantRange> CR = CR1.exactUnionWith(CR: CR2); |
6753 | if (!CR) { |
6754 | // We need non-wrapping ranges. |
6755 | if (CR1.isWrappedSet() || CR2.isWrappedSet()) |
6756 | return false; |
6757 | |
6758 | // Check whether we have equal-size ranges that only differ by one bit. |
6759 | // In that case we can apply a mask to map one range onto the other. |
6760 | LowerDiff = CR1.getLower() ^ CR2.getLower(); |
6761 | APInt UpperDiff = (CR1.getUpper() - 1) ^ (CR2.getUpper() - 1); |
6762 | APInt CR1Size = CR1.getUpper() - CR1.getLower(); |
6763 | if (!LowerDiff.isPowerOf2() || LowerDiff != UpperDiff || |
6764 | CR1Size != CR2.getUpper() - CR2.getLower()) |
6765 | return false; |
6766 | |
6767 | CR = CR1.getLower().ult(RHS: CR2.getLower()) ? CR1 : CR2; |
6768 | CreateMask = true; |
6769 | } |
6770 | |
6771 | if (IsAnd) |
6772 | CR = CR->inverse(); |
6773 | |
6774 | CmpInst::Predicate NewPred; |
6775 | APInt NewC, Offset; |
6776 | CR->getEquivalentICmp(Pred&: NewPred, RHS&: NewC, Offset); |
6777 | |
6778 | // We take the result type of one of the original icmps, CmpTy, for |
6779 | // the to be build icmp. The operand type, CmpOperandTy, is used for |
6780 | // the other instructions and constants to be build. The types of |
6781 | // the parameters and output are the same for add and and. CmpTy |
6782 | // and the type of DstReg might differ. That is why we zext or trunc |
6783 | // the icmp into the destination register. |
6784 | |
6785 | MatchInfo = [=](MachineIRBuilder &B) { |
6786 | if (CreateMask && Offset != 0) { |
6787 | auto TildeLowerDiff = B.buildConstant(Res: CmpOperandTy, Val: ~LowerDiff); |
6788 | auto And = B.buildAnd(Dst: CmpOperandTy, Src0: R1, Src1: TildeLowerDiff); // the mask. |
6789 | auto OffsetC = B.buildConstant(Res: CmpOperandTy, Val: Offset); |
6790 | auto Add = B.buildAdd(Dst: CmpOperandTy, Src0: And, Src1: OffsetC, Flags); |
6791 | auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC); |
6792 | auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: Add, Op1: NewCon); |
6793 | B.buildZExtOrTrunc(Res: DstReg, Op: ICmp); |
6794 | } else if (CreateMask && Offset == 0) { |
6795 | auto TildeLowerDiff = B.buildConstant(Res: CmpOperandTy, Val: ~LowerDiff); |
6796 | auto And = B.buildAnd(Dst: CmpOperandTy, Src0: R1, Src1: TildeLowerDiff); // the mask. |
6797 | auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC); |
6798 | auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: And, Op1: NewCon); |
6799 | B.buildZExtOrTrunc(Res: DstReg, Op: ICmp); |
6800 | } else if (!CreateMask && Offset != 0) { |
6801 | auto OffsetC = B.buildConstant(Res: CmpOperandTy, Val: Offset); |
6802 | auto Add = B.buildAdd(Dst: CmpOperandTy, Src0: R1, Src1: OffsetC, Flags); |
6803 | auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC); |
6804 | auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: Add, Op1: NewCon); |
6805 | B.buildZExtOrTrunc(Res: DstReg, Op: ICmp); |
6806 | } else if (!CreateMask && Offset == 0) { |
6807 | auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC); |
6808 | auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: R1, Op1: NewCon); |
6809 | B.buildZExtOrTrunc(Res: DstReg, Op: ICmp); |
6810 | } else { |
6811 | llvm_unreachable("unexpected configuration of CreateMask and Offset" ); |
6812 | } |
6813 | }; |
6814 | return true; |
6815 | } |
6816 | |
6817 | bool CombinerHelper::matchAnd(MachineInstr &MI, BuildFnTy &MatchInfo) { |
6818 | GAnd *And = cast<GAnd>(Val: &MI); |
6819 | |
6820 | if (tryFoldAndOrOrICmpsUsingRanges(Logic: And, MatchInfo)) |
6821 | return true; |
6822 | |
6823 | return false; |
6824 | } |
6825 | |
6826 | bool CombinerHelper::matchOr(MachineInstr &MI, BuildFnTy &MatchInfo) { |
6827 | GOr *Or = cast<GOr>(Val: &MI); |
6828 | |
6829 | if (tryFoldAndOrOrICmpsUsingRanges(Logic: Or, MatchInfo)) |
6830 | return true; |
6831 | |
6832 | return false; |
6833 | } |
6834 | |