1 | //===--------- TaskDispatch.h - ORC task dispatch utils ---------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Task and TaskDispatch classes. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef LLVM_EXECUTIONENGINE_ORC_TASKDISPATCH_H |
14 | #define LLVM_EXECUTIONENGINE_ORC_TASKDISPATCH_H |
15 | |
16 | #include "llvm/Config/llvm-config.h" |
17 | #include "llvm/Support/Debug.h" |
18 | #include "llvm/Support/ExtensibleRTTI.h" |
19 | #include "llvm/Support/raw_ostream.h" |
20 | |
21 | #include <cassert> |
22 | #include <string> |
23 | |
24 | #if LLVM_ENABLE_THREADS |
25 | #include <condition_variable> |
26 | #include <deque> |
27 | #include <mutex> |
28 | #include <thread> |
29 | #endif |
30 | |
31 | namespace llvm { |
32 | namespace orc { |
33 | |
34 | /// Represents an abstract task for ORC to run. |
35 | class Task : public RTTIExtends<Task, RTTIRoot> { |
36 | public: |
37 | static char ID; |
38 | |
39 | virtual ~Task() = default; |
40 | |
41 | /// Description of the task to be performed. Used for logging. |
42 | virtual void printDescription(raw_ostream &OS) = 0; |
43 | |
44 | /// Run the task. |
45 | virtual void run() = 0; |
46 | |
47 | private: |
48 | void anchor() override; |
49 | }; |
50 | |
51 | /// Base class for generic tasks. |
52 | class GenericNamedTask : public RTTIExtends<GenericNamedTask, Task> { |
53 | public: |
54 | static char ID; |
55 | static const char *DefaultDescription; |
56 | }; |
57 | |
58 | /// Generic task implementation. |
59 | template <typename FnT> class GenericNamedTaskImpl : public GenericNamedTask { |
60 | public: |
61 | GenericNamedTaskImpl(FnT &&Fn, std::string DescBuffer) |
62 | : Fn(std::forward<FnT>(Fn)), Desc(DescBuffer.c_str()), |
63 | DescBuffer(std::move(DescBuffer)) {} |
64 | GenericNamedTaskImpl(FnT &&Fn, const char *Desc) |
65 | : Fn(std::forward<FnT>(Fn)), Desc(Desc) { |
66 | assert(Desc && "Description cannot be null" ); |
67 | } |
68 | void printDescription(raw_ostream &OS) override { OS << Desc; } |
69 | void run() override { Fn(); } |
70 | |
71 | private: |
72 | FnT Fn; |
73 | const char *Desc; |
74 | std::string DescBuffer; |
75 | }; |
76 | |
77 | /// Create a generic named task from a std::string description. |
78 | template <typename FnT> |
79 | std::unique_ptr<GenericNamedTask> makeGenericNamedTask(FnT &&Fn, |
80 | std::string Desc) { |
81 | return std::make_unique<GenericNamedTaskImpl<FnT>>(std::forward<FnT>(Fn), |
82 | std::move(Desc)); |
83 | } |
84 | |
85 | /// Create a generic named task from a const char * description. |
86 | template <typename FnT> |
87 | std::unique_ptr<GenericNamedTask> |
88 | makeGenericNamedTask(FnT &&Fn, const char *Desc = nullptr) { |
89 | if (!Desc) |
90 | Desc = GenericNamedTask::DefaultDescription; |
91 | return std::make_unique<GenericNamedTaskImpl<FnT>>(std::forward<FnT>(Fn), |
92 | Desc); |
93 | } |
94 | |
95 | /// Abstract base for classes that dispatch ORC Tasks. |
96 | class TaskDispatcher { |
97 | public: |
98 | virtual ~TaskDispatcher(); |
99 | |
100 | /// Run the given task. |
101 | virtual void dispatch(std::unique_ptr<Task> T) = 0; |
102 | |
103 | /// Called by ExecutionSession. Waits until all tasks have completed. |
104 | virtual void shutdown() = 0; |
105 | }; |
106 | |
107 | /// Runs all tasks on the current thread. |
108 | class InPlaceTaskDispatcher : public TaskDispatcher { |
109 | public: |
110 | void dispatch(std::unique_ptr<Task> T) override; |
111 | void shutdown() override; |
112 | }; |
113 | |
114 | #if LLVM_ENABLE_THREADS |
115 | |
116 | class DynamicThreadPoolTaskDispatcher : public TaskDispatcher { |
117 | public: |
118 | DynamicThreadPoolTaskDispatcher( |
119 | std::optional<size_t> MaxMaterializationThreads) |
120 | : MaxMaterializationThreads(MaxMaterializationThreads) {} |
121 | void dispatch(std::unique_ptr<Task> T) override; |
122 | void shutdown() override; |
123 | private: |
124 | std::mutex DispatchMutex; |
125 | bool Running = true; |
126 | size_t Outstanding = 0; |
127 | std::condition_variable OutstandingCV; |
128 | |
129 | std::optional<size_t> MaxMaterializationThreads; |
130 | size_t NumMaterializationThreads = 0; |
131 | std::deque<std::unique_ptr<Task>> MaterializationTaskQueue; |
132 | }; |
133 | |
134 | #endif // LLVM_ENABLE_THREADS |
135 | |
136 | } // End namespace orc |
137 | } // End namespace llvm |
138 | |
139 | #endif // LLVM_EXECUTIONENGINE_ORC_TASKDISPATCH_H |
140 | |