1//===- Dialect.h - IR Dialect Description -----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the 'dialect' abstraction.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_IR_DIALECT_H
14#define MLIR_IR_DIALECT_H
15
16#include "mlir/IR/DialectRegistry.h"
17#include "mlir/IR/OperationSupport.h"
18#include "mlir/Support/TypeID.h"
19
20#include <map>
21#include <tuple>
22
23namespace mlir {
24class DialectAsmParser;
25class DialectAsmPrinter;
26class DialectInterface;
27class OpBuilder;
28class Type;
29
30//===----------------------------------------------------------------------===//
31// Dialect
32//===----------------------------------------------------------------------===//
33
34/// Dialects are groups of MLIR operations, types and attributes, as well as
35/// behavior associated with the entire group. For example, hooks into other
36/// systems for constant folding, interfaces, default named types for asm
37/// printing, etc.
38///
39/// Instances of the dialect object are loaded in a specific MLIRContext.
40///
41class Dialect {
42public:
43 /// Type for a callback provided by the dialect to parse a custom operation.
44 /// This is used for the dialect to provide an alternative way to parse custom
45 /// operations, including unregistered ones.
46 using ParseOpHook =
47 function_ref<ParseResult(OpAsmParser &parser, OperationState &result)>;
48
49 virtual ~Dialect();
50
51 /// Utility function that returns if the given string is a valid dialect
52 /// namespace
53 static bool isValidNamespace(StringRef str);
54
55 MLIRContext *getContext() const { return context; }
56
57 StringRef getNamespace() const { return name; }
58
59 /// Returns the unique identifier that corresponds to this dialect.
60 TypeID getTypeID() const { return dialectID; }
61
62 /// Returns true if this dialect allows for unregistered operations, i.e.
63 /// operations prefixed with the dialect namespace but not registered with
64 /// addOperation.
65 bool allowsUnknownOperations() const { return unknownOpsAllowed; }
66
67 /// Return true if this dialect allows for unregistered types, i.e., types
68 /// prefixed with the dialect namespace but not registered with addType.
69 /// These are represented with OpaqueType.
70 bool allowsUnknownTypes() const { return unknownTypesAllowed; }
71
72 /// Register dialect-wide canonicalization patterns. This method should only
73 /// be used to register canonicalization patterns that do not conceptually
74 /// belong to any single operation in the dialect. (In that case, use the op's
75 /// canonicalizer.) E.g., canonicalization patterns for op interfaces should
76 /// be registered here.
77 virtual void getCanonicalizationPatterns(RewritePatternSet &results) const {}
78
79 /// Registered hook to materialize a single constant operation from a given
80 /// attribute value with the desired resultant type. This method should use
81 /// the provided builder to create the operation without changing the
82 /// insertion position. The generated operation is expected to be constant
83 /// like, i.e. single result, zero operands, non side-effecting, etc. On
84 /// success, this hook should return the value generated to represent the
85 /// constant value. Otherwise, it should return null on failure.
86 virtual Operation *materializeConstant(OpBuilder &builder, Attribute value,
87 Type type, Location loc) {
88 return nullptr;
89 }
90
91 //===--------------------------------------------------------------------===//
92 // Parsing Hooks
93 //===--------------------------------------------------------------------===//
94
95 /// Parse an attribute registered to this dialect. If 'type' is nonnull, it
96 /// refers to the expected type of the attribute.
97 virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const;
98
99 /// Print an attribute registered to this dialect. Note: The type of the
100 /// attribute need not be printed by this method as it is always printed by
101 /// the caller.
102 virtual void printAttribute(Attribute, DialectAsmPrinter &) const {
103 llvm_unreachable("dialect has no registered attribute printing hook");
104 }
105
106 /// Parse a type registered to this dialect.
107 virtual Type parseType(DialectAsmParser &parser) const;
108
109 /// Print a type registered to this dialect.
110 virtual void printType(Type, DialectAsmPrinter &) const {
111 llvm_unreachable("dialect has no registered type printing hook");
112 }
113
114 /// Return the hook to parse an operation registered to this dialect, if any.
115 /// By default this will lookup for registered operations and return the
116 /// `parse()` method registered on the RegisteredOperationName. Dialects can
117 /// override this behavior and handle unregistered operations as well.
118 virtual std::optional<ParseOpHook>
119 getParseOperationHook(StringRef opName) const;
120
121 /// Print an operation registered to this dialect.
122 /// This hook is invoked for registered operation which don't override the
123 /// `print()` method to define their own custom assembly.
124 virtual llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
125 getOperationPrinter(Operation *op) const;
126
127 //===--------------------------------------------------------------------===//
128 // Verification Hooks
129 //===--------------------------------------------------------------------===//
130
131 /// Verify an attribute from this dialect on the argument at 'argIndex' for
132 /// the region at 'regionIndex' on the given operation. Returns failure if
133 /// the verification failed, success otherwise. This hook may optionally be
134 /// invoked from any operation containing a region.
135 virtual LogicalResult verifyRegionArgAttribute(Operation *,
136 unsigned regionIndex,
137 unsigned argIndex,
138 NamedAttribute);
139
140 /// Verify an attribute from this dialect on the result at 'resultIndex' for
141 /// the region at 'regionIndex' on the given operation. Returns failure if
142 /// the verification failed, success otherwise. This hook may optionally be
143 /// invoked from any operation containing a region.
144 virtual LogicalResult verifyRegionResultAttribute(Operation *,
145 unsigned regionIndex,
146 unsigned resultIndex,
147 NamedAttribute);
148
149 /// Verify an attribute from this dialect on the given operation. Returns
150 /// failure if the verification failed, success otherwise.
151 virtual LogicalResult verifyOperationAttribute(Operation *, NamedAttribute) {
152 return success();
153 }
154
155 //===--------------------------------------------------------------------===//
156 // Interfaces
157 //===--------------------------------------------------------------------===//
158
159 /// Lookup an interface for the given ID if one is registered, otherwise
160 /// nullptr.
161 DialectInterface *getRegisteredInterface(TypeID interfaceID) {
162#ifndef NDEBUG
163 handleUseOfUndefinedPromisedInterface(interfaceRequestorID: getTypeID(), interfaceID);
164#endif
165
166 auto it = registeredInterfaces.find(Val: interfaceID);
167 return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr;
168 }
169 template <typename InterfaceT>
170 InterfaceT *getRegisteredInterface() {
171#ifndef NDEBUG
172 handleUseOfUndefinedPromisedInterface(interfaceRequestorID: getTypeID(),
173 interfaceID: InterfaceT::getInterfaceID(),
174 interfaceName: llvm::getTypeName<InterfaceT>());
175#endif
176
177 return static_cast<InterfaceT *>(
178 getRegisteredInterface(InterfaceT::getInterfaceID()));
179 }
180
181 /// Lookup an op interface for the given ID if one is registered, otherwise
182 /// nullptr.
183 virtual void *getRegisteredInterfaceForOp(TypeID interfaceID,
184 OperationName opName) {
185 return nullptr;
186 }
187 template <typename InterfaceT>
188 typename InterfaceT::Concept *
189 getRegisteredInterfaceForOp(OperationName opName) {
190 return static_cast<typename InterfaceT::Concept *>(
191 getRegisteredInterfaceForOp(InterfaceT::getInterfaceID(), opName));
192 }
193
194 /// Register a dialect interface with this dialect instance.
195 void addInterface(std::unique_ptr<DialectInterface> interface);
196
197 /// Register a set of dialect interfaces with this dialect instance.
198 template <typename... Args>
199 void addInterfaces() {
200 (addInterface(std::make_unique<Args>(this)), ...);
201 }
202 template <typename InterfaceT, typename... Args>
203 InterfaceT &addInterface(Args &&...args) {
204 InterfaceT *interface = new InterfaceT(this, std::forward<Args>(args)...);
205 addInterface(interface: std::unique_ptr<DialectInterface>(interface));
206 return *interface;
207 }
208
209 /// Declare that the given interface will be implemented, but has a delayed
210 /// registration. The promised interface type can be an interface of any type
211 /// not just a dialect interface, i.e. it may also be an
212 /// AttributeInterface/OpInterface/TypeInterface/etc.
213 template <typename InterfaceT, typename ConcreteT>
214 void declarePromisedInterface() {
215 unresolvedPromisedInterfaces.insert(
216 {TypeID::get<ConcreteT>(), InterfaceT::getInterfaceID()});
217 }
218
219 // Declare the same interface for multiple types.
220 // Example:
221 // declarePromisedInterfaces<FunctionOpInterface, MyFuncType1, MyFuncType2>()
222 template <typename InterfaceT, typename... ConcreteT>
223 void declarePromisedInterfaces() {
224 (declarePromisedInterface<InterfaceT, ConcreteT>(), ...);
225 }
226
227 /// Checks if the given interface, which is attempting to be used, is a
228 /// promised interface of this dialect that has yet to be implemented. If so,
229 /// emits a fatal error. `interfaceName` is an optional string that contains a
230 /// more user readable name for the interface (such as the class name).
231 void handleUseOfUndefinedPromisedInterface(TypeID interfaceRequestorID,
232 TypeID interfaceID,
233 StringRef interfaceName = "") {
234 if (unresolvedPromisedInterfaces.count(
235 V: {interfaceRequestorID, interfaceID})) {
236 llvm::report_fatal_error(
237 reason: "checking for an interface (`" + interfaceName +
238 "`) that was promised by dialect '" + getNamespace() +
239 "' but never implemented. This is generally an indication "
240 "that the dialect extension implementing the interface was never "
241 "registered.");
242 }
243 }
244
245 /// Checks if the given interface, which is attempting to be attached to a
246 /// construct owned by this dialect, is a promised interface of this dialect
247 /// that has yet to be implemented. If so, it resolves the interface promise.
248 void handleAdditionOfUndefinedPromisedInterface(TypeID interfaceRequestorID,
249 TypeID interfaceID) {
250 unresolvedPromisedInterfaces.erase(V: {interfaceRequestorID, interfaceID});
251 }
252
253 /// Checks if a promise has been made for the interface/requestor pair.
254 bool hasPromisedInterface(TypeID interfaceRequestorID,
255 TypeID interfaceID) const {
256 return unresolvedPromisedInterfaces.count(
257 V: {interfaceRequestorID, interfaceID});
258 }
259
260 /// Checks if a promise has been made for the interface/requestor pair.
261 template <typename ConcreteT, typename InterfaceT>
262 bool hasPromisedInterface() const {
263 return hasPromisedInterface(TypeID::get<ConcreteT>(),
264 InterfaceT::getInterfaceID());
265 }
266
267protected:
268 /// The constructor takes a unique namespace for this dialect as well as the
269 /// context to bind to.
270 /// Note: The namespace must not contain '.' characters.
271 /// Note: All operations belonging to this dialect must have names starting
272 /// with the namespace followed by '.'.
273 /// Example:
274 /// - "tf" for the TensorFlow ops like "tf.add".
275 Dialect(StringRef name, MLIRContext *context, TypeID id);
276
277 /// This method is used by derived classes to add their operations to the set.
278 ///
279 template <typename... Args>
280 void addOperations() {
281 // This initializer_list argument pack expansion is essentially equal to
282 // using a fold expression with a comma operator. Clang however, refuses
283 // to compile a fold expression with a depth of more than 256 by default.
284 // There seem to be no such limitations for initializer_list.
285 (void)std::initializer_list<int>{
286 0, (RegisteredOperationName::insert<Args>(*this), 0)...};
287 }
288
289 /// Register a set of type classes with this dialect.
290 template <typename... Args>
291 void addTypes() {
292 // This initializer_list argument pack expansion is essentially equal to
293 // using a fold expression with a comma operator. Clang however, refuses
294 // to compile a fold expression with a depth of more than 256 by default.
295 // There seem to be no such limitations for initializer_list.
296 (void)std::initializer_list<int>{0, (addType<Args>(), 0)...};
297 }
298
299 /// Register a type instance with this dialect.
300 /// The use of this method is in general discouraged in favor of
301 /// 'addTypes<CustomType>()'.
302 void addType(TypeID typeID, AbstractType &&typeInfo);
303
304 /// Register a set of attribute classes with this dialect.
305 template <typename... Args>
306 void addAttributes() {
307 // This initializer_list argument pack expansion is essentially equal to
308 // using a fold expression with a comma operator. Clang however, refuses
309 // to compile a fold expression with a depth of more than 256 by default.
310 // There seem to be no such limitations for initializer_list.
311 (void)std::initializer_list<int>{0, (addAttribute<Args>(), 0)...};
312 }
313
314 /// Register an attribute instance with this dialect.
315 /// The use of this method is in general discouraged in favor of
316 /// 'addAttributes<CustomAttr>()'.
317 void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo);
318
319 /// Enable support for unregistered operations.
320 void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; }
321
322 /// Enable support for unregistered types.
323 void allowUnknownTypes(bool allow = true) { unknownTypesAllowed = allow; }
324
325private:
326 Dialect(const Dialect &) = delete;
327 void operator=(Dialect &) = delete;
328
329 /// Register an attribute instance with this dialect.
330 template <typename T>
331 void addAttribute() {
332 // Add this attribute to the dialect and register it with the uniquer.
333 addAttribute(T::getTypeID(), AbstractAttribute::get<T>(*this));
334 detail::AttributeUniquer::registerAttribute<T>(context);
335 }
336
337 /// Register a type instance with this dialect.
338 template <typename T>
339 void addType() {
340 // Add this type to the dialect and register it with the uniquer.
341 addType(T::getTypeID(), AbstractType::get<T>(*this));
342 detail::TypeUniquer::registerType<T>(context);
343 }
344
345 /// The namespace of this dialect.
346 StringRef name;
347
348 /// The unique identifier of the derived Op class, this is used in the context
349 /// to allow registering multiple times the same dialect.
350 TypeID dialectID;
351
352 /// This is the context that owns this Dialect object.
353 MLIRContext *context;
354
355 /// Flag that specifies whether this dialect supports unregistered operations,
356 /// i.e. operations prefixed with the dialect namespace but not registered
357 /// with addOperation.
358 bool unknownOpsAllowed = false;
359
360 /// Flag that specifies whether this dialect allows unregistered types, i.e.
361 /// types prefixed with the dialect namespace but not registered with addType.
362 /// These types are represented with OpaqueType.
363 bool unknownTypesAllowed = false;
364
365 /// A collection of registered dialect interfaces.
366 DenseMap<TypeID, std::unique_ptr<DialectInterface>> registeredInterfaces;
367
368 /// A set of interfaces that the dialect (or its constructs, i.e.
369 /// Attributes/Operations/Types/etc.) has promised to implement, but has yet
370 /// to provide an implementation for.
371 DenseSet<std::pair<TypeID, TypeID>> unresolvedPromisedInterfaces;
372
373 friend class DialectRegistry;
374 friend void registerDialect();
375 friend class MLIRContext;
376};
377
378} // namespace mlir
379
380namespace llvm {
381/// Provide isa functionality for Dialects.
382template <typename T>
383struct isa_impl<T, ::mlir::Dialect,
384 std::enable_if_t<std::is_base_of<::mlir::Dialect, T>::value>> {
385 static inline bool doit(const ::mlir::Dialect &dialect) {
386 return mlir::TypeID::get<T>() == dialect.getTypeID();
387 }
388};
389template <typename T>
390struct isa_impl<
391 T, ::mlir::Dialect,
392 std::enable_if_t<std::is_base_of<::mlir::DialectInterface, T>::value>> {
393 static inline bool doit(const ::mlir::Dialect &dialect) {
394 return const_cast<::mlir::Dialect &>(dialect).getRegisteredInterface<T>();
395 }
396};
397template <typename T>
398struct cast_retty_impl<T, ::mlir::Dialect *> {
399 using ret_type = T *;
400};
401template <typename T>
402struct cast_retty_impl<T, ::mlir::Dialect> {
403 using ret_type = T &;
404};
405
406template <typename T>
407struct cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect> {
408 template <typename To>
409 static std::enable_if_t<std::is_base_of<::mlir::Dialect, To>::value, To &>
410 doitImpl(::mlir::Dialect &dialect) {
411 return static_cast<To &>(dialect);
412 }
413 template <typename To>
414 static std::enable_if_t<std::is_base_of<::mlir::DialectInterface, To>::value,
415 To &>
416 doitImpl(::mlir::Dialect &dialect) {
417 return *dialect.getRegisteredInterface<To>();
418 }
419
420 static auto &doit(::mlir::Dialect &dialect) { return doitImpl<T>(dialect); }
421};
422template <class T>
423struct cast_convert_val<T, ::mlir::Dialect *, ::mlir::Dialect *> {
424 static auto doit(::mlir::Dialect *dialect) {
425 return &cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect>::doit(
426 *dialect);
427 }
428};
429
430} // namespace llvm
431
432#endif
433

source code of mlir/include/mlir/IR/Dialect.h