-
Notifications
You must be signed in to change notification settings - Fork 14.3k
Fix bug in visitDivExpr
, visitModExpr
and visitMulExpr
#145290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-affine @llvm/pr-subscribers-mlir Author: Arnab Dutta (arnab-polymage) ChangesWhenever the result of a div or mod affine expression is a constant expression, place the value in the constant index of the flattened expression instead of adding it as a local expression. Full diff: https://github.com/llvm/llvm-project/pull/145290.diff 1 Files Affected:
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index cc81f9d19aca7..feedef46c66b8 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -1177,10 +1177,10 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
if (flatExprs[numDims + numSymbols + it.index()] == 0)
continue;
AffineExpr expr = it.value();
- auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
- if (!binaryExpr)
- continue;
-
+ assert(isa<AffineBinaryOpExpr>(expr) &&
+ "local expression cannot be a dimension, symbol or a constant -- it "
+ "should be a binary op expression");
+ auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
AffineExpr lhs = binaryExpr.getLHS();
AffineExpr rhs = binaryExpr.getRHS();
if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
@@ -1348,6 +1348,11 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
localExprs, context);
AffineExpr modExpr = dividendExpr % divisorExpr;
+ if (auto constModExpr = dyn_cast<AffineConstantExpr>(modExpr)) {
+ std::fill(lhs.begin(), lhs.end(), 0);
+ lhs[getConstantIndex()] = constModExpr.getValue();
+ return success();
+ }
return addLocalVariableSemiAffine(modLhs, rhs, modExpr, lhs, lhs.size());
}
@@ -1482,6 +1487,11 @@ LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
localExprs, context);
AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
+ if (auto constDivExpr = dyn_cast<AffineConstantExpr>(divExpr)) {
+ std::fill(lhs.begin(), lhs.end(), 0);
+ lhs[getConstantIndex()] = constDivExpr.getValue();
+ return success();
+ }
return addLocalVariableSemiAffine(divLhs, rhs, divExpr, lhs, lhs.size());
}
|
@llvm/pr-subscribers-mlir-core Author: Arnab Dutta (arnab-polymage) ChangesWhenever the result of a div or mod affine expression is a constant expression, place the value in the constant index of the flattened expression instead of adding it as a local expression. Full diff: https://github.com/llvm/llvm-project/pull/145290.diff 1 Files Affected:
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index cc81f9d19aca7..feedef46c66b8 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -1177,10 +1177,10 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
if (flatExprs[numDims + numSymbols + it.index()] == 0)
continue;
AffineExpr expr = it.value();
- auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
- if (!binaryExpr)
- continue;
-
+ assert(isa<AffineBinaryOpExpr>(expr) &&
+ "local expression cannot be a dimension, symbol or a constant -- it "
+ "should be a binary op expression");
+ auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
AffineExpr lhs = binaryExpr.getLHS();
AffineExpr rhs = binaryExpr.getRHS();
if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
@@ -1348,6 +1348,11 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
localExprs, context);
AffineExpr modExpr = dividendExpr % divisorExpr;
+ if (auto constModExpr = dyn_cast<AffineConstantExpr>(modExpr)) {
+ std::fill(lhs.begin(), lhs.end(), 0);
+ lhs[getConstantIndex()] = constModExpr.getValue();
+ return success();
+ }
return addLocalVariableSemiAffine(modLhs, rhs, modExpr, lhs, lhs.size());
}
@@ -1482,6 +1487,11 @@ LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
localExprs, context);
AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
+ if (auto constDivExpr = dyn_cast<AffineConstantExpr>(divExpr)) {
+ std::fill(lhs.begin(), lhs.end(), 0);
+ lhs[getConstantIndex()] = constDivExpr.getValue();
+ return success();
+ }
return addLocalVariableSemiAffine(divLhs, rhs, divExpr, lhs, lhs.size());
}
|
f51da54
to
7c184d5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes aren't complete or sound.
7c184d5
to
75ab97d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing test cases for the scenarios exercised. Changes are completely untested!
Which constraint? |
75ab97d
to
17ce1f0
Compare
visitDivExpr
and visitModExpr
visitDivExpr
, visitModExpr
and visitMulExpr
17ce1f0
to
5fc2915
Compare
5fc2915
to
7d262f5
Compare
7d262f5
to
c2493e6
Compare
c2493e6
to
c140442
Compare
Whenever the result of a div or mod affine expression is a constant expression, place the value in the constant index of the flattened expression instead of adding it as a local expression.
c140442
to
f7715ad
Compare
@@ -418,6 +418,14 @@ class SimpleAffineExprFlattener | |||
AffineExpr localExpr); | |||
|
|||
private: | |||
/// Flatten binary expression `expr` and add it to `result`. If `expr` is a | |||
/// dimension, symbol or constant, we add it to appropriate index in `result`. | |||
/// Otherwise we add it in the local variable section. `lhs` and `rhs` are the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's unclear here what operands of expr
are! An AffineExpr doesn't have any operands. Rephrase.
@@ -1274,6 +1273,27 @@ SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims, | |||
operandExprStack.reserve(8); | |||
} | |||
|
|||
LogicalResult SimpleAffineExprFlattener::addExprToFlattenedList( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is completely misnamed. Is this specific to semi-affine exprs? In fact, addLocalIdSemiAffine in any of its overrides doesn't even use lhs
and rhs
. So all of these arguments aren't making sense to me.
LogicalResult SimpleAffineExprFlattener::addLocalIdSemiAffine(
ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs, AffineExpr localExpr) {
for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
localExprs.push_back(localExpr);
++numLocals;
// lhs and rhs are not used here; an override of this method uses them.
return success();
}
// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (13 mod s0)> | ||
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 2)> | ||
// CHECK-LABEL: semiaffine_simplification_local_expr_folded_into_non_binary_expr | ||
func.func @semiaffine_simplification_local_expr_folded_into_non_binary_expr(%arg0: memref<?x?xf32>) -> (index, index) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
semi_affine
Whenever the result of a div, mod or mul affine expression is a constant expression, place the value in the constant index of the flattened expression instead of adding it as a local expression.