Skip to content

Fix bug in visitDivExpr, visitModExpr and visitMulExpr #145290

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

arnab-polymage
Copy link
Contributor

@arnab-polymage arnab-polymage commented Jun 23, 2025

Whenever the result of a div, mod or mul affine expression is a constant expression, place the value in the constant index of the flattened expression instead of adding it as a local expression.

@arnab-polymage arnab-polymage requested a review from CoTinker June 23, 2025 08:36
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Jun 23, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 23, 2025

@llvm/pr-subscribers-mlir-affine

@llvm/pr-subscribers-mlir

Author: Arnab Dutta (arnab-polymage)

Changes

Whenever the result of a div or mod affine expression is a constant expression, place the value in the constant index of the flattened expression instead of adding it as a local expression.


Full diff: https://github.com/llvm/llvm-project/pull/145290.diff

1 Files Affected:

  • (modified) mlir/lib/IR/AffineExpr.cpp (+14-4)
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index cc81f9d19aca7..feedef46c66b8 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -1177,10 +1177,10 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
     if (flatExprs[numDims + numSymbols + it.index()] == 0)
       continue;
     AffineExpr expr = it.value();
-    auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
-    if (!binaryExpr)
-      continue;
-
+    assert(isa<AffineBinaryOpExpr>(expr) &&
+           "local expression cannot be a dimension, symbol or a constant -- it "
+           "should be a binary op expression");
+    auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
     AffineExpr lhs = binaryExpr.getLHS();
     AffineExpr rhs = binaryExpr.getRHS();
     if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
@@ -1348,6 +1348,11 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
     AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
                                                        localExprs, context);
     AffineExpr modExpr = dividendExpr % divisorExpr;
+    if (auto constModExpr = dyn_cast<AffineConstantExpr>(modExpr)) {
+      std::fill(lhs.begin(), lhs.end(), 0);
+      lhs[getConstantIndex()] = constModExpr.getValue();
+      return success();
+    }
     return addLocalVariableSemiAffine(modLhs, rhs, modExpr, lhs, lhs.size());
   }
 
@@ -1482,6 +1487,11 @@ LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
     AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
                                              localExprs, context);
     AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
+    if (auto constDivExpr = dyn_cast<AffineConstantExpr>(divExpr)) {
+      std::fill(lhs.begin(), lhs.end(), 0);
+      lhs[getConstantIndex()] = constDivExpr.getValue();
+      return success();
+    }
     return addLocalVariableSemiAffine(divLhs, rhs, divExpr, lhs, lhs.size());
   }
 

@llvmbot
Copy link
Member

llvmbot commented Jun 23, 2025

@llvm/pr-subscribers-mlir-core

Author: Arnab Dutta (arnab-polymage)

Changes

Whenever the result of a div or mod affine expression is a constant expression, place the value in the constant index of the flattened expression instead of adding it as a local expression.


Full diff: https://github.com/llvm/llvm-project/pull/145290.diff

1 Files Affected:

  • (modified) mlir/lib/IR/AffineExpr.cpp (+14-4)
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index cc81f9d19aca7..feedef46c66b8 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -1177,10 +1177,10 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
     if (flatExprs[numDims + numSymbols + it.index()] == 0)
       continue;
     AffineExpr expr = it.value();
-    auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
-    if (!binaryExpr)
-      continue;
-
+    assert(isa<AffineBinaryOpExpr>(expr) &&
+           "local expression cannot be a dimension, symbol or a constant -- it "
+           "should be a binary op expression");
+    auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
     AffineExpr lhs = binaryExpr.getLHS();
     AffineExpr rhs = binaryExpr.getRHS();
     if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
@@ -1348,6 +1348,11 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
     AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
                                                        localExprs, context);
     AffineExpr modExpr = dividendExpr % divisorExpr;
+    if (auto constModExpr = dyn_cast<AffineConstantExpr>(modExpr)) {
+      std::fill(lhs.begin(), lhs.end(), 0);
+      lhs[getConstantIndex()] = constModExpr.getValue();
+      return success();
+    }
     return addLocalVariableSemiAffine(modLhs, rhs, modExpr, lhs, lhs.size());
   }
 
@@ -1482,6 +1487,11 @@ LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
     AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
                                              localExprs, context);
     AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
+    if (auto constDivExpr = dyn_cast<AffineConstantExpr>(divExpr)) {
+      std::fill(lhs.begin(), lhs.end(), 0);
+      lhs[getConstantIndex()] = constDivExpr.getValue();
+      return success();
+    }
     return addLocalVariableSemiAffine(divLhs, rhs, divExpr, lhs, lhs.size());
   }
 

Copy link
Contributor

@bondhugula bondhugula left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes aren't complete or sound.

Copy link
Contributor

@bondhugula bondhugula left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing test cases for the scenarios exercised. Changes are completely untested!

@bondhugula
Copy link
Contributor

This is a stupid question. Will semi affine expr also have this constrain? If that's the case, the fix should be reverted.

Which constraint?

@arnab-polymage arnab-polymage requested a review from CoTinker June 27, 2025 03:31
@arnab-polymage arnab-polymage changed the title Fix bug in visitDivExpr and visitModExpr Fix bug in visitDivExpr, visitModExpr and visitMulExpr Jun 27, 2025
Whenever the result of a div or mod affine expression is a
constant expression, place the value in the constant index of the
flattened expression instead of adding it as a local expression.
@@ -418,6 +418,14 @@ class SimpleAffineExprFlattener
AffineExpr localExpr);

private:
/// Flatten binary expression `expr` and add it to `result`. If `expr` is a
/// dimension, symbol or constant, we add it to appropriate index in `result`.
/// Otherwise we add it in the local variable section. `lhs` and `rhs` are the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unclear here what operands of expr are! An AffineExpr doesn't have any operands. Rephrase.

@@ -1274,6 +1273,27 @@ SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims,
operandExprStack.reserve(8);
}

LogicalResult SimpleAffineExprFlattener::addExprToFlattenedList(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is completely misnamed. Is this specific to semi-affine exprs? In fact, addLocalIdSemiAffine in any of its overrides doesn't even use lhs and rhs. So all of these arguments aren't making sense to me.

LogicalResult SimpleAffineExprFlattener::addLocalIdSemiAffine(
    ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs, AffineExpr localExpr) {
  for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
    subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
  localExprs.push_back(localExpr);
  ++numLocals;
  // lhs and rhs are not used here; an override of this method uses them.
  return success();
}

// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (13 mod s0)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 2)>
// CHECK-LABEL: semiaffine_simplification_local_expr_folded_into_non_binary_expr
func.func @semiaffine_simplification_local_expr_folded_into_non_binary_expr(%arg0: memref<?x?xf32>) -> (index, index) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

semi_affine

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:affine mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants