-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][vector] Folder: shape_cast(extract) -> extract #146368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: James Newling (newling) ChangesIn #140583 more shape_cast ops will appear. Specifically broadcasts that just prepend ones become shape_cast ops (i.e. volume preserving broadcasts are canonicalized to shape_casts). This PR ensures that broadcast-like shape_cast ops fold at least as well as broadcast ops. Full diff: https://github.com/llvm/llvm-project/pull/146368.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a11dbe2589205..e4da65252c6e3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1696,59 +1696,68 @@ static bool hasZeroDimVectors(Operation *op) {
llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
}
+/// All BroadcastOps and SplatOps, and ShapeCastOps that only prepends 1s, are
+/// considered 'broadcastlike'.
+static bool isBroadcastLike(Operation *op) {
+ if (isa<BroadcastOp, SplatOp>(op))
+ return true;
+
+ auto shapeCast = dyn_cast<ShapeCastOp>(op);
+ if (!shapeCast)
+ return false;
+
+ VectorType srcType = shapeCast.getSourceVectorType();
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+ uint64_t srcRank = srcType.getRank();
+ ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
+ return dstShape.size() <= srcRank && dstShape.take_back(srcRank) == srcShape;
+}
+
/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
- Operation *defOp = extractOp.getVector().getDefiningOp();
- if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
+
+ Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
+ if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp))
return Value();
- Value source = defOp->getOperand(0);
- if (extractOp.getType() == source.getType())
- return source;
- auto getRank = [](Type type) {
- return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
- : 0;
- };
+ Value src = broadcastLikeOp->getOperand(0);
+
+ // Replace extract(broadcast(X)) with X
+ if (extractOp.getType() == src.getType())
+ return src;
- // If splat or broadcast from a scalar, just return the source scalar.
- unsigned broadcastSrcRank = getRank(source.getType());
- if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
- return source;
+ // Get required types and ranks in the chain
+ // src -> broadcastDst -> dst
+ auto srcType = llvm::dyn_cast<VectorType>(src.getType());
+ auto dstType = llvm::dyn_cast<VectorType>(extractOp.getType());
+ unsigned srcRank = srcType ? srcType.getRank() : 0;
+ unsigned broadcastDstRank = extractOp.getSourceVectorType().getRank();
+ unsigned dstRank = dstType ? dstType.getRank() : 0;
- unsigned extractResultRank = getRank(extractOp.getType());
- if (extractResultRank > broadcastSrcRank)
+ // Cannot do without the broadcast if overall the rank increases.
+ if (dstRank > srcRank)
return Value();
- // Check that the dimension of the result haven't been broadcasted.
- auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
- auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
- if (extractVecType && broadcastVecType &&
- extractVecType.getShape() !=
- broadcastVecType.getShape().take_back(extractResultRank))
+
+ assert(srcType && "src must be a vector type because of previous checks");
+
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+ if (dstType && dstType.getShape() != srcShape.take_back(dstRank))
return Value();
- auto broadcastOp = cast<vector::BroadcastOp>(defOp);
- int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
+ // Replace extract(broadcast(X)) with extract(X).
+ // First, determine the new extraction position.
+ unsigned deltaOverall = srcRank - dstRank;
+ unsigned deltaBroadcast = broadcastDstRank - srcRank;
- // Detect all the positions that come from "dim-1" broadcasting.
- // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
- // extract position to `0` when extracting from the source operand.
- llvm::SetVector<int64_t> broadcastedUnitDims =
- broadcastOp.computeBroadcastedUnitDims();
- SmallVector<OpFoldResult> extractPos(extractOp.getMixedPosition());
- OpBuilder b(extractOp.getContext());
- int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
- for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
- if (broadcastedUnitDims.contains(i))
- extractPos[i] = b.getIndexAttr(0);
- // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
- // matching extract position when extracting from the source operand.
- int64_t rankDiff = broadcastSrcRank - extractResultRank;
- extractPos.erase(extractPos.begin(),
- std::next(extractPos.begin(), extractPos.size() - rankDiff));
- // OpBuilder is only used as a helper to build an I64ArrayAttr.
- auto [staticPos, dynPos] = decomposeMixedValues(extractPos);
+ SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
+ SmallVector<OpFoldResult> newPositions(deltaOverall);
+ IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
+ for (auto [i, size] : llvm::enumerate(srcShape.take_front(deltaOverall))) {
+ newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
+ }
+ auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
extractOp->setOperands(
- llvm::to_vector(llvm::concat<Value>(ValueRange(source), dynPos)));
+ llvm::to_vector(llvm::concat<Value>(ValueRange(src), dynPos)));
extractOp.setStaticPosition(staticPos);
return extractOp.getResult();
}
@@ -2193,32 +2202,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
LogicalResult matchAndRewrite(ExtractOp extractOp,
PatternRewriter &rewriter) const override {
- Operation *defOp = extractOp.getVector().getDefiningOp();
- if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
- return failure();
- Value source = defOp->getOperand(0);
- if (extractOp.getType() == source.getType())
+ Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
+ VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+ if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp) || !outType)
return failure();
- auto getRank = [](Type type) {
- return llvm::isa<VectorType>(type)
- ? llvm::cast<VectorType>(type).getRank()
- : 0;
- };
- unsigned broadcastSrcRank = getRank(source.getType());
- unsigned extractResultRank = getRank(extractOp.getType());
- // We only consider the case where the rank of the source is less than or
- // equal to the rank of the extract dst. The other cases are handled in the
- // folding patterns.
- if (extractResultRank < broadcastSrcRank)
- return failure();
- // For scalar result, the input can only be a rank-0 vector, which will
- // be handled by the folder.
- if (extractResultRank == 0)
+
+ Value source = broadcastLikeOp->getOperand(0);
+ if (isBroadcastableTo(source.getType(), outType) !=
+ BroadcastableToResult::Success)
return failure();
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
- extractOp, extractOp.getType(), source);
+ rewriter.replaceOpWithNewOp<BroadcastOp>(extractOp, outType, source);
return success();
}
};
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 65b73375831da..350233d1f7969 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -764,10 +764,10 @@ func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32
// -----
-// CHECK-LABEL: fold_extract_splat
+// CHECK-LABEL: fold_extract_scalar_from_splat
// CHECK-SAME: %[[A:.*]]: f32
// CHECK: return %[[A]] : f32
-func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
+func.func @fold_extract_scalar_from_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
%b = vector.splat %a : vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
return %r : f32
@@ -775,6 +775,16 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : in
// -----
+// CHECK-LABEL: fold_extract_vector_from_splat
+// CHECK: vector.broadcast {{.*}} f32 to vector<4xf32>
+func.func @fold_extract_vector_from_splat(%a : f32, %idx0 : index, %idx1 : index) -> vector<4xf32> {
+ %b = vector.splat %a : vector<1x2x4xf32>
+ %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
+ return %r : vector<4xf32>
+}
+
+// -----
+
// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>
// CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
@@ -804,6 +814,21 @@ func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
// -----
+// Test where the shape_cast is broadcast-like.
+// CHECK-LABEL: fold_extract_shape_cast_to_lower_rank
+// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
+// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index
+// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
+// CHECK: return %[[B]] : vector<4xf32>
+func.func @fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>,
+ %idx0 : index, %idx1 : index) -> vector<4xf32> {
+ %b = vector.shape_cast %a : vector<2x4xf32> to vector<1x2x4xf32>
+ %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
+ return %r : vector<4xf32>
+}
+
+// -----
+
// CHECK-LABEL: fold_extract_broadcast_to_higher_rank
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
// CHECK: return %[[B]] : vector<4xf32>
@@ -831,6 +856,19 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde
// -----
+// CHECK-LABEL: fold_extract_broadcastlike_shape_cast
+// CHECK-SAME: %[[A:.*]]: vector<1xf32>
+// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<1x1xf32>
+// CHECK: return %[[R]] : vector<1x1xf32>
+func.func @fold_extract_broadcastlike_shape_cast(%a : vector<1xf32>, %idx0 : index)
+ -> vector<1x1xf32> {
+ %s = vector.shape_cast %a : vector<1xf32> to vector<1x1x1xf32>
+ %r = vector.extract %s[%idx0] : vector<1x1xf32> from vector<1x1x1xf32>
+ return %r : vector<1x1xf32>
+}
+
+// -----
+
// CHECK-LABEL: @fold_extract_shuffle
// CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>
// CHECK-NOT: vector.shuffle
@@ -1549,7 +1587,7 @@ func.func @negative_store_to_load_tensor_memref(
%arg0 : tensor<?x?xf32>,
%arg1 : memref<?x?xf32>,
%v0 : vector<4x2xf32>
- ) -> vector<4x2xf32>
+ ) -> vector<4x2xf32>
{
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
@@ -1606,7 +1644,7 @@ func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor<
// CHECK: vector.transfer_read
func.func @negative_store_to_load_tensor_broadcast_masked(
%arg0 : tensor<?x?xf32>, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>)
- -> vector<4x2x6xf32>
+ -> vector<4x2x6xf32>
{
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
|
c129bff
to
b3877b8
Compare
This makes sense and I guess that's what The old logic was quite convoluted, so replacing that with something simpler (while all tests pass) is great, thanks! |
ArrayRef<int64_t> srcShape = srcType.getShape(); | ||
uint64_t srcRank = srcType.getRank(); | ||
ArrayRef<int64_t> dstShape = shapeCast.getType().getShape(); | ||
return dstShape.size() <= srcRank && dstShape.take_back(srcRank) == srcShape; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is looking for shape_cast
that prepends 1s, shouldn't the destination rank be larger than the source rank?
Also, since the comment above mentions "prepends 1s", I'd find this more intuitive:
llvm::all_of(dstShape.take_front(dstRank - srcRank), [](int64_t dim){return dim == 1;})
THis is a nit :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That'd pass for
(3,2000) -> (1,1,2000,3)
which is not broadcast-like. The key is that it only prepends 1s.
I made this mistake in my first implementation! So probably worth adding a negative case, and making the comment clearer (I'm on it).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is looking for shape_cast that prepends 1s, shouldn't the destination rank be larger than the source rank?
Good spot! I introduced this error in an intermediate commit (apologies for squashing commits, I didn't expect any eyes on this until it passed CI). Fixed.
extractVecType.getShape() != | ||
broadcastVecType.getShape().take_back(extractResultRank)) | ||
|
||
assert(srcType && "src must be a vector type because of previous checks"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, this doesn't quite agree with:
unsigned srcRank = srcType ? srcType.getRank() : 0;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's the intervening rank check which allows this assertion. Code is like
if (extractOp.getType() == src.getType())
return src;
[...]
unsigned srcRank = srcType ? srcType.getRank() : 0;
[...]
if (dstRank > srcRank)
return Value();
[...]
assert(srcType && "src must be a vector type because of previous checks");
Suppose src is scalar at the point of assertion.
Then srcRank is 0, so dstRank is 0.
If dstRank is 0, then dst is scalar.
If they're both scalar, we would have returned early (same types).
Contradiction -- src is not scalar.
TBH this is reasoning is probably too complicated, and could be replaced with a if (...) return Value()
if (srcType) return Value();
if (broadcastSrcRank == 0 && source.getType() == extractOp.getType()) | ||
return source; | ||
// Get required types and ranks in the chain | ||
// src -> broadcastDst -> dst |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess ExtractOp
is missing somewhere here?
assert(srcType && "src must be a vector type because of previous checks"); | ||
|
||
ArrayRef<int64_t> srcShape = srcType.getShape(); | ||
if (dstType && dstType.getShape() != srcShape.take_back(dstRank)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a comment explaining what case this is?
Yes, I took the liberty to insert some code simplifications while adding |
In #140583 more shape_cast ops will appear. Specifically, broadcasts that just prepend ones become shape_cast ops (i.e. volume preserving broadcasts are canonicalized to shape_casts). This PR ensures that broadcast-like shape_cast ops fold at least as well as broadcast ops.
This is done by modifying patterns that target broadcast ops, to target 'broadcast-like' ops. No new patterns are added, the patterns that exist are just made to match on shape_casts where appropriate.
This PR also includes minor code simplifications: use
isBroadcastableTo
to simplifyExtractOpFromBroadcast
and simplify how broadcast dims are detected infoldExtractFromBroadcast
. These are NFC.