1//===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// OpenMP specific optimizations:
10//
11// - Deduplication of runtime calls, e.g., omp_get_thread_num.
12// - Replacing globalized device memory with stack memory.
13// - Replacing globalized device memory with shared memory.
14// - Parallel region merging.
15// - Transforming generic-mode device kernels to SPMD mode.
16// - Specializing the state machine for generic-mode device kernels.
17//
18//===----------------------------------------------------------------------===//
19
20#include "llvm/Transforms/IPO/OpenMPOpt.h"
21
22#include "llvm/ADT/EnumeratedArray.h"
23#include "llvm/ADT/PostOrderIterator.h"
24#include "llvm/ADT/SetVector.h"
25#include "llvm/ADT/SmallPtrSet.h"
26#include "llvm/ADT/SmallVector.h"
27#include "llvm/ADT/Statistic.h"
28#include "llvm/ADT/StringExtras.h"
29#include "llvm/ADT/StringRef.h"
30#include "llvm/Analysis/CallGraph.h"
31#include "llvm/Analysis/CallGraphSCCPass.h"
32#include "llvm/Analysis/MemoryLocation.h"
33#include "llvm/Analysis/OptimizationRemarkEmitter.h"
34#include "llvm/Analysis/ValueTracking.h"
35#include "llvm/Frontend/OpenMP/OMPConstants.h"
36#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
37#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
38#include "llvm/IR/Assumptions.h"
39#include "llvm/IR/BasicBlock.h"
40#include "llvm/IR/Constants.h"
41#include "llvm/IR/DiagnosticInfo.h"
42#include "llvm/IR/Dominators.h"
43#include "llvm/IR/Function.h"
44#include "llvm/IR/GlobalValue.h"
45#include "llvm/IR/GlobalVariable.h"
46#include "llvm/IR/InstrTypes.h"
47#include "llvm/IR/Instruction.h"
48#include "llvm/IR/Instructions.h"
49#include "llvm/IR/IntrinsicInst.h"
50#include "llvm/IR/IntrinsicsAMDGPU.h"
51#include "llvm/IR/IntrinsicsNVPTX.h"
52#include "llvm/IR/LLVMContext.h"
53#include "llvm/Support/Casting.h"
54#include "llvm/Support/CommandLine.h"
55#include "llvm/Support/Debug.h"
56#include "llvm/Transforms/IPO/Attributor.h"
57#include "llvm/Transforms/Utils/BasicBlockUtils.h"
58#include "llvm/Transforms/Utils/CallGraphUpdater.h"
59
60#include <algorithm>
61#include <optional>
62#include <string>
63
64using namespace llvm;
65using namespace omp;
66
67#define DEBUG_TYPE "openmp-opt"
68
69static cl::opt<bool> DisableOpenMPOptimizations(
70 "openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."),
71 cl::Hidden, cl::init(Val: false));
72
73static cl::opt<bool> EnableParallelRegionMerging(
74 "openmp-opt-enable-merging",
75 cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
76 cl::init(Val: false));
77
78static cl::opt<bool>
79 DisableInternalization("openmp-opt-disable-internalization",
80 cl::desc("Disable function internalization."),
81 cl::Hidden, cl::init(Val: false));
82
83static cl::opt<bool> DeduceICVValues("openmp-deduce-icv-values",
84 cl::init(Val: false), cl::Hidden);
85static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(Val: false),
86 cl::Hidden);
87static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
88 cl::init(Val: false), cl::Hidden);
89
90static cl::opt<bool> HideMemoryTransferLatency(
91 "openmp-hide-memory-transfer-latency",
92 cl::desc("[WIP] Tries to hide the latency of host to device memory"
93 " transfers"),
94 cl::Hidden, cl::init(Val: false));
95
96static cl::opt<bool> DisableOpenMPOptDeglobalization(
97 "openmp-opt-disable-deglobalization",
98 cl::desc("Disable OpenMP optimizations involving deglobalization."),
99 cl::Hidden, cl::init(Val: false));
100
101static cl::opt<bool> DisableOpenMPOptSPMDization(
102 "openmp-opt-disable-spmdization",
103 cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
104 cl::Hidden, cl::init(Val: false));
105
106static cl::opt<bool> DisableOpenMPOptFolding(
107 "openmp-opt-disable-folding",
108 cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
109 cl::init(Val: false));
110
111static cl::opt<bool> DisableOpenMPOptStateMachineRewrite(
112 "openmp-opt-disable-state-machine-rewrite",
113 cl::desc("Disable OpenMP optimizations that replace the state machine."),
114 cl::Hidden, cl::init(Val: false));
115
116static cl::opt<bool> DisableOpenMPOptBarrierElimination(
117 "openmp-opt-disable-barrier-elimination",
118 cl::desc("Disable OpenMP optimizations that eliminate barriers."),
119 cl::Hidden, cl::init(Val: false));
120
121static cl::opt<bool> PrintModuleAfterOptimizations(
122 "openmp-opt-print-module-after",
123 cl::desc("Print the current module after OpenMP optimizations."),
124 cl::Hidden, cl::init(Val: false));
125
126static cl::opt<bool> PrintModuleBeforeOptimizations(
127 "openmp-opt-print-module-before",
128 cl::desc("Print the current module before OpenMP optimizations."),
129 cl::Hidden, cl::init(Val: false));
130
131static cl::opt<bool> AlwaysInlineDeviceFunctions(
132 "openmp-opt-inline-device",
133 cl::desc("Inline all applicible functions on the device."), cl::Hidden,
134 cl::init(Val: false));
135
136static cl::opt<bool>
137 EnableVerboseRemarks("openmp-opt-verbose-remarks",
138 cl::desc("Enables more verbose remarks."), cl::Hidden,
139 cl::init(Val: false));
140
141static cl::opt<unsigned>
142 SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden,
143 cl::desc("Maximal number of attributor iterations."),
144 cl::init(Val: 256));
145
146static cl::opt<unsigned>
147 SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden,
148 cl::desc("Maximum amount of shared memory to use."),
149 cl::init(Val: std::numeric_limits<unsigned>::max()));
150
151STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
152 "Number of OpenMP runtime calls deduplicated");
153STATISTIC(NumOpenMPParallelRegionsDeleted,
154 "Number of OpenMP parallel regions deleted");
155STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
156 "Number of OpenMP runtime functions identified");
157STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
158 "Number of OpenMP runtime function uses identified");
159STATISTIC(NumOpenMPTargetRegionKernels,
160 "Number of OpenMP target region entry points (=kernels) identified");
161STATISTIC(NumNonOpenMPTargetRegionKernels,
162 "Number of non-OpenMP target region kernels identified");
163STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
164 "Number of OpenMP target region entry points (=kernels) executed in "
165 "SPMD-mode instead of generic-mode");
166STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
167 "Number of OpenMP target region entry points (=kernels) executed in "
168 "generic-mode without a state machines");
169STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
170 "Number of OpenMP target region entry points (=kernels) executed in "
171 "generic-mode with customized state machines with fallback");
172STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
173 "Number of OpenMP target region entry points (=kernels) executed in "
174 "generic-mode with customized state machines without fallback");
175STATISTIC(
176 NumOpenMPParallelRegionsReplacedInGPUStateMachine,
177 "Number of OpenMP parallel regions replaced with ID in GPU state machines");
178STATISTIC(NumOpenMPParallelRegionsMerged,
179 "Number of OpenMP parallel regions merged");
180STATISTIC(NumBytesMovedToSharedMemory,
181 "Amount of memory pushed to shared memory");
182STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated");
183
184#if !defined(NDEBUG)
185static constexpr auto TAG = "[" DEBUG_TYPE "]";
186#endif
187
188namespace KernelInfo {
189
190// struct ConfigurationEnvironmentTy {
191// uint8_t UseGenericStateMachine;
192// uint8_t MayUseNestedParallelism;
193// llvm::omp::OMPTgtExecModeFlags ExecMode;
194// int32_t MinThreads;
195// int32_t MaxThreads;
196// int32_t MinTeams;
197// int32_t MaxTeams;
198// };
199
200// struct DynamicEnvironmentTy {
201// uint16_t DebugIndentionLevel;
202// };
203
204// struct KernelEnvironmentTy {
205// ConfigurationEnvironmentTy Configuration;
206// IdentTy *Ident;
207// DynamicEnvironmentTy *DynamicEnv;
208// };
209
210#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX) \
211 constexpr const unsigned MEMBER##Idx = IDX;
212
213KERNEL_ENVIRONMENT_IDX(Configuration, 0)
214KERNEL_ENVIRONMENT_IDX(Ident, 1)
215
216#undef KERNEL_ENVIRONMENT_IDX
217
218#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX) \
219 constexpr const unsigned MEMBER##Idx = IDX;
220
221KERNEL_ENVIRONMENT_CONFIGURATION_IDX(UseGenericStateMachine, 0)
222KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MayUseNestedParallelism, 1)
223KERNEL_ENVIRONMENT_CONFIGURATION_IDX(ExecMode, 2)
224KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MinThreads, 3)
225KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MaxThreads, 4)
226KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MinTeams, 5)
227KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MaxTeams, 6)
228
229#undef KERNEL_ENVIRONMENT_CONFIGURATION_IDX
230
231#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE) \
232 RETURNTYPE *get##MEMBER##FromKernelEnvironment(ConstantStruct *KernelEnvC) { \
233 return cast<RETURNTYPE>(KernelEnvC->getAggregateElement(MEMBER##Idx)); \
234 }
235
236KERNEL_ENVIRONMENT_GETTER(Ident, Constant)
237KERNEL_ENVIRONMENT_GETTER(Configuration, ConstantStruct)
238
239#undef KERNEL_ENVIRONMENT_GETTER
240
241#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER) \
242 ConstantInt *get##MEMBER##FromKernelEnvironment( \
243 ConstantStruct *KernelEnvC) { \
244 ConstantStruct *ConfigC = \
245 getConfigurationFromKernelEnvironment(KernelEnvC); \
246 return dyn_cast<ConstantInt>(ConfigC->getAggregateElement(MEMBER##Idx)); \
247 }
248
249KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(UseGenericStateMachine)
250KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MayUseNestedParallelism)
251KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(ExecMode)
252KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MinThreads)
253KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MaxThreads)
254KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MinTeams)
255KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MaxTeams)
256
257#undef KERNEL_ENVIRONMENT_CONFIGURATION_GETTER
258
259GlobalVariable *
260getKernelEnvironementGVFromKernelInitCB(CallBase *KernelInitCB) {
261 constexpr const int InitKernelEnvironmentArgNo = 0;
262 return cast<GlobalVariable>(
263 Val: KernelInitCB->getArgOperand(i: InitKernelEnvironmentArgNo)
264 ->stripPointerCasts());
265}
266
267ConstantStruct *getKernelEnvironementFromKernelInitCB(CallBase *KernelInitCB) {
268 GlobalVariable *KernelEnvGV =
269 getKernelEnvironementGVFromKernelInitCB(KernelInitCB);
270 return cast<ConstantStruct>(Val: KernelEnvGV->getInitializer());
271}
272} // namespace KernelInfo
273
274namespace {
275
276struct AAHeapToShared;
277
278struct AAICVTracker;
279
280/// OpenMP specific information. For now, stores RFIs and ICVs also needed for
281/// Attributor runs.
282struct OMPInformationCache : public InformationCache {
283 OMPInformationCache(Module &M, AnalysisGetter &AG,
284 BumpPtrAllocator &Allocator, SetVector<Function *> *CGSCC,
285 bool OpenMPPostLink)
286 : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
287 OpenMPPostLink(OpenMPPostLink) {
288
289 OMPBuilder.Config.IsTargetDevice = isOpenMPDevice(M&: OMPBuilder.M);
290 OMPBuilder.initialize();
291 initializeRuntimeFunctions(M);
292 initializeInternalControlVars();
293 }
294
295 /// Generic information that describes an internal control variable.
296 struct InternalControlVarInfo {
297 /// The kind, as described by InternalControlVar enum.
298 InternalControlVar Kind;
299
300 /// The name of the ICV.
301 StringRef Name;
302
303 /// Environment variable associated with this ICV.
304 StringRef EnvVarName;
305
306 /// Initial value kind.
307 ICVInitValue InitKind;
308
309 /// Initial value.
310 ConstantInt *InitValue;
311
312 /// Setter RTL function associated with this ICV.
313 RuntimeFunction Setter;
314
315 /// Getter RTL function associated with this ICV.
316 RuntimeFunction Getter;
317
318 /// RTL Function corresponding to the override clause of this ICV
319 RuntimeFunction Clause;
320 };
321
322 /// Generic information that describes a runtime function
323 struct RuntimeFunctionInfo {
324
325 /// The kind, as described by the RuntimeFunction enum.
326 RuntimeFunction Kind;
327
328 /// The name of the function.
329 StringRef Name;
330
331 /// Flag to indicate a variadic function.
332 bool IsVarArg;
333
334 /// The return type of the function.
335 Type *ReturnType;
336
337 /// The argument types of the function.
338 SmallVector<Type *, 8> ArgumentTypes;
339
340 /// The declaration if available.
341 Function *Declaration = nullptr;
342
343 /// Uses of this runtime function per function containing the use.
344 using UseVector = SmallVector<Use *, 16>;
345
346 /// Clear UsesMap for runtime function.
347 void clearUsesMap() { UsesMap.clear(); }
348
349 /// Boolean conversion that is true if the runtime function was found.
350 operator bool() const { return Declaration; }
351
352 /// Return the vector of uses in function \p F.
353 UseVector &getOrCreateUseVector(Function *F) {
354 std::shared_ptr<UseVector> &UV = UsesMap[F];
355 if (!UV)
356 UV = std::make_shared<UseVector>();
357 return *UV;
358 }
359
360 /// Return the vector of uses in function \p F or `nullptr` if there are
361 /// none.
362 const UseVector *getUseVector(Function &F) const {
363 auto I = UsesMap.find(Val: &F);
364 if (I != UsesMap.end())
365 return I->second.get();
366 return nullptr;
367 }
368
369 /// Return how many functions contain uses of this runtime function.
370 size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
371
372 /// Return the number of arguments (or the minimal number for variadic
373 /// functions).
374 size_t getNumArgs() const { return ArgumentTypes.size(); }
375
376 /// Run the callback \p CB on each use and forget the use if the result is
377 /// true. The callback will be fed the function in which the use was
378 /// encountered as second argument.
379 void foreachUse(SmallVectorImpl<Function *> &SCC,
380 function_ref<bool(Use &, Function &)> CB) {
381 for (Function *F : SCC)
382 foreachUse(CB, F);
383 }
384
385 /// Run the callback \p CB on each use within the function \p F and forget
386 /// the use if the result is true.
387 void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
388 SmallVector<unsigned, 8> ToBeDeleted;
389 ToBeDeleted.clear();
390
391 unsigned Idx = 0;
392 UseVector &UV = getOrCreateUseVector(F);
393
394 for (Use *U : UV) {
395 if (CB(*U, *F))
396 ToBeDeleted.push_back(Elt: Idx);
397 ++Idx;
398 }
399
400 // Remove the to-be-deleted indices in reverse order as prior
401 // modifications will not modify the smaller indices.
402 while (!ToBeDeleted.empty()) {
403 unsigned Idx = ToBeDeleted.pop_back_val();
404 UV[Idx] = UV.back();
405 UV.pop_back();
406 }
407 }
408
409 private:
410 /// Map from functions to all uses of this runtime function contained in
411 /// them.
412 DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;
413
414 public:
415 /// Iterators for the uses of this runtime function.
416 decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
417 decltype(UsesMap)::iterator end() { return UsesMap.end(); }
418 };
419
420 /// An OpenMP-IR-Builder instance
421 OpenMPIRBuilder OMPBuilder;
422
423 /// Map from runtime function kind to the runtime function description.
424 EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
425 RuntimeFunction::OMPRTL___last>
426 RFIs;
427
428 /// Map from function declarations/definitions to their runtime enum type.
429 DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
430
431 /// Map from ICV kind to the ICV description.
432 EnumeratedArray<InternalControlVarInfo, InternalControlVar,
433 InternalControlVar::ICV___last>
434 ICVs;
435
436 /// Helper to initialize all internal control variable information for those
437 /// defined in OMPKinds.def.
438 void initializeInternalControlVars() {
439#define ICV_RT_SET(_Name, RTL) \
440 { \
441 auto &ICV = ICVs[_Name]; \
442 ICV.Setter = RTL; \
443 }
444#define ICV_RT_GET(Name, RTL) \
445 { \
446 auto &ICV = ICVs[Name]; \
447 ICV.Getter = RTL; \
448 }
449#define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \
450 { \
451 auto &ICV = ICVs[Enum]; \
452 ICV.Name = _Name; \
453 ICV.Kind = Enum; \
454 ICV.InitKind = Init; \
455 ICV.EnvVarName = _EnvVarName; \
456 switch (ICV.InitKind) { \
457 case ICV_IMPLEMENTATION_DEFINED: \
458 ICV.InitValue = nullptr; \
459 break; \
460 case ICV_ZERO: \
461 ICV.InitValue = ConstantInt::get( \
462 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \
463 break; \
464 case ICV_FALSE: \
465 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \
466 break; \
467 case ICV_LAST: \
468 break; \
469 } \
470 }
471#include "llvm/Frontend/OpenMP/OMPKinds.def"
472 }
473
474 /// Returns true if the function declaration \p F matches the runtime
475 /// function types, that is, return type \p RTFRetType, and argument types
476 /// \p RTFArgTypes.
477 static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
478 SmallVector<Type *, 8> &RTFArgTypes) {
479 // TODO: We should output information to the user (under debug output
480 // and via remarks).
481
482 if (!F)
483 return false;
484 if (F->getReturnType() != RTFRetType)
485 return false;
486 if (F->arg_size() != RTFArgTypes.size())
487 return false;
488
489 auto *RTFTyIt = RTFArgTypes.begin();
490 for (Argument &Arg : F->args()) {
491 if (Arg.getType() != *RTFTyIt)
492 return false;
493
494 ++RTFTyIt;
495 }
496
497 return true;
498 }
499
500 // Helper to collect all uses of the declaration in the UsesMap.
501 unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
502 unsigned NumUses = 0;
503 if (!RFI.Declaration)
504 return NumUses;
505 OMPBuilder.addAttributes(FnID: RFI.Kind, Fn&: *RFI.Declaration);
506
507 if (CollectStats) {
508 NumOpenMPRuntimeFunctionsIdentified += 1;
509 NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
510 }
511
512 // TODO: We directly convert uses into proper calls and unknown uses.
513 for (Use &U : RFI.Declaration->uses()) {
514 if (Instruction *UserI = dyn_cast<Instruction>(Val: U.getUser())) {
515 if (!CGSCC || CGSCC->empty() || CGSCC->contains(key: UserI->getFunction())) {
516 RFI.getOrCreateUseVector(F: UserI->getFunction()).push_back(Elt: &U);
517 ++NumUses;
518 }
519 } else {
520 RFI.getOrCreateUseVector(F: nullptr).push_back(Elt: &U);
521 ++NumUses;
522 }
523 }
524 return NumUses;
525 }
526
527 // Helper function to recollect uses of a runtime function.
528 void recollectUsesForFunction(RuntimeFunction RTF) {
529 auto &RFI = RFIs[RTF];
530 RFI.clearUsesMap();
531 collectUses(RFI, /*CollectStats*/ false);
532 }
533
534 // Helper function to recollect uses of all runtime functions.
535 void recollectUses() {
536 for (int Idx = 0; Idx < RFIs.size(); ++Idx)
537 recollectUsesForFunction(RTF: static_cast<RuntimeFunction>(Idx));
538 }
539
540 // Helper function to inherit the calling convention of the function callee.
541 void setCallingConvention(FunctionCallee Callee, CallInst *CI) {
542 if (Function *Fn = dyn_cast<Function>(Val: Callee.getCallee()))
543 CI->setCallingConv(Fn->getCallingConv());
544 }
545
546 // Helper function to determine if it's legal to create a call to the runtime
547 // functions.
548 bool runtimeFnsAvailable(ArrayRef<RuntimeFunction> Fns) {
549 // We can always emit calls if we haven't yet linked in the runtime.
550 if (!OpenMPPostLink)
551 return true;
552
553 // Once the runtime has been already been linked in we cannot emit calls to
554 // any undefined functions.
555 for (RuntimeFunction Fn : Fns) {
556 RuntimeFunctionInfo &RFI = RFIs[Fn];
557
558 if (RFI.Declaration && RFI.Declaration->isDeclaration())
559 return false;
560 }
561 return true;
562 }
563
564 /// Helper to initialize all runtime function information for those defined
565 /// in OpenMPKinds.def.
566 void initializeRuntimeFunctions(Module &M) {
567
568 // Helper macros for handling __VA_ARGS__ in OMP_RTL
569#define OMP_TYPE(VarName, ...) \
570 Type *VarName = OMPBuilder.VarName; \
571 (void)VarName;
572
573#define OMP_ARRAY_TYPE(VarName, ...) \
574 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \
575 (void)VarName##Ty; \
576 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \
577 (void)VarName##PtrTy;
578
579#define OMP_FUNCTION_TYPE(VarName, ...) \
580 FunctionType *VarName = OMPBuilder.VarName; \
581 (void)VarName; \
582 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
583 (void)VarName##Ptr;
584
585#define OMP_STRUCT_TYPE(VarName, ...) \
586 StructType *VarName = OMPBuilder.VarName; \
587 (void)VarName; \
588 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
589 (void)VarName##Ptr;
590
591#define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \
592 { \
593 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \
594 Function *F = M.getFunction(_Name); \
595 RTLFunctions.insert(F); \
596 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \
597 RuntimeFunctionIDMap[F] = _Enum; \
598 auto &RFI = RFIs[_Enum]; \
599 RFI.Kind = _Enum; \
600 RFI.Name = _Name; \
601 RFI.IsVarArg = _IsVarArg; \
602 RFI.ReturnType = OMPBuilder._ReturnType; \
603 RFI.ArgumentTypes = std::move(ArgsTypes); \
604 RFI.Declaration = F; \
605 unsigned NumUses = collectUses(RFI); \
606 (void)NumUses; \
607 LLVM_DEBUG({ \
608 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \
609 << " found\n"; \
610 if (RFI.Declaration) \
611 dbgs() << TAG << "-> got " << NumUses << " uses in " \
612 << RFI.getNumFunctionsWithUses() \
613 << " different functions.\n"; \
614 }); \
615 } \
616 }
617#include "llvm/Frontend/OpenMP/OMPKinds.def"
618
619 // Remove the `noinline` attribute from `__kmpc`, `ompx::` and `omp_`
620 // functions, except if `optnone` is present.
621 if (isOpenMPDevice(M)) {
622 for (Function &F : M) {
623 for (StringRef Prefix : {"__kmpc", "_ZN4ompx", "omp_"})
624 if (F.hasFnAttribute(Attribute::NoInline) &&
625 F.getName().starts_with(Prefix) &&
626 !F.hasFnAttribute(Attribute::OptimizeNone))
627 F.removeFnAttr(Attribute::NoInline);
628 }
629 }
630
631 // TODO: We should attach the attributes defined in OMPKinds.def.
632 }
633
634 /// Collection of known OpenMP runtime functions..
635 DenseSet<const Function *> RTLFunctions;
636
637 /// Indicates if we have already linked in the OpenMP device library.
638 bool OpenMPPostLink = false;
639};
640
641template <typename Ty, bool InsertInvalidates = true>
642struct BooleanStateWithSetVector : public BooleanState {
643 bool contains(const Ty &Elem) const { return Set.contains(Elem); }
644 bool insert(const Ty &Elem) {
645 if (InsertInvalidates)
646 BooleanState::indicatePessimisticFixpoint();
647 return Set.insert(Elem);
648 }
649
650 const Ty &operator[](int Idx) const { return Set[Idx]; }
651 bool operator==(const BooleanStateWithSetVector &RHS) const {
652 return BooleanState::operator==(R: RHS) && Set == RHS.Set;
653 }
654 bool operator!=(const BooleanStateWithSetVector &RHS) const {
655 return !(*this == RHS);
656 }
657
658 bool empty() const { return Set.empty(); }
659 size_t size() const { return Set.size(); }
660
661 /// "Clamp" this state with \p RHS.
662 BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
663 BooleanState::operator^=(R: RHS);
664 Set.insert(RHS.Set.begin(), RHS.Set.end());
665 return *this;
666 }
667
668private:
669 /// A set to keep track of elements.
670 SetVector<Ty> Set;
671
672public:
673 typename decltype(Set)::iterator begin() { return Set.begin(); }
674 typename decltype(Set)::iterator end() { return Set.end(); }
675 typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
676 typename decltype(Set)::const_iterator end() const { return Set.end(); }
677};
678
679template <typename Ty, bool InsertInvalidates = true>
680using BooleanStateWithPtrSetVector =
681 BooleanStateWithSetVector<Ty *, InsertInvalidates>;
682
683struct KernelInfoState : AbstractState {
684 /// Flag to track if we reached a fixpoint.
685 bool IsAtFixpoint = false;
686
687 /// The parallel regions (identified by the outlined parallel functions) that
688 /// can be reached from the associated function.
689 BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
690 ReachedKnownParallelRegions;
691
692 /// State to track what parallel region we might reach.
693 BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
694
695 /// State to track if we are in SPMD-mode, assumed or know, and why we decided
696 /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
697 /// false.
698 BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
699
700 /// The __kmpc_target_init call in this kernel, if any. If we find more than
701 /// one we abort as the kernel is malformed.
702 CallBase *KernelInitCB = nullptr;
703
704 /// The constant kernel environement as taken from and passed to
705 /// __kmpc_target_init.
706 ConstantStruct *KernelEnvC = nullptr;
707
708 /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
709 /// one we abort as the kernel is malformed.
710 CallBase *KernelDeinitCB = nullptr;
711
712 /// Flag to indicate if the associated function is a kernel entry.
713 bool IsKernelEntry = false;
714
715 /// State to track what kernel entries can reach the associated function.
716 BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
717
718 /// State to indicate if we can track parallel level of the associated
719 /// function. We will give up tracking if we encounter unknown caller or the
720 /// caller is __kmpc_parallel_51.
721 BooleanStateWithSetVector<uint8_t> ParallelLevels;
722
723 /// Flag that indicates if the kernel has nested Parallelism
724 bool NestedParallelism = false;
725
726 /// Abstract State interface
727 ///{
728
729 KernelInfoState() = default;
730 KernelInfoState(bool BestState) {
731 if (!BestState)
732 indicatePessimisticFixpoint();
733 }
734
735 /// See AbstractState::isValidState(...)
736 bool isValidState() const override { return true; }
737
738 /// See AbstractState::isAtFixpoint(...)
739 bool isAtFixpoint() const override { return IsAtFixpoint; }
740
741 /// See AbstractState::indicatePessimisticFixpoint(...)
742 ChangeStatus indicatePessimisticFixpoint() override {
743 IsAtFixpoint = true;
744 ParallelLevels.indicatePessimisticFixpoint();
745 ReachingKernelEntries.indicatePessimisticFixpoint();
746 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
747 ReachedKnownParallelRegions.indicatePessimisticFixpoint();
748 ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
749 NestedParallelism = true;
750 return ChangeStatus::CHANGED;
751 }
752
753 /// See AbstractState::indicateOptimisticFixpoint(...)
754 ChangeStatus indicateOptimisticFixpoint() override {
755 IsAtFixpoint = true;
756 ParallelLevels.indicateOptimisticFixpoint();
757 ReachingKernelEntries.indicateOptimisticFixpoint();
758 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
759 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
760 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
761 return ChangeStatus::UNCHANGED;
762 }
763
764 /// Return the assumed state
765 KernelInfoState &getAssumed() { return *this; }
766 const KernelInfoState &getAssumed() const { return *this; }
767
768 bool operator==(const KernelInfoState &RHS) const {
769 if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
770 return false;
771 if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
772 return false;
773 if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
774 return false;
775 if (ReachingKernelEntries != RHS.ReachingKernelEntries)
776 return false;
777 if (ParallelLevels != RHS.ParallelLevels)
778 return false;
779 if (NestedParallelism != RHS.NestedParallelism)
780 return false;
781 return true;
782 }
783
784 /// Returns true if this kernel contains any OpenMP parallel regions.
785 bool mayContainParallelRegion() {
786 return !ReachedKnownParallelRegions.empty() ||
787 !ReachedUnknownParallelRegions.empty();
788 }
789
790 /// Return empty set as the best state of potential values.
791 static KernelInfoState getBestState() { return KernelInfoState(true); }
792
793 static KernelInfoState getBestState(KernelInfoState &KIS) {
794 return getBestState();
795 }
796
797 /// Return full set as the worst state of potential values.
798 static KernelInfoState getWorstState() { return KernelInfoState(false); }
799
800 /// "Clamp" this state with \p KIS.
801 KernelInfoState operator^=(const KernelInfoState &KIS) {
802 // Do not merge two different _init and _deinit call sites.
803 if (KIS.KernelInitCB) {
804 if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
805 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
806 "assumptions.");
807 KernelInitCB = KIS.KernelInitCB;
808 }
809 if (KIS.KernelDeinitCB) {
810 if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
811 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
812 "assumptions.");
813 KernelDeinitCB = KIS.KernelDeinitCB;
814 }
815 if (KIS.KernelEnvC) {
816 if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
817 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
818 "assumptions.");
819 KernelEnvC = KIS.KernelEnvC;
820 }
821 SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
822 ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
823 ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
824 NestedParallelism |= KIS.NestedParallelism;
825 return *this;
826 }
827
828 KernelInfoState operator&=(const KernelInfoState &KIS) {
829 return (*this ^= KIS);
830 }
831
832 ///}
833};
834
835/// Used to map the values physically (in the IR) stored in an offload
836/// array, to a vector in memory.
837struct OffloadArray {
838 /// Physical array (in the IR).
839 AllocaInst *Array = nullptr;
840 /// Mapped values.
841 SmallVector<Value *, 8> StoredValues;
842 /// Last stores made in the offload array.
843 SmallVector<StoreInst *, 8> LastAccesses;
844
845 OffloadArray() = default;
846
847 /// Initializes the OffloadArray with the values stored in \p Array before
848 /// instruction \p Before is reached. Returns false if the initialization
849 /// fails.
850 /// This MUST be used immediately after the construction of the object.
851 bool initialize(AllocaInst &Array, Instruction &Before) {
852 if (!Array.getAllocatedType()->isArrayTy())
853 return false;
854
855 if (!getValues(Array, Before))
856 return false;
857
858 this->Array = &Array;
859 return true;
860 }
861
862 static const unsigned DeviceIDArgNum = 1;
863 static const unsigned BasePtrsArgNum = 3;
864 static const unsigned PtrsArgNum = 4;
865 static const unsigned SizesArgNum = 5;
866
867private:
868 /// Traverses the BasicBlock where \p Array is, collecting the stores made to
869 /// \p Array, leaving StoredValues with the values stored before the
870 /// instruction \p Before is reached.
871 bool getValues(AllocaInst &Array, Instruction &Before) {
872 // Initialize container.
873 const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
874 StoredValues.assign(NumElts: NumValues, Elt: nullptr);
875 LastAccesses.assign(NumElts: NumValues, Elt: nullptr);
876
877 // TODO: This assumes the instruction \p Before is in the same
878 // BasicBlock as Array. Make it general, for any control flow graph.
879 BasicBlock *BB = Array.getParent();
880 if (BB != Before.getParent())
881 return false;
882
883 const DataLayout &DL = Array.getModule()->getDataLayout();
884 const unsigned int PointerSize = DL.getPointerSize();
885
886 for (Instruction &I : *BB) {
887 if (&I == &Before)
888 break;
889
890 if (!isa<StoreInst>(Val: &I))
891 continue;
892
893 auto *S = cast<StoreInst>(Val: &I);
894 int64_t Offset = -1;
895 auto *Dst =
896 GetPointerBaseWithConstantOffset(Ptr: S->getPointerOperand(), Offset, DL);
897 if (Dst == &Array) {
898 int64_t Idx = Offset / PointerSize;
899 StoredValues[Idx] = getUnderlyingObject(V: S->getValueOperand());
900 LastAccesses[Idx] = S;
901 }
902 }
903
904 return isFilled();
905 }
906
907 /// Returns true if all values in StoredValues and
908 /// LastAccesses are not nullptrs.
909 bool isFilled() {
910 const unsigned NumValues = StoredValues.size();
911 for (unsigned I = 0; I < NumValues; ++I) {
912 if (!StoredValues[I] || !LastAccesses[I])
913 return false;
914 }
915
916 return true;
917 }
918};
919
920struct OpenMPOpt {
921
922 using OptimizationRemarkGetter =
923 function_ref<OptimizationRemarkEmitter &(Function *)>;
924
925 OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
926 OptimizationRemarkGetter OREGetter,
927 OMPInformationCache &OMPInfoCache, Attributor &A)
928 : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
929 OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
930
931 /// Check if any remarks are enabled for openmp-opt
932 bool remarksEnabled() {
933 auto &Ctx = M.getContext();
934 return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
935 }
936
937 /// Run all OpenMP optimizations on the underlying SCC.
938 bool run(bool IsModulePass) {
939 if (SCC.empty())
940 return false;
941
942 bool Changed = false;
943
944 LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
945 << " functions\n");
946
947 if (IsModulePass) {
948 Changed |= runAttributor(IsModulePass);
949
950 // Recollect uses, in case Attributor deleted any.
951 OMPInfoCache.recollectUses();
952
953 // TODO: This should be folded into buildCustomStateMachine.
954 Changed |= rewriteDeviceCodeStateMachine();
955
956 if (remarksEnabled())
957 analysisGlobalization();
958 } else {
959 if (PrintICVValues)
960 printICVs();
961 if (PrintOpenMPKernels)
962 printKernels();
963
964 Changed |= runAttributor(IsModulePass);
965
966 // Recollect uses, in case Attributor deleted any.
967 OMPInfoCache.recollectUses();
968
969 Changed |= deleteParallelRegions();
970
971 if (HideMemoryTransferLatency)
972 Changed |= hideMemTransfersLatency();
973 Changed |= deduplicateRuntimeCalls();
974 if (EnableParallelRegionMerging) {
975 if (mergeParallelRegions()) {
976 deduplicateRuntimeCalls();
977 Changed = true;
978 }
979 }
980 }
981
982 if (OMPInfoCache.OpenMPPostLink)
983 Changed |= removeRuntimeSymbols();
984
985 return Changed;
986 }
987
988 /// Print initial ICV values for testing.
989 /// FIXME: This should be done from the Attributor once it is added.
990 void printICVs() const {
991 InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
992 ICV_proc_bind};
993
994 for (Function *F : SCC) {
995 for (auto ICV : ICVs) {
996 auto ICVInfo = OMPInfoCache.ICVs[ICV];
997 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
998 return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
999 << " Value: "
1000 << (ICVInfo.InitValue
1001 ? toString(I: ICVInfo.InitValue->getValue(), Radix: 10, Signed: true)
1002 : "IMPLEMENTATION_DEFINED");
1003 };
1004
1005 emitRemark<OptimizationRemarkAnalysis>(F, RemarkName: "OpenMPICVTracker", RemarkCB&: Remark);
1006 }
1007 }
1008 }
1009
1010 /// Print OpenMP GPU kernels for testing.
1011 void printKernels() const {
1012 for (Function *F : SCC) {
1013 if (!omp::isOpenMPKernel(Fn&: *F))
1014 continue;
1015
1016 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1017 return ORA << "OpenMP GPU kernel "
1018 << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
1019 };
1020
1021 emitRemark<OptimizationRemarkAnalysis>(F, RemarkName: "OpenMPGPU", RemarkCB&: Remark);
1022 }
1023 }
1024
1025 /// Return the call if \p U is a callee use in a regular call. If \p RFI is
1026 /// given it has to be the callee or a nullptr is returned.
1027 static CallInst *getCallIfRegularCall(
1028 Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1029 CallInst *CI = dyn_cast<CallInst>(Val: U.getUser());
1030 if (CI && CI->isCallee(U: &U) && !CI->hasOperandBundles() &&
1031 (!RFI ||
1032 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1033 return CI;
1034 return nullptr;
1035 }
1036
1037 /// Return the call if \p V is a regular call. If \p RFI is given it has to be
1038 /// the callee or a nullptr is returned.
1039 static CallInst *getCallIfRegularCall(
1040 Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1041 CallInst *CI = dyn_cast<CallInst>(Val: &V);
1042 if (CI && !CI->hasOperandBundles() &&
1043 (!RFI ||
1044 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1045 return CI;
1046 return nullptr;
1047 }
1048
1049private:
1050 /// Merge parallel regions when it is safe.
1051 bool mergeParallelRegions() {
1052 const unsigned CallbackCalleeOperand = 2;
1053 const unsigned CallbackFirstArgOperand = 3;
1054 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1055
1056 // Check if there are any __kmpc_fork_call calls to merge.
1057 OMPInformationCache::RuntimeFunctionInfo &RFI =
1058 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1059
1060 if (!RFI.Declaration)
1061 return false;
1062
1063 // Unmergable calls that prevent merging a parallel region.
1064 OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
1065 OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
1066 OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
1067 };
1068
1069 bool Changed = false;
1070 LoopInfo *LI = nullptr;
1071 DominatorTree *DT = nullptr;
1072
1073 SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap;
1074
1075 BasicBlock *StartBB = nullptr, *EndBB = nullptr;
1076 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1077 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1078 BasicBlock *CGEndBB =
1079 SplitBlock(Old: CGStartBB, SplitPt: &*CodeGenIP.getPoint(), DT, LI);
1080 assert(StartBB != nullptr && "StartBB should not be null");
1081 CGStartBB->getTerminator()->setSuccessor(Idx: 0, BB: StartBB);
1082 assert(EndBB != nullptr && "EndBB should not be null");
1083 EndBB->getTerminator()->setSuccessor(Idx: 0, BB: CGEndBB);
1084 };
1085
1086 auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
1087 Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
1088 ReplacementValue = &Inner;
1089 return CodeGenIP;
1090 };
1091
1092 auto FiniCB = [&](InsertPointTy CodeGenIP) {};
1093
1094 /// Create a sequential execution region within a merged parallel region,
1095 /// encapsulated in a master construct with a barrier for synchronization.
1096 auto CreateSequentialRegion = [&](Function *OuterFn,
1097 BasicBlock *OuterPredBB,
1098 Instruction *SeqStartI,
1099 Instruction *SeqEndI) {
1100 // Isolate the instructions of the sequential region to a separate
1101 // block.
1102 BasicBlock *ParentBB = SeqStartI->getParent();
1103 BasicBlock *SeqEndBB =
1104 SplitBlock(Old: ParentBB, SplitPt: SeqEndI->getNextNode(), DT, LI);
1105 BasicBlock *SeqAfterBB =
1106 SplitBlock(Old: SeqEndBB, SplitPt: &*SeqEndBB->getFirstInsertionPt(), DT, LI);
1107 BasicBlock *SeqStartBB =
1108 SplitBlock(Old: ParentBB, SplitPt: SeqStartI, DT, LI, MSSAU: nullptr, BBName: "seq.par.merged");
1109
1110 assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
1111 "Expected a different CFG");
1112 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
1113 ParentBB->getTerminator()->eraseFromParent();
1114
1115 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1116 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1117 BasicBlock *CGEndBB =
1118 SplitBlock(Old: CGStartBB, SplitPt: &*CodeGenIP.getPoint(), DT, LI);
1119 assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
1120 CGStartBB->getTerminator()->setSuccessor(Idx: 0, BB: SeqStartBB);
1121 assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
1122 SeqEndBB->getTerminator()->setSuccessor(Idx: 0, BB: CGEndBB);
1123 };
1124 auto FiniCB = [&](InsertPointTy CodeGenIP) {};
1125
1126 // Find outputs from the sequential region to outside users and
1127 // broadcast their values to them.
1128 for (Instruction &I : *SeqStartBB) {
1129 SmallPtrSet<Instruction *, 4> OutsideUsers;
1130 for (User *Usr : I.users()) {
1131 Instruction &UsrI = *cast<Instruction>(Val: Usr);
1132 // Ignore outputs to LT intrinsics, code extraction for the merged
1133 // parallel region will fix them.
1134 if (UsrI.isLifetimeStartOrEnd())
1135 continue;
1136
1137 if (UsrI.getParent() != SeqStartBB)
1138 OutsideUsers.insert(Ptr: &UsrI);
1139 }
1140
1141 if (OutsideUsers.empty())
1142 continue;
1143
1144 // Emit an alloca in the outer region to store the broadcasted
1145 // value.
1146 const DataLayout &DL = M.getDataLayout();
1147 AllocaInst *AllocaI = new AllocaInst(
1148 I.getType(), DL.getAllocaAddrSpace(), nullptr,
1149 I.getName() + ".seq.output.alloc", OuterFn->front().begin());
1150
1151 // Emit a store instruction in the sequential BB to update the
1152 // value.
1153 new StoreInst(&I, AllocaI, SeqStartBB->getTerminator()->getIterator());
1154
1155 // Emit a load instruction and replace the use of the output value
1156 // with it.
1157 for (Instruction *UsrI : OutsideUsers) {
1158 LoadInst *LoadI = new LoadInst(I.getType(), AllocaI,
1159 I.getName() + ".seq.output.load",
1160 UsrI->getIterator());
1161 UsrI->replaceUsesOfWith(From: &I, To: LoadI);
1162 }
1163 }
1164
1165 OpenMPIRBuilder::LocationDescription Loc(
1166 InsertPointTy(ParentBB, ParentBB->end()), DL);
1167 InsertPointTy SeqAfterIP =
1168 OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
1169
1170 OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
1171
1172 BranchInst::Create(IfTrue: SeqAfterBB, InsertAtEnd: SeqAfterIP.getBlock());
1173
1174 LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
1175 << "\n");
1176 };
1177
1178 // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
1179 // contained in BB and only separated by instructions that can be
1180 // redundantly executed in parallel. The block BB is split before the first
1181 // call (in MergableCIs) and after the last so the entire region we merge
1182 // into a single parallel region is contained in a single basic block
1183 // without any other instructions. We use the OpenMPIRBuilder to outline
1184 // that block and call the resulting function via __kmpc_fork_call.
1185 auto Merge = [&](const SmallVectorImpl<CallInst *> &MergableCIs,
1186 BasicBlock *BB) {
1187 // TODO: Change the interface to allow single CIs expanded, e.g, to
1188 // include an outer loop.
1189 assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
1190
1191 auto Remark = [&](OptimizationRemark OR) {
1192 OR << "Parallel region merged with parallel region"
1193 << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1194 for (auto *CI : llvm::drop_begin(RangeOrContainer: MergableCIs)) {
1195 OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1196 if (CI != MergableCIs.back())
1197 OR << ", ";
1198 }
1199 return OR << ".";
1200 };
1201
1202 emitRemark<OptimizationRemark>(I: MergableCIs.front(), RemarkName: "OMP150", RemarkCB&: Remark);
1203
1204 Function *OriginalFn = BB->getParent();
1205 LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1206 << " parallel regions in " << OriginalFn->getName()
1207 << "\n");
1208
1209 // Isolate the calls to merge in a separate block.
1210 EndBB = SplitBlock(Old: BB, SplitPt: MergableCIs.back()->getNextNode(), DT, LI);
1211 BasicBlock *AfterBB =
1212 SplitBlock(Old: EndBB, SplitPt: &*EndBB->getFirstInsertionPt(), DT, LI);
1213 StartBB = SplitBlock(Old: BB, SplitPt: MergableCIs.front(), DT, LI, MSSAU: nullptr,
1214 BBName: "omp.par.merged");
1215
1216 assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1217 const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1218 BB->getTerminator()->eraseFromParent();
1219
1220 // Create sequential regions for sequential instructions that are
1221 // in-between mergable parallel regions.
1222 for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1223 It != End; ++It) {
1224 Instruction *ForkCI = *It;
1225 Instruction *NextForkCI = *(It + 1);
1226
1227 // Continue if there are not in-between instructions.
1228 if (ForkCI->getNextNode() == NextForkCI)
1229 continue;
1230
1231 CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1232 NextForkCI->getPrevNode());
1233 }
1234
1235 OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1236 DL);
1237 IRBuilder<>::InsertPoint AllocaIP(
1238 &OriginalFn->getEntryBlock(),
1239 OriginalFn->getEntryBlock().getFirstInsertionPt());
1240 // Create the merged parallel region with default proc binding, to
1241 // avoid overriding binding settings, and without explicit cancellation.
1242 InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
1243 Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1244 OMP_PROC_BIND_default, /* IsCancellable */ false);
1245 BranchInst::Create(IfTrue: AfterBB, InsertAtEnd: AfterIP.getBlock());
1246
1247 // Perform the actual outlining.
1248 OMPInfoCache.OMPBuilder.finalize(Fn: OriginalFn);
1249
1250 Function *OutlinedFn = MergableCIs.front()->getCaller();
1251
1252 // Replace the __kmpc_fork_call calls with direct calls to the outlined
1253 // callbacks.
1254 SmallVector<Value *, 8> Args;
1255 for (auto *CI : MergableCIs) {
1256 Value *Callee = CI->getArgOperand(i: CallbackCalleeOperand);
1257 FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
1258 Args.clear();
1259 Args.push_back(Elt: OutlinedFn->getArg(i: 0));
1260 Args.push_back(Elt: OutlinedFn->getArg(i: 1));
1261 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1262 ++U)
1263 Args.push_back(Elt: CI->getArgOperand(i: U));
1264
1265 CallInst *NewCI =
1266 CallInst::Create(Ty: FT, Func: Callee, Args, NameStr: "", InsertBefore: CI->getIterator());
1267 if (CI->getDebugLoc())
1268 NewCI->setDebugLoc(CI->getDebugLoc());
1269
1270 // Forward parameter attributes from the callback to the callee.
1271 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1272 ++U)
1273 for (const Attribute &A : CI->getAttributes().getParamAttrs(ArgNo: U))
1274 NewCI->addParamAttr(
1275 ArgNo: U - (CallbackFirstArgOperand - CallbackCalleeOperand), Attr: A);
1276
1277 // Emit an explicit barrier to replace the implicit fork-join barrier.
1278 if (CI != MergableCIs.back()) {
1279 // TODO: Remove barrier if the merged parallel region includes the
1280 // 'nowait' clause.
1281 OMPInfoCache.OMPBuilder.createBarrier(
1282 InsertPointTy(NewCI->getParent(),
1283 NewCI->getNextNode()->getIterator()),
1284 OMPD_parallel);
1285 }
1286
1287 CI->eraseFromParent();
1288 }
1289
1290 assert(OutlinedFn != OriginalFn && "Outlining failed");
1291 CGUpdater.registerOutlinedFunction(OriginalFn&: *OriginalFn, NewFn&: *OutlinedFn);
1292 CGUpdater.reanalyzeFunction(Fn&: *OriginalFn);
1293
1294 NumOpenMPParallelRegionsMerged += MergableCIs.size();
1295
1296 return true;
1297 };
1298
1299 // Helper function that identifes sequences of
1300 // __kmpc_fork_call uses in a basic block.
1301 auto DetectPRsCB = [&](Use &U, Function &F) {
1302 CallInst *CI = getCallIfRegularCall(U, RFI: &RFI);
1303 BB2PRMap[CI->getParent()].insert(Ptr: CI);
1304
1305 return false;
1306 };
1307
1308 BB2PRMap.clear();
1309 RFI.foreachUse(SCC, CB: DetectPRsCB);
1310 SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1311 // Find mergable parallel regions within a basic block that are
1312 // safe to merge, that is any in-between instructions can safely
1313 // execute in parallel after merging.
1314 // TODO: support merging across basic-blocks.
1315 for (auto &It : BB2PRMap) {
1316 auto &CIs = It.getSecond();
1317 if (CIs.size() < 2)
1318 continue;
1319
1320 BasicBlock *BB = It.getFirst();
1321 SmallVector<CallInst *, 4> MergableCIs;
1322
1323 /// Returns true if the instruction is mergable, false otherwise.
1324 /// A terminator instruction is unmergable by definition since merging
1325 /// works within a BB. Instructions before the mergable region are
1326 /// mergable if they are not calls to OpenMP runtime functions that may
1327 /// set different execution parameters for subsequent parallel regions.
1328 /// Instructions in-between parallel regions are mergable if they are not
1329 /// calls to any non-intrinsic function since that may call a non-mergable
1330 /// OpenMP runtime function.
1331 auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1332 // We do not merge across BBs, hence return false (unmergable) if the
1333 // instruction is a terminator.
1334 if (I.isTerminator())
1335 return false;
1336
1337 if (!isa<CallInst>(Val: &I))
1338 return true;
1339
1340 CallInst *CI = cast<CallInst>(Val: &I);
1341 if (IsBeforeMergableRegion) {
1342 Function *CalledFunction = CI->getCalledFunction();
1343 if (!CalledFunction)
1344 return false;
1345 // Return false (unmergable) if the call before the parallel
1346 // region calls an explicit affinity (proc_bind) or number of
1347 // threads (num_threads) compiler-generated function. Those settings
1348 // may be incompatible with following parallel regions.
1349 // TODO: ICV tracking to detect compatibility.
1350 for (const auto &RFI : UnmergableCallsInfo) {
1351 if (CalledFunction == RFI.Declaration)
1352 return false;
1353 }
1354 } else {
1355 // Return false (unmergable) if there is a call instruction
1356 // in-between parallel regions when it is not an intrinsic. It
1357 // may call an unmergable OpenMP runtime function in its callpath.
1358 // TODO: Keep track of possible OpenMP calls in the callpath.
1359 if (!isa<IntrinsicInst>(Val: CI))
1360 return false;
1361 }
1362
1363 return true;
1364 };
1365 // Find maximal number of parallel region CIs that are safe to merge.
1366 for (auto It = BB->begin(), End = BB->end(); It != End;) {
1367 Instruction &I = *It;
1368 ++It;
1369
1370 if (CIs.count(Ptr: &I)) {
1371 MergableCIs.push_back(Elt: cast<CallInst>(Val: &I));
1372 continue;
1373 }
1374
1375 // Continue expanding if the instruction is mergable.
1376 if (IsMergable(I, MergableCIs.empty()))
1377 continue;
1378
1379 // Forward the instruction iterator to skip the next parallel region
1380 // since there is an unmergable instruction which can affect it.
1381 for (; It != End; ++It) {
1382 Instruction &SkipI = *It;
1383 if (CIs.count(Ptr: &SkipI)) {
1384 LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1385 << " due to " << I << "\n");
1386 ++It;
1387 break;
1388 }
1389 }
1390
1391 // Store mergable regions found.
1392 if (MergableCIs.size() > 1) {
1393 MergableCIsVector.push_back(Elt: MergableCIs);
1394 LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1395 << " parallel regions in block " << BB->getName()
1396 << " of function " << BB->getParent()->getName()
1397 << "\n";);
1398 }
1399
1400 MergableCIs.clear();
1401 }
1402
1403 if (!MergableCIsVector.empty()) {
1404 Changed = true;
1405
1406 for (auto &MergableCIs : MergableCIsVector)
1407 Merge(MergableCIs, BB);
1408 MergableCIsVector.clear();
1409 }
1410 }
1411
1412 if (Changed) {
1413 /// Re-collect use for fork calls, emitted barrier calls, and
1414 /// any emitted master/end_master calls.
1415 OMPInfoCache.recollectUsesForFunction(RTF: OMPRTL___kmpc_fork_call);
1416 OMPInfoCache.recollectUsesForFunction(RTF: OMPRTL___kmpc_barrier);
1417 OMPInfoCache.recollectUsesForFunction(RTF: OMPRTL___kmpc_master);
1418 OMPInfoCache.recollectUsesForFunction(RTF: OMPRTL___kmpc_end_master);
1419 }
1420
1421 return Changed;
1422 }
1423
1424 /// Try to delete parallel regions if possible.
1425 bool deleteParallelRegions() {
1426 const unsigned CallbackCalleeOperand = 2;
1427
1428 OMPInformationCache::RuntimeFunctionInfo &RFI =
1429 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1430
1431 if (!RFI.Declaration)
1432 return false;
1433
1434 bool Changed = false;
1435 auto DeleteCallCB = [&](Use &U, Function &) {
1436 CallInst *CI = getCallIfRegularCall(U);
1437 if (!CI)
1438 return false;
1439 auto *Fn = dyn_cast<Function>(
1440 Val: CI->getArgOperand(i: CallbackCalleeOperand)->stripPointerCasts());
1441 if (!Fn)
1442 return false;
1443 if (!Fn->onlyReadsMemory())
1444 return false;
1445 if (!Fn->hasFnAttribute(Attribute::WillReturn))
1446 return false;
1447
1448 LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1449 << CI->getCaller()->getName() << "\n");
1450
1451 auto Remark = [&](OptimizationRemark OR) {
1452 return OR << "Removing parallel region with no side-effects.";
1453 };
1454 emitRemark<OptimizationRemark>(I: CI, RemarkName: "OMP160", RemarkCB&: Remark);
1455
1456 CGUpdater.removeCallSite(CS&: *CI);
1457 CI->eraseFromParent();
1458 Changed = true;
1459 ++NumOpenMPParallelRegionsDeleted;
1460 return true;
1461 };
1462
1463 RFI.foreachUse(SCC, CB: DeleteCallCB);
1464
1465 return Changed;
1466 }
1467
1468 /// Try to eliminate runtime calls by reusing existing ones.
1469 bool deduplicateRuntimeCalls() {
1470 bool Changed = false;
1471
1472 RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1473 OMPRTL_omp_get_num_threads,
1474 OMPRTL_omp_in_parallel,
1475 OMPRTL_omp_get_cancellation,
1476 OMPRTL_omp_get_supported_active_levels,
1477 OMPRTL_omp_get_level,
1478 OMPRTL_omp_get_ancestor_thread_num,
1479 OMPRTL_omp_get_team_size,
1480 OMPRTL_omp_get_active_level,
1481 OMPRTL_omp_in_final,
1482 OMPRTL_omp_get_proc_bind,
1483 OMPRTL_omp_get_num_places,
1484 OMPRTL_omp_get_num_procs,
1485 OMPRTL_omp_get_place_num,
1486 OMPRTL_omp_get_partition_num_places,
1487 OMPRTL_omp_get_partition_place_nums};
1488
1489 // Global-tid is handled separately.
1490 SmallSetVector<Value *, 16> GTIdArgs;
1491 collectGlobalThreadIdArguments(GTIdArgs);
1492 LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1493 << " global thread ID arguments\n");
1494
1495 for (Function *F : SCC) {
1496 for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1497 Changed |= deduplicateRuntimeCalls(
1498 F&: *F, RFI&: OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1499
1500 // __kmpc_global_thread_num is special as we can replace it with an
1501 // argument in enough cases to make it worth trying.
1502 Value *GTIdArg = nullptr;
1503 for (Argument &Arg : F->args())
1504 if (GTIdArgs.count(key: &Arg)) {
1505 GTIdArg = &Arg;
1506 break;
1507 }
1508 Changed |= deduplicateRuntimeCalls(
1509 F&: *F, RFI&: OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], ReplVal: GTIdArg);
1510 }
1511
1512 return Changed;
1513 }
1514
1515 /// Tries to remove known runtime symbols that are optional from the module.
1516 bool removeRuntimeSymbols() {
1517 // The RPC client symbol is defined in `libc` and indicates that something
1518 // required an RPC server. If its users were all optimized out then we can
1519 // safely remove it.
1520 // TODO: This should be somewhere more common in the future.
1521 if (GlobalVariable *GV = M.getNamedGlobal(Name: "__llvm_libc_rpc_client")) {
1522 if (!GV->getType()->isPointerTy())
1523 return false;
1524
1525 Constant *C = GV->getInitializer();
1526 if (!C)
1527 return false;
1528
1529 // Check to see if the only user of the RPC client is the external handle.
1530 GlobalVariable *Client = dyn_cast<GlobalVariable>(Val: C->stripPointerCasts());
1531 if (!Client || Client->getNumUses() > 1 ||
1532 Client->user_back() != GV->getInitializer())
1533 return false;
1534
1535 Client->replaceAllUsesWith(V: PoisonValue::get(T: Client->getType()));
1536 Client->eraseFromParent();
1537
1538 GV->replaceAllUsesWith(V: PoisonValue::get(T: GV->getType()));
1539 GV->eraseFromParent();
1540
1541 return true;
1542 }
1543 return false;
1544 }
1545
1546 /// Tries to hide the latency of runtime calls that involve host to
1547 /// device memory transfers by splitting them into their "issue" and "wait"
1548 /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1549 /// moved downards as much as possible. The "issue" issues the memory transfer
1550 /// asynchronously, returning a handle. The "wait" waits in the returned
1551 /// handle for the memory transfer to finish.
1552 bool hideMemTransfersLatency() {
1553 auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1554 bool Changed = false;
1555 auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1556 auto *RTCall = getCallIfRegularCall(U, RFI: &RFI);
1557 if (!RTCall)
1558 return false;
1559
1560 OffloadArray OffloadArrays[3];
1561 if (!getValuesInOffloadArrays(RuntimeCall&: *RTCall, OAs: OffloadArrays))
1562 return false;
1563
1564 LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1565
1566 // TODO: Check if can be moved upwards.
1567 bool WasSplit = false;
1568 Instruction *WaitMovementPoint = canBeMovedDownwards(RuntimeCall&: *RTCall);
1569 if (WaitMovementPoint)
1570 WasSplit = splitTargetDataBeginRTC(RuntimeCall&: *RTCall, WaitMovementPoint&: *WaitMovementPoint);
1571
1572 Changed |= WasSplit;
1573 return WasSplit;
1574 };
1575 if (OMPInfoCache.runtimeFnsAvailable(
1576 Fns: {OMPRTL___tgt_target_data_begin_mapper_issue,
1577 OMPRTL___tgt_target_data_begin_mapper_wait}))
1578 RFI.foreachUse(SCC, CB: SplitMemTransfers);
1579
1580 return Changed;
1581 }
1582
1583 void analysisGlobalization() {
1584 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1585
1586 auto CheckGlobalization = [&](Use &U, Function &Decl) {
1587 if (CallInst *CI = getCallIfRegularCall(U, RFI: &RFI)) {
1588 auto Remark = [&](OptimizationRemarkMissed ORM) {
1589 return ORM
1590 << "Found thread data sharing on the GPU. "
1591 << "Expect degraded performance due to data globalization.";
1592 };
1593 emitRemark<OptimizationRemarkMissed>(I: CI, RemarkName: "OMP112", RemarkCB&: Remark);
1594 }
1595
1596 return false;
1597 };
1598
1599 RFI.foreachUse(SCC, CB: CheckGlobalization);
1600 }
1601
1602 /// Maps the values stored in the offload arrays passed as arguments to
1603 /// \p RuntimeCall into the offload arrays in \p OAs.
1604 bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1605 MutableArrayRef<OffloadArray> OAs) {
1606 assert(OAs.size() == 3 && "Need space for three offload arrays!");
1607
1608 // A runtime call that involves memory offloading looks something like:
1609 // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1610 // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1611 // ...)
1612 // So, the idea is to access the allocas that allocate space for these
1613 // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1614 // Therefore:
1615 // i8** %offload_baseptrs.
1616 Value *BasePtrsArg =
1617 RuntimeCall.getArgOperand(i: OffloadArray::BasePtrsArgNum);
1618 // i8** %offload_ptrs.
1619 Value *PtrsArg = RuntimeCall.getArgOperand(i: OffloadArray::PtrsArgNum);
1620 // i8** %offload_sizes.
1621 Value *SizesArg = RuntimeCall.getArgOperand(i: OffloadArray::SizesArgNum);
1622
1623 // Get values stored in **offload_baseptrs.
1624 auto *V = getUnderlyingObject(V: BasePtrsArg);
1625 if (!isa<AllocaInst>(Val: V))
1626 return false;
1627 auto *BasePtrsArray = cast<AllocaInst>(Val: V);
1628 if (!OAs[0].initialize(Array&: *BasePtrsArray, Before&: RuntimeCall))
1629 return false;
1630
1631 // Get values stored in **offload_baseptrs.
1632 V = getUnderlyingObject(V: PtrsArg);
1633 if (!isa<AllocaInst>(Val: V))
1634 return false;
1635 auto *PtrsArray = cast<AllocaInst>(Val: V);
1636 if (!OAs[1].initialize(Array&: *PtrsArray, Before&: RuntimeCall))
1637 return false;
1638
1639 // Get values stored in **offload_sizes.
1640 V = getUnderlyingObject(V: SizesArg);
1641 // If it's a [constant] global array don't analyze it.
1642 if (isa<GlobalValue>(Val: V))
1643 return isa<Constant>(Val: V);
1644 if (!isa<AllocaInst>(Val: V))
1645 return false;
1646
1647 auto *SizesArray = cast<AllocaInst>(Val: V);
1648 if (!OAs[2].initialize(Array&: *SizesArray, Before&: RuntimeCall))
1649 return false;
1650
1651 return true;
1652 }
1653
1654 /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1655 /// For now this is a way to test that the function getValuesInOffloadArrays
1656 /// is working properly.
1657 /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1658 void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1659 assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1660
1661 LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1662 std::string ValuesStr;
1663 raw_string_ostream Printer(ValuesStr);
1664 std::string Separator = " --- ";
1665
1666 for (auto *BP : OAs[0].StoredValues) {
1667 BP->print(O&: Printer);
1668 Printer << Separator;
1669 }
1670 LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n");
1671 ValuesStr.clear();
1672
1673 for (auto *P : OAs[1].StoredValues) {
1674 P->print(O&: Printer);
1675 Printer << Separator;
1676 }
1677 LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n");
1678 ValuesStr.clear();
1679
1680 for (auto *S : OAs[2].StoredValues) {
1681 S->print(O&: Printer);
1682 Printer << Separator;
1683 }
1684 LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n");
1685 }
1686
1687 /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1688 /// moved. Returns nullptr if the movement is not possible, or not worth it.
1689 Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1690 // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1691 // Make it traverse the CFG.
1692
1693 Instruction *CurrentI = &RuntimeCall;
1694 bool IsWorthIt = false;
1695 while ((CurrentI = CurrentI->getNextNode())) {
1696
1697 // TODO: Once we detect the regions to be offloaded we should use the
1698 // alias analysis manager to check if CurrentI may modify one of
1699 // the offloaded regions.
1700 if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1701 if (IsWorthIt)
1702 return CurrentI;
1703
1704 return nullptr;
1705 }
1706
1707 // FIXME: For now if we move it over anything without side effect
1708 // is worth it.
1709 IsWorthIt = true;
1710 }
1711
1712 // Return end of BasicBlock.
1713 return RuntimeCall.getParent()->getTerminator();
1714 }
1715
1716 /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1717 bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1718 Instruction &WaitMovementPoint) {
1719 // Create stack allocated handle (__tgt_async_info) at the beginning of the
1720 // function. Used for storing information of the async transfer, allowing to
1721 // wait on it later.
1722 auto &IRBuilder = OMPInfoCache.OMPBuilder;
1723 Function *F = RuntimeCall.getCaller();
1724 BasicBlock &Entry = F->getEntryBlock();
1725 IRBuilder.Builder.SetInsertPoint(TheBB: &Entry,
1726 IP: Entry.getFirstNonPHIOrDbgOrAlloca());
1727 Value *Handle = IRBuilder.Builder.CreateAlloca(
1728 Ty: IRBuilder.AsyncInfo, /*ArraySize=*/nullptr, Name: "handle");
1729 Handle =
1730 IRBuilder.Builder.CreateAddrSpaceCast(V: Handle, DestTy: IRBuilder.AsyncInfoPtr);
1731
1732 // Add "issue" runtime call declaration:
1733 // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1734 // i8**, i8**, i64*, i64*)
1735 FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1736 M, FnID: OMPRTL___tgt_target_data_begin_mapper_issue);
1737
1738 // Change RuntimeCall call site for its asynchronous version.
1739 SmallVector<Value *, 16> Args;
1740 for (auto &Arg : RuntimeCall.args())
1741 Args.push_back(Elt: Arg.get());
1742 Args.push_back(Elt: Handle);
1743
1744 CallInst *IssueCallsite = CallInst::Create(Func: IssueDecl, Args, /*NameStr=*/"",
1745 InsertBefore: RuntimeCall.getIterator());
1746 OMPInfoCache.setCallingConvention(Callee: IssueDecl, CI: IssueCallsite);
1747 RuntimeCall.eraseFromParent();
1748
1749 // Add "wait" runtime call declaration:
1750 // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1751 FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1752 M, FnID: OMPRTL___tgt_target_data_begin_mapper_wait);
1753
1754 Value *WaitParams[2] = {
1755 IssueCallsite->getArgOperand(
1756 i: OffloadArray::DeviceIDArgNum), // device_id.
1757 Handle // handle to wait on.
1758 };
1759 CallInst *WaitCallsite = CallInst::Create(
1760 Func: WaitDecl, Args: WaitParams, /*NameStr=*/"", InsertBefore: WaitMovementPoint.getIterator());
1761 OMPInfoCache.setCallingConvention(Callee: WaitDecl, CI: WaitCallsite);
1762
1763 return true;
1764 }
1765
1766 static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1767 bool GlobalOnly, bool &SingleChoice) {
1768 if (CurrentIdent == NextIdent)
1769 return CurrentIdent;
1770
1771 // TODO: Figure out how to actually combine multiple debug locations. For
1772 // now we just keep an existing one if there is a single choice.
1773 if (!GlobalOnly || isa<GlobalValue>(Val: NextIdent)) {
1774 SingleChoice = !CurrentIdent;
1775 return NextIdent;
1776 }
1777 return nullptr;
1778 }
1779
1780 /// Return an `struct ident_t*` value that represents the ones used in the
1781 /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1782 /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1783 /// return value we create one from scratch. We also do not yet combine
1784 /// information, e.g., the source locations, see combinedIdentStruct.
1785 Value *
1786 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1787 Function &F, bool GlobalOnly) {
1788 bool SingleChoice = true;
1789 Value *Ident = nullptr;
1790 auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1791 CallInst *CI = getCallIfRegularCall(U, RFI: &RFI);
1792 if (!CI || &F != &Caller)
1793 return false;
1794 Ident = combinedIdentStruct(CurrentIdent: Ident, NextIdent: CI->getArgOperand(i: 0),
1795 /* GlobalOnly */ true, SingleChoice);
1796 return false;
1797 };
1798 RFI.foreachUse(SCC, CB: CombineIdentStruct);
1799
1800 if (!Ident || !SingleChoice) {
1801 // The IRBuilder uses the insertion block to get to the module, this is
1802 // unfortunate but we work around it for now.
1803 if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1804 OMPInfoCache.OMPBuilder.updateToLocation(Loc: OpenMPIRBuilder::InsertPointTy(
1805 &F.getEntryBlock(), F.getEntryBlock().begin()));
1806 // Create a fallback location if non was found.
1807 // TODO: Use the debug locations of the calls instead.
1808 uint32_t SrcLocStrSize;
1809 Constant *Loc =
1810 OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1811 Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr: Loc, SrcLocStrSize);
1812 }
1813 return Ident;
1814 }
1815
1816 /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1817 /// \p ReplVal if given.
1818 bool deduplicateRuntimeCalls(Function &F,
1819 OMPInformationCache::RuntimeFunctionInfo &RFI,
1820 Value *ReplVal = nullptr) {
1821 auto *UV = RFI.getUseVector(F);
1822 if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1823 return false;
1824
1825 LLVM_DEBUG(
1826 dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1827 << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1828
1829 assert((!ReplVal || (isa<Argument>(ReplVal) &&
1830 cast<Argument>(ReplVal)->getParent() == &F)) &&
1831 "Unexpected replacement value!");
1832
1833 // TODO: Use dominance to find a good position instead.
1834 auto CanBeMoved = [this](CallBase &CB) {
1835 unsigned NumArgs = CB.arg_size();
1836 if (NumArgs == 0)
1837 return true;
1838 if (CB.getArgOperand(i: 0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1839 return false;
1840 for (unsigned U = 1; U < NumArgs; ++U)
1841 if (isa<Instruction>(Val: CB.getArgOperand(i: U)))
1842 return false;
1843 return true;
1844 };
1845
1846 if (!ReplVal) {
1847 auto *DT =
1848 OMPInfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(F);
1849 if (!DT)
1850 return false;
1851 Instruction *IP = nullptr;
1852 for (Use *U : *UV) {
1853 if (CallInst *CI = getCallIfRegularCall(U&: *U, RFI: &RFI)) {
1854 if (IP)
1855 IP = DT->findNearestCommonDominator(I1: IP, I2: CI);
1856 else
1857 IP = CI;
1858 if (!CanBeMoved(*CI))
1859 continue;
1860 if (!ReplVal)
1861 ReplVal = CI;
1862 }
1863 }
1864 if (!ReplVal)
1865 return false;
1866 assert(IP && "Expected insertion point!");
1867 cast<Instruction>(Val: ReplVal)->moveBefore(MovePos: IP);
1868 }
1869
1870 // If we use a call as a replacement value we need to make sure the ident is
1871 // valid at the new location. For now we just pick a global one, either
1872 // existing and used by one of the calls, or created from scratch.
1873 if (CallBase *CI = dyn_cast<CallBase>(Val: ReplVal)) {
1874 if (!CI->arg_empty() &&
1875 CI->getArgOperand(i: 0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1876 Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1877 /* GlobalOnly */ true);
1878 CI->setArgOperand(i: 0, v: Ident);
1879 }
1880 }
1881
1882 bool Changed = false;
1883 auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1884 CallInst *CI = getCallIfRegularCall(U, RFI: &RFI);
1885 if (!CI || CI == ReplVal || &F != &Caller)
1886 return false;
1887 assert(CI->getCaller() == &F && "Unexpected call!");
1888
1889 auto Remark = [&](OptimizationRemark OR) {
1890 return OR << "OpenMP runtime call "
1891 << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1892 };
1893 if (CI->getDebugLoc())
1894 emitRemark<OptimizationRemark>(I: CI, RemarkName: "OMP170", RemarkCB&: Remark);
1895 else
1896 emitRemark<OptimizationRemark>(F: &F, RemarkName: "OMP170", RemarkCB&: Remark);
1897
1898 CGUpdater.removeCallSite(CS&: *CI);
1899 CI->replaceAllUsesWith(V: ReplVal);
1900 CI->eraseFromParent();
1901 ++NumOpenMPRuntimeCallsDeduplicated;
1902 Changed = true;
1903 return true;
1904 };
1905 RFI.foreachUse(SCC, CB: ReplaceAndDeleteCB);
1906
1907 return Changed;
1908 }
1909
1910 /// Collect arguments that represent the global thread id in \p GTIdArgs.
1911 void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1912 // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1913 // initialization. We could define an AbstractAttribute instead and
1914 // run the Attributor here once it can be run as an SCC pass.
1915
1916 // Helper to check the argument \p ArgNo at all call sites of \p F for
1917 // a GTId.
1918 auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1919 if (!F.hasLocalLinkage())
1920 return false;
1921 for (Use &U : F.uses()) {
1922 if (CallInst *CI = getCallIfRegularCall(U)) {
1923 Value *ArgOp = CI->getArgOperand(i: ArgNo);
1924 if (CI == &RefCI || GTIdArgs.count(key: ArgOp) ||
1925 getCallIfRegularCall(
1926 V&: *ArgOp, RFI: &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1927 continue;
1928 }
1929 return false;
1930 }
1931 return true;
1932 };
1933
1934 // Helper to identify uses of a GTId as GTId arguments.
1935 auto AddUserArgs = [&](Value &GTId) {
1936 for (Use &U : GTId.uses())
1937 if (CallInst *CI = dyn_cast<CallInst>(Val: U.getUser()))
1938 if (CI->isArgOperand(U: &U))
1939 if (Function *Callee = CI->getCalledFunction())
1940 if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1941 GTIdArgs.insert(X: Callee->getArg(i: U.getOperandNo()));
1942 };
1943
1944 // The argument users of __kmpc_global_thread_num calls are GTIds.
1945 OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1946 OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1947
1948 GlobThreadNumRFI.foreachUse(SCC, CB: [&](Use &U, Function &F) {
1949 if (CallInst *CI = getCallIfRegularCall(U, RFI: &GlobThreadNumRFI))
1950 AddUserArgs(*CI);
1951 return false;
1952 });
1953
1954 // Transitively search for more arguments by looking at the users of the
1955 // ones we know already. During the search the GTIdArgs vector is extended
1956 // so we cannot cache the size nor can we use a range based for.
1957 for (unsigned U = 0; U < GTIdArgs.size(); ++U)
1958 AddUserArgs(*GTIdArgs[U]);
1959 }
1960
1961 /// Kernel (=GPU) optimizations and utility functions
1962 ///
1963 ///{{
1964
1965 /// Cache to remember the unique kernel for a function.
1966 DenseMap<Function *, std::optional<Kernel>> UniqueKernelMap;
1967
1968 /// Find the unique kernel that will execute \p F, if any.
1969 Kernel getUniqueKernelFor(Function &F);
1970
1971 /// Find the unique kernel that will execute \p I, if any.
1972 Kernel getUniqueKernelFor(Instruction &I) {
1973 return getUniqueKernelFor(F&: *I.getFunction());
1974 }
1975
1976 /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1977 /// the cases we can avoid taking the address of a function.
1978 bool rewriteDeviceCodeStateMachine();
1979
1980 ///
1981 ///}}
1982
1983 /// Emit a remark generically
1984 ///
1985 /// This template function can be used to generically emit a remark. The
1986 /// RemarkKind should be one of the following:
1987 /// - OptimizationRemark to indicate a successful optimization attempt
1988 /// - OptimizationRemarkMissed to report a failed optimization attempt
1989 /// - OptimizationRemarkAnalysis to provide additional information about an
1990 /// optimization attempt
1991 ///
1992 /// The remark is built using a callback function provided by the caller that
1993 /// takes a RemarkKind as input and returns a RemarkKind.
1994 template <typename RemarkKind, typename RemarkCallBack>
1995 void emitRemark(Instruction *I, StringRef RemarkName,
1996 RemarkCallBack &&RemarkCB) const {
1997 Function *F = I->getParent()->getParent();
1998 auto &ORE = OREGetter(F);
1999
2000 if (RemarkName.starts_with(Prefix: "OMP"))
2001 ORE.emit([&]() {
2002 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
2003 << " [" << RemarkName << "]";
2004 });
2005 else
2006 ORE.emit(
2007 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
2008 }
2009
2010 /// Emit a remark on a function.
2011 template <typename RemarkKind, typename RemarkCallBack>
2012 void emitRemark(Function *F, StringRef RemarkName,
2013 RemarkCallBack &&RemarkCB) const {
2014 auto &ORE = OREGetter(F);
2015
2016 if (RemarkName.starts_with(Prefix: "OMP"))
2017 ORE.emit([&]() {
2018 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
2019 << " [" << RemarkName << "]";
2020 });
2021 else
2022 ORE.emit(
2023 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
2024 }
2025
2026 /// The underlying module.
2027 Module &M;
2028
2029 /// The SCC we are operating on.
2030 SmallVectorImpl<Function *> &SCC;
2031
2032 /// Callback to update the call graph, the first argument is a removed call,
2033 /// the second an optional replacement call.
2034 CallGraphUpdater &CGUpdater;
2035
2036 /// Callback to get an OptimizationRemarkEmitter from a Function *
2037 OptimizationRemarkGetter OREGetter;
2038
2039 /// OpenMP-specific information cache. Also Used for Attributor runs.
2040 OMPInformationCache &OMPInfoCache;
2041
2042 /// Attributor instance.
2043 Attributor &A;
2044
2045 /// Helper function to run Attributor on SCC.
2046 bool runAttributor(bool IsModulePass) {
2047 if (SCC.empty())
2048 return false;
2049
2050 registerAAs(IsModulePass);
2051
2052 ChangeStatus Changed = A.run();
2053
2054 LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
2055 << " functions, result: " << Changed << ".\n");
2056
2057 if (Changed == ChangeStatus::CHANGED)
2058 OMPInfoCache.invalidateAnalyses();
2059
2060 return Changed == ChangeStatus::CHANGED;
2061 }
2062
2063 void registerFoldRuntimeCall(RuntimeFunction RF);
2064
2065 /// Populate the Attributor with abstract attribute opportunities in the
2066 /// functions.
2067 void registerAAs(bool IsModulePass);
2068
2069public:
2070 /// Callback to register AAs for live functions, including internal functions
2071 /// marked live during the traversal.
2072 static void registerAAsForFunction(Attributor &A, const Function &F);
2073};
2074
2075Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
2076 if (OMPInfoCache.CGSCC && !OMPInfoCache.CGSCC->empty() &&
2077 !OMPInfoCache.CGSCC->contains(key: &F))
2078 return nullptr;
2079
2080 // Use a scope to keep the lifetime of the CachedKernel short.
2081 {
2082 std::optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
2083 if (CachedKernel)
2084 return *CachedKernel;
2085
2086 // TODO: We should use an AA to create an (optimistic and callback
2087 // call-aware) call graph. For now we stick to simple patterns that
2088 // are less powerful, basically the worst fixpoint.
2089 if (isOpenMPKernel(Fn&: F)) {
2090 CachedKernel = Kernel(&F);
2091 return *CachedKernel;
2092 }
2093
2094 CachedKernel = nullptr;
2095 if (!F.hasLocalLinkage()) {
2096
2097 // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
2098 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2099 return ORA << "Potentially unknown OpenMP target region caller.";
2100 };
2101 emitRemark<OptimizationRemarkAnalysis>(F: &F, RemarkName: "OMP100", RemarkCB&: Remark);
2102
2103 return nullptr;
2104 }
2105 }
2106
2107 auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
2108 if (auto *Cmp = dyn_cast<ICmpInst>(Val: U.getUser())) {
2109 // Allow use in equality comparisons.
2110 if (Cmp->isEquality())
2111 return getUniqueKernelFor(I&: *Cmp);
2112 return nullptr;
2113 }
2114 if (auto *CB = dyn_cast<CallBase>(Val: U.getUser())) {
2115 // Allow direct calls.
2116 if (CB->isCallee(U: &U))
2117 return getUniqueKernelFor(I&: *CB);
2118
2119 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2120 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2121 // Allow the use in __kmpc_parallel_51 calls.
2122 if (OpenMPOpt::getCallIfRegularCall(V&: *U.getUser(), RFI: &KernelParallelRFI))
2123 return getUniqueKernelFor(I&: *CB);
2124 return nullptr;
2125 }
2126 // Disallow every other use.
2127 return nullptr;
2128 };
2129
2130 // TODO: In the future we want to track more than just a unique kernel.
2131 SmallPtrSet<Kernel, 2> PotentialKernels;
2132 OMPInformationCache::foreachUse(F, CB: [&](const Use &U) {
2133 PotentialKernels.insert(Ptr: GetUniqueKernelForUse(U));
2134 });
2135
2136 Kernel K = nullptr;
2137 if (PotentialKernels.size() == 1)
2138 K = *PotentialKernels.begin();
2139
2140 // Cache the result.
2141 UniqueKernelMap[&F] = K;
2142
2143 return K;
2144}
2145
2146bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2147 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2148 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2149
2150 bool Changed = false;
2151 if (!KernelParallelRFI)
2152 return Changed;
2153
2154 // If we have disabled state machine changes, exit
2155 if (DisableOpenMPOptStateMachineRewrite)
2156 return Changed;
2157
2158 for (Function *F : SCC) {
2159
2160 // Check if the function is a use in a __kmpc_parallel_51 call at
2161 // all.
2162 bool UnknownUse = false;
2163 bool KernelParallelUse = false;
2164 unsigned NumDirectCalls = 0;
2165
2166 SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
2167 OMPInformationCache::foreachUse(F&: *F, CB: [&](Use &U) {
2168 if (auto *CB = dyn_cast<CallBase>(Val: U.getUser()))
2169 if (CB->isCallee(U: &U)) {
2170 ++NumDirectCalls;
2171 return;
2172 }
2173
2174 if (isa<ICmpInst>(Val: U.getUser())) {
2175 ToBeReplacedStateMachineUses.push_back(Elt: &U);
2176 return;
2177 }
2178
2179 // Find wrapper functions that represent parallel kernels.
2180 CallInst *CI =
2181 OpenMPOpt::getCallIfRegularCall(V&: *U.getUser(), RFI: &KernelParallelRFI);
2182 const unsigned int WrapperFunctionArgNo = 6;
2183 if (!KernelParallelUse && CI &&
2184 CI->getArgOperandNo(U: &U) == WrapperFunctionArgNo) {
2185 KernelParallelUse = true;
2186 ToBeReplacedStateMachineUses.push_back(Elt: &U);
2187 return;
2188 }
2189 UnknownUse = true;
2190 });
2191
2192 // Do not emit a remark if we haven't seen a __kmpc_parallel_51
2193 // use.
2194 if (!KernelParallelUse)
2195 continue;
2196
2197 // If this ever hits, we should investigate.
2198 // TODO: Checking the number of uses is not a necessary restriction and
2199 // should be lifted.
2200 if (UnknownUse || NumDirectCalls != 1 ||
2201 ToBeReplacedStateMachineUses.size() > 2) {
2202 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2203 return ORA << "Parallel region is used in "
2204 << (UnknownUse ? "unknown" : "unexpected")
2205 << " ways. Will not attempt to rewrite the state machine.";
2206 };
2207 emitRemark<OptimizationRemarkAnalysis>(F, RemarkName: "OMP101", RemarkCB&: Remark);
2208 continue;
2209 }
2210
2211 // Even if we have __kmpc_parallel_51 calls, we (for now) give
2212 // up if the function is not called from a unique kernel.
2213 Kernel K = getUniqueKernelFor(F&: *F);
2214 if (!K) {
2215 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2216 return ORA << "Parallel region is not called from a unique kernel. "
2217 "Will not attempt to rewrite the state machine.";
2218 };
2219 emitRemark<OptimizationRemarkAnalysis>(F, RemarkName: "OMP102", RemarkCB&: Remark);
2220 continue;
2221 }
2222
2223 // We now know F is a parallel body function called only from the kernel K.
2224 // We also identified the state machine uses in which we replace the
2225 // function pointer by a new global symbol for identification purposes. This
2226 // ensures only direct calls to the function are left.
2227
2228 Module &M = *F->getParent();
2229 Type *Int8Ty = Type::getInt8Ty(C&: M.getContext());
2230
2231 auto *ID = new GlobalVariable(
2232 M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2233 UndefValue::get(T: Int8Ty), F->getName() + ".ID");
2234
2235 for (Use *U : ToBeReplacedStateMachineUses)
2236 U->set(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
2237 C: ID, Ty: U->get()->getType()));
2238
2239 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2240
2241 Changed = true;
2242 }
2243
2244 return Changed;
2245}
2246
2247/// Abstract Attribute for tracking ICV values.
2248struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2249 using Base = StateWrapper<BooleanState, AbstractAttribute>;
2250 AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2251
2252 /// Returns true if value is assumed to be tracked.
2253 bool isAssumedTracked() const { return getAssumed(); }
2254
2255 /// Returns true if value is known to be tracked.
2256 bool isKnownTracked() const { return getAssumed(); }
2257
2258 /// Create an abstract attribute biew for the position \p IRP.
2259 static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2260
2261 /// Return the value with which \p I can be replaced for specific \p ICV.
2262 virtual std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2263 const Instruction *I,
2264 Attributor &A) const {
2265 return std::nullopt;
2266 }
2267
2268 /// Return an assumed unique ICV value if a single candidate is found. If
2269 /// there cannot be one, return a nullptr. If it is not clear yet, return
2270 /// std::nullopt.
2271 virtual std::optional<Value *>
2272 getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2273
2274 // Currently only nthreads is being tracked.
2275 // this array will only grow with time.
2276 InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2277
2278 /// See AbstractAttribute::getName()
2279 const std::string getName() const override { return "AAICVTracker"; }
2280
2281 /// See AbstractAttribute::getIdAddr()
2282 const char *getIdAddr() const override { return &ID; }
2283
2284 /// This function should return true if the type of the \p AA is AAICVTracker
2285 static bool classof(const AbstractAttribute *AA) {
2286 return (AA->getIdAddr() == &ID);
2287 }
2288
2289 static const char ID;
2290};
2291
2292struct AAICVTrackerFunction : public AAICVTracker {
2293 AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2294 : AAICVTracker(IRP, A) {}
2295
2296 // FIXME: come up with better string.
2297 const std::string getAsStr(Attributor *) const override {
2298 return "ICVTrackerFunction";
2299 }
2300
2301 // FIXME: come up with some stats.
2302 void trackStatistics() const override {}
2303
2304 /// We don't manifest anything for this AA.
2305 ChangeStatus manifest(Attributor &A) override {
2306 return ChangeStatus::UNCHANGED;
2307 }
2308
2309 // Map of ICV to their values at specific program point.
2310 EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar,
2311 InternalControlVar::ICV___last>
2312 ICVReplacementValuesMap;
2313
2314 ChangeStatus updateImpl(Attributor &A) override {
2315 ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2316
2317 Function *F = getAnchorScope();
2318
2319 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2320
2321 for (InternalControlVar ICV : TrackableICVs) {
2322 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2323
2324 auto &ValuesMap = ICVReplacementValuesMap[ICV];
2325 auto TrackValues = [&](Use &U, Function &) {
2326 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2327 if (!CI)
2328 return false;
2329
2330 // FIXME: handle setters with more that 1 arguments.
2331 /// Track new value.
2332 if (ValuesMap.insert(KV: std::make_pair(x&: CI, y: CI->getArgOperand(i: 0))).second)
2333 HasChanged = ChangeStatus::CHANGED;
2334
2335 return false;
2336 };
2337
2338 auto CallCheck = [&](Instruction &I) {
2339 std::optional<Value *> ReplVal = getValueForCall(A, I, ICV);
2340 if (ReplVal && ValuesMap.insert(KV: std::make_pair(x: &I, y&: *ReplVal)).second)
2341 HasChanged = ChangeStatus::CHANGED;
2342
2343 return true;
2344 };
2345
2346 // Track all changes of an ICV.
2347 SetterRFI.foreachUse(CB: TrackValues, F);
2348
2349 bool UsedAssumedInformation = false;
2350 A.checkForAllInstructions(Pred: CallCheck, QueryingAA: *this, Opcodes: {Instruction::Call},
2351 UsedAssumedInformation,
2352 /* CheckBBLivenessOnly */ true);
2353
2354 /// TODO: Figure out a way to avoid adding entry in
2355 /// ICVReplacementValuesMap
2356 Instruction *Entry = &F->getEntryBlock().front();
2357 if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Val: Entry))
2358 ValuesMap.insert(KV: std::make_pair(x&: Entry, y: nullptr));
2359 }
2360
2361 return HasChanged;
2362 }
2363
2364 /// Helper to check if \p I is a call and get the value for it if it is
2365 /// unique.
2366 std::optional<Value *> getValueForCall(Attributor &A, const Instruction &I,
2367 InternalControlVar &ICV) const {
2368
2369 const auto *CB = dyn_cast<CallBase>(Val: &I);
2370 if (!CB || CB->hasFnAttr(Kind: "no_openmp") ||
2371 CB->hasFnAttr(Kind: "no_openmp_routines"))
2372 return std::nullopt;
2373
2374 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2375 auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2376 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2377 Function *CalledFunction = CB->getCalledFunction();
2378
2379 // Indirect call, assume ICV changes.
2380 if (CalledFunction == nullptr)
2381 return nullptr;
2382 if (CalledFunction == GetterRFI.Declaration)
2383 return std::nullopt;
2384 if (CalledFunction == SetterRFI.Declaration) {
2385 if (ICVReplacementValuesMap[ICV].count(Val: &I))
2386 return ICVReplacementValuesMap[ICV].lookup(Val: &I);
2387
2388 return nullptr;
2389 }
2390
2391 // Since we don't know, assume it changes the ICV.
2392 if (CalledFunction->isDeclaration())
2393 return nullptr;
2394
2395 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2396 QueryingAA: *this, IRP: IRPosition::callsite_returned(CB: *CB), DepClass: DepClassTy::REQUIRED);
2397
2398 if (ICVTrackingAA->isAssumedTracked()) {
2399 std::optional<Value *> URV =
2400 ICVTrackingAA->getUniqueReplacementValue(ICV);
2401 if (!URV || (*URV && AA::isValidAtPosition(VAC: AA::ValueAndContext(**URV, I),
2402 InfoCache&: OMPInfoCache)))
2403 return URV;
2404 }
2405
2406 // If we don't know, assume it changes.
2407 return nullptr;
2408 }
2409
2410 // We don't check unique value for a function, so return std::nullopt.
2411 std::optional<Value *>
2412 getUniqueReplacementValue(InternalControlVar ICV) const override {
2413 return std::nullopt;
2414 }
2415
2416 /// Return the value with which \p I can be replaced for specific \p ICV.
2417 std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2418 const Instruction *I,
2419 Attributor &A) const override {
2420 const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2421 if (ValuesMap.count(Val: I))
2422 return ValuesMap.lookup(Val: I);
2423
2424 SmallVector<const Instruction *, 16> Worklist;
2425 SmallPtrSet<const Instruction *, 16> Visited;
2426 Worklist.push_back(Elt: I);
2427
2428 std::optional<Value *> ReplVal;
2429
2430 while (!Worklist.empty()) {
2431 const Instruction *CurrInst = Worklist.pop_back_val();
2432 if (!Visited.insert(Ptr: CurrInst).second)
2433 continue;
2434
2435 const BasicBlock *CurrBB = CurrInst->getParent();
2436
2437 // Go up and look for all potential setters/calls that might change the
2438 // ICV.
2439 while ((CurrInst = CurrInst->getPrevNode())) {
2440 if (ValuesMap.count(Val: CurrInst)) {
2441 std::optional<Value *> NewReplVal = ValuesMap.lookup(Val: CurrInst);
2442 // Unknown value, track new.
2443 if (!ReplVal) {
2444 ReplVal = NewReplVal;
2445 break;
2446 }
2447
2448 // If we found a new value, we can't know the icv value anymore.
2449 if (NewReplVal)
2450 if (ReplVal != NewReplVal)
2451 return nullptr;
2452
2453 break;
2454 }
2455
2456 std::optional<Value *> NewReplVal = getValueForCall(A, I: *CurrInst, ICV);
2457 if (!NewReplVal)
2458 continue;
2459
2460 // Unknown value, track new.
2461 if (!ReplVal) {
2462 ReplVal = NewReplVal;
2463 break;
2464 }
2465
2466 // if (NewReplVal.hasValue())
2467 // We found a new value, we can't know the icv value anymore.
2468 if (ReplVal != NewReplVal)
2469 return nullptr;
2470 }
2471
2472 // If we are in the same BB and we have a value, we are done.
2473 if (CurrBB == I->getParent() && ReplVal)
2474 return ReplVal;
2475
2476 // Go through all predecessors and add terminators for analysis.
2477 for (const BasicBlock *Pred : predecessors(BB: CurrBB))
2478 if (const Instruction *Terminator = Pred->getTerminator())
2479 Worklist.push_back(Elt: Terminator);
2480 }
2481
2482 return ReplVal;
2483 }
2484};
2485
2486struct AAICVTrackerFunctionReturned : AAICVTracker {
2487 AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2488 : AAICVTracker(IRP, A) {}
2489
2490 // FIXME: come up with better string.
2491 const std::string getAsStr(Attributor *) const override {
2492 return "ICVTrackerFunctionReturned";
2493 }
2494
2495 // FIXME: come up with some stats.
2496 void trackStatistics() const override {}
2497
2498 /// We don't manifest anything for this AA.
2499 ChangeStatus manifest(Attributor &A) override {
2500 return ChangeStatus::UNCHANGED;
2501 }
2502
2503 // Map of ICV to their values at specific program point.
2504 EnumeratedArray<std::optional<Value *>, InternalControlVar,
2505 InternalControlVar::ICV___last>
2506 ICVReplacementValuesMap;
2507
2508 /// Return the value with which \p I can be replaced for specific \p ICV.
2509 std::optional<Value *>
2510 getUniqueReplacementValue(InternalControlVar ICV) const override {
2511 return ICVReplacementValuesMap[ICV];
2512 }
2513
2514 ChangeStatus updateImpl(Attributor &A) override {
2515 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2516 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2517 QueryingAA: *this, IRP: IRPosition::function(F: *getAnchorScope()), DepClass: DepClassTy::REQUIRED);
2518
2519 if (!ICVTrackingAA->isAssumedTracked())
2520 return indicatePessimisticFixpoint();
2521
2522 for (InternalControlVar ICV : TrackableICVs) {
2523 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2524 std::optional<Value *> UniqueICVValue;
2525
2526 auto CheckReturnInst = [&](Instruction &I) {
2527 std::optional<Value *> NewReplVal =
2528 ICVTrackingAA->getReplacementValue(ICV, I: &I, A);
2529
2530 // If we found a second ICV value there is no unique returned value.
2531 if (UniqueICVValue && UniqueICVValue != NewReplVal)
2532 return false;
2533
2534 UniqueICVValue = NewReplVal;
2535
2536 return true;
2537 };
2538
2539 bool UsedAssumedInformation = false;
2540 if (!A.checkForAllInstructions(Pred: CheckReturnInst, QueryingAA: *this, Opcodes: {Instruction::Ret},
2541 UsedAssumedInformation,
2542 /* CheckBBLivenessOnly */ true))
2543 UniqueICVValue = nullptr;
2544
2545 if (UniqueICVValue == ReplVal)
2546 continue;
2547
2548 ReplVal = UniqueICVValue;
2549 Changed = ChangeStatus::CHANGED;
2550 }
2551
2552 return Changed;
2553 }
2554};
2555
2556struct AAICVTrackerCallSite : AAICVTracker {
2557 AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2558 : AAICVTracker(IRP, A) {}
2559
2560 void initialize(Attributor &A) override {
2561 assert(getAnchorScope() && "Expected anchor function");
2562
2563 // We only initialize this AA for getters, so we need to know which ICV it
2564 // gets.
2565 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2566 for (InternalControlVar ICV : TrackableICVs) {
2567 auto ICVInfo = OMPInfoCache.ICVs[ICV];
2568 auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2569 if (Getter.Declaration == getAssociatedFunction()) {
2570 AssociatedICV = ICVInfo.Kind;
2571 return;
2572 }
2573 }
2574
2575 /// Unknown ICV.
2576 indicatePessimisticFixpoint();
2577 }
2578
2579 ChangeStatus manifest(Attributor &A) override {
2580 if (!ReplVal || !*ReplVal)
2581 return ChangeStatus::UNCHANGED;
2582
2583 A.changeAfterManifest(IRP: IRPosition::inst(I: *getCtxI()), NV&: **ReplVal);
2584 A.deleteAfterManifest(I&: *getCtxI());
2585
2586 return ChangeStatus::CHANGED;
2587 }
2588
2589 // FIXME: come up with better string.
2590 const std::string getAsStr(Attributor *) const override {
2591 return "ICVTrackerCallSite";
2592 }
2593
2594 // FIXME: come up with some stats.
2595 void trackStatistics() const override {}
2596
2597 InternalControlVar AssociatedICV;
2598 std::optional<Value *> ReplVal;
2599
2600 ChangeStatus updateImpl(Attributor &A) override {
2601 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2602 QueryingAA: *this, IRP: IRPosition::function(F: *getAnchorScope()), DepClass: DepClassTy::REQUIRED);
2603
2604 // We don't have any information, so we assume it changes the ICV.
2605 if (!ICVTrackingAA->isAssumedTracked())
2606 return indicatePessimisticFixpoint();
2607
2608 std::optional<Value *> NewReplVal =
2609 ICVTrackingAA->getReplacementValue(ICV: AssociatedICV, I: getCtxI(), A);
2610
2611 if (ReplVal == NewReplVal)
2612 return ChangeStatus::UNCHANGED;
2613
2614 ReplVal = NewReplVal;
2615 return ChangeStatus::CHANGED;
2616 }
2617
2618 // Return the value with which associated value can be replaced for specific
2619 // \p ICV.
2620 std::optional<Value *>
2621 getUniqueReplacementValue(InternalControlVar ICV) const override {
2622 return ReplVal;
2623 }
2624};
2625
2626struct AAICVTrackerCallSiteReturned : AAICVTracker {
2627 AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2628 : AAICVTracker(IRP, A) {}
2629
2630 // FIXME: come up with better string.
2631 const std::string getAsStr(Attributor *) const override {
2632 return "ICVTrackerCallSiteReturned";
2633 }
2634
2635 // FIXME: come up with some stats.
2636 void trackStatistics() const override {}
2637
2638 /// We don't manifest anything for this AA.
2639 ChangeStatus manifest(Attributor &A) override {
2640 return ChangeStatus::UNCHANGED;
2641 }
2642
2643 // Map of ICV to their values at specific program point.
2644 EnumeratedArray<std::optional<Value *>, InternalControlVar,
2645 InternalControlVar::ICV___last>
2646 ICVReplacementValuesMap;
2647
2648 /// Return the value with which associated value can be replaced for specific
2649 /// \p ICV.
2650 std::optional<Value *>
2651 getUniqueReplacementValue(InternalControlVar ICV) const override {
2652 return ICVReplacementValuesMap[ICV];
2653 }
2654
2655 ChangeStatus updateImpl(Attributor &A) override {
2656 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2657 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2658 QueryingAA: *this, IRP: IRPosition::returned(F: *getAssociatedFunction()),
2659 DepClass: DepClassTy::REQUIRED);
2660
2661 // We don't have any information, so we assume it changes the ICV.
2662 if (!ICVTrackingAA->isAssumedTracked())
2663 return indicatePessimisticFixpoint();
2664
2665 for (InternalControlVar ICV : TrackableICVs) {
2666 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2667 std::optional<Value *> NewReplVal =
2668 ICVTrackingAA->getUniqueReplacementValue(ICV);
2669
2670 if (ReplVal == NewReplVal)
2671 continue;
2672
2673 ReplVal = NewReplVal;
2674 Changed = ChangeStatus::CHANGED;
2675 }
2676 return Changed;
2677 }
2678};
2679
2680/// Determines if \p BB exits the function unconditionally itself or reaches a
2681/// block that does through only unique successors.
2682static bool hasFunctionEndAsUniqueSuccessor(const BasicBlock *BB) {
2683 if (succ_empty(BB))
2684 return true;
2685 const BasicBlock *const Successor = BB->getUniqueSuccessor();
2686 if (!Successor)
2687 return false;
2688 return hasFunctionEndAsUniqueSuccessor(BB: Successor);
2689}
2690
2691struct AAExecutionDomainFunction : public AAExecutionDomain {
2692 AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2693 : AAExecutionDomain(IRP, A) {}
2694
2695 ~AAExecutionDomainFunction() { delete RPOT; }
2696
2697 void initialize(Attributor &A) override {
2698 Function *F = getAnchorScope();
2699 assert(F && "Expected anchor function");
2700 RPOT = new ReversePostOrderTraversal<Function *>(F);
2701 }
2702
2703 const std::string getAsStr(Attributor *) const override {
2704 unsigned TotalBlocks = 0, InitialThreadBlocks = 0, AlignedBlocks = 0;
2705 for (auto &It : BEDMap) {
2706 if (!It.getFirst())
2707 continue;
2708 TotalBlocks++;
2709 InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly;
2710 AlignedBlocks += It.getSecond().IsReachedFromAlignedBarrierOnly &&
2711 It.getSecond().IsReachingAlignedBarrierOnly;
2712 }
2713 return "[AAExecutionDomain] " + std::to_string(val: InitialThreadBlocks) + "/" +
2714 std::to_string(val: AlignedBlocks) + " of " +
2715 std::to_string(val: TotalBlocks) +
2716 " executed by initial thread / aligned";
2717 }
2718
2719 /// See AbstractAttribute::trackStatistics().
2720 void trackStatistics() const override {}
2721
2722 ChangeStatus manifest(Attributor &A) override {
2723 LLVM_DEBUG({
2724 for (const BasicBlock &BB : *getAnchorScope()) {
2725 if (!isExecutedByInitialThreadOnly(BB))
2726 continue;
2727 dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2728 << BB.getName() << " is executed by a single thread.\n";
2729 }
2730 });
2731
2732 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2733
2734 if (DisableOpenMPOptBarrierElimination)
2735 return Changed;
2736
2737 SmallPtrSet<CallBase *, 16> DeletedBarriers;
2738 auto HandleAlignedBarrier = [&](CallBase *CB) {
2739 const ExecutionDomainTy &ED = CB ? CEDMap[{CB, PRE}] : BEDMap[nullptr];
2740 if (!ED.IsReachedFromAlignedBarrierOnly ||
2741 ED.EncounteredNonLocalSideEffect)
2742 return;
2743 if (!ED.EncounteredAssumes.empty() && !A.isModulePass())
2744 return;
2745
2746 // We can remove this barrier, if it is one, or aligned barriers reaching
2747 // the kernel end (if CB is nullptr). Aligned barriers reaching the kernel
2748 // end should only be removed if the kernel end is their unique successor;
2749 // otherwise, they may have side-effects that aren't accounted for in the
2750 // kernel end in their other successors. If those barriers have other
2751 // barriers reaching them, those can be transitively removed as well as
2752 // long as the kernel end is also their unique successor.
2753 if (CB) {
2754 DeletedBarriers.insert(Ptr: CB);
2755 A.deleteAfterManifest(I&: *CB);
2756 ++NumBarriersEliminated;
2757 Changed = ChangeStatus::CHANGED;
2758 } else if (!ED.AlignedBarriers.empty()) {
2759 Changed = ChangeStatus::CHANGED;
2760 SmallVector<CallBase *> Worklist(ED.AlignedBarriers.begin(),
2761 ED.AlignedBarriers.end());
2762 SmallSetVector<CallBase *, 16> Visited;
2763 while (!Worklist.empty()) {
2764 CallBase *LastCB = Worklist.pop_back_val();
2765 if (!Visited.insert(X: LastCB))
2766 continue;
2767 if (LastCB->getFunction() != getAnchorScope())
2768 continue;
2769 if (!hasFunctionEndAsUniqueSuccessor(BB: LastCB->getParent()))
2770 continue;
2771 if (!DeletedBarriers.count(Ptr: LastCB)) {
2772 ++NumBarriersEliminated;
2773 A.deleteAfterManifest(I&: *LastCB);
2774 continue;
2775 }
2776 // The final aligned barrier (LastCB) reaching the kernel end was
2777 // removed already. This means we can go one step further and remove
2778 // the barriers encoutered last before (LastCB).
2779 const ExecutionDomainTy &LastED = CEDMap[{LastCB, PRE}];
2780 Worklist.append(in_start: LastED.AlignedBarriers.begin(),
2781 in_end: LastED.AlignedBarriers.end());
2782 }
2783 }
2784
2785 // If we actually eliminated a barrier we need to eliminate the associated
2786 // llvm.assumes as well to avoid creating UB.
2787 if (!ED.EncounteredAssumes.empty() && (CB || !ED.AlignedBarriers.empty()))
2788 for (auto *AssumeCB : ED.EncounteredAssumes)
2789 A.deleteAfterManifest(I&: *AssumeCB);
2790 };
2791
2792 for (auto *CB : AlignedBarriers)
2793 HandleAlignedBarrier(CB);
2794
2795 // Handle the "kernel end barrier" for kernels too.
2796 if (omp::isOpenMPKernel(Fn&: *getAnchorScope()))
2797 HandleAlignedBarrier(nullptr);
2798
2799 return Changed;
2800 }
2801
2802 bool isNoOpFence(const FenceInst &FI) const override {
2803 return getState().isValidState() && !NonNoOpFences.count(Ptr: &FI);
2804 }
2805
2806 /// Merge barrier and assumption information from \p PredED into the successor
2807 /// \p ED.
2808 void
2809 mergeInPredecessorBarriersAndAssumptions(Attributor &A, ExecutionDomainTy &ED,
2810 const ExecutionDomainTy &PredED);
2811
2812 /// Merge all information from \p PredED into the successor \p ED. If
2813 /// \p InitialEdgeOnly is set, only the initial edge will enter the block
2814 /// represented by \p ED from this predecessor.
2815 bool mergeInPredecessor(Attributor &A, ExecutionDomainTy &ED,
2816 const ExecutionDomainTy &PredED,
2817 bool InitialEdgeOnly = false);
2818
2819 /// Accumulate information for the entry block in \p EntryBBED.
2820 bool handleCallees(Attributor &A, ExecutionDomainTy &EntryBBED);
2821
2822 /// See AbstractAttribute::updateImpl.
2823 ChangeStatus updateImpl(Attributor &A) override;
2824
2825 /// Query interface, see AAExecutionDomain
2826 ///{
2827 bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2828 if (!isValidState())
2829 return false;
2830 assert(BB.getParent() == getAnchorScope() && "Block is out of scope!");
2831 return BEDMap.lookup(Val: &BB).IsExecutedByInitialThreadOnly;
2832 }
2833
2834 bool isExecutedInAlignedRegion(Attributor &A,
2835 const Instruction &I) const override {
2836 assert(I.getFunction() == getAnchorScope() &&
2837 "Instruction is out of scope!");
2838 if (!isValidState())
2839 return false;
2840
2841 bool ForwardIsOk = true;
2842 const Instruction *CurI;
2843
2844 // Check forward until a call or the block end is reached.
2845 CurI = &I;
2846 do {
2847 auto *CB = dyn_cast<CallBase>(Val: CurI);
2848 if (!CB)
2849 continue;
2850 if (CB != &I && AlignedBarriers.contains(key: const_cast<CallBase *>(CB)))
2851 return true;
2852 const auto &It = CEDMap.find(Val: {CB, PRE});
2853 if (It == CEDMap.end())
2854 continue;
2855 if (!It->getSecond().IsReachingAlignedBarrierOnly)
2856 ForwardIsOk = false;
2857 break;
2858 } while ((CurI = CurI->getNextNonDebugInstruction()));
2859
2860 if (!CurI && !BEDMap.lookup(Val: I.getParent()).IsReachingAlignedBarrierOnly)
2861 ForwardIsOk = false;
2862
2863 // Check backward until a call or the block beginning is reached.
2864 CurI = &I;
2865 do {
2866 auto *CB = dyn_cast<CallBase>(Val: CurI);
2867 if (!CB)
2868 continue;
2869 if (CB != &I && AlignedBarriers.contains(key: const_cast<CallBase *>(CB)))
2870 return true;
2871 const auto &It = CEDMap.find(Val: {CB, POST});
2872 if (It == CEDMap.end())
2873 continue;
2874 if (It->getSecond().IsReachedFromAlignedBarrierOnly)
2875 break;
2876 return false;
2877 } while ((CurI = CurI->getPrevNonDebugInstruction()));
2878
2879 // Delayed decision on the forward pass to allow aligned barrier detection
2880 // in the backwards traversal.
2881 if (!ForwardIsOk)
2882 return false;
2883
2884 if (!CurI) {
2885 const BasicBlock *BB = I.getParent();
2886 if (BB == &BB->getParent()->getEntryBlock())
2887 return BEDMap.lookup(Val: nullptr).IsReachedFromAlignedBarrierOnly;
2888 if (!llvm::all_of(Range: predecessors(BB), P: [&](const BasicBlock *PredBB) {
2889 return BEDMap.lookup(Val: PredBB).IsReachedFromAlignedBarrierOnly;
2890 })) {
2891 return false;
2892 }
2893 }
2894
2895 // On neither traversal we found a anything but aligned barriers.
2896 return true;
2897 }
2898
2899 ExecutionDomainTy getExecutionDomain(const BasicBlock &BB) const override {
2900 assert(isValidState() &&
2901 "No request should be made against an invalid state!");
2902 return BEDMap.lookup(Val: &BB);
2903 }
2904 std::pair<ExecutionDomainTy, ExecutionDomainTy>
2905 getExecutionDomain(const CallBase &CB) const override {
2906 assert(isValidState() &&
2907 "No request should be made against an invalid state!");
2908 return {CEDMap.lookup(Val: {&CB, PRE}), CEDMap.lookup(Val: {&CB, POST})};
2909 }
2910 ExecutionDomainTy getFunctionExecutionDomain() const override {
2911 assert(isValidState() &&
2912 "No request should be made against an invalid state!");
2913 return InterProceduralED;
2914 }
2915 ///}
2916
2917 // Check if the edge into the successor block contains a condition that only
2918 // lets the main thread execute it.
2919 static bool isInitialThreadOnlyEdge(Attributor &A, BranchInst *Edge,
2920 BasicBlock &SuccessorBB) {
2921 if (!Edge || !Edge->isConditional())
2922 return false;
2923 if (Edge->getSuccessor(i: 0) != &SuccessorBB)
2924 return false;
2925
2926 auto *Cmp = dyn_cast<CmpInst>(Val: Edge->getCondition());
2927 if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2928 return false;
2929
2930 ConstantInt *C = dyn_cast<ConstantInt>(Val: Cmp->getOperand(i_nocapture: 1));
2931 if (!C)
2932 return false;
2933
2934 // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2935 if (C->isAllOnesValue()) {
2936 auto *CB = dyn_cast<CallBase>(Val: Cmp->getOperand(i_nocapture: 0));
2937 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2938 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2939 CB = CB ? OpenMPOpt::getCallIfRegularCall(V&: *CB, RFI: &RFI) : nullptr;
2940 if (!CB)
2941 return false;
2942 ConstantStruct *KernelEnvC =
2943 KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB: CB);
2944 ConstantInt *ExecModeC =
2945 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
2946 return ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC;
2947 }
2948
2949 if (C->isZero()) {
2950 // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
2951 if (auto *II = dyn_cast<IntrinsicInst>(Val: Cmp->getOperand(i_nocapture: 0)))
2952 if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2953 return true;
2954
2955 // Match: 0 == llvm.amdgcn.workitem.id.x()
2956 if (auto *II = dyn_cast<IntrinsicInst>(Val: Cmp->getOperand(i_nocapture: 0)))
2957 if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2958 return true;
2959 }
2960
2961 return false;
2962 };
2963
2964 /// Mapping containing information about the function for other AAs.
2965 ExecutionDomainTy InterProceduralED;
2966
2967 enum Direction { PRE = 0, POST = 1 };
2968 /// Mapping containing information per block.
2969 DenseMap<const BasicBlock *, ExecutionDomainTy> BEDMap;
2970 DenseMap<PointerIntPair<const CallBase *, 1, Direction>, ExecutionDomainTy>
2971 CEDMap;
2972 SmallSetVector<CallBase *, 16> AlignedBarriers;
2973
2974 ReversePostOrderTraversal<Function *> *RPOT = nullptr;
2975
2976 /// Set \p R to \V and report true if that changed \p R.
2977 static bool setAndRecord(bool &R, bool V) {
2978 bool Eq = (R == V);
2979 R = V;
2980 return !Eq;
2981 }
2982
2983 /// Collection of fences known to be non-no-opt. All fences not in this set
2984 /// can be assumed no-opt.
2985 SmallPtrSet<const FenceInst *, 8> NonNoOpFences;
2986};
2987
2988void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions(
2989 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED) {
2990 for (auto *EA : PredED.EncounteredAssumes)
2991 ED.addAssumeInst(A, AI&: *EA);
2992
2993 for (auto *AB : PredED.AlignedBarriers)
2994 ED.addAlignedBarrier(A, CB&: *AB);
2995}
2996
2997bool AAExecutionDomainFunction::mergeInPredecessor(
2998 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED,
2999 bool InitialEdgeOnly) {
3000
3001 bool Changed = false;
3002 Changed |=
3003 setAndRecord(R&: ED.IsExecutedByInitialThreadOnly,
3004 V: InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly &&
3005 ED.IsExecutedByInitialThreadOnly));
3006
3007 Changed |= setAndRecord(R&: ED.IsReachedFromAlignedBarrierOnly,
3008 V: ED.IsReachedFromAlignedBarrierOnly &&
3009 PredED.IsReachedFromAlignedBarrierOnly);
3010 Changed |= setAndRecord(R&: ED.EncounteredNonLocalSideEffect,
3011 V: ED.EncounteredNonLocalSideEffect |
3012 PredED.EncounteredNonLocalSideEffect);
3013 // Do not track assumptions and barriers as part of Changed.
3014 if (ED.IsReachedFromAlignedBarrierOnly)
3015 mergeInPredecessorBarriersAndAssumptions(A, ED, PredED);
3016 else
3017 ED.clearAssumeInstAndAlignedBarriers();
3018 return Changed;
3019}
3020
3021bool AAExecutionDomainFunction::handleCallees(Attributor &A,
3022 ExecutionDomainTy &EntryBBED) {
3023 SmallVector<std::pair<ExecutionDomainTy, ExecutionDomainTy>, 4> CallSiteEDs;
3024 auto PredForCallSite = [&](AbstractCallSite ACS) {
3025 const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3026 QueryingAA: *this, IRP: IRPosition::function(F: *ACS.getInstruction()->getFunction()),
3027 DepClass: DepClassTy::OPTIONAL);
3028 if (!EDAA || !EDAA->getState().isValidState())
3029 return false;
3030 CallSiteEDs.emplace_back(
3031 Args: EDAA->getExecutionDomain(CB: *cast<CallBase>(Val: ACS.getInstruction())));
3032 return true;
3033 };
3034
3035 ExecutionDomainTy ExitED;
3036 bool AllCallSitesKnown;
3037 if (A.checkForAllCallSites(Pred: PredForCallSite, QueryingAA: *this,
3038 /* RequiresAllCallSites */ RequireAllCallSites: true,
3039 UsedAssumedInformation&: AllCallSitesKnown)) {
3040 for (const auto &[CSInED, CSOutED] : CallSiteEDs) {
3041 mergeInPredecessor(A, ED&: EntryBBED, PredED: CSInED);
3042 ExitED.IsReachingAlignedBarrierOnly &=
3043 CSOutED.IsReachingAlignedBarrierOnly;
3044 }
3045
3046 } else {
3047 // We could not find all predecessors, so this is either a kernel or a
3048 // function with external linkage (or with some other weird uses).
3049 if (omp::isOpenMPKernel(Fn&: *getAnchorScope())) {
3050 EntryBBED.IsExecutedByInitialThreadOnly = false;
3051 EntryBBED.IsReachedFromAlignedBarrierOnly = true;
3052 EntryBBED.EncounteredNonLocalSideEffect = false;
3053 ExitED.IsReachingAlignedBarrierOnly = false;
3054 } else {
3055 EntryBBED.IsExecutedByInitialThreadOnly = false;
3056 EntryBBED.IsReachedFromAlignedBarrierOnly = false;
3057 EntryBBED.EncounteredNonLocalSideEffect = true;
3058 ExitED.IsReachingAlignedBarrierOnly = false;
3059 }
3060 }
3061
3062 bool Changed = false;
3063 auto &FnED = BEDMap[nullptr];
3064 Changed |= setAndRecord(R&: FnED.IsReachedFromAlignedBarrierOnly,
3065 V: FnED.IsReachedFromAlignedBarrierOnly &
3066 EntryBBED.IsReachedFromAlignedBarrierOnly);
3067 Changed |= setAndRecord(R&: FnED.IsReachingAlignedBarrierOnly,
3068 V: FnED.IsReachingAlignedBarrierOnly &
3069 ExitED.IsReachingAlignedBarrierOnly);
3070 Changed |= setAndRecord(R&: FnED.IsExecutedByInitialThreadOnly,
3071 V: EntryBBED.IsExecutedByInitialThreadOnly);
3072 return Changed;
3073}
3074
3075ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
3076
3077 bool Changed = false;
3078
3079 // Helper to deal with an aligned barrier encountered during the forward
3080 // traversal. \p CB is the aligned barrier, \p ED is the execution domain when
3081 // it was encountered.
3082 auto HandleAlignedBarrier = [&](CallBase &CB, ExecutionDomainTy &ED) {
3083 Changed |= AlignedBarriers.insert(X: &CB);
3084 // First, update the barrier ED kept in the separate CEDMap.
3085 auto &CallInED = CEDMap[{&CB, PRE}];
3086 Changed |= mergeInPredecessor(A, ED&: CallInED, PredED: ED);
3087 CallInED.IsReachingAlignedBarrierOnly = true;
3088 // Next adjust the ED we use for the traversal.
3089 ED.EncounteredNonLocalSideEffect = false;
3090 ED.IsReachedFromAlignedBarrierOnly = true;
3091 // Aligned barrier collection has to come last.
3092 ED.clearAssumeInstAndAlignedBarriers();
3093 ED.addAlignedBarrier(A, CB);
3094 auto &CallOutED = CEDMap[{&CB, POST}];
3095 Changed |= mergeInPredecessor(A, ED&: CallOutED, PredED: ED);
3096 };
3097
3098 auto *LivenessAA =
3099 A.getAAFor<AAIsDead>(QueryingAA: *this, IRP: getIRPosition(), DepClass: DepClassTy::OPTIONAL);
3100
3101 Function *F = getAnchorScope();
3102 BasicBlock &EntryBB = F->getEntryBlock();
3103 bool IsKernel = omp::isOpenMPKernel(Fn&: *F);
3104
3105 SmallVector<Instruction *> SyncInstWorklist;
3106 for (auto &RIt : *RPOT) {
3107 BasicBlock &BB = *RIt;
3108
3109 bool IsEntryBB = &BB == &EntryBB;
3110 // TODO: We use local reasoning since we don't have a divergence analysis
3111 // running as well. We could basically allow uniform branches here.
3112 bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel;
3113 bool IsExplicitlyAligned = IsEntryBB && IsKernel;
3114 ExecutionDomainTy ED;
3115 // Propagate "incoming edges" into information about this block.
3116 if (IsEntryBB) {
3117 Changed |= handleCallees(A, EntryBBED&: ED);
3118 } else {
3119 // For live non-entry blocks we only propagate
3120 // information via live edges.
3121 if (LivenessAA && LivenessAA->isAssumedDead(BB: &BB))
3122 continue;
3123
3124 for (auto *PredBB : predecessors(BB: &BB)) {
3125 if (LivenessAA && LivenessAA->isEdgeDead(From: PredBB, To: &BB))
3126 continue;
3127 bool InitialEdgeOnly = isInitialThreadOnlyEdge(
3128 A, Edge: dyn_cast<BranchInst>(Val: PredBB->getTerminator()), SuccessorBB&: BB);
3129 mergeInPredecessor(A, ED, PredED: BEDMap[PredBB], InitialEdgeOnly);
3130 }
3131 }
3132
3133 // Now we traverse the block, accumulate effects in ED and attach
3134 // information to calls.
3135 for (Instruction &I : BB) {
3136 bool UsedAssumedInformation;
3137 if (A.isAssumedDead(I, QueryingAA: *this, LivenessAA, UsedAssumedInformation,
3138 /* CheckBBLivenessOnly */ false, DepClass: DepClassTy::OPTIONAL,
3139 /* CheckForDeadStore */ true))
3140 continue;
3141
3142 // Asummes and "assume-like" (dbg, lifetime, ...) are handled first, the
3143 // former is collected the latter is ignored.
3144 if (auto *II = dyn_cast<IntrinsicInst>(Val: &I)) {
3145 if (auto *AI = dyn_cast_or_null<AssumeInst>(Val: II)) {
3146 ED.addAssumeInst(A, AI&: *AI);
3147 continue;
3148 }
3149 // TODO: Should we also collect and delete lifetime markers?
3150 if (II->isAssumeLikeIntrinsic())
3151 continue;
3152 }
3153
3154 if (auto *FI = dyn_cast<FenceInst>(Val: &I)) {
3155 if (!ED.EncounteredNonLocalSideEffect) {
3156 // An aligned fence without non-local side-effects is a no-op.
3157 if (ED.IsReachedFromAlignedBarrierOnly)
3158 continue;
3159 // A non-aligned fence without non-local side-effects is a no-op
3160 // if the ordering only publishes non-local side-effects (or less).
3161 switch (FI->getOrdering()) {
3162 case AtomicOrdering::NotAtomic:
3163 continue;
3164 case AtomicOrdering::Unordered:
3165 continue;
3166 case AtomicOrdering::Monotonic:
3167 continue;
3168 case AtomicOrdering::Acquire:
3169 break;
3170 case AtomicOrdering::Release:
3171 continue;
3172 case AtomicOrdering::AcquireRelease:
3173 break;
3174 case AtomicOrdering::SequentiallyConsistent:
3175 break;
3176 };
3177 }
3178 NonNoOpFences.insert(Ptr: FI);
3179 }
3180
3181 auto *CB = dyn_cast<CallBase>(Val: &I);
3182 bool IsNoSync = AA::isNoSyncInst(A, I, QueryingAA: *this);
3183 bool IsAlignedBarrier =
3184 !IsNoSync && CB &&
3185 AANoSync::isAlignedBarrier(CB: *CB, ExecutedAligned: AlignedBarrierLastInBlock);
3186
3187 AlignedBarrierLastInBlock &= IsNoSync;
3188 IsExplicitlyAligned &= IsNoSync;
3189
3190 // Next we check for calls. Aligned barriers are handled
3191 // explicitly, everything else is kept for the backward traversal and will
3192 // also affect our state.
3193 if (CB) {
3194 if (IsAlignedBarrier) {
3195 HandleAlignedBarrier(*CB, ED);
3196 AlignedBarrierLastInBlock = true;
3197 IsExplicitlyAligned = true;
3198 continue;
3199 }
3200
3201 // Check the pointer(s) of a memory intrinsic explicitly.
3202 if (isa<MemIntrinsic>(Val: &I)) {
3203 if (!ED.EncounteredNonLocalSideEffect &&
3204 AA::isPotentiallyAffectedByBarrier(A, I, QueryingAA: *this))
3205 ED.EncounteredNonLocalSideEffect = true;
3206 if (!IsNoSync) {
3207 ED.IsReachedFromAlignedBarrierOnly = false;
3208 SyncInstWorklist.push_back(Elt: &I);
3209 }
3210 continue;
3211 }
3212
3213 // Record how we entered the call, then accumulate the effect of the
3214 // call in ED for potential use by the callee.
3215 auto &CallInED = CEDMap[{CB, PRE}];
3216 Changed |= mergeInPredecessor(A, ED&: CallInED, PredED: ED);
3217
3218 // If we have a sync-definition we can check if it starts/ends in an
3219 // aligned barrier. If we are unsure we assume any sync breaks
3220 // alignment.
3221 Function *Callee = CB->getCalledFunction();
3222 if (!IsNoSync && Callee && !Callee->isDeclaration()) {
3223 const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3224 QueryingAA: *this, IRP: IRPosition::function(F: *Callee), DepClass: DepClassTy::OPTIONAL);
3225 if (EDAA && EDAA->getState().isValidState()) {
3226 const auto &CalleeED = EDAA->getFunctionExecutionDomain();
3227 ED.IsReachedFromAlignedBarrierOnly =
3228 CalleeED.IsReachedFromAlignedBarrierOnly;
3229 AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;
3230 if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)
3231 ED.EncounteredNonLocalSideEffect |=
3232 CalleeED.EncounteredNonLocalSideEffect;
3233 else
3234 ED.EncounteredNonLocalSideEffect =
3235 CalleeED.EncounteredNonLocalSideEffect;
3236 if (!CalleeED.IsReachingAlignedBarrierOnly) {
3237 Changed |=
3238 setAndRecord(R&: CallInED.IsReachingAlignedBarrierOnly, V: false);
3239 SyncInstWorklist.push_back(Elt: &I);
3240 }
3241 if (CalleeED.IsReachedFromAlignedBarrierOnly)
3242 mergeInPredecessorBarriersAndAssumptions(A, ED, PredED: CalleeED);
3243 auto &CallOutED = CEDMap[{CB, POST}];
3244 Changed |= mergeInPredecessor(A, ED&: CallOutED, PredED: ED);
3245 continue;
3246 }
3247 }
3248 if (!IsNoSync) {
3249 ED.IsReachedFromAlignedBarrierOnly = false;
3250 Changed |= setAndRecord(R&: CallInED.IsReachingAlignedBarrierOnly, V: false);
3251 SyncInstWorklist.push_back(Elt: &I);
3252 }
3253 AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly;
3254 ED.EncounteredNonLocalSideEffect |= !CB->doesNotAccessMemory();
3255 auto &CallOutED = CEDMap[{CB, POST}];
3256 Changed |= mergeInPredecessor(A, ED&: CallOutED, PredED: ED);
3257 }
3258
3259 if (!I.mayHaveSideEffects() && !I.mayReadFromMemory())
3260 continue;
3261
3262 // If we have a callee we try to use fine-grained information to
3263 // determine local side-effects.
3264 if (CB) {
3265 const auto *MemAA = A.getAAFor<AAMemoryLocation>(
3266 QueryingAA: *this, IRP: IRPosition::callsite_function(CB: *CB), DepClass: DepClassTy::OPTIONAL);
3267
3268 auto AccessPred = [&](const Instruction *I, const Value *Ptr,
3269 AAMemoryLocation::AccessKind,
3270 AAMemoryLocation::MemoryLocationsKind) {
3271 return !AA::isPotentiallyAffectedByBarrier(A, Ptrs: {Ptr}, QueryingAA: *this, CtxI: I);
3272 };
3273 if (MemAA && MemAA->getState().isValidState() &&
3274 MemAA->checkForAllAccessesToMemoryKind(
3275 Pred: AccessPred, MLK: AAMemoryLocation::ALL_LOCATIONS))
3276 continue;
3277 }
3278
3279 auto &InfoCache = A.getInfoCache();
3280 if (!I.mayHaveSideEffects() && InfoCache.isOnlyUsedByAssume(I))
3281 continue;
3282
3283 if (auto *LI = dyn_cast<LoadInst>(Val: &I))
3284 if (LI->hasMetadata(KindID: LLVMContext::MD_invariant_load))
3285 continue;
3286
3287 if (!ED.EncounteredNonLocalSideEffect &&
3288 AA::isPotentiallyAffectedByBarrier(A, I, QueryingAA: *this))
3289 ED.EncounteredNonLocalSideEffect = true;
3290 }
3291
3292 bool IsEndAndNotReachingAlignedBarriersOnly = false;
3293 if (!isa<UnreachableInst>(Val: BB.getTerminator()) &&
3294 !BB.getTerminator()->getNumSuccessors()) {
3295
3296 Changed |= mergeInPredecessor(A, ED&: InterProceduralED, PredED: ED);
3297
3298 auto &FnED = BEDMap[nullptr];
3299 if (IsKernel && !IsExplicitlyAligned)
3300 FnED.IsReachingAlignedBarrierOnly = false;
3301 Changed |= mergeInPredecessor(A, ED&: FnED, PredED: ED);
3302
3303 if (!FnED.IsReachingAlignedBarrierOnly) {
3304 IsEndAndNotReachingAlignedBarriersOnly = true;
3305 SyncInstWorklist.push_back(Elt: BB.getTerminator());
3306 auto &BBED = BEDMap[&BB];
3307 Changed |= setAndRecord(R&: BBED.IsReachingAlignedBarrierOnly, V: false);
3308 }
3309 }
3310
3311 ExecutionDomainTy &StoredED = BEDMap[&BB];
3312 ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly &
3313 !IsEndAndNotReachingAlignedBarriersOnly;
3314
3315 // Check if we computed anything different as part of the forward
3316 // traversal. We do not take assumptions and aligned barriers into account
3317 // as they do not influence the state we iterate. Backward traversal values
3318 // are handled later on.
3319 if (ED.IsExecutedByInitialThreadOnly !=
3320 StoredED.IsExecutedByInitialThreadOnly ||
3321 ED.IsReachedFromAlignedBarrierOnly !=
3322 StoredED.IsReachedFromAlignedBarrierOnly ||
3323 ED.EncounteredNonLocalSideEffect !=
3324 StoredED.EncounteredNonLocalSideEffect)
3325 Changed = true;
3326
3327 // Update the state with the new value.
3328 StoredED = std::move(ED);
3329 }
3330
3331 // Propagate (non-aligned) sync instruction effects backwards until the
3332 // entry is hit or an aligned barrier.
3333 SmallSetVector<BasicBlock *, 16> Visited;
3334 while (!SyncInstWorklist.empty()) {
3335 Instruction *SyncInst = SyncInstWorklist.pop_back_val();
3336 Instruction *CurInst = SyncInst;
3337 bool HitAlignedBarrierOrKnownEnd = false;
3338 while ((CurInst = CurInst->getPrevNode())) {
3339 auto *CB = dyn_cast<CallBase>(Val: CurInst);
3340 if (!CB)
3341 continue;
3342 auto &CallOutED = CEDMap[{CB, POST}];
3343 Changed |= setAndRecord(R&: CallOutED.IsReachingAlignedBarrierOnly, V: false);
3344 auto &CallInED = CEDMap[{CB, PRE}];
3345 HitAlignedBarrierOrKnownEnd =
3346 AlignedBarriers.count(key: CB) || !CallInED.IsReachingAlignedBarrierOnly;
3347 if (HitAlignedBarrierOrKnownEnd)
3348 break;
3349 Changed |= setAndRecord(R&: CallInED.IsReachingAlignedBarrierOnly, V: false);
3350 }
3351 if (HitAlignedBarrierOrKnownEnd)
3352 continue;
3353 BasicBlock *SyncBB = SyncInst->getParent();
3354 for (auto *PredBB : predecessors(BB: SyncBB)) {
3355 if (LivenessAA && LivenessAA->isEdgeDead(From: PredBB, To: SyncBB))
3356 continue;
3357 if (!Visited.insert(X: PredBB))
3358 continue;
3359 auto &PredED = BEDMap[PredBB];
3360 if (setAndRecord(R&: PredED.IsReachingAlignedBarrierOnly, V: false)) {
3361 Changed = true;
3362 SyncInstWorklist.push_back(Elt: PredBB->getTerminator());
3363 }
3364 }
3365 if (SyncBB != &EntryBB)
3366 continue;
3367 Changed |=
3368 setAndRecord(R&: InterProceduralED.IsReachingAlignedBarrierOnly, V: false);
3369 }
3370
3371 return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
3372}
3373
3374/// Try to replace memory allocation calls called by a single thread with a
3375/// static buffer of shared memory.
3376struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
3377 using Base = StateWrapper<BooleanState, AbstractAttribute>;
3378 AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3379
3380 /// Create an abstract attribute view for the position \p IRP.
3381 static AAHeapToShared &createForPosition(const IRPosition &IRP,
3382 Attributor &A);
3383
3384 /// Returns true if HeapToShared conversion is assumed to be possible.
3385 virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
3386
3387 /// Returns true if HeapToShared conversion is assumed and the CB is a
3388 /// callsite to a free operation to be removed.
3389 virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
3390
3391 /// See AbstractAttribute::getName().
3392 const std::string getName() const override { return "AAHeapToShared"; }
3393
3394 /// See AbstractAttribute::getIdAddr().
3395 const char *getIdAddr() const override { return &ID; }
3396
3397 /// This function should return true if the type of the \p AA is
3398 /// AAHeapToShared.
3399 static bool classof(const AbstractAttribute *AA) {
3400 return (AA->getIdAddr() == &ID);
3401 }
3402
3403 /// Unique ID (due to the unique address)
3404 static const char ID;
3405};
3406
3407struct AAHeapToSharedFunction : public AAHeapToShared {
3408 AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
3409 : AAHeapToShared(IRP, A) {}
3410
3411 const std::string getAsStr(Attributor *) const override {
3412 return "[AAHeapToShared] " + std::to_string(val: MallocCalls.size()) +
3413 " malloc calls eligible.";
3414 }
3415
3416 /// See AbstractAttribute::trackStatistics().
3417 void trackStatistics() const override {}
3418
3419 /// This functions finds free calls that will be removed by the
3420 /// HeapToShared transformation.
3421 void findPotentialRemovedFreeCalls(Attributor &A) {
3422 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3423 auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3424
3425 PotentialRemovedFreeCalls.clear();
3426 // Update free call users of found malloc calls.
3427 for (CallBase *CB : MallocCalls) {
3428 SmallVector<CallBase *, 4> FreeCalls;
3429 for (auto *U : CB->users()) {
3430 CallBase *C = dyn_cast<CallBase>(Val: U);
3431 if (C && C->getCalledFunction() == FreeRFI.Declaration)
3432 FreeCalls.push_back(Elt: C);
3433 }
3434
3435 if (FreeCalls.size() != 1)
3436 continue;
3437
3438 PotentialRemovedFreeCalls.insert(Ptr: FreeCalls.front());
3439 }
3440 }
3441
3442 void initialize(Attributor &A) override {
3443 if (DisableOpenMPOptDeglobalization) {
3444 indicatePessimisticFixpoint();
3445 return;
3446 }
3447
3448 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3449 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3450 if (!RFI.Declaration)
3451 return;
3452
3453 Attributor::SimplifictionCallbackTy SCB =
3454 [](const IRPosition &, const AbstractAttribute *,
3455 bool &) -> std::optional<Value *> { return nullptr; };
3456
3457 Function *F = getAnchorScope();
3458 for (User *U : RFI.Declaration->users())
3459 if (CallBase *CB = dyn_cast<CallBase>(Val: U)) {
3460 if (CB->getFunction() != F)
3461 continue;
3462 MallocCalls.insert(X: CB);
3463 A.registerSimplificationCallback(IRP: IRPosition::callsite_returned(CB: *CB),
3464 CB: SCB);
3465 }
3466
3467 findPotentialRemovedFreeCalls(A);
3468 }
3469
3470 bool isAssumedHeapToShared(CallBase &CB) const override {
3471 return isValidState() && MallocCalls.count(key: &CB);
3472 }
3473
3474 bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
3475 return isValidState() && PotentialRemovedFreeCalls.count(Ptr: &CB);
3476 }
3477
3478 ChangeStatus manifest(Attributor &A) override {
3479 if (MallocCalls.empty())
3480 return ChangeStatus::UNCHANGED;
3481
3482 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3483 auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3484
3485 Function *F = getAnchorScope();
3486 auto *HS = A.lookupAAFor<AAHeapToStack>(IRP: IRPosition::function(F: *F), QueryingAA: this,
3487 DepClass: DepClassTy::OPTIONAL);
3488
3489 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3490 for (CallBase *CB : MallocCalls) {
3491 // Skip replacing this if HeapToStack has already claimed it.
3492 if (HS && HS->isAssumedHeapToStack(CB: *CB))
3493 continue;
3494
3495 // Find the unique free call to remove it.
3496 SmallVector<CallBase *, 4> FreeCalls;
3497 for (auto *U : CB->users()) {
3498 CallBase *C = dyn_cast<CallBase>(Val: U);
3499 if (C && C->getCalledFunction() == FreeCall.Declaration)
3500 FreeCalls.push_back(Elt: C);
3501 }
3502 if (FreeCalls.size() != 1)
3503 continue;
3504
3505 auto *AllocSize = cast<ConstantInt>(Val: CB->getArgOperand(i: 0));
3506
3507 if (AllocSize->getZExtValue() + SharedMemoryUsed > SharedMemoryLimit) {
3508 LLVM_DEBUG(dbgs() << TAG << "Cannot replace call " << *CB
3509 << " with shared memory."
3510 << " Shared memory usage is limited to "
3511 << SharedMemoryLimit << " bytes\n");
3512 continue;
3513 }
3514
3515 LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
3516 << " with " << AllocSize->getZExtValue()
3517 << " bytes of shared memory\n");
3518
3519 // Create a new shared memory buffer of the same size as the allocation
3520 // and replace all the uses of the original allocation with it.
3521 Module *M = CB->getModule();
3522 Type *Int8Ty = Type::getInt8Ty(C&: M->getContext());
3523 Type *Int8ArrTy = ArrayType::get(ElementType: Int8Ty, NumElements: AllocSize->getZExtValue());
3524 auto *SharedMem = new GlobalVariable(
3525 *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
3526 PoisonValue::get(T: Int8ArrTy), CB->getName() + "_shared", nullptr,
3527 GlobalValue::NotThreadLocal,
3528 static_cast<unsigned>(AddressSpace::Shared));
3529 auto *NewBuffer =
3530 ConstantExpr::getPointerCast(C: SharedMem, Ty: Int8Ty->getPointerTo());
3531
3532 auto Remark = [&](OptimizationRemark OR) {
3533 return OR << "Replaced globalized variable with "
3534 << ore::NV("SharedMemory", AllocSize->getZExtValue())
3535 << (AllocSize->isOne() ? " byte " : " bytes ")
3536 << "of shared memory.";
3537 };
3538 A.emitRemark<OptimizationRemark>(I: CB, RemarkName: "OMP111", RemarkCB&: Remark);
3539
3540 MaybeAlign Alignment = CB->getRetAlign();
3541 assert(Alignment &&
3542 "HeapToShared on allocation without alignment attribute");
3543 SharedMem->setAlignment(*Alignment);
3544
3545 A.changeAfterManifest(IRP: IRPosition::callsite_returned(CB: *CB), NV&: *NewBuffer);
3546 A.deleteAfterManifest(I&: *CB);
3547 A.deleteAfterManifest(I&: *FreeCalls.front());
3548
3549 SharedMemoryUsed += AllocSize->getZExtValue();
3550 NumBytesMovedToSharedMemory = SharedMemoryUsed;
3551 Changed = ChangeStatus::CHANGED;
3552 }
3553
3554 return Changed;
3555 }
3556
3557 ChangeStatus updateImpl(Attributor &A) override {
3558 if (MallocCalls.empty())
3559 return indicatePessimisticFixpoint();
3560 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3561 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3562 if (!RFI.Declaration)
3563 return ChangeStatus::UNCHANGED;
3564
3565 Function *F = getAnchorScope();
3566
3567 auto NumMallocCalls = MallocCalls.size();
3568
3569 // Only consider malloc calls executed by a single thread with a constant.
3570 for (User *U : RFI.Declaration->users()) {
3571 if (CallBase *CB = dyn_cast<CallBase>(Val: U)) {
3572 if (CB->getCaller() != F)
3573 continue;
3574 if (!MallocCalls.count(key: CB))
3575 continue;
3576 if (!isa<ConstantInt>(Val: CB->getArgOperand(i: 0))) {
3577 MallocCalls.remove(X: CB);
3578 continue;
3579 }
3580 const auto *ED = A.getAAFor<AAExecutionDomain>(
3581 QueryingAA: *this, IRP: IRPosition::function(F: *F), DepClass: DepClassTy::REQUIRED);
3582 if (!ED || !ED->isExecutedByInitialThreadOnly(I: *CB))
3583 MallocCalls.remove(X: CB);
3584 }
3585 }
3586
3587 findPotentialRemovedFreeCalls(A);
3588
3589 if (NumMallocCalls != MallocCalls.size())
3590 return ChangeStatus::CHANGED;
3591
3592 return ChangeStatus::UNCHANGED;
3593 }
3594
3595 /// Collection of all malloc calls in a function.
3596 SmallSetVector<CallBase *, 4> MallocCalls;
3597 /// Collection of potentially removed free calls in a function.
3598 SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
3599 /// The total amount of shared memory that has been used for HeapToShared.
3600 unsigned SharedMemoryUsed = 0;
3601};
3602
3603struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
3604 using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
3605 AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3606
3607 /// The callee value is tracked beyond a simple stripPointerCasts, so we allow
3608 /// unknown callees.
3609 static bool requiresCalleeForCallBase() { return false; }
3610
3611 /// Statistics are tracked as part of manifest for now.
3612 void trackStatistics() const override {}
3613
3614 /// See AbstractAttribute::getAsStr()
3615 const std::string getAsStr(Attributor *) const override {
3616 if (!isValidState())
3617 return "<invalid>";
3618 return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
3619 : "generic") +
3620 std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
3621 : "") +
3622 std::string(" #PRs: ") +
3623 (ReachedKnownParallelRegions.isValidState()
3624 ? std::to_string(val: ReachedKnownParallelRegions.size())
3625 : "<invalid>") +
3626 ", #Unknown PRs: " +
3627 (ReachedUnknownParallelRegions.isValidState()
3628 ? std::to_string(val: ReachedUnknownParallelRegions.size())
3629 : "<invalid>") +
3630 ", #Reaching Kernels: " +
3631 (ReachingKernelEntries.isValidState()
3632 ? std::to_string(val: ReachingKernelEntries.size())
3633 : "<invalid>") +
3634 ", #ParLevels: " +
3635 (ParallelLevels.isValidState()
3636 ? std::to_string(val: ParallelLevels.size())
3637 : "<invalid>") +
3638 ", NestedPar: " + (NestedParallelism ? "yes" : "no");
3639 }
3640
3641 /// Create an abstract attribute biew for the position \p IRP.
3642 static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
3643
3644 /// See AbstractAttribute::getName()
3645 const std::string getName() const override { return "AAKernelInfo"; }
3646
3647 /// See AbstractAttribute::getIdAddr()
3648 const char *getIdAddr() const override { return &ID; }
3649
3650 /// This function should return true if the type of the \p AA is AAKernelInfo
3651 static bool classof(const AbstractAttribute *AA) {
3652 return (AA->getIdAddr() == &ID);
3653 }
3654
3655 static const char ID;
3656};
3657
3658/// The function kernel info abstract attribute, basically, what can we say
3659/// about a function with regards to the KernelInfoState.
3660struct AAKernelInfoFunction : AAKernelInfo {
3661 AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
3662 : AAKernelInfo(IRP, A) {}
3663
3664 SmallPtrSet<Instruction *, 4> GuardedInstructions;
3665
3666 SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
3667 return GuardedInstructions;
3668 }
3669
3670 void setConfigurationOfKernelEnvironment(ConstantStruct *ConfigC) {
3671 Constant *NewKernelEnvC = ConstantFoldInsertValueInstruction(
3672 Agg: KernelEnvC, Val: ConfigC, Idxs: {KernelInfo::ConfigurationIdx});
3673 assert(NewKernelEnvC && "Failed to create new kernel environment");
3674 KernelEnvC = cast<ConstantStruct>(Val: NewKernelEnvC);
3675 }
3676
3677#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER) \
3678 void set##MEMBER##OfKernelEnvironment(ConstantInt *NewVal) { \
3679 ConstantStruct *ConfigC = \
3680 KernelInfo::getConfigurationFromKernelEnvironment(KernelEnvC); \
3681 Constant *NewConfigC = ConstantFoldInsertValueInstruction( \
3682 ConfigC, NewVal, {KernelInfo::MEMBER##Idx}); \
3683 assert(NewConfigC && "Failed to create new configuration environment"); \
3684 setConfigurationOfKernelEnvironment(cast<ConstantStruct>(NewConfigC)); \
3685 }
3686
3687 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(UseGenericStateMachine)
3688 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MayUseNestedParallelism)
3689 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(ExecMode)
3690 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MinThreads)
3691 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MaxThreads)
3692 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MinTeams)
3693 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MaxTeams)
3694
3695#undef KERNEL_ENVIRONMENT_CONFIGURATION_SETTER
3696
3697 /// See AbstractAttribute::initialize(...).
3698 void initialize(Attributor &A) override {
3699 // This is a high-level transform that might change the constant arguments
3700 // of the init and dinit calls. We need to tell the Attributor about this
3701 // to avoid other parts using the current constant value for simpliication.
3702 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3703
3704 Function *Fn = getAnchorScope();
3705
3706 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
3707 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
3708 OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
3709 OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
3710
3711 // For kernels we perform more initialization work, first we find the init
3712 // and deinit calls.
3713 auto StoreCallBase = [](Use &U,
3714 OMPInformationCache::RuntimeFunctionInfo &RFI,
3715 CallBase *&Storage) {
3716 CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, RFI: &RFI);
3717 assert(CB &&
3718 "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
3719 assert(!Storage &&
3720 "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
3721 Storage = CB;
3722 return false;
3723 };
3724 InitRFI.foreachUse(
3725 CB: [&](Use &U, Function &) {
3726 StoreCallBase(U, InitRFI, KernelInitCB);
3727 return false;
3728 },
3729 F: Fn);
3730 DeinitRFI.foreachUse(
3731 CB: [&](Use &U, Function &) {
3732 StoreCallBase(U, DeinitRFI, KernelDeinitCB);
3733 return false;
3734 },
3735 F: Fn);
3736
3737 // Ignore kernels without initializers such as global constructors.
3738 if (!KernelInitCB || !KernelDeinitCB)
3739 return;
3740
3741 // Add itself to the reaching kernel and set IsKernelEntry.
3742 ReachingKernelEntries.insert(Elem: Fn);
3743 IsKernelEntry = true;
3744
3745 KernelEnvC =
3746 KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);
3747 GlobalVariable *KernelEnvGV =
3748 KernelInfo::getKernelEnvironementGVFromKernelInitCB(KernelInitCB);
3749
3750 Attributor::GlobalVariableSimplifictionCallbackTy
3751 KernelConfigurationSimplifyCB =
3752 [&](const GlobalVariable &GV, const AbstractAttribute *AA,
3753 bool &UsedAssumedInformation) -> std::optional<Constant *> {
3754 if (!isAtFixpoint()) {
3755 if (!AA)
3756 return nullptr;
3757 UsedAssumedInformation = true;
3758 A.recordDependence(FromAA: *this, ToAA: *AA, DepClass: DepClassTy::OPTIONAL);
3759 }
3760 return KernelEnvC;
3761 };
3762
3763 A.registerGlobalVariableSimplificationCallback(
3764 GV: *KernelEnvGV, CB: KernelConfigurationSimplifyCB);
3765
3766 // Check if we know we are in SPMD-mode already.
3767 ConstantInt *ExecModeC =
3768 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
3769 ConstantInt *AssumedExecModeC = ConstantInt::get(
3770 Ty: ExecModeC->getIntegerType(),
3771 V: ExecModeC->getSExtValue() | OMP_TGT_EXEC_MODE_GENERIC_SPMD);
3772 if (ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)
3773 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3774 else if (DisableOpenMPOptSPMDization)
3775 // This is a generic region but SPMDization is disabled so stop
3776 // tracking.
3777 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3778 else
3779 setExecModeOfKernelEnvironment(AssumedExecModeC);
3780
3781 const Triple T(Fn->getParent()->getTargetTriple());
3782 auto *Int32Ty = Type::getInt32Ty(C&: Fn->getContext());
3783 auto [MinThreads, MaxThreads] =
3784 OpenMPIRBuilder::readThreadBoundsForKernel(T, Kernel&: *Fn);
3785 if (MinThreads)
3786 setMinThreadsOfKernelEnvironment(ConstantInt::get(Ty: Int32Ty, V: MinThreads));
3787 if (MaxThreads)
3788 setMaxThreadsOfKernelEnvironment(ConstantInt::get(Ty: Int32Ty, V: MaxThreads));
3789 auto [MinTeams, MaxTeams] =
3790 OpenMPIRBuilder::readTeamBoundsForKernel(T, Kernel&: *Fn);
3791 if (MinTeams)
3792 setMinTeamsOfKernelEnvironment(ConstantInt::get(Ty: Int32Ty, V: MinTeams));
3793 if (MaxTeams)
3794 setMaxTeamsOfKernelEnvironment(ConstantInt::get(Ty: Int32Ty, V: MaxTeams));
3795
3796 ConstantInt *MayUseNestedParallelismC =
3797 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC);
3798 ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get(
3799 Ty: MayUseNestedParallelismC->getIntegerType(), V: NestedParallelism);
3800 setMayUseNestedParallelismOfKernelEnvironment(
3801 AssumedMayUseNestedParallelismC);
3802
3803 if (!DisableOpenMPOptStateMachineRewrite) {
3804 ConstantInt *UseGenericStateMachineC =
3805 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3806 KernelEnvC);
3807 ConstantInt *AssumedUseGenericStateMachineC =
3808 ConstantInt::get(Ty: UseGenericStateMachineC->getIntegerType(), V: false);
3809 setUseGenericStateMachineOfKernelEnvironment(
3810 AssumedUseGenericStateMachineC);
3811 }
3812
3813 // Register virtual uses of functions we might need to preserve.
3814 auto RegisterVirtualUse = [&](RuntimeFunction RFKind,
3815 Attributor::VirtualUseCallbackTy &CB) {
3816 if (!OMPInfoCache.RFIs[RFKind].Declaration)
3817 return;
3818 A.registerVirtualUseCallback(V: *OMPInfoCache.RFIs[RFKind].Declaration, CB);
3819 };
3820
3821 // Add a dependence to ensure updates if the state changes.
3822 auto AddDependence = [](Attributor &A, const AAKernelInfo *KI,
3823 const AbstractAttribute *QueryingAA) {
3824 if (QueryingAA) {
3825 A.recordDependence(FromAA: *KI, ToAA: *QueryingAA, DepClass: DepClassTy::OPTIONAL);
3826 }
3827 return true;
3828 };
3829
3830 Attributor::VirtualUseCallbackTy CustomStateMachineUseCB =
3831 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3832 // Whenever we create a custom state machine we will insert calls to
3833 // __kmpc_get_hardware_num_threads_in_block,
3834 // __kmpc_get_warp_size,
3835 // __kmpc_barrier_simple_generic,
3836 // __kmpc_kernel_parallel, and
3837 // __kmpc_kernel_end_parallel.
3838 // Not needed if we are on track for SPMDzation.
3839 if (SPMDCompatibilityTracker.isValidState())
3840 return AddDependence(A, this, QueryingAA);
3841 // Not needed if we can't rewrite due to an invalid state.
3842 if (!ReachedKnownParallelRegions.isValidState())
3843 return AddDependence(A, this, QueryingAA);
3844 return false;
3845 };
3846
3847 // Not needed if we are pre-runtime merge.
3848 if (!KernelInitCB->getCalledFunction()->isDeclaration()) {
3849 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block,
3850 CustomStateMachineUseCB);
3851 RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB);
3852 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic,
3853 CustomStateMachineUseCB);
3854 RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel,
3855 CustomStateMachineUseCB);
3856 RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel,
3857 CustomStateMachineUseCB);
3858 }
3859
3860 // If we do not perform SPMDzation we do not need the virtual uses below.
3861 if (SPMDCompatibilityTracker.isAtFixpoint())
3862 return;
3863
3864 Attributor::VirtualUseCallbackTy HWThreadIdUseCB =
3865 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3866 // Whenever we perform SPMDzation we will insert
3867 // __kmpc_get_hardware_thread_id_in_block calls.
3868 if (!SPMDCompatibilityTracker.isValidState())
3869 return AddDependence(A, this, QueryingAA);
3870 return false;
3871 };
3872 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block,
3873 HWThreadIdUseCB);
3874
3875 Attributor::VirtualUseCallbackTy SPMDBarrierUseCB =
3876 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3877 // Whenever we perform SPMDzation with guarding we will insert
3878 // __kmpc_simple_barrier_spmd calls. If SPMDzation failed, there is
3879 // nothing to guard, or there are no parallel regions, we don't need
3880 // the calls.
3881 if (!SPMDCompatibilityTracker.isValidState())
3882 return AddDependence(A, this, QueryingAA);
3883 if (SPMDCompatibilityTracker.empty())
3884 return AddDependence(A, this, QueryingAA);
3885 if (!mayContainParallelRegion())
3886 return AddDependence(A, this, QueryingAA);
3887 return false;
3888 };
3889 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB);
3890 }
3891
3892 /// Sanitize the string \p S such that it is a suitable global symbol name.
3893 static std::string sanitizeForGlobalName(std::string S) {
3894 std::replace_if(
3895 first: S.begin(), last: S.end(),
3896 pred: [](const char C) {
3897 return !((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') ||
3898 (C >= '0' && C <= '9') || C == '_');
3899 },
3900 new_value: '.');
3901 return S;
3902 }
3903
3904 /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3905 /// finished now.
3906 ChangeStatus manifest(Attributor &A) override {
3907 // If we are not looking at a kernel with __kmpc_target_init and
3908 // __kmpc_target_deinit call we cannot actually manifest the information.
3909 if (!KernelInitCB || !KernelDeinitCB)
3910 return ChangeStatus::UNCHANGED;
3911
3912 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3913
3914 bool HasBuiltStateMachine = true;
3915 if (!changeToSPMDMode(A, Changed)) {
3916 if (!KernelInitCB->getCalledFunction()->isDeclaration())
3917 HasBuiltStateMachine = buildCustomStateMachine(A, Changed);
3918 else
3919 HasBuiltStateMachine = false;
3920 }
3921
3922 // We need to reset KernelEnvC if specific rewriting is not done.
3923 ConstantStruct *ExistingKernelEnvC =
3924 KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);
3925 ConstantInt *OldUseGenericStateMachineVal =
3926 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3927 KernelEnvC: ExistingKernelEnvC);
3928 if (!HasBuiltStateMachine)
3929 setUseGenericStateMachineOfKernelEnvironment(
3930 OldUseGenericStateMachineVal);
3931
3932 // At last, update the KernelEnvc
3933 GlobalVariable *KernelEnvGV =
3934 KernelInfo::getKernelEnvironementGVFromKernelInitCB(KernelInitCB);
3935 if (KernelEnvGV->getInitializer() != KernelEnvC) {
3936 KernelEnvGV->setInitializer(KernelEnvC);
3937 Changed = ChangeStatus::CHANGED;
3938 }
3939
3940 return Changed;
3941 }
3942
3943 void insertInstructionGuardsHelper(Attributor &A) {
3944 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3945
3946 auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3947 Instruction *RegionEndI) {
3948 LoopInfo *LI = nullptr;
3949 DominatorTree *DT = nullptr;
3950 MemorySSAUpdater *MSU = nullptr;
3951 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3952
3953 BasicBlock *ParentBB = RegionStartI->getParent();
3954 Function *Fn = ParentBB->getParent();
3955 Module &M = *Fn->getParent();
3956
3957 // Create all the blocks and logic.
3958 // ParentBB:
3959 // goto RegionCheckTidBB
3960 // RegionCheckTidBB:
3961 // Tid = __kmpc_hardware_thread_id()
3962 // if (Tid != 0)
3963 // goto RegionBarrierBB
3964 // RegionStartBB:
3965 // <execute instructions guarded>
3966 // goto RegionEndBB
3967 // RegionEndBB:
3968 // <store escaping values to shared mem>
3969 // goto RegionBarrierBB
3970 // RegionBarrierBB:
3971 // __kmpc_simple_barrier_spmd()
3972 // // second barrier is omitted if lacking escaping values.
3973 // <load escaping values from shared mem>
3974 // __kmpc_simple_barrier_spmd()
3975 // goto RegionExitBB
3976 // RegionExitBB:
3977 // <execute rest of instructions>
3978
3979 BasicBlock *RegionEndBB = SplitBlock(Old: ParentBB, SplitPt: RegionEndI->getNextNode(),
3980 DT, LI, MSSAU: MSU, BBName: "region.guarded.end");
3981 BasicBlock *RegionBarrierBB =
3982 SplitBlock(Old: RegionEndBB, SplitPt: &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3983 MSSAU: MSU, BBName: "region.barrier");
3984 BasicBlock *RegionExitBB =
3985 SplitBlock(Old: RegionBarrierBB, SplitPt: &*RegionBarrierBB->getFirstInsertionPt(),
3986 DT, LI, MSSAU: MSU, BBName: "region.exit");
3987 BasicBlock *RegionStartBB =
3988 SplitBlock(Old: ParentBB, SplitPt: RegionStartI, DT, LI, MSSAU: MSU, BBName: "region.guarded");
3989
3990 assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3991 "Expected a different CFG");
3992
3993 BasicBlock *RegionCheckTidBB = SplitBlock(
3994 Old: ParentBB, SplitPt: ParentBB->getTerminator(), DT, LI, MSSAU: MSU, BBName: "region.check.tid");
3995
3996 // Register basic blocks with the Attributor.
3997 A.registerManifestAddedBasicBlock(BB&: *RegionEndBB);
3998 A.registerManifestAddedBasicBlock(BB&: *RegionBarrierBB);
3999 A.registerManifestAddedBasicBlock(BB&: *RegionExitBB);
4000 A.registerManifestAddedBasicBlock(BB&: *RegionStartBB);
4001 A.registerManifestAddedBasicBlock(BB&: *RegionCheckTidBB);
4002
4003 bool HasBroadcastValues = false;
4004 // Find escaping outputs from the guarded region to outside users and
4005 // broadcast their values to them.
4006 for (Instruction &I : *RegionStartBB) {
4007 SmallVector<Use *, 4> OutsideUses;
4008 for (Use &U : I.uses()) {
4009 Instruction &UsrI = *cast<Instruction>(Val: U.getUser());
4010 if (UsrI.getParent() != RegionStartBB)
4011 OutsideUses.push_back(Elt: &U);
4012 }
4013
4014 if (OutsideUses.empty())
4015 continue;
4016
4017 HasBroadcastValues = true;
4018
4019 // Emit a global variable in shared memory to store the broadcasted
4020 // value.
4021 auto *SharedMem = new GlobalVariable(
4022 M, I.getType(), /* IsConstant */ false,
4023 GlobalValue::InternalLinkage, UndefValue::get(T: I.getType()),
4024 sanitizeForGlobalName(
4025 S: (I.getName() + ".guarded.output.alloc").str()),
4026 nullptr, GlobalValue::NotThreadLocal,
4027 static_cast<unsigned>(AddressSpace::Shared));
4028
4029 // Emit a store instruction to update the value.
4030 new StoreInst(&I, SharedMem,
4031 RegionEndBB->getTerminator()->getIterator());
4032
4033 LoadInst *LoadI = new LoadInst(
4034 I.getType(), SharedMem, I.getName() + ".guarded.output.load",
4035 RegionBarrierBB->getTerminator()->getIterator());
4036
4037 // Emit a load instruction and replace uses of the output value.
4038 for (Use *U : OutsideUses)
4039 A.changeUseAfterManifest(U&: *U, NV&: *LoadI);
4040 }
4041
4042 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4043
4044 // Go to tid check BB in ParentBB.
4045 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
4046 ParentBB->getTerminator()->eraseFromParent();
4047 OpenMPIRBuilder::LocationDescription Loc(
4048 InsertPointTy(ParentBB, ParentBB->end()), DL);
4049 OMPInfoCache.OMPBuilder.updateToLocation(Loc);
4050 uint32_t SrcLocStrSize;
4051 auto *SrcLocStr =
4052 OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4053 Value *Ident =
4054 OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4055 BranchInst::Create(IfTrue: RegionCheckTidBB, InsertAtEnd: ParentBB)->setDebugLoc(DL);
4056
4057 // Add check for Tid in RegionCheckTidBB
4058 RegionCheckTidBB->getTerminator()->eraseFromParent();
4059 OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
4060 InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
4061 OMPInfoCache.OMPBuilder.updateToLocation(Loc: LocRegionCheckTid);
4062 FunctionCallee HardwareTidFn =
4063 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4064 M, FnID: OMPRTL___kmpc_get_hardware_thread_id_in_block);
4065 CallInst *Tid =
4066 OMPInfoCache.OMPBuilder.Builder.CreateCall(Callee: HardwareTidFn, Args: {});
4067 Tid->setDebugLoc(DL);
4068 OMPInfoCache.setCallingConvention(Callee: HardwareTidFn, CI: Tid);
4069 Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Arg: Tid);
4070 OMPInfoCache.OMPBuilder.Builder
4071 .CreateCondBr(Cond: TidCheck, True: RegionStartBB, False: RegionBarrierBB)
4072 ->setDebugLoc(DL);
4073
4074 // First barrier for synchronization, ensures main thread has updated
4075 // values.
4076 FunctionCallee BarrierFn =
4077 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4078 M, FnID: OMPRTL___kmpc_barrier_simple_spmd);
4079 OMPInfoCache.OMPBuilder.updateToLocation(Loc: InsertPointTy(
4080 RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
4081 CallInst *Barrier =
4082 OMPInfoCache.OMPBuilder.Builder.CreateCall(Callee: BarrierFn, Args: {Ident, Tid});
4083 Barrier->setDebugLoc(DL);
4084 OMPInfoCache.setCallingConvention(Callee: BarrierFn, CI: Barrier);
4085
4086 // Second barrier ensures workers have read broadcast values.
4087 if (HasBroadcastValues) {
4088 CallInst *Barrier =
4089 CallInst::Create(Func: BarrierFn, Args: {Ident, Tid}, NameStr: "",
4090 InsertBefore: RegionBarrierBB->getTerminator()->getIterator());
4091 Barrier->setDebugLoc(DL);
4092 OMPInfoCache.setCallingConvention(Callee: BarrierFn, CI: Barrier);
4093 }
4094 };
4095
4096 auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4097 SmallPtrSet<BasicBlock *, 8> Visited;
4098 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4099 BasicBlock *BB = GuardedI->getParent();
4100 if (!Visited.insert(Ptr: BB).second)
4101 continue;
4102
4103 SmallVector<std::pair<Instruction *, Instruction *>> Reorders;
4104 Instruction *LastEffect = nullptr;
4105 BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
4106 while (++IP != IPEnd) {
4107 if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
4108 continue;
4109 Instruction *I = &*IP;
4110 if (OpenMPOpt::getCallIfRegularCall(V&: *I, RFI: &AllocSharedRFI))
4111 continue;
4112 if (!I->user_empty() || !SPMDCompatibilityTracker.contains(Elem: I)) {
4113 LastEffect = nullptr;
4114 continue;
4115 }
4116 if (LastEffect)
4117 Reorders.push_back(Elt: {I, LastEffect});
4118 LastEffect = &*IP;
4119 }
4120 for (auto &Reorder : Reorders)
4121 Reorder.first->moveBefore(MovePos: Reorder.second);
4122 }
4123
4124 SmallVector<std::pair<Instruction *, Instruction *>, 4> GuardedRegions;
4125
4126 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4127 BasicBlock *BB = GuardedI->getParent();
4128 auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
4129 IRP: IRPosition::function(F: *GuardedI->getFunction()), QueryingAA: nullptr,
4130 DepClass: DepClassTy::NONE);
4131 assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
4132 auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(Val: CalleeAA);
4133 // Continue if instruction is already guarded.
4134 if (CalleeAAFunction.getGuardedInstructions().contains(Ptr: GuardedI))
4135 continue;
4136
4137 Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
4138 for (Instruction &I : *BB) {
4139 // If instruction I needs to be guarded update the guarded region
4140 // bounds.
4141 if (SPMDCompatibilityTracker.contains(Elem: &I)) {
4142 CalleeAAFunction.getGuardedInstructions().insert(Ptr: &I);
4143 if (GuardedRegionStart)
4144 GuardedRegionEnd = &I;
4145 else
4146 GuardedRegionStart = GuardedRegionEnd = &I;
4147
4148 continue;
4149 }
4150
4151 // Instruction I does not need guarding, store
4152 // any region found and reset bounds.
4153 if (GuardedRegionStart) {
4154 GuardedRegions.push_back(
4155 Elt: std::make_pair(x&: GuardedRegionStart, y&: GuardedRegionEnd));
4156 GuardedRegionStart = nullptr;
4157 GuardedRegionEnd = nullptr;
4158 }
4159 }
4160 }
4161
4162 for (auto &GR : GuardedRegions)
4163 CreateGuardedRegion(GR.first, GR.second);
4164 }
4165
4166 void forceSingleThreadPerWorkgroupHelper(Attributor &A) {
4167 // Only allow 1 thread per workgroup to continue executing the user code.
4168 //
4169 // InitCB = __kmpc_target_init(...)
4170 // ThreadIdInBlock = __kmpc_get_hardware_thread_id_in_block();
4171 // if (ThreadIdInBlock != 0) return;
4172 // UserCode:
4173 // // user code
4174 //
4175 auto &Ctx = getAnchorValue().getContext();
4176 Function *Kernel = getAssociatedFunction();
4177 assert(Kernel && "Expected an associated function!");
4178
4179 // Create block for user code to branch to from initial block.
4180 BasicBlock *InitBB = KernelInitCB->getParent();
4181 BasicBlock *UserCodeBB = InitBB->splitBasicBlock(
4182 I: KernelInitCB->getNextNode(), BBName: "main.thread.user_code");
4183 BasicBlock *ReturnBB =
4184 BasicBlock::Create(Context&: Ctx, Name: "exit.threads", Parent: Kernel, InsertBefore: UserCodeBB);
4185
4186 // Register blocks with attributor:
4187 A.registerManifestAddedBasicBlock(BB&: *InitBB);
4188 A.registerManifestAddedBasicBlock(BB&: *UserCodeBB);
4189 A.registerManifestAddedBasicBlock(BB&: *ReturnBB);
4190
4191 // Debug location:
4192 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4193 ReturnInst::Create(C&: Ctx, InsertAtEnd: ReturnBB)->setDebugLoc(DLoc);
4194 InitBB->getTerminator()->eraseFromParent();
4195
4196 // Prepare call to OMPRTL___kmpc_get_hardware_thread_id_in_block.
4197 Module &M = *Kernel->getParent();
4198 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4199 FunctionCallee ThreadIdInBlockFn =
4200 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4201 M, FnID: OMPRTL___kmpc_get_hardware_thread_id_in_block);
4202
4203 // Get thread ID in block.
4204 CallInst *ThreadIdInBlock =
4205 CallInst::Create(Func: ThreadIdInBlockFn, NameStr: "thread_id.in.block", InsertAtEnd: InitBB);
4206 OMPInfoCache.setCallingConvention(Callee: ThreadIdInBlockFn, CI: ThreadIdInBlock);
4207 ThreadIdInBlock->setDebugLoc(DLoc);
4208
4209 // Eliminate all threads in the block with ID not equal to 0:
4210 Instruction *IsMainThread =
4211 ICmpInst::Create(Op: ICmpInst::ICmp, Pred: CmpInst::ICMP_NE, S1: ThreadIdInBlock,
4212 S2: ConstantInt::get(Ty: ThreadIdInBlock->getType(), V: 0),
4213 Name: "thread.is_main", InsertAtEnd: InitBB);
4214 IsMainThread->setDebugLoc(DLoc);
4215 BranchInst::Create(IfTrue: ReturnBB, IfFalse: UserCodeBB, Cond: IsMainThread, InsertAtEnd: InitBB);
4216 }
4217
4218 bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
4219 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4220
4221 // We cannot change to SPMD mode if the runtime functions aren't availible.
4222 if (!OMPInfoCache.runtimeFnsAvailable(
4223 Fns: {OMPRTL___kmpc_get_hardware_thread_id_in_block,
4224 OMPRTL___kmpc_barrier_simple_spmd}))
4225 return false;
4226
4227 if (!SPMDCompatibilityTracker.isAssumed()) {
4228 for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
4229 if (!NonCompatibleI)
4230 continue;
4231
4232 // Skip diagnostics on calls to known OpenMP runtime functions for now.
4233 if (auto *CB = dyn_cast<CallBase>(Val: NonCompatibleI))
4234 if (OMPInfoCache.RTLFunctions.contains(V: CB->getCalledFunction()))
4235 continue;
4236
4237 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4238 ORA << "Value has potential side effects preventing SPMD-mode "
4239 "execution";
4240 if (isa<CallBase>(Val: NonCompatibleI)) {
4241 ORA << ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to "
4242 "the called function to override";
4243 }
4244 return ORA << ".";
4245 };
4246 A.emitRemark<OptimizationRemarkAnalysis>(I: NonCompatibleI, RemarkName: "OMP121",
4247 RemarkCB&: Remark);
4248
4249 LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
4250 << *NonCompatibleI << "\n");
4251 }
4252
4253 return false;
4254 }
4255
4256 // Get the actual kernel, could be the caller of the anchor scope if we have
4257 // a debug wrapper.
4258 Function *Kernel = getAnchorScope();
4259 if (Kernel->hasLocalLinkage()) {
4260 assert(Kernel->hasOneUse() && "Unexpected use of debug kernel wrapper.");
4261 auto *CB = cast<CallBase>(Val: Kernel->user_back());
4262 Kernel = CB->getCaller();
4263 }
4264 assert(omp::isOpenMPKernel(*Kernel) && "Expected kernel function!");
4265
4266 // Check if the kernel is already in SPMD mode, if so, return success.
4267 ConstantStruct *ExistingKernelEnvC =
4268 KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);
4269 auto *ExecModeC =
4270 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC: ExistingKernelEnvC);
4271 const int8_t ExecModeVal = ExecModeC->getSExtValue();
4272 if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC)
4273 return true;
4274
4275 // We will now unconditionally modify the IR, indicate a change.
4276 Changed = ChangeStatus::CHANGED;
4277
4278 // Do not use instruction guards when no parallel is present inside
4279 // the target region.
4280 if (mayContainParallelRegion())
4281 insertInstructionGuardsHelper(A);
4282 else
4283 forceSingleThreadPerWorkgroupHelper(A);
4284
4285 // Adjust the global exec mode flag that tells the runtime what mode this
4286 // kernel is executed in.
4287 assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
4288 "Initially non-SPMD kernel has SPMD exec mode!");
4289 setExecModeOfKernelEnvironment(
4290 ConstantInt::get(Ty: ExecModeC->getIntegerType(),
4291 V: ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
4292
4293 ++NumOpenMPTargetRegionKernelsSPMD;
4294
4295 auto Remark = [&](OptimizationRemark OR) {
4296 return OR << "Transformed generic-mode kernel to SPMD-mode.";
4297 };
4298 A.emitRemark<OptimizationRemark>(I: KernelInitCB, RemarkName: "OMP120", RemarkCB&: Remark);
4299 return true;
4300 };
4301
4302 bool buildCustomStateMachine(Attributor &A, ChangeStatus &Changed) {
4303 // If we have disabled state machine rewrites, don't make a custom one
4304 if (DisableOpenMPOptStateMachineRewrite)
4305 return false;
4306
4307 // Don't rewrite the state machine if we are not in a valid state.
4308 if (!ReachedKnownParallelRegions.isValidState())
4309 return false;
4310
4311 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4312 if (!OMPInfoCache.runtimeFnsAvailable(
4313 Fns: {OMPRTL___kmpc_get_hardware_num_threads_in_block,
4314 OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
4315 OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
4316 return false;
4317
4318 ConstantStruct *ExistingKernelEnvC =
4319 KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);
4320
4321 // Check if the current configuration is non-SPMD and generic state machine.
4322 // If we already have SPMD mode or a custom state machine we do not need to
4323 // go any further. If it is anything but a constant something is weird and
4324 // we give up.
4325 ConstantInt *UseStateMachineC =
4326 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4327 KernelEnvC: ExistingKernelEnvC);
4328 ConstantInt *ModeC =
4329 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC: ExistingKernelEnvC);
4330
4331 // If we are stuck with generic mode, try to create a custom device (=GPU)
4332 // state machine which is specialized for the parallel regions that are
4333 // reachable by the kernel.
4334 if (UseStateMachineC->isZero() ||
4335 (ModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD))
4336 return false;
4337
4338 Changed = ChangeStatus::CHANGED;
4339
4340 // If not SPMD mode, indicate we use a custom state machine now.
4341 setUseGenericStateMachineOfKernelEnvironment(
4342 ConstantInt::get(Ty: UseStateMachineC->getIntegerType(), V: false));
4343
4344 // If we don't actually need a state machine we are done here. This can
4345 // happen if there simply are no parallel regions. In the resulting kernel
4346 // all worker threads will simply exit right away, leaving the main thread
4347 // to do the work alone.
4348 if (!mayContainParallelRegion()) {
4349 ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
4350
4351 auto Remark = [&](OptimizationRemark OR) {
4352 return OR << "Removing unused state machine from generic-mode kernel.";
4353 };
4354 A.emitRemark<OptimizationRemark>(I: KernelInitCB, RemarkName: "OMP130", RemarkCB&: Remark);
4355
4356 return true;
4357 }
4358
4359 // Keep track in the statistics of our new shiny custom state machine.
4360 if (ReachedUnknownParallelRegions.empty()) {
4361 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
4362
4363 auto Remark = [&](OptimizationRemark OR) {
4364 return OR << "Rewriting generic-mode kernel with a customized state "
4365 "machine.";
4366 };
4367 A.emitRemark<OptimizationRemark>(I: KernelInitCB, RemarkName: "OMP131", RemarkCB&: Remark);
4368 } else {
4369 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
4370
4371 auto Remark = [&](OptimizationRemarkAnalysis OR) {
4372 return OR << "Generic-mode kernel is executed with a customized state "
4373 "machine that requires a fallback.";
4374 };
4375 A.emitRemark<OptimizationRemarkAnalysis>(I: KernelInitCB, RemarkName: "OMP132", RemarkCB&: Remark);
4376
4377 // Tell the user why we ended up with a fallback.
4378 for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
4379 if (!UnknownParallelRegionCB)
4380 continue;
4381 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4382 return ORA << "Call may contain unknown parallel regions. Use "
4383 << "`__attribute__((assume(\"omp_no_parallelism\")))` to "
4384 "override.";
4385 };
4386 A.emitRemark<OptimizationRemarkAnalysis>(I: UnknownParallelRegionCB,
4387 RemarkName: "OMP133", RemarkCB&: Remark);
4388 }
4389 }
4390
4391 // Create all the blocks:
4392 //
4393 // InitCB = __kmpc_target_init(...)
4394 // BlockHwSize =
4395 // __kmpc_get_hardware_num_threads_in_block();
4396 // WarpSize = __kmpc_get_warp_size();
4397 // BlockSize = BlockHwSize - WarpSize;
4398 // IsWorkerCheckBB: bool IsWorker = InitCB != -1;
4399 // if (IsWorker) {
4400 // if (InitCB >= BlockSize) return;
4401 // SMBeginBB: __kmpc_barrier_simple_generic(...);
4402 // void *WorkFn;
4403 // bool Active = __kmpc_kernel_parallel(&WorkFn);
4404 // if (!WorkFn) return;
4405 // SMIsActiveCheckBB: if (Active) {
4406 // SMIfCascadeCurrentBB: if (WorkFn == <ParFn0>)
4407 // ParFn0(...);
4408 // SMIfCascadeCurrentBB: else if (WorkFn == <ParFn1>)
4409 // ParFn1(...);
4410 // ...
4411 // SMIfCascadeCurrentBB: else
4412 // ((WorkFnTy*)WorkFn)(...);
4413 // SMEndParallelBB: __kmpc_kernel_end_parallel(...);
4414 // }
4415 // SMDoneBB: __kmpc_barrier_simple_generic(...);
4416 // goto SMBeginBB;
4417 // }
4418 // UserCodeEntryBB: // user code
4419 // __kmpc_target_deinit(...)
4420 //
4421 auto &Ctx = getAnchorValue().getContext();
4422 Function *Kernel = getAssociatedFunction();
4423 assert(Kernel && "Expected an associated function!");
4424
4425 BasicBlock *InitBB = KernelInitCB->getParent();
4426 BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
4427 I: KernelInitCB->getNextNode(), BBName: "thread.user_code.check");
4428 BasicBlock *IsWorkerCheckBB =
4429 BasicBlock::Create(Context&: Ctx, Name: "is_worker_check", Parent: Kernel, InsertBefore: UserCodeEntryBB);
4430 BasicBlock *StateMachineBeginBB = BasicBlock::Create(
4431 Context&: Ctx, Name: "worker_state_machine.begin", Parent: Kernel, InsertBefore: UserCodeEntryBB);
4432 BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
4433 Context&: Ctx, Name: "worker_state_machine.finished", Parent: Kernel, InsertBefore: UserCodeEntryBB);
4434 BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
4435 Context&: Ctx, Name: "worker_state_machine.is_active.check", Parent: Kernel, InsertBefore: UserCodeEntryBB);
4436 BasicBlock *StateMachineIfCascadeCurrentBB =
4437 BasicBlock::Create(Context&: Ctx, Name: "worker_state_machine.parallel_region.check",
4438 Parent: Kernel, InsertBefore: UserCodeEntryBB);
4439 BasicBlock *StateMachineEndParallelBB =
4440 BasicBlock::Create(Context&: Ctx, Name: "worker_state_machine.parallel_region.end",
4441 Parent: Kernel, InsertBefore: UserCodeEntryBB);
4442 BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
4443 Context&: Ctx, Name: "worker_state_machine.done.barrier", Parent: Kernel, InsertBefore: UserCodeEntryBB);
4444 A.registerManifestAddedBasicBlock(BB&: *InitBB);
4445 A.registerManifestAddedBasicBlock(BB&: *UserCodeEntryBB);
4446 A.registerManifestAddedBasicBlock(BB&: *IsWorkerCheckBB);
4447 A.registerManifestAddedBasicBlock(BB&: *StateMachineBeginBB);
4448 A.registerManifestAddedBasicBlock(BB&: *StateMachineFinishedBB);
4449 A.registerManifestAddedBasicBlock(BB&: *StateMachineIsActiveCheckBB);
4450 A.registerManifestAddedBasicBlock(BB&: *StateMachineIfCascadeCurrentBB);
4451 A.registerManifestAddedBasicBlock(BB&: *StateMachineEndParallelBB);
4452 A.registerManifestAddedBasicBlock(BB&: *StateMachineDoneBarrierBB);
4453
4454 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4455 ReturnInst::Create(C&: Ctx, InsertAtEnd: StateMachineFinishedBB)->setDebugLoc(DLoc);
4456 InitBB->getTerminator()->eraseFromParent();
4457
4458 Instruction *IsWorker =
4459 ICmpInst::Create(Op: ICmpInst::ICmp, Pred: llvm::CmpInst::ICMP_NE, S1: KernelInitCB,
4460 S2: ConstantInt::get(Ty: KernelInitCB->getType(), V: -1),
4461 Name: "thread.is_worker", InsertAtEnd: InitBB);
4462 IsWorker->setDebugLoc(DLoc);
4463 BranchInst::Create(IfTrue: IsWorkerCheckBB, IfFalse: UserCodeEntryBB, Cond: IsWorker, InsertAtEnd: InitBB);
4464
4465 Module &M = *Kernel->getParent();
4466 FunctionCallee BlockHwSizeFn =
4467 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4468 M, FnID: OMPRTL___kmpc_get_hardware_num_threads_in_block);
4469 FunctionCallee WarpSizeFn =
4470 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4471 M, FnID: OMPRTL___kmpc_get_warp_size);
4472 CallInst *BlockHwSize =
4473 CallInst::Create(Func: BlockHwSizeFn, NameStr: "block.hw_size", InsertAtEnd: IsWorkerCheckBB);
4474 OMPInfoCache.setCallingConvention(Callee: BlockHwSizeFn, CI: BlockHwSize);
4475 BlockHwSize->setDebugLoc(DLoc);
4476 CallInst *WarpSize =
4477 CallInst::Create(Func: WarpSizeFn, NameStr: "warp.size", InsertAtEnd: IsWorkerCheckBB);
4478 OMPInfoCache.setCallingConvention(Callee: WarpSizeFn, CI: WarpSize);
4479 WarpSize->setDebugLoc(DLoc);
4480 Instruction *BlockSize = BinaryOperator::CreateSub(
4481 V1: BlockHwSize, V2: WarpSize, Name: "block.size", BB: IsWorkerCheckBB);
4482 BlockSize->setDebugLoc(DLoc);
4483 Instruction *IsMainOrWorker = ICmpInst::Create(
4484 Op: ICmpInst::ICmp, Pred: llvm::CmpInst::ICMP_SLT, S1: KernelInitCB, S2: BlockSize,
4485 Name: "thread.is_main_or_worker", InsertAtEnd: IsWorkerCheckBB);
4486 IsMainOrWorker->setDebugLoc(DLoc);
4487 BranchInst::Create(IfTrue: StateMachineBeginBB, IfFalse: StateMachineFinishedBB,
4488 Cond: IsMainOrWorker, InsertAtEnd: IsWorkerCheckBB);
4489
4490 // Create local storage for the work function pointer.
4491 const DataLayout &DL = M.getDataLayout();
4492 Type *VoidPtrTy = PointerType::getUnqual(C&: Ctx);
4493 Instruction *WorkFnAI =
4494 new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
4495 "worker.work_fn.addr", Kernel->getEntryBlock().begin());
4496 WorkFnAI->setDebugLoc(DLoc);
4497
4498 OMPInfoCache.OMPBuilder.updateToLocation(
4499 Loc: OpenMPIRBuilder::LocationDescription(
4500 IRBuilder<>::InsertPoint(StateMachineBeginBB,
4501 StateMachineBeginBB->end()),
4502 DLoc));
4503
4504 Value *Ident = KernelInfo::getIdentFromKernelEnvironment(KernelEnvC);
4505 Value *GTid = KernelInitCB;
4506
4507 FunctionCallee BarrierFn =
4508 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4509 M, FnID: OMPRTL___kmpc_barrier_simple_generic);
4510 CallInst *Barrier =
4511 CallInst::Create(Func: BarrierFn, Args: {Ident, GTid}, NameStr: "", InsertAtEnd: StateMachineBeginBB);
4512 OMPInfoCache.setCallingConvention(Callee: BarrierFn, CI: Barrier);
4513 Barrier->setDebugLoc(DLoc);
4514
4515 if (WorkFnAI->getType()->getPointerAddressSpace() !=
4516 (unsigned int)AddressSpace::Generic) {
4517 WorkFnAI = new AddrSpaceCastInst(
4518 WorkFnAI, PointerType::get(C&: Ctx, AddressSpace: (unsigned int)AddressSpace::Generic),
4519 WorkFnAI->getName() + ".generic", StateMachineBeginBB);
4520 WorkFnAI->setDebugLoc(DLoc);
4521 }
4522
4523 FunctionCallee KernelParallelFn =
4524 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4525 M, FnID: OMPRTL___kmpc_kernel_parallel);
4526 CallInst *IsActiveWorker = CallInst::Create(
4527 Func: KernelParallelFn, Args: {WorkFnAI}, NameStr: "worker.is_active", InsertAtEnd: StateMachineBeginBB);
4528 OMPInfoCache.setCallingConvention(Callee: KernelParallelFn, CI: IsActiveWorker);
4529 IsActiveWorker->setDebugLoc(DLoc);
4530 Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
4531 StateMachineBeginBB);
4532 WorkFn->setDebugLoc(DLoc);
4533
4534 FunctionType *ParallelRegionFnTy = FunctionType::get(
4535 Result: Type::getVoidTy(C&: Ctx), Params: {Type::getInt16Ty(C&: Ctx), Type::getInt32Ty(C&: Ctx)},
4536 isVarArg: false);
4537
4538 Instruction *IsDone =
4539 ICmpInst::Create(Op: ICmpInst::ICmp, Pred: llvm::CmpInst::ICMP_EQ, S1: WorkFn,
4540 S2: Constant::getNullValue(Ty: VoidPtrTy), Name: "worker.is_done",
4541 InsertAtEnd: StateMachineBeginBB);
4542 IsDone->setDebugLoc(DLoc);
4543 BranchInst::Create(IfTrue: StateMachineFinishedBB, IfFalse: StateMachineIsActiveCheckBB,
4544 Cond: IsDone, InsertAtEnd: StateMachineBeginBB)
4545 ->setDebugLoc(DLoc);
4546
4547 BranchInst::Create(IfTrue: StateMachineIfCascadeCurrentBB,
4548 IfFalse: StateMachineDoneBarrierBB, Cond: IsActiveWorker,
4549 InsertAtEnd: StateMachineIsActiveCheckBB)
4550 ->setDebugLoc(DLoc);
4551
4552 Value *ZeroArg =
4553 Constant::getNullValue(Ty: ParallelRegionFnTy->getParamType(i: 0));
4554
4555 const unsigned int WrapperFunctionArgNo = 6;
4556
4557 // Now that we have most of the CFG skeleton it is time for the if-cascade
4558 // that checks the function pointer we got from the runtime against the
4559 // parallel regions we expect, if there are any.
4560 for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {
4561 auto *CB = ReachedKnownParallelRegions[I];
4562 auto *ParallelRegion = dyn_cast<Function>(
4563 Val: CB->getArgOperand(i: WrapperFunctionArgNo)->stripPointerCasts());
4564 BasicBlock *PRExecuteBB = BasicBlock::Create(
4565 Context&: Ctx, Name: "worker_state_machine.parallel_region.execute", Parent: Kernel,
4566 InsertBefore: StateMachineEndParallelBB);
4567 CallInst::Create(Func: ParallelRegion, Args: {ZeroArg, GTid}, NameStr: "", InsertAtEnd: PRExecuteBB)
4568 ->setDebugLoc(DLoc);
4569 BranchInst::Create(IfTrue: StateMachineEndParallelBB, InsertAtEnd: PRExecuteBB)
4570 ->setDebugLoc(DLoc);
4571
4572 BasicBlock *PRNextBB =
4573 BasicBlock::Create(Context&: Ctx, Name: "worker_state_machine.parallel_region.check",
4574 Parent: Kernel, InsertBefore: StateMachineEndParallelBB);
4575 A.registerManifestAddedBasicBlock(BB&: *PRExecuteBB);
4576 A.registerManifestAddedBasicBlock(BB&: *PRNextBB);
4577
4578 // Check if we need to compare the pointer at all or if we can just
4579 // call the parallel region function.
4580 Value *IsPR;
4581 if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) {
4582 Instruction *CmpI = ICmpInst::Create(
4583 Op: ICmpInst::ICmp, Pred: llvm::CmpInst::ICMP_EQ, S1: WorkFn, S2: ParallelRegion,
4584 Name: "worker.check_parallel_region", InsertAtEnd: StateMachineIfCascadeCurrentBB);
4585 CmpI->setDebugLoc(DLoc);
4586 IsPR = CmpI;
4587 } else {
4588 IsPR = ConstantInt::getTrue(Context&: Ctx);
4589 }
4590
4591 BranchInst::Create(IfTrue: PRExecuteBB, IfFalse: PRNextBB, Cond: IsPR,
4592 InsertAtEnd: StateMachineIfCascadeCurrentBB)
4593 ->setDebugLoc(DLoc);
4594 StateMachineIfCascadeCurrentBB = PRNextBB;
4595 }
4596
4597 // At the end of the if-cascade we place the indirect function pointer call
4598 // in case we might need it, that is if there can be parallel regions we
4599 // have not handled in the if-cascade above.
4600 if (!ReachedUnknownParallelRegions.empty()) {
4601 StateMachineIfCascadeCurrentBB->setName(
4602 "worker_state_machine.parallel_region.fallback.execute");
4603 CallInst::Create(Ty: ParallelRegionFnTy, Func: WorkFn, Args: {ZeroArg, GTid}, NameStr: "",
4604 InsertAtEnd: StateMachineIfCascadeCurrentBB)
4605 ->setDebugLoc(DLoc);
4606 }
4607 BranchInst::Create(IfTrue: StateMachineEndParallelBB,
4608 InsertAtEnd: StateMachineIfCascadeCurrentBB)
4609 ->setDebugLoc(DLoc);
4610
4611 FunctionCallee EndParallelFn =
4612 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4613 M, FnID: OMPRTL___kmpc_kernel_end_parallel);
4614 CallInst *EndParallel =
4615 CallInst::Create(Func: EndParallelFn, Args: {}, NameStr: "", InsertAtEnd: StateMachineEndParallelBB);
4616 OMPInfoCache.setCallingConvention(Callee: EndParallelFn, CI: EndParallel);
4617 EndParallel->setDebugLoc(DLoc);
4618 BranchInst::Create(IfTrue: StateMachineDoneBarrierBB, InsertAtEnd: StateMachineEndParallelBB)
4619 ->setDebugLoc(DLoc);
4620
4621 CallInst::Create(Func: BarrierFn, Args: {Ident, GTid}, NameStr: "", InsertAtEnd: StateMachineDoneBarrierBB)
4622 ->setDebugLoc(DLoc);
4623 BranchInst::Create(IfTrue: StateMachineBeginBB, InsertAtEnd: StateMachineDoneBarrierBB)
4624 ->setDebugLoc(DLoc);
4625
4626 return true;
4627 }
4628
4629 /// Fixpoint iteration update function. Will be called every time a dependence
4630 /// changed its state (and in the beginning).
4631 ChangeStatus updateImpl(Attributor &A) override {
4632 KernelInfoState StateBefore = getState();
4633
4634 // When we leave this function this RAII will make sure the member
4635 // KernelEnvC is updated properly depending on the state. That member is
4636 // used for simplification of values and needs to be up to date at all
4637 // times.
4638 struct UpdateKernelEnvCRAII {
4639 AAKernelInfoFunction &AA;
4640
4641 UpdateKernelEnvCRAII(AAKernelInfoFunction &AA) : AA(AA) {}
4642
4643 ~UpdateKernelEnvCRAII() {
4644 if (!AA.KernelEnvC)
4645 return;
4646
4647 ConstantStruct *ExistingKernelEnvC =
4648 KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB: AA.KernelInitCB);
4649
4650 if (!AA.isValidState()) {
4651 AA.KernelEnvC = ExistingKernelEnvC;
4652 return;
4653 }
4654
4655 if (!AA.ReachedKnownParallelRegions.isValidState())
4656 AA.setUseGenericStateMachineOfKernelEnvironment(
4657 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4658 KernelEnvC: ExistingKernelEnvC));
4659
4660 if (!AA.SPMDCompatibilityTracker.isValidState())
4661 AA.setExecModeOfKernelEnvironment(
4662 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC: ExistingKernelEnvC));
4663
4664 ConstantInt *MayUseNestedParallelismC =
4665 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(
4666 KernelEnvC: AA.KernelEnvC);
4667 ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get(
4668 Ty: MayUseNestedParallelismC->getIntegerType(), V: AA.NestedParallelism);
4669 AA.setMayUseNestedParallelismOfKernelEnvironment(
4670 NewMayUseNestedParallelismC);
4671 }
4672 } RAII(*this);
4673
4674 // Callback to check a read/write instruction.
4675 auto CheckRWInst = [&](Instruction &I) {
4676 // We handle calls later.
4677 if (isa<CallBase>(Val: I))
4678 return true;
4679 // We only care about write effects.
4680 if (!I.mayWriteToMemory())
4681 return true;
4682 if (auto *SI = dyn_cast<StoreInst>(Val: &I)) {
4683 const auto *UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>(
4684 QueryingAA: *this, IRP: IRPosition::value(V: *SI->getPointerOperand()),
4685 DepClass: DepClassTy::OPTIONAL);
4686 auto *HS = A.getAAFor<AAHeapToStack>(
4687 QueryingAA: *this, IRP: IRPosition::function(F: *I.getFunction()),
4688 DepClass: DepClassTy::OPTIONAL);
4689 if (UnderlyingObjsAA &&
4690 UnderlyingObjsAA->forallUnderlyingObjects(Pred: [&](Value &Obj) {
4691 if (AA::isAssumedThreadLocalObject(A, Obj, QueryingAA: *this))
4692 return true;
4693 // Check for AAHeapToStack moved objects which must not be
4694 // guarded.
4695 auto *CB = dyn_cast<CallBase>(Val: &Obj);
4696 return CB && HS && HS->isAssumedHeapToStack(CB: *CB);
4697 }))
4698 return true;
4699 }
4700
4701 // Insert instruction that needs guarding.
4702 SPMDCompatibilityTracker.insert(Elem: &I);
4703 return true;
4704 };
4705
4706 bool UsedAssumedInformationInCheckRWInst = false;
4707 if (!SPMDCompatibilityTracker.isAtFixpoint())
4708 if (!A.checkForAllReadWriteInstructions(
4709 Pred: CheckRWInst, QueryingAA&: *this, UsedAssumedInformation&: UsedAssumedInformationInCheckRWInst))
4710 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4711
4712 bool UsedAssumedInformationFromReachingKernels = false;
4713 if (!IsKernelEntry) {
4714 updateParallelLevels(A);
4715
4716 bool AllReachingKernelsKnown = true;
4717 updateReachingKernelEntries(A, AllReachingKernelsKnown);
4718 UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
4719
4720 if (!SPMDCompatibilityTracker.empty()) {
4721 if (!ParallelLevels.isValidState())
4722 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4723 else if (!ReachingKernelEntries.isValidState())
4724 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4725 else {
4726 // Check if all reaching kernels agree on the mode as we can otherwise
4727 // not guard instructions. We might not be sure about the mode so we
4728 // we cannot fix the internal spmd-zation state either.
4729 int SPMD = 0, Generic = 0;
4730 for (auto *Kernel : ReachingKernelEntries) {
4731 auto *CBAA = A.getAAFor<AAKernelInfo>(
4732 QueryingAA: *this, IRP: IRPosition::function(F: *Kernel), DepClass: DepClassTy::OPTIONAL);
4733 if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&
4734 CBAA->SPMDCompatibilityTracker.isAssumed())
4735 ++SPMD;
4736 else
4737 ++Generic;
4738 if (!CBAA || !CBAA->SPMDCompatibilityTracker.isAtFixpoint())
4739 UsedAssumedInformationFromReachingKernels = true;
4740 }
4741 if (SPMD != 0 && Generic != 0)
4742 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4743 }
4744 }
4745 }
4746
4747 // Callback to check a call instruction.
4748 bool AllParallelRegionStatesWereFixed = true;
4749 bool AllSPMDStatesWereFixed = true;
4750 auto CheckCallInst = [&](Instruction &I) {
4751 auto &CB = cast<CallBase>(Val&: I);
4752 auto *CBAA = A.getAAFor<AAKernelInfo>(
4753 QueryingAA: *this, IRP: IRPosition::callsite_function(CB), DepClass: DepClassTy::OPTIONAL);
4754 if (!CBAA)
4755 return false;
4756 getState() ^= CBAA->getState();
4757 AllSPMDStatesWereFixed &= CBAA->SPMDCompatibilityTracker.isAtFixpoint();
4758 AllParallelRegionStatesWereFixed &=
4759 CBAA->ReachedKnownParallelRegions.isAtFixpoint();
4760 AllParallelRegionStatesWereFixed &=
4761 CBAA->ReachedUnknownParallelRegions.isAtFixpoint();
4762 return true;
4763 };
4764
4765 bool UsedAssumedInformationInCheckCallInst = false;
4766 if (!A.checkForAllCallLikeInstructions(
4767 Pred: CheckCallInst, QueryingAA: *this, UsedAssumedInformation&: UsedAssumedInformationInCheckCallInst)) {
4768 LLVM_DEBUG(dbgs() << TAG
4769 << "Failed to visit all call-like instructions!\n";);
4770 return indicatePessimisticFixpoint();
4771 }
4772
4773 // If we haven't used any assumed information for the reached parallel
4774 // region states we can fix it.
4775 if (!UsedAssumedInformationInCheckCallInst &&
4776 AllParallelRegionStatesWereFixed) {
4777 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
4778 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
4779 }
4780
4781 // If we haven't used any assumed information for the SPMD state we can fix
4782 // it.
4783 if (!UsedAssumedInformationInCheckRWInst &&
4784 !UsedAssumedInformationInCheckCallInst &&
4785 !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
4786 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4787
4788 return StateBefore == getState() ? ChangeStatus::UNCHANGED
4789 : ChangeStatus::CHANGED;
4790 }
4791
4792private:
4793 /// Update info regarding reaching kernels.
4794 void updateReachingKernelEntries(Attributor &A,
4795 bool &AllReachingKernelsKnown) {
4796 auto PredCallSite = [&](AbstractCallSite ACS) {
4797 Function *Caller = ACS.getInstruction()->getFunction();
4798
4799 assert(Caller && "Caller is nullptr");
4800
4801 auto *CAA = A.getOrCreateAAFor<AAKernelInfo>(
4802 IRP: IRPosition::function(F: *Caller), QueryingAA: this, DepClass: DepClassTy::REQUIRED);
4803 if (CAA && CAA->ReachingKernelEntries.isValidState()) {
4804 ReachingKernelEntries ^= CAA->ReachingKernelEntries;
4805 return true;
4806 }
4807
4808 // We lost track of the caller of the associated function, any kernel
4809 // could reach now.
4810 ReachingKernelEntries.indicatePessimisticFixpoint();
4811
4812 return true;
4813 };
4814
4815 if (!A.checkForAllCallSites(Pred: PredCallSite, QueryingAA: *this,
4816 RequireAllCallSites: true /* RequireAllCallSites */,
4817 UsedAssumedInformation&: AllReachingKernelsKnown))
4818 ReachingKernelEntries.indicatePessimisticFixpoint();
4819 }
4820
4821 /// Update info regarding parallel levels.
4822 void updateParallelLevels(Attributor &A) {
4823 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4824 OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
4825 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
4826
4827 auto PredCallSite = [&](AbstractCallSite ACS) {
4828 Function *Caller = ACS.getInstruction()->getFunction();
4829
4830 assert(Caller && "Caller is nullptr");
4831
4832 auto *CAA =
4833 A.getOrCreateAAFor<AAKernelInfo>(IRP: IRPosition::function(F: *Caller));
4834 if (CAA && CAA->ParallelLevels.isValidState()) {
4835 // Any function that is called by `__kmpc_parallel_51` will not be
4836 // folded as the parallel level in the function is updated. In order to
4837 // get it right, all the analysis would depend on the implentation. That
4838 // said, if in the future any change to the implementation, the analysis
4839 // could be wrong. As a consequence, we are just conservative here.
4840 if (Caller == Parallel51RFI.Declaration) {
4841 ParallelLevels.indicatePessimisticFixpoint();
4842 return true;
4843 }
4844
4845 ParallelLevels ^= CAA->ParallelLevels;
4846
4847 return true;
4848 }
4849
4850 // We lost track of the caller of the associated function, any kernel
4851 // could reach now.
4852 ParallelLevels.indicatePessimisticFixpoint();
4853
4854 return true;
4855 };
4856
4857 bool AllCallSitesKnown = true;
4858 if (!A.checkForAllCallSites(Pred: PredCallSite, QueryingAA: *this,
4859 RequireAllCallSites: true /* RequireAllCallSites */,
4860 UsedAssumedInformation&: AllCallSitesKnown))
4861 ParallelLevels.indicatePessimisticFixpoint();
4862 }
4863};
4864
4865/// The call site kernel info abstract attribute, basically, what can we say
4866/// about a call site with regards to the KernelInfoState. For now this simply
4867/// forwards the information from the callee.
4868struct AAKernelInfoCallSite : AAKernelInfo {
4869 AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
4870 : AAKernelInfo(IRP, A) {}
4871
4872 /// See AbstractAttribute::initialize(...).
4873 void initialize(Attributor &A) override {
4874 AAKernelInfo::initialize(A);
4875
4876 CallBase &CB = cast<CallBase>(Val&: getAssociatedValue());
4877 auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>(
4878 QueryingAA: *this, IRP: IRPosition::callsite_function(CB), DepClass: DepClassTy::OPTIONAL);
4879
4880 // Check for SPMD-mode assumptions.
4881 if (AssumptionAA && AssumptionAA->hasAssumption(Assumption: "ompx_spmd_amenable")) {
4882 indicateOptimisticFixpoint();
4883 return;
4884 }
4885
4886 // First weed out calls we do not care about, that is readonly/readnone
4887 // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
4888 // parallel region or anything else we are looking for.
4889 if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(Val: CB)) {
4890 indicateOptimisticFixpoint();
4891 return;
4892 }
4893
4894 // Next we check if we know the callee. If it is a known OpenMP function
4895 // we will handle them explicitly in the switch below. If it is not, we
4896 // will use an AAKernelInfo object on the callee to gather information and
4897 // merge that into the current state. The latter happens in the updateImpl.
4898 auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {
4899 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4900 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Val: Callee);
4901 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4902 // Unknown caller or declarations are not analyzable, we give up.
4903 if (!Callee || !A.isFunctionIPOAmendable(F: *Callee)) {
4904
4905 // Unknown callees might contain parallel regions, except if they have
4906 // an appropriate assumption attached.
4907 if (!AssumptionAA ||
4908 !(AssumptionAA->hasAssumption(Assumption: "omp_no_openmp") ||
4909 AssumptionAA->hasAssumption(Assumption: "omp_no_parallelism")))
4910 ReachedUnknownParallelRegions.insert(Elem: &CB);
4911
4912 // If SPMDCompatibilityTracker is not fixed, we need to give up on the
4913 // idea we can run something unknown in SPMD-mode.
4914 if (!SPMDCompatibilityTracker.isAtFixpoint()) {
4915 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4916 SPMDCompatibilityTracker.insert(Elem: &CB);
4917 }
4918
4919 // We have updated the state for this unknown call properly, there
4920 // won't be any change so we indicate a fixpoint.
4921 indicateOptimisticFixpoint();
4922 }
4923 // If the callee is known and can be used in IPO, we will update the
4924 // state based on the callee state in updateImpl.
4925 return;
4926 }
4927 if (NumCallees > 1) {
4928 indicatePessimisticFixpoint();
4929 return;
4930 }
4931
4932 RuntimeFunction RF = It->getSecond();
4933 switch (RF) {
4934 // All the functions we know are compatible with SPMD mode.
4935 case OMPRTL___kmpc_is_spmd_exec_mode:
4936 case OMPRTL___kmpc_distribute_static_fini:
4937 case OMPRTL___kmpc_for_static_fini:
4938 case OMPRTL___kmpc_global_thread_num:
4939 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4940 case OMPRTL___kmpc_get_hardware_num_blocks:
4941 case OMPRTL___kmpc_single:
4942 case OMPRTL___kmpc_end_single:
4943 case OMPRTL___kmpc_master:
4944 case OMPRTL___kmpc_end_master:
4945 case OMPRTL___kmpc_barrier:
4946 case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
4947 case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
4948 case OMPRTL___kmpc_error:
4949 case OMPRTL___kmpc_flush:
4950 case OMPRTL___kmpc_get_hardware_thread_id_in_block:
4951 case OMPRTL___kmpc_get_warp_size:
4952 case OMPRTL_omp_get_thread_num:
4953 case OMPRTL_omp_get_num_threads:
4954 case OMPRTL_omp_get_max_threads:
4955 case OMPRTL_omp_in_parallel:
4956 case OMPRTL_omp_get_dynamic:
4957 case OMPRTL_omp_get_cancellation:
4958 case OMPRTL_omp_get_nested:
4959 case OMPRTL_omp_get_schedule:
4960 case OMPRTL_omp_get_thread_limit:
4961 case OMPRTL_omp_get_supported_active_levels:
4962 case OMPRTL_omp_get_max_active_levels:
4963 case OMPRTL_omp_get_level:
4964 case OMPRTL_omp_get_ancestor_thread_num:
4965 case OMPRTL_omp_get_team_size:
4966 case OMPRTL_omp_get_active_level:
4967 case OMPRTL_omp_in_final:
4968 case OMPRTL_omp_get_proc_bind:
4969 case OMPRTL_omp_get_num_places:
4970 case OMPRTL_omp_get_num_procs:
4971 case OMPRTL_omp_get_place_proc_ids:
4972 case OMPRTL_omp_get_place_num:
4973 case OMPRTL_omp_get_partition_num_places:
4974 case OMPRTL_omp_get_partition_place_nums:
4975 case OMPRTL_omp_get_wtime:
4976 break;
4977 case OMPRTL___kmpc_distribute_static_init_4:
4978 case OMPRTL___kmpc_distribute_static_init_4u:
4979 case OMPRTL___kmpc_distribute_static_init_8:
4980 case OMPRTL___kmpc_distribute_static_init_8u:
4981 case OMPRTL___kmpc_for_static_init_4:
4982 case OMPRTL___kmpc_for_static_init_4u:
4983 case OMPRTL___kmpc_for_static_init_8:
4984 case OMPRTL___kmpc_for_static_init_8u: {
4985 // Check the schedule and allow static schedule in SPMD mode.
4986 unsigned ScheduleArgOpNo = 2;
4987 auto *ScheduleTypeCI =
4988 dyn_cast<ConstantInt>(Val: CB.getArgOperand(i: ScheduleArgOpNo));
4989 unsigned ScheduleTypeVal =
4990 ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4991 switch (OMPScheduleType(ScheduleTypeVal)) {
4992 case OMPScheduleType::UnorderedStatic:
4993 case OMPScheduleType::UnorderedStaticChunked:
4994 case OMPScheduleType::OrderedDistribute:
4995 case OMPScheduleType::OrderedDistributeChunked:
4996 break;
4997 default:
4998 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4999 SPMDCompatibilityTracker.insert(Elem: &CB);
5000 break;
5001 };
5002 } break;
5003 case OMPRTL___kmpc_target_init:
5004 KernelInitCB = &CB;
5005 break;
5006 case OMPRTL___kmpc_target_deinit:
5007 KernelDeinitCB = &CB;
5008 break;
5009 case OMPRTL___kmpc_parallel_51:
5010 if (!handleParallel51(A, CB))
5011 indicatePessimisticFixpoint();
5012 return;
5013 case OMPRTL___kmpc_omp_task:
5014 // We do not look into tasks right now, just give up.
5015 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5016 SPMDCompatibilityTracker.insert(Elem: &CB);
5017 ReachedUnknownParallelRegions.insert(Elem: &CB);
5018 break;
5019 case OMPRTL___kmpc_alloc_shared:
5020 case OMPRTL___kmpc_free_shared:
5021 // Return without setting a fixpoint, to be resolved in updateImpl.
5022 return;
5023 default:
5024 // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
5025 // generally. However, they do not hide parallel regions.
5026 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5027 SPMDCompatibilityTracker.insert(Elem: &CB);
5028 break;
5029 }
5030 // All other OpenMP runtime calls will not reach parallel regions so they
5031 // can be safely ignored for now. Since it is a known OpenMP runtime call
5032 // we have now modeled all effects and there is no need for any update.
5033 indicateOptimisticFixpoint();
5034 };
5035
5036 const auto *AACE =
5037 A.getAAFor<AACallEdges>(QueryingAA: *this, IRP: getIRPosition(), DepClass: DepClassTy::OPTIONAL);
5038 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5039 CheckCallee(getAssociatedFunction(), 1);
5040 return;
5041 }
5042 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5043 for (auto *Callee : OptimisticEdges) {
5044 CheckCallee(Callee, OptimisticEdges.size());
5045 if (isAtFixpoint())
5046 break;
5047 }
5048 }
5049
5050 ChangeStatus updateImpl(Attributor &A) override {
5051 // TODO: Once we have call site specific value information we can provide
5052 // call site specific liveness information and then it makes
5053 // sense to specialize attributes for call sites arguments instead of
5054 // redirecting requests to the callee argument.
5055 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5056 KernelInfoState StateBefore = getState();
5057
5058 auto CheckCallee = [&](Function *F, int NumCallees) {
5059 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Val: F);
5060
5061 // If F is not a runtime function, propagate the AAKernelInfo of the
5062 // callee.
5063 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
5064 const IRPosition &FnPos = IRPosition::function(F: *F);
5065 auto *FnAA =
5066 A.getAAFor<AAKernelInfo>(QueryingAA: *this, IRP: FnPos, DepClass: DepClassTy::REQUIRED);
5067 if (!FnAA)
5068 return indicatePessimisticFixpoint();
5069 if (getState() == FnAA->getState())
5070 return ChangeStatus::UNCHANGED;
5071 getState() = FnAA->getState();
5072 return ChangeStatus::CHANGED;
5073 }
5074 if (NumCallees > 1)
5075 return indicatePessimisticFixpoint();
5076
5077 CallBase &CB = cast<CallBase>(Val&: getAssociatedValue());
5078 if (It->getSecond() == OMPRTL___kmpc_parallel_51) {
5079 if (!handleParallel51(A, CB))
5080 return indicatePessimisticFixpoint();
5081 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5082 : ChangeStatus::CHANGED;
5083 }
5084
5085 // F is a runtime function that allocates or frees memory, check
5086 // AAHeapToStack and AAHeapToShared.
5087 assert(
5088 (It->getSecond() == OMPRTL___kmpc_alloc_shared ||
5089 It->getSecond() == OMPRTL___kmpc_free_shared) &&
5090 "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
5091
5092 auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(
5093 QueryingAA: *this, IRP: IRPosition::function(F: *CB.getCaller()), DepClass: DepClassTy::OPTIONAL);
5094 auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>(
5095 QueryingAA: *this, IRP: IRPosition::function(F: *CB.getCaller()), DepClass: DepClassTy::OPTIONAL);
5096
5097 RuntimeFunction RF = It->getSecond();
5098
5099 switch (RF) {
5100 // If neither HeapToStack nor HeapToShared assume the call is removed,
5101 // assume SPMD incompatibility.
5102 case OMPRTL___kmpc_alloc_shared:
5103 if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
5104 (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
5105 SPMDCompatibilityTracker.insert(Elem: &CB);
5106 break;
5107 case OMPRTL___kmpc_free_shared:
5108 if ((!HeapToStackAA ||
5109 !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
5110 (!HeapToSharedAA ||
5111 !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
5112 SPMDCompatibilityTracker.insert(Elem: &CB);
5113 break;
5114 default:
5115 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5116 SPMDCompatibilityTracker.insert(Elem: &CB);
5117 }
5118 return ChangeStatus::CHANGED;
5119 };
5120
5121 const auto *AACE =
5122 A.getAAFor<AACallEdges>(QueryingAA: *this, IRP: getIRPosition(), DepClass: DepClassTy::OPTIONAL);
5123 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5124 if (Function *F = getAssociatedFunction())
5125 CheckCallee(F, /*NumCallees=*/1);
5126 } else {
5127 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5128 for (auto *Callee : OptimisticEdges) {
5129 CheckCallee(Callee, OptimisticEdges.size());
5130 if (isAtFixpoint())
5131 break;
5132 }
5133 }
5134
5135 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5136 : ChangeStatus::CHANGED;
5137 }
5138
5139 /// Deal with a __kmpc_parallel_51 call (\p CB). Returns true if the call was
5140 /// handled, if a problem occurred, false is returned.
5141 bool handleParallel51(Attributor &A, CallBase &CB) {
5142 const unsigned int NonWrapperFunctionArgNo = 5;
5143 const unsigned int WrapperFunctionArgNo = 6;
5144 auto ParallelRegionOpArgNo = SPMDCompatibilityTracker.isAssumed()
5145 ? NonWrapperFunctionArgNo
5146 : WrapperFunctionArgNo;
5147
5148 auto *ParallelRegion = dyn_cast<Function>(
5149 Val: CB.getArgOperand(i: ParallelRegionOpArgNo)->stripPointerCasts());
5150 if (!ParallelRegion)
5151 return false;
5152
5153 ReachedKnownParallelRegions.insert(Elem: &CB);
5154 /// Check nested parallelism
5155 auto *FnAA = A.getAAFor<AAKernelInfo>(
5156 QueryingAA: *this, IRP: IRPosition::function(F: *ParallelRegion), DepClass: DepClassTy::OPTIONAL);
5157 NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
5158 !FnAA->ReachedKnownParallelRegions.empty() ||
5159 !FnAA->ReachedKnownParallelRegions.isValidState() ||
5160 !FnAA->ReachedUnknownParallelRegions.isValidState() ||
5161 !FnAA->ReachedUnknownParallelRegions.empty();
5162 return true;
5163 }
5164};
5165
5166struct AAFoldRuntimeCall
5167 : public StateWrapper<BooleanState, AbstractAttribute> {
5168 using Base = StateWrapper<BooleanState, AbstractAttribute>;
5169
5170 AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
5171
5172 /// Statistics are tracked as part of manifest for now.
5173 void trackStatistics() const override {}
5174
5175 /// Create an abstract attribute biew for the position \p IRP.
5176 static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
5177 Attributor &A);
5178
5179 /// See AbstractAttribute::getName()
5180 const std::string getName() const override { return "AAFoldRuntimeCall"; }
5181
5182 /// See AbstractAttribute::getIdAddr()
5183 const char *getIdAddr() const override { return &ID; }
5184
5185 /// This function should return true if the type of the \p AA is
5186 /// AAFoldRuntimeCall
5187 static bool classof(const AbstractAttribute *AA) {
5188 return (AA->getIdAddr() == &ID);
5189 }
5190
5191 static const char ID;
5192};
5193
5194struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
5195 AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
5196 : AAFoldRuntimeCall(IRP, A) {}
5197
5198 /// See AbstractAttribute::getAsStr()
5199 const std::string getAsStr(Attributor *) const override {
5200 if (!isValidState())
5201 return "<invalid>";
5202
5203 std::string Str("simplified value: ");
5204
5205 if (!SimplifiedValue)
5206 return Str + std::string("none");
5207
5208 if (!*SimplifiedValue)
5209 return Str + std::string("nullptr");
5210
5211 if (ConstantInt *CI = dyn_cast<ConstantInt>(Val: *SimplifiedValue))
5212 return Str + std::to_string(val: CI->getSExtValue());
5213
5214 return Str + std::string("unknown");
5215 }
5216
5217 void initialize(Attributor &A) override {
5218 if (DisableOpenMPOptFolding)
5219 indicatePessimisticFixpoint();
5220
5221 Function *Callee = getAssociatedFunction();
5222
5223 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5224 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Val: Callee);
5225 assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
5226 "Expected a known OpenMP runtime function");
5227
5228 RFKind = It->getSecond();
5229
5230 CallBase &CB = cast<CallBase>(Val&: getAssociatedValue());
5231 A.registerSimplificationCallback(
5232 IRP: IRPosition::callsite_returned(CB),
5233 CB: [&](const IRPosition &IRP, const AbstractAttribute *AA,
5234 bool &UsedAssumedInformation) -> std::optional<Value *> {
5235 assert((isValidState() ||
5236 (SimplifiedValue && *SimplifiedValue == nullptr)) &&
5237 "Unexpected invalid state!");
5238
5239 if (!isAtFixpoint()) {
5240 UsedAssumedInformation = true;
5241 if (AA)
5242 A.recordDependence(FromAA: *this, ToAA: *AA, DepClass: DepClassTy::OPTIONAL);
5243 }
5244 return SimplifiedValue;
5245 });
5246 }
5247
5248 ChangeStatus updateImpl(Attributor &A) override {
5249 ChangeStatus Changed = ChangeStatus::UNCHANGED;
5250 switch (RFKind) {
5251 case OMPRTL___kmpc_is_spmd_exec_mode:
5252 Changed |= foldIsSPMDExecMode(A);
5253 break;
5254 case OMPRTL___kmpc_parallel_level:
5255 Changed |= foldParallelLevel(A);
5256 break;
5257 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
5258 Changed = Changed | foldKernelFnAttribute(A, Attr: "omp_target_thread_limit");
5259 break;
5260 case OMPRTL___kmpc_get_hardware_num_blocks:
5261 Changed = Changed | foldKernelFnAttribute(A, Attr: "omp_target_num_teams");
5262 break;
5263 default:
5264 llvm_unreachable("Unhandled OpenMP runtime function!");
5265 }
5266
5267 return Changed;
5268 }
5269
5270 ChangeStatus manifest(Attributor &A) override {
5271 ChangeStatus Changed = ChangeStatus::UNCHANGED;
5272
5273 if (SimplifiedValue && *SimplifiedValue) {
5274 Instruction &I = *getCtxI();
5275 A.changeAfterManifest(IRP: IRPosition::inst(I), NV&: **SimplifiedValue);
5276 A.deleteAfterManifest(I);
5277
5278 CallBase *CB = dyn_cast<CallBase>(Val: &I);
5279 auto Remark = [&](OptimizationRemark OR) {
5280 if (auto *C = dyn_cast<ConstantInt>(Val: *SimplifiedValue))
5281 return OR << "Replacing OpenMP runtime call "
5282 << CB->getCalledFunction()->getName() << " with "
5283 << ore::NV("FoldedValue", C->getZExtValue()) << ".";
5284 return OR << "Replacing OpenMP runtime call "
5285 << CB->getCalledFunction()->getName() << ".";
5286 };
5287
5288 if (CB && EnableVerboseRemarks)
5289 A.emitRemark<OptimizationRemark>(I: CB, RemarkName: "OMP180", RemarkCB&: Remark);
5290
5291 LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
5292 << **SimplifiedValue << "\n");
5293
5294 Changed = ChangeStatus::CHANGED;
5295 }
5296
5297 return Changed;
5298 }
5299
5300 ChangeStatus indicatePessimisticFixpoint() override {
5301 SimplifiedValue = nullptr;
5302 return AAFoldRuntimeCall::indicatePessimisticFixpoint();
5303 }
5304
5305private:
5306 /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
5307 ChangeStatus foldIsSPMDExecMode(Attributor &A) {
5308 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5309
5310 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5311 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5312 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5313 QueryingAA: *this, IRP: IRPosition::function(F: *getAnchorScope()), DepClass: DepClassTy::REQUIRED);
5314
5315 if (!CallerKernelInfoAA ||
5316 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5317 return indicatePessimisticFixpoint();
5318
5319 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5320 auto *AA = A.getAAFor<AAKernelInfo>(QueryingAA: *this, IRP: IRPosition::function(F: *K),
5321 DepClass: DepClassTy::REQUIRED);
5322
5323 if (!AA || !AA->isValidState()) {
5324 SimplifiedValue = nullptr;
5325 return indicatePessimisticFixpoint();
5326 }
5327
5328 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5329 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5330 ++KnownSPMDCount;
5331 else
5332 ++AssumedSPMDCount;
5333 } else {
5334 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5335 ++KnownNonSPMDCount;
5336 else
5337 ++AssumedNonSPMDCount;
5338 }
5339 }
5340
5341 if ((AssumedSPMDCount + KnownSPMDCount) &&
5342 (AssumedNonSPMDCount + KnownNonSPMDCount))
5343 return indicatePessimisticFixpoint();
5344
5345 auto &Ctx = getAnchorValue().getContext();
5346 if (KnownSPMDCount || AssumedSPMDCount) {
5347 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5348 "Expected only SPMD kernels!");
5349 // All reaching kernels are in SPMD mode. Update all function calls to
5350 // __kmpc_is_spmd_exec_mode to 1.
5351 SimplifiedValue = ConstantInt::get(Ty: Type::getInt8Ty(C&: Ctx), V: true);
5352 } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
5353 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5354 "Expected only non-SPMD kernels!");
5355 // All reaching kernels are in non-SPMD mode. Update all function
5356 // calls to __kmpc_is_spmd_exec_mode to 0.
5357 SimplifiedValue = ConstantInt::get(Ty: Type::getInt8Ty(C&: Ctx), V: false);
5358 } else {
5359 // We have empty reaching kernels, therefore we cannot tell if the
5360 // associated call site can be folded. At this moment, SimplifiedValue
5361 // must be none.
5362 assert(!SimplifiedValue && "SimplifiedValue should be none");
5363 }
5364
5365 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5366 : ChangeStatus::CHANGED;
5367 }
5368
5369 /// Fold __kmpc_parallel_level into a constant if possible.
5370 ChangeStatus foldParallelLevel(Attributor &A) {
5371 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5372
5373 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5374 QueryingAA: *this, IRP: IRPosition::function(F: *getAnchorScope()), DepClass: DepClassTy::REQUIRED);
5375
5376 if (!CallerKernelInfoAA ||
5377 !CallerKernelInfoAA->ParallelLevels.isValidState())
5378 return indicatePessimisticFixpoint();
5379
5380 if (!CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5381 return indicatePessimisticFixpoint();
5382
5383 if (CallerKernelInfoAA->ReachingKernelEntries.empty()) {
5384 assert(!SimplifiedValue &&
5385 "SimplifiedValue should keep none at this point");
5386 return ChangeStatus::UNCHANGED;
5387 }
5388
5389 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5390 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5391 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5392 auto *AA = A.getAAFor<AAKernelInfo>(QueryingAA: *this, IRP: IRPosition::function(F: *K),
5393 DepClass: DepClassTy::REQUIRED);
5394 if (!AA || !AA->SPMDCompatibilityTracker.isValidState())
5395 return indicatePessimisticFixpoint();
5396
5397 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5398 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5399 ++KnownSPMDCount;
5400 else
5401 ++AssumedSPMDCount;
5402 } else {
5403 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5404 ++KnownNonSPMDCount;
5405 else
5406 ++AssumedNonSPMDCount;
5407 }
5408 }
5409
5410 if ((AssumedSPMDCount + KnownSPMDCount) &&
5411 (AssumedNonSPMDCount + KnownNonSPMDCount))
5412 return indicatePessimisticFixpoint();
5413
5414 auto &Ctx = getAnchorValue().getContext();
5415 // If the caller can only be reached by SPMD kernel entries, the parallel
5416 // level is 1. Similarly, if the caller can only be reached by non-SPMD
5417 // kernel entries, it is 0.
5418 if (AssumedSPMDCount || KnownSPMDCount) {
5419 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5420 "Expected only SPMD kernels!");
5421 SimplifiedValue = ConstantInt::get(Ty: Type::getInt8Ty(C&: Ctx), V: 1);
5422 } else {
5423 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5424 "Expected only non-SPMD kernels!");
5425 SimplifiedValue = ConstantInt::get(Ty: Type::getInt8Ty(C&: Ctx), V: 0);
5426 }
5427 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5428 : ChangeStatus::CHANGED;
5429 }
5430
5431 ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
5432 // Specialize only if all the calls agree with the attribute constant value
5433 int32_t CurrentAttrValue = -1;
5434 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5435
5436 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5437 QueryingAA: *this, IRP: IRPosition::function(F: *getAnchorScope()), DepClass: DepClassTy::REQUIRED);
5438
5439 if (!CallerKernelInfoAA ||
5440 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5441 return indicatePessimisticFixpoint();
5442
5443 // Iterate over the kernels that reach this function
5444 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5445 int32_t NextAttrVal = K->getFnAttributeAsParsedInteger(Kind: Attr, Default: -1);
5446
5447 if (NextAttrVal == -1 ||
5448 (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
5449 return indicatePessimisticFixpoint();
5450 CurrentAttrValue = NextAttrVal;
5451 }
5452
5453 if (CurrentAttrValue != -1) {
5454 auto &Ctx = getAnchorValue().getContext();
5455 SimplifiedValue =
5456 ConstantInt::get(Ty: Type::getInt32Ty(C&: Ctx), V: CurrentAttrValue);
5457 }
5458 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5459 : ChangeStatus::CHANGED;
5460 }
5461
5462 /// An optional value the associated value is assumed to fold to. That is, we
5463 /// assume the associated value (which is a call) can be replaced by this
5464 /// simplified value.
5465 std::optional<Value *> SimplifiedValue;
5466
5467 /// The runtime function kind of the callee of the associated call site.
5468 RuntimeFunction RFKind;
5469};
5470
5471} // namespace
5472
5473/// Register folding callsite
5474void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
5475 auto &RFI = OMPInfoCache.RFIs[RF];
5476 RFI.foreachUse(SCC, CB: [&](Use &U, Function &F) {
5477 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, RFI: &RFI);
5478 if (!CI)
5479 return false;
5480 A.getOrCreateAAFor<AAFoldRuntimeCall>(
5481 IRP: IRPosition::callsite_returned(CB: *CI), /* QueryingAA */ nullptr,
5482 DepClass: DepClassTy::NONE, /* ForceUpdate */ false,
5483 /* UpdateAfterInit */ false);
5484 return false;
5485 });
5486}
5487
5488void OpenMPOpt::registerAAs(bool IsModulePass) {
5489 if (SCC.empty())
5490 return;
5491
5492 if (IsModulePass) {
5493 // Ensure we create the AAKernelInfo AAs first and without triggering an
5494 // update. This will make sure we register all value simplification
5495 // callbacks before any other AA has the chance to create an AAValueSimplify
5496 // or similar.
5497 auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {
5498 A.getOrCreateAAFor<AAKernelInfo>(
5499 IRP: IRPosition::function(F: Kernel), /* QueryingAA */ nullptr,
5500 DepClass: DepClassTy::NONE, /* ForceUpdate */ false,
5501 /* UpdateAfterInit */ false);
5502 return false;
5503 };
5504 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
5505 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
5506 InitRFI.foreachUse(SCC, CB: CreateKernelInfoCB);
5507
5508 registerFoldRuntimeCall(RF: OMPRTL___kmpc_is_spmd_exec_mode);
5509 registerFoldRuntimeCall(RF: OMPRTL___kmpc_parallel_level);
5510 registerFoldRuntimeCall(RF: OMPRTL___kmpc_get_hardware_num_threads_in_block);
5511 registerFoldRuntimeCall(RF: OMPRTL___kmpc_get_hardware_num_blocks);
5512 }
5513
5514 // Create CallSite AA for all Getters.
5515 if (DeduceICVValues) {
5516 for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
5517 auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
5518
5519 auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
5520
5521 auto CreateAA = [&](Use &U, Function &Caller) {
5522 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, RFI: &GetterRFI);
5523 if (!CI)
5524 return false;
5525
5526 auto &CB = cast<CallBase>(Val&: *CI);
5527
5528 IRPosition CBPos = IRPosition::callsite_function(CB);
5529 A.getOrCreateAAFor<AAICVTracker>(IRP: CBPos);
5530 return false;
5531 };
5532
5533 GetterRFI.foreachUse(SCC, CB: CreateAA);
5534 }
5535 }
5536
5537 // Create an ExecutionDomain AA for every function and a HeapToStack AA for
5538 // every function if there is a device kernel.
5539 if (!isOpenMPDevice(M))
5540 return;
5541
5542 for (auto *F : SCC) {
5543 if (F->isDeclaration())
5544 continue;
5545
5546 // We look at internal functions only on-demand but if any use is not a
5547 // direct call or outside the current set of analyzed functions, we have
5548 // to do it eagerly.
5549 if (F->hasLocalLinkage()) {
5550 if (llvm::all_of(Range: F->uses(), P: [this](const Use &U) {
5551 const auto *CB = dyn_cast<CallBase>(Val: U.getUser());
5552 return CB && CB->isCallee(U: &U) &&
5553 A.isRunOn(Fn: const_cast<Function *>(CB->getCaller()));
5554 }))
5555 continue;
5556 }
5557 registerAAsForFunction(A, F: *F);
5558 }
5559}
5560
5561void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
5562 if (!DisableOpenMPOptDeglobalization)
5563 A.getOrCreateAAFor<AAHeapToShared>(IRP: IRPosition::function(F));
5564 A.getOrCreateAAFor<AAExecutionDomain>(IRP: IRPosition::function(F));
5565 if (!DisableOpenMPOptDeglobalization)
5566 A.getOrCreateAAFor<AAHeapToStack>(IRP: IRPosition::function(F));
5567 if (F.hasFnAttribute(Attribute::Convergent))
5568 A.getOrCreateAAFor<AANonConvergent>(IRP: IRPosition::function(F));
5569
5570 for (auto &I : instructions(F)) {
5571 if (auto *LI = dyn_cast<LoadInst>(Val: &I)) {
5572 bool UsedAssumedInformation = false;
5573 A.getAssumedSimplified(V: IRPosition::value(V: *LI), /* AA */ nullptr,
5574 UsedAssumedInformation, S: AA::Interprocedural);
5575 continue;
5576 }
5577 if (auto *CI = dyn_cast<CallBase>(Val: &I)) {
5578 if (CI->isIndirectCall())
5579 A.getOrCreateAAFor<AAIndirectCallInfo>(
5580 IRP: IRPosition::callsite_function(CB: *CI));
5581 }
5582 if (auto *SI = dyn_cast<StoreInst>(Val: &I)) {
5583 A.getOrCreateAAFor<AAIsDead>(IRP: IRPosition::value(V: *SI));
5584 continue;
5585 }
5586 if (auto *FI = dyn_cast<FenceInst>(Val: &I)) {
5587 A.getOrCreateAAFor<AAIsDead>(IRP: IRPosition::value(V: *FI));
5588 continue;
5589 }
5590 if (auto *II = dyn_cast<IntrinsicInst>(Val: &I)) {
5591 if (II->getIntrinsicID() == Intrinsic::assume) {
5592 A.getOrCreateAAFor<AAPotentialValues>(
5593 IRP: IRPosition::value(V: *II->getArgOperand(i: 0)));
5594 continue;
5595 }
5596 }
5597 }
5598}
5599
5600const char AAICVTracker::ID = 0;
5601const char AAKernelInfo::ID = 0;
5602const char AAExecutionDomain::ID = 0;
5603const char AAHeapToShared::ID = 0;
5604const char AAFoldRuntimeCall::ID = 0;
5605
5606AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
5607 Attributor &A) {
5608 AAICVTracker *AA = nullptr;
5609 switch (IRP.getPositionKind()) {
5610 case IRPosition::IRP_INVALID:
5611 case IRPosition::IRP_FLOAT:
5612 case IRPosition::IRP_ARGUMENT:
5613 case IRPosition::IRP_CALL_SITE_ARGUMENT:
5614 llvm_unreachable("ICVTracker can only be created for function position!");
5615 case IRPosition::IRP_RETURNED:
5616 AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
5617 break;
5618 case IRPosition::IRP_CALL_SITE_RETURNED:
5619 AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
5620 break;
5621 case IRPosition::IRP_CALL_SITE:
5622 AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
5623 break;
5624 case IRPosition::IRP_FUNCTION:
5625 AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
5626 break;
5627 }
5628
5629 return *AA;
5630}
5631
5632AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP,
5633 Attributor &A) {
5634 AAExecutionDomainFunction *AA = nullptr;
5635 switch (IRP.getPositionKind()) {
5636 case IRPosition::IRP_INVALID:
5637 case IRPosition::IRP_FLOAT:
5638 case IRPosition::IRP_ARGUMENT:
5639 case IRPosition::IRP_CALL_SITE_ARGUMENT:
5640 case IRPosition::IRP_RETURNED:
5641 case IRPosition::IRP_CALL_SITE_RETURNED:
5642 case IRPosition::IRP_CALL_SITE:
5643 llvm_unreachable(
5644 "AAExecutionDomain can only be created for function position!");
5645 case IRPosition::IRP_FUNCTION:
5646 AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
5647 break;
5648 }
5649
5650 return *AA;
5651}
5652
5653AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
5654 Attributor &A) {
5655 AAHeapToSharedFunction *AA = nullptr;
5656 switch (IRP.getPositionKind()) {
5657 case IRPosition::IRP_INVALID:
5658 case IRPosition::IRP_FLOAT:
5659 case IRPosition::IRP_ARGUMENT:
5660 case IRPosition::IRP_CALL_SITE_ARGUMENT:
5661 case IRPosition::IRP_RETURNED:
5662 case IRPosition::IRP_CALL_SITE_RETURNED:
5663 case IRPosition::IRP_CALL_SITE:
5664 llvm_unreachable(
5665 "AAHeapToShared can only be created for function position!");
5666 case IRPosition::IRP_FUNCTION:
5667 AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
5668 break;
5669 }
5670
5671 return *AA;
5672}
5673
5674AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
5675 Attributor &A) {
5676 AAKernelInfo *AA = nullptr;
5677 switch (IRP.getPositionKind()) {
5678 case IRPosition::IRP_INVALID:
5679 case IRPosition::IRP_FLOAT:
5680 case IRPosition::IRP_ARGUMENT:
5681 case IRPosition::IRP_RETURNED:
5682 case IRPosition::IRP_CALL_SITE_RETURNED:
5683 case IRPosition::IRP_CALL_SITE_ARGUMENT:
5684 llvm_unreachable("KernelInfo can only be created for function position!");
5685 case IRPosition::IRP_CALL_SITE:
5686 AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
5687 break;
5688 case IRPosition::IRP_FUNCTION:
5689 AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
5690 break;
5691 }
5692
5693 return *AA;
5694}
5695
5696AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
5697 Attributor &A) {
5698 AAFoldRuntimeCall *AA = nullptr;
5699 switch (IRP.getPositionKind()) {
5700 case IRPosition::IRP_INVALID:
5701 case IRPosition::IRP_FLOAT:
5702 case IRPosition::IRP_ARGUMENT:
5703 case IRPosition::IRP_RETURNED:
5704 case IRPosition::IRP_FUNCTION:
5705 case IRPosition::IRP_CALL_SITE:
5706 case IRPosition::IRP_CALL_SITE_ARGUMENT:
5707 llvm_unreachable("KernelInfo can only be created for call site position!");
5708 case IRPosition::IRP_CALL_SITE_RETURNED:
5709 AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
5710 break;
5711 }
5712
5713 return *AA;
5714}
5715
5716PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
5717 if (!containsOpenMP(M))
5718 return PreservedAnalyses::all();
5719 if (DisableOpenMPOptimizations)
5720 return PreservedAnalyses::all();
5721
5722 FunctionAnalysisManager &FAM =
5723 AM.getResult<FunctionAnalysisManagerModuleProxy>(IR&: M).getManager();
5724 KernelSet Kernels = getDeviceKernels(M);
5725
5726 if (PrintModuleBeforeOptimizations)
5727 LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt Module Pass:\n" << M);
5728
5729 auto IsCalled = [&](Function &F) {
5730 if (Kernels.contains(key: &F))
5731 return true;
5732 for (const User *U : F.users())
5733 if (!isa<BlockAddress>(Val: U))
5734 return true;
5735 return false;
5736 };
5737
5738 auto EmitRemark = [&](Function &F) {
5739 auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(IR&: F);
5740 ORE.emit(RemarkBuilder: [&]() {
5741 OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
5742 return ORA << "Could not internalize function. "
5743 << "Some optimizations may not be possible. [OMP140]";
5744 });
5745 };
5746
5747 bool Changed = false;
5748
5749 // Create internal copies of each function if this is a kernel Module. This
5750 // allows iterprocedural passes to see every call edge.
5751 DenseMap<Function *, Function *> InternalizedMap;
5752 if (isOpenMPDevice(M)) {
5753 SmallPtrSet<Function *, 16> InternalizeFns;
5754 for (Function &F : M)
5755 if (!F.isDeclaration() && !Kernels.contains(key: &F) && IsCalled(F) &&
5756 !DisableInternalization) {
5757 if (Attributor::isInternalizable(F)) {
5758 InternalizeFns.insert(Ptr: &F);
5759 } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
5760 EmitRemark(F);
5761 }
5762 }
5763
5764 Changed |=
5765 Attributor::internalizeFunctions(FnSet&: InternalizeFns, FnMap&: InternalizedMap);
5766 }
5767
5768 // Look at every function in the Module unless it was internalized.
5769 SetVector<Function *> Functions;
5770 SmallVector<Function *, 16> SCC;
5771 for (Function &F : M)
5772 if (!F.isDeclaration() && !InternalizedMap.lookup(Val: &F)) {
5773 SCC.push_back(Elt: &F);
5774 Functions.insert(X: &F);
5775 }
5776
5777 if (SCC.empty())
5778 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
5779
5780 AnalysisGetter AG(FAM);
5781
5782 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5783 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(IR&: *F);
5784 };
5785
5786 BumpPtrAllocator Allocator;
5787 CallGraphUpdater CGUpdater;
5788
5789 bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5790 LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink;
5791 OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, PostLink);
5792
5793 unsigned MaxFixpointIterations =
5794 (isOpenMPDevice(M)) ? SetFixpointIterations : 32;
5795
5796 AttributorConfig AC(CGUpdater);
5797 AC.DefaultInitializeLiveInternals = false;
5798 AC.IsModulePass = true;
5799 AC.RewriteSignatures = false;
5800 AC.MaxFixpointIterations = MaxFixpointIterations;
5801 AC.OREGetter = OREGetter;
5802 AC.PassName = DEBUG_TYPE;
5803 AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5804 AC.IPOAmendableCB = [](const Function &F) {
5805 return F.hasFnAttribute(Kind: "kernel");
5806 };
5807
5808 Attributor A(Functions, InfoCache, AC);
5809
5810 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5811 Changed |= OMPOpt.run(IsModulePass: true);
5812
5813 // Optionally inline device functions for potentially better performance.
5814 if (AlwaysInlineDeviceFunctions && isOpenMPDevice(M))
5815 for (Function &F : M)
5816 if (!F.isDeclaration() && !Kernels.contains(&F) &&
5817 !F.hasFnAttribute(Attribute::NoInline))
5818 F.addFnAttr(Attribute::AlwaysInline);
5819
5820 if (PrintModuleAfterOptimizations)
5821 LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
5822
5823 if (Changed)
5824 return PreservedAnalyses::none();
5825
5826 return PreservedAnalyses::all();
5827}
5828
5829PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C,
5830 CGSCCAnalysisManager &AM,
5831 LazyCallGraph &CG,
5832 CGSCCUpdateResult &UR) {
5833 if (!containsOpenMP(M&: *C.begin()->getFunction().getParent()))
5834 return PreservedAnalyses::all();
5835 if (DisableOpenMPOptimizations)
5836 return PreservedAnalyses::all();
5837
5838 SmallVector<Function *, 16> SCC;
5839 // If there are kernels in the module, we have to run on all SCC's.
5840 for (LazyCallGraph::Node &N : C) {
5841 Function *Fn = &N.getFunction();
5842 SCC.push_back(Elt: Fn);
5843 }
5844
5845 if (SCC.empty())
5846 return PreservedAnalyses::all();
5847
5848 Module &M = *C.begin()->getFunction().getParent();
5849
5850 if (PrintModuleBeforeOptimizations)
5851 LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt CGSCC Pass:\n" << M);
5852
5853 KernelSet Kernels = getDeviceKernels(M);
5854
5855 FunctionAnalysisManager &FAM =
5856 AM.getResult<FunctionAnalysisManagerCGSCCProxy>(IR&: C, ExtraArgs&: CG).getManager();
5857
5858 AnalysisGetter AG(FAM);
5859
5860 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5861 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(IR&: *F);
5862 };
5863
5864 BumpPtrAllocator Allocator;
5865 CallGraphUpdater CGUpdater;
5866 CGUpdater.initialize(LCG&: CG, SCC&: C, AM, UR);
5867
5868 bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5869 LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink;
5870 SetVector<Function *> Functions(SCC.begin(), SCC.end());
5871 OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
5872 /*CGSCC*/ &Functions, PostLink);
5873
5874 unsigned MaxFixpointIterations =
5875 (isOpenMPDevice(M)) ? SetFixpointIterations : 32;
5876
5877 AttributorConfig AC(CGUpdater);
5878 AC.DefaultInitializeLiveInternals = false;
5879 AC.IsModulePass = false;
5880 AC.RewriteSignatures = false;
5881 AC.MaxFixpointIterations = MaxFixpointIterations;
5882 AC.OREGetter = OREGetter;
5883 AC.PassName = DEBUG_TYPE;
5884 AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5885
5886 Attributor A(Functions, InfoCache, AC);
5887
5888 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5889 bool Changed = OMPOpt.run(IsModulePass: false);
5890
5891 if (PrintModuleAfterOptimizations)
5892 LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
5893
5894 if (Changed)
5895 return PreservedAnalyses::none();
5896
5897 return PreservedAnalyses::all();
5898}
5899
5900bool llvm::omp::isOpenMPKernel(Function &Fn) {
5901 return Fn.hasFnAttribute(Kind: "kernel");
5902}
5903
5904KernelSet llvm::omp::getDeviceKernels(Module &M) {
5905 // TODO: Create a more cross-platform way of determining device kernels.
5906 NamedMDNode *MD = M.getNamedMetadata(Name: "nvvm.annotations");
5907 KernelSet Kernels;
5908
5909 if (!MD)
5910 return Kernels;
5911
5912 for (auto *Op : MD->operands()) {
5913 if (Op->getNumOperands() < 2)
5914 continue;
5915 MDString *KindID = dyn_cast<MDString>(Val: Op->getOperand(I: 1));
5916 if (!KindID || KindID->getString() != "kernel")
5917 continue;
5918
5919 Function *KernelFn =
5920 mdconst::dyn_extract_or_null<Function>(MD: Op->getOperand(I: 0));
5921 if (!KernelFn)
5922 continue;
5923
5924 // We are only interested in OpenMP target regions. Others, such as kernels
5925 // generated by CUDA but linked together, are not interesting to this pass.
5926 if (isOpenMPKernel(Fn&: *KernelFn)) {
5927 ++NumOpenMPTargetRegionKernels;
5928 Kernels.insert(X: KernelFn);
5929 } else
5930 ++NumNonOpenMPTargetRegionKernels;
5931 }
5932
5933 return Kernels;
5934}
5935
5936bool llvm::omp::containsOpenMP(Module &M) {
5937 Metadata *MD = M.getModuleFlag(Key: "openmp");
5938 if (!MD)
5939 return false;
5940
5941 return true;
5942}
5943
5944bool llvm::omp::isOpenMPDevice(Module &M) {
5945 Metadata *MD = M.getModuleFlag(Key: "openmp-device");
5946 if (!MD)
5947 return false;
5948
5949 return true;
5950}
5951

source code of llvm/lib/Transforms/IPO/OpenMPOpt.cpp