From 3abad601d83131e5dfc900e24a3b3725d6ea95d2 Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang Date: Fri, 9 May 2025 09:33:33 +0100 Subject: [PATCH 01/10] [mlir][gpu][spirv] Add patterns for gpu.shuffle up/down Convert gpu.shuffle down %val, %offset, %width to spirv.GroupNonUniformRotateKHR %val, %offset, cluster_size(%width) Convert gpu.shuffle up %val, %offset, %width to %down_offset = arith.subi %width, %offset spirv.GroupNonUniformRotateKHR %val, %down_offset, cluster_size(%width) --- mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 15 ++++- mlir/test/Conversion/GPUToSPIRV/shuffle.mlir | 57 +++++++++++++++++++ 2 files changed, 70 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index 78e6ebb523a46..e705abd0f3863 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -450,8 +450,19 @@ LogicalResult GPUShuffleConversion::matchAndRewrite( result = rewriter.create( loc, scope, adaptor.getValue(), adaptor.getOffset()); break; - default: - return rewriter.notifyMatchFailure(shuffleOp, "unimplemented shuffle mode"); + case gpu::ShuffleMode::DOWN: + result = rewriter.create( + loc, scope, adaptor.getValue(), adaptor.getOffset(), + shuffleOp.getWidth()); + break; + case gpu::ShuffleMode::UP: { + Value offsetForShuffleDown = rewriter.create( + loc, shuffleOp.getWidth(), adaptor.getOffset()); + result = rewriter.create( + loc, scope, adaptor.getValue(), offsetForShuffleDown, + shuffleOp.getWidth()); + break; + } } rewriter.replaceOp(shuffleOp, {result, trueVal}); diff --git a/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir b/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir index d3d8ec0dab40f..5d7d3c81577e3 100644 --- a/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir @@ -72,3 +72,60 @@ gpu.module @kernels { } } + +// ----- + +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env<#spirv.vce, + #spirv.resource_limits> +} { + +gpu.module @kernels { + // CHECK-LABEL: spirv.func @shuffle_down() + gpu.func @shuffle_down() kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + %offset = arith.constant 4 : i32 + %width = arith.constant 16 : i32 + %val = arith.constant 42.0 : f32 + + // CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32 + // CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32 + // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 + // CHECK: %{{.+}} = spirv.Constant true + // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32 + %result, %valid = gpu.shuffle down %val, %offset, %width : f32 + gpu.return + } +} + +} + +// ----- + +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env<#spirv.vce, + #spirv.resource_limits> +} { + +gpu.module @kernels { + // CHECK-LABEL: spirv.func @shuffle_up() + gpu.func @shuffle_up() kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + %offset = arith.constant 4 : i32 + %width = arith.constant 16 : i32 + %val = arith.constant 42.0 : f32 + + // CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32 + // CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32 + // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 + // CHECK: %{{.+}} = spirv.Constant true + // CHECK: %[[DOWN_OFFSET:.+]] = spirv.Constant 12 : i32 + // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR %[[VAL]], %[[DOWN_OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32 + %result, %valid = gpu.shuffle up %val, %offset, %width : f32 + gpu.return + } +} + +} From 4f9e602b933107405372e5570d439261e53ee8e6 Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang Date: Wed, 21 May 2025 13:54:43 +0100 Subject: [PATCH 02/10] The width argument cannot exceed the subgroup limit. --- mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index e705abd0f3863..287465e6581db 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -430,10 +430,12 @@ LogicalResult GPUShuffleConversion::matchAndRewrite( unsigned subgroupSize = targetEnv.getAttr().getResourceLimits().getSubgroupSize(); IntegerAttr widthAttr; + // The width argument specifies the number of lanes that participate in the + // shuffle. The width value should not exceed the subgroup limit. if (!matchPattern(shuffleOp.getWidth(), m_Constant(&widthAttr)) || - widthAttr.getValue().getZExtValue() != subgroupSize) + widthAttr.getValue().getZExtValue() <= subgroupSize) return rewriter.notifyMatchFailure( - shuffleOp, "shuffle width and target subgroup size mismatch"); + shuffleOp, "shuffle width is larger than target subgroup size"); Location loc = shuffleOp.getLoc(); Value trueVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), From 2a3f62cf1dfa01f4e9b3c788341822855804fdec Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang Date: Wed, 21 May 2025 14:07:55 +0100 Subject: [PATCH 03/10] fix typo --- mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index 287465e6581db..899bcb33bd48e 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -433,7 +433,7 @@ LogicalResult GPUShuffleConversion::matchAndRewrite( // The width argument specifies the number of lanes that participate in the // shuffle. The width value should not exceed the subgroup limit. if (!matchPattern(shuffleOp.getWidth(), m_Constant(&widthAttr)) || - widthAttr.getValue().getZExtValue() <= subgroupSize) + widthAttr.getValue().getZExtValue() > subgroupSize) return rewriter.notifyMatchFailure( shuffleOp, "shuffle width is larger than target subgroup size"); From cfe6c2ddc91dcbcc11289249bccba3dc434a368d Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang Date: Thu, 22 May 2025 09:33:22 +0100 Subject: [PATCH 04/10] remove test for gpu.shuffle width != subgroup_size limit --- mlir/test/Conversion/GPUToSPIRV/shuffle.mlir | 23 -------------------- 1 file changed, 23 deletions(-) diff --git a/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir b/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir index 5d7d3c81577e3..f0bf5e110915c 100644 --- a/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir @@ -26,29 +26,6 @@ gpu.module @kernels { // ----- -module attributes { - gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits> -} { - -gpu.module @kernels { - gpu.func @shuffle_xor() kernel - attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { - %mask = arith.constant 8 : i32 - %width = arith.constant 16 : i32 - %val = arith.constant 42.0 : f32 - - // Cannot convert due to shuffle width and target subgroup size mismatch - // expected-error @+1 {{failed to legalize operation 'gpu.shuffle'}} - %result, %valid = gpu.shuffle xor %val, %mask, %width : f32 - gpu.return - } -} - -} - -// ----- - module attributes { gpu.container_module, spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits> From d4d2123957619bc9ab976d5c25d1c73327948b86 Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang Date: Wed, 4 Jun 2025 10:34:41 +0100 Subject: [PATCH 05/10] Remove rotation semantic in gpu.shufflw up/down There is no such semantic in SPIRV OpGroupNonUniformShuffleUp and OpGroupNonUniformShuffleDown. In addition, there is no such semantic in NVVM shfl intrinsics. Refer to NVVM IR spec https://docs.nvidia.com/cuda/archive/12.2.1/nvvm-ir-spec/index.html#data-movement "If the computed source lane index j is in range, the returned i32 value will be the value of %a from lane j; otherwise, it will be the the value of %a from the current thread. If the thread corresponding to lane j is inactive, then the returned i32 value is undefined." --- mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | 6 ++-- mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 21 ++++-------- mlir/test/Conversion/GPUToSPIRV/shuffle.mlir | 32 ++++++++++++++++--- 3 files changed, 38 insertions(+), 21 deletions(-) diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td index 15b14c767b66a..4e1eccbc1030a 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1332,7 +1332,8 @@ def GPU_ShuffleOp : GPU_Op< %3, %4 = gpu.shuffle down %0, %cst1, %width : f32 ``` - For lane `k`, returns the value from lane `(k + 1) % width`. + For lane `k`, returns the value from lane `(k + cst1)`. The resulting value + is undefined if the lane is out of bounds in the subgroup. `up` example: @@ -1341,7 +1342,8 @@ def GPU_ShuffleOp : GPU_Op< %5, %6 = gpu.shuffle up %0, %cst1, %width : f32 ``` - For lane `k`, returns the value from lane `(k - 1) % width`. + For lane `k`, returns the value from lane `(k - cst1)`. The resulting value + is undefined if the lane is out of bounds in the subgroup. `idx` example: diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index 899bcb33bd48e..359bafaa457f9 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -430,12 +430,10 @@ LogicalResult GPUShuffleConversion::matchAndRewrite( unsigned subgroupSize = targetEnv.getAttr().getResourceLimits().getSubgroupSize(); IntegerAttr widthAttr; - // The width argument specifies the number of lanes that participate in the - // shuffle. The width value should not exceed the subgroup limit. if (!matchPattern(shuffleOp.getWidth(), m_Constant(&widthAttr)) || - widthAttr.getValue().getZExtValue() > subgroupSize) + widthAttr.getValue().getZExtValue() != subgroupSize) return rewriter.notifyMatchFailure( - shuffleOp, "shuffle width is larger than target subgroup size"); + shuffleOp, "shuffle width and target subgroup size mismatch"); Location loc = shuffleOp.getLoc(); Value trueVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), @@ -453,19 +451,14 @@ LogicalResult GPUShuffleConversion::matchAndRewrite( loc, scope, adaptor.getValue(), adaptor.getOffset()); break; case gpu::ShuffleMode::DOWN: - result = rewriter.create( - loc, scope, adaptor.getValue(), adaptor.getOffset(), - shuffleOp.getWidth()); + result = rewriter.create( + loc, scope, adaptor.getValue(), adaptor.getOffset()); break; - case gpu::ShuffleMode::UP: { - Value offsetForShuffleDown = rewriter.create( - loc, shuffleOp.getWidth(), adaptor.getOffset()); - result = rewriter.create( - loc, scope, adaptor.getValue(), offsetForShuffleDown, - shuffleOp.getWidth()); + case gpu::ShuffleMode::UP: + result = rewriter.create( + loc, scope, adaptor.getValue(), adaptor.getOffset()); break; } - } rewriter.replaceOp(shuffleOp, {result, trueVal}); return success(); diff --git a/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir b/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir index f0bf5e110915c..56877a756b7ba 100644 --- a/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir @@ -26,6 +26,29 @@ gpu.module @kernels { // ----- +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits> +} { + +gpu.module @kernels { + gpu.func @shuffle_xor() kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + %mask = arith.constant 8 : i32 + %width = arith.constant 16 : i32 + %val = arith.constant 42.0 : f32 + + // Cannot convert due to shuffle width and target subgroup size mismatch + // expected-error @+1 {{failed to legalize operation 'gpu.shuffle'}} + %result, %valid = gpu.shuffle xor %val, %mask, %width : f32 + gpu.return + } +} + +} + +// ----- + module attributes { gpu.container_module, spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits> @@ -54,7 +77,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits> } { @@ -70,7 +93,7 @@ gpu.module @kernels { // CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32 // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 // CHECK: %{{.+}} = spirv.Constant true - // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32 + // CHECK: %{{.+}} = spirv.GroupNonUniformShuffleDown %[[VAL]], %[[OFFSET]] : f32, i32 %result, %valid = gpu.shuffle down %val, %offset, %width : f32 gpu.return } @@ -82,7 +105,7 @@ gpu.module @kernels { module attributes { gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce, + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits> } { @@ -98,8 +121,7 @@ gpu.module @kernels { // CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32 // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 // CHECK: %{{.+}} = spirv.Constant true - // CHECK: %[[DOWN_OFFSET:.+]] = spirv.Constant 12 : i32 - // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR %[[VAL]], %[[DOWN_OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32 + // CHECK: %{{.+}} = spirv.GroupNonUniformShuffleUp %[[VAL]], %[[OFFSET]] : f32, i32 %result, %valid = gpu.shuffle up %val, %offset, %width : f32 gpu.return } From 544ae11e2598b4e1a5289609540451169a44018c Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang Date: Wed, 4 Jun 2025 17:36:30 +0100 Subject: [PATCH 06/10] Refine description and set 'valid' flag according to the resulting landID --- mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | 9 +-- mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 67 +++++++++++++++++-- mlir/test/Conversion/GPUToSPIRV/shuffle.mlir | 39 +++++++++++ 3 files changed, 106 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td index 4e1eccbc1030a..e25012c11b42a 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1332,8 +1332,9 @@ def GPU_ShuffleOp : GPU_Op< %3, %4 = gpu.shuffle down %0, %cst1, %width : f32 ``` - For lane `k`, returns the value from lane `(k + cst1)`. The resulting value - is undefined if the lane is out of bounds in the subgroup. + For lane `k`, returns the value from lane `(k + cst1)`. If `(k + cst1)` is + bigger than or equal to `width`, the value is unspecified and `valid` is + `false`. `up` example: @@ -1342,8 +1343,8 @@ def GPU_ShuffleOp : GPU_Op< %5, %6 = gpu.shuffle up %0, %cst1, %width : f32 ``` - For lane `k`, returns the value from lane `(k - cst1)`. The resulting value - is undefined if the lane is out of bounds in the subgroup. + For lane `k`, returns the value from lane `(k - cst1)`. If `(k - cst1)` is + smaller than `0`, the value is unspecified and `valid` is `false`. `idx` example: diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index 359bafaa457f9..70d792216524c 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -416,6 +416,15 @@ LogicalResult GPUBarrierConversion::matchAndRewrite( return success(); } +template +Value getDimOp(OpBuilder &builder, MLIRContext *ctx, Location loc, + gpu::Dimension dimension) { + Type indexType = IndexType::get(ctx); + IntegerType i32Type = IntegerType::get(ctx, 32); + Value dim = builder.create(loc, indexType, dimension); + return builder.create(loc, i32Type, dim); +} + //===----------------------------------------------------------------------===// // Shuffle //===----------------------------------------------------------------------===// @@ -436,8 +445,8 @@ LogicalResult GPUShuffleConversion::matchAndRewrite( shuffleOp, "shuffle width and target subgroup size mismatch"); Location loc = shuffleOp.getLoc(); - Value trueVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), - shuffleOp.getLoc(), rewriter); + Value validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), + shuffleOp.getLoc(), rewriter); auto scope = rewriter.getAttr(spirv::Scope::Subgroup); Value result; @@ -450,17 +459,65 @@ LogicalResult GPUShuffleConversion::matchAndRewrite( result = rewriter.create( loc, scope, adaptor.getValue(), adaptor.getOffset()); break; - case gpu::ShuffleMode::DOWN: + case gpu::ShuffleMode::DOWN: { result = rewriter.create( loc, scope, adaptor.getValue(), adaptor.getOffset()); + + MLIRContext *ctx = shuffleOp.getContext(); + Value dimX = + getDimOp(rewriter, ctx, loc, gpu::Dimension::x); + Value dimY = + getDimOp(rewriter, ctx, loc, gpu::Dimension::y); + Value tidX = + getDimOp(rewriter, ctx, loc, gpu::Dimension::x); + Value tidY = + getDimOp(rewriter, ctx, loc, gpu::Dimension::y); + Value tidZ = + getDimOp(rewriter, ctx, loc, gpu::Dimension::z); + auto i32Type = rewriter.getIntegerType(32); + Value tmp1 = rewriter.create(loc, i32Type, tidZ, dimY); + Value tmp2 = rewriter.create(loc, i32Type, tmp1, tidY); + Value tmp3 = rewriter.create(loc, i32Type, tmp2, dimX); + Value landId = rewriter.create(loc, i32Type, tmp3, tidX); + + Value resultLandId = + rewriter.create(loc, landId, adaptor.getOffset()); + validVal = rewriter.create(loc, arith::CmpIPredicate::ult, + resultLandId, adaptor.getWidth()); break; - case gpu::ShuffleMode::UP: + } + case gpu::ShuffleMode::UP: { result = rewriter.create( loc, scope, adaptor.getValue(), adaptor.getOffset()); + + MLIRContext *ctx = shuffleOp.getContext(); + Value dimX = + getDimOp(rewriter, ctx, loc, gpu::Dimension::x); + Value dimY = + getDimOp(rewriter, ctx, loc, gpu::Dimension::y); + Value tidX = + getDimOp(rewriter, ctx, loc, gpu::Dimension::x); + Value tidY = + getDimOp(rewriter, ctx, loc, gpu::Dimension::y); + Value tidZ = + getDimOp(rewriter, ctx, loc, gpu::Dimension::z); + auto i32Type = rewriter.getIntegerType(32); + Value tmp1 = rewriter.create(loc, i32Type, tidZ, dimY); + Value tmp2 = rewriter.create(loc, i32Type, tmp1, tidY); + Value tmp3 = rewriter.create(loc, i32Type, tmp2, dimX); + Value landId = rewriter.create(loc, i32Type, tmp3, tidX); + + Value resultLandId = + rewriter.create(loc, landId, adaptor.getOffset()); + validVal = rewriter.create( + loc, arith::CmpIPredicate::sge, resultLandId, + rewriter.create( + loc, i32Type, rewriter.getIntegerAttr(i32Type, 0))); break; } + } - rewriter.replaceOp(shuffleOp, {result, trueVal}); + rewriter.replaceOp(shuffleOp, {result, validVal}); return success(); } diff --git a/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir b/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir index 56877a756b7ba..396421b7585af 100644 --- a/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir @@ -94,6 +94,25 @@ gpu.module @kernels { // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 // CHECK: %{{.+}} = spirv.Constant true // CHECK: %{{.+}} = spirv.GroupNonUniformShuffleDown %[[VAL]], %[[OFFSET]] : f32, i32 + + // CHECK: %[[BLOCK_SIZE_X:.+]] = spirv.Constant 16 : i32 + // CHECK: %[[BLOCK_SIZE_Y:.+]] = spirv.Constant 1 : i32 + // CHECK: %__builtin__LocalInvocationId___addr = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr, Input> + // CHECK: %[[WORKGROUP:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr : vector<3xi32> + // CHECK: %[[THREAD_X:.+]] = spirv.CompositeExtract %[[WORKGROUP]][0 : i32] : vector<3xi32> + // CHECK: %__builtin__LocalInvocationId___addr_1 = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr, Input> + // CHECK: %[[WORKGROUP_1:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr_1 : vector<3xi32> + // CHECK: %[[THREAD_Y:.+]] = spirv.CompositeExtract %[[WORKGROUP_1]][1 : i32] : vector<3xi32> + // CHECK: %__builtin__LocalInvocationId___addr_2 = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr, Input> + // CHECK: %[[WORKGROUP_2:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr_2 : vector<3xi32> + // CHECK: %[[THREAD_Z:.+]] = spirv.CompositeExtract %[[WORKGROUP_2]][2 : i32] : vector<3xi32> + // CHECK: %[[S0:.+]] = spirv.IMul %[[THREAD_Z]], %[[BLOCK_SIZE_Y]] : i32 + // CHECK: %[[S1:.+]] = spirv.IAdd %[[S0]], %[[THREAD_Y]] : i32 + // CHECK: %[[S2:.+]] = spirv.IMul %[[S1]], %[[BLOCK_SIZE_X]] : i32 + // CHECK: %[[LANE_ID:.+]] = spirv.IAdd %[[S2]], %[[THREAD_X]] : i32 + // CHECK: %[[VAL_LANE_ID:.+]] = spirv.IAdd %[[LANE_ID]], %[[OFFSET]] : i32 + // CHECK: %[[VALID:.+]] = spirv.ULessThan %[[VAL_LANE_ID]], %[[WIDTH]] : i32 + %result, %valid = gpu.shuffle down %val, %offset, %width : f32 gpu.return } @@ -122,6 +141,26 @@ gpu.module @kernels { // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 // CHECK: %{{.+}} = spirv.Constant true // CHECK: %{{.+}} = spirv.GroupNonUniformShuffleUp %[[VAL]], %[[OFFSET]] : f32, i32 + + // CHECK: %[[BLOCK_SIZE_X:.+]] = spirv.Constant 16 : i32 + // CHECK: %[[BLOCK_SIZE_Y:.+]] = spirv.Constant 1 : i32 + // CHECK: %__builtin__LocalInvocationId___addr = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr, Input> + // CHECK: %[[WORKGROUP:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr : vector<3xi32> + // CHECK: %[[THREAD_X:.+]] = spirv.CompositeExtract %[[WORKGROUP]][0 : i32] : vector<3xi32> + // CHECK: %__builtin__LocalInvocationId___addr_1 = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr, Input> + // CHECK: %[[WORKGROUP_1:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr_1 : vector<3xi32> + // CHECK: %[[THREAD_Y:.+]] = spirv.CompositeExtract %[[WORKGROUP_1]][1 : i32] : vector<3xi32> + // CHECK: %__builtin__LocalInvocationId___addr_2 = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr, Input> + // CHECK: %[[WORKGROUP_2:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr_2 : vector<3xi32> + // CHECK: %[[THREAD_Z:.+]] = spirv.CompositeExtract %[[WORKGROUP_2]][2 : i32] : vector<3xi32> + // CHECK: %[[S0:.+]] = spirv.IMul %[[THREAD_Z]], %[[BLOCK_SIZE_Y]] : i32 + // CHECK: %[[S1:.+]] = spirv.IAdd %[[S0]], %[[THREAD_Y]] : i32 + // CHECK: %[[S2:.+]] = spirv.IMul %[[S1]], %[[BLOCK_SIZE_X]] : i32 + // CHECK: %[[LANE_ID:.+]] = spirv.IAdd %[[S2]], %[[THREAD_X]] : i32 + // CHECK: %[[VAL_LANE_ID:.+]] = spirv.ISub %[[LANE_ID]], %[[OFFSET]] : i32 + // CHECK: %[[CST0:.+]] = spirv.Constant 0 : i32 + // CHECK: %[[VALID:.+]] = spirv.SGreaterThanEqual %[[VAL_LANE_ID]], %[[CST0]] : i32 + %result, %valid = gpu.shuffle up %val, %offset, %width : f32 gpu.return } From 1ae89dc96ea8afe740ff06401400f6f3162c2039 Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang Date: Wed, 4 Jun 2025 19:58:52 +0100 Subject: [PATCH 07/10] refactor and update wording --- mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | 5 +- mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 54 +++++++------------ 2 files changed, 22 insertions(+), 37 deletions(-) diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td index e25012c11b42a..a81b2e83ddefe 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1333,8 +1333,7 @@ def GPU_ShuffleOp : GPU_Op< ``` For lane `k`, returns the value from lane `(k + cst1)`. If `(k + cst1)` is - bigger than or equal to `width`, the value is unspecified and `valid` is - `false`. + bigger than or equal to `width`, the value is poison and `valid` is `false`. `up` example: @@ -1344,7 +1343,7 @@ def GPU_ShuffleOp : GPU_Op< ``` For lane `k`, returns the value from lane `(k - cst1)`. If `(k - cst1)` is - smaller than `0`, the value is unspecified and `valid` is `false`. + smaller than `0`, the value is poison and `valid` is `false`. `idx` example: diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index 70d792216524c..c3bf017c56016 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -425,6 +425,21 @@ Value getDimOp(OpBuilder &builder, MLIRContext *ctx, Location loc, return builder.create(loc, i32Type, dim); } +Value getLaneId(OpBuilder &rewriter, MLIRContext *ctx, Location loc) { + Value dimX = getDimOp(rewriter, ctx, loc, gpu::Dimension::x); + Value dimY = getDimOp(rewriter, ctx, loc, gpu::Dimension::y); + Value tidX = getDimOp(rewriter, ctx, loc, gpu::Dimension::x); + Value tidY = getDimOp(rewriter, ctx, loc, gpu::Dimension::y); + Value tidZ = getDimOp(rewriter, ctx, loc, gpu::Dimension::z); + auto i32Type = rewriter.getIntegerType(32); + Value tmp1 = rewriter.create(loc, i32Type, tidZ, dimY); + Value tmp2 = rewriter.create(loc, i32Type, tmp1, tidY); + Value tmp3 = rewriter.create(loc, i32Type, tmp2, dimX); + Value laneId = rewriter.create(loc, i32Type, tmp3, tidX); + + return laneId; +} + //===----------------------------------------------------------------------===// // Shuffle //===----------------------------------------------------------------------===// @@ -464,24 +479,9 @@ LogicalResult GPUShuffleConversion::matchAndRewrite( loc, scope, adaptor.getValue(), adaptor.getOffset()); MLIRContext *ctx = shuffleOp.getContext(); - Value dimX = - getDimOp(rewriter, ctx, loc, gpu::Dimension::x); - Value dimY = - getDimOp(rewriter, ctx, loc, gpu::Dimension::y); - Value tidX = - getDimOp(rewriter, ctx, loc, gpu::Dimension::x); - Value tidY = - getDimOp(rewriter, ctx, loc, gpu::Dimension::y); - Value tidZ = - getDimOp(rewriter, ctx, loc, gpu::Dimension::z); - auto i32Type = rewriter.getIntegerType(32); - Value tmp1 = rewriter.create(loc, i32Type, tidZ, dimY); - Value tmp2 = rewriter.create(loc, i32Type, tmp1, tidY); - Value tmp3 = rewriter.create(loc, i32Type, tmp2, dimX); - Value landId = rewriter.create(loc, i32Type, tmp3, tidX); - + Value laneId = getLaneId(rewriter, ctx, loc); Value resultLandId = - rewriter.create(loc, landId, adaptor.getOffset()); + rewriter.create(loc, laneId, adaptor.getOffset()); validVal = rewriter.create(loc, arith::CmpIPredicate::ult, resultLandId, adaptor.getWidth()); break; @@ -491,24 +491,10 @@ LogicalResult GPUShuffleConversion::matchAndRewrite( loc, scope, adaptor.getValue(), adaptor.getOffset()); MLIRContext *ctx = shuffleOp.getContext(); - Value dimX = - getDimOp(rewriter, ctx, loc, gpu::Dimension::x); - Value dimY = - getDimOp(rewriter, ctx, loc, gpu::Dimension::y); - Value tidX = - getDimOp(rewriter, ctx, loc, gpu::Dimension::x); - Value tidY = - getDimOp(rewriter, ctx, loc, gpu::Dimension::y); - Value tidZ = - getDimOp(rewriter, ctx, loc, gpu::Dimension::z); - auto i32Type = rewriter.getIntegerType(32); - Value tmp1 = rewriter.create(loc, i32Type, tidZ, dimY); - Value tmp2 = rewriter.create(loc, i32Type, tmp1, tidY); - Value tmp3 = rewriter.create(loc, i32Type, tmp2, dimX); - Value landId = rewriter.create(loc, i32Type, tmp3, tidX); - + Value laneId = getLaneId(rewriter, ctx, loc); Value resultLandId = - rewriter.create(loc, landId, adaptor.getOffset()); + rewriter.create(loc, laneId, adaptor.getOffset()); + auto i32Type = rewriter.getIntegerType(32); validVal = rewriter.create( loc, arith::CmpIPredicate::sge, resultLandId, rewriter.create( From 6c7be291bc9c7b6a8bf05bb963f6fe90bf579431 Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang Date: Fri, 6 Jun 2025 11:34:26 +0100 Subject: [PATCH 08/10] Use gpu::LaneIdOp --- mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 43 +++++-------------- mlir/test/Conversion/GPUToSPIRV/shuffle.mlir | 40 +++-------------- 2 files changed, 17 insertions(+), 66 deletions(-) diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index c3bf017c56016..0ae234c0cee20 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -416,30 +416,6 @@ LogicalResult GPUBarrierConversion::matchAndRewrite( return success(); } -template -Value getDimOp(OpBuilder &builder, MLIRContext *ctx, Location loc, - gpu::Dimension dimension) { - Type indexType = IndexType::get(ctx); - IntegerType i32Type = IntegerType::get(ctx, 32); - Value dim = builder.create(loc, indexType, dimension); - return builder.create(loc, i32Type, dim); -} - -Value getLaneId(OpBuilder &rewriter, MLIRContext *ctx, Location loc) { - Value dimX = getDimOp(rewriter, ctx, loc, gpu::Dimension::x); - Value dimY = getDimOp(rewriter, ctx, loc, gpu::Dimension::y); - Value tidX = getDimOp(rewriter, ctx, loc, gpu::Dimension::x); - Value tidY = getDimOp(rewriter, ctx, loc, gpu::Dimension::y); - Value tidZ = getDimOp(rewriter, ctx, loc, gpu::Dimension::z); - auto i32Type = rewriter.getIntegerType(32); - Value tmp1 = rewriter.create(loc, i32Type, tidZ, dimY); - Value tmp2 = rewriter.create(loc, i32Type, tmp1, tidY); - Value tmp3 = rewriter.create(loc, i32Type, tmp2, dimX); - Value laneId = rewriter.create(loc, i32Type, tmp3, tidX); - - return laneId; -} - //===----------------------------------------------------------------------===// // Shuffle //===----------------------------------------------------------------------===// @@ -460,26 +436,30 @@ LogicalResult GPUShuffleConversion::matchAndRewrite( shuffleOp, "shuffle width and target subgroup size mismatch"); Location loc = shuffleOp.getLoc(); - Value validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), - shuffleOp.getLoc(), rewriter); auto scope = rewriter.getAttr(spirv::Scope::Subgroup); Value result; + Value validVal; switch (shuffleOp.getMode()) { - case gpu::ShuffleMode::XOR: + case gpu::ShuffleMode::XOR: { result = rewriter.create( loc, scope, adaptor.getValue(), adaptor.getOffset()); + validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), + shuffleOp.getLoc(), rewriter); break; - case gpu::ShuffleMode::IDX: + } + case gpu::ShuffleMode::IDX: { result = rewriter.create( loc, scope, adaptor.getValue(), adaptor.getOffset()); + validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), + shuffleOp.getLoc(), rewriter); break; + } case gpu::ShuffleMode::DOWN: { result = rewriter.create( loc, scope, adaptor.getValue(), adaptor.getOffset()); - MLIRContext *ctx = shuffleOp.getContext(); - Value laneId = getLaneId(rewriter, ctx, loc); + Value laneId = rewriter.create(loc, widthAttr); Value resultLandId = rewriter.create(loc, laneId, adaptor.getOffset()); validVal = rewriter.create(loc, arith::CmpIPredicate::ult, @@ -490,8 +470,7 @@ LogicalResult GPUShuffleConversion::matchAndRewrite( result = rewriter.create( loc, scope, adaptor.getValue(), adaptor.getOffset()); - MLIRContext *ctx = shuffleOp.getContext(); - Value laneId = getLaneId(rewriter, ctx, loc); + Value laneId = rewriter.create(loc, widthAttr); Value resultLandId = rewriter.create(loc, laneId, adaptor.getOffset()); auto i32Type = rewriter.getIntegerType(32); diff --git a/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir b/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir index 396421b7585af..e93f69704f25b 100644 --- a/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir @@ -15,8 +15,8 @@ gpu.module @kernels { // CHECK: %[[MASK:.+]] = spirv.Constant 8 : i32 // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 - // CHECK: %{{.+}} = spirv.Constant true // CHECK: %{{.+}} = spirv.GroupNonUniformShuffleXor %[[VAL]], %[[MASK]] : f32, i32 + // CHECK: %{{.+}} = spirv.Constant true %result, %valid = gpu.shuffle xor %val, %mask, %width : f32 gpu.return } @@ -64,8 +64,8 @@ gpu.module @kernels { // CHECK: %[[MASK:.+]] = spirv.Constant 8 : i32 // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 - // CHECK: %{{.+}} = spirv.Constant true // CHECK: %{{.+}} = spirv.GroupNonUniformShuffle %[[VAL]], %[[MASK]] : f32, i32 + // CHECK: %{{.+}} = spirv.Constant true %result, %valid = gpu.shuffle idx %val, %mask, %width : f32 gpu.return } @@ -92,24 +92,10 @@ gpu.module @kernels { // CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32 // CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32 // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 - // CHECK: %{{.+}} = spirv.Constant true // CHECK: %{{.+}} = spirv.GroupNonUniformShuffleDown %[[VAL]], %[[OFFSET]] : f32, i32 - // CHECK: %[[BLOCK_SIZE_X:.+]] = spirv.Constant 16 : i32 - // CHECK: %[[BLOCK_SIZE_Y:.+]] = spirv.Constant 1 : i32 - // CHECK: %__builtin__LocalInvocationId___addr = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr, Input> - // CHECK: %[[WORKGROUP:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr : vector<3xi32> - // CHECK: %[[THREAD_X:.+]] = spirv.CompositeExtract %[[WORKGROUP]][0 : i32] : vector<3xi32> - // CHECK: %__builtin__LocalInvocationId___addr_1 = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr, Input> - // CHECK: %[[WORKGROUP_1:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr_1 : vector<3xi32> - // CHECK: %[[THREAD_Y:.+]] = spirv.CompositeExtract %[[WORKGROUP_1]][1 : i32] : vector<3xi32> - // CHECK: %__builtin__LocalInvocationId___addr_2 = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr, Input> - // CHECK: %[[WORKGROUP_2:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr_2 : vector<3xi32> - // CHECK: %[[THREAD_Z:.+]] = spirv.CompositeExtract %[[WORKGROUP_2]][2 : i32] : vector<3xi32> - // CHECK: %[[S0:.+]] = spirv.IMul %[[THREAD_Z]], %[[BLOCK_SIZE_Y]] : i32 - // CHECK: %[[S1:.+]] = spirv.IAdd %[[S0]], %[[THREAD_Y]] : i32 - // CHECK: %[[S2:.+]] = spirv.IMul %[[S1]], %[[BLOCK_SIZE_X]] : i32 - // CHECK: %[[LANE_ID:.+]] = spirv.IAdd %[[S2]], %[[THREAD_X]] : i32 + // CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__ : !spirv.ptr + // CHECK: %[[LANE_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]] : i32 // CHECK: %[[VAL_LANE_ID:.+]] = spirv.IAdd %[[LANE_ID]], %[[OFFSET]] : i32 // CHECK: %[[VALID:.+]] = spirv.ULessThan %[[VAL_LANE_ID]], %[[WIDTH]] : i32 @@ -139,24 +125,10 @@ gpu.module @kernels { // CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32 // CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32 // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 - // CHECK: %{{.+}} = spirv.Constant true // CHECK: %{{.+}} = spirv.GroupNonUniformShuffleUp %[[VAL]], %[[OFFSET]] : f32, i32 - // CHECK: %[[BLOCK_SIZE_X:.+]] = spirv.Constant 16 : i32 - // CHECK: %[[BLOCK_SIZE_Y:.+]] = spirv.Constant 1 : i32 - // CHECK: %__builtin__LocalInvocationId___addr = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr, Input> - // CHECK: %[[WORKGROUP:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr : vector<3xi32> - // CHECK: %[[THREAD_X:.+]] = spirv.CompositeExtract %[[WORKGROUP]][0 : i32] : vector<3xi32> - // CHECK: %__builtin__LocalInvocationId___addr_1 = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr, Input> - // CHECK: %[[WORKGROUP_1:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr_1 : vector<3xi32> - // CHECK: %[[THREAD_Y:.+]] = spirv.CompositeExtract %[[WORKGROUP_1]][1 : i32] : vector<3xi32> - // CHECK: %__builtin__LocalInvocationId___addr_2 = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr, Input> - // CHECK: %[[WORKGROUP_2:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr_2 : vector<3xi32> - // CHECK: %[[THREAD_Z:.+]] = spirv.CompositeExtract %[[WORKGROUP_2]][2 : i32] : vector<3xi32> - // CHECK: %[[S0:.+]] = spirv.IMul %[[THREAD_Z]], %[[BLOCK_SIZE_Y]] : i32 - // CHECK: %[[S1:.+]] = spirv.IAdd %[[S0]], %[[THREAD_Y]] : i32 - // CHECK: %[[S2:.+]] = spirv.IMul %[[S1]], %[[BLOCK_SIZE_X]] : i32 - // CHECK: %[[LANE_ID:.+]] = spirv.IAdd %[[S2]], %[[THREAD_X]] : i32 + // CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__ : !spirv.ptr + // CHECK: %[[LANE_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]] : i32 // CHECK: %[[VAL_LANE_ID:.+]] = spirv.ISub %[[LANE_ID]], %[[OFFSET]] : i32 // CHECK: %[[CST0:.+]] = spirv.Constant 0 : i32 // CHECK: %[[VALID:.+]] = spirv.SGreaterThanEqual %[[VAL_LANE_ID]], %[[CST0]] : i32 From 6bffc6b7518189850da2843204035524b9960d5f Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang Date: Mon, 9 Jun 2025 10:15:03 +0100 Subject: [PATCH 09/10] Fix typo and add signless/unsigned checking for offset --- mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index 0ae234c0cee20..70ed80a84df3d 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -435,6 +435,11 @@ LogicalResult GPUShuffleConversion::matchAndRewrite( return rewriter.notifyMatchFailure( shuffleOp, "shuffle width and target subgroup size mismatch"); + // Ensure the offset is a signless/unsigned integer. + if (adaptor.getOffset().getType().isSignedInteger()) + return rewriter.notifyMatchFailure( + shuffleOp, "shuffle offset must be a signless/unsigned integer"); + Location loc = shuffleOp.getLoc(); auto scope = rewriter.getAttr(spirv::Scope::Subgroup); Value result; @@ -460,10 +465,10 @@ LogicalResult GPUShuffleConversion::matchAndRewrite( loc, scope, adaptor.getValue(), adaptor.getOffset()); Value laneId = rewriter.create(loc, widthAttr); - Value resultLandId = + Value resultLaneId = rewriter.create(loc, laneId, adaptor.getOffset()); validVal = rewriter.create(loc, arith::CmpIPredicate::ult, - resultLandId, adaptor.getWidth()); + resultLaneId, adaptor.getWidth()); break; } case gpu::ShuffleMode::UP: { @@ -471,11 +476,11 @@ LogicalResult GPUShuffleConversion::matchAndRewrite( loc, scope, adaptor.getValue(), adaptor.getOffset()); Value laneId = rewriter.create(loc, widthAttr); - Value resultLandId = + Value resultLaneId = rewriter.create(loc, laneId, adaptor.getOffset()); auto i32Type = rewriter.getIntegerType(32); validVal = rewriter.create( - loc, arith::CmpIPredicate::sge, resultLandId, + loc, arith::CmpIPredicate::sge, resultLaneId, rewriter.create( loc, i32Type, rewriter.getIntegerAttr(i32Type, 0))); break; From 47db273b540e395a7063b743d8a178fb81ceb70b Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang Date: Mon, 9 Jun 2025 19:18:47 +0100 Subject: [PATCH 10/10] Use assert instead of if checking --- mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index 70ed80a84df3d..47172b9462658 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -435,10 +435,8 @@ LogicalResult GPUShuffleConversion::matchAndRewrite( return rewriter.notifyMatchFailure( shuffleOp, "shuffle width and target subgroup size mismatch"); - // Ensure the offset is a signless/unsigned integer. - if (adaptor.getOffset().getType().isSignedInteger()) - return rewriter.notifyMatchFailure( - shuffleOp, "shuffle offset must be a signless/unsigned integer"); + assert(!adaptor.getOffset().getType().isSignedInteger() && + "shuffle offset must be a signless/unsigned integer"); Location loc = shuffleOp.getLoc(); auto scope = rewriter.getAttr(spirv::Scope::Subgroup);