1 | //===- Math.h - PBQP Vector and Matrix classes ------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef LLVM_CODEGEN_PBQP_MATH_H |
10 | #define LLVM_CODEGEN_PBQP_MATH_H |
11 | |
12 | #include "llvm/ADT/Hashing.h" |
13 | #include "llvm/ADT/STLExtras.h" |
14 | #include <algorithm> |
15 | #include <cassert> |
16 | #include <functional> |
17 | #include <memory> |
18 | |
19 | namespace llvm { |
20 | namespace PBQP { |
21 | |
22 | using PBQPNum = float; |
23 | |
24 | /// PBQP Vector class. |
25 | class Vector { |
26 | friend hash_code hash_value(const Vector &); |
27 | |
28 | public: |
29 | /// Construct a PBQP vector of the given size. |
30 | explicit Vector(unsigned Length) |
31 | : Length(Length), Data(std::make_unique<PBQPNum []>(num: Length)) {} |
32 | |
33 | /// Construct a PBQP vector with initializer. |
34 | Vector(unsigned Length, PBQPNum InitVal) |
35 | : Length(Length), Data(std::make_unique<PBQPNum []>(num: Length)) { |
36 | std::fill(first: Data.get(), last: Data.get() + Length, value: InitVal); |
37 | } |
38 | |
39 | /// Copy construct a PBQP vector. |
40 | Vector(const Vector &V) |
41 | : Length(V.Length), Data(std::make_unique<PBQPNum []>(num: Length)) { |
42 | std::copy(first: V.Data.get(), last: V.Data.get() + Length, result: Data.get()); |
43 | } |
44 | |
45 | /// Move construct a PBQP vector. |
46 | Vector(Vector &&V) |
47 | : Length(V.Length), Data(std::move(V.Data)) { |
48 | V.Length = 0; |
49 | } |
50 | |
51 | /// Comparison operator. |
52 | bool operator==(const Vector &V) const { |
53 | assert(Length != 0 && Data && "Invalid vector" ); |
54 | if (Length != V.Length) |
55 | return false; |
56 | return std::equal(first1: Data.get(), last1: Data.get() + Length, first2: V.Data.get()); |
57 | } |
58 | |
59 | /// Return the length of the vector |
60 | unsigned getLength() const { |
61 | assert(Length != 0 && Data && "Invalid vector" ); |
62 | return Length; |
63 | } |
64 | |
65 | /// Element access. |
66 | PBQPNum& operator[](unsigned Index) { |
67 | assert(Length != 0 && Data && "Invalid vector" ); |
68 | assert(Index < Length && "Vector element access out of bounds." ); |
69 | return Data[Index]; |
70 | } |
71 | |
72 | /// Const element access. |
73 | const PBQPNum& operator[](unsigned Index) const { |
74 | assert(Length != 0 && Data && "Invalid vector" ); |
75 | assert(Index < Length && "Vector element access out of bounds." ); |
76 | return Data[Index]; |
77 | } |
78 | |
79 | /// Add another vector to this one. |
80 | Vector& operator+=(const Vector &V) { |
81 | assert(Length != 0 && Data && "Invalid vector" ); |
82 | assert(Length == V.Length && "Vector length mismatch." ); |
83 | std::transform(first1: Data.get(), last1: Data.get() + Length, first2: V.Data.get(), result: Data.get(), |
84 | binary_op: std::plus<PBQPNum>()); |
85 | return *this; |
86 | } |
87 | |
88 | /// Returns the index of the minimum value in this vector |
89 | unsigned minIndex() const { |
90 | assert(Length != 0 && Data && "Invalid vector" ); |
91 | return std::min_element(first: Data.get(), last: Data.get() + Length) - Data.get(); |
92 | } |
93 | |
94 | private: |
95 | unsigned Length; |
96 | std::unique_ptr<PBQPNum []> Data; |
97 | }; |
98 | |
99 | /// Return a hash_value for the given vector. |
100 | inline hash_code hash_value(const Vector &V) { |
101 | unsigned *VBegin = reinterpret_cast<unsigned*>(V.Data.get()); |
102 | unsigned *VEnd = reinterpret_cast<unsigned*>(V.Data.get() + V.Length); |
103 | return hash_combine(args: V.Length, args: hash_combine_range(first: VBegin, last: VEnd)); |
104 | } |
105 | |
106 | /// Output a textual representation of the given vector on the given |
107 | /// output stream. |
108 | template <typename OStream> |
109 | OStream& operator<<(OStream &OS, const Vector &V) { |
110 | assert((V.getLength() != 0) && "Zero-length vector badness." ); |
111 | |
112 | OS << "[ " << V[0]; |
113 | for (unsigned i = 1; i < V.getLength(); ++i) |
114 | OS << ", " << V[i]; |
115 | OS << " ]" ; |
116 | |
117 | return OS; |
118 | } |
119 | |
120 | /// PBQP Matrix class |
121 | class Matrix { |
122 | private: |
123 | friend hash_code hash_value(const Matrix &); |
124 | |
125 | public: |
126 | /// Construct a PBQP Matrix with the given dimensions. |
127 | Matrix(unsigned Rows, unsigned Cols) : |
128 | Rows(Rows), Cols(Cols), Data(std::make_unique<PBQPNum []>(num: Rows * Cols)) { |
129 | } |
130 | |
131 | /// Construct a PBQP Matrix with the given dimensions and initial |
132 | /// value. |
133 | Matrix(unsigned Rows, unsigned Cols, PBQPNum InitVal) |
134 | : Rows(Rows), Cols(Cols), |
135 | Data(std::make_unique<PBQPNum []>(num: Rows * Cols)) { |
136 | std::fill(first: Data.get(), last: Data.get() + (Rows * Cols), value: InitVal); |
137 | } |
138 | |
139 | /// Copy construct a PBQP matrix. |
140 | Matrix(const Matrix &M) |
141 | : Rows(M.Rows), Cols(M.Cols), |
142 | Data(std::make_unique<PBQPNum []>(num: Rows * Cols)) { |
143 | std::copy(first: M.Data.get(), last: M.Data.get() + (Rows * Cols), result: Data.get()); |
144 | } |
145 | |
146 | /// Move construct a PBQP matrix. |
147 | Matrix(Matrix &&M) |
148 | : Rows(M.Rows), Cols(M.Cols), Data(std::move(M.Data)) { |
149 | M.Rows = M.Cols = 0; |
150 | } |
151 | |
152 | /// Comparison operator. |
153 | bool operator==(const Matrix &M) const { |
154 | assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix" ); |
155 | if (Rows != M.Rows || Cols != M.Cols) |
156 | return false; |
157 | return std::equal(first1: Data.get(), last1: Data.get() + (Rows * Cols), first2: M.Data.get()); |
158 | } |
159 | |
160 | /// Return the number of rows in this matrix. |
161 | unsigned getRows() const { |
162 | assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix" ); |
163 | return Rows; |
164 | } |
165 | |
166 | /// Return the number of cols in this matrix. |
167 | unsigned getCols() const { |
168 | assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix" ); |
169 | return Cols; |
170 | } |
171 | |
172 | /// Matrix element access. |
173 | PBQPNum* operator[](unsigned R) { |
174 | assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix" ); |
175 | assert(R < Rows && "Row out of bounds." ); |
176 | return Data.get() + (R * Cols); |
177 | } |
178 | |
179 | /// Matrix element access. |
180 | const PBQPNum* operator[](unsigned R) const { |
181 | assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix" ); |
182 | assert(R < Rows && "Row out of bounds." ); |
183 | return Data.get() + (R * Cols); |
184 | } |
185 | |
186 | /// Returns the given row as a vector. |
187 | Vector getRowAsVector(unsigned R) const { |
188 | assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix" ); |
189 | Vector V(Cols); |
190 | for (unsigned C = 0; C < Cols; ++C) |
191 | V[C] = (*this)[R][C]; |
192 | return V; |
193 | } |
194 | |
195 | /// Returns the given column as a vector. |
196 | Vector getColAsVector(unsigned C) const { |
197 | assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix" ); |
198 | Vector V(Rows); |
199 | for (unsigned R = 0; R < Rows; ++R) |
200 | V[R] = (*this)[R][C]; |
201 | return V; |
202 | } |
203 | |
204 | /// Matrix transpose. |
205 | Matrix transpose() const { |
206 | assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix" ); |
207 | Matrix M(Cols, Rows); |
208 | for (unsigned r = 0; r < Rows; ++r) |
209 | for (unsigned c = 0; c < Cols; ++c) |
210 | M[c][r] = (*this)[r][c]; |
211 | return M; |
212 | } |
213 | |
214 | /// Add the given matrix to this one. |
215 | Matrix& operator+=(const Matrix &M) { |
216 | assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix" ); |
217 | assert(Rows == M.Rows && Cols == M.Cols && |
218 | "Matrix dimensions mismatch." ); |
219 | std::transform(first1: Data.get(), last1: Data.get() + (Rows * Cols), first2: M.Data.get(), |
220 | result: Data.get(), binary_op: std::plus<PBQPNum>()); |
221 | return *this; |
222 | } |
223 | |
224 | Matrix operator+(const Matrix &M) { |
225 | assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix" ); |
226 | Matrix Tmp(*this); |
227 | Tmp += M; |
228 | return Tmp; |
229 | } |
230 | |
231 | private: |
232 | unsigned Rows, Cols; |
233 | std::unique_ptr<PBQPNum []> Data; |
234 | }; |
235 | |
236 | /// Return a hash_code for the given matrix. |
237 | inline hash_code hash_value(const Matrix &M) { |
238 | unsigned *MBegin = reinterpret_cast<unsigned*>(M.Data.get()); |
239 | unsigned *MEnd = |
240 | reinterpret_cast<unsigned*>(M.Data.get() + (M.Rows * M.Cols)); |
241 | return hash_combine(args: M.Rows, args: M.Cols, args: hash_combine_range(first: MBegin, last: MEnd)); |
242 | } |
243 | |
244 | /// Output a textual representation of the given matrix on the given |
245 | /// output stream. |
246 | template <typename OStream> |
247 | OStream& operator<<(OStream &OS, const Matrix &M) { |
248 | assert((M.getRows() != 0) && "Zero-row matrix badness." ); |
249 | for (unsigned i = 0; i < M.getRows(); ++i) |
250 | OS << M.getRowAsVector(R: i) << "\n" ; |
251 | return OS; |
252 | } |
253 | |
254 | template <typename Metadata> |
255 | class MDVector : public Vector { |
256 | public: |
257 | MDVector(const Vector &v) : Vector(v), md(*this) {} |
258 | MDVector(Vector &&v) : Vector(std::move(v)), md(*this) { } |
259 | |
260 | const Metadata& getMetadata() const { return md; } |
261 | |
262 | private: |
263 | Metadata md; |
264 | }; |
265 | |
266 | template <typename Metadata> |
267 | inline hash_code hash_value(const MDVector<Metadata> &V) { |
268 | return hash_value(V: static_cast<const Vector&>(V)); |
269 | } |
270 | |
271 | template <typename Metadata> |
272 | class MDMatrix : public Matrix { |
273 | public: |
274 | MDMatrix(const Matrix &m) : Matrix(m), md(*this) {} |
275 | MDMatrix(Matrix &&m) : Matrix(std::move(m)), md(*this) { } |
276 | |
277 | const Metadata& getMetadata() const { return md; } |
278 | |
279 | private: |
280 | Metadata md; |
281 | }; |
282 | |
283 | template <typename Metadata> |
284 | inline hash_code hash_value(const MDMatrix<Metadata> &M) { |
285 | return hash_value(M: static_cast<const Matrix&>(M)); |
286 | } |
287 | |
288 | } // end namespace PBQP |
289 | } // end namespace llvm |
290 | |
291 | #endif // LLVM_CODEGEN_PBQP_MATH_H |
292 | |