diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 862ed7bae1fbb..f032d76f54ece 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1691,10 +1691,36 @@ static bool hasZeroDimVectors(Operation *op) { llvm::any_of(op->getResultTypes(), hasZeroDimVectorType); } +/// vector.splat, and vector.shape_cast that just prepends 1's are +/// special cases of vector.broadcast. This function returns true +/// if \p op is one of these operations. +static bool isBroadcastLike(Operation *op) { + + if (isa(op)) + return true; + + // a shape_cast which just prepends 1's is broadcast-like. + auto shapeCast = dyn_cast(op); + if (!shapeCast) + return false; + + ArrayRef dstShape = shapeCast.getType().getShape(); + ArrayRef srcShape = shapeCast.getSourceVectorType().getShape(); + + // A rank-reducing shape_cast cannot be broadcast-like. + if (srcShape.size() > dstShape.size()) + return false; + + bool isSuffix = (srcShape == dstShape.take_back(srcShape.size())); + return isSuffix; +} + /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp. -static Value foldExtractFromBroadcast(ExtractOp extractOp) { +static Value foldExtractFromBroadcastLike(ExtractOp extractOp) { + Operation *defOp = extractOp.getVector().getDefiningOp(); - if (!defOp || !isa(defOp)) + + if (!defOp || !isBroadcastLike(defOp)) return Value(); Value source = defOp->getOperand(0); @@ -1721,14 +1747,22 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) { broadcastVecType.getShape().take_back(extractResultRank)) return Value(); - auto broadcastOp = cast(defOp); - int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank(); + assert(defOp->getNumResults() == 1 && "all broadcast-like ops have 1 result"); + auto dstType = dyn_cast(defOp->getResult(0).getType()); + assert(dstType && "all broadcast-like ops have vector results"); + + int64_t broadcastDstRank = dstType.getRank(); // Detect all the positions that come from "dim-1" broadcasting. - // These dimensions correspond to "dim-1" broadcasted dims; set the mathching + // These dimensions correspond to "dim-1" broadcasted dims; set the matching // extract position to `0` when extracting from the source operand. - llvm::SetVector broadcastedUnitDims = - broadcastOp.computeBroadcastedUnitDims(); + auto broadcastedUnitDims = [&]() -> llvm::SetVector { + if (auto broadcastOp = dyn_cast(defOp)) { + return broadcastOp.computeBroadcastedUnitDims(); + } + return {}; + }(); + SmallVector extractPos(extractOp.getMixedPosition()); OpBuilder b(extractOp.getContext()); int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank; @@ -2163,7 +2197,7 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) { return getResult(); if (auto res = ExtractFromInsertTransposeChainState(*this).fold()) return res; - if (auto res = foldExtractFromBroadcast(*this)) + if (auto res = foldExtractFromBroadcastLike(*this)) return res; if (auto res = foldExtractFromShuffle(*this)) return res; @@ -2181,7 +2215,7 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) { namespace { -// Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast. +// Pattern to rewrite a ExtractOp(broadcast-like) -> Broadcast. class ExtractOpFromBroadcast final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -2189,7 +2223,8 @@ class ExtractOpFromBroadcast final : public OpRewritePattern { LogicalResult matchAndRewrite(ExtractOp extractOp, PatternRewriter &rewriter) const override { Operation *defOp = extractOp.getVector().getDefiningOp(); - if (!defOp || !isa(defOp)) + + if (!defOp || !isBroadcastLike(defOp)) return failure(); Value source = defOp->getOperand(0); @@ -2351,11 +2386,41 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp, return success(); } +/// BEFORE: +/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32> +/// AFTER: +/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32> +struct ExtractToShapeCast final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ExtractOp extractOp, + PatternRewriter &rewriter) const override { + VectorType sourceType = extractOp.getSourceVectorType(); + VectorType outType = dyn_cast(extractOp.getType()); + if (!outType) + return failure(); + + // Negative values in `position` indicates poison, which cannot be + // represented with a shape_cast + if (llvm::any_of(extractOp.getMixedPosition(), + [](OpFoldResult v) { return !isConstantIntValue(v, 0); })) + return failure(); + + if (sourceType.getNumElements() != outType.getNumElements()) + return failure(); + + rewriter.replaceOpWithNewOp(extractOp, outType, + extractOp.getVector()); + return success(); + } +}; + } // namespace void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results + .add( + context); results.add(foldExtractFromShapeCastToShapeCast); results.add(foldExtractFromFromElements); } @@ -2867,13 +2932,36 @@ struct BroadcastFolder : public OpRewritePattern { return success(); } }; + +/// BEFORE: +/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> +/// AFTER: +/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8> +struct BroadcastToShapeCast final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::BroadcastOp broadcast, + PatternRewriter &rewriter) const override { + auto sourceType = dyn_cast(broadcast.getSourceType()); + if (!sourceType) { + return rewriter.notifyMatchFailure( + broadcast, "source is a scalar, shape_cast doesn't support scalar"); + } + + VectorType outType = broadcast.getType(); + if (sourceType.getNumElements() != outType.getNumElements()) + return failure(); + + rewriter.replaceOpWithNewOp(broadcast, outType, + broadcast.getSource()); + return success(); + } +}; } // namespace void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - // BroadcastToShapeCast is not a default canonicalization, it is opt-in by - // calling `populateCastAwayVectorLeadingOneDimPatterns` - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// @@ -5991,10 +6079,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final } }; -/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either -/// i) Y = ShapeCast(X), or -/// ii) Y = Broadcast(X) -/// If both (i) and (ii) are possible, (i) is chosen. +/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X) class ShapeCastBroadcastFolder final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -6009,22 +6094,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern { auto srcVectorType = dyn_cast(broadcastOp.getSourceType()); bool srcIsScalar = !srcVectorType; - // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X). - // Example: - // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32> - // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32> - // to - // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32> - if (srcVectorType) { - if (srcVectorType.getNumElements() == - shapeCastOp.getResultVectorType().getNumElements()) { - rewriter.replaceOpWithNewOp( - shapeCastOp, shapeCastOp.getResultVectorType(), - broadcastOp.getSource()); - return success(); - } - } - // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X) // Example // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32> @@ -6233,7 +6302,7 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) { // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8> // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8> // - // Example of what NOT to fold: + // Example of what not to fold: // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8> // if (getSourceVectorType() == getResultVectorType() && @@ -6359,32 +6428,6 @@ class FoldTransposeCreateMask final : public OpRewritePattern { } }; -/// Folds transpose(shape_cast) into a new shape_cast. -class FoldTransposeShapeCast final : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TransposeOp transposeOp, - PatternRewriter &rewriter) const override { - auto shapeCastOp = - transposeOp.getVector().getDefiningOp(); - if (!shapeCastOp) - return failure(); - if (!isOrderPreserving(transposeOp)) - return failure(); - - VectorType resultType = transposeOp.getType(); - - // We don't need to check isValidShapeCast at this point, because it is - // guaranteed that merging the transpose into the the shape_cast is a valid - // shape_cast, because the transpose just inserts/removes ones. - - rewriter.replaceOpWithNewOp(transposeOp, resultType, - shapeCastOp.getSource()); - return success(); - } -}; - /// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is /// 'order preserving', where 'order preserving' means the flattened /// inputs and outputs of the transpose have identical (numerical) values. @@ -6480,12 +6523,35 @@ class FoldTransposeBroadcast : public OpRewritePattern { } }; +/// BEFORE: +/// %0 = vector.transpose %arg0, [0, 2, 1] : +/// vector<2x1x2xf32> to vector<2x2x1xf32> +/// AFTER: +/// %0 = vector.shape_cast %arg0 : +/// vector<2x1x2xf32> to vector<2x2x1xf32> +struct TransposeToShapeCast final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::TransposeOp transpose, + PatternRewriter &rewriter) const override { + + if (!isOrderPreserving(transpose)) { + return rewriter.notifyMatchFailure( + transpose, "not order preserving, so not semantically a 'copy'"); + } + rewriter.replaceOpWithNewOp( + transpose, transpose.getType(), transpose.getVector()); + return success(); + } +}; + } // namespace void vector::TransposeOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp index 732e316c93381..3d12527b86283 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -11,7 +11,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" @@ -382,64 +381,6 @@ class TransposeOpLowering : public OpRewritePattern { vector::VectorTransposeLowering vectorTransposeLowering; }; -/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied -/// to 2D vectors with at least one unit dim. For example: -/// -/// Replace: -/// vector.transpose %0, [1, 0] : vector<4x1xi32>> to -/// vector<1x4xi32> -/// with: -/// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32> -/// -/// Source with leading unit dim (inverse) is also replaced. Unit dim must -/// be fixed. Non-unit dim can be scalable. -/// -/// TODO: This pattern was introduced specifically to help lower scalable -/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's -/// to cancel out) would be preferable: -/// -/// BEFORE: -/// %0 = some_op -/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32> -/// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> -/// AFTER: -/// %0 = some_op -/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32> -/// -/// Given the context above, we may want to consider (re-)moving this pattern -/// at some later time. I am leaving it for now in case there are other users -/// that I am not aware of. -class Transpose2DWithUnitDimToShapeCast - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - Transpose2DWithUnitDimToShapeCast(MLIRContext *context, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit) {} - - LogicalResult matchAndRewrite(vector::TransposeOp op, - PatternRewriter &rewriter) const override { - Value input = op.getVector(); - VectorType resType = op.getResultVectorType(); - - // Set up convenience transposition table. - ArrayRef transp = op.getPermutation(); - - if (resType.getRank() == 2 && - ((resType.getShape().front() == 1 && - !resType.getScalableDims().front()) || - (resType.getShape().back() == 1 && - !resType.getScalableDims().back())) && - transp == ArrayRef({1, 0})) { - rewriter.replaceOpWithNewOp(op, resType, input); - return success(); - } - - return failure(); - } -}; - /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops. /// If the strategy is Shuffle1D, it will be lowered to: /// vector.shape_cast 2D -> 1D @@ -511,8 +452,8 @@ class TransposeOp2DToShuffleLowering void mlir::vector::populateVectorTransposeLoweringPatterns( RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) { - patterns.add(patterns.getContext(), - benefit); + TransposeOp::getCanonicalizationPatterns(patterns, patterns.getContext()); + ShapeCastOp::getCanonicalizationPatterns(patterns, patterns.getContext()); patterns.add( vectorTransposeLowering, patterns.getContext(), benefit); } diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir index 6cdf576272ebc..a9a2fdccdd82f 100644 --- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir +++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir @@ -480,11 +480,11 @@ func.func @lift_illegal_transpose_to_memory_with_in_bounds_attr(%a: index, %b: i // ----- -// The pass should do nothing (and not crash). -// CHECK-LABEL: @illegal_transpose_no_defining_source_op -func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> +// CHECK-LABEL: @transpose_no_defining_source_op +func.func @transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> { - // CHECK: vector.transpose + // CHECK: vector.shape_cast + // CHECK-SAME: vector<[4]x1xf32> to vector<1x[4]xf32> %0 = vector.transpose %vec, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> return %0 : vector<1x[4]xf32> } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 65b73375831da..b7c07399bc25e 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -451,16 +451,25 @@ func.func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>, // ----- // CHECK-LABEL: transpose_3D_identity -// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>) +// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>) +// CHECK-NEXT: return [[ARG]] func.func @transpose_3D_identity(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> { - // CHECK-NOT: transpose %0 = vector.transpose %arg, [0, 1, 2] : vector<4x3x2xf32> to vector<4x3x2xf32> - // CHECK-NEXT: return [[ARG]] return %0 : vector<4x3x2xf32> } // ----- +// CHECK-LABEL: transpose_0D_identity +// CHECK-SAME: ([[ARG:%.*]]: vector) +// CHECK-NEXT: return [[ARG]] +func.func @transpose_0D_identity(%arg : vector) -> vector { + %0 = vector.transpose %arg, [] : vector to vector + return %0 : vector +} + +// ----- + // CHECK-LABEL: transpose_2D_sequence // CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>) func.func @transpose_2D_sequence(%arg : vector<4x3xf32>) -> vector<4x3xf32> { @@ -753,10 +762,21 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector, // ----- -// CHECK-LABEL: negative_fold_extract_broadcast +// CHECK-LABEL: negative_fold_partial_extract_broadcast +// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x2x4xf32> +// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x2x4xf32> +func.func @negative_fold_partial_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> { + %b = vector.broadcast %a : vector<1x1xf32> to vector<1x2x4xf32> + %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x2x4xf32> + return %r : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: negative_fold_full_extract_broadcast // CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32> -// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32> -func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> { +// CHECK: vector.shape_cast %{{.*}} : vector<1x1x4xf32> to vector<4xf32> +func.func @negative_fold_full_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> { %b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32> %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32> return %r : vector<4xf32> @@ -764,10 +784,10 @@ func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32 // ----- -// CHECK-LABEL: fold_extract_splat +// CHECK-LABEL: fold_extract_scalar_splat // CHECK-SAME: %[[A:.*]]: f32 // CHECK: return %[[A]] : f32 -func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 { +func.func @fold_extract_scalar_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 { %b = vector.splat %a : vector<1x2x4xf32> %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32> return %r : f32 @@ -775,12 +795,22 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : in // ----- -// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting +// CHECK-LABEL: fold_extract_vector_splat +// CHECK: vector.broadcast {{.*}} f32 to vector<4xf32> +func.func @fold_extract_vector_splat(%a : f32, %idx0 : index, %idx1 : index) -> vector<4xf32> { + %b = vector.splat %a : vector<1x2x4xf32> + %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32> + return %r : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: fold_extract_broadcast_21_to_124 // CHECK-SAME: %[[A:.*]]: vector<2x1xf32> // CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index // CHECK: %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32> // CHECK: return %[[R]] : f32 -func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>, +func.func @fold_extract_broadcast_21_to_124(%a : vector<2x1xf32>, %idx : index, %idx1 : index, %idx2 : index) -> f32 { %b = vector.broadcast %a : vector<2x1xf32> to vector<1x2x4xf32> %r = vector.extract %b[%idx, %idx1, %idx2] : f32 from vector<1x2x4xf32> @@ -789,6 +819,20 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>, // ----- +// CHECK-LABEL: fold_extract_broadcast_21_to_224 +// CHECK-SAME: %[[A:.*]]: vector<2x1xf32> +// CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index +// CHECK: %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32> +// CHECK: return %[[R]] : f32 +func.func @fold_extract_broadcast_21_to_224(%a : vector<2x1xf32>, + %idx : index, %idx1 : index, %idx2 : index) -> f32 { + %b = vector.broadcast %a : vector<2x1xf32> to vector<2x2x4xf32> + %r = vector.extract %b[%idx, %idx1, %idx2] : f32 from vector<2x2x4xf32> + return %r : f32 +} + +// ----- + // CHECK-LABEL: fold_extract_broadcast_to_lower_rank // CHECK-SAME: %[[A:.*]]: vector<2x4xf32> // CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index @@ -797,8 +841,8 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>, // rank(extract_output) < rank(broadcast_input) func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>, %idx0 : index, %idx1 : index) -> vector<4xf32> { - %b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32> - %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32> + %b = vector.broadcast %a : vector<2x4xf32> to vector<2x2x4xf32> + %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<2x2x4xf32> return %r : vector<4xf32> } @@ -1549,7 +1593,7 @@ func.func @negative_store_to_load_tensor_memref( %arg0 : tensor, %arg1 : memref, %v0 : vector<4x2xf32> - ) -> vector<4x2xf32> + ) -> vector<4x2xf32> { %c0 = arith.constant 0 : index %cf0 = arith.constant 0.0 : f32 @@ -1606,7 +1650,7 @@ func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor< // CHECK: vector.transfer_read func.func @negative_store_to_load_tensor_broadcast_masked( %arg0 : tensor, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>) - -> vector<4x2x6xf32> + -> vector<4x2x6xf32> { %c0 = arith.constant 0 : index %cf0 = arith.constant 0.0 : f32 @@ -1920,12 +1964,12 @@ func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> { // ----- -// CHECK-LABEL: func @insert_extract_to_broadcast +// CHECK-LABEL: func @insert_extract_to_shape_cast // CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>) -// CHECK: %[[V0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<4xf32> from vector<1x1x4xf32> -// CHECK: %[[V1:.*]] = vector.broadcast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32> +// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32> +// CHECK: %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32> // CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32> -func.func @insert_extract_to_broadcast(%arg0 : vector<1x1x4xf32>, +func.func @insert_extract_to_shape_cast(%arg0 : vector<1x1x4xf32>, %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) { %0 = vector.extract %arg0[0, 0] : vector<4xf32> from vector<1x1x4xf32> %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32> @@ -2277,7 +2321,7 @@ func.func @shuffle_1d_rhs_poison() -> vector<4xi32> { // CHECK-LABEL: func @shuffle_canonicalize_0d func.func @shuffle_canonicalize_0d(%v0 : vector, %v1 : vector) -> vector<1xi32> { - // CHECK: vector.broadcast %{{.*}} : vector to vector<1xi32> + // CHECK: vector.shape_cast %{{.*}} : vector to vector<1xi32> %shuffle = vector.shuffle %v0, %v1 [0] : vector, vector return %shuffle : vector<1xi32> } @@ -2764,9 +2808,8 @@ func.func @transfer_read_from_rank_reducing_extract_slice(%src: tensor<1x8x8x8xf // CHECK-LABEL: func.func @extract_from_broadcast func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> { %0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32> - - // CHECK-NEXT: %0 = vector.extract {{.*}}[0, 0] : vector<1xf32> from vector<1x1x1xf32> - // CHECK-NEXT: return %0 : vector<1xf32> + // CHECK-NEXT: %[[RES:.*]] = vector.shape_cast{{.*}} vector<1x1x1xf32> to vector<1xf32> + // CHECK-NEXT: return %[[RES]] : vector<1xf32> %1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32> return %1: vector<1xf32> } diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir index fdab2a8918a2e..d5f96a8928770 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir @@ -81,8 +81,8 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector< // CHECK-LABEL: func @to_shape_cast_rank2_to_rank1( // CHECK-SAME: %[[A:.*]]: vector<1x2xi8>) -// CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]][0] : vector<2xi8> from vector<1x2xi8> -// CHECK: return %[[EXTRACT]] : vector<2xi8> +// CHECK: %[[SC:.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8> +// CHECK: return %[[SC]] : vector<2xi8> func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> { %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8> %1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8> diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir new file mode 100644 index 0000000000000..e249a6afcc993 --- /dev/null +++ b/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir @@ -0,0 +1,162 @@ +// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s + +// This file contains tests where there a vector.shape_cast gets canonicalized, or where a +// vector.shape_cast is the result of a canonicalization. Not all such tests must live in this file. + +// +---------------------------------------- +// Tests of BroadcastToShapeCast +// +---------------------------------------- + +// CHECK-LABEL: @broadcast_to_shape_cast +// CHECK-SAME: %[[ARG0:.*]]: vector<4xi8> +// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]] +// CHECK-NEXT: return %[[SCAST]] : vector<1x1x4xi8> +func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> { + %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> + return %0 : vector<1x1x4xi8> +} + +// ----- + +// broadcast can only be transformed to a shape_cast if the number of elements is +// unchanged by the broadcast +// CHECK-LABEL: @negative_broadcast_increased_elements_to_shape_cast +// CHECK-NOT: shape_cast +// CHECK: return +func.func @negative_broadcast_increased_elements_to_shape_cast(%arg0 : vector<1x4xi8>) -> vector<2x3x4xi8> { + %0 = vector.broadcast %arg0 : vector<1x4xi8> to vector<2x3x4xi8> + return %0 : vector<2x3x4xi8> +} + +// ----- + +// shape_cast does not support scalar inputs/outputs, so a broadcast of a scalar +// cannot be transformed to a shape_cast. +// CHECK-LABEL: @negative_broadcast_scalar_to_shape_cast +// CHECK-NOT: shape_cast +// CHECK: return +func.func @negative_broadcast_scalar_to_shape_cast(%arg0 : i8) -> vector<1xi8> { + %0 = vector.broadcast %arg0 : i8 to vector<1xi8> + return %0 : vector<1xi8> +} + +// ----- + +// +---------------------------------------- +// Tests of TransposeToShapeCast +// +---------------------------------------- + +// In this test, the permutation maps the non-unit dimensions (0 and 2) as follows: +// 0 -> 0 +// 2 -> 1 +// Because 0 < 1, this permutation is order preserving and effectively a shape_cast. +// CHECK-LABEL: @transpose_to_shape_cast +// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32> +// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]] +// CHECK-NEXT: return %[[SCAST]] : vector<2x2x1xf32> +func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> { + %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to vector<2x2x1xf32> + return %0 : vector<2x2x1xf32> +} + +// ----- + +// In this test, the permutation maps the non-unit dimensions (1 and 2) as follows: +// 1 -> 0 +// 2 -> 4 +// Because 0 < 4, this permutation is order preserving and effectively a shape_cast. +// CHECK-LABEL: @shape_cast_of_transpose +// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1x1xi8>) +// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] : +// CHECK-SAME: vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8> +// CHECK: return %[[SHAPE_CAST]] +func.func @shape_cast_of_transpose(%arg : vector<1x4x4x1x1xi8>) -> vector<4x1x1x1x4xi8> { + %0 = vector.transpose %arg, [1, 0, 3, 4, 2] : vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8> + return %0 : vector<4x1x1x1x4xi8> +} + +// ----- + +// Scalable dimensions should be treated as non-unit dimensions. +// CHECK-LABEL: @transpose_scalable_unit +// CHECK-NOT: shape_cast +func.func @transpose_scalable_unit(%arg : vector<[1]x4xi8>) -> vector<4x[1]xi8> { + %0 = vector.transpose %arg, [1, 0] : vector<[1]x4xi8> to vector<4x[1]xi8> + return %0 : vector<4x[1]xi8> +} + +// ----- + +// In this test, the mapping of non-unit dimensions (1 and 2) is as follows: +// 1 -> 2 +// 2 -> 1 +// As this is not increasing (2 > 1), this transpose is not order +// preserving and cannot be treated as a shape_cast. +// CHECK-LABEL: @negative_transpose_to_shape_cast +// CHECK-NOT: shape_cast +func.func @negative_transpose_to_shape_cast(%arg : vector<1x4x4x1xi8>) -> vector<1x4x4x1xi8> { + %0 = vector.transpose %arg, [0, 2, 1, 3] + : vector<1x4x4x1xi8> to vector<1x4x4x1xi8> + return %0 : vector<1x4x4x1xi8> +} + +// ----- + +// CHECK-LABEL: @shape_cast_of_transpose_scalable +// CHECK-NEXT: vector.shape_cast +// CHECK-NEXT: return +func.func @shape_cast_of_transpose_scalable(%arg : vector<[4]x1xi8>) -> vector<[4]xi8> { + %0 = vector.transpose %arg, [1, 0] : vector<[4]x1xi8> to vector<1x[4]xi8> + %1 = vector.shape_cast %0 : vector<1x[4]xi8> to vector<[4]xi8> + return %1 : vector<[4]xi8> +} + +// ----- + +// CHECK-LABEL: @transpose_of_shape_cast_scalable +// CHECK-NEXT: vector.shape_cast +// CHECK-NEXT: return +func.func @transpose_of_shape_cast_scalable(%arg : vector<[4]xi8>) -> vector<[4]x1xi8> { + %0 = vector.shape_cast %arg : vector<[4]xi8> to vector<1x[4]xi8> + %1 = vector.transpose %0, [1, 0] : vector<1x[4]xi8> to vector<[4]x1xi8> + return %1 : vector<[4]x1xi8> +} + +// ----- + +// A test where a transpose cannot be transformed to a shape_cast because it is not order +// preserving +// CHECK-LABEL: @negative_transpose_to_shape_cast +// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32> +// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1] +// CHECK-NEXT: return %[[TRANSPOSE]] : vector<2x2x1xf32> +func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> { + %0 = vector.transpose %arg0, [2, 0, 1] : vector<2x1x2xf32> to vector<2x2x1xf32> + return %0 : vector<2x2x1xf32> +} + +// ----- + +// +---------------------------------------- +// Tests of ExtractToShapeCast +// +---------------------------------------- + +// CHECK-LABEL: @extract_to_shape_cast +// CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32> +// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]] +// CHECK-NEXT: return %[[SCAST]] : vector<4xf32> +func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> { + %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32> + return %0 : vector<4xf32> +} + +// ----- + +// In this example, arg1 might be negative indicating poison. +// CHECK-LABEL: @negative_extract_to_shape_cast +// CHECK-NOT: shape_cast +func.func @negative_extract_to_shape_cast(%arg0 : vector<1x4xf32>, %arg1 : index) -> vector<4xf32> { + %0 = vector.extract %arg0[%arg1] : vector<4xf32> from vector<1x4xf32> + return %0 : vector<4xf32> +} + diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir index 5011d8b2b2ef6..97a8a9a9c2597 100644 --- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir @@ -387,6 +387,66 @@ func.func @non_dividing_gcd_increasing(%arg0 : vector<3x10xi8>) -> vector<2x15xi return %0 : vector<2x15xi8> } +// **--------------------------------------------------------** // +// Tests where the shape_cast is equivalent to a transpose +// **--------------------------------------------------------** // + +// CHECK-LABEL: func @transpose102_1x8x8xf32 +// CHECK: vector.extract {{.*}}[0, 0] : vector<8xf32> from vector<1x8x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<8x1x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[0, 1] : vector<8xf32> from vector<1x8x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [1, 0] : vector<8xf32> into vector<8x1x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[0, 2] : vector<8xf32> from vector<1x8x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [2, 0] : vector<8xf32> into vector<8x1x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[0, 3] : vector<8xf32> from vector<1x8x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [3, 0] : vector<8xf32> into vector<8x1x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[0, 4] : vector<8xf32> from vector<1x8x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [4, 0] : vector<8xf32> into vector<8x1x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[0, 5] : vector<8xf32> from vector<1x8x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [5, 0] : vector<8xf32> into vector<8x1x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[0, 6] : vector<8xf32> from vector<1x8x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [6, 0] : vector<8xf32> into vector<8x1x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[0, 7] : vector<8xf32> from vector<1x8x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [7, 0] : vector<8xf32> into vector<8x1x8xf32> +func.func @transpose102_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x1x8xf32> { + %0 = vector.shape_cast %arg0 : vector<1x8x8xf32> to vector<8x1x8xf32> + return %0 : vector<8x1x8xf32> +} + +// CHECK-LABEL: func @transpose102_8x1x8xf32 +// CHECK: vector.extract {{.*}}[0, 0] : vector<8xf32> from vector<8x1x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<1x8x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[1, 0] : vector<8xf32> from vector<8x1x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<8xf32> into vector<1x8x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[2, 0] : vector<8xf32> from vector<8x1x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [0, 2] : vector<8xf32> into vector<1x8x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[3, 0] : vector<8xf32> from vector<8x1x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [0, 3] : vector<8xf32> into vector<1x8x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[4, 0] : vector<8xf32> from vector<8x1x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [0, 4] : vector<8xf32> into vector<1x8x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[5, 0] : vector<8xf32> from vector<8x1x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [0, 5] : vector<8xf32> into vector<1x8x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[6, 0] : vector<8xf32> from vector<8x1x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [0, 6] : vector<8xf32> into vector<1x8x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[7, 0] : vector<8xf32> from vector<8x1x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [0, 7] : vector<8xf32> into vector<1x8x8xf32> +func.func @transpose102_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32> { + %0 = vector.shape_cast %arg0 : vector<8x1x8xf32> to vector<1x8x8xf32> + return %0 : vector<1x8x8xf32> +} + +// CHECK-LABEL: func @transpose1023_2x1x8x4xf32( +// CHECK: vector.extract {{.*}}[0, 0] : vector<8x4xf32> from vector<2x1x8x4xf32> +// CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8x4xf32> into vector<1x2x8x4xf32> +// CHECK-NEXT: vector.extract {{.*}}[1, 0] : vector<8x4xf32> from vector<2x1x8x4xf32> +// CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<8x4xf32> into vector<1x2x8x4xf32> +func.func @transpose1023_2x1x8x4xf32(%arg0: vector<2x1x8x4xf32>) -> vector<1x2x8x4xf32> { + // Note the 2-D extract/insert pair since dimensions 2 and 3 are not transposed! + %0 = vector.shape_cast %arg0 : vector<2x1x8x4xf32> to vector<1x2x8x4xf32> + return %0 : vector<1x2x8x4xf32> +} + + module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { %f = transform.structured.match ops{["func.func"]} in %module_op diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir index 511ab70f35086..7886fba6c80c4 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -24,7 +24,7 @@ func.func @vector_transfer_ops_0d_tensor(%src: tensor) -> vector<1xf32> { %f0 = arith.constant 0.0 : f32 // CHECK: %[[S:.*]] = vector.transfer_read %[[SRC]][] -// CHECK: %[[V:.*]] = vector.broadcast %[[S]] : vector to vector<1xf32> +// CHECK: %[[V:.*]] = vector.shape_cast %[[S]] : vector to vector<1xf32> %res = vector.transfer_read %src[], %f0 {in_bounds = [true], permutation_map = affine_map<()->(0)>} : tensor, vector<1xf32> @@ -369,9 +369,8 @@ func.func @transfer_write_broadcast_unit_dim_tensor( %c0 = arith.constant 0 : index %res = vector.transfer_write %vec_0, %dst_0[%c0, %c0, %c0, %c0] {in_bounds = [false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>} : vector<14x8x16xf32>, tensor - // CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<14x8x16xf32> to vector<1x14x8x16xf32> - // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0, 3] : vector<1x14x8x16xf32> to vector<14x8x1x16xf32> - // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC1]], %[[DST0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true, true]} : vector<14x8x1x16xf32>, tensor + // CHECK: %[[NEW_VEC0:.*]] = vector.shape_cast %{{.*}} : vector<14x8x16xf32> to vector<14x8x1x16xf32> + // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[DST0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true, true]} : vector<14x8x1x16xf32>, tensor return %res : tensor } @@ -385,9 +384,8 @@ func.func @transfer_write_broadcast_unit_dim_memref( %c0 = arith.constant 0 : index vector.transfer_write %vec_0, %mem_0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} : vector<8x16xf32>, memref - // CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<8x16xf32> to vector<1x8x16xf32> - // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0] : vector<1x8x16xf32> to vector<8x16x1xf32> - // CHECK: vector.transfer_write %[[NEW_VEC1]], %[[MEM0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true]} : vector<8x16x1xf32>, memref + // CHECK: %[[NEW_VEC0:.*]] = vector.shape_cast %{{.*}} : vector<8x16xf32> to vector<8x16x1xf32> + // CHECK: vector.transfer_write %[[NEW_VEC0]], %[[MEM0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true]} : vector<8x16x1xf32>, memref return } diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir index a730f217f027d..e5dfb0645eeb9 100644 --- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir @@ -23,55 +23,21 @@ func.func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> { // CHECK-LABEL: func @transpose102_1x8x8xf32 func.func @transpose102_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x1x8xf32> { - // CHECK: vector.extract {{.*}}[0, 0] : vector<8xf32> from vector<1x8x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<8x1x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[0, 1] : vector<8xf32> from vector<1x8x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [1, 0] : vector<8xf32> into vector<8x1x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[0, 2] : vector<8xf32> from vector<1x8x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [2, 0] : vector<8xf32> into vector<8x1x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[0, 3] : vector<8xf32> from vector<1x8x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [3, 0] : vector<8xf32> into vector<8x1x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[0, 4] : vector<8xf32> from vector<1x8x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [4, 0] : vector<8xf32> into vector<8x1x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[0, 5] : vector<8xf32> from vector<1x8x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [5, 0] : vector<8xf32> into vector<8x1x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[0, 6] : vector<8xf32> from vector<1x8x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [6, 0] : vector<8xf32> into vector<8x1x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[0, 7] : vector<8xf32> from vector<1x8x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [7, 0] : vector<8xf32> into vector<8x1x8xf32> + // CHECK: vector.shape_cast %0 = vector.transpose %arg0, [1, 0, 2] : vector<1x8x8xf32> to vector<8x1x8xf32> return %0 : vector<8x1x8xf32> } // CHECK-LABEL: func @transpose102_8x1x8xf32 func.func @transpose102_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32> { - // CHECK: vector.extract {{.*}}[0, 0] : vector<8xf32> from vector<8x1x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<1x8x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[1, 0] : vector<8xf32> from vector<8x1x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<8xf32> into vector<1x8x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[2, 0] : vector<8xf32> from vector<8x1x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 2] : vector<8xf32> into vector<1x8x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[3, 0] : vector<8xf32> from vector<8x1x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 3] : vector<8xf32> into vector<1x8x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[4, 0] : vector<8xf32> from vector<8x1x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 4] : vector<8xf32> into vector<1x8x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[5, 0] : vector<8xf32> from vector<8x1x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 5] : vector<8xf32> into vector<1x8x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[6, 0] : vector<8xf32> from vector<8x1x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 6] : vector<8xf32> into vector<1x8x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[7, 0] : vector<8xf32> from vector<8x1x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 7] : vector<8xf32> into vector<1x8x8xf32> + // CHECK: vector.shape_cast %0 = vector.transpose %arg0, [1, 0, 2] : vector<8x1x8xf32> to vector<1x8x8xf32> return %0 : vector<1x8x8xf32> } -// CHECK-LABEL: func @transpose1023_2x1x8x4xf32( +// CHECK-LABEL: func @transpose1023_2x1x8x4xf32( func.func @transpose1023_2x1x8x4xf32(%arg0: vector<2x1x8x4xf32>) -> vector<1x2x8x4xf32> { - // Note the 2-D extract/insert pair since dimensions 2 and 3 are not transposed! - // CHECK: vector.extract {{.*}}[0, 0] : vector<8x4xf32> from vector<2x1x8x4xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8x4xf32> into vector<1x2x8x4xf32> - // CHECK-NEXT: vector.extract {{.*}}[1, 0] : vector<8x4xf32> from vector<2x1x8x4xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<8x4xf32> into vector<1x2x8x4xf32> + // CHECK: vector.shape_cast %0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<2x1x8x4xf32> to vector<1x2x8x4xf32> return %0 : vector<1x2x8x4xf32> } diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index 38771f2593449..ba47799729f1d 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -1311,8 +1311,8 @@ func.func @vector_insert_2d_broadcast(%laneid: index) -> (vector<4x96xf32>) { // CHECK-PROP-DAG: %[[THREADID:.*]] = gpu.thread_id x // CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0(%[[THREADID]])[32] args(%[[IN2]] // CHECK-PROP: %[[GATHER:.*]] = vector.gather %[[AR1]][{{.*}}] -// CHECK-PROP: %[[EXTRACT:.*]] = vector.extract %[[GATHER]][0] : vector<64xi32> from vector<1x64xi32> -// CHECK-PROP: %[[CAST:.*]] = arith.index_cast %[[EXTRACT]] : vector<64xi32> to vector<64xindex> +// CHECK-PROP: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[GATHER]] : vector<1x64xi32> to vector<64xi32> +// CHECK-PROP: %[[CAST:.*]] = arith.index_cast %[[SHAPE_CAST]] : vector<64xi32> to vector<64xindex> // CHECK-PROP: %[[EXTRACTELT:.*]] = vector.extract %[[CAST]][{{.*}}] : index from vector<64xindex> // CHECK-PROP: gpu.yield %[[EXTRACTELT]] : index // CHECK-PROP: %[[APPLY:.*]] = affine.apply #[[$MAP]]()[%[[THREADID]]] @@ -1348,8 +1348,8 @@ func.func @transfer_read_prop_operands(%in2: vector<1x2xindex>, %ar1 : memref<1 // CHECK-PROP-LABEL: func @dont_fold_vector_broadcast( // CHECK-PROP: %[[r:.*]] = gpu.warp_execute_on_lane_0{{.*}} -> (vector<1x2xf32>) // CHECK-PROP: %[[some_def:.*]] = "some_def" -// CHECK-PROP: %[[broadcast:.*]] = vector.broadcast %[[some_def]] : vector<64xf32> to vector<1x64xf32> -// CHECK-PROP: gpu.yield %[[broadcast]] : vector<1x64xf32> +// CHECK-PROP: %[[shape_cast:.*]] = vector.shape_cast %[[some_def]] : vector<64xf32> to vector<1x64xf32> +// CHECK-PROP: gpu.yield %[[shape_cast]] : vector<1x64xf32> // CHECK-PROP: vector.print %[[r]] : vector<1x2xf32> func.func @dont_fold_vector_broadcast(%laneid: index) { %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x2xf32>) {