1 | //===- MemRef.h - MemRef dialect --------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_DIALECT_MEMREF_IR_MEMREF_H_ |
10 | #define MLIR_DIALECT_MEMREF_IR_MEMREF_H_ |
11 | |
12 | #include "mlir/Bytecode/BytecodeOpInterface.h" |
13 | #include "mlir/Dialect/Arith/IR/Arith.h" |
14 | #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" |
15 | #include "mlir/IR/Dialect.h" |
16 | #include "mlir/Interfaces/CallInterfaces.h" |
17 | #include "mlir/Interfaces/CastInterfaces.h" |
18 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
19 | #include "mlir/Interfaces/CopyOpInterface.h" |
20 | #include "mlir/Interfaces/InferTypeOpInterface.h" |
21 | #include "mlir/Interfaces/MemorySlotInterfaces.h" |
22 | #include "mlir/Interfaces/ShapedOpInterfaces.h" |
23 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
24 | #include "mlir/Interfaces/ViewLikeInterface.h" |
25 | |
26 | #include <optional> |
27 | |
28 | namespace mlir { |
29 | |
30 | namespace arith { |
31 | enum class AtomicRMWKind : uint64_t; |
32 | class AtomicRMWKindAttr; |
33 | } // namespace arith |
34 | |
35 | class Location; |
36 | class OpBuilder; |
37 | |
38 | raw_ostream &operator<<(raw_ostream &os, const Range &range); |
39 | |
40 | /// Return the list of Range (i.e. offset, size, stride). Each Range |
41 | /// entry contains either the dynamic value or a ConstantIndexOp constructed |
42 | /// with `b` at location `loc`. |
43 | SmallVector<Range, 8> getOrCreateRanges(OffsetSizeAndStrideOpInterface op, |
44 | OpBuilder &b, Location loc); |
45 | |
46 | namespace memref { |
47 | |
48 | /// This is a common utility used for patterns of the form |
49 | /// "someop(memref.cast) -> someop". It folds the source of any memref.cast |
50 | /// into the root operation directly. |
51 | LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr); |
52 | |
53 | /// Return an unranked/ranked tensor type for the given unranked/ranked memref |
54 | /// type. |
55 | Type getTensorTypeFromMemRefType(Type type); |
56 | |
57 | /// Finds a single dealloc operation for the given allocated value. If there |
58 | /// are > 1 deallocates for `allocValue`, returns std::nullopt, else returns the |
59 | /// single deallocate if it exists or nullptr. |
60 | std::optional<Operation *> findDealloc(Value allocValue); |
61 | |
62 | /// Return the dimension of the given memref value. |
63 | OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, |
64 | int64_t dim); |
65 | |
66 | /// Return the dimensions of the given memref value. |
67 | SmallVector<OpFoldResult> getMixedSizes(OpBuilder &builder, Location loc, |
68 | Value value); |
69 | |
70 | /// Create a rank-reducing SubViewOp @[0 .. 0] with strides [1 .. 1] and |
71 | /// appropriate sizes (i.e. `memref.getSizes()`) to reduce the rank of `memref` |
72 | /// to that of `targetShape`. |
73 | Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, |
74 | Value memref, |
75 | ArrayRef<int64_t> targetShape); |
76 | } // namespace memref |
77 | } // namespace mlir |
78 | |
79 | //===----------------------------------------------------------------------===// |
80 | // MemRef Dialect |
81 | //===----------------------------------------------------------------------===// |
82 | |
83 | #include "mlir/Dialect/MemRef/IR/MemRefOpsDialect.h.inc" |
84 | |
85 | //===----------------------------------------------------------------------===// |
86 | // MemRef Dialect Operations |
87 | //===----------------------------------------------------------------------===// |
88 | |
89 | #define GET_OP_CLASSES |
90 | #include "mlir/Dialect/MemRef/IR/MemRefOps.h.inc" |
91 | |
92 | #endif // MLIR_DIALECT_MEMREF_IR_MEMREF_H_ |
93 | |