1//===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
11#include "AffineExprDetail.h"
12#include "mlir/IR/AffineExpr.h"
13#include "mlir/IR/AffineExprVisitor.h"
14#include "mlir/IR/AffineMap.h"
15#include "mlir/IR/IntegerSet.h"
16#include "mlir/Support/MathExtras.h"
17#include "mlir/Support/TypeID.h"
18#include "llvm/ADT/STLExtras.h"
19#include <numeric>
20#include <optional>
21
22using namespace mlir;
23using namespace mlir::detail;
24
25MLIRContext *AffineExpr::getContext() const { return expr->context; }
26
27AffineExprKind AffineExpr::getKind() const { return expr->kind; }
28
29/// Walk all of the AffineExprs in `e` in postorder. This is a private factory
30/// method to help handle lambda walk functions. Users should use the regular
31/// (non-static) `walk` method.
32template <typename WalkRetTy>
33WalkRetTy mlir::AffineExpr::walk(AffineExpr e,
34 function_ref<WalkRetTy(AffineExpr)> callback) {
35 struct AffineExprWalker
36 : public AffineExprVisitor<AffineExprWalker, WalkRetTy> {
37 function_ref<WalkRetTy(AffineExpr)> callback;
38
39 AffineExprWalker(function_ref<WalkRetTy(AffineExpr)> callback)
40 : callback(callback) {}
41
42 WalkRetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
43 return callback(expr);
44 }
45 WalkRetTy visitConstantExpr(AffineConstantExpr expr) {
46 return callback(expr);
47 }
48 WalkRetTy visitDimExpr(AffineDimExpr expr) { return callback(expr); }
49 WalkRetTy visitSymbolExpr(AffineSymbolExpr expr) { return callback(expr); }
50 };
51
52 return AffineExprWalker(callback).walkPostOrder(e);
53}
54// Explicitly instantiate for the two supported return types.
55template void mlir::AffineExpr::walk(AffineExpr e,
56 function_ref<void(AffineExpr)> callback);
57template WalkResult
58mlir::AffineExpr::walk(AffineExpr e,
59 function_ref<WalkResult(AffineExpr)> callback);
60
61// Dispatch affine expression construction based on kind.
62AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
63 AffineExpr rhs) {
64 if (kind == AffineExprKind::Add)
65 return lhs + rhs;
66 if (kind == AffineExprKind::Mul)
67 return lhs * rhs;
68 if (kind == AffineExprKind::FloorDiv)
69 return lhs.floorDiv(other: rhs);
70 if (kind == AffineExprKind::CeilDiv)
71 return lhs.ceilDiv(other: rhs);
72 if (kind == AffineExprKind::Mod)
73 return lhs % rhs;
74
75 llvm_unreachable("unknown binary operation on affine expressions");
76}
77
78/// This method substitutes any uses of dimensions and symbols (e.g.
79/// dim#0 with dimReplacements[0]) and returns the modified expression tree.
80AffineExpr
81AffineExpr::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
82 ArrayRef<AffineExpr> symReplacements) const {
83 switch (getKind()) {
84 case AffineExprKind::Constant:
85 return *this;
86 case AffineExprKind::DimId: {
87 unsigned dimId = llvm::cast<AffineDimExpr>(Val: *this).getPosition();
88 if (dimId >= dimReplacements.size())
89 return *this;
90 return dimReplacements[dimId];
91 }
92 case AffineExprKind::SymbolId: {
93 unsigned symId = llvm::cast<AffineSymbolExpr>(Val: *this).getPosition();
94 if (symId >= symReplacements.size())
95 return *this;
96 return symReplacements[symId];
97 }
98 case AffineExprKind::Add:
99 case AffineExprKind::Mul:
100 case AffineExprKind::FloorDiv:
101 case AffineExprKind::CeilDiv:
102 case AffineExprKind::Mod:
103 auto binOp = llvm::cast<AffineBinaryOpExpr>(Val: *this);
104 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
105 auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
106 auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
107 if (newLHS == lhs && newRHS == rhs)
108 return *this;
109 return getAffineBinaryOpExpr(kind: getKind(), lhs: newLHS, rhs: newRHS);
110 }
111 llvm_unreachable("Unknown AffineExpr");
112}
113
114AffineExpr AffineExpr::replaceDims(ArrayRef<AffineExpr> dimReplacements) const {
115 return replaceDimsAndSymbols(dimReplacements, symReplacements: {});
116}
117
118AffineExpr
119AffineExpr::replaceSymbols(ArrayRef<AffineExpr> symReplacements) const {
120 return replaceDimsAndSymbols(dimReplacements: {}, symReplacements);
121}
122
123/// Replace dims[offset ... numDims)
124/// by dims[offset + shift ... shift + numDims).
125AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift,
126 unsigned offset) const {
127 SmallVector<AffineExpr, 4> dims;
128 for (unsigned idx = 0; idx < offset; ++idx)
129 dims.push_back(Elt: getAffineDimExpr(position: idx, context: getContext()));
130 for (unsigned idx = offset; idx < numDims; ++idx)
131 dims.push_back(Elt: getAffineDimExpr(position: idx + shift, context: getContext()));
132 return replaceDimsAndSymbols(dimReplacements: dims, symReplacements: {});
133}
134
135/// Replace symbols[offset ... numSymbols)
136/// by symbols[offset + shift ... shift + numSymbols).
137AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift,
138 unsigned offset) const {
139 SmallVector<AffineExpr, 4> symbols;
140 for (unsigned idx = 0; idx < offset; ++idx)
141 symbols.push_back(Elt: getAffineSymbolExpr(position: idx, context: getContext()));
142 for (unsigned idx = offset; idx < numSymbols; ++idx)
143 symbols.push_back(Elt: getAffineSymbolExpr(position: idx + shift, context: getContext()));
144 return replaceDimsAndSymbols(dimReplacements: {}, symReplacements: symbols);
145}
146
147/// Sparse replace method. Return the modified expression tree.
148AffineExpr
149AffineExpr::replace(const DenseMap<AffineExpr, AffineExpr> &map) const {
150 auto it = map.find(Val: *this);
151 if (it != map.end())
152 return it->second;
153 switch (getKind()) {
154 default:
155 return *this;
156 case AffineExprKind::Add:
157 case AffineExprKind::Mul:
158 case AffineExprKind::FloorDiv:
159 case AffineExprKind::CeilDiv:
160 case AffineExprKind::Mod:
161 auto binOp = llvm::cast<AffineBinaryOpExpr>(Val: *this);
162 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
163 auto newLHS = lhs.replace(map);
164 auto newRHS = rhs.replace(map);
165 if (newLHS == lhs && newRHS == rhs)
166 return *this;
167 return getAffineBinaryOpExpr(kind: getKind(), lhs: newLHS, rhs: newRHS);
168 }
169 llvm_unreachable("Unknown AffineExpr");
170}
171
172/// Sparse replace method. Return the modified expression tree.
173AffineExpr AffineExpr::replace(AffineExpr expr, AffineExpr replacement) const {
174 DenseMap<AffineExpr, AffineExpr> map;
175 map.insert(KV: std::make_pair(x&: expr, y&: replacement));
176 return replace(map);
177}
178/// Returns true if this expression is made out of only symbols and
179/// constants (no dimensional identifiers).
180bool AffineExpr::isSymbolicOrConstant() const {
181 switch (getKind()) {
182 case AffineExprKind::Constant:
183 return true;
184 case AffineExprKind::DimId:
185 return false;
186 case AffineExprKind::SymbolId:
187 return true;
188
189 case AffineExprKind::Add:
190 case AffineExprKind::Mul:
191 case AffineExprKind::FloorDiv:
192 case AffineExprKind::CeilDiv:
193 case AffineExprKind::Mod: {
194 auto expr = llvm::cast<AffineBinaryOpExpr>(Val: *this);
195 return expr.getLHS().isSymbolicOrConstant() &&
196 expr.getRHS().isSymbolicOrConstant();
197 }
198 }
199 llvm_unreachable("Unknown AffineExpr");
200}
201
202/// Returns true if this is a pure affine expression, i.e., multiplication,
203/// floordiv, ceildiv, and mod is only allowed w.r.t constants.
204bool AffineExpr::isPureAffine() const {
205 switch (getKind()) {
206 case AffineExprKind::SymbolId:
207 case AffineExprKind::DimId:
208 case AffineExprKind::Constant:
209 return true;
210 case AffineExprKind::Add: {
211 auto op = llvm::cast<AffineBinaryOpExpr>(Val: *this);
212 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
213 }
214
215 case AffineExprKind::Mul: {
216 // TODO: Canonicalize the constants in binary operators to the RHS when
217 // possible, allowing this to merge into the next case.
218 auto op = llvm::cast<AffineBinaryOpExpr>(Val: *this);
219 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
220 (llvm::isa<AffineConstantExpr>(Val: op.getLHS()) ||
221 llvm::isa<AffineConstantExpr>(Val: op.getRHS()));
222 }
223 case AffineExprKind::FloorDiv:
224 case AffineExprKind::CeilDiv:
225 case AffineExprKind::Mod: {
226 auto op = llvm::cast<AffineBinaryOpExpr>(Val: *this);
227 return op.getLHS().isPureAffine() &&
228 llvm::isa<AffineConstantExpr>(Val: op.getRHS());
229 }
230 }
231 llvm_unreachable("Unknown AffineExpr");
232}
233
234// Returns the greatest known integral divisor of this affine expression.
235int64_t AffineExpr::getLargestKnownDivisor() const {
236 AffineBinaryOpExpr binExpr(nullptr);
237 switch (getKind()) {
238 case AffineExprKind::DimId:
239 [[fallthrough]];
240 case AffineExprKind::SymbolId:
241 return 1;
242 case AffineExprKind::CeilDiv:
243 [[fallthrough]];
244 case AffineExprKind::FloorDiv: {
245 // If the RHS is a constant and divides the known divisor on the LHS, the
246 // quotient is a known divisor of the expression.
247 binExpr = llvm::cast<AffineBinaryOpExpr>(Val: *this);
248 auto rhs = llvm::dyn_cast<AffineConstantExpr>(Val: binExpr.getRHS());
249 // Leave alone undefined expressions.
250 if (rhs && rhs.getValue() != 0) {
251 int64_t lhsDiv = binExpr.getLHS().getLargestKnownDivisor();
252 if (lhsDiv % rhs.getValue() == 0)
253 return lhsDiv / rhs.getValue();
254 }
255 return 1;
256 }
257 case AffineExprKind::Constant:
258 return std::abs(i: llvm::cast<AffineConstantExpr>(Val: *this).getValue());
259 case AffineExprKind::Mul: {
260 binExpr = llvm::cast<AffineBinaryOpExpr>(Val: *this);
261 return binExpr.getLHS().getLargestKnownDivisor() *
262 binExpr.getRHS().getLargestKnownDivisor();
263 }
264 case AffineExprKind::Add:
265 [[fallthrough]];
266 case AffineExprKind::Mod: {
267 binExpr = llvm::cast<AffineBinaryOpExpr>(Val: *this);
268 return std::gcd(m: (uint64_t)binExpr.getLHS().getLargestKnownDivisor(),
269 n: (uint64_t)binExpr.getRHS().getLargestKnownDivisor());
270 }
271 }
272 llvm_unreachable("Unknown AffineExpr");
273}
274
275bool AffineExpr::isMultipleOf(int64_t factor) const {
276 AffineBinaryOpExpr binExpr(nullptr);
277 uint64_t l, u;
278 switch (getKind()) {
279 case AffineExprKind::SymbolId:
280 [[fallthrough]];
281 case AffineExprKind::DimId:
282 return factor * factor == 1;
283 case AffineExprKind::Constant:
284 return llvm::cast<AffineConstantExpr>(Val: *this).getValue() % factor == 0;
285 case AffineExprKind::Mul: {
286 binExpr = llvm::cast<AffineBinaryOpExpr>(Val: *this);
287 // It's probably not worth optimizing this further (to not traverse the
288 // whole sub-tree under - it that would require a version of isMultipleOf
289 // that on a 'false' return also returns the largest known divisor).
290 return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 ||
291 (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 ||
292 (l * u) % factor == 0;
293 }
294 case AffineExprKind::Add:
295 case AffineExprKind::FloorDiv:
296 case AffineExprKind::CeilDiv:
297 case AffineExprKind::Mod: {
298 binExpr = llvm::cast<AffineBinaryOpExpr>(Val: *this);
299 return std::gcd(m: (uint64_t)binExpr.getLHS().getLargestKnownDivisor(),
300 n: (uint64_t)binExpr.getRHS().getLargestKnownDivisor()) %
301 factor ==
302 0;
303 }
304 }
305 llvm_unreachable("Unknown AffineExpr");
306}
307
308bool AffineExpr::isFunctionOfDim(unsigned position) const {
309 if (getKind() == AffineExprKind::DimId) {
310 return *this == mlir::getAffineDimExpr(position, context: getContext());
311 }
312 if (auto expr = llvm::dyn_cast<AffineBinaryOpExpr>(Val: *this)) {
313 return expr.getLHS().isFunctionOfDim(position) ||
314 expr.getRHS().isFunctionOfDim(position);
315 }
316 return false;
317}
318
319bool AffineExpr::isFunctionOfSymbol(unsigned position) const {
320 if (getKind() == AffineExprKind::SymbolId) {
321 return *this == mlir::getAffineSymbolExpr(position, context: getContext());
322 }
323 if (auto expr = llvm::dyn_cast<AffineBinaryOpExpr>(Val: *this)) {
324 return expr.getLHS().isFunctionOfSymbol(position) ||
325 expr.getRHS().isFunctionOfSymbol(position);
326 }
327 return false;
328}
329
330AffineBinaryOpExpr::AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
331 : AffineExpr(ptr) {}
332AffineExpr AffineBinaryOpExpr::getLHS() const {
333 return static_cast<ImplType *>(expr)->lhs;
334}
335AffineExpr AffineBinaryOpExpr::getRHS() const {
336 return static_cast<ImplType *>(expr)->rhs;
337}
338
339AffineDimExpr::AffineDimExpr(AffineExpr::ImplType *ptr) : AffineExpr(ptr) {}
340unsigned AffineDimExpr::getPosition() const {
341 return static_cast<ImplType *>(expr)->position;
342}
343
344/// Returns true if the expression is divisible by the given symbol with
345/// position `symbolPos`. The argument `opKind` specifies here what kind of
346/// division or mod operation called this division. It helps in implementing the
347/// commutative property of the floordiv and ceildiv operations. If the argument
348///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
349/// operation, then the commutative property can be used otherwise, the floordiv
350/// operation is not divisible. The same argument holds for ceildiv operation.
351static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
352 AffineExprKind opKind) {
353 // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
354 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
355 opKind == AffineExprKind::CeilDiv) &&
356 "unexpected opKind");
357 switch (expr.getKind()) {
358 case AffineExprKind::Constant:
359 return cast<AffineConstantExpr>(Val&: expr).getValue() == 0;
360 case AffineExprKind::DimId:
361 return false;
362 case AffineExprKind::SymbolId:
363 return (cast<AffineSymbolExpr>(Val&: expr).getPosition() == symbolPos);
364 // Checks divisibility by the given symbol for both operands.
365 case AffineExprKind::Add: {
366 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
367 return isDivisibleBySymbol(expr: binaryExpr.getLHS(), symbolPos, opKind) &&
368 isDivisibleBySymbol(expr: binaryExpr.getRHS(), symbolPos, opKind);
369 }
370 // Checks divisibility by the given symbol for both operands. Consider the
371 // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
372 // this is a division by s1 and both the operands of modulo are divisible by
373 // s1 but it is not divisible by s1 always. The third argument is
374 // `AffineExprKind::Mod` for this reason.
375 case AffineExprKind::Mod: {
376 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
377 return isDivisibleBySymbol(expr: binaryExpr.getLHS(), symbolPos,
378 opKind: AffineExprKind::Mod) &&
379 isDivisibleBySymbol(expr: binaryExpr.getRHS(), symbolPos,
380 opKind: AffineExprKind::Mod);
381 }
382 // Checks if any of the operand divisible by the given symbol.
383 case AffineExprKind::Mul: {
384 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
385 return isDivisibleBySymbol(expr: binaryExpr.getLHS(), symbolPos, opKind) ||
386 isDivisibleBySymbol(expr: binaryExpr.getRHS(), symbolPos, opKind);
387 }
388 // Floordiv and ceildiv are divisible by the given symbol when the first
389 // operand is divisible, and the affine expression kind of the argument expr
390 // is same as the argument `opKind`. This can be inferred from commutative
391 // property of floordiv and ceildiv operations and are as follow:
392 // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
393 // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
394 // It will fail if operations are not same. For example:
395 // (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
396 case AffineExprKind::FloorDiv:
397 case AffineExprKind::CeilDiv: {
398 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
399 if (opKind != expr.getKind())
400 return false;
401 return isDivisibleBySymbol(expr: binaryExpr.getLHS(), symbolPos, opKind: expr.getKind());
402 }
403 }
404 llvm_unreachable("Unknown AffineExpr");
405}
406
407/// Divides the given expression by the given symbol at position `symbolPos`. It
408/// considers the divisibility condition is checked before calling itself. A
409/// null expression is returned whenever the divisibility condition fails.
410static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
411 AffineExprKind opKind) {
412 // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
413 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
414 opKind == AffineExprKind::CeilDiv) &&
415 "unexpected opKind");
416 switch (expr.getKind()) {
417 case AffineExprKind::Constant:
418 if (cast<AffineConstantExpr>(Val&: expr).getValue() != 0)
419 return nullptr;
420 return getAffineConstantExpr(constant: 0, context: expr.getContext());
421 case AffineExprKind::DimId:
422 return nullptr;
423 case AffineExprKind::SymbolId:
424 return getAffineConstantExpr(constant: 1, context: expr.getContext());
425 // Dividing both operands by the given symbol.
426 case AffineExprKind::Add: {
427 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
428 return getAffineBinaryOpExpr(
429 kind: expr.getKind(), lhs: symbolicDivide(expr: binaryExpr.getLHS(), symbolPos, opKind),
430 rhs: symbolicDivide(expr: binaryExpr.getRHS(), symbolPos, opKind));
431 }
432 // Dividing both operands by the given symbol.
433 case AffineExprKind::Mod: {
434 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
435 return getAffineBinaryOpExpr(
436 kind: expr.getKind(),
437 lhs: symbolicDivide(expr: binaryExpr.getLHS(), symbolPos, opKind: expr.getKind()),
438 rhs: symbolicDivide(expr: binaryExpr.getRHS(), symbolPos, opKind: expr.getKind()));
439 }
440 // Dividing any of the operand by the given symbol.
441 case AffineExprKind::Mul: {
442 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
443 if (!isDivisibleBySymbol(expr: binaryExpr.getLHS(), symbolPos, opKind))
444 return binaryExpr.getLHS() *
445 symbolicDivide(expr: binaryExpr.getRHS(), symbolPos, opKind);
446 return symbolicDivide(expr: binaryExpr.getLHS(), symbolPos, opKind) *
447 binaryExpr.getRHS();
448 }
449 // Dividing first operand only by the given symbol.
450 case AffineExprKind::FloorDiv:
451 case AffineExprKind::CeilDiv: {
452 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
453 return getAffineBinaryOpExpr(
454 kind: expr.getKind(),
455 lhs: symbolicDivide(expr: binaryExpr.getLHS(), symbolPos, opKind: expr.getKind()),
456 rhs: binaryExpr.getRHS());
457 }
458 }
459 llvm_unreachable("Unknown AffineExpr");
460}
461
462/// Populate `result` with all summand operands of given (potentially nested)
463/// addition. If the given expression is not an addition, just populate the
464/// expression itself.
465/// Example: Add(Add(7, 8), Mul(9, 10)) will return [7, 8, Mul(9, 10)].
466static void getSummandExprs(AffineExpr expr, SmallVector<AffineExpr> &result) {
467 auto addExpr = dyn_cast<AffineBinaryOpExpr>(Val&: expr);
468 if (!addExpr || addExpr.getKind() != AffineExprKind::Add) {
469 result.push_back(Elt: expr);
470 return;
471 }
472 getSummandExprs(expr: addExpr.getLHS(), result);
473 getSummandExprs(expr: addExpr.getRHS(), result);
474}
475
476/// Return "true" if `candidate` is a negated expression, i.e., Mul(-1, expr).
477/// If so, also return the non-negated expression via `expr`.
478static bool isNegatedAffineExpr(AffineExpr candidate, AffineExpr &expr) {
479 auto mulExpr = dyn_cast<AffineBinaryOpExpr>(Val&: candidate);
480 if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
481 return false;
482 if (auto lhs = dyn_cast<AffineConstantExpr>(Val: mulExpr.getLHS())) {
483 if (lhs.getValue() == -1) {
484 expr = mulExpr.getRHS();
485 return true;
486 }
487 }
488 if (auto rhs = dyn_cast<AffineConstantExpr>(Val: mulExpr.getRHS())) {
489 if (rhs.getValue() == -1) {
490 expr = mulExpr.getLHS();
491 return true;
492 }
493 }
494 return false;
495}
496
497/// Return "true" if `lhs` % `rhs` is guaranteed to evaluate to zero based on
498/// the fact that `lhs` contains another modulo expression that ensures that
499/// `lhs` is divisible by `rhs`. This is a common pattern in the resulting IR
500/// after loop peeling.
501///
502/// Example: lhs = ub - ub % step
503/// rhs = step
504/// => (ub - ub % step) % step is guaranteed to evaluate to 0.
505static bool isModOfModSubtraction(AffineExpr lhs, AffineExpr rhs,
506 unsigned numDims, unsigned numSymbols) {
507 // TODO: Try to unify this function with `getBoundForAffineExpr`.
508 // Collect all summands in lhs.
509 SmallVector<AffineExpr> summands;
510 getSummandExprs(expr: lhs, result&: summands);
511 // Look for Mul(-1, Mod(x, rhs)) among the summands. If x matches the
512 // remaining summands, then lhs % rhs is guaranteed to evaluate to 0.
513 for (int64_t i = 0, e = summands.size(); i < e; ++i) {
514 AffineExpr current = summands[i];
515 AffineExpr beforeNegation;
516 if (!isNegatedAffineExpr(candidate: current, expr&: beforeNegation))
517 continue;
518 AffineBinaryOpExpr innerMod = dyn_cast<AffineBinaryOpExpr>(Val&: beforeNegation);
519 if (!innerMod || innerMod.getKind() != AffineExprKind::Mod)
520 continue;
521 if (innerMod.getRHS() != rhs)
522 continue;
523 // Sum all remaining summands and subtract x. If that expression can be
524 // simplified to zero, then the remaining summands and x are equal.
525 AffineExpr diff = getAffineConstantExpr(constant: 0, context: lhs.getContext());
526 for (int64_t j = 0; j < e; ++j)
527 if (i != j)
528 diff = diff + summands[j];
529 diff = diff - innerMod.getLHS();
530 diff = simplifyAffineExpr(expr: diff, numDims, numSymbols);
531 auto constExpr = dyn_cast<AffineConstantExpr>(Val&: diff);
532 if (constExpr && constExpr.getValue() == 0)
533 return true;
534 }
535 return false;
536}
537
538/// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv
539/// operations when the second operand simplifies to a symbol and the first
540/// operand is divisible by that symbol. It can be applied to any semi-affine
541/// expression. Returned expression can either be a semi-affine or pure affine
542/// expression.
543static AffineExpr simplifySemiAffine(AffineExpr expr, unsigned numDims,
544 unsigned numSymbols) {
545 switch (expr.getKind()) {
546 case AffineExprKind::Constant:
547 case AffineExprKind::DimId:
548 case AffineExprKind::SymbolId:
549 return expr;
550 case AffineExprKind::Add:
551 case AffineExprKind::Mul: {
552 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
553 return getAffineBinaryOpExpr(
554 kind: expr.getKind(),
555 lhs: simplifySemiAffine(expr: binaryExpr.getLHS(), numDims, numSymbols),
556 rhs: simplifySemiAffine(expr: binaryExpr.getRHS(), numDims, numSymbols));
557 }
558 // Check if the simplification of the second operand is a symbol, and the
559 // first operand is divisible by it. If the operation is a modulo, a constant
560 // zero expression is returned. In the case of floordiv and ceildiv, the
561 // symbol from the simplification of the second operand divides the first
562 // operand. Otherwise, simplification is not possible.
563 case AffineExprKind::FloorDiv:
564 case AffineExprKind::CeilDiv:
565 case AffineExprKind::Mod: {
566 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
567 AffineExpr sLHS =
568 simplifySemiAffine(expr: binaryExpr.getLHS(), numDims, numSymbols);
569 AffineExpr sRHS =
570 simplifySemiAffine(expr: binaryExpr.getRHS(), numDims, numSymbols);
571 if (isModOfModSubtraction(lhs: sLHS, rhs: sRHS, numDims, numSymbols))
572 return getAffineConstantExpr(constant: 0, context: expr.getContext());
573 AffineSymbolExpr symbolExpr = dyn_cast<AffineSymbolExpr>(
574 Val: simplifySemiAffine(expr: binaryExpr.getRHS(), numDims, numSymbols));
575 if (!symbolExpr)
576 return getAffineBinaryOpExpr(kind: expr.getKind(), lhs: sLHS, rhs: sRHS);
577 unsigned symbolPos = symbolExpr.getPosition();
578 if (!isDivisibleBySymbol(expr: binaryExpr.getLHS(), symbolPos, opKind: expr.getKind()))
579 return getAffineBinaryOpExpr(kind: expr.getKind(), lhs: sLHS, rhs: sRHS);
580 if (expr.getKind() == AffineExprKind::Mod)
581 return getAffineConstantExpr(constant: 0, context: expr.getContext());
582 return symbolicDivide(expr: sLHS, symbolPos, opKind: expr.getKind());
583 }
584 }
585 llvm_unreachable("Unknown AffineExpr");
586}
587
588static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
589 MLIRContext *context) {
590 auto assignCtx = [context](AffineDimExprStorage *storage) {
591 storage->context = context;
592 };
593
594 StorageUniquer &uniquer = context->getAffineUniquer();
595 return uniquer.get<AffineDimExprStorage>(
596 initFn: assignCtx, args: static_cast<unsigned>(kind), args&: position);
597}
598
599AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
600 return getAffineDimOrSymbol(kind: AffineExprKind::DimId, position, context);
601}
602
603AffineSymbolExpr::AffineSymbolExpr(AffineExpr::ImplType *ptr)
604 : AffineExpr(ptr) {}
605unsigned AffineSymbolExpr::getPosition() const {
606 return static_cast<ImplType *>(expr)->position;
607}
608
609AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
610 return getAffineDimOrSymbol(kind: AffineExprKind::SymbolId, position, context);
611}
612
613AffineConstantExpr::AffineConstantExpr(AffineExpr::ImplType *ptr)
614 : AffineExpr(ptr) {}
615int64_t AffineConstantExpr::getValue() const {
616 return static_cast<ImplType *>(expr)->constant;
617}
618
619bool AffineExpr::operator==(int64_t v) const {
620 return *this == getAffineConstantExpr(constant: v, context: getContext());
621}
622
623AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
624 auto assignCtx = [context](AffineConstantExprStorage *storage) {
625 storage->context = context;
626 };
627
628 StorageUniquer &uniquer = context->getAffineUniquer();
629 return uniquer.get<AffineConstantExprStorage>(initFn: assignCtx, args&: constant);
630}
631
632SmallVector<AffineExpr>
633mlir::getAffineConstantExprs(ArrayRef<int64_t> constants,
634 MLIRContext *context) {
635 return llvm::to_vector(Range: llvm::map_range(C&: constants, F: [&](int64_t constant) {
636 return getAffineConstantExpr(constant, context);
637 }));
638}
639
640/// Simplify add expression. Return nullptr if it can't be simplified.
641static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
642 auto lhsConst = dyn_cast<AffineConstantExpr>(Val&: lhs);
643 auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhs);
644 // Fold if both LHS, RHS are a constant.
645 if (lhsConst && rhsConst)
646 return getAffineConstantExpr(constant: lhsConst.getValue() + rhsConst.getValue(),
647 context: lhs.getContext());
648
649 // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
650 // If only one of them is a symbolic expressions, make it the RHS.
651 if (isa<AffineConstantExpr>(Val: lhs) ||
652 (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) {
653 return rhs + lhs;
654 }
655
656 // At this point, if there was a constant, it would be on the right.
657
658 // Addition with a zero is a noop, return the other input.
659 if (rhsConst) {
660 if (rhsConst.getValue() == 0)
661 return lhs;
662 }
663 // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
664 auto lBin = dyn_cast<AffineBinaryOpExpr>(Val&: lhs);
665 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) {
666 if (auto lrhs = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS()))
667 return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
668 }
669
670 // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr".
671 // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their
672 // respective multiplicands.
673 std::optional<int64_t> rLhsConst, rRhsConst;
674 AffineExpr firstExpr, secondExpr;
675 AffineConstantExpr rLhsConstExpr;
676 auto lBinOpExpr = dyn_cast<AffineBinaryOpExpr>(Val&: lhs);
677 if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul &&
678 (rLhsConstExpr = dyn_cast<AffineConstantExpr>(Val: lBinOpExpr.getRHS()))) {
679 rLhsConst = rLhsConstExpr.getValue();
680 firstExpr = lBinOpExpr.getLHS();
681 } else {
682 rLhsConst = 1;
683 firstExpr = lhs;
684 }
685
686 auto rBinOpExpr = dyn_cast<AffineBinaryOpExpr>(Val&: rhs);
687 AffineConstantExpr rRhsConstExpr;
688 if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul &&
689 (rRhsConstExpr = dyn_cast<AffineConstantExpr>(Val: rBinOpExpr.getRHS()))) {
690 rRhsConst = rRhsConstExpr.getValue();
691 secondExpr = rBinOpExpr.getLHS();
692 } else {
693 rRhsConst = 1;
694 secondExpr = rhs;
695 }
696
697 if (rLhsConst && rRhsConst && firstExpr == secondExpr)
698 return getAffineBinaryOpExpr(
699 kind: AffineExprKind::Mul, lhs: firstExpr,
700 rhs: getAffineConstantExpr(constant: *rLhsConst + *rRhsConst, context: lhs.getContext()));
701
702 // When doing successive additions, bring constant to the right: turn (d0 + 2)
703 // + d1 into (d0 + d1) + 2.
704 if (lBin && lBin.getKind() == AffineExprKind::Add) {
705 if (auto lrhs = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS())) {
706 return lBin.getLHS() + rhs + lrhs;
707 }
708 }
709
710 // Detect and transform "expr - q * (expr floordiv q)" to "expr mod q", where
711 // q may be a constant or symbolic expression. This leads to a much more
712 // efficient form when 'c' is a power of two, and in general a more compact
713 // and readable form.
714
715 // Process '(expr floordiv c) * (-c)'.
716 if (!rBinOpExpr)
717 return nullptr;
718
719 auto lrhs = rBinOpExpr.getLHS();
720 auto rrhs = rBinOpExpr.getRHS();
721
722 AffineExpr llrhs, rlrhs;
723
724 // Check if lrhsBinOpExpr is of the form (expr floordiv q) * q, where q is a
725 // symbolic expression.
726 auto lrhsBinOpExpr = dyn_cast<AffineBinaryOpExpr>(Val&: lrhs);
727 // Check rrhsConstOpExpr = -1.
728 auto rrhsConstOpExpr = dyn_cast<AffineConstantExpr>(Val&: rrhs);
729 if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr &&
730 lrhsBinOpExpr.getKind() == AffineExprKind::Mul) {
731 // Check llrhs = expr floordiv q.
732 llrhs = lrhsBinOpExpr.getLHS();
733 // Check rlrhs = q.
734 rlrhs = lrhsBinOpExpr.getRHS();
735 auto llrhsBinOpExpr = dyn_cast<AffineBinaryOpExpr>(Val&: llrhs);
736 if (!llrhsBinOpExpr || llrhsBinOpExpr.getKind() != AffineExprKind::FloorDiv)
737 return nullptr;
738 if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS())
739 return lhs % rlrhs;
740 }
741
742 // Process lrhs, which is 'expr floordiv c'.
743 AffineBinaryOpExpr lrBinOpExpr = dyn_cast<AffineBinaryOpExpr>(Val&: lrhs);
744 if (!lrBinOpExpr || lrBinOpExpr.getKind() != AffineExprKind::FloorDiv)
745 return nullptr;
746
747 llrhs = lrBinOpExpr.getLHS();
748 rlrhs = lrBinOpExpr.getRHS();
749
750 if (lhs == llrhs && rlrhs == -rrhs) {
751 return lhs % rlrhs;
752 }
753 return nullptr;
754}
755
756AffineExpr AffineExpr::operator+(int64_t v) const {
757 return *this + getAffineConstantExpr(constant: v, context: getContext());
758}
759AffineExpr AffineExpr::operator+(AffineExpr other) const {
760 if (auto simplified = simplifyAdd(lhs: *this, rhs: other))
761 return simplified;
762
763 StorageUniquer &uniquer = getContext()->getAffineUniquer();
764 return uniquer.get<AffineBinaryOpExprStorage>(
765 /*initFn=*/{}, args: static_cast<unsigned>(AffineExprKind::Add), args: *this, args&: other);
766}
767
768/// Simplify a multiply expression. Return nullptr if it can't be simplified.
769static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
770 auto lhsConst = dyn_cast<AffineConstantExpr>(Val&: lhs);
771 auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhs);
772
773 if (lhsConst && rhsConst)
774 return getAffineConstantExpr(constant: lhsConst.getValue() * rhsConst.getValue(),
775 context: lhs.getContext());
776
777 if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())
778 return nullptr;
779
780 // Canonicalize the mul expression so that the constant/symbolic term is the
781 // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
782 // constant. (Note that a constant is trivially symbolic).
783 if (!rhs.isSymbolicOrConstant() || isa<AffineConstantExpr>(Val: lhs)) {
784 // At least one of them has to be symbolic.
785 return rhs * lhs;
786 }
787
788 // At this point, if there was a constant, it would be on the right.
789
790 // Multiplication with a one is a noop, return the other input.
791 if (rhsConst) {
792 if (rhsConst.getValue() == 1)
793 return lhs;
794 // Multiplication with zero.
795 if (rhsConst.getValue() == 0)
796 return rhsConst;
797 }
798
799 // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
800 auto lBin = dyn_cast<AffineBinaryOpExpr>(Val&: lhs);
801 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) {
802 if (auto lrhs = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS()))
803 return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue());
804 }
805
806 // When doing successive multiplication, bring constant to the right: turn (d0
807 // * 2) * d1 into (d0 * d1) * 2.
808 if (lBin && lBin.getKind() == AffineExprKind::Mul) {
809 if (auto lrhs = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS())) {
810 return (lBin.getLHS() * rhs) * lrhs;
811 }
812 }
813
814 return nullptr;
815}
816
817AffineExpr AffineExpr::operator*(int64_t v) const {
818 return *this * getAffineConstantExpr(constant: v, context: getContext());
819}
820AffineExpr AffineExpr::operator*(AffineExpr other) const {
821 if (auto simplified = simplifyMul(lhs: *this, rhs: other))
822 return simplified;
823
824 StorageUniquer &uniquer = getContext()->getAffineUniquer();
825 return uniquer.get<AffineBinaryOpExprStorage>(
826 /*initFn=*/{}, args: static_cast<unsigned>(AffineExprKind::Mul), args: *this, args&: other);
827}
828
829// Unary minus, delegate to operator*.
830AffineExpr AffineExpr::operator-() const {
831 return *this * getAffineConstantExpr(constant: -1, context: getContext());
832}
833
834// Delegate to operator+.
835AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); }
836AffineExpr AffineExpr::operator-(AffineExpr other) const {
837 return *this + (-other);
838}
839
840static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
841 auto lhsConst = dyn_cast<AffineConstantExpr>(Val&: lhs);
842 auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhs);
843
844 // mlir floordiv by zero or negative numbers is undefined and preserved as is.
845 if (!rhsConst || rhsConst.getValue() < 1)
846 return nullptr;
847
848 if (lhsConst)
849 return getAffineConstantExpr(
850 constant: floorDiv(lhs: lhsConst.getValue(), rhs: rhsConst.getValue()), context: lhs.getContext());
851
852 // Fold floordiv of a multiply with a constant that is a multiple of the
853 // divisor. Eg: (i * 128) floordiv 64 = i * 2.
854 if (rhsConst == 1)
855 return lhs;
856
857 // Simplify (expr * const) floordiv divConst when expr is known to be a
858 // multiple of divConst.
859 auto lBin = dyn_cast<AffineBinaryOpExpr>(Val&: lhs);
860 if (lBin && lBin.getKind() == AffineExprKind::Mul) {
861 if (auto lrhs = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS())) {
862 // rhsConst is known to be a positive constant.
863 if (lrhs.getValue() % rhsConst.getValue() == 0)
864 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
865 }
866 }
867
868 // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is
869 // known to be a multiple of divConst.
870 if (lBin && lBin.getKind() == AffineExprKind::Add) {
871 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
872 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
873 // rhsConst is known to be a positive constant.
874 if (llhsDiv % rhsConst.getValue() == 0 ||
875 lrhsDiv % rhsConst.getValue() == 0)
876 return lBin.getLHS().floorDiv(v: rhsConst.getValue()) +
877 lBin.getRHS().floorDiv(v: rhsConst.getValue());
878 }
879
880 return nullptr;
881}
882
883AffineExpr AffineExpr::floorDiv(uint64_t v) const {
884 return floorDiv(other: getAffineConstantExpr(constant: v, context: getContext()));
885}
886AffineExpr AffineExpr::floorDiv(AffineExpr other) const {
887 if (auto simplified = simplifyFloorDiv(lhs: *this, rhs: other))
888 return simplified;
889
890 StorageUniquer &uniquer = getContext()->getAffineUniquer();
891 return uniquer.get<AffineBinaryOpExprStorage>(
892 /*initFn=*/{}, args: static_cast<unsigned>(AffineExprKind::FloorDiv), args: *this,
893 args&: other);
894}
895
896static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
897 auto lhsConst = dyn_cast<AffineConstantExpr>(Val&: lhs);
898 auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhs);
899
900 if (!rhsConst || rhsConst.getValue() < 1)
901 return nullptr;
902
903 if (lhsConst)
904 return getAffineConstantExpr(
905 constant: ceilDiv(lhs: lhsConst.getValue(), rhs: rhsConst.getValue()), context: lhs.getContext());
906
907 // Fold ceildiv of a multiply with a constant that is a multiple of the
908 // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
909 if (rhsConst.getValue() == 1)
910 return lhs;
911
912 // Simplify (expr * const) ceildiv divConst when const is known to be a
913 // multiple of divConst.
914 auto lBin = dyn_cast<AffineBinaryOpExpr>(Val&: lhs);
915 if (lBin && lBin.getKind() == AffineExprKind::Mul) {
916 if (auto lrhs = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS())) {
917 // rhsConst is known to be a positive constant.
918 if (lrhs.getValue() % rhsConst.getValue() == 0)
919 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
920 }
921 }
922
923 return nullptr;
924}
925
926AffineExpr AffineExpr::ceilDiv(uint64_t v) const {
927 return ceilDiv(other: getAffineConstantExpr(constant: v, context: getContext()));
928}
929AffineExpr AffineExpr::ceilDiv(AffineExpr other) const {
930 if (auto simplified = simplifyCeilDiv(lhs: *this, rhs: other))
931 return simplified;
932
933 StorageUniquer &uniquer = getContext()->getAffineUniquer();
934 return uniquer.get<AffineBinaryOpExprStorage>(
935 /*initFn=*/{}, args: static_cast<unsigned>(AffineExprKind::CeilDiv), args: *this,
936 args&: other);
937}
938
939static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
940 auto lhsConst = dyn_cast<AffineConstantExpr>(Val&: lhs);
941 auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhs);
942
943 // mod w.r.t zero or negative numbers is undefined and preserved as is.
944 if (!rhsConst || rhsConst.getValue() < 1)
945 return nullptr;
946
947 if (lhsConst)
948 return getAffineConstantExpr(constant: mod(lhs: lhsConst.getValue(), rhs: rhsConst.getValue()),
949 context: lhs.getContext());
950
951 // Fold modulo of an expression that is known to be a multiple of a constant
952 // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
953 // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
954 if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0)
955 return getAffineConstantExpr(constant: 0, context: lhs.getContext());
956
957 // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is
958 // known to be a multiple of divConst.
959 auto lBin = dyn_cast<AffineBinaryOpExpr>(Val&: lhs);
960 if (lBin && lBin.getKind() == AffineExprKind::Add) {
961 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
962 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
963 // rhsConst is known to be a positive constant.
964 if (llhsDiv % rhsConst.getValue() == 0)
965 return lBin.getRHS() % rhsConst.getValue();
966 if (lrhsDiv % rhsConst.getValue() == 0)
967 return lBin.getLHS() % rhsConst.getValue();
968 }
969
970 // Simplify (e % a) % b to e % b when b evenly divides a
971 if (lBin && lBin.getKind() == AffineExprKind::Mod) {
972 auto intermediate = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS());
973 if (intermediate && intermediate.getValue() >= 1 &&
974 mod(lhs: intermediate.getValue(), rhs: rhsConst.getValue()) == 0) {
975 return lBin.getLHS() % rhsConst.getValue();
976 }
977 }
978
979 return nullptr;
980}
981
982AffineExpr AffineExpr::operator%(uint64_t v) const {
983 return *this % getAffineConstantExpr(constant: v, context: getContext());
984}
985AffineExpr AffineExpr::operator%(AffineExpr other) const {
986 if (auto simplified = simplifyMod(lhs: *this, rhs: other))
987 return simplified;
988
989 StorageUniquer &uniquer = getContext()->getAffineUniquer();
990 return uniquer.get<AffineBinaryOpExprStorage>(
991 /*initFn=*/{}, args: static_cast<unsigned>(AffineExprKind::Mod), args: *this, args&: other);
992}
993
994AffineExpr AffineExpr::compose(AffineMap map) const {
995 SmallVector<AffineExpr, 8> dimReplacements(map.getResults().begin(),
996 map.getResults().end());
997 return replaceDimsAndSymbols(dimReplacements, symReplacements: {});
998}
999raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr expr) {
1000 expr.print(os);
1001 return os;
1002}
1003
1004/// Constructs an affine expression from a flat ArrayRef. If there are local
1005/// identifiers (neither dimensional nor symbolic) that appear in the sum of
1006/// products expression, `localExprs` is expected to have the AffineExpr
1007/// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be
1008/// in the format [dims, symbols, locals, constant term].
1009AffineExpr mlir::getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
1010 unsigned numDims,
1011 unsigned numSymbols,
1012 ArrayRef<AffineExpr> localExprs,
1013 MLIRContext *context) {
1014 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
1015 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
1016 "unexpected number of local expressions");
1017
1018 auto expr = getAffineConstantExpr(constant: 0, context);
1019 // Dimensions and symbols.
1020 for (unsigned j = 0; j < numDims + numSymbols; j++) {
1021 if (flatExprs[j] == 0)
1022 continue;
1023 auto id = j < numDims ? getAffineDimExpr(position: j, context)
1024 : getAffineSymbolExpr(position: j - numDims, context);
1025 expr = expr + id * flatExprs[j];
1026 }
1027
1028 // Local identifiers.
1029 for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
1030 j++) {
1031 if (flatExprs[j] == 0)
1032 continue;
1033 auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
1034 expr = expr + term;
1035 }
1036
1037 // Constant term.
1038 int64_t constTerm = flatExprs[flatExprs.size() - 1];
1039 if (constTerm != 0)
1040 expr = expr + constTerm;
1041 return expr;
1042}
1043
1044/// Constructs a semi-affine expression from a flat ArrayRef. If there are
1045/// local identifiers (neither dimensional nor symbolic) that appear in the sum
1046/// of products expression, `localExprs` is expected to have the AffineExprs for
1047/// it, and is substituted into. The ArrayRef `flatExprs` is expected to be in
1048/// the format [dims, symbols, locals, constant term]. The semi-affine
1049/// expression is constructed in the sorted order of dimension and symbol
1050/// position numbers. Note: local expressions/ids are used for mod, div as well
1051/// as symbolic RHS terms for terms that are not pure affine.
1052static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
1053 unsigned numDims,
1054 unsigned numSymbols,
1055 ArrayRef<AffineExpr> localExprs,
1056 MLIRContext *context) {
1057 assert(!flatExprs.empty() && "flatExprs cannot be empty");
1058
1059 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
1060 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
1061 "unexpected number of local expressions");
1062
1063 AffineExpr expr = getAffineConstantExpr(constant: 0, context);
1064
1065 // We design indices as a pair which help us present the semi-affine map as
1066 // sum of product where terms are sorted based on dimension or symbol
1067 // position: <keyA, keyB> for expressions of the form dimension * symbol,
1068 // where keyA is the position number of the dimension and keyB is the
1069 // position number of the symbol. For dimensional expressions we set the index
1070 // as (position number of the dimension, -1), as we want dimensional
1071 // expressions to appear before symbolic and product of dimensional and
1072 // symbolic expressions having the dimension with the same position number.
1073 // For symbolic expression set the index as (position number of the symbol,
1074 // maximum of last dimension and symbol position) number. For example, we want
1075 // the expression we are constructing to look something like: d0 + d0 * s0 +
1076 // s0 + d1*s1 + s1.
1077
1078 // Stores the affine expression corresponding to a given index.
1079 DenseMap<std::pair<unsigned, signed>, AffineExpr> indexToExprMap;
1080 // Stores the constant coefficient value corresponding to a given
1081 // dimension, symbol or a non-pure affine expression stored in `localExprs`.
1082 DenseMap<std::pair<unsigned, signed>, int64_t> coefficients;
1083 // Stores the indices as defined above, and later sorted to produce
1084 // the semi-affine expression in the desired form.
1085 SmallVector<std::pair<unsigned, signed>, 8> indices;
1086
1087 // Example: expression = d0 + d0 * s0 + 2 * s0.
1088 // indices = [{0,-1}, {0, 0}, {0, 1}]
1089 // coefficients = [{{0, -1}, 1}, {{0, 0}, 1}, {{0, 1}, 2}]
1090 // indexToExprMap = [{{0, -1}, d0}, {{0, 0}, d0 * s0}, {{0, 1}, s0}]
1091
1092 // Adds entries to `indexToExprMap`, `coefficients` and `indices`.
1093 auto addEntry = [&](std::pair<unsigned, signed> index, int64_t coefficient,
1094 AffineExpr expr) {
1095 assert(!llvm::is_contained(indices, index) &&
1096 "Key is already present in indices vector and overwriting will "
1097 "happen in `indexToExprMap` and `coefficients`!");
1098
1099 indices.push_back(Elt: index);
1100 coefficients.insert(KV: {index, coefficient});
1101 indexToExprMap.insert(KV: {index, expr});
1102 };
1103
1104 // Design indices for dimensional or symbolic terms, and store the indices,
1105 // constant coefficient corresponding to the indices in `coefficients` map,
1106 // and affine expression corresponding to indices in `indexToExprMap` map.
1107
1108 // Ensure we do not have duplicate keys in `indexToExpr` map.
1109 unsigned offsetSym = 0;
1110 signed offsetDim = -1;
1111 for (unsigned j = numDims; j < numDims + numSymbols; ++j) {
1112 if (flatExprs[j] == 0)
1113 continue;
1114 // For symbolic expression set the index as <position number
1115 // of the symbol, max(dimCount, symCount)> number,
1116 // as we want symbolic expressions with the same positional number to
1117 // appear after dimensional expressions having the same positional number.
1118 std::pair<unsigned, signed> indexEntry(
1119 j - numDims, std::max(a: numDims, b: numSymbols) + offsetSym++);
1120 addEntry(indexEntry, flatExprs[j],
1121 getAffineSymbolExpr(position: j - numDims, context));
1122 }
1123
1124 // Denotes semi-affine product, modulo or division terms, which has been added
1125 // to the `indexToExpr` map.
1126 SmallVector<bool, 4> addedToMap(flatExprs.size() - numDims - numSymbols - 1,
1127 false);
1128 unsigned lhsPos, rhsPos;
1129 // Construct indices for product terms involving dimension, symbol or constant
1130 // as lhs/rhs, and store the indices, constant coefficient corresponding to
1131 // the indices in `coefficients` map, and affine expression corresponding to
1132 // in indices in `indexToExprMap` map.
1133 for (const auto &it : llvm::enumerate(First&: localExprs)) {
1134 AffineExpr expr = it.value();
1135 if (flatExprs[numDims + numSymbols + it.index()] == 0)
1136 continue;
1137 AffineExpr lhs = cast<AffineBinaryOpExpr>(Val&: expr).getLHS();
1138 AffineExpr rhs = cast<AffineBinaryOpExpr>(Val&: expr).getRHS();
1139 if (!((isa<AffineDimExpr>(Val: lhs) || isa<AffineSymbolExpr>(Val: lhs)) &&
1140 (isa<AffineDimExpr>(Val: rhs) || isa<AffineSymbolExpr>(Val: rhs) ||
1141 isa<AffineConstantExpr>(Val: rhs)))) {
1142 continue;
1143 }
1144 if (isa<AffineConstantExpr>(Val: rhs)) {
1145 // For product/modulo/division expressions, when rhs of modulo/division
1146 // expression is constant, we put 0 in place of keyB, because we want
1147 // them to appear earlier in the semi-affine expression we are
1148 // constructing. When rhs is constant, we place 0 in place of keyB.
1149 if (isa<AffineDimExpr>(Val: lhs)) {
1150 lhsPos = cast<AffineDimExpr>(Val&: lhs).getPosition();
1151 std::pair<unsigned, signed> indexEntry(lhsPos, offsetDim--);
1152 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1153 expr);
1154 } else {
1155 lhsPos = cast<AffineSymbolExpr>(Val&: lhs).getPosition();
1156 std::pair<unsigned, signed> indexEntry(
1157 lhsPos, std::max(a: numDims, b: numSymbols) + offsetSym++);
1158 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1159 expr);
1160 }
1161 } else if (isa<AffineDimExpr>(Val: lhs)) {
1162 // For product/modulo/division expressions having lhs as dimension and rhs
1163 // as symbol, we order the terms in the semi-affine expression based on
1164 // the pair: <keyA, keyB> for expressions of the form dimension * symbol,
1165 // where keyA is the position number of the dimension and keyB is the
1166 // position number of the symbol.
1167 lhsPos = cast<AffineDimExpr>(Val&: lhs).getPosition();
1168 rhsPos = cast<AffineSymbolExpr>(Val&: rhs).getPosition();
1169 std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1170 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1171 } else {
1172 // For product/modulo/division expressions having both lhs and rhs as
1173 // symbol, we design indices as a pair: <keyA, keyB> for expressions
1174 // of the form dimension * symbol, where keyA is the position number of
1175 // the dimension and keyB is the position number of the symbol.
1176 lhsPos = cast<AffineSymbolExpr>(Val&: lhs).getPosition();
1177 rhsPos = cast<AffineSymbolExpr>(Val&: rhs).getPosition();
1178 std::pair<unsigned, signed> indexEntry(
1179 lhsPos, std::max(a: numDims, b: numSymbols) + offsetSym++);
1180 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1181 }
1182 addedToMap[it.index()] = true;
1183 }
1184
1185 for (unsigned j = 0; j < numDims; ++j) {
1186 if (flatExprs[j] == 0)
1187 continue;
1188 // For dimensional expressions we set the index as <position number of the
1189 // dimension, 0>, as we want dimensional expressions to appear before
1190 // symbolic ones and products of dimensional and symbolic expressions
1191 // having the dimension with the same position number.
1192 std::pair<unsigned, signed> indexEntry(j, offsetDim--);
1193 addEntry(indexEntry, flatExprs[j], getAffineDimExpr(position: j, context));
1194 }
1195
1196 // Constructing the simplified semi-affine sum of product/division/mod
1197 // expression from the flattened form in the desired sorted order of indices
1198 // of the various individual product/division/mod expressions.
1199 llvm::sort(C&: indices);
1200 for (const std::pair<unsigned, unsigned> index : indices) {
1201 assert(indexToExprMap.lookup(index) &&
1202 "cannot find key in `indexToExprMap` map");
1203 expr = expr + indexToExprMap.lookup(Val: index) * coefficients.lookup(Val: index);
1204 }
1205
1206 // Local identifiers.
1207 for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
1208 j++) {
1209 // If the coefficient of the local expression is 0, continue as we need not
1210 // add it in out final expression.
1211 if (flatExprs[j] == 0 || addedToMap[j - numDims - numSymbols])
1212 continue;
1213 auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
1214 expr = expr + term;
1215 }
1216
1217 // Constant term.
1218 int64_t constTerm = flatExprs.back();
1219 if (constTerm != 0)
1220 expr = expr + constTerm;
1221 return expr;
1222}
1223
1224SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims,
1225 unsigned numSymbols)
1226 : numDims(numDims), numSymbols(numSymbols), numLocals(0) {
1227 operandExprStack.reserve(n: 8);
1228}
1229
1230// In pure affine t = expr * c, we multiply each coefficient of lhs with c.
1231//
1232// In case of semi affine multiplication expressions, t = expr * symbolic_expr,
1233// introduce a local variable p (= expr * symbolic_expr), and the affine
1234// expression expr * symbolic_expr is added to `localExprs`.
1235LogicalResult SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
1236 assert(operandExprStack.size() >= 2);
1237 SmallVector<int64_t, 8> rhs = operandExprStack.back();
1238 operandExprStack.pop_back();
1239 SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1240
1241 // Flatten semi-affine multiplication expressions by introducing a local
1242 // variable in place of the product; the affine expression
1243 // corresponding to the quantifier is added to `localExprs`.
1244 if (!isa<AffineConstantExpr>(Val: expr.getRHS())) {
1245 MLIRContext *context = expr.getContext();
1246 AffineExpr a = getAffineExprFromFlatForm(flatExprs: lhs, numDims, numSymbols,
1247 localExprs, context);
1248 AffineExpr b = getAffineExprFromFlatForm(flatExprs: rhs, numDims, numSymbols,
1249 localExprs, context);
1250 addLocalVariableSemiAffine(expr: a * b, result&: lhs, resultSize: lhs.size());
1251 return success();
1252 }
1253
1254 // Get the RHS constant.
1255 int64_t rhsConst = rhs[getConstantIndex()];
1256 for (int64_t &lhsElt : lhs)
1257 lhsElt *= rhsConst;
1258
1259 return success();
1260}
1261
1262LogicalResult SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
1263 assert(operandExprStack.size() >= 2);
1264 const auto &rhs = operandExprStack.back();
1265 auto &lhs = operandExprStack[operandExprStack.size() - 2];
1266 assert(lhs.size() == rhs.size());
1267 // Update the LHS in place.
1268 for (unsigned i = 0, e = rhs.size(); i < e; i++) {
1269 lhs[i] += rhs[i];
1270 }
1271 // Pop off the RHS.
1272 operandExprStack.pop_back();
1273 return success();
1274}
1275
1276//
1277// t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
1278//
1279// A mod expression "expr mod c" is thus flattened by introducing a new local
1280// variable q (= expr floordiv c), such that expr mod c is replaced with
1281// 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
1282//
1283// In case of semi-affine modulo expressions, t = expr mod symbolic_expr,
1284// introduce a local variable m (= expr mod symbolic_expr), and the affine
1285// expression expr mod symbolic_expr is added to `localExprs`.
1286LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
1287 assert(operandExprStack.size() >= 2);
1288
1289 SmallVector<int64_t, 8> rhs = operandExprStack.back();
1290 operandExprStack.pop_back();
1291 SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1292 MLIRContext *context = expr.getContext();
1293
1294 // Flatten semi affine modulo expressions by introducing a local
1295 // variable in place of the modulo value, and the affine expression
1296 // corresponding to the quantifier is added to `localExprs`.
1297 if (!isa<AffineConstantExpr>(Val: expr.getRHS())) {
1298 AffineExpr dividendExpr = getAffineExprFromFlatForm(
1299 flatExprs: lhs, numDims, numSymbols, localExprs, context);
1300 AffineExpr divisorExpr = getAffineExprFromFlatForm(flatExprs: rhs, numDims, numSymbols,
1301 localExprs, context);
1302 AffineExpr modExpr = dividendExpr % divisorExpr;
1303 addLocalVariableSemiAffine(expr: modExpr, result&: lhs, resultSize: lhs.size());
1304 return success();
1305 }
1306
1307 int64_t rhsConst = rhs[getConstantIndex()];
1308 if (rhsConst <= 0)
1309 return failure();
1310
1311 // Check if the LHS expression is a multiple of modulo factor.
1312 unsigned i, e;
1313 for (i = 0, e = lhs.size(); i < e; i++)
1314 if (lhs[i] % rhsConst != 0)
1315 break;
1316 // If yes, modulo expression here simplifies to zero.
1317 if (i == lhs.size()) {
1318 std::fill(first: lhs.begin(), last: lhs.end(), value: 0);
1319 return success();
1320 }
1321
1322 // Add a local variable for the quotient, i.e., expr % c is replaced by
1323 // (expr - q * c) where q = expr floordiv c. Do this while canceling out
1324 // the GCD of expr and c.
1325 SmallVector<int64_t, 8> floorDividend(lhs);
1326 uint64_t gcd = rhsConst;
1327 for (int64_t lhsElt : lhs)
1328 gcd = std::gcd(m: gcd, n: (uint64_t)std::abs(i: lhsElt));
1329 // Simplify the numerator and the denominator.
1330 if (gcd != 1) {
1331 for (int64_t &floorDividendElt : floorDividend)
1332 floorDividendElt = floorDividendElt / static_cast<int64_t>(gcd);
1333 }
1334 int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd);
1335
1336 // Construct the AffineExpr form of the floordiv to store in localExprs.
1337
1338 AffineExpr dividendExpr = getAffineExprFromFlatForm(
1339 flatExprs: floorDividend, numDims, numSymbols, localExprs, context);
1340 AffineExpr divisorExpr = getAffineConstantExpr(constant: floorDivisor, context);
1341 AffineExpr floorDivExpr = dividendExpr.floorDiv(other: divisorExpr);
1342 int loc;
1343 if ((loc = findLocalId(localExpr: floorDivExpr)) == -1) {
1344 addLocalFloorDivId(dividend: floorDividend, divisor: floorDivisor, localExpr: floorDivExpr);
1345 // Set result at top of stack to "lhs - rhsConst * q".
1346 lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
1347 } else {
1348 // Reuse the existing local id.
1349 lhs[getLocalVarStartIndex() + loc] = -rhsConst;
1350 }
1351 return success();
1352}
1353
1354LogicalResult
1355SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) {
1356 return visitDivExpr(expr, /*isCeil=*/true);
1357}
1358LogicalResult
1359SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) {
1360 return visitDivExpr(expr, /*isCeil=*/false);
1361}
1362
1363LogicalResult SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) {
1364 operandExprStack.emplace_back(args: SmallVector<int64_t, 32>(getNumCols(), 0));
1365 auto &eq = operandExprStack.back();
1366 assert(expr.getPosition() < numDims && "Inconsistent number of dims");
1367 eq[getDimStartIndex() + expr.getPosition()] = 1;
1368 return success();
1369}
1370
1371LogicalResult
1372SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) {
1373 operandExprStack.emplace_back(args: SmallVector<int64_t, 32>(getNumCols(), 0));
1374 auto &eq = operandExprStack.back();
1375 assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
1376 eq[getSymbolStartIndex() + expr.getPosition()] = 1;
1377 return success();
1378}
1379
1380LogicalResult
1381SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
1382 operandExprStack.emplace_back(args: SmallVector<int64_t, 32>(getNumCols(), 0));
1383 auto &eq = operandExprStack.back();
1384 eq[getConstantIndex()] = expr.getValue();
1385 return success();
1386}
1387
1388void SimpleAffineExprFlattener::addLocalVariableSemiAffine(
1389 AffineExpr expr, SmallVectorImpl<int64_t> &result,
1390 unsigned long resultSize) {
1391 assert(result.size() == resultSize &&
1392 "`result` vector passed is not of correct size");
1393 int loc;
1394 if ((loc = findLocalId(localExpr: expr)) == -1)
1395 addLocalIdSemiAffine(localExpr: expr);
1396 std::fill(first: result.begin(), last: result.end(), value: 0);
1397 if (loc == -1)
1398 result[getLocalVarStartIndex() + numLocals - 1] = 1;
1399 else
1400 result[getLocalVarStartIndex() + loc] = 1;
1401}
1402
1403// t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
1404// A floordiv is thus flattened by introducing a new local variable q, and
1405// replacing that expression with 'q' while adding the constraints
1406// c * q <= expr <= c * q + c - 1 to localVarCst (done by
1407// IntegerRelation::addLocalFloorDiv).
1408//
1409// A ceildiv is similarly flattened:
1410// t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
1411//
1412// In case of semi affine division expressions, t = expr floordiv symbolic_expr
1413// or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr
1414// floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to
1415// `localExprs`.
1416LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
1417 bool isCeil) {
1418 assert(operandExprStack.size() >= 2);
1419
1420 MLIRContext *context = expr.getContext();
1421 SmallVector<int64_t, 8> rhs = operandExprStack.back();
1422 operandExprStack.pop_back();
1423 SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1424
1425 // Flatten semi affine division expressions by introducing a local
1426 // variable in place of the quotient, and the affine expression corresponding
1427 // to the quantifier is added to `localExprs`.
1428 if (!isa<AffineConstantExpr>(Val: expr.getRHS())) {
1429 AffineExpr a = getAffineExprFromFlatForm(flatExprs: lhs, numDims, numSymbols,
1430 localExprs, context);
1431 AffineExpr b = getAffineExprFromFlatForm(flatExprs: rhs, numDims, numSymbols,
1432 localExprs, context);
1433 AffineExpr divExpr = isCeil ? a.ceilDiv(other: b) : a.floorDiv(other: b);
1434 addLocalVariableSemiAffine(expr: divExpr, result&: lhs, resultSize: lhs.size());
1435 return success();
1436 }
1437
1438 // This is a pure affine expr; the RHS is a positive constant.
1439 int64_t rhsConst = rhs[getConstantIndex()];
1440 if (rhsConst <= 0)
1441 return failure();
1442
1443 // Simplify the floordiv, ceildiv if possible by canceling out the greatest
1444 // common divisors of the numerator and denominator.
1445 uint64_t gcd = std::abs(i: rhsConst);
1446 for (int64_t lhsElt : lhs)
1447 gcd = std::gcd(m: gcd, n: (uint64_t)std::abs(i: lhsElt));
1448 // Simplify the numerator and the denominator.
1449 if (gcd != 1) {
1450 for (int64_t &lhsElt : lhs)
1451 lhsElt = lhsElt / static_cast<int64_t>(gcd);
1452 }
1453 int64_t divisor = rhsConst / static_cast<int64_t>(gcd);
1454 // If the divisor becomes 1, the updated LHS is the result. (The
1455 // divisor can't be negative since rhsConst is positive).
1456 if (divisor == 1)
1457 return success();
1458
1459 // If the divisor cannot be simplified to one, we will have to retain
1460 // the ceil/floor expr (simplified up until here). Add an existential
1461 // quantifier to express its result, i.e., expr1 div expr2 is replaced
1462 // by a new identifier, q.
1463 AffineExpr a =
1464 getAffineExprFromFlatForm(flatExprs: lhs, numDims, numSymbols, localExprs, context);
1465 AffineExpr b = getAffineConstantExpr(constant: divisor, context);
1466
1467 int loc;
1468 AffineExpr divExpr = isCeil ? a.ceilDiv(other: b) : a.floorDiv(other: b);
1469 if ((loc = findLocalId(localExpr: divExpr)) == -1) {
1470 if (!isCeil) {
1471 SmallVector<int64_t, 8> dividend(lhs);
1472 addLocalFloorDivId(dividend, divisor, localExpr: divExpr);
1473 } else {
1474 // lhs ceildiv c <=> (lhs + c - 1) floordiv c
1475 SmallVector<int64_t, 8> dividend(lhs);
1476 dividend.back() += divisor - 1;
1477 addLocalFloorDivId(dividend, divisor, localExpr: divExpr);
1478 }
1479 }
1480 // Set the expression on stack to the local var introduced to capture the
1481 // result of the division (floor or ceil).
1482 std::fill(first: lhs.begin(), last: lhs.end(), value: 0);
1483 if (loc == -1)
1484 lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
1485 else
1486 lhs[getLocalVarStartIndex() + loc] = 1;
1487 return success();
1488}
1489
1490// Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
1491// The local identifier added is always a floordiv of a pure add/mul affine
1492// function of other identifiers, coefficients of which are specified in
1493// dividend and with respect to a positive constant divisor. localExpr is the
1494// simplified tree expression (AffineExpr) corresponding to the quantifier.
1495void SimpleAffineExprFlattener::addLocalFloorDivId(ArrayRef<int64_t> dividend,
1496 int64_t divisor,
1497 AffineExpr localExpr) {
1498 assert(divisor > 0 && "positive constant divisor expected");
1499 for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1500 subExpr.insert(I: subExpr.begin() + getLocalVarStartIndex() + numLocals, Elt: 0);
1501 localExprs.push_back(Elt: localExpr);
1502 numLocals++;
1503 // dividend and divisor are not used here; an override of this method uses it.
1504}
1505
1506void SimpleAffineExprFlattener::addLocalIdSemiAffine(AffineExpr localExpr) {
1507 for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1508 subExpr.insert(I: subExpr.begin() + getLocalVarStartIndex() + numLocals, Elt: 0);
1509 localExprs.push_back(Elt: localExpr);
1510 ++numLocals;
1511}
1512
1513int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
1514 SmallVectorImpl<AffineExpr>::iterator it;
1515 if ((it = llvm::find(Range&: localExprs, Val: localExpr)) == localExprs.end())
1516 return -1;
1517 return it - localExprs.begin();
1518}
1519
1520/// Simplify the affine expression by flattening it and reconstructing it.
1521AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
1522 unsigned numSymbols) {
1523 // Simplify semi-affine expressions separately.
1524 if (!expr.isPureAffine())
1525 expr = simplifySemiAffine(expr, numDims, numSymbols);
1526
1527 SimpleAffineExprFlattener flattener(numDims, numSymbols);
1528 // has poison expression
1529 if (failed(result: flattener.walkPostOrder(expr)))
1530 return expr;
1531 ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1532 if (!expr.isPureAffine() &&
1533 expr == getAffineExprFromFlatForm(flatExprs: flattenedExpr, numDims, numSymbols,
1534 localExprs: flattener.localExprs,
1535 context: expr.getContext()))
1536 return expr;
1537 AffineExpr simplifiedExpr =
1538 expr.isPureAffine()
1539 ? getAffineExprFromFlatForm(flatExprs: flattenedExpr, numDims, numSymbols,
1540 localExprs: flattener.localExprs, context: expr.getContext())
1541 : getSemiAffineExprFromFlatForm(flatExprs: flattenedExpr, numDims, numSymbols,
1542 localExprs: flattener.localExprs,
1543 context: expr.getContext());
1544
1545 flattener.operandExprStack.pop_back();
1546 assert(flattener.operandExprStack.empty());
1547 return simplifiedExpr;
1548}
1549
1550std::optional<int64_t> mlir::getBoundForAffineExpr(
1551 AffineExpr expr, unsigned numDims, unsigned numSymbols,
1552 ArrayRef<std::optional<int64_t>> constLowerBounds,
1553 ArrayRef<std::optional<int64_t>> constUpperBounds, bool isUpper) {
1554 // Handle divs and mods.
1555 if (auto binOpExpr = dyn_cast<AffineBinaryOpExpr>(Val&: expr)) {
1556 // If the LHS of a floor or ceil is bounded and the RHS is a constant, we
1557 // can compute an upper bound.
1558 if (binOpExpr.getKind() == AffineExprKind::FloorDiv) {
1559 auto rhsConst = dyn_cast<AffineConstantExpr>(Val: binOpExpr.getRHS());
1560 if (!rhsConst || rhsConst.getValue() < 1)
1561 return std::nullopt;
1562 auto bound =
1563 getBoundForAffineExpr(expr: binOpExpr.getLHS(), numDims, numSymbols,
1564 constLowerBounds, constUpperBounds, isUpper);
1565 if (!bound)
1566 return std::nullopt;
1567 return mlir::floorDiv(lhs: *bound, rhs: rhsConst.getValue());
1568 }
1569 if (binOpExpr.getKind() == AffineExprKind::CeilDiv) {
1570 auto rhsConst = dyn_cast<AffineConstantExpr>(Val: binOpExpr.getRHS());
1571 if (rhsConst && rhsConst.getValue() >= 1) {
1572 auto bound =
1573 getBoundForAffineExpr(expr: binOpExpr.getLHS(), numDims, numSymbols,
1574 constLowerBounds, constUpperBounds, isUpper);
1575 if (!bound)
1576 return std::nullopt;
1577 return mlir::ceilDiv(lhs: *bound, rhs: rhsConst.getValue());
1578 }
1579 return std::nullopt;
1580 }
1581 if (binOpExpr.getKind() == AffineExprKind::Mod) {
1582 // lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
1583 // bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
1584 // (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
1585 auto rhsConst = dyn_cast<AffineConstantExpr>(Val: binOpExpr.getRHS());
1586 if (rhsConst && rhsConst.getValue() >= 1) {
1587 int64_t rhsConstVal = rhsConst.getValue();
1588 auto lb = getBoundForAffineExpr(expr: binOpExpr.getLHS(), numDims, numSymbols,
1589 constLowerBounds, constUpperBounds,
1590 /*isUpper=*/false);
1591 auto ub =
1592 getBoundForAffineExpr(expr: binOpExpr.getLHS(), numDims, numSymbols,
1593 constLowerBounds, constUpperBounds, isUpper);
1594 if (ub && lb &&
1595 floorDiv(lhs: *lb, rhs: rhsConstVal) == floorDiv(lhs: *ub, rhs: rhsConstVal))
1596 return isUpper ? mod(lhs: *ub, rhs: rhsConstVal) : mod(lhs: *lb, rhs: rhsConstVal);
1597 return isUpper ? rhsConstVal - 1 : 0;
1598 }
1599 }
1600 }
1601 // Flatten the expression.
1602 SimpleAffineExprFlattener flattener(numDims, numSymbols);
1603 auto simpleResult = flattener.walkPostOrder(expr);
1604 // has poison expression
1605 if (failed(result: simpleResult))
1606 return std::nullopt;
1607 ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1608 // TODO: Handle local variables. We can get hold of flattener.localExprs and
1609 // get bound on the local expr recursively.
1610 if (flattener.numLocals > 0)
1611 return std::nullopt;
1612 int64_t bound = 0;
1613 // Substitute the constant lower or upper bound for the dimensional or
1614 // symbolic input depending on `isUpper` to determine the bound.
1615 for (unsigned i = 0, e = numDims + numSymbols; i < e; ++i) {
1616 if (flattenedExpr[i] > 0) {
1617 auto &constBound = isUpper ? constUpperBounds[i] : constLowerBounds[i];
1618 if (!constBound)
1619 return std::nullopt;
1620 bound += *constBound * flattenedExpr[i];
1621 } else if (flattenedExpr[i] < 0) {
1622 auto &constBound = isUpper ? constLowerBounds[i] : constUpperBounds[i];
1623 if (!constBound)
1624 return std::nullopt;
1625 bound += *constBound * flattenedExpr[i];
1626 }
1627 }
1628 // Constant term.
1629 bound += flattenedExpr.back();
1630 return bound;
1631}
1632

source code of mlir/lib/IR/AffineExpr.cpp