Skip to content

Commit d8730eb

Browse files
[mlir][GPU][transform] Add gpu_to_rocdl conversion pattern to transform dialect
Authored-by: Son Tuan Vu <[email protected]>
1 parent 85aa5f8 commit d8730eb

File tree

4 files changed

+55
-0
lines changed

4 files changed

+55
-0
lines changed

mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,20 @@ def ApplyGPUSubgroupReduceToNVVMConversionPatternsOp : Op<Transform_Dialect,
5454
let assemblyFormat = "attr-dict";
5555
}
5656

57+
def ApplyGPUToROCDLConversionPatternsOp : Op<Transform_Dialect,
58+
"apply_conversion_patterns.gpu.gpu_to_rocdl",
59+
[DeclareOpInterfaceMethods<ConversionPatternDescriptorOpInterface,
60+
["verifyTypeConverter"]>]> {
61+
let description = [{
62+
Collects patterns that convert GPU dialect ops to ROCDL dialect ops. These
63+
patterns require an "LLVMTypeConverter".
64+
}];
65+
let arguments = (ins StrAttr:$chipset);
66+
let assemblyFormat = [{
67+
`chipset` `=` $chipset attr-dict
68+
}];
69+
}
70+
5771
//===----------------------------------------------------------------------===//
5872
// Apply...PatternsOp
5973
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/GPU/TransformOps/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,5 @@ add_mlir_dialect_library(MLIRGPUTransformOps
2424
# ConversionPatterns
2525
MLIRNVGPUToNVVM
2626
MLIRGPUToNVVMTransforms
27+
MLIRGPUToROCDLTransforms
2728
)

mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
1212
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
13+
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
1314
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
1415
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
1516
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -42,6 +43,7 @@
4243
#include "llvm/Support/Debug.h"
4344
#include "llvm/Support/ErrorHandling.h"
4445
#include "llvm/Support/InterleavedRange.h"
46+
#include "llvm/Support/LogicalResult.h"
4547
#include <type_traits>
4648

4749
using namespace mlir;
@@ -129,6 +131,42 @@ LogicalResult transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
129131
return success();
130132
}
131133

134+
void transform::ApplyGPUToROCDLConversionPatternsOp::populatePatterns(
135+
TypeConverter &typeConverter, RewritePatternSet &patterns) {
136+
auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
137+
populateGpuMemorySpaceAttributeConversions(
138+
llvmTypeConverter, [](AddressSpace space) {
139+
switch (space) {
140+
case AddressSpace::Global:
141+
return 1;
142+
case AddressSpace::Workgroup:
143+
return 3;
144+
case AddressSpace::Private:
145+
return 5;
146+
}
147+
llvm_unreachable("unknown address space enum value");
148+
return 0;
149+
});
150+
FailureOr<amdgpu::Chipset> maybeChipset =
151+
amdgpu::Chipset::parse(getChipset());
152+
assert(llvm::succeeded(maybeChipset) && "expected valid chipset");
153+
populateGpuToROCDLConversionPatterns(
154+
llvmTypeConverter, patterns, mlir::gpu::amd::Runtime::HIP, *maybeChipset);
155+
}
156+
157+
LogicalResult
158+
transform::ApplyGPUToROCDLConversionPatternsOp::verifyTypeConverter(
159+
transform::TypeConverterBuilderOpInterface builder) {
160+
FailureOr<amdgpu::Chipset> maybeChipset =
161+
amdgpu::Chipset::parse(getChipset());
162+
if (failed(maybeChipset)) {
163+
return emitOpError("Invalid chipset name: " + getChipset());
164+
}
165+
if (builder.getTypeConverterType() != "LLVMTypeConverter")
166+
return emitOpError("expected LLVMTypeConverter");
167+
return success();
168+
}
169+
132170
//===----------------------------------------------------------------------===//
133171
// Apply...PatternsOp
134172
//===----------------------------------------------------------------------===//s

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5502,13 +5502,15 @@ cc_library(
55025502
":GPUDialect",
55035503
":GPUToGPURuntimeTransforms",
55045504
":GPUToNVVMTransforms",
5505+
":GPUToROCDLTransforms",
55055506
":GPUTransformOpsIncGen",
55065507
":GPUTransforms",
55075508
":IR",
55085509
":LLVMCommonConversion",
55095510
":MemRefDialect",
55105511
":NVGPUDialect",
55115512
":NVVMDialect",
5513+
":ROCDLDialect",
55125514
":SCFDialect",
55135515
":Support",
55145516
":TransformDialect",

0 commit comments

Comments
 (0)