1//===- DAGCombiner.cpp - Implement a DAG node combiner --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass combines dag nodes to form fewer, simpler DAG nodes. It can be run
10// both before and after the DAG is legalized.
11//
12// This pass is not a substitute for the LLVM IR instcombine pass. This pass is
13// primarily intended to handle simplification opportunities that are implicit
14// in the LLVM IR and exposed by the various codegen lowering phases.
15//
16//===----------------------------------------------------------------------===//
17
18#include "llvm/ADT/APFloat.h"
19#include "llvm/ADT/APInt.h"
20#include "llvm/ADT/ArrayRef.h"
21#include "llvm/ADT/DenseMap.h"
22#include "llvm/ADT/IntervalMap.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SetVector.h"
25#include "llvm/ADT/SmallBitVector.h"
26#include "llvm/ADT/SmallPtrSet.h"
27#include "llvm/ADT/SmallSet.h"
28#include "llvm/ADT/SmallVector.h"
29#include "llvm/ADT/Statistic.h"
30#include "llvm/Analysis/AliasAnalysis.h"
31#include "llvm/Analysis/MemoryLocation.h"
32#include "llvm/Analysis/TargetLibraryInfo.h"
33#include "llvm/Analysis/ValueTracking.h"
34#include "llvm/Analysis/VectorUtils.h"
35#include "llvm/CodeGen/ByteProvider.h"
36#include "llvm/CodeGen/DAGCombine.h"
37#include "llvm/CodeGen/ISDOpcodes.h"
38#include "llvm/CodeGen/MachineFunction.h"
39#include "llvm/CodeGen/MachineMemOperand.h"
40#include "llvm/CodeGen/RuntimeLibcalls.h"
41#include "llvm/CodeGen/SelectionDAG.h"
42#include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
43#include "llvm/CodeGen/SelectionDAGNodes.h"
44#include "llvm/CodeGen/SelectionDAGTargetInfo.h"
45#include "llvm/CodeGen/TargetLowering.h"
46#include "llvm/CodeGen/TargetRegisterInfo.h"
47#include "llvm/CodeGen/TargetSubtargetInfo.h"
48#include "llvm/CodeGen/ValueTypes.h"
49#include "llvm/CodeGenTypes/MachineValueType.h"
50#include "llvm/IR/Attributes.h"
51#include "llvm/IR/Constant.h"
52#include "llvm/IR/DataLayout.h"
53#include "llvm/IR/DerivedTypes.h"
54#include "llvm/IR/Function.h"
55#include "llvm/IR/Metadata.h"
56#include "llvm/Support/Casting.h"
57#include "llvm/Support/CodeGen.h"
58#include "llvm/Support/CommandLine.h"
59#include "llvm/Support/Compiler.h"
60#include "llvm/Support/Debug.h"
61#include "llvm/Support/DebugCounter.h"
62#include "llvm/Support/ErrorHandling.h"
63#include "llvm/Support/KnownBits.h"
64#include "llvm/Support/MathExtras.h"
65#include "llvm/Support/raw_ostream.h"
66#include "llvm/Target/TargetMachine.h"
67#include "llvm/Target/TargetOptions.h"
68#include <algorithm>
69#include <cassert>
70#include <cstdint>
71#include <functional>
72#include <iterator>
73#include <optional>
74#include <string>
75#include <tuple>
76#include <utility>
77#include <variant>
78
79using namespace llvm;
80
81#define DEBUG_TYPE "dagcombine"
82
83STATISTIC(NodesCombined , "Number of dag nodes combined");
84STATISTIC(PreIndexedNodes , "Number of pre-indexed nodes created");
85STATISTIC(PostIndexedNodes, "Number of post-indexed nodes created");
86STATISTIC(OpsNarrowed , "Number of load/op/store narrowed");
87STATISTIC(LdStFP2Int , "Number of fp load/store pairs transformed to int");
88STATISTIC(SlicedLoads, "Number of load sliced");
89STATISTIC(NumFPLogicOpsConv, "Number of logic ops converted to fp ops");
90
91DEBUG_COUNTER(DAGCombineCounter, "dagcombine",
92 "Controls whether a DAG combine is performed for a node");
93
94static cl::opt<bool>
95CombinerGlobalAA("combiner-global-alias-analysis", cl::Hidden,
96 cl::desc("Enable DAG combiner's use of IR alias analysis"));
97
98static cl::opt<bool>
99UseTBAA("combiner-use-tbaa", cl::Hidden, cl::init(Val: true),
100 cl::desc("Enable DAG combiner's use of TBAA"));
101
102#ifndef NDEBUG
103static cl::opt<std::string>
104CombinerAAOnlyFunc("combiner-aa-only-func", cl::Hidden,
105 cl::desc("Only use DAG-combiner alias analysis in this"
106 " function"));
107#endif
108
109/// Hidden option to stress test load slicing, i.e., when this option
110/// is enabled, load slicing bypasses most of its profitability guards.
111static cl::opt<bool>
112StressLoadSlicing("combiner-stress-load-slicing", cl::Hidden,
113 cl::desc("Bypass the profitability model of load slicing"),
114 cl::init(Val: false));
115
116static cl::opt<bool>
117 MaySplitLoadIndex("combiner-split-load-index", cl::Hidden, cl::init(Val: true),
118 cl::desc("DAG combiner may split indexing from loads"));
119
120static cl::opt<bool>
121 EnableStoreMerging("combiner-store-merging", cl::Hidden, cl::init(Val: true),
122 cl::desc("DAG combiner enable merging multiple stores "
123 "into a wider store"));
124
125static cl::opt<unsigned> TokenFactorInlineLimit(
126 "combiner-tokenfactor-inline-limit", cl::Hidden, cl::init(Val: 2048),
127 cl::desc("Limit the number of operands to inline for Token Factors"));
128
129static cl::opt<unsigned> StoreMergeDependenceLimit(
130 "combiner-store-merge-dependence-limit", cl::Hidden, cl::init(Val: 10),
131 cl::desc("Limit the number of times for the same StoreNode and RootNode "
132 "to bail out in store merging dependence check"));
133
134static cl::opt<bool> EnableReduceLoadOpStoreWidth(
135 "combiner-reduce-load-op-store-width", cl::Hidden, cl::init(Val: true),
136 cl::desc("DAG combiner enable reducing the width of load/op/store "
137 "sequence"));
138
139static cl::opt<bool> EnableShrinkLoadReplaceStoreWithStore(
140 "combiner-shrink-load-replace-store-with-store", cl::Hidden, cl::init(Val: true),
141 cl::desc("DAG combiner enable load/<replace bytes>/store with "
142 "a narrower store"));
143
144static cl::opt<bool> EnableVectorFCopySignExtendRound(
145 "combiner-vector-fcopysign-extend-round", cl::Hidden, cl::init(Val: false),
146 cl::desc(
147 "Enable merging extends and rounds into FCOPYSIGN on vector types"));
148
149namespace {
150
151 class DAGCombiner {
152 SelectionDAG &DAG;
153 const TargetLowering &TLI;
154 const SelectionDAGTargetInfo *STI;
155 CombineLevel Level = BeforeLegalizeTypes;
156 CodeGenOptLevel OptLevel;
157 bool LegalDAG = false;
158 bool LegalOperations = false;
159 bool LegalTypes = false;
160 bool ForCodeSize;
161 bool DisableGenericCombines;
162
163 /// Worklist of all of the nodes that need to be simplified.
164 ///
165 /// This must behave as a stack -- new nodes to process are pushed onto the
166 /// back and when processing we pop off of the back.
167 ///
168 /// The worklist will not contain duplicates but may contain null entries
169 /// due to nodes being deleted from the underlying DAG.
170 SmallVector<SDNode *, 64> Worklist;
171
172 /// Mapping from an SDNode to its position on the worklist.
173 ///
174 /// This is used to find and remove nodes from the worklist (by nulling
175 /// them) when they are deleted from the underlying DAG. It relies on
176 /// stable indices of nodes within the worklist.
177 DenseMap<SDNode *, unsigned> WorklistMap;
178
179 /// This records all nodes attempted to be added to the worklist since we
180 /// considered a new worklist entry. As we keep do not add duplicate nodes
181 /// in the worklist, this is different from the tail of the worklist.
182 SmallSetVector<SDNode *, 32> PruningList;
183
184 /// Set of nodes which have been combined (at least once).
185 ///
186 /// This is used to allow us to reliably add any operands of a DAG node
187 /// which have not yet been combined to the worklist.
188 SmallPtrSet<SDNode *, 32> CombinedNodes;
189
190 /// Map from candidate StoreNode to the pair of RootNode and count.
191 /// The count is used to track how many times we have seen the StoreNode
192 /// with the same RootNode bail out in dependence check. If we have seen
193 /// the bail out for the same pair many times over a limit, we won't
194 /// consider the StoreNode with the same RootNode as store merging
195 /// candidate again.
196 DenseMap<SDNode *, std::pair<SDNode *, unsigned>> StoreRootCountMap;
197
198 // AA - Used for DAG load/store alias analysis.
199 AliasAnalysis *AA;
200
201 /// When an instruction is simplified, add all users of the instruction to
202 /// the work lists because they might get more simplified now.
203 void AddUsersToWorklist(SDNode *N) {
204 for (SDNode *Node : N->uses())
205 AddToWorklist(N: Node);
206 }
207
208 /// Convenient shorthand to add a node and all of its user to the worklist.
209 void AddToWorklistWithUsers(SDNode *N) {
210 AddUsersToWorklist(N);
211 AddToWorklist(N);
212 }
213
214 // Prune potentially dangling nodes. This is called after
215 // any visit to a node, but should also be called during a visit after any
216 // failed combine which may have created a DAG node.
217 void clearAddedDanglingWorklistEntries() {
218 // Check any nodes added to the worklist to see if they are prunable.
219 while (!PruningList.empty()) {
220 auto *N = PruningList.pop_back_val();
221 if (N->use_empty())
222 recursivelyDeleteUnusedNodes(N);
223 }
224 }
225
226 SDNode *getNextWorklistEntry() {
227 // Before we do any work, remove nodes that are not in use.
228 clearAddedDanglingWorklistEntries();
229 SDNode *N = nullptr;
230 // The Worklist holds the SDNodes in order, but it may contain null
231 // entries.
232 while (!N && !Worklist.empty()) {
233 N = Worklist.pop_back_val();
234 }
235
236 if (N) {
237 bool GoodWorklistEntry = WorklistMap.erase(Val: N);
238 (void)GoodWorklistEntry;
239 assert(GoodWorklistEntry &&
240 "Found a worklist entry without a corresponding map entry!");
241 }
242 return N;
243 }
244
245 /// Call the node-specific routine that folds each particular type of node.
246 SDValue visit(SDNode *N);
247
248 public:
249 DAGCombiner(SelectionDAG &D, AliasAnalysis *AA, CodeGenOptLevel OL)
250 : DAG(D), TLI(D.getTargetLoweringInfo()),
251 STI(D.getSubtarget().getSelectionDAGInfo()), OptLevel(OL), AA(AA) {
252 ForCodeSize = DAG.shouldOptForSize();
253 DisableGenericCombines = STI && STI->disableGenericCombines(OptLevel);
254
255 MaximumLegalStoreInBits = 0;
256 // We use the minimum store size here, since that's all we can guarantee
257 // for the scalable vector types.
258 for (MVT VT : MVT::all_valuetypes())
259 if (EVT(VT).isSimple() && VT != MVT::Other &&
260 TLI.isTypeLegal(EVT(VT)) &&
261 VT.getSizeInBits().getKnownMinValue() >= MaximumLegalStoreInBits)
262 MaximumLegalStoreInBits = VT.getSizeInBits().getKnownMinValue();
263 }
264
265 void ConsiderForPruning(SDNode *N) {
266 // Mark this for potential pruning.
267 PruningList.insert(X: N);
268 }
269
270 /// Add to the worklist making sure its instance is at the back (next to be
271 /// processed.)
272 void AddToWorklist(SDNode *N, bool IsCandidateForPruning = true) {
273 assert(N->getOpcode() != ISD::DELETED_NODE &&
274 "Deleted Node added to Worklist");
275
276 // Skip handle nodes as they can't usefully be combined and confuse the
277 // zero-use deletion strategy.
278 if (N->getOpcode() == ISD::HANDLENODE)
279 return;
280
281 if (IsCandidateForPruning)
282 ConsiderForPruning(N);
283
284 if (WorklistMap.insert(KV: std::make_pair(x&: N, y: Worklist.size())).second)
285 Worklist.push_back(Elt: N);
286 }
287
288 /// Remove all instances of N from the worklist.
289 void removeFromWorklist(SDNode *N) {
290 CombinedNodes.erase(Ptr: N);
291 PruningList.remove(X: N);
292 StoreRootCountMap.erase(Val: N);
293
294 auto It = WorklistMap.find(Val: N);
295 if (It == WorklistMap.end())
296 return; // Not in the worklist.
297
298 // Null out the entry rather than erasing it to avoid a linear operation.
299 Worklist[It->second] = nullptr;
300 WorklistMap.erase(I: It);
301 }
302
303 void deleteAndRecombine(SDNode *N);
304 bool recursivelyDeleteUnusedNodes(SDNode *N);
305
306 /// Replaces all uses of the results of one DAG node with new values.
307 SDValue CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
308 bool AddTo = true);
309
310 /// Replaces all uses of the results of one DAG node with new values.
311 SDValue CombineTo(SDNode *N, SDValue Res, bool AddTo = true) {
312 return CombineTo(N, To: &Res, NumTo: 1, AddTo);
313 }
314
315 /// Replaces all uses of the results of one DAG node with new values.
316 SDValue CombineTo(SDNode *N, SDValue Res0, SDValue Res1,
317 bool AddTo = true) {
318 SDValue To[] = { Res0, Res1 };
319 return CombineTo(N, To, NumTo: 2, AddTo);
320 }
321
322 void CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO);
323
324 private:
325 unsigned MaximumLegalStoreInBits;
326
327 /// Check the specified integer node value to see if it can be simplified or
328 /// if things it uses can be simplified by bit propagation.
329 /// If so, return true.
330 bool SimplifyDemandedBits(SDValue Op) {
331 unsigned BitWidth = Op.getScalarValueSizeInBits();
332 APInt DemandedBits = APInt::getAllOnes(numBits: BitWidth);
333 return SimplifyDemandedBits(Op, DemandedBits);
334 }
335
336 bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits) {
337 EVT VT = Op.getValueType();
338 APInt DemandedElts = VT.isFixedLengthVector()
339 ? APInt::getAllOnes(numBits: VT.getVectorNumElements())
340 : APInt(1, 1);
341 return SimplifyDemandedBits(Op, DemandedBits, DemandedElts, AssumeSingleUse: false);
342 }
343
344 /// Check the specified vector node value to see if it can be simplified or
345 /// if things it uses can be simplified as it only uses some of the
346 /// elements. If so, return true.
347 bool SimplifyDemandedVectorElts(SDValue Op) {
348 // TODO: For now just pretend it cannot be simplified.
349 if (Op.getValueType().isScalableVector())
350 return false;
351
352 unsigned NumElts = Op.getValueType().getVectorNumElements();
353 APInt DemandedElts = APInt::getAllOnes(numBits: NumElts);
354 return SimplifyDemandedVectorElts(Op, DemandedElts);
355 }
356
357 bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
358 const APInt &DemandedElts,
359 bool AssumeSingleUse = false);
360 bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedElts,
361 bool AssumeSingleUse = false);
362
363 bool CombineToPreIndexedLoadStore(SDNode *N);
364 bool CombineToPostIndexedLoadStore(SDNode *N);
365 SDValue SplitIndexingFromLoad(LoadSDNode *LD);
366 bool SliceUpLoad(SDNode *N);
367
368 // Looks up the chain to find a unique (unaliased) store feeding the passed
369 // load. If no such store is found, returns a nullptr.
370 // Note: This will look past a CALLSEQ_START if the load is chained to it so
371 // so that it can find stack stores for byval params.
372 StoreSDNode *getUniqueStoreFeeding(LoadSDNode *LD, int64_t &Offset);
373 // Scalars have size 0 to distinguish from singleton vectors.
374 SDValue ForwardStoreValueToDirectLoad(LoadSDNode *LD);
375 bool getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val);
376 bool extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val);
377
378 /// Replace an ISD::EXTRACT_VECTOR_ELT of a load with a narrowed
379 /// load.
380 ///
381 /// \param EVE ISD::EXTRACT_VECTOR_ELT to be replaced.
382 /// \param InVecVT type of the input vector to EVE with bitcasts resolved.
383 /// \param EltNo index of the vector element to load.
384 /// \param OriginalLoad load that EVE came from to be replaced.
385 /// \returns EVE on success SDValue() on failure.
386 SDValue scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
387 SDValue EltNo,
388 LoadSDNode *OriginalLoad);
389 void ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad);
390 SDValue PromoteOperand(SDValue Op, EVT PVT, bool &Replace);
391 SDValue SExtPromoteOperand(SDValue Op, EVT PVT);
392 SDValue ZExtPromoteOperand(SDValue Op, EVT PVT);
393 SDValue PromoteIntBinOp(SDValue Op);
394 SDValue PromoteIntShiftOp(SDValue Op);
395 SDValue PromoteExtend(SDValue Op);
396 bool PromoteLoad(SDValue Op);
397
398 SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
399 SDValue RHS, SDValue True, SDValue False,
400 ISD::CondCode CC);
401
402 /// Call the node-specific routine that knows how to fold each
403 /// particular type of node. If that doesn't do anything, try the
404 /// target-specific DAG combines.
405 SDValue combine(SDNode *N);
406
407 // Visitation implementation - Implement dag node combining for different
408 // node types. The semantics are as follows:
409 // Return Value:
410 // SDValue.getNode() == 0 - No change was made
411 // SDValue.getNode() == N - N was replaced, is dead and has been handled.
412 // otherwise - N should be replaced by the returned Operand.
413 //
414 SDValue visitTokenFactor(SDNode *N);
415 SDValue visitMERGE_VALUES(SDNode *N);
416 SDValue visitADD(SDNode *N);
417 SDValue visitADDLike(SDNode *N);
418 SDValue visitADDLikeCommutative(SDValue N0, SDValue N1, SDNode *LocReference);
419 SDValue visitSUB(SDNode *N);
420 SDValue visitADDSAT(SDNode *N);
421 SDValue visitSUBSAT(SDNode *N);
422 SDValue visitADDC(SDNode *N);
423 SDValue visitADDO(SDNode *N);
424 SDValue visitUADDOLike(SDValue N0, SDValue N1, SDNode *N);
425 SDValue visitSUBC(SDNode *N);
426 SDValue visitSUBO(SDNode *N);
427 SDValue visitADDE(SDNode *N);
428 SDValue visitUADDO_CARRY(SDNode *N);
429 SDValue visitSADDO_CARRY(SDNode *N);
430 SDValue visitUADDO_CARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
431 SDNode *N);
432 SDValue visitSADDO_CARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
433 SDNode *N);
434 SDValue visitSUBE(SDNode *N);
435 SDValue visitUSUBO_CARRY(SDNode *N);
436 SDValue visitSSUBO_CARRY(SDNode *N);
437 SDValue visitMUL(SDNode *N);
438 SDValue visitMULFIX(SDNode *N);
439 SDValue useDivRem(SDNode *N);
440 SDValue visitSDIV(SDNode *N);
441 SDValue visitSDIVLike(SDValue N0, SDValue N1, SDNode *N);
442 SDValue visitUDIV(SDNode *N);
443 SDValue visitUDIVLike(SDValue N0, SDValue N1, SDNode *N);
444 SDValue visitREM(SDNode *N);
445 SDValue visitMULHU(SDNode *N);
446 SDValue visitMULHS(SDNode *N);
447 SDValue visitAVG(SDNode *N);
448 SDValue visitABD(SDNode *N);
449 SDValue visitSMUL_LOHI(SDNode *N);
450 SDValue visitUMUL_LOHI(SDNode *N);
451 SDValue visitMULO(SDNode *N);
452 SDValue visitIMINMAX(SDNode *N);
453 SDValue visitAND(SDNode *N);
454 SDValue visitANDLike(SDValue N0, SDValue N1, SDNode *N);
455 SDValue visitOR(SDNode *N);
456 SDValue visitORLike(SDValue N0, SDValue N1, SDNode *N);
457 SDValue visitXOR(SDNode *N);
458 SDValue SimplifyVCastOp(SDNode *N, const SDLoc &DL);
459 SDValue SimplifyVBinOp(SDNode *N, const SDLoc &DL);
460 SDValue visitSHL(SDNode *N);
461 SDValue visitSRA(SDNode *N);
462 SDValue visitSRL(SDNode *N);
463 SDValue visitFunnelShift(SDNode *N);
464 SDValue visitSHLSAT(SDNode *N);
465 SDValue visitRotate(SDNode *N);
466 SDValue visitABS(SDNode *N);
467 SDValue visitBSWAP(SDNode *N);
468 SDValue visitBITREVERSE(SDNode *N);
469 SDValue visitCTLZ(SDNode *N);
470 SDValue visitCTLZ_ZERO_UNDEF(SDNode *N);
471 SDValue visitCTTZ(SDNode *N);
472 SDValue visitCTTZ_ZERO_UNDEF(SDNode *N);
473 SDValue visitCTPOP(SDNode *N);
474 SDValue visitSELECT(SDNode *N);
475 SDValue visitVSELECT(SDNode *N);
476 SDValue visitVP_SELECT(SDNode *N);
477 SDValue visitSELECT_CC(SDNode *N);
478 SDValue visitSETCC(SDNode *N);
479 SDValue visitSETCCCARRY(SDNode *N);
480 SDValue visitSIGN_EXTEND(SDNode *N);
481 SDValue visitZERO_EXTEND(SDNode *N);
482 SDValue visitANY_EXTEND(SDNode *N);
483 SDValue visitAssertExt(SDNode *N);
484 SDValue visitAssertAlign(SDNode *N);
485 SDValue visitSIGN_EXTEND_INREG(SDNode *N);
486 SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
487 SDValue visitTRUNCATE(SDNode *N);
488 SDValue visitBITCAST(SDNode *N);
489 SDValue visitFREEZE(SDNode *N);
490 SDValue visitBUILD_PAIR(SDNode *N);
491 SDValue visitFADD(SDNode *N);
492 SDValue visitVP_FADD(SDNode *N);
493 SDValue visitVP_FSUB(SDNode *N);
494 SDValue visitSTRICT_FADD(SDNode *N);
495 SDValue visitFSUB(SDNode *N);
496 SDValue visitFMUL(SDNode *N);
497 template <class MatchContextClass> SDValue visitFMA(SDNode *N);
498 SDValue visitFMAD(SDNode *N);
499 SDValue visitFDIV(SDNode *N);
500 SDValue visitFREM(SDNode *N);
501 SDValue visitFSQRT(SDNode *N);
502 SDValue visitFCOPYSIGN(SDNode *N);
503 SDValue visitFPOW(SDNode *N);
504 SDValue visitSINT_TO_FP(SDNode *N);
505 SDValue visitUINT_TO_FP(SDNode *N);
506 SDValue visitFP_TO_SINT(SDNode *N);
507 SDValue visitFP_TO_UINT(SDNode *N);
508 SDValue visitXRINT(SDNode *N);
509 SDValue visitFP_ROUND(SDNode *N);
510 SDValue visitFP_EXTEND(SDNode *N);
511 SDValue visitFNEG(SDNode *N);
512 SDValue visitFABS(SDNode *N);
513 SDValue visitFCEIL(SDNode *N);
514 SDValue visitFTRUNC(SDNode *N);
515 SDValue visitFFREXP(SDNode *N);
516 SDValue visitFFLOOR(SDNode *N);
517 SDValue visitFMinMax(SDNode *N);
518 SDValue visitBRCOND(SDNode *N);
519 SDValue visitBR_CC(SDNode *N);
520 SDValue visitLOAD(SDNode *N);
521
522 SDValue replaceStoreChain(StoreSDNode *ST, SDValue BetterChain);
523 SDValue replaceStoreOfFPConstant(StoreSDNode *ST);
524 SDValue replaceStoreOfInsertLoad(StoreSDNode *ST);
525
526 bool refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(SDNode *N);
527
528 SDValue visitSTORE(SDNode *N);
529 SDValue visitLIFETIME_END(SDNode *N);
530 SDValue visitINSERT_VECTOR_ELT(SDNode *N);
531 SDValue visitEXTRACT_VECTOR_ELT(SDNode *N);
532 SDValue visitBUILD_VECTOR(SDNode *N);
533 SDValue visitCONCAT_VECTORS(SDNode *N);
534 SDValue visitEXTRACT_SUBVECTOR(SDNode *N);
535 SDValue visitVECTOR_SHUFFLE(SDNode *N);
536 SDValue visitSCALAR_TO_VECTOR(SDNode *N);
537 SDValue visitINSERT_SUBVECTOR(SDNode *N);
538 SDValue visitMLOAD(SDNode *N);
539 SDValue visitMSTORE(SDNode *N);
540 SDValue visitMGATHER(SDNode *N);
541 SDValue visitMSCATTER(SDNode *N);
542 SDValue visitVPGATHER(SDNode *N);
543 SDValue visitVPSCATTER(SDNode *N);
544 SDValue visitVP_STRIDED_LOAD(SDNode *N);
545 SDValue visitVP_STRIDED_STORE(SDNode *N);
546 SDValue visitFP_TO_FP16(SDNode *N);
547 SDValue visitFP16_TO_FP(SDNode *N);
548 SDValue visitFP_TO_BF16(SDNode *N);
549 SDValue visitBF16_TO_FP(SDNode *N);
550 SDValue visitVECREDUCE(SDNode *N);
551 SDValue visitVPOp(SDNode *N);
552 SDValue visitGET_FPENV_MEM(SDNode *N);
553 SDValue visitSET_FPENV_MEM(SDNode *N);
554
555 template <class MatchContextClass>
556 SDValue visitFADDForFMACombine(SDNode *N);
557 template <class MatchContextClass>
558 SDValue visitFSUBForFMACombine(SDNode *N);
559 SDValue visitFMULForFMADistributiveCombine(SDNode *N);
560
561 SDValue XformToShuffleWithZero(SDNode *N);
562 bool reassociationCanBreakAddressingModePattern(unsigned Opc,
563 const SDLoc &DL,
564 SDNode *N,
565 SDValue N0,
566 SDValue N1);
567 SDValue reassociateOpsCommutative(unsigned Opc, const SDLoc &DL, SDValue N0,
568 SDValue N1, SDNodeFlags Flags);
569 SDValue reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
570 SDValue N1, SDNodeFlags Flags);
571 SDValue reassociateReduction(unsigned RedOpc, unsigned Opc, const SDLoc &DL,
572 EVT VT, SDValue N0, SDValue N1,
573 SDNodeFlags Flags = SDNodeFlags());
574
575 SDValue visitShiftByConstant(SDNode *N);
576
577 SDValue foldSelectOfConstants(SDNode *N);
578 SDValue foldVSelectOfConstants(SDNode *N);
579 SDValue foldBinOpIntoSelect(SDNode *BO);
580 bool SimplifySelectOps(SDNode *SELECT, SDValue LHS, SDValue RHS);
581 SDValue hoistLogicOpWithSameOpcodeHands(SDNode *N);
582 SDValue SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2);
583 SDValue SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
584 SDValue N2, SDValue N3, ISD::CondCode CC,
585 bool NotExtCompare = false);
586 SDValue convertSelectOfFPConstantsToLoadOffset(
587 const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
588 ISD::CondCode CC);
589 SDValue foldSignChangeInBitcast(SDNode *N);
590 SDValue foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0, SDValue N1,
591 SDValue N2, SDValue N3, ISD::CondCode CC);
592 SDValue foldSelectOfBinops(SDNode *N);
593 SDValue foldSextSetcc(SDNode *N);
594 SDValue foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
595 const SDLoc &DL);
596 SDValue foldSubToUSubSat(EVT DstVT, SDNode *N);
597 SDValue foldABSToABD(SDNode *N);
598 SDValue unfoldMaskedMerge(SDNode *N);
599 SDValue unfoldExtremeBitClearingToShifts(SDNode *N);
600 SDValue SimplifySetCC(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
601 const SDLoc &DL, bool foldBooleans);
602 SDValue rebuildSetCC(SDValue N);
603
604 bool isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
605 SDValue &CC, bool MatchStrict = false) const;
606 bool isOneUseSetCC(SDValue N) const;
607
608 SDValue SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
609 unsigned HiOp);
610 SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
611 SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
612 const TargetLowering &TLI);
613
614 SDValue CombineExtLoad(SDNode *N);
615 SDValue CombineZExtLogicopShiftLoad(SDNode *N);
616 SDValue combineRepeatedFPDivisors(SDNode *N);
617 SDValue combineFMulOrFDivWithIntPow2(SDNode *N);
618 SDValue mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex);
619 SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex);
620 SDValue combineInsertEltToLoad(SDNode *N, unsigned InsIndex);
621 SDValue ConstantFoldBITCASTofBUILD_VECTOR(SDNode *, EVT);
622 SDValue BuildSDIV(SDNode *N);
623 SDValue BuildSDIVPow2(SDNode *N);
624 SDValue BuildUDIV(SDNode *N);
625 SDValue BuildSREMPow2(SDNode *N);
626 SDValue buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N);
627 SDValue BuildLogBase2(SDValue V, const SDLoc &DL,
628 bool KnownNeverZero = false,
629 bool InexpensiveOnly = false,
630 std::optional<EVT> OutVT = std::nullopt);
631 SDValue BuildDivEstimate(SDValue N, SDValue Op, SDNodeFlags Flags);
632 SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags);
633 SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags Flags);
634 SDValue buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags, bool Recip);
635 SDValue buildSqrtNROneConst(SDValue Arg, SDValue Est, unsigned Iterations,
636 SDNodeFlags Flags, bool Reciprocal);
637 SDValue buildSqrtNRTwoConst(SDValue Arg, SDValue Est, unsigned Iterations,
638 SDNodeFlags Flags, bool Reciprocal);
639 SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
640 bool DemandHighBits = true);
641 SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
642 SDValue MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg,
643 SDValue InnerPos, SDValue InnerNeg, bool HasPos,
644 unsigned PosOpcode, unsigned NegOpcode,
645 const SDLoc &DL);
646 SDValue MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos, SDValue Neg,
647 SDValue InnerPos, SDValue InnerNeg, bool HasPos,
648 unsigned PosOpcode, unsigned NegOpcode,
649 const SDLoc &DL);
650 SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL);
651 SDValue MatchLoadCombine(SDNode *N);
652 SDValue mergeTruncStores(StoreSDNode *N);
653 SDValue reduceLoadWidth(SDNode *N);
654 SDValue ReduceLoadOpStoreWidth(SDNode *N);
655 SDValue splitMergedValStore(StoreSDNode *ST);
656 SDValue TransformFPLoadStorePair(SDNode *N);
657 SDValue convertBuildVecZextToZext(SDNode *N);
658 SDValue convertBuildVecZextToBuildVecWithZeros(SDNode *N);
659 SDValue reduceBuildVecExtToExtBuildVec(SDNode *N);
660 SDValue reduceBuildVecTruncToBitCast(SDNode *N);
661 SDValue reduceBuildVecToShuffle(SDNode *N);
662 SDValue createBuildVecShuffle(const SDLoc &DL, SDNode *N,
663 ArrayRef<int> VectorMask, SDValue VecIn1,
664 SDValue VecIn2, unsigned LeftIdx,
665 bool DidSplitVec);
666 SDValue matchVSelectOpSizesWithSetCC(SDNode *Cast);
667
668 /// Walk up chain skipping non-aliasing memory nodes,
669 /// looking for aliasing nodes and adding them to the Aliases vector.
670 void GatherAllAliases(SDNode *N, SDValue OriginalChain,
671 SmallVectorImpl<SDValue> &Aliases);
672
673 /// Return true if there is any possibility that the two addresses overlap.
674 bool mayAlias(SDNode *Op0, SDNode *Op1) const;
675
676 /// Walk up chain skipping non-aliasing memory nodes, looking for a better
677 /// chain (aliasing node.)
678 SDValue FindBetterChain(SDNode *N, SDValue Chain);
679
680 /// Try to replace a store and any possibly adjacent stores on
681 /// consecutive chains with better chains. Return true only if St is
682 /// replaced.
683 ///
684 /// Notice that other chains may still be replaced even if the function
685 /// returns false.
686 bool findBetterNeighborChains(StoreSDNode *St);
687
688 // Helper for findBetterNeighborChains. Walk up store chain add additional
689 // chained stores that do not overlap and can be parallelized.
690 bool parallelizeChainedStores(StoreSDNode *St);
691
692 /// Holds a pointer to an LSBaseSDNode as well as information on where it
693 /// is located in a sequence of memory operations connected by a chain.
694 struct MemOpLink {
695 // Ptr to the mem node.
696 LSBaseSDNode *MemNode;
697
698 // Offset from the base ptr.
699 int64_t OffsetFromBase;
700
701 MemOpLink(LSBaseSDNode *N, int64_t Offset)
702 : MemNode(N), OffsetFromBase(Offset) {}
703 };
704
705 // Classify the origin of a stored value.
706 enum class StoreSource { Unknown, Constant, Extract, Load };
707 StoreSource getStoreSource(SDValue StoreVal) {
708 switch (StoreVal.getOpcode()) {
709 case ISD::Constant:
710 case ISD::ConstantFP:
711 return StoreSource::Constant;
712 case ISD::BUILD_VECTOR:
713 if (ISD::isBuildVectorOfConstantSDNodes(N: StoreVal.getNode()) ||
714 ISD::isBuildVectorOfConstantFPSDNodes(N: StoreVal.getNode()))
715 return StoreSource::Constant;
716 return StoreSource::Unknown;
717 case ISD::EXTRACT_VECTOR_ELT:
718 case ISD::EXTRACT_SUBVECTOR:
719 return StoreSource::Extract;
720 case ISD::LOAD:
721 return StoreSource::Load;
722 default:
723 return StoreSource::Unknown;
724 }
725 }
726
727 /// This is a helper function for visitMUL to check the profitability
728 /// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
729 /// MulNode is the original multiply, AddNode is (add x, c1),
730 /// and ConstNode is c2.
731 bool isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
732 SDValue ConstNode);
733
734 /// This is a helper function for visitAND and visitZERO_EXTEND. Returns
735 /// true if the (and (load x) c) pattern matches an extload. ExtVT returns
736 /// the type of the loaded value to be extended.
737 bool isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
738 EVT LoadResultTy, EVT &ExtVT);
739
740 /// Helper function to calculate whether the given Load/Store can have its
741 /// width reduced to ExtVT.
742 bool isLegalNarrowLdSt(LSBaseSDNode *LDSTN, ISD::LoadExtType ExtType,
743 EVT &MemVT, unsigned ShAmt = 0);
744
745 /// Used by BackwardsPropagateMask to find suitable loads.
746 bool SearchForAndLoads(SDNode *N, SmallVectorImpl<LoadSDNode*> &Loads,
747 SmallPtrSetImpl<SDNode*> &NodesWithConsts,
748 ConstantSDNode *Mask, SDNode *&NodeToMask);
749 /// Attempt to propagate a given AND node back to load leaves so that they
750 /// can be combined into narrow loads.
751 bool BackwardsPropagateMask(SDNode *N);
752
753 /// Helper function for mergeConsecutiveStores which merges the component
754 /// store chains.
755 SDValue getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
756 unsigned NumStores);
757
758 /// Helper function for mergeConsecutiveStores which checks if all the store
759 /// nodes have the same underlying object. We can still reuse the first
760 /// store's pointer info if all the stores are from the same object.
761 bool hasSameUnderlyingObj(ArrayRef<MemOpLink> StoreNodes);
762
763 /// This is a helper function for mergeConsecutiveStores. When the source
764 /// elements of the consecutive stores are all constants or all extracted
765 /// vector elements, try to merge them into one larger store introducing
766 /// bitcasts if necessary. \return True if a merged store was created.
767 bool mergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> &StoreNodes,
768 EVT MemVT, unsigned NumStores,
769 bool IsConstantSrc, bool UseVector,
770 bool UseTrunc);
771
772 /// This is a helper function for mergeConsecutiveStores. Stores that
773 /// potentially may be merged with St are placed in StoreNodes. RootNode is
774 /// a chain predecessor to all store candidates.
775 void getStoreMergeCandidates(StoreSDNode *St,
776 SmallVectorImpl<MemOpLink> &StoreNodes,
777 SDNode *&Root);
778
779 /// Helper function for mergeConsecutiveStores. Checks if candidate stores
780 /// have indirect dependency through their operands. RootNode is the
781 /// predecessor to all stores calculated by getStoreMergeCandidates and is
782 /// used to prune the dependency check. \return True if safe to merge.
783 bool checkMergeStoreCandidatesForDependencies(
784 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
785 SDNode *RootNode);
786
787 /// This is a helper function for mergeConsecutiveStores. Given a list of
788 /// store candidates, find the first N that are consecutive in memory.
789 /// Returns 0 if there are not at least 2 consecutive stores to try merging.
790 unsigned getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
791 int64_t ElementSizeBytes) const;
792
793 /// This is a helper function for mergeConsecutiveStores. It is used for
794 /// store chains that are composed entirely of constant values.
795 bool tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> &StoreNodes,
796 unsigned NumConsecutiveStores,
797 EVT MemVT, SDNode *Root, bool AllowVectors);
798
799 /// This is a helper function for mergeConsecutiveStores. It is used for
800 /// store chains that are composed entirely of extracted vector elements.
801 /// When extracting multiple vector elements, try to store them in one
802 /// vector store rather than a sequence of scalar stores.
803 bool tryStoreMergeOfExtracts(SmallVectorImpl<MemOpLink> &StoreNodes,
804 unsigned NumConsecutiveStores, EVT MemVT,
805 SDNode *Root);
806
807 /// This is a helper function for mergeConsecutiveStores. It is used for
808 /// store chains that are composed entirely of loaded values.
809 bool tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
810 unsigned NumConsecutiveStores, EVT MemVT,
811 SDNode *Root, bool AllowVectors,
812 bool IsNonTemporalStore, bool IsNonTemporalLoad);
813
814 /// Merge consecutive store operations into a wide store.
815 /// This optimization uses wide integers or vectors when possible.
816 /// \return true if stores were merged.
817 bool mergeConsecutiveStores(StoreSDNode *St);
818
819 /// Try to transform a truncation where C is a constant:
820 /// (trunc (and X, C)) -> (and (trunc X), (trunc C))
821 ///
822 /// \p N needs to be a truncation and its first operand an AND. Other
823 /// requirements are checked by the function (e.g. that trunc is
824 /// single-use) and if missed an empty SDValue is returned.
825 SDValue distributeTruncateThroughAnd(SDNode *N);
826
827 /// Helper function to determine whether the target supports operation
828 /// given by \p Opcode for type \p VT, that is, whether the operation
829 /// is legal or custom before legalizing operations, and whether is
830 /// legal (but not custom) after legalization.
831 bool hasOperation(unsigned Opcode, EVT VT) {
832 return TLI.isOperationLegalOrCustom(Op: Opcode, VT, LegalOnly: LegalOperations);
833 }
834
835 public:
836 /// Runs the dag combiner on all nodes in the work list
837 void Run(CombineLevel AtLevel);
838
839 SelectionDAG &getDAG() const { return DAG; }
840
841 /// Returns a type large enough to hold any valid shift amount - before type
842 /// legalization these can be huge.
843 EVT getShiftAmountTy(EVT LHSTy) {
844 assert(LHSTy.isInteger() && "Shift amount is not an integer type!");
845 return TLI.getShiftAmountTy(LHSTy, DL: DAG.getDataLayout(), LegalTypes);
846 }
847
848 /// This method returns true if we are running before type legalization or
849 /// if the specified VT is legal.
850 bool isTypeLegal(const EVT &VT) {
851 if (!LegalTypes) return true;
852 return TLI.isTypeLegal(VT);
853 }
854
855 /// Convenience wrapper around TargetLowering::getSetCCResultType
856 EVT getSetCCResultType(EVT VT) const {
857 return TLI.getSetCCResultType(DL: DAG.getDataLayout(), Context&: *DAG.getContext(), VT);
858 }
859
860 void ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
861 SDValue OrigLoad, SDValue ExtLoad,
862 ISD::NodeType ExtType);
863 };
864
865/// This class is a DAGUpdateListener that removes any deleted
866/// nodes from the worklist.
867class WorklistRemover : public SelectionDAG::DAGUpdateListener {
868 DAGCombiner &DC;
869
870public:
871 explicit WorklistRemover(DAGCombiner &dc)
872 : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
873
874 void NodeDeleted(SDNode *N, SDNode *E) override {
875 DC.removeFromWorklist(N);
876 }
877};
878
879class WorklistInserter : public SelectionDAG::DAGUpdateListener {
880 DAGCombiner &DC;
881
882public:
883 explicit WorklistInserter(DAGCombiner &dc)
884 : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
885
886 // FIXME: Ideally we could add N to the worklist, but this causes exponential
887 // compile time costs in large DAGs, e.g. Halide.
888 void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); }
889};
890
891class EmptyMatchContext {
892 SelectionDAG &DAG;
893 const TargetLowering &TLI;
894
895public:
896 EmptyMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root)
897 : DAG(DAG), TLI(TLI) {}
898
899 bool match(SDValue OpN, unsigned Opcode) const {
900 return Opcode == OpN->getOpcode();
901 }
902
903 // Same as SelectionDAG::getNode().
904 template <typename... ArgT> SDValue getNode(ArgT &&...Args) {
905 return DAG.getNode(std::forward<ArgT>(Args)...);
906 }
907
908 bool isOperationLegalOrCustom(unsigned Op, EVT VT,
909 bool LegalOnly = false) const {
910 return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly);
911 }
912};
913
914class VPMatchContext {
915 SelectionDAG &DAG;
916 const TargetLowering &TLI;
917 SDValue RootMaskOp;
918 SDValue RootVectorLenOp;
919
920public:
921 VPMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root)
922 : DAG(DAG), TLI(TLI), RootMaskOp(), RootVectorLenOp() {
923 assert(Root->isVPOpcode());
924 if (auto RootMaskPos = ISD::getVPMaskIdx(Opcode: Root->getOpcode()))
925 RootMaskOp = Root->getOperand(Num: *RootMaskPos);
926 else if (Root->getOpcode() == ISD::VP_SELECT)
927 RootMaskOp = DAG.getAllOnesConstant(DL: SDLoc(Root),
928 VT: Root->getOperand(Num: 0).getValueType());
929
930 if (auto RootVLenPos =
931 ISD::getVPExplicitVectorLengthIdx(Opcode: Root->getOpcode()))
932 RootVectorLenOp = Root->getOperand(Num: *RootVLenPos);
933 }
934
935 /// whether \p OpVal is a node that is functionally compatible with the
936 /// NodeType \p Opc
937 bool match(SDValue OpVal, unsigned Opc) const {
938 if (!OpVal->isVPOpcode())
939 return OpVal->getOpcode() == Opc;
940
941 auto BaseOpc = ISD::getBaseOpcodeForVP(Opcode: OpVal->getOpcode(),
942 hasFPExcept: !OpVal->getFlags().hasNoFPExcept());
943 if (BaseOpc != Opc)
944 return false;
945
946 // Make sure the mask of OpVal is true mask or is same as Root's.
947 unsigned VPOpcode = OpVal->getOpcode();
948 if (auto MaskPos = ISD::getVPMaskIdx(Opcode: VPOpcode)) {
949 SDValue MaskOp = OpVal.getOperand(i: *MaskPos);
950 if (RootMaskOp != MaskOp &&
951 !ISD::isConstantSplatVectorAllOnes(N: MaskOp.getNode()))
952 return false;
953 }
954
955 // Make sure the EVL of OpVal is same as Root's.
956 if (auto VLenPos = ISD::getVPExplicitVectorLengthIdx(Opcode: VPOpcode))
957 if (RootVectorLenOp != OpVal.getOperand(i: *VLenPos))
958 return false;
959 return true;
960 }
961
962 // Specialize based on number of operands.
963 // TODO emit VP intrinsics where MaskOp/VectorLenOp != null
964 // SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return
965 // DAG.getNode(Opcode, DL, VT); }
966 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand) {
967 unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
968 assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
969 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
970 return DAG.getNode(Opcode: VPOpcode, DL, VT,
971 Ops: {Operand, RootMaskOp, RootVectorLenOp});
972 }
973
974 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
975 SDValue N2) {
976 unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
977 assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
978 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
979 return DAG.getNode(Opcode: VPOpcode, DL, VT,
980 Ops: {N1, N2, RootMaskOp, RootVectorLenOp});
981 }
982
983 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
984 SDValue N2, SDValue N3) {
985 unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
986 assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
987 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
988 return DAG.getNode(Opcode: VPOpcode, DL, VT,
989 Ops: {N1, N2, N3, RootMaskOp, RootVectorLenOp});
990 }
991
992 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand,
993 SDNodeFlags Flags) {
994 unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
995 assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
996 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
997 return DAG.getNode(Opcode: VPOpcode, DL, VT, Ops: {Operand, RootMaskOp, RootVectorLenOp},
998 Flags);
999 }
1000
1001 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
1002 SDValue N2, SDNodeFlags Flags) {
1003 unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
1004 assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
1005 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
1006 return DAG.getNode(Opcode: VPOpcode, DL, VT, Ops: {N1, N2, RootMaskOp, RootVectorLenOp},
1007 Flags);
1008 }
1009
1010 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
1011 SDValue N2, SDValue N3, SDNodeFlags Flags) {
1012 unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
1013 assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
1014 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
1015 return DAG.getNode(Opcode: VPOpcode, DL, VT,
1016 Ops: {N1, N2, N3, RootMaskOp, RootVectorLenOp}, Flags);
1017 }
1018
1019 bool isOperationLegalOrCustom(unsigned Op, EVT VT,
1020 bool LegalOnly = false) const {
1021 unsigned VPOp = ISD::getVPForBaseOpcode(Opcode: Op);
1022 return TLI.isOperationLegalOrCustom(Op: VPOp, VT, LegalOnly);
1023 }
1024};
1025
1026} // end anonymous namespace
1027
1028//===----------------------------------------------------------------------===//
1029// TargetLowering::DAGCombinerInfo implementation
1030//===----------------------------------------------------------------------===//
1031
1032void TargetLowering::DAGCombinerInfo::AddToWorklist(SDNode *N) {
1033 ((DAGCombiner*)DC)->AddToWorklist(N);
1034}
1035
1036SDValue TargetLowering::DAGCombinerInfo::
1037CombineTo(SDNode *N, ArrayRef<SDValue> To, bool AddTo) {
1038 return ((DAGCombiner*)DC)->CombineTo(N, To: &To[0], NumTo: To.size(), AddTo);
1039}
1040
1041SDValue TargetLowering::DAGCombinerInfo::
1042CombineTo(SDNode *N, SDValue Res, bool AddTo) {
1043 return ((DAGCombiner*)DC)->CombineTo(N, Res, AddTo);
1044}
1045
1046SDValue TargetLowering::DAGCombinerInfo::
1047CombineTo(SDNode *N, SDValue Res0, SDValue Res1, bool AddTo) {
1048 return ((DAGCombiner*)DC)->CombineTo(N, Res0, Res1, AddTo);
1049}
1050
1051bool TargetLowering::DAGCombinerInfo::
1052recursivelyDeleteUnusedNodes(SDNode *N) {
1053 return ((DAGCombiner*)DC)->recursivelyDeleteUnusedNodes(N);
1054}
1055
1056void TargetLowering::DAGCombinerInfo::
1057CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
1058 return ((DAGCombiner*)DC)->CommitTargetLoweringOpt(TLO);
1059}
1060
1061//===----------------------------------------------------------------------===//
1062// Helper Functions
1063//===----------------------------------------------------------------------===//
1064
1065void DAGCombiner::deleteAndRecombine(SDNode *N) {
1066 removeFromWorklist(N);
1067
1068 // If the operands of this node are only used by the node, they will now be
1069 // dead. Make sure to re-visit them and recursively delete dead nodes.
1070 for (const SDValue &Op : N->ops())
1071 // For an operand generating multiple values, one of the values may
1072 // become dead allowing further simplification (e.g. split index
1073 // arithmetic from an indexed load).
1074 if (Op->hasOneUse() || Op->getNumValues() > 1)
1075 AddToWorklist(N: Op.getNode());
1076
1077 DAG.DeleteNode(N);
1078}
1079
1080// APInts must be the same size for most operations, this helper
1081// function zero extends the shorter of the pair so that they match.
1082// We provide an Offset so that we can create bitwidths that won't overflow.
1083static void zeroExtendToMatch(APInt &LHS, APInt &RHS, unsigned Offset = 0) {
1084 unsigned Bits = Offset + std::max(a: LHS.getBitWidth(), b: RHS.getBitWidth());
1085 LHS = LHS.zext(width: Bits);
1086 RHS = RHS.zext(width: Bits);
1087}
1088
1089// Return true if this node is a setcc, or is a select_cc
1090// that selects between the target values used for true and false, making it
1091// equivalent to a setcc. Also, set the incoming LHS, RHS, and CC references to
1092// the appropriate nodes based on the type of node we are checking. This
1093// simplifies life a bit for the callers.
1094bool DAGCombiner::isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
1095 SDValue &CC, bool MatchStrict) const {
1096 if (N.getOpcode() == ISD::SETCC) {
1097 LHS = N.getOperand(i: 0);
1098 RHS = N.getOperand(i: 1);
1099 CC = N.getOperand(i: 2);
1100 return true;
1101 }
1102
1103 if (MatchStrict &&
1104 (N.getOpcode() == ISD::STRICT_FSETCC ||
1105 N.getOpcode() == ISD::STRICT_FSETCCS)) {
1106 LHS = N.getOperand(i: 1);
1107 RHS = N.getOperand(i: 2);
1108 CC = N.getOperand(i: 3);
1109 return true;
1110 }
1111
1112 if (N.getOpcode() != ISD::SELECT_CC || !TLI.isConstTrueVal(N: N.getOperand(i: 2)) ||
1113 !TLI.isConstFalseVal(N: N.getOperand(i: 3)))
1114 return false;
1115
1116 if (TLI.getBooleanContents(Type: N.getValueType()) ==
1117 TargetLowering::UndefinedBooleanContent)
1118 return false;
1119
1120 LHS = N.getOperand(i: 0);
1121 RHS = N.getOperand(i: 1);
1122 CC = N.getOperand(i: 4);
1123 return true;
1124}
1125
1126/// Return true if this is a SetCC-equivalent operation with only one use.
1127/// If this is true, it allows the users to invert the operation for free when
1128/// it is profitable to do so.
1129bool DAGCombiner::isOneUseSetCC(SDValue N) const {
1130 SDValue N0, N1, N2;
1131 if (isSetCCEquivalent(N, LHS&: N0, RHS&: N1, CC&: N2) && N->hasOneUse())
1132 return true;
1133 return false;
1134}
1135
1136static bool isConstantSplatVectorMaskForType(SDNode *N, EVT ScalarTy) {
1137 if (!ScalarTy.isSimple())
1138 return false;
1139
1140 uint64_t MaskForTy = 0ULL;
1141 switch (ScalarTy.getSimpleVT().SimpleTy) {
1142 case MVT::i8:
1143 MaskForTy = 0xFFULL;
1144 break;
1145 case MVT::i16:
1146 MaskForTy = 0xFFFFULL;
1147 break;
1148 case MVT::i32:
1149 MaskForTy = 0xFFFFFFFFULL;
1150 break;
1151 default:
1152 return false;
1153 break;
1154 }
1155
1156 APInt Val;
1157 if (ISD::isConstantSplatVector(N, SplatValue&: Val))
1158 return Val.getLimitedValue() == MaskForTy;
1159
1160 return false;
1161}
1162
1163// Determines if it is a constant integer or a splat/build vector of constant
1164// integers (and undefs).
1165// Do not permit build vector implicit truncation.
1166static bool isConstantOrConstantVector(SDValue N, bool NoOpaques = false) {
1167 if (ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Val&: N))
1168 return !(Const->isOpaque() && NoOpaques);
1169 if (N.getOpcode() != ISD::BUILD_VECTOR && N.getOpcode() != ISD::SPLAT_VECTOR)
1170 return false;
1171 unsigned BitWidth = N.getScalarValueSizeInBits();
1172 for (const SDValue &Op : N->op_values()) {
1173 if (Op.isUndef())
1174 continue;
1175 ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Val: Op);
1176 if (!Const || Const->getAPIntValue().getBitWidth() != BitWidth ||
1177 (Const->isOpaque() && NoOpaques))
1178 return false;
1179 }
1180 return true;
1181}
1182
1183// Determines if a BUILD_VECTOR is composed of all-constants possibly mixed with
1184// undef's.
1185static bool isAnyConstantBuildVector(SDValue V, bool NoOpaques = false) {
1186 if (V.getOpcode() != ISD::BUILD_VECTOR)
1187 return false;
1188 return isConstantOrConstantVector(N: V, NoOpaques) ||
1189 ISD::isBuildVectorOfConstantFPSDNodes(N: V.getNode());
1190}
1191
1192// Determine if this an indexed load with an opaque target constant index.
1193static bool canSplitIdx(LoadSDNode *LD) {
1194 return MaySplitLoadIndex &&
1195 (LD->getOperand(Num: 2).getOpcode() != ISD::TargetConstant ||
1196 !cast<ConstantSDNode>(Val: LD->getOperand(Num: 2))->isOpaque());
1197}
1198
1199bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc,
1200 const SDLoc &DL,
1201 SDNode *N,
1202 SDValue N0,
1203 SDValue N1) {
1204 // Currently this only tries to ensure we don't undo the GEP splits done by
1205 // CodeGenPrepare when shouldConsiderGEPOffsetSplit is true. To ensure this,
1206 // we check if the following transformation would be problematic:
1207 // (load/store (add, (add, x, offset1), offset2)) ->
1208 // (load/store (add, x, offset1+offset2)).
1209
1210 // (load/store (add, (add, x, y), offset2)) ->
1211 // (load/store (add, (add, x, offset2), y)).
1212
1213 if (Opc != ISD::ADD || N0.getOpcode() != ISD::ADD)
1214 return false;
1215
1216 auto *C2 = dyn_cast<ConstantSDNode>(Val&: N1);
1217 if (!C2)
1218 return false;
1219
1220 const APInt &C2APIntVal = C2->getAPIntValue();
1221 if (C2APIntVal.getSignificantBits() > 64)
1222 return false;
1223
1224 if (auto *C1 = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1))) {
1225 if (N0.hasOneUse())
1226 return false;
1227
1228 const APInt &C1APIntVal = C1->getAPIntValue();
1229 const APInt CombinedValueIntVal = C1APIntVal + C2APIntVal;
1230 if (CombinedValueIntVal.getSignificantBits() > 64)
1231 return false;
1232 const int64_t CombinedValue = CombinedValueIntVal.getSExtValue();
1233
1234 for (SDNode *Node : N->uses()) {
1235 if (auto *LoadStore = dyn_cast<MemSDNode>(Val: Node)) {
1236 // Is x[offset2] already not a legal addressing mode? If so then
1237 // reassociating the constants breaks nothing (we test offset2 because
1238 // that's the one we hope to fold into the load or store).
1239 TargetLoweringBase::AddrMode AM;
1240 AM.HasBaseReg = true;
1241 AM.BaseOffs = C2APIntVal.getSExtValue();
1242 EVT VT = LoadStore->getMemoryVT();
1243 unsigned AS = LoadStore->getAddressSpace();
1244 Type *AccessTy = VT.getTypeForEVT(Context&: *DAG.getContext());
1245 if (!TLI.isLegalAddressingMode(DL: DAG.getDataLayout(), AM, Ty: AccessTy, AddrSpace: AS))
1246 continue;
1247
1248 // Would x[offset1+offset2] still be a legal addressing mode?
1249 AM.BaseOffs = CombinedValue;
1250 if (!TLI.isLegalAddressingMode(DL: DAG.getDataLayout(), AM, Ty: AccessTy, AddrSpace: AS))
1251 return true;
1252 }
1253 }
1254 } else {
1255 if (auto *GA = dyn_cast<GlobalAddressSDNode>(Val: N0.getOperand(i: 1)))
1256 if (GA->getOpcode() == ISD::GlobalAddress && TLI.isOffsetFoldingLegal(GA))
1257 return false;
1258
1259 for (SDNode *Node : N->uses()) {
1260 auto *LoadStore = dyn_cast<MemSDNode>(Val: Node);
1261 if (!LoadStore)
1262 return false;
1263
1264 // Is x[offset2] a legal addressing mode? If so then
1265 // reassociating the constants breaks address pattern
1266 TargetLoweringBase::AddrMode AM;
1267 AM.HasBaseReg = true;
1268 AM.BaseOffs = C2APIntVal.getSExtValue();
1269 EVT VT = LoadStore->getMemoryVT();
1270 unsigned AS = LoadStore->getAddressSpace();
1271 Type *AccessTy = VT.getTypeForEVT(Context&: *DAG.getContext());
1272 if (!TLI.isLegalAddressingMode(DL: DAG.getDataLayout(), AM, Ty: AccessTy, AddrSpace: AS))
1273 return false;
1274 }
1275 return true;
1276 }
1277
1278 return false;
1279}
1280
1281// Helper for DAGCombiner::reassociateOps. Try to reassociate an expression
1282// such as (Opc N0, N1), if \p N0 is the same kind of operation as \p Opc.
1283SDValue DAGCombiner::reassociateOpsCommutative(unsigned Opc, const SDLoc &DL,
1284 SDValue N0, SDValue N1,
1285 SDNodeFlags Flags) {
1286 EVT VT = N0.getValueType();
1287
1288 if (N0.getOpcode() != Opc)
1289 return SDValue();
1290
1291 SDValue N00 = N0.getOperand(i: 0);
1292 SDValue N01 = N0.getOperand(i: 1);
1293
1294 if (DAG.isConstantIntBuildVectorOrConstantInt(N: peekThroughBitcasts(V: N01))) {
1295 if (DAG.isConstantIntBuildVectorOrConstantInt(N: peekThroughBitcasts(V: N1))) {
1296 // Reassociate: (op (op x, c1), c2) -> (op x, (op c1, c2))
1297 if (SDValue OpNode = DAG.FoldConstantArithmetic(Opcode: Opc, DL, VT, Ops: {N01, N1}))
1298 return DAG.getNode(Opcode: Opc, DL, VT, N1: N00, N2: OpNode);
1299 return SDValue();
1300 }
1301 if (TLI.isReassocProfitable(DAG, N0, N1)) {
1302 // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1)
1303 // iff (op x, c1) has one use
1304 SDNodeFlags NewFlags;
1305 if (N0.getOpcode() == ISD::ADD && N0->getFlags().hasNoUnsignedWrap() &&
1306 Flags.hasNoUnsignedWrap())
1307 NewFlags.setNoUnsignedWrap(true);
1308 SDValue OpNode = DAG.getNode(Opcode: Opc, DL: SDLoc(N0), VT, N1: N00, N2: N1, Flags: NewFlags);
1309 return DAG.getNode(Opcode: Opc, DL, VT, N1: OpNode, N2: N01, Flags: NewFlags);
1310 }
1311 }
1312
1313 // Check for repeated operand logic simplifications.
1314 if (Opc == ISD::AND || Opc == ISD::OR) {
1315 // (N00 & N01) & N00 --> N00 & N01
1316 // (N00 & N01) & N01 --> N00 & N01
1317 // (N00 | N01) | N00 --> N00 | N01
1318 // (N00 | N01) | N01 --> N00 | N01
1319 if (N1 == N00 || N1 == N01)
1320 return N0;
1321 }
1322 if (Opc == ISD::XOR) {
1323 // (N00 ^ N01) ^ N00 --> N01
1324 if (N1 == N00)
1325 return N01;
1326 // (N00 ^ N01) ^ N01 --> N00
1327 if (N1 == N01)
1328 return N00;
1329 }
1330
1331 if (TLI.isReassocProfitable(DAG, N0, N1)) {
1332 if (N1 != N01) {
1333 // Reassociate if (op N00, N1) already exist
1334 if (SDNode *NE = DAG.getNodeIfExists(Opcode: Opc, VTList: DAG.getVTList(VT), Ops: {N00, N1})) {
1335 // if Op (Op N00, N1), N01 already exist
1336 // we need to stop reassciate to avoid dead loop
1337 if (!DAG.doesNodeExist(Opcode: Opc, VTList: DAG.getVTList(VT), Ops: {SDValue(NE, 0), N01}))
1338 return DAG.getNode(Opcode: Opc, DL, VT, N1: SDValue(NE, 0), N2: N01);
1339 }
1340 }
1341
1342 if (N1 != N00) {
1343 // Reassociate if (op N01, N1) already exist
1344 if (SDNode *NE = DAG.getNodeIfExists(Opcode: Opc, VTList: DAG.getVTList(VT), Ops: {N01, N1})) {
1345 // if Op (Op N01, N1), N00 already exist
1346 // we need to stop reassciate to avoid dead loop
1347 if (!DAG.doesNodeExist(Opcode: Opc, VTList: DAG.getVTList(VT), Ops: {SDValue(NE, 0), N00}))
1348 return DAG.getNode(Opcode: Opc, DL, VT, N1: SDValue(NE, 0), N2: N00);
1349 }
1350 }
1351
1352 // Reassociate the operands from (OR/AND (OR/AND(N00, N001)), N1) to (OR/AND
1353 // (OR/AND(N00, N1)), N01) when N00 and N1 are comparisons with the same
1354 // predicate or to (OR/AND (OR/AND(N1, N01)), N00) when N01 and N1 are
1355 // comparisons with the same predicate. This enables optimizations as the
1356 // following one:
1357 // CMP(A,C)||CMP(B,C) => CMP(MIN/MAX(A,B), C)
1358 // CMP(A,C)&&CMP(B,C) => CMP(MIN/MAX(A,B), C)
1359 if (Opc == ISD::AND || Opc == ISD::OR) {
1360 if (N1->getOpcode() == ISD::SETCC && N00->getOpcode() == ISD::SETCC &&
1361 N01->getOpcode() == ISD::SETCC) {
1362 ISD::CondCode CC1 = cast<CondCodeSDNode>(Val: N1.getOperand(i: 2))->get();
1363 ISD::CondCode CC00 = cast<CondCodeSDNode>(Val: N00.getOperand(i: 2))->get();
1364 ISD::CondCode CC01 = cast<CondCodeSDNode>(Val: N01.getOperand(i: 2))->get();
1365 if (CC1 == CC00 && CC1 != CC01) {
1366 SDValue OpNode = DAG.getNode(Opcode: Opc, DL: SDLoc(N0), VT, N1: N00, N2: N1, Flags);
1367 return DAG.getNode(Opcode: Opc, DL, VT, N1: OpNode, N2: N01, Flags);
1368 }
1369 if (CC1 == CC01 && CC1 != CC00) {
1370 SDValue OpNode = DAG.getNode(Opcode: Opc, DL: SDLoc(N0), VT, N1: N01, N2: N1, Flags);
1371 return DAG.getNode(Opcode: Opc, DL, VT, N1: OpNode, N2: N00, Flags);
1372 }
1373 }
1374 }
1375 }
1376
1377 return SDValue();
1378}
1379
1380// Try to reassociate commutative binops.
1381SDValue DAGCombiner::reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
1382 SDValue N1, SDNodeFlags Flags) {
1383 assert(TLI.isCommutativeBinOp(Opc) && "Operation not commutative.");
1384
1385 // Floating-point reassociation is not allowed without loose FP math.
1386 if (N0.getValueType().isFloatingPoint() ||
1387 N1.getValueType().isFloatingPoint())
1388 if (!Flags.hasAllowReassociation() || !Flags.hasNoSignedZeros())
1389 return SDValue();
1390
1391 if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N0, N1, Flags))
1392 return Combined;
1393 if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N0: N1, N1: N0, Flags))
1394 return Combined;
1395 return SDValue();
1396}
1397
1398// Try to fold Opc(vecreduce(x), vecreduce(y)) -> vecreduce(Opc(x, y))
1399// Note that we only expect Flags to be passed from FP operations. For integer
1400// operations they need to be dropped.
1401SDValue DAGCombiner::reassociateReduction(unsigned RedOpc, unsigned Opc,
1402 const SDLoc &DL, EVT VT, SDValue N0,
1403 SDValue N1, SDNodeFlags Flags) {
1404 if (N0.getOpcode() == RedOpc && N1.getOpcode() == RedOpc &&
1405 N0.getOperand(i: 0).getValueType() == N1.getOperand(i: 0).getValueType() &&
1406 N0->hasOneUse() && N1->hasOneUse() &&
1407 TLI.isOperationLegalOrCustom(Op: Opc, VT: N0.getOperand(i: 0).getValueType()) &&
1408 TLI.shouldReassociateReduction(RedOpc, VT: N0.getOperand(i: 0).getValueType())) {
1409 SelectionDAG::FlagInserter FlagsInserter(DAG, Flags);
1410 return DAG.getNode(Opcode: RedOpc, DL, VT,
1411 Operand: DAG.getNode(Opcode: Opc, DL, VT: N0.getOperand(i: 0).getValueType(),
1412 N1: N0.getOperand(i: 0), N2: N1.getOperand(i: 0)));
1413 }
1414 return SDValue();
1415}
1416
1417SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
1418 bool AddTo) {
1419 assert(N->getNumValues() == NumTo && "Broken CombineTo call!");
1420 ++NodesCombined;
1421 LLVM_DEBUG(dbgs() << "\nReplacing.1 "; N->dump(&DAG); dbgs() << "\nWith: ";
1422 To[0].dump(&DAG);
1423 dbgs() << " and " << NumTo - 1 << " other values\n");
1424 for (unsigned i = 0, e = NumTo; i != e; ++i)
1425 assert((!To[i].getNode() ||
1426 N->getValueType(i) == To[i].getValueType()) &&
1427 "Cannot combine value to value of different type!");
1428
1429 WorklistRemover DeadNodes(*this);
1430 DAG.ReplaceAllUsesWith(From: N, To);
1431 if (AddTo) {
1432 // Push the new nodes and any users onto the worklist
1433 for (unsigned i = 0, e = NumTo; i != e; ++i) {
1434 if (To[i].getNode())
1435 AddToWorklistWithUsers(N: To[i].getNode());
1436 }
1437 }
1438
1439 // Finally, if the node is now dead, remove it from the graph. The node
1440 // may not be dead if the replacement process recursively simplified to
1441 // something else needing this node.
1442 if (N->use_empty())
1443 deleteAndRecombine(N);
1444 return SDValue(N, 0);
1445}
1446
1447void DAGCombiner::
1448CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
1449 // Replace the old value with the new one.
1450 ++NodesCombined;
1451 LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.dump(&DAG);
1452 dbgs() << "\nWith: "; TLO.New.dump(&DAG); dbgs() << '\n');
1453
1454 // Replace all uses.
1455 DAG.ReplaceAllUsesOfValueWith(From: TLO.Old, To: TLO.New);
1456
1457 // Push the new node and any (possibly new) users onto the worklist.
1458 AddToWorklistWithUsers(N: TLO.New.getNode());
1459
1460 // Finally, if the node is now dead, remove it from the graph.
1461 recursivelyDeleteUnusedNodes(N: TLO.Old.getNode());
1462}
1463
1464/// Check the specified integer node value to see if it can be simplified or if
1465/// things it uses can be simplified by bit propagation. If so, return true.
1466bool DAGCombiner::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
1467 const APInt &DemandedElts,
1468 bool AssumeSingleUse) {
1469 TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1470 KnownBits Known;
1471 if (!TLI.SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, Depth: 0,
1472 AssumeSingleUse))
1473 return false;
1474
1475 // Revisit the node.
1476 AddToWorklist(N: Op.getNode());
1477
1478 CommitTargetLoweringOpt(TLO);
1479 return true;
1480}
1481
1482/// Check the specified vector node value to see if it can be simplified or
1483/// if things it uses can be simplified as it only uses some of the elements.
1484/// If so, return true.
1485bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op,
1486 const APInt &DemandedElts,
1487 bool AssumeSingleUse) {
1488 TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1489 APInt KnownUndef, KnownZero;
1490 if (!TLI.SimplifyDemandedVectorElts(Op, DemandedEltMask: DemandedElts, KnownUndef, KnownZero,
1491 TLO, Depth: 0, AssumeSingleUse))
1492 return false;
1493
1494 // Revisit the node.
1495 AddToWorklist(N: Op.getNode());
1496
1497 CommitTargetLoweringOpt(TLO);
1498 return true;
1499}
1500
1501void DAGCombiner::ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad) {
1502 SDLoc DL(Load);
1503 EVT VT = Load->getValueType(ResNo: 0);
1504 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: SDValue(ExtLoad, 0));
1505
1506 LLVM_DEBUG(dbgs() << "\nReplacing.9 "; Load->dump(&DAG); dbgs() << "\nWith: ";
1507 Trunc.dump(&DAG); dbgs() << '\n');
1508
1509 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Load, 0), To: Trunc);
1510 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Load, 1), To: SDValue(ExtLoad, 1));
1511
1512 AddToWorklist(N: Trunc.getNode());
1513 recursivelyDeleteUnusedNodes(N: Load);
1514}
1515
1516SDValue DAGCombiner::PromoteOperand(SDValue Op, EVT PVT, bool &Replace) {
1517 Replace = false;
1518 SDLoc DL(Op);
1519 if (ISD::isUNINDEXEDLoad(N: Op.getNode())) {
1520 LoadSDNode *LD = cast<LoadSDNode>(Val&: Op);
1521 EVT MemVT = LD->getMemoryVT();
1522 ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(N: LD) ? ISD::EXTLOAD
1523 : LD->getExtensionType();
1524 Replace = true;
1525 return DAG.getExtLoad(ExtType, dl: DL, VT: PVT,
1526 Chain: LD->getChain(), Ptr: LD->getBasePtr(),
1527 MemVT, MMO: LD->getMemOperand());
1528 }
1529
1530 unsigned Opc = Op.getOpcode();
1531 switch (Opc) {
1532 default: break;
1533 case ISD::AssertSext:
1534 if (SDValue Op0 = SExtPromoteOperand(Op: Op.getOperand(i: 0), PVT))
1535 return DAG.getNode(Opcode: ISD::AssertSext, DL, VT: PVT, N1: Op0, N2: Op.getOperand(i: 1));
1536 break;
1537 case ISD::AssertZext:
1538 if (SDValue Op0 = ZExtPromoteOperand(Op: Op.getOperand(i: 0), PVT))
1539 return DAG.getNode(Opcode: ISD::AssertZext, DL, VT: PVT, N1: Op0, N2: Op.getOperand(i: 1));
1540 break;
1541 case ISD::Constant: {
1542 unsigned ExtOpc =
1543 Op.getValueType().isByteSized() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1544 return DAG.getNode(Opcode: ExtOpc, DL, VT: PVT, Operand: Op);
1545 }
1546 }
1547
1548 if (!TLI.isOperationLegal(Op: ISD::ANY_EXTEND, VT: PVT))
1549 return SDValue();
1550 return DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: PVT, Operand: Op);
1551}
1552
1553SDValue DAGCombiner::SExtPromoteOperand(SDValue Op, EVT PVT) {
1554 if (!TLI.isOperationLegal(Op: ISD::SIGN_EXTEND_INREG, VT: PVT))
1555 return SDValue();
1556 EVT OldVT = Op.getValueType();
1557 SDLoc DL(Op);
1558 bool Replace = false;
1559 SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1560 if (!NewOp.getNode())
1561 return SDValue();
1562 AddToWorklist(N: NewOp.getNode());
1563
1564 if (Replace)
1565 ReplaceLoadWithPromotedLoad(Load: Op.getNode(), ExtLoad: NewOp.getNode());
1566 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL, VT: NewOp.getValueType(), N1: NewOp,
1567 N2: DAG.getValueType(OldVT));
1568}
1569
1570SDValue DAGCombiner::ZExtPromoteOperand(SDValue Op, EVT PVT) {
1571 EVT OldVT = Op.getValueType();
1572 SDLoc DL(Op);
1573 bool Replace = false;
1574 SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1575 if (!NewOp.getNode())
1576 return SDValue();
1577 AddToWorklist(N: NewOp.getNode());
1578
1579 if (Replace)
1580 ReplaceLoadWithPromotedLoad(Load: Op.getNode(), ExtLoad: NewOp.getNode());
1581 return DAG.getZeroExtendInReg(Op: NewOp, DL, VT: OldVT);
1582}
1583
1584/// Promote the specified integer binary operation if the target indicates it is
1585/// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1586/// i32 since i16 instructions are longer.
1587SDValue DAGCombiner::PromoteIntBinOp(SDValue Op) {
1588 if (!LegalOperations)
1589 return SDValue();
1590
1591 EVT VT = Op.getValueType();
1592 if (VT.isVector() || !VT.isInteger())
1593 return SDValue();
1594
1595 // If operation type is 'undesirable', e.g. i16 on x86, consider
1596 // promoting it.
1597 unsigned Opc = Op.getOpcode();
1598 if (TLI.isTypeDesirableForOp(Opc, VT))
1599 return SDValue();
1600
1601 EVT PVT = VT;
1602 // Consult target whether it is a good idea to promote this operation and
1603 // what's the right type to promote it to.
1604 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1605 assert(PVT != VT && "Don't know what type to promote to!");
1606
1607 LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1608
1609 bool Replace0 = false;
1610 SDValue N0 = Op.getOperand(i: 0);
1611 SDValue NN0 = PromoteOperand(Op: N0, PVT, Replace&: Replace0);
1612
1613 bool Replace1 = false;
1614 SDValue N1 = Op.getOperand(i: 1);
1615 SDValue NN1 = PromoteOperand(Op: N1, PVT, Replace&: Replace1);
1616 SDLoc DL(Op);
1617
1618 SDValue RV =
1619 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: DAG.getNode(Opcode: Opc, DL, VT: PVT, N1: NN0, N2: NN1));
1620
1621 // We are always replacing N0/N1's use in N and only need additional
1622 // replacements if there are additional uses.
1623 // Note: We are checking uses of the *nodes* (SDNode) rather than values
1624 // (SDValue) here because the node may reference multiple values
1625 // (for example, the chain value of a load node).
1626 Replace0 &= !N0->hasOneUse();
1627 Replace1 &= (N0 != N1) && !N1->hasOneUse();
1628
1629 // Combine Op here so it is preserved past replacements.
1630 CombineTo(N: Op.getNode(), Res: RV);
1631
1632 // If operands have a use ordering, make sure we deal with
1633 // predecessor first.
1634 if (Replace0 && Replace1 && N0->isPredecessorOf(N: N1.getNode())) {
1635 std::swap(a&: N0, b&: N1);
1636 std::swap(a&: NN0, b&: NN1);
1637 }
1638
1639 if (Replace0) {
1640 AddToWorklist(N: NN0.getNode());
1641 ReplaceLoadWithPromotedLoad(Load: N0.getNode(), ExtLoad: NN0.getNode());
1642 }
1643 if (Replace1) {
1644 AddToWorklist(N: NN1.getNode());
1645 ReplaceLoadWithPromotedLoad(Load: N1.getNode(), ExtLoad: NN1.getNode());
1646 }
1647 return Op;
1648 }
1649 return SDValue();
1650}
1651
1652/// Promote the specified integer shift operation if the target indicates it is
1653/// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1654/// i32 since i16 instructions are longer.
1655SDValue DAGCombiner::PromoteIntShiftOp(SDValue Op) {
1656 if (!LegalOperations)
1657 return SDValue();
1658
1659 EVT VT = Op.getValueType();
1660 if (VT.isVector() || !VT.isInteger())
1661 return SDValue();
1662
1663 // If operation type is 'undesirable', e.g. i16 on x86, consider
1664 // promoting it.
1665 unsigned Opc = Op.getOpcode();
1666 if (TLI.isTypeDesirableForOp(Opc, VT))
1667 return SDValue();
1668
1669 EVT PVT = VT;
1670 // Consult target whether it is a good idea to promote this operation and
1671 // what's the right type to promote it to.
1672 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1673 assert(PVT != VT && "Don't know what type to promote to!");
1674
1675 LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1676
1677 bool Replace = false;
1678 SDValue N0 = Op.getOperand(i: 0);
1679 if (Opc == ISD::SRA)
1680 N0 = SExtPromoteOperand(Op: N0, PVT);
1681 else if (Opc == ISD::SRL)
1682 N0 = ZExtPromoteOperand(Op: N0, PVT);
1683 else
1684 N0 = PromoteOperand(Op: N0, PVT, Replace);
1685
1686 if (!N0.getNode())
1687 return SDValue();
1688
1689 SDLoc DL(Op);
1690 SDValue N1 = Op.getOperand(i: 1);
1691 SDValue RV =
1692 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: DAG.getNode(Opcode: Opc, DL, VT: PVT, N1: N0, N2: N1));
1693
1694 if (Replace)
1695 ReplaceLoadWithPromotedLoad(Load: Op.getOperand(i: 0).getNode(), ExtLoad: N0.getNode());
1696
1697 // Deal with Op being deleted.
1698 if (Op && Op.getOpcode() != ISD::DELETED_NODE)
1699 return RV;
1700 }
1701 return SDValue();
1702}
1703
1704SDValue DAGCombiner::PromoteExtend(SDValue Op) {
1705 if (!LegalOperations)
1706 return SDValue();
1707
1708 EVT VT = Op.getValueType();
1709 if (VT.isVector() || !VT.isInteger())
1710 return SDValue();
1711
1712 // If operation type is 'undesirable', e.g. i16 on x86, consider
1713 // promoting it.
1714 unsigned Opc = Op.getOpcode();
1715 if (TLI.isTypeDesirableForOp(Opc, VT))
1716 return SDValue();
1717
1718 EVT PVT = VT;
1719 // Consult target whether it is a good idea to promote this operation and
1720 // what's the right type to promote it to.
1721 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1722 assert(PVT != VT && "Don't know what type to promote to!");
1723 // fold (aext (aext x)) -> (aext x)
1724 // fold (aext (zext x)) -> (zext x)
1725 // fold (aext (sext x)) -> (sext x)
1726 LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1727 return DAG.getNode(Opcode: Op.getOpcode(), DL: SDLoc(Op), VT, Operand: Op.getOperand(i: 0));
1728 }
1729 return SDValue();
1730}
1731
1732bool DAGCombiner::PromoteLoad(SDValue Op) {
1733 if (!LegalOperations)
1734 return false;
1735
1736 if (!ISD::isUNINDEXEDLoad(N: Op.getNode()))
1737 return false;
1738
1739 EVT VT = Op.getValueType();
1740 if (VT.isVector() || !VT.isInteger())
1741 return false;
1742
1743 // If operation type is 'undesirable', e.g. i16 on x86, consider
1744 // promoting it.
1745 unsigned Opc = Op.getOpcode();
1746 if (TLI.isTypeDesirableForOp(Opc, VT))
1747 return false;
1748
1749 EVT PVT = VT;
1750 // Consult target whether it is a good idea to promote this operation and
1751 // what's the right type to promote it to.
1752 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1753 assert(PVT != VT && "Don't know what type to promote to!");
1754
1755 SDLoc DL(Op);
1756 SDNode *N = Op.getNode();
1757 LoadSDNode *LD = cast<LoadSDNode>(Val: N);
1758 EVT MemVT = LD->getMemoryVT();
1759 ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(N: LD) ? ISD::EXTLOAD
1760 : LD->getExtensionType();
1761 SDValue NewLD = DAG.getExtLoad(ExtType, dl: DL, VT: PVT,
1762 Chain: LD->getChain(), Ptr: LD->getBasePtr(),
1763 MemVT, MMO: LD->getMemOperand());
1764 SDValue Result = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: NewLD);
1765
1766 LLVM_DEBUG(dbgs() << "\nPromoting "; N->dump(&DAG); dbgs() << "\nTo: ";
1767 Result.dump(&DAG); dbgs() << '\n');
1768
1769 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result);
1770 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: NewLD.getValue(R: 1));
1771
1772 AddToWorklist(N: Result.getNode());
1773 recursivelyDeleteUnusedNodes(N);
1774 return true;
1775 }
1776
1777 return false;
1778}
1779
1780/// Recursively delete a node which has no uses and any operands for
1781/// which it is the only use.
1782///
1783/// Note that this both deletes the nodes and removes them from the worklist.
1784/// It also adds any nodes who have had a user deleted to the worklist as they
1785/// may now have only one use and subject to other combines.
1786bool DAGCombiner::recursivelyDeleteUnusedNodes(SDNode *N) {
1787 if (!N->use_empty())
1788 return false;
1789
1790 SmallSetVector<SDNode *, 16> Nodes;
1791 Nodes.insert(X: N);
1792 do {
1793 N = Nodes.pop_back_val();
1794 if (!N)
1795 continue;
1796
1797 if (N->use_empty()) {
1798 for (const SDValue &ChildN : N->op_values())
1799 Nodes.insert(X: ChildN.getNode());
1800
1801 removeFromWorklist(N);
1802 DAG.DeleteNode(N);
1803 } else {
1804 AddToWorklist(N);
1805 }
1806 } while (!Nodes.empty());
1807 return true;
1808}
1809
1810//===----------------------------------------------------------------------===//
1811// Main DAG Combiner implementation
1812//===----------------------------------------------------------------------===//
1813
1814void DAGCombiner::Run(CombineLevel AtLevel) {
1815 // set the instance variables, so that the various visit routines may use it.
1816 Level = AtLevel;
1817 LegalDAG = Level >= AfterLegalizeDAG;
1818 LegalOperations = Level >= AfterLegalizeVectorOps;
1819 LegalTypes = Level >= AfterLegalizeTypes;
1820
1821 WorklistInserter AddNodes(*this);
1822
1823 // Add all the dag nodes to the worklist.
1824 //
1825 // Note: All nodes are not added to PruningList here, this is because the only
1826 // nodes which can be deleted are those which have no uses and all other nodes
1827 // which would otherwise be added to the worklist by the first call to
1828 // getNextWorklistEntry are already present in it.
1829 for (SDNode &Node : DAG.allnodes())
1830 AddToWorklist(N: &Node, /* IsCandidateForPruning */ Node.use_empty());
1831
1832 // Create a dummy node (which is not added to allnodes), that adds a reference
1833 // to the root node, preventing it from being deleted, and tracking any
1834 // changes of the root.
1835 HandleSDNode Dummy(DAG.getRoot());
1836
1837 // While we have a valid worklist entry node, try to combine it.
1838 while (SDNode *N = getNextWorklistEntry()) {
1839 // If N has no uses, it is dead. Make sure to revisit all N's operands once
1840 // N is deleted from the DAG, since they too may now be dead or may have a
1841 // reduced number of uses, allowing other xforms.
1842 if (recursivelyDeleteUnusedNodes(N))
1843 continue;
1844
1845 WorklistRemover DeadNodes(*this);
1846
1847 // If this combine is running after legalizing the DAG, re-legalize any
1848 // nodes pulled off the worklist.
1849 if (LegalDAG) {
1850 SmallSetVector<SDNode *, 16> UpdatedNodes;
1851 bool NIsValid = DAG.LegalizeOp(N, UpdatedNodes);
1852
1853 for (SDNode *LN : UpdatedNodes)
1854 AddToWorklistWithUsers(N: LN);
1855
1856 if (!NIsValid)
1857 continue;
1858 }
1859
1860 LLVM_DEBUG(dbgs() << "\nCombining: "; N->dump(&DAG));
1861
1862 // Add any operands of the new node which have not yet been combined to the
1863 // worklist as well. Because the worklist uniques things already, this
1864 // won't repeatedly process the same operand.
1865 for (const SDValue &ChildN : N->op_values())
1866 if (!CombinedNodes.count(Ptr: ChildN.getNode()))
1867 AddToWorklist(N: ChildN.getNode());
1868
1869 CombinedNodes.insert(Ptr: N);
1870 SDValue RV = combine(N);
1871
1872 if (!RV.getNode())
1873 continue;
1874
1875 ++NodesCombined;
1876
1877 // If we get back the same node we passed in, rather than a new node or
1878 // zero, we know that the node must have defined multiple values and
1879 // CombineTo was used. Since CombineTo takes care of the worklist
1880 // mechanics for us, we have no work to do in this case.
1881 if (RV.getNode() == N)
1882 continue;
1883
1884 assert(N->getOpcode() != ISD::DELETED_NODE &&
1885 RV.getOpcode() != ISD::DELETED_NODE &&
1886 "Node was deleted but visit returned new node!");
1887
1888 LLVM_DEBUG(dbgs() << " ... into: "; RV.dump(&DAG));
1889
1890 if (N->getNumValues() == RV->getNumValues())
1891 DAG.ReplaceAllUsesWith(From: N, To: RV.getNode());
1892 else {
1893 assert(N->getValueType(0) == RV.getValueType() &&
1894 N->getNumValues() == 1 && "Type mismatch");
1895 DAG.ReplaceAllUsesWith(From: N, To: &RV);
1896 }
1897
1898 // Push the new node and any users onto the worklist. Omit this if the
1899 // new node is the EntryToken (e.g. if a store managed to get optimized
1900 // out), because re-visiting the EntryToken and its users will not uncover
1901 // any additional opportunities, but there may be a large number of such
1902 // users, potentially causing compile time explosion.
1903 if (RV.getOpcode() != ISD::EntryToken)
1904 AddToWorklistWithUsers(N: RV.getNode());
1905
1906 // Finally, if the node is now dead, remove it from the graph. The node
1907 // may not be dead if the replacement process recursively simplified to
1908 // something else needing this node. This will also take care of adding any
1909 // operands which have lost a user to the worklist.
1910 recursivelyDeleteUnusedNodes(N);
1911 }
1912
1913 // If the root changed (e.g. it was a dead load, update the root).
1914 DAG.setRoot(Dummy.getValue());
1915 DAG.RemoveDeadNodes();
1916}
1917
1918SDValue DAGCombiner::visit(SDNode *N) {
1919 // clang-format off
1920 switch (N->getOpcode()) {
1921 default: break;
1922 case ISD::TokenFactor: return visitTokenFactor(N);
1923 case ISD::MERGE_VALUES: return visitMERGE_VALUES(N);
1924 case ISD::ADD: return visitADD(N);
1925 case ISD::SUB: return visitSUB(N);
1926 case ISD::SADDSAT:
1927 case ISD::UADDSAT: return visitADDSAT(N);
1928 case ISD::SSUBSAT:
1929 case ISD::USUBSAT: return visitSUBSAT(N);
1930 case ISD::ADDC: return visitADDC(N);
1931 case ISD::SADDO:
1932 case ISD::UADDO: return visitADDO(N);
1933 case ISD::SUBC: return visitSUBC(N);
1934 case ISD::SSUBO:
1935 case ISD::USUBO: return visitSUBO(N);
1936 case ISD::ADDE: return visitADDE(N);
1937 case ISD::UADDO_CARRY: return visitUADDO_CARRY(N);
1938 case ISD::SADDO_CARRY: return visitSADDO_CARRY(N);
1939 case ISD::SUBE: return visitSUBE(N);
1940 case ISD::USUBO_CARRY: return visitUSUBO_CARRY(N);
1941 case ISD::SSUBO_CARRY: return visitSSUBO_CARRY(N);
1942 case ISD::SMULFIX:
1943 case ISD::SMULFIXSAT:
1944 case ISD::UMULFIX:
1945 case ISD::UMULFIXSAT: return visitMULFIX(N);
1946 case ISD::MUL: return visitMUL(N);
1947 case ISD::SDIV: return visitSDIV(N);
1948 case ISD::UDIV: return visitUDIV(N);
1949 case ISD::SREM:
1950 case ISD::UREM: return visitREM(N);
1951 case ISD::MULHU: return visitMULHU(N);
1952 case ISD::MULHS: return visitMULHS(N);
1953 case ISD::AVGFLOORS:
1954 case ISD::AVGFLOORU:
1955 case ISD::AVGCEILS:
1956 case ISD::AVGCEILU: return visitAVG(N);
1957 case ISD::ABDS:
1958 case ISD::ABDU: return visitABD(N);
1959 case ISD::SMUL_LOHI: return visitSMUL_LOHI(N);
1960 case ISD::UMUL_LOHI: return visitUMUL_LOHI(N);
1961 case ISD::SMULO:
1962 case ISD::UMULO: return visitMULO(N);
1963 case ISD::SMIN:
1964 case ISD::SMAX:
1965 case ISD::UMIN:
1966 case ISD::UMAX: return visitIMINMAX(N);
1967 case ISD::AND: return visitAND(N);
1968 case ISD::OR: return visitOR(N);
1969 case ISD::XOR: return visitXOR(N);
1970 case ISD::SHL: return visitSHL(N);
1971 case ISD::SRA: return visitSRA(N);
1972 case ISD::SRL: return visitSRL(N);
1973 case ISD::ROTR:
1974 case ISD::ROTL: return visitRotate(N);
1975 case ISD::FSHL:
1976 case ISD::FSHR: return visitFunnelShift(N);
1977 case ISD::SSHLSAT:
1978 case ISD::USHLSAT: return visitSHLSAT(N);
1979 case ISD::ABS: return visitABS(N);
1980 case ISD::BSWAP: return visitBSWAP(N);
1981 case ISD::BITREVERSE: return visitBITREVERSE(N);
1982 case ISD::CTLZ: return visitCTLZ(N);
1983 case ISD::CTLZ_ZERO_UNDEF: return visitCTLZ_ZERO_UNDEF(N);
1984 case ISD::CTTZ: return visitCTTZ(N);
1985 case ISD::CTTZ_ZERO_UNDEF: return visitCTTZ_ZERO_UNDEF(N);
1986 case ISD::CTPOP: return visitCTPOP(N);
1987 case ISD::SELECT: return visitSELECT(N);
1988 case ISD::VSELECT: return visitVSELECT(N);
1989 case ISD::SELECT_CC: return visitSELECT_CC(N);
1990 case ISD::SETCC: return visitSETCC(N);
1991 case ISD::SETCCCARRY: return visitSETCCCARRY(N);
1992 case ISD::SIGN_EXTEND: return visitSIGN_EXTEND(N);
1993 case ISD::ZERO_EXTEND: return visitZERO_EXTEND(N);
1994 case ISD::ANY_EXTEND: return visitANY_EXTEND(N);
1995 case ISD::AssertSext:
1996 case ISD::AssertZext: return visitAssertExt(N);
1997 case ISD::AssertAlign: return visitAssertAlign(N);
1998 case ISD::SIGN_EXTEND_INREG: return visitSIGN_EXTEND_INREG(N);
1999 case ISD::SIGN_EXTEND_VECTOR_INREG:
2000 case ISD::ZERO_EXTEND_VECTOR_INREG:
2001 case ISD::ANY_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N);
2002 case ISD::TRUNCATE: return visitTRUNCATE(N);
2003 case ISD::BITCAST: return visitBITCAST(N);
2004 case ISD::BUILD_PAIR: return visitBUILD_PAIR(N);
2005 case ISD::FADD: return visitFADD(N);
2006 case ISD::STRICT_FADD: return visitSTRICT_FADD(N);
2007 case ISD::FSUB: return visitFSUB(N);
2008 case ISD::FMUL: return visitFMUL(N);
2009 case ISD::FMA: return visitFMA<EmptyMatchContext>(N);
2010 case ISD::FMAD: return visitFMAD(N);
2011 case ISD::FDIV: return visitFDIV(N);
2012 case ISD::FREM: return visitFREM(N);
2013 case ISD::FSQRT: return visitFSQRT(N);
2014 case ISD::FCOPYSIGN: return visitFCOPYSIGN(N);
2015 case ISD::FPOW: return visitFPOW(N);
2016 case ISD::SINT_TO_FP: return visitSINT_TO_FP(N);
2017 case ISD::UINT_TO_FP: return visitUINT_TO_FP(N);
2018 case ISD::FP_TO_SINT: return visitFP_TO_SINT(N);
2019 case ISD::FP_TO_UINT: return visitFP_TO_UINT(N);
2020 case ISD::LRINT:
2021 case ISD::LLRINT: return visitXRINT(N);
2022 case ISD::FP_ROUND: return visitFP_ROUND(N);
2023 case ISD::FP_EXTEND: return visitFP_EXTEND(N);
2024 case ISD::FNEG: return visitFNEG(N);
2025 case ISD::FABS: return visitFABS(N);
2026 case ISD::FFLOOR: return visitFFLOOR(N);
2027 case ISD::FMINNUM:
2028 case ISD::FMAXNUM:
2029 case ISD::FMINIMUM:
2030 case ISD::FMAXIMUM: return visitFMinMax(N);
2031 case ISD::FCEIL: return visitFCEIL(N);
2032 case ISD::FTRUNC: return visitFTRUNC(N);
2033 case ISD::FFREXP: return visitFFREXP(N);
2034 case ISD::BRCOND: return visitBRCOND(N);
2035 case ISD::BR_CC: return visitBR_CC(N);
2036 case ISD::LOAD: return visitLOAD(N);
2037 case ISD::STORE: return visitSTORE(N);
2038 case ISD::INSERT_VECTOR_ELT: return visitINSERT_VECTOR_ELT(N);
2039 case ISD::EXTRACT_VECTOR_ELT: return visitEXTRACT_VECTOR_ELT(N);
2040 case ISD::BUILD_VECTOR: return visitBUILD_VECTOR(N);
2041 case ISD::CONCAT_VECTORS: return visitCONCAT_VECTORS(N);
2042 case ISD::EXTRACT_SUBVECTOR: return visitEXTRACT_SUBVECTOR(N);
2043 case ISD::VECTOR_SHUFFLE: return visitVECTOR_SHUFFLE(N);
2044 case ISD::SCALAR_TO_VECTOR: return visitSCALAR_TO_VECTOR(N);
2045 case ISD::INSERT_SUBVECTOR: return visitINSERT_SUBVECTOR(N);
2046 case ISD::MGATHER: return visitMGATHER(N);
2047 case ISD::MLOAD: return visitMLOAD(N);
2048 case ISD::MSCATTER: return visitMSCATTER(N);
2049 case ISD::MSTORE: return visitMSTORE(N);
2050 case ISD::LIFETIME_END: return visitLIFETIME_END(N);
2051 case ISD::FP_TO_FP16: return visitFP_TO_FP16(N);
2052 case ISD::FP16_TO_FP: return visitFP16_TO_FP(N);
2053 case ISD::FP_TO_BF16: return visitFP_TO_BF16(N);
2054 case ISD::BF16_TO_FP: return visitBF16_TO_FP(N);
2055 case ISD::FREEZE: return visitFREEZE(N);
2056 case ISD::GET_FPENV_MEM: return visitGET_FPENV_MEM(N);
2057 case ISD::SET_FPENV_MEM: return visitSET_FPENV_MEM(N);
2058 case ISD::VECREDUCE_FADD:
2059 case ISD::VECREDUCE_FMUL:
2060 case ISD::VECREDUCE_ADD:
2061 case ISD::VECREDUCE_MUL:
2062 case ISD::VECREDUCE_AND:
2063 case ISD::VECREDUCE_OR:
2064 case ISD::VECREDUCE_XOR:
2065 case ISD::VECREDUCE_SMAX:
2066 case ISD::VECREDUCE_SMIN:
2067 case ISD::VECREDUCE_UMAX:
2068 case ISD::VECREDUCE_UMIN:
2069 case ISD::VECREDUCE_FMAX:
2070 case ISD::VECREDUCE_FMIN:
2071 case ISD::VECREDUCE_FMAXIMUM:
2072 case ISD::VECREDUCE_FMINIMUM: return visitVECREDUCE(N);
2073#define BEGIN_REGISTER_VP_SDNODE(SDOPC, ...) case ISD::SDOPC:
2074#include "llvm/IR/VPIntrinsics.def"
2075 return visitVPOp(N);
2076 }
2077 // clang-format on
2078 return SDValue();
2079}
2080
2081SDValue DAGCombiner::combine(SDNode *N) {
2082 if (!DebugCounter::shouldExecute(CounterName: DAGCombineCounter))
2083 return SDValue();
2084
2085 SDValue RV;
2086 if (!DisableGenericCombines)
2087 RV = visit(N);
2088
2089 // If nothing happened, try a target-specific DAG combine.
2090 if (!RV.getNode()) {
2091 assert(N->getOpcode() != ISD::DELETED_NODE &&
2092 "Node was deleted but visit returned NULL!");
2093
2094 if (N->getOpcode() >= ISD::BUILTIN_OP_END ||
2095 TLI.hasTargetDAGCombine(NT: (ISD::NodeType)N->getOpcode())) {
2096
2097 // Expose the DAG combiner to the target combiner impls.
2098 TargetLowering::DAGCombinerInfo
2099 DagCombineInfo(DAG, Level, false, this);
2100
2101 RV = TLI.PerformDAGCombine(N, DCI&: DagCombineInfo);
2102 }
2103 }
2104
2105 // If nothing happened still, try promoting the operation.
2106 if (!RV.getNode()) {
2107 switch (N->getOpcode()) {
2108 default: break;
2109 case ISD::ADD:
2110 case ISD::SUB:
2111 case ISD::MUL:
2112 case ISD::AND:
2113 case ISD::OR:
2114 case ISD::XOR:
2115 RV = PromoteIntBinOp(Op: SDValue(N, 0));
2116 break;
2117 case ISD::SHL:
2118 case ISD::SRA:
2119 case ISD::SRL:
2120 RV = PromoteIntShiftOp(Op: SDValue(N, 0));
2121 break;
2122 case ISD::SIGN_EXTEND:
2123 case ISD::ZERO_EXTEND:
2124 case ISD::ANY_EXTEND:
2125 RV = PromoteExtend(Op: SDValue(N, 0));
2126 break;
2127 case ISD::LOAD:
2128 if (PromoteLoad(Op: SDValue(N, 0)))
2129 RV = SDValue(N, 0);
2130 break;
2131 }
2132 }
2133
2134 // If N is a commutative binary node, try to eliminate it if the commuted
2135 // version is already present in the DAG.
2136 if (!RV.getNode() && TLI.isCommutativeBinOp(Opcode: N->getOpcode())) {
2137 SDValue N0 = N->getOperand(Num: 0);
2138 SDValue N1 = N->getOperand(Num: 1);
2139
2140 // Constant operands are canonicalized to RHS.
2141 if (N0 != N1 && (isa<ConstantSDNode>(Val: N0) || !isa<ConstantSDNode>(Val: N1))) {
2142 SDValue Ops[] = {N1, N0};
2143 SDNode *CSENode = DAG.getNodeIfExists(Opcode: N->getOpcode(), VTList: N->getVTList(), Ops,
2144 Flags: N->getFlags());
2145 if (CSENode)
2146 return SDValue(CSENode, 0);
2147 }
2148 }
2149
2150 return RV;
2151}
2152
2153/// Given a node, return its input chain if it has one, otherwise return a null
2154/// sd operand.
2155static SDValue getInputChainForNode(SDNode *N) {
2156 if (unsigned NumOps = N->getNumOperands()) {
2157 if (N->getOperand(Num: 0).getValueType() == MVT::Other)
2158 return N->getOperand(Num: 0);
2159 if (N->getOperand(Num: NumOps-1).getValueType() == MVT::Other)
2160 return N->getOperand(Num: NumOps-1);
2161 for (unsigned i = 1; i < NumOps-1; ++i)
2162 if (N->getOperand(Num: i).getValueType() == MVT::Other)
2163 return N->getOperand(Num: i);
2164 }
2165 return SDValue();
2166}
2167
2168SDValue DAGCombiner::visitTokenFactor(SDNode *N) {
2169 // If N has two operands, where one has an input chain equal to the other,
2170 // the 'other' chain is redundant.
2171 if (N->getNumOperands() == 2) {
2172 if (getInputChainForNode(N: N->getOperand(Num: 0).getNode()) == N->getOperand(Num: 1))
2173 return N->getOperand(Num: 0);
2174 if (getInputChainForNode(N: N->getOperand(Num: 1).getNode()) == N->getOperand(Num: 0))
2175 return N->getOperand(Num: 1);
2176 }
2177
2178 // Don't simplify token factors if optnone.
2179 if (OptLevel == CodeGenOptLevel::None)
2180 return SDValue();
2181
2182 // Don't simplify the token factor if the node itself has too many operands.
2183 if (N->getNumOperands() > TokenFactorInlineLimit)
2184 return SDValue();
2185
2186 // If the sole user is a token factor, we should make sure we have a
2187 // chance to merge them together. This prevents TF chains from inhibiting
2188 // optimizations.
2189 if (N->hasOneUse() && N->use_begin()->getOpcode() == ISD::TokenFactor)
2190 AddToWorklist(N: *(N->use_begin()));
2191
2192 SmallVector<SDNode *, 8> TFs; // List of token factors to visit.
2193 SmallVector<SDValue, 8> Ops; // Ops for replacing token factor.
2194 SmallPtrSet<SDNode*, 16> SeenOps;
2195 bool Changed = false; // If we should replace this token factor.
2196
2197 // Start out with this token factor.
2198 TFs.push_back(Elt: N);
2199
2200 // Iterate through token factors. The TFs grows when new token factors are
2201 // encountered.
2202 for (unsigned i = 0; i < TFs.size(); ++i) {
2203 // Limit number of nodes to inline, to avoid quadratic compile times.
2204 // We have to add the outstanding Token Factors to Ops, otherwise we might
2205 // drop Ops from the resulting Token Factors.
2206 if (Ops.size() > TokenFactorInlineLimit) {
2207 for (unsigned j = i; j < TFs.size(); j++)
2208 Ops.emplace_back(Args&: TFs[j], Args: 0);
2209 // Drop unprocessed Token Factors from TFs, so we do not add them to the
2210 // combiner worklist later.
2211 TFs.resize(N: i);
2212 break;
2213 }
2214
2215 SDNode *TF = TFs[i];
2216 // Check each of the operands.
2217 for (const SDValue &Op : TF->op_values()) {
2218 switch (Op.getOpcode()) {
2219 case ISD::EntryToken:
2220 // Entry tokens don't need to be added to the list. They are
2221 // redundant.
2222 Changed = true;
2223 break;
2224
2225 case ISD::TokenFactor:
2226 if (Op.hasOneUse() && !is_contained(Range&: TFs, Element: Op.getNode())) {
2227 // Queue up for processing.
2228 TFs.push_back(Elt: Op.getNode());
2229 Changed = true;
2230 break;
2231 }
2232 [[fallthrough]];
2233
2234 default:
2235 // Only add if it isn't already in the list.
2236 if (SeenOps.insert(Ptr: Op.getNode()).second)
2237 Ops.push_back(Elt: Op);
2238 else
2239 Changed = true;
2240 break;
2241 }
2242 }
2243 }
2244
2245 // Re-visit inlined Token Factors, to clean them up in case they have been
2246 // removed. Skip the first Token Factor, as this is the current node.
2247 for (unsigned i = 1, e = TFs.size(); i < e; i++)
2248 AddToWorklist(N: TFs[i]);
2249
2250 // Remove Nodes that are chained to another node in the list. Do so
2251 // by walking up chains breath-first stopping when we've seen
2252 // another operand. In general we must climb to the EntryNode, but we can exit
2253 // early if we find all remaining work is associated with just one operand as
2254 // no further pruning is possible.
2255
2256 // List of nodes to search through and original Ops from which they originate.
2257 SmallVector<std::pair<SDNode *, unsigned>, 8> Worklist;
2258 SmallVector<unsigned, 8> OpWorkCount; // Count of work for each Op.
2259 SmallPtrSet<SDNode *, 16> SeenChains;
2260 bool DidPruneOps = false;
2261
2262 unsigned NumLeftToConsider = 0;
2263 for (const SDValue &Op : Ops) {
2264 Worklist.push_back(Elt: std::make_pair(x: Op.getNode(), y: NumLeftToConsider++));
2265 OpWorkCount.push_back(Elt: 1);
2266 }
2267
2268 auto AddToWorklist = [&](unsigned CurIdx, SDNode *Op, unsigned OpNumber) {
2269 // If this is an Op, we can remove the op from the list. Remark any
2270 // search associated with it as from the current OpNumber.
2271 if (SeenOps.contains(Ptr: Op)) {
2272 Changed = true;
2273 DidPruneOps = true;
2274 unsigned OrigOpNumber = 0;
2275 while (OrigOpNumber < Ops.size() && Ops[OrigOpNumber].getNode() != Op)
2276 OrigOpNumber++;
2277 assert((OrigOpNumber != Ops.size()) &&
2278 "expected to find TokenFactor Operand");
2279 // Re-mark worklist from OrigOpNumber to OpNumber
2280 for (unsigned i = CurIdx + 1; i < Worklist.size(); ++i) {
2281 if (Worklist[i].second == OrigOpNumber) {
2282 Worklist[i].second = OpNumber;
2283 }
2284 }
2285 OpWorkCount[OpNumber] += OpWorkCount[OrigOpNumber];
2286 OpWorkCount[OrigOpNumber] = 0;
2287 NumLeftToConsider--;
2288 }
2289 // Add if it's a new chain
2290 if (SeenChains.insert(Ptr: Op).second) {
2291 OpWorkCount[OpNumber]++;
2292 Worklist.push_back(Elt: std::make_pair(x&: Op, y&: OpNumber));
2293 }
2294 };
2295
2296 for (unsigned i = 0; i < Worklist.size() && i < 1024; ++i) {
2297 // We need at least be consider at least 2 Ops to prune.
2298 if (NumLeftToConsider <= 1)
2299 break;
2300 auto CurNode = Worklist[i].first;
2301 auto CurOpNumber = Worklist[i].second;
2302 assert((OpWorkCount[CurOpNumber] > 0) &&
2303 "Node should not appear in worklist");
2304 switch (CurNode->getOpcode()) {
2305 case ISD::EntryToken:
2306 // Hitting EntryToken is the only way for the search to terminate without
2307 // hitting
2308 // another operand's search. Prevent us from marking this operand
2309 // considered.
2310 NumLeftToConsider++;
2311 break;
2312 case ISD::TokenFactor:
2313 for (const SDValue &Op : CurNode->op_values())
2314 AddToWorklist(i, Op.getNode(), CurOpNumber);
2315 break;
2316 case ISD::LIFETIME_START:
2317 case ISD::LIFETIME_END:
2318 case ISD::CopyFromReg:
2319 case ISD::CopyToReg:
2320 AddToWorklist(i, CurNode->getOperand(Num: 0).getNode(), CurOpNumber);
2321 break;
2322 default:
2323 if (auto *MemNode = dyn_cast<MemSDNode>(Val: CurNode))
2324 AddToWorklist(i, MemNode->getChain().getNode(), CurOpNumber);
2325 break;
2326 }
2327 OpWorkCount[CurOpNumber]--;
2328 if (OpWorkCount[CurOpNumber] == 0)
2329 NumLeftToConsider--;
2330 }
2331
2332 // If we've changed things around then replace token factor.
2333 if (Changed) {
2334 SDValue Result;
2335 if (Ops.empty()) {
2336 // The entry token is the only possible outcome.
2337 Result = DAG.getEntryNode();
2338 } else {
2339 if (DidPruneOps) {
2340 SmallVector<SDValue, 8> PrunedOps;
2341 //
2342 for (const SDValue &Op : Ops) {
2343 if (SeenChains.count(Ptr: Op.getNode()) == 0)
2344 PrunedOps.push_back(Elt: Op);
2345 }
2346 Result = DAG.getTokenFactor(DL: SDLoc(N), Vals&: PrunedOps);
2347 } else {
2348 Result = DAG.getTokenFactor(DL: SDLoc(N), Vals&: Ops);
2349 }
2350 }
2351 return Result;
2352 }
2353 return SDValue();
2354}
2355
2356/// MERGE_VALUES can always be eliminated.
2357SDValue DAGCombiner::visitMERGE_VALUES(SDNode *N) {
2358 WorklistRemover DeadNodes(*this);
2359 // Replacing results may cause a different MERGE_VALUES to suddenly
2360 // be CSE'd with N, and carry its uses with it. Iterate until no
2361 // uses remain, to ensure that the node can be safely deleted.
2362 // First add the users of this node to the work list so that they
2363 // can be tried again once they have new operands.
2364 AddUsersToWorklist(N);
2365 do {
2366 // Do as a single replacement to avoid rewalking use lists.
2367 SmallVector<SDValue, 8> Ops;
2368 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
2369 Ops.push_back(Elt: N->getOperand(Num: i));
2370 DAG.ReplaceAllUsesWith(From: N, To: Ops.data());
2371 } while (!N->use_empty());
2372 deleteAndRecombine(N);
2373 return SDValue(N, 0); // Return N so it doesn't get rechecked!
2374}
2375
2376/// If \p N is a ConstantSDNode with isOpaque() == false return it casted to a
2377/// ConstantSDNode pointer else nullptr.
2378static ConstantSDNode *getAsNonOpaqueConstant(SDValue N) {
2379 ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Val&: N);
2380 return Const != nullptr && !Const->isOpaque() ? Const : nullptr;
2381}
2382
2383// isTruncateOf - If N is a truncate of some other value, return true, record
2384// the value being truncated in Op and which of Op's bits are zero/one in Known.
2385// This function computes KnownBits to avoid a duplicated call to
2386// computeKnownBits in the caller.
2387static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
2388 KnownBits &Known) {
2389 if (N->getOpcode() == ISD::TRUNCATE) {
2390 Op = N->getOperand(Num: 0);
2391 Known = DAG.computeKnownBits(Op);
2392 return true;
2393 }
2394
2395 if (N.getOpcode() != ISD::SETCC ||
2396 N.getValueType().getScalarType() != MVT::i1 ||
2397 cast<CondCodeSDNode>(Val: N.getOperand(i: 2))->get() != ISD::SETNE)
2398 return false;
2399
2400 SDValue Op0 = N->getOperand(Num: 0);
2401 SDValue Op1 = N->getOperand(Num: 1);
2402 assert(Op0.getValueType() == Op1.getValueType());
2403
2404 if (isNullOrNullSplat(V: Op0))
2405 Op = Op1;
2406 else if (isNullOrNullSplat(V: Op1))
2407 Op = Op0;
2408 else
2409 return false;
2410
2411 Known = DAG.computeKnownBits(Op);
2412
2413 return (Known.Zero | 1).isAllOnes();
2414}
2415
2416/// Return true if 'Use' is a load or a store that uses N as its base pointer
2417/// and that N may be folded in the load / store addressing mode.
2418static bool canFoldInAddressingMode(SDNode *N, SDNode *Use, SelectionDAG &DAG,
2419 const TargetLowering &TLI) {
2420 EVT VT;
2421 unsigned AS;
2422
2423 if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Val: Use)) {
2424 if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2425 return false;
2426 VT = LD->getMemoryVT();
2427 AS = LD->getAddressSpace();
2428 } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(Val: Use)) {
2429 if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2430 return false;
2431 VT = ST->getMemoryVT();
2432 AS = ST->getAddressSpace();
2433 } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(Val: Use)) {
2434 if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2435 return false;
2436 VT = LD->getMemoryVT();
2437 AS = LD->getAddressSpace();
2438 } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(Val: Use)) {
2439 if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2440 return false;
2441 VT = ST->getMemoryVT();
2442 AS = ST->getAddressSpace();
2443 } else {
2444 return false;
2445 }
2446
2447 TargetLowering::AddrMode AM;
2448 if (N->getOpcode() == ISD::ADD) {
2449 AM.HasBaseReg = true;
2450 ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
2451 if (Offset)
2452 // [reg +/- imm]
2453 AM.BaseOffs = Offset->getSExtValue();
2454 else
2455 // [reg +/- reg]
2456 AM.Scale = 1;
2457 } else if (N->getOpcode() == ISD::SUB) {
2458 AM.HasBaseReg = true;
2459 ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
2460 if (Offset)
2461 // [reg +/- imm]
2462 AM.BaseOffs = -Offset->getSExtValue();
2463 else
2464 // [reg +/- reg]
2465 AM.Scale = 1;
2466 } else {
2467 return false;
2468 }
2469
2470 return TLI.isLegalAddressingMode(DL: DAG.getDataLayout(), AM,
2471 Ty: VT.getTypeForEVT(Context&: *DAG.getContext()), AddrSpace: AS);
2472}
2473
2474/// This inverts a canonicalization in IR that replaces a variable select arm
2475/// with an identity constant. Codegen improves if we re-use the variable
2476/// operand rather than load a constant. This can also be converted into a
2477/// masked vector operation if the target supports it.
2478static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
2479 bool ShouldCommuteOperands) {
2480 // Match a select as operand 1. The identity constant that we are looking for
2481 // is only valid as operand 1 of a non-commutative binop.
2482 SDValue N0 = N->getOperand(Num: 0);
2483 SDValue N1 = N->getOperand(Num: 1);
2484 if (ShouldCommuteOperands)
2485 std::swap(a&: N0, b&: N1);
2486
2487 // TODO: Should this apply to scalar select too?
2488 if (N1.getOpcode() != ISD::VSELECT || !N1.hasOneUse())
2489 return SDValue();
2490
2491 // We can't hoist all instructions because of immediate UB (not speculatable).
2492 // For example div/rem by zero.
2493 if (!DAG.isSafeToSpeculativelyExecuteNode(N))
2494 return SDValue();
2495
2496 unsigned Opcode = N->getOpcode();
2497 EVT VT = N->getValueType(ResNo: 0);
2498 SDValue Cond = N1.getOperand(i: 0);
2499 SDValue TVal = N1.getOperand(i: 1);
2500 SDValue FVal = N1.getOperand(i: 2);
2501
2502 // This transform increases uses of N0, so freeze it to be safe.
2503 // binop N0, (vselect Cond, IDC, FVal) --> vselect Cond, N0, (binop N0, FVal)
2504 unsigned OpNo = ShouldCommuteOperands ? 0 : 1;
2505 if (isNeutralConstant(Opc: Opcode, Flags: N->getFlags(), V: TVal, OperandNo: OpNo)) {
2506 SDValue F0 = DAG.getFreeze(V: N0);
2507 SDValue NewBO = DAG.getNode(Opcode, DL: SDLoc(N), VT, N1: F0, N2: FVal, Flags: N->getFlags());
2508 return DAG.getSelect(DL: SDLoc(N), VT, Cond, LHS: F0, RHS: NewBO);
2509 }
2510 // binop N0, (vselect Cond, TVal, IDC) --> vselect Cond, (binop N0, TVal), N0
2511 if (isNeutralConstant(Opc: Opcode, Flags: N->getFlags(), V: FVal, OperandNo: OpNo)) {
2512 SDValue F0 = DAG.getFreeze(V: N0);
2513 SDValue NewBO = DAG.getNode(Opcode, DL: SDLoc(N), VT, N1: F0, N2: TVal, Flags: N->getFlags());
2514 return DAG.getSelect(DL: SDLoc(N), VT, Cond, LHS: NewBO, RHS: F0);
2515 }
2516
2517 return SDValue();
2518}
2519
2520SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
2521 assert(TLI.isBinOp(BO->getOpcode()) && BO->getNumValues() == 1 &&
2522 "Unexpected binary operator");
2523
2524 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2525 auto BinOpcode = BO->getOpcode();
2526 EVT VT = BO->getValueType(ResNo: 0);
2527 if (TLI.shouldFoldSelectWithIdentityConstant(BinOpcode, VT)) {
2528 if (SDValue Sel = foldSelectWithIdentityConstant(N: BO, DAG, ShouldCommuteOperands: false))
2529 return Sel;
2530
2531 if (TLI.isCommutativeBinOp(Opcode: BO->getOpcode()))
2532 if (SDValue Sel = foldSelectWithIdentityConstant(N: BO, DAG, ShouldCommuteOperands: true))
2533 return Sel;
2534 }
2535
2536 // Don't do this unless the old select is going away. We want to eliminate the
2537 // binary operator, not replace a binop with a select.
2538 // TODO: Handle ISD::SELECT_CC.
2539 unsigned SelOpNo = 0;
2540 SDValue Sel = BO->getOperand(Num: 0);
2541 if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) {
2542 SelOpNo = 1;
2543 Sel = BO->getOperand(Num: 1);
2544
2545 // Peek through trunc to shift amount type.
2546 if ((BinOpcode == ISD::SHL || BinOpcode == ISD::SRA ||
2547 BinOpcode == ISD::SRL) && Sel.hasOneUse()) {
2548 // This is valid when the truncated bits of x are already zero.
2549 SDValue Op;
2550 KnownBits Known;
2551 if (isTruncateOf(DAG, N: Sel, Op, Known) &&
2552 Known.countMaxActiveBits() < Sel.getScalarValueSizeInBits())
2553 Sel = Op;
2554 }
2555 }
2556
2557 if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse())
2558 return SDValue();
2559
2560 SDValue CT = Sel.getOperand(i: 1);
2561 if (!isConstantOrConstantVector(N: CT, NoOpaques: true) &&
2562 !DAG.isConstantFPBuildVectorOrConstantFP(N: CT))
2563 return SDValue();
2564
2565 SDValue CF = Sel.getOperand(i: 2);
2566 if (!isConstantOrConstantVector(N: CF, NoOpaques: true) &&
2567 !DAG.isConstantFPBuildVectorOrConstantFP(N: CF))
2568 return SDValue();
2569
2570 // Bail out if any constants are opaque because we can't constant fold those.
2571 // The exception is "and" and "or" with either 0 or -1 in which case we can
2572 // propagate non constant operands into select. I.e.:
2573 // and (select Cond, 0, -1), X --> select Cond, 0, X
2574 // or X, (select Cond, -1, 0) --> select Cond, -1, X
2575 bool CanFoldNonConst =
2576 (BinOpcode == ISD::AND || BinOpcode == ISD::OR) &&
2577 ((isNullOrNullSplat(V: CT) && isAllOnesOrAllOnesSplat(V: CF)) ||
2578 (isNullOrNullSplat(V: CF) && isAllOnesOrAllOnesSplat(V: CT)));
2579
2580 SDValue CBO = BO->getOperand(Num: SelOpNo ^ 1);
2581 if (!CanFoldNonConst &&
2582 !isConstantOrConstantVector(N: CBO, NoOpaques: true) &&
2583 !DAG.isConstantFPBuildVectorOrConstantFP(N: CBO))
2584 return SDValue();
2585
2586 SDLoc DL(Sel);
2587 SDValue NewCT, NewCF;
2588
2589 if (CanFoldNonConst) {
2590 // If CBO is an opaque constant, we can't rely on getNode to constant fold.
2591 if ((BinOpcode == ISD::AND && isNullOrNullSplat(V: CT)) ||
2592 (BinOpcode == ISD::OR && isAllOnesOrAllOnesSplat(V: CT)))
2593 NewCT = CT;
2594 else
2595 NewCT = CBO;
2596
2597 if ((BinOpcode == ISD::AND && isNullOrNullSplat(V: CF)) ||
2598 (BinOpcode == ISD::OR && isAllOnesOrAllOnesSplat(V: CF)))
2599 NewCF = CF;
2600 else
2601 NewCF = CBO;
2602 } else {
2603 // We have a select-of-constants followed by a binary operator with a
2604 // constant. Eliminate the binop by pulling the constant math into the
2605 // select. Example: add (select Cond, CT, CF), CBO --> select Cond, CT +
2606 // CBO, CF + CBO
2607 NewCT = SelOpNo ? DAG.FoldConstantArithmetic(Opcode: BinOpcode, DL, VT, Ops: {CBO, CT})
2608 : DAG.FoldConstantArithmetic(Opcode: BinOpcode, DL, VT, Ops: {CT, CBO});
2609 if (!NewCT)
2610 return SDValue();
2611
2612 NewCF = SelOpNo ? DAG.FoldConstantArithmetic(Opcode: BinOpcode, DL, VT, Ops: {CBO, CF})
2613 : DAG.FoldConstantArithmetic(Opcode: BinOpcode, DL, VT, Ops: {CF, CBO});
2614 if (!NewCF)
2615 return SDValue();
2616 }
2617
2618 SDValue SelectOp = DAG.getSelect(DL, VT, Cond: Sel.getOperand(i: 0), LHS: NewCT, RHS: NewCF);
2619 SelectOp->setFlags(BO->getFlags());
2620 return SelectOp;
2621}
2622
2623static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, SelectionDAG &DAG) {
2624 assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2625 "Expecting add or sub");
2626
2627 // Match a constant operand and a zext operand for the math instruction:
2628 // add Z, C
2629 // sub C, Z
2630 bool IsAdd = N->getOpcode() == ISD::ADD;
2631 SDValue C = IsAdd ? N->getOperand(Num: 1) : N->getOperand(Num: 0);
2632 SDValue Z = IsAdd ? N->getOperand(Num: 0) : N->getOperand(Num: 1);
2633 auto *CN = dyn_cast<ConstantSDNode>(Val&: C);
2634 if (!CN || Z.getOpcode() != ISD::ZERO_EXTEND)
2635 return SDValue();
2636
2637 // Match the zext operand as a setcc of a boolean.
2638 if (Z.getOperand(i: 0).getOpcode() != ISD::SETCC ||
2639 Z.getOperand(i: 0).getValueType() != MVT::i1)
2640 return SDValue();
2641
2642 // Match the compare as: setcc (X & 1), 0, eq.
2643 SDValue SetCC = Z.getOperand(i: 0);
2644 ISD::CondCode CC = cast<CondCodeSDNode>(Val: SetCC->getOperand(Num: 2))->get();
2645 if (CC != ISD::SETEQ || !isNullConstant(V: SetCC.getOperand(i: 1)) ||
2646 SetCC.getOperand(i: 0).getOpcode() != ISD::AND ||
2647 !isOneConstant(V: SetCC.getOperand(i: 0).getOperand(i: 1)))
2648 return SDValue();
2649
2650 // We are adding/subtracting a constant and an inverted low bit. Turn that
2651 // into a subtract/add of the low bit with incremented/decremented constant:
2652 // add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1))
2653 // sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1))
2654 EVT VT = C.getValueType();
2655 SDLoc DL(N);
2656 SDValue LowBit = DAG.getZExtOrTrunc(Op: SetCC.getOperand(i: 0), DL, VT);
2657 SDValue C1 = IsAdd ? DAG.getConstant(Val: CN->getAPIntValue() + 1, DL, VT) :
2658 DAG.getConstant(Val: CN->getAPIntValue() - 1, DL, VT);
2659 return DAG.getNode(Opcode: IsAdd ? ISD::SUB : ISD::ADD, DL, VT, N1: C1, N2: LowBit);
2660}
2661
2662/// Try to fold a 'not' shifted sign-bit with add/sub with constant operand into
2663/// a shift and add with a different constant.
2664static SDValue foldAddSubOfSignBit(SDNode *N, SelectionDAG &DAG) {
2665 assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2666 "Expecting add or sub");
2667
2668 // We need a constant operand for the add/sub, and the other operand is a
2669 // logical shift right: add (srl), C or sub C, (srl).
2670 bool IsAdd = N->getOpcode() == ISD::ADD;
2671 SDValue ConstantOp = IsAdd ? N->getOperand(Num: 1) : N->getOperand(Num: 0);
2672 SDValue ShiftOp = IsAdd ? N->getOperand(Num: 0) : N->getOperand(Num: 1);
2673 if (!DAG.isConstantIntBuildVectorOrConstantInt(N: ConstantOp) ||
2674 ShiftOp.getOpcode() != ISD::SRL)
2675 return SDValue();
2676
2677 // The shift must be of a 'not' value.
2678 SDValue Not = ShiftOp.getOperand(i: 0);
2679 if (!Not.hasOneUse() || !isBitwiseNot(V: Not))
2680 return SDValue();
2681
2682 // The shift must be moving the sign bit to the least-significant-bit.
2683 EVT VT = ShiftOp.getValueType();
2684 SDValue ShAmt = ShiftOp.getOperand(i: 1);
2685 ConstantSDNode *ShAmtC = isConstOrConstSplat(N: ShAmt);
2686 if (!ShAmtC || ShAmtC->getAPIntValue() != (VT.getScalarSizeInBits() - 1))
2687 return SDValue();
2688
2689 // Eliminate the 'not' by adjusting the shift and add/sub constant:
2690 // add (srl (not X), 31), C --> add (sra X, 31), (C + 1)
2691 // sub C, (srl (not X), 31) --> add (srl X, 31), (C - 1)
2692 SDLoc DL(N);
2693 if (SDValue NewC = DAG.FoldConstantArithmetic(
2694 Opcode: IsAdd ? ISD::ADD : ISD::SUB, DL, VT,
2695 Ops: {ConstantOp, DAG.getConstant(Val: 1, DL, VT)})) {
2696 SDValue NewShift = DAG.getNode(Opcode: IsAdd ? ISD::SRA : ISD::SRL, DL, VT,
2697 N1: Not.getOperand(i: 0), N2: ShAmt);
2698 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: NewShift, N2: NewC);
2699 }
2700
2701 return SDValue();
2702}
2703
2704static bool
2705areBitwiseNotOfEachother(SDValue Op0, SDValue Op1) {
2706 return (isBitwiseNot(V: Op0) && Op0.getOperand(i: 0) == Op1) ||
2707 (isBitwiseNot(V: Op1) && Op1.getOperand(i: 0) == Op0);
2708}
2709
2710/// Try to fold a node that behaves like an ADD (note that N isn't necessarily
2711/// an ISD::ADD here, it could for example be an ISD::OR if we know that there
2712/// are no common bits set in the operands).
2713SDValue DAGCombiner::visitADDLike(SDNode *N) {
2714 SDValue N0 = N->getOperand(Num: 0);
2715 SDValue N1 = N->getOperand(Num: 1);
2716 EVT VT = N0.getValueType();
2717 SDLoc DL(N);
2718
2719 // fold (add x, undef) -> undef
2720 if (N0.isUndef())
2721 return N0;
2722 if (N1.isUndef())
2723 return N1;
2724
2725 // fold (add c1, c2) -> c1+c2
2726 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::ADD, DL, VT, Ops: {N0, N1}))
2727 return C;
2728
2729 // canonicalize constant to RHS
2730 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
2731 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
2732 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: N0);
2733
2734 if (areBitwiseNotOfEachother(Op0: N0, Op1: N1))
2735 return DAG.getConstant(Val: APInt::getAllOnes(numBits: VT.getScalarSizeInBits()),
2736 DL: SDLoc(N), VT);
2737
2738 // fold vector ops
2739 if (VT.isVector()) {
2740 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
2741 return FoldedVOp;
2742
2743 // fold (add x, 0) -> x, vector edition
2744 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
2745 return N0;
2746 }
2747
2748 // fold (add x, 0) -> x
2749 if (isNullConstant(V: N1))
2750 return N0;
2751
2752 if (N0.getOpcode() == ISD::SUB) {
2753 SDValue N00 = N0.getOperand(i: 0);
2754 SDValue N01 = N0.getOperand(i: 1);
2755
2756 // fold ((A-c1)+c2) -> (A+(c2-c1))
2757 if (SDValue Sub = DAG.FoldConstantArithmetic(Opcode: ISD::SUB, DL, VT, Ops: {N1, N01}))
2758 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: Sub);
2759
2760 // fold ((c1-A)+c2) -> (c1+c2)-A
2761 if (SDValue Add = DAG.FoldConstantArithmetic(Opcode: ISD::ADD, DL, VT, Ops: {N1, N00}))
2762 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Add, N2: N0.getOperand(i: 1));
2763 }
2764
2765 // add (sext i1 X), 1 -> zext (not i1 X)
2766 // We don't transform this pattern:
2767 // add (zext i1 X), -1 -> sext (not i1 X)
2768 // because most (?) targets generate better code for the zext form.
2769 if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
2770 isOneOrOneSplat(V: N1)) {
2771 SDValue X = N0.getOperand(i: 0);
2772 if ((!LegalOperations ||
2773 (TLI.isOperationLegal(Op: ISD::XOR, VT: X.getValueType()) &&
2774 TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT))) &&
2775 X.getScalarValueSizeInBits() == 1) {
2776 SDValue Not = DAG.getNOT(DL, Val: X, VT: X.getValueType());
2777 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: Not);
2778 }
2779 }
2780
2781 // Fold (add (or x, c0), c1) -> (add x, (c0 + c1))
2782 // iff (or x, c0) is equivalent to (add x, c0).
2783 // Fold (add (xor x, c0), c1) -> (add x, (c0 + c1))
2784 // iff (xor x, c0) is equivalent to (add x, c0).
2785 if (DAG.isADDLike(Op: N0)) {
2786 SDValue N01 = N0.getOperand(i: 1);
2787 if (SDValue Add = DAG.FoldConstantArithmetic(Opcode: ISD::ADD, DL, VT, Ops: {N1, N01}))
2788 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: Add);
2789 }
2790
2791 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
2792 return NewSel;
2793
2794 // reassociate add
2795 if (!reassociationCanBreakAddressingModePattern(Opc: ISD::ADD, DL, N, N0, N1)) {
2796 if (SDValue RADD = reassociateOps(Opc: ISD::ADD, DL, N0, N1, Flags: N->getFlags()))
2797 return RADD;
2798
2799 // Reassociate (add (or x, c), y) -> (add add(x, y), c)) if (or x, c) is
2800 // equivalent to (add x, c).
2801 // Reassociate (add (xor x, c), y) -> (add add(x, y), c)) if (xor x, c) is
2802 // equivalent to (add x, c).
2803 // Do this optimization only when adding c does not introduce instructions
2804 // for adding carries.
2805 auto ReassociateAddOr = [&](SDValue N0, SDValue N1) {
2806 if (DAG.isADDLike(Op: N0) && N0.hasOneUse() &&
2807 isConstantOrConstantVector(N: N0.getOperand(i: 1), /* NoOpaque */ NoOpaques: true)) {
2808 // If N0's type does not split or is a sign mask, it does not introduce
2809 // add carry.
2810 auto TyActn = TLI.getTypeAction(Context&: *DAG.getContext(), VT: N0.getValueType());
2811 bool NoAddCarry = TyActn == TargetLoweringBase::TypeLegal ||
2812 TyActn == TargetLoweringBase::TypePromoteInteger ||
2813 isMinSignedConstant(V: N0.getOperand(i: 1));
2814 if (NoAddCarry)
2815 return DAG.getNode(
2816 Opcode: ISD::ADD, DL, VT,
2817 N1: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: N0.getOperand(i: 0)),
2818 N2: N0.getOperand(i: 1));
2819 }
2820 return SDValue();
2821 };
2822 if (SDValue Add = ReassociateAddOr(N0, N1))
2823 return Add;
2824 if (SDValue Add = ReassociateAddOr(N1, N0))
2825 return Add;
2826
2827 // Fold add(vecreduce(x), vecreduce(y)) -> vecreduce(add(x, y))
2828 if (SDValue SD =
2829 reassociateReduction(RedOpc: ISD::VECREDUCE_ADD, Opc: ISD::ADD, DL, VT, N0, N1))
2830 return SD;
2831 }
2832 // fold ((0-A) + B) -> B-A
2833 if (N0.getOpcode() == ISD::SUB && isNullOrNullSplat(V: N0.getOperand(i: 0)))
2834 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1, N2: N0.getOperand(i: 1));
2835
2836 // fold (A + (0-B)) -> A-B
2837 if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(V: N1.getOperand(i: 0)))
2838 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: N1.getOperand(i: 1));
2839
2840 // fold (A+(B-A)) -> B
2841 if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(i: 1))
2842 return N1.getOperand(i: 0);
2843
2844 // fold ((B-A)+A) -> B
2845 if (N0.getOpcode() == ISD::SUB && N1 == N0.getOperand(i: 1))
2846 return N0.getOperand(i: 0);
2847
2848 // fold ((A-B)+(C-A)) -> (C-B)
2849 if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
2850 N0.getOperand(i: 0) == N1.getOperand(i: 1))
2851 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N1.getOperand(i: 0),
2852 N2: N0.getOperand(i: 1));
2853
2854 // fold ((A-B)+(B-C)) -> (A-C)
2855 if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
2856 N0.getOperand(i: 1) == N1.getOperand(i: 0))
2857 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0.getOperand(i: 0),
2858 N2: N1.getOperand(i: 1));
2859
2860 // fold (A+(B-(A+C))) to (B-C)
2861 if (N1.getOpcode() == ISD::SUB && N1.getOperand(i: 1).getOpcode() == ISD::ADD &&
2862 N0 == N1.getOperand(i: 1).getOperand(i: 0))
2863 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N1.getOperand(i: 0),
2864 N2: N1.getOperand(i: 1).getOperand(i: 1));
2865
2866 // fold (A+(B-(C+A))) to (B-C)
2867 if (N1.getOpcode() == ISD::SUB && N1.getOperand(i: 1).getOpcode() == ISD::ADD &&
2868 N0 == N1.getOperand(i: 1).getOperand(i: 1))
2869 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N1.getOperand(i: 0),
2870 N2: N1.getOperand(i: 1).getOperand(i: 0));
2871
2872 // fold (A+((B-A)+or-C)) to (B+or-C)
2873 if ((N1.getOpcode() == ISD::SUB || N1.getOpcode() == ISD::ADD) &&
2874 N1.getOperand(i: 0).getOpcode() == ISD::SUB &&
2875 N0 == N1.getOperand(i: 0).getOperand(i: 1))
2876 return DAG.getNode(Opcode: N1.getOpcode(), DL, VT, N1: N1.getOperand(i: 0).getOperand(i: 0),
2877 N2: N1.getOperand(i: 1));
2878
2879 // fold (A-B)+(C-D) to (A+C)-(B+D) when A or C is constant
2880 if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
2881 N0->hasOneUse() && N1->hasOneUse()) {
2882 SDValue N00 = N0.getOperand(i: 0);
2883 SDValue N01 = N0.getOperand(i: 1);
2884 SDValue N10 = N1.getOperand(i: 0);
2885 SDValue N11 = N1.getOperand(i: 1);
2886
2887 if (isConstantOrConstantVector(N: N00) || isConstantOrConstantVector(N: N10))
2888 return DAG.getNode(Opcode: ISD::SUB, DL, VT,
2889 N1: DAG.getNode(Opcode: ISD::ADD, DL: SDLoc(N0), VT, N1: N00, N2: N10),
2890 N2: DAG.getNode(Opcode: ISD::ADD, DL: SDLoc(N1), VT, N1: N01, N2: N11));
2891 }
2892
2893 // fold (add (umax X, C), -C) --> (usubsat X, C)
2894 if (N0.getOpcode() == ISD::UMAX && hasOperation(Opcode: ISD::USUBSAT, VT)) {
2895 auto MatchUSUBSAT = [](ConstantSDNode *Max, ConstantSDNode *Op) {
2896 return (!Max && !Op) ||
2897 (Max && Op && Max->getAPIntValue() == (-Op->getAPIntValue()));
2898 };
2899 if (ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchUSUBSAT,
2900 /*AllowUndefs*/ true))
2901 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT, N1: N0.getOperand(i: 0),
2902 N2: N0.getOperand(i: 1));
2903 }
2904
2905 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
2906 return SDValue(N, 0);
2907
2908 if (isOneOrOneSplat(V: N1)) {
2909 // fold (add (xor a, -1), 1) -> (sub 0, a)
2910 if (isBitwiseNot(V: N0))
2911 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: DAG.getConstant(Val: 0, DL, VT),
2912 N2: N0.getOperand(i: 0));
2913
2914 // fold (add (add (xor a, -1), b), 1) -> (sub b, a)
2915 if (N0.getOpcode() == ISD::ADD) {
2916 SDValue A, Xor;
2917
2918 if (isBitwiseNot(V: N0.getOperand(i: 0))) {
2919 A = N0.getOperand(i: 1);
2920 Xor = N0.getOperand(i: 0);
2921 } else if (isBitwiseNot(V: N0.getOperand(i: 1))) {
2922 A = N0.getOperand(i: 0);
2923 Xor = N0.getOperand(i: 1);
2924 }
2925
2926 if (Xor)
2927 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: A, N2: Xor.getOperand(i: 0));
2928 }
2929
2930 // Look for:
2931 // add (add x, y), 1
2932 // And if the target does not like this form then turn into:
2933 // sub y, (xor x, -1)
2934 if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
2935 N0.hasOneUse() &&
2936 // Limit this to after legalization if the add has wrap flags
2937 (Level >= AfterLegalizeDAG || (!N->getFlags().hasNoUnsignedWrap() &&
2938 !N->getFlags().hasNoSignedWrap()))) {
2939 SDValue Not = DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: N0.getOperand(i: 0),
2940 N2: DAG.getAllOnesConstant(DL, VT));
2941 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0.getOperand(i: 1), N2: Not);
2942 }
2943 }
2944
2945 // (x - y) + -1 -> add (xor y, -1), x
2946 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
2947 isAllOnesOrAllOnesSplat(V: N1)) {
2948 SDValue Xor = DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: N0.getOperand(i: 1), N2: N1);
2949 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Xor, N2: N0.getOperand(i: 0));
2950 }
2951
2952 if (SDValue Combined = visitADDLikeCommutative(N0, N1, LocReference: N))
2953 return Combined;
2954
2955 if (SDValue Combined = visitADDLikeCommutative(N0: N1, N1: N0, LocReference: N))
2956 return Combined;
2957
2958 return SDValue();
2959}
2960
2961SDValue DAGCombiner::visitADD(SDNode *N) {
2962 SDValue N0 = N->getOperand(Num: 0);
2963 SDValue N1 = N->getOperand(Num: 1);
2964 EVT VT = N0.getValueType();
2965 SDLoc DL(N);
2966
2967 if (SDValue Combined = visitADDLike(N))
2968 return Combined;
2969
2970 if (SDValue V = foldAddSubBoolOfMaskedVal(N, DAG))
2971 return V;
2972
2973 if (SDValue V = foldAddSubOfSignBit(N, DAG))
2974 return V;
2975
2976 // fold (a+b) -> (a|b) iff a and b share no bits.
2977 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::OR, VT)) &&
2978 DAG.haveNoCommonBitsSet(A: N0, B: N1))
2979 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: N0, N2: N1);
2980
2981 // Fold (add (vscale * C0), (vscale * C1)) to (vscale * (C0 + C1)).
2982 if (N0.getOpcode() == ISD::VSCALE && N1.getOpcode() == ISD::VSCALE) {
2983 const APInt &C0 = N0->getConstantOperandAPInt(Num: 0);
2984 const APInt &C1 = N1->getConstantOperandAPInt(Num: 0);
2985 return DAG.getVScale(DL, VT, MulImm: C0 + C1);
2986 }
2987
2988 // fold a+vscale(c1)+vscale(c2) -> a+vscale(c1+c2)
2989 if (N0.getOpcode() == ISD::ADD &&
2990 N0.getOperand(i: 1).getOpcode() == ISD::VSCALE &&
2991 N1.getOpcode() == ISD::VSCALE) {
2992 const APInt &VS0 = N0.getOperand(i: 1)->getConstantOperandAPInt(Num: 0);
2993 const APInt &VS1 = N1->getConstantOperandAPInt(Num: 0);
2994 SDValue VS = DAG.getVScale(DL, VT, MulImm: VS0 + VS1);
2995 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: VS);
2996 }
2997
2998 // Fold (add step_vector(c1), step_vector(c2) to step_vector(c1+c2))
2999 if (N0.getOpcode() == ISD::STEP_VECTOR &&
3000 N1.getOpcode() == ISD::STEP_VECTOR) {
3001 const APInt &C0 = N0->getConstantOperandAPInt(Num: 0);
3002 const APInt &C1 = N1->getConstantOperandAPInt(Num: 0);
3003 APInt NewStep = C0 + C1;
3004 return DAG.getStepVector(DL, ResVT: VT, StepVal: NewStep);
3005 }
3006
3007 // Fold a + step_vector(c1) + step_vector(c2) to a + step_vector(c1+c2)
3008 if (N0.getOpcode() == ISD::ADD &&
3009 N0.getOperand(i: 1).getOpcode() == ISD::STEP_VECTOR &&
3010 N1.getOpcode() == ISD::STEP_VECTOR) {
3011 const APInt &SV0 = N0.getOperand(i: 1)->getConstantOperandAPInt(Num: 0);
3012 const APInt &SV1 = N1->getConstantOperandAPInt(Num: 0);
3013 APInt NewStep = SV0 + SV1;
3014 SDValue SV = DAG.getStepVector(DL, ResVT: VT, StepVal: NewStep);
3015 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: SV);
3016 }
3017
3018 return SDValue();
3019}
3020
3021SDValue DAGCombiner::visitADDSAT(SDNode *N) {
3022 unsigned Opcode = N->getOpcode();
3023 SDValue N0 = N->getOperand(Num: 0);
3024 SDValue N1 = N->getOperand(Num: 1);
3025 EVT VT = N0.getValueType();
3026 bool IsSigned = Opcode == ISD::SADDSAT;
3027 SDLoc DL(N);
3028
3029 // fold (add_sat x, undef) -> -1
3030 if (N0.isUndef() || N1.isUndef())
3031 return DAG.getAllOnesConstant(DL, VT);
3032
3033 // fold (add_sat c1, c2) -> c3
3034 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
3035 return C;
3036
3037 // canonicalize constant to RHS
3038 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
3039 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
3040 return DAG.getNode(Opcode, DL, VT, N1, N2: N0);
3041
3042 // fold vector ops
3043 if (VT.isVector()) {
3044 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
3045 return FoldedVOp;
3046
3047 // fold (add_sat x, 0) -> x, vector edition
3048 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
3049 return N0;
3050 }
3051
3052 // fold (add_sat x, 0) -> x
3053 if (isNullConstant(V: N1))
3054 return N0;
3055
3056 // If it cannot overflow, transform into an add.
3057 if (DAG.willNotOverflowAdd(IsSigned, N0, N1))
3058 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1);
3059
3060 return SDValue();
3061}
3062
3063static SDValue getAsCarry(const TargetLowering &TLI, SDValue V,
3064 bool ForceCarryReconstruction = false) {
3065 bool Masked = false;
3066
3067 // First, peel away TRUNCATE/ZERO_EXTEND/AND nodes due to legalization.
3068 while (true) {
3069 if (V.getOpcode() == ISD::TRUNCATE || V.getOpcode() == ISD::ZERO_EXTEND) {
3070 V = V.getOperand(i: 0);
3071 continue;
3072 }
3073
3074 if (V.getOpcode() == ISD::AND && isOneConstant(V: V.getOperand(i: 1))) {
3075 if (ForceCarryReconstruction)
3076 return V;
3077
3078 Masked = true;
3079 V = V.getOperand(i: 0);
3080 continue;
3081 }
3082
3083 if (ForceCarryReconstruction && V.getValueType() == MVT::i1)
3084 return V;
3085
3086 break;
3087 }
3088
3089 // If this is not a carry, return.
3090 if (V.getResNo() != 1)
3091 return SDValue();
3092
3093 if (V.getOpcode() != ISD::UADDO_CARRY && V.getOpcode() != ISD::USUBO_CARRY &&
3094 V.getOpcode() != ISD::UADDO && V.getOpcode() != ISD::USUBO)
3095 return SDValue();
3096
3097 EVT VT = V->getValueType(ResNo: 0);
3098 if (!TLI.isOperationLegalOrCustom(Op: V.getOpcode(), VT))
3099 return SDValue();
3100
3101 // If the result is masked, then no matter what kind of bool it is we can
3102 // return. If it isn't, then we need to make sure the bool type is either 0 or
3103 // 1 and not other values.
3104 if (Masked ||
3105 TLI.getBooleanContents(Type: V.getValueType()) ==
3106 TargetLoweringBase::ZeroOrOneBooleanContent)
3107 return V;
3108
3109 return SDValue();
3110}
3111
3112/// Given the operands of an add/sub operation, see if the 2nd operand is a
3113/// masked 0/1 whose source operand is actually known to be 0/-1. If so, invert
3114/// the opcode and bypass the mask operation.
3115static SDValue foldAddSubMasked1(bool IsAdd, SDValue N0, SDValue N1,
3116 SelectionDAG &DAG, const SDLoc &DL) {
3117 if (N1.getOpcode() == ISD::ZERO_EXTEND)
3118 N1 = N1.getOperand(i: 0);
3119
3120 if (N1.getOpcode() != ISD::AND || !isOneOrOneSplat(V: N1->getOperand(Num: 1)))
3121 return SDValue();
3122
3123 EVT VT = N0.getValueType();
3124 SDValue N10 = N1.getOperand(i: 0);
3125 if (N10.getValueType() != VT && N10.getOpcode() == ISD::TRUNCATE)
3126 N10 = N10.getOperand(i: 0);
3127
3128 if (N10.getValueType() != VT)
3129 return SDValue();
3130
3131 if (DAG.ComputeNumSignBits(Op: N10) != VT.getScalarSizeInBits())
3132 return SDValue();
3133
3134 // add N0, (and (AssertSext X, i1), 1) --> sub N0, X
3135 // sub N0, (and (AssertSext X, i1), 1) --> add N0, X
3136 return DAG.getNode(Opcode: IsAdd ? ISD::SUB : ISD::ADD, DL, VT, N1: N0, N2: N10);
3137}
3138
3139/// Helper for doing combines based on N0 and N1 being added to each other.
3140SDValue DAGCombiner::visitADDLikeCommutative(SDValue N0, SDValue N1,
3141 SDNode *LocReference) {
3142 EVT VT = N0.getValueType();
3143 SDLoc DL(LocReference);
3144
3145 // fold (add x, shl(0 - y, n)) -> sub(x, shl(y, n))
3146 if (N1.getOpcode() == ISD::SHL && N1.getOperand(i: 0).getOpcode() == ISD::SUB &&
3147 isNullOrNullSplat(V: N1.getOperand(i: 0).getOperand(i: 0)))
3148 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0,
3149 N2: DAG.getNode(Opcode: ISD::SHL, DL, VT,
3150 N1: N1.getOperand(i: 0).getOperand(i: 1),
3151 N2: N1.getOperand(i: 1)));
3152
3153 if (SDValue V = foldAddSubMasked1(IsAdd: true, N0, N1, DAG, DL))
3154 return V;
3155
3156 // Look for:
3157 // add (add x, 1), y
3158 // And if the target does not like this form then turn into:
3159 // sub y, (xor x, -1)
3160 if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
3161 N0.hasOneUse() && isOneOrOneSplat(V: N0.getOperand(i: 1)) &&
3162 // Limit this to after legalization if the add has wrap flags
3163 (Level >= AfterLegalizeDAG || (!N0->getFlags().hasNoUnsignedWrap() &&
3164 !N0->getFlags().hasNoSignedWrap()))) {
3165 SDValue Not = DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: N0.getOperand(i: 0),
3166 N2: DAG.getAllOnesConstant(DL, VT));
3167 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1, N2: Not);
3168 }
3169
3170 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse()) {
3171 // Hoist one-use subtraction by non-opaque constant:
3172 // (x - C) + y -> (x + y) - C
3173 // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
3174 if (isConstantOrConstantVector(N: N0.getOperand(i: 1), /*NoOpaques=*/true)) {
3175 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: N1);
3176 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Add, N2: N0.getOperand(i: 1));
3177 }
3178 // Hoist one-use subtraction from non-opaque constant:
3179 // (C - x) + y -> (y - x) + C
3180 if (isConstantOrConstantVector(N: N0.getOperand(i: 0), /*NoOpaques=*/true)) {
3181 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1, N2: N0.getOperand(i: 1));
3182 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Sub, N2: N0.getOperand(i: 0));
3183 }
3184 }
3185
3186 // add (mul x, C), x -> mul x, C+1
3187 if (N0.getOpcode() == ISD::MUL && N0.getOperand(i: 0) == N1 &&
3188 isConstantOrConstantVector(N: N0.getOperand(i: 1), /*NoOpaques=*/true) &&
3189 N0.hasOneUse()) {
3190 SDValue NewC = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 1),
3191 N2: DAG.getConstant(Val: 1, DL, VT));
3192 return DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: N0.getOperand(i: 0), N2: NewC);
3193 }
3194
3195 // If the target's bool is represented as 0/1, prefer to make this 'sub 0/1'
3196 // rather than 'add 0/-1' (the zext should get folded).
3197 // add (sext i1 Y), X --> sub X, (zext i1 Y)
3198 if (N0.getOpcode() == ISD::SIGN_EXTEND &&
3199 N0.getOperand(i: 0).getScalarValueSizeInBits() == 1 &&
3200 TLI.getBooleanContents(Type: VT) == TargetLowering::ZeroOrOneBooleanContent) {
3201 SDValue ZExt = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0.getOperand(i: 0));
3202 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1, N2: ZExt);
3203 }
3204
3205 // add X, (sextinreg Y i1) -> sub X, (and Y 1)
3206 if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
3207 VTSDNode *TN = cast<VTSDNode>(Val: N1.getOperand(i: 1));
3208 if (TN->getVT() == MVT::i1) {
3209 SDValue ZExt = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N1.getOperand(i: 0),
3210 N2: DAG.getConstant(Val: 1, DL, VT));
3211 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: ZExt);
3212 }
3213 }
3214
3215 // (add X, (uaddo_carry Y, 0, Carry)) -> (uaddo_carry X, Y, Carry)
3216 if (N1.getOpcode() == ISD::UADDO_CARRY && isNullConstant(V: N1.getOperand(i: 1)) &&
3217 N1.getResNo() == 0)
3218 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL, VTList: N1->getVTList(),
3219 N1: N0, N2: N1.getOperand(i: 0), N3: N1.getOperand(i: 2));
3220
3221 // (add X, Carry) -> (uaddo_carry X, 0, Carry)
3222 if (TLI.isOperationLegalOrCustom(Op: ISD::UADDO_CARRY, VT))
3223 if (SDValue Carry = getAsCarry(TLI, V: N1))
3224 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL,
3225 VTList: DAG.getVTList(VT1: VT, VT2: Carry.getValueType()), N1: N0,
3226 N2: DAG.getConstant(Val: 0, DL, VT), N3: Carry);
3227
3228 return SDValue();
3229}
3230
3231SDValue DAGCombiner::visitADDC(SDNode *N) {
3232 SDValue N0 = N->getOperand(Num: 0);
3233 SDValue N1 = N->getOperand(Num: 1);
3234 EVT VT = N0.getValueType();
3235 SDLoc DL(N);
3236
3237 // If the flag result is dead, turn this into an ADD.
3238 if (!N->hasAnyUseOfValue(Value: 1))
3239 return CombineTo(N, DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1),
3240 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3241
3242 // canonicalize constant to RHS.
3243 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(Val&: N0);
3244 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(Val&: N1);
3245 if (N0C && !N1C)
3246 return DAG.getNode(Opcode: ISD::ADDC, DL, VTList: N->getVTList(), N1, N2: N0);
3247
3248 // fold (addc x, 0) -> x + no carry out
3249 if (isNullConstant(V: N1))
3250 return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE,
3251 DL, MVT::Glue));
3252
3253 // If it cannot overflow, transform into an add.
3254 if (DAG.computeOverflowForUnsignedAdd(N0, N1) == SelectionDAG::OFK_Never)
3255 return CombineTo(N, DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1),
3256 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3257
3258 return SDValue();
3259}
3260
3261/**
3262 * Flips a boolean if it is cheaper to compute. If the Force parameters is set,
3263 * then the flip also occurs if computing the inverse is the same cost.
3264 * This function returns an empty SDValue in case it cannot flip the boolean
3265 * without increasing the cost of the computation. If you want to flip a boolean
3266 * no matter what, use DAG.getLogicalNOT.
3267 */
3268static SDValue extractBooleanFlip(SDValue V, SelectionDAG &DAG,
3269 const TargetLowering &TLI,
3270 bool Force) {
3271 if (Force && isa<ConstantSDNode>(Val: V))
3272 return DAG.getLogicalNOT(DL: SDLoc(V), Val: V, VT: V.getValueType());
3273
3274 if (V.getOpcode() != ISD::XOR)
3275 return SDValue();
3276
3277 ConstantSDNode *Const = isConstOrConstSplat(N: V.getOperand(i: 1), AllowUndefs: false);
3278 if (!Const)
3279 return SDValue();
3280
3281 EVT VT = V.getValueType();
3282
3283 bool IsFlip = false;
3284 switch(TLI.getBooleanContents(Type: VT)) {
3285 case TargetLowering::ZeroOrOneBooleanContent:
3286 IsFlip = Const->isOne();
3287 break;
3288 case TargetLowering::ZeroOrNegativeOneBooleanContent:
3289 IsFlip = Const->isAllOnes();
3290 break;
3291 case TargetLowering::UndefinedBooleanContent:
3292 IsFlip = (Const->getAPIntValue() & 0x01) == 1;
3293 break;
3294 }
3295
3296 if (IsFlip)
3297 return V.getOperand(i: 0);
3298 if (Force)
3299 return DAG.getLogicalNOT(DL: SDLoc(V), Val: V, VT: V.getValueType());
3300 return SDValue();
3301}
3302
3303SDValue DAGCombiner::visitADDO(SDNode *N) {
3304 SDValue N0 = N->getOperand(Num: 0);
3305 SDValue N1 = N->getOperand(Num: 1);
3306 EVT VT = N0.getValueType();
3307 bool IsSigned = (ISD::SADDO == N->getOpcode());
3308
3309 EVT CarryVT = N->getValueType(ResNo: 1);
3310 SDLoc DL(N);
3311
3312 // If the flag result is dead, turn this into an ADD.
3313 if (!N->hasAnyUseOfValue(Value: 1))
3314 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1),
3315 Res1: DAG.getUNDEF(VT: CarryVT));
3316
3317 // canonicalize constant to RHS.
3318 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
3319 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
3320 return DAG.getNode(Opcode: N->getOpcode(), DL, VTList: N->getVTList(), N1, N2: N0);
3321
3322 // fold (addo x, 0) -> x + no carry out
3323 if (isNullOrNullSplat(V: N1))
3324 return CombineTo(N, Res0: N0, Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
3325
3326 // If it cannot overflow, transform into an add.
3327 if (DAG.willNotOverflowAdd(IsSigned, N0, N1))
3328 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1),
3329 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
3330
3331 if (IsSigned) {
3332 // fold (saddo (xor a, -1), 1) -> (ssub 0, a).
3333 if (isBitwiseNot(V: N0) && isOneOrOneSplat(V: N1))
3334 return DAG.getNode(Opcode: ISD::SSUBO, DL, VTList: N->getVTList(),
3335 N1: DAG.getConstant(Val: 0, DL, VT), N2: N0.getOperand(i: 0));
3336 } else {
3337 // fold (uaddo (xor a, -1), 1) -> (usub 0, a) and flip carry.
3338 if (isBitwiseNot(V: N0) && isOneOrOneSplat(V: N1)) {
3339 SDValue Sub = DAG.getNode(Opcode: ISD::USUBO, DL, VTList: N->getVTList(),
3340 N1: DAG.getConstant(Val: 0, DL, VT), N2: N0.getOperand(i: 0));
3341 return CombineTo(
3342 N, Res0: Sub, Res1: DAG.getLogicalNOT(DL, Val: Sub.getValue(R: 1), VT: Sub->getValueType(ResNo: 1)));
3343 }
3344
3345 if (SDValue Combined = visitUADDOLike(N0, N1, N))
3346 return Combined;
3347
3348 if (SDValue Combined = visitUADDOLike(N0: N1, N1: N0, N))
3349 return Combined;
3350 }
3351
3352 return SDValue();
3353}
3354
3355SDValue DAGCombiner::visitUADDOLike(SDValue N0, SDValue N1, SDNode *N) {
3356 EVT VT = N0.getValueType();
3357 if (VT.isVector())
3358 return SDValue();
3359
3360 // (uaddo X, (uaddo_carry Y, 0, Carry)) -> (uaddo_carry X, Y, Carry)
3361 // If Y + 1 cannot overflow.
3362 if (N1.getOpcode() == ISD::UADDO_CARRY && isNullConstant(V: N1.getOperand(i: 1))) {
3363 SDValue Y = N1.getOperand(i: 0);
3364 SDValue One = DAG.getConstant(Val: 1, DL: SDLoc(N), VT: Y.getValueType());
3365 if (DAG.computeOverflowForUnsignedAdd(N0: Y, N1: One) == SelectionDAG::OFK_Never)
3366 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL: SDLoc(N), VTList: N->getVTList(), N1: N0, N2: Y,
3367 N3: N1.getOperand(i: 2));
3368 }
3369
3370 // (uaddo X, Carry) -> (uaddo_carry X, 0, Carry)
3371 if (TLI.isOperationLegalOrCustom(Op: ISD::UADDO_CARRY, VT))
3372 if (SDValue Carry = getAsCarry(TLI, V: N1))
3373 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL: SDLoc(N), VTList: N->getVTList(), N1: N0,
3374 N2: DAG.getConstant(Val: 0, DL: SDLoc(N), VT), N3: Carry);
3375
3376 return SDValue();
3377}
3378
3379SDValue DAGCombiner::visitADDE(SDNode *N) {
3380 SDValue N0 = N->getOperand(Num: 0);
3381 SDValue N1 = N->getOperand(Num: 1);
3382 SDValue CarryIn = N->getOperand(Num: 2);
3383
3384 // canonicalize constant to RHS
3385 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(Val&: N0);
3386 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(Val&: N1);
3387 if (N0C && !N1C)
3388 return DAG.getNode(Opcode: ISD::ADDE, DL: SDLoc(N), VTList: N->getVTList(),
3389 N1, N2: N0, N3: CarryIn);
3390
3391 // fold (adde x, y, false) -> (addc x, y)
3392 if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
3393 return DAG.getNode(Opcode: ISD::ADDC, DL: SDLoc(N), VTList: N->getVTList(), N1: N0, N2: N1);
3394
3395 return SDValue();
3396}
3397
3398SDValue DAGCombiner::visitUADDO_CARRY(SDNode *N) {
3399 SDValue N0 = N->getOperand(Num: 0);
3400 SDValue N1 = N->getOperand(Num: 1);
3401 SDValue CarryIn = N->getOperand(Num: 2);
3402 SDLoc DL(N);
3403
3404 // canonicalize constant to RHS
3405 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(Val&: N0);
3406 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(Val&: N1);
3407 if (N0C && !N1C)
3408 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL, VTList: N->getVTList(), N1, N2: N0, N3: CarryIn);
3409
3410 // fold (uaddo_carry x, y, false) -> (uaddo x, y)
3411 if (isNullConstant(V: CarryIn)) {
3412 if (!LegalOperations ||
3413 TLI.isOperationLegalOrCustom(Op: ISD::UADDO, VT: N->getValueType(ResNo: 0)))
3414 return DAG.getNode(Opcode: ISD::UADDO, DL, VTList: N->getVTList(), N1: N0, N2: N1);
3415 }
3416
3417 // fold (uaddo_carry 0, 0, X) -> (and (ext/trunc X), 1) and no carry.
3418 if (isNullConstant(V: N0) && isNullConstant(V: N1)) {
3419 EVT VT = N0.getValueType();
3420 EVT CarryVT = CarryIn.getValueType();
3421 SDValue CarryExt = DAG.getBoolExtOrTrunc(Op: CarryIn, SL: DL, VT, OpVT: CarryVT);
3422 AddToWorklist(N: CarryExt.getNode());
3423 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::AND, DL, VT, N1: CarryExt,
3424 N2: DAG.getConstant(Val: 1, DL, VT)),
3425 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
3426 }
3427
3428 if (SDValue Combined = visitUADDO_CARRYLike(N0, N1, CarryIn, N))
3429 return Combined;
3430
3431 if (SDValue Combined = visitUADDO_CARRYLike(N0: N1, N1: N0, CarryIn, N))
3432 return Combined;
3433
3434 // We want to avoid useless duplication.
3435 // TODO: This is done automatically for binary operations. As UADDO_CARRY is
3436 // not a binary operation, this is not really possible to leverage this
3437 // existing mechanism for it. However, if more operations require the same
3438 // deduplication logic, then it may be worth generalize.
3439 SDValue Ops[] = {N1, N0, CarryIn};
3440 SDNode *CSENode =
3441 DAG.getNodeIfExists(Opcode: ISD::UADDO_CARRY, VTList: N->getVTList(), Ops, Flags: N->getFlags());
3442 if (CSENode)
3443 return SDValue(CSENode, 0);
3444
3445 return SDValue();
3446}
3447
3448/**
3449 * If we are facing some sort of diamond carry propapagtion pattern try to
3450 * break it up to generate something like:
3451 * (uaddo_carry X, 0, (uaddo_carry A, B, Z):Carry)
3452 *
3453 * The end result is usually an increase in operation required, but because the
3454 * carry is now linearized, other transforms can kick in and optimize the DAG.
3455 *
3456 * Patterns typically look something like
3457 * (uaddo A, B)
3458 * / \
3459 * Carry Sum
3460 * | \
3461 * | (uaddo_carry *, 0, Z)
3462 * | /
3463 * \ Carry
3464 * | /
3465 * (uaddo_carry X, *, *)
3466 *
3467 * But numerous variation exist. Our goal is to identify A, B, X and Z and
3468 * produce a combine with a single path for carry propagation.
3469 */
3470static SDValue combineUADDO_CARRYDiamond(DAGCombiner &Combiner,
3471 SelectionDAG &DAG, SDValue X,
3472 SDValue Carry0, SDValue Carry1,
3473 SDNode *N) {
3474 if (Carry1.getResNo() != 1 || Carry0.getResNo() != 1)
3475 return SDValue();
3476 if (Carry1.getOpcode() != ISD::UADDO)
3477 return SDValue();
3478
3479 SDValue Z;
3480
3481 /**
3482 * First look for a suitable Z. It will present itself in the form of
3483 * (uaddo_carry Y, 0, Z) or its equivalent (uaddo Y, 1) for Z=true
3484 */
3485 if (Carry0.getOpcode() == ISD::UADDO_CARRY &&
3486 isNullConstant(V: Carry0.getOperand(i: 1))) {
3487 Z = Carry0.getOperand(i: 2);
3488 } else if (Carry0.getOpcode() == ISD::UADDO &&
3489 isOneConstant(V: Carry0.getOperand(i: 1))) {
3490 EVT VT = Combiner.getSetCCResultType(VT: Carry0.getValueType());
3491 Z = DAG.getConstant(Val: 1, DL: SDLoc(Carry0.getOperand(i: 1)), VT);
3492 } else {
3493 // We couldn't find a suitable Z.
3494 return SDValue();
3495 }
3496
3497
3498 auto cancelDiamond = [&](SDValue A,SDValue B) {
3499 SDLoc DL(N);
3500 SDValue NewY =
3501 DAG.getNode(Opcode: ISD::UADDO_CARRY, DL, VTList: Carry0->getVTList(), N1: A, N2: B, N3: Z);
3502 Combiner.AddToWorklist(N: NewY.getNode());
3503 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL, VTList: N->getVTList(), N1: X,
3504 N2: DAG.getConstant(Val: 0, DL, VT: X.getValueType()),
3505 N3: NewY.getValue(R: 1));
3506 };
3507
3508 /**
3509 * (uaddo A, B)
3510 * |
3511 * Sum
3512 * |
3513 * (uaddo_carry *, 0, Z)
3514 */
3515 if (Carry0.getOperand(i: 0) == Carry1.getValue(R: 0)) {
3516 return cancelDiamond(Carry1.getOperand(i: 0), Carry1.getOperand(i: 1));
3517 }
3518
3519 /**
3520 * (uaddo_carry A, 0, Z)
3521 * |
3522 * Sum
3523 * |
3524 * (uaddo *, B)
3525 */
3526 if (Carry1.getOperand(i: 0) == Carry0.getValue(R: 0)) {
3527 return cancelDiamond(Carry0.getOperand(i: 0), Carry1.getOperand(i: 1));
3528 }
3529
3530 if (Carry1.getOperand(i: 1) == Carry0.getValue(R: 0)) {
3531 return cancelDiamond(Carry1.getOperand(i: 0), Carry0.getOperand(i: 0));
3532 }
3533
3534 return SDValue();
3535}
3536
3537// If we are facing some sort of diamond carry/borrow in/out pattern try to
3538// match patterns like:
3539//
3540// (uaddo A, B) CarryIn
3541// | \ |
3542// | \ |
3543// PartialSum PartialCarryOutX /
3544// | | /
3545// | ____|____________/
3546// | / |
3547// (uaddo *, *) \________
3548// | \ \
3549// | \ |
3550// | PartialCarryOutY |
3551// | \ |
3552// | \ /
3553// AddCarrySum | ______/
3554// | /
3555// CarryOut = (or *, *)
3556//
3557// And generate UADDO_CARRY (or USUBO_CARRY) with two result values:
3558//
3559// {AddCarrySum, CarryOut} = (uaddo_carry A, B, CarryIn)
3560//
3561// Our goal is to identify A, B, and CarryIn and produce UADDO_CARRY/USUBO_CARRY
3562// with a single path for carry/borrow out propagation.
3563static SDValue combineCarryDiamond(SelectionDAG &DAG, const TargetLowering &TLI,
3564 SDValue N0, SDValue N1, SDNode *N) {
3565 SDValue Carry0 = getAsCarry(TLI, V: N0);
3566 if (!Carry0)
3567 return SDValue();
3568 SDValue Carry1 = getAsCarry(TLI, V: N1);
3569 if (!Carry1)
3570 return SDValue();
3571
3572 unsigned Opcode = Carry0.getOpcode();
3573 if (Opcode != Carry1.getOpcode())
3574 return SDValue();
3575 if (Opcode != ISD::UADDO && Opcode != ISD::USUBO)
3576 return SDValue();
3577
3578 // Canonicalize the add/sub of A and B (the top node in the above ASCII art)
3579 // as Carry0 and the add/sub of the carry in as Carry1 (the middle node).
3580 if (Carry1.getNode()->isOperandOf(N: Carry0.getNode()))
3581 std::swap(a&: Carry0, b&: Carry1);
3582
3583 // Check if nodes are connected in expected way.
3584 if (Carry1.getOperand(i: 0) != Carry0.getValue(R: 0) &&
3585 Carry1.getOperand(i: 1) != Carry0.getValue(R: 0))
3586 return SDValue();
3587
3588 // The carry in value must be on the righthand side for subtraction.
3589 unsigned CarryInOperandNum =
3590 Carry1.getOperand(i: 0) == Carry0.getValue(R: 0) ? 1 : 0;
3591 if (Opcode == ISD::USUBO && CarryInOperandNum != 1)
3592 return SDValue();
3593 SDValue CarryIn = Carry1.getOperand(i: CarryInOperandNum);
3594
3595 unsigned NewOp = Opcode == ISD::UADDO ? ISD::UADDO_CARRY : ISD::USUBO_CARRY;
3596 if (!TLI.isOperationLegalOrCustom(Op: NewOp, VT: Carry0.getValue(R: 0).getValueType()))
3597 return SDValue();
3598
3599 // Verify that the carry/borrow in is plausibly a carry/borrow bit.
3600 CarryIn = getAsCarry(TLI, V: CarryIn, ForceCarryReconstruction: true);
3601 if (!CarryIn)
3602 return SDValue();
3603
3604 SDLoc DL(N);
3605 SDValue Merged =
3606 DAG.getNode(Opcode: NewOp, DL, VTList: Carry1->getVTList(), N1: Carry0.getOperand(i: 0),
3607 N2: Carry0.getOperand(i: 1), N3: CarryIn);
3608
3609 // Please note that because we have proven that the result of the UADDO/USUBO
3610 // of A and B feeds into the UADDO/USUBO that does the carry/borrow in, we can
3611 // therefore prove that if the first UADDO/USUBO overflows, the second
3612 // UADDO/USUBO cannot. For example consider 8-bit numbers where 0xFF is the
3613 // maximum value.
3614 //
3615 // 0xFF + 0xFF == 0xFE with carry but 0xFE + 1 does not carry
3616 // 0x00 - 0xFF == 1 with a carry/borrow but 1 - 1 == 0 (no carry/borrow)
3617 //
3618 // This is important because it means that OR and XOR can be used to merge
3619 // carry flags; and that AND can return a constant zero.
3620 //
3621 // TODO: match other operations that can merge flags (ADD, etc)
3622 DAG.ReplaceAllUsesOfValueWith(From: Carry1.getValue(R: 0), To: Merged.getValue(R: 0));
3623 if (N->getOpcode() == ISD::AND)
3624 return DAG.getConstant(0, DL, MVT::i1);
3625 return Merged.getValue(R: 1);
3626}
3627
3628SDValue DAGCombiner::visitUADDO_CARRYLike(SDValue N0, SDValue N1,
3629 SDValue CarryIn, SDNode *N) {
3630 // fold (uaddo_carry (xor a, -1), b, c) -> (usubo_carry b, a, !c) and flip
3631 // carry.
3632 if (isBitwiseNot(V: N0))
3633 if (SDValue NotC = extractBooleanFlip(V: CarryIn, DAG, TLI, Force: true)) {
3634 SDLoc DL(N);
3635 SDValue Sub = DAG.getNode(Opcode: ISD::USUBO_CARRY, DL, VTList: N->getVTList(), N1,
3636 N2: N0.getOperand(i: 0), N3: NotC);
3637 return CombineTo(
3638 N, Res0: Sub, Res1: DAG.getLogicalNOT(DL, Val: Sub.getValue(R: 1), VT: Sub->getValueType(ResNo: 1)));
3639 }
3640
3641 // Iff the flag result is dead:
3642 // (uaddo_carry (add|uaddo X, Y), 0, Carry) -> (uaddo_carry X, Y, Carry)
3643 // Don't do this if the Carry comes from the uaddo. It won't remove the uaddo
3644 // or the dependency between the instructions.
3645 if ((N0.getOpcode() == ISD::ADD ||
3646 (N0.getOpcode() == ISD::UADDO && N0.getResNo() == 0 &&
3647 N0.getValue(R: 1) != CarryIn)) &&
3648 isNullConstant(V: N1) && !N->hasAnyUseOfValue(Value: 1))
3649 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL: SDLoc(N), VTList: N->getVTList(),
3650 N1: N0.getOperand(i: 0), N2: N0.getOperand(i: 1), N3: CarryIn);
3651
3652 /**
3653 * When one of the uaddo_carry argument is itself a carry, we may be facing
3654 * a diamond carry propagation. In which case we try to transform the DAG
3655 * to ensure linear carry propagation if that is possible.
3656 */
3657 if (auto Y = getAsCarry(TLI, V: N1)) {
3658 // Because both are carries, Y and Z can be swapped.
3659 if (auto R = combineUADDO_CARRYDiamond(Combiner&: *this, DAG, X: N0, Carry0: Y, Carry1: CarryIn, N))
3660 return R;
3661 if (auto R = combineUADDO_CARRYDiamond(Combiner&: *this, DAG, X: N0, Carry0: CarryIn, Carry1: Y, N))
3662 return R;
3663 }
3664
3665 return SDValue();
3666}
3667
3668SDValue DAGCombiner::visitSADDO_CARRYLike(SDValue N0, SDValue N1,
3669 SDValue CarryIn, SDNode *N) {
3670 // fold (saddo_carry (xor a, -1), b, c) -> (ssubo_carry b, a, !c)
3671 if (isBitwiseNot(V: N0)) {
3672 if (SDValue NotC = extractBooleanFlip(V: CarryIn, DAG, TLI, Force: true))
3673 return DAG.getNode(Opcode: ISD::SSUBO_CARRY, DL: SDLoc(N), VTList: N->getVTList(), N1,
3674 N2: N0.getOperand(i: 0), N3: NotC);
3675 }
3676
3677 return SDValue();
3678}
3679
3680SDValue DAGCombiner::visitSADDO_CARRY(SDNode *N) {
3681 SDValue N0 = N->getOperand(Num: 0);
3682 SDValue N1 = N->getOperand(Num: 1);
3683 SDValue CarryIn = N->getOperand(Num: 2);
3684 SDLoc DL(N);
3685
3686 // canonicalize constant to RHS
3687 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(Val&: N0);
3688 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(Val&: N1);
3689 if (N0C && !N1C)
3690 return DAG.getNode(Opcode: ISD::SADDO_CARRY, DL, VTList: N->getVTList(), N1, N2: N0, N3: CarryIn);
3691
3692 // fold (saddo_carry x, y, false) -> (saddo x, y)
3693 if (isNullConstant(V: CarryIn)) {
3694 if (!LegalOperations ||
3695 TLI.isOperationLegalOrCustom(Op: ISD::SADDO, VT: N->getValueType(ResNo: 0)))
3696 return DAG.getNode(Opcode: ISD::SADDO, DL, VTList: N->getVTList(), N1: N0, N2: N1);
3697 }
3698
3699 if (SDValue Combined = visitSADDO_CARRYLike(N0, N1, CarryIn, N))
3700 return Combined;
3701
3702 if (SDValue Combined = visitSADDO_CARRYLike(N0: N1, N1: N0, CarryIn, N))
3703 return Combined;
3704
3705 return SDValue();
3706}
3707
3708// Attempt to create a USUBSAT(LHS, RHS) node with DstVT, performing a
3709// clamp/truncation if necessary.
3710static SDValue getTruncatedUSUBSAT(EVT DstVT, EVT SrcVT, SDValue LHS,
3711 SDValue RHS, SelectionDAG &DAG,
3712 const SDLoc &DL) {
3713 assert(DstVT.getScalarSizeInBits() <= SrcVT.getScalarSizeInBits() &&
3714 "Illegal truncation");
3715
3716 if (DstVT == SrcVT)
3717 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT: DstVT, N1: LHS, N2: RHS);
3718
3719 // If the LHS is zero-extended then we can perform the USUBSAT as DstVT by
3720 // clamping RHS.
3721 APInt UpperBits = APInt::getBitsSetFrom(numBits: SrcVT.getScalarSizeInBits(),
3722 loBit: DstVT.getScalarSizeInBits());
3723 if (!DAG.MaskedValueIsZero(Op: LHS, Mask: UpperBits))
3724 return SDValue();
3725
3726 SDValue SatLimit =
3727 DAG.getConstant(Val: APInt::getLowBitsSet(numBits: SrcVT.getScalarSizeInBits(),
3728 loBitsSet: DstVT.getScalarSizeInBits()),
3729 DL, VT: SrcVT);
3730 RHS = DAG.getNode(Opcode: ISD::UMIN, DL, VT: SrcVT, N1: RHS, N2: SatLimit);
3731 RHS = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: DstVT, Operand: RHS);
3732 LHS = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: DstVT, Operand: LHS);
3733 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT: DstVT, N1: LHS, N2: RHS);
3734}
3735
3736// Try to find umax(a,b) - b or a - umin(a,b) patterns that may be converted to
3737// usubsat(a,b), optionally as a truncated type.
3738SDValue DAGCombiner::foldSubToUSubSat(EVT DstVT, SDNode *N) {
3739 if (N->getOpcode() != ISD::SUB ||
3740 !(!LegalOperations || hasOperation(Opcode: ISD::USUBSAT, VT: DstVT)))
3741 return SDValue();
3742
3743 EVT SubVT = N->getValueType(ResNo: 0);
3744 SDValue Op0 = N->getOperand(Num: 0);
3745 SDValue Op1 = N->getOperand(Num: 1);
3746
3747 // Try to find umax(a,b) - b or a - umin(a,b) patterns
3748 // they may be converted to usubsat(a,b).
3749 if (Op0.getOpcode() == ISD::UMAX && Op0.hasOneUse()) {
3750 SDValue MaxLHS = Op0.getOperand(i: 0);
3751 SDValue MaxRHS = Op0.getOperand(i: 1);
3752 if (MaxLHS == Op1)
3753 return getTruncatedUSUBSAT(DstVT, SrcVT: SubVT, LHS: MaxRHS, RHS: Op1, DAG, DL: SDLoc(N));
3754 if (MaxRHS == Op1)
3755 return getTruncatedUSUBSAT(DstVT, SrcVT: SubVT, LHS: MaxLHS, RHS: Op1, DAG, DL: SDLoc(N));
3756 }
3757
3758 if (Op1.getOpcode() == ISD::UMIN && Op1.hasOneUse()) {
3759 SDValue MinLHS = Op1.getOperand(i: 0);
3760 SDValue MinRHS = Op1.getOperand(i: 1);
3761 if (MinLHS == Op0)
3762 return getTruncatedUSUBSAT(DstVT, SrcVT: SubVT, LHS: Op0, RHS: MinRHS, DAG, DL: SDLoc(N));
3763 if (MinRHS == Op0)
3764 return getTruncatedUSUBSAT(DstVT, SrcVT: SubVT, LHS: Op0, RHS: MinLHS, DAG, DL: SDLoc(N));
3765 }
3766
3767 // sub(a,trunc(umin(zext(a),b))) -> usubsat(a,trunc(umin(b,SatLimit)))
3768 if (Op1.getOpcode() == ISD::TRUNCATE &&
3769 Op1.getOperand(i: 0).getOpcode() == ISD::UMIN &&
3770 Op1.getOperand(i: 0).hasOneUse()) {
3771 SDValue MinLHS = Op1.getOperand(i: 0).getOperand(i: 0);
3772 SDValue MinRHS = Op1.getOperand(i: 0).getOperand(i: 1);
3773 if (MinLHS.getOpcode() == ISD::ZERO_EXTEND && MinLHS.getOperand(i: 0) == Op0)
3774 return getTruncatedUSUBSAT(DstVT, SrcVT: MinLHS.getValueType(), LHS: MinLHS, RHS: MinRHS,
3775 DAG, DL: SDLoc(N));
3776 if (MinRHS.getOpcode() == ISD::ZERO_EXTEND && MinRHS.getOperand(i: 0) == Op0)
3777 return getTruncatedUSUBSAT(DstVT, SrcVT: MinLHS.getValueType(), LHS: MinRHS, RHS: MinLHS,
3778 DAG, DL: SDLoc(N));
3779 }
3780
3781 return SDValue();
3782}
3783
3784// Since it may not be valid to emit a fold to zero for vector initializers
3785// check if we can before folding.
3786static SDValue tryFoldToZero(const SDLoc &DL, const TargetLowering &TLI, EVT VT,
3787 SelectionDAG &DAG, bool LegalOperations) {
3788 if (!VT.isVector())
3789 return DAG.getConstant(Val: 0, DL, VT);
3790 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT))
3791 return DAG.getConstant(Val: 0, DL, VT);
3792 return SDValue();
3793}
3794
3795SDValue DAGCombiner::visitSUB(SDNode *N) {
3796 SDValue N0 = N->getOperand(Num: 0);
3797 SDValue N1 = N->getOperand(Num: 1);
3798 EVT VT = N0.getValueType();
3799 SDLoc DL(N);
3800
3801 auto PeekThroughFreeze = [](SDValue N) {
3802 if (N->getOpcode() == ISD::FREEZE && N.hasOneUse())
3803 return N->getOperand(Num: 0);
3804 return N;
3805 };
3806
3807 // fold (sub x, x) -> 0
3808 // FIXME: Refactor this and xor and other similar operations together.
3809 if (PeekThroughFreeze(N0) == PeekThroughFreeze(N1))
3810 return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
3811
3812 // fold (sub c1, c2) -> c3
3813 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::SUB, DL, VT, Ops: {N0, N1}))
3814 return C;
3815
3816 // fold vector ops
3817 if (VT.isVector()) {
3818 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
3819 return FoldedVOp;
3820
3821 // fold (sub x, 0) -> x, vector edition
3822 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
3823 return N0;
3824 }
3825
3826 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
3827 return NewSel;
3828
3829 ConstantSDNode *N1C = getAsNonOpaqueConstant(N: N1);
3830
3831 // fold (sub x, c) -> (add x, -c)
3832 if (N1C) {
3833 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0,
3834 N2: DAG.getConstant(Val: -N1C->getAPIntValue(), DL, VT));
3835 }
3836
3837 if (isNullOrNullSplat(V: N0)) {
3838 unsigned BitWidth = VT.getScalarSizeInBits();
3839 // Right-shifting everything out but the sign bit followed by negation is
3840 // the same as flipping arithmetic/logical shift type without the negation:
3841 // -(X >>u 31) -> (X >>s 31)
3842 // -(X >>s 31) -> (X >>u 31)
3843 if (N1->getOpcode() == ISD::SRA || N1->getOpcode() == ISD::SRL) {
3844 ConstantSDNode *ShiftAmt = isConstOrConstSplat(N: N1.getOperand(i: 1));
3845 if (ShiftAmt && ShiftAmt->getAPIntValue() == (BitWidth - 1)) {
3846 auto NewSh = N1->getOpcode() == ISD::SRA ? ISD::SRL : ISD::SRA;
3847 if (!LegalOperations || TLI.isOperationLegal(Op: NewSh, VT))
3848 return DAG.getNode(Opcode: NewSh, DL, VT, N1: N1.getOperand(i: 0), N2: N1.getOperand(i: 1));
3849 }
3850 }
3851
3852 // 0 - X --> 0 if the sub is NUW.
3853 if (N->getFlags().hasNoUnsignedWrap())
3854 return N0;
3855
3856 if (DAG.MaskedValueIsZero(Op: N1, Mask: ~APInt::getSignMask(BitWidth))) {
3857 // N1 is either 0 or the minimum signed value. If the sub is NSW, then
3858 // N1 must be 0 because negating the minimum signed value is undefined.
3859 if (N->getFlags().hasNoSignedWrap())
3860 return N0;
3861
3862 // 0 - X --> X if X is 0 or the minimum signed value.
3863 return N1;
3864 }
3865
3866 // Convert 0 - abs(x).
3867 if (N1.getOpcode() == ISD::ABS && N1.hasOneUse() &&
3868 !TLI.isOperationLegalOrCustom(Op: ISD::ABS, VT))
3869 if (SDValue Result = TLI.expandABS(N: N1.getNode(), DAG, IsNegative: true))
3870 return Result;
3871
3872 // Fold neg(splat(neg(x)) -> splat(x)
3873 if (VT.isVector()) {
3874 SDValue N1S = DAG.getSplatValue(V: N1, LegalTypes: true);
3875 if (N1S && N1S.getOpcode() == ISD::SUB &&
3876 isNullConstant(V: N1S.getOperand(i: 0)))
3877 return DAG.getSplat(VT, DL, Op: N1S.getOperand(i: 1));
3878 }
3879 }
3880
3881 // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1)
3882 if (isAllOnesOrAllOnesSplat(V: N0))
3883 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1, N2: N0);
3884
3885 // fold (A - (0-B)) -> A+B
3886 if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(V: N1.getOperand(i: 0)))
3887 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1.getOperand(i: 1));
3888
3889 // fold A-(A-B) -> B
3890 if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(i: 0))
3891 return N1.getOperand(i: 1);
3892
3893 // fold (A+B)-A -> B
3894 if (N0.getOpcode() == ISD::ADD && N0.getOperand(i: 0) == N1)
3895 return N0.getOperand(i: 1);
3896
3897 // fold (A+B)-B -> A
3898 if (N0.getOpcode() == ISD::ADD && N0.getOperand(i: 1) == N1)
3899 return N0.getOperand(i: 0);
3900
3901 // fold (A+C1)-C2 -> A+(C1-C2)
3902 if (N0.getOpcode() == ISD::ADD) {
3903 SDValue N01 = N0.getOperand(i: 1);
3904 if (SDValue NewC = DAG.FoldConstantArithmetic(Opcode: ISD::SUB, DL, VT, Ops: {N01, N1}))
3905 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: NewC);
3906 }
3907
3908 // fold C2-(A+C1) -> (C2-C1)-A
3909 if (N1.getOpcode() == ISD::ADD) {
3910 SDValue N11 = N1.getOperand(i: 1);
3911 if (SDValue NewC = DAG.FoldConstantArithmetic(Opcode: ISD::SUB, DL, VT, Ops: {N0, N11}))
3912 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: NewC, N2: N1.getOperand(i: 0));
3913 }
3914
3915 // fold (A-C1)-C2 -> A-(C1+C2)
3916 if (N0.getOpcode() == ISD::SUB) {
3917 SDValue N01 = N0.getOperand(i: 1);
3918 if (SDValue NewC = DAG.FoldConstantArithmetic(Opcode: ISD::ADD, DL, VT, Ops: {N01, N1}))
3919 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0.getOperand(i: 0), N2: NewC);
3920 }
3921
3922 // fold (c1-A)-c2 -> (c1-c2)-A
3923 if (N0.getOpcode() == ISD::SUB) {
3924 SDValue N00 = N0.getOperand(i: 0);
3925 if (SDValue NewC = DAG.FoldConstantArithmetic(Opcode: ISD::SUB, DL, VT, Ops: {N00, N1}))
3926 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: NewC, N2: N0.getOperand(i: 1));
3927 }
3928
3929 // fold ((A+(B+or-C))-B) -> A+or-C
3930 if (N0.getOpcode() == ISD::ADD &&
3931 (N0.getOperand(i: 1).getOpcode() == ISD::SUB ||
3932 N0.getOperand(i: 1).getOpcode() == ISD::ADD) &&
3933 N0.getOperand(i: 1).getOperand(i: 0) == N1)
3934 return DAG.getNode(Opcode: N0.getOperand(i: 1).getOpcode(), DL, VT, N1: N0.getOperand(i: 0),
3935 N2: N0.getOperand(i: 1).getOperand(i: 1));
3936
3937 // fold ((A+(C+B))-B) -> A+C
3938 if (N0.getOpcode() == ISD::ADD && N0.getOperand(i: 1).getOpcode() == ISD::ADD &&
3939 N0.getOperand(i: 1).getOperand(i: 1) == N1)
3940 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0),
3941 N2: N0.getOperand(i: 1).getOperand(i: 0));
3942
3943 // fold ((A-(B-C))-C) -> A-B
3944 if (N0.getOpcode() == ISD::SUB && N0.getOperand(i: 1).getOpcode() == ISD::SUB &&
3945 N0.getOperand(i: 1).getOperand(i: 1) == N1)
3946 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0.getOperand(i: 0),
3947 N2: N0.getOperand(i: 1).getOperand(i: 0));
3948
3949 // fold (A-(B-C)) -> A+(C-B)
3950 if (N1.getOpcode() == ISD::SUB && N1.hasOneUse())
3951 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0,
3952 N2: DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N1.getOperand(i: 1),
3953 N2: N1.getOperand(i: 0)));
3954
3955 // A - (A & B) -> A & (~B)
3956 if (N1.getOpcode() == ISD::AND) {
3957 SDValue A = N1.getOperand(i: 0);
3958 SDValue B = N1.getOperand(i: 1);
3959 if (A != N0)
3960 std::swap(a&: A, b&: B);
3961 if (A == N0 &&
3962 (N1.hasOneUse() || isConstantOrConstantVector(N: B, /*NoOpaques=*/true))) {
3963 SDValue InvB =
3964 DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: B, N2: DAG.getAllOnesConstant(DL, VT));
3965 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: A, N2: InvB);
3966 }
3967 }
3968
3969 // fold (X - (-Y * Z)) -> (X + (Y * Z))
3970 if (N1.getOpcode() == ISD::MUL && N1.hasOneUse()) {
3971 if (N1.getOperand(i: 0).getOpcode() == ISD::SUB &&
3972 isNullOrNullSplat(V: N1.getOperand(i: 0).getOperand(i: 0))) {
3973 SDValue Mul = DAG.getNode(Opcode: ISD::MUL, DL, VT,
3974 N1: N1.getOperand(i: 0).getOperand(i: 1),
3975 N2: N1.getOperand(i: 1));
3976 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: Mul);
3977 }
3978 if (N1.getOperand(i: 1).getOpcode() == ISD::SUB &&
3979 isNullOrNullSplat(V: N1.getOperand(i: 1).getOperand(i: 0))) {
3980 SDValue Mul = DAG.getNode(Opcode: ISD::MUL, DL, VT,
3981 N1: N1.getOperand(i: 0),
3982 N2: N1.getOperand(i: 1).getOperand(i: 1));
3983 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: Mul);
3984 }
3985 }
3986
3987 // If either operand of a sub is undef, the result is undef
3988 if (N0.isUndef())
3989 return N0;
3990 if (N1.isUndef())
3991 return N1;
3992
3993 if (SDValue V = foldAddSubBoolOfMaskedVal(N, DAG))
3994 return V;
3995
3996 if (SDValue V = foldAddSubOfSignBit(N, DAG))
3997 return V;
3998
3999 if (SDValue V = foldAddSubMasked1(IsAdd: false, N0, N1, DAG, DL: SDLoc(N)))
4000 return V;
4001
4002 if (SDValue V = foldSubToUSubSat(DstVT: VT, N))
4003 return V;
4004
4005 // (x - y) - 1 -> add (xor y, -1), x
4006 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() && isOneOrOneSplat(V: N1)) {
4007 SDValue Xor = DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: N0.getOperand(i: 1),
4008 N2: DAG.getAllOnesConstant(DL, VT));
4009 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Xor, N2: N0.getOperand(i: 0));
4010 }
4011
4012 // Look for:
4013 // sub y, (xor x, -1)
4014 // And if the target does not like this form then turn into:
4015 // add (add x, y), 1
4016 if (TLI.preferIncOfAddToSubOfNot(VT) && N1.hasOneUse() && isBitwiseNot(V: N1)) {
4017 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1.getOperand(i: 0));
4018 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Add, N2: DAG.getConstant(Val: 1, DL, VT));
4019 }
4020
4021 // Hoist one-use addition by non-opaque constant:
4022 // (x + C) - y -> (x - y) + C
4023 if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
4024 isConstantOrConstantVector(N: N0.getOperand(i: 1), /*NoOpaques=*/true)) {
4025 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0.getOperand(i: 0), N2: N1);
4026 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Sub, N2: N0.getOperand(i: 1));
4027 }
4028 // y - (x + C) -> (y - x) - C
4029 if (N1.getOpcode() == ISD::ADD && N1.hasOneUse() &&
4030 isConstantOrConstantVector(N: N1.getOperand(i: 1), /*NoOpaques=*/true)) {
4031 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: N1.getOperand(i: 0));
4032 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Sub, N2: N1.getOperand(i: 1));
4033 }
4034 // (x - C) - y -> (x - y) - C
4035 // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
4036 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
4037 isConstantOrConstantVector(N: N0.getOperand(i: 1), /*NoOpaques=*/true)) {
4038 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0.getOperand(i: 0), N2: N1);
4039 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Sub, N2: N0.getOperand(i: 1));
4040 }
4041 // (C - x) - y -> C - (x + y)
4042 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
4043 isConstantOrConstantVector(N: N0.getOperand(i: 0), /*NoOpaques=*/true)) {
4044 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 1), N2: N1);
4045 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0.getOperand(i: 0), N2: Add);
4046 }
4047
4048 // If the target's bool is represented as 0/-1, prefer to make this 'add 0/-1'
4049 // rather than 'sub 0/1' (the sext should get folded).
4050 // sub X, (zext i1 Y) --> add X, (sext i1 Y)
4051 if (N1.getOpcode() == ISD::ZERO_EXTEND &&
4052 N1.getOperand(i: 0).getScalarValueSizeInBits() == 1 &&
4053 TLI.getBooleanContents(Type: VT) ==
4054 TargetLowering::ZeroOrNegativeOneBooleanContent) {
4055 SDValue SExt = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: N1.getOperand(i: 0));
4056 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: SExt);
4057 }
4058
4059 // fold Y = sra (X, size(X)-1); sub (xor (X, Y), Y) -> (abs X)
4060 if (TLI.isOperationLegalOrCustom(Op: ISD::ABS, VT)) {
4061 if (N0.getOpcode() == ISD::XOR && N1.getOpcode() == ISD::SRA) {
4062 SDValue X0 = N0.getOperand(i: 0), X1 = N0.getOperand(i: 1);
4063 SDValue S0 = N1.getOperand(i: 0);
4064 if ((X0 == S0 && X1 == N1) || (X0 == N1 && X1 == S0))
4065 if (ConstantSDNode *C = isConstOrConstSplat(N: N1.getOperand(i: 1)))
4066 if (C->getAPIntValue() == (VT.getScalarSizeInBits() - 1))
4067 return DAG.getNode(Opcode: ISD::ABS, DL: SDLoc(N), VT, Operand: S0);
4068 }
4069 }
4070
4071 // If the relocation model supports it, consider symbol offsets.
4072 if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(Val&: N0))
4073 if (!LegalOperations && TLI.isOffsetFoldingLegal(GA)) {
4074 // fold (sub Sym+c1, Sym+c2) -> c1-c2
4075 if (GlobalAddressSDNode *GB = dyn_cast<GlobalAddressSDNode>(Val&: N1))
4076 if (GA->getGlobal() == GB->getGlobal())
4077 return DAG.getConstant(Val: (uint64_t)GA->getOffset() - GB->getOffset(),
4078 DL, VT);
4079 }
4080
4081 // sub X, (sextinreg Y i1) -> add X, (and Y 1)
4082 if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
4083 VTSDNode *TN = cast<VTSDNode>(Val: N1.getOperand(i: 1));
4084 if (TN->getVT() == MVT::i1) {
4085 SDValue ZExt = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N1.getOperand(i: 0),
4086 N2: DAG.getConstant(Val: 1, DL, VT));
4087 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: ZExt);
4088 }
4089 }
4090
4091 // canonicalize (sub X, (vscale * C)) to (add X, (vscale * -C))
4092 if (N1.getOpcode() == ISD::VSCALE && N1.hasOneUse()) {
4093 const APInt &IntVal = N1.getConstantOperandAPInt(i: 0);
4094 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: DAG.getVScale(DL, VT, MulImm: -IntVal));
4095 }
4096
4097 // canonicalize (sub X, step_vector(C)) to (add X, step_vector(-C))
4098 if (N1.getOpcode() == ISD::STEP_VECTOR && N1.hasOneUse()) {
4099 APInt NewStep = -N1.getConstantOperandAPInt(i: 0);
4100 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0,
4101 N2: DAG.getStepVector(DL, ResVT: VT, StepVal: NewStep));
4102 }
4103
4104 // Prefer an add for more folding potential and possibly better codegen:
4105 // sub N0, (lshr N10, width-1) --> add N0, (ashr N10, width-1)
4106 if (!LegalOperations && N1.getOpcode() == ISD::SRL && N1.hasOneUse()) {
4107 SDValue ShAmt = N1.getOperand(i: 1);
4108 ConstantSDNode *ShAmtC = isConstOrConstSplat(N: ShAmt);
4109 if (ShAmtC &&
4110 ShAmtC->getAPIntValue() == (N1.getScalarValueSizeInBits() - 1)) {
4111 SDValue SRA = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: N1.getOperand(i: 0), N2: ShAmt);
4112 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: SRA);
4113 }
4114 }
4115
4116 // As with the previous fold, prefer add for more folding potential.
4117 // Subtracting SMIN/0 is the same as adding SMIN/0:
4118 // N0 - (X << BW-1) --> N0 + (X << BW-1)
4119 if (N1.getOpcode() == ISD::SHL) {
4120 ConstantSDNode *ShlC = isConstOrConstSplat(N: N1.getOperand(i: 1));
4121 if (ShlC && ShlC->getAPIntValue() == VT.getScalarSizeInBits() - 1)
4122 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: N0);
4123 }
4124
4125 // (sub (usubo_carry X, 0, Carry), Y) -> (usubo_carry X, Y, Carry)
4126 if (N0.getOpcode() == ISD::USUBO_CARRY && isNullConstant(V: N0.getOperand(i: 1)) &&
4127 N0.getResNo() == 0 && N0.hasOneUse())
4128 return DAG.getNode(Opcode: ISD::USUBO_CARRY, DL, VTList: N0->getVTList(),
4129 N1: N0.getOperand(i: 0), N2: N1, N3: N0.getOperand(i: 2));
4130
4131 if (TLI.isOperationLegalOrCustom(Op: ISD::UADDO_CARRY, VT)) {
4132 // (sub Carry, X) -> (uaddo_carry (sub 0, X), 0, Carry)
4133 if (SDValue Carry = getAsCarry(TLI, V: N0)) {
4134 SDValue X = N1;
4135 SDValue Zero = DAG.getConstant(Val: 0, DL, VT);
4136 SDValue NegX = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Zero, N2: X);
4137 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL,
4138 VTList: DAG.getVTList(VT1: VT, VT2: Carry.getValueType()), N1: NegX, N2: Zero,
4139 N3: Carry);
4140 }
4141 }
4142
4143 // If there's no chance of borrowing from adjacent bits, then sub is xor:
4144 // sub C0, X --> xor X, C0
4145 if (ConstantSDNode *C0 = isConstOrConstSplat(N: N0)) {
4146 if (!C0->isOpaque()) {
4147 const APInt &C0Val = C0->getAPIntValue();
4148 const APInt &MaybeOnes = ~DAG.computeKnownBits(Op: N1).Zero;
4149 if ((C0Val - MaybeOnes) == (C0Val ^ MaybeOnes))
4150 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1, N2: N0);
4151 }
4152 }
4153
4154 // max(a,b) - min(a,b) --> abd(a,b)
4155 auto MatchSubMaxMin = [&](unsigned Max, unsigned Min, unsigned Abd) {
4156 if (N0.getOpcode() != Max || N1.getOpcode() != Min)
4157 return SDValue();
4158 if ((N0.getOperand(i: 0) != N1.getOperand(i: 0) ||
4159 N0.getOperand(i: 1) != N1.getOperand(i: 1)) &&
4160 (N0.getOperand(i: 0) != N1.getOperand(i: 1) ||
4161 N0.getOperand(i: 1) != N1.getOperand(i: 0)))
4162 return SDValue();
4163 if (!hasOperation(Opcode: Abd, VT))
4164 return SDValue();
4165 return DAG.getNode(Opcode: Abd, DL, VT, N1: N0.getOperand(i: 0), N2: N0.getOperand(i: 1));
4166 };
4167 if (SDValue R = MatchSubMaxMin(ISD::SMAX, ISD::SMIN, ISD::ABDS))
4168 return R;
4169 if (SDValue R = MatchSubMaxMin(ISD::UMAX, ISD::UMIN, ISD::ABDU))
4170 return R;
4171
4172 return SDValue();
4173}
4174
4175SDValue DAGCombiner::visitSUBSAT(SDNode *N) {
4176 unsigned Opcode = N->getOpcode();
4177 SDValue N0 = N->getOperand(Num: 0);
4178 SDValue N1 = N->getOperand(Num: 1);
4179 EVT VT = N0.getValueType();
4180 bool IsSigned = Opcode == ISD::SSUBSAT;
4181 SDLoc DL(N);
4182
4183 // fold (sub_sat x, undef) -> 0
4184 if (N0.isUndef() || N1.isUndef())
4185 return DAG.getConstant(Val: 0, DL, VT);
4186
4187 // fold (sub_sat x, x) -> 0
4188 if (N0 == N1)
4189 return DAG.getConstant(Val: 0, DL, VT);
4190
4191 // fold (sub_sat c1, c2) -> c3
4192 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
4193 return C;
4194
4195 // fold vector ops
4196 if (VT.isVector()) {
4197 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4198 return FoldedVOp;
4199
4200 // fold (sub_sat x, 0) -> x, vector edition
4201 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
4202 return N0;
4203 }
4204
4205 // fold (sub_sat x, 0) -> x
4206 if (isNullConstant(V: N1))
4207 return N0;
4208
4209 // If it cannot overflow, transform into an sub.
4210 if (DAG.willNotOverflowSub(IsSigned, N0, N1))
4211 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: N1);
4212
4213 return SDValue();
4214}
4215
4216SDValue DAGCombiner::visitSUBC(SDNode *N) {
4217 SDValue N0 = N->getOperand(Num: 0);
4218 SDValue N1 = N->getOperand(Num: 1);
4219 EVT VT = N0.getValueType();
4220 SDLoc DL(N);
4221
4222 // If the flag result is dead, turn this into an SUB.
4223 if (!N->hasAnyUseOfValue(Value: 1))
4224 return CombineTo(N, DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: N1),
4225 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4226
4227 // fold (subc x, x) -> 0 + no borrow
4228 if (N0 == N1)
4229 return CombineTo(N, DAG.getConstant(Val: 0, DL, VT),
4230 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4231
4232 // fold (subc x, 0) -> x + no borrow
4233 if (isNullConstant(V: N1))
4234 return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4235
4236 // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1) + no borrow
4237 if (isAllOnesConstant(V: N0))
4238 return CombineTo(N, DAG.getNode(Opcode: ISD::XOR, DL, VT, N1, N2: N0),
4239 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4240
4241 return SDValue();
4242}
4243
4244SDValue DAGCombiner::visitSUBO(SDNode *N) {
4245 SDValue N0 = N->getOperand(Num: 0);
4246 SDValue N1 = N->getOperand(Num: 1);
4247 EVT VT = N0.getValueType();
4248 bool IsSigned = (ISD::SSUBO == N->getOpcode());
4249
4250 EVT CarryVT = N->getValueType(ResNo: 1);
4251 SDLoc DL(N);
4252
4253 // If the flag result is dead, turn this into an SUB.
4254 if (!N->hasAnyUseOfValue(Value: 1))
4255 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: N1),
4256 Res1: DAG.getUNDEF(VT: CarryVT));
4257
4258 // fold (subo x, x) -> 0 + no borrow
4259 if (N0 == N1)
4260 return CombineTo(N, Res0: DAG.getConstant(Val: 0, DL, VT),
4261 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
4262
4263 ConstantSDNode *N1C = getAsNonOpaqueConstant(N: N1);
4264
4265 // fold (subox, c) -> (addo x, -c)
4266 if (IsSigned && N1C && !N1C->isMinSignedValue()) {
4267 return DAG.getNode(Opcode: ISD::SADDO, DL, VTList: N->getVTList(), N1: N0,
4268 N2: DAG.getConstant(Val: -N1C->getAPIntValue(), DL, VT));
4269 }
4270
4271 // fold (subo x, 0) -> x + no borrow
4272 if (isNullOrNullSplat(V: N1))
4273 return CombineTo(N, Res0: N0, Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
4274
4275 // If it cannot overflow, transform into an sub.
4276 if (DAG.willNotOverflowSub(IsSigned, N0, N1))
4277 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: N1),
4278 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
4279
4280 // Canonicalize (usubo -1, x) -> ~x, i.e. (xor x, -1) + no borrow
4281 if (!IsSigned && isAllOnesOrAllOnesSplat(V: N0))
4282 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::XOR, DL, VT, N1, N2: N0),
4283 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
4284
4285 return SDValue();
4286}
4287
4288SDValue DAGCombiner::visitSUBE(SDNode *N) {
4289 SDValue N0 = N->getOperand(Num: 0);
4290 SDValue N1 = N->getOperand(Num: 1);
4291 SDValue CarryIn = N->getOperand(Num: 2);
4292
4293 // fold (sube x, y, false) -> (subc x, y)
4294 if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
4295 return DAG.getNode(Opcode: ISD::SUBC, DL: SDLoc(N), VTList: N->getVTList(), N1: N0, N2: N1);
4296
4297 return SDValue();
4298}
4299
4300SDValue DAGCombiner::visitUSUBO_CARRY(SDNode *N) {
4301 SDValue N0 = N->getOperand(Num: 0);
4302 SDValue N1 = N->getOperand(Num: 1);
4303 SDValue CarryIn = N->getOperand(Num: 2);
4304
4305 // fold (usubo_carry x, y, false) -> (usubo x, y)
4306 if (isNullConstant(V: CarryIn)) {
4307 if (!LegalOperations ||
4308 TLI.isOperationLegalOrCustom(Op: ISD::USUBO, VT: N->getValueType(ResNo: 0)))
4309 return DAG.getNode(Opcode: ISD::USUBO, DL: SDLoc(N), VTList: N->getVTList(), N1: N0, N2: N1);
4310 }
4311
4312 return SDValue();
4313}
4314
4315SDValue DAGCombiner::visitSSUBO_CARRY(SDNode *N) {
4316 SDValue N0 = N->getOperand(Num: 0);
4317 SDValue N1 = N->getOperand(Num: 1);
4318 SDValue CarryIn = N->getOperand(Num: 2);
4319
4320 // fold (ssubo_carry x, y, false) -> (ssubo x, y)
4321 if (isNullConstant(V: CarryIn)) {
4322 if (!LegalOperations ||
4323 TLI.isOperationLegalOrCustom(Op: ISD::SSUBO, VT: N->getValueType(ResNo: 0)))
4324 return DAG.getNode(Opcode: ISD::SSUBO, DL: SDLoc(N), VTList: N->getVTList(), N1: N0, N2: N1);
4325 }
4326
4327 return SDValue();
4328}
4329
4330// Notice that "mulfix" can be any of SMULFIX, SMULFIXSAT, UMULFIX and
4331// UMULFIXSAT here.
4332SDValue DAGCombiner::visitMULFIX(SDNode *N) {
4333 SDValue N0 = N->getOperand(Num: 0);
4334 SDValue N1 = N->getOperand(Num: 1);
4335 SDValue Scale = N->getOperand(Num: 2);
4336 EVT VT = N0.getValueType();
4337
4338 // fold (mulfix x, undef, scale) -> 0
4339 if (N0.isUndef() || N1.isUndef())
4340 return DAG.getConstant(Val: 0, DL: SDLoc(N), VT);
4341
4342 // Canonicalize constant to RHS (vector doesn't have to splat)
4343 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
4344 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
4345 return DAG.getNode(Opcode: N->getOpcode(), DL: SDLoc(N), VT, N1, N2: N0, N3: Scale);
4346
4347 // fold (mulfix x, 0, scale) -> 0
4348 if (isNullConstant(V: N1))
4349 return DAG.getConstant(Val: 0, DL: SDLoc(N), VT);
4350
4351 return SDValue();
4352}
4353
4354SDValue DAGCombiner::visitMUL(SDNode *N) {
4355 SDValue N0 = N->getOperand(Num: 0);
4356 SDValue N1 = N->getOperand(Num: 1);
4357 EVT VT = N0.getValueType();
4358 SDLoc DL(N);
4359
4360 // fold (mul x, undef) -> 0
4361 if (N0.isUndef() || N1.isUndef())
4362 return DAG.getConstant(Val: 0, DL, VT);
4363
4364 // fold (mul c1, c2) -> c1*c2
4365 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::MUL, DL, VT, Ops: {N0, N1}))
4366 return C;
4367
4368 // canonicalize constant to RHS (vector doesn't have to splat)
4369 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
4370 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
4371 return DAG.getNode(Opcode: ISD::MUL, DL, VT, N1, N2: N0);
4372
4373 bool N1IsConst = false;
4374 bool N1IsOpaqueConst = false;
4375 APInt ConstValue1;
4376
4377 // fold vector ops
4378 if (VT.isVector()) {
4379 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4380 return FoldedVOp;
4381
4382 N1IsConst = ISD::isConstantSplatVector(N: N1.getNode(), SplatValue&: ConstValue1);
4383 assert((!N1IsConst ||
4384 ConstValue1.getBitWidth() == VT.getScalarSizeInBits()) &&
4385 "Splat APInt should be element width");
4386 } else {
4387 N1IsConst = isa<ConstantSDNode>(Val: N1);
4388 if (N1IsConst) {
4389 ConstValue1 = N1->getAsAPIntVal();
4390 N1IsOpaqueConst = cast<ConstantSDNode>(Val&: N1)->isOpaque();
4391 }
4392 }
4393
4394 // fold (mul x, 0) -> 0
4395 if (N1IsConst && ConstValue1.isZero())
4396 return N1;
4397
4398 // fold (mul x, 1) -> x
4399 if (N1IsConst && ConstValue1.isOne())
4400 return N0;
4401
4402 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
4403 return NewSel;
4404
4405 // fold (mul x, -1) -> 0-x
4406 if (N1IsConst && ConstValue1.isAllOnes())
4407 return DAG.getNegative(Val: N0, DL, VT);
4408
4409 // fold (mul x, (1 << c)) -> x << c
4410 if (isConstantOrConstantVector(N: N1, /*NoOpaques*/ true) &&
4411 (!VT.isVector() || Level <= AfterLegalizeVectorOps)) {
4412 if (SDValue LogBase2 = BuildLogBase2(V: N1, DL)) {
4413 EVT ShiftVT = getShiftAmountTy(LHSTy: N0.getValueType());
4414 SDValue Trunc = DAG.getZExtOrTrunc(Op: LogBase2, DL, VT: ShiftVT);
4415 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0, N2: Trunc);
4416 }
4417 }
4418
4419 // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
4420 if (N1IsConst && !N1IsOpaqueConst && ConstValue1.isNegatedPowerOf2()) {
4421 unsigned Log2Val = (-ConstValue1).logBase2();
4422 EVT ShiftVT = getShiftAmountTy(LHSTy: N0.getValueType());
4423
4424 // FIXME: If the input is something that is easily negated (e.g. a
4425 // single-use add), we should put the negate there.
4426 return DAG.getNode(Opcode: ISD::SUB, DL, VT,
4427 N1: DAG.getConstant(Val: 0, DL, VT),
4428 N2: DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0,
4429 N2: DAG.getConstant(Val: Log2Val, DL, VT: ShiftVT)));
4430 }
4431
4432 // Attempt to reuse an existing umul_lohi/smul_lohi node, but only if the
4433 // hi result is in use in case we hit this mid-legalization.
4434 for (unsigned LoHiOpc : {ISD::UMUL_LOHI, ISD::SMUL_LOHI}) {
4435 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: LoHiOpc, VT)) {
4436 SDVTList LoHiVT = DAG.getVTList(VT1: VT, VT2: VT);
4437 // TODO: Can we match commutable operands with getNodeIfExists?
4438 if (SDNode *LoHi = DAG.getNodeIfExists(Opcode: LoHiOpc, VTList: LoHiVT, Ops: {N0, N1}))
4439 if (LoHi->hasAnyUseOfValue(Value: 1))
4440 return SDValue(LoHi, 0);
4441 if (SDNode *LoHi = DAG.getNodeIfExists(Opcode: LoHiOpc, VTList: LoHiVT, Ops: {N1, N0}))
4442 if (LoHi->hasAnyUseOfValue(Value: 1))
4443 return SDValue(LoHi, 0);
4444 }
4445 }
4446
4447 // Try to transform:
4448 // (1) multiply-by-(power-of-2 +/- 1) into shift and add/sub.
4449 // mul x, (2^N + 1) --> add (shl x, N), x
4450 // mul x, (2^N - 1) --> sub (shl x, N), x
4451 // Examples: x * 33 --> (x << 5) + x
4452 // x * 15 --> (x << 4) - x
4453 // x * -33 --> -((x << 5) + x)
4454 // x * -15 --> -((x << 4) - x) ; this reduces --> x - (x << 4)
4455 // (2) multiply-by-(power-of-2 +/- power-of-2) into shifts and add/sub.
4456 // mul x, (2^N + 2^M) --> (add (shl x, N), (shl x, M))
4457 // mul x, (2^N - 2^M) --> (sub (shl x, N), (shl x, M))
4458 // Examples: x * 0x8800 --> (x << 15) + (x << 11)
4459 // x * 0xf800 --> (x << 16) - (x << 11)
4460 // x * -0x8800 --> -((x << 15) + (x << 11))
4461 // x * -0xf800 --> -((x << 16) - (x << 11)) ; (x << 11) - (x << 16)
4462 if (N1IsConst && TLI.decomposeMulByConstant(Context&: *DAG.getContext(), VT, C: N1)) {
4463 // TODO: We could handle more general decomposition of any constant by
4464 // having the target set a limit on number of ops and making a
4465 // callback to determine that sequence (similar to sqrt expansion).
4466 unsigned MathOp = ISD::DELETED_NODE;
4467 APInt MulC = ConstValue1.abs();
4468 // The constant `2` should be treated as (2^0 + 1).
4469 unsigned TZeros = MulC == 2 ? 0 : MulC.countr_zero();
4470 MulC.lshrInPlace(ShiftAmt: TZeros);
4471 if ((MulC - 1).isPowerOf2())
4472 MathOp = ISD::ADD;
4473 else if ((MulC + 1).isPowerOf2())
4474 MathOp = ISD::SUB;
4475
4476 if (MathOp != ISD::DELETED_NODE) {
4477 unsigned ShAmt =
4478 MathOp == ISD::ADD ? (MulC - 1).logBase2() : (MulC + 1).logBase2();
4479 ShAmt += TZeros;
4480 assert(ShAmt < VT.getScalarSizeInBits() &&
4481 "multiply-by-constant generated out of bounds shift");
4482 SDValue Shl =
4483 DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0, N2: DAG.getConstant(Val: ShAmt, DL, VT));
4484 SDValue R =
4485 TZeros ? DAG.getNode(Opcode: MathOp, DL, VT, N1: Shl,
4486 N2: DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0,
4487 N2: DAG.getConstant(Val: TZeros, DL, VT)))
4488 : DAG.getNode(Opcode: MathOp, DL, VT, N1: Shl, N2: N0);
4489 if (ConstValue1.isNegative())
4490 R = DAG.getNegative(Val: R, DL, VT);
4491 return R;
4492 }
4493 }
4494
4495 // (mul (shl X, c1), c2) -> (mul X, c2 << c1)
4496 if (N0.getOpcode() == ISD::SHL) {
4497 SDValue N01 = N0.getOperand(i: 1);
4498 if (SDValue C3 = DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL, VT, Ops: {N1, N01}))
4499 return DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: N0.getOperand(i: 0), N2: C3);
4500 }
4501
4502 // Change (mul (shl X, C), Y) -> (shl (mul X, Y), C) when the shift has one
4503 // use.
4504 {
4505 SDValue Sh, Y;
4506
4507 // Check for both (mul (shl X, C), Y) and (mul Y, (shl X, C)).
4508 if (N0.getOpcode() == ISD::SHL &&
4509 isConstantOrConstantVector(N: N0.getOperand(i: 1)) && N0->hasOneUse()) {
4510 Sh = N0; Y = N1;
4511 } else if (N1.getOpcode() == ISD::SHL &&
4512 isConstantOrConstantVector(N: N1.getOperand(i: 1)) &&
4513 N1->hasOneUse()) {
4514 Sh = N1; Y = N0;
4515 }
4516
4517 if (Sh.getNode()) {
4518 SDValue Mul = DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: Sh.getOperand(i: 0), N2: Y);
4519 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Mul, N2: Sh.getOperand(i: 1));
4520 }
4521 }
4522
4523 // fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2)
4524 if (N0.getOpcode() == ISD::ADD &&
4525 DAG.isConstantIntBuildVectorOrConstantInt(N: N1) &&
4526 DAG.isConstantIntBuildVectorOrConstantInt(N: N0.getOperand(i: 1)) &&
4527 isMulAddWithConstProfitable(MulNode: N, AddNode: N0, ConstNode: N1))
4528 return DAG.getNode(
4529 Opcode: ISD::ADD, DL, VT,
4530 N1: DAG.getNode(Opcode: ISD::MUL, DL: SDLoc(N0), VT, N1: N0.getOperand(i: 0), N2: N1),
4531 N2: DAG.getNode(Opcode: ISD::MUL, DL: SDLoc(N1), VT, N1: N0.getOperand(i: 1), N2: N1));
4532
4533 // Fold (mul (vscale * C0), C1) to (vscale * (C0 * C1)).
4534 ConstantSDNode *NC1 = isConstOrConstSplat(N: N1);
4535 if (N0.getOpcode() == ISD::VSCALE && NC1) {
4536 const APInt &C0 = N0.getConstantOperandAPInt(i: 0);
4537 const APInt &C1 = NC1->getAPIntValue();
4538 return DAG.getVScale(DL, VT, MulImm: C0 * C1);
4539 }
4540
4541 // Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)).
4542 APInt MulVal;
4543 if (N0.getOpcode() == ISD::STEP_VECTOR &&
4544 ISD::isConstantSplatVector(N: N1.getNode(), SplatValue&: MulVal)) {
4545 const APInt &C0 = N0.getConstantOperandAPInt(i: 0);
4546 APInt NewStep = C0 * MulVal;
4547 return DAG.getStepVector(DL, ResVT: VT, StepVal: NewStep);
4548 }
4549
4550 // Fold ((mul x, 0/undef) -> 0,
4551 // (mul x, 1) -> x) -> x)
4552 // -> and(x, mask)
4553 // We can replace vectors with '0' and '1' factors with a clearing mask.
4554 if (VT.isFixedLengthVector()) {
4555 unsigned NumElts = VT.getVectorNumElements();
4556 SmallBitVector ClearMask;
4557 ClearMask.reserve(N: NumElts);
4558 auto IsClearMask = [&ClearMask](ConstantSDNode *V) {
4559 if (!V || V->isZero()) {
4560 ClearMask.push_back(Val: true);
4561 return true;
4562 }
4563 ClearMask.push_back(Val: false);
4564 return V->isOne();
4565 };
4566 if ((!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::AND, VT)) &&
4567 ISD::matchUnaryPredicate(Op: N1, Match: IsClearMask, /*AllowUndefs*/ true)) {
4568 assert(N1.getOpcode() == ISD::BUILD_VECTOR && "Unknown constant vector");
4569 EVT LegalSVT = N1.getOperand(i: 0).getValueType();
4570 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: LegalSVT);
4571 SDValue AllOnes = DAG.getAllOnesConstant(DL, VT: LegalSVT);
4572 SmallVector<SDValue, 16> Mask(NumElts, AllOnes);
4573 for (unsigned I = 0; I != NumElts; ++I)
4574 if (ClearMask[I])
4575 Mask[I] = Zero;
4576 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0, N2: DAG.getBuildVector(VT, DL, Ops: Mask));
4577 }
4578 }
4579
4580 // reassociate mul
4581 if (SDValue RMUL = reassociateOps(Opc: ISD::MUL, DL, N0, N1, Flags: N->getFlags()))
4582 return RMUL;
4583
4584 // Fold mul(vecreduce(x), vecreduce(y)) -> vecreduce(mul(x, y))
4585 if (SDValue SD =
4586 reassociateReduction(RedOpc: ISD::VECREDUCE_MUL, Opc: ISD::MUL, DL, VT, N0, N1))
4587 return SD;
4588
4589 // Simplify the operands using demanded-bits information.
4590 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
4591 return SDValue(N, 0);
4592
4593 return SDValue();
4594}
4595
4596/// Return true if divmod libcall is available.
4597static bool isDivRemLibcallAvailable(SDNode *Node, bool isSigned,
4598 const TargetLowering &TLI) {
4599 RTLIB::Libcall LC;
4600 EVT NodeType = Node->getValueType(ResNo: 0);
4601 if (!NodeType.isSimple())
4602 return false;
4603 switch (NodeType.getSimpleVT().SimpleTy) {
4604 default: return false; // No libcall for vector types.
4605 case MVT::i8: LC= isSigned ? RTLIB::SDIVREM_I8 : RTLIB::UDIVREM_I8; break;
4606 case MVT::i16: LC= isSigned ? RTLIB::SDIVREM_I16 : RTLIB::UDIVREM_I16; break;
4607 case MVT::i32: LC= isSigned ? RTLIB::SDIVREM_I32 : RTLIB::UDIVREM_I32; break;
4608 case MVT::i64: LC= isSigned ? RTLIB::SDIVREM_I64 : RTLIB::UDIVREM_I64; break;
4609 case MVT::i128: LC= isSigned ? RTLIB::SDIVREM_I128:RTLIB::UDIVREM_I128; break;
4610 }
4611
4612 return TLI.getLibcallName(Call: LC) != nullptr;
4613}
4614
4615/// Issue divrem if both quotient and remainder are needed.
4616SDValue DAGCombiner::useDivRem(SDNode *Node) {
4617 if (Node->use_empty())
4618 return SDValue(); // This is a dead node, leave it alone.
4619
4620 unsigned Opcode = Node->getOpcode();
4621 bool isSigned = (Opcode == ISD::SDIV) || (Opcode == ISD::SREM);
4622 unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM;
4623
4624 // DivMod lib calls can still work on non-legal types if using lib-calls.
4625 EVT VT = Node->getValueType(ResNo: 0);
4626 if (VT.isVector() || !VT.isInteger())
4627 return SDValue();
4628
4629 if (!TLI.isTypeLegal(VT) && !TLI.isOperationCustom(Op: DivRemOpc, VT))
4630 return SDValue();
4631
4632 // If DIVREM is going to get expanded into a libcall,
4633 // but there is no libcall available, then don't combine.
4634 if (!TLI.isOperationLegalOrCustom(Op: DivRemOpc, VT) &&
4635 !isDivRemLibcallAvailable(Node, isSigned, TLI))
4636 return SDValue();
4637
4638 // If div is legal, it's better to do the normal expansion
4639 unsigned OtherOpcode = 0;
4640 if ((Opcode == ISD::SDIV) || (Opcode == ISD::UDIV)) {
4641 OtherOpcode = isSigned ? ISD::SREM : ISD::UREM;
4642 if (TLI.isOperationLegalOrCustom(Op: Opcode, VT))
4643 return SDValue();
4644 } else {
4645 OtherOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
4646 if (TLI.isOperationLegalOrCustom(Op: OtherOpcode, VT))
4647 return SDValue();
4648 }
4649
4650 SDValue Op0 = Node->getOperand(Num: 0);
4651 SDValue Op1 = Node->getOperand(Num: 1);
4652 SDValue combined;
4653 for (SDNode *User : Op0->uses()) {
4654 if (User == Node || User->getOpcode() == ISD::DELETED_NODE ||
4655 User->use_empty())
4656 continue;
4657 // Convert the other matching node(s), too;
4658 // otherwise, the DIVREM may get target-legalized into something
4659 // target-specific that we won't be able to recognize.
4660 unsigned UserOpc = User->getOpcode();
4661 if ((UserOpc == Opcode || UserOpc == OtherOpcode || UserOpc == DivRemOpc) &&
4662 User->getOperand(Num: 0) == Op0 &&
4663 User->getOperand(Num: 1) == Op1) {
4664 if (!combined) {
4665 if (UserOpc == OtherOpcode) {
4666 SDVTList VTs = DAG.getVTList(VT1: VT, VT2: VT);
4667 combined = DAG.getNode(Opcode: DivRemOpc, DL: SDLoc(Node), VTList: VTs, N1: Op0, N2: Op1);
4668 } else if (UserOpc == DivRemOpc) {
4669 combined = SDValue(User, 0);
4670 } else {
4671 assert(UserOpc == Opcode);
4672 continue;
4673 }
4674 }
4675 if (UserOpc == ISD::SDIV || UserOpc == ISD::UDIV)
4676 CombineTo(N: User, Res: combined);
4677 else if (UserOpc == ISD::SREM || UserOpc == ISD::UREM)
4678 CombineTo(N: User, Res: combined.getValue(R: 1));
4679 }
4680 }
4681 return combined;
4682}
4683
4684static SDValue simplifyDivRem(SDNode *N, SelectionDAG &DAG) {
4685 SDValue N0 = N->getOperand(Num: 0);
4686 SDValue N1 = N->getOperand(Num: 1);
4687 EVT VT = N->getValueType(ResNo: 0);
4688 SDLoc DL(N);
4689
4690 unsigned Opc = N->getOpcode();
4691 bool IsDiv = (ISD::SDIV == Opc) || (ISD::UDIV == Opc);
4692 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
4693
4694 // X / undef -> undef
4695 // X % undef -> undef
4696 // X / 0 -> undef
4697 // X % 0 -> undef
4698 // NOTE: This includes vectors where any divisor element is zero/undef.
4699 if (DAG.isUndef(Opcode: Opc, Ops: {N0, N1}))
4700 return DAG.getUNDEF(VT);
4701
4702 // undef / X -> 0
4703 // undef % X -> 0
4704 if (N0.isUndef())
4705 return DAG.getConstant(Val: 0, DL, VT);
4706
4707 // 0 / X -> 0
4708 // 0 % X -> 0
4709 ConstantSDNode *N0C = isConstOrConstSplat(N: N0);
4710 if (N0C && N0C->isZero())
4711 return N0;
4712
4713 // X / X -> 1
4714 // X % X -> 0
4715 if (N0 == N1)
4716 return DAG.getConstant(Val: IsDiv ? 1 : 0, DL, VT);
4717
4718 // X / 1 -> X
4719 // X % 1 -> 0
4720 // If this is a boolean op (single-bit element type), we can't have
4721 // division-by-zero or remainder-by-zero, so assume the divisor is 1.
4722 // TODO: Similarly, if we're zero-extending a boolean divisor, then assume
4723 // it's a 1.
4724 if ((N1C && N1C->isOne()) || (VT.getScalarType() == MVT::i1))
4725 return IsDiv ? N0 : DAG.getConstant(Val: 0, DL, VT);
4726
4727 return SDValue();
4728}
4729
4730SDValue DAGCombiner::visitSDIV(SDNode *N) {
4731 SDValue N0 = N->getOperand(Num: 0);
4732 SDValue N1 = N->getOperand(Num: 1);
4733 EVT VT = N->getValueType(ResNo: 0);
4734 EVT CCVT = getSetCCResultType(VT);
4735 SDLoc DL(N);
4736
4737 // fold (sdiv c1, c2) -> c1/c2
4738 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::SDIV, DL, VT, Ops: {N0, N1}))
4739 return C;
4740
4741 // fold vector ops
4742 if (VT.isVector())
4743 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4744 return FoldedVOp;
4745
4746 // fold (sdiv X, -1) -> 0-X
4747 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
4748 if (N1C && N1C->isAllOnes())
4749 return DAG.getNegative(Val: N0, DL, VT);
4750
4751 // fold (sdiv X, MIN_SIGNED) -> select(X == MIN_SIGNED, 1, 0)
4752 if (N1C && N1C->isMinSignedValue())
4753 return DAG.getSelect(DL, VT, Cond: DAG.getSetCC(DL, VT: CCVT, LHS: N0, RHS: N1, Cond: ISD::SETEQ),
4754 LHS: DAG.getConstant(Val: 1, DL, VT),
4755 RHS: DAG.getConstant(Val: 0, DL, VT));
4756
4757 if (SDValue V = simplifyDivRem(N, DAG))
4758 return V;
4759
4760 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
4761 return NewSel;
4762
4763 // If we know the sign bits of both operands are zero, strength reduce to a
4764 // udiv instead. Handles (X&15) /s 4 -> X&15 >> 2
4765 if (DAG.SignBitIsZero(Op: N1) && DAG.SignBitIsZero(Op: N0))
4766 return DAG.getNode(Opcode: ISD::UDIV, DL, VT: N1.getValueType(), N1: N0, N2: N1);
4767
4768 if (SDValue V = visitSDIVLike(N0, N1, N)) {
4769 // If the corresponding remainder node exists, update its users with
4770 // (Dividend - (Quotient * Divisor).
4771 if (SDNode *RemNode = DAG.getNodeIfExists(Opcode: ISD::SREM, VTList: N->getVTList(),
4772 Ops: { N0, N1 })) {
4773 SDValue Mul = DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: V, N2: N1);
4774 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: Mul);
4775 AddToWorklist(N: Mul.getNode());
4776 AddToWorklist(N: Sub.getNode());
4777 CombineTo(N: RemNode, Res: Sub);
4778 }
4779 return V;
4780 }
4781
4782 // sdiv, srem -> sdivrem
4783 // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
4784 // true. Otherwise, we break the simplification logic in visitREM().
4785 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4786 if (!N1C || TLI.isIntDivCheap(VT: N->getValueType(ResNo: 0), Attr))
4787 if (SDValue DivRem = useDivRem(Node: N))
4788 return DivRem;
4789
4790 return SDValue();
4791}
4792
4793static bool isDivisorPowerOfTwo(SDValue Divisor) {
4794 // Helper for determining whether a value is a power-2 constant scalar or a
4795 // vector of such elements.
4796 auto IsPowerOfTwo = [](ConstantSDNode *C) {
4797 if (C->isZero() || C->isOpaque())
4798 return false;
4799 if (C->getAPIntValue().isPowerOf2())
4800 return true;
4801 if (C->getAPIntValue().isNegatedPowerOf2())
4802 return true;
4803 return false;
4804 };
4805
4806 return ISD::matchUnaryPredicate(Op: Divisor, Match: IsPowerOfTwo);
4807}
4808
4809SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) {
4810 SDLoc DL(N);
4811 EVT VT = N->getValueType(ResNo: 0);
4812 EVT CCVT = getSetCCResultType(VT);
4813 unsigned BitWidth = VT.getScalarSizeInBits();
4814
4815 // fold (sdiv X, pow2) -> simple ops after legalize
4816 // FIXME: We check for the exact bit here because the generic lowering gives
4817 // better results in that case. The target-specific lowering should learn how
4818 // to handle exact sdivs efficiently.
4819 if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(Divisor: N1)) {
4820 // Target-specific implementation of sdiv x, pow2.
4821 if (SDValue Res = BuildSDIVPow2(N))
4822 return Res;
4823
4824 // Create constants that are functions of the shift amount value.
4825 EVT ShiftAmtTy = getShiftAmountTy(LHSTy: N0.getValueType());
4826 SDValue Bits = DAG.getConstant(Val: BitWidth, DL, VT: ShiftAmtTy);
4827 SDValue C1 = DAG.getNode(Opcode: ISD::CTTZ, DL, VT, Operand: N1);
4828 C1 = DAG.getZExtOrTrunc(Op: C1, DL, VT: ShiftAmtTy);
4829 SDValue Inexact = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftAmtTy, N1: Bits, N2: C1);
4830 if (!isConstantOrConstantVector(N: Inexact))
4831 return SDValue();
4832
4833 // Splat the sign bit into the register
4834 SDValue Sign = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: N0,
4835 N2: DAG.getConstant(Val: BitWidth - 1, DL, VT: ShiftAmtTy));
4836 AddToWorklist(N: Sign.getNode());
4837
4838 // Add (N0 < 0) ? abs2 - 1 : 0;
4839 SDValue Srl = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Sign, N2: Inexact);
4840 AddToWorklist(N: Srl.getNode());
4841 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: Srl);
4842 AddToWorklist(N: Add.getNode());
4843 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: Add, N2: C1);
4844 AddToWorklist(N: Sra.getNode());
4845
4846 // Special case: (sdiv X, 1) -> X
4847 // Special Case: (sdiv X, -1) -> 0-X
4848 SDValue One = DAG.getConstant(Val: 1, DL, VT);
4849 SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
4850 SDValue IsOne = DAG.getSetCC(DL, VT: CCVT, LHS: N1, RHS: One, Cond: ISD::SETEQ);
4851 SDValue IsAllOnes = DAG.getSetCC(DL, VT: CCVT, LHS: N1, RHS: AllOnes, Cond: ISD::SETEQ);
4852 SDValue IsOneOrAllOnes = DAG.getNode(Opcode: ISD::OR, DL, VT: CCVT, N1: IsOne, N2: IsAllOnes);
4853 Sra = DAG.getSelect(DL, VT, Cond: IsOneOrAllOnes, LHS: N0, RHS: Sra);
4854
4855 // If dividing by a positive value, we're done. Otherwise, the result must
4856 // be negated.
4857 SDValue Zero = DAG.getConstant(Val: 0, DL, VT);
4858 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Zero, N2: Sra);
4859
4860 // FIXME: Use SELECT_CC once we improve SELECT_CC constant-folding.
4861 SDValue IsNeg = DAG.getSetCC(DL, VT: CCVT, LHS: N1, RHS: Zero, Cond: ISD::SETLT);
4862 SDValue Res = DAG.getSelect(DL, VT, Cond: IsNeg, LHS: Sub, RHS: Sra);
4863 return Res;
4864 }
4865
4866 // If integer divide is expensive and we satisfy the requirements, emit an
4867 // alternate sequence. Targets may check function attributes for size/speed
4868 // trade-offs.
4869 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4870 if (isConstantOrConstantVector(N: N1) &&
4871 !TLI.isIntDivCheap(VT: N->getValueType(ResNo: 0), Attr))
4872 if (SDValue Op = BuildSDIV(N))
4873 return Op;
4874
4875 return SDValue();
4876}
4877
4878SDValue DAGCombiner::visitUDIV(SDNode *N) {
4879 SDValue N0 = N->getOperand(Num: 0);
4880 SDValue N1 = N->getOperand(Num: 1);
4881 EVT VT = N->getValueType(ResNo: 0);
4882 EVT CCVT = getSetCCResultType(VT);
4883 SDLoc DL(N);
4884
4885 // fold (udiv c1, c2) -> c1/c2
4886 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::UDIV, DL, VT, Ops: {N0, N1}))
4887 return C;
4888
4889 // fold vector ops
4890 if (VT.isVector())
4891 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4892 return FoldedVOp;
4893
4894 // fold (udiv X, -1) -> select(X == -1, 1, 0)
4895 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
4896 if (N1C && N1C->isAllOnes() && CCVT.isVector() == VT.isVector()) {
4897 return DAG.getSelect(DL, VT, Cond: DAG.getSetCC(DL, VT: CCVT, LHS: N0, RHS: N1, Cond: ISD::SETEQ),
4898 LHS: DAG.getConstant(Val: 1, DL, VT),
4899 RHS: DAG.getConstant(Val: 0, DL, VT));
4900 }
4901
4902 if (SDValue V = simplifyDivRem(N, DAG))
4903 return V;
4904
4905 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
4906 return NewSel;
4907
4908 if (SDValue V = visitUDIVLike(N0, N1, N)) {
4909 // If the corresponding remainder node exists, update its users with
4910 // (Dividend - (Quotient * Divisor).
4911 if (SDNode *RemNode = DAG.getNodeIfExists(Opcode: ISD::UREM, VTList: N->getVTList(),
4912 Ops: { N0, N1 })) {
4913 SDValue Mul = DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: V, N2: N1);
4914 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: Mul);
4915 AddToWorklist(N: Mul.getNode());
4916 AddToWorklist(N: Sub.getNode());
4917 CombineTo(N: RemNode, Res: Sub);
4918 }
4919 return V;
4920 }
4921
4922 // sdiv, srem -> sdivrem
4923 // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
4924 // true. Otherwise, we break the simplification logic in visitREM().
4925 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4926 if (!N1C || TLI.isIntDivCheap(VT: N->getValueType(ResNo: 0), Attr))
4927 if (SDValue DivRem = useDivRem(Node: N))
4928 return DivRem;
4929
4930 return SDValue();
4931}
4932
4933SDValue DAGCombiner::visitUDIVLike(SDValue N0, SDValue N1, SDNode *N) {
4934 SDLoc DL(N);
4935 EVT VT = N->getValueType(ResNo: 0);
4936
4937 // fold (udiv x, (1 << c)) -> x >>u c
4938 if (isConstantOrConstantVector(N: N1, /*NoOpaques*/ true)) {
4939 if (SDValue LogBase2 = BuildLogBase2(V: N1, DL)) {
4940 AddToWorklist(N: LogBase2.getNode());
4941
4942 EVT ShiftVT = getShiftAmountTy(LHSTy: N0.getValueType());
4943 SDValue Trunc = DAG.getZExtOrTrunc(Op: LogBase2, DL, VT: ShiftVT);
4944 AddToWorklist(N: Trunc.getNode());
4945 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0, N2: Trunc);
4946 }
4947 }
4948
4949 // fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2
4950 if (N1.getOpcode() == ISD::SHL) {
4951 SDValue N10 = N1.getOperand(i: 0);
4952 if (isConstantOrConstantVector(N: N10, /*NoOpaques*/ true)) {
4953 if (SDValue LogBase2 = BuildLogBase2(V: N10, DL)) {
4954 AddToWorklist(N: LogBase2.getNode());
4955
4956 EVT ADDVT = N1.getOperand(i: 1).getValueType();
4957 SDValue Trunc = DAG.getZExtOrTrunc(Op: LogBase2, DL, VT: ADDVT);
4958 AddToWorklist(N: Trunc.getNode());
4959 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT: ADDVT, N1: N1.getOperand(i: 1), N2: Trunc);
4960 AddToWorklist(N: Add.getNode());
4961 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0, N2: Add);
4962 }
4963 }
4964 }
4965
4966 // fold (udiv x, c) -> alternate
4967 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4968 if (isConstantOrConstantVector(N: N1) &&
4969 !TLI.isIntDivCheap(VT: N->getValueType(ResNo: 0), Attr))
4970 if (SDValue Op = BuildUDIV(N))
4971 return Op;
4972
4973 return SDValue();
4974}
4975
4976SDValue DAGCombiner::buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N) {
4977 if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(Divisor: N1) &&
4978 !DAG.doesNodeExist(Opcode: ISD::SDIV, VTList: N->getVTList(), Ops: {N0, N1})) {
4979 // Target-specific implementation of srem x, pow2.
4980 if (SDValue Res = BuildSREMPow2(N))
4981 return Res;
4982 }
4983 return SDValue();
4984}
4985
4986// handles ISD::SREM and ISD::UREM
4987SDValue DAGCombiner::visitREM(SDNode *N) {
4988 unsigned Opcode = N->getOpcode();
4989 SDValue N0 = N->getOperand(Num: 0);
4990 SDValue N1 = N->getOperand(Num: 1);
4991 EVT VT = N->getValueType(ResNo: 0);
4992 EVT CCVT = getSetCCResultType(VT);
4993
4994 bool isSigned = (Opcode == ISD::SREM);
4995 SDLoc DL(N);
4996
4997 // fold (rem c1, c2) -> c1%c2
4998 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
4999 return C;
5000
5001 // fold (urem X, -1) -> select(FX == -1, 0, FX)
5002 // Freeze the numerator to avoid a miscompile with an undefined value.
5003 if (!isSigned && llvm::isAllOnesOrAllOnesSplat(V: N1, /*AllowUndefs*/ false) &&
5004 CCVT.isVector() == VT.isVector()) {
5005 SDValue F0 = DAG.getFreeze(V: N0);
5006 SDValue EqualsNeg1 = DAG.getSetCC(DL, VT: CCVT, LHS: F0, RHS: N1, Cond: ISD::SETEQ);
5007 return DAG.getSelect(DL, VT, Cond: EqualsNeg1, LHS: DAG.getConstant(Val: 0, DL, VT), RHS: F0);
5008 }
5009
5010 if (SDValue V = simplifyDivRem(N, DAG))
5011 return V;
5012
5013 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
5014 return NewSel;
5015
5016 if (isSigned) {
5017 // If we know the sign bits of both operands are zero, strength reduce to a
5018 // urem instead. Handles (X & 0x0FFFFFFF) %s 16 -> X&15
5019 if (DAG.SignBitIsZero(Op: N1) && DAG.SignBitIsZero(Op: N0))
5020 return DAG.getNode(Opcode: ISD::UREM, DL, VT, N1: N0, N2: N1);
5021 } else {
5022 if (DAG.isKnownToBeAPowerOfTwo(Val: N1)) {
5023 // fold (urem x, pow2) -> (and x, pow2-1)
5024 SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
5025 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: NegOne);
5026 AddToWorklist(N: Add.getNode());
5027 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0, N2: Add);
5028 }
5029 // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1))
5030 // fold (urem x, (lshr pow2, y)) -> (and x, (add (lshr pow2, y), -1))
5031 // TODO: We should sink the following into isKnownToBePowerOfTwo
5032 // using a OrZero parameter analogous to our handling in ValueTracking.
5033 if ((N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) &&
5034 DAG.isKnownToBeAPowerOfTwo(Val: N1.getOperand(i: 0))) {
5035 SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
5036 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: NegOne);
5037 AddToWorklist(N: Add.getNode());
5038 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0, N2: Add);
5039 }
5040 }
5041
5042 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5043
5044 // If X/C can be simplified by the division-by-constant logic, lower
5045 // X%C to the equivalent of X-X/C*C.
5046 // Reuse the SDIVLike/UDIVLike combines - to avoid mangling nodes, the
5047 // speculative DIV must not cause a DIVREM conversion. We guard against this
5048 // by skipping the simplification if isIntDivCheap(). When div is not cheap,
5049 // combine will not return a DIVREM. Regardless, checking cheapness here
5050 // makes sense since the simplification results in fatter code.
5051 if (DAG.isKnownNeverZero(Op: N1) && !TLI.isIntDivCheap(VT, Attr)) {
5052 if (isSigned) {
5053 // check if we can build faster implementation for srem
5054 if (SDValue OptimizedRem = buildOptimizedSREM(N0, N1, N))
5055 return OptimizedRem;
5056 }
5057
5058 SDValue OptimizedDiv =
5059 isSigned ? visitSDIVLike(N0, N1, N) : visitUDIVLike(N0, N1, N);
5060 if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != N) {
5061 // If the equivalent Div node also exists, update its users.
5062 unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
5063 if (SDNode *DivNode = DAG.getNodeIfExists(Opcode: DivOpcode, VTList: N->getVTList(),
5064 Ops: { N0, N1 }))
5065 CombineTo(N: DivNode, Res: OptimizedDiv);
5066 SDValue Mul = DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: OptimizedDiv, N2: N1);
5067 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: Mul);
5068 AddToWorklist(N: OptimizedDiv.getNode());
5069 AddToWorklist(N: Mul.getNode());
5070 return Sub;
5071 }
5072 }
5073
5074 // sdiv, srem -> sdivrem
5075 if (SDValue DivRem = useDivRem(Node: N))
5076 return DivRem.getValue(R: 1);
5077
5078 return SDValue();
5079}
5080
5081SDValue DAGCombiner::visitMULHS(SDNode *N) {
5082 SDValue N0 = N->getOperand(Num: 0);
5083 SDValue N1 = N->getOperand(Num: 1);
5084 EVT VT = N->getValueType(ResNo: 0);
5085 SDLoc DL(N);
5086
5087 // fold (mulhs c1, c2)
5088 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::MULHS, DL, VT, Ops: {N0, N1}))
5089 return C;
5090
5091 // canonicalize constant to RHS.
5092 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5093 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5094 return DAG.getNode(Opcode: ISD::MULHS, DL, VTList: N->getVTList(), N1, N2: N0);
5095
5096 if (VT.isVector()) {
5097 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5098 return FoldedVOp;
5099
5100 // fold (mulhs x, 0) -> 0
5101 // do not return N1, because undef node may exist.
5102 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
5103 return DAG.getConstant(Val: 0, DL, VT);
5104 }
5105
5106 // fold (mulhs x, 0) -> 0
5107 if (isNullConstant(V: N1))
5108 return N1;
5109
5110 // fold (mulhs x, 1) -> (sra x, size(x)-1)
5111 if (isOneConstant(V: N1))
5112 return DAG.getNode(Opcode: ISD::SRA, DL, VT: N0.getValueType(), N1: N0,
5113 N2: DAG.getConstant(Val: N0.getScalarValueSizeInBits() - 1, DL,
5114 VT: getShiftAmountTy(LHSTy: N0.getValueType())));
5115
5116 // fold (mulhs x, undef) -> 0
5117 if (N0.isUndef() || N1.isUndef())
5118 return DAG.getConstant(Val: 0, DL, VT);
5119
5120 // If the type twice as wide is legal, transform the mulhs to a wider multiply
5121 // plus a shift.
5122 if (!TLI.isOperationLegalOrCustom(Op: ISD::MULHS, VT) && VT.isSimple() &&
5123 !VT.isVector()) {
5124 MVT Simple = VT.getSimpleVT();
5125 unsigned SimpleSize = Simple.getSizeInBits();
5126 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SimpleSize*2);
5127 if (TLI.isOperationLegal(Op: ISD::MUL, VT: NewVT)) {
5128 N0 = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT: NewVT, Operand: N0);
5129 N1 = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT: NewVT, Operand: N1);
5130 N1 = DAG.getNode(Opcode: ISD::MUL, DL, VT: NewVT, N1: N0, N2: N1);
5131 N1 = DAG.getNode(Opcode: ISD::SRL, DL, VT: NewVT, N1,
5132 N2: DAG.getConstant(Val: SimpleSize, DL,
5133 VT: getShiftAmountTy(LHSTy: N1.getValueType())));
5134 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N1);
5135 }
5136 }
5137
5138 return SDValue();
5139}
5140
5141SDValue DAGCombiner::visitMULHU(SDNode *N) {
5142 SDValue N0 = N->getOperand(Num: 0);
5143 SDValue N1 = N->getOperand(Num: 1);
5144 EVT VT = N->getValueType(ResNo: 0);
5145 SDLoc DL(N);
5146
5147 // fold (mulhu c1, c2)
5148 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::MULHU, DL, VT, Ops: {N0, N1}))
5149 return C;
5150
5151 // canonicalize constant to RHS.
5152 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5153 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5154 return DAG.getNode(Opcode: ISD::MULHU, DL, VTList: N->getVTList(), N1, N2: N0);
5155
5156 if (VT.isVector()) {
5157 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5158 return FoldedVOp;
5159
5160 // fold (mulhu x, 0) -> 0
5161 // do not return N1, because undef node may exist.
5162 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
5163 return DAG.getConstant(Val: 0, DL, VT);
5164 }
5165
5166 // fold (mulhu x, 0) -> 0
5167 if (isNullConstant(V: N1))
5168 return N1;
5169
5170 // fold (mulhu x, 1) -> 0
5171 if (isOneConstant(V: N1))
5172 return DAG.getConstant(Val: 0, DL, VT: N0.getValueType());
5173
5174 // fold (mulhu x, undef) -> 0
5175 if (N0.isUndef() || N1.isUndef())
5176 return DAG.getConstant(Val: 0, DL, VT);
5177
5178 // fold (mulhu x, (1 << c)) -> x >> (bitwidth - c)
5179 if (isConstantOrConstantVector(N: N1, /*NoOpaques*/ true) &&
5180 hasOperation(Opcode: ISD::SRL, VT)) {
5181 if (SDValue LogBase2 = BuildLogBase2(V: N1, DL)) {
5182 unsigned NumEltBits = VT.getScalarSizeInBits();
5183 SDValue SRLAmt = DAG.getNode(
5184 Opcode: ISD::SUB, DL, VT, N1: DAG.getConstant(Val: NumEltBits, DL, VT), N2: LogBase2);
5185 EVT ShiftVT = getShiftAmountTy(LHSTy: N0.getValueType());
5186 SDValue Trunc = DAG.getZExtOrTrunc(Op: SRLAmt, DL, VT: ShiftVT);
5187 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0, N2: Trunc);
5188 }
5189 }
5190
5191 // If the type twice as wide is legal, transform the mulhu to a wider multiply
5192 // plus a shift.
5193 if (!TLI.isOperationLegalOrCustom(Op: ISD::MULHU, VT) && VT.isSimple() &&
5194 !VT.isVector()) {
5195 MVT Simple = VT.getSimpleVT();
5196 unsigned SimpleSize = Simple.getSizeInBits();
5197 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SimpleSize*2);
5198 if (TLI.isOperationLegal(Op: ISD::MUL, VT: NewVT)) {
5199 N0 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: NewVT, Operand: N0);
5200 N1 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: NewVT, Operand: N1);
5201 N1 = DAG.getNode(Opcode: ISD::MUL, DL, VT: NewVT, N1: N0, N2: N1);
5202 N1 = DAG.getNode(Opcode: ISD::SRL, DL, VT: NewVT, N1,
5203 N2: DAG.getConstant(Val: SimpleSize, DL,
5204 VT: getShiftAmountTy(LHSTy: N1.getValueType())));
5205 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N1);
5206 }
5207 }
5208
5209 // Simplify the operands using demanded-bits information.
5210 // We don't have demanded bits support for MULHU so this just enables constant
5211 // folding based on known bits.
5212 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
5213 return SDValue(N, 0);
5214
5215 return SDValue();
5216}
5217
5218SDValue DAGCombiner::visitAVG(SDNode *N) {
5219 unsigned Opcode = N->getOpcode();
5220 SDValue N0 = N->getOperand(Num: 0);
5221 SDValue N1 = N->getOperand(Num: 1);
5222 EVT VT = N->getValueType(ResNo: 0);
5223 SDLoc DL(N);
5224
5225 // fold (avg c1, c2)
5226 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
5227 return C;
5228
5229 // canonicalize constant to RHS.
5230 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5231 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5232 return DAG.getNode(Opcode, DL, VTList: N->getVTList(), N1, N2: N0);
5233
5234 if (VT.isVector()) {
5235 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5236 return FoldedVOp;
5237
5238 // fold (avgfloor x, 0) -> x >> 1
5239 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode())) {
5240 if (Opcode == ISD::AVGFLOORS)
5241 return DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: N0, N2: DAG.getConstant(Val: 1, DL, VT));
5242 if (Opcode == ISD::AVGFLOORU)
5243 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0, N2: DAG.getConstant(Val: 1, DL, VT));
5244 }
5245 }
5246
5247 // fold (avg x, undef) -> x
5248 if (N0.isUndef())
5249 return N1;
5250 if (N1.isUndef())
5251 return N0;
5252
5253 // Fold (avg x, x) --> x
5254 if (N0 == N1 && Level >= AfterLegalizeTypes)
5255 return N0;
5256
5257 // TODO If we use avg for scalars anywhere, we can add (avgfl x, 0) -> x >> 1
5258
5259 return SDValue();
5260}
5261
5262SDValue DAGCombiner::visitABD(SDNode *N) {
5263 unsigned Opcode = N->getOpcode();
5264 SDValue N0 = N->getOperand(Num: 0);
5265 SDValue N1 = N->getOperand(Num: 1);
5266 EVT VT = N->getValueType(ResNo: 0);
5267 SDLoc DL(N);
5268
5269 // fold (abd c1, c2)
5270 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
5271 return C;
5272
5273 // canonicalize constant to RHS.
5274 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5275 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5276 return DAG.getNode(Opcode, DL, VTList: N->getVTList(), N1, N2: N0);
5277
5278 if (VT.isVector()) {
5279 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5280 return FoldedVOp;
5281
5282 // fold (abds x, 0) -> abs x
5283 // fold (abdu x, 0) -> x
5284 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode())) {
5285 if (Opcode == ISD::ABDS)
5286 return DAG.getNode(Opcode: ISD::ABS, DL, VT, Operand: N0);
5287 if (Opcode == ISD::ABDU)
5288 return N0;
5289 }
5290 }
5291
5292 // fold (abd x, undef) -> 0
5293 if (N0.isUndef() || N1.isUndef())
5294 return DAG.getConstant(Val: 0, DL, VT);
5295
5296 // fold (abds x, y) -> (abdu x, y) iff both args are known positive
5297 if (Opcode == ISD::ABDS && hasOperation(Opcode: ISD::ABDU, VT) &&
5298 DAG.SignBitIsZero(Op: N0) && DAG.SignBitIsZero(Op: N1))
5299 return DAG.getNode(Opcode: ISD::ABDU, DL, VT, N1, N2: N0);
5300
5301 return SDValue();
5302}
5303
5304/// Perform optimizations common to nodes that compute two values. LoOp and HiOp
5305/// give the opcodes for the two computations that are being performed. Return
5306/// true if a simplification was made.
5307SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
5308 unsigned HiOp) {
5309 // If the high half is not needed, just compute the low half.
5310 bool HiExists = N->hasAnyUseOfValue(Value: 1);
5311 if (!HiExists && (!LegalOperations ||
5312 TLI.isOperationLegalOrCustom(Op: LoOp, VT: N->getValueType(ResNo: 0)))) {
5313 SDValue Res = DAG.getNode(Opcode: LoOp, DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Ops: N->ops());
5314 return CombineTo(N, Res0: Res, Res1: Res);
5315 }
5316
5317 // If the low half is not needed, just compute the high half.
5318 bool LoExists = N->hasAnyUseOfValue(Value: 0);
5319 if (!LoExists && (!LegalOperations ||
5320 TLI.isOperationLegalOrCustom(Op: HiOp, VT: N->getValueType(ResNo: 1)))) {
5321 SDValue Res = DAG.getNode(Opcode: HiOp, DL: SDLoc(N), VT: N->getValueType(ResNo: 1), Ops: N->ops());
5322 return CombineTo(N, Res0: Res, Res1: Res);
5323 }
5324
5325 // If both halves are used, return as it is.
5326 if (LoExists && HiExists)
5327 return SDValue();
5328
5329 // If the two computed results can be simplified separately, separate them.
5330 if (LoExists) {
5331 SDValue Lo = DAG.getNode(Opcode: LoOp, DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Ops: N->ops());
5332 AddToWorklist(N: Lo.getNode());
5333 SDValue LoOpt = combine(N: Lo.getNode());
5334 if (LoOpt.getNode() && LoOpt.getNode() != Lo.getNode() &&
5335 (!LegalOperations ||
5336 TLI.isOperationLegalOrCustom(Op: LoOpt.getOpcode(), VT: LoOpt.getValueType())))
5337 return CombineTo(N, Res0: LoOpt, Res1: LoOpt);
5338 }
5339
5340 if (HiExists) {
5341 SDValue Hi = DAG.getNode(Opcode: HiOp, DL: SDLoc(N), VT: N->getValueType(ResNo: 1), Ops: N->ops());
5342 AddToWorklist(N: Hi.getNode());
5343 SDValue HiOpt = combine(N: Hi.getNode());
5344 if (HiOpt.getNode() && HiOpt != Hi &&
5345 (!LegalOperations ||
5346 TLI.isOperationLegalOrCustom(Op: HiOpt.getOpcode(), VT: HiOpt.getValueType())))
5347 return CombineTo(N, Res0: HiOpt, Res1: HiOpt);
5348 }
5349
5350 return SDValue();
5351}
5352
5353SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) {
5354 if (SDValue Res = SimplifyNodeWithTwoResults(N, LoOp: ISD::MUL, HiOp: ISD::MULHS))
5355 return Res;
5356
5357 SDValue N0 = N->getOperand(Num: 0);
5358 SDValue N1 = N->getOperand(Num: 1);
5359 EVT VT = N->getValueType(ResNo: 0);
5360 SDLoc DL(N);
5361
5362 // Constant fold.
5363 if (isa<ConstantSDNode>(Val: N0) && isa<ConstantSDNode>(Val: N1))
5364 return DAG.getNode(Opcode: ISD::SMUL_LOHI, DL, VTList: N->getVTList(), N1: N0, N2: N1);
5365
5366 // canonicalize constant to RHS (vector doesn't have to splat)
5367 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5368 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5369 return DAG.getNode(Opcode: ISD::SMUL_LOHI, DL, VTList: N->getVTList(), N1, N2: N0);
5370
5371 // If the type is twice as wide is legal, transform the mulhu to a wider
5372 // multiply plus a shift.
5373 if (VT.isSimple() && !VT.isVector()) {
5374 MVT Simple = VT.getSimpleVT();
5375 unsigned SimpleSize = Simple.getSizeInBits();
5376 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SimpleSize*2);
5377 if (TLI.isOperationLegal(Op: ISD::MUL, VT: NewVT)) {
5378 SDValue Lo = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT: NewVT, Operand: N0);
5379 SDValue Hi = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT: NewVT, Operand: N1);
5380 Lo = DAG.getNode(Opcode: ISD::MUL, DL, VT: NewVT, N1: Lo, N2: Hi);
5381 // Compute the high part as N1.
5382 Hi = DAG.getNode(Opcode: ISD::SRL, DL, VT: NewVT, N1: Lo,
5383 N2: DAG.getConstant(Val: SimpleSize, DL,
5384 VT: getShiftAmountTy(LHSTy: Lo.getValueType())));
5385 Hi = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Hi);
5386 // Compute the low part as N0.
5387 Lo = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Lo);
5388 return CombineTo(N, Res0: Lo, Res1: Hi);
5389 }
5390 }
5391
5392 return SDValue();
5393}
5394
5395SDValue DAGCombiner::visitUMUL_LOHI(SDNode *N) {
5396 if (SDValue Res = SimplifyNodeWithTwoResults(N, LoOp: ISD::MUL, HiOp: ISD::MULHU))
5397 return Res;
5398
5399 SDValue N0 = N->getOperand(Num: 0);
5400 SDValue N1 = N->getOperand(Num: 1);
5401 EVT VT = N->getValueType(ResNo: 0);
5402 SDLoc DL(N);
5403
5404 // Constant fold.
5405 if (isa<ConstantSDNode>(Val: N0) && isa<ConstantSDNode>(Val: N1))
5406 return DAG.getNode(Opcode: ISD::UMUL_LOHI, DL, VTList: N->getVTList(), N1: N0, N2: N1);
5407
5408 // canonicalize constant to RHS (vector doesn't have to splat)
5409 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5410 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5411 return DAG.getNode(Opcode: ISD::UMUL_LOHI, DL, VTList: N->getVTList(), N1, N2: N0);
5412
5413 // (umul_lohi N0, 0) -> (0, 0)
5414 if (isNullConstant(V: N1)) {
5415 SDValue Zero = DAG.getConstant(Val: 0, DL, VT);
5416 return CombineTo(N, Res0: Zero, Res1: Zero);
5417 }
5418
5419 // (umul_lohi N0, 1) -> (N0, 0)
5420 if (isOneConstant(V: N1)) {
5421 SDValue Zero = DAG.getConstant(Val: 0, DL, VT);
5422 return CombineTo(N, Res0: N0, Res1: Zero);
5423 }
5424
5425 // If the type is twice as wide is legal, transform the mulhu to a wider
5426 // multiply plus a shift.
5427 if (VT.isSimple() && !VT.isVector()) {
5428 MVT Simple = VT.getSimpleVT();
5429 unsigned SimpleSize = Simple.getSizeInBits();
5430 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SimpleSize*2);
5431 if (TLI.isOperationLegal(Op: ISD::MUL, VT: NewVT)) {
5432 SDValue Lo = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: NewVT, Operand: N0);
5433 SDValue Hi = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: NewVT, Operand: N1);
5434 Lo = DAG.getNode(Opcode: ISD::MUL, DL, VT: NewVT, N1: Lo, N2: Hi);
5435 // Compute the high part as N1.
5436 Hi = DAG.getNode(Opcode: ISD::SRL, DL, VT: NewVT, N1: Lo,
5437 N2: DAG.getConstant(Val: SimpleSize, DL,
5438 VT: getShiftAmountTy(LHSTy: Lo.getValueType())));
5439 Hi = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Hi);
5440 // Compute the low part as N0.
5441 Lo = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Lo);
5442 return CombineTo(N, Res0: Lo, Res1: Hi);
5443 }
5444 }
5445
5446 return SDValue();
5447}
5448
5449SDValue DAGCombiner::visitMULO(SDNode *N) {
5450 SDValue N0 = N->getOperand(Num: 0);
5451 SDValue N1 = N->getOperand(Num: 1);
5452 EVT VT = N0.getValueType();
5453 bool IsSigned = (ISD::SMULO == N->getOpcode());
5454
5455 EVT CarryVT = N->getValueType(ResNo: 1);
5456 SDLoc DL(N);
5457
5458 ConstantSDNode *N0C = isConstOrConstSplat(N: N0);
5459 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
5460
5461 // fold operation with constant operands.
5462 // TODO: Move this to FoldConstantArithmetic when it supports nodes with
5463 // multiple results.
5464 if (N0C && N1C) {
5465 bool Overflow;
5466 APInt Result =
5467 IsSigned ? N0C->getAPIntValue().smul_ov(RHS: N1C->getAPIntValue(), Overflow)
5468 : N0C->getAPIntValue().umul_ov(RHS: N1C->getAPIntValue(), Overflow);
5469 return CombineTo(N, Res0: DAG.getConstant(Val: Result, DL, VT),
5470 Res1: DAG.getBoolConstant(V: Overflow, DL, VT: CarryVT, OpVT: CarryVT));
5471 }
5472
5473 // canonicalize constant to RHS.
5474 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5475 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5476 return DAG.getNode(Opcode: N->getOpcode(), DL, VTList: N->getVTList(), N1, N2: N0);
5477
5478 // fold (mulo x, 0) -> 0 + no carry out
5479 if (isNullOrNullSplat(V: N1))
5480 return CombineTo(N, Res0: DAG.getConstant(Val: 0, DL, VT),
5481 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
5482
5483 // (mulo x, 2) -> (addo x, x)
5484 // FIXME: This needs a freeze.
5485 if (N1C && N1C->getAPIntValue() == 2 &&
5486 (!IsSigned || VT.getScalarSizeInBits() > 2))
5487 return DAG.getNode(Opcode: IsSigned ? ISD::SADDO : ISD::UADDO, DL,
5488 VTList: N->getVTList(), N1: N0, N2: N0);
5489
5490 // A 1 bit SMULO overflows if both inputs are 1.
5491 if (IsSigned && VT.getScalarSizeInBits() == 1) {
5492 SDValue And = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0, N2: N1);
5493 SDValue Cmp = DAG.getSetCC(DL, VT: CarryVT, LHS: And,
5494 RHS: DAG.getConstant(Val: 0, DL, VT), Cond: ISD::SETNE);
5495 return CombineTo(N, Res0: And, Res1: Cmp);
5496 }
5497
5498 // If it cannot overflow, transform into a mul.
5499 if (DAG.willNotOverflowMul(IsSigned, N0, N1))
5500 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: N0, N2: N1),
5501 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
5502 return SDValue();
5503}
5504
5505// Function to calculate whether the Min/Max pair of SDNodes (potentially
5506// swapped around) make a signed saturate pattern, clamping to between a signed
5507// saturate of -2^(BW-1) and 2^(BW-1)-1, or an unsigned saturate of 0 and 2^BW.
5508// Returns the node being clamped and the bitwidth of the clamp in BW. Should
5509// work with both SMIN/SMAX nodes and setcc/select combo. The operands are the
5510// same as SimplifySelectCC. N0<N1 ? N2 : N3.
5511static SDValue isSaturatingMinMax(SDValue N0, SDValue N1, SDValue N2,
5512 SDValue N3, ISD::CondCode CC, unsigned &BW,
5513 bool &Unsigned, SelectionDAG &DAG) {
5514 auto isSignedMinMax = [&](SDValue N0, SDValue N1, SDValue N2, SDValue N3,
5515 ISD::CondCode CC) {
5516 // The compare and select operand should be the same or the select operands
5517 // should be truncated versions of the comparison.
5518 if (N0 != N2 && (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(i: 0)))
5519 return 0;
5520 // The constants need to be the same or a truncated version of each other.
5521 ConstantSDNode *N1C = isConstOrConstSplat(N: peekThroughTruncates(V: N1));
5522 ConstantSDNode *N3C = isConstOrConstSplat(N: peekThroughTruncates(V: N3));
5523 if (!N1C || !N3C)
5524 return 0;
5525 const APInt &C1 = N1C->getAPIntValue().trunc(width: N1.getScalarValueSizeInBits());
5526 const APInt &C2 = N3C->getAPIntValue().trunc(width: N3.getScalarValueSizeInBits());
5527 if (C1.getBitWidth() < C2.getBitWidth() || C1 != C2.sext(width: C1.getBitWidth()))
5528 return 0;
5529 return CC == ISD::SETLT ? ISD::SMIN : (CC == ISD::SETGT ? ISD::SMAX : 0);
5530 };
5531
5532 // Check the initial value is a SMIN/SMAX equivalent.
5533 unsigned Opcode0 = isSignedMinMax(N0, N1, N2, N3, CC);
5534 if (!Opcode0)
5535 return SDValue();
5536
5537 // We could only need one range check, if the fptosi could never produce
5538 // the upper value.
5539 if (N0.getOpcode() == ISD::FP_TO_SINT && Opcode0 == ISD::SMAX) {
5540 if (isNullOrNullSplat(V: N3)) {
5541 EVT IntVT = N0.getValueType().getScalarType();
5542 EVT FPVT = N0.getOperand(i: 0).getValueType().getScalarType();
5543 if (FPVT.isSimple()) {
5544 Type *InputTy = FPVT.getTypeForEVT(Context&: *DAG.getContext());
5545 const fltSemantics &Semantics = InputTy->getFltSemantics();
5546 uint32_t MinBitWidth =
5547 APFloatBase::semanticsIntSizeInBits(Semantics, /*isSigned*/ true);
5548 if (IntVT.getSizeInBits() >= MinBitWidth) {
5549 Unsigned = true;
5550 BW = PowerOf2Ceil(A: MinBitWidth);
5551 return N0;
5552 }
5553 }
5554 }
5555 }
5556
5557 SDValue N00, N01, N02, N03;
5558 ISD::CondCode N0CC;
5559 switch (N0.getOpcode()) {
5560 case ISD::SMIN:
5561 case ISD::SMAX:
5562 N00 = N02 = N0.getOperand(i: 0);
5563 N01 = N03 = N0.getOperand(i: 1);
5564 N0CC = N0.getOpcode() == ISD::SMIN ? ISD::SETLT : ISD::SETGT;
5565 break;
5566 case ISD::SELECT_CC:
5567 N00 = N0.getOperand(i: 0);
5568 N01 = N0.getOperand(i: 1);
5569 N02 = N0.getOperand(i: 2);
5570 N03 = N0.getOperand(i: 3);
5571 N0CC = cast<CondCodeSDNode>(Val: N0.getOperand(i: 4))->get();
5572 break;
5573 case ISD::SELECT:
5574 case ISD::VSELECT:
5575 if (N0.getOperand(i: 0).getOpcode() != ISD::SETCC)
5576 return SDValue();
5577 N00 = N0.getOperand(i: 0).getOperand(i: 0);
5578 N01 = N0.getOperand(i: 0).getOperand(i: 1);
5579 N02 = N0.getOperand(i: 1);
5580 N03 = N0.getOperand(i: 2);
5581 N0CC = cast<CondCodeSDNode>(Val: N0.getOperand(i: 0).getOperand(i: 2))->get();
5582 break;
5583 default:
5584 return SDValue();
5585 }
5586
5587 unsigned Opcode1 = isSignedMinMax(N00, N01, N02, N03, N0CC);
5588 if (!Opcode1 || Opcode0 == Opcode1)
5589 return SDValue();
5590
5591 ConstantSDNode *MinCOp = isConstOrConstSplat(N: Opcode0 == ISD::SMIN ? N1 : N01);
5592 ConstantSDNode *MaxCOp = isConstOrConstSplat(N: Opcode0 == ISD::SMIN ? N01 : N1);
5593 if (!MinCOp || !MaxCOp || MinCOp->getValueType(ResNo: 0) != MaxCOp->getValueType(ResNo: 0))
5594 return SDValue();
5595
5596 const APInt &MinC = MinCOp->getAPIntValue();
5597 const APInt &MaxC = MaxCOp->getAPIntValue();
5598 APInt MinCPlus1 = MinC + 1;
5599 if (-MaxC == MinCPlus1 && MinCPlus1.isPowerOf2()) {
5600 BW = MinCPlus1.exactLogBase2() + 1;
5601 Unsigned = false;
5602 return N02;
5603 }
5604
5605 if (MaxC == 0 && MinCPlus1.isPowerOf2()) {
5606 BW = MinCPlus1.exactLogBase2();
5607 Unsigned = true;
5608 return N02;
5609 }
5610
5611 return SDValue();
5612}
5613
5614static SDValue PerformMinMaxFpToSatCombine(SDValue N0, SDValue N1, SDValue N2,
5615 SDValue N3, ISD::CondCode CC,
5616 SelectionDAG &DAG) {
5617 unsigned BW;
5618 bool Unsigned;
5619 SDValue Fp = isSaturatingMinMax(N0, N1, N2, N3, CC, BW, Unsigned, DAG);
5620 if (!Fp || Fp.getOpcode() != ISD::FP_TO_SINT)
5621 return SDValue();
5622 EVT FPVT = Fp.getOperand(i: 0).getValueType();
5623 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: BW);
5624 if (FPVT.isVector())
5625 NewVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NewVT,
5626 EC: FPVT.getVectorElementCount());
5627 unsigned NewOpc = Unsigned ? ISD::FP_TO_UINT_SAT : ISD::FP_TO_SINT_SAT;
5628 if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(Op: NewOpc, FPVT, VT: NewVT))
5629 return SDValue();
5630 SDLoc DL(Fp);
5631 SDValue Sat = DAG.getNode(Opcode: NewOpc, DL, VT: NewVT, N1: Fp.getOperand(i: 0),
5632 N2: DAG.getValueType(NewVT.getScalarType()));
5633 return DAG.getExtOrTrunc(IsSigned: !Unsigned, Op: Sat, DL, VT: N2->getValueType(ResNo: 0));
5634}
5635
5636static SDValue PerformUMinFpToSatCombine(SDValue N0, SDValue N1, SDValue N2,
5637 SDValue N3, ISD::CondCode CC,
5638 SelectionDAG &DAG) {
5639 // We are looking for UMIN(FPTOUI(X), (2^n)-1), which may have come via a
5640 // select/vselect/select_cc. The two operands pairs for the select (N2/N3) may
5641 // be truncated versions of the setcc (N0/N1).
5642 if ((N0 != N2 &&
5643 (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(i: 0))) ||
5644 N0.getOpcode() != ISD::FP_TO_UINT || CC != ISD::SETULT)
5645 return SDValue();
5646 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
5647 ConstantSDNode *N3C = isConstOrConstSplat(N: N3);
5648 if (!N1C || !N3C)
5649 return SDValue();
5650 const APInt &C1 = N1C->getAPIntValue();
5651 const APInt &C3 = N3C->getAPIntValue();
5652 if (!(C1 + 1).isPowerOf2() || C1.getBitWidth() < C3.getBitWidth() ||
5653 C1 != C3.zext(width: C1.getBitWidth()))
5654 return SDValue();
5655
5656 unsigned BW = (C1 + 1).exactLogBase2();
5657 EVT FPVT = N0.getOperand(i: 0).getValueType();
5658 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: BW);
5659 if (FPVT.isVector())
5660 NewVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NewVT,
5661 EC: FPVT.getVectorElementCount());
5662 if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(Op: ISD::FP_TO_UINT_SAT,
5663 FPVT, VT: NewVT))
5664 return SDValue();
5665
5666 SDValue Sat =
5667 DAG.getNode(Opcode: ISD::FP_TO_UINT_SAT, DL: SDLoc(N0), VT: NewVT, N1: N0.getOperand(i: 0),
5668 N2: DAG.getValueType(NewVT.getScalarType()));
5669 return DAG.getZExtOrTrunc(Op: Sat, DL: SDLoc(N0), VT: N3.getValueType());
5670}
5671
5672SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
5673 SDValue N0 = N->getOperand(Num: 0);
5674 SDValue N1 = N->getOperand(Num: 1);
5675 EVT VT = N0.getValueType();
5676 unsigned Opcode = N->getOpcode();
5677 SDLoc DL(N);
5678
5679 // fold operation with constant operands.
5680 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
5681 return C;
5682
5683 // If the operands are the same, this is a no-op.
5684 if (N0 == N1)
5685 return N0;
5686
5687 // canonicalize constant to RHS
5688 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5689 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5690 return DAG.getNode(Opcode, DL, VT, N1, N2: N0);
5691
5692 // fold vector ops
5693 if (VT.isVector())
5694 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5695 return FoldedVOp;
5696
5697 // Is sign bits are zero, flip between UMIN/UMAX and SMIN/SMAX.
5698 // Only do this if the current op isn't legal and the flipped is.
5699 if (!TLI.isOperationLegal(Op: Opcode, VT) &&
5700 (N0.isUndef() || DAG.SignBitIsZero(Op: N0)) &&
5701 (N1.isUndef() || DAG.SignBitIsZero(Op: N1))) {
5702 unsigned AltOpcode;
5703 switch (Opcode) {
5704 case ISD::SMIN: AltOpcode = ISD::UMIN; break;
5705 case ISD::SMAX: AltOpcode = ISD::UMAX; break;
5706 case ISD::UMIN: AltOpcode = ISD::SMIN; break;
5707 case ISD::UMAX: AltOpcode = ISD::SMAX; break;
5708 default: llvm_unreachable("Unknown MINMAX opcode");
5709 }
5710 if (TLI.isOperationLegal(Op: AltOpcode, VT))
5711 return DAG.getNode(Opcode: AltOpcode, DL, VT, N1: N0, N2: N1);
5712 }
5713
5714 if (Opcode == ISD::SMIN || Opcode == ISD::SMAX)
5715 if (SDValue S = PerformMinMaxFpToSatCombine(
5716 N0, N1, N2: N0, N3: N1, CC: Opcode == ISD::SMIN ? ISD::SETLT : ISD::SETGT, DAG))
5717 return S;
5718 if (Opcode == ISD::UMIN)
5719 if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N2: N0, N3: N1, CC: ISD::SETULT, DAG))
5720 return S;
5721
5722 // Fold min/max(vecreduce(x), vecreduce(y)) -> vecreduce(min/max(x, y))
5723 auto ReductionOpcode = [](unsigned Opcode) {
5724 switch (Opcode) {
5725 case ISD::SMIN:
5726 return ISD::VECREDUCE_SMIN;
5727 case ISD::SMAX:
5728 return ISD::VECREDUCE_SMAX;
5729 case ISD::UMIN:
5730 return ISD::VECREDUCE_UMIN;
5731 case ISD::UMAX:
5732 return ISD::VECREDUCE_UMAX;
5733 default:
5734 llvm_unreachable("Unexpected opcode");
5735 }
5736 };
5737 if (SDValue SD = reassociateReduction(RedOpc: ReductionOpcode(Opcode), Opc: Opcode,
5738 DL: SDLoc(N), VT, N0, N1))
5739 return SD;
5740
5741 // Simplify the operands using demanded-bits information.
5742 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
5743 return SDValue(N, 0);
5744
5745 return SDValue();
5746}
5747
5748/// If this is a bitwise logic instruction and both operands have the same
5749/// opcode, try to sink the other opcode after the logic instruction.
5750SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) {
5751 SDValue N0 = N->getOperand(Num: 0), N1 = N->getOperand(Num: 1);
5752 EVT VT = N0.getValueType();
5753 unsigned LogicOpcode = N->getOpcode();
5754 unsigned HandOpcode = N0.getOpcode();
5755 assert(ISD::isBitwiseLogicOp(LogicOpcode) && "Expected logic opcode");
5756 assert(HandOpcode == N1.getOpcode() && "Bad input!");
5757
5758 // Bail early if none of these transforms apply.
5759 if (N0.getNumOperands() == 0)
5760 return SDValue();
5761
5762 // FIXME: We should check number of uses of the operands to not increase
5763 // the instruction count for all transforms.
5764
5765 // Handle size-changing casts (or sign_extend_inreg).
5766 SDValue X = N0.getOperand(i: 0);
5767 SDValue Y = N1.getOperand(i: 0);
5768 EVT XVT = X.getValueType();
5769 SDLoc DL(N);
5770 if (ISD::isExtOpcode(Opcode: HandOpcode) || ISD::isExtVecInRegOpcode(Opcode: HandOpcode) ||
5771 (HandOpcode == ISD::SIGN_EXTEND_INREG &&
5772 N0.getOperand(i: 1) == N1.getOperand(i: 1))) {
5773 // If both operands have other uses, this transform would create extra
5774 // instructions without eliminating anything.
5775 if (!N0.hasOneUse() && !N1.hasOneUse())
5776 return SDValue();
5777 // We need matching integer source types.
5778 if (XVT != Y.getValueType())
5779 return SDValue();
5780 // Don't create an illegal op during or after legalization. Don't ever
5781 // create an unsupported vector op.
5782 if ((VT.isVector() || LegalOperations) &&
5783 !TLI.isOperationLegalOrCustom(Op: LogicOpcode, VT: XVT))
5784 return SDValue();
5785 // Avoid infinite looping with PromoteIntBinOp.
5786 // TODO: Should we apply desirable/legal constraints to all opcodes?
5787 if ((HandOpcode == ISD::ANY_EXTEND ||
5788 HandOpcode == ISD::ANY_EXTEND_VECTOR_INREG) &&
5789 LegalTypes && !TLI.isTypeDesirableForOp(LogicOpcode, VT: XVT))
5790 return SDValue();
5791 // logic_op (hand_op X), (hand_op Y) --> hand_op (logic_op X, Y)
5792 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT: XVT, N1: X, N2: Y);
5793 if (HandOpcode == ISD::SIGN_EXTEND_INREG)
5794 return DAG.getNode(Opcode: HandOpcode, DL, VT, N1: Logic, N2: N0.getOperand(i: 1));
5795 return DAG.getNode(Opcode: HandOpcode, DL, VT, Operand: Logic);
5796 }
5797
5798 // logic_op (truncate x), (truncate y) --> truncate (logic_op x, y)
5799 if (HandOpcode == ISD::TRUNCATE) {
5800 // If both operands have other uses, this transform would create extra
5801 // instructions without eliminating anything.
5802 if (!N0.hasOneUse() && !N1.hasOneUse())
5803 return SDValue();
5804 // We need matching source types.
5805 if (XVT != Y.getValueType())
5806 return SDValue();
5807 // Don't create an illegal op during or after legalization.
5808 if (LegalOperations && !TLI.isOperationLegal(Op: LogicOpcode, VT: XVT))
5809 return SDValue();
5810 // Be extra careful sinking truncate. If it's free, there's no benefit in
5811 // widening a binop. Also, don't create a logic op on an illegal type.
5812 if (TLI.isZExtFree(FromTy: VT, ToTy: XVT) && TLI.isTruncateFree(FromVT: XVT, ToVT: VT))
5813 return SDValue();
5814 if (!TLI.isTypeLegal(VT: XVT))
5815 return SDValue();
5816 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT: XVT, N1: X, N2: Y);
5817 return DAG.getNode(Opcode: HandOpcode, DL, VT, Operand: Logic);
5818 }
5819
5820 // For binops SHL/SRL/SRA/AND:
5821 // logic_op (OP x, z), (OP y, z) --> OP (logic_op x, y), z
5822 if ((HandOpcode == ISD::SHL || HandOpcode == ISD::SRL ||
5823 HandOpcode == ISD::SRA || HandOpcode == ISD::AND) &&
5824 N0.getOperand(i: 1) == N1.getOperand(i: 1)) {
5825 // If either operand has other uses, this transform is not an improvement.
5826 if (!N0.hasOneUse() || !N1.hasOneUse())
5827 return SDValue();
5828 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT: XVT, N1: X, N2: Y);
5829 return DAG.getNode(Opcode: HandOpcode, DL, VT, N1: Logic, N2: N0.getOperand(i: 1));
5830 }
5831
5832 // Unary ops: logic_op (bswap x), (bswap y) --> bswap (logic_op x, y)
5833 if (HandOpcode == ISD::BSWAP) {
5834 // If either operand has other uses, this transform is not an improvement.
5835 if (!N0.hasOneUse() || !N1.hasOneUse())
5836 return SDValue();
5837 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT: XVT, N1: X, N2: Y);
5838 return DAG.getNode(Opcode: HandOpcode, DL, VT, Operand: Logic);
5839 }
5840
5841 // For funnel shifts FSHL/FSHR:
5842 // logic_op (OP x, x1, s), (OP y, y1, s) -->
5843 // --> OP (logic_op x, y), (logic_op, x1, y1), s
5844 if ((HandOpcode == ISD::FSHL || HandOpcode == ISD::FSHR) &&
5845 N0.getOperand(i: 2) == N1.getOperand(i: 2)) {
5846 if (!N0.hasOneUse() || !N1.hasOneUse())
5847 return SDValue();
5848 SDValue X1 = N0.getOperand(i: 1);
5849 SDValue Y1 = N1.getOperand(i: 1);
5850 SDValue S = N0.getOperand(i: 2);
5851 SDValue Logic0 = DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: X, N2: Y);
5852 SDValue Logic1 = DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: X1, N2: Y1);
5853 return DAG.getNode(Opcode: HandOpcode, DL, VT, N1: Logic0, N2: Logic1, N3: S);
5854 }
5855
5856 // Simplify xor/and/or (bitcast(A), bitcast(B)) -> bitcast(op (A,B))
5857 // Only perform this optimization up until type legalization, before
5858 // LegalizeVectorOprs. LegalizeVectorOprs promotes vector operations by
5859 // adding bitcasts. For example (xor v4i32) is promoted to (v2i64), and
5860 // we don't want to undo this promotion.
5861 // We also handle SCALAR_TO_VECTOR because xor/or/and operations are cheaper
5862 // on scalars.
5863 if ((HandOpcode == ISD::BITCAST || HandOpcode == ISD::SCALAR_TO_VECTOR) &&
5864 Level <= AfterLegalizeTypes) {
5865 // Input types must be integer and the same.
5866 if (XVT.isInteger() && XVT == Y.getValueType() &&
5867 !(VT.isVector() && TLI.isTypeLegal(VT) &&
5868 !XVT.isVector() && !TLI.isTypeLegal(VT: XVT))) {
5869 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT: XVT, N1: X, N2: Y);
5870 return DAG.getNode(Opcode: HandOpcode, DL, VT, Operand: Logic);
5871 }
5872 }
5873
5874 // Xor/and/or are indifferent to the swizzle operation (shuffle of one value).
5875 // Simplify xor/and/or (shuff(A), shuff(B)) -> shuff(op (A,B))
5876 // If both shuffles use the same mask, and both shuffle within a single
5877 // vector, then it is worthwhile to move the swizzle after the operation.
5878 // The type-legalizer generates this pattern when loading illegal
5879 // vector types from memory. In many cases this allows additional shuffle
5880 // optimizations.
5881 // There are other cases where moving the shuffle after the xor/and/or
5882 // is profitable even if shuffles don't perform a swizzle.
5883 // If both shuffles use the same mask, and both shuffles have the same first
5884 // or second operand, then it might still be profitable to move the shuffle
5885 // after the xor/and/or operation.
5886 if (HandOpcode == ISD::VECTOR_SHUFFLE && Level < AfterLegalizeDAG) {
5887 auto *SVN0 = cast<ShuffleVectorSDNode>(Val&: N0);
5888 auto *SVN1 = cast<ShuffleVectorSDNode>(Val&: N1);
5889 assert(X.getValueType() == Y.getValueType() &&
5890 "Inputs to shuffles are not the same type");
5891
5892 // Check that both shuffles use the same mask. The masks are known to be of
5893 // the same length because the result vector type is the same.
5894 // Check also that shuffles have only one use to avoid introducing extra
5895 // instructions.
5896 if (!SVN0->hasOneUse() || !SVN1->hasOneUse() ||
5897 !SVN0->getMask().equals(RHS: SVN1->getMask()))
5898 return SDValue();
5899
5900 // Don't try to fold this node if it requires introducing a
5901 // build vector of all zeros that might be illegal at this stage.
5902 SDValue ShOp = N0.getOperand(i: 1);
5903 if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
5904 ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
5905
5906 // (logic_op (shuf (A, C), shuf (B, C))) --> shuf (logic_op (A, B), C)
5907 if (N0.getOperand(i: 1) == N1.getOperand(i: 1) && ShOp.getNode()) {
5908 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT,
5909 N1: N0.getOperand(i: 0), N2: N1.getOperand(i: 0));
5910 return DAG.getVectorShuffle(VT, dl: DL, N1: Logic, N2: ShOp, Mask: SVN0->getMask());
5911 }
5912
5913 // Don't try to fold this node if it requires introducing a
5914 // build vector of all zeros that might be illegal at this stage.
5915 ShOp = N0.getOperand(i: 0);
5916 if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
5917 ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
5918
5919 // (logic_op (shuf (C, A), shuf (C, B))) --> shuf (C, logic_op (A, B))
5920 if (N0.getOperand(i: 0) == N1.getOperand(i: 0) && ShOp.getNode()) {
5921 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: N0.getOperand(i: 1),
5922 N2: N1.getOperand(i: 1));
5923 return DAG.getVectorShuffle(VT, dl: DL, N1: ShOp, N2: Logic, Mask: SVN0->getMask());
5924 }
5925 }
5926
5927 return SDValue();
5928}
5929
5930/// Try to make (and/or setcc (LL, LR), setcc (RL, RR)) more efficient.
5931SDValue DAGCombiner::foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
5932 const SDLoc &DL) {
5933 SDValue LL, LR, RL, RR, N0CC, N1CC;
5934 if (!isSetCCEquivalent(N: N0, LHS&: LL, RHS&: LR, CC&: N0CC) ||
5935 !isSetCCEquivalent(N: N1, LHS&: RL, RHS&: RR, CC&: N1CC))
5936 return SDValue();
5937
5938 assert(N0.getValueType() == N1.getValueType() &&
5939 "Unexpected operand types for bitwise logic op");
5940 assert(LL.getValueType() == LR.getValueType() &&
5941 RL.getValueType() == RR.getValueType() &&
5942 "Unexpected operand types for setcc");
5943
5944 // If we're here post-legalization or the logic op type is not i1, the logic
5945 // op type must match a setcc result type. Also, all folds require new
5946 // operations on the left and right operands, so those types must match.
5947 EVT VT = N0.getValueType();
5948 EVT OpVT = LL.getValueType();
5949 if (LegalOperations || VT.getScalarType() != MVT::i1)
5950 if (VT != getSetCCResultType(VT: OpVT))
5951 return SDValue();
5952 if (OpVT != RL.getValueType())
5953 return SDValue();
5954
5955 ISD::CondCode CC0 = cast<CondCodeSDNode>(Val&: N0CC)->get();
5956 ISD::CondCode CC1 = cast<CondCodeSDNode>(Val&: N1CC)->get();
5957 bool IsInteger = OpVT.isInteger();
5958 if (LR == RR && CC0 == CC1 && IsInteger) {
5959 bool IsZero = isNullOrNullSplat(V: LR);
5960 bool IsNeg1 = isAllOnesOrAllOnesSplat(V: LR);
5961
5962 // All bits clear?
5963 bool AndEqZero = IsAnd && CC1 == ISD::SETEQ && IsZero;
5964 // All sign bits clear?
5965 bool AndGtNeg1 = IsAnd && CC1 == ISD::SETGT && IsNeg1;
5966 // Any bits set?
5967 bool OrNeZero = !IsAnd && CC1 == ISD::SETNE && IsZero;
5968 // Any sign bits set?
5969 bool OrLtZero = !IsAnd && CC1 == ISD::SETLT && IsZero;
5970
5971 // (and (seteq X, 0), (seteq Y, 0)) --> (seteq (or X, Y), 0)
5972 // (and (setgt X, -1), (setgt Y, -1)) --> (setgt (or X, Y), -1)
5973 // (or (setne X, 0), (setne Y, 0)) --> (setne (or X, Y), 0)
5974 // (or (setlt X, 0), (setlt Y, 0)) --> (setlt (or X, Y), 0)
5975 if (AndEqZero || AndGtNeg1 || OrNeZero || OrLtZero) {
5976 SDValue Or = DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N0), VT: OpVT, N1: LL, N2: RL);
5977 AddToWorklist(N: Or.getNode());
5978 return DAG.getSetCC(DL, VT, LHS: Or, RHS: LR, Cond: CC1);
5979 }
5980
5981 // All bits set?
5982 bool AndEqNeg1 = IsAnd && CC1 == ISD::SETEQ && IsNeg1;
5983 // All sign bits set?
5984 bool AndLtZero = IsAnd && CC1 == ISD::SETLT && IsZero;
5985 // Any bits clear?
5986 bool OrNeNeg1 = !IsAnd && CC1 == ISD::SETNE && IsNeg1;
5987 // Any sign bits clear?
5988 bool OrGtNeg1 = !IsAnd && CC1 == ISD::SETGT && IsNeg1;
5989
5990 // (and (seteq X, -1), (seteq Y, -1)) --> (seteq (and X, Y), -1)
5991 // (and (setlt X, 0), (setlt Y, 0)) --> (setlt (and X, Y), 0)
5992 // (or (setne X, -1), (setne Y, -1)) --> (setne (and X, Y), -1)
5993 // (or (setgt X, -1), (setgt Y -1)) --> (setgt (and X, Y), -1)
5994 if (AndEqNeg1 || AndLtZero || OrNeNeg1 || OrGtNeg1) {
5995 SDValue And = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(N0), VT: OpVT, N1: LL, N2: RL);
5996 AddToWorklist(N: And.getNode());
5997 return DAG.getSetCC(DL, VT, LHS: And, RHS: LR, Cond: CC1);
5998 }
5999 }
6000
6001 // TODO: What is the 'or' equivalent of this fold?
6002 // (and (setne X, 0), (setne X, -1)) --> (setuge (add X, 1), 2)
6003 if (IsAnd && LL == RL && CC0 == CC1 && OpVT.getScalarSizeInBits() > 1 &&
6004 IsInteger && CC0 == ISD::SETNE &&
6005 ((isNullConstant(V: LR) && isAllOnesConstant(V: RR)) ||
6006 (isAllOnesConstant(V: LR) && isNullConstant(V: RR)))) {
6007 SDValue One = DAG.getConstant(Val: 1, DL, VT: OpVT);
6008 SDValue Two = DAG.getConstant(Val: 2, DL, VT: OpVT);
6009 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL: SDLoc(N0), VT: OpVT, N1: LL, N2: One);
6010 AddToWorklist(N: Add.getNode());
6011 return DAG.getSetCC(DL, VT, LHS: Add, RHS: Two, Cond: ISD::SETUGE);
6012 }
6013
6014 // Try more general transforms if the predicates match and the only user of
6015 // the compares is the 'and' or 'or'.
6016 if (IsInteger && TLI.convertSetCCLogicToBitwiseLogic(VT: OpVT) && CC0 == CC1 &&
6017 N0.hasOneUse() && N1.hasOneUse()) {
6018 // and (seteq A, B), (seteq C, D) --> seteq (or (xor A, B), (xor C, D)), 0
6019 // or (setne A, B), (setne C, D) --> setne (or (xor A, B), (xor C, D)), 0
6020 if ((IsAnd && CC1 == ISD::SETEQ) || (!IsAnd && CC1 == ISD::SETNE)) {
6021 SDValue XorL = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N0), VT: OpVT, N1: LL, N2: LR);
6022 SDValue XorR = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N1), VT: OpVT, N1: RL, N2: RR);
6023 SDValue Or = DAG.getNode(Opcode: ISD::OR, DL, VT: OpVT, N1: XorL, N2: XorR);
6024 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: OpVT);
6025 return DAG.getSetCC(DL, VT, LHS: Or, RHS: Zero, Cond: CC1);
6026 }
6027
6028 // Turn compare of constants whose difference is 1 bit into add+and+setcc.
6029 if ((IsAnd && CC1 == ISD::SETNE) || (!IsAnd && CC1 == ISD::SETEQ)) {
6030 // Match a shared variable operand and 2 non-opaque constant operands.
6031 auto MatchDiffPow2 = [&](ConstantSDNode *C0, ConstantSDNode *C1) {
6032 // The difference of the constants must be a single bit.
6033 const APInt &CMax =
6034 APIntOps::umax(A: C0->getAPIntValue(), B: C1->getAPIntValue());
6035 const APInt &CMin =
6036 APIntOps::umin(A: C0->getAPIntValue(), B: C1->getAPIntValue());
6037 return !C0->isOpaque() && !C1->isOpaque() && (CMax - CMin).isPowerOf2();
6038 };
6039 if (LL == RL && ISD::matchBinaryPredicate(LHS: LR, RHS: RR, Match: MatchDiffPow2)) {
6040 // and/or (setcc X, CMax, ne), (setcc X, CMin, ne/eq) -->
6041 // setcc ((sub X, CMin), ~(CMax - CMin)), 0, ne/eq
6042 SDValue Max = DAG.getNode(Opcode: ISD::UMAX, DL, VT: OpVT, N1: LR, N2: RR);
6043 SDValue Min = DAG.getNode(Opcode: ISD::UMIN, DL, VT: OpVT, N1: LR, N2: RR);
6044 SDValue Offset = DAG.getNode(Opcode: ISD::SUB, DL, VT: OpVT, N1: LL, N2: Min);
6045 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: OpVT, N1: Max, N2: Min);
6046 SDValue Mask = DAG.getNOT(DL, Val: Diff, VT: OpVT);
6047 SDValue And = DAG.getNode(Opcode: ISD::AND, DL, VT: OpVT, N1: Offset, N2: Mask);
6048 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: OpVT);
6049 return DAG.getSetCC(DL, VT, LHS: And, RHS: Zero, Cond: CC0);
6050 }
6051 }
6052 }
6053
6054 // Canonicalize equivalent operands to LL == RL.
6055 if (LL == RR && LR == RL) {
6056 CC1 = ISD::getSetCCSwappedOperands(Operation: CC1);
6057 std::swap(a&: RL, b&: RR);
6058 }
6059
6060 // (and (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
6061 // (or (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
6062 if (LL == RL && LR == RR) {
6063 ISD::CondCode NewCC = IsAnd ? ISD::getSetCCAndOperation(Op1: CC0, Op2: CC1, Type: OpVT)
6064 : ISD::getSetCCOrOperation(Op1: CC0, Op2: CC1, Type: OpVT);
6065 if (NewCC != ISD::SETCC_INVALID &&
6066 (!LegalOperations ||
6067 (TLI.isCondCodeLegal(CC: NewCC, VT: LL.getSimpleValueType()) &&
6068 TLI.isOperationLegal(Op: ISD::SETCC, VT: OpVT))))
6069 return DAG.getSetCC(DL, VT, LHS: LL, RHS: LR, Cond: NewCC);
6070 }
6071
6072 return SDValue();
6073}
6074
6075static bool arebothOperandsNotSNan(SDValue Operand1, SDValue Operand2,
6076 SelectionDAG &DAG) {
6077 return DAG.isKnownNeverSNaN(Op: Operand2) && DAG.isKnownNeverSNaN(Op: Operand1);
6078}
6079
6080static bool arebothOperandsNotNan(SDValue Operand1, SDValue Operand2,
6081 SelectionDAG &DAG) {
6082 return DAG.isKnownNeverNaN(Op: Operand2) && DAG.isKnownNeverNaN(Op: Operand1);
6083}
6084
6085static unsigned getMinMaxOpcodeForFP(SDValue Operand1, SDValue Operand2,
6086 ISD::CondCode CC, unsigned OrAndOpcode,
6087 SelectionDAG &DAG,
6088 bool isFMAXNUMFMINNUM_IEEE,
6089 bool isFMAXNUMFMINNUM) {
6090 // The optimization cannot be applied for all the predicates because
6091 // of the way FMINNUM/FMAXNUM and FMINNUM_IEEE/FMAXNUM_IEEE handle
6092 // NaNs. For FMINNUM_IEEE/FMAXNUM_IEEE, the optimization cannot be
6093 // applied at all if one of the operands is a signaling NaN.
6094
6095 // It is safe to use FMINNUM_IEEE/FMAXNUM_IEEE if all the operands
6096 // are non NaN values.
6097 if (((CC == ISD::SETLT || CC == ISD::SETLE) && (OrAndOpcode == ISD::OR)) ||
6098 ((CC == ISD::SETGT || CC == ISD::SETGE) && (OrAndOpcode == ISD::AND)))
6099 return arebothOperandsNotNan(Operand1, Operand2, DAG) &&
6100 isFMAXNUMFMINNUM_IEEE
6101 ? ISD::FMINNUM_IEEE
6102 : ISD::DELETED_NODE;
6103 else if (((CC == ISD::SETGT || CC == ISD::SETGE) &&
6104 (OrAndOpcode == ISD::OR)) ||
6105 ((CC == ISD::SETLT || CC == ISD::SETLE) &&
6106 (OrAndOpcode == ISD::AND)))
6107 return arebothOperandsNotNan(Operand1, Operand2, DAG) &&
6108 isFMAXNUMFMINNUM_IEEE
6109 ? ISD::FMAXNUM_IEEE
6110 : ISD::DELETED_NODE;
6111 // Both FMINNUM/FMAXNUM and FMINNUM_IEEE/FMAXNUM_IEEE handle quiet
6112 // NaNs in the same way. But, FMINNUM/FMAXNUM and FMINNUM_IEEE/
6113 // FMAXNUM_IEEE handle signaling NaNs differently. If we cannot prove
6114 // that there are not any sNaNs, then the optimization is not valid
6115 // for FMINNUM_IEEE/FMAXNUM_IEEE. In the presence of sNaNs, we apply
6116 // the optimization using FMINNUM/FMAXNUM for the following cases. If
6117 // we can prove that we do not have any sNaNs, then we can do the
6118 // optimization using FMINNUM_IEEE/FMAXNUM_IEEE for the following
6119 // cases.
6120 else if (((CC == ISD::SETOLT || CC == ISD::SETOLE) &&
6121 (OrAndOpcode == ISD::OR)) ||
6122 ((CC == ISD::SETUGT || CC == ISD::SETUGE) &&
6123 (OrAndOpcode == ISD::AND)))
6124 return isFMAXNUMFMINNUM ? ISD::FMINNUM
6125 : arebothOperandsNotSNan(Operand1, Operand2, DAG) &&
6126 isFMAXNUMFMINNUM_IEEE
6127 ? ISD::FMINNUM_IEEE
6128 : ISD::DELETED_NODE;
6129 else if (((CC == ISD::SETOGT || CC == ISD::SETOGE) &&
6130 (OrAndOpcode == ISD::OR)) ||
6131 ((CC == ISD::SETULT || CC == ISD::SETULE) &&
6132 (OrAndOpcode == ISD::AND)))
6133 return isFMAXNUMFMINNUM ? ISD::FMAXNUM
6134 : arebothOperandsNotSNan(Operand1, Operand2, DAG) &&
6135 isFMAXNUMFMINNUM_IEEE
6136 ? ISD::FMAXNUM_IEEE
6137 : ISD::DELETED_NODE;
6138 return ISD::DELETED_NODE;
6139}
6140
6141static SDValue foldAndOrOfSETCC(SDNode *LogicOp, SelectionDAG &DAG) {
6142 using AndOrSETCCFoldKind = TargetLowering::AndOrSETCCFoldKind;
6143 assert(
6144 (LogicOp->getOpcode() == ISD::AND || LogicOp->getOpcode() == ISD::OR) &&
6145 "Invalid Op to combine SETCC with");
6146
6147 // TODO: Search past casts/truncates.
6148 SDValue LHS = LogicOp->getOperand(Num: 0);
6149 SDValue RHS = LogicOp->getOperand(Num: 1);
6150 if (LHS->getOpcode() != ISD::SETCC || RHS->getOpcode() != ISD::SETCC ||
6151 !LHS->hasOneUse() || !RHS->hasOneUse())
6152 return SDValue();
6153
6154 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6155 AndOrSETCCFoldKind TargetPreference = TLI.isDesirableToCombineLogicOpOfSETCC(
6156 LogicOp, SETCC0: LHS.getNode(), SETCC1: RHS.getNode());
6157
6158 SDValue LHS0 = LHS->getOperand(Num: 0);
6159 SDValue RHS0 = RHS->getOperand(Num: 0);
6160 SDValue LHS1 = LHS->getOperand(Num: 1);
6161 SDValue RHS1 = RHS->getOperand(Num: 1);
6162 // TODO: We don't actually need a splat here, for vectors we just need the
6163 // invariants to hold for each element.
6164 auto *LHS1C = isConstOrConstSplat(N: LHS1);
6165 auto *RHS1C = isConstOrConstSplat(N: RHS1);
6166 ISD::CondCode CCL = cast<CondCodeSDNode>(Val: LHS.getOperand(i: 2))->get();
6167 ISD::CondCode CCR = cast<CondCodeSDNode>(Val: RHS.getOperand(i: 2))->get();
6168 EVT VT = LogicOp->getValueType(ResNo: 0);
6169 EVT OpVT = LHS0.getValueType();
6170 SDLoc DL(LogicOp);
6171
6172 // Check if the operands of an and/or operation are comparisons and if they
6173 // compare against the same value. Replace the and/or-cmp-cmp sequence with
6174 // min/max cmp sequence. If LHS1 is equal to RHS1, then the or-cmp-cmp
6175 // sequence will be replaced with min-cmp sequence:
6176 // (LHS0 < LHS1) | (RHS0 < RHS1) -> min(LHS0, RHS0) < LHS1
6177 // and and-cmp-cmp will be replaced with max-cmp sequence:
6178 // (LHS0 < LHS1) & (RHS0 < RHS1) -> max(LHS0, RHS0) < LHS1
6179 // The optimization does not work for `==` or `!=` .
6180 // The two comparisons should have either the same predicate or the
6181 // predicate of one of the comparisons is the opposite of the other one.
6182 bool isFMAXNUMFMINNUM_IEEE = TLI.isOperationLegal(Op: ISD::FMAXNUM_IEEE, VT: OpVT) &&
6183 TLI.isOperationLegal(Op: ISD::FMINNUM_IEEE, VT: OpVT);
6184 bool isFMAXNUMFMINNUM = TLI.isOperationLegalOrCustom(Op: ISD::FMAXNUM, VT: OpVT) &&
6185 TLI.isOperationLegalOrCustom(Op: ISD::FMINNUM, VT: OpVT);
6186 if (((OpVT.isInteger() && TLI.isOperationLegal(Op: ISD::UMAX, VT: OpVT) &&
6187 TLI.isOperationLegal(Op: ISD::SMAX, VT: OpVT) &&
6188 TLI.isOperationLegal(Op: ISD::UMIN, VT: OpVT) &&
6189 TLI.isOperationLegal(Op: ISD::SMIN, VT: OpVT)) ||
6190 (OpVT.isFloatingPoint() &&
6191 (isFMAXNUMFMINNUM_IEEE || isFMAXNUMFMINNUM))) &&
6192 !ISD::isIntEqualitySetCC(Code: CCL) && !ISD::isFPEqualitySetCC(Code: CCL) &&
6193 CCL != ISD::SETFALSE && CCL != ISD::SETO && CCL != ISD::SETUO &&
6194 CCL != ISD::SETTRUE &&
6195 (CCL == CCR || CCL == ISD::getSetCCSwappedOperands(Operation: CCR))) {
6196
6197 SDValue CommonValue, Operand1, Operand2;
6198 ISD::CondCode CC = ISD::SETCC_INVALID;
6199 if (CCL == CCR) {
6200 if (LHS0 == RHS0) {
6201 CommonValue = LHS0;
6202 Operand1 = LHS1;
6203 Operand2 = RHS1;
6204 CC = ISD::getSetCCSwappedOperands(Operation: CCL);
6205 } else if (LHS1 == RHS1) {
6206 CommonValue = LHS1;
6207 Operand1 = LHS0;
6208 Operand2 = RHS0;
6209 CC = CCL;
6210 }
6211 } else {
6212 assert(CCL == ISD::getSetCCSwappedOperands(CCR) && "Unexpected CC");
6213 if (LHS0 == RHS1) {
6214 CommonValue = LHS0;
6215 Operand1 = LHS1;
6216 Operand2 = RHS0;
6217 CC = CCR;
6218 } else if (RHS0 == LHS1) {
6219 CommonValue = LHS1;
6220 Operand1 = LHS0;
6221 Operand2 = RHS1;
6222 CC = CCL;
6223 }
6224 }
6225
6226 // Don't do this transform for sign bit tests. Let foldLogicOfSetCCs
6227 // handle it using OR/AND.
6228 if (CC == ISD::SETLT && isNullOrNullSplat(V: CommonValue))
6229 CC = ISD::SETCC_INVALID;
6230 else if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(V: CommonValue))
6231 CC = ISD::SETCC_INVALID;
6232
6233 if (CC != ISD::SETCC_INVALID) {
6234 unsigned NewOpcode = ISD::DELETED_NODE;
6235 bool IsSigned = isSignedIntSetCC(Code: CC);
6236 if (OpVT.isInteger()) {
6237 bool IsLess = (CC == ISD::SETLE || CC == ISD::SETULE ||
6238 CC == ISD::SETLT || CC == ISD::SETULT);
6239 bool IsOr = (LogicOp->getOpcode() == ISD::OR);
6240 if (IsLess == IsOr)
6241 NewOpcode = IsSigned ? ISD::SMIN : ISD::UMIN;
6242 else
6243 NewOpcode = IsSigned ? ISD::SMAX : ISD::UMAX;
6244 } else if (OpVT.isFloatingPoint())
6245 NewOpcode =
6246 getMinMaxOpcodeForFP(Operand1, Operand2, CC, OrAndOpcode: LogicOp->getOpcode(),
6247 DAG, isFMAXNUMFMINNUM_IEEE, isFMAXNUMFMINNUM);
6248
6249 if (NewOpcode != ISD::DELETED_NODE) {
6250 SDValue MinMaxValue =
6251 DAG.getNode(Opcode: NewOpcode, DL, VT: OpVT, N1: Operand1, N2: Operand2);
6252 return DAG.getSetCC(DL, VT, LHS: MinMaxValue, RHS: CommonValue, Cond: CC);
6253 }
6254 }
6255 }
6256
6257 if (TargetPreference == AndOrSETCCFoldKind::None)
6258 return SDValue();
6259
6260 if (CCL == CCR &&
6261 CCL == (LogicOp->getOpcode() == ISD::AND ? ISD::SETNE : ISD::SETEQ) &&
6262 LHS0 == RHS0 && LHS1C && RHS1C && OpVT.isInteger()) {
6263 const APInt &APLhs = LHS1C->getAPIntValue();
6264 const APInt &APRhs = RHS1C->getAPIntValue();
6265
6266 // Preference is to use ISD::ABS or we already have an ISD::ABS (in which
6267 // case this is just a compare).
6268 if (APLhs == (-APRhs) &&
6269 ((TargetPreference & AndOrSETCCFoldKind::ABS) ||
6270 DAG.doesNodeExist(Opcode: ISD::ABS, VTList: DAG.getVTList(VT: OpVT), Ops: {LHS0}))) {
6271 const APInt &C = APLhs.isNegative() ? APRhs : APLhs;
6272 // (icmp eq A, C) | (icmp eq A, -C)
6273 // -> (icmp eq Abs(A), C)
6274 // (icmp ne A, C) & (icmp ne A, -C)
6275 // -> (icmp ne Abs(A), C)
6276 SDValue AbsOp = DAG.getNode(Opcode: ISD::ABS, DL, VT: OpVT, Operand: LHS0);
6277 return DAG.getNode(Opcode: ISD::SETCC, DL, VT, N1: AbsOp,
6278 N2: DAG.getConstant(Val: C, DL, VT: OpVT), N3: LHS.getOperand(i: 2));
6279 } else if (TargetPreference &
6280 (AndOrSETCCFoldKind::AddAnd | AndOrSETCCFoldKind::NotAnd)) {
6281
6282 // AndOrSETCCFoldKind::AddAnd:
6283 // A == C0 | A == C1
6284 // IF IsPow2(smax(C0, C1)-smin(C0, C1))
6285 // -> ((A - smin(C0, C1)) & ~(smax(C0, C1)-smin(C0, C1))) == 0
6286 // A != C0 & A != C1
6287 // IF IsPow2(smax(C0, C1)-smin(C0, C1))
6288 // -> ((A - smin(C0, C1)) & ~(smax(C0, C1)-smin(C0, C1))) != 0
6289
6290 // AndOrSETCCFoldKind::NotAnd:
6291 // A == C0 | A == C1
6292 // IF smax(C0, C1) == -1 AND IsPow2(smax(C0, C1) - smin(C0, C1))
6293 // -> ~A & smin(C0, C1) == 0
6294 // A != C0 & A != C1
6295 // IF smax(C0, C1) == -1 AND IsPow2(smax(C0, C1) - smin(C0, C1))
6296 // -> ~A & smin(C0, C1) != 0
6297
6298 const APInt &MaxC = APIntOps::smax(A: APRhs, B: APLhs);
6299 const APInt &MinC = APIntOps::smin(A: APRhs, B: APLhs);
6300 APInt Dif = MaxC - MinC;
6301 if (!Dif.isZero() && Dif.isPowerOf2()) {
6302 if (MaxC.isAllOnes() &&
6303 (TargetPreference & AndOrSETCCFoldKind::NotAnd)) {
6304 SDValue NotOp = DAG.getNOT(DL, Val: LHS0, VT: OpVT);
6305 SDValue AndOp = DAG.getNode(Opcode: ISD::AND, DL, VT: OpVT, N1: NotOp,
6306 N2: DAG.getConstant(Val: MinC, DL, VT: OpVT));
6307 return DAG.getNode(Opcode: ISD::SETCC, DL, VT, N1: AndOp,
6308 N2: DAG.getConstant(Val: 0, DL, VT: OpVT), N3: LHS.getOperand(i: 2));
6309 } else if (TargetPreference & AndOrSETCCFoldKind::AddAnd) {
6310
6311 SDValue AddOp = DAG.getNode(Opcode: ISD::ADD, DL, VT: OpVT, N1: LHS0,
6312 N2: DAG.getConstant(Val: -MinC, DL, VT: OpVT));
6313 SDValue AndOp = DAG.getNode(Opcode: ISD::AND, DL, VT: OpVT, N1: AddOp,
6314 N2: DAG.getConstant(Val: ~Dif, DL, VT: OpVT));
6315 return DAG.getNode(Opcode: ISD::SETCC, DL, VT, N1: AndOp,
6316 N2: DAG.getConstant(Val: 0, DL, VT: OpVT), N3: LHS.getOperand(i: 2));
6317 }
6318 }
6319 }
6320 }
6321
6322 return SDValue();
6323}
6324
6325// Combine `(select c, (X & 1), 0)` -> `(and (zext c), X)`.
6326// We canonicalize to the `select` form in the middle end, but the `and` form
6327// gets better codegen and all tested targets (arm, x86, riscv)
6328static SDValue combineSelectAsExtAnd(SDValue Cond, SDValue T, SDValue F,
6329 const SDLoc &DL, SelectionDAG &DAG) {
6330 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6331 if (!isNullConstant(V: F))
6332 return SDValue();
6333
6334 EVT CondVT = Cond.getValueType();
6335 if (TLI.getBooleanContents(Type: CondVT) !=
6336 TargetLoweringBase::ZeroOrOneBooleanContent)
6337 return SDValue();
6338
6339 if (T.getOpcode() != ISD::AND)
6340 return SDValue();
6341
6342 if (!isOneConstant(V: T.getOperand(i: 1)))
6343 return SDValue();
6344
6345 EVT OpVT = T.getValueType();
6346
6347 SDValue CondMask =
6348 OpVT == CondVT ? Cond : DAG.getBoolExtOrTrunc(Op: Cond, SL: DL, VT: OpVT, OpVT: CondVT);
6349 return DAG.getNode(Opcode: ISD::AND, DL, VT: OpVT, N1: CondMask, N2: T.getOperand(i: 0));
6350}
6351
6352/// This contains all DAGCombine rules which reduce two values combined by
6353/// an And operation to a single value. This makes them reusable in the context
6354/// of visitSELECT(). Rules involving constants are not included as
6355/// visitSELECT() already handles those cases.
6356SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, SDNode *N) {
6357 EVT VT = N1.getValueType();
6358 SDLoc DL(N);
6359
6360 // fold (and x, undef) -> 0
6361 if (N0.isUndef() || N1.isUndef())
6362 return DAG.getConstant(Val: 0, DL, VT);
6363
6364 if (SDValue V = foldLogicOfSetCCs(IsAnd: true, N0, N1, DL))
6365 return V;
6366
6367 // Canonicalize:
6368 // and(x, add) -> and(add, x)
6369 if (N1.getOpcode() == ISD::ADD)
6370 std::swap(a&: N0, b&: N1);
6371
6372 // TODO: Rewrite this to return a new 'AND' instead of using CombineTo.
6373 if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::SRL &&
6374 VT.isScalarInteger() && VT.getSizeInBits() <= 64 && N0->hasOneUse()) {
6375 if (ConstantSDNode *ADDI = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1))) {
6376 if (ConstantSDNode *SRLI = dyn_cast<ConstantSDNode>(Val: N1.getOperand(i: 1))) {
6377 // Look for (and (add x, c1), (lshr y, c2)). If C1 wasn't a legal
6378 // immediate for an add, but it is legal if its top c2 bits are set,
6379 // transform the ADD so the immediate doesn't need to be materialized
6380 // in a register.
6381 APInt ADDC = ADDI->getAPIntValue();
6382 APInt SRLC = SRLI->getAPIntValue();
6383 if (ADDC.getSignificantBits() <= 64 && SRLC.ult(RHS: VT.getSizeInBits()) &&
6384 !TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
6385 APInt Mask = APInt::getHighBitsSet(numBits: VT.getSizeInBits(),
6386 hiBitsSet: SRLC.getZExtValue());
6387 if (DAG.MaskedValueIsZero(Op: N0.getOperand(i: 1), Mask)) {
6388 ADDC |= Mask;
6389 if (TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
6390 SDLoc DL0(N0);
6391 SDValue NewAdd =
6392 DAG.getNode(Opcode: ISD::ADD, DL: DL0, VT,
6393 N1: N0.getOperand(i: 0), N2: DAG.getConstant(Val: ADDC, DL, VT));
6394 CombineTo(N: N0.getNode(), Res: NewAdd);
6395 // Return N so it doesn't get rechecked!
6396 return SDValue(N, 0);
6397 }
6398 }
6399 }
6400 }
6401 }
6402 }
6403
6404 return SDValue();
6405}
6406
6407bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
6408 EVT LoadResultTy, EVT &ExtVT) {
6409 if (!AndC->getAPIntValue().isMask())
6410 return false;
6411
6412 unsigned ActiveBits = AndC->getAPIntValue().countr_one();
6413
6414 ExtVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ActiveBits);
6415 EVT LoadedVT = LoadN->getMemoryVT();
6416
6417 if (ExtVT == LoadedVT &&
6418 (!LegalOperations ||
6419 TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: LoadResultTy, MemVT: ExtVT))) {
6420 // ZEXTLOAD will match without needing to change the size of the value being
6421 // loaded.
6422 return true;
6423 }
6424
6425 // Do not change the width of a volatile or atomic loads.
6426 if (!LoadN->isSimple())
6427 return false;
6428
6429 // Do not generate loads of non-round integer types since these can
6430 // be expensive (and would be wrong if the type is not byte sized).
6431 if (!LoadedVT.bitsGT(VT: ExtVT) || !ExtVT.isRound())
6432 return false;
6433
6434 if (LegalOperations &&
6435 !TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: LoadResultTy, MemVT: ExtVT))
6436 return false;
6437
6438 if (!TLI.shouldReduceLoadWidth(Load: LoadN, ExtTy: ISD::ZEXTLOAD, NewVT: ExtVT))
6439 return false;
6440
6441 return true;
6442}
6443
6444bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST,
6445 ISD::LoadExtType ExtType, EVT &MemVT,
6446 unsigned ShAmt) {
6447 if (!LDST)
6448 return false;
6449 // Only allow byte offsets.
6450 if (ShAmt % 8)
6451 return false;
6452
6453 // Do not generate loads of non-round integer types since these can
6454 // be expensive (and would be wrong if the type is not byte sized).
6455 if (!MemVT.isRound())
6456 return false;
6457
6458 // Don't change the width of a volatile or atomic loads.
6459 if (!LDST->isSimple())
6460 return false;
6461
6462 EVT LdStMemVT = LDST->getMemoryVT();
6463
6464 // Bail out when changing the scalable property, since we can't be sure that
6465 // we're actually narrowing here.
6466 if (LdStMemVT.isScalableVector() != MemVT.isScalableVector())
6467 return false;
6468
6469 // Verify that we are actually reducing a load width here.
6470 if (LdStMemVT.bitsLT(VT: MemVT))
6471 return false;
6472
6473 // Ensure that this isn't going to produce an unsupported memory access.
6474 if (ShAmt) {
6475 assert(ShAmt % 8 == 0 && "ShAmt is byte offset");
6476 const unsigned ByteShAmt = ShAmt / 8;
6477 const Align LDSTAlign = LDST->getAlign();
6478 const Align NarrowAlign = commonAlignment(A: LDSTAlign, Offset: ByteShAmt);
6479 if (!TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: MemVT,
6480 AddrSpace: LDST->getAddressSpace(), Alignment: NarrowAlign,
6481 Flags: LDST->getMemOperand()->getFlags()))
6482 return false;
6483 }
6484
6485 // It's not possible to generate a constant of extended or untyped type.
6486 EVT PtrType = LDST->getBasePtr().getValueType();
6487 if (PtrType == MVT::Untyped || PtrType.isExtended())
6488 return false;
6489
6490 if (isa<LoadSDNode>(Val: LDST)) {
6491 LoadSDNode *Load = cast<LoadSDNode>(Val: LDST);
6492 // Don't transform one with multiple uses, this would require adding a new
6493 // load.
6494 if (!SDValue(Load, 0).hasOneUse())
6495 return false;
6496
6497 if (LegalOperations &&
6498 !TLI.isLoadExtLegal(ExtType, ValVT: Load->getValueType(ResNo: 0), MemVT))
6499 return false;
6500
6501 // For the transform to be legal, the load must produce only two values
6502 // (the value loaded and the chain). Don't transform a pre-increment
6503 // load, for example, which produces an extra value. Otherwise the
6504 // transformation is not equivalent, and the downstream logic to replace
6505 // uses gets things wrong.
6506 if (Load->getNumValues() > 2)
6507 return false;
6508
6509 // If the load that we're shrinking is an extload and we're not just
6510 // discarding the extension we can't simply shrink the load. Bail.
6511 // TODO: It would be possible to merge the extensions in some cases.
6512 if (Load->getExtensionType() != ISD::NON_EXTLOAD &&
6513 Load->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
6514 return false;
6515
6516 if (!TLI.shouldReduceLoadWidth(Load, ExtTy: ExtType, NewVT: MemVT))
6517 return false;
6518 } else {
6519 assert(isa<StoreSDNode>(LDST) && "It is not a Load nor a Store SDNode");
6520 StoreSDNode *Store = cast<StoreSDNode>(Val: LDST);
6521 // Can't write outside the original store
6522 if (Store->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
6523 return false;
6524
6525 if (LegalOperations &&
6526 !TLI.isTruncStoreLegal(ValVT: Store->getValue().getValueType(), MemVT))
6527 return false;
6528 }
6529 return true;
6530}
6531
6532bool DAGCombiner::SearchForAndLoads(SDNode *N,
6533 SmallVectorImpl<LoadSDNode*> &Loads,
6534 SmallPtrSetImpl<SDNode*> &NodesWithConsts,
6535 ConstantSDNode *Mask,
6536 SDNode *&NodeToMask) {
6537 // Recursively search for the operands, looking for loads which can be
6538 // narrowed.
6539 for (SDValue Op : N->op_values()) {
6540 if (Op.getValueType().isVector())
6541 return false;
6542
6543 // Some constants may need fixing up later if they are too large.
6544 if (auto *C = dyn_cast<ConstantSDNode>(Val&: Op)) {
6545 if ((N->getOpcode() == ISD::OR || N->getOpcode() == ISD::XOR) &&
6546 (Mask->getAPIntValue() & C->getAPIntValue()) != C->getAPIntValue())
6547 NodesWithConsts.insert(Ptr: N);
6548 continue;
6549 }
6550
6551 if (!Op.hasOneUse())
6552 return false;
6553
6554 switch(Op.getOpcode()) {
6555 case ISD::LOAD: {
6556 auto *Load = cast<LoadSDNode>(Val&: Op);
6557 EVT ExtVT;
6558 if (isAndLoadExtLoad(AndC: Mask, LoadN: Load, LoadResultTy: Load->getValueType(ResNo: 0), ExtVT) &&
6559 isLegalNarrowLdSt(LDST: Load, ExtType: ISD::ZEXTLOAD, MemVT&: ExtVT)) {
6560
6561 // ZEXTLOAD is already small enough.
6562 if (Load->getExtensionType() == ISD::ZEXTLOAD &&
6563 ExtVT.bitsGE(VT: Load->getMemoryVT()))
6564 continue;
6565
6566 // Use LE to convert equal sized loads to zext.
6567 if (ExtVT.bitsLE(VT: Load->getMemoryVT()))
6568 Loads.push_back(Elt: Load);
6569
6570 continue;
6571 }
6572 return false;
6573 }
6574 case ISD::ZERO_EXTEND:
6575 case ISD::AssertZext: {
6576 unsigned ActiveBits = Mask->getAPIntValue().countr_one();
6577 EVT ExtVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ActiveBits);
6578 EVT VT = Op.getOpcode() == ISD::AssertZext ?
6579 cast<VTSDNode>(Val: Op.getOperand(i: 1))->getVT() :
6580 Op.getOperand(i: 0).getValueType();
6581
6582 // We can accept extending nodes if the mask is wider or an equal
6583 // width to the original type.
6584 if (ExtVT.bitsGE(VT))
6585 continue;
6586 break;
6587 }
6588 case ISD::OR:
6589 case ISD::XOR:
6590 case ISD::AND:
6591 if (!SearchForAndLoads(N: Op.getNode(), Loads, NodesWithConsts, Mask,
6592 NodeToMask))
6593 return false;
6594 continue;
6595 }
6596
6597 // Allow one node which will masked along with any loads found.
6598 if (NodeToMask)
6599 return false;
6600
6601 // Also ensure that the node to be masked only produces one data result.
6602 NodeToMask = Op.getNode();
6603 if (NodeToMask->getNumValues() > 1) {
6604 bool HasValue = false;
6605 for (unsigned i = 0, e = NodeToMask->getNumValues(); i < e; ++i) {
6606 MVT VT = SDValue(NodeToMask, i).getSimpleValueType();
6607 if (VT != MVT::Glue && VT != MVT::Other) {
6608 if (HasValue) {
6609 NodeToMask = nullptr;
6610 return false;
6611 }
6612 HasValue = true;
6613 }
6614 }
6615 assert(HasValue && "Node to be masked has no data result?");
6616 }
6617 }
6618 return true;
6619}
6620
6621bool DAGCombiner::BackwardsPropagateMask(SDNode *N) {
6622 auto *Mask = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
6623 if (!Mask)
6624 return false;
6625
6626 if (!Mask->getAPIntValue().isMask())
6627 return false;
6628
6629 // No need to do anything if the and directly uses a load.
6630 if (isa<LoadSDNode>(Val: N->getOperand(Num: 0)))
6631 return false;
6632
6633 SmallVector<LoadSDNode*, 8> Loads;
6634 SmallPtrSet<SDNode*, 2> NodesWithConsts;
6635 SDNode *FixupNode = nullptr;
6636 if (SearchForAndLoads(N, Loads, NodesWithConsts, Mask, NodeToMask&: FixupNode)) {
6637 if (Loads.empty())
6638 return false;
6639
6640 LLVM_DEBUG(dbgs() << "Backwards propagate AND: "; N->dump());
6641 SDValue MaskOp = N->getOperand(Num: 1);
6642
6643 // If it exists, fixup the single node we allow in the tree that needs
6644 // masking.
6645 if (FixupNode) {
6646 LLVM_DEBUG(dbgs() << "First, need to fix up: "; FixupNode->dump());
6647 SDValue And = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(FixupNode),
6648 VT: FixupNode->getValueType(ResNo: 0),
6649 N1: SDValue(FixupNode, 0), N2: MaskOp);
6650 DAG.ReplaceAllUsesOfValueWith(From: SDValue(FixupNode, 0), To: And);
6651 if (And.getOpcode() == ISD ::AND)
6652 DAG.UpdateNodeOperands(N: And.getNode(), Op1: SDValue(FixupNode, 0), Op2: MaskOp);
6653 }
6654
6655 // Narrow any constants that need it.
6656 for (auto *LogicN : NodesWithConsts) {
6657 SDValue Op0 = LogicN->getOperand(Num: 0);
6658 SDValue Op1 = LogicN->getOperand(Num: 1);
6659
6660 if (isa<ConstantSDNode>(Val: Op0))
6661 Op0 =
6662 DAG.getNode(Opcode: ISD::AND, DL: SDLoc(Op0), VT: Op0.getValueType(), N1: Op0, N2: MaskOp);
6663
6664 if (isa<ConstantSDNode>(Val: Op1))
6665 Op1 =
6666 DAG.getNode(Opcode: ISD::AND, DL: SDLoc(Op1), VT: Op1.getValueType(), N1: Op1, N2: MaskOp);
6667
6668 if (isa<ConstantSDNode>(Val: Op0) && !isa<ConstantSDNode>(Val: Op1))
6669 std::swap(a&: Op0, b&: Op1);
6670
6671 DAG.UpdateNodeOperands(N: LogicN, Op1: Op0, Op2: Op1);
6672 }
6673
6674 // Create narrow loads.
6675 for (auto *Load : Loads) {
6676 LLVM_DEBUG(dbgs() << "Propagate AND back to: "; Load->dump());
6677 SDValue And = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(Load), VT: Load->getValueType(ResNo: 0),
6678 N1: SDValue(Load, 0), N2: MaskOp);
6679 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Load, 0), To: And);
6680 if (And.getOpcode() == ISD ::AND)
6681 And = SDValue(
6682 DAG.UpdateNodeOperands(N: And.getNode(), Op1: SDValue(Load, 0), Op2: MaskOp), 0);
6683 SDValue NewLoad = reduceLoadWidth(N: And.getNode());
6684 assert(NewLoad &&
6685 "Shouldn't be masking the load if it can't be narrowed");
6686 CombineTo(N: Load, Res0: NewLoad, Res1: NewLoad.getValue(R: 1));
6687 }
6688 DAG.ReplaceAllUsesWith(From: N, To: N->getOperand(Num: 0).getNode());
6689 return true;
6690 }
6691 return false;
6692}
6693
6694// Unfold
6695// x & (-1 'logical shift' y)
6696// To
6697// (x 'opposite logical shift' y) 'logical shift' y
6698// if it is better for performance.
6699SDValue DAGCombiner::unfoldExtremeBitClearingToShifts(SDNode *N) {
6700 assert(N->getOpcode() == ISD::AND);
6701
6702 SDValue N0 = N->getOperand(Num: 0);
6703 SDValue N1 = N->getOperand(Num: 1);
6704
6705 // Do we actually prefer shifts over mask?
6706 if (!TLI.shouldFoldMaskToVariableShiftPair(X: N0))
6707 return SDValue();
6708
6709 // Try to match (-1 '[outer] logical shift' y)
6710 unsigned OuterShift;
6711 unsigned InnerShift; // The opposite direction to the OuterShift.
6712 SDValue Y; // Shift amount.
6713 auto matchMask = [&OuterShift, &InnerShift, &Y](SDValue M) -> bool {
6714 if (!M.hasOneUse())
6715 return false;
6716 OuterShift = M->getOpcode();
6717 if (OuterShift == ISD::SHL)
6718 InnerShift = ISD::SRL;
6719 else if (OuterShift == ISD::SRL)
6720 InnerShift = ISD::SHL;
6721 else
6722 return false;
6723 if (!isAllOnesConstant(V: M->getOperand(Num: 0)))
6724 return false;
6725 Y = M->getOperand(Num: 1);
6726 return true;
6727 };
6728
6729 SDValue X;
6730 if (matchMask(N1))
6731 X = N0;
6732 else if (matchMask(N0))
6733 X = N1;
6734 else
6735 return SDValue();
6736
6737 SDLoc DL(N);
6738 EVT VT = N->getValueType(ResNo: 0);
6739
6740 // tmp = x 'opposite logical shift' y
6741 SDValue T0 = DAG.getNode(Opcode: InnerShift, DL, VT, N1: X, N2: Y);
6742 // ret = tmp 'logical shift' y
6743 SDValue T1 = DAG.getNode(Opcode: OuterShift, DL, VT, N1: T0, N2: Y);
6744
6745 return T1;
6746}
6747
6748/// Try to replace shift/logic that tests if a bit is clear with mask + setcc.
6749/// For a target with a bit test, this is expected to become test + set and save
6750/// at least 1 instruction.
6751static SDValue combineShiftAnd1ToBitTest(SDNode *And, SelectionDAG &DAG) {
6752 assert(And->getOpcode() == ISD::AND && "Expected an 'and' op");
6753
6754 // Look through an optional extension.
6755 SDValue And0 = And->getOperand(Num: 0), And1 = And->getOperand(Num: 1);
6756 if (And0.getOpcode() == ISD::ANY_EXTEND && And0.hasOneUse())
6757 And0 = And0.getOperand(i: 0);
6758 if (!isOneConstant(V: And1) || !And0.hasOneUse())
6759 return SDValue();
6760
6761 SDValue Src = And0;
6762
6763 // Attempt to find a 'not' op.
6764 // TODO: Should we favor test+set even without the 'not' op?
6765 bool FoundNot = false;
6766 if (isBitwiseNot(V: Src)) {
6767 FoundNot = true;
6768 Src = Src.getOperand(i: 0);
6769
6770 // Look though an optional truncation. The source operand may not be the
6771 // same type as the original 'and', but that is ok because we are masking
6772 // off everything but the low bit.
6773 if (Src.getOpcode() == ISD::TRUNCATE && Src.hasOneUse())
6774 Src = Src.getOperand(i: 0);
6775 }
6776
6777 // Match a shift-right by constant.
6778 if (Src.getOpcode() != ISD::SRL || !Src.hasOneUse())
6779 return SDValue();
6780
6781 // This is probably not worthwhile without a supported type.
6782 EVT SrcVT = Src.getValueType();
6783 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6784 if (!TLI.isTypeLegal(VT: SrcVT))
6785 return SDValue();
6786
6787 // We might have looked through casts that make this transform invalid.
6788 unsigned BitWidth = SrcVT.getScalarSizeInBits();
6789 SDValue ShiftAmt = Src.getOperand(i: 1);
6790 auto *ShiftAmtC = dyn_cast<ConstantSDNode>(Val&: ShiftAmt);
6791 if (!ShiftAmtC || !ShiftAmtC->getAPIntValue().ult(RHS: BitWidth))
6792 return SDValue();
6793
6794 // Set source to shift source.
6795 Src = Src.getOperand(i: 0);
6796
6797 // Try again to find a 'not' op.
6798 // TODO: Should we favor test+set even with two 'not' ops?
6799 if (!FoundNot) {
6800 if (!isBitwiseNot(V: Src))
6801 return SDValue();
6802 Src = Src.getOperand(i: 0);
6803 }
6804
6805 if (!TLI.hasBitTest(X: Src, Y: ShiftAmt))
6806 return SDValue();
6807
6808 // Turn this into a bit-test pattern using mask op + setcc:
6809 // and (not (srl X, C)), 1 --> (and X, 1<<C) == 0
6810 // and (srl (not X), C)), 1 --> (and X, 1<<C) == 0
6811 SDLoc DL(And);
6812 SDValue X = DAG.getZExtOrTrunc(Op: Src, DL, VT: SrcVT);
6813 EVT CCVT =
6814 TLI.getSetCCResultType(DL: DAG.getDataLayout(), Context&: *DAG.getContext(), VT: SrcVT);
6815 SDValue Mask = DAG.getConstant(
6816 Val: APInt::getOneBitSet(numBits: BitWidth, BitNo: ShiftAmtC->getZExtValue()), DL, VT: SrcVT);
6817 SDValue NewAnd = DAG.getNode(Opcode: ISD::AND, DL, VT: SrcVT, N1: X, N2: Mask);
6818 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: SrcVT);
6819 SDValue Setcc = DAG.getSetCC(DL, VT: CCVT, LHS: NewAnd, RHS: Zero, Cond: ISD::SETEQ);
6820 return DAG.getZExtOrTrunc(Op: Setcc, DL, VT: And->getValueType(ResNo: 0));
6821}
6822
6823/// For targets that support usubsat, match a bit-hack form of that operation
6824/// that ends in 'and' and convert it.
6825static SDValue foldAndToUsubsat(SDNode *N, SelectionDAG &DAG) {
6826 SDValue N0 = N->getOperand(Num: 0);
6827 SDValue N1 = N->getOperand(Num: 1);
6828 EVT VT = N1.getValueType();
6829
6830 // Canonicalize SRA as operand 1.
6831 if (N0.getOpcode() == ISD::SRA)
6832 std::swap(a&: N0, b&: N1);
6833
6834 // xor/add with SMIN (signmask) are logically equivalent.
6835 if (N0.getOpcode() != ISD::XOR && N0.getOpcode() != ISD::ADD)
6836 return SDValue();
6837
6838 if (N1.getOpcode() != ISD::SRA || !N0.hasOneUse() || !N1.hasOneUse() ||
6839 N0.getOperand(i: 0) != N1.getOperand(i: 0))
6840 return SDValue();
6841
6842 unsigned BitWidth = VT.getScalarSizeInBits();
6843 ConstantSDNode *XorC = isConstOrConstSplat(N: N0.getOperand(i: 1), AllowUndefs: true);
6844 ConstantSDNode *SraC = isConstOrConstSplat(N: N1.getOperand(i: 1), AllowUndefs: true);
6845 if (!XorC || !XorC->getAPIntValue().isSignMask() ||
6846 !SraC || SraC->getAPIntValue() != BitWidth - 1)
6847 return SDValue();
6848
6849 // (i8 X ^ 128) & (i8 X s>> 7) --> usubsat X, 128
6850 // (i8 X + 128) & (i8 X s>> 7) --> usubsat X, 128
6851 SDLoc DL(N);
6852 SDValue SignMask = DAG.getConstant(Val: XorC->getAPIntValue(), DL, VT);
6853 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT, N1: N0.getOperand(i: 0), N2: SignMask);
6854}
6855
6856/// Given a bitwise logic operation N with a matching bitwise logic operand,
6857/// fold a pattern where 2 of the source operands are identically shifted
6858/// values. For example:
6859/// ((X0 << Y) | Z) | (X1 << Y) --> ((X0 | X1) << Y) | Z
6860static SDValue foldLogicOfShifts(SDNode *N, SDValue LogicOp, SDValue ShiftOp,
6861 SelectionDAG &DAG) {
6862 unsigned LogicOpcode = N->getOpcode();
6863 assert(ISD::isBitwiseLogicOp(LogicOpcode) &&
6864 "Expected bitwise logic operation");
6865
6866 if (!LogicOp.hasOneUse() || !ShiftOp.hasOneUse())
6867 return SDValue();
6868
6869 // Match another bitwise logic op and a shift.
6870 unsigned ShiftOpcode = ShiftOp.getOpcode();
6871 if (LogicOp.getOpcode() != LogicOpcode ||
6872 !(ShiftOpcode == ISD::SHL || ShiftOpcode == ISD::SRL ||
6873 ShiftOpcode == ISD::SRA))
6874 return SDValue();
6875
6876 // Match another shift op inside the first logic operand. Handle both commuted
6877 // possibilities.
6878 // LOGIC (LOGIC (SH X0, Y), Z), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
6879 // LOGIC (LOGIC Z, (SH X0, Y)), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
6880 SDValue X1 = ShiftOp.getOperand(i: 0);
6881 SDValue Y = ShiftOp.getOperand(i: 1);
6882 SDValue X0, Z;
6883 if (LogicOp.getOperand(i: 0).getOpcode() == ShiftOpcode &&
6884 LogicOp.getOperand(i: 0).getOperand(i: 1) == Y) {
6885 X0 = LogicOp.getOperand(i: 0).getOperand(i: 0);
6886 Z = LogicOp.getOperand(i: 1);
6887 } else if (LogicOp.getOperand(i: 1).getOpcode() == ShiftOpcode &&
6888 LogicOp.getOperand(i: 1).getOperand(i: 1) == Y) {
6889 X0 = LogicOp.getOperand(i: 1).getOperand(i: 0);
6890 Z = LogicOp.getOperand(i: 0);
6891 } else {
6892 return SDValue();
6893 }
6894
6895 EVT VT = N->getValueType(ResNo: 0);
6896 SDLoc DL(N);
6897 SDValue LogicX = DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: X0, N2: X1);
6898 SDValue NewShift = DAG.getNode(Opcode: ShiftOpcode, DL, VT, N1: LogicX, N2: Y);
6899 return DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: NewShift, N2: Z);
6900}
6901
6902/// Given a tree of logic operations with shape like
6903/// (LOGIC (LOGIC (X, Y), LOGIC (Z, Y)))
6904/// try to match and fold shift operations with the same shift amount.
6905/// For example:
6906/// LOGIC (LOGIC (SH X0, Y), Z), (LOGIC (SH X1, Y), W) -->
6907/// --> LOGIC (SH (LOGIC X0, X1), Y), (LOGIC Z, W)
6908static SDValue foldLogicTreeOfShifts(SDNode *N, SDValue LeftHand,
6909 SDValue RightHand, SelectionDAG &DAG) {
6910 unsigned LogicOpcode = N->getOpcode();
6911 assert(ISD::isBitwiseLogicOp(LogicOpcode) &&
6912 "Expected bitwise logic operation");
6913 if (LeftHand.getOpcode() != LogicOpcode ||
6914 RightHand.getOpcode() != LogicOpcode)
6915 return SDValue();
6916 if (!LeftHand.hasOneUse() || !RightHand.hasOneUse())
6917 return SDValue();
6918
6919 // Try to match one of following patterns:
6920 // LOGIC (LOGIC (SH X0, Y), Z), (LOGIC (SH X1, Y), W)
6921 // LOGIC (LOGIC (SH X0, Y), Z), (LOGIC W, (SH X1, Y))
6922 // Note that foldLogicOfShifts will handle commuted versions of the left hand
6923 // itself.
6924 SDValue CombinedShifts, W;
6925 SDValue R0 = RightHand.getOperand(i: 0);
6926 SDValue R1 = RightHand.getOperand(i: 1);
6927 if ((CombinedShifts = foldLogicOfShifts(N, LogicOp: LeftHand, ShiftOp: R0, DAG)))
6928 W = R1;
6929 else if ((CombinedShifts = foldLogicOfShifts(N, LogicOp: LeftHand, ShiftOp: R1, DAG)))
6930 W = R0;
6931 else
6932 return SDValue();
6933
6934 EVT VT = N->getValueType(ResNo: 0);
6935 SDLoc DL(N);
6936 return DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: CombinedShifts, N2: W);
6937}
6938
6939SDValue DAGCombiner::visitAND(SDNode *N) {
6940 SDValue N0 = N->getOperand(Num: 0);
6941 SDValue N1 = N->getOperand(Num: 1);
6942 EVT VT = N1.getValueType();
6943
6944 // x & x --> x
6945 if (N0 == N1)
6946 return N0;
6947
6948 // fold (and c1, c2) -> c1&c2
6949 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::AND, DL: SDLoc(N), VT, Ops: {N0, N1}))
6950 return C;
6951
6952 // canonicalize constant to RHS
6953 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
6954 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
6955 return DAG.getNode(Opcode: ISD::AND, DL: SDLoc(N), VT, N1, N2: N0);
6956
6957 if (areBitwiseNotOfEachother(Op0: N0, Op1: N1))
6958 return DAG.getConstant(Val: APInt::getZero(numBits: VT.getScalarSizeInBits()), DL: SDLoc(N),
6959 VT);
6960
6961 // fold vector ops
6962 if (VT.isVector()) {
6963 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL: SDLoc(N)))
6964 return FoldedVOp;
6965
6966 // fold (and x, 0) -> 0, vector edition
6967 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
6968 // do not return N1, because undef node may exist in N1
6969 return DAG.getConstant(Val: APInt::getZero(numBits: N1.getScalarValueSizeInBits()),
6970 DL: SDLoc(N), VT: N1.getValueType());
6971
6972 // fold (and x, -1) -> x, vector edition
6973 if (ISD::isConstantSplatVectorAllOnes(N: N1.getNode()))
6974 return N0;
6975
6976 // fold (and (masked_load) (splat_vec (x, ...))) to zext_masked_load
6977 auto *MLoad = dyn_cast<MaskedLoadSDNode>(Val&: N0);
6978 ConstantSDNode *Splat = isConstOrConstSplat(N: N1, AllowUndefs: true, AllowTruncation: true);
6979 if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && Splat &&
6980 N1.hasOneUse()) {
6981 EVT LoadVT = MLoad->getMemoryVT();
6982 EVT ExtVT = VT;
6983 if (TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: ExtVT, MemVT: LoadVT)) {
6984 // For this AND to be a zero extension of the masked load the elements
6985 // of the BuildVec must mask the bottom bits of the extended element
6986 // type
6987 uint64_t ElementSize =
6988 LoadVT.getVectorElementType().getScalarSizeInBits();
6989 if (Splat->getAPIntValue().isMask(numBits: ElementSize)) {
6990 auto NewLoad = DAG.getMaskedLoad(
6991 VT: ExtVT, dl: SDLoc(N), Chain: MLoad->getChain(), Base: MLoad->getBasePtr(),
6992 Offset: MLoad->getOffset(), Mask: MLoad->getMask(), Src0: MLoad->getPassThru(),
6993 MemVT: LoadVT, MMO: MLoad->getMemOperand(), AM: MLoad->getAddressingMode(),
6994 ISD::ZEXTLOAD, IsExpanding: MLoad->isExpandingLoad());
6995 bool LoadHasOtherUsers = !N0.hasOneUse();
6996 CombineTo(N, Res: NewLoad);
6997 if (LoadHasOtherUsers)
6998 CombineTo(N: MLoad, Res0: NewLoad.getValue(R: 0), Res1: NewLoad.getValue(R: 1));
6999 return SDValue(N, 0);
7000 }
7001 }
7002 }
7003 }
7004
7005 // fold (and x, -1) -> x
7006 if (isAllOnesConstant(V: N1))
7007 return N0;
7008
7009 // if (and x, c) is known to be zero, return 0
7010 unsigned BitWidth = VT.getScalarSizeInBits();
7011 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
7012 if (N1C && DAG.MaskedValueIsZero(Op: SDValue(N, 0), Mask: APInt::getAllOnes(numBits: BitWidth)))
7013 return DAG.getConstant(Val: 0, DL: SDLoc(N), VT);
7014
7015 if (SDValue R = foldAndOrOfSETCC(LogicOp: N, DAG))
7016 return R;
7017
7018 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
7019 return NewSel;
7020
7021 // reassociate and
7022 if (SDValue RAND = reassociateOps(Opc: ISD::AND, DL: SDLoc(N), N0, N1, Flags: N->getFlags()))
7023 return RAND;
7024
7025 // Fold and(vecreduce(x), vecreduce(y)) -> vecreduce(and(x, y))
7026 if (SDValue SD = reassociateReduction(RedOpc: ISD::VECREDUCE_AND, Opc: ISD::AND, DL: SDLoc(N),
7027 VT, N0, N1))
7028 return SD;
7029
7030 // fold (and (or x, C), D) -> D if (C & D) == D
7031 auto MatchSubset = [](ConstantSDNode *LHS, ConstantSDNode *RHS) {
7032 return RHS->getAPIntValue().isSubsetOf(RHS: LHS->getAPIntValue());
7033 };
7034 if (N0.getOpcode() == ISD::OR &&
7035 ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchSubset))
7036 return N1;
7037
7038 if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
7039 SDValue N0Op0 = N0.getOperand(i: 0);
7040 EVT SrcVT = N0Op0.getValueType();
7041 unsigned SrcBitWidth = SrcVT.getScalarSizeInBits();
7042 APInt Mask = ~N1C->getAPIntValue();
7043 Mask = Mask.trunc(width: SrcBitWidth);
7044
7045 // fold (and (any_ext V), c) -> (zero_ext V) if 'and' only clears top bits.
7046 if (DAG.MaskedValueIsZero(Op: N0Op0, Mask))
7047 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL: SDLoc(N), VT, Operand: N0Op0);
7048
7049 // fold (and (any_ext V), c) -> (zero_ext (and (trunc V), c)) if profitable.
7050 if (N1C->getAPIntValue().countLeadingZeros() >= (BitWidth - SrcBitWidth) &&
7051 TLI.isTruncateFree(FromVT: VT, ToVT: SrcVT) && TLI.isZExtFree(FromTy: SrcVT, ToTy: VT) &&
7052 TLI.isTypeDesirableForOp(ISD::AND, VT: SrcVT) &&
7053 TLI.isNarrowingProfitable(SrcVT: VT, DestVT: SrcVT)) {
7054 SDLoc DL(N);
7055 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT,
7056 Operand: DAG.getNode(Opcode: ISD::AND, DL, VT: SrcVT, N1: N0Op0,
7057 N2: DAG.getZExtOrTrunc(Op: N1, DL, VT: SrcVT)));
7058 }
7059 }
7060
7061 // fold (and (ext (and V, c1)), c2) -> (and (ext V), (and c1, (ext c2)))
7062 if (ISD::isExtOpcode(Opcode: N0.getOpcode())) {
7063 unsigned ExtOpc = N0.getOpcode();
7064 SDValue N0Op0 = N0.getOperand(i: 0);
7065 if (N0Op0.getOpcode() == ISD::AND &&
7066 (ExtOpc != ISD::ZERO_EXTEND || !TLI.isZExtFree(Val: N0Op0, VT2: VT)) &&
7067 DAG.isConstantIntBuildVectorOrConstantInt(N: N1) &&
7068 DAG.isConstantIntBuildVectorOrConstantInt(N: N0Op0.getOperand(i: 1)) &&
7069 N0->hasOneUse() && N0Op0->hasOneUse()) {
7070 SDLoc DL(N);
7071 SDValue NewMask =
7072 DAG.getNode(Opcode: ISD::AND, DL, VT, N1,
7073 N2: DAG.getNode(Opcode: ExtOpc, DL, VT, Operand: N0Op0.getOperand(i: 1)));
7074 return DAG.getNode(Opcode: ISD::AND, DL, VT,
7075 N1: DAG.getNode(Opcode: ExtOpc, DL, VT, Operand: N0Op0.getOperand(i: 0)),
7076 N2: NewMask);
7077 }
7078 }
7079
7080 // similarly fold (and (X (load ([non_ext|any_ext|zero_ext] V))), c) ->
7081 // (X (load ([non_ext|zero_ext] V))) if 'and' only clears top bits which must
7082 // already be zero by virtue of the width of the base type of the load.
7083 //
7084 // the 'X' node here can either be nothing or an extract_vector_elt to catch
7085 // more cases.
7086 if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
7087 N0.getValueSizeInBits() == N0.getOperand(i: 0).getScalarValueSizeInBits() &&
7088 N0.getOperand(i: 0).getOpcode() == ISD::LOAD &&
7089 N0.getOperand(i: 0).getResNo() == 0) ||
7090 (N0.getOpcode() == ISD::LOAD && N0.getResNo() == 0)) {
7091 LoadSDNode *Load = cast<LoadSDNode>( Val: (N0.getOpcode() == ISD::LOAD) ?
7092 N0 : N0.getOperand(i: 0) );
7093
7094 // Get the constant (if applicable) the zero'th operand is being ANDed with.
7095 // This can be a pure constant or a vector splat, in which case we treat the
7096 // vector as a scalar and use the splat value.
7097 APInt Constant = APInt::getZero(numBits: 1);
7098 if (const ConstantSDNode *C = isConstOrConstSplat(
7099 N: N1, /*AllowUndef=*/AllowUndefs: false, /*AllowTruncation=*/true)) {
7100 Constant = C->getAPIntValue();
7101 } else if (BuildVectorSDNode *Vector = dyn_cast<BuildVectorSDNode>(Val&: N1)) {
7102 unsigned EltBitWidth = Vector->getValueType(ResNo: 0).getScalarSizeInBits();
7103 APInt SplatValue, SplatUndef;
7104 unsigned SplatBitSize;
7105 bool HasAnyUndefs;
7106 // Endianness should not matter here. Code below makes sure that we only
7107 // use the result if the SplatBitSize is a multiple of the vector element
7108 // size. And after that we AND all element sized parts of the splat
7109 // together. So the end result should be the same regardless of in which
7110 // order we do those operations.
7111 const bool IsBigEndian = false;
7112 bool IsSplat =
7113 Vector->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
7114 HasAnyUndefs, MinSplatBits: EltBitWidth, isBigEndian: IsBigEndian);
7115
7116 // Make sure that variable 'Constant' is only set if 'SplatBitSize' is a
7117 // multiple of 'BitWidth'. Otherwise, we could propagate a wrong value.
7118 if (IsSplat && (SplatBitSize % EltBitWidth) == 0) {
7119 // Undef bits can contribute to a possible optimisation if set, so
7120 // set them.
7121 SplatValue |= SplatUndef;
7122
7123 // The splat value may be something like "0x00FFFFFF", which means 0 for
7124 // the first vector value and FF for the rest, repeating. We need a mask
7125 // that will apply equally to all members of the vector, so AND all the
7126 // lanes of the constant together.
7127 Constant = APInt::getAllOnes(numBits: EltBitWidth);
7128 for (unsigned i = 0, n = (SplatBitSize / EltBitWidth); i < n; ++i)
7129 Constant &= SplatValue.extractBits(numBits: EltBitWidth, bitPosition: i * EltBitWidth);
7130 }
7131 }
7132
7133 // If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is
7134 // actually legal and isn't going to get expanded, else this is a false
7135 // optimisation.
7136 bool CanZextLoadProfitably = TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD,
7137 ValVT: Load->getValueType(ResNo: 0),
7138 MemVT: Load->getMemoryVT());
7139
7140 // Resize the constant to the same size as the original memory access before
7141 // extension. If it is still the AllOnesValue then this AND is completely
7142 // unneeded.
7143 Constant = Constant.zextOrTrunc(width: Load->getMemoryVT().getScalarSizeInBits());
7144
7145 bool B;
7146 switch (Load->getExtensionType()) {
7147 default: B = false; break;
7148 case ISD::EXTLOAD: B = CanZextLoadProfitably; break;
7149 case ISD::ZEXTLOAD:
7150 case ISD::NON_EXTLOAD: B = true; break;
7151 }
7152
7153 if (B && Constant.isAllOnes()) {
7154 // If the load type was an EXTLOAD, convert to ZEXTLOAD in order to
7155 // preserve semantics once we get rid of the AND.
7156 SDValue NewLoad(Load, 0);
7157
7158 // Fold the AND away. NewLoad may get replaced immediately.
7159 CombineTo(N, Res: (N0.getNode() == Load) ? NewLoad : N0);
7160
7161 if (Load->getExtensionType() == ISD::EXTLOAD) {
7162 NewLoad = DAG.getLoad(AM: Load->getAddressingMode(), ExtType: ISD::ZEXTLOAD,
7163 VT: Load->getValueType(ResNo: 0), dl: SDLoc(Load),
7164 Chain: Load->getChain(), Ptr: Load->getBasePtr(),
7165 Offset: Load->getOffset(), MemVT: Load->getMemoryVT(),
7166 MMO: Load->getMemOperand());
7167 // Replace uses of the EXTLOAD with the new ZEXTLOAD.
7168 if (Load->getNumValues() == 3) {
7169 // PRE/POST_INC loads have 3 values.
7170 SDValue To[] = { NewLoad.getValue(R: 0), NewLoad.getValue(R: 1),
7171 NewLoad.getValue(R: 2) };
7172 CombineTo(N: Load, To, NumTo: 3, AddTo: true);
7173 } else {
7174 CombineTo(N: Load, Res0: NewLoad.getValue(R: 0), Res1: NewLoad.getValue(R: 1));
7175 }
7176 }
7177
7178 return SDValue(N, 0); // Return N so it doesn't get rechecked!
7179 }
7180 }
7181
7182 // Try to convert a constant mask AND into a shuffle clear mask.
7183 if (VT.isVector())
7184 if (SDValue Shuffle = XformToShuffleWithZero(N))
7185 return Shuffle;
7186
7187 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
7188 return Combined;
7189
7190 if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && N0.hasOneUse() && N1C &&
7191 ISD::isExtOpcode(Opcode: N0.getOperand(i: 0).getOpcode())) {
7192 SDValue Ext = N0.getOperand(i: 0);
7193 EVT ExtVT = Ext->getValueType(ResNo: 0);
7194 SDValue Extendee = Ext->getOperand(Num: 0);
7195
7196 unsigned ScalarWidth = Extendee.getValueType().getScalarSizeInBits();
7197 if (N1C->getAPIntValue().isMask(numBits: ScalarWidth) &&
7198 (!LegalOperations || TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT: ExtVT))) {
7199 // (and (extract_subvector (zext|anyext|sext v) _) iN_mask)
7200 // => (extract_subvector (iN_zeroext v))
7201 SDValue ZeroExtExtendee =
7202 DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL: SDLoc(N), VT: ExtVT, Operand: Extendee);
7203
7204 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL: SDLoc(N), VT, N1: ZeroExtExtendee,
7205 N2: N0.getOperand(i: 1));
7206 }
7207 }
7208
7209 // fold (and (masked_gather x)) -> (zext_masked_gather x)
7210 if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(Val&: N0)) {
7211 EVT MemVT = GN0->getMemoryVT();
7212 EVT ScalarVT = MemVT.getScalarType();
7213
7214 if (SDValue(GN0, 0).hasOneUse() &&
7215 isConstantSplatVectorMaskForType(N: N1.getNode(), ScalarTy: ScalarVT) &&
7216 TLI.isVectorLoadExtDesirable(ExtVal: SDValue(SDValue(GN0, 0)))) {
7217 SDValue Ops[] = {GN0->getChain(), GN0->getPassThru(), GN0->getMask(),
7218 GN0->getBasePtr(), GN0->getIndex(), GN0->getScale()};
7219
7220 SDValue ZExtLoad = DAG.getMaskedGather(
7221 VTs: DAG.getVTList(VT, MVT::Other), MemVT, dl: SDLoc(N), Ops,
7222 MMO: GN0->getMemOperand(), IndexType: GN0->getIndexType(), ExtTy: ISD::ZEXTLOAD);
7223
7224 CombineTo(N, Res: ZExtLoad);
7225 AddToWorklist(N: ZExtLoad.getNode());
7226 // Avoid recheck of N.
7227 return SDValue(N, 0);
7228 }
7229 }
7230
7231 // fold (and (load x), 255) -> (zextload x, i8)
7232 // fold (and (extload x, i16), 255) -> (zextload x, i8)
7233 if (N1C && N0.getOpcode() == ISD::LOAD && !VT.isVector())
7234 if (SDValue Res = reduceLoadWidth(N))
7235 return Res;
7236
7237 if (LegalTypes) {
7238 // Attempt to propagate the AND back up to the leaves which, if they're
7239 // loads, can be combined to narrow loads and the AND node can be removed.
7240 // Perform after legalization so that extend nodes will already be
7241 // combined into the loads.
7242 if (BackwardsPropagateMask(N))
7243 return SDValue(N, 0);
7244 }
7245
7246 if (SDValue Combined = visitANDLike(N0, N1, N))
7247 return Combined;
7248
7249 // Simplify: (and (op x...), (op y...)) -> (op (and x, y))
7250 if (N0.getOpcode() == N1.getOpcode())
7251 if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
7252 return V;
7253
7254 if (SDValue R = foldLogicOfShifts(N, LogicOp: N0, ShiftOp: N1, DAG))
7255 return R;
7256 if (SDValue R = foldLogicOfShifts(N, LogicOp: N1, ShiftOp: N0, DAG))
7257 return R;
7258
7259 // Masking the negated extension of a boolean is just the zero-extended
7260 // boolean:
7261 // and (sub 0, zext(bool X)), 1 --> zext(bool X)
7262 // and (sub 0, sext(bool X)), 1 --> zext(bool X)
7263 //
7264 // Note: the SimplifyDemandedBits fold below can make an information-losing
7265 // transform, and then we have no way to find this better fold.
7266 if (N1C && N1C->isOne() && N0.getOpcode() == ISD::SUB) {
7267 if (isNullOrNullSplat(V: N0.getOperand(i: 0))) {
7268 SDValue SubRHS = N0.getOperand(i: 1);
7269 if (SubRHS.getOpcode() == ISD::ZERO_EXTEND &&
7270 SubRHS.getOperand(i: 0).getScalarValueSizeInBits() == 1)
7271 return SubRHS;
7272 if (SubRHS.getOpcode() == ISD::SIGN_EXTEND &&
7273 SubRHS.getOperand(i: 0).getScalarValueSizeInBits() == 1)
7274 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL: SDLoc(N), VT, Operand: SubRHS.getOperand(i: 0));
7275 }
7276 }
7277
7278 // fold (and (sign_extend_inreg x, i16 to i32), 1) -> (and x, 1)
7279 // fold (and (sra)) -> (and (srl)) when possible.
7280 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
7281 return SDValue(N, 0);
7282
7283 // fold (zext_inreg (extload x)) -> (zextload x)
7284 // fold (zext_inreg (sextload x)) -> (zextload x) iff load has one use
7285 if (ISD::isUNINDEXEDLoad(N: N0.getNode()) &&
7286 (ISD::isEXTLoad(N: N0.getNode()) ||
7287 (ISD::isSEXTLoad(N: N0.getNode()) && N0.hasOneUse()))) {
7288 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
7289 EVT MemVT = LN0->getMemoryVT();
7290 // If we zero all the possible extended bits, then we can turn this into
7291 // a zextload if we are running before legalize or the operation is legal.
7292 unsigned ExtBitSize = N1.getScalarValueSizeInBits();
7293 unsigned MemBitSize = MemVT.getScalarSizeInBits();
7294 APInt ExtBits = APInt::getHighBitsSet(numBits: ExtBitSize, hiBitsSet: ExtBitSize - MemBitSize);
7295 if (DAG.MaskedValueIsZero(Op: N1, Mask: ExtBits) &&
7296 ((!LegalOperations && LN0->isSimple()) ||
7297 TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: VT, MemVT))) {
7298 SDValue ExtLoad =
7299 DAG.getExtLoad(ExtType: ISD::ZEXTLOAD, dl: SDLoc(N0), VT, Chain: LN0->getChain(),
7300 Ptr: LN0->getBasePtr(), MemVT, MMO: LN0->getMemOperand());
7301 AddToWorklist(N);
7302 CombineTo(N: N0.getNode(), Res0: ExtLoad, Res1: ExtLoad.getValue(R: 1));
7303 return SDValue(N, 0); // Return N so it doesn't get rechecked!
7304 }
7305 }
7306
7307 // fold (and (or (srl N, 8), (shl N, 8)), 0xffff) -> (srl (bswap N), const)
7308 if (N1C && N1C->getAPIntValue() == 0xffff && N0.getOpcode() == ISD::OR) {
7309 if (SDValue BSwap = MatchBSwapHWordLow(N: N0.getNode(), N0: N0.getOperand(i: 0),
7310 N1: N0.getOperand(i: 1), DemandHighBits: false))
7311 return BSwap;
7312 }
7313
7314 if (SDValue Shifts = unfoldExtremeBitClearingToShifts(N))
7315 return Shifts;
7316
7317 if (SDValue V = combineShiftAnd1ToBitTest(And: N, DAG))
7318 return V;
7319
7320 // Recognize the following pattern:
7321 //
7322 // AndVT = (and (sign_extend NarrowVT to AndVT) #bitmask)
7323 //
7324 // where bitmask is a mask that clears the upper bits of AndVT. The
7325 // number of bits in bitmask must be a power of two.
7326 auto IsAndZeroExtMask = [](SDValue LHS, SDValue RHS) {
7327 if (LHS->getOpcode() != ISD::SIGN_EXTEND)
7328 return false;
7329
7330 auto *C = dyn_cast<ConstantSDNode>(Val&: RHS);
7331 if (!C)
7332 return false;
7333
7334 if (!C->getAPIntValue().isMask(
7335 numBits: LHS.getOperand(i: 0).getValueType().getFixedSizeInBits()))
7336 return false;
7337
7338 return true;
7339 };
7340
7341 // Replace (and (sign_extend ...) #bitmask) with (zero_extend ...).
7342 if (IsAndZeroExtMask(N0, N1))
7343 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL: SDLoc(N), VT, Operand: N0.getOperand(i: 0));
7344
7345 if (hasOperation(Opcode: ISD::USUBSAT, VT))
7346 if (SDValue V = foldAndToUsubsat(N, DAG))
7347 return V;
7348
7349 // Postpone until legalization completed to avoid interference with bswap
7350 // folding
7351 if (LegalOperations || VT.isVector())
7352 if (SDValue R = foldLogicTreeOfShifts(N, LeftHand: N0, RightHand: N1, DAG))
7353 return R;
7354
7355 return SDValue();
7356}
7357
7358/// Match (a >> 8) | (a << 8) as (bswap a) >> 16.
7359SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
7360 bool DemandHighBits) {
7361 if (!LegalOperations)
7362 return SDValue();
7363
7364 EVT VT = N->getValueType(ResNo: 0);
7365 if (VT != MVT::i64 && VT != MVT::i32 && VT != MVT::i16)
7366 return SDValue();
7367 if (!TLI.isOperationLegalOrCustom(Op: ISD::BSWAP, VT))
7368 return SDValue();
7369
7370 // Recognize (and (shl a, 8), 0xff00), (and (srl a, 8), 0xff)
7371 bool LookPassAnd0 = false;
7372 bool LookPassAnd1 = false;
7373 if (N0.getOpcode() == ISD::AND && N0.getOperand(i: 0).getOpcode() == ISD::SRL)
7374 std::swap(a&: N0, b&: N1);
7375 if (N1.getOpcode() == ISD::AND && N1.getOperand(i: 0).getOpcode() == ISD::SHL)
7376 std::swap(a&: N0, b&: N1);
7377 if (N0.getOpcode() == ISD::AND) {
7378 if (!N0->hasOneUse())
7379 return SDValue();
7380 ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
7381 // Also handle 0xffff since the LHS is guaranteed to have zeros there.
7382 // This is needed for X86.
7383 if (!N01C || (N01C->getZExtValue() != 0xFF00 &&
7384 N01C->getZExtValue() != 0xFFFF))
7385 return SDValue();
7386 N0 = N0.getOperand(i: 0);
7387 LookPassAnd0 = true;
7388 }
7389
7390 if (N1.getOpcode() == ISD::AND) {
7391 if (!N1->hasOneUse())
7392 return SDValue();
7393 ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(Val: N1.getOperand(i: 1));
7394 if (!N11C || N11C->getZExtValue() != 0xFF)
7395 return SDValue();
7396 N1 = N1.getOperand(i: 0);
7397 LookPassAnd1 = true;
7398 }
7399
7400 if (N0.getOpcode() == ISD::SRL && N1.getOpcode() == ISD::SHL)
7401 std::swap(a&: N0, b&: N1);
7402 if (N0.getOpcode() != ISD::SHL || N1.getOpcode() != ISD::SRL)
7403 return SDValue();
7404 if (!N0->hasOneUse() || !N1->hasOneUse())
7405 return SDValue();
7406
7407 ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
7408 ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(Val: N1.getOperand(i: 1));
7409 if (!N01C || !N11C)
7410 return SDValue();
7411 if (N01C->getZExtValue() != 8 || N11C->getZExtValue() != 8)
7412 return SDValue();
7413
7414 // Look for (shl (and a, 0xff), 8), (srl (and a, 0xff00), 8)
7415 SDValue N00 = N0->getOperand(Num: 0);
7416 if (!LookPassAnd0 && N00.getOpcode() == ISD::AND) {
7417 if (!N00->hasOneUse())
7418 return SDValue();
7419 ConstantSDNode *N001C = dyn_cast<ConstantSDNode>(Val: N00.getOperand(i: 1));
7420 if (!N001C || N001C->getZExtValue() != 0xFF)
7421 return SDValue();
7422 N00 = N00.getOperand(i: 0);
7423 LookPassAnd0 = true;
7424 }
7425
7426 SDValue N10 = N1->getOperand(Num: 0);
7427 if (!LookPassAnd1 && N10.getOpcode() == ISD::AND) {
7428 if (!N10->hasOneUse())
7429 return SDValue();
7430 ConstantSDNode *N101C = dyn_cast<ConstantSDNode>(Val: N10.getOperand(i: 1));
7431 // Also allow 0xFFFF since the bits will be shifted out. This is needed
7432 // for X86.
7433 if (!N101C || (N101C->getZExtValue() != 0xFF00 &&
7434 N101C->getZExtValue() != 0xFFFF))
7435 return SDValue();
7436 N10 = N10.getOperand(i: 0);
7437 LookPassAnd1 = true;
7438 }
7439
7440 if (N00 != N10)
7441 return SDValue();
7442
7443 // Make sure everything beyond the low halfword gets set to zero since the SRL
7444 // 16 will clear the top bits.
7445 unsigned OpSizeInBits = VT.getSizeInBits();
7446 if (OpSizeInBits > 16) {
7447 // If the left-shift isn't masked out then the only way this is a bswap is
7448 // if all bits beyond the low 8 are 0. In that case the entire pattern
7449 // reduces to a left shift anyway: leave it for other parts of the combiner.
7450 if (DemandHighBits && !LookPassAnd0)
7451 return SDValue();
7452
7453 // However, if the right shift isn't masked out then it might be because
7454 // it's not needed. See if we can spot that too. If the high bits aren't
7455 // demanded, we only need bits 23:16 to be zero. Otherwise, we need all
7456 // upper bits to be zero.
7457 if (!LookPassAnd1) {
7458 unsigned HighBit = DemandHighBits ? OpSizeInBits : 24;
7459 if (!DAG.MaskedValueIsZero(Op: N10,
7460 Mask: APInt::getBitsSet(numBits: OpSizeInBits, loBit: 16, hiBit: HighBit)))
7461 return SDValue();
7462 }
7463 }
7464
7465 SDValue Res = DAG.getNode(Opcode: ISD::BSWAP, DL: SDLoc(N), VT, Operand: N00);
7466 if (OpSizeInBits > 16) {
7467 SDLoc DL(N);
7468 Res = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Res,
7469 N2: DAG.getConstant(Val: OpSizeInBits - 16, DL,
7470 VT: getShiftAmountTy(LHSTy: VT)));
7471 }
7472 return Res;
7473}
7474
7475/// Return true if the specified node is an element that makes up a 32-bit
7476/// packed halfword byteswap.
7477/// ((x & 0x000000ff) << 8) |
7478/// ((x & 0x0000ff00) >> 8) |
7479/// ((x & 0x00ff0000) << 8) |
7480/// ((x & 0xff000000) >> 8)
7481static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) {
7482 if (!N->hasOneUse())
7483 return false;
7484
7485 unsigned Opc = N.getOpcode();
7486 if (Opc != ISD::AND && Opc != ISD::SHL && Opc != ISD::SRL)
7487 return false;
7488
7489 SDValue N0 = N.getOperand(i: 0);
7490 unsigned Opc0 = N0.getOpcode();
7491 if (Opc0 != ISD::AND && Opc0 != ISD::SHL && Opc0 != ISD::SRL)
7492 return false;
7493
7494 ConstantSDNode *N1C = nullptr;
7495 // SHL or SRL: look upstream for AND mask operand
7496 if (Opc == ISD::AND)
7497 N1C = dyn_cast<ConstantSDNode>(Val: N.getOperand(i: 1));
7498 else if (Opc0 == ISD::AND)
7499 N1C = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
7500 if (!N1C)
7501 return false;
7502
7503 unsigned MaskByteOffset;
7504 switch (N1C->getZExtValue()) {
7505 default:
7506 return false;
7507 case 0xFF: MaskByteOffset = 0; break;
7508 case 0xFF00: MaskByteOffset = 1; break;
7509 case 0xFFFF:
7510 // In case demanded bits didn't clear the bits that will be shifted out.
7511 // This is needed for X86.
7512 if (Opc == ISD::SRL || (Opc == ISD::AND && Opc0 == ISD::SHL)) {
7513 MaskByteOffset = 1;
7514 break;
7515 }
7516 return false;
7517 case 0xFF0000: MaskByteOffset = 2; break;
7518 case 0xFF000000: MaskByteOffset = 3; break;
7519 }
7520
7521 // Look for (x & 0xff) << 8 as well as ((x << 8) & 0xff00).
7522 if (Opc == ISD::AND) {
7523 if (MaskByteOffset == 0 || MaskByteOffset == 2) {
7524 // (x >> 8) & 0xff
7525 // (x >> 8) & 0xff0000
7526 if (Opc0 != ISD::SRL)
7527 return false;
7528 ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
7529 if (!C || C->getZExtValue() != 8)
7530 return false;
7531 } else {
7532 // (x << 8) & 0xff00
7533 // (x << 8) & 0xff000000
7534 if (Opc0 != ISD::SHL)
7535 return false;
7536 ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
7537 if (!C || C->getZExtValue() != 8)
7538 return false;
7539 }
7540 } else if (Opc == ISD::SHL) {
7541 // (x & 0xff) << 8
7542 // (x & 0xff0000) << 8
7543 if (MaskByteOffset != 0 && MaskByteOffset != 2)
7544 return false;
7545 ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val: N.getOperand(i: 1));
7546 if (!C || C->getZExtValue() != 8)
7547 return false;
7548 } else { // Opc == ISD::SRL
7549 // (x & 0xff00) >> 8
7550 // (x & 0xff000000) >> 8
7551 if (MaskByteOffset != 1 && MaskByteOffset != 3)
7552 return false;
7553 ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val: N.getOperand(i: 1));
7554 if (!C || C->getZExtValue() != 8)
7555 return false;
7556 }
7557
7558 if (Parts[MaskByteOffset])
7559 return false;
7560
7561 Parts[MaskByteOffset] = N0.getOperand(i: 0).getNode();
7562 return true;
7563}
7564
7565// Match 2 elements of a packed halfword bswap.
7566static bool isBSwapHWordPair(SDValue N, MutableArrayRef<SDNode *> Parts) {
7567 if (N.getOpcode() == ISD::OR)
7568 return isBSwapHWordElement(N: N.getOperand(i: 0), Parts) &&
7569 isBSwapHWordElement(N: N.getOperand(i: 1), Parts);
7570
7571 if (N.getOpcode() == ISD::SRL && N.getOperand(i: 0).getOpcode() == ISD::BSWAP) {
7572 ConstantSDNode *C = isConstOrConstSplat(N: N.getOperand(i: 1));
7573 if (!C || C->getAPIntValue() != 16)
7574 return false;
7575 Parts[0] = Parts[1] = N.getOperand(i: 0).getOperand(i: 0).getNode();
7576 return true;
7577 }
7578
7579 return false;
7580}
7581
7582// Match this pattern:
7583// (or (and (shl (A, 8)), 0xff00ff00), (and (srl (A, 8)), 0x00ff00ff))
7584// And rewrite this to:
7585// (rotr (bswap A), 16)
7586static SDValue matchBSwapHWordOrAndAnd(const TargetLowering &TLI,
7587 SelectionDAG &DAG, SDNode *N, SDValue N0,
7588 SDValue N1, EVT VT, EVT ShiftAmountTy) {
7589 assert(N->getOpcode() == ISD::OR && VT == MVT::i32 &&
7590 "MatchBSwapHWordOrAndAnd: expecting i32");
7591 if (!TLI.isOperationLegalOrCustom(Op: ISD::ROTR, VT))
7592 return SDValue();
7593 if (N0.getOpcode() != ISD::AND || N1.getOpcode() != ISD::AND)
7594 return SDValue();
7595 // TODO: this is too restrictive; lifting this restriction requires more tests
7596 if (!N0->hasOneUse() || !N1->hasOneUse())
7597 return SDValue();
7598 ConstantSDNode *Mask0 = isConstOrConstSplat(N: N0.getOperand(i: 1));
7599 ConstantSDNode *Mask1 = isConstOrConstSplat(N: N1.getOperand(i: 1));
7600 if (!Mask0 || !Mask1)
7601 return SDValue();
7602 if (Mask0->getAPIntValue() != 0xff00ff00 ||
7603 Mask1->getAPIntValue() != 0x00ff00ff)
7604 return SDValue();
7605 SDValue Shift0 = N0.getOperand(i: 0);
7606 SDValue Shift1 = N1.getOperand(i: 0);
7607 if (Shift0.getOpcode() != ISD::SHL || Shift1.getOpcode() != ISD::SRL)
7608 return SDValue();
7609 ConstantSDNode *ShiftAmt0 = isConstOrConstSplat(N: Shift0.getOperand(i: 1));
7610 ConstantSDNode *ShiftAmt1 = isConstOrConstSplat(N: Shift1.getOperand(i: 1));
7611 if (!ShiftAmt0 || !ShiftAmt1)
7612 return SDValue();
7613 if (ShiftAmt0->getAPIntValue() != 8 || ShiftAmt1->getAPIntValue() != 8)
7614 return SDValue();
7615 if (Shift0.getOperand(i: 0) != Shift1.getOperand(i: 0))
7616 return SDValue();
7617
7618 SDLoc DL(N);
7619 SDValue BSwap = DAG.getNode(Opcode: ISD::BSWAP, DL, VT, Operand: Shift0.getOperand(i: 0));
7620 SDValue ShAmt = DAG.getConstant(Val: 16, DL, VT: ShiftAmountTy);
7621 return DAG.getNode(Opcode: ISD::ROTR, DL, VT, N1: BSwap, N2: ShAmt);
7622}
7623
7624/// Match a 32-bit packed halfword bswap. That is
7625/// ((x & 0x000000ff) << 8) |
7626/// ((x & 0x0000ff00) >> 8) |
7627/// ((x & 0x00ff0000) << 8) |
7628/// ((x & 0xff000000) >> 8)
7629/// => (rotl (bswap x), 16)
7630SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
7631 if (!LegalOperations)
7632 return SDValue();
7633
7634 EVT VT = N->getValueType(ResNo: 0);
7635 if (VT != MVT::i32)
7636 return SDValue();
7637 if (!TLI.isOperationLegalOrCustom(Op: ISD::BSWAP, VT))
7638 return SDValue();
7639
7640 if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0, N1, VT,
7641 ShiftAmountTy: getShiftAmountTy(LHSTy: VT)))
7642 return BSwap;
7643
7644 // Try again with commuted operands.
7645 if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0: N1, N1: N0, VT,
7646 ShiftAmountTy: getShiftAmountTy(LHSTy: VT)))
7647 return BSwap;
7648
7649
7650 // Look for either
7651 // (or (bswaphpair), (bswaphpair))
7652 // (or (or (bswaphpair), (and)), (and))
7653 // (or (or (and), (bswaphpair)), (and))
7654 SDNode *Parts[4] = {};
7655
7656 if (isBSwapHWordPair(N: N0, Parts)) {
7657 // (or (or (and), (and)), (or (and), (and)))
7658 if (!isBSwapHWordPair(N: N1, Parts))
7659 return SDValue();
7660 } else if (N0.getOpcode() == ISD::OR) {
7661 // (or (or (or (and), (and)), (and)), (and))
7662 if (!isBSwapHWordElement(N: N1, Parts))
7663 return SDValue();
7664 SDValue N00 = N0.getOperand(i: 0);
7665 SDValue N01 = N0.getOperand(i: 1);
7666 if (!(isBSwapHWordElement(N: N01, Parts) && isBSwapHWordPair(N: N00, Parts)) &&
7667 !(isBSwapHWordElement(N: N00, Parts) && isBSwapHWordPair(N: N01, Parts)))
7668 return SDValue();
7669 } else {
7670 return SDValue();
7671 }
7672
7673 // Make sure the parts are all coming from the same node.
7674 if (Parts[0] != Parts[1] || Parts[0] != Parts[2] || Parts[0] != Parts[3])
7675 return SDValue();
7676
7677 SDLoc DL(N);
7678 SDValue BSwap = DAG.getNode(Opcode: ISD::BSWAP, DL, VT,
7679 Operand: SDValue(Parts[0], 0));
7680
7681 // Result of the bswap should be rotated by 16. If it's not legal, then
7682 // do (x << 16) | (x >> 16).
7683 SDValue ShAmt = DAG.getConstant(Val: 16, DL, VT: getShiftAmountTy(LHSTy: VT));
7684 if (TLI.isOperationLegalOrCustom(Op: ISD::ROTL, VT))
7685 return DAG.getNode(Opcode: ISD::ROTL, DL, VT, N1: BSwap, N2: ShAmt);
7686 if (TLI.isOperationLegalOrCustom(Op: ISD::ROTR, VT))
7687 return DAG.getNode(Opcode: ISD::ROTR, DL, VT, N1: BSwap, N2: ShAmt);
7688 return DAG.getNode(Opcode: ISD::OR, DL, VT,
7689 N1: DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: BSwap, N2: ShAmt),
7690 N2: DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: BSwap, N2: ShAmt));
7691}
7692
7693/// This contains all DAGCombine rules which reduce two values combined by
7694/// an Or operation to a single value \see visitANDLike().
7695SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *N) {
7696 EVT VT = N1.getValueType();
7697 SDLoc DL(N);
7698
7699 // fold (or x, undef) -> -1
7700 if (!LegalOperations && (N0.isUndef() || N1.isUndef()))
7701 return DAG.getAllOnesConstant(DL, VT);
7702
7703 if (SDValue V = foldLogicOfSetCCs(IsAnd: false, N0, N1, DL))
7704 return V;
7705
7706 // (or (and X, C1), (and Y, C2)) -> (and (or X, Y), C3) if possible.
7707 if (N0.getOpcode() == ISD::AND && N1.getOpcode() == ISD::AND &&
7708 // Don't increase # computations.
7709 (N0->hasOneUse() || N1->hasOneUse())) {
7710 // We can only do this xform if we know that bits from X that are set in C2
7711 // but not in C1 are already zero. Likewise for Y.
7712 if (const ConstantSDNode *N0O1C =
7713 getAsNonOpaqueConstant(N: N0.getOperand(i: 1))) {
7714 if (const ConstantSDNode *N1O1C =
7715 getAsNonOpaqueConstant(N: N1.getOperand(i: 1))) {
7716 // We can only do this xform if we know that bits from X that are set in
7717 // C2 but not in C1 are already zero. Likewise for Y.
7718 const APInt &LHSMask = N0O1C->getAPIntValue();
7719 const APInt &RHSMask = N1O1C->getAPIntValue();
7720
7721 if (DAG.MaskedValueIsZero(Op: N0.getOperand(i: 0), Mask: RHSMask&~LHSMask) &&
7722 DAG.MaskedValueIsZero(Op: N1.getOperand(i: 0), Mask: LHSMask&~RHSMask)) {
7723 SDValue X = DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N0), VT,
7724 N1: N0.getOperand(i: 0), N2: N1.getOperand(i: 0));
7725 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: X,
7726 N2: DAG.getConstant(Val: LHSMask | RHSMask, DL, VT));
7727 }
7728 }
7729 }
7730 }
7731
7732 // (or (and X, M), (and X, N)) -> (and X, (or M, N))
7733 if (N0.getOpcode() == ISD::AND &&
7734 N1.getOpcode() == ISD::AND &&
7735 N0.getOperand(i: 0) == N1.getOperand(i: 0) &&
7736 // Don't increase # computations.
7737 (N0->hasOneUse() || N1->hasOneUse())) {
7738 SDValue X = DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N0), VT,
7739 N1: N0.getOperand(i: 1), N2: N1.getOperand(i: 1));
7740 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0.getOperand(i: 0), N2: X);
7741 }
7742
7743 return SDValue();
7744}
7745
7746/// OR combines for which the commuted variant will be tried as well.
7747static SDValue visitORCommutative(SelectionDAG &DAG, SDValue N0, SDValue N1,
7748 SDNode *N) {
7749 EVT VT = N0.getValueType();
7750
7751 auto peekThroughResize = [](SDValue V) {
7752 if (V->getOpcode() == ISD::ZERO_EXTEND || V->getOpcode() == ISD::TRUNCATE)
7753 return V->getOperand(Num: 0);
7754 return V;
7755 };
7756
7757 SDValue N0Resized = peekThroughResize(N0);
7758 if (N0Resized.getOpcode() == ISD::AND) {
7759 SDValue N1Resized = peekThroughResize(N1);
7760 SDValue N00 = N0Resized.getOperand(i: 0);
7761 SDValue N01 = N0Resized.getOperand(i: 1);
7762
7763 // fold or (and x, y), x --> x
7764 if (N00 == N1Resized || N01 == N1Resized)
7765 return N1;
7766
7767 // fold (or (and X, (xor Y, -1)), Y) -> (or X, Y)
7768 // TODO: Set AllowUndefs = true.
7769 if (SDValue NotOperand = getBitwiseNotOperand(V: N01, Mask: N00,
7770 /* AllowUndefs */ false)) {
7771 if (peekThroughResize(NotOperand) == N1Resized)
7772 return DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N), VT,
7773 N1: DAG.getZExtOrTrunc(Op: N00, DL: SDLoc(N), VT), N2: N1);
7774 }
7775
7776 // fold (or (and (xor Y, -1), X), Y) -> (or X, Y)
7777 if (SDValue NotOperand = getBitwiseNotOperand(V: N00, Mask: N01,
7778 /* AllowUndefs */ false)) {
7779 if (peekThroughResize(NotOperand) == N1Resized)
7780 return DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N), VT,
7781 N1: DAG.getZExtOrTrunc(Op: N01, DL: SDLoc(N), VT), N2: N1);
7782 }
7783 }
7784
7785 if (N0.getOpcode() == ISD::XOR) {
7786 // fold or (xor x, y), x --> or x, y
7787 // or (xor x, y), (x and/or y) --> or x, y
7788 SDValue N00 = N0.getOperand(i: 0);
7789 SDValue N01 = N0.getOperand(i: 1);
7790 if (N00 == N1)
7791 return DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N), VT, N1: N01, N2: N1);
7792 if (N01 == N1)
7793 return DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N), VT, N1: N00, N2: N1);
7794
7795 if (N1.getOpcode() == ISD::AND || N1.getOpcode() == ISD::OR) {
7796 SDValue N10 = N1.getOperand(i: 0);
7797 SDValue N11 = N1.getOperand(i: 1);
7798 if ((N00 == N10 && N01 == N11) || (N00 == N11 && N01 == N10))
7799 return DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N), VT, N1: N00, N2: N01);
7800 }
7801 }
7802
7803 if (SDValue R = foldLogicOfShifts(N, LogicOp: N0, ShiftOp: N1, DAG))
7804 return R;
7805
7806 auto peekThroughZext = [](SDValue V) {
7807 if (V->getOpcode() == ISD::ZERO_EXTEND)
7808 return V->getOperand(Num: 0);
7809 return V;
7810 };
7811
7812 // (fshl X, ?, Y) | (shl X, Y) --> fshl X, ?, Y
7813 if (N0.getOpcode() == ISD::FSHL && N1.getOpcode() == ISD::SHL &&
7814 N0.getOperand(i: 0) == N1.getOperand(i: 0) &&
7815 peekThroughZext(N0.getOperand(i: 2)) == peekThroughZext(N1.getOperand(i: 1)))
7816 return N0;
7817
7818 // (fshr ?, X, Y) | (srl X, Y) --> fshr ?, X, Y
7819 if (N0.getOpcode() == ISD::FSHR && N1.getOpcode() == ISD::SRL &&
7820 N0.getOperand(i: 1) == N1.getOperand(i: 0) &&
7821 peekThroughZext(N0.getOperand(i: 2)) == peekThroughZext(N1.getOperand(i: 1)))
7822 return N0;
7823
7824 return SDValue();
7825}
7826
7827SDValue DAGCombiner::visitOR(SDNode *N) {
7828 SDValue N0 = N->getOperand(Num: 0);
7829 SDValue N1 = N->getOperand(Num: 1);
7830 EVT VT = N1.getValueType();
7831
7832 // x | x --> x
7833 if (N0 == N1)
7834 return N0;
7835
7836 // fold (or c1, c2) -> c1|c2
7837 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::OR, DL: SDLoc(N), VT, Ops: {N0, N1}))
7838 return C;
7839
7840 // canonicalize constant to RHS
7841 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
7842 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
7843 return DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N), VT, N1, N2: N0);
7844
7845 // fold vector ops
7846 if (VT.isVector()) {
7847 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL: SDLoc(N)))
7848 return FoldedVOp;
7849
7850 // fold (or x, 0) -> x, vector edition
7851 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
7852 return N0;
7853
7854 // fold (or x, -1) -> -1, vector edition
7855 if (ISD::isConstantSplatVectorAllOnes(N: N1.getNode()))
7856 // do not return N1, because undef node may exist in N1
7857 return DAG.getAllOnesConstant(DL: SDLoc(N), VT: N1.getValueType());
7858
7859 // fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf A, B, Mask)
7860 // Do this only if the resulting type / shuffle is legal.
7861 auto *SV0 = dyn_cast<ShuffleVectorSDNode>(Val&: N0);
7862 auto *SV1 = dyn_cast<ShuffleVectorSDNode>(Val&: N1);
7863 if (SV0 && SV1 && TLI.isTypeLegal(VT)) {
7864 bool ZeroN00 = ISD::isBuildVectorAllZeros(N: N0.getOperand(i: 0).getNode());
7865 bool ZeroN01 = ISD::isBuildVectorAllZeros(N: N0.getOperand(i: 1).getNode());
7866 bool ZeroN10 = ISD::isBuildVectorAllZeros(N: N1.getOperand(i: 0).getNode());
7867 bool ZeroN11 = ISD::isBuildVectorAllZeros(N: N1.getOperand(i: 1).getNode());
7868 // Ensure both shuffles have a zero input.
7869 if ((ZeroN00 != ZeroN01) && (ZeroN10 != ZeroN11)) {
7870 assert((!ZeroN00 || !ZeroN01) && "Both inputs zero!");
7871 assert((!ZeroN10 || !ZeroN11) && "Both inputs zero!");
7872 bool CanFold = true;
7873 int NumElts = VT.getVectorNumElements();
7874 SmallVector<int, 4> Mask(NumElts, -1);
7875
7876 for (int i = 0; i != NumElts; ++i) {
7877 int M0 = SV0->getMaskElt(Idx: i);
7878 int M1 = SV1->getMaskElt(Idx: i);
7879
7880 // Determine if either index is pointing to a zero vector.
7881 bool M0Zero = M0 < 0 || (ZeroN00 == (M0 < NumElts));
7882 bool M1Zero = M1 < 0 || (ZeroN10 == (M1 < NumElts));
7883
7884 // If one element is zero and the otherside is undef, keep undef.
7885 // This also handles the case that both are undef.
7886 if ((M0Zero && M1 < 0) || (M1Zero && M0 < 0))
7887 continue;
7888
7889 // Make sure only one of the elements is zero.
7890 if (M0Zero == M1Zero) {
7891 CanFold = false;
7892 break;
7893 }
7894
7895 assert((M0 >= 0 || M1 >= 0) && "Undef index!");
7896
7897 // We have a zero and non-zero element. If the non-zero came from
7898 // SV0 make the index a LHS index. If it came from SV1, make it
7899 // a RHS index. We need to mod by NumElts because we don't care
7900 // which operand it came from in the original shuffles.
7901 Mask[i] = M1Zero ? M0 % NumElts : (M1 % NumElts) + NumElts;
7902 }
7903
7904 if (CanFold) {
7905 SDValue NewLHS = ZeroN00 ? N0.getOperand(i: 1) : N0.getOperand(i: 0);
7906 SDValue NewRHS = ZeroN10 ? N1.getOperand(i: 1) : N1.getOperand(i: 0);
7907
7908 SDValue LegalShuffle =
7909 TLI.buildLegalVectorShuffle(VT, DL: SDLoc(N), N0: NewLHS, N1: NewRHS,
7910 Mask, DAG);
7911 if (LegalShuffle)
7912 return LegalShuffle;
7913 }
7914 }
7915 }
7916 }
7917
7918 // fold (or x, 0) -> x
7919 if (isNullConstant(V: N1))
7920 return N0;
7921
7922 // fold (or x, -1) -> -1
7923 if (isAllOnesConstant(V: N1))
7924 return N1;
7925
7926 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
7927 return NewSel;
7928
7929 // fold (or x, c) -> c iff (x & ~c) == 0
7930 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(Val&: N1);
7931 if (N1C && DAG.MaskedValueIsZero(Op: N0, Mask: ~N1C->getAPIntValue()))
7932 return N1;
7933
7934 if (SDValue R = foldAndOrOfSETCC(LogicOp: N, DAG))
7935 return R;
7936
7937 if (SDValue Combined = visitORLike(N0, N1, N))
7938 return Combined;
7939
7940 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
7941 return Combined;
7942
7943 // Recognize halfword bswaps as (bswap + rotl 16) or (bswap + shl 16)
7944 if (SDValue BSwap = MatchBSwapHWord(N, N0, N1))
7945 return BSwap;
7946 if (SDValue BSwap = MatchBSwapHWordLow(N, N0, N1))
7947 return BSwap;
7948
7949 // reassociate or
7950 if (SDValue ROR = reassociateOps(Opc: ISD::OR, DL: SDLoc(N), N0, N1, Flags: N->getFlags()))
7951 return ROR;
7952
7953 // Fold or(vecreduce(x), vecreduce(y)) -> vecreduce(or(x, y))
7954 if (SDValue SD = reassociateReduction(RedOpc: ISD::VECREDUCE_OR, Opc: ISD::OR, DL: SDLoc(N),
7955 VT, N0, N1))
7956 return SD;
7957
7958 // Canonicalize (or (and X, c1), c2) -> (and (or X, c2), c1|c2)
7959 // iff (c1 & c2) != 0 or c1/c2 are undef.
7960 auto MatchIntersect = [](ConstantSDNode *C1, ConstantSDNode *C2) {
7961 return !C1 || !C2 || C1->getAPIntValue().intersects(RHS: C2->getAPIntValue());
7962 };
7963 if (N0.getOpcode() == ISD::AND && N0->hasOneUse() &&
7964 ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchIntersect, AllowUndefs: true)) {
7965 if (SDValue COR = DAG.FoldConstantArithmetic(Opcode: ISD::OR, DL: SDLoc(N1), VT,
7966 Ops: {N1, N0.getOperand(i: 1)})) {
7967 SDValue IOR = DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N0), VT, N1: N0.getOperand(i: 0), N2: N1);
7968 AddToWorklist(N: IOR.getNode());
7969 return DAG.getNode(Opcode: ISD::AND, DL: SDLoc(N), VT, N1: COR, N2: IOR);
7970 }
7971 }
7972
7973 if (SDValue Combined = visitORCommutative(DAG, N0, N1, N))
7974 return Combined;
7975 if (SDValue Combined = visitORCommutative(DAG, N0: N1, N1: N0, N))
7976 return Combined;
7977
7978 // Simplify: (or (op x...), (op y...)) -> (op (or x, y))
7979 if (N0.getOpcode() == N1.getOpcode())
7980 if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
7981 return V;
7982
7983 // See if this is some rotate idiom.
7984 if (SDValue Rot = MatchRotate(LHS: N0, RHS: N1, DL: SDLoc(N)))
7985 return Rot;
7986
7987 if (SDValue Load = MatchLoadCombine(N))
7988 return Load;
7989
7990 // Simplify the operands using demanded-bits information.
7991 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
7992 return SDValue(N, 0);
7993
7994 // If OR can be rewritten into ADD, try combines based on ADD.
7995 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::ADD, VT)) &&
7996 DAG.isADDLike(Op: SDValue(N, 0)))
7997 if (SDValue Combined = visitADDLike(N))
7998 return Combined;
7999
8000 // Postpone until legalization completed to avoid interference with bswap
8001 // folding
8002 if (LegalOperations || VT.isVector())
8003 if (SDValue R = foldLogicTreeOfShifts(N, LeftHand: N0, RightHand: N1, DAG))
8004 return R;
8005
8006 return SDValue();
8007}
8008
8009static SDValue stripConstantMask(const SelectionDAG &DAG, SDValue Op,
8010 SDValue &Mask) {
8011 if (Op.getOpcode() == ISD::AND &&
8012 DAG.isConstantIntBuildVectorOrConstantInt(N: Op.getOperand(i: 1))) {
8013 Mask = Op.getOperand(i: 1);
8014 return Op.getOperand(i: 0);
8015 }
8016 return Op;
8017}
8018
8019/// Match "(X shl/srl V1) & V2" where V2 may not be present.
8020static bool matchRotateHalf(const SelectionDAG &DAG, SDValue Op, SDValue &Shift,
8021 SDValue &Mask) {
8022 Op = stripConstantMask(DAG, Op, Mask);
8023 if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SHL) {
8024 Shift = Op;
8025 return true;
8026 }
8027 return false;
8028}
8029
8030/// Helper function for visitOR to extract the needed side of a rotate idiom
8031/// from a shl/srl/mul/udiv. This is meant to handle cases where
8032/// InstCombine merged some outside op with one of the shifts from
8033/// the rotate pattern.
8034/// \returns An empty \c SDValue if the needed shift couldn't be extracted.
8035/// Otherwise, returns an expansion of \p ExtractFrom based on the following
8036/// patterns:
8037///
8038/// (or (add v v) (shrl v bitwidth-1)):
8039/// expands (add v v) -> (shl v 1)
8040///
8041/// (or (mul v c0) (shrl (mul v c1) c2)):
8042/// expands (mul v c0) -> (shl (mul v c1) c3)
8043///
8044/// (or (udiv v c0) (shl (udiv v c1) c2)):
8045/// expands (udiv v c0) -> (shrl (udiv v c1) c3)
8046///
8047/// (or (shl v c0) (shrl (shl v c1) c2)):
8048/// expands (shl v c0) -> (shl (shl v c1) c3)
8049///
8050/// (or (shrl v c0) (shl (shrl v c1) c2)):
8051/// expands (shrl v c0) -> (shrl (shrl v c1) c3)
8052///
8053/// Such that in all cases, c3+c2==bitwidth(op v c1).
8054static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift,
8055 SDValue ExtractFrom, SDValue &Mask,
8056 const SDLoc &DL) {
8057 assert(OppShift && ExtractFrom && "Empty SDValue");
8058 if (OppShift.getOpcode() != ISD::SHL && OppShift.getOpcode() != ISD::SRL)
8059 return SDValue();
8060
8061 ExtractFrom = stripConstantMask(DAG, Op: ExtractFrom, Mask);
8062
8063 // Value and Type of the shift.
8064 SDValue OppShiftLHS = OppShift.getOperand(i: 0);
8065 EVT ShiftedVT = OppShiftLHS.getValueType();
8066
8067 // Amount of the existing shift.
8068 ConstantSDNode *OppShiftCst = isConstOrConstSplat(N: OppShift.getOperand(i: 1));
8069
8070 // (add v v) -> (shl v 1)
8071 // TODO: Should this be a general DAG canonicalization?
8072 if (OppShift.getOpcode() == ISD::SRL && OppShiftCst &&
8073 ExtractFrom.getOpcode() == ISD::ADD &&
8074 ExtractFrom.getOperand(i: 0) == ExtractFrom.getOperand(i: 1) &&
8075 ExtractFrom.getOperand(i: 0) == OppShiftLHS &&
8076 OppShiftCst->getAPIntValue() == ShiftedVT.getScalarSizeInBits() - 1)
8077 return DAG.getNode(Opcode: ISD::SHL, DL, VT: ShiftedVT, N1: OppShiftLHS,
8078 N2: DAG.getShiftAmountConstant(Val: 1, VT: ShiftedVT, DL));
8079
8080 // Preconditions:
8081 // (or (op0 v c0) (shiftl/r (op0 v c1) c2))
8082 //
8083 // Find opcode of the needed shift to be extracted from (op0 v c0).
8084 unsigned Opcode = ISD::DELETED_NODE;
8085 bool IsMulOrDiv = false;
8086 // Set Opcode and IsMulOrDiv if the extract opcode matches the needed shift
8087 // opcode or its arithmetic (mul or udiv) variant.
8088 auto SelectOpcode = [&](unsigned NeededShift, unsigned MulOrDivVariant) {
8089 IsMulOrDiv = ExtractFrom.getOpcode() == MulOrDivVariant;
8090 if (!IsMulOrDiv && ExtractFrom.getOpcode() != NeededShift)
8091 return false;
8092 Opcode = NeededShift;
8093 return true;
8094 };
8095 // op0 must be either the needed shift opcode or the mul/udiv equivalent
8096 // that the needed shift can be extracted from.
8097 if ((OppShift.getOpcode() != ISD::SRL || !SelectOpcode(ISD::SHL, ISD::MUL)) &&
8098 (OppShift.getOpcode() != ISD::SHL || !SelectOpcode(ISD::SRL, ISD::UDIV)))
8099 return SDValue();
8100
8101 // op0 must be the same opcode on both sides, have the same LHS argument,
8102 // and produce the same value type.
8103 if (OppShiftLHS.getOpcode() != ExtractFrom.getOpcode() ||
8104 OppShiftLHS.getOperand(i: 0) != ExtractFrom.getOperand(i: 0) ||
8105 ShiftedVT != ExtractFrom.getValueType())
8106 return SDValue();
8107
8108 // Constant mul/udiv/shift amount from the RHS of the shift's LHS op.
8109 ConstantSDNode *OppLHSCst = isConstOrConstSplat(N: OppShiftLHS.getOperand(i: 1));
8110 // Constant mul/udiv/shift amount from the RHS of the ExtractFrom op.
8111 ConstantSDNode *ExtractFromCst =
8112 isConstOrConstSplat(N: ExtractFrom.getOperand(i: 1));
8113 // TODO: We should be able to handle non-uniform constant vectors for these values
8114 // Check that we have constant values.
8115 if (!OppShiftCst || !OppShiftCst->getAPIntValue() ||
8116 !OppLHSCst || !OppLHSCst->getAPIntValue() ||
8117 !ExtractFromCst || !ExtractFromCst->getAPIntValue())
8118 return SDValue();
8119
8120 // Compute the shift amount we need to extract to complete the rotate.
8121 const unsigned VTWidth = ShiftedVT.getScalarSizeInBits();
8122 if (OppShiftCst->getAPIntValue().ugt(RHS: VTWidth))
8123 return SDValue();
8124 APInt NeededShiftAmt = VTWidth - OppShiftCst->getAPIntValue();
8125 // Normalize the bitwidth of the two mul/udiv/shift constant operands.
8126 APInt ExtractFromAmt = ExtractFromCst->getAPIntValue();
8127 APInt OppLHSAmt = OppLHSCst->getAPIntValue();
8128 zeroExtendToMatch(LHS&: ExtractFromAmt, RHS&: OppLHSAmt);
8129
8130 // Now try extract the needed shift from the ExtractFrom op and see if the
8131 // result matches up with the existing shift's LHS op.
8132 if (IsMulOrDiv) {
8133 // Op to extract from is a mul or udiv by a constant.
8134 // Check:
8135 // c2 / (1 << (bitwidth(op0 v c0) - c1)) == c0
8136 // c2 % (1 << (bitwidth(op0 v c0) - c1)) == 0
8137 const APInt ExtractDiv = APInt::getOneBitSet(numBits: ExtractFromAmt.getBitWidth(),
8138 BitNo: NeededShiftAmt.getZExtValue());
8139 APInt ResultAmt;
8140 APInt Rem;
8141 APInt::udivrem(LHS: ExtractFromAmt, RHS: ExtractDiv, Quotient&: ResultAmt, Remainder&: Rem);
8142 if (Rem != 0 || ResultAmt != OppLHSAmt)
8143 return SDValue();
8144 } else {
8145 // Op to extract from is a shift by a constant.
8146 // Check:
8147 // c2 - (bitwidth(op0 v c0) - c1) == c0
8148 if (OppLHSAmt != ExtractFromAmt - NeededShiftAmt.zextOrTrunc(
8149 width: ExtractFromAmt.getBitWidth()))
8150 return SDValue();
8151 }
8152
8153 // Return the expanded shift op that should allow a rotate to be formed.
8154 EVT ShiftVT = OppShift.getOperand(i: 1).getValueType();
8155 EVT ResVT = ExtractFrom.getValueType();
8156 SDValue NewShiftNode = DAG.getConstant(Val: NeededShiftAmt, DL, VT: ShiftVT);
8157 return DAG.getNode(Opcode, DL, VT: ResVT, N1: OppShiftLHS, N2: NewShiftNode);
8158}
8159
8160// Return true if we can prove that, whenever Neg and Pos are both in the
8161// range [0, EltSize), Neg == (Pos == 0 ? 0 : EltSize - Pos). This means that
8162// for two opposing shifts shift1 and shift2 and a value X with OpBits bits:
8163//
8164// (or (shift1 X, Neg), (shift2 X, Pos))
8165//
8166// reduces to a rotate in direction shift2 by Pos or (equivalently) a rotate
8167// in direction shift1 by Neg. The range [0, EltSize) means that we only need
8168// to consider shift amounts with defined behavior.
8169//
8170// The IsRotate flag should be set when the LHS of both shifts is the same.
8171// Otherwise if matching a general funnel shift, it should be clear.
8172static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
8173 SelectionDAG &DAG, bool IsRotate) {
8174 const auto &TLI = DAG.getTargetLoweringInfo();
8175 // If EltSize is a power of 2 then:
8176 //
8177 // (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1)
8178 // (b) Neg == Neg & (EltSize - 1) whenever Neg is in [0, EltSize).
8179 //
8180 // So if EltSize is a power of 2 and Neg is (and Neg', EltSize-1), we check
8181 // for the stronger condition:
8182 //
8183 // Neg & (EltSize - 1) == (EltSize - Pos) & (EltSize - 1) [A]
8184 //
8185 // for all Neg and Pos. Since Neg & (EltSize - 1) == Neg' & (EltSize - 1)
8186 // we can just replace Neg with Neg' for the rest of the function.
8187 //
8188 // In other cases we check for the even stronger condition:
8189 //
8190 // Neg == EltSize - Pos [B]
8191 //
8192 // for all Neg and Pos. Note that the (or ...) then invokes undefined
8193 // behavior if Pos == 0 (and consequently Neg == EltSize).
8194 //
8195 // We could actually use [A] whenever EltSize is a power of 2, but the
8196 // only extra cases that it would match are those uninteresting ones
8197 // where Neg and Pos are never in range at the same time. E.g. for
8198 // EltSize == 32, using [A] would allow a Neg of the form (sub 64, Pos)
8199 // as well as (sub 32, Pos), but:
8200 //
8201 // (or (shift1 X, (sub 64, Pos)), (shift2 X, Pos))
8202 //
8203 // always invokes undefined behavior for 32-bit X.
8204 //
8205 // Below, Mask == EltSize - 1 when using [A] and is all-ones otherwise.
8206 // This allows us to peek through any operations that only affect Mask's
8207 // un-demanded bits.
8208 //
8209 // NOTE: We can only do this when matching operations which won't modify the
8210 // least Log2(EltSize) significant bits and not a general funnel shift.
8211 unsigned MaskLoBits = 0;
8212 if (IsRotate && isPowerOf2_64(Value: EltSize)) {
8213 unsigned Bits = Log2_64(Value: EltSize);
8214 unsigned NegBits = Neg.getScalarValueSizeInBits();
8215 if (NegBits >= Bits) {
8216 APInt DemandedBits = APInt::getLowBitsSet(numBits: NegBits, loBitsSet: Bits);
8217 if (SDValue Inner =
8218 TLI.SimplifyMultipleUseDemandedBits(Op: Neg, DemandedBits, DAG)) {
8219 Neg = Inner;
8220 MaskLoBits = Bits;
8221 }
8222 }
8223 }
8224
8225 // Check whether Neg has the form (sub NegC, NegOp1) for some NegC and NegOp1.
8226 if (Neg.getOpcode() != ISD::SUB)
8227 return false;
8228 ConstantSDNode *NegC = isConstOrConstSplat(N: Neg.getOperand(i: 0));
8229 if (!NegC)
8230 return false;
8231 SDValue NegOp1 = Neg.getOperand(i: 1);
8232
8233 // On the RHS of [A], if Pos is the result of operation on Pos' that won't
8234 // affect Mask's demanded bits, just replace Pos with Pos'. These operations
8235 // are redundant for the purpose of the equality.
8236 if (MaskLoBits) {
8237 unsigned PosBits = Pos.getScalarValueSizeInBits();
8238 if (PosBits >= MaskLoBits) {
8239 APInt DemandedBits = APInt::getLowBitsSet(numBits: PosBits, loBitsSet: MaskLoBits);
8240 if (SDValue Inner =
8241 TLI.SimplifyMultipleUseDemandedBits(Op: Pos, DemandedBits, DAG)) {
8242 Pos = Inner;
8243 }
8244 }
8245 }
8246
8247 // The condition we need is now:
8248 //
8249 // (NegC - NegOp1) & Mask == (EltSize - Pos) & Mask
8250 //
8251 // If NegOp1 == Pos then we need:
8252 //
8253 // EltSize & Mask == NegC & Mask
8254 //
8255 // (because "x & Mask" is a truncation and distributes through subtraction).
8256 //
8257 // We also need to account for a potential truncation of NegOp1 if the amount
8258 // has already been legalized to a shift amount type.
8259 APInt Width;
8260 if ((Pos == NegOp1) ||
8261 (NegOp1.getOpcode() == ISD::TRUNCATE && Pos == NegOp1.getOperand(i: 0)))
8262 Width = NegC->getAPIntValue();
8263
8264 // Check for cases where Pos has the form (add NegOp1, PosC) for some PosC.
8265 // Then the condition we want to prove becomes:
8266 //
8267 // (NegC - NegOp1) & Mask == (EltSize - (NegOp1 + PosC)) & Mask
8268 //
8269 // which, again because "x & Mask" is a truncation, becomes:
8270 //
8271 // NegC & Mask == (EltSize - PosC) & Mask
8272 // EltSize & Mask == (NegC + PosC) & Mask
8273 else if (Pos.getOpcode() == ISD::ADD && Pos.getOperand(i: 0) == NegOp1) {
8274 if (ConstantSDNode *PosC = isConstOrConstSplat(N: Pos.getOperand(i: 1)))
8275 Width = PosC->getAPIntValue() + NegC->getAPIntValue();
8276 else
8277 return false;
8278 } else
8279 return false;
8280
8281 // Now we just need to check that EltSize & Mask == Width & Mask.
8282 if (MaskLoBits)
8283 // EltSize & Mask is 0 since Mask is EltSize - 1.
8284 return Width.getLoBits(numBits: MaskLoBits) == 0;
8285 return Width == EltSize;
8286}
8287
8288// A subroutine of MatchRotate used once we have found an OR of two opposite
8289// shifts of Shifted. If Neg == <operand size> - Pos then the OR reduces
8290// to both (PosOpcode Shifted, Pos) and (NegOpcode Shifted, Neg), with the
8291// former being preferred if supported. InnerPos and InnerNeg are Pos and
8292// Neg with outer conversions stripped away.
8293SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
8294 SDValue Neg, SDValue InnerPos,
8295 SDValue InnerNeg, bool HasPos,
8296 unsigned PosOpcode, unsigned NegOpcode,
8297 const SDLoc &DL) {
8298 // fold (or (shl x, (*ext y)),
8299 // (srl x, (*ext (sub 32, y)))) ->
8300 // (rotl x, y) or (rotr x, (sub 32, y))
8301 //
8302 // fold (or (shl x, (*ext (sub 32, y))),
8303 // (srl x, (*ext y))) ->
8304 // (rotr x, y) or (rotl x, (sub 32, y))
8305 EVT VT = Shifted.getValueType();
8306 if (matchRotateSub(Pos: InnerPos, Neg: InnerNeg, EltSize: VT.getScalarSizeInBits(), DAG,
8307 /*IsRotate*/ true)) {
8308 return DAG.getNode(Opcode: HasPos ? PosOpcode : NegOpcode, DL, VT, N1: Shifted,
8309 N2: HasPos ? Pos : Neg);
8310 }
8311
8312 return SDValue();
8313}
8314
8315// A subroutine of MatchRotate used once we have found an OR of two opposite
8316// shifts of N0 + N1. If Neg == <operand size> - Pos then the OR reduces
8317// to both (PosOpcode N0, N1, Pos) and (NegOpcode N0, N1, Neg), with the
8318// former being preferred if supported. InnerPos and InnerNeg are Pos and
8319// Neg with outer conversions stripped away.
8320// TODO: Merge with MatchRotatePosNeg.
8321SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
8322 SDValue Neg, SDValue InnerPos,
8323 SDValue InnerNeg, bool HasPos,
8324 unsigned PosOpcode, unsigned NegOpcode,
8325 const SDLoc &DL) {
8326 EVT VT = N0.getValueType();
8327 unsigned EltBits = VT.getScalarSizeInBits();
8328
8329 // fold (or (shl x0, (*ext y)),
8330 // (srl x1, (*ext (sub 32, y)))) ->
8331 // (fshl x0, x1, y) or (fshr x0, x1, (sub 32, y))
8332 //
8333 // fold (or (shl x0, (*ext (sub 32, y))),
8334 // (srl x1, (*ext y))) ->
8335 // (fshr x0, x1, y) or (fshl x0, x1, (sub 32, y))
8336 if (matchRotateSub(Pos: InnerPos, Neg: InnerNeg, EltSize: EltBits, DAG, /*IsRotate*/ N0 == N1)) {
8337 return DAG.getNode(Opcode: HasPos ? PosOpcode : NegOpcode, DL, VT, N1: N0, N2: N1,
8338 N3: HasPos ? Pos : Neg);
8339 }
8340
8341 // Matching the shift+xor cases, we can't easily use the xor'd shift amount
8342 // so for now just use the PosOpcode case if its legal.
8343 // TODO: When can we use the NegOpcode case?
8344 if (PosOpcode == ISD::FSHL && isPowerOf2_32(Value: EltBits)) {
8345 auto IsBinOpImm = [](SDValue Op, unsigned BinOpc, unsigned Imm) {
8346 if (Op.getOpcode() != BinOpc)
8347 return false;
8348 ConstantSDNode *Cst = isConstOrConstSplat(N: Op.getOperand(i: 1));
8349 return Cst && (Cst->getAPIntValue() == Imm);
8350 };
8351
8352 // fold (or (shl x0, y), (srl (srl x1, 1), (xor y, 31)))
8353 // -> (fshl x0, x1, y)
8354 if (IsBinOpImm(N1, ISD::SRL, 1) &&
8355 IsBinOpImm(InnerNeg, ISD::XOR, EltBits - 1) &&
8356 InnerPos == InnerNeg.getOperand(i: 0) &&
8357 TLI.isOperationLegalOrCustom(Op: ISD::FSHL, VT)) {
8358 return DAG.getNode(Opcode: ISD::FSHL, DL, VT, N1: N0, N2: N1.getOperand(i: 0), N3: Pos);
8359 }
8360
8361 // fold (or (shl (shl x0, 1), (xor y, 31)), (srl x1, y))
8362 // -> (fshr x0, x1, y)
8363 if (IsBinOpImm(N0, ISD::SHL, 1) &&
8364 IsBinOpImm(InnerPos, ISD::XOR, EltBits - 1) &&
8365 InnerNeg == InnerPos.getOperand(i: 0) &&
8366 TLI.isOperationLegalOrCustom(Op: ISD::FSHR, VT)) {
8367 return DAG.getNode(Opcode: ISD::FSHR, DL, VT, N1: N0.getOperand(i: 0), N2: N1, N3: Neg);
8368 }
8369
8370 // fold (or (shl (add x0, x0), (xor y, 31)), (srl x1, y))
8371 // -> (fshr x0, x1, y)
8372 // TODO: Should add(x,x) -> shl(x,1) be a general DAG canonicalization?
8373 if (N0.getOpcode() == ISD::ADD && N0.getOperand(i: 0) == N0.getOperand(i: 1) &&
8374 IsBinOpImm(InnerPos, ISD::XOR, EltBits - 1) &&
8375 InnerNeg == InnerPos.getOperand(i: 0) &&
8376 TLI.isOperationLegalOrCustom(Op: ISD::FSHR, VT)) {
8377 return DAG.getNode(Opcode: ISD::FSHR, DL, VT, N1: N0.getOperand(i: 0), N2: N1, N3: Neg);
8378 }
8379 }
8380
8381 return SDValue();
8382}
8383
8384// MatchRotate - Handle an 'or' of two operands. If this is one of the many
8385// idioms for rotate, and if the target supports rotation instructions, generate
8386// a rot[lr]. This also matches funnel shift patterns, similar to rotation but
8387// with different shifted sources.
8388SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
8389 EVT VT = LHS.getValueType();
8390
8391 // The target must have at least one rotate/funnel flavor.
8392 // We still try to match rotate by constant pre-legalization.
8393 // TODO: Support pre-legalization funnel-shift by constant.
8394 bool HasROTL = hasOperation(Opcode: ISD::ROTL, VT);
8395 bool HasROTR = hasOperation(Opcode: ISD::ROTR, VT);
8396 bool HasFSHL = hasOperation(Opcode: ISD::FSHL, VT);
8397 bool HasFSHR = hasOperation(Opcode: ISD::FSHR, VT);
8398
8399 // If the type is going to be promoted and the target has enabled custom
8400 // lowering for rotate, allow matching rotate by non-constants. Only allow
8401 // this for scalar types.
8402 if (VT.isScalarInteger() && TLI.getTypeAction(Context&: *DAG.getContext(), VT) ==
8403 TargetLowering::TypePromoteInteger) {
8404 HasROTL |= TLI.getOperationAction(Op: ISD::ROTL, VT) == TargetLowering::Custom;
8405 HasROTR |= TLI.getOperationAction(Op: ISD::ROTR, VT) == TargetLowering::Custom;
8406 }
8407
8408 if (LegalOperations && !HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
8409 return SDValue();
8410
8411 // Check for truncated rotate.
8412 if (LHS.getOpcode() == ISD::TRUNCATE && RHS.getOpcode() == ISD::TRUNCATE &&
8413 LHS.getOperand(i: 0).getValueType() == RHS.getOperand(i: 0).getValueType()) {
8414 assert(LHS.getValueType() == RHS.getValueType());
8415 if (SDValue Rot = MatchRotate(LHS: LHS.getOperand(i: 0), RHS: RHS.getOperand(i: 0), DL)) {
8416 return DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(LHS), VT: LHS.getValueType(), Operand: Rot);
8417 }
8418 }
8419
8420 // Match "(X shl/srl V1) & V2" where V2 may not be present.
8421 SDValue LHSShift; // The shift.
8422 SDValue LHSMask; // AND value if any.
8423 matchRotateHalf(DAG, Op: LHS, Shift&: LHSShift, Mask&: LHSMask);
8424
8425 SDValue RHSShift; // The shift.
8426 SDValue RHSMask; // AND value if any.
8427 matchRotateHalf(DAG, Op: RHS, Shift&: RHSShift, Mask&: RHSMask);
8428
8429 // If neither side matched a rotate half, bail
8430 if (!LHSShift && !RHSShift)
8431 return SDValue();
8432
8433 // InstCombine may have combined a constant shl, srl, mul, or udiv with one
8434 // side of the rotate, so try to handle that here. In all cases we need to
8435 // pass the matched shift from the opposite side to compute the opcode and
8436 // needed shift amount to extract. We still want to do this if both sides
8437 // matched a rotate half because one half may be a potential overshift that
8438 // can be broken down (ie if InstCombine merged two shl or srl ops into a
8439 // single one).
8440
8441 // Have LHS side of the rotate, try to extract the needed shift from the RHS.
8442 if (LHSShift)
8443 if (SDValue NewRHSShift =
8444 extractShiftForRotate(DAG, OppShift: LHSShift, ExtractFrom: RHS, Mask&: RHSMask, DL))
8445 RHSShift = NewRHSShift;
8446 // Have RHS side of the rotate, try to extract the needed shift from the LHS.
8447 if (RHSShift)
8448 if (SDValue NewLHSShift =
8449 extractShiftForRotate(DAG, OppShift: RHSShift, ExtractFrom: LHS, Mask&: LHSMask, DL))
8450 LHSShift = NewLHSShift;
8451
8452 // If a side is still missing, nothing else we can do.
8453 if (!RHSShift || !LHSShift)
8454 return SDValue();
8455
8456 // At this point we've matched or extracted a shift op on each side.
8457
8458 if (LHSShift.getOpcode() == RHSShift.getOpcode())
8459 return SDValue(); // Shifts must disagree.
8460
8461 // Canonicalize shl to left side in a shl/srl pair.
8462 if (RHSShift.getOpcode() == ISD::SHL) {
8463 std::swap(a&: LHS, b&: RHS);
8464 std::swap(a&: LHSShift, b&: RHSShift);
8465 std::swap(a&: LHSMask, b&: RHSMask);
8466 }
8467
8468 // Something has gone wrong - we've lost the shl/srl pair - bail.
8469 if (LHSShift.getOpcode() != ISD::SHL || RHSShift.getOpcode() != ISD::SRL)
8470 return SDValue();
8471
8472 unsigned EltSizeInBits = VT.getScalarSizeInBits();
8473 SDValue LHSShiftArg = LHSShift.getOperand(i: 0);
8474 SDValue LHSShiftAmt = LHSShift.getOperand(i: 1);
8475 SDValue RHSShiftArg = RHSShift.getOperand(i: 0);
8476 SDValue RHSShiftAmt = RHSShift.getOperand(i: 1);
8477
8478 auto MatchRotateSum = [EltSizeInBits](ConstantSDNode *LHS,
8479 ConstantSDNode *RHS) {
8480 return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits;
8481 };
8482
8483 auto ApplyMasks = [&](SDValue Res) {
8484 // If there is an AND of either shifted operand, apply it to the result.
8485 if (LHSMask.getNode() || RHSMask.getNode()) {
8486 SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
8487 SDValue Mask = AllOnes;
8488
8489 if (LHSMask.getNode()) {
8490 SDValue RHSBits = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: AllOnes, N2: RHSShiftAmt);
8491 Mask = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Mask,
8492 N2: DAG.getNode(Opcode: ISD::OR, DL, VT, N1: LHSMask, N2: RHSBits));
8493 }
8494 if (RHSMask.getNode()) {
8495 SDValue LHSBits = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: AllOnes, N2: LHSShiftAmt);
8496 Mask = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Mask,
8497 N2: DAG.getNode(Opcode: ISD::OR, DL, VT, N1: RHSMask, N2: LHSBits));
8498 }
8499
8500 Res = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Res, N2: Mask);
8501 }
8502
8503 return Res;
8504 };
8505
8506 // TODO: Support pre-legalization funnel-shift by constant.
8507 bool IsRotate = LHSShiftArg == RHSShiftArg;
8508 if (!IsRotate && !(HasFSHL || HasFSHR)) {
8509 if (TLI.isTypeLegal(VT) && LHS.hasOneUse() && RHS.hasOneUse() &&
8510 ISD::matchBinaryPredicate(LHS: LHSShiftAmt, RHS: RHSShiftAmt, Match: MatchRotateSum)) {
8511 // Look for a disguised rotate by constant.
8512 // The common shifted operand X may be hidden inside another 'or'.
8513 SDValue X, Y;
8514 auto matchOr = [&X, &Y](SDValue Or, SDValue CommonOp) {
8515 if (!Or.hasOneUse() || Or.getOpcode() != ISD::OR)
8516 return false;
8517 if (CommonOp == Or.getOperand(i: 0)) {
8518 X = CommonOp;
8519 Y = Or.getOperand(i: 1);
8520 return true;
8521 }
8522 if (CommonOp == Or.getOperand(i: 1)) {
8523 X = CommonOp;
8524 Y = Or.getOperand(i: 0);
8525 return true;
8526 }
8527 return false;
8528 };
8529
8530 SDValue Res;
8531 if (matchOr(LHSShiftArg, RHSShiftArg)) {
8532 // (shl (X | Y), C1) | (srl X, C2) --> (rotl X, C1) | (shl Y, C1)
8533 SDValue RotX = DAG.getNode(Opcode: ISD::ROTL, DL, VT, N1: X, N2: LHSShiftAmt);
8534 SDValue ShlY = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Y, N2: LHSShiftAmt);
8535 Res = DAG.getNode(Opcode: ISD::OR, DL, VT, N1: RotX, N2: ShlY);
8536 } else if (matchOr(RHSShiftArg, LHSShiftArg)) {
8537 // (shl X, C1) | (srl (X | Y), C2) --> (rotl X, C1) | (srl Y, C2)
8538 SDValue RotX = DAG.getNode(Opcode: ISD::ROTL, DL, VT, N1: X, N2: LHSShiftAmt);
8539 SDValue SrlY = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Y, N2: RHSShiftAmt);
8540 Res = DAG.getNode(Opcode: ISD::OR, DL, VT, N1: RotX, N2: SrlY);
8541 } else {
8542 return SDValue();
8543 }
8544
8545 return ApplyMasks(Res);
8546 }
8547
8548 return SDValue(); // Requires funnel shift support.
8549 }
8550
8551 // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1)
8552 // fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2)
8553 // fold (or (shl x, C1), (srl y, C2)) -> (fshl x, y, C1)
8554 // fold (or (shl x, C1), (srl y, C2)) -> (fshr x, y, C2)
8555 // iff C1+C2 == EltSizeInBits
8556 if (ISD::matchBinaryPredicate(LHS: LHSShiftAmt, RHS: RHSShiftAmt, Match: MatchRotateSum)) {
8557 SDValue Res;
8558 if (IsRotate && (HasROTL || HasROTR || !(HasFSHL || HasFSHR))) {
8559 bool UseROTL = !LegalOperations || HasROTL;
8560 Res = DAG.getNode(Opcode: UseROTL ? ISD::ROTL : ISD::ROTR, DL, VT, N1: LHSShiftArg,
8561 N2: UseROTL ? LHSShiftAmt : RHSShiftAmt);
8562 } else {
8563 bool UseFSHL = !LegalOperations || HasFSHL;
8564 Res = DAG.getNode(Opcode: UseFSHL ? ISD::FSHL : ISD::FSHR, DL, VT, N1: LHSShiftArg,
8565 N2: RHSShiftArg, N3: UseFSHL ? LHSShiftAmt : RHSShiftAmt);
8566 }
8567
8568 return ApplyMasks(Res);
8569 }
8570
8571 // Even pre-legalization, we can't easily rotate/funnel-shift by a variable
8572 // shift.
8573 if (!HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
8574 return SDValue();
8575
8576 // If there is a mask here, and we have a variable shift, we can't be sure
8577 // that we're masking out the right stuff.
8578 if (LHSMask.getNode() || RHSMask.getNode())
8579 return SDValue();
8580
8581 // If the shift amount is sign/zext/any-extended just peel it off.
8582 SDValue LExtOp0 = LHSShiftAmt;
8583 SDValue RExtOp0 = RHSShiftAmt;
8584 if ((LHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
8585 LHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
8586 LHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
8587 LHSShiftAmt.getOpcode() == ISD::TRUNCATE) &&
8588 (RHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
8589 RHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
8590 RHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
8591 RHSShiftAmt.getOpcode() == ISD::TRUNCATE)) {
8592 LExtOp0 = LHSShiftAmt.getOperand(i: 0);
8593 RExtOp0 = RHSShiftAmt.getOperand(i: 0);
8594 }
8595
8596 if (IsRotate && (HasROTL || HasROTR)) {
8597 SDValue TryL =
8598 MatchRotatePosNeg(Shifted: LHSShiftArg, Pos: LHSShiftAmt, Neg: RHSShiftAmt, InnerPos: LExtOp0,
8599 InnerNeg: RExtOp0, HasPos: HasROTL, PosOpcode: ISD::ROTL, NegOpcode: ISD::ROTR, DL);
8600 if (TryL)
8601 return TryL;
8602
8603 SDValue TryR =
8604 MatchRotatePosNeg(Shifted: RHSShiftArg, Pos: RHSShiftAmt, Neg: LHSShiftAmt, InnerPos: RExtOp0,
8605 InnerNeg: LExtOp0, HasPos: HasROTR, PosOpcode: ISD::ROTR, NegOpcode: ISD::ROTL, DL);
8606 if (TryR)
8607 return TryR;
8608 }
8609
8610 SDValue TryL =
8611 MatchFunnelPosNeg(N0: LHSShiftArg, N1: RHSShiftArg, Pos: LHSShiftAmt, Neg: RHSShiftAmt,
8612 InnerPos: LExtOp0, InnerNeg: RExtOp0, HasPos: HasFSHL, PosOpcode: ISD::FSHL, NegOpcode: ISD::FSHR, DL);
8613 if (TryL)
8614 return TryL;
8615
8616 SDValue TryR =
8617 MatchFunnelPosNeg(N0: LHSShiftArg, N1: RHSShiftArg, Pos: RHSShiftAmt, Neg: LHSShiftAmt,
8618 InnerPos: RExtOp0, InnerNeg: LExtOp0, HasPos: HasFSHR, PosOpcode: ISD::FSHR, NegOpcode: ISD::FSHL, DL);
8619 if (TryR)
8620 return TryR;
8621
8622 return SDValue();
8623}
8624
8625/// Recursively traverses the expression calculating the origin of the requested
8626/// byte of the given value. Returns std::nullopt if the provider can't be
8627/// calculated.
8628///
8629/// For all the values except the root of the expression, we verify that the
8630/// value has exactly one use and if not then return std::nullopt. This way if
8631/// the origin of the byte is returned it's guaranteed that the values which
8632/// contribute to the byte are not used outside of this expression.
8633
8634/// However, there is a special case when dealing with vector loads -- we allow
8635/// more than one use if the load is a vector type. Since the values that
8636/// contribute to the byte ultimately come from the ExtractVectorElements of the
8637/// Load, we don't care if the Load has uses other than ExtractVectorElements,
8638/// because those operations are independent from the pattern to be combined.
8639/// For vector loads, we simply care that the ByteProviders are adjacent
8640/// positions of the same vector, and their index matches the byte that is being
8641/// provided. This is captured by the \p VectorIndex algorithm. \p VectorIndex
8642/// is the index used in an ExtractVectorElement, and \p StartingIndex is the
8643/// byte position we are trying to provide for the LoadCombine. If these do
8644/// not match, then we can not combine the vector loads. \p Index uses the
8645/// byte position we are trying to provide for and is matched against the
8646/// shl and load size. The \p Index algorithm ensures the requested byte is
8647/// provided for by the pattern, and the pattern does not over provide bytes.
8648///
8649///
8650/// The supported LoadCombine pattern for vector loads is as follows
8651/// or
8652/// / \
8653/// or shl
8654/// / \ |
8655/// or shl zext
8656/// / \ | |
8657/// shl zext zext EVE*
8658/// | | | |
8659/// zext EVE* EVE* LOAD
8660/// | | |
8661/// EVE* LOAD LOAD
8662/// |
8663/// LOAD
8664///
8665/// *ExtractVectorElement
8666using SDByteProvider = ByteProvider<SDNode *>;
8667
8668static std::optional<SDByteProvider>
8669calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
8670 std::optional<uint64_t> VectorIndex,
8671 unsigned StartingIndex = 0) {
8672
8673 // Typical i64 by i8 pattern requires recursion up to 8 calls depth
8674 if (Depth == 10)
8675 return std::nullopt;
8676
8677 // Only allow multiple uses if the instruction is a vector load (in which
8678 // case we will use the load for every ExtractVectorElement)
8679 if (Depth && !Op.hasOneUse() &&
8680 (Op.getOpcode() != ISD::LOAD || !Op.getValueType().isVector()))
8681 return std::nullopt;
8682
8683 // Fail to combine if we have encountered anything but a LOAD after handling
8684 // an ExtractVectorElement.
8685 if (Op.getOpcode() != ISD::LOAD && VectorIndex.has_value())
8686 return std::nullopt;
8687
8688 unsigned BitWidth = Op.getValueSizeInBits();
8689 if (BitWidth % 8 != 0)
8690 return std::nullopt;
8691 unsigned ByteWidth = BitWidth / 8;
8692 assert(Index < ByteWidth && "invalid index requested");
8693 (void) ByteWidth;
8694
8695 switch (Op.getOpcode()) {
8696 case ISD::OR: {
8697 auto LHS =
8698 calculateByteProvider(Op: Op->getOperand(Num: 0), Index, Depth: Depth + 1, VectorIndex);
8699 if (!LHS)
8700 return std::nullopt;
8701 auto RHS =
8702 calculateByteProvider(Op: Op->getOperand(Num: 1), Index, Depth: Depth + 1, VectorIndex);
8703 if (!RHS)
8704 return std::nullopt;
8705
8706 if (LHS->isConstantZero())
8707 return RHS;
8708 if (RHS->isConstantZero())
8709 return LHS;
8710 return std::nullopt;
8711 }
8712 case ISD::SHL: {
8713 auto ShiftOp = dyn_cast<ConstantSDNode>(Val: Op->getOperand(Num: 1));
8714 if (!ShiftOp)
8715 return std::nullopt;
8716
8717 uint64_t BitShift = ShiftOp->getZExtValue();
8718
8719 if (BitShift % 8 != 0)
8720 return std::nullopt;
8721 uint64_t ByteShift = BitShift / 8;
8722
8723 // If we are shifting by an amount greater than the index we are trying to
8724 // provide, then do not provide anything. Otherwise, subtract the index by
8725 // the amount we shifted by.
8726 return Index < ByteShift
8727 ? SDByteProvider::getConstantZero()
8728 : calculateByteProvider(Op: Op->getOperand(Num: 0), Index: Index - ByteShift,
8729 Depth: Depth + 1, VectorIndex, StartingIndex: Index);
8730 }
8731 case ISD::ANY_EXTEND:
8732 case ISD::SIGN_EXTEND:
8733 case ISD::ZERO_EXTEND: {
8734 SDValue NarrowOp = Op->getOperand(Num: 0);
8735 unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
8736 if (NarrowBitWidth % 8 != 0)
8737 return std::nullopt;
8738 uint64_t NarrowByteWidth = NarrowBitWidth / 8;
8739
8740 if (Index >= NarrowByteWidth)
8741 return Op.getOpcode() == ISD::ZERO_EXTEND
8742 ? std::optional<SDByteProvider>(
8743 SDByteProvider::getConstantZero())
8744 : std::nullopt;
8745 return calculateByteProvider(Op: NarrowOp, Index, Depth: Depth + 1, VectorIndex,
8746 StartingIndex);
8747 }
8748 case ISD::BSWAP:
8749 return calculateByteProvider(Op: Op->getOperand(Num: 0), Index: ByteWidth - Index - 1,
8750 Depth: Depth + 1, VectorIndex, StartingIndex);
8751 case ISD::EXTRACT_VECTOR_ELT: {
8752 auto OffsetOp = dyn_cast<ConstantSDNode>(Val: Op->getOperand(Num: 1));
8753 if (!OffsetOp)
8754 return std::nullopt;
8755
8756 VectorIndex = OffsetOp->getZExtValue();
8757
8758 SDValue NarrowOp = Op->getOperand(Num: 0);
8759 unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
8760 if (NarrowBitWidth % 8 != 0)
8761 return std::nullopt;
8762 uint64_t NarrowByteWidth = NarrowBitWidth / 8;
8763
8764 // Check to see if the position of the element in the vector corresponds
8765 // with the byte we are trying to provide for. In the case of a vector of
8766 // i8, this simply means the VectorIndex == StartingIndex. For non i8 cases,
8767 // the element will provide a range of bytes. For example, if we have a
8768 // vector of i16s, each element provides two bytes (V[1] provides byte 2 and
8769 // 3).
8770 if (*VectorIndex * NarrowByteWidth > StartingIndex)
8771 return std::nullopt;
8772 if ((*VectorIndex + 1) * NarrowByteWidth <= StartingIndex)
8773 return std::nullopt;
8774
8775 return calculateByteProvider(Op: Op->getOperand(Num: 0), Index, Depth: Depth + 1,
8776 VectorIndex, StartingIndex);
8777 }
8778 case ISD::LOAD: {
8779 auto L = cast<LoadSDNode>(Val: Op.getNode());
8780 if (!L->isSimple() || L->isIndexed())
8781 return std::nullopt;
8782
8783 unsigned NarrowBitWidth = L->getMemoryVT().getSizeInBits();
8784 if (NarrowBitWidth % 8 != 0)
8785 return std::nullopt;
8786 uint64_t NarrowByteWidth = NarrowBitWidth / 8;
8787
8788 // If the width of the load does not reach byte we are trying to provide for
8789 // and it is not a ZEXTLOAD, then the load does not provide for the byte in
8790 // question
8791 if (Index >= NarrowByteWidth)
8792 return L->getExtensionType() == ISD::ZEXTLOAD
8793 ? std::optional<SDByteProvider>(
8794 SDByteProvider::getConstantZero())
8795 : std::nullopt;
8796
8797 unsigned BPVectorIndex = VectorIndex.value_or(u: 0U);
8798 return SDByteProvider::getSrc(Val: L, ByteOffset: Index, VectorOffset: BPVectorIndex);
8799 }
8800 }
8801
8802 return std::nullopt;
8803}
8804
8805static unsigned littleEndianByteAt(unsigned BW, unsigned i) {
8806 return i;
8807}
8808
8809static unsigned bigEndianByteAt(unsigned BW, unsigned i) {
8810 return BW - i - 1;
8811}
8812
8813// Check if the bytes offsets we are looking at match with either big or
8814// little endian value loaded. Return true for big endian, false for little
8815// endian, and std::nullopt if match failed.
8816static std::optional<bool> isBigEndian(const ArrayRef<int64_t> ByteOffsets,
8817 int64_t FirstOffset) {
8818 // The endian can be decided only when it is 2 bytes at least.
8819 unsigned Width = ByteOffsets.size();
8820 if (Width < 2)
8821 return std::nullopt;
8822
8823 bool BigEndian = true, LittleEndian = true;
8824 for (unsigned i = 0; i < Width; i++) {
8825 int64_t CurrentByteOffset = ByteOffsets[i] - FirstOffset;
8826 LittleEndian &= CurrentByteOffset == littleEndianByteAt(BW: Width, i);
8827 BigEndian &= CurrentByteOffset == bigEndianByteAt(BW: Width, i);
8828 if (!BigEndian && !LittleEndian)
8829 return std::nullopt;
8830 }
8831
8832 assert((BigEndian != LittleEndian) && "It should be either big endian or"
8833 "little endian");
8834 return BigEndian;
8835}
8836
8837static SDValue stripTruncAndExt(SDValue Value) {
8838 switch (Value.getOpcode()) {
8839 case ISD::TRUNCATE:
8840 case ISD::ZERO_EXTEND:
8841 case ISD::SIGN_EXTEND:
8842 case ISD::ANY_EXTEND:
8843 return stripTruncAndExt(Value: Value.getOperand(i: 0));
8844 }
8845 return Value;
8846}
8847
8848/// Match a pattern where a wide type scalar value is stored by several narrow
8849/// stores. Fold it into a single store or a BSWAP and a store if the targets
8850/// supports it.
8851///
8852/// Assuming little endian target:
8853/// i8 *p = ...
8854/// i32 val = ...
8855/// p[0] = (val >> 0) & 0xFF;
8856/// p[1] = (val >> 8) & 0xFF;
8857/// p[2] = (val >> 16) & 0xFF;
8858/// p[3] = (val >> 24) & 0xFF;
8859/// =>
8860/// *((i32)p) = val;
8861///
8862/// i8 *p = ...
8863/// i32 val = ...
8864/// p[0] = (val >> 24) & 0xFF;
8865/// p[1] = (val >> 16) & 0xFF;
8866/// p[2] = (val >> 8) & 0xFF;
8867/// p[3] = (val >> 0) & 0xFF;
8868/// =>
8869/// *((i32)p) = BSWAP(val);
8870SDValue DAGCombiner::mergeTruncStores(StoreSDNode *N) {
8871 // The matching looks for "store (trunc x)" patterns that appear early but are
8872 // likely to be replaced by truncating store nodes during combining.
8873 // TODO: If there is evidence that running this later would help, this
8874 // limitation could be removed. Legality checks may need to be added
8875 // for the created store and optional bswap/rotate.
8876 if (LegalOperations || OptLevel == CodeGenOptLevel::None)
8877 return SDValue();
8878
8879 // We only handle merging simple stores of 1-4 bytes.
8880 // TODO: Allow unordered atomics when wider type is legal (see D66309)
8881 EVT MemVT = N->getMemoryVT();
8882 if (!(MemVT == MVT::i8 || MemVT == MVT::i16 || MemVT == MVT::i32) ||
8883 !N->isSimple() || N->isIndexed())
8884 return SDValue();
8885
8886 // Collect all of the stores in the chain, upto the maximum store width (i64).
8887 SDValue Chain = N->getChain();
8888 SmallVector<StoreSDNode *, 8> Stores = {N};
8889 unsigned NarrowNumBits = MemVT.getScalarSizeInBits();
8890 unsigned MaxWideNumBits = 64;
8891 unsigned MaxStores = MaxWideNumBits / NarrowNumBits;
8892 while (auto *Store = dyn_cast<StoreSDNode>(Val&: Chain)) {
8893 // All stores must be the same size to ensure that we are writing all of the
8894 // bytes in the wide value.
8895 // This store should have exactly one use as a chain operand for another
8896 // store in the merging set. If there are other chain uses, then the
8897 // transform may not be safe because order of loads/stores outside of this
8898 // set may not be preserved.
8899 // TODO: We could allow multiple sizes by tracking each stored byte.
8900 if (Store->getMemoryVT() != MemVT || !Store->isSimple() ||
8901 Store->isIndexed() || !Store->hasOneUse())
8902 return SDValue();
8903 Stores.push_back(Elt: Store);
8904 Chain = Store->getChain();
8905 if (MaxStores < Stores.size())
8906 return SDValue();
8907 }
8908 // There is no reason to continue if we do not have at least a pair of stores.
8909 if (Stores.size() < 2)
8910 return SDValue();
8911
8912 // Handle simple types only.
8913 LLVMContext &Context = *DAG.getContext();
8914 unsigned NumStores = Stores.size();
8915 unsigned WideNumBits = NumStores * NarrowNumBits;
8916 EVT WideVT = EVT::getIntegerVT(Context, BitWidth: WideNumBits);
8917 if (WideVT != MVT::i16 && WideVT != MVT::i32 && WideVT != MVT::i64)
8918 return SDValue();
8919
8920 // Check if all bytes of the source value that we are looking at are stored
8921 // to the same base address. Collect offsets from Base address into OffsetMap.
8922 SDValue SourceValue;
8923 SmallVector<int64_t, 8> OffsetMap(NumStores, INT64_MAX);
8924 int64_t FirstOffset = INT64_MAX;
8925 StoreSDNode *FirstStore = nullptr;
8926 std::optional<BaseIndexOffset> Base;
8927 for (auto *Store : Stores) {
8928 // All the stores store different parts of the CombinedValue. A truncate is
8929 // required to get the partial value.
8930 SDValue Trunc = Store->getValue();
8931 if (Trunc.getOpcode() != ISD::TRUNCATE)
8932 return SDValue();
8933 // Other than the first/last part, a shift operation is required to get the
8934 // offset.
8935 int64_t Offset = 0;
8936 SDValue WideVal = Trunc.getOperand(i: 0);
8937 if ((WideVal.getOpcode() == ISD::SRL || WideVal.getOpcode() == ISD::SRA) &&
8938 isa<ConstantSDNode>(Val: WideVal.getOperand(i: 1))) {
8939 // The shift amount must be a constant multiple of the narrow type.
8940 // It is translated to the offset address in the wide source value "y".
8941 //
8942 // x = srl y, ShiftAmtC
8943 // i8 z = trunc x
8944 // store z, ...
8945 uint64_t ShiftAmtC = WideVal.getConstantOperandVal(i: 1);
8946 if (ShiftAmtC % NarrowNumBits != 0)
8947 return SDValue();
8948
8949 Offset = ShiftAmtC / NarrowNumBits;
8950 WideVal = WideVal.getOperand(i: 0);
8951 }
8952
8953 // Stores must share the same source value with different offsets.
8954 // Truncate and extends should be stripped to get the single source value.
8955 if (!SourceValue)
8956 SourceValue = WideVal;
8957 else if (stripTruncAndExt(Value: SourceValue) != stripTruncAndExt(Value: WideVal))
8958 return SDValue();
8959 else if (SourceValue.getValueType() != WideVT) {
8960 if (WideVal.getValueType() == WideVT ||
8961 WideVal.getScalarValueSizeInBits() >
8962 SourceValue.getScalarValueSizeInBits())
8963 SourceValue = WideVal;
8964 // Give up if the source value type is smaller than the store size.
8965 if (SourceValue.getScalarValueSizeInBits() < WideVT.getScalarSizeInBits())
8966 return SDValue();
8967 }
8968
8969 // Stores must share the same base address.
8970 BaseIndexOffset Ptr = BaseIndexOffset::match(N: Store, DAG);
8971 int64_t ByteOffsetFromBase = 0;
8972 if (!Base)
8973 Base = Ptr;
8974 else if (!Base->equalBaseIndex(Other: Ptr, DAG, Off&: ByteOffsetFromBase))
8975 return SDValue();
8976
8977 // Remember the first store.
8978 if (ByteOffsetFromBase < FirstOffset) {
8979 FirstStore = Store;
8980 FirstOffset = ByteOffsetFromBase;
8981 }
8982 // Map the offset in the store and the offset in the combined value, and
8983 // early return if it has been set before.
8984 if (Offset < 0 || Offset >= NumStores || OffsetMap[Offset] != INT64_MAX)
8985 return SDValue();
8986 OffsetMap[Offset] = ByteOffsetFromBase;
8987 }
8988
8989 assert(FirstOffset != INT64_MAX && "First byte offset must be set");
8990 assert(FirstStore && "First store must be set");
8991
8992 // Check that a store of the wide type is both allowed and fast on the target
8993 const DataLayout &Layout = DAG.getDataLayout();
8994 unsigned Fast = 0;
8995 bool Allowed = TLI.allowsMemoryAccess(Context, DL: Layout, VT: WideVT,
8996 MMO: *FirstStore->getMemOperand(), Fast: &Fast);
8997 if (!Allowed || !Fast)
8998 return SDValue();
8999
9000 // Check if the pieces of the value are going to the expected places in memory
9001 // to merge the stores.
9002 auto checkOffsets = [&](bool MatchLittleEndian) {
9003 if (MatchLittleEndian) {
9004 for (unsigned i = 0; i != NumStores; ++i)
9005 if (OffsetMap[i] != i * (NarrowNumBits / 8) + FirstOffset)
9006 return false;
9007 } else { // MatchBigEndian by reversing loop counter.
9008 for (unsigned i = 0, j = NumStores - 1; i != NumStores; ++i, --j)
9009 if (OffsetMap[j] != i * (NarrowNumBits / 8) + FirstOffset)
9010 return false;
9011 }
9012 return true;
9013 };
9014
9015 // Check if the offsets line up for the native data layout of this target.
9016 bool NeedBswap = false;
9017 bool NeedRotate = false;
9018 if (!checkOffsets(Layout.isLittleEndian())) {
9019 // Special-case: check if byte offsets line up for the opposite endian.
9020 if (NarrowNumBits == 8 && checkOffsets(Layout.isBigEndian()))
9021 NeedBswap = true;
9022 else if (NumStores == 2 && checkOffsets(Layout.isBigEndian()))
9023 NeedRotate = true;
9024 else
9025 return SDValue();
9026 }
9027
9028 SDLoc DL(N);
9029 if (WideVT != SourceValue.getValueType()) {
9030 assert(SourceValue.getValueType().getScalarSizeInBits() > WideNumBits &&
9031 "Unexpected store value to merge");
9032 SourceValue = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: WideVT, Operand: SourceValue);
9033 }
9034
9035 // Before legalize we can introduce illegal bswaps/rotates which will be later
9036 // converted to an explicit bswap sequence. This way we end up with a single
9037 // store and byte shuffling instead of several stores and byte shuffling.
9038 if (NeedBswap) {
9039 SourceValue = DAG.getNode(Opcode: ISD::BSWAP, DL, VT: WideVT, Operand: SourceValue);
9040 } else if (NeedRotate) {
9041 assert(WideNumBits % 2 == 0 && "Unexpected type for rotate");
9042 SDValue RotAmt = DAG.getConstant(Val: WideNumBits / 2, DL, VT: WideVT);
9043 SourceValue = DAG.getNode(Opcode: ISD::ROTR, DL, VT: WideVT, N1: SourceValue, N2: RotAmt);
9044 }
9045
9046 SDValue NewStore =
9047 DAG.getStore(Chain, dl: DL, Val: SourceValue, Ptr: FirstStore->getBasePtr(),
9048 PtrInfo: FirstStore->getPointerInfo(), Alignment: FirstStore->getAlign());
9049
9050 // Rely on other DAG combine rules to remove the other individual stores.
9051 DAG.ReplaceAllUsesWith(From: N, To: NewStore.getNode());
9052 return NewStore;
9053}
9054
9055/// Match a pattern where a wide type scalar value is loaded by several narrow
9056/// loads and combined by shifts and ors. Fold it into a single load or a load
9057/// and a BSWAP if the targets supports it.
9058///
9059/// Assuming little endian target:
9060/// i8 *a = ...
9061/// i32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
9062/// =>
9063/// i32 val = *((i32)a)
9064///
9065/// i8 *a = ...
9066/// i32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
9067/// =>
9068/// i32 val = BSWAP(*((i32)a))
9069///
9070/// TODO: This rule matches complex patterns with OR node roots and doesn't
9071/// interact well with the worklist mechanism. When a part of the pattern is
9072/// updated (e.g. one of the loads) its direct users are put into the worklist,
9073/// but the root node of the pattern which triggers the load combine is not
9074/// necessarily a direct user of the changed node. For example, once the address
9075/// of t28 load is reassociated load combine won't be triggered:
9076/// t25: i32 = add t4, Constant:i32<2>
9077/// t26: i64 = sign_extend t25
9078/// t27: i64 = add t2, t26
9079/// t28: i8,ch = load<LD1[%tmp9]> t0, t27, undef:i64
9080/// t29: i32 = zero_extend t28
9081/// t32: i32 = shl t29, Constant:i8<8>
9082/// t33: i32 = or t23, t32
9083/// As a possible fix visitLoad can check if the load can be a part of a load
9084/// combine pattern and add corresponding OR roots to the worklist.
9085SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
9086 assert(N->getOpcode() == ISD::OR &&
9087 "Can only match load combining against OR nodes");
9088
9089 // Handles simple types only
9090 EVT VT = N->getValueType(ResNo: 0);
9091 if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
9092 return SDValue();
9093 unsigned ByteWidth = VT.getSizeInBits() / 8;
9094
9095 bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian();
9096 auto MemoryByteOffset = [&](SDByteProvider P) {
9097 assert(P.hasSrc() && "Must be a memory byte provider");
9098 auto *Load = cast<LoadSDNode>(Val: P.Src.value());
9099
9100 unsigned LoadBitWidth = Load->getMemoryVT().getScalarSizeInBits();
9101
9102 assert(LoadBitWidth % 8 == 0 &&
9103 "can only analyze providers for individual bytes not bit");
9104 unsigned LoadByteWidth = LoadBitWidth / 8;
9105 return IsBigEndianTarget ? bigEndianByteAt(BW: LoadByteWidth, i: P.DestOffset)
9106 : littleEndianByteAt(BW: LoadByteWidth, i: P.DestOffset);
9107 };
9108
9109 std::optional<BaseIndexOffset> Base;
9110 SDValue Chain;
9111
9112 SmallPtrSet<LoadSDNode *, 8> Loads;
9113 std::optional<SDByteProvider> FirstByteProvider;
9114 int64_t FirstOffset = INT64_MAX;
9115
9116 // Check if all the bytes of the OR we are looking at are loaded from the same
9117 // base address. Collect bytes offsets from Base address in ByteOffsets.
9118 SmallVector<int64_t, 8> ByteOffsets(ByteWidth);
9119 unsigned ZeroExtendedBytes = 0;
9120 for (int i = ByteWidth - 1; i >= 0; --i) {
9121 auto P =
9122 calculateByteProvider(Op: SDValue(N, 0), Index: i, Depth: 0, /*VectorIndex*/ std::nullopt,
9123 /*StartingIndex*/ i);
9124 if (!P)
9125 return SDValue();
9126
9127 if (P->isConstantZero()) {
9128 // It's OK for the N most significant bytes to be 0, we can just
9129 // zero-extend the load.
9130 if (++ZeroExtendedBytes != (ByteWidth - static_cast<unsigned>(i)))
9131 return SDValue();
9132 continue;
9133 }
9134 assert(P->hasSrc() && "provenance should either be memory or zero");
9135 auto *L = cast<LoadSDNode>(Val: P->Src.value());
9136
9137 // All loads must share the same chain
9138 SDValue LChain = L->getChain();
9139 if (!Chain)
9140 Chain = LChain;
9141 else if (Chain != LChain)
9142 return SDValue();
9143
9144 // Loads must share the same base address
9145 BaseIndexOffset Ptr = BaseIndexOffset::match(N: L, DAG);
9146 int64_t ByteOffsetFromBase = 0;
9147
9148 // For vector loads, the expected load combine pattern will have an
9149 // ExtractElement for each index in the vector. While each of these
9150 // ExtractElements will be accessing the same base address as determined
9151 // by the load instruction, the actual bytes they interact with will differ
9152 // due to different ExtractElement indices. To accurately determine the
9153 // byte position of an ExtractElement, we offset the base load ptr with
9154 // the index multiplied by the byte size of each element in the vector.
9155 if (L->getMemoryVT().isVector()) {
9156 unsigned LoadWidthInBit = L->getMemoryVT().getScalarSizeInBits();
9157 if (LoadWidthInBit % 8 != 0)
9158 return SDValue();
9159 unsigned ByteOffsetFromVector = P->SrcOffset * LoadWidthInBit / 8;
9160 Ptr.addToOffset(VectorOff: ByteOffsetFromVector);
9161 }
9162
9163 if (!Base)
9164 Base = Ptr;
9165
9166 else if (!Base->equalBaseIndex(Other: Ptr, DAG, Off&: ByteOffsetFromBase))
9167 return SDValue();
9168
9169 // Calculate the offset of the current byte from the base address
9170 ByteOffsetFromBase += MemoryByteOffset(*P);
9171 ByteOffsets[i] = ByteOffsetFromBase;
9172
9173 // Remember the first byte load
9174 if (ByteOffsetFromBase < FirstOffset) {
9175 FirstByteProvider = P;
9176 FirstOffset = ByteOffsetFromBase;
9177 }
9178
9179 Loads.insert(Ptr: L);
9180 }
9181
9182 assert(!Loads.empty() && "All the bytes of the value must be loaded from "
9183 "memory, so there must be at least one load which produces the value");
9184 assert(Base && "Base address of the accessed memory location must be set");
9185 assert(FirstOffset != INT64_MAX && "First byte offset must be set");
9186
9187 bool NeedsZext = ZeroExtendedBytes > 0;
9188
9189 EVT MemVT =
9190 EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: (ByteWidth - ZeroExtendedBytes) * 8);
9191
9192 if (!MemVT.isSimple())
9193 return SDValue();
9194
9195 // Before legalize we can introduce too wide illegal loads which will be later
9196 // split into legal sized loads. This enables us to combine i64 load by i8
9197 // patterns to a couple of i32 loads on 32 bit targets.
9198 if (LegalOperations &&
9199 !TLI.isOperationLegal(Op: NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD,
9200 VT: MemVT))
9201 return SDValue();
9202
9203 // Check if the bytes of the OR we are looking at match with either big or
9204 // little endian value load
9205 std::optional<bool> IsBigEndian = isBigEndian(
9206 ByteOffsets: ArrayRef(ByteOffsets).drop_back(N: ZeroExtendedBytes), FirstOffset);
9207 if (!IsBigEndian)
9208 return SDValue();
9209
9210 assert(FirstByteProvider && "must be set");
9211
9212 // Ensure that the first byte is loaded from zero offset of the first load.
9213 // So the combined value can be loaded from the first load address.
9214 if (MemoryByteOffset(*FirstByteProvider) != 0)
9215 return SDValue();
9216 auto *FirstLoad = cast<LoadSDNode>(Val: FirstByteProvider->Src.value());
9217
9218 // The node we are looking at matches with the pattern, check if we can
9219 // replace it with a single (possibly zero-extended) load and bswap + shift if
9220 // needed.
9221
9222 // If the load needs byte swap check if the target supports it
9223 bool NeedsBswap = IsBigEndianTarget != *IsBigEndian;
9224
9225 // Before legalize we can introduce illegal bswaps which will be later
9226 // converted to an explicit bswap sequence. This way we end up with a single
9227 // load and byte shuffling instead of several loads and byte shuffling.
9228 // We do not introduce illegal bswaps when zero-extending as this tends to
9229 // introduce too many arithmetic instructions.
9230 if (NeedsBswap && (LegalOperations || NeedsZext) &&
9231 !TLI.isOperationLegal(Op: ISD::BSWAP, VT))
9232 return SDValue();
9233
9234 // If we need to bswap and zero extend, we have to insert a shift. Check that
9235 // it is legal.
9236 if (NeedsBswap && NeedsZext && LegalOperations &&
9237 !TLI.isOperationLegal(Op: ISD::SHL, VT))
9238 return SDValue();
9239
9240 // Check that a load of the wide type is both allowed and fast on the target
9241 unsigned Fast = 0;
9242 bool Allowed =
9243 TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: MemVT,
9244 MMO: *FirstLoad->getMemOperand(), Fast: &Fast);
9245 if (!Allowed || !Fast)
9246 return SDValue();
9247
9248 SDValue NewLoad =
9249 DAG.getExtLoad(ExtType: NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, dl: SDLoc(N), VT,
9250 Chain, Ptr: FirstLoad->getBasePtr(),
9251 PtrInfo: FirstLoad->getPointerInfo(), MemVT, Alignment: FirstLoad->getAlign());
9252
9253 // Transfer chain users from old loads to the new load.
9254 for (LoadSDNode *L : Loads)
9255 DAG.makeEquivalentMemoryOrdering(OldLoad: L, NewMemOp: NewLoad);
9256
9257 if (!NeedsBswap)
9258 return NewLoad;
9259
9260 SDValue ShiftedLoad =
9261 NeedsZext
9262 ? DAG.getNode(Opcode: ISD::SHL, DL: SDLoc(N), VT, N1: NewLoad,
9263 N2: DAG.getShiftAmountConstant(Val: ZeroExtendedBytes * 8, VT,
9264 DL: SDLoc(N), LegalTypes: LegalOperations))
9265 : NewLoad;
9266 return DAG.getNode(Opcode: ISD::BSWAP, DL: SDLoc(N), VT, Operand: ShiftedLoad);
9267}
9268
9269// If the target has andn, bsl, or a similar bit-select instruction,
9270// we want to unfold masked merge, with canonical pattern of:
9271// | A | |B|
9272// ((x ^ y) & m) ^ y
9273// | D |
9274// Into:
9275// (x & m) | (y & ~m)
9276// If y is a constant, m is not a 'not', and the 'andn' does not work with
9277// immediates, we unfold into a different pattern:
9278// ~(~x & m) & (m | y)
9279// If x is a constant, m is a 'not', and the 'andn' does not work with
9280// immediates, we unfold into a different pattern:
9281// (x | ~m) & ~(~m & ~y)
9282// NOTE: we don't unfold the pattern if 'xor' is actually a 'not', because at
9283// the very least that breaks andnpd / andnps patterns, and because those
9284// patterns are simplified in IR and shouldn't be created in the DAG
9285SDValue DAGCombiner::unfoldMaskedMerge(SDNode *N) {
9286 assert(N->getOpcode() == ISD::XOR);
9287
9288 // Don't touch 'not' (i.e. where y = -1).
9289 if (isAllOnesOrAllOnesSplat(V: N->getOperand(Num: 1)))
9290 return SDValue();
9291
9292 EVT VT = N->getValueType(ResNo: 0);
9293
9294 // There are 3 commutable operators in the pattern,
9295 // so we have to deal with 8 possible variants of the basic pattern.
9296 SDValue X, Y, M;
9297 auto matchAndXor = [&X, &Y, &M](SDValue And, unsigned XorIdx, SDValue Other) {
9298 if (And.getOpcode() != ISD::AND || !And.hasOneUse())
9299 return false;
9300 SDValue Xor = And.getOperand(i: XorIdx);
9301 if (Xor.getOpcode() != ISD::XOR || !Xor.hasOneUse())
9302 return false;
9303 SDValue Xor0 = Xor.getOperand(i: 0);
9304 SDValue Xor1 = Xor.getOperand(i: 1);
9305 // Don't touch 'not' (i.e. where y = -1).
9306 if (isAllOnesOrAllOnesSplat(V: Xor1))
9307 return false;
9308 if (Other == Xor0)
9309 std::swap(a&: Xor0, b&: Xor1);
9310 if (Other != Xor1)
9311 return false;
9312 X = Xor0;
9313 Y = Xor1;
9314 M = And.getOperand(i: XorIdx ? 0 : 1);
9315 return true;
9316 };
9317
9318 SDValue N0 = N->getOperand(Num: 0);
9319 SDValue N1 = N->getOperand(Num: 1);
9320 if (!matchAndXor(N0, 0, N1) && !matchAndXor(N0, 1, N1) &&
9321 !matchAndXor(N1, 0, N0) && !matchAndXor(N1, 1, N0))
9322 return SDValue();
9323
9324 // Don't do anything if the mask is constant. This should not be reachable.
9325 // InstCombine should have already unfolded this pattern, and DAGCombiner
9326 // probably shouldn't produce it, too.
9327 if (isa<ConstantSDNode>(Val: M.getNode()))
9328 return SDValue();
9329
9330 // We can transform if the target has AndNot
9331 if (!TLI.hasAndNot(X: M))
9332 return SDValue();
9333
9334 SDLoc DL(N);
9335
9336 // If Y is a constant, check that 'andn' works with immediates. Unless M is
9337 // a bitwise not that would already allow ANDN to be used.
9338 if (!TLI.hasAndNot(X: Y) && !isBitwiseNot(V: M)) {
9339 assert(TLI.hasAndNot(X) && "Only mask is a variable? Unreachable.");
9340 // If not, we need to do a bit more work to make sure andn is still used.
9341 SDValue NotX = DAG.getNOT(DL, Val: X, VT);
9342 SDValue LHS = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: NotX, N2: M);
9343 SDValue NotLHS = DAG.getNOT(DL, Val: LHS, VT);
9344 SDValue RHS = DAG.getNode(Opcode: ISD::OR, DL, VT, N1: M, N2: Y);
9345 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: NotLHS, N2: RHS);
9346 }
9347
9348 // If X is a constant and M is a bitwise not, check that 'andn' works with
9349 // immediates.
9350 if (!TLI.hasAndNot(X) && isBitwiseNot(V: M)) {
9351 assert(TLI.hasAndNot(Y) && "Only mask is a variable? Unreachable.");
9352 // If not, we need to do a bit more work to make sure andn is still used.
9353 SDValue NotM = M.getOperand(i: 0);
9354 SDValue LHS = DAG.getNode(Opcode: ISD::OR, DL, VT, N1: X, N2: NotM);
9355 SDValue NotY = DAG.getNOT(DL, Val: Y, VT);
9356 SDValue RHS = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: NotM, N2: NotY);
9357 SDValue NotRHS = DAG.getNOT(DL, Val: RHS, VT);
9358 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: LHS, N2: NotRHS);
9359 }
9360
9361 SDValue LHS = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: X, N2: M);
9362 SDValue NotM = DAG.getNOT(DL, Val: M, VT);
9363 SDValue RHS = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Y, N2: NotM);
9364
9365 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: LHS, N2: RHS);
9366}
9367
9368SDValue DAGCombiner::visitXOR(SDNode *N) {
9369 SDValue N0 = N->getOperand(Num: 0);
9370 SDValue N1 = N->getOperand(Num: 1);
9371 EVT VT = N0.getValueType();
9372 SDLoc DL(N);
9373
9374 // fold (xor undef, undef) -> 0. This is a common idiom (misuse).
9375 if (N0.isUndef() && N1.isUndef())
9376 return DAG.getConstant(Val: 0, DL, VT);
9377
9378 // fold (xor x, undef) -> undef
9379 if (N0.isUndef())
9380 return N0;
9381 if (N1.isUndef())
9382 return N1;
9383
9384 // fold (xor c1, c2) -> c1^c2
9385 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::XOR, DL, VT, Ops: {N0, N1}))
9386 return C;
9387
9388 // canonicalize constant to RHS
9389 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
9390 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
9391 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1, N2: N0);
9392
9393 // fold vector ops
9394 if (VT.isVector()) {
9395 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
9396 return FoldedVOp;
9397
9398 // fold (xor x, 0) -> x, vector edition
9399 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
9400 return N0;
9401 }
9402
9403 // fold (xor x, 0) -> x
9404 if (isNullConstant(V: N1))
9405 return N0;
9406
9407 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
9408 return NewSel;
9409
9410 // reassociate xor
9411 if (SDValue RXOR = reassociateOps(Opc: ISD::XOR, DL, N0, N1, Flags: N->getFlags()))
9412 return RXOR;
9413
9414 // Fold xor(vecreduce(x), vecreduce(y)) -> vecreduce(xor(x, y))
9415 if (SDValue SD =
9416 reassociateReduction(RedOpc: ISD::VECREDUCE_XOR, Opc: ISD::XOR, DL, VT, N0, N1))
9417 return SD;
9418
9419 // fold (a^b) -> (a|b) iff a and b share no bits.
9420 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::OR, VT)) &&
9421 DAG.haveNoCommonBitsSet(A: N0, B: N1))
9422 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: N0, N2: N1);
9423
9424 // look for 'add-like' folds:
9425 // XOR(N0,MIN_SIGNED_VALUE) == ADD(N0,MIN_SIGNED_VALUE)
9426 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::ADD, VT)) &&
9427 isMinSignedConstant(V: N1))
9428 if (SDValue Combined = visitADDLike(N))
9429 return Combined;
9430
9431 // fold !(x cc y) -> (x !cc y)
9432 unsigned N0Opcode = N0.getOpcode();
9433 SDValue LHS, RHS, CC;
9434 if (TLI.isConstTrueVal(N: N1) &&
9435 isSetCCEquivalent(N: N0, LHS, RHS, CC, /*MatchStrict*/ true)) {
9436 ISD::CondCode NotCC = ISD::getSetCCInverse(Operation: cast<CondCodeSDNode>(Val&: CC)->get(),
9437 Type: LHS.getValueType());
9438 if (!LegalOperations ||
9439 TLI.isCondCodeLegal(CC: NotCC, VT: LHS.getSimpleValueType())) {
9440 switch (N0Opcode) {
9441 default:
9442 llvm_unreachable("Unhandled SetCC Equivalent!");
9443 case ISD::SETCC:
9444 return DAG.getSetCC(DL: SDLoc(N0), VT, LHS, RHS, Cond: NotCC);
9445 case ISD::SELECT_CC:
9446 return DAG.getSelectCC(DL: SDLoc(N0), LHS, RHS, True: N0.getOperand(i: 2),
9447 False: N0.getOperand(i: 3), Cond: NotCC);
9448 case ISD::STRICT_FSETCC:
9449 case ISD::STRICT_FSETCCS: {
9450 if (N0.hasOneUse()) {
9451 // FIXME Can we handle multiple uses? Could we token factor the chain
9452 // results from the new/old setcc?
9453 SDValue SetCC =
9454 DAG.getSetCC(DL: SDLoc(N0), VT, LHS, RHS, Cond: NotCC,
9455 Chain: N0.getOperand(i: 0), IsSignaling: N0Opcode == ISD::STRICT_FSETCCS);
9456 CombineTo(N, Res: SetCC);
9457 DAG.ReplaceAllUsesOfValueWith(From: N0.getValue(R: 1), To: SetCC.getValue(R: 1));
9458 recursivelyDeleteUnusedNodes(N: N0.getNode());
9459 return SDValue(N, 0); // Return N so it doesn't get rechecked!
9460 }
9461 break;
9462 }
9463 }
9464 }
9465 }
9466
9467 // fold (not (zext (setcc x, y))) -> (zext (not (setcc x, y)))
9468 if (isOneConstant(V: N1) && N0Opcode == ISD::ZERO_EXTEND && N0.hasOneUse() &&
9469 isSetCCEquivalent(N: N0.getOperand(i: 0), LHS, RHS, CC)){
9470 SDValue V = N0.getOperand(i: 0);
9471 SDLoc DL0(N0);
9472 V = DAG.getNode(Opcode: ISD::XOR, DL: DL0, VT: V.getValueType(), N1: V,
9473 N2: DAG.getConstant(Val: 1, DL: DL0, VT: V.getValueType()));
9474 AddToWorklist(N: V.getNode());
9475 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: V);
9476 }
9477
9478 // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are setcc
9479 if (isOneConstant(V: N1) && VT == MVT::i1 && N0.hasOneUse() &&
9480 (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
9481 SDValue N00 = N0.getOperand(i: 0), N01 = N0.getOperand(i: 1);
9482 if (isOneUseSetCC(N: N01) || isOneUseSetCC(N: N00)) {
9483 unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
9484 N00 = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N00), VT, N1: N00, N2: N1); // N00 = ~N00
9485 N01 = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N01), VT, N1: N01, N2: N1); // N01 = ~N01
9486 AddToWorklist(N: N00.getNode()); AddToWorklist(N: N01.getNode());
9487 return DAG.getNode(Opcode: NewOpcode, DL, VT, N1: N00, N2: N01);
9488 }
9489 }
9490 // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are constants
9491 if (isAllOnesConstant(V: N1) && N0.hasOneUse() &&
9492 (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
9493 SDValue N00 = N0.getOperand(i: 0), N01 = N0.getOperand(i: 1);
9494 if (isa<ConstantSDNode>(Val: N01) || isa<ConstantSDNode>(Val: N00)) {
9495 unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
9496 N00 = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N00), VT, N1: N00, N2: N1); // N00 = ~N00
9497 N01 = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N01), VT, N1: N01, N2: N1); // N01 = ~N01
9498 AddToWorklist(N: N00.getNode()); AddToWorklist(N: N01.getNode());
9499 return DAG.getNode(Opcode: NewOpcode, DL, VT, N1: N00, N2: N01);
9500 }
9501 }
9502
9503 // fold (not (neg x)) -> (add X, -1)
9504 // FIXME: This can be generalized to (not (sub Y, X)) -> (add X, ~Y) if
9505 // Y is a constant or the subtract has a single use.
9506 if (isAllOnesConstant(V: N1) && N0.getOpcode() == ISD::SUB &&
9507 isNullConstant(V: N0.getOperand(i: 0))) {
9508 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 1),
9509 N2: DAG.getAllOnesConstant(DL, VT));
9510 }
9511
9512 // fold (not (add X, -1)) -> (neg X)
9513 if (isAllOnesConstant(V: N1) && N0.getOpcode() == ISD::ADD &&
9514 isAllOnesOrAllOnesSplat(V: N0.getOperand(i: 1))) {
9515 return DAG.getNegative(Val: N0.getOperand(i: 0), DL, VT);
9516 }
9517
9518 // fold (xor (and x, y), y) -> (and (not x), y)
9519 if (N0Opcode == ISD::AND && N0.hasOneUse() && N0->getOperand(Num: 1) == N1) {
9520 SDValue X = N0.getOperand(i: 0);
9521 SDValue NotX = DAG.getNOT(DL: SDLoc(X), Val: X, VT);
9522 AddToWorklist(N: NotX.getNode());
9523 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: NotX, N2: N1);
9524 }
9525
9526 // fold Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X)
9527 if (TLI.isOperationLegalOrCustom(Op: ISD::ABS, VT)) {
9528 SDValue A = N0Opcode == ISD::ADD ? N0 : N1;
9529 SDValue S = N0Opcode == ISD::SRA ? N0 : N1;
9530 if (A.getOpcode() == ISD::ADD && S.getOpcode() == ISD::SRA) {
9531 SDValue A0 = A.getOperand(i: 0), A1 = A.getOperand(i: 1);
9532 SDValue S0 = S.getOperand(i: 0);
9533 if ((A0 == S && A1 == S0) || (A1 == S && A0 == S0))
9534 if (ConstantSDNode *C = isConstOrConstSplat(N: S.getOperand(i: 1)))
9535 if (C->getAPIntValue() == (VT.getScalarSizeInBits() - 1))
9536 return DAG.getNode(Opcode: ISD::ABS, DL, VT, Operand: S0);
9537 }
9538 }
9539
9540 // fold (xor x, x) -> 0
9541 if (N0 == N1)
9542 return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
9543
9544 // fold (xor (shl 1, x), -1) -> (rotl ~1, x)
9545 // Here is a concrete example of this equivalence:
9546 // i16 x == 14
9547 // i16 shl == 1 << 14 == 16384 == 0b0100000000000000
9548 // i16 xor == ~(1 << 14) == 49151 == 0b1011111111111111
9549 //
9550 // =>
9551 //
9552 // i16 ~1 == 0b1111111111111110
9553 // i16 rol(~1, 14) == 0b1011111111111111
9554 //
9555 // Some additional tips to help conceptualize this transform:
9556 // - Try to see the operation as placing a single zero in a value of all ones.
9557 // - There exists no value for x which would allow the result to contain zero.
9558 // - Values of x larger than the bitwidth are undefined and do not require a
9559 // consistent result.
9560 // - Pushing the zero left requires shifting one bits in from the right.
9561 // A rotate left of ~1 is a nice way of achieving the desired result.
9562 if (TLI.isOperationLegalOrCustom(Op: ISD::ROTL, VT) && N0Opcode == ISD::SHL &&
9563 isAllOnesConstant(V: N1) && isOneConstant(V: N0.getOperand(i: 0))) {
9564 return DAG.getNode(Opcode: ISD::ROTL, DL, VT, N1: DAG.getConstant(Val: ~1, DL, VT),
9565 N2: N0.getOperand(i: 1));
9566 }
9567
9568 // Simplify: xor (op x...), (op y...) -> (op (xor x, y))
9569 if (N0Opcode == N1.getOpcode())
9570 if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
9571 return V;
9572
9573 if (SDValue R = foldLogicOfShifts(N, LogicOp: N0, ShiftOp: N1, DAG))
9574 return R;
9575 if (SDValue R = foldLogicOfShifts(N, LogicOp: N1, ShiftOp: N0, DAG))
9576 return R;
9577 if (SDValue R = foldLogicTreeOfShifts(N, LeftHand: N0, RightHand: N1, DAG))
9578 return R;
9579
9580 // Unfold ((x ^ y) & m) ^ y into (x & m) | (y & ~m) if profitable
9581 if (SDValue MM = unfoldMaskedMerge(N))
9582 return MM;
9583
9584 // Simplify the expression using non-local knowledge.
9585 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
9586 return SDValue(N, 0);
9587
9588 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
9589 return Combined;
9590
9591 return SDValue();
9592}
9593
9594/// If we have a shift-by-constant of a bitwise logic op that itself has a
9595/// shift-by-constant operand with identical opcode, we may be able to convert
9596/// that into 2 independent shifts followed by the logic op. This is a
9597/// throughput improvement.
9598static SDValue combineShiftOfShiftedLogic(SDNode *Shift, SelectionDAG &DAG) {
9599 // Match a one-use bitwise logic op.
9600 SDValue LogicOp = Shift->getOperand(Num: 0);
9601 if (!LogicOp.hasOneUse())
9602 return SDValue();
9603
9604 unsigned LogicOpcode = LogicOp.getOpcode();
9605 if (LogicOpcode != ISD::AND && LogicOpcode != ISD::OR &&
9606 LogicOpcode != ISD::XOR)
9607 return SDValue();
9608
9609 // Find a matching one-use shift by constant.
9610 unsigned ShiftOpcode = Shift->getOpcode();
9611 SDValue C1 = Shift->getOperand(Num: 1);
9612 ConstantSDNode *C1Node = isConstOrConstSplat(N: C1);
9613 assert(C1Node && "Expected a shift with constant operand");
9614 const APInt &C1Val = C1Node->getAPIntValue();
9615 auto matchFirstShift = [&](SDValue V, SDValue &ShiftOp,
9616 const APInt *&ShiftAmtVal) {
9617 if (V.getOpcode() != ShiftOpcode || !V.hasOneUse())
9618 return false;
9619
9620 ConstantSDNode *ShiftCNode = isConstOrConstSplat(N: V.getOperand(i: 1));
9621 if (!ShiftCNode)
9622 return false;
9623
9624 // Capture the shifted operand and shift amount value.
9625 ShiftOp = V.getOperand(i: 0);
9626 ShiftAmtVal = &ShiftCNode->getAPIntValue();
9627
9628 // Shift amount types do not have to match their operand type, so check that
9629 // the constants are the same width.
9630 if (ShiftAmtVal->getBitWidth() != C1Val.getBitWidth())
9631 return false;
9632
9633 // The fold is not valid if the sum of the shift values exceeds bitwidth.
9634 if ((*ShiftAmtVal + C1Val).uge(RHS: V.getScalarValueSizeInBits()))
9635 return false;
9636
9637 return true;
9638 };
9639
9640 // Logic ops are commutative, so check each operand for a match.
9641 SDValue X, Y;
9642 const APInt *C0Val;
9643 if (matchFirstShift(LogicOp.getOperand(i: 0), X, C0Val))
9644 Y = LogicOp.getOperand(i: 1);
9645 else if (matchFirstShift(LogicOp.getOperand(i: 1), X, C0Val))
9646 Y = LogicOp.getOperand(i: 0);
9647 else
9648 return SDValue();
9649
9650 // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1)
9651 SDLoc DL(Shift);
9652 EVT VT = Shift->getValueType(ResNo: 0);
9653 EVT ShiftAmtVT = Shift->getOperand(Num: 1).getValueType();
9654 SDValue ShiftSumC = DAG.getConstant(Val: *C0Val + C1Val, DL, VT: ShiftAmtVT);
9655 SDValue NewShift1 = DAG.getNode(Opcode: ShiftOpcode, DL, VT, N1: X, N2: ShiftSumC);
9656 SDValue NewShift2 = DAG.getNode(Opcode: ShiftOpcode, DL, VT, N1: Y, N2: C1);
9657 return DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: NewShift1, N2: NewShift2);
9658}
9659
9660/// Handle transforms common to the three shifts, when the shift amount is a
9661/// constant.
9662/// We are looking for: (shift being one of shl/sra/srl)
9663/// shift (binop X, C0), C1
9664/// And want to transform into:
9665/// binop (shift X, C1), (shift C0, C1)
9666SDValue DAGCombiner::visitShiftByConstant(SDNode *N) {
9667 assert(isConstOrConstSplat(N->getOperand(1)) && "Expected constant operand");
9668
9669 // Do not turn a 'not' into a regular xor.
9670 if (isBitwiseNot(V: N->getOperand(Num: 0)))
9671 return SDValue();
9672
9673 // The inner binop must be one-use, since we want to replace it.
9674 SDValue LHS = N->getOperand(Num: 0);
9675 if (!LHS.hasOneUse() || !TLI.isDesirableToCommuteWithShift(N, Level))
9676 return SDValue();
9677
9678 // Fold shift(bitop(shift(x,c1),y), c2) -> bitop(shift(x,c1+c2),shift(y,c2)).
9679 if (SDValue R = combineShiftOfShiftedLogic(Shift: N, DAG))
9680 return R;
9681
9682 // We want to pull some binops through shifts, so that we have (and (shift))
9683 // instead of (shift (and)), likewise for add, or, xor, etc. This sort of
9684 // thing happens with address calculations, so it's important to canonicalize
9685 // it.
9686 switch (LHS.getOpcode()) {
9687 default:
9688 return SDValue();
9689 case ISD::OR:
9690 case ISD::XOR:
9691 case ISD::AND:
9692 break;
9693 case ISD::ADD:
9694 if (N->getOpcode() != ISD::SHL)
9695 return SDValue(); // only shl(add) not sr[al](add).
9696 break;
9697 }
9698
9699 // FIXME: disable this unless the input to the binop is a shift by a constant
9700 // or is copy/select. Enable this in other cases when figure out it's exactly
9701 // profitable.
9702 SDValue BinOpLHSVal = LHS.getOperand(i: 0);
9703 bool IsShiftByConstant = (BinOpLHSVal.getOpcode() == ISD::SHL ||
9704 BinOpLHSVal.getOpcode() == ISD::SRA ||
9705 BinOpLHSVal.getOpcode() == ISD::SRL) &&
9706 isa<ConstantSDNode>(Val: BinOpLHSVal.getOperand(i: 1));
9707 bool IsCopyOrSelect = BinOpLHSVal.getOpcode() == ISD::CopyFromReg ||
9708 BinOpLHSVal.getOpcode() == ISD::SELECT;
9709
9710 if (!IsShiftByConstant && !IsCopyOrSelect)
9711 return SDValue();
9712
9713 if (IsCopyOrSelect && N->hasOneUse())
9714 return SDValue();
9715
9716 // Attempt to fold the constants, shifting the binop RHS by the shift amount.
9717 SDLoc DL(N);
9718 EVT VT = N->getValueType(ResNo: 0);
9719 if (SDValue NewRHS = DAG.FoldConstantArithmetic(
9720 Opcode: N->getOpcode(), DL, VT, Ops: {LHS.getOperand(i: 1), N->getOperand(Num: 1)})) {
9721 SDValue NewShift = DAG.getNode(Opcode: N->getOpcode(), DL, VT, N1: LHS.getOperand(i: 0),
9722 N2: N->getOperand(Num: 1));
9723 return DAG.getNode(Opcode: LHS.getOpcode(), DL, VT, N1: NewShift, N2: NewRHS);
9724 }
9725
9726 return SDValue();
9727}
9728
9729SDValue DAGCombiner::distributeTruncateThroughAnd(SDNode *N) {
9730 assert(N->getOpcode() == ISD::TRUNCATE);
9731 assert(N->getOperand(0).getOpcode() == ISD::AND);
9732
9733 // (truncate:TruncVT (and N00, N01C)) -> (and (truncate:TruncVT N00), TruncC)
9734 EVT TruncVT = N->getValueType(ResNo: 0);
9735 if (N->hasOneUse() && N->getOperand(Num: 0).hasOneUse() &&
9736 TLI.isTypeDesirableForOp(ISD::AND, VT: TruncVT)) {
9737 SDValue N01 = N->getOperand(Num: 0).getOperand(i: 1);
9738 if (isConstantOrConstantVector(N: N01, /* NoOpaques */ true)) {
9739 SDLoc DL(N);
9740 SDValue N00 = N->getOperand(Num: 0).getOperand(i: 0);
9741 SDValue Trunc00 = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: TruncVT, Operand: N00);
9742 SDValue Trunc01 = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: TruncVT, Operand: N01);
9743 AddToWorklist(N: Trunc00.getNode());
9744 AddToWorklist(N: Trunc01.getNode());
9745 return DAG.getNode(Opcode: ISD::AND, DL, VT: TruncVT, N1: Trunc00, N2: Trunc01);
9746 }
9747 }
9748
9749 return SDValue();
9750}
9751
9752SDValue DAGCombiner::visitRotate(SDNode *N) {
9753 SDLoc dl(N);
9754 SDValue N0 = N->getOperand(Num: 0);
9755 SDValue N1 = N->getOperand(Num: 1);
9756 EVT VT = N->getValueType(ResNo: 0);
9757 unsigned Bitsize = VT.getScalarSizeInBits();
9758
9759 // fold (rot x, 0) -> x
9760 if (isNullOrNullSplat(V: N1))
9761 return N0;
9762
9763 // fold (rot x, c) -> x iff (c % BitSize) == 0
9764 if (isPowerOf2_32(Value: Bitsize) && Bitsize > 1) {
9765 APInt ModuloMask(N1.getScalarValueSizeInBits(), Bitsize - 1);
9766 if (DAG.MaskedValueIsZero(Op: N1, Mask: ModuloMask))
9767 return N0;
9768 }
9769
9770 // fold (rot x, c) -> (rot x, c % BitSize)
9771 bool OutOfRange = false;
9772 auto MatchOutOfRange = [Bitsize, &OutOfRange](ConstantSDNode *C) {
9773 OutOfRange |= C->getAPIntValue().uge(RHS: Bitsize);
9774 return true;
9775 };
9776 if (ISD::matchUnaryPredicate(Op: N1, Match: MatchOutOfRange) && OutOfRange) {
9777 EVT AmtVT = N1.getValueType();
9778 SDValue Bits = DAG.getConstant(Val: Bitsize, DL: dl, VT: AmtVT);
9779 if (SDValue Amt =
9780 DAG.FoldConstantArithmetic(Opcode: ISD::UREM, DL: dl, VT: AmtVT, Ops: {N1, Bits}))
9781 return DAG.getNode(Opcode: N->getOpcode(), DL: dl, VT, N1: N0, N2: Amt);
9782 }
9783
9784 // rot i16 X, 8 --> bswap X
9785 auto *RotAmtC = isConstOrConstSplat(N: N1);
9786 if (RotAmtC && RotAmtC->getAPIntValue() == 8 &&
9787 VT.getScalarSizeInBits() == 16 && hasOperation(Opcode: ISD::BSWAP, VT))
9788 return DAG.getNode(Opcode: ISD::BSWAP, DL: dl, VT, Operand: N0);
9789
9790 // Simplify the operands using demanded-bits information.
9791 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
9792 return SDValue(N, 0);
9793
9794 // fold (rot* x, (trunc (and y, c))) -> (rot* x, (and (trunc y), (trunc c))).
9795 if (N1.getOpcode() == ISD::TRUNCATE &&
9796 N1.getOperand(i: 0).getOpcode() == ISD::AND) {
9797 if (SDValue NewOp1 = distributeTruncateThroughAnd(N: N1.getNode()))
9798 return DAG.getNode(Opcode: N->getOpcode(), DL: dl, VT, N1: N0, N2: NewOp1);
9799 }
9800
9801 unsigned NextOp = N0.getOpcode();
9802
9803 // fold (rot* (rot* x, c2), c1)
9804 // -> (rot* x, ((c1 % bitsize) +- (c2 % bitsize) + bitsize) % bitsize)
9805 if (NextOp == ISD::ROTL || NextOp == ISD::ROTR) {
9806 SDNode *C1 = DAG.isConstantIntBuildVectorOrConstantInt(N: N1);
9807 SDNode *C2 = DAG.isConstantIntBuildVectorOrConstantInt(N: N0.getOperand(i: 1));
9808 if (C1 && C2 && C1->getValueType(ResNo: 0) == C2->getValueType(ResNo: 0)) {
9809 EVT ShiftVT = C1->getValueType(ResNo: 0);
9810 bool SameSide = (N->getOpcode() == NextOp);
9811 unsigned CombineOp = SameSide ? ISD::ADD : ISD::SUB;
9812 SDValue BitsizeC = DAG.getConstant(Val: Bitsize, DL: dl, VT: ShiftVT);
9813 SDValue Norm1 = DAG.FoldConstantArithmetic(Opcode: ISD::UREM, DL: dl, VT: ShiftVT,
9814 Ops: {N1, BitsizeC});
9815 SDValue Norm2 = DAG.FoldConstantArithmetic(Opcode: ISD::UREM, DL: dl, VT: ShiftVT,
9816 Ops: {N0.getOperand(i: 1), BitsizeC});
9817 if (Norm1 && Norm2)
9818 if (SDValue CombinedShift = DAG.FoldConstantArithmetic(
9819 Opcode: CombineOp, DL: dl, VT: ShiftVT, Ops: {Norm1, Norm2})) {
9820 CombinedShift = DAG.FoldConstantArithmetic(Opcode: ISD::ADD, DL: dl, VT: ShiftVT,
9821 Ops: {CombinedShift, BitsizeC});
9822 SDValue CombinedShiftNorm = DAG.FoldConstantArithmetic(
9823 Opcode: ISD::UREM, DL: dl, VT: ShiftVT, Ops: {CombinedShift, BitsizeC});
9824 return DAG.getNode(Opcode: N->getOpcode(), DL: dl, VT, N1: N0->getOperand(Num: 0),
9825 N2: CombinedShiftNorm);
9826 }
9827 }
9828 }
9829 return SDValue();
9830}
9831
9832SDValue DAGCombiner::visitSHL(SDNode *N) {
9833 SDValue N0 = N->getOperand(Num: 0);
9834 SDValue N1 = N->getOperand(Num: 1);
9835 if (SDValue V = DAG.simplifyShift(X: N0, Y: N1))
9836 return V;
9837
9838 EVT VT = N0.getValueType();
9839 EVT ShiftVT = N1.getValueType();
9840 unsigned OpSizeInBits = VT.getScalarSizeInBits();
9841
9842 // fold (shl c1, c2) -> c1<<c2
9843 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL: SDLoc(N), VT, Ops: {N0, N1}))
9844 return C;
9845
9846 // fold vector ops
9847 if (VT.isVector()) {
9848 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL: SDLoc(N)))
9849 return FoldedVOp;
9850
9851 BuildVectorSDNode *N1CV = dyn_cast<BuildVectorSDNode>(Val&: N1);
9852 // If setcc produces all-one true value then:
9853 // (shl (and (setcc) N01CV) N1CV) -> (and (setcc) N01CV<<N1CV)
9854 if (N1CV && N1CV->isConstant()) {
9855 if (N0.getOpcode() == ISD::AND) {
9856 SDValue N00 = N0->getOperand(Num: 0);
9857 SDValue N01 = N0->getOperand(Num: 1);
9858 BuildVectorSDNode *N01CV = dyn_cast<BuildVectorSDNode>(Val&: N01);
9859
9860 if (N01CV && N01CV->isConstant() && N00.getOpcode() == ISD::SETCC &&
9861 TLI.getBooleanContents(Type: N00.getOperand(i: 0).getValueType()) ==
9862 TargetLowering::ZeroOrNegativeOneBooleanContent) {
9863 if (SDValue C =
9864 DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL: SDLoc(N), VT, Ops: {N01, N1}))
9865 return DAG.getNode(Opcode: ISD::AND, DL: SDLoc(N), VT, N1: N00, N2: C);
9866 }
9867 }
9868 }
9869 }
9870
9871 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
9872 return NewSel;
9873
9874 // if (shl x, c) is known to be zero, return 0
9875 if (DAG.MaskedValueIsZero(Op: SDValue(N, 0), Mask: APInt::getAllOnes(numBits: OpSizeInBits)))
9876 return DAG.getConstant(Val: 0, DL: SDLoc(N), VT);
9877
9878 // fold (shl x, (trunc (and y, c))) -> (shl x, (and (trunc y), (trunc c))).
9879 if (N1.getOpcode() == ISD::TRUNCATE &&
9880 N1.getOperand(i: 0).getOpcode() == ISD::AND) {
9881 if (SDValue NewOp1 = distributeTruncateThroughAnd(N: N1.getNode()))
9882 return DAG.getNode(Opcode: ISD::SHL, DL: SDLoc(N), VT, N1: N0, N2: NewOp1);
9883 }
9884
9885 // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2))
9886 if (N0.getOpcode() == ISD::SHL) {
9887 auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
9888 ConstantSDNode *RHS) {
9889 APInt c1 = LHS->getAPIntValue();
9890 APInt c2 = RHS->getAPIntValue();
9891 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
9892 return (c1 + c2).uge(RHS: OpSizeInBits);
9893 };
9894 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchOutOfRange))
9895 return DAG.getConstant(Val: 0, DL: SDLoc(N), VT);
9896
9897 auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
9898 ConstantSDNode *RHS) {
9899 APInt c1 = LHS->getAPIntValue();
9900 APInt c2 = RHS->getAPIntValue();
9901 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
9902 return (c1 + c2).ult(RHS: OpSizeInBits);
9903 };
9904 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchInRange)) {
9905 SDLoc DL(N);
9906 SDValue Sum = DAG.getNode(Opcode: ISD::ADD, DL, VT: ShiftVT, N1, N2: N0.getOperand(i: 1));
9907 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0.getOperand(i: 0), N2: Sum);
9908 }
9909 }
9910
9911 // fold (shl (ext (shl x, c1)), c2) -> (shl (ext x), (add c1, c2))
9912 // For this to be valid, the second form must not preserve any of the bits
9913 // that are shifted out by the inner shift in the first form. This means
9914 // the outer shift size must be >= the number of bits added by the ext.
9915 // As a corollary, we don't care what kind of ext it is.
9916 if ((N0.getOpcode() == ISD::ZERO_EXTEND ||
9917 N0.getOpcode() == ISD::ANY_EXTEND ||
9918 N0.getOpcode() == ISD::SIGN_EXTEND) &&
9919 N0.getOperand(i: 0).getOpcode() == ISD::SHL) {
9920 SDValue N0Op0 = N0.getOperand(i: 0);
9921 SDValue InnerShiftAmt = N0Op0.getOperand(i: 1);
9922 EVT InnerVT = N0Op0.getValueType();
9923 uint64_t InnerBitwidth = InnerVT.getScalarSizeInBits();
9924
9925 auto MatchOutOfRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
9926 ConstantSDNode *RHS) {
9927 APInt c1 = LHS->getAPIntValue();
9928 APInt c2 = RHS->getAPIntValue();
9929 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
9930 return c2.uge(RHS: OpSizeInBits - InnerBitwidth) &&
9931 (c1 + c2).uge(RHS: OpSizeInBits);
9932 };
9933 if (ISD::matchBinaryPredicate(LHS: InnerShiftAmt, RHS: N1, Match: MatchOutOfRange,
9934 /*AllowUndefs*/ false,
9935 /*AllowTypeMismatch*/ true))
9936 return DAG.getConstant(Val: 0, DL: SDLoc(N), VT);
9937
9938 auto MatchInRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
9939 ConstantSDNode *RHS) {
9940 APInt c1 = LHS->getAPIntValue();
9941 APInt c2 = RHS->getAPIntValue();
9942 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
9943 return c2.uge(RHS: OpSizeInBits - InnerBitwidth) &&
9944 (c1 + c2).ult(RHS: OpSizeInBits);
9945 };
9946 if (ISD::matchBinaryPredicate(LHS: InnerShiftAmt, RHS: N1, Match: MatchInRange,
9947 /*AllowUndefs*/ false,
9948 /*AllowTypeMismatch*/ true)) {
9949 SDLoc DL(N);
9950 SDValue Ext = DAG.getNode(Opcode: N0.getOpcode(), DL, VT, Operand: N0Op0.getOperand(i: 0));
9951 SDValue Sum = DAG.getZExtOrTrunc(Op: InnerShiftAmt, DL, VT: ShiftVT);
9952 Sum = DAG.getNode(Opcode: ISD::ADD, DL, VT: ShiftVT, N1: Sum, N2: N1);
9953 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Ext, N2: Sum);
9954 }
9955 }
9956
9957 // fold (shl (zext (srl x, C)), C) -> (zext (shl (srl x, C), C))
9958 // Only fold this if the inner zext has no other uses to avoid increasing
9959 // the total number of instructions.
9960 if (N0.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse() &&
9961 N0.getOperand(i: 0).getOpcode() == ISD::SRL) {
9962 SDValue N0Op0 = N0.getOperand(i: 0);
9963 SDValue InnerShiftAmt = N0Op0.getOperand(i: 1);
9964
9965 auto MatchEqual = [VT](ConstantSDNode *LHS, ConstantSDNode *RHS) {
9966 APInt c1 = LHS->getAPIntValue();
9967 APInt c2 = RHS->getAPIntValue();
9968 zeroExtendToMatch(LHS&: c1, RHS&: c2);
9969 return c1.ult(RHS: VT.getScalarSizeInBits()) && (c1 == c2);
9970 };
9971 if (ISD::matchBinaryPredicate(LHS: InnerShiftAmt, RHS: N1, Match: MatchEqual,
9972 /*AllowUndefs*/ false,
9973 /*AllowTypeMismatch*/ true)) {
9974 SDLoc DL(N);
9975 EVT InnerShiftAmtVT = N0Op0.getOperand(i: 1).getValueType();
9976 SDValue NewSHL = DAG.getZExtOrTrunc(Op: N1, DL, VT: InnerShiftAmtVT);
9977 NewSHL = DAG.getNode(Opcode: ISD::SHL, DL, VT: N0Op0.getValueType(), N1: N0Op0, N2: NewSHL);
9978 AddToWorklist(N: NewSHL.getNode());
9979 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL: SDLoc(N0), VT, Operand: NewSHL);
9980 }
9981 }
9982
9983 if (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SRA) {
9984 auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
9985 ConstantSDNode *RHS) {
9986 const APInt &LHSC = LHS->getAPIntValue();
9987 const APInt &RHSC = RHS->getAPIntValue();
9988 return LHSC.ult(RHS: OpSizeInBits) && RHSC.ult(RHS: OpSizeInBits) &&
9989 LHSC.getZExtValue() <= RHSC.getZExtValue();
9990 };
9991
9992 SDLoc DL(N);
9993
9994 // fold (shl (sr[la] exact X, C1), C2) -> (shl X, (C2-C1)) if C1 <= C2
9995 // fold (shl (sr[la] exact X, C1), C2) -> (sr[la] X, (C2-C1)) if C1 >= C2
9996 if (N0->getFlags().hasExact()) {
9997 if (ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchShiftAmount,
9998 /*AllowUndefs*/ false,
9999 /*AllowTypeMismatch*/ true)) {
10000 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
10001 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1, N2: N01);
10002 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
10003 }
10004 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchShiftAmount,
10005 /*AllowUndefs*/ false,
10006 /*AllowTypeMismatch*/ true)) {
10007 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
10008 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1: N01, N2: N1);
10009 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
10010 }
10011 }
10012
10013 // fold (shl (srl x, c1), c2) -> (and (shl x, (sub c2, c1), MASK) or
10014 // (and (srl x, (sub c1, c2), MASK)
10015 // Only fold this if the inner shift has no other uses -- if it does,
10016 // folding this will increase the total number of instructions.
10017 if (N0.getOpcode() == ISD::SRL &&
10018 (N0.getOperand(i: 1) == N1 || N0.hasOneUse()) &&
10019 TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
10020 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchShiftAmount,
10021 /*AllowUndefs*/ false,
10022 /*AllowTypeMismatch*/ true)) {
10023 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
10024 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1: N01, N2: N1);
10025 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10026 Mask = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Mask, N2: N01);
10027 Mask = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Mask, N2: Diff);
10028 SDValue Shift = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
10029 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Shift, N2: Mask);
10030 }
10031 if (ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchShiftAmount,
10032 /*AllowUndefs*/ false,
10033 /*AllowTypeMismatch*/ true)) {
10034 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
10035 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1, N2: N01);
10036 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10037 Mask = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Mask, N2: N1);
10038 SDValue Shift = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
10039 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Shift, N2: Mask);
10040 }
10041 }
10042 }
10043
10044 // fold (shl (sra x, c1), c1) -> (and x, (shl -1, c1))
10045 if (N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(i: 1) &&
10046 isConstantOrConstantVector(N: N1, /* No Opaques */ NoOpaques: true)) {
10047 SDLoc DL(N);
10048 SDValue AllBits = DAG.getAllOnesConstant(DL, VT);
10049 SDValue HiBitsMask = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: AllBits, N2: N1);
10050 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0.getOperand(i: 0), N2: HiBitsMask);
10051 }
10052
10053 // fold (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
10054 // fold (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
10055 // Variant of version done on multiply, except mul by a power of 2 is turned
10056 // into a shift.
10057 if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR) &&
10058 N0->hasOneUse() && TLI.isDesirableToCommuteWithShift(N, Level)) {
10059 SDValue N01 = N0.getOperand(i: 1);
10060 if (SDValue Shl1 =
10061 DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL: SDLoc(N1), VT, Ops: {N01, N1})) {
10062 SDValue Shl0 = DAG.getNode(Opcode: ISD::SHL, DL: SDLoc(N0), VT, N1: N0.getOperand(i: 0), N2: N1);
10063 AddToWorklist(N: Shl0.getNode());
10064 SDNodeFlags Flags;
10065 // Preserve the disjoint flag for Or.
10066 if (N0.getOpcode() == ISD::OR && N0->getFlags().hasDisjoint())
10067 Flags.setDisjoint(true);
10068 return DAG.getNode(Opcode: N0.getOpcode(), DL: SDLoc(N), VT, N1: Shl0, N2: Shl1, Flags);
10069 }
10070 }
10071
10072 // fold (shl (sext (add_nsw x, c1)), c2) -> (add (shl (sext x), c2), c1 << c2)
10073 // TODO: Add zext/add_nuw variant with suitable test coverage
10074 // TODO: Should we limit this with isLegalAddImmediate?
10075 if (N0.getOpcode() == ISD::SIGN_EXTEND &&
10076 N0.getOperand(i: 0).getOpcode() == ISD::ADD &&
10077 N0.getOperand(i: 0)->getFlags().hasNoSignedWrap() && N0->hasOneUse() &&
10078 N0.getOperand(i: 0)->hasOneUse() &&
10079 TLI.isDesirableToCommuteWithShift(N, Level)) {
10080 SDValue Add = N0.getOperand(i: 0);
10081 SDLoc DL(N0);
10082 if (SDValue ExtC = DAG.FoldConstantArithmetic(Opcode: N0.getOpcode(), DL, VT,
10083 Ops: {Add.getOperand(i: 1)})) {
10084 if (SDValue ShlC =
10085 DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL, VT, Ops: {ExtC, N1})) {
10086 SDValue ExtX = DAG.getNode(Opcode: N0.getOpcode(), DL, VT, Operand: Add.getOperand(i: 0));
10087 SDValue ShlX = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: ExtX, N2: N1);
10088 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: ShlX, N2: ShlC);
10089 }
10090 }
10091 }
10092
10093 // fold (shl (mul x, c1), c2) -> (mul x, c1 << c2)
10094 if (N0.getOpcode() == ISD::MUL && N0->hasOneUse()) {
10095 SDValue N01 = N0.getOperand(i: 1);
10096 if (SDValue Shl =
10097 DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL: SDLoc(N1), VT, Ops: {N01, N1}))
10098 return DAG.getNode(Opcode: ISD::MUL, DL: SDLoc(N), VT, N1: N0.getOperand(i: 0), N2: Shl);
10099 }
10100
10101 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
10102 if (N1C && !N1C->isOpaque())
10103 if (SDValue NewSHL = visitShiftByConstant(N))
10104 return NewSHL;
10105
10106 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
10107 return SDValue(N, 0);
10108
10109 // Fold (shl (vscale * C0), C1) to (vscale * (C0 << C1)).
10110 if (N0.getOpcode() == ISD::VSCALE && N1C) {
10111 const APInt &C0 = N0.getConstantOperandAPInt(i: 0);
10112 const APInt &C1 = N1C->getAPIntValue();
10113 return DAG.getVScale(DL: SDLoc(N), VT, MulImm: C0 << C1);
10114 }
10115
10116 // Fold (shl step_vector(C0), C1) to (step_vector(C0 << C1)).
10117 APInt ShlVal;
10118 if (N0.getOpcode() == ISD::STEP_VECTOR &&
10119 ISD::isConstantSplatVector(N: N1.getNode(), SplatValue&: ShlVal)) {
10120 const APInt &C0 = N0.getConstantOperandAPInt(i: 0);
10121 if (ShlVal.ult(RHS: C0.getBitWidth())) {
10122 APInt NewStep = C0 << ShlVal;
10123 return DAG.getStepVector(DL: SDLoc(N), ResVT: VT, StepVal: NewStep);
10124 }
10125 }
10126
10127 return SDValue();
10128}
10129
10130// Transform a right shift of a multiply into a multiply-high.
10131// Examples:
10132// (srl (mul (zext i32:$a to i64), (zext i32:$a to i64)), 32) -> (mulhu $a, $b)
10133// (sra (mul (sext i32:$a to i64), (sext i32:$a to i64)), 32) -> (mulhs $a, $b)
10134static SDValue combineShiftToMULH(SDNode *N, SelectionDAG &DAG,
10135 const TargetLowering &TLI) {
10136 assert((N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) &&
10137 "SRL or SRA node is required here!");
10138
10139 // Check the shift amount. Proceed with the transformation if the shift
10140 // amount is constant.
10141 ConstantSDNode *ShiftAmtSrc = isConstOrConstSplat(N: N->getOperand(Num: 1));
10142 if (!ShiftAmtSrc)
10143 return SDValue();
10144
10145 SDLoc DL(N);
10146
10147 // The operation feeding into the shift must be a multiply.
10148 SDValue ShiftOperand = N->getOperand(Num: 0);
10149 if (ShiftOperand.getOpcode() != ISD::MUL)
10150 return SDValue();
10151
10152 // Both operands must be equivalent extend nodes.
10153 SDValue LeftOp = ShiftOperand.getOperand(i: 0);
10154 SDValue RightOp = ShiftOperand.getOperand(i: 1);
10155
10156 bool IsSignExt = LeftOp.getOpcode() == ISD::SIGN_EXTEND;
10157 bool IsZeroExt = LeftOp.getOpcode() == ISD::ZERO_EXTEND;
10158
10159 if (!IsSignExt && !IsZeroExt)
10160 return SDValue();
10161
10162 EVT NarrowVT = LeftOp.getOperand(i: 0).getValueType();
10163 unsigned NarrowVTSize = NarrowVT.getScalarSizeInBits();
10164
10165 // return true if U may use the lower bits of its operands
10166 auto UserOfLowerBits = [NarrowVTSize](SDNode *U) {
10167 if (U->getOpcode() != ISD::SRL && U->getOpcode() != ISD::SRA) {
10168 return true;
10169 }
10170 ConstantSDNode *UShiftAmtSrc = isConstOrConstSplat(N: U->getOperand(Num: 1));
10171 if (!UShiftAmtSrc) {
10172 return true;
10173 }
10174 unsigned UShiftAmt = UShiftAmtSrc->getZExtValue();
10175 return UShiftAmt < NarrowVTSize;
10176 };
10177
10178 // If the lower part of the MUL is also used and MUL_LOHI is supported
10179 // do not introduce the MULH in favor of MUL_LOHI
10180 unsigned MulLoHiOp = IsSignExt ? ISD::SMUL_LOHI : ISD::UMUL_LOHI;
10181 if (!ShiftOperand.hasOneUse() &&
10182 TLI.isOperationLegalOrCustom(Op: MulLoHiOp, VT: NarrowVT) &&
10183 llvm::any_of(Range: ShiftOperand->uses(), P: UserOfLowerBits)) {
10184 return SDValue();
10185 }
10186
10187 SDValue MulhRightOp;
10188 if (ConstantSDNode *Constant = isConstOrConstSplat(N: RightOp)) {
10189 unsigned ActiveBits = IsSignExt
10190 ? Constant->getAPIntValue().getSignificantBits()
10191 : Constant->getAPIntValue().getActiveBits();
10192 if (ActiveBits > NarrowVTSize)
10193 return SDValue();
10194 MulhRightOp = DAG.getConstant(
10195 Val: Constant->getAPIntValue().trunc(width: NarrowVT.getScalarSizeInBits()), DL,
10196 VT: NarrowVT);
10197 } else {
10198 if (LeftOp.getOpcode() != RightOp.getOpcode())
10199 return SDValue();
10200 // Check that the two extend nodes are the same type.
10201 if (NarrowVT != RightOp.getOperand(i: 0).getValueType())
10202 return SDValue();
10203 MulhRightOp = RightOp.getOperand(i: 0);
10204 }
10205
10206 EVT WideVT = LeftOp.getValueType();
10207 // Proceed with the transformation if the wide types match.
10208 assert((WideVT == RightOp.getValueType()) &&
10209 "Cannot have a multiply node with two different operand types.");
10210
10211 // Proceed with the transformation if the wide type is twice as large
10212 // as the narrow type.
10213 if (WideVT.getScalarSizeInBits() != 2 * NarrowVTSize)
10214 return SDValue();
10215
10216 // Check the shift amount with the narrow type size.
10217 // Proceed with the transformation if the shift amount is the width
10218 // of the narrow type.
10219 unsigned ShiftAmt = ShiftAmtSrc->getZExtValue();
10220 if (ShiftAmt != NarrowVTSize)
10221 return SDValue();
10222
10223 // If the operation feeding into the MUL is a sign extend (sext),
10224 // we use mulhs. Othewise, zero extends (zext) use mulhu.
10225 unsigned MulhOpcode = IsSignExt ? ISD::MULHS : ISD::MULHU;
10226
10227 // Combine to mulh if mulh is legal/custom for the narrow type on the target
10228 // or if it is a vector type then we could transform to an acceptable type and
10229 // rely on legalization to split/combine the result.
10230 if (NarrowVT.isVector()) {
10231 EVT TransformVT = TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: NarrowVT);
10232 if (TransformVT.getVectorElementType() != NarrowVT.getVectorElementType() ||
10233 !TLI.isOperationLegalOrCustom(Op: MulhOpcode, VT: TransformVT))
10234 return SDValue();
10235 } else {
10236 if (!TLI.isOperationLegalOrCustom(Op: MulhOpcode, VT: NarrowVT))
10237 return SDValue();
10238 }
10239
10240 SDValue Result =
10241 DAG.getNode(Opcode: MulhOpcode, DL, VT: NarrowVT, N1: LeftOp.getOperand(i: 0), N2: MulhRightOp);
10242 bool IsSigned = N->getOpcode() == ISD::SRA;
10243 return DAG.getExtOrTrunc(IsSigned, Op: Result, DL, VT: WideVT);
10244}
10245
10246// fold (bswap (logic_op(bswap(x),y))) -> logic_op(x,bswap(y))
10247// This helper function accept SDNode with opcode ISD::BSWAP and ISD::BITREVERSE
10248static SDValue foldBitOrderCrossLogicOp(SDNode *N, SelectionDAG &DAG) {
10249 unsigned Opcode = N->getOpcode();
10250 if (Opcode != ISD::BSWAP && Opcode != ISD::BITREVERSE)
10251 return SDValue();
10252
10253 SDValue N0 = N->getOperand(Num: 0);
10254 EVT VT = N->getValueType(ResNo: 0);
10255 SDLoc DL(N);
10256 if (ISD::isBitwiseLogicOp(Opcode: N0.getOpcode()) && N0.hasOneUse()) {
10257 SDValue OldLHS = N0.getOperand(i: 0);
10258 SDValue OldRHS = N0.getOperand(i: 1);
10259
10260 // If both operands are bswap/bitreverse, ignore the multiuse
10261 // Otherwise need to ensure logic_op and bswap/bitreverse(x) have one use.
10262 if (OldLHS.getOpcode() == Opcode && OldRHS.getOpcode() == Opcode) {
10263 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, N1: OldLHS.getOperand(i: 0),
10264 N2: OldRHS.getOperand(i: 0));
10265 }
10266
10267 if (OldLHS.getOpcode() == Opcode && OldLHS.hasOneUse()) {
10268 SDValue NewBitReorder = DAG.getNode(Opcode, DL, VT, Operand: OldRHS);
10269 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, N1: OldLHS.getOperand(i: 0),
10270 N2: NewBitReorder);
10271 }
10272
10273 if (OldRHS.getOpcode() == Opcode && OldRHS.hasOneUse()) {
10274 SDValue NewBitReorder = DAG.getNode(Opcode, DL, VT, Operand: OldLHS);
10275 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, N1: NewBitReorder,
10276 N2: OldRHS.getOperand(i: 0));
10277 }
10278 }
10279 return SDValue();
10280}
10281
10282SDValue DAGCombiner::visitSRA(SDNode *N) {
10283 SDValue N0 = N->getOperand(Num: 0);
10284 SDValue N1 = N->getOperand(Num: 1);
10285 if (SDValue V = DAG.simplifyShift(X: N0, Y: N1))
10286 return V;
10287
10288 EVT VT = N0.getValueType();
10289 unsigned OpSizeInBits = VT.getScalarSizeInBits();
10290
10291 // fold (sra c1, c2) -> (sra c1, c2)
10292 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::SRA, DL: SDLoc(N), VT, Ops: {N0, N1}))
10293 return C;
10294
10295 // Arithmetic shifting an all-sign-bit value is a no-op.
10296 // fold (sra 0, x) -> 0
10297 // fold (sra -1, x) -> -1
10298 if (DAG.ComputeNumSignBits(Op: N0) == OpSizeInBits)
10299 return N0;
10300
10301 // fold vector ops
10302 if (VT.isVector())
10303 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL: SDLoc(N)))
10304 return FoldedVOp;
10305
10306 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
10307 return NewSel;
10308
10309 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
10310
10311 // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2))
10312 // clamp (add c1, c2) to max shift.
10313 if (N0.getOpcode() == ISD::SRA) {
10314 SDLoc DL(N);
10315 EVT ShiftVT = N1.getValueType();
10316 EVT ShiftSVT = ShiftVT.getScalarType();
10317 SmallVector<SDValue, 16> ShiftValues;
10318
10319 auto SumOfShifts = [&](ConstantSDNode *LHS, ConstantSDNode *RHS) {
10320 APInt c1 = LHS->getAPIntValue();
10321 APInt c2 = RHS->getAPIntValue();
10322 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
10323 APInt Sum = c1 + c2;
10324 unsigned ShiftSum =
10325 Sum.uge(RHS: OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue();
10326 ShiftValues.push_back(Elt: DAG.getConstant(Val: ShiftSum, DL, VT: ShiftSVT));
10327 return true;
10328 };
10329 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: SumOfShifts)) {
10330 SDValue ShiftValue;
10331 if (N1.getOpcode() == ISD::BUILD_VECTOR)
10332 ShiftValue = DAG.getBuildVector(VT: ShiftVT, DL, Ops: ShiftValues);
10333 else if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
10334 assert(ShiftValues.size() == 1 &&
10335 "Expected matchBinaryPredicate to return one element for "
10336 "SPLAT_VECTORs");
10337 ShiftValue = DAG.getSplatVector(VT: ShiftVT, DL, Op: ShiftValues[0]);
10338 } else
10339 ShiftValue = ShiftValues[0];
10340 return DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: N0.getOperand(i: 0), N2: ShiftValue);
10341 }
10342 }
10343
10344 // fold (sra (shl X, m), (sub result_size, n))
10345 // -> (sign_extend (trunc (shl X, (sub (sub result_size, n), m)))) for
10346 // result_size - n != m.
10347 // If truncate is free for the target sext(shl) is likely to result in better
10348 // code.
10349 if (N0.getOpcode() == ISD::SHL && N1C) {
10350 // Get the two constants of the shifts, CN0 = m, CN = n.
10351 const ConstantSDNode *N01C = isConstOrConstSplat(N: N0.getOperand(i: 1));
10352 if (N01C) {
10353 LLVMContext &Ctx = *DAG.getContext();
10354 // Determine what the truncate's result bitsize and type would be.
10355 EVT TruncVT = EVT::getIntegerVT(Context&: Ctx, BitWidth: OpSizeInBits - N1C->getZExtValue());
10356
10357 if (VT.isVector())
10358 TruncVT = EVT::getVectorVT(Context&: Ctx, VT: TruncVT, EC: VT.getVectorElementCount());
10359
10360 // Determine the residual right-shift amount.
10361 int ShiftAmt = N1C->getZExtValue() - N01C->getZExtValue();
10362
10363 // If the shift is not a no-op (in which case this should be just a sign
10364 // extend already), the truncated to type is legal, sign_extend is legal
10365 // on that type, and the truncate to that type is both legal and free,
10366 // perform the transform.
10367 if ((ShiftAmt > 0) &&
10368 TLI.isOperationLegalOrCustom(Op: ISD::SIGN_EXTEND, VT: TruncVT) &&
10369 TLI.isOperationLegalOrCustom(Op: ISD::TRUNCATE, VT) &&
10370 TLI.isTruncateFree(FromVT: VT, ToVT: TruncVT)) {
10371 SDLoc DL(N);
10372 SDValue Amt = DAG.getConstant(Val: ShiftAmt, DL,
10373 VT: getShiftAmountTy(LHSTy: N0.getOperand(i: 0).getValueType()));
10374 SDValue Shift = DAG.getNode(Opcode: ISD::SRL, DL, VT,
10375 N1: N0.getOperand(i: 0), N2: Amt);
10376 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: TruncVT,
10377 Operand: Shift);
10378 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL,
10379 VT: N->getValueType(ResNo: 0), Operand: Trunc);
10380 }
10381 }
10382 }
10383
10384 // We convert trunc/ext to opposing shifts in IR, but casts may be cheaper.
10385 // sra (add (shl X, N1C), AddC), N1C -->
10386 // sext (add (trunc X to (width - N1C)), AddC')
10387 // sra (sub AddC, (shl X, N1C)), N1C -->
10388 // sext (sub AddC1',(trunc X to (width - N1C)))
10389 if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::SUB) && N1C &&
10390 N0.hasOneUse()) {
10391 bool IsAdd = N0.getOpcode() == ISD::ADD;
10392 SDValue Shl = N0.getOperand(i: IsAdd ? 0 : 1);
10393 if (Shl.getOpcode() == ISD::SHL && Shl.getOperand(i: 1) == N1 &&
10394 Shl.hasOneUse()) {
10395 // TODO: AddC does not need to be a splat.
10396 if (ConstantSDNode *AddC =
10397 isConstOrConstSplat(N: N0.getOperand(i: IsAdd ? 1 : 0))) {
10398 // Determine what the truncate's type would be and ask the target if
10399 // that is a free operation.
10400 LLVMContext &Ctx = *DAG.getContext();
10401 unsigned ShiftAmt = N1C->getZExtValue();
10402 EVT TruncVT = EVT::getIntegerVT(Context&: Ctx, BitWidth: OpSizeInBits - ShiftAmt);
10403 if (VT.isVector())
10404 TruncVT = EVT::getVectorVT(Context&: Ctx, VT: TruncVT, EC: VT.getVectorElementCount());
10405
10406 // TODO: The simple type check probably belongs in the default hook
10407 // implementation and/or target-specific overrides (because
10408 // non-simple types likely require masking when legalized), but
10409 // that restriction may conflict with other transforms.
10410 if (TruncVT.isSimple() && isTypeLegal(VT: TruncVT) &&
10411 TLI.isTruncateFree(FromVT: VT, ToVT: TruncVT)) {
10412 SDLoc DL(N);
10413 SDValue Trunc = DAG.getZExtOrTrunc(Op: Shl.getOperand(i: 0), DL, VT: TruncVT);
10414 SDValue ShiftC =
10415 DAG.getConstant(Val: AddC->getAPIntValue().lshr(shiftAmt: ShiftAmt).trunc(
10416 width: TruncVT.getScalarSizeInBits()),
10417 DL, VT: TruncVT);
10418 SDValue Add;
10419 if (IsAdd)
10420 Add = DAG.getNode(Opcode: ISD::ADD, DL, VT: TruncVT, N1: Trunc, N2: ShiftC);
10421 else
10422 Add = DAG.getNode(Opcode: ISD::SUB, DL, VT: TruncVT, N1: ShiftC, N2: Trunc);
10423 return DAG.getSExtOrTrunc(Op: Add, DL, VT);
10424 }
10425 }
10426 }
10427 }
10428
10429 // fold (sra x, (trunc (and y, c))) -> (sra x, (and (trunc y), (trunc c))).
10430 if (N1.getOpcode() == ISD::TRUNCATE &&
10431 N1.getOperand(i: 0).getOpcode() == ISD::AND) {
10432 if (SDValue NewOp1 = distributeTruncateThroughAnd(N: N1.getNode()))
10433 return DAG.getNode(Opcode: ISD::SRA, DL: SDLoc(N), VT, N1: N0, N2: NewOp1);
10434 }
10435
10436 // fold (sra (trunc (sra x, c1)), c2) -> (trunc (sra x, c1 + c2))
10437 // fold (sra (trunc (srl x, c1)), c2) -> (trunc (sra x, c1 + c2))
10438 // if c1 is equal to the number of bits the trunc removes
10439 // TODO - support non-uniform vector shift amounts.
10440 if (N0.getOpcode() == ISD::TRUNCATE &&
10441 (N0.getOperand(i: 0).getOpcode() == ISD::SRL ||
10442 N0.getOperand(i: 0).getOpcode() == ISD::SRA) &&
10443 N0.getOperand(i: 0).hasOneUse() &&
10444 N0.getOperand(i: 0).getOperand(i: 1).hasOneUse() && N1C) {
10445 SDValue N0Op0 = N0.getOperand(i: 0);
10446 if (ConstantSDNode *LargeShift = isConstOrConstSplat(N: N0Op0.getOperand(i: 1))) {
10447 EVT LargeVT = N0Op0.getValueType();
10448 unsigned TruncBits = LargeVT.getScalarSizeInBits() - OpSizeInBits;
10449 if (LargeShift->getAPIntValue() == TruncBits) {
10450 SDLoc DL(N);
10451 EVT LargeShiftVT = getShiftAmountTy(LHSTy: LargeVT);
10452 SDValue Amt = DAG.getZExtOrTrunc(Op: N1, DL, VT: LargeShiftVT);
10453 Amt = DAG.getNode(Opcode: ISD::ADD, DL, VT: LargeShiftVT, N1: Amt,
10454 N2: DAG.getConstant(Val: TruncBits, DL, VT: LargeShiftVT));
10455 SDValue SRA =
10456 DAG.getNode(Opcode: ISD::SRA, DL, VT: LargeVT, N1: N0Op0.getOperand(i: 0), N2: Amt);
10457 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: SRA);
10458 }
10459 }
10460 }
10461
10462 // Simplify, based on bits shifted out of the LHS.
10463 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
10464 return SDValue(N, 0);
10465
10466 // If the sign bit is known to be zero, switch this to a SRL.
10467 if (DAG.SignBitIsZero(Op: N0))
10468 return DAG.getNode(Opcode: ISD::SRL, DL: SDLoc(N), VT, N1: N0, N2: N1);
10469
10470 if (N1C && !N1C->isOpaque())
10471 if (SDValue NewSRA = visitShiftByConstant(N))
10472 return NewSRA;
10473
10474 // Try to transform this shift into a multiply-high if
10475 // it matches the appropriate pattern detected in combineShiftToMULH.
10476 if (SDValue MULH = combineShiftToMULH(N, DAG, TLI))
10477 return MULH;
10478
10479 // Attempt to convert a sra of a load into a narrower sign-extending load.
10480 if (SDValue NarrowLoad = reduceLoadWidth(N))
10481 return NarrowLoad;
10482
10483 return SDValue();
10484}
10485
10486SDValue DAGCombiner::visitSRL(SDNode *N) {
10487 SDValue N0 = N->getOperand(Num: 0);
10488 SDValue N1 = N->getOperand(Num: 1);
10489 if (SDValue V = DAG.simplifyShift(X: N0, Y: N1))
10490 return V;
10491
10492 EVT VT = N0.getValueType();
10493 EVT ShiftVT = N1.getValueType();
10494 unsigned OpSizeInBits = VT.getScalarSizeInBits();
10495
10496 // fold (srl c1, c2) -> c1 >>u c2
10497 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::SRL, DL: SDLoc(N), VT, Ops: {N0, N1}))
10498 return C;
10499
10500 // fold vector ops
10501 if (VT.isVector())
10502 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL: SDLoc(N)))
10503 return FoldedVOp;
10504
10505 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
10506 return NewSel;
10507
10508 // if (srl x, c) is known to be zero, return 0
10509 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
10510 if (N1C &&
10511 DAG.MaskedValueIsZero(Op: SDValue(N, 0), Mask: APInt::getAllOnes(numBits: OpSizeInBits)))
10512 return DAG.getConstant(Val: 0, DL: SDLoc(N), VT);
10513
10514 // fold (srl (srl x, c1), c2) -> 0 or (srl x, (add c1, c2))
10515 if (N0.getOpcode() == ISD::SRL) {
10516 auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
10517 ConstantSDNode *RHS) {
10518 APInt c1 = LHS->getAPIntValue();
10519 APInt c2 = RHS->getAPIntValue();
10520 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
10521 return (c1 + c2).uge(RHS: OpSizeInBits);
10522 };
10523 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchOutOfRange))
10524 return DAG.getConstant(Val: 0, DL: SDLoc(N), VT);
10525
10526 auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
10527 ConstantSDNode *RHS) {
10528 APInt c1 = LHS->getAPIntValue();
10529 APInt c2 = RHS->getAPIntValue();
10530 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
10531 return (c1 + c2).ult(RHS: OpSizeInBits);
10532 };
10533 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchInRange)) {
10534 SDLoc DL(N);
10535 SDValue Sum = DAG.getNode(Opcode: ISD::ADD, DL, VT: ShiftVT, N1, N2: N0.getOperand(i: 1));
10536 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0.getOperand(i: 0), N2: Sum);
10537 }
10538 }
10539
10540 if (N1C && N0.getOpcode() == ISD::TRUNCATE &&
10541 N0.getOperand(i: 0).getOpcode() == ISD::SRL) {
10542 SDValue InnerShift = N0.getOperand(i: 0);
10543 // TODO - support non-uniform vector shift amounts.
10544 if (auto *N001C = isConstOrConstSplat(N: InnerShift.getOperand(i: 1))) {
10545 uint64_t c1 = N001C->getZExtValue();
10546 uint64_t c2 = N1C->getZExtValue();
10547 EVT InnerShiftVT = InnerShift.getValueType();
10548 EVT ShiftAmtVT = InnerShift.getOperand(i: 1).getValueType();
10549 uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits();
10550 // srl (trunc (srl x, c1)), c2 --> 0 or (trunc (srl x, (add c1, c2)))
10551 // This is only valid if the OpSizeInBits + c1 = size of inner shift.
10552 if (c1 + OpSizeInBits == InnerShiftSize) {
10553 SDLoc DL(N);
10554 if (c1 + c2 >= InnerShiftSize)
10555 return DAG.getConstant(Val: 0, DL, VT);
10556 SDValue NewShiftAmt = DAG.getConstant(Val: c1 + c2, DL, VT: ShiftAmtVT);
10557 SDValue NewShift = DAG.getNode(Opcode: ISD::SRL, DL, VT: InnerShiftVT,
10558 N1: InnerShift.getOperand(i: 0), N2: NewShiftAmt);
10559 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: NewShift);
10560 }
10561 // In the more general case, we can clear the high bits after the shift:
10562 // srl (trunc (srl x, c1)), c2 --> trunc (and (srl x, (c1+c2)), Mask)
10563 if (N0.hasOneUse() && InnerShift.hasOneUse() &&
10564 c1 + c2 < InnerShiftSize) {
10565 SDLoc DL(N);
10566 SDValue NewShiftAmt = DAG.getConstant(Val: c1 + c2, DL, VT: ShiftAmtVT);
10567 SDValue NewShift = DAG.getNode(Opcode: ISD::SRL, DL, VT: InnerShiftVT,
10568 N1: InnerShift.getOperand(i: 0), N2: NewShiftAmt);
10569 SDValue Mask = DAG.getConstant(Val: APInt::getLowBitsSet(numBits: InnerShiftSize,
10570 loBitsSet: OpSizeInBits - c2),
10571 DL, VT: InnerShiftVT);
10572 SDValue And = DAG.getNode(Opcode: ISD::AND, DL, VT: InnerShiftVT, N1: NewShift, N2: Mask);
10573 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: And);
10574 }
10575 }
10576 }
10577
10578 // fold (srl (shl x, c1), c2) -> (and (shl x, (sub c1, c2), MASK) or
10579 // (and (srl x, (sub c2, c1), MASK)
10580 if (N0.getOpcode() == ISD::SHL &&
10581 (N0.getOperand(i: 1) == N1 || N0->hasOneUse()) &&
10582 TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
10583 auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
10584 ConstantSDNode *RHS) {
10585 const APInt &LHSC = LHS->getAPIntValue();
10586 const APInt &RHSC = RHS->getAPIntValue();
10587 return LHSC.ult(RHS: OpSizeInBits) && RHSC.ult(RHS: OpSizeInBits) &&
10588 LHSC.getZExtValue() <= RHSC.getZExtValue();
10589 };
10590 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchShiftAmount,
10591 /*AllowUndefs*/ false,
10592 /*AllowTypeMismatch*/ true)) {
10593 SDLoc DL(N);
10594 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
10595 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1: N01, N2: N1);
10596 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10597 Mask = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Mask, N2: N01);
10598 Mask = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Mask, N2: Diff);
10599 SDValue Shift = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
10600 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Shift, N2: Mask);
10601 }
10602 if (ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchShiftAmount,
10603 /*AllowUndefs*/ false,
10604 /*AllowTypeMismatch*/ true)) {
10605 SDLoc DL(N);
10606 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
10607 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1, N2: N01);
10608 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10609 Mask = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Mask, N2: N1);
10610 SDValue Shift = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
10611 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Shift, N2: Mask);
10612 }
10613 }
10614
10615 // fold (srl (anyextend x), c) -> (and (anyextend (srl x, c)), mask)
10616 // TODO - support non-uniform vector shift amounts.
10617 if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
10618 // Shifting in all undef bits?
10619 EVT SmallVT = N0.getOperand(i: 0).getValueType();
10620 unsigned BitSize = SmallVT.getScalarSizeInBits();
10621 if (N1C->getAPIntValue().uge(RHS: BitSize))
10622 return DAG.getUNDEF(VT);
10623
10624 if (!LegalTypes || TLI.isTypeDesirableForOp(ISD::SRL, VT: SmallVT)) {
10625 uint64_t ShiftAmt = N1C->getZExtValue();
10626 SDLoc DL0(N0);
10627 SDValue SmallShift = DAG.getNode(Opcode: ISD::SRL, DL: DL0, VT: SmallVT,
10628 N1: N0.getOperand(i: 0),
10629 N2: DAG.getConstant(Val: ShiftAmt, DL: DL0,
10630 VT: getShiftAmountTy(LHSTy: SmallVT)));
10631 AddToWorklist(N: SmallShift.getNode());
10632 APInt Mask = APInt::getLowBitsSet(numBits: OpSizeInBits, loBitsSet: OpSizeInBits - ShiftAmt);
10633 SDLoc DL(N);
10634 return DAG.getNode(Opcode: ISD::AND, DL, VT,
10635 N1: DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT, Operand: SmallShift),
10636 N2: DAG.getConstant(Val: Mask, DL, VT));
10637 }
10638 }
10639
10640 // fold (srl (sra X, Y), 31) -> (srl X, 31). This srl only looks at the sign
10641 // bit, which is unmodified by sra.
10642 if (N1C && N1C->getAPIntValue() == (OpSizeInBits - 1)) {
10643 if (N0.getOpcode() == ISD::SRA)
10644 return DAG.getNode(Opcode: ISD::SRL, DL: SDLoc(N), VT, N1: N0.getOperand(i: 0), N2: N1);
10645 }
10646
10647 // fold (srl (ctlz x), "5") -> x iff x has one bit set (the low bit), and x has a power
10648 // of two bitwidth. The "5" represents (log2 (bitwidth x)).
10649 if (N1C && N0.getOpcode() == ISD::CTLZ &&
10650 isPowerOf2_32(Value: OpSizeInBits) &&
10651 N1C->getAPIntValue() == Log2_32(Value: OpSizeInBits)) {
10652 KnownBits Known = DAG.computeKnownBits(Op: N0.getOperand(i: 0));
10653
10654 // If any of the input bits are KnownOne, then the input couldn't be all
10655 // zeros, thus the result of the srl will always be zero.
10656 if (Known.One.getBoolValue()) return DAG.getConstant(Val: 0, DL: SDLoc(N0), VT);
10657
10658 // If all of the bits input the to ctlz node are known to be zero, then
10659 // the result of the ctlz is "32" and the result of the shift is one.
10660 APInt UnknownBits = ~Known.Zero;
10661 if (UnknownBits == 0) return DAG.getConstant(Val: 1, DL: SDLoc(N0), VT);
10662
10663 // Otherwise, check to see if there is exactly one bit input to the ctlz.
10664 if (UnknownBits.isPowerOf2()) {
10665 // Okay, we know that only that the single bit specified by UnknownBits
10666 // could be set on input to the CTLZ node. If this bit is set, the SRL
10667 // will return 0, if it is clear, it returns 1. Change the CTLZ/SRL pair
10668 // to an SRL/XOR pair, which is likely to simplify more.
10669 unsigned ShAmt = UnknownBits.countr_zero();
10670 SDValue Op = N0.getOperand(i: 0);
10671
10672 if (ShAmt) {
10673 SDLoc DL(N0);
10674 Op = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Op,
10675 N2: DAG.getConstant(Val: ShAmt, DL,
10676 VT: getShiftAmountTy(LHSTy: Op.getValueType())));
10677 AddToWorklist(N: Op.getNode());
10678 }
10679
10680 SDLoc DL(N);
10681 return DAG.getNode(Opcode: ISD::XOR, DL, VT,
10682 N1: Op, N2: DAG.getConstant(Val: 1, DL, VT));
10683 }
10684 }
10685
10686 // fold (srl x, (trunc (and y, c))) -> (srl x, (and (trunc y), (trunc c))).
10687 if (N1.getOpcode() == ISD::TRUNCATE &&
10688 N1.getOperand(i: 0).getOpcode() == ISD::AND) {
10689 if (SDValue NewOp1 = distributeTruncateThroughAnd(N: N1.getNode()))
10690 return DAG.getNode(Opcode: ISD::SRL, DL: SDLoc(N), VT, N1: N0, N2: NewOp1);
10691 }
10692
10693 // fold operands of srl based on knowledge that the low bits are not
10694 // demanded.
10695 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
10696 return SDValue(N, 0);
10697
10698 if (N1C && !N1C->isOpaque())
10699 if (SDValue NewSRL = visitShiftByConstant(N))
10700 return NewSRL;
10701
10702 // Attempt to convert a srl of a load into a narrower zero-extending load.
10703 if (SDValue NarrowLoad = reduceLoadWidth(N))
10704 return NarrowLoad;
10705
10706 // Here is a common situation. We want to optimize:
10707 //
10708 // %a = ...
10709 // %b = and i32 %a, 2
10710 // %c = srl i32 %b, 1
10711 // brcond i32 %c ...
10712 //
10713 // into
10714 //
10715 // %a = ...
10716 // %b = and %a, 2
10717 // %c = setcc eq %b, 0
10718 // brcond %c ...
10719 //
10720 // However when after the source operand of SRL is optimized into AND, the SRL
10721 // itself may not be optimized further. Look for it and add the BRCOND into
10722 // the worklist.
10723 //
10724 // The also tends to happen for binary operations when SimplifyDemandedBits
10725 // is involved.
10726 //
10727 // FIXME: This is unecessary if we process the DAG in topological order,
10728 // which we plan to do. This workaround can be removed once the DAG is
10729 // processed in topological order.
10730 if (N->hasOneUse()) {
10731 SDNode *Use = *N->use_begin();
10732
10733 // Look pass the truncate.
10734 if (Use->getOpcode() == ISD::TRUNCATE && Use->hasOneUse())
10735 Use = *Use->use_begin();
10736
10737 if (Use->getOpcode() == ISD::BRCOND || Use->getOpcode() == ISD::AND ||
10738 Use->getOpcode() == ISD::OR || Use->getOpcode() == ISD::XOR)
10739 AddToWorklist(N: Use);
10740 }
10741
10742 // Try to transform this shift into a multiply-high if
10743 // it matches the appropriate pattern detected in combineShiftToMULH.
10744 if (SDValue MULH = combineShiftToMULH(N, DAG, TLI))
10745 return MULH;
10746
10747 return SDValue();
10748}
10749
10750SDValue DAGCombiner::visitFunnelShift(SDNode *N) {
10751 EVT VT = N->getValueType(ResNo: 0);
10752 SDValue N0 = N->getOperand(Num: 0);
10753 SDValue N1 = N->getOperand(Num: 1);
10754 SDValue N2 = N->getOperand(Num: 2);
10755 bool IsFSHL = N->getOpcode() == ISD::FSHL;
10756 unsigned BitWidth = VT.getScalarSizeInBits();
10757
10758 // fold (fshl N0, N1, 0) -> N0
10759 // fold (fshr N0, N1, 0) -> N1
10760 if (isPowerOf2_32(Value: BitWidth))
10761 if (DAG.MaskedValueIsZero(
10762 Op: N2, Mask: APInt(N2.getScalarValueSizeInBits(), BitWidth - 1)))
10763 return IsFSHL ? N0 : N1;
10764
10765 auto IsUndefOrZero = [](SDValue V) {
10766 return V.isUndef() || isNullOrNullSplat(V, /*AllowUndefs*/ true);
10767 };
10768
10769 // TODO - support non-uniform vector shift amounts.
10770 if (ConstantSDNode *Cst = isConstOrConstSplat(N: N2)) {
10771 EVT ShAmtTy = N2.getValueType();
10772
10773 // fold (fsh* N0, N1, c) -> (fsh* N0, N1, c % BitWidth)
10774 if (Cst->getAPIntValue().uge(RHS: BitWidth)) {
10775 uint64_t RotAmt = Cst->getAPIntValue().urem(RHS: BitWidth);
10776 return DAG.getNode(Opcode: N->getOpcode(), DL: SDLoc(N), VT, N1: N0, N2: N1,
10777 N3: DAG.getConstant(Val: RotAmt, DL: SDLoc(N), VT: ShAmtTy));
10778 }
10779
10780 unsigned ShAmt = Cst->getZExtValue();
10781 if (ShAmt == 0)
10782 return IsFSHL ? N0 : N1;
10783
10784 // fold fshl(undef_or_zero, N1, C) -> lshr(N1, BW-C)
10785 // fold fshr(undef_or_zero, N1, C) -> lshr(N1, C)
10786 // fold fshl(N0, undef_or_zero, C) -> shl(N0, C)
10787 // fold fshr(N0, undef_or_zero, C) -> shl(N0, BW-C)
10788 if (IsUndefOrZero(N0))
10789 return DAG.getNode(Opcode: ISD::SRL, DL: SDLoc(N), VT, N1,
10790 N2: DAG.getConstant(Val: IsFSHL ? BitWidth - ShAmt : ShAmt,
10791 DL: SDLoc(N), VT: ShAmtTy));
10792 if (IsUndefOrZero(N1))
10793 return DAG.getNode(Opcode: ISD::SHL, DL: SDLoc(N), VT, N1: N0,
10794 N2: DAG.getConstant(Val: IsFSHL ? ShAmt : BitWidth - ShAmt,
10795 DL: SDLoc(N), VT: ShAmtTy));
10796
10797 // fold (fshl ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
10798 // fold (fshr ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
10799 // TODO - bigendian support once we have test coverage.
10800 // TODO - can we merge this with CombineConseutiveLoads/MatchLoadCombine?
10801 // TODO - permit LHS EXTLOAD if extensions are shifted out.
10802 if ((BitWidth % 8) == 0 && (ShAmt % 8) == 0 && !VT.isVector() &&
10803 !DAG.getDataLayout().isBigEndian()) {
10804 auto *LHS = dyn_cast<LoadSDNode>(Val&: N0);
10805 auto *RHS = dyn_cast<LoadSDNode>(Val&: N1);
10806 if (LHS && RHS && LHS->isSimple() && RHS->isSimple() &&
10807 LHS->getAddressSpace() == RHS->getAddressSpace() &&
10808 (LHS->hasOneUse() || RHS->hasOneUse()) && ISD::isNON_EXTLoad(N: RHS) &&
10809 ISD::isNON_EXTLoad(N: LHS)) {
10810 if (DAG.areNonVolatileConsecutiveLoads(LD: LHS, Base: RHS, Bytes: BitWidth / 8, Dist: 1)) {
10811 SDLoc DL(RHS);
10812 uint64_t PtrOff =
10813 IsFSHL ? (((BitWidth - ShAmt) % BitWidth) / 8) : (ShAmt / 8);
10814 Align NewAlign = commonAlignment(A: RHS->getAlign(), Offset: PtrOff);
10815 unsigned Fast = 0;
10816 if (TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT,
10817 AddrSpace: RHS->getAddressSpace(), Alignment: NewAlign,
10818 Flags: RHS->getMemOperand()->getFlags(), Fast: &Fast) &&
10819 Fast) {
10820 SDValue NewPtr = DAG.getMemBasePlusOffset(
10821 Base: RHS->getBasePtr(), Offset: TypeSize::getFixed(ExactSize: PtrOff), DL);
10822 AddToWorklist(N: NewPtr.getNode());
10823 SDValue Load = DAG.getLoad(
10824 VT, dl: DL, Chain: RHS->getChain(), Ptr: NewPtr,
10825 PtrInfo: RHS->getPointerInfo().getWithOffset(O: PtrOff), Alignment: NewAlign,
10826 MMOFlags: RHS->getMemOperand()->getFlags(), AAInfo: RHS->getAAInfo());
10827 // Replace the old load's chain with the new load's chain.
10828 WorklistRemover DeadNodes(*this);
10829 DAG.ReplaceAllUsesOfValueWith(From: N1.getValue(R: 1), To: Load.getValue(R: 1));
10830 return Load;
10831 }
10832 }
10833 }
10834 }
10835 }
10836
10837 // fold fshr(undef_or_zero, N1, N2) -> lshr(N1, N2)
10838 // fold fshl(N0, undef_or_zero, N2) -> shl(N0, N2)
10839 // iff We know the shift amount is in range.
10840 // TODO: when is it worth doing SUB(BW, N2) as well?
10841 if (isPowerOf2_32(Value: BitWidth)) {
10842 APInt ModuloBits(N2.getScalarValueSizeInBits(), BitWidth - 1);
10843 if (IsUndefOrZero(N0) && !IsFSHL && DAG.MaskedValueIsZero(Op: N2, Mask: ~ModuloBits))
10844 return DAG.getNode(Opcode: ISD::SRL, DL: SDLoc(N), VT, N1, N2);
10845 if (IsUndefOrZero(N1) && IsFSHL && DAG.MaskedValueIsZero(Op: N2, Mask: ~ModuloBits))
10846 return DAG.getNode(Opcode: ISD::SHL, DL: SDLoc(N), VT, N1: N0, N2);
10847 }
10848
10849 // fold (fshl N0, N0, N2) -> (rotl N0, N2)
10850 // fold (fshr N0, N0, N2) -> (rotr N0, N2)
10851 // TODO: Investigate flipping this rotate if only one is legal, if funnel shift
10852 // is legal as well we might be better off avoiding non-constant (BW - N2).
10853 unsigned RotOpc = IsFSHL ? ISD::ROTL : ISD::ROTR;
10854 if (N0 == N1 && hasOperation(Opcode: RotOpc, VT))
10855 return DAG.getNode(Opcode: RotOpc, DL: SDLoc(N), VT, N1: N0, N2);
10856
10857 // Simplify, based on bits shifted out of N0/N1.
10858 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
10859 return SDValue(N, 0);
10860
10861 return SDValue();
10862}
10863
10864SDValue DAGCombiner::visitSHLSAT(SDNode *N) {
10865 SDValue N0 = N->getOperand(Num: 0);
10866 SDValue N1 = N->getOperand(Num: 1);
10867 if (SDValue V = DAG.simplifyShift(X: N0, Y: N1))
10868 return V;
10869
10870 EVT VT = N0.getValueType();
10871
10872 // fold (*shlsat c1, c2) -> c1<<c2
10873 if (SDValue C =
10874 DAG.FoldConstantArithmetic(Opcode: N->getOpcode(), DL: SDLoc(N), VT, Ops: {N0, N1}))
10875 return C;
10876
10877 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
10878
10879 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::SHL, VT)) {
10880 // fold (sshlsat x, c) -> (shl x, c)
10881 if (N->getOpcode() == ISD::SSHLSAT && N1C &&
10882 N1C->getAPIntValue().ult(RHS: DAG.ComputeNumSignBits(Op: N0)))
10883 return DAG.getNode(Opcode: ISD::SHL, DL: SDLoc(N), VT, N1: N0, N2: N1);
10884
10885 // fold (ushlsat x, c) -> (shl x, c)
10886 if (N->getOpcode() == ISD::USHLSAT && N1C &&
10887 N1C->getAPIntValue().ule(
10888 RHS: DAG.computeKnownBits(Op: N0).countMinLeadingZeros()))
10889 return DAG.getNode(Opcode: ISD::SHL, DL: SDLoc(N), VT, N1: N0, N2: N1);
10890 }
10891
10892 return SDValue();
10893}
10894
10895// Given a ABS node, detect the following patterns:
10896// (ABS (SUB (EXTEND a), (EXTEND b))).
10897// (TRUNC (ABS (SUB (EXTEND a), (EXTEND b)))).
10898// Generates UABD/SABD instruction.
10899SDValue DAGCombiner::foldABSToABD(SDNode *N) {
10900 EVT SrcVT = N->getValueType(ResNo: 0);
10901
10902 if (N->getOpcode() == ISD::TRUNCATE)
10903 N = N->getOperand(Num: 0).getNode();
10904
10905 if (N->getOpcode() != ISD::ABS)
10906 return SDValue();
10907
10908 EVT VT = N->getValueType(ResNo: 0);
10909 SDValue AbsOp1 = N->getOperand(Num: 0);
10910 SDValue Op0, Op1;
10911 SDLoc DL(N);
10912
10913 if (AbsOp1.getOpcode() != ISD::SUB)
10914 return SDValue();
10915
10916 Op0 = AbsOp1.getOperand(i: 0);
10917 Op1 = AbsOp1.getOperand(i: 1);
10918
10919 unsigned Opc0 = Op0.getOpcode();
10920
10921 // Check if the operands of the sub are (zero|sign)-extended.
10922 // TODO: Should we use ValueTracking instead?
10923 if (Opc0 != Op1.getOpcode() ||
10924 (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND &&
10925 Opc0 != ISD::SIGN_EXTEND_INREG)) {
10926 // fold (abs (sub nsw x, y)) -> abds(x, y)
10927 if (AbsOp1->getFlags().hasNoSignedWrap() && hasOperation(Opcode: ISD::ABDS, VT) &&
10928 TLI.preferABDSToABSWithNSW(VT)) {
10929 SDValue ABD = DAG.getNode(Opcode: ISD::ABDS, DL, VT, N1: Op0, N2: Op1);
10930 return DAG.getZExtOrTrunc(Op: ABD, DL, VT: SrcVT);
10931 }
10932 return SDValue();
10933 }
10934
10935 EVT VT0, VT1;
10936 if (Opc0 == ISD::SIGN_EXTEND_INREG) {
10937 VT0 = cast<VTSDNode>(Val: Op0.getOperand(i: 1))->getVT();
10938 VT1 = cast<VTSDNode>(Val: Op1.getOperand(i: 1))->getVT();
10939 } else {
10940 VT0 = Op0.getOperand(i: 0).getValueType();
10941 VT1 = Op1.getOperand(i: 0).getValueType();
10942 }
10943 unsigned ABDOpcode = (Opc0 == ISD::ZERO_EXTEND) ? ISD::ABDU : ISD::ABDS;
10944
10945 // fold abs(sext(x) - sext(y)) -> zext(abds(x, y))
10946 // fold abs(zext(x) - zext(y)) -> zext(abdu(x, y))
10947 EVT MaxVT = VT0.bitsGT(VT: VT1) ? VT0 : VT1;
10948 if ((VT0 == MaxVT || Op0->hasOneUse()) &&
10949 (VT1 == MaxVT || Op1->hasOneUse()) && hasOperation(Opcode: ABDOpcode, VT: MaxVT)) {
10950 SDValue ABD = DAG.getNode(Opcode: ABDOpcode, DL, VT: MaxVT,
10951 N1: DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MaxVT, Operand: Op0),
10952 N2: DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MaxVT, Operand: Op1));
10953 ABD = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: ABD);
10954 return DAG.getZExtOrTrunc(Op: ABD, DL, VT: SrcVT);
10955 }
10956
10957 // fold abs(sext(x) - sext(y)) -> abds(sext(x), sext(y))
10958 // fold abs(zext(x) - zext(y)) -> abdu(zext(x), zext(y))
10959 if (hasOperation(Opcode: ABDOpcode, VT)) {
10960 SDValue ABD = DAG.getNode(Opcode: ABDOpcode, DL, VT, N1: Op0, N2: Op1);
10961 return DAG.getZExtOrTrunc(Op: ABD, DL, VT: SrcVT);
10962 }
10963
10964 return SDValue();
10965}
10966
10967SDValue DAGCombiner::visitABS(SDNode *N) {
10968 SDValue N0 = N->getOperand(Num: 0);
10969 EVT VT = N->getValueType(ResNo: 0);
10970
10971 // fold (abs c1) -> c2
10972 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::ABS, DL: SDLoc(N), VT, Ops: {N0}))
10973 return C;
10974 // fold (abs (abs x)) -> (abs x)
10975 if (N0.getOpcode() == ISD::ABS)
10976 return N0;
10977 // fold (abs x) -> x iff not-negative
10978 if (DAG.SignBitIsZero(Op: N0))
10979 return N0;
10980
10981 if (SDValue ABD = foldABSToABD(N))
10982 return ABD;
10983
10984 // fold (abs (sign_extend_inreg x)) -> (zero_extend (abs (truncate x)))
10985 // iff zero_extend/truncate are free.
10986 if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG) {
10987 EVT ExtVT = cast<VTSDNode>(Val: N0.getOperand(i: 1))->getVT();
10988 if (TLI.isTruncateFree(FromVT: VT, ToVT: ExtVT) && TLI.isZExtFree(FromTy: ExtVT, ToTy: VT) &&
10989 TLI.isTypeDesirableForOp(ISD::ABS, VT: ExtVT) &&
10990 hasOperation(Opcode: ISD::ABS, VT: ExtVT)) {
10991 SDLoc DL(N);
10992 return DAG.getNode(
10993 Opcode: ISD::ZERO_EXTEND, DL, VT,
10994 Operand: DAG.getNode(Opcode: ISD::ABS, DL, VT: ExtVT,
10995 Operand: DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: ExtVT, Operand: N0.getOperand(i: 0))));
10996 }
10997 }
10998
10999 return SDValue();
11000}
11001
11002SDValue DAGCombiner::visitBSWAP(SDNode *N) {
11003 SDValue N0 = N->getOperand(Num: 0);
11004 EVT VT = N->getValueType(ResNo: 0);
11005 SDLoc DL(N);
11006
11007 // fold (bswap c1) -> c2
11008 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::BSWAP, DL, VT, Ops: {N0}))
11009 return C;
11010 // fold (bswap (bswap x)) -> x
11011 if (N0.getOpcode() == ISD::BSWAP)
11012 return N0.getOperand(i: 0);
11013
11014 // Canonicalize bswap(bitreverse(x)) -> bitreverse(bswap(x)). If bitreverse
11015 // isn't supported, it will be expanded to bswap followed by a manual reversal
11016 // of bits in each byte. By placing bswaps before bitreverse, we can remove
11017 // the two bswaps if the bitreverse gets expanded.
11018 if (N0.getOpcode() == ISD::BITREVERSE && N0.hasOneUse()) {
11019 SDValue BSwap = DAG.getNode(Opcode: ISD::BSWAP, DL, VT, Operand: N0.getOperand(i: 0));
11020 return DAG.getNode(Opcode: ISD::BITREVERSE, DL, VT, Operand: BSwap);
11021 }
11022
11023 // fold (bswap shl(x,c)) -> (zext(bswap(trunc(shl(x,sub(c,bw/2))))))
11024 // iff x >= bw/2 (i.e. lower half is known zero)
11025 unsigned BW = VT.getScalarSizeInBits();
11026 if (BW >= 32 && N0.getOpcode() == ISD::SHL && N0.hasOneUse()) {
11027 auto *ShAmt = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
11028 EVT HalfVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: BW / 2);
11029 if (ShAmt && ShAmt->getAPIntValue().ult(RHS: BW) &&
11030 ShAmt->getZExtValue() >= (BW / 2) &&
11031 (ShAmt->getZExtValue() % 16) == 0 && TLI.isTypeLegal(VT: HalfVT) &&
11032 TLI.isTruncateFree(FromVT: VT, ToVT: HalfVT) &&
11033 (!LegalOperations || hasOperation(Opcode: ISD::BSWAP, VT: HalfVT))) {
11034 SDValue Res = N0.getOperand(i: 0);
11035 if (uint64_t NewShAmt = (ShAmt->getZExtValue() - (BW / 2)))
11036 Res = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Res,
11037 N2: DAG.getConstant(Val: NewShAmt, DL, VT: getShiftAmountTy(LHSTy: VT)));
11038 Res = DAG.getZExtOrTrunc(Op: Res, DL, VT: HalfVT);
11039 Res = DAG.getNode(Opcode: ISD::BSWAP, DL, VT: HalfVT, Operand: Res);
11040 return DAG.getZExtOrTrunc(Op: Res, DL, VT);
11041 }
11042 }
11043
11044 // Try to canonicalize bswap-of-logical-shift-by-8-bit-multiple as
11045 // inverse-shift-of-bswap:
11046 // bswap (X u<< C) --> (bswap X) u>> C
11047 // bswap (X u>> C) --> (bswap X) u<< C
11048 if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
11049 N0.hasOneUse()) {
11050 auto *ShAmt = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
11051 if (ShAmt && ShAmt->getAPIntValue().ult(RHS: BW) &&
11052 ShAmt->getZExtValue() % 8 == 0) {
11053 SDValue NewSwap = DAG.getNode(Opcode: ISD::BSWAP, DL, VT, Operand: N0.getOperand(i: 0));
11054 unsigned InverseShift = N0.getOpcode() == ISD::SHL ? ISD::SRL : ISD::SHL;
11055 return DAG.getNode(Opcode: InverseShift, DL, VT, N1: NewSwap, N2: N0.getOperand(i: 1));
11056 }
11057 }
11058
11059 if (SDValue V = foldBitOrderCrossLogicOp(N, DAG))
11060 return V;
11061
11062 return SDValue();
11063}
11064
11065SDValue DAGCombiner::visitBITREVERSE(SDNode *N) {
11066 SDValue N0 = N->getOperand(Num: 0);
11067 EVT VT = N->getValueType(ResNo: 0);
11068 SDLoc DL(N);
11069
11070 // fold (bitreverse c1) -> c2
11071 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::BITREVERSE, DL, VT, Ops: {N0}))
11072 return C;
11073 // fold (bitreverse (bitreverse x)) -> x
11074 if (N0.getOpcode() == ISD::BITREVERSE)
11075 return N0.getOperand(i: 0);
11076 return SDValue();
11077}
11078
11079SDValue DAGCombiner::visitCTLZ(SDNode *N) {
11080 SDValue N0 = N->getOperand(Num: 0);
11081 EVT VT = N->getValueType(ResNo: 0);
11082 SDLoc DL(N);
11083
11084 // fold (ctlz c1) -> c2
11085 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::CTLZ, DL, VT, Ops: {N0}))
11086 return C;
11087
11088 // If the value is known never to be zero, switch to the undef version.
11089 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::CTLZ_ZERO_UNDEF, VT))
11090 if (DAG.isKnownNeverZero(Op: N0))
11091 return DAG.getNode(Opcode: ISD::CTLZ_ZERO_UNDEF, DL, VT, Operand: N0);
11092
11093 return SDValue();
11094}
11095
11096SDValue DAGCombiner::visitCTLZ_ZERO_UNDEF(SDNode *N) {
11097 SDValue N0 = N->getOperand(Num: 0);
11098 EVT VT = N->getValueType(ResNo: 0);
11099 SDLoc DL(N);
11100
11101 // fold (ctlz_zero_undef c1) -> c2
11102 if (SDValue C =
11103 DAG.FoldConstantArithmetic(Opcode: ISD::CTLZ_ZERO_UNDEF, DL, VT, Ops: {N0}))
11104 return C;
11105 return SDValue();
11106}
11107
11108SDValue DAGCombiner::visitCTTZ(SDNode *N) {
11109 SDValue N0 = N->getOperand(Num: 0);
11110 EVT VT = N->getValueType(ResNo: 0);
11111 SDLoc DL(N);
11112
11113 // fold (cttz c1) -> c2
11114 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::CTTZ, DL, VT, Ops: {N0}))
11115 return C;
11116
11117 // If the value is known never to be zero, switch to the undef version.
11118 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::CTTZ_ZERO_UNDEF, VT))
11119 if (DAG.isKnownNeverZero(Op: N0))
11120 return DAG.getNode(Opcode: ISD::CTTZ_ZERO_UNDEF, DL, VT, Operand: N0);
11121
11122 return SDValue();
11123}
11124
11125SDValue DAGCombiner::visitCTTZ_ZERO_UNDEF(SDNode *N) {
11126 SDValue N0 = N->getOperand(Num: 0);
11127 EVT VT = N->getValueType(ResNo: 0);
11128 SDLoc DL(N);
11129
11130 // fold (cttz_zero_undef c1) -> c2
11131 if (SDValue C =
11132 DAG.FoldConstantArithmetic(Opcode: ISD::CTTZ_ZERO_UNDEF, DL, VT, Ops: {N0}))
11133 return C;
11134 return SDValue();
11135}
11136
11137SDValue DAGCombiner::visitCTPOP(SDNode *N) {
11138 SDValue N0 = N->getOperand(Num: 0);
11139 EVT VT = N->getValueType(ResNo: 0);
11140 unsigned NumBits = VT.getScalarSizeInBits();
11141 SDLoc DL(N);
11142
11143 // fold (ctpop c1) -> c2
11144 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::CTPOP, DL, VT, Ops: {N0}))
11145 return C;
11146
11147 // If the upper bits are known to be zero, then see if its profitable to
11148 // only count the lower bits.
11149 if (VT.isScalarInteger() && NumBits > 8 && (NumBits & 1) == 0) {
11150 EVT HalfVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: NumBits / 2);
11151 if (hasOperation(Opcode: ISD::CTPOP, VT: HalfVT) &&
11152 TLI.isTypeDesirableForOp(ISD::CTPOP, VT: HalfVT) &&
11153 TLI.isTruncateFree(Val: N0, VT2: HalfVT) && TLI.isZExtFree(FromTy: HalfVT, ToTy: VT)) {
11154 APInt UpperBits = APInt::getHighBitsSet(numBits: NumBits, hiBitsSet: NumBits / 2);
11155 if (DAG.MaskedValueIsZero(Op: N0, Mask: UpperBits)) {
11156 SDValue PopCnt = DAG.getNode(Opcode: ISD::CTPOP, DL, VT: HalfVT,
11157 Operand: DAG.getZExtOrTrunc(Op: N0, DL, VT: HalfVT));
11158 return DAG.getZExtOrTrunc(Op: PopCnt, DL, VT);
11159 }
11160 }
11161 }
11162
11163 return SDValue();
11164}
11165
11166// FIXME: This should be checking for no signed zeros on individual operands, as
11167// well as no nans.
11168static bool isLegalToCombineMinNumMaxNum(SelectionDAG &DAG, SDValue LHS,
11169 SDValue RHS,
11170 const TargetLowering &TLI) {
11171 const TargetOptions &Options = DAG.getTarget().Options;
11172 EVT VT = LHS.getValueType();
11173
11174 return Options.NoSignedZerosFPMath && VT.isFloatingPoint() &&
11175 TLI.isProfitableToCombineMinNumMaxNum(VT) &&
11176 DAG.isKnownNeverNaN(Op: LHS) && DAG.isKnownNeverNaN(Op: RHS);
11177}
11178
11179static SDValue combineMinNumMaxNumImpl(const SDLoc &DL, EVT VT, SDValue LHS,
11180 SDValue RHS, SDValue True, SDValue False,
11181 ISD::CondCode CC,
11182 const TargetLowering &TLI,
11183 SelectionDAG &DAG) {
11184 EVT TransformVT = TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT);
11185 switch (CC) {
11186 case ISD::SETOLT:
11187 case ISD::SETOLE:
11188 case ISD::SETLT:
11189 case ISD::SETLE:
11190 case ISD::SETULT:
11191 case ISD::SETULE: {
11192 // Since it's known never nan to get here already, either fminnum or
11193 // fminnum_ieee are OK. Try the ieee version first, since it's fminnum is
11194 // expanded in terms of it.
11195 unsigned IEEEOpcode = (LHS == True) ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
11196 if (TLI.isOperationLegalOrCustom(Op: IEEEOpcode, VT))
11197 return DAG.getNode(Opcode: IEEEOpcode, DL, VT, N1: LHS, N2: RHS);
11198
11199 unsigned Opcode = (LHS == True) ? ISD::FMINNUM : ISD::FMAXNUM;
11200 if (TLI.isOperationLegalOrCustom(Op: Opcode, VT: TransformVT))
11201 return DAG.getNode(Opcode, DL, VT, N1: LHS, N2: RHS);
11202 return SDValue();
11203 }
11204 case ISD::SETOGT:
11205 case ISD::SETOGE:
11206 case ISD::SETGT:
11207 case ISD::SETGE:
11208 case ISD::SETUGT:
11209 case ISD::SETUGE: {
11210 unsigned IEEEOpcode = (LHS == True) ? ISD::FMAXNUM_IEEE : ISD::FMINNUM_IEEE;
11211 if (TLI.isOperationLegalOrCustom(Op: IEEEOpcode, VT))
11212 return DAG.getNode(Opcode: IEEEOpcode, DL, VT, N1: LHS, N2: RHS);
11213
11214 unsigned Opcode = (LHS == True) ? ISD::FMAXNUM : ISD::FMINNUM;
11215 if (TLI.isOperationLegalOrCustom(Op: Opcode, VT: TransformVT))
11216 return DAG.getNode(Opcode, DL, VT, N1: LHS, N2: RHS);
11217 return SDValue();
11218 }
11219 default:
11220 return SDValue();
11221 }
11222}
11223
11224/// Generate Min/Max node
11225SDValue DAGCombiner::combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
11226 SDValue RHS, SDValue True,
11227 SDValue False, ISD::CondCode CC) {
11228 if ((LHS == True && RHS == False) || (LHS == False && RHS == True))
11229 return combineMinNumMaxNumImpl(DL, VT, LHS, RHS, True, False, CC, TLI, DAG);
11230
11231 // If we can't directly match this, try to see if we can pull an fneg out of
11232 // the select.
11233 SDValue NegTrue = TLI.getCheaperOrNeutralNegatedExpression(
11234 Op: True, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize);
11235 if (!NegTrue)
11236 return SDValue();
11237
11238 HandleSDNode NegTrueHandle(NegTrue);
11239
11240 // Try to unfold an fneg from the select if we are comparing the negated
11241 // constant.
11242 //
11243 // select (setcc x, K) (fneg x), -K -> fneg(minnum(x, K))
11244 //
11245 // TODO: Handle fabs
11246 if (LHS == NegTrue) {
11247 // If we can't directly match this, try to see if we can pull an fneg out of
11248 // the select.
11249 SDValue NegRHS = TLI.getCheaperOrNeutralNegatedExpression(
11250 Op: RHS, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize);
11251 if (NegRHS) {
11252 HandleSDNode NegRHSHandle(NegRHS);
11253 if (NegRHS == False) {
11254 SDValue Combined = combineMinNumMaxNumImpl(DL, VT, LHS, RHS, True: NegTrue,
11255 False, CC, TLI, DAG);
11256 if (Combined)
11257 return DAG.getNode(Opcode: ISD::FNEG, DL, VT, Operand: Combined);
11258 }
11259 }
11260 }
11261
11262 return SDValue();
11263}
11264
11265/// If a (v)select has a condition value that is a sign-bit test, try to smear
11266/// the condition operand sign-bit across the value width and use it as a mask.
11267static SDValue foldSelectOfConstantsUsingSra(SDNode *N, SelectionDAG &DAG) {
11268 SDValue Cond = N->getOperand(Num: 0);
11269 SDValue C1 = N->getOperand(Num: 1);
11270 SDValue C2 = N->getOperand(Num: 2);
11271 if (!isConstantOrConstantVector(N: C1) || !isConstantOrConstantVector(N: C2))
11272 return SDValue();
11273
11274 EVT VT = N->getValueType(ResNo: 0);
11275 if (Cond.getOpcode() != ISD::SETCC || !Cond.hasOneUse() ||
11276 VT != Cond.getOperand(i: 0).getValueType())
11277 return SDValue();
11278
11279 // The inverted-condition + commuted-select variants of these patterns are
11280 // canonicalized to these forms in IR.
11281 SDValue X = Cond.getOperand(i: 0);
11282 SDValue CondC = Cond.getOperand(i: 1);
11283 ISD::CondCode CC = cast<CondCodeSDNode>(Val: Cond.getOperand(i: 2))->get();
11284 if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(V: CondC) &&
11285 isAllOnesOrAllOnesSplat(V: C2)) {
11286 // i32 X > -1 ? C1 : -1 --> (X >>s 31) | C1
11287 SDLoc DL(N);
11288 SDValue ShAmtC = DAG.getConstant(Val: X.getScalarValueSizeInBits() - 1, DL, VT);
11289 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: X, N2: ShAmtC);
11290 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: Sra, N2: C1);
11291 }
11292 if (CC == ISD::SETLT && isNullOrNullSplat(V: CondC) && isNullOrNullSplat(V: C2)) {
11293 // i8 X < 0 ? C1 : 0 --> (X >>s 7) & C1
11294 SDLoc DL(N);
11295 SDValue ShAmtC = DAG.getConstant(Val: X.getScalarValueSizeInBits() - 1, DL, VT);
11296 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: X, N2: ShAmtC);
11297 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Sra, N2: C1);
11298 }
11299 return SDValue();
11300}
11301
11302static bool shouldConvertSelectOfConstantsToMath(const SDValue &Cond, EVT VT,
11303 const TargetLowering &TLI) {
11304 if (!TLI.convertSelectOfConstantsToMath(VT))
11305 return false;
11306
11307 if (Cond.getOpcode() != ISD::SETCC || !Cond->hasOneUse())
11308 return true;
11309 if (!TLI.isOperationLegalOrCustom(Op: ISD::SELECT_CC, VT))
11310 return true;
11311
11312 ISD::CondCode CC = cast<CondCodeSDNode>(Val: Cond.getOperand(i: 2))->get();
11313 if (CC == ISD::SETLT && isNullOrNullSplat(V: Cond.getOperand(i: 1)))
11314 return true;
11315 if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(V: Cond.getOperand(i: 1)))
11316 return true;
11317
11318 return false;
11319}
11320
11321SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) {
11322 SDValue Cond = N->getOperand(Num: 0);
11323 SDValue N1 = N->getOperand(Num: 1);
11324 SDValue N2 = N->getOperand(Num: 2);
11325 EVT VT = N->getValueType(ResNo: 0);
11326 EVT CondVT = Cond.getValueType();
11327 SDLoc DL(N);
11328
11329 if (!VT.isInteger())
11330 return SDValue();
11331
11332 auto *C1 = dyn_cast<ConstantSDNode>(Val&: N1);
11333 auto *C2 = dyn_cast<ConstantSDNode>(Val&: N2);
11334 if (!C1 || !C2)
11335 return SDValue();
11336
11337 if (CondVT != MVT::i1 || LegalOperations) {
11338 // fold (select Cond, 0, 1) -> (xor Cond, 1)
11339 // We can't do this reliably if integer based booleans have different contents
11340 // to floating point based booleans. This is because we can't tell whether we
11341 // have an integer-based boolean or a floating-point-based boolean unless we
11342 // can find the SETCC that produced it and inspect its operands. This is
11343 // fairly easy if C is the SETCC node, but it can potentially be
11344 // undiscoverable (or not reasonably discoverable). For example, it could be
11345 // in another basic block or it could require searching a complicated
11346 // expression.
11347 if (CondVT.isInteger() &&
11348 TLI.getBooleanContents(/*isVec*/false, /*isFloat*/true) ==
11349 TargetLowering::ZeroOrOneBooleanContent &&
11350 TLI.getBooleanContents(/*isVec*/false, /*isFloat*/false) ==
11351 TargetLowering::ZeroOrOneBooleanContent &&
11352 C1->isZero() && C2->isOne()) {
11353 SDValue NotCond =
11354 DAG.getNode(Opcode: ISD::XOR, DL, VT: CondVT, N1: Cond, N2: DAG.getConstant(Val: 1, DL, VT: CondVT));
11355 if (VT.bitsEq(VT: CondVT))
11356 return NotCond;
11357 return DAG.getZExtOrTrunc(Op: NotCond, DL, VT);
11358 }
11359
11360 return SDValue();
11361 }
11362
11363 // Only do this before legalization to avoid conflicting with target-specific
11364 // transforms in the other direction (create a select from a zext/sext). There
11365 // is also a target-independent combine here in DAGCombiner in the other
11366 // direction for (select Cond, -1, 0) when the condition is not i1.
11367 assert(CondVT == MVT::i1 && !LegalOperations);
11368
11369 // select Cond, 1, 0 --> zext (Cond)
11370 if (C1->isOne() && C2->isZero())
11371 return DAG.getZExtOrTrunc(Op: Cond, DL, VT);
11372
11373 // select Cond, -1, 0 --> sext (Cond)
11374 if (C1->isAllOnes() && C2->isZero())
11375 return DAG.getSExtOrTrunc(Op: Cond, DL, VT);
11376
11377 // select Cond, 0, 1 --> zext (!Cond)
11378 if (C1->isZero() && C2->isOne()) {
11379 SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
11380 NotCond = DAG.getZExtOrTrunc(Op: NotCond, DL, VT);
11381 return NotCond;
11382 }
11383
11384 // select Cond, 0, -1 --> sext (!Cond)
11385 if (C1->isZero() && C2->isAllOnes()) {
11386 SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
11387 NotCond = DAG.getSExtOrTrunc(Op: NotCond, DL, VT);
11388 return NotCond;
11389 }
11390
11391 // Use a target hook because some targets may prefer to transform in the
11392 // other direction.
11393 if (!shouldConvertSelectOfConstantsToMath(Cond, VT, TLI))
11394 return SDValue();
11395
11396 // For any constants that differ by 1, we can transform the select into
11397 // an extend and add.
11398 const APInt &C1Val = C1->getAPIntValue();
11399 const APInt &C2Val = C2->getAPIntValue();
11400
11401 // select Cond, C1, C1-1 --> add (zext Cond), C1-1
11402 if (C1Val - 1 == C2Val) {
11403 Cond = DAG.getZExtOrTrunc(Op: Cond, DL, VT);
11404 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Cond, N2);
11405 }
11406
11407 // select Cond, C1, C1+1 --> add (sext Cond), C1+1
11408 if (C1Val + 1 == C2Val) {
11409 Cond = DAG.getSExtOrTrunc(Op: Cond, DL, VT);
11410 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Cond, N2);
11411 }
11412
11413 // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
11414 if (C1Val.isPowerOf2() && C2Val.isZero()) {
11415 Cond = DAG.getZExtOrTrunc(Op: Cond, DL, VT);
11416 SDValue ShAmtC =
11417 DAG.getShiftAmountConstant(Val: C1Val.exactLogBase2(), VT, DL);
11418 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Cond, N2: ShAmtC);
11419 }
11420
11421 // select Cond, -1, C --> or (sext Cond), C
11422 if (C1->isAllOnes()) {
11423 Cond = DAG.getSExtOrTrunc(Op: Cond, DL, VT);
11424 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: Cond, N2);
11425 }
11426
11427 // select Cond, C, -1 --> or (sext (not Cond)), C
11428 if (C2->isAllOnes()) {
11429 SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
11430 NotCond = DAG.getSExtOrTrunc(Op: NotCond, DL, VT);
11431 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: NotCond, N2: N1);
11432 }
11433
11434 if (SDValue V = foldSelectOfConstantsUsingSra(N, DAG))
11435 return V;
11436
11437 return SDValue();
11438}
11439
11440template <class MatchContextClass>
11441static SDValue foldBoolSelectToLogic(SDNode *N, SelectionDAG &DAG) {
11442 assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT ||
11443 N->getOpcode() == ISD::VP_SELECT) &&
11444 "Expected a (v)(vp.)select");
11445 SDValue Cond = N->getOperand(Num: 0);
11446 SDValue T = N->getOperand(Num: 1), F = N->getOperand(Num: 2);
11447 EVT VT = N->getValueType(ResNo: 0);
11448 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
11449 MatchContextClass matcher(DAG, TLI, N);
11450
11451 if (VT != Cond.getValueType() || VT.getScalarSizeInBits() != 1)
11452 return SDValue();
11453
11454 // select Cond, Cond, F --> or Cond, F
11455 // select Cond, 1, F --> or Cond, F
11456 if (Cond == T || isOneOrOneSplat(V: T, /* AllowUndefs */ true))
11457 return matcher.getNode(ISD::OR, SDLoc(N), VT, Cond, F);
11458
11459 // select Cond, T, Cond --> and Cond, T
11460 // select Cond, T, 0 --> and Cond, T
11461 if (Cond == F || isNullOrNullSplat(V: F, /* AllowUndefs */ true))
11462 return matcher.getNode(ISD::AND, SDLoc(N), VT, Cond, T);
11463
11464 // select Cond, T, 1 --> or (not Cond), T
11465 if (isOneOrOneSplat(V: F, /* AllowUndefs */ true)) {
11466 SDValue NotCond = matcher.getNode(ISD::XOR, SDLoc(N), VT, Cond,
11467 DAG.getAllOnesConstant(DL: SDLoc(N), VT));
11468 return matcher.getNode(ISD::OR, SDLoc(N), VT, NotCond, T);
11469 }
11470
11471 // select Cond, 0, F --> and (not Cond), F
11472 if (isNullOrNullSplat(V: T, /* AllowUndefs */ true)) {
11473 SDValue NotCond = matcher.getNode(ISD::XOR, SDLoc(N), VT, Cond,
11474 DAG.getAllOnesConstant(DL: SDLoc(N), VT));
11475 return matcher.getNode(ISD::AND, SDLoc(N), VT, NotCond, F);
11476 }
11477
11478 return SDValue();
11479}
11480
11481static SDValue foldVSelectToSignBitSplatMask(SDNode *N, SelectionDAG &DAG) {
11482 SDValue N0 = N->getOperand(Num: 0);
11483 SDValue N1 = N->getOperand(Num: 1);
11484 SDValue N2 = N->getOperand(Num: 2);
11485 EVT VT = N->getValueType(ResNo: 0);
11486 if (N0.getOpcode() != ISD::SETCC || !N0.hasOneUse())
11487 return SDValue();
11488
11489 SDValue Cond0 = N0.getOperand(i: 0);
11490 SDValue Cond1 = N0.getOperand(i: 1);
11491 ISD::CondCode CC = cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get();
11492 if (VT != Cond0.getValueType())
11493 return SDValue();
11494
11495 // Match a signbit check of Cond0 as "Cond0 s<0". Swap select operands if the
11496 // compare is inverted from that pattern ("Cond0 s> -1").
11497 if (CC == ISD::SETLT && isNullOrNullSplat(V: Cond1))
11498 ; // This is the pattern we are looking for.
11499 else if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(V: Cond1))
11500 std::swap(a&: N1, b&: N2);
11501 else
11502 return SDValue();
11503
11504 // (Cond0 s< 0) ? N1 : 0 --> (Cond0 s>> BW-1) & N1
11505 if (isNullOrNullSplat(V: N2)) {
11506 SDLoc DL(N);
11507 SDValue ShiftAmt = DAG.getConstant(Val: VT.getScalarSizeInBits() - 1, DL, VT);
11508 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: Cond0, N2: ShiftAmt);
11509 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Sra, N2: N1);
11510 }
11511
11512 // (Cond0 s< 0) ? -1 : N2 --> (Cond0 s>> BW-1) | N2
11513 if (isAllOnesOrAllOnesSplat(V: N1)) {
11514 SDLoc DL(N);
11515 SDValue ShiftAmt = DAG.getConstant(Val: VT.getScalarSizeInBits() - 1, DL, VT);
11516 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: Cond0, N2: ShiftAmt);
11517 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: Sra, N2);
11518 }
11519
11520 // If we have to invert the sign bit mask, only do that transform if the
11521 // target has a bitwise 'and not' instruction (the invert is free).
11522 // (Cond0 s< -0) ? 0 : N2 --> ~(Cond0 s>> BW-1) & N2
11523 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
11524 if (isNullOrNullSplat(V: N1) && TLI.hasAndNot(X: N1)) {
11525 SDLoc DL(N);
11526 SDValue ShiftAmt = DAG.getConstant(Val: VT.getScalarSizeInBits() - 1, DL, VT);
11527 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: Cond0, N2: ShiftAmt);
11528 SDValue Not = DAG.getNOT(DL, Val: Sra, VT);
11529 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Not, N2);
11530 }
11531
11532 // TODO: There's another pattern in this family, but it may require
11533 // implementing hasOrNot() to check for profitability:
11534 // (Cond0 s> -1) ? -1 : N2 --> ~(Cond0 s>> BW-1) | N2
11535
11536 return SDValue();
11537}
11538
11539SDValue DAGCombiner::visitSELECT(SDNode *N) {
11540 SDValue N0 = N->getOperand(Num: 0);
11541 SDValue N1 = N->getOperand(Num: 1);
11542 SDValue N2 = N->getOperand(Num: 2);
11543 EVT VT = N->getValueType(ResNo: 0);
11544 EVT VT0 = N0.getValueType();
11545 SDLoc DL(N);
11546 SDNodeFlags Flags = N->getFlags();
11547
11548 if (SDValue V = DAG.simplifySelect(Cond: N0, TVal: N1, FVal: N2))
11549 return V;
11550
11551 if (SDValue V = foldBoolSelectToLogic<EmptyMatchContext>(N, DAG))
11552 return V;
11553
11554 // select (not Cond), N1, N2 -> select Cond, N2, N1
11555 if (SDValue F = extractBooleanFlip(V: N0, DAG, TLI, Force: false)) {
11556 SDValue SelectOp = DAG.getSelect(DL, VT, Cond: F, LHS: N2, RHS: N1);
11557 SelectOp->setFlags(Flags);
11558 return SelectOp;
11559 }
11560
11561 if (SDValue V = foldSelectOfConstants(N))
11562 return V;
11563
11564 // If we can fold this based on the true/false value, do so.
11565 if (SimplifySelectOps(SELECT: N, LHS: N1, RHS: N2))
11566 return SDValue(N, 0); // Don't revisit N.
11567
11568 if (VT0 == MVT::i1) {
11569 // The code in this block deals with the following 2 equivalences:
11570 // select(C0|C1, x, y) <=> select(C0, x, select(C1, x, y))
11571 // select(C0&C1, x, y) <=> select(C0, select(C1, x, y), y)
11572 // The target can specify its preferred form with the
11573 // shouldNormalizeToSelectSequence() callback. However we always transform
11574 // to the right anyway if we find the inner select exists in the DAG anyway
11575 // and we always transform to the left side if we know that we can further
11576 // optimize the combination of the conditions.
11577 bool normalizeToSequence =
11578 TLI.shouldNormalizeToSelectSequence(Context&: *DAG.getContext(), VT);
11579 // select (and Cond0, Cond1), X, Y
11580 // -> select Cond0, (select Cond1, X, Y), Y
11581 if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) {
11582 SDValue Cond0 = N0->getOperand(Num: 0);
11583 SDValue Cond1 = N0->getOperand(Num: 1);
11584 SDValue InnerSelect =
11585 DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Cond1, N2: N1, N3: N2, Flags);
11586 if (normalizeToSequence || !InnerSelect.use_empty())
11587 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Cond0,
11588 N2: InnerSelect, N3: N2, Flags);
11589 // Cleanup on failure.
11590 if (InnerSelect.use_empty())
11591 recursivelyDeleteUnusedNodes(N: InnerSelect.getNode());
11592 }
11593 // select (or Cond0, Cond1), X, Y -> select Cond0, X, (select Cond1, X, Y)
11594 if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) {
11595 SDValue Cond0 = N0->getOperand(Num: 0);
11596 SDValue Cond1 = N0->getOperand(Num: 1);
11597 SDValue InnerSelect = DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(),
11598 N1: Cond1, N2: N1, N3: N2, Flags);
11599 if (normalizeToSequence || !InnerSelect.use_empty())
11600 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Cond0, N2: N1,
11601 N3: InnerSelect, Flags);
11602 // Cleanup on failure.
11603 if (InnerSelect.use_empty())
11604 recursivelyDeleteUnusedNodes(N: InnerSelect.getNode());
11605 }
11606
11607 // select Cond0, (select Cond1, X, Y), Y -> select (and Cond0, Cond1), X, Y
11608 if (N1->getOpcode() == ISD::SELECT && N1->hasOneUse()) {
11609 SDValue N1_0 = N1->getOperand(Num: 0);
11610 SDValue N1_1 = N1->getOperand(Num: 1);
11611 SDValue N1_2 = N1->getOperand(Num: 2);
11612 if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) {
11613 // Create the actual and node if we can generate good code for it.
11614 if (!normalizeToSequence) {
11615 SDValue And = DAG.getNode(Opcode: ISD::AND, DL, VT: N0.getValueType(), N1: N0, N2: N1_0);
11616 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: And, N2: N1_1,
11617 N3: N2, Flags);
11618 }
11619 // Otherwise see if we can optimize the "and" to a better pattern.
11620 if (SDValue Combined = visitANDLike(N0, N1: N1_0, N)) {
11621 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Combined, N2: N1_1,
11622 N3: N2, Flags);
11623 }
11624 }
11625 }
11626 // select Cond0, X, (select Cond1, X, Y) -> select (or Cond0, Cond1), X, Y
11627 if (N2->getOpcode() == ISD::SELECT && N2->hasOneUse()) {
11628 SDValue N2_0 = N2->getOperand(Num: 0);
11629 SDValue N2_1 = N2->getOperand(Num: 1);
11630 SDValue N2_2 = N2->getOperand(Num: 2);
11631 if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) {
11632 // Create the actual or node if we can generate good code for it.
11633 if (!normalizeToSequence) {
11634 SDValue Or = DAG.getNode(Opcode: ISD::OR, DL, VT: N0.getValueType(), N1: N0, N2: N2_0);
11635 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Or, N2: N1,
11636 N3: N2_2, Flags);
11637 }
11638 // Otherwise see if we can optimize to a better pattern.
11639 if (SDValue Combined = visitORLike(N0, N1: N2_0, N))
11640 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Combined, N2: N1,
11641 N3: N2_2, Flags);
11642 }
11643 }
11644 }
11645
11646 // Fold selects based on a setcc into other things, such as min/max/abs.
11647 if (N0.getOpcode() == ISD::SETCC) {
11648 SDValue Cond0 = N0.getOperand(i: 0), Cond1 = N0.getOperand(i: 1);
11649 ISD::CondCode CC = cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get();
11650
11651 // select (fcmp lt x, y), x, y -> fminnum x, y
11652 // select (fcmp gt x, y), x, y -> fmaxnum x, y
11653 //
11654 // This is OK if we don't care what happens if either operand is a NaN.
11655 if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, LHS: N1, RHS: N2, TLI))
11656 if (SDValue FMinMax =
11657 combineMinNumMaxNum(DL, VT, LHS: Cond0, RHS: Cond1, True: N1, False: N2, CC))
11658 return FMinMax;
11659
11660 // Use 'unsigned add with overflow' to optimize an unsigned saturating add.
11661 // This is conservatively limited to pre-legal-operations to give targets
11662 // a chance to reverse the transform if they want to do that. Also, it is
11663 // unlikely that the pattern would be formed late, so it's probably not
11664 // worth going through the other checks.
11665 if (!LegalOperations && TLI.isOperationLegalOrCustom(Op: ISD::UADDO, VT) &&
11666 CC == ISD::SETUGT && N0.hasOneUse() && isAllOnesConstant(V: N1) &&
11667 N2.getOpcode() == ISD::ADD && Cond0 == N2.getOperand(i: 0)) {
11668 auto *C = dyn_cast<ConstantSDNode>(Val: N2.getOperand(i: 1));
11669 auto *NotC = dyn_cast<ConstantSDNode>(Val&: Cond1);
11670 if (C && NotC && C->getAPIntValue() == ~NotC->getAPIntValue()) {
11671 // select (setcc Cond0, ~C, ugt), -1, (add Cond0, C) -->
11672 // uaddo Cond0, C; select uaddo.1, -1, uaddo.0
11673 //
11674 // The IR equivalent of this transform would have this form:
11675 // %a = add %x, C
11676 // %c = icmp ugt %x, ~C
11677 // %r = select %c, -1, %a
11678 // =>
11679 // %u = call {iN,i1} llvm.uadd.with.overflow(%x, C)
11680 // %u0 = extractvalue %u, 0
11681 // %u1 = extractvalue %u, 1
11682 // %r = select %u1, -1, %u0
11683 SDVTList VTs = DAG.getVTList(VT1: VT, VT2: VT0);
11684 SDValue UAO = DAG.getNode(Opcode: ISD::UADDO, DL, VTList: VTs, N1: Cond0, N2: N2.getOperand(i: 1));
11685 return DAG.getSelect(DL, VT, Cond: UAO.getValue(R: 1), LHS: N1, RHS: UAO.getValue(R: 0));
11686 }
11687 }
11688
11689 if (TLI.isOperationLegal(Op: ISD::SELECT_CC, VT) ||
11690 (!LegalOperations &&
11691 TLI.isOperationLegalOrCustom(Op: ISD::SELECT_CC, VT))) {
11692 // Any flags available in a select/setcc fold will be on the setcc as they
11693 // migrated from fcmp
11694 Flags = N0->getFlags();
11695 SDValue SelectNode = DAG.getNode(Opcode: ISD::SELECT_CC, DL, VT, N1: Cond0, N2: Cond1, N3: N1,
11696 N4: N2, N5: N0.getOperand(i: 2));
11697 SelectNode->setFlags(Flags);
11698 return SelectNode;
11699 }
11700
11701 if (SDValue NewSel = SimplifySelect(DL, N0, N1, N2))
11702 return NewSel;
11703 }
11704
11705 if (!VT.isVector())
11706 if (SDValue BinOp = foldSelectOfBinops(N))
11707 return BinOp;
11708
11709 if (SDValue R = combineSelectAsExtAnd(Cond: N0, T: N1, F: N2, DL, DAG))
11710 return R;
11711
11712 return SDValue();
11713}
11714
11715// This function assumes all the vselect's arguments are CONCAT_VECTOR
11716// nodes and that the condition is a BV of ConstantSDNodes (or undefs).
11717static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
11718 SDLoc DL(N);
11719 SDValue Cond = N->getOperand(Num: 0);
11720 SDValue LHS = N->getOperand(Num: 1);
11721 SDValue RHS = N->getOperand(Num: 2);
11722 EVT VT = N->getValueType(ResNo: 0);
11723 int NumElems = VT.getVectorNumElements();
11724 assert(LHS.getOpcode() == ISD::CONCAT_VECTORS &&
11725 RHS.getOpcode() == ISD::CONCAT_VECTORS &&
11726 Cond.getOpcode() == ISD::BUILD_VECTOR);
11727
11728 // CONCAT_VECTOR can take an arbitrary number of arguments. We only care about
11729 // binary ones here.
11730 if (LHS->getNumOperands() != 2 || RHS->getNumOperands() != 2)
11731 return SDValue();
11732
11733 // We're sure we have an even number of elements due to the
11734 // concat_vectors we have as arguments to vselect.
11735 // Skip BV elements until we find one that's not an UNDEF
11736 // After we find an UNDEF element, keep looping until we get to half the
11737 // length of the BV and see if all the non-undef nodes are the same.
11738 ConstantSDNode *BottomHalf = nullptr;
11739 for (int i = 0; i < NumElems / 2; ++i) {
11740 if (Cond->getOperand(Num: i)->isUndef())
11741 continue;
11742
11743 if (BottomHalf == nullptr)
11744 BottomHalf = cast<ConstantSDNode>(Val: Cond.getOperand(i));
11745 else if (Cond->getOperand(Num: i).getNode() != BottomHalf)
11746 return SDValue();
11747 }
11748
11749 // Do the same for the second half of the BuildVector
11750 ConstantSDNode *TopHalf = nullptr;
11751 for (int i = NumElems / 2; i < NumElems; ++i) {
11752 if (Cond->getOperand(Num: i)->isUndef())
11753 continue;
11754
11755 if (TopHalf == nullptr)
11756 TopHalf = cast<ConstantSDNode>(Val: Cond.getOperand(i));
11757 else if (Cond->getOperand(Num: i).getNode() != TopHalf)
11758 return SDValue();
11759 }
11760
11761 assert(TopHalf && BottomHalf &&
11762 "One half of the selector was all UNDEFs and the other was all the "
11763 "same value. This should have been addressed before this function.");
11764 return DAG.getNode(
11765 Opcode: ISD::CONCAT_VECTORS, DL, VT,
11766 N1: BottomHalf->isZero() ? RHS->getOperand(Num: 0) : LHS->getOperand(Num: 0),
11767 N2: TopHalf->isZero() ? RHS->getOperand(Num: 1) : LHS->getOperand(Num: 1));
11768}
11769
11770bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
11771 SelectionDAG &DAG, const SDLoc &DL) {
11772
11773 // Only perform the transformation when existing operands can be reused.
11774 if (IndexIsScaled)
11775 return false;
11776
11777 if (!isNullConstant(V: BasePtr) && !Index.hasOneUse())
11778 return false;
11779
11780 EVT VT = BasePtr.getValueType();
11781
11782 if (SDValue SplatVal = DAG.getSplatValue(V: Index);
11783 SplatVal && !isNullConstant(V: SplatVal) &&
11784 SplatVal.getValueType() == VT) {
11785 BasePtr = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: BasePtr, N2: SplatVal);
11786 Index = DAG.getSplat(VT: Index.getValueType(), DL, Op: DAG.getConstant(Val: 0, DL, VT));
11787 return true;
11788 }
11789
11790 if (Index.getOpcode() != ISD::ADD)
11791 return false;
11792
11793 if (SDValue SplatVal = DAG.getSplatValue(V: Index.getOperand(i: 0));
11794 SplatVal && SplatVal.getValueType() == VT) {
11795 BasePtr = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: BasePtr, N2: SplatVal);
11796 Index = Index.getOperand(i: 1);
11797 return true;
11798 }
11799 if (SDValue SplatVal = DAG.getSplatValue(V: Index.getOperand(i: 1));
11800 SplatVal && SplatVal.getValueType() == VT) {
11801 BasePtr = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: BasePtr, N2: SplatVal);
11802 Index = Index.getOperand(i: 0);
11803 return true;
11804 }
11805 return false;
11806}
11807
11808// Fold sext/zext of index into index type.
11809bool refineIndexType(SDValue &Index, ISD::MemIndexType &IndexType, EVT DataVT,
11810 SelectionDAG &DAG) {
11811 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
11812
11813 // It's always safe to look through zero extends.
11814 if (Index.getOpcode() == ISD::ZERO_EXTEND) {
11815 if (TLI.shouldRemoveExtendFromGSIndex(Extend: Index, DataVT)) {
11816 IndexType = ISD::UNSIGNED_SCALED;
11817 Index = Index.getOperand(i: 0);
11818 return true;
11819 }
11820 if (ISD::isIndexTypeSigned(IndexType)) {
11821 IndexType = ISD::UNSIGNED_SCALED;
11822 return true;
11823 }
11824 }
11825
11826 // It's only safe to look through sign extends when Index is signed.
11827 if (Index.getOpcode() == ISD::SIGN_EXTEND &&
11828 ISD::isIndexTypeSigned(IndexType) &&
11829 TLI.shouldRemoveExtendFromGSIndex(Extend: Index, DataVT)) {
11830 Index = Index.getOperand(i: 0);
11831 return true;
11832 }
11833
11834 return false;
11835}
11836
11837SDValue DAGCombiner::visitVPSCATTER(SDNode *N) {
11838 VPScatterSDNode *MSC = cast<VPScatterSDNode>(Val: N);
11839 SDValue Mask = MSC->getMask();
11840 SDValue Chain = MSC->getChain();
11841 SDValue Index = MSC->getIndex();
11842 SDValue Scale = MSC->getScale();
11843 SDValue StoreVal = MSC->getValue();
11844 SDValue BasePtr = MSC->getBasePtr();
11845 SDValue VL = MSC->getVectorLength();
11846 ISD::MemIndexType IndexType = MSC->getIndexType();
11847 SDLoc DL(N);
11848
11849 // Zap scatters with a zero mask.
11850 if (ISD::isConstantSplatVectorAllZeros(N: Mask.getNode()))
11851 return Chain;
11852
11853 if (refineUniformBase(BasePtr, Index, IndexIsScaled: MSC->isIndexScaled(), DAG, DL)) {
11854 SDValue Ops[] = {Chain, StoreVal, BasePtr, Index, Scale, Mask, VL};
11855 return DAG.getScatterVP(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
11856 DL, Ops, MSC->getMemOperand(), IndexType);
11857 }
11858
11859 if (refineIndexType(Index, IndexType, DataVT: StoreVal.getValueType(), DAG)) {
11860 SDValue Ops[] = {Chain, StoreVal, BasePtr, Index, Scale, Mask, VL};
11861 return DAG.getScatterVP(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
11862 DL, Ops, MSC->getMemOperand(), IndexType);
11863 }
11864
11865 return SDValue();
11866}
11867
11868SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
11869 MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(Val: N);
11870 SDValue Mask = MSC->getMask();
11871 SDValue Chain = MSC->getChain();
11872 SDValue Index = MSC->getIndex();
11873 SDValue Scale = MSC->getScale();
11874 SDValue StoreVal = MSC->getValue();
11875 SDValue BasePtr = MSC->getBasePtr();
11876 ISD::MemIndexType IndexType = MSC->getIndexType();
11877 SDLoc DL(N);
11878
11879 // Zap scatters with a zero mask.
11880 if (ISD::isConstantSplatVectorAllZeros(N: Mask.getNode()))
11881 return Chain;
11882
11883 if (refineUniformBase(BasePtr, Index, IndexIsScaled: MSC->isIndexScaled(), DAG, DL)) {
11884 SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
11885 return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
11886 DL, Ops, MSC->getMemOperand(), IndexType,
11887 MSC->isTruncatingStore());
11888 }
11889
11890 if (refineIndexType(Index, IndexType, DataVT: StoreVal.getValueType(), DAG)) {
11891 SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
11892 return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
11893 DL, Ops, MSC->getMemOperand(), IndexType,
11894 MSC->isTruncatingStore());
11895 }
11896
11897 return SDValue();
11898}
11899
11900SDValue DAGCombiner::visitMSTORE(SDNode *N) {
11901 MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(Val: N);
11902 SDValue Mask = MST->getMask();
11903 SDValue Chain = MST->getChain();
11904 SDValue Value = MST->getValue();
11905 SDValue Ptr = MST->getBasePtr();
11906 SDLoc DL(N);
11907
11908 // Zap masked stores with a zero mask.
11909 if (ISD::isConstantSplatVectorAllZeros(N: Mask.getNode()))
11910 return Chain;
11911
11912 // Remove a masked store if base pointers and masks are equal.
11913 if (MaskedStoreSDNode *MST1 = dyn_cast<MaskedStoreSDNode>(Val&: Chain)) {
11914 if (MST->isUnindexed() && MST->isSimple() && MST1->isUnindexed() &&
11915 MST1->isSimple() && MST1->getBasePtr() == Ptr &&
11916 !MST->getBasePtr().isUndef() &&
11917 ((Mask == MST1->getMask() && MST->getMemoryVT().getStoreSize() ==
11918 MST1->getMemoryVT().getStoreSize()) ||
11919 ISD::isConstantSplatVectorAllOnes(N: Mask.getNode())) &&
11920 TypeSize::isKnownLE(LHS: MST1->getMemoryVT().getStoreSize(),
11921 RHS: MST->getMemoryVT().getStoreSize())) {
11922 CombineTo(N: MST1, Res: MST1->getChain());
11923 if (N->getOpcode() != ISD::DELETED_NODE)
11924 AddToWorklist(N);
11925 return SDValue(N, 0);
11926 }
11927 }
11928
11929 // If this is a masked load with an all ones mask, we can use a unmasked load.
11930 // FIXME: Can we do this for indexed, compressing, or truncating stores?
11931 if (ISD::isConstantSplatVectorAllOnes(N: Mask.getNode()) && MST->isUnindexed() &&
11932 !MST->isCompressingStore() && !MST->isTruncatingStore())
11933 return DAG.getStore(Chain: MST->getChain(), dl: SDLoc(N), Val: MST->getValue(),
11934 Ptr: MST->getBasePtr(), PtrInfo: MST->getPointerInfo(),
11935 Alignment: MST->getOriginalAlign(), MMOFlags: MachineMemOperand::MOStore,
11936 AAInfo: MST->getAAInfo());
11937
11938 // Try transforming N to an indexed store.
11939 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
11940 return SDValue(N, 0);
11941
11942 if (MST->isTruncatingStore() && MST->isUnindexed() &&
11943 Value.getValueType().isInteger() &&
11944 (!isa<ConstantSDNode>(Val: Value) ||
11945 !cast<ConstantSDNode>(Val&: Value)->isOpaque())) {
11946 APInt TruncDemandedBits =
11947 APInt::getLowBitsSet(numBits: Value.getScalarValueSizeInBits(),
11948 loBitsSet: MST->getMemoryVT().getScalarSizeInBits());
11949
11950 // See if we can simplify the operation with
11951 // SimplifyDemandedBits, which only works if the value has a single use.
11952 if (SimplifyDemandedBits(Op: Value, DemandedBits: TruncDemandedBits)) {
11953 // Re-visit the store if anything changed and the store hasn't been merged
11954 // with another node (N is deleted) SimplifyDemandedBits will add Value's
11955 // node back to the worklist if necessary, but we also need to re-visit
11956 // the Store node itself.
11957 if (N->getOpcode() != ISD::DELETED_NODE)
11958 AddToWorklist(N);
11959 return SDValue(N, 0);
11960 }
11961 }
11962
11963 // If this is a TRUNC followed by a masked store, fold this into a masked
11964 // truncating store. We can do this even if this is already a masked
11965 // truncstore.
11966 // TODO: Try combine to masked compress store if possiable.
11967 if ((Value.getOpcode() == ISD::TRUNCATE) && Value->hasOneUse() &&
11968 MST->isUnindexed() && !MST->isCompressingStore() &&
11969 TLI.canCombineTruncStore(ValVT: Value.getOperand(i: 0).getValueType(),
11970 MemVT: MST->getMemoryVT(), LegalOnly: LegalOperations)) {
11971 auto Mask = TLI.promoteTargetBoolean(DAG, Bool: MST->getMask(),
11972 ValVT: Value.getOperand(i: 0).getValueType());
11973 return DAG.getMaskedStore(Chain, dl: SDLoc(N), Val: Value.getOperand(i: 0), Base: Ptr,
11974 Offset: MST->getOffset(), Mask, MemVT: MST->getMemoryVT(),
11975 MMO: MST->getMemOperand(), AM: MST->getAddressingMode(),
11976 /*IsTruncating=*/true);
11977 }
11978
11979 return SDValue();
11980}
11981
11982SDValue DAGCombiner::visitVP_STRIDED_STORE(SDNode *N) {
11983 auto *SST = cast<VPStridedStoreSDNode>(Val: N);
11984 EVT EltVT = SST->getValue().getValueType().getVectorElementType();
11985 // Combine strided stores with unit-stride to a regular VP store.
11986 if (auto *CStride = dyn_cast<ConstantSDNode>(Val: SST->getStride());
11987 CStride && CStride->getZExtValue() == EltVT.getStoreSize()) {
11988 return DAG.getStoreVP(Chain: SST->getChain(), dl: SDLoc(N), Val: SST->getValue(),
11989 Ptr: SST->getBasePtr(), Offset: SST->getOffset(), Mask: SST->getMask(),
11990 EVL: SST->getVectorLength(), MemVT: SST->getMemoryVT(),
11991 MMO: SST->getMemOperand(), AM: SST->getAddressingMode(),
11992 IsTruncating: SST->isTruncatingStore(), IsCompressing: SST->isCompressingStore());
11993 }
11994 return SDValue();
11995}
11996
11997SDValue DAGCombiner::visitVPGATHER(SDNode *N) {
11998 VPGatherSDNode *MGT = cast<VPGatherSDNode>(Val: N);
11999 SDValue Mask = MGT->getMask();
12000 SDValue Chain = MGT->getChain();
12001 SDValue Index = MGT->getIndex();
12002 SDValue Scale = MGT->getScale();
12003 SDValue BasePtr = MGT->getBasePtr();
12004 SDValue VL = MGT->getVectorLength();
12005 ISD::MemIndexType IndexType = MGT->getIndexType();
12006 SDLoc DL(N);
12007
12008 if (refineUniformBase(BasePtr, Index, IndexIsScaled: MGT->isIndexScaled(), DAG, DL)) {
12009 SDValue Ops[] = {Chain, BasePtr, Index, Scale, Mask, VL};
12010 return DAG.getGatherVP(
12011 DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
12012 Ops, MGT->getMemOperand(), IndexType);
12013 }
12014
12015 if (refineIndexType(Index, IndexType, DataVT: N->getValueType(ResNo: 0), DAG)) {
12016 SDValue Ops[] = {Chain, BasePtr, Index, Scale, Mask, VL};
12017 return DAG.getGatherVP(
12018 DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
12019 Ops, MGT->getMemOperand(), IndexType);
12020 }
12021
12022 return SDValue();
12023}
12024
12025SDValue DAGCombiner::visitMGATHER(SDNode *N) {
12026 MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(Val: N);
12027 SDValue Mask = MGT->getMask();
12028 SDValue Chain = MGT->getChain();
12029 SDValue Index = MGT->getIndex();
12030 SDValue Scale = MGT->getScale();
12031 SDValue PassThru = MGT->getPassThru();
12032 SDValue BasePtr = MGT->getBasePtr();
12033 ISD::MemIndexType IndexType = MGT->getIndexType();
12034 SDLoc DL(N);
12035
12036 // Zap gathers with a zero mask.
12037 if (ISD::isConstantSplatVectorAllZeros(N: Mask.getNode()))
12038 return CombineTo(N, Res0: PassThru, Res1: MGT->getChain());
12039
12040 if (refineUniformBase(BasePtr, Index, IndexIsScaled: MGT->isIndexScaled(), DAG, DL)) {
12041 SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
12042 return DAG.getMaskedGather(
12043 DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
12044 Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
12045 }
12046
12047 if (refineIndexType(Index, IndexType, DataVT: N->getValueType(ResNo: 0), DAG)) {
12048 SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
12049 return DAG.getMaskedGather(
12050 DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
12051 Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
12052 }
12053
12054 return SDValue();
12055}
12056
12057SDValue DAGCombiner::visitMLOAD(SDNode *N) {
12058 MaskedLoadSDNode *MLD = cast<MaskedLoadSDNode>(Val: N);
12059 SDValue Mask = MLD->getMask();
12060 SDLoc DL(N);
12061
12062 // Zap masked loads with a zero mask.
12063 if (ISD::isConstantSplatVectorAllZeros(N: Mask.getNode()))
12064 return CombineTo(N, Res0: MLD->getPassThru(), Res1: MLD->getChain());
12065
12066 // If this is a masked load with an all ones mask, we can use a unmasked load.
12067 // FIXME: Can we do this for indexed, expanding, or extending loads?
12068 if (ISD::isConstantSplatVectorAllOnes(N: Mask.getNode()) && MLD->isUnindexed() &&
12069 !MLD->isExpandingLoad() && MLD->getExtensionType() == ISD::NON_EXTLOAD) {
12070 SDValue NewLd = DAG.getLoad(
12071 VT: N->getValueType(ResNo: 0), dl: SDLoc(N), Chain: MLD->getChain(), Ptr: MLD->getBasePtr(),
12072 PtrInfo: MLD->getPointerInfo(), Alignment: MLD->getOriginalAlign(),
12073 MMOFlags: MachineMemOperand::MOLoad, AAInfo: MLD->getAAInfo(), Ranges: MLD->getRanges());
12074 return CombineTo(N, Res0: NewLd, Res1: NewLd.getValue(R: 1));
12075 }
12076
12077 // Try transforming N to an indexed load.
12078 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
12079 return SDValue(N, 0);
12080
12081 return SDValue();
12082}
12083
12084SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
12085 auto *SLD = cast<VPStridedLoadSDNode>(Val: N);
12086 EVT EltVT = SLD->getValueType(ResNo: 0).getVectorElementType();
12087 // Combine strided loads with unit-stride to a regular VP load.
12088 if (auto *CStride = dyn_cast<ConstantSDNode>(Val: SLD->getStride());
12089 CStride && CStride->getZExtValue() == EltVT.getStoreSize()) {
12090 SDValue NewLd = DAG.getLoadVP(
12091 AM: SLD->getAddressingMode(), ExtType: SLD->getExtensionType(), VT: SLD->getValueType(ResNo: 0),
12092 dl: SDLoc(N), Chain: SLD->getChain(), Ptr: SLD->getBasePtr(), Offset: SLD->getOffset(),
12093 Mask: SLD->getMask(), EVL: SLD->getVectorLength(), MemVT: SLD->getMemoryVT(),
12094 MMO: SLD->getMemOperand(), IsExpanding: SLD->isExpandingLoad());
12095 return CombineTo(N, Res0: NewLd, Res1: NewLd.getValue(R: 1));
12096 }
12097 return SDValue();
12098}
12099
12100/// A vector select of 2 constant vectors can be simplified to math/logic to
12101/// avoid a variable select instruction and possibly avoid constant loads.
12102SDValue DAGCombiner::foldVSelectOfConstants(SDNode *N) {
12103 SDValue Cond = N->getOperand(Num: 0);
12104 SDValue N1 = N->getOperand(Num: 1);
12105 SDValue N2 = N->getOperand(Num: 2);
12106 EVT VT = N->getValueType(ResNo: 0);
12107 if (!Cond.hasOneUse() || Cond.getScalarValueSizeInBits() != 1 ||
12108 !shouldConvertSelectOfConstantsToMath(Cond, VT, TLI) ||
12109 !ISD::isBuildVectorOfConstantSDNodes(N: N1.getNode()) ||
12110 !ISD::isBuildVectorOfConstantSDNodes(N: N2.getNode()))
12111 return SDValue();
12112
12113 // Check if we can use the condition value to increment/decrement a single
12114 // constant value. This simplifies a select to an add and removes a constant
12115 // load/materialization from the general case.
12116 bool AllAddOne = true;
12117 bool AllSubOne = true;
12118 unsigned Elts = VT.getVectorNumElements();
12119 for (unsigned i = 0; i != Elts; ++i) {
12120 SDValue N1Elt = N1.getOperand(i);
12121 SDValue N2Elt = N2.getOperand(i);
12122 if (N1Elt.isUndef() || N2Elt.isUndef())
12123 continue;
12124 if (N1Elt.getValueType() != N2Elt.getValueType())
12125 continue;
12126
12127 const APInt &C1 = N1Elt->getAsAPIntVal();
12128 const APInt &C2 = N2Elt->getAsAPIntVal();
12129 if (C1 != C2 + 1)
12130 AllAddOne = false;
12131 if (C1 != C2 - 1)
12132 AllSubOne = false;
12133 }
12134
12135 // Further simplifications for the extra-special cases where the constants are
12136 // all 0 or all -1 should be implemented as folds of these patterns.
12137 SDLoc DL(N);
12138 if (AllAddOne || AllSubOne) {
12139 // vselect <N x i1> Cond, C+1, C --> add (zext Cond), C
12140 // vselect <N x i1> Cond, C-1, C --> add (sext Cond), C
12141 auto ExtendOpcode = AllAddOne ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND;
12142 SDValue ExtendedCond = DAG.getNode(Opcode: ExtendOpcode, DL, VT, Operand: Cond);
12143 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: ExtendedCond, N2);
12144 }
12145
12146 // select Cond, Pow2C, 0 --> (zext Cond) << log2(Pow2C)
12147 APInt Pow2C;
12148 if (ISD::isConstantSplatVector(N: N1.getNode(), SplatValue&: Pow2C) && Pow2C.isPowerOf2() &&
12149 isNullOrNullSplat(V: N2)) {
12150 SDValue ZextCond = DAG.getZExtOrTrunc(Op: Cond, DL, VT);
12151 SDValue ShAmtC = DAG.getConstant(Val: Pow2C.exactLogBase2(), DL, VT);
12152 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: ZextCond, N2: ShAmtC);
12153 }
12154
12155 if (SDValue V = foldSelectOfConstantsUsingSra(N, DAG))
12156 return V;
12157
12158 // The general case for select-of-constants:
12159 // vselect <N x i1> Cond, C1, C2 --> xor (and (sext Cond), (C1^C2)), C2
12160 // ...but that only makes sense if a vselect is slower than 2 logic ops, so
12161 // leave that to a machine-specific pass.
12162 return SDValue();
12163}
12164
12165SDValue DAGCombiner::visitVP_SELECT(SDNode *N) {
12166 if (SDValue V = foldBoolSelectToLogic<VPMatchContext>(N, DAG))
12167 return V;
12168
12169 return SDValue();
12170}
12171
12172SDValue DAGCombiner::visitVSELECT(SDNode *N) {
12173 SDValue N0 = N->getOperand(Num: 0);
12174 SDValue N1 = N->getOperand(Num: 1);
12175 SDValue N2 = N->getOperand(Num: 2);
12176 EVT VT = N->getValueType(ResNo: 0);
12177 SDLoc DL(N);
12178
12179 if (SDValue V = DAG.simplifySelect(Cond: N0, TVal: N1, FVal: N2))
12180 return V;
12181
12182 if (SDValue V = foldBoolSelectToLogic<EmptyMatchContext>(N, DAG))
12183 return V;
12184
12185 // vselect (not Cond), N1, N2 -> vselect Cond, N2, N1
12186 if (SDValue F = extractBooleanFlip(V: N0, DAG, TLI, Force: false))
12187 return DAG.getSelect(DL, VT, Cond: F, LHS: N2, RHS: N1);
12188
12189 // Canonicalize integer abs.
12190 // vselect (setg[te] X, 0), X, -X ->
12191 // vselect (setgt X, -1), X, -X ->
12192 // vselect (setl[te] X, 0), -X, X ->
12193 // Y = sra (X, size(X)-1); xor (add (X, Y), Y)
12194 if (N0.getOpcode() == ISD::SETCC) {
12195 SDValue LHS = N0.getOperand(i: 0), RHS = N0.getOperand(i: 1);
12196 ISD::CondCode CC = cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get();
12197 bool isAbs = false;
12198 bool RHSIsAllZeros = ISD::isBuildVectorAllZeros(N: RHS.getNode());
12199
12200 if (((RHSIsAllZeros && (CC == ISD::SETGT || CC == ISD::SETGE)) ||
12201 (ISD::isBuildVectorAllOnes(N: RHS.getNode()) && CC == ISD::SETGT)) &&
12202 N1 == LHS && N2.getOpcode() == ISD::SUB && N1 == N2.getOperand(i: 1))
12203 isAbs = ISD::isBuildVectorAllZeros(N: N2.getOperand(i: 0).getNode());
12204 else if ((RHSIsAllZeros && (CC == ISD::SETLT || CC == ISD::SETLE)) &&
12205 N2 == LHS && N1.getOpcode() == ISD::SUB && N2 == N1.getOperand(i: 1))
12206 isAbs = ISD::isBuildVectorAllZeros(N: N1.getOperand(i: 0).getNode());
12207
12208 if (isAbs) {
12209 if (TLI.isOperationLegalOrCustom(Op: ISD::ABS, VT))
12210 return DAG.getNode(Opcode: ISD::ABS, DL, VT, Operand: LHS);
12211
12212 SDValue Shift = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: LHS,
12213 N2: DAG.getConstant(Val: VT.getScalarSizeInBits() - 1,
12214 DL, VT: getShiftAmountTy(LHSTy: VT)));
12215 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: LHS, N2: Shift);
12216 AddToWorklist(N: Shift.getNode());
12217 AddToWorklist(N: Add.getNode());
12218 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: Add, N2: Shift);
12219 }
12220
12221 // vselect x, y (fcmp lt x, y) -> fminnum x, y
12222 // vselect x, y (fcmp gt x, y) -> fmaxnum x, y
12223 //
12224 // This is OK if we don't care about what happens if either operand is a
12225 // NaN.
12226 //
12227 if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, LHS, RHS, TLI)) {
12228 if (SDValue FMinMax = combineMinNumMaxNum(DL, VT, LHS, RHS, True: N1, False: N2, CC))
12229 return FMinMax;
12230 }
12231
12232 if (SDValue S = PerformMinMaxFpToSatCombine(N0: LHS, N1: RHS, N2: N1, N3: N2, CC, DAG))
12233 return S;
12234 if (SDValue S = PerformUMinFpToSatCombine(N0: LHS, N1: RHS, N2: N1, N3: N2, CC, DAG))
12235 return S;
12236
12237 // If this select has a condition (setcc) with narrower operands than the
12238 // select, try to widen the compare to match the select width.
12239 // TODO: This should be extended to handle any constant.
12240 // TODO: This could be extended to handle non-loading patterns, but that
12241 // requires thorough testing to avoid regressions.
12242 if (isNullOrNullSplat(V: RHS)) {
12243 EVT NarrowVT = LHS.getValueType();
12244 EVT WideVT = N1.getValueType().changeVectorElementTypeToInteger();
12245 EVT SetCCVT = getSetCCResultType(VT: LHS.getValueType());
12246 unsigned SetCCWidth = SetCCVT.getScalarSizeInBits();
12247 unsigned WideWidth = WideVT.getScalarSizeInBits();
12248 bool IsSigned = isSignedIntSetCC(Code: CC);
12249 auto LoadExtOpcode = IsSigned ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
12250 if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() &&
12251 SetCCWidth != 1 && SetCCWidth < WideWidth &&
12252 TLI.isLoadExtLegalOrCustom(ExtType: LoadExtOpcode, ValVT: WideVT, MemVT: NarrowVT) &&
12253 TLI.isOperationLegalOrCustom(Op: ISD::SETCC, VT: WideVT)) {
12254 // Both compare operands can be widened for free. The LHS can use an
12255 // extended load, and the RHS is a constant:
12256 // vselect (ext (setcc load(X), C)), N1, N2 -->
12257 // vselect (setcc extload(X), C'), N1, N2
12258 auto ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
12259 SDValue WideLHS = DAG.getNode(Opcode: ExtOpcode, DL, VT: WideVT, Operand: LHS);
12260 SDValue WideRHS = DAG.getNode(Opcode: ExtOpcode, DL, VT: WideVT, Operand: RHS);
12261 EVT WideSetCCVT = getSetCCResultType(VT: WideVT);
12262 SDValue WideSetCC = DAG.getSetCC(DL, VT: WideSetCCVT, LHS: WideLHS, RHS: WideRHS, Cond: CC);
12263 return DAG.getSelect(DL, VT: N1.getValueType(), Cond: WideSetCC, LHS: N1, RHS: N2);
12264 }
12265 }
12266
12267 // Match VSELECTs with absolute difference patterns.
12268 // (vselect (setcc a, b, set?gt), (sub a, b), (sub b, a)) --> (abd? a, b)
12269 // (vselect (setcc a, b, set?ge), (sub a, b), (sub b, a)) --> (abd? a, b)
12270 // (vselect (setcc a, b, set?lt), (sub b, a), (sub a, b)) --> (abd? a, b)
12271 // (vselect (setcc a, b, set?le), (sub b, a), (sub a, b)) --> (abd? a, b)
12272 if (N1.getOpcode() == ISD::SUB && N2.getOpcode() == ISD::SUB &&
12273 N1.getOperand(i: 0) == N2.getOperand(i: 1) &&
12274 N1.getOperand(i: 1) == N2.getOperand(i: 0)) {
12275 bool IsSigned = isSignedIntSetCC(Code: CC);
12276 unsigned ABDOpc = IsSigned ? ISD::ABDS : ISD::ABDU;
12277 if (hasOperation(Opcode: ABDOpc, VT)) {
12278 switch (CC) {
12279 case ISD::SETGT:
12280 case ISD::SETGE:
12281 case ISD::SETUGT:
12282 case ISD::SETUGE:
12283 if (LHS == N1.getOperand(i: 0) && RHS == N1.getOperand(i: 1))
12284 return DAG.getNode(Opcode: ABDOpc, DL, VT, N1: LHS, N2: RHS);
12285 break;
12286 case ISD::SETLT:
12287 case ISD::SETLE:
12288 case ISD::SETULT:
12289 case ISD::SETULE:
12290 if (RHS == N1.getOperand(i: 0) && LHS == N1.getOperand(i: 1) )
12291 return DAG.getNode(Opcode: ABDOpc, DL, VT, N1: LHS, N2: RHS);
12292 break;
12293 default:
12294 break;
12295 }
12296 }
12297 }
12298
12299 // Match VSELECTs into add with unsigned saturation.
12300 if (hasOperation(Opcode: ISD::UADDSAT, VT)) {
12301 // Check if one of the arms of the VSELECT is vector with all bits set.
12302 // If it's on the left side invert the predicate to simplify logic below.
12303 SDValue Other;
12304 ISD::CondCode SatCC = CC;
12305 if (ISD::isConstantSplatVectorAllOnes(N: N1.getNode())) {
12306 Other = N2;
12307 SatCC = ISD::getSetCCInverse(Operation: SatCC, Type: VT.getScalarType());
12308 } else if (ISD::isConstantSplatVectorAllOnes(N: N2.getNode())) {
12309 Other = N1;
12310 }
12311
12312 if (Other && Other.getOpcode() == ISD::ADD) {
12313 SDValue CondLHS = LHS, CondRHS = RHS;
12314 SDValue OpLHS = Other.getOperand(i: 0), OpRHS = Other.getOperand(i: 1);
12315
12316 // Canonicalize condition operands.
12317 if (SatCC == ISD::SETUGE) {
12318 std::swap(a&: CondLHS, b&: CondRHS);
12319 SatCC = ISD::SETULE;
12320 }
12321
12322 // We can test against either of the addition operands.
12323 // x <= x+y ? x+y : ~0 --> uaddsat x, y
12324 // x+y >= x ? x+y : ~0 --> uaddsat x, y
12325 if (SatCC == ISD::SETULE && Other == CondRHS &&
12326 (OpLHS == CondLHS || OpRHS == CondLHS))
12327 return DAG.getNode(Opcode: ISD::UADDSAT, DL, VT, N1: OpLHS, N2: OpRHS);
12328
12329 if (OpRHS.getOpcode() == CondRHS.getOpcode() &&
12330 (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
12331 OpRHS.getOpcode() == ISD::SPLAT_VECTOR) &&
12332 CondLHS == OpLHS) {
12333 // If the RHS is a constant we have to reverse the const
12334 // canonicalization.
12335 // x >= ~C ? x+C : ~0 --> uaddsat x, C
12336 auto MatchUADDSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
12337 return Cond->getAPIntValue() == ~Op->getAPIntValue();
12338 };
12339 if (SatCC == ISD::SETULE &&
12340 ISD::matchBinaryPredicate(LHS: OpRHS, RHS: CondRHS, Match: MatchUADDSAT))
12341 return DAG.getNode(Opcode: ISD::UADDSAT, DL, VT, N1: OpLHS, N2: OpRHS);
12342 }
12343 }
12344 }
12345
12346 // Match VSELECTs into sub with unsigned saturation.
12347 if (hasOperation(Opcode: ISD::USUBSAT, VT)) {
12348 // Check if one of the arms of the VSELECT is a zero vector. If it's on
12349 // the left side invert the predicate to simplify logic below.
12350 SDValue Other;
12351 ISD::CondCode SatCC = CC;
12352 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode())) {
12353 Other = N2;
12354 SatCC = ISD::getSetCCInverse(Operation: SatCC, Type: VT.getScalarType());
12355 } else if (ISD::isConstantSplatVectorAllZeros(N: N2.getNode())) {
12356 Other = N1;
12357 }
12358
12359 // zext(x) >= y ? trunc(zext(x) - y) : 0
12360 // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
12361 // zext(x) > y ? trunc(zext(x) - y) : 0
12362 // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
12363 if (Other && Other.getOpcode() == ISD::TRUNCATE &&
12364 Other.getOperand(i: 0).getOpcode() == ISD::SUB &&
12365 (SatCC == ISD::SETUGE || SatCC == ISD::SETUGT)) {
12366 SDValue OpLHS = Other.getOperand(i: 0).getOperand(i: 0);
12367 SDValue OpRHS = Other.getOperand(i: 0).getOperand(i: 1);
12368 if (LHS == OpLHS && RHS == OpRHS && LHS.getOpcode() == ISD::ZERO_EXTEND)
12369 if (SDValue R = getTruncatedUSUBSAT(DstVT: VT, SrcVT: LHS.getValueType(), LHS, RHS,
12370 DAG, DL))
12371 return R;
12372 }
12373
12374 if (Other && Other.getNumOperands() == 2) {
12375 SDValue CondRHS = RHS;
12376 SDValue OpLHS = Other.getOperand(i: 0), OpRHS = Other.getOperand(i: 1);
12377
12378 if (OpLHS == LHS) {
12379 // Look for a general sub with unsigned saturation first.
12380 // x >= y ? x-y : 0 --> usubsat x, y
12381 // x > y ? x-y : 0 --> usubsat x, y
12382 if ((SatCC == ISD::SETUGE || SatCC == ISD::SETUGT) &&
12383 Other.getOpcode() == ISD::SUB && OpRHS == CondRHS)
12384 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT, N1: OpLHS, N2: OpRHS);
12385
12386 if (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
12387 OpRHS.getOpcode() == ISD::SPLAT_VECTOR) {
12388 if (CondRHS.getOpcode() == ISD::BUILD_VECTOR ||
12389 CondRHS.getOpcode() == ISD::SPLAT_VECTOR) {
12390 // If the RHS is a constant we have to reverse the const
12391 // canonicalization.
12392 // x > C-1 ? x+-C : 0 --> usubsat x, C
12393 auto MatchUSUBSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
12394 return (!Op && !Cond) ||
12395 (Op && Cond &&
12396 Cond->getAPIntValue() == (-Op->getAPIntValue() - 1));
12397 };
12398 if (SatCC == ISD::SETUGT && Other.getOpcode() == ISD::ADD &&
12399 ISD::matchBinaryPredicate(LHS: OpRHS, RHS: CondRHS, Match: MatchUSUBSAT,
12400 /*AllowUndefs*/ true)) {
12401 OpRHS = DAG.getNegative(Val: OpRHS, DL, VT);
12402 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT, N1: OpLHS, N2: OpRHS);
12403 }
12404
12405 // Another special case: If C was a sign bit, the sub has been
12406 // canonicalized into a xor.
12407 // FIXME: Would it be better to use computeKnownBits to
12408 // determine whether it's safe to decanonicalize the xor?
12409 // x s< 0 ? x^C : 0 --> usubsat x, C
12410 APInt SplatValue;
12411 if (SatCC == ISD::SETLT && Other.getOpcode() == ISD::XOR &&
12412 ISD::isConstantSplatVector(N: OpRHS.getNode(), SplatValue) &&
12413 ISD::isConstantSplatVectorAllZeros(N: CondRHS.getNode()) &&
12414 SplatValue.isSignMask()) {
12415 // Note that we have to rebuild the RHS constant here to
12416 // ensure we don't rely on particular values of undef lanes.
12417 OpRHS = DAG.getConstant(Val: SplatValue, DL, VT);
12418 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT, N1: OpLHS, N2: OpRHS);
12419 }
12420 }
12421 }
12422 }
12423 }
12424 }
12425 }
12426
12427 if (SimplifySelectOps(SELECT: N, LHS: N1, RHS: N2))
12428 return SDValue(N, 0); // Don't revisit N.
12429
12430 // Fold (vselect all_ones, N1, N2) -> N1
12431 if (ISD::isConstantSplatVectorAllOnes(N: N0.getNode()))
12432 return N1;
12433 // Fold (vselect all_zeros, N1, N2) -> N2
12434 if (ISD::isConstantSplatVectorAllZeros(N: N0.getNode()))
12435 return N2;
12436
12437 // The ConvertSelectToConcatVector function is assuming both the above
12438 // checks for (vselect (build_vector all{ones,zeros) ...) have been made
12439 // and addressed.
12440 if (N1.getOpcode() == ISD::CONCAT_VECTORS &&
12441 N2.getOpcode() == ISD::CONCAT_VECTORS &&
12442 ISD::isBuildVectorOfConstantSDNodes(N: N0.getNode())) {
12443 if (SDValue CV = ConvertSelectToConcatVector(N, DAG))
12444 return CV;
12445 }
12446
12447 if (SDValue V = foldVSelectOfConstants(N))
12448 return V;
12449
12450 if (hasOperation(Opcode: ISD::SRA, VT))
12451 if (SDValue V = foldVSelectToSignBitSplatMask(N, DAG))
12452 return V;
12453
12454 if (SimplifyDemandedVectorElts(Op: SDValue(N, 0)))
12455 return SDValue(N, 0);
12456
12457 return SDValue();
12458}
12459
12460SDValue DAGCombiner::visitSELECT_CC(SDNode *N) {
12461 SDValue N0 = N->getOperand(Num: 0);
12462 SDValue N1 = N->getOperand(Num: 1);
12463 SDValue N2 = N->getOperand(Num: 2);
12464 SDValue N3 = N->getOperand(Num: 3);
12465 SDValue N4 = N->getOperand(Num: 4);
12466 ISD::CondCode CC = cast<CondCodeSDNode>(Val&: N4)->get();
12467
12468 // fold select_cc lhs, rhs, x, x, cc -> x
12469 if (N2 == N3)
12470 return N2;
12471
12472 // select_cc bool, 0, x, y, seteq -> select bool, y, x
12473 if (CC == ISD::SETEQ && !LegalTypes && N0.getValueType() == MVT::i1 &&
12474 isNullConstant(N1))
12475 return DAG.getSelect(DL: SDLoc(N), VT: N2.getValueType(), Cond: N0, LHS: N3, RHS: N2);
12476
12477 // Determine if the condition we're dealing with is constant
12478 if (SDValue SCC = SimplifySetCC(VT: getSetCCResultType(VT: N0.getValueType()), N0, N1,
12479 Cond: CC, DL: SDLoc(N), foldBooleans: false)) {
12480 AddToWorklist(N: SCC.getNode());
12481
12482 // cond always true -> true val
12483 // cond always false -> false val
12484 if (auto *SCCC = dyn_cast<ConstantSDNode>(Val: SCC.getNode()))
12485 return SCCC->isZero() ? N3 : N2;
12486
12487 // When the condition is UNDEF, just return the first operand. This is
12488 // coherent the DAG creation, no setcc node is created in this case
12489 if (SCC->isUndef())
12490 return N2;
12491
12492 // Fold to a simpler select_cc
12493 if (SCC.getOpcode() == ISD::SETCC) {
12494 SDValue SelectOp = DAG.getNode(
12495 Opcode: ISD::SELECT_CC, DL: SDLoc(N), VT: N2.getValueType(), N1: SCC.getOperand(i: 0),
12496 N2: SCC.getOperand(i: 1), N3: N2, N4: N3, N5: SCC.getOperand(i: 2));
12497 SelectOp->setFlags(SCC->getFlags());
12498 return SelectOp;
12499 }
12500 }
12501
12502 // If we can fold this based on the true/false value, do so.
12503 if (SimplifySelectOps(SELECT: N, LHS: N2, RHS: N3))
12504 return SDValue(N, 0); // Don't revisit N.
12505
12506 // fold select_cc into other things, such as min/max/abs
12507 return SimplifySelectCC(DL: SDLoc(N), N0, N1, N2, N3, CC);
12508}
12509
12510SDValue DAGCombiner::visitSETCC(SDNode *N) {
12511 // setcc is very commonly used as an argument to brcond. This pattern
12512 // also lend itself to numerous combines and, as a result, it is desired
12513 // we keep the argument to a brcond as a setcc as much as possible.
12514 bool PreferSetCC =
12515 N->hasOneUse() && N->use_begin()->getOpcode() == ISD::BRCOND;
12516
12517 ISD::CondCode Cond = cast<CondCodeSDNode>(Val: N->getOperand(Num: 2))->get();
12518 EVT VT = N->getValueType(ResNo: 0);
12519 SDValue N0 = N->getOperand(Num: 0), N1 = N->getOperand(Num: 1);
12520
12521 SDValue Combined = SimplifySetCC(VT, N0, N1, Cond, DL: SDLoc(N), foldBooleans: !PreferSetCC);
12522
12523 if (Combined) {
12524 // If we prefer to have a setcc, and we don't, we'll try our best to
12525 // recreate one using rebuildSetCC.
12526 if (PreferSetCC && Combined.getOpcode() != ISD::SETCC) {
12527 SDValue NewSetCC = rebuildSetCC(N: Combined);
12528
12529 // We don't have anything interesting to combine to.
12530 if (NewSetCC.getNode() == N)
12531 return SDValue();
12532
12533 if (NewSetCC)
12534 return NewSetCC;
12535 }
12536 return Combined;
12537 }
12538
12539 // Optimize
12540 // 1) (icmp eq/ne (and X, C0), (shift X, C1))
12541 // or
12542 // 2) (icmp eq/ne X, (rotate X, C1))
12543 // If C0 is a mask or shifted mask and the shift amt (C1) isolates the
12544 // remaining bits (i.e something like `(x64 & UINT32_MAX) == (x64 >> 32)`)
12545 // Then:
12546 // If C1 is a power of 2, then the rotate and shift+and versions are
12547 // equivilent, so we can interchange them depending on target preference.
12548 // Otherwise, if we have the shift+and version we can interchange srl/shl
12549 // which inturn affects the constant C0. We can use this to get better
12550 // constants again determined by target preference.
12551 if (Cond == ISD::SETNE || Cond == ISD::SETEQ) {
12552 auto IsAndWithShift = [](SDValue A, SDValue B) {
12553 return A.getOpcode() == ISD::AND &&
12554 (B.getOpcode() == ISD::SRL || B.getOpcode() == ISD::SHL) &&
12555 A.getOperand(i: 0) == B.getOperand(i: 0);
12556 };
12557 auto IsRotateWithOp = [](SDValue A, SDValue B) {
12558 return (B.getOpcode() == ISD::ROTL || B.getOpcode() == ISD::ROTR) &&
12559 B.getOperand(i: 0) == A;
12560 };
12561 SDValue AndOrOp = SDValue(), ShiftOrRotate = SDValue();
12562 bool IsRotate = false;
12563
12564 // Find either shift+and or rotate pattern.
12565 if (IsAndWithShift(N0, N1)) {
12566 AndOrOp = N0;
12567 ShiftOrRotate = N1;
12568 } else if (IsAndWithShift(N1, N0)) {
12569 AndOrOp = N1;
12570 ShiftOrRotate = N0;
12571 } else if (IsRotateWithOp(N0, N1)) {
12572 IsRotate = true;
12573 AndOrOp = N0;
12574 ShiftOrRotate = N1;
12575 } else if (IsRotateWithOp(N1, N0)) {
12576 IsRotate = true;
12577 AndOrOp = N1;
12578 ShiftOrRotate = N0;
12579 }
12580
12581 if (AndOrOp && ShiftOrRotate && ShiftOrRotate.hasOneUse() &&
12582 (IsRotate || AndOrOp.hasOneUse())) {
12583 EVT OpVT = N0.getValueType();
12584 // Get constant shift/rotate amount and possibly mask (if its shift+and
12585 // variant).
12586 auto GetAPIntValue = [](SDValue Op) -> std::optional<APInt> {
12587 ConstantSDNode *CNode = isConstOrConstSplat(N: Op, /*AllowUndefs*/ false,
12588 /*AllowTrunc*/ AllowTruncation: false);
12589 if (CNode == nullptr)
12590 return std::nullopt;
12591 return CNode->getAPIntValue();
12592 };
12593 std::optional<APInt> AndCMask =
12594 IsRotate ? std::nullopt : GetAPIntValue(AndOrOp.getOperand(i: 1));
12595 std::optional<APInt> ShiftCAmt =
12596 GetAPIntValue(ShiftOrRotate.getOperand(i: 1));
12597 unsigned NumBits = OpVT.getScalarSizeInBits();
12598
12599 // We found constants.
12600 if (ShiftCAmt && (IsRotate || AndCMask) && ShiftCAmt->ult(RHS: NumBits)) {
12601 unsigned ShiftOpc = ShiftOrRotate.getOpcode();
12602 // Check that the constants meet the constraints.
12603 bool CanTransform = IsRotate;
12604 if (!CanTransform) {
12605 // Check that mask and shift compliment eachother
12606 CanTransform = *ShiftCAmt == (~*AndCMask).popcount();
12607 // Check that we are comparing all bits
12608 CanTransform &= (*ShiftCAmt + AndCMask->popcount()) == NumBits;
12609 // Check that the and mask is correct for the shift
12610 CanTransform &=
12611 ShiftOpc == ISD::SHL ? (~*AndCMask).isMask() : AndCMask->isMask();
12612 }
12613
12614 // See if target prefers another shift/rotate opcode.
12615 unsigned NewShiftOpc = TLI.preferedOpcodeForCmpEqPiecesOfOperand(
12616 VT: OpVT, ShiftOpc, MayTransformRotate: ShiftCAmt->isPowerOf2(), ShiftOrRotateAmt: *ShiftCAmt, AndMask: AndCMask);
12617 // Transform is valid and we have a new preference.
12618 if (CanTransform && NewShiftOpc != ShiftOpc) {
12619 SDLoc DL(N);
12620 SDValue NewShiftOrRotate =
12621 DAG.getNode(Opcode: NewShiftOpc, DL, VT: OpVT, N1: ShiftOrRotate.getOperand(i: 0),
12622 N2: ShiftOrRotate.getOperand(i: 1));
12623 SDValue NewAndOrOp = SDValue();
12624
12625 if (NewShiftOpc == ISD::SHL || NewShiftOpc == ISD::SRL) {
12626 APInt NewMask =
12627 NewShiftOpc == ISD::SHL
12628 ? APInt::getHighBitsSet(numBits: NumBits,
12629 hiBitsSet: NumBits - ShiftCAmt->getZExtValue())
12630 : APInt::getLowBitsSet(numBits: NumBits,
12631 loBitsSet: NumBits - ShiftCAmt->getZExtValue());
12632 NewAndOrOp =
12633 DAG.getNode(Opcode: ISD::AND, DL, VT: OpVT, N1: ShiftOrRotate.getOperand(i: 0),
12634 N2: DAG.getConstant(Val: NewMask, DL, VT: OpVT));
12635 } else {
12636 NewAndOrOp = ShiftOrRotate.getOperand(i: 0);
12637 }
12638
12639 return DAG.getSetCC(DL, VT, LHS: NewAndOrOp, RHS: NewShiftOrRotate, Cond);
12640 }
12641 }
12642 }
12643 }
12644 return SDValue();
12645}
12646
12647SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) {
12648 SDValue LHS = N->getOperand(Num: 0);
12649 SDValue RHS = N->getOperand(Num: 1);
12650 SDValue Carry = N->getOperand(Num: 2);
12651 SDValue Cond = N->getOperand(Num: 3);
12652
12653 // If Carry is false, fold to a regular SETCC.
12654 if (isNullConstant(V: Carry))
12655 return DAG.getNode(Opcode: ISD::SETCC, DL: SDLoc(N), VTList: N->getVTList(), N1: LHS, N2: RHS, N3: Cond);
12656
12657 return SDValue();
12658}
12659
12660/// Check if N satisfies:
12661/// N is used once.
12662/// N is a Load.
12663/// The load is compatible with ExtOpcode. It means
12664/// If load has explicit zero/sign extension, ExpOpcode must have the same
12665/// extension.
12666/// Otherwise returns true.
12667static bool isCompatibleLoad(SDValue N, unsigned ExtOpcode) {
12668 if (!N.hasOneUse())
12669 return false;
12670
12671 if (!isa<LoadSDNode>(Val: N))
12672 return false;
12673
12674 LoadSDNode *Load = cast<LoadSDNode>(Val&: N);
12675 ISD::LoadExtType LoadExt = Load->getExtensionType();
12676 if (LoadExt == ISD::NON_EXTLOAD || LoadExt == ISD::EXTLOAD)
12677 return true;
12678
12679 // Now LoadExt is either SEXTLOAD or ZEXTLOAD, ExtOpcode must have the same
12680 // extension.
12681 if ((LoadExt == ISD::SEXTLOAD && ExtOpcode != ISD::SIGN_EXTEND) ||
12682 (LoadExt == ISD::ZEXTLOAD && ExtOpcode != ISD::ZERO_EXTEND))
12683 return false;
12684
12685 return true;
12686}
12687
12688/// Fold
12689/// (sext (select c, load x, load y)) -> (select c, sextload x, sextload y)
12690/// (zext (select c, load x, load y)) -> (select c, zextload x, zextload y)
12691/// (aext (select c, load x, load y)) -> (select c, extload x, extload y)
12692/// This function is called by the DAGCombiner when visiting sext/zext/aext
12693/// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
12694static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI,
12695 SelectionDAG &DAG,
12696 CombineLevel Level) {
12697 unsigned Opcode = N->getOpcode();
12698 SDValue N0 = N->getOperand(Num: 0);
12699 EVT VT = N->getValueType(ResNo: 0);
12700 SDLoc DL(N);
12701
12702 assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
12703 Opcode == ISD::ANY_EXTEND) &&
12704 "Expected EXTEND dag node in input!");
12705
12706 if (!(N0->getOpcode() == ISD::SELECT || N0->getOpcode() == ISD::VSELECT) ||
12707 !N0.hasOneUse())
12708 return SDValue();
12709
12710 SDValue Op1 = N0->getOperand(Num: 1);
12711 SDValue Op2 = N0->getOperand(Num: 2);
12712 if (!isCompatibleLoad(N: Op1, ExtOpcode: Opcode) || !isCompatibleLoad(N: Op2, ExtOpcode: Opcode))
12713 return SDValue();
12714
12715 auto ExtLoadOpcode = ISD::EXTLOAD;
12716 if (Opcode == ISD::SIGN_EXTEND)
12717 ExtLoadOpcode = ISD::SEXTLOAD;
12718 else if (Opcode == ISD::ZERO_EXTEND)
12719 ExtLoadOpcode = ISD::ZEXTLOAD;
12720
12721 // Illegal VSELECT may ISel fail if happen after legalization (DAG
12722 // Combine2), so we should conservatively check the OperationAction.
12723 LoadSDNode *Load1 = cast<LoadSDNode>(Val&: Op1);
12724 LoadSDNode *Load2 = cast<LoadSDNode>(Val&: Op2);
12725 if (!TLI.isLoadExtLegal(ExtType: ExtLoadOpcode, ValVT: VT, MemVT: Load1->getMemoryVT()) ||
12726 !TLI.isLoadExtLegal(ExtType: ExtLoadOpcode, ValVT: VT, MemVT: Load2->getMemoryVT()) ||
12727 (N0->getOpcode() == ISD::VSELECT && Level >= AfterLegalizeTypes &&
12728 TLI.getOperationAction(Op: ISD::VSELECT, VT) != TargetLowering::Legal))
12729 return SDValue();
12730
12731 SDValue Ext1 = DAG.getNode(Opcode, DL, VT, Operand: Op1);
12732 SDValue Ext2 = DAG.getNode(Opcode, DL, VT, Operand: Op2);
12733 return DAG.getSelect(DL, VT, Cond: N0->getOperand(Num: 0), LHS: Ext1, RHS: Ext2);
12734}
12735
12736/// Try to fold a sext/zext/aext dag node into a ConstantSDNode or
12737/// a build_vector of constants.
12738/// This function is called by the DAGCombiner when visiting sext/zext/aext
12739/// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
12740/// Vector extends are not folded if operations are legal; this is to
12741/// avoid introducing illegal build_vector dag nodes.
12742static SDValue tryToFoldExtendOfConstant(SDNode *N, const SDLoc &DL,
12743 const TargetLowering &TLI,
12744 SelectionDAG &DAG, bool LegalTypes) {
12745 unsigned Opcode = N->getOpcode();
12746 SDValue N0 = N->getOperand(Num: 0);
12747 EVT VT = N->getValueType(ResNo: 0);
12748
12749 assert((ISD::isExtOpcode(Opcode) || ISD::isExtVecInRegOpcode(Opcode)) &&
12750 "Expected EXTEND dag node in input!");
12751
12752 // fold (sext c1) -> c1
12753 // fold (zext c1) -> c1
12754 // fold (aext c1) -> c1
12755 if (isa<ConstantSDNode>(Val: N0))
12756 return DAG.getNode(Opcode, DL, VT, Operand: N0);
12757
12758 // fold (sext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
12759 // fold (zext (select cond, c1, c2)) -> (select cond, zext c1, zext c2)
12760 // fold (aext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
12761 if (N0->getOpcode() == ISD::SELECT) {
12762 SDValue Op1 = N0->getOperand(Num: 1);
12763 SDValue Op2 = N0->getOperand(Num: 2);
12764 if (isa<ConstantSDNode>(Val: Op1) && isa<ConstantSDNode>(Val: Op2) &&
12765 (Opcode != ISD::ZERO_EXTEND || !TLI.isZExtFree(FromTy: N0.getValueType(), ToTy: VT))) {
12766 // For any_extend, choose sign extension of the constants to allow a
12767 // possible further transform to sign_extend_inreg.i.e.
12768 //
12769 // t1: i8 = select t0, Constant:i8<-1>, Constant:i8<0>
12770 // t2: i64 = any_extend t1
12771 // -->
12772 // t3: i64 = select t0, Constant:i64<-1>, Constant:i64<0>
12773 // -->
12774 // t4: i64 = sign_extend_inreg t3
12775 unsigned FoldOpc = Opcode;
12776 if (FoldOpc == ISD::ANY_EXTEND)
12777 FoldOpc = ISD::SIGN_EXTEND;
12778 return DAG.getSelect(DL, VT, Cond: N0->getOperand(Num: 0),
12779 LHS: DAG.getNode(Opcode: FoldOpc, DL, VT, Operand: Op1),
12780 RHS: DAG.getNode(Opcode: FoldOpc, DL, VT, Operand: Op2));
12781 }
12782 }
12783
12784 // fold (sext (build_vector AllConstants) -> (build_vector AllConstants)
12785 // fold (zext (build_vector AllConstants) -> (build_vector AllConstants)
12786 // fold (aext (build_vector AllConstants) -> (build_vector AllConstants)
12787 EVT SVT = VT.getScalarType();
12788 if (!(VT.isVector() && (!LegalTypes || TLI.isTypeLegal(VT: SVT)) &&
12789 ISD::isBuildVectorOfConstantSDNodes(N: N0.getNode())))
12790 return SDValue();
12791
12792 // We can fold this node into a build_vector.
12793 unsigned VTBits = SVT.getSizeInBits();
12794 unsigned EVTBits = N0->getValueType(ResNo: 0).getScalarSizeInBits();
12795 SmallVector<SDValue, 8> Elts;
12796 unsigned NumElts = VT.getVectorNumElements();
12797
12798 for (unsigned i = 0; i != NumElts; ++i) {
12799 SDValue Op = N0.getOperand(i);
12800 if (Op.isUndef()) {
12801 if (Opcode == ISD::ANY_EXTEND || Opcode == ISD::ANY_EXTEND_VECTOR_INREG)
12802 Elts.push_back(Elt: DAG.getUNDEF(VT: SVT));
12803 else
12804 Elts.push_back(Elt: DAG.getConstant(Val: 0, DL, VT: SVT));
12805 continue;
12806 }
12807
12808 SDLoc DL(Op);
12809 // Get the constant value and if needed trunc it to the size of the type.
12810 // Nodes like build_vector might have constants wider than the scalar type.
12811 APInt C = Op->getAsAPIntVal().zextOrTrunc(width: EVTBits);
12812 if (Opcode == ISD::SIGN_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG)
12813 Elts.push_back(Elt: DAG.getConstant(Val: C.sext(width: VTBits), DL, VT: SVT));
12814 else
12815 Elts.push_back(Elt: DAG.getConstant(Val: C.zext(width: VTBits), DL, VT: SVT));
12816 }
12817
12818 return DAG.getBuildVector(VT, DL, Ops: Elts);
12819}
12820
12821// ExtendUsesToFormExtLoad - Trying to extend uses of a load to enable this:
12822// "fold ({s|z|a}ext (load x)) -> ({s|z|a}ext (truncate ({s|z|a}extload x)))"
12823// transformation. Returns true if extension are possible and the above
12824// mentioned transformation is profitable.
12825static bool ExtendUsesToFormExtLoad(EVT VT, SDNode *N, SDValue N0,
12826 unsigned ExtOpc,
12827 SmallVectorImpl<SDNode *> &ExtendNodes,
12828 const TargetLowering &TLI) {
12829 bool HasCopyToRegUses = false;
12830 bool isTruncFree = TLI.isTruncateFree(FromVT: VT, ToVT: N0.getValueType());
12831 for (SDNode::use_iterator UI = N0->use_begin(), UE = N0->use_end(); UI != UE;
12832 ++UI) {
12833 SDNode *User = *UI;
12834 if (User == N)
12835 continue;
12836 if (UI.getUse().getResNo() != N0.getResNo())
12837 continue;
12838 // FIXME: Only extend SETCC N, N and SETCC N, c for now.
12839 if (ExtOpc != ISD::ANY_EXTEND && User->getOpcode() == ISD::SETCC) {
12840 ISD::CondCode CC = cast<CondCodeSDNode>(Val: User->getOperand(Num: 2))->get();
12841 if (ExtOpc == ISD::ZERO_EXTEND && ISD::isSignedIntSetCC(Code: CC))
12842 // Sign bits will be lost after a zext.
12843 return false;
12844 bool Add = false;
12845 for (unsigned i = 0; i != 2; ++i) {
12846 SDValue UseOp = User->getOperand(Num: i);
12847 if (UseOp == N0)
12848 continue;
12849 if (!isa<ConstantSDNode>(Val: UseOp))
12850 return false;
12851 Add = true;
12852 }
12853 if (Add)
12854 ExtendNodes.push_back(Elt: User);
12855 continue;
12856 }
12857 // If truncates aren't free and there are users we can't
12858 // extend, it isn't worthwhile.
12859 if (!isTruncFree)
12860 return false;
12861 // Remember if this value is live-out.
12862 if (User->getOpcode() == ISD::CopyToReg)
12863 HasCopyToRegUses = true;
12864 }
12865
12866 if (HasCopyToRegUses) {
12867 bool BothLiveOut = false;
12868 for (SDNode::use_iterator UI = N->use_begin(), UE = N->use_end();
12869 UI != UE; ++UI) {
12870 SDUse &Use = UI.getUse();
12871 if (Use.getResNo() == 0 && Use.getUser()->getOpcode() == ISD::CopyToReg) {
12872 BothLiveOut = true;
12873 break;
12874 }
12875 }
12876 if (BothLiveOut)
12877 // Both unextended and extended values are live out. There had better be
12878 // a good reason for the transformation.
12879 return !ExtendNodes.empty();
12880 }
12881 return true;
12882}
12883
12884void DAGCombiner::ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
12885 SDValue OrigLoad, SDValue ExtLoad,
12886 ISD::NodeType ExtType) {
12887 // Extend SetCC uses if necessary.
12888 SDLoc DL(ExtLoad);
12889 for (SDNode *SetCC : SetCCs) {
12890 SmallVector<SDValue, 4> Ops;
12891
12892 for (unsigned j = 0; j != 2; ++j) {
12893 SDValue SOp = SetCC->getOperand(Num: j);
12894 if (SOp == OrigLoad)
12895 Ops.push_back(Elt: ExtLoad);
12896 else
12897 Ops.push_back(Elt: DAG.getNode(Opcode: ExtType, DL, VT: ExtLoad->getValueType(ResNo: 0), Operand: SOp));
12898 }
12899
12900 Ops.push_back(Elt: SetCC->getOperand(Num: 2));
12901 CombineTo(N: SetCC, Res: DAG.getNode(Opcode: ISD::SETCC, DL, VT: SetCC->getValueType(ResNo: 0), Ops));
12902 }
12903}
12904
12905// FIXME: Bring more similar combines here, common to sext/zext (maybe aext?).
12906SDValue DAGCombiner::CombineExtLoad(SDNode *N) {
12907 SDValue N0 = N->getOperand(Num: 0);
12908 EVT DstVT = N->getValueType(ResNo: 0);
12909 EVT SrcVT = N0.getValueType();
12910
12911 assert((N->getOpcode() == ISD::SIGN_EXTEND ||
12912 N->getOpcode() == ISD::ZERO_EXTEND) &&
12913 "Unexpected node type (not an extend)!");
12914
12915 // fold (sext (load x)) to multiple smaller sextloads; same for zext.
12916 // For example, on a target with legal v4i32, but illegal v8i32, turn:
12917 // (v8i32 (sext (v8i16 (load x))))
12918 // into:
12919 // (v8i32 (concat_vectors (v4i32 (sextload x)),
12920 // (v4i32 (sextload (x + 16)))))
12921 // Where uses of the original load, i.e.:
12922 // (v8i16 (load x))
12923 // are replaced with:
12924 // (v8i16 (truncate
12925 // (v8i32 (concat_vectors (v4i32 (sextload x)),
12926 // (v4i32 (sextload (x + 16)))))))
12927 //
12928 // This combine is only applicable to illegal, but splittable, vectors.
12929 // All legal types, and illegal non-vector types, are handled elsewhere.
12930 // This combine is controlled by TargetLowering::isVectorLoadExtDesirable.
12931 //
12932 if (N0->getOpcode() != ISD::LOAD)
12933 return SDValue();
12934
12935 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
12936
12937 if (!ISD::isNON_EXTLoad(N: LN0) || !ISD::isUNINDEXEDLoad(N: LN0) ||
12938 !N0.hasOneUse() || !LN0->isSimple() ||
12939 !DstVT.isVector() || !DstVT.isPow2VectorType() ||
12940 !TLI.isVectorLoadExtDesirable(ExtVal: SDValue(N, 0)))
12941 return SDValue();
12942
12943 SmallVector<SDNode *, 4> SetCCs;
12944 if (!ExtendUsesToFormExtLoad(VT: DstVT, N, N0, ExtOpc: N->getOpcode(), ExtendNodes&: SetCCs, TLI))
12945 return SDValue();
12946
12947 ISD::LoadExtType ExtType =
12948 N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
12949
12950 // Try to split the vector types to get down to legal types.
12951 EVT SplitSrcVT = SrcVT;
12952 EVT SplitDstVT = DstVT;
12953 while (!TLI.isLoadExtLegalOrCustom(ExtType, ValVT: SplitDstVT, MemVT: SplitSrcVT) &&
12954 SplitSrcVT.getVectorNumElements() > 1) {
12955 SplitDstVT = DAG.GetSplitDestVTs(VT: SplitDstVT).first;
12956 SplitSrcVT = DAG.GetSplitDestVTs(VT: SplitSrcVT).first;
12957 }
12958
12959 if (!TLI.isLoadExtLegalOrCustom(ExtType, ValVT: SplitDstVT, MemVT: SplitSrcVT))
12960 return SDValue();
12961
12962 assert(!DstVT.isScalableVector() && "Unexpected scalable vector type");
12963
12964 SDLoc DL(N);
12965 const unsigned NumSplits =
12966 DstVT.getVectorNumElements() / SplitDstVT.getVectorNumElements();
12967 const unsigned Stride = SplitSrcVT.getStoreSize();
12968 SmallVector<SDValue, 4> Loads;
12969 SmallVector<SDValue, 4> Chains;
12970
12971 SDValue BasePtr = LN0->getBasePtr();
12972 for (unsigned Idx = 0; Idx < NumSplits; Idx++) {
12973 const unsigned Offset = Idx * Stride;
12974
12975 SDValue SplitLoad =
12976 DAG.getExtLoad(ExtType, dl: SDLoc(LN0), VT: SplitDstVT, Chain: LN0->getChain(),
12977 Ptr: BasePtr, PtrInfo: LN0->getPointerInfo().getWithOffset(O: Offset),
12978 MemVT: SplitSrcVT, Alignment: LN0->getOriginalAlign(),
12979 MMOFlags: LN0->getMemOperand()->getFlags(), AAInfo: LN0->getAAInfo());
12980
12981 BasePtr = DAG.getMemBasePlusOffset(Base: BasePtr, Offset: TypeSize::getFixed(ExactSize: Stride), DL);
12982
12983 Loads.push_back(Elt: SplitLoad.getValue(R: 0));
12984 Chains.push_back(Elt: SplitLoad.getValue(R: 1));
12985 }
12986
12987 SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains);
12988 SDValue NewValue = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT: DstVT, Ops: Loads);
12989
12990 // Simplify TF.
12991 AddToWorklist(N: NewChain.getNode());
12992
12993 CombineTo(N, Res: NewValue);
12994
12995 // Replace uses of the original load (before extension)
12996 // with a truncate of the concatenated sextloaded vectors.
12997 SDValue Trunc =
12998 DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N0), VT: N0.getValueType(), Operand: NewValue);
12999 ExtendSetCCUses(SetCCs, OrigLoad: N0, ExtLoad: NewValue, ExtType: (ISD::NodeType)N->getOpcode());
13000 CombineTo(N: N0.getNode(), Res0: Trunc, Res1: NewChain);
13001 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13002}
13003
13004// fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
13005// (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
13006SDValue DAGCombiner::CombineZExtLogicopShiftLoad(SDNode *N) {
13007 assert(N->getOpcode() == ISD::ZERO_EXTEND);
13008 EVT VT = N->getValueType(ResNo: 0);
13009 EVT OrigVT = N->getOperand(Num: 0).getValueType();
13010 if (TLI.isZExtFree(FromTy: OrigVT, ToTy: VT))
13011 return SDValue();
13012
13013 // and/or/xor
13014 SDValue N0 = N->getOperand(Num: 0);
13015 if (!ISD::isBitwiseLogicOp(Opcode: N0.getOpcode()) ||
13016 N0.getOperand(i: 1).getOpcode() != ISD::Constant ||
13017 (LegalOperations && !TLI.isOperationLegal(Op: N0.getOpcode(), VT)))
13018 return SDValue();
13019
13020 // shl/shr
13021 SDValue N1 = N0->getOperand(Num: 0);
13022 if (!(N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) ||
13023 N1.getOperand(i: 1).getOpcode() != ISD::Constant ||
13024 (LegalOperations && !TLI.isOperationLegal(Op: N1.getOpcode(), VT)))
13025 return SDValue();
13026
13027 // load
13028 if (!isa<LoadSDNode>(Val: N1.getOperand(i: 0)))
13029 return SDValue();
13030 LoadSDNode *Load = cast<LoadSDNode>(Val: N1.getOperand(i: 0));
13031 EVT MemVT = Load->getMemoryVT();
13032 if (!TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: VT, MemVT) ||
13033 Load->getExtensionType() == ISD::SEXTLOAD || Load->isIndexed())
13034 return SDValue();
13035
13036
13037 // If the shift op is SHL, the logic op must be AND, otherwise the result
13038 // will be wrong.
13039 if (N1.getOpcode() == ISD::SHL && N0.getOpcode() != ISD::AND)
13040 return SDValue();
13041
13042 if (!N0.hasOneUse() || !N1.hasOneUse())
13043 return SDValue();
13044
13045 SmallVector<SDNode*, 4> SetCCs;
13046 if (!ExtendUsesToFormExtLoad(VT, N: N1.getNode(), N0: N1.getOperand(i: 0),
13047 ExtOpc: ISD::ZERO_EXTEND, ExtendNodes&: SetCCs, TLI))
13048 return SDValue();
13049
13050 // Actually do the transformation.
13051 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::ZEXTLOAD, dl: SDLoc(Load), VT,
13052 Chain: Load->getChain(), Ptr: Load->getBasePtr(),
13053 MemVT: Load->getMemoryVT(), MMO: Load->getMemOperand());
13054
13055 SDLoc DL1(N1);
13056 SDValue Shift = DAG.getNode(Opcode: N1.getOpcode(), DL: DL1, VT, N1: ExtLoad,
13057 N2: N1.getOperand(i: 1));
13058
13059 APInt Mask = N0.getConstantOperandAPInt(i: 1).zext(width: VT.getSizeInBits());
13060 SDLoc DL0(N0);
13061 SDValue And = DAG.getNode(Opcode: N0.getOpcode(), DL: DL0, VT, N1: Shift,
13062 N2: DAG.getConstant(Val: Mask, DL: DL0, VT));
13063
13064 ExtendSetCCUses(SetCCs, OrigLoad: N1.getOperand(i: 0), ExtLoad, ExtType: ISD::ZERO_EXTEND);
13065 CombineTo(N, Res: And);
13066 if (SDValue(Load, 0).hasOneUse()) {
13067 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Load, 1), To: ExtLoad.getValue(R: 1));
13068 } else {
13069 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(Load),
13070 VT: Load->getValueType(ResNo: 0), Operand: ExtLoad);
13071 CombineTo(N: Load, Res0: Trunc, Res1: ExtLoad.getValue(R: 1));
13072 }
13073
13074 // N0 is dead at this point.
13075 recursivelyDeleteUnusedNodes(N: N0.getNode());
13076
13077 return SDValue(N,0); // Return N so it doesn't get rechecked!
13078}
13079
13080/// If we're narrowing or widening the result of a vector select and the final
13081/// size is the same size as a setcc (compare) feeding the select, then try to
13082/// apply the cast operation to the select's operands because matching vector
13083/// sizes for a select condition and other operands should be more efficient.
13084SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
13085 unsigned CastOpcode = Cast->getOpcode();
13086 assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
13087 CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
13088 CastOpcode == ISD::FP_ROUND) &&
13089 "Unexpected opcode for vector select narrowing/widening");
13090
13091 // We only do this transform before legal ops because the pattern may be
13092 // obfuscated by target-specific operations after legalization. Do not create
13093 // an illegal select op, however, because that may be difficult to lower.
13094 EVT VT = Cast->getValueType(ResNo: 0);
13095 if (LegalOperations || !TLI.isOperationLegalOrCustom(Op: ISD::VSELECT, VT))
13096 return SDValue();
13097
13098 SDValue VSel = Cast->getOperand(Num: 0);
13099 if (VSel.getOpcode() != ISD::VSELECT || !VSel.hasOneUse() ||
13100 VSel.getOperand(i: 0).getOpcode() != ISD::SETCC)
13101 return SDValue();
13102
13103 // Does the setcc have the same vector size as the casted select?
13104 SDValue SetCC = VSel.getOperand(i: 0);
13105 EVT SetCCVT = getSetCCResultType(VT: SetCC.getOperand(i: 0).getValueType());
13106 if (SetCCVT.getSizeInBits() != VT.getSizeInBits())
13107 return SDValue();
13108
13109 // cast (vsel (setcc X), A, B) --> vsel (setcc X), (cast A), (cast B)
13110 SDValue A = VSel.getOperand(i: 1);
13111 SDValue B = VSel.getOperand(i: 2);
13112 SDValue CastA, CastB;
13113 SDLoc DL(Cast);
13114 if (CastOpcode == ISD::FP_ROUND) {
13115 // FP_ROUND (fptrunc) has an extra flag operand to pass along.
13116 CastA = DAG.getNode(Opcode: CastOpcode, DL, VT, N1: A, N2: Cast->getOperand(Num: 1));
13117 CastB = DAG.getNode(Opcode: CastOpcode, DL, VT, N1: B, N2: Cast->getOperand(Num: 1));
13118 } else {
13119 CastA = DAG.getNode(Opcode: CastOpcode, DL, VT, Operand: A);
13120 CastB = DAG.getNode(Opcode: CastOpcode, DL, VT, Operand: B);
13121 }
13122 return DAG.getNode(Opcode: ISD::VSELECT, DL, VT, N1: SetCC, N2: CastA, N3: CastB);
13123}
13124
13125// fold ([s|z]ext ([s|z]extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
13126// fold ([s|z]ext ( extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
13127static SDValue tryToFoldExtOfExtload(SelectionDAG &DAG, DAGCombiner &Combiner,
13128 const TargetLowering &TLI, EVT VT,
13129 bool LegalOperations, SDNode *N,
13130 SDValue N0, ISD::LoadExtType ExtLoadType) {
13131 SDNode *N0Node = N0.getNode();
13132 bool isAExtLoad = (ExtLoadType == ISD::SEXTLOAD) ? ISD::isSEXTLoad(N: N0Node)
13133 : ISD::isZEXTLoad(N: N0Node);
13134 if ((!isAExtLoad && !ISD::isEXTLoad(N: N0Node)) ||
13135 !ISD::isUNINDEXEDLoad(N: N0Node) || !N0.hasOneUse())
13136 return SDValue();
13137
13138 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
13139 EVT MemVT = LN0->getMemoryVT();
13140 if ((LegalOperations || !LN0->isSimple() ||
13141 VT.isVector()) &&
13142 !TLI.isLoadExtLegal(ExtType: ExtLoadType, ValVT: VT, MemVT))
13143 return SDValue();
13144
13145 SDValue ExtLoad =
13146 DAG.getExtLoad(ExtType: ExtLoadType, dl: SDLoc(LN0), VT, Chain: LN0->getChain(),
13147 Ptr: LN0->getBasePtr(), MemVT, MMO: LN0->getMemOperand());
13148 Combiner.CombineTo(N, Res: ExtLoad);
13149 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN0, 1), To: ExtLoad.getValue(R: 1));
13150 if (LN0->use_empty())
13151 Combiner.recursivelyDeleteUnusedNodes(N: LN0);
13152 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13153}
13154
13155// fold ([s|z]ext (load x)) -> ([s|z]ext (truncate ([s|z]extload x)))
13156// Only generate vector extloads when 1) they're legal, and 2) they are
13157// deemed desirable by the target.
13158static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner,
13159 const TargetLowering &TLI, EVT VT,
13160 bool LegalOperations, SDNode *N, SDValue N0,
13161 ISD::LoadExtType ExtLoadType,
13162 ISD::NodeType ExtOpc) {
13163 // TODO: isFixedLengthVector() should be removed and any negative effects on
13164 // code generation being the result of that target's implementation of
13165 // isVectorLoadExtDesirable().
13166 if (!ISD::isNON_EXTLoad(N: N0.getNode()) ||
13167 !ISD::isUNINDEXEDLoad(N: N0.getNode()) ||
13168 ((LegalOperations || VT.isFixedLengthVector() ||
13169 !cast<LoadSDNode>(Val&: N0)->isSimple()) &&
13170 !TLI.isLoadExtLegal(ExtType: ExtLoadType, ValVT: VT, MemVT: N0.getValueType())))
13171 return {};
13172
13173 bool DoXform = true;
13174 SmallVector<SDNode *, 4> SetCCs;
13175 if (!N0.hasOneUse())
13176 DoXform = ExtendUsesToFormExtLoad(VT, N, N0, ExtOpc, ExtendNodes&: SetCCs, TLI);
13177 if (VT.isVector())
13178 DoXform &= TLI.isVectorLoadExtDesirable(ExtVal: SDValue(N, 0));
13179 if (!DoXform)
13180 return {};
13181
13182 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
13183 SDValue ExtLoad = DAG.getExtLoad(ExtType: ExtLoadType, dl: SDLoc(LN0), VT, Chain: LN0->getChain(),
13184 Ptr: LN0->getBasePtr(), MemVT: N0.getValueType(),
13185 MMO: LN0->getMemOperand());
13186 Combiner.ExtendSetCCUses(SetCCs, OrigLoad: N0, ExtLoad, ExtType: ExtOpc);
13187 // If the load value is used only by N, replace it via CombineTo N.
13188 bool NoReplaceTrunc = SDValue(LN0, 0).hasOneUse();
13189 Combiner.CombineTo(N, Res: ExtLoad);
13190 if (NoReplaceTrunc) {
13191 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN0, 1), To: ExtLoad.getValue(R: 1));
13192 Combiner.recursivelyDeleteUnusedNodes(N: LN0);
13193 } else {
13194 SDValue Trunc =
13195 DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N0), VT: N0.getValueType(), Operand: ExtLoad);
13196 Combiner.CombineTo(N: LN0, Res0: Trunc, Res1: ExtLoad.getValue(R: 1));
13197 }
13198 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13199}
13200
13201static SDValue
13202tryToFoldExtOfMaskedLoad(SelectionDAG &DAG, const TargetLowering &TLI, EVT VT,
13203 bool LegalOperations, SDNode *N, SDValue N0,
13204 ISD::LoadExtType ExtLoadType, ISD::NodeType ExtOpc) {
13205 if (!N0.hasOneUse())
13206 return SDValue();
13207
13208 MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(Val&: N0);
13209 if (!Ld || Ld->getExtensionType() != ISD::NON_EXTLOAD)
13210 return SDValue();
13211
13212 if ((LegalOperations || !cast<MaskedLoadSDNode>(Val&: N0)->isSimple()) &&
13213 !TLI.isLoadExtLegalOrCustom(ExtType: ExtLoadType, ValVT: VT, MemVT: Ld->getValueType(ResNo: 0)))
13214 return SDValue();
13215
13216 if (!TLI.isVectorLoadExtDesirable(ExtVal: SDValue(N, 0)))
13217 return SDValue();
13218
13219 SDLoc dl(Ld);
13220 SDValue PassThru = DAG.getNode(Opcode: ExtOpc, DL: dl, VT, Operand: Ld->getPassThru());
13221 SDValue NewLoad = DAG.getMaskedLoad(
13222 VT, dl, Chain: Ld->getChain(), Base: Ld->getBasePtr(), Offset: Ld->getOffset(), Mask: Ld->getMask(),
13223 Src0: PassThru, MemVT: Ld->getMemoryVT(), MMO: Ld->getMemOperand(), AM: Ld->getAddressingMode(),
13224 ExtLoadType, IsExpanding: Ld->isExpandingLoad());
13225 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Ld, 1), To: SDValue(NewLoad.getNode(), 1));
13226 return NewLoad;
13227}
13228
13229static SDValue foldExtendedSignBitTest(SDNode *N, SelectionDAG &DAG,
13230 bool LegalOperations) {
13231 assert((N->getOpcode() == ISD::SIGN_EXTEND ||
13232 N->getOpcode() == ISD::ZERO_EXTEND) && "Expected sext or zext");
13233
13234 SDValue SetCC = N->getOperand(Num: 0);
13235 if (LegalOperations || SetCC.getOpcode() != ISD::SETCC ||
13236 !SetCC.hasOneUse() || SetCC.getValueType() != MVT::i1)
13237 return SDValue();
13238
13239 SDValue X = SetCC.getOperand(i: 0);
13240 SDValue Ones = SetCC.getOperand(i: 1);
13241 ISD::CondCode CC = cast<CondCodeSDNode>(Val: SetCC.getOperand(i: 2))->get();
13242 EVT VT = N->getValueType(ResNo: 0);
13243 EVT XVT = X.getValueType();
13244 // setge X, C is canonicalized to setgt, so we do not need to match that
13245 // pattern. The setlt sibling is folded in SimplifySelectCC() because it does
13246 // not require the 'not' op.
13247 if (CC == ISD::SETGT && isAllOnesConstant(V: Ones) && VT == XVT) {
13248 // Invert and smear/shift the sign bit:
13249 // sext i1 (setgt iN X, -1) --> sra (not X), (N - 1)
13250 // zext i1 (setgt iN X, -1) --> srl (not X), (N - 1)
13251 SDLoc DL(N);
13252 unsigned ShCt = VT.getSizeInBits() - 1;
13253 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
13254 if (!TLI.shouldAvoidTransformToShift(VT, Amount: ShCt)) {
13255 SDValue NotX = DAG.getNOT(DL, Val: X, VT);
13256 SDValue ShiftAmount = DAG.getConstant(Val: ShCt, DL, VT);
13257 auto ShiftOpcode =
13258 N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SRA : ISD::SRL;
13259 return DAG.getNode(Opcode: ShiftOpcode, DL, VT, N1: NotX, N2: ShiftAmount);
13260 }
13261 }
13262 return SDValue();
13263}
13264
13265SDValue DAGCombiner::foldSextSetcc(SDNode *N) {
13266 SDValue N0 = N->getOperand(Num: 0);
13267 if (N0.getOpcode() != ISD::SETCC)
13268 return SDValue();
13269
13270 SDValue N00 = N0.getOperand(i: 0);
13271 SDValue N01 = N0.getOperand(i: 1);
13272 ISD::CondCode CC = cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get();
13273 EVT VT = N->getValueType(ResNo: 0);
13274 EVT N00VT = N00.getValueType();
13275 SDLoc DL(N);
13276
13277 // Propagate fast-math-flags.
13278 SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
13279
13280 // On some architectures (such as SSE/NEON/etc) the SETCC result type is
13281 // the same size as the compared operands. Try to optimize sext(setcc())
13282 // if this is the case.
13283 if (VT.isVector() && !LegalOperations &&
13284 TLI.getBooleanContents(Type: N00VT) ==
13285 TargetLowering::ZeroOrNegativeOneBooleanContent) {
13286 EVT SVT = getSetCCResultType(VT: N00VT);
13287
13288 // If we already have the desired type, don't change it.
13289 if (SVT != N0.getValueType()) {
13290 // We know that the # elements of the results is the same as the
13291 // # elements of the compare (and the # elements of the compare result
13292 // for that matter). Check to see that they are the same size. If so,
13293 // we know that the element size of the sext'd result matches the
13294 // element size of the compare operands.
13295 if (VT.getSizeInBits() == SVT.getSizeInBits())
13296 return DAG.getSetCC(DL, VT, LHS: N00, RHS: N01, Cond: CC);
13297
13298 // If the desired elements are smaller or larger than the source
13299 // elements, we can use a matching integer vector type and then
13300 // truncate/sign extend.
13301 EVT MatchingVecType = N00VT.changeVectorElementTypeToInteger();
13302 if (SVT == MatchingVecType) {
13303 SDValue VsetCC = DAG.getSetCC(DL, VT: MatchingVecType, LHS: N00, RHS: N01, Cond: CC);
13304 return DAG.getSExtOrTrunc(Op: VsetCC, DL, VT);
13305 }
13306 }
13307
13308 // Try to eliminate the sext of a setcc by zexting the compare operands.
13309 if (N0.hasOneUse() && TLI.isOperationLegalOrCustom(Op: ISD::SETCC, VT) &&
13310 !TLI.isOperationLegalOrCustom(Op: ISD::SETCC, VT: SVT)) {
13311 bool IsSignedCmp = ISD::isSignedIntSetCC(Code: CC);
13312 unsigned LoadOpcode = IsSignedCmp ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
13313 unsigned ExtOpcode = IsSignedCmp ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
13314
13315 // We have an unsupported narrow vector compare op that would be legal
13316 // if extended to the destination type. See if the compare operands
13317 // can be freely extended to the destination type.
13318 auto IsFreeToExtend = [&](SDValue V) {
13319 if (isConstantOrConstantVector(N: V, /*NoOpaques*/ true))
13320 return true;
13321 // Match a simple, non-extended load that can be converted to a
13322 // legal {z/s}ext-load.
13323 // TODO: Allow widening of an existing {z/s}ext-load?
13324 if (!(ISD::isNON_EXTLoad(N: V.getNode()) &&
13325 ISD::isUNINDEXEDLoad(N: V.getNode()) &&
13326 cast<LoadSDNode>(Val&: V)->isSimple() &&
13327 TLI.isLoadExtLegal(ExtType: LoadOpcode, ValVT: VT, MemVT: V.getValueType())))
13328 return false;
13329
13330 // Non-chain users of this value must either be the setcc in this
13331 // sequence or extends that can be folded into the new {z/s}ext-load.
13332 for (SDNode::use_iterator UI = V->use_begin(), UE = V->use_end();
13333 UI != UE; ++UI) {
13334 // Skip uses of the chain and the setcc.
13335 SDNode *User = *UI;
13336 if (UI.getUse().getResNo() != 0 || User == N0.getNode())
13337 continue;
13338 // Extra users must have exactly the same cast we are about to create.
13339 // TODO: This restriction could be eased if ExtendUsesToFormExtLoad()
13340 // is enhanced similarly.
13341 if (User->getOpcode() != ExtOpcode || User->getValueType(ResNo: 0) != VT)
13342 return false;
13343 }
13344 return true;
13345 };
13346
13347 if (IsFreeToExtend(N00) && IsFreeToExtend(N01)) {
13348 SDValue Ext0 = DAG.getNode(Opcode: ExtOpcode, DL, VT, Operand: N00);
13349 SDValue Ext1 = DAG.getNode(Opcode: ExtOpcode, DL, VT, Operand: N01);
13350 return DAG.getSetCC(DL, VT, LHS: Ext0, RHS: Ext1, Cond: CC);
13351 }
13352 }
13353 }
13354
13355 // sext(setcc x, y, cc) -> (select (setcc x, y, cc), T, 0)
13356 // Here, T can be 1 or -1, depending on the type of the setcc and
13357 // getBooleanContents().
13358 unsigned SetCCWidth = N0.getScalarValueSizeInBits();
13359
13360 // To determine the "true" side of the select, we need to know the high bit
13361 // of the value returned by the setcc if it evaluates to true.
13362 // If the type of the setcc is i1, then the true case of the select is just
13363 // sext(i1 1), that is, -1.
13364 // If the type of the setcc is larger (say, i8) then the value of the high
13365 // bit depends on getBooleanContents(), so ask TLI for a real "true" value
13366 // of the appropriate width.
13367 SDValue ExtTrueVal = (SetCCWidth == 1)
13368 ? DAG.getAllOnesConstant(DL, VT)
13369 : DAG.getBoolConstant(V: true, DL, VT, OpVT: N00VT);
13370 SDValue Zero = DAG.getConstant(Val: 0, DL, VT);
13371 if (SDValue SCC = SimplifySelectCC(DL, N0: N00, N1: N01, N2: ExtTrueVal, N3: Zero, CC, NotExtCompare: true))
13372 return SCC;
13373
13374 if (!VT.isVector() && !shouldConvertSelectOfConstantsToMath(Cond: N0, VT, TLI)) {
13375 EVT SetCCVT = getSetCCResultType(VT: N00VT);
13376 // Don't do this transform for i1 because there's a select transform
13377 // that would reverse it.
13378 // TODO: We should not do this transform at all without a target hook
13379 // because a sext is likely cheaper than a select?
13380 if (SetCCVT.getScalarSizeInBits() != 1 &&
13381 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SETCC, VT: N00VT))) {
13382 SDValue SetCC = DAG.getSetCC(DL, VT: SetCCVT, LHS: N00, RHS: N01, Cond: CC);
13383 return DAG.getSelect(DL, VT, Cond: SetCC, LHS: ExtTrueVal, RHS: Zero);
13384 }
13385 }
13386
13387 return SDValue();
13388}
13389
13390SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
13391 SDValue N0 = N->getOperand(Num: 0);
13392 EVT VT = N->getValueType(ResNo: 0);
13393 SDLoc DL(N);
13394
13395 if (VT.isVector())
13396 if (SDValue FoldedVOp = SimplifyVCastOp(N, DL))
13397 return FoldedVOp;
13398
13399 // sext(undef) = 0 because the top bit will all be the same.
13400 if (N0.isUndef())
13401 return DAG.getConstant(Val: 0, DL, VT);
13402
13403 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
13404 return Res;
13405
13406 // fold (sext (sext x)) -> (sext x)
13407 // fold (sext (aext x)) -> (sext x)
13408 if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
13409 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: N0.getOperand(i: 0));
13410
13411 // fold (sext (aext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
13412 // fold (sext (sext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
13413 if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
13414 N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG)
13415 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_VECTOR_INREG, DL: SDLoc(N), VT,
13416 Operand: N0.getOperand(i: 0));
13417
13418 // fold (sext (sext_inreg x)) -> (sext (trunc x))
13419 if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG) {
13420 SDValue N00 = N0.getOperand(i: 0);
13421 EVT ExtVT = cast<VTSDNode>(Val: N0->getOperand(Num: 1))->getVT();
13422 if ((N00.getOpcode() == ISD::TRUNCATE || TLI.isTruncateFree(Val: N00, VT2: ExtVT)) &&
13423 (!LegalTypes || TLI.isTypeLegal(VT: ExtVT))) {
13424 SDValue T = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: ExtVT, Operand: N00);
13425 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: T);
13426 }
13427 }
13428
13429 if (N0.getOpcode() == ISD::TRUNCATE) {
13430 // fold (sext (truncate (load x))) -> (sext (smaller load x))
13431 // fold (sext (truncate (srl (load x), c))) -> (sext (smaller load (x+c/n)))
13432 if (SDValue NarrowLoad = reduceLoadWidth(N: N0.getNode())) {
13433 SDNode *oye = N0.getOperand(i: 0).getNode();
13434 if (NarrowLoad.getNode() != N0.getNode()) {
13435 CombineTo(N: N0.getNode(), Res: NarrowLoad);
13436 // CombineTo deleted the truncate, if needed, but not what's under it.
13437 AddToWorklist(N: oye);
13438 }
13439 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13440 }
13441
13442 // See if the value being truncated is already sign extended. If so, just
13443 // eliminate the trunc/sext pair.
13444 SDValue Op = N0.getOperand(i: 0);
13445 unsigned OpBits = Op.getScalarValueSizeInBits();
13446 unsigned MidBits = N0.getScalarValueSizeInBits();
13447 unsigned DestBits = VT.getScalarSizeInBits();
13448 unsigned NumSignBits = DAG.ComputeNumSignBits(Op);
13449
13450 if (OpBits == DestBits) {
13451 // Op is i32, Mid is i8, and Dest is i32. If Op has more than 24 sign
13452 // bits, it is already ready.
13453 if (NumSignBits > DestBits-MidBits)
13454 return Op;
13455 } else if (OpBits < DestBits) {
13456 // Op is i32, Mid is i8, and Dest is i64. If Op has more than 24 sign
13457 // bits, just sext from i32.
13458 if (NumSignBits > OpBits-MidBits)
13459 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: Op);
13460 } else {
13461 // Op is i64, Mid is i8, and Dest is i32. If Op has more than 56 sign
13462 // bits, just truncate to i32.
13463 if (NumSignBits > OpBits-MidBits)
13464 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Op);
13465 }
13466
13467 // fold (sext (truncate x)) -> (sextinreg x).
13468 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::SIGN_EXTEND_INREG,
13469 VT: N0.getValueType())) {
13470 if (OpBits < DestBits)
13471 Op = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL: SDLoc(N0), VT, Operand: Op);
13472 else if (OpBits > DestBits)
13473 Op = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N0), VT, Operand: Op);
13474 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL, VT, N1: Op,
13475 N2: DAG.getValueType(N0.getValueType()));
13476 }
13477 }
13478
13479 // Try to simplify (sext (load x)).
13480 if (SDValue foldedExt =
13481 tryToFoldExtOfLoad(DAG, Combiner&: *this, TLI, VT, LegalOperations, N, N0,
13482 ExtLoadType: ISD::SEXTLOAD, ExtOpc: ISD::SIGN_EXTEND))
13483 return foldedExt;
13484
13485 if (SDValue foldedExt =
13486 tryToFoldExtOfMaskedLoad(DAG, TLI, VT, LegalOperations, N, N0,
13487 ExtLoadType: ISD::SEXTLOAD, ExtOpc: ISD::SIGN_EXTEND))
13488 return foldedExt;
13489
13490 // fold (sext (load x)) to multiple smaller sextloads.
13491 // Only on illegal but splittable vectors.
13492 if (SDValue ExtLoad = CombineExtLoad(N))
13493 return ExtLoad;
13494
13495 // Try to simplify (sext (sextload x)).
13496 if (SDValue foldedExt = tryToFoldExtOfExtload(
13497 DAG, Combiner&: *this, TLI, VT, LegalOperations, N, N0, ExtLoadType: ISD::SEXTLOAD))
13498 return foldedExt;
13499
13500 // fold (sext (and/or/xor (load x), cst)) ->
13501 // (and/or/xor (sextload x), (sext cst))
13502 if (ISD::isBitwiseLogicOp(Opcode: N0.getOpcode()) &&
13503 isa<LoadSDNode>(Val: N0.getOperand(i: 0)) &&
13504 N0.getOperand(i: 1).getOpcode() == ISD::Constant &&
13505 (!LegalOperations && TLI.isOperationLegal(Op: N0.getOpcode(), VT))) {
13506 LoadSDNode *LN00 = cast<LoadSDNode>(Val: N0.getOperand(i: 0));
13507 EVT MemVT = LN00->getMemoryVT();
13508 if (TLI.isLoadExtLegal(ExtType: ISD::SEXTLOAD, ValVT: VT, MemVT) &&
13509 LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) {
13510 SmallVector<SDNode*, 4> SetCCs;
13511 bool DoXform = ExtendUsesToFormExtLoad(VT, N: N0.getNode(), N0: N0.getOperand(i: 0),
13512 ExtOpc: ISD::SIGN_EXTEND, ExtendNodes&: SetCCs, TLI);
13513 if (DoXform) {
13514 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::SEXTLOAD, dl: SDLoc(LN00), VT,
13515 Chain: LN00->getChain(), Ptr: LN00->getBasePtr(),
13516 MemVT: LN00->getMemoryVT(),
13517 MMO: LN00->getMemOperand());
13518 APInt Mask = N0.getConstantOperandAPInt(i: 1).sext(width: VT.getSizeInBits());
13519 SDValue And = DAG.getNode(Opcode: N0.getOpcode(), DL, VT,
13520 N1: ExtLoad, N2: DAG.getConstant(Val: Mask, DL, VT));
13521 ExtendSetCCUses(SetCCs, OrigLoad: N0.getOperand(i: 0), ExtLoad, ExtType: ISD::SIGN_EXTEND);
13522 bool NoReplaceTruncAnd = !N0.hasOneUse();
13523 bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
13524 CombineTo(N, Res: And);
13525 // If N0 has multiple uses, change other uses as well.
13526 if (NoReplaceTruncAnd) {
13527 SDValue TruncAnd =
13528 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: N0.getValueType(), Operand: And);
13529 CombineTo(N: N0.getNode(), Res: TruncAnd);
13530 }
13531 if (NoReplaceTrunc) {
13532 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN00, 1), To: ExtLoad.getValue(R: 1));
13533 } else {
13534 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(LN00),
13535 VT: LN00->getValueType(ResNo: 0), Operand: ExtLoad);
13536 CombineTo(N: LN00, Res0: Trunc, Res1: ExtLoad.getValue(R: 1));
13537 }
13538 return SDValue(N,0); // Return N so it doesn't get rechecked!
13539 }
13540 }
13541 }
13542
13543 if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
13544 return V;
13545
13546 if (SDValue V = foldSextSetcc(N))
13547 return V;
13548
13549 // fold (sext x) -> (zext x) if the sign bit is known zero.
13550 if (!TLI.isSExtCheaperThanZExt(FromTy: N0.getValueType(), ToTy: VT) &&
13551 (!LegalOperations || TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT)) &&
13552 DAG.SignBitIsZero(Op: N0)) {
13553 SDNodeFlags Flags;
13554 Flags.setNonNeg(true);
13555 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0, Flags);
13556 }
13557
13558 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(Cast: N))
13559 return NewVSel;
13560
13561 // Eliminate this sign extend by doing a negation in the destination type:
13562 // sext i32 (0 - (zext i8 X to i32)) to i64 --> 0 - (zext i8 X to i64)
13563 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
13564 isNullOrNullSplat(V: N0.getOperand(i: 0)) &&
13565 N0.getOperand(i: 1).getOpcode() == ISD::ZERO_EXTEND &&
13566 TLI.isOperationLegalOrCustom(Op: ISD::SUB, VT)) {
13567 SDValue Zext = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1).getOperand(i: 0), DL, VT);
13568 return DAG.getNegative(Val: Zext, DL, VT);
13569 }
13570 // Eliminate this sign extend by doing a decrement in the destination type:
13571 // sext i32 ((zext i8 X to i32) + (-1)) to i64 --> (zext i8 X to i64) + (-1)
13572 if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
13573 isAllOnesOrAllOnesSplat(V: N0.getOperand(i: 1)) &&
13574 N0.getOperand(i: 0).getOpcode() == ISD::ZERO_EXTEND &&
13575 TLI.isOperationLegalOrCustom(Op: ISD::ADD, VT)) {
13576 SDValue Zext = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 0).getOperand(i: 0), DL, VT);
13577 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Zext, N2: DAG.getAllOnesConstant(DL, VT));
13578 }
13579
13580 // fold sext (not i1 X) -> add (zext i1 X), -1
13581 // TODO: This could be extended to handle bool vectors.
13582 if (N0.getValueType() == MVT::i1 && isBitwiseNot(N0) && N0.hasOneUse() &&
13583 (!LegalOperations || (TLI.isOperationLegal(ISD::ZERO_EXTEND, VT) &&
13584 TLI.isOperationLegal(ISD::ADD, VT)))) {
13585 // If we can eliminate the 'not', the sext form should be better
13586 if (SDValue NewXor = visitXOR(N: N0.getNode())) {
13587 // Returning N0 is a form of in-visit replacement that may have
13588 // invalidated N0.
13589 if (NewXor.getNode() == N0.getNode()) {
13590 // Return SDValue here as the xor should have already been replaced in
13591 // this sext.
13592 return SDValue();
13593 }
13594
13595 // Return a new sext with the new xor.
13596 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: NewXor);
13597 }
13598
13599 SDValue Zext = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0.getOperand(i: 0));
13600 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Zext, N2: DAG.getAllOnesConstant(DL, VT));
13601 }
13602
13603 if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, Level))
13604 return Res;
13605
13606 return SDValue();
13607}
13608
13609/// Given an extending node with a pop-count operand, if the target does not
13610/// support a pop-count in the narrow source type but does support it in the
13611/// destination type, widen the pop-count to the destination type.
13612static SDValue widenCtPop(SDNode *Extend, SelectionDAG &DAG) {
13613 assert((Extend->getOpcode() == ISD::ZERO_EXTEND ||
13614 Extend->getOpcode() == ISD::ANY_EXTEND) && "Expected extend op");
13615
13616 SDValue CtPop = Extend->getOperand(Num: 0);
13617 if (CtPop.getOpcode() != ISD::CTPOP || !CtPop.hasOneUse())
13618 return SDValue();
13619
13620 EVT VT = Extend->getValueType(ResNo: 0);
13621 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
13622 if (TLI.isOperationLegalOrCustom(Op: ISD::CTPOP, VT: CtPop.getValueType()) ||
13623 !TLI.isOperationLegalOrCustom(Op: ISD::CTPOP, VT))
13624 return SDValue();
13625
13626 // zext (ctpop X) --> ctpop (zext X)
13627 SDLoc DL(Extend);
13628 SDValue NewZext = DAG.getZExtOrTrunc(Op: CtPop.getOperand(i: 0), DL, VT);
13629 return DAG.getNode(Opcode: ISD::CTPOP, DL, VT, Operand: NewZext);
13630}
13631
13632// If we have (zext (abs X)) where X is a type that will be promoted by type
13633// legalization, convert to (abs (sext X)). But don't extend past a legal type.
13634static SDValue widenAbs(SDNode *Extend, SelectionDAG &DAG) {
13635 assert(Extend->getOpcode() == ISD::ZERO_EXTEND && "Expected zero extend.");
13636
13637 EVT VT = Extend->getValueType(ResNo: 0);
13638 if (VT.isVector())
13639 return SDValue();
13640
13641 SDValue Abs = Extend->getOperand(Num: 0);
13642 if (Abs.getOpcode() != ISD::ABS || !Abs.hasOneUse())
13643 return SDValue();
13644
13645 EVT AbsVT = Abs.getValueType();
13646 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
13647 if (TLI.getTypeAction(Context&: *DAG.getContext(), VT: AbsVT) !=
13648 TargetLowering::TypePromoteInteger)
13649 return SDValue();
13650
13651 EVT LegalVT = TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: AbsVT);
13652
13653 SDValue SExt =
13654 DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL: SDLoc(Abs), VT: LegalVT, Operand: Abs.getOperand(i: 0));
13655 SDValue NewAbs = DAG.getNode(Opcode: ISD::ABS, DL: SDLoc(Abs), VT: LegalVT, Operand: SExt);
13656 return DAG.getZExtOrTrunc(Op: NewAbs, DL: SDLoc(Extend), VT);
13657}
13658
13659SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
13660 SDValue N0 = N->getOperand(Num: 0);
13661 EVT VT = N->getValueType(ResNo: 0);
13662 SDLoc DL(N);
13663
13664 if (VT.isVector())
13665 if (SDValue FoldedVOp = SimplifyVCastOp(N, DL))
13666 return FoldedVOp;
13667
13668 // zext(undef) = 0
13669 if (N0.isUndef())
13670 return DAG.getConstant(Val: 0, DL, VT);
13671
13672 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
13673 return Res;
13674
13675 // fold (zext (zext x)) -> (zext x)
13676 // fold (zext (aext x)) -> (zext x)
13677 if (N0.getOpcode() == ISD::ZERO_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
13678 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0.getOperand(i: 0));
13679
13680 // fold (zext (aext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
13681 // fold (zext (zext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
13682 if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
13683 N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG)
13684 return DAG.getNode(Opcode: ISD::ZERO_EXTEND_VECTOR_INREG, DL: SDLoc(N), VT,
13685 Operand: N0.getOperand(i: 0));
13686
13687 // fold (zext (truncate x)) -> (zext x) or
13688 // (zext (truncate x)) -> (truncate x)
13689 // This is valid when the truncated bits of x are already zero.
13690 SDValue Op;
13691 KnownBits Known;
13692 if (isTruncateOf(DAG, N: N0, Op, Known)) {
13693 APInt TruncatedBits =
13694 (Op.getScalarValueSizeInBits() == N0.getScalarValueSizeInBits()) ?
13695 APInt(Op.getScalarValueSizeInBits(), 0) :
13696 APInt::getBitsSet(numBits: Op.getScalarValueSizeInBits(),
13697 loBit: N0.getScalarValueSizeInBits(),
13698 hiBit: std::min(a: Op.getScalarValueSizeInBits(),
13699 b: VT.getScalarSizeInBits()));
13700 if (TruncatedBits.isSubsetOf(RHS: Known.Zero)) {
13701 SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, DL, VT);
13702 DAG.salvageDebugInfo(N&: *N0.getNode());
13703
13704 return ZExtOrTrunc;
13705 }
13706 }
13707
13708 // fold (zext (truncate x)) -> (and x, mask)
13709 if (N0.getOpcode() == ISD::TRUNCATE) {
13710 // fold (zext (truncate (load x))) -> (zext (smaller load x))
13711 // fold (zext (truncate (srl (load x), c))) -> (zext (smaller load (x+c/n)))
13712 if (SDValue NarrowLoad = reduceLoadWidth(N: N0.getNode())) {
13713 SDNode *oye = N0.getOperand(i: 0).getNode();
13714 if (NarrowLoad.getNode() != N0.getNode()) {
13715 CombineTo(N: N0.getNode(), Res: NarrowLoad);
13716 // CombineTo deleted the truncate, if needed, but not what's under it.
13717 AddToWorklist(N: oye);
13718 }
13719 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13720 }
13721
13722 EVT SrcVT = N0.getOperand(i: 0).getValueType();
13723 EVT MinVT = N0.getValueType();
13724
13725 // Try to mask before the extension to avoid having to generate a larger mask,
13726 // possibly over several sub-vectors.
13727 if (SrcVT.bitsLT(VT) && VT.isVector()) {
13728 if (!LegalOperations || (TLI.isOperationLegal(Op: ISD::AND, VT: SrcVT) &&
13729 TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT))) {
13730 SDValue Op = N0.getOperand(i: 0);
13731 Op = DAG.getZeroExtendInReg(Op, DL, VT: MinVT);
13732 AddToWorklist(N: Op.getNode());
13733 SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, DL, VT);
13734 // Transfer the debug info; the new node is equivalent to N0.
13735 DAG.transferDbgValues(From: N0, To: ZExtOrTrunc);
13736 return ZExtOrTrunc;
13737 }
13738 }
13739
13740 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::AND, VT)) {
13741 SDValue Op = DAG.getAnyExtOrTrunc(Op: N0.getOperand(i: 0), DL, VT);
13742 AddToWorklist(N: Op.getNode());
13743 SDValue And = DAG.getZeroExtendInReg(Op, DL, VT: MinVT);
13744 // We may safely transfer the debug info describing the truncate node over
13745 // to the equivalent and operation.
13746 DAG.transferDbgValues(From: N0, To: And);
13747 return And;
13748 }
13749 }
13750
13751 // Fold (zext (and (trunc x), cst)) -> (and x, cst),
13752 // if either of the casts is not free.
13753 if (N0.getOpcode() == ISD::AND &&
13754 N0.getOperand(i: 0).getOpcode() == ISD::TRUNCATE &&
13755 N0.getOperand(i: 1).getOpcode() == ISD::Constant &&
13756 (!TLI.isTruncateFree(Val: N0.getOperand(i: 0).getOperand(i: 0), VT2: N0.getValueType()) ||
13757 !TLI.isZExtFree(FromTy: N0.getValueType(), ToTy: VT))) {
13758 SDValue X = N0.getOperand(i: 0).getOperand(i: 0);
13759 X = DAG.getAnyExtOrTrunc(Op: X, DL: SDLoc(X), VT);
13760 APInt Mask = N0.getConstantOperandAPInt(i: 1).zext(width: VT.getSizeInBits());
13761 return DAG.getNode(Opcode: ISD::AND, DL, VT,
13762 N1: X, N2: DAG.getConstant(Val: Mask, DL, VT));
13763 }
13764
13765 // Try to simplify (zext (load x)).
13766 if (SDValue foldedExt =
13767 tryToFoldExtOfLoad(DAG, Combiner&: *this, TLI, VT, LegalOperations, N, N0,
13768 ExtLoadType: ISD::ZEXTLOAD, ExtOpc: ISD::ZERO_EXTEND))
13769 return foldedExt;
13770
13771 if (SDValue foldedExt =
13772 tryToFoldExtOfMaskedLoad(DAG, TLI, VT, LegalOperations, N, N0,
13773 ExtLoadType: ISD::ZEXTLOAD, ExtOpc: ISD::ZERO_EXTEND))
13774 return foldedExt;
13775
13776 // fold (zext (load x)) to multiple smaller zextloads.
13777 // Only on illegal but splittable vectors.
13778 if (SDValue ExtLoad = CombineExtLoad(N))
13779 return ExtLoad;
13780
13781 // fold (zext (and/or/xor (load x), cst)) ->
13782 // (and/or/xor (zextload x), (zext cst))
13783 // Unless (and (load x) cst) will match as a zextload already and has
13784 // additional users, or the zext is already free.
13785 if (ISD::isBitwiseLogicOp(Opcode: N0.getOpcode()) && !TLI.isZExtFree(Val: N0, VT2: VT) &&
13786 isa<LoadSDNode>(Val: N0.getOperand(i: 0)) &&
13787 N0.getOperand(i: 1).getOpcode() == ISD::Constant &&
13788 (!LegalOperations && TLI.isOperationLegal(Op: N0.getOpcode(), VT))) {
13789 LoadSDNode *LN00 = cast<LoadSDNode>(Val: N0.getOperand(i: 0));
13790 EVT MemVT = LN00->getMemoryVT();
13791 if (TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: VT, MemVT) &&
13792 LN00->getExtensionType() != ISD::SEXTLOAD && LN00->isUnindexed()) {
13793 bool DoXform = true;
13794 SmallVector<SDNode*, 4> SetCCs;
13795 if (!N0.hasOneUse()) {
13796 if (N0.getOpcode() == ISD::AND) {
13797 auto *AndC = cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
13798 EVT LoadResultTy = AndC->getValueType(ResNo: 0);
13799 EVT ExtVT;
13800 if (isAndLoadExtLoad(AndC, LoadN: LN00, LoadResultTy, ExtVT))
13801 DoXform = false;
13802 }
13803 }
13804 if (DoXform)
13805 DoXform = ExtendUsesToFormExtLoad(VT, N: N0.getNode(), N0: N0.getOperand(i: 0),
13806 ExtOpc: ISD::ZERO_EXTEND, ExtendNodes&: SetCCs, TLI);
13807 if (DoXform) {
13808 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::ZEXTLOAD, dl: SDLoc(LN00), VT,
13809 Chain: LN00->getChain(), Ptr: LN00->getBasePtr(),
13810 MemVT: LN00->getMemoryVT(),
13811 MMO: LN00->getMemOperand());
13812 APInt Mask = N0.getConstantOperandAPInt(i: 1).zext(width: VT.getSizeInBits());
13813 SDValue And = DAG.getNode(Opcode: N0.getOpcode(), DL, VT,
13814 N1: ExtLoad, N2: DAG.getConstant(Val: Mask, DL, VT));
13815 ExtendSetCCUses(SetCCs, OrigLoad: N0.getOperand(i: 0), ExtLoad, ExtType: ISD::ZERO_EXTEND);
13816 bool NoReplaceTruncAnd = !N0.hasOneUse();
13817 bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
13818 CombineTo(N, Res: And);
13819 // If N0 has multiple uses, change other uses as well.
13820 if (NoReplaceTruncAnd) {
13821 SDValue TruncAnd =
13822 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: N0.getValueType(), Operand: And);
13823 CombineTo(N: N0.getNode(), Res: TruncAnd);
13824 }
13825 if (NoReplaceTrunc) {
13826 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN00, 1), To: ExtLoad.getValue(R: 1));
13827 } else {
13828 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(LN00),
13829 VT: LN00->getValueType(ResNo: 0), Operand: ExtLoad);
13830 CombineTo(N: LN00, Res0: Trunc, Res1: ExtLoad.getValue(R: 1));
13831 }
13832 return SDValue(N,0); // Return N so it doesn't get rechecked!
13833 }
13834 }
13835 }
13836
13837 // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
13838 // (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
13839 if (SDValue ZExtLoad = CombineZExtLogicopShiftLoad(N))
13840 return ZExtLoad;
13841
13842 // Try to simplify (zext (zextload x)).
13843 if (SDValue foldedExt = tryToFoldExtOfExtload(
13844 DAG, Combiner&: *this, TLI, VT, LegalOperations, N, N0, ExtLoadType: ISD::ZEXTLOAD))
13845 return foldedExt;
13846
13847 if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
13848 return V;
13849
13850 if (N0.getOpcode() == ISD::SETCC) {
13851 // Propagate fast-math-flags.
13852 SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
13853
13854 // Only do this before legalize for now.
13855 if (!LegalOperations && VT.isVector() &&
13856 N0.getValueType().getVectorElementType() == MVT::i1) {
13857 EVT N00VT = N0.getOperand(i: 0).getValueType();
13858 if (getSetCCResultType(VT: N00VT) == N0.getValueType())
13859 return SDValue();
13860
13861 // We know that the # elements of the results is the same as the #
13862 // elements of the compare (and the # elements of the compare result for
13863 // that matter). Check to see that they are the same size. If so, we know
13864 // that the element size of the sext'd result matches the element size of
13865 // the compare operands.
13866 if (VT.getSizeInBits() == N00VT.getSizeInBits()) {
13867 // zext(setcc) -> zext_in_reg(vsetcc) for vectors.
13868 SDValue VSetCC = DAG.getNode(Opcode: ISD::SETCC, DL, VT, N1: N0.getOperand(i: 0),
13869 N2: N0.getOperand(i: 1), N3: N0.getOperand(i: 2));
13870 return DAG.getZeroExtendInReg(Op: VSetCC, DL, VT: N0.getValueType());
13871 }
13872
13873 // If the desired elements are smaller or larger than the source
13874 // elements we can use a matching integer vector type and then
13875 // truncate/any extend followed by zext_in_reg.
13876 EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
13877 SDValue VsetCC =
13878 DAG.getNode(Opcode: ISD::SETCC, DL, VT: MatchingVectorType, N1: N0.getOperand(i: 0),
13879 N2: N0.getOperand(i: 1), N3: N0.getOperand(i: 2));
13880 return DAG.getZeroExtendInReg(Op: DAG.getAnyExtOrTrunc(Op: VsetCC, DL, VT), DL,
13881 VT: N0.getValueType());
13882 }
13883
13884 // zext(setcc x,y,cc) -> zext(select x, y, true, false, cc)
13885 EVT N0VT = N0.getValueType();
13886 EVT N00VT = N0.getOperand(i: 0).getValueType();
13887 if (SDValue SCC = SimplifySelectCC(
13888 DL, N0: N0.getOperand(i: 0), N1: N0.getOperand(i: 1),
13889 N2: DAG.getBoolConstant(V: true, DL, VT: N0VT, OpVT: N00VT),
13890 N3: DAG.getBoolConstant(V: false, DL, VT: N0VT, OpVT: N00VT),
13891 CC: cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get(), NotExtCompare: true))
13892 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: SCC);
13893 }
13894
13895 // (zext (shl (zext x), cst)) -> (shl (zext x), cst)
13896 if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
13897 !TLI.isZExtFree(Val: N0, VT2: VT)) {
13898 SDValue ShVal = N0.getOperand(i: 0);
13899 SDValue ShAmt = N0.getOperand(i: 1);
13900 if (auto *ShAmtC = dyn_cast<ConstantSDNode>(Val&: ShAmt)) {
13901 if (ShVal.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse()) {
13902 if (N0.getOpcode() == ISD::SHL) {
13903 // If the original shl may be shifting out bits, do not perform this
13904 // transformation.
13905 // TODO: Add MaskedValueIsZero check.
13906 unsigned KnownZeroBits = ShVal.getValueSizeInBits() -
13907 ShVal.getOperand(i: 0).getValueSizeInBits();
13908 if (ShAmtC->getAPIntValue().ugt(RHS: KnownZeroBits))
13909 return SDValue();
13910 }
13911
13912 // Ensure that the shift amount is wide enough for the shifted value.
13913 if (Log2_32_Ceil(VT.getSizeInBits()) > ShAmt.getValueSizeInBits())
13914 ShAmt = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, ShAmt);
13915
13916 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT,
13917 N1: DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: ShVal), N2: ShAmt);
13918 }
13919 }
13920 }
13921
13922 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(Cast: N))
13923 return NewVSel;
13924
13925 if (SDValue NewCtPop = widenCtPop(Extend: N, DAG))
13926 return NewCtPop;
13927
13928 if (SDValue V = widenAbs(Extend: N, DAG))
13929 return V;
13930
13931 if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, Level))
13932 return Res;
13933
13934 return SDValue();
13935}
13936
13937SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
13938 SDValue N0 = N->getOperand(Num: 0);
13939 EVT VT = N->getValueType(ResNo: 0);
13940 SDLoc DL(N);
13941
13942 // aext(undef) = undef
13943 if (N0.isUndef())
13944 return DAG.getUNDEF(VT);
13945
13946 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
13947 return Res;
13948
13949 // fold (aext (aext x)) -> (aext x)
13950 // fold (aext (zext x)) -> (zext x)
13951 // fold (aext (sext x)) -> (sext x)
13952 if (N0.getOpcode() == ISD::ANY_EXTEND ||
13953 N0.getOpcode() == ISD::ZERO_EXTEND ||
13954 N0.getOpcode() == ISD::SIGN_EXTEND)
13955 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, Operand: N0.getOperand(i: 0));
13956
13957 // fold (aext (aext_extend_vector_inreg x)) -> (aext_extend_vector_inreg x)
13958 // fold (aext (zext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
13959 // fold (aext (sext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
13960 if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
13961 N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG ||
13962 N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG)
13963 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, Operand: N0.getOperand(i: 0));
13964
13965 // fold (aext (truncate (load x))) -> (aext (smaller load x))
13966 // fold (aext (truncate (srl (load x), c))) -> (aext (small load (x+c/n)))
13967 if (N0.getOpcode() == ISD::TRUNCATE) {
13968 if (SDValue NarrowLoad = reduceLoadWidth(N: N0.getNode())) {
13969 SDNode *oye = N0.getOperand(i: 0).getNode();
13970 if (NarrowLoad.getNode() != N0.getNode()) {
13971 CombineTo(N: N0.getNode(), Res: NarrowLoad);
13972 // CombineTo deleted the truncate, if needed, but not what's under it.
13973 AddToWorklist(N: oye);
13974 }
13975 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13976 }
13977 }
13978
13979 // fold (aext (truncate x))
13980 if (N0.getOpcode() == ISD::TRUNCATE)
13981 return DAG.getAnyExtOrTrunc(Op: N0.getOperand(i: 0), DL, VT);
13982
13983 // Fold (aext (and (trunc x), cst)) -> (and x, cst)
13984 // if the trunc is not free.
13985 if (N0.getOpcode() == ISD::AND &&
13986 N0.getOperand(i: 0).getOpcode() == ISD::TRUNCATE &&
13987 N0.getOperand(i: 1).getOpcode() == ISD::Constant &&
13988 !TLI.isTruncateFree(Val: N0.getOperand(i: 0).getOperand(i: 0), VT2: N0.getValueType())) {
13989 SDValue X = DAG.getAnyExtOrTrunc(Op: N0.getOperand(i: 0).getOperand(i: 0), DL, VT);
13990 SDValue Y = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT, Operand: N0.getOperand(i: 1));
13991 assert(isa<ConstantSDNode>(Y) && "Expected constant to be folded!");
13992 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: X, N2: Y);
13993 }
13994
13995 // fold (aext (load x)) -> (aext (truncate (extload x)))
13996 // None of the supported targets knows how to perform load and any_ext
13997 // on vectors in one instruction, so attempt to fold to zext instead.
13998 if (VT.isVector()) {
13999 // Try to simplify (zext (load x)).
14000 if (SDValue foldedExt =
14001 tryToFoldExtOfLoad(DAG, Combiner&: *this, TLI, VT, LegalOperations, N, N0,
14002 ExtLoadType: ISD::ZEXTLOAD, ExtOpc: ISD::ZERO_EXTEND))
14003 return foldedExt;
14004 } else if (ISD::isNON_EXTLoad(N: N0.getNode()) &&
14005 ISD::isUNINDEXEDLoad(N: N0.getNode()) &&
14006 TLI.isLoadExtLegal(ExtType: ISD::EXTLOAD, ValVT: VT, MemVT: N0.getValueType())) {
14007 bool DoXform = true;
14008 SmallVector<SDNode *, 4> SetCCs;
14009 if (!N0.hasOneUse())
14010 DoXform =
14011 ExtendUsesToFormExtLoad(VT, N, N0, ExtOpc: ISD::ANY_EXTEND, ExtendNodes&: SetCCs, TLI);
14012 if (DoXform) {
14013 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
14014 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::EXTLOAD, dl: DL, VT, Chain: LN0->getChain(),
14015 Ptr: LN0->getBasePtr(), MemVT: N0.getValueType(),
14016 MMO: LN0->getMemOperand());
14017 ExtendSetCCUses(SetCCs, OrigLoad: N0, ExtLoad, ExtType: ISD::ANY_EXTEND);
14018 // If the load value is used only by N, replace it via CombineTo N.
14019 bool NoReplaceTrunc = N0.hasOneUse();
14020 CombineTo(N, Res: ExtLoad);
14021 if (NoReplaceTrunc) {
14022 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN0, 1), To: ExtLoad.getValue(R: 1));
14023 recursivelyDeleteUnusedNodes(N: LN0);
14024 } else {
14025 SDValue Trunc =
14026 DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N0), VT: N0.getValueType(), Operand: ExtLoad);
14027 CombineTo(N: LN0, Res0: Trunc, Res1: ExtLoad.getValue(R: 1));
14028 }
14029 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14030 }
14031 }
14032
14033 // fold (aext (zextload x)) -> (aext (truncate (zextload x)))
14034 // fold (aext (sextload x)) -> (aext (truncate (sextload x)))
14035 // fold (aext ( extload x)) -> (aext (truncate (extload x)))
14036 if (N0.getOpcode() == ISD::LOAD && !ISD::isNON_EXTLoad(N: N0.getNode()) &&
14037 ISD::isUNINDEXEDLoad(N: N0.getNode()) && N0.hasOneUse()) {
14038 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
14039 ISD::LoadExtType ExtType = LN0->getExtensionType();
14040 EVT MemVT = LN0->getMemoryVT();
14041 if (!LegalOperations || TLI.isLoadExtLegal(ExtType, ValVT: VT, MemVT)) {
14042 SDValue ExtLoad =
14043 DAG.getExtLoad(ExtType, dl: DL, VT, Chain: LN0->getChain(), Ptr: LN0->getBasePtr(),
14044 MemVT, MMO: LN0->getMemOperand());
14045 CombineTo(N, Res: ExtLoad);
14046 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN0, 1), To: ExtLoad.getValue(R: 1));
14047 recursivelyDeleteUnusedNodes(N: LN0);
14048 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14049 }
14050 }
14051
14052 if (N0.getOpcode() == ISD::SETCC) {
14053 // Propagate fast-math-flags.
14054 SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
14055
14056 // For vectors:
14057 // aext(setcc) -> vsetcc
14058 // aext(setcc) -> truncate(vsetcc)
14059 // aext(setcc) -> aext(vsetcc)
14060 // Only do this before legalize for now.
14061 if (VT.isVector() && !LegalOperations) {
14062 EVT N00VT = N0.getOperand(i: 0).getValueType();
14063 if (getSetCCResultType(VT: N00VT) == N0.getValueType())
14064 return SDValue();
14065
14066 // We know that the # elements of the results is the same as the
14067 // # elements of the compare (and the # elements of the compare result
14068 // for that matter). Check to see that they are the same size. If so,
14069 // we know that the element size of the sext'd result matches the
14070 // element size of the compare operands.
14071 if (VT.getSizeInBits() == N00VT.getSizeInBits())
14072 return DAG.getSetCC(DL, VT, LHS: N0.getOperand(i: 0), RHS: N0.getOperand(i: 1),
14073 Cond: cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get());
14074
14075 // If the desired elements are smaller or larger than the source
14076 // elements we can use a matching integer vector type and then
14077 // truncate/any extend
14078 EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
14079 SDValue VsetCC = DAG.getSetCC(
14080 DL, VT: MatchingVectorType, LHS: N0.getOperand(i: 0), RHS: N0.getOperand(i: 1),
14081 Cond: cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get());
14082 return DAG.getAnyExtOrTrunc(Op: VsetCC, DL, VT);
14083 }
14084
14085 // aext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc
14086 if (SDValue SCC = SimplifySelectCC(
14087 DL, N0: N0.getOperand(i: 0), N1: N0.getOperand(i: 1), N2: DAG.getConstant(Val: 1, DL, VT),
14088 N3: DAG.getConstant(Val: 0, DL, VT),
14089 CC: cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get(), NotExtCompare: true))
14090 return SCC;
14091 }
14092
14093 if (SDValue NewCtPop = widenCtPop(Extend: N, DAG))
14094 return NewCtPop;
14095
14096 if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, Level))
14097 return Res;
14098
14099 return SDValue();
14100}
14101
14102SDValue DAGCombiner::visitAssertExt(SDNode *N) {
14103 unsigned Opcode = N->getOpcode();
14104 SDValue N0 = N->getOperand(Num: 0);
14105 SDValue N1 = N->getOperand(Num: 1);
14106 EVT AssertVT = cast<VTSDNode>(Val&: N1)->getVT();
14107
14108 // fold (assert?ext (assert?ext x, vt), vt) -> (assert?ext x, vt)
14109 if (N0.getOpcode() == Opcode &&
14110 AssertVT == cast<VTSDNode>(Val: N0.getOperand(i: 1))->getVT())
14111 return N0;
14112
14113 if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
14114 N0.getOperand(i: 0).getOpcode() == Opcode) {
14115 // We have an assert, truncate, assert sandwich. Make one stronger assert
14116 // by asserting on the smallest asserted type to the larger source type.
14117 // This eliminates the later assert:
14118 // assert (trunc (assert X, i8) to iN), i1 --> trunc (assert X, i1) to iN
14119 // assert (trunc (assert X, i1) to iN), i8 --> trunc (assert X, i1) to iN
14120 SDLoc DL(N);
14121 SDValue BigA = N0.getOperand(i: 0);
14122 EVT BigA_AssertVT = cast<VTSDNode>(Val: BigA.getOperand(i: 1))->getVT();
14123 EVT MinAssertVT = AssertVT.bitsLT(VT: BigA_AssertVT) ? AssertVT : BigA_AssertVT;
14124 SDValue MinAssertVTVal = DAG.getValueType(MinAssertVT);
14125 SDValue NewAssert = DAG.getNode(Opcode, DL, VT: BigA.getValueType(),
14126 N1: BigA.getOperand(i: 0), N2: MinAssertVTVal);
14127 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: N->getValueType(ResNo: 0), Operand: NewAssert);
14128 }
14129
14130 // If we have (AssertZext (truncate (AssertSext X, iX)), iY) and Y is smaller
14131 // than X. Just move the AssertZext in front of the truncate and drop the
14132 // AssertSExt.
14133 if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
14134 N0.getOperand(i: 0).getOpcode() == ISD::AssertSext &&
14135 Opcode == ISD::AssertZext) {
14136 SDValue BigA = N0.getOperand(i: 0);
14137 EVT BigA_AssertVT = cast<VTSDNode>(Val: BigA.getOperand(i: 1))->getVT();
14138 if (AssertVT.bitsLT(VT: BigA_AssertVT)) {
14139 SDLoc DL(N);
14140 SDValue NewAssert = DAG.getNode(Opcode, DL, VT: BigA.getValueType(),
14141 N1: BigA.getOperand(i: 0), N2: N1);
14142 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: N->getValueType(ResNo: 0), Operand: NewAssert);
14143 }
14144 }
14145
14146 return SDValue();
14147}
14148
14149SDValue DAGCombiner::visitAssertAlign(SDNode *N) {
14150 SDLoc DL(N);
14151
14152 Align AL = cast<AssertAlignSDNode>(Val: N)->getAlign();
14153 SDValue N0 = N->getOperand(Num: 0);
14154
14155 // Fold (assertalign (assertalign x, AL0), AL1) ->
14156 // (assertalign x, max(AL0, AL1))
14157 if (auto *AAN = dyn_cast<AssertAlignSDNode>(Val&: N0))
14158 return DAG.getAssertAlign(DL, V: N0.getOperand(i: 0),
14159 A: std::max(a: AL, b: AAN->getAlign()));
14160
14161 // In rare cases, there are trivial arithmetic ops in source operands. Sink
14162 // this assert down to source operands so that those arithmetic ops could be
14163 // exposed to the DAG combining.
14164 switch (N0.getOpcode()) {
14165 default:
14166 break;
14167 case ISD::ADD:
14168 case ISD::SUB: {
14169 unsigned AlignShift = Log2(A: AL);
14170 SDValue LHS = N0.getOperand(i: 0);
14171 SDValue RHS = N0.getOperand(i: 1);
14172 unsigned LHSAlignShift = DAG.computeKnownBits(Op: LHS).countMinTrailingZeros();
14173 unsigned RHSAlignShift = DAG.computeKnownBits(Op: RHS).countMinTrailingZeros();
14174 if (LHSAlignShift >= AlignShift || RHSAlignShift >= AlignShift) {
14175 if (LHSAlignShift < AlignShift)
14176 LHS = DAG.getAssertAlign(DL, V: LHS, A: AL);
14177 if (RHSAlignShift < AlignShift)
14178 RHS = DAG.getAssertAlign(DL, V: RHS, A: AL);
14179 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT: N0.getValueType(), N1: LHS, N2: RHS);
14180 }
14181 break;
14182 }
14183 }
14184
14185 return SDValue();
14186}
14187
14188/// If the result of a load is shifted/masked/truncated to an effectively
14189/// narrower type, try to transform the load to a narrower type and/or
14190/// use an extending load.
14191SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
14192 unsigned Opc = N->getOpcode();
14193
14194 ISD::LoadExtType ExtType = ISD::NON_EXTLOAD;
14195 SDValue N0 = N->getOperand(Num: 0);
14196 EVT VT = N->getValueType(ResNo: 0);
14197 EVT ExtVT = VT;
14198
14199 // This transformation isn't valid for vector loads.
14200 if (VT.isVector())
14201 return SDValue();
14202
14203 // The ShAmt variable is used to indicate that we've consumed a right
14204 // shift. I.e. we want to narrow the width of the load by skipping to load the
14205 // ShAmt least significant bits.
14206 unsigned ShAmt = 0;
14207 // A special case is when the least significant bits from the load are masked
14208 // away, but using an AND rather than a right shift. HasShiftedOffset is used
14209 // to indicate that the narrowed load should be left-shifted ShAmt bits to get
14210 // the result.
14211 unsigned ShiftedOffset = 0;
14212 // Special case: SIGN_EXTEND_INREG is basically truncating to ExtVT then
14213 // extended to VT.
14214 if (Opc == ISD::SIGN_EXTEND_INREG) {
14215 ExtType = ISD::SEXTLOAD;
14216 ExtVT = cast<VTSDNode>(Val: N->getOperand(Num: 1))->getVT();
14217 } else if (Opc == ISD::SRL || Opc == ISD::SRA) {
14218 // Another special-case: SRL/SRA is basically zero/sign-extending a narrower
14219 // value, or it may be shifting a higher subword, half or byte into the
14220 // lowest bits.
14221
14222 // Only handle shift with constant shift amount, and the shiftee must be a
14223 // load.
14224 auto *LN = dyn_cast<LoadSDNode>(Val&: N0);
14225 auto *N1C = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
14226 if (!N1C || !LN)
14227 return SDValue();
14228 // If the shift amount is larger than the memory type then we're not
14229 // accessing any of the loaded bytes.
14230 ShAmt = N1C->getZExtValue();
14231 uint64_t MemoryWidth = LN->getMemoryVT().getScalarSizeInBits();
14232 if (MemoryWidth <= ShAmt)
14233 return SDValue();
14234 // Attempt to fold away the SRL by using ZEXTLOAD and SRA by using SEXTLOAD.
14235 ExtType = Opc == ISD::SRL ? ISD::ZEXTLOAD : ISD::SEXTLOAD;
14236 ExtVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: MemoryWidth - ShAmt);
14237 // If original load is a SEXTLOAD then we can't simply replace it by a
14238 // ZEXTLOAD (we could potentially replace it by a more narrow SEXTLOAD
14239 // followed by a ZEXT, but that is not handled at the moment). Similarly if
14240 // the original load is a ZEXTLOAD and we want to use a SEXTLOAD.
14241 if ((LN->getExtensionType() == ISD::SEXTLOAD ||
14242 LN->getExtensionType() == ISD::ZEXTLOAD) &&
14243 LN->getExtensionType() != ExtType)
14244 return SDValue();
14245 } else if (Opc == ISD::AND) {
14246 // An AND with a constant mask is the same as a truncate + zero-extend.
14247 auto AndC = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
14248 if (!AndC)
14249 return SDValue();
14250
14251 const APInt &Mask = AndC->getAPIntValue();
14252 unsigned ActiveBits = 0;
14253 if (Mask.isMask()) {
14254 ActiveBits = Mask.countr_one();
14255 } else if (Mask.isShiftedMask(MaskIdx&: ShAmt, MaskLen&: ActiveBits)) {
14256 ShiftedOffset = ShAmt;
14257 } else {
14258 return SDValue();
14259 }
14260
14261 ExtType = ISD::ZEXTLOAD;
14262 ExtVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ActiveBits);
14263 }
14264
14265 // In case Opc==SRL we've already prepared ExtVT/ExtType/ShAmt based on doing
14266 // a right shift. Here we redo some of those checks, to possibly adjust the
14267 // ExtVT even further based on "a masking AND". We could also end up here for
14268 // other reasons (e.g. based on Opc==TRUNCATE) and that is why some checks
14269 // need to be done here as well.
14270 if (Opc == ISD::SRL || N0.getOpcode() == ISD::SRL) {
14271 SDValue SRL = Opc == ISD::SRL ? SDValue(N, 0) : N0;
14272 // Bail out when the SRL has more than one use. This is done for historical
14273 // (undocumented) reasons. Maybe intent was to guard the AND-masking below
14274 // check below? And maybe it could be non-profitable to do the transform in
14275 // case the SRL has multiple uses and we get here with Opc!=ISD::SRL?
14276 // FIXME: Can't we just skip this check for the Opc==ISD::SRL case.
14277 if (!SRL.hasOneUse())
14278 return SDValue();
14279
14280 // Only handle shift with constant shift amount, and the shiftee must be a
14281 // load.
14282 auto *LN = dyn_cast<LoadSDNode>(Val: SRL.getOperand(i: 0));
14283 auto *SRL1C = dyn_cast<ConstantSDNode>(Val: SRL.getOperand(i: 1));
14284 if (!SRL1C || !LN)
14285 return SDValue();
14286
14287 // If the shift amount is larger than the input type then we're not
14288 // accessing any of the loaded bytes. If the load was a zextload/extload
14289 // then the result of the shift+trunc is zero/undef (handled elsewhere).
14290 ShAmt = SRL1C->getZExtValue();
14291 uint64_t MemoryWidth = LN->getMemoryVT().getSizeInBits();
14292 if (ShAmt >= MemoryWidth)
14293 return SDValue();
14294
14295 // Because a SRL must be assumed to *need* to zero-extend the high bits
14296 // (as opposed to anyext the high bits), we can't combine the zextload
14297 // lowering of SRL and an sextload.
14298 if (LN->getExtensionType() == ISD::SEXTLOAD)
14299 return SDValue();
14300
14301 // Avoid reading outside the memory accessed by the original load (could
14302 // happened if we only adjust the load base pointer by ShAmt). Instead we
14303 // try to narrow the load even further. The typical scenario here is:
14304 // (i64 (truncate (i96 (srl (load x), 64)))) ->
14305 // (i64 (truncate (i96 (zextload (load i32 + offset) from i32))))
14306 if (ExtVT.getScalarSizeInBits() > MemoryWidth - ShAmt) {
14307 // Don't replace sextload by zextload.
14308 if (ExtType == ISD::SEXTLOAD)
14309 return SDValue();
14310 // Narrow the load.
14311 ExtType = ISD::ZEXTLOAD;
14312 ExtVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: MemoryWidth - ShAmt);
14313 }
14314
14315 // If the SRL is only used by a masking AND, we may be able to adjust
14316 // the ExtVT to make the AND redundant.
14317 SDNode *Mask = *(SRL->use_begin());
14318 if (SRL.hasOneUse() && Mask->getOpcode() == ISD::AND &&
14319 isa<ConstantSDNode>(Val: Mask->getOperand(Num: 1))) {
14320 unsigned Offset, ActiveBits;
14321 const APInt& ShiftMask = Mask->getConstantOperandAPInt(Num: 1);
14322 if (ShiftMask.isMask()) {
14323 EVT MaskedVT =
14324 EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ShiftMask.countr_one());
14325 // If the mask is smaller, recompute the type.
14326 if ((ExtVT.getScalarSizeInBits() > MaskedVT.getScalarSizeInBits()) &&
14327 TLI.isLoadExtLegal(ExtType, ValVT: SRL.getValueType(), MemVT: MaskedVT))
14328 ExtVT = MaskedVT;
14329 } else if (ExtType == ISD::ZEXTLOAD &&
14330 ShiftMask.isShiftedMask(MaskIdx&: Offset, MaskLen&: ActiveBits) &&
14331 (Offset + ShAmt) < VT.getScalarSizeInBits()) {
14332 EVT MaskedVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ActiveBits);
14333 // If the mask is shifted we can use a narrower load and a shl to insert
14334 // the trailing zeros.
14335 if (((Offset + ActiveBits) <= ExtVT.getScalarSizeInBits()) &&
14336 TLI.isLoadExtLegal(ExtType, ValVT: SRL.getValueType(), MemVT: MaskedVT)) {
14337 ExtVT = MaskedVT;
14338 ShAmt = Offset + ShAmt;
14339 ShiftedOffset = Offset;
14340 }
14341 }
14342 }
14343
14344 N0 = SRL.getOperand(i: 0);
14345 }
14346
14347 // If the load is shifted left (and the result isn't shifted back right), we
14348 // can fold a truncate through the shift. The typical scenario is that N
14349 // points at a TRUNCATE here so the attempted fold is:
14350 // (truncate (shl (load x), c))) -> (shl (narrow load x), c)
14351 // ShLeftAmt will indicate how much a narrowed load should be shifted left.
14352 unsigned ShLeftAmt = 0;
14353 if (ShAmt == 0 && N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
14354 ExtVT == VT && TLI.isNarrowingProfitable(SrcVT: N0.getValueType(), DestVT: VT)) {
14355 if (ConstantSDNode *N01 = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1))) {
14356 ShLeftAmt = N01->getZExtValue();
14357 N0 = N0.getOperand(i: 0);
14358 }
14359 }
14360
14361 // If we haven't found a load, we can't narrow it.
14362 if (!isa<LoadSDNode>(Val: N0))
14363 return SDValue();
14364
14365 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
14366 // Reducing the width of a volatile load is illegal. For atomics, we may be
14367 // able to reduce the width provided we never widen again. (see D66309)
14368 if (!LN0->isSimple() ||
14369 !isLegalNarrowLdSt(LDST: LN0, ExtType, MemVT&: ExtVT, ShAmt))
14370 return SDValue();
14371
14372 auto AdjustBigEndianShift = [&](unsigned ShAmt) {
14373 unsigned LVTStoreBits =
14374 LN0->getMemoryVT().getStoreSizeInBits().getFixedValue();
14375 unsigned EVTStoreBits = ExtVT.getStoreSizeInBits().getFixedValue();
14376 return LVTStoreBits - EVTStoreBits - ShAmt;
14377 };
14378
14379 // We need to adjust the pointer to the load by ShAmt bits in order to load
14380 // the correct bytes.
14381 unsigned PtrAdjustmentInBits =
14382 DAG.getDataLayout().isBigEndian() ? AdjustBigEndianShift(ShAmt) : ShAmt;
14383
14384 uint64_t PtrOff = PtrAdjustmentInBits / 8;
14385 SDLoc DL(LN0);
14386 // The original load itself didn't wrap, so an offset within it doesn't.
14387 SDNodeFlags Flags;
14388 Flags.setNoUnsignedWrap(true);
14389 SDValue NewPtr = DAG.getMemBasePlusOffset(
14390 Base: LN0->getBasePtr(), Offset: TypeSize::getFixed(ExactSize: PtrOff), DL, Flags);
14391 AddToWorklist(N: NewPtr.getNode());
14392
14393 SDValue Load;
14394 if (ExtType == ISD::NON_EXTLOAD)
14395 Load = DAG.getLoad(VT, dl: DL, Chain: LN0->getChain(), Ptr: NewPtr,
14396 PtrInfo: LN0->getPointerInfo().getWithOffset(O: PtrOff),
14397 Alignment: LN0->getOriginalAlign(),
14398 MMOFlags: LN0->getMemOperand()->getFlags(), AAInfo: LN0->getAAInfo());
14399 else
14400 Load = DAG.getExtLoad(ExtType, dl: DL, VT, Chain: LN0->getChain(), Ptr: NewPtr,
14401 PtrInfo: LN0->getPointerInfo().getWithOffset(O: PtrOff), MemVT: ExtVT,
14402 Alignment: LN0->getOriginalAlign(),
14403 MMOFlags: LN0->getMemOperand()->getFlags(), AAInfo: LN0->getAAInfo());
14404
14405 // Replace the old load's chain with the new load's chain.
14406 WorklistRemover DeadNodes(*this);
14407 DAG.ReplaceAllUsesOfValueWith(From: N0.getValue(R: 1), To: Load.getValue(R: 1));
14408
14409 // Shift the result left, if we've swallowed a left shift.
14410 SDValue Result = Load;
14411 if (ShLeftAmt != 0) {
14412 EVT ShImmTy = getShiftAmountTy(LHSTy: Result.getValueType());
14413 if (!isUIntN(N: ShImmTy.getScalarSizeInBits(), x: ShLeftAmt))
14414 ShImmTy = VT;
14415 // If the shift amount is as large as the result size (but, presumably,
14416 // no larger than the source) then the useful bits of the result are
14417 // zero; we can't simply return the shortened shift, because the result
14418 // of that operation is undefined.
14419 if (ShLeftAmt >= VT.getScalarSizeInBits())
14420 Result = DAG.getConstant(Val: 0, DL, VT);
14421 else
14422 Result = DAG.getNode(Opcode: ISD::SHL, DL, VT,
14423 N1: Result, N2: DAG.getConstant(Val: ShLeftAmt, DL, VT: ShImmTy));
14424 }
14425
14426 if (ShiftedOffset != 0) {
14427 // We're using a shifted mask, so the load now has an offset. This means
14428 // that data has been loaded into the lower bytes than it would have been
14429 // before, so we need to shl the loaded data into the correct position in the
14430 // register.
14431 SDValue ShiftC = DAG.getConstant(Val: ShiftedOffset, DL, VT);
14432 Result = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Result, N2: ShiftC);
14433 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result);
14434 }
14435
14436 // Return the new loaded value.
14437 return Result;
14438}
14439
14440SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
14441 SDValue N0 = N->getOperand(Num: 0);
14442 SDValue N1 = N->getOperand(Num: 1);
14443 EVT VT = N->getValueType(ResNo: 0);
14444 EVT ExtVT = cast<VTSDNode>(Val&: N1)->getVT();
14445 unsigned VTBits = VT.getScalarSizeInBits();
14446 unsigned ExtVTBits = ExtVT.getScalarSizeInBits();
14447
14448 // sext_vector_inreg(undef) = 0 because the top bit will all be the same.
14449 if (N0.isUndef())
14450 return DAG.getConstant(Val: 0, DL: SDLoc(N), VT);
14451
14452 // fold (sext_in_reg c1) -> c1
14453 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0))
14454 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL: SDLoc(N), VT, N1: N0, N2: N1);
14455
14456 // If the input is already sign extended, just drop the extension.
14457 if (ExtVTBits >= DAG.ComputeMaxSignificantBits(Op: N0))
14458 return N0;
14459
14460 // fold (sext_in_reg (sext_in_reg x, VT2), VT1) -> (sext_in_reg x, minVT) pt2
14461 if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
14462 ExtVT.bitsLT(VT: cast<VTSDNode>(Val: N0.getOperand(i: 1))->getVT()))
14463 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL: SDLoc(N), VT, N1: N0.getOperand(i: 0),
14464 N2: N1);
14465
14466 // fold (sext_in_reg (sext x)) -> (sext x)
14467 // fold (sext_in_reg (aext x)) -> (sext x)
14468 // if x is small enough or if we know that x has more than 1 sign bit and the
14469 // sign_extend_inreg is extending from one of them.
14470 if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
14471 SDValue N00 = N0.getOperand(i: 0);
14472 unsigned N00Bits = N00.getScalarValueSizeInBits();
14473 if ((N00Bits <= ExtVTBits ||
14474 DAG.ComputeMaxSignificantBits(Op: N00) <= ExtVTBits) &&
14475 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SIGN_EXTEND, VT)))
14476 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL: SDLoc(N), VT, Operand: N00);
14477 }
14478
14479 // fold (sext_in_reg (*_extend_vector_inreg x)) -> (sext_vector_inreg x)
14480 // if x is small enough or if we know that x has more than 1 sign bit and the
14481 // sign_extend_inreg is extending from one of them.
14482 if (ISD::isExtVecInRegOpcode(Opcode: N0.getOpcode())) {
14483 SDValue N00 = N0.getOperand(i: 0);
14484 unsigned N00Bits = N00.getScalarValueSizeInBits();
14485 unsigned DstElts = N0.getValueType().getVectorMinNumElements();
14486 unsigned SrcElts = N00.getValueType().getVectorMinNumElements();
14487 bool IsZext = N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG;
14488 APInt DemandedSrcElts = APInt::getLowBitsSet(numBits: SrcElts, loBitsSet: DstElts);
14489 if ((N00Bits == ExtVTBits ||
14490 (!IsZext && (N00Bits < ExtVTBits ||
14491 DAG.ComputeMaxSignificantBits(Op: N00) <= ExtVTBits))) &&
14492 (!LegalOperations ||
14493 TLI.isOperationLegal(Op: ISD::SIGN_EXTEND_VECTOR_INREG, VT)))
14494 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_VECTOR_INREG, DL: SDLoc(N), VT, Operand: N00);
14495 }
14496
14497 // fold (sext_in_reg (zext x)) -> (sext x)
14498 // iff we are extending the source sign bit.
14499 if (N0.getOpcode() == ISD::ZERO_EXTEND) {
14500 SDValue N00 = N0.getOperand(i: 0);
14501 if (N00.getScalarValueSizeInBits() == ExtVTBits &&
14502 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SIGN_EXTEND, VT)))
14503 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL: SDLoc(N), VT, Operand: N00);
14504 }
14505
14506 // fold (sext_in_reg x) -> (zext_in_reg x) if the sign bit is known zero.
14507 if (DAG.MaskedValueIsZero(Op: N0, Mask: APInt::getOneBitSet(numBits: VTBits, BitNo: ExtVTBits - 1)))
14508 return DAG.getZeroExtendInReg(Op: N0, DL: SDLoc(N), VT: ExtVT);
14509
14510 // fold operands of sext_in_reg based on knowledge that the top bits are not
14511 // demanded.
14512 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
14513 return SDValue(N, 0);
14514
14515 // fold (sext_in_reg (load x)) -> (smaller sextload x)
14516 // fold (sext_in_reg (srl (load x), c)) -> (smaller sextload (x+c/evtbits))
14517 if (SDValue NarrowLoad = reduceLoadWidth(N))
14518 return NarrowLoad;
14519
14520 // fold (sext_in_reg (srl X, 24), i8) -> (sra X, 24)
14521 // fold (sext_in_reg (srl X, 23), i8) -> (sra X, 23) iff possible.
14522 // We already fold "(sext_in_reg (srl X, 25), i8) -> srl X, 25" above.
14523 if (N0.getOpcode() == ISD::SRL) {
14524 if (auto *ShAmt = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1)))
14525 if (ShAmt->getAPIntValue().ule(RHS: VTBits - ExtVTBits)) {
14526 // We can turn this into an SRA iff the input to the SRL is already sign
14527 // extended enough.
14528 unsigned InSignBits = DAG.ComputeNumSignBits(Op: N0.getOperand(i: 0));
14529 if (((VTBits - ExtVTBits) - ShAmt->getZExtValue()) < InSignBits)
14530 return DAG.getNode(Opcode: ISD::SRA, DL: SDLoc(N), VT, N1: N0.getOperand(i: 0),
14531 N2: N0.getOperand(i: 1));
14532 }
14533 }
14534
14535 // fold (sext_inreg (extload x)) -> (sextload x)
14536 // If sextload is not supported by target, we can only do the combine when
14537 // load has one use. Doing otherwise can block folding the extload with other
14538 // extends that the target does support.
14539 if (ISD::isEXTLoad(N: N0.getNode()) &&
14540 ISD::isUNINDEXEDLoad(N: N0.getNode()) &&
14541 ExtVT == cast<LoadSDNode>(Val&: N0)->getMemoryVT() &&
14542 ((!LegalOperations && cast<LoadSDNode>(Val&: N0)->isSimple() &&
14543 N0.hasOneUse()) ||
14544 TLI.isLoadExtLegal(ExtType: ISD::SEXTLOAD, ValVT: VT, MemVT: ExtVT))) {
14545 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
14546 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::SEXTLOAD, dl: SDLoc(N), VT,
14547 Chain: LN0->getChain(),
14548 Ptr: LN0->getBasePtr(), MemVT: ExtVT,
14549 MMO: LN0->getMemOperand());
14550 CombineTo(N, Res: ExtLoad);
14551 CombineTo(N: N0.getNode(), Res0: ExtLoad, Res1: ExtLoad.getValue(R: 1));
14552 AddToWorklist(N: ExtLoad.getNode());
14553 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14554 }
14555
14556 // fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use
14557 if (ISD::isZEXTLoad(N: N0.getNode()) && ISD::isUNINDEXEDLoad(N: N0.getNode()) &&
14558 N0.hasOneUse() &&
14559 ExtVT == cast<LoadSDNode>(Val&: N0)->getMemoryVT() &&
14560 ((!LegalOperations && cast<LoadSDNode>(Val&: N0)->isSimple()) &&
14561 TLI.isLoadExtLegal(ExtType: ISD::SEXTLOAD, ValVT: VT, MemVT: ExtVT))) {
14562 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
14563 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::SEXTLOAD, dl: SDLoc(N), VT,
14564 Chain: LN0->getChain(),
14565 Ptr: LN0->getBasePtr(), MemVT: ExtVT,
14566 MMO: LN0->getMemOperand());
14567 CombineTo(N, Res: ExtLoad);
14568 CombineTo(N: N0.getNode(), Res0: ExtLoad, Res1: ExtLoad.getValue(R: 1));
14569 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14570 }
14571
14572 // fold (sext_inreg (masked_load x)) -> (sext_masked_load x)
14573 // ignore it if the masked load is already sign extended
14574 if (MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(Val&: N0)) {
14575 if (ExtVT == Ld->getMemoryVT() && N0.hasOneUse() &&
14576 Ld->getExtensionType() != ISD::LoadExtType::NON_EXTLOAD &&
14577 TLI.isLoadExtLegal(ExtType: ISD::SEXTLOAD, ValVT: VT, MemVT: ExtVT)) {
14578 SDValue ExtMaskedLoad = DAG.getMaskedLoad(
14579 VT, dl: SDLoc(N), Chain: Ld->getChain(), Base: Ld->getBasePtr(), Offset: Ld->getOffset(),
14580 Mask: Ld->getMask(), Src0: Ld->getPassThru(), MemVT: ExtVT, MMO: Ld->getMemOperand(),
14581 AM: Ld->getAddressingMode(), ISD::SEXTLOAD, IsExpanding: Ld->isExpandingLoad());
14582 CombineTo(N, Res: ExtMaskedLoad);
14583 CombineTo(N: N0.getNode(), Res0: ExtMaskedLoad, Res1: ExtMaskedLoad.getValue(R: 1));
14584 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14585 }
14586 }
14587
14588 // fold (sext_inreg (masked_gather x)) -> (sext_masked_gather x)
14589 if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(Val&: N0)) {
14590 if (SDValue(GN0, 0).hasOneUse() &&
14591 ExtVT == GN0->getMemoryVT() &&
14592 TLI.isVectorLoadExtDesirable(ExtVal: SDValue(SDValue(GN0, 0)))) {
14593 SDValue Ops[] = {GN0->getChain(), GN0->getPassThru(), GN0->getMask(),
14594 GN0->getBasePtr(), GN0->getIndex(), GN0->getScale()};
14595
14596 SDValue ExtLoad = DAG.getMaskedGather(
14597 DAG.getVTList(VT, MVT::Other), ExtVT, SDLoc(N), Ops,
14598 GN0->getMemOperand(), GN0->getIndexType(), ISD::SEXTLOAD);
14599
14600 CombineTo(N, Res: ExtLoad);
14601 CombineTo(N: N0.getNode(), Res0: ExtLoad, Res1: ExtLoad.getValue(R: 1));
14602 AddToWorklist(N: ExtLoad.getNode());
14603 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14604 }
14605 }
14606
14607 // Form (sext_inreg (bswap >> 16)) or (sext_inreg (rotl (bswap) 16))
14608 if (ExtVTBits <= 16 && N0.getOpcode() == ISD::OR) {
14609 if (SDValue BSwap = MatchBSwapHWordLow(N: N0.getNode(), N0: N0.getOperand(i: 0),
14610 N1: N0.getOperand(i: 1), DemandHighBits: false))
14611 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL: SDLoc(N), VT, N1: BSwap, N2: N1);
14612 }
14613
14614 // Fold (iM_signext_inreg
14615 // (extract_subvector (zext|anyext|sext iN_v to _) _)
14616 // from iN)
14617 // -> (extract_subvector (signext iN_v to iM))
14618 if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && N0.hasOneUse() &&
14619 ISD::isExtOpcode(Opcode: N0.getOperand(i: 0).getOpcode())) {
14620 SDValue InnerExt = N0.getOperand(i: 0);
14621 EVT InnerExtVT = InnerExt->getValueType(ResNo: 0);
14622 SDValue Extendee = InnerExt->getOperand(Num: 0);
14623
14624 if (ExtVTBits == Extendee.getValueType().getScalarSizeInBits() &&
14625 (!LegalOperations ||
14626 TLI.isOperationLegal(Op: ISD::SIGN_EXTEND, VT: InnerExtVT))) {
14627 SDValue SignExtExtendee =
14628 DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL: SDLoc(N), VT: InnerExtVT, Operand: Extendee);
14629 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL: SDLoc(N), VT, N1: SignExtExtendee,
14630 N2: N0.getOperand(i: 1));
14631 }
14632 }
14633
14634 return SDValue();
14635}
14636
14637static SDValue foldExtendVectorInregToExtendOfSubvector(
14638 SDNode *N, const SDLoc &DL, const TargetLowering &TLI, SelectionDAG &DAG,
14639 bool LegalOperations) {
14640 unsigned InregOpcode = N->getOpcode();
14641 unsigned Opcode = DAG.getOpcode_EXTEND(Opcode: InregOpcode);
14642
14643 SDValue Src = N->getOperand(Num: 0);
14644 EVT VT = N->getValueType(ResNo: 0);
14645 EVT SrcVT = EVT::getVectorVT(Context&: *DAG.getContext(),
14646 VT: Src.getValueType().getVectorElementType(),
14647 EC: VT.getVectorElementCount());
14648
14649 assert(ISD::isExtVecInRegOpcode(InregOpcode) &&
14650 "Expected EXTEND_VECTOR_INREG dag node in input!");
14651
14652 // Profitability check: our operand must be an one-use CONCAT_VECTORS.
14653 // FIXME: one-use check may be overly restrictive
14654 if (!Src.hasOneUse() || Src.getOpcode() != ISD::CONCAT_VECTORS)
14655 return SDValue();
14656
14657 // Profitability check: we must be extending exactly one of it's operands.
14658 // FIXME: this is probably overly restrictive.
14659 Src = Src.getOperand(i: 0);
14660 if (Src.getValueType() != SrcVT)
14661 return SDValue();
14662
14663 if (LegalOperations && !TLI.isOperationLegal(Op: Opcode, VT))
14664 return SDValue();
14665
14666 return DAG.getNode(Opcode, DL, VT, Operand: Src);
14667}
14668
14669SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
14670 SDValue N0 = N->getOperand(Num: 0);
14671 EVT VT = N->getValueType(ResNo: 0);
14672 SDLoc DL(N);
14673
14674 if (N0.isUndef()) {
14675 // aext_vector_inreg(undef) = undef because the top bits are undefined.
14676 // {s/z}ext_vector_inreg(undef) = 0 because the top bits must be the same.
14677 return N->getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG
14678 ? DAG.getUNDEF(VT)
14679 : DAG.getConstant(Val: 0, DL, VT);
14680 }
14681
14682 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
14683 return Res;
14684
14685 if (SimplifyDemandedVectorElts(Op: SDValue(N, 0)))
14686 return SDValue(N, 0);
14687
14688 if (SDValue R = foldExtendVectorInregToExtendOfSubvector(N, DL, TLI, DAG,
14689 LegalOperations))
14690 return R;
14691
14692 return SDValue();
14693}
14694
14695SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
14696 SDValue N0 = N->getOperand(Num: 0);
14697 EVT VT = N->getValueType(ResNo: 0);
14698 EVT SrcVT = N0.getValueType();
14699 bool isLE = DAG.getDataLayout().isLittleEndian();
14700
14701 // trunc(undef) = undef
14702 if (N0.isUndef())
14703 return DAG.getUNDEF(VT);
14704
14705 // fold (truncate (truncate x)) -> (truncate x)
14706 if (N0.getOpcode() == ISD::TRUNCATE)
14707 return DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N), VT, Operand: N0.getOperand(i: 0));
14708
14709 // fold (truncate c1) -> c1
14710 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::TRUNCATE, DL: SDLoc(N), VT, Ops: {N0}))
14711 return C;
14712
14713 // fold (truncate (ext x)) -> (ext x) or (truncate x) or x
14714 if (N0.getOpcode() == ISD::ZERO_EXTEND ||
14715 N0.getOpcode() == ISD::SIGN_EXTEND ||
14716 N0.getOpcode() == ISD::ANY_EXTEND) {
14717 // if the source is smaller than the dest, we still need an extend.
14718 if (N0.getOperand(i: 0).getValueType().bitsLT(VT))
14719 return DAG.getNode(Opcode: N0.getOpcode(), DL: SDLoc(N), VT, Operand: N0.getOperand(i: 0));
14720 // if the source is larger than the dest, than we just need the truncate.
14721 if (N0.getOperand(i: 0).getValueType().bitsGT(VT))
14722 return DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N), VT, Operand: N0.getOperand(i: 0));
14723 // if the source and dest are the same type, we can drop both the extend
14724 // and the truncate.
14725 return N0.getOperand(i: 0);
14726 }
14727
14728 // Try to narrow a truncate-of-sext_in_reg to the destination type:
14729 // trunc (sign_ext_inreg X, iM) to iN --> sign_ext_inreg (trunc X to iN), iM
14730 if (!LegalTypes && N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
14731 N0.hasOneUse()) {
14732 SDValue X = N0.getOperand(i: 0);
14733 SDValue ExtVal = N0.getOperand(i: 1);
14734 EVT ExtVT = cast<VTSDNode>(Val&: ExtVal)->getVT();
14735 if (ExtVT.bitsLT(VT) && TLI.preferSextInRegOfTruncate(TruncVT: VT, VT: SrcVT, ExtVT)) {
14736 SDValue TrX = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N), VT, Operand: X);
14737 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL: SDLoc(N), VT, N1: TrX, N2: ExtVal);
14738 }
14739 }
14740
14741 // If this is anyext(trunc), don't fold it, allow ourselves to be folded.
14742 if (N->hasOneUse() && (N->use_begin()->getOpcode() == ISD::ANY_EXTEND))
14743 return SDValue();
14744
14745 // Fold extract-and-trunc into a narrow extract. For example:
14746 // i64 x = EXTRACT_VECTOR_ELT(v2i64 val, i32 1)
14747 // i32 y = TRUNCATE(i64 x)
14748 // -- becomes --
14749 // v16i8 b = BITCAST (v2i64 val)
14750 // i8 x = EXTRACT_VECTOR_ELT(v16i8 b, i32 8)
14751 //
14752 // Note: We only run this optimization after type legalization (which often
14753 // creates this pattern) and before operation legalization after which
14754 // we need to be more careful about the vector instructions that we generate.
14755 if (N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
14756 LegalTypes && !LegalOperations && N0->hasOneUse() && VT != MVT::i1) {
14757 EVT VecTy = N0.getOperand(i: 0).getValueType();
14758 EVT ExTy = N0.getValueType();
14759 EVT TrTy = N->getValueType(ResNo: 0);
14760
14761 auto EltCnt = VecTy.getVectorElementCount();
14762 unsigned SizeRatio = ExTy.getSizeInBits()/TrTy.getSizeInBits();
14763 auto NewEltCnt = EltCnt * SizeRatio;
14764
14765 EVT NVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: TrTy, EC: NewEltCnt);
14766 assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size");
14767
14768 SDValue EltNo = N0->getOperand(Num: 1);
14769 if (isa<ConstantSDNode>(Val: EltNo) && isTypeLegal(VT: NVT)) {
14770 int Elt = EltNo->getAsZExtVal();
14771 int Index = isLE ? (Elt*SizeRatio) : (Elt*SizeRatio + (SizeRatio-1));
14772
14773 SDLoc DL(N);
14774 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: TrTy,
14775 N1: DAG.getBitcast(VT: NVT, V: N0.getOperand(i: 0)),
14776 N2: DAG.getVectorIdxConstant(Val: Index, DL));
14777 }
14778 }
14779
14780 // trunc (select c, a, b) -> select c, (trunc a), (trunc b)
14781 if (N0.getOpcode() == ISD::SELECT && N0.hasOneUse()) {
14782 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::SELECT, VT: SrcVT)) &&
14783 TLI.isTruncateFree(FromVT: SrcVT, ToVT: VT)) {
14784 SDLoc SL(N0);
14785 SDValue Cond = N0.getOperand(i: 0);
14786 SDValue TruncOp0 = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SL, VT, Operand: N0.getOperand(i: 1));
14787 SDValue TruncOp1 = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SL, VT, Operand: N0.getOperand(i: 2));
14788 return DAG.getNode(Opcode: ISD::SELECT, DL: SDLoc(N), VT, N1: Cond, N2: TruncOp0, N3: TruncOp1);
14789 }
14790 }
14791
14792 // trunc (shl x, K) -> shl (trunc x), K => K < VT.getScalarSizeInBits()
14793 if (N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
14794 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SHL, VT)) &&
14795 TLI.isTypeDesirableForOp(ISD::SHL, VT)) {
14796 SDValue Amt = N0.getOperand(i: 1);
14797 KnownBits Known = DAG.computeKnownBits(Op: Amt);
14798 unsigned Size = VT.getScalarSizeInBits();
14799 if (Known.countMaxActiveBits() <= Log2_32(Value: Size)) {
14800 SDLoc SL(N);
14801 EVT AmtVT = TLI.getShiftAmountTy(LHSTy: VT, DL: DAG.getDataLayout());
14802
14803 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SL, VT, Operand: N0.getOperand(i: 0));
14804 if (AmtVT != Amt.getValueType()) {
14805 Amt = DAG.getZExtOrTrunc(Op: Amt, DL: SL, VT: AmtVT);
14806 AddToWorklist(N: Amt.getNode());
14807 }
14808 return DAG.getNode(Opcode: ISD::SHL, DL: SL, VT, N1: Trunc, N2: Amt);
14809 }
14810 }
14811
14812 if (SDValue V = foldSubToUSubSat(DstVT: VT, N: N0.getNode()))
14813 return V;
14814
14815 if (SDValue ABD = foldABSToABD(N))
14816 return ABD;
14817
14818 // Attempt to pre-truncate BUILD_VECTOR sources.
14819 if (N0.getOpcode() == ISD::BUILD_VECTOR && !LegalOperations &&
14820 N0.hasOneUse() &&
14821 TLI.isTruncateFree(FromVT: SrcVT.getScalarType(), ToVT: VT.getScalarType()) &&
14822 // Avoid creating illegal types if running after type legalizer.
14823 (!LegalTypes || TLI.isTypeLegal(VT: VT.getScalarType()))) {
14824 SDLoc DL(N);
14825 EVT SVT = VT.getScalarType();
14826 SmallVector<SDValue, 8> TruncOps;
14827 for (const SDValue &Op : N0->op_values()) {
14828 SDValue TruncOp = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: SVT, Operand: Op);
14829 TruncOps.push_back(Elt: TruncOp);
14830 }
14831 return DAG.getBuildVector(VT, DL, Ops: TruncOps);
14832 }
14833
14834 // trunc (splat_vector x) -> splat_vector (trunc x)
14835 if (N0.getOpcode() == ISD::SPLAT_VECTOR &&
14836 (!LegalTypes || TLI.isTypeLegal(VT: VT.getScalarType())) &&
14837 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SPLAT_VECTOR, VT))) {
14838 SDLoc DL(N);
14839 EVT SVT = VT.getScalarType();
14840 return DAG.getSplatVector(
14841 VT, DL, Op: DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: SVT, Operand: N0->getOperand(Num: 0)));
14842 }
14843
14844 // Fold a series of buildvector, bitcast, and truncate if possible.
14845 // For example fold
14846 // (2xi32 trunc (bitcast ((4xi32)buildvector x, x, y, y) 2xi64)) to
14847 // (2xi32 (buildvector x, y)).
14848 if (Level == AfterLegalizeVectorOps && VT.isVector() &&
14849 N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
14850 N0.getOperand(i: 0).getOpcode() == ISD::BUILD_VECTOR &&
14851 N0.getOperand(i: 0).hasOneUse()) {
14852 SDValue BuildVect = N0.getOperand(i: 0);
14853 EVT BuildVectEltTy = BuildVect.getValueType().getVectorElementType();
14854 EVT TruncVecEltTy = VT.getVectorElementType();
14855
14856 // Check that the element types match.
14857 if (BuildVectEltTy == TruncVecEltTy) {
14858 // Now we only need to compute the offset of the truncated elements.
14859 unsigned BuildVecNumElts = BuildVect.getNumOperands();
14860 unsigned TruncVecNumElts = VT.getVectorNumElements();
14861 unsigned TruncEltOffset = BuildVecNumElts / TruncVecNumElts;
14862
14863 assert((BuildVecNumElts % TruncVecNumElts) == 0 &&
14864 "Invalid number of elements");
14865
14866 SmallVector<SDValue, 8> Opnds;
14867 for (unsigned i = 0, e = BuildVecNumElts; i != e; i += TruncEltOffset)
14868 Opnds.push_back(Elt: BuildVect.getOperand(i));
14869
14870 return DAG.getBuildVector(VT, DL: SDLoc(N), Ops: Opnds);
14871 }
14872 }
14873
14874 // fold (truncate (load x)) -> (smaller load x)
14875 // fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits))
14876 if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) {
14877 if (SDValue Reduced = reduceLoadWidth(N))
14878 return Reduced;
14879
14880 // Handle the case where the truncated result is at least as wide as the
14881 // loaded type.
14882 if (N0.hasOneUse() && ISD::isUNINDEXEDLoad(N: N0.getNode())) {
14883 auto *LN0 = cast<LoadSDNode>(Val&: N0);
14884 if (LN0->isSimple() && LN0->getMemoryVT().bitsLE(VT)) {
14885 SDValue NewLoad = DAG.getExtLoad(
14886 ExtType: LN0->getExtensionType(), dl: SDLoc(LN0), VT, Chain: LN0->getChain(),
14887 Ptr: LN0->getBasePtr(), MemVT: LN0->getMemoryVT(), MMO: LN0->getMemOperand());
14888 DAG.ReplaceAllUsesOfValueWith(From: N0.getValue(R: 1), To: NewLoad.getValue(R: 1));
14889 return NewLoad;
14890 }
14891 }
14892 }
14893
14894 // fold (trunc (concat ... x ...)) -> (concat ..., (trunc x), ...)),
14895 // where ... are all 'undef'.
14896 if (N0.getOpcode() == ISD::CONCAT_VECTORS && !LegalTypes) {
14897 SmallVector<EVT, 8> VTs;
14898 SDValue V;
14899 unsigned Idx = 0;
14900 unsigned NumDefs = 0;
14901
14902 for (unsigned i = 0, e = N0.getNumOperands(); i != e; ++i) {
14903 SDValue X = N0.getOperand(i);
14904 if (!X.isUndef()) {
14905 V = X;
14906 Idx = i;
14907 NumDefs++;
14908 }
14909 // Stop if more than one members are non-undef.
14910 if (NumDefs > 1)
14911 break;
14912
14913 VTs.push_back(Elt: EVT::getVectorVT(Context&: *DAG.getContext(),
14914 VT: VT.getVectorElementType(),
14915 EC: X.getValueType().getVectorElementCount()));
14916 }
14917
14918 if (NumDefs == 0)
14919 return DAG.getUNDEF(VT);
14920
14921 if (NumDefs == 1) {
14922 assert(V.getNode() && "The single defined operand is empty!");
14923 SmallVector<SDValue, 8> Opnds;
14924 for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
14925 if (i != Idx) {
14926 Opnds.push_back(Elt: DAG.getUNDEF(VT: VTs[i]));
14927 continue;
14928 }
14929 SDValue NV = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(V), VT: VTs[i], Operand: V);
14930 AddToWorklist(N: NV.getNode());
14931 Opnds.push_back(Elt: NV);
14932 }
14933 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT, Ops: Opnds);
14934 }
14935 }
14936
14937 // Fold truncate of a bitcast of a vector to an extract of the low vector
14938 // element.
14939 //
14940 // e.g. trunc (i64 (bitcast v2i32:x)) -> extract_vector_elt v2i32:x, idx
14941 if (N0.getOpcode() == ISD::BITCAST && !VT.isVector()) {
14942 SDValue VecSrc = N0.getOperand(i: 0);
14943 EVT VecSrcVT = VecSrc.getValueType();
14944 if (VecSrcVT.isVector() && VecSrcVT.getScalarType() == VT &&
14945 (!LegalOperations ||
14946 TLI.isOperationLegal(Op: ISD::EXTRACT_VECTOR_ELT, VT: VecSrcVT))) {
14947 SDLoc SL(N);
14948
14949 unsigned Idx = isLE ? 0 : VecSrcVT.getVectorNumElements() - 1;
14950 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: SL, VT, N1: VecSrc,
14951 N2: DAG.getVectorIdxConstant(Val: Idx, DL: SL));
14952 }
14953 }
14954
14955 // Simplify the operands using demanded-bits information.
14956 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
14957 return SDValue(N, 0);
14958
14959 // fold (truncate (extract_subvector(ext x))) ->
14960 // (extract_subvector x)
14961 // TODO: This can be generalized to cover cases where the truncate and extract
14962 // do not fully cancel each other out.
14963 if (!LegalTypes && N0.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
14964 SDValue N00 = N0.getOperand(i: 0);
14965 if (N00.getOpcode() == ISD::SIGN_EXTEND ||
14966 N00.getOpcode() == ISD::ZERO_EXTEND ||
14967 N00.getOpcode() == ISD::ANY_EXTEND) {
14968 if (N00.getOperand(i: 0)->getValueType(ResNo: 0).getVectorElementType() ==
14969 VT.getVectorElementType())
14970 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL: SDLoc(N0->getOperand(Num: 0)), VT,
14971 N1: N00.getOperand(i: 0), N2: N0.getOperand(i: 1));
14972 }
14973 }
14974
14975 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(Cast: N))
14976 return NewVSel;
14977
14978 // Narrow a suitable binary operation with a non-opaque constant operand by
14979 // moving it ahead of the truncate. This is limited to pre-legalization
14980 // because targets may prefer a wider type during later combines and invert
14981 // this transform.
14982 switch (N0.getOpcode()) {
14983 case ISD::ADD:
14984 case ISD::SUB:
14985 case ISD::MUL:
14986 case ISD::AND:
14987 case ISD::OR:
14988 case ISD::XOR:
14989 if (!LegalOperations && N0.hasOneUse() &&
14990 (isConstantOrConstantVector(N: N0.getOperand(i: 0), NoOpaques: true) ||
14991 isConstantOrConstantVector(N: N0.getOperand(i: 1), NoOpaques: true))) {
14992 // TODO: We already restricted this to pre-legalization, but for vectors
14993 // we are extra cautious to not create an unsupported operation.
14994 // Target-specific changes are likely needed to avoid regressions here.
14995 if (VT.isScalarInteger() || TLI.isOperationLegal(Op: N0.getOpcode(), VT)) {
14996 SDLoc DL(N);
14997 SDValue NarrowL = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 0));
14998 SDValue NarrowR = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 1));
14999 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, N1: NarrowL, N2: NarrowR);
15000 }
15001 }
15002 break;
15003 case ISD::ADDE:
15004 case ISD::UADDO_CARRY:
15005 // (trunc adde(X, Y, Carry)) -> (adde trunc(X), trunc(Y), Carry)
15006 // (trunc uaddo_carry(X, Y, Carry)) ->
15007 // (uaddo_carry trunc(X), trunc(Y), Carry)
15008 // When the adde's carry is not used.
15009 // We only do for uaddo_carry before legalize operation
15010 if (((!LegalOperations && N0.getOpcode() == ISD::UADDO_CARRY) ||
15011 TLI.isOperationLegal(Op: N0.getOpcode(), VT)) &&
15012 N0.hasOneUse() && !N0->hasAnyUseOfValue(Value: 1)) {
15013 SDLoc DL(N);
15014 SDValue X = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 0));
15015 SDValue Y = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 1));
15016 SDVTList VTs = DAG.getVTList(VT1: VT, VT2: N0->getValueType(ResNo: 1));
15017 return DAG.getNode(Opcode: N0.getOpcode(), DL, VTList: VTs, N1: X, N2: Y, N3: N0.getOperand(i: 2));
15018 }
15019 break;
15020 case ISD::USUBSAT:
15021 // Truncate the USUBSAT only if LHS is a known zero-extension, its not
15022 // enough to know that the upper bits are zero we must ensure that we don't
15023 // introduce an extra truncate.
15024 if (!LegalOperations && N0.hasOneUse() &&
15025 N0.getOperand(i: 0).getOpcode() == ISD::ZERO_EXTEND &&
15026 N0.getOperand(i: 0).getOperand(i: 0).getScalarValueSizeInBits() <=
15027 VT.getScalarSizeInBits() &&
15028 hasOperation(Opcode: N0.getOpcode(), VT)) {
15029 return getTruncatedUSUBSAT(DstVT: VT, SrcVT, LHS: N0.getOperand(i: 0), RHS: N0.getOperand(i: 1),
15030 DAG, DL: SDLoc(N));
15031 }
15032 break;
15033 }
15034
15035 return SDValue();
15036}
15037
15038static SDNode *getBuildPairElt(SDNode *N, unsigned i) {
15039 SDValue Elt = N->getOperand(Num: i);
15040 if (Elt.getOpcode() != ISD::MERGE_VALUES)
15041 return Elt.getNode();
15042 return Elt.getOperand(i: Elt.getResNo()).getNode();
15043}
15044
15045/// build_pair (load, load) -> load
15046/// if load locations are consecutive.
15047SDValue DAGCombiner::CombineConsecutiveLoads(SDNode *N, EVT VT) {
15048 assert(N->getOpcode() == ISD::BUILD_PAIR);
15049
15050 auto *LD1 = dyn_cast<LoadSDNode>(Val: getBuildPairElt(N, i: 0));
15051 auto *LD2 = dyn_cast<LoadSDNode>(Val: getBuildPairElt(N, i: 1));
15052
15053 // A BUILD_PAIR is always having the least significant part in elt 0 and the
15054 // most significant part in elt 1. So when combining into one large load, we
15055 // need to consider the endianness.
15056 if (DAG.getDataLayout().isBigEndian())
15057 std::swap(a&: LD1, b&: LD2);
15058
15059 if (!LD1 || !LD2 || !ISD::isNON_EXTLoad(N: LD1) || !ISD::isNON_EXTLoad(N: LD2) ||
15060 !LD1->hasOneUse() || !LD2->hasOneUse() ||
15061 LD1->getAddressSpace() != LD2->getAddressSpace())
15062 return SDValue();
15063
15064 unsigned LD1Fast = 0;
15065 EVT LD1VT = LD1->getValueType(ResNo: 0);
15066 unsigned LD1Bytes = LD1VT.getStoreSize();
15067 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::LOAD, VT)) &&
15068 DAG.areNonVolatileConsecutiveLoads(LD: LD2, Base: LD1, Bytes: LD1Bytes, Dist: 1) &&
15069 TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT,
15070 MMO: *LD1->getMemOperand(), Fast: &LD1Fast) && LD1Fast)
15071 return DAG.getLoad(VT, dl: SDLoc(N), Chain: LD1->getChain(), Ptr: LD1->getBasePtr(),
15072 PtrInfo: LD1->getPointerInfo(), Alignment: LD1->getAlign());
15073
15074 return SDValue();
15075}
15076
15077static unsigned getPPCf128HiElementSelector(const SelectionDAG &DAG) {
15078 // On little-endian machines, bitcasting from ppcf128 to i128 does swap the Hi
15079 // and Lo parts; on big-endian machines it doesn't.
15080 return DAG.getDataLayout().isBigEndian() ? 1 : 0;
15081}
15082
15083SDValue DAGCombiner::foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
15084 const TargetLowering &TLI) {
15085 // If this is not a bitcast to an FP type or if the target doesn't have
15086 // IEEE754-compliant FP logic, we're done.
15087 EVT VT = N->getValueType(ResNo: 0);
15088 SDValue N0 = N->getOperand(Num: 0);
15089 EVT SourceVT = N0.getValueType();
15090
15091 if (!VT.isFloatingPoint())
15092 return SDValue();
15093
15094 // TODO: Handle cases where the integer constant is a different scalar
15095 // bitwidth to the FP.
15096 if (VT.getScalarSizeInBits() != SourceVT.getScalarSizeInBits())
15097 return SDValue();
15098
15099 unsigned FPOpcode;
15100 APInt SignMask;
15101 switch (N0.getOpcode()) {
15102 case ISD::AND:
15103 FPOpcode = ISD::FABS;
15104 SignMask = ~APInt::getSignMask(BitWidth: SourceVT.getScalarSizeInBits());
15105 break;
15106 case ISD::XOR:
15107 FPOpcode = ISD::FNEG;
15108 SignMask = APInt::getSignMask(BitWidth: SourceVT.getScalarSizeInBits());
15109 break;
15110 case ISD::OR:
15111 FPOpcode = ISD::FABS;
15112 SignMask = APInt::getSignMask(BitWidth: SourceVT.getScalarSizeInBits());
15113 break;
15114 default:
15115 return SDValue();
15116 }
15117
15118 if (LegalOperations && !TLI.isOperationLegal(Op: FPOpcode, VT))
15119 return SDValue();
15120
15121 // This needs to be the inverse of logic in foldSignChangeInBitcast.
15122 // FIXME: I don't think looking for bitcast intrinsically makes sense, but
15123 // removing this would require more changes.
15124 auto IsBitCastOrFree = [&TLI, FPOpcode](SDValue Op, EVT VT) {
15125 if (Op.getOpcode() == ISD::BITCAST && Op.getOperand(i: 0).getValueType() == VT)
15126 return true;
15127
15128 return FPOpcode == ISD::FABS ? TLI.isFAbsFree(VT) : TLI.isFNegFree(VT);
15129 };
15130
15131 // Fold (bitcast int (and (bitcast fp X to int), 0x7fff...) to fp) -> fabs X
15132 // Fold (bitcast int (xor (bitcast fp X to int), 0x8000...) to fp) -> fneg X
15133 // Fold (bitcast int (or (bitcast fp X to int), 0x8000...) to fp) ->
15134 // fneg (fabs X)
15135 SDValue LogicOp0 = N0.getOperand(i: 0);
15136 ConstantSDNode *LogicOp1 = isConstOrConstSplat(N: N0.getOperand(i: 1), AllowUndefs: true);
15137 if (LogicOp1 && LogicOp1->getAPIntValue() == SignMask &&
15138 IsBitCastOrFree(LogicOp0, VT)) {
15139 SDValue CastOp0 = DAG.getNode(Opcode: ISD::BITCAST, DL: SDLoc(N), VT, Operand: LogicOp0);
15140 SDValue FPOp = DAG.getNode(Opcode: FPOpcode, DL: SDLoc(N), VT, Operand: CastOp0);
15141 NumFPLogicOpsConv++;
15142 if (N0.getOpcode() == ISD::OR)
15143 return DAG.getNode(Opcode: ISD::FNEG, DL: SDLoc(N), VT, Operand: FPOp);
15144 return FPOp;
15145 }
15146
15147 return SDValue();
15148}
15149
15150SDValue DAGCombiner::visitBITCAST(SDNode *N) {
15151 SDValue N0 = N->getOperand(Num: 0);
15152 EVT VT = N->getValueType(ResNo: 0);
15153
15154 if (N0.isUndef())
15155 return DAG.getUNDEF(VT);
15156
15157 // If the input is a BUILD_VECTOR with all constant elements, fold this now.
15158 // Only do this before legalize types, unless both types are integer and the
15159 // scalar type is legal. Only do this before legalize ops, since the target
15160 // maybe depending on the bitcast.
15161 // First check to see if this is all constant.
15162 // TODO: Support FP bitcasts after legalize types.
15163 if (VT.isVector() &&
15164 (!LegalTypes ||
15165 (!LegalOperations && VT.isInteger() && N0.getValueType().isInteger() &&
15166 TLI.isTypeLegal(VT: VT.getVectorElementType()))) &&
15167 N0.getOpcode() == ISD::BUILD_VECTOR && N0->hasOneUse() &&
15168 cast<BuildVectorSDNode>(Val&: N0)->isConstant())
15169 return ConstantFoldBITCASTofBUILD_VECTOR(N0.getNode(),
15170 VT.getVectorElementType());
15171
15172 // If the input is a constant, let getNode fold it.
15173 if (isIntOrFPConstant(V: N0)) {
15174 // If we can't allow illegal operations, we need to check that this is just
15175 // a fp -> int or int -> conversion and that the resulting operation will
15176 // be legal.
15177 if (!LegalOperations ||
15178 (isa<ConstantSDNode>(Val: N0) && VT.isFloatingPoint() && !VT.isVector() &&
15179 TLI.isOperationLegal(Op: ISD::ConstantFP, VT)) ||
15180 (isa<ConstantFPSDNode>(Val: N0) && VT.isInteger() && !VT.isVector() &&
15181 TLI.isOperationLegal(Op: ISD::Constant, VT))) {
15182 SDValue C = DAG.getBitcast(VT, V: N0);
15183 if (C.getNode() != N)
15184 return C;
15185 }
15186 }
15187
15188 // (conv (conv x, t1), t2) -> (conv x, t2)
15189 if (N0.getOpcode() == ISD::BITCAST)
15190 return DAG.getBitcast(VT, V: N0.getOperand(i: 0));
15191
15192 // fold (conv (logicop (conv x), (c))) -> (logicop x, (conv c))
15193 // iff the current bitwise logicop type isn't legal
15194 if (ISD::isBitwiseLogicOp(Opcode: N0.getOpcode()) && VT.isInteger() &&
15195 !TLI.isTypeLegal(VT: N0.getOperand(i: 0).getValueType())) {
15196 auto IsFreeBitcast = [VT](SDValue V) {
15197 return (V.getOpcode() == ISD::BITCAST &&
15198 V.getOperand(i: 0).getValueType() == VT) ||
15199 (ISD::isBuildVectorOfConstantSDNodes(N: V.getNode()) &&
15200 V->hasOneUse());
15201 };
15202 if (IsFreeBitcast(N0.getOperand(i: 0)) && IsFreeBitcast(N0.getOperand(i: 1)))
15203 return DAG.getNode(Opcode: N0.getOpcode(), DL: SDLoc(N), VT,
15204 N1: DAG.getBitcast(VT, V: N0.getOperand(i: 0)),
15205 N2: DAG.getBitcast(VT, V: N0.getOperand(i: 1)));
15206 }
15207
15208 // fold (conv (load x)) -> (load (conv*)x)
15209 // If the resultant load doesn't need a higher alignment than the original!
15210 if (ISD::isNormalLoad(N: N0.getNode()) && N0.hasOneUse() &&
15211 // Do not remove the cast if the types differ in endian layout.
15212 TLI.hasBigEndianPartOrdering(VT: N0.getValueType(), DL: DAG.getDataLayout()) ==
15213 TLI.hasBigEndianPartOrdering(VT, DL: DAG.getDataLayout()) &&
15214 // If the load is volatile, we only want to change the load type if the
15215 // resulting load is legal. Otherwise we might increase the number of
15216 // memory accesses. We don't care if the original type was legal or not
15217 // as we assume software couldn't rely on the number of accesses of an
15218 // illegal type.
15219 ((!LegalOperations && cast<LoadSDNode>(Val&: N0)->isSimple()) ||
15220 TLI.isOperationLegal(Op: ISD::LOAD, VT))) {
15221 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
15222
15223 if (TLI.isLoadBitCastBeneficial(LoadVT: N0.getValueType(), BitcastVT: VT, DAG,
15224 MMO: *LN0->getMemOperand())) {
15225 SDValue Load =
15226 DAG.getLoad(VT, dl: SDLoc(N), Chain: LN0->getChain(), Ptr: LN0->getBasePtr(),
15227 MMO: LN0->getMemOperand());
15228 DAG.ReplaceAllUsesOfValueWith(From: N0.getValue(R: 1), To: Load.getValue(R: 1));
15229 return Load;
15230 }
15231 }
15232
15233 if (SDValue V = foldBitcastedFPLogic(N, DAG, TLI))
15234 return V;
15235
15236 // fold (bitconvert (fneg x)) -> (xor (bitconvert x), signbit)
15237 // fold (bitconvert (fabs x)) -> (and (bitconvert x), (not signbit))
15238 //
15239 // For ppc_fp128:
15240 // fold (bitcast (fneg x)) ->
15241 // flipbit = signbit
15242 // (xor (bitcast x) (build_pair flipbit, flipbit))
15243 //
15244 // fold (bitcast (fabs x)) ->
15245 // flipbit = (and (extract_element (bitcast x), 0), signbit)
15246 // (xor (bitcast x) (build_pair flipbit, flipbit))
15247 // This often reduces constant pool loads.
15248 if (((N0.getOpcode() == ISD::FNEG && !TLI.isFNegFree(VT: N0.getValueType())) ||
15249 (N0.getOpcode() == ISD::FABS && !TLI.isFAbsFree(VT: N0.getValueType()))) &&
15250 N0->hasOneUse() && VT.isInteger() && !VT.isVector() &&
15251 !N0.getValueType().isVector()) {
15252 SDValue NewConv = DAG.getBitcast(VT, V: N0.getOperand(i: 0));
15253 AddToWorklist(N: NewConv.getNode());
15254
15255 SDLoc DL(N);
15256 if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
15257 assert(VT.getSizeInBits() == 128);
15258 SDValue SignBit = DAG.getConstant(
15259 APInt::getSignMask(VT.getSizeInBits() / 2), SDLoc(N0), MVT::i64);
15260 SDValue FlipBit;
15261 if (N0.getOpcode() == ISD::FNEG) {
15262 FlipBit = SignBit;
15263 AddToWorklist(N: FlipBit.getNode());
15264 } else {
15265 assert(N0.getOpcode() == ISD::FABS);
15266 SDValue Hi =
15267 DAG.getNode(ISD::EXTRACT_ELEMENT, SDLoc(NewConv), MVT::i64, NewConv,
15268 DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
15269 SDLoc(NewConv)));
15270 AddToWorklist(N: Hi.getNode());
15271 FlipBit = DAG.getNode(ISD::AND, SDLoc(N0), MVT::i64, Hi, SignBit);
15272 AddToWorklist(N: FlipBit.getNode());
15273 }
15274 SDValue FlipBits =
15275 DAG.getNode(Opcode: ISD::BUILD_PAIR, DL: SDLoc(N0), VT, N1: FlipBit, N2: FlipBit);
15276 AddToWorklist(N: FlipBits.getNode());
15277 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: NewConv, N2: FlipBits);
15278 }
15279 APInt SignBit = APInt::getSignMask(BitWidth: VT.getSizeInBits());
15280 if (N0.getOpcode() == ISD::FNEG)
15281 return DAG.getNode(Opcode: ISD::XOR, DL, VT,
15282 N1: NewConv, N2: DAG.getConstant(Val: SignBit, DL, VT));
15283 assert(N0.getOpcode() == ISD::FABS);
15284 return DAG.getNode(Opcode: ISD::AND, DL, VT,
15285 N1: NewConv, N2: DAG.getConstant(Val: ~SignBit, DL, VT));
15286 }
15287
15288 // fold (bitconvert (fcopysign cst, x)) ->
15289 // (or (and (bitconvert x), sign), (and cst, (not sign)))
15290 // Note that we don't handle (copysign x, cst) because this can always be
15291 // folded to an fneg or fabs.
15292 //
15293 // For ppc_fp128:
15294 // fold (bitcast (fcopysign cst, x)) ->
15295 // flipbit = (and (extract_element
15296 // (xor (bitcast cst), (bitcast x)), 0),
15297 // signbit)
15298 // (xor (bitcast cst) (build_pair flipbit, flipbit))
15299 if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse() &&
15300 isa<ConstantFPSDNode>(Val: N0.getOperand(i: 0)) && VT.isInteger() &&
15301 !VT.isVector()) {
15302 unsigned OrigXWidth = N0.getOperand(i: 1).getValueSizeInBits();
15303 EVT IntXVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: OrigXWidth);
15304 if (isTypeLegal(VT: IntXVT)) {
15305 SDValue X = DAG.getBitcast(VT: IntXVT, V: N0.getOperand(i: 1));
15306 AddToWorklist(N: X.getNode());
15307
15308 // If X has a different width than the result/lhs, sext it or truncate it.
15309 unsigned VTWidth = VT.getSizeInBits();
15310 if (OrigXWidth < VTWidth) {
15311 X = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL: SDLoc(N), VT, Operand: X);
15312 AddToWorklist(N: X.getNode());
15313 } else if (OrigXWidth > VTWidth) {
15314 // To get the sign bit in the right place, we have to shift it right
15315 // before truncating.
15316 SDLoc DL(X);
15317 X = DAG.getNode(Opcode: ISD::SRL, DL,
15318 VT: X.getValueType(), N1: X,
15319 N2: DAG.getConstant(Val: OrigXWidth-VTWidth, DL,
15320 VT: X.getValueType()));
15321 AddToWorklist(N: X.getNode());
15322 X = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(X), VT, Operand: X);
15323 AddToWorklist(N: X.getNode());
15324 }
15325
15326 if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
15327 APInt SignBit = APInt::getSignMask(BitWidth: VT.getSizeInBits() / 2);
15328 SDValue Cst = DAG.getBitcast(VT, V: N0.getOperand(i: 0));
15329 AddToWorklist(N: Cst.getNode());
15330 SDValue X = DAG.getBitcast(VT, V: N0.getOperand(i: 1));
15331 AddToWorklist(N: X.getNode());
15332 SDValue XorResult = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N0), VT, N1: Cst, N2: X);
15333 AddToWorklist(N: XorResult.getNode());
15334 SDValue XorResult64 = DAG.getNode(
15335 ISD::EXTRACT_ELEMENT, SDLoc(XorResult), MVT::i64, XorResult,
15336 DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
15337 SDLoc(XorResult)));
15338 AddToWorklist(N: XorResult64.getNode());
15339 SDValue FlipBit =
15340 DAG.getNode(ISD::AND, SDLoc(XorResult64), MVT::i64, XorResult64,
15341 DAG.getConstant(SignBit, SDLoc(XorResult64), MVT::i64));
15342 AddToWorklist(N: FlipBit.getNode());
15343 SDValue FlipBits =
15344 DAG.getNode(Opcode: ISD::BUILD_PAIR, DL: SDLoc(N0), VT, N1: FlipBit, N2: FlipBit);
15345 AddToWorklist(N: FlipBits.getNode());
15346 return DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N), VT, N1: Cst, N2: FlipBits);
15347 }
15348 APInt SignBit = APInt::getSignMask(BitWidth: VT.getSizeInBits());
15349 X = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(X), VT,
15350 N1: X, N2: DAG.getConstant(Val: SignBit, DL: SDLoc(X), VT));
15351 AddToWorklist(N: X.getNode());
15352
15353 SDValue Cst = DAG.getBitcast(VT, V: N0.getOperand(i: 0));
15354 Cst = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(Cst), VT,
15355 N1: Cst, N2: DAG.getConstant(Val: ~SignBit, DL: SDLoc(Cst), VT));
15356 AddToWorklist(N: Cst.getNode());
15357
15358 return DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N), VT, N1: X, N2: Cst);
15359 }
15360 }
15361
15362 // bitconvert(build_pair(ld, ld)) -> ld iff load locations are consecutive.
15363 if (N0.getOpcode() == ISD::BUILD_PAIR)
15364 if (SDValue CombineLD = CombineConsecutiveLoads(N: N0.getNode(), VT))
15365 return CombineLD;
15366
15367 // Remove double bitcasts from shuffles - this is often a legacy of
15368 // XformToShuffleWithZero being used to combine bitmaskings (of
15369 // float vectors bitcast to integer vectors) into shuffles.
15370 // bitcast(shuffle(bitcast(s0),bitcast(s1))) -> shuffle(s0,s1)
15371 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT) && VT.isVector() &&
15372 N0->getOpcode() == ISD::VECTOR_SHUFFLE && N0.hasOneUse() &&
15373 VT.getVectorNumElements() >= N0.getValueType().getVectorNumElements() &&
15374 !(VT.getVectorNumElements() % N0.getValueType().getVectorNumElements())) {
15375 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Val&: N0);
15376
15377 // If operands are a bitcast, peek through if it casts the original VT.
15378 // If operands are a constant, just bitcast back to original VT.
15379 auto PeekThroughBitcast = [&](SDValue Op) {
15380 if (Op.getOpcode() == ISD::BITCAST &&
15381 Op.getOperand(i: 0).getValueType() == VT)
15382 return SDValue(Op.getOperand(i: 0));
15383 if (Op.isUndef() || isAnyConstantBuildVector(V: Op))
15384 return DAG.getBitcast(VT, V: Op);
15385 return SDValue();
15386 };
15387
15388 // FIXME: If either input vector is bitcast, try to convert the shuffle to
15389 // the result type of this bitcast. This would eliminate at least one
15390 // bitcast. See the transform in InstCombine.
15391 SDValue SV0 = PeekThroughBitcast(N0->getOperand(Num: 0));
15392 SDValue SV1 = PeekThroughBitcast(N0->getOperand(Num: 1));
15393 if (!(SV0 && SV1))
15394 return SDValue();
15395
15396 int MaskScale =
15397 VT.getVectorNumElements() / N0.getValueType().getVectorNumElements();
15398 SmallVector<int, 8> NewMask;
15399 for (int M : SVN->getMask())
15400 for (int i = 0; i != MaskScale; ++i)
15401 NewMask.push_back(Elt: M < 0 ? -1 : M * MaskScale + i);
15402
15403 SDValue LegalShuffle =
15404 TLI.buildLegalVectorShuffle(VT, DL: SDLoc(N), N0: SV0, N1: SV1, Mask: NewMask, DAG);
15405 if (LegalShuffle)
15406 return LegalShuffle;
15407 }
15408
15409 return SDValue();
15410}
15411
15412SDValue DAGCombiner::visitBUILD_PAIR(SDNode *N) {
15413 EVT VT = N->getValueType(ResNo: 0);
15414 return CombineConsecutiveLoads(N, VT);
15415}
15416
15417SDValue DAGCombiner::visitFREEZE(SDNode *N) {
15418 SDValue N0 = N->getOperand(Num: 0);
15419
15420 if (DAG.isGuaranteedNotToBeUndefOrPoison(Op: N0, /*PoisonOnly*/ false))
15421 return N0;
15422
15423 // Fold freeze(op(x, ...)) -> op(freeze(x), ...).
15424 // Try to push freeze through instructions that propagate but don't produce
15425 // poison as far as possible. If an operand of freeze follows three
15426 // conditions 1) one-use, 2) does not produce poison, and 3) has all but one
15427 // guaranteed-non-poison operands (or is a BUILD_VECTOR or similar) then push
15428 // the freeze through to the operands that are not guaranteed non-poison.
15429 // NOTE: we will strip poison-generating flags, so ignore them here.
15430 if (DAG.canCreateUndefOrPoison(Op: N0, /*PoisonOnly*/ false,
15431 /*ConsiderFlags*/ false) ||
15432 N0->getNumValues() != 1 || !N0->hasOneUse())
15433 return SDValue();
15434
15435 bool AllowMultipleMaybePoisonOperands = N0.getOpcode() == ISD::BUILD_VECTOR ||
15436 N0.getOpcode() == ISD::BUILD_PAIR ||
15437 N0.getOpcode() == ISD::CONCAT_VECTORS;
15438
15439 SmallSetVector<SDValue, 8> MaybePoisonOperands;
15440 for (SDValue Op : N0->ops()) {
15441 if (DAG.isGuaranteedNotToBeUndefOrPoison(Op, /*PoisonOnly*/ false,
15442 /*Depth*/ 1))
15443 continue;
15444 bool HadMaybePoisonOperands = !MaybePoisonOperands.empty();
15445 bool IsNewMaybePoisonOperand = MaybePoisonOperands.insert(X: Op);
15446 if (!HadMaybePoisonOperands)
15447 continue;
15448 if (IsNewMaybePoisonOperand && !AllowMultipleMaybePoisonOperands) {
15449 // Multiple maybe-poison ops when not allowed - bail out.
15450 return SDValue();
15451 }
15452 }
15453 // NOTE: the whole op may be not guaranteed to not be undef or poison because
15454 // it could create undef or poison due to it's poison-generating flags.
15455 // So not finding any maybe-poison operands is fine.
15456
15457 for (SDValue MaybePoisonOperand : MaybePoisonOperands) {
15458 // Don't replace every single UNDEF everywhere with frozen UNDEF, though.
15459 if (MaybePoisonOperand.getOpcode() == ISD::UNDEF)
15460 continue;
15461 // First, freeze each offending operand.
15462 SDValue FrozenMaybePoisonOperand = DAG.getFreeze(V: MaybePoisonOperand);
15463 // Then, change all other uses of unfrozen operand to use frozen operand.
15464 DAG.ReplaceAllUsesOfValueWith(From: MaybePoisonOperand, To: FrozenMaybePoisonOperand);
15465 if (FrozenMaybePoisonOperand.getOpcode() == ISD::FREEZE &&
15466 FrozenMaybePoisonOperand.getOperand(i: 0) == FrozenMaybePoisonOperand) {
15467 // But, that also updated the use in the freeze we just created, thus
15468 // creating a cycle in a DAG. Let's undo that by mutating the freeze.
15469 DAG.UpdateNodeOperands(N: FrozenMaybePoisonOperand.getNode(),
15470 Op: MaybePoisonOperand);
15471 }
15472 }
15473
15474 // This node has been merged with another.
15475 if (N->getOpcode() == ISD::DELETED_NODE)
15476 return SDValue(N, 0);
15477
15478 // The whole node may have been updated, so the value we were holding
15479 // may no longer be valid. Re-fetch the operand we're `freeze`ing.
15480 N0 = N->getOperand(Num: 0);
15481
15482 // Finally, recreate the node, it's operands were updated to use
15483 // frozen operands, so we just need to use it's "original" operands.
15484 SmallVector<SDValue> Ops(N0->op_begin(), N0->op_end());
15485 // Special-handle ISD::UNDEF, each single one of them can be it's own thing.
15486 for (SDValue &Op : Ops) {
15487 if (Op.getOpcode() == ISD::UNDEF)
15488 Op = DAG.getFreeze(V: Op);
15489 }
15490 // NOTE: this strips poison generating flags.
15491 SDValue R = DAG.getNode(Opcode: N0.getOpcode(), DL: SDLoc(N0), VTList: N0->getVTList(), Ops);
15492 assert(DAG.isGuaranteedNotToBeUndefOrPoison(R, /*PoisonOnly*/ false) &&
15493 "Can't create node that may be undef/poison!");
15494 return R;
15495}
15496
15497/// We know that BV is a build_vector node with Constant, ConstantFP or Undef
15498/// operands. DstEltVT indicates the destination element value type.
15499SDValue DAGCombiner::
15500ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
15501 EVT SrcEltVT = BV->getValueType(ResNo: 0).getVectorElementType();
15502
15503 // If this is already the right type, we're done.
15504 if (SrcEltVT == DstEltVT) return SDValue(BV, 0);
15505
15506 unsigned SrcBitSize = SrcEltVT.getSizeInBits();
15507 unsigned DstBitSize = DstEltVT.getSizeInBits();
15508
15509 // If this is a conversion of N elements of one type to N elements of another
15510 // type, convert each element. This handles FP<->INT cases.
15511 if (SrcBitSize == DstBitSize) {
15512 SmallVector<SDValue, 8> Ops;
15513 for (SDValue Op : BV->op_values()) {
15514 // If the vector element type is not legal, the BUILD_VECTOR operands
15515 // are promoted and implicitly truncated. Make that explicit here.
15516 if (Op.getValueType() != SrcEltVT)
15517 Op = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(BV), VT: SrcEltVT, Operand: Op);
15518 Ops.push_back(Elt: DAG.getBitcast(VT: DstEltVT, V: Op));
15519 AddToWorklist(N: Ops.back().getNode());
15520 }
15521 EVT VT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: DstEltVT,
15522 NumElements: BV->getValueType(ResNo: 0).getVectorNumElements());
15523 return DAG.getBuildVector(VT, DL: SDLoc(BV), Ops);
15524 }
15525
15526 // Otherwise, we're growing or shrinking the elements. To avoid having to
15527 // handle annoying details of growing/shrinking FP values, we convert them to
15528 // int first.
15529 if (SrcEltVT.isFloatingPoint()) {
15530 // Convert the input float vector to a int vector where the elements are the
15531 // same sizes.
15532 EVT IntVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SrcEltVT.getSizeInBits());
15533 BV = ConstantFoldBITCASTofBUILD_VECTOR(BV, DstEltVT: IntVT).getNode();
15534 SrcEltVT = IntVT;
15535 }
15536
15537 // Now we know the input is an integer vector. If the output is a FP type,
15538 // convert to integer first, then to FP of the right size.
15539 if (DstEltVT.isFloatingPoint()) {
15540 EVT TmpVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: DstEltVT.getSizeInBits());
15541 SDNode *Tmp = ConstantFoldBITCASTofBUILD_VECTOR(BV, DstEltVT: TmpVT).getNode();
15542
15543 // Next, convert to FP elements of the same size.
15544 return ConstantFoldBITCASTofBUILD_VECTOR(BV: Tmp, DstEltVT);
15545 }
15546
15547 // Okay, we know the src/dst types are both integers of differing types.
15548 assert(SrcEltVT.isInteger() && DstEltVT.isInteger());
15549
15550 // TODO: Should ConstantFoldBITCASTofBUILD_VECTOR always take a
15551 // BuildVectorSDNode?
15552 auto *BVN = cast<BuildVectorSDNode>(Val: BV);
15553
15554 // Extract the constant raw bit data.
15555 BitVector UndefElements;
15556 SmallVector<APInt> RawBits;
15557 bool IsLE = DAG.getDataLayout().isLittleEndian();
15558 if (!BVN->getConstantRawBits(IsLittleEndian: IsLE, DstEltSizeInBits: DstBitSize, RawBitElements&: RawBits, UndefElements))
15559 return SDValue();
15560
15561 SDLoc DL(BV);
15562 SmallVector<SDValue, 8> Ops;
15563 for (unsigned I = 0, E = RawBits.size(); I != E; ++I) {
15564 if (UndefElements[I])
15565 Ops.push_back(Elt: DAG.getUNDEF(VT: DstEltVT));
15566 else
15567 Ops.push_back(Elt: DAG.getConstant(Val: RawBits[I], DL, VT: DstEltVT));
15568 }
15569
15570 EVT VT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: DstEltVT, NumElements: Ops.size());
15571 return DAG.getBuildVector(VT, DL, Ops);
15572}
15573
15574// Returns true if floating point contraction is allowed on the FMUL-SDValue
15575// `N`
15576static bool isContractableFMUL(const TargetOptions &Options, SDValue N) {
15577 assert(N.getOpcode() == ISD::FMUL);
15578
15579 return Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath ||
15580 N->getFlags().hasAllowContract();
15581}
15582
15583// Returns true if `N` can assume no infinities involved in its computation.
15584static bool hasNoInfs(const TargetOptions &Options, SDValue N) {
15585 return Options.NoInfsFPMath || N->getFlags().hasNoInfs();
15586}
15587
15588/// Try to perform FMA combining on a given FADD node.
15589template <class MatchContextClass>
15590SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
15591 SDValue N0 = N->getOperand(Num: 0);
15592 SDValue N1 = N->getOperand(Num: 1);
15593 EVT VT = N->getValueType(ResNo: 0);
15594 SDLoc SL(N);
15595 MatchContextClass matcher(DAG, TLI, N);
15596 const TargetOptions &Options = DAG.getTarget().Options;
15597
15598 bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
15599
15600 // Floating-point multiply-add with intermediate rounding.
15601 // FIXME: Make isFMADLegal have specific behavior when using VPMatchContext.
15602 // FIXME: Add VP_FMAD opcode.
15603 bool HasFMAD = !UseVP && (LegalOperations && TLI.isFMADLegal(DAG, N));
15604
15605 // Floating-point multiply-add without intermediate rounding.
15606 bool HasFMA =
15607 TLI.isFMAFasterThanFMulAndFAdd(MF: DAG.getMachineFunction(), VT) &&
15608 (!LegalOperations || matcher.isOperationLegalOrCustom(ISD::FMA, VT));
15609
15610 // No valid opcode, do not combine.
15611 if (!HasFMAD && !HasFMA)
15612 return SDValue();
15613
15614 bool CanReassociate =
15615 Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
15616 bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
15617 Options.UnsafeFPMath || HasFMAD);
15618 // If the addition is not contractable, do not combine.
15619 if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
15620 return SDValue();
15621
15622 // Folding fadd (fmul x, y), (fmul x, y) -> fma x, y, (fmul x, y) is never
15623 // beneficial. It does not reduce latency. It increases register pressure. It
15624 // replaces an fadd with an fma which is a more complex instruction, so is
15625 // likely to have a larger encoding, use more functional units, etc.
15626 if (N0 == N1)
15627 return SDValue();
15628
15629 if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
15630 return SDValue();
15631
15632 // Always prefer FMAD to FMA for precision.
15633 unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
15634 bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
15635
15636 auto isFusedOp = [&](SDValue N) {
15637 return matcher.match(N, ISD::FMA) || matcher.match(N, ISD::FMAD);
15638 };
15639
15640 // Is the node an FMUL and contractable either due to global flags or
15641 // SDNodeFlags.
15642 auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) {
15643 if (!matcher.match(N, ISD::FMUL))
15644 return false;
15645 return AllowFusionGlobally || N->getFlags().hasAllowContract();
15646 };
15647 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
15648 // prefer to fold the multiply with fewer uses.
15649 if (Aggressive && isContractableFMUL(N0) && isContractableFMUL(N1)) {
15650 if (N0->use_size() > N1->use_size())
15651 std::swap(a&: N0, b&: N1);
15652 }
15653
15654 // fold (fadd (fmul x, y), z) -> (fma x, y, z)
15655 if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) {
15656 return matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(i: 0),
15657 N0.getOperand(i: 1), N1);
15658 }
15659
15660 // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
15661 // Note: Commutes FADD operands.
15662 if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) {
15663 return matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(i: 0),
15664 N1.getOperand(i: 1), N0);
15665 }
15666
15667 // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E)
15668 // fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E)
15669 // This also works with nested fma instructions:
15670 // fadd (fma A, B, (fma (C, D, (fmul (E, F))))), G -->
15671 // fma A, B, (fma C, D, fma (E, F, G))
15672 // fadd (G, (fma A, B, (fma (C, D, (fmul (E, F)))))) -->
15673 // fma A, B, (fma C, D, fma (E, F, G)).
15674 // This requires reassociation because it changes the order of operations.
15675 if (CanReassociate) {
15676 SDValue FMA, E;
15677 if (isFusedOp(N0) && N0.hasOneUse()) {
15678 FMA = N0;
15679 E = N1;
15680 } else if (isFusedOp(N1) && N1.hasOneUse()) {
15681 FMA = N1;
15682 E = N0;
15683 }
15684
15685 SDValue TmpFMA = FMA;
15686 while (E && isFusedOp(TmpFMA) && TmpFMA.hasOneUse()) {
15687 SDValue FMul = TmpFMA->getOperand(Num: 2);
15688 if (matcher.match(FMul, ISD::FMUL) && FMul.hasOneUse()) {
15689 SDValue C = FMul.getOperand(i: 0);
15690 SDValue D = FMul.getOperand(i: 1);
15691 SDValue CDE = matcher.getNode(PreferredFusedOpcode, SL, VT, C, D, E);
15692 DAG.ReplaceAllUsesOfValueWith(From: FMul, To: CDE);
15693 // Replacing the inner FMul could cause the outer FMA to be simplified
15694 // away.
15695 return FMA.getOpcode() == ISD::DELETED_NODE ? SDValue(N, 0) : FMA;
15696 }
15697
15698 TmpFMA = TmpFMA->getOperand(Num: 2);
15699 }
15700 }
15701
15702 // Look through FP_EXTEND nodes to do more combining.
15703
15704 // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
15705 if (matcher.match(N0, ISD::FP_EXTEND)) {
15706 SDValue N00 = N0.getOperand(i: 0);
15707 if (isContractableFMUL(N00) &&
15708 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
15709 SrcVT: N00.getValueType())) {
15710 return matcher.getNode(
15711 PreferredFusedOpcode, SL, VT,
15712 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 0)),
15713 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 1)), N1);
15714 }
15715 }
15716
15717 // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x)
15718 // Note: Commutes FADD operands.
15719 if (matcher.match(N1, ISD::FP_EXTEND)) {
15720 SDValue N10 = N1.getOperand(i: 0);
15721 if (isContractableFMUL(N10) &&
15722 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
15723 SrcVT: N10.getValueType())) {
15724 return matcher.getNode(
15725 PreferredFusedOpcode, SL, VT,
15726 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(i: 0)),
15727 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(i: 1)), N0);
15728 }
15729 }
15730
15731 // More folding opportunities when target permits.
15732 if (Aggressive) {
15733 // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
15734 // -> (fma x, y, (fma (fpext u), (fpext v), z))
15735 auto FoldFAddFMAFPExtFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
15736 SDValue Z) {
15737 return matcher.getNode(
15738 PreferredFusedOpcode, SL, VT, X, Y,
15739 matcher.getNode(PreferredFusedOpcode, SL, VT,
15740 matcher.getNode(ISD::FP_EXTEND, SL, VT, U),
15741 matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
15742 };
15743 if (isFusedOp(N0)) {
15744 SDValue N02 = N0.getOperand(i: 2);
15745 if (matcher.match(N02, ISD::FP_EXTEND)) {
15746 SDValue N020 = N02.getOperand(i: 0);
15747 if (isContractableFMUL(N020) &&
15748 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
15749 SrcVT: N020.getValueType())) {
15750 return FoldFAddFMAFPExtFMul(N0.getOperand(i: 0), N0.getOperand(i: 1),
15751 N020.getOperand(i: 0), N020.getOperand(i: 1),
15752 N1);
15753 }
15754 }
15755 }
15756
15757 // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
15758 // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
15759 // FIXME: This turns two single-precision and one double-precision
15760 // operation into two double-precision operations, which might not be
15761 // interesting for all targets, especially GPUs.
15762 auto FoldFAddFPExtFMAFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
15763 SDValue Z) {
15764 return matcher.getNode(
15765 PreferredFusedOpcode, SL, VT,
15766 matcher.getNode(ISD::FP_EXTEND, SL, VT, X),
15767 matcher.getNode(ISD::FP_EXTEND, SL, VT, Y),
15768 matcher.getNode(PreferredFusedOpcode, SL, VT,
15769 matcher.getNode(ISD::FP_EXTEND, SL, VT, U),
15770 matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
15771 };
15772 if (N0.getOpcode() == ISD::FP_EXTEND) {
15773 SDValue N00 = N0.getOperand(i: 0);
15774 if (isFusedOp(N00)) {
15775 SDValue N002 = N00.getOperand(i: 2);
15776 if (isContractableFMUL(N002) &&
15777 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
15778 SrcVT: N00.getValueType())) {
15779 return FoldFAddFPExtFMAFMul(N00.getOperand(i: 0), N00.getOperand(i: 1),
15780 N002.getOperand(i: 0), N002.getOperand(i: 1),
15781 N1);
15782 }
15783 }
15784 }
15785
15786 // fold (fadd x, (fma y, z, (fpext (fmul u, v)))
15787 // -> (fma y, z, (fma (fpext u), (fpext v), x))
15788 if (isFusedOp(N1)) {
15789 SDValue N12 = N1.getOperand(i: 2);
15790 if (N12.getOpcode() == ISD::FP_EXTEND) {
15791 SDValue N120 = N12.getOperand(i: 0);
15792 if (isContractableFMUL(N120) &&
15793 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
15794 SrcVT: N120.getValueType())) {
15795 return FoldFAddFMAFPExtFMul(N1.getOperand(i: 0), N1.getOperand(i: 1),
15796 N120.getOperand(i: 0), N120.getOperand(i: 1),
15797 N0);
15798 }
15799 }
15800 }
15801
15802 // fold (fadd x, (fpext (fma y, z, (fmul u, v)))
15803 // -> (fma (fpext y), (fpext z), (fma (fpext u), (fpext v), x))
15804 // FIXME: This turns two single-precision and one double-precision
15805 // operation into two double-precision operations, which might not be
15806 // interesting for all targets, especially GPUs.
15807 if (N1.getOpcode() == ISD::FP_EXTEND) {
15808 SDValue N10 = N1.getOperand(i: 0);
15809 if (isFusedOp(N10)) {
15810 SDValue N102 = N10.getOperand(i: 2);
15811 if (isContractableFMUL(N102) &&
15812 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
15813 SrcVT: N10.getValueType())) {
15814 return FoldFAddFPExtFMAFMul(N10.getOperand(i: 0), N10.getOperand(i: 1),
15815 N102.getOperand(i: 0), N102.getOperand(i: 1),
15816 N0);
15817 }
15818 }
15819 }
15820 }
15821
15822 return SDValue();
15823}
15824
15825/// Try to perform FMA combining on a given FSUB node.
15826template <class MatchContextClass>
15827SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
15828 SDValue N0 = N->getOperand(Num: 0);
15829 SDValue N1 = N->getOperand(Num: 1);
15830 EVT VT = N->getValueType(ResNo: 0);
15831 SDLoc SL(N);
15832 MatchContextClass matcher(DAG, TLI, N);
15833 const TargetOptions &Options = DAG.getTarget().Options;
15834
15835 bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
15836
15837 // Floating-point multiply-add with intermediate rounding.
15838 // FIXME: Make isFMADLegal have specific behavior when using VPMatchContext.
15839 // FIXME: Add VP_FMAD opcode.
15840 bool HasFMAD = !UseVP && (LegalOperations && TLI.isFMADLegal(DAG, N));
15841
15842 // Floating-point multiply-add without intermediate rounding.
15843 bool HasFMA =
15844 TLI.isFMAFasterThanFMulAndFAdd(MF: DAG.getMachineFunction(), VT) &&
15845 (!LegalOperations || matcher.isOperationLegalOrCustom(ISD::FMA, VT));
15846
15847 // No valid opcode, do not combine.
15848 if (!HasFMAD && !HasFMA)
15849 return SDValue();
15850
15851 const SDNodeFlags Flags = N->getFlags();
15852 bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
15853 Options.UnsafeFPMath || HasFMAD);
15854
15855 // If the subtraction is not contractable, do not combine.
15856 if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
15857 return SDValue();
15858
15859 if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
15860 return SDValue();
15861
15862 // Always prefer FMAD to FMA for precision.
15863 unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
15864 bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
15865 bool NoSignedZero = Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros();
15866
15867 // Is the node an FMUL and contractable either due to global flags or
15868 // SDNodeFlags.
15869 auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) {
15870 if (!matcher.match(N, ISD::FMUL))
15871 return false;
15872 return AllowFusionGlobally || N->getFlags().hasAllowContract();
15873 };
15874
15875 // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
15876 auto tryToFoldXYSubZ = [&](SDValue XY, SDValue Z) {
15877 if (isContractableFMUL(XY) && (Aggressive || XY->hasOneUse())) {
15878 return matcher.getNode(PreferredFusedOpcode, SL, VT, XY.getOperand(i: 0),
15879 XY.getOperand(i: 1),
15880 matcher.getNode(ISD::FNEG, SL, VT, Z));
15881 }
15882 return SDValue();
15883 };
15884
15885 // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
15886 // Note: Commutes FSUB operands.
15887 auto tryToFoldXSubYZ = [&](SDValue X, SDValue YZ) {
15888 if (isContractableFMUL(YZ) && (Aggressive || YZ->hasOneUse())) {
15889 return matcher.getNode(
15890 PreferredFusedOpcode, SL, VT,
15891 matcher.getNode(ISD::FNEG, SL, VT, YZ.getOperand(i: 0)),
15892 YZ.getOperand(i: 1), X);
15893 }
15894 return SDValue();
15895 };
15896
15897 // If we have two choices trying to fold (fsub (fmul u, v), (fmul x, y)),
15898 // prefer to fold the multiply with fewer uses.
15899 if (isContractableFMUL(N0) && isContractableFMUL(N1) &&
15900 (N0->use_size() > N1->use_size())) {
15901 // fold (fsub (fmul a, b), (fmul c, d)) -> (fma (fneg c), d, (fmul a, b))
15902 if (SDValue V = tryToFoldXSubYZ(N0, N1))
15903 return V;
15904 // fold (fsub (fmul a, b), (fmul c, d)) -> (fma a, b, (fneg (fmul c, d)))
15905 if (SDValue V = tryToFoldXYSubZ(N0, N1))
15906 return V;
15907 } else {
15908 // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
15909 if (SDValue V = tryToFoldXYSubZ(N0, N1))
15910 return V;
15911 // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
15912 if (SDValue V = tryToFoldXSubYZ(N0, N1))
15913 return V;
15914 }
15915
15916 // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z))
15917 if (matcher.match(N0, ISD::FNEG) && isContractableFMUL(N0.getOperand(i: 0)) &&
15918 (Aggressive || (N0->hasOneUse() && N0.getOperand(i: 0).hasOneUse()))) {
15919 SDValue N00 = N0.getOperand(i: 0).getOperand(i: 0);
15920 SDValue N01 = N0.getOperand(i: 0).getOperand(i: 1);
15921 return matcher.getNode(PreferredFusedOpcode, SL, VT,
15922 matcher.getNode(ISD::FNEG, SL, VT, N00), N01,
15923 matcher.getNode(ISD::FNEG, SL, VT, N1));
15924 }
15925
15926 // Look through FP_EXTEND nodes to do more combining.
15927
15928 // fold (fsub (fpext (fmul x, y)), z)
15929 // -> (fma (fpext x), (fpext y), (fneg z))
15930 if (matcher.match(N0, ISD::FP_EXTEND)) {
15931 SDValue N00 = N0.getOperand(i: 0);
15932 if (isContractableFMUL(N00) &&
15933 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
15934 SrcVT: N00.getValueType())) {
15935 return matcher.getNode(
15936 PreferredFusedOpcode, SL, VT,
15937 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 0)),
15938 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 1)),
15939 matcher.getNode(ISD::FNEG, SL, VT, N1));
15940 }
15941 }
15942
15943 // fold (fsub x, (fpext (fmul y, z)))
15944 // -> (fma (fneg (fpext y)), (fpext z), x)
15945 // Note: Commutes FSUB operands.
15946 if (matcher.match(N1, ISD::FP_EXTEND)) {
15947 SDValue N10 = N1.getOperand(i: 0);
15948 if (isContractableFMUL(N10) &&
15949 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
15950 SrcVT: N10.getValueType())) {
15951 return matcher.getNode(
15952 PreferredFusedOpcode, SL, VT,
15953 matcher.getNode(
15954 ISD::FNEG, SL, VT,
15955 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(i: 0))),
15956 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(i: 1)), N0);
15957 }
15958 }
15959
15960 // fold (fsub (fpext (fneg (fmul, x, y))), z)
15961 // -> (fneg (fma (fpext x), (fpext y), z))
15962 // Note: This could be removed with appropriate canonicalization of the
15963 // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
15964 // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
15965 // from implementing the canonicalization in visitFSUB.
15966 if (matcher.match(N0, ISD::FP_EXTEND)) {
15967 SDValue N00 = N0.getOperand(i: 0);
15968 if (matcher.match(N00, ISD::FNEG)) {
15969 SDValue N000 = N00.getOperand(i: 0);
15970 if (isContractableFMUL(N000) &&
15971 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
15972 SrcVT: N00.getValueType())) {
15973 return matcher.getNode(
15974 ISD::FNEG, SL, VT,
15975 matcher.getNode(
15976 PreferredFusedOpcode, SL, VT,
15977 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(i: 0)),
15978 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(i: 1)),
15979 N1));
15980 }
15981 }
15982 }
15983
15984 // fold (fsub (fneg (fpext (fmul, x, y))), z)
15985 // -> (fneg (fma (fpext x)), (fpext y), z)
15986 // Note: This could be removed with appropriate canonicalization of the
15987 // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
15988 // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
15989 // from implementing the canonicalization in visitFSUB.
15990 if (matcher.match(N0, ISD::FNEG)) {
15991 SDValue N00 = N0.getOperand(i: 0);
15992 if (matcher.match(N00, ISD::FP_EXTEND)) {
15993 SDValue N000 = N00.getOperand(i: 0);
15994 if (isContractableFMUL(N000) &&
15995 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
15996 SrcVT: N000.getValueType())) {
15997 return matcher.getNode(
15998 ISD::FNEG, SL, VT,
15999 matcher.getNode(
16000 PreferredFusedOpcode, SL, VT,
16001 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(i: 0)),
16002 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(i: 1)),
16003 N1));
16004 }
16005 }
16006 }
16007
16008 auto isReassociable = [&Options](SDNode *N) {
16009 return Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
16010 };
16011
16012 auto isContractableAndReassociableFMUL = [&isContractableFMUL,
16013 &isReassociable](SDValue N) {
16014 return isContractableFMUL(N) && isReassociable(N.getNode());
16015 };
16016
16017 auto isFusedOp = [&](SDValue N) {
16018 return matcher.match(N, ISD::FMA) || matcher.match(N, ISD::FMAD);
16019 };
16020
16021 // More folding opportunities when target permits.
16022 if (Aggressive && isReassociable(N)) {
16023 bool CanFuse = Options.UnsafeFPMath || N->getFlags().hasAllowContract();
16024 // fold (fsub (fma x, y, (fmul u, v)), z)
16025 // -> (fma x, y (fma u, v, (fneg z)))
16026 if (CanFuse && isFusedOp(N0) &&
16027 isContractableAndReassociableFMUL(N0.getOperand(i: 2)) &&
16028 N0->hasOneUse() && N0.getOperand(i: 2)->hasOneUse()) {
16029 return matcher.getNode(
16030 PreferredFusedOpcode, SL, VT, N0.getOperand(i: 0), N0.getOperand(i: 1),
16031 matcher.getNode(PreferredFusedOpcode, SL, VT,
16032 N0.getOperand(i: 2).getOperand(i: 0),
16033 N0.getOperand(i: 2).getOperand(i: 1),
16034 matcher.getNode(ISD::FNEG, SL, VT, N1)));
16035 }
16036
16037 // fold (fsub x, (fma y, z, (fmul u, v)))
16038 // -> (fma (fneg y), z, (fma (fneg u), v, x))
16039 if (CanFuse && isFusedOp(N1) &&
16040 isContractableAndReassociableFMUL(N1.getOperand(i: 2)) &&
16041 N1->hasOneUse() && NoSignedZero) {
16042 SDValue N20 = N1.getOperand(i: 2).getOperand(i: 0);
16043 SDValue N21 = N1.getOperand(i: 2).getOperand(i: 1);
16044 return matcher.getNode(
16045 PreferredFusedOpcode, SL, VT,
16046 matcher.getNode(ISD::FNEG, SL, VT, N1.getOperand(i: 0)),
16047 N1.getOperand(i: 1),
16048 matcher.getNode(PreferredFusedOpcode, SL, VT,
16049 matcher.getNode(ISD::FNEG, SL, VT, N20), N21, N0));
16050 }
16051
16052 // fold (fsub (fma x, y, (fpext (fmul u, v))), z)
16053 // -> (fma x, y (fma (fpext u), (fpext v), (fneg z)))
16054 if (isFusedOp(N0) && N0->hasOneUse()) {
16055 SDValue N02 = N0.getOperand(i: 2);
16056 if (matcher.match(N02, ISD::FP_EXTEND)) {
16057 SDValue N020 = N02.getOperand(i: 0);
16058 if (isContractableAndReassociableFMUL(N020) &&
16059 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
16060 SrcVT: N020.getValueType())) {
16061 return matcher.getNode(
16062 PreferredFusedOpcode, SL, VT, N0.getOperand(i: 0), N0.getOperand(i: 1),
16063 matcher.getNode(
16064 PreferredFusedOpcode, SL, VT,
16065 matcher.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(i: 0)),
16066 matcher.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(i: 1)),
16067 matcher.getNode(ISD::FNEG, SL, VT, N1)));
16068 }
16069 }
16070 }
16071
16072 // fold (fsub (fpext (fma x, y, (fmul u, v))), z)
16073 // -> (fma (fpext x), (fpext y),
16074 // (fma (fpext u), (fpext v), (fneg z)))
16075 // FIXME: This turns two single-precision and one double-precision
16076 // operation into two double-precision operations, which might not be
16077 // interesting for all targets, especially GPUs.
16078 if (matcher.match(N0, ISD::FP_EXTEND)) {
16079 SDValue N00 = N0.getOperand(i: 0);
16080 if (isFusedOp(N00)) {
16081 SDValue N002 = N00.getOperand(i: 2);
16082 if (isContractableAndReassociableFMUL(N002) &&
16083 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
16084 SrcVT: N00.getValueType())) {
16085 return matcher.getNode(
16086 PreferredFusedOpcode, SL, VT,
16087 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 0)),
16088 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 1)),
16089 matcher.getNode(
16090 PreferredFusedOpcode, SL, VT,
16091 matcher.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(i: 0)),
16092 matcher.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(i: 1)),
16093 matcher.getNode(ISD::FNEG, SL, VT, N1)));
16094 }
16095 }
16096 }
16097
16098 // fold (fsub x, (fma y, z, (fpext (fmul u, v))))
16099 // -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x))
16100 if (isFusedOp(N1) && matcher.match(N1.getOperand(i: 2), ISD::FP_EXTEND) &&
16101 N1->hasOneUse()) {
16102 SDValue N120 = N1.getOperand(i: 2).getOperand(i: 0);
16103 if (isContractableAndReassociableFMUL(N120) &&
16104 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
16105 SrcVT: N120.getValueType())) {
16106 SDValue N1200 = N120.getOperand(i: 0);
16107 SDValue N1201 = N120.getOperand(i: 1);
16108 return matcher.getNode(
16109 PreferredFusedOpcode, SL, VT,
16110 matcher.getNode(ISD::FNEG, SL, VT, N1.getOperand(i: 0)),
16111 N1.getOperand(i: 1),
16112 matcher.getNode(
16113 PreferredFusedOpcode, SL, VT,
16114 matcher.getNode(ISD::FNEG, SL, VT,
16115 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1200)),
16116 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1201), N0));
16117 }
16118 }
16119
16120 // fold (fsub x, (fpext (fma y, z, (fmul u, v))))
16121 // -> (fma (fneg (fpext y)), (fpext z),
16122 // (fma (fneg (fpext u)), (fpext v), x))
16123 // FIXME: This turns two single-precision and one double-precision
16124 // operation into two double-precision operations, which might not be
16125 // interesting for all targets, especially GPUs.
16126 if (matcher.match(N1, ISD::FP_EXTEND) && isFusedOp(N1.getOperand(i: 0))) {
16127 SDValue CvtSrc = N1.getOperand(i: 0);
16128 SDValue N100 = CvtSrc.getOperand(i: 0);
16129 SDValue N101 = CvtSrc.getOperand(i: 1);
16130 SDValue N102 = CvtSrc.getOperand(i: 2);
16131 if (isContractableAndReassociableFMUL(N102) &&
16132 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
16133 SrcVT: CvtSrc.getValueType())) {
16134 SDValue N1020 = N102.getOperand(i: 0);
16135 SDValue N1021 = N102.getOperand(i: 1);
16136 return matcher.getNode(
16137 PreferredFusedOpcode, SL, VT,
16138 matcher.getNode(ISD::FNEG, SL, VT,
16139 matcher.getNode(ISD::FP_EXTEND, SL, VT, N100)),
16140 matcher.getNode(ISD::FP_EXTEND, SL, VT, N101),
16141 matcher.getNode(
16142 PreferredFusedOpcode, SL, VT,
16143 matcher.getNode(ISD::FNEG, SL, VT,
16144 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1020)),
16145 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1021), N0));
16146 }
16147 }
16148 }
16149
16150 return SDValue();
16151}
16152
16153/// Try to perform FMA combining on a given FMUL node based on the distributive
16154/// law x * (y + 1) = x * y + x and variants thereof (commuted versions,
16155/// subtraction instead of addition).
16156SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) {
16157 SDValue N0 = N->getOperand(Num: 0);
16158 SDValue N1 = N->getOperand(Num: 1);
16159 EVT VT = N->getValueType(ResNo: 0);
16160 SDLoc SL(N);
16161
16162 assert(N->getOpcode() == ISD::FMUL && "Expected FMUL Operation");
16163
16164 const TargetOptions &Options = DAG.getTarget().Options;
16165
16166 // The transforms below are incorrect when x == 0 and y == inf, because the
16167 // intermediate multiplication produces a nan.
16168 SDValue FAdd = N0.getOpcode() == ISD::FADD ? N0 : N1;
16169 if (!hasNoInfs(Options, N: FAdd))
16170 return SDValue();
16171
16172 // Floating-point multiply-add without intermediate rounding.
16173 bool HasFMA =
16174 isContractableFMUL(Options, N: SDValue(N, 0)) &&
16175 TLI.isFMAFasterThanFMulAndFAdd(MF: DAG.getMachineFunction(), VT) &&
16176 (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::FMA, VT));
16177
16178 // Floating-point multiply-add with intermediate rounding. This can result
16179 // in a less precise result due to the changed rounding order.
16180 bool HasFMAD = Options.UnsafeFPMath &&
16181 (LegalOperations && TLI.isFMADLegal(DAG, N));
16182
16183 // No valid opcode, do not combine.
16184 if (!HasFMAD && !HasFMA)
16185 return SDValue();
16186
16187 // Always prefer FMAD to FMA for precision.
16188 unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
16189 bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
16190
16191 // fold (fmul (fadd x0, +1.0), y) -> (fma x0, y, y)
16192 // fold (fmul (fadd x0, -1.0), y) -> (fma x0, y, (fneg y))
16193 auto FuseFADD = [&](SDValue X, SDValue Y) {
16194 if (X.getOpcode() == ISD::FADD && (Aggressive || X->hasOneUse())) {
16195 if (auto *C = isConstOrConstSplatFP(N: X.getOperand(i: 1), AllowUndefs: true)) {
16196 if (C->isExactlyValue(V: +1.0))
16197 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT, N1: X.getOperand(i: 0), N2: Y,
16198 N3: Y);
16199 if (C->isExactlyValue(V: -1.0))
16200 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT, N1: X.getOperand(i: 0), N2: Y,
16201 N3: DAG.getNode(Opcode: ISD::FNEG, DL: SL, VT, Operand: Y));
16202 }
16203 }
16204 return SDValue();
16205 };
16206
16207 if (SDValue FMA = FuseFADD(N0, N1))
16208 return FMA;
16209 if (SDValue FMA = FuseFADD(N1, N0))
16210 return FMA;
16211
16212 // fold (fmul (fsub +1.0, x1), y) -> (fma (fneg x1), y, y)
16213 // fold (fmul (fsub -1.0, x1), y) -> (fma (fneg x1), y, (fneg y))
16214 // fold (fmul (fsub x0, +1.0), y) -> (fma x0, y, (fneg y))
16215 // fold (fmul (fsub x0, -1.0), y) -> (fma x0, y, y)
16216 auto FuseFSUB = [&](SDValue X, SDValue Y) {
16217 if (X.getOpcode() == ISD::FSUB && (Aggressive || X->hasOneUse())) {
16218 if (auto *C0 = isConstOrConstSplatFP(N: X.getOperand(i: 0), AllowUndefs: true)) {
16219 if (C0->isExactlyValue(V: +1.0))
16220 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT,
16221 N1: DAG.getNode(Opcode: ISD::FNEG, DL: SL, VT, Operand: X.getOperand(i: 1)), N2: Y,
16222 N3: Y);
16223 if (C0->isExactlyValue(V: -1.0))
16224 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT,
16225 N1: DAG.getNode(Opcode: ISD::FNEG, DL: SL, VT, Operand: X.getOperand(i: 1)), N2: Y,
16226 N3: DAG.getNode(Opcode: ISD::FNEG, DL: SL, VT, Operand: Y));
16227 }
16228 if (auto *C1 = isConstOrConstSplatFP(N: X.getOperand(i: 1), AllowUndefs: true)) {
16229 if (C1->isExactlyValue(V: +1.0))
16230 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT, N1: X.getOperand(i: 0), N2: Y,
16231 N3: DAG.getNode(Opcode: ISD::FNEG, DL: SL, VT, Operand: Y));
16232 if (C1->isExactlyValue(V: -1.0))
16233 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT, N1: X.getOperand(i: 0), N2: Y,
16234 N3: Y);
16235 }
16236 }
16237 return SDValue();
16238 };
16239
16240 if (SDValue FMA = FuseFSUB(N0, N1))
16241 return FMA;
16242 if (SDValue FMA = FuseFSUB(N1, N0))
16243 return FMA;
16244
16245 return SDValue();
16246}
16247
16248SDValue DAGCombiner::visitVP_FADD(SDNode *N) {
16249 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16250
16251 // FADD -> FMA combines:
16252 if (SDValue Fused = visitFADDForFMACombine<VPMatchContext>(N)) {
16253 if (Fused.getOpcode() != ISD::DELETED_NODE)
16254 AddToWorklist(N: Fused.getNode());
16255 return Fused;
16256 }
16257 return SDValue();
16258}
16259
16260SDValue DAGCombiner::visitFADD(SDNode *N) {
16261 SDValue N0 = N->getOperand(Num: 0);
16262 SDValue N1 = N->getOperand(Num: 1);
16263 SDNode *N0CFP = DAG.isConstantFPBuildVectorOrConstantFP(N: N0);
16264 SDNode *N1CFP = DAG.isConstantFPBuildVectorOrConstantFP(N: N1);
16265 EVT VT = N->getValueType(ResNo: 0);
16266 SDLoc DL(N);
16267 const TargetOptions &Options = DAG.getTarget().Options;
16268 SDNodeFlags Flags = N->getFlags();
16269 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16270
16271 if (SDValue R = DAG.simplifyFPBinop(Opcode: N->getOpcode(), X: N0, Y: N1, Flags))
16272 return R;
16273
16274 // fold (fadd c1, c2) -> c1 + c2
16275 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FADD, DL, VT, Ops: {N0, N1}))
16276 return C;
16277
16278 // canonicalize constant to RHS
16279 if (N0CFP && !N1CFP)
16280 return DAG.getNode(Opcode: ISD::FADD, DL, VT, N1, N2: N0);
16281
16282 // fold vector ops
16283 if (VT.isVector())
16284 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
16285 return FoldedVOp;
16286
16287 // N0 + -0.0 --> N0 (also allowed with +0.0 and fast-math)
16288 ConstantFPSDNode *N1C = isConstOrConstSplatFP(N: N1, AllowUndefs: true);
16289 if (N1C && N1C->isZero())
16290 if (N1C->isNegative() || Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())
16291 return N0;
16292
16293 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
16294 return NewSel;
16295
16296 // fold (fadd A, (fneg B)) -> (fsub A, B)
16297 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::FSUB, VT))
16298 if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
16299 Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
16300 return DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1: N0, N2: NegN1);
16301
16302 // fold (fadd (fneg A), B) -> (fsub B, A)
16303 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::FSUB, VT))
16304 if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
16305 Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
16306 return DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1, N2: NegN0);
16307
16308 auto isFMulNegTwo = [](SDValue FMul) {
16309 if (!FMul.hasOneUse() || FMul.getOpcode() != ISD::FMUL)
16310 return false;
16311 auto *C = isConstOrConstSplatFP(N: FMul.getOperand(i: 1), AllowUndefs: true);
16312 return C && C->isExactlyValue(V: -2.0);
16313 };
16314
16315 // fadd (fmul B, -2.0), A --> fsub A, (fadd B, B)
16316 if (isFMulNegTwo(N0)) {
16317 SDValue B = N0.getOperand(i: 0);
16318 SDValue Add = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: B, N2: B);
16319 return DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1, N2: Add);
16320 }
16321 // fadd A, (fmul B, -2.0) --> fsub A, (fadd B, B)
16322 if (isFMulNegTwo(N1)) {
16323 SDValue B = N1.getOperand(i: 0);
16324 SDValue Add = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: B, N2: B);
16325 return DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1: N0, N2: Add);
16326 }
16327
16328 // No FP constant should be created after legalization as Instruction
16329 // Selection pass has a hard time dealing with FP constants.
16330 bool AllowNewConst = (Level < AfterLegalizeDAG);
16331
16332 // If nnan is enabled, fold lots of things.
16333 if ((Options.NoNaNsFPMath || Flags.hasNoNaNs()) && AllowNewConst) {
16334 // If allowed, fold (fadd (fneg x), x) -> 0.0
16335 if (N0.getOpcode() == ISD::FNEG && N0.getOperand(i: 0) == N1)
16336 return DAG.getConstantFP(Val: 0.0, DL, VT);
16337
16338 // If allowed, fold (fadd x, (fneg x)) -> 0.0
16339 if (N1.getOpcode() == ISD::FNEG && N1.getOperand(i: 0) == N0)
16340 return DAG.getConstantFP(Val: 0.0, DL, VT);
16341 }
16342
16343 // If 'unsafe math' or reassoc and nsz, fold lots of things.
16344 // TODO: break out portions of the transformations below for which Unsafe is
16345 // considered and which do not require both nsz and reassoc
16346 if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
16347 (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
16348 AllowNewConst) {
16349 // fadd (fadd x, c1), c2 -> fadd x, c1 + c2
16350 if (N1CFP && N0.getOpcode() == ISD::FADD &&
16351 DAG.isConstantFPBuildVectorOrConstantFP(N: N0.getOperand(i: 1))) {
16352 SDValue NewC = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0.getOperand(i: 1), N2: N1);
16353 return DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0.getOperand(i: 0), N2: NewC);
16354 }
16355
16356 // We can fold chains of FADD's of the same value into multiplications.
16357 // This transform is not safe in general because we are reducing the number
16358 // of rounding steps.
16359 if (TLI.isOperationLegalOrCustom(Op: ISD::FMUL, VT) && !N0CFP && !N1CFP) {
16360 if (N0.getOpcode() == ISD::FMUL) {
16361 SDNode *CFP00 =
16362 DAG.isConstantFPBuildVectorOrConstantFP(N: N0.getOperand(i: 0));
16363 SDNode *CFP01 =
16364 DAG.isConstantFPBuildVectorOrConstantFP(N: N0.getOperand(i: 1));
16365
16366 // (fadd (fmul x, c), x) -> (fmul x, c+1)
16367 if (CFP01 && !CFP00 && N0.getOperand(i: 0) == N1) {
16368 SDValue NewCFP = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0.getOperand(i: 1),
16369 N2: DAG.getConstantFP(Val: 1.0, DL, VT));
16370 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1, N2: NewCFP);
16371 }
16372
16373 // (fadd (fmul x, c), (fadd x, x)) -> (fmul x, c+2)
16374 if (CFP01 && !CFP00 && N1.getOpcode() == ISD::FADD &&
16375 N1.getOperand(i: 0) == N1.getOperand(i: 1) &&
16376 N0.getOperand(i: 0) == N1.getOperand(i: 0)) {
16377 SDValue NewCFP = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0.getOperand(i: 1),
16378 N2: DAG.getConstantFP(Val: 2.0, DL, VT));
16379 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0.getOperand(i: 0), N2: NewCFP);
16380 }
16381 }
16382
16383 if (N1.getOpcode() == ISD::FMUL) {
16384 SDNode *CFP10 =
16385 DAG.isConstantFPBuildVectorOrConstantFP(N: N1.getOperand(i: 0));
16386 SDNode *CFP11 =
16387 DAG.isConstantFPBuildVectorOrConstantFP(N: N1.getOperand(i: 1));
16388
16389 // (fadd x, (fmul x, c)) -> (fmul x, c+1)
16390 if (CFP11 && !CFP10 && N1.getOperand(i: 0) == N0) {
16391 SDValue NewCFP = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N1.getOperand(i: 1),
16392 N2: DAG.getConstantFP(Val: 1.0, DL, VT));
16393 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: NewCFP);
16394 }
16395
16396 // (fadd (fadd x, x), (fmul x, c)) -> (fmul x, c+2)
16397 if (CFP11 && !CFP10 && N0.getOpcode() == ISD::FADD &&
16398 N0.getOperand(i: 0) == N0.getOperand(i: 1) &&
16399 N1.getOperand(i: 0) == N0.getOperand(i: 0)) {
16400 SDValue NewCFP = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N1.getOperand(i: 1),
16401 N2: DAG.getConstantFP(Val: 2.0, DL, VT));
16402 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N1.getOperand(i: 0), N2: NewCFP);
16403 }
16404 }
16405
16406 if (N0.getOpcode() == ISD::FADD) {
16407 SDNode *CFP00 =
16408 DAG.isConstantFPBuildVectorOrConstantFP(N: N0.getOperand(i: 0));
16409 // (fadd (fadd x, x), x) -> (fmul x, 3.0)
16410 if (!CFP00 && N0.getOperand(i: 0) == N0.getOperand(i: 1) &&
16411 (N0.getOperand(i: 0) == N1)) {
16412 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1,
16413 N2: DAG.getConstantFP(Val: 3.0, DL, VT));
16414 }
16415 }
16416
16417 if (N1.getOpcode() == ISD::FADD) {
16418 SDNode *CFP10 =
16419 DAG.isConstantFPBuildVectorOrConstantFP(N: N1.getOperand(i: 0));
16420 // (fadd x, (fadd x, x)) -> (fmul x, 3.0)
16421 if (!CFP10 && N1.getOperand(i: 0) == N1.getOperand(i: 1) &&
16422 N1.getOperand(i: 0) == N0) {
16423 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0,
16424 N2: DAG.getConstantFP(Val: 3.0, DL, VT));
16425 }
16426 }
16427
16428 // (fadd (fadd x, x), (fadd x, x)) -> (fmul x, 4.0)
16429 if (N0.getOpcode() == ISD::FADD && N1.getOpcode() == ISD::FADD &&
16430 N0.getOperand(i: 0) == N0.getOperand(i: 1) &&
16431 N1.getOperand(i: 0) == N1.getOperand(i: 1) &&
16432 N0.getOperand(i: 0) == N1.getOperand(i: 0)) {
16433 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0.getOperand(i: 0),
16434 N2: DAG.getConstantFP(Val: 4.0, DL, VT));
16435 }
16436 }
16437
16438 // Fold fadd(vecreduce(x), vecreduce(y)) -> vecreduce(fadd(x, y))
16439 if (SDValue SD = reassociateReduction(RedOpc: ISD::VECREDUCE_FADD, Opc: ISD::FADD, DL,
16440 VT, N0, N1, Flags))
16441 return SD;
16442 } // enable-unsafe-fp-math
16443
16444 // FADD -> FMA combines:
16445 if (SDValue Fused = visitFADDForFMACombine<EmptyMatchContext>(N)) {
16446 if (Fused.getOpcode() != ISD::DELETED_NODE)
16447 AddToWorklist(N: Fused.getNode());
16448 return Fused;
16449 }
16450 return SDValue();
16451}
16452
16453SDValue DAGCombiner::visitSTRICT_FADD(SDNode *N) {
16454 SDValue Chain = N->getOperand(Num: 0);
16455 SDValue N0 = N->getOperand(Num: 1);
16456 SDValue N1 = N->getOperand(Num: 2);
16457 EVT VT = N->getValueType(ResNo: 0);
16458 EVT ChainVT = N->getValueType(ResNo: 1);
16459 SDLoc DL(N);
16460 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16461
16462 // fold (strict_fadd A, (fneg B)) -> (strict_fsub A, B)
16463 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::STRICT_FSUB, VT))
16464 if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
16465 Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize)) {
16466 return DAG.getNode(Opcode: ISD::STRICT_FSUB, DL, VTList: DAG.getVTList(VT1: VT, VT2: ChainVT),
16467 Ops: {Chain, N0, NegN1});
16468 }
16469
16470 // fold (strict_fadd (fneg A), B) -> (strict_fsub B, A)
16471 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::STRICT_FSUB, VT))
16472 if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
16473 Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize)) {
16474 return DAG.getNode(Opcode: ISD::STRICT_FSUB, DL, VTList: DAG.getVTList(VT1: VT, VT2: ChainVT),
16475 Ops: {Chain, N1, NegN0});
16476 }
16477 return SDValue();
16478}
16479
16480SDValue DAGCombiner::visitFSUB(SDNode *N) {
16481 SDValue N0 = N->getOperand(Num: 0);
16482 SDValue N1 = N->getOperand(Num: 1);
16483 ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N: N0, AllowUndefs: true);
16484 ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N: N1, AllowUndefs: true);
16485 EVT VT = N->getValueType(ResNo: 0);
16486 SDLoc DL(N);
16487 const TargetOptions &Options = DAG.getTarget().Options;
16488 const SDNodeFlags Flags = N->getFlags();
16489 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16490
16491 if (SDValue R = DAG.simplifyFPBinop(Opcode: N->getOpcode(), X: N0, Y: N1, Flags))
16492 return R;
16493
16494 // fold (fsub c1, c2) -> c1-c2
16495 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FSUB, DL, VT, Ops: {N0, N1}))
16496 return C;
16497
16498 // fold vector ops
16499 if (VT.isVector())
16500 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
16501 return FoldedVOp;
16502
16503 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
16504 return NewSel;
16505
16506 // (fsub A, 0) -> A
16507 if (N1CFP && N1CFP->isZero()) {
16508 if (!N1CFP->isNegative() || Options.NoSignedZerosFPMath ||
16509 Flags.hasNoSignedZeros()) {
16510 return N0;
16511 }
16512 }
16513
16514 if (N0 == N1) {
16515 // (fsub x, x) -> 0.0
16516 if (Options.NoNaNsFPMath || Flags.hasNoNaNs())
16517 return DAG.getConstantFP(Val: 0.0f, DL, VT);
16518 }
16519
16520 // (fsub -0.0, N1) -> -N1
16521 if (N0CFP && N0CFP->isZero()) {
16522 if (N0CFP->isNegative() ||
16523 (Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())) {
16524 // We cannot replace an FSUB(+-0.0,X) with FNEG(X) when denormals are
16525 // flushed to zero, unless all users treat denorms as zero (DAZ).
16526 // FIXME: This transform will change the sign of a NaN and the behavior
16527 // of a signaling NaN. It is only valid when a NoNaN flag is present.
16528 DenormalMode DenormMode = DAG.getDenormalMode(VT);
16529 if (DenormMode == DenormalMode::getIEEE()) {
16530 if (SDValue NegN1 =
16531 TLI.getNegatedExpression(Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
16532 return NegN1;
16533 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::FNEG, VT))
16534 return DAG.getNode(Opcode: ISD::FNEG, DL, VT, Operand: N1);
16535 }
16536 }
16537 }
16538
16539 if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
16540 (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
16541 N1.getOpcode() == ISD::FADD) {
16542 // X - (X + Y) -> -Y
16543 if (N0 == N1->getOperand(Num: 0))
16544 return DAG.getNode(Opcode: ISD::FNEG, DL, VT, Operand: N1->getOperand(Num: 1));
16545 // X - (Y + X) -> -Y
16546 if (N0 == N1->getOperand(Num: 1))
16547 return DAG.getNode(Opcode: ISD::FNEG, DL, VT, Operand: N1->getOperand(Num: 0));
16548 }
16549
16550 // fold (fsub A, (fneg B)) -> (fadd A, B)
16551 if (SDValue NegN1 =
16552 TLI.getNegatedExpression(Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
16553 return DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0, N2: NegN1);
16554
16555 // FSUB -> FMA combines:
16556 if (SDValue Fused = visitFSUBForFMACombine<EmptyMatchContext>(N)) {
16557 AddToWorklist(N: Fused.getNode());
16558 return Fused;
16559 }
16560
16561 return SDValue();
16562}
16563
16564// Transform IEEE Floats:
16565// (fmul C, (uitofp Pow2))
16566// -> (bitcast_to_FP (add (bitcast_to_INT C), Log2(Pow2) << mantissa))
16567// (fdiv C, (uitofp Pow2))
16568// -> (bitcast_to_FP (sub (bitcast_to_INT C), Log2(Pow2) << mantissa))
16569//
16570// The rationale is fmul/fdiv by a power of 2 is just change the exponent, so
16571// there is no need for more than an add/sub.
16572//
16573// This is valid under the following circumstances:
16574// 1) We are dealing with IEEE floats
16575// 2) C is normal
16576// 3) The fmul/fdiv add/sub will not go outside of min/max exponent bounds.
16577// TODO: Much of this could also be used for generating `ldexp` on targets the
16578// prefer it.
16579SDValue DAGCombiner::combineFMulOrFDivWithIntPow2(SDNode *N) {
16580 EVT VT = N->getValueType(ResNo: 0);
16581 SDValue ConstOp, Pow2Op;
16582
16583 std::optional<int> Mantissa;
16584 auto GetConstAndPow2Ops = [&](unsigned ConstOpIdx) {
16585 if (ConstOpIdx == 1 && N->getOpcode() == ISD::FDIV)
16586 return false;
16587
16588 ConstOp = peekThroughBitcasts(V: N->getOperand(Num: ConstOpIdx));
16589 Pow2Op = N->getOperand(Num: 1 - ConstOpIdx);
16590 if (Pow2Op.getOpcode() != ISD::UINT_TO_FP &&
16591 (Pow2Op.getOpcode() != ISD::SINT_TO_FP ||
16592 !DAG.computeKnownBits(Op: Pow2Op).isNonNegative()))
16593 return false;
16594
16595 Pow2Op = Pow2Op.getOperand(i: 0);
16596
16597 // `Log2(Pow2Op) < Pow2Op.getScalarSizeInBits()`.
16598 // TODO: We could use knownbits to make this bound more precise.
16599 int MaxExpChange = Pow2Op.getValueType().getScalarSizeInBits();
16600
16601 auto IsFPConstValid = [N, MaxExpChange, &Mantissa](ConstantFPSDNode *CFP) {
16602 if (CFP == nullptr)
16603 return false;
16604
16605 const APFloat &APF = CFP->getValueAPF();
16606
16607 // Make sure we have normal/ieee constant.
16608 if (!APF.isNormal() || !APF.isIEEE())
16609 return false;
16610
16611 // Make sure the floats exponent is within the bounds that this transform
16612 // produces bitwise equals value.
16613 int CurExp = ilogb(Arg: APF);
16614 // FMul by pow2 will only increase exponent.
16615 int MinExp =
16616 N->getOpcode() == ISD::FMUL ? CurExp : (CurExp - MaxExpChange);
16617 // FDiv by pow2 will only decrease exponent.
16618 int MaxExp =
16619 N->getOpcode() == ISD::FDIV ? CurExp : (CurExp + MaxExpChange);
16620 if (MinExp <= APFloat::semanticsMinExponent(APF.getSemantics()) ||
16621 MaxExp >= APFloat::semanticsMaxExponent(APF.getSemantics()))
16622 return false;
16623
16624 // Finally make sure we actually know the mantissa for the float type.
16625 int ThisMantissa = APFloat::semanticsPrecision(APF.getSemantics()) - 1;
16626 if (!Mantissa)
16627 Mantissa = ThisMantissa;
16628
16629 return *Mantissa == ThisMantissa && ThisMantissa > 0;
16630 };
16631
16632 // TODO: We may be able to include undefs.
16633 return ISD::matchUnaryFpPredicate(Op: ConstOp, Match: IsFPConstValid);
16634 };
16635
16636 if (!GetConstAndPow2Ops(0) && !GetConstAndPow2Ops(1))
16637 return SDValue();
16638
16639 if (!TLI.optimizeFMulOrFDivAsShiftAddBitcast(N, FPConst: ConstOp, IntPow2: Pow2Op))
16640 return SDValue();
16641
16642 // Get log2 after all other checks have taken place. This is because
16643 // BuildLogBase2 may create a new node.
16644 SDLoc DL(N);
16645 // Get Log2 type with same bitwidth as the float type (VT).
16646 EVT NewIntVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: VT.getScalarSizeInBits());
16647 if (VT.isVector())
16648 NewIntVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NewIntVT,
16649 EC: VT.getVectorElementCount());
16650
16651 SDValue Log2 = BuildLogBase2(V: Pow2Op, DL, KnownNeverZero: DAG.isKnownNeverZero(Op: Pow2Op),
16652 /*InexpensiveOnly*/ true, OutVT: NewIntVT);
16653 if (!Log2)
16654 return SDValue();
16655
16656 // Perform actual transform.
16657 SDValue MantissaShiftCnt =
16658 DAG.getConstant(Val: *Mantissa, DL, VT: getShiftAmountTy(LHSTy: NewIntVT));
16659 // TODO: Sometimes Log2 is of form `(X + C)`. `(X + C) << C1` should fold to
16660 // `(X << C1) + (C << C1)`, but that isn't always the case because of the
16661 // cast. We could implement that by handle here to handle the casts.
16662 SDValue Shift = DAG.getNode(Opcode: ISD::SHL, DL, VT: NewIntVT, N1: Log2, N2: MantissaShiftCnt);
16663 SDValue ResAsInt =
16664 DAG.getNode(Opcode: N->getOpcode() == ISD::FMUL ? ISD::ADD : ISD::SUB, DL,
16665 VT: NewIntVT, N1: DAG.getBitcast(VT: NewIntVT, V: ConstOp), N2: Shift);
16666 SDValue ResAsFP = DAG.getBitcast(VT, V: ResAsInt);
16667 return ResAsFP;
16668}
16669
16670SDValue DAGCombiner::visitFMUL(SDNode *N) {
16671 SDValue N0 = N->getOperand(Num: 0);
16672 SDValue N1 = N->getOperand(Num: 1);
16673 ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N: N1, AllowUndefs: true);
16674 EVT VT = N->getValueType(ResNo: 0);
16675 SDLoc DL(N);
16676 const TargetOptions &Options = DAG.getTarget().Options;
16677 const SDNodeFlags Flags = N->getFlags();
16678 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16679
16680 if (SDValue R = DAG.simplifyFPBinop(Opcode: N->getOpcode(), X: N0, Y: N1, Flags))
16681 return R;
16682
16683 // fold (fmul c1, c2) -> c1*c2
16684 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FMUL, DL, VT, Ops: {N0, N1}))
16685 return C;
16686
16687 // canonicalize constant to RHS
16688 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0) &&
16689 !DAG.isConstantFPBuildVectorOrConstantFP(N: N1))
16690 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1, N2: N0);
16691
16692 // fold vector ops
16693 if (VT.isVector())
16694 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
16695 return FoldedVOp;
16696
16697 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
16698 return NewSel;
16699
16700 if (Options.UnsafeFPMath || Flags.hasAllowReassociation()) {
16701 // fmul (fmul X, C1), C2 -> fmul X, C1 * C2
16702 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N1) &&
16703 N0.getOpcode() == ISD::FMUL) {
16704 SDValue N00 = N0.getOperand(i: 0);
16705 SDValue N01 = N0.getOperand(i: 1);
16706 // Avoid an infinite loop by making sure that N00 is not a constant
16707 // (the inner multiply has not been constant folded yet).
16708 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N01) &&
16709 !DAG.isConstantFPBuildVectorOrConstantFP(N: N00)) {
16710 SDValue MulConsts = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N01, N2: N1);
16711 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N00, N2: MulConsts);
16712 }
16713 }
16714
16715 // Match a special-case: we convert X * 2.0 into fadd.
16716 // fmul (fadd X, X), C -> fmul X, 2.0 * C
16717 if (N0.getOpcode() == ISD::FADD && N0.hasOneUse() &&
16718 N0.getOperand(i: 0) == N0.getOperand(i: 1)) {
16719 const SDValue Two = DAG.getConstantFP(Val: 2.0, DL, VT);
16720 SDValue MulConsts = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Two, N2: N1);
16721 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0.getOperand(i: 0), N2: MulConsts);
16722 }
16723
16724 // Fold fmul(vecreduce(x), vecreduce(y)) -> vecreduce(fmul(x, y))
16725 if (SDValue SD = reassociateReduction(RedOpc: ISD::VECREDUCE_FMUL, Opc: ISD::FMUL, DL,
16726 VT, N0, N1, Flags))
16727 return SD;
16728 }
16729
16730 // fold (fmul X, 2.0) -> (fadd X, X)
16731 if (N1CFP && N1CFP->isExactlyValue(V: +2.0))
16732 return DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0, N2: N0);
16733
16734 // fold (fmul X, -1.0) -> (fsub -0.0, X)
16735 if (N1CFP && N1CFP->isExactlyValue(V: -1.0)) {
16736 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::FSUB, VT)) {
16737 return DAG.getNode(Opcode: ISD::FSUB, DL, VT,
16738 N1: DAG.getConstantFP(Val: -0.0, DL, VT), N2: N0, Flags);
16739 }
16740 }
16741
16742 // -N0 * -N1 --> N0 * N1
16743 TargetLowering::NegatibleCost CostN0 =
16744 TargetLowering::NegatibleCost::Expensive;
16745 TargetLowering::NegatibleCost CostN1 =
16746 TargetLowering::NegatibleCost::Expensive;
16747 SDValue NegN0 =
16748 TLI.getNegatedExpression(Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN0);
16749 if (NegN0) {
16750 HandleSDNode NegN0Handle(NegN0);
16751 SDValue NegN1 =
16752 TLI.getNegatedExpression(Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN1);
16753 if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
16754 CostN1 == TargetLowering::NegatibleCost::Cheaper))
16755 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: NegN0, N2: NegN1);
16756 }
16757
16758 // fold (fmul X, (select (fcmp X > 0.0), -1.0, 1.0)) -> (fneg (fabs X))
16759 // fold (fmul X, (select (fcmp X > 0.0), 1.0, -1.0)) -> (fabs X)
16760 if (Flags.hasNoNaNs() && Flags.hasNoSignedZeros() &&
16761 (N0.getOpcode() == ISD::SELECT || N1.getOpcode() == ISD::SELECT) &&
16762 TLI.isOperationLegal(Op: ISD::FABS, VT)) {
16763 SDValue Select = N0, X = N1;
16764 if (Select.getOpcode() != ISD::SELECT)
16765 std::swap(a&: Select, b&: X);
16766
16767 SDValue Cond = Select.getOperand(i: 0);
16768 auto TrueOpnd = dyn_cast<ConstantFPSDNode>(Val: Select.getOperand(i: 1));
16769 auto FalseOpnd = dyn_cast<ConstantFPSDNode>(Val: Select.getOperand(i: 2));
16770
16771 if (TrueOpnd && FalseOpnd &&
16772 Cond.getOpcode() == ISD::SETCC && Cond.getOperand(i: 0) == X &&
16773 isa<ConstantFPSDNode>(Val: Cond.getOperand(i: 1)) &&
16774 cast<ConstantFPSDNode>(Val: Cond.getOperand(i: 1))->isExactlyValue(V: 0.0)) {
16775 ISD::CondCode CC = cast<CondCodeSDNode>(Val: Cond.getOperand(i: 2))->get();
16776 switch (CC) {
16777 default: break;
16778 case ISD::SETOLT:
16779 case ISD::SETULT:
16780 case ISD::SETOLE:
16781 case ISD::SETULE:
16782 case ISD::SETLT:
16783 case ISD::SETLE:
16784 std::swap(a&: TrueOpnd, b&: FalseOpnd);
16785 [[fallthrough]];
16786 case ISD::SETOGT:
16787 case ISD::SETUGT:
16788 case ISD::SETOGE:
16789 case ISD::SETUGE:
16790 case ISD::SETGT:
16791 case ISD::SETGE:
16792 if (TrueOpnd->isExactlyValue(V: -1.0) && FalseOpnd->isExactlyValue(V: 1.0) &&
16793 TLI.isOperationLegal(Op: ISD::FNEG, VT))
16794 return DAG.getNode(Opcode: ISD::FNEG, DL, VT,
16795 Operand: DAG.getNode(Opcode: ISD::FABS, DL, VT, Operand: X));
16796 if (TrueOpnd->isExactlyValue(V: 1.0) && FalseOpnd->isExactlyValue(V: -1.0))
16797 return DAG.getNode(Opcode: ISD::FABS, DL, VT, Operand: X);
16798
16799 break;
16800 }
16801 }
16802 }
16803
16804 // FMUL -> FMA combines:
16805 if (SDValue Fused = visitFMULForFMADistributiveCombine(N)) {
16806 AddToWorklist(N: Fused.getNode());
16807 return Fused;
16808 }
16809
16810 // Don't do `combineFMulOrFDivWithIntPow2` until after FMUL -> FMA has been
16811 // able to run.
16812 if (SDValue R = combineFMulOrFDivWithIntPow2(N))
16813 return R;
16814
16815 return SDValue();
16816}
16817
16818template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
16819 SDValue N0 = N->getOperand(Num: 0);
16820 SDValue N1 = N->getOperand(Num: 1);
16821 SDValue N2 = N->getOperand(Num: 2);
16822 ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(Val&: N0);
16823 ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(Val&: N1);
16824 EVT VT = N->getValueType(ResNo: 0);
16825 SDLoc DL(N);
16826 const TargetOptions &Options = DAG.getTarget().Options;
16827 // FMA nodes have flags that propagate to the created nodes.
16828 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16829 MatchContextClass matcher(DAG, TLI, N);
16830
16831 bool CanReassociate =
16832 Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
16833
16834 // Constant fold FMA.
16835 if (isa<ConstantFPSDNode>(Val: N0) &&
16836 isa<ConstantFPSDNode>(Val: N1) &&
16837 isa<ConstantFPSDNode>(Val: N2)) {
16838 return matcher.getNode(ISD::FMA, DL, VT, N0, N1, N2);
16839 }
16840
16841 // (-N0 * -N1) + N2 --> (N0 * N1) + N2
16842 TargetLowering::NegatibleCost CostN0 =
16843 TargetLowering::NegatibleCost::Expensive;
16844 TargetLowering::NegatibleCost CostN1 =
16845 TargetLowering::NegatibleCost::Expensive;
16846 SDValue NegN0 =
16847 TLI.getNegatedExpression(Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN0);
16848 if (NegN0) {
16849 HandleSDNode NegN0Handle(NegN0);
16850 SDValue NegN1 =
16851 TLI.getNegatedExpression(Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN1);
16852 if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
16853 CostN1 == TargetLowering::NegatibleCost::Cheaper))
16854 return matcher.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2);
16855 }
16856
16857 // FIXME: use fast math flags instead of Options.UnsafeFPMath
16858 if (Options.UnsafeFPMath) {
16859 if (N0CFP && N0CFP->isZero())
16860 return N2;
16861 if (N1CFP && N1CFP->isZero())
16862 return N2;
16863 }
16864
16865 // FIXME: Support splat of constant.
16866 if (N0CFP && N0CFP->isExactlyValue(V: 1.0))
16867 return matcher.getNode(ISD::FADD, SDLoc(N), VT, N1, N2);
16868 if (N1CFP && N1CFP->isExactlyValue(V: 1.0))
16869 return matcher.getNode(ISD::FADD, SDLoc(N), VT, N0, N2);
16870
16871 // Canonicalize (fma c, x, y) -> (fma x, c, y)
16872 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0) &&
16873 !DAG.isConstantFPBuildVectorOrConstantFP(N: N1))
16874 return matcher.getNode(ISD::FMA, SDLoc(N), VT, N1, N0, N2);
16875
16876 if (CanReassociate) {
16877 // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2)
16878 if (matcher.match(N2, ISD::FMUL) && N0 == N2.getOperand(i: 0) &&
16879 DAG.isConstantFPBuildVectorOrConstantFP(N: N1) &&
16880 DAG.isConstantFPBuildVectorOrConstantFP(N: N2.getOperand(i: 1))) {
16881 return matcher.getNode(
16882 ISD::FMUL, DL, VT, N0,
16883 matcher.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(i: 1)));
16884 }
16885
16886 // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y)
16887 if (matcher.match(N0, ISD::FMUL) &&
16888 DAG.isConstantFPBuildVectorOrConstantFP(N: N1) &&
16889 DAG.isConstantFPBuildVectorOrConstantFP(N: N0.getOperand(i: 1))) {
16890 return matcher.getNode(
16891 ISD::FMA, DL, VT, N0.getOperand(i: 0),
16892 matcher.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(i: 1)), N2);
16893 }
16894 }
16895
16896 // (fma x, -1, y) -> (fadd (fneg x), y)
16897 // FIXME: Support splat of constant.
16898 if (N1CFP) {
16899 if (N1CFP->isExactlyValue(V: 1.0))
16900 return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
16901
16902 if (N1CFP->isExactlyValue(V: -1.0) &&
16903 (!LegalOperations || TLI.isOperationLegal(Op: ISD::FNEG, VT))) {
16904 SDValue RHSNeg = matcher.getNode(ISD::FNEG, DL, VT, N0);
16905 AddToWorklist(N: RHSNeg.getNode());
16906 return matcher.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
16907 }
16908
16909 // fma (fneg x), K, y -> fma x -K, y
16910 if (matcher.match(N0, ISD::FNEG) &&
16911 (TLI.isOperationLegal(Op: ISD::ConstantFP, VT) ||
16912 (N1.hasOneUse() &&
16913 !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
16914 return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(i: 0),
16915 matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
16916 }
16917 }
16918
16919 // FIXME: Support splat of constant.
16920 if (CanReassociate) {
16921 // (fma x, c, x) -> (fmul x, (c+1))
16922 if (N1CFP && N0 == N2) {
16923 return matcher.getNode(ISD::FMUL, DL, VT, N0,
16924 matcher.getNode(ISD::FADD, DL, VT, N1,
16925 DAG.getConstantFP(Val: 1.0, DL, VT)));
16926 }
16927
16928 // (fma x, c, (fneg x)) -> (fmul x, (c-1))
16929 if (N1CFP && matcher.match(N2, ISD::FNEG) && N2.getOperand(i: 0) == N0) {
16930 return matcher.getNode(ISD::FMUL, DL, VT, N0,
16931 matcher.getNode(ISD::FADD, DL, VT, N1,
16932 DAG.getConstantFP(Val: -1.0, DL, VT)));
16933 }
16934 }
16935
16936 // fold ((fma (fneg X), Y, (fneg Z)) -> fneg (fma X, Y, Z))
16937 // fold ((fma X, (fneg Y), (fneg Z)) -> fneg (fma X, Y, Z))
16938 if (!TLI.isFNegFree(VT))
16939 if (SDValue Neg = TLI.getCheaperNegatedExpression(
16940 Op: SDValue(N, 0), DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
16941 return matcher.getNode(ISD::FNEG, DL, VT, Neg);
16942 return SDValue();
16943}
16944
16945SDValue DAGCombiner::visitFMAD(SDNode *N) {
16946 SDValue N0 = N->getOperand(Num: 0);
16947 SDValue N1 = N->getOperand(Num: 1);
16948 SDValue N2 = N->getOperand(Num: 2);
16949 EVT VT = N->getValueType(ResNo: 0);
16950 SDLoc DL(N);
16951
16952 // Constant fold FMAD.
16953 if (isa<ConstantFPSDNode>(Val: N0) && isa<ConstantFPSDNode>(Val: N1) &&
16954 isa<ConstantFPSDNode>(Val: N2))
16955 return DAG.getNode(Opcode: ISD::FMAD, DL, VT, N1: N0, N2: N1, N3: N2);
16956
16957 return SDValue();
16958}
16959
16960// Combine multiple FDIVs with the same divisor into multiple FMULs by the
16961// reciprocal.
16962// E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip)
16963// Notice that this is not always beneficial. One reason is different targets
16964// may have different costs for FDIV and FMUL, so sometimes the cost of two
16965// FDIVs may be lower than the cost of one FDIV and two FMULs. Another reason
16966// is the critical path is increased from "one FDIV" to "one FDIV + one FMUL".
16967SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) {
16968 // TODO: Limit this transform based on optsize/minsize - it always creates at
16969 // least 1 extra instruction. But the perf win may be substantial enough
16970 // that only minsize should restrict this.
16971 bool UnsafeMath = DAG.getTarget().Options.UnsafeFPMath;
16972 const SDNodeFlags Flags = N->getFlags();
16973 if (LegalDAG || (!UnsafeMath && !Flags.hasAllowReciprocal()))
16974 return SDValue();
16975
16976 // Skip if current node is a reciprocal/fneg-reciprocal.
16977 SDValue N0 = N->getOperand(Num: 0), N1 = N->getOperand(Num: 1);
16978 ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N: N0, /* AllowUndefs */ true);
16979 if (N0CFP && (N0CFP->isExactlyValue(V: 1.0) || N0CFP->isExactlyValue(V: -1.0)))
16980 return SDValue();
16981
16982 // Exit early if the target does not want this transform or if there can't
16983 // possibly be enough uses of the divisor to make the transform worthwhile.
16984 unsigned MinUses = TLI.combineRepeatedFPDivisors();
16985
16986 // For splat vectors, scale the number of uses by the splat factor. If we can
16987 // convert the division into a scalar op, that will likely be much faster.
16988 unsigned NumElts = 1;
16989 EVT VT = N->getValueType(ResNo: 0);
16990 if (VT.isVector() && DAG.isSplatValue(V: N1))
16991 NumElts = VT.getVectorMinNumElements();
16992
16993 if (!MinUses || (N1->use_size() * NumElts) < MinUses)
16994 return SDValue();
16995
16996 // Find all FDIV users of the same divisor.
16997 // Use a set because duplicates may be present in the user list.
16998 SetVector<SDNode *> Users;
16999 for (auto *U : N1->uses()) {
17000 if (U->getOpcode() == ISD::FDIV && U->getOperand(Num: 1) == N1) {
17001 // Skip X/sqrt(X) that has not been simplified to sqrt(X) yet.
17002 if (U->getOperand(Num: 1).getOpcode() == ISD::FSQRT &&
17003 U->getOperand(Num: 0) == U->getOperand(Num: 1).getOperand(i: 0) &&
17004 U->getFlags().hasAllowReassociation() &&
17005 U->getFlags().hasNoSignedZeros())
17006 continue;
17007
17008 // This division is eligible for optimization only if global unsafe math
17009 // is enabled or if this division allows reciprocal formation.
17010 if (UnsafeMath || U->getFlags().hasAllowReciprocal())
17011 Users.insert(X: U);
17012 }
17013 }
17014
17015 // Now that we have the actual number of divisor uses, make sure it meets
17016 // the minimum threshold specified by the target.
17017 if ((Users.size() * NumElts) < MinUses)
17018 return SDValue();
17019
17020 SDLoc DL(N);
17021 SDValue FPOne = DAG.getConstantFP(Val: 1.0, DL, VT);
17022 SDValue Reciprocal = DAG.getNode(Opcode: ISD::FDIV, DL, VT, N1: FPOne, N2: N1, Flags);
17023
17024 // Dividend / Divisor -> Dividend * Reciprocal
17025 for (auto *U : Users) {
17026 SDValue Dividend = U->getOperand(Num: 0);
17027 if (Dividend != FPOne) {
17028 SDValue NewNode = DAG.getNode(Opcode: ISD::FMUL, DL: SDLoc(U), VT, N1: Dividend,
17029 N2: Reciprocal, Flags);
17030 CombineTo(N: U, Res: NewNode);
17031 } else if (U != Reciprocal.getNode()) {
17032 // In the absence of fast-math-flags, this user node is always the
17033 // same node as Reciprocal, but with FMF they may be different nodes.
17034 CombineTo(N: U, Res: Reciprocal);
17035 }
17036 }
17037 return SDValue(N, 0); // N was replaced.
17038}
17039
17040SDValue DAGCombiner::visitFDIV(SDNode *N) {
17041 SDValue N0 = N->getOperand(Num: 0);
17042 SDValue N1 = N->getOperand(Num: 1);
17043 EVT VT = N->getValueType(ResNo: 0);
17044 SDLoc DL(N);
17045 const TargetOptions &Options = DAG.getTarget().Options;
17046 SDNodeFlags Flags = N->getFlags();
17047 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17048
17049 if (SDValue R = DAG.simplifyFPBinop(Opcode: N->getOpcode(), X: N0, Y: N1, Flags))
17050 return R;
17051
17052 // fold (fdiv c1, c2) -> c1/c2
17053 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FDIV, DL, VT, Ops: {N0, N1}))
17054 return C;
17055
17056 // fold vector ops
17057 if (VT.isVector())
17058 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
17059 return FoldedVOp;
17060
17061 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
17062 return NewSel;
17063
17064 if (SDValue V = combineRepeatedFPDivisors(N))
17065 return V;
17066
17067 if (Options.UnsafeFPMath || Flags.hasAllowReciprocal()) {
17068 // fold (fdiv X, c2) -> fmul X, 1/c2 if losing precision is acceptable.
17069 if (auto *N1CFP = dyn_cast<ConstantFPSDNode>(Val&: N1)) {
17070 // Compute the reciprocal 1.0 / c2.
17071 const APFloat &N1APF = N1CFP->getValueAPF();
17072 APFloat Recip(N1APF.getSemantics(), 1); // 1.0
17073 APFloat::opStatus st = Recip.divide(RHS: N1APF, RM: APFloat::rmNearestTiesToEven);
17074 // Only do the transform if the reciprocal is a legal fp immediate that
17075 // isn't too nasty (eg NaN, denormal, ...).
17076 if ((st == APFloat::opOK || st == APFloat::opInexact) && // Not too nasty
17077 (!LegalOperations ||
17078 // FIXME: custom lowering of ConstantFP might fail (see e.g. ARM
17079 // backend)... we should handle this gracefully after Legalize.
17080 // TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT) ||
17081 TLI.isOperationLegal(Op: ISD::ConstantFP, VT) ||
17082 TLI.isFPImmLegal(Recip, VT, ForCodeSize)))
17083 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0,
17084 N2: DAG.getConstantFP(Val: Recip, DL, VT));
17085 }
17086
17087 // If this FDIV is part of a reciprocal square root, it may be folded
17088 // into a target-specific square root estimate instruction.
17089 if (N1.getOpcode() == ISD::FSQRT) {
17090 if (SDValue RV = buildRsqrtEstimate(Op: N1.getOperand(i: 0), Flags))
17091 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: RV);
17092 } else if (N1.getOpcode() == ISD::FP_EXTEND &&
17093 N1.getOperand(i: 0).getOpcode() == ISD::FSQRT) {
17094 if (SDValue RV =
17095 buildRsqrtEstimate(Op: N1.getOperand(i: 0).getOperand(i: 0), Flags)) {
17096 RV = DAG.getNode(Opcode: ISD::FP_EXTEND, DL: SDLoc(N1), VT, Operand: RV);
17097 AddToWorklist(N: RV.getNode());
17098 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: RV);
17099 }
17100 } else if (N1.getOpcode() == ISD::FP_ROUND &&
17101 N1.getOperand(i: 0).getOpcode() == ISD::FSQRT) {
17102 if (SDValue RV =
17103 buildRsqrtEstimate(Op: N1.getOperand(i: 0).getOperand(i: 0), Flags)) {
17104 RV = DAG.getNode(Opcode: ISD::FP_ROUND, DL: SDLoc(N1), VT, N1: RV, N2: N1.getOperand(i: 1));
17105 AddToWorklist(N: RV.getNode());
17106 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: RV);
17107 }
17108 } else if (N1.getOpcode() == ISD::FMUL) {
17109 // Look through an FMUL. Even though this won't remove the FDIV directly,
17110 // it's still worthwhile to get rid of the FSQRT if possible.
17111 SDValue Sqrt, Y;
17112 if (N1.getOperand(i: 0).getOpcode() == ISD::FSQRT) {
17113 Sqrt = N1.getOperand(i: 0);
17114 Y = N1.getOperand(i: 1);
17115 } else if (N1.getOperand(i: 1).getOpcode() == ISD::FSQRT) {
17116 Sqrt = N1.getOperand(i: 1);
17117 Y = N1.getOperand(i: 0);
17118 }
17119 if (Sqrt.getNode()) {
17120 // If the other multiply operand is known positive, pull it into the
17121 // sqrt. That will eliminate the division if we convert to an estimate.
17122 if (Flags.hasAllowReassociation() && N1.hasOneUse() &&
17123 N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse()) {
17124 SDValue A;
17125 if (Y.getOpcode() == ISD::FABS && Y.hasOneUse())
17126 A = Y.getOperand(i: 0);
17127 else if (Y == Sqrt.getOperand(i: 0))
17128 A = Y;
17129 if (A) {
17130 // X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)
17131 // X / (A * sqrt(A)) --> X / sqrt(A*A*A) --> X * rsqrt(A*A*A)
17132 SDValue AA = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: A, N2: A);
17133 SDValue AAZ =
17134 DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: AA, N2: Sqrt.getOperand(i: 0));
17135 if (SDValue Rsqrt = buildRsqrtEstimate(Op: AAZ, Flags))
17136 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: Rsqrt);
17137
17138 // Estimate creation failed. Clean up speculatively created nodes.
17139 recursivelyDeleteUnusedNodes(N: AAZ.getNode());
17140 }
17141 }
17142
17143 // We found a FSQRT, so try to make this fold:
17144 // X / (Y * sqrt(Z)) -> X * (rsqrt(Z) / Y)
17145 if (SDValue Rsqrt = buildRsqrtEstimate(Op: Sqrt.getOperand(i: 0), Flags)) {
17146 SDValue Div = DAG.getNode(Opcode: ISD::FDIV, DL: SDLoc(N1), VT, N1: Rsqrt, N2: Y);
17147 AddToWorklist(N: Div.getNode());
17148 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: Div);
17149 }
17150 }
17151 }
17152
17153 // Fold into a reciprocal estimate and multiply instead of a real divide.
17154 if (Options.NoInfsFPMath || Flags.hasNoInfs())
17155 if (SDValue RV = BuildDivEstimate(N: N0, Op: N1, Flags))
17156 return RV;
17157 }
17158
17159 // Fold X/Sqrt(X) -> Sqrt(X)
17160 if ((Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros()) &&
17161 (Options.UnsafeFPMath || Flags.hasAllowReassociation()))
17162 if (N1.getOpcode() == ISD::FSQRT && N0 == N1.getOperand(i: 0))
17163 return N1;
17164
17165 // (fdiv (fneg X), (fneg Y)) -> (fdiv X, Y)
17166 TargetLowering::NegatibleCost CostN0 =
17167 TargetLowering::NegatibleCost::Expensive;
17168 TargetLowering::NegatibleCost CostN1 =
17169 TargetLowering::NegatibleCost::Expensive;
17170 SDValue NegN0 =
17171 TLI.getNegatedExpression(Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN0);
17172 if (NegN0) {
17173 HandleSDNode NegN0Handle(NegN0);
17174 SDValue NegN1 =
17175 TLI.getNegatedExpression(Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN1);
17176 if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
17177 CostN1 == TargetLowering::NegatibleCost::Cheaper))
17178 return DAG.getNode(Opcode: ISD::FDIV, DL: SDLoc(N), VT, N1: NegN0, N2: NegN1);
17179 }
17180
17181 if (SDValue R = combineFMulOrFDivWithIntPow2(N))
17182 return R;
17183
17184 return SDValue();
17185}
17186
17187SDValue DAGCombiner::visitFREM(SDNode *N) {
17188 SDValue N0 = N->getOperand(Num: 0);
17189 SDValue N1 = N->getOperand(Num: 1);
17190 EVT VT = N->getValueType(ResNo: 0);
17191 SDNodeFlags Flags = N->getFlags();
17192 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17193
17194 if (SDValue R = DAG.simplifyFPBinop(Opcode: N->getOpcode(), X: N0, Y: N1, Flags))
17195 return R;
17196
17197 // fold (frem c1, c2) -> fmod(c1,c2)
17198 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FREM, DL: SDLoc(N), VT, Ops: {N0, N1}))
17199 return C;
17200
17201 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
17202 return NewSel;
17203
17204 return SDValue();
17205}
17206
17207SDValue DAGCombiner::visitFSQRT(SDNode *N) {
17208 SDNodeFlags Flags = N->getFlags();
17209 const TargetOptions &Options = DAG.getTarget().Options;
17210
17211 // Require 'ninf' flag since sqrt(+Inf) = +Inf, but the estimation goes as:
17212 // sqrt(+Inf) == rsqrt(+Inf) * +Inf = 0 * +Inf = NaN
17213 if (!Flags.hasApproximateFuncs() ||
17214 (!Options.NoInfsFPMath && !Flags.hasNoInfs()))
17215 return SDValue();
17216
17217 SDValue N0 = N->getOperand(Num: 0);
17218 if (TLI.isFsqrtCheap(X: N0, DAG))
17219 return SDValue();
17220
17221 // FSQRT nodes have flags that propagate to the created nodes.
17222 // TODO: If this is N0/sqrt(N0), and we reach this node before trying to
17223 // transform the fdiv, we may produce a sub-optimal estimate sequence
17224 // because the reciprocal calculation may not have to filter out a
17225 // 0.0 input.
17226 return buildSqrtEstimate(Op: N0, Flags);
17227}
17228
17229/// copysign(x, fp_extend(y)) -> copysign(x, y)
17230/// copysign(x, fp_round(y)) -> copysign(x, y)
17231/// Operands to the functions are the type of X and Y respectively.
17232static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(EVT XTy, EVT YTy) {
17233 // Always fold no-op FP casts.
17234 if (XTy == YTy)
17235 return true;
17236
17237 // Do not optimize out type conversion of f128 type yet.
17238 // For some targets like x86_64, configuration is changed to keep one f128
17239 // value in one SSE register, but instruction selection cannot handle
17240 // FCOPYSIGN on SSE registers yet.
17241 if (YTy == MVT::f128)
17242 return false;
17243
17244 return !YTy.isVector() || EnableVectorFCopySignExtendRound;
17245}
17246
17247static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode *N) {
17248 SDValue N1 = N->getOperand(Num: 1);
17249 if (N1.getOpcode() != ISD::FP_EXTEND &&
17250 N1.getOpcode() != ISD::FP_ROUND)
17251 return false;
17252 EVT N1VT = N1->getValueType(ResNo: 0);
17253 EVT N1Op0VT = N1->getOperand(Num: 0).getValueType();
17254 return CanCombineFCOPYSIGN_EXTEND_ROUND(XTy: N1VT, YTy: N1Op0VT);
17255}
17256
17257SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) {
17258 SDValue N0 = N->getOperand(Num: 0);
17259 SDValue N1 = N->getOperand(Num: 1);
17260 EVT VT = N->getValueType(ResNo: 0);
17261
17262 // fold (fcopysign c1, c2) -> fcopysign(c1,c2)
17263 if (SDValue C =
17264 DAG.FoldConstantArithmetic(Opcode: ISD::FCOPYSIGN, DL: SDLoc(N), VT, Ops: {N0, N1}))
17265 return C;
17266
17267 if (ConstantFPSDNode *N1C = isConstOrConstSplatFP(N: N->getOperand(Num: 1))) {
17268 const APFloat &V = N1C->getValueAPF();
17269 // copysign(x, c1) -> fabs(x) iff ispos(c1)
17270 // copysign(x, c1) -> fneg(fabs(x)) iff isneg(c1)
17271 if (!V.isNegative()) {
17272 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::FABS, VT))
17273 return DAG.getNode(Opcode: ISD::FABS, DL: SDLoc(N), VT, Operand: N0);
17274 } else {
17275 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::FNEG, VT))
17276 return DAG.getNode(Opcode: ISD::FNEG, DL: SDLoc(N), VT,
17277 Operand: DAG.getNode(Opcode: ISD::FABS, DL: SDLoc(N0), VT, Operand: N0));
17278 }
17279 }
17280
17281 // copysign(fabs(x), y) -> copysign(x, y)
17282 // copysign(fneg(x), y) -> copysign(x, y)
17283 // copysign(copysign(x,z), y) -> copysign(x, y)
17284 if (N0.getOpcode() == ISD::FABS || N0.getOpcode() == ISD::FNEG ||
17285 N0.getOpcode() == ISD::FCOPYSIGN)
17286 return DAG.getNode(Opcode: ISD::FCOPYSIGN, DL: SDLoc(N), VT, N1: N0.getOperand(i: 0), N2: N1);
17287
17288 // copysign(x, abs(y)) -> abs(x)
17289 if (N1.getOpcode() == ISD::FABS)
17290 return DAG.getNode(Opcode: ISD::FABS, DL: SDLoc(N), VT, Operand: N0);
17291
17292 // copysign(x, copysign(y,z)) -> copysign(x, z)
17293 if (N1.getOpcode() == ISD::FCOPYSIGN)
17294 return DAG.getNode(Opcode: ISD::FCOPYSIGN, DL: SDLoc(N), VT, N1: N0, N2: N1.getOperand(i: 1));
17295
17296 // copysign(x, fp_extend(y)) -> copysign(x, y)
17297 // copysign(x, fp_round(y)) -> copysign(x, y)
17298 if (CanCombineFCOPYSIGN_EXTEND_ROUND(N))
17299 return DAG.getNode(Opcode: ISD::FCOPYSIGN, DL: SDLoc(N), VT, N1: N0, N2: N1.getOperand(i: 0));
17300
17301 return SDValue();
17302}
17303
17304SDValue DAGCombiner::visitFPOW(SDNode *N) {
17305 ConstantFPSDNode *ExponentC = isConstOrConstSplatFP(N: N->getOperand(Num: 1));
17306 if (!ExponentC)
17307 return SDValue();
17308 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17309
17310 // Try to convert x ** (1/3) into cube root.
17311 // TODO: Handle the various flavors of long double.
17312 // TODO: Since we're approximating, we don't need an exact 1/3 exponent.
17313 // Some range near 1/3 should be fine.
17314 EVT VT = N->getValueType(ResNo: 0);
17315 if ((VT == MVT::f32 && ExponentC->getValueAPF().isExactlyValue(1.0f/3.0f)) ||
17316 (VT == MVT::f64 && ExponentC->getValueAPF().isExactlyValue(1.0/3.0))) {
17317 // pow(-0.0, 1/3) = +0.0; cbrt(-0.0) = -0.0.
17318 // pow(-inf, 1/3) = +inf; cbrt(-inf) = -inf.
17319 // pow(-val, 1/3) = nan; cbrt(-val) = -num.
17320 // For regular numbers, rounding may cause the results to differ.
17321 // Therefore, we require { nsz ninf nnan afn } for this transform.
17322 // TODO: We could select out the special cases if we don't have nsz/ninf.
17323 SDNodeFlags Flags = N->getFlags();
17324 if (!Flags.hasNoSignedZeros() || !Flags.hasNoInfs() || !Flags.hasNoNaNs() ||
17325 !Flags.hasApproximateFuncs())
17326 return SDValue();
17327
17328 // Do not create a cbrt() libcall if the target does not have it, and do not
17329 // turn a pow that has lowering support into a cbrt() libcall.
17330 if (!DAG.getLibInfo().has(F: LibFunc_cbrt) ||
17331 (!DAG.getTargetLoweringInfo().isOperationExpand(Op: ISD::FPOW, VT) &&
17332 DAG.getTargetLoweringInfo().isOperationExpand(Op: ISD::FCBRT, VT)))
17333 return SDValue();
17334
17335 return DAG.getNode(Opcode: ISD::FCBRT, DL: SDLoc(N), VT, Operand: N->getOperand(Num: 0));
17336 }
17337
17338 // Try to convert x ** (1/4) and x ** (3/4) into square roots.
17339 // x ** (1/2) is canonicalized to sqrt, so we do not bother with that case.
17340 // TODO: This could be extended (using a target hook) to handle smaller
17341 // power-of-2 fractional exponents.
17342 bool ExponentIs025 = ExponentC->getValueAPF().isExactlyValue(V: 0.25);
17343 bool ExponentIs075 = ExponentC->getValueAPF().isExactlyValue(V: 0.75);
17344 if (ExponentIs025 || ExponentIs075) {
17345 // pow(-0.0, 0.25) = +0.0; sqrt(sqrt(-0.0)) = -0.0.
17346 // pow(-inf, 0.25) = +inf; sqrt(sqrt(-inf)) = NaN.
17347 // pow(-0.0, 0.75) = +0.0; sqrt(-0.0) * sqrt(sqrt(-0.0)) = +0.0.
17348 // pow(-inf, 0.75) = +inf; sqrt(-inf) * sqrt(sqrt(-inf)) = NaN.
17349 // For regular numbers, rounding may cause the results to differ.
17350 // Therefore, we require { nsz ninf afn } for this transform.
17351 // TODO: We could select out the special cases if we don't have nsz/ninf.
17352 SDNodeFlags Flags = N->getFlags();
17353
17354 // We only need no signed zeros for the 0.25 case.
17355 if ((!Flags.hasNoSignedZeros() && ExponentIs025) || !Flags.hasNoInfs() ||
17356 !Flags.hasApproximateFuncs())
17357 return SDValue();
17358
17359 // Don't double the number of libcalls. We are trying to inline fast code.
17360 if (!DAG.getTargetLoweringInfo().isOperationLegalOrCustom(Op: ISD::FSQRT, VT))
17361 return SDValue();
17362
17363 // Assume that libcalls are the smallest code.
17364 // TODO: This restriction should probably be lifted for vectors.
17365 if (ForCodeSize)
17366 return SDValue();
17367
17368 // pow(X, 0.25) --> sqrt(sqrt(X))
17369 SDLoc DL(N);
17370 SDValue Sqrt = DAG.getNode(Opcode: ISD::FSQRT, DL, VT, Operand: N->getOperand(Num: 0));
17371 SDValue SqrtSqrt = DAG.getNode(Opcode: ISD::FSQRT, DL, VT, Operand: Sqrt);
17372 if (ExponentIs025)
17373 return SqrtSqrt;
17374 // pow(X, 0.75) --> sqrt(X) * sqrt(sqrt(X))
17375 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Sqrt, N2: SqrtSqrt);
17376 }
17377
17378 return SDValue();
17379}
17380
17381static SDValue foldFPToIntToFP(SDNode *N, SelectionDAG &DAG,
17382 const TargetLowering &TLI) {
17383 // We only do this if the target has legal ftrunc. Otherwise, we'd likely be
17384 // replacing casts with a libcall. We also must be allowed to ignore -0.0
17385 // because FTRUNC will return -0.0 for (-1.0, -0.0), but using integer
17386 // conversions would return +0.0.
17387 // FIXME: We should be able to use node-level FMF here.
17388 // TODO: If strict math, should we use FABS (+ range check for signed cast)?
17389 EVT VT = N->getValueType(ResNo: 0);
17390 if (!TLI.isOperationLegal(Op: ISD::FTRUNC, VT) ||
17391 !DAG.getTarget().Options.NoSignedZerosFPMath)
17392 return SDValue();
17393
17394 // fptosi/fptoui round towards zero, so converting from FP to integer and
17395 // back is the same as an 'ftrunc': [us]itofp (fpto[us]i X) --> ftrunc X
17396 SDValue N0 = N->getOperand(Num: 0);
17397 if (N->getOpcode() == ISD::SINT_TO_FP && N0.getOpcode() == ISD::FP_TO_SINT &&
17398 N0.getOperand(i: 0).getValueType() == VT)
17399 return DAG.getNode(Opcode: ISD::FTRUNC, DL: SDLoc(N), VT, Operand: N0.getOperand(i: 0));
17400
17401 if (N->getOpcode() == ISD::UINT_TO_FP && N0.getOpcode() == ISD::FP_TO_UINT &&
17402 N0.getOperand(i: 0).getValueType() == VT)
17403 return DAG.getNode(Opcode: ISD::FTRUNC, DL: SDLoc(N), VT, Operand: N0.getOperand(i: 0));
17404
17405 return SDValue();
17406}
17407
17408SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) {
17409 SDValue N0 = N->getOperand(Num: 0);
17410 EVT VT = N->getValueType(ResNo: 0);
17411 EVT OpVT = N0.getValueType();
17412
17413 // [us]itofp(undef) = 0, because the result value is bounded.
17414 if (N0.isUndef())
17415 return DAG.getConstantFP(Val: 0.0, DL: SDLoc(N), VT);
17416
17417 // fold (sint_to_fp c1) -> c1fp
17418 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
17419 // ...but only if the target supports immediate floating-point values
17420 (!LegalOperations ||
17421 TLI.isOperationLegalOrCustom(Op: ISD::ConstantFP, VT)))
17422 return DAG.getNode(Opcode: ISD::SINT_TO_FP, DL: SDLoc(N), VT, Operand: N0);
17423
17424 // If the input is a legal type, and SINT_TO_FP is not legal on this target,
17425 // but UINT_TO_FP is legal on this target, try to convert.
17426 if (!hasOperation(Opcode: ISD::SINT_TO_FP, VT: OpVT) &&
17427 hasOperation(Opcode: ISD::UINT_TO_FP, VT: OpVT)) {
17428 // If the sign bit is known to be zero, we can change this to UINT_TO_FP.
17429 if (DAG.SignBitIsZero(Op: N0))
17430 return DAG.getNode(Opcode: ISD::UINT_TO_FP, DL: SDLoc(N), VT, Operand: N0);
17431 }
17432
17433 // The next optimizations are desirable only if SELECT_CC can be lowered.
17434 // fold (sint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), -1.0, 0.0)
17435 if (N0.getOpcode() == ISD::SETCC && N0.getValueType() == MVT::i1 &&
17436 !VT.isVector() &&
17437 (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
17438 SDLoc DL(N);
17439 return DAG.getSelect(DL, VT, Cond: N0, LHS: DAG.getConstantFP(Val: -1.0, DL, VT),
17440 RHS: DAG.getConstantFP(Val: 0.0, DL, VT));
17441 }
17442
17443 // fold (sint_to_fp (zext (setcc x, y, cc))) ->
17444 // (select (setcc x, y, cc), 1.0, 0.0)
17445 if (N0.getOpcode() == ISD::ZERO_EXTEND &&
17446 N0.getOperand(i: 0).getOpcode() == ISD::SETCC && !VT.isVector() &&
17447 (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::ConstantFP, VT))) {
17448 SDLoc DL(N);
17449 return DAG.getSelect(DL, VT, Cond: N0.getOperand(i: 0),
17450 LHS: DAG.getConstantFP(Val: 1.0, DL, VT),
17451 RHS: DAG.getConstantFP(Val: 0.0, DL, VT));
17452 }
17453
17454 if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
17455 return FTrunc;
17456
17457 return SDValue();
17458}
17459
17460SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) {
17461 SDValue N0 = N->getOperand(Num: 0);
17462 EVT VT = N->getValueType(ResNo: 0);
17463 EVT OpVT = N0.getValueType();
17464
17465 // [us]itofp(undef) = 0, because the result value is bounded.
17466 if (N0.isUndef())
17467 return DAG.getConstantFP(Val: 0.0, DL: SDLoc(N), VT);
17468
17469 // fold (uint_to_fp c1) -> c1fp
17470 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
17471 // ...but only if the target supports immediate floating-point values
17472 (!LegalOperations ||
17473 TLI.isOperationLegalOrCustom(Op: ISD::ConstantFP, VT)))
17474 return DAG.getNode(Opcode: ISD::UINT_TO_FP, DL: SDLoc(N), VT, Operand: N0);
17475
17476 // If the input is a legal type, and UINT_TO_FP is not legal on this target,
17477 // but SINT_TO_FP is legal on this target, try to convert.
17478 if (!hasOperation(Opcode: ISD::UINT_TO_FP, VT: OpVT) &&
17479 hasOperation(Opcode: ISD::SINT_TO_FP, VT: OpVT)) {
17480 // If the sign bit is known to be zero, we can change this to SINT_TO_FP.
17481 if (DAG.SignBitIsZero(Op: N0))
17482 return DAG.getNode(Opcode: ISD::SINT_TO_FP, DL: SDLoc(N), VT, Operand: N0);
17483 }
17484
17485 // fold (uint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), 1.0, 0.0)
17486 if (N0.getOpcode() == ISD::SETCC && !VT.isVector() &&
17487 (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::ConstantFP, VT))) {
17488 SDLoc DL(N);
17489 return DAG.getSelect(DL, VT, Cond: N0, LHS: DAG.getConstantFP(Val: 1.0, DL, VT),
17490 RHS: DAG.getConstantFP(Val: 0.0, DL, VT));
17491 }
17492
17493 if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
17494 return FTrunc;
17495
17496 return SDValue();
17497}
17498
17499// Fold (fp_to_{s/u}int ({s/u}int_to_fpx)) -> zext x, sext x, trunc x, or x
17500static SDValue FoldIntToFPToInt(SDNode *N, SelectionDAG &DAG) {
17501 SDValue N0 = N->getOperand(Num: 0);
17502 EVT VT = N->getValueType(ResNo: 0);
17503
17504 if (N0.getOpcode() != ISD::UINT_TO_FP && N0.getOpcode() != ISD::SINT_TO_FP)
17505 return SDValue();
17506
17507 SDValue Src = N0.getOperand(i: 0);
17508 EVT SrcVT = Src.getValueType();
17509 bool IsInputSigned = N0.getOpcode() == ISD::SINT_TO_FP;
17510 bool IsOutputSigned = N->getOpcode() == ISD::FP_TO_SINT;
17511
17512 // We can safely assume the conversion won't overflow the output range,
17513 // because (for example) (uint8_t)18293.f is undefined behavior.
17514
17515 // Since we can assume the conversion won't overflow, our decision as to
17516 // whether the input will fit in the float should depend on the minimum
17517 // of the input range and output range.
17518
17519 // This means this is also safe for a signed input and unsigned output, since
17520 // a negative input would lead to undefined behavior.
17521 unsigned InputSize = (int)SrcVT.getScalarSizeInBits() - IsInputSigned;
17522 unsigned OutputSize = (int)VT.getScalarSizeInBits();
17523 unsigned ActualSize = std::min(a: InputSize, b: OutputSize);
17524 const fltSemantics &sem = DAG.EVTToAPFloatSemantics(VT: N0.getValueType());
17525
17526 // We can only fold away the float conversion if the input range can be
17527 // represented exactly in the float range.
17528 if (APFloat::semanticsPrecision(sem) >= ActualSize) {
17529 if (VT.getScalarSizeInBits() > SrcVT.getScalarSizeInBits()) {
17530 unsigned ExtOp = IsInputSigned && IsOutputSigned ? ISD::SIGN_EXTEND
17531 : ISD::ZERO_EXTEND;
17532 return DAG.getNode(Opcode: ExtOp, DL: SDLoc(N), VT, Operand: Src);
17533 }
17534 if (VT.getScalarSizeInBits() < SrcVT.getScalarSizeInBits())
17535 return DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N), VT, Operand: Src);
17536 return DAG.getBitcast(VT, V: Src);
17537 }
17538 return SDValue();
17539}
17540
17541SDValue DAGCombiner::visitFP_TO_SINT(SDNode *N) {
17542 SDValue N0 = N->getOperand(Num: 0);
17543 EVT VT = N->getValueType(ResNo: 0);
17544
17545 // fold (fp_to_sint undef) -> undef
17546 if (N0.isUndef())
17547 return DAG.getUNDEF(VT);
17548
17549 // fold (fp_to_sint c1fp) -> c1
17550 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
17551 return DAG.getNode(Opcode: ISD::FP_TO_SINT, DL: SDLoc(N), VT, Operand: N0);
17552
17553 return FoldIntToFPToInt(N, DAG);
17554}
17555
17556SDValue DAGCombiner::visitFP_TO_UINT(SDNode *N) {
17557 SDValue N0 = N->getOperand(Num: 0);
17558 EVT VT = N->getValueType(ResNo: 0);
17559
17560 // fold (fp_to_uint undef) -> undef
17561 if (N0.isUndef())
17562 return DAG.getUNDEF(VT);
17563
17564 // fold (fp_to_uint c1fp) -> c1
17565 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
17566 return DAG.getNode(Opcode: ISD::FP_TO_UINT, DL: SDLoc(N), VT, Operand: N0);
17567
17568 return FoldIntToFPToInt(N, DAG);
17569}
17570
17571SDValue DAGCombiner::visitXRINT(SDNode *N) {
17572 SDValue N0 = N->getOperand(Num: 0);
17573 EVT VT = N->getValueType(ResNo: 0);
17574
17575 // fold (lrint|llrint undef) -> undef
17576 if (N0.isUndef())
17577 return DAG.getUNDEF(VT);
17578
17579 // fold (lrint|llrint c1fp) -> c1
17580 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
17581 return DAG.getNode(Opcode: N->getOpcode(), DL: SDLoc(N), VT, Operand: N0);
17582
17583 return SDValue();
17584}
17585
17586SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
17587 SDValue N0 = N->getOperand(Num: 0);
17588 SDValue N1 = N->getOperand(Num: 1);
17589 EVT VT = N->getValueType(ResNo: 0);
17590
17591 // fold (fp_round c1fp) -> c1fp
17592 if (SDValue C =
17593 DAG.FoldConstantArithmetic(Opcode: ISD::FP_ROUND, DL: SDLoc(N), VT, Ops: {N0, N1}))
17594 return C;
17595
17596 // fold (fp_round (fp_extend x)) -> x
17597 if (N0.getOpcode() == ISD::FP_EXTEND && VT == N0.getOperand(i: 0).getValueType())
17598 return N0.getOperand(i: 0);
17599
17600 // fold (fp_round (fp_round x)) -> (fp_round x)
17601 if (N0.getOpcode() == ISD::FP_ROUND) {
17602 const bool NIsTrunc = N->getConstantOperandVal(Num: 1) == 1;
17603 const bool N0IsTrunc = N0.getConstantOperandVal(i: 1) == 1;
17604
17605 // Avoid folding legal fp_rounds into non-legal ones.
17606 if (!hasOperation(Opcode: ISD::FP_ROUND, VT))
17607 return SDValue();
17608
17609 // Skip this folding if it results in an fp_round from f80 to f16.
17610 //
17611 // f80 to f16 always generates an expensive (and as yet, unimplemented)
17612 // libcall to __truncxfhf2 instead of selecting native f16 conversion
17613 // instructions from f32 or f64. Moreover, the first (value-preserving)
17614 // fp_round from f80 to either f32 or f64 may become a NOP in platforms like
17615 // x86.
17616 if (N0.getOperand(0).getValueType() == MVT::f80 && VT == MVT::f16)
17617 return SDValue();
17618
17619 // If the first fp_round isn't a value preserving truncation, it might
17620 // introduce a tie in the second fp_round, that wouldn't occur in the
17621 // single-step fp_round we want to fold to.
17622 // In other words, double rounding isn't the same as rounding.
17623 // Also, this is a value preserving truncation iff both fp_round's are.
17624 if (DAG.getTarget().Options.UnsafeFPMath || N0IsTrunc) {
17625 SDLoc DL(N);
17626 return DAG.getNode(
17627 Opcode: ISD::FP_ROUND, DL, VT, N1: N0.getOperand(i: 0),
17628 N2: DAG.getIntPtrConstant(Val: NIsTrunc && N0IsTrunc, DL, /*isTarget=*/true));
17629 }
17630 }
17631
17632 // fold (fp_round (copysign X, Y)) -> (copysign (fp_round X), Y)
17633 // Note: From a legality perspective, this is a two step transform. First,
17634 // we duplicate the fp_round to the arguments of the copysign, then we
17635 // eliminate the fp_round on Y. The second step requires an additional
17636 // predicate to match the implementation above.
17637 if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse() &&
17638 CanCombineFCOPYSIGN_EXTEND_ROUND(XTy: VT,
17639 YTy: N0.getValueType())) {
17640 SDValue Tmp = DAG.getNode(Opcode: ISD::FP_ROUND, DL: SDLoc(N0), VT,
17641 N1: N0.getOperand(i: 0), N2: N1);
17642 AddToWorklist(N: Tmp.getNode());
17643 return DAG.getNode(Opcode: ISD::FCOPYSIGN, DL: SDLoc(N), VT,
17644 N1: Tmp, N2: N0.getOperand(i: 1));
17645 }
17646
17647 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(Cast: N))
17648 return NewVSel;
17649
17650 return SDValue();
17651}
17652
17653SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
17654 SDValue N0 = N->getOperand(Num: 0);
17655 EVT VT = N->getValueType(ResNo: 0);
17656
17657 if (VT.isVector())
17658 if (SDValue FoldedVOp = SimplifyVCastOp(N, DL: SDLoc(N)))
17659 return FoldedVOp;
17660
17661 // If this is fp_round(fpextend), don't fold it, allow ourselves to be folded.
17662 if (N->hasOneUse() &&
17663 N->use_begin()->getOpcode() == ISD::FP_ROUND)
17664 return SDValue();
17665
17666 // fold (fp_extend c1fp) -> c1fp
17667 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
17668 return DAG.getNode(Opcode: ISD::FP_EXTEND, DL: SDLoc(N), VT, Operand: N0);
17669
17670 // fold (fp_extend (fp16_to_fp op)) -> (fp16_to_fp op)
17671 if (N0.getOpcode() == ISD::FP16_TO_FP &&
17672 TLI.getOperationAction(Op: ISD::FP16_TO_FP, VT) == TargetLowering::Legal)
17673 return DAG.getNode(Opcode: ISD::FP16_TO_FP, DL: SDLoc(N), VT, Operand: N0.getOperand(i: 0));
17674
17675 // Turn fp_extend(fp_round(X, 1)) -> x since the fp_round doesn't affect the
17676 // value of X.
17677 if (N0.getOpcode() == ISD::FP_ROUND
17678 && N0.getConstantOperandVal(i: 1) == 1) {
17679 SDValue In = N0.getOperand(i: 0);
17680 if (In.getValueType() == VT) return In;
17681 if (VT.bitsLT(VT: In.getValueType()))
17682 return DAG.getNode(Opcode: ISD::FP_ROUND, DL: SDLoc(N), VT,
17683 N1: In, N2: N0.getOperand(i: 1));
17684 return DAG.getNode(Opcode: ISD::FP_EXTEND, DL: SDLoc(N), VT, Operand: In);
17685 }
17686
17687 // fold (fpext (load x)) -> (fpext (fptrunc (extload x)))
17688 if (ISD::isNormalLoad(N: N0.getNode()) && N0.hasOneUse() &&
17689 TLI.isLoadExtLegalOrCustom(ExtType: ISD::EXTLOAD, ValVT: VT, MemVT: N0.getValueType())) {
17690 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
17691 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::EXTLOAD, dl: SDLoc(N), VT,
17692 Chain: LN0->getChain(),
17693 Ptr: LN0->getBasePtr(), MemVT: N0.getValueType(),
17694 MMO: LN0->getMemOperand());
17695 CombineTo(N, Res: ExtLoad);
17696 CombineTo(
17697 N: N0.getNode(),
17698 Res0: DAG.getNode(Opcode: ISD::FP_ROUND, DL: SDLoc(N0), VT: N0.getValueType(), N1: ExtLoad,
17699 N2: DAG.getIntPtrConstant(Val: 1, DL: SDLoc(N0), /*isTarget=*/true)),
17700 Res1: ExtLoad.getValue(R: 1));
17701 return SDValue(N, 0); // Return N so it doesn't get rechecked!
17702 }
17703
17704 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(Cast: N))
17705 return NewVSel;
17706
17707 return SDValue();
17708}
17709
17710SDValue DAGCombiner::visitFCEIL(SDNode *N) {
17711 SDValue N0 = N->getOperand(Num: 0);
17712 EVT VT = N->getValueType(ResNo: 0);
17713
17714 // fold (fceil c1) -> fceil(c1)
17715 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
17716 return DAG.getNode(Opcode: ISD::FCEIL, DL: SDLoc(N), VT, Operand: N0);
17717
17718 return SDValue();
17719}
17720
17721SDValue DAGCombiner::visitFTRUNC(SDNode *N) {
17722 SDValue N0 = N->getOperand(Num: 0);
17723 EVT VT = N->getValueType(ResNo: 0);
17724
17725 // fold (ftrunc c1) -> ftrunc(c1)
17726 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
17727 return DAG.getNode(Opcode: ISD::FTRUNC, DL: SDLoc(N), VT, Operand: N0);
17728
17729 // fold ftrunc (known rounded int x) -> x
17730 // ftrunc is a part of fptosi/fptoui expansion on some targets, so this is
17731 // likely to be generated to extract integer from a rounded floating value.
17732 switch (N0.getOpcode()) {
17733 default: break;
17734 case ISD::FRINT:
17735 case ISD::FTRUNC:
17736 case ISD::FNEARBYINT:
17737 case ISD::FROUNDEVEN:
17738 case ISD::FFLOOR:
17739 case ISD::FCEIL:
17740 return N0;
17741 }
17742
17743 return SDValue();
17744}
17745
17746SDValue DAGCombiner::visitFFREXP(SDNode *N) {
17747 SDValue N0 = N->getOperand(Num: 0);
17748
17749 // fold (ffrexp c1) -> ffrexp(c1)
17750 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
17751 return DAG.getNode(Opcode: ISD::FFREXP, DL: SDLoc(N), VTList: N->getVTList(), N: N0);
17752 return SDValue();
17753}
17754
17755SDValue DAGCombiner::visitFFLOOR(SDNode *N) {
17756 SDValue N0 = N->getOperand(Num: 0);
17757 EVT VT = N->getValueType(ResNo: 0);
17758
17759 // fold (ffloor c1) -> ffloor(c1)
17760 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
17761 return DAG.getNode(Opcode: ISD::FFLOOR, DL: SDLoc(N), VT, Operand: N0);
17762
17763 return SDValue();
17764}
17765
17766SDValue DAGCombiner::visitFNEG(SDNode *N) {
17767 SDValue N0 = N->getOperand(Num: 0);
17768 EVT VT = N->getValueType(ResNo: 0);
17769 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17770
17771 // Constant fold FNEG.
17772 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
17773 return DAG.getNode(Opcode: ISD::FNEG, DL: SDLoc(N), VT, Operand: N0);
17774
17775 if (SDValue NegN0 =
17776 TLI.getNegatedExpression(Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
17777 return NegN0;
17778
17779 // -(X-Y) -> (Y-X) is unsafe because when X==Y, -0.0 != +0.0
17780 // FIXME: This is duplicated in getNegatibleCost, but getNegatibleCost doesn't
17781 // know it was called from a context with a nsz flag if the input fsub does
17782 // not.
17783 if (N0.getOpcode() == ISD::FSUB &&
17784 (DAG.getTarget().Options.NoSignedZerosFPMath ||
17785 N->getFlags().hasNoSignedZeros()) && N0.hasOneUse()) {
17786 return DAG.getNode(Opcode: ISD::FSUB, DL: SDLoc(N), VT, N1: N0.getOperand(i: 1),
17787 N2: N0.getOperand(i: 0));
17788 }
17789
17790 if (SDValue Cast = foldSignChangeInBitcast(N))
17791 return Cast;
17792
17793 return SDValue();
17794}
17795
17796SDValue DAGCombiner::visitFMinMax(SDNode *N) {
17797 SDValue N0 = N->getOperand(Num: 0);
17798 SDValue N1 = N->getOperand(Num: 1);
17799 EVT VT = N->getValueType(ResNo: 0);
17800 const SDNodeFlags Flags = N->getFlags();
17801 unsigned Opc = N->getOpcode();
17802 bool PropagatesNaN = Opc == ISD::FMINIMUM || Opc == ISD::FMAXIMUM;
17803 bool IsMin = Opc == ISD::FMINNUM || Opc == ISD::FMINIMUM;
17804 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17805
17806 // Constant fold.
17807 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: Opc, DL: SDLoc(N), VT, Ops: {N0, N1}))
17808 return C;
17809
17810 // Canonicalize to constant on RHS.
17811 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0) &&
17812 !DAG.isConstantFPBuildVectorOrConstantFP(N: N1))
17813 return DAG.getNode(Opcode: N->getOpcode(), DL: SDLoc(N), VT, N1, N2: N0);
17814
17815 if (const ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N: N1)) {
17816 const APFloat &AF = N1CFP->getValueAPF();
17817
17818 // minnum(X, nan) -> X
17819 // maxnum(X, nan) -> X
17820 // minimum(X, nan) -> nan
17821 // maximum(X, nan) -> nan
17822 if (AF.isNaN())
17823 return PropagatesNaN ? N->getOperand(Num: 1) : N->getOperand(Num: 0);
17824
17825 // In the following folds, inf can be replaced with the largest finite
17826 // float, if the ninf flag is set.
17827 if (AF.isInfinity() || (Flags.hasNoInfs() && AF.isLargest())) {
17828 // minnum(X, -inf) -> -inf
17829 // maxnum(X, +inf) -> +inf
17830 // minimum(X, -inf) -> -inf if nnan
17831 // maximum(X, +inf) -> +inf if nnan
17832 if (IsMin == AF.isNegative() && (!PropagatesNaN || Flags.hasNoNaNs()))
17833 return N->getOperand(Num: 1);
17834
17835 // minnum(X, +inf) -> X if nnan
17836 // maxnum(X, -inf) -> X if nnan
17837 // minimum(X, +inf) -> X
17838 // maximum(X, -inf) -> X
17839 if (IsMin != AF.isNegative() && (PropagatesNaN || Flags.hasNoNaNs()))
17840 return N->getOperand(Num: 0);
17841 }
17842 }
17843
17844 if (SDValue SD = reassociateReduction(
17845 RedOpc: PropagatesNaN
17846 ? (IsMin ? ISD::VECREDUCE_FMINIMUM : ISD::VECREDUCE_FMAXIMUM)
17847 : (IsMin ? ISD::VECREDUCE_FMIN : ISD::VECREDUCE_FMAX),
17848 Opc, DL: SDLoc(N), VT, N0, N1, Flags))
17849 return SD;
17850
17851 return SDValue();
17852}
17853
17854SDValue DAGCombiner::visitFABS(SDNode *N) {
17855 SDValue N0 = N->getOperand(Num: 0);
17856 EVT VT = N->getValueType(ResNo: 0);
17857
17858 // fold (fabs c1) -> fabs(c1)
17859 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
17860 return DAG.getNode(Opcode: ISD::FABS, DL: SDLoc(N), VT, Operand: N0);
17861
17862 // fold (fabs (fabs x)) -> (fabs x)
17863 if (N0.getOpcode() == ISD::FABS)
17864 return N->getOperand(Num: 0);
17865
17866 // fold (fabs (fneg x)) -> (fabs x)
17867 // fold (fabs (fcopysign x, y)) -> (fabs x)
17868 if (N0.getOpcode() == ISD::FNEG || N0.getOpcode() == ISD::FCOPYSIGN)
17869 return DAG.getNode(Opcode: ISD::FABS, DL: SDLoc(N), VT, Operand: N0.getOperand(i: 0));
17870
17871 if (SDValue Cast = foldSignChangeInBitcast(N))
17872 return Cast;
17873
17874 return SDValue();
17875}
17876
17877SDValue DAGCombiner::visitBRCOND(SDNode *N) {
17878 SDValue Chain = N->getOperand(Num: 0);
17879 SDValue N1 = N->getOperand(Num: 1);
17880 SDValue N2 = N->getOperand(Num: 2);
17881
17882 // BRCOND(FREEZE(cond)) is equivalent to BRCOND(cond) (both are
17883 // nondeterministic jumps).
17884 if (N1->getOpcode() == ISD::FREEZE && N1.hasOneUse()) {
17885 return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other, Chain,
17886 N1->getOperand(0), N2);
17887 }
17888
17889 // Variant of the previous fold where there is a SETCC in between:
17890 // BRCOND(SETCC(FREEZE(X), CONST, Cond))
17891 // =>
17892 // BRCOND(FREEZE(SETCC(X, CONST, Cond)))
17893 // =>
17894 // BRCOND(SETCC(X, CONST, Cond))
17895 // This is correct if FREEZE(X) has one use and SETCC(FREEZE(X), CONST, Cond)
17896 // isn't equivalent to true or false.
17897 // For example, SETCC(FREEZE(X), -128, SETULT) cannot be folded to
17898 // FREEZE(SETCC(X, -128, SETULT)) because X can be poison.
17899 if (N1->getOpcode() == ISD::SETCC && N1.hasOneUse()) {
17900 SDValue S0 = N1->getOperand(Num: 0), S1 = N1->getOperand(Num: 1);
17901 ISD::CondCode Cond = cast<CondCodeSDNode>(Val: N1->getOperand(Num: 2))->get();
17902 ConstantSDNode *S0C = dyn_cast<ConstantSDNode>(Val&: S0);
17903 ConstantSDNode *S1C = dyn_cast<ConstantSDNode>(Val&: S1);
17904 bool Updated = false;
17905
17906 // Is 'X Cond C' always true or false?
17907 auto IsAlwaysTrueOrFalse = [](ISD::CondCode Cond, ConstantSDNode *C) {
17908 bool False = (Cond == ISD::SETULT && C->isZero()) ||
17909 (Cond == ISD::SETLT && C->isMinSignedValue()) ||
17910 (Cond == ISD::SETUGT && C->isAllOnes()) ||
17911 (Cond == ISD::SETGT && C->isMaxSignedValue());
17912 bool True = (Cond == ISD::SETULE && C->isAllOnes()) ||
17913 (Cond == ISD::SETLE && C->isMaxSignedValue()) ||
17914 (Cond == ISD::SETUGE && C->isZero()) ||
17915 (Cond == ISD::SETGE && C->isMinSignedValue());
17916 return True || False;
17917 };
17918
17919 if (S0->getOpcode() == ISD::FREEZE && S0.hasOneUse() && S1C) {
17920 if (!IsAlwaysTrueOrFalse(Cond, S1C)) {
17921 S0 = S0->getOperand(Num: 0);
17922 Updated = true;
17923 }
17924 }
17925 if (S1->getOpcode() == ISD::FREEZE && S1.hasOneUse() && S0C) {
17926 if (!IsAlwaysTrueOrFalse(ISD::getSetCCSwappedOperands(Operation: Cond), S0C)) {
17927 S1 = S1->getOperand(Num: 0);
17928 Updated = true;
17929 }
17930 }
17931
17932 if (Updated)
17933 return DAG.getNode(
17934 ISD::BRCOND, SDLoc(N), MVT::Other, Chain,
17935 DAG.getSetCC(SDLoc(N1), N1->getValueType(0), S0, S1, Cond), N2);
17936 }
17937
17938 // If N is a constant we could fold this into a fallthrough or unconditional
17939 // branch. However that doesn't happen very often in normal code, because
17940 // Instcombine/SimplifyCFG should have handled the available opportunities.
17941 // If we did this folding here, it would be necessary to update the
17942 // MachineBasicBlock CFG, which is awkward.
17943
17944 // fold a brcond with a setcc condition into a BR_CC node if BR_CC is legal
17945 // on the target.
17946 if (N1.getOpcode() == ISD::SETCC &&
17947 TLI.isOperationLegalOrCustom(Op: ISD::BR_CC,
17948 VT: N1.getOperand(i: 0).getValueType())) {
17949 return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
17950 Chain, N1.getOperand(2),
17951 N1.getOperand(0), N1.getOperand(1), N2);
17952 }
17953
17954 if (N1.hasOneUse()) {
17955 // rebuildSetCC calls visitXor which may change the Chain when there is a
17956 // STRICT_FSETCC/STRICT_FSETCCS involved. Use a handle to track changes.
17957 HandleSDNode ChainHandle(Chain);
17958 if (SDValue NewN1 = rebuildSetCC(N1))
17959 return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other,
17960 ChainHandle.getValue(), NewN1, N2);
17961 }
17962
17963 return SDValue();
17964}
17965
17966SDValue DAGCombiner::rebuildSetCC(SDValue N) {
17967 if (N.getOpcode() == ISD::SRL ||
17968 (N.getOpcode() == ISD::TRUNCATE &&
17969 (N.getOperand(i: 0).hasOneUse() &&
17970 N.getOperand(i: 0).getOpcode() == ISD::SRL))) {
17971 // Look pass the truncate.
17972 if (N.getOpcode() == ISD::TRUNCATE)
17973 N = N.getOperand(i: 0);
17974
17975 // Match this pattern so that we can generate simpler code:
17976 //
17977 // %a = ...
17978 // %b = and i32 %a, 2
17979 // %c = srl i32 %b, 1
17980 // brcond i32 %c ...
17981 //
17982 // into
17983 //
17984 // %a = ...
17985 // %b = and i32 %a, 2
17986 // %c = setcc eq %b, 0
17987 // brcond %c ...
17988 //
17989 // This applies only when the AND constant value has one bit set and the
17990 // SRL constant is equal to the log2 of the AND constant. The back-end is
17991 // smart enough to convert the result into a TEST/JMP sequence.
17992 SDValue Op0 = N.getOperand(i: 0);
17993 SDValue Op1 = N.getOperand(i: 1);
17994
17995 if (Op0.getOpcode() == ISD::AND && Op1.getOpcode() == ISD::Constant) {
17996 SDValue AndOp1 = Op0.getOperand(i: 1);
17997
17998 if (AndOp1.getOpcode() == ISD::Constant) {
17999 const APInt &AndConst = AndOp1->getAsAPIntVal();
18000
18001 if (AndConst.isPowerOf2() &&
18002 Op1->getAsAPIntVal() == AndConst.logBase2()) {
18003 SDLoc DL(N);
18004 return DAG.getSetCC(DL, VT: getSetCCResultType(VT: Op0.getValueType()),
18005 LHS: Op0, RHS: DAG.getConstant(Val: 0, DL, VT: Op0.getValueType()),
18006 Cond: ISD::SETNE);
18007 }
18008 }
18009 }
18010 }
18011
18012 // Transform (brcond (xor x, y)) -> (brcond (setcc, x, y, ne))
18013 // Transform (brcond (xor (xor x, y), -1)) -> (brcond (setcc, x, y, eq))
18014 if (N.getOpcode() == ISD::XOR) {
18015 // Because we may call this on a speculatively constructed
18016 // SimplifiedSetCC Node, we need to simplify this node first.
18017 // Ideally this should be folded into SimplifySetCC and not
18018 // here. For now, grab a handle to N so we don't lose it from
18019 // replacements interal to the visit.
18020 HandleSDNode XORHandle(N);
18021 while (N.getOpcode() == ISD::XOR) {
18022 SDValue Tmp = visitXOR(N: N.getNode());
18023 // No simplification done.
18024 if (!Tmp.getNode())
18025 break;
18026 // Returning N is form in-visit replacement that may invalidated
18027 // N. Grab value from Handle.
18028 if (Tmp.getNode() == N.getNode())
18029 N = XORHandle.getValue();
18030 else // Node simplified. Try simplifying again.
18031 N = Tmp;
18032 }
18033
18034 if (N.getOpcode() != ISD::XOR)
18035 return N;
18036
18037 SDValue Op0 = N->getOperand(Num: 0);
18038 SDValue Op1 = N->getOperand(Num: 1);
18039
18040 if (Op0.getOpcode() != ISD::SETCC && Op1.getOpcode() != ISD::SETCC) {
18041 bool Equal = false;
18042 // (brcond (xor (xor x, y), -1)) -> (brcond (setcc x, y, eq))
18043 if (isBitwiseNot(N) && Op0.hasOneUse() && Op0.getOpcode() == ISD::XOR &&
18044 Op0.getValueType() == MVT::i1) {
18045 N = Op0;
18046 Op0 = N->getOperand(Num: 0);
18047 Op1 = N->getOperand(Num: 1);
18048 Equal = true;
18049 }
18050
18051 EVT SetCCVT = N.getValueType();
18052 if (LegalTypes)
18053 SetCCVT = getSetCCResultType(VT: SetCCVT);
18054 // Replace the uses of XOR with SETCC
18055 return DAG.getSetCC(DL: SDLoc(N), VT: SetCCVT, LHS: Op0, RHS: Op1,
18056 Cond: Equal ? ISD::SETEQ : ISD::SETNE);
18057 }
18058 }
18059
18060 return SDValue();
18061}
18062
18063// Operand List for BR_CC: Chain, CondCC, CondLHS, CondRHS, DestBB.
18064//
18065SDValue DAGCombiner::visitBR_CC(SDNode *N) {
18066 CondCodeSDNode *CC = cast<CondCodeSDNode>(Val: N->getOperand(Num: 1));
18067 SDValue CondLHS = N->getOperand(Num: 2), CondRHS = N->getOperand(Num: 3);
18068
18069 // If N is a constant we could fold this into a fallthrough or unconditional
18070 // branch. However that doesn't happen very often in normal code, because
18071 // Instcombine/SimplifyCFG should have handled the available opportunities.
18072 // If we did this folding here, it would be necessary to update the
18073 // MachineBasicBlock CFG, which is awkward.
18074
18075 // Use SimplifySetCC to simplify SETCC's.
18076 SDValue Simp = SimplifySetCC(VT: getSetCCResultType(VT: CondLHS.getValueType()),
18077 N0: CondLHS, N1: CondRHS, Cond: CC->get(), DL: SDLoc(N),
18078 foldBooleans: false);
18079 if (Simp.getNode()) AddToWorklist(N: Simp.getNode());
18080
18081 // fold to a simpler setcc
18082 if (Simp.getNode() && Simp.getOpcode() == ISD::SETCC)
18083 return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
18084 N->getOperand(0), Simp.getOperand(2),
18085 Simp.getOperand(0), Simp.getOperand(1),
18086 N->getOperand(4));
18087
18088 return SDValue();
18089}
18090
18091static bool getCombineLoadStoreParts(SDNode *N, unsigned Inc, unsigned Dec,
18092 bool &IsLoad, bool &IsMasked, SDValue &Ptr,
18093 const TargetLowering &TLI) {
18094 if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Val: N)) {
18095 if (LD->isIndexed())
18096 return false;
18097 EVT VT = LD->getMemoryVT();
18098 if (!TLI.isIndexedLoadLegal(IdxMode: Inc, VT) && !TLI.isIndexedLoadLegal(IdxMode: Dec, VT))
18099 return false;
18100 Ptr = LD->getBasePtr();
18101 } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(Val: N)) {
18102 if (ST->isIndexed())
18103 return false;
18104 EVT VT = ST->getMemoryVT();
18105 if (!TLI.isIndexedStoreLegal(IdxMode: Inc, VT) && !TLI.isIndexedStoreLegal(IdxMode: Dec, VT))
18106 return false;
18107 Ptr = ST->getBasePtr();
18108 IsLoad = false;
18109 } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(Val: N)) {
18110 if (LD->isIndexed())
18111 return false;
18112 EVT VT = LD->getMemoryVT();
18113 if (!TLI.isIndexedMaskedLoadLegal(IdxMode: Inc, VT) &&
18114 !TLI.isIndexedMaskedLoadLegal(IdxMode: Dec, VT))
18115 return false;
18116 Ptr = LD->getBasePtr();
18117 IsMasked = true;
18118 } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(Val: N)) {
18119 if (ST->isIndexed())
18120 return false;
18121 EVT VT = ST->getMemoryVT();
18122 if (!TLI.isIndexedMaskedStoreLegal(IdxMode: Inc, VT) &&
18123 !TLI.isIndexedMaskedStoreLegal(IdxMode: Dec, VT))
18124 return false;
18125 Ptr = ST->getBasePtr();
18126 IsLoad = false;
18127 IsMasked = true;
18128 } else {
18129 return false;
18130 }
18131 return true;
18132}
18133
18134/// Try turning a load/store into a pre-indexed load/store when the base
18135/// pointer is an add or subtract and it has other uses besides the load/store.
18136/// After the transformation, the new indexed load/store has effectively folded
18137/// the add/subtract in and all of its other uses are redirected to the
18138/// new load/store.
18139bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) {
18140 if (Level < AfterLegalizeDAG)
18141 return false;
18142
18143 bool IsLoad = true;
18144 bool IsMasked = false;
18145 SDValue Ptr;
18146 if (!getCombineLoadStoreParts(N, Inc: ISD::PRE_INC, Dec: ISD::PRE_DEC, IsLoad, IsMasked,
18147 Ptr, TLI))
18148 return false;
18149
18150 // If the pointer is not an add/sub, or if it doesn't have multiple uses, bail
18151 // out. There is no reason to make this a preinc/predec.
18152 if ((Ptr.getOpcode() != ISD::ADD && Ptr.getOpcode() != ISD::SUB) ||
18153 Ptr->hasOneUse())
18154 return false;
18155
18156 // Ask the target to do addressing mode selection.
18157 SDValue BasePtr;
18158 SDValue Offset;
18159 ISD::MemIndexedMode AM = ISD::UNINDEXED;
18160 if (!TLI.getPreIndexedAddressParts(N, BasePtr, Offset, AM, DAG))
18161 return false;
18162
18163 // Backends without true r+i pre-indexed forms may need to pass a
18164 // constant base with a variable offset so that constant coercion
18165 // will work with the patterns in canonical form.
18166 bool Swapped = false;
18167 if (isa<ConstantSDNode>(Val: BasePtr)) {
18168 std::swap(a&: BasePtr, b&: Offset);
18169 Swapped = true;
18170 }
18171
18172 // Don't create a indexed load / store with zero offset.
18173 if (isNullConstant(V: Offset))
18174 return false;
18175
18176 // Try turning it into a pre-indexed load / store except when:
18177 // 1) The new base ptr is a frame index.
18178 // 2) If N is a store and the new base ptr is either the same as or is a
18179 // predecessor of the value being stored.
18180 // 3) Another use of old base ptr is a predecessor of N. If ptr is folded
18181 // that would create a cycle.
18182 // 4) All uses are load / store ops that use it as old base ptr.
18183
18184 // Check #1. Preinc'ing a frame index would require copying the stack pointer
18185 // (plus the implicit offset) to a register to preinc anyway.
18186 if (isa<FrameIndexSDNode>(Val: BasePtr) || isa<RegisterSDNode>(Val: BasePtr))
18187 return false;
18188
18189 // Check #2.
18190 if (!IsLoad) {
18191 SDValue Val = IsMasked ? cast<MaskedStoreSDNode>(Val: N)->getValue()
18192 : cast<StoreSDNode>(Val: N)->getValue();
18193
18194 // Would require a copy.
18195 if (Val == BasePtr)
18196 return false;
18197
18198 // Would create a cycle.
18199 if (Val == Ptr || Ptr->isPredecessorOf(N: Val.getNode()))
18200 return false;
18201 }
18202
18203 // Caches for hasPredecessorHelper.
18204 SmallPtrSet<const SDNode *, 32> Visited;
18205 SmallVector<const SDNode *, 16> Worklist;
18206 Worklist.push_back(Elt: N);
18207
18208 // If the offset is a constant, there may be other adds of constants that
18209 // can be folded with this one. We should do this to avoid having to keep
18210 // a copy of the original base pointer.
18211 SmallVector<SDNode *, 16> OtherUses;
18212 constexpr unsigned int MaxSteps = 8192;
18213 if (isa<ConstantSDNode>(Val: Offset))
18214 for (SDNode::use_iterator UI = BasePtr->use_begin(),
18215 UE = BasePtr->use_end();
18216 UI != UE; ++UI) {
18217 SDUse &Use = UI.getUse();
18218 // Skip the use that is Ptr and uses of other results from BasePtr's
18219 // node (important for nodes that return multiple results).
18220 if (Use.getUser() == Ptr.getNode() || Use != BasePtr)
18221 continue;
18222
18223 if (SDNode::hasPredecessorHelper(N: Use.getUser(), Visited, Worklist,
18224 MaxSteps))
18225 continue;
18226
18227 if (Use.getUser()->getOpcode() != ISD::ADD &&
18228 Use.getUser()->getOpcode() != ISD::SUB) {
18229 OtherUses.clear();
18230 break;
18231 }
18232
18233 SDValue Op1 = Use.getUser()->getOperand(Num: (UI.getOperandNo() + 1) & 1);
18234 if (!isa<ConstantSDNode>(Val: Op1)) {
18235 OtherUses.clear();
18236 break;
18237 }
18238
18239 // FIXME: In some cases, we can be smarter about this.
18240 if (Op1.getValueType() != Offset.getValueType()) {
18241 OtherUses.clear();
18242 break;
18243 }
18244
18245 OtherUses.push_back(Elt: Use.getUser());
18246 }
18247
18248 if (Swapped)
18249 std::swap(a&: BasePtr, b&: Offset);
18250
18251 // Now check for #3 and #4.
18252 bool RealUse = false;
18253
18254 for (SDNode *Use : Ptr->uses()) {
18255 if (Use == N)
18256 continue;
18257 if (SDNode::hasPredecessorHelper(N: Use, Visited, Worklist, MaxSteps))
18258 return false;
18259
18260 // If Ptr may be folded in addressing mode of other use, then it's
18261 // not profitable to do this transformation.
18262 if (!canFoldInAddressingMode(N: Ptr.getNode(), Use, DAG, TLI))
18263 RealUse = true;
18264 }
18265
18266 if (!RealUse)
18267 return false;
18268
18269 SDValue Result;
18270 if (!IsMasked) {
18271 if (IsLoad)
18272 Result = DAG.getIndexedLoad(OrigLoad: SDValue(N, 0), dl: SDLoc(N), Base: BasePtr, Offset, AM);
18273 else
18274 Result =
18275 DAG.getIndexedStore(OrigStore: SDValue(N, 0), dl: SDLoc(N), Base: BasePtr, Offset, AM);
18276 } else {
18277 if (IsLoad)
18278 Result = DAG.getIndexedMaskedLoad(OrigLoad: SDValue(N, 0), dl: SDLoc(N), Base: BasePtr,
18279 Offset, AM);
18280 else
18281 Result = DAG.getIndexedMaskedStore(OrigStore: SDValue(N, 0), dl: SDLoc(N), Base: BasePtr,
18282 Offset, AM);
18283 }
18284 ++PreIndexedNodes;
18285 ++NodesCombined;
18286 LLVM_DEBUG(dbgs() << "\nReplacing.4 "; N->dump(&DAG); dbgs() << "\nWith: ";
18287 Result.dump(&DAG); dbgs() << '\n');
18288 WorklistRemover DeadNodes(*this);
18289 if (IsLoad) {
18290 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result.getValue(R: 0));
18291 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: Result.getValue(R: 2));
18292 } else {
18293 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result.getValue(R: 1));
18294 }
18295
18296 // Finally, since the node is now dead, remove it from the graph.
18297 deleteAndRecombine(N);
18298
18299 if (Swapped)
18300 std::swap(a&: BasePtr, b&: Offset);
18301
18302 // Replace other uses of BasePtr that can be updated to use Ptr
18303 for (unsigned i = 0, e = OtherUses.size(); i != e; ++i) {
18304 unsigned OffsetIdx = 1;
18305 if (OtherUses[i]->getOperand(Num: OffsetIdx).getNode() == BasePtr.getNode())
18306 OffsetIdx = 0;
18307 assert(OtherUses[i]->getOperand(!OffsetIdx).getNode() ==
18308 BasePtr.getNode() && "Expected BasePtr operand");
18309
18310 // We need to replace ptr0 in the following expression:
18311 // x0 * offset0 + y0 * ptr0 = t0
18312 // knowing that
18313 // x1 * offset1 + y1 * ptr0 = t1 (the indexed load/store)
18314 //
18315 // where x0, x1, y0 and y1 in {-1, 1} are given by the types of the
18316 // indexed load/store and the expression that needs to be re-written.
18317 //
18318 // Therefore, we have:
18319 // t0 = (x0 * offset0 - x1 * y0 * y1 *offset1) + (y0 * y1) * t1
18320
18321 auto *CN = cast<ConstantSDNode>(Val: OtherUses[i]->getOperand(Num: OffsetIdx));
18322 const APInt &Offset0 = CN->getAPIntValue();
18323 const APInt &Offset1 = Offset->getAsAPIntVal();
18324 int X0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 1) ? -1 : 1;
18325 int Y0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 0) ? -1 : 1;
18326 int X1 = (AM == ISD::PRE_DEC && !Swapped) ? -1 : 1;
18327 int Y1 = (AM == ISD::PRE_DEC && Swapped) ? -1 : 1;
18328
18329 unsigned Opcode = (Y0 * Y1 < 0) ? ISD::SUB : ISD::ADD;
18330
18331 APInt CNV = Offset0;
18332 if (X0 < 0) CNV = -CNV;
18333 if (X1 * Y0 * Y1 < 0) CNV = CNV + Offset1;
18334 else CNV = CNV - Offset1;
18335
18336 SDLoc DL(OtherUses[i]);
18337
18338 // We can now generate the new expression.
18339 SDValue NewOp1 = DAG.getConstant(Val: CNV, DL, VT: CN->getValueType(ResNo: 0));
18340 SDValue NewOp2 = Result.getValue(R: IsLoad ? 1 : 0);
18341
18342 SDValue NewUse = DAG.getNode(Opcode,
18343 DL,
18344 VT: OtherUses[i]->getValueType(ResNo: 0), N1: NewOp1, N2: NewOp2);
18345 DAG.ReplaceAllUsesOfValueWith(From: SDValue(OtherUses[i], 0), To: NewUse);
18346 deleteAndRecombine(N: OtherUses[i]);
18347 }
18348
18349 // Replace the uses of Ptr with uses of the updated base value.
18350 DAG.ReplaceAllUsesOfValueWith(From: Ptr, To: Result.getValue(R: IsLoad ? 1 : 0));
18351 deleteAndRecombine(N: Ptr.getNode());
18352 AddToWorklist(N: Result.getNode());
18353
18354 return true;
18355}
18356
18357static bool shouldCombineToPostInc(SDNode *N, SDValue Ptr, SDNode *PtrUse,
18358 SDValue &BasePtr, SDValue &Offset,
18359 ISD::MemIndexedMode &AM,
18360 SelectionDAG &DAG,
18361 const TargetLowering &TLI) {
18362 if (PtrUse == N ||
18363 (PtrUse->getOpcode() != ISD::ADD && PtrUse->getOpcode() != ISD::SUB))
18364 return false;
18365
18366 if (!TLI.getPostIndexedAddressParts(N, PtrUse, BasePtr, Offset, AM, DAG))
18367 return false;
18368
18369 // Don't create a indexed load / store with zero offset.
18370 if (isNullConstant(V: Offset))
18371 return false;
18372
18373 if (isa<FrameIndexSDNode>(Val: BasePtr) || isa<RegisterSDNode>(Val: BasePtr))
18374 return false;
18375
18376 SmallPtrSet<const SDNode *, 32> Visited;
18377 for (SDNode *Use : BasePtr->uses()) {
18378 if (Use == Ptr.getNode())
18379 continue;
18380
18381 // No if there's a later user which could perform the index instead.
18382 if (isa<MemSDNode>(Val: Use)) {
18383 bool IsLoad = true;
18384 bool IsMasked = false;
18385 SDValue OtherPtr;
18386 if (getCombineLoadStoreParts(N: Use, Inc: ISD::POST_INC, Dec: ISD::POST_DEC, IsLoad,
18387 IsMasked, Ptr&: OtherPtr, TLI)) {
18388 SmallVector<const SDNode *, 2> Worklist;
18389 Worklist.push_back(Elt: Use);
18390 if (SDNode::hasPredecessorHelper(N, Visited, Worklist))
18391 return false;
18392 }
18393 }
18394
18395 // If all the uses are load / store addresses, then don't do the
18396 // transformation.
18397 if (Use->getOpcode() == ISD::ADD || Use->getOpcode() == ISD::SUB) {
18398 for (SDNode *UseUse : Use->uses())
18399 if (canFoldInAddressingMode(N: Use, Use: UseUse, DAG, TLI))
18400 return false;
18401 }
18402 }
18403 return true;
18404}
18405
18406static SDNode *getPostIndexedLoadStoreOp(SDNode *N, bool &IsLoad,
18407 bool &IsMasked, SDValue &Ptr,
18408 SDValue &BasePtr, SDValue &Offset,
18409 ISD::MemIndexedMode &AM,
18410 SelectionDAG &DAG,
18411 const TargetLowering &TLI) {
18412 if (!getCombineLoadStoreParts(N, Inc: ISD::POST_INC, Dec: ISD::POST_DEC, IsLoad,
18413 IsMasked, Ptr, TLI) ||
18414 Ptr->hasOneUse())
18415 return nullptr;
18416
18417 // Try turning it into a post-indexed load / store except when
18418 // 1) All uses are load / store ops that use it as base ptr (and
18419 // it may be folded as addressing mmode).
18420 // 2) Op must be independent of N, i.e. Op is neither a predecessor
18421 // nor a successor of N. Otherwise, if Op is folded that would
18422 // create a cycle.
18423 for (SDNode *Op : Ptr->uses()) {
18424 // Check for #1.
18425 if (!shouldCombineToPostInc(N, Ptr, PtrUse: Op, BasePtr, Offset, AM, DAG, TLI))
18426 continue;
18427
18428 // Check for #2.
18429 SmallPtrSet<const SDNode *, 32> Visited;
18430 SmallVector<const SDNode *, 8> Worklist;
18431 constexpr unsigned int MaxSteps = 8192;
18432 // Ptr is predecessor to both N and Op.
18433 Visited.insert(Ptr: Ptr.getNode());
18434 Worklist.push_back(Elt: N);
18435 Worklist.push_back(Elt: Op);
18436 if (!SDNode::hasPredecessorHelper(N, Visited, Worklist, MaxSteps) &&
18437 !SDNode::hasPredecessorHelper(N: Op, Visited, Worklist, MaxSteps))
18438 return Op;
18439 }
18440 return nullptr;
18441}
18442
18443/// Try to combine a load/store with a add/sub of the base pointer node into a
18444/// post-indexed load/store. The transformation folded the add/subtract into the
18445/// new indexed load/store effectively and all of its uses are redirected to the
18446/// new load/store.
18447bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) {
18448 if (Level < AfterLegalizeDAG)
18449 return false;
18450
18451 bool IsLoad = true;
18452 bool IsMasked = false;
18453 SDValue Ptr;
18454 SDValue BasePtr;
18455 SDValue Offset;
18456 ISD::MemIndexedMode AM = ISD::UNINDEXED;
18457 SDNode *Op = getPostIndexedLoadStoreOp(N, IsLoad, IsMasked, Ptr, BasePtr,
18458 Offset, AM, DAG, TLI);
18459 if (!Op)
18460 return false;
18461
18462 SDValue Result;
18463 if (!IsMasked)
18464 Result = IsLoad ? DAG.getIndexedLoad(OrigLoad: SDValue(N, 0), dl: SDLoc(N), Base: BasePtr,
18465 Offset, AM)
18466 : DAG.getIndexedStore(OrigStore: SDValue(N, 0), dl: SDLoc(N),
18467 Base: BasePtr, Offset, AM);
18468 else
18469 Result = IsLoad ? DAG.getIndexedMaskedLoad(OrigLoad: SDValue(N, 0), dl: SDLoc(N),
18470 Base: BasePtr, Offset, AM)
18471 : DAG.getIndexedMaskedStore(OrigStore: SDValue(N, 0), dl: SDLoc(N),
18472 Base: BasePtr, Offset, AM);
18473 ++PostIndexedNodes;
18474 ++NodesCombined;
18475 LLVM_DEBUG(dbgs() << "\nReplacing.5 "; N->dump(&DAG); dbgs() << "\nWith: ";
18476 Result.dump(&DAG); dbgs() << '\n');
18477 WorklistRemover DeadNodes(*this);
18478 if (IsLoad) {
18479 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result.getValue(R: 0));
18480 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: Result.getValue(R: 2));
18481 } else {
18482 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result.getValue(R: 1));
18483 }
18484
18485 // Finally, since the node is now dead, remove it from the graph.
18486 deleteAndRecombine(N);
18487
18488 // Replace the uses of Use with uses of the updated base value.
18489 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Op, 0),
18490 To: Result.getValue(R: IsLoad ? 1 : 0));
18491 deleteAndRecombine(N: Op);
18492 return true;
18493}
18494
18495/// Return the base-pointer arithmetic from an indexed \p LD.
18496SDValue DAGCombiner::SplitIndexingFromLoad(LoadSDNode *LD) {
18497 ISD::MemIndexedMode AM = LD->getAddressingMode();
18498 assert(AM != ISD::UNINDEXED);
18499 SDValue BP = LD->getOperand(Num: 1);
18500 SDValue Inc = LD->getOperand(Num: 2);
18501
18502 // Some backends use TargetConstants for load offsets, but don't expect
18503 // TargetConstants in general ADD nodes. We can convert these constants into
18504 // regular Constants (if the constant is not opaque).
18505 assert((Inc.getOpcode() != ISD::TargetConstant ||
18506 !cast<ConstantSDNode>(Inc)->isOpaque()) &&
18507 "Cannot split out indexing using opaque target constants");
18508 if (Inc.getOpcode() == ISD::TargetConstant) {
18509 ConstantSDNode *ConstInc = cast<ConstantSDNode>(Val&: Inc);
18510 Inc = DAG.getConstant(Val: *ConstInc->getConstantIntValue(), DL: SDLoc(Inc),
18511 VT: ConstInc->getValueType(ResNo: 0));
18512 }
18513
18514 unsigned Opc =
18515 (AM == ISD::PRE_INC || AM == ISD::POST_INC ? ISD::ADD : ISD::SUB);
18516 return DAG.getNode(Opcode: Opc, DL: SDLoc(LD), VT: BP.getSimpleValueType(), N1: BP, N2: Inc);
18517}
18518
18519static inline ElementCount numVectorEltsOrZero(EVT T) {
18520 return T.isVector() ? T.getVectorElementCount() : ElementCount::getFixed(MinVal: 0);
18521}
18522
18523bool DAGCombiner::getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val) {
18524 EVT STType = Val.getValueType();
18525 EVT STMemType = ST->getMemoryVT();
18526 if (STType == STMemType)
18527 return true;
18528 if (isTypeLegal(VT: STMemType))
18529 return false; // fail.
18530 if (STType.isFloatingPoint() && STMemType.isFloatingPoint() &&
18531 TLI.isOperationLegal(Op: ISD::FTRUNC, VT: STMemType)) {
18532 Val = DAG.getNode(Opcode: ISD::FTRUNC, DL: SDLoc(ST), VT: STMemType, Operand: Val);
18533 return true;
18534 }
18535 if (numVectorEltsOrZero(T: STType) == numVectorEltsOrZero(T: STMemType) &&
18536 STType.isInteger() && STMemType.isInteger()) {
18537 Val = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(ST), VT: STMemType, Operand: Val);
18538 return true;
18539 }
18540 if (STType.getSizeInBits() == STMemType.getSizeInBits()) {
18541 Val = DAG.getBitcast(VT: STMemType, V: Val);
18542 return true;
18543 }
18544 return false; // fail.
18545}
18546
18547bool DAGCombiner::extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val) {
18548 EVT LDMemType = LD->getMemoryVT();
18549 EVT LDType = LD->getValueType(ResNo: 0);
18550 assert(Val.getValueType() == LDMemType &&
18551 "Attempting to extend value of non-matching type");
18552 if (LDType == LDMemType)
18553 return true;
18554 if (LDMemType.isInteger() && LDType.isInteger()) {
18555 switch (LD->getExtensionType()) {
18556 case ISD::NON_EXTLOAD:
18557 Val = DAG.getBitcast(VT: LDType, V: Val);
18558 return true;
18559 case ISD::EXTLOAD:
18560 Val = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL: SDLoc(LD), VT: LDType, Operand: Val);
18561 return true;
18562 case ISD::SEXTLOAD:
18563 Val = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL: SDLoc(LD), VT: LDType, Operand: Val);
18564 return true;
18565 case ISD::ZEXTLOAD:
18566 Val = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL: SDLoc(LD), VT: LDType, Operand: Val);
18567 return true;
18568 }
18569 }
18570 return false;
18571}
18572
18573StoreSDNode *DAGCombiner::getUniqueStoreFeeding(LoadSDNode *LD,
18574 int64_t &Offset) {
18575 SDValue Chain = LD->getOperand(Num: 0);
18576
18577 // Look through CALLSEQ_START.
18578 if (Chain.getOpcode() == ISD::CALLSEQ_START)
18579 Chain = Chain->getOperand(Num: 0);
18580
18581 StoreSDNode *ST = nullptr;
18582 SmallVector<SDValue, 8> Aliases;
18583 if (Chain.getOpcode() == ISD::TokenFactor) {
18584 // Look for unique store within the TokenFactor.
18585 for (SDValue Op : Chain->ops()) {
18586 StoreSDNode *Store = dyn_cast<StoreSDNode>(Val: Op.getNode());
18587 if (!Store)
18588 continue;
18589 BaseIndexOffset BasePtrLD = BaseIndexOffset::match(N: LD, DAG);
18590 BaseIndexOffset BasePtrST = BaseIndexOffset::match(N: Store, DAG);
18591 if (!BasePtrST.equalBaseIndex(Other: BasePtrLD, DAG, Off&: Offset))
18592 continue;
18593 // Make sure the store is not aliased with any nodes in TokenFactor.
18594 GatherAllAliases(N: Store, OriginalChain: Chain, Aliases);
18595 if (Aliases.empty() ||
18596 (Aliases.size() == 1 && Aliases.front().getNode() == Store))
18597 ST = Store;
18598 break;
18599 }
18600 } else {
18601 StoreSDNode *Store = dyn_cast<StoreSDNode>(Val: Chain.getNode());
18602 if (Store) {
18603 BaseIndexOffset BasePtrLD = BaseIndexOffset::match(N: LD, DAG);
18604 BaseIndexOffset BasePtrST = BaseIndexOffset::match(N: Store, DAG);
18605 if (BasePtrST.equalBaseIndex(Other: BasePtrLD, DAG, Off&: Offset))
18606 ST = Store;
18607 }
18608 }
18609
18610 return ST;
18611}
18612
18613SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) {
18614 if (OptLevel == CodeGenOptLevel::None || !LD->isSimple())
18615 return SDValue();
18616 SDValue Chain = LD->getOperand(Num: 0);
18617 int64_t Offset;
18618
18619 StoreSDNode *ST = getUniqueStoreFeeding(LD, Offset);
18620 // TODO: Relax this restriction for unordered atomics (see D66309)
18621 if (!ST || !ST->isSimple() || ST->getAddressSpace() != LD->getAddressSpace())
18622 return SDValue();
18623
18624 EVT LDType = LD->getValueType(ResNo: 0);
18625 EVT LDMemType = LD->getMemoryVT();
18626 EVT STMemType = ST->getMemoryVT();
18627 EVT STType = ST->getValue().getValueType();
18628
18629 // There are two cases to consider here:
18630 // 1. The store is fixed width and the load is scalable. In this case we
18631 // don't know at compile time if the store completely envelops the load
18632 // so we abandon the optimisation.
18633 // 2. The store is scalable and the load is fixed width. We could
18634 // potentially support a limited number of cases here, but there has been
18635 // no cost-benefit analysis to prove it's worth it.
18636 bool LdStScalable = LDMemType.isScalableVT();
18637 if (LdStScalable != STMemType.isScalableVT())
18638 return SDValue();
18639
18640 // If we are dealing with scalable vectors on a big endian platform the
18641 // calculation of offsets below becomes trickier, since we do not know at
18642 // compile time the absolute size of the vector. Until we've done more
18643 // analysis on big-endian platforms it seems better to bail out for now.
18644 if (LdStScalable && DAG.getDataLayout().isBigEndian())
18645 return SDValue();
18646
18647 // Normalize for Endianness. After this Offset=0 will denote that the least
18648 // significant bit in the loaded value maps to the least significant bit in
18649 // the stored value). With Offset=n (for n > 0) the loaded value starts at the
18650 // n:th least significant byte of the stored value.
18651 int64_t OrigOffset = Offset;
18652 if (DAG.getDataLayout().isBigEndian())
18653 Offset = ((int64_t)STMemType.getStoreSizeInBits().getFixedValue() -
18654 (int64_t)LDMemType.getStoreSizeInBits().getFixedValue()) /
18655 8 -
18656 Offset;
18657
18658 // Check that the stored value cover all bits that are loaded.
18659 bool STCoversLD;
18660
18661 TypeSize LdMemSize = LDMemType.getSizeInBits();
18662 TypeSize StMemSize = STMemType.getSizeInBits();
18663 if (LdStScalable)
18664 STCoversLD = (Offset == 0) && LdMemSize == StMemSize;
18665 else
18666 STCoversLD = (Offset >= 0) && (Offset * 8 + LdMemSize.getFixedValue() <=
18667 StMemSize.getFixedValue());
18668
18669 auto ReplaceLd = [&](LoadSDNode *LD, SDValue Val, SDValue Chain) -> SDValue {
18670 if (LD->isIndexed()) {
18671 // Cannot handle opaque target constants and we must respect the user's
18672 // request not to split indexes from loads.
18673 if (!canSplitIdx(LD))
18674 return SDValue();
18675 SDValue Idx = SplitIndexingFromLoad(LD);
18676 SDValue Ops[] = {Val, Idx, Chain};
18677 return CombineTo(N: LD, To: Ops, NumTo: 3);
18678 }
18679 return CombineTo(N: LD, Res0: Val, Res1: Chain);
18680 };
18681
18682 if (!STCoversLD)
18683 return SDValue();
18684
18685 // Memory as copy space (potentially masked).
18686 if (Offset == 0 && LDType == STType && STMemType == LDMemType) {
18687 // Simple case: Direct non-truncating forwarding
18688 if (LDType.getSizeInBits() == LdMemSize)
18689 return ReplaceLd(LD, ST->getValue(), Chain);
18690 // Can we model the truncate and extension with an and mask?
18691 if (STType.isInteger() && LDMemType.isInteger() && !STType.isVector() &&
18692 !LDMemType.isVector() && LD->getExtensionType() != ISD::SEXTLOAD) {
18693 // Mask to size of LDMemType
18694 auto Mask =
18695 DAG.getConstant(Val: APInt::getLowBitsSet(numBits: STType.getFixedSizeInBits(),
18696 loBitsSet: StMemSize.getFixedValue()),
18697 DL: SDLoc(ST), VT: STType);
18698 auto Val = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(LD), VT: LDType, N1: ST->getValue(), N2: Mask);
18699 return ReplaceLd(LD, Val, Chain);
18700 }
18701 }
18702
18703 // Handle some cases for big-endian that would be Offset 0 and handled for
18704 // little-endian.
18705 SDValue Val = ST->getValue();
18706 if (DAG.getDataLayout().isBigEndian() && Offset > 0 && OrigOffset == 0) {
18707 if (STType.isInteger() && !STType.isVector() && LDType.isInteger() &&
18708 !LDType.isVector() && isTypeLegal(VT: STType) &&
18709 TLI.isOperationLegal(Op: ISD::SRL, VT: STType)) {
18710 Val = DAG.getNode(Opcode: ISD::SRL, DL: SDLoc(LD), VT: STType, N1: Val,
18711 N2: DAG.getConstant(Val: Offset * 8, DL: SDLoc(LD), VT: STType));
18712 Offset = 0;
18713 }
18714 }
18715
18716 // TODO: Deal with nonzero offset.
18717 if (LD->getBasePtr().isUndef() || Offset != 0)
18718 return SDValue();
18719 // Model necessary truncations / extenstions.
18720 // Truncate Value To Stored Memory Size.
18721 do {
18722 if (!getTruncatedStoreValue(ST, Val))
18723 continue;
18724 if (!isTypeLegal(VT: LDMemType))
18725 continue;
18726 if (STMemType != LDMemType) {
18727 // TODO: Support vectors? This requires extract_subvector/bitcast.
18728 if (!STMemType.isVector() && !LDMemType.isVector() &&
18729 STMemType.isInteger() && LDMemType.isInteger())
18730 Val = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(LD), VT: LDMemType, Operand: Val);
18731 else
18732 continue;
18733 }
18734 if (!extendLoadedValueToExtension(LD, Val))
18735 continue;
18736 return ReplaceLd(LD, Val, Chain);
18737 } while (false);
18738
18739 // On failure, cleanup dead nodes we may have created.
18740 if (Val->use_empty())
18741 deleteAndRecombine(N: Val.getNode());
18742 return SDValue();
18743}
18744
18745SDValue DAGCombiner::visitLOAD(SDNode *N) {
18746 LoadSDNode *LD = cast<LoadSDNode>(Val: N);
18747 SDValue Chain = LD->getChain();
18748 SDValue Ptr = LD->getBasePtr();
18749
18750 // If load is not volatile and there are no uses of the loaded value (and
18751 // the updated indexed value in case of indexed loads), change uses of the
18752 // chain value into uses of the chain input (i.e. delete the dead load).
18753 // TODO: Allow this for unordered atomics (see D66309)
18754 if (LD->isSimple()) {
18755 if (N->getValueType(1) == MVT::Other) {
18756 // Unindexed loads.
18757 if (!N->hasAnyUseOfValue(Value: 0)) {
18758 // It's not safe to use the two value CombineTo variant here. e.g.
18759 // v1, chain2 = load chain1, loc
18760 // v2, chain3 = load chain2, loc
18761 // v3 = add v2, c
18762 // Now we replace use of chain2 with chain1. This makes the second load
18763 // isomorphic to the one we are deleting, and thus makes this load live.
18764 LLVM_DEBUG(dbgs() << "\nReplacing.6 "; N->dump(&DAG);
18765 dbgs() << "\nWith chain: "; Chain.dump(&DAG);
18766 dbgs() << "\n");
18767 WorklistRemover DeadNodes(*this);
18768 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: Chain);
18769 AddUsersToWorklist(N: Chain.getNode());
18770 if (N->use_empty())
18771 deleteAndRecombine(N);
18772
18773 return SDValue(N, 0); // Return N so it doesn't get rechecked!
18774 }
18775 } else {
18776 // Indexed loads.
18777 assert(N->getValueType(2) == MVT::Other && "Malformed indexed loads?");
18778
18779 // If this load has an opaque TargetConstant offset, then we cannot split
18780 // the indexing into an add/sub directly (that TargetConstant may not be
18781 // valid for a different type of node, and we cannot convert an opaque
18782 // target constant into a regular constant).
18783 bool CanSplitIdx = canSplitIdx(LD);
18784
18785 if (!N->hasAnyUseOfValue(Value: 0) && (CanSplitIdx || !N->hasAnyUseOfValue(Value: 1))) {
18786 SDValue Undef = DAG.getUNDEF(VT: N->getValueType(ResNo: 0));
18787 SDValue Index;
18788 if (N->hasAnyUseOfValue(Value: 1) && CanSplitIdx) {
18789 Index = SplitIndexingFromLoad(LD);
18790 // Try to fold the base pointer arithmetic into subsequent loads and
18791 // stores.
18792 AddUsersToWorklist(N);
18793 } else
18794 Index = DAG.getUNDEF(VT: N->getValueType(ResNo: 1));
18795 LLVM_DEBUG(dbgs() << "\nReplacing.7 "; N->dump(&DAG);
18796 dbgs() << "\nWith: "; Undef.dump(&DAG);
18797 dbgs() << " and 2 other values\n");
18798 WorklistRemover DeadNodes(*this);
18799 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Undef);
18800 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: Index);
18801 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 2), To: Chain);
18802 deleteAndRecombine(N);
18803 return SDValue(N, 0); // Return N so it doesn't get rechecked!
18804 }
18805 }
18806 }
18807
18808 // If this load is directly stored, replace the load value with the stored
18809 // value.
18810 if (auto V = ForwardStoreValueToDirectLoad(LD))
18811 return V;
18812
18813 // Try to infer better alignment information than the load already has.
18814 if (OptLevel != CodeGenOptLevel::None && LD->isUnindexed() &&
18815 !LD->isAtomic()) {
18816 if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
18817 if (*Alignment > LD->getAlign() &&
18818 isAligned(Lhs: *Alignment, SizeInBytes: LD->getSrcValueOffset())) {
18819 SDValue NewLoad = DAG.getExtLoad(
18820 ExtType: LD->getExtensionType(), dl: SDLoc(N), VT: LD->getValueType(ResNo: 0), Chain, Ptr,
18821 PtrInfo: LD->getPointerInfo(), MemVT: LD->getMemoryVT(), Alignment: *Alignment,
18822 MMOFlags: LD->getMemOperand()->getFlags(), AAInfo: LD->getAAInfo());
18823 // NewLoad will always be N as we are only refining the alignment
18824 assert(NewLoad.getNode() == N);
18825 (void)NewLoad;
18826 }
18827 }
18828 }
18829
18830 if (LD->isUnindexed()) {
18831 // Walk up chain skipping non-aliasing memory nodes.
18832 SDValue BetterChain = FindBetterChain(N: LD, Chain);
18833
18834 // If there is a better chain.
18835 if (Chain != BetterChain) {
18836 SDValue ReplLoad;
18837
18838 // Replace the chain to void dependency.
18839 if (LD->getExtensionType() == ISD::NON_EXTLOAD) {
18840 ReplLoad = DAG.getLoad(VT: N->getValueType(ResNo: 0), dl: SDLoc(LD),
18841 Chain: BetterChain, Ptr, MMO: LD->getMemOperand());
18842 } else {
18843 ReplLoad = DAG.getExtLoad(ExtType: LD->getExtensionType(), dl: SDLoc(LD),
18844 VT: LD->getValueType(ResNo: 0),
18845 Chain: BetterChain, Ptr, MemVT: LD->getMemoryVT(),
18846 MMO: LD->getMemOperand());
18847 }
18848
18849 // Create token factor to keep old chain connected.
18850 SDValue Token = DAG.getNode(ISD::TokenFactor, SDLoc(N),
18851 MVT::Other, Chain, ReplLoad.getValue(1));
18852
18853 // Replace uses with load result and token factor
18854 return CombineTo(N, Res0: ReplLoad.getValue(R: 0), Res1: Token);
18855 }
18856 }
18857
18858 // Try transforming N to an indexed load.
18859 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
18860 return SDValue(N, 0);
18861
18862 // Try to slice up N to more direct loads if the slices are mapped to
18863 // different register banks or pairing can take place.
18864 if (SliceUpLoad(N))
18865 return SDValue(N, 0);
18866
18867 return SDValue();
18868}
18869
18870namespace {
18871
18872/// Helper structure used to slice a load in smaller loads.
18873/// Basically a slice is obtained from the following sequence:
18874/// Origin = load Ty1, Base
18875/// Shift = srl Ty1 Origin, CstTy Amount
18876/// Inst = trunc Shift to Ty2
18877///
18878/// Then, it will be rewritten into:
18879/// Slice = load SliceTy, Base + SliceOffset
18880/// [Inst = zext Slice to Ty2], only if SliceTy <> Ty2
18881///
18882/// SliceTy is deduced from the number of bits that are actually used to
18883/// build Inst.
18884struct LoadedSlice {
18885 /// Helper structure used to compute the cost of a slice.
18886 struct Cost {
18887 /// Are we optimizing for code size.
18888 bool ForCodeSize = false;
18889
18890 /// Various cost.
18891 unsigned Loads = 0;
18892 unsigned Truncates = 0;
18893 unsigned CrossRegisterBanksCopies = 0;
18894 unsigned ZExts = 0;
18895 unsigned Shift = 0;
18896
18897 explicit Cost(bool ForCodeSize) : ForCodeSize(ForCodeSize) {}
18898
18899 /// Get the cost of one isolated slice.
18900 Cost(const LoadedSlice &LS, bool ForCodeSize)
18901 : ForCodeSize(ForCodeSize), Loads(1) {
18902 EVT TruncType = LS.Inst->getValueType(ResNo: 0);
18903 EVT LoadedType = LS.getLoadedType();
18904 if (TruncType != LoadedType &&
18905 !LS.DAG->getTargetLoweringInfo().isZExtFree(FromTy: LoadedType, ToTy: TruncType))
18906 ZExts = 1;
18907 }
18908
18909 /// Account for slicing gain in the current cost.
18910 /// Slicing provide a few gains like removing a shift or a
18911 /// truncate. This method allows to grow the cost of the original
18912 /// load with the gain from this slice.
18913 void addSliceGain(const LoadedSlice &LS) {
18914 // Each slice saves a truncate.
18915 const TargetLowering &TLI = LS.DAG->getTargetLoweringInfo();
18916 if (!TLI.isTruncateFree(Val: LS.Inst->getOperand(Num: 0), VT2: LS.Inst->getValueType(ResNo: 0)))
18917 ++Truncates;
18918 // If there is a shift amount, this slice gets rid of it.
18919 if (LS.Shift)
18920 ++Shift;
18921 // If this slice can merge a cross register bank copy, account for it.
18922 if (LS.canMergeExpensiveCrossRegisterBankCopy())
18923 ++CrossRegisterBanksCopies;
18924 }
18925
18926 Cost &operator+=(const Cost &RHS) {
18927 Loads += RHS.Loads;
18928 Truncates += RHS.Truncates;
18929 CrossRegisterBanksCopies += RHS.CrossRegisterBanksCopies;
18930 ZExts += RHS.ZExts;
18931 Shift += RHS.Shift;
18932 return *this;
18933 }
18934
18935 bool operator==(const Cost &RHS) const {
18936 return Loads == RHS.Loads && Truncates == RHS.Truncates &&
18937 CrossRegisterBanksCopies == RHS.CrossRegisterBanksCopies &&
18938 ZExts == RHS.ZExts && Shift == RHS.Shift;
18939 }
18940
18941 bool operator!=(const Cost &RHS) const { return !(*this == RHS); }
18942
18943 bool operator<(const Cost &RHS) const {
18944 // Assume cross register banks copies are as expensive as loads.
18945 // FIXME: Do we want some more target hooks?
18946 unsigned ExpensiveOpsLHS = Loads + CrossRegisterBanksCopies;
18947 unsigned ExpensiveOpsRHS = RHS.Loads + RHS.CrossRegisterBanksCopies;
18948 // Unless we are optimizing for code size, consider the
18949 // expensive operation first.
18950 if (!ForCodeSize && ExpensiveOpsLHS != ExpensiveOpsRHS)
18951 return ExpensiveOpsLHS < ExpensiveOpsRHS;
18952 return (Truncates + ZExts + Shift + ExpensiveOpsLHS) <
18953 (RHS.Truncates + RHS.ZExts + RHS.Shift + ExpensiveOpsRHS);
18954 }
18955
18956 bool operator>(const Cost &RHS) const { return RHS < *this; }
18957
18958 bool operator<=(const Cost &RHS) const { return !(RHS < *this); }
18959
18960 bool operator>=(const Cost &RHS) const { return !(*this < RHS); }
18961 };
18962
18963 // The last instruction that represent the slice. This should be a
18964 // truncate instruction.
18965 SDNode *Inst;
18966
18967 // The original load instruction.
18968 LoadSDNode *Origin;
18969
18970 // The right shift amount in bits from the original load.
18971 unsigned Shift;
18972
18973 // The DAG from which Origin came from.
18974 // This is used to get some contextual information about legal types, etc.
18975 SelectionDAG *DAG;
18976
18977 LoadedSlice(SDNode *Inst = nullptr, LoadSDNode *Origin = nullptr,
18978 unsigned Shift = 0, SelectionDAG *DAG = nullptr)
18979 : Inst(Inst), Origin(Origin), Shift(Shift), DAG(DAG) {}
18980
18981 /// Get the bits used in a chunk of bits \p BitWidth large.
18982 /// \return Result is \p BitWidth and has used bits set to 1 and
18983 /// not used bits set to 0.
18984 APInt getUsedBits() const {
18985 // Reproduce the trunc(lshr) sequence:
18986 // - Start from the truncated value.
18987 // - Zero extend to the desired bit width.
18988 // - Shift left.
18989 assert(Origin && "No original load to compare against.");
18990 unsigned BitWidth = Origin->getValueSizeInBits(ResNo: 0);
18991 assert(Inst && "This slice is not bound to an instruction");
18992 assert(Inst->getValueSizeInBits(0) <= BitWidth &&
18993 "Extracted slice is bigger than the whole type!");
18994 APInt UsedBits(Inst->getValueSizeInBits(ResNo: 0), 0);
18995 UsedBits.setAllBits();
18996 UsedBits = UsedBits.zext(width: BitWidth);
18997 UsedBits <<= Shift;
18998 return UsedBits;
18999 }
19000
19001 /// Get the size of the slice to be loaded in bytes.
19002 unsigned getLoadedSize() const {
19003 unsigned SliceSize = getUsedBits().popcount();
19004 assert(!(SliceSize & 0x7) && "Size is not a multiple of a byte.");
19005 return SliceSize / 8;
19006 }
19007
19008 /// Get the type that will be loaded for this slice.
19009 /// Note: This may not be the final type for the slice.
19010 EVT getLoadedType() const {
19011 assert(DAG && "Missing context");
19012 LLVMContext &Ctxt = *DAG->getContext();
19013 return EVT::getIntegerVT(Context&: Ctxt, BitWidth: getLoadedSize() * 8);
19014 }
19015
19016 /// Get the alignment of the load used for this slice.
19017 Align getAlign() const {
19018 Align Alignment = Origin->getAlign();
19019 uint64_t Offset = getOffsetFromBase();
19020 if (Offset != 0)
19021 Alignment = commonAlignment(A: Alignment, Offset: Alignment.value() + Offset);
19022 return Alignment;
19023 }
19024
19025 /// Check if this slice can be rewritten with legal operations.
19026 bool isLegal() const {
19027 // An invalid slice is not legal.
19028 if (!Origin || !Inst || !DAG)
19029 return false;
19030
19031 // Offsets are for indexed load only, we do not handle that.
19032 if (!Origin->getOffset().isUndef())
19033 return false;
19034
19035 const TargetLowering &TLI = DAG->getTargetLoweringInfo();
19036
19037 // Check that the type is legal.
19038 EVT SliceType = getLoadedType();
19039 if (!TLI.isTypeLegal(VT: SliceType))
19040 return false;
19041
19042 // Check that the load is legal for this type.
19043 if (!TLI.isOperationLegal(Op: ISD::LOAD, VT: SliceType))
19044 return false;
19045
19046 // Check that the offset can be computed.
19047 // 1. Check its type.
19048 EVT PtrType = Origin->getBasePtr().getValueType();
19049 if (PtrType == MVT::Untyped || PtrType.isExtended())
19050 return false;
19051
19052 // 2. Check that it fits in the immediate.
19053 if (!TLI.isLegalAddImmediate(getOffsetFromBase()))
19054 return false;
19055
19056 // 3. Check that the computation is legal.
19057 if (!TLI.isOperationLegal(Op: ISD::ADD, VT: PtrType))
19058 return false;
19059
19060 // Check that the zext is legal if it needs one.
19061 EVT TruncateType = Inst->getValueType(ResNo: 0);
19062 if (TruncateType != SliceType &&
19063 !TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT: TruncateType))
19064 return false;
19065
19066 return true;
19067 }
19068
19069 /// Get the offset in bytes of this slice in the original chunk of
19070 /// bits.
19071 /// \pre DAG != nullptr.
19072 uint64_t getOffsetFromBase() const {
19073 assert(DAG && "Missing context.");
19074 bool IsBigEndian = DAG->getDataLayout().isBigEndian();
19075 assert(!(Shift & 0x7) && "Shifts not aligned on Bytes are not supported.");
19076 uint64_t Offset = Shift / 8;
19077 unsigned TySizeInBytes = Origin->getValueSizeInBits(ResNo: 0) / 8;
19078 assert(!(Origin->getValueSizeInBits(0) & 0x7) &&
19079 "The size of the original loaded type is not a multiple of a"
19080 " byte.");
19081 // If Offset is bigger than TySizeInBytes, it means we are loading all
19082 // zeros. This should have been optimized before in the process.
19083 assert(TySizeInBytes > Offset &&
19084 "Invalid shift amount for given loaded size");
19085 if (IsBigEndian)
19086 Offset = TySizeInBytes - Offset - getLoadedSize();
19087 return Offset;
19088 }
19089
19090 /// Generate the sequence of instructions to load the slice
19091 /// represented by this object and redirect the uses of this slice to
19092 /// this new sequence of instructions.
19093 /// \pre this->Inst && this->Origin are valid Instructions and this
19094 /// object passed the legal check: LoadedSlice::isLegal returned true.
19095 /// \return The last instruction of the sequence used to load the slice.
19096 SDValue loadSlice() const {
19097 assert(Inst && Origin && "Unable to replace a non-existing slice.");
19098 const SDValue &OldBaseAddr = Origin->getBasePtr();
19099 SDValue BaseAddr = OldBaseAddr;
19100 // Get the offset in that chunk of bytes w.r.t. the endianness.
19101 int64_t Offset = static_cast<int64_t>(getOffsetFromBase());
19102 assert(Offset >= 0 && "Offset too big to fit in int64_t!");
19103 if (Offset) {
19104 // BaseAddr = BaseAddr + Offset.
19105 EVT ArithType = BaseAddr.getValueType();
19106 SDLoc DL(Origin);
19107 BaseAddr = DAG->getNode(Opcode: ISD::ADD, DL, VT: ArithType, N1: BaseAddr,
19108 N2: DAG->getConstant(Val: Offset, DL, VT: ArithType));
19109 }
19110
19111 // Create the type of the loaded slice according to its size.
19112 EVT SliceType = getLoadedType();
19113
19114 // Create the load for the slice.
19115 SDValue LastInst =
19116 DAG->getLoad(VT: SliceType, dl: SDLoc(Origin), Chain: Origin->getChain(), Ptr: BaseAddr,
19117 PtrInfo: Origin->getPointerInfo().getWithOffset(O: Offset), Alignment: getAlign(),
19118 MMOFlags: Origin->getMemOperand()->getFlags());
19119 // If the final type is not the same as the loaded type, this means that
19120 // we have to pad with zero. Create a zero extend for that.
19121 EVT FinalType = Inst->getValueType(ResNo: 0);
19122 if (SliceType != FinalType)
19123 LastInst =
19124 DAG->getNode(Opcode: ISD::ZERO_EXTEND, DL: SDLoc(LastInst), VT: FinalType, Operand: LastInst);
19125 return LastInst;
19126 }
19127
19128 /// Check if this slice can be merged with an expensive cross register
19129 /// bank copy. E.g.,
19130 /// i = load i32
19131 /// f = bitcast i32 i to float
19132 bool canMergeExpensiveCrossRegisterBankCopy() const {
19133 if (!Inst || !Inst->hasOneUse())
19134 return false;
19135 SDNode *Use = *Inst->use_begin();
19136 if (Use->getOpcode() != ISD::BITCAST)
19137 return false;
19138 assert(DAG && "Missing context");
19139 const TargetLowering &TLI = DAG->getTargetLoweringInfo();
19140 EVT ResVT = Use->getValueType(ResNo: 0);
19141 const TargetRegisterClass *ResRC =
19142 TLI.getRegClassFor(VT: ResVT.getSimpleVT(), isDivergent: Use->isDivergent());
19143 const TargetRegisterClass *ArgRC =
19144 TLI.getRegClassFor(VT: Use->getOperand(Num: 0).getValueType().getSimpleVT(),
19145 isDivergent: Use->getOperand(Num: 0)->isDivergent());
19146 if (ArgRC == ResRC || !TLI.isOperationLegal(Op: ISD::LOAD, VT: ResVT))
19147 return false;
19148
19149 // At this point, we know that we perform a cross-register-bank copy.
19150 // Check if it is expensive.
19151 const TargetRegisterInfo *TRI = DAG->getSubtarget().getRegisterInfo();
19152 // Assume bitcasts are cheap, unless both register classes do not
19153 // explicitly share a common sub class.
19154 if (!TRI || TRI->getCommonSubClass(A: ArgRC, B: ResRC))
19155 return false;
19156
19157 // Check if it will be merged with the load.
19158 // 1. Check the alignment / fast memory access constraint.
19159 unsigned IsFast = 0;
19160 if (!TLI.allowsMemoryAccess(Context&: *DAG->getContext(), DL: DAG->getDataLayout(), VT: ResVT,
19161 AddrSpace: Origin->getAddressSpace(), Alignment: getAlign(),
19162 Flags: Origin->getMemOperand()->getFlags(), Fast: &IsFast) ||
19163 !IsFast)
19164 return false;
19165
19166 // 2. Check that the load is a legal operation for that type.
19167 if (!TLI.isOperationLegal(Op: ISD::LOAD, VT: ResVT))
19168 return false;
19169
19170 // 3. Check that we do not have a zext in the way.
19171 if (Inst->getValueType(ResNo: 0) != getLoadedType())
19172 return false;
19173
19174 return true;
19175 }
19176};
19177
19178} // end anonymous namespace
19179
19180/// Check that all bits set in \p UsedBits form a dense region, i.e.,
19181/// \p UsedBits looks like 0..0 1..1 0..0.
19182static bool areUsedBitsDense(const APInt &UsedBits) {
19183 // If all the bits are one, this is dense!
19184 if (UsedBits.isAllOnes())
19185 return true;
19186
19187 // Get rid of the unused bits on the right.
19188 APInt NarrowedUsedBits = UsedBits.lshr(shiftAmt: UsedBits.countr_zero());
19189 // Get rid of the unused bits on the left.
19190 if (NarrowedUsedBits.countl_zero())
19191 NarrowedUsedBits = NarrowedUsedBits.trunc(width: NarrowedUsedBits.getActiveBits());
19192 // Check that the chunk of bits is completely used.
19193 return NarrowedUsedBits.isAllOnes();
19194}
19195
19196/// Check whether or not \p First and \p Second are next to each other
19197/// in memory. This means that there is no hole between the bits loaded
19198/// by \p First and the bits loaded by \p Second.
19199static bool areSlicesNextToEachOther(const LoadedSlice &First,
19200 const LoadedSlice &Second) {
19201 assert(First.Origin == Second.Origin && First.Origin &&
19202 "Unable to match different memory origins.");
19203 APInt UsedBits = First.getUsedBits();
19204 assert((UsedBits & Second.getUsedBits()) == 0 &&
19205 "Slices are not supposed to overlap.");
19206 UsedBits |= Second.getUsedBits();
19207 return areUsedBitsDense(UsedBits);
19208}
19209
19210/// Adjust the \p GlobalLSCost according to the target
19211/// paring capabilities and the layout of the slices.
19212/// \pre \p GlobalLSCost should account for at least as many loads as
19213/// there is in the slices in \p LoadedSlices.
19214static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices,
19215 LoadedSlice::Cost &GlobalLSCost) {
19216 unsigned NumberOfSlices = LoadedSlices.size();
19217 // If there is less than 2 elements, no pairing is possible.
19218 if (NumberOfSlices < 2)
19219 return;
19220
19221 // Sort the slices so that elements that are likely to be next to each
19222 // other in memory are next to each other in the list.
19223 llvm::sort(C&: LoadedSlices, Comp: [](const LoadedSlice &LHS, const LoadedSlice &RHS) {
19224 assert(LHS.Origin == RHS.Origin && "Different bases not implemented.");
19225 return LHS.getOffsetFromBase() < RHS.getOffsetFromBase();
19226 });
19227 const TargetLowering &TLI = LoadedSlices[0].DAG->getTargetLoweringInfo();
19228 // First (resp. Second) is the first (resp. Second) potentially candidate
19229 // to be placed in a paired load.
19230 const LoadedSlice *First = nullptr;
19231 const LoadedSlice *Second = nullptr;
19232 for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice,
19233 // Set the beginning of the pair.
19234 First = Second) {
19235 Second = &LoadedSlices[CurrSlice];
19236
19237 // If First is NULL, it means we start a new pair.
19238 // Get to the next slice.
19239 if (!First)
19240 continue;
19241
19242 EVT LoadedType = First->getLoadedType();
19243
19244 // If the types of the slices are different, we cannot pair them.
19245 if (LoadedType != Second->getLoadedType())
19246 continue;
19247
19248 // Check if the target supplies paired loads for this type.
19249 Align RequiredAlignment;
19250 if (!TLI.hasPairedLoad(LoadedType, RequiredAlignment)) {
19251 // move to the next pair, this type is hopeless.
19252 Second = nullptr;
19253 continue;
19254 }
19255 // Check if we meet the alignment requirement.
19256 if (First->getAlign() < RequiredAlignment)
19257 continue;
19258
19259 // Check that both loads are next to each other in memory.
19260 if (!areSlicesNextToEachOther(First: *First, Second: *Second))
19261 continue;
19262
19263 assert(GlobalLSCost.Loads > 0 && "We save more loads than we created!");
19264 --GlobalLSCost.Loads;
19265 // Move to the next pair.
19266 Second = nullptr;
19267 }
19268}
19269
19270/// Check the profitability of all involved LoadedSlice.
19271/// Currently, it is considered profitable if there is exactly two
19272/// involved slices (1) which are (2) next to each other in memory, and
19273/// whose cost (\see LoadedSlice::Cost) is smaller than the original load (3).
19274///
19275/// Note: The order of the elements in \p LoadedSlices may be modified, but not
19276/// the elements themselves.
19277///
19278/// FIXME: When the cost model will be mature enough, we can relax
19279/// constraints (1) and (2).
19280static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
19281 const APInt &UsedBits, bool ForCodeSize) {
19282 unsigned NumberOfSlices = LoadedSlices.size();
19283 if (StressLoadSlicing)
19284 return NumberOfSlices > 1;
19285
19286 // Check (1).
19287 if (NumberOfSlices != 2)
19288 return false;
19289
19290 // Check (2).
19291 if (!areUsedBitsDense(UsedBits))
19292 return false;
19293
19294 // Check (3).
19295 LoadedSlice::Cost OrigCost(ForCodeSize), GlobalSlicingCost(ForCodeSize);
19296 // The original code has one big load.
19297 OrigCost.Loads = 1;
19298 for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice) {
19299 const LoadedSlice &LS = LoadedSlices[CurrSlice];
19300 // Accumulate the cost of all the slices.
19301 LoadedSlice::Cost SliceCost(LS, ForCodeSize);
19302 GlobalSlicingCost += SliceCost;
19303
19304 // Account as cost in the original configuration the gain obtained
19305 // with the current slices.
19306 OrigCost.addSliceGain(LS);
19307 }
19308
19309 // If the target supports paired load, adjust the cost accordingly.
19310 adjustCostForPairing(LoadedSlices, GlobalLSCost&: GlobalSlicingCost);
19311 return OrigCost > GlobalSlicingCost;
19312}
19313
19314/// If the given load, \p LI, is used only by trunc or trunc(lshr)
19315/// operations, split it in the various pieces being extracted.
19316///
19317/// This sort of thing is introduced by SROA.
19318/// This slicing takes care not to insert overlapping loads.
19319/// \pre LI is a simple load (i.e., not an atomic or volatile load).
19320bool DAGCombiner::SliceUpLoad(SDNode *N) {
19321 if (Level < AfterLegalizeDAG)
19322 return false;
19323
19324 LoadSDNode *LD = cast<LoadSDNode>(Val: N);
19325 if (!LD->isSimple() || !ISD::isNormalLoad(N: LD) ||
19326 !LD->getValueType(ResNo: 0).isInteger())
19327 return false;
19328
19329 // The algorithm to split up a load of a scalable vector into individual
19330 // elements currently requires knowing the length of the loaded type,
19331 // so will need adjusting to work on scalable vectors.
19332 if (LD->getValueType(ResNo: 0).isScalableVector())
19333 return false;
19334
19335 // Keep track of already used bits to detect overlapping values.
19336 // In that case, we will just abort the transformation.
19337 APInt UsedBits(LD->getValueSizeInBits(ResNo: 0), 0);
19338
19339 SmallVector<LoadedSlice, 4> LoadedSlices;
19340
19341 // Check if this load is used as several smaller chunks of bits.
19342 // Basically, look for uses in trunc or trunc(lshr) and record a new chain
19343 // of computation for each trunc.
19344 for (SDNode::use_iterator UI = LD->use_begin(), UIEnd = LD->use_end();
19345 UI != UIEnd; ++UI) {
19346 // Skip the uses of the chain.
19347 if (UI.getUse().getResNo() != 0)
19348 continue;
19349
19350 SDNode *User = *UI;
19351 unsigned Shift = 0;
19352
19353 // Check if this is a trunc(lshr).
19354 if (User->getOpcode() == ISD::SRL && User->hasOneUse() &&
19355 isa<ConstantSDNode>(Val: User->getOperand(Num: 1))) {
19356 Shift = User->getConstantOperandVal(Num: 1);
19357 User = *User->use_begin();
19358 }
19359
19360 // At this point, User is a Truncate, iff we encountered, trunc or
19361 // trunc(lshr).
19362 if (User->getOpcode() != ISD::TRUNCATE)
19363 return false;
19364
19365 // The width of the type must be a power of 2 and greater than 8-bits.
19366 // Otherwise the load cannot be represented in LLVM IR.
19367 // Moreover, if we shifted with a non-8-bits multiple, the slice
19368 // will be across several bytes. We do not support that.
19369 unsigned Width = User->getValueSizeInBits(ResNo: 0);
19370 if (Width < 8 || !isPowerOf2_32(Value: Width) || (Shift & 0x7))
19371 return false;
19372
19373 // Build the slice for this chain of computations.
19374 LoadedSlice LS(User, LD, Shift, &DAG);
19375 APInt CurrentUsedBits = LS.getUsedBits();
19376
19377 // Check if this slice overlaps with another.
19378 if ((CurrentUsedBits & UsedBits) != 0)
19379 return false;
19380 // Update the bits used globally.
19381 UsedBits |= CurrentUsedBits;
19382
19383 // Check if the new slice would be legal.
19384 if (!LS.isLegal())
19385 return false;
19386
19387 // Record the slice.
19388 LoadedSlices.push_back(Elt: LS);
19389 }
19390
19391 // Abort slicing if it does not seem to be profitable.
19392 if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize))
19393 return false;
19394
19395 ++SlicedLoads;
19396
19397 // Rewrite each chain to use an independent load.
19398 // By construction, each chain can be represented by a unique load.
19399
19400 // Prepare the argument for the new token factor for all the slices.
19401 SmallVector<SDValue, 8> ArgChains;
19402 for (const LoadedSlice &LS : LoadedSlices) {
19403 SDValue SliceInst = LS.loadSlice();
19404 CombineTo(N: LS.Inst, Res: SliceInst, AddTo: true);
19405 if (SliceInst.getOpcode() != ISD::LOAD)
19406 SliceInst = SliceInst.getOperand(i: 0);
19407 assert(SliceInst->getOpcode() == ISD::LOAD &&
19408 "It takes more than a zext to get to the loaded slice!!");
19409 ArgChains.push_back(Elt: SliceInst.getValue(R: 1));
19410 }
19411
19412 SDValue Chain = DAG.getNode(ISD::TokenFactor, SDLoc(LD), MVT::Other,
19413 ArgChains);
19414 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: Chain);
19415 AddToWorklist(N: Chain.getNode());
19416 return true;
19417}
19418
19419/// Check to see if V is (and load (ptr), imm), where the load is having
19420/// specific bytes cleared out. If so, return the byte size being masked out
19421/// and the shift amount.
19422static std::pair<unsigned, unsigned>
19423CheckForMaskedLoad(SDValue V, SDValue Ptr, SDValue Chain) {
19424 std::pair<unsigned, unsigned> Result(0, 0);
19425
19426 // Check for the structure we're looking for.
19427 if (V->getOpcode() != ISD::AND ||
19428 !isa<ConstantSDNode>(Val: V->getOperand(Num: 1)) ||
19429 !ISD::isNormalLoad(N: V->getOperand(Num: 0).getNode()))
19430 return Result;
19431
19432 // Check the chain and pointer.
19433 LoadSDNode *LD = cast<LoadSDNode>(Val: V->getOperand(Num: 0));
19434 if (LD->getBasePtr() != Ptr) return Result; // Not from same pointer.
19435
19436 // This only handles simple types.
19437 if (V.getValueType() != MVT::i16 &&
19438 V.getValueType() != MVT::i32 &&
19439 V.getValueType() != MVT::i64)
19440 return Result;
19441
19442 // Check the constant mask. Invert it so that the bits being masked out are
19443 // 0 and the bits being kept are 1. Use getSExtValue so that leading bits
19444 // follow the sign bit for uniformity.
19445 uint64_t NotMask = ~cast<ConstantSDNode>(Val: V->getOperand(Num: 1))->getSExtValue();
19446 unsigned NotMaskLZ = llvm::countl_zero(Val: NotMask);
19447 if (NotMaskLZ & 7) return Result; // Must be multiple of a byte.
19448 unsigned NotMaskTZ = llvm::countr_zero(Val: NotMask);
19449 if (NotMaskTZ & 7) return Result; // Must be multiple of a byte.
19450 if (NotMaskLZ == 64) return Result; // All zero mask.
19451
19452 // See if we have a continuous run of bits. If so, we have 0*1+0*
19453 if (llvm::countr_one(Value: NotMask >> NotMaskTZ) + NotMaskTZ + NotMaskLZ != 64)
19454 return Result;
19455
19456 // Adjust NotMaskLZ down to be from the actual size of the int instead of i64.
19457 if (V.getValueType() != MVT::i64 && NotMaskLZ)
19458 NotMaskLZ -= 64-V.getValueSizeInBits();
19459
19460 unsigned MaskedBytes = (V.getValueSizeInBits()-NotMaskLZ-NotMaskTZ)/8;
19461 switch (MaskedBytes) {
19462 case 1:
19463 case 2:
19464 case 4: break;
19465 default: return Result; // All one mask, or 5-byte mask.
19466 }
19467
19468 // Verify that the first bit starts at a multiple of mask so that the access
19469 // is aligned the same as the access width.
19470 if (NotMaskTZ && NotMaskTZ/8 % MaskedBytes) return Result;
19471
19472 // For narrowing to be valid, it must be the case that the load the
19473 // immediately preceding memory operation before the store.
19474 if (LD == Chain.getNode())
19475 ; // ok.
19476 else if (Chain->getOpcode() == ISD::TokenFactor &&
19477 SDValue(LD, 1).hasOneUse()) {
19478 // LD has only 1 chain use so they are no indirect dependencies.
19479 if (!LD->isOperandOf(N: Chain.getNode()))
19480 return Result;
19481 } else
19482 return Result; // Fail.
19483
19484 Result.first = MaskedBytes;
19485 Result.second = NotMaskTZ/8;
19486 return Result;
19487}
19488
19489/// Check to see if IVal is something that provides a value as specified by
19490/// MaskInfo. If so, replace the specified store with a narrower store of
19491/// truncated IVal.
19492static SDValue
19493ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned, unsigned> &MaskInfo,
19494 SDValue IVal, StoreSDNode *St,
19495 DAGCombiner *DC) {
19496 unsigned NumBytes = MaskInfo.first;
19497 unsigned ByteShift = MaskInfo.second;
19498 SelectionDAG &DAG = DC->getDAG();
19499
19500 // Check to see if IVal is all zeros in the part being masked in by the 'or'
19501 // that uses this. If not, this is not a replacement.
19502 APInt Mask = ~APInt::getBitsSet(numBits: IVal.getValueSizeInBits(),
19503 loBit: ByteShift*8, hiBit: (ByteShift+NumBytes)*8);
19504 if (!DAG.MaskedValueIsZero(Op: IVal, Mask)) return SDValue();
19505
19506 // Check that it is legal on the target to do this. It is legal if the new
19507 // VT we're shrinking to (i8/i16/i32) is legal or we're still before type
19508 // legalization. If the source type is legal, but the store type isn't, see
19509 // if we can use a truncating store.
19510 MVT VT = MVT::getIntegerVT(BitWidth: NumBytes * 8);
19511 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
19512 bool UseTruncStore;
19513 if (DC->isTypeLegal(VT))
19514 UseTruncStore = false;
19515 else if (TLI.isTypeLegal(VT: IVal.getValueType()) &&
19516 TLI.isTruncStoreLegal(ValVT: IVal.getValueType(), MemVT: VT))
19517 UseTruncStore = true;
19518 else
19519 return SDValue();
19520
19521 // Can't do this for indexed stores.
19522 if (St->isIndexed())
19523 return SDValue();
19524
19525 // Check that the target doesn't think this is a bad idea.
19526 if (St->getMemOperand() &&
19527 !TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT,
19528 MMO: *St->getMemOperand()))
19529 return SDValue();
19530
19531 // Okay, we can do this! Replace the 'St' store with a store of IVal that is
19532 // shifted by ByteShift and truncated down to NumBytes.
19533 if (ByteShift) {
19534 SDLoc DL(IVal);
19535 IVal = DAG.getNode(Opcode: ISD::SRL, DL, VT: IVal.getValueType(), N1: IVal,
19536 N2: DAG.getConstant(Val: ByteShift*8, DL,
19537 VT: DC->getShiftAmountTy(LHSTy: IVal.getValueType())));
19538 }
19539
19540 // Figure out the offset for the store and the alignment of the access.
19541 unsigned StOffset;
19542 if (DAG.getDataLayout().isLittleEndian())
19543 StOffset = ByteShift;
19544 else
19545 StOffset = IVal.getValueType().getStoreSize() - ByteShift - NumBytes;
19546
19547 SDValue Ptr = St->getBasePtr();
19548 if (StOffset) {
19549 SDLoc DL(IVal);
19550 Ptr = DAG.getMemBasePlusOffset(Base: Ptr, Offset: TypeSize::getFixed(ExactSize: StOffset), DL);
19551 }
19552
19553 ++OpsNarrowed;
19554 if (UseTruncStore)
19555 return DAG.getTruncStore(Chain: St->getChain(), dl: SDLoc(St), Val: IVal, Ptr,
19556 PtrInfo: St->getPointerInfo().getWithOffset(O: StOffset),
19557 SVT: VT, Alignment: St->getOriginalAlign());
19558
19559 // Truncate down to the new size.
19560 IVal = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(IVal), VT, Operand: IVal);
19561
19562 return DAG
19563 .getStore(Chain: St->getChain(), dl: SDLoc(St), Val: IVal, Ptr,
19564 PtrInfo: St->getPointerInfo().getWithOffset(O: StOffset),
19565 Alignment: St->getOriginalAlign());
19566}
19567
19568/// Look for sequence of load / op / store where op is one of 'or', 'xor', and
19569/// 'and' of immediates. If 'op' is only touching some of the loaded bits, try
19570/// narrowing the load and store if it would end up being a win for performance
19571/// or code size.
19572SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) {
19573 StoreSDNode *ST = cast<StoreSDNode>(Val: N);
19574 if (!ST->isSimple())
19575 return SDValue();
19576
19577 SDValue Chain = ST->getChain();
19578 SDValue Value = ST->getValue();
19579 SDValue Ptr = ST->getBasePtr();
19580 EVT VT = Value.getValueType();
19581
19582 if (ST->isTruncatingStore() || VT.isVector())
19583 return SDValue();
19584
19585 unsigned Opc = Value.getOpcode();
19586
19587 if ((Opc != ISD::OR && Opc != ISD::XOR && Opc != ISD::AND) ||
19588 !Value.hasOneUse())
19589 return SDValue();
19590
19591 // If this is "store (or X, Y), P" and X is "(and (load P), cst)", where cst
19592 // is a byte mask indicating a consecutive number of bytes, check to see if
19593 // Y is known to provide just those bytes. If so, we try to replace the
19594 // load + replace + store sequence with a single (narrower) store, which makes
19595 // the load dead.
19596 if (Opc == ISD::OR && EnableShrinkLoadReplaceStoreWithStore) {
19597 std::pair<unsigned, unsigned> MaskedLoad;
19598 MaskedLoad = CheckForMaskedLoad(V: Value.getOperand(i: 0), Ptr, Chain);
19599 if (MaskedLoad.first)
19600 if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskInfo: MaskedLoad,
19601 IVal: Value.getOperand(i: 1), St: ST,DC: this))
19602 return NewST;
19603
19604 // Or is commutative, so try swapping X and Y.
19605 MaskedLoad = CheckForMaskedLoad(V: Value.getOperand(i: 1), Ptr, Chain);
19606 if (MaskedLoad.first)
19607 if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskInfo: MaskedLoad,
19608 IVal: Value.getOperand(i: 0), St: ST,DC: this))
19609 return NewST;
19610 }
19611
19612 if (!EnableReduceLoadOpStoreWidth)
19613 return SDValue();
19614
19615 if (Value.getOperand(i: 1).getOpcode() != ISD::Constant)
19616 return SDValue();
19617
19618 SDValue N0 = Value.getOperand(i: 0);
19619 if (ISD::isNormalLoad(N: N0.getNode()) && N0.hasOneUse() &&
19620 Chain == SDValue(N0.getNode(), 1)) {
19621 LoadSDNode *LD = cast<LoadSDNode>(Val&: N0);
19622 if (LD->getBasePtr() != Ptr ||
19623 LD->getPointerInfo().getAddrSpace() !=
19624 ST->getPointerInfo().getAddrSpace())
19625 return SDValue();
19626
19627 // Find the type to narrow it the load / op / store to.
19628 SDValue N1 = Value.getOperand(i: 1);
19629 unsigned BitWidth = N1.getValueSizeInBits();
19630 APInt Imm = N1->getAsAPIntVal();
19631 if (Opc == ISD::AND)
19632 Imm ^= APInt::getAllOnes(numBits: BitWidth);
19633 if (Imm == 0 || Imm.isAllOnes())
19634 return SDValue();
19635 unsigned ShAmt = Imm.countr_zero();
19636 unsigned MSB = BitWidth - Imm.countl_zero() - 1;
19637 unsigned NewBW = NextPowerOf2(A: MSB - ShAmt);
19638 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: NewBW);
19639 // The narrowing should be profitable, the load/store operation should be
19640 // legal (or custom) and the store size should be equal to the NewVT width.
19641 while (NewBW < BitWidth &&
19642 (NewVT.getStoreSizeInBits() != NewBW ||
19643 !TLI.isOperationLegalOrCustom(Op: Opc, VT: NewVT) ||
19644 !TLI.isNarrowingProfitable(SrcVT: VT, DestVT: NewVT))) {
19645 NewBW = NextPowerOf2(A: NewBW);
19646 NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: NewBW);
19647 }
19648 if (NewBW >= BitWidth)
19649 return SDValue();
19650
19651 // If the lsb changed does not start at the type bitwidth boundary,
19652 // start at the previous one.
19653 if (ShAmt % NewBW)
19654 ShAmt = (((ShAmt + NewBW - 1) / NewBW) * NewBW) - NewBW;
19655 APInt Mask = APInt::getBitsSet(numBits: BitWidth, loBit: ShAmt,
19656 hiBit: std::min(a: BitWidth, b: ShAmt + NewBW));
19657 if ((Imm & Mask) == Imm) {
19658 APInt NewImm = (Imm & Mask).lshr(shiftAmt: ShAmt).trunc(width: NewBW);
19659 if (Opc == ISD::AND)
19660 NewImm ^= APInt::getAllOnes(numBits: NewBW);
19661 uint64_t PtrOff = ShAmt / 8;
19662 // For big endian targets, we need to adjust the offset to the pointer to
19663 // load the correct bytes.
19664 if (DAG.getDataLayout().isBigEndian())
19665 PtrOff = (BitWidth + 7 - NewBW) / 8 - PtrOff;
19666
19667 unsigned IsFast = 0;
19668 Align NewAlign = commonAlignment(A: LD->getAlign(), Offset: PtrOff);
19669 if (!TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: NewVT,
19670 AddrSpace: LD->getAddressSpace(), Alignment: NewAlign,
19671 Flags: LD->getMemOperand()->getFlags(), Fast: &IsFast) ||
19672 !IsFast)
19673 return SDValue();
19674
19675 SDValue NewPtr =
19676 DAG.getMemBasePlusOffset(Base: Ptr, Offset: TypeSize::getFixed(ExactSize: PtrOff), DL: SDLoc(LD));
19677 SDValue NewLD =
19678 DAG.getLoad(VT: NewVT, dl: SDLoc(N0), Chain: LD->getChain(), Ptr: NewPtr,
19679 PtrInfo: LD->getPointerInfo().getWithOffset(O: PtrOff), Alignment: NewAlign,
19680 MMOFlags: LD->getMemOperand()->getFlags(), AAInfo: LD->getAAInfo());
19681 SDValue NewVal = DAG.getNode(Opcode: Opc, DL: SDLoc(Value), VT: NewVT, N1: NewLD,
19682 N2: DAG.getConstant(Val: NewImm, DL: SDLoc(Value),
19683 VT: NewVT));
19684 SDValue NewST =
19685 DAG.getStore(Chain, dl: SDLoc(N), Val: NewVal, Ptr: NewPtr,
19686 PtrInfo: ST->getPointerInfo().getWithOffset(O: PtrOff), Alignment: NewAlign);
19687
19688 AddToWorklist(N: NewPtr.getNode());
19689 AddToWorklist(N: NewLD.getNode());
19690 AddToWorklist(N: NewVal.getNode());
19691 WorklistRemover DeadNodes(*this);
19692 DAG.ReplaceAllUsesOfValueWith(From: N0.getValue(R: 1), To: NewLD.getValue(R: 1));
19693 ++OpsNarrowed;
19694 return NewST;
19695 }
19696 }
19697
19698 return SDValue();
19699}
19700
19701/// For a given floating point load / store pair, if the load value isn't used
19702/// by any other operations, then consider transforming the pair to integer
19703/// load / store operations if the target deems the transformation profitable.
19704SDValue DAGCombiner::TransformFPLoadStorePair(SDNode *N) {
19705 StoreSDNode *ST = cast<StoreSDNode>(Val: N);
19706 SDValue Value = ST->getValue();
19707 if (ISD::isNormalStore(N: ST) && ISD::isNormalLoad(N: Value.getNode()) &&
19708 Value.hasOneUse()) {
19709 LoadSDNode *LD = cast<LoadSDNode>(Val&: Value);
19710 EVT VT = LD->getMemoryVT();
19711 if (!VT.isFloatingPoint() ||
19712 VT != ST->getMemoryVT() ||
19713 LD->isNonTemporal() ||
19714 ST->isNonTemporal() ||
19715 LD->getPointerInfo().getAddrSpace() != 0 ||
19716 ST->getPointerInfo().getAddrSpace() != 0)
19717 return SDValue();
19718
19719 TypeSize VTSize = VT.getSizeInBits();
19720
19721 // We don't know the size of scalable types at compile time so we cannot
19722 // create an integer of the equivalent size.
19723 if (VTSize.isScalable())
19724 return SDValue();
19725
19726 unsigned FastLD = 0, FastST = 0;
19727 EVT IntVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: VTSize.getFixedValue());
19728 if (!TLI.isOperationLegal(Op: ISD::LOAD, VT: IntVT) ||
19729 !TLI.isOperationLegal(Op: ISD::STORE, VT: IntVT) ||
19730 !TLI.isDesirableToTransformToIntegerOp(ISD::LOAD, VT) ||
19731 !TLI.isDesirableToTransformToIntegerOp(ISD::STORE, VT) ||
19732 !TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: IntVT,
19733 MMO: *LD->getMemOperand(), Fast: &FastLD) ||
19734 !TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: IntVT,
19735 MMO: *ST->getMemOperand(), Fast: &FastST) ||
19736 !FastLD || !FastST)
19737 return SDValue();
19738
19739 SDValue NewLD =
19740 DAG.getLoad(VT: IntVT, dl: SDLoc(Value), Chain: LD->getChain(), Ptr: LD->getBasePtr(),
19741 PtrInfo: LD->getPointerInfo(), Alignment: LD->getAlign());
19742
19743 SDValue NewST =
19744 DAG.getStore(Chain: ST->getChain(), dl: SDLoc(N), Val: NewLD, Ptr: ST->getBasePtr(),
19745 PtrInfo: ST->getPointerInfo(), Alignment: ST->getAlign());
19746
19747 AddToWorklist(N: NewLD.getNode());
19748 AddToWorklist(N: NewST.getNode());
19749 WorklistRemover DeadNodes(*this);
19750 DAG.ReplaceAllUsesOfValueWith(From: Value.getValue(R: 1), To: NewLD.getValue(R: 1));
19751 ++LdStFP2Int;
19752 return NewST;
19753 }
19754
19755 return SDValue();
19756}
19757
19758// This is a helper function for visitMUL to check the profitability
19759// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
19760// MulNode is the original multiply, AddNode is (add x, c1),
19761// and ConstNode is c2.
19762//
19763// If the (add x, c1) has multiple uses, we could increase
19764// the number of adds if we make this transformation.
19765// It would only be worth doing this if we can remove a
19766// multiply in the process. Check for that here.
19767// To illustrate:
19768// (A + c1) * c3
19769// (A + c2) * c3
19770// We're checking for cases where we have common "c3 * A" expressions.
19771bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
19772 SDValue ConstNode) {
19773 APInt Val;
19774
19775 // If the add only has one use, and the target thinks the folding is
19776 // profitable or does not lead to worse code, this would be OK to do.
19777 if (AddNode->hasOneUse() &&
19778 TLI.isMulAddWithConstProfitable(AddNode, ConstNode))
19779 return true;
19780
19781 // Walk all the users of the constant with which we're multiplying.
19782 for (SDNode *Use : ConstNode->uses()) {
19783 if (Use == MulNode) // This use is the one we're on right now. Skip it.
19784 continue;
19785
19786 if (Use->getOpcode() == ISD::MUL) { // We have another multiply use.
19787 SDNode *OtherOp;
19788 SDNode *MulVar = AddNode.getOperand(i: 0).getNode();
19789
19790 // OtherOp is what we're multiplying against the constant.
19791 if (Use->getOperand(Num: 0) == ConstNode)
19792 OtherOp = Use->getOperand(Num: 1).getNode();
19793 else
19794 OtherOp = Use->getOperand(Num: 0).getNode();
19795
19796 // Check to see if multiply is with the same operand of our "add".
19797 //
19798 // ConstNode = CONST
19799 // Use = ConstNode * A <-- visiting Use. OtherOp is A.
19800 // ...
19801 // AddNode = (A + c1) <-- MulVar is A.
19802 // = AddNode * ConstNode <-- current visiting instruction.
19803 //
19804 // If we make this transformation, we will have a common
19805 // multiply (ConstNode * A) that we can save.
19806 if (OtherOp == MulVar)
19807 return true;
19808
19809 // Now check to see if a future expansion will give us a common
19810 // multiply.
19811 //
19812 // ConstNode = CONST
19813 // AddNode = (A + c1)
19814 // ... = AddNode * ConstNode <-- current visiting instruction.
19815 // ...
19816 // OtherOp = (A + c2)
19817 // Use = OtherOp * ConstNode <-- visiting Use.
19818 //
19819 // If we make this transformation, we will have a common
19820 // multiply (CONST * A) after we also do the same transformation
19821 // to the "t2" instruction.
19822 if (OtherOp->getOpcode() == ISD::ADD &&
19823 DAG.isConstantIntBuildVectorOrConstantInt(N: OtherOp->getOperand(Num: 1)) &&
19824 OtherOp->getOperand(Num: 0).getNode() == MulVar)
19825 return true;
19826 }
19827 }
19828
19829 // Didn't find a case where this would be profitable.
19830 return false;
19831}
19832
19833SDValue DAGCombiner::getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
19834 unsigned NumStores) {
19835 SmallVector<SDValue, 8> Chains;
19836 SmallPtrSet<const SDNode *, 8> Visited;
19837 SDLoc StoreDL(StoreNodes[0].MemNode);
19838
19839 for (unsigned i = 0; i < NumStores; ++i) {
19840 Visited.insert(Ptr: StoreNodes[i].MemNode);
19841 }
19842
19843 // don't include nodes that are children or repeated nodes.
19844 for (unsigned i = 0; i < NumStores; ++i) {
19845 if (Visited.insert(Ptr: StoreNodes[i].MemNode->getChain().getNode()).second)
19846 Chains.push_back(Elt: StoreNodes[i].MemNode->getChain());
19847 }
19848
19849 assert(!Chains.empty() && "Chain should have generated a chain");
19850 return DAG.getTokenFactor(DL: StoreDL, Vals&: Chains);
19851}
19852
19853bool DAGCombiner::hasSameUnderlyingObj(ArrayRef<MemOpLink> StoreNodes) {
19854 const Value *UnderlyingObj = nullptr;
19855 for (const auto &MemOp : StoreNodes) {
19856 const MachineMemOperand *MMO = MemOp.MemNode->getMemOperand();
19857 // Pseudo value like stack frame has its own frame index and size, should
19858 // not use the first store's frame index for other frames.
19859 if (MMO->getPseudoValue())
19860 return false;
19861
19862 if (!MMO->getValue())
19863 return false;
19864
19865 const Value *Obj = getUnderlyingObject(V: MMO->getValue());
19866
19867 if (UnderlyingObj && UnderlyingObj != Obj)
19868 return false;
19869
19870 if (!UnderlyingObj)
19871 UnderlyingObj = Obj;
19872 }
19873
19874 return true;
19875}
19876
19877bool DAGCombiner::mergeStoresOfConstantsOrVecElts(
19878 SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT, unsigned NumStores,
19879 bool IsConstantSrc, bool UseVector, bool UseTrunc) {
19880 // Make sure we have something to merge.
19881 if (NumStores < 2)
19882 return false;
19883
19884 assert((!UseTrunc || !UseVector) &&
19885 "This optimization cannot emit a vector truncating store");
19886
19887 // The latest Node in the DAG.
19888 SDLoc DL(StoreNodes[0].MemNode);
19889
19890 TypeSize ElementSizeBits = MemVT.getStoreSizeInBits();
19891 unsigned SizeInBits = NumStores * ElementSizeBits;
19892 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
19893
19894 std::optional<MachineMemOperand::Flags> Flags;
19895 AAMDNodes AAInfo;
19896 for (unsigned I = 0; I != NumStores; ++I) {
19897 StoreSDNode *St = cast<StoreSDNode>(Val: StoreNodes[I].MemNode);
19898 if (!Flags) {
19899 Flags = St->getMemOperand()->getFlags();
19900 AAInfo = St->getAAInfo();
19901 continue;
19902 }
19903 // Skip merging if there's an inconsistent flag.
19904 if (Flags != St->getMemOperand()->getFlags())
19905 return false;
19906 // Concatenate AA metadata.
19907 AAInfo = AAInfo.concat(Other: St->getAAInfo());
19908 }
19909
19910 EVT StoreTy;
19911 if (UseVector) {
19912 unsigned Elts = NumStores * NumMemElts;
19913 // Get the type for the merged vector store.
19914 StoreTy = EVT::getVectorVT(Context&: *DAG.getContext(), VT: MemVT.getScalarType(), NumElements: Elts);
19915 } else
19916 StoreTy = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SizeInBits);
19917
19918 SDValue StoredVal;
19919 if (UseVector) {
19920 if (IsConstantSrc) {
19921 SmallVector<SDValue, 8> BuildVector;
19922 for (unsigned I = 0; I != NumStores; ++I) {
19923 StoreSDNode *St = cast<StoreSDNode>(Val: StoreNodes[I].MemNode);
19924 SDValue Val = St->getValue();
19925 // If constant is of the wrong type, convert it now. This comes up
19926 // when one of our stores was truncating.
19927 if (MemVT != Val.getValueType()) {
19928 Val = peekThroughBitcasts(V: Val);
19929 // Deal with constants of wrong size.
19930 if (ElementSizeBits != Val.getValueSizeInBits()) {
19931 auto *C = dyn_cast<ConstantSDNode>(Val);
19932 if (!C)
19933 // Not clear how to truncate FP values.
19934 // TODO: Handle truncation of build_vector constants
19935 return false;
19936
19937 EVT IntMemVT =
19938 EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: MemVT.getSizeInBits());
19939 Val = DAG.getConstant(Val: C->getAPIntValue()
19940 .zextOrTrunc(width: Val.getValueSizeInBits())
19941 .zextOrTrunc(width: ElementSizeBits),
19942 DL: SDLoc(C), VT: IntMemVT);
19943 }
19944 // Make sure correctly size type is the correct type.
19945 Val = DAG.getBitcast(VT: MemVT, V: Val);
19946 }
19947 BuildVector.push_back(Elt: Val);
19948 }
19949 StoredVal = DAG.getNode(Opcode: MemVT.isVector() ? ISD::CONCAT_VECTORS
19950 : ISD::BUILD_VECTOR,
19951 DL, VT: StoreTy, Ops: BuildVector);
19952 } else {
19953 SmallVector<SDValue, 8> Ops;
19954 for (unsigned i = 0; i < NumStores; ++i) {
19955 StoreSDNode *St = cast<StoreSDNode>(Val: StoreNodes[i].MemNode);
19956 SDValue Val = peekThroughBitcasts(V: St->getValue());
19957 // All operands of BUILD_VECTOR / CONCAT_VECTOR must be of
19958 // type MemVT. If the underlying value is not the correct
19959 // type, but it is an extraction of an appropriate vector we
19960 // can recast Val to be of the correct type. This may require
19961 // converting between EXTRACT_VECTOR_ELT and
19962 // EXTRACT_SUBVECTOR.
19963 if ((MemVT != Val.getValueType()) &&
19964 (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
19965 Val.getOpcode() == ISD::EXTRACT_SUBVECTOR)) {
19966 EVT MemVTScalarTy = MemVT.getScalarType();
19967 // We may need to add a bitcast here to get types to line up.
19968 if (MemVTScalarTy != Val.getValueType().getScalarType()) {
19969 Val = DAG.getBitcast(VT: MemVT, V: Val);
19970 } else if (MemVT.isVector() &&
19971 Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
19972 Val = DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: MemVT, Operand: Val);
19973 } else {
19974 unsigned OpC = MemVT.isVector() ? ISD::EXTRACT_SUBVECTOR
19975 : ISD::EXTRACT_VECTOR_ELT;
19976 SDValue Vec = Val.getOperand(i: 0);
19977 SDValue Idx = Val.getOperand(i: 1);
19978 Val = DAG.getNode(Opcode: OpC, DL: SDLoc(Val), VT: MemVT, N1: Vec, N2: Idx);
19979 }
19980 }
19981 Ops.push_back(Elt: Val);
19982 }
19983
19984 // Build the extracted vector elements back into a vector.
19985 StoredVal = DAG.getNode(Opcode: MemVT.isVector() ? ISD::CONCAT_VECTORS
19986 : ISD::BUILD_VECTOR,
19987 DL, VT: StoreTy, Ops);
19988 }
19989 } else {
19990 // We should always use a vector store when merging extracted vector
19991 // elements, so this path implies a store of constants.
19992 assert(IsConstantSrc && "Merged vector elements should use vector store");
19993
19994 APInt StoreInt(SizeInBits, 0);
19995
19996 // Construct a single integer constant which is made of the smaller
19997 // constant inputs.
19998 bool IsLE = DAG.getDataLayout().isLittleEndian();
19999 for (unsigned i = 0; i < NumStores; ++i) {
20000 unsigned Idx = IsLE ? (NumStores - 1 - i) : i;
20001 StoreSDNode *St = cast<StoreSDNode>(Val: StoreNodes[Idx].MemNode);
20002
20003 SDValue Val = St->getValue();
20004 Val = peekThroughBitcasts(V: Val);
20005 StoreInt <<= ElementSizeBits;
20006 if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val)) {
20007 StoreInt |= C->getAPIntValue()
20008 .zextOrTrunc(width: ElementSizeBits)
20009 .zextOrTrunc(width: SizeInBits);
20010 } else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Val)) {
20011 StoreInt |= C->getValueAPF()
20012 .bitcastToAPInt()
20013 .zextOrTrunc(width: ElementSizeBits)
20014 .zextOrTrunc(width: SizeInBits);
20015 // If fp truncation is necessary give up for now.
20016 if (MemVT.getSizeInBits() != ElementSizeBits)
20017 return false;
20018 } else if (ISD::isBuildVectorOfConstantSDNodes(N: Val.getNode()) ||
20019 ISD::isBuildVectorOfConstantFPSDNodes(N: Val.getNode())) {
20020 // Not yet handled
20021 return false;
20022 } else {
20023 llvm_unreachable("Invalid constant element type");
20024 }
20025 }
20026
20027 // Create the new Load and Store operations.
20028 StoredVal = DAG.getConstant(Val: StoreInt, DL, VT: StoreTy);
20029 }
20030
20031 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
20032 SDValue NewChain = getMergeStoreChains(StoreNodes, NumStores);
20033 bool CanReusePtrInfo = hasSameUnderlyingObj(StoreNodes);
20034
20035 // make sure we use trunc store if it's necessary to be legal.
20036 // When generate the new widen store, if the first store's pointer info can
20037 // not be reused, discard the pointer info except the address space because
20038 // now the widen store can not be represented by the original pointer info
20039 // which is for the narrow memory object.
20040 SDValue NewStore;
20041 if (!UseTrunc) {
20042 NewStore = DAG.getStore(
20043 Chain: NewChain, dl: DL, Val: StoredVal, Ptr: FirstInChain->getBasePtr(),
20044 PtrInfo: CanReusePtrInfo
20045 ? FirstInChain->getPointerInfo()
20046 : MachinePointerInfo(FirstInChain->getPointerInfo().getAddrSpace()),
20047 Alignment: FirstInChain->getAlign(), MMOFlags: *Flags, AAInfo);
20048 } else { // Must be realized as a trunc store
20049 EVT LegalizedStoredValTy =
20050 TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: StoredVal.getValueType());
20051 unsigned LegalizedStoreSize = LegalizedStoredValTy.getSizeInBits();
20052 ConstantSDNode *C = cast<ConstantSDNode>(Val&: StoredVal);
20053 SDValue ExtendedStoreVal =
20054 DAG.getConstant(Val: C->getAPIntValue().zextOrTrunc(width: LegalizedStoreSize), DL,
20055 VT: LegalizedStoredValTy);
20056 NewStore = DAG.getTruncStore(
20057 Chain: NewChain, dl: DL, Val: ExtendedStoreVal, Ptr: FirstInChain->getBasePtr(),
20058 PtrInfo: CanReusePtrInfo
20059 ? FirstInChain->getPointerInfo()
20060 : MachinePointerInfo(FirstInChain->getPointerInfo().getAddrSpace()),
20061 SVT: StoredVal.getValueType() /*TVT*/, Alignment: FirstInChain->getAlign(), MMOFlags: *Flags,
20062 AAInfo);
20063 }
20064
20065 // Replace all merged stores with the new store.
20066 for (unsigned i = 0; i < NumStores; ++i)
20067 CombineTo(N: StoreNodes[i].MemNode, Res: NewStore);
20068
20069 AddToWorklist(N: NewChain.getNode());
20070 return true;
20071}
20072
20073void DAGCombiner::getStoreMergeCandidates(
20074 StoreSDNode *St, SmallVectorImpl<MemOpLink> &StoreNodes,
20075 SDNode *&RootNode) {
20076 // This holds the base pointer, index, and the offset in bytes from the base
20077 // pointer. We must have a base and an offset. Do not handle stores to undef
20078 // base pointers.
20079 BaseIndexOffset BasePtr = BaseIndexOffset::match(N: St, DAG);
20080 if (!BasePtr.getBase().getNode() || BasePtr.getBase().isUndef())
20081 return;
20082
20083 SDValue Val = peekThroughBitcasts(V: St->getValue());
20084 StoreSource StoreSrc = getStoreSource(StoreVal: Val);
20085 assert(StoreSrc != StoreSource::Unknown && "Expected known source for store");
20086
20087 // Match on loadbaseptr if relevant.
20088 EVT MemVT = St->getMemoryVT();
20089 BaseIndexOffset LBasePtr;
20090 EVT LoadVT;
20091 if (StoreSrc == StoreSource::Load) {
20092 auto *Ld = cast<LoadSDNode>(Val);
20093 LBasePtr = BaseIndexOffset::match(N: Ld, DAG);
20094 LoadVT = Ld->getMemoryVT();
20095 // Load and store should be the same type.
20096 if (MemVT != LoadVT)
20097 return;
20098 // Loads must only have one use.
20099 if (!Ld->hasNUsesOfValue(NUses: 1, Value: 0))
20100 return;
20101 // The memory operands must not be volatile/indexed/atomic.
20102 // TODO: May be able to relax for unordered atomics (see D66309)
20103 if (!Ld->isSimple() || Ld->isIndexed())
20104 return;
20105 }
20106 auto CandidateMatch = [&](StoreSDNode *Other, BaseIndexOffset &Ptr,
20107 int64_t &Offset) -> bool {
20108 // The memory operands must not be volatile/indexed/atomic.
20109 // TODO: May be able to relax for unordered atomics (see D66309)
20110 if (!Other->isSimple() || Other->isIndexed())
20111 return false;
20112 // Don't mix temporal stores with non-temporal stores.
20113 if (St->isNonTemporal() != Other->isNonTemporal())
20114 return false;
20115 if (!TLI.areTwoSDNodeTargetMMOFlagsMergeable(NodeX: *St, NodeY: *Other))
20116 return false;
20117 SDValue OtherBC = peekThroughBitcasts(V: Other->getValue());
20118 // Allow merging constants of different types as integers.
20119 bool NoTypeMatch = (MemVT.isInteger()) ? !MemVT.bitsEq(VT: Other->getMemoryVT())
20120 : Other->getMemoryVT() != MemVT;
20121 switch (StoreSrc) {
20122 case StoreSource::Load: {
20123 if (NoTypeMatch)
20124 return false;
20125 // The Load's Base Ptr must also match.
20126 auto *OtherLd = dyn_cast<LoadSDNode>(Val&: OtherBC);
20127 if (!OtherLd)
20128 return false;
20129 BaseIndexOffset LPtr = BaseIndexOffset::match(N: OtherLd, DAG);
20130 if (LoadVT != OtherLd->getMemoryVT())
20131 return false;
20132 // Loads must only have one use.
20133 if (!OtherLd->hasNUsesOfValue(NUses: 1, Value: 0))
20134 return false;
20135 // The memory operands must not be volatile/indexed/atomic.
20136 // TODO: May be able to relax for unordered atomics (see D66309)
20137 if (!OtherLd->isSimple() || OtherLd->isIndexed())
20138 return false;
20139 // Don't mix temporal loads with non-temporal loads.
20140 if (cast<LoadSDNode>(Val)->isNonTemporal() != OtherLd->isNonTemporal())
20141 return false;
20142 if (!TLI.areTwoSDNodeTargetMMOFlagsMergeable(NodeX: *cast<LoadSDNode>(Val),
20143 NodeY: *OtherLd))
20144 return false;
20145 if (!(LBasePtr.equalBaseIndex(Other: LPtr, DAG)))
20146 return false;
20147 break;
20148 }
20149 case StoreSource::Constant:
20150 if (NoTypeMatch)
20151 return false;
20152 if (getStoreSource(StoreVal: OtherBC) != StoreSource::Constant)
20153 return false;
20154 break;
20155 case StoreSource::Extract:
20156 // Do not merge truncated stores here.
20157 if (Other->isTruncatingStore())
20158 return false;
20159 if (!MemVT.bitsEq(VT: OtherBC.getValueType()))
20160 return false;
20161 if (OtherBC.getOpcode() != ISD::EXTRACT_VECTOR_ELT &&
20162 OtherBC.getOpcode() != ISD::EXTRACT_SUBVECTOR)
20163 return false;
20164 break;
20165 default:
20166 llvm_unreachable("Unhandled store source for merging");
20167 }
20168 Ptr = BaseIndexOffset::match(N: Other, DAG);
20169 return (BasePtr.equalBaseIndex(Other: Ptr, DAG, Off&: Offset));
20170 };
20171
20172 // Check if the pair of StoreNode and the RootNode already bail out many
20173 // times which is over the limit in dependence check.
20174 auto OverLimitInDependenceCheck = [&](SDNode *StoreNode,
20175 SDNode *RootNode) -> bool {
20176 auto RootCount = StoreRootCountMap.find(Val: StoreNode);
20177 return RootCount != StoreRootCountMap.end() &&
20178 RootCount->second.first == RootNode &&
20179 RootCount->second.second > StoreMergeDependenceLimit;
20180 };
20181
20182 auto TryToAddCandidate = [&](SDNode::use_iterator UseIter) {
20183 // This must be a chain use.
20184 if (UseIter.getOperandNo() != 0)
20185 return;
20186 if (auto *OtherStore = dyn_cast<StoreSDNode>(Val: *UseIter)) {
20187 BaseIndexOffset Ptr;
20188 int64_t PtrDiff;
20189 if (CandidateMatch(OtherStore, Ptr, PtrDiff) &&
20190 !OverLimitInDependenceCheck(OtherStore, RootNode))
20191 StoreNodes.push_back(Elt: MemOpLink(OtherStore, PtrDiff));
20192 }
20193 };
20194
20195 // We looking for a root node which is an ancestor to all mergable
20196 // stores. We search up through a load, to our root and then down
20197 // through all children. For instance we will find Store{1,2,3} if
20198 // St is Store1, Store2. or Store3 where the root is not a load
20199 // which always true for nonvolatile ops. TODO: Expand
20200 // the search to find all valid candidates through multiple layers of loads.
20201 //
20202 // Root
20203 // |-------|-------|
20204 // Load Load Store3
20205 // | |
20206 // Store1 Store2
20207 //
20208 // FIXME: We should be able to climb and
20209 // descend TokenFactors to find candidates as well.
20210
20211 RootNode = St->getChain().getNode();
20212
20213 unsigned NumNodesExplored = 0;
20214 const unsigned MaxSearchNodes = 1024;
20215 if (auto *Ldn = dyn_cast<LoadSDNode>(Val: RootNode)) {
20216 RootNode = Ldn->getChain().getNode();
20217 for (auto I = RootNode->use_begin(), E = RootNode->use_end();
20218 I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored) {
20219 if (I.getOperandNo() == 0 && isa<LoadSDNode>(Val: *I)) { // walk down chain
20220 for (auto I2 = (*I)->use_begin(), E2 = (*I)->use_end(); I2 != E2; ++I2)
20221 TryToAddCandidate(I2);
20222 }
20223 // Check stores that depend on the root (e.g. Store 3 in the chart above).
20224 if (I.getOperandNo() == 0 && isa<StoreSDNode>(Val: *I)) {
20225 TryToAddCandidate(I);
20226 }
20227 }
20228 } else {
20229 for (auto I = RootNode->use_begin(), E = RootNode->use_end();
20230 I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored)
20231 TryToAddCandidate(I);
20232 }
20233}
20234
20235// We need to check that merging these stores does not cause a loop in the
20236// DAG. Any store candidate may depend on another candidate indirectly through
20237// its operands. Check in parallel by searching up from operands of candidates.
20238bool DAGCombiner::checkMergeStoreCandidatesForDependencies(
20239 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
20240 SDNode *RootNode) {
20241 // FIXME: We should be able to truncate a full search of
20242 // predecessors by doing a BFS and keeping tabs the originating
20243 // stores from which worklist nodes come from in a similar way to
20244 // TokenFactor simplfication.
20245
20246 SmallPtrSet<const SDNode *, 32> Visited;
20247 SmallVector<const SDNode *, 8> Worklist;
20248
20249 // RootNode is a predecessor to all candidates so we need not search
20250 // past it. Add RootNode (peeking through TokenFactors). Do not count
20251 // these towards size check.
20252
20253 Worklist.push_back(Elt: RootNode);
20254 while (!Worklist.empty()) {
20255 auto N = Worklist.pop_back_val();
20256 if (!Visited.insert(Ptr: N).second)
20257 continue; // Already present in Visited.
20258 if (N->getOpcode() == ISD::TokenFactor) {
20259 for (SDValue Op : N->ops())
20260 Worklist.push_back(Elt: Op.getNode());
20261 }
20262 }
20263
20264 // Don't count pruning nodes towards max.
20265 unsigned int Max = 1024 + Visited.size();
20266 // Search Ops of store candidates.
20267 for (unsigned i = 0; i < NumStores; ++i) {
20268 SDNode *N = StoreNodes[i].MemNode;
20269 // Of the 4 Store Operands:
20270 // * Chain (Op 0) -> We have already considered these
20271 // in candidate selection, but only by following the
20272 // chain dependencies. We could still have a chain
20273 // dependency to a load, that has a non-chain dep to
20274 // another load, that depends on a store, etc. So it is
20275 // possible to have dependencies that consist of a mix
20276 // of chain and non-chain deps, and we need to include
20277 // chain operands in the analysis here..
20278 // * Value (Op 1) -> Cycles may happen (e.g. through load chains)
20279 // * Address (Op 2) -> Merged addresses may only vary by a fixed constant,
20280 // but aren't necessarily fromt the same base node, so
20281 // cycles possible (e.g. via indexed store).
20282 // * (Op 3) -> Represents the pre or post-indexing offset (or undef for
20283 // non-indexed stores). Not constant on all targets (e.g. ARM)
20284 // and so can participate in a cycle.
20285 for (unsigned j = 0; j < N->getNumOperands(); ++j)
20286 Worklist.push_back(Elt: N->getOperand(Num: j).getNode());
20287 }
20288 // Search through DAG. We can stop early if we find a store node.
20289 for (unsigned i = 0; i < NumStores; ++i)
20290 if (SDNode::hasPredecessorHelper(N: StoreNodes[i].MemNode, Visited, Worklist,
20291 MaxSteps: Max)) {
20292 // If the searching bail out, record the StoreNode and RootNode in the
20293 // StoreRootCountMap. If we have seen the pair many times over a limit,
20294 // we won't add the StoreNode into StoreNodes set again.
20295 if (Visited.size() >= Max) {
20296 auto &RootCount = StoreRootCountMap[StoreNodes[i].MemNode];
20297 if (RootCount.first == RootNode)
20298 RootCount.second++;
20299 else
20300 RootCount = {RootNode, 1};
20301 }
20302 return false;
20303 }
20304 return true;
20305}
20306
20307unsigned
20308DAGCombiner::getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
20309 int64_t ElementSizeBytes) const {
20310 while (true) {
20311 // Find a store past the width of the first store.
20312 size_t StartIdx = 0;
20313 while ((StartIdx + 1 < StoreNodes.size()) &&
20314 StoreNodes[StartIdx].OffsetFromBase + ElementSizeBytes !=
20315 StoreNodes[StartIdx + 1].OffsetFromBase)
20316 ++StartIdx;
20317
20318 // Bail if we don't have enough candidates to merge.
20319 if (StartIdx + 1 >= StoreNodes.size())
20320 return 0;
20321
20322 // Trim stores that overlapped with the first store.
20323 if (StartIdx)
20324 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + StartIdx);
20325
20326 // Scan the memory operations on the chain and find the first
20327 // non-consecutive store memory address.
20328 unsigned NumConsecutiveStores = 1;
20329 int64_t StartAddress = StoreNodes[0].OffsetFromBase;
20330 // Check that the addresses are consecutive starting from the second
20331 // element in the list of stores.
20332 for (unsigned i = 1, e = StoreNodes.size(); i < e; ++i) {
20333 int64_t CurrAddress = StoreNodes[i].OffsetFromBase;
20334 if (CurrAddress - StartAddress != (ElementSizeBytes * i))
20335 break;
20336 NumConsecutiveStores = i + 1;
20337 }
20338 if (NumConsecutiveStores > 1)
20339 return NumConsecutiveStores;
20340
20341 // There are no consecutive stores at the start of the list.
20342 // Remove the first store and try again.
20343 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + 1);
20344 }
20345}
20346
20347bool DAGCombiner::tryStoreMergeOfConstants(
20348 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
20349 EVT MemVT, SDNode *RootNode, bool AllowVectors) {
20350 LLVMContext &Context = *DAG.getContext();
20351 const DataLayout &DL = DAG.getDataLayout();
20352 int64_t ElementSizeBytes = MemVT.getStoreSize();
20353 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
20354 bool MadeChange = false;
20355
20356 // Store the constants into memory as one consecutive store.
20357 while (NumConsecutiveStores >= 2) {
20358 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
20359 unsigned FirstStoreAS = FirstInChain->getAddressSpace();
20360 Align FirstStoreAlign = FirstInChain->getAlign();
20361 unsigned LastLegalType = 1;
20362 unsigned LastLegalVectorType = 1;
20363 bool LastIntegerTrunc = false;
20364 bool NonZero = false;
20365 unsigned FirstZeroAfterNonZero = NumConsecutiveStores;
20366 for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
20367 StoreSDNode *ST = cast<StoreSDNode>(Val: StoreNodes[i].MemNode);
20368 SDValue StoredVal = ST->getValue();
20369 bool IsElementZero = false;
20370 if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val&: StoredVal))
20371 IsElementZero = C->isZero();
20372 else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Val&: StoredVal))
20373 IsElementZero = C->getConstantFPValue()->isNullValue();
20374 else if (ISD::isBuildVectorAllZeros(N: StoredVal.getNode()))
20375 IsElementZero = true;
20376 if (IsElementZero) {
20377 if (NonZero && FirstZeroAfterNonZero == NumConsecutiveStores)
20378 FirstZeroAfterNonZero = i;
20379 }
20380 NonZero |= !IsElementZero;
20381
20382 // Find a legal type for the constant store.
20383 unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
20384 EVT StoreTy = EVT::getIntegerVT(Context, BitWidth: SizeInBits);
20385 unsigned IsFast = 0;
20386
20387 // Break early when size is too large to be legal.
20388 if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
20389 break;
20390
20391 if (TLI.isTypeLegal(VT: StoreTy) &&
20392 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: StoreTy,
20393 MF: DAG.getMachineFunction()) &&
20394 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
20395 MMO: *FirstInChain->getMemOperand(), Fast: &IsFast) &&
20396 IsFast) {
20397 LastIntegerTrunc = false;
20398 LastLegalType = i + 1;
20399 // Or check whether a truncstore is legal.
20400 } else if (TLI.getTypeAction(Context, VT: StoreTy) ==
20401 TargetLowering::TypePromoteInteger) {
20402 EVT LegalizedStoredValTy =
20403 TLI.getTypeToTransformTo(Context, VT: StoredVal.getValueType());
20404 if (TLI.isTruncStoreLegal(ValVT: LegalizedStoredValTy, MemVT: StoreTy) &&
20405 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: LegalizedStoredValTy,
20406 MF: DAG.getMachineFunction()) &&
20407 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
20408 MMO: *FirstInChain->getMemOperand(), Fast: &IsFast) &&
20409 IsFast) {
20410 LastIntegerTrunc = true;
20411 LastLegalType = i + 1;
20412 }
20413 }
20414
20415 // We only use vectors if the target allows it and the function is not
20416 // marked with the noimplicitfloat attribute.
20417 if (TLI.storeOfVectorConstantIsCheap(IsZero: !NonZero, MemVT, NumElem: i + 1, AddrSpace: FirstStoreAS) &&
20418 AllowVectors) {
20419 // Find a legal type for the vector store.
20420 unsigned Elts = (i + 1) * NumMemElts;
20421 EVT Ty = EVT::getVectorVT(Context, VT: MemVT.getScalarType(), NumElements: Elts);
20422 if (TLI.isTypeLegal(VT: Ty) && TLI.isTypeLegal(VT: MemVT) &&
20423 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: Ty, MF: DAG.getMachineFunction()) &&
20424 TLI.allowsMemoryAccess(Context, DL, VT: Ty,
20425 MMO: *FirstInChain->getMemOperand(), Fast: &IsFast) &&
20426 IsFast)
20427 LastLegalVectorType = i + 1;
20428 }
20429 }
20430
20431 bool UseVector = (LastLegalVectorType > LastLegalType) && AllowVectors;
20432 unsigned NumElem = (UseVector) ? LastLegalVectorType : LastLegalType;
20433 bool UseTrunc = LastIntegerTrunc && !UseVector;
20434
20435 // Check if we found a legal integer type that creates a meaningful
20436 // merge.
20437 if (NumElem < 2) {
20438 // We know that candidate stores are in order and of correct
20439 // shape. While there is no mergeable sequence from the
20440 // beginning one may start later in the sequence. The only
20441 // reason a merge of size N could have failed where another of
20442 // the same size would not have, is if the alignment has
20443 // improved or we've dropped a non-zero value. Drop as many
20444 // candidates as we can here.
20445 unsigned NumSkip = 1;
20446 while ((NumSkip < NumConsecutiveStores) &&
20447 (NumSkip < FirstZeroAfterNonZero) &&
20448 (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
20449 NumSkip++;
20450
20451 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumSkip);
20452 NumConsecutiveStores -= NumSkip;
20453 continue;
20454 }
20455
20456 // Check that we can merge these candidates without causing a cycle.
20457 if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumStores: NumElem,
20458 RootNode)) {
20459 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumElem);
20460 NumConsecutiveStores -= NumElem;
20461 continue;
20462 }
20463
20464 MadeChange |= mergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumStores: NumElem,
20465 /*IsConstantSrc*/ true,
20466 UseVector, UseTrunc);
20467
20468 // Remove merged stores for next iteration.
20469 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumElem);
20470 NumConsecutiveStores -= NumElem;
20471 }
20472 return MadeChange;
20473}
20474
20475bool DAGCombiner::tryStoreMergeOfExtracts(
20476 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
20477 EVT MemVT, SDNode *RootNode) {
20478 LLVMContext &Context = *DAG.getContext();
20479 const DataLayout &DL = DAG.getDataLayout();
20480 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
20481 bool MadeChange = false;
20482
20483 // Loop on Consecutive Stores on success.
20484 while (NumConsecutiveStores >= 2) {
20485 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
20486 unsigned FirstStoreAS = FirstInChain->getAddressSpace();
20487 Align FirstStoreAlign = FirstInChain->getAlign();
20488 unsigned NumStoresToMerge = 1;
20489 for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
20490 // Find a legal type for the vector store.
20491 unsigned Elts = (i + 1) * NumMemElts;
20492 EVT Ty = EVT::getVectorVT(Context&: *DAG.getContext(), VT: MemVT.getScalarType(), NumElements: Elts);
20493 unsigned IsFast = 0;
20494
20495 // Break early when size is too large to be legal.
20496 if (Ty.getSizeInBits() > MaximumLegalStoreInBits)
20497 break;
20498
20499 if (TLI.isTypeLegal(VT: Ty) &&
20500 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: Ty, MF: DAG.getMachineFunction()) &&
20501 TLI.allowsMemoryAccess(Context, DL, VT: Ty,
20502 MMO: *FirstInChain->getMemOperand(), Fast: &IsFast) &&
20503 IsFast)
20504 NumStoresToMerge = i + 1;
20505 }
20506
20507 // Check if we found a legal integer type creating a meaningful
20508 // merge.
20509 if (NumStoresToMerge < 2) {
20510 // We know that candidate stores are in order and of correct
20511 // shape. While there is no mergeable sequence from the
20512 // beginning one may start later in the sequence. The only
20513 // reason a merge of size N could have failed where another of
20514 // the same size would not have, is if the alignment has
20515 // improved. Drop as many candidates as we can here.
20516 unsigned NumSkip = 1;
20517 while ((NumSkip < NumConsecutiveStores) &&
20518 (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
20519 NumSkip++;
20520
20521 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumSkip);
20522 NumConsecutiveStores -= NumSkip;
20523 continue;
20524 }
20525
20526 // Check that we can merge these candidates without causing a cycle.
20527 if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumStores: NumStoresToMerge,
20528 RootNode)) {
20529 StoreNodes.erase(CS: StoreNodes.begin(),
20530 CE: StoreNodes.begin() + NumStoresToMerge);
20531 NumConsecutiveStores -= NumStoresToMerge;
20532 continue;
20533 }
20534
20535 MadeChange |= mergeStoresOfConstantsOrVecElts(
20536 StoreNodes, MemVT, NumStores: NumStoresToMerge, /*IsConstantSrc*/ false,
20537 /*UseVector*/ true, /*UseTrunc*/ false);
20538
20539 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumStoresToMerge);
20540 NumConsecutiveStores -= NumStoresToMerge;
20541 }
20542 return MadeChange;
20543}
20544
20545bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
20546 unsigned NumConsecutiveStores, EVT MemVT,
20547 SDNode *RootNode, bool AllowVectors,
20548 bool IsNonTemporalStore,
20549 bool IsNonTemporalLoad) {
20550 LLVMContext &Context = *DAG.getContext();
20551 const DataLayout &DL = DAG.getDataLayout();
20552 int64_t ElementSizeBytes = MemVT.getStoreSize();
20553 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
20554 bool MadeChange = false;
20555
20556 // Look for load nodes which are used by the stored values.
20557 SmallVector<MemOpLink, 8> LoadNodes;
20558
20559 // Find acceptable loads. Loads need to have the same chain (token factor),
20560 // must not be zext, volatile, indexed, and they must be consecutive.
20561 BaseIndexOffset LdBasePtr;
20562
20563 for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
20564 StoreSDNode *St = cast<StoreSDNode>(Val: StoreNodes[i].MemNode);
20565 SDValue Val = peekThroughBitcasts(V: St->getValue());
20566 LoadSDNode *Ld = cast<LoadSDNode>(Val);
20567
20568 BaseIndexOffset LdPtr = BaseIndexOffset::match(N: Ld, DAG);
20569 // If this is not the first ptr that we check.
20570 int64_t LdOffset = 0;
20571 if (LdBasePtr.getBase().getNode()) {
20572 // The base ptr must be the same.
20573 if (!LdBasePtr.equalBaseIndex(Other: LdPtr, DAG, Off&: LdOffset))
20574 break;
20575 } else {
20576 // Check that all other base pointers are the same as this one.
20577 LdBasePtr = LdPtr;
20578 }
20579
20580 // We found a potential memory operand to merge.
20581 LoadNodes.push_back(Elt: MemOpLink(Ld, LdOffset));
20582 }
20583
20584 while (NumConsecutiveStores >= 2 && LoadNodes.size() >= 2) {
20585 Align RequiredAlignment;
20586 bool NeedRotate = false;
20587 if (LoadNodes.size() == 2) {
20588 // If we have load/store pair instructions and we only have two values,
20589 // don't bother merging.
20590 if (TLI.hasPairedLoad(MemVT, RequiredAlignment) &&
20591 StoreNodes[0].MemNode->getAlign() >= RequiredAlignment) {
20592 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + 2);
20593 LoadNodes.erase(CS: LoadNodes.begin(), CE: LoadNodes.begin() + 2);
20594 break;
20595 }
20596 // If the loads are reversed, see if we can rotate the halves into place.
20597 int64_t Offset0 = LoadNodes[0].OffsetFromBase;
20598 int64_t Offset1 = LoadNodes[1].OffsetFromBase;
20599 EVT PairVT = EVT::getIntegerVT(Context, BitWidth: ElementSizeBytes * 8 * 2);
20600 if (Offset0 - Offset1 == ElementSizeBytes &&
20601 (hasOperation(Opcode: ISD::ROTL, VT: PairVT) ||
20602 hasOperation(Opcode: ISD::ROTR, VT: PairVT))) {
20603 std::swap(a&: LoadNodes[0], b&: LoadNodes[1]);
20604 NeedRotate = true;
20605 }
20606 }
20607 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
20608 unsigned FirstStoreAS = FirstInChain->getAddressSpace();
20609 Align FirstStoreAlign = FirstInChain->getAlign();
20610 LoadSDNode *FirstLoad = cast<LoadSDNode>(Val: LoadNodes[0].MemNode);
20611
20612 // Scan the memory operations on the chain and find the first
20613 // non-consecutive load memory address. These variables hold the index in
20614 // the store node array.
20615
20616 unsigned LastConsecutiveLoad = 1;
20617
20618 // This variable refers to the size and not index in the array.
20619 unsigned LastLegalVectorType = 1;
20620 unsigned LastLegalIntegerType = 1;
20621 bool isDereferenceable = true;
20622 bool DoIntegerTruncate = false;
20623 int64_t StartAddress = LoadNodes[0].OffsetFromBase;
20624 SDValue LoadChain = FirstLoad->getChain();
20625 for (unsigned i = 1; i < LoadNodes.size(); ++i) {
20626 // All loads must share the same chain.
20627 if (LoadNodes[i].MemNode->getChain() != LoadChain)
20628 break;
20629
20630 int64_t CurrAddress = LoadNodes[i].OffsetFromBase;
20631 if (CurrAddress - StartAddress != (ElementSizeBytes * i))
20632 break;
20633 LastConsecutiveLoad = i;
20634
20635 if (isDereferenceable && !LoadNodes[i].MemNode->isDereferenceable())
20636 isDereferenceable = false;
20637
20638 // Find a legal type for the vector store.
20639 unsigned Elts = (i + 1) * NumMemElts;
20640 EVT StoreTy = EVT::getVectorVT(Context, VT: MemVT.getScalarType(), NumElements: Elts);
20641
20642 // Break early when size is too large to be legal.
20643 if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
20644 break;
20645
20646 unsigned IsFastSt = 0;
20647 unsigned IsFastLd = 0;
20648 // Don't try vector types if we need a rotate. We may still fail the
20649 // legality checks for the integer type, but we can't handle the rotate
20650 // case with vectors.
20651 // FIXME: We could use a shuffle in place of the rotate.
20652 if (!NeedRotate && TLI.isTypeLegal(VT: StoreTy) &&
20653 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: StoreTy,
20654 MF: DAG.getMachineFunction()) &&
20655 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
20656 MMO: *FirstInChain->getMemOperand(), Fast: &IsFastSt) &&
20657 IsFastSt &&
20658 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
20659 MMO: *FirstLoad->getMemOperand(), Fast: &IsFastLd) &&
20660 IsFastLd) {
20661 LastLegalVectorType = i + 1;
20662 }
20663
20664 // Find a legal type for the integer store.
20665 unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
20666 StoreTy = EVT::getIntegerVT(Context, BitWidth: SizeInBits);
20667 if (TLI.isTypeLegal(VT: StoreTy) &&
20668 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: StoreTy,
20669 MF: DAG.getMachineFunction()) &&
20670 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
20671 MMO: *FirstInChain->getMemOperand(), Fast: &IsFastSt) &&
20672 IsFastSt &&
20673 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
20674 MMO: *FirstLoad->getMemOperand(), Fast: &IsFastLd) &&
20675 IsFastLd) {
20676 LastLegalIntegerType = i + 1;
20677 DoIntegerTruncate = false;
20678 // Or check whether a truncstore and extload is legal.
20679 } else if (TLI.getTypeAction(Context, VT: StoreTy) ==
20680 TargetLowering::TypePromoteInteger) {
20681 EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(Context, VT: StoreTy);
20682 if (TLI.isTruncStoreLegal(ValVT: LegalizedStoredValTy, MemVT: StoreTy) &&
20683 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: LegalizedStoredValTy,
20684 MF: DAG.getMachineFunction()) &&
20685 TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: LegalizedStoredValTy, MemVT: StoreTy) &&
20686 TLI.isLoadExtLegal(ExtType: ISD::SEXTLOAD, ValVT: LegalizedStoredValTy, MemVT: StoreTy) &&
20687 TLI.isLoadExtLegal(ExtType: ISD::EXTLOAD, ValVT: LegalizedStoredValTy, MemVT: StoreTy) &&
20688 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
20689 MMO: *FirstInChain->getMemOperand(), Fast: &IsFastSt) &&
20690 IsFastSt &&
20691 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
20692 MMO: *FirstLoad->getMemOperand(), Fast: &IsFastLd) &&
20693 IsFastLd) {
20694 LastLegalIntegerType = i + 1;
20695 DoIntegerTruncate = true;
20696 }
20697 }
20698 }
20699
20700 // Only use vector types if the vector type is larger than the integer
20701 // type. If they are the same, use integers.
20702 bool UseVectorTy =
20703 LastLegalVectorType > LastLegalIntegerType && AllowVectors;
20704 unsigned LastLegalType =
20705 std::max(a: LastLegalVectorType, b: LastLegalIntegerType);
20706
20707 // We add +1 here because the LastXXX variables refer to location while
20708 // the NumElem refers to array/index size.
20709 unsigned NumElem = std::min(a: NumConsecutiveStores, b: LastConsecutiveLoad + 1);
20710 NumElem = std::min(a: LastLegalType, b: NumElem);
20711 Align FirstLoadAlign = FirstLoad->getAlign();
20712
20713 if (NumElem < 2) {
20714 // We know that candidate stores are in order and of correct
20715 // shape. While there is no mergeable sequence from the
20716 // beginning one may start later in the sequence. The only
20717 // reason a merge of size N could have failed where another of
20718 // the same size would not have is if the alignment or either
20719 // the load or store has improved. Drop as many candidates as we
20720 // can here.
20721 unsigned NumSkip = 1;
20722 while ((NumSkip < LoadNodes.size()) &&
20723 (LoadNodes[NumSkip].MemNode->getAlign() <= FirstLoadAlign) &&
20724 (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
20725 NumSkip++;
20726 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumSkip);
20727 LoadNodes.erase(CS: LoadNodes.begin(), CE: LoadNodes.begin() + NumSkip);
20728 NumConsecutiveStores -= NumSkip;
20729 continue;
20730 }
20731
20732 // Check that we can merge these candidates without causing a cycle.
20733 if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumStores: NumElem,
20734 RootNode)) {
20735 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumElem);
20736 LoadNodes.erase(CS: LoadNodes.begin(), CE: LoadNodes.begin() + NumElem);
20737 NumConsecutiveStores -= NumElem;
20738 continue;
20739 }
20740
20741 // Find if it is better to use vectors or integers to load and store
20742 // to memory.
20743 EVT JointMemOpVT;
20744 if (UseVectorTy) {
20745 // Find a legal type for the vector store.
20746 unsigned Elts = NumElem * NumMemElts;
20747 JointMemOpVT = EVT::getVectorVT(Context, VT: MemVT.getScalarType(), NumElements: Elts);
20748 } else {
20749 unsigned SizeInBits = NumElem * ElementSizeBytes * 8;
20750 JointMemOpVT = EVT::getIntegerVT(Context, BitWidth: SizeInBits);
20751 }
20752
20753 SDLoc LoadDL(LoadNodes[0].MemNode);
20754 SDLoc StoreDL(StoreNodes[0].MemNode);
20755
20756 // The merged loads are required to have the same incoming chain, so
20757 // using the first's chain is acceptable.
20758
20759 SDValue NewStoreChain = getMergeStoreChains(StoreNodes, NumStores: NumElem);
20760 bool CanReusePtrInfo = hasSameUnderlyingObj(StoreNodes);
20761 AddToWorklist(N: NewStoreChain.getNode());
20762
20763 MachineMemOperand::Flags LdMMOFlags =
20764 isDereferenceable ? MachineMemOperand::MODereferenceable
20765 : MachineMemOperand::MONone;
20766 if (IsNonTemporalLoad)
20767 LdMMOFlags |= MachineMemOperand::MONonTemporal;
20768
20769 LdMMOFlags |= TLI.getTargetMMOFlags(Node: *FirstLoad);
20770
20771 MachineMemOperand::Flags StMMOFlags = IsNonTemporalStore
20772 ? MachineMemOperand::MONonTemporal
20773 : MachineMemOperand::MONone;
20774
20775 StMMOFlags |= TLI.getTargetMMOFlags(Node: *StoreNodes[0].MemNode);
20776
20777 SDValue NewLoad, NewStore;
20778 if (UseVectorTy || !DoIntegerTruncate) {
20779 NewLoad = DAG.getLoad(
20780 VT: JointMemOpVT, dl: LoadDL, Chain: FirstLoad->getChain(), Ptr: FirstLoad->getBasePtr(),
20781 PtrInfo: FirstLoad->getPointerInfo(), Alignment: FirstLoadAlign, MMOFlags: LdMMOFlags);
20782 SDValue StoreOp = NewLoad;
20783 if (NeedRotate) {
20784 unsigned LoadWidth = ElementSizeBytes * 8 * 2;
20785 assert(JointMemOpVT == EVT::getIntegerVT(Context, LoadWidth) &&
20786 "Unexpected type for rotate-able load pair");
20787 SDValue RotAmt =
20788 DAG.getShiftAmountConstant(Val: LoadWidth / 2, VT: JointMemOpVT, DL: LoadDL);
20789 // Target can convert to the identical ROTR if it does not have ROTL.
20790 StoreOp = DAG.getNode(Opcode: ISD::ROTL, DL: LoadDL, VT: JointMemOpVT, N1: NewLoad, N2: RotAmt);
20791 }
20792 NewStore = DAG.getStore(
20793 Chain: NewStoreChain, dl: StoreDL, Val: StoreOp, Ptr: FirstInChain->getBasePtr(),
20794 PtrInfo: CanReusePtrInfo ? FirstInChain->getPointerInfo()
20795 : MachinePointerInfo(FirstStoreAS),
20796 Alignment: FirstStoreAlign, MMOFlags: StMMOFlags);
20797 } else { // This must be the truncstore/extload case
20798 EVT ExtendedTy =
20799 TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: JointMemOpVT);
20800 NewLoad = DAG.getExtLoad(ExtType: ISD::EXTLOAD, dl: LoadDL, VT: ExtendedTy,
20801 Chain: FirstLoad->getChain(), Ptr: FirstLoad->getBasePtr(),
20802 PtrInfo: FirstLoad->getPointerInfo(), MemVT: JointMemOpVT,
20803 Alignment: FirstLoadAlign, MMOFlags: LdMMOFlags);
20804 NewStore = DAG.getTruncStore(
20805 Chain: NewStoreChain, dl: StoreDL, Val: NewLoad, Ptr: FirstInChain->getBasePtr(),
20806 PtrInfo: CanReusePtrInfo ? FirstInChain->getPointerInfo()
20807 : MachinePointerInfo(FirstStoreAS),
20808 SVT: JointMemOpVT, Alignment: FirstInChain->getAlign(),
20809 MMOFlags: FirstInChain->getMemOperand()->getFlags());
20810 }
20811
20812 // Transfer chain users from old loads to the new load.
20813 for (unsigned i = 0; i < NumElem; ++i) {
20814 LoadSDNode *Ld = cast<LoadSDNode>(Val: LoadNodes[i].MemNode);
20815 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Ld, 1),
20816 To: SDValue(NewLoad.getNode(), 1));
20817 }
20818
20819 // Replace all stores with the new store. Recursively remove corresponding
20820 // values if they are no longer used.
20821 for (unsigned i = 0; i < NumElem; ++i) {
20822 SDValue Val = StoreNodes[i].MemNode->getOperand(Num: 1);
20823 CombineTo(N: StoreNodes[i].MemNode, Res: NewStore);
20824 if (Val->use_empty())
20825 recursivelyDeleteUnusedNodes(N: Val.getNode());
20826 }
20827
20828 MadeChange = true;
20829 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumElem);
20830 LoadNodes.erase(CS: LoadNodes.begin(), CE: LoadNodes.begin() + NumElem);
20831 NumConsecutiveStores -= NumElem;
20832 }
20833 return MadeChange;
20834}
20835
20836bool DAGCombiner::mergeConsecutiveStores(StoreSDNode *St) {
20837 if (OptLevel == CodeGenOptLevel::None || !EnableStoreMerging)
20838 return false;
20839
20840 // TODO: Extend this function to merge stores of scalable vectors.
20841 // (i.e. two <vscale x 8 x i8> stores can be merged to one <vscale x 16 x i8>
20842 // store since we know <vscale x 16 x i8> is exactly twice as large as
20843 // <vscale x 8 x i8>). Until then, bail out for scalable vectors.
20844 EVT MemVT = St->getMemoryVT();
20845 if (MemVT.isScalableVT())
20846 return false;
20847 if (!MemVT.isSimple() || MemVT.getSizeInBits() * 2 > MaximumLegalStoreInBits)
20848 return false;
20849
20850 // This function cannot currently deal with non-byte-sized memory sizes.
20851 int64_t ElementSizeBytes = MemVT.getStoreSize();
20852 if (ElementSizeBytes * 8 != (int64_t)MemVT.getSizeInBits())
20853 return false;
20854
20855 // Do not bother looking at stored values that are not constants, loads, or
20856 // extracted vector elements.
20857 SDValue StoredVal = peekThroughBitcasts(V: St->getValue());
20858 const StoreSource StoreSrc = getStoreSource(StoreVal: StoredVal);
20859 if (StoreSrc == StoreSource::Unknown)
20860 return false;
20861
20862 SmallVector<MemOpLink, 8> StoreNodes;
20863 SDNode *RootNode;
20864 // Find potential store merge candidates by searching through chain sub-DAG
20865 getStoreMergeCandidates(St, StoreNodes, RootNode);
20866
20867 // Check if there is anything to merge.
20868 if (StoreNodes.size() < 2)
20869 return false;
20870
20871 // Sort the memory operands according to their distance from the
20872 // base pointer.
20873 llvm::sort(C&: StoreNodes, Comp: [](MemOpLink LHS, MemOpLink RHS) {
20874 return LHS.OffsetFromBase < RHS.OffsetFromBase;
20875 });
20876
20877 bool AllowVectors = !DAG.getMachineFunction().getFunction().hasFnAttribute(
20878 Attribute::NoImplicitFloat);
20879 bool IsNonTemporalStore = St->isNonTemporal();
20880 bool IsNonTemporalLoad = StoreSrc == StoreSource::Load &&
20881 cast<LoadSDNode>(Val&: StoredVal)->isNonTemporal();
20882
20883 // Store Merge attempts to merge the lowest stores. This generally
20884 // works out as if successful, as the remaining stores are checked
20885 // after the first collection of stores is merged. However, in the
20886 // case that a non-mergeable store is found first, e.g., {p[-2],
20887 // p[0], p[1], p[2], p[3]}, we would fail and miss the subsequent
20888 // mergeable cases. To prevent this, we prune such stores from the
20889 // front of StoreNodes here.
20890 bool MadeChange = false;
20891 while (StoreNodes.size() > 1) {
20892 unsigned NumConsecutiveStores =
20893 getConsecutiveStores(StoreNodes, ElementSizeBytes);
20894 // There are no more stores in the list to examine.
20895 if (NumConsecutiveStores == 0)
20896 return MadeChange;
20897
20898 // We have at least 2 consecutive stores. Try to merge them.
20899 assert(NumConsecutiveStores >= 2 && "Expected at least 2 stores");
20900 switch (StoreSrc) {
20901 case StoreSource::Constant:
20902 MadeChange |= tryStoreMergeOfConstants(StoreNodes, NumConsecutiveStores,
20903 MemVT, RootNode, AllowVectors);
20904 break;
20905
20906 case StoreSource::Extract:
20907 MadeChange |= tryStoreMergeOfExtracts(StoreNodes, NumConsecutiveStores,
20908 MemVT, RootNode);
20909 break;
20910
20911 case StoreSource::Load:
20912 MadeChange |= tryStoreMergeOfLoads(StoreNodes, NumConsecutiveStores,
20913 MemVT, RootNode, AllowVectors,
20914 IsNonTemporalStore, IsNonTemporalLoad);
20915 break;
20916
20917 default:
20918 llvm_unreachable("Unhandled store source type");
20919 }
20920 }
20921 return MadeChange;
20922}
20923
20924SDValue DAGCombiner::replaceStoreChain(StoreSDNode *ST, SDValue BetterChain) {
20925 SDLoc SL(ST);
20926 SDValue ReplStore;
20927
20928 // Replace the chain to avoid dependency.
20929 if (ST->isTruncatingStore()) {
20930 ReplStore = DAG.getTruncStore(Chain: BetterChain, dl: SL, Val: ST->getValue(),
20931 Ptr: ST->getBasePtr(), SVT: ST->getMemoryVT(),
20932 MMO: ST->getMemOperand());
20933 } else {
20934 ReplStore = DAG.getStore(Chain: BetterChain, dl: SL, Val: ST->getValue(), Ptr: ST->getBasePtr(),
20935 MMO: ST->getMemOperand());
20936 }
20937
20938 // Create token to keep both nodes around.
20939 SDValue Token = DAG.getNode(ISD::TokenFactor, SL,
20940 MVT::Other, ST->getChain(), ReplStore);
20941
20942 // Make sure the new and old chains are cleaned up.
20943 AddToWorklist(N: Token.getNode());
20944
20945 // Don't add users to work list.
20946 return CombineTo(N: ST, Res: Token, AddTo: false);
20947}
20948
20949SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) {
20950 SDValue Value = ST->getValue();
20951 if (Value.getOpcode() == ISD::TargetConstantFP)
20952 return SDValue();
20953
20954 if (!ISD::isNormalStore(N: ST))
20955 return SDValue();
20956
20957 SDLoc DL(ST);
20958
20959 SDValue Chain = ST->getChain();
20960 SDValue Ptr = ST->getBasePtr();
20961
20962 const ConstantFPSDNode *CFP = cast<ConstantFPSDNode>(Val&: Value);
20963
20964 // NOTE: If the original store is volatile, this transform must not increase
20965 // the number of stores. For example, on x86-32 an f64 can be stored in one
20966 // processor operation but an i64 (which is not legal) requires two. So the
20967 // transform should not be done in this case.
20968
20969 SDValue Tmp;
20970 switch (CFP->getSimpleValueType(ResNo: 0).SimpleTy) {
20971 default:
20972 llvm_unreachable("Unknown FP type");
20973 case MVT::f16: // We don't do this for these yet.
20974 case MVT::bf16:
20975 case MVT::f80:
20976 case MVT::f128:
20977 case MVT::ppcf128:
20978 return SDValue();
20979 case MVT::f32:
20980 if ((isTypeLegal(MVT::i32) && !LegalOperations && ST->isSimple()) ||
20981 TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) {
20982 Tmp = DAG.getConstant((uint32_t)CFP->getValueAPF().
20983 bitcastToAPInt().getZExtValue(), SDLoc(CFP),
20984 MVT::i32);
20985 return DAG.getStore(Chain, dl: DL, Val: Tmp, Ptr, MMO: ST->getMemOperand());
20986 }
20987
20988 return SDValue();
20989 case MVT::f64:
20990 if ((TLI.isTypeLegal(MVT::i64) && !LegalOperations &&
20991 ST->isSimple()) ||
20992 TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i64)) {
20993 Tmp = DAG.getConstant(CFP->getValueAPF().bitcastToAPInt().
20994 getZExtValue(), SDLoc(CFP), MVT::i64);
20995 return DAG.getStore(Chain, dl: DL, Val: Tmp,
20996 Ptr, MMO: ST->getMemOperand());
20997 }
20998
20999 if (ST->isSimple() && TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32) &&
21000 !TLI.isFPImmLegal(CFP->getValueAPF(), MVT::f64)) {
21001 // Many FP stores are not made apparent until after legalize, e.g. for
21002 // argument passing. Since this is so common, custom legalize the
21003 // 64-bit integer store into two 32-bit stores.
21004 uint64_t Val = CFP->getValueAPF().bitcastToAPInt().getZExtValue();
21005 SDValue Lo = DAG.getConstant(Val & 0xFFFFFFFF, SDLoc(CFP), MVT::i32);
21006 SDValue Hi = DAG.getConstant(Val >> 32, SDLoc(CFP), MVT::i32);
21007 if (DAG.getDataLayout().isBigEndian())
21008 std::swap(a&: Lo, b&: Hi);
21009
21010 MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
21011 AAMDNodes AAInfo = ST->getAAInfo();
21012
21013 SDValue St0 = DAG.getStore(Chain, dl: DL, Val: Lo, Ptr, PtrInfo: ST->getPointerInfo(),
21014 Alignment: ST->getOriginalAlign(), MMOFlags, AAInfo);
21015 Ptr = DAG.getMemBasePlusOffset(Base: Ptr, Offset: TypeSize::getFixed(ExactSize: 4), DL);
21016 SDValue St1 = DAG.getStore(Chain, dl: DL, Val: Hi, Ptr,
21017 PtrInfo: ST->getPointerInfo().getWithOffset(O: 4),
21018 Alignment: ST->getOriginalAlign(), MMOFlags, AAInfo);
21019 return DAG.getNode(ISD::TokenFactor, DL, MVT::Other,
21020 St0, St1);
21021 }
21022
21023 return SDValue();
21024 }
21025}
21026
21027// (store (insert_vector_elt (load p), x, i), p) -> (store x, p+offset)
21028//
21029// If a store of a load with an element inserted into it has no other
21030// uses in between the chain, then we can consider the vector store
21031// dead and replace it with just the single scalar element store.
21032SDValue DAGCombiner::replaceStoreOfInsertLoad(StoreSDNode *ST) {
21033 SDLoc DL(ST);
21034 SDValue Value = ST->getValue();
21035 SDValue Ptr = ST->getBasePtr();
21036 SDValue Chain = ST->getChain();
21037 if (Value.getOpcode() != ISD::INSERT_VECTOR_ELT || !Value.hasOneUse())
21038 return SDValue();
21039
21040 SDValue Elt = Value.getOperand(i: 1);
21041 SDValue Idx = Value.getOperand(i: 2);
21042
21043 // If the element isn't byte sized or is implicitly truncated then we can't
21044 // compute an offset.
21045 EVT EltVT = Elt.getValueType();
21046 if (!EltVT.isByteSized() ||
21047 EltVT != Value.getOperand(i: 0).getValueType().getVectorElementType())
21048 return SDValue();
21049
21050 auto *Ld = dyn_cast<LoadSDNode>(Val: Value.getOperand(i: 0));
21051 if (!Ld || Ld->getBasePtr() != Ptr ||
21052 ST->getMemoryVT() != Ld->getMemoryVT() || !ST->isSimple() ||
21053 !ISD::isNormalStore(N: ST) ||
21054 Ld->getAddressSpace() != ST->getAddressSpace() ||
21055 !Chain.reachesChainWithoutSideEffects(Dest: SDValue(Ld, 1)))
21056 return SDValue();
21057
21058 unsigned IsFast;
21059 if (!TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(),
21060 VT: Elt.getValueType(), AddrSpace: ST->getAddressSpace(),
21061 Alignment: ST->getAlign(), Flags: ST->getMemOperand()->getFlags(),
21062 Fast: &IsFast) ||
21063 !IsFast)
21064 return SDValue();
21065
21066 MachinePointerInfo PointerInfo(ST->getAddressSpace());
21067
21068 // If the offset is a known constant then try to recover the pointer
21069 // info
21070 SDValue NewPtr;
21071 if (auto *CIdx = dyn_cast<ConstantSDNode>(Val&: Idx)) {
21072 unsigned COffset = CIdx->getSExtValue() * EltVT.getSizeInBits() / 8;
21073 NewPtr = DAG.getMemBasePlusOffset(Base: Ptr, Offset: TypeSize::getFixed(ExactSize: COffset), DL);
21074 PointerInfo = ST->getPointerInfo().getWithOffset(O: COffset);
21075 } else {
21076 NewPtr = TLI.getVectorElementPointer(DAG, VecPtr: Ptr, VecVT: Value.getValueType(), Index: Idx);
21077 }
21078
21079 return DAG.getStore(Chain, dl: DL, Val: Elt, Ptr: NewPtr, PtrInfo: PointerInfo, Alignment: ST->getAlign(),
21080 MMOFlags: ST->getMemOperand()->getFlags());
21081}
21082
21083SDValue DAGCombiner::visitSTORE(SDNode *N) {
21084 StoreSDNode *ST = cast<StoreSDNode>(Val: N);
21085 SDValue Chain = ST->getChain();
21086 SDValue Value = ST->getValue();
21087 SDValue Ptr = ST->getBasePtr();
21088
21089 // If this is a store of a bit convert, store the input value if the
21090 // resultant store does not need a higher alignment than the original.
21091 if (Value.getOpcode() == ISD::BITCAST && !ST->isTruncatingStore() &&
21092 ST->isUnindexed()) {
21093 EVT SVT = Value.getOperand(i: 0).getValueType();
21094 // If the store is volatile, we only want to change the store type if the
21095 // resulting store is legal. Otherwise we might increase the number of
21096 // memory accesses. We don't care if the original type was legal or not
21097 // as we assume software couldn't rely on the number of accesses of an
21098 // illegal type.
21099 // TODO: May be able to relax for unordered atomics (see D66309)
21100 if (((!LegalOperations && ST->isSimple()) ||
21101 TLI.isOperationLegal(Op: ISD::STORE, VT: SVT)) &&
21102 TLI.isStoreBitCastBeneficial(StoreVT: Value.getValueType(), BitcastVT: SVT,
21103 DAG, MMO: *ST->getMemOperand())) {
21104 return DAG.getStore(Chain, dl: SDLoc(N), Val: Value.getOperand(i: 0), Ptr,
21105 MMO: ST->getMemOperand());
21106 }
21107 }
21108
21109 // Turn 'store undef, Ptr' -> nothing.
21110 if (Value.isUndef() && ST->isUnindexed())
21111 return Chain;
21112
21113 // Try to infer better alignment information than the store already has.
21114 if (OptLevel != CodeGenOptLevel::None && ST->isUnindexed() &&
21115 !ST->isAtomic()) {
21116 if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
21117 if (*Alignment > ST->getAlign() &&
21118 isAligned(Lhs: *Alignment, SizeInBytes: ST->getSrcValueOffset())) {
21119 SDValue NewStore =
21120 DAG.getTruncStore(Chain, dl: SDLoc(N), Val: Value, Ptr, PtrInfo: ST->getPointerInfo(),
21121 SVT: ST->getMemoryVT(), Alignment: *Alignment,
21122 MMOFlags: ST->getMemOperand()->getFlags(), AAInfo: ST->getAAInfo());
21123 // NewStore will always be N as we are only refining the alignment
21124 assert(NewStore.getNode() == N);
21125 (void)NewStore;
21126 }
21127 }
21128 }
21129
21130 // Try transforming a pair floating point load / store ops to integer
21131 // load / store ops.
21132 if (SDValue NewST = TransformFPLoadStorePair(N))
21133 return NewST;
21134
21135 // Try transforming several stores into STORE (BSWAP).
21136 if (SDValue Store = mergeTruncStores(N: ST))
21137 return Store;
21138
21139 if (ST->isUnindexed()) {
21140 // Walk up chain skipping non-aliasing memory nodes, on this store and any
21141 // adjacent stores.
21142 if (findBetterNeighborChains(St: ST)) {
21143 // replaceStoreChain uses CombineTo, which handled all of the worklist
21144 // manipulation. Return the original node to not do anything else.
21145 return SDValue(ST, 0);
21146 }
21147 Chain = ST->getChain();
21148 }
21149
21150 // FIXME: is there such a thing as a truncating indexed store?
21151 if (ST->isTruncatingStore() && ST->isUnindexed() &&
21152 Value.getValueType().isInteger() &&
21153 (!isa<ConstantSDNode>(Val: Value) ||
21154 !cast<ConstantSDNode>(Val&: Value)->isOpaque())) {
21155 // Convert a truncating store of a extension into a standard store.
21156 if ((Value.getOpcode() == ISD::ZERO_EXTEND ||
21157 Value.getOpcode() == ISD::SIGN_EXTEND ||
21158 Value.getOpcode() == ISD::ANY_EXTEND) &&
21159 Value.getOperand(i: 0).getValueType() == ST->getMemoryVT() &&
21160 TLI.isOperationLegalOrCustom(Op: ISD::STORE, VT: ST->getMemoryVT()))
21161 return DAG.getStore(Chain, dl: SDLoc(N), Val: Value.getOperand(i: 0), Ptr,
21162 MMO: ST->getMemOperand());
21163
21164 APInt TruncDemandedBits =
21165 APInt::getLowBitsSet(numBits: Value.getScalarValueSizeInBits(),
21166 loBitsSet: ST->getMemoryVT().getScalarSizeInBits());
21167
21168 // See if we can simplify the operation with SimplifyDemandedBits, which
21169 // only works if the value has a single use.
21170 AddToWorklist(N: Value.getNode());
21171 if (SimplifyDemandedBits(Op: Value, DemandedBits: TruncDemandedBits)) {
21172 // Re-visit the store if anything changed and the store hasn't been merged
21173 // with another node (N is deleted) SimplifyDemandedBits will add Value's
21174 // node back to the worklist if necessary, but we also need to re-visit
21175 // the Store node itself.
21176 if (N->getOpcode() != ISD::DELETED_NODE)
21177 AddToWorklist(N);
21178 return SDValue(N, 0);
21179 }
21180
21181 // Otherwise, see if we can simplify the input to this truncstore with
21182 // knowledge that only the low bits are being used. For example:
21183 // "truncstore (or (shl x, 8), y), i8" -> "truncstore y, i8"
21184 if (SDValue Shorter =
21185 TLI.SimplifyMultipleUseDemandedBits(Op: Value, DemandedBits: TruncDemandedBits, DAG))
21186 return DAG.getTruncStore(Chain, dl: SDLoc(N), Val: Shorter, Ptr, SVT: ST->getMemoryVT(),
21187 MMO: ST->getMemOperand());
21188
21189 // If we're storing a truncated constant, see if we can simplify it.
21190 // TODO: Move this to targetShrinkDemandedConstant?
21191 if (auto *Cst = dyn_cast<ConstantSDNode>(Val&: Value))
21192 if (!Cst->isOpaque()) {
21193 const APInt &CValue = Cst->getAPIntValue();
21194 APInt NewVal = CValue & TruncDemandedBits;
21195 if (NewVal != CValue) {
21196 SDValue Shorter =
21197 DAG.getConstant(Val: NewVal, DL: SDLoc(N), VT: Value.getValueType());
21198 return DAG.getTruncStore(Chain, dl: SDLoc(N), Val: Shorter, Ptr,
21199 SVT: ST->getMemoryVT(), MMO: ST->getMemOperand());
21200 }
21201 }
21202 }
21203
21204 // If this is a load followed by a store to the same location, then the store
21205 // is dead/noop. Peek through any truncates if canCombineTruncStore failed.
21206 // TODO: Add big-endian truncate support with test coverage.
21207 // TODO: Can relax for unordered atomics (see D66309)
21208 SDValue TruncVal = DAG.getDataLayout().isLittleEndian()
21209 ? peekThroughTruncates(V: Value)
21210 : Value;
21211 if (auto *Ld = dyn_cast<LoadSDNode>(Val&: TruncVal)) {
21212 if (Ld->getBasePtr() == Ptr && ST->getMemoryVT() == Ld->getMemoryVT() &&
21213 ST->isUnindexed() && ST->isSimple() &&
21214 Ld->getAddressSpace() == ST->getAddressSpace() &&
21215 // There can't be any side effects between the load and store, such as
21216 // a call or store.
21217 Chain.reachesChainWithoutSideEffects(Dest: SDValue(Ld, 1))) {
21218 // The store is dead, remove it.
21219 return Chain;
21220 }
21221 }
21222
21223 // Try scalarizing vector stores of loads where we only change one element
21224 if (SDValue NewST = replaceStoreOfInsertLoad(ST))
21225 return NewST;
21226
21227 // TODO: Can relax for unordered atomics (see D66309)
21228 if (StoreSDNode *ST1 = dyn_cast<StoreSDNode>(Val&: Chain)) {
21229 if (ST->isUnindexed() && ST->isSimple() &&
21230 ST1->isUnindexed() && ST1->isSimple()) {
21231 if (OptLevel != CodeGenOptLevel::None && ST1->getBasePtr() == Ptr &&
21232 ST1->getValue() == Value && ST->getMemoryVT() == ST1->getMemoryVT() &&
21233 ST->getAddressSpace() == ST1->getAddressSpace()) {
21234 // If this is a store followed by a store with the same value to the
21235 // same location, then the store is dead/noop.
21236 return Chain;
21237 }
21238
21239 if (OptLevel != CodeGenOptLevel::None && ST1->hasOneUse() &&
21240 !ST1->getBasePtr().isUndef() &&
21241 ST->getAddressSpace() == ST1->getAddressSpace()) {
21242 // If we consider two stores and one smaller in size is a scalable
21243 // vector type and another one a bigger size store with a fixed type,
21244 // then we could not allow the scalable store removal because we don't
21245 // know its final size in the end.
21246 if (ST->getMemoryVT().isScalableVector() ||
21247 ST1->getMemoryVT().isScalableVector()) {
21248 if (ST1->getBasePtr() == Ptr &&
21249 TypeSize::isKnownLE(LHS: ST1->getMemoryVT().getStoreSize(),
21250 RHS: ST->getMemoryVT().getStoreSize())) {
21251 CombineTo(N: ST1, Res: ST1->getChain());
21252 return SDValue(N, 0);
21253 }
21254 } else {
21255 const BaseIndexOffset STBase = BaseIndexOffset::match(N: ST, DAG);
21256 const BaseIndexOffset ChainBase = BaseIndexOffset::match(N: ST1, DAG);
21257 // If this is a store who's preceding store to a subset of the current
21258 // location and no one other node is chained to that store we can
21259 // effectively drop the store. Do not remove stores to undef as they
21260 // may be used as data sinks.
21261 if (STBase.contains(DAG, BitSize: ST->getMemoryVT().getFixedSizeInBits(),
21262 Other: ChainBase,
21263 OtherBitSize: ST1->getMemoryVT().getFixedSizeInBits())) {
21264 CombineTo(N: ST1, Res: ST1->getChain());
21265 return SDValue(N, 0);
21266 }
21267 }
21268 }
21269 }
21270 }
21271
21272 // If this is an FP_ROUND or TRUNC followed by a store, fold this into a
21273 // truncating store. We can do this even if this is already a truncstore.
21274 if ((Value.getOpcode() == ISD::FP_ROUND ||
21275 Value.getOpcode() == ISD::TRUNCATE) &&
21276 Value->hasOneUse() && ST->isUnindexed() &&
21277 TLI.canCombineTruncStore(ValVT: Value.getOperand(i: 0).getValueType(),
21278 MemVT: ST->getMemoryVT(), LegalOnly: LegalOperations)) {
21279 return DAG.getTruncStore(Chain, dl: SDLoc(N), Val: Value.getOperand(i: 0),
21280 Ptr, SVT: ST->getMemoryVT(), MMO: ST->getMemOperand());
21281 }
21282
21283 // Always perform this optimization before types are legal. If the target
21284 // prefers, also try this after legalization to catch stores that were created
21285 // by intrinsics or other nodes.
21286 if (!LegalTypes || (TLI.mergeStoresAfterLegalization(MemVT: ST->getMemoryVT()))) {
21287 while (true) {
21288 // There can be multiple store sequences on the same chain.
21289 // Keep trying to merge store sequences until we are unable to do so
21290 // or until we merge the last store on the chain.
21291 bool Changed = mergeConsecutiveStores(St: ST);
21292 if (!Changed) break;
21293 // Return N as merge only uses CombineTo and no worklist clean
21294 // up is necessary.
21295 if (N->getOpcode() == ISD::DELETED_NODE || !isa<StoreSDNode>(Val: N))
21296 return SDValue(N, 0);
21297 }
21298 }
21299
21300 // Try transforming N to an indexed store.
21301 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
21302 return SDValue(N, 0);
21303
21304 // Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr'
21305 //
21306 // Make sure to do this only after attempting to merge stores in order to
21307 // avoid changing the types of some subset of stores due to visit order,
21308 // preventing their merging.
21309 if (isa<ConstantFPSDNode>(Val: ST->getValue())) {
21310 if (SDValue NewSt = replaceStoreOfFPConstant(ST))
21311 return NewSt;
21312 }
21313
21314 if (SDValue NewSt = splitMergedValStore(ST))
21315 return NewSt;
21316
21317 return ReduceLoadOpStoreWidth(N);
21318}
21319
21320SDValue DAGCombiner::visitLIFETIME_END(SDNode *N) {
21321 const auto *LifetimeEnd = cast<LifetimeSDNode>(Val: N);
21322 if (!LifetimeEnd->hasOffset())
21323 return SDValue();
21324
21325 const BaseIndexOffset LifetimeEndBase(N->getOperand(Num: 1), SDValue(),
21326 LifetimeEnd->getOffset(), false);
21327
21328 // We walk up the chains to find stores.
21329 SmallVector<SDValue, 8> Chains = {N->getOperand(Num: 0)};
21330 while (!Chains.empty()) {
21331 SDValue Chain = Chains.pop_back_val();
21332 if (!Chain.hasOneUse())
21333 continue;
21334 switch (Chain.getOpcode()) {
21335 case ISD::TokenFactor:
21336 for (unsigned Nops = Chain.getNumOperands(); Nops;)
21337 Chains.push_back(Elt: Chain.getOperand(i: --Nops));
21338 break;
21339 case ISD::LIFETIME_START:
21340 case ISD::LIFETIME_END:
21341 // We can forward past any lifetime start/end that can be proven not to
21342 // alias the node.
21343 if (!mayAlias(Op0: Chain.getNode(), Op1: N))
21344 Chains.push_back(Elt: Chain.getOperand(i: 0));
21345 break;
21346 case ISD::STORE: {
21347 StoreSDNode *ST = dyn_cast<StoreSDNode>(Val&: Chain);
21348 // TODO: Can relax for unordered atomics (see D66309)
21349 if (!ST->isSimple() || ST->isIndexed())
21350 continue;
21351 const TypeSize StoreSize = ST->getMemoryVT().getStoreSize();
21352 // The bounds of a scalable store are not known until runtime, so this
21353 // store cannot be elided.
21354 if (StoreSize.isScalable())
21355 continue;
21356 const BaseIndexOffset StoreBase = BaseIndexOffset::match(N: ST, DAG);
21357 // If we store purely within object bounds just before its lifetime ends,
21358 // we can remove the store.
21359 if (LifetimeEndBase.contains(DAG, BitSize: LifetimeEnd->getSize() * 8, Other: StoreBase,
21360 OtherBitSize: StoreSize.getFixedValue() * 8)) {
21361 LLVM_DEBUG(dbgs() << "\nRemoving store:"; StoreBase.dump();
21362 dbgs() << "\nwithin LIFETIME_END of : ";
21363 LifetimeEndBase.dump(); dbgs() << "\n");
21364 CombineTo(N: ST, Res: ST->getChain());
21365 return SDValue(N, 0);
21366 }
21367 }
21368 }
21369 }
21370 return SDValue();
21371}
21372
21373/// For the instruction sequence of store below, F and I values
21374/// are bundled together as an i64 value before being stored into memory.
21375/// Sometimes it is more efficent to generate separate stores for F and I,
21376/// which can remove the bitwise instructions or sink them to colder places.
21377///
21378/// (store (or (zext (bitcast F to i32) to i64),
21379/// (shl (zext I to i64), 32)), addr) -->
21380/// (store F, addr) and (store I, addr+4)
21381///
21382/// Similarly, splitting for other merged store can also be beneficial, like:
21383/// For pair of {i32, i32}, i64 store --> two i32 stores.
21384/// For pair of {i32, i16}, i64 store --> two i32 stores.
21385/// For pair of {i16, i16}, i32 store --> two i16 stores.
21386/// For pair of {i16, i8}, i32 store --> two i16 stores.
21387/// For pair of {i8, i8}, i16 store --> two i8 stores.
21388///
21389/// We allow each target to determine specifically which kind of splitting is
21390/// supported.
21391///
21392/// The store patterns are commonly seen from the simple code snippet below
21393/// if only std::make_pair(...) is sroa transformed before inlined into hoo.
21394/// void goo(const std::pair<int, float> &);
21395/// hoo() {
21396/// ...
21397/// goo(std::make_pair(tmp, ftmp));
21398/// ...
21399/// }
21400///
21401SDValue DAGCombiner::splitMergedValStore(StoreSDNode *ST) {
21402 if (OptLevel == CodeGenOptLevel::None)
21403 return SDValue();
21404
21405 // Can't change the number of memory accesses for a volatile store or break
21406 // atomicity for an atomic one.
21407 if (!ST->isSimple())
21408 return SDValue();
21409
21410 SDValue Val = ST->getValue();
21411 SDLoc DL(ST);
21412
21413 // Match OR operand.
21414 if (!Val.getValueType().isScalarInteger() || Val.getOpcode() != ISD::OR)
21415 return SDValue();
21416
21417 // Match SHL operand and get Lower and Higher parts of Val.
21418 SDValue Op1 = Val.getOperand(i: 0);
21419 SDValue Op2 = Val.getOperand(i: 1);
21420 SDValue Lo, Hi;
21421 if (Op1.getOpcode() != ISD::SHL) {
21422 std::swap(a&: Op1, b&: Op2);
21423 if (Op1.getOpcode() != ISD::SHL)
21424 return SDValue();
21425 }
21426 Lo = Op2;
21427 Hi = Op1.getOperand(i: 0);
21428 if (!Op1.hasOneUse())
21429 return SDValue();
21430
21431 // Match shift amount to HalfValBitSize.
21432 unsigned HalfValBitSize = Val.getValueSizeInBits() / 2;
21433 ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(Val: Op1.getOperand(i: 1));
21434 if (!ShAmt || ShAmt->getAPIntValue() != HalfValBitSize)
21435 return SDValue();
21436
21437 // Lo and Hi are zero-extended from int with size less equal than 32
21438 // to i64.
21439 if (Lo.getOpcode() != ISD::ZERO_EXTEND || !Lo.hasOneUse() ||
21440 !Lo.getOperand(i: 0).getValueType().isScalarInteger() ||
21441 Lo.getOperand(i: 0).getValueSizeInBits() > HalfValBitSize ||
21442 Hi.getOpcode() != ISD::ZERO_EXTEND || !Hi.hasOneUse() ||
21443 !Hi.getOperand(i: 0).getValueType().isScalarInteger() ||
21444 Hi.getOperand(i: 0).getValueSizeInBits() > HalfValBitSize)
21445 return SDValue();
21446
21447 // Use the EVT of low and high parts before bitcast as the input
21448 // of target query.
21449 EVT LowTy = (Lo.getOperand(i: 0).getOpcode() == ISD::BITCAST)
21450 ? Lo.getOperand(i: 0).getValueType()
21451 : Lo.getValueType();
21452 EVT HighTy = (Hi.getOperand(i: 0).getOpcode() == ISD::BITCAST)
21453 ? Hi.getOperand(i: 0).getValueType()
21454 : Hi.getValueType();
21455 if (!TLI.isMultiStoresCheaperThanBitsMerge(LTy: LowTy, HTy: HighTy))
21456 return SDValue();
21457
21458 // Start to split store.
21459 MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
21460 AAMDNodes AAInfo = ST->getAAInfo();
21461
21462 // Change the sizes of Lo and Hi's value types to HalfValBitSize.
21463 EVT VT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: HalfValBitSize);
21464 Lo = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: Lo.getOperand(i: 0));
21465 Hi = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: Hi.getOperand(i: 0));
21466
21467 SDValue Chain = ST->getChain();
21468 SDValue Ptr = ST->getBasePtr();
21469 // Lower value store.
21470 SDValue St0 = DAG.getStore(Chain, dl: DL, Val: Lo, Ptr, PtrInfo: ST->getPointerInfo(),
21471 Alignment: ST->getOriginalAlign(), MMOFlags, AAInfo);
21472 Ptr =
21473 DAG.getMemBasePlusOffset(Base: Ptr, Offset: TypeSize::getFixed(ExactSize: HalfValBitSize / 8), DL);
21474 // Higher value store.
21475 SDValue St1 = DAG.getStore(
21476 Chain: St0, dl: DL, Val: Hi, Ptr, PtrInfo: ST->getPointerInfo().getWithOffset(O: HalfValBitSize / 8),
21477 Alignment: ST->getOriginalAlign(), MMOFlags, AAInfo);
21478 return St1;
21479}
21480
21481// Merge an insertion into an existing shuffle:
21482// (insert_vector_elt (vector_shuffle X, Y, Mask),
21483// .(extract_vector_elt X, N), InsIndex)
21484// --> (vector_shuffle X, Y, NewMask)
21485// and variations where shuffle operands may be CONCAT_VECTORS.
21486static bool mergeEltWithShuffle(SDValue &X, SDValue &Y, ArrayRef<int> Mask,
21487 SmallVectorImpl<int> &NewMask, SDValue Elt,
21488 unsigned InsIndex) {
21489 if (Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
21490 !isa<ConstantSDNode>(Val: Elt.getOperand(i: 1)))
21491 return false;
21492
21493 // Vec's operand 0 is using indices from 0 to N-1 and
21494 // operand 1 from N to 2N - 1, where N is the number of
21495 // elements in the vectors.
21496 SDValue InsertVal0 = Elt.getOperand(i: 0);
21497 int ElementOffset = -1;
21498
21499 // We explore the inputs of the shuffle in order to see if we find the
21500 // source of the extract_vector_elt. If so, we can use it to modify the
21501 // shuffle rather than perform an insert_vector_elt.
21502 SmallVector<std::pair<int, SDValue>, 8> ArgWorkList;
21503 ArgWorkList.emplace_back(Args: Mask.size(), Args&: Y);
21504 ArgWorkList.emplace_back(Args: 0, Args&: X);
21505
21506 while (!ArgWorkList.empty()) {
21507 int ArgOffset;
21508 SDValue ArgVal;
21509 std::tie(args&: ArgOffset, args&: ArgVal) = ArgWorkList.pop_back_val();
21510
21511 if (ArgVal == InsertVal0) {
21512 ElementOffset = ArgOffset;
21513 break;
21514 }
21515
21516 // Peek through concat_vector.
21517 if (ArgVal.getOpcode() == ISD::CONCAT_VECTORS) {
21518 int CurrentArgOffset =
21519 ArgOffset + ArgVal.getValueType().getVectorNumElements();
21520 int Step = ArgVal.getOperand(i: 0).getValueType().getVectorNumElements();
21521 for (SDValue Op : reverse(C: ArgVal->ops())) {
21522 CurrentArgOffset -= Step;
21523 ArgWorkList.emplace_back(Args&: CurrentArgOffset, Args&: Op);
21524 }
21525
21526 // Make sure we went through all the elements and did not screw up index
21527 // computation.
21528 assert(CurrentArgOffset == ArgOffset);
21529 }
21530 }
21531
21532 // If we failed to find a match, see if we can replace an UNDEF shuffle
21533 // operand.
21534 if (ElementOffset == -1) {
21535 if (!Y.isUndef() || InsertVal0.getValueType() != Y.getValueType())
21536 return false;
21537 ElementOffset = Mask.size();
21538 Y = InsertVal0;
21539 }
21540
21541 NewMask.assign(in_start: Mask.begin(), in_end: Mask.end());
21542 NewMask[InsIndex] = ElementOffset + Elt.getConstantOperandVal(i: 1);
21543 assert(NewMask[InsIndex] < (int)(2 * Mask.size()) && NewMask[InsIndex] >= 0 &&
21544 "NewMask[InsIndex] is out of bound");
21545 return true;
21546}
21547
21548// Merge an insertion into an existing shuffle:
21549// (insert_vector_elt (vector_shuffle X, Y), (extract_vector_elt X, N),
21550// InsIndex)
21551// --> (vector_shuffle X, Y) and variations where shuffle operands may be
21552// CONCAT_VECTORS.
21553SDValue DAGCombiner::mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex) {
21554 assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
21555 "Expected extract_vector_elt");
21556 SDValue InsertVal = N->getOperand(Num: 1);
21557 SDValue Vec = N->getOperand(Num: 0);
21558
21559 auto *SVN = dyn_cast<ShuffleVectorSDNode>(Val&: Vec);
21560 if (!SVN || !Vec.hasOneUse())
21561 return SDValue();
21562
21563 ArrayRef<int> Mask = SVN->getMask();
21564 SDValue X = Vec.getOperand(i: 0);
21565 SDValue Y = Vec.getOperand(i: 1);
21566
21567 SmallVector<int, 16> NewMask(Mask);
21568 if (mergeEltWithShuffle(X, Y, Mask, NewMask, Elt: InsertVal, InsIndex)) {
21569 SDValue LegalShuffle = TLI.buildLegalVectorShuffle(
21570 VT: Vec.getValueType(), DL: SDLoc(N), N0: X, N1: Y, Mask: NewMask, DAG);
21571 if (LegalShuffle)
21572 return LegalShuffle;
21573 }
21574
21575 return SDValue();
21576}
21577
21578// Convert a disguised subvector insertion into a shuffle:
21579// insert_vector_elt V, (bitcast X from vector type), IdxC -->
21580// bitcast(shuffle (bitcast V), (extended X), Mask)
21581// Note: We do not use an insert_subvector node because that requires a
21582// legal subvector type.
21583SDValue DAGCombiner::combineInsertEltToShuffle(SDNode *N, unsigned InsIndex) {
21584 assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
21585 "Expected extract_vector_elt");
21586 SDValue InsertVal = N->getOperand(Num: 1);
21587
21588 if (InsertVal.getOpcode() != ISD::BITCAST || !InsertVal.hasOneUse() ||
21589 !InsertVal.getOperand(i: 0).getValueType().isVector())
21590 return SDValue();
21591
21592 SDValue SubVec = InsertVal.getOperand(i: 0);
21593 SDValue DestVec = N->getOperand(Num: 0);
21594 EVT SubVecVT = SubVec.getValueType();
21595 EVT VT = DestVec.getValueType();
21596 unsigned NumSrcElts = SubVecVT.getVectorNumElements();
21597 // If the source only has a single vector element, the cost of creating adding
21598 // it to a vector is likely to exceed the cost of a insert_vector_elt.
21599 if (NumSrcElts == 1)
21600 return SDValue();
21601 unsigned ExtendRatio = VT.getSizeInBits() / SubVecVT.getSizeInBits();
21602 unsigned NumMaskVals = ExtendRatio * NumSrcElts;
21603
21604 // Step 1: Create a shuffle mask that implements this insert operation. The
21605 // vector that we are inserting into will be operand 0 of the shuffle, so
21606 // those elements are just 'i'. The inserted subvector is in the first
21607 // positions of operand 1 of the shuffle. Example:
21608 // insert v4i32 V, (v2i16 X), 2 --> shuffle v8i16 V', X', {0,1,2,3,8,9,6,7}
21609 SmallVector<int, 16> Mask(NumMaskVals);
21610 for (unsigned i = 0; i != NumMaskVals; ++i) {
21611 if (i / NumSrcElts == InsIndex)
21612 Mask[i] = (i % NumSrcElts) + NumMaskVals;
21613 else
21614 Mask[i] = i;
21615 }
21616
21617 // Bail out if the target can not handle the shuffle we want to create.
21618 EVT SubVecEltVT = SubVecVT.getVectorElementType();
21619 EVT ShufVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: SubVecEltVT, NumElements: NumMaskVals);
21620 if (!TLI.isShuffleMaskLegal(Mask, ShufVT))
21621 return SDValue();
21622
21623 // Step 2: Create a wide vector from the inserted source vector by appending
21624 // undefined elements. This is the same size as our destination vector.
21625 SDLoc DL(N);
21626 SmallVector<SDValue, 8> ConcatOps(ExtendRatio, DAG.getUNDEF(VT: SubVecVT));
21627 ConcatOps[0] = SubVec;
21628 SDValue PaddedSubV = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT: ShufVT, Ops: ConcatOps);
21629
21630 // Step 3: Shuffle in the padded subvector.
21631 SDValue DestVecBC = DAG.getBitcast(VT: ShufVT, V: DestVec);
21632 SDValue Shuf = DAG.getVectorShuffle(VT: ShufVT, dl: DL, N1: DestVecBC, N2: PaddedSubV, Mask);
21633 AddToWorklist(N: PaddedSubV.getNode());
21634 AddToWorklist(N: DestVecBC.getNode());
21635 AddToWorklist(N: Shuf.getNode());
21636 return DAG.getBitcast(VT, V: Shuf);
21637}
21638
21639// Combine insert(shuffle(load, <u,0,1,2>), load, 0) into a single load if
21640// possible and the new load will be quick. We use more loads but less shuffles
21641// and inserts.
21642SDValue DAGCombiner::combineInsertEltToLoad(SDNode *N, unsigned InsIndex) {
21643 EVT VT = N->getValueType(ResNo: 0);
21644
21645 // InsIndex is expected to be the first of last lane.
21646 if (!VT.isFixedLengthVector() ||
21647 (InsIndex != 0 && InsIndex != VT.getVectorNumElements() - 1))
21648 return SDValue();
21649
21650 // Look for a shuffle with the mask u,0,1,2,3,4,5,6 or 1,2,3,4,5,6,7,u
21651 // depending on the InsIndex.
21652 auto *Shuffle = dyn_cast<ShuffleVectorSDNode>(Val: N->getOperand(Num: 0));
21653 SDValue Scalar = N->getOperand(Num: 1);
21654 if (!Shuffle || !all_of(Range: enumerate(First: Shuffle->getMask()), P: [&](auto P) {
21655 return InsIndex == P.index() || P.value() < 0 ||
21656 (InsIndex == 0 && P.value() == (int)P.index() - 1) ||
21657 (InsIndex == VT.getVectorNumElements() - 1 &&
21658 P.value() == (int)P.index() + 1);
21659 }))
21660 return SDValue();
21661
21662 // We optionally skip over an extend so long as both loads are extended in the
21663 // same way from the same type.
21664 unsigned Extend = 0;
21665 if (Scalar.getOpcode() == ISD::ZERO_EXTEND ||
21666 Scalar.getOpcode() == ISD::SIGN_EXTEND ||
21667 Scalar.getOpcode() == ISD::ANY_EXTEND) {
21668 Extend = Scalar.getOpcode();
21669 Scalar = Scalar.getOperand(i: 0);
21670 }
21671
21672 auto *ScalarLoad = dyn_cast<LoadSDNode>(Val&: Scalar);
21673 if (!ScalarLoad)
21674 return SDValue();
21675
21676 SDValue Vec = Shuffle->getOperand(Num: 0);
21677 if (Extend) {
21678 if (Vec.getOpcode() != Extend)
21679 return SDValue();
21680 Vec = Vec.getOperand(i: 0);
21681 }
21682 auto *VecLoad = dyn_cast<LoadSDNode>(Val&: Vec);
21683 if (!VecLoad || Vec.getValueType().getScalarType() != Scalar.getValueType())
21684 return SDValue();
21685
21686 int EltSize = ScalarLoad->getValueType(ResNo: 0).getScalarSizeInBits();
21687 if (EltSize == 0 || EltSize % 8 != 0 || !ScalarLoad->isSimple() ||
21688 !VecLoad->isSimple() || VecLoad->getExtensionType() != ISD::NON_EXTLOAD ||
21689 ScalarLoad->getExtensionType() != ISD::NON_EXTLOAD ||
21690 ScalarLoad->getAddressSpace() != VecLoad->getAddressSpace())
21691 return SDValue();
21692
21693 // Check that the offset between the pointers to produce a single continuous
21694 // load.
21695 if (InsIndex == 0) {
21696 if (!DAG.areNonVolatileConsecutiveLoads(LD: ScalarLoad, Base: VecLoad, Bytes: EltSize / 8,
21697 Dist: -1))
21698 return SDValue();
21699 } else {
21700 if (!DAG.areNonVolatileConsecutiveLoads(
21701 LD: VecLoad, Base: ScalarLoad, Bytes: VT.getVectorNumElements() * EltSize / 8, Dist: -1))
21702 return SDValue();
21703 }
21704
21705 // And that the new unaligned load will be fast.
21706 unsigned IsFast = 0;
21707 Align NewAlign = commonAlignment(A: VecLoad->getAlign(), Offset: EltSize / 8);
21708 if (!TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(),
21709 VT: Vec.getValueType(), AddrSpace: VecLoad->getAddressSpace(),
21710 Alignment: NewAlign, Flags: VecLoad->getMemOperand()->getFlags(),
21711 Fast: &IsFast) ||
21712 !IsFast)
21713 return SDValue();
21714
21715 // Calculate the new Ptr and create the new load.
21716 SDLoc DL(N);
21717 SDValue Ptr = ScalarLoad->getBasePtr();
21718 if (InsIndex != 0)
21719 Ptr = DAG.getNode(Opcode: ISD::ADD, DL, VT: Ptr.getValueType(), N1: VecLoad->getBasePtr(),
21720 N2: DAG.getConstant(Val: EltSize / 8, DL, VT: Ptr.getValueType()));
21721 MachinePointerInfo PtrInfo =
21722 InsIndex == 0 ? ScalarLoad->getPointerInfo()
21723 : VecLoad->getPointerInfo().getWithOffset(O: EltSize / 8);
21724
21725 SDValue Load = DAG.getLoad(VT: VecLoad->getValueType(ResNo: 0), dl: DL,
21726 Chain: ScalarLoad->getChain(), Ptr, PtrInfo, Alignment: NewAlign);
21727 DAG.makeEquivalentMemoryOrdering(OldLoad: ScalarLoad, NewMemOp: Load.getValue(R: 1));
21728 DAG.makeEquivalentMemoryOrdering(OldLoad: VecLoad, NewMemOp: Load.getValue(R: 1));
21729 return Extend ? DAG.getNode(Opcode: Extend, DL, VT, Operand: Load) : Load;
21730}
21731
21732SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
21733 SDValue InVec = N->getOperand(Num: 0);
21734 SDValue InVal = N->getOperand(Num: 1);
21735 SDValue EltNo = N->getOperand(Num: 2);
21736 SDLoc DL(N);
21737
21738 EVT VT = InVec.getValueType();
21739 auto *IndexC = dyn_cast<ConstantSDNode>(Val&: EltNo);
21740
21741 // Insert into out-of-bounds element is undefined.
21742 if (IndexC && VT.isFixedLengthVector() &&
21743 IndexC->getZExtValue() >= VT.getVectorNumElements())
21744 return DAG.getUNDEF(VT);
21745
21746 // Remove redundant insertions:
21747 // (insert_vector_elt x (extract_vector_elt x idx) idx) -> x
21748 if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
21749 InVec == InVal.getOperand(i: 0) && EltNo == InVal.getOperand(i: 1))
21750 return InVec;
21751
21752 if (!IndexC) {
21753 // If this is variable insert to undef vector, it might be better to splat:
21754 // inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >
21755 if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT))
21756 return DAG.getSplat(VT, DL, Op: InVal);
21757 return SDValue();
21758 }
21759
21760 if (VT.isScalableVector())
21761 return SDValue();
21762
21763 unsigned NumElts = VT.getVectorNumElements();
21764
21765 // We must know which element is being inserted for folds below here.
21766 unsigned Elt = IndexC->getZExtValue();
21767
21768 // Handle <1 x ???> vector insertion special cases.
21769 if (NumElts == 1) {
21770 // insert_vector_elt(x, extract_vector_elt(y, 0), 0) -> y
21771 if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
21772 InVal.getOperand(i: 0).getValueType() == VT &&
21773 isNullConstant(V: InVal.getOperand(i: 1)))
21774 return InVal.getOperand(i: 0);
21775 }
21776
21777 // Canonicalize insert_vector_elt dag nodes.
21778 // Example:
21779 // (insert_vector_elt (insert_vector_elt A, Idx0), Idx1)
21780 // -> (insert_vector_elt (insert_vector_elt A, Idx1), Idx0)
21781 //
21782 // Do this only if the child insert_vector node has one use; also
21783 // do this only if indices are both constants and Idx1 < Idx0.
21784 if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT && InVec.hasOneUse()
21785 && isa<ConstantSDNode>(Val: InVec.getOperand(i: 2))) {
21786 unsigned OtherElt = InVec.getConstantOperandVal(i: 2);
21787 if (Elt < OtherElt) {
21788 // Swap nodes.
21789 SDValue NewOp = DAG.getNode(Opcode: ISD::INSERT_VECTOR_ELT, DL, VT,
21790 N1: InVec.getOperand(i: 0), N2: InVal, N3: EltNo);
21791 AddToWorklist(N: NewOp.getNode());
21792 return DAG.getNode(Opcode: ISD::INSERT_VECTOR_ELT, DL: SDLoc(InVec.getNode()),
21793 VT, N1: NewOp, N2: InVec.getOperand(i: 1), N3: InVec.getOperand(i: 2));
21794 }
21795 }
21796
21797 if (SDValue Shuf = mergeInsertEltWithShuffle(N, InsIndex: Elt))
21798 return Shuf;
21799
21800 if (SDValue Shuf = combineInsertEltToShuffle(N, InsIndex: Elt))
21801 return Shuf;
21802
21803 if (SDValue Shuf = combineInsertEltToLoad(N, InsIndex: Elt))
21804 return Shuf;
21805
21806 // Attempt to convert an insert_vector_elt chain into a legal build_vector.
21807 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT)) {
21808 // vXi1 vector - we don't need to recurse.
21809 if (NumElts == 1)
21810 return DAG.getBuildVector(VT, DL, Ops: {InVal});
21811
21812 // If we haven't already collected the element, insert into the op list.
21813 EVT MaxEltVT = InVal.getValueType();
21814 auto AddBuildVectorOp = [&](SmallVectorImpl<SDValue> &Ops, SDValue Elt,
21815 unsigned Idx) {
21816 if (!Ops[Idx]) {
21817 Ops[Idx] = Elt;
21818 if (VT.isInteger()) {
21819 EVT EltVT = Elt.getValueType();
21820 MaxEltVT = MaxEltVT.bitsGE(VT: EltVT) ? MaxEltVT : EltVT;
21821 }
21822 }
21823 };
21824
21825 // Ensure all the operands are the same value type, fill any missing
21826 // operands with UNDEF and create the BUILD_VECTOR.
21827 auto CanonicalizeBuildVector = [&](SmallVectorImpl<SDValue> &Ops) {
21828 assert(Ops.size() == NumElts && "Unexpected vector size");
21829 for (SDValue &Op : Ops) {
21830 if (Op)
21831 Op = VT.isInteger() ? DAG.getAnyExtOrTrunc(Op, DL, VT: MaxEltVT) : Op;
21832 else
21833 Op = DAG.getUNDEF(VT: MaxEltVT);
21834 }
21835 return DAG.getBuildVector(VT, DL, Ops);
21836 };
21837
21838 SmallVector<SDValue, 8> Ops(NumElts, SDValue());
21839 Ops[Elt] = InVal;
21840
21841 // Recurse up a INSERT_VECTOR_ELT chain to build a BUILD_VECTOR.
21842 for (SDValue CurVec = InVec; CurVec;) {
21843 // UNDEF - build new BUILD_VECTOR from already inserted operands.
21844 if (CurVec.isUndef())
21845 return CanonicalizeBuildVector(Ops);
21846
21847 // BUILD_VECTOR - insert unused operands and build new BUILD_VECTOR.
21848 if (CurVec.getOpcode() == ISD::BUILD_VECTOR && CurVec.hasOneUse()) {
21849 for (unsigned I = 0; I != NumElts; ++I)
21850 AddBuildVectorOp(Ops, CurVec.getOperand(i: I), I);
21851 return CanonicalizeBuildVector(Ops);
21852 }
21853
21854 // SCALAR_TO_VECTOR - insert unused scalar and build new BUILD_VECTOR.
21855 if (CurVec.getOpcode() == ISD::SCALAR_TO_VECTOR && CurVec.hasOneUse()) {
21856 AddBuildVectorOp(Ops, CurVec.getOperand(i: 0), 0);
21857 return CanonicalizeBuildVector(Ops);
21858 }
21859
21860 // INSERT_VECTOR_ELT - insert operand and continue up the chain.
21861 if (CurVec.getOpcode() == ISD::INSERT_VECTOR_ELT && CurVec.hasOneUse())
21862 if (auto *CurIdx = dyn_cast<ConstantSDNode>(Val: CurVec.getOperand(i: 2)))
21863 if (CurIdx->getAPIntValue().ult(RHS: NumElts)) {
21864 unsigned Idx = CurIdx->getZExtValue();
21865 AddBuildVectorOp(Ops, CurVec.getOperand(i: 1), Idx);
21866
21867 // Found entire BUILD_VECTOR.
21868 if (all_of(Range&: Ops, P: [](SDValue Op) { return !!Op; }))
21869 return CanonicalizeBuildVector(Ops);
21870
21871 CurVec = CurVec->getOperand(Num: 0);
21872 continue;
21873 }
21874
21875 // VECTOR_SHUFFLE - if all the operands match the shuffle's sources,
21876 // update the shuffle mask (and second operand if we started with unary
21877 // shuffle) and create a new legal shuffle.
21878 if (CurVec.getOpcode() == ISD::VECTOR_SHUFFLE && CurVec.hasOneUse()) {
21879 auto *SVN = cast<ShuffleVectorSDNode>(Val&: CurVec);
21880 SDValue LHS = SVN->getOperand(Num: 0);
21881 SDValue RHS = SVN->getOperand(Num: 1);
21882 SmallVector<int, 16> Mask(SVN->getMask());
21883 bool Merged = true;
21884 for (auto I : enumerate(First&: Ops)) {
21885 SDValue &Op = I.value();
21886 if (Op) {
21887 SmallVector<int, 16> NewMask;
21888 if (!mergeEltWithShuffle(X&: LHS, Y&: RHS, Mask, NewMask, Elt: Op, InsIndex: I.index())) {
21889 Merged = false;
21890 break;
21891 }
21892 Mask = std::move(NewMask);
21893 }
21894 }
21895 if (Merged)
21896 if (SDValue NewShuffle =
21897 TLI.buildLegalVectorShuffle(VT, DL, N0: LHS, N1: RHS, Mask, DAG))
21898 return NewShuffle;
21899 }
21900
21901 // If all insertions are zero value, try to convert to AND mask.
21902 // TODO: Do this for -1 with OR mask?
21903 if (!LegalOperations && llvm::isNullConstant(V: InVal) &&
21904 all_of(Range&: Ops, P: [InVal](SDValue Op) { return !Op || Op == InVal; }) &&
21905 count_if(Range&: Ops, P: [InVal](SDValue Op) { return Op == InVal; }) >= 2) {
21906 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: MaxEltVT);
21907 SDValue AllOnes = DAG.getAllOnesConstant(DL, VT: MaxEltVT);
21908 SmallVector<SDValue, 8> Mask(NumElts);
21909 for (unsigned I = 0; I != NumElts; ++I)
21910 Mask[I] = Ops[I] ? Zero : AllOnes;
21911 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: CurVec,
21912 N2: DAG.getBuildVector(VT, DL, Ops: Mask));
21913 }
21914
21915 // Failed to find a match in the chain - bail.
21916 break;
21917 }
21918
21919 // See if we can fill in the missing constant elements as zeros.
21920 // TODO: Should we do this for any constant?
21921 APInt DemandedZeroElts = APInt::getZero(numBits: NumElts);
21922 for (unsigned I = 0; I != NumElts; ++I)
21923 if (!Ops[I])
21924 DemandedZeroElts.setBit(I);
21925
21926 if (DAG.MaskedVectorIsZero(Op: InVec, DemandedElts: DemandedZeroElts)) {
21927 SDValue Zero = VT.isInteger() ? DAG.getConstant(Val: 0, DL, VT: MaxEltVT)
21928 : DAG.getConstantFP(Val: 0, DL, VT: MaxEltVT);
21929 for (unsigned I = 0; I != NumElts; ++I)
21930 if (!Ops[I])
21931 Ops[I] = Zero;
21932
21933 return CanonicalizeBuildVector(Ops);
21934 }
21935 }
21936
21937 return SDValue();
21938}
21939
21940SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
21941 SDValue EltNo,
21942 LoadSDNode *OriginalLoad) {
21943 assert(OriginalLoad->isSimple());
21944
21945 EVT ResultVT = EVE->getValueType(ResNo: 0);
21946 EVT VecEltVT = InVecVT.getVectorElementType();
21947
21948 // If the vector element type is not a multiple of a byte then we are unable
21949 // to correctly compute an address to load only the extracted element as a
21950 // scalar.
21951 if (!VecEltVT.isByteSized())
21952 return SDValue();
21953
21954 ISD::LoadExtType ExtTy =
21955 ResultVT.bitsGT(VT: VecEltVT) ? ISD::NON_EXTLOAD : ISD::EXTLOAD;
21956 if (!TLI.isOperationLegalOrCustom(Op: ISD::LOAD, VT: VecEltVT) ||
21957 !TLI.shouldReduceLoadWidth(Load: OriginalLoad, ExtTy, NewVT: VecEltVT))
21958 return SDValue();
21959
21960 Align Alignment = OriginalLoad->getAlign();
21961 MachinePointerInfo MPI;
21962 SDLoc DL(EVE);
21963 if (auto *ConstEltNo = dyn_cast<ConstantSDNode>(Val&: EltNo)) {
21964 int Elt = ConstEltNo->getZExtValue();
21965 unsigned PtrOff = VecEltVT.getSizeInBits() * Elt / 8;
21966 MPI = OriginalLoad->getPointerInfo().getWithOffset(O: PtrOff);
21967 Alignment = commonAlignment(A: Alignment, Offset: PtrOff);
21968 } else {
21969 // Discard the pointer info except the address space because the memory
21970 // operand can't represent this new access since the offset is variable.
21971 MPI = MachinePointerInfo(OriginalLoad->getPointerInfo().getAddrSpace());
21972 Alignment = commonAlignment(A: Alignment, Offset: VecEltVT.getSizeInBits() / 8);
21973 }
21974
21975 unsigned IsFast = 0;
21976 if (!TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: VecEltVT,
21977 AddrSpace: OriginalLoad->getAddressSpace(), Alignment,
21978 Flags: OriginalLoad->getMemOperand()->getFlags(),
21979 Fast: &IsFast) ||
21980 !IsFast)
21981 return SDValue();
21982
21983 SDValue NewPtr = TLI.getVectorElementPointer(DAG, VecPtr: OriginalLoad->getBasePtr(),
21984 VecVT: InVecVT, Index: EltNo);
21985
21986 // We are replacing a vector load with a scalar load. The new load must have
21987 // identical memory op ordering to the original.
21988 SDValue Load;
21989 if (ResultVT.bitsGT(VT: VecEltVT)) {
21990 // If the result type of vextract is wider than the load, then issue an
21991 // extending load instead.
21992 ISD::LoadExtType ExtType =
21993 TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: ResultVT, MemVT: VecEltVT) ? ISD::ZEXTLOAD
21994 : ISD::EXTLOAD;
21995 Load = DAG.getExtLoad(ExtType, dl: DL, VT: ResultVT, Chain: OriginalLoad->getChain(),
21996 Ptr: NewPtr, PtrInfo: MPI, MemVT: VecEltVT, Alignment,
21997 MMOFlags: OriginalLoad->getMemOperand()->getFlags(),
21998 AAInfo: OriginalLoad->getAAInfo());
21999 DAG.makeEquivalentMemoryOrdering(OldLoad: OriginalLoad, NewMemOp: Load);
22000 } else {
22001 // The result type is narrower or the same width as the vector element
22002 Load = DAG.getLoad(VT: VecEltVT, dl: DL, Chain: OriginalLoad->getChain(), Ptr: NewPtr, PtrInfo: MPI,
22003 Alignment, MMOFlags: OriginalLoad->getMemOperand()->getFlags(),
22004 AAInfo: OriginalLoad->getAAInfo());
22005 DAG.makeEquivalentMemoryOrdering(OldLoad: OriginalLoad, NewMemOp: Load);
22006 if (ResultVT.bitsLT(VT: VecEltVT))
22007 Load = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: ResultVT, Operand: Load);
22008 else
22009 Load = DAG.getBitcast(VT: ResultVT, V: Load);
22010 }
22011 ++OpsNarrowed;
22012 return Load;
22013}
22014
22015/// Transform a vector binary operation into a scalar binary operation by moving
22016/// the math/logic after an extract element of a vector.
22017static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
22018 bool LegalOperations) {
22019 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
22020 SDValue Vec = ExtElt->getOperand(Num: 0);
22021 SDValue Index = ExtElt->getOperand(Num: 1);
22022 auto *IndexC = dyn_cast<ConstantSDNode>(Val&: Index);
22023 if (!IndexC || !TLI.isBinOp(Opcode: Vec.getOpcode()) || !Vec.hasOneUse() ||
22024 Vec->getNumValues() != 1)
22025 return SDValue();
22026
22027 // Targets may want to avoid this to prevent an expensive register transfer.
22028 if (!TLI.shouldScalarizeBinop(VecOp: Vec))
22029 return SDValue();
22030
22031 // Extracting an element of a vector constant is constant-folded, so this
22032 // transform is just replacing a vector op with a scalar op while moving the
22033 // extract.
22034 SDValue Op0 = Vec.getOperand(i: 0);
22035 SDValue Op1 = Vec.getOperand(i: 1);
22036 APInt SplatVal;
22037 if (isAnyConstantBuildVector(V: Op0, NoOpaques: true) ||
22038 ISD::isConstantSplatVector(N: Op0.getNode(), SplatValue&: SplatVal) ||
22039 isAnyConstantBuildVector(V: Op1, NoOpaques: true) ||
22040 ISD::isConstantSplatVector(N: Op1.getNode(), SplatValue&: SplatVal)) {
22041 // extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C'
22042 // extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC)
22043 SDLoc DL(ExtElt);
22044 EVT VT = ExtElt->getValueType(ResNo: 0);
22045 SDValue Ext0 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT, N1: Op0, N2: Index);
22046 SDValue Ext1 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT, N1: Op1, N2: Index);
22047 return DAG.getNode(Opcode: Vec.getOpcode(), DL, VT, N1: Ext0, N2: Ext1);
22048 }
22049
22050 return SDValue();
22051}
22052
22053// Given a ISD::EXTRACT_VECTOR_ELT, which is a glorified bit sequence extract,
22054// recursively analyse all of it's users. and try to model themselves as
22055// bit sequence extractions. If all of them agree on the new, narrower element
22056// type, and all of them can be modelled as ISD::EXTRACT_VECTOR_ELT's of that
22057// new element type, do so now.
22058// This is mainly useful to recover from legalization that scalarized
22059// the vector as wide elements, but tries to rebuild it with narrower elements.
22060//
22061// Some more nodes could be modelled if that helps cover interesting patterns.
22062bool DAGCombiner::refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(
22063 SDNode *N) {
22064 // We perform this optimization post type-legalization because
22065 // the type-legalizer often scalarizes integer-promoted vectors.
22066 // Performing this optimization before may cause legalizaton cycles.
22067 if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
22068 return false;
22069
22070 // TODO: Add support for big-endian.
22071 if (DAG.getDataLayout().isBigEndian())
22072 return false;
22073
22074 SDValue VecOp = N->getOperand(Num: 0);
22075 EVT VecVT = VecOp.getValueType();
22076 assert(!VecVT.isScalableVector() && "Only for fixed vectors.");
22077
22078 // We must start with a constant extraction index.
22079 auto *IndexC = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
22080 if (!IndexC)
22081 return false;
22082
22083 assert(IndexC->getZExtValue() < VecVT.getVectorNumElements() &&
22084 "Original ISD::EXTRACT_VECTOR_ELT is undefinend?");
22085
22086 // TODO: deal with the case of implicit anyext of the extraction.
22087 unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
22088 EVT ScalarVT = N->getValueType(ResNo: 0);
22089 if (VecVT.getScalarType() != ScalarVT)
22090 return false;
22091
22092 // TODO: deal with the cases other than everything being integer-typed.
22093 if (!ScalarVT.isScalarInteger())
22094 return false;
22095
22096 struct Entry {
22097 SDNode *Producer;
22098
22099 // Which bits of VecOp does it contain?
22100 unsigned BitPos;
22101 int NumBits;
22102 // NOTE: the actual width of \p Producer may be wider than NumBits!
22103
22104 Entry(Entry &&) = default;
22105 Entry(SDNode *Producer_, unsigned BitPos_, int NumBits_)
22106 : Producer(Producer_), BitPos(BitPos_), NumBits(NumBits_) {}
22107
22108 Entry() = delete;
22109 Entry(const Entry &) = delete;
22110 Entry &operator=(const Entry &) = delete;
22111 Entry &operator=(Entry &&) = delete;
22112 };
22113 SmallVector<Entry, 32> Worklist;
22114 SmallVector<Entry, 32> Leafs;
22115
22116 // We start at the "root" ISD::EXTRACT_VECTOR_ELT.
22117 Worklist.emplace_back(Args&: N, /*BitPos=*/Args: VecEltBitWidth * IndexC->getZExtValue(),
22118 /*NumBits=*/Args&: VecEltBitWidth);
22119
22120 while (!Worklist.empty()) {
22121 Entry E = Worklist.pop_back_val();
22122 // Does the node not even use any of the VecOp bits?
22123 if (!(E.NumBits > 0 && E.BitPos < VecVT.getSizeInBits() &&
22124 E.BitPos + E.NumBits <= VecVT.getSizeInBits()))
22125 return false; // Let's allow the other combines clean this up first.
22126 // Did we fail to model any of the users of the Producer?
22127 bool ProducerIsLeaf = false;
22128 // Look at each user of this Producer.
22129 for (SDNode *User : E.Producer->uses()) {
22130 switch (User->getOpcode()) {
22131 // TODO: support ISD::BITCAST
22132 // TODO: support ISD::ANY_EXTEND
22133 // TODO: support ISD::ZERO_EXTEND
22134 // TODO: support ISD::SIGN_EXTEND
22135 case ISD::TRUNCATE:
22136 // Truncation simply means we keep position, but extract less bits.
22137 Worklist.emplace_back(Args&: User, Args&: E.BitPos,
22138 /*NumBits=*/Args: User->getValueSizeInBits(ResNo: 0));
22139 break;
22140 // TODO: support ISD::SRA
22141 // TODO: support ISD::SHL
22142 case ISD::SRL:
22143 // We should be shifting the Producer by a constant amount.
22144 if (auto *ShAmtC = dyn_cast<ConstantSDNode>(Val: User->getOperand(Num: 1));
22145 User->getOperand(Num: 0).getNode() == E.Producer && ShAmtC) {
22146 // Logical right-shift means that we start extraction later,
22147 // but stop it at the same position we did previously.
22148 unsigned ShAmt = ShAmtC->getZExtValue();
22149 Worklist.emplace_back(Args&: User, Args: E.BitPos + ShAmt, Args: E.NumBits - ShAmt);
22150 break;
22151 }
22152 [[fallthrough]];
22153 default:
22154 // We can not model this user of the Producer.
22155 // Which means the current Producer will be a ISD::EXTRACT_VECTOR_ELT.
22156 ProducerIsLeaf = true;
22157 // Profitability check: all users that we can not model
22158 // must be ISD::BUILD_VECTOR's.
22159 if (User->getOpcode() != ISD::BUILD_VECTOR)
22160 return false;
22161 break;
22162 }
22163 }
22164 if (ProducerIsLeaf)
22165 Leafs.emplace_back(Args: std::move(E));
22166 }
22167
22168 unsigned NewVecEltBitWidth = Leafs.front().NumBits;
22169
22170 // If we are still at the same element granularity, give up,
22171 if (NewVecEltBitWidth == VecEltBitWidth)
22172 return false;
22173
22174 // The vector width must be a multiple of the new element width.
22175 if (VecVT.getSizeInBits() % NewVecEltBitWidth != 0)
22176 return false;
22177
22178 // All leafs must agree on the new element width.
22179 // All leafs must not expect any "padding" bits ontop of that width.
22180 // All leafs must start extraction from multiple of that width.
22181 if (!all_of(Range&: Leafs, P: [NewVecEltBitWidth](const Entry &E) {
22182 return (unsigned)E.NumBits == NewVecEltBitWidth &&
22183 E.Producer->getValueSizeInBits(ResNo: 0) == NewVecEltBitWidth &&
22184 E.BitPos % NewVecEltBitWidth == 0;
22185 }))
22186 return false;
22187
22188 EVT NewScalarVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: NewVecEltBitWidth);
22189 EVT NewVecVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NewScalarVT,
22190 NumElements: VecVT.getSizeInBits() / NewVecEltBitWidth);
22191
22192 if (LegalTypes &&
22193 !(TLI.isTypeLegal(VT: NewScalarVT) && TLI.isTypeLegal(VT: NewVecVT)))
22194 return false;
22195
22196 if (LegalOperations &&
22197 !(TLI.isOperationLegalOrCustom(Op: ISD::BITCAST, VT: NewVecVT) &&
22198 TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_VECTOR_ELT, VT: NewVecVT)))
22199 return false;
22200
22201 SDValue NewVecOp = DAG.getBitcast(VT: NewVecVT, V: VecOp);
22202 for (const Entry &E : Leafs) {
22203 SDLoc DL(E.Producer);
22204 unsigned NewIndex = E.BitPos / NewVecEltBitWidth;
22205 assert(NewIndex < NewVecVT.getVectorNumElements() &&
22206 "Creating out-of-bounds ISD::EXTRACT_VECTOR_ELT?");
22207 SDValue V = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: NewScalarVT, N1: NewVecOp,
22208 N2: DAG.getVectorIdxConstant(Val: NewIndex, DL));
22209 CombineTo(N: E.Producer, Res: V);
22210 }
22211
22212 return true;
22213}
22214
22215SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
22216 SDValue VecOp = N->getOperand(Num: 0);
22217 SDValue Index = N->getOperand(Num: 1);
22218 EVT ScalarVT = N->getValueType(ResNo: 0);
22219 EVT VecVT = VecOp.getValueType();
22220 if (VecOp.isUndef())
22221 return DAG.getUNDEF(VT: ScalarVT);
22222
22223 // extract_vector_elt (insert_vector_elt vec, val, idx), idx) -> val
22224 //
22225 // This only really matters if the index is non-constant since other combines
22226 // on the constant elements already work.
22227 SDLoc DL(N);
22228 if (VecOp.getOpcode() == ISD::INSERT_VECTOR_ELT &&
22229 Index == VecOp.getOperand(i: 2)) {
22230 SDValue Elt = VecOp.getOperand(i: 1);
22231 return VecVT.isInteger() ? DAG.getAnyExtOrTrunc(Op: Elt, DL, VT: ScalarVT) : Elt;
22232 }
22233
22234 // (vextract (scalar_to_vector val, 0) -> val
22235 if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR) {
22236 // Only 0'th element of SCALAR_TO_VECTOR is defined.
22237 if (DAG.isKnownNeverZero(Op: Index))
22238 return DAG.getUNDEF(VT: ScalarVT);
22239
22240 // Check if the result type doesn't match the inserted element type.
22241 // The inserted element and extracted element may have mismatched bitwidth.
22242 // As a result, EXTRACT_VECTOR_ELT may extend or truncate the extracted vector.
22243 SDValue InOp = VecOp.getOperand(i: 0);
22244 if (InOp.getValueType() != ScalarVT) {
22245 assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
22246 if (InOp.getValueType().bitsGT(VT: ScalarVT))
22247 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: ScalarVT, Operand: InOp);
22248 return DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: ScalarVT, Operand: InOp);
22249 }
22250 return InOp;
22251 }
22252
22253 // extract_vector_elt of out-of-bounds element -> UNDEF
22254 auto *IndexC = dyn_cast<ConstantSDNode>(Val&: Index);
22255 if (IndexC && VecVT.isFixedLengthVector() &&
22256 IndexC->getAPIntValue().uge(RHS: VecVT.getVectorNumElements()))
22257 return DAG.getUNDEF(VT: ScalarVT);
22258
22259 // extract_vector_elt(freeze(x)), idx -> freeze(extract_vector_elt(x)), idx
22260 if (VecOp.hasOneUse() && VecOp.getOpcode() == ISD::FREEZE) {
22261 return DAG.getFreeze(V: DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: ScalarVT,
22262 N1: VecOp.getOperand(i: 0), N2: Index));
22263 }
22264
22265 // extract_vector_elt (build_vector x, y), 1 -> y
22266 if (((IndexC && VecOp.getOpcode() == ISD::BUILD_VECTOR) ||
22267 VecOp.getOpcode() == ISD::SPLAT_VECTOR) &&
22268 TLI.isTypeLegal(VT: VecVT)) {
22269 assert((VecOp.getOpcode() != ISD::BUILD_VECTOR ||
22270 VecVT.isFixedLengthVector()) &&
22271 "BUILD_VECTOR used for scalable vectors");
22272 unsigned IndexVal =
22273 VecOp.getOpcode() == ISD::BUILD_VECTOR ? IndexC->getZExtValue() : 0;
22274 SDValue Elt = VecOp.getOperand(i: IndexVal);
22275 EVT InEltVT = Elt.getValueType();
22276
22277 if (VecOp.hasOneUse() || TLI.aggressivelyPreferBuildVectorSources(VecVT) ||
22278 isNullConstant(V: Elt)) {
22279 // Sometimes build_vector's scalar input types do not match result type.
22280 if (ScalarVT == InEltVT)
22281 return Elt;
22282
22283 // TODO: It may be useful to truncate if free if the build_vector
22284 // implicitly converts.
22285 }
22286 }
22287
22288 if (SDValue BO = scalarizeExtractedBinop(ExtElt: N, DAG, LegalOperations))
22289 return BO;
22290
22291 if (VecVT.isScalableVector())
22292 return SDValue();
22293
22294 // All the code from this point onwards assumes fixed width vectors, but it's
22295 // possible that some of the combinations could be made to work for scalable
22296 // vectors too.
22297 unsigned NumElts = VecVT.getVectorNumElements();
22298 unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
22299
22300 // See if the extracted element is constant, in which case fold it if its
22301 // a legal fp immediate.
22302 if (IndexC && ScalarVT.isFloatingPoint()) {
22303 APInt EltMask = APInt::getOneBitSet(numBits: NumElts, BitNo: IndexC->getZExtValue());
22304 KnownBits KnownElt = DAG.computeKnownBits(Op: VecOp, DemandedElts: EltMask);
22305 if (KnownElt.isConstant()) {
22306 APFloat CstFP =
22307 APFloat(DAG.EVTToAPFloatSemantics(VT: ScalarVT), KnownElt.getConstant());
22308 if (TLI.isFPImmLegal(CstFP, ScalarVT))
22309 return DAG.getConstantFP(Val: CstFP, DL, VT: ScalarVT);
22310 }
22311 }
22312
22313 // TODO: These transforms should not require the 'hasOneUse' restriction, but
22314 // there are regressions on multiple targets without it. We can end up with a
22315 // mess of scalar and vector code if we reduce only part of the DAG to scalar.
22316 if (IndexC && VecOp.getOpcode() == ISD::BITCAST && VecVT.isInteger() &&
22317 VecOp.hasOneUse()) {
22318 // The vector index of the LSBs of the source depend on the endian-ness.
22319 bool IsLE = DAG.getDataLayout().isLittleEndian();
22320 unsigned ExtractIndex = IndexC->getZExtValue();
22321 // extract_elt (v2i32 (bitcast i64:x)), BCTruncElt -> i32 (trunc i64:x)
22322 unsigned BCTruncElt = IsLE ? 0 : NumElts - 1;
22323 SDValue BCSrc = VecOp.getOperand(i: 0);
22324 if (ExtractIndex == BCTruncElt && BCSrc.getValueType().isScalarInteger())
22325 return DAG.getAnyExtOrTrunc(Op: BCSrc, DL, VT: ScalarVT);
22326
22327 if (LegalTypes && BCSrc.getValueType().isInteger() &&
22328 BCSrc.getOpcode() == ISD::SCALAR_TO_VECTOR) {
22329 // ext_elt (bitcast (scalar_to_vec i64 X to v2i64) to v4i32), TruncElt -->
22330 // trunc i64 X to i32
22331 SDValue X = BCSrc.getOperand(i: 0);
22332 assert(X.getValueType().isScalarInteger() && ScalarVT.isScalarInteger() &&
22333 "Extract element and scalar to vector can't change element type "
22334 "from FP to integer.");
22335 unsigned XBitWidth = X.getValueSizeInBits();
22336 BCTruncElt = IsLE ? 0 : XBitWidth / VecEltBitWidth - 1;
22337
22338 // An extract element return value type can be wider than its vector
22339 // operand element type. In that case, the high bits are undefined, so
22340 // it's possible that we may need to extend rather than truncate.
22341 if (ExtractIndex == BCTruncElt && XBitWidth > VecEltBitWidth) {
22342 assert(XBitWidth % VecEltBitWidth == 0 &&
22343 "Scalar bitwidth must be a multiple of vector element bitwidth");
22344 return DAG.getAnyExtOrTrunc(Op: X, DL, VT: ScalarVT);
22345 }
22346 }
22347 }
22348
22349 // Transform: (EXTRACT_VECTOR_ELT( VECTOR_SHUFFLE )) -> EXTRACT_VECTOR_ELT.
22350 // We only perform this optimization before the op legalization phase because
22351 // we may introduce new vector instructions which are not backed by TD
22352 // patterns. For example on AVX, extracting elements from a wide vector
22353 // without using extract_subvector. However, if we can find an underlying
22354 // scalar value, then we can always use that.
22355 if (IndexC && VecOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
22356 auto *Shuf = cast<ShuffleVectorSDNode>(Val&: VecOp);
22357 // Find the new index to extract from.
22358 int OrigElt = Shuf->getMaskElt(Idx: IndexC->getZExtValue());
22359
22360 // Extracting an undef index is undef.
22361 if (OrigElt == -1)
22362 return DAG.getUNDEF(VT: ScalarVT);
22363
22364 // Select the right vector half to extract from.
22365 SDValue SVInVec;
22366 if (OrigElt < (int)NumElts) {
22367 SVInVec = VecOp.getOperand(i: 0);
22368 } else {
22369 SVInVec = VecOp.getOperand(i: 1);
22370 OrigElt -= NumElts;
22371 }
22372
22373 if (SVInVec.getOpcode() == ISD::BUILD_VECTOR) {
22374 SDValue InOp = SVInVec.getOperand(i: OrigElt);
22375 if (InOp.getValueType() != ScalarVT) {
22376 assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
22377 InOp = DAG.getSExtOrTrunc(Op: InOp, DL, VT: ScalarVT);
22378 }
22379
22380 return InOp;
22381 }
22382
22383 // FIXME: We should handle recursing on other vector shuffles and
22384 // scalar_to_vector here as well.
22385
22386 if (!LegalOperations ||
22387 // FIXME: Should really be just isOperationLegalOrCustom.
22388 TLI.isOperationLegal(Op: ISD::EXTRACT_VECTOR_ELT, VT: VecVT) ||
22389 TLI.isOperationExpand(Op: ISD::VECTOR_SHUFFLE, VT: VecVT)) {
22390 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: ScalarVT, N1: SVInVec,
22391 N2: DAG.getVectorIdxConstant(Val: OrigElt, DL));
22392 }
22393 }
22394
22395 // If only EXTRACT_VECTOR_ELT nodes use the source vector we can
22396 // simplify it based on the (valid) extraction indices.
22397 if (llvm::all_of(Range: VecOp->uses(), P: [&](SDNode *Use) {
22398 return Use->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
22399 Use->getOperand(Num: 0) == VecOp &&
22400 isa<ConstantSDNode>(Val: Use->getOperand(Num: 1));
22401 })) {
22402 APInt DemandedElts = APInt::getZero(numBits: NumElts);
22403 for (SDNode *Use : VecOp->uses()) {
22404 auto *CstElt = cast<ConstantSDNode>(Val: Use->getOperand(Num: 1));
22405 if (CstElt->getAPIntValue().ult(RHS: NumElts))
22406 DemandedElts.setBit(CstElt->getZExtValue());
22407 }
22408 if (SimplifyDemandedVectorElts(Op: VecOp, DemandedElts, AssumeSingleUse: true)) {
22409 // We simplified the vector operand of this extract element. If this
22410 // extract is not dead, visit it again so it is folded properly.
22411 if (N->getOpcode() != ISD::DELETED_NODE)
22412 AddToWorklist(N);
22413 return SDValue(N, 0);
22414 }
22415 APInt DemandedBits = APInt::getAllOnes(numBits: VecEltBitWidth);
22416 if (SimplifyDemandedBits(Op: VecOp, DemandedBits, DemandedElts, AssumeSingleUse: true)) {
22417 // We simplified the vector operand of this extract element. If this
22418 // extract is not dead, visit it again so it is folded properly.
22419 if (N->getOpcode() != ISD::DELETED_NODE)
22420 AddToWorklist(N);
22421 return SDValue(N, 0);
22422 }
22423 }
22424
22425 if (refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(N))
22426 return SDValue(N, 0);
22427
22428 // Everything under here is trying to match an extract of a loaded value.
22429 // If the result of load has to be truncated, then it's not necessarily
22430 // profitable.
22431 bool BCNumEltsChanged = false;
22432 EVT ExtVT = VecVT.getVectorElementType();
22433 EVT LVT = ExtVT;
22434 if (ScalarVT.bitsLT(VT: LVT) && !TLI.isTruncateFree(FromVT: LVT, ToVT: ScalarVT))
22435 return SDValue();
22436
22437 if (VecOp.getOpcode() == ISD::BITCAST) {
22438 // Don't duplicate a load with other uses.
22439 if (!VecOp.hasOneUse())
22440 return SDValue();
22441
22442 EVT BCVT = VecOp.getOperand(i: 0).getValueType();
22443 if (!BCVT.isVector() || ExtVT.bitsGT(VT: BCVT.getVectorElementType()))
22444 return SDValue();
22445 if (NumElts != BCVT.getVectorNumElements())
22446 BCNumEltsChanged = true;
22447 VecOp = VecOp.getOperand(i: 0);
22448 ExtVT = BCVT.getVectorElementType();
22449 }
22450
22451 // extract (vector load $addr), i --> load $addr + i * size
22452 if (!LegalOperations && !IndexC && VecOp.hasOneUse() &&
22453 ISD::isNormalLoad(N: VecOp.getNode()) &&
22454 !Index->hasPredecessor(N: VecOp.getNode())) {
22455 auto *VecLoad = dyn_cast<LoadSDNode>(Val&: VecOp);
22456 if (VecLoad && VecLoad->isSimple())
22457 return scalarizeExtractedVectorLoad(EVE: N, InVecVT: VecVT, EltNo: Index, OriginalLoad: VecLoad);
22458 }
22459
22460 // Perform only after legalization to ensure build_vector / vector_shuffle
22461 // optimizations have already been done.
22462 if (!LegalOperations || !IndexC)
22463 return SDValue();
22464
22465 // (vextract (v4f32 load $addr), c) -> (f32 load $addr+c*size)
22466 // (vextract (v4f32 s2v (f32 load $addr)), c) -> (f32 load $addr+c*size)
22467 // (vextract (v4f32 shuffle (load $addr), <1,u,u,u>), 0) -> (f32 load $addr)
22468 int Elt = IndexC->getZExtValue();
22469 LoadSDNode *LN0 = nullptr;
22470 if (ISD::isNormalLoad(N: VecOp.getNode())) {
22471 LN0 = cast<LoadSDNode>(Val&: VecOp);
22472 } else if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR &&
22473 VecOp.getOperand(i: 0).getValueType() == ExtVT &&
22474 ISD::isNormalLoad(N: VecOp.getOperand(i: 0).getNode())) {
22475 // Don't duplicate a load with other uses.
22476 if (!VecOp.hasOneUse())
22477 return SDValue();
22478
22479 LN0 = cast<LoadSDNode>(Val: VecOp.getOperand(i: 0));
22480 }
22481 if (auto *Shuf = dyn_cast<ShuffleVectorSDNode>(Val&: VecOp)) {
22482 // (vextract (vector_shuffle (load $addr), v2, <1, u, u, u>), 1)
22483 // =>
22484 // (load $addr+1*size)
22485
22486 // Don't duplicate a load with other uses.
22487 if (!VecOp.hasOneUse())
22488 return SDValue();
22489
22490 // If the bit convert changed the number of elements, it is unsafe
22491 // to examine the mask.
22492 if (BCNumEltsChanged)
22493 return SDValue();
22494
22495 // Select the input vector, guarding against out of range extract vector.
22496 int Idx = (Elt > (int)NumElts) ? -1 : Shuf->getMaskElt(Idx: Elt);
22497 VecOp = (Idx < (int)NumElts) ? VecOp.getOperand(i: 0) : VecOp.getOperand(i: 1);
22498
22499 if (VecOp.getOpcode() == ISD::BITCAST) {
22500 // Don't duplicate a load with other uses.
22501 if (!VecOp.hasOneUse())
22502 return SDValue();
22503
22504 VecOp = VecOp.getOperand(i: 0);
22505 }
22506 if (ISD::isNormalLoad(N: VecOp.getNode())) {
22507 LN0 = cast<LoadSDNode>(Val&: VecOp);
22508 Elt = (Idx < (int)NumElts) ? Idx : Idx - (int)NumElts;
22509 Index = DAG.getConstant(Val: Elt, DL, VT: Index.getValueType());
22510 }
22511 } else if (VecOp.getOpcode() == ISD::CONCAT_VECTORS && !BCNumEltsChanged &&
22512 VecVT.getVectorElementType() == ScalarVT &&
22513 (!LegalTypes ||
22514 TLI.isTypeLegal(
22515 VT: VecOp.getOperand(i: 0).getValueType().getVectorElementType()))) {
22516 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 0
22517 // -> extract_vector_elt a, 0
22518 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 1
22519 // -> extract_vector_elt a, 1
22520 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 2
22521 // -> extract_vector_elt b, 0
22522 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 3
22523 // -> extract_vector_elt b, 1
22524 SDLoc SL(N);
22525 EVT ConcatVT = VecOp.getOperand(i: 0).getValueType();
22526 unsigned ConcatNumElts = ConcatVT.getVectorNumElements();
22527 SDValue NewIdx = DAG.getConstant(Val: Elt % ConcatNumElts, DL: SL,
22528 VT: Index.getValueType());
22529
22530 SDValue ConcatOp = VecOp.getOperand(i: Elt / ConcatNumElts);
22531 SDValue Elt = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: SL,
22532 VT: ConcatVT.getVectorElementType(),
22533 N1: ConcatOp, N2: NewIdx);
22534 return DAG.getNode(Opcode: ISD::BITCAST, DL: SL, VT: ScalarVT, Operand: Elt);
22535 }
22536
22537 // Make sure we found a non-volatile load and the extractelement is
22538 // the only use.
22539 if (!LN0 || !LN0->hasNUsesOfValue(NUses: 1,Value: 0) || !LN0->isSimple())
22540 return SDValue();
22541
22542 // If Idx was -1 above, Elt is going to be -1, so just return undef.
22543 if (Elt == -1)
22544 return DAG.getUNDEF(VT: LVT);
22545
22546 return scalarizeExtractedVectorLoad(EVE: N, InVecVT: VecVT, EltNo: Index, OriginalLoad: LN0);
22547}
22548
22549// Simplify (build_vec (ext )) to (bitcast (build_vec ))
22550SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) {
22551 // We perform this optimization post type-legalization because
22552 // the type-legalizer often scalarizes integer-promoted vectors.
22553 // Performing this optimization before may create bit-casts which
22554 // will be type-legalized to complex code sequences.
22555 // We perform this optimization only before the operation legalizer because we
22556 // may introduce illegal operations.
22557 if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
22558 return SDValue();
22559
22560 unsigned NumInScalars = N->getNumOperands();
22561 SDLoc DL(N);
22562 EVT VT = N->getValueType(ResNo: 0);
22563
22564 // Check to see if this is a BUILD_VECTOR of a bunch of values
22565 // which come from any_extend or zero_extend nodes. If so, we can create
22566 // a new BUILD_VECTOR using bit-casts which may enable other BUILD_VECTOR
22567 // optimizations. We do not handle sign-extend because we can't fill the sign
22568 // using shuffles.
22569 EVT SourceType = MVT::Other;
22570 bool AllAnyExt = true;
22571
22572 for (unsigned i = 0; i != NumInScalars; ++i) {
22573 SDValue In = N->getOperand(Num: i);
22574 // Ignore undef inputs.
22575 if (In.isUndef()) continue;
22576
22577 bool AnyExt = In.getOpcode() == ISD::ANY_EXTEND;
22578 bool ZeroExt = In.getOpcode() == ISD::ZERO_EXTEND;
22579
22580 // Abort if the element is not an extension.
22581 if (!ZeroExt && !AnyExt) {
22582 SourceType = MVT::Other;
22583 break;
22584 }
22585
22586 // The input is a ZeroExt or AnyExt. Check the original type.
22587 EVT InTy = In.getOperand(i: 0).getValueType();
22588
22589 // Check that all of the widened source types are the same.
22590 if (SourceType == MVT::Other)
22591 // First time.
22592 SourceType = InTy;
22593 else if (InTy != SourceType) {
22594 // Multiple income types. Abort.
22595 SourceType = MVT::Other;
22596 break;
22597 }
22598
22599 // Check if all of the extends are ANY_EXTENDs.
22600 AllAnyExt &= AnyExt;
22601 }
22602
22603 // In order to have valid types, all of the inputs must be extended from the
22604 // same source type and all of the inputs must be any or zero extend.
22605 // Scalar sizes must be a power of two.
22606 EVT OutScalarTy = VT.getScalarType();
22607 bool ValidTypes =
22608 SourceType != MVT::Other &&
22609 llvm::has_single_bit<uint32_t>(OutScalarTy.getSizeInBits()) &&
22610 llvm::has_single_bit<uint32_t>(SourceType.getSizeInBits());
22611
22612 // Create a new simpler BUILD_VECTOR sequence which other optimizations can
22613 // turn into a single shuffle instruction.
22614 if (!ValidTypes)
22615 return SDValue();
22616
22617 // If we already have a splat buildvector, then don't fold it if it means
22618 // introducing zeros.
22619 if (!AllAnyExt && DAG.isSplatValue(V: SDValue(N, 0), /*AllowUndefs*/ true))
22620 return SDValue();
22621
22622 bool isLE = DAG.getDataLayout().isLittleEndian();
22623 unsigned ElemRatio = OutScalarTy.getSizeInBits()/SourceType.getSizeInBits();
22624 assert(ElemRatio > 1 && "Invalid element size ratio");
22625 SDValue Filler = AllAnyExt ? DAG.getUNDEF(VT: SourceType):
22626 DAG.getConstant(Val: 0, DL, VT: SourceType);
22627
22628 unsigned NewBVElems = ElemRatio * VT.getVectorNumElements();
22629 SmallVector<SDValue, 8> Ops(NewBVElems, Filler);
22630
22631 // Populate the new build_vector
22632 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
22633 SDValue Cast = N->getOperand(Num: i);
22634 assert((Cast.getOpcode() == ISD::ANY_EXTEND ||
22635 Cast.getOpcode() == ISD::ZERO_EXTEND ||
22636 Cast.isUndef()) && "Invalid cast opcode");
22637 SDValue In;
22638 if (Cast.isUndef())
22639 In = DAG.getUNDEF(VT: SourceType);
22640 else
22641 In = Cast->getOperand(Num: 0);
22642 unsigned Index = isLE ? (i * ElemRatio) :
22643 (i * ElemRatio + (ElemRatio - 1));
22644
22645 assert(Index < Ops.size() && "Invalid index");
22646 Ops[Index] = In;
22647 }
22648
22649 // The type of the new BUILD_VECTOR node.
22650 EVT VecVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: SourceType, NumElements: NewBVElems);
22651 assert(VecVT.getSizeInBits() == VT.getSizeInBits() &&
22652 "Invalid vector size");
22653 // Check if the new vector type is legal.
22654 if (!isTypeLegal(VT: VecVT) ||
22655 (!TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT: VecVT) &&
22656 TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT)))
22657 return SDValue();
22658
22659 // Make the new BUILD_VECTOR.
22660 SDValue BV = DAG.getBuildVector(VT: VecVT, DL, Ops);
22661
22662 // The new BUILD_VECTOR node has the potential to be further optimized.
22663 AddToWorklist(N: BV.getNode());
22664 // Bitcast to the desired type.
22665 return DAG.getBitcast(VT, V: BV);
22666}
22667
22668// Simplify (build_vec (trunc $1)
22669// (trunc (srl $1 half-width))
22670// (trunc (srl $1 (2 * half-width))))
22671// to (bitcast $1)
22672SDValue DAGCombiner::reduceBuildVecTruncToBitCast(SDNode *N) {
22673 assert(N->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
22674
22675 EVT VT = N->getValueType(ResNo: 0);
22676
22677 // Don't run this before LegalizeTypes if VT is legal.
22678 // Targets may have other preferences.
22679 if (Level < AfterLegalizeTypes && TLI.isTypeLegal(VT))
22680 return SDValue();
22681
22682 // Only for little endian
22683 if (!DAG.getDataLayout().isLittleEndian())
22684 return SDValue();
22685
22686 SDLoc DL(N);
22687 EVT OutScalarTy = VT.getScalarType();
22688 uint64_t ScalarTypeBitsize = OutScalarTy.getSizeInBits();
22689
22690 // Only for power of two types to be sure that bitcast works well
22691 if (!isPowerOf2_64(Value: ScalarTypeBitsize))
22692 return SDValue();
22693
22694 unsigned NumInScalars = N->getNumOperands();
22695
22696 // Look through bitcasts
22697 auto PeekThroughBitcast = [](SDValue Op) {
22698 if (Op.getOpcode() == ISD::BITCAST)
22699 return Op.getOperand(i: 0);
22700 return Op;
22701 };
22702
22703 // The source value where all the parts are extracted.
22704 SDValue Src;
22705 for (unsigned i = 0; i != NumInScalars; ++i) {
22706 SDValue In = PeekThroughBitcast(N->getOperand(Num: i));
22707 // Ignore undef inputs.
22708 if (In.isUndef()) continue;
22709
22710 if (In.getOpcode() != ISD::TRUNCATE)
22711 return SDValue();
22712
22713 In = PeekThroughBitcast(In.getOperand(i: 0));
22714
22715 if (In.getOpcode() != ISD::SRL) {
22716 // For now only build_vec without shuffling, handle shifts here in the
22717 // future.
22718 if (i != 0)
22719 return SDValue();
22720
22721 Src = In;
22722 } else {
22723 // In is SRL
22724 SDValue part = PeekThroughBitcast(In.getOperand(i: 0));
22725
22726 if (!Src) {
22727 Src = part;
22728 } else if (Src != part) {
22729 // Vector parts do not stem from the same variable
22730 return SDValue();
22731 }
22732
22733 SDValue ShiftAmtVal = In.getOperand(i: 1);
22734 if (!isa<ConstantSDNode>(Val: ShiftAmtVal))
22735 return SDValue();
22736
22737 uint64_t ShiftAmt = In.getConstantOperandVal(i: 1);
22738
22739 // The extracted value is not extracted at the right position
22740 if (ShiftAmt != i * ScalarTypeBitsize)
22741 return SDValue();
22742 }
22743 }
22744
22745 // Only cast if the size is the same
22746 if (!Src || Src.getValueType().getSizeInBits() != VT.getSizeInBits())
22747 return SDValue();
22748
22749 return DAG.getBitcast(VT, V: Src);
22750}
22751
22752SDValue DAGCombiner::createBuildVecShuffle(const SDLoc &DL, SDNode *N,
22753 ArrayRef<int> VectorMask,
22754 SDValue VecIn1, SDValue VecIn2,
22755 unsigned LeftIdx, bool DidSplitVec) {
22756 SDValue ZeroIdx = DAG.getVectorIdxConstant(Val: 0, DL);
22757
22758 EVT VT = N->getValueType(ResNo: 0);
22759 EVT InVT1 = VecIn1.getValueType();
22760 EVT InVT2 = VecIn2.getNode() ? VecIn2.getValueType() : InVT1;
22761
22762 unsigned NumElems = VT.getVectorNumElements();
22763 unsigned ShuffleNumElems = NumElems;
22764
22765 // If we artificially split a vector in two already, then the offsets in the
22766 // operands will all be based off of VecIn1, even those in VecIn2.
22767 unsigned Vec2Offset = DidSplitVec ? 0 : InVT1.getVectorNumElements();
22768
22769 uint64_t VTSize = VT.getFixedSizeInBits();
22770 uint64_t InVT1Size = InVT1.getFixedSizeInBits();
22771 uint64_t InVT2Size = InVT2.getFixedSizeInBits();
22772
22773 assert(InVT2Size <= InVT1Size &&
22774 "Inputs must be sorted to be in non-increasing vector size order.");
22775
22776 // We can't generate a shuffle node with mismatched input and output types.
22777 // Try to make the types match the type of the output.
22778 if (InVT1 != VT || InVT2 != VT) {
22779 if ((VTSize % InVT1Size == 0) && InVT1 == InVT2) {
22780 // If the output vector length is a multiple of both input lengths,
22781 // we can concatenate them and pad the rest with undefs.
22782 unsigned NumConcats = VTSize / InVT1Size;
22783 assert(NumConcats >= 2 && "Concat needs at least two inputs!");
22784 SmallVector<SDValue, 2> ConcatOps(NumConcats, DAG.getUNDEF(VT: InVT1));
22785 ConcatOps[0] = VecIn1;
22786 ConcatOps[1] = VecIn2 ? VecIn2 : DAG.getUNDEF(VT: InVT1);
22787 VecIn1 = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT, Ops: ConcatOps);
22788 VecIn2 = SDValue();
22789 } else if (InVT1Size == VTSize * 2) {
22790 if (!TLI.isExtractSubvectorCheap(ResVT: VT, SrcVT: InVT1, Index: NumElems))
22791 return SDValue();
22792
22793 if (!VecIn2.getNode()) {
22794 // If we only have one input vector, and it's twice the size of the
22795 // output, split it in two.
22796 VecIn2 = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT, N1: VecIn1,
22797 N2: DAG.getVectorIdxConstant(Val: NumElems, DL));
22798 VecIn1 = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT, N1: VecIn1, N2: ZeroIdx);
22799 // Since we now have shorter input vectors, adjust the offset of the
22800 // second vector's start.
22801 Vec2Offset = NumElems;
22802 } else {
22803 assert(InVT2Size <= InVT1Size &&
22804 "Second input is not going to be larger than the first one.");
22805
22806 // VecIn1 is wider than the output, and we have another, possibly
22807 // smaller input. Pad the smaller input with undefs, shuffle at the
22808 // input vector width, and extract the output.
22809 // The shuffle type is different than VT, so check legality again.
22810 if (LegalOperations &&
22811 !TLI.isOperationLegal(Op: ISD::VECTOR_SHUFFLE, VT: InVT1))
22812 return SDValue();
22813
22814 // Legalizing INSERT_SUBVECTOR is tricky - you basically have to
22815 // lower it back into a BUILD_VECTOR. So if the inserted type is
22816 // illegal, don't even try.
22817 if (InVT1 != InVT2) {
22818 if (!TLI.isTypeLegal(VT: InVT2))
22819 return SDValue();
22820 VecIn2 = DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL, VT: InVT1,
22821 N1: DAG.getUNDEF(VT: InVT1), N2: VecIn2, N3: ZeroIdx);
22822 }
22823 ShuffleNumElems = NumElems * 2;
22824 }
22825 } else if (InVT2Size * 2 == VTSize && InVT1Size == VTSize) {
22826 SmallVector<SDValue, 2> ConcatOps(2, DAG.getUNDEF(VT: InVT2));
22827 ConcatOps[0] = VecIn2;
22828 VecIn2 = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT, Ops: ConcatOps);
22829 } else if (InVT1Size / VTSize > 1 && InVT1Size % VTSize == 0) {
22830 if (!TLI.isExtractSubvectorCheap(ResVT: VT, SrcVT: InVT1, Index: NumElems) ||
22831 !TLI.isTypeLegal(VT: InVT1) || !TLI.isTypeLegal(VT: InVT2))
22832 return SDValue();
22833 // If dest vector has less than two elements, then use shuffle and extract
22834 // from larger regs will cost even more.
22835 if (VT.getVectorNumElements() <= 2 || !VecIn2.getNode())
22836 return SDValue();
22837 assert(InVT2Size <= InVT1Size &&
22838 "Second input is not going to be larger than the first one.");
22839
22840 // VecIn1 is wider than the output, and we have another, possibly
22841 // smaller input. Pad the smaller input with undefs, shuffle at the
22842 // input vector width, and extract the output.
22843 // The shuffle type is different than VT, so check legality again.
22844 if (LegalOperations && !TLI.isOperationLegal(Op: ISD::VECTOR_SHUFFLE, VT: InVT1))
22845 return SDValue();
22846
22847 if (InVT1 != InVT2) {
22848 VecIn2 = DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL, VT: InVT1,
22849 N1: DAG.getUNDEF(VT: InVT1), N2: VecIn2, N3: ZeroIdx);
22850 }
22851 ShuffleNumElems = InVT1Size / VTSize * NumElems;
22852 } else {
22853 // TODO: Support cases where the length mismatch isn't exactly by a
22854 // factor of 2.
22855 // TODO: Move this check upwards, so that if we have bad type
22856 // mismatches, we don't create any DAG nodes.
22857 return SDValue();
22858 }
22859 }
22860
22861 // Initialize mask to undef.
22862 SmallVector<int, 8> Mask(ShuffleNumElems, -1);
22863
22864 // Only need to run up to the number of elements actually used, not the
22865 // total number of elements in the shuffle - if we are shuffling a wider
22866 // vector, the high lanes should be set to undef.
22867 for (unsigned i = 0; i != NumElems; ++i) {
22868 if (VectorMask[i] <= 0)
22869 continue;
22870
22871 unsigned ExtIndex = N->getOperand(Num: i).getConstantOperandVal(i: 1);
22872 if (VectorMask[i] == (int)LeftIdx) {
22873 Mask[i] = ExtIndex;
22874 } else if (VectorMask[i] == (int)LeftIdx + 1) {
22875 Mask[i] = Vec2Offset + ExtIndex;
22876 }
22877 }
22878
22879 // The type the input vectors may have changed above.
22880 InVT1 = VecIn1.getValueType();
22881
22882 // If we already have a VecIn2, it should have the same type as VecIn1.
22883 // If we don't, get an undef/zero vector of the appropriate type.
22884 VecIn2 = VecIn2.getNode() ? VecIn2 : DAG.getUNDEF(VT: InVT1);
22885 assert(InVT1 == VecIn2.getValueType() && "Unexpected second input type.");
22886
22887 SDValue Shuffle = DAG.getVectorShuffle(VT: InVT1, dl: DL, N1: VecIn1, N2: VecIn2, Mask);
22888 if (ShuffleNumElems > NumElems)
22889 Shuffle = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT, N1: Shuffle, N2: ZeroIdx);
22890
22891 return Shuffle;
22892}
22893
22894static SDValue reduceBuildVecToShuffleWithZero(SDNode *BV, SelectionDAG &DAG) {
22895 assert(BV->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
22896
22897 // First, determine where the build vector is not undef.
22898 // TODO: We could extend this to handle zero elements as well as undefs.
22899 int NumBVOps = BV->getNumOperands();
22900 int ZextElt = -1;
22901 for (int i = 0; i != NumBVOps; ++i) {
22902 SDValue Op = BV->getOperand(Num: i);
22903 if (Op.isUndef())
22904 continue;
22905 if (ZextElt == -1)
22906 ZextElt = i;
22907 else
22908 return SDValue();
22909 }
22910 // Bail out if there's no non-undef element.
22911 if (ZextElt == -1)
22912 return SDValue();
22913
22914 // The build vector contains some number of undef elements and exactly
22915 // one other element. That other element must be a zero-extended scalar
22916 // extracted from a vector at a constant index to turn this into a shuffle.
22917 // Also, require that the build vector does not implicitly truncate/extend
22918 // its elements.
22919 // TODO: This could be enhanced to allow ANY_EXTEND as well as ZERO_EXTEND.
22920 EVT VT = BV->getValueType(ResNo: 0);
22921 SDValue Zext = BV->getOperand(Num: ZextElt);
22922 if (Zext.getOpcode() != ISD::ZERO_EXTEND || !Zext.hasOneUse() ||
22923 Zext.getOperand(i: 0).getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
22924 !isa<ConstantSDNode>(Val: Zext.getOperand(i: 0).getOperand(i: 1)) ||
22925 Zext.getValueSizeInBits() != VT.getScalarSizeInBits())
22926 return SDValue();
22927
22928 // The zero-extend must be a multiple of the source size, and we must be
22929 // building a vector of the same size as the source of the extract element.
22930 SDValue Extract = Zext.getOperand(i: 0);
22931 unsigned DestSize = Zext.getValueSizeInBits();
22932 unsigned SrcSize = Extract.getValueSizeInBits();
22933 if (DestSize % SrcSize != 0 ||
22934 Extract.getOperand(i: 0).getValueSizeInBits() != VT.getSizeInBits())
22935 return SDValue();
22936
22937 // Create a shuffle mask that will combine the extracted element with zeros
22938 // and undefs.
22939 int ZextRatio = DestSize / SrcSize;
22940 int NumMaskElts = NumBVOps * ZextRatio;
22941 SmallVector<int, 32> ShufMask(NumMaskElts, -1);
22942 for (int i = 0; i != NumMaskElts; ++i) {
22943 if (i / ZextRatio == ZextElt) {
22944 // The low bits of the (potentially translated) extracted element map to
22945 // the source vector. The high bits map to zero. We will use a zero vector
22946 // as the 2nd source operand of the shuffle, so use the 1st element of
22947 // that vector (mask value is number-of-elements) for the high bits.
22948 int Low = DAG.getDataLayout().isBigEndian() ? (ZextRatio - 1) : 0;
22949 ShufMask[i] = (i % ZextRatio == Low) ? Extract.getConstantOperandVal(i: 1)
22950 : NumMaskElts;
22951 }
22952
22953 // Undef elements of the build vector remain undef because we initialize
22954 // the shuffle mask with -1.
22955 }
22956
22957 // buildvec undef, ..., (zext (extractelt V, IndexC)), undef... -->
22958 // bitcast (shuffle V, ZeroVec, VectorMask)
22959 SDLoc DL(BV);
22960 EVT VecVT = Extract.getOperand(i: 0).getValueType();
22961 SDValue ZeroVec = DAG.getConstant(Val: 0, DL, VT: VecVT);
22962 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
22963 SDValue Shuf = TLI.buildLegalVectorShuffle(VT: VecVT, DL, N0: Extract.getOperand(i: 0),
22964 N1: ZeroVec, Mask: ShufMask, DAG);
22965 if (!Shuf)
22966 return SDValue();
22967 return DAG.getBitcast(VT, V: Shuf);
22968}
22969
22970// FIXME: promote to STLExtras.
22971template <typename R, typename T>
22972static auto getFirstIndexOf(R &&Range, const T &Val) {
22973 auto I = find(Range, Val);
22974 if (I == Range.end())
22975 return static_cast<decltype(std::distance(Range.begin(), I))>(-1);
22976 return std::distance(Range.begin(), I);
22977}
22978
22979// Check to see if this is a BUILD_VECTOR of a bunch of EXTRACT_VECTOR_ELT
22980// operations. If the types of the vectors we're extracting from allow it,
22981// turn this into a vector_shuffle node.
22982SDValue DAGCombiner::reduceBuildVecToShuffle(SDNode *N) {
22983 SDLoc DL(N);
22984 EVT VT = N->getValueType(ResNo: 0);
22985
22986 // Only type-legal BUILD_VECTOR nodes are converted to shuffle nodes.
22987 if (!isTypeLegal(VT))
22988 return SDValue();
22989
22990 if (SDValue V = reduceBuildVecToShuffleWithZero(BV: N, DAG))
22991 return V;
22992
22993 // May only combine to shuffle after legalize if shuffle is legal.
22994 if (LegalOperations && !TLI.isOperationLegal(Op: ISD::VECTOR_SHUFFLE, VT))
22995 return SDValue();
22996
22997 bool UsesZeroVector = false;
22998 unsigned NumElems = N->getNumOperands();
22999
23000 // Record, for each element of the newly built vector, which input vector
23001 // that element comes from. -1 stands for undef, 0 for the zero vector,
23002 // and positive values for the input vectors.
23003 // VectorMask maps each element to its vector number, and VecIn maps vector
23004 // numbers to their initial SDValues.
23005
23006 SmallVector<int, 8> VectorMask(NumElems, -1);
23007 SmallVector<SDValue, 8> VecIn;
23008 VecIn.push_back(Elt: SDValue());
23009
23010 for (unsigned i = 0; i != NumElems; ++i) {
23011 SDValue Op = N->getOperand(Num: i);
23012
23013 if (Op.isUndef())
23014 continue;
23015
23016 // See if we can use a blend with a zero vector.
23017 // TODO: Should we generalize this to a blend with an arbitrary constant
23018 // vector?
23019 if (isNullConstant(V: Op) || isNullFPConstant(V: Op)) {
23020 UsesZeroVector = true;
23021 VectorMask[i] = 0;
23022 continue;
23023 }
23024
23025 // Not an undef or zero. If the input is something other than an
23026 // EXTRACT_VECTOR_ELT with an in-range constant index, bail out.
23027 if (Op.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
23028 !isa<ConstantSDNode>(Val: Op.getOperand(i: 1)))
23029 return SDValue();
23030 SDValue ExtractedFromVec = Op.getOperand(i: 0);
23031
23032 if (ExtractedFromVec.getValueType().isScalableVector())
23033 return SDValue();
23034
23035 const APInt &ExtractIdx = Op.getConstantOperandAPInt(i: 1);
23036 if (ExtractIdx.uge(RHS: ExtractedFromVec.getValueType().getVectorNumElements()))
23037 return SDValue();
23038
23039 // All inputs must have the same element type as the output.
23040 if (VT.getVectorElementType() !=
23041 ExtractedFromVec.getValueType().getVectorElementType())
23042 return SDValue();
23043
23044 // Have we seen this input vector before?
23045 // The vectors are expected to be tiny (usually 1 or 2 elements), so using
23046 // a map back from SDValues to numbers isn't worth it.
23047 int Idx = getFirstIndexOf(Range&: VecIn, Val: ExtractedFromVec);
23048 if (Idx == -1) { // A new source vector?
23049 Idx = VecIn.size();
23050 VecIn.push_back(Elt: ExtractedFromVec);
23051 }
23052
23053 VectorMask[i] = Idx;
23054 }
23055
23056 // If we didn't find at least one input vector, bail out.
23057 if (VecIn.size() < 2)
23058 return SDValue();
23059
23060 // If all the Operands of BUILD_VECTOR extract from same
23061 // vector, then split the vector efficiently based on the maximum
23062 // vector access index and adjust the VectorMask and
23063 // VecIn accordingly.
23064 bool DidSplitVec = false;
23065 if (VecIn.size() == 2) {
23066 unsigned MaxIndex = 0;
23067 unsigned NearestPow2 = 0;
23068 SDValue Vec = VecIn.back();
23069 EVT InVT = Vec.getValueType();
23070 SmallVector<unsigned, 8> IndexVec(NumElems, 0);
23071
23072 for (unsigned i = 0; i < NumElems; i++) {
23073 if (VectorMask[i] <= 0)
23074 continue;
23075 unsigned Index = N->getOperand(Num: i).getConstantOperandVal(i: 1);
23076 IndexVec[i] = Index;
23077 MaxIndex = std::max(a: MaxIndex, b: Index);
23078 }
23079
23080 NearestPow2 = PowerOf2Ceil(A: MaxIndex);
23081 if (InVT.isSimple() && NearestPow2 > 2 && MaxIndex < NearestPow2 &&
23082 NumElems * 2 < NearestPow2) {
23083 unsigned SplitSize = NearestPow2 / 2;
23084 EVT SplitVT = EVT::getVectorVT(Context&: *DAG.getContext(),
23085 VT: InVT.getVectorElementType(), NumElements: SplitSize);
23086 if (TLI.isTypeLegal(VT: SplitVT) &&
23087 SplitSize + SplitVT.getVectorNumElements() <=
23088 InVT.getVectorNumElements()) {
23089 SDValue VecIn2 = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: SplitVT, N1: Vec,
23090 N2: DAG.getVectorIdxConstant(Val: SplitSize, DL));
23091 SDValue VecIn1 = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: SplitVT, N1: Vec,
23092 N2: DAG.getVectorIdxConstant(Val: 0, DL));
23093 VecIn.pop_back();
23094 VecIn.push_back(Elt: VecIn1);
23095 VecIn.push_back(Elt: VecIn2);
23096 DidSplitVec = true;
23097
23098 for (unsigned i = 0; i < NumElems; i++) {
23099 if (VectorMask[i] <= 0)
23100 continue;
23101 VectorMask[i] = (IndexVec[i] < SplitSize) ? 1 : 2;
23102 }
23103 }
23104 }
23105 }
23106
23107 // Sort input vectors by decreasing vector element count,
23108 // while preserving the relative order of equally-sized vectors.
23109 // Note that we keep the first "implicit zero vector as-is.
23110 SmallVector<SDValue, 8> SortedVecIn(VecIn);
23111 llvm::stable_sort(Range: MutableArrayRef<SDValue>(SortedVecIn).drop_front(),
23112 C: [](const SDValue &a, const SDValue &b) {
23113 return a.getValueType().getVectorNumElements() >
23114 b.getValueType().getVectorNumElements();
23115 });
23116
23117 // We now also need to rebuild the VectorMask, because it referenced element
23118 // order in VecIn, and we just sorted them.
23119 for (int &SourceVectorIndex : VectorMask) {
23120 if (SourceVectorIndex <= 0)
23121 continue;
23122 unsigned Idx = getFirstIndexOf(Range&: SortedVecIn, Val: VecIn[SourceVectorIndex]);
23123 assert(Idx > 0 && Idx < SortedVecIn.size() &&
23124 VecIn[SourceVectorIndex] == SortedVecIn[Idx] && "Remapping failure");
23125 SourceVectorIndex = Idx;
23126 }
23127
23128 VecIn = std::move(SortedVecIn);
23129
23130 // TODO: Should this fire if some of the input vectors has illegal type (like
23131 // it does now), or should we let legalization run its course first?
23132
23133 // Shuffle phase:
23134 // Take pairs of vectors, and shuffle them so that the result has elements
23135 // from these vectors in the correct places.
23136 // For example, given:
23137 // t10: i32 = extract_vector_elt t1, Constant:i64<0>
23138 // t11: i32 = extract_vector_elt t2, Constant:i64<0>
23139 // t12: i32 = extract_vector_elt t3, Constant:i64<0>
23140 // t13: i32 = extract_vector_elt t1, Constant:i64<1>
23141 // t14: v4i32 = BUILD_VECTOR t10, t11, t12, t13
23142 // We will generate:
23143 // t20: v4i32 = vector_shuffle<0,4,u,1> t1, t2
23144 // t21: v4i32 = vector_shuffle<u,u,0,u> t3, undef
23145 SmallVector<SDValue, 4> Shuffles;
23146 for (unsigned In = 0, Len = (VecIn.size() / 2); In < Len; ++In) {
23147 unsigned LeftIdx = 2 * In + 1;
23148 SDValue VecLeft = VecIn[LeftIdx];
23149 SDValue VecRight =
23150 (LeftIdx + 1) < VecIn.size() ? VecIn[LeftIdx + 1] : SDValue();
23151
23152 if (SDValue Shuffle = createBuildVecShuffle(DL, N, VectorMask, VecIn1: VecLeft,
23153 VecIn2: VecRight, LeftIdx, DidSplitVec))
23154 Shuffles.push_back(Elt: Shuffle);
23155 else
23156 return SDValue();
23157 }
23158
23159 // If we need the zero vector as an "ingredient" in the blend tree, add it
23160 // to the list of shuffles.
23161 if (UsesZeroVector)
23162 Shuffles.push_back(Elt: VT.isInteger() ? DAG.getConstant(Val: 0, DL, VT)
23163 : DAG.getConstantFP(Val: 0.0, DL, VT));
23164
23165 // If we only have one shuffle, we're done.
23166 if (Shuffles.size() == 1)
23167 return Shuffles[0];
23168
23169 // Update the vector mask to point to the post-shuffle vectors.
23170 for (int &Vec : VectorMask)
23171 if (Vec == 0)
23172 Vec = Shuffles.size() - 1;
23173 else
23174 Vec = (Vec - 1) / 2;
23175
23176 // More than one shuffle. Generate a binary tree of blends, e.g. if from
23177 // the previous step we got the set of shuffles t10, t11, t12, t13, we will
23178 // generate:
23179 // t10: v8i32 = vector_shuffle<0,8,u,u,u,u,u,u> t1, t2
23180 // t11: v8i32 = vector_shuffle<u,u,0,8,u,u,u,u> t3, t4
23181 // t12: v8i32 = vector_shuffle<u,u,u,u,0,8,u,u> t5, t6
23182 // t13: v8i32 = vector_shuffle<u,u,u,u,u,u,0,8> t7, t8
23183 // t20: v8i32 = vector_shuffle<0,1,10,11,u,u,u,u> t10, t11
23184 // t21: v8i32 = vector_shuffle<u,u,u,u,4,5,14,15> t12, t13
23185 // t30: v8i32 = vector_shuffle<0,1,2,3,12,13,14,15> t20, t21
23186
23187 // Make sure the initial size of the shuffle list is even.
23188 if (Shuffles.size() % 2)
23189 Shuffles.push_back(Elt: DAG.getUNDEF(VT));
23190
23191 for (unsigned CurSize = Shuffles.size(); CurSize > 1; CurSize /= 2) {
23192 if (CurSize % 2) {
23193 Shuffles[CurSize] = DAG.getUNDEF(VT);
23194 CurSize++;
23195 }
23196 for (unsigned In = 0, Len = CurSize / 2; In < Len; ++In) {
23197 int Left = 2 * In;
23198 int Right = 2 * In + 1;
23199 SmallVector<int, 8> Mask(NumElems, -1);
23200 SDValue L = Shuffles[Left];
23201 ArrayRef<int> LMask;
23202 bool IsLeftShuffle = L.getOpcode() == ISD::VECTOR_SHUFFLE &&
23203 L.use_empty() && L.getOperand(i: 1).isUndef() &&
23204 L.getOperand(i: 0).getValueType() == L.getValueType();
23205 if (IsLeftShuffle) {
23206 LMask = cast<ShuffleVectorSDNode>(Val: L.getNode())->getMask();
23207 L = L.getOperand(i: 0);
23208 }
23209 SDValue R = Shuffles[Right];
23210 ArrayRef<int> RMask;
23211 bool IsRightShuffle = R.getOpcode() == ISD::VECTOR_SHUFFLE &&
23212 R.use_empty() && R.getOperand(i: 1).isUndef() &&
23213 R.getOperand(i: 0).getValueType() == R.getValueType();
23214 if (IsRightShuffle) {
23215 RMask = cast<ShuffleVectorSDNode>(Val: R.getNode())->getMask();
23216 R = R.getOperand(i: 0);
23217 }
23218 for (unsigned I = 0; I != NumElems; ++I) {
23219 if (VectorMask[I] == Left) {
23220 Mask[I] = I;
23221 if (IsLeftShuffle)
23222 Mask[I] = LMask[I];
23223 VectorMask[I] = In;
23224 } else if (VectorMask[I] == Right) {
23225 Mask[I] = I + NumElems;
23226 if (IsRightShuffle)
23227 Mask[I] = RMask[I] + NumElems;
23228 VectorMask[I] = In;
23229 }
23230 }
23231
23232 Shuffles[In] = DAG.getVectorShuffle(VT, dl: DL, N1: L, N2: R, Mask);
23233 }
23234 }
23235 return Shuffles[0];
23236}
23237
23238// Try to turn a build vector of zero extends of extract vector elts into a
23239// a vector zero extend and possibly an extract subvector.
23240// TODO: Support sign extend?
23241// TODO: Allow undef elements?
23242SDValue DAGCombiner::convertBuildVecZextToZext(SDNode *N) {
23243 if (LegalOperations)
23244 return SDValue();
23245
23246 EVT VT = N->getValueType(ResNo: 0);
23247
23248 bool FoundZeroExtend = false;
23249 SDValue Op0 = N->getOperand(Num: 0);
23250 auto checkElem = [&](SDValue Op) -> int64_t {
23251 unsigned Opc = Op.getOpcode();
23252 FoundZeroExtend |= (Opc == ISD::ZERO_EXTEND);
23253 if ((Opc == ISD::ZERO_EXTEND || Opc == ISD::ANY_EXTEND) &&
23254 Op.getOperand(i: 0).getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
23255 Op0.getOperand(i: 0).getOperand(i: 0) == Op.getOperand(i: 0).getOperand(i: 0))
23256 if (auto *C = dyn_cast<ConstantSDNode>(Val: Op.getOperand(i: 0).getOperand(i: 1)))
23257 return C->getZExtValue();
23258 return -1;
23259 };
23260
23261 // Make sure the first element matches
23262 // (zext (extract_vector_elt X, C))
23263 // Offset must be a constant multiple of the
23264 // known-minimum vector length of the result type.
23265 int64_t Offset = checkElem(Op0);
23266 if (Offset < 0 || (Offset % VT.getVectorNumElements()) != 0)
23267 return SDValue();
23268
23269 unsigned NumElems = N->getNumOperands();
23270 SDValue In = Op0.getOperand(i: 0).getOperand(i: 0);
23271 EVT InSVT = In.getValueType().getScalarType();
23272 EVT InVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: InSVT, NumElements: NumElems);
23273
23274 // Don't create an illegal input type after type legalization.
23275 if (LegalTypes && !TLI.isTypeLegal(VT: InVT))
23276 return SDValue();
23277
23278 // Ensure all the elements come from the same vector and are adjacent.
23279 for (unsigned i = 1; i != NumElems; ++i) {
23280 if ((Offset + i) != checkElem(N->getOperand(Num: i)))
23281 return SDValue();
23282 }
23283
23284 SDLoc DL(N);
23285 In = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: InVT, N1: In,
23286 N2: Op0.getOperand(i: 0).getOperand(i: 1));
23287 return DAG.getNode(Opcode: FoundZeroExtend ? ISD::ZERO_EXTEND : ISD::ANY_EXTEND, DL,
23288 VT, Operand: In);
23289}
23290
23291// If this is a very simple BUILD_VECTOR with first element being a ZERO_EXTEND,
23292// and all other elements being constant zero's, granularize the BUILD_VECTOR's
23293// element width, absorbing the ZERO_EXTEND, turning it into a constant zero op.
23294// This patten can appear during legalization.
23295//
23296// NOTE: This can be generalized to allow more than a single
23297// non-constant-zero op, UNDEF's, and to be KnownBits-based,
23298SDValue DAGCombiner::convertBuildVecZextToBuildVecWithZeros(SDNode *N) {
23299 // Don't run this after legalization. Targets may have other preferences.
23300 if (Level >= AfterLegalizeDAG)
23301 return SDValue();
23302
23303 // FIXME: support big-endian.
23304 if (DAG.getDataLayout().isBigEndian())
23305 return SDValue();
23306
23307 EVT VT = N->getValueType(ResNo: 0);
23308 EVT OpVT = N->getOperand(Num: 0).getValueType();
23309 assert(!VT.isScalableVector() && "Encountered scalable BUILD_VECTOR?");
23310
23311 EVT OpIntVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: OpVT.getSizeInBits());
23312
23313 if (!TLI.isTypeLegal(VT: OpIntVT) ||
23314 (LegalOperations && !TLI.isOperationLegalOrCustom(Op: ISD::BITCAST, VT: OpIntVT)))
23315 return SDValue();
23316
23317 unsigned EltBitwidth = VT.getScalarSizeInBits();
23318 // NOTE: the actual width of operands may be wider than that!
23319
23320 // Analyze all operands of this BUILD_VECTOR. What is the largest number of
23321 // active bits they all have? We'll want to truncate them all to that width.
23322 unsigned ActiveBits = 0;
23323 APInt KnownZeroOps(VT.getVectorNumElements(), 0);
23324 for (auto I : enumerate(First: N->ops())) {
23325 SDValue Op = I.value();
23326 // FIXME: support UNDEF elements?
23327 if (auto *Cst = dyn_cast<ConstantSDNode>(Val&: Op)) {
23328 unsigned OpActiveBits =
23329 Cst->getAPIntValue().trunc(width: EltBitwidth).getActiveBits();
23330 if (OpActiveBits == 0) {
23331 KnownZeroOps.setBit(I.index());
23332 continue;
23333 }
23334 // Profitability check: don't allow non-zero constant operands.
23335 return SDValue();
23336 }
23337 // Profitability check: there must only be a single non-zero operand,
23338 // and it must be the first operand of the BUILD_VECTOR.
23339 if (I.index() != 0)
23340 return SDValue();
23341 // The operand must be a zero-extension itself.
23342 // FIXME: this could be generalized to known leading zeros check.
23343 if (Op.getOpcode() != ISD::ZERO_EXTEND)
23344 return SDValue();
23345 unsigned CurrActiveBits =
23346 Op.getOperand(i: 0).getValueSizeInBits().getFixedValue();
23347 assert(!ActiveBits && "Already encountered non-constant-zero operand?");
23348 ActiveBits = CurrActiveBits;
23349 // We want to at least halve the element size.
23350 if (2 * ActiveBits > EltBitwidth)
23351 return SDValue();
23352 }
23353
23354 // This BUILD_VECTOR must have at least one non-constant-zero operand.
23355 if (ActiveBits == 0)
23356 return SDValue();
23357
23358 // We have EltBitwidth bits, the *minimal* chunk size is ActiveBits,
23359 // into how many chunks can we split our element width?
23360 EVT NewScalarIntVT, NewIntVT;
23361 std::optional<unsigned> Factor;
23362 // We can split the element into at least two chunks, but not into more
23363 // than |_ EltBitwidth / ActiveBits _| chunks. Find a largest split factor
23364 // for which the element width is a multiple of it,
23365 // and the resulting types/operations on that chunk width are legal.
23366 assert(2 * ActiveBits <= EltBitwidth &&
23367 "We know that half or less bits of the element are active.");
23368 for (unsigned Scale = EltBitwidth / ActiveBits; Scale >= 2; --Scale) {
23369 if (EltBitwidth % Scale != 0)
23370 continue;
23371 unsigned ChunkBitwidth = EltBitwidth / Scale;
23372 assert(ChunkBitwidth >= ActiveBits && "As per starting point.");
23373 NewScalarIntVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ChunkBitwidth);
23374 NewIntVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NewScalarIntVT,
23375 NumElements: Scale * N->getNumOperands());
23376 if (!TLI.isTypeLegal(VT: NewScalarIntVT) || !TLI.isTypeLegal(VT: NewIntVT) ||
23377 (LegalOperations &&
23378 !(TLI.isOperationLegalOrCustom(Op: ISD::TRUNCATE, VT: NewScalarIntVT) &&
23379 TLI.isOperationLegalOrCustom(Op: ISD::BUILD_VECTOR, VT: NewIntVT))))
23380 continue;
23381 Factor = Scale;
23382 break;
23383 }
23384 if (!Factor)
23385 return SDValue();
23386
23387 SDLoc DL(N);
23388 SDValue ZeroOp = DAG.getConstant(Val: 0, DL, VT: NewScalarIntVT);
23389
23390 // Recreate the BUILD_VECTOR, with elements now being Factor times smaller.
23391 SmallVector<SDValue, 16> NewOps;
23392 NewOps.reserve(N: NewIntVT.getVectorNumElements());
23393 for (auto I : enumerate(First: N->ops())) {
23394 SDValue Op = I.value();
23395 assert(!Op.isUndef() && "FIXME: after allowing UNDEF's, handle them here.");
23396 unsigned SrcOpIdx = I.index();
23397 if (KnownZeroOps[SrcOpIdx]) {
23398 NewOps.append(NumInputs: *Factor, Elt: ZeroOp);
23399 continue;
23400 }
23401 Op = DAG.getBitcast(VT: OpIntVT, V: Op);
23402 Op = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: NewScalarIntVT, Operand: Op);
23403 NewOps.emplace_back(Args&: Op);
23404 NewOps.append(NumInputs: *Factor - 1, Elt: ZeroOp);
23405 }
23406 assert(NewOps.size() == NewIntVT.getVectorNumElements());
23407 SDValue NewBV = DAG.getBuildVector(VT: NewIntVT, DL, Ops: NewOps);
23408 NewBV = DAG.getBitcast(VT, V: NewBV);
23409 return NewBV;
23410}
23411
23412SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) {
23413 EVT VT = N->getValueType(ResNo: 0);
23414
23415 // A vector built entirely of undefs is undef.
23416 if (ISD::allOperandsUndef(N))
23417 return DAG.getUNDEF(VT);
23418
23419 // If this is a splat of a bitcast from another vector, change to a
23420 // concat_vector.
23421 // For example:
23422 // (build_vector (i64 (bitcast (v2i32 X))), (i64 (bitcast (v2i32 X)))) ->
23423 // (v2i64 (bitcast (concat_vectors (v2i32 X), (v2i32 X))))
23424 //
23425 // If X is a build_vector itself, the concat can become a larger build_vector.
23426 // TODO: Maybe this is useful for non-splat too?
23427 if (!LegalOperations) {
23428 if (SDValue Splat = cast<BuildVectorSDNode>(Val: N)->getSplatValue()) {
23429 Splat = peekThroughBitcasts(V: Splat);
23430 EVT SrcVT = Splat.getValueType();
23431 if (SrcVT.isVector()) {
23432 unsigned NumElts = N->getNumOperands() * SrcVT.getVectorNumElements();
23433 EVT NewVT = EVT::getVectorVT(Context&: *DAG.getContext(),
23434 VT: SrcVT.getVectorElementType(), NumElements: NumElts);
23435 if (!LegalTypes || TLI.isTypeLegal(VT: NewVT)) {
23436 SmallVector<SDValue, 8> Ops(N->getNumOperands(), Splat);
23437 SDValue Concat = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N),
23438 VT: NewVT, Ops);
23439 return DAG.getBitcast(VT, V: Concat);
23440 }
23441 }
23442 }
23443 }
23444
23445 // Check if we can express BUILD VECTOR via subvector extract.
23446 if (!LegalTypes && (N->getNumOperands() > 1)) {
23447 SDValue Op0 = N->getOperand(Num: 0);
23448 auto checkElem = [&](SDValue Op) -> uint64_t {
23449 if ((Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT) &&
23450 (Op0.getOperand(i: 0) == Op.getOperand(i: 0)))
23451 if (auto CNode = dyn_cast<ConstantSDNode>(Val: Op.getOperand(i: 1)))
23452 return CNode->getZExtValue();
23453 return -1;
23454 };
23455
23456 int Offset = checkElem(Op0);
23457 for (unsigned i = 0; i < N->getNumOperands(); ++i) {
23458 if (Offset + i != checkElem(N->getOperand(Num: i))) {
23459 Offset = -1;
23460 break;
23461 }
23462 }
23463
23464 if ((Offset == 0) &&
23465 (Op0.getOperand(i: 0).getValueType() == N->getValueType(ResNo: 0)))
23466 return Op0.getOperand(i: 0);
23467 if ((Offset != -1) &&
23468 ((Offset % N->getValueType(ResNo: 0).getVectorNumElements()) ==
23469 0)) // IDX must be multiple of output size.
23470 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL: SDLoc(N), VT: N->getValueType(ResNo: 0),
23471 N1: Op0.getOperand(i: 0), N2: Op0.getOperand(i: 1));
23472 }
23473
23474 if (SDValue V = convertBuildVecZextToZext(N))
23475 return V;
23476
23477 if (SDValue V = convertBuildVecZextToBuildVecWithZeros(N))
23478 return V;
23479
23480 if (SDValue V = reduceBuildVecExtToExtBuildVec(N))
23481 return V;
23482
23483 if (SDValue V = reduceBuildVecTruncToBitCast(N))
23484 return V;
23485
23486 if (SDValue V = reduceBuildVecToShuffle(N))
23487 return V;
23488
23489 // A splat of a single element is a SPLAT_VECTOR if supported on the target.
23490 // Do this late as some of the above may replace the splat.
23491 if (TLI.getOperationAction(Op: ISD::SPLAT_VECTOR, VT) != TargetLowering::Expand)
23492 if (SDValue V = cast<BuildVectorSDNode>(Val: N)->getSplatValue()) {
23493 assert(!V.isUndef() && "Splat of undef should have been handled earlier");
23494 return DAG.getNode(Opcode: ISD::SPLAT_VECTOR, DL: SDLoc(N), VT, Operand: V);
23495 }
23496
23497 return SDValue();
23498}
23499
23500static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) {
23501 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
23502 EVT OpVT = N->getOperand(Num: 0).getValueType();
23503
23504 // If the operands are legal vectors, leave them alone.
23505 if (TLI.isTypeLegal(VT: OpVT) || OpVT.isScalableVector())
23506 return SDValue();
23507
23508 SDLoc DL(N);
23509 EVT VT = N->getValueType(ResNo: 0);
23510 SmallVector<SDValue, 8> Ops;
23511
23512 EVT SVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: OpVT.getSizeInBits());
23513 SDValue ScalarUndef = DAG.getNode(Opcode: ISD::UNDEF, DL, VT: SVT);
23514
23515 // Keep track of what we encounter.
23516 bool AnyInteger = false;
23517 bool AnyFP = false;
23518 for (const SDValue &Op : N->ops()) {
23519 if (ISD::BITCAST == Op.getOpcode() &&
23520 !Op.getOperand(i: 0).getValueType().isVector())
23521 Ops.push_back(Elt: Op.getOperand(i: 0));
23522 else if (ISD::UNDEF == Op.getOpcode())
23523 Ops.push_back(Elt: ScalarUndef);
23524 else
23525 return SDValue();
23526
23527 // Note whether we encounter an integer or floating point scalar.
23528 // If it's neither, bail out, it could be something weird like x86mmx.
23529 EVT LastOpVT = Ops.back().getValueType();
23530 if (LastOpVT.isFloatingPoint())
23531 AnyFP = true;
23532 else if (LastOpVT.isInteger())
23533 AnyInteger = true;
23534 else
23535 return SDValue();
23536 }
23537
23538 // If any of the operands is a floating point scalar bitcast to a vector,
23539 // use floating point types throughout, and bitcast everything.
23540 // Replace UNDEFs by another scalar UNDEF node, of the final desired type.
23541 if (AnyFP) {
23542 SVT = EVT::getFloatingPointVT(BitWidth: OpVT.getSizeInBits());
23543 ScalarUndef = DAG.getNode(Opcode: ISD::UNDEF, DL, VT: SVT);
23544 if (AnyInteger) {
23545 for (SDValue &Op : Ops) {
23546 if (Op.getValueType() == SVT)
23547 continue;
23548 if (Op.isUndef())
23549 Op = ScalarUndef;
23550 else
23551 Op = DAG.getBitcast(VT: SVT, V: Op);
23552 }
23553 }
23554 }
23555
23556 EVT VecVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: SVT,
23557 NumElements: VT.getSizeInBits() / SVT.getSizeInBits());
23558 return DAG.getBitcast(VT, V: DAG.getBuildVector(VT: VecVT, DL, Ops));
23559}
23560
23561// Attempt to merge nested concat_vectors/undefs.
23562// Fold concat_vectors(concat_vectors(x,y,z,w),u,u,concat_vectors(a,b,c,d))
23563// --> concat_vectors(x,y,z,w,u,u,u,u,u,u,u,u,a,b,c,d)
23564static SDValue combineConcatVectorOfConcatVectors(SDNode *N,
23565 SelectionDAG &DAG) {
23566 EVT VT = N->getValueType(ResNo: 0);
23567
23568 // Ensure we're concatenating UNDEF and CONCAT_VECTORS nodes of similar types.
23569 EVT SubVT;
23570 SDValue FirstConcat;
23571 for (const SDValue &Op : N->ops()) {
23572 if (Op.isUndef())
23573 continue;
23574 if (Op.getOpcode() != ISD::CONCAT_VECTORS)
23575 return SDValue();
23576 if (!FirstConcat) {
23577 SubVT = Op.getOperand(i: 0).getValueType();
23578 if (!DAG.getTargetLoweringInfo().isTypeLegal(VT: SubVT))
23579 return SDValue();
23580 FirstConcat = Op;
23581 continue;
23582 }
23583 if (SubVT != Op.getOperand(i: 0).getValueType())
23584 return SDValue();
23585 }
23586 assert(FirstConcat && "Concat of all-undefs found");
23587
23588 SmallVector<SDValue> ConcatOps;
23589 for (const SDValue &Op : N->ops()) {
23590 if (Op.isUndef()) {
23591 ConcatOps.append(NumInputs: FirstConcat->getNumOperands(), Elt: DAG.getUNDEF(VT: SubVT));
23592 continue;
23593 }
23594 ConcatOps.append(in_start: Op->op_begin(), in_end: Op->op_end());
23595 }
23596 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT, Ops: ConcatOps);
23597}
23598
23599// Check to see if this is a CONCAT_VECTORS of a bunch of EXTRACT_SUBVECTOR
23600// operations. If so, and if the EXTRACT_SUBVECTOR vector inputs come from at
23601// most two distinct vectors the same size as the result, attempt to turn this
23602// into a legal shuffle.
23603static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) {
23604 EVT VT = N->getValueType(ResNo: 0);
23605 EVT OpVT = N->getOperand(Num: 0).getValueType();
23606
23607 // We currently can't generate an appropriate shuffle for a scalable vector.
23608 if (VT.isScalableVector())
23609 return SDValue();
23610
23611 int NumElts = VT.getVectorNumElements();
23612 int NumOpElts = OpVT.getVectorNumElements();
23613
23614 SDValue SV0 = DAG.getUNDEF(VT), SV1 = DAG.getUNDEF(VT);
23615 SmallVector<int, 8> Mask;
23616
23617 for (SDValue Op : N->ops()) {
23618 Op = peekThroughBitcasts(V: Op);
23619
23620 // UNDEF nodes convert to UNDEF shuffle mask values.
23621 if (Op.isUndef()) {
23622 Mask.append(NumInputs: (unsigned)NumOpElts, Elt: -1);
23623 continue;
23624 }
23625
23626 if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
23627 return SDValue();
23628
23629 // What vector are we extracting the subvector from and at what index?
23630 SDValue ExtVec = Op.getOperand(i: 0);
23631 int ExtIdx = Op.getConstantOperandVal(i: 1);
23632
23633 // We want the EVT of the original extraction to correctly scale the
23634 // extraction index.
23635 EVT ExtVT = ExtVec.getValueType();
23636 ExtVec = peekThroughBitcasts(V: ExtVec);
23637
23638 // UNDEF nodes convert to UNDEF shuffle mask values.
23639 if (ExtVec.isUndef()) {
23640 Mask.append(NumInputs: (unsigned)NumOpElts, Elt: -1);
23641 continue;
23642 }
23643
23644 // Ensure that we are extracting a subvector from a vector the same
23645 // size as the result.
23646 if (ExtVT.getSizeInBits() != VT.getSizeInBits())
23647 return SDValue();
23648
23649 // Scale the subvector index to account for any bitcast.
23650 int NumExtElts = ExtVT.getVectorNumElements();
23651 if (0 == (NumExtElts % NumElts))
23652 ExtIdx /= (NumExtElts / NumElts);
23653 else if (0 == (NumElts % NumExtElts))
23654 ExtIdx *= (NumElts / NumExtElts);
23655 else
23656 return SDValue();
23657
23658 // At most we can reference 2 inputs in the final shuffle.
23659 if (SV0.isUndef() || SV0 == ExtVec) {
23660 SV0 = ExtVec;
23661 for (int i = 0; i != NumOpElts; ++i)
23662 Mask.push_back(Elt: i + ExtIdx);
23663 } else if (SV1.isUndef() || SV1 == ExtVec) {
23664 SV1 = ExtVec;
23665 for (int i = 0; i != NumOpElts; ++i)
23666 Mask.push_back(Elt: i + ExtIdx + NumElts);
23667 } else {
23668 return SDValue();
23669 }
23670 }
23671
23672 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
23673 return TLI.buildLegalVectorShuffle(VT, DL: SDLoc(N), N0: DAG.getBitcast(VT, V: SV0),
23674 N1: DAG.getBitcast(VT, V: SV1), Mask, DAG);
23675}
23676
23677static SDValue combineConcatVectorOfCasts(SDNode *N, SelectionDAG &DAG) {
23678 unsigned CastOpcode = N->getOperand(Num: 0).getOpcode();
23679 switch (CastOpcode) {
23680 case ISD::SINT_TO_FP:
23681 case ISD::UINT_TO_FP:
23682 case ISD::FP_TO_SINT:
23683 case ISD::FP_TO_UINT:
23684 // TODO: Allow more opcodes?
23685 // case ISD::BITCAST:
23686 // case ISD::TRUNCATE:
23687 // case ISD::ZERO_EXTEND:
23688 // case ISD::SIGN_EXTEND:
23689 // case ISD::FP_EXTEND:
23690 break;
23691 default:
23692 return SDValue();
23693 }
23694
23695 EVT SrcVT = N->getOperand(Num: 0).getOperand(i: 0).getValueType();
23696 if (!SrcVT.isVector())
23697 return SDValue();
23698
23699 // All operands of the concat must be the same kind of cast from the same
23700 // source type.
23701 SmallVector<SDValue, 4> SrcOps;
23702 for (SDValue Op : N->ops()) {
23703 if (Op.getOpcode() != CastOpcode || !Op.hasOneUse() ||
23704 Op.getOperand(i: 0).getValueType() != SrcVT)
23705 return SDValue();
23706 SrcOps.push_back(Elt: Op.getOperand(i: 0));
23707 }
23708
23709 // The wider cast must be supported by the target. This is unusual because
23710 // the operation support type parameter depends on the opcode. In addition,
23711 // check the other type in the cast to make sure this is really legal.
23712 EVT VT = N->getValueType(ResNo: 0);
23713 EVT SrcEltVT = SrcVT.getVectorElementType();
23714 ElementCount NumElts = SrcVT.getVectorElementCount() * N->getNumOperands();
23715 EVT ConcatSrcVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: SrcEltVT, EC: NumElts);
23716 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
23717 switch (CastOpcode) {
23718 case ISD::SINT_TO_FP:
23719 case ISD::UINT_TO_FP:
23720 if (!TLI.isOperationLegalOrCustom(Op: CastOpcode, VT: ConcatSrcVT) ||
23721 !TLI.isTypeLegal(VT))
23722 return SDValue();
23723 break;
23724 case ISD::FP_TO_SINT:
23725 case ISD::FP_TO_UINT:
23726 if (!TLI.isOperationLegalOrCustom(Op: CastOpcode, VT) ||
23727 !TLI.isTypeLegal(VT: ConcatSrcVT))
23728 return SDValue();
23729 break;
23730 default:
23731 llvm_unreachable("Unexpected cast opcode");
23732 }
23733
23734 // concat (cast X), (cast Y)... -> cast (concat X, Y...)
23735 SDLoc DL(N);
23736 SDValue NewConcat = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT: ConcatSrcVT, Ops: SrcOps);
23737 return DAG.getNode(Opcode: CastOpcode, DL, VT, Operand: NewConcat);
23738}
23739
23740// See if this is a simple CONCAT_VECTORS with no UNDEF operands, and if one of
23741// the operands is a SHUFFLE_VECTOR, and all other operands are also operands
23742// to that SHUFFLE_VECTOR, create wider SHUFFLE_VECTOR.
23743static SDValue combineConcatVectorOfShuffleAndItsOperands(
23744 SDNode *N, SelectionDAG &DAG, const TargetLowering &TLI, bool LegalTypes,
23745 bool LegalOperations) {
23746 EVT VT = N->getValueType(ResNo: 0);
23747 EVT OpVT = N->getOperand(Num: 0).getValueType();
23748 if (VT.isScalableVector())
23749 return SDValue();
23750
23751 // For now, only allow simple 2-operand concatenations.
23752 if (N->getNumOperands() != 2)
23753 return SDValue();
23754
23755 // Don't create illegal types/shuffles when not allowed to.
23756 if ((LegalTypes && !TLI.isTypeLegal(VT)) ||
23757 (LegalOperations &&
23758 !TLI.isOperationLegalOrCustom(Op: ISD::VECTOR_SHUFFLE, VT)))
23759 return SDValue();
23760
23761 // Analyze all of the operands of the CONCAT_VECTORS. Out of all of them,
23762 // we want to find one that is: (1) a SHUFFLE_VECTOR (2) only used by us,
23763 // and (3) all operands of CONCAT_VECTORS must be either that SHUFFLE_VECTOR,
23764 // or one of the operands of that SHUFFLE_VECTOR (but not UNDEF!).
23765 // (4) and for now, the SHUFFLE_VECTOR must be unary.
23766 ShuffleVectorSDNode *SVN = nullptr;
23767 for (SDValue Op : N->ops()) {
23768 if (auto *CurSVN = dyn_cast<ShuffleVectorSDNode>(Val&: Op);
23769 CurSVN && CurSVN->getOperand(Num: 1).isUndef() && N->isOnlyUserOf(N: CurSVN) &&
23770 all_of(Range: N->ops(), P: [CurSVN](SDValue Op) {
23771 // FIXME: can we allow UNDEF operands?
23772 return !Op.isUndef() &&
23773 (Op.getNode() == CurSVN || is_contained(Range: CurSVN->ops(), Element: Op));
23774 })) {
23775 SVN = CurSVN;
23776 break;
23777 }
23778 }
23779 if (!SVN)
23780 return SDValue();
23781
23782 // We are going to pad the shuffle operands, so any indice, that was picking
23783 // from the second operand, must be adjusted.
23784 SmallVector<int, 16> AdjustedMask;
23785 AdjustedMask.reserve(N: SVN->getMask().size());
23786 assert(SVN->getOperand(1).isUndef() && "Expected unary shuffle!");
23787 append_range(C&: AdjustedMask, R: SVN->getMask());
23788
23789 // Identity masks for the operands of the (padded) shuffle.
23790 SmallVector<int, 32> IdentityMask(2 * OpVT.getVectorNumElements());
23791 MutableArrayRef<int> FirstShufOpIdentityMask =
23792 MutableArrayRef<int>(IdentityMask)
23793 .take_front(N: OpVT.getVectorNumElements());
23794 MutableArrayRef<int> SecondShufOpIdentityMask =
23795 MutableArrayRef<int>(IdentityMask).take_back(N: OpVT.getVectorNumElements());
23796 std::iota(first: FirstShufOpIdentityMask.begin(), last: FirstShufOpIdentityMask.end(), value: 0);
23797 std::iota(first: SecondShufOpIdentityMask.begin(), last: SecondShufOpIdentityMask.end(),
23798 value: VT.getVectorNumElements());
23799
23800 // New combined shuffle mask.
23801 SmallVector<int, 32> Mask;
23802 Mask.reserve(N: VT.getVectorNumElements());
23803 for (SDValue Op : N->ops()) {
23804 assert(!Op.isUndef() && "Not expecting to concatenate UNDEF.");
23805 if (Op.getNode() == SVN) {
23806 append_range(C&: Mask, R&: AdjustedMask);
23807 continue;
23808 }
23809 if (Op == SVN->getOperand(Num: 0)) {
23810 append_range(C&: Mask, R&: FirstShufOpIdentityMask);
23811 continue;
23812 }
23813 if (Op == SVN->getOperand(Num: 1)) {
23814 append_range(C&: Mask, R&: SecondShufOpIdentityMask);
23815 continue;
23816 }
23817 llvm_unreachable("Unexpected operand!");
23818 }
23819
23820 // Don't create illegal shuffle masks.
23821 if (!TLI.isShuffleMaskLegal(Mask, VT))
23822 return SDValue();
23823
23824 // Pad the shuffle operands with UNDEF.
23825 SDLoc dl(N);
23826 std::array<SDValue, 2> ShufOps;
23827 for (auto I : zip(t: SVN->ops(), u&: ShufOps)) {
23828 SDValue ShufOp = std::get<0>(t&: I);
23829 SDValue &NewShufOp = std::get<1>(t&: I);
23830 if (ShufOp.isUndef())
23831 NewShufOp = DAG.getUNDEF(VT);
23832 else {
23833 SmallVector<SDValue, 2> ShufOpParts(N->getNumOperands(),
23834 DAG.getUNDEF(VT: OpVT));
23835 ShufOpParts[0] = ShufOp;
23836 NewShufOp = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: dl, VT, Ops: ShufOpParts);
23837 }
23838 }
23839 // Finally, create the new wide shuffle.
23840 return DAG.getVectorShuffle(VT, dl, N1: ShufOps[0], N2: ShufOps[1], Mask);
23841}
23842
23843SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
23844 // If we only have one input vector, we don't need to do any concatenation.
23845 if (N->getNumOperands() == 1)
23846 return N->getOperand(Num: 0);
23847
23848 // Check if all of the operands are undefs.
23849 EVT VT = N->getValueType(ResNo: 0);
23850 if (ISD::allOperandsUndef(N))
23851 return DAG.getUNDEF(VT);
23852
23853 // Optimize concat_vectors where all but the first of the vectors are undef.
23854 if (all_of(Range: drop_begin(RangeOrContainer: N->ops()),
23855 P: [](const SDValue &Op) { return Op.isUndef(); })) {
23856 SDValue In = N->getOperand(Num: 0);
23857 assert(In.getValueType().isVector() && "Must concat vectors");
23858
23859 // If the input is a concat_vectors, just make a larger concat by padding
23860 // with smaller undefs.
23861 //
23862 // Legalizing in AArch64TargetLowering::LowerCONCAT_VECTORS() and combining
23863 // here could cause an infinite loop. That legalizing happens when LegalDAG
23864 // is true and input of AArch64TargetLowering::LowerCONCAT_VECTORS() is
23865 // scalable.
23866 if (In.getOpcode() == ISD::CONCAT_VECTORS && In.hasOneUse() &&
23867 !(LegalDAG && In.getValueType().isScalableVector())) {
23868 unsigned NumOps = N->getNumOperands() * In.getNumOperands();
23869 SmallVector<SDValue, 4> Ops(In->op_begin(), In->op_end());
23870 Ops.resize(N: NumOps, NV: DAG.getUNDEF(VT: Ops[0].getValueType()));
23871 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT, Ops);
23872 }
23873
23874 SDValue Scalar = peekThroughOneUseBitcasts(V: In);
23875
23876 // concat_vectors(scalar_to_vector(scalar), undef) ->
23877 // scalar_to_vector(scalar)
23878 if (!LegalOperations && Scalar.getOpcode() == ISD::SCALAR_TO_VECTOR &&
23879 Scalar.hasOneUse()) {
23880 EVT SVT = Scalar.getValueType().getVectorElementType();
23881 if (SVT == Scalar.getOperand(i: 0).getValueType())
23882 Scalar = Scalar.getOperand(i: 0);
23883 }
23884
23885 // concat_vectors(scalar, undef) -> scalar_to_vector(scalar)
23886 if (!Scalar.getValueType().isVector() && In.hasOneUse()) {
23887 // If the bitcast type isn't legal, it might be a trunc of a legal type;
23888 // look through the trunc so we can still do the transform:
23889 // concat_vectors(trunc(scalar), undef) -> scalar_to_vector(scalar)
23890 if (Scalar->getOpcode() == ISD::TRUNCATE &&
23891 !TLI.isTypeLegal(VT: Scalar.getValueType()) &&
23892 TLI.isTypeLegal(VT: Scalar->getOperand(Num: 0).getValueType()))
23893 Scalar = Scalar->getOperand(Num: 0);
23894
23895 EVT SclTy = Scalar.getValueType();
23896
23897 if (!SclTy.isFloatingPoint() && !SclTy.isInteger())
23898 return SDValue();
23899
23900 // Bail out if the vector size is not a multiple of the scalar size.
23901 if (VT.getSizeInBits() % SclTy.getSizeInBits())
23902 return SDValue();
23903
23904 unsigned VNTNumElms = VT.getSizeInBits() / SclTy.getSizeInBits();
23905 if (VNTNumElms < 2)
23906 return SDValue();
23907
23908 EVT NVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: SclTy, NumElements: VNTNumElms);
23909 if (!TLI.isTypeLegal(VT: NVT) || !TLI.isTypeLegal(VT: Scalar.getValueType()))
23910 return SDValue();
23911
23912 SDValue Res = DAG.getNode(Opcode: ISD::SCALAR_TO_VECTOR, DL: SDLoc(N), VT: NVT, Operand: Scalar);
23913 return DAG.getBitcast(VT, V: Res);
23914 }
23915 }
23916
23917 // Fold any combination of BUILD_VECTOR or UNDEF nodes into one BUILD_VECTOR.
23918 // We have already tested above for an UNDEF only concatenation.
23919 // fold (concat_vectors (BUILD_VECTOR A, B, ...), (BUILD_VECTOR C, D, ...))
23920 // -> (BUILD_VECTOR A, B, ..., C, D, ...)
23921 auto IsBuildVectorOrUndef = [](const SDValue &Op) {
23922 return ISD::UNDEF == Op.getOpcode() || ISD::BUILD_VECTOR == Op.getOpcode();
23923 };
23924 if (llvm::all_of(Range: N->ops(), P: IsBuildVectorOrUndef)) {
23925 SmallVector<SDValue, 8> Opnds;
23926 EVT SVT = VT.getScalarType();
23927
23928 EVT MinVT = SVT;
23929 if (!SVT.isFloatingPoint()) {
23930 // If BUILD_VECTOR are from built from integer, they may have different
23931 // operand types. Get the smallest type and truncate all operands to it.
23932 bool FoundMinVT = false;
23933 for (const SDValue &Op : N->ops())
23934 if (ISD::BUILD_VECTOR == Op.getOpcode()) {
23935 EVT OpSVT = Op.getOperand(i: 0).getValueType();
23936 MinVT = (!FoundMinVT || OpSVT.bitsLE(VT: MinVT)) ? OpSVT : MinVT;
23937 FoundMinVT = true;
23938 }
23939 assert(FoundMinVT && "Concat vector type mismatch");
23940 }
23941
23942 for (const SDValue &Op : N->ops()) {
23943 EVT OpVT = Op.getValueType();
23944 unsigned NumElts = OpVT.getVectorNumElements();
23945
23946 if (ISD::UNDEF == Op.getOpcode())
23947 Opnds.append(NumInputs: NumElts, Elt: DAG.getUNDEF(VT: MinVT));
23948
23949 if (ISD::BUILD_VECTOR == Op.getOpcode()) {
23950 if (SVT.isFloatingPoint()) {
23951 assert(SVT == OpVT.getScalarType() && "Concat vector type mismatch");
23952 Opnds.append(in_start: Op->op_begin(), in_end: Op->op_begin() + NumElts);
23953 } else {
23954 for (unsigned i = 0; i != NumElts; ++i)
23955 Opnds.push_back(
23956 Elt: DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N), VT: MinVT, Operand: Op.getOperand(i)));
23957 }
23958 }
23959 }
23960
23961 assert(VT.getVectorNumElements() == Opnds.size() &&
23962 "Concat vector type mismatch");
23963 return DAG.getBuildVector(VT, DL: SDLoc(N), Ops: Opnds);
23964 }
23965
23966 // Fold CONCAT_VECTORS of only bitcast scalars (or undef) to BUILD_VECTOR.
23967 // FIXME: Add support for concat_vectors(bitcast(vec0),bitcast(vec1),...).
23968 if (SDValue V = combineConcatVectorOfScalars(N, DAG))
23969 return V;
23970
23971 if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT)) {
23972 // Fold CONCAT_VECTORS of CONCAT_VECTORS (or undef) to VECTOR_SHUFFLE.
23973 if (SDValue V = combineConcatVectorOfConcatVectors(N, DAG))
23974 return V;
23975
23976 // Fold CONCAT_VECTORS of EXTRACT_SUBVECTOR (or undef) to VECTOR_SHUFFLE.
23977 if (SDValue V = combineConcatVectorOfExtracts(N, DAG))
23978 return V;
23979 }
23980
23981 if (SDValue V = combineConcatVectorOfCasts(N, DAG))
23982 return V;
23983
23984 if (SDValue V = combineConcatVectorOfShuffleAndItsOperands(
23985 N, DAG, TLI, LegalTypes, LegalOperations))
23986 return V;
23987
23988 // Type legalization of vectors and DAG canonicalization of SHUFFLE_VECTOR
23989 // nodes often generate nop CONCAT_VECTOR nodes. Scan the CONCAT_VECTOR
23990 // operands and look for a CONCAT operations that place the incoming vectors
23991 // at the exact same location.
23992 //
23993 // For scalable vectors, EXTRACT_SUBVECTOR indexes are implicitly scaled.
23994 SDValue SingleSource = SDValue();
23995 unsigned PartNumElem =
23996 N->getOperand(Num: 0).getValueType().getVectorMinNumElements();
23997
23998 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
23999 SDValue Op = N->getOperand(Num: i);
24000
24001 if (Op.isUndef())
24002 continue;
24003
24004 // Check if this is the identity extract:
24005 if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
24006 return SDValue();
24007
24008 // Find the single incoming vector for the extract_subvector.
24009 if (SingleSource.getNode()) {
24010 if (Op.getOperand(i: 0) != SingleSource)
24011 return SDValue();
24012 } else {
24013 SingleSource = Op.getOperand(i: 0);
24014
24015 // Check the source type is the same as the type of the result.
24016 // If not, this concat may extend the vector, so we can not
24017 // optimize it away.
24018 if (SingleSource.getValueType() != N->getValueType(ResNo: 0))
24019 return SDValue();
24020 }
24021
24022 // Check that we are reading from the identity index.
24023 unsigned IdentityIndex = i * PartNumElem;
24024 if (Op.getConstantOperandAPInt(i: 1) != IdentityIndex)
24025 return SDValue();
24026 }
24027
24028 if (SingleSource.getNode())
24029 return SingleSource;
24030
24031 return SDValue();
24032}
24033
24034// Helper that peeks through INSERT_SUBVECTOR/CONCAT_VECTORS to find
24035// if the subvector can be sourced for free.
24036static SDValue getSubVectorSrc(SDValue V, SDValue Index, EVT SubVT) {
24037 if (V.getOpcode() == ISD::INSERT_SUBVECTOR &&
24038 V.getOperand(i: 1).getValueType() == SubVT && V.getOperand(i: 2) == Index) {
24039 return V.getOperand(i: 1);
24040 }
24041 auto *IndexC = dyn_cast<ConstantSDNode>(Val&: Index);
24042 if (IndexC && V.getOpcode() == ISD::CONCAT_VECTORS &&
24043 V.getOperand(i: 0).getValueType() == SubVT &&
24044 (IndexC->getZExtValue() % SubVT.getVectorMinNumElements()) == 0) {
24045 uint64_t SubIdx = IndexC->getZExtValue() / SubVT.getVectorMinNumElements();
24046 return V.getOperand(i: SubIdx);
24047 }
24048 return SDValue();
24049}
24050
24051static SDValue narrowInsertExtractVectorBinOp(SDNode *Extract,
24052 SelectionDAG &DAG,
24053 bool LegalOperations) {
24054 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24055 SDValue BinOp = Extract->getOperand(Num: 0);
24056 unsigned BinOpcode = BinOp.getOpcode();
24057 if (!TLI.isBinOp(Opcode: BinOpcode) || BinOp->getNumValues() != 1)
24058 return SDValue();
24059
24060 EVT VecVT = BinOp.getValueType();
24061 SDValue Bop0 = BinOp.getOperand(i: 0), Bop1 = BinOp.getOperand(i: 1);
24062 if (VecVT != Bop0.getValueType() || VecVT != Bop1.getValueType())
24063 return SDValue();
24064
24065 SDValue Index = Extract->getOperand(Num: 1);
24066 EVT SubVT = Extract->getValueType(ResNo: 0);
24067 if (!TLI.isOperationLegalOrCustom(Op: BinOpcode, VT: SubVT, LegalOnly: LegalOperations))
24068 return SDValue();
24069
24070 SDValue Sub0 = getSubVectorSrc(V: Bop0, Index, SubVT);
24071 SDValue Sub1 = getSubVectorSrc(V: Bop1, Index, SubVT);
24072
24073 // TODO: We could handle the case where only 1 operand is being inserted by
24074 // creating an extract of the other operand, but that requires checking
24075 // number of uses and/or costs.
24076 if (!Sub0 || !Sub1)
24077 return SDValue();
24078
24079 // We are inserting both operands of the wide binop only to extract back
24080 // to the narrow vector size. Eliminate all of the insert/extract:
24081 // ext (binop (ins ?, X, Index), (ins ?, Y, Index)), Index --> binop X, Y
24082 return DAG.getNode(Opcode: BinOpcode, DL: SDLoc(Extract), VT: SubVT, N1: Sub0, N2: Sub1,
24083 Flags: BinOp->getFlags());
24084}
24085
24086/// If we are extracting a subvector produced by a wide binary operator try
24087/// to use a narrow binary operator and/or avoid concatenation and extraction.
24088static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG,
24089 bool LegalOperations) {
24090 // TODO: Refactor with the caller (visitEXTRACT_SUBVECTOR), so we can share
24091 // some of these bailouts with other transforms.
24092
24093 if (SDValue V = narrowInsertExtractVectorBinOp(Extract, DAG, LegalOperations))
24094 return V;
24095
24096 // The extract index must be a constant, so we can map it to a concat operand.
24097 auto *ExtractIndexC = dyn_cast<ConstantSDNode>(Val: Extract->getOperand(Num: 1));
24098 if (!ExtractIndexC)
24099 return SDValue();
24100
24101 // We are looking for an optionally bitcasted wide vector binary operator
24102 // feeding an extract subvector.
24103 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24104 SDValue BinOp = peekThroughBitcasts(V: Extract->getOperand(Num: 0));
24105 unsigned BOpcode = BinOp.getOpcode();
24106 if (!TLI.isBinOp(Opcode: BOpcode) || BinOp->getNumValues() != 1)
24107 return SDValue();
24108
24109 // Exclude the fake form of fneg (fsub -0.0, x) because that is likely to be
24110 // reduced to the unary fneg when it is visited, and we probably want to deal
24111 // with fneg in a target-specific way.
24112 if (BOpcode == ISD::FSUB) {
24113 auto *C = isConstOrConstSplatFP(N: BinOp.getOperand(i: 0), /*AllowUndefs*/ true);
24114 if (C && C->getValueAPF().isNegZero())
24115 return SDValue();
24116 }
24117
24118 // The binop must be a vector type, so we can extract some fraction of it.
24119 EVT WideBVT = BinOp.getValueType();
24120 // The optimisations below currently assume we are dealing with fixed length
24121 // vectors. It is possible to add support for scalable vectors, but at the
24122 // moment we've done no analysis to prove whether they are profitable or not.
24123 if (!WideBVT.isFixedLengthVector())
24124 return SDValue();
24125
24126 EVT VT = Extract->getValueType(ResNo: 0);
24127 unsigned ExtractIndex = ExtractIndexC->getZExtValue();
24128 assert(ExtractIndex % VT.getVectorNumElements() == 0 &&
24129 "Extract index is not a multiple of the vector length.");
24130
24131 // Bail out if this is not a proper multiple width extraction.
24132 unsigned WideWidth = WideBVT.getSizeInBits();
24133 unsigned NarrowWidth = VT.getSizeInBits();
24134 if (WideWidth % NarrowWidth != 0)
24135 return SDValue();
24136
24137 // Bail out if we are extracting a fraction of a single operation. This can
24138 // occur because we potentially looked through a bitcast of the binop.
24139 unsigned NarrowingRatio = WideWidth / NarrowWidth;
24140 unsigned WideNumElts = WideBVT.getVectorNumElements();
24141 if (WideNumElts % NarrowingRatio != 0)
24142 return SDValue();
24143
24144 // Bail out if the target does not support a narrower version of the binop.
24145 EVT NarrowBVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: WideBVT.getScalarType(),
24146 NumElements: WideNumElts / NarrowingRatio);
24147 if (!TLI.isOperationLegalOrCustomOrPromote(Op: BOpcode, VT: NarrowBVT,
24148 LegalOnly: LegalOperations))
24149 return SDValue();
24150
24151 // If extraction is cheap, we don't need to look at the binop operands
24152 // for concat ops. The narrow binop alone makes this transform profitable.
24153 // We can't just reuse the original extract index operand because we may have
24154 // bitcasted.
24155 unsigned ConcatOpNum = ExtractIndex / VT.getVectorNumElements();
24156 unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements();
24157 if (TLI.isExtractSubvectorCheap(ResVT: NarrowBVT, SrcVT: WideBVT, Index: ExtBOIdx) &&
24158 BinOp.hasOneUse() && Extract->getOperand(Num: 0)->hasOneUse()) {
24159 // extract (binop B0, B1), N --> binop (extract B0, N), (extract B1, N)
24160 SDLoc DL(Extract);
24161 SDValue NewExtIndex = DAG.getVectorIdxConstant(Val: ExtBOIdx, DL);
24162 SDValue X = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NarrowBVT,
24163 N1: BinOp.getOperand(i: 0), N2: NewExtIndex);
24164 SDValue Y = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NarrowBVT,
24165 N1: BinOp.getOperand(i: 1), N2: NewExtIndex);
24166 SDValue NarrowBinOp =
24167 DAG.getNode(Opcode: BOpcode, DL, VT: NarrowBVT, N1: X, N2: Y, Flags: BinOp->getFlags());
24168 return DAG.getBitcast(VT, V: NarrowBinOp);
24169 }
24170
24171 // Only handle the case where we are doubling and then halving. A larger ratio
24172 // may require more than two narrow binops to replace the wide binop.
24173 if (NarrowingRatio != 2)
24174 return SDValue();
24175
24176 // TODO: The motivating case for this transform is an x86 AVX1 target. That
24177 // target has temptingly almost legal versions of bitwise logic ops in 256-bit
24178 // flavors, but no other 256-bit integer support. This could be extended to
24179 // handle any binop, but that may require fixing/adding other folds to avoid
24180 // codegen regressions.
24181 if (BOpcode != ISD::AND && BOpcode != ISD::OR && BOpcode != ISD::XOR)
24182 return SDValue();
24183
24184 // We need at least one concatenation operation of a binop operand to make
24185 // this transform worthwhile. The concat must double the input vector sizes.
24186 auto GetSubVector = [ConcatOpNum](SDValue V) -> SDValue {
24187 if (V.getOpcode() == ISD::CONCAT_VECTORS && V.getNumOperands() == 2)
24188 return V.getOperand(i: ConcatOpNum);
24189 return SDValue();
24190 };
24191 SDValue SubVecL = GetSubVector(peekThroughBitcasts(V: BinOp.getOperand(i: 0)));
24192 SDValue SubVecR = GetSubVector(peekThroughBitcasts(V: BinOp.getOperand(i: 1)));
24193
24194 if (SubVecL || SubVecR) {
24195 // If a binop operand was not the result of a concat, we must extract a
24196 // half-sized operand for our new narrow binop:
24197 // extract (binop (concat X1, X2), (concat Y1, Y2)), N --> binop XN, YN
24198 // extract (binop (concat X1, X2), Y), N --> binop XN, (extract Y, IndexC)
24199 // extract (binop X, (concat Y1, Y2)), N --> binop (extract X, IndexC), YN
24200 SDLoc DL(Extract);
24201 SDValue IndexC = DAG.getVectorIdxConstant(Val: ExtBOIdx, DL);
24202 SDValue X = SubVecL ? DAG.getBitcast(VT: NarrowBVT, V: SubVecL)
24203 : DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NarrowBVT,
24204 N1: BinOp.getOperand(i: 0), N2: IndexC);
24205
24206 SDValue Y = SubVecR ? DAG.getBitcast(VT: NarrowBVT, V: SubVecR)
24207 : DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NarrowBVT,
24208 N1: BinOp.getOperand(i: 1), N2: IndexC);
24209
24210 SDValue NarrowBinOp = DAG.getNode(Opcode: BOpcode, DL, VT: NarrowBVT, N1: X, N2: Y);
24211 return DAG.getBitcast(VT, V: NarrowBinOp);
24212 }
24213
24214 return SDValue();
24215}
24216
24217/// If we are extracting a subvector from a wide vector load, convert to a
24218/// narrow load to eliminate the extraction:
24219/// (extract_subvector (load wide vector)) --> (load narrow vector)
24220static SDValue narrowExtractedVectorLoad(SDNode *Extract, SelectionDAG &DAG) {
24221 // TODO: Add support for big-endian. The offset calculation must be adjusted.
24222 if (DAG.getDataLayout().isBigEndian())
24223 return SDValue();
24224
24225 auto *Ld = dyn_cast<LoadSDNode>(Val: Extract->getOperand(Num: 0));
24226 if (!Ld || Ld->getExtensionType() || !Ld->isSimple())
24227 return SDValue();
24228
24229 // Allow targets to opt-out.
24230 EVT VT = Extract->getValueType(ResNo: 0);
24231
24232 // We can only create byte sized loads.
24233 if (!VT.isByteSized())
24234 return SDValue();
24235
24236 unsigned Index = Extract->getConstantOperandVal(Num: 1);
24237 unsigned NumElts = VT.getVectorMinNumElements();
24238 // A fixed length vector being extracted from a scalable vector
24239 // may not be any *smaller* than the scalable one.
24240 if (Index == 0 && NumElts >= Ld->getValueType(ResNo: 0).getVectorMinNumElements())
24241 return SDValue();
24242
24243 // The definition of EXTRACT_SUBVECTOR states that the index must be a
24244 // multiple of the minimum number of elements in the result type.
24245 assert(Index % NumElts == 0 && "The extract subvector index is not a "
24246 "multiple of the result's element count");
24247
24248 // It's fine to use TypeSize here as we know the offset will not be negative.
24249 TypeSize Offset = VT.getStoreSize() * (Index / NumElts);
24250
24251 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24252 if (!TLI.shouldReduceLoadWidth(Load: Ld, ExtTy: Ld->getExtensionType(), NewVT: VT))
24253 return SDValue();
24254
24255 // The narrow load will be offset from the base address of the old load if
24256 // we are extracting from something besides index 0 (little-endian).
24257 SDLoc DL(Extract);
24258
24259 // TODO: Use "BaseIndexOffset" to make this more effective.
24260 SDValue NewAddr = DAG.getMemBasePlusOffset(Base: Ld->getBasePtr(), Offset, DL);
24261
24262 uint64_t StoreSize = MemoryLocation::getSizeOrUnknown(T: VT.getStoreSize());
24263 MachineFunction &MF = DAG.getMachineFunction();
24264 MachineMemOperand *MMO;
24265 if (Offset.isScalable()) {
24266 MachinePointerInfo MPI =
24267 MachinePointerInfo(Ld->getPointerInfo().getAddrSpace());
24268 MMO = MF.getMachineMemOperand(MMO: Ld->getMemOperand(), PtrInfo: MPI, Size: StoreSize);
24269 } else
24270 MMO = MF.getMachineMemOperand(MMO: Ld->getMemOperand(), Offset: Offset.getFixedValue(),
24271 Size: StoreSize);
24272
24273 SDValue NewLd = DAG.getLoad(VT, dl: DL, Chain: Ld->getChain(), Ptr: NewAddr, MMO);
24274 DAG.makeEquivalentMemoryOrdering(OldLoad: Ld, NewMemOp: NewLd);
24275 return NewLd;
24276}
24277
24278/// Given EXTRACT_SUBVECTOR(VECTOR_SHUFFLE(Op0, Op1, Mask)),
24279/// try to produce VECTOR_SHUFFLE(EXTRACT_SUBVECTOR(Op?, ?),
24280/// EXTRACT_SUBVECTOR(Op?, ?),
24281/// Mask'))
24282/// iff it is legal and profitable to do so. Notably, the trimmed mask
24283/// (containing only the elements that are extracted)
24284/// must reference at most two subvectors.
24285static SDValue foldExtractSubvectorFromShuffleVector(SDNode *N,
24286 SelectionDAG &DAG,
24287 const TargetLowering &TLI,
24288 bool LegalOperations) {
24289 assert(N->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
24290 "Must only be called on EXTRACT_SUBVECTOR's");
24291
24292 SDValue N0 = N->getOperand(Num: 0);
24293
24294 // Only deal with non-scalable vectors.
24295 EVT NarrowVT = N->getValueType(ResNo: 0);
24296 EVT WideVT = N0.getValueType();
24297 if (!NarrowVT.isFixedLengthVector() || !WideVT.isFixedLengthVector())
24298 return SDValue();
24299
24300 // The operand must be a shufflevector.
24301 auto *WideShuffleVector = dyn_cast<ShuffleVectorSDNode>(Val&: N0);
24302 if (!WideShuffleVector)
24303 return SDValue();
24304
24305 // The old shuffleneeds to go away.
24306 if (!WideShuffleVector->hasOneUse())
24307 return SDValue();
24308
24309 // And the narrow shufflevector that we'll form must be legal.
24310 if (LegalOperations &&
24311 !TLI.isOperationLegalOrCustom(Op: ISD::VECTOR_SHUFFLE, VT: NarrowVT))
24312 return SDValue();
24313
24314 uint64_t FirstExtractedEltIdx = N->getConstantOperandVal(Num: 1);
24315 int NumEltsExtracted = NarrowVT.getVectorNumElements();
24316 assert((FirstExtractedEltIdx % NumEltsExtracted) == 0 &&
24317 "Extract index is not a multiple of the output vector length.");
24318
24319 int WideNumElts = WideVT.getVectorNumElements();
24320
24321 SmallVector<int, 16> NewMask;
24322 NewMask.reserve(N: NumEltsExtracted);
24323 SmallSetVector<std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>, 2>
24324 DemandedSubvectors;
24325
24326 // Try to decode the wide mask into narrow mask from at most two subvectors.
24327 for (int M : WideShuffleVector->getMask().slice(N: FirstExtractedEltIdx,
24328 M: NumEltsExtracted)) {
24329 assert((M >= -1) && (M < (2 * WideNumElts)) &&
24330 "Out-of-bounds shuffle mask?");
24331
24332 if (M < 0) {
24333 // Does not depend on operands, does not require adjustment.
24334 NewMask.emplace_back(Args&: M);
24335 continue;
24336 }
24337
24338 // From which operand of the shuffle does this shuffle mask element pick?
24339 int WideShufOpIdx = M / WideNumElts;
24340 // Which element of that operand is picked?
24341 int OpEltIdx = M % WideNumElts;
24342
24343 assert((OpEltIdx + WideShufOpIdx * WideNumElts) == M &&
24344 "Shuffle mask vector decomposition failure.");
24345
24346 // And which NumEltsExtracted-sized subvector of that operand is that?
24347 int OpSubvecIdx = OpEltIdx / NumEltsExtracted;
24348 // And which element within that subvector of that operand is that?
24349 int OpEltIdxInSubvec = OpEltIdx % NumEltsExtracted;
24350
24351 assert((OpEltIdxInSubvec + OpSubvecIdx * NumEltsExtracted) == OpEltIdx &&
24352 "Shuffle mask subvector decomposition failure.");
24353
24354 assert((OpEltIdxInSubvec + OpSubvecIdx * NumEltsExtracted +
24355 WideShufOpIdx * WideNumElts) == M &&
24356 "Shuffle mask full decomposition failure.");
24357
24358 SDValue Op = WideShuffleVector->getOperand(Num: WideShufOpIdx);
24359
24360 if (Op.isUndef()) {
24361 // Picking from an undef operand. Let's adjust mask instead.
24362 NewMask.emplace_back(Args: -1);
24363 continue;
24364 }
24365
24366 const std::pair<SDValue, int> DemandedSubvector =
24367 std::make_pair(x&: Op, y&: OpSubvecIdx);
24368
24369 if (DemandedSubvectors.insert(X: DemandedSubvector)) {
24370 if (DemandedSubvectors.size() > 2)
24371 return SDValue(); // We can't handle more than two subvectors.
24372 // How many elements into the WideVT does this subvector start?
24373 int Index = NumEltsExtracted * OpSubvecIdx;
24374 // Bail out if the extraction isn't going to be cheap.
24375 if (!TLI.isExtractSubvectorCheap(ResVT: NarrowVT, SrcVT: WideVT, Index))
24376 return SDValue();
24377 }
24378
24379 // Ok, but from which operand of the new shuffle will this element pick?
24380 int NewOpIdx =
24381 getFirstIndexOf(Range: DemandedSubvectors.getArrayRef(), Val: DemandedSubvector);
24382 assert((NewOpIdx == 0 || NewOpIdx == 1) && "Unexpected operand index.");
24383
24384 int AdjM = OpEltIdxInSubvec + NewOpIdx * NumEltsExtracted;
24385 NewMask.emplace_back(Args&: AdjM);
24386 }
24387 assert(NewMask.size() == (unsigned)NumEltsExtracted && "Produced bad mask.");
24388 assert(DemandedSubvectors.size() <= 2 &&
24389 "Should have ended up demanding at most two subvectors.");
24390
24391 // Did we discover that the shuffle does not actually depend on operands?
24392 if (DemandedSubvectors.empty())
24393 return DAG.getUNDEF(VT: NarrowVT);
24394
24395 // Profitability check: only deal with extractions from the first subvector
24396 // unless the mask becomes an identity mask.
24397 if (!ShuffleVectorInst::isIdentityMask(Mask: NewMask, NumSrcElts: NewMask.size()) ||
24398 any_of(Range&: NewMask, P: [](int M) { return M < 0; }))
24399 for (auto &DemandedSubvector : DemandedSubvectors)
24400 if (DemandedSubvector.second != 0)
24401 return SDValue();
24402
24403 // We still perform the exact same EXTRACT_SUBVECTOR, just on different
24404 // operand[s]/index[es], so there is no point in checking for it's legality.
24405
24406 // Do not turn a legal shuffle into an illegal one.
24407 if (TLI.isShuffleMaskLegal(WideShuffleVector->getMask(), WideVT) &&
24408 !TLI.isShuffleMaskLegal(NewMask, NarrowVT))
24409 return SDValue();
24410
24411 SDLoc DL(N);
24412
24413 SmallVector<SDValue, 2> NewOps;
24414 for (const std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>
24415 &DemandedSubvector : DemandedSubvectors) {
24416 // How many elements into the WideVT does this subvector start?
24417 int Index = NumEltsExtracted * DemandedSubvector.second;
24418 SDValue IndexC = DAG.getVectorIdxConstant(Val: Index, DL);
24419 NewOps.emplace_back(Args: DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NarrowVT,
24420 N1: DemandedSubvector.first, N2: IndexC));
24421 }
24422 assert((NewOps.size() == 1 || NewOps.size() == 2) &&
24423 "Should end up with either one or two ops");
24424
24425 // If we ended up with only one operand, pad with an undef.
24426 if (NewOps.size() == 1)
24427 NewOps.emplace_back(Args: DAG.getUNDEF(VT: NarrowVT));
24428
24429 return DAG.getVectorShuffle(VT: NarrowVT, dl: DL, N1: NewOps[0], N2: NewOps[1], Mask: NewMask);
24430}
24431
24432SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
24433 EVT NVT = N->getValueType(ResNo: 0);
24434 SDValue V = N->getOperand(Num: 0);
24435 uint64_t ExtIdx = N->getConstantOperandVal(Num: 1);
24436
24437 // Extract from UNDEF is UNDEF.
24438 if (V.isUndef())
24439 return DAG.getUNDEF(VT: NVT);
24440
24441 if (TLI.isOperationLegalOrCustomOrPromote(Op: ISD::LOAD, VT: NVT))
24442 if (SDValue NarrowLoad = narrowExtractedVectorLoad(Extract: N, DAG))
24443 return NarrowLoad;
24444
24445 // Combine an extract of an extract into a single extract_subvector.
24446 // ext (ext X, C), 0 --> ext X, C
24447 if (ExtIdx == 0 && V.getOpcode() == ISD::EXTRACT_SUBVECTOR && V.hasOneUse()) {
24448 if (TLI.isExtractSubvectorCheap(ResVT: NVT, SrcVT: V.getOperand(i: 0).getValueType(),
24449 Index: V.getConstantOperandVal(i: 1)) &&
24450 TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_SUBVECTOR, VT: NVT)) {
24451 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL: SDLoc(N), VT: NVT, N1: V.getOperand(i: 0),
24452 N2: V.getOperand(i: 1));
24453 }
24454 }
24455
24456 // ty1 extract_vector(ty2 splat(V))) -> ty1 splat(V)
24457 if (V.getOpcode() == ISD::SPLAT_VECTOR)
24458 if (DAG.isConstantValueOfAnyType(N: V.getOperand(i: 0)) || V.hasOneUse())
24459 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::SPLAT_VECTOR, VT: NVT))
24460 return DAG.getSplatVector(VT: NVT, DL: SDLoc(N), Op: V.getOperand(i: 0));
24461
24462 // Try to move vector bitcast after extract_subv by scaling extraction index:
24463 // extract_subv (bitcast X), Index --> bitcast (extract_subv X, Index')
24464 if (V.getOpcode() == ISD::BITCAST &&
24465 V.getOperand(i: 0).getValueType().isVector() &&
24466 (!LegalOperations || TLI.isOperationLegal(Op: ISD::BITCAST, VT: NVT))) {
24467 SDValue SrcOp = V.getOperand(i: 0);
24468 EVT SrcVT = SrcOp.getValueType();
24469 unsigned SrcNumElts = SrcVT.getVectorMinNumElements();
24470 unsigned DestNumElts = V.getValueType().getVectorMinNumElements();
24471 if ((SrcNumElts % DestNumElts) == 0) {
24472 unsigned SrcDestRatio = SrcNumElts / DestNumElts;
24473 ElementCount NewExtEC = NVT.getVectorElementCount() * SrcDestRatio;
24474 EVT NewExtVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: SrcVT.getScalarType(),
24475 EC: NewExtEC);
24476 if (TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_SUBVECTOR, VT: NewExtVT)) {
24477 SDLoc DL(N);
24478 SDValue NewIndex = DAG.getVectorIdxConstant(Val: ExtIdx * SrcDestRatio, DL);
24479 SDValue NewExtract = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NewExtVT,
24480 N1: V.getOperand(i: 0), N2: NewIndex);
24481 return DAG.getBitcast(VT: NVT, V: NewExtract);
24482 }
24483 }
24484 if ((DestNumElts % SrcNumElts) == 0) {
24485 unsigned DestSrcRatio = DestNumElts / SrcNumElts;
24486 if (NVT.getVectorElementCount().isKnownMultipleOf(RHS: DestSrcRatio)) {
24487 ElementCount NewExtEC =
24488 NVT.getVectorElementCount().divideCoefficientBy(RHS: DestSrcRatio);
24489 EVT ScalarVT = SrcVT.getScalarType();
24490 if ((ExtIdx % DestSrcRatio) == 0) {
24491 SDLoc DL(N);
24492 unsigned IndexValScaled = ExtIdx / DestSrcRatio;
24493 EVT NewExtVT =
24494 EVT::getVectorVT(Context&: *DAG.getContext(), VT: ScalarVT, EC: NewExtEC);
24495 if (TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_SUBVECTOR, VT: NewExtVT)) {
24496 SDValue NewIndex = DAG.getVectorIdxConstant(Val: IndexValScaled, DL);
24497 SDValue NewExtract =
24498 DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NewExtVT,
24499 N1: V.getOperand(i: 0), N2: NewIndex);
24500 return DAG.getBitcast(VT: NVT, V: NewExtract);
24501 }
24502 if (NewExtEC.isScalar() &&
24503 TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_VECTOR_ELT, VT: ScalarVT)) {
24504 SDValue NewIndex = DAG.getVectorIdxConstant(Val: IndexValScaled, DL);
24505 SDValue NewExtract =
24506 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: ScalarVT,
24507 N1: V.getOperand(i: 0), N2: NewIndex);
24508 return DAG.getBitcast(VT: NVT, V: NewExtract);
24509 }
24510 }
24511 }
24512 }
24513 }
24514
24515 if (V.getOpcode() == ISD::CONCAT_VECTORS) {
24516 unsigned ExtNumElts = NVT.getVectorMinNumElements();
24517 EVT ConcatSrcVT = V.getOperand(i: 0).getValueType();
24518 assert(ConcatSrcVT.getVectorElementType() == NVT.getVectorElementType() &&
24519 "Concat and extract subvector do not change element type");
24520 assert((ExtIdx % ExtNumElts) == 0 &&
24521 "Extract index is not a multiple of the input vector length.");
24522
24523 unsigned ConcatSrcNumElts = ConcatSrcVT.getVectorMinNumElements();
24524 unsigned ConcatOpIdx = ExtIdx / ConcatSrcNumElts;
24525
24526 // If the concatenated source types match this extract, it's a direct
24527 // simplification:
24528 // extract_subvec (concat V1, V2, ...), i --> Vi
24529 if (NVT.getVectorElementCount() == ConcatSrcVT.getVectorElementCount())
24530 return V.getOperand(i: ConcatOpIdx);
24531
24532 // If the concatenated source vectors are a multiple length of this extract,
24533 // then extract a fraction of one of those source vectors directly from a
24534 // concat operand. Example:
24535 // v2i8 extract_subvec (v16i8 concat (v8i8 X), (v8i8 Y), 14 -->
24536 // v2i8 extract_subvec v8i8 Y, 6
24537 if (NVT.isFixedLengthVector() && ConcatSrcVT.isFixedLengthVector() &&
24538 ConcatSrcNumElts % ExtNumElts == 0) {
24539 SDLoc DL(N);
24540 unsigned NewExtIdx = ExtIdx - ConcatOpIdx * ConcatSrcNumElts;
24541 assert(NewExtIdx + ExtNumElts <= ConcatSrcNumElts &&
24542 "Trying to extract from >1 concat operand?");
24543 assert(NewExtIdx % ExtNumElts == 0 &&
24544 "Extract index is not a multiple of the input vector length.");
24545 SDValue NewIndexC = DAG.getVectorIdxConstant(Val: NewExtIdx, DL);
24546 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NVT,
24547 N1: V.getOperand(i: ConcatOpIdx), N2: NewIndexC);
24548 }
24549 }
24550
24551 if (SDValue V =
24552 foldExtractSubvectorFromShuffleVector(N, DAG, TLI, LegalOperations))
24553 return V;
24554
24555 V = peekThroughBitcasts(V);
24556
24557 // If the input is a build vector. Try to make a smaller build vector.
24558 if (V.getOpcode() == ISD::BUILD_VECTOR) {
24559 EVT InVT = V.getValueType();
24560 unsigned ExtractSize = NVT.getSizeInBits();
24561 unsigned EltSize = InVT.getScalarSizeInBits();
24562 // Only do this if we won't split any elements.
24563 if (ExtractSize % EltSize == 0) {
24564 unsigned NumElems = ExtractSize / EltSize;
24565 EVT EltVT = InVT.getVectorElementType();
24566 EVT ExtractVT =
24567 NumElems == 1 ? EltVT
24568 : EVT::getVectorVT(Context&: *DAG.getContext(), VT: EltVT, NumElements: NumElems);
24569 if ((Level < AfterLegalizeDAG ||
24570 (NumElems == 1 ||
24571 TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT: ExtractVT))) &&
24572 (!LegalTypes || TLI.isTypeLegal(VT: ExtractVT))) {
24573 unsigned IdxVal = (ExtIdx * NVT.getScalarSizeInBits()) / EltSize;
24574
24575 if (NumElems == 1) {
24576 SDValue Src = V->getOperand(Num: IdxVal);
24577 if (EltVT != Src.getValueType())
24578 Src = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N), VT: EltVT, Operand: Src);
24579 return DAG.getBitcast(VT: NVT, V: Src);
24580 }
24581
24582 // Extract the pieces from the original build_vector.
24583 SDValue BuildVec = DAG.getBuildVector(VT: ExtractVT, DL: SDLoc(N),
24584 Ops: V->ops().slice(N: IdxVal, M: NumElems));
24585 return DAG.getBitcast(VT: NVT, V: BuildVec);
24586 }
24587 }
24588 }
24589
24590 if (V.getOpcode() == ISD::INSERT_SUBVECTOR) {
24591 // Handle only simple case where vector being inserted and vector
24592 // being extracted are of same size.
24593 EVT SmallVT = V.getOperand(i: 1).getValueType();
24594 if (!NVT.bitsEq(VT: SmallVT))
24595 return SDValue();
24596
24597 // Combine:
24598 // (extract_subvec (insert_subvec V1, V2, InsIdx), ExtIdx)
24599 // Into:
24600 // indices are equal or bit offsets are equal => V1
24601 // otherwise => (extract_subvec V1, ExtIdx)
24602 uint64_t InsIdx = V.getConstantOperandVal(i: 2);
24603 if (InsIdx * SmallVT.getScalarSizeInBits() ==
24604 ExtIdx * NVT.getScalarSizeInBits()) {
24605 if (LegalOperations && !TLI.isOperationLegal(Op: ISD::BITCAST, VT: NVT))
24606 return SDValue();
24607
24608 return DAG.getBitcast(VT: NVT, V: V.getOperand(i: 1));
24609 }
24610 return DAG.getNode(
24611 Opcode: ISD::EXTRACT_SUBVECTOR, DL: SDLoc(N), VT: NVT,
24612 N1: DAG.getBitcast(VT: N->getOperand(Num: 0).getValueType(), V: V.getOperand(i: 0)),
24613 N2: N->getOperand(Num: 1));
24614 }
24615
24616 if (SDValue NarrowBOp = narrowExtractedVectorBinOp(Extract: N, DAG, LegalOperations))
24617 return NarrowBOp;
24618
24619 if (SimplifyDemandedVectorElts(Op: SDValue(N, 0)))
24620 return SDValue(N, 0);
24621
24622 return SDValue();
24623}
24624
24625/// Try to convert a wide shuffle of concatenated vectors into 2 narrow shuffles
24626/// followed by concatenation. Narrow vector ops may have better performance
24627/// than wide ops, and this can unlock further narrowing of other vector ops.
24628/// Targets can invert this transform later if it is not profitable.
24629static SDValue foldShuffleOfConcatUndefs(ShuffleVectorSDNode *Shuf,
24630 SelectionDAG &DAG) {
24631 SDValue N0 = Shuf->getOperand(Num: 0), N1 = Shuf->getOperand(Num: 1);
24632 if (N0.getOpcode() != ISD::CONCAT_VECTORS || N0.getNumOperands() != 2 ||
24633 N1.getOpcode() != ISD::CONCAT_VECTORS || N1.getNumOperands() != 2 ||
24634 !N0.getOperand(i: 1).isUndef() || !N1.getOperand(i: 1).isUndef())
24635 return SDValue();
24636
24637 // Split the wide shuffle mask into halves. Any mask element that is accessing
24638 // operand 1 is offset down to account for narrowing of the vectors.
24639 ArrayRef<int> Mask = Shuf->getMask();
24640 EVT VT = Shuf->getValueType(ResNo: 0);
24641 unsigned NumElts = VT.getVectorNumElements();
24642 unsigned HalfNumElts = NumElts / 2;
24643 SmallVector<int, 16> Mask0(HalfNumElts, -1);
24644 SmallVector<int, 16> Mask1(HalfNumElts, -1);
24645 for (unsigned i = 0; i != NumElts; ++i) {
24646 if (Mask[i] == -1)
24647 continue;
24648 // If we reference the upper (undef) subvector then the element is undef.
24649 if ((Mask[i] % NumElts) >= HalfNumElts)
24650 continue;
24651 int M = Mask[i] < (int)NumElts ? Mask[i] : Mask[i] - (int)HalfNumElts;
24652 if (i < HalfNumElts)
24653 Mask0[i] = M;
24654 else
24655 Mask1[i - HalfNumElts] = M;
24656 }
24657
24658 // Ask the target if this is a valid transform.
24659 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24660 EVT HalfVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: VT.getScalarType(),
24661 NumElements: HalfNumElts);
24662 if (!TLI.isShuffleMaskLegal(Mask0, HalfVT) ||
24663 !TLI.isShuffleMaskLegal(Mask1, HalfVT))
24664 return SDValue();
24665
24666 // shuffle (concat X, undef), (concat Y, undef), Mask -->
24667 // concat (shuffle X, Y, Mask0), (shuffle X, Y, Mask1)
24668 SDValue X = N0.getOperand(i: 0), Y = N1.getOperand(i: 0);
24669 SDLoc DL(Shuf);
24670 SDValue Shuf0 = DAG.getVectorShuffle(VT: HalfVT, dl: DL, N1: X, N2: Y, Mask: Mask0);
24671 SDValue Shuf1 = DAG.getVectorShuffle(VT: HalfVT, dl: DL, N1: X, N2: Y, Mask: Mask1);
24672 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT, N1: Shuf0, N2: Shuf1);
24673}
24674
24675// Tries to turn a shuffle of two CONCAT_VECTORS into a single concat,
24676// or turn a shuffle of a single concat into simpler shuffle then concat.
24677static SDValue partitionShuffleOfConcats(SDNode *N, SelectionDAG &DAG) {
24678 EVT VT = N->getValueType(ResNo: 0);
24679 unsigned NumElts = VT.getVectorNumElements();
24680
24681 SDValue N0 = N->getOperand(Num: 0);
24682 SDValue N1 = N->getOperand(Num: 1);
24683 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Val: N);
24684 ArrayRef<int> Mask = SVN->getMask();
24685
24686 SmallVector<SDValue, 4> Ops;
24687 EVT ConcatVT = N0.getOperand(i: 0).getValueType();
24688 unsigned NumElemsPerConcat = ConcatVT.getVectorNumElements();
24689 unsigned NumConcats = NumElts / NumElemsPerConcat;
24690
24691 auto IsUndefMaskElt = [](int i) { return i == -1; };
24692
24693 // Special case: shuffle(concat(A,B)) can be more efficiently represented
24694 // as concat(shuffle(A,B),UNDEF) if the shuffle doesn't set any of the high
24695 // half vector elements.
24696 if (NumElemsPerConcat * 2 == NumElts && N1.isUndef() &&
24697 llvm::all_of(Range: Mask.slice(N: NumElemsPerConcat, M: NumElemsPerConcat),
24698 P: IsUndefMaskElt)) {
24699 N0 = DAG.getVectorShuffle(VT: ConcatVT, dl: SDLoc(N), N1: N0.getOperand(i: 0),
24700 N2: N0.getOperand(i: 1),
24701 Mask: Mask.slice(N: 0, M: NumElemsPerConcat));
24702 N1 = DAG.getUNDEF(VT: ConcatVT);
24703 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT, N1: N0, N2: N1);
24704 }
24705
24706 // Look at every vector that's inserted. We're looking for exact
24707 // subvector-sized copies from a concatenated vector
24708 for (unsigned I = 0; I != NumConcats; ++I) {
24709 unsigned Begin = I * NumElemsPerConcat;
24710 ArrayRef<int> SubMask = Mask.slice(N: Begin, M: NumElemsPerConcat);
24711
24712 // Make sure we're dealing with a copy.
24713 if (llvm::all_of(Range&: SubMask, P: IsUndefMaskElt)) {
24714 Ops.push_back(Elt: DAG.getUNDEF(VT: ConcatVT));
24715 continue;
24716 }
24717
24718 int OpIdx = -1;
24719 for (int i = 0; i != (int)NumElemsPerConcat; ++i) {
24720 if (IsUndefMaskElt(SubMask[i]))
24721 continue;
24722 if ((SubMask[i] % (int)NumElemsPerConcat) != i)
24723 return SDValue();
24724 int EltOpIdx = SubMask[i] / NumElemsPerConcat;
24725 if (0 <= OpIdx && EltOpIdx != OpIdx)
24726 return SDValue();
24727 OpIdx = EltOpIdx;
24728 }
24729 assert(0 <= OpIdx && "Unknown concat_vectors op");
24730
24731 if (OpIdx < (int)N0.getNumOperands())
24732 Ops.push_back(Elt: N0.getOperand(i: OpIdx));
24733 else
24734 Ops.push_back(Elt: N1.getOperand(i: OpIdx - N0.getNumOperands()));
24735 }
24736
24737 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT, Ops);
24738}
24739
24740// Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
24741// BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
24742//
24743// SHUFFLE(BUILD_VECTOR(), BUILD_VECTOR()) -> BUILD_VECTOR() is always
24744// a simplification in some sense, but it isn't appropriate in general: some
24745// BUILD_VECTORs are substantially cheaper than others. The general case
24746// of a BUILD_VECTOR requires inserting each element individually (or
24747// performing the equivalent in a temporary stack variable). A BUILD_VECTOR of
24748// all constants is a single constant pool load. A BUILD_VECTOR where each
24749// element is identical is a splat. A BUILD_VECTOR where most of the operands
24750// are undef lowers to a small number of element insertions.
24751//
24752// To deal with this, we currently use a bunch of mostly arbitrary heuristics.
24753// We don't fold shuffles where one side is a non-zero constant, and we don't
24754// fold shuffles if the resulting (non-splat) BUILD_VECTOR would have duplicate
24755// non-constant operands. This seems to work out reasonably well in practice.
24756static SDValue combineShuffleOfScalars(ShuffleVectorSDNode *SVN,
24757 SelectionDAG &DAG,
24758 const TargetLowering &TLI) {
24759 EVT VT = SVN->getValueType(ResNo: 0);
24760 unsigned NumElts = VT.getVectorNumElements();
24761 SDValue N0 = SVN->getOperand(Num: 0);
24762 SDValue N1 = SVN->getOperand(Num: 1);
24763
24764 if (!N0->hasOneUse())
24765 return SDValue();
24766
24767 // If only one of N1,N2 is constant, bail out if it is not ALL_ZEROS as
24768 // discussed above.
24769 if (!N1.isUndef()) {
24770 if (!N1->hasOneUse())
24771 return SDValue();
24772
24773 bool N0AnyConst = isAnyConstantBuildVector(V: N0);
24774 bool N1AnyConst = isAnyConstantBuildVector(V: N1);
24775 if (N0AnyConst && !N1AnyConst && !ISD::isBuildVectorAllZeros(N: N0.getNode()))
24776 return SDValue();
24777 if (!N0AnyConst && N1AnyConst && !ISD::isBuildVectorAllZeros(N: N1.getNode()))
24778 return SDValue();
24779 }
24780
24781 // If both inputs are splats of the same value then we can safely merge this
24782 // to a single BUILD_VECTOR with undef elements based on the shuffle mask.
24783 bool IsSplat = false;
24784 auto *BV0 = dyn_cast<BuildVectorSDNode>(Val&: N0);
24785 auto *BV1 = dyn_cast<BuildVectorSDNode>(Val&: N1);
24786 if (BV0 && BV1)
24787 if (SDValue Splat0 = BV0->getSplatValue())
24788 IsSplat = (Splat0 == BV1->getSplatValue());
24789
24790 SmallVector<SDValue, 8> Ops;
24791 SmallSet<SDValue, 16> DuplicateOps;
24792 for (int M : SVN->getMask()) {
24793 SDValue Op = DAG.getUNDEF(VT: VT.getScalarType());
24794 if (M >= 0) {
24795 int Idx = M < (int)NumElts ? M : M - NumElts;
24796 SDValue &S = (M < (int)NumElts ? N0 : N1);
24797 if (S.getOpcode() == ISD::BUILD_VECTOR) {
24798 Op = S.getOperand(i: Idx);
24799 } else if (S.getOpcode() == ISD::SCALAR_TO_VECTOR) {
24800 SDValue Op0 = S.getOperand(i: 0);
24801 Op = Idx == 0 ? Op0 : DAG.getUNDEF(VT: Op0.getValueType());
24802 } else {
24803 // Operand can't be combined - bail out.
24804 return SDValue();
24805 }
24806 }
24807
24808 // Don't duplicate a non-constant BUILD_VECTOR operand unless we're
24809 // generating a splat; semantically, this is fine, but it's likely to
24810 // generate low-quality code if the target can't reconstruct an appropriate
24811 // shuffle.
24812 if (!Op.isUndef() && !isIntOrFPConstant(V: Op))
24813 if (!IsSplat && !DuplicateOps.insert(V: Op).second)
24814 return SDValue();
24815
24816 Ops.push_back(Elt: Op);
24817 }
24818
24819 // BUILD_VECTOR requires all inputs to be of the same type, find the
24820 // maximum type and extend them all.
24821 EVT SVT = VT.getScalarType();
24822 if (SVT.isInteger())
24823 for (SDValue &Op : Ops)
24824 SVT = (SVT.bitsLT(VT: Op.getValueType()) ? Op.getValueType() : SVT);
24825 if (SVT != VT.getScalarType())
24826 for (SDValue &Op : Ops)
24827 Op = Op.isUndef() ? DAG.getUNDEF(VT: SVT)
24828 : (TLI.isZExtFree(FromTy: Op.getValueType(), ToTy: SVT)
24829 ? DAG.getZExtOrTrunc(Op, DL: SDLoc(SVN), VT: SVT)
24830 : DAG.getSExtOrTrunc(Op, DL: SDLoc(SVN), VT: SVT));
24831 return DAG.getBuildVector(VT, DL: SDLoc(SVN), Ops);
24832}
24833
24834// Match shuffles that can be converted to *_vector_extend_in_reg.
24835// This is often generated during legalization.
24836// e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src)),
24837// and returns the EVT to which the extension should be performed.
24838// NOTE: this assumes that the src is the first operand of the shuffle.
24839static std::optional<EVT> canCombineShuffleToExtendVectorInreg(
24840 unsigned Opcode, EVT VT, std::function<bool(unsigned)> Match,
24841 SelectionDAG &DAG, const TargetLowering &TLI, bool LegalTypes,
24842 bool LegalOperations) {
24843 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
24844
24845 // TODO Add support for big-endian when we have a test case.
24846 if (!VT.isInteger() || IsBigEndian)
24847 return std::nullopt;
24848
24849 unsigned NumElts = VT.getVectorNumElements();
24850 unsigned EltSizeInBits = VT.getScalarSizeInBits();
24851
24852 // Attempt to match a '*_extend_vector_inreg' shuffle, we just search for
24853 // power-of-2 extensions as they are the most likely.
24854 // FIXME: should try Scale == NumElts case too,
24855 for (unsigned Scale = 2; Scale < NumElts; Scale *= 2) {
24856 // The vector width must be a multiple of Scale.
24857 if (NumElts % Scale != 0)
24858 continue;
24859
24860 EVT OutSVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: EltSizeInBits * Scale);
24861 EVT OutVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: OutSVT, NumElements: NumElts / Scale);
24862
24863 if ((LegalTypes && !TLI.isTypeLegal(VT: OutVT)) ||
24864 (LegalOperations && !TLI.isOperationLegalOrCustom(Op: Opcode, VT: OutVT)))
24865 continue;
24866
24867 if (Match(Scale))
24868 return OutVT;
24869 }
24870
24871 return std::nullopt;
24872}
24873
24874// Match shuffles that can be converted to any_vector_extend_in_reg.
24875// This is often generated during legalization.
24876// e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src))
24877static SDValue combineShuffleToAnyExtendVectorInreg(ShuffleVectorSDNode *SVN,
24878 SelectionDAG &DAG,
24879 const TargetLowering &TLI,
24880 bool LegalOperations) {
24881 EVT VT = SVN->getValueType(ResNo: 0);
24882 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
24883
24884 // TODO Add support for big-endian when we have a test case.
24885 if (!VT.isInteger() || IsBigEndian)
24886 return SDValue();
24887
24888 // shuffle<0,-1,1,-1> == (v2i64 anyextend_vector_inreg(v4i32))
24889 auto isAnyExtend = [NumElts = VT.getVectorNumElements(),
24890 Mask = SVN->getMask()](unsigned Scale) {
24891 for (unsigned i = 0; i != NumElts; ++i) {
24892 if (Mask[i] < 0)
24893 continue;
24894 if ((i % Scale) == 0 && Mask[i] == (int)(i / Scale))
24895 continue;
24896 return false;
24897 }
24898 return true;
24899 };
24900
24901 unsigned Opcode = ISD::ANY_EXTEND_VECTOR_INREG;
24902 SDValue N0 = SVN->getOperand(Num: 0);
24903 // Never create an illegal type. Only create unsupported operations if we
24904 // are pre-legalization.
24905 std::optional<EVT> OutVT = canCombineShuffleToExtendVectorInreg(
24906 Opcode, VT, Match: isAnyExtend, DAG, TLI, /*LegalTypes=*/true, LegalOperations);
24907 if (!OutVT)
24908 return SDValue();
24909 return DAG.getBitcast(VT, V: DAG.getNode(Opcode, DL: SDLoc(SVN), VT: *OutVT, Operand: N0));
24910}
24911
24912// Match shuffles that can be converted to zero_extend_vector_inreg.
24913// This is often generated during legalization.
24914// e.g. v4i32 <0,z,1,u> -> (v2i64 zero_extend_vector_inreg(v4i32 src))
24915static SDValue combineShuffleToZeroExtendVectorInReg(ShuffleVectorSDNode *SVN,
24916 SelectionDAG &DAG,
24917 const TargetLowering &TLI,
24918 bool LegalOperations) {
24919 bool LegalTypes = true;
24920 EVT VT = SVN->getValueType(ResNo: 0);
24921 assert(!VT.isScalableVector() && "Encountered scalable shuffle?");
24922 unsigned NumElts = VT.getVectorNumElements();
24923 unsigned EltSizeInBits = VT.getScalarSizeInBits();
24924
24925 // TODO: add support for big-endian when we have a test case.
24926 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
24927 if (!VT.isInteger() || IsBigEndian)
24928 return SDValue();
24929
24930 SmallVector<int, 16> Mask(SVN->getMask().begin(), SVN->getMask().end());
24931 auto ForEachDecomposedIndice = [NumElts, &Mask](auto Fn) {
24932 for (int &Indice : Mask) {
24933 if (Indice < 0)
24934 continue;
24935 int OpIdx = (unsigned)Indice < NumElts ? 0 : 1;
24936 int OpEltIdx = (unsigned)Indice < NumElts ? Indice : Indice - NumElts;
24937 Fn(Indice, OpIdx, OpEltIdx);
24938 }
24939 };
24940
24941 // Which elements of which operand does this shuffle demand?
24942 std::array<APInt, 2> OpsDemandedElts;
24943 for (APInt &OpDemandedElts : OpsDemandedElts)
24944 OpDemandedElts = APInt::getZero(numBits: NumElts);
24945 ForEachDecomposedIndice(
24946 [&OpsDemandedElts](int &Indice, int OpIdx, int OpEltIdx) {
24947 OpsDemandedElts[OpIdx].setBit(OpEltIdx);
24948 });
24949
24950 // Element-wise(!), which of these demanded elements are know to be zero?
24951 std::array<APInt, 2> OpsKnownZeroElts;
24952 for (auto I : zip(t: SVN->ops(), u&: OpsDemandedElts, args&: OpsKnownZeroElts))
24953 std::get<2>(t&: I) =
24954 DAG.computeVectorKnownZeroElements(Op: std::get<0>(t&: I), DemandedElts: std::get<1>(t&: I));
24955
24956 // Manifest zeroable element knowledge in the shuffle mask.
24957 // NOTE: we don't have 'zeroable' sentinel value in generic DAG,
24958 // this is a local invention, but it won't leak into DAG.
24959 // FIXME: should we not manifest them, but just check when matching?
24960 bool HadZeroableElts = false;
24961 ForEachDecomposedIndice([&OpsKnownZeroElts, &HadZeroableElts](
24962 int &Indice, int OpIdx, int OpEltIdx) {
24963 if (OpsKnownZeroElts[OpIdx][OpEltIdx]) {
24964 Indice = -2; // Zeroable element.
24965 HadZeroableElts = true;
24966 }
24967 });
24968
24969 // Don't proceed unless we've refined at least one zeroable mask indice.
24970 // If we didn't, then we are still trying to match the same shuffle mask
24971 // we previously tried to match as ISD::ANY_EXTEND_VECTOR_INREG,
24972 // and evidently failed. Proceeding will lead to endless combine loops.
24973 if (!HadZeroableElts)
24974 return SDValue();
24975
24976 // The shuffle may be more fine-grained than we want. Widen elements first.
24977 // FIXME: should we do this before manifesting zeroable shuffle mask indices?
24978 SmallVector<int, 16> ScaledMask;
24979 getShuffleMaskWithWidestElts(Mask, ScaledMask);
24980 assert(Mask.size() >= ScaledMask.size() &&
24981 Mask.size() % ScaledMask.size() == 0 && "Unexpected mask widening.");
24982 int Prescale = Mask.size() / ScaledMask.size();
24983
24984 NumElts = ScaledMask.size();
24985 EltSizeInBits *= Prescale;
24986
24987 EVT PrescaledVT = EVT::getVectorVT(
24988 Context&: *DAG.getContext(), VT: EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: EltSizeInBits),
24989 NumElements: NumElts);
24990
24991 if (LegalTypes && !TLI.isTypeLegal(VT: PrescaledVT) && TLI.isTypeLegal(VT))
24992 return SDValue();
24993
24994 // For example,
24995 // shuffle<0,z,1,-1> == (v2i64 zero_extend_vector_inreg(v4i32))
24996 // But not shuffle<z,z,1,-1> and not shuffle<0,z,z,-1> ! (for same types)
24997 auto isZeroExtend = [NumElts, &ScaledMask](unsigned Scale) {
24998 assert(Scale >= 2 && Scale <= NumElts && NumElts % Scale == 0 &&
24999 "Unexpected mask scaling factor.");
25000 ArrayRef<int> Mask = ScaledMask;
25001 for (unsigned SrcElt = 0, NumSrcElts = NumElts / Scale;
25002 SrcElt != NumSrcElts; ++SrcElt) {
25003 // Analyze the shuffle mask in Scale-sized chunks.
25004 ArrayRef<int> MaskChunk = Mask.take_front(N: Scale);
25005 assert(MaskChunk.size() == Scale && "Unexpected mask size.");
25006 Mask = Mask.drop_front(N: MaskChunk.size());
25007 // The first indice in this chunk must be SrcElt, but not zero!
25008 // FIXME: undef should be fine, but that results in more-defined result.
25009 if (int FirstIndice = MaskChunk[0]; (unsigned)FirstIndice != SrcElt)
25010 return false;
25011 // The rest of the indices in this chunk must be zeros.
25012 // FIXME: undef should be fine, but that results in more-defined result.
25013 if (!all_of(Range: MaskChunk.drop_front(N: 1),
25014 P: [](int Indice) { return Indice == -2; }))
25015 return false;
25016 }
25017 assert(Mask.empty() && "Did not process the whole mask?");
25018 return true;
25019 };
25020
25021 unsigned Opcode = ISD::ZERO_EXTEND_VECTOR_INREG;
25022 for (bool Commuted : {false, true}) {
25023 SDValue Op = SVN->getOperand(Num: !Commuted ? 0 : 1);
25024 if (Commuted)
25025 ShuffleVectorSDNode::commuteMask(Mask: ScaledMask);
25026 std::optional<EVT> OutVT = canCombineShuffleToExtendVectorInreg(
25027 Opcode, VT: PrescaledVT, Match: isZeroExtend, DAG, TLI, LegalTypes,
25028 LegalOperations);
25029 if (OutVT)
25030 return DAG.getBitcast(VT, V: DAG.getNode(Opcode, DL: SDLoc(SVN), VT: *OutVT,
25031 Operand: DAG.getBitcast(VT: PrescaledVT, V: Op)));
25032 }
25033 return SDValue();
25034}
25035
25036// Detect 'truncate_vector_inreg' style shuffles that pack the lower parts of
25037// each source element of a large type into the lowest elements of a smaller
25038// destination type. This is often generated during legalization.
25039// If the source node itself was a '*_extend_vector_inreg' node then we should
25040// then be able to remove it.
25041static SDValue combineTruncationShuffle(ShuffleVectorSDNode *SVN,
25042 SelectionDAG &DAG) {
25043 EVT VT = SVN->getValueType(ResNo: 0);
25044 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
25045
25046 // TODO Add support for big-endian when we have a test case.
25047 if (!VT.isInteger() || IsBigEndian)
25048 return SDValue();
25049
25050 SDValue N0 = peekThroughBitcasts(V: SVN->getOperand(Num: 0));
25051
25052 unsigned Opcode = N0.getOpcode();
25053 if (!ISD::isExtVecInRegOpcode(Opcode))
25054 return SDValue();
25055
25056 SDValue N00 = N0.getOperand(i: 0);
25057 ArrayRef<int> Mask = SVN->getMask();
25058 unsigned NumElts = VT.getVectorNumElements();
25059 unsigned EltSizeInBits = VT.getScalarSizeInBits();
25060 unsigned ExtSrcSizeInBits = N00.getScalarValueSizeInBits();
25061 unsigned ExtDstSizeInBits = N0.getScalarValueSizeInBits();
25062
25063 if (ExtDstSizeInBits % ExtSrcSizeInBits != 0)
25064 return SDValue();
25065 unsigned ExtScale = ExtDstSizeInBits / ExtSrcSizeInBits;
25066
25067 // (v4i32 truncate_vector_inreg(v2i64)) == shuffle<0,2-1,-1>
25068 // (v8i16 truncate_vector_inreg(v4i32)) == shuffle<0,2,4,6,-1,-1,-1,-1>
25069 // (v8i16 truncate_vector_inreg(v2i64)) == shuffle<0,4,-1,-1,-1,-1,-1,-1>
25070 auto isTruncate = [&Mask, &NumElts](unsigned Scale) {
25071 for (unsigned i = 0; i != NumElts; ++i) {
25072 if (Mask[i] < 0)
25073 continue;
25074 if ((i * Scale) < NumElts && Mask[i] == (int)(i * Scale))
25075 continue;
25076 return false;
25077 }
25078 return true;
25079 };
25080
25081 // At the moment we just handle the case where we've truncated back to the
25082 // same size as before the extension.
25083 // TODO: handle more extension/truncation cases as cases arise.
25084 if (EltSizeInBits != ExtSrcSizeInBits)
25085 return SDValue();
25086
25087 // We can remove *extend_vector_inreg only if the truncation happens at
25088 // the same scale as the extension.
25089 if (isTruncate(ExtScale))
25090 return DAG.getBitcast(VT, V: N00);
25091
25092 return SDValue();
25093}
25094
25095// Combine shuffles of splat-shuffles of the form:
25096// shuffle (shuffle V, undef, splat-mask), undef, M
25097// If splat-mask contains undef elements, we need to be careful about
25098// introducing undef's in the folded mask which are not the result of composing
25099// the masks of the shuffles.
25100static SDValue combineShuffleOfSplatVal(ShuffleVectorSDNode *Shuf,
25101 SelectionDAG &DAG) {
25102 EVT VT = Shuf->getValueType(ResNo: 0);
25103 unsigned NumElts = VT.getVectorNumElements();
25104
25105 if (!Shuf->getOperand(Num: 1).isUndef())
25106 return SDValue();
25107
25108 // See if this unary non-splat shuffle actually *is* a splat shuffle,
25109 // in disguise, with all demanded elements being identical.
25110 // FIXME: this can be done per-operand.
25111 if (!Shuf->isSplat()) {
25112 APInt DemandedElts(NumElts, 0);
25113 for (int Idx : Shuf->getMask()) {
25114 if (Idx < 0)
25115 continue; // Ignore sentinel indices.
25116 assert((unsigned)Idx < NumElts && "Out-of-bounds shuffle indice?");
25117 DemandedElts.setBit(Idx);
25118 }
25119 assert(DemandedElts.popcount() > 1 && "Is a splat shuffle already?");
25120 APInt UndefElts;
25121 if (DAG.isSplatValue(V: Shuf->getOperand(Num: 0), DemandedElts, UndefElts)) {
25122 // Even if all demanded elements are splat, some of them could be undef.
25123 // Which lowest demanded element is *not* known-undef?
25124 std::optional<unsigned> MinNonUndefIdx;
25125 for (int Idx : Shuf->getMask()) {
25126 if (Idx < 0 || UndefElts[Idx])
25127 continue; // Ignore sentinel indices, and undef elements.
25128 MinNonUndefIdx = std::min<unsigned>(a: Idx, b: MinNonUndefIdx.value_or(u: ~0U));
25129 }
25130 if (!MinNonUndefIdx)
25131 return DAG.getUNDEF(VT); // All undef - result is undef.
25132 assert(*MinNonUndefIdx < NumElts && "Expected valid element index.");
25133 SmallVector<int, 8> SplatMask(Shuf->getMask().begin(),
25134 Shuf->getMask().end());
25135 for (int &Idx : SplatMask) {
25136 if (Idx < 0)
25137 continue; // Passthrough sentinel indices.
25138 // Otherwise, just pick the lowest demanded non-undef element.
25139 // Or sentinel undef, if we know we'd pick a known-undef element.
25140 Idx = UndefElts[Idx] ? -1 : *MinNonUndefIdx;
25141 }
25142 assert(SplatMask != Shuf->getMask() && "Expected mask to change!");
25143 return DAG.getVectorShuffle(VT, dl: SDLoc(Shuf), N1: Shuf->getOperand(Num: 0),
25144 N2: Shuf->getOperand(Num: 1), Mask: SplatMask);
25145 }
25146 }
25147
25148 // If the inner operand is a known splat with no undefs, just return that directly.
25149 // TODO: Create DemandedElts mask from Shuf's mask.
25150 // TODO: Allow undef elements and merge with the shuffle code below.
25151 if (DAG.isSplatValue(V: Shuf->getOperand(Num: 0), /*AllowUndefs*/ false))
25152 return Shuf->getOperand(Num: 0);
25153
25154 auto *Splat = dyn_cast<ShuffleVectorSDNode>(Val: Shuf->getOperand(Num: 0));
25155 if (!Splat || !Splat->isSplat())
25156 return SDValue();
25157
25158 ArrayRef<int> ShufMask = Shuf->getMask();
25159 ArrayRef<int> SplatMask = Splat->getMask();
25160 assert(ShufMask.size() == SplatMask.size() && "Mask length mismatch");
25161
25162 // Prefer simplifying to the splat-shuffle, if possible. This is legal if
25163 // every undef mask element in the splat-shuffle has a corresponding undef
25164 // element in the user-shuffle's mask or if the composition of mask elements
25165 // would result in undef.
25166 // Examples for (shuffle (shuffle v, undef, SplatMask), undef, UserMask):
25167 // * UserMask=[0,2,u,u], SplatMask=[2,u,2,u] -> [2,2,u,u]
25168 // In this case it is not legal to simplify to the splat-shuffle because we
25169 // may be exposing the users of the shuffle an undef element at index 1
25170 // which was not there before the combine.
25171 // * UserMask=[0,u,2,u], SplatMask=[2,u,2,u] -> [2,u,2,u]
25172 // In this case the composition of masks yields SplatMask, so it's ok to
25173 // simplify to the splat-shuffle.
25174 // * UserMask=[3,u,2,u], SplatMask=[2,u,2,u] -> [u,u,2,u]
25175 // In this case the composed mask includes all undef elements of SplatMask
25176 // and in addition sets element zero to undef. It is safe to simplify to
25177 // the splat-shuffle.
25178 auto CanSimplifyToExistingSplat = [](ArrayRef<int> UserMask,
25179 ArrayRef<int> SplatMask) {
25180 for (unsigned i = 0, e = UserMask.size(); i != e; ++i)
25181 if (UserMask[i] != -1 && SplatMask[i] == -1 &&
25182 SplatMask[UserMask[i]] != -1)
25183 return false;
25184 return true;
25185 };
25186 if (CanSimplifyToExistingSplat(ShufMask, SplatMask))
25187 return Shuf->getOperand(Num: 0);
25188
25189 // Create a new shuffle with a mask that is composed of the two shuffles'
25190 // masks.
25191 SmallVector<int, 32> NewMask;
25192 for (int Idx : ShufMask)
25193 NewMask.push_back(Elt: Idx == -1 ? -1 : SplatMask[Idx]);
25194
25195 return DAG.getVectorShuffle(VT: Splat->getValueType(ResNo: 0), dl: SDLoc(Splat),
25196 N1: Splat->getOperand(Num: 0), N2: Splat->getOperand(Num: 1),
25197 Mask: NewMask);
25198}
25199
25200// Combine shuffles of bitcasts into a shuffle of the bitcast type, providing
25201// the mask can be treated as a larger type.
25202static SDValue combineShuffleOfBitcast(ShuffleVectorSDNode *SVN,
25203 SelectionDAG &DAG,
25204 const TargetLowering &TLI,
25205 bool LegalOperations) {
25206 SDValue Op0 = SVN->getOperand(Num: 0);
25207 SDValue Op1 = SVN->getOperand(Num: 1);
25208 EVT VT = SVN->getValueType(ResNo: 0);
25209 if (Op0.getOpcode() != ISD::BITCAST)
25210 return SDValue();
25211 EVT InVT = Op0.getOperand(i: 0).getValueType();
25212 if (!InVT.isVector() ||
25213 (!Op1.isUndef() && (Op1.getOpcode() != ISD::BITCAST ||
25214 Op1.getOperand(i: 0).getValueType() != InVT)))
25215 return SDValue();
25216 if (isAnyConstantBuildVector(V: Op0.getOperand(i: 0)) &&
25217 (Op1.isUndef() || isAnyConstantBuildVector(V: Op1.getOperand(i: 0))))
25218 return SDValue();
25219
25220 int VTLanes = VT.getVectorNumElements();
25221 int InLanes = InVT.getVectorNumElements();
25222 if (VTLanes <= InLanes || VTLanes % InLanes != 0 ||
25223 (LegalOperations &&
25224 !TLI.isOperationLegalOrCustom(Op: ISD::VECTOR_SHUFFLE, VT: InVT)))
25225 return SDValue();
25226 int Factor = VTLanes / InLanes;
25227
25228 // Check that each group of lanes in the mask are either undef or make a valid
25229 // mask for the wider lane type.
25230 ArrayRef<int> Mask = SVN->getMask();
25231 SmallVector<int> NewMask;
25232 if (!widenShuffleMaskElts(Scale: Factor, Mask, ScaledMask&: NewMask))
25233 return SDValue();
25234
25235 if (!TLI.isShuffleMaskLegal(NewMask, InVT))
25236 return SDValue();
25237
25238 // Create the new shuffle with the new mask and bitcast it back to the
25239 // original type.
25240 SDLoc DL(SVN);
25241 Op0 = Op0.getOperand(i: 0);
25242 Op1 = Op1.isUndef() ? DAG.getUNDEF(VT: InVT) : Op1.getOperand(i: 0);
25243 SDValue NewShuf = DAG.getVectorShuffle(VT: InVT, dl: DL, N1: Op0, N2: Op1, Mask: NewMask);
25244 return DAG.getBitcast(VT, V: NewShuf);
25245}
25246
25247/// Combine shuffle of shuffle of the form:
25248/// shuf (shuf X, undef, InnerMask), undef, OuterMask --> splat X
25249static SDValue formSplatFromShuffles(ShuffleVectorSDNode *OuterShuf,
25250 SelectionDAG &DAG) {
25251 if (!OuterShuf->getOperand(Num: 1).isUndef())
25252 return SDValue();
25253 auto *InnerShuf = dyn_cast<ShuffleVectorSDNode>(Val: OuterShuf->getOperand(Num: 0));
25254 if (!InnerShuf || !InnerShuf->getOperand(Num: 1).isUndef())
25255 return SDValue();
25256
25257 ArrayRef<int> OuterMask = OuterShuf->getMask();
25258 ArrayRef<int> InnerMask = InnerShuf->getMask();
25259 unsigned NumElts = OuterMask.size();
25260 assert(NumElts == InnerMask.size() && "Mask length mismatch");
25261 SmallVector<int, 32> CombinedMask(NumElts, -1);
25262 int SplatIndex = -1;
25263 for (unsigned i = 0; i != NumElts; ++i) {
25264 // Undef lanes remain undef.
25265 int OuterMaskElt = OuterMask[i];
25266 if (OuterMaskElt == -1)
25267 continue;
25268
25269 // Peek through the shuffle masks to get the underlying source element.
25270 int InnerMaskElt = InnerMask[OuterMaskElt];
25271 if (InnerMaskElt == -1)
25272 continue;
25273
25274 // Initialize the splatted element.
25275 if (SplatIndex == -1)
25276 SplatIndex = InnerMaskElt;
25277
25278 // Non-matching index - this is not a splat.
25279 if (SplatIndex != InnerMaskElt)
25280 return SDValue();
25281
25282 CombinedMask[i] = InnerMaskElt;
25283 }
25284 assert((all_of(CombinedMask, [](int M) { return M == -1; }) ||
25285 getSplatIndex(CombinedMask) != -1) &&
25286 "Expected a splat mask");
25287
25288 // TODO: The transform may be a win even if the mask is not legal.
25289 EVT VT = OuterShuf->getValueType(ResNo: 0);
25290 assert(VT == InnerShuf->getValueType(0) && "Expected matching shuffle types");
25291 if (!DAG.getTargetLoweringInfo().isShuffleMaskLegal(CombinedMask, VT))
25292 return SDValue();
25293
25294 return DAG.getVectorShuffle(VT, dl: SDLoc(OuterShuf), N1: InnerShuf->getOperand(Num: 0),
25295 N2: InnerShuf->getOperand(Num: 1), Mask: CombinedMask);
25296}
25297
25298/// If the shuffle mask is taking exactly one element from the first vector
25299/// operand and passing through all other elements from the second vector
25300/// operand, return the index of the mask element that is choosing an element
25301/// from the first operand. Otherwise, return -1.
25302static int getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask) {
25303 int MaskSize = Mask.size();
25304 int EltFromOp0 = -1;
25305 // TODO: This does not match if there are undef elements in the shuffle mask.
25306 // Should we ignore undefs in the shuffle mask instead? The trade-off is
25307 // removing an instruction (a shuffle), but losing the knowledge that some
25308 // vector lanes are not needed.
25309 for (int i = 0; i != MaskSize; ++i) {
25310 if (Mask[i] >= 0 && Mask[i] < MaskSize) {
25311 // We're looking for a shuffle of exactly one element from operand 0.
25312 if (EltFromOp0 != -1)
25313 return -1;
25314 EltFromOp0 = i;
25315 } else if (Mask[i] != i + MaskSize) {
25316 // Nothing from operand 1 can change lanes.
25317 return -1;
25318 }
25319 }
25320 return EltFromOp0;
25321}
25322
25323/// If a shuffle inserts exactly one element from a source vector operand into
25324/// another vector operand and we can access the specified element as a scalar,
25325/// then we can eliminate the shuffle.
25326static SDValue replaceShuffleOfInsert(ShuffleVectorSDNode *Shuf,
25327 SelectionDAG &DAG) {
25328 // First, check if we are taking one element of a vector and shuffling that
25329 // element into another vector.
25330 ArrayRef<int> Mask = Shuf->getMask();
25331 SmallVector<int, 16> CommutedMask(Mask);
25332 SDValue Op0 = Shuf->getOperand(Num: 0);
25333 SDValue Op1 = Shuf->getOperand(Num: 1);
25334 int ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(Mask);
25335 if (ShufOp0Index == -1) {
25336 // Commute mask and check again.
25337 ShuffleVectorSDNode::commuteMask(Mask: CommutedMask);
25338 ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(Mask: CommutedMask);
25339 if (ShufOp0Index == -1)
25340 return SDValue();
25341 // Commute operands to match the commuted shuffle mask.
25342 std::swap(a&: Op0, b&: Op1);
25343 Mask = CommutedMask;
25344 }
25345
25346 // The shuffle inserts exactly one element from operand 0 into operand 1.
25347 // Now see if we can access that element as a scalar via a real insert element
25348 // instruction.
25349 // TODO: We can try harder to locate the element as a scalar. Examples: it
25350 // could be an operand of SCALAR_TO_VECTOR, BUILD_VECTOR, or a constant.
25351 assert(Mask[ShufOp0Index] >= 0 && Mask[ShufOp0Index] < (int)Mask.size() &&
25352 "Shuffle mask value must be from operand 0");
25353 if (Op0.getOpcode() != ISD::INSERT_VECTOR_ELT)
25354 return SDValue();
25355
25356 auto *InsIndexC = dyn_cast<ConstantSDNode>(Val: Op0.getOperand(i: 2));
25357 if (!InsIndexC || InsIndexC->getSExtValue() != Mask[ShufOp0Index])
25358 return SDValue();
25359
25360 // There's an existing insertelement with constant insertion index, so we
25361 // don't need to check the legality/profitability of a replacement operation
25362 // that differs at most in the constant value. The target should be able to
25363 // lower any of those in a similar way. If not, legalization will expand this
25364 // to a scalar-to-vector plus shuffle.
25365 //
25366 // Note that the shuffle may move the scalar from the position that the insert
25367 // element used. Therefore, our new insert element occurs at the shuffle's
25368 // mask index value, not the insert's index value.
25369 // shuffle (insertelt v1, x, C), v2, mask --> insertelt v2, x, C'
25370 SDValue NewInsIndex = DAG.getVectorIdxConstant(Val: ShufOp0Index, DL: SDLoc(Shuf));
25371 return DAG.getNode(Opcode: ISD::INSERT_VECTOR_ELT, DL: SDLoc(Shuf), VT: Op0.getValueType(),
25372 N1: Op1, N2: Op0.getOperand(i: 1), N3: NewInsIndex);
25373}
25374
25375/// If we have a unary shuffle of a shuffle, see if it can be folded away
25376/// completely. This has the potential to lose undef knowledge because the first
25377/// shuffle may not have an undef mask element where the second one does. So
25378/// only call this after doing simplifications based on demanded elements.
25379static SDValue simplifyShuffleOfShuffle(ShuffleVectorSDNode *Shuf) {
25380 // shuf (shuf0 X, Y, Mask0), undef, Mask
25381 auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(Val: Shuf->getOperand(Num: 0));
25382 if (!Shuf0 || !Shuf->getOperand(Num: 1).isUndef())
25383 return SDValue();
25384
25385 ArrayRef<int> Mask = Shuf->getMask();
25386 ArrayRef<int> Mask0 = Shuf0->getMask();
25387 for (int i = 0, e = (int)Mask.size(); i != e; ++i) {
25388 // Ignore undef elements.
25389 if (Mask[i] == -1)
25390 continue;
25391 assert(Mask[i] >= 0 && Mask[i] < e && "Unexpected shuffle mask value");
25392
25393 // Is the element of the shuffle operand chosen by this shuffle the same as
25394 // the element chosen by the shuffle operand itself?
25395 if (Mask0[Mask[i]] != Mask0[i])
25396 return SDValue();
25397 }
25398 // Every element of this shuffle is identical to the result of the previous
25399 // shuffle, so we can replace this value.
25400 return Shuf->getOperand(Num: 0);
25401}
25402
25403SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
25404 EVT VT = N->getValueType(ResNo: 0);
25405 unsigned NumElts = VT.getVectorNumElements();
25406
25407 SDValue N0 = N->getOperand(Num: 0);
25408 SDValue N1 = N->getOperand(Num: 1);
25409
25410 assert(N0.getValueType() == VT && "Vector shuffle must be normalized in DAG");
25411
25412 // Canonicalize shuffle undef, undef -> undef
25413 if (N0.isUndef() && N1.isUndef())
25414 return DAG.getUNDEF(VT);
25415
25416 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Val: N);
25417
25418 // Canonicalize shuffle v, v -> v, undef
25419 if (N0 == N1)
25420 return DAG.getVectorShuffle(VT, dl: SDLoc(N), N1: N0, N2: DAG.getUNDEF(VT),
25421 Mask: createUnaryMask(Mask: SVN->getMask(), NumElts));
25422
25423 // Canonicalize shuffle undef, v -> v, undef. Commute the shuffle mask.
25424 if (N0.isUndef())
25425 return DAG.getCommutedVectorShuffle(SV: *SVN);
25426
25427 // Remove references to rhs if it is undef
25428 if (N1.isUndef()) {
25429 bool Changed = false;
25430 SmallVector<int, 8> NewMask;
25431 for (unsigned i = 0; i != NumElts; ++i) {
25432 int Idx = SVN->getMaskElt(Idx: i);
25433 if (Idx >= (int)NumElts) {
25434 Idx = -1;
25435 Changed = true;
25436 }
25437 NewMask.push_back(Elt: Idx);
25438 }
25439 if (Changed)
25440 return DAG.getVectorShuffle(VT, dl: SDLoc(N), N1: N0, N2: N1, Mask: NewMask);
25441 }
25442
25443 if (SDValue InsElt = replaceShuffleOfInsert(Shuf: SVN, DAG))
25444 return InsElt;
25445
25446 // A shuffle of a single vector that is a splatted value can always be folded.
25447 if (SDValue V = combineShuffleOfSplatVal(Shuf: SVN, DAG))
25448 return V;
25449
25450 if (SDValue V = formSplatFromShuffles(OuterShuf: SVN, DAG))
25451 return V;
25452
25453 // If it is a splat, check if the argument vector is another splat or a
25454 // build_vector.
25455 if (SVN->isSplat() && SVN->getSplatIndex() < (int)NumElts) {
25456 int SplatIndex = SVN->getSplatIndex();
25457 if (N0.hasOneUse() && TLI.isExtractVecEltCheap(VT, Index: SplatIndex) &&
25458 TLI.isBinOp(Opcode: N0.getOpcode()) && N0->getNumValues() == 1) {
25459 // splat (vector_bo L, R), Index -->
25460 // splat (scalar_bo (extelt L, Index), (extelt R, Index))
25461 SDValue L = N0.getOperand(i: 0), R = N0.getOperand(i: 1);
25462 SDLoc DL(N);
25463 EVT EltVT = VT.getScalarType();
25464 SDValue Index = DAG.getVectorIdxConstant(Val: SplatIndex, DL);
25465 SDValue ExtL = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: L, N2: Index);
25466 SDValue ExtR = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: R, N2: Index);
25467 SDValue NewBO =
25468 DAG.getNode(Opcode: N0.getOpcode(), DL, VT: EltVT, N1: ExtL, N2: ExtR, Flags: N0->getFlags());
25469 SDValue Insert = DAG.getNode(Opcode: ISD::SCALAR_TO_VECTOR, DL, VT, Operand: NewBO);
25470 SmallVector<int, 16> ZeroMask(VT.getVectorNumElements(), 0);
25471 return DAG.getVectorShuffle(VT, dl: DL, N1: Insert, N2: DAG.getUNDEF(VT), Mask: ZeroMask);
25472 }
25473
25474 // splat(scalar_to_vector(x), 0) -> build_vector(x,...,x)
25475 // splat(insert_vector_elt(v, x, c), c) -> build_vector(x,...,x)
25476 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT)) &&
25477 N0.hasOneUse()) {
25478 if (N0.getOpcode() == ISD::SCALAR_TO_VECTOR && SplatIndex == 0)
25479 return DAG.getSplatBuildVector(VT, DL: SDLoc(N), Op: N0.getOperand(i: 0));
25480
25481 if (N0.getOpcode() == ISD::INSERT_VECTOR_ELT)
25482 if (auto *Idx = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 2)))
25483 if (Idx->getAPIntValue() == SplatIndex)
25484 return DAG.getSplatBuildVector(VT, DL: SDLoc(N), Op: N0.getOperand(i: 1));
25485
25486 // Look through a bitcast if LE and splatting lane 0, through to a
25487 // scalar_to_vector or a build_vector.
25488 if (N0.getOpcode() == ISD::BITCAST && N0.getOperand(i: 0).hasOneUse() &&
25489 SplatIndex == 0 && DAG.getDataLayout().isLittleEndian() &&
25490 (N0.getOperand(i: 0).getOpcode() == ISD::SCALAR_TO_VECTOR ||
25491 N0.getOperand(i: 0).getOpcode() == ISD::BUILD_VECTOR)) {
25492 EVT N00VT = N0.getOperand(i: 0).getValueType();
25493 if (VT.getScalarSizeInBits() <= N00VT.getScalarSizeInBits() &&
25494 VT.isInteger() && N00VT.isInteger()) {
25495 EVT InVT =
25496 TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: VT.getScalarType());
25497 SDValue Op = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 0).getOperand(i: 0),
25498 DL: SDLoc(N), VT: InVT);
25499 return DAG.getSplatBuildVector(VT, DL: SDLoc(N), Op);
25500 }
25501 }
25502 }
25503
25504 // If this is a bit convert that changes the element type of the vector but
25505 // not the number of vector elements, look through it. Be careful not to
25506 // look though conversions that change things like v4f32 to v2f64.
25507 SDNode *V = N0.getNode();
25508 if (V->getOpcode() == ISD::BITCAST) {
25509 SDValue ConvInput = V->getOperand(Num: 0);
25510 if (ConvInput.getValueType().isVector() &&
25511 ConvInput.getValueType().getVectorNumElements() == NumElts)
25512 V = ConvInput.getNode();
25513 }
25514
25515 if (V->getOpcode() == ISD::BUILD_VECTOR) {
25516 assert(V->getNumOperands() == NumElts &&
25517 "BUILD_VECTOR has wrong number of operands");
25518 SDValue Base;
25519 bool AllSame = true;
25520 for (unsigned i = 0; i != NumElts; ++i) {
25521 if (!V->getOperand(Num: i).isUndef()) {
25522 Base = V->getOperand(Num: i);
25523 break;
25524 }
25525 }
25526 // Splat of <u, u, u, u>, return <u, u, u, u>
25527 if (!Base.getNode())
25528 return N0;
25529 for (unsigned i = 0; i != NumElts; ++i) {
25530 if (V->getOperand(Num: i) != Base) {
25531 AllSame = false;
25532 break;
25533 }
25534 }
25535 // Splat of <x, x, x, x>, return <x, x, x, x>
25536 if (AllSame)
25537 return N0;
25538
25539 // Canonicalize any other splat as a build_vector.
25540 SDValue Splatted = V->getOperand(Num: SplatIndex);
25541 SmallVector<SDValue, 8> Ops(NumElts, Splatted);
25542 SDValue NewBV = DAG.getBuildVector(VT: V->getValueType(ResNo: 0), DL: SDLoc(N), Ops);
25543
25544 // We may have jumped through bitcasts, so the type of the
25545 // BUILD_VECTOR may not match the type of the shuffle.
25546 if (V->getValueType(ResNo: 0) != VT)
25547 NewBV = DAG.getBitcast(VT, V: NewBV);
25548 return NewBV;
25549 }
25550 }
25551
25552 // Simplify source operands based on shuffle mask.
25553 if (SimplifyDemandedVectorElts(Op: SDValue(N, 0)))
25554 return SDValue(N, 0);
25555
25556 // This is intentionally placed after demanded elements simplification because
25557 // it could eliminate knowledge of undef elements created by this shuffle.
25558 if (SDValue ShufOp = simplifyShuffleOfShuffle(Shuf: SVN))
25559 return ShufOp;
25560
25561 // Match shuffles that can be converted to any_vector_extend_in_reg.
25562 if (SDValue V =
25563 combineShuffleToAnyExtendVectorInreg(SVN, DAG, TLI, LegalOperations))
25564 return V;
25565
25566 // Combine "truncate_vector_in_reg" style shuffles.
25567 if (SDValue V = combineTruncationShuffle(SVN, DAG))
25568 return V;
25569
25570 if (N0.getOpcode() == ISD::CONCAT_VECTORS &&
25571 Level < AfterLegalizeVectorOps &&
25572 (N1.isUndef() ||
25573 (N1.getOpcode() == ISD::CONCAT_VECTORS &&
25574 N0.getOperand(i: 0).getValueType() == N1.getOperand(i: 0).getValueType()))) {
25575 if (SDValue V = partitionShuffleOfConcats(N, DAG))
25576 return V;
25577 }
25578
25579 // A shuffle of a concat of the same narrow vector can be reduced to use
25580 // only low-half elements of a concat with undef:
25581 // shuf (concat X, X), undef, Mask --> shuf (concat X, undef), undef, Mask'
25582 if (N0.getOpcode() == ISD::CONCAT_VECTORS && N1.isUndef() &&
25583 N0.getNumOperands() == 2 &&
25584 N0.getOperand(i: 0) == N0.getOperand(i: 1)) {
25585 int HalfNumElts = (int)NumElts / 2;
25586 SmallVector<int, 8> NewMask;
25587 for (unsigned i = 0; i != NumElts; ++i) {
25588 int Idx = SVN->getMaskElt(Idx: i);
25589 if (Idx >= HalfNumElts) {
25590 assert(Idx < (int)NumElts && "Shuffle mask chooses undef op");
25591 Idx -= HalfNumElts;
25592 }
25593 NewMask.push_back(Elt: Idx);
25594 }
25595 if (TLI.isShuffleMaskLegal(NewMask, VT)) {
25596 SDValue UndefVec = DAG.getUNDEF(VT: N0.getOperand(i: 0).getValueType());
25597 SDValue NewCat = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT,
25598 N1: N0.getOperand(i: 0), N2: UndefVec);
25599 return DAG.getVectorShuffle(VT, dl: SDLoc(N), N1: NewCat, N2: N1, Mask: NewMask);
25600 }
25601 }
25602
25603 // See if we can replace a shuffle with an insert_subvector.
25604 // e.g. v2i32 into v8i32:
25605 // shuffle(lhs,concat(rhs0,rhs1,rhs2,rhs3),0,1,2,3,10,11,6,7).
25606 // --> insert_subvector(lhs,rhs1,4).
25607 if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT) &&
25608 TLI.isOperationLegalOrCustom(Op: ISD::INSERT_SUBVECTOR, VT)) {
25609 auto ShuffleToInsert = [&](SDValue LHS, SDValue RHS, ArrayRef<int> Mask) {
25610 // Ensure RHS subvectors are legal.
25611 assert(RHS.getOpcode() == ISD::CONCAT_VECTORS && "Can't find subvectors");
25612 EVT SubVT = RHS.getOperand(i: 0).getValueType();
25613 int NumSubVecs = RHS.getNumOperands();
25614 int NumSubElts = SubVT.getVectorNumElements();
25615 assert((NumElts % NumSubElts) == 0 && "Subvector mismatch");
25616 if (!TLI.isTypeLegal(VT: SubVT))
25617 return SDValue();
25618
25619 // Don't bother if we have an unary shuffle (matches undef + LHS elts).
25620 if (all_of(Range&: Mask, P: [NumElts](int M) { return M < (int)NumElts; }))
25621 return SDValue();
25622
25623 // Search [NumSubElts] spans for RHS sequence.
25624 // TODO: Can we avoid nested loops to increase performance?
25625 SmallVector<int> InsertionMask(NumElts);
25626 for (int SubVec = 0; SubVec != NumSubVecs; ++SubVec) {
25627 for (int SubIdx = 0; SubIdx != (int)NumElts; SubIdx += NumSubElts) {
25628 // Reset mask to identity.
25629 std::iota(first: InsertionMask.begin(), last: InsertionMask.end(), value: 0);
25630
25631 // Add subvector insertion.
25632 std::iota(first: InsertionMask.begin() + SubIdx,
25633 last: InsertionMask.begin() + SubIdx + NumSubElts,
25634 value: NumElts + (SubVec * NumSubElts));
25635
25636 // See if the shuffle mask matches the reference insertion mask.
25637 bool MatchingShuffle = true;
25638 for (int i = 0; i != (int)NumElts; ++i) {
25639 int ExpectIdx = InsertionMask[i];
25640 int ActualIdx = Mask[i];
25641 if (0 <= ActualIdx && ExpectIdx != ActualIdx) {
25642 MatchingShuffle = false;
25643 break;
25644 }
25645 }
25646
25647 if (MatchingShuffle)
25648 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N), VT, N1: LHS,
25649 N2: RHS.getOperand(i: SubVec),
25650 N3: DAG.getVectorIdxConstant(Val: SubIdx, DL: SDLoc(N)));
25651 }
25652 }
25653 return SDValue();
25654 };
25655 ArrayRef<int> Mask = SVN->getMask();
25656 if (N1.getOpcode() == ISD::CONCAT_VECTORS)
25657 if (SDValue InsertN1 = ShuffleToInsert(N0, N1, Mask))
25658 return InsertN1;
25659 if (N0.getOpcode() == ISD::CONCAT_VECTORS) {
25660 SmallVector<int> CommuteMask(Mask);
25661 ShuffleVectorSDNode::commuteMask(Mask: CommuteMask);
25662 if (SDValue InsertN0 = ShuffleToInsert(N1, N0, CommuteMask))
25663 return InsertN0;
25664 }
25665 }
25666
25667 // If we're not performing a select/blend shuffle, see if we can convert the
25668 // shuffle into a AND node, with all the out-of-lane elements are known zero.
25669 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
25670 bool IsInLaneMask = true;
25671 ArrayRef<int> Mask = SVN->getMask();
25672 SmallVector<int, 16> ClearMask(NumElts, -1);
25673 APInt DemandedLHS = APInt::getZero(numBits: NumElts);
25674 APInt DemandedRHS = APInt::getZero(numBits: NumElts);
25675 for (int I = 0; I != (int)NumElts; ++I) {
25676 int M = Mask[I];
25677 if (M < 0)
25678 continue;
25679 ClearMask[I] = M == I ? I : (I + NumElts);
25680 IsInLaneMask &= (M == I) || (M == (int)(I + NumElts));
25681 if (M != I) {
25682 APInt &Demanded = M < (int)NumElts ? DemandedLHS : DemandedRHS;
25683 Demanded.setBit(M % NumElts);
25684 }
25685 }
25686 // TODO: Should we try to mask with N1 as well?
25687 if (!IsInLaneMask && (!DemandedLHS.isZero() || !DemandedRHS.isZero()) &&
25688 (DemandedLHS.isZero() || DAG.MaskedVectorIsZero(Op: N0, DemandedElts: DemandedLHS)) &&
25689 (DemandedRHS.isZero() || DAG.MaskedVectorIsZero(Op: N1, DemandedElts: DemandedRHS))) {
25690 SDLoc DL(N);
25691 EVT IntVT = VT.changeVectorElementTypeToInteger();
25692 EVT IntSVT = VT.getVectorElementType().changeTypeToInteger();
25693 // Transform the type to a legal type so that the buildvector constant
25694 // elements are not illegal. Make sure that the result is larger than the
25695 // original type, incase the value is split into two (eg i64->i32).
25696 if (!TLI.isTypeLegal(VT: IntSVT) && LegalTypes)
25697 IntSVT = TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: IntSVT);
25698 if (IntSVT.getSizeInBits() >= IntVT.getScalarSizeInBits()) {
25699 SDValue ZeroElt = DAG.getConstant(Val: 0, DL, VT: IntSVT);
25700 SDValue AllOnesElt = DAG.getAllOnesConstant(DL, VT: IntSVT);
25701 SmallVector<SDValue, 16> AndMask(NumElts, DAG.getUNDEF(VT: IntSVT));
25702 for (int I = 0; I != (int)NumElts; ++I)
25703 if (0 <= Mask[I])
25704 AndMask[I] = Mask[I] == I ? AllOnesElt : ZeroElt;
25705
25706 // See if a clear mask is legal instead of going via
25707 // XformToShuffleWithZero which loses UNDEF mask elements.
25708 if (TLI.isVectorClearMaskLegal(ClearMask, IntVT))
25709 return DAG.getBitcast(
25710 VT, V: DAG.getVectorShuffle(VT: IntVT, dl: DL, N1: DAG.getBitcast(VT: IntVT, V: N0),
25711 N2: DAG.getConstant(Val: 0, DL, VT: IntVT), Mask: ClearMask));
25712
25713 if (TLI.isOperationLegalOrCustom(Op: ISD::AND, VT: IntVT))
25714 return DAG.getBitcast(
25715 VT, V: DAG.getNode(Opcode: ISD::AND, DL, VT: IntVT, N1: DAG.getBitcast(VT: IntVT, V: N0),
25716 N2: DAG.getBuildVector(VT: IntVT, DL, Ops: AndMask)));
25717 }
25718 }
25719 }
25720
25721 // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
25722 // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
25723 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT))
25724 if (SDValue Res = combineShuffleOfScalars(SVN, DAG, TLI))
25725 return Res;
25726
25727 // If this shuffle only has a single input that is a bitcasted shuffle,
25728 // attempt to merge the 2 shuffles and suitably bitcast the inputs/output
25729 // back to their original types.
25730 if (N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
25731 N1.isUndef() && Level < AfterLegalizeVectorOps &&
25732 TLI.isTypeLegal(VT)) {
25733
25734 SDValue BC0 = peekThroughOneUseBitcasts(V: N0);
25735 if (BC0.getOpcode() == ISD::VECTOR_SHUFFLE && BC0.hasOneUse()) {
25736 EVT SVT = VT.getScalarType();
25737 EVT InnerVT = BC0->getValueType(ResNo: 0);
25738 EVT InnerSVT = InnerVT.getScalarType();
25739
25740 // Determine which shuffle works with the smaller scalar type.
25741 EVT ScaleVT = SVT.bitsLT(VT: InnerSVT) ? VT : InnerVT;
25742 EVT ScaleSVT = ScaleVT.getScalarType();
25743
25744 if (TLI.isTypeLegal(VT: ScaleVT) &&
25745 0 == (InnerSVT.getSizeInBits() % ScaleSVT.getSizeInBits()) &&
25746 0 == (SVT.getSizeInBits() % ScaleSVT.getSizeInBits())) {
25747 int InnerScale = InnerSVT.getSizeInBits() / ScaleSVT.getSizeInBits();
25748 int OuterScale = SVT.getSizeInBits() / ScaleSVT.getSizeInBits();
25749
25750 // Scale the shuffle masks to the smaller scalar type.
25751 ShuffleVectorSDNode *InnerSVN = cast<ShuffleVectorSDNode>(Val&: BC0);
25752 SmallVector<int, 8> InnerMask;
25753 SmallVector<int, 8> OuterMask;
25754 narrowShuffleMaskElts(Scale: InnerScale, Mask: InnerSVN->getMask(), ScaledMask&: InnerMask);
25755 narrowShuffleMaskElts(Scale: OuterScale, Mask: SVN->getMask(), ScaledMask&: OuterMask);
25756
25757 // Merge the shuffle masks.
25758 SmallVector<int, 8> NewMask;
25759 for (int M : OuterMask)
25760 NewMask.push_back(Elt: M < 0 ? -1 : InnerMask[M]);
25761
25762 // Test for shuffle mask legality over both commutations.
25763 SDValue SV0 = BC0->getOperand(Num: 0);
25764 SDValue SV1 = BC0->getOperand(Num: 1);
25765 bool LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
25766 if (!LegalMask) {
25767 std::swap(a&: SV0, b&: SV1);
25768 ShuffleVectorSDNode::commuteMask(Mask: NewMask);
25769 LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
25770 }
25771
25772 if (LegalMask) {
25773 SV0 = DAG.getBitcast(VT: ScaleVT, V: SV0);
25774 SV1 = DAG.getBitcast(VT: ScaleVT, V: SV1);
25775 return DAG.getBitcast(
25776 VT, V: DAG.getVectorShuffle(VT: ScaleVT, dl: SDLoc(N), N1: SV0, N2: SV1, Mask: NewMask));
25777 }
25778 }
25779 }
25780 }
25781
25782 // Match shuffles of bitcasts, so long as the mask can be treated as the
25783 // larger type.
25784 if (SDValue V = combineShuffleOfBitcast(SVN, DAG, TLI, LegalOperations))
25785 return V;
25786
25787 // Compute the combined shuffle mask for a shuffle with SV0 as the first
25788 // operand, and SV1 as the second operand.
25789 // i.e. Merge SVN(OtherSVN, N1) -> shuffle(SV0, SV1, Mask) iff Commute = false
25790 // Merge SVN(N1, OtherSVN) -> shuffle(SV0, SV1, Mask') iff Commute = true
25791 auto MergeInnerShuffle =
25792 [NumElts, &VT](bool Commute, ShuffleVectorSDNode *SVN,
25793 ShuffleVectorSDNode *OtherSVN, SDValue N1,
25794 const TargetLowering &TLI, SDValue &SV0, SDValue &SV1,
25795 SmallVectorImpl<int> &Mask) -> bool {
25796 // Don't try to fold splats; they're likely to simplify somehow, or they
25797 // might be free.
25798 if (OtherSVN->isSplat())
25799 return false;
25800
25801 SV0 = SV1 = SDValue();
25802 Mask.clear();
25803
25804 for (unsigned i = 0; i != NumElts; ++i) {
25805 int Idx = SVN->getMaskElt(Idx: i);
25806 if (Idx < 0) {
25807 // Propagate Undef.
25808 Mask.push_back(Elt: Idx);
25809 continue;
25810 }
25811
25812 if (Commute)
25813 Idx = (Idx < (int)NumElts) ? (Idx + NumElts) : (Idx - NumElts);
25814
25815 SDValue CurrentVec;
25816 if (Idx < (int)NumElts) {
25817 // This shuffle index refers to the inner shuffle N0. Lookup the inner
25818 // shuffle mask to identify which vector is actually referenced.
25819 Idx = OtherSVN->getMaskElt(Idx);
25820 if (Idx < 0) {
25821 // Propagate Undef.
25822 Mask.push_back(Elt: Idx);
25823 continue;
25824 }
25825 CurrentVec = (Idx < (int)NumElts) ? OtherSVN->getOperand(Num: 0)
25826 : OtherSVN->getOperand(Num: 1);
25827 } else {
25828 // This shuffle index references an element within N1.
25829 CurrentVec = N1;
25830 }
25831
25832 // Simple case where 'CurrentVec' is UNDEF.
25833 if (CurrentVec.isUndef()) {
25834 Mask.push_back(Elt: -1);
25835 continue;
25836 }
25837
25838 // Canonicalize the shuffle index. We don't know yet if CurrentVec
25839 // will be the first or second operand of the combined shuffle.
25840 Idx = Idx % NumElts;
25841 if (!SV0.getNode() || SV0 == CurrentVec) {
25842 // Ok. CurrentVec is the left hand side.
25843 // Update the mask accordingly.
25844 SV0 = CurrentVec;
25845 Mask.push_back(Elt: Idx);
25846 continue;
25847 }
25848 if (!SV1.getNode() || SV1 == CurrentVec) {
25849 // Ok. CurrentVec is the right hand side.
25850 // Update the mask accordingly.
25851 SV1 = CurrentVec;
25852 Mask.push_back(Elt: Idx + NumElts);
25853 continue;
25854 }
25855
25856 // Last chance - see if the vector is another shuffle and if it
25857 // uses one of the existing candidate shuffle ops.
25858 if (auto *CurrentSVN = dyn_cast<ShuffleVectorSDNode>(Val&: CurrentVec)) {
25859 int InnerIdx = CurrentSVN->getMaskElt(Idx);
25860 if (InnerIdx < 0) {
25861 Mask.push_back(Elt: -1);
25862 continue;
25863 }
25864 SDValue InnerVec = (InnerIdx < (int)NumElts)
25865 ? CurrentSVN->getOperand(Num: 0)
25866 : CurrentSVN->getOperand(Num: 1);
25867 if (InnerVec.isUndef()) {
25868 Mask.push_back(Elt: -1);
25869 continue;
25870 }
25871 InnerIdx %= NumElts;
25872 if (InnerVec == SV0) {
25873 Mask.push_back(Elt: InnerIdx);
25874 continue;
25875 }
25876 if (InnerVec == SV1) {
25877 Mask.push_back(Elt: InnerIdx + NumElts);
25878 continue;
25879 }
25880 }
25881
25882 // Bail out if we cannot convert the shuffle pair into a single shuffle.
25883 return false;
25884 }
25885
25886 if (llvm::all_of(Range&: Mask, P: [](int M) { return M < 0; }))
25887 return true;
25888
25889 // Avoid introducing shuffles with illegal mask.
25890 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
25891 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
25892 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
25893 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, A, M2)
25894 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, A, M2)
25895 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, B, M2)
25896 if (TLI.isShuffleMaskLegal(Mask, VT))
25897 return true;
25898
25899 std::swap(a&: SV0, b&: SV1);
25900 ShuffleVectorSDNode::commuteMask(Mask);
25901 return TLI.isShuffleMaskLegal(Mask, VT);
25902 };
25903
25904 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
25905 // Canonicalize shuffles according to rules:
25906 // shuffle(A, shuffle(A, B)) -> shuffle(shuffle(A,B), A)
25907 // shuffle(B, shuffle(A, B)) -> shuffle(shuffle(A,B), B)
25908 // shuffle(B, shuffle(A, Undef)) -> shuffle(shuffle(A, Undef), B)
25909 if (N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
25910 N0.getOpcode() != ISD::VECTOR_SHUFFLE) {
25911 // The incoming shuffle must be of the same type as the result of the
25912 // current shuffle.
25913 assert(N1->getOperand(0).getValueType() == VT &&
25914 "Shuffle types don't match");
25915
25916 SDValue SV0 = N1->getOperand(Num: 0);
25917 SDValue SV1 = N1->getOperand(Num: 1);
25918 bool HasSameOp0 = N0 == SV0;
25919 bool IsSV1Undef = SV1.isUndef();
25920 if (HasSameOp0 || IsSV1Undef || N0 == SV1)
25921 // Commute the operands of this shuffle so merging below will trigger.
25922 return DAG.getCommutedVectorShuffle(SV: *SVN);
25923 }
25924
25925 // Canonicalize splat shuffles to the RHS to improve merging below.
25926 // shuffle(splat(A,u), shuffle(C,D)) -> shuffle'(shuffle(C,D), splat(A,u))
25927 if (N0.getOpcode() == ISD::VECTOR_SHUFFLE &&
25928 N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
25929 cast<ShuffleVectorSDNode>(Val&: N0)->isSplat() &&
25930 !cast<ShuffleVectorSDNode>(Val&: N1)->isSplat()) {
25931 return DAG.getCommutedVectorShuffle(SV: *SVN);
25932 }
25933
25934 // Try to fold according to rules:
25935 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
25936 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
25937 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
25938 // Don't try to fold shuffles with illegal type.
25939 // Only fold if this shuffle is the only user of the other shuffle.
25940 // Try matching shuffle(C,shuffle(A,B)) commutted patterns as well.
25941 for (int i = 0; i != 2; ++i) {
25942 if (N->getOperand(Num: i).getOpcode() == ISD::VECTOR_SHUFFLE &&
25943 N->isOnlyUserOf(N: N->getOperand(Num: i).getNode())) {
25944 // The incoming shuffle must be of the same type as the result of the
25945 // current shuffle.
25946 auto *OtherSV = cast<ShuffleVectorSDNode>(Val: N->getOperand(Num: i));
25947 assert(OtherSV->getOperand(0).getValueType() == VT &&
25948 "Shuffle types don't match");
25949
25950 SDValue SV0, SV1;
25951 SmallVector<int, 4> Mask;
25952 if (MergeInnerShuffle(i != 0, SVN, OtherSV, N->getOperand(Num: 1 - i), TLI,
25953 SV0, SV1, Mask)) {
25954 // Check if all indices in Mask are Undef. In case, propagate Undef.
25955 if (llvm::all_of(Range&: Mask, P: [](int M) { return M < 0; }))
25956 return DAG.getUNDEF(VT);
25957
25958 return DAG.getVectorShuffle(VT, dl: SDLoc(N),
25959 N1: SV0 ? SV0 : DAG.getUNDEF(VT),
25960 N2: SV1 ? SV1 : DAG.getUNDEF(VT), Mask);
25961 }
25962 }
25963 }
25964
25965 // Merge shuffles through binops if we are able to merge it with at least
25966 // one other shuffles.
25967 // shuffle(bop(shuffle(x,y),shuffle(z,w)),undef)
25968 // shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d)))
25969 unsigned SrcOpcode = N0.getOpcode();
25970 if (TLI.isBinOp(Opcode: SrcOpcode) && N->isOnlyUserOf(N: N0.getNode()) &&
25971 (N1.isUndef() ||
25972 (SrcOpcode == N1.getOpcode() && N->isOnlyUserOf(N: N1.getNode())))) {
25973 // Get binop source ops, or just pass on the undef.
25974 SDValue Op00 = N0.getOperand(i: 0);
25975 SDValue Op01 = N0.getOperand(i: 1);
25976 SDValue Op10 = N1.isUndef() ? N1 : N1.getOperand(i: 0);
25977 SDValue Op11 = N1.isUndef() ? N1 : N1.getOperand(i: 1);
25978 // TODO: We might be able to relax the VT check but we don't currently
25979 // have any isBinOp() that has different result/ops VTs so play safe until
25980 // we have test coverage.
25981 if (Op00.getValueType() == VT && Op10.getValueType() == VT &&
25982 Op01.getValueType() == VT && Op11.getValueType() == VT &&
25983 (Op00.getOpcode() == ISD::VECTOR_SHUFFLE ||
25984 Op10.getOpcode() == ISD::VECTOR_SHUFFLE ||
25985 Op01.getOpcode() == ISD::VECTOR_SHUFFLE ||
25986 Op11.getOpcode() == ISD::VECTOR_SHUFFLE)) {
25987 auto CanMergeInnerShuffle = [&](SDValue &SV0, SDValue &SV1,
25988 SmallVectorImpl<int> &Mask, bool LeftOp,
25989 bool Commute) {
25990 SDValue InnerN = Commute ? N1 : N0;
25991 SDValue Op0 = LeftOp ? Op00 : Op01;
25992 SDValue Op1 = LeftOp ? Op10 : Op11;
25993 if (Commute)
25994 std::swap(a&: Op0, b&: Op1);
25995 // Only accept the merged shuffle if we don't introduce undef elements,
25996 // or the inner shuffle already contained undef elements.
25997 auto *SVN0 = dyn_cast<ShuffleVectorSDNode>(Val&: Op0);
25998 return SVN0 && InnerN->isOnlyUserOf(N: SVN0) &&
25999 MergeInnerShuffle(Commute, SVN, SVN0, Op1, TLI, SV0, SV1,
26000 Mask) &&
26001 (llvm::any_of(Range: SVN0->getMask(), P: [](int M) { return M < 0; }) ||
26002 llvm::none_of(Range&: Mask, P: [](int M) { return M < 0; }));
26003 };
26004
26005 // Ensure we don't increase the number of shuffles - we must merge a
26006 // shuffle from at least one of the LHS and RHS ops.
26007 bool MergedLeft = false;
26008 SDValue LeftSV0, LeftSV1;
26009 SmallVector<int, 4> LeftMask;
26010 if (CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, false) ||
26011 CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, true)) {
26012 MergedLeft = true;
26013 } else {
26014 LeftMask.assign(in_start: SVN->getMask().begin(), in_end: SVN->getMask().end());
26015 LeftSV0 = Op00, LeftSV1 = Op10;
26016 }
26017
26018 bool MergedRight = false;
26019 SDValue RightSV0, RightSV1;
26020 SmallVector<int, 4> RightMask;
26021 if (CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, false) ||
26022 CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, true)) {
26023 MergedRight = true;
26024 } else {
26025 RightMask.assign(in_start: SVN->getMask().begin(), in_end: SVN->getMask().end());
26026 RightSV0 = Op01, RightSV1 = Op11;
26027 }
26028
26029 if (MergedLeft || MergedRight) {
26030 SDLoc DL(N);
26031 SDValue LHS = DAG.getVectorShuffle(
26032 VT, dl: DL, N1: LeftSV0 ? LeftSV0 : DAG.getUNDEF(VT),
26033 N2: LeftSV1 ? LeftSV1 : DAG.getUNDEF(VT), Mask: LeftMask);
26034 SDValue RHS = DAG.getVectorShuffle(
26035 VT, dl: DL, N1: RightSV0 ? RightSV0 : DAG.getUNDEF(VT),
26036 N2: RightSV1 ? RightSV1 : DAG.getUNDEF(VT), Mask: RightMask);
26037 return DAG.getNode(Opcode: SrcOpcode, DL, VT, N1: LHS, N2: RHS);
26038 }
26039 }
26040 }
26041 }
26042
26043 if (SDValue V = foldShuffleOfConcatUndefs(Shuf: SVN, DAG))
26044 return V;
26045
26046 // Match shuffles that can be converted to ISD::ZERO_EXTEND_VECTOR_INREG.
26047 // Perform this really late, because it could eliminate knowledge
26048 // of undef elements created by this shuffle.
26049 if (Level < AfterLegalizeTypes)
26050 if (SDValue V = combineShuffleToZeroExtendVectorInReg(SVN, DAG, TLI,
26051 LegalOperations))
26052 return V;
26053
26054 return SDValue();
26055}
26056
26057SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) {
26058 EVT VT = N->getValueType(ResNo: 0);
26059 if (!VT.isFixedLengthVector())
26060 return SDValue();
26061
26062 // Try to convert a scalar binop with an extracted vector element to a vector
26063 // binop. This is intended to reduce potentially expensive register moves.
26064 // TODO: Check if both operands are extracted.
26065 // TODO: How to prefer scalar/vector ops with multiple uses of the extact?
26066 // TODO: Generalize this, so it can be called from visitINSERT_VECTOR_ELT().
26067 SDValue Scalar = N->getOperand(Num: 0);
26068 unsigned Opcode = Scalar.getOpcode();
26069 EVT VecEltVT = VT.getScalarType();
26070 if (Scalar.hasOneUse() && Scalar->getNumValues() == 1 &&
26071 TLI.isBinOp(Opcode) && Scalar.getValueType() == VecEltVT &&
26072 Scalar.getOperand(i: 0).getValueType() == VecEltVT &&
26073 Scalar.getOperand(i: 1).getValueType() == VecEltVT &&
26074 Scalar->isOnlyUserOf(N: Scalar.getOperand(i: 0).getNode()) &&
26075 Scalar->isOnlyUserOf(N: Scalar.getOperand(i: 1).getNode()) &&
26076 DAG.isSafeToSpeculativelyExecute(Opcode) && hasOperation(Opcode, VT)) {
26077 // Match an extract element and get a shuffle mask equivalent.
26078 SmallVector<int, 8> ShufMask(VT.getVectorNumElements(), -1);
26079
26080 for (int i : {0, 1}) {
26081 // s2v (bo (extelt V, Idx), C) --> shuffle (bo V, C'), {Idx, -1, -1...}
26082 // s2v (bo C, (extelt V, Idx)) --> shuffle (bo C', V), {Idx, -1, -1...}
26083 SDValue EE = Scalar.getOperand(i);
26084 auto *C = dyn_cast<ConstantSDNode>(Val: Scalar.getOperand(i: i ? 0 : 1));
26085 if (C && EE.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
26086 EE.getOperand(i: 0).getValueType() == VT &&
26087 isa<ConstantSDNode>(Val: EE.getOperand(i: 1))) {
26088 // Mask = {ExtractIndex, undef, undef....}
26089 ShufMask[0] = EE.getConstantOperandVal(i: 1);
26090 // Make sure the shuffle is legal if we are crossing lanes.
26091 if (TLI.isShuffleMaskLegal(ShufMask, VT)) {
26092 SDLoc DL(N);
26093 SDValue V[] = {EE.getOperand(i: 0),
26094 DAG.getConstant(Val: C->getAPIntValue(), DL, VT)};
26095 SDValue VecBO = DAG.getNode(Opcode, DL, VT, N1: V[i], N2: V[1 - i]);
26096 return DAG.getVectorShuffle(VT, dl: DL, N1: VecBO, N2: DAG.getUNDEF(VT),
26097 Mask: ShufMask);
26098 }
26099 }
26100 }
26101 }
26102
26103 // Replace a SCALAR_TO_VECTOR(EXTRACT_VECTOR_ELT(V,C0)) pattern
26104 // with a VECTOR_SHUFFLE and possible truncate.
26105 if (Opcode != ISD::EXTRACT_VECTOR_ELT ||
26106 !Scalar.getOperand(i: 0).getValueType().isFixedLengthVector())
26107 return SDValue();
26108
26109 // If we have an implicit truncate, truncate here if it is legal.
26110 if (VecEltVT != Scalar.getValueType() &&
26111 Scalar.getValueType().isScalarInteger() && isTypeLegal(VT: VecEltVT)) {
26112 SDValue Val = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(Scalar), VT: VecEltVT, Operand: Scalar);
26113 return DAG.getNode(Opcode: ISD::SCALAR_TO_VECTOR, DL: SDLoc(N), VT, Operand: Val);
26114 }
26115
26116 auto *ExtIndexC = dyn_cast<ConstantSDNode>(Val: Scalar.getOperand(i: 1));
26117 if (!ExtIndexC)
26118 return SDValue();
26119
26120 SDValue SrcVec = Scalar.getOperand(i: 0);
26121 EVT SrcVT = SrcVec.getValueType();
26122 unsigned SrcNumElts = SrcVT.getVectorNumElements();
26123 unsigned VTNumElts = VT.getVectorNumElements();
26124 if (VecEltVT == SrcVT.getScalarType() && VTNumElts <= SrcNumElts) {
26125 // Create a shuffle equivalent for scalar-to-vector: {ExtIndex, -1, -1, ...}
26126 SmallVector<int, 8> Mask(SrcNumElts, -1);
26127 Mask[0] = ExtIndexC->getZExtValue();
26128 SDValue LegalShuffle = TLI.buildLegalVectorShuffle(
26129 VT: SrcVT, DL: SDLoc(N), N0: SrcVec, N1: DAG.getUNDEF(VT: SrcVT), Mask, DAG);
26130 if (!LegalShuffle)
26131 return SDValue();
26132
26133 // If the initial vector is the same size, the shuffle is the result.
26134 if (VT == SrcVT)
26135 return LegalShuffle;
26136
26137 // If not, shorten the shuffled vector.
26138 if (VTNumElts != SrcNumElts) {
26139 SDValue ZeroIdx = DAG.getVectorIdxConstant(Val: 0, DL: SDLoc(N));
26140 EVT SubVT = EVT::getVectorVT(Context&: *DAG.getContext(),
26141 VT: SrcVT.getVectorElementType(), NumElements: VTNumElts);
26142 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL: SDLoc(N), VT: SubVT, N1: LegalShuffle,
26143 N2: ZeroIdx);
26144 }
26145 }
26146
26147 return SDValue();
26148}
26149
26150SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) {
26151 EVT VT = N->getValueType(ResNo: 0);
26152 SDValue N0 = N->getOperand(Num: 0);
26153 SDValue N1 = N->getOperand(Num: 1);
26154 SDValue N2 = N->getOperand(Num: 2);
26155 uint64_t InsIdx = N->getConstantOperandVal(Num: 2);
26156
26157 // If inserting an UNDEF, just return the original vector.
26158 if (N1.isUndef())
26159 return N0;
26160
26161 // If this is an insert of an extracted vector into an undef vector, we can
26162 // just use the input to the extract if the types match, and can simplify
26163 // in some cases even if they don't.
26164 if (N0.isUndef() && N1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
26165 N1.getOperand(i: 1) == N2) {
26166 EVT SrcVT = N1.getOperand(i: 0).getValueType();
26167 if (SrcVT == VT)
26168 return N1.getOperand(i: 0);
26169 // TODO: To remove the zero check, need to adjust the offset to
26170 // a multiple of the new src type.
26171 if (isNullConstant(V: N2) &&
26172 VT.isScalableVector() == SrcVT.isScalableVector()) {
26173 if (VT.getVectorMinNumElements() >= SrcVT.getVectorMinNumElements())
26174 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N),
26175 VT, N1: N0, N2: N1.getOperand(i: 0), N3: N2);
26176 else
26177 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL: SDLoc(N),
26178 VT, N1: N1.getOperand(i: 0), N2);
26179 }
26180 }
26181
26182 // Simplify scalar inserts into an undef vector:
26183 // insert_subvector undef, (splat X), N2 -> splat X
26184 if (N0.isUndef() && N1.getOpcode() == ISD::SPLAT_VECTOR)
26185 if (DAG.isConstantValueOfAnyType(N: N1.getOperand(i: 0)) || N1.hasOneUse())
26186 return DAG.getNode(Opcode: ISD::SPLAT_VECTOR, DL: SDLoc(N), VT, Operand: N1.getOperand(i: 0));
26187
26188 // If we are inserting a bitcast value into an undef, with the same
26189 // number of elements, just use the bitcast input of the extract.
26190 // i.e. INSERT_SUBVECTOR UNDEF (BITCAST N1) N2 ->
26191 // BITCAST (INSERT_SUBVECTOR UNDEF N1 N2)
26192 if (N0.isUndef() && N1.getOpcode() == ISD::BITCAST &&
26193 N1.getOperand(i: 0).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
26194 N1.getOperand(i: 0).getOperand(i: 1) == N2 &&
26195 N1.getOperand(i: 0).getOperand(i: 0).getValueType().getVectorElementCount() ==
26196 VT.getVectorElementCount() &&
26197 N1.getOperand(i: 0).getOperand(i: 0).getValueType().getSizeInBits() ==
26198 VT.getSizeInBits()) {
26199 return DAG.getBitcast(VT, V: N1.getOperand(i: 0).getOperand(i: 0));
26200 }
26201
26202 // If both N1 and N2 are bitcast values on which insert_subvector
26203 // would makes sense, pull the bitcast through.
26204 // i.e. INSERT_SUBVECTOR (BITCAST N0) (BITCAST N1) N2 ->
26205 // BITCAST (INSERT_SUBVECTOR N0 N1 N2)
26206 if (N0.getOpcode() == ISD::BITCAST && N1.getOpcode() == ISD::BITCAST) {
26207 SDValue CN0 = N0.getOperand(i: 0);
26208 SDValue CN1 = N1.getOperand(i: 0);
26209 EVT CN0VT = CN0.getValueType();
26210 EVT CN1VT = CN1.getValueType();
26211 if (CN0VT.isVector() && CN1VT.isVector() &&
26212 CN0VT.getVectorElementType() == CN1VT.getVectorElementType() &&
26213 CN0VT.getVectorElementCount() == VT.getVectorElementCount()) {
26214 SDValue NewINSERT = DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N),
26215 VT: CN0.getValueType(), N1: CN0, N2: CN1, N3: N2);
26216 return DAG.getBitcast(VT, V: NewINSERT);
26217 }
26218 }
26219
26220 // Combine INSERT_SUBVECTORs where we are inserting to the same index.
26221 // INSERT_SUBVECTOR( INSERT_SUBVECTOR( Vec, SubOld, Idx ), SubNew, Idx )
26222 // --> INSERT_SUBVECTOR( Vec, SubNew, Idx )
26223 if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
26224 N0.getOperand(i: 1).getValueType() == N1.getValueType() &&
26225 N0.getOperand(i: 2) == N2)
26226 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N), VT, N1: N0.getOperand(i: 0),
26227 N2: N1, N3: N2);
26228
26229 // Eliminate an intermediate insert into an undef vector:
26230 // insert_subvector undef, (insert_subvector undef, X, 0), 0 -->
26231 // insert_subvector undef, X, 0
26232 if (N0.isUndef() && N1.getOpcode() == ISD::INSERT_SUBVECTOR &&
26233 N1.getOperand(i: 0).isUndef() && isNullConstant(V: N1.getOperand(i: 2)) &&
26234 isNullConstant(V: N2))
26235 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N), VT, N1: N0,
26236 N2: N1.getOperand(i: 1), N3: N2);
26237
26238 // Push subvector bitcasts to the output, adjusting the index as we go.
26239 // insert_subvector(bitcast(v), bitcast(s), c1)
26240 // -> bitcast(insert_subvector(v, s, c2))
26241 if ((N0.isUndef() || N0.getOpcode() == ISD::BITCAST) &&
26242 N1.getOpcode() == ISD::BITCAST) {
26243 SDValue N0Src = peekThroughBitcasts(V: N0);
26244 SDValue N1Src = peekThroughBitcasts(V: N1);
26245 EVT N0SrcSVT = N0Src.getValueType().getScalarType();
26246 EVT N1SrcSVT = N1Src.getValueType().getScalarType();
26247 if ((N0.isUndef() || N0SrcSVT == N1SrcSVT) &&
26248 N0Src.getValueType().isVector() && N1Src.getValueType().isVector()) {
26249 EVT NewVT;
26250 SDLoc DL(N);
26251 SDValue NewIdx;
26252 LLVMContext &Ctx = *DAG.getContext();
26253 ElementCount NumElts = VT.getVectorElementCount();
26254 unsigned EltSizeInBits = VT.getScalarSizeInBits();
26255 if ((EltSizeInBits % N1SrcSVT.getSizeInBits()) == 0) {
26256 unsigned Scale = EltSizeInBits / N1SrcSVT.getSizeInBits();
26257 NewVT = EVT::getVectorVT(Context&: Ctx, VT: N1SrcSVT, EC: NumElts * Scale);
26258 NewIdx = DAG.getVectorIdxConstant(Val: InsIdx * Scale, DL);
26259 } else if ((N1SrcSVT.getSizeInBits() % EltSizeInBits) == 0) {
26260 unsigned Scale = N1SrcSVT.getSizeInBits() / EltSizeInBits;
26261 if (NumElts.isKnownMultipleOf(RHS: Scale) && (InsIdx % Scale) == 0) {
26262 NewVT = EVT::getVectorVT(Context&: Ctx, VT: N1SrcSVT,
26263 EC: NumElts.divideCoefficientBy(RHS: Scale));
26264 NewIdx = DAG.getVectorIdxConstant(Val: InsIdx / Scale, DL);
26265 }
26266 }
26267 if (NewIdx && hasOperation(Opcode: ISD::INSERT_SUBVECTOR, VT: NewVT)) {
26268 SDValue Res = DAG.getBitcast(VT: NewVT, V: N0Src);
26269 Res = DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL, VT: NewVT, N1: Res, N2: N1Src, N3: NewIdx);
26270 return DAG.getBitcast(VT, V: Res);
26271 }
26272 }
26273 }
26274
26275 // Canonicalize insert_subvector dag nodes.
26276 // Example:
26277 // (insert_subvector (insert_subvector A, Idx0), Idx1)
26278 // -> (insert_subvector (insert_subvector A, Idx1), Idx0)
26279 if (N0.getOpcode() == ISD::INSERT_SUBVECTOR && N0.hasOneUse() &&
26280 N1.getValueType() == N0.getOperand(i: 1).getValueType()) {
26281 unsigned OtherIdx = N0.getConstantOperandVal(i: 2);
26282 if (InsIdx < OtherIdx) {
26283 // Swap nodes.
26284 SDValue NewOp = DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N), VT,
26285 N1: N0.getOperand(i: 0), N2: N1, N3: N2);
26286 AddToWorklist(N: NewOp.getNode());
26287 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N0.getNode()),
26288 VT, N1: NewOp, N2: N0.getOperand(i: 1), N3: N0.getOperand(i: 2));
26289 }
26290 }
26291
26292 // If the input vector is a concatenation, and the insert replaces
26293 // one of the pieces, we can optimize into a single concat_vectors.
26294 if (N0.getOpcode() == ISD::CONCAT_VECTORS && N0.hasOneUse() &&
26295 N0.getOperand(i: 0).getValueType() == N1.getValueType() &&
26296 N0.getOperand(i: 0).getValueType().isScalableVector() ==
26297 N1.getValueType().isScalableVector()) {
26298 unsigned Factor = N1.getValueType().getVectorMinNumElements();
26299 SmallVector<SDValue, 8> Ops(N0->op_begin(), N0->op_end());
26300 Ops[InsIdx / Factor] = N1;
26301 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT, Ops);
26302 }
26303
26304 // Simplify source operands based on insertion.
26305 if (SimplifyDemandedVectorElts(Op: SDValue(N, 0)))
26306 return SDValue(N, 0);
26307
26308 return SDValue();
26309}
26310
26311SDValue DAGCombiner::visitFP_TO_FP16(SDNode *N) {
26312 SDValue N0 = N->getOperand(Num: 0);
26313
26314 // fold (fp_to_fp16 (fp16_to_fp op)) -> op
26315 if (N0->getOpcode() == ISD::FP16_TO_FP)
26316 return N0->getOperand(Num: 0);
26317
26318 return SDValue();
26319}
26320
26321SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
26322 auto Op = N->getOpcode();
26323 assert((Op == ISD::FP16_TO_FP || Op == ISD::BF16_TO_FP) &&
26324 "opcode should be FP16_TO_FP or BF16_TO_FP.");
26325 SDValue N0 = N->getOperand(Num: 0);
26326
26327 // fold fp16_to_fp(op & 0xffff) -> fp16_to_fp(op) or
26328 // fold bf16_to_fp(op & 0xffff) -> bf16_to_fp(op)
26329 if (!TLI.shouldKeepZExtForFP16Conv() && N0->getOpcode() == ISD::AND) {
26330 ConstantSDNode *AndConst = getAsNonOpaqueConstant(N: N0.getOperand(i: 1));
26331 if (AndConst && AndConst->getAPIntValue() == 0xffff) {
26332 return DAG.getNode(Opcode: Op, DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Operand: N0.getOperand(i: 0));
26333 }
26334 }
26335
26336 return SDValue();
26337}
26338
26339SDValue DAGCombiner::visitFP_TO_BF16(SDNode *N) {
26340 SDValue N0 = N->getOperand(Num: 0);
26341
26342 // fold (fp_to_bf16 (bf16_to_fp op)) -> op
26343 if (N0->getOpcode() == ISD::BF16_TO_FP)
26344 return N0->getOperand(Num: 0);
26345
26346 return SDValue();
26347}
26348
26349SDValue DAGCombiner::visitBF16_TO_FP(SDNode *N) {
26350 // fold bf16_to_fp(op & 0xffff) -> bf16_to_fp(op)
26351 return visitFP16_TO_FP(N);
26352}
26353
26354SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
26355 SDValue N0 = N->getOperand(Num: 0);
26356 EVT VT = N0.getValueType();
26357 unsigned Opcode = N->getOpcode();
26358
26359 // VECREDUCE over 1-element vector is just an extract.
26360 if (VT.getVectorElementCount().isScalar()) {
26361 SDLoc dl(N);
26362 SDValue Res =
26363 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT: VT.getVectorElementType(), N1: N0,
26364 N2: DAG.getVectorIdxConstant(Val: 0, DL: dl));
26365 if (Res.getValueType() != N->getValueType(ResNo: 0))
26366 Res = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL: dl, VT: N->getValueType(ResNo: 0), Operand: Res);
26367 return Res;
26368 }
26369
26370 // On an boolean vector an and/or reduction is the same as a umin/umax
26371 // reduction. Convert them if the latter is legal while the former isn't.
26372 if (Opcode == ISD::VECREDUCE_AND || Opcode == ISD::VECREDUCE_OR) {
26373 unsigned NewOpcode = Opcode == ISD::VECREDUCE_AND
26374 ? ISD::VECREDUCE_UMIN : ISD::VECREDUCE_UMAX;
26375 if (!TLI.isOperationLegalOrCustom(Op: Opcode, VT) &&
26376 TLI.isOperationLegalOrCustom(Op: NewOpcode, VT) &&
26377 DAG.ComputeNumSignBits(Op: N0) == VT.getScalarSizeInBits())
26378 return DAG.getNode(Opcode: NewOpcode, DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Operand: N0);
26379 }
26380
26381 // vecreduce_or(insert_subvector(zero or undef, val)) -> vecreduce_or(val)
26382 // vecreduce_and(insert_subvector(ones or undef, val)) -> vecreduce_and(val)
26383 if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
26384 TLI.isTypeLegal(VT: N0.getOperand(i: 1).getValueType())) {
26385 SDValue Vec = N0.getOperand(i: 0);
26386 SDValue Subvec = N0.getOperand(i: 1);
26387 if ((Opcode == ISD::VECREDUCE_OR &&
26388 (N0.getOperand(i: 0).isUndef() || isNullOrNullSplat(V: Vec))) ||
26389 (Opcode == ISD::VECREDUCE_AND &&
26390 (N0.getOperand(i: 0).isUndef() || isAllOnesOrAllOnesSplat(V: Vec))))
26391 return DAG.getNode(Opcode, DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Operand: Subvec);
26392 }
26393
26394 return SDValue();
26395}
26396
26397SDValue DAGCombiner::visitVP_FSUB(SDNode *N) {
26398 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
26399
26400 // FSUB -> FMA combines:
26401 if (SDValue Fused = visitFSUBForFMACombine<VPMatchContext>(N)) {
26402 AddToWorklist(N: Fused.getNode());
26403 return Fused;
26404 }
26405 return SDValue();
26406}
26407
26408SDValue DAGCombiner::visitVPOp(SDNode *N) {
26409
26410 if (N->getOpcode() == ISD::VP_GATHER)
26411 if (SDValue SD = visitVPGATHER(N))
26412 return SD;
26413
26414 if (N->getOpcode() == ISD::VP_SCATTER)
26415 if (SDValue SD = visitVPSCATTER(N))
26416 return SD;
26417
26418 if (N->getOpcode() == ISD::EXPERIMENTAL_VP_STRIDED_LOAD)
26419 if (SDValue SD = visitVP_STRIDED_LOAD(N))
26420 return SD;
26421
26422 if (N->getOpcode() == ISD::EXPERIMENTAL_VP_STRIDED_STORE)
26423 if (SDValue SD = visitVP_STRIDED_STORE(N))
26424 return SD;
26425
26426 // VP operations in which all vector elements are disabled - either by
26427 // determining that the mask is all false or that the EVL is 0 - can be
26428 // eliminated.
26429 bool AreAllEltsDisabled = false;
26430 if (auto EVLIdx = ISD::getVPExplicitVectorLengthIdx(Opcode: N->getOpcode()))
26431 AreAllEltsDisabled |= isNullConstant(V: N->getOperand(Num: *EVLIdx));
26432 if (auto MaskIdx = ISD::getVPMaskIdx(Opcode: N->getOpcode()))
26433 AreAllEltsDisabled |=
26434 ISD::isConstantSplatVectorAllZeros(N: N->getOperand(Num: *MaskIdx).getNode());
26435
26436 // This is the only generic VP combine we support for now.
26437 if (!AreAllEltsDisabled) {
26438 switch (N->getOpcode()) {
26439 case ISD::VP_FADD:
26440 return visitVP_FADD(N);
26441 case ISD::VP_FSUB:
26442 return visitVP_FSUB(N);
26443 case ISD::VP_FMA:
26444 return visitFMA<VPMatchContext>(N);
26445 case ISD::VP_SELECT:
26446 return visitVP_SELECT(N);
26447 }
26448 return SDValue();
26449 }
26450
26451 // Binary operations can be replaced by UNDEF.
26452 if (ISD::isVPBinaryOp(Opcode: N->getOpcode()))
26453 return DAG.getUNDEF(VT: N->getValueType(ResNo: 0));
26454
26455 // VP Memory operations can be replaced by either the chain (stores) or the
26456 // chain + undef (loads).
26457 if (const auto *MemSD = dyn_cast<MemSDNode>(Val: N)) {
26458 if (MemSD->writeMem())
26459 return MemSD->getChain();
26460 return CombineTo(N, Res0: DAG.getUNDEF(VT: N->getValueType(ResNo: 0)), Res1: MemSD->getChain());
26461 }
26462
26463 // Reduction operations return the start operand when no elements are active.
26464 if (ISD::isVPReduction(Opcode: N->getOpcode()))
26465 return N->getOperand(Num: 0);
26466
26467 return SDValue();
26468}
26469
26470SDValue DAGCombiner::visitGET_FPENV_MEM(SDNode *N) {
26471 SDValue Chain = N->getOperand(Num: 0);
26472 SDValue Ptr = N->getOperand(Num: 1);
26473 EVT MemVT = cast<FPStateAccessSDNode>(Val: N)->getMemoryVT();
26474
26475 // Check if the memory, where FP state is written to, is used only in a single
26476 // load operation.
26477 LoadSDNode *LdNode = nullptr;
26478 for (auto *U : Ptr->uses()) {
26479 if (U == N)
26480 continue;
26481 if (auto *Ld = dyn_cast<LoadSDNode>(Val: U)) {
26482 if (LdNode && LdNode != Ld)
26483 return SDValue();
26484 LdNode = Ld;
26485 continue;
26486 }
26487 return SDValue();
26488 }
26489 if (!LdNode || !LdNode->isSimple() || LdNode->isIndexed() ||
26490 !LdNode->getOffset().isUndef() || LdNode->getMemoryVT() != MemVT ||
26491 !LdNode->getChain().reachesChainWithoutSideEffects(Dest: SDValue(N, 0)))
26492 return SDValue();
26493
26494 // Check if the loaded value is used only in a store operation.
26495 StoreSDNode *StNode = nullptr;
26496 for (auto I = LdNode->use_begin(), E = LdNode->use_end(); I != E; ++I) {
26497 SDUse &U = I.getUse();
26498 if (U.getResNo() == 0) {
26499 if (auto *St = dyn_cast<StoreSDNode>(Val: U.getUser())) {
26500 if (StNode)
26501 return SDValue();
26502 StNode = St;
26503 } else {
26504 return SDValue();
26505 }
26506 }
26507 }
26508 if (!StNode || !StNode->isSimple() || StNode->isIndexed() ||
26509 !StNode->getOffset().isUndef() || StNode->getMemoryVT() != MemVT ||
26510 !StNode->getChain().reachesChainWithoutSideEffects(Dest: SDValue(LdNode, 1)))
26511 return SDValue();
26512
26513 // Create new node GET_FPENV_MEM, which uses the store address to write FP
26514 // environment.
26515 SDValue Res = DAG.getGetFPEnv(Chain, dl: SDLoc(N), Ptr: StNode->getBasePtr(), MemVT,
26516 MMO: StNode->getMemOperand());
26517 CombineTo(N: StNode, Res, AddTo: false);
26518 return Res;
26519}
26520
26521SDValue DAGCombiner::visitSET_FPENV_MEM(SDNode *N) {
26522 SDValue Chain = N->getOperand(Num: 0);
26523 SDValue Ptr = N->getOperand(Num: 1);
26524 EVT MemVT = cast<FPStateAccessSDNode>(Val: N)->getMemoryVT();
26525
26526 // Check if the address of FP state is used also in a store operation only.
26527 StoreSDNode *StNode = nullptr;
26528 for (auto *U : Ptr->uses()) {
26529 if (U == N)
26530 continue;
26531 if (auto *St = dyn_cast<StoreSDNode>(Val: U)) {
26532 if (StNode && StNode != St)
26533 return SDValue();
26534 StNode = St;
26535 continue;
26536 }
26537 return SDValue();
26538 }
26539 if (!StNode || !StNode->isSimple() || StNode->isIndexed() ||
26540 !StNode->getOffset().isUndef() || StNode->getMemoryVT() != MemVT ||
26541 !Chain.reachesChainWithoutSideEffects(Dest: SDValue(StNode, 0)))
26542 return SDValue();
26543
26544 // Check if the stored value is loaded from some location and the loaded
26545 // value is used only in the store operation.
26546 SDValue StValue = StNode->getValue();
26547 auto *LdNode = dyn_cast<LoadSDNode>(Val&: StValue);
26548 if (!LdNode || !LdNode->isSimple() || LdNode->isIndexed() ||
26549 !LdNode->getOffset().isUndef() || LdNode->getMemoryVT() != MemVT ||
26550 !StNode->getChain().reachesChainWithoutSideEffects(Dest: SDValue(LdNode, 1)))
26551 return SDValue();
26552
26553 // Create new node SET_FPENV_MEM, which uses the load address to read FP
26554 // environment.
26555 SDValue Res =
26556 DAG.getSetFPEnv(Chain: LdNode->getChain(), dl: SDLoc(N), Ptr: LdNode->getBasePtr(), MemVT,
26557 MMO: LdNode->getMemOperand());
26558 return Res;
26559}
26560
26561/// Returns a vector_shuffle if it able to transform an AND to a vector_shuffle
26562/// with the destination vector and a zero vector.
26563/// e.g. AND V, <0xffffffff, 0, 0xffffffff, 0>. ==>
26564/// vector_shuffle V, Zero, <0, 4, 2, 4>
26565SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) {
26566 assert(N->getOpcode() == ISD::AND && "Unexpected opcode!");
26567
26568 EVT VT = N->getValueType(ResNo: 0);
26569 SDValue LHS = N->getOperand(Num: 0);
26570 SDValue RHS = peekThroughBitcasts(V: N->getOperand(Num: 1));
26571 SDLoc DL(N);
26572
26573 // Make sure we're not running after operation legalization where it
26574 // may have custom lowered the vector shuffles.
26575 if (LegalOperations)
26576 return SDValue();
26577
26578 if (RHS.getOpcode() != ISD::BUILD_VECTOR)
26579 return SDValue();
26580
26581 EVT RVT = RHS.getValueType();
26582 unsigned NumElts = RHS.getNumOperands();
26583
26584 // Attempt to create a valid clear mask, splitting the mask into
26585 // sub elements and checking to see if each is
26586 // all zeros or all ones - suitable for shuffle masking.
26587 auto BuildClearMask = [&](int Split) {
26588 int NumSubElts = NumElts * Split;
26589 int NumSubBits = RVT.getScalarSizeInBits() / Split;
26590
26591 SmallVector<int, 8> Indices;
26592 for (int i = 0; i != NumSubElts; ++i) {
26593 int EltIdx = i / Split;
26594 int SubIdx = i % Split;
26595 SDValue Elt = RHS.getOperand(i: EltIdx);
26596 // X & undef --> 0 (not undef). So this lane must be converted to choose
26597 // from the zero constant vector (same as if the element had all 0-bits).
26598 if (Elt.isUndef()) {
26599 Indices.push_back(Elt: i + NumSubElts);
26600 continue;
26601 }
26602
26603 APInt Bits;
26604 if (auto *Cst = dyn_cast<ConstantSDNode>(Val&: Elt))
26605 Bits = Cst->getAPIntValue();
26606 else if (auto *CstFP = dyn_cast<ConstantFPSDNode>(Val&: Elt))
26607 Bits = CstFP->getValueAPF().bitcastToAPInt();
26608 else
26609 return SDValue();
26610
26611 // Extract the sub element from the constant bit mask.
26612 if (DAG.getDataLayout().isBigEndian())
26613 Bits = Bits.extractBits(numBits: NumSubBits, bitPosition: (Split - SubIdx - 1) * NumSubBits);
26614 else
26615 Bits = Bits.extractBits(numBits: NumSubBits, bitPosition: SubIdx * NumSubBits);
26616
26617 if (Bits.isAllOnes())
26618 Indices.push_back(Elt: i);
26619 else if (Bits == 0)
26620 Indices.push_back(Elt: i + NumSubElts);
26621 else
26622 return SDValue();
26623 }
26624
26625 // Let's see if the target supports this vector_shuffle.
26626 EVT ClearSVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: NumSubBits);
26627 EVT ClearVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: ClearSVT, NumElements: NumSubElts);
26628 if (!TLI.isVectorClearMaskLegal(Indices, ClearVT))
26629 return SDValue();
26630
26631 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: ClearVT);
26632 return DAG.getBitcast(VT, V: DAG.getVectorShuffle(VT: ClearVT, dl: DL,
26633 N1: DAG.getBitcast(VT: ClearVT, V: LHS),
26634 N2: Zero, Mask: Indices));
26635 };
26636
26637 // Determine maximum split level (byte level masking).
26638 int MaxSplit = 1;
26639 if (RVT.getScalarSizeInBits() % 8 == 0)
26640 MaxSplit = RVT.getScalarSizeInBits() / 8;
26641
26642 for (int Split = 1; Split <= MaxSplit; ++Split)
26643 if (RVT.getScalarSizeInBits() % Split == 0)
26644 if (SDValue S = BuildClearMask(Split))
26645 return S;
26646
26647 return SDValue();
26648}
26649
26650/// If a vector binop is performed on splat values, it may be profitable to
26651/// extract, scalarize, and insert/splat.
26652static SDValue scalarizeBinOpOfSplats(SDNode *N, SelectionDAG &DAG,
26653 const SDLoc &DL) {
26654 SDValue N0 = N->getOperand(Num: 0);
26655 SDValue N1 = N->getOperand(Num: 1);
26656 unsigned Opcode = N->getOpcode();
26657 EVT VT = N->getValueType(ResNo: 0);
26658 EVT EltVT = VT.getVectorElementType();
26659 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
26660
26661 // TODO: Remove/replace the extract cost check? If the elements are available
26662 // as scalars, then there may be no extract cost. Should we ask if
26663 // inserting a scalar back into a vector is cheap instead?
26664 int Index0, Index1;
26665 SDValue Src0 = DAG.getSplatSourceVector(V: N0, SplatIndex&: Index0);
26666 SDValue Src1 = DAG.getSplatSourceVector(V: N1, SplatIndex&: Index1);
26667 // Extract element from splat_vector should be free.
26668 // TODO: use DAG.isSplatValue instead?
26669 bool IsBothSplatVector = N0.getOpcode() == ISD::SPLAT_VECTOR &&
26670 N1.getOpcode() == ISD::SPLAT_VECTOR;
26671 if (!Src0 || !Src1 || Index0 != Index1 ||
26672 Src0.getValueType().getVectorElementType() != EltVT ||
26673 Src1.getValueType().getVectorElementType() != EltVT ||
26674 !(IsBothSplatVector || TLI.isExtractVecEltCheap(VT, Index: Index0)) ||
26675 !TLI.isOperationLegalOrCustom(Op: Opcode, VT: EltVT))
26676 return SDValue();
26677
26678 SDValue IndexC = DAG.getVectorIdxConstant(Val: Index0, DL);
26679 SDValue X = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: Src0, N2: IndexC);
26680 SDValue Y = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: Src1, N2: IndexC);
26681 SDValue ScalarBO = DAG.getNode(Opcode, DL, VT: EltVT, N1: X, N2: Y, Flags: N->getFlags());
26682
26683 // If all lanes but 1 are undefined, no need to splat the scalar result.
26684 // TODO: Keep track of undefs and use that info in the general case.
26685 if (N0.getOpcode() == ISD::BUILD_VECTOR && N0.getOpcode() == N1.getOpcode() &&
26686 count_if(Range: N0->ops(), P: [](SDValue V) { return !V.isUndef(); }) == 1 &&
26687 count_if(Range: N1->ops(), P: [](SDValue V) { return !V.isUndef(); }) == 1) {
26688 // bo (build_vec ..undef, X, undef...), (build_vec ..undef, Y, undef...) -->
26689 // build_vec ..undef, (bo X, Y), undef...
26690 SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), DAG.getUNDEF(VT: EltVT));
26691 Ops[Index0] = ScalarBO;
26692 return DAG.getBuildVector(VT, DL, Ops);
26693 }
26694
26695 // bo (splat X, Index), (splat Y, Index) --> splat (bo X, Y), Index
26696 return DAG.getSplat(VT, DL, Op: ScalarBO);
26697}
26698
26699/// Visit a vector cast operation, like FP_EXTEND.
26700SDValue DAGCombiner::SimplifyVCastOp(SDNode *N, const SDLoc &DL) {
26701 EVT VT = N->getValueType(ResNo: 0);
26702 assert(VT.isVector() && "SimplifyVCastOp only works on vectors!");
26703 EVT EltVT = VT.getVectorElementType();
26704 unsigned Opcode = N->getOpcode();
26705
26706 SDValue N0 = N->getOperand(Num: 0);
26707 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
26708
26709 // TODO: promote operation might be also good here?
26710 int Index0;
26711 SDValue Src0 = DAG.getSplatSourceVector(V: N0, SplatIndex&: Index0);
26712 if (Src0 &&
26713 (N0.getOpcode() == ISD::SPLAT_VECTOR ||
26714 TLI.isExtractVecEltCheap(VT, Index: Index0)) &&
26715 TLI.isOperationLegalOrCustom(Op: Opcode, VT: EltVT) &&
26716 TLI.preferScalarizeSplat(N)) {
26717 EVT SrcVT = N0.getValueType();
26718 EVT SrcEltVT = SrcVT.getVectorElementType();
26719 SDValue IndexC = DAG.getVectorIdxConstant(Val: Index0, DL);
26720 SDValue Elt =
26721 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: SrcEltVT, N1: Src0, N2: IndexC);
26722 SDValue ScalarBO = DAG.getNode(Opcode, DL, VT: EltVT, Operand: Elt, Flags: N->getFlags());
26723 if (VT.isScalableVector())
26724 return DAG.getSplatVector(VT, DL, Op: ScalarBO);
26725 SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), ScalarBO);
26726 return DAG.getBuildVector(VT, DL, Ops);
26727 }
26728
26729 return SDValue();
26730}
26731
26732/// Visit a binary vector operation, like ADD.
26733SDValue DAGCombiner::SimplifyVBinOp(SDNode *N, const SDLoc &DL) {
26734 EVT VT = N->getValueType(ResNo: 0);
26735 assert(VT.isVector() && "SimplifyVBinOp only works on vectors!");
26736
26737 SDValue LHS = N->getOperand(Num: 0);
26738 SDValue RHS = N->getOperand(Num: 1);
26739 unsigned Opcode = N->getOpcode();
26740 SDNodeFlags Flags = N->getFlags();
26741
26742 // Move unary shuffles with identical masks after a vector binop:
26743 // VBinOp (shuffle A, Undef, Mask), (shuffle B, Undef, Mask))
26744 // --> shuffle (VBinOp A, B), Undef, Mask
26745 // This does not require type legality checks because we are creating the
26746 // same types of operations that are in the original sequence. We do have to
26747 // restrict ops like integer div that have immediate UB (eg, div-by-zero)
26748 // though. This code is adapted from the identical transform in instcombine.
26749 if (DAG.isSafeToSpeculativelyExecute(Opcode)) {
26750 auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(Val&: LHS);
26751 auto *Shuf1 = dyn_cast<ShuffleVectorSDNode>(Val&: RHS);
26752 if (Shuf0 && Shuf1 && Shuf0->getMask().equals(RHS: Shuf1->getMask()) &&
26753 LHS.getOperand(i: 1).isUndef() && RHS.getOperand(i: 1).isUndef() &&
26754 (LHS.hasOneUse() || RHS.hasOneUse() || LHS == RHS)) {
26755 SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, N1: LHS.getOperand(i: 0),
26756 N2: RHS.getOperand(i: 0), Flags);
26757 SDValue UndefV = LHS.getOperand(i: 1);
26758 return DAG.getVectorShuffle(VT, dl: DL, N1: NewBinOp, N2: UndefV, Mask: Shuf0->getMask());
26759 }
26760
26761 // Try to sink a splat shuffle after a binop with a uniform constant.
26762 // This is limited to cases where neither the shuffle nor the constant have
26763 // undefined elements because that could be poison-unsafe or inhibit
26764 // demanded elements analysis. It is further limited to not change a splat
26765 // of an inserted scalar because that may be optimized better by
26766 // load-folding or other target-specific behaviors.
26767 if (isConstOrConstSplat(N: RHS) && Shuf0 && all_equal(Range: Shuf0->getMask()) &&
26768 Shuf0->hasOneUse() && Shuf0->getOperand(Num: 1).isUndef() &&
26769 Shuf0->getOperand(Num: 0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
26770 // binop (splat X), (splat C) --> splat (binop X, C)
26771 SDValue X = Shuf0->getOperand(Num: 0);
26772 SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, N1: X, N2: RHS, Flags);
26773 return DAG.getVectorShuffle(VT, dl: DL, N1: NewBinOp, N2: DAG.getUNDEF(VT),
26774 Mask: Shuf0->getMask());
26775 }
26776 if (isConstOrConstSplat(N: LHS) && Shuf1 && all_equal(Range: Shuf1->getMask()) &&
26777 Shuf1->hasOneUse() && Shuf1->getOperand(Num: 1).isUndef() &&
26778 Shuf1->getOperand(Num: 0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
26779 // binop (splat C), (splat X) --> splat (binop C, X)
26780 SDValue X = Shuf1->getOperand(Num: 0);
26781 SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, N1: LHS, N2: X, Flags);
26782 return DAG.getVectorShuffle(VT, dl: DL, N1: NewBinOp, N2: DAG.getUNDEF(VT),
26783 Mask: Shuf1->getMask());
26784 }
26785 }
26786
26787 // The following pattern is likely to emerge with vector reduction ops. Moving
26788 // the binary operation ahead of insertion may allow using a narrower vector
26789 // instruction that has better performance than the wide version of the op:
26790 // VBinOp (ins undef, X, Z), (ins undef, Y, Z) --> ins VecC, (VBinOp X, Y), Z
26791 if (LHS.getOpcode() == ISD::INSERT_SUBVECTOR && LHS.getOperand(i: 0).isUndef() &&
26792 RHS.getOpcode() == ISD::INSERT_SUBVECTOR && RHS.getOperand(i: 0).isUndef() &&
26793 LHS.getOperand(i: 2) == RHS.getOperand(i: 2) &&
26794 (LHS.hasOneUse() || RHS.hasOneUse())) {
26795 SDValue X = LHS.getOperand(i: 1);
26796 SDValue Y = RHS.getOperand(i: 1);
26797 SDValue Z = LHS.getOperand(i: 2);
26798 EVT NarrowVT = X.getValueType();
26799 if (NarrowVT == Y.getValueType() &&
26800 TLI.isOperationLegalOrCustomOrPromote(Op: Opcode, VT: NarrowVT,
26801 LegalOnly: LegalOperations)) {
26802 // (binop undef, undef) may not return undef, so compute that result.
26803 SDValue VecC =
26804 DAG.getNode(Opcode, DL, VT, N1: DAG.getUNDEF(VT), N2: DAG.getUNDEF(VT));
26805 SDValue NarrowBO = DAG.getNode(Opcode, DL, VT: NarrowVT, N1: X, N2: Y);
26806 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL, VT, N1: VecC, N2: NarrowBO, N3: Z);
26807 }
26808 }
26809
26810 // Make sure all but the first op are undef or constant.
26811 auto ConcatWithConstantOrUndef = [](SDValue Concat) {
26812 return Concat.getOpcode() == ISD::CONCAT_VECTORS &&
26813 all_of(Range: drop_begin(RangeOrContainer: Concat->ops()), P: [](const SDValue &Op) {
26814 return Op.isUndef() ||
26815 ISD::isBuildVectorOfConstantSDNodes(N: Op.getNode());
26816 });
26817 };
26818
26819 // The following pattern is likely to emerge with vector reduction ops. Moving
26820 // the binary operation ahead of the concat may allow using a narrower vector
26821 // instruction that has better performance than the wide version of the op:
26822 // VBinOp (concat X, undef/constant), (concat Y, undef/constant) -->
26823 // concat (VBinOp X, Y), VecC
26824 if (ConcatWithConstantOrUndef(LHS) && ConcatWithConstantOrUndef(RHS) &&
26825 (LHS.hasOneUse() || RHS.hasOneUse())) {
26826 EVT NarrowVT = LHS.getOperand(i: 0).getValueType();
26827 if (NarrowVT == RHS.getOperand(i: 0).getValueType() &&
26828 TLI.isOperationLegalOrCustomOrPromote(Op: Opcode, VT: NarrowVT)) {
26829 unsigned NumOperands = LHS.getNumOperands();
26830 SmallVector<SDValue, 4> ConcatOps;
26831 for (unsigned i = 0; i != NumOperands; ++i) {
26832 // This constant fold for operands 1 and up.
26833 ConcatOps.push_back(Elt: DAG.getNode(Opcode, DL, VT: NarrowVT, N1: LHS.getOperand(i),
26834 N2: RHS.getOperand(i)));
26835 }
26836
26837 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT, Ops: ConcatOps);
26838 }
26839 }
26840
26841 if (SDValue V = scalarizeBinOpOfSplats(N, DAG, DL))
26842 return V;
26843
26844 return SDValue();
26845}
26846
26847SDValue DAGCombiner::SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1,
26848 SDValue N2) {
26849 assert(N0.getOpcode() == ISD::SETCC &&
26850 "First argument must be a SetCC node!");
26851
26852 SDValue SCC = SimplifySelectCC(DL, N0: N0.getOperand(i: 0), N1: N0.getOperand(i: 1), N2: N1, N3: N2,
26853 CC: cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get());
26854
26855 // If we got a simplified select_cc node back from SimplifySelectCC, then
26856 // break it down into a new SETCC node, and a new SELECT node, and then return
26857 // the SELECT node, since we were called with a SELECT node.
26858 if (SCC.getNode()) {
26859 // Check to see if we got a select_cc back (to turn into setcc/select).
26860 // Otherwise, just return whatever node we got back, like fabs.
26861 if (SCC.getOpcode() == ISD::SELECT_CC) {
26862 const SDNodeFlags Flags = N0->getFlags();
26863 SDValue SETCC = DAG.getNode(Opcode: ISD::SETCC, DL: SDLoc(N0),
26864 VT: N0.getValueType(),
26865 N1: SCC.getOperand(i: 0), N2: SCC.getOperand(i: 1),
26866 N3: SCC.getOperand(i: 4), Flags);
26867 AddToWorklist(N: SETCC.getNode());
26868 SDValue SelectNode = DAG.getSelect(DL: SDLoc(SCC), VT: SCC.getValueType(), Cond: SETCC,
26869 LHS: SCC.getOperand(i: 2), RHS: SCC.getOperand(i: 3));
26870 SelectNode->setFlags(Flags);
26871 return SelectNode;
26872 }
26873
26874 return SCC;
26875 }
26876 return SDValue();
26877}
26878
26879/// Given a SELECT or a SELECT_CC node, where LHS and RHS are the two values
26880/// being selected between, see if we can simplify the select. Callers of this
26881/// should assume that TheSelect is deleted if this returns true. As such, they
26882/// should return the appropriate thing (e.g. the node) back to the top-level of
26883/// the DAG combiner loop to avoid it being looked at.
26884bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS,
26885 SDValue RHS) {
26886 // fold (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
26887 // The select + setcc is redundant, because fsqrt returns NaN for X < 0.
26888 if (const ConstantFPSDNode *NaN = isConstOrConstSplatFP(N: LHS)) {
26889 if (NaN->isNaN() && RHS.getOpcode() == ISD::FSQRT) {
26890 // We have: (select (setcc ?, ?, ?), NaN, (fsqrt ?))
26891 SDValue Sqrt = RHS;
26892 ISD::CondCode CC;
26893 SDValue CmpLHS;
26894 const ConstantFPSDNode *Zero = nullptr;
26895
26896 if (TheSelect->getOpcode() == ISD::SELECT_CC) {
26897 CC = cast<CondCodeSDNode>(Val: TheSelect->getOperand(Num: 4))->get();
26898 CmpLHS = TheSelect->getOperand(Num: 0);
26899 Zero = isConstOrConstSplatFP(N: TheSelect->getOperand(Num: 1));
26900 } else {
26901 // SELECT or VSELECT
26902 SDValue Cmp = TheSelect->getOperand(Num: 0);
26903 if (Cmp.getOpcode() == ISD::SETCC) {
26904 CC = cast<CondCodeSDNode>(Val: Cmp.getOperand(i: 2))->get();
26905 CmpLHS = Cmp.getOperand(i: 0);
26906 Zero = isConstOrConstSplatFP(N: Cmp.getOperand(i: 1));
26907 }
26908 }
26909 if (Zero && Zero->isZero() &&
26910 Sqrt.getOperand(i: 0) == CmpLHS && (CC == ISD::SETOLT ||
26911 CC == ISD::SETULT || CC == ISD::SETLT)) {
26912 // We have: (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
26913 CombineTo(N: TheSelect, Res: Sqrt);
26914 return true;
26915 }
26916 }
26917 }
26918 // Cannot simplify select with vector condition
26919 if (TheSelect->getOperand(Num: 0).getValueType().isVector()) return false;
26920
26921 // If this is a select from two identical things, try to pull the operation
26922 // through the select.
26923 if (LHS.getOpcode() != RHS.getOpcode() ||
26924 !LHS.hasOneUse() || !RHS.hasOneUse())
26925 return false;
26926
26927 // If this is a load and the token chain is identical, replace the select
26928 // of two loads with a load through a select of the address to load from.
26929 // This triggers in things like "select bool X, 10.0, 123.0" after the FP
26930 // constants have been dropped into the constant pool.
26931 if (LHS.getOpcode() == ISD::LOAD) {
26932 LoadSDNode *LLD = cast<LoadSDNode>(Val&: LHS);
26933 LoadSDNode *RLD = cast<LoadSDNode>(Val&: RHS);
26934
26935 // Token chains must be identical.
26936 if (LHS.getOperand(i: 0) != RHS.getOperand(i: 0) ||
26937 // Do not let this transformation reduce the number of volatile loads.
26938 // Be conservative for atomics for the moment
26939 // TODO: This does appear to be legal for unordered atomics (see D66309)
26940 !LLD->isSimple() || !RLD->isSimple() ||
26941 // FIXME: If either is a pre/post inc/dec load,
26942 // we'd need to split out the address adjustment.
26943 LLD->isIndexed() || RLD->isIndexed() ||
26944 // If this is an EXTLOAD, the VT's must match.
26945 LLD->getMemoryVT() != RLD->getMemoryVT() ||
26946 // If this is an EXTLOAD, the kind of extension must match.
26947 (LLD->getExtensionType() != RLD->getExtensionType() &&
26948 // The only exception is if one of the extensions is anyext.
26949 LLD->getExtensionType() != ISD::EXTLOAD &&
26950 RLD->getExtensionType() != ISD::EXTLOAD) ||
26951 // FIXME: this discards src value information. This is
26952 // over-conservative. It would be beneficial to be able to remember
26953 // both potential memory locations. Since we are discarding
26954 // src value info, don't do the transformation if the memory
26955 // locations are not in the default address space.
26956 LLD->getPointerInfo().getAddrSpace() != 0 ||
26957 RLD->getPointerInfo().getAddrSpace() != 0 ||
26958 // We can't produce a CMOV of a TargetFrameIndex since we won't
26959 // generate the address generation required.
26960 LLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
26961 RLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
26962 !TLI.isOperationLegalOrCustom(Op: TheSelect->getOpcode(),
26963 VT: LLD->getBasePtr().getValueType()))
26964 return false;
26965
26966 // The loads must not depend on one another.
26967 if (LLD->isPredecessorOf(N: RLD) || RLD->isPredecessorOf(N: LLD))
26968 return false;
26969
26970 // Check that the select condition doesn't reach either load. If so,
26971 // folding this will induce a cycle into the DAG. If not, this is safe to
26972 // xform, so create a select of the addresses.
26973
26974 SmallPtrSet<const SDNode *, 32> Visited;
26975 SmallVector<const SDNode *, 16> Worklist;
26976
26977 // Always fail if LLD and RLD are not independent. TheSelect is a
26978 // predecessor to all Nodes in question so we need not search past it.
26979
26980 Visited.insert(Ptr: TheSelect);
26981 Worklist.push_back(Elt: LLD);
26982 Worklist.push_back(Elt: RLD);
26983
26984 if (SDNode::hasPredecessorHelper(N: LLD, Visited, Worklist) ||
26985 SDNode::hasPredecessorHelper(N: RLD, Visited, Worklist))
26986 return false;
26987
26988 SDValue Addr;
26989 if (TheSelect->getOpcode() == ISD::SELECT) {
26990 // We cannot do this optimization if any pair of {RLD, LLD} is a
26991 // predecessor to {RLD, LLD, CondNode}. As we've already compared the
26992 // Loads, we only need to check if CondNode is a successor to one of the
26993 // loads. We can further avoid this if there's no use of their chain
26994 // value.
26995 SDNode *CondNode = TheSelect->getOperand(Num: 0).getNode();
26996 Worklist.push_back(Elt: CondNode);
26997
26998 if ((LLD->hasAnyUseOfValue(Value: 1) &&
26999 SDNode::hasPredecessorHelper(N: LLD, Visited, Worklist)) ||
27000 (RLD->hasAnyUseOfValue(Value: 1) &&
27001 SDNode::hasPredecessorHelper(N: RLD, Visited, Worklist)))
27002 return false;
27003
27004 Addr = DAG.getSelect(DL: SDLoc(TheSelect),
27005 VT: LLD->getBasePtr().getValueType(),
27006 Cond: TheSelect->getOperand(Num: 0), LHS: LLD->getBasePtr(),
27007 RHS: RLD->getBasePtr());
27008 } else { // Otherwise SELECT_CC
27009 // We cannot do this optimization if any pair of {RLD, LLD} is a
27010 // predecessor to {RLD, LLD, CondLHS, CondRHS}. As we've already compared
27011 // the Loads, we only need to check if CondLHS/CondRHS is a successor to
27012 // one of the loads. We can further avoid this if there's no use of their
27013 // chain value.
27014
27015 SDNode *CondLHS = TheSelect->getOperand(Num: 0).getNode();
27016 SDNode *CondRHS = TheSelect->getOperand(Num: 1).getNode();
27017 Worklist.push_back(Elt: CondLHS);
27018 Worklist.push_back(Elt: CondRHS);
27019
27020 if ((LLD->hasAnyUseOfValue(Value: 1) &&
27021 SDNode::hasPredecessorHelper(N: LLD, Visited, Worklist)) ||
27022 (RLD->hasAnyUseOfValue(Value: 1) &&
27023 SDNode::hasPredecessorHelper(N: RLD, Visited, Worklist)))
27024 return false;
27025
27026 Addr = DAG.getNode(Opcode: ISD::SELECT_CC, DL: SDLoc(TheSelect),
27027 VT: LLD->getBasePtr().getValueType(),
27028 N1: TheSelect->getOperand(Num: 0),
27029 N2: TheSelect->getOperand(Num: 1),
27030 N3: LLD->getBasePtr(), N4: RLD->getBasePtr(),
27031 N5: TheSelect->getOperand(Num: 4));
27032 }
27033
27034 SDValue Load;
27035 // It is safe to replace the two loads if they have different alignments,
27036 // but the new load must be the minimum (most restrictive) alignment of the
27037 // inputs.
27038 Align Alignment = std::min(a: LLD->getAlign(), b: RLD->getAlign());
27039 MachineMemOperand::Flags MMOFlags = LLD->getMemOperand()->getFlags();
27040 if (!RLD->isInvariant())
27041 MMOFlags &= ~MachineMemOperand::MOInvariant;
27042 if (!RLD->isDereferenceable())
27043 MMOFlags &= ~MachineMemOperand::MODereferenceable;
27044 if (LLD->getExtensionType() == ISD::NON_EXTLOAD) {
27045 // FIXME: Discards pointer and AA info.
27046 Load = DAG.getLoad(VT: TheSelect->getValueType(ResNo: 0), dl: SDLoc(TheSelect),
27047 Chain: LLD->getChain(), Ptr: Addr, PtrInfo: MachinePointerInfo(), Alignment,
27048 MMOFlags);
27049 } else {
27050 // FIXME: Discards pointer and AA info.
27051 Load = DAG.getExtLoad(
27052 ExtType: LLD->getExtensionType() == ISD::EXTLOAD ? RLD->getExtensionType()
27053 : LLD->getExtensionType(),
27054 dl: SDLoc(TheSelect), VT: TheSelect->getValueType(ResNo: 0), Chain: LLD->getChain(), Ptr: Addr,
27055 PtrInfo: MachinePointerInfo(), MemVT: LLD->getMemoryVT(), Alignment, MMOFlags);
27056 }
27057
27058 // Users of the select now use the result of the load.
27059 CombineTo(N: TheSelect, Res: Load);
27060
27061 // Users of the old loads now use the new load's chain. We know the
27062 // old-load value is dead now.
27063 CombineTo(N: LHS.getNode(), Res0: Load.getValue(R: 0), Res1: Load.getValue(R: 1));
27064 CombineTo(N: RHS.getNode(), Res0: Load.getValue(R: 0), Res1: Load.getValue(R: 1));
27065 return true;
27066 }
27067
27068 return false;
27069}
27070
27071/// Try to fold an expression of the form (N0 cond N1) ? N2 : N3 to a shift and
27072/// bitwise 'and'.
27073SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0,
27074 SDValue N1, SDValue N2, SDValue N3,
27075 ISD::CondCode CC) {
27076 // If this is a select where the false operand is zero and the compare is a
27077 // check of the sign bit, see if we can perform the "gzip trick":
27078 // select_cc setlt X, 0, A, 0 -> and (sra X, size(X)-1), A
27079 // select_cc setgt X, 0, A, 0 -> and (not (sra X, size(X)-1)), A
27080 EVT XType = N0.getValueType();
27081 EVT AType = N2.getValueType();
27082 if (!isNullConstant(V: N3) || !XType.bitsGE(VT: AType))
27083 return SDValue();
27084
27085 // If the comparison is testing for a positive value, we have to invert
27086 // the sign bit mask, so only do that transform if the target has a bitwise
27087 // 'and not' instruction (the invert is free).
27088 if (CC == ISD::SETGT && TLI.hasAndNot(X: N2)) {
27089 // (X > -1) ? A : 0
27090 // (X > 0) ? X : 0 <-- This is canonical signed max.
27091 if (!(isAllOnesConstant(V: N1) || (isNullConstant(V: N1) && N0 == N2)))
27092 return SDValue();
27093 } else if (CC == ISD::SETLT) {
27094 // (X < 0) ? A : 0
27095 // (X < 1) ? X : 0 <-- This is un-canonicalized signed min.
27096 if (!(isNullConstant(V: N1) || (isOneConstant(V: N1) && N0 == N2)))
27097 return SDValue();
27098 } else {
27099 return SDValue();
27100 }
27101
27102 // and (sra X, size(X)-1), A -> "and (srl X, C2), A" iff A is a single-bit
27103 // constant.
27104 EVT ShiftAmtTy = getShiftAmountTy(LHSTy: N0.getValueType());
27105 auto *N2C = dyn_cast<ConstantSDNode>(Val: N2.getNode());
27106 if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue() - 1)) == 0)) {
27107 unsigned ShCt = XType.getSizeInBits() - N2C->getAPIntValue().logBase2() - 1;
27108 if (!TLI.shouldAvoidTransformToShift(VT: XType, Amount: ShCt)) {
27109 SDValue ShiftAmt = DAG.getConstant(Val: ShCt, DL, VT: ShiftAmtTy);
27110 SDValue Shift = DAG.getNode(Opcode: ISD::SRL, DL, VT: XType, N1: N0, N2: ShiftAmt);
27111 AddToWorklist(N: Shift.getNode());
27112
27113 if (XType.bitsGT(VT: AType)) {
27114 Shift = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: AType, Operand: Shift);
27115 AddToWorklist(N: Shift.getNode());
27116 }
27117
27118 if (CC == ISD::SETGT)
27119 Shift = DAG.getNOT(DL, Val: Shift, VT: AType);
27120
27121 return DAG.getNode(Opcode: ISD::AND, DL, VT: AType, N1: Shift, N2);
27122 }
27123 }
27124
27125 unsigned ShCt = XType.getSizeInBits() - 1;
27126 if (TLI.shouldAvoidTransformToShift(VT: XType, Amount: ShCt))
27127 return SDValue();
27128
27129 SDValue ShiftAmt = DAG.getConstant(Val: ShCt, DL, VT: ShiftAmtTy);
27130 SDValue Shift = DAG.getNode(Opcode: ISD::SRA, DL, VT: XType, N1: N0, N2: ShiftAmt);
27131 AddToWorklist(N: Shift.getNode());
27132
27133 if (XType.bitsGT(VT: AType)) {
27134 Shift = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: AType, Operand: Shift);
27135 AddToWorklist(N: Shift.getNode());
27136 }
27137
27138 if (CC == ISD::SETGT)
27139 Shift = DAG.getNOT(DL, Val: Shift, VT: AType);
27140
27141 return DAG.getNode(Opcode: ISD::AND, DL, VT: AType, N1: Shift, N2);
27142}
27143
27144// Fold select(cc, binop(), binop()) -> binop(select(), select()) etc.
27145SDValue DAGCombiner::foldSelectOfBinops(SDNode *N) {
27146 SDValue N0 = N->getOperand(Num: 0);
27147 SDValue N1 = N->getOperand(Num: 1);
27148 SDValue N2 = N->getOperand(Num: 2);
27149 SDLoc DL(N);
27150
27151 unsigned BinOpc = N1.getOpcode();
27152 if (!TLI.isBinOp(Opcode: BinOpc) || (N2.getOpcode() != BinOpc) ||
27153 (N1.getResNo() != N2.getResNo()))
27154 return SDValue();
27155
27156 // The use checks are intentionally on SDNode because we may be dealing
27157 // with opcodes that produce more than one SDValue.
27158 // TODO: Do we really need to check N0 (the condition operand of the select)?
27159 // But removing that clause could cause an infinite loop...
27160 if (!N0->hasOneUse() || !N1->hasOneUse() || !N2->hasOneUse())
27161 return SDValue();
27162
27163 // Binops may include opcodes that return multiple values, so all values
27164 // must be created/propagated from the newly created binops below.
27165 SDVTList OpVTs = N1->getVTList();
27166
27167 // Fold select(cond, binop(x, y), binop(z, y))
27168 // --> binop(select(cond, x, z), y)
27169 if (N1.getOperand(i: 1) == N2.getOperand(i: 1)) {
27170 SDValue N10 = N1.getOperand(i: 0);
27171 SDValue N20 = N2.getOperand(i: 0);
27172 SDValue NewSel = DAG.getSelect(DL, VT: N10.getValueType(), Cond: N0, LHS: N10, RHS: N20);
27173 SDValue NewBinOp = DAG.getNode(Opcode: BinOpc, DL, VTList: OpVTs, N1: NewSel, N2: N1.getOperand(i: 1));
27174 NewBinOp->setFlags(N1->getFlags());
27175 NewBinOp->intersectFlagsWith(Flags: N2->getFlags());
27176 return SDValue(NewBinOp.getNode(), N1.getResNo());
27177 }
27178
27179 // Fold select(cond, binop(x, y), binop(x, z))
27180 // --> binop(x, select(cond, y, z))
27181 if (N1.getOperand(i: 0) == N2.getOperand(i: 0)) {
27182 SDValue N11 = N1.getOperand(i: 1);
27183 SDValue N21 = N2.getOperand(i: 1);
27184 // Second op VT might be different (e.g. shift amount type)
27185 if (N11.getValueType() == N21.getValueType()) {
27186 SDValue NewSel = DAG.getSelect(DL, VT: N11.getValueType(), Cond: N0, LHS: N11, RHS: N21);
27187 SDValue NewBinOp =
27188 DAG.getNode(Opcode: BinOpc, DL, VTList: OpVTs, N1: N1.getOperand(i: 0), N2: NewSel);
27189 NewBinOp->setFlags(N1->getFlags());
27190 NewBinOp->intersectFlagsWith(Flags: N2->getFlags());
27191 return SDValue(NewBinOp.getNode(), N1.getResNo());
27192 }
27193 }
27194
27195 // TODO: Handle isCommutativeBinOp patterns as well?
27196 return SDValue();
27197}
27198
27199// Transform (fneg/fabs (bitconvert x)) to avoid loading constant pool values.
27200SDValue DAGCombiner::foldSignChangeInBitcast(SDNode *N) {
27201 SDValue N0 = N->getOperand(Num: 0);
27202 EVT VT = N->getValueType(ResNo: 0);
27203 bool IsFabs = N->getOpcode() == ISD::FABS;
27204 bool IsFree = IsFabs ? TLI.isFAbsFree(VT) : TLI.isFNegFree(VT);
27205
27206 if (IsFree || N0.getOpcode() != ISD::BITCAST || !N0.hasOneUse())
27207 return SDValue();
27208
27209 SDValue Int = N0.getOperand(i: 0);
27210 EVT IntVT = Int.getValueType();
27211
27212 // The operand to cast should be integer.
27213 if (!IntVT.isInteger() || IntVT.isVector())
27214 return SDValue();
27215
27216 // (fneg (bitconvert x)) -> (bitconvert (xor x sign))
27217 // (fabs (bitconvert x)) -> (bitconvert (and x ~sign))
27218 APInt SignMask;
27219 if (N0.getValueType().isVector()) {
27220 // For vector, create a sign mask (0x80...) or its inverse (for fabs,
27221 // 0x7f...) per element and splat it.
27222 SignMask = APInt::getSignMask(BitWidth: N0.getScalarValueSizeInBits());
27223 if (IsFabs)
27224 SignMask = ~SignMask;
27225 SignMask = APInt::getSplat(NewLen: IntVT.getSizeInBits(), V: SignMask);
27226 } else {
27227 // For scalar, just use the sign mask (0x80... or the inverse, 0x7f...)
27228 SignMask = APInt::getSignMask(BitWidth: IntVT.getSizeInBits());
27229 if (IsFabs)
27230 SignMask = ~SignMask;
27231 }
27232 SDLoc DL(N0);
27233 Int = DAG.getNode(Opcode: IsFabs ? ISD::AND : ISD::XOR, DL, VT: IntVT, N1: Int,
27234 N2: DAG.getConstant(Val: SignMask, DL, VT: IntVT));
27235 AddToWorklist(N: Int.getNode());
27236 return DAG.getBitcast(VT, V: Int);
27237}
27238
27239/// Turn "(a cond b) ? 1.0f : 2.0f" into "load (tmp + ((a cond b) ? 0 : 4)"
27240/// where "tmp" is a constant pool entry containing an array with 1.0 and 2.0
27241/// in it. This may be a win when the constant is not otherwise available
27242/// because it replaces two constant pool loads with one.
27243SDValue DAGCombiner::convertSelectOfFPConstantsToLoadOffset(
27244 const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
27245 ISD::CondCode CC) {
27246 if (!TLI.reduceSelectOfFPConstantLoads(CmpOpVT: N0.getValueType()))
27247 return SDValue();
27248
27249 // If we are before legalize types, we want the other legalization to happen
27250 // first (for example, to avoid messing with soft float).
27251 auto *TV = dyn_cast<ConstantFPSDNode>(Val&: N2);
27252 auto *FV = dyn_cast<ConstantFPSDNode>(Val&: N3);
27253 EVT VT = N2.getValueType();
27254 if (!TV || !FV || !TLI.isTypeLegal(VT))
27255 return SDValue();
27256
27257 // If a constant can be materialized without loads, this does not make sense.
27258 if (TLI.getOperationAction(Op: ISD::ConstantFP, VT) == TargetLowering::Legal ||
27259 TLI.isFPImmLegal(TV->getValueAPF(), TV->getValueType(ResNo: 0), ForCodeSize) ||
27260 TLI.isFPImmLegal(FV->getValueAPF(), FV->getValueType(ResNo: 0), ForCodeSize))
27261 return SDValue();
27262
27263 // If both constants have multiple uses, then we won't need to do an extra
27264 // load. The values are likely around in registers for other users.
27265 if (!TV->hasOneUse() && !FV->hasOneUse())
27266 return SDValue();
27267
27268 Constant *Elts[] = { const_cast<ConstantFP*>(FV->getConstantFPValue()),
27269 const_cast<ConstantFP*>(TV->getConstantFPValue()) };
27270 Type *FPTy = Elts[0]->getType();
27271 const DataLayout &TD = DAG.getDataLayout();
27272
27273 // Create a ConstantArray of the two constants.
27274 Constant *CA = ConstantArray::get(T: ArrayType::get(ElementType: FPTy, NumElements: 2), V: Elts);
27275 SDValue CPIdx = DAG.getConstantPool(C: CA, VT: TLI.getPointerTy(DL: DAG.getDataLayout()),
27276 Align: TD.getPrefTypeAlign(Ty: FPTy));
27277 Align Alignment = cast<ConstantPoolSDNode>(Val&: CPIdx)->getAlign();
27278
27279 // Get offsets to the 0 and 1 elements of the array, so we can select between
27280 // them.
27281 SDValue Zero = DAG.getIntPtrConstant(Val: 0, DL);
27282 unsigned EltSize = (unsigned)TD.getTypeAllocSize(Ty: Elts[0]->getType());
27283 SDValue One = DAG.getIntPtrConstant(Val: EltSize, DL: SDLoc(FV));
27284 SDValue Cond =
27285 DAG.getSetCC(DL, VT: getSetCCResultType(VT: N0.getValueType()), LHS: N0, RHS: N1, Cond: CC);
27286 AddToWorklist(N: Cond.getNode());
27287 SDValue CstOffset = DAG.getSelect(DL, VT: Zero.getValueType(), Cond, LHS: One, RHS: Zero);
27288 AddToWorklist(N: CstOffset.getNode());
27289 CPIdx = DAG.getNode(Opcode: ISD::ADD, DL, VT: CPIdx.getValueType(), N1: CPIdx, N2: CstOffset);
27290 AddToWorklist(N: CPIdx.getNode());
27291 return DAG.getLoad(VT: TV->getValueType(ResNo: 0), dl: DL, Chain: DAG.getEntryNode(), Ptr: CPIdx,
27292 PtrInfo: MachinePointerInfo::getConstantPool(
27293 MF&: DAG.getMachineFunction()), Alignment);
27294}
27295
27296/// Simplify an expression of the form (N0 cond N1) ? N2 : N3
27297/// where 'cond' is the comparison specified by CC.
27298SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
27299 SDValue N2, SDValue N3, ISD::CondCode CC,
27300 bool NotExtCompare) {
27301 // (x ? y : y) -> y.
27302 if (N2 == N3) return N2;
27303
27304 EVT CmpOpVT = N0.getValueType();
27305 EVT CmpResVT = getSetCCResultType(VT: CmpOpVT);
27306 EVT VT = N2.getValueType();
27307 auto *N1C = dyn_cast<ConstantSDNode>(Val: N1.getNode());
27308 auto *N2C = dyn_cast<ConstantSDNode>(Val: N2.getNode());
27309 auto *N3C = dyn_cast<ConstantSDNode>(Val: N3.getNode());
27310
27311 // Determine if the condition we're dealing with is constant.
27312 if (SDValue SCC = DAG.FoldSetCC(VT: CmpResVT, N1: N0, N2: N1, Cond: CC, dl: DL)) {
27313 AddToWorklist(N: SCC.getNode());
27314 if (auto *SCCC = dyn_cast<ConstantSDNode>(Val&: SCC)) {
27315 // fold select_cc true, x, y -> x
27316 // fold select_cc false, x, y -> y
27317 return !(SCCC->isZero()) ? N2 : N3;
27318 }
27319 }
27320
27321 if (SDValue V =
27322 convertSelectOfFPConstantsToLoadOffset(DL, N0, N1, N2, N3, CC))
27323 return V;
27324
27325 if (SDValue V = foldSelectCCToShiftAnd(DL, N0, N1, N2, N3, CC))
27326 return V;
27327
27328 // fold (select_cc seteq (and x, y), 0, 0, A) -> (and (sra (shl x)) A)
27329 // where y is has a single bit set.
27330 // A plaintext description would be, we can turn the SELECT_CC into an AND
27331 // when the condition can be materialized as an all-ones register. Any
27332 // single bit-test can be materialized as an all-ones register with
27333 // shift-left and shift-right-arith.
27334 if (CC == ISD::SETEQ && N0->getOpcode() == ISD::AND &&
27335 N0->getValueType(ResNo: 0) == VT && isNullConstant(V: N1) && isNullConstant(V: N2)) {
27336 SDValue AndLHS = N0->getOperand(Num: 0);
27337 auto *ConstAndRHS = dyn_cast<ConstantSDNode>(Val: N0->getOperand(Num: 1));
27338 if (ConstAndRHS && ConstAndRHS->getAPIntValue().popcount() == 1) {
27339 // Shift the tested bit over the sign bit.
27340 const APInt &AndMask = ConstAndRHS->getAPIntValue();
27341 if (TLI.shouldFoldSelectWithSingleBitTest(VT, AndMask)) {
27342 unsigned ShCt = AndMask.getBitWidth() - 1;
27343 SDValue ShlAmt =
27344 DAG.getConstant(Val: AndMask.countl_zero(), DL: SDLoc(AndLHS),
27345 VT: getShiftAmountTy(LHSTy: AndLHS.getValueType()));
27346 SDValue Shl = DAG.getNode(Opcode: ISD::SHL, DL: SDLoc(N0), VT, N1: AndLHS, N2: ShlAmt);
27347
27348 // Now arithmetic right shift it all the way over, so the result is
27349 // either all-ones, or zero.
27350 SDValue ShrAmt =
27351 DAG.getConstant(Val: ShCt, DL: SDLoc(Shl),
27352 VT: getShiftAmountTy(LHSTy: Shl.getValueType()));
27353 SDValue Shr = DAG.getNode(Opcode: ISD::SRA, DL: SDLoc(N0), VT, N1: Shl, N2: ShrAmt);
27354
27355 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Shr, N2: N3);
27356 }
27357 }
27358 }
27359
27360 // fold select C, 16, 0 -> shl C, 4
27361 bool Fold = N2C && isNullConstant(V: N3) && N2C->getAPIntValue().isPowerOf2();
27362 bool Swap = N3C && isNullConstant(V: N2) && N3C->getAPIntValue().isPowerOf2();
27363
27364 if ((Fold || Swap) &&
27365 TLI.getBooleanContents(Type: CmpOpVT) ==
27366 TargetLowering::ZeroOrOneBooleanContent &&
27367 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SETCC, VT: CmpOpVT))) {
27368
27369 if (Swap) {
27370 CC = ISD::getSetCCInverse(Operation: CC, Type: CmpOpVT);
27371 std::swap(a&: N2C, b&: N3C);
27372 }
27373
27374 // If the caller doesn't want us to simplify this into a zext of a compare,
27375 // don't do it.
27376 if (NotExtCompare && N2C->isOne())
27377 return SDValue();
27378
27379 SDValue Temp, SCC;
27380 // zext (setcc n0, n1)
27381 if (LegalTypes) {
27382 SCC = DAG.getSetCC(DL, VT: CmpResVT, LHS: N0, RHS: N1, Cond: CC);
27383 Temp = DAG.getZExtOrTrunc(Op: SCC, DL: SDLoc(N2), VT);
27384 } else {
27385 SCC = DAG.getSetCC(SDLoc(N0), MVT::i1, N0, N1, CC);
27386 Temp = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL: SDLoc(N2), VT, Operand: SCC);
27387 }
27388
27389 AddToWorklist(N: SCC.getNode());
27390 AddToWorklist(N: Temp.getNode());
27391
27392 if (N2C->isOne())
27393 return Temp;
27394
27395 unsigned ShCt = N2C->getAPIntValue().logBase2();
27396 if (TLI.shouldAvoidTransformToShift(VT, Amount: ShCt))
27397 return SDValue();
27398
27399 // shl setcc result by log2 n2c
27400 return DAG.getNode(Opcode: ISD::SHL, DL, VT: N2.getValueType(), N1: Temp,
27401 N2: DAG.getConstant(Val: ShCt, DL: SDLoc(Temp),
27402 VT: getShiftAmountTy(LHSTy: Temp.getValueType())));
27403 }
27404
27405 // select_cc seteq X, 0, sizeof(X), ctlz(X) -> ctlz(X)
27406 // select_cc seteq X, 0, sizeof(X), ctlz_zero_undef(X) -> ctlz(X)
27407 // select_cc seteq X, 0, sizeof(X), cttz(X) -> cttz(X)
27408 // select_cc seteq X, 0, sizeof(X), cttz_zero_undef(X) -> cttz(X)
27409 // select_cc setne X, 0, ctlz(X), sizeof(X) -> ctlz(X)
27410 // select_cc setne X, 0, ctlz_zero_undef(X), sizeof(X) -> ctlz(X)
27411 // select_cc setne X, 0, cttz(X), sizeof(X) -> cttz(X)
27412 // select_cc setne X, 0, cttz_zero_undef(X), sizeof(X) -> cttz(X)
27413 if (N1C && N1C->isZero() && (CC == ISD::SETEQ || CC == ISD::SETNE)) {
27414 SDValue ValueOnZero = N2;
27415 SDValue Count = N3;
27416 // If the condition is NE instead of E, swap the operands.
27417 if (CC == ISD::SETNE)
27418 std::swap(a&: ValueOnZero, b&: Count);
27419 // Check if the value on zero is a constant equal to the bits in the type.
27420 if (auto *ValueOnZeroC = dyn_cast<ConstantSDNode>(Val&: ValueOnZero)) {
27421 if (ValueOnZeroC->getAPIntValue() == VT.getSizeInBits()) {
27422 // If the other operand is cttz/cttz_zero_undef of N0, and cttz is
27423 // legal, combine to just cttz.
27424 if ((Count.getOpcode() == ISD::CTTZ ||
27425 Count.getOpcode() == ISD::CTTZ_ZERO_UNDEF) &&
27426 N0 == Count.getOperand(i: 0) &&
27427 (!LegalOperations || TLI.isOperationLegal(Op: ISD::CTTZ, VT)))
27428 return DAG.getNode(Opcode: ISD::CTTZ, DL, VT, Operand: N0);
27429 // If the other operand is ctlz/ctlz_zero_undef of N0, and ctlz is
27430 // legal, combine to just ctlz.
27431 if ((Count.getOpcode() == ISD::CTLZ ||
27432 Count.getOpcode() == ISD::CTLZ_ZERO_UNDEF) &&
27433 N0 == Count.getOperand(i: 0) &&
27434 (!LegalOperations || TLI.isOperationLegal(Op: ISD::CTLZ, VT)))
27435 return DAG.getNode(Opcode: ISD::CTLZ, DL, VT, Operand: N0);
27436 }
27437 }
27438 }
27439
27440 // Fold select_cc setgt X, -1, C, ~C -> xor (ashr X, BW-1), C
27441 // Fold select_cc setlt X, 0, C, ~C -> xor (ashr X, BW-1), ~C
27442 if (!NotExtCompare && N1C && N2C && N3C &&
27443 N2C->getAPIntValue() == ~N3C->getAPIntValue() &&
27444 ((N1C->isAllOnes() && CC == ISD::SETGT) ||
27445 (N1C->isZero() && CC == ISD::SETLT)) &&
27446 !TLI.shouldAvoidTransformToShift(VT, Amount: CmpOpVT.getScalarSizeInBits() - 1)) {
27447 SDValue ASR = DAG.getNode(
27448 Opcode: ISD::SRA, DL, VT: CmpOpVT, N1: N0,
27449 N2: DAG.getConstant(Val: CmpOpVT.getScalarSizeInBits() - 1, DL, VT: CmpOpVT));
27450 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: DAG.getSExtOrTrunc(Op: ASR, DL, VT),
27451 N2: DAG.getSExtOrTrunc(Op: CC == ISD::SETLT ? N3 : N2, DL, VT));
27452 }
27453
27454 if (SDValue S = PerformMinMaxFpToSatCombine(N0, N1, N2, N3, CC, DAG))
27455 return S;
27456 if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N2, N3, CC, DAG))
27457 return S;
27458
27459 return SDValue();
27460}
27461
27462/// This is a stub for TargetLowering::SimplifySetCC.
27463SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
27464 ISD::CondCode Cond, const SDLoc &DL,
27465 bool foldBooleans) {
27466 TargetLowering::DAGCombinerInfo
27467 DagCombineInfo(DAG, Level, false, this);
27468 return TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DCI&: DagCombineInfo, dl: DL);
27469}
27470
27471/// Given an ISD::SDIV node expressing a divide by constant, return
27472/// a DAG expression to select that will generate the same value by multiplying
27473/// by a magic number.
27474/// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
27475SDValue DAGCombiner::BuildSDIV(SDNode *N) {
27476 // when optimising for minimum size, we don't want to expand a div to a mul
27477 // and a shift.
27478 if (DAG.getMachineFunction().getFunction().hasMinSize())
27479 return SDValue();
27480
27481 SmallVector<SDNode *, 8> Built;
27482 if (SDValue S = TLI.BuildSDIV(N, DAG, IsAfterLegalization: LegalOperations, Created&: Built)) {
27483 for (SDNode *N : Built)
27484 AddToWorklist(N);
27485 return S;
27486 }
27487
27488 return SDValue();
27489}
27490
27491/// Given an ISD::SDIV node expressing a divide by constant power of 2, return a
27492/// DAG expression that will generate the same value by right shifting.
27493SDValue DAGCombiner::BuildSDIVPow2(SDNode *N) {
27494 ConstantSDNode *C = isConstOrConstSplat(N: N->getOperand(Num: 1));
27495 if (!C)
27496 return SDValue();
27497
27498 // Avoid division by zero.
27499 if (C->isZero())
27500 return SDValue();
27501
27502 SmallVector<SDNode *, 8> Built;
27503 if (SDValue S = TLI.BuildSDIVPow2(N, Divisor: C->getAPIntValue(), DAG, Created&: Built)) {
27504 for (SDNode *N : Built)
27505 AddToWorklist(N);
27506 return S;
27507 }
27508
27509 return SDValue();
27510}
27511
27512/// Given an ISD::UDIV node expressing a divide by constant, return a DAG
27513/// expression that will generate the same value by multiplying by a magic
27514/// number.
27515/// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
27516SDValue DAGCombiner::BuildUDIV(SDNode *N) {
27517 // when optimising for minimum size, we don't want to expand a div to a mul
27518 // and a shift.
27519 if (DAG.getMachineFunction().getFunction().hasMinSize())
27520 return SDValue();
27521
27522 SmallVector<SDNode *, 8> Built;
27523 if (SDValue S = TLI.BuildUDIV(N, DAG, IsAfterLegalization: LegalOperations, Created&: Built)) {
27524 for (SDNode *N : Built)
27525 AddToWorklist(N);
27526 return S;
27527 }
27528
27529 return SDValue();
27530}
27531
27532/// Given an ISD::SREM node expressing a remainder by constant power of 2,
27533/// return a DAG expression that will generate the same value.
27534SDValue DAGCombiner::BuildSREMPow2(SDNode *N) {
27535 ConstantSDNode *C = isConstOrConstSplat(N: N->getOperand(Num: 1));
27536 if (!C)
27537 return SDValue();
27538
27539 // Avoid division by zero.
27540 if (C->isZero())
27541 return SDValue();
27542
27543 SmallVector<SDNode *, 8> Built;
27544 if (SDValue S = TLI.BuildSREMPow2(N, Divisor: C->getAPIntValue(), DAG, Created&: Built)) {
27545 for (SDNode *N : Built)
27546 AddToWorklist(N);
27547 return S;
27548 }
27549
27550 return SDValue();
27551}
27552
27553// This is basically just a port of takeLog2 from InstCombineMulDivRem.cpp
27554//
27555// Returns the node that represents `Log2(Op)`. This may create a new node. If
27556// we are unable to compute `Log2(Op)` its return `SDValue()`.
27557//
27558// All nodes will be created at `DL` and the output will be of type `VT`.
27559//
27560// This will only return `Log2(Op)` if we can prove `Op` is non-zero. Set
27561// `AssumeNonZero` if this function should simply assume (not require proving
27562// `Op` is non-zero).
27563static SDValue takeInexpensiveLog2(SelectionDAG &DAG, const SDLoc &DL, EVT VT,
27564 SDValue Op, unsigned Depth,
27565 bool AssumeNonZero) {
27566 assert(VT.isInteger() && "Only integer types are supported!");
27567
27568 auto PeekThroughCastsAndTrunc = [](SDValue V) {
27569 while (true) {
27570 switch (V.getOpcode()) {
27571 case ISD::TRUNCATE:
27572 case ISD::ZERO_EXTEND:
27573 V = V.getOperand(i: 0);
27574 break;
27575 default:
27576 return V;
27577 }
27578 }
27579 };
27580
27581 if (VT.isScalableVector())
27582 return SDValue();
27583
27584 Op = PeekThroughCastsAndTrunc(Op);
27585
27586 // Helper for determining whether a value is a power-2 constant scalar or a
27587 // vector of such elements.
27588 SmallVector<APInt> Pow2Constants;
27589 auto IsPowerOfTwo = [&Pow2Constants](ConstantSDNode *C) {
27590 if (C->isZero() || C->isOpaque())
27591 return false;
27592 // TODO: We may also be able to support negative powers of 2 here.
27593 if (C->getAPIntValue().isPowerOf2()) {
27594 Pow2Constants.emplace_back(Args: C->getAPIntValue());
27595 return true;
27596 }
27597 return false;
27598 };
27599
27600 if (ISD::matchUnaryPredicate(Op, Match: IsPowerOfTwo)) {
27601 if (!VT.isVector())
27602 return DAG.getConstant(Val: Pow2Constants.back().logBase2(), DL, VT);
27603 // We need to create a build vector
27604 SmallVector<SDValue> Log2Ops;
27605 for (const APInt &Pow2 : Pow2Constants)
27606 Log2Ops.emplace_back(
27607 Args: DAG.getConstant(Val: Pow2.logBase2(), DL, VT: VT.getScalarType()));
27608 return DAG.getBuildVector(VT, DL, Ops: Log2Ops);
27609 }
27610
27611 if (Depth >= DAG.MaxRecursionDepth)
27612 return SDValue();
27613
27614 auto CastToVT = [&](EVT NewVT, SDValue ToCast) {
27615 ToCast = PeekThroughCastsAndTrunc(ToCast);
27616 EVT CurVT = ToCast.getValueType();
27617 if (NewVT == CurVT)
27618 return ToCast;
27619
27620 if (NewVT.getSizeInBits() == CurVT.getSizeInBits())
27621 return DAG.getBitcast(VT: NewVT, V: ToCast);
27622
27623 return DAG.getZExtOrTrunc(Op: ToCast, DL, VT: NewVT);
27624 };
27625
27626 // log2(X << Y) -> log2(X) + Y
27627 if (Op.getOpcode() == ISD::SHL) {
27628 // 1 << Y and X nuw/nsw << Y are all non-zero.
27629 if (AssumeNonZero || Op->getFlags().hasNoUnsignedWrap() ||
27630 Op->getFlags().hasNoSignedWrap() || isOneConstant(V: Op.getOperand(i: 0)))
27631 if (SDValue LogX = takeInexpensiveLog2(DAG, DL, VT, Op: Op.getOperand(i: 0),
27632 Depth: Depth + 1, AssumeNonZero))
27633 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: LogX,
27634 N2: CastToVT(VT, Op.getOperand(i: 1)));
27635 }
27636
27637 // c ? X : Y -> c ? Log2(X) : Log2(Y)
27638 if ((Op.getOpcode() == ISD::SELECT || Op.getOpcode() == ISD::VSELECT) &&
27639 Op.hasOneUse()) {
27640 if (SDValue LogX = takeInexpensiveLog2(DAG, DL, VT, Op: Op.getOperand(i: 1),
27641 Depth: Depth + 1, AssumeNonZero))
27642 if (SDValue LogY = takeInexpensiveLog2(DAG, DL, VT, Op: Op.getOperand(i: 2),
27643 Depth: Depth + 1, AssumeNonZero))
27644 return DAG.getSelect(DL, VT, Cond: Op.getOperand(i: 0), LHS: LogX, RHS: LogY);
27645 }
27646
27647 // log2(umin(X, Y)) -> umin(log2(X), log2(Y))
27648 // log2(umax(X, Y)) -> umax(log2(X), log2(Y))
27649 if ((Op.getOpcode() == ISD::UMIN || Op.getOpcode() == ISD::UMAX) &&
27650 Op.hasOneUse()) {
27651 // Use AssumeNonZero as false here. Otherwise we can hit case where
27652 // log2(umax(X, Y)) != umax(log2(X), log2(Y)) (because overflow).
27653 if (SDValue LogX =
27654 takeInexpensiveLog2(DAG, DL, VT, Op: Op.getOperand(i: 0), Depth: Depth + 1,
27655 /*AssumeNonZero*/ false))
27656 if (SDValue LogY =
27657 takeInexpensiveLog2(DAG, DL, VT, Op: Op.getOperand(i: 1), Depth: Depth + 1,
27658 /*AssumeNonZero*/ false))
27659 return DAG.getNode(Opcode: Op.getOpcode(), DL, VT, N1: LogX, N2: LogY);
27660 }
27661
27662 return SDValue();
27663}
27664
27665/// Determines the LogBase2 value for a non-null input value using the
27666/// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
27667SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL,
27668 bool KnownNonZero, bool InexpensiveOnly,
27669 std::optional<EVT> OutVT) {
27670 EVT VT = OutVT ? *OutVT : V.getValueType();
27671 SDValue InexpensiveLogBase2 =
27672 takeInexpensiveLog2(DAG, DL, VT, Op: V, /*Depth*/ 0, AssumeNonZero: KnownNonZero);
27673 if (InexpensiveLogBase2 || InexpensiveOnly || !DAG.isKnownToBeAPowerOfTwo(Val: V))
27674 return InexpensiveLogBase2;
27675
27676 SDValue Ctlz = DAG.getNode(Opcode: ISD::CTLZ, DL, VT, Operand: V);
27677 SDValue Base = DAG.getConstant(Val: VT.getScalarSizeInBits() - 1, DL, VT);
27678 SDValue LogBase2 = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Base, N2: Ctlz);
27679 return LogBase2;
27680}
27681
27682/// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
27683/// For the reciprocal, we need to find the zero of the function:
27684/// F(X) = 1/X - A [which has a zero at X = 1/A]
27685/// =>
27686/// X_{i+1} = X_i (2 - A X_i) = X_i + X_i (1 - A X_i) [this second form
27687/// does not require additional intermediate precision]
27688/// For the last iteration, put numerator N into it to gain more precision:
27689/// Result = N X_i + X_i (N - N A X_i)
27690SDValue DAGCombiner::BuildDivEstimate(SDValue N, SDValue Op,
27691 SDNodeFlags Flags) {
27692 if (LegalDAG)
27693 return SDValue();
27694
27695 // TODO: Handle extended types?
27696 EVT VT = Op.getValueType();
27697 if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
27698 VT.getScalarType() != MVT::f64)
27699 return SDValue();
27700
27701 // If estimates are explicitly disabled for this function, we're done.
27702 MachineFunction &MF = DAG.getMachineFunction();
27703 int Enabled = TLI.getRecipEstimateDivEnabled(VT, MF);
27704 if (Enabled == TLI.ReciprocalEstimate::Disabled)
27705 return SDValue();
27706
27707 // Estimates may be explicitly enabled for this type with a custom number of
27708 // refinement steps.
27709 int Iterations = TLI.getDivRefinementSteps(VT, MF);
27710 if (SDValue Est = TLI.getRecipEstimate(Operand: Op, DAG, Enabled, RefinementSteps&: Iterations)) {
27711 AddToWorklist(N: Est.getNode());
27712
27713 SDLoc DL(Op);
27714 if (Iterations) {
27715 SDValue FPOne = DAG.getConstantFP(Val: 1.0, DL, VT);
27716
27717 // Newton iterations: Est = Est + Est (N - Arg * Est)
27718 // If this is the last iteration, also multiply by the numerator.
27719 for (int i = 0; i < Iterations; ++i) {
27720 SDValue MulEst = Est;
27721
27722 if (i == Iterations - 1) {
27723 MulEst = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N, N2: Est, Flags);
27724 AddToWorklist(N: MulEst.getNode());
27725 }
27726
27727 SDValue NewEst = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Op, N2: MulEst, Flags);
27728 AddToWorklist(N: NewEst.getNode());
27729
27730 NewEst = DAG.getNode(Opcode: ISD::FSUB, DL, VT,
27731 N1: (i == Iterations - 1 ? N : FPOne), N2: NewEst, Flags);
27732 AddToWorklist(N: NewEst.getNode());
27733
27734 NewEst = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: NewEst, Flags);
27735 AddToWorklist(N: NewEst.getNode());
27736
27737 Est = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: MulEst, N2: NewEst, Flags);
27738 AddToWorklist(N: Est.getNode());
27739 }
27740 } else {
27741 // If no iterations are available, multiply with N.
27742 Est = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: N, Flags);
27743 AddToWorklist(N: Est.getNode());
27744 }
27745
27746 return Est;
27747 }
27748
27749 return SDValue();
27750}
27751
27752/// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
27753/// For the reciprocal sqrt, we need to find the zero of the function:
27754/// F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
27755/// =>
27756/// X_{i+1} = X_i (1.5 - A X_i^2 / 2)
27757/// As a result, we precompute A/2 prior to the iteration loop.
27758SDValue DAGCombiner::buildSqrtNROneConst(SDValue Arg, SDValue Est,
27759 unsigned Iterations,
27760 SDNodeFlags Flags, bool Reciprocal) {
27761 EVT VT = Arg.getValueType();
27762 SDLoc DL(Arg);
27763 SDValue ThreeHalves = DAG.getConstantFP(Val: 1.5, DL, VT);
27764
27765 // We now need 0.5 * Arg which we can write as (1.5 * Arg - Arg) so that
27766 // this entire sequence requires only one FP constant.
27767 SDValue HalfArg = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: ThreeHalves, N2: Arg, Flags);
27768 HalfArg = DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1: HalfArg, N2: Arg, Flags);
27769
27770 // Newton iterations: Est = Est * (1.5 - HalfArg * Est * Est)
27771 for (unsigned i = 0; i < Iterations; ++i) {
27772 SDValue NewEst = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: Est, Flags);
27773 NewEst = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: HalfArg, N2: NewEst, Flags);
27774 NewEst = DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1: ThreeHalves, N2: NewEst, Flags);
27775 Est = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: NewEst, Flags);
27776 }
27777
27778 // If non-reciprocal square root is requested, multiply the result by Arg.
27779 if (!Reciprocal)
27780 Est = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: Arg, Flags);
27781
27782 return Est;
27783}
27784
27785/// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
27786/// For the reciprocal sqrt, we need to find the zero of the function:
27787/// F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
27788/// =>
27789/// X_{i+1} = (-0.5 * X_i) * (A * X_i * X_i + (-3.0))
27790SDValue DAGCombiner::buildSqrtNRTwoConst(SDValue Arg, SDValue Est,
27791 unsigned Iterations,
27792 SDNodeFlags Flags, bool Reciprocal) {
27793 EVT VT = Arg.getValueType();
27794 SDLoc DL(Arg);
27795 SDValue MinusThree = DAG.getConstantFP(Val: -3.0, DL, VT);
27796 SDValue MinusHalf = DAG.getConstantFP(Val: -0.5, DL, VT);
27797
27798 // This routine must enter the loop below to work correctly
27799 // when (Reciprocal == false).
27800 assert(Iterations > 0);
27801
27802 // Newton iterations for reciprocal square root:
27803 // E = (E * -0.5) * ((A * E) * E + -3.0)
27804 for (unsigned i = 0; i < Iterations; ++i) {
27805 SDValue AE = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Arg, N2: Est, Flags);
27806 SDValue AEE = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: AE, N2: Est, Flags);
27807 SDValue RHS = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: AEE, N2: MinusThree, Flags);
27808
27809 // When calculating a square root at the last iteration build:
27810 // S = ((A * E) * -0.5) * ((A * E) * E + -3.0)
27811 // (notice a common subexpression)
27812 SDValue LHS;
27813 if (Reciprocal || (i + 1) < Iterations) {
27814 // RSQRT: LHS = (E * -0.5)
27815 LHS = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: MinusHalf, Flags);
27816 } else {
27817 // SQRT: LHS = (A * E) * -0.5
27818 LHS = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: AE, N2: MinusHalf, Flags);
27819 }
27820
27821 Est = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: LHS, N2: RHS, Flags);
27822 }
27823
27824 return Est;
27825}
27826
27827/// Build code to calculate either rsqrt(Op) or sqrt(Op). In the latter case
27828/// Op*rsqrt(Op) is actually computed, so additional postprocessing is needed if
27829/// Op can be zero.
27830SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
27831 bool Reciprocal) {
27832 if (LegalDAG)
27833 return SDValue();
27834
27835 // TODO: Handle extended types?
27836 EVT VT = Op.getValueType();
27837 if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
27838 VT.getScalarType() != MVT::f64)
27839 return SDValue();
27840
27841 // If estimates are explicitly disabled for this function, we're done.
27842 MachineFunction &MF = DAG.getMachineFunction();
27843 int Enabled = TLI.getRecipEstimateSqrtEnabled(VT, MF);
27844 if (Enabled == TLI.ReciprocalEstimate::Disabled)
27845 return SDValue();
27846
27847 // Estimates may be explicitly enabled for this type with a custom number of
27848 // refinement steps.
27849 int Iterations = TLI.getSqrtRefinementSteps(VT, MF);
27850
27851 bool UseOneConstNR = false;
27852 if (SDValue Est =
27853 TLI.getSqrtEstimate(Operand: Op, DAG, Enabled, RefinementSteps&: Iterations, UseOneConstNR,
27854 Reciprocal)) {
27855 AddToWorklist(N: Est.getNode());
27856
27857 if (Iterations > 0)
27858 Est = UseOneConstNR
27859 ? buildSqrtNROneConst(Arg: Op, Est, Iterations, Flags, Reciprocal)
27860 : buildSqrtNRTwoConst(Arg: Op, Est, Iterations, Flags, Reciprocal);
27861 if (!Reciprocal) {
27862 SDLoc DL(Op);
27863 // Try the target specific test first.
27864 SDValue Test = TLI.getSqrtInputTest(Operand: Op, DAG, Mode: DAG.getDenormalMode(VT));
27865
27866 // The estimate is now completely wrong if the input was exactly 0.0 or
27867 // possibly a denormal. Force the answer to 0.0 or value provided by
27868 // target for those cases.
27869 Est = DAG.getNode(
27870 Opcode: Test.getValueType().isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT,
27871 N1: Test, N2: TLI.getSqrtResultForDenormInput(Operand: Op, DAG), N3: Est);
27872 }
27873 return Est;
27874 }
27875
27876 return SDValue();
27877}
27878
27879SDValue DAGCombiner::buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags) {
27880 return buildSqrtEstimateImpl(Op, Flags, Reciprocal: true);
27881}
27882
27883SDValue DAGCombiner::buildSqrtEstimate(SDValue Op, SDNodeFlags Flags) {
27884 return buildSqrtEstimateImpl(Op, Flags, Reciprocal: false);
27885}
27886
27887/// Return true if there is any possibility that the two addresses overlap.
27888bool DAGCombiner::mayAlias(SDNode *Op0, SDNode *Op1) const {
27889
27890 struct MemUseCharacteristics {
27891 bool IsVolatile;
27892 bool IsAtomic;
27893 SDValue BasePtr;
27894 int64_t Offset;
27895 std::optional<int64_t> NumBytes;
27896 MachineMemOperand *MMO;
27897 };
27898
27899 auto getCharacteristics = [](SDNode *N) -> MemUseCharacteristics {
27900 if (const auto *LSN = dyn_cast<LSBaseSDNode>(Val: N)) {
27901 int64_t Offset = 0;
27902 if (auto *C = dyn_cast<ConstantSDNode>(Val: LSN->getOffset()))
27903 Offset = (LSN->getAddressingMode() == ISD::PRE_INC)
27904 ? C->getSExtValue()
27905 : (LSN->getAddressingMode() == ISD::PRE_DEC)
27906 ? -1 * C->getSExtValue()
27907 : 0;
27908 uint64_t Size =
27909 MemoryLocation::getSizeOrUnknown(T: LSN->getMemoryVT().getStoreSize());
27910 return {.IsVolatile: LSN->isVolatile(),
27911 .IsAtomic: LSN->isAtomic(),
27912 .BasePtr: LSN->getBasePtr(),
27913 .Offset: Offset /*base offset*/,
27914 .NumBytes: std::optional<int64_t>(Size),
27915 .MMO: LSN->getMemOperand()};
27916 }
27917 if (const auto *LN = cast<LifetimeSDNode>(Val: N))
27918 return {.IsVolatile: false /*isVolatile*/,
27919 /*isAtomic*/ .IsAtomic: false,
27920 .BasePtr: LN->getOperand(Num: 1),
27921 .Offset: (LN->hasOffset()) ? LN->getOffset() : 0,
27922 .NumBytes: (LN->hasOffset()) ? std::optional<int64_t>(LN->getSize())
27923 : std::optional<int64_t>(),
27924 .MMO: (MachineMemOperand *)nullptr};
27925 // Default.
27926 return {.IsVolatile: false /*isvolatile*/,
27927 /*isAtomic*/ .IsAtomic: false, .BasePtr: SDValue(),
27928 .Offset: (int64_t)0 /*offset*/, .NumBytes: std::optional<int64_t>() /*size*/,
27929 .MMO: (MachineMemOperand *)nullptr};
27930 };
27931
27932 MemUseCharacteristics MUC0 = getCharacteristics(Op0),
27933 MUC1 = getCharacteristics(Op1);
27934
27935 // If they are to the same address, then they must be aliases.
27936 if (MUC0.BasePtr.getNode() && MUC0.BasePtr == MUC1.BasePtr &&
27937 MUC0.Offset == MUC1.Offset)
27938 return true;
27939
27940 // If they are both volatile then they cannot be reordered.
27941 if (MUC0.IsVolatile && MUC1.IsVolatile)
27942 return true;
27943
27944 // Be conservative about atomics for the moment
27945 // TODO: This is way overconservative for unordered atomics (see D66309)
27946 if (MUC0.IsAtomic && MUC1.IsAtomic)
27947 return true;
27948
27949 if (MUC0.MMO && MUC1.MMO) {
27950 if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
27951 (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
27952 return false;
27953 }
27954
27955 // Try to prove that there is aliasing, or that there is no aliasing. Either
27956 // way, we can return now. If nothing can be proved, proceed with more tests.
27957 bool IsAlias;
27958 if (BaseIndexOffset::computeAliasing(Op0, NumBytes0: MUC0.NumBytes, Op1, NumBytes1: MUC1.NumBytes,
27959 DAG, IsAlias))
27960 return IsAlias;
27961
27962 // The following all rely on MMO0 and MMO1 being valid. Fail conservatively if
27963 // either are not known.
27964 if (!MUC0.MMO || !MUC1.MMO)
27965 return true;
27966
27967 // If one operation reads from invariant memory, and the other may store, they
27968 // cannot alias. These should really be checking the equivalent of mayWrite,
27969 // but it only matters for memory nodes other than load /store.
27970 if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
27971 (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
27972 return false;
27973
27974 // If we know required SrcValue1 and SrcValue2 have relatively large
27975 // alignment compared to the size and offset of the access, we may be able
27976 // to prove they do not alias. This check is conservative for now to catch
27977 // cases created by splitting vector types, it only works when the offsets are
27978 // multiples of the size of the data.
27979 int64_t SrcValOffset0 = MUC0.MMO->getOffset();
27980 int64_t SrcValOffset1 = MUC1.MMO->getOffset();
27981 Align OrigAlignment0 = MUC0.MMO->getBaseAlign();
27982 Align OrigAlignment1 = MUC1.MMO->getBaseAlign();
27983 auto &Size0 = MUC0.NumBytes;
27984 auto &Size1 = MUC1.NumBytes;
27985 if (OrigAlignment0 == OrigAlignment1 && SrcValOffset0 != SrcValOffset1 &&
27986 Size0.has_value() && Size1.has_value() && *Size0 == *Size1 &&
27987 OrigAlignment0 > *Size0 && SrcValOffset0 % *Size0 == 0 &&
27988 SrcValOffset1 % *Size1 == 0) {
27989 int64_t OffAlign0 = SrcValOffset0 % OrigAlignment0.value();
27990 int64_t OffAlign1 = SrcValOffset1 % OrigAlignment1.value();
27991
27992 // There is no overlap between these relatively aligned accesses of
27993 // similar size. Return no alias.
27994 if ((OffAlign0 + *Size0) <= OffAlign1 || (OffAlign1 + *Size1) <= OffAlign0)
27995 return false;
27996 }
27997
27998 bool UseAA = CombinerGlobalAA.getNumOccurrences() > 0
27999 ? CombinerGlobalAA
28000 : DAG.getSubtarget().useAA();
28001#ifndef NDEBUG
28002 if (CombinerAAOnlyFunc.getNumOccurrences() &&
28003 CombinerAAOnlyFunc != DAG.getMachineFunction().getName())
28004 UseAA = false;
28005#endif
28006
28007 if (UseAA && AA && MUC0.MMO->getValue() && MUC1.MMO->getValue() && Size0 &&
28008 Size1) {
28009 // Use alias analysis information.
28010 int64_t MinOffset = std::min(a: SrcValOffset0, b: SrcValOffset1);
28011 int64_t Overlap0 = *Size0 + SrcValOffset0 - MinOffset;
28012 int64_t Overlap1 = *Size1 + SrcValOffset1 - MinOffset;
28013 if (AA->isNoAlias(
28014 LocA: MemoryLocation(MUC0.MMO->getValue(), Overlap0,
28015 UseTBAA ? MUC0.MMO->getAAInfo() : AAMDNodes()),
28016 LocB: MemoryLocation(MUC1.MMO->getValue(), Overlap1,
28017 UseTBAA ? MUC1.MMO->getAAInfo() : AAMDNodes())))
28018 return false;
28019 }
28020
28021 // Otherwise we have to assume they alias.
28022 return true;
28023}
28024
28025/// Walk up chain skipping non-aliasing memory nodes,
28026/// looking for aliasing nodes and adding them to the Aliases vector.
28027void DAGCombiner::GatherAllAliases(SDNode *N, SDValue OriginalChain,
28028 SmallVectorImpl<SDValue> &Aliases) {
28029 SmallVector<SDValue, 8> Chains; // List of chains to visit.
28030 SmallPtrSet<SDNode *, 16> Visited; // Visited node set.
28031
28032 // Get alias information for node.
28033 // TODO: relax aliasing for unordered atomics (see D66309)
28034 const bool IsLoad = isa<LoadSDNode>(Val: N) && cast<LoadSDNode>(Val: N)->isSimple();
28035
28036 // Starting off.
28037 Chains.push_back(Elt: OriginalChain);
28038 unsigned Depth = 0;
28039
28040 // Attempt to improve chain by a single step
28041 auto ImproveChain = [&](SDValue &C) -> bool {
28042 switch (C.getOpcode()) {
28043 case ISD::EntryToken:
28044 // No need to mark EntryToken.
28045 C = SDValue();
28046 return true;
28047 case ISD::LOAD:
28048 case ISD::STORE: {
28049 // Get alias information for C.
28050 // TODO: Relax aliasing for unordered atomics (see D66309)
28051 bool IsOpLoad = isa<LoadSDNode>(Val: C.getNode()) &&
28052 cast<LSBaseSDNode>(Val: C.getNode())->isSimple();
28053 if ((IsLoad && IsOpLoad) || !mayAlias(Op0: N, Op1: C.getNode())) {
28054 // Look further up the chain.
28055 C = C.getOperand(i: 0);
28056 return true;
28057 }
28058 // Alias, so stop here.
28059 return false;
28060 }
28061
28062 case ISD::CopyFromReg:
28063 // Always forward past CopyFromReg.
28064 C = C.getOperand(i: 0);
28065 return true;
28066
28067 case ISD::LIFETIME_START:
28068 case ISD::LIFETIME_END: {
28069 // We can forward past any lifetime start/end that can be proven not to
28070 // alias the memory access.
28071 if (!mayAlias(Op0: N, Op1: C.getNode())) {
28072 // Look further up the chain.
28073 C = C.getOperand(i: 0);
28074 return true;
28075 }
28076 return false;
28077 }
28078 default:
28079 return false;
28080 }
28081 };
28082
28083 // Look at each chain and determine if it is an alias. If so, add it to the
28084 // aliases list. If not, then continue up the chain looking for the next
28085 // candidate.
28086 while (!Chains.empty()) {
28087 SDValue Chain = Chains.pop_back_val();
28088
28089 // Don't bother if we've seen Chain before.
28090 if (!Visited.insert(Ptr: Chain.getNode()).second)
28091 continue;
28092
28093 // For TokenFactor nodes, look at each operand and only continue up the
28094 // chain until we reach the depth limit.
28095 //
28096 // FIXME: The depth check could be made to return the last non-aliasing
28097 // chain we found before we hit a tokenfactor rather than the original
28098 // chain.
28099 if (Depth > TLI.getGatherAllAliasesMaxDepth()) {
28100 Aliases.clear();
28101 Aliases.push_back(Elt: OriginalChain);
28102 return;
28103 }
28104
28105 if (Chain.getOpcode() == ISD::TokenFactor) {
28106 // We have to check each of the operands of the token factor for "small"
28107 // token factors, so we queue them up. Adding the operands to the queue
28108 // (stack) in reverse order maintains the original order and increases the
28109 // likelihood that getNode will find a matching token factor (CSE.)
28110 if (Chain.getNumOperands() > 16) {
28111 Aliases.push_back(Elt: Chain);
28112 continue;
28113 }
28114 for (unsigned n = Chain.getNumOperands(); n;)
28115 Chains.push_back(Elt: Chain.getOperand(i: --n));
28116 ++Depth;
28117 continue;
28118 }
28119 // Everything else
28120 if (ImproveChain(Chain)) {
28121 // Updated Chain Found, Consider new chain if one exists.
28122 if (Chain.getNode())
28123 Chains.push_back(Elt: Chain);
28124 ++Depth;
28125 continue;
28126 }
28127 // No Improved Chain Possible, treat as Alias.
28128 Aliases.push_back(Elt: Chain);
28129 }
28130}
28131
28132/// Walk up chain skipping non-aliasing memory nodes, looking for a better chain
28133/// (aliasing node.)
28134SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) {
28135 if (OptLevel == CodeGenOptLevel::None)
28136 return OldChain;
28137
28138 // Ops for replacing token factor.
28139 SmallVector<SDValue, 8> Aliases;
28140
28141 // Accumulate all the aliases to this node.
28142 GatherAllAliases(N, OriginalChain: OldChain, Aliases);
28143
28144 // If no operands then chain to entry token.
28145 if (Aliases.empty())
28146 return DAG.getEntryNode();
28147
28148 // If a single operand then chain to it. We don't need to revisit it.
28149 if (Aliases.size() == 1)
28150 return Aliases[0];
28151
28152 // Construct a custom tailored token factor.
28153 return DAG.getTokenFactor(DL: SDLoc(N), Vals&: Aliases);
28154}
28155
28156// This function tries to collect a bunch of potentially interesting
28157// nodes to improve the chains of, all at once. This might seem
28158// redundant, as this function gets called when visiting every store
28159// node, so why not let the work be done on each store as it's visited?
28160//
28161// I believe this is mainly important because mergeConsecutiveStores
28162// is unable to deal with merging stores of different sizes, so unless
28163// we improve the chains of all the potential candidates up-front
28164// before running mergeConsecutiveStores, it might only see some of
28165// the nodes that will eventually be candidates, and then not be able
28166// to go from a partially-merged state to the desired final
28167// fully-merged state.
28168
28169bool DAGCombiner::parallelizeChainedStores(StoreSDNode *St) {
28170 SmallVector<StoreSDNode *, 8> ChainedStores;
28171 StoreSDNode *STChain = St;
28172 // Intervals records which offsets from BaseIndex have been covered. In
28173 // the common case, every store writes to the immediately previous address
28174 // space and thus merged with the previous interval at insertion time.
28175
28176 using IMap = llvm::IntervalMap<int64_t, std::monostate, 8,
28177 IntervalMapHalfOpenInfo<int64_t>>;
28178 IMap::Allocator A;
28179 IMap Intervals(A);
28180
28181 // This holds the base pointer, index, and the offset in bytes from the base
28182 // pointer.
28183 const BaseIndexOffset BasePtr = BaseIndexOffset::match(N: St, DAG);
28184
28185 // We must have a base and an offset.
28186 if (!BasePtr.getBase().getNode())
28187 return false;
28188
28189 // Do not handle stores to undef base pointers.
28190 if (BasePtr.getBase().isUndef())
28191 return false;
28192
28193 // Do not handle stores to opaque types
28194 if (St->getMemoryVT().isZeroSized())
28195 return false;
28196
28197 // BaseIndexOffset assumes that offsets are fixed-size, which
28198 // is not valid for scalable vectors where the offsets are
28199 // scaled by `vscale`, so bail out early.
28200 if (St->getMemoryVT().isScalableVT())
28201 return false;
28202
28203 // Add ST's interval.
28204 Intervals.insert(a: 0, b: (St->getMemoryVT().getSizeInBits() + 7) / 8,
28205 y: std::monostate{});
28206
28207 while (StoreSDNode *Chain = dyn_cast<StoreSDNode>(Val: STChain->getChain())) {
28208 if (Chain->getMemoryVT().isScalableVector())
28209 return false;
28210
28211 // If the chain has more than one use, then we can't reorder the mem ops.
28212 if (!SDValue(Chain, 0)->hasOneUse())
28213 break;
28214 // TODO: Relax for unordered atomics (see D66309)
28215 if (!Chain->isSimple() || Chain->isIndexed())
28216 break;
28217
28218 // Find the base pointer and offset for this memory node.
28219 const BaseIndexOffset Ptr = BaseIndexOffset::match(N: Chain, DAG);
28220 // Check that the base pointer is the same as the original one.
28221 int64_t Offset;
28222 if (!BasePtr.equalBaseIndex(Other: Ptr, DAG, Off&: Offset))
28223 break;
28224 int64_t Length = (Chain->getMemoryVT().getSizeInBits() + 7) / 8;
28225 // Make sure we don't overlap with other intervals by checking the ones to
28226 // the left or right before inserting.
28227 auto I = Intervals.find(x: Offset);
28228 // If there's a next interval, we should end before it.
28229 if (I != Intervals.end() && I.start() < (Offset + Length))
28230 break;
28231 // If there's a previous interval, we should start after it.
28232 if (I != Intervals.begin() && (--I).stop() <= Offset)
28233 break;
28234 Intervals.insert(a: Offset, b: Offset + Length, y: std::monostate{});
28235
28236 ChainedStores.push_back(Elt: Chain);
28237 STChain = Chain;
28238 }
28239
28240 // If we didn't find a chained store, exit.
28241 if (ChainedStores.empty())
28242 return false;
28243
28244 // Improve all chained stores (St and ChainedStores members) starting from
28245 // where the store chain ended and return single TokenFactor.
28246 SDValue NewChain = STChain->getChain();
28247 SmallVector<SDValue, 8> TFOps;
28248 for (unsigned I = ChainedStores.size(); I;) {
28249 StoreSDNode *S = ChainedStores[--I];
28250 SDValue BetterChain = FindBetterChain(N: S, OldChain: NewChain);
28251 S = cast<StoreSDNode>(Val: DAG.UpdateNodeOperands(
28252 N: S, Op1: BetterChain, Op2: S->getOperand(Num: 1), Op3: S->getOperand(Num: 2), Op4: S->getOperand(Num: 3)));
28253 TFOps.push_back(Elt: SDValue(S, 0));
28254 ChainedStores[I] = S;
28255 }
28256
28257 // Improve St's chain. Use a new node to avoid creating a loop from CombineTo.
28258 SDValue BetterChain = FindBetterChain(N: St, OldChain: NewChain);
28259 SDValue NewST;
28260 if (St->isTruncatingStore())
28261 NewST = DAG.getTruncStore(Chain: BetterChain, dl: SDLoc(St), Val: St->getValue(),
28262 Ptr: St->getBasePtr(), SVT: St->getMemoryVT(),
28263 MMO: St->getMemOperand());
28264 else
28265 NewST = DAG.getStore(Chain: BetterChain, dl: SDLoc(St), Val: St->getValue(),
28266 Ptr: St->getBasePtr(), MMO: St->getMemOperand());
28267
28268 TFOps.push_back(Elt: NewST);
28269
28270 // If we improved every element of TFOps, then we've lost the dependence on
28271 // NewChain to successors of St and we need to add it back to TFOps. Do so at
28272 // the beginning to keep relative order consistent with FindBetterChains.
28273 auto hasImprovedChain = [&](SDValue ST) -> bool {
28274 return ST->getOperand(Num: 0) != NewChain;
28275 };
28276 bool AddNewChain = llvm::all_of(Range&: TFOps, P: hasImprovedChain);
28277 if (AddNewChain)
28278 TFOps.insert(I: TFOps.begin(), Elt: NewChain);
28279
28280 SDValue TF = DAG.getTokenFactor(DL: SDLoc(STChain), Vals&: TFOps);
28281 CombineTo(N: St, Res: TF);
28282
28283 // Add TF and its operands to the worklist.
28284 AddToWorklist(N: TF.getNode());
28285 for (const SDValue &Op : TF->ops())
28286 AddToWorklist(N: Op.getNode());
28287 AddToWorklist(N: STChain);
28288 return true;
28289}
28290
28291bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) {
28292 if (OptLevel == CodeGenOptLevel::None)
28293 return false;
28294
28295 const BaseIndexOffset BasePtr = BaseIndexOffset::match(N: St, DAG);
28296
28297 // We must have a base and an offset.
28298 if (!BasePtr.getBase().getNode())
28299 return false;
28300
28301 // Do not handle stores to undef base pointers.
28302 if (BasePtr.getBase().isUndef())
28303 return false;
28304
28305 // Directly improve a chain of disjoint stores starting at St.
28306 if (parallelizeChainedStores(St))
28307 return true;
28308
28309 // Improve St's Chain..
28310 SDValue BetterChain = FindBetterChain(N: St, OldChain: St->getChain());
28311 if (St->getChain() != BetterChain) {
28312 replaceStoreChain(ST: St, BetterChain);
28313 return true;
28314 }
28315 return false;
28316}
28317
28318/// This is the entry point for the file.
28319void SelectionDAG::Combine(CombineLevel Level, AliasAnalysis *AA,
28320 CodeGenOptLevel OptLevel) {
28321 /// This is the main entry point to this class.
28322 DAGCombiner(*this, AA, OptLevel).Run(AtLevel: Level);
28323}
28324

source code of llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp