Skip to content

Commit fe3933d

Browse files
yangtetrisYang Bai
andauthored
[mlir][vector] Support complete folding in single pass for vector.insert/vector.extract (#142124)
### Description This patch improves the folding efficiency of `vector.insert` and `vector.extract` operations by not returning early after successfully converting dynamic indices to static indices. This PR also renames the test pass `TestConstantFold` to `TestSingleFold` and adds comprehensive documentation explaining the single-pass folding behavior. ### Motivation Since the `OpBuilder::createOrFold` function only calls `fold` **once**, the current `fold` methods of `vector.insert` and `vector.extract` may leave the op in a state that can be folded further. For example, consider the following un-folded IR: ``` %v1 = vector.insert %e1, %v0 [0] : f32 into vector<128xf32> %c0 = arith.constant 0 : index %e2 = vector.extract %v1[%c0] : f32 from vector<128xf32> ``` If we use `createOrFold` to create the `vector.extract` op, then the result will be: ``` %v1 = vector.insert %e1, %v0 [127] : f32 into vector<128xf32> %e2 = vector.extract %v1[0] : f32 from vector<128xf32> ``` But this is not the optimal result. `createOrFold` should have returned `%e1`. The reason is that the execution of fold returns immediately after `extractInsertFoldConstantOp`, causing subsequent folding logics to be skipped. --------- Co-authored-by: Yang Bai <[email protected]>
1 parent 0018921 commit fe3933d

File tree

13 files changed

+86
-32
lines changed

13 files changed

+86
-32
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2063,6 +2063,7 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
20632063
if (opChange) {
20642064
op.setStaticPosition(staticPosition);
20652065
op.getOperation()->setOperands(operands);
2066+
// Return the original result to indicate an in-place folding happened.
20662067
return op.getResult();
20672068
}
20682069
return {};
@@ -2146,11 +2147,12 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
21462147
return getVector();
21472148
if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
21482149
return res;
2149-
// Fold `arith.constant` indices into the `vector.extract` operation. Make
2150-
// sure that patterns requiring constant indices are added after this fold.
2150+
// Fold `arith.constant` indices into the `vector.extract` operation.
2151+
// Do not stop here as this fold may enable subsequent folds that require
2152+
// constant indices.
21512153
SmallVector<Value> operands = {getVector()};
2152-
if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
2153-
return val;
2154+
auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
2155+
21542156
if (auto res = foldPoisonIndexInsertExtractOp(
21552157
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
21562158
return res;
@@ -2172,7 +2174,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
21722174
return val;
21732175
if (auto val = foldScalarExtractFromFromElements(*this))
21742176
return val;
2175-
return OpFoldResult();
2177+
2178+
return inplaceFolded;
21762179
}
21772180

21782181
namespace {
@@ -3272,11 +3275,12 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
32723275
// (type mismatch).
32733276
if (getNumIndices() == 0 && getValueToStoreType() == getType())
32743277
return getValueToStore();
3275-
// Fold `arith.constant` indices into the `vector.insert` operation. Make
3276-
// sure that patterns requiring constant indices are added after this fold.
3278+
// Fold `arith.constant` indices into the `vector.insert` operation.
3279+
// Do not stop here as this fold may enable subsequent folds that require
3280+
// constant indices.
32773281
SmallVector<Value> operands = {getValueToStore(), getDest()};
3278-
if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
3279-
return val;
3282+
auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
3283+
32803284
if (auto res = foldPoisonIndexInsertExtractOp(
32813285
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
32823286
return res;
@@ -3286,7 +3290,7 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
32863290
return res;
32873291
}
32883292

3289-
return {};
3293+
return inplaceFolded;
32903294
}
32913295

32923296
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Affine/constant-fold.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -test-constant-fold -split-input-file %s | FileCheck %s
1+
// RUN: mlir-opt -test-single-fold -split-input-file %s | FileCheck %s
22

33
// CHECK-LABEL: func @affine_apply
44
func.func @affine_apply(%variable : index) -> (index, index, index) {

mlir/test/Dialect/Linalg/mesh-spmdization.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// RUN: mlir-opt \
2-
// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
2+
// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-single-fold))" \
33
// RUN: --split-input-file \
44
// RUN: %s | FileCheck %s
55

mlir/test/Dialect/Mesh/spmdization.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// RUN: mlir-opt \
2-
// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
2+
// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-single-fold))" \
33
// RUN: %s | FileCheck %s
44

55
mesh.mesh @mesh_1d(shape = 2)

mlir/test/Dialect/Tensor/mesh-spmdization.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// RUN: mlir-opt \
2-
// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
2+
// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-single-fold))" \
33
// RUN: %s | FileCheck %s
44

55
mesh.mesh @mesh_1d_4(shape = 4)

mlir/test/Dialect/Tosa/constant_folding.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt --test-constant-fold %s | FileCheck %s
1+
// RUN: mlir-opt --test-single-fold %s | FileCheck %s
22

33
// CHECK-LABEL: func @test_const
44
func.func @test_const(%arg0 : index) -> tensor<4xi32> {

mlir/test/Dialect/Vector/constant-fold.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -split-input-file -test-constant-fold | FileCheck %s
1+
// RUN: mlir-opt %s -split-input-file -test-single-fold | FileCheck %s
22

33
// CHECK-LABEL: fold_extract_transpose_negative
44
func.func @fold_extract_transpose_negative(%arg0: vector<4x4xf16>) -> vector<4x4xf16> {
@@ -11,3 +11,5 @@ func.func @fold_extract_transpose_negative(%arg0: vector<4x4xf16>) -> vector<4x4
1111
%2 = vector.extract %1[0] : vector<4x4xf16> from vector<1x4x4xf16>
1212
return %2 : vector<4x4xf16>
1313
}
14+
15+
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: mlir-opt %s -split-input-file -test-single-fold | FileCheck %s
2+
3+
// The tests in this file verify that fold() methods can handle complex
4+
// optimization scenarios without requiring multiple folding iterations.
5+
// This is important because:
6+
//
7+
// 1. OpBuilder::createOrFold() only calls fold() once, so operations must
8+
// be fully optimized in that single call
9+
// 2. Multiple rounds of folding would incur higher performance costs,
10+
// so it's more efficient to complete all optimizations in one pass
11+
//
12+
// These tests ensure that folding implementations are robust and complete,
13+
// avoiding situations where operations are left in intermediate states
14+
// that could be further optimized.
15+
16+
// CHECK-LABEL: fold_extract_in_single_pass
17+
// CHECK-SAME: (%{{.*}}: vector<4xf16>, %[[ARG1:.+]]: f16)
18+
func.func @fold_extract_in_single_pass(%arg0: vector<4xf16>, %arg1: f16) -> f16 {
19+
%0 = vector.insert %arg1, %arg0 [1] : f16 into vector<4xf16>
20+
%c1 = arith.constant 1 : index
21+
// Verify that the fold is finished in a single pass even if the index is dynamic.
22+
%1 = vector.extract %0[%c1] : f16 from vector<4xf16>
23+
// CHECK: return %[[ARG1]] : f16
24+
return %1 : f16
25+
}
26+
27+
// -----
28+
29+
// CHECK-LABEL: fold_insert_in_single_pass
30+
func.func @fold_insert_in_single_pass() -> vector<2xf16> {
31+
%cst = arith.constant dense<0.000000e+00> : vector<2xf16>
32+
%c1 = arith.constant 1 : index
33+
%c2 = arith.constant 2.5 : f16
34+
// Verify that the fold is finished in a single pass even if the index is dynamic.
35+
// CHECK: arith.constant dense<[0.000000e+00, 2.500000e+00]> : vector<2xf16>
36+
%0 = vector.insert %c2, %cst [%c1] : f16 into vector<2xf16>
37+
return %0 : vector<2xf16>
38+
}

mlir/test/Transforms/constant-fold-debuginfo.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -split-input-file -test-constant-fold -mlir-print-debuginfo | FileCheck %s
1+
// RUN: mlir-opt %s -split-input-file -test-single-fold -mlir-print-debuginfo | FileCheck %s
22

33
// CHECK-LABEL: func @fold_and_merge
44
func.func @fold_and_merge() -> (i32, i32) {

mlir/test/Transforms/constant-fold.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -test-constant-fold | FileCheck %s
1+
// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -test-single-fold | FileCheck %s
22

33
// -----
44

mlir/test/lib/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ endif()
2626
add_mlir_library(MLIRTestTransforms
2727
TestCommutativityUtils.cpp
2828
TestCompositePass.cpp
29-
TestConstantFold.cpp
3029
TestControlFlowSink.cpp
3130
TestInlining.cpp
3231
TestInliningCallback.cpp
3332
TestMakeIsolatedFromAbove.cpp
33+
TestSingleFold.cpp
3434
TestTransformsOps.cpp
3535
${MLIRTestTransformsPDLSrc}
3636

mlir/test/lib/Transforms/TestConstantFold.cpp renamed to mlir/test/lib/Transforms/TestSingleFold.cpp

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- TestConstantFold.cpp - Pass to test constant folding ---------------===//
1+
//===- TestSingleFold.cpp - Pass to test single-pass folding --------------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -12,14 +12,23 @@
1212
using namespace mlir;
1313

1414
namespace {
15-
/// Simple constant folding pass.
16-
struct TestConstantFold : public PassWrapper<TestConstantFold, OperationPass<>>,
17-
public RewriterBase::Listener {
18-
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConstantFold)
15+
/// Test pass for single-pass constant folding.
16+
///
17+
/// This pass tests the behavior of operations when folded exactly once. Unlike
18+
/// canonicalization passes that may apply multiple rounds of folding, this pass
19+
/// ensures that each operation is folded at most once, which is useful for
20+
/// testing scenarios where the fold implementation should handle complex cases
21+
/// without requiring multiple iterations.
22+
///
23+
/// The pass also removes dead constants after folding to clean up unused
24+
/// intermediate results.
25+
struct TestSingleFold : public PassWrapper<TestSingleFold, OperationPass<>>,
26+
public RewriterBase::Listener {
27+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSingleFold)
1928

20-
StringRef getArgument() const final { return "test-constant-fold"; }
29+
StringRef getArgument() const final { return "test-single-fold"; }
2130
StringRef getDescription() const final {
22-
return "Test operation constant folding";
31+
return "Test single-pass operation folding and dead constant elimination";
2332
}
2433
// All constants in the operation post folding.
2534
SmallVector<Operation *> existingConstants;
@@ -39,18 +48,19 @@ struct TestConstantFold : public PassWrapper<TestConstantFold, OperationPass<>>,
3948
};
4049
} // namespace
4150

42-
void TestConstantFold::foldOperation(Operation *op, OperationFolder &helper) {
51+
void TestSingleFold::foldOperation(Operation *op, OperationFolder &helper) {
4352
// Attempt to fold the specified operation, including handling unused or
4453
// duplicated constants.
4554
(void)helper.tryToFold(op);
4655
}
4756

48-
void TestConstantFold::runOnOperation() {
57+
void TestSingleFold::runOnOperation() {
4958
existingConstants.clear();
5059

5160
// Collect and fold the operations within the operation.
5261
SmallVector<Operation *, 8> ops;
53-
getOperation()->walk<mlir::WalkOrder::PreOrder>([&](Operation *op) { ops.push_back(op); });
62+
getOperation()->walk<mlir::WalkOrder::PreOrder>(
63+
[&](Operation *op) { ops.push_back(op); });
5464

5565
// Fold the constants in reverse so that the last generated constants from
5666
// folding are at the beginning. This creates somewhat of a linear ordering to
@@ -70,6 +80,6 @@ void TestConstantFold::runOnOperation() {
7080

7181
namespace mlir {
7282
namespace test {
73-
void registerTestConstantFold() { PassRegistration<TestConstantFold>(); }
83+
void registerTestSingleFold() { PassRegistration<TestSingleFold>(); }
7484
} // namespace test
7585
} // namespace mlir

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ void registerTestCfAssertPass();
8787
void registerTestCFGLoopInfoPass();
8888
void registerTestComposeSubView();
8989
void registerTestCompositePass();
90-
void registerTestConstantFold();
9190
void registerTestControlFlowSink();
9291
void registerTestConvertToSPIRVPass();
9392
void registerTestDataLayoutPropagation();
@@ -145,6 +144,7 @@ void registerTestSCFUtilsPass();
145144
void registerTestSCFWhileOpBuilderPass();
146145
void registerTestSCFWrapInZeroTripCheckPasses();
147146
void registerTestShapeMappingPass();
147+
void registerTestSingleFold();
148148
void registerTestSliceAnalysisPass();
149149
void registerTestSPIRVCPURunnerPipeline();
150150
void registerTestSPIRVFuncSignatureConversion();
@@ -233,7 +233,6 @@ void registerTestPasses() {
233233
mlir::test::registerTestCFGLoopInfoPass();
234234
mlir::test::registerTestComposeSubView();
235235
mlir::test::registerTestCompositePass();
236-
mlir::test::registerTestConstantFold();
237236
mlir::test::registerTestControlFlowSink();
238237
mlir::test::registerTestConvertToSPIRVPass();
239238
mlir::test::registerTestDataLayoutPropagation();
@@ -291,6 +290,7 @@ void registerTestPasses() {
291290
mlir::test::registerTestSCFWhileOpBuilderPass();
292291
mlir::test::registerTestSCFWrapInZeroTripCheckPasses();
293292
mlir::test::registerTestShapeMappingPass();
293+
mlir::test::registerTestSingleFold();
294294
mlir::test::registerTestSliceAnalysisPass();
295295
mlir::test::registerTestSPIRVCPURunnerPipeline();
296296
mlir::test::registerTestSPIRVFuncSignatureConversion();

0 commit comments

Comments
 (0)