Skip to content

Commit c129bff

Browse files
committed
extend to broadcastlike, code simplifications
1 parent b822a32 commit c129bff

File tree

2 files changed

+101
-68
lines changed

2 files changed

+101
-68
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 59 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,59 +1696,68 @@ static bool hasZeroDimVectors(Operation *op) {
16961696
llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
16971697
}
16981698

1699+
/// All BroadcastOps and SplatOps, and ShapeCastOps that only prepends 1s, are
1700+
/// considered 'broadcastlike'.
1701+
static bool isBroadcastLike(Operation *op) {
1702+
if (isa<BroadcastOp, SplatOp>(op))
1703+
return true;
1704+
1705+
auto shapeCast = dyn_cast<ShapeCastOp>(op);
1706+
if (!shapeCast)
1707+
return false;
1708+
1709+
VectorType srcType = shapeCast.getSourceVectorType();
1710+
ArrayRef<int64_t> srcShape = srcType.getShape();
1711+
uint64_t srcRank = srcType.getRank();
1712+
ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
1713+
return dstShape.size() <= srcRank && dstShape.take_back(srcRank) == srcShape;
1714+
}
1715+
16991716
/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
17001717
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1701-
Operation *defOp = extractOp.getVector().getDefiningOp();
1702-
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1718+
1719+
Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
1720+
if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp))
17031721
return Value();
17041722

1705-
Value source = defOp->getOperand(0);
1706-
if (extractOp.getType() == source.getType())
1707-
return source;
1708-
auto getRank = [](Type type) {
1709-
return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1710-
: 0;
1711-
};
1723+
Value src = broadcastLikeOp->getOperand(0);
1724+
1725+
// Replace extract(broadcast(X)) with X
1726+
if (extractOp.getType() == src.getType())
1727+
return src;
17121728

1713-
// If splat or broadcast from a scalar, just return the source scalar.
1714-
unsigned broadcastSrcRank = getRank(source.getType());
1715-
if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
1716-
return source;
1729+
// Get required types and ranks in the chain
1730+
// src -> broadcastDst -> dst
1731+
auto srcType = llvm::dyn_cast<VectorType>(src.getType());
1732+
auto dstType = llvm::dyn_cast<VectorType>(extractOp.getType());
1733+
unsigned srcRank = srcType ? srcType.getRank() : 0;
1734+
unsigned broadcastDstRank = extractOp.getSourceVectorType().getRank();
1735+
unsigned dstRank = dstType ? dstType.getRank() : 0;
17171736

1718-
unsigned extractResultRank = getRank(extractOp.getType());
1719-
if (extractResultRank > broadcastSrcRank)
1737+
// Cannot do without the broadcast if overall the rank increases.
1738+
if (dstRank > srcRank)
17201739
return Value();
1721-
// Check that the dimension of the result haven't been broadcasted.
1722-
auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1723-
auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
1724-
if (extractVecType && broadcastVecType &&
1725-
extractVecType.getShape() !=
1726-
broadcastVecType.getShape().take_back(extractResultRank))
1740+
1741+
assert(srcType && "src must be a vector type because of previous checks");
1742+
1743+
ArrayRef<int64_t> srcShape = srcType.getShape();
1744+
if (dstType && dstType.getShape() != srcShape.take_back(dstRank))
17271745
return Value();
17281746

1729-
auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1730-
int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1747+
// Replace extract(broadcast(X)) with extract(X).
1748+
// First, determine the new extraction position.
1749+
unsigned deltaOverall = srcRank - dstRank;
1750+
unsigned deltaBroadcast = broadcastDstRank - srcRank;
17311751

1732-
// Detect all the positions that come from "dim-1" broadcasting.
1733-
// These dimensions correspond to "dim-1" broadcasted dims; set the mathching
1734-
// extract position to `0` when extracting from the source operand.
1735-
llvm::SetVector<int64_t> broadcastedUnitDims =
1736-
broadcastOp.computeBroadcastedUnitDims();
1737-
SmallVector<OpFoldResult> extractPos(extractOp.getMixedPosition());
1738-
OpBuilder b(extractOp.getContext());
1739-
int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1740-
for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1741-
if (broadcastedUnitDims.contains(i))
1742-
extractPos[i] = b.getIndexAttr(0);
1743-
// `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
1744-
// matching extract position when extracting from the source operand.
1745-
int64_t rankDiff = broadcastSrcRank - extractResultRank;
1746-
extractPos.erase(extractPos.begin(),
1747-
std::next(extractPos.begin(), extractPos.size() - rankDiff));
1748-
// OpBuilder is only used as a helper to build an I64ArrayAttr.
1749-
auto [staticPos, dynPos] = decomposeMixedValues(extractPos);
1752+
SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
1753+
SmallVector<OpFoldResult> newPositions(deltaOverall);
1754+
IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
1755+
for (auto [i, size] : llvm::enumerate(srcShape.take_front(deltaOverall))) {
1756+
newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1757+
}
1758+
auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
17501759
extractOp->setOperands(
1751-
llvm::to_vector(llvm::concat<Value>(ValueRange(source), dynPos)));
1760+
llvm::to_vector(llvm::concat<Value>(ValueRange(src), dynPos)));
17521761
extractOp.setStaticPosition(staticPos);
17531762
return extractOp.getResult();
17541763
}
@@ -2193,32 +2202,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
21932202

21942203
LogicalResult matchAndRewrite(ExtractOp extractOp,
21952204
PatternRewriter &rewriter) const override {
2196-
Operation *defOp = extractOp.getVector().getDefiningOp();
2197-
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
2198-
return failure();
21992205

2200-
Value source = defOp->getOperand(0);
2201-
if (extractOp.getType() == source.getType())
2206+
Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
2207+
VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2208+
if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp) || !outType)
22022209
return failure();
2203-
auto getRank = [](Type type) {
2204-
return llvm::isa<VectorType>(type)
2205-
? llvm::cast<VectorType>(type).getRank()
2206-
: 0;
2207-
};
2208-
unsigned broadcastSrcRank = getRank(source.getType());
2209-
unsigned extractResultRank = getRank(extractOp.getType());
2210-
// We only consider the case where the rank of the source is less than or
2211-
// equal to the rank of the extract dst. The other cases are handled in the
2212-
// folding patterns.
2213-
if (extractResultRank < broadcastSrcRank)
2214-
return failure();
2215-
// For scalar result, the input can only be a rank-0 vector, which will
2216-
// be handled by the folder.
2217-
if (extractResultRank == 0)
2210+
2211+
Value source = broadcastLikeOp->getOperand(0);
2212+
if (isBroadcastableTo(source.getType(), outType) !=
2213+
BroadcastableToResult::Success)
22182214
return failure();
22192215

2220-
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
2221-
extractOp, extractOp.getType(), source);
2216+
rewriter.replaceOpWithNewOp<BroadcastOp>(extractOp, outType, source);
22222217
return success();
22232218
}
22242219
};

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -764,17 +764,27 @@ func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32
764764

765765
// -----
766766

767-
// CHECK-LABEL: fold_extract_splat
767+
// CHECK-LABEL: fold_extract_scalar_from_splat
768768
// CHECK-SAME: %[[A:.*]]: f32
769769
// CHECK: return %[[A]] : f32
770-
func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
770+
func.func @fold_extract_scalar_from_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
771771
%b = vector.splat %a : vector<1x2x4xf32>
772772
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
773773
return %r : f32
774774
}
775775

776776
// -----
777777

778+
// CHECK-LABEL: fold_extract_vector_from_splat
779+
// CHECK: vector.broadcast {{.*}} f32 to vector<4xf32>
780+
func.func @fold_extract_vector_from_splat(%a : f32, %idx0 : index, %idx1 : index) -> vector<4xf32> {
781+
%b = vector.splat %a : vector<1x2x4xf32>
782+
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
783+
return %r : vector<4xf32>
784+
}
785+
786+
// -----
787+
778788
// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
779789
// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>
780790
// CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
@@ -804,6 +814,21 @@ func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
804814

805815
// -----
806816

817+
// Test where the shape_cast is broadcast-like.
818+
// CHECK-LABEL: fold_extract_shape_cast_to_lower_rank
819+
// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
820+
// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index
821+
// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
822+
// CHECK: return %[[B]] : vector<4xf32>
823+
func.func @fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>,
824+
%idx0 : index, %idx1 : index) -> vector<4xf32> {
825+
%b = vector.shape_cast %a : vector<2x4xf32> to vector<1x2x4xf32>
826+
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
827+
return %r : vector<4xf32>
828+
}
829+
830+
// -----
831+
807832
// CHECK-LABEL: fold_extract_broadcast_to_higher_rank
808833
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
809834
// CHECK: return %[[B]] : vector<4xf32>
@@ -831,6 +856,19 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde
831856

832857
// -----
833858

859+
// CHECK-LABEL: fold_extract_broadcastlike_shape_cast
860+
// CHECK-SAME: %[[A:.*]]: vector<1xf32>
861+
// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<1x1xf32>
862+
// CHECK: return %[[R]] : vector<1x1xf32>
863+
func.func @fold_extract_broadcastlike_shape_cast(%a : vector<1xf32>, %idx0 : index)
864+
-> vector<1x1xf32> {
865+
%s = vector.shape_cast %a : vector<1xf32> to vector<1x1x1xf32>
866+
%r = vector.extract %s[%idx0] : vector<1x1xf32> from vector<1x1x1xf32>
867+
return %r : vector<1x1xf32>
868+
}
869+
870+
// -----
871+
834872
// CHECK-LABEL: @fold_extract_shuffle
835873
// CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>
836874
// CHECK-NOT: vector.shuffle
@@ -1549,7 +1587,7 @@ func.func @negative_store_to_load_tensor_memref(
15491587
%arg0 : tensor<?x?xf32>,
15501588
%arg1 : memref<?x?xf32>,
15511589
%v0 : vector<4x2xf32>
1552-
) -> vector<4x2xf32>
1590+
) -> vector<4x2xf32>
15531591
{
15541592
%c0 = arith.constant 0 : index
15551593
%cf0 = arith.constant 0.0 : f32
@@ -1606,7 +1644,7 @@ func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor<
16061644
// CHECK: vector.transfer_read
16071645
func.func @negative_store_to_load_tensor_broadcast_masked(
16081646
%arg0 : tensor<?x?xf32>, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>)
1609-
-> vector<4x2x6xf32>
1647+
-> vector<4x2x6xf32>
16101648
{
16111649
%c0 = arith.constant 0 : index
16121650
%cf0 = arith.constant 0.0 : f32

0 commit comments

Comments
 (0)