1 | //===- TypeSwitch.h - Switch functionality for RTTI casting -*- C++ -*-----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | /// |
9 | /// \file |
10 | /// This file implements the TypeSwitch template, which mimics a switch() |
11 | /// statement whose cases are type names. |
12 | /// |
13 | //===-----------------------------------------------------------------------===/ |
14 | |
15 | #ifndef LLVM_ADT_TYPESWITCH_H |
16 | #define LLVM_ADT_TYPESWITCH_H |
17 | |
18 | #include "llvm/ADT/STLExtras.h" |
19 | #include "llvm/Support/Casting.h" |
20 | #include <optional> |
21 | |
22 | namespace llvm { |
23 | namespace detail { |
24 | |
25 | template <typename DerivedT, typename T> class TypeSwitchBase { |
26 | public: |
27 | TypeSwitchBase(const T &value) : value(value) {} |
28 | TypeSwitchBase(TypeSwitchBase &&other) : value(other.value) {} |
29 | ~TypeSwitchBase() = default; |
30 | |
31 | /// TypeSwitchBase is not copyable. |
32 | TypeSwitchBase(const TypeSwitchBase &) = delete; |
33 | void operator=(const TypeSwitchBase &) = delete; |
34 | void operator=(TypeSwitchBase &&other) = delete; |
35 | |
36 | /// Invoke a case on the derived class with multiple case types. |
37 | template <typename CaseT, typename CaseT2, typename... CaseTs, |
38 | typename CallableT> |
39 | // This is marked always_inline and nodebug so it doesn't show up in stack |
40 | // traces at -O0 (or other optimization levels). Large TypeSwitch's are |
41 | // common, are equivalent to a switch, and don't add any value to stack |
42 | // traces. |
43 | LLVM_ATTRIBUTE_ALWAYS_INLINE LLVM_ATTRIBUTE_NODEBUG DerivedT & |
44 | Case(CallableT &&caseFn) { |
45 | DerivedT &derived = static_cast<DerivedT &>(*this); |
46 | return derived.template Case<CaseT>(caseFn) |
47 | .template Case<CaseT2, CaseTs...>(caseFn); |
48 | } |
49 | |
50 | /// Invoke a case on the derived class, inferring the type of the Case from |
51 | /// the first input of the given callable. |
52 | /// Note: This inference rules for this overload are very simple: strip |
53 | /// pointers and references. |
54 | template <typename CallableT> DerivedT &Case(CallableT &&caseFn) { |
55 | using Traits = function_traits<std::decay_t<CallableT>>; |
56 | using CaseT = std::remove_cv_t<std::remove_pointer_t< |
57 | std::remove_reference_t<typename Traits::template arg_t<0>>>>; |
58 | |
59 | DerivedT &derived = static_cast<DerivedT &>(*this); |
60 | return derived.template Case<CaseT>(std::forward<CallableT>(caseFn)); |
61 | } |
62 | |
63 | protected: |
64 | /// Trait to check whether `ValueT` provides a 'dyn_cast' method with type |
65 | /// `CastT`. |
66 | template <typename ValueT, typename CastT> |
67 | using has_dyn_cast_t = |
68 | decltype(std::declval<ValueT &>().template dyn_cast<CastT>()); |
69 | |
70 | /// Attempt to dyn_cast the given `value` to `CastT`. This overload is |
71 | /// selected if `value` already has a suitable dyn_cast method. |
72 | template <typename CastT, typename ValueT> |
73 | static decltype(auto) castValue( |
74 | ValueT &&value, |
75 | std::enable_if_t<is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = |
76 | nullptr) { |
77 | return value.template dyn_cast<CastT>(); |
78 | } |
79 | |
80 | /// Attempt to dyn_cast the given `value` to `CastT`. This overload is |
81 | /// selected if llvm::dyn_cast should be used. |
82 | template <typename CastT, typename ValueT> |
83 | static decltype(auto) castValue( |
84 | ValueT &&value, |
85 | std::enable_if_t<!is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = |
86 | nullptr) { |
87 | return dyn_cast<CastT>(value); |
88 | } |
89 | |
90 | /// The root value we are switching on. |
91 | const T value; |
92 | }; |
93 | } // end namespace detail |
94 | |
95 | /// This class implements a switch-like dispatch statement for a value of 'T' |
96 | /// using dyn_cast functionality. Each `Case<T>` takes a callable to be invoked |
97 | /// if the root value isa<T>, the callable is invoked with the result of |
98 | /// dyn_cast<T>() as a parameter. |
99 | /// |
100 | /// Example: |
101 | /// Operation *op = ...; |
102 | /// LogicalResult result = TypeSwitch<Operation *, LogicalResult>(op) |
103 | /// .Case<ConstantOp>([](ConstantOp op) { ... }) |
104 | /// .Default([](Operation *op) { ... }); |
105 | /// |
106 | template <typename T, typename ResultT = void> |
107 | class TypeSwitch : public detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T> { |
108 | public: |
109 | using BaseT = detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T>; |
110 | using BaseT::BaseT; |
111 | using BaseT::Case; |
112 | TypeSwitch(TypeSwitch &&other) = default; |
113 | |
114 | /// Add a case on the given type. |
115 | template <typename CaseT, typename CallableT> |
116 | TypeSwitch<T, ResultT> &Case(CallableT &&caseFn) { |
117 | if (result) |
118 | return *this; |
119 | |
120 | // Check to see if CaseT applies to 'value'. |
121 | if (auto caseValue = BaseT::template castValue<CaseT>(this->value)) |
122 | result.emplace(caseFn(caseValue)); |
123 | return *this; |
124 | } |
125 | |
126 | /// As a default, invoke the given callable within the root value. |
127 | template <typename CallableT> |
128 | [[nodiscard]] ResultT Default(CallableT &&defaultFn) { |
129 | if (result) |
130 | return std::move(*result); |
131 | return defaultFn(this->value); |
132 | } |
133 | /// As a default, return the given value. |
134 | [[nodiscard]] ResultT Default(ResultT defaultResult) { |
135 | if (result) |
136 | return std::move(*result); |
137 | return defaultResult; |
138 | } |
139 | |
140 | [[nodiscard]] operator ResultT() { |
141 | assert(result && "Fell off the end of a type-switch" ); |
142 | return std::move(*result); |
143 | } |
144 | |
145 | private: |
146 | /// The pointer to the result of this switch statement, once known, |
147 | /// null before that. |
148 | std::optional<ResultT> result; |
149 | }; |
150 | |
151 | /// Specialization of TypeSwitch for void returning callables. |
152 | template <typename T> |
153 | class TypeSwitch<T, void> |
154 | : public detail::TypeSwitchBase<TypeSwitch<T, void>, T> { |
155 | public: |
156 | using BaseT = detail::TypeSwitchBase<TypeSwitch<T, void>, T>; |
157 | using BaseT::BaseT; |
158 | using BaseT::Case; |
159 | TypeSwitch(TypeSwitch &&other) = default; |
160 | |
161 | /// Add a case on the given type. |
162 | template <typename CaseT, typename CallableT> |
163 | TypeSwitch<T, void> &Case(CallableT &&caseFn) { |
164 | if (foundMatch) |
165 | return *this; |
166 | |
167 | // Check to see if any of the types apply to 'value'. |
168 | if (auto caseValue = BaseT::template castValue<CaseT>(this->value)) { |
169 | caseFn(caseValue); |
170 | foundMatch = true; |
171 | } |
172 | return *this; |
173 | } |
174 | |
175 | /// As a default, invoke the given callable within the root value. |
176 | template <typename CallableT> void Default(CallableT &&defaultFn) { |
177 | if (!foundMatch) |
178 | defaultFn(this->value); |
179 | } |
180 | |
181 | private: |
182 | /// A flag detailing if we have already found a match. |
183 | bool foundMatch = false; |
184 | }; |
185 | } // end namespace llvm |
186 | |
187 | #endif // LLVM_ADT_TYPESWITCH_H |
188 | |