1//===- Transforms.h - Linalg transformations as patterns --------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H
10#define MLIR_DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H
11
12#include <utility>
13
14#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
15#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
16#include "mlir/Dialect/Linalg/Utils/Utils.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Dialect/SCF/Utils/Utils.h"
19#include "mlir/Dialect/Tensor/IR/Tensor.h"
20#include "mlir/Dialect/Utils/StaticValueUtils.h"
21#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
22#include "mlir/Dialect/X86Vector/Transforms.h"
23#include "mlir/IR/PatternMatch.h"
24#include "mlir/Interfaces/TilingInterface.h"
25#include "mlir/Support/LogicalResult.h"
26#include "mlir/Transforms/DialectConversion.h"
27#include "llvm/ADT/SmallBitVector.h"
28#include "llvm/ADT/SmallSet.h"
29
30namespace mlir {
31namespace bufferization {
32class AllocTensorOp;
33class OneShotAnalysisState;
34} // namespace bufferization
35
36namespace linalg {
37
38class LinalgOp;
39
40//===----------------------------------------------------------------------===//
41// Utils.
42//===----------------------------------------------------------------------===//
43
44/// Return vector::CombiningKind for the given op.
45std::optional<vector::CombiningKind> getCombinerOpKind(Operation *combinerOp);
46
47//===----------------------------------------------------------------------===//
48// Bufferization-related transforms.
49//===----------------------------------------------------------------------===//
50
51struct BufferizeToAllocationOptions {
52 enum class AllocOp { MemrefAlloc = 0, MemrefAlloca = 1 };
53 AllocOp allocOp = AllocOp::MemrefAlloc;
54
55 enum class MemcpyOp {
56 MaterializeInDestination = 0,
57 MemrefCopy = 1,
58 LinalgCopy = 2
59 };
60 MemcpyOp memcpyOp = MemcpyOp::MaterializeInDestination;
61
62 /// If set to "true", only the destination tensor operands are bufferized to
63 /// a new allocation (and wrapped in "bufferization.to_tensor"), but not the
64 /// targeted op itself.
65 bool bufferizeDestinationOnly = false;
66
67 /// If set to "true", a `memref.dealloc` operation will be emitted for each
68 /// allocated buffer. Otherwise, the memory is leaked, which is useful if
69 /// the buffer deallocation pipeline should be run after bufferization is
70 /// done.
71 bool emitDealloc = false;
72};
73
74/// Materialize a buffer allocation for the given tensor.pad op and lower the
75/// op to linalg.fill/linalg.generic + bufferization.materialize_in_destination.
76/// E.g.:
77///
78/// %0 = tensor.pad low[%l] high[%h] %t ...
79///
80/// is lowered to:
81///
82/// %alloc = memref.alloc
83/// linalg.fill ... outs(%alloc)
84/// %subview = memref.subview %alloc [%l] [...] [1]
85/// bufferization.materialize_in_destination %t in %subview
86/// %0 = bufferization.to_tensor %alloc restrict writable
87///
88/// In addition to rewriting the IR as shown above, this function returns the
89/// newly allocated buffer. The `insertionPoint` parameter can be used to
90/// specify a custom insertion point for the buffer allocation.
91Value bufferizeToAllocation(RewriterBase &rewriter,
92 const BufferizeToAllocationOptions &options,
93 tensor::PadOp padOp, Attribute memorySpace = {},
94 Operation *insertionPoint = nullptr);
95
96/// Materialize a buffer allocation for the given vector.mask op and bufferize
97/// the op, including its region. E.g.:
98///
99/// %0 = vector.mask {
100/// vector.transfer_write %v, %t : vector<16xf32>, tensor<?xf32>
101/// } : vector<16xi1> -> tensor<?xf32>
102///
103/// is lowered to:
104///
105/// %alloc = memref.alloc
106/// bufferization.materialize_in_destination %t in %subview
107/// vector.mask {
108/// vector.transfer_write %arg0, %alloc : vector<16xf32>, memref<?xf32>
109/// } : vector<16xi1>
110/// %0 = bufferization.to_tensor %alloc restrict writable
111///
112/// In addition to rewriting the IR as shown above, this function returns the
113/// newly allocated buffer. The `insertionPoint` parameter can be used to
114/// specify a custom insertion point for the buffer allocation.
115Value bufferizeToAllocation(RewriterBase &rewriter,
116 const BufferizeToAllocationOptions &options,
117 vector::MaskOp maskOp, Attribute memorySpace = {},
118 Operation *insertionPoint = nullptr);
119
120/// Materialize a buffer allocation for the given bufferization.alloc_tensor op
121/// and lower the op to memref.alloc + memref.tensor_store.
122///
123/// In addition to rewriting the IR, this function returns the newly allocated
124/// buffer. The `insertionPoint` parameter can be used to specify a custom
125/// insertion point for the buffer allocation.
126Value bufferizeToAllocation(RewriterBase &rewriter,
127 const BufferizeToAllocationOptions &options,
128 bufferization::AllocTensorOp allocTensorOp,
129 Attribute memorySpace = {},
130 Operation *insertionPoint = nullptr);
131
132/// Bufferize the given op with tensor semantics and materialize the result in
133/// a newly allocated buffer.
134///
135/// Only bufferizable ops that bufferize to a memory write or have an
136/// aliasing OpOperand (and do not themselves bufferize to an allocation) are
137/// supported. They are bufferized using their BufferizableOpInterface
138/// implementation.
139///
140/// Selected ops that bufferize to an allocation (or need special handling) are
141/// also supported:
142/// - tensor.pad
143/// - vector.mask
144///
145/// This function returns the newly allocated buffer. The `insertionPoint`
146/// parameter can be used to specify a custom insertion point for the buffer
147/// allocation.
148Value bufferizeToAllocation(RewriterBase &rewriter,
149 const BufferizeToAllocationOptions &options,
150 Operation *op, Attribute memorySpace = {},
151 Operation *insertionPoint = nullptr);
152
153/// Try to eliminate tensor::EmptyOps inside `op` that are anchored on a
154/// LinalgOp. This transforms looks for LinalgOps that have an unused output
155/// operand and an input operand that is rooted in a tensor::EmptyOp. The
156/// tensor::EmptyOp uses are replaced with the output operand and the two
157/// operands of the LinalgOp are swapped.
158///
159/// Example:
160/// %0 = tensor.empty()
161/// %1 = linalg.matmul ins(...) outs(%0)
162/// %2 = linalg.generic ins(%1) outs(%dest) {
163/// ^bb0(%in: f32, %out: f32):
164/// // out not used
165/// }
166///
167/// The IR is transformed as follows:
168/// %0 = tensor.empty()
169/// %1 = linalg.matmul ins(...) outs(%dest)
170/// %2 = linalg.generic ins(%0) outs(%1) {
171/// ^bb0(%in: f32, %out: f32):
172/// // Use %out instead of %in
173/// }
174///
175/// The "ins" operand has no uses inside the body of the LinalgOp and can be
176/// folded away with existing cleanup patterns. Afterwards, the tensor::EmptyOp
177/// can also fold away.
178LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(
179 RewriterBase &rewriter, Operation *op,
180 bufferization::OneShotAnalysisState &state);
181
182//===----------------------------------------------------------------------===//
183// Structs that configure the behavior of various transformations.
184//===----------------------------------------------------------------------===//
185
186using TileSizeComputationFunction =
187 std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>;
188
189struct LinalgTilingOptions {
190 /// Computation function that returns the tile sizes for each operation.
191 /// Delayed construction of constant tile sizes should occur to interoperate
192 /// with folding.
193 TileSizeComputationFunction tileSizeComputationFunction = nullptr;
194
195 LinalgTilingOptions &
196 setTileSizeComputationFunction(TileSizeComputationFunction fun) {
197 tileSizeComputationFunction = std::move(fun);
198 return *this;
199 }
200 /// Set the `tileSizeComputationFunction` to return the values `ts`. The
201 /// values must not fold away when tiling. Otherwise, use a more robust
202 /// `tileSizeComputationFunction`.
203 LinalgTilingOptions &setTileSizes(const SmallVector<Value, 4> &ts) {
204 tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; };
205 return *this;
206 }
207 /// Convenience function to set the `tileSizeComputationFunction` to a
208 /// function that computes tile sizes at the point they are needed. Allows
209 /// proper interaction with folding.
210 LinalgTilingOptions &setTileSizes(ArrayRef<int64_t> ts);
211
212 /// Tile all dynamic dimensions by 1. I.e., scalarize those dimensions.
213 /// Note: `scalarizeDynamicDims` and `setTileSizes` cannot be used together.
214 LinalgTilingOptions &scalarizeDynamicDims();
215
216 /// The interchange vector to reorder the tiled loops.
217 SmallVector<unsigned, 4> interchangeVector = {};
218
219 LinalgTilingOptions &setInterchange(ArrayRef<unsigned> interchange) {
220 interchangeVector.assign(in_start: interchange.begin(), in_end: interchange.end());
221 return *this;
222 }
223
224 /// The type of tile loops to generate.
225 LinalgTilingLoopType loopType = LinalgTilingLoopType::Loops;
226
227 LinalgTilingOptions &setLoopType(LinalgTilingLoopType lt) {
228 loopType = lt;
229 return *this;
230 }
231
232 /// When specified, specifies distribution of generated tile loops to
233 /// processors.
234 std::optional<LinalgLoopDistributionOptions> distribution;
235
236 LinalgTilingOptions &
237 setDistributionOptions(LinalgLoopDistributionOptions distributionOptions) {
238 distribution = std::move(distributionOptions);
239 return *this;
240 }
241
242 /// Specification markers of how to distribute the `linalg.tiled_loop`.
243 SmallVector<StringRef, 2> distributionTypes = {};
244
245 LinalgTilingOptions &setDistributionTypes(ArrayRef<StringRef> types) {
246 distributionTypes.assign(in_start: types.begin(), in_end: types.end());
247 return *this;
248 }
249
250 /// Peel the specified loops.
251 SmallVector<int64_t> peeledLoops;
252
253 LinalgTilingOptions &setPeeledLoops(ArrayRef<int64_t> loops) {
254 peeledLoops.clear();
255 peeledLoops.append(in_start: loops.begin(), in_end: loops.end());
256 return *this;
257 }
258};
259
260struct LinalgTilingAndFusionOptions {
261 /// Tile sizes used to tile the root operation.
262 SmallVector<int64_t> tileSizes;
263 LinalgTilingAndFusionOptions &setTileSizes(ArrayRef<int64_t> ts) {
264 tileSizes.assign(in_start: ts.begin(), in_end: ts.end());
265 return *this;
266 }
267 /// Tile interchange used to permute the tile loops.
268 SmallVector<int64_t> tileInterchange;
269 /// When specified, specifies distribution of generated tile loops to
270 /// processors.
271 std::optional<LinalgLoopDistributionOptions> tileDistribution;
272 LinalgTilingAndFusionOptions &
273 setDistributionOptions(LinalgLoopDistributionOptions distributionOptions) {
274 tileDistribution = std::move(distributionOptions);
275 return *this;
276 }
277};
278
279struct LinalgPaddingOptions {
280 /// A padding value for every operand.
281 SmallVector<Attribute> paddingValues;
282 LinalgPaddingOptions &setPaddingValues(ArrayRef<Attribute> pv) {
283 paddingValues.assign(in_start: pv.begin(), in_end: pv.end());
284 return *this;
285 }
286 /// A list of iterator dimensions to pad.
287 SmallVector<int64_t> paddingDimensions;
288 LinalgPaddingOptions &setPaddingDimensions(ArrayRef<int64_t> pd) {
289 paddingDimensions.assign(in_start: pd.begin(), in_end: pd.end());
290 return *this;
291 }
292 /// A list of multiples to which each padding dimension should be padded to.
293 std::optional<SmallVector<int64_t>> padToMultipleOf;
294 LinalgPaddingOptions &setPadToMultipleOf(ArrayRef<int64_t> m) {
295 padToMultipleOf.emplace(args: m.begin(), args: m.end());
296 return *this;
297 }
298 /// A flag for every operand to mark the PadOp as nofold which enables
299 /// packing for statically shaped operands.
300 SmallVector<bool> packPaddings;
301 LinalgPaddingOptions &setPackPaddings(ArrayRef<bool> pp) {
302 packPaddings.assign(in_start: pp.begin(), in_end: pp.end());
303 return *this;
304 }
305 /// A number of loops to hoist the PadOp out for every operand.
306 SmallVector<int64_t> hoistPaddings;
307 LinalgPaddingOptions &setHoistPaddings(ArrayRef<int64_t> hp) {
308 hoistPaddings.assign(in_start: hp.begin(), in_end: hp.end());
309 return *this;
310 }
311 /// A permutation vector for every operand used to transpose the packed
312 /// PadOp results.
313 SmallVector<SmallVector<int64_t>> transposePaddings;
314 LinalgPaddingOptions &
315 setTransposePaddings(ArrayRef<SmallVector<int64_t>> tp) {
316 transposePaddings.assign(in_start: tp.begin(), in_end: tp.end());
317 return *this;
318 }
319 enum class CopyBackOp : int8_t {
320 None = 0,
321 BufferizationMaterializeInDestination = 1,
322 LinalgCopy = 2
323 };
324 /// The op to be used for copying the padded result to the original
325 /// destination tensor.
326 CopyBackOp copyBackOp = CopyBackOp::BufferizationMaterializeInDestination;
327 LinalgPaddingOptions &setCopyBackOp(CopyBackOp op) {
328 copyBackOp = op;
329 return *this;
330 }
331};
332
333/// Callback function type used to perform the allocation for the promoted
334/// `subView`. In `boundingSubViewsize` a best attempt is made to find the
335/// smallest constant value for the size of the buffer needed for each
336/// dimension. If that is not possible, contains the dynamic size of the
337/// subview. The call back should return the buffer to use.
338using AllocBufferCallbackFn = std::function<std::optional<Value>(
339 OpBuilder &b, memref::SubViewOp subView,
340 ArrayRef<Value> boundingSubViewSize, DataLayout &layout)>;
341
342/// Callback function type used to deallocate the buffers used to hold the
343/// promoted subview.
344using DeallocBufferCallbackFn =
345 std::function<LogicalResult(OpBuilder &b, Value buffer)>;
346
347/// Callback function type used to insert copy from original subview to
348/// subview of the promoted region for the read operands/subview of promoted
349/// region to original subview for the results. The copy has to happen from
350/// `src` to `dst`.
351using CopyCallbackFn =
352 std::function<LogicalResult(OpBuilder &b, Value src, Value dst)>;
353
354struct LinalgPromotionOptions {
355 /// Indices of subViews to promote. If `std::nullopt`, try to promote all
356 /// operands.
357 std::optional<DenseSet<unsigned>> operandsToPromote;
358 LinalgPromotionOptions &setOperandsToPromote(ArrayRef<int64_t> operands) {
359 operandsToPromote = DenseSet<unsigned>();
360 operandsToPromote->insert(I: operands.begin(), E: operands.end());
361 return *this;
362 }
363 /// If ith element of `useFullTiles` is true the full view should be used
364 /// for the promoted buffer of the ith operand in `operandsToPromote`.
365 /// Otherwise the partial view will be used. The decision is defaulted to
366 /// `useFullTileBuffersDefault` when `useFullTileBuffers` is std::nullopt and
367 /// for operands missing from `useFullTileBuffers`.
368 std::optional<llvm::SmallBitVector> useFullTileBuffers;
369 LinalgPromotionOptions &setUseFullTileBuffers(ArrayRef<bool> useFullTiles) {
370 unsigned size = useFullTiles.size();
371 llvm::SmallBitVector tmp(size, false);
372 for (unsigned i = 0; i < size; ++i)
373 tmp[i] = useFullTiles[i];
374 useFullTileBuffers = tmp;
375 return *this;
376 }
377 /// If true all operands unspecified by `useFullTileBuffers` will use the
378 /// full view, otherwise the partial view.
379 bool useFullTileBuffersDefault = false;
380 LinalgPromotionOptions &setUseFullTileBuffersByDefault(bool use) {
381 useFullTileBuffersDefault = use;
382 return *this;
383 }
384 /// Alignment of promoted buffer. If `std::nullopt` do not specify alignment.
385 std::optional<unsigned> alignment;
386 LinalgPromotionOptions &setAlignment(unsigned align) {
387 alignment = align;
388 return *this;
389 }
390 /// Memory space of promoted buffer. If `std::nullopt` do not specify memory
391 /// space.
392 std::optional<Attribute> memorySpace;
393 LinalgPromotionOptions &setMemorySpace(Attribute memorySpc) {
394 memorySpace = memorySpc;
395 return *this;
396 }
397 /// Use alloca with the default allocation scheme.
398 bool useAlloca = false;
399 LinalgPromotionOptions &setUseAlloca(bool use) {
400 useAlloca = use;
401 return *this;
402 }
403 /// Callback function to do the allocation of the promoted buffer. If
404 /// std::nullopt, then the default allocation scheme of allocating a
405 /// memref<?xi8> buffer followed by a view operation is used.
406 std::optional<AllocBufferCallbackFn> allocationFn;
407 std::optional<DeallocBufferCallbackFn> deallocationFn;
408 LinalgPromotionOptions &
409 setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn,
410 DeallocBufferCallbackFn const &deallocFn) {
411 allocationFn = allocFn;
412 deallocationFn = deallocFn;
413 return *this;
414 }
415 /// Callback function to do the copy of data to and from the promoted
416 /// subview. If std::nullopt then a memref.copy is used.
417 std::optional<CopyCallbackFn> copyInFn;
418 std::optional<CopyCallbackFn> copyOutFn;
419 LinalgPromotionOptions &setCopyInOutFns(CopyCallbackFn const &copyIn,
420 CopyCallbackFn const &copyOut) {
421 copyInFn = copyIn;
422 copyOutFn = copyOut;
423 return *this;
424 }
425};
426
427/// Split Reduction options.
428struct SplitReductionOptions {
429 // Ratio used to split the reduction dimension. If the ratio is <= 1,
430 // nothing will be done.
431 int64_t ratio = 0;
432 // Index where the extra dimension is added to the intermediate tensor
433 // shape.
434 unsigned index = 0;
435 // If the inner dimension after splitting is parallel or reduction.
436 bool innerParallel = false;
437};
438
439/// Function signature to control reduction splitting. This returns
440/// `SplitReductionOptions`.
441// TODO: don't use unsigned unless doing bit manipulation.
442using ControlSplitReductionFn =
443 std::function<SplitReductionOptions(LinalgOp op)>;
444
445//===----------------------------------------------------------------------===//
446// Preconditions that ensure the corresponding transformation succeeds and can
447// be applied as a rewrite pattern.
448//===----------------------------------------------------------------------===//
449
450/// Return true if two `linalg.generic` operations with producer/consumer
451/// relationship through `fusedOperand` can be fused using elementwise op
452/// fusion.
453bool areElementwiseOpsFusable(OpOperand *fusedOperand);
454
455/// Promote memref.subviews feeding linalg-on-buffers operations.
456LogicalResult promoteSubviewsPrecondition(Operation *op,
457 LinalgPromotionOptions options);
458
459/// Return success if the operation can be vectorized.
460LogicalResult vectorizeOpPrecondition(Operation *op,
461 ArrayRef<int64_t> inputVectorSizes = {},
462 ArrayRef<bool> inputScalableVecDims = {},
463 bool vectorizeNDExtract = false,
464 bool flatten1DDepthwiseConv = false);
465
466//===----------------------------------------------------------------------===//
467// Transformations exposed as functional-style API calls.
468//===----------------------------------------------------------------------===//
469
470using LinalgLoops = SmallVector<Operation *, 4>;
471
472/// Transformation to drop unit-extent dimensions from `linalg.generic`
473/// operations.
474struct ControlDropUnitDims {
475 enum class RankReductionStrategy { ReassociativeReshape, ExtractInsertSlice };
476
477 RankReductionStrategy rankReductionStrategy =
478 RankReductionStrategy::ReassociativeReshape;
479
480 using ControlFnTy = std::function<SmallVector<unsigned>(Operation *)>;
481 ControlFnTy controlFn = [](Operation *op) {
482 if (auto genericOp = dyn_cast_or_null<GenericOp>(op)) {
483 return llvm::to_vector(llvm::seq<unsigned>(0, genericOp.getNumLoops()));
484 }
485 if (auto padOp = dyn_cast_or_null<tensor::PadOp>(op)) {
486 return llvm::to_vector(
487 llvm::seq<unsigned>(0, padOp.getSourceType().getRank()));
488 }
489 return SmallVector<unsigned>{};
490 };
491};
492LogicalResult dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
493 const ControlDropUnitDims &options);
494
495/// Fuse two `linalg.generic` operations that have a producer-consumer
496/// relationship captured through `fusedOperand`. The method expects
497/// that `areElementwiseOpsFusable` returns true for the given `fusedOperand`.
498struct ElementwiseOpFusionResult {
499 Operation *fusedOp;
500 llvm::DenseMap<Value, Value> replacements;
501 static llvm::SmallDenseSet<int>
502 getPreservedProducerResults(GenericOp producer, GenericOp consumer);
503};
504FailureOr<ElementwiseOpFusionResult>
505fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
506
507/// Try to peel and canonicalize loop `op` and return the new result.
508/// Also applies affine_min/max bounds simplification on the fly where relevant.
509// TODO: Add support for scf.parallel and affine.for loops.
510SmallVector<Value> peelLoop(RewriterBase &rewriter, Operation *op);
511
512/// Peel 'loops' and applies affine_min/max bounds simplification on the fly
513/// where relevant.
514void peelLoops(RewriterBase &rewriter, ArrayRef<scf::ForOp> loops);
515
516/// Pad the iterator dimensions `paddingDimensions` of all `opToPad` operands
517/// to a static bounding box. The original `opToPad` is cloned and operates on
518/// the padded tensors.
519///
520/// * "options.padToMultipleOf" indicates that each padding dimension should be
521/// padded to the specified multiple.
522/// * Use "options.paddingValues" and "options.packPaddings" to set padding
523/// value and nofold attribute of the created tensor::PadOps, respectively.
524/// * The unpadded results (extracted slice of the cloned operation) are
525/// returned via `replacements`.
526/// * The tensor::PadOps are returned via `padOps`.
527/// * "options.copyBackOp" specifies the op type for copying back the unpadded
528/// result to the original destination tensor.
529LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
530 const LinalgPaddingOptions &options,
531 LinalgOp &paddedOp,
532 SmallVector<Value> &replacements,
533 SmallVector<tensor::PadOp> &padOps);
534
535namespace detail {
536
537/// Helper struct to hold the results of building a packing loop nest.
538struct PackingResult {
539 SmallVector<OpFoldResult> offsets, sizes, strides;
540 SmallVector<Value> clonedLoopIvs, leadingPackedTensorIndexings;
541 GenericOp maybeTransposeOp;
542 tensor::PadOp hoistedPadOp;
543};
544
545/// Build the packing loop nest required to hoist `opToHoist` above
546/// `outermostEnclosingForOp`.
547/// The loop nest is built just before `outermostEnclosingForOp`.
548FailureOr<PackingResult>
549buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist,
550 scf::ForOp outermostEnclosingForOp,
551 ArrayRef<int64_t> transposeVector);
552
553} // namespace detail
554
555/// Mechanically hoist padding operations on tensors by `numLoops` into a new,
556/// generally larger tensor. This achieves packing of multiple padding ops into
557/// a larger tensor. On success, `opToHoist` is replaced by the cloned version
558/// in the packing loop so the caller can continue reasoning about the padding
559/// operation. If `transposeVector` is non-empty, hoist padding introduces a
560/// GenericOp to transpose the padded tensor before inserting it into the packed
561/// tensor. A `transposeVector` can change the storage order of the padded
562/// tensor but does not change the order of the pack or compute loops.
563///
564/// TODO: In the future, we should consider rewriting as a tensor.pack after
565/// hoisting since this abstraction is now available.
566///
567/// Example in pseudo-mlir:
568/// =======================
569///
570/// If hoistPaddingOnTensors is called with `nLoops` = 2 on the following IR.
571/// ```
572/// scf.for (%i, %j, %k)
573/// %st0 = tensor.extract_slice f(%i, %k) : ... to tensor<?x?xf32>
574/// %0 = tensor.pad %st0 low[0, 0] high[...] {
575/// ^bb0( ... ):
576/// linalg.yield %pad
577/// } : tensor<?x?xf32> to tensor<4x8xf32>
578/// compute(%0)
579/// ```
580///
581/// IR resembling the following is produced:
582///
583/// ```
584/// scf.for (%i) {
585/// %packed_init = tensor.empty range(%j) : tensor<?x4x8xf32>
586/// %packed = scf.for (%k) iter_args(%p : %packed_init) {
587/// %st0 = tensor.extract_slice f(%i, %k) : ... to tensor<?x?xf32>
588/// %0 = tensor.pad %st0 low[0, 0] high[...] {
589/// ^bb0( ... ):
590/// linalg.yield %pad
591/// } : tensor<?x?xf32> to tensor<4x8xf32>
592/// %1 = tensor.insert_slice %0 ...
593/// : tensor<4x8xf32> to tensor<?x4x8xf32>
594/// scf.yield %1: tensor<?x4x8xf32>
595/// } -> tensor<?x4x8xf32>
596/// scf.for (%j, %k) {
597/// %st0 = tensor.extract_slice %packed [%k, 0, 0][1, 4, 8][1, 1, 1] :
598/// tensor<?x4x8xf32> to tensor<4x8xf32>
599/// compute(%st0)
600/// }
601/// }
602/// ```
603FailureOr<Value>
604hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist,
605 int64_t numLoops, ArrayRef<int64_t> transposeVector,
606 tensor::PadOp &hoistedOp,
607 SmallVectorImpl<GenericOp> &transposeOps);
608/// Calls into `hoistPaddingOnTensors` with a local IRRewriter.
609FailureOr<Value>
610hoistPaddingOnTensors(tensor::PadOp opToHoist, int64_t numLoops,
611 ArrayRef<int64_t> transposeVector,
612 tensor::PadOp &hoistedOp,
613 SmallVectorImpl<GenericOp> &transposeOps);
614
615/// Apply padding and hoisting to `linalgOp` according to the configuration
616/// specified in `options`.
617FailureOr<LinalgOp> padAndHoistLinalgOp(RewriterBase &rewriter,
618 LinalgOp linalgOp,
619 const LinalgPaddingOptions &options);
620
621/// Split the given `op` into two parts along the given iteration space
622/// `dimension` at the specified `splitPoint`, and return the two parts.
623/// If the second part is statically known to be empty, do not create it
624/// and return nullptr instead. Error state is signalled by returning
625/// a pair of nullptrs.
626///
627/// For example, the following op:
628///
629/// linalg.matmul ins(%0, %1 : tensor<128x32xf32>, tensor<32x64xf32>)
630/// outs(%2 : tensor<128x64xf32>)
631///
632/// split along the first dimension at position 42 will result in:
633///
634/// %3 = tensor.extract_slice %0[0, 0][42, 32][1, 1]
635/// %4 = tensor.extract_slice %2[0, 0][42, 64][1, 1]
636/// %5 = linalg.matmul ins(%3, %1 : tensor<42x32xf32>, tensor<32x64xf32>)
637/// outs(%5 : tensor<42x64xf32>)
638/// %6 = tensor.insert_slice %5 into %2[0, 0][42, 64][1, 1]
639///
640/// %7 = tensor.extract_slice %0[42, 0][86, 32][1, 1]
641/// %8 = tensor.extract_slice %6[42, 0][86, 64][1, 1]
642/// %9 = linalg.matmul ins(%7, %1 : tensor<86x32xf32>, tensor<32x64xf32>)
643/// outs(%8 : tensor<86x64xf32>)
644/// tensor.insert_slice %5 into %6[42, 0][86, 64][1, 1]
645///
646/// Note that there is no simplification other than constant propagation applied
647/// to slice extraction and insertion.
648std::pair<TilingInterface, TilingInterface> splitOp(RewriterBase &rewriter,
649 TilingInterface op,
650 unsigned dimension,
651 OpFoldResult splitPoint);
652
653/// Perform standalone tiling of a single LinalgOp by `tileSizes`.
654/// and permute the loop nest according to `interchangeVector`
655/// The permutation is expressed as a list of integers that specify
656/// the new ordering of the loop nest. The length of `interchangeVector`
657/// must be equal to the length of `tileSizes`.
658/// An empty vector is interpreted as the identity permutation and the
659/// transformation returns early.
660///
661/// Return a struct containing the tiled loops in the specified order
662/// and the cloned op if successful, std::nullopt otherwise.
663///
664/// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed by
665/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
666/// integers, in the range 0..`tileSizes.size()` without duplications
667/// (i.e. `[1,1,2]` is an invalid permutation).
668struct TiledLinalgOp {
669 LinalgOp op;
670 SmallVector<Operation *, 8> loops;
671 SmallVector<Value, 4> tensorResults;
672};
673FailureOr<TiledLinalgOp> tileLinalgOp(RewriterBase &b, LinalgOp op,
674 const LinalgTilingOptions &options);
675
676/// Interchange the `iterator_types` and `iterator_maps` dimensions and adapts
677/// the index accesses of `op`. This is an in-place transformation controlled
678/// by `interchangeVector`. An empty vector is interpreted as the identity
679/// permutation and the transformation returns early.
680///
681/// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed with
682/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
683/// integers, in the range 0..`op.rank` without duplications
684/// (i.e. `[1,1,2]` is an invalid permutation).
685///
686/// Return failure if the permutation is not valid.
687FailureOr<GenericOp> interchangeGenericOp(RewriterBase &rewriter,
688 GenericOp genericOp,
689 ArrayRef<unsigned> interchangeVector);
690
691/// Create a GenericOp from the given named operation `namedOp` and replace
692/// namedOp.
693/// Return failure if `namedOp` is a GenericOp or misses a region builder.
694FailureOr<GenericOp> generalizeNamedOp(RewriterBase &rewriter,
695 LinalgOp namedOp);
696
697/// Create a namedOp from the given GenericOp and replace the GenericOp.
698/// Currently we can specialize only trivial linalg copy operations.
699FailureOr<LinalgOp> specializeGenericOp(RewriterBase &rewriter,
700 GenericOp genericOp);
701
702/// Create a new buffer using the `allocationFn` provided. The size of this
703/// buffer is the smallest constant bounding size along each dimension that
704/// can be computed for the size of the result of `subView`. Returns the
705/// allocated buffer as `fullLocalView` and the view that matches the size of
706/// the result of subview operation as `partialLocalView`.
707struct PromotionInfo {
708 Value fullLocalView;
709 Value partialLocalView;
710};
711FailureOr<PromotionInfo>
712promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView,
713 const AllocBufferCallbackFn &allocationFn,
714 DataLayout &layout);
715
716/// Promote the `subViews` into a new buffer allocated at the insertion point
717/// `b`. Promotion occurs in 3 steps:
718/// 1. Create a new buffer for a full tile (i.e. not clipped at the
719/// boundary).
720/// 2. Take a full view on the buffer.
721/// 3. Take a partial slice of the full view in step 2. and copy into it.
722///
723/// Return the modified linalg op (the modification happens in place) as well
724/// as all the copy ops created.
725FailureOr<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
726 const LinalgPromotionOptions &options);
727
728/// Allocate the subview in the GPU workgroup memory.
729std::optional<Value> allocateWorkgroupMemory(OpBuilder &builder,
730 memref::SubViewOp subview,
731 ArrayRef<Value> sizeBounds,
732 DataLayout &);
733
734/// In case of GPU group memory there is no need to deallocate.
735LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value /*buffer*/);
736
737/// Create Memref copy operations and add gpu barrier guards before and after
738/// the copy operation to ensure data integrity.
739LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst);
740
741/// Allocate the subview in the GPU private memory.
742std::optional<Value> allocateGPUPrivateMemory(OpBuilder &builder,
743 memref::SubViewOp subview,
744 ArrayRef<Value> sizeBounds,
745 DataLayout &);
746
747/// Normal copy to between src and dst.
748LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst);
749
750/// In case of GPU private memory there is no need to deallocate since the
751/// memory is freed when going outside of the scope.
752LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
753
754/// Emit a suitable vector form for an operation. If provided,
755/// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes`
756/// must match the rank of the iteration space of the operation and the sizes
757/// must be smaller or equal than their counterpart interation space sizes, if
758/// static. `inputVectorShapes` also allows the vectorization of operations with
759/// dynamic shapes.
760LogicalResult vectorize(RewriterBase &rewriter, Operation *op,
761 ArrayRef<int64_t> inputVectorSizes = {},
762 ArrayRef<bool> inputScalableVecDims = {},
763 bool vectorizeNDExtract = false,
764 bool flatten1DDepthwiseConv = false);
765
766/// Emit a suitable vector form for a Copy op with fully static shape.
767LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
768
769/// Emit a loop nest of `scf.for` with the proper body for `linalgOp`.
770FailureOr<LinalgLoops> linalgOpToLoops(RewriterBase &rewriter,
771 LinalgOp linalgOp);
772
773/// Emit a loop nest of `scf.parallel` with the proper body for `linalgOp`.
774FailureOr<LinalgLoops> linalgOpToParallelLoops(RewriterBase &rewriter,
775 LinalgOp linalgOp);
776
777/// Emit a loop nest of `affine.for` with the proper body for `linalgOp`.
778FailureOr<LinalgLoops> linalgOpToAffineLoops(RewriterBase &rewriter,
779 LinalgOp linalgOp);
780
781/// Creates a number of ranges equal to the number of non-zero in `tileSizes`.
782/// One for each loop of the LinalgOp that is tiled. The `tileSizes` argument
783/// has one entry per surrounding loop. It uses zero as the convention that a
784/// particular loop is not tiled. This convention simplifies implementations
785/// by avoiding affine map manipulations. The returned ranges correspond to
786/// the loop ranges, in the proper order, that are tiled and for which new
787/// loops will be created. Also the function returns a map from loop indices
788/// of the LinalgOp to the corresponding non-empty range indices of newly
789/// created loops.
790using LoopIndexToRangeIndexMap = DenseMap<int, int>;
791std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
792makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
793 ArrayRef<OpFoldResult> allShapeSizes,
794 ArrayRef<OpFoldResult> allTileSizes);
795
796namespace detail {
797template <typename T>
798struct MultiSizeSpecificationBase {
799 /// Tile sizes.
800 T lowTileSize, highTileSize;
801 /// Number of tiles associated with each size.
802 T lowTripCount, highTripCount;
803};
804} // namespace detail
805
806/// A description of a multi-size tiling comprising tile sizes and numbers of
807/// tiles, expressed as Values which may or may not be constant. Multi-size
808/// currently means two-size.
809struct MultiSizeSpecification
810 : public detail::MultiSizeSpecificationBase<Value> {};
811struct StaticMultiSizeSpecification
812 : public detail::MultiSizeSpecificationBase<int64_t> {};
813
814/// Emits the IR computing the multi-sized tiling specification with two tile
815/// sizes not exceeding `targetSize`, each divisible by `sizeDivisor`, such
816/// that there exist numbers of tiles with these sizes that fully cover the
817/// given iteration space `dimension` of the structured `op`.
818///
819/// The computation is as follows:
820///
821/// b = originalTripCount floordiv sizeDivisor
822/// t = (targetSize + sizeDivisor - 1) floordiv sizeDivisor
823/// d = (b + t - 1) floordiv t
824/// s = (b floordiv d) * sizeDivisor
825/// v = b % d
826/// u = d - v
827///
828/// where the tile sizes are `s` and `s` + `sizeDivisor`, and the numbers of
829/// the corresponding tiles are `u` and `v`, respectively. Alternatively,
830///
831/// s * u + (s + sizeDivisor) * v == original size,
832/// where s mod sizeDivisor = 0.
833///
834/// Expects all values to be positive. In some cases with the target tile size
835/// sufficiently close to the dimension shape and non-unit divisor, it is
836/// impossible to compute such sizes. If `emitAssertion` is set, also emit the
837/// assertion that size computation succeeded.
838///
839/// Returns the specification consisting of both tile values and the number of
840/// tiles of each size.
841FailureOr<MultiSizeSpecification>
842computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension,
843 OpFoldResult targetSize, OpFoldResult divisor,
844 bool emitAssertions = true);
845FailureOr<StaticMultiSizeSpecification>
846computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize,
847 int64_t divisor);
848
849/// Rewrite a TilingInterface `op` to a tiled `scf.forall`, applying
850/// tiling by `numThreads`.
851/// If non-empty, the `mapping` is added as an attribute to the
852/// resulting `scf.forall`.
853/// Zero tile sizes indicate that the dimension is not tiled, and can be
854/// thought of as tiling by the full size of data. It is the user's
855/// responsibility to ensure that `numThreads` is a valid tiling specification
856/// (i.e. that only tiles parallel dimensions, e.g. in the Linalg case).
857struct ForallTilingResult {
858 Operation *tileOp;
859 Operation *tiledOp;
860};
861FailureOr<ForallTilingResult> tileToForallOp(RewriterBase &builder,
862 TilingInterface op,
863 ArrayRef<OpFoldResult> numThreads,
864 std::optional<ArrayAttr> mapping);
865
866/// Same as `tileToForallOp`, but calculate the number of threads
867/// required using the given tileSizes.
868FailureOr<ForallTilingResult>
869tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op,
870 ArrayRef<OpFoldResult> tileSizes,
871 std::optional<ArrayAttr> mapping);
872
873/// Transformation information returned after reduction tiling.
874struct ForallReductionTilingResult {
875 /// The partial reduction tiled op generated.
876 Operation *parallelTiledOp;
877 /// The final reduction operation merging all the partial reductions.
878 Operation *mergeOp;
879 /// The op initializing the tensor used for partial reductions.
880 Operation *initialOp;
881 /// The `scf.forall` operation that iterate over the tiles.
882 scf::ForallOp loops;
883};
884
885/// Method to tile a reduction to parallel iterations computing partial
886/// reductions. After the loop all the partial reduction are merged into a final
887/// reduction. For example for the following sequence
888///
889/// ```mlir
890/// %0 = linalg.generic %in ["parallel", "reduction"]
891/// : tensor<7x9xf32> -> tensor<7xf32>
892/// ```
893///
894/// into:
895///
896/// ```mlir
897/// %0 = linalg.fill ... : tensor<7x4xf32>
898/// %1 = scf.forall (%iv) in (%c4) shared_outs(%arg0 = %0)
899/// -> (tensor<7x4xf32>) {
900/// %2 = tensor.extract_slice %arg3 : tensor<7x4xf32> to tensor<7xf32>
901/// %3 = tensor.extract_slice %in : tensor<7x9xf32> -> tensor<7x?xf32>
902/// %4 = linalg.generic %2, %3 ["parallel", "reduction"]
903/// : tensor<7x?xf32> -> tensor<7xf32>
904/// %5 = tensor.insert_slice %3, %arg0[0, %iv] : tensor<7x4xf32>
905/// }
906/// %6 = linalg.generic %1 ["parallel", "reduction"]
907/// : tensor<7x4xf32> -> tensor<7xf32>
908/// ```
909FailureOr<ForallReductionTilingResult>
910tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op,
911 ArrayRef<OpFoldResult> numThreads,
912 ArrayRef<OpFoldResult> tileSizes = {},
913 std::optional<ArrayAttr> mapping = std::nullopt);
914
915/// All indices returned by IndexOp should be invariant with respect to
916/// tiling. Therefore, if an operation is tiled, we have to transform the
917/// indices accordingly, i.e. offset them by the values of the corresponding
918/// induction variables that are captured implicitly in the body of the op.
919///
920/// Example. `linalg.generic` before tiling:
921///
922/// #id_2d = (i, j) -> (i, j)
923/// #pointwise_2d_trait = {
924/// indexing_maps = [#id_2d, #id_2d],
925/// iterator_types = ["parallel", "parallel"]
926/// }
927/// linalg.generic #pointwise_2d_trait %operand, %result {
928/// ^bb0(%operand_in: f32, %result_in: f32):
929/// %i = linalg.index 0 : index
930/// %j = linalg.index 1 : index
931/// <some operations that use %i, %j>
932/// }: memref<50x100xf32>, memref<50x100xf32>
933///
934/// After tiling pass with tiles sizes 10 and 25:
935///
936/// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2)
937///
938/// %c1 = arith.constant 1 : index
939/// %c0 = arith.constant 0 : index
940/// %c25 = arith.constant 25 : index
941/// %c10 = arith.constant 10 : index
942/// operand_dim_0 = dim %operand, 0 : memref<50x100xf32>
943/// operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
944/// scf.for %k = %c0 to operand_dim_0 step %c10 {
945/// scf.for %l = %c0 to operand_dim_1 step %c25 {
946/// %4 = memref.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
947/// : memref<50x100xf32> to memref<?x?xf32, #strided>
948/// %5 = memref.subview %result[%k, %l][%c10, %c25][%c1, %c1]
949/// : memref<50x100xf32> to memref<?x?xf32, #strided>
950/// linalg.generic pointwise_2d_trait %4, %5 {
951/// ^bb0(%operand_in: f32, %result_in: f32):
952/// %i = linalg.index 0 : index
953/// %j = linalg.index 1 : index
954/// // Indices `k` and `l` are implicitly captured in the body.
955/// %transformed_i = arith.addi %i, %k : index // index `i` is offset by
956/// %k %transformed_j = arith.addi %j, %l : index // index `j` is offset
957/// by %l
958/// // Every use of %i, %j is replaced with %transformed_i,
959/// %transformed_j <some operations that use %transformed_i,
960/// %transformed_j>
961/// }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided>
962/// }
963/// }
964///
965/// TODO: Investigate whether mixing implicit and explicit indices
966/// does not lead to losing information.
967void transformIndexOps(RewriterBase &b, LinalgOp op,
968 SmallVectorImpl<Value> &ivs,
969 const LoopIndexToRangeIndexMap &loopIndexToRangeIndex);
970
971/// Apply transformation to split the single linalg op reduction into a
972/// parallel and reduction dimension. Then create a new linalg.generic op
973/// doing the rest of the reduction. Return the new linalg op with an extra
974/// parallel dimension or failure if the transformation didn't happen.
975///
976/// Example:
977/// ```
978/// %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
979/// affine_map<(d0) -> ()>],
980/// iterator_types = ["reduction"]}
981/// ins(%in : tensor<32xf32>)
982/// outs(%out : tensor<f32>) {
983/// ^bb0(%arg1: f32, %arg2: f32):
984/// %y = arith.addf %arg1, %arg2 : f32
985/// linalg.yield %y : f32
986/// } -> tensor<f32>
987/// ```
988/// To:
989/// ```
990/// %cst = arith.constant 0.000000e+00 : f32
991/// %0 = tensor.expand_shape %in [[0, 1]] : tensor<32xf32> into
992/// tensor<4x8xf32> %1 = tensor.empty [4] : tensor<4xf32> %2 = linalg.fill
993/// ins(%cst : f32) outs(%1 : tensor<4xf32>) -> tensor<4xf32> %3 =
994/// linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
995/// affine_map<(d0, d1) -> (d0)>],
996/// iterator_types = ["parallel", "reduction"]}
997/// ins(%0 : tensor<4x8xf32>) outs(%2 : tensor<4xf32>) {
998/// ^bb0(%arg3: f32, %arg5: f32):
999/// %5 = arith.addf %arg3, %arg4 : f32
1000/// linalg.yield %5 : f32
1001/// } -> tensor<4xf32>
1002/// %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
1003/// affine_map<(d0) -> ()>],
1004/// iterator_types = ["reduction"]}
1005/// ins(%3 : tensor<4xf32>) outs(%out : tensor<f32>) {
1006/// ^bb0(%arg3: f32, %arg4: f32):
1007/// %5 = arith.addf %arg3, %arg4 : f32
1008/// linalg.yield %5 : f32
1009/// } -> tensor<f32>
1010/// ```
1011struct SplitReductionResult {
1012 Operation *initOrAlloc;
1013 FillOp fillOp;
1014 LinalgOp splitLinalgOp;
1015 LinalgOp resultCombiningLinalgOp;
1016};
1017FailureOr<SplitReductionResult>
1018splitReduction(RewriterBase &b, LinalgOp op,
1019 const ControlSplitReductionFn &controlSplitReductionFn,
1020 bool useAlloc = false);
1021
1022/// Scaling-based implementation of the split reduction transformation.
1023/// Instead of introducing an ExpandShapeOp, this rewrites a reduction
1024/// dimension `k` into `k * scale + kk`.
1025///
1026/// Example:
1027/// ```
1028/// %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>)
1029/// outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
1030/// ```
1031///
1032/// Is transformed to:
1033///
1034/// ```
1035/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2 * 4 + d3)>
1036/// #map1 = affine_map<(d0, d1, d2, d3) -> (d2 * 4 + d3, d1)>
1037/// #map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
1038/// #map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
1039/// #map4 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
1040/// #map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
1041/// %0 = tensor.empty [16, 32, 64] : tensor<16x32x64xf32>
1042/// %cst = arith.constant 0.000000e+00 : f32
1043/// %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<16x32x64xf32>) ->
1044/// tensor<16x32x64xf32>
1045/// %2 = tensor.empty [64, 4] : tensor<64x4xi1>
1046///
1047/// %3 = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3],
1048/// iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
1049/// ins(%A, %B, %2 : tensor<16x256xf32>, tensor<256x32xf32>,
1050/// tensor<64x4xi1>)
1051/// outs(%1 : tensor<16x32x64xf32>) {
1052/// ^bb0(%arg3: f32, %arg4: f32, %arg5: i1, %arg6: f32):
1053/// %5 = arith.mulf %arg3, %arg4 : f32
1054/// %6 = arith.addf %arg6, %5 : f32
1055/// linalg.yield %6 : f32
1056/// } -> tensor<16x32x64xf32>
1057///
1058/// %4 = linalg.generic {indexing_maps = [#map4, #map5],
1059/// iterator_types = ["parallel", "parallel", "reduction"]}
1060// ins(%3 : tensor<16x32x64xf32>)
1061/// outs(%C : tensor<16x32xf32>) {
1062/// ^bb0(%arg3: f32, %arg4: f32):
1063/// %5 = arith.addf %arg3, %arg4 : f32
1064/// linalg.yield %5 : f32
1065/// } -> tensor<16x32xf32>
1066///
1067/// return %4 : tensor<16x32xf32>
1068/// ```
1069FailureOr<SplitReductionResult>
1070splitReductionByScaling(RewriterBase &b, LinalgOp op,
1071 const ControlSplitReductionFn &controlSplitReductionFn,
1072 bool useAlloc = false);
1073
1074/// Return `true` if a given sequence of dimensions are contiguous in the
1075/// range of the specified indexing map.
1076bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence);
1077/// Return `true` if all sequences of dimensions specified in `dimSequences` are
1078/// contiguous in all the ranges of the `maps`.
1079bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
1080 ArrayRef<ReassociationIndices> dimSequences);
1081
1082struct CollapseResult {
1083 SmallVector<Value> results;
1084 LinalgOp collapsedOp;
1085};
1086
1087/// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition
1088/// to calling this method is that for each list in `foldedIterationDim`, the
1089/// sequence of dimensions is contiguous in domains of all `indexing_maps` of
1090/// the `linalgOp`. This can be checked using `areDimSequencePreserved` method.
1091/// When valid, the method also collapses the operands of the op. Returns
1092/// replacement values of the results of the original `linalgOp` by inserting
1093/// reshapes to get back values of compatible types.
1094FailureOr<CollapseResult>
1095collapseOpIterationDims(LinalgOp op,
1096 ArrayRef<ReassociationIndices> foldedIterationDims,
1097 RewriterBase &rewriter);
1098
1099struct LowerPackResult {
1100 tensor::PadOp padOp;
1101 tensor::ExpandShapeOp expandShapeOp;
1102 linalg::TransposeOp transposeOp;
1103};
1104
1105/// Rewrite pack as pad + reshape + transpose.
1106FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
1107 tensor::PackOp packOp);
1108
1109struct LowerUnPackOpResult {
1110 tensor::EmptyOp emptyOp;
1111 linalg::TransposeOp transposeOp;
1112 tensor::CollapseShapeOp collapseShapeOp;
1113 tensor::ExtractSliceOp extractSliceOp;
1114};
1115
1116/// Rewrite pack as empty + transpose + reshape + extract_slice.
1117FailureOr<LowerUnPackOpResult> lowerUnPack(RewriterBase &rewriter,
1118 tensor::UnPackOp unPackOp);
1119
1120/// Struct to hold the result of a `pack` call.
1121struct PackResult {
1122 SmallVector<tensor::PackOp> packOps;
1123 linalg::LinalgOp packedLinalgOp;
1124 SmallVector<tensor::UnPackOp> unPackOps;
1125};
1126/// Implement packing of a single LinalgOp by `packedSizes`.
1127/// There must be one packedSizes entry per `linalgOp` iterator.
1128/// Return the packed Linalg op on success, failure otherwise.
1129FailureOr<PackResult> pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
1130 ArrayRef<OpFoldResult> packedSizes);
1131
1132/// Struct to hold the result of a `packTranspose` call.
1133struct PackTransposeResult {
1134 tensor::PackOp transposedPackOp;
1135 linalg::LinalgOp transposedLinalgOp;
1136 tensor::UnPackOp transposedUnPackOp;
1137};
1138/// Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the
1139/// transposed PackOp -> LinalgOp -> UnPackOp chain after replacements.
1140/// Return failure if either:
1141/// 1. the `packOp` does not have the `linalgOp` as its unique use.
1142/// 2. the `maybeUnPackOp`, if specified must be a consumer of the result tied
1143/// to the unique `packOp` use.
1144/// 3. `outerPerm` (resp. `innerPerm`) must be valid permutations of
1145/// `packOp.getOuterDimsPerm` (resp. `packOp.getInnerDimsPerm`) or empty.
1146FailureOr<PackTransposeResult>
1147packTranspose(RewriterBase &rewriter, tensor::PackOp packOp,
1148 linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp,
1149 ArrayRef<int64_t> outerPerm, ArrayRef<int64_t> innerPerm);
1150
1151/// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m
1152/// and n are proper parallel dimensions and k is a proper reduction
1153/// dimension. Packing occurs by rewriting the op as a linalg.generic and
1154/// calling linalg::pack by `mnkPackedSizes`. The order of the packed
1155/// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2}
1156/// to reorder {m, n, k} into one of the 8 possible forms. The outer
1157/// dimensions of the operands are not permuted at this time, this is left for
1158/// future work.
1159FailureOr<PackResult>
1160packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
1161 ArrayRef<OpFoldResult> mnkPackedSizes,
1162 ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf,
1163 ArrayRef<int64_t> mnkOrder);
1164
1165/// Rewrite tensor.from_elements to linalg.generic.
1166FailureOr<Operation *>
1167rewriteInDestinationPassingStyle(RewriterBase &rewriter,
1168 tensor::FromElementsOp fromElementsOp);
1169
1170/// Rewrite tensor.generate to linalg.generic.
1171FailureOr<Operation *>
1172rewriteInDestinationPassingStyle(RewriterBase &rewriter,
1173 tensor::GenerateOp generateOp);
1174
1175/// Rewrite tensor.pad to linalg.generic + tensor.insert_slice.
1176FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
1177 tensor::PadOp padOp);
1178
1179/// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing)
1180/// and linalg.matmul.
1181///
1182/// A convolution operation can be written as a matrix-matrix multiplication by
1183/// unfolding the cross-correlation between input and filter and explicitly copy
1184/// overlapped sliding window inputs.
1185///
1186/// Consider 2D input X with single channel input and output and 2x2 filter W:
1187/// [x(0, 0) , x(0, 1) , ..., x(0, n) ]
1188/// [x(1, 0) , x(1, 1) , ..., x(1, n) ]
1189/// [. , . ,. , . ] [w(0, 0), w(0, 1)]
1190/// [. , . , . , . ] (conv) [w(1, 0), w(1, 1)]
1191/// [. , . , ., . ]
1192/// [x(n-1, 0), x(n-1, 1), ..., x(n-1, n-1)]
1193///
1194/// The packed input data (img2col) is a matrix with |rows| = output spatial
1195/// size, |columns| = filter spatial size. To compute the output Y(i, j) we need
1196/// to calculate the dot product between filter window at input X(x, y)) and the
1197/// filter which will look like the following where r.h.s is the img2col matrix
1198/// and l.h.s is the flattened filter:
1199///
1200/// [x(0,0), x(0,1), x(1,0), x(1,1)]
1201/// [x(0,1), x(1,1), x(0,2), x(1,2)] (matmul) [w(0,0), w(0,1), w(1,0), w(1,1)]
1202/// [x(0,1), x(1,1), x(0,2), x(1,2)]
1203/// [ . , . , . , . ]
1204///
1205/// In general for 2D case with (N, H, W, C) input and (Kh, Kw, C, D) filter
1206/// and output (N, Ho, Wo, D) the convolution is the following matrix-matrix
1207/// multiplication (Ho x Wo, Kh x Kw x C) * (Kh x Kw x C, D) for each input in
1208/// the N input. For the case where N > 1 its a batched matrix-matrix
1209/// multiplication.
1210///
1211/// On success, return both the operation that produces the img2col tensor and
1212/// the final operation of the sequence that replaces the original convolution.
1213FailureOr<std::pair<Operation *, Operation *>>
1214rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp);
1215
1216/// Same as the above but for Fhwc channel orderings in the filter. In this case
1217/// the matrix multiplication is actually a row-wise dot-product rather than a
1218/// row-column dot-product. This is to avoid transposing the filter matrix which
1219/// would be required for a regular matrix multiplication to produce the correct
1220/// output dimensions.
1221FailureOr<std::pair<Operation *, Operation *>>
1222rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp);
1223
1224/// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except there is no
1225/// reduction among the input channels so each convolution can be a
1226/// matrix-vector product and by transposing both input filter so channels are
1227/// outer most the computation is a batched matrix-vector product.
1228FailureOr<std::pair<Operation *, Operation *>>
1229rewriteInIm2Col(RewriterBase &rewriter,
1230 linalg::DepthwiseConv2DNhwcHwcOp convOp);
1231
1232/// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except because the
1233/// channels are to the left of the image shape dimensions, the position of the
1234/// contraction dimension in the resulting matmul is reversed. This swaps the
1235/// LHS and RHS of the matmul when compared with nhwc (i.e. (D, C x Kh x Kw) *
1236/// (C x Kh x Kw, Ho x Wo))
1237FailureOr<std::pair<Operation *, Operation *>>
1238rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp);
1239
1240/// Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by
1241/// materializing transpose.
1242FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
1243 linalg::Conv2DNhwcFhwcOp op);
1244FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
1245 linalg::Conv2DNhwcFhwcQOp op);
1246
1247/// Convert Linalg matmul ops to transposed variants.
1248FailureOr<Operation *> transposeMatmul(RewriterBase &rewriter,
1249 linalg::MatmulOp op,
1250 bool transposeLHS = true);
1251FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
1252 linalg::BatchMatmulOp op,
1253 bool transposeLHS = true);
1254
1255//===----------------------------------------------------------------------===//
1256// Rewrite patterns wrapping transformations.
1257// TODO: every single such pattern should be a close to noop wrapper around a
1258// functional-stye API call.
1259//===----------------------------------------------------------------------===//
1260
1261/// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
1262/// convolution ops.
1263template <typename Conv2DOp, typename Conv1DOp>
1264struct DownscaleSizeOneWindowed2DConvolution final
1265 : public OpRewritePattern<Conv2DOp> {
1266 using OpRewritePattern<Conv2DOp>::OpRewritePattern;
1267
1268 FailureOr<Conv1DOp> returningMatchAndRewrite(Conv2DOp convOp,
1269 PatternRewriter &rewriter) const;
1270
1271 LogicalResult matchAndRewrite(Conv2DOp convOp,
1272 PatternRewriter &rewriter) const override {
1273 return returningMatchAndRewrite(convOp, rewriter);
1274 }
1275};
1276
1277extern template struct DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp,
1278 Conv1DNwcWcfOp>;
1279extern template struct DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
1280 Conv1DNcwFcwOp>;
1281
1282/// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
1283/// dimensions into 1-D depthwise convolution ops.
1284struct DownscaleDepthwiseConv2DNhwcHwcOp final
1285 : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
1286 DownscaleDepthwiseConv2DNhwcHwcOp(MLIRContext *context,
1287 PatternBenefit benefit = 1)
1288 : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit) {}
1289
1290 FailureOr<DepthwiseConv1DNwcWcOp>
1291 returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
1292 PatternRewriter &rewriter) const;
1293
1294 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
1295 PatternRewriter &rewriter) const override {
1296 return returningMatchAndRewrite(convOp, rewriter);
1297 }
1298};
1299
1300struct DownscaleConv2DOp final : public OpRewritePattern<Conv2DOp> {
1301 DownscaleConv2DOp(MLIRContext *context, PatternBenefit benefit = 1)
1302 : OpRewritePattern<Conv2DOp>(context, benefit) {}
1303
1304 FailureOr<Conv1DOp> returningMatchAndRewrite(Conv2DOp convOp,
1305 PatternRewriter &rewriter) const;
1306
1307 LogicalResult matchAndRewrite(Conv2DOp convOp,
1308 PatternRewriter &rewriter) const override {
1309 return returningMatchAndRewrite(convOp, rewriter);
1310 }
1311};
1312
1313///
1314/// Linalg generalization pattern.
1315///
1316/// Apply the `generalization` transformation as a pattern.
1317/// See `generalization` for more details.
1318//
1319// TODO: Automatic default pattern class that just unwraps a function
1320// returning FailureOr<GenericOp>.
1321struct LinalgGeneralizationPattern
1322 : public OpInterfaceRewritePattern<LinalgOp> {
1323 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
1324
1325 /// `matchAndRewrite` implementation that returns the significant
1326 /// transformed pieces of IR.
1327 FailureOr<GenericOp>
1328 returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const {
1329 return generalizeNamedOp(rewriter, op);
1330 }
1331
1332 LogicalResult matchAndRewrite(LinalgOp op,
1333 PatternRewriter &rewriter) const override {
1334 return returningMatchAndRewrite(op, rewriter);
1335 }
1336};
1337
1338/// Vectorization pattern for memref::CopyOp.
1339struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> {
1340 using OpRewritePattern<memref::CopyOp>::OpRewritePattern;
1341
1342 LogicalResult matchAndRewrite(memref::CopyOp copyOp,
1343 PatternRewriter &rewriter) const override;
1344};
1345
1346using OptimizeCopyFn =
1347 std::function<LogicalResult(RewriterBase &, tensor::PadOp, Value)>;
1348
1349/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
1350/// InsertSliceOp. For now, only constant padding values are supported.
1351/// `OptimizeCopyFn` can be used to customize copying step optimization.
1352struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
1353 GeneralizePadOpPattern(MLIRContext *context,
1354 OptimizeCopyFn optimizeCopyFn = nullptr,
1355 PatternBenefit benefit = 1)
1356 : OpRewritePattern<tensor::PadOp>(context, benefit),
1357 optimizeCopyFn(std::move(optimizeCopyFn)) {}
1358 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1359 PatternRewriter &rewriter) const override;
1360
1361protected:
1362 OptimizeCopyFn optimizeCopyFn;
1363 Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp,
1364 Value dest,
1365 const SmallVector<Value> &dynSizes) const;
1366};
1367
1368/// Rewrites a tensor::PackOp into a sequence of tensor.pad + linalg.transpose +
1369/// tensor.insert_slice ops, where the tensor::PackOp has outer dims being all
1370/// 1s.
1371struct GeneralizeOuterUnitDimsPackOpPattern
1372 : public OpRewritePattern<tensor::PackOp> {
1373 using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
1374 LogicalResult matchAndRewrite(tensor::PackOp packOp,
1375 PatternRewriter &rewriter) const override;
1376};
1377
1378/// Rewrites a tensor::UnPackOp into a sequence of rank-reduced extract_slice op
1379/// + transpose op + insert_slice op, where the tensor::UnPackOp has outer dims
1380/// being all 1s.
1381struct GeneralizeOuterUnitDimsUnPackOpPattern
1382 : public OpRewritePattern<tensor::UnPackOp> {
1383 using OpRewritePattern<tensor::UnPackOp>::OpRewritePattern;
1384 LogicalResult matchAndRewrite(tensor::UnPackOp unpackOp,
1385 PatternRewriter &rewriter) const override;
1386};
1387
1388/// Match and rewrite for the pattern:
1389/// ```
1390/// %alloc = ...
1391/// [optional] %view = memref.view %alloc ...
1392/// %subView = subview %allocOrView ...
1393/// [optional] linalg.fill(%allocOrView, %cst) ...
1394/// ...
1395/// memref.copy(%in, %subView) ...
1396/// vector.transfer_read %allocOrView[...], %cst ...
1397/// ```
1398/// into
1399/// ```
1400/// [unchanged] %alloc = ...
1401/// [unchanged] [optional] %view = memref.view %alloc ...
1402/// [unchanged] [unchanged] %subView = subview %allocOrView ...
1403/// ...
1404/// vector.transfer_read %in[...], %cst ...
1405/// ```
1406/// Where there is no interleaved use between memref.copy and transfer_read as
1407/// well as no interleaved use between linalg.fill and memref.copy (if
1408/// linalg.fill is specified).
1409/// This is a custom rewrite to forward partial reads (with optional fills) to
1410/// vector.transfer_read.
1411struct LinalgCopyVTRForwardingPattern
1412 : public OpRewritePattern<vector::TransferReadOp> {
1413 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
1414
1415 LogicalResult matchAndRewrite(vector::TransferReadOp xferOp,
1416 PatternRewriter &rewriter) const override;
1417};
1418
1419/// Match and rewrite for the pattern:
1420/// ```
1421/// %alloc = ...
1422/// [optional] %view = memref.view %alloc ...
1423/// %subView = subview %allocOrView...
1424/// ...
1425/// vector.transfer_write %..., %allocOrView[...]
1426/// memref.copy(%subView, %out)
1427/// ```
1428/// into
1429/// ```
1430/// [unchanged] %alloc = ...
1431/// [unchanged] [optional] %view = memref.view %alloc ...
1432/// [unchanged] %subView = subview %allocOrView...
1433/// ...
1434/// vector.transfer_write %..., %out[...]
1435/// ```
1436/// Where there is no interleaved use between transfer_write and memref.copy.
1437/// This is a custom rewrite to forward partial writes to
1438/// vector.transfer_write.
1439struct LinalgCopyVTWForwardingPattern
1440 : public OpRewritePattern<vector::TransferWriteOp> {
1441 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
1442
1443 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
1444 PatternRewriter &rewriter) const override;
1445};
1446
1447/// Rewrite extract_slice(tensor.pad(x)) into tensor.pad(extract_slice(x)).
1448struct ExtractSliceOfPadTensorSwapPattern
1449 : public OpRewritePattern<tensor::ExtractSliceOp> {
1450 /// A function to control pattern application and rewrite logic.
1451 ///
1452 /// The function will be given the slice op and should return:
1453 /// - std::nullopt: to fail the match and not apply the pattern;
1454 /// - true: to apply the pattern with zero slice guard;
1455 /// - false: to apply the pattern without zero slice guard.
1456 ///
1457 /// See the documentation for tensor::bubbleUpPadSlice regarding zero slice
1458 /// guard.
1459 using ControlFn = std::function<std::optional<bool>(tensor::ExtractSliceOp)>;
1460
1461 ExtractSliceOfPadTensorSwapPattern(MLIRContext *context,
1462 ControlFn controlFn = nullptr,
1463 PatternBenefit benefit = 1)
1464 : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
1465
1466 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
1467 PatternRewriter &rewriter) const override;
1468
1469private:
1470 ControlFn controlFn;
1471};
1472
1473//===----------------------------------------------------------------------===//
1474// Populate functions.
1475//===----------------------------------------------------------------------===//
1476
1477/// Canonicalization patterns relevant to apply after tiling patterns. These
1478/// are applied automatically by the tiling pass but need to be applied
1479/// manually when tiling is called programmatically.
1480RewritePatternSet getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx);
1481void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
1482
1483/// Linalg generalization patterns
1484
1485/// Populates `patterns` with patterns to convert spec-generated named ops to
1486/// linalg.generic ops.
1487void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
1488
1489/// Linalg decompose convolutions patterns
1490
1491/// Populates patterns to decompose high-D convolution ops into low-D ones.
1492/// This is a step in progressive lowering for convolution ops, afterwards we
1493/// can vectorize the low-D convolution ops.
1494void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
1495 PatternBenefit benefit = 1);
1496
1497/// Populates patterns to transform linalg.conv_2d_xxx operations into
1498/// linalg.generic (for img2col packing) and linalg.matmul.
1499/// \see rewriteInIm2Col for more details.
1500void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns);
1501
1502/// Populates `patterns` with patterns that vectorize tensor.pad.
1503/// These patterns are meant to apply in a complementary fashion. Benefits
1504/// are used to encode a certain ordering of pattern application. To avoid
1505/// scattering magic constants throughout the code base, the patterns must be
1506/// added with this function. `baseBenefit` can be used to offset the benefit
1507/// of all tensor::PadOp vectorization patterns by a certain value.
1508void populatePadOpVectorizationPatterns(RewritePatternSet &patterns,
1509 PatternBenefit baseBenefit = 1);
1510
1511/// Populate patterns for splitting a `LinalgOp` with multiple statements within
1512/// its payload into multiple `GenericOp` that have a single statement.
1513/// The option `removeDeadArgsAndResults` adds patterns to remove dead arguments
1514/// and results from the generated decomposed ops. This is default `true` since
1515/// the core decomposition patterns relies on these clean up patterns. It is set
1516/// to false only for testing purposes.
1517void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns,
1518 bool removeDeadArgsAndResults = true);
1519
1520/// Populate patterns that convert non-destination-style ops to destination
1521/// style ops.
1522void populateConvertToDestinationStylePatterns(RewritePatternSet &patterns);
1523
1524/// Populate patterns for vectorizing low-D convolution ops. This is a step in
1525/// progressive lowering for convolution ops, it assume high-D convolution ops
1526/// were decomposed previously.
1527void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns,
1528 PatternBenefit benefit = 1);
1529
1530/// Populate patterns that convert `ElementwiseMappable` ops to linalg
1531/// parallel loops.
1532void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns);
1533
1534/// Populate patterns that are only useful in the context of sparse tensors.
1535void populateSparseTensorRewriting(RewritePatternSet &patterns);
1536
1537/// Function type which is used to control when to stop fusion. It is expected
1538/// that OpOperand is not modified in the callback. The OpOperand is not marked
1539/// as const to allow callers to use non-const methods.
1540using ControlFusionFn = std::function<bool(OpOperand *fusedOperand)>;
1541
1542/// Patterns for fusing linalg operation on tensors.
1543
1544/// Pattern to fuse `linalg.generic` -> `linalg.generic` operations
1545/// when both operations are fusable elementwise operations.
1546void populateElementwiseOpsFusionPatterns(
1547 RewritePatternSet &patterns,
1548 const ControlFusionFn &controlElementwiseOpFusion);
1549
1550/// Function type which is used to control propagation of tensor.pack/unpack
1551/// ops.
1552using ControlPropagationFn = std::function<bool(Operation *op)>;
1553
1554/// Patterns to bubble up or down data layout ops across other operations.
1555void populateDataLayoutPropagationPatterns(
1556 RewritePatternSet &patterns,
1557 const ControlPropagationFn &controlPackUnPackPropagation);
1558
1559/// Pattern to remove dead operands and results of `linalg.generic` operations.
1560/// This is effectively DCE for a linalg op.
1561void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns);
1562
1563/// Patterns to promote inputs to outputs and remove unused inputs of
1564/// `linalg.generic` ops.
1565void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns);
1566
1567/// Function type to control generic op dimension collapsing. It is expected
1568/// to return an array of `ReassociationIndices` representing dimensions that
1569/// should be merged.
1570using GetCollapsableDimensionsFn =
1571 std::function<SmallVector<ReassociationIndices>(linalg::LinalgOp)>;
1572
1573/// Pattern to collapse dimensions in a linalg.generic op. This will collapse
1574/// tensor operands when needed and expand back the result tensors.
1575void populateCollapseDimensions(
1576 RewritePatternSet &patterns,
1577 const GetCollapsableDimensionsFn &controlCollapseDimensions);
1578
1579/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
1580/// producer (consumer) generic operation by expanding the dimensionality of the
1581/// loop in the generic op.
1582void populateFoldReshapeOpsByExpansionPatterns(
1583 RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes);
1584
1585/// Patterns to fold an expanding tensor.expand_shape operation with its
1586/// producer generic operation by collapsing the dimensions of the generic op.
1587void populateFoldReshapeOpsByCollapsingPatterns(
1588 RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes);
1589
1590/// Patterns to constant fold Linalg operations.
1591void populateConstantFoldLinalgOperations(RewritePatternSet &patterns,
1592 const ControlFusionFn &controlFn);
1593
1594/// Pattern to fuse a `tensor.pad` operation with the producer of its source,
1595/// if the producer is a `linalg` operation with all parallel iterator types.
1596void populateFuseTensorPadWithProducerLinalgOpPatterns(
1597 RewritePatternSet &patterns);
1598
1599/// Patterns to convert from one named op to another. These can be seen as
1600/// canonicalizations of named ops into another named op.
1601void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns);
1602
1603/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
1604/// tensors via reassociative reshape ops.
1605void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns,
1606 ControlDropUnitDims &options);
1607
1608/// A pattern that converts init operands to input operands.
1609void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns);
1610
1611/// Patterns that are used to inline constant operands into linalg generic ops.
1612void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns);
1613
1614/// Patterns that are used to bubble up extract slice op above linalg op.
1615void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
1616
1617/// Adds patterns that waps tensor.extract_slice(linalg.fill(%cst, %init)) into
1618/// linalg.fill(%cst, tensor.extract_slice(%init)).
1619void populateSwapExtractSliceWithFillPatterns(RewritePatternSet &patterns);
1620
1621/// Patterns to apply `splitReduction` below.
1622void populateSplitReductionPattern(
1623 RewritePatternSet &patterns,
1624 const ControlSplitReductionFn &controlSplitReductionFn,
1625 bool useAlloc = false);
1626
1627/// Patterns to convert Linalg matmul ops to transposed variants.
1628void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
1629 bool transposeLHS = true);
1630
1631} // namespace linalg
1632} // namespace mlir
1633
1634#endif // MLIR_DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H
1635

source code of mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h