From b3877b8d6c90569c89eddddca63303fa0bf3e28a Mon Sep 17 00:00:00 2001 From: James Newling Date: Mon, 30 Jun 2025 09:15:48 -0700 Subject: [PATCH 1/2] extend to broadcastlike, code simplifications --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 126 ++++++++++----------- mlir/test/Dialect/Vector/canonicalize.mlir | 46 +++++++- 2 files changed, 104 insertions(+), 68 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index a11dbe2589205..ed616fb4d343b 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1696,59 +1696,71 @@ static bool hasZeroDimVectors(Operation *op) { llvm::any_of(op->getResultTypes(), hasZeroDimVectorType); } +/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepends +/// 1s, are considered 'broadcastlike'. +static bool isBroadcastLike(Operation *op) { + if (isa(op)) + return true; + + auto shapeCast = dyn_cast(op); + if (!shapeCast) + return false; + + // Check that it just prepends 1s, like (2,3) -> (1,1,2,3). + // Condition 1: dst has hight rank. + // Condition 2: src shape is a suffix of dst shape. + VectorType srcType = shapeCast.getSourceVectorType(); + ArrayRef srcShape = srcType.getShape(); + uint64_t srcRank = srcType.getRank(); + ArrayRef dstShape = shapeCast.getType().getShape(); + return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape; +} + /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp. static Value foldExtractFromBroadcast(ExtractOp extractOp) { - Operation *defOp = extractOp.getVector().getDefiningOp(); - if (!defOp || !isa(defOp)) + + Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp(); + if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp)) return Value(); - Value source = defOp->getOperand(0); - if (extractOp.getType() == source.getType()) - return source; - auto getRank = [](Type type) { - return llvm::isa(type) ? llvm::cast(type).getRank() - : 0; - }; + Value src = broadcastLikeOp->getOperand(0); + + // Replace extract(broadcast(X)) with X + if (extractOp.getType() == src.getType()) + return src; - // If splat or broadcast from a scalar, just return the source scalar. - unsigned broadcastSrcRank = getRank(source.getType()); - if (broadcastSrcRank == 0 && source.getType() == extractOp.getType()) - return source; + // Get required types and ranks in the chain + // src -> broadcastDst -> dst + auto srcType = llvm::dyn_cast(src.getType()); + auto dstType = llvm::dyn_cast(extractOp.getType()); + unsigned srcRank = srcType ? srcType.getRank() : 0; + unsigned broadcastDstRank = extractOp.getSourceVectorType().getRank(); + unsigned dstRank = dstType ? dstType.getRank() : 0; - unsigned extractResultRank = getRank(extractOp.getType()); - if (extractResultRank > broadcastSrcRank) + // Cannot do without the broadcast if overall the rank increases. + if (dstRank > srcRank) return Value(); - // Check that the dimension of the result haven't been broadcasted. - auto extractVecType = llvm::dyn_cast(extractOp.getType()); - auto broadcastVecType = llvm::dyn_cast(source.getType()); - if (extractVecType && broadcastVecType && - extractVecType.getShape() != - broadcastVecType.getShape().take_back(extractResultRank)) + + assert(srcType && "src must be a vector type because of previous checks"); + + ArrayRef srcShape = srcType.getShape(); + if (dstType && dstType.getShape() != srcShape.take_back(dstRank)) return Value(); - auto broadcastOp = cast(defOp); - int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank(); + // Replace extract(broadcast(X)) with extract(X). + // First, determine the new extraction position. + unsigned deltaOverall = srcRank - dstRank; + unsigned deltaBroadcast = broadcastDstRank - srcRank; - // Detect all the positions that come from "dim-1" broadcasting. - // These dimensions correspond to "dim-1" broadcasted dims; set the mathching - // extract position to `0` when extracting from the source operand. - llvm::SetVector broadcastedUnitDims = - broadcastOp.computeBroadcastedUnitDims(); - SmallVector extractPos(extractOp.getMixedPosition()); - OpBuilder b(extractOp.getContext()); - int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank; - for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i) - if (broadcastedUnitDims.contains(i)) - extractPos[i] = b.getIndexAttr(0); - // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the - // matching extract position when extracting from the source operand. - int64_t rankDiff = broadcastSrcRank - extractResultRank; - extractPos.erase(extractPos.begin(), - std::next(extractPos.begin(), extractPos.size() - rankDiff)); - // OpBuilder is only used as a helper to build an I64ArrayAttr. - auto [staticPos, dynPos] = decomposeMixedValues(extractPos); + SmallVector oldPositions = extractOp.getMixedPosition(); + SmallVector newPositions(deltaOverall); + IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0); + for (auto [i, size] : llvm::enumerate(srcShape.take_front(deltaOverall))) { + newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast]; + } + auto [staticPos, dynPos] = decomposeMixedValues(newPositions); extractOp->setOperands( - llvm::to_vector(llvm::concat(ValueRange(source), dynPos))); + llvm::to_vector(llvm::concat(ValueRange(src), dynPos))); extractOp.setStaticPosition(staticPos); return extractOp.getResult(); } @@ -2193,32 +2205,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern { LogicalResult matchAndRewrite(ExtractOp extractOp, PatternRewriter &rewriter) const override { - Operation *defOp = extractOp.getVector().getDefiningOp(); - if (!defOp || !isa(defOp)) - return failure(); - Value source = defOp->getOperand(0); - if (extractOp.getType() == source.getType()) + Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp(); + VectorType outType = dyn_cast(extractOp.getType()); + if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp) || !outType) return failure(); - auto getRank = [](Type type) { - return llvm::isa(type) - ? llvm::cast(type).getRank() - : 0; - }; - unsigned broadcastSrcRank = getRank(source.getType()); - unsigned extractResultRank = getRank(extractOp.getType()); - // We only consider the case where the rank of the source is less than or - // equal to the rank of the extract dst. The other cases are handled in the - // folding patterns. - if (extractResultRank < broadcastSrcRank) - return failure(); - // For scalar result, the input can only be a rank-0 vector, which will - // be handled by the folder. - if (extractResultRank == 0) + + Value source = broadcastLikeOp->getOperand(0); + if (isBroadcastableTo(source.getType(), outType) != + BroadcastableToResult::Success) return failure(); - rewriter.replaceOpWithNewOp( - extractOp, extractOp.getType(), source); + rewriter.replaceOpWithNewOp(extractOp, outType, source); return success(); } }; diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 65b73375831da..350233d1f7969 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -764,10 +764,10 @@ func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32 // ----- -// CHECK-LABEL: fold_extract_splat +// CHECK-LABEL: fold_extract_scalar_from_splat // CHECK-SAME: %[[A:.*]]: f32 // CHECK: return %[[A]] : f32 -func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 { +func.func @fold_extract_scalar_from_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 { %b = vector.splat %a : vector<1x2x4xf32> %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32> return %r : f32 @@ -775,6 +775,16 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : in // ----- +// CHECK-LABEL: fold_extract_vector_from_splat +// CHECK: vector.broadcast {{.*}} f32 to vector<4xf32> +func.func @fold_extract_vector_from_splat(%a : f32, %idx0 : index, %idx1 : index) -> vector<4xf32> { + %b = vector.splat %a : vector<1x2x4xf32> + %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32> + return %r : vector<4xf32> +} + +// ----- + // CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting // CHECK-SAME: %[[A:.*]]: vector<2x1xf32> // CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index @@ -804,6 +814,21 @@ func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>, // ----- +// Test where the shape_cast is broadcast-like. +// CHECK-LABEL: fold_extract_shape_cast_to_lower_rank +// CHECK-SAME: %[[A:.*]]: vector<2x4xf32> +// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index +// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32> +// CHECK: return %[[B]] : vector<4xf32> +func.func @fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>, + %idx0 : index, %idx1 : index) -> vector<4xf32> { + %b = vector.shape_cast %a : vector<2x4xf32> to vector<1x2x4xf32> + %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32> + return %r : vector<4xf32> +} + +// ----- + // CHECK-LABEL: fold_extract_broadcast_to_higher_rank // CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32> // CHECK: return %[[B]] : vector<4xf32> @@ -831,6 +856,19 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde // ----- +// CHECK-LABEL: fold_extract_broadcastlike_shape_cast +// CHECK-SAME: %[[A:.*]]: vector<1xf32> +// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<1x1xf32> +// CHECK: return %[[R]] : vector<1x1xf32> +func.func @fold_extract_broadcastlike_shape_cast(%a : vector<1xf32>, %idx0 : index) + -> vector<1x1xf32> { + %s = vector.shape_cast %a : vector<1xf32> to vector<1x1x1xf32> + %r = vector.extract %s[%idx0] : vector<1x1xf32> from vector<1x1x1xf32> + return %r : vector<1x1xf32> +} + +// ----- + // CHECK-LABEL: @fold_extract_shuffle // CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32> // CHECK-NOT: vector.shuffle @@ -1549,7 +1587,7 @@ func.func @negative_store_to_load_tensor_memref( %arg0 : tensor, %arg1 : memref, %v0 : vector<4x2xf32> - ) -> vector<4x2xf32> + ) -> vector<4x2xf32> { %c0 = arith.constant 0 : index %cf0 = arith.constant 0.0 : f32 @@ -1606,7 +1644,7 @@ func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor< // CHECK: vector.transfer_read func.func @negative_store_to_load_tensor_broadcast_masked( %arg0 : tensor, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>) - -> vector<4x2x6xf32> + -> vector<4x2x6xf32> { %c0 = arith.constant 0 : index %cf0 = arith.constant 0.0 : f32 From da8e03a560fc0850b280b39cb82e0df8e86a4fce Mon Sep 17 00:00:00 2001 From: James Newling Date: Mon, 30 Jun 2025 11:32:03 -0700 Subject: [PATCH 2/2] improve comments, add test --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 92 +++++++++++++++------- mlir/test/Dialect/Vector/canonicalize.mlir | 14 ++++ 2 files changed, 78 insertions(+), 28 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index ed616fb4d343b..9461ba02dd546 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1696,8 +1696,8 @@ static bool hasZeroDimVectors(Operation *op) { llvm::any_of(op->getResultTypes(), hasZeroDimVectorType); } -/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepends -/// 1s, are considered 'broadcastlike'. +/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend +/// 1s, are considered to be 'broadcastlike'. static bool isBroadcastLike(Operation *op) { if (isa(op)) return true; @@ -1706,9 +1706,12 @@ static bool isBroadcastLike(Operation *op) { if (!shapeCast) return false; - // Check that it just prepends 1s, like (2,3) -> (1,1,2,3). + // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3). // Condition 1: dst has hight rank. // Condition 2: src shape is a suffix of dst shape. + // + // Note that checking that dst shape has a prefix of 1s is not sufficient, + // for example (2,3) -> (1,3,2) is not broadcast-like. VectorType srcType = shapeCast.getSourceVectorType(); ArrayRef srcShape = srcType.getShape(); uint64_t srcRank = srcType.getRank(); @@ -1716,51 +1719,84 @@ static bool isBroadcastLike(Operation *op) { return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape; } -/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp. +/// Fold extract(broadcast(X)) to either extract(X) or just X. +/// +/// Example: +/// +/// broadcast extract +/// (3, 4) --------> (2, 3, 4) ------> (4) +/// +/// becomes +/// extract +/// (3,4) ---------------------------> (4) +/// +/// +/// The variable names used in this implementation use names which correspond to +/// the above shapes as, +/// +/// - (3, 4) is `input` shape. +/// - (2, 3, 4) is `broadcast` shape. +/// - (4) is `extract` shape. +/// +/// This folding is possible when the suffix of `input` shape is the same as +/// `extract` shape. static Value foldExtractFromBroadcast(ExtractOp extractOp) { - Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp(); - if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp)) + Operation *defOp = extractOp.getVector().getDefiningOp(); + if (!defOp || !isBroadcastLike(defOp)) return Value(); - Value src = broadcastLikeOp->getOperand(0); + Value input = defOp->getOperand(0); // Replace extract(broadcast(X)) with X - if (extractOp.getType() == src.getType()) - return src; + if (extractOp.getType() == input.getType()) + return input; // Get required types and ranks in the chain - // src -> broadcastDst -> dst - auto srcType = llvm::dyn_cast(src.getType()); - auto dstType = llvm::dyn_cast(extractOp.getType()); - unsigned srcRank = srcType ? srcType.getRank() : 0; - unsigned broadcastDstRank = extractOp.getSourceVectorType().getRank(); - unsigned dstRank = dstType ? dstType.getRank() : 0; + // input -> broadcast -> extract + auto inputType = llvm::dyn_cast(input.getType()); + auto extractType = llvm::dyn_cast(extractOp.getType()); + unsigned inputRank = inputType ? inputType.getRank() : 0; + unsigned broadcastRank = extractOp.getSourceVectorType().getRank(); + unsigned extractRank = extractType ? extractType.getRank() : 0; // Cannot do without the broadcast if overall the rank increases. - if (dstRank > srcRank) + if (extractRank > inputRank) return Value(); - assert(srcType && "src must be a vector type because of previous checks"); - - ArrayRef srcShape = srcType.getShape(); - if (dstType && dstType.getShape() != srcShape.take_back(dstRank)) + // Proof by contradiction that, at this point, input is a vector. + // Suppose input is a scalar. + // ==> inputRank is 0. + // ==> extractRank is 0 (because extractRank <= inputRank). + // ==> extract is scalar (because rank-0 extraction is always scalar). + // ==> input and extract are scalar, so same type. + // ==> returned early (check same type). + // Contradiction! + assert(inputType && "input must be a vector type because of previous checks"); + ArrayRef inputShape = inputType.getShape(); + + // In the case where there is a broadcast dimension in the suffix, it is not + // possible to replace extract(broadcast(X)) with extract(X). Example: + // + // broadcast extract + // (1) --------> (3,4) ------> (4) + if (extractType && + extractType.getShape() != inputShape.take_back(extractRank)) return Value(); // Replace extract(broadcast(X)) with extract(X). // First, determine the new extraction position. - unsigned deltaOverall = srcRank - dstRank; - unsigned deltaBroadcast = broadcastDstRank - srcRank; - + unsigned deltaOverall = inputRank - extractRank; + unsigned deltaBroadcast = broadcastRank - inputRank; SmallVector oldPositions = extractOp.getMixedPosition(); SmallVector newPositions(deltaOverall); IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0); - for (auto [i, size] : llvm::enumerate(srcShape.take_front(deltaOverall))) { + for (auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) { newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast]; } auto [staticPos, dynPos] = decomposeMixedValues(newPositions); extractOp->setOperands( - llvm::to_vector(llvm::concat(ValueRange(src), dynPos))); + llvm::to_vector(llvm::concat(ValueRange(input), dynPos))); extractOp.setStaticPosition(staticPos); return extractOp.getResult(); } @@ -2206,12 +2242,12 @@ class ExtractOpFromBroadcast final : public OpRewritePattern { LogicalResult matchAndRewrite(ExtractOp extractOp, PatternRewriter &rewriter) const override { - Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp(); + Operation *defOp = extractOp.getVector().getDefiningOp(); VectorType outType = dyn_cast(extractOp.getType()); - if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp) || !outType) + if (!defOp || !isBroadcastLike(defOp) || !outType) return failure(); - Value source = broadcastLikeOp->getOperand(0); + Value source = defOp->getOperand(0); if (isBroadcastableTo(source.getType(), outType) != BroadcastableToResult::Success) return failure(); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 350233d1f7969..c7d9074b853f9 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -829,6 +829,20 @@ func.func @fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>, // ----- +// Test where the shape_cast is not broadcast-like, even though it prepends 1s. +// CHECK-LABEL: negative_fold_extract_shape_cast_to_lower_rank +// CHECK-NEXT: vector.shape_cast +// CHECK-NEXT: vector.extract +// CHECK-NEXT: return +func.func @negative_fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>, + %idx0 : index, %idx1 : index) -> vector<2xf32> { + %b = vector.shape_cast %a : vector<2x4xf32> to vector<1x4x2xf32> + %r = vector.extract %b[%idx0, %idx1] : vector<2xf32> from vector<1x4x2xf32> + return %r : vector<2xf32> +} + +// ----- + // CHECK-LABEL: fold_extract_broadcast_to_higher_rank // CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32> // CHECK: return %[[B]] : vector<4xf32>