Skip to content

[MLIR][XeGPU] Add transformation pattern for vector.broadcast in Wg to Sg pass #144417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,46 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
}
};

/// This pattern transforms vector.broadcast ops to work at subgroup level.
struct WgToSgVectorBroadcastOp
: public OpConversionPattern<vector::BroadcastOp> {
using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
VectorType resultType = op.getResult().getType();
ArrayRef<int64_t> wgShape = resultType.getShape();

xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
if (!layout || !layout.getSgLayout())
return failure();
Copy link
Contributor

@chencha3 chencha3 Jun 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably also need to check whether the LayoutAttr of input is broadcastable to the LayoutAttr of output. In your test example: input LayoutAttr is #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]> and output LayoutAttr is #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>, but what if the input LayoutAttr is #xegpu.layout<sg_layout = [2, 1], sg_data = [6, 1], lane_layout = [2, 1], lane_data = [1, 1]>? is the lowering still valid?


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks to me that the current implementation is assuming the rank of source is the same as the rank of the result, which is a subset of the supported semantics of vector.broadcast. I believe it is partially because of the limitation of LayoutAttr. It would be better to add a check.

// TODO: Currently only supports cases where the source and result ranks
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if broadcast is highly N dimensional?
It's probably unlikely to end up with such IR but wonder if logic here is still safe to execute in such a case.

// are the same.
auto srcType =
dyn_cast<VectorType>(adaptor.getOperands().front()[0].getType());
if (!srcType || srcType.getRank() != resultType.getRank())
return failure();

SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
VectorType newResultType =
VectorType::get(sgShape, resultType.getElementType());

SmallVector<Value> newBroadcastOps;
for (auto operand : adaptor.getOperands().front()) {
auto newBroadcast = rewriter.create<vector::BroadcastOp>(
op.getLoc(), newResultType, operand);
xegpu::setLayoutAttr(newBroadcast->getResult(0),
layout.dropSgLayoutAndData());
newBroadcastOps.push_back(newBroadcast.getResult());
}

rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
return success();
}
};

// This pattern transforms elementwise ops to work at subgroup level.
struct WgToSgElementwiseOp : public ConversionPattern {
WgToSgElementwiseOp(MLIRContext *ctx)
Expand Down Expand Up @@ -473,8 +513,8 @@ namespace xegpu {
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
UnrealizedConversionCastOpPattern, WgToSgElementwiseOp>(
patterns.getContext());
UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
WgToSgVectorBroadcastOp>(patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
Expand Down Expand Up @@ -581,6 +621,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(layout);
});

target.addDynamicallyLegalOp<vector::BroadcastOp>(
[=](vector::BroadcastOp op) -> bool {
return isLegal(xegpu::getLayoutAttr(op.getResult()));
});

target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
[=](Operation *op) -> std::optional<bool> {
// Only handle elementwise mappable ops
Expand Down
19 changes: 18 additions & 1 deletion mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,24 @@ gpu.module @test_round_robin_assignment {
gpu.return
}

// CHECK-LABEL: broadcast
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
gpu.func @broadcast(%src: memref<24x1xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
-> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
: !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
-> vector<24x1xf32>
// CHECK-COUNT-3: vector.broadcast {{.*}}
// CHECK-SAME-COUNT-3: {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
// CHECK-SAME-COUNT-3: : vector<2x1xf32> to vector<2x4xf32>
// CHECK-NOT: vector.broadcast
%broadcast = vector.broadcast %load
{layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>}
: vector<24x1xf32> to vector<24x8xf32>
gpu.return
}

gpu.func @scf_for(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
Expand Down Expand Up @@ -197,5 +215,4 @@ gpu.module @test_round_robin_assignment {
xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
gpu.return
}

}
19 changes: 17 additions & 2 deletions mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,22 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
gpu.return
}

// CHECK-LABEL: broadcast
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
gpu.func @broadcast(%src: memref<24x1xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
-> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a test case with broadcast in dim 0 too.

%load = xegpu.load_nd %tdesc
: !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
-> vector<24x1xf32>
// CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
// CHECK-SAME: : vector<12x1xf32> to vector<12x8xf32>
%broadcast = vector.broadcast %load
{layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>}
: vector<24x1xf32> to vector<24x8xf32>
gpu.return
}

gpu.func @scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) {
//CHECK: [[c0:%.+]] = arith.constant 0 : index
//CHECK: [[c128:%.+]] = arith.constant 128 : index
Expand Down Expand Up @@ -295,6 +311,5 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
gpu.return
}


}