1 | //===- BuiltinAttributeInterfaces.h - Builtin Attr Interfaces ---*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_IR_BUILTINATTRIBUTEINTERFACES_H |
10 | #define MLIR_IR_BUILTINATTRIBUTEINTERFACES_H |
11 | |
12 | #include "mlir/IR/AffineMap.h" |
13 | #include "mlir/IR/Attributes.h" |
14 | #include "mlir/IR/BuiltinTypeInterfaces.h" |
15 | #include "mlir/IR/Types.h" |
16 | #include "mlir/Support/LogicalResult.h" |
17 | #include "llvm/Support/raw_ostream.h" |
18 | #include <complex> |
19 | #include <optional> |
20 | |
21 | namespace mlir { |
22 | |
23 | //===----------------------------------------------------------------------===// |
24 | // ElementsAttr |
25 | //===----------------------------------------------------------------------===// |
26 | namespace detail { |
27 | /// This class provides support for indexing into the element range of an |
28 | /// ElementsAttr. It is used to opaquely wrap either a contiguous range, via |
29 | /// `ElementsAttrIndexer::contiguous`, or a non-contiguous range, via |
30 | /// `ElementsAttrIndexer::nonContiguous`, A contiguous range is an array-like |
31 | /// range, where all of the elements are layed out sequentially in memory. A |
32 | /// non-contiguous range implies no contiguity, and elements may even be |
33 | /// materialized when indexing, such as the case for a mapped_range. |
34 | struct ElementsAttrIndexer { |
35 | public: |
36 | ElementsAttrIndexer() |
37 | : ElementsAttrIndexer(/*isContiguous=*/true, /*isSplat=*/true) {} |
38 | ElementsAttrIndexer(ElementsAttrIndexer &&rhs) |
39 | : isContiguous(rhs.isContiguous), isSplat(rhs.isSplat) { |
40 | if (isContiguous) |
41 | conState = rhs.conState; |
42 | else |
43 | new (&nonConState) NonContiguousState(std::move(rhs.nonConState)); |
44 | } |
45 | ElementsAttrIndexer(const ElementsAttrIndexer &rhs) |
46 | : isContiguous(rhs.isContiguous), isSplat(rhs.isSplat) { |
47 | if (isContiguous) |
48 | conState = rhs.conState; |
49 | else |
50 | new (&nonConState) NonContiguousState(rhs.nonConState); |
51 | } |
52 | ~ElementsAttrIndexer() { |
53 | if (!isContiguous) |
54 | nonConState.~NonContiguousState(); |
55 | } |
56 | |
57 | /// Construct an indexer for a non-contiguous range starting at the given |
58 | /// iterator. A non-contiguous range implies no contiguity, and elements may |
59 | /// even be materialized when indexing, such as the case for a mapped_range. |
60 | template <typename IteratorT> |
61 | static ElementsAttrIndexer nonContiguous(bool isSplat, IteratorT &&iterator) { |
62 | ElementsAttrIndexer indexer(/*isContiguous=*/false, isSplat); |
63 | new (&indexer.nonConState) |
64 | NonContiguousState(std::forward<IteratorT>(iterator)); |
65 | return indexer; |
66 | } |
67 | |
68 | // Construct an indexer for a contiguous range starting at the given element |
69 | // pointer. A contiguous range is an array-like range, where all of the |
70 | // elements are layed out sequentially in memory. |
71 | template <typename T> |
72 | static ElementsAttrIndexer contiguous(bool isSplat, const T *firstEltPtr) { |
73 | ElementsAttrIndexer indexer(/*isContiguous=*/true, isSplat); |
74 | new (&indexer.conState) ContiguousState(firstEltPtr); |
75 | return indexer; |
76 | } |
77 | |
78 | /// Access the element at the given index. |
79 | template <typename T> |
80 | T at(uint64_t index) const { |
81 | if (isSplat) |
82 | index = 0; |
83 | return isContiguous ? conState.at<T>(index) : nonConState.at<T>(index); |
84 | } |
85 | |
86 | private: |
87 | ElementsAttrIndexer(bool isContiguous, bool isSplat) |
88 | : isContiguous(isContiguous), isSplat(isSplat), conState(nullptr) {} |
89 | |
90 | /// This class contains all of the state necessary to index a contiguous |
91 | /// range. |
92 | class ContiguousState { |
93 | public: |
94 | ContiguousState(const void *firstEltPtr) : firstEltPtr(firstEltPtr) {} |
95 | |
96 | /// Access the element at the given index. |
97 | template <typename T> |
98 | const T &at(uint64_t index) const { |
99 | return *(reinterpret_cast<const T *>(firstEltPtr) + index); |
100 | } |
101 | |
102 | private: |
103 | const void *firstEltPtr; |
104 | }; |
105 | |
106 | /// This class contains all of the state necessary to index a non-contiguous |
107 | /// range. |
108 | class NonContiguousState { |
109 | private: |
110 | /// This class is used to represent the abstract base of an opaque iterator. |
111 | /// This allows for all iterator and element types to be completely |
112 | /// type-erased. |
113 | struct OpaqueIteratorBase { |
114 | virtual ~OpaqueIteratorBase() = default; |
115 | virtual std::unique_ptr<OpaqueIteratorBase> clone() const = 0; |
116 | }; |
117 | /// This class is used to represent the abstract base of an opaque iterator |
118 | /// that iterates over elements of type `T`. This allows for all iterator |
119 | /// types to be completely type-erased. |
120 | template <typename T> |
121 | struct OpaqueIteratorValueBase : public OpaqueIteratorBase { |
122 | virtual T at(uint64_t index) = 0; |
123 | }; |
124 | /// This class is used to represent an opaque handle to an iterator of type |
125 | /// `IteratorT` that iterates over elements of type `T`. |
126 | template <typename IteratorT, typename T> |
127 | struct OpaqueIterator : public OpaqueIteratorValueBase<T> { |
128 | template <typename ItTy, typename FuncTy, typename FuncReturnTy> |
129 | static void isMappedIteratorTestFn( |
130 | llvm::mapped_iterator<ItTy, FuncTy, FuncReturnTy>) {} |
131 | template <typename U, typename... Args> |
132 | using is_mapped_iterator = |
133 | decltype(isMappedIteratorTestFn(std::declval<U>())); |
134 | template <typename U> |
135 | using detect_is_mapped_iterator = |
136 | llvm::is_detected<is_mapped_iterator, U>; |
137 | |
138 | /// Access the element within the iterator at the given index. |
139 | template <typename ItT> |
140 | static std::enable_if_t<!detect_is_mapped_iterator<ItT>::value, T> |
141 | atImpl(ItT &&it, uint64_t index) { |
142 | return *std::next(it, index); |
143 | } |
144 | template <typename ItT> |
145 | static std::enable_if_t<detect_is_mapped_iterator<ItT>::value, T> |
146 | atImpl(ItT &&it, uint64_t index) { |
147 | // Special case mapped_iterator to avoid copying the function. |
148 | return it.getFunction()(*std::next(it.getCurrent(), index)); |
149 | } |
150 | |
151 | public: |
152 | template <typename U> |
153 | OpaqueIterator(U &&iterator) : iterator(std::forward<U>(iterator)) {} |
154 | std::unique_ptr<OpaqueIteratorBase> clone() const final { |
155 | return std::make_unique<OpaqueIterator<IteratorT, T>>(iterator); |
156 | } |
157 | |
158 | /// Access the element at the given index. |
159 | T at(uint64_t index) final { return atImpl(iterator, index); } |
160 | |
161 | private: |
162 | IteratorT iterator; |
163 | }; |
164 | |
165 | public: |
166 | /// Construct the state with the given iterator type. |
167 | template <typename IteratorT, typename T = typename llvm::remove_cvref_t< |
168 | decltype(*std::declval<IteratorT>())>> |
169 | NonContiguousState(IteratorT iterator) |
170 | : iterator(std::make_unique<OpaqueIterator<IteratorT, T>>(iterator)) {} |
171 | NonContiguousState(const NonContiguousState &other) |
172 | : iterator(other.iterator->clone()) {} |
173 | NonContiguousState(NonContiguousState &&other) = default; |
174 | |
175 | /// Access the element at the given index. |
176 | template <typename T> |
177 | T at(uint64_t index) const { |
178 | auto *valueIt = static_cast<OpaqueIteratorValueBase<T> *>(iterator.get()); |
179 | return valueIt->at(index); |
180 | } |
181 | |
182 | /// The opaque iterator state. |
183 | std::unique_ptr<OpaqueIteratorBase> iterator; |
184 | }; |
185 | |
186 | /// A boolean indicating if this range is contiguous or not. |
187 | bool isContiguous; |
188 | /// A boolean indicating if this range is a splat. |
189 | bool isSplat; |
190 | /// The underlying range state. |
191 | union { |
192 | ContiguousState conState; |
193 | NonContiguousState nonConState; |
194 | }; |
195 | }; |
196 | |
197 | /// This class implements a generic iterator for ElementsAttr. |
198 | template <typename T> |
199 | class ElementsAttrIterator |
200 | : public llvm::iterator_facade_base<ElementsAttrIterator<T>, |
201 | std::random_access_iterator_tag, T, |
202 | std::ptrdiff_t, T, T> { |
203 | public: |
204 | ElementsAttrIterator(ElementsAttrIndexer indexer, size_t dataIndex) |
205 | : indexer(std::move(indexer)), index(dataIndex) {} |
206 | |
207 | // Boilerplate iterator methods. |
208 | ptrdiff_t operator-(const ElementsAttrIterator &rhs) const { |
209 | return index - rhs.index; |
210 | } |
211 | bool operator==(const ElementsAttrIterator &rhs) const { |
212 | return index == rhs.index; |
213 | } |
214 | bool operator<(const ElementsAttrIterator &rhs) const { |
215 | return index < rhs.index; |
216 | } |
217 | ElementsAttrIterator &operator+=(ptrdiff_t offset) { |
218 | index += offset; |
219 | return *this; |
220 | } |
221 | ElementsAttrIterator &operator-=(ptrdiff_t offset) { |
222 | index -= offset; |
223 | return *this; |
224 | } |
225 | |
226 | /// Return the value at the current iterator position. |
227 | T operator*() const { return indexer.at<T>(index); } |
228 | |
229 | private: |
230 | ElementsAttrIndexer indexer; |
231 | ptrdiff_t index; |
232 | }; |
233 | |
234 | /// This class provides iterator utilities for an ElementsAttr range. |
235 | template <typename IteratorT> |
236 | class ElementsAttrRange : public llvm::iterator_range<IteratorT> { |
237 | public: |
238 | using reference = typename IteratorT::reference; |
239 | |
240 | ElementsAttrRange(ShapedType shapeType, |
241 | const llvm::iterator_range<IteratorT> &range) |
242 | : llvm::iterator_range<IteratorT>(range), shapeType(shapeType) {} |
243 | ElementsAttrRange(ShapedType shapeType, IteratorT beginIt, IteratorT endIt) |
244 | : ElementsAttrRange(shapeType, llvm::make_range(beginIt, endIt)) {} |
245 | |
246 | /// Return the value at the given index. |
247 | reference operator[](ArrayRef<uint64_t> index) const; |
248 | reference operator[](uint64_t index) const { |
249 | return *std::next(this->begin(), index); |
250 | } |
251 | |
252 | /// Return the size of this range. |
253 | size_t size() const { return llvm::size(*this); } |
254 | |
255 | private: |
256 | /// The shaped type of the parent ElementsAttr. |
257 | ShapedType shapeType; |
258 | }; |
259 | |
260 | } // namespace detail |
261 | |
262 | //===----------------------------------------------------------------------===// |
263 | // MemRefLayoutAttrInterface |
264 | //===----------------------------------------------------------------------===// |
265 | |
266 | namespace detail { |
267 | |
268 | // Verify the affine map 'm' can be used as a layout specification |
269 | // for memref with 'shape'. |
270 | LogicalResult |
271 | verifyAffineMapAsLayout(AffineMap m, ArrayRef<int64_t> shape, |
272 | function_ref<InFlightDiagnostic()> emitError); |
273 | |
274 | } // namespace detail |
275 | |
276 | } // namespace mlir |
277 | |
278 | //===----------------------------------------------------------------------===// |
279 | // Tablegen Interface Declarations |
280 | //===----------------------------------------------------------------------===// |
281 | |
282 | #include "mlir/IR/BuiltinAttributeInterfaces.h.inc" |
283 | |
284 | //===----------------------------------------------------------------------===// |
285 | // ElementsAttr |
286 | //===----------------------------------------------------------------------===// |
287 | |
288 | namespace mlir { |
289 | namespace detail { |
290 | /// Return the value at the given index. |
291 | template <typename IteratorT> |
292 | auto ElementsAttrRange<IteratorT>::operator[](ArrayRef<uint64_t> index) const |
293 | -> reference { |
294 | // Skip to the element corresponding to the flattened index. |
295 | return (*this)[ElementsAttr::getFlattenedIndex(shapeType, index)]; |
296 | } |
297 | } // namespace detail |
298 | |
299 | /// Return the elements of this attribute as a value of type 'T'. |
300 | template <typename T> |
301 | auto ElementsAttr::value_begin() const -> DefaultValueCheckT<T, iterator<T>> { |
302 | if (std::optional<iterator<T>> iterator = try_value_begin<T>()) |
303 | return std::move(*iterator); |
304 | llvm::errs() |
305 | << "ElementsAttr does not provide iteration facilities for type `" |
306 | << llvm::getTypeName<T>() << "`, see attribute: " << *this << "\n" ; |
307 | llvm_unreachable("invalid `T` for ElementsAttr::getValues" ); |
308 | } |
309 | template <typename T> |
310 | auto ElementsAttr::try_value_begin() const |
311 | -> DefaultValueCheckT<T, std::optional<iterator<T>>> { |
312 | FailureOr<detail::ElementsAttrIndexer> indexer = |
313 | getValuesImpl(TypeID::get<T>()); |
314 | if (failed(indexer)) |
315 | return std::nullopt; |
316 | return iterator<T>(std::move(*indexer), 0); |
317 | } |
318 | } // namespace mlir. |
319 | |
320 | #endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_H |
321 | |