diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index e3563d10bc6f1..dbb97f230c873 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -331,6 +331,46 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern { } }; +/// This pattern transforms vector.broadcast ops to work at subgroup level. +struct WgToSgVectorBroadcastOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + VectorType resultType = op.getResult().getType(); + ArrayRef wgShape = resultType.getShape(); + + xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult()); + if (!layout || !layout.getSgLayout()) + return failure(); + + // TODO: Currently only supports cases where the source and result ranks + // are the same. + auto srcType = + dyn_cast(adaptor.getOperands().front()[0].getType()); + if (!srcType || srcType.getRank() != resultType.getRank()) + return failure(); + + SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; + VectorType newResultType = + VectorType::get(sgShape, resultType.getElementType()); + + SmallVector newBroadcastOps; + for (auto operand : adaptor.getOperands().front()) { + auto newBroadcast = rewriter.create( + op.getLoc(), newResultType, operand); + xegpu::setLayoutAttr(newBroadcast->getResult(0), + layout.dropSgLayoutAndData()); + newBroadcastOps.push_back(newBroadcast.getResult()); + } + + rewriter.replaceOpWithMultiple(op, {newBroadcastOps}); + return success(); + } +}; + // This pattern transforms elementwise ops to work at subgroup level. struct WgToSgElementwiseOp : public ConversionPattern { WgToSgElementwiseOp(MLIRContext *ctx) @@ -473,8 +513,8 @@ namespace xegpu { void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { patterns.add( - patterns.getContext()); + UnrealizedConversionCastOpPattern, WgToSgElementwiseOp, + WgToSgVectorBroadcastOp>(patterns.getContext()); } } // namespace xegpu } // namespace mlir @@ -581,6 +621,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(layout); }); + target.addDynamicallyLegalOp( + [=](vector::BroadcastOp op) -> bool { + return isLegal(xegpu::getLayoutAttr(op.getResult())); + }); + target.addDynamicallyLegalDialect( [=](Operation *op) -> std::optional { // Only handle elementwise mappable ops diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir index c6124f90e0f48..8a880068aab33 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir @@ -103,6 +103,24 @@ gpu.module @test_round_robin_assignment { gpu.return } + // CHECK-LABEL: broadcast + // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32> + gpu.func @broadcast(%src: memref<24x1xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32> + -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc + : !xegpu.tensor_desc<24x1xf32, #xegpu.layout> + -> vector<24x1xf32> + // CHECK-COUNT-3: vector.broadcast {{.*}} + // CHECK-SAME-COUNT-3: {layout_result_0 = #xegpu.layout} + // CHECK-SAME-COUNT-3: : vector<2x1xf32> to vector<2x4xf32> + // CHECK-NOT: vector.broadcast + %broadcast = vector.broadcast %load + {layout_result_0 = #xegpu.layout} + : vector<24x1xf32> to vector<24x8xf32> + gpu.return + } + gpu.func @scf_for(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) { %c1 = arith.constant 1 : index %c10 = arith.constant 10 : index @@ -197,5 +215,4 @@ gpu.module @test_round_robin_assignment { xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout> gpu.return } - } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir index 44b11c304cc80..f60358f188e72 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir @@ -170,6 +170,22 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { gpu.return } + // CHECK-LABEL: broadcast + // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32> + gpu.func @broadcast(%src: memref<24x1xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32> + -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc + : !xegpu.tensor_desc<24x1xf32, #xegpu.layout> + -> vector<24x1xf32> + // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout} + // CHECK-SAME: : vector<12x1xf32> to vector<12x8xf32> + %broadcast = vector.broadcast %load + {layout_result_0 = #xegpu.layout} + : vector<24x1xf32> to vector<24x8xf32> + gpu.return + } + gpu.func @scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) { //CHECK: [[c0:%.+]] = arith.constant 0 : index //CHECK: [[c128:%.+]] = arith.constant 128 : index @@ -295,6 +311,5 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout> gpu.return } - - } +