diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h index 98fb6075cbf32..be9839ce26339 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -50,6 +50,7 @@ namespace vector { class ContractionOp; class TransferReadOp; class TransferWriteOp; +class TransposeOp; class VectorDialect; namespace detail { @@ -171,6 +172,12 @@ SmallVector getAsValues(OpBuilder &builder, Location loc, /// `std::nullopt`. std::optional getConstantVscaleMultiplier(Value value); +/// Return true if `transpose` does not permute a pair of non-unit dims. +/// By `order preserving` we mean that the flattened versions of the input and +/// output vectors are (numerically) identical. In other words `transpose` is +/// effectively a shape cast. +bool isOrderPreserving(TransposeOp transpose); + //===----------------------------------------------------------------------===// // Vector Masking Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index f1100d5cf8b68..3344765f4818a 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -406,6 +406,26 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, void populateVectorTransposeNarrowTypeRewritePatterns( RewritePatternSet &patterns, PatternBenefit benefit = 1); +/// Add patterns that convert operations that are semantically equivalent to +/// shape_cast, to shape_cast. Currently this includes patterns for converting +/// transpose, extract and broadcast to shape_cast. Examples that will be +/// converted to shape_cast are: +/// +/// ``` +/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> +/// %1 = vector.transpose %arg1, [1, 0] : vector<2x1xi8> to vector<1x2xi8> +/// %2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8> +/// ``` +/// +/// Note that there is no pattern for vector.extract_strided_slice, because the +/// only extract_strided_slice that is semantically equivalent to shape_cast is +/// one that has idential input and output shapes, which is already folded. +/// +/// These patterns can be useful to expose more folding opportunities by +/// creating pairs of shape_casts that cancel. +void populateConvertToShapeCastPatterns(RewritePatternSet &, + PatternBenefit = 1); + /// Initialize `typeConverter` and `conversionTarget` for vector linearization. /// This registers (1) which operations are legal and hence should not be /// linearized, (2) what converted types are (rank-1 vectors) and how to diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index f9c7fb7799eb0..562fc7d6ca110 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5574,13 +5574,11 @@ LogicalResult ShapeCastOp::verify() { return success(); } -namespace { - /// Return true if `transpose` does not permute a pair of non-unit dims. /// By `order preserving` we mean that the flattened versions of the input and /// output vectors are (numerically) identical. In other words `transpose` is /// effectively a shape cast. -bool isOrderPreserving(TransposeOp transpose) { +bool mlir::vector::isOrderPreserving(TransposeOp transpose) { ArrayRef permutation = transpose.getPermutation(); VectorType sourceType = transpose.getSourceVectorType(); ArrayRef inShape = sourceType.getShape(); @@ -5600,8 +5598,6 @@ bool isOrderPreserving(TransposeOp transpose) { return true; } -} // namespace - OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { VectorType resultType = getType(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index b94c5fce64f83..efcde8e97c0cd 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -2182,6 +2182,92 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern { } }; +/// For example, +/// ``` +/// %0 = vector.transpose %arg0, [0, 2, 1] : +/// vector<2x1x2xf32> to vector<2x2x1xf32> +/// ``` +/// becomes +/// ``` +/// %0 = vector.shape_cast %arg0 : +/// vector<2x1x2xf32> to vector<2x2x1xf32> +/// ``` +struct TransposeToShapeCast final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::TransposeOp transpose, + PatternRewriter &rewriter) const override { + if (!isOrderPreserving(transpose)) { + return rewriter.notifyMatchFailure( + transpose, "not order preserving, so not semantically a 'copy'"); + } + rewriter.replaceOpWithNewOp( + transpose, transpose.getType(), transpose.getVector()); + return success(); + } +}; + +/// For example, +/// ``` +/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> +/// ``` +/// becomes +/// ``` +/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8> +/// ``` +struct BroadcastToShapeCast final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::BroadcastOp broadcast, + PatternRewriter &rewriter) const override { + auto sourceType = dyn_cast(broadcast.getSourceType()); + if (!sourceType) { + return rewriter.notifyMatchFailure( + broadcast, "source is a scalar, shape_cast doesn't support scalar"); + } + + VectorType outType = broadcast.getType(); + if (sourceType.getNumElements() != outType.getNumElements()) + return failure(); + + rewriter.replaceOpWithNewOp(broadcast, outType, + broadcast.getSource()); + return success(); + } +}; + +/// For example, +/// ``` +/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32> +/// ``` +/// becomes +/// ``` +/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32> +/// ``` +struct ExtractToShapeCast final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ExtractOp extractOp, + PatternRewriter &rewriter) const override { + VectorType sourceType = extractOp.getSourceVectorType(); + VectorType outType = dyn_cast(extractOp.getType()); + if (!outType) + return failure(); + + // Negative values in `position` indicates poison, cannot convert to + // shape_cast + if (llvm::any_of(extractOp.getMixedPosition(), + [](OpFoldResult v) { return !isConstantIntValue(v, 0); })) + return failure(); + + if (sourceType.getNumElements() != outType.getNumElements()) + return failure(); + + rewriter.replaceOpWithNewOp(extractOp, outType, + extractOp.getVector()); + return success(); + } +}; + } // namespace void mlir::vector::populateFoldArithExtensionPatterns( @@ -2285,6 +2371,13 @@ void mlir::vector::populateElementwiseToVectorOpsPatterns( patterns.getContext()); } +void mlir::vector::populateConvertToShapeCastPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns + .insert( + patterns.getContext(), benefit); +} + //===----------------------------------------------------------------------===// // TableGen'd enum attribute definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir b/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir new file mode 100644 index 0000000000000..0ad6b3ff7d541 --- /dev/null +++ b/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir @@ -0,0 +1,65 @@ +// RUN: mlir-opt %s -split-input-file -test-convert-to-shape-cast | FileCheck %s + + +// CHECK-LABEL: @transpose_to_shape_cast +// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32> +// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]] +// CHECK-NEXT: return %[[SCAST]] : vector<2x2x1xf32> +func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> { + %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to vector<2x2x1xf32> + return %0 : vector<2x2x1xf32> +} + +// ----- + +// CHECK-LABEL: @negative_transpose_to_shape_cast +// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32> +// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1] +// CHECK-NEXT: return %[[TRANSPOSE]] : vector<2x2x1xf32> +func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> { + %0 = vector.transpose %arg0, [2, 0, 1] : vector<2x1x2xf32> to vector<2x2x1xf32> + return %0 : vector<2x2x1xf32> +} + +// ----- + +// CHECK-LABEL: @broadcast_to_shape_cast +// CHECK-SAME: %[[ARG0:.*]]: vector<4xi8> +// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]] +// CHECK-NEXT: return %[[SCAST]] : vector<1x1x4xi8> +func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> { + %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> + return %0 : vector<1x1x4xi8> +} + +// ----- + +// CHECK-LABEL: @negative_broadcast_to_shape_cast +// CHECK-NOT: shape_cast +// CHECK: return +func.func @negative_broadcast_to_shape_cast(%arg0 : vector<1x4xi8>) -> vector<2x3x4xi8> { + %0 = vector.broadcast %arg0 : vector<1x4xi8> to vector<2x3x4xi8> + return %0 : vector<2x3x4xi8> +} + +// ----- + +// CHECK-LABEL: @extract_to_shape_cast +// CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32> +// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]] +// CHECK-NEXT: return %[[SCAST]] : vector<4xf32> +func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> { + %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32> + return %0 : vector<4xf32> +} + +// ----- + +// In this example, arg1 might be negative indicating poison. +// CHECK-LABEL: @negative_extract_to_shape_cast +// CHECK-NOT: shape_cast +func.func @negative_extract_to_shape_cast(%arg0 : vector<1x4xf32>, %arg1 : index) -> vector<4xf32> { + %0 = vector.extract %arg0[%arg1] : vector<4xf32> from vector<1x4xf32> + return %0 : vector<4xf32> +} + diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index b73c40adcffa7..aa97d6fc5dc69 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -1022,6 +1022,26 @@ struct TestEliminateVectorMasks VscaleRange{vscaleMin, vscaleMax}); } }; + +struct TestConvertToShapeCast + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertToShapeCast) + + TestConvertToShapeCast() = default; + + StringRef getArgument() const final { return "test-convert-to-shape-cast"; } + StringRef getDescription() const final { + return "Test conversion to shape_cast of semantically equivalent ops"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateConvertToShapeCastPatterns(patterns); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); + } +}; } // namespace namespace mlir { @@ -1072,6 +1092,8 @@ void registerTestVectorLowerings() { PassRegistration(); PassRegistration(); + + PassRegistration(); } } // namespace test } // namespace mlir