1 | //===- Attributes.h - MLIR Attribute Classes --------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_IR_ATTRIBUTES_H |
10 | #define MLIR_IR_ATTRIBUTES_H |
11 | |
12 | #include "mlir/IR/AttributeSupport.h" |
13 | #include "llvm/Support/PointerLikeTypeTraits.h" |
14 | |
15 | namespace mlir { |
16 | class AsmState; |
17 | class StringAttr; |
18 | |
19 | /// Attributes are known-constant values of operations. |
20 | /// |
21 | /// Instances of the Attribute class are references to immortal key-value pairs |
22 | /// with immutable, uniqued keys owned by MLIRContext. As such, an Attribute is |
23 | /// a thin wrapper around an underlying storage pointer. Attributes are usually |
24 | /// passed by value. |
25 | class Attribute { |
26 | public: |
27 | /// Utility class for implementing attributes. |
28 | template <typename ConcreteType, typename BaseType, typename StorageType, |
29 | template <typename T> class... Traits> |
30 | using AttrBase = detail::StorageUserBase<ConcreteType, BaseType, StorageType, |
31 | detail::AttributeUniquer, Traits...>; |
32 | |
33 | using ImplType = AttributeStorage; |
34 | using ValueType = void; |
35 | using AbstractTy = AbstractAttribute; |
36 | |
37 | constexpr Attribute() = default; |
38 | /* implicit */ Attribute(const ImplType *impl) |
39 | : impl(const_cast<ImplType *>(impl)) {} |
40 | |
41 | Attribute(const Attribute &other) = default; |
42 | Attribute &operator=(const Attribute &other) = default; |
43 | |
44 | bool operator==(Attribute other) const { return impl == other.impl; } |
45 | bool operator!=(Attribute other) const { return !(*this == other); } |
46 | explicit operator bool() const { return impl; } |
47 | |
48 | bool operator!() const { return impl == nullptr; } |
49 | |
50 | /// Casting utility functions. These are deprecated and will be removed, |
51 | /// please prefer using the `llvm` namespace variants instead. |
52 | template <typename... Tys> |
53 | bool isa() const; |
54 | template <typename... Tys> |
55 | bool isa_and_nonnull() const; |
56 | template <typename U> |
57 | U dyn_cast() const; |
58 | template <typename U> |
59 | U dyn_cast_or_null() const; |
60 | template <typename U> |
61 | U cast() const; |
62 | |
63 | /// Return a unique identifier for the concrete attribute type. This is used |
64 | /// to support dynamic type casting. |
65 | TypeID getTypeID() { return impl->getAbstractAttribute().getTypeID(); } |
66 | |
67 | /// Return the context this attribute belongs to. |
68 | MLIRContext *getContext() const; |
69 | |
70 | /// Get the dialect this attribute is registered to. |
71 | Dialect &getDialect() const { |
72 | return impl->getAbstractAttribute().getDialect(); |
73 | } |
74 | |
75 | /// Print the attribute. If `elideType` is set, the attribute is printed |
76 | /// without a trailing colon type if it has one. |
77 | void print(raw_ostream &os, bool elideType = false) const; |
78 | void print(raw_ostream &os, AsmState &state, bool elideType = false) const; |
79 | void dump() const; |
80 | |
81 | /// Print the attribute without dialect wrapping. |
82 | void printStripped(raw_ostream &os) const; |
83 | void printStripped(raw_ostream &os, AsmState &state) const; |
84 | |
85 | /// Get an opaque pointer to the attribute. |
86 | const void *getAsOpaquePointer() const { return impl; } |
87 | /// Construct an attribute from the opaque pointer representation. |
88 | static Attribute getFromOpaquePointer(const void *ptr) { |
89 | return Attribute(reinterpret_cast<const ImplType *>(ptr)); |
90 | } |
91 | |
92 | friend ::llvm::hash_code hash_value(Attribute arg); |
93 | |
94 | /// Returns true if `InterfaceT` has been promised by the dialect or |
95 | /// implemented. |
96 | template <typename InterfaceT> |
97 | bool hasPromiseOrImplementsInterface() { |
98 | return dialect_extension_detail::hasPromisedInterface( |
99 | getDialect(), getTypeID(), InterfaceT::getInterfaceID()) || |
100 | mlir::isa<InterfaceT>(*this); |
101 | } |
102 | |
103 | /// Returns true if the type was registered with a particular trait. |
104 | template <template <typename T> class Trait> |
105 | bool hasTrait() { |
106 | return getAbstractAttribute().hasTrait<Trait>(); |
107 | } |
108 | |
109 | /// Return the abstract descriptor for this attribute. |
110 | const AbstractTy &getAbstractAttribute() const { |
111 | return impl->getAbstractAttribute(); |
112 | } |
113 | |
114 | /// Walk all of the immediately nested sub-attributes and sub-types. This |
115 | /// method does not recurse into sub elements. |
116 | void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn, |
117 | function_ref<void(Type)> walkTypesFn) const { |
118 | getAbstractAttribute().walkImmediateSubElements(attr: *this, walkAttrsFn, |
119 | walkTypesFn); |
120 | } |
121 | |
122 | /// Replace the immediately nested sub-attributes and sub-types with those |
123 | /// provided. The order of the provided elements is derived from the order of |
124 | /// the elements returned by the callbacks of `walkImmediateSubElements`. The |
125 | /// element at index 0 would replace the very first attribute given by |
126 | /// `walkImmediateSubElements`. On success, the new instance with the values |
127 | /// replaced is returned. If replacement fails, nullptr is returned. |
128 | auto replaceImmediateSubElements(ArrayRef<Attribute> replAttrs, |
129 | ArrayRef<Type> replTypes) const { |
130 | return getAbstractAttribute().replaceImmediateSubElements(attr: *this, replAttrs, |
131 | replTypes); |
132 | } |
133 | |
134 | /// Walk this attribute and all attibutes/types nested within using the |
135 | /// provided walk functions. See `AttrTypeWalker` for information on the |
136 | /// supported walk function types. |
137 | template <WalkOrder Order = WalkOrder::PostOrder, typename... WalkFns> |
138 | auto walk(WalkFns &&...walkFns) { |
139 | AttrTypeWalker walker; |
140 | (walker.addWalk(std::forward<WalkFns>(walkFns)), ...); |
141 | return walker.walk<Order>(*this); |
142 | } |
143 | |
144 | /// Recursively replace all of the nested sub-attributes and sub-types using |
145 | /// the provided map functions. Returns nullptr in the case of failure. See |
146 | /// `AttrTypeReplacer` for information on the support replacement function |
147 | /// types. |
148 | template <typename... ReplacementFns> |
149 | auto replace(ReplacementFns &&...replacementFns) { |
150 | AttrTypeReplacer replacer; |
151 | (replacer.addReplacement(std::forward<ReplacementFns>(replacementFns)), |
152 | ...); |
153 | return replacer.replace(attr: *this); |
154 | } |
155 | |
156 | /// Return the internal Attribute implementation. |
157 | ImplType *getImpl() const { return impl; } |
158 | |
159 | protected: |
160 | ImplType *impl{nullptr}; |
161 | }; |
162 | |
163 | inline raw_ostream &operator<<(raw_ostream &os, Attribute attr) { |
164 | attr.print(os); |
165 | return os; |
166 | } |
167 | |
168 | template <typename... Tys> |
169 | bool Attribute::isa() const { |
170 | return llvm::isa<Tys...>(*this); |
171 | } |
172 | |
173 | template <typename... Tys> |
174 | bool Attribute::isa_and_nonnull() const { |
175 | return llvm::isa_and_present<Tys...>(*this); |
176 | } |
177 | |
178 | template <typename U> |
179 | U Attribute::dyn_cast() const { |
180 | return llvm::dyn_cast<U>(*this); |
181 | } |
182 | |
183 | template <typename U> |
184 | U Attribute::dyn_cast_or_null() const { |
185 | return llvm::dyn_cast_if_present<U>(*this); |
186 | } |
187 | |
188 | template <typename U> |
189 | U Attribute::cast() const { |
190 | return llvm::cast<U>(*this); |
191 | } |
192 | |
193 | inline ::llvm::hash_code hash_value(Attribute arg) { |
194 | return DenseMapInfo<const Attribute::ImplType *>::getHashValue(PtrVal: arg.impl); |
195 | } |
196 | |
197 | //===----------------------------------------------------------------------===// |
198 | // NamedAttribute |
199 | //===----------------------------------------------------------------------===// |
200 | |
201 | /// NamedAttribute represents a combination of a name and an Attribute value. |
202 | class NamedAttribute { |
203 | public: |
204 | NamedAttribute(StringAttr name, Attribute value); |
205 | |
206 | /// Return the name of the attribute. |
207 | StringAttr getName() const; |
208 | |
209 | /// Return the dialect of the name of this attribute, if the name is prefixed |
210 | /// by a dialect namespace. For example, `llvm.fast_math` would return the |
211 | /// LLVM dialect (if it is loaded). Returns nullptr if the dialect isn't |
212 | /// loaded, or if the name is not prefixed by a dialect namespace. |
213 | Dialect *getNameDialect() const; |
214 | |
215 | /// Return the value of the attribute. |
216 | Attribute getValue() const { return value; } |
217 | |
218 | /// Set the name of this attribute. |
219 | void setName(StringAttr newName); |
220 | |
221 | /// Set the value of this attribute. |
222 | void setValue(Attribute newValue) { |
223 | assert(value && "expected valid attribute value" ); |
224 | value = newValue; |
225 | } |
226 | |
227 | /// Compare this attribute to the provided attribute, ordering by name. |
228 | bool operator<(const NamedAttribute &rhs) const; |
229 | /// Compare this attribute to the provided string, ordering by name. |
230 | bool operator<(StringRef rhs) const; |
231 | |
232 | bool operator==(const NamedAttribute &rhs) const { |
233 | return name == rhs.name && value == rhs.value; |
234 | } |
235 | bool operator!=(const NamedAttribute &rhs) const { return !(*this == rhs); } |
236 | |
237 | private: |
238 | NamedAttribute(Attribute name, Attribute value) : name(name), value(value) {} |
239 | |
240 | /// Allow access to internals to enable hashing. |
241 | friend ::llvm::hash_code hash_value(const NamedAttribute &arg); |
242 | friend DenseMapInfo<NamedAttribute>; |
243 | |
244 | /// The name of the attribute. This is represented as a StringAttr, but |
245 | /// type-erased to Attribute in the field. |
246 | Attribute name; |
247 | /// The value of the attribute. |
248 | Attribute value; |
249 | }; |
250 | |
251 | inline ::llvm::hash_code hash_value(const NamedAttribute &arg) { |
252 | using AttrPairT = std::pair<Attribute, Attribute>; |
253 | return DenseMapInfo<AttrPairT>::getHashValue(PairVal: AttrPairT(arg.name, arg.value)); |
254 | } |
255 | |
256 | /// Allow walking and replacing the subelements of a NamedAttribute. |
257 | template <> |
258 | struct AttrTypeSubElementHandler<NamedAttribute> { |
259 | template <typename T> |
260 | static void walk(T param, AttrTypeImmediateSubElementWalker &walker) { |
261 | walker.walk(param.getName()); |
262 | walker.walk(param.getValue()); |
263 | } |
264 | template <typename T> |
265 | static T replace(T param, AttrSubElementReplacements &attrRepls, |
266 | TypeSubElementReplacements &typeRepls) { |
267 | ArrayRef<Attribute> paramRepls = attrRepls.take_front(n: 2); |
268 | return T(cast<decltype(param.getName())>(paramRepls[0]), paramRepls[1]); |
269 | } |
270 | }; |
271 | |
272 | //===----------------------------------------------------------------------===// |
273 | // AttributeTraitBase |
274 | //===----------------------------------------------------------------------===// |
275 | |
276 | namespace AttributeTrait { |
277 | /// This class represents the base of an attribute trait. |
278 | template <typename ConcreteType, template <typename> class TraitType> |
279 | using TraitBase = detail::StorageUserTraitBase<ConcreteType, TraitType>; |
280 | } // namespace AttributeTrait |
281 | |
282 | //===----------------------------------------------------------------------===// |
283 | // AttributeInterface |
284 | //===----------------------------------------------------------------------===// |
285 | |
286 | /// This class represents the base of an attribute interface. See the definition |
287 | /// of `detail::Interface` for requirements on the `Traits` type. |
288 | template <typename ConcreteType, typename Traits> |
289 | class AttributeInterface |
290 | : public detail::Interface<ConcreteType, Attribute, Traits, Attribute, |
291 | AttributeTrait::TraitBase> { |
292 | public: |
293 | using Base = AttributeInterface<ConcreteType, Traits>; |
294 | using InterfaceBase = detail::Interface<ConcreteType, Attribute, Traits, |
295 | Attribute, AttributeTrait::TraitBase>; |
296 | using InterfaceBase::InterfaceBase; |
297 | |
298 | protected: |
299 | /// Returns the impl interface instance for the given type. |
300 | static typename InterfaceBase::Concept *getInterfaceFor(Attribute attr) { |
301 | #ifndef NDEBUG |
302 | // Check that the current interface isn't an unresolved promise for the |
303 | // given attribute. |
304 | dialect_extension_detail::handleUseOfUndefinedPromisedInterface( |
305 | dialect&: attr.getDialect(), interfaceRequestorID: attr.getTypeID(), interfaceID: ConcreteType::getInterfaceID(), |
306 | interfaceName: llvm::getTypeName<ConcreteType>()); |
307 | #endif |
308 | |
309 | return attr.getAbstractAttribute().getInterface<ConcreteType>(); |
310 | } |
311 | |
312 | /// Allow access to 'getInterfaceFor'. |
313 | friend InterfaceBase; |
314 | }; |
315 | |
316 | //===----------------------------------------------------------------------===// |
317 | // Core AttributeTrait |
318 | //===----------------------------------------------------------------------===// |
319 | |
320 | /// This trait is used to determine if an attribute is mutable or not. It is |
321 | /// attached on an attribute if the corresponding ImplType defines a `mutate` |
322 | /// function with proper signature. |
323 | namespace AttributeTrait { |
324 | template <typename ConcreteType> |
325 | using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>; |
326 | } // namespace AttributeTrait |
327 | |
328 | } // namespace mlir. |
329 | |
330 | namespace llvm { |
331 | |
332 | // Attribute hash just like pointers. |
333 | template <> |
334 | struct DenseMapInfo<mlir::Attribute> { |
335 | static mlir::Attribute getEmptyKey() { |
336 | auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); |
337 | return mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)); |
338 | } |
339 | static mlir::Attribute getTombstoneKey() { |
340 | auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey(); |
341 | return mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)); |
342 | } |
343 | static unsigned getHashValue(mlir::Attribute val) { |
344 | return mlir::hash_value(arg: val); |
345 | } |
346 | static bool isEqual(mlir::Attribute LHS, mlir::Attribute RHS) { |
347 | return LHS == RHS; |
348 | } |
349 | }; |
350 | template <typename T> |
351 | struct DenseMapInfo< |
352 | T, std::enable_if_t<std::is_base_of<mlir::Attribute, T>::value && |
353 | !mlir::detail::IsInterface<T>::value>> |
354 | : public DenseMapInfo<mlir::Attribute> { |
355 | static T getEmptyKey() { |
356 | const void *pointer = llvm::DenseMapInfo<const void *>::getEmptyKey(); |
357 | return T::getFromOpaquePointer(pointer); |
358 | } |
359 | static T getTombstoneKey() { |
360 | const void *pointer = llvm::DenseMapInfo<const void *>::getTombstoneKey(); |
361 | return T::getFromOpaquePointer(pointer); |
362 | } |
363 | }; |
364 | |
365 | /// Allow LLVM to steal the low bits of Attributes. |
366 | template <> |
367 | struct PointerLikeTypeTraits<mlir::Attribute> { |
368 | static inline void *getAsVoidPointer(mlir::Attribute attr) { |
369 | return const_cast<void *>(attr.getAsOpaquePointer()); |
370 | } |
371 | static inline mlir::Attribute getFromVoidPointer(void *ptr) { |
372 | return mlir::Attribute::getFromOpaquePointer(ptr); |
373 | } |
374 | static constexpr int NumLowBitsAvailable = llvm::PointerLikeTypeTraits< |
375 | mlir::AttributeStorage *>::NumLowBitsAvailable; |
376 | }; |
377 | |
378 | template <> |
379 | struct DenseMapInfo<mlir::NamedAttribute> { |
380 | static mlir::NamedAttribute getEmptyKey() { |
381 | auto emptyAttr = llvm::DenseMapInfo<mlir::Attribute>::getEmptyKey(); |
382 | return mlir::NamedAttribute(emptyAttr, emptyAttr); |
383 | } |
384 | static mlir::NamedAttribute getTombstoneKey() { |
385 | auto tombAttr = llvm::DenseMapInfo<mlir::Attribute>::getTombstoneKey(); |
386 | return mlir::NamedAttribute(tombAttr, tombAttr); |
387 | } |
388 | static unsigned getHashValue(mlir::NamedAttribute val) { |
389 | return mlir::hash_value(arg: val); |
390 | } |
391 | static bool isEqual(mlir::NamedAttribute lhs, mlir::NamedAttribute rhs) { |
392 | return lhs == rhs; |
393 | } |
394 | }; |
395 | |
396 | /// Add support for llvm style casts. We provide a cast between To and From if |
397 | /// From is mlir::Attribute or derives from it. |
398 | template <typename To, typename From> |
399 | struct CastInfo<To, From, |
400 | std::enable_if_t<std::is_same_v<mlir::Attribute, |
401 | std::remove_const_t<From>> || |
402 | std::is_base_of_v<mlir::Attribute, From>>> |
403 | : NullableValueCastFailed<To>, |
404 | DefaultDoCastIfPossible<To, From, CastInfo<To, From>> { |
405 | /// Arguments are taken as mlir::Attribute here and not as `From`, because |
406 | /// when casting from an intermediate type of the hierarchy to one of its |
407 | /// children, the val.getTypeID() inside T::classof will use the static |
408 | /// getTypeID of the parent instead of the non-static Type::getTypeID that |
409 | /// returns the dynamic ID. This means that T::classof would end up comparing |
410 | /// the static TypeID of the children to the static TypeID of its parent, |
411 | /// making it impossible to downcast from the parent to the child. |
412 | static inline bool isPossible(mlir::Attribute ty) { |
413 | /// Return a constant true instead of a dynamic true when casting to self or |
414 | /// up the hierarchy. |
415 | if constexpr (std::is_base_of_v<To, From>) { |
416 | return true; |
417 | } else { |
418 | return To::classof(ty); |
419 | } |
420 | } |
421 | static inline To doCast(mlir::Attribute attr) { return To(attr.getImpl()); } |
422 | }; |
423 | |
424 | } // namespace llvm |
425 | |
426 | #endif |
427 | |