1//===-- lib/CodeGen/GlobalISel/GICombinerHelper.cpp -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8#include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
9#include "llvm/ADT/APFloat.h"
10#include "llvm/ADT/STLExtras.h"
11#include "llvm/ADT/SetVector.h"
12#include "llvm/ADT/SmallBitVector.h"
13#include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
14#include "llvm/CodeGen/GlobalISel/GISelKnownBits.h"
15#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
16#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
17#include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
18#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
19#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
20#include "llvm/CodeGen/GlobalISel/Utils.h"
21#include "llvm/CodeGen/LowLevelTypeUtils.h"
22#include "llvm/CodeGen/MachineBasicBlock.h"
23#include "llvm/CodeGen/MachineDominators.h"
24#include "llvm/CodeGen/MachineInstr.h"
25#include "llvm/CodeGen/MachineMemOperand.h"
26#include "llvm/CodeGen/MachineRegisterInfo.h"
27#include "llvm/CodeGen/RegisterBankInfo.h"
28#include "llvm/CodeGen/TargetInstrInfo.h"
29#include "llvm/CodeGen/TargetLowering.h"
30#include "llvm/CodeGen/TargetOpcodes.h"
31#include "llvm/IR/ConstantRange.h"
32#include "llvm/IR/DataLayout.h"
33#include "llvm/IR/InstrTypes.h"
34#include "llvm/Support/Casting.h"
35#include "llvm/Support/DivisionByConstantInfo.h"
36#include "llvm/Support/ErrorHandling.h"
37#include "llvm/Support/MathExtras.h"
38#include "llvm/Target/TargetMachine.h"
39#include <cmath>
40#include <optional>
41#include <tuple>
42
43#define DEBUG_TYPE "gi-combiner"
44
45using namespace llvm;
46using namespace MIPatternMatch;
47
48// Option to allow testing of the combiner while no targets know about indexed
49// addressing.
50static cl::opt<bool>
51 ForceLegalIndexing("force-legal-indexing", cl::Hidden, cl::init(Val: false),
52 cl::desc("Force all indexed operations to be "
53 "legal for the GlobalISel combiner"));
54
55CombinerHelper::CombinerHelper(GISelChangeObserver &Observer,
56 MachineIRBuilder &B, bool IsPreLegalize,
57 GISelKnownBits *KB, MachineDominatorTree *MDT,
58 const LegalizerInfo *LI)
59 : Builder(B), MRI(Builder.getMF().getRegInfo()), Observer(Observer), KB(KB),
60 MDT(MDT), IsPreLegalize(IsPreLegalize), LI(LI),
61 RBI(Builder.getMF().getSubtarget().getRegBankInfo()),
62 TRI(Builder.getMF().getSubtarget().getRegisterInfo()) {
63 (void)this->KB;
64}
65
66const TargetLowering &CombinerHelper::getTargetLowering() const {
67 return *Builder.getMF().getSubtarget().getTargetLowering();
68}
69
70/// \returns The little endian in-memory byte position of byte \p I in a
71/// \p ByteWidth bytes wide type.
72///
73/// E.g. Given a 4-byte type x, x[0] -> byte 0
74static unsigned littleEndianByteAt(const unsigned ByteWidth, const unsigned I) {
75 assert(I < ByteWidth && "I must be in [0, ByteWidth)");
76 return I;
77}
78
79/// Determines the LogBase2 value for a non-null input value using the
80/// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
81static Register buildLogBase2(Register V, MachineIRBuilder &MIB) {
82 auto &MRI = *MIB.getMRI();
83 LLT Ty = MRI.getType(Reg: V);
84 auto Ctlz = MIB.buildCTLZ(Dst: Ty, Src0: V);
85 auto Base = MIB.buildConstant(Res: Ty, Val: Ty.getScalarSizeInBits() - 1);
86 return MIB.buildSub(Dst: Ty, Src0: Base, Src1: Ctlz).getReg(Idx: 0);
87}
88
89/// \returns The big endian in-memory byte position of byte \p I in a
90/// \p ByteWidth bytes wide type.
91///
92/// E.g. Given a 4-byte type x, x[0] -> byte 3
93static unsigned bigEndianByteAt(const unsigned ByteWidth, const unsigned I) {
94 assert(I < ByteWidth && "I must be in [0, ByteWidth)");
95 return ByteWidth - I - 1;
96}
97
98/// Given a map from byte offsets in memory to indices in a load/store,
99/// determine if that map corresponds to a little or big endian byte pattern.
100///
101/// \param MemOffset2Idx maps memory offsets to address offsets.
102/// \param LowestIdx is the lowest index in \p MemOffset2Idx.
103///
104/// \returns true if the map corresponds to a big endian byte pattern, false if
105/// it corresponds to a little endian byte pattern, and std::nullopt otherwise.
106///
107/// E.g. given a 32-bit type x, and x[AddrOffset], the in-memory byte patterns
108/// are as follows:
109///
110/// AddrOffset Little endian Big endian
111/// 0 0 3
112/// 1 1 2
113/// 2 2 1
114/// 3 3 0
115static std::optional<bool>
116isBigEndian(const SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx,
117 int64_t LowestIdx) {
118 // Need at least two byte positions to decide on endianness.
119 unsigned Width = MemOffset2Idx.size();
120 if (Width < 2)
121 return std::nullopt;
122 bool BigEndian = true, LittleEndian = true;
123 for (unsigned MemOffset = 0; MemOffset < Width; ++ MemOffset) {
124 auto MemOffsetAndIdx = MemOffset2Idx.find(Val: MemOffset);
125 if (MemOffsetAndIdx == MemOffset2Idx.end())
126 return std::nullopt;
127 const int64_t Idx = MemOffsetAndIdx->second - LowestIdx;
128 assert(Idx >= 0 && "Expected non-negative byte offset?");
129 LittleEndian &= Idx == littleEndianByteAt(ByteWidth: Width, I: MemOffset);
130 BigEndian &= Idx == bigEndianByteAt(ByteWidth: Width, I: MemOffset);
131 if (!BigEndian && !LittleEndian)
132 return std::nullopt;
133 }
134
135 assert((BigEndian != LittleEndian) &&
136 "Pattern cannot be both big and little endian!");
137 return BigEndian;
138}
139
140bool CombinerHelper::isPreLegalize() const { return IsPreLegalize; }
141
142bool CombinerHelper::isLegal(const LegalityQuery &Query) const {
143 assert(LI && "Must have LegalizerInfo to query isLegal!");
144 return LI->getAction(Query).Action == LegalizeActions::Legal;
145}
146
147bool CombinerHelper::isLegalOrBeforeLegalizer(
148 const LegalityQuery &Query) const {
149 return isPreLegalize() || isLegal(Query);
150}
151
152bool CombinerHelper::isConstantLegalOrBeforeLegalizer(const LLT Ty) const {
153 if (!Ty.isVector())
154 return isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_CONSTANT, {Ty}});
155 // Vector constants are represented as a G_BUILD_VECTOR of scalar G_CONSTANTs.
156 if (isPreLegalize())
157 return true;
158 LLT EltTy = Ty.getElementType();
159 return isLegal(Query: {TargetOpcode::G_BUILD_VECTOR, {Ty, EltTy}}) &&
160 isLegal(Query: {TargetOpcode::G_CONSTANT, {EltTy}});
161}
162
163void CombinerHelper::replaceRegWith(MachineRegisterInfo &MRI, Register FromReg,
164 Register ToReg) const {
165 Observer.changingAllUsesOfReg(MRI, Reg: FromReg);
166
167 if (MRI.constrainRegAttrs(Reg: ToReg, ConstrainingReg: FromReg))
168 MRI.replaceRegWith(FromReg, ToReg);
169 else
170 Builder.buildCopy(Res: ToReg, Op: FromReg);
171
172 Observer.finishedChangingAllUsesOfReg();
173}
174
175void CombinerHelper::replaceRegOpWith(MachineRegisterInfo &MRI,
176 MachineOperand &FromRegOp,
177 Register ToReg) const {
178 assert(FromRegOp.getParent() && "Expected an operand in an MI");
179 Observer.changingInstr(MI&: *FromRegOp.getParent());
180
181 FromRegOp.setReg(ToReg);
182
183 Observer.changedInstr(MI&: *FromRegOp.getParent());
184}
185
186void CombinerHelper::replaceOpcodeWith(MachineInstr &FromMI,
187 unsigned ToOpcode) const {
188 Observer.changingInstr(MI&: FromMI);
189
190 FromMI.setDesc(Builder.getTII().get(Opcode: ToOpcode));
191
192 Observer.changedInstr(MI&: FromMI);
193}
194
195const RegisterBank *CombinerHelper::getRegBank(Register Reg) const {
196 return RBI->getRegBank(Reg, MRI, TRI: *TRI);
197}
198
199void CombinerHelper::setRegBank(Register Reg, const RegisterBank *RegBank) {
200 if (RegBank)
201 MRI.setRegBank(Reg, RegBank: *RegBank);
202}
203
204bool CombinerHelper::tryCombineCopy(MachineInstr &MI) {
205 if (matchCombineCopy(MI)) {
206 applyCombineCopy(MI);
207 return true;
208 }
209 return false;
210}
211bool CombinerHelper::matchCombineCopy(MachineInstr &MI) {
212 if (MI.getOpcode() != TargetOpcode::COPY)
213 return false;
214 Register DstReg = MI.getOperand(i: 0).getReg();
215 Register SrcReg = MI.getOperand(i: 1).getReg();
216 return canReplaceReg(DstReg, SrcReg, MRI);
217}
218void CombinerHelper::applyCombineCopy(MachineInstr &MI) {
219 Register DstReg = MI.getOperand(i: 0).getReg();
220 Register SrcReg = MI.getOperand(i: 1).getReg();
221 MI.eraseFromParent();
222 replaceRegWith(MRI, FromReg: DstReg, ToReg: SrcReg);
223}
224
225bool CombinerHelper::tryCombineConcatVectors(MachineInstr &MI) {
226 bool IsUndef = false;
227 SmallVector<Register, 4> Ops;
228 if (matchCombineConcatVectors(MI, IsUndef, Ops)) {
229 applyCombineConcatVectors(MI, IsUndef, Ops);
230 return true;
231 }
232 return false;
233}
234
235bool CombinerHelper::matchCombineConcatVectors(MachineInstr &MI, bool &IsUndef,
236 SmallVectorImpl<Register> &Ops) {
237 assert(MI.getOpcode() == TargetOpcode::G_CONCAT_VECTORS &&
238 "Invalid instruction");
239 IsUndef = true;
240 MachineInstr *Undef = nullptr;
241
242 // Walk over all the operands of concat vectors and check if they are
243 // build_vector themselves or undef.
244 // Then collect their operands in Ops.
245 for (const MachineOperand &MO : MI.uses()) {
246 Register Reg = MO.getReg();
247 MachineInstr *Def = MRI.getVRegDef(Reg);
248 assert(Def && "Operand not defined");
249 switch (Def->getOpcode()) {
250 case TargetOpcode::G_BUILD_VECTOR:
251 IsUndef = false;
252 // Remember the operands of the build_vector to fold
253 // them into the yet-to-build flattened concat vectors.
254 for (const MachineOperand &BuildVecMO : Def->uses())
255 Ops.push_back(Elt: BuildVecMO.getReg());
256 break;
257 case TargetOpcode::G_IMPLICIT_DEF: {
258 LLT OpType = MRI.getType(Reg);
259 // Keep one undef value for all the undef operands.
260 if (!Undef) {
261 Builder.setInsertPt(MBB&: *MI.getParent(), II: MI);
262 Undef = Builder.buildUndef(Res: OpType.getScalarType());
263 }
264 assert(MRI.getType(Undef->getOperand(0).getReg()) ==
265 OpType.getScalarType() &&
266 "All undefs should have the same type");
267 // Break the undef vector in as many scalar elements as needed
268 // for the flattening.
269 for (unsigned EltIdx = 0, EltEnd = OpType.getNumElements();
270 EltIdx != EltEnd; ++EltIdx)
271 Ops.push_back(Elt: Undef->getOperand(i: 0).getReg());
272 break;
273 }
274 default:
275 return false;
276 }
277 }
278 return true;
279}
280void CombinerHelper::applyCombineConcatVectors(
281 MachineInstr &MI, bool IsUndef, const ArrayRef<Register> Ops) {
282 // We determined that the concat_vectors can be flatten.
283 // Generate the flattened build_vector.
284 Register DstReg = MI.getOperand(i: 0).getReg();
285 Builder.setInsertPt(MBB&: *MI.getParent(), II: MI);
286 Register NewDstReg = MRI.cloneVirtualRegister(VReg: DstReg);
287
288 // Note: IsUndef is sort of redundant. We could have determine it by
289 // checking that at all Ops are undef. Alternatively, we could have
290 // generate a build_vector of undefs and rely on another combine to
291 // clean that up. For now, given we already gather this information
292 // in tryCombineConcatVectors, just save compile time and issue the
293 // right thing.
294 if (IsUndef)
295 Builder.buildUndef(Res: NewDstReg);
296 else
297 Builder.buildBuildVector(Res: NewDstReg, Ops);
298 MI.eraseFromParent();
299 replaceRegWith(MRI, FromReg: DstReg, ToReg: NewDstReg);
300}
301
302bool CombinerHelper::tryCombineShuffleVector(MachineInstr &MI) {
303 SmallVector<Register, 4> Ops;
304 if (matchCombineShuffleVector(MI, Ops)) {
305 applyCombineShuffleVector(MI, Ops);
306 return true;
307 }
308 return false;
309}
310
311bool CombinerHelper::matchCombineShuffleVector(MachineInstr &MI,
312 SmallVectorImpl<Register> &Ops) {
313 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR &&
314 "Invalid instruction kind");
315 LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
316 Register Src1 = MI.getOperand(i: 1).getReg();
317 LLT SrcType = MRI.getType(Reg: Src1);
318 // As bizarre as it may look, shuffle vector can actually produce
319 // scalar! This is because at the IR level a <1 x ty> shuffle
320 // vector is perfectly valid.
321 unsigned DstNumElts = DstType.isVector() ? DstType.getNumElements() : 1;
322 unsigned SrcNumElts = SrcType.isVector() ? SrcType.getNumElements() : 1;
323
324 // If the resulting vector is smaller than the size of the source
325 // vectors being concatenated, we won't be able to replace the
326 // shuffle vector into a concat_vectors.
327 //
328 // Note: We may still be able to produce a concat_vectors fed by
329 // extract_vector_elt and so on. It is less clear that would
330 // be better though, so don't bother for now.
331 //
332 // If the destination is a scalar, the size of the sources doesn't
333 // matter. we will lower the shuffle to a plain copy. This will
334 // work only if the source and destination have the same size. But
335 // that's covered by the next condition.
336 //
337 // TODO: If the size between the source and destination don't match
338 // we could still emit an extract vector element in that case.
339 if (DstNumElts < 2 * SrcNumElts && DstNumElts != 1)
340 return false;
341
342 // Check that the shuffle mask can be broken evenly between the
343 // different sources.
344 if (DstNumElts % SrcNumElts != 0)
345 return false;
346
347 // Mask length is a multiple of the source vector length.
348 // Check if the shuffle is some kind of concatenation of the input
349 // vectors.
350 unsigned NumConcat = DstNumElts / SrcNumElts;
351 SmallVector<int, 8> ConcatSrcs(NumConcat, -1);
352 ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask();
353 for (unsigned i = 0; i != DstNumElts; ++i) {
354 int Idx = Mask[i];
355 // Undef value.
356 if (Idx < 0)
357 continue;
358 // Ensure the indices in each SrcType sized piece are sequential and that
359 // the same source is used for the whole piece.
360 if ((Idx % SrcNumElts != (i % SrcNumElts)) ||
361 (ConcatSrcs[i / SrcNumElts] >= 0 &&
362 ConcatSrcs[i / SrcNumElts] != (int)(Idx / SrcNumElts)))
363 return false;
364 // Remember which source this index came from.
365 ConcatSrcs[i / SrcNumElts] = Idx / SrcNumElts;
366 }
367
368 // The shuffle is concatenating multiple vectors together.
369 // Collect the different operands for that.
370 Register UndefReg;
371 Register Src2 = MI.getOperand(i: 2).getReg();
372 for (auto Src : ConcatSrcs) {
373 if (Src < 0) {
374 if (!UndefReg) {
375 Builder.setInsertPt(MBB&: *MI.getParent(), II: MI);
376 UndefReg = Builder.buildUndef(Res: SrcType).getReg(Idx: 0);
377 }
378 Ops.push_back(Elt: UndefReg);
379 } else if (Src == 0)
380 Ops.push_back(Elt: Src1);
381 else
382 Ops.push_back(Elt: Src2);
383 }
384 return true;
385}
386
387void CombinerHelper::applyCombineShuffleVector(MachineInstr &MI,
388 const ArrayRef<Register> Ops) {
389 Register DstReg = MI.getOperand(i: 0).getReg();
390 Builder.setInsertPt(MBB&: *MI.getParent(), II: MI);
391 Register NewDstReg = MRI.cloneVirtualRegister(VReg: DstReg);
392
393 if (Ops.size() == 1)
394 Builder.buildCopy(Res: NewDstReg, Op: Ops[0]);
395 else
396 Builder.buildMergeLikeInstr(Res: NewDstReg, Ops);
397
398 MI.eraseFromParent();
399 replaceRegWith(MRI, FromReg: DstReg, ToReg: NewDstReg);
400}
401
402bool CombinerHelper::matchShuffleToExtract(MachineInstr &MI) {
403 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR &&
404 "Invalid instruction kind");
405
406 ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask();
407 return Mask.size() == 1;
408}
409
410void CombinerHelper::applyShuffleToExtract(MachineInstr &MI) {
411 Register DstReg = MI.getOperand(i: 0).getReg();
412 Builder.setInsertPt(MBB&: *MI.getParent(), II: MI);
413
414 int I = MI.getOperand(i: 3).getShuffleMask()[0];
415 Register Src1 = MI.getOperand(i: 1).getReg();
416 LLT Src1Ty = MRI.getType(Reg: Src1);
417 int Src1NumElts = Src1Ty.isVector() ? Src1Ty.getNumElements() : 1;
418 Register SrcReg;
419 if (I >= Src1NumElts) {
420 SrcReg = MI.getOperand(i: 2).getReg();
421 I -= Src1NumElts;
422 } else if (I >= 0)
423 SrcReg = Src1;
424
425 if (I < 0)
426 Builder.buildUndef(Res: DstReg);
427 else if (!MRI.getType(Reg: SrcReg).isVector())
428 Builder.buildCopy(Res: DstReg, Op: SrcReg);
429 else
430 Builder.buildExtractVectorElementConstant(Res: DstReg, Val: SrcReg, Idx: I);
431
432 MI.eraseFromParent();
433}
434
435namespace {
436
437/// Select a preference between two uses. CurrentUse is the current preference
438/// while *ForCandidate is attributes of the candidate under consideration.
439PreferredTuple ChoosePreferredUse(MachineInstr &LoadMI,
440 PreferredTuple &CurrentUse,
441 const LLT TyForCandidate,
442 unsigned OpcodeForCandidate,
443 MachineInstr *MIForCandidate) {
444 if (!CurrentUse.Ty.isValid()) {
445 if (CurrentUse.ExtendOpcode == OpcodeForCandidate ||
446 CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT)
447 return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate};
448 return CurrentUse;
449 }
450
451 // We permit the extend to hoist through basic blocks but this is only
452 // sensible if the target has extending loads. If you end up lowering back
453 // into a load and extend during the legalizer then the end result is
454 // hoisting the extend up to the load.
455
456 // Prefer defined extensions to undefined extensions as these are more
457 // likely to reduce the number of instructions.
458 if (OpcodeForCandidate == TargetOpcode::G_ANYEXT &&
459 CurrentUse.ExtendOpcode != TargetOpcode::G_ANYEXT)
460 return CurrentUse;
461 else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT &&
462 OpcodeForCandidate != TargetOpcode::G_ANYEXT)
463 return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate};
464
465 // Prefer sign extensions to zero extensions as sign-extensions tend to be
466 // more expensive. Don't do this if the load is already a zero-extend load
467 // though, otherwise we'll rewrite a zero-extend load into a sign-extend
468 // later.
469 if (!isa<GZExtLoad>(Val: LoadMI) && CurrentUse.Ty == TyForCandidate) {
470 if (CurrentUse.ExtendOpcode == TargetOpcode::G_SEXT &&
471 OpcodeForCandidate == TargetOpcode::G_ZEXT)
472 return CurrentUse;
473 else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ZEXT &&
474 OpcodeForCandidate == TargetOpcode::G_SEXT)
475 return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate};
476 }
477
478 // This is potentially target specific. We've chosen the largest type
479 // because G_TRUNC is usually free. One potential catch with this is that
480 // some targets have a reduced number of larger registers than smaller
481 // registers and this choice potentially increases the live-range for the
482 // larger value.
483 if (TyForCandidate.getSizeInBits() > CurrentUse.Ty.getSizeInBits()) {
484 return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate};
485 }
486 return CurrentUse;
487}
488
489/// Find a suitable place to insert some instructions and insert them. This
490/// function accounts for special cases like inserting before a PHI node.
491/// The current strategy for inserting before PHI's is to duplicate the
492/// instructions for each predecessor. However, while that's ok for G_TRUNC
493/// on most targets since it generally requires no code, other targets/cases may
494/// want to try harder to find a dominating block.
495static void InsertInsnsWithoutSideEffectsBeforeUse(
496 MachineIRBuilder &Builder, MachineInstr &DefMI, MachineOperand &UseMO,
497 std::function<void(MachineBasicBlock *, MachineBasicBlock::iterator,
498 MachineOperand &UseMO)>
499 Inserter) {
500 MachineInstr &UseMI = *UseMO.getParent();
501
502 MachineBasicBlock *InsertBB = UseMI.getParent();
503
504 // If the use is a PHI then we want the predecessor block instead.
505 if (UseMI.isPHI()) {
506 MachineOperand *PredBB = std::next(x: &UseMO);
507 InsertBB = PredBB->getMBB();
508 }
509
510 // If the block is the same block as the def then we want to insert just after
511 // the def instead of at the start of the block.
512 if (InsertBB == DefMI.getParent()) {
513 MachineBasicBlock::iterator InsertPt = &DefMI;
514 Inserter(InsertBB, std::next(x: InsertPt), UseMO);
515 return;
516 }
517
518 // Otherwise we want the start of the BB
519 Inserter(InsertBB, InsertBB->getFirstNonPHI(), UseMO);
520}
521} // end anonymous namespace
522
523bool CombinerHelper::tryCombineExtendingLoads(MachineInstr &MI) {
524 PreferredTuple Preferred;
525 if (matchCombineExtendingLoads(MI, MatchInfo&: Preferred)) {
526 applyCombineExtendingLoads(MI, MatchInfo&: Preferred);
527 return true;
528 }
529 return false;
530}
531
532static unsigned getExtLoadOpcForExtend(unsigned ExtOpc) {
533 unsigned CandidateLoadOpc;
534 switch (ExtOpc) {
535 case TargetOpcode::G_ANYEXT:
536 CandidateLoadOpc = TargetOpcode::G_LOAD;
537 break;
538 case TargetOpcode::G_SEXT:
539 CandidateLoadOpc = TargetOpcode::G_SEXTLOAD;
540 break;
541 case TargetOpcode::G_ZEXT:
542 CandidateLoadOpc = TargetOpcode::G_ZEXTLOAD;
543 break;
544 default:
545 llvm_unreachable("Unexpected extend opc");
546 }
547 return CandidateLoadOpc;
548}
549
550bool CombinerHelper::matchCombineExtendingLoads(MachineInstr &MI,
551 PreferredTuple &Preferred) {
552 // We match the loads and follow the uses to the extend instead of matching
553 // the extends and following the def to the load. This is because the load
554 // must remain in the same position for correctness (unless we also add code
555 // to find a safe place to sink it) whereas the extend is freely movable.
556 // It also prevents us from duplicating the load for the volatile case or just
557 // for performance.
558 GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(Val: &MI);
559 if (!LoadMI)
560 return false;
561
562 Register LoadReg = LoadMI->getDstReg();
563
564 LLT LoadValueTy = MRI.getType(Reg: LoadReg);
565 if (!LoadValueTy.isScalar())
566 return false;
567
568 // Most architectures are going to legalize <s8 loads into at least a 1 byte
569 // load, and the MMOs can only describe memory accesses in multiples of bytes.
570 // If we try to perform extload combining on those, we can end up with
571 // %a(s8) = extload %ptr (load 1 byte from %ptr)
572 // ... which is an illegal extload instruction.
573 if (LoadValueTy.getSizeInBits() < 8)
574 return false;
575
576 // For non power-of-2 types, they will very likely be legalized into multiple
577 // loads. Don't bother trying to match them into extending loads.
578 if (!llvm::has_single_bit<uint32_t>(Value: LoadValueTy.getSizeInBits()))
579 return false;
580
581 // Find the preferred type aside from the any-extends (unless it's the only
582 // one) and non-extending ops. We'll emit an extending load to that type and
583 // and emit a variant of (extend (trunc X)) for the others according to the
584 // relative type sizes. At the same time, pick an extend to use based on the
585 // extend involved in the chosen type.
586 unsigned PreferredOpcode =
587 isa<GLoad>(Val: &MI)
588 ? TargetOpcode::G_ANYEXT
589 : isa<GSExtLoad>(Val: &MI) ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT;
590 Preferred = {.Ty: LLT(), .ExtendOpcode: PreferredOpcode, .MI: nullptr};
591 for (auto &UseMI : MRI.use_nodbg_instructions(Reg: LoadReg)) {
592 if (UseMI.getOpcode() == TargetOpcode::G_SEXT ||
593 UseMI.getOpcode() == TargetOpcode::G_ZEXT ||
594 (UseMI.getOpcode() == TargetOpcode::G_ANYEXT)) {
595 const auto &MMO = LoadMI->getMMO();
596 // For atomics, only form anyextending loads.
597 if (MMO.isAtomic() && UseMI.getOpcode() != TargetOpcode::G_ANYEXT)
598 continue;
599 // Check for legality.
600 if (!isPreLegalize()) {
601 LegalityQuery::MemDesc MMDesc(MMO);
602 unsigned CandidateLoadOpc = getExtLoadOpcForExtend(ExtOpc: UseMI.getOpcode());
603 LLT UseTy = MRI.getType(Reg: UseMI.getOperand(i: 0).getReg());
604 LLT SrcTy = MRI.getType(Reg: LoadMI->getPointerReg());
605 if (LI->getAction(Query: {CandidateLoadOpc, {UseTy, SrcTy}, {MMDesc}})
606 .Action != LegalizeActions::Legal)
607 continue;
608 }
609 Preferred = ChoosePreferredUse(LoadMI&: MI, CurrentUse&: Preferred,
610 TyForCandidate: MRI.getType(Reg: UseMI.getOperand(i: 0).getReg()),
611 OpcodeForCandidate: UseMI.getOpcode(), MIForCandidate: &UseMI);
612 }
613 }
614
615 // There were no extends
616 if (!Preferred.MI)
617 return false;
618 // It should be impossible to chose an extend without selecting a different
619 // type since by definition the result of an extend is larger.
620 assert(Preferred.Ty != LoadValueTy && "Extending to same type?");
621
622 LLVM_DEBUG(dbgs() << "Preferred use is: " << *Preferred.MI);
623 return true;
624}
625
626void CombinerHelper::applyCombineExtendingLoads(MachineInstr &MI,
627 PreferredTuple &Preferred) {
628 // Rewrite the load to the chosen extending load.
629 Register ChosenDstReg = Preferred.MI->getOperand(i: 0).getReg();
630
631 // Inserter to insert a truncate back to the original type at a given point
632 // with some basic CSE to limit truncate duplication to one per BB.
633 DenseMap<MachineBasicBlock *, MachineInstr *> EmittedInsns;
634 auto InsertTruncAt = [&](MachineBasicBlock *InsertIntoBB,
635 MachineBasicBlock::iterator InsertBefore,
636 MachineOperand &UseMO) {
637 MachineInstr *PreviouslyEmitted = EmittedInsns.lookup(Val: InsertIntoBB);
638 if (PreviouslyEmitted) {
639 Observer.changingInstr(MI&: *UseMO.getParent());
640 UseMO.setReg(PreviouslyEmitted->getOperand(i: 0).getReg());
641 Observer.changedInstr(MI&: *UseMO.getParent());
642 return;
643 }
644
645 Builder.setInsertPt(MBB&: *InsertIntoBB, II: InsertBefore);
646 Register NewDstReg = MRI.cloneVirtualRegister(VReg: MI.getOperand(i: 0).getReg());
647 MachineInstr *NewMI = Builder.buildTrunc(Res: NewDstReg, Op: ChosenDstReg);
648 EmittedInsns[InsertIntoBB] = NewMI;
649 replaceRegOpWith(MRI, FromRegOp&: UseMO, ToReg: NewDstReg);
650 };
651
652 Observer.changingInstr(MI);
653 unsigned LoadOpc = getExtLoadOpcForExtend(ExtOpc: Preferred.ExtendOpcode);
654 MI.setDesc(Builder.getTII().get(Opcode: LoadOpc));
655
656 // Rewrite all the uses to fix up the types.
657 auto &LoadValue = MI.getOperand(i: 0);
658 SmallVector<MachineOperand *, 4> Uses;
659 for (auto &UseMO : MRI.use_operands(Reg: LoadValue.getReg()))
660 Uses.push_back(Elt: &UseMO);
661
662 for (auto *UseMO : Uses) {
663 MachineInstr *UseMI = UseMO->getParent();
664
665 // If the extend is compatible with the preferred extend then we should fix
666 // up the type and extend so that it uses the preferred use.
667 if (UseMI->getOpcode() == Preferred.ExtendOpcode ||
668 UseMI->getOpcode() == TargetOpcode::G_ANYEXT) {
669 Register UseDstReg = UseMI->getOperand(i: 0).getReg();
670 MachineOperand &UseSrcMO = UseMI->getOperand(i: 1);
671 const LLT UseDstTy = MRI.getType(Reg: UseDstReg);
672 if (UseDstReg != ChosenDstReg) {
673 if (Preferred.Ty == UseDstTy) {
674 // If the use has the same type as the preferred use, then merge
675 // the vregs and erase the extend. For example:
676 // %1:_(s8) = G_LOAD ...
677 // %2:_(s32) = G_SEXT %1(s8)
678 // %3:_(s32) = G_ANYEXT %1(s8)
679 // ... = ... %3(s32)
680 // rewrites to:
681 // %2:_(s32) = G_SEXTLOAD ...
682 // ... = ... %2(s32)
683 replaceRegWith(MRI, FromReg: UseDstReg, ToReg: ChosenDstReg);
684 Observer.erasingInstr(MI&: *UseMO->getParent());
685 UseMO->getParent()->eraseFromParent();
686 } else if (Preferred.Ty.getSizeInBits() < UseDstTy.getSizeInBits()) {
687 // If the preferred size is smaller, then keep the extend but extend
688 // from the result of the extending load. For example:
689 // %1:_(s8) = G_LOAD ...
690 // %2:_(s32) = G_SEXT %1(s8)
691 // %3:_(s64) = G_ANYEXT %1(s8)
692 // ... = ... %3(s64)
693 /// rewrites to:
694 // %2:_(s32) = G_SEXTLOAD ...
695 // %3:_(s64) = G_ANYEXT %2:_(s32)
696 // ... = ... %3(s64)
697 replaceRegOpWith(MRI, FromRegOp&: UseSrcMO, ToReg: ChosenDstReg);
698 } else {
699 // If the preferred size is large, then insert a truncate. For
700 // example:
701 // %1:_(s8) = G_LOAD ...
702 // %2:_(s64) = G_SEXT %1(s8)
703 // %3:_(s32) = G_ZEXT %1(s8)
704 // ... = ... %3(s32)
705 /// rewrites to:
706 // %2:_(s64) = G_SEXTLOAD ...
707 // %4:_(s8) = G_TRUNC %2:_(s32)
708 // %3:_(s64) = G_ZEXT %2:_(s8)
709 // ... = ... %3(s64)
710 InsertInsnsWithoutSideEffectsBeforeUse(Builder, DefMI&: MI, UseMO&: *UseMO,
711 Inserter: InsertTruncAt);
712 }
713 continue;
714 }
715 // The use is (one of) the uses of the preferred use we chose earlier.
716 // We're going to update the load to def this value later so just erase
717 // the old extend.
718 Observer.erasingInstr(MI&: *UseMO->getParent());
719 UseMO->getParent()->eraseFromParent();
720 continue;
721 }
722
723 // The use isn't an extend. Truncate back to the type we originally loaded.
724 // This is free on many targets.
725 InsertInsnsWithoutSideEffectsBeforeUse(Builder, DefMI&: MI, UseMO&: *UseMO, Inserter: InsertTruncAt);
726 }
727
728 MI.getOperand(i: 0).setReg(ChosenDstReg);
729 Observer.changedInstr(MI);
730}
731
732bool CombinerHelper::matchCombineLoadWithAndMask(MachineInstr &MI,
733 BuildFnTy &MatchInfo) {
734 assert(MI.getOpcode() == TargetOpcode::G_AND);
735
736 // If we have the following code:
737 // %mask = G_CONSTANT 255
738 // %ld = G_LOAD %ptr, (load s16)
739 // %and = G_AND %ld, %mask
740 //
741 // Try to fold it into
742 // %ld = G_ZEXTLOAD %ptr, (load s8)
743
744 Register Dst = MI.getOperand(i: 0).getReg();
745 if (MRI.getType(Reg: Dst).isVector())
746 return false;
747
748 auto MaybeMask =
749 getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI);
750 if (!MaybeMask)
751 return false;
752
753 APInt MaskVal = MaybeMask->Value;
754
755 if (!MaskVal.isMask())
756 return false;
757
758 Register SrcReg = MI.getOperand(i: 1).getReg();
759 // Don't use getOpcodeDef() here since intermediate instructions may have
760 // multiple users.
761 GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(Val: MRI.getVRegDef(Reg: SrcReg));
762 if (!LoadMI || !MRI.hasOneNonDBGUse(RegNo: LoadMI->getDstReg()))
763 return false;
764
765 Register LoadReg = LoadMI->getDstReg();
766 LLT RegTy = MRI.getType(Reg: LoadReg);
767 Register PtrReg = LoadMI->getPointerReg();
768 unsigned RegSize = RegTy.getSizeInBits();
769 uint64_t LoadSizeBits = LoadMI->getMemSizeInBits();
770 unsigned MaskSizeBits = MaskVal.countr_one();
771
772 // The mask may not be larger than the in-memory type, as it might cover sign
773 // extended bits
774 if (MaskSizeBits > LoadSizeBits)
775 return false;
776
777 // If the mask covers the whole destination register, there's nothing to
778 // extend
779 if (MaskSizeBits >= RegSize)
780 return false;
781
782 // Most targets cannot deal with loads of size < 8 and need to re-legalize to
783 // at least byte loads. Avoid creating such loads here
784 if (MaskSizeBits < 8 || !isPowerOf2_32(Value: MaskSizeBits))
785 return false;
786
787 const MachineMemOperand &MMO = LoadMI->getMMO();
788 LegalityQuery::MemDesc MemDesc(MMO);
789
790 // Don't modify the memory access size if this is atomic/volatile, but we can
791 // still adjust the opcode to indicate the high bit behavior.
792 if (LoadMI->isSimple())
793 MemDesc.MemoryTy = LLT::scalar(SizeInBits: MaskSizeBits);
794 else if (LoadSizeBits > MaskSizeBits || LoadSizeBits == RegSize)
795 return false;
796
797 // TODO: Could check if it's legal with the reduced or original memory size.
798 if (!isLegalOrBeforeLegalizer(
799 Query: {TargetOpcode::G_ZEXTLOAD, {RegTy, MRI.getType(Reg: PtrReg)}, {MemDesc}}))
800 return false;
801
802 MatchInfo = [=](MachineIRBuilder &B) {
803 B.setInstrAndDebugLoc(*LoadMI);
804 auto &MF = B.getMF();
805 auto PtrInfo = MMO.getPointerInfo();
806 auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Ty: MemDesc.MemoryTy);
807 B.buildLoadInstr(Opcode: TargetOpcode::G_ZEXTLOAD, Res: Dst, Addr: PtrReg, MMO&: *NewMMO);
808 LoadMI->eraseFromParent();
809 };
810 return true;
811}
812
813bool CombinerHelper::isPredecessor(const MachineInstr &DefMI,
814 const MachineInstr &UseMI) {
815 assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() &&
816 "shouldn't consider debug uses");
817 assert(DefMI.getParent() == UseMI.getParent());
818 if (&DefMI == &UseMI)
819 return true;
820 const MachineBasicBlock &MBB = *DefMI.getParent();
821 auto DefOrUse = find_if(Range: MBB, P: [&DefMI, &UseMI](const MachineInstr &MI) {
822 return &MI == &DefMI || &MI == &UseMI;
823 });
824 if (DefOrUse == MBB.end())
825 llvm_unreachable("Block must contain both DefMI and UseMI!");
826 return &*DefOrUse == &DefMI;
827}
828
829bool CombinerHelper::dominates(const MachineInstr &DefMI,
830 const MachineInstr &UseMI) {
831 assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() &&
832 "shouldn't consider debug uses");
833 if (MDT)
834 return MDT->dominates(A: &DefMI, B: &UseMI);
835 else if (DefMI.getParent() != UseMI.getParent())
836 return false;
837
838 return isPredecessor(DefMI, UseMI);
839}
840
841bool CombinerHelper::matchSextTruncSextLoad(MachineInstr &MI) {
842 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
843 Register SrcReg = MI.getOperand(i: 1).getReg();
844 Register LoadUser = SrcReg;
845
846 if (MRI.getType(Reg: SrcReg).isVector())
847 return false;
848
849 Register TruncSrc;
850 if (mi_match(R: SrcReg, MRI, P: m_GTrunc(Src: m_Reg(R&: TruncSrc))))
851 LoadUser = TruncSrc;
852
853 uint64_t SizeInBits = MI.getOperand(i: 2).getImm();
854 // If the source is a G_SEXTLOAD from the same bit width, then we don't
855 // need any extend at all, just a truncate.
856 if (auto *LoadMI = getOpcodeDef<GSExtLoad>(Reg: LoadUser, MRI)) {
857 // If truncating more than the original extended value, abort.
858 auto LoadSizeBits = LoadMI->getMemSizeInBits();
859 if (TruncSrc && MRI.getType(Reg: TruncSrc).getSizeInBits() < LoadSizeBits)
860 return false;
861 if (LoadSizeBits == SizeInBits)
862 return true;
863 }
864 return false;
865}
866
867void CombinerHelper::applySextTruncSextLoad(MachineInstr &MI) {
868 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
869 Builder.setInstrAndDebugLoc(MI);
870 Builder.buildCopy(Res: MI.getOperand(i: 0).getReg(), Op: MI.getOperand(i: 1).getReg());
871 MI.eraseFromParent();
872}
873
874bool CombinerHelper::matchSextInRegOfLoad(
875 MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) {
876 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
877
878 Register DstReg = MI.getOperand(i: 0).getReg();
879 LLT RegTy = MRI.getType(Reg: DstReg);
880
881 // Only supports scalars for now.
882 if (RegTy.isVector())
883 return false;
884
885 Register SrcReg = MI.getOperand(i: 1).getReg();
886 auto *LoadDef = getOpcodeDef<GLoad>(Reg: SrcReg, MRI);
887 if (!LoadDef || !MRI.hasOneNonDBGUse(RegNo: DstReg))
888 return false;
889
890 uint64_t MemBits = LoadDef->getMemSizeInBits();
891
892 // If the sign extend extends from a narrower width than the load's width,
893 // then we can narrow the load width when we combine to a G_SEXTLOAD.
894 // Avoid widening the load at all.
895 unsigned NewSizeBits = std::min(a: (uint64_t)MI.getOperand(i: 2).getImm(), b: MemBits);
896
897 // Don't generate G_SEXTLOADs with a < 1 byte width.
898 if (NewSizeBits < 8)
899 return false;
900 // Don't bother creating a non-power-2 sextload, it will likely be broken up
901 // anyway for most targets.
902 if (!isPowerOf2_32(Value: NewSizeBits))
903 return false;
904
905 const MachineMemOperand &MMO = LoadDef->getMMO();
906 LegalityQuery::MemDesc MMDesc(MMO);
907
908 // Don't modify the memory access size if this is atomic/volatile, but we can
909 // still adjust the opcode to indicate the high bit behavior.
910 if (LoadDef->isSimple())
911 MMDesc.MemoryTy = LLT::scalar(SizeInBits: NewSizeBits);
912 else if (MemBits > NewSizeBits || MemBits == RegTy.getSizeInBits())
913 return false;
914
915 // TODO: Could check if it's legal with the reduced or original memory size.
916 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SEXTLOAD,
917 {MRI.getType(Reg: LoadDef->getDstReg()),
918 MRI.getType(Reg: LoadDef->getPointerReg())},
919 {MMDesc}}))
920 return false;
921
922 MatchInfo = std::make_tuple(args: LoadDef->getDstReg(), args&: NewSizeBits);
923 return true;
924}
925
926void CombinerHelper::applySextInRegOfLoad(
927 MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) {
928 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
929 Register LoadReg;
930 unsigned ScalarSizeBits;
931 std::tie(args&: LoadReg, args&: ScalarSizeBits) = MatchInfo;
932 GLoad *LoadDef = cast<GLoad>(Val: MRI.getVRegDef(Reg: LoadReg));
933
934 // If we have the following:
935 // %ld = G_LOAD %ptr, (load 2)
936 // %ext = G_SEXT_INREG %ld, 8
937 // ==>
938 // %ld = G_SEXTLOAD %ptr (load 1)
939
940 auto &MMO = LoadDef->getMMO();
941 Builder.setInstrAndDebugLoc(*LoadDef);
942 auto &MF = Builder.getMF();
943 auto PtrInfo = MMO.getPointerInfo();
944 auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Size: ScalarSizeBits / 8);
945 Builder.buildLoadInstr(Opcode: TargetOpcode::G_SEXTLOAD, Res: MI.getOperand(i: 0).getReg(),
946 Addr: LoadDef->getPointerReg(), MMO&: *NewMMO);
947 MI.eraseFromParent();
948}
949
950static Type *getTypeForLLT(LLT Ty, LLVMContext &C) {
951 if (Ty.isVector())
952 return FixedVectorType::get(ElementType: IntegerType::get(C, NumBits: Ty.getScalarSizeInBits()),
953 NumElts: Ty.getNumElements());
954 return IntegerType::get(C, NumBits: Ty.getSizeInBits());
955}
956
957/// Return true if 'MI' is a load or a store that may be fold it's address
958/// operand into the load / store addressing mode.
959static bool canFoldInAddressingMode(GLoadStore *MI, const TargetLowering &TLI,
960 MachineRegisterInfo &MRI) {
961 TargetLowering::AddrMode AM;
962 auto *MF = MI->getMF();
963 auto *Addr = getOpcodeDef<GPtrAdd>(Reg: MI->getPointerReg(), MRI);
964 if (!Addr)
965 return false;
966
967 AM.HasBaseReg = true;
968 if (auto CstOff = getIConstantVRegVal(VReg: Addr->getOffsetReg(), MRI))
969 AM.BaseOffs = CstOff->getSExtValue(); // [reg +/- imm]
970 else
971 AM.Scale = 1; // [reg +/- reg]
972
973 return TLI.isLegalAddressingMode(
974 DL: MF->getDataLayout(), AM,
975 Ty: getTypeForLLT(Ty: MI->getMMO().getMemoryType(),
976 C&: MF->getFunction().getContext()),
977 AddrSpace: MI->getMMO().getAddrSpace());
978}
979
980static unsigned getIndexedOpc(unsigned LdStOpc) {
981 switch (LdStOpc) {
982 case TargetOpcode::G_LOAD:
983 return TargetOpcode::G_INDEXED_LOAD;
984 case TargetOpcode::G_STORE:
985 return TargetOpcode::G_INDEXED_STORE;
986 case TargetOpcode::G_ZEXTLOAD:
987 return TargetOpcode::G_INDEXED_ZEXTLOAD;
988 case TargetOpcode::G_SEXTLOAD:
989 return TargetOpcode::G_INDEXED_SEXTLOAD;
990 default:
991 llvm_unreachable("Unexpected opcode");
992 }
993}
994
995bool CombinerHelper::isIndexedLoadStoreLegal(GLoadStore &LdSt) const {
996 // Check for legality.
997 LLT PtrTy = MRI.getType(Reg: LdSt.getPointerReg());
998 LLT Ty = MRI.getType(Reg: LdSt.getReg(Idx: 0));
999 LLT MemTy = LdSt.getMMO().getMemoryType();
1000 SmallVector<LegalityQuery::MemDesc, 2> MemDescrs(
1001 {{MemTy, MemTy.getSizeInBits(), AtomicOrdering::NotAtomic}});
1002 unsigned IndexedOpc = getIndexedOpc(LdStOpc: LdSt.getOpcode());
1003 SmallVector<LLT> OpTys;
1004 if (IndexedOpc == TargetOpcode::G_INDEXED_STORE)
1005 OpTys = {PtrTy, Ty, Ty};
1006 else
1007 OpTys = {Ty, PtrTy}; // For G_INDEXED_LOAD, G_INDEXED_[SZ]EXTLOAD
1008
1009 LegalityQuery Q(IndexedOpc, OpTys, MemDescrs);
1010 return isLegal(Query: Q);
1011}
1012
1013static cl::opt<unsigned> PostIndexUseThreshold(
1014 "post-index-use-threshold", cl::Hidden, cl::init(Val: 32),
1015 cl::desc("Number of uses of a base pointer to check before it is no longer "
1016 "considered for post-indexing."));
1017
1018bool CombinerHelper::findPostIndexCandidate(GLoadStore &LdSt, Register &Addr,
1019 Register &Base, Register &Offset,
1020 bool &RematOffset) {
1021 // We're looking for the following pattern, for either load or store:
1022 // %baseptr:_(p0) = ...
1023 // G_STORE %val(s64), %baseptr(p0)
1024 // %offset:_(s64) = G_CONSTANT i64 -256
1025 // %new_addr:_(p0) = G_PTR_ADD %baseptr, %offset(s64)
1026 const auto &TLI = getTargetLowering();
1027
1028 Register Ptr = LdSt.getPointerReg();
1029 // If the store is the only use, don't bother.
1030 if (MRI.hasOneNonDBGUse(RegNo: Ptr))
1031 return false;
1032
1033 if (!isIndexedLoadStoreLegal(LdSt))
1034 return false;
1035
1036 if (getOpcodeDef(Opcode: TargetOpcode::G_FRAME_INDEX, Reg: Ptr, MRI))
1037 return false;
1038
1039 MachineInstr *StoredValDef = getDefIgnoringCopies(Reg: LdSt.getReg(Idx: 0), MRI);
1040 auto *PtrDef = MRI.getVRegDef(Reg: Ptr);
1041
1042 unsigned NumUsesChecked = 0;
1043 for (auto &Use : MRI.use_nodbg_instructions(Reg: Ptr)) {
1044 if (++NumUsesChecked > PostIndexUseThreshold)
1045 return false; // Try to avoid exploding compile time.
1046
1047 auto *PtrAdd = dyn_cast<GPtrAdd>(Val: &Use);
1048 // The use itself might be dead. This can happen during combines if DCE
1049 // hasn't had a chance to run yet. Don't allow it to form an indexed op.
1050 if (!PtrAdd || MRI.use_nodbg_empty(RegNo: PtrAdd->getReg(Idx: 0)))
1051 continue;
1052
1053 // Check the user of this isn't the store, otherwise we'd be generate a
1054 // indexed store defining its own use.
1055 if (StoredValDef == &Use)
1056 continue;
1057
1058 Offset = PtrAdd->getOffsetReg();
1059 if (!ForceLegalIndexing &&
1060 !TLI.isIndexingLegal(MI&: LdSt, Base: PtrAdd->getBaseReg(), Offset,
1061 /*IsPre*/ false, MRI))
1062 continue;
1063
1064 // Make sure the offset calculation is before the potentially indexed op.
1065 MachineInstr *OffsetDef = MRI.getVRegDef(Reg: Offset);
1066 RematOffset = false;
1067 if (!dominates(DefMI: *OffsetDef, UseMI: LdSt)) {
1068 // If the offset however is just a G_CONSTANT, we can always just
1069 // rematerialize it where we need it.
1070 if (OffsetDef->getOpcode() != TargetOpcode::G_CONSTANT)
1071 continue;
1072 RematOffset = true;
1073 }
1074
1075 for (auto &BasePtrUse : MRI.use_nodbg_instructions(Reg: PtrAdd->getBaseReg())) {
1076 if (&BasePtrUse == PtrDef)
1077 continue;
1078
1079 // If the user is a later load/store that can be post-indexed, then don't
1080 // combine this one.
1081 auto *BasePtrLdSt = dyn_cast<GLoadStore>(Val: &BasePtrUse);
1082 if (BasePtrLdSt && BasePtrLdSt != &LdSt &&
1083 dominates(DefMI: LdSt, UseMI: *BasePtrLdSt) &&
1084 isIndexedLoadStoreLegal(LdSt&: *BasePtrLdSt))
1085 return false;
1086
1087 // Now we're looking for the key G_PTR_ADD instruction, which contains
1088 // the offset add that we want to fold.
1089 if (auto *BasePtrUseDef = dyn_cast<GPtrAdd>(Val: &BasePtrUse)) {
1090 Register PtrAddDefReg = BasePtrUseDef->getReg(Idx: 0);
1091 for (auto &BaseUseUse : MRI.use_nodbg_instructions(Reg: PtrAddDefReg)) {
1092 // If the use is in a different block, then we may produce worse code
1093 // due to the extra register pressure.
1094 if (BaseUseUse.getParent() != LdSt.getParent())
1095 return false;
1096
1097 if (auto *UseUseLdSt = dyn_cast<GLoadStore>(Val: &BaseUseUse))
1098 if (canFoldInAddressingMode(MI: UseUseLdSt, TLI, MRI))
1099 return false;
1100 }
1101 if (!dominates(DefMI: LdSt, UseMI: BasePtrUse))
1102 return false; // All use must be dominated by the load/store.
1103 }
1104 }
1105
1106 Addr = PtrAdd->getReg(Idx: 0);
1107 Base = PtrAdd->getBaseReg();
1108 return true;
1109 }
1110
1111 return false;
1112}
1113
1114bool CombinerHelper::findPreIndexCandidate(GLoadStore &LdSt, Register &Addr,
1115 Register &Base, Register &Offset) {
1116 auto &MF = *LdSt.getParent()->getParent();
1117 const auto &TLI = *MF.getSubtarget().getTargetLowering();
1118
1119 Addr = LdSt.getPointerReg();
1120 if (!mi_match(R: Addr, MRI, P: m_GPtrAdd(L: m_Reg(R&: Base), R: m_Reg(R&: Offset))) ||
1121 MRI.hasOneNonDBGUse(RegNo: Addr))
1122 return false;
1123
1124 if (!ForceLegalIndexing &&
1125 !TLI.isIndexingLegal(MI&: LdSt, Base, Offset, /*IsPre*/ true, MRI))
1126 return false;
1127
1128 if (!isIndexedLoadStoreLegal(LdSt))
1129 return false;
1130
1131 MachineInstr *BaseDef = getDefIgnoringCopies(Reg: Base, MRI);
1132 if (BaseDef->getOpcode() == TargetOpcode::G_FRAME_INDEX)
1133 return false;
1134
1135 if (auto *St = dyn_cast<GStore>(Val: &LdSt)) {
1136 // Would require a copy.
1137 if (Base == St->getValueReg())
1138 return false;
1139
1140 // We're expecting one use of Addr in MI, but it could also be the
1141 // value stored, which isn't actually dominated by the instruction.
1142 if (St->getValueReg() == Addr)
1143 return false;
1144 }
1145
1146 // Avoid increasing cross-block register pressure.
1147 for (auto &AddrUse : MRI.use_nodbg_instructions(Reg: Addr))
1148 if (AddrUse.getParent() != LdSt.getParent())
1149 return false;
1150
1151 // FIXME: check whether all uses of the base pointer are constant PtrAdds.
1152 // That might allow us to end base's liveness here by adjusting the constant.
1153 bool RealUse = false;
1154 for (auto &AddrUse : MRI.use_nodbg_instructions(Reg: Addr)) {
1155 if (!dominates(DefMI: LdSt, UseMI: AddrUse))
1156 return false; // All use must be dominated by the load/store.
1157
1158 // If Ptr may be folded in addressing mode of other use, then it's
1159 // not profitable to do this transformation.
1160 if (auto *UseLdSt = dyn_cast<GLoadStore>(Val: &AddrUse)) {
1161 if (!canFoldInAddressingMode(MI: UseLdSt, TLI, MRI))
1162 RealUse = true;
1163 } else {
1164 RealUse = true;
1165 }
1166 }
1167 return RealUse;
1168}
1169
1170bool CombinerHelper::matchCombineExtractedVectorLoad(MachineInstr &MI,
1171 BuildFnTy &MatchInfo) {
1172 assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT);
1173
1174 // Check if there is a load that defines the vector being extracted from.
1175 auto *LoadMI = getOpcodeDef<GLoad>(Reg: MI.getOperand(i: 1).getReg(), MRI);
1176 if (!LoadMI)
1177 return false;
1178
1179 Register Vector = MI.getOperand(i: 1).getReg();
1180 LLT VecEltTy = MRI.getType(Reg: Vector).getElementType();
1181
1182 assert(MRI.getType(MI.getOperand(0).getReg()) == VecEltTy);
1183
1184 // Checking whether we should reduce the load width.
1185 if (!MRI.hasOneNonDBGUse(RegNo: Vector))
1186 return false;
1187
1188 // Check if the defining load is simple.
1189 if (!LoadMI->isSimple())
1190 return false;
1191
1192 // If the vector element type is not a multiple of a byte then we are unable
1193 // to correctly compute an address to load only the extracted element as a
1194 // scalar.
1195 if (!VecEltTy.isByteSized())
1196 return false;
1197
1198 // Check if the new load that we are going to create is legal
1199 // if we are in the post-legalization phase.
1200 MachineMemOperand MMO = LoadMI->getMMO();
1201 Align Alignment = MMO.getAlign();
1202 MachinePointerInfo PtrInfo;
1203 uint64_t Offset;
1204
1205 // Finding the appropriate PtrInfo if offset is a known constant.
1206 // This is required to create the memory operand for the narrowed load.
1207 // This machine memory operand object helps us infer about legality
1208 // before we proceed to combine the instruction.
1209 if (auto CVal = getIConstantVRegVal(VReg: Vector, MRI)) {
1210 int Elt = CVal->getZExtValue();
1211 // FIXME: should be (ABI size)*Elt.
1212 Offset = VecEltTy.getSizeInBits() * Elt / 8;
1213 PtrInfo = MMO.getPointerInfo().getWithOffset(O: Offset);
1214 } else {
1215 // Discard the pointer info except the address space because the memory
1216 // operand can't represent this new access since the offset is variable.
1217 Offset = VecEltTy.getSizeInBits() / 8;
1218 PtrInfo = MachinePointerInfo(MMO.getPointerInfo().getAddrSpace());
1219 }
1220
1221 Alignment = commonAlignment(A: Alignment, Offset);
1222
1223 Register VecPtr = LoadMI->getPointerReg();
1224 LLT PtrTy = MRI.getType(Reg: VecPtr);
1225
1226 MachineFunction &MF = *MI.getMF();
1227 auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Ty: VecEltTy);
1228
1229 LegalityQuery::MemDesc MMDesc(*NewMMO);
1230
1231 LegalityQuery Q = {TargetOpcode::G_LOAD, {VecEltTy, PtrTy}, {MMDesc}};
1232
1233 if (!isLegalOrBeforeLegalizer(Query: Q))
1234 return false;
1235
1236 // Load must be allowed and fast on the target.
1237 LLVMContext &C = MF.getFunction().getContext();
1238 auto &DL = MF.getDataLayout();
1239 unsigned Fast = 0;
1240 if (!getTargetLowering().allowsMemoryAccess(Context&: C, DL, Ty: VecEltTy, MMO: *NewMMO,
1241 Fast: &Fast) ||
1242 !Fast)
1243 return false;
1244
1245 Register Result = MI.getOperand(i: 0).getReg();
1246 Register Index = MI.getOperand(i: 2).getReg();
1247
1248 MatchInfo = [=](MachineIRBuilder &B) {
1249 GISelObserverWrapper DummyObserver;
1250 LegalizerHelper Helper(B.getMF(), DummyObserver, B);
1251 //// Get pointer to the vector element.
1252 Register finalPtr = Helper.getVectorElementPointer(
1253 VecPtr: LoadMI->getPointerReg(), VecTy: MRI.getType(Reg: LoadMI->getOperand(i: 0).getReg()),
1254 Index);
1255 // New G_LOAD instruction.
1256 B.buildLoad(Res: Result, Addr: finalPtr, PtrInfo, Alignment);
1257 // Remove original GLOAD instruction.
1258 LoadMI->eraseFromParent();
1259 };
1260
1261 return true;
1262}
1263
1264bool CombinerHelper::matchCombineIndexedLoadStore(
1265 MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) {
1266 auto &LdSt = cast<GLoadStore>(Val&: MI);
1267
1268 if (LdSt.isAtomic())
1269 return false;
1270
1271 MatchInfo.IsPre = findPreIndexCandidate(LdSt, Addr&: MatchInfo.Addr, Base&: MatchInfo.Base,
1272 Offset&: MatchInfo.Offset);
1273 if (!MatchInfo.IsPre &&
1274 !findPostIndexCandidate(LdSt, Addr&: MatchInfo.Addr, Base&: MatchInfo.Base,
1275 Offset&: MatchInfo.Offset, RematOffset&: MatchInfo.RematOffset))
1276 return false;
1277
1278 return true;
1279}
1280
1281void CombinerHelper::applyCombineIndexedLoadStore(
1282 MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) {
1283 MachineInstr &AddrDef = *MRI.getUniqueVRegDef(Reg: MatchInfo.Addr);
1284 Builder.setInstrAndDebugLoc(MI);
1285 unsigned Opcode = MI.getOpcode();
1286 bool IsStore = Opcode == TargetOpcode::G_STORE;
1287 unsigned NewOpcode = getIndexedOpc(LdStOpc: Opcode);
1288
1289 // If the offset constant didn't happen to dominate the load/store, we can
1290 // just clone it as needed.
1291 if (MatchInfo.RematOffset) {
1292 auto *OldCst = MRI.getVRegDef(Reg: MatchInfo.Offset);
1293 auto NewCst = Builder.buildConstant(Res: MRI.getType(Reg: MatchInfo.Offset),
1294 Val: *OldCst->getOperand(i: 1).getCImm());
1295 MatchInfo.Offset = NewCst.getReg(Idx: 0);
1296 }
1297
1298 auto MIB = Builder.buildInstr(Opcode: NewOpcode);
1299 if (IsStore) {
1300 MIB.addDef(RegNo: MatchInfo.Addr);
1301 MIB.addUse(RegNo: MI.getOperand(i: 0).getReg());
1302 } else {
1303 MIB.addDef(RegNo: MI.getOperand(i: 0).getReg());
1304 MIB.addDef(RegNo: MatchInfo.Addr);
1305 }
1306
1307 MIB.addUse(RegNo: MatchInfo.Base);
1308 MIB.addUse(RegNo: MatchInfo.Offset);
1309 MIB.addImm(Val: MatchInfo.IsPre);
1310 MIB->cloneMemRefs(MF&: *MI.getMF(), MI);
1311 MI.eraseFromParent();
1312 AddrDef.eraseFromParent();
1313
1314 LLVM_DEBUG(dbgs() << " Combinined to indexed operation");
1315}
1316
1317bool CombinerHelper::matchCombineDivRem(MachineInstr &MI,
1318 MachineInstr *&OtherMI) {
1319 unsigned Opcode = MI.getOpcode();
1320 bool IsDiv, IsSigned;
1321
1322 switch (Opcode) {
1323 default:
1324 llvm_unreachable("Unexpected opcode!");
1325 case TargetOpcode::G_SDIV:
1326 case TargetOpcode::G_UDIV: {
1327 IsDiv = true;
1328 IsSigned = Opcode == TargetOpcode::G_SDIV;
1329 break;
1330 }
1331 case TargetOpcode::G_SREM:
1332 case TargetOpcode::G_UREM: {
1333 IsDiv = false;
1334 IsSigned = Opcode == TargetOpcode::G_SREM;
1335 break;
1336 }
1337 }
1338
1339 Register Src1 = MI.getOperand(i: 1).getReg();
1340 unsigned DivOpcode, RemOpcode, DivremOpcode;
1341 if (IsSigned) {
1342 DivOpcode = TargetOpcode::G_SDIV;
1343 RemOpcode = TargetOpcode::G_SREM;
1344 DivremOpcode = TargetOpcode::G_SDIVREM;
1345 } else {
1346 DivOpcode = TargetOpcode::G_UDIV;
1347 RemOpcode = TargetOpcode::G_UREM;
1348 DivremOpcode = TargetOpcode::G_UDIVREM;
1349 }
1350
1351 if (!isLegalOrBeforeLegalizer(Query: {DivremOpcode, {MRI.getType(Reg: Src1)}}))
1352 return false;
1353
1354 // Combine:
1355 // %div:_ = G_[SU]DIV %src1:_, %src2:_
1356 // %rem:_ = G_[SU]REM %src1:_, %src2:_
1357 // into:
1358 // %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_
1359
1360 // Combine:
1361 // %rem:_ = G_[SU]REM %src1:_, %src2:_
1362 // %div:_ = G_[SU]DIV %src1:_, %src2:_
1363 // into:
1364 // %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_
1365
1366 for (auto &UseMI : MRI.use_nodbg_instructions(Reg: Src1)) {
1367 if (MI.getParent() == UseMI.getParent() &&
1368 ((IsDiv && UseMI.getOpcode() == RemOpcode) ||
1369 (!IsDiv && UseMI.getOpcode() == DivOpcode)) &&
1370 matchEqualDefs(MOP1: MI.getOperand(i: 2), MOP2: UseMI.getOperand(i: 2)) &&
1371 matchEqualDefs(MOP1: MI.getOperand(i: 1), MOP2: UseMI.getOperand(i: 1))) {
1372 OtherMI = &UseMI;
1373 return true;
1374 }
1375 }
1376
1377 return false;
1378}
1379
1380void CombinerHelper::applyCombineDivRem(MachineInstr &MI,
1381 MachineInstr *&OtherMI) {
1382 unsigned Opcode = MI.getOpcode();
1383 assert(OtherMI && "OtherMI shouldn't be empty.");
1384
1385 Register DestDivReg, DestRemReg;
1386 if (Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_UDIV) {
1387 DestDivReg = MI.getOperand(i: 0).getReg();
1388 DestRemReg = OtherMI->getOperand(i: 0).getReg();
1389 } else {
1390 DestDivReg = OtherMI->getOperand(i: 0).getReg();
1391 DestRemReg = MI.getOperand(i: 0).getReg();
1392 }
1393
1394 bool IsSigned =
1395 Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_SREM;
1396
1397 // Check which instruction is first in the block so we don't break def-use
1398 // deps by "moving" the instruction incorrectly. Also keep track of which
1399 // instruction is first so we pick it's operands, avoiding use-before-def
1400 // bugs.
1401 MachineInstr *FirstInst;
1402 if (dominates(DefMI: MI, UseMI: *OtherMI)) {
1403 Builder.setInstrAndDebugLoc(MI);
1404 FirstInst = &MI;
1405 } else {
1406 Builder.setInstrAndDebugLoc(*OtherMI);
1407 FirstInst = OtherMI;
1408 }
1409
1410 Builder.buildInstr(Opc: IsSigned ? TargetOpcode::G_SDIVREM
1411 : TargetOpcode::G_UDIVREM,
1412 DstOps: {DestDivReg, DestRemReg},
1413 SrcOps: { FirstInst->getOperand(i: 1), FirstInst->getOperand(i: 2) });
1414 MI.eraseFromParent();
1415 OtherMI->eraseFromParent();
1416}
1417
1418bool CombinerHelper::matchOptBrCondByInvertingCond(MachineInstr &MI,
1419 MachineInstr *&BrCond) {
1420 assert(MI.getOpcode() == TargetOpcode::G_BR);
1421
1422 // Try to match the following:
1423 // bb1:
1424 // G_BRCOND %c1, %bb2
1425 // G_BR %bb3
1426 // bb2:
1427 // ...
1428 // bb3:
1429
1430 // The above pattern does not have a fall through to the successor bb2, always
1431 // resulting in a branch no matter which path is taken. Here we try to find
1432 // and replace that pattern with conditional branch to bb3 and otherwise
1433 // fallthrough to bb2. This is generally better for branch predictors.
1434
1435 MachineBasicBlock *MBB = MI.getParent();
1436 MachineBasicBlock::iterator BrIt(MI);
1437 if (BrIt == MBB->begin())
1438 return false;
1439 assert(std::next(BrIt) == MBB->end() && "expected G_BR to be a terminator");
1440
1441 BrCond = &*std::prev(x: BrIt);
1442 if (BrCond->getOpcode() != TargetOpcode::G_BRCOND)
1443 return false;
1444
1445 // Check that the next block is the conditional branch target. Also make sure
1446 // that it isn't the same as the G_BR's target (otherwise, this will loop.)
1447 MachineBasicBlock *BrCondTarget = BrCond->getOperand(i: 1).getMBB();
1448 return BrCondTarget != MI.getOperand(i: 0).getMBB() &&
1449 MBB->isLayoutSuccessor(MBB: BrCondTarget);
1450}
1451
1452void CombinerHelper::applyOptBrCondByInvertingCond(MachineInstr &MI,
1453 MachineInstr *&BrCond) {
1454 MachineBasicBlock *BrTarget = MI.getOperand(i: 0).getMBB();
1455 Builder.setInstrAndDebugLoc(*BrCond);
1456 LLT Ty = MRI.getType(Reg: BrCond->getOperand(i: 0).getReg());
1457 // FIXME: Does int/fp matter for this? If so, we might need to restrict
1458 // this to i1 only since we might not know for sure what kind of
1459 // compare generated the condition value.
1460 auto True = Builder.buildConstant(
1461 Res: Ty, Val: getICmpTrueVal(TLI: getTargetLowering(), IsVector: false, IsFP: false));
1462 auto Xor = Builder.buildXor(Dst: Ty, Src0: BrCond->getOperand(i: 0), Src1: True);
1463
1464 auto *FallthroughBB = BrCond->getOperand(i: 1).getMBB();
1465 Observer.changingInstr(MI);
1466 MI.getOperand(i: 0).setMBB(FallthroughBB);
1467 Observer.changedInstr(MI);
1468
1469 // Change the conditional branch to use the inverted condition and
1470 // new target block.
1471 Observer.changingInstr(MI&: *BrCond);
1472 BrCond->getOperand(i: 0).setReg(Xor.getReg(Idx: 0));
1473 BrCond->getOperand(i: 1).setMBB(BrTarget);
1474 Observer.changedInstr(MI&: *BrCond);
1475}
1476
1477
1478bool CombinerHelper::tryEmitMemcpyInline(MachineInstr &MI) {
1479 MachineIRBuilder HelperBuilder(MI);
1480 GISelObserverWrapper DummyObserver;
1481 LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder);
1482 return Helper.lowerMemcpyInline(MI) ==
1483 LegalizerHelper::LegalizeResult::Legalized;
1484}
1485
1486bool CombinerHelper::tryCombineMemCpyFamily(MachineInstr &MI, unsigned MaxLen) {
1487 MachineIRBuilder HelperBuilder(MI);
1488 GISelObserverWrapper DummyObserver;
1489 LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder);
1490 return Helper.lowerMemCpyFamily(MI, MaxLen) ==
1491 LegalizerHelper::LegalizeResult::Legalized;
1492}
1493
1494static APFloat constantFoldFpUnary(const MachineInstr &MI,
1495 const MachineRegisterInfo &MRI,
1496 const APFloat &Val) {
1497 APFloat Result(Val);
1498 switch (MI.getOpcode()) {
1499 default:
1500 llvm_unreachable("Unexpected opcode!");
1501 case TargetOpcode::G_FNEG: {
1502 Result.changeSign();
1503 return Result;
1504 }
1505 case TargetOpcode::G_FABS: {
1506 Result.clearSign();
1507 return Result;
1508 }
1509 case TargetOpcode::G_FPTRUNC: {
1510 bool Unused;
1511 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
1512 Result.convert(ToSemantics: getFltSemanticForLLT(Ty: DstTy), RM: APFloat::rmNearestTiesToEven,
1513 losesInfo: &Unused);
1514 return Result;
1515 }
1516 case TargetOpcode::G_FSQRT: {
1517 bool Unused;
1518 Result.convert(ToSemantics: APFloat::IEEEdouble(), RM: APFloat::rmNearestTiesToEven,
1519 losesInfo: &Unused);
1520 Result = APFloat(sqrt(x: Result.convertToDouble()));
1521 break;
1522 }
1523 case TargetOpcode::G_FLOG2: {
1524 bool Unused;
1525 Result.convert(ToSemantics: APFloat::IEEEdouble(), RM: APFloat::rmNearestTiesToEven,
1526 losesInfo: &Unused);
1527 Result = APFloat(log2(x: Result.convertToDouble()));
1528 break;
1529 }
1530 }
1531 // Convert `APFloat` to appropriate IEEE type depending on `DstTy`. Otherwise,
1532 // `buildFConstant` will assert on size mismatch. Only `G_FSQRT`, and
1533 // `G_FLOG2` reach here.
1534 bool Unused;
1535 Result.convert(ToSemantics: Val.getSemantics(), RM: APFloat::rmNearestTiesToEven, losesInfo: &Unused);
1536 return Result;
1537}
1538
1539void CombinerHelper::applyCombineConstantFoldFpUnary(MachineInstr &MI,
1540 const ConstantFP *Cst) {
1541 Builder.setInstrAndDebugLoc(MI);
1542 APFloat Folded = constantFoldFpUnary(MI, MRI, Val: Cst->getValue());
1543 const ConstantFP *NewCst = ConstantFP::get(Context&: Builder.getContext(), V: Folded);
1544 Builder.buildFConstant(Res: MI.getOperand(i: 0), Val: *NewCst);
1545 MI.eraseFromParent();
1546}
1547
1548bool CombinerHelper::matchPtrAddImmedChain(MachineInstr &MI,
1549 PtrAddChain &MatchInfo) {
1550 // We're trying to match the following pattern:
1551 // %t1 = G_PTR_ADD %base, G_CONSTANT imm1
1552 // %root = G_PTR_ADD %t1, G_CONSTANT imm2
1553 // -->
1554 // %root = G_PTR_ADD %base, G_CONSTANT (imm1 + imm2)
1555
1556 if (MI.getOpcode() != TargetOpcode::G_PTR_ADD)
1557 return false;
1558
1559 Register Add2 = MI.getOperand(i: 1).getReg();
1560 Register Imm1 = MI.getOperand(i: 2).getReg();
1561 auto MaybeImmVal = getIConstantVRegValWithLookThrough(VReg: Imm1, MRI);
1562 if (!MaybeImmVal)
1563 return false;
1564
1565 MachineInstr *Add2Def = MRI.getVRegDef(Reg: Add2);
1566 if (!Add2Def || Add2Def->getOpcode() != TargetOpcode::G_PTR_ADD)
1567 return false;
1568
1569 Register Base = Add2Def->getOperand(i: 1).getReg();
1570 Register Imm2 = Add2Def->getOperand(i: 2).getReg();
1571 auto MaybeImm2Val = getIConstantVRegValWithLookThrough(VReg: Imm2, MRI);
1572 if (!MaybeImm2Val)
1573 return false;
1574
1575 // Check if the new combined immediate forms an illegal addressing mode.
1576 // Do not combine if it was legal before but would get illegal.
1577 // To do so, we need to find a load/store user of the pointer to get
1578 // the access type.
1579 Type *AccessTy = nullptr;
1580 auto &MF = *MI.getMF();
1581 for (auto &UseMI : MRI.use_nodbg_instructions(Reg: MI.getOperand(i: 0).getReg())) {
1582 if (auto *LdSt = dyn_cast<GLoadStore>(Val: &UseMI)) {
1583 AccessTy = getTypeForLLT(Ty: MRI.getType(Reg: LdSt->getReg(Idx: 0)),
1584 C&: MF.getFunction().getContext());
1585 break;
1586 }
1587 }
1588 TargetLoweringBase::AddrMode AMNew;
1589 APInt CombinedImm = MaybeImmVal->Value + MaybeImm2Val->Value;
1590 AMNew.BaseOffs = CombinedImm.getSExtValue();
1591 if (AccessTy) {
1592 AMNew.HasBaseReg = true;
1593 TargetLoweringBase::AddrMode AMOld;
1594 AMOld.BaseOffs = MaybeImmVal->Value.getSExtValue();
1595 AMOld.HasBaseReg = true;
1596 unsigned AS = MRI.getType(Reg: Add2).getAddressSpace();
1597 const auto &TLI = *MF.getSubtarget().getTargetLowering();
1598 if (TLI.isLegalAddressingMode(DL: MF.getDataLayout(), AM: AMOld, Ty: AccessTy, AddrSpace: AS) &&
1599 !TLI.isLegalAddressingMode(DL: MF.getDataLayout(), AM: AMNew, Ty: AccessTy, AddrSpace: AS))
1600 return false;
1601 }
1602
1603 // Pass the combined immediate to the apply function.
1604 MatchInfo.Imm = AMNew.BaseOffs;
1605 MatchInfo.Base = Base;
1606 MatchInfo.Bank = getRegBank(Reg: Imm2);
1607 return true;
1608}
1609
1610void CombinerHelper::applyPtrAddImmedChain(MachineInstr &MI,
1611 PtrAddChain &MatchInfo) {
1612 assert(MI.getOpcode() == TargetOpcode::G_PTR_ADD && "Expected G_PTR_ADD");
1613 MachineIRBuilder MIB(MI);
1614 LLT OffsetTy = MRI.getType(Reg: MI.getOperand(i: 2).getReg());
1615 auto NewOffset = MIB.buildConstant(Res: OffsetTy, Val: MatchInfo.Imm);
1616 setRegBank(Reg: NewOffset.getReg(Idx: 0), RegBank: MatchInfo.Bank);
1617 Observer.changingInstr(MI);
1618 MI.getOperand(i: 1).setReg(MatchInfo.Base);
1619 MI.getOperand(i: 2).setReg(NewOffset.getReg(Idx: 0));
1620 Observer.changedInstr(MI);
1621}
1622
1623bool CombinerHelper::matchShiftImmedChain(MachineInstr &MI,
1624 RegisterImmPair &MatchInfo) {
1625 // We're trying to match the following pattern with any of
1626 // G_SHL/G_ASHR/G_LSHR/G_SSHLSAT/G_USHLSAT shift instructions:
1627 // %t1 = SHIFT %base, G_CONSTANT imm1
1628 // %root = SHIFT %t1, G_CONSTANT imm2
1629 // -->
1630 // %root = SHIFT %base, G_CONSTANT (imm1 + imm2)
1631
1632 unsigned Opcode = MI.getOpcode();
1633 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR ||
1634 Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT ||
1635 Opcode == TargetOpcode::G_USHLSAT) &&
1636 "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT");
1637
1638 Register Shl2 = MI.getOperand(i: 1).getReg();
1639 Register Imm1 = MI.getOperand(i: 2).getReg();
1640 auto MaybeImmVal = getIConstantVRegValWithLookThrough(VReg: Imm1, MRI);
1641 if (!MaybeImmVal)
1642 return false;
1643
1644 MachineInstr *Shl2Def = MRI.getUniqueVRegDef(Reg: Shl2);
1645 if (Shl2Def->getOpcode() != Opcode)
1646 return false;
1647
1648 Register Base = Shl2Def->getOperand(i: 1).getReg();
1649 Register Imm2 = Shl2Def->getOperand(i: 2).getReg();
1650 auto MaybeImm2Val = getIConstantVRegValWithLookThrough(VReg: Imm2, MRI);
1651 if (!MaybeImm2Val)
1652 return false;
1653
1654 // Pass the combined immediate to the apply function.
1655 MatchInfo.Imm =
1656 (MaybeImmVal->Value.getZExtValue() + MaybeImm2Val->Value).getZExtValue();
1657 MatchInfo.Reg = Base;
1658
1659 // There is no simple replacement for a saturating unsigned left shift that
1660 // exceeds the scalar size.
1661 if (Opcode == TargetOpcode::G_USHLSAT &&
1662 MatchInfo.Imm >= MRI.getType(Reg: Shl2).getScalarSizeInBits())
1663 return false;
1664
1665 return true;
1666}
1667
1668void CombinerHelper::applyShiftImmedChain(MachineInstr &MI,
1669 RegisterImmPair &MatchInfo) {
1670 unsigned Opcode = MI.getOpcode();
1671 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR ||
1672 Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT ||
1673 Opcode == TargetOpcode::G_USHLSAT) &&
1674 "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT");
1675
1676 Builder.setInstrAndDebugLoc(MI);
1677 LLT Ty = MRI.getType(Reg: MI.getOperand(i: 1).getReg());
1678 unsigned const ScalarSizeInBits = Ty.getScalarSizeInBits();
1679 auto Imm = MatchInfo.Imm;
1680
1681 if (Imm >= ScalarSizeInBits) {
1682 // Any logical shift that exceeds scalar size will produce zero.
1683 if (Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_LSHR) {
1684 Builder.buildConstant(Res: MI.getOperand(i: 0), Val: 0);
1685 MI.eraseFromParent();
1686 return;
1687 }
1688 // Arithmetic shift and saturating signed left shift have no effect beyond
1689 // scalar size.
1690 Imm = ScalarSizeInBits - 1;
1691 }
1692
1693 LLT ImmTy = MRI.getType(Reg: MI.getOperand(i: 2).getReg());
1694 Register NewImm = Builder.buildConstant(Res: ImmTy, Val: Imm).getReg(Idx: 0);
1695 Observer.changingInstr(MI);
1696 MI.getOperand(i: 1).setReg(MatchInfo.Reg);
1697 MI.getOperand(i: 2).setReg(NewImm);
1698 Observer.changedInstr(MI);
1699}
1700
1701bool CombinerHelper::matchShiftOfShiftedLogic(MachineInstr &MI,
1702 ShiftOfShiftedLogic &MatchInfo) {
1703 // We're trying to match the following pattern with any of
1704 // G_SHL/G_ASHR/G_LSHR/G_USHLSAT/G_SSHLSAT shift instructions in combination
1705 // with any of G_AND/G_OR/G_XOR logic instructions.
1706 // %t1 = SHIFT %X, G_CONSTANT C0
1707 // %t2 = LOGIC %t1, %Y
1708 // %root = SHIFT %t2, G_CONSTANT C1
1709 // -->
1710 // %t3 = SHIFT %X, G_CONSTANT (C0+C1)
1711 // %t4 = SHIFT %Y, G_CONSTANT C1
1712 // %root = LOGIC %t3, %t4
1713 unsigned ShiftOpcode = MI.getOpcode();
1714 assert((ShiftOpcode == TargetOpcode::G_SHL ||
1715 ShiftOpcode == TargetOpcode::G_ASHR ||
1716 ShiftOpcode == TargetOpcode::G_LSHR ||
1717 ShiftOpcode == TargetOpcode::G_USHLSAT ||
1718 ShiftOpcode == TargetOpcode::G_SSHLSAT) &&
1719 "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT");
1720
1721 // Match a one-use bitwise logic op.
1722 Register LogicDest = MI.getOperand(i: 1).getReg();
1723 if (!MRI.hasOneNonDBGUse(RegNo: LogicDest))
1724 return false;
1725
1726 MachineInstr *LogicMI = MRI.getUniqueVRegDef(Reg: LogicDest);
1727 unsigned LogicOpcode = LogicMI->getOpcode();
1728 if (LogicOpcode != TargetOpcode::G_AND && LogicOpcode != TargetOpcode::G_OR &&
1729 LogicOpcode != TargetOpcode::G_XOR)
1730 return false;
1731
1732 // Find a matching one-use shift by constant.
1733 const Register C1 = MI.getOperand(i: 2).getReg();
1734 auto MaybeImmVal = getIConstantVRegValWithLookThrough(VReg: C1, MRI);
1735 if (!MaybeImmVal || MaybeImmVal->Value == 0)
1736 return false;
1737
1738 const uint64_t C1Val = MaybeImmVal->Value.getZExtValue();
1739
1740 auto matchFirstShift = [&](const MachineInstr *MI, uint64_t &ShiftVal) {
1741 // Shift should match previous one and should be a one-use.
1742 if (MI->getOpcode() != ShiftOpcode ||
1743 !MRI.hasOneNonDBGUse(RegNo: MI->getOperand(i: 0).getReg()))
1744 return false;
1745
1746 // Must be a constant.
1747 auto MaybeImmVal =
1748 getIConstantVRegValWithLookThrough(VReg: MI->getOperand(i: 2).getReg(), MRI);
1749 if (!MaybeImmVal)
1750 return false;
1751
1752 ShiftVal = MaybeImmVal->Value.getSExtValue();
1753 return true;
1754 };
1755
1756 // Logic ops are commutative, so check each operand for a match.
1757 Register LogicMIReg1 = LogicMI->getOperand(i: 1).getReg();
1758 MachineInstr *LogicMIOp1 = MRI.getUniqueVRegDef(Reg: LogicMIReg1);
1759 Register LogicMIReg2 = LogicMI->getOperand(i: 2).getReg();
1760 MachineInstr *LogicMIOp2 = MRI.getUniqueVRegDef(Reg: LogicMIReg2);
1761 uint64_t C0Val;
1762
1763 if (matchFirstShift(LogicMIOp1, C0Val)) {
1764 MatchInfo.LogicNonShiftReg = LogicMIReg2;
1765 MatchInfo.Shift2 = LogicMIOp1;
1766 } else if (matchFirstShift(LogicMIOp2, C0Val)) {
1767 MatchInfo.LogicNonShiftReg = LogicMIReg1;
1768 MatchInfo.Shift2 = LogicMIOp2;
1769 } else
1770 return false;
1771
1772 MatchInfo.ValSum = C0Val + C1Val;
1773
1774 // The fold is not valid if the sum of the shift values exceeds bitwidth.
1775 if (MatchInfo.ValSum >= MRI.getType(Reg: LogicDest).getScalarSizeInBits())
1776 return false;
1777
1778 MatchInfo.Logic = LogicMI;
1779 return true;
1780}
1781
1782void CombinerHelper::applyShiftOfShiftedLogic(MachineInstr &MI,
1783 ShiftOfShiftedLogic &MatchInfo) {
1784 unsigned Opcode = MI.getOpcode();
1785 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR ||
1786 Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_USHLSAT ||
1787 Opcode == TargetOpcode::G_SSHLSAT) &&
1788 "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT");
1789
1790 LLT ShlType = MRI.getType(Reg: MI.getOperand(i: 2).getReg());
1791 LLT DestType = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
1792 Builder.setInstrAndDebugLoc(MI);
1793
1794 Register Const = Builder.buildConstant(Res: ShlType, Val: MatchInfo.ValSum).getReg(Idx: 0);
1795
1796 Register Shift1Base = MatchInfo.Shift2->getOperand(i: 1).getReg();
1797 Register Shift1 =
1798 Builder.buildInstr(Opc: Opcode, DstOps: {DestType}, SrcOps: {Shift1Base, Const}).getReg(Idx: 0);
1799
1800 // If LogicNonShiftReg is the same to Shift1Base, and shift1 const is the same
1801 // to MatchInfo.Shift2 const, CSEMIRBuilder will reuse the old shift1 when
1802 // build shift2. So, if we erase MatchInfo.Shift2 at the end, actually we
1803 // remove old shift1. And it will cause crash later. So erase it earlier to
1804 // avoid the crash.
1805 MatchInfo.Shift2->eraseFromParent();
1806
1807 Register Shift2Const = MI.getOperand(i: 2).getReg();
1808 Register Shift2 = Builder
1809 .buildInstr(Opc: Opcode, DstOps: {DestType},
1810 SrcOps: {MatchInfo.LogicNonShiftReg, Shift2Const})
1811 .getReg(Idx: 0);
1812
1813 Register Dest = MI.getOperand(i: 0).getReg();
1814 Builder.buildInstr(Opc: MatchInfo.Logic->getOpcode(), DstOps: {Dest}, SrcOps: {Shift1, Shift2});
1815
1816 // This was one use so it's safe to remove it.
1817 MatchInfo.Logic->eraseFromParent();
1818
1819 MI.eraseFromParent();
1820}
1821
1822bool CombinerHelper::matchCommuteShift(MachineInstr &MI, BuildFnTy &MatchInfo) {
1823 assert(MI.getOpcode() == TargetOpcode::G_SHL && "Expected G_SHL");
1824 // Combine (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
1825 // Combine (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
1826 auto &Shl = cast<GenericMachineInstr>(Val&: MI);
1827 Register DstReg = Shl.getReg(Idx: 0);
1828 Register SrcReg = Shl.getReg(Idx: 1);
1829 Register ShiftReg = Shl.getReg(Idx: 2);
1830 Register X, C1;
1831
1832 if (!getTargetLowering().isDesirableToCommuteWithShift(MI, IsAfterLegal: !isPreLegalize()))
1833 return false;
1834
1835 if (!mi_match(R: SrcReg, MRI,
1836 P: m_OneNonDBGUse(SP: m_any_of(preds: m_GAdd(L: m_Reg(R&: X), R: m_Reg(R&: C1)),
1837 preds: m_GOr(L: m_Reg(R&: X), R: m_Reg(R&: C1))))))
1838 return false;
1839
1840 APInt C1Val, C2Val;
1841 if (!mi_match(R: C1, MRI, P: m_ICstOrSplat(Cst&: C1Val)) ||
1842 !mi_match(R: ShiftReg, MRI, P: m_ICstOrSplat(Cst&: C2Val)))
1843 return false;
1844
1845 auto *SrcDef = MRI.getVRegDef(Reg: SrcReg);
1846 assert((SrcDef->getOpcode() == TargetOpcode::G_ADD ||
1847 SrcDef->getOpcode() == TargetOpcode::G_OR) && "Unexpected op");
1848 LLT SrcTy = MRI.getType(Reg: SrcReg);
1849 MatchInfo = [=](MachineIRBuilder &B) {
1850 auto S1 = B.buildShl(Dst: SrcTy, Src0: X, Src1: ShiftReg);
1851 auto S2 = B.buildShl(Dst: SrcTy, Src0: C1, Src1: ShiftReg);
1852 B.buildInstr(Opc: SrcDef->getOpcode(), DstOps: {DstReg}, SrcOps: {S1, S2});
1853 };
1854 return true;
1855}
1856
1857bool CombinerHelper::matchCombineMulToShl(MachineInstr &MI,
1858 unsigned &ShiftVal) {
1859 assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL");
1860 auto MaybeImmVal =
1861 getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI);
1862 if (!MaybeImmVal)
1863 return false;
1864
1865 ShiftVal = MaybeImmVal->Value.exactLogBase2();
1866 return (static_cast<int32_t>(ShiftVal) != -1);
1867}
1868
1869void CombinerHelper::applyCombineMulToShl(MachineInstr &MI,
1870 unsigned &ShiftVal) {
1871 assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL");
1872 MachineIRBuilder MIB(MI);
1873 LLT ShiftTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
1874 auto ShiftCst = MIB.buildConstant(Res: ShiftTy, Val: ShiftVal);
1875 Observer.changingInstr(MI);
1876 MI.setDesc(MIB.getTII().get(Opcode: TargetOpcode::G_SHL));
1877 MI.getOperand(i: 2).setReg(ShiftCst.getReg(Idx: 0));
1878 Observer.changedInstr(MI);
1879}
1880
1881// shl ([sza]ext x), y => zext (shl x, y), if shift does not overflow source
1882bool CombinerHelper::matchCombineShlOfExtend(MachineInstr &MI,
1883 RegisterImmPair &MatchData) {
1884 assert(MI.getOpcode() == TargetOpcode::G_SHL && KB);
1885 if (!getTargetLowering().isDesirableToPullExtFromShl(MI))
1886 return false;
1887
1888 Register LHS = MI.getOperand(i: 1).getReg();
1889
1890 Register ExtSrc;
1891 if (!mi_match(R: LHS, MRI, P: m_GAnyExt(Src: m_Reg(R&: ExtSrc))) &&
1892 !mi_match(R: LHS, MRI, P: m_GZExt(Src: m_Reg(R&: ExtSrc))) &&
1893 !mi_match(R: LHS, MRI, P: m_GSExt(Src: m_Reg(R&: ExtSrc))))
1894 return false;
1895
1896 Register RHS = MI.getOperand(i: 2).getReg();
1897 MachineInstr *MIShiftAmt = MRI.getVRegDef(Reg: RHS);
1898 auto MaybeShiftAmtVal = isConstantOrConstantSplatVector(MI&: *MIShiftAmt, MRI);
1899 if (!MaybeShiftAmtVal)
1900 return false;
1901
1902 if (LI) {
1903 LLT SrcTy = MRI.getType(Reg: ExtSrc);
1904
1905 // We only really care about the legality with the shifted value. We can
1906 // pick any type the constant shift amount, so ask the target what to
1907 // use. Otherwise we would have to guess and hope it is reported as legal.
1908 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: SrcTy);
1909 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SHL, {SrcTy, ShiftAmtTy}}))
1910 return false;
1911 }
1912
1913 int64_t ShiftAmt = MaybeShiftAmtVal->getSExtValue();
1914 MatchData.Reg = ExtSrc;
1915 MatchData.Imm = ShiftAmt;
1916
1917 unsigned MinLeadingZeros = KB->getKnownZeroes(R: ExtSrc).countl_one();
1918 unsigned SrcTySize = MRI.getType(Reg: ExtSrc).getScalarSizeInBits();
1919 return MinLeadingZeros >= ShiftAmt && ShiftAmt < SrcTySize;
1920}
1921
1922void CombinerHelper::applyCombineShlOfExtend(MachineInstr &MI,
1923 const RegisterImmPair &MatchData) {
1924 Register ExtSrcReg = MatchData.Reg;
1925 int64_t ShiftAmtVal = MatchData.Imm;
1926
1927 LLT ExtSrcTy = MRI.getType(Reg: ExtSrcReg);
1928 Builder.setInstrAndDebugLoc(MI);
1929 auto ShiftAmt = Builder.buildConstant(Res: ExtSrcTy, Val: ShiftAmtVal);
1930 auto NarrowShift =
1931 Builder.buildShl(Dst: ExtSrcTy, Src0: ExtSrcReg, Src1: ShiftAmt, Flags: MI.getFlags());
1932 Builder.buildZExt(Res: MI.getOperand(i: 0), Op: NarrowShift);
1933 MI.eraseFromParent();
1934}
1935
1936bool CombinerHelper::matchCombineMergeUnmerge(MachineInstr &MI,
1937 Register &MatchInfo) {
1938 GMerge &Merge = cast<GMerge>(Val&: MI);
1939 SmallVector<Register, 16> MergedValues;
1940 for (unsigned I = 0; I < Merge.getNumSources(); ++I)
1941 MergedValues.emplace_back(Args: Merge.getSourceReg(I));
1942
1943 auto *Unmerge = getOpcodeDef<GUnmerge>(Reg: MergedValues[0], MRI);
1944 if (!Unmerge || Unmerge->getNumDefs() != Merge.getNumSources())
1945 return false;
1946
1947 for (unsigned I = 0; I < MergedValues.size(); ++I)
1948 if (MergedValues[I] != Unmerge->getReg(Idx: I))
1949 return false;
1950
1951 MatchInfo = Unmerge->getSourceReg();
1952 return true;
1953}
1954
1955static Register peekThroughBitcast(Register Reg,
1956 const MachineRegisterInfo &MRI) {
1957 while (mi_match(R: Reg, MRI, P: m_GBitcast(Src: m_Reg(R&: Reg))))
1958 ;
1959
1960 return Reg;
1961}
1962
1963bool CombinerHelper::matchCombineUnmergeMergeToPlainValues(
1964 MachineInstr &MI, SmallVectorImpl<Register> &Operands) {
1965 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
1966 "Expected an unmerge");
1967 auto &Unmerge = cast<GUnmerge>(Val&: MI);
1968 Register SrcReg = peekThroughBitcast(Reg: Unmerge.getSourceReg(), MRI);
1969
1970 auto *SrcInstr = getOpcodeDef<GMergeLikeInstr>(Reg: SrcReg, MRI);
1971 if (!SrcInstr)
1972 return false;
1973
1974 // Check the source type of the merge.
1975 LLT SrcMergeTy = MRI.getType(Reg: SrcInstr->getSourceReg(I: 0));
1976 LLT Dst0Ty = MRI.getType(Reg: Unmerge.getReg(Idx: 0));
1977 bool SameSize = Dst0Ty.getSizeInBits() == SrcMergeTy.getSizeInBits();
1978 if (SrcMergeTy != Dst0Ty && !SameSize)
1979 return false;
1980 // They are the same now (modulo a bitcast).
1981 // We can collect all the src registers.
1982 for (unsigned Idx = 0; Idx < SrcInstr->getNumSources(); ++Idx)
1983 Operands.push_back(Elt: SrcInstr->getSourceReg(I: Idx));
1984 return true;
1985}
1986
1987void CombinerHelper::applyCombineUnmergeMergeToPlainValues(
1988 MachineInstr &MI, SmallVectorImpl<Register> &Operands) {
1989 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
1990 "Expected an unmerge");
1991 assert((MI.getNumOperands() - 1 == Operands.size()) &&
1992 "Not enough operands to replace all defs");
1993 unsigned NumElems = MI.getNumOperands() - 1;
1994
1995 LLT SrcTy = MRI.getType(Reg: Operands[0]);
1996 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
1997 bool CanReuseInputDirectly = DstTy == SrcTy;
1998 Builder.setInstrAndDebugLoc(MI);
1999 for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
2000 Register DstReg = MI.getOperand(i: Idx).getReg();
2001 Register SrcReg = Operands[Idx];
2002
2003 // This combine may run after RegBankSelect, so we need to be aware of
2004 // register banks.
2005 const auto &DstCB = MRI.getRegClassOrRegBank(Reg: DstReg);
2006 if (!DstCB.isNull() && DstCB != MRI.getRegClassOrRegBank(Reg: SrcReg)) {
2007 SrcReg = Builder.buildCopy(Res: MRI.getType(Reg: SrcReg), Op: SrcReg).getReg(Idx: 0);
2008 MRI.setRegClassOrRegBank(Reg: SrcReg, RCOrRB: DstCB);
2009 }
2010
2011 if (CanReuseInputDirectly)
2012 replaceRegWith(MRI, FromReg: DstReg, ToReg: SrcReg);
2013 else
2014 Builder.buildCast(Dst: DstReg, Src: SrcReg);
2015 }
2016 MI.eraseFromParent();
2017}
2018
2019bool CombinerHelper::matchCombineUnmergeConstant(MachineInstr &MI,
2020 SmallVectorImpl<APInt> &Csts) {
2021 unsigned SrcIdx = MI.getNumOperands() - 1;
2022 Register SrcReg = MI.getOperand(i: SrcIdx).getReg();
2023 MachineInstr *SrcInstr = MRI.getVRegDef(Reg: SrcReg);
2024 if (SrcInstr->getOpcode() != TargetOpcode::G_CONSTANT &&
2025 SrcInstr->getOpcode() != TargetOpcode::G_FCONSTANT)
2026 return false;
2027 // Break down the big constant in smaller ones.
2028 const MachineOperand &CstVal = SrcInstr->getOperand(i: 1);
2029 APInt Val = SrcInstr->getOpcode() == TargetOpcode::G_CONSTANT
2030 ? CstVal.getCImm()->getValue()
2031 : CstVal.getFPImm()->getValueAPF().bitcastToAPInt();
2032
2033 LLT Dst0Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
2034 unsigned ShiftAmt = Dst0Ty.getSizeInBits();
2035 // Unmerge a constant.
2036 for (unsigned Idx = 0; Idx != SrcIdx; ++Idx) {
2037 Csts.emplace_back(Args: Val.trunc(width: ShiftAmt));
2038 Val = Val.lshr(shiftAmt: ShiftAmt);
2039 }
2040
2041 return true;
2042}
2043
2044void CombinerHelper::applyCombineUnmergeConstant(MachineInstr &MI,
2045 SmallVectorImpl<APInt> &Csts) {
2046 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2047 "Expected an unmerge");
2048 assert((MI.getNumOperands() - 1 == Csts.size()) &&
2049 "Not enough operands to replace all defs");
2050 unsigned NumElems = MI.getNumOperands() - 1;
2051 Builder.setInstrAndDebugLoc(MI);
2052 for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
2053 Register DstReg = MI.getOperand(i: Idx).getReg();
2054 Builder.buildConstant(Res: DstReg, Val: Csts[Idx]);
2055 }
2056
2057 MI.eraseFromParent();
2058}
2059
2060bool CombinerHelper::matchCombineUnmergeUndef(
2061 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
2062 unsigned SrcIdx = MI.getNumOperands() - 1;
2063 Register SrcReg = MI.getOperand(i: SrcIdx).getReg();
2064 MatchInfo = [&MI](MachineIRBuilder &B) {
2065 unsigned NumElems = MI.getNumOperands() - 1;
2066 for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
2067 Register DstReg = MI.getOperand(i: Idx).getReg();
2068 B.buildUndef(Res: DstReg);
2069 }
2070 };
2071 return isa<GImplicitDef>(Val: MRI.getVRegDef(Reg: SrcReg));
2072}
2073
2074bool CombinerHelper::matchCombineUnmergeWithDeadLanesToTrunc(MachineInstr &MI) {
2075 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2076 "Expected an unmerge");
2077 // Check that all the lanes are dead except the first one.
2078 for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) {
2079 if (!MRI.use_nodbg_empty(RegNo: MI.getOperand(i: Idx).getReg()))
2080 return false;
2081 }
2082 return true;
2083}
2084
2085void CombinerHelper::applyCombineUnmergeWithDeadLanesToTrunc(MachineInstr &MI) {
2086 Builder.setInstrAndDebugLoc(MI);
2087 Register SrcReg = MI.getOperand(i: MI.getNumDefs()).getReg();
2088 // Truncating a vector is going to truncate every single lane,
2089 // whereas we want the full lowbits.
2090 // Do the operation on a scalar instead.
2091 LLT SrcTy = MRI.getType(Reg: SrcReg);
2092 if (SrcTy.isVector())
2093 SrcReg =
2094 Builder.buildCast(Dst: LLT::scalar(SizeInBits: SrcTy.getSizeInBits()), Src: SrcReg).getReg(Idx: 0);
2095
2096 Register Dst0Reg = MI.getOperand(i: 0).getReg();
2097 LLT Dst0Ty = MRI.getType(Reg: Dst0Reg);
2098 if (Dst0Ty.isVector()) {
2099 auto MIB = Builder.buildTrunc(Res: LLT::scalar(SizeInBits: Dst0Ty.getSizeInBits()), Op: SrcReg);
2100 Builder.buildCast(Dst: Dst0Reg, Src: MIB);
2101 } else
2102 Builder.buildTrunc(Res: Dst0Reg, Op: SrcReg);
2103 MI.eraseFromParent();
2104}
2105
2106bool CombinerHelper::matchCombineUnmergeZExtToZExt(MachineInstr &MI) {
2107 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2108 "Expected an unmerge");
2109 Register Dst0Reg = MI.getOperand(i: 0).getReg();
2110 LLT Dst0Ty = MRI.getType(Reg: Dst0Reg);
2111 // G_ZEXT on vector applies to each lane, so it will
2112 // affect all destinations. Therefore we won't be able
2113 // to simplify the unmerge to just the first definition.
2114 if (Dst0Ty.isVector())
2115 return false;
2116 Register SrcReg = MI.getOperand(i: MI.getNumDefs()).getReg();
2117 LLT SrcTy = MRI.getType(Reg: SrcReg);
2118 if (SrcTy.isVector())
2119 return false;
2120
2121 Register ZExtSrcReg;
2122 if (!mi_match(R: SrcReg, MRI, P: m_GZExt(Src: m_Reg(R&: ZExtSrcReg))))
2123 return false;
2124
2125 // Finally we can replace the first definition with
2126 // a zext of the source if the definition is big enough to hold
2127 // all of ZExtSrc bits.
2128 LLT ZExtSrcTy = MRI.getType(Reg: ZExtSrcReg);
2129 return ZExtSrcTy.getSizeInBits() <= Dst0Ty.getSizeInBits();
2130}
2131
2132void CombinerHelper::applyCombineUnmergeZExtToZExt(MachineInstr &MI) {
2133 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
2134 "Expected an unmerge");
2135
2136 Register Dst0Reg = MI.getOperand(i: 0).getReg();
2137
2138 MachineInstr *ZExtInstr =
2139 MRI.getVRegDef(Reg: MI.getOperand(i: MI.getNumDefs()).getReg());
2140 assert(ZExtInstr && ZExtInstr->getOpcode() == TargetOpcode::G_ZEXT &&
2141 "Expecting a G_ZEXT");
2142
2143 Register ZExtSrcReg = ZExtInstr->getOperand(i: 1).getReg();
2144 LLT Dst0Ty = MRI.getType(Reg: Dst0Reg);
2145 LLT ZExtSrcTy = MRI.getType(Reg: ZExtSrcReg);
2146
2147 Builder.setInstrAndDebugLoc(MI);
2148
2149 if (Dst0Ty.getSizeInBits() > ZExtSrcTy.getSizeInBits()) {
2150 Builder.buildZExt(Res: Dst0Reg, Op: ZExtSrcReg);
2151 } else {
2152 assert(Dst0Ty.getSizeInBits() == ZExtSrcTy.getSizeInBits() &&
2153 "ZExt src doesn't fit in destination");
2154 replaceRegWith(MRI, FromReg: Dst0Reg, ToReg: ZExtSrcReg);
2155 }
2156
2157 Register ZeroReg;
2158 for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) {
2159 if (!ZeroReg)
2160 ZeroReg = Builder.buildConstant(Res: Dst0Ty, Val: 0).getReg(Idx: 0);
2161 replaceRegWith(MRI, FromReg: MI.getOperand(i: Idx).getReg(), ToReg: ZeroReg);
2162 }
2163 MI.eraseFromParent();
2164}
2165
2166bool CombinerHelper::matchCombineShiftToUnmerge(MachineInstr &MI,
2167 unsigned TargetShiftSize,
2168 unsigned &ShiftVal) {
2169 assert((MI.getOpcode() == TargetOpcode::G_SHL ||
2170 MI.getOpcode() == TargetOpcode::G_LSHR ||
2171 MI.getOpcode() == TargetOpcode::G_ASHR) && "Expected a shift");
2172
2173 LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
2174 if (Ty.isVector()) // TODO:
2175 return false;
2176
2177 // Don't narrow further than the requested size.
2178 unsigned Size = Ty.getSizeInBits();
2179 if (Size <= TargetShiftSize)
2180 return false;
2181
2182 auto MaybeImmVal =
2183 getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI);
2184 if (!MaybeImmVal)
2185 return false;
2186
2187 ShiftVal = MaybeImmVal->Value.getSExtValue();
2188 return ShiftVal >= Size / 2 && ShiftVal < Size;
2189}
2190
2191void CombinerHelper::applyCombineShiftToUnmerge(MachineInstr &MI,
2192 const unsigned &ShiftVal) {
2193 Register DstReg = MI.getOperand(i: 0).getReg();
2194 Register SrcReg = MI.getOperand(i: 1).getReg();
2195 LLT Ty = MRI.getType(Reg: SrcReg);
2196 unsigned Size = Ty.getSizeInBits();
2197 unsigned HalfSize = Size / 2;
2198 assert(ShiftVal >= HalfSize);
2199
2200 LLT HalfTy = LLT::scalar(SizeInBits: HalfSize);
2201
2202 Builder.setInstr(MI);
2203 auto Unmerge = Builder.buildUnmerge(Res: HalfTy, Op: SrcReg);
2204 unsigned NarrowShiftAmt = ShiftVal - HalfSize;
2205
2206 if (MI.getOpcode() == TargetOpcode::G_LSHR) {
2207 Register Narrowed = Unmerge.getReg(Idx: 1);
2208
2209 // dst = G_LSHR s64:x, C for C >= 32
2210 // =>
2211 // lo, hi = G_UNMERGE_VALUES x
2212 // dst = G_MERGE_VALUES (G_LSHR hi, C - 32), 0
2213
2214 if (NarrowShiftAmt != 0) {
2215 Narrowed = Builder.buildLShr(Dst: HalfTy, Src0: Narrowed,
2216 Src1: Builder.buildConstant(Res: HalfTy, Val: NarrowShiftAmt)).getReg(Idx: 0);
2217 }
2218
2219 auto Zero = Builder.buildConstant(Res: HalfTy, Val: 0);
2220 Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Narrowed, Zero});
2221 } else if (MI.getOpcode() == TargetOpcode::G_SHL) {
2222 Register Narrowed = Unmerge.getReg(Idx: 0);
2223 // dst = G_SHL s64:x, C for C >= 32
2224 // =>
2225 // lo, hi = G_UNMERGE_VALUES x
2226 // dst = G_MERGE_VALUES 0, (G_SHL hi, C - 32)
2227 if (NarrowShiftAmt != 0) {
2228 Narrowed = Builder.buildShl(Dst: HalfTy, Src0: Narrowed,
2229 Src1: Builder.buildConstant(Res: HalfTy, Val: NarrowShiftAmt)).getReg(Idx: 0);
2230 }
2231
2232 auto Zero = Builder.buildConstant(Res: HalfTy, Val: 0);
2233 Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Zero, Narrowed});
2234 } else {
2235 assert(MI.getOpcode() == TargetOpcode::G_ASHR);
2236 auto Hi = Builder.buildAShr(
2237 Dst: HalfTy, Src0: Unmerge.getReg(Idx: 1),
2238 Src1: Builder.buildConstant(Res: HalfTy, Val: HalfSize - 1));
2239
2240 if (ShiftVal == HalfSize) {
2241 // (G_ASHR i64:x, 32) ->
2242 // G_MERGE_VALUES hi_32(x), (G_ASHR hi_32(x), 31)
2243 Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Unmerge.getReg(Idx: 1), Hi});
2244 } else if (ShiftVal == Size - 1) {
2245 // Don't need a second shift.
2246 // (G_ASHR i64:x, 63) ->
2247 // %narrowed = (G_ASHR hi_32(x), 31)
2248 // G_MERGE_VALUES %narrowed, %narrowed
2249 Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Hi, Hi});
2250 } else {
2251 auto Lo = Builder.buildAShr(
2252 Dst: HalfTy, Src0: Unmerge.getReg(Idx: 1),
2253 Src1: Builder.buildConstant(Res: HalfTy, Val: ShiftVal - HalfSize));
2254
2255 // (G_ASHR i64:x, C) ->, for C >= 32
2256 // G_MERGE_VALUES (G_ASHR hi_32(x), C - 32), (G_ASHR hi_32(x), 31)
2257 Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Lo, Hi});
2258 }
2259 }
2260
2261 MI.eraseFromParent();
2262}
2263
2264bool CombinerHelper::tryCombineShiftToUnmerge(MachineInstr &MI,
2265 unsigned TargetShiftAmount) {
2266 unsigned ShiftAmt;
2267 if (matchCombineShiftToUnmerge(MI, TargetShiftSize: TargetShiftAmount, ShiftVal&: ShiftAmt)) {
2268 applyCombineShiftToUnmerge(MI, ShiftVal: ShiftAmt);
2269 return true;
2270 }
2271
2272 return false;
2273}
2274
2275bool CombinerHelper::matchCombineI2PToP2I(MachineInstr &MI, Register &Reg) {
2276 assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR");
2277 Register DstReg = MI.getOperand(i: 0).getReg();
2278 LLT DstTy = MRI.getType(Reg: DstReg);
2279 Register SrcReg = MI.getOperand(i: 1).getReg();
2280 return mi_match(R: SrcReg, MRI,
2281 P: m_GPtrToInt(Src: m_all_of(preds: m_SpecificType(Ty: DstTy), preds: m_Reg(R&: Reg))));
2282}
2283
2284void CombinerHelper::applyCombineI2PToP2I(MachineInstr &MI, Register &Reg) {
2285 assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR");
2286 Register DstReg = MI.getOperand(i: 0).getReg();
2287 Builder.setInstr(MI);
2288 Builder.buildCopy(Res: DstReg, Op: Reg);
2289 MI.eraseFromParent();
2290}
2291
2292void CombinerHelper::applyCombineP2IToI2P(MachineInstr &MI, Register &Reg) {
2293 assert(MI.getOpcode() == TargetOpcode::G_PTRTOINT && "Expected a G_PTRTOINT");
2294 Register DstReg = MI.getOperand(i: 0).getReg();
2295 Builder.setInstr(MI);
2296 Builder.buildZExtOrTrunc(Res: DstReg, Op: Reg);
2297 MI.eraseFromParent();
2298}
2299
2300bool CombinerHelper::matchCombineAddP2IToPtrAdd(
2301 MachineInstr &MI, std::pair<Register, bool> &PtrReg) {
2302 assert(MI.getOpcode() == TargetOpcode::G_ADD);
2303 Register LHS = MI.getOperand(i: 1).getReg();
2304 Register RHS = MI.getOperand(i: 2).getReg();
2305 LLT IntTy = MRI.getType(Reg: LHS);
2306
2307 // G_PTR_ADD always has the pointer in the LHS, so we may need to commute the
2308 // instruction.
2309 PtrReg.second = false;
2310 for (Register SrcReg : {LHS, RHS}) {
2311 if (mi_match(R: SrcReg, MRI, P: m_GPtrToInt(Src: m_Reg(R&: PtrReg.first)))) {
2312 // Don't handle cases where the integer is implicitly converted to the
2313 // pointer width.
2314 LLT PtrTy = MRI.getType(Reg: PtrReg.first);
2315 if (PtrTy.getScalarSizeInBits() == IntTy.getScalarSizeInBits())
2316 return true;
2317 }
2318
2319 PtrReg.second = true;
2320 }
2321
2322 return false;
2323}
2324
2325void CombinerHelper::applyCombineAddP2IToPtrAdd(
2326 MachineInstr &MI, std::pair<Register, bool> &PtrReg) {
2327 Register Dst = MI.getOperand(i: 0).getReg();
2328 Register LHS = MI.getOperand(i: 1).getReg();
2329 Register RHS = MI.getOperand(i: 2).getReg();
2330
2331 const bool DoCommute = PtrReg.second;
2332 if (DoCommute)
2333 std::swap(a&: LHS, b&: RHS);
2334 LHS = PtrReg.first;
2335
2336 LLT PtrTy = MRI.getType(Reg: LHS);
2337
2338 Builder.setInstrAndDebugLoc(MI);
2339 auto PtrAdd = Builder.buildPtrAdd(Res: PtrTy, Op0: LHS, Op1: RHS);
2340 Builder.buildPtrToInt(Dst, Src: PtrAdd);
2341 MI.eraseFromParent();
2342}
2343
2344bool CombinerHelper::matchCombineConstPtrAddToI2P(MachineInstr &MI,
2345 APInt &NewCst) {
2346 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
2347 Register LHS = PtrAdd.getBaseReg();
2348 Register RHS = PtrAdd.getOffsetReg();
2349 MachineRegisterInfo &MRI = Builder.getMF().getRegInfo();
2350
2351 if (auto RHSCst = getIConstantVRegVal(VReg: RHS, MRI)) {
2352 APInt Cst;
2353 if (mi_match(R: LHS, MRI, P: m_GIntToPtr(Src: m_ICst(Cst)))) {
2354 auto DstTy = MRI.getType(Reg: PtrAdd.getReg(Idx: 0));
2355 // G_INTTOPTR uses zero-extension
2356 NewCst = Cst.zextOrTrunc(width: DstTy.getSizeInBits());
2357 NewCst += RHSCst->sextOrTrunc(width: DstTy.getSizeInBits());
2358 return true;
2359 }
2360 }
2361
2362 return false;
2363}
2364
2365void CombinerHelper::applyCombineConstPtrAddToI2P(MachineInstr &MI,
2366 APInt &NewCst) {
2367 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
2368 Register Dst = PtrAdd.getReg(Idx: 0);
2369
2370 Builder.setInstrAndDebugLoc(MI);
2371 Builder.buildConstant(Res: Dst, Val: NewCst);
2372 PtrAdd.eraseFromParent();
2373}
2374
2375bool CombinerHelper::matchCombineAnyExtTrunc(MachineInstr &MI, Register &Reg) {
2376 assert(MI.getOpcode() == TargetOpcode::G_ANYEXT && "Expected a G_ANYEXT");
2377 Register DstReg = MI.getOperand(i: 0).getReg();
2378 Register SrcReg = MI.getOperand(i: 1).getReg();
2379 Register OriginalSrcReg = getSrcRegIgnoringCopies(Reg: SrcReg, MRI);
2380 if (OriginalSrcReg.isValid())
2381 SrcReg = OriginalSrcReg;
2382 LLT DstTy = MRI.getType(Reg: DstReg);
2383 return mi_match(R: SrcReg, MRI,
2384 P: m_GTrunc(Src: m_all_of(preds: m_Reg(R&: Reg), preds: m_SpecificType(Ty: DstTy))));
2385}
2386
2387bool CombinerHelper::matchCombineZextTrunc(MachineInstr &MI, Register &Reg) {
2388 assert(MI.getOpcode() == TargetOpcode::G_ZEXT && "Expected a G_ZEXT");
2389 Register DstReg = MI.getOperand(i: 0).getReg();
2390 Register SrcReg = MI.getOperand(i: 1).getReg();
2391 LLT DstTy = MRI.getType(Reg: DstReg);
2392 if (mi_match(R: SrcReg, MRI,
2393 P: m_GTrunc(Src: m_all_of(preds: m_Reg(R&: Reg), preds: m_SpecificType(Ty: DstTy))))) {
2394 unsigned DstSize = DstTy.getScalarSizeInBits();
2395 unsigned SrcSize = MRI.getType(Reg: SrcReg).getScalarSizeInBits();
2396 return KB->getKnownBits(R: Reg).countMinLeadingZeros() >= DstSize - SrcSize;
2397 }
2398 return false;
2399}
2400
2401bool CombinerHelper::matchCombineExtOfExt(
2402 MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) {
2403 assert((MI.getOpcode() == TargetOpcode::G_ANYEXT ||
2404 MI.getOpcode() == TargetOpcode::G_SEXT ||
2405 MI.getOpcode() == TargetOpcode::G_ZEXT) &&
2406 "Expected a G_[ASZ]EXT");
2407 Register SrcReg = MI.getOperand(i: 1).getReg();
2408 Register OriginalSrcReg = getSrcRegIgnoringCopies(Reg: SrcReg, MRI);
2409 if (OriginalSrcReg.isValid())
2410 SrcReg = OriginalSrcReg;
2411 MachineInstr *SrcMI = MRI.getVRegDef(Reg: SrcReg);
2412 // Match exts with the same opcode, anyext([sz]ext) and sext(zext).
2413 unsigned Opc = MI.getOpcode();
2414 unsigned SrcOpc = SrcMI->getOpcode();
2415 if (Opc == SrcOpc ||
2416 (Opc == TargetOpcode::G_ANYEXT &&
2417 (SrcOpc == TargetOpcode::G_SEXT || SrcOpc == TargetOpcode::G_ZEXT)) ||
2418 (Opc == TargetOpcode::G_SEXT && SrcOpc == TargetOpcode::G_ZEXT)) {
2419 MatchInfo = std::make_tuple(args: SrcMI->getOperand(i: 1).getReg(), args&: SrcOpc);
2420 return true;
2421 }
2422 return false;
2423}
2424
2425void CombinerHelper::applyCombineExtOfExt(
2426 MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) {
2427 assert((MI.getOpcode() == TargetOpcode::G_ANYEXT ||
2428 MI.getOpcode() == TargetOpcode::G_SEXT ||
2429 MI.getOpcode() == TargetOpcode::G_ZEXT) &&
2430 "Expected a G_[ASZ]EXT");
2431
2432 Register Reg = std::get<0>(t&: MatchInfo);
2433 unsigned SrcExtOp = std::get<1>(t&: MatchInfo);
2434
2435 // Combine exts with the same opcode.
2436 if (MI.getOpcode() == SrcExtOp) {
2437 Observer.changingInstr(MI);
2438 MI.getOperand(i: 1).setReg(Reg);
2439 Observer.changedInstr(MI);
2440 return;
2441 }
2442
2443 // Combine:
2444 // - anyext([sz]ext x) to [sz]ext x
2445 // - sext(zext x) to zext x
2446 if (MI.getOpcode() == TargetOpcode::G_ANYEXT ||
2447 (MI.getOpcode() == TargetOpcode::G_SEXT &&
2448 SrcExtOp == TargetOpcode::G_ZEXT)) {
2449 Register DstReg = MI.getOperand(i: 0).getReg();
2450 Builder.setInstrAndDebugLoc(MI);
2451 Builder.buildInstr(Opc: SrcExtOp, DstOps: {DstReg}, SrcOps: {Reg});
2452 MI.eraseFromParent();
2453 }
2454}
2455
2456bool CombinerHelper::matchCombineTruncOfExt(
2457 MachineInstr &MI, std::pair<Register, unsigned> &MatchInfo) {
2458 assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC");
2459 Register SrcReg = MI.getOperand(i: 1).getReg();
2460 MachineInstr *SrcMI = MRI.getVRegDef(Reg: SrcReg);
2461 unsigned SrcOpc = SrcMI->getOpcode();
2462 if (SrcOpc == TargetOpcode::G_ANYEXT || SrcOpc == TargetOpcode::G_SEXT ||
2463 SrcOpc == TargetOpcode::G_ZEXT) {
2464 MatchInfo = std::make_pair(x: SrcMI->getOperand(i: 1).getReg(), y&: SrcOpc);
2465 return true;
2466 }
2467 return false;
2468}
2469
2470void CombinerHelper::applyCombineTruncOfExt(
2471 MachineInstr &MI, std::pair<Register, unsigned> &MatchInfo) {
2472 assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC");
2473 Register SrcReg = MatchInfo.first;
2474 unsigned SrcExtOp = MatchInfo.second;
2475 Register DstReg = MI.getOperand(i: 0).getReg();
2476 LLT SrcTy = MRI.getType(Reg: SrcReg);
2477 LLT DstTy = MRI.getType(Reg: DstReg);
2478 if (SrcTy == DstTy) {
2479 MI.eraseFromParent();
2480 replaceRegWith(MRI, FromReg: DstReg, ToReg: SrcReg);
2481 return;
2482 }
2483 Builder.setInstrAndDebugLoc(MI);
2484 if (SrcTy.getSizeInBits() < DstTy.getSizeInBits())
2485 Builder.buildInstr(Opc: SrcExtOp, DstOps: {DstReg}, SrcOps: {SrcReg});
2486 else
2487 Builder.buildTrunc(Res: DstReg, Op: SrcReg);
2488 MI.eraseFromParent();
2489}
2490
2491static LLT getMidVTForTruncRightShiftCombine(LLT ShiftTy, LLT TruncTy) {
2492 const unsigned ShiftSize = ShiftTy.getScalarSizeInBits();
2493 const unsigned TruncSize = TruncTy.getScalarSizeInBits();
2494
2495 // ShiftTy > 32 > TruncTy -> 32
2496 if (ShiftSize > 32 && TruncSize < 32)
2497 return ShiftTy.changeElementSize(NewEltSize: 32);
2498
2499 // TODO: We could also reduce to 16 bits, but that's more target-dependent.
2500 // Some targets like it, some don't, some only like it under certain
2501 // conditions/processor versions, etc.
2502 // A TL hook might be needed for this.
2503
2504 // Don't combine
2505 return ShiftTy;
2506}
2507
2508bool CombinerHelper::matchCombineTruncOfShift(
2509 MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) {
2510 assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC");
2511 Register DstReg = MI.getOperand(i: 0).getReg();
2512 Register SrcReg = MI.getOperand(i: 1).getReg();
2513
2514 if (!MRI.hasOneNonDBGUse(RegNo: SrcReg))
2515 return false;
2516
2517 LLT SrcTy = MRI.getType(Reg: SrcReg);
2518 LLT DstTy = MRI.getType(Reg: DstReg);
2519
2520 MachineInstr *SrcMI = getDefIgnoringCopies(Reg: SrcReg, MRI);
2521 const auto &TL = getTargetLowering();
2522
2523 LLT NewShiftTy;
2524 switch (SrcMI->getOpcode()) {
2525 default:
2526 return false;
2527 case TargetOpcode::G_SHL: {
2528 NewShiftTy = DstTy;
2529
2530 // Make sure new shift amount is legal.
2531 KnownBits Known = KB->getKnownBits(R: SrcMI->getOperand(i: 2).getReg());
2532 if (Known.getMaxValue().uge(RHS: NewShiftTy.getScalarSizeInBits()))
2533 return false;
2534 break;
2535 }
2536 case TargetOpcode::G_LSHR:
2537 case TargetOpcode::G_ASHR: {
2538 // For right shifts, we conservatively do not do the transform if the TRUNC
2539 // has any STORE users. The reason is that if we change the type of the
2540 // shift, we may break the truncstore combine.
2541 //
2542 // TODO: Fix truncstore combine to handle (trunc(lshr (trunc x), k)).
2543 for (auto &User : MRI.use_instructions(Reg: DstReg))
2544 if (User.getOpcode() == TargetOpcode::G_STORE)
2545 return false;
2546
2547 NewShiftTy = getMidVTForTruncRightShiftCombine(ShiftTy: SrcTy, TruncTy: DstTy);
2548 if (NewShiftTy == SrcTy)
2549 return false;
2550
2551 // Make sure we won't lose information by truncating the high bits.
2552 KnownBits Known = KB->getKnownBits(R: SrcMI->getOperand(i: 2).getReg());
2553 if (Known.getMaxValue().ugt(RHS: NewShiftTy.getScalarSizeInBits() -
2554 DstTy.getScalarSizeInBits()))
2555 return false;
2556 break;
2557 }
2558 }
2559
2560 if (!isLegalOrBeforeLegalizer(
2561 Query: {SrcMI->getOpcode(),
2562 {NewShiftTy, TL.getPreferredShiftAmountTy(ShiftValueTy: NewShiftTy)}}))
2563 return false;
2564
2565 MatchInfo = std::make_pair(x&: SrcMI, y&: NewShiftTy);
2566 return true;
2567}
2568
2569void CombinerHelper::applyCombineTruncOfShift(
2570 MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) {
2571 Builder.setInstrAndDebugLoc(MI);
2572
2573 MachineInstr *ShiftMI = MatchInfo.first;
2574 LLT NewShiftTy = MatchInfo.second;
2575
2576 Register Dst = MI.getOperand(i: 0).getReg();
2577 LLT DstTy = MRI.getType(Reg: Dst);
2578
2579 Register ShiftAmt = ShiftMI->getOperand(i: 2).getReg();
2580 Register ShiftSrc = ShiftMI->getOperand(i: 1).getReg();
2581 ShiftSrc = Builder.buildTrunc(Res: NewShiftTy, Op: ShiftSrc).getReg(Idx: 0);
2582
2583 Register NewShift =
2584 Builder
2585 .buildInstr(Opc: ShiftMI->getOpcode(), DstOps: {NewShiftTy}, SrcOps: {ShiftSrc, ShiftAmt})
2586 .getReg(Idx: 0);
2587
2588 if (NewShiftTy == DstTy)
2589 replaceRegWith(MRI, FromReg: Dst, ToReg: NewShift);
2590 else
2591 Builder.buildTrunc(Res: Dst, Op: NewShift);
2592
2593 eraseInst(MI);
2594}
2595
2596bool CombinerHelper::matchAnyExplicitUseIsUndef(MachineInstr &MI) {
2597 return any_of(Range: MI.explicit_uses(), P: [this](const MachineOperand &MO) {
2598 return MO.isReg() &&
2599 getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MO.getReg(), MRI);
2600 });
2601}
2602
2603bool CombinerHelper::matchAllExplicitUsesAreUndef(MachineInstr &MI) {
2604 return all_of(Range: MI.explicit_uses(), P: [this](const MachineOperand &MO) {
2605 return !MO.isReg() ||
2606 getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MO.getReg(), MRI);
2607 });
2608}
2609
2610bool CombinerHelper::matchUndefShuffleVectorMask(MachineInstr &MI) {
2611 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
2612 ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask();
2613 return all_of(Range&: Mask, P: [](int Elt) { return Elt < 0; });
2614}
2615
2616bool CombinerHelper::matchUndefStore(MachineInstr &MI) {
2617 assert(MI.getOpcode() == TargetOpcode::G_STORE);
2618 return getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MI.getOperand(i: 0).getReg(),
2619 MRI);
2620}
2621
2622bool CombinerHelper::matchUndefSelectCmp(MachineInstr &MI) {
2623 assert(MI.getOpcode() == TargetOpcode::G_SELECT);
2624 return getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MI.getOperand(i: 1).getReg(),
2625 MRI);
2626}
2627
2628bool CombinerHelper::matchInsertExtractVecEltOutOfBounds(MachineInstr &MI) {
2629 assert((MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT ||
2630 MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) &&
2631 "Expected an insert/extract element op");
2632 LLT VecTy = MRI.getType(Reg: MI.getOperand(i: 1).getReg());
2633 unsigned IdxIdx =
2634 MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT ? 2 : 3;
2635 auto Idx = getIConstantVRegVal(VReg: MI.getOperand(i: IdxIdx).getReg(), MRI);
2636 if (!Idx)
2637 return false;
2638 return Idx->getZExtValue() >= VecTy.getNumElements();
2639}
2640
2641bool CombinerHelper::matchConstantSelectCmp(MachineInstr &MI, unsigned &OpIdx) {
2642 GSelect &SelMI = cast<GSelect>(Val&: MI);
2643 auto Cst =
2644 isConstantOrConstantSplatVector(MI&: *MRI.getVRegDef(Reg: SelMI.getCondReg()), MRI);
2645 if (!Cst)
2646 return false;
2647 OpIdx = Cst->isZero() ? 3 : 2;
2648 return true;
2649}
2650
2651void CombinerHelper::eraseInst(MachineInstr &MI) { MI.eraseFromParent(); }
2652
2653bool CombinerHelper::matchEqualDefs(const MachineOperand &MOP1,
2654 const MachineOperand &MOP2) {
2655 if (!MOP1.isReg() || !MOP2.isReg())
2656 return false;
2657 auto InstAndDef1 = getDefSrcRegIgnoringCopies(Reg: MOP1.getReg(), MRI);
2658 if (!InstAndDef1)
2659 return false;
2660 auto InstAndDef2 = getDefSrcRegIgnoringCopies(Reg: MOP2.getReg(), MRI);
2661 if (!InstAndDef2)
2662 return false;
2663 MachineInstr *I1 = InstAndDef1->MI;
2664 MachineInstr *I2 = InstAndDef2->MI;
2665
2666 // Handle a case like this:
2667 //
2668 // %0:_(s64), %1:_(s64) = G_UNMERGE_VALUES %2:_(<2 x s64>)
2669 //
2670 // Even though %0 and %1 are produced by the same instruction they are not
2671 // the same values.
2672 if (I1 == I2)
2673 return MOP1.getReg() == MOP2.getReg();
2674
2675 // If we have an instruction which loads or stores, we can't guarantee that
2676 // it is identical.
2677 //
2678 // For example, we may have
2679 //
2680 // %x1 = G_LOAD %addr (load N from @somewhere)
2681 // ...
2682 // call @foo
2683 // ...
2684 // %x2 = G_LOAD %addr (load N from @somewhere)
2685 // ...
2686 // %or = G_OR %x1, %x2
2687 //
2688 // It's possible that @foo will modify whatever lives at the address we're
2689 // loading from. To be safe, let's just assume that all loads and stores
2690 // are different (unless we have something which is guaranteed to not
2691 // change.)
2692 if (I1->mayLoadOrStore() && !I1->isDereferenceableInvariantLoad())
2693 return false;
2694
2695 // If both instructions are loads or stores, they are equal only if both
2696 // are dereferenceable invariant loads with the same number of bits.
2697 if (I1->mayLoadOrStore() && I2->mayLoadOrStore()) {
2698 GLoadStore *LS1 = dyn_cast<GLoadStore>(Val: I1);
2699 GLoadStore *LS2 = dyn_cast<GLoadStore>(Val: I2);
2700 if (!LS1 || !LS2)
2701 return false;
2702
2703 if (!I2->isDereferenceableInvariantLoad() ||
2704 (LS1->getMemSizeInBits() != LS2->getMemSizeInBits()))
2705 return false;
2706 }
2707
2708 // Check for physical registers on the instructions first to avoid cases
2709 // like this:
2710 //
2711 // %a = COPY $physreg
2712 // ...
2713 // SOMETHING implicit-def $physreg
2714 // ...
2715 // %b = COPY $physreg
2716 //
2717 // These copies are not equivalent.
2718 if (any_of(Range: I1->uses(), P: [](const MachineOperand &MO) {
2719 return MO.isReg() && MO.getReg().isPhysical();
2720 })) {
2721 // Check if we have a case like this:
2722 //
2723 // %a = COPY $physreg
2724 // %b = COPY %a
2725 //
2726 // In this case, I1 and I2 will both be equal to %a = COPY $physreg.
2727 // From that, we know that they must have the same value, since they must
2728 // have come from the same COPY.
2729 return I1->isIdenticalTo(Other: *I2);
2730 }
2731
2732 // We don't have any physical registers, so we don't necessarily need the
2733 // same vreg defs.
2734 //
2735 // On the off-chance that there's some target instruction feeding into the
2736 // instruction, let's use produceSameValue instead of isIdenticalTo.
2737 if (Builder.getTII().produceSameValue(MI0: *I1, MI1: *I2, MRI: &MRI)) {
2738 // Handle instructions with multiple defs that produce same values. Values
2739 // are same for operands with same index.
2740 // %0:_(s8), %1:_(s8), %2:_(s8), %3:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>)
2741 // %5:_(s8), %6:_(s8), %7:_(s8), %8:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>)
2742 // I1 and I2 are different instructions but produce same values,
2743 // %1 and %6 are same, %1 and %7 are not the same value.
2744 return I1->findRegisterDefOperandIdx(Reg: InstAndDef1->Reg) ==
2745 I2->findRegisterDefOperandIdx(Reg: InstAndDef2->Reg);
2746 }
2747 return false;
2748}
2749
2750bool CombinerHelper::matchConstantOp(const MachineOperand &MOP, int64_t C) {
2751 if (!MOP.isReg())
2752 return false;
2753 auto *MI = MRI.getVRegDef(Reg: MOP.getReg());
2754 auto MaybeCst = isConstantOrConstantSplatVector(MI&: *MI, MRI);
2755 return MaybeCst && MaybeCst->getBitWidth() <= 64 &&
2756 MaybeCst->getSExtValue() == C;
2757}
2758
2759bool CombinerHelper::matchConstantFPOp(const MachineOperand &MOP, double C) {
2760 if (!MOP.isReg())
2761 return false;
2762 std::optional<FPValueAndVReg> MaybeCst;
2763 if (!mi_match(R: MOP.getReg(), MRI, P: m_GFCstOrSplat(FPValReg&: MaybeCst)))
2764 return false;
2765
2766 return MaybeCst->Value.isExactlyValue(V: C);
2767}
2768
2769void CombinerHelper::replaceSingleDefInstWithOperand(MachineInstr &MI,
2770 unsigned OpIdx) {
2771 assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?");
2772 Register OldReg = MI.getOperand(i: 0).getReg();
2773 Register Replacement = MI.getOperand(i: OpIdx).getReg();
2774 assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?");
2775 MI.eraseFromParent();
2776 replaceRegWith(MRI, FromReg: OldReg, ToReg: Replacement);
2777}
2778
2779void CombinerHelper::replaceSingleDefInstWithReg(MachineInstr &MI,
2780 Register Replacement) {
2781 assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?");
2782 Register OldReg = MI.getOperand(i: 0).getReg();
2783 assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?");
2784 MI.eraseFromParent();
2785 replaceRegWith(MRI, FromReg: OldReg, ToReg: Replacement);
2786}
2787
2788bool CombinerHelper::matchConstantLargerBitWidth(MachineInstr &MI,
2789 unsigned ConstIdx) {
2790 Register ConstReg = MI.getOperand(i: ConstIdx).getReg();
2791 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
2792
2793 // Get the shift amount
2794 auto VRegAndVal = getIConstantVRegValWithLookThrough(VReg: ConstReg, MRI);
2795 if (!VRegAndVal)
2796 return false;
2797
2798 // Return true of shift amount >= Bitwidth
2799 return (VRegAndVal->Value.uge(RHS: DstTy.getSizeInBits()));
2800}
2801
2802void CombinerHelper::applyFunnelShiftConstantModulo(MachineInstr &MI) {
2803 assert((MI.getOpcode() == TargetOpcode::G_FSHL ||
2804 MI.getOpcode() == TargetOpcode::G_FSHR) &&
2805 "This is not a funnel shift operation");
2806
2807 Register ConstReg = MI.getOperand(i: 3).getReg();
2808 LLT ConstTy = MRI.getType(Reg: ConstReg);
2809 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
2810
2811 auto VRegAndVal = getIConstantVRegValWithLookThrough(VReg: ConstReg, MRI);
2812 assert((VRegAndVal) && "Value is not a constant");
2813
2814 // Calculate the new Shift Amount = Old Shift Amount % BitWidth
2815 APInt NewConst = VRegAndVal->Value.urem(
2816 RHS: APInt(ConstTy.getSizeInBits(), DstTy.getScalarSizeInBits()));
2817
2818 Builder.setInstrAndDebugLoc(MI);
2819 auto NewConstInstr = Builder.buildConstant(Res: ConstTy, Val: NewConst.getZExtValue());
2820 Builder.buildInstr(
2821 Opc: MI.getOpcode(), DstOps: {MI.getOperand(i: 0)},
2822 SrcOps: {MI.getOperand(i: 1), MI.getOperand(i: 2), NewConstInstr.getReg(Idx: 0)});
2823
2824 MI.eraseFromParent();
2825}
2826
2827bool CombinerHelper::matchSelectSameVal(MachineInstr &MI) {
2828 assert(MI.getOpcode() == TargetOpcode::G_SELECT);
2829 // Match (cond ? x : x)
2830 return matchEqualDefs(MOP1: MI.getOperand(i: 2), MOP2: MI.getOperand(i: 3)) &&
2831 canReplaceReg(DstReg: MI.getOperand(i: 0).getReg(), SrcReg: MI.getOperand(i: 2).getReg(),
2832 MRI);
2833}
2834
2835bool CombinerHelper::matchBinOpSameVal(MachineInstr &MI) {
2836 return matchEqualDefs(MOP1: MI.getOperand(i: 1), MOP2: MI.getOperand(i: 2)) &&
2837 canReplaceReg(DstReg: MI.getOperand(i: 0).getReg(), SrcReg: MI.getOperand(i: 1).getReg(),
2838 MRI);
2839}
2840
2841bool CombinerHelper::matchOperandIsZero(MachineInstr &MI, unsigned OpIdx) {
2842 return matchConstantOp(MOP: MI.getOperand(i: OpIdx), C: 0) &&
2843 canReplaceReg(DstReg: MI.getOperand(i: 0).getReg(), SrcReg: MI.getOperand(i: OpIdx).getReg(),
2844 MRI);
2845}
2846
2847bool CombinerHelper::matchOperandIsUndef(MachineInstr &MI, unsigned OpIdx) {
2848 MachineOperand &MO = MI.getOperand(i: OpIdx);
2849 return MO.isReg() &&
2850 getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MO.getReg(), MRI);
2851}
2852
2853bool CombinerHelper::matchOperandIsKnownToBeAPowerOfTwo(MachineInstr &MI,
2854 unsigned OpIdx) {
2855 MachineOperand &MO = MI.getOperand(i: OpIdx);
2856 return isKnownToBeAPowerOfTwo(Val: MO.getReg(), MRI, KnownBits: KB);
2857}
2858
2859void CombinerHelper::replaceInstWithFConstant(MachineInstr &MI, double C) {
2860 assert(MI.getNumDefs() == 1 && "Expected only one def?");
2861 Builder.setInstr(MI);
2862 Builder.buildFConstant(Res: MI.getOperand(i: 0), Val: C);
2863 MI.eraseFromParent();
2864}
2865
2866void CombinerHelper::replaceInstWithConstant(MachineInstr &MI, int64_t C) {
2867 assert(MI.getNumDefs() == 1 && "Expected only one def?");
2868 Builder.setInstr(MI);
2869 Builder.buildConstant(Res: MI.getOperand(i: 0), Val: C);
2870 MI.eraseFromParent();
2871}
2872
2873void CombinerHelper::replaceInstWithConstant(MachineInstr &MI, APInt C) {
2874 assert(MI.getNumDefs() == 1 && "Expected only one def?");
2875 Builder.setInstr(MI);
2876 Builder.buildConstant(Res: MI.getOperand(i: 0), Val: C);
2877 MI.eraseFromParent();
2878}
2879
2880void CombinerHelper::replaceInstWithFConstant(MachineInstr &MI, ConstantFP *CFP) {
2881 assert(MI.getNumDefs() == 1 && "Expected only one def?");
2882 Builder.setInstr(MI);
2883 Builder.buildFConstant(Res: MI.getOperand(i: 0), Val: CFP->getValueAPF());
2884 MI.eraseFromParent();
2885}
2886
2887void CombinerHelper::replaceInstWithUndef(MachineInstr &MI) {
2888 assert(MI.getNumDefs() == 1 && "Expected only one def?");
2889 Builder.setInstr(MI);
2890 Builder.buildUndef(Res: MI.getOperand(i: 0));
2891 MI.eraseFromParent();
2892}
2893
2894bool CombinerHelper::matchSimplifyAddToSub(
2895 MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) {
2896 Register LHS = MI.getOperand(i: 1).getReg();
2897 Register RHS = MI.getOperand(i: 2).getReg();
2898 Register &NewLHS = std::get<0>(t&: MatchInfo);
2899 Register &NewRHS = std::get<1>(t&: MatchInfo);
2900
2901 // Helper lambda to check for opportunities for
2902 // ((0-A) + B) -> B - A
2903 // (A + (0-B)) -> A - B
2904 auto CheckFold = [&](Register &MaybeSub, Register &MaybeNewLHS) {
2905 if (!mi_match(R: MaybeSub, MRI, P: m_Neg(Src: m_Reg(R&: NewRHS))))
2906 return false;
2907 NewLHS = MaybeNewLHS;
2908 return true;
2909 };
2910
2911 return CheckFold(LHS, RHS) || CheckFold(RHS, LHS);
2912}
2913
2914bool CombinerHelper::matchCombineInsertVecElts(
2915 MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) {
2916 assert(MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT &&
2917 "Invalid opcode");
2918 Register DstReg = MI.getOperand(i: 0).getReg();
2919 LLT DstTy = MRI.getType(Reg: DstReg);
2920 assert(DstTy.isVector() && "Invalid G_INSERT_VECTOR_ELT?");
2921 unsigned NumElts = DstTy.getNumElements();
2922 // If this MI is part of a sequence of insert_vec_elts, then
2923 // don't do the combine in the middle of the sequence.
2924 if (MRI.hasOneUse(RegNo: DstReg) && MRI.use_instr_begin(RegNo: DstReg)->getOpcode() ==
2925 TargetOpcode::G_INSERT_VECTOR_ELT)
2926 return false;
2927 MachineInstr *CurrInst = &MI;
2928 MachineInstr *TmpInst;
2929 int64_t IntImm;
2930 Register TmpReg;
2931 MatchInfo.resize(N: NumElts);
2932 while (mi_match(
2933 R: CurrInst->getOperand(i: 0).getReg(), MRI,
2934 P: m_GInsertVecElt(Src0: m_MInstr(MI&: TmpInst), Src1: m_Reg(R&: TmpReg), Src2: m_ICst(Cst&: IntImm)))) {
2935 if (IntImm >= NumElts || IntImm < 0)
2936 return false;
2937 if (!MatchInfo[IntImm])
2938 MatchInfo[IntImm] = TmpReg;
2939 CurrInst = TmpInst;
2940 }
2941 // Variable index.
2942 if (CurrInst->getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT)
2943 return false;
2944 if (TmpInst->getOpcode() == TargetOpcode::G_BUILD_VECTOR) {
2945 for (unsigned I = 1; I < TmpInst->getNumOperands(); ++I) {
2946 if (!MatchInfo[I - 1].isValid())
2947 MatchInfo[I - 1] = TmpInst->getOperand(i: I).getReg();
2948 }
2949 return true;
2950 }
2951 // If we didn't end in a G_IMPLICIT_DEF, bail out.
2952 return TmpInst->getOpcode() == TargetOpcode::G_IMPLICIT_DEF;
2953}
2954
2955void CombinerHelper::applyCombineInsertVecElts(
2956 MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) {
2957 Builder.setInstr(MI);
2958 Register UndefReg;
2959 auto GetUndef = [&]() {
2960 if (UndefReg)
2961 return UndefReg;
2962 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
2963 UndefReg = Builder.buildUndef(Res: DstTy.getScalarType()).getReg(Idx: 0);
2964 return UndefReg;
2965 };
2966 for (unsigned I = 0; I < MatchInfo.size(); ++I) {
2967 if (!MatchInfo[I])
2968 MatchInfo[I] = GetUndef();
2969 }
2970 Builder.buildBuildVector(Res: MI.getOperand(i: 0).getReg(), Ops: MatchInfo);
2971 MI.eraseFromParent();
2972}
2973
2974void CombinerHelper::applySimplifyAddToSub(
2975 MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) {
2976 Builder.setInstr(MI);
2977 Register SubLHS, SubRHS;
2978 std::tie(args&: SubLHS, args&: SubRHS) = MatchInfo;
2979 Builder.buildSub(Dst: MI.getOperand(i: 0).getReg(), Src0: SubLHS, Src1: SubRHS);
2980 MI.eraseFromParent();
2981}
2982
2983bool CombinerHelper::matchHoistLogicOpWithSameOpcodeHands(
2984 MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) {
2985 // Matches: logic (hand x, ...), (hand y, ...) -> hand (logic x, y), ...
2986 //
2987 // Creates the new hand + logic instruction (but does not insert them.)
2988 //
2989 // On success, MatchInfo is populated with the new instructions. These are
2990 // inserted in applyHoistLogicOpWithSameOpcodeHands.
2991 unsigned LogicOpcode = MI.getOpcode();
2992 assert(LogicOpcode == TargetOpcode::G_AND ||
2993 LogicOpcode == TargetOpcode::G_OR ||
2994 LogicOpcode == TargetOpcode::G_XOR);
2995 MachineIRBuilder MIB(MI);
2996 Register Dst = MI.getOperand(i: 0).getReg();
2997 Register LHSReg = MI.getOperand(i: 1).getReg();
2998 Register RHSReg = MI.getOperand(i: 2).getReg();
2999
3000 // Don't recompute anything.
3001 if (!MRI.hasOneNonDBGUse(RegNo: LHSReg) || !MRI.hasOneNonDBGUse(RegNo: RHSReg))
3002 return false;
3003
3004 // Make sure we have (hand x, ...), (hand y, ...)
3005 MachineInstr *LeftHandInst = getDefIgnoringCopies(Reg: LHSReg, MRI);
3006 MachineInstr *RightHandInst = getDefIgnoringCopies(Reg: RHSReg, MRI);
3007 if (!LeftHandInst || !RightHandInst)
3008 return false;
3009 unsigned HandOpcode = LeftHandInst->getOpcode();
3010 if (HandOpcode != RightHandInst->getOpcode())
3011 return false;
3012 if (!LeftHandInst->getOperand(i: 1).isReg() ||
3013 !RightHandInst->getOperand(i: 1).isReg())
3014 return false;
3015
3016 // Make sure the types match up, and if we're doing this post-legalization,
3017 // we end up with legal types.
3018 Register X = LeftHandInst->getOperand(i: 1).getReg();
3019 Register Y = RightHandInst->getOperand(i: 1).getReg();
3020 LLT XTy = MRI.getType(Reg: X);
3021 LLT YTy = MRI.getType(Reg: Y);
3022 if (!XTy.isValid() || XTy != YTy)
3023 return false;
3024
3025 // Optional extra source register.
3026 Register ExtraHandOpSrcReg;
3027 switch (HandOpcode) {
3028 default:
3029 return false;
3030 case TargetOpcode::G_ANYEXT:
3031 case TargetOpcode::G_SEXT:
3032 case TargetOpcode::G_ZEXT: {
3033 // Match: logic (ext X), (ext Y) --> ext (logic X, Y)
3034 break;
3035 }
3036 case TargetOpcode::G_AND:
3037 case TargetOpcode::G_ASHR:
3038 case TargetOpcode::G_LSHR:
3039 case TargetOpcode::G_SHL: {
3040 // Match: logic (binop x, z), (binop y, z) -> binop (logic x, y), z
3041 MachineOperand &ZOp = LeftHandInst->getOperand(i: 2);
3042 if (!matchEqualDefs(MOP1: ZOp, MOP2: RightHandInst->getOperand(i: 2)))
3043 return false;
3044 ExtraHandOpSrcReg = ZOp.getReg();
3045 break;
3046 }
3047 }
3048
3049 if (!isLegalOrBeforeLegalizer(Query: {LogicOpcode, {XTy, YTy}}))
3050 return false;
3051
3052 // Record the steps to build the new instructions.
3053 //
3054 // Steps to build (logic x, y)
3055 auto NewLogicDst = MRI.createGenericVirtualRegister(Ty: XTy);
3056 OperandBuildSteps LogicBuildSteps = {
3057 [=](MachineInstrBuilder &MIB) { MIB.addDef(RegNo: NewLogicDst); },
3058 [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: X); },
3059 [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: Y); }};
3060 InstructionBuildSteps LogicSteps(LogicOpcode, LogicBuildSteps);
3061
3062 // Steps to build hand (logic x, y), ...z
3063 OperandBuildSteps HandBuildSteps = {
3064 [=](MachineInstrBuilder &MIB) { MIB.addDef(RegNo: Dst); },
3065 [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: NewLogicDst); }};
3066 if (ExtraHandOpSrcReg.isValid())
3067 HandBuildSteps.push_back(
3068 Elt: [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: ExtraHandOpSrcReg); });
3069 InstructionBuildSteps HandSteps(HandOpcode, HandBuildSteps);
3070
3071 MatchInfo = InstructionStepsMatchInfo({LogicSteps, HandSteps});
3072 return true;
3073}
3074
3075void CombinerHelper::applyBuildInstructionSteps(
3076 MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) {
3077 assert(MatchInfo.InstrsToBuild.size() &&
3078 "Expected at least one instr to build?");
3079 Builder.setInstr(MI);
3080 for (auto &InstrToBuild : MatchInfo.InstrsToBuild) {
3081 assert(InstrToBuild.Opcode && "Expected a valid opcode?");
3082 assert(InstrToBuild.OperandFns.size() && "Expected at least one operand?");
3083 MachineInstrBuilder Instr = Builder.buildInstr(Opcode: InstrToBuild.Opcode);
3084 for (auto &OperandFn : InstrToBuild.OperandFns)
3085 OperandFn(Instr);
3086 }
3087 MI.eraseFromParent();
3088}
3089
3090bool CombinerHelper::matchAshrShlToSextInreg(
3091 MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) {
3092 assert(MI.getOpcode() == TargetOpcode::G_ASHR);
3093 int64_t ShlCst, AshrCst;
3094 Register Src;
3095 if (!mi_match(R: MI.getOperand(i: 0).getReg(), MRI,
3096 P: m_GAShr(L: m_GShl(L: m_Reg(R&: Src), R: m_ICstOrSplat(Cst&: ShlCst)),
3097 R: m_ICstOrSplat(Cst&: AshrCst))))
3098 return false;
3099 if (ShlCst != AshrCst)
3100 return false;
3101 if (!isLegalOrBeforeLegalizer(
3102 Query: {TargetOpcode::G_SEXT_INREG, {MRI.getType(Reg: Src)}}))
3103 return false;
3104 MatchInfo = std::make_tuple(args&: Src, args&: ShlCst);
3105 return true;
3106}
3107
3108void CombinerHelper::applyAshShlToSextInreg(
3109 MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) {
3110 assert(MI.getOpcode() == TargetOpcode::G_ASHR);
3111 Register Src;
3112 int64_t ShiftAmt;
3113 std::tie(args&: Src, args&: ShiftAmt) = MatchInfo;
3114 unsigned Size = MRI.getType(Reg: Src).getScalarSizeInBits();
3115 Builder.setInstrAndDebugLoc(MI);
3116 Builder.buildSExtInReg(Res: MI.getOperand(i: 0).getReg(), Op: Src, ImmOp: Size - ShiftAmt);
3117 MI.eraseFromParent();
3118}
3119
3120/// and(and(x, C1), C2) -> C1&C2 ? and(x, C1&C2) : 0
3121bool CombinerHelper::matchOverlappingAnd(
3122 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
3123 assert(MI.getOpcode() == TargetOpcode::G_AND);
3124
3125 Register Dst = MI.getOperand(i: 0).getReg();
3126 LLT Ty = MRI.getType(Reg: Dst);
3127
3128 Register R;
3129 int64_t C1;
3130 int64_t C2;
3131 if (!mi_match(
3132 R: Dst, MRI,
3133 P: m_GAnd(L: m_GAnd(L: m_Reg(R), R: m_ICst(Cst&: C1)), R: m_ICst(Cst&: C2))))
3134 return false;
3135
3136 MatchInfo = [=](MachineIRBuilder &B) {
3137 if (C1 & C2) {
3138 B.buildAnd(Dst, Src0: R, Src1: B.buildConstant(Res: Ty, Val: C1 & C2));
3139 return;
3140 }
3141 auto Zero = B.buildConstant(Res: Ty, Val: 0);
3142 replaceRegWith(MRI, FromReg: Dst, ToReg: Zero->getOperand(i: 0).getReg());
3143 };
3144 return true;
3145}
3146
3147bool CombinerHelper::matchRedundantAnd(MachineInstr &MI,
3148 Register &Replacement) {
3149 // Given
3150 //
3151 // %y:_(sN) = G_SOMETHING
3152 // %x:_(sN) = G_SOMETHING
3153 // %res:_(sN) = G_AND %x, %y
3154 //
3155 // Eliminate the G_AND when it is known that x & y == x or x & y == y.
3156 //
3157 // Patterns like this can appear as a result of legalization. E.g.
3158 //
3159 // %cmp:_(s32) = G_ICMP intpred(pred), %x(s32), %y
3160 // %one:_(s32) = G_CONSTANT i32 1
3161 // %and:_(s32) = G_AND %cmp, %one
3162 //
3163 // In this case, G_ICMP only produces a single bit, so x & 1 == x.
3164 assert(MI.getOpcode() == TargetOpcode::G_AND);
3165 if (!KB)
3166 return false;
3167
3168 Register AndDst = MI.getOperand(i: 0).getReg();
3169 Register LHS = MI.getOperand(i: 1).getReg();
3170 Register RHS = MI.getOperand(i: 2).getReg();
3171 KnownBits LHSBits = KB->getKnownBits(R: LHS);
3172 KnownBits RHSBits = KB->getKnownBits(R: RHS);
3173
3174 // Check that x & Mask == x.
3175 // x & 1 == x, always
3176 // x & 0 == x, only if x is also 0
3177 // Meaning Mask has no effect if every bit is either one in Mask or zero in x.
3178 //
3179 // Check if we can replace AndDst with the LHS of the G_AND
3180 if (canReplaceReg(DstReg: AndDst, SrcReg: LHS, MRI) &&
3181 (LHSBits.Zero | RHSBits.One).isAllOnes()) {
3182 Replacement = LHS;
3183 return true;
3184 }
3185
3186 // Check if we can replace AndDst with the RHS of the G_AND
3187 if (canReplaceReg(DstReg: AndDst, SrcReg: RHS, MRI) &&
3188 (LHSBits.One | RHSBits.Zero).isAllOnes()) {
3189 Replacement = RHS;
3190 return true;
3191 }
3192
3193 return false;
3194}
3195
3196bool CombinerHelper::matchRedundantOr(MachineInstr &MI, Register &Replacement) {
3197 // Given
3198 //
3199 // %y:_(sN) = G_SOMETHING
3200 // %x:_(sN) = G_SOMETHING
3201 // %res:_(sN) = G_OR %x, %y
3202 //
3203 // Eliminate the G_OR when it is known that x | y == x or x | y == y.
3204 assert(MI.getOpcode() == TargetOpcode::G_OR);
3205 if (!KB)
3206 return false;
3207
3208 Register OrDst = MI.getOperand(i: 0).getReg();
3209 Register LHS = MI.getOperand(i: 1).getReg();
3210 Register RHS = MI.getOperand(i: 2).getReg();
3211 KnownBits LHSBits = KB->getKnownBits(R: LHS);
3212 KnownBits RHSBits = KB->getKnownBits(R: RHS);
3213
3214 // Check that x | Mask == x.
3215 // x | 0 == x, always
3216 // x | 1 == x, only if x is also 1
3217 // Meaning Mask has no effect if every bit is either zero in Mask or one in x.
3218 //
3219 // Check if we can replace OrDst with the LHS of the G_OR
3220 if (canReplaceReg(DstReg: OrDst, SrcReg: LHS, MRI) &&
3221 (LHSBits.One | RHSBits.Zero).isAllOnes()) {
3222 Replacement = LHS;
3223 return true;
3224 }
3225
3226 // Check if we can replace OrDst with the RHS of the G_OR
3227 if (canReplaceReg(DstReg: OrDst, SrcReg: RHS, MRI) &&
3228 (LHSBits.Zero | RHSBits.One).isAllOnes()) {
3229 Replacement = RHS;
3230 return true;
3231 }
3232
3233 return false;
3234}
3235
3236bool CombinerHelper::matchRedundantSExtInReg(MachineInstr &MI) {
3237 // If the input is already sign extended, just drop the extension.
3238 Register Src = MI.getOperand(i: 1).getReg();
3239 unsigned ExtBits = MI.getOperand(i: 2).getImm();
3240 unsigned TypeSize = MRI.getType(Reg: Src).getScalarSizeInBits();
3241 return KB->computeNumSignBits(R: Src) >= (TypeSize - ExtBits + 1);
3242}
3243
3244static bool isConstValidTrue(const TargetLowering &TLI, unsigned ScalarSizeBits,
3245 int64_t Cst, bool IsVector, bool IsFP) {
3246 // For i1, Cst will always be -1 regardless of boolean contents.
3247 return (ScalarSizeBits == 1 && Cst == -1) ||
3248 isConstTrueVal(TLI, Val: Cst, IsVector, IsFP);
3249}
3250
3251bool CombinerHelper::matchNotCmp(MachineInstr &MI,
3252 SmallVectorImpl<Register> &RegsToNegate) {
3253 assert(MI.getOpcode() == TargetOpcode::G_XOR);
3254 LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
3255 const auto &TLI = *Builder.getMF().getSubtarget().getTargetLowering();
3256 Register XorSrc;
3257 Register CstReg;
3258 // We match xor(src, true) here.
3259 if (!mi_match(R: MI.getOperand(i: 0).getReg(), MRI,
3260 P: m_GXor(L: m_Reg(R&: XorSrc), R: m_Reg(R&: CstReg))))
3261 return false;
3262
3263 if (!MRI.hasOneNonDBGUse(RegNo: XorSrc))
3264 return false;
3265
3266 // Check that XorSrc is the root of a tree of comparisons combined with ANDs
3267 // and ORs. The suffix of RegsToNegate starting from index I is used a work
3268 // list of tree nodes to visit.
3269 RegsToNegate.push_back(Elt: XorSrc);
3270 // Remember whether the comparisons are all integer or all floating point.
3271 bool IsInt = false;
3272 bool IsFP = false;
3273 for (unsigned I = 0; I < RegsToNegate.size(); ++I) {
3274 Register Reg = RegsToNegate[I];
3275 if (!MRI.hasOneNonDBGUse(RegNo: Reg))
3276 return false;
3277 MachineInstr *Def = MRI.getVRegDef(Reg);
3278 switch (Def->getOpcode()) {
3279 default:
3280 // Don't match if the tree contains anything other than ANDs, ORs and
3281 // comparisons.
3282 return false;
3283 case TargetOpcode::G_ICMP:
3284 if (IsFP)
3285 return false;
3286 IsInt = true;
3287 // When we apply the combine we will invert the predicate.
3288 break;
3289 case TargetOpcode::G_FCMP:
3290 if (IsInt)
3291 return false;
3292 IsFP = true;
3293 // When we apply the combine we will invert the predicate.
3294 break;
3295 case TargetOpcode::G_AND:
3296 case TargetOpcode::G_OR:
3297 // Implement De Morgan's laws:
3298 // ~(x & y) -> ~x | ~y
3299 // ~(x | y) -> ~x & ~y
3300 // When we apply the combine we will change the opcode and recursively
3301 // negate the operands.
3302 RegsToNegate.push_back(Elt: Def->getOperand(i: 1).getReg());
3303 RegsToNegate.push_back(Elt: Def->getOperand(i: 2).getReg());
3304 break;
3305 }
3306 }
3307
3308 // Now we know whether the comparisons are integer or floating point, check
3309 // the constant in the xor.
3310 int64_t Cst;
3311 if (Ty.isVector()) {
3312 MachineInstr *CstDef = MRI.getVRegDef(Reg: CstReg);
3313 auto MaybeCst = getIConstantSplatSExtVal(MI: *CstDef, MRI);
3314 if (!MaybeCst)
3315 return false;
3316 if (!isConstValidTrue(TLI, ScalarSizeBits: Ty.getScalarSizeInBits(), Cst: *MaybeCst, IsVector: true, IsFP))
3317 return false;
3318 } else {
3319 if (!mi_match(R: CstReg, MRI, P: m_ICst(Cst)))
3320 return false;
3321 if (!isConstValidTrue(TLI, ScalarSizeBits: Ty.getSizeInBits(), Cst, IsVector: false, IsFP))
3322 return false;
3323 }
3324
3325 return true;
3326}
3327
3328void CombinerHelper::applyNotCmp(MachineInstr &MI,
3329 SmallVectorImpl<Register> &RegsToNegate) {
3330 for (Register Reg : RegsToNegate) {
3331 MachineInstr *Def = MRI.getVRegDef(Reg);
3332 Observer.changingInstr(MI&: *Def);
3333 // For each comparison, invert the opcode. For each AND and OR, change the
3334 // opcode.
3335 switch (Def->getOpcode()) {
3336 default:
3337 llvm_unreachable("Unexpected opcode");
3338 case TargetOpcode::G_ICMP:
3339 case TargetOpcode::G_FCMP: {
3340 MachineOperand &PredOp = Def->getOperand(i: 1);
3341 CmpInst::Predicate NewP = CmpInst::getInversePredicate(
3342 pred: (CmpInst::Predicate)PredOp.getPredicate());
3343 PredOp.setPredicate(NewP);
3344 break;
3345 }
3346 case TargetOpcode::G_AND:
3347 Def->setDesc(Builder.getTII().get(Opcode: TargetOpcode::G_OR));
3348 break;
3349 case TargetOpcode::G_OR:
3350 Def->setDesc(Builder.getTII().get(Opcode: TargetOpcode::G_AND));
3351 break;
3352 }
3353 Observer.changedInstr(MI&: *Def);
3354 }
3355
3356 replaceRegWith(MRI, FromReg: MI.getOperand(i: 0).getReg(), ToReg: MI.getOperand(i: 1).getReg());
3357 MI.eraseFromParent();
3358}
3359
3360bool CombinerHelper::matchXorOfAndWithSameReg(
3361 MachineInstr &MI, std::pair<Register, Register> &MatchInfo) {
3362 // Match (xor (and x, y), y) (or any of its commuted cases)
3363 assert(MI.getOpcode() == TargetOpcode::G_XOR);
3364 Register &X = MatchInfo.first;
3365 Register &Y = MatchInfo.second;
3366 Register AndReg = MI.getOperand(i: 1).getReg();
3367 Register SharedReg = MI.getOperand(i: 2).getReg();
3368
3369 // Find a G_AND on either side of the G_XOR.
3370 // Look for one of
3371 //
3372 // (xor (and x, y), SharedReg)
3373 // (xor SharedReg, (and x, y))
3374 if (!mi_match(R: AndReg, MRI, P: m_GAnd(L: m_Reg(R&: X), R: m_Reg(R&: Y)))) {
3375 std::swap(a&: AndReg, b&: SharedReg);
3376 if (!mi_match(R: AndReg, MRI, P: m_GAnd(L: m_Reg(R&: X), R: m_Reg(R&: Y))))
3377 return false;
3378 }
3379
3380 // Only do this if we'll eliminate the G_AND.
3381 if (!MRI.hasOneNonDBGUse(RegNo: AndReg))
3382 return false;
3383
3384 // We can combine if SharedReg is the same as either the LHS or RHS of the
3385 // G_AND.
3386 if (Y != SharedReg)
3387 std::swap(a&: X, b&: Y);
3388 return Y == SharedReg;
3389}
3390
3391void CombinerHelper::applyXorOfAndWithSameReg(
3392 MachineInstr &MI, std::pair<Register, Register> &MatchInfo) {
3393 // Fold (xor (and x, y), y) -> (and (not x), y)
3394 Builder.setInstrAndDebugLoc(MI);
3395 Register X, Y;
3396 std::tie(args&: X, args&: Y) = MatchInfo;
3397 auto Not = Builder.buildNot(Dst: MRI.getType(Reg: X), Src0: X);
3398 Observer.changingInstr(MI);
3399 MI.setDesc(Builder.getTII().get(Opcode: TargetOpcode::G_AND));
3400 MI.getOperand(i: 1).setReg(Not->getOperand(i: 0).getReg());
3401 MI.getOperand(i: 2).setReg(Y);
3402 Observer.changedInstr(MI);
3403}
3404
3405bool CombinerHelper::matchPtrAddZero(MachineInstr &MI) {
3406 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
3407 Register DstReg = PtrAdd.getReg(Idx: 0);
3408 LLT Ty = MRI.getType(Reg: DstReg);
3409 const DataLayout &DL = Builder.getMF().getDataLayout();
3410
3411 if (DL.isNonIntegralAddressSpace(AddrSpace: Ty.getScalarType().getAddressSpace()))
3412 return false;
3413
3414 if (Ty.isPointer()) {
3415 auto ConstVal = getIConstantVRegVal(VReg: PtrAdd.getBaseReg(), MRI);
3416 return ConstVal && *ConstVal == 0;
3417 }
3418
3419 assert(Ty.isVector() && "Expecting a vector type");
3420 const MachineInstr *VecMI = MRI.getVRegDef(Reg: PtrAdd.getBaseReg());
3421 return isBuildVectorAllZeros(MI: *VecMI, MRI);
3422}
3423
3424void CombinerHelper::applyPtrAddZero(MachineInstr &MI) {
3425 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
3426 Builder.setInstrAndDebugLoc(PtrAdd);
3427 Builder.buildIntToPtr(Dst: PtrAdd.getReg(Idx: 0), Src: PtrAdd.getOffsetReg());
3428 PtrAdd.eraseFromParent();
3429}
3430
3431/// The second source operand is known to be a power of 2.
3432void CombinerHelper::applySimplifyURemByPow2(MachineInstr &MI) {
3433 Register DstReg = MI.getOperand(i: 0).getReg();
3434 Register Src0 = MI.getOperand(i: 1).getReg();
3435 Register Pow2Src1 = MI.getOperand(i: 2).getReg();
3436 LLT Ty = MRI.getType(Reg: DstReg);
3437 Builder.setInstrAndDebugLoc(MI);
3438
3439 // Fold (urem x, pow2) -> (and x, pow2-1)
3440 auto NegOne = Builder.buildConstant(Res: Ty, Val: -1);
3441 auto Add = Builder.buildAdd(Dst: Ty, Src0: Pow2Src1, Src1: NegOne);
3442 Builder.buildAnd(Dst: DstReg, Src0, Src1: Add);
3443 MI.eraseFromParent();
3444}
3445
3446bool CombinerHelper::matchFoldBinOpIntoSelect(MachineInstr &MI,
3447 unsigned &SelectOpNo) {
3448 Register LHS = MI.getOperand(i: 1).getReg();
3449 Register RHS = MI.getOperand(i: 2).getReg();
3450
3451 Register OtherOperandReg = RHS;
3452 SelectOpNo = 1;
3453 MachineInstr *Select = MRI.getVRegDef(Reg: LHS);
3454
3455 // Don't do this unless the old select is going away. We want to eliminate the
3456 // binary operator, not replace a binop with a select.
3457 if (Select->getOpcode() != TargetOpcode::G_SELECT ||
3458 !MRI.hasOneNonDBGUse(RegNo: LHS)) {
3459 OtherOperandReg = LHS;
3460 SelectOpNo = 2;
3461 Select = MRI.getVRegDef(Reg: RHS);
3462 if (Select->getOpcode() != TargetOpcode::G_SELECT ||
3463 !MRI.hasOneNonDBGUse(RegNo: RHS))
3464 return false;
3465 }
3466
3467 MachineInstr *SelectLHS = MRI.getVRegDef(Reg: Select->getOperand(i: 2).getReg());
3468 MachineInstr *SelectRHS = MRI.getVRegDef(Reg: Select->getOperand(i: 3).getReg());
3469
3470 if (!isConstantOrConstantVector(MI: *SelectLHS, MRI,
3471 /*AllowFP*/ true,
3472 /*AllowOpaqueConstants*/ false))
3473 return false;
3474 if (!isConstantOrConstantVector(MI: *SelectRHS, MRI,
3475 /*AllowFP*/ true,
3476 /*AllowOpaqueConstants*/ false))
3477 return false;
3478
3479 unsigned BinOpcode = MI.getOpcode();
3480
3481 // We know that one of the operands is a select of constants. Now verify that
3482 // the other binary operator operand is either a constant, or we can handle a
3483 // variable.
3484 bool CanFoldNonConst =
3485 (BinOpcode == TargetOpcode::G_AND || BinOpcode == TargetOpcode::G_OR) &&
3486 (isNullOrNullSplat(MI: *SelectLHS, MRI) ||
3487 isAllOnesOrAllOnesSplat(MI: *SelectLHS, MRI)) &&
3488 (isNullOrNullSplat(MI: *SelectRHS, MRI) ||
3489 isAllOnesOrAllOnesSplat(MI: *SelectRHS, MRI));
3490 if (CanFoldNonConst)
3491 return true;
3492
3493 return isConstantOrConstantVector(MI: *MRI.getVRegDef(Reg: OtherOperandReg), MRI,
3494 /*AllowFP*/ true,
3495 /*AllowOpaqueConstants*/ false);
3496}
3497
3498/// \p SelectOperand is the operand in binary operator \p MI that is the select
3499/// to fold.
3500void CombinerHelper::applyFoldBinOpIntoSelect(MachineInstr &MI,
3501 const unsigned &SelectOperand) {
3502 Builder.setInstrAndDebugLoc(MI);
3503
3504 Register Dst = MI.getOperand(i: 0).getReg();
3505 Register LHS = MI.getOperand(i: 1).getReg();
3506 Register RHS = MI.getOperand(i: 2).getReg();
3507 MachineInstr *Select = MRI.getVRegDef(Reg: MI.getOperand(i: SelectOperand).getReg());
3508
3509 Register SelectCond = Select->getOperand(i: 1).getReg();
3510 Register SelectTrue = Select->getOperand(i: 2).getReg();
3511 Register SelectFalse = Select->getOperand(i: 3).getReg();
3512
3513 LLT Ty = MRI.getType(Reg: Dst);
3514 unsigned BinOpcode = MI.getOpcode();
3515
3516 Register FoldTrue, FoldFalse;
3517
3518 // We have a select-of-constants followed by a binary operator with a
3519 // constant. Eliminate the binop by pulling the constant math into the select.
3520 // Example: add (select Cond, CT, CF), CBO --> select Cond, CT + CBO, CF + CBO
3521 if (SelectOperand == 1) {
3522 // TODO: SelectionDAG verifies this actually constant folds before
3523 // committing to the combine.
3524
3525 FoldTrue = Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {SelectTrue, RHS}).getReg(Idx: 0);
3526 FoldFalse =
3527 Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {SelectFalse, RHS}).getReg(Idx: 0);
3528 } else {
3529 FoldTrue = Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {LHS, SelectTrue}).getReg(Idx: 0);
3530 FoldFalse =
3531 Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {LHS, SelectFalse}).getReg(Idx: 0);
3532 }
3533
3534 Builder.buildSelect(Res: Dst, Tst: SelectCond, Op0: FoldTrue, Op1: FoldFalse, Flags: MI.getFlags());
3535 MI.eraseFromParent();
3536}
3537
3538std::optional<SmallVector<Register, 8>>
3539CombinerHelper::findCandidatesForLoadOrCombine(const MachineInstr *Root) const {
3540 assert(Root->getOpcode() == TargetOpcode::G_OR && "Expected G_OR only!");
3541 // We want to detect if Root is part of a tree which represents a bunch
3542 // of loads being merged into a larger load. We'll try to recognize patterns
3543 // like, for example:
3544 //
3545 // Reg Reg
3546 // \ /
3547 // OR_1 Reg
3548 // \ /
3549 // OR_2
3550 // \ Reg
3551 // .. /
3552 // Root
3553 //
3554 // Reg Reg Reg Reg
3555 // \ / \ /
3556 // OR_1 OR_2
3557 // \ /
3558 // \ /
3559 // ...
3560 // Root
3561 //
3562 // Each "Reg" may have been produced by a load + some arithmetic. This
3563 // function will save each of them.
3564 SmallVector<Register, 8> RegsToVisit;
3565 SmallVector<const MachineInstr *, 7> Ors = {Root};
3566
3567 // In the "worst" case, we're dealing with a load for each byte. So, there
3568 // are at most #bytes - 1 ORs.
3569 const unsigned MaxIter =
3570 MRI.getType(Reg: Root->getOperand(i: 0).getReg()).getSizeInBytes() - 1;
3571 for (unsigned Iter = 0; Iter < MaxIter; ++Iter) {
3572 if (Ors.empty())
3573 break;
3574 const MachineInstr *Curr = Ors.pop_back_val();
3575 Register OrLHS = Curr->getOperand(i: 1).getReg();
3576 Register OrRHS = Curr->getOperand(i: 2).getReg();
3577
3578 // In the combine, we want to elimate the entire tree.
3579 if (!MRI.hasOneNonDBGUse(RegNo: OrLHS) || !MRI.hasOneNonDBGUse(RegNo: OrRHS))
3580 return std::nullopt;
3581
3582 // If it's a G_OR, save it and continue to walk. If it's not, then it's
3583 // something that may be a load + arithmetic.
3584 if (const MachineInstr *Or = getOpcodeDef(Opcode: TargetOpcode::G_OR, Reg: OrLHS, MRI))
3585 Ors.push_back(Elt: Or);
3586 else
3587 RegsToVisit.push_back(Elt: OrLHS);
3588 if (const MachineInstr *Or = getOpcodeDef(Opcode: TargetOpcode::G_OR, Reg: OrRHS, MRI))
3589 Ors.push_back(Elt: Or);
3590 else
3591 RegsToVisit.push_back(Elt: OrRHS);
3592 }
3593
3594 // We're going to try and merge each register into a wider power-of-2 type,
3595 // so we ought to have an even number of registers.
3596 if (RegsToVisit.empty() || RegsToVisit.size() % 2 != 0)
3597 return std::nullopt;
3598 return RegsToVisit;
3599}
3600
3601/// Helper function for findLoadOffsetsForLoadOrCombine.
3602///
3603/// Check if \p Reg is the result of loading a \p MemSizeInBits wide value,
3604/// and then moving that value into a specific byte offset.
3605///
3606/// e.g. x[i] << 24
3607///
3608/// \returns The load instruction and the byte offset it is moved into.
3609static std::optional<std::pair<GZExtLoad *, int64_t>>
3610matchLoadAndBytePosition(Register Reg, unsigned MemSizeInBits,
3611 const MachineRegisterInfo &MRI) {
3612 assert(MRI.hasOneNonDBGUse(Reg) &&
3613 "Expected Reg to only have one non-debug use?");
3614 Register MaybeLoad;
3615 int64_t Shift;
3616 if (!mi_match(R: Reg, MRI,
3617 P: m_OneNonDBGUse(SP: m_GShl(L: m_Reg(R&: MaybeLoad), R: m_ICst(Cst&: Shift))))) {
3618 Shift = 0;
3619 MaybeLoad = Reg;
3620 }
3621
3622 if (Shift % MemSizeInBits != 0)
3623 return std::nullopt;
3624
3625 // TODO: Handle other types of loads.
3626 auto *Load = getOpcodeDef<GZExtLoad>(Reg: MaybeLoad, MRI);
3627 if (!Load)
3628 return std::nullopt;
3629
3630 if (!Load->isUnordered() || Load->getMemSizeInBits() != MemSizeInBits)
3631 return std::nullopt;
3632
3633 return std::make_pair(x&: Load, y: Shift / MemSizeInBits);
3634}
3635
3636std::optional<std::tuple<GZExtLoad *, int64_t, GZExtLoad *>>
3637CombinerHelper::findLoadOffsetsForLoadOrCombine(
3638 SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx,
3639 const SmallVector<Register, 8> &RegsToVisit, const unsigned MemSizeInBits) {
3640
3641 // Each load found for the pattern. There should be one for each RegsToVisit.
3642 SmallSetVector<const MachineInstr *, 8> Loads;
3643
3644 // The lowest index used in any load. (The lowest "i" for each x[i].)
3645 int64_t LowestIdx = INT64_MAX;
3646
3647 // The load which uses the lowest index.
3648 GZExtLoad *LowestIdxLoad = nullptr;
3649
3650 // Keeps track of the load indices we see. We shouldn't see any indices twice.
3651 SmallSet<int64_t, 8> SeenIdx;
3652
3653 // Ensure each load is in the same MBB.
3654 // TODO: Support multiple MachineBasicBlocks.
3655 MachineBasicBlock *MBB = nullptr;
3656 const MachineMemOperand *MMO = nullptr;
3657
3658 // Earliest instruction-order load in the pattern.
3659 GZExtLoad *EarliestLoad = nullptr;
3660
3661 // Latest instruction-order load in the pattern.
3662 GZExtLoad *LatestLoad = nullptr;
3663
3664 // Base pointer which every load should share.
3665 Register BasePtr;
3666
3667 // We want to find a load for each register. Each load should have some
3668 // appropriate bit twiddling arithmetic. During this loop, we will also keep
3669 // track of the load which uses the lowest index. Later, we will check if we
3670 // can use its pointer in the final, combined load.
3671 for (auto Reg : RegsToVisit) {
3672 // Find the load, and find the position that it will end up in (e.g. a
3673 // shifted) value.
3674 auto LoadAndPos = matchLoadAndBytePosition(Reg, MemSizeInBits, MRI);
3675 if (!LoadAndPos)
3676 return std::nullopt;
3677 GZExtLoad *Load;
3678 int64_t DstPos;
3679 std::tie(args&: Load, args&: DstPos) = *LoadAndPos;
3680
3681 // TODO: Handle multiple MachineBasicBlocks. Currently not handled because
3682 // it is difficult to check for stores/calls/etc between loads.
3683 MachineBasicBlock *LoadMBB = Load->getParent();
3684 if (!MBB)
3685 MBB = LoadMBB;
3686 if (LoadMBB != MBB)
3687 return std::nullopt;
3688
3689 // Make sure that the MachineMemOperands of every seen load are compatible.
3690 auto &LoadMMO = Load->getMMO();
3691 if (!MMO)
3692 MMO = &LoadMMO;
3693 if (MMO->getAddrSpace() != LoadMMO.getAddrSpace())
3694 return std::nullopt;
3695
3696 // Find out what the base pointer and index for the load is.
3697 Register LoadPtr;
3698 int64_t Idx;
3699 if (!mi_match(R: Load->getOperand(i: 1).getReg(), MRI,
3700 P: m_GPtrAdd(L: m_Reg(R&: LoadPtr), R: m_ICst(Cst&: Idx)))) {
3701 LoadPtr = Load->getOperand(i: 1).getReg();
3702 Idx = 0;
3703 }
3704
3705 // Don't combine things like a[i], a[i] -> a bigger load.
3706 if (!SeenIdx.insert(V: Idx).second)
3707 return std::nullopt;
3708
3709 // Every load must share the same base pointer; don't combine things like:
3710 //
3711 // a[i], b[i + 1] -> a bigger load.
3712 if (!BasePtr.isValid())
3713 BasePtr = LoadPtr;
3714 if (BasePtr != LoadPtr)
3715 return std::nullopt;
3716
3717 if (Idx < LowestIdx) {
3718 LowestIdx = Idx;
3719 LowestIdxLoad = Load;
3720 }
3721
3722 // Keep track of the byte offset that this load ends up at. If we have seen
3723 // the byte offset, then stop here. We do not want to combine:
3724 //
3725 // a[i] << 16, a[i + k] << 16 -> a bigger load.
3726 if (!MemOffset2Idx.try_emplace(Key: DstPos, Args&: Idx).second)
3727 return std::nullopt;
3728 Loads.insert(X: Load);
3729
3730 // Keep track of the position of the earliest/latest loads in the pattern.
3731 // We will check that there are no load fold barriers between them later
3732 // on.
3733 //
3734 // FIXME: Is there a better way to check for load fold barriers?
3735 if (!EarliestLoad || dominates(DefMI: *Load, UseMI: *EarliestLoad))
3736 EarliestLoad = Load;
3737 if (!LatestLoad || dominates(DefMI: *LatestLoad, UseMI: *Load))
3738 LatestLoad = Load;
3739 }
3740
3741 // We found a load for each register. Let's check if each load satisfies the
3742 // pattern.
3743 assert(Loads.size() == RegsToVisit.size() &&
3744 "Expected to find a load for each register?");
3745 assert(EarliestLoad != LatestLoad && EarliestLoad &&
3746 LatestLoad && "Expected at least two loads?");
3747
3748 // Check if there are any stores, calls, etc. between any of the loads. If
3749 // there are, then we can't safely perform the combine.
3750 //
3751 // MaxIter is chosen based off the (worst case) number of iterations it
3752 // typically takes to succeed in the LLVM test suite plus some padding.
3753 //
3754 // FIXME: Is there a better way to check for load fold barriers?
3755 const unsigned MaxIter = 20;
3756 unsigned Iter = 0;
3757 for (const auto &MI : instructionsWithoutDebug(It: EarliestLoad->getIterator(),
3758 End: LatestLoad->getIterator())) {
3759 if (Loads.count(key: &MI))
3760 continue;
3761 if (MI.isLoadFoldBarrier())
3762 return std::nullopt;
3763 if (Iter++ == MaxIter)
3764 return std::nullopt;
3765 }
3766
3767 return std::make_tuple(args&: LowestIdxLoad, args&: LowestIdx, args&: LatestLoad);
3768}
3769
3770bool CombinerHelper::matchLoadOrCombine(
3771 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
3772 assert(MI.getOpcode() == TargetOpcode::G_OR);
3773 MachineFunction &MF = *MI.getMF();
3774 // Assuming a little-endian target, transform:
3775 // s8 *a = ...
3776 // s32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
3777 // =>
3778 // s32 val = *((i32)a)
3779 //
3780 // s8 *a = ...
3781 // s32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
3782 // =>
3783 // s32 val = BSWAP(*((s32)a))
3784 Register Dst = MI.getOperand(i: 0).getReg();
3785 LLT Ty = MRI.getType(Reg: Dst);
3786 if (Ty.isVector())
3787 return false;
3788
3789 // We need to combine at least two loads into this type. Since the smallest
3790 // possible load is into a byte, we need at least a 16-bit wide type.
3791 const unsigned WideMemSizeInBits = Ty.getSizeInBits();
3792 if (WideMemSizeInBits < 16 || WideMemSizeInBits % 8 != 0)
3793 return false;
3794
3795 // Match a collection of non-OR instructions in the pattern.
3796 auto RegsToVisit = findCandidatesForLoadOrCombine(Root: &MI);
3797 if (!RegsToVisit)
3798 return false;
3799
3800 // We have a collection of non-OR instructions. Figure out how wide each of
3801 // the small loads should be based off of the number of potential loads we
3802 // found.
3803 const unsigned NarrowMemSizeInBits = WideMemSizeInBits / RegsToVisit->size();
3804 if (NarrowMemSizeInBits % 8 != 0)
3805 return false;
3806
3807 // Check if each register feeding into each OR is a load from the same
3808 // base pointer + some arithmetic.
3809 //
3810 // e.g. a[0], a[1] << 8, a[2] << 16, etc.
3811 //
3812 // Also verify that each of these ends up putting a[i] into the same memory
3813 // offset as a load into a wide type would.
3814 SmallDenseMap<int64_t, int64_t, 8> MemOffset2Idx;
3815 GZExtLoad *LowestIdxLoad, *LatestLoad;
3816 int64_t LowestIdx;
3817 auto MaybeLoadInfo = findLoadOffsetsForLoadOrCombine(
3818 MemOffset2Idx, RegsToVisit: *RegsToVisit, MemSizeInBits: NarrowMemSizeInBits);
3819 if (!MaybeLoadInfo)
3820 return false;
3821 std::tie(args&: LowestIdxLoad, args&: LowestIdx, args&: LatestLoad) = *MaybeLoadInfo;
3822
3823 // We have a bunch of loads being OR'd together. Using the addresses + offsets
3824 // we found before, check if this corresponds to a big or little endian byte
3825 // pattern. If it does, then we can represent it using a load + possibly a
3826 // BSWAP.
3827 bool IsBigEndianTarget = MF.getDataLayout().isBigEndian();
3828 std::optional<bool> IsBigEndian = isBigEndian(MemOffset2Idx, LowestIdx);
3829 if (!IsBigEndian)
3830 return false;
3831 bool NeedsBSwap = IsBigEndianTarget != *IsBigEndian;
3832 if (NeedsBSwap && !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_BSWAP, {Ty}}))
3833 return false;
3834
3835 // Make sure that the load from the lowest index produces offset 0 in the
3836 // final value.
3837 //
3838 // This ensures that we won't combine something like this:
3839 //
3840 // load x[i] -> byte 2
3841 // load x[i+1] -> byte 0 ---> wide_load x[i]
3842 // load x[i+2] -> byte 1
3843 const unsigned NumLoadsInTy = WideMemSizeInBits / NarrowMemSizeInBits;
3844 const unsigned ZeroByteOffset =
3845 *IsBigEndian
3846 ? bigEndianByteAt(ByteWidth: NumLoadsInTy, I: 0)
3847 : littleEndianByteAt(ByteWidth: NumLoadsInTy, I: 0);
3848 auto ZeroOffsetIdx = MemOffset2Idx.find(Val: ZeroByteOffset);
3849 if (ZeroOffsetIdx == MemOffset2Idx.end() ||
3850 ZeroOffsetIdx->second != LowestIdx)
3851 return false;
3852
3853 // We wil reuse the pointer from the load which ends up at byte offset 0. It
3854 // may not use index 0.
3855 Register Ptr = LowestIdxLoad->getPointerReg();
3856 const MachineMemOperand &MMO = LowestIdxLoad->getMMO();
3857 LegalityQuery::MemDesc MMDesc(MMO);
3858 MMDesc.MemoryTy = Ty;
3859 if (!isLegalOrBeforeLegalizer(
3860 Query: {TargetOpcode::G_LOAD, {Ty, MRI.getType(Reg: Ptr)}, {MMDesc}}))
3861 return false;
3862 auto PtrInfo = MMO.getPointerInfo();
3863 auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Size: WideMemSizeInBits / 8);
3864
3865 // Load must be allowed and fast on the target.
3866 LLVMContext &C = MF.getFunction().getContext();
3867 auto &DL = MF.getDataLayout();
3868 unsigned Fast = 0;
3869 if (!getTargetLowering().allowsMemoryAccess(Context&: C, DL, Ty, MMO: *NewMMO, Fast: &Fast) ||
3870 !Fast)
3871 return false;
3872
3873 MatchInfo = [=](MachineIRBuilder &MIB) {
3874 MIB.setInstrAndDebugLoc(*LatestLoad);
3875 Register LoadDst = NeedsBSwap ? MRI.cloneVirtualRegister(VReg: Dst) : Dst;
3876 MIB.buildLoad(Res: LoadDst, Addr: Ptr, MMO&: *NewMMO);
3877 if (NeedsBSwap)
3878 MIB.buildBSwap(Dst, Src0: LoadDst);
3879 };
3880 return true;
3881}
3882
3883bool CombinerHelper::matchExtendThroughPhis(MachineInstr &MI,
3884 MachineInstr *&ExtMI) {
3885 auto &PHI = cast<GPhi>(Val&: MI);
3886 Register DstReg = PHI.getReg(Idx: 0);
3887
3888 // TODO: Extending a vector may be expensive, don't do this until heuristics
3889 // are better.
3890 if (MRI.getType(Reg: DstReg).isVector())
3891 return false;
3892
3893 // Try to match a phi, whose only use is an extend.
3894 if (!MRI.hasOneNonDBGUse(RegNo: DstReg))
3895 return false;
3896 ExtMI = &*MRI.use_instr_nodbg_begin(RegNo: DstReg);
3897 switch (ExtMI->getOpcode()) {
3898 case TargetOpcode::G_ANYEXT:
3899 return true; // G_ANYEXT is usually free.
3900 case TargetOpcode::G_ZEXT:
3901 case TargetOpcode::G_SEXT:
3902 break;
3903 default:
3904 return false;
3905 }
3906
3907 // If the target is likely to fold this extend away, don't propagate.
3908 if (Builder.getTII().isExtendLikelyToBeFolded(ExtMI&: *ExtMI, MRI))
3909 return false;
3910
3911 // We don't want to propagate the extends unless there's a good chance that
3912 // they'll be optimized in some way.
3913 // Collect the unique incoming values.
3914 SmallPtrSet<MachineInstr *, 4> InSrcs;
3915 for (unsigned I = 0; I < PHI.getNumIncomingValues(); ++I) {
3916 auto *DefMI = getDefIgnoringCopies(Reg: PHI.getIncomingValue(I), MRI);
3917 switch (DefMI->getOpcode()) {
3918 case TargetOpcode::G_LOAD:
3919 case TargetOpcode::G_TRUNC:
3920 case TargetOpcode::G_SEXT:
3921 case TargetOpcode::G_ZEXT:
3922 case TargetOpcode::G_ANYEXT:
3923 case TargetOpcode::G_CONSTANT:
3924 InSrcs.insert(Ptr: DefMI);
3925 // Don't try to propagate if there are too many places to create new
3926 // extends, chances are it'll increase code size.
3927 if (InSrcs.size() > 2)
3928 return false;
3929 break;
3930 default:
3931 return false;
3932 }
3933 }
3934 return true;
3935}
3936
3937void CombinerHelper::applyExtendThroughPhis(MachineInstr &MI,
3938 MachineInstr *&ExtMI) {
3939 auto &PHI = cast<GPhi>(Val&: MI);
3940 Register DstReg = ExtMI->getOperand(i: 0).getReg();
3941 LLT ExtTy = MRI.getType(Reg: DstReg);
3942
3943 // Propagate the extension into the block of each incoming reg's block.
3944 // Use a SetVector here because PHIs can have duplicate edges, and we want
3945 // deterministic iteration order.
3946 SmallSetVector<MachineInstr *, 8> SrcMIs;
3947 SmallDenseMap<MachineInstr *, MachineInstr *, 8> OldToNewSrcMap;
3948 for (unsigned I = 0; I < PHI.getNumIncomingValues(); ++I) {
3949 auto SrcReg = PHI.getIncomingValue(I);
3950 auto *SrcMI = MRI.getVRegDef(Reg: SrcReg);
3951 if (!SrcMIs.insert(X: SrcMI))
3952 continue;
3953
3954 // Build an extend after each src inst.
3955 auto *MBB = SrcMI->getParent();
3956 MachineBasicBlock::iterator InsertPt = ++SrcMI->getIterator();
3957 if (InsertPt != MBB->end() && InsertPt->isPHI())
3958 InsertPt = MBB->getFirstNonPHI();
3959
3960 Builder.setInsertPt(MBB&: *SrcMI->getParent(), II: InsertPt);
3961 Builder.setDebugLoc(MI.getDebugLoc());
3962 auto NewExt = Builder.buildExtOrTrunc(ExtOpc: ExtMI->getOpcode(), Res: ExtTy, Op: SrcReg);
3963 OldToNewSrcMap[SrcMI] = NewExt;
3964 }
3965
3966 // Create a new phi with the extended inputs.
3967 Builder.setInstrAndDebugLoc(MI);
3968 auto NewPhi = Builder.buildInstrNoInsert(Opcode: TargetOpcode::G_PHI);
3969 NewPhi.addDef(RegNo: DstReg);
3970 for (const MachineOperand &MO : llvm::drop_begin(RangeOrContainer: MI.operands())) {
3971 if (!MO.isReg()) {
3972 NewPhi.addMBB(MBB: MO.getMBB());
3973 continue;
3974 }
3975 auto *NewSrc = OldToNewSrcMap[MRI.getVRegDef(Reg: MO.getReg())];
3976 NewPhi.addUse(RegNo: NewSrc->getOperand(i: 0).getReg());
3977 }
3978 Builder.insertInstr(MIB: NewPhi);
3979 ExtMI->eraseFromParent();
3980}
3981
3982bool CombinerHelper::matchExtractVecEltBuildVec(MachineInstr &MI,
3983 Register &Reg) {
3984 assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT);
3985 // If we have a constant index, look for a G_BUILD_VECTOR source
3986 // and find the source register that the index maps to.
3987 Register SrcVec = MI.getOperand(i: 1).getReg();
3988 LLT SrcTy = MRI.getType(Reg: SrcVec);
3989
3990 auto Cst = getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI);
3991 if (!Cst || Cst->Value.getZExtValue() >= SrcTy.getNumElements())
3992 return false;
3993
3994 unsigned VecIdx = Cst->Value.getZExtValue();
3995
3996 // Check if we have a build_vector or build_vector_trunc with an optional
3997 // trunc in front.
3998 MachineInstr *SrcVecMI = MRI.getVRegDef(Reg: SrcVec);
3999 if (SrcVecMI->getOpcode() == TargetOpcode::G_TRUNC) {
4000 SrcVecMI = MRI.getVRegDef(Reg: SrcVecMI->getOperand(i: 1).getReg());
4001 }
4002
4003 if (SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR &&
4004 SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR_TRUNC)
4005 return false;
4006
4007 EVT Ty(getMVTForLLT(Ty: SrcTy));
4008 if (!MRI.hasOneNonDBGUse(RegNo: SrcVec) &&
4009 !getTargetLowering().aggressivelyPreferBuildVectorSources(VecVT: Ty))
4010 return false;
4011
4012 Reg = SrcVecMI->getOperand(i: VecIdx + 1).getReg();
4013 return true;
4014}
4015
4016void CombinerHelper::applyExtractVecEltBuildVec(MachineInstr &MI,
4017 Register &Reg) {
4018 // Check the type of the register, since it may have come from a
4019 // G_BUILD_VECTOR_TRUNC.
4020 LLT ScalarTy = MRI.getType(Reg);
4021 Register DstReg = MI.getOperand(i: 0).getReg();
4022 LLT DstTy = MRI.getType(Reg: DstReg);
4023
4024 Builder.setInstrAndDebugLoc(MI);
4025 if (ScalarTy != DstTy) {
4026 assert(ScalarTy.getSizeInBits() > DstTy.getSizeInBits());
4027 Builder.buildTrunc(Res: DstReg, Op: Reg);
4028 MI.eraseFromParent();
4029 return;
4030 }
4031 replaceSingleDefInstWithReg(MI, Replacement: Reg);
4032}
4033
4034bool CombinerHelper::matchExtractAllEltsFromBuildVector(
4035 MachineInstr &MI,
4036 SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) {
4037 assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
4038 // This combine tries to find build_vector's which have every source element
4039 // extracted using G_EXTRACT_VECTOR_ELT. This can happen when transforms like
4040 // the masked load scalarization is run late in the pipeline. There's already
4041 // a combine for a similar pattern starting from the extract, but that
4042 // doesn't attempt to do it if there are multiple uses of the build_vector,
4043 // which in this case is true. Starting the combine from the build_vector
4044 // feels more natural than trying to find sibling nodes of extracts.
4045 // E.g.
4046 // %vec(<4 x s32>) = G_BUILD_VECTOR %s1(s32), %s2, %s3, %s4
4047 // %ext1 = G_EXTRACT_VECTOR_ELT %vec, 0
4048 // %ext2 = G_EXTRACT_VECTOR_ELT %vec, 1
4049 // %ext3 = G_EXTRACT_VECTOR_ELT %vec, 2
4050 // %ext4 = G_EXTRACT_VECTOR_ELT %vec, 3
4051 // ==>
4052 // replace ext{1,2,3,4} with %s{1,2,3,4}
4053
4054 Register DstReg = MI.getOperand(i: 0).getReg();
4055 LLT DstTy = MRI.getType(Reg: DstReg);
4056 unsigned NumElts = DstTy.getNumElements();
4057
4058 SmallBitVector ExtractedElts(NumElts);
4059 for (MachineInstr &II : MRI.use_nodbg_instructions(Reg: DstReg)) {
4060 if (II.getOpcode() != TargetOpcode::G_EXTRACT_VECTOR_ELT)
4061 return false;
4062 auto Cst = getIConstantVRegVal(VReg: II.getOperand(i: 2).getReg(), MRI);
4063 if (!Cst)
4064 return false;
4065 unsigned Idx = Cst->getZExtValue();
4066 if (Idx >= NumElts)
4067 return false; // Out of range.
4068 ExtractedElts.set(Idx);
4069 SrcDstPairs.emplace_back(
4070 Args: std::make_pair(x: MI.getOperand(i: Idx + 1).getReg(), y: &II));
4071 }
4072 // Match if every element was extracted.
4073 return ExtractedElts.all();
4074}
4075
4076void CombinerHelper::applyExtractAllEltsFromBuildVector(
4077 MachineInstr &MI,
4078 SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) {
4079 assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
4080 for (auto &Pair : SrcDstPairs) {
4081 auto *ExtMI = Pair.second;
4082 replaceRegWith(MRI, FromReg: ExtMI->getOperand(i: 0).getReg(), ToReg: Pair.first);
4083 ExtMI->eraseFromParent();
4084 }
4085 MI.eraseFromParent();
4086}
4087
4088void CombinerHelper::applyBuildFn(
4089 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
4090 Builder.setInstrAndDebugLoc(MI);
4091 MatchInfo(Builder);
4092 MI.eraseFromParent();
4093}
4094
4095void CombinerHelper::applyBuildFnNoErase(
4096 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
4097 Builder.setInstrAndDebugLoc(MI);
4098 MatchInfo(Builder);
4099}
4100
4101bool CombinerHelper::matchOrShiftToFunnelShift(MachineInstr &MI,
4102 BuildFnTy &MatchInfo) {
4103 assert(MI.getOpcode() == TargetOpcode::G_OR);
4104
4105 Register Dst = MI.getOperand(i: 0).getReg();
4106 LLT Ty = MRI.getType(Reg: Dst);
4107 unsigned BitWidth = Ty.getScalarSizeInBits();
4108
4109 Register ShlSrc, ShlAmt, LShrSrc, LShrAmt, Amt;
4110 unsigned FshOpc = 0;
4111
4112 // Match (or (shl ...), (lshr ...)).
4113 if (!mi_match(R: Dst, MRI,
4114 // m_GOr() handles the commuted version as well.
4115 P: m_GOr(L: m_GShl(L: m_Reg(R&: ShlSrc), R: m_Reg(R&: ShlAmt)),
4116 R: m_GLShr(L: m_Reg(R&: LShrSrc), R: m_Reg(R&: LShrAmt)))))
4117 return false;
4118
4119 // Given constants C0 and C1 such that C0 + C1 is bit-width:
4120 // (or (shl x, C0), (lshr y, C1)) -> (fshl x, y, C0) or (fshr x, y, C1)
4121 int64_t CstShlAmt, CstLShrAmt;
4122 if (mi_match(R: ShlAmt, MRI, P: m_ICstOrSplat(Cst&: CstShlAmt)) &&
4123 mi_match(R: LShrAmt, MRI, P: m_ICstOrSplat(Cst&: CstLShrAmt)) &&
4124 CstShlAmt + CstLShrAmt == BitWidth) {
4125 FshOpc = TargetOpcode::G_FSHR;
4126 Amt = LShrAmt;
4127
4128 } else if (mi_match(R: LShrAmt, MRI,
4129 P: m_GSub(L: m_SpecificICstOrSplat(RequestedValue: BitWidth), R: m_Reg(R&: Amt))) &&
4130 ShlAmt == Amt) {
4131 // (or (shl x, amt), (lshr y, (sub bw, amt))) -> (fshl x, y, amt)
4132 FshOpc = TargetOpcode::G_FSHL;
4133
4134 } else if (mi_match(R: ShlAmt, MRI,
4135 P: m_GSub(L: m_SpecificICstOrSplat(RequestedValue: BitWidth), R: m_Reg(R&: Amt))) &&
4136 LShrAmt == Amt) {
4137 // (or (shl x, (sub bw, amt)), (lshr y, amt)) -> (fshr x, y, amt)
4138 FshOpc = TargetOpcode::G_FSHR;
4139
4140 } else {
4141 return false;
4142 }
4143
4144 LLT AmtTy = MRI.getType(Reg: Amt);
4145 if (!isLegalOrBeforeLegalizer(Query: {FshOpc, {Ty, AmtTy}}))
4146 return false;
4147
4148 MatchInfo = [=](MachineIRBuilder &B) {
4149 B.buildInstr(Opc: FshOpc, DstOps: {Dst}, SrcOps: {ShlSrc, LShrSrc, Amt});
4150 };
4151 return true;
4152}
4153
4154/// Match an FSHL or FSHR that can be combined to a ROTR or ROTL rotate.
4155bool CombinerHelper::matchFunnelShiftToRotate(MachineInstr &MI) {
4156 unsigned Opc = MI.getOpcode();
4157 assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR);
4158 Register X = MI.getOperand(i: 1).getReg();
4159 Register Y = MI.getOperand(i: 2).getReg();
4160 if (X != Y)
4161 return false;
4162 unsigned RotateOpc =
4163 Opc == TargetOpcode::G_FSHL ? TargetOpcode::G_ROTL : TargetOpcode::G_ROTR;
4164 return isLegalOrBeforeLegalizer(Query: {RotateOpc, {MRI.getType(Reg: X), MRI.getType(Reg: Y)}});
4165}
4166
4167void CombinerHelper::applyFunnelShiftToRotate(MachineInstr &MI) {
4168 unsigned Opc = MI.getOpcode();
4169 assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR);
4170 bool IsFSHL = Opc == TargetOpcode::G_FSHL;
4171 Observer.changingInstr(MI);
4172 MI.setDesc(Builder.getTII().get(Opcode: IsFSHL ? TargetOpcode::G_ROTL
4173 : TargetOpcode::G_ROTR));
4174 MI.removeOperand(OpNo: 2);
4175 Observer.changedInstr(MI);
4176}
4177
4178// Fold (rot x, c) -> (rot x, c % BitSize)
4179bool CombinerHelper::matchRotateOutOfRange(MachineInstr &MI) {
4180 assert(MI.getOpcode() == TargetOpcode::G_ROTL ||
4181 MI.getOpcode() == TargetOpcode::G_ROTR);
4182 unsigned Bitsize =
4183 MRI.getType(Reg: MI.getOperand(i: 0).getReg()).getScalarSizeInBits();
4184 Register AmtReg = MI.getOperand(i: 2).getReg();
4185 bool OutOfRange = false;
4186 auto MatchOutOfRange = [Bitsize, &OutOfRange](const Constant *C) {
4187 if (auto *CI = dyn_cast<ConstantInt>(Val: C))
4188 OutOfRange |= CI->getValue().uge(RHS: Bitsize);
4189 return true;
4190 };
4191 return matchUnaryPredicate(MRI, Reg: AmtReg, Match: MatchOutOfRange) && OutOfRange;
4192}
4193
4194void CombinerHelper::applyRotateOutOfRange(MachineInstr &MI) {
4195 assert(MI.getOpcode() == TargetOpcode::G_ROTL ||
4196 MI.getOpcode() == TargetOpcode::G_ROTR);
4197 unsigned Bitsize =
4198 MRI.getType(Reg: MI.getOperand(i: 0).getReg()).getScalarSizeInBits();
4199 Builder.setInstrAndDebugLoc(MI);
4200 Register Amt = MI.getOperand(i: 2).getReg();
4201 LLT AmtTy = MRI.getType(Reg: Amt);
4202 auto Bits = Builder.buildConstant(Res: AmtTy, Val: Bitsize);
4203 Amt = Builder.buildURem(Dst: AmtTy, Src0: MI.getOperand(i: 2).getReg(), Src1: Bits).getReg(Idx: 0);
4204 Observer.changingInstr(MI);
4205 MI.getOperand(i: 2).setReg(Amt);
4206 Observer.changedInstr(MI);
4207}
4208
4209bool CombinerHelper::matchICmpToTrueFalseKnownBits(MachineInstr &MI,
4210 int64_t &MatchInfo) {
4211 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
4212 auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(i: 1).getPredicate());
4213 auto KnownLHS = KB->getKnownBits(R: MI.getOperand(i: 2).getReg());
4214 auto KnownRHS = KB->getKnownBits(R: MI.getOperand(i: 3).getReg());
4215 std::optional<bool> KnownVal;
4216 switch (Pred) {
4217 default:
4218 llvm_unreachable("Unexpected G_ICMP predicate?");
4219 case CmpInst::ICMP_EQ:
4220 KnownVal = KnownBits::eq(LHS: KnownLHS, RHS: KnownRHS);
4221 break;
4222 case CmpInst::ICMP_NE:
4223 KnownVal = KnownBits::ne(LHS: KnownLHS, RHS: KnownRHS);
4224 break;
4225 case CmpInst::ICMP_SGE:
4226 KnownVal = KnownBits::sge(LHS: KnownLHS, RHS: KnownRHS);
4227 break;
4228 case CmpInst::ICMP_SGT:
4229 KnownVal = KnownBits::sgt(LHS: KnownLHS, RHS: KnownRHS);
4230 break;
4231 case CmpInst::ICMP_SLE:
4232 KnownVal = KnownBits::sle(LHS: KnownLHS, RHS: KnownRHS);
4233 break;
4234 case CmpInst::ICMP_SLT:
4235 KnownVal = KnownBits::slt(LHS: KnownLHS, RHS: KnownRHS);
4236 break;
4237 case CmpInst::ICMP_UGE:
4238 KnownVal = KnownBits::uge(LHS: KnownLHS, RHS: KnownRHS);
4239 break;
4240 case CmpInst::ICMP_UGT:
4241 KnownVal = KnownBits::ugt(LHS: KnownLHS, RHS: KnownRHS);
4242 break;
4243 case CmpInst::ICMP_ULE:
4244 KnownVal = KnownBits::ule(LHS: KnownLHS, RHS: KnownRHS);
4245 break;
4246 case CmpInst::ICMP_ULT:
4247 KnownVal = KnownBits::ult(LHS: KnownLHS, RHS: KnownRHS);
4248 break;
4249 }
4250 if (!KnownVal)
4251 return false;
4252 MatchInfo =
4253 *KnownVal
4254 ? getICmpTrueVal(TLI: getTargetLowering(),
4255 /*IsVector = */
4256 MRI.getType(Reg: MI.getOperand(i: 0).getReg()).isVector(),
4257 /* IsFP = */ false)
4258 : 0;
4259 return true;
4260}
4261
4262bool CombinerHelper::matchICmpToLHSKnownBits(
4263 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
4264 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
4265 // Given:
4266 //
4267 // %x = G_WHATEVER (... x is known to be 0 or 1 ...)
4268 // %cmp = G_ICMP ne %x, 0
4269 //
4270 // Or:
4271 //
4272 // %x = G_WHATEVER (... x is known to be 0 or 1 ...)
4273 // %cmp = G_ICMP eq %x, 1
4274 //
4275 // We can replace %cmp with %x assuming true is 1 on the target.
4276 auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(i: 1).getPredicate());
4277 if (!CmpInst::isEquality(pred: Pred))
4278 return false;
4279 Register Dst = MI.getOperand(i: 0).getReg();
4280 LLT DstTy = MRI.getType(Reg: Dst);
4281 if (getICmpTrueVal(TLI: getTargetLowering(), IsVector: DstTy.isVector(),
4282 /* IsFP = */ false) != 1)
4283 return false;
4284 int64_t OneOrZero = Pred == CmpInst::ICMP_EQ;
4285 if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICst(RequestedValue: OneOrZero)))
4286 return false;
4287 Register LHS = MI.getOperand(i: 2).getReg();
4288 auto KnownLHS = KB->getKnownBits(R: LHS);
4289 if (KnownLHS.getMinValue() != 0 || KnownLHS.getMaxValue() != 1)
4290 return false;
4291 // Make sure replacing Dst with the LHS is a legal operation.
4292 LLT LHSTy = MRI.getType(Reg: LHS);
4293 unsigned LHSSize = LHSTy.getSizeInBits();
4294 unsigned DstSize = DstTy.getSizeInBits();
4295 unsigned Op = TargetOpcode::COPY;
4296 if (DstSize != LHSSize)
4297 Op = DstSize < LHSSize ? TargetOpcode::G_TRUNC : TargetOpcode::G_ZEXT;
4298 if (!isLegalOrBeforeLegalizer(Query: {Op, {DstTy, LHSTy}}))
4299 return false;
4300 MatchInfo = [=](MachineIRBuilder &B) { B.buildInstr(Opc: Op, DstOps: {Dst}, SrcOps: {LHS}); };
4301 return true;
4302}
4303
4304// Replace (and (or x, c1), c2) with (and x, c2) iff c1 & c2 == 0
4305bool CombinerHelper::matchAndOrDisjointMask(
4306 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
4307 assert(MI.getOpcode() == TargetOpcode::G_AND);
4308
4309 // Ignore vector types to simplify matching the two constants.
4310 // TODO: do this for vectors and scalars via a demanded bits analysis.
4311 LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
4312 if (Ty.isVector())
4313 return false;
4314
4315 Register Src;
4316 Register AndMaskReg;
4317 int64_t AndMaskBits;
4318 int64_t OrMaskBits;
4319 if (!mi_match(MI, MRI,
4320 P: m_GAnd(L: m_GOr(L: m_Reg(R&: Src), R: m_ICst(Cst&: OrMaskBits)),
4321 R: m_all_of(preds: m_ICst(Cst&: AndMaskBits), preds: m_Reg(R&: AndMaskReg)))))
4322 return false;
4323
4324 // Check if OrMask could turn on any bits in Src.
4325 if (AndMaskBits & OrMaskBits)
4326 return false;
4327
4328 MatchInfo = [=, &MI](MachineIRBuilder &B) {
4329 Observer.changingInstr(MI);
4330 // Canonicalize the result to have the constant on the RHS.
4331 if (MI.getOperand(i: 1).getReg() == AndMaskReg)
4332 MI.getOperand(i: 2).setReg(AndMaskReg);
4333 MI.getOperand(i: 1).setReg(Src);
4334 Observer.changedInstr(MI);
4335 };
4336 return true;
4337}
4338
4339/// Form a G_SBFX from a G_SEXT_INREG fed by a right shift.
4340bool CombinerHelper::matchBitfieldExtractFromSExtInReg(
4341 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
4342 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
4343 Register Dst = MI.getOperand(i: 0).getReg();
4344 Register Src = MI.getOperand(i: 1).getReg();
4345 LLT Ty = MRI.getType(Reg: Src);
4346 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
4347 if (!LI || !LI->isLegalOrCustom(Query: {TargetOpcode::G_SBFX, {Ty, ExtractTy}}))
4348 return false;
4349 int64_t Width = MI.getOperand(i: 2).getImm();
4350 Register ShiftSrc;
4351 int64_t ShiftImm;
4352 if (!mi_match(
4353 R: Src, MRI,
4354 P: m_OneNonDBGUse(SP: m_any_of(preds: m_GAShr(L: m_Reg(R&: ShiftSrc), R: m_ICst(Cst&: ShiftImm)),
4355 preds: m_GLShr(L: m_Reg(R&: ShiftSrc), R: m_ICst(Cst&: ShiftImm))))))
4356 return false;
4357 if (ShiftImm < 0 || ShiftImm + Width > Ty.getScalarSizeInBits())
4358 return false;
4359
4360 MatchInfo = [=](MachineIRBuilder &B) {
4361 auto Cst1 = B.buildConstant(Res: ExtractTy, Val: ShiftImm);
4362 auto Cst2 = B.buildConstant(Res: ExtractTy, Val: Width);
4363 B.buildSbfx(Dst, Src: ShiftSrc, LSB: Cst1, Width: Cst2);
4364 };
4365 return true;
4366}
4367
4368/// Form a G_UBFX from "(a srl b) & mask", where b and mask are constants.
4369bool CombinerHelper::matchBitfieldExtractFromAnd(
4370 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
4371 assert(MI.getOpcode() == TargetOpcode::G_AND);
4372 Register Dst = MI.getOperand(i: 0).getReg();
4373 LLT Ty = MRI.getType(Reg: Dst);
4374 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
4375 if (LI && !LI->isLegalOrCustom(Query: {TargetOpcode::G_UBFX, {Ty, ExtractTy}}))
4376 return false;
4377
4378 int64_t AndImm, LSBImm;
4379 Register ShiftSrc;
4380 const unsigned Size = Ty.getScalarSizeInBits();
4381 if (!mi_match(R: MI.getOperand(i: 0).getReg(), MRI,
4382 P: m_GAnd(L: m_OneNonDBGUse(SP: m_GLShr(L: m_Reg(R&: ShiftSrc), R: m_ICst(Cst&: LSBImm))),
4383 R: m_ICst(Cst&: AndImm))))
4384 return false;
4385
4386 // The mask is a mask of the low bits iff imm & (imm+1) == 0.
4387 auto MaybeMask = static_cast<uint64_t>(AndImm);
4388 if (MaybeMask & (MaybeMask + 1))
4389 return false;
4390
4391 // LSB must fit within the register.
4392 if (static_cast<uint64_t>(LSBImm) >= Size)
4393 return false;
4394
4395 uint64_t Width = APInt(Size, AndImm).countr_one();
4396 MatchInfo = [=](MachineIRBuilder &B) {
4397 auto WidthCst = B.buildConstant(Res: ExtractTy, Val: Width);
4398 auto LSBCst = B.buildConstant(Res: ExtractTy, Val: LSBImm);
4399 B.buildInstr(Opc: TargetOpcode::G_UBFX, DstOps: {Dst}, SrcOps: {ShiftSrc, LSBCst, WidthCst});
4400 };
4401 return true;
4402}
4403
4404bool CombinerHelper::matchBitfieldExtractFromShr(
4405 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
4406 const unsigned Opcode = MI.getOpcode();
4407 assert(Opcode == TargetOpcode::G_ASHR || Opcode == TargetOpcode::G_LSHR);
4408
4409 const Register Dst = MI.getOperand(i: 0).getReg();
4410
4411 const unsigned ExtrOpcode = Opcode == TargetOpcode::G_ASHR
4412 ? TargetOpcode::G_SBFX
4413 : TargetOpcode::G_UBFX;
4414
4415 // Check if the type we would use for the extract is legal
4416 LLT Ty = MRI.getType(Reg: Dst);
4417 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
4418 if (!LI || !LI->isLegalOrCustom(Query: {ExtrOpcode, {Ty, ExtractTy}}))
4419 return false;
4420
4421 Register ShlSrc;
4422 int64_t ShrAmt;
4423 int64_t ShlAmt;
4424 const unsigned Size = Ty.getScalarSizeInBits();
4425
4426 // Try to match shr (shl x, c1), c2
4427 if (!mi_match(R: Dst, MRI,
4428 P: m_BinOp(Opcode,
4429 L: m_OneNonDBGUse(SP: m_GShl(L: m_Reg(R&: ShlSrc), R: m_ICst(Cst&: ShlAmt))),
4430 R: m_ICst(Cst&: ShrAmt))))
4431 return false;
4432
4433 // Make sure that the shift sizes can fit a bitfield extract
4434 if (ShlAmt < 0 || ShlAmt > ShrAmt || ShrAmt >= Size)
4435 return false;
4436
4437 // Skip this combine if the G_SEXT_INREG combine could handle it
4438 if (Opcode == TargetOpcode::G_ASHR && ShlAmt == ShrAmt)
4439 return false;
4440
4441 // Calculate start position and width of the extract
4442 const int64_t Pos = ShrAmt - ShlAmt;
4443 const int64_t Width = Size - ShrAmt;
4444
4445 MatchInfo = [=](MachineIRBuilder &B) {
4446 auto WidthCst = B.buildConstant(Res: ExtractTy, Val: Width);
4447 auto PosCst = B.buildConstant(Res: ExtractTy, Val: Pos);
4448 B.buildInstr(Opc: ExtrOpcode, DstOps: {Dst}, SrcOps: {ShlSrc, PosCst, WidthCst});
4449 };
4450 return true;
4451}
4452
4453bool CombinerHelper::matchBitfieldExtractFromShrAnd(
4454 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
4455 const unsigned Opcode = MI.getOpcode();
4456 assert(Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_ASHR);
4457
4458 const Register Dst = MI.getOperand(i: 0).getReg();
4459 LLT Ty = MRI.getType(Reg: Dst);
4460 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
4461 if (LI && !LI->isLegalOrCustom(Query: {TargetOpcode::G_UBFX, {Ty, ExtractTy}}))
4462 return false;
4463
4464 // Try to match shr (and x, c1), c2
4465 Register AndSrc;
4466 int64_t ShrAmt;
4467 int64_t SMask;
4468 if (!mi_match(R: Dst, MRI,
4469 P: m_BinOp(Opcode,
4470 L: m_OneNonDBGUse(SP: m_GAnd(L: m_Reg(R&: AndSrc), R: m_ICst(Cst&: SMask))),
4471 R: m_ICst(Cst&: ShrAmt))))
4472 return false;
4473
4474 const unsigned Size = Ty.getScalarSizeInBits();
4475 if (ShrAmt < 0 || ShrAmt >= Size)
4476 return false;
4477
4478 // If the shift subsumes the mask, emit the 0 directly.
4479 if (0 == (SMask >> ShrAmt)) {
4480 MatchInfo = [=](MachineIRBuilder &B) {
4481 B.buildConstant(Res: Dst, Val: 0);
4482 };
4483 return true;
4484 }
4485
4486 // Check that ubfx can do the extraction, with no holes in the mask.
4487 uint64_t UMask = SMask;
4488 UMask |= maskTrailingOnes<uint64_t>(N: ShrAmt);
4489 UMask &= maskTrailingOnes<uint64_t>(N: Size);
4490 if (!isMask_64(Value: UMask))
4491 return false;
4492
4493 // Calculate start position and width of the extract.
4494 const int64_t Pos = ShrAmt;
4495 const int64_t Width = llvm::countr_one(Value: UMask) - ShrAmt;
4496
4497 // It's preferable to keep the shift, rather than form G_SBFX.
4498 // TODO: remove the G_AND via demanded bits analysis.
4499 if (Opcode == TargetOpcode::G_ASHR && Width + ShrAmt == Size)
4500 return false;
4501
4502 MatchInfo = [=](MachineIRBuilder &B) {
4503 auto WidthCst = B.buildConstant(Res: ExtractTy, Val: Width);
4504 auto PosCst = B.buildConstant(Res: ExtractTy, Val: Pos);
4505 B.buildInstr(Opc: TargetOpcode::G_UBFX, DstOps: {Dst}, SrcOps: {AndSrc, PosCst, WidthCst});
4506 };
4507 return true;
4508}
4509
4510bool CombinerHelper::reassociationCanBreakAddressingModePattern(
4511 MachineInstr &MI) {
4512 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
4513
4514 Register Src1Reg = PtrAdd.getBaseReg();
4515 auto *Src1Def = getOpcodeDef<GPtrAdd>(Reg: Src1Reg, MRI);
4516 if (!Src1Def)
4517 return false;
4518
4519 Register Src2Reg = PtrAdd.getOffsetReg();
4520
4521 if (MRI.hasOneNonDBGUse(RegNo: Src1Reg))
4522 return false;
4523
4524 auto C1 = getIConstantVRegVal(VReg: Src1Def->getOffsetReg(), MRI);
4525 if (!C1)
4526 return false;
4527 auto C2 = getIConstantVRegVal(VReg: Src2Reg, MRI);
4528 if (!C2)
4529 return false;
4530
4531 const APInt &C1APIntVal = *C1;
4532 const APInt &C2APIntVal = *C2;
4533 const int64_t CombinedValue = (C1APIntVal + C2APIntVal).getSExtValue();
4534
4535 for (auto &UseMI : MRI.use_nodbg_instructions(Reg: PtrAdd.getReg(Idx: 0))) {
4536 // This combine may end up running before ptrtoint/inttoptr combines
4537 // manage to eliminate redundant conversions, so try to look through them.
4538 MachineInstr *ConvUseMI = &UseMI;
4539 unsigned ConvUseOpc = ConvUseMI->getOpcode();
4540 while (ConvUseOpc == TargetOpcode::G_INTTOPTR ||
4541 ConvUseOpc == TargetOpcode::G_PTRTOINT) {
4542 Register DefReg = ConvUseMI->getOperand(i: 0).getReg();
4543 if (!MRI.hasOneNonDBGUse(RegNo: DefReg))
4544 break;
4545 ConvUseMI = &*MRI.use_instr_nodbg_begin(RegNo: DefReg);
4546 ConvUseOpc = ConvUseMI->getOpcode();
4547 }
4548 auto *LdStMI = dyn_cast<GLoadStore>(Val: ConvUseMI);
4549 if (!LdStMI)
4550 continue;
4551 // Is x[offset2] already not a legal addressing mode? If so then
4552 // reassociating the constants breaks nothing (we test offset2 because
4553 // that's the one we hope to fold into the load or store).
4554 TargetLoweringBase::AddrMode AM;
4555 AM.HasBaseReg = true;
4556 AM.BaseOffs = C2APIntVal.getSExtValue();
4557 unsigned AS = MRI.getType(Reg: LdStMI->getPointerReg()).getAddressSpace();
4558 Type *AccessTy = getTypeForLLT(Ty: LdStMI->getMMO().getMemoryType(),
4559 C&: PtrAdd.getMF()->getFunction().getContext());
4560 const auto &TLI = *PtrAdd.getMF()->getSubtarget().getTargetLowering();
4561 if (!TLI.isLegalAddressingMode(DL: PtrAdd.getMF()->getDataLayout(), AM,
4562 Ty: AccessTy, AddrSpace: AS))
4563 continue;
4564
4565 // Would x[offset1+offset2] still be a legal addressing mode?
4566 AM.BaseOffs = CombinedValue;
4567 if (!TLI.isLegalAddressingMode(DL: PtrAdd.getMF()->getDataLayout(), AM,
4568 Ty: AccessTy, AddrSpace: AS))
4569 return true;
4570 }
4571
4572 return false;
4573}
4574
4575bool CombinerHelper::matchReassocConstantInnerRHS(GPtrAdd &MI,
4576 MachineInstr *RHS,
4577 BuildFnTy &MatchInfo) {
4578 // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C)
4579 Register Src1Reg = MI.getOperand(i: 1).getReg();
4580 if (RHS->getOpcode() != TargetOpcode::G_ADD)
4581 return false;
4582 auto C2 = getIConstantVRegVal(VReg: RHS->getOperand(i: 2).getReg(), MRI);
4583 if (!C2)
4584 return false;
4585
4586 MatchInfo = [=, &MI](MachineIRBuilder &B) {
4587 LLT PtrTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
4588
4589 auto NewBase =
4590 Builder.buildPtrAdd(Res: PtrTy, Op0: Src1Reg, Op1: RHS->getOperand(i: 1).getReg());
4591 Observer.changingInstr(MI);
4592 MI.getOperand(i: 1).setReg(NewBase.getReg(Idx: 0));
4593 MI.getOperand(i: 2).setReg(RHS->getOperand(i: 2).getReg());
4594 Observer.changedInstr(MI);
4595 };
4596 return !reassociationCanBreakAddressingModePattern(MI);
4597}
4598
4599bool CombinerHelper::matchReassocConstantInnerLHS(GPtrAdd &MI,
4600 MachineInstr *LHS,
4601 MachineInstr *RHS,
4602 BuildFnTy &MatchInfo) {
4603 // G_PTR_ADD (G_PTR_ADD X, C), Y) -> (G_PTR_ADD (G_PTR_ADD(X, Y), C)
4604 // if and only if (G_PTR_ADD X, C) has one use.
4605 Register LHSBase;
4606 std::optional<ValueAndVReg> LHSCstOff;
4607 if (!mi_match(R: MI.getBaseReg(), MRI,
4608 P: m_OneNonDBGUse(SP: m_GPtrAdd(L: m_Reg(R&: LHSBase), R: m_GCst(ValReg&: LHSCstOff)))))
4609 return false;
4610
4611 auto *LHSPtrAdd = cast<GPtrAdd>(Val: LHS);
4612 MatchInfo = [=, &MI](MachineIRBuilder &B) {
4613 // When we change LHSPtrAdd's offset register we might cause it to use a reg
4614 // before its def. Sink the instruction so the outer PTR_ADD to ensure this
4615 // doesn't happen.
4616 LHSPtrAdd->moveBefore(MovePos: &MI);
4617 Register RHSReg = MI.getOffsetReg();
4618 // set VReg will cause type mismatch if it comes from extend/trunc
4619 auto NewCst = B.buildConstant(Res: MRI.getType(Reg: RHSReg), Val: LHSCstOff->Value);
4620 Observer.changingInstr(MI);
4621 MI.getOperand(i: 2).setReg(NewCst.getReg(Idx: 0));
4622 Observer.changedInstr(MI);
4623 Observer.changingInstr(MI&: *LHSPtrAdd);
4624 LHSPtrAdd->getOperand(i: 2).setReg(RHSReg);
4625 Observer.changedInstr(MI&: *LHSPtrAdd);
4626 };
4627 return !reassociationCanBreakAddressingModePattern(MI);
4628}
4629
4630bool CombinerHelper::matchReassocFoldConstantsInSubTree(GPtrAdd &MI,
4631 MachineInstr *LHS,
4632 MachineInstr *RHS,
4633 BuildFnTy &MatchInfo) {
4634 // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2)
4635 auto *LHSPtrAdd = dyn_cast<GPtrAdd>(Val: LHS);
4636 if (!LHSPtrAdd)
4637 return false;
4638
4639 Register Src2Reg = MI.getOperand(i: 2).getReg();
4640 Register LHSSrc1 = LHSPtrAdd->getBaseReg();
4641 Register LHSSrc2 = LHSPtrAdd->getOffsetReg();
4642 auto C1 = getIConstantVRegVal(VReg: LHSSrc2, MRI);
4643 if (!C1)
4644 return false;
4645 auto C2 = getIConstantVRegVal(VReg: Src2Reg, MRI);
4646 if (!C2)
4647 return false;
4648
4649 MatchInfo = [=, &MI](MachineIRBuilder &B) {
4650 auto NewCst = B.buildConstant(Res: MRI.getType(Reg: Src2Reg), Val: *C1 + *C2);
4651 Observer.changingInstr(MI);
4652 MI.getOperand(i: 1).setReg(LHSSrc1);
4653 MI.getOperand(i: 2).setReg(NewCst.getReg(Idx: 0));
4654 Observer.changedInstr(MI);
4655 };
4656 return !reassociationCanBreakAddressingModePattern(MI);
4657}
4658
4659bool CombinerHelper::matchReassocPtrAdd(MachineInstr &MI,
4660 BuildFnTy &MatchInfo) {
4661 auto &PtrAdd = cast<GPtrAdd>(Val&: MI);
4662 // We're trying to match a few pointer computation patterns here for
4663 // re-association opportunities.
4664 // 1) Isolating a constant operand to be on the RHS, e.g.:
4665 // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C)
4666 //
4667 // 2) Folding two constants in each sub-tree as long as such folding
4668 // doesn't break a legal addressing mode.
4669 // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2)
4670 //
4671 // 3) Move a constant from the LHS of an inner op to the RHS of the outer.
4672 // G_PTR_ADD (G_PTR_ADD X, C), Y) -> G_PTR_ADD (G_PTR_ADD(X, Y), C)
4673 // iif (G_PTR_ADD X, C) has one use.
4674 MachineInstr *LHS = MRI.getVRegDef(Reg: PtrAdd.getBaseReg());
4675 MachineInstr *RHS = MRI.getVRegDef(Reg: PtrAdd.getOffsetReg());
4676
4677 // Try to match example 2.
4678 if (matchReassocFoldConstantsInSubTree(MI&: PtrAdd, LHS, RHS, MatchInfo))
4679 return true;
4680
4681 // Try to match example 3.
4682 if (matchReassocConstantInnerLHS(MI&: PtrAdd, LHS, RHS, MatchInfo))
4683 return true;
4684
4685 // Try to match example 1.
4686 if (matchReassocConstantInnerRHS(MI&: PtrAdd, RHS, MatchInfo))
4687 return true;
4688
4689 return false;
4690}
4691bool CombinerHelper::tryReassocBinOp(unsigned Opc, Register DstReg,
4692 Register OpLHS, Register OpRHS,
4693 BuildFnTy &MatchInfo) {
4694 LLT OpRHSTy = MRI.getType(Reg: OpRHS);
4695 MachineInstr *OpLHSDef = MRI.getVRegDef(Reg: OpLHS);
4696
4697 if (OpLHSDef->getOpcode() != Opc)
4698 return false;
4699
4700 MachineInstr *OpRHSDef = MRI.getVRegDef(Reg: OpRHS);
4701 Register OpLHSLHS = OpLHSDef->getOperand(i: 1).getReg();
4702 Register OpLHSRHS = OpLHSDef->getOperand(i: 2).getReg();
4703
4704 // If the inner op is (X op C), pull the constant out so it can be folded with
4705 // other constants in the expression tree. Folding is not guaranteed so we
4706 // might have (C1 op C2). In that case do not pull a constant out because it
4707 // won't help and can lead to infinite loops.
4708 if (isConstantOrConstantSplatVector(MI&: *MRI.getVRegDef(Reg: OpLHSRHS), MRI) &&
4709 !isConstantOrConstantSplatVector(MI&: *MRI.getVRegDef(Reg: OpLHSLHS), MRI)) {
4710 if (isConstantOrConstantSplatVector(MI&: *OpRHSDef, MRI)) {
4711 // (Opc (Opc X, C1), C2) -> (Opc X, (Opc C1, C2))
4712 MatchInfo = [=](MachineIRBuilder &B) {
4713 auto NewCst = B.buildInstr(Opc, DstOps: {OpRHSTy}, SrcOps: {OpLHSRHS, OpRHS});
4714 B.buildInstr(Opc, DstOps: {DstReg}, SrcOps: {OpLHSLHS, NewCst});
4715 };
4716 return true;
4717 }
4718 if (getTargetLowering().isReassocProfitable(MRI, N0: OpLHS, N1: OpRHS)) {
4719 // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1)
4720 // iff (op x, c1) has one use
4721 MatchInfo = [=](MachineIRBuilder &B) {
4722 auto NewLHSLHS = B.buildInstr(Opc, DstOps: {OpRHSTy}, SrcOps: {OpLHSLHS, OpRHS});
4723 B.buildInstr(Opc, DstOps: {DstReg}, SrcOps: {NewLHSLHS, OpLHSRHS});
4724 };
4725 return true;
4726 }
4727 }
4728
4729 return false;
4730}
4731
4732bool CombinerHelper::matchReassocCommBinOp(MachineInstr &MI,
4733 BuildFnTy &MatchInfo) {
4734 // We don't check if the reassociation will break a legal addressing mode
4735 // here since pointer arithmetic is handled by G_PTR_ADD.
4736 unsigned Opc = MI.getOpcode();
4737 Register DstReg = MI.getOperand(i: 0).getReg();
4738 Register LHSReg = MI.getOperand(i: 1).getReg();
4739 Register RHSReg = MI.getOperand(i: 2).getReg();
4740
4741 if (tryReassocBinOp(Opc, DstReg, OpLHS: LHSReg, OpRHS: RHSReg, MatchInfo))
4742 return true;
4743 if (tryReassocBinOp(Opc, DstReg, OpLHS: RHSReg, OpRHS: LHSReg, MatchInfo))
4744 return true;
4745 return false;
4746}
4747
4748bool CombinerHelper::matchConstantFoldCastOp(MachineInstr &MI, APInt &MatchInfo) {
4749 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
4750 Register SrcOp = MI.getOperand(i: 1).getReg();
4751
4752 if (auto MaybeCst = ConstantFoldCastOp(Opcode: MI.getOpcode(), DstTy, Op0: SrcOp, MRI)) {
4753 MatchInfo = *MaybeCst;
4754 return true;
4755 }
4756
4757 return false;
4758}
4759
4760bool CombinerHelper::matchConstantFoldBinOp(MachineInstr &MI, APInt &MatchInfo) {
4761 Register Op1 = MI.getOperand(i: 1).getReg();
4762 Register Op2 = MI.getOperand(i: 2).getReg();
4763 auto MaybeCst = ConstantFoldBinOp(Opcode: MI.getOpcode(), Op1, Op2, MRI);
4764 if (!MaybeCst)
4765 return false;
4766 MatchInfo = *MaybeCst;
4767 return true;
4768}
4769
4770bool CombinerHelper::matchConstantFoldFPBinOp(MachineInstr &MI, ConstantFP* &MatchInfo) {
4771 Register Op1 = MI.getOperand(i: 1).getReg();
4772 Register Op2 = MI.getOperand(i: 2).getReg();
4773 auto MaybeCst = ConstantFoldFPBinOp(Opcode: MI.getOpcode(), Op1, Op2, MRI);
4774 if (!MaybeCst)
4775 return false;
4776 MatchInfo =
4777 ConstantFP::get(Context&: MI.getMF()->getFunction().getContext(), V: *MaybeCst);
4778 return true;
4779}
4780
4781bool CombinerHelper::matchConstantFoldFMA(MachineInstr &MI,
4782 ConstantFP *&MatchInfo) {
4783 assert(MI.getOpcode() == TargetOpcode::G_FMA ||
4784 MI.getOpcode() == TargetOpcode::G_FMAD);
4785 auto [_, Op1, Op2, Op3] = MI.getFirst4Regs();
4786
4787 const ConstantFP *Op3Cst = getConstantFPVRegVal(VReg: Op3, MRI);
4788 if (!Op3Cst)
4789 return false;
4790
4791 const ConstantFP *Op2Cst = getConstantFPVRegVal(VReg: Op2, MRI);
4792 if (!Op2Cst)
4793 return false;
4794
4795 const ConstantFP *Op1Cst = getConstantFPVRegVal(VReg: Op1, MRI);
4796 if (!Op1Cst)
4797 return false;
4798
4799 APFloat Op1F = Op1Cst->getValueAPF();
4800 Op1F.fusedMultiplyAdd(Multiplicand: Op2Cst->getValueAPF(), Addend: Op3Cst->getValueAPF(),
4801 RM: APFloat::rmNearestTiesToEven);
4802 MatchInfo = ConstantFP::get(Context&: MI.getMF()->getFunction().getContext(), V: Op1F);
4803 return true;
4804}
4805
4806bool CombinerHelper::matchNarrowBinopFeedingAnd(
4807 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
4808 // Look for a binop feeding into an AND with a mask:
4809 //
4810 // %add = G_ADD %lhs, %rhs
4811 // %and = G_AND %add, 000...11111111
4812 //
4813 // Check if it's possible to perform the binop at a narrower width and zext
4814 // back to the original width like so:
4815 //
4816 // %narrow_lhs = G_TRUNC %lhs
4817 // %narrow_rhs = G_TRUNC %rhs
4818 // %narrow_add = G_ADD %narrow_lhs, %narrow_rhs
4819 // %new_add = G_ZEXT %narrow_add
4820 // %and = G_AND %new_add, 000...11111111
4821 //
4822 // This can allow later combines to eliminate the G_AND if it turns out
4823 // that the mask is irrelevant.
4824 assert(MI.getOpcode() == TargetOpcode::G_AND);
4825 Register Dst = MI.getOperand(i: 0).getReg();
4826 Register AndLHS = MI.getOperand(i: 1).getReg();
4827 Register AndRHS = MI.getOperand(i: 2).getReg();
4828 LLT WideTy = MRI.getType(Reg: Dst);
4829
4830 // If the potential binop has more than one use, then it's possible that one
4831 // of those uses will need its full width.
4832 if (!WideTy.isScalar() || !MRI.hasOneNonDBGUse(RegNo: AndLHS))
4833 return false;
4834
4835 // Check if the LHS feeding the AND is impacted by the high bits that we're
4836 // masking out.
4837 //
4838 // e.g. for 64-bit x, y:
4839 //
4840 // add_64(x, y) & 65535 == zext(add_16(trunc(x), trunc(y))) & 65535
4841 MachineInstr *LHSInst = getDefIgnoringCopies(Reg: AndLHS, MRI);
4842 if (!LHSInst)
4843 return false;
4844 unsigned LHSOpc = LHSInst->getOpcode();
4845 switch (LHSOpc) {
4846 default:
4847 return false;
4848 case TargetOpcode::G_ADD:
4849 case TargetOpcode::G_SUB:
4850 case TargetOpcode::G_MUL:
4851 case TargetOpcode::G_AND:
4852 case TargetOpcode::G_OR:
4853 case TargetOpcode::G_XOR:
4854 break;
4855 }
4856
4857 // Find the mask on the RHS.
4858 auto Cst = getIConstantVRegValWithLookThrough(VReg: AndRHS, MRI);
4859 if (!Cst)
4860 return false;
4861 auto Mask = Cst->Value;
4862 if (!Mask.isMask())
4863 return false;
4864
4865 // No point in combining if there's nothing to truncate.
4866 unsigned NarrowWidth = Mask.countr_one();
4867 if (NarrowWidth == WideTy.getSizeInBits())
4868 return false;
4869 LLT NarrowTy = LLT::scalar(SizeInBits: NarrowWidth);
4870
4871 // Check if adding the zext + truncates could be harmful.
4872 auto &MF = *MI.getMF();
4873 const auto &TLI = getTargetLowering();
4874 LLVMContext &Ctx = MF.getFunction().getContext();
4875 auto &DL = MF.getDataLayout();
4876 if (!TLI.isTruncateFree(FromTy: WideTy, ToTy: NarrowTy, DL, Ctx) ||
4877 !TLI.isZExtFree(FromTy: NarrowTy, ToTy: WideTy, DL, Ctx))
4878 return false;
4879 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_TRUNC, {NarrowTy, WideTy}}) ||
4880 !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ZEXT, {WideTy, NarrowTy}}))
4881 return false;
4882 Register BinOpLHS = LHSInst->getOperand(i: 1).getReg();
4883 Register BinOpRHS = LHSInst->getOperand(i: 2).getReg();
4884 MatchInfo = [=, &MI](MachineIRBuilder &B) {
4885 auto NarrowLHS = Builder.buildTrunc(Res: NarrowTy, Op: BinOpLHS);
4886 auto NarrowRHS = Builder.buildTrunc(Res: NarrowTy, Op: BinOpRHS);
4887 auto NarrowBinOp =
4888 Builder.buildInstr(Opc: LHSOpc, DstOps: {NarrowTy}, SrcOps: {NarrowLHS, NarrowRHS});
4889 auto Ext = Builder.buildZExt(Res: WideTy, Op: NarrowBinOp);
4890 Observer.changingInstr(MI);
4891 MI.getOperand(i: 1).setReg(Ext.getReg(Idx: 0));
4892 Observer.changedInstr(MI);
4893 };
4894 return true;
4895}
4896
4897bool CombinerHelper::matchMulOBy2(MachineInstr &MI, BuildFnTy &MatchInfo) {
4898 unsigned Opc = MI.getOpcode();
4899 assert(Opc == TargetOpcode::G_UMULO || Opc == TargetOpcode::G_SMULO);
4900
4901 if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 2)))
4902 return false;
4903
4904 MatchInfo = [=, &MI](MachineIRBuilder &B) {
4905 Observer.changingInstr(MI);
4906 unsigned NewOpc = Opc == TargetOpcode::G_UMULO ? TargetOpcode::G_UADDO
4907 : TargetOpcode::G_SADDO;
4908 MI.setDesc(Builder.getTII().get(Opcode: NewOpc));
4909 MI.getOperand(i: 3).setReg(MI.getOperand(i: 2).getReg());
4910 Observer.changedInstr(MI);
4911 };
4912 return true;
4913}
4914
4915bool CombinerHelper::matchMulOBy0(MachineInstr &MI, BuildFnTy &MatchInfo) {
4916 // (G_*MULO x, 0) -> 0 + no carry out
4917 assert(MI.getOpcode() == TargetOpcode::G_UMULO ||
4918 MI.getOpcode() == TargetOpcode::G_SMULO);
4919 if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 0)))
4920 return false;
4921 Register Dst = MI.getOperand(i: 0).getReg();
4922 Register Carry = MI.getOperand(i: 1).getReg();
4923 if (!isConstantLegalOrBeforeLegalizer(Ty: MRI.getType(Reg: Dst)) ||
4924 !isConstantLegalOrBeforeLegalizer(Ty: MRI.getType(Reg: Carry)))
4925 return false;
4926 MatchInfo = [=](MachineIRBuilder &B) {
4927 B.buildConstant(Res: Dst, Val: 0);
4928 B.buildConstant(Res: Carry, Val: 0);
4929 };
4930 return true;
4931}
4932
4933bool CombinerHelper::matchAddOBy0(MachineInstr &MI, BuildFnTy &MatchInfo) {
4934 // (G_*ADDO x, 0) -> x + no carry out
4935 assert(MI.getOpcode() == TargetOpcode::G_UADDO ||
4936 MI.getOpcode() == TargetOpcode::G_SADDO);
4937 if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 0)))
4938 return false;
4939 Register Carry = MI.getOperand(i: 1).getReg();
4940 if (!isConstantLegalOrBeforeLegalizer(Ty: MRI.getType(Reg: Carry)))
4941 return false;
4942 Register Dst = MI.getOperand(i: 0).getReg();
4943 Register LHS = MI.getOperand(i: 2).getReg();
4944 MatchInfo = [=](MachineIRBuilder &B) {
4945 B.buildCopy(Res: Dst, Op: LHS);
4946 B.buildConstant(Res: Carry, Val: 0);
4947 };
4948 return true;
4949}
4950
4951bool CombinerHelper::matchAddEToAddO(MachineInstr &MI, BuildFnTy &MatchInfo) {
4952 // (G_*ADDE x, y, 0) -> (G_*ADDO x, y)
4953 // (G_*SUBE x, y, 0) -> (G_*SUBO x, y)
4954 assert(MI.getOpcode() == TargetOpcode::G_UADDE ||
4955 MI.getOpcode() == TargetOpcode::G_SADDE ||
4956 MI.getOpcode() == TargetOpcode::G_USUBE ||
4957 MI.getOpcode() == TargetOpcode::G_SSUBE);
4958 if (!mi_match(R: MI.getOperand(i: 4).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 0)))
4959 return false;
4960 MatchInfo = [&](MachineIRBuilder &B) {
4961 unsigned NewOpcode;
4962 switch (MI.getOpcode()) {
4963 case TargetOpcode::G_UADDE:
4964 NewOpcode = TargetOpcode::G_UADDO;
4965 break;
4966 case TargetOpcode::G_SADDE:
4967 NewOpcode = TargetOpcode::G_SADDO;
4968 break;
4969 case TargetOpcode::G_USUBE:
4970 NewOpcode = TargetOpcode::G_USUBO;
4971 break;
4972 case TargetOpcode::G_SSUBE:
4973 NewOpcode = TargetOpcode::G_SSUBO;
4974 break;
4975 }
4976 Observer.changingInstr(MI);
4977 MI.setDesc(B.getTII().get(Opcode: NewOpcode));
4978 MI.removeOperand(OpNo: 4);
4979 Observer.changedInstr(MI);
4980 };
4981 return true;
4982}
4983
4984bool CombinerHelper::matchSubAddSameReg(MachineInstr &MI,
4985 BuildFnTy &MatchInfo) {
4986 assert(MI.getOpcode() == TargetOpcode::G_SUB);
4987 Register Dst = MI.getOperand(i: 0).getReg();
4988 // (x + y) - z -> x (if y == z)
4989 // (x + y) - z -> y (if x == z)
4990 Register X, Y, Z;
4991 if (mi_match(R: Dst, MRI, P: m_GSub(L: m_GAdd(L: m_Reg(R&: X), R: m_Reg(R&: Y)), R: m_Reg(R&: Z)))) {
4992 Register ReplaceReg;
4993 int64_t CstX, CstY;
4994 if (Y == Z || (mi_match(R: Y, MRI, P: m_ICstOrSplat(Cst&: CstY)) &&
4995 mi_match(R: Z, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstY))))
4996 ReplaceReg = X;
4997 else if (X == Z || (mi_match(R: X, MRI, P: m_ICstOrSplat(Cst&: CstX)) &&
4998 mi_match(R: Z, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstX))))
4999 ReplaceReg = Y;
5000 if (ReplaceReg) {
5001 MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Res: Dst, Op: ReplaceReg); };
5002 return true;
5003 }
5004 }
5005
5006 // x - (y + z) -> 0 - y (if x == z)
5007 // x - (y + z) -> 0 - z (if x == y)
5008 if (mi_match(R: Dst, MRI, P: m_GSub(L: m_Reg(R&: X), R: m_GAdd(L: m_Reg(R&: Y), R: m_Reg(R&: Z))))) {
5009 Register ReplaceReg;
5010 int64_t CstX;
5011 if (X == Z || (mi_match(R: X, MRI, P: m_ICstOrSplat(Cst&: CstX)) &&
5012 mi_match(R: Z, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstX))))
5013 ReplaceReg = Y;
5014 else if (X == Y || (mi_match(R: X, MRI, P: m_ICstOrSplat(Cst&: CstX)) &&
5015 mi_match(R: Y, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstX))))
5016 ReplaceReg = Z;
5017 if (ReplaceReg) {
5018 MatchInfo = [=](MachineIRBuilder &B) {
5019 auto Zero = B.buildConstant(Res: MRI.getType(Reg: Dst), Val: 0);
5020 B.buildSub(Dst, Src0: Zero, Src1: ReplaceReg);
5021 };
5022 return true;
5023 }
5024 }
5025 return false;
5026}
5027
5028MachineInstr *CombinerHelper::buildUDivUsingMul(MachineInstr &MI) {
5029 assert(MI.getOpcode() == TargetOpcode::G_UDIV);
5030 auto &UDiv = cast<GenericMachineInstr>(Val&: MI);
5031 Register Dst = UDiv.getReg(Idx: 0);
5032 Register LHS = UDiv.getReg(Idx: 1);
5033 Register RHS = UDiv.getReg(Idx: 2);
5034 LLT Ty = MRI.getType(Reg: Dst);
5035 LLT ScalarTy = Ty.getScalarType();
5036 const unsigned EltBits = ScalarTy.getScalarSizeInBits();
5037 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
5038 LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType();
5039 auto &MIB = Builder;
5040 MIB.setInstrAndDebugLoc(MI);
5041
5042 bool UseNPQ = false;
5043 SmallVector<Register, 16> PreShifts, PostShifts, MagicFactors, NPQFactors;
5044
5045 auto BuildUDIVPattern = [&](const Constant *C) {
5046 auto *CI = cast<ConstantInt>(Val: C);
5047 const APInt &Divisor = CI->getValue();
5048
5049 bool SelNPQ = false;
5050 APInt Magic(Divisor.getBitWidth(), 0);
5051 unsigned PreShift = 0, PostShift = 0;
5052
5053 // Magic algorithm doesn't work for division by 1. We need to emit a select
5054 // at the end.
5055 // TODO: Use undef values for divisor of 1.
5056 if (!Divisor.isOne()) {
5057 UnsignedDivisionByConstantInfo magics =
5058 UnsignedDivisionByConstantInfo::get(D: Divisor);
5059
5060 Magic = std::move(magics.Magic);
5061
5062 assert(magics.PreShift < Divisor.getBitWidth() &&
5063 "We shouldn't generate an undefined shift!");
5064 assert(magics.PostShift < Divisor.getBitWidth() &&
5065 "We shouldn't generate an undefined shift!");
5066 assert((!magics.IsAdd || magics.PreShift == 0) && "Unexpected pre-shift");
5067 PreShift = magics.PreShift;
5068 PostShift = magics.PostShift;
5069 SelNPQ = magics.IsAdd;
5070 }
5071
5072 PreShifts.push_back(
5073 Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: PreShift).getReg(Idx: 0));
5074 MagicFactors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Magic).getReg(Idx: 0));
5075 NPQFactors.push_back(
5076 Elt: MIB.buildConstant(Res: ScalarTy,
5077 Val: SelNPQ ? APInt::getOneBitSet(numBits: EltBits, BitNo: EltBits - 1)
5078 : APInt::getZero(numBits: EltBits))
5079 .getReg(Idx: 0));
5080 PostShifts.push_back(
5081 Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: PostShift).getReg(Idx: 0));
5082 UseNPQ |= SelNPQ;
5083 return true;
5084 };
5085
5086 // Collect the shifts/magic values from each element.
5087 bool Matched = matchUnaryPredicate(MRI, Reg: RHS, Match: BuildUDIVPattern);
5088 (void)Matched;
5089 assert(Matched && "Expected unary predicate match to succeed");
5090
5091 Register PreShift, PostShift, MagicFactor, NPQFactor;
5092 auto *RHSDef = getOpcodeDef<GBuildVector>(Reg: RHS, MRI);
5093 if (RHSDef) {
5094 PreShift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: PreShifts).getReg(Idx: 0);
5095 MagicFactor = MIB.buildBuildVector(Res: Ty, Ops: MagicFactors).getReg(Idx: 0);
5096 NPQFactor = MIB.buildBuildVector(Res: Ty, Ops: NPQFactors).getReg(Idx: 0);
5097 PostShift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: PostShifts).getReg(Idx: 0);
5098 } else {
5099 assert(MRI.getType(RHS).isScalar() &&
5100 "Non-build_vector operation should have been a scalar");
5101 PreShift = PreShifts[0];
5102 MagicFactor = MagicFactors[0];
5103 PostShift = PostShifts[0];
5104 }
5105
5106 Register Q = LHS;
5107 Q = MIB.buildLShr(Dst: Ty, Src0: Q, Src1: PreShift).getReg(Idx: 0);
5108
5109 // Multiply the numerator (operand 0) by the magic value.
5110 Q = MIB.buildUMulH(Dst: Ty, Src0: Q, Src1: MagicFactor).getReg(Idx: 0);
5111
5112 if (UseNPQ) {
5113 Register NPQ = MIB.buildSub(Dst: Ty, Src0: LHS, Src1: Q).getReg(Idx: 0);
5114
5115 // For vectors we might have a mix of non-NPQ/NPQ paths, so use
5116 // G_UMULH to act as a SRL-by-1 for NPQ, else multiply by zero.
5117 if (Ty.isVector())
5118 NPQ = MIB.buildUMulH(Dst: Ty, Src0: NPQ, Src1: NPQFactor).getReg(Idx: 0);
5119 else
5120 NPQ = MIB.buildLShr(Dst: Ty, Src0: NPQ, Src1: MIB.buildConstant(Res: ShiftAmtTy, Val: 1)).getReg(Idx: 0);
5121
5122 Q = MIB.buildAdd(Dst: Ty, Src0: NPQ, Src1: Q).getReg(Idx: 0);
5123 }
5124
5125 Q = MIB.buildLShr(Dst: Ty, Src0: Q, Src1: PostShift).getReg(Idx: 0);
5126 auto One = MIB.buildConstant(Res: Ty, Val: 1);
5127 auto IsOne = MIB.buildICmp(
5128 Pred: CmpInst::Predicate::ICMP_EQ,
5129 Res: Ty.isScalar() ? LLT::scalar(SizeInBits: 1) : Ty.changeElementSize(NewEltSize: 1), Op0: RHS, Op1: One);
5130 return MIB.buildSelect(Res: Ty, Tst: IsOne, Op0: LHS, Op1: Q);
5131}
5132
5133bool CombinerHelper::matchUDivByConst(MachineInstr &MI) {
5134 assert(MI.getOpcode() == TargetOpcode::G_UDIV);
5135 Register Dst = MI.getOperand(i: 0).getReg();
5136 Register RHS = MI.getOperand(i: 2).getReg();
5137 LLT DstTy = MRI.getType(Reg: Dst);
5138 auto *RHSDef = MRI.getVRegDef(Reg: RHS);
5139 if (!isConstantOrConstantVector(MI&: *RHSDef, MRI))
5140 return false;
5141
5142 auto &MF = *MI.getMF();
5143 AttributeList Attr = MF.getFunction().getAttributes();
5144 const auto &TLI = getTargetLowering();
5145 LLVMContext &Ctx = MF.getFunction().getContext();
5146 auto &DL = MF.getDataLayout();
5147 if (TLI.isIntDivCheap(VT: getApproximateEVTForLLT(Ty: DstTy, DL, Ctx), Attr))
5148 return false;
5149
5150 // Don't do this for minsize because the instruction sequence is usually
5151 // larger.
5152 if (MF.getFunction().hasMinSize())
5153 return false;
5154
5155 // Don't do this if the types are not going to be legal.
5156 if (LI) {
5157 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_MUL, {DstTy, DstTy}}))
5158 return false;
5159 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_UMULH, {DstTy}}))
5160 return false;
5161 if (!isLegalOrBeforeLegalizer(
5162 Query: {TargetOpcode::G_ICMP,
5163 {DstTy.isVector() ? DstTy.changeElementSize(NewEltSize: 1) : LLT::scalar(SizeInBits: 1),
5164 DstTy}}))
5165 return false;
5166 }
5167
5168 auto CheckEltValue = [&](const Constant *C) {
5169 if (auto *CI = dyn_cast_or_null<ConstantInt>(Val: C))
5170 return !CI->isZero();
5171 return false;
5172 };
5173 return matchUnaryPredicate(MRI, Reg: RHS, Match: CheckEltValue);
5174}
5175
5176void CombinerHelper::applyUDivByConst(MachineInstr &MI) {
5177 auto *NewMI = buildUDivUsingMul(MI);
5178 replaceSingleDefInstWithReg(MI, Replacement: NewMI->getOperand(i: 0).getReg());
5179}
5180
5181bool CombinerHelper::matchSDivByConst(MachineInstr &MI) {
5182 assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV");
5183 Register Dst = MI.getOperand(i: 0).getReg();
5184 Register RHS = MI.getOperand(i: 2).getReg();
5185 LLT DstTy = MRI.getType(Reg: Dst);
5186
5187 auto &MF = *MI.getMF();
5188 AttributeList Attr = MF.getFunction().getAttributes();
5189 const auto &TLI = getTargetLowering();
5190 LLVMContext &Ctx = MF.getFunction().getContext();
5191 auto &DL = MF.getDataLayout();
5192 if (TLI.isIntDivCheap(VT: getApproximateEVTForLLT(Ty: DstTy, DL, Ctx), Attr))
5193 return false;
5194
5195 // Don't do this for minsize because the instruction sequence is usually
5196 // larger.
5197 if (MF.getFunction().hasMinSize())
5198 return false;
5199
5200 // If the sdiv has an 'exact' flag we can use a simpler lowering.
5201 if (MI.getFlag(Flag: MachineInstr::MIFlag::IsExact)) {
5202 return matchUnaryPredicate(
5203 MRI, Reg: RHS, Match: [](const Constant *C) { return C && !C->isZeroValue(); });
5204 }
5205
5206 // Don't support the general case for now.
5207 return false;
5208}
5209
5210void CombinerHelper::applySDivByConst(MachineInstr &MI) {
5211 auto *NewMI = buildSDivUsingMul(MI);
5212 replaceSingleDefInstWithReg(MI, Replacement: NewMI->getOperand(i: 0).getReg());
5213}
5214
5215MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) {
5216 assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV");
5217 auto &SDiv = cast<GenericMachineInstr>(Val&: MI);
5218 Register Dst = SDiv.getReg(Idx: 0);
5219 Register LHS = SDiv.getReg(Idx: 1);
5220 Register RHS = SDiv.getReg(Idx: 2);
5221 LLT Ty = MRI.getType(Reg: Dst);
5222 LLT ScalarTy = Ty.getScalarType();
5223 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
5224 LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType();
5225 auto &MIB = Builder;
5226 MIB.setInstrAndDebugLoc(MI);
5227
5228 bool UseSRA = false;
5229 SmallVector<Register, 16> Shifts, Factors;
5230
5231 auto *RHSDef = cast<GenericMachineInstr>(Val: getDefIgnoringCopies(Reg: RHS, MRI));
5232 bool IsSplat = getIConstantSplatVal(MI: *RHSDef, MRI).has_value();
5233
5234 auto BuildSDIVPattern = [&](const Constant *C) {
5235 // Don't recompute inverses for each splat element.
5236 if (IsSplat && !Factors.empty()) {
5237 Shifts.push_back(Elt: Shifts[0]);
5238 Factors.push_back(Elt: Factors[0]);
5239 return true;
5240 }
5241
5242 auto *CI = cast<ConstantInt>(Val: C);
5243 APInt Divisor = CI->getValue();
5244 unsigned Shift = Divisor.countr_zero();
5245 if (Shift) {
5246 Divisor.ashrInPlace(ShiftAmt: Shift);
5247 UseSRA = true;
5248 }
5249
5250 // Calculate the multiplicative inverse modulo BW.
5251 // 2^W requires W + 1 bits, so we have to extend and then truncate.
5252 unsigned W = Divisor.getBitWidth();
5253 APInt Factor = Divisor.zext(width: W + 1)
5254 .multiplicativeInverse(modulo: APInt::getSignedMinValue(numBits: W + 1))
5255 .trunc(width: W);
5256 Shifts.push_back(Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: Shift).getReg(Idx: 0));
5257 Factors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Factor).getReg(Idx: 0));
5258 return true;
5259 };
5260
5261 // Collect all magic values from the build vector.
5262 bool Matched = matchUnaryPredicate(MRI, Reg: RHS, Match: BuildSDIVPattern);
5263 (void)Matched;
5264 assert(Matched && "Expected unary predicate match to succeed");
5265
5266 Register Shift, Factor;
5267 if (Ty.isVector()) {
5268 Shift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: Shifts).getReg(Idx: 0);
5269 Factor = MIB.buildBuildVector(Res: Ty, Ops: Factors).getReg(Idx: 0);
5270 } else {
5271 Shift = Shifts[0];
5272 Factor = Factors[0];
5273 }
5274
5275 Register Res = LHS;
5276
5277 if (UseSRA)
5278 Res = MIB.buildAShr(Dst: Ty, Src0: Res, Src1: Shift, Flags: MachineInstr::IsExact).getReg(Idx: 0);
5279
5280 return MIB.buildMul(Dst: Ty, Src0: Res, Src1: Factor);
5281}
5282
5283bool CombinerHelper::matchUMulHToLShr(MachineInstr &MI) {
5284 assert(MI.getOpcode() == TargetOpcode::G_UMULH);
5285 Register RHS = MI.getOperand(i: 2).getReg();
5286 Register Dst = MI.getOperand(i: 0).getReg();
5287 LLT Ty = MRI.getType(Reg: Dst);
5288 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
5289 auto MatchPow2ExceptOne = [&](const Constant *C) {
5290 if (auto *CI = dyn_cast<ConstantInt>(Val: C))
5291 return CI->getValue().isPowerOf2() && !CI->getValue().isOne();
5292 return false;
5293 };
5294 if (!matchUnaryPredicate(MRI, Reg: RHS, Match: MatchPow2ExceptOne, AllowUndefs: false))
5295 return false;
5296 return isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_LSHR, {Ty, ShiftAmtTy}});
5297}
5298
5299void CombinerHelper::applyUMulHToLShr(MachineInstr &MI) {
5300 Register LHS = MI.getOperand(i: 1).getReg();
5301 Register RHS = MI.getOperand(i: 2).getReg();
5302 Register Dst = MI.getOperand(i: 0).getReg();
5303 LLT Ty = MRI.getType(Reg: Dst);
5304 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty);
5305 unsigned NumEltBits = Ty.getScalarSizeInBits();
5306
5307 Builder.setInstrAndDebugLoc(MI);
5308 auto LogBase2 = buildLogBase2(V: RHS, MIB&: Builder);
5309 auto ShiftAmt =
5310 Builder.buildSub(Dst: Ty, Src0: Builder.buildConstant(Res: Ty, Val: NumEltBits), Src1: LogBase2);
5311 auto Trunc = Builder.buildZExtOrTrunc(Res: ShiftAmtTy, Op: ShiftAmt);
5312 Builder.buildLShr(Dst, Src0: LHS, Src1: Trunc);
5313 MI.eraseFromParent();
5314}
5315
5316bool CombinerHelper::matchRedundantNegOperands(MachineInstr &MI,
5317 BuildFnTy &MatchInfo) {
5318 unsigned Opc = MI.getOpcode();
5319 assert(Opc == TargetOpcode::G_FADD || Opc == TargetOpcode::G_FSUB ||
5320 Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV ||
5321 Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA);
5322
5323 Register Dst = MI.getOperand(i: 0).getReg();
5324 Register X = MI.getOperand(i: 1).getReg();
5325 Register Y = MI.getOperand(i: 2).getReg();
5326 LLT Type = MRI.getType(Reg: Dst);
5327
5328 // fold (fadd x, fneg(y)) -> (fsub x, y)
5329 // fold (fadd fneg(y), x) -> (fsub x, y)
5330 // G_ADD is commutative so both cases are checked by m_GFAdd
5331 if (mi_match(R: Dst, MRI, P: m_GFAdd(L: m_Reg(R&: X), R: m_GFNeg(Src: m_Reg(R&: Y)))) &&
5332 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FSUB, {Type}})) {
5333 Opc = TargetOpcode::G_FSUB;
5334 }
5335 /// fold (fsub x, fneg(y)) -> (fadd x, y)
5336 else if (mi_match(R: Dst, MRI, P: m_GFSub(L: m_Reg(R&: X), R: m_GFNeg(Src: m_Reg(R&: Y)))) &&
5337 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FADD, {Type}})) {
5338 Opc = TargetOpcode::G_FADD;
5339 }
5340 // fold (fmul fneg(x), fneg(y)) -> (fmul x, y)
5341 // fold (fdiv fneg(x), fneg(y)) -> (fdiv x, y)
5342 // fold (fmad fneg(x), fneg(y), z) -> (fmad x, y, z)
5343 // fold (fma fneg(x), fneg(y), z) -> (fma x, y, z)
5344 else if ((Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV ||
5345 Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA) &&
5346 mi_match(R: X, MRI, P: m_GFNeg(Src: m_Reg(R&: X))) &&
5347 mi_match(R: Y, MRI, P: m_GFNeg(Src: m_Reg(R&: Y)))) {
5348 // no opcode change
5349 } else
5350 return false;
5351
5352 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5353 Observer.changingInstr(MI);
5354 MI.setDesc(B.getTII().get(Opcode: Opc));
5355 MI.getOperand(i: 1).setReg(X);
5356 MI.getOperand(i: 2).setReg(Y);
5357 Observer.changedInstr(MI);
5358 };
5359 return true;
5360}
5361
5362bool CombinerHelper::matchFsubToFneg(MachineInstr &MI, Register &MatchInfo) {
5363 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
5364
5365 Register LHS = MI.getOperand(i: 1).getReg();
5366 MatchInfo = MI.getOperand(i: 2).getReg();
5367 LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
5368
5369 const auto LHSCst = Ty.isVector()
5370 ? getFConstantSplat(VReg: LHS, MRI, /* allowUndef */ AllowUndef: true)
5371 : getFConstantVRegValWithLookThrough(VReg: LHS, MRI);
5372 if (!LHSCst)
5373 return false;
5374
5375 // -0.0 is always allowed
5376 if (LHSCst->Value.isNegZero())
5377 return true;
5378
5379 // +0.0 is only allowed if nsz is set.
5380 if (LHSCst->Value.isPosZero())
5381 return MI.getFlag(Flag: MachineInstr::FmNsz);
5382
5383 return false;
5384}
5385
5386void CombinerHelper::applyFsubToFneg(MachineInstr &MI, Register &MatchInfo) {
5387 Builder.setInstrAndDebugLoc(MI);
5388 Register Dst = MI.getOperand(i: 0).getReg();
5389 Builder.buildFNeg(
5390 Dst, Src0: Builder.buildFCanonicalize(Dst: MRI.getType(Reg: Dst), Src0: MatchInfo).getReg(Idx: 0));
5391 eraseInst(MI);
5392}
5393
5394/// Checks if \p MI is TargetOpcode::G_FMUL and contractable either
5395/// due to global flags or MachineInstr flags.
5396static bool isContractableFMul(MachineInstr &MI, bool AllowFusionGlobally) {
5397 if (MI.getOpcode() != TargetOpcode::G_FMUL)
5398 return false;
5399 return AllowFusionGlobally || MI.getFlag(Flag: MachineInstr::MIFlag::FmContract);
5400}
5401
5402static bool hasMoreUses(const MachineInstr &MI0, const MachineInstr &MI1,
5403 const MachineRegisterInfo &MRI) {
5404 return std::distance(first: MRI.use_instr_nodbg_begin(RegNo: MI0.getOperand(i: 0).getReg()),
5405 last: MRI.use_instr_nodbg_end()) >
5406 std::distance(first: MRI.use_instr_nodbg_begin(RegNo: MI1.getOperand(i: 0).getReg()),
5407 last: MRI.use_instr_nodbg_end());
5408}
5409
5410bool CombinerHelper::canCombineFMadOrFMA(MachineInstr &MI,
5411 bool &AllowFusionGlobally,
5412 bool &HasFMAD, bool &Aggressive,
5413 bool CanReassociate) {
5414
5415 auto *MF = MI.getMF();
5416 const auto &TLI = *MF->getSubtarget().getTargetLowering();
5417 const TargetOptions &Options = MF->getTarget().Options;
5418 LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
5419
5420 if (CanReassociate &&
5421 !(Options.UnsafeFPMath || MI.getFlag(Flag: MachineInstr::MIFlag::FmReassoc)))
5422 return false;
5423
5424 // Floating-point multiply-add with intermediate rounding.
5425 HasFMAD = (!isPreLegalize() && TLI.isFMADLegal(MI, Ty: DstType));
5426 // Floating-point multiply-add without intermediate rounding.
5427 bool HasFMA = TLI.isFMAFasterThanFMulAndFAdd(MF: *MF, DstType) &&
5428 isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FMA, {DstType}});
5429 // No valid opcode, do not combine.
5430 if (!HasFMAD && !HasFMA)
5431 return false;
5432
5433 AllowFusionGlobally = Options.AllowFPOpFusion == FPOpFusion::Fast ||
5434 Options.UnsafeFPMath || HasFMAD;
5435 // If the addition is not contractable, do not combine.
5436 if (!AllowFusionGlobally && !MI.getFlag(Flag: MachineInstr::MIFlag::FmContract))
5437 return false;
5438
5439 Aggressive = TLI.enableAggressiveFMAFusion(Ty: DstType);
5440 return true;
5441}
5442
5443bool CombinerHelper::matchCombineFAddFMulToFMadOrFMA(
5444 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
5445 assert(MI.getOpcode() == TargetOpcode::G_FADD);
5446
5447 bool AllowFusionGlobally, HasFMAD, Aggressive;
5448 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
5449 return false;
5450
5451 Register Op1 = MI.getOperand(i: 1).getReg();
5452 Register Op2 = MI.getOperand(i: 2).getReg();
5453 DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1};
5454 DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2};
5455 unsigned PreferredFusedOpcode =
5456 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5457
5458 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
5459 // prefer to fold the multiply with fewer uses.
5460 if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
5461 isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) {
5462 if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI))
5463 std::swap(a&: LHS, b&: RHS);
5464 }
5465
5466 // fold (fadd (fmul x, y), z) -> (fma x, y, z)
5467 if (isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
5468 (Aggressive || MRI.hasOneNonDBGUse(RegNo: LHS.Reg))) {
5469 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5470 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
5471 SrcOps: {LHS.MI->getOperand(i: 1).getReg(),
5472 LHS.MI->getOperand(i: 2).getReg(), RHS.Reg});
5473 };
5474 return true;
5475 }
5476
5477 // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
5478 if (isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally) &&
5479 (Aggressive || MRI.hasOneNonDBGUse(RegNo: RHS.Reg))) {
5480 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5481 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
5482 SrcOps: {RHS.MI->getOperand(i: 1).getReg(),
5483 RHS.MI->getOperand(i: 2).getReg(), LHS.Reg});
5484 };
5485 return true;
5486 }
5487
5488 return false;
5489}
5490
5491bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMA(
5492 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
5493 assert(MI.getOpcode() == TargetOpcode::G_FADD);
5494
5495 bool AllowFusionGlobally, HasFMAD, Aggressive;
5496 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
5497 return false;
5498
5499 const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering();
5500 Register Op1 = MI.getOperand(i: 1).getReg();
5501 Register Op2 = MI.getOperand(i: 2).getReg();
5502 DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1};
5503 DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2};
5504 LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
5505
5506 unsigned PreferredFusedOpcode =
5507 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5508
5509 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
5510 // prefer to fold the multiply with fewer uses.
5511 if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
5512 isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) {
5513 if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI))
5514 std::swap(a&: LHS, b&: RHS);
5515 }
5516
5517 // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
5518 MachineInstr *FpExtSrc;
5519 if (mi_match(R: LHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FpExtSrc))) &&
5520 isContractableFMul(MI&: *FpExtSrc, AllowFusionGlobally) &&
5521 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
5522 SrcTy: MRI.getType(Reg: FpExtSrc->getOperand(i: 1).getReg()))) {
5523 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5524 auto FpExtX = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 1).getReg());
5525 auto FpExtY = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 2).getReg());
5526 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
5527 SrcOps: {FpExtX.getReg(Idx: 0), FpExtY.getReg(Idx: 0), RHS.Reg});
5528 };
5529 return true;
5530 }
5531
5532 // fold (fadd z, (fpext (fmul x, y))) -> (fma (fpext x), (fpext y), z)
5533 // Note: Commutes FADD operands.
5534 if (mi_match(R: RHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FpExtSrc))) &&
5535 isContractableFMul(MI&: *FpExtSrc, AllowFusionGlobally) &&
5536 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
5537 SrcTy: MRI.getType(Reg: FpExtSrc->getOperand(i: 1).getReg()))) {
5538 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5539 auto FpExtX = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 1).getReg());
5540 auto FpExtY = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 2).getReg());
5541 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
5542 SrcOps: {FpExtX.getReg(Idx: 0), FpExtY.getReg(Idx: 0), LHS.Reg});
5543 };
5544 return true;
5545 }
5546
5547 return false;
5548}
5549
5550bool CombinerHelper::matchCombineFAddFMAFMulToFMadOrFMA(
5551 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
5552 assert(MI.getOpcode() == TargetOpcode::G_FADD);
5553
5554 bool AllowFusionGlobally, HasFMAD, Aggressive;
5555 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive, CanReassociate: true))
5556 return false;
5557
5558 Register Op1 = MI.getOperand(i: 1).getReg();
5559 Register Op2 = MI.getOperand(i: 2).getReg();
5560 DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1};
5561 DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2};
5562 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
5563
5564 unsigned PreferredFusedOpcode =
5565 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5566
5567 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
5568 // prefer to fold the multiply with fewer uses.
5569 if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
5570 isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) {
5571 if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI))
5572 std::swap(a&: LHS, b&: RHS);
5573 }
5574
5575 MachineInstr *FMA = nullptr;
5576 Register Z;
5577 // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y, (fma u, v, z))
5578 if (LHS.MI->getOpcode() == PreferredFusedOpcode &&
5579 (MRI.getVRegDef(Reg: LHS.MI->getOperand(i: 3).getReg())->getOpcode() ==
5580 TargetOpcode::G_FMUL) &&
5581 MRI.hasOneNonDBGUse(RegNo: LHS.MI->getOperand(i: 0).getReg()) &&
5582 MRI.hasOneNonDBGUse(RegNo: LHS.MI->getOperand(i: 3).getReg())) {
5583 FMA = LHS.MI;
5584 Z = RHS.Reg;
5585 }
5586 // fold (fadd z, (fma x, y, (fmul u, v))) -> (fma x, y, (fma u, v, z))
5587 else if (RHS.MI->getOpcode() == PreferredFusedOpcode &&
5588 (MRI.getVRegDef(Reg: RHS.MI->getOperand(i: 3).getReg())->getOpcode() ==
5589 TargetOpcode::G_FMUL) &&
5590 MRI.hasOneNonDBGUse(RegNo: RHS.MI->getOperand(i: 0).getReg()) &&
5591 MRI.hasOneNonDBGUse(RegNo: RHS.MI->getOperand(i: 3).getReg())) {
5592 Z = LHS.Reg;
5593 FMA = RHS.MI;
5594 }
5595
5596 if (FMA) {
5597 MachineInstr *FMulMI = MRI.getVRegDef(Reg: FMA->getOperand(i: 3).getReg());
5598 Register X = FMA->getOperand(i: 1).getReg();
5599 Register Y = FMA->getOperand(i: 2).getReg();
5600 Register U = FMulMI->getOperand(i: 1).getReg();
5601 Register V = FMulMI->getOperand(i: 2).getReg();
5602
5603 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5604 Register InnerFMA = MRI.createGenericVirtualRegister(Ty: DstTy);
5605 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {InnerFMA}, SrcOps: {U, V, Z});
5606 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
5607 SrcOps: {X, Y, InnerFMA});
5608 };
5609 return true;
5610 }
5611
5612 return false;
5613}
5614
5615bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMAAggressive(
5616 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
5617 assert(MI.getOpcode() == TargetOpcode::G_FADD);
5618
5619 bool AllowFusionGlobally, HasFMAD, Aggressive;
5620 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
5621 return false;
5622
5623 if (!Aggressive)
5624 return false;
5625
5626 const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering();
5627 LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
5628 Register Op1 = MI.getOperand(i: 1).getReg();
5629 Register Op2 = MI.getOperand(i: 2).getReg();
5630 DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1};
5631 DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2};
5632
5633 unsigned PreferredFusedOpcode =
5634 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5635
5636 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
5637 // prefer to fold the multiply with fewer uses.
5638 if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
5639 isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) {
5640 if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI))
5641 std::swap(a&: LHS, b&: RHS);
5642 }
5643
5644 // Builds: (fma x, y, (fma (fpext u), (fpext v), z))
5645 auto buildMatchInfo = [=, &MI](Register U, Register V, Register Z, Register X,
5646 Register Y, MachineIRBuilder &B) {
5647 Register FpExtU = B.buildFPExt(Res: DstType, Op: U).getReg(Idx: 0);
5648 Register FpExtV = B.buildFPExt(Res: DstType, Op: V).getReg(Idx: 0);
5649 Register InnerFMA =
5650 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {DstType}, SrcOps: {FpExtU, FpExtV, Z})
5651 .getReg(Idx: 0);
5652 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
5653 SrcOps: {X, Y, InnerFMA});
5654 };
5655
5656 MachineInstr *FMulMI, *FMAMI;
5657 // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
5658 // -> (fma x, y, (fma (fpext u), (fpext v), z))
5659 if (LHS.MI->getOpcode() == PreferredFusedOpcode &&
5660 mi_match(R: LHS.MI->getOperand(i: 3).getReg(), MRI,
5661 P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) &&
5662 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
5663 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
5664 SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) {
5665 MatchInfo = [=](MachineIRBuilder &B) {
5666 buildMatchInfo(FMulMI->getOperand(i: 1).getReg(),
5667 FMulMI->getOperand(i: 2).getReg(), RHS.Reg,
5668 LHS.MI->getOperand(i: 1).getReg(),
5669 LHS.MI->getOperand(i: 2).getReg(), B);
5670 };
5671 return true;
5672 }
5673
5674 // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
5675 // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
5676 // FIXME: This turns two single-precision and one double-precision
5677 // operation into two double-precision operations, which might not be
5678 // interesting for all targets, especially GPUs.
5679 if (mi_match(R: LHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMAMI))) &&
5680 FMAMI->getOpcode() == PreferredFusedOpcode) {
5681 MachineInstr *FMulMI = MRI.getVRegDef(Reg: FMAMI->getOperand(i: 3).getReg());
5682 if (isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
5683 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
5684 SrcTy: MRI.getType(Reg: FMAMI->getOperand(i: 0).getReg()))) {
5685 MatchInfo = [=](MachineIRBuilder &B) {
5686 Register X = FMAMI->getOperand(i: 1).getReg();
5687 Register Y = FMAMI->getOperand(i: 2).getReg();
5688 X = B.buildFPExt(Res: DstType, Op: X).getReg(Idx: 0);
5689 Y = B.buildFPExt(Res: DstType, Op: Y).getReg(Idx: 0);
5690 buildMatchInfo(FMulMI->getOperand(i: 1).getReg(),
5691 FMulMI->getOperand(i: 2).getReg(), RHS.Reg, X, Y, B);
5692 };
5693
5694 return true;
5695 }
5696 }
5697
5698 // fold (fadd z, (fma x, y, (fpext (fmul u, v)))
5699 // -> (fma x, y, (fma (fpext u), (fpext v), z))
5700 if (RHS.MI->getOpcode() == PreferredFusedOpcode &&
5701 mi_match(R: RHS.MI->getOperand(i: 3).getReg(), MRI,
5702 P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) &&
5703 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
5704 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
5705 SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) {
5706 MatchInfo = [=](MachineIRBuilder &B) {
5707 buildMatchInfo(FMulMI->getOperand(i: 1).getReg(),
5708 FMulMI->getOperand(i: 2).getReg(), LHS.Reg,
5709 RHS.MI->getOperand(i: 1).getReg(),
5710 RHS.MI->getOperand(i: 2).getReg(), B);
5711 };
5712 return true;
5713 }
5714
5715 // fold (fadd z, (fpext (fma x, y, (fmul u, v)))
5716 // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
5717 // FIXME: This turns two single-precision and one double-precision
5718 // operation into two double-precision operations, which might not be
5719 // interesting for all targets, especially GPUs.
5720 if (mi_match(R: RHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMAMI))) &&
5721 FMAMI->getOpcode() == PreferredFusedOpcode) {
5722 MachineInstr *FMulMI = MRI.getVRegDef(Reg: FMAMI->getOperand(i: 3).getReg());
5723 if (isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
5724 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType,
5725 SrcTy: MRI.getType(Reg: FMAMI->getOperand(i: 0).getReg()))) {
5726 MatchInfo = [=](MachineIRBuilder &B) {
5727 Register X = FMAMI->getOperand(i: 1).getReg();
5728 Register Y = FMAMI->getOperand(i: 2).getReg();
5729 X = B.buildFPExt(Res: DstType, Op: X).getReg(Idx: 0);
5730 Y = B.buildFPExt(Res: DstType, Op: Y).getReg(Idx: 0);
5731 buildMatchInfo(FMulMI->getOperand(i: 1).getReg(),
5732 FMulMI->getOperand(i: 2).getReg(), LHS.Reg, X, Y, B);
5733 };
5734 return true;
5735 }
5736 }
5737
5738 return false;
5739}
5740
5741bool CombinerHelper::matchCombineFSubFMulToFMadOrFMA(
5742 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
5743 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
5744
5745 bool AllowFusionGlobally, HasFMAD, Aggressive;
5746 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
5747 return false;
5748
5749 Register Op1 = MI.getOperand(i: 1).getReg();
5750 Register Op2 = MI.getOperand(i: 2).getReg();
5751 DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1};
5752 DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2};
5753 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
5754
5755 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
5756 // prefer to fold the multiply with fewer uses.
5757 int FirstMulHasFewerUses = true;
5758 if (isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
5759 isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally) &&
5760 hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI))
5761 FirstMulHasFewerUses = false;
5762
5763 unsigned PreferredFusedOpcode =
5764 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5765
5766 // fold (fsub (fmul x, y), z) -> (fma x, y, -z)
5767 if (FirstMulHasFewerUses &&
5768 (isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) &&
5769 (Aggressive || MRI.hasOneNonDBGUse(RegNo: LHS.Reg)))) {
5770 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5771 Register NegZ = B.buildFNeg(Dst: DstTy, Src0: RHS.Reg).getReg(Idx: 0);
5772 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
5773 SrcOps: {LHS.MI->getOperand(i: 1).getReg(),
5774 LHS.MI->getOperand(i: 2).getReg(), NegZ});
5775 };
5776 return true;
5777 }
5778 // fold (fsub x, (fmul y, z)) -> (fma -y, z, x)
5779 else if ((isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally) &&
5780 (Aggressive || MRI.hasOneNonDBGUse(RegNo: RHS.Reg)))) {
5781 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5782 Register NegY =
5783 B.buildFNeg(Dst: DstTy, Src0: RHS.MI->getOperand(i: 1).getReg()).getReg(Idx: 0);
5784 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
5785 SrcOps: {NegY, RHS.MI->getOperand(i: 2).getReg(), LHS.Reg});
5786 };
5787 return true;
5788 }
5789
5790 return false;
5791}
5792
5793bool CombinerHelper::matchCombineFSubFNegFMulToFMadOrFMA(
5794 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
5795 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
5796
5797 bool AllowFusionGlobally, HasFMAD, Aggressive;
5798 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
5799 return false;
5800
5801 Register LHSReg = MI.getOperand(i: 1).getReg();
5802 Register RHSReg = MI.getOperand(i: 2).getReg();
5803 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
5804
5805 unsigned PreferredFusedOpcode =
5806 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5807
5808 MachineInstr *FMulMI;
5809 // fold (fsub (fneg (fmul x, y)), z) -> (fma (fneg x), y, (fneg z))
5810 if (mi_match(R: LHSReg, MRI, P: m_GFNeg(Src: m_MInstr(MI&: FMulMI))) &&
5811 (Aggressive || (MRI.hasOneNonDBGUse(RegNo: LHSReg) &&
5812 MRI.hasOneNonDBGUse(RegNo: FMulMI->getOperand(i: 0).getReg()))) &&
5813 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally)) {
5814 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5815 Register NegX =
5816 B.buildFNeg(Dst: DstTy, Src0: FMulMI->getOperand(i: 1).getReg()).getReg(Idx: 0);
5817 Register NegZ = B.buildFNeg(Dst: DstTy, Src0: RHSReg).getReg(Idx: 0);
5818 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
5819 SrcOps: {NegX, FMulMI->getOperand(i: 2).getReg(), NegZ});
5820 };
5821 return true;
5822 }
5823
5824 // fold (fsub x, (fneg (fmul, y, z))) -> (fma y, z, x)
5825 if (mi_match(R: RHSReg, MRI, P: m_GFNeg(Src: m_MInstr(MI&: FMulMI))) &&
5826 (Aggressive || (MRI.hasOneNonDBGUse(RegNo: RHSReg) &&
5827 MRI.hasOneNonDBGUse(RegNo: FMulMI->getOperand(i: 0).getReg()))) &&
5828 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally)) {
5829 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5830 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
5831 SrcOps: {FMulMI->getOperand(i: 1).getReg(),
5832 FMulMI->getOperand(i: 2).getReg(), LHSReg});
5833 };
5834 return true;
5835 }
5836
5837 return false;
5838}
5839
5840bool CombinerHelper::matchCombineFSubFpExtFMulToFMadOrFMA(
5841 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
5842 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
5843
5844 bool AllowFusionGlobally, HasFMAD, Aggressive;
5845 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
5846 return false;
5847
5848 Register LHSReg = MI.getOperand(i: 1).getReg();
5849 Register RHSReg = MI.getOperand(i: 2).getReg();
5850 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
5851
5852 unsigned PreferredFusedOpcode =
5853 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5854
5855 MachineInstr *FMulMI;
5856 // fold (fsub (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), (fneg z))
5857 if (mi_match(R: LHSReg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) &&
5858 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
5859 (Aggressive || MRI.hasOneNonDBGUse(RegNo: LHSReg))) {
5860 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5861 Register FpExtX =
5862 B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 1).getReg()).getReg(Idx: 0);
5863 Register FpExtY =
5864 B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 2).getReg()).getReg(Idx: 0);
5865 Register NegZ = B.buildFNeg(Dst: DstTy, Src0: RHSReg).getReg(Idx: 0);
5866 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
5867 SrcOps: {FpExtX, FpExtY, NegZ});
5868 };
5869 return true;
5870 }
5871
5872 // fold (fsub x, (fpext (fmul y, z))) -> (fma (fneg (fpext y)), (fpext z), x)
5873 if (mi_match(R: RHSReg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) &&
5874 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
5875 (Aggressive || MRI.hasOneNonDBGUse(RegNo: RHSReg))) {
5876 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5877 Register FpExtY =
5878 B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 1).getReg()).getReg(Idx: 0);
5879 Register NegY = B.buildFNeg(Dst: DstTy, Src0: FpExtY).getReg(Idx: 0);
5880 Register FpExtZ =
5881 B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 2).getReg()).getReg(Idx: 0);
5882 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()},
5883 SrcOps: {NegY, FpExtZ, LHSReg});
5884 };
5885 return true;
5886 }
5887
5888 return false;
5889}
5890
5891bool CombinerHelper::matchCombineFSubFpExtFNegFMulToFMadOrFMA(
5892 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
5893 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
5894
5895 bool AllowFusionGlobally, HasFMAD, Aggressive;
5896 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
5897 return false;
5898
5899 const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering();
5900 LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
5901 Register LHSReg = MI.getOperand(i: 1).getReg();
5902 Register RHSReg = MI.getOperand(i: 2).getReg();
5903
5904 unsigned PreferredFusedOpcode =
5905 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5906
5907 auto buildMatchInfo = [=](Register Dst, Register X, Register Y, Register Z,
5908 MachineIRBuilder &B) {
5909 Register FpExtX = B.buildFPExt(Res: DstTy, Op: X).getReg(Idx: 0);
5910 Register FpExtY = B.buildFPExt(Res: DstTy, Op: Y).getReg(Idx: 0);
5911 B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {Dst}, SrcOps: {FpExtX, FpExtY, Z});
5912 };
5913
5914 MachineInstr *FMulMI;
5915 // fold (fsub (fpext (fneg (fmul x, y))), z) ->
5916 // (fneg (fma (fpext x), (fpext y), z))
5917 // fold (fsub (fneg (fpext (fmul x, y))), z) ->
5918 // (fneg (fma (fpext x), (fpext y), z))
5919 if ((mi_match(R: LHSReg, MRI, P: m_GFPExt(Src: m_GFNeg(Src: m_MInstr(MI&: FMulMI)))) ||
5920 mi_match(R: LHSReg, MRI, P: m_GFNeg(Src: m_GFPExt(Src: m_MInstr(MI&: FMulMI))))) &&
5921 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
5922 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstTy,
5923 SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) {
5924 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5925 Register FMAReg = MRI.createGenericVirtualRegister(Ty: DstTy);
5926 buildMatchInfo(FMAReg, FMulMI->getOperand(i: 1).getReg(),
5927 FMulMI->getOperand(i: 2).getReg(), RHSReg, B);
5928 B.buildFNeg(Dst: MI.getOperand(i: 0).getReg(), Src0: FMAReg);
5929 };
5930 return true;
5931 }
5932
5933 // fold (fsub x, (fpext (fneg (fmul y, z)))) -> (fma (fpext y), (fpext z), x)
5934 // fold (fsub x, (fneg (fpext (fmul y, z)))) -> (fma (fpext y), (fpext z), x)
5935 if ((mi_match(R: RHSReg, MRI, P: m_GFPExt(Src: m_GFNeg(Src: m_MInstr(MI&: FMulMI)))) ||
5936 mi_match(R: RHSReg, MRI, P: m_GFNeg(Src: m_GFPExt(Src: m_MInstr(MI&: FMulMI))))) &&
5937 isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) &&
5938 TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstTy,
5939 SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) {
5940 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5941 buildMatchInfo(MI.getOperand(i: 0).getReg(), FMulMI->getOperand(i: 1).getReg(),
5942 FMulMI->getOperand(i: 2).getReg(), LHSReg, B);
5943 };
5944 return true;
5945 }
5946
5947 return false;
5948}
5949
5950bool CombinerHelper::matchCombineFMinMaxNaN(MachineInstr &MI,
5951 unsigned &IdxToPropagate) {
5952 bool PropagateNaN;
5953 switch (MI.getOpcode()) {
5954 default:
5955 return false;
5956 case TargetOpcode::G_FMINNUM:
5957 case TargetOpcode::G_FMAXNUM:
5958 PropagateNaN = false;
5959 break;
5960 case TargetOpcode::G_FMINIMUM:
5961 case TargetOpcode::G_FMAXIMUM:
5962 PropagateNaN = true;
5963 break;
5964 }
5965
5966 auto MatchNaN = [&](unsigned Idx) {
5967 Register MaybeNaNReg = MI.getOperand(i: Idx).getReg();
5968 const ConstantFP *MaybeCst = getConstantFPVRegVal(VReg: MaybeNaNReg, MRI);
5969 if (!MaybeCst || !MaybeCst->getValueAPF().isNaN())
5970 return false;
5971 IdxToPropagate = PropagateNaN ? Idx : (Idx == 1 ? 2 : 1);
5972 return true;
5973 };
5974
5975 return MatchNaN(1) || MatchNaN(2);
5976}
5977
5978bool CombinerHelper::matchAddSubSameReg(MachineInstr &MI, Register &Src) {
5979 assert(MI.getOpcode() == TargetOpcode::G_ADD && "Expected a G_ADD");
5980 Register LHS = MI.getOperand(i: 1).getReg();
5981 Register RHS = MI.getOperand(i: 2).getReg();
5982
5983 // Helper lambda to check for opportunities for
5984 // A + (B - A) -> B
5985 // (B - A) + A -> B
5986 auto CheckFold = [&](Register MaybeSub, Register MaybeSameReg) {
5987 Register Reg;
5988 return mi_match(R: MaybeSub, MRI, P: m_GSub(L: m_Reg(R&: Src), R: m_Reg(R&: Reg))) &&
5989 Reg == MaybeSameReg;
5990 };
5991 return CheckFold(LHS, RHS) || CheckFold(RHS, LHS);
5992}
5993
5994bool CombinerHelper::matchBuildVectorIdentityFold(MachineInstr &MI,
5995 Register &MatchInfo) {
5996 // This combine folds the following patterns:
5997 //
5998 // G_BUILD_VECTOR_TRUNC (G_BITCAST(x), G_LSHR(G_BITCAST(x), k))
5999 // G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), G_TRUNC(G_LSHR(G_BITCAST(x), k)))
6000 // into
6001 // x
6002 // if
6003 // k == sizeof(VecEltTy)/2
6004 // type(x) == type(dst)
6005 //
6006 // G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), undef)
6007 // into
6008 // x
6009 // if
6010 // type(x) == type(dst)
6011
6012 LLT DstVecTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6013 LLT DstEltTy = DstVecTy.getElementType();
6014
6015 Register Lo, Hi;
6016
6017 if (mi_match(
6018 MI, MRI,
6019 P: m_GBuildVector(L: m_GTrunc(Src: m_GBitcast(Src: m_Reg(R&: Lo))), R: m_GImplicitDef()))) {
6020 MatchInfo = Lo;
6021 return MRI.getType(Reg: MatchInfo) == DstVecTy;
6022 }
6023
6024 std::optional<ValueAndVReg> ShiftAmount;
6025 const auto LoPattern = m_GBitcast(Src: m_Reg(R&: Lo));
6026 const auto HiPattern = m_GLShr(L: m_GBitcast(Src: m_Reg(R&: Hi)), R: m_GCst(ValReg&: ShiftAmount));
6027 if (mi_match(
6028 MI, MRI,
6029 P: m_any_of(preds: m_GBuildVectorTrunc(L: LoPattern, R: HiPattern),
6030 preds: m_GBuildVector(L: m_GTrunc(Src: LoPattern), R: m_GTrunc(Src: HiPattern))))) {
6031 if (Lo == Hi && ShiftAmount->Value == DstEltTy.getSizeInBits()) {
6032 MatchInfo = Lo;
6033 return MRI.getType(Reg: MatchInfo) == DstVecTy;
6034 }
6035 }
6036
6037 return false;
6038}
6039
6040bool CombinerHelper::matchTruncBuildVectorFold(MachineInstr &MI,
6041 Register &MatchInfo) {
6042 // Replace (G_TRUNC (G_BITCAST (G_BUILD_VECTOR x, y)) with just x
6043 // if type(x) == type(G_TRUNC)
6044 if (!mi_match(R: MI.getOperand(i: 1).getReg(), MRI,
6045 P: m_GBitcast(Src: m_GBuildVector(L: m_Reg(R&: MatchInfo), R: m_Reg()))))
6046 return false;
6047
6048 return MRI.getType(Reg: MatchInfo) == MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6049}
6050
6051bool CombinerHelper::matchTruncLshrBuildVectorFold(MachineInstr &MI,
6052 Register &MatchInfo) {
6053 // Replace (G_TRUNC (G_LSHR (G_BITCAST (G_BUILD_VECTOR x, y)), K)) with
6054 // y if K == size of vector element type
6055 std::optional<ValueAndVReg> ShiftAmt;
6056 if (!mi_match(R: MI.getOperand(i: 1).getReg(), MRI,
6057 P: m_GLShr(L: m_GBitcast(Src: m_GBuildVector(L: m_Reg(), R: m_Reg(R&: MatchInfo))),
6058 R: m_GCst(ValReg&: ShiftAmt))))
6059 return false;
6060
6061 LLT MatchTy = MRI.getType(Reg: MatchInfo);
6062 return ShiftAmt->Value.getZExtValue() == MatchTy.getSizeInBits() &&
6063 MatchTy == MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6064}
6065
6066unsigned CombinerHelper::getFPMinMaxOpcForSelect(
6067 CmpInst::Predicate Pred, LLT DstTy,
6068 SelectPatternNaNBehaviour VsNaNRetVal) const {
6069 assert(VsNaNRetVal != SelectPatternNaNBehaviour::NOT_APPLICABLE &&
6070 "Expected a NaN behaviour?");
6071 // Choose an opcode based off of legality or the behaviour when one of the
6072 // LHS/RHS may be NaN.
6073 switch (Pred) {
6074 default:
6075 return 0;
6076 case CmpInst::FCMP_UGT:
6077 case CmpInst::FCMP_UGE:
6078 case CmpInst::FCMP_OGT:
6079 case CmpInst::FCMP_OGE:
6080 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER)
6081 return TargetOpcode::G_FMAXNUM;
6082 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN)
6083 return TargetOpcode::G_FMAXIMUM;
6084 if (isLegal(Query: {TargetOpcode::G_FMAXNUM, {DstTy}}))
6085 return TargetOpcode::G_FMAXNUM;
6086 if (isLegal(Query: {TargetOpcode::G_FMAXIMUM, {DstTy}}))
6087 return TargetOpcode::G_FMAXIMUM;
6088 return 0;
6089 case CmpInst::FCMP_ULT:
6090 case CmpInst::FCMP_ULE:
6091 case CmpInst::FCMP_OLT:
6092 case CmpInst::FCMP_OLE:
6093 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER)
6094 return TargetOpcode::G_FMINNUM;
6095 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN)
6096 return TargetOpcode::G_FMINIMUM;
6097 if (isLegal(Query: {TargetOpcode::G_FMINNUM, {DstTy}}))
6098 return TargetOpcode::G_FMINNUM;
6099 if (!isLegal(Query: {TargetOpcode::G_FMINIMUM, {DstTy}}))
6100 return 0;
6101 return TargetOpcode::G_FMINIMUM;
6102 }
6103}
6104
6105CombinerHelper::SelectPatternNaNBehaviour
6106CombinerHelper::computeRetValAgainstNaN(Register LHS, Register RHS,
6107 bool IsOrderedComparison) const {
6108 bool LHSSafe = isKnownNeverNaN(Val: LHS, MRI);
6109 bool RHSSafe = isKnownNeverNaN(Val: RHS, MRI);
6110 // Completely unsafe.
6111 if (!LHSSafe && !RHSSafe)
6112 return SelectPatternNaNBehaviour::NOT_APPLICABLE;
6113 if (LHSSafe && RHSSafe)
6114 return SelectPatternNaNBehaviour::RETURNS_ANY;
6115 // An ordered comparison will return false when given a NaN, so it
6116 // returns the RHS.
6117 if (IsOrderedComparison)
6118 return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_NAN
6119 : SelectPatternNaNBehaviour::RETURNS_OTHER;
6120 // An unordered comparison will return true when given a NaN, so it
6121 // returns the LHS.
6122 return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_OTHER
6123 : SelectPatternNaNBehaviour::RETURNS_NAN;
6124}
6125
6126bool CombinerHelper::matchFPSelectToMinMax(Register Dst, Register Cond,
6127 Register TrueVal, Register FalseVal,
6128 BuildFnTy &MatchInfo) {
6129 // Match: select (fcmp cond x, y) x, y
6130 // select (fcmp cond x, y) y, x
6131 // And turn it into fminnum/fmaxnum or fmin/fmax based off of the condition.
6132 LLT DstTy = MRI.getType(Reg: Dst);
6133 // Bail out early on pointers, since we'll never want to fold to a min/max.
6134 if (DstTy.isPointer())
6135 return false;
6136 // Match a floating point compare with a less-than/greater-than predicate.
6137 // TODO: Allow multiple users of the compare if they are all selects.
6138 CmpInst::Predicate Pred;
6139 Register CmpLHS, CmpRHS;
6140 if (!mi_match(R: Cond, MRI,
6141 P: m_OneNonDBGUse(
6142 SP: m_GFCmp(P: m_Pred(P&: Pred), L: m_Reg(R&: CmpLHS), R: m_Reg(R&: CmpRHS)))) ||
6143 CmpInst::isEquality(pred: Pred))
6144 return false;
6145 SelectPatternNaNBehaviour ResWithKnownNaNInfo =
6146 computeRetValAgainstNaN(LHS: CmpLHS, RHS: CmpRHS, IsOrderedComparison: CmpInst::isOrdered(predicate: Pred));
6147 if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::NOT_APPLICABLE)
6148 return false;
6149 if (TrueVal == CmpRHS && FalseVal == CmpLHS) {
6150 std::swap(a&: CmpLHS, b&: CmpRHS);
6151 Pred = CmpInst::getSwappedPredicate(pred: Pred);
6152 if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_NAN)
6153 ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_OTHER;
6154 else if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_OTHER)
6155 ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_NAN;
6156 }
6157 if (TrueVal != CmpLHS || FalseVal != CmpRHS)
6158 return false;
6159 // Decide what type of max/min this should be based off of the predicate.
6160 unsigned Opc = getFPMinMaxOpcForSelect(Pred, DstTy, VsNaNRetVal: ResWithKnownNaNInfo);
6161 if (!Opc || !isLegal(Query: {Opc, {DstTy}}))
6162 return false;
6163 // Comparisons between signed zero and zero may have different results...
6164 // unless we have fmaximum/fminimum. In that case, we know -0 < 0.
6165 if (Opc != TargetOpcode::G_FMAXIMUM && Opc != TargetOpcode::G_FMINIMUM) {
6166 // We don't know if a comparison between two 0s will give us a consistent
6167 // result. Be conservative and only proceed if at least one side is
6168 // non-zero.
6169 auto KnownNonZeroSide = getFConstantVRegValWithLookThrough(VReg: CmpLHS, MRI);
6170 if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero()) {
6171 KnownNonZeroSide = getFConstantVRegValWithLookThrough(VReg: CmpRHS, MRI);
6172 if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero())
6173 return false;
6174 }
6175 }
6176 MatchInfo = [=](MachineIRBuilder &B) {
6177 B.buildInstr(Opc, DstOps: {Dst}, SrcOps: {CmpLHS, CmpRHS});
6178 };
6179 return true;
6180}
6181
6182bool CombinerHelper::matchSimplifySelectToMinMax(MachineInstr &MI,
6183 BuildFnTy &MatchInfo) {
6184 // TODO: Handle integer cases.
6185 assert(MI.getOpcode() == TargetOpcode::G_SELECT);
6186 // Condition may be fed by a truncated compare.
6187 Register Cond = MI.getOperand(i: 1).getReg();
6188 Register MaybeTrunc;
6189 if (mi_match(R: Cond, MRI, P: m_OneNonDBGUse(SP: m_GTrunc(Src: m_Reg(R&: MaybeTrunc)))))
6190 Cond = MaybeTrunc;
6191 Register Dst = MI.getOperand(i: 0).getReg();
6192 Register TrueVal = MI.getOperand(i: 2).getReg();
6193 Register FalseVal = MI.getOperand(i: 3).getReg();
6194 return matchFPSelectToMinMax(Dst, Cond, TrueVal, FalseVal, MatchInfo);
6195}
6196
6197bool CombinerHelper::matchRedundantBinOpInEquality(MachineInstr &MI,
6198 BuildFnTy &MatchInfo) {
6199 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
6200 // (X + Y) == X --> Y == 0
6201 // (X + Y) != X --> Y != 0
6202 // (X - Y) == X --> Y == 0
6203 // (X - Y) != X --> Y != 0
6204 // (X ^ Y) == X --> Y == 0
6205 // (X ^ Y) != X --> Y != 0
6206 Register Dst = MI.getOperand(i: 0).getReg();
6207 CmpInst::Predicate Pred;
6208 Register X, Y, OpLHS, OpRHS;
6209 bool MatchedSub = mi_match(
6210 R: Dst, MRI,
6211 P: m_c_GICmp(P: m_Pred(P&: Pred), L: m_Reg(R&: X), R: m_GSub(L: m_Reg(R&: OpLHS), R: m_Reg(R&: Y))));
6212 if (MatchedSub && X != OpLHS)
6213 return false;
6214 if (!MatchedSub) {
6215 if (!mi_match(R: Dst, MRI,
6216 P: m_c_GICmp(P: m_Pred(P&: Pred), L: m_Reg(R&: X),
6217 R: m_any_of(preds: m_GAdd(L: m_Reg(R&: OpLHS), R: m_Reg(R&: OpRHS)),
6218 preds: m_GXor(L: m_Reg(R&: OpLHS), R: m_Reg(R&: OpRHS))))))
6219 return false;
6220 Y = X == OpLHS ? OpRHS : X == OpRHS ? OpLHS : Register();
6221 }
6222 MatchInfo = [=](MachineIRBuilder &B) {
6223 auto Zero = B.buildConstant(Res: MRI.getType(Reg: Y), Val: 0);
6224 B.buildICmp(Pred, Res: Dst, Op0: Y, Op1: Zero);
6225 };
6226 return CmpInst::isEquality(pred: Pred) && Y.isValid();
6227}
6228
6229bool CombinerHelper::matchShiftsTooBig(MachineInstr &MI) {
6230 Register ShiftReg = MI.getOperand(i: 2).getReg();
6231 LLT ResTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg());
6232 auto IsShiftTooBig = [&](const Constant *C) {
6233 auto *CI = dyn_cast<ConstantInt>(Val: C);
6234 return CI && CI->uge(Num: ResTy.getScalarSizeInBits());
6235 };
6236 return matchUnaryPredicate(MRI, Reg: ShiftReg, Match: IsShiftTooBig);
6237}
6238
6239bool CombinerHelper::matchCommuteConstantToRHS(MachineInstr &MI) {
6240 Register LHS = MI.getOperand(i: 1).getReg();
6241 Register RHS = MI.getOperand(i: 2).getReg();
6242 auto *LHSDef = MRI.getVRegDef(Reg: LHS);
6243 if (getIConstantVRegVal(VReg: LHS, MRI).has_value())
6244 return true;
6245
6246 // LHS may be a G_CONSTANT_FOLD_BARRIER. If so we commute
6247 // as long as we don't already have a constant on the RHS.
6248 if (LHSDef->getOpcode() != TargetOpcode::G_CONSTANT_FOLD_BARRIER)
6249 return false;
6250 return MRI.getVRegDef(Reg: RHS)->getOpcode() !=
6251 TargetOpcode::G_CONSTANT_FOLD_BARRIER &&
6252 !getIConstantVRegVal(VReg: RHS, MRI);
6253}
6254
6255bool CombinerHelper::matchCommuteFPConstantToRHS(MachineInstr &MI) {
6256 Register LHS = MI.getOperand(i: 1).getReg();
6257 Register RHS = MI.getOperand(i: 2).getReg();
6258 std::optional<FPValueAndVReg> ValAndVReg;
6259 if (!mi_match(R: LHS, MRI, P: m_GFCstOrSplat(FPValReg&: ValAndVReg)))
6260 return false;
6261 return !mi_match(R: RHS, MRI, P: m_GFCstOrSplat(FPValReg&: ValAndVReg));
6262}
6263
6264void CombinerHelper::applyCommuteBinOpOperands(MachineInstr &MI) {
6265 Observer.changingInstr(MI);
6266 Register LHSReg = MI.getOperand(i: 1).getReg();
6267 Register RHSReg = MI.getOperand(i: 2).getReg();
6268 MI.getOperand(i: 1).setReg(RHSReg);
6269 MI.getOperand(i: 2).setReg(LHSReg);
6270 Observer.changedInstr(MI);
6271}
6272
6273bool CombinerHelper::isOneOrOneSplat(Register Src, bool AllowUndefs) {
6274 LLT SrcTy = MRI.getType(Reg: Src);
6275 if (SrcTy.isFixedVector())
6276 return isConstantSplatVector(Src, SplatValue: 1, AllowUndefs);
6277 if (SrcTy.isScalar()) {
6278 if (AllowUndefs && getOpcodeDef<GImplicitDef>(Reg: Src, MRI) != nullptr)
6279 return true;
6280 auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI);
6281 return IConstant && IConstant->Value == 1;
6282 }
6283 return false; // scalable vector
6284}
6285
6286bool CombinerHelper::isZeroOrZeroSplat(Register Src, bool AllowUndefs) {
6287 LLT SrcTy = MRI.getType(Reg: Src);
6288 if (SrcTy.isFixedVector())
6289 return isConstantSplatVector(Src, SplatValue: 0, AllowUndefs);
6290 if (SrcTy.isScalar()) {
6291 if (AllowUndefs && getOpcodeDef<GImplicitDef>(Reg: Src, MRI) != nullptr)
6292 return true;
6293 auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI);
6294 return IConstant && IConstant->Value == 0;
6295 }
6296 return false; // scalable vector
6297}
6298
6299// Ignores COPYs during conformance checks.
6300// FIXME scalable vectors.
6301bool CombinerHelper::isConstantSplatVector(Register Src, int64_t SplatValue,
6302 bool AllowUndefs) {
6303 GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Reg: Src, MRI);
6304 if (!BuildVector)
6305 return false;
6306 unsigned NumSources = BuildVector->getNumSources();
6307
6308 for (unsigned I = 0; I < NumSources; ++I) {
6309 GImplicitDef *ImplicitDef =
6310 getOpcodeDef<GImplicitDef>(Reg: BuildVector->getSourceReg(I), MRI);
6311 if (ImplicitDef && AllowUndefs)
6312 continue;
6313 if (ImplicitDef && !AllowUndefs)
6314 return false;
6315 std::optional<ValueAndVReg> IConstant =
6316 getIConstantVRegValWithLookThrough(VReg: BuildVector->getSourceReg(I), MRI);
6317 if (IConstant && IConstant->Value == SplatValue)
6318 continue;
6319 return false;
6320 }
6321 return true;
6322}
6323
6324// Ignores COPYs during lookups.
6325// FIXME scalable vectors
6326std::optional<APInt>
6327CombinerHelper::getConstantOrConstantSplatVector(Register Src) {
6328 auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI);
6329 if (IConstant)
6330 return IConstant->Value;
6331
6332 GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Reg: Src, MRI);
6333 if (!BuildVector)
6334 return std::nullopt;
6335 unsigned NumSources = BuildVector->getNumSources();
6336
6337 std::optional<APInt> Value = std::nullopt;
6338 for (unsigned I = 0; I < NumSources; ++I) {
6339 std::optional<ValueAndVReg> IConstant =
6340 getIConstantVRegValWithLookThrough(VReg: BuildVector->getSourceReg(I), MRI);
6341 if (!IConstant)
6342 return std::nullopt;
6343 if (!Value)
6344 Value = IConstant->Value;
6345 else if (*Value != IConstant->Value)
6346 return std::nullopt;
6347 }
6348 return Value;
6349}
6350
6351// TODO: use knownbits to determine zeros
6352bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
6353 BuildFnTy &MatchInfo) {
6354 uint32_t Flags = Select->getFlags();
6355 Register Dest = Select->getReg(Idx: 0);
6356 Register Cond = Select->getCondReg();
6357 Register True = Select->getTrueReg();
6358 Register False = Select->getFalseReg();
6359 LLT CondTy = MRI.getType(Reg: Select->getCondReg());
6360 LLT TrueTy = MRI.getType(Reg: Select->getTrueReg());
6361
6362 // We only do this combine for scalar boolean conditions.
6363 if (CondTy != LLT::scalar(SizeInBits: 1))
6364 return false;
6365
6366 // Both are scalars.
6367 std::optional<ValueAndVReg> TrueOpt =
6368 getIConstantVRegValWithLookThrough(VReg: True, MRI);
6369 std::optional<ValueAndVReg> FalseOpt =
6370 getIConstantVRegValWithLookThrough(VReg: False, MRI);
6371
6372 if (!TrueOpt || !FalseOpt)
6373 return false;
6374
6375 APInt TrueValue = TrueOpt->Value;
6376 APInt FalseValue = FalseOpt->Value;
6377
6378 // select Cond, 1, 0 --> zext (Cond)
6379 if (TrueValue.isOne() && FalseValue.isZero()) {
6380 MatchInfo = [=](MachineIRBuilder &B) {
6381 B.setInstrAndDebugLoc(*Select);
6382 B.buildZExtOrTrunc(Res: Dest, Op: Cond);
6383 };
6384 return true;
6385 }
6386
6387 // select Cond, -1, 0 --> sext (Cond)
6388 if (TrueValue.isAllOnes() && FalseValue.isZero()) {
6389 MatchInfo = [=](MachineIRBuilder &B) {
6390 B.setInstrAndDebugLoc(*Select);
6391 B.buildSExtOrTrunc(Res: Dest, Op: Cond);
6392 };
6393 return true;
6394 }
6395
6396 // select Cond, 0, 1 --> zext (!Cond)
6397 if (TrueValue.isZero() && FalseValue.isOne()) {
6398 MatchInfo = [=](MachineIRBuilder &B) {
6399 B.setInstrAndDebugLoc(*Select);
6400 Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy);
6401 B.buildNot(Dst: Inner, Src0: Cond);
6402 B.buildZExtOrTrunc(Res: Dest, Op: Inner);
6403 };
6404 return true;
6405 }
6406
6407 // select Cond, 0, -1 --> sext (!Cond)
6408 if (TrueValue.isZero() && FalseValue.isAllOnes()) {
6409 MatchInfo = [=](MachineIRBuilder &B) {
6410 B.setInstrAndDebugLoc(*Select);
6411 Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy);
6412 B.buildNot(Dst: Inner, Src0: Cond);
6413 B.buildSExtOrTrunc(Res: Dest, Op: Inner);
6414 };
6415 return true;
6416 }
6417
6418 // select Cond, C1, C1-1 --> add (zext Cond), C1-1
6419 if (TrueValue - 1 == FalseValue) {
6420 MatchInfo = [=](MachineIRBuilder &B) {
6421 B.setInstrAndDebugLoc(*Select);
6422 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
6423 B.buildZExtOrTrunc(Res: Inner, Op: Cond);
6424 B.buildAdd(Dst: Dest, Src0: Inner, Src1: False);
6425 };
6426 return true;
6427 }
6428
6429 // select Cond, C1, C1+1 --> add (sext Cond), C1+1
6430 if (TrueValue + 1 == FalseValue) {
6431 MatchInfo = [=](MachineIRBuilder &B) {
6432 B.setInstrAndDebugLoc(*Select);
6433 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
6434 B.buildSExtOrTrunc(Res: Inner, Op: Cond);
6435 B.buildAdd(Dst: Dest, Src0: Inner, Src1: False);
6436 };
6437 return true;
6438 }
6439
6440 // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
6441 if (TrueValue.isPowerOf2() && FalseValue.isZero()) {
6442 MatchInfo = [=](MachineIRBuilder &B) {
6443 B.setInstrAndDebugLoc(*Select);
6444 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
6445 B.buildZExtOrTrunc(Res: Inner, Op: Cond);
6446 // The shift amount must be scalar.
6447 LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy;
6448 auto ShAmtC = B.buildConstant(Res: ShiftTy, Val: TrueValue.exactLogBase2());
6449 B.buildShl(Dst: Dest, Src0: Inner, Src1: ShAmtC, Flags);
6450 };
6451 return true;
6452 }
6453 // select Cond, -1, C --> or (sext Cond), C
6454 if (TrueValue.isAllOnes()) {
6455 MatchInfo = [=](MachineIRBuilder &B) {
6456 B.setInstrAndDebugLoc(*Select);
6457 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
6458 B.buildSExtOrTrunc(Res: Inner, Op: Cond);
6459 B.buildOr(Dst: Dest, Src0: Inner, Src1: False, Flags);
6460 };
6461 return true;
6462 }
6463
6464 // select Cond, C, -1 --> or (sext (not Cond)), C
6465 if (FalseValue.isAllOnes()) {
6466 MatchInfo = [=](MachineIRBuilder &B) {
6467 B.setInstrAndDebugLoc(*Select);
6468 Register Not = MRI.createGenericVirtualRegister(Ty: CondTy);
6469 B.buildNot(Dst: Not, Src0: Cond);
6470 Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy);
6471 B.buildSExtOrTrunc(Res: Inner, Op: Not);
6472 B.buildOr(Dst: Dest, Src0: Inner, Src1: True, Flags);
6473 };
6474 return true;
6475 }
6476
6477 return false;
6478}
6479
6480// TODO: use knownbits to determine zeros
6481bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
6482 BuildFnTy &MatchInfo) {
6483 uint32_t Flags = Select->getFlags();
6484 Register DstReg = Select->getReg(Idx: 0);
6485 Register Cond = Select->getCondReg();
6486 Register True = Select->getTrueReg();
6487 Register False = Select->getFalseReg();
6488 LLT CondTy = MRI.getType(Reg: Select->getCondReg());
6489 LLT TrueTy = MRI.getType(Reg: Select->getTrueReg());
6490
6491 // Boolean or fixed vector of booleans.
6492 if (CondTy.isScalableVector() ||
6493 (CondTy.isFixedVector() &&
6494 CondTy.getElementType().getScalarSizeInBits() != 1) ||
6495 CondTy.getScalarSizeInBits() != 1)
6496 return false;
6497
6498 if (CondTy != TrueTy)
6499 return false;
6500
6501 // select Cond, Cond, F --> or Cond, F
6502 // select Cond, 1, F --> or Cond, F
6503 if ((Cond == True) || isOneOrOneSplat(Src: True, /* AllowUndefs */ true)) {
6504 MatchInfo = [=](MachineIRBuilder &B) {
6505 B.setInstrAndDebugLoc(*Select);
6506 Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy);
6507 B.buildZExtOrTrunc(Res: Ext, Op: Cond);
6508 B.buildOr(Dst: DstReg, Src0: Ext, Src1: False, Flags);
6509 };
6510 return true;
6511 }
6512
6513 // select Cond, T, Cond --> and Cond, T
6514 // select Cond, T, 0 --> and Cond, T
6515 if ((Cond == False) || isZeroOrZeroSplat(Src: False, /* AllowUndefs */ true)) {
6516 MatchInfo = [=](MachineIRBuilder &B) {
6517 B.setInstrAndDebugLoc(*Select);
6518 Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy);
6519 B.buildZExtOrTrunc(Res: Ext, Op: Cond);
6520 B.buildAnd(Dst: DstReg, Src0: Ext, Src1: True);
6521 };
6522 return true;
6523 }
6524
6525 // select Cond, T, 1 --> or (not Cond), T
6526 if (isOneOrOneSplat(Src: False, /* AllowUndefs */ true)) {
6527 MatchInfo = [=](MachineIRBuilder &B) {
6528 B.setInstrAndDebugLoc(*Select);
6529 // First the not.
6530 Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy);
6531 B.buildNot(Dst: Inner, Src0: Cond);
6532 // Then an ext to match the destination register.
6533 Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy);
6534 B.buildZExtOrTrunc(Res: Ext, Op: Inner);
6535 B.buildOr(Dst: DstReg, Src0: Ext, Src1: True, Flags);
6536 };
6537 return true;
6538 }
6539
6540 // select Cond, 0, F --> and (not Cond), F
6541 if (isZeroOrZeroSplat(Src: True, /* AllowUndefs */ true)) {
6542 MatchInfo = [=](MachineIRBuilder &B) {
6543 B.setInstrAndDebugLoc(*Select);
6544 // First the not.
6545 Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy);
6546 B.buildNot(Dst: Inner, Src0: Cond);
6547 // Then an ext to match the destination register.
6548 Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy);
6549 B.buildZExtOrTrunc(Res: Ext, Op: Inner);
6550 B.buildAnd(Dst: DstReg, Src0: Ext, Src1: False);
6551 };
6552 return true;
6553 }
6554
6555 return false;
6556}
6557
6558bool CombinerHelper::tryFoldSelectToIntMinMax(GSelect *Select,
6559 BuildFnTy &MatchInfo) {
6560 Register DstReg = Select->getReg(Idx: 0);
6561 Register Cond = Select->getCondReg();
6562 Register True = Select->getTrueReg();
6563 Register False = Select->getFalseReg();
6564 LLT DstTy = MRI.getType(Reg: DstReg);
6565
6566 if (DstTy.isPointer())
6567 return false;
6568
6569 // We need an G_ICMP on the condition register.
6570 GICmp *Cmp = getOpcodeDef<GICmp>(Reg: Cond, MRI);
6571 if (!Cmp)
6572 return false;
6573
6574 // We want to fold the icmp and replace the select.
6575 if (!MRI.hasOneNonDBGUse(RegNo: Cmp->getReg(Idx: 0)))
6576 return false;
6577
6578 CmpInst::Predicate Pred = Cmp->getCond();
6579 // We need a larger or smaller predicate for
6580 // canonicalization.
6581 if (CmpInst::isEquality(pred: Pred))
6582 return false;
6583
6584 Register CmpLHS = Cmp->getLHSReg();
6585 Register CmpRHS = Cmp->getRHSReg();
6586
6587 // We can swap CmpLHS and CmpRHS for higher hitrate.
6588 if (True == CmpRHS && False == CmpLHS) {
6589 std::swap(a&: CmpLHS, b&: CmpRHS);
6590 Pred = CmpInst::getSwappedPredicate(pred: Pred);
6591 }
6592
6593 // (icmp X, Y) ? X : Y -> integer minmax.
6594 // see matchSelectPattern in ValueTracking.
6595 // Legality between G_SELECT and integer minmax can differ.
6596 if (True == CmpLHS && False == CmpRHS) {
6597 switch (Pred) {
6598 case ICmpInst::ICMP_UGT:
6599 case ICmpInst::ICMP_UGE: {
6600 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_UMAX, DstTy}))
6601 return false;
6602 MatchInfo = [=](MachineIRBuilder &B) {
6603 B.buildUMax(Dst: DstReg, Src0: True, Src1: False);
6604 };
6605 return true;
6606 }
6607 case ICmpInst::ICMP_SGT:
6608 case ICmpInst::ICMP_SGE: {
6609 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SMAX, DstTy}))
6610 return false;
6611 MatchInfo = [=](MachineIRBuilder &B) {
6612 B.buildSMax(Dst: DstReg, Src0: True, Src1: False);
6613 };
6614 return true;
6615 }
6616 case ICmpInst::ICMP_ULT:
6617 case ICmpInst::ICMP_ULE: {
6618 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_UMIN, DstTy}))
6619 return false;
6620 MatchInfo = [=](MachineIRBuilder &B) {
6621 B.buildUMin(Dst: DstReg, Src0: True, Src1: False);
6622 };
6623 return true;
6624 }
6625 case ICmpInst::ICMP_SLT:
6626 case ICmpInst::ICMP_SLE: {
6627 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SMIN, DstTy}))
6628 return false;
6629 MatchInfo = [=](MachineIRBuilder &B) {
6630 B.buildSMin(Dst: DstReg, Src0: True, Src1: False);
6631 };
6632 return true;
6633 }
6634 default:
6635 return false;
6636 }
6637 }
6638
6639 return false;
6640}
6641
6642bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) {
6643 GSelect *Select = cast<GSelect>(Val: &MI);
6644
6645 if (tryFoldSelectOfConstants(Select, MatchInfo))
6646 return true;
6647
6648 if (tryFoldBoolSelectToLogic(Select, MatchInfo))
6649 return true;
6650
6651 if (tryFoldSelectToIntMinMax(Select, MatchInfo))
6652 return true;
6653
6654 return false;
6655}
6656
6657/// Fold (icmp Pred1 V1, C1) && (icmp Pred2 V2, C2)
6658/// or (icmp Pred1 V1, C1) || (icmp Pred2 V2, C2)
6659/// into a single comparison using range-based reasoning.
6660/// see InstCombinerImpl::foldAndOrOfICmpsUsingRanges.
6661bool CombinerHelper::tryFoldAndOrOrICmpsUsingRanges(GLogicalBinOp *Logic,
6662 BuildFnTy &MatchInfo) {
6663 assert(Logic->getOpcode() != TargetOpcode::G_XOR && "unexpected xor");
6664 bool IsAnd = Logic->getOpcode() == TargetOpcode::G_AND;
6665 Register DstReg = Logic->getReg(Idx: 0);
6666 Register LHS = Logic->getLHSReg();
6667 Register RHS = Logic->getRHSReg();
6668 unsigned Flags = Logic->getFlags();
6669
6670 // We need an G_ICMP on the LHS register.
6671 GICmp *Cmp1 = getOpcodeDef<GICmp>(Reg: LHS, MRI);
6672 if (!Cmp1)
6673 return false;
6674
6675 // We need an G_ICMP on the RHS register.
6676 GICmp *Cmp2 = getOpcodeDef<GICmp>(Reg: RHS, MRI);
6677 if (!Cmp2)
6678 return false;
6679
6680 // We want to fold the icmps.
6681 if (!MRI.hasOneNonDBGUse(RegNo: Cmp1->getReg(Idx: 0)) ||
6682 !MRI.hasOneNonDBGUse(RegNo: Cmp2->getReg(Idx: 0)))
6683 return false;
6684
6685 APInt C1;
6686 APInt C2;
6687 std::optional<ValueAndVReg> MaybeC1 =
6688 getIConstantVRegValWithLookThrough(VReg: Cmp1->getRHSReg(), MRI);
6689 if (!MaybeC1)
6690 return false;
6691 C1 = MaybeC1->Value;
6692
6693 std::optional<ValueAndVReg> MaybeC2 =
6694 getIConstantVRegValWithLookThrough(VReg: Cmp2->getRHSReg(), MRI);
6695 if (!MaybeC2)
6696 return false;
6697 C2 = MaybeC2->Value;
6698
6699 Register R1 = Cmp1->getLHSReg();
6700 Register R2 = Cmp2->getLHSReg();
6701 CmpInst::Predicate Pred1 = Cmp1->getCond();
6702 CmpInst::Predicate Pred2 = Cmp2->getCond();
6703 LLT CmpTy = MRI.getType(Reg: Cmp1->getReg(Idx: 0));
6704 LLT CmpOperandTy = MRI.getType(Reg: R1);
6705
6706 // We build ands, adds, and constants of type CmpOperandTy.
6707 // They must be legal to build.
6708 if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_AND, CmpOperandTy}) ||
6709 !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, CmpOperandTy}) ||
6710 !isConstantLegalOrBeforeLegalizer(Ty: CmpOperandTy))
6711 return false;
6712
6713 // Look through add of a constant offset on R1, R2, or both operands. This
6714 // allows us to interpret the R + C' < C'' range idiom into a proper range.
6715 std::optional<APInt> Offset1;
6716 std::optional<APInt> Offset2;
6717 if (R1 != R2) {
6718 if (GAdd *Add = getOpcodeDef<GAdd>(Reg: R1, MRI)) {
6719 std::optional<ValueAndVReg> MaybeOffset1 =
6720 getIConstantVRegValWithLookThrough(VReg: Add->getRHSReg(), MRI);
6721 if (MaybeOffset1) {
6722 R1 = Add->getLHSReg();
6723 Offset1 = MaybeOffset1->Value;
6724 }
6725 }
6726 if (GAdd *Add = getOpcodeDef<GAdd>(Reg: R2, MRI)) {
6727 std::optional<ValueAndVReg> MaybeOffset2 =
6728 getIConstantVRegValWithLookThrough(VReg: Add->getRHSReg(), MRI);
6729 if (MaybeOffset2) {
6730 R2 = Add->getLHSReg();
6731 Offset2 = MaybeOffset2->Value;
6732 }
6733 }
6734 }
6735
6736 if (R1 != R2)
6737 return false;
6738
6739 // We calculate the icmp ranges including maybe offsets.
6740 ConstantRange CR1 = ConstantRange::makeExactICmpRegion(
6741 Pred: IsAnd ? ICmpInst::getInversePredicate(pred: Pred1) : Pred1, Other: C1);
6742 if (Offset1)
6743 CR1 = CR1.subtract(CI: *Offset1);
6744
6745 ConstantRange CR2 = ConstantRange::makeExactICmpRegion(
6746 Pred: IsAnd ? ICmpInst::getInversePredicate(pred: Pred2) : Pred2, Other: C2);
6747 if (Offset2)
6748 CR2 = CR2.subtract(CI: *Offset2);
6749
6750 bool CreateMask = false;
6751 APInt LowerDiff;
6752 std::optional<ConstantRange> CR = CR1.exactUnionWith(CR: CR2);
6753 if (!CR) {
6754 // We need non-wrapping ranges.
6755 if (CR1.isWrappedSet() || CR2.isWrappedSet())
6756 return false;
6757
6758 // Check whether we have equal-size ranges that only differ by one bit.
6759 // In that case we can apply a mask to map one range onto the other.
6760 LowerDiff = CR1.getLower() ^ CR2.getLower();
6761 APInt UpperDiff = (CR1.getUpper() - 1) ^ (CR2.getUpper() - 1);
6762 APInt CR1Size = CR1.getUpper() - CR1.getLower();
6763 if (!LowerDiff.isPowerOf2() || LowerDiff != UpperDiff ||
6764 CR1Size != CR2.getUpper() - CR2.getLower())
6765 return false;
6766
6767 CR = CR1.getLower().ult(RHS: CR2.getLower()) ? CR1 : CR2;
6768 CreateMask = true;
6769 }
6770
6771 if (IsAnd)
6772 CR = CR->inverse();
6773
6774 CmpInst::Predicate NewPred;
6775 APInt NewC, Offset;
6776 CR->getEquivalentICmp(Pred&: NewPred, RHS&: NewC, Offset);
6777
6778 // We take the result type of one of the original icmps, CmpTy, for
6779 // the to be build icmp. The operand type, CmpOperandTy, is used for
6780 // the other instructions and constants to be build. The types of
6781 // the parameters and output are the same for add and and. CmpTy
6782 // and the type of DstReg might differ. That is why we zext or trunc
6783 // the icmp into the destination register.
6784
6785 MatchInfo = [=](MachineIRBuilder &B) {
6786 if (CreateMask && Offset != 0) {
6787 auto TildeLowerDiff = B.buildConstant(Res: CmpOperandTy, Val: ~LowerDiff);
6788 auto And = B.buildAnd(Dst: CmpOperandTy, Src0: R1, Src1: TildeLowerDiff); // the mask.
6789 auto OffsetC = B.buildConstant(Res: CmpOperandTy, Val: Offset);
6790 auto Add = B.buildAdd(Dst: CmpOperandTy, Src0: And, Src1: OffsetC, Flags);
6791 auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC);
6792 auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: Add, Op1: NewCon);
6793 B.buildZExtOrTrunc(Res: DstReg, Op: ICmp);
6794 } else if (CreateMask && Offset == 0) {
6795 auto TildeLowerDiff = B.buildConstant(Res: CmpOperandTy, Val: ~LowerDiff);
6796 auto And = B.buildAnd(Dst: CmpOperandTy, Src0: R1, Src1: TildeLowerDiff); // the mask.
6797 auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC);
6798 auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: And, Op1: NewCon);
6799 B.buildZExtOrTrunc(Res: DstReg, Op: ICmp);
6800 } else if (!CreateMask && Offset != 0) {
6801 auto OffsetC = B.buildConstant(Res: CmpOperandTy, Val: Offset);
6802 auto Add = B.buildAdd(Dst: CmpOperandTy, Src0: R1, Src1: OffsetC, Flags);
6803 auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC);
6804 auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: Add, Op1: NewCon);
6805 B.buildZExtOrTrunc(Res: DstReg, Op: ICmp);
6806 } else if (!CreateMask && Offset == 0) {
6807 auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC);
6808 auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: R1, Op1: NewCon);
6809 B.buildZExtOrTrunc(Res: DstReg, Op: ICmp);
6810 } else {
6811 llvm_unreachable("unexpected configuration of CreateMask and Offset");
6812 }
6813 };
6814 return true;
6815}
6816
6817bool CombinerHelper::matchAnd(MachineInstr &MI, BuildFnTy &MatchInfo) {
6818 GAnd *And = cast<GAnd>(Val: &MI);
6819
6820 if (tryFoldAndOrOrICmpsUsingRanges(Logic: And, MatchInfo))
6821 return true;
6822
6823 return false;
6824}
6825
6826bool CombinerHelper::matchOr(MachineInstr &MI, BuildFnTy &MatchInfo) {
6827 GOr *Or = cast<GOr>(Val: &MI);
6828
6829 if (tryFoldAndOrOrICmpsUsingRanges(Logic: Or, MatchInfo))
6830 return true;
6831
6832 return false;
6833}
6834

source code of llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp