Skip to content

[mlir][vector] Folder: shape_cast(extract) -> extract #146368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 97 additions & 63 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1696,59 +1696,107 @@ static bool hasZeroDimVectors(Operation *op) {
llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
}

/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend
/// 1s, are considered to be 'broadcastlike'.
static bool isBroadcastLike(Operation *op) {
if (isa<BroadcastOp, SplatOp>(op))
return true;

auto shapeCast = dyn_cast<ShapeCastOp>(op);
if (!shapeCast)
return false;

// Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
// Condition 1: dst has hight rank.
// Condition 2: src shape is a suffix of dst shape.
//
// Note that checking that dst shape has a prefix of 1s is not sufficient,
// for example (2,3) -> (1,3,2) is not broadcast-like.
VectorType srcType = shapeCast.getSourceVectorType();
ArrayRef<int64_t> srcShape = srcType.getShape();
uint64_t srcRank = srcType.getRank();
ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
}

/// Fold extract(broadcast(X)) to either extract(X) or just X.
///
/// Example:
///
/// broadcast extract
/// (3, 4) --------> (2, 3, 4) ------> (4)
///
/// becomes
/// extract
/// (3,4) ---------------------------> (4)
///
///
/// The variable names used in this implementation use names which correspond to
/// the above shapes as,
///
/// - (3, 4) is `input` shape.
/// - (2, 3, 4) is `broadcast` shape.
/// - (4) is `extract` shape.
///
/// This folding is possible when the suffix of `input` shape is the same as
/// `extract` shape.
static Value foldExtractFromBroadcast(ExtractOp extractOp) {

Operation *defOp = extractOp.getVector().getDefiningOp();
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
if (!defOp || !isBroadcastLike(defOp))
return Value();

Value source = defOp->getOperand(0);
if (extractOp.getType() == source.getType())
return source;
auto getRank = [](Type type) {
return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
: 0;
};
Value input = defOp->getOperand(0);

// If splat or broadcast from a scalar, just return the source scalar.
unsigned broadcastSrcRank = getRank(source.getType());
if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
return source;
// Replace extract(broadcast(X)) with X
if (extractOp.getType() == input.getType())
return input;

unsigned extractResultRank = getRank(extractOp.getType());
if (extractResultRank > broadcastSrcRank)
return Value();
// Check that the dimension of the result haven't been broadcasted.
auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
if (extractVecType && broadcastVecType &&
extractVecType.getShape() !=
broadcastVecType.getShape().take_back(extractResultRank))
// Get required types and ranks in the chain
// input -> broadcast -> extract
auto inputType = llvm::dyn_cast<VectorType>(input.getType());
auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
unsigned inputRank = inputType ? inputType.getRank() : 0;
unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
unsigned extractRank = extractType ? extractType.getRank() : 0;

// Cannot do without the broadcast if overall the rank increases.
if (extractRank > inputRank)
return Value();

auto broadcastOp = cast<vector::BroadcastOp>(defOp);
int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
// Proof by contradiction that, at this point, input is a vector.
// Suppose input is a scalar.
// ==> inputRank is 0.
// ==> extractRank is 0 (because extractRank <= inputRank).
// ==> extract is scalar (because rank-0 extraction is always scalar).
// ==> input and extract are scalar, so same type.
// ==> returned early (check same type).
// Contradiction!
assert(inputType && "input must be a vector type because of previous checks");
ArrayRef<int64_t> inputShape = inputType.getShape();

// In the case where there is a broadcast dimension in the suffix, it is not
// possible to replace extract(broadcast(X)) with extract(X). Example:
//
// broadcast extract
// (1) --------> (3,4) ------> (4)
if (extractType &&
extractType.getShape() != inputShape.take_back(extractRank))
return Value();

// Detect all the positions that come from "dim-1" broadcasting.
// These dimensions correspond to "dim-1" broadcasted dims; set the mathching
// extract position to `0` when extracting from the source operand.
llvm::SetVector<int64_t> broadcastedUnitDims =
broadcastOp.computeBroadcastedUnitDims();
SmallVector<OpFoldResult> extractPos(extractOp.getMixedPosition());
OpBuilder b(extractOp.getContext());
int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
if (broadcastedUnitDims.contains(i))
extractPos[i] = b.getIndexAttr(0);
// `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
// matching extract position when extracting from the source operand.
int64_t rankDiff = broadcastSrcRank - extractResultRank;
extractPos.erase(extractPos.begin(),
std::next(extractPos.begin(), extractPos.size() - rankDiff));
// OpBuilder is only used as a helper to build an I64ArrayAttr.
auto [staticPos, dynPos] = decomposeMixedValues(extractPos);
// Replace extract(broadcast(X)) with extract(X).
// First, determine the new extraction position.
unsigned deltaOverall = inputRank - extractRank;
unsigned deltaBroadcast = broadcastRank - inputRank;
SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
SmallVector<OpFoldResult> newPositions(deltaOverall);
IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
for (auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
}
auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
extractOp->setOperands(
llvm::to_vector(llvm::concat<Value>(ValueRange(source), dynPos)));
llvm::to_vector(llvm::concat<Value>(ValueRange(input), dynPos)));
extractOp.setStaticPosition(staticPos);
return extractOp.getResult();
}
Expand Down Expand Up @@ -2193,32 +2241,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {

LogicalResult matchAndRewrite(ExtractOp extractOp,
PatternRewriter &rewriter) const override {

Operation *defOp = extractOp.getVector().getDefiningOp();
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
VectorType outType = dyn_cast<VectorType>(extractOp.getType());
if (!defOp || !isBroadcastLike(defOp) || !outType)
return failure();

Value source = defOp->getOperand(0);
if (extractOp.getType() == source.getType())
return failure();
auto getRank = [](Type type) {
return llvm::isa<VectorType>(type)
? llvm::cast<VectorType>(type).getRank()
: 0;
};
unsigned broadcastSrcRank = getRank(source.getType());
unsigned extractResultRank = getRank(extractOp.getType());
// We only consider the case where the rank of the source is less than or
// equal to the rank of the extract dst. The other cases are handled in the
// folding patterns.
if (extractResultRank < broadcastSrcRank)
return failure();
// For scalar result, the input can only be a rank-0 vector, which will
// be handled by the folder.
if (extractResultRank == 0)
if (isBroadcastableTo(source.getType(), outType) !=
BroadcastableToResult::Success)
return failure();

rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
extractOp, extractOp.getType(), source);
rewriter.replaceOpWithNewOp<BroadcastOp>(extractOp, outType, source);
return success();
}
};
Expand Down
60 changes: 56 additions & 4 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -764,17 +764,27 @@ func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32

// -----

// CHECK-LABEL: fold_extract_splat
// CHECK-LABEL: fold_extract_scalar_from_splat
// CHECK-SAME: %[[A:.*]]: f32
// CHECK: return %[[A]] : f32
func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
func.func @fold_extract_scalar_from_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
%b = vector.splat %a : vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
return %r : f32
}

// -----

// CHECK-LABEL: fold_extract_vector_from_splat
// CHECK: vector.broadcast {{.*}} f32 to vector<4xf32>
func.func @fold_extract_vector_from_splat(%a : f32, %idx0 : index, %idx1 : index) -> vector<4xf32> {
%b = vector.splat %a : vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
return %r : vector<4xf32>
}

// -----

// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>
// CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
Expand Down Expand Up @@ -804,6 +814,35 @@ func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,

// -----

// Test where the shape_cast is broadcast-like.
// CHECK-LABEL: fold_extract_shape_cast_to_lower_rank
// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index
// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
// CHECK: return %[[B]] : vector<4xf32>
func.func @fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>,
%idx0 : index, %idx1 : index) -> vector<4xf32> {
%b = vector.shape_cast %a : vector<2x4xf32> to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
return %r : vector<4xf32>
}

// -----

// Test where the shape_cast is not broadcast-like, even though it prepends 1s.
// CHECK-LABEL: negative_fold_extract_shape_cast_to_lower_rank
// CHECK-NEXT: vector.shape_cast
// CHECK-NEXT: vector.extract
// CHECK-NEXT: return
func.func @negative_fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>,
%idx0 : index, %idx1 : index) -> vector<2xf32> {
%b = vector.shape_cast %a : vector<2x4xf32> to vector<1x4x2xf32>
%r = vector.extract %b[%idx0, %idx1] : vector<2xf32> from vector<1x4x2xf32>
return %r : vector<2xf32>
}

// -----

// CHECK-LABEL: fold_extract_broadcast_to_higher_rank
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
// CHECK: return %[[B]] : vector<4xf32>
Expand Down Expand Up @@ -831,6 +870,19 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde

// -----

// CHECK-LABEL: fold_extract_broadcastlike_shape_cast
// CHECK-SAME: %[[A:.*]]: vector<1xf32>
// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<1x1xf32>
// CHECK: return %[[R]] : vector<1x1xf32>
func.func @fold_extract_broadcastlike_shape_cast(%a : vector<1xf32>, %idx0 : index)
-> vector<1x1xf32> {
%s = vector.shape_cast %a : vector<1xf32> to vector<1x1x1xf32>
%r = vector.extract %s[%idx0] : vector<1x1xf32> from vector<1x1x1xf32>
return %r : vector<1x1xf32>
}

// -----

// CHECK-LABEL: @fold_extract_shuffle
// CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>
// CHECK-NOT: vector.shuffle
Expand Down Expand Up @@ -1549,7 +1601,7 @@ func.func @negative_store_to_load_tensor_memref(
%arg0 : tensor<?x?xf32>,
%arg1 : memref<?x?xf32>,
%v0 : vector<4x2xf32>
) -> vector<4x2xf32>
) -> vector<4x2xf32>
{
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
Expand Down Expand Up @@ -1606,7 +1658,7 @@ func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor<
// CHECK: vector.transfer_read
func.func @negative_store_to_load_tensor_broadcast_masked(
%arg0 : tensor<?x?xf32>, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>)
-> vector<4x2x6xf32>
-> vector<4x2x6xf32>
{
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
Expand Down