1//===- llvm/Analysis/VectorUtils.h - Vector utilities -----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines some vectorizer utilities.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef LLVM_ANALYSIS_VECTORUTILS_H
14#define LLVM_ANALYSIS_VECTORUTILS_H
15
16#include "llvm/ADT/MapVector.h"
17#include "llvm/ADT/SmallVector.h"
18#include "llvm/Analysis/LoopAccessAnalysis.h"
19#include "llvm/Support/CheckedArithmetic.h"
20
21namespace llvm {
22class TargetLibraryInfo;
23
24/// Describes the type of Parameters
25enum class VFParamKind {
26 Vector, // No semantic information.
27 OMP_Linear, // declare simd linear(i)
28 OMP_LinearRef, // declare simd linear(ref(i))
29 OMP_LinearVal, // declare simd linear(val(i))
30 OMP_LinearUVal, // declare simd linear(uval(i))
31 OMP_LinearPos, // declare simd linear(i:c) uniform(c)
32 OMP_LinearValPos, // declare simd linear(val(i:c)) uniform(c)
33 OMP_LinearRefPos, // declare simd linear(ref(i:c)) uniform(c)
34 OMP_LinearUValPos, // declare simd linear(uval(i:c)) uniform(c
35 OMP_Uniform, // declare simd uniform(i)
36 GlobalPredicate, // Global logical predicate that acts on all lanes
37 // of the input and output mask concurrently. For
38 // example, it is implied by the `M` token in the
39 // Vector Function ABI mangled name.
40 Unknown
41};
42
43/// Describes the type of Instruction Set Architecture
44enum class VFISAKind {
45 AdvancedSIMD, // AArch64 Advanced SIMD (NEON)
46 SVE, // AArch64 Scalable Vector Extension
47 SSE, // x86 SSE
48 AVX, // x86 AVX
49 AVX2, // x86 AVX2
50 AVX512, // x86 AVX512
51 LLVM, // LLVM internal ISA for functions that are not
52 // attached to an existing ABI via name mangling.
53 Unknown // Unknown ISA
54};
55
56/// Encapsulates information needed to describe a parameter.
57///
58/// The description of the parameter is not linked directly to
59/// OpenMP or any other vector function description. This structure
60/// is extendible to handle other paradigms that describe vector
61/// functions and their parameters.
62struct VFParameter {
63 unsigned ParamPos; // Parameter Position in Scalar Function.
64 VFParamKind ParamKind; // Kind of Parameter.
65 int LinearStepOrPos = 0; // Step or Position of the Parameter.
66 Align Alignment = Align(); // Optional alignment in bytes, defaulted to 1.
67
68 // Comparison operator.
69 bool operator==(const VFParameter &Other) const {
70 return std::tie(ParamPos, ParamKind, LinearStepOrPos, Alignment) ==
71 std::tie(Other.ParamPos, Other.ParamKind, Other.LinearStepOrPos,
72 Other.Alignment);
73 }
74};
75
76/// Contains the information about the kind of vectorization
77/// available.
78///
79/// This object in independent on the paradigm used to
80/// represent vector functions. in particular, it is not attached to
81/// any target-specific ABI.
82struct VFShape {
83 unsigned VF; // Vectorization factor.
84 bool IsScalable; // True if the function is a scalable function.
85 SmallVector<VFParameter, 8> Parameters; // List of parameter information.
86 // Comparison operator.
87 bool operator==(const VFShape &Other) const {
88 return std::tie(VF, IsScalable, Parameters) ==
89 std::tie(Other.VF, Other.IsScalable, Other.Parameters);
90 }
91
92 /// Update the parameter in position P.ParamPos to P.
93 void updateParam(VFParameter P) {
94 assert(P.ParamPos < Parameters.size() && "Invalid parameter position.");
95 Parameters[P.ParamPos] = P;
96 assert(hasValidParameterList() && "Invalid parameter list");
97 }
98
99 // Retrieve the VFShape that can be used to map a (scalar) function to itself,
100 // with VF = 1.
101 static VFShape getScalarShape(const CallInst &CI) {
102 return VFShape::get(CI, ElementCount::getFixed(1),
103 /*HasGlobalPredicate*/ false);
104 }
105
106 // Retrieve the basic vectorization shape of the function, where all
107 // parameters are mapped to VFParamKind::Vector with \p EC
108 // lanes. Specifies whether the function has a Global Predicate
109 // argument via \p HasGlobalPred.
110 static VFShape get(const CallInst &CI, ElementCount EC, bool HasGlobalPred) {
111 SmallVector<VFParameter, 8> Parameters;
112 for (unsigned I = 0; I < CI.arg_size(); ++I)
113 Parameters.push_back(VFParameter({I, VFParamKind::Vector}));
114 if (HasGlobalPred)
115 Parameters.push_back(
116 VFParameter({CI.arg_size(), VFParamKind::GlobalPredicate}));
117
118 return {EC.getKnownMinValue(), EC.isScalable(), Parameters};
119 }
120 /// Sanity check on the Parameters in the VFShape.
121 bool hasValidParameterList() const;
122};
123
124/// Holds the VFShape for a specific scalar to vector function mapping.
125struct VFInfo {
126 VFShape Shape; /// Classification of the vector function.
127 std::string ScalarName; /// Scalar Function Name.
128 std::string VectorName; /// Vector Function Name associated to this VFInfo.
129 VFISAKind ISA; /// Instruction Set Architecture.
130
131 // Comparison operator.
132 bool operator==(const VFInfo &Other) const {
133 return std::tie(Shape, ScalarName, VectorName, ISA) ==
134 std::tie(Shape, Other.ScalarName, Other.VectorName, Other.ISA);
135 }
136};
137
138namespace VFABI {
139/// LLVM Internal VFABI ISA token for vector functions.
140static constexpr char const *_LLVM_ = "_LLVM_";
141/// Prefix for internal name redirection for vector function that
142/// tells the compiler to scalarize the call using the scalar name
143/// of the function. For example, a mangled name like
144/// `_ZGV_LLVM_N2v_foo(_LLVM_Scalarize_foo)` would tell the
145/// vectorizer to vectorize the scalar call `foo`, and to scalarize
146/// it once vectorization is done.
147static constexpr char const *_LLVM_Scalarize_ = "_LLVM_Scalarize_";
148
149/// Function to construct a VFInfo out of a mangled names in the
150/// following format:
151///
152/// <VFABI_name>{(<redirection>)}
153///
154/// where <VFABI_name> is the name of the vector function, mangled according
155/// to the rules described in the Vector Function ABI of the target vector
156/// extension (or <isa> from now on). The <VFABI_name> is in the following
157/// format:
158///
159/// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)]
160///
161/// This methods support demangling rules for the following <isa>:
162///
163/// * AArch64: https://developer.arm.com/docs/101129/latest
164///
165/// * x86 (libmvec): https://sourceware.org/glibc/wiki/libmvec and
166/// https://sourceware.org/glibc/wiki/libmvec?action=AttachFile&do=view&target=VectorABI.txt
167///
168/// \param MangledName -> input string in the format
169/// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)].
170/// \param M -> Module used to retrieve informations about the vector
171/// function that are not possible to retrieve from the mangled
172/// name. At the moment, this parameter is needed only to retrieve the
173/// Vectorization Factor of scalable vector functions from their
174/// respective IR declarations.
175Optional<VFInfo> tryDemangleForVFABI(StringRef MangledName, const Module &M);
176
177/// This routine mangles the given VectorName according to the LangRef
178/// specification for vector-function-abi-variant attribute and is specific to
179/// the TLI mappings. It is the responsibility of the caller to make sure that
180/// this is only used if all parameters in the vector function are vector type.
181/// This returned string holds scalar-to-vector mapping:
182/// _ZGV<isa><mask><vlen><vparams>_<scalarname>(<vectorname>)
183///
184/// where:
185///
186/// <isa> = "_LLVM_"
187/// <mask> = "N". Note: TLI does not support masked interfaces.
188/// <vlen> = Number of concurrent lanes, stored in the `VectorizationFactor`
189/// field of the `VecDesc` struct. If the number of lanes is scalable
190/// then 'x' is printed instead.
191/// <vparams> = "v", as many as are the numArgs.
192/// <scalarname> = the name of the scalar function.
193/// <vectorname> = the name of the vector function.
194std::string mangleTLIVectorName(StringRef VectorName, StringRef ScalarName,
195 unsigned numArgs, ElementCount VF);
196
197/// Retrieve the `VFParamKind` from a string token.
198VFParamKind getVFParamKindFromString(const StringRef Token);
199
200// Name of the attribute where the variant mappings are stored.
201static constexpr char const *MappingsAttrName = "vector-function-abi-variant";
202
203/// Populates a set of strings representing the Vector Function ABI variants
204/// associated to the CallInst CI. If the CI does not contain the
205/// vector-function-abi-variant attribute, we return without populating
206/// VariantMappings, i.e. callers of getVectorVariantNames need not check for
207/// the presence of the attribute (see InjectTLIMappings).
208void getVectorVariantNames(const CallInst &CI,
209 SmallVectorImpl<std::string> &VariantMappings);
210} // end namespace VFABI
211
212/// The Vector Function Database.
213///
214/// Helper class used to find the vector functions associated to a
215/// scalar CallInst.
216class VFDatabase {
217 /// The Module of the CallInst CI.
218 const Module *M;
219 /// The CallInst instance being queried for scalar to vector mappings.
220 const CallInst &CI;
221 /// List of vector functions descriptors associated to the call
222 /// instruction.
223 const SmallVector<VFInfo, 8> ScalarToVectorMappings;
224
225 /// Retrieve the scalar-to-vector mappings associated to the rule of
226 /// a vector Function ABI.
227 static void getVFABIMappings(const CallInst &CI,
228 SmallVectorImpl<VFInfo> &Mappings) {
229 if (!CI.getCalledFunction())
230 return;
231
232 const StringRef ScalarName = CI.getCalledFunction()->getName();
233
234 SmallVector<std::string, 8> ListOfStrings;
235 // The check for the vector-function-abi-variant attribute is done when
236 // retrieving the vector variant names here.
237 VFABI::getVectorVariantNames(CI, ListOfStrings);
238 if (ListOfStrings.empty())
239 return;
240 for (const auto &MangledName : ListOfStrings) {
241 const Optional<VFInfo> Shape =
242 VFABI::tryDemangleForVFABI(MangledName, *(CI.getModule()));
243 // A match is found via scalar and vector names, and also by
244 // ensuring that the variant described in the attribute has a
245 // corresponding definition or declaration of the vector
246 // function in the Module M.
247 if (Shape.hasValue() && (Shape.getValue().ScalarName == ScalarName)) {
248 assert(CI.getModule()->getFunction(Shape.getValue().VectorName) &&
249 "Vector function is missing.");
250 Mappings.push_back(Shape.getValue());
251 }
252 }
253 }
254
255public:
256 /// Retrieve all the VFInfo instances associated to the CallInst CI.
257 static SmallVector<VFInfo, 8> getMappings(const CallInst &CI) {
258 SmallVector<VFInfo, 8> Ret;
259
260 // Get mappings from the Vector Function ABI variants.
261 getVFABIMappings(CI, Ret);
262
263 // Other non-VFABI variants should be retrieved here.
264
265 return Ret;
266 }
267
268 /// Constructor, requires a CallInst instance.
269 VFDatabase(CallInst &CI)
270 : M(CI.getModule()), CI(CI),
271 ScalarToVectorMappings(VFDatabase::getMappings(CI)) {}
272 /// \defgroup VFDatabase query interface.
273 ///
274 /// @{
275 /// Retrieve the Function with VFShape \p Shape.
276 Function *getVectorizedFunction(const VFShape &Shape) const {
277 if (Shape == VFShape::getScalarShape(CI))
278 return CI.getCalledFunction();
279
280 for (const auto &Info : ScalarToVectorMappings)
281 if (Info.Shape == Shape)
282 return M->getFunction(Info.VectorName);
283
284 return nullptr;
285 }
286 /// @}
287};
288
289template <typename T> class ArrayRef;
290class DemandedBits;
291class GetElementPtrInst;
292template <typename InstTy> class InterleaveGroup;
293class IRBuilderBase;
294class Loop;
295class ScalarEvolution;
296class TargetTransformInfo;
297class Type;
298class Value;
299
300namespace Intrinsic {
301typedef unsigned ID;
302}
303
304/// A helper function for converting Scalar types to vector types. If
305/// the incoming type is void, we return void. If the EC represents a
306/// scalar, we return the scalar type.
307inline Type *ToVectorTy(Type *Scalar, ElementCount EC) {
308 if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar())
309 return Scalar;
310 return VectorType::get(Scalar, EC);
311}
312
313inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
314 return ToVectorTy(Scalar, ElementCount::getFixed(VF));
315}
316
317/// Identify if the intrinsic is trivially vectorizable.
318/// This method returns true if the intrinsic's argument types are all scalars
319/// for the scalar form of the intrinsic and all vectors (or scalars handled by
320/// hasVectorInstrinsicScalarOpd) for the vector form of the intrinsic.
321bool isTriviallyVectorizable(Intrinsic::ID ID);
322
323/// Identifies if the vector form of the intrinsic has a scalar operand.
324bool hasVectorInstrinsicScalarOpd(Intrinsic::ID ID, unsigned ScalarOpdIdx);
325
326/// Returns intrinsic ID for call.
327/// For the input call instruction it finds mapping intrinsic and returns
328/// its intrinsic ID, in case it does not found it return not_intrinsic.
329Intrinsic::ID getVectorIntrinsicIDForCall(const CallInst *CI,
330 const TargetLibraryInfo *TLI);
331
332/// Find the operand of the GEP that should be checked for consecutive
333/// stores. This ignores trailing indices that have no effect on the final
334/// pointer.
335unsigned getGEPInductionOperand(const GetElementPtrInst *Gep);
336
337/// If the argument is a GEP, then returns the operand identified by
338/// getGEPInductionOperand. However, if there is some other non-loop-invariant
339/// operand, it returns that instead.
340Value *stripGetElementPtr(Value *Ptr, ScalarEvolution *SE, Loop *Lp);
341
342/// If a value has only one user that is a CastInst, return it.
343Value *getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty);
344
345/// Get the stride of a pointer access in a loop. Looks for symbolic
346/// strides "a[i*stride]". Returns the symbolic stride, or null otherwise.
347Value *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp);
348
349/// Given a vector and an element number, see if the scalar value is
350/// already around as a register, for example if it were inserted then extracted
351/// from the vector.
352Value *findScalarElement(Value *V, unsigned EltNo);
353
354/// If all non-negative \p Mask elements are the same value, return that value.
355/// If all elements are negative (undefined) or \p Mask contains different
356/// non-negative values, return -1.
357int getSplatIndex(ArrayRef<int> Mask);
358
359/// Get splat value if the input is a splat vector or return nullptr.
360/// The value may be extracted from a splat constants vector or from
361/// a sequence of instructions that broadcast a single value into a vector.
362Value *getSplatValue(const Value *V);
363
364/// Return true if each element of the vector value \p V is poisoned or equal to
365/// every other non-poisoned element. If an index element is specified, either
366/// every element of the vector is poisoned or the element at that index is not
367/// poisoned and equal to every other non-poisoned element.
368/// This may be more powerful than the related getSplatValue() because it is
369/// not limited by finding a scalar source value to a splatted vector.
370bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0);
371
372/// Replace each shuffle mask index with the scaled sequential indices for an
373/// equivalent mask of narrowed elements. Mask elements that are less than 0
374/// (sentinel values) are repeated in the output mask.
375///
376/// Example with Scale = 4:
377/// <4 x i32> <3, 2, 0, -1> -->
378/// <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1>
379///
380/// This is the reverse process of widening shuffle mask elements, but it always
381/// succeeds because the indexes can always be multiplied (scaled up) to map to
382/// narrower vector elements.
383void narrowShuffleMaskElts(int Scale, ArrayRef<int> Mask,
384 SmallVectorImpl<int> &ScaledMask);
385
386/// Try to transform a shuffle mask by replacing elements with the scaled index
387/// for an equivalent mask of widened elements. If all mask elements that would
388/// map to a wider element of the new mask are the same negative number
389/// (sentinel value), that element of the new mask is the same value. If any
390/// element in a given slice is negative and some other element in that slice is
391/// not the same value, return false (partial matches with sentinel values are
392/// not allowed).
393///
394/// Example with Scale = 4:
395/// <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> -->
396/// <4 x i32> <3, 2, 0, -1>
397///
398/// This is the reverse process of narrowing shuffle mask elements if it
399/// succeeds. This transform is not always possible because indexes may not
400/// divide evenly (scale down) to map to wider vector elements.
401bool widenShuffleMaskElts(int Scale, ArrayRef<int> Mask,
402 SmallVectorImpl<int> &ScaledMask);
403
404/// Compute a map of integer instructions to their minimum legal type
405/// size.
406///
407/// C semantics force sub-int-sized values (e.g. i8, i16) to be promoted to int
408/// type (e.g. i32) whenever arithmetic is performed on them.
409///
410/// For targets with native i8 or i16 operations, usually InstCombine can shrink
411/// the arithmetic type down again. However InstCombine refuses to create
412/// illegal types, so for targets without i8 or i16 registers, the lengthening
413/// and shrinking remains.
414///
415/// Most SIMD ISAs (e.g. NEON) however support vectors of i8 or i16 even when
416/// their scalar equivalents do not, so during vectorization it is important to
417/// remove these lengthens and truncates when deciding the profitability of
418/// vectorization.
419///
420/// This function analyzes the given range of instructions and determines the
421/// minimum type size each can be converted to. It attempts to remove or
422/// minimize type size changes across each def-use chain, so for example in the
423/// following code:
424///
425/// %1 = load i8, i8*
426/// %2 = add i8 %1, 2
427/// %3 = load i16, i16*
428/// %4 = zext i8 %2 to i32
429/// %5 = zext i16 %3 to i32
430/// %6 = add i32 %4, %5
431/// %7 = trunc i32 %6 to i16
432///
433/// Instruction %6 must be done at least in i16, so computeMinimumValueSizes
434/// will return: {%1: 16, %2: 16, %3: 16, %4: 16, %5: 16, %6: 16, %7: 16}.
435///
436/// If the optional TargetTransformInfo is provided, this function tries harder
437/// to do less work by only looking at illegal types.
438MapVector<Instruction*, uint64_t>
439computeMinimumValueSizes(ArrayRef<BasicBlock*> Blocks,
440 DemandedBits &DB,
441 const TargetTransformInfo *TTI=nullptr);
442
443/// Compute the union of two access-group lists.
444///
445/// If the list contains just one access group, it is returned directly. If the
446/// list is empty, returns nullptr.
447MDNode *uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2);
448
449/// Compute the access-group list of access groups that @p Inst1 and @p Inst2
450/// are both in. If either instruction does not access memory at all, it is
451/// considered to be in every list.
452///
453/// If the list contains just one access group, it is returned directly. If the
454/// list is empty, returns nullptr.
455MDNode *intersectAccessGroups(const Instruction *Inst1,
456 const Instruction *Inst2);
457
458/// Specifically, let Kinds = [MD_tbaa, MD_alias_scope, MD_noalias, MD_fpmath,
459/// MD_nontemporal, MD_access_group].
460/// For K in Kinds, we get the MDNode for K from each of the
461/// elements of VL, compute their "intersection" (i.e., the most generic
462/// metadata value that covers all of the individual values), and set I's
463/// metadata for M equal to the intersection value.
464///
465/// This function always sets a (possibly null) value for each K in Kinds.
466Instruction *propagateMetadata(Instruction *I, ArrayRef<Value *> VL);
467
468/// Create a mask that filters the members of an interleave group where there
469/// are gaps.
470///
471/// For example, the mask for \p Group with interleave-factor 3
472/// and \p VF 4, that has only its first member present is:
473///
474/// <1,0,0,1,0,0,1,0,0,1,0,0>
475///
476/// Note: The result is a mask of 0's and 1's, as opposed to the other
477/// create[*]Mask() utilities which create a shuffle mask (mask that
478/// consists of indices).
479Constant *createBitMaskForGaps(IRBuilderBase &Builder, unsigned VF,
480 const InterleaveGroup<Instruction> &Group);
481
482/// Create a mask with replicated elements.
483///
484/// This function creates a shuffle mask for replicating each of the \p VF
485/// elements in a vector \p ReplicationFactor times. It can be used to
486/// transform a mask of \p VF elements into a mask of
487/// \p VF * \p ReplicationFactor elements used by a predicated
488/// interleaved-group of loads/stores whose Interleaved-factor ==
489/// \p ReplicationFactor.
490///
491/// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is:
492///
493/// <0,0,0,1,1,1,2,2,2,3,3,3>
494llvm::SmallVector<int, 16> createReplicatedMask(unsigned ReplicationFactor,
495 unsigned VF);
496
497/// Create an interleave shuffle mask.
498///
499/// This function creates a shuffle mask for interleaving \p NumVecs vectors of
500/// vectorization factor \p VF into a single wide vector. The mask is of the
501/// form:
502///
503/// <0, VF, VF * 2, ..., VF * (NumVecs - 1), 1, VF + 1, VF * 2 + 1, ...>
504///
505/// For example, the mask for VF = 4 and NumVecs = 2 is:
506///
507/// <0, 4, 1, 5, 2, 6, 3, 7>.
508llvm::SmallVector<int, 16> createInterleaveMask(unsigned VF, unsigned NumVecs);
509
510/// Create a stride shuffle mask.
511///
512/// This function creates a shuffle mask whose elements begin at \p Start and
513/// are incremented by \p Stride. The mask can be used to deinterleave an
514/// interleaved vector into separate vectors of vectorization factor \p VF. The
515/// mask is of the form:
516///
517/// <Start, Start + Stride, ..., Start + Stride * (VF - 1)>
518///
519/// For example, the mask for Start = 0, Stride = 2, and VF = 4 is:
520///
521/// <0, 2, 4, 6>
522llvm::SmallVector<int, 16> createStrideMask(unsigned Start, unsigned Stride,
523 unsigned VF);
524
525/// Create a sequential shuffle mask.
526///
527/// This function creates shuffle mask whose elements are sequential and begin
528/// at \p Start. The mask contains \p NumInts integers and is padded with \p
529/// NumUndefs undef values. The mask is of the form:
530///
531/// <Start, Start + 1, ... Start + NumInts - 1, undef_1, ... undef_NumUndefs>
532///
533/// For example, the mask for Start = 0, NumInsts = 4, and NumUndefs = 4 is:
534///
535/// <0, 1, 2, 3, undef, undef, undef, undef>
536llvm::SmallVector<int, 16>
537createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs);
538
539/// Concatenate a list of vectors.
540///
541/// This function generates code that concatenate the vectors in \p Vecs into a
542/// single large vector. The number of vectors should be greater than one, and
543/// their element types should be the same. The number of elements in the
544/// vectors should also be the same; however, if the last vector has fewer
545/// elements, it will be padded with undefs.
546Value *concatenateVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vecs);
547
548/// Given a mask vector of i1, Return true if all of the elements of this
549/// predicate mask are known to be false or undef. That is, return true if all
550/// lanes can be assumed inactive.
551bool maskIsAllZeroOrUndef(Value *Mask);
552
553/// Given a mask vector of i1, Return true if all of the elements of this
554/// predicate mask are known to be true or undef. That is, return true if all
555/// lanes can be assumed active.
556bool maskIsAllOneOrUndef(Value *Mask);
557
558/// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y)
559/// for each lane which may be active.
560APInt possiblyDemandedEltsInMask(Value *Mask);
561
562/// The group of interleaved loads/stores sharing the same stride and
563/// close to each other.
564///
565/// Each member in this group has an index starting from 0, and the largest
566/// index should be less than interleaved factor, which is equal to the absolute
567/// value of the access's stride.
568///
569/// E.g. An interleaved load group of factor 4:
570/// for (unsigned i = 0; i < 1024; i+=4) {
571/// a = A[i]; // Member of index 0
572/// b = A[i+1]; // Member of index 1
573/// d = A[i+3]; // Member of index 3
574/// ...
575/// }
576///
577/// An interleaved store group of factor 4:
578/// for (unsigned i = 0; i < 1024; i+=4) {
579/// ...
580/// A[i] = a; // Member of index 0
581/// A[i+1] = b; // Member of index 1
582/// A[i+2] = c; // Member of index 2
583/// A[i+3] = d; // Member of index 3
584/// }
585///
586/// Note: the interleaved load group could have gaps (missing members), but
587/// the interleaved store group doesn't allow gaps.
588template <typename InstTy> class InterleaveGroup {
589public:
590 InterleaveGroup(uint32_t Factor, bool Reverse, Align Alignment)
591 : Factor(Factor), Reverse(Reverse), Alignment(Alignment),
592 InsertPos(nullptr) {}
593
594 InterleaveGroup(InstTy *Instr, int32_t Stride, Align Alignment)
595 : Alignment(Alignment), InsertPos(Instr) {
596 Factor = std::abs(Stride);
597 assert(Factor > 1 && "Invalid interleave factor");
598
599 Reverse = Stride < 0;
600 Members[0] = Instr;
601 }
602
603 bool isReverse() const { return Reverse; }
604 uint32_t getFactor() const { return Factor; }
605 Align getAlign() const { return Alignment; }
606 uint32_t getNumMembers() const { return Members.size(); }
607
608 /// Try to insert a new member \p Instr with index \p Index and
609 /// alignment \p NewAlign. The index is related to the leader and it could be
610 /// negative if it is the new leader.
611 ///
612 /// \returns false if the instruction doesn't belong to the group.
613 bool insertMember(InstTy *Instr, int32_t Index, Align NewAlign) {
614 // Make sure the key fits in an int32_t.
615 Optional<int32_t> MaybeKey = checkedAdd(Index, SmallestKey);
616 if (!MaybeKey)
617 return false;
618 int32_t Key = *MaybeKey;
619
620 // Skip if the key is used for either the tombstone or empty special values.
621 if (DenseMapInfo<int32_t>::getTombstoneKey() == Key ||
622 DenseMapInfo<int32_t>::getEmptyKey() == Key)
623 return false;
624
625 // Skip if there is already a member with the same index.
626 if (Members.find(Key) != Members.end())
627 return false;
628
629 if (Key > LargestKey) {
630 // The largest index is always less than the interleave factor.
631 if (Index >= static_cast<int32_t>(Factor))
632 return false;
633
634 LargestKey = Key;
635 } else if (Key < SmallestKey) {
636
637 // Make sure the largest index fits in an int32_t.
638 Optional<int32_t> MaybeLargestIndex = checkedSub(LargestKey, Key);
639 if (!MaybeLargestIndex)
640 return false;
641
642 // The largest index is always less than the interleave factor.
643 if (*MaybeLargestIndex >= static_cast<int64_t>(Factor))
644 return false;
645
646 SmallestKey = Key;
647 }
648
649 // It's always safe to select the minimum alignment.
650 Alignment = std::min(Alignment, NewAlign);
651 Members[Key] = Instr;
652 return true;
653 }
654
655 /// Get the member with the given index \p Index
656 ///
657 /// \returns nullptr if contains no such member.
658 InstTy *getMember(uint32_t Index) const {
659 int32_t Key = SmallestKey + Index;
660 return Members.lookup(Key);
661 }
662
663 /// Get the index for the given member. Unlike the key in the member
664 /// map, the index starts from 0.
665 uint32_t getIndex(const InstTy *Instr) const {
666 for (auto I : Members) {
667 if (I.second == Instr)
668 return I.first - SmallestKey;
669 }
670
671 llvm_unreachable("InterleaveGroup contains no such member");
672 }
673
674 InstTy *getInsertPos() const { return InsertPos; }
675 void setInsertPos(InstTy *Inst) { InsertPos = Inst; }
676
677 /// Add metadata (e.g. alias info) from the instructions in this group to \p
678 /// NewInst.
679 ///
680 /// FIXME: this function currently does not add noalias metadata a'la
681 /// addNewMedata. To do that we need to compute the intersection of the
682 /// noalias info from all members.
683 void addMetadata(InstTy *NewInst) const;
684
685 /// Returns true if this Group requires a scalar iteration to handle gaps.
686 bool requiresScalarEpilogue() const {
687 // If the last member of the Group exists, then a scalar epilog is not
688 // needed for this group.
689 if (getMember(getFactor() - 1))
690 return false;
691
692 // We have a group with gaps. It therefore cannot be a group of stores,
693 // and it can't be a reversed access, because such groups get invalidated.
694 assert(!getMember(0)->mayWriteToMemory() &&
695 "Group should have been invalidated");
696 assert(!isReverse() && "Group should have been invalidated");
697
698 // This is a group of loads, with gaps, and without a last-member
699 return true;
700 }
701
702private:
703 uint32_t Factor; // Interleave Factor.
704 bool Reverse;
705 Align Alignment;
706 DenseMap<int32_t, InstTy *> Members;
707 int32_t SmallestKey = 0;
708 int32_t LargestKey = 0;
709
710 // To avoid breaking dependences, vectorized instructions of an interleave
711 // group should be inserted at either the first load or the last store in
712 // program order.
713 //
714 // E.g. %even = load i32 // Insert Position
715 // %add = add i32 %even // Use of %even
716 // %odd = load i32
717 //
718 // store i32 %even
719 // %odd = add i32 // Def of %odd
720 // store i32 %odd // Insert Position
721 InstTy *InsertPos;
722};
723
724/// Drive the analysis of interleaved memory accesses in the loop.
725///
726/// Use this class to analyze interleaved accesses only when we can vectorize
727/// a loop. Otherwise it's meaningless to do analysis as the vectorization
728/// on interleaved accesses is unsafe.
729///
730/// The analysis collects interleave groups and records the relationships
731/// between the member and the group in a map.
732class InterleavedAccessInfo {
733public:
734 InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L,
735 DominatorTree *DT, LoopInfo *LI,
736 const LoopAccessInfo *LAI)
737 : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(LAI) {}
738
739 ~InterleavedAccessInfo() { invalidateGroups(); }
740
741 /// Analyze the interleaved accesses and collect them in interleave
742 /// groups. Substitute symbolic strides using \p Strides.
743 /// Consider also predicated loads/stores in the analysis if
744 /// \p EnableMaskedInterleavedGroup is true.
745 void analyzeInterleaving(bool EnableMaskedInterleavedGroup);
746
747 /// Invalidate groups, e.g., in case all blocks in loop will be predicated
748 /// contrary to original assumption. Although we currently prevent group
749 /// formation for predicated accesses, we may be able to relax this limitation
750 /// in the future once we handle more complicated blocks. Returns true if any
751 /// groups were invalidated.
752 bool invalidateGroups() {
753 if (InterleaveGroups.empty()) {
754 assert(
755 !RequiresScalarEpilogue &&
756 "RequiresScalarEpilog should not be set without interleave groups");
757 return false;
758 }
759
760 InterleaveGroupMap.clear();
761 for (auto *Ptr : InterleaveGroups)
762 delete Ptr;
763 InterleaveGroups.clear();
764 RequiresScalarEpilogue = false;
765 return true;
766 }
767
768 /// Check if \p Instr belongs to any interleave group.
769 bool isInterleaved(Instruction *Instr) const {
770 return InterleaveGroupMap.find(Instr) != InterleaveGroupMap.end();
771 }
772
773 /// Get the interleave group that \p Instr belongs to.
774 ///
775 /// \returns nullptr if doesn't have such group.
776 InterleaveGroup<Instruction> *
777 getInterleaveGroup(const Instruction *Instr) const {
778 return InterleaveGroupMap.lookup(Instr);
779 }
780
781 iterator_range<SmallPtrSetIterator<llvm::InterleaveGroup<Instruction> *>>
782 getInterleaveGroups() {
783 return make_range(InterleaveGroups.begin(), InterleaveGroups.end());
784 }
785
786 /// Returns true if an interleaved group that may access memory
787 /// out-of-bounds requires a scalar epilogue iteration for correctness.
788 bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; }
789
790 /// Invalidate groups that require a scalar epilogue (due to gaps). This can
791 /// happen when optimizing for size forbids a scalar epilogue, and the gap
792 /// cannot be filtered by masking the load/store.
793 void invalidateGroupsRequiringScalarEpilogue();
794
795private:
796 /// A wrapper around ScalarEvolution, used to add runtime SCEV checks.
797 /// Simplifies SCEV expressions in the context of existing SCEV assumptions.
798 /// The interleaved access analysis can also add new predicates (for example
799 /// by versioning strides of pointers).
800 PredicatedScalarEvolution &PSE;
801
802 Loop *TheLoop;
803 DominatorTree *DT;
804 LoopInfo *LI;
805 const LoopAccessInfo *LAI;
806
807 /// True if the loop may contain non-reversed interleaved groups with
808 /// out-of-bounds accesses. We ensure we don't speculatively access memory
809 /// out-of-bounds by executing at least one scalar epilogue iteration.
810 bool RequiresScalarEpilogue = false;
811
812 /// Holds the relationships between the members and the interleave group.
813 DenseMap<Instruction *, InterleaveGroup<Instruction> *> InterleaveGroupMap;
814
815 SmallPtrSet<InterleaveGroup<Instruction> *, 4> InterleaveGroups;
816
817 /// Holds dependences among the memory accesses in the loop. It maps a source
818 /// access to a set of dependent sink accesses.
819 DenseMap<Instruction *, SmallPtrSet<Instruction *, 2>> Dependences;
820
821 /// The descriptor for a strided memory access.
822 struct StrideDescriptor {
823 StrideDescriptor() = default;
824 StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size,
825 Align Alignment)
826 : Stride(Stride), Scev(Scev), Size(Size), Alignment(Alignment) {}
827
828 // The access's stride. It is negative for a reverse access.
829 int64_t Stride = 0;
830
831 // The scalar expression of this access.
832 const SCEV *Scev = nullptr;
833
834 // The size of the memory object.
835 uint64_t Size = 0;
836
837 // The alignment of this access.
838 Align Alignment;
839 };
840
841 /// A type for holding instructions and their stride descriptors.
842 using StrideEntry = std::pair<Instruction *, StrideDescriptor>;
843
844 /// Create a new interleave group with the given instruction \p Instr,
845 /// stride \p Stride and alignment \p Align.
846 ///
847 /// \returns the newly created interleave group.
848 InterleaveGroup<Instruction> *
849 createInterleaveGroup(Instruction *Instr, int Stride, Align Alignment) {
850 assert(!InterleaveGroupMap.count(Instr) &&
851 "Already in an interleaved access group");
852 InterleaveGroupMap[Instr] =
853 new InterleaveGroup<Instruction>(Instr, Stride, Alignment);
854 InterleaveGroups.insert(InterleaveGroupMap[Instr]);
855 return InterleaveGroupMap[Instr];
856 }
857
858 /// Release the group and remove all the relationships.
859 void releaseGroup(InterleaveGroup<Instruction> *Group) {
860 for (unsigned i = 0; i < Group->getFactor(); i++)
861 if (Instruction *Member = Group->getMember(i))
862 InterleaveGroupMap.erase(Member);
863
864 InterleaveGroups.erase(Group);
865 delete Group;
866 }
867
868 /// Collect all the accesses with a constant stride in program order.
869 void collectConstStrideAccesses(
870 MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo,
871 const ValueToValueMap &Strides);
872
873 /// Returns true if \p Stride is allowed in an interleaved group.
874 static bool isStrided(int Stride);
875
876 /// Returns true if \p BB is a predicated block.
877 bool isPredicated(BasicBlock *BB) const {
878 return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT);
879 }
880
881 /// Returns true if LoopAccessInfo can be used for dependence queries.
882 bool areDependencesValid() const {
883 return LAI && LAI->getDepChecker().getDependences();
884 }
885
886 /// Returns true if memory accesses \p A and \p B can be reordered, if
887 /// necessary, when constructing interleaved groups.
888 ///
889 /// \p A must precede \p B in program order. We return false if reordering is
890 /// not necessary or is prevented because \p A and \p B may be dependent.
891 bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A,
892 StrideEntry *B) const {
893 // Code motion for interleaved accesses can potentially hoist strided loads
894 // and sink strided stores. The code below checks the legality of the
895 // following two conditions:
896 //
897 // 1. Potentially moving a strided load (B) before any store (A) that
898 // precedes B, or
899 //
900 // 2. Potentially moving a strided store (A) after any load or store (B)
901 // that A precedes.
902 //
903 // It's legal to reorder A and B if we know there isn't a dependence from A
904 // to B. Note that this determination is conservative since some
905 // dependences could potentially be reordered safely.
906
907 // A is potentially the source of a dependence.
908 auto *Src = A->first;
909 auto SrcDes = A->second;
910
911 // B is potentially the sink of a dependence.
912 auto *Sink = B->first;
913 auto SinkDes = B->second;
914
915 // Code motion for interleaved accesses can't violate WAR dependences.
916 // Thus, reordering is legal if the source isn't a write.
917 if (!Src->mayWriteToMemory())
918 return true;
919
920 // At least one of the accesses must be strided.
921 if (!isStrided(SrcDes.Stride) && !isStrided(SinkDes.Stride))
922 return true;
923
924 // If dependence information is not available from LoopAccessInfo,
925 // conservatively assume the instructions can't be reordered.
926 if (!areDependencesValid())
927 return false;
928
929 // If we know there is a dependence from source to sink, assume the
930 // instructions can't be reordered. Otherwise, reordering is legal.
931 return Dependences.find(Src) == Dependences.end() ||
932 !Dependences.lookup(Src).count(Sink);
933 }
934
935 /// Collect the dependences from LoopAccessInfo.
936 ///
937 /// We process the dependences once during the interleaved access analysis to
938 /// enable constant-time dependence queries.
939 void collectDependences() {
940 if (!areDependencesValid())
941 return;
942 auto *Deps = LAI->getDepChecker().getDependences();
943 for (auto Dep : *Deps)
944 Dependences[Dep.getSource(*LAI)].insert(Dep.getDestination(*LAI));
945 }
946};
947
948} // llvm namespace
949
950#endif
951