-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][XeGPU] Add transformation pattern for vector.broadcast in Wg to Sg pass #144417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f1509d2
c5cd274
2b23906
803a565
2c97ee7
692ae9e
717664f
9d71167
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -331,6 +331,46 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> { | |
} | ||
}; | ||
|
||
/// This pattern transforms vector.broadcast ops to work at subgroup level. | ||
struct WgToSgVectorBroadcastOp | ||
: public OpConversionPattern<vector::BroadcastOp> { | ||
using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern; | ||
|
||
LogicalResult | ||
matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
VectorType resultType = op.getResult().getType(); | ||
ArrayRef<int64_t> wgShape = resultType.getShape(); | ||
|
||
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult()); | ||
if (!layout || !layout.getSgLayout()) | ||
return failure(); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks to me that the current implementation is assuming the rank of source is the same as the rank of the result, which is a subset of the supported semantics of |
||
// TODO: Currently only supports cases where the source and result ranks | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens if broadcast is highly N dimensional? |
||
// are the same. | ||
auto srcType = | ||
dyn_cast<VectorType>(adaptor.getOperands().front()[0].getType()); | ||
if (!srcType || srcType.getRank() != resultType.getRank()) | ||
return failure(); | ||
|
||
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; | ||
VectorType newResultType = | ||
VectorType::get(sgShape, resultType.getElementType()); | ||
|
||
SmallVector<Value> newBroadcastOps; | ||
for (auto operand : adaptor.getOperands().front()) { | ||
auto newBroadcast = rewriter.create<vector::BroadcastOp>( | ||
op.getLoc(), newResultType, operand); | ||
xegpu::setLayoutAttr(newBroadcast->getResult(0), | ||
layout.dropSgLayoutAndData()); | ||
newBroadcastOps.push_back(newBroadcast.getResult()); | ||
} | ||
|
||
rewriter.replaceOpWithMultiple(op, {newBroadcastOps}); | ||
return success(); | ||
} | ||
}; | ||
|
||
// This pattern transforms elementwise ops to work at subgroup level. | ||
struct WgToSgElementwiseOp : public ConversionPattern { | ||
WgToSgElementwiseOp(MLIRContext *ctx) | ||
|
@@ -473,8 +513,8 @@ namespace xegpu { | |
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { | ||
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp, | ||
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp, | ||
UnrealizedConversionCastOpPattern, WgToSgElementwiseOp>( | ||
patterns.getContext()); | ||
UnrealizedConversionCastOpPattern, WgToSgElementwiseOp, | ||
WgToSgVectorBroadcastOp>(patterns.getContext()); | ||
} | ||
} // namespace xegpu | ||
} // namespace mlir | ||
|
@@ -581,6 +621,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() { | |
return isLegal(layout); | ||
}); | ||
|
||
target.addDynamicallyLegalOp<vector::BroadcastOp>( | ||
[=](vector::BroadcastOp op) -> bool { | ||
return isLegal(xegpu::getLayoutAttr(op.getResult())); | ||
}); | ||
|
||
target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>( | ||
[=](Operation *op) -> std::optional<bool> { | ||
// Only handle elementwise mappable ops | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -170,6 +170,22 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { | |
gpu.return | ||
} | ||
|
||
// CHECK-LABEL: broadcast | ||
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32> | ||
gpu.func @broadcast(%src: memref<24x1xf32>) { | ||
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32> | ||
-> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add a test case with broadcast in dim 0 too. |
||
%load = xegpu.load_nd %tdesc | ||
: !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>> | ||
-> vector<24x1xf32> | ||
// CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>} | ||
// CHECK-SAME: : vector<12x1xf32> to vector<12x8xf32> | ||
%broadcast = vector.broadcast %load | ||
{layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>} | ||
: vector<24x1xf32> to vector<24x8xf32> | ||
gpu.return | ||
} | ||
|
||
gpu.func @scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) { | ||
//CHECK: [[c0:%.+]] = arith.constant 0 : index | ||
//CHECK: [[c128:%.+]] = arith.constant 128 : index | ||
|
@@ -295,6 +311,5 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { | |
xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>> | ||
gpu.return | ||
} | ||
|
||
|
||
} | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably also need to check whether the LayoutAttr of input is broadcastable to the LayoutAttr of output. In your test example: input LayoutAttr is
#xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>
and output LayoutAttr is#xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>
, but what if the input LayoutAttr is#xegpu.layout<sg_layout = [2, 1], sg_data = [6, 1], lane_layout = [2, 1], lane_data = [1, 1]>
? is the lowering still valid?