From 9e349f31609ccc854b2c40e36fce896dd8e3434d Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 6 May 2025 15:46:36 -0700 Subject: [PATCH 1/4] first --- .../mlir/Dialect/Vector/IR/VectorOps.h | 7 ++ .../Vector/Transforms/VectorRewritePatterns.h | 20 ++++ mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 5 +- .../Vector/Transforms/VectorTransforms.cpp | 92 +++++++++++++++++++ .../Dialect/Vector/convert-to-shape-cast.mlir | 65 +++++++++++++ .../Dialect/Vector/TestVectorTransforms.cpp | 22 +++++ 6 files changed, 207 insertions(+), 4 deletions(-) create mode 100644 mlir/test/Dialect/Vector/convert-to-shape-cast.mlir diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h index 98fb6075cbf32..be9839ce26339 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -50,6 +50,7 @@ namespace vector { class ContractionOp; class TransferReadOp; class TransferWriteOp; +class TransposeOp; class VectorDialect; namespace detail { @@ -171,6 +172,12 @@ SmallVector getAsValues(OpBuilder &builder, Location loc, /// `std::nullopt`. std::optional getConstantVscaleMultiplier(Value value); +/// Return true if `transpose` does not permute a pair of non-unit dims. +/// By `order preserving` we mean that the flattened versions of the input and +/// output vectors are (numerically) identical. In other words `transpose` is +/// effectively a shape cast. +bool isOrderPreserving(TransposeOp transpose); + //===----------------------------------------------------------------------===// // Vector Masking Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index f1100d5cf8b68..a6a221b2e3a67 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -406,6 +406,26 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, void populateVectorTransposeNarrowTypeRewritePatterns( RewritePatternSet &patterns, PatternBenefit benefit = 1); +/// Add patterns that convert operations that are semantically equivalent to +/// shape_cast, to shape_cast. Currently this includes patterns for converting +/// transpose, extract and broadcast to shape_cast. Examples that will be +/// converted to shape_cast are: +/// +/// ``` +/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> +/// %1 = vector.transpose %arg1, [1, 0] : vector<2x1xi8> to vector<1x2xi8> +/// %2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8> +/// ``` +/// +/// Note that there is no pattern for vector.extract_strided_slice, because the +/// only extract_strided_slice that is semantically equivalent to shape_cast is +/// one that has idential input and output shapes, which is already folded. +/// +/// These patterns can be useful to expose more folding opportunities by +/// creating pairs of shape_casts that cancel. +void populateConvertToShapeCastPatterns(RewritePatternSet &, + PatternBenefit = 1); + /// Initialize `typeConverter` and `conversionTarget` for vector linearization. /// This registers (1) which operations are legal and hence should not be /// linearized, (2) what converted types are (rank-1 vectors) and how to diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index f9c7fb7799eb0..11622e1da8de1 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5574,13 +5574,12 @@ LogicalResult ShapeCastOp::verify() { return success(); } -namespace { /// Return true if `transpose` does not permute a pair of non-unit dims. /// By `order preserving` we mean that the flattened versions of the input and /// output vectors are (numerically) identical. In other words `transpose` is /// effectively a shape cast. -bool isOrderPreserving(TransposeOp transpose) { +bool mlir::vector::isOrderPreserving(TransposeOp transpose) { ArrayRef permutation = transpose.getPermutation(); VectorType sourceType = transpose.getSourceVectorType(); ArrayRef inShape = sourceType.getShape(); @@ -5600,8 +5599,6 @@ bool isOrderPreserving(TransposeOp transpose) { return true; } -} // namespace - OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { VectorType resultType = getType(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index b94c5fce64f83..05fc6989bf9d2 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -2182,6 +2182,91 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern { } }; +/// For example, +/// ``` +/// %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to +/// vector<2x2x1xf32> +/// ``` +/// becomes +/// ``` +/// %0 = vector.shape_cast %arg0 : vector<2x1x2xf32> to vector<2x2x1xf32> +/// ``` +struct TransposeToShapeCast final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::TransposeOp transpose, + PatternRewriter &rewriter) const override { + if (!isOrderPreserving(transpose)) { + return rewriter.notifyMatchFailure( + transpose, "not order preserving, so not semantically a 'copy'"); + } + rewriter.replaceOpWithNewOp( + transpose, transpose.getType(), transpose.getVector()); + return success(); + } +}; + +/// For example, +/// ``` +/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> +/// ``` +/// becomes +/// ``` +/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8> +/// ``` +struct BroadcastToShapeCast final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::BroadcastOp broadcast, + PatternRewriter &rewriter) const override { + auto sourceType = dyn_cast(broadcast.getSourceType()); + if (!sourceType) { + return rewriter.notifyMatchFailure( + broadcast, "source is a scalar, shape_cast doesn't support scalar"); + } + + VectorType outType = broadcast.getType(); + if (sourceType.getNumElements() != outType.getNumElements()) + return failure(); + + rewriter.replaceOpWithNewOp(broadcast, outType, + broadcast.getSource()); + return success(); + } +}; + +/// For example, +/// ``` +/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32> +/// ``` +/// becomes +/// ``` +/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32> +/// ``` +struct ExtractToShapeCast final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ExtractOp extractOp, + PatternRewriter &rewriter) const override { + VectorType sourceType = extractOp.getSourceVectorType(); + VectorType outType = dyn_cast(extractOp.getType()); + if (!outType) + return failure(); + + // Negative values in `position` indicates poison, cannot convert to + // shape_cast + if (llvm::any_of(extractOp.getMixedPosition(), + [](OpFoldResult v) { return !isConstantIntValue(v, 0); })) + return failure(); + + if (sourceType.getNumElements() != outType.getNumElements()) + return failure(); + + rewriter.replaceOpWithNewOp(extractOp, outType, + extractOp.getVector()); + return success(); + } +}; + } // namespace void mlir::vector::populateFoldArithExtensionPatterns( @@ -2285,6 +2370,13 @@ void mlir::vector::populateElementwiseToVectorOpsPatterns( patterns.getContext()); } +void mlir::vector::populateConvertToShapeCastPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns + .insert( + patterns.getContext(), benefit); +} + //===----------------------------------------------------------------------===// // TableGen'd enum attribute definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir b/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir new file mode 100644 index 0000000000000..483c3e73614e0 --- /dev/null +++ b/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir @@ -0,0 +1,65 @@ +// RUN: mlir-opt %s -split-input-file -test-convert-to-shape-cast | FileCheck %s + + +// CHECK-LABEL: @transpose_to_shape_cast +// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32> +// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]] +// CHECK-NEXT: return %[[SCAST]] : vector<2x2x1xf32> +func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> { + %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to vector<2x2x1xf32> + return %0 : vector<2x2x1xf32> +} + +// ----- + +// CHECK-LABEL: @negative_transpose_to_shape_cast +// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32> +// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1] +// CHECK-NEXT: return %[[TRANSPOSE]] : vector<2x2x1xf32> +func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> { + %0 = vector.transpose %arg0, [2, 0, 1] : vector<2x1x2xf32> to vector<2x2x1xf32> + return %0 : vector<2x2x1xf32> +} + +// ----- + +// CHECK-LABEL: @broadcast_to_shape_cast +// CHECK-SAME: %[[ARG0:.*]]: vector<4xi8> +// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]] +// CHECK-NEXT: return %[[SCAST]] : vector<1x1x4xi8> +func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> { + %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> + return %0 : vector<1x1x4xi8> +} + +// ----- + +// CHECK-LABEL: @negative_broadcast_to_shape_cast +// CHECK-NOT: shape_cast +// CHECK: return +func.func @negative_broadcast_to_shape_cast(%arg0 : vector<1x4xi8>) -> vector<2x3x4xi8> { + %0 = vector.broadcast %arg0 : vector<1x4xi8> to vector<2x3x4xi8> + return %0 : vector<2x3x4xi8> +} + +// ----- + +// CHECK-LABEL: @extract_to_shape_cast +// CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32> +// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]] +// CHECK-NEXT: return %[[SCAST]] : vector<4xf32> +func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> { + %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32> + return %0 : vector<4xf32> +} + +// ----- + +// In this example, arg1 might be negative indicating poison. +// CHECK-LABEL: @negative_extract_to_shape_cast +// CHECK-NOT: shape_cast +func.func @negative_extract_to_shape_cast(%arg0 : vector<1x4xf32>, %arg1 : index) -> vector<4xf32> { + %0 = vector.extract %arg0[%arg1] : vector<4xf32> from vector<1x4xf32> + return %0 : vector<4xf32> +} + diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index b73c40adcffa7..aa97d6fc5dc69 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -1022,6 +1022,26 @@ struct TestEliminateVectorMasks VscaleRange{vscaleMin, vscaleMax}); } }; + +struct TestConvertToShapeCast + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertToShapeCast) + + TestConvertToShapeCast() = default; + + StringRef getArgument() const final { return "test-convert-to-shape-cast"; } + StringRef getDescription() const final { + return "Test conversion to shape_cast of semantically equivalent ops"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateConvertToShapeCastPatterns(patterns); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); + } +}; } // namespace namespace mlir { @@ -1072,6 +1092,8 @@ void registerTestVectorLowerings() { PassRegistration(); PassRegistration(); + + PassRegistration(); } } // namespace test } // namespace mlir From afb11a43ab0fea5c9ed995c23a88e5ed15c81e5f Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 6 May 2025 15:49:24 -0700 Subject: [PATCH 2/4] whitespace --- .../Dialect/Vector/Transforms/VectorRewritePatterns.h | 2 +- mlir/test/Dialect/Vector/convert-to-shape-cast.mlir | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index a6a221b2e3a67..3344765f4818a 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -422,7 +422,7 @@ void populateVectorTransposeNarrowTypeRewritePatterns( /// one that has idential input and output shapes, which is already folded. /// /// These patterns can be useful to expose more folding opportunities by -/// creating pairs of shape_casts that cancel. +/// creating pairs of shape_casts that cancel. void populateConvertToShapeCastPatterns(RewritePatternSet &, PatternBenefit = 1); diff --git a/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir b/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir index 483c3e73614e0..0ad6b3ff7d541 100644 --- a/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir +++ b/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir @@ -1,9 +1,9 @@ -// RUN: mlir-opt %s -split-input-file -test-convert-to-shape-cast | FileCheck %s +// RUN: mlir-opt %s -split-input-file -test-convert-to-shape-cast | FileCheck %s // CHECK-LABEL: @transpose_to_shape_cast // CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32> -// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]] +// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]] // CHECK-NEXT: return %[[SCAST]] : vector<2x2x1xf32> func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> { %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to vector<2x2x1xf32> @@ -14,7 +14,7 @@ func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf3 // CHECK-LABEL: @negative_transpose_to_shape_cast // CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32> -// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1] +// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1] // CHECK-NEXT: return %[[TRANSPOSE]] : vector<2x2x1xf32> func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> { %0 = vector.transpose %arg0, [2, 0, 1] : vector<2x1x2xf32> to vector<2x2x1xf32> @@ -36,7 +36,7 @@ func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> { // CHECK-LABEL: @negative_broadcast_to_shape_cast // CHECK-NOT: shape_cast -// CHECK: return +// CHECK: return func.func @negative_broadcast_to_shape_cast(%arg0 : vector<1x4xi8>) -> vector<2x3x4xi8> { %0 = vector.broadcast %arg0 : vector<1x4xi8> to vector<2x3x4xi8> return %0 : vector<2x3x4xi8> @@ -55,7 +55,7 @@ func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> { // ----- -// In this example, arg1 might be negative indicating poison. +// In this example, arg1 might be negative indicating poison. // CHECK-LABEL: @negative_extract_to_shape_cast // CHECK-NOT: shape_cast func.func @negative_extract_to_shape_cast(%arg0 : vector<1x4xf32>, %arg1 : index) -> vector<4xf32> { From 37107a4259f723ce9925d3923526ca8df516a2ce Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 6 May 2025 15:52:32 -0700 Subject: [PATCH 3/4] spacing --- mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 05fc6989bf9d2..efcde8e97c0cd 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -2184,12 +2184,13 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern { /// For example, /// ``` -/// %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to -/// vector<2x2x1xf32> +/// %0 = vector.transpose %arg0, [0, 2, 1] : +/// vector<2x1x2xf32> to vector<2x2x1xf32> /// ``` /// becomes /// ``` -/// %0 = vector.shape_cast %arg0 : vector<2x1x2xf32> to vector<2x2x1xf32> +/// %0 = vector.shape_cast %arg0 : +/// vector<2x1x2xf32> to vector<2x2x1xf32> /// ``` struct TransposeToShapeCast final : public OpRewritePattern { From 6854f9262bfecd08b70903a0a272361fe779f867 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 6 May 2025 16:06:14 -0700 Subject: [PATCH 4/4] format --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 11622e1da8de1..562fc7d6ca110 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5574,7 +5574,6 @@ LogicalResult ShapeCastOp::verify() { return success(); } - /// Return true if `transpose` does not permute a pair of non-unit dims. /// By `order preserving` we mean that the flattened versions of the input and /// output vectors are (numerically) identical. In other words `transpose` is