1//===-- runtime/matmul.cpp ------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9// Implements all forms of MATMUL (Fortran 2018 16.9.124)
10//
11// There are two main entry points; one establishes a descriptor for the
12// result and allocates it, and the other expects a result descriptor that
13// points to existing storage.
14//
15// This implementation must handle all combinations of numeric types and
16// kinds (100 - 165 cases depending on the target), plus all combinations
17// of logical kinds (16). A single template undergoes many instantiations
18// to cover all of the valid possibilities.
19//
20// Places where BLAS routines could be called are marked as TODO items.
21
22#include "flang/Runtime/matmul.h"
23#include "terminator.h"
24#include "tools.h"
25#include "flang/Common/optional.h"
26#include "flang/Runtime/c-or-cpp.h"
27#include "flang/Runtime/cpp-type.h"
28#include "flang/Runtime/descriptor.h"
29#include <cstring>
30
31namespace Fortran::runtime {
32
33// Suppress the warnings about calling __host__-only std::complex operators,
34// defined in C++ STD header files, from __device__ code.
35RT_DIAG_PUSH
36RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN
37
38// General accumulator for any type and stride; this is not used for
39// contiguous numeric cases.
40template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
41class Accumulator {
42public:
43 using Result = AccumulationType<RCAT, RKIND>;
44 RT_API_ATTRS Accumulator(const Descriptor &x, const Descriptor &y)
45 : x_{x}, y_{y} {}
46 RT_API_ATTRS void Accumulate(
47 const SubscriptValue xAt[], const SubscriptValue yAt[]) {
48 if constexpr (RCAT == TypeCategory::Logical) {
49 sum_ = sum_ ||
50 (IsLogicalElementTrue(x_, xAt) && IsLogicalElementTrue(y_, yAt));
51 } else {
52 sum_ += static_cast<Result>(*x_.Element<XT>(xAt)) *
53 static_cast<Result>(*y_.Element<YT>(yAt));
54 }
55 }
56 RT_API_ATTRS Result GetResult() const { return sum_; }
57
58private:
59 const Descriptor &x_, &y_;
60 Result sum_{};
61};
62
63// Contiguous numeric matrix*matrix multiplication
64// matrix(rows,n) * matrix(n,cols) -> matrix(rows,cols)
65// Straightforward algorithm:
66// DO 1 I = 1, NROWS
67// DO 1 J = 1, NCOLS
68// RES(I,J) = 0
69// DO 1 K = 1, N
70// 1 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J)
71// With loop distribution and transposition to avoid the inner sum
72// reduction and to avoid non-unit strides:
73// DO 1 I = 1, NROWS
74// DO 1 J = 1, NCOLS
75// 1 RES(I,J) = 0
76// DO 2 K = 1, N
77// DO 2 J = 1, NCOLS
78// DO 2 I = 1, NROWS
79// 2 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) ! loop-invariant last term
80template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
81 bool X_HAS_STRIDED_COLUMNS, bool Y_HAS_STRIDED_COLUMNS>
82inline RT_API_ATTRS void MatrixTimesMatrix(
83 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
84 SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y,
85 SubscriptValue n, std::size_t xColumnByteStride = 0,
86 std::size_t yColumnByteStride = 0) {
87 using ResultType = CppTypeFor<RCAT, RKIND>;
88 std::memset(product, 0, rows * cols * sizeof *product);
89 const XT *RESTRICT xp0{x};
90 for (SubscriptValue k{0}; k < n; ++k) {
91 ResultType *RESTRICT p{product};
92 for (SubscriptValue j{0}; j < cols; ++j) {
93 const XT *RESTRICT xp{xp0};
94 ResultType yv;
95 if constexpr (!Y_HAS_STRIDED_COLUMNS) {
96 yv = static_cast<ResultType>(y[k + j * n]);
97 } else {
98 yv = static_cast<ResultType>(reinterpret_cast<const YT *>(
99 reinterpret_cast<const char *>(y) + j * yColumnByteStride)[k]);
100 }
101 for (SubscriptValue i{0}; i < rows; ++i) {
102 *p++ += static_cast<ResultType>(*xp++) * yv;
103 }
104 }
105 if constexpr (!X_HAS_STRIDED_COLUMNS) {
106 xp0 += rows;
107 } else {
108 xp0 = reinterpret_cast<const XT *>(
109 reinterpret_cast<const char *>(xp0) + xColumnByteStride);
110 }
111 }
112}
113
114RT_DIAG_POP
115
116template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
117inline RT_API_ATTRS void MatrixTimesMatrixHelper(
118 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
119 SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y,
120 SubscriptValue n, Fortran::common::optional<std::size_t> xColumnByteStride,
121 Fortran::common::optional<std::size_t> yColumnByteStride) {
122 if (!xColumnByteStride) {
123 if (!yColumnByteStride) {
124 MatrixTimesMatrix<RCAT, RKIND, XT, YT, false, false>(
125 product, rows, cols, x, y, n);
126 } else {
127 MatrixTimesMatrix<RCAT, RKIND, XT, YT, false, true>(
128 product, rows, cols, x, y, n, 0, *yColumnByteStride);
129 }
130 } else {
131 if (!yColumnByteStride) {
132 MatrixTimesMatrix<RCAT, RKIND, XT, YT, true, false>(
133 product, rows, cols, x, y, n, *xColumnByteStride);
134 } else {
135 MatrixTimesMatrix<RCAT, RKIND, XT, YT, true, true>(
136 product, rows, cols, x, y, n, *xColumnByteStride, *yColumnByteStride);
137 }
138 }
139}
140
141RT_DIAG_PUSH
142RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN
143
144// Contiguous numeric matrix*vector multiplication
145// matrix(rows,n) * column vector(n) -> column vector(rows)
146// Straightforward algorithm:
147// DO 1 J = 1, NROWS
148// RES(J) = 0
149// DO 1 K = 1, N
150// 1 RES(J) = RES(J) + X(J,K)*Y(K)
151// With loop distribution and transposition to avoid the inner
152// sum reduction and to avoid non-unit strides:
153// DO 1 J = 1, NROWS
154// 1 RES(J) = 0
155// DO 2 K = 1, N
156// DO 2 J = 1, NROWS
157// 2 RES(J) = RES(J) + X(J,K)*Y(K)
158template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
159 bool X_HAS_STRIDED_COLUMNS>
160inline RT_API_ATTRS void MatrixTimesVector(
161 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
162 SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y,
163 std::size_t xColumnByteStride = 0) {
164 using ResultType = CppTypeFor<RCAT, RKIND>;
165 std::memset(product, 0, rows * sizeof *product);
166 [[maybe_unused]] const XT *RESTRICT xp0{x};
167 for (SubscriptValue k{0}; k < n; ++k) {
168 ResultType *RESTRICT p{product};
169 auto yv{static_cast<ResultType>(*y++)};
170 for (SubscriptValue j{0}; j < rows; ++j) {
171 *p++ += static_cast<ResultType>(*x++) * yv;
172 }
173 if constexpr (X_HAS_STRIDED_COLUMNS) {
174 xp0 = reinterpret_cast<const XT *>(
175 reinterpret_cast<const char *>(xp0) + xColumnByteStride);
176 x = xp0;
177 }
178 }
179}
180
181RT_DIAG_POP
182
183template <TypeCategory RCAT, int RKIND, typename XT, typename YT>
184inline RT_API_ATTRS void MatrixTimesVectorHelper(
185 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue rows,
186 SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y,
187 Fortran::common::optional<std::size_t> xColumnByteStride) {
188 if (!xColumnByteStride) {
189 MatrixTimesVector<RCAT, RKIND, XT, YT, false>(product, rows, n, x, y);
190 } else {
191 MatrixTimesVector<RCAT, RKIND, XT, YT, true>(
192 product, rows, n, x, y, *xColumnByteStride);
193 }
194}
195
196RT_DIAG_PUSH
197RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN
198
199// Contiguous numeric vector*matrix multiplication
200// row vector(n) * matrix(n,cols) -> row vector(cols)
201// Straightforward algorithm:
202// DO 1 J = 1, NCOLS
203// RES(J) = 0
204// DO 1 K = 1, N
205// 1 RES(J) = RES(J) + X(K)*Y(K,J)
206// With loop distribution and transposition to avoid the inner
207// sum reduction and one non-unit stride (the other remains):
208// DO 1 J = 1, NCOLS
209// 1 RES(J) = 0
210// DO 2 K = 1, N
211// DO 2 J = 1, NCOLS
212// 2 RES(J) = RES(J) + X(K)*Y(K,J)
213template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
214 bool Y_HAS_STRIDED_COLUMNS>
215inline RT_API_ATTRS void VectorTimesMatrix(
216 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue n,
217 SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y,
218 std::size_t yColumnByteStride = 0) {
219 using ResultType = CppTypeFor<RCAT, RKIND>;
220 std::memset(product, 0, cols * sizeof *product);
221 for (SubscriptValue k{0}; k < n; ++k) {
222 ResultType *RESTRICT p{product};
223 auto xv{static_cast<ResultType>(*x++)};
224 const YT *RESTRICT yp{&y[k]};
225 for (SubscriptValue j{0}; j < cols; ++j) {
226 *p++ += xv * static_cast<ResultType>(*yp);
227 if constexpr (!Y_HAS_STRIDED_COLUMNS) {
228 yp += n;
229 } else {
230 yp = reinterpret_cast<const YT *>(
231 reinterpret_cast<const char *>(yp) + yColumnByteStride);
232 }
233 }
234 }
235}
236
237RT_DIAG_POP
238
239template <TypeCategory RCAT, int RKIND, typename XT, typename YT,
240 bool SPARSE_COLUMNS = false>
241inline RT_API_ATTRS void VectorTimesMatrixHelper(
242 CppTypeFor<RCAT, RKIND> *RESTRICT product, SubscriptValue n,
243 SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y,
244 Fortran::common::optional<std::size_t> yColumnByteStride) {
245 if (!yColumnByteStride) {
246 VectorTimesMatrix<RCAT, RKIND, XT, YT, false>(product, n, cols, x, y);
247 } else {
248 VectorTimesMatrix<RCAT, RKIND, XT, YT, true>(
249 product, n, cols, x, y, *yColumnByteStride);
250 }
251}
252
253RT_DIAG_PUSH
254RT_DIAG_DISABLE_CALL_HOST_FROM_DEVICE_WARN
255
256// Implements an instance of MATMUL for given argument types.
257template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT,
258 typename YT>
259static inline RT_API_ATTRS void DoMatmul(
260 std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor> &result,
261 const Descriptor &x, const Descriptor &y, Terminator &terminator) {
262 int xRank{x.rank()};
263 int yRank{y.rank()};
264 int resRank{xRank + yRank - 2};
265 if (xRank * yRank != 2 * resRank) {
266 terminator.Crash("MATMUL: bad argument ranks (%d * %d)", xRank, yRank);
267 }
268 SubscriptValue extent[2]{
269 xRank == 2 ? x.GetDimension(0).Extent() : y.GetDimension(1).Extent(),
270 resRank == 2 ? y.GetDimension(1).Extent() : 0};
271 if constexpr (IS_ALLOCATING) {
272 result.Establish(
273 RCAT, RKIND, nullptr, resRank, extent, CFI_attribute_allocatable);
274 for (int j{0}; j < resRank; ++j) {
275 result.GetDimension(j).SetBounds(1, extent[j]);
276 }
277 if (int stat{result.Allocate()}) {
278 terminator.Crash(
279 "MATMUL: could not allocate memory for result; STAT=%d", stat);
280 }
281 } else {
282 RUNTIME_CHECK(terminator, resRank == result.rank());
283 RUNTIME_CHECK(
284 terminator, result.ElementBytes() == static_cast<std::size_t>(RKIND));
285 RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]);
286 RUNTIME_CHECK(terminator,
287 resRank == 1 || result.GetDimension(1).Extent() == extent[1]);
288 }
289 SubscriptValue n{x.GetDimension(xRank - 1).Extent()};
290 if (n != y.GetDimension(0).Extent()) {
291 terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
292 static_cast<std::intmax_t>(x.GetDimension(0).Extent()),
293 static_cast<std::intmax_t>(n),
294 static_cast<std::intmax_t>(y.GetDimension(0).Extent()),
295 static_cast<std::intmax_t>(y.GetDimension(1).Extent()));
296 }
297 using WriteResult =
298 CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT,
299 RKIND>;
300 if constexpr (RCAT != TypeCategory::Logical) {
301 if (x.IsContiguous(1) && y.IsContiguous(1) &&
302 (IS_ALLOCATING || result.IsContiguous())) {
303 // Contiguous numeric matrices (maybe with columns
304 // separated by a stride).
305 Fortran::common::optional<std::size_t> xColumnByteStride;
306 if (!x.IsContiguous()) {
307 // X's columns are strided.
308 SubscriptValue xAt[2]{};
309 x.GetLowerBounds(xAt);
310 xAt[1]++;
311 xColumnByteStride = x.SubscriptsToByteOffset(xAt);
312 }
313 Fortran::common::optional<std::size_t> yColumnByteStride;
314 if (!y.IsContiguous()) {
315 // Y's columns are strided.
316 SubscriptValue yAt[2]{};
317 y.GetLowerBounds(yAt);
318 yAt[1]++;
319 yColumnByteStride = y.SubscriptsToByteOffset(yAt);
320 }
321 // Note that BLAS GEMM can be used for the strided
322 // columns by setting proper leading dimension size.
323 // This implies that the column stride is divisible
324 // by the element size, which is usually true.
325 if (resRank == 2) { // M*M -> M
326 if (std::is_same_v<XT, YT>) {
327 if constexpr (std::is_same_v<XT, float>) {
328 // TODO: call BLAS-3 SGEMM
329 // TODO: try using CUTLASS for device.
330 } else if constexpr (std::is_same_v<XT, double>) {
331 // TODO: call BLAS-3 DGEMM
332 } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
333 // TODO: call BLAS-3 CGEMM
334 } else if constexpr (std::is_same_v<XT, std::complex<double>>) {
335 // TODO: call BLAS-3 ZGEMM
336 }
337 }
338 MatrixTimesMatrixHelper<RCAT, RKIND, XT, YT>(
339 result.template OffsetElement<WriteResult>(), extent[0], extent[1],
340 x.OffsetElement<XT>(), y.OffsetElement<YT>(), n, xColumnByteStride,
341 yColumnByteStride);
342 return;
343 } else if (xRank == 2) { // M*V -> V
344 if (std::is_same_v<XT, YT>) {
345 if constexpr (std::is_same_v<XT, float>) {
346 // TODO: call BLAS-2 SGEMV(x,y)
347 } else if constexpr (std::is_same_v<XT, double>) {
348 // TODO: call BLAS-2 DGEMV(x,y)
349 } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
350 // TODO: call BLAS-2 CGEMV(x,y)
351 } else if constexpr (std::is_same_v<XT, std::complex<double>>) {
352 // TODO: call BLAS-2 ZGEMV(x,y)
353 }
354 }
355 MatrixTimesVectorHelper<RCAT, RKIND, XT, YT>(
356 result.template OffsetElement<WriteResult>(), extent[0], n,
357 x.OffsetElement<XT>(), y.OffsetElement<YT>(), xColumnByteStride);
358 return;
359 } else { // V*M -> V
360 if (std::is_same_v<XT, YT>) {
361 if constexpr (std::is_same_v<XT, float>) {
362 // TODO: call BLAS-2 SGEMV(y,x)
363 } else if constexpr (std::is_same_v<XT, double>) {
364 // TODO: call BLAS-2 DGEMV(y,x)
365 } else if constexpr (std::is_same_v<XT, std::complex<float>>) {
366 // TODO: call BLAS-2 CGEMV(y,x)
367 } else if constexpr (std::is_same_v<XT, std::complex<double>>) {
368 // TODO: call BLAS-2 ZGEMV(y,x)
369 }
370 }
371 VectorTimesMatrixHelper<RCAT, RKIND, XT, YT>(
372 result.template OffsetElement<WriteResult>(), n, extent[0],
373 x.OffsetElement<XT>(), y.OffsetElement<YT>(), yColumnByteStride);
374 return;
375 }
376 }
377 }
378 // General algorithms for LOGICAL and noncontiguity
379 SubscriptValue xAt[2], yAt[2], resAt[2];
380 x.GetLowerBounds(xAt);
381 y.GetLowerBounds(yAt);
382 result.GetLowerBounds(resAt);
383 if (resRank == 2) { // M*M -> M
384 SubscriptValue x1{xAt[1]}, y0{yAt[0]}, y1{yAt[1]}, res1{resAt[1]};
385 for (SubscriptValue i{0}; i < extent[0]; ++i) {
386 for (SubscriptValue j{0}; j < extent[1]; ++j) {
387 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
388 yAt[1] = y1 + j;
389 for (SubscriptValue k{0}; k < n; ++k) {
390 xAt[1] = x1 + k;
391 yAt[0] = y0 + k;
392 accumulator.Accumulate(xAt, yAt);
393 }
394 resAt[1] = res1 + j;
395 *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
396 }
397 ++resAt[0];
398 ++xAt[0];
399 }
400 } else if (xRank == 2) { // M*V -> V
401 SubscriptValue x1{xAt[1]}, y0{yAt[0]};
402 for (SubscriptValue j{0}; j < extent[0]; ++j) {
403 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
404 for (SubscriptValue k{0}; k < n; ++k) {
405 xAt[1] = x1 + k;
406 yAt[0] = y0 + k;
407 accumulator.Accumulate(xAt, yAt);
408 }
409 *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
410 ++resAt[0];
411 ++xAt[0];
412 }
413 } else { // V*M -> V
414 SubscriptValue x0{xAt[0]}, y0{yAt[0]};
415 for (SubscriptValue j{0}; j < extent[0]; ++j) {
416 Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y};
417 for (SubscriptValue k{0}; k < n; ++k) {
418 xAt[0] = x0 + k;
419 yAt[0] = y0 + k;
420 accumulator.Accumulate(xAt, yAt);
421 }
422 *result.template Element<WriteResult>(resAt) = accumulator.GetResult();
423 ++resAt[0];
424 ++yAt[1];
425 }
426 }
427}
428
429RT_DIAG_POP
430
431// Maps the dynamic type information from the arguments' descriptors
432// to the right instantiation of DoMatmul() for valid combinations of
433// types.
434template <bool IS_ALLOCATING> struct Matmul {
435 using ResultDescriptor =
436 std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>;
437 template <TypeCategory XCAT, int XKIND> struct MM1 {
438 template <TypeCategory YCAT, int YKIND> struct MM2 {
439 RT_API_ATTRS void operator()(ResultDescriptor &result,
440 const Descriptor &x, const Descriptor &y,
441 Terminator &terminator) const {
442 if constexpr (constexpr auto resultType{
443 GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
444 if constexpr (common::IsNumericTypeCategory(resultType->first) ||
445 resultType->first == TypeCategory::Logical) {
446 return DoMatmul<IS_ALLOCATING, resultType->first,
447 resultType->second, CppTypeFor<XCAT, XKIND>,
448 CppTypeFor<YCAT, YKIND>>(result, x, y, terminator);
449 }
450 }
451 terminator.Crash("MATMUL: bad operand types (%d(%d), %d(%d))",
452 static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND);
453 }
454 };
455 RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x,
456 const Descriptor &y, Terminator &terminator, TypeCategory yCat,
457 int yKind) const {
458 ApplyType<MM2, void>(yCat, yKind, terminator, result, x, y, terminator);
459 }
460 };
461 RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x,
462 const Descriptor &y, const char *sourceFile, int line) const {
463 Terminator terminator{sourceFile, line};
464 auto xCatKind{x.type().GetCategoryAndKind()};
465 auto yCatKind{y.type().GetCategoryAndKind()};
466 RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
467 ApplyType<MM1, void>(xCatKind->first, xCatKind->second, terminator, result,
468 x, y, terminator, yCatKind->first, yCatKind->second);
469 }
470};
471
472extern "C" {
473RT_EXT_API_GROUP_BEGIN
474
475void RTDEF(Matmul)(Descriptor &result, const Descriptor &x, const Descriptor &y,
476 const char *sourceFile, int line) {
477 Matmul<true>{}(result, x, y, sourceFile, line);
478}
479void RTDEF(MatmulDirect)(const Descriptor &result, const Descriptor &x,
480 const Descriptor &y, const char *sourceFile, int line) {
481 Matmul<false>{}(result, x, y, sourceFile, line);
482}
483
484RT_EXT_API_GROUP_END
485} // extern "C"
486} // namespace Fortran::runtime
487

source code of flang/runtime/matmul.cpp