1//===- TypeSwitch.h - Switch functionality for RTTI casting -*- C++ -*-----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the TypeSwitch template, which mimics a switch()
10// statement whose cases are type names.
11//
12//===-----------------------------------------------------------------------===/
13
14#ifndef LLVM_ADT_TYPESWITCH_H
15#define LLVM_ADT_TYPESWITCH_H
16
17#include "llvm/ADT/Optional.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/Support/Casting.h"
20
21namespace llvm {
22namespace detail {
23
24template <typename DerivedT, typename T> class TypeSwitchBase {
25public:
26 TypeSwitchBase(const T &value) : value(value) {}
27 TypeSwitchBase(TypeSwitchBase &&other) : value(other.value) {}
28 ~TypeSwitchBase() = default;
29
30 /// TypeSwitchBase is not copyable.
31 TypeSwitchBase(const TypeSwitchBase &) = delete;
32 void operator=(const TypeSwitchBase &) = delete;
33 void operator=(TypeSwitchBase &&other) = delete;
34
35 /// Invoke a case on the derived class with multiple case types.
36 template <typename CaseT, typename CaseT2, typename... CaseTs,
37 typename CallableT>
38 DerivedT &Case(CallableT &&caseFn) {
39 DerivedT &derived = static_cast<DerivedT &>(*this);
40 return derived.template Case<CaseT>(caseFn)
41 .template Case<CaseT2, CaseTs...>(caseFn);
42 }
43
44 /// Invoke a case on the derived class, inferring the type of the Case from
45 /// the first input of the given callable.
46 /// Note: This inference rules for this overload are very simple: strip
47 /// pointers and references.
48 template <typename CallableT> DerivedT &Case(CallableT &&caseFn) {
49 using Traits = function_traits<std::decay_t<CallableT>>;
50 using CaseT = std::remove_cv_t<std::remove_pointer_t<
51 std::remove_reference_t<typename Traits::template arg_t<0>>>>;
52
53 DerivedT &derived = static_cast<DerivedT &>(*this);
54 return derived.template Case<CaseT>(std::forward<CallableT>(caseFn));
55 }
56
57protected:
58 /// Trait to check whether `ValueT` provides a 'dyn_cast' method with type
59 /// `CastT`.
60 template <typename ValueT, typename CastT>
61 using has_dyn_cast_t =
62 decltype(std::declval<ValueT &>().template dyn_cast<CastT>());
63
64 /// Attempt to dyn_cast the given `value` to `CastT`. This overload is
65 /// selected if `value` already has a suitable dyn_cast method.
66 template <typename CastT, typename ValueT>
67 static auto castValue(
68 ValueT value,
69 typename std::enable_if_t<
70 is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = nullptr) {
71 return value.template dyn_cast<CastT>();
72 }
73
74 /// Attempt to dyn_cast the given `value` to `CastT`. This overload is
75 /// selected if llvm::dyn_cast should be used.
76 template <typename CastT, typename ValueT>
77 static auto castValue(
78 ValueT value,
79 typename std::enable_if_t<
80 !is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = nullptr) {
81 return dyn_cast<CastT>(value);
82 }
83
84 /// The root value we are switching on.
85 const T value;
86};
87} // end namespace detail
88
89/// This class implements a switch-like dispatch statement for a value of 'T'
90/// using dyn_cast functionality. Each `Case<T>` takes a callable to be invoked
91/// if the root value isa<T>, the callable is invoked with the result of
92/// dyn_cast<T>() as a parameter.
93///
94/// Example:
95/// Operation *op = ...;
96/// LogicalResult result = TypeSwitch<Operation *, LogicalResult>(op)
97/// .Case<ConstantOp>([](ConstantOp op) { ... })
98/// .Default([](Operation *op) { ... });
99///
100template <typename T, typename ResultT = void>
101class TypeSwitch : public detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T> {
102public:
103 using BaseT = detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T>;
104 using BaseT::BaseT;
105 using BaseT::Case;
106 TypeSwitch(TypeSwitch &&other) = default;
107
108 /// Add a case on the given type.
109 template <typename CaseT, typename CallableT>
110 TypeSwitch<T, ResultT> &Case(CallableT &&caseFn) {
111 if (result)
112 return *this;
113
114 // Check to see if CaseT applies to 'value'.
115 if (auto caseValue = BaseT::template castValue<CaseT>(this->value))
116 result = caseFn(caseValue);
117 return *this;
118 }
119
120 /// As a default, invoke the given callable within the root value.
121 template <typename CallableT>
122 LLVM_NODISCARD ResultT Default(CallableT &&defaultFn) {
123 if (result)
124 return std::move(*result);
125 return defaultFn(this->value);
126 }
127 /// As a default, return the given value.
128 LLVM_NODISCARD ResultT Default(ResultT defaultResult) {
129 if (result)
130 return std::move(*result);
131 return defaultResult;
132 }
133
134 LLVM_NODISCARD
135 operator ResultT() {
136 assert(result && "Fell off the end of a type-switch");
137 return std::move(*result);
138 }
139
140private:
141 /// The pointer to the result of this switch statement, once known,
142 /// null before that.
143 Optional<ResultT> result;
144};
145
146/// Specialization of TypeSwitch for void returning callables.
147template <typename T>
148class TypeSwitch<T, void>
149 : public detail::TypeSwitchBase<TypeSwitch<T, void>, T> {
150public:
151 using BaseT = detail::TypeSwitchBase<TypeSwitch<T, void>, T>;
152 using BaseT::BaseT;
153 using BaseT::Case;
154 TypeSwitch(TypeSwitch &&other) = default;
155
156 /// Add a case on the given type.
157 template <typename CaseT, typename CallableT>
158 TypeSwitch<T, void> &Case(CallableT &&caseFn) {
159 if (foundMatch)
160 return *this;
161
162 // Check to see if any of the types apply to 'value'.
163 if (auto caseValue = BaseT::template castValue<CaseT>(this->value)) {
164 caseFn(caseValue);
165 foundMatch = true;
166 }
167 return *this;
168 }
169
170 /// As a default, invoke the given callable within the root value.
171 template <typename CallableT> void Default(CallableT &&defaultFn) {
172 if (!foundMatch)
173 defaultFn(this->value);
174 }
175
176private:
177 /// A flag detailing if we have already found a match.
178 bool foundMatch = false;
179};
180} // end namespace llvm
181
182#endif // LLVM_ADT_TYPESWITCH_H
183