1//===- Arith.h - Arith dialect ------------------------------------*- C++-*-==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_ARITH_IR_ARITH_H_
10#define MLIR_DIALECT_ARITH_IR_ARITH_H_
11
12#include "mlir/Bytecode/BytecodeOpInterface.h"
13#include "mlir/IR/Dialect.h"
14#include "mlir/IR/OpDefinition.h"
15#include "mlir/IR/OpImplementation.h"
16#include "mlir/Interfaces/CastInterfaces.h"
17#include "mlir/Interfaces/InferIntRangeInterface.h"
18#include "mlir/Interfaces/InferTypeOpInterface.h"
19#include "mlir/Interfaces/SideEffectInterfaces.h"
20#include "mlir/Interfaces/VectorInterfaces.h"
21#include "llvm/ADT/StringExtras.h"
22
23//===----------------------------------------------------------------------===//
24// ArithDialect
25//===----------------------------------------------------------------------===//
26
27#include "mlir/Dialect/Arith/IR/ArithOpsDialect.h.inc"
28
29//===----------------------------------------------------------------------===//
30// Arith Dialect Enum Attributes
31//===----------------------------------------------------------------------===//
32
33#include "mlir/Dialect/Arith/IR/ArithOpsEnums.h.inc"
34#define GET_ATTRDEF_CLASSES
35#include "mlir/Dialect/Arith/IR/ArithOpsAttributes.h.inc"
36
37//===----------------------------------------------------------------------===//
38// Arith Interfaces
39//===----------------------------------------------------------------------===//
40#include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.h.inc"
41
42//===----------------------------------------------------------------------===//
43// Arith Dialect Operations
44//===----------------------------------------------------------------------===//
45
46#define GET_OP_CLASSES
47#include "mlir/Dialect/Arith/IR/ArithOps.h.inc"
48
49namespace mlir {
50namespace arith {
51
52/// Specialization of `arith.constant` op that returns an integer value.
53class ConstantIntOp : public arith::ConstantOp {
54public:
55 using arith::ConstantOp::ConstantOp;
56 static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
57
58 /// Build a constant int op that produces an integer of the specified width.
59 static void build(OpBuilder &builder, OperationState &result, int64_t value,
60 unsigned width);
61
62 /// Build a constant int op that produces an integer of the specified type,
63 /// which must be an integer type.
64 static void build(OpBuilder &builder, OperationState &result, int64_t value,
65 Type type);
66
67 inline int64_t value() {
68 return cast<IntegerAttr>(arith::ConstantOp::getValue()).getInt();
69 }
70
71 static bool classof(Operation *op);
72};
73
74/// Specialization of `arith.constant` op that returns a floating point value.
75class ConstantFloatOp : public arith::ConstantOp {
76public:
77 using arith::ConstantOp::ConstantOp;
78 static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
79
80 /// Build a constant float op that produces a float of the specified type.
81 static void build(OpBuilder &builder, OperationState &result,
82 const APFloat &value, FloatType type);
83
84 inline APFloat value() {
85 return cast<FloatAttr>(arith::ConstantOp::getValue()).getValue();
86 }
87
88 static bool classof(Operation *op);
89};
90
91/// Specialization of `arith.constant` op that returns an integer of index type.
92class ConstantIndexOp : public arith::ConstantOp {
93public:
94 using arith::ConstantOp::ConstantOp;
95 static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
96 /// Build a constant int op that produces an index.
97 static void build(OpBuilder &builder, OperationState &result, int64_t value);
98
99 inline int64_t value() {
100 return cast<IntegerAttr>(arith::ConstantOp::getValue()).getInt();
101 }
102
103 static bool classof(Operation *op);
104};
105
106} // namespace arith
107} // namespace mlir
108
109//===----------------------------------------------------------------------===//
110// Utility Functions
111//===----------------------------------------------------------------------===//
112
113namespace mlir {
114namespace arith {
115
116/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
117/// comparison predicates.
118bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs,
119 const APInt &rhs);
120
121/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
122/// comparison predicates.
123bool applyCmpPredicate(arith::CmpFPredicate predicate, const APFloat &lhs,
124 const APFloat &rhs);
125
126/// Returns the identity value attribute associated with an AtomicRMWKind op.
127/// `useOnlyFiniteValue` defines whether the identity value should steer away
128/// from infinity representations or anything that is not a proper finite
129/// number.
130/// E.g., The identity value for maxf is in theory `-Inf`, but if we want to
131/// stay in the finite range, it would be `BiggestRepresentableNegativeFloat`.
132/// The purpose of this boolean is to offer constants that will play nice
133/// with fast math related optimizations.
134TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
135 OpBuilder &builder, Location loc,
136 bool useOnlyFiniteValue = false);
137
138/// Return the identity numeric value associated to the give op. Return
139/// std::nullopt if there is no known neutral element.
140/// If `op` has `FastMathFlags::ninf`, only finite values will be used
141/// as neutral element.
142std::optional<TypedAttr> getNeutralElement(Operation *op);
143
144/// Returns the identity value associated with an AtomicRMWKind op.
145/// \see getIdentityValueAttr for a description of what `useOnlyFiniteValue`
146/// does.
147Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder,
148 Location loc, bool useOnlyFiniteValue = false);
149
150/// Returns the value obtained by applying the reduction operation kind
151/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
152Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
153 Value lhs, Value rhs);
154
155arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred);
156} // namespace arith
157} // namespace mlir
158
159#endif // MLIR_DIALECT_ARITH_IR_ARITH_H_
160

source code of mlir/include/mlir/Dialect/Arith/IR/Arith.h