1//===- BuiltinAttributeInterfaces.h - Builtin Attr Interfaces ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
10#define MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
11
12#include "mlir/IR/AffineMap.h"
13#include "mlir/IR/Attributes.h"
14#include "mlir/IR/BuiltinTypeInterfaces.h"
15#include "mlir/IR/Types.h"
16#include "mlir/Support/LogicalResult.h"
17#include "llvm/Support/raw_ostream.h"
18#include <complex>
19#include <optional>
20
21namespace mlir {
22
23//===----------------------------------------------------------------------===//
24// ElementsAttr
25//===----------------------------------------------------------------------===//
26namespace detail {
27/// This class provides support for indexing into the element range of an
28/// ElementsAttr. It is used to opaquely wrap either a contiguous range, via
29/// `ElementsAttrIndexer::contiguous`, or a non-contiguous range, via
30/// `ElementsAttrIndexer::nonContiguous`, A contiguous range is an array-like
31/// range, where all of the elements are layed out sequentially in memory. A
32/// non-contiguous range implies no contiguity, and elements may even be
33/// materialized when indexing, such as the case for a mapped_range.
34struct ElementsAttrIndexer {
35public:
36 ElementsAttrIndexer()
37 : ElementsAttrIndexer(/*isContiguous=*/true, /*isSplat=*/true) {}
38 ElementsAttrIndexer(ElementsAttrIndexer &&rhs)
39 : isContiguous(rhs.isContiguous), isSplat(rhs.isSplat) {
40 if (isContiguous)
41 conState = rhs.conState;
42 else
43 new (&nonConState) NonContiguousState(std::move(rhs.nonConState));
44 }
45 ElementsAttrIndexer(const ElementsAttrIndexer &rhs)
46 : isContiguous(rhs.isContiguous), isSplat(rhs.isSplat) {
47 if (isContiguous)
48 conState = rhs.conState;
49 else
50 new (&nonConState) NonContiguousState(rhs.nonConState);
51 }
52 ~ElementsAttrIndexer() {
53 if (!isContiguous)
54 nonConState.~NonContiguousState();
55 }
56
57 /// Construct an indexer for a non-contiguous range starting at the given
58 /// iterator. A non-contiguous range implies no contiguity, and elements may
59 /// even be materialized when indexing, such as the case for a mapped_range.
60 template <typename IteratorT>
61 static ElementsAttrIndexer nonContiguous(bool isSplat, IteratorT &&iterator) {
62 ElementsAttrIndexer indexer(/*isContiguous=*/false, isSplat);
63 new (&indexer.nonConState)
64 NonContiguousState(std::forward<IteratorT>(iterator));
65 return indexer;
66 }
67
68 // Construct an indexer for a contiguous range starting at the given element
69 // pointer. A contiguous range is an array-like range, where all of the
70 // elements are layed out sequentially in memory.
71 template <typename T>
72 static ElementsAttrIndexer contiguous(bool isSplat, const T *firstEltPtr) {
73 ElementsAttrIndexer indexer(/*isContiguous=*/true, isSplat);
74 new (&indexer.conState) ContiguousState(firstEltPtr);
75 return indexer;
76 }
77
78 /// Access the element at the given index.
79 template <typename T>
80 T at(uint64_t index) const {
81 if (isSplat)
82 index = 0;
83 return isContiguous ? conState.at<T>(index) : nonConState.at<T>(index);
84 }
85
86private:
87 ElementsAttrIndexer(bool isContiguous, bool isSplat)
88 : isContiguous(isContiguous), isSplat(isSplat), conState(nullptr) {}
89
90 /// This class contains all of the state necessary to index a contiguous
91 /// range.
92 class ContiguousState {
93 public:
94 ContiguousState(const void *firstEltPtr) : firstEltPtr(firstEltPtr) {}
95
96 /// Access the element at the given index.
97 template <typename T>
98 const T &at(uint64_t index) const {
99 return *(reinterpret_cast<const T *>(firstEltPtr) + index);
100 }
101
102 private:
103 const void *firstEltPtr;
104 };
105
106 /// This class contains all of the state necessary to index a non-contiguous
107 /// range.
108 class NonContiguousState {
109 private:
110 /// This class is used to represent the abstract base of an opaque iterator.
111 /// This allows for all iterator and element types to be completely
112 /// type-erased.
113 struct OpaqueIteratorBase {
114 virtual ~OpaqueIteratorBase() = default;
115 virtual std::unique_ptr<OpaqueIteratorBase> clone() const = 0;
116 };
117 /// This class is used to represent the abstract base of an opaque iterator
118 /// that iterates over elements of type `T`. This allows for all iterator
119 /// types to be completely type-erased.
120 template <typename T>
121 struct OpaqueIteratorValueBase : public OpaqueIteratorBase {
122 virtual T at(uint64_t index) = 0;
123 };
124 /// This class is used to represent an opaque handle to an iterator of type
125 /// `IteratorT` that iterates over elements of type `T`.
126 template <typename IteratorT, typename T>
127 struct OpaqueIterator : public OpaqueIteratorValueBase<T> {
128 template <typename ItTy, typename FuncTy, typename FuncReturnTy>
129 static void isMappedIteratorTestFn(
130 llvm::mapped_iterator<ItTy, FuncTy, FuncReturnTy>) {}
131 template <typename U, typename... Args>
132 using is_mapped_iterator =
133 decltype(isMappedIteratorTestFn(std::declval<U>()));
134 template <typename U>
135 using detect_is_mapped_iterator =
136 llvm::is_detected<is_mapped_iterator, U>;
137
138 /// Access the element within the iterator at the given index.
139 template <typename ItT>
140 static std::enable_if_t<!detect_is_mapped_iterator<ItT>::value, T>
141 atImpl(ItT &&it, uint64_t index) {
142 return *std::next(it, index);
143 }
144 template <typename ItT>
145 static std::enable_if_t<detect_is_mapped_iterator<ItT>::value, T>
146 atImpl(ItT &&it, uint64_t index) {
147 // Special case mapped_iterator to avoid copying the function.
148 return it.getFunction()(*std::next(it.getCurrent(), index));
149 }
150
151 public:
152 template <typename U>
153 OpaqueIterator(U &&iterator) : iterator(std::forward<U>(iterator)) {}
154 std::unique_ptr<OpaqueIteratorBase> clone() const final {
155 return std::make_unique<OpaqueIterator<IteratorT, T>>(iterator);
156 }
157
158 /// Access the element at the given index.
159 T at(uint64_t index) final { return atImpl(iterator, index); }
160
161 private:
162 IteratorT iterator;
163 };
164
165 public:
166 /// Construct the state with the given iterator type.
167 template <typename IteratorT, typename T = typename llvm::remove_cvref_t<
168 decltype(*std::declval<IteratorT>())>>
169 NonContiguousState(IteratorT iterator)
170 : iterator(std::make_unique<OpaqueIterator<IteratorT, T>>(iterator)) {}
171 NonContiguousState(const NonContiguousState &other)
172 : iterator(other.iterator->clone()) {}
173 NonContiguousState(NonContiguousState &&other) = default;
174
175 /// Access the element at the given index.
176 template <typename T>
177 T at(uint64_t index) const {
178 auto *valueIt = static_cast<OpaqueIteratorValueBase<T> *>(iterator.get());
179 return valueIt->at(index);
180 }
181
182 /// The opaque iterator state.
183 std::unique_ptr<OpaqueIteratorBase> iterator;
184 };
185
186 /// A boolean indicating if this range is contiguous or not.
187 bool isContiguous;
188 /// A boolean indicating if this range is a splat.
189 bool isSplat;
190 /// The underlying range state.
191 union {
192 ContiguousState conState;
193 NonContiguousState nonConState;
194 };
195};
196
197/// This class implements a generic iterator for ElementsAttr.
198template <typename T>
199class ElementsAttrIterator
200 : public llvm::iterator_facade_base<ElementsAttrIterator<T>,
201 std::random_access_iterator_tag, T,
202 std::ptrdiff_t, T, T> {
203public:
204 ElementsAttrIterator(ElementsAttrIndexer indexer, size_t dataIndex)
205 : indexer(std::move(indexer)), index(dataIndex) {}
206
207 // Boilerplate iterator methods.
208 ptrdiff_t operator-(const ElementsAttrIterator &rhs) const {
209 return index - rhs.index;
210 }
211 bool operator==(const ElementsAttrIterator &rhs) const {
212 return index == rhs.index;
213 }
214 bool operator<(const ElementsAttrIterator &rhs) const {
215 return index < rhs.index;
216 }
217 ElementsAttrIterator &operator+=(ptrdiff_t offset) {
218 index += offset;
219 return *this;
220 }
221 ElementsAttrIterator &operator-=(ptrdiff_t offset) {
222 index -= offset;
223 return *this;
224 }
225
226 /// Return the value at the current iterator position.
227 T operator*() const { return indexer.at<T>(index); }
228
229private:
230 ElementsAttrIndexer indexer;
231 ptrdiff_t index;
232};
233
234/// This class provides iterator utilities for an ElementsAttr range.
235template <typename IteratorT>
236class ElementsAttrRange : public llvm::iterator_range<IteratorT> {
237public:
238 using reference = typename IteratorT::reference;
239
240 ElementsAttrRange(ShapedType shapeType,
241 const llvm::iterator_range<IteratorT> &range)
242 : llvm::iterator_range<IteratorT>(range), shapeType(shapeType) {}
243 ElementsAttrRange(ShapedType shapeType, IteratorT beginIt, IteratorT endIt)
244 : ElementsAttrRange(shapeType, llvm::make_range(beginIt, endIt)) {}
245
246 /// Return the value at the given index.
247 reference operator[](ArrayRef<uint64_t> index) const;
248 reference operator[](uint64_t index) const {
249 return *std::next(this->begin(), index);
250 }
251
252 /// Return the size of this range.
253 size_t size() const { return llvm::size(*this); }
254
255private:
256 /// The shaped type of the parent ElementsAttr.
257 ShapedType shapeType;
258};
259
260} // namespace detail
261
262//===----------------------------------------------------------------------===//
263// MemRefLayoutAttrInterface
264//===----------------------------------------------------------------------===//
265
266namespace detail {
267
268// Verify the affine map 'm' can be used as a layout specification
269// for memref with 'shape'.
270LogicalResult
271verifyAffineMapAsLayout(AffineMap m, ArrayRef<int64_t> shape,
272 function_ref<InFlightDiagnostic()> emitError);
273
274} // namespace detail
275
276} // namespace mlir
277
278//===----------------------------------------------------------------------===//
279// Tablegen Interface Declarations
280//===----------------------------------------------------------------------===//
281
282#include "mlir/IR/BuiltinAttributeInterfaces.h.inc"
283
284//===----------------------------------------------------------------------===//
285// ElementsAttr
286//===----------------------------------------------------------------------===//
287
288namespace mlir {
289namespace detail {
290/// Return the value at the given index.
291template <typename IteratorT>
292auto ElementsAttrRange<IteratorT>::operator[](ArrayRef<uint64_t> index) const
293 -> reference {
294 // Skip to the element corresponding to the flattened index.
295 return (*this)[ElementsAttr::getFlattenedIndex(shapeType, index)];
296}
297} // namespace detail
298
299/// Return the elements of this attribute as a value of type 'T'.
300template <typename T>
301auto ElementsAttr::value_begin() const -> DefaultValueCheckT<T, iterator<T>> {
302 if (std::optional<iterator<T>> iterator = try_value_begin<T>())
303 return std::move(*iterator);
304 llvm::errs()
305 << "ElementsAttr does not provide iteration facilities for type `"
306 << llvm::getTypeName<T>() << "`, see attribute: " << *this << "\n";
307 llvm_unreachable("invalid `T` for ElementsAttr::getValues");
308}
309template <typename T>
310auto ElementsAttr::try_value_begin() const
311 -> DefaultValueCheckT<T, std::optional<iterator<T>>> {
312 FailureOr<detail::ElementsAttrIndexer> indexer =
313 getValuesImpl(TypeID::get<T>());
314 if (failed(indexer))
315 return std::nullopt;
316 return iterator<T>(std::move(*indexer), 0);
317}
318} // namespace mlir.
319
320#endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_H
321

source code of mlir/include/mlir/IR/BuiltinAttributeInterfaces.h