diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index a11dbe2589205..9461ba02dd546 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1696,59 +1696,107 @@ static bool hasZeroDimVectors(Operation *op) { llvm::any_of(op->getResultTypes(), hasZeroDimVectorType); } -/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp. +/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend +/// 1s, are considered to be 'broadcastlike'. +static bool isBroadcastLike(Operation *op) { + if (isa(op)) + return true; + + auto shapeCast = dyn_cast(op); + if (!shapeCast) + return false; + + // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3). + // Condition 1: dst has hight rank. + // Condition 2: src shape is a suffix of dst shape. + // + // Note that checking that dst shape has a prefix of 1s is not sufficient, + // for example (2,3) -> (1,3,2) is not broadcast-like. + VectorType srcType = shapeCast.getSourceVectorType(); + ArrayRef srcShape = srcType.getShape(); + uint64_t srcRank = srcType.getRank(); + ArrayRef dstShape = shapeCast.getType().getShape(); + return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape; +} + +/// Fold extract(broadcast(X)) to either extract(X) or just X. +/// +/// Example: +/// +/// broadcast extract +/// (3, 4) --------> (2, 3, 4) ------> (4) +/// +/// becomes +/// extract +/// (3,4) ---------------------------> (4) +/// +/// +/// The variable names used in this implementation use names which correspond to +/// the above shapes as, +/// +/// - (3, 4) is `input` shape. +/// - (2, 3, 4) is `broadcast` shape. +/// - (4) is `extract` shape. +/// +/// This folding is possible when the suffix of `input` shape is the same as +/// `extract` shape. static Value foldExtractFromBroadcast(ExtractOp extractOp) { + Operation *defOp = extractOp.getVector().getDefiningOp(); - if (!defOp || !isa(defOp)) + if (!defOp || !isBroadcastLike(defOp)) return Value(); - Value source = defOp->getOperand(0); - if (extractOp.getType() == source.getType()) - return source; - auto getRank = [](Type type) { - return llvm::isa(type) ? llvm::cast(type).getRank() - : 0; - }; + Value input = defOp->getOperand(0); - // If splat or broadcast from a scalar, just return the source scalar. - unsigned broadcastSrcRank = getRank(source.getType()); - if (broadcastSrcRank == 0 && source.getType() == extractOp.getType()) - return source; + // Replace extract(broadcast(X)) with X + if (extractOp.getType() == input.getType()) + return input; - unsigned extractResultRank = getRank(extractOp.getType()); - if (extractResultRank > broadcastSrcRank) - return Value(); - // Check that the dimension of the result haven't been broadcasted. - auto extractVecType = llvm::dyn_cast(extractOp.getType()); - auto broadcastVecType = llvm::dyn_cast(source.getType()); - if (extractVecType && broadcastVecType && - extractVecType.getShape() != - broadcastVecType.getShape().take_back(extractResultRank)) + // Get required types and ranks in the chain + // input -> broadcast -> extract + auto inputType = llvm::dyn_cast(input.getType()); + auto extractType = llvm::dyn_cast(extractOp.getType()); + unsigned inputRank = inputType ? inputType.getRank() : 0; + unsigned broadcastRank = extractOp.getSourceVectorType().getRank(); + unsigned extractRank = extractType ? extractType.getRank() : 0; + + // Cannot do without the broadcast if overall the rank increases. + if (extractRank > inputRank) return Value(); - auto broadcastOp = cast(defOp); - int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank(); + // Proof by contradiction that, at this point, input is a vector. + // Suppose input is a scalar. + // ==> inputRank is 0. + // ==> extractRank is 0 (because extractRank <= inputRank). + // ==> extract is scalar (because rank-0 extraction is always scalar). + // ==> input and extract are scalar, so same type. + // ==> returned early (check same type). + // Contradiction! + assert(inputType && "input must be a vector type because of previous checks"); + ArrayRef inputShape = inputType.getShape(); + + // In the case where there is a broadcast dimension in the suffix, it is not + // possible to replace extract(broadcast(X)) with extract(X). Example: + // + // broadcast extract + // (1) --------> (3,4) ------> (4) + if (extractType && + extractType.getShape() != inputShape.take_back(extractRank)) + return Value(); - // Detect all the positions that come from "dim-1" broadcasting. - // These dimensions correspond to "dim-1" broadcasted dims; set the mathching - // extract position to `0` when extracting from the source operand. - llvm::SetVector broadcastedUnitDims = - broadcastOp.computeBroadcastedUnitDims(); - SmallVector extractPos(extractOp.getMixedPosition()); - OpBuilder b(extractOp.getContext()); - int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank; - for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i) - if (broadcastedUnitDims.contains(i)) - extractPos[i] = b.getIndexAttr(0); - // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the - // matching extract position when extracting from the source operand. - int64_t rankDiff = broadcastSrcRank - extractResultRank; - extractPos.erase(extractPos.begin(), - std::next(extractPos.begin(), extractPos.size() - rankDiff)); - // OpBuilder is only used as a helper to build an I64ArrayAttr. - auto [staticPos, dynPos] = decomposeMixedValues(extractPos); + // Replace extract(broadcast(X)) with extract(X). + // First, determine the new extraction position. + unsigned deltaOverall = inputRank - extractRank; + unsigned deltaBroadcast = broadcastRank - inputRank; + SmallVector oldPositions = extractOp.getMixedPosition(); + SmallVector newPositions(deltaOverall); + IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0); + for (auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) { + newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast]; + } + auto [staticPos, dynPos] = decomposeMixedValues(newPositions); extractOp->setOperands( - llvm::to_vector(llvm::concat(ValueRange(source), dynPos))); + llvm::to_vector(llvm::concat(ValueRange(input), dynPos))); extractOp.setStaticPosition(staticPos); return extractOp.getResult(); } @@ -2193,32 +2241,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern { LogicalResult matchAndRewrite(ExtractOp extractOp, PatternRewriter &rewriter) const override { + Operation *defOp = extractOp.getVector().getDefiningOp(); - if (!defOp || !isa(defOp)) + VectorType outType = dyn_cast(extractOp.getType()); + if (!defOp || !isBroadcastLike(defOp) || !outType) return failure(); Value source = defOp->getOperand(0); - if (extractOp.getType() == source.getType()) - return failure(); - auto getRank = [](Type type) { - return llvm::isa(type) - ? llvm::cast(type).getRank() - : 0; - }; - unsigned broadcastSrcRank = getRank(source.getType()); - unsigned extractResultRank = getRank(extractOp.getType()); - // We only consider the case where the rank of the source is less than or - // equal to the rank of the extract dst. The other cases are handled in the - // folding patterns. - if (extractResultRank < broadcastSrcRank) - return failure(); - // For scalar result, the input can only be a rank-0 vector, which will - // be handled by the folder. - if (extractResultRank == 0) + if (isBroadcastableTo(source.getType(), outType) != + BroadcastableToResult::Success) return failure(); - rewriter.replaceOpWithNewOp( - extractOp, extractOp.getType(), source); + rewriter.replaceOpWithNewOp(extractOp, outType, source); return success(); } }; diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 65b73375831da..c7d9074b853f9 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -764,10 +764,10 @@ func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32 // ----- -// CHECK-LABEL: fold_extract_splat +// CHECK-LABEL: fold_extract_scalar_from_splat // CHECK-SAME: %[[A:.*]]: f32 // CHECK: return %[[A]] : f32 -func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 { +func.func @fold_extract_scalar_from_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 { %b = vector.splat %a : vector<1x2x4xf32> %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32> return %r : f32 @@ -775,6 +775,16 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : in // ----- +// CHECK-LABEL: fold_extract_vector_from_splat +// CHECK: vector.broadcast {{.*}} f32 to vector<4xf32> +func.func @fold_extract_vector_from_splat(%a : f32, %idx0 : index, %idx1 : index) -> vector<4xf32> { + %b = vector.splat %a : vector<1x2x4xf32> + %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32> + return %r : vector<4xf32> +} + +// ----- + // CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting // CHECK-SAME: %[[A:.*]]: vector<2x1xf32> // CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index @@ -804,6 +814,35 @@ func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>, // ----- +// Test where the shape_cast is broadcast-like. +// CHECK-LABEL: fold_extract_shape_cast_to_lower_rank +// CHECK-SAME: %[[A:.*]]: vector<2x4xf32> +// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index +// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32> +// CHECK: return %[[B]] : vector<4xf32> +func.func @fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>, + %idx0 : index, %idx1 : index) -> vector<4xf32> { + %b = vector.shape_cast %a : vector<2x4xf32> to vector<1x2x4xf32> + %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32> + return %r : vector<4xf32> +} + +// ----- + +// Test where the shape_cast is not broadcast-like, even though it prepends 1s. +// CHECK-LABEL: negative_fold_extract_shape_cast_to_lower_rank +// CHECK-NEXT: vector.shape_cast +// CHECK-NEXT: vector.extract +// CHECK-NEXT: return +func.func @negative_fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>, + %idx0 : index, %idx1 : index) -> vector<2xf32> { + %b = vector.shape_cast %a : vector<2x4xf32> to vector<1x4x2xf32> + %r = vector.extract %b[%idx0, %idx1] : vector<2xf32> from vector<1x4x2xf32> + return %r : vector<2xf32> +} + +// ----- + // CHECK-LABEL: fold_extract_broadcast_to_higher_rank // CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32> // CHECK: return %[[B]] : vector<4xf32> @@ -831,6 +870,19 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde // ----- +// CHECK-LABEL: fold_extract_broadcastlike_shape_cast +// CHECK-SAME: %[[A:.*]]: vector<1xf32> +// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<1x1xf32> +// CHECK: return %[[R]] : vector<1x1xf32> +func.func @fold_extract_broadcastlike_shape_cast(%a : vector<1xf32>, %idx0 : index) + -> vector<1x1xf32> { + %s = vector.shape_cast %a : vector<1xf32> to vector<1x1x1xf32> + %r = vector.extract %s[%idx0] : vector<1x1xf32> from vector<1x1x1xf32> + return %r : vector<1x1xf32> +} + +// ----- + // CHECK-LABEL: @fold_extract_shuffle // CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32> // CHECK-NOT: vector.shuffle @@ -1549,7 +1601,7 @@ func.func @negative_store_to_load_tensor_memref( %arg0 : tensor, %arg1 : memref, %v0 : vector<4x2xf32> - ) -> vector<4x2xf32> + ) -> vector<4x2xf32> { %c0 = arith.constant 0 : index %cf0 = arith.constant 0.0 : f32 @@ -1606,7 +1658,7 @@ func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor< // CHECK: vector.transfer_read func.func @negative_store_to_load_tensor_broadcast_masked( %arg0 : tensor, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>) - -> vector<4x2x6xf32> + -> vector<4x2x6xf32> { %c0 = arith.constant 0 : index %cf0 = arith.constant 0.0 : f32