@@ -1696,59 +1696,68 @@ static bool hasZeroDimVectors(Operation *op) {
1696
1696
llvm::any_of (op->getResultTypes (), hasZeroDimVectorType);
1697
1697
}
1698
1698
1699
+ // / All BroadcastOps and SplatOps, and ShapeCastOps that only prepends 1s, are
1700
+ // / considered 'broadcastlike'.
1701
+ static bool isBroadcastLike (Operation *op) {
1702
+ if (isa<BroadcastOp, SplatOp>(op))
1703
+ return true ;
1704
+
1705
+ auto shapeCast = dyn_cast<ShapeCastOp>(op);
1706
+ if (!shapeCast)
1707
+ return false ;
1708
+
1709
+ VectorType srcType = shapeCast.getSourceVectorType ();
1710
+ ArrayRef<int64_t > srcShape = srcType.getShape ();
1711
+ uint64_t srcRank = srcType.getRank ();
1712
+ ArrayRef<int64_t > dstShape = shapeCast.getType ().getShape ();
1713
+ return dstShape.size () <= srcRank && dstShape.take_back (srcRank) == srcShape;
1714
+ }
1715
+
1699
1716
// / Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1700
1717
static Value foldExtractFromBroadcast (ExtractOp extractOp) {
1701
- Operation *defOp = extractOp.getVector ().getDefiningOp ();
1702
- if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1718
+
1719
+ Operation *broadcastLikeOp = extractOp.getVector ().getDefiningOp ();
1720
+ if (!broadcastLikeOp || !isBroadcastLike (broadcastLikeOp))
1703
1721
return Value ();
1704
1722
1705
- Value source = defOp->getOperand (0 );
1706
- if (extractOp.getType () == source.getType ())
1707
- return source;
1708
- auto getRank = [](Type type) {
1709
- return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank ()
1710
- : 0 ;
1711
- };
1723
+ Value src = broadcastLikeOp->getOperand (0 );
1724
+
1725
+ // Replace extract(broadcast(X)) with X
1726
+ if (extractOp.getType () == src.getType ())
1727
+ return src;
1712
1728
1713
- // If splat or broadcast from a scalar, just return the source scalar.
1714
- unsigned broadcastSrcRank = getRank (source.getType ());
1715
- if (broadcastSrcRank == 0 && source.getType () == extractOp.getType ())
1716
- return source;
1729
+ // Get required types and ranks in the chain
1730
+ // src -> broadcastDst -> dst
1731
+ auto srcType = llvm::dyn_cast<VectorType>(src.getType ());
1732
+ auto dstType = llvm::dyn_cast<VectorType>(extractOp.getType ());
1733
+ unsigned srcRank = srcType ? srcType.getRank () : 0 ;
1734
+ unsigned broadcastDstRank = extractOp.getSourceVectorType ().getRank ();
1735
+ unsigned dstRank = dstType ? dstType.getRank () : 0 ;
1717
1736
1718
- unsigned extractResultRank = getRank (extractOp. getType ());
1719
- if (extractResultRank > broadcastSrcRank )
1737
+ // Cannot do without the broadcast if overall the rank increases.
1738
+ if (dstRank > srcRank )
1720
1739
return Value ();
1721
- // Check that the dimension of the result haven't been broadcasted.
1722
- auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType ());
1723
- auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType ());
1724
- if (extractVecType && broadcastVecType &&
1725
- extractVecType.getShape () !=
1726
- broadcastVecType.getShape ().take_back (extractResultRank))
1740
+
1741
+ assert (srcType && " src must be a vector type because of previous checks" );
1742
+
1743
+ ArrayRef<int64_t > srcShape = srcType.getShape ();
1744
+ if (dstType && dstType.getShape () != srcShape.take_back (dstRank))
1727
1745
return Value ();
1728
1746
1729
- auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1730
- int64_t broadcastDstRank = broadcastOp.getResultVectorType ().getRank ();
1747
+ // Replace extract(broadcast(X)) with extract(X).
1748
+ // First, determine the new extraction position.
1749
+ unsigned deltaOverall = srcRank - dstRank;
1750
+ unsigned deltaBroadcast = broadcastDstRank - srcRank;
1731
1751
1732
- // Detect all the positions that come from "dim-1" broadcasting.
1733
- // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
1734
- // extract position to `0` when extracting from the source operand.
1735
- llvm::SetVector<int64_t > broadcastedUnitDims =
1736
- broadcastOp.computeBroadcastedUnitDims ();
1737
- SmallVector<OpFoldResult> extractPos (extractOp.getMixedPosition ());
1738
- OpBuilder b (extractOp.getContext ());
1739
- int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1740
- for (int64_t i = broadcastRankDiff, e = extractPos.size (); i < e; ++i)
1741
- if (broadcastedUnitDims.contains (i))
1742
- extractPos[i] = b.getIndexAttr (0 );
1743
- // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
1744
- // matching extract position when extracting from the source operand.
1745
- int64_t rankDiff = broadcastSrcRank - extractResultRank;
1746
- extractPos.erase (extractPos.begin (),
1747
- std::next (extractPos.begin (), extractPos.size () - rankDiff));
1748
- // OpBuilder is only used as a helper to build an I64ArrayAttr.
1749
- auto [staticPos, dynPos] = decomposeMixedValues (extractPos);
1752
+ SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition ();
1753
+ SmallVector<OpFoldResult> newPositions (deltaOverall);
1754
+ IntegerAttr zero = OpBuilder (extractOp.getContext ()).getIndexAttr (0 );
1755
+ for (auto [i, size] : llvm::enumerate (srcShape.take_front (deltaOverall))) {
1756
+ newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1757
+ }
1758
+ auto [staticPos, dynPos] = decomposeMixedValues (newPositions);
1750
1759
extractOp->setOperands (
1751
- llvm::to_vector (llvm::concat<Value>(ValueRange (source ), dynPos)));
1760
+ llvm::to_vector (llvm::concat<Value>(ValueRange (src ), dynPos)));
1752
1761
extractOp.setStaticPosition (staticPos);
1753
1762
return extractOp.getResult ();
1754
1763
}
@@ -2193,32 +2202,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2193
2202
2194
2203
LogicalResult matchAndRewrite (ExtractOp extractOp,
2195
2204
PatternRewriter &rewriter) const override {
2196
- Operation *defOp = extractOp.getVector ().getDefiningOp ();
2197
- if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
2198
- return failure ();
2199
2205
2200
- Value source = defOp->getOperand (0 );
2201
- if (extractOp.getType () == source.getType ())
2206
+ Operation *broadcastLikeOp = extractOp.getVector ().getDefiningOp ();
2207
+ VectorType outType = dyn_cast<VectorType>(extractOp.getType ());
2208
+ if (!broadcastLikeOp || !isBroadcastLike (broadcastLikeOp) || !outType)
2202
2209
return failure ();
2203
- auto getRank = [](Type type) {
2204
- return llvm::isa<VectorType>(type)
2205
- ? llvm::cast<VectorType>(type).getRank ()
2206
- : 0 ;
2207
- };
2208
- unsigned broadcastSrcRank = getRank (source.getType ());
2209
- unsigned extractResultRank = getRank (extractOp.getType ());
2210
- // We only consider the case where the rank of the source is less than or
2211
- // equal to the rank of the extract dst. The other cases are handled in the
2212
- // folding patterns.
2213
- if (extractResultRank < broadcastSrcRank)
2214
- return failure ();
2215
- // For scalar result, the input can only be a rank-0 vector, which will
2216
- // be handled by the folder.
2217
- if (extractResultRank == 0 )
2210
+
2211
+ Value source = broadcastLikeOp->getOperand (0 );
2212
+ if (isBroadcastableTo (source.getType (), outType) !=
2213
+ BroadcastableToResult::Success)
2218
2214
return failure ();
2219
2215
2220
- rewriter.replaceOpWithNewOp <vector::BroadcastOp>(
2221
- extractOp, extractOp.getType (), source);
2216
+ rewriter.replaceOpWithNewOp <BroadcastOp>(extractOp, outType, source);
2222
2217
return success ();
2223
2218
}
2224
2219
};
0 commit comments