diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h index 1826f5fd8ad35..b3c101e4e2f47 100644 --- a/mlir/include/mlir/IR/AffineExprVisitor.h +++ b/mlir/include/mlir/IR/AffineExprVisitor.h @@ -418,6 +418,14 @@ class SimpleAffineExprFlattener AffineExpr localExpr); private: + /// Flatten binary expression `expr` and add it to `result`. If `expr` is a + /// dimension, symbol or constant, we add it to appropriate index in `result`. + /// Otherwise we add it in the local variable section. `lhs` and `rhs` are the + /// LHS and RHS expressions of `expr`. + LogicalResult addExprToFlattenedList(AffineExpr expr, ArrayRef lhs, + ArrayRef rhs, + SmallVectorImpl &result); + /// Adds `localExpr`, which may be mod, ceildiv, floordiv or mod expression /// representing the affine expression corresponding to the quantifier /// introduced as the local variable corresponding to `localExpr`. If the diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index cc81f9d19aca7..5c262c1179b9c 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -1177,10 +1177,9 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef flatExprs, if (flatExprs[numDims + numSymbols + it.index()] == 0) continue; AffineExpr expr = it.value(); - auto binaryExpr = dyn_cast(expr); - if (!binaryExpr) - continue; - + // A local expression cannot be a dimension, symbol or a constant -- it + // should be a binary op expression. + auto binaryExpr = cast(expr); AffineExpr lhs = binaryExpr.getLHS(); AffineExpr rhs = binaryExpr.getRHS(); if (!((isa(lhs) || isa(lhs)) && @@ -1274,6 +1273,27 @@ SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims, operandExprStack.reserve(8); } +LogicalResult SimpleAffineExprFlattener::addExprToFlattenedList( + AffineExpr expr, ArrayRef lhs, ArrayRef rhs, + SmallVectorImpl &result) { + if (auto constExpr = dyn_cast(expr)) { + std::fill(result.begin(), result.end(), 0); + result[getConstantIndex()] = constExpr.getValue(); + return success(); + } + if (auto dimExpr = dyn_cast(expr)) { + std::fill(result.begin(), result.end(), 0); + result[getDimStartIndex() + dimExpr.getPosition()] = 1; + return success(); + } + if (auto symExpr = dyn_cast(expr)) { + std::fill(result.begin(), result.end(), 0); + result[getSymbolStartIndex() + symExpr.getPosition()] = 1; + return success(); + } + return addLocalVariableSemiAffine(lhs, rhs, expr, result, result.size()); +} + // In pure affine t = expr * c, we multiply each coefficient of lhs with c. // // In case of semi affine multiplication expressions, t = expr * symbolic_expr, @@ -1295,7 +1315,7 @@ LogicalResult SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) { localExprs, context); AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols, localExprs, context); - return addLocalVariableSemiAffine(mulLhs, rhs, a * b, lhs, lhs.size()); + return addExprToFlattenedList(a * b, mulLhs, rhs, lhs); } // Get the RHS constant. @@ -1347,8 +1367,7 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) { lhs, numDims, numSymbols, localExprs, context); AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols, localExprs, context); - AffineExpr modExpr = dividendExpr % divisorExpr; - return addLocalVariableSemiAffine(modLhs, rhs, modExpr, lhs, lhs.size()); + return addExprToFlattenedList(dividendExpr % divisorExpr, modLhs, rhs, lhs); } int64_t rhsConst = rhs[getConstantIndex()]; @@ -1482,7 +1501,7 @@ LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr, AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols, localExprs, context); AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b); - return addLocalVariableSemiAffine(divLhs, rhs, divExpr, lhs, lhs.size()); + return addExprToFlattenedList(divExpr, divLhs, rhs, lhs); } // This is a pure affine expr; the RHS is a positive constant. diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir index 6f2737a982752..17d4ab4dfe448 100644 --- a/mlir/test/Dialect/Affine/simplify-structures.mlir +++ b/mlir/test/Dialect/Affine/simplify-structures.mlir @@ -608,3 +608,32 @@ func.func @semiaffine_simplification_floordiv_and_ceildiv_const(%arg0: tensor (13 mod s0)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 2)> +// CHECK-LABEL: semi_affine_simplification_local_expr_folded_into_non_binary_expr +func.func @semi_affine_simplification_local_expr_folded_into_non_binary_expr(%arg0: memref) -> (index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c13 = arith.constant 13 : index + // CHECK: %[[DIM:.*]] = memref.dim + %dim = memref.dim %arg0, %c0 : memref + // CHECK: %[[VAL0:.*]] = affine.apply #[[$MAP]]()[%[[DIM]]] + %c = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 mod (s1 + (-s1 + s3) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1, %dim] + %alloc = memref.alloc() : memref<1xindex> + affine.for %iv = 0 to 1 { + %d = affine.apply affine_map<(d0)[s1, s2] -> ((d0 - s1 + s1 * s2) * (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>(%iv)[%dim, %c1] + affine.store %d, %alloc[0] : memref<1xindex> + } + // CHECK: affine.for %[[IV:.*]] = 0 to 1 { + // CHECK-NEXT: %[[VAL:.*]] = affine.apply #[[$MAP1]](%[[IV]]) + // CHECK-NEXT: affine.store %[[VAL]], %{{.*}}[0] : memref<1xindex> + // CHECK-NEXT: } + // CHECK: %[[VAL1:.*]] = affine.load %{{.*}}[0] + %d = affine.load %alloc[0] : memref<1xindex> + // CHECK: return %[[VAL0]], %[[VAL1]] + return %c, %d : index, index +}