1//===- FunctionExtras.h - Function type erasure utilities -------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8/// \file
9/// This file provides a collection of function (or more generally, callable)
10/// type erasure utilities supplementing those provided by the standard library
11/// in `<function>`.
12///
13/// It provides `unique_function`, which works like `std::function` but supports
14/// move-only callable objects and const-qualification.
15///
16/// Future plans:
17/// - Add a `function` that provides ref-qualified support, which doesn't work
18/// with `std::function`.
19/// - Provide support for specifying multiple signatures to type erase callable
20/// objects with an overload set, such as those produced by generic lambdas.
21/// - Expand to include a copyable utility that directly replaces std::function
22/// but brings the above improvements.
23///
24/// Note that LLVM's utilities are greatly simplified by not supporting
25/// allocators.
26///
27/// If the standard library ever begins to provide comparable facilities we can
28/// consider switching to those.
29///
30//===----------------------------------------------------------------------===//
31
32#ifndef LLVM_ADT_FUNCTIONEXTRAS_H
33#define LLVM_ADT_FUNCTIONEXTRAS_H
34
35#include "llvm/ADT/PointerIntPair.h"
36#include "llvm/ADT/PointerUnion.h"
37#include "llvm/ADT/STLForwardCompat.h"
38#include "llvm/Support/MemAlloc.h"
39#include "llvm/Support/type_traits.h"
40#include <memory>
41#include <type_traits>
42
43namespace llvm {
44
45/// unique_function is a type-erasing functor similar to std::function.
46///
47/// It can hold move-only function objects, like lambdas capturing unique_ptrs.
48/// Accordingly, it is movable but not copyable.
49///
50/// It supports const-qualification:
51/// - unique_function<int() const> has a const operator().
52/// It can only hold functions which themselves have a const operator().
53/// - unique_function<int()> has a non-const operator().
54/// It can hold functions with a non-const operator(), like mutable lambdas.
55template <typename FunctionT> class unique_function;
56
57namespace detail {
58
59template <typename T>
60using EnableIfTrivial =
61 std::enable_if_t<llvm::is_trivially_move_constructible<T>::value &&
62 std::is_trivially_destructible<T>::value>;
63template <typename CallableT, typename ThisT>
64using EnableUnlessSameType =
65 std::enable_if_t<!std::is_same<remove_cvref_t<CallableT>, ThisT>::value>;
66template <typename CallableT, typename Ret, typename... Params>
67using EnableIfCallable =
68 std::enable_if_t<std::is_void<Ret>::value ||
69 std::is_convertible<decltype(std::declval<CallableT>()(
70 std::declval<Params>()...)),
71 Ret>::value>;
72
73template <typename ReturnT, typename... ParamTs> class UniqueFunctionBase {
74protected:
75 static constexpr size_t InlineStorageSize = sizeof(void *) * 3;
76
77 template <typename T, class = void>
78 struct IsSizeLessThanThresholdT : std::false_type {};
79
80 template <typename T>
81 struct IsSizeLessThanThresholdT<
82 T, std::enable_if_t<sizeof(T) <= 2 * sizeof(void *)>> : std::true_type {};
83
84 // Provide a type function to map parameters that won't observe extra copies
85 // or moves and which are small enough to likely pass in register to values
86 // and all other types to l-value reference types. We use this to compute the
87 // types used in our erased call utility to minimize copies and moves unless
88 // doing so would force things unnecessarily into memory.
89 //
90 // The heuristic used is related to common ABI register passing conventions.
91 // It doesn't have to be exact though, and in one way it is more strict
92 // because we want to still be able to observe either moves *or* copies.
93 template <typename T>
94 using AdjustedParamT = typename std::conditional<
95 !std::is_reference<T>::value &&
96 llvm::is_trivially_copy_constructible<T>::value &&
97 llvm::is_trivially_move_constructible<T>::value &&
98 IsSizeLessThanThresholdT<T>::value,
99 T, T &>::type;
100
101 // The type of the erased function pointer we use as a callback to dispatch to
102 // the stored callable when it is trivial to move and destroy.
103 using CallPtrT = ReturnT (*)(void *CallableAddr,
104 AdjustedParamT<ParamTs>... Params);
105 using MovePtrT = void (*)(void *LHSCallableAddr, void *RHSCallableAddr);
106 using DestroyPtrT = void (*)(void *CallableAddr);
107
108 /// A struct to hold a single trivial callback with sufficient alignment for
109 /// our bitpacking.
110 struct alignas(8) TrivialCallback {
111 CallPtrT CallPtr;
112 };
113
114 /// A struct we use to aggregate three callbacks when we need full set of
115 /// operations.
116 struct alignas(8) NonTrivialCallbacks {
117 CallPtrT CallPtr;
118 MovePtrT MovePtr;
119 DestroyPtrT DestroyPtr;
120 };
121
122 // Create a pointer union between either a pointer to a static trivial call
123 // pointer in a struct or a pointer to a static struct of the call, move, and
124 // destroy pointers.
125 using CallbackPointerUnionT =
126 PointerUnion<TrivialCallback *, NonTrivialCallbacks *>;
127
128 // The main storage buffer. This will either have a pointer to out-of-line
129 // storage or an inline buffer storing the callable.
130 union StorageUnionT {
131 // For out-of-line storage we keep a pointer to the underlying storage and
132 // the size. This is enough to deallocate the memory.
133 struct OutOfLineStorageT {
134 void *StoragePtr;
135 size_t Size;
136 size_t Alignment;
137 } OutOfLineStorage;
138 static_assert(
139 sizeof(OutOfLineStorageT) <= InlineStorageSize,
140 "Should always use all of the out-of-line storage for inline storage!");
141
142 // For in-line storage, we just provide an aligned character buffer. We
143 // provide three pointers worth of storage here.
144 // This is mutable as an inlined `const unique_function<void() const>` may
145 // still modify its own mutable members.
146 mutable
147 typename std::aligned_storage<InlineStorageSize, alignof(void *)>::type
148 InlineStorage;
149 } StorageUnion;
150
151 // A compressed pointer to either our dispatching callback or our table of
152 // dispatching callbacks and the flag for whether the callable itself is
153 // stored inline or not.
154 PointerIntPair<CallbackPointerUnionT, 1, bool> CallbackAndInlineFlag;
155
156 bool isInlineStorage() const { return CallbackAndInlineFlag.getInt(); }
157
158 bool isTrivialCallback() const {
159 return CallbackAndInlineFlag.getPointer().template is<TrivialCallback *>();
160 }
161
162 CallPtrT getTrivialCallback() const {
163 return CallbackAndInlineFlag.getPointer().template get<TrivialCallback *>()->CallPtr;
164 }
165
166 NonTrivialCallbacks *getNonTrivialCallbacks() const {
167 return CallbackAndInlineFlag.getPointer()
168 .template get<NonTrivialCallbacks *>();
169 }
170
171 CallPtrT getCallPtr() const {
172 return isTrivialCallback() ? getTrivialCallback()
173 : getNonTrivialCallbacks()->CallPtr;
174 }
175
176 // These three functions are only const in the narrow sense. They return
177 // mutable pointers to function state.
178 // This allows unique_function<T const>::operator() to be const, even if the
179 // underlying functor may be internally mutable.
180 //
181 // const callers must ensure they're only used in const-correct ways.
182 void *getCalleePtr() const {
183 return isInlineStorage() ? getInlineStorage() : getOutOfLineStorage();
184 }
185 void *getInlineStorage() const { return &StorageUnion.InlineStorage; }
186 void *getOutOfLineStorage() const {
187 return StorageUnion.OutOfLineStorage.StoragePtr;
188 }
189
190 size_t getOutOfLineStorageSize() const {
191 return StorageUnion.OutOfLineStorage.Size;
192 }
193 size_t getOutOfLineStorageAlignment() const {
194 return StorageUnion.OutOfLineStorage.Alignment;
195 }
196
197 void setOutOfLineStorage(void *Ptr, size_t Size, size_t Alignment) {
198 StorageUnion.OutOfLineStorage = {Ptr, Size, Alignment};
199 }
200
201 template <typename CalledAsT>
202 static ReturnT CallImpl(void *CallableAddr,
203 AdjustedParamT<ParamTs>... Params) {
204 auto &Func = *reinterpret_cast<CalledAsT *>(CallableAddr);
205 return Func(std::forward<ParamTs>(Params)...);
206 }
207
208 template <typename CallableT>
209 static void MoveImpl(void *LHSCallableAddr, void *RHSCallableAddr) noexcept {
210 new (LHSCallableAddr)
211 CallableT(std::move(*reinterpret_cast<CallableT *>(RHSCallableAddr)));
212 }
213
214 template <typename CallableT>
215 static void DestroyImpl(void *CallableAddr) noexcept {
216 reinterpret_cast<CallableT *>(CallableAddr)->~CallableT();
217 }
218
219 // The pointers to call/move/destroy functions are determined for each
220 // callable type (and called-as type, which determines the overload chosen).
221 // (definitions are out-of-line).
222
223 // By default, we need an object that contains all the different
224 // type erased behaviors needed. Create a static instance of the struct type
225 // here and each instance will contain a pointer to it.
226 // Wrap in a struct to avoid https://gcc.gnu.org/PR71954
227 template <typename CallableT, typename CalledAs, typename Enable = void>
228 struct CallbacksHolder {
229 static NonTrivialCallbacks Callbacks;
230 };
231 // See if we can create a trivial callback. We need the callable to be
232 // trivially moved and trivially destroyed so that we don't have to store
233 // type erased callbacks for those operations.
234 template <typename CallableT, typename CalledAs>
235 struct CallbacksHolder<CallableT, CalledAs, EnableIfTrivial<CallableT>> {
236 static TrivialCallback Callbacks;
237 };
238
239 // A simple tag type so the call-as type to be passed to the constructor.
240 template <typename T> struct CalledAs {};
241
242 // Essentially the "main" unique_function constructor, but subclasses
243 // provide the qualified type to be used for the call.
244 // (We always store a T, even if the call will use a pointer to const T).
245 template <typename CallableT, typename CalledAsT>
246 UniqueFunctionBase(CallableT Callable, CalledAs<CalledAsT>) {
247 bool IsInlineStorage = true;
248 void *CallableAddr = getInlineStorage();
249 if (sizeof(CallableT) > InlineStorageSize ||
250 alignof(CallableT) > alignof(decltype(StorageUnion.InlineStorage))) {
251 IsInlineStorage = false;
252 // Allocate out-of-line storage. FIXME: Use an explicit alignment
253 // parameter in C++17 mode.
254 auto Size = sizeof(CallableT);
255 auto Alignment = alignof(CallableT);
256 CallableAddr = allocate_buffer(Size, Alignment);
257 setOutOfLineStorage(CallableAddr, Size, Alignment);
258 }
259
260 // Now move into the storage.
261 new (CallableAddr) CallableT(std::move(Callable));
262 CallbackAndInlineFlag.setPointerAndInt(
263 &CallbacksHolder<CallableT, CalledAsT>::Callbacks, IsInlineStorage);
264 }
265
266 ~UniqueFunctionBase() {
267 if (!CallbackAndInlineFlag.getPointer())
268 return;
269
270 // Cache this value so we don't re-check it after type-erased operations.
271 bool IsInlineStorage = isInlineStorage();
272
273 if (!isTrivialCallback())
274 getNonTrivialCallbacks()->DestroyPtr(
275 IsInlineStorage ? getInlineStorage() : getOutOfLineStorage());
276
277 if (!IsInlineStorage)
278 deallocate_buffer(getOutOfLineStorage(), getOutOfLineStorageSize(),
279 getOutOfLineStorageAlignment());
280 }
281
282 UniqueFunctionBase(UniqueFunctionBase &&RHS) noexcept {
283 // Copy the callback and inline flag.
284 CallbackAndInlineFlag = RHS.CallbackAndInlineFlag;
285
286 // If the RHS is empty, just copying the above is sufficient.
287 if (!RHS)
288 return;
289
290 if (!isInlineStorage()) {
291 // The out-of-line case is easiest to move.
292 StorageUnion.OutOfLineStorage = RHS.StorageUnion.OutOfLineStorage;
293 } else if (isTrivialCallback()) {
294 // Move is trivial, just memcpy the bytes across.
295 memcpy(getInlineStorage(), RHS.getInlineStorage(), InlineStorageSize);
296 } else {
297 // Non-trivial move, so dispatch to a type-erased implementation.
298 getNonTrivialCallbacks()->MovePtr(getInlineStorage(),
299 RHS.getInlineStorage());
300 }
301
302 // Clear the old callback and inline flag to get back to as-if-null.
303 RHS.CallbackAndInlineFlag = {};
304
305#ifndef NDEBUG
306 // In debug builds, we also scribble across the rest of the storage.
307 memset(RHS.getInlineStorage(), 0xAD, InlineStorageSize);
308#endif
309 }
310
311 UniqueFunctionBase &operator=(UniqueFunctionBase &&RHS) noexcept {
312 if (this == &RHS)
313 return *this;
314
315 // Because we don't try to provide any exception safety guarantees we can
316 // implement move assignment very simply by first destroying the current
317 // object and then move-constructing over top of it.
318 this->~UniqueFunctionBase();
319 new (this) UniqueFunctionBase(std::move(RHS));
320 return *this;
321 }
322
323 UniqueFunctionBase() = default;
324
325public:
326 explicit operator bool() const {
327 return (bool)CallbackAndInlineFlag.getPointer();
328 }
329};
330
331template <typename R, typename... P>
332template <typename CallableT, typename CalledAsT, typename Enable>
333typename UniqueFunctionBase<R, P...>::NonTrivialCallbacks UniqueFunctionBase<
334 R, P...>::CallbacksHolder<CallableT, CalledAsT, Enable>::Callbacks = {
335 &CallImpl<CalledAsT>, &MoveImpl<CallableT>, &DestroyImpl<CallableT>};
336
337template <typename R, typename... P>
338template <typename CallableT, typename CalledAsT>
339typename UniqueFunctionBase<R, P...>::TrivialCallback
340 UniqueFunctionBase<R, P...>::CallbacksHolder<
341 CallableT, CalledAsT, EnableIfTrivial<CallableT>>::Callbacks{
342 &CallImpl<CalledAsT>};
343
344} // namespace detail
345
346template <typename R, typename... P>
347class unique_function<R(P...)> : public detail::UniqueFunctionBase<R, P...> {
348 using Base = detail::UniqueFunctionBase<R, P...>;
349
350public:
351 unique_function() = default;
352 unique_function(std::nullptr_t) {}
353 unique_function(unique_function &&) = default;
354 unique_function(const unique_function &) = delete;
355 unique_function &operator=(unique_function &&) = default;
356 unique_function &operator=(const unique_function &) = delete;
357
358 template <typename CallableT>
359 unique_function(
360 CallableT Callable,
361 detail::EnableUnlessSameType<CallableT, unique_function> * = nullptr,
362 detail::EnableIfCallable<CallableT, R, P...> * = nullptr)
363 : Base(std::forward<CallableT>(Callable),
364 typename Base::template CalledAs<CallableT>{}) {}
365
366 R operator()(P... Params) {
367 return this->getCallPtr()(this->getCalleePtr(), Params...);
368 }
369};
370
371template <typename R, typename... P>
372class unique_function<R(P...) const>
373 : public detail::UniqueFunctionBase<R, P...> {
374 using Base = detail::UniqueFunctionBase<R, P...>;
375
376public:
377 unique_function() = default;
378 unique_function(std::nullptr_t) {}
379 unique_function(unique_function &&) = default;
380 unique_function(const unique_function &) = delete;
381 unique_function &operator=(unique_function &&) = default;
382 unique_function &operator=(const unique_function &) = delete;
383
384 template <typename CallableT>
385 unique_function(
386 CallableT Callable,
387 detail::EnableUnlessSameType<CallableT, unique_function> * = nullptr,
388 detail::EnableIfCallable<const CallableT, R, P...> * = nullptr)
389 : Base(std::forward<CallableT>(Callable),
390 typename Base::template CalledAs<const CallableT>{}) {}
391
392 R operator()(P... Params) const {
393 return this->getCallPtr()(this->getCalleePtr(), Params...);
394 }
395};
396
397} // end namespace llvm
398
399#endif // LLVM_ADT_FUNCTIONEXTRAS_H
400