diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td index 15b14c767b66a..a81b2e83ddefe 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1332,7 +1332,8 @@ def GPU_ShuffleOp : GPU_Op< %3, %4 = gpu.shuffle down %0, %cst1, %width : f32 ``` - For lane `k`, returns the value from lane `(k + 1) % width`. + For lane `k`, returns the value from lane `(k + cst1)`. If `(k + cst1)` is + bigger than or equal to `width`, the value is poison and `valid` is `false`. `up` example: @@ -1341,7 +1342,8 @@ def GPU_ShuffleOp : GPU_Op< %5, %6 = gpu.shuffle up %0, %cst1, %width : f32 ``` - For lane `k`, returns the value from lane `(k - 1) % width`. + For lane `k`, returns the value from lane `(k - cst1)`. If `(k - cst1)` is + smaller than `0`, the value is poison and `valid` is `false`. `idx` example: diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index 78e6ebb523a46..47172b9462658 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -435,26 +435,57 @@ LogicalResult GPUShuffleConversion::matchAndRewrite( return rewriter.notifyMatchFailure( shuffleOp, "shuffle width and target subgroup size mismatch"); + assert(!adaptor.getOffset().getType().isSignedInteger() && + "shuffle offset must be a signless/unsigned integer"); + Location loc = shuffleOp.getLoc(); - Value trueVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), - shuffleOp.getLoc(), rewriter); auto scope = rewriter.getAttr(spirv::Scope::Subgroup); Value result; + Value validVal; switch (shuffleOp.getMode()) { - case gpu::ShuffleMode::XOR: + case gpu::ShuffleMode::XOR: { result = rewriter.create( loc, scope, adaptor.getValue(), adaptor.getOffset()); + validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), + shuffleOp.getLoc(), rewriter); break; - case gpu::ShuffleMode::IDX: + } + case gpu::ShuffleMode::IDX: { result = rewriter.create( loc, scope, adaptor.getValue(), adaptor.getOffset()); + validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), + shuffleOp.getLoc(), rewriter); + break; + } + case gpu::ShuffleMode::DOWN: { + result = rewriter.create( + loc, scope, adaptor.getValue(), adaptor.getOffset()); + + Value laneId = rewriter.create(loc, widthAttr); + Value resultLaneId = + rewriter.create(loc, laneId, adaptor.getOffset()); + validVal = rewriter.create(loc, arith::CmpIPredicate::ult, + resultLaneId, adaptor.getWidth()); break; - default: - return rewriter.notifyMatchFailure(shuffleOp, "unimplemented shuffle mode"); + } + case gpu::ShuffleMode::UP: { + result = rewriter.create( + loc, scope, adaptor.getValue(), adaptor.getOffset()); + + Value laneId = rewriter.create(loc, widthAttr); + Value resultLaneId = + rewriter.create(loc, laneId, adaptor.getOffset()); + auto i32Type = rewriter.getIntegerType(32); + validVal = rewriter.create( + loc, arith::CmpIPredicate::sge, resultLaneId, + rewriter.create( + loc, i32Type, rewriter.getIntegerAttr(i32Type, 0))); + break; + } } - rewriter.replaceOp(shuffleOp, {result, trueVal}); + rewriter.replaceOp(shuffleOp, {result, validVal}); return success(); } diff --git a/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir b/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir index d3d8ec0dab40f..e93f69704f25b 100644 --- a/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir @@ -15,8 +15,8 @@ gpu.module @kernels { // CHECK: %[[MASK:.+]] = spirv.Constant 8 : i32 // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 - // CHECK: %{{.+}} = spirv.Constant true // CHECK: %{{.+}} = spirv.GroupNonUniformShuffleXor %[[VAL]], %[[MASK]] : f32, i32 + // CHECK: %{{.+}} = spirv.Constant true %result, %valid = gpu.shuffle xor %val, %mask, %width : f32 gpu.return } @@ -64,11 +64,78 @@ gpu.module @kernels { // CHECK: %[[MASK:.+]] = spirv.Constant 8 : i32 // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 - // CHECK: %{{.+}} = spirv.Constant true // CHECK: %{{.+}} = spirv.GroupNonUniformShuffle %[[VAL]], %[[MASK]] : f32, i32 + // CHECK: %{{.+}} = spirv.Constant true %result, %valid = gpu.shuffle idx %val, %mask, %width : f32 gpu.return } } } + +// ----- + +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env<#spirv.vce, + #spirv.resource_limits> +} { + +gpu.module @kernels { + // CHECK-LABEL: spirv.func @shuffle_down() + gpu.func @shuffle_down() kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + %offset = arith.constant 4 : i32 + %width = arith.constant 16 : i32 + %val = arith.constant 42.0 : f32 + + // CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32 + // CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32 + // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 + // CHECK: %{{.+}} = spirv.GroupNonUniformShuffleDown %[[VAL]], %[[OFFSET]] : f32, i32 + + // CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__ : !spirv.ptr + // CHECK: %[[LANE_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]] : i32 + // CHECK: %[[VAL_LANE_ID:.+]] = spirv.IAdd %[[LANE_ID]], %[[OFFSET]] : i32 + // CHECK: %[[VALID:.+]] = spirv.ULessThan %[[VAL_LANE_ID]], %[[WIDTH]] : i32 + + %result, %valid = gpu.shuffle down %val, %offset, %width : f32 + gpu.return + } +} + +} + +// ----- + +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env<#spirv.vce, + #spirv.resource_limits> +} { + +gpu.module @kernels { + // CHECK-LABEL: spirv.func @shuffle_up() + gpu.func @shuffle_up() kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + %offset = arith.constant 4 : i32 + %width = arith.constant 16 : i32 + %val = arith.constant 42.0 : f32 + + // CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32 + // CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32 + // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 + // CHECK: %{{.+}} = spirv.GroupNonUniformShuffleUp %[[VAL]], %[[OFFSET]] : f32, i32 + + // CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__ : !spirv.ptr + // CHECK: %[[LANE_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]] : i32 + // CHECK: %[[VAL_LANE_ID:.+]] = spirv.ISub %[[LANE_ID]], %[[OFFSET]] : i32 + // CHECK: %[[CST0:.+]] = spirv.Constant 0 : i32 + // CHECK: %[[VALID:.+]] = spirv.SGreaterThanEqual %[[VAL_LANE_ID]], %[[CST0]] : i32 + + %result, %valid = gpu.shuffle up %val, %offset, %width : f32 + gpu.return + } +} + +}