Skip to content

Fix bug in visitDivExpr, visitModExpr and visitMulExpr #145290

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/include/mlir/IR/AffineExprVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,14 @@ class SimpleAffineExprFlattener
AffineExpr localExpr);

private:
/// Flatten binary expression `expr` and add it to `result`. If `expr` is a
/// dimension, symbol or constant, we add it to appropriate index in `result`.
/// Otherwise we add it in the local variable section. `lhs` and `rhs` are the
/// LHS and RHS expressions of `expr`.
LogicalResult addExprToFlattenedList(AffineExpr expr, ArrayRef<int64_t> lhs,
ArrayRef<int64_t> rhs,
SmallVectorImpl<int64_t> &result);

/// Adds `localExpr`, which may be mod, ceildiv, floordiv or mod expression
/// representing the affine expression corresponding to the quantifier
/// introduced as the local variable corresponding to `localExpr`. If the
Expand Down
35 changes: 27 additions & 8 deletions mlir/lib/IR/AffineExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1177,10 +1177,9 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
if (flatExprs[numDims + numSymbols + it.index()] == 0)
continue;
AffineExpr expr = it.value();
auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
if (!binaryExpr)
continue;

// A local expression cannot be a dimension, symbol or a constant -- it
// should be a binary op expression.
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
AffineExpr lhs = binaryExpr.getLHS();
AffineExpr rhs = binaryExpr.getRHS();
if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
Expand Down Expand Up @@ -1274,6 +1273,27 @@ SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims,
operandExprStack.reserve(8);
}

LogicalResult SimpleAffineExprFlattener::addExprToFlattenedList(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is completely misnamed. Is this specific to semi-affine exprs? In fact, addLocalIdSemiAffine in any of its overrides doesn't even use lhs and rhs. So all of these arguments aren't making sense to me.

LogicalResult SimpleAffineExprFlattener::addLocalIdSemiAffine(
    ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs, AffineExpr localExpr) {
  for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
    subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
  localExprs.push_back(localExpr);
  ++numLocals;
  // lhs and rhs are not used here; an override of this method uses them.
  return success();
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment inside the function you have pasted explains this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I don't see the override using them either.

AffineExpr expr, ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs,
SmallVectorImpl<int64_t> &result) {
if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
std::fill(result.begin(), result.end(), 0);
result[getConstantIndex()] = constExpr.getValue();
return success();
}
if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
std::fill(result.begin(), result.end(), 0);
result[getDimStartIndex() + dimExpr.getPosition()] = 1;
return success();
}
if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
std::fill(result.begin(), result.end(), 0);
result[getSymbolStartIndex() + symExpr.getPosition()] = 1;
return success();
}
return addLocalVariableSemiAffine(lhs, rhs, expr, result, result.size());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we know that expr is semi-affine at this stage? It can just be a purely affine binary expression. This method is confusing without any further comments. At this point, all you know is that expr is an affine binary expression and you might as well cast it to that and send it to make the signature of addLocal... less confusing.

}

// In pure affine t = expr * c, we multiply each coefficient of lhs with c.
//
// In case of semi affine multiplication expressions, t = expr * symbolic_expr,
Expand All @@ -1295,7 +1315,7 @@ LogicalResult SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
localExprs, context);
AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
localExprs, context);
return addLocalVariableSemiAffine(mulLhs, rhs, a * b, lhs, lhs.size());
return addExprToFlattenedList(a * b, mulLhs, rhs, lhs);
}

// Get the RHS constant.
Expand Down Expand Up @@ -1347,8 +1367,7 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
lhs, numDims, numSymbols, localExprs, context);
AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
localExprs, context);
AffineExpr modExpr = dividendExpr % divisorExpr;
return addLocalVariableSemiAffine(modLhs, rhs, modExpr, lhs, lhs.size());
return addExprToFlattenedList(dividendExpr % divisorExpr, modLhs, rhs, lhs);
}

int64_t rhsConst = rhs[getConstantIndex()];
Expand Down Expand Up @@ -1482,7 +1501,7 @@ LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
localExprs, context);
AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
return addLocalVariableSemiAffine(divLhs, rhs, divExpr, lhs, lhs.size());
return addExprToFlattenedList(divExpr, divLhs, rhs, lhs);
}

// This is a pure affine expr; the RHS is a positive constant.
Expand Down
29 changes: 29 additions & 0 deletions mlir/test/Dialect/Affine/simplify-structures.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -608,3 +608,32 @@ func.func @semiaffine_simplification_floordiv_and_ceildiv_const(%arg0: tensor<?x
// CHECK-NEXT: return %[[C6]], %[[C7]]
return %a, %b : index, index
}

// -----

// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (13 mod s0)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 2)>
// CHECK-LABEL: semi_affine_simplification_local_expr_folded_into_non_binary_expr
func.func @semi_affine_simplification_local_expr_folded_into_non_binary_expr(%arg0: memref<?x?xf32>) -> (index, index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c13 = arith.constant 13 : index
// CHECK: %[[DIM:.*]] = memref.dim
%dim = memref.dim %arg0, %c0 : memref<?x?xf32>
// CHECK: %[[VAL0:.*]] = affine.apply #[[$MAP]]()[%[[DIM]]]
%c = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 mod (s1 + (-s1 + s3) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1, %dim]
%alloc = memref.alloc() : memref<1xindex>
affine.for %iv = 0 to 1 {
%d = affine.apply affine_map<(d0)[s1, s2] -> ((d0 - s1 + s1 * s2) * (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>(%iv)[%dim, %c1]
affine.store %d, %alloc[0] : memref<1xindex>
}
// CHECK: affine.for %[[IV:.*]] = 0 to 1 {
// CHECK-NEXT: %[[VAL:.*]] = affine.apply #[[$MAP1]](%[[IV]])
// CHECK-NEXT: affine.store %[[VAL]], %{{.*}}[0] : memref<1xindex>
// CHECK-NEXT: }
// CHECK: %[[VAL1:.*]] = affine.load %{{.*}}[0]
%d = affine.load %alloc[0] : memref<1xindex>
// CHECK: return %[[VAL0]], %[[VAL1]]
return %c, %d : index, index
}
Loading