diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td index 36b579485fc04..87423c639945f 100644 --- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td +++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td @@ -54,6 +54,20 @@ def ApplyGPUSubgroupReduceToNVVMConversionPatternsOp : Op]> { + let description = [{ + Collects patterns that convert GPU dialect ops to ROCDL dialect ops. These + patterns require an "LLVMTypeConverter". + }]; + let arguments = (ins StrAttr:$chipset); + let assemblyFormat = [{ + `chipset` `=` $chipset attr-dict + }]; +} + //===----------------------------------------------------------------------===// // Apply...PatternsOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/GPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/GPU/TransformOps/CMakeLists.txt index b26788f675ce5..e5cc0254f1ffe 100644 --- a/mlir/lib/Dialect/GPU/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/TransformOps/CMakeLists.txt @@ -24,4 +24,5 @@ add_mlir_dialect_library(MLIRGPUTransformOps # ConversionPatterns MLIRNVGPUToNVVM MLIRGPUToNVVMTransforms + MLIRGPUToROCDLTransforms ) diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp index f7fcd99b030dd..c9e91535df946 100644 --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -10,6 +10,7 @@ #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" +#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -17,6 +18,7 @@ #include "mlir/Dialect/GPU/TransformOps/Utils.h" #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -40,6 +42,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/InterleavedRange.h" +#include "llvm/Support/LogicalResult.h" #include using namespace mlir; @@ -127,6 +130,41 @@ LogicalResult transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp:: return success(); } +void transform::ApplyGPUToROCDLConversionPatternsOp::populatePatterns( + TypeConverter &typeConverter, RewritePatternSet &patterns) { + auto &llvmTypeConverter = static_cast(typeConverter); + populateGpuMemorySpaceAttributeConversions( + llvmTypeConverter, [](AddressSpace space) { + switch (space) { + case AddressSpace::Global: + return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace; + case AddressSpace::Workgroup: + return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace; + case AddressSpace::Private: + return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace; + } + llvm_unreachable("unknown address space enum value"); + }); + FailureOr maybeChipset = + amdgpu::Chipset::parse(getChipset()); + assert(llvm::succeeded(maybeChipset) && "expected valid chipset"); + populateGpuToROCDLConversionPatterns( + llvmTypeConverter, patterns, mlir::gpu::amd::Runtime::HIP, *maybeChipset); +} + +LogicalResult +transform::ApplyGPUToROCDLConversionPatternsOp::verifyTypeConverter( + transform::TypeConverterBuilderOpInterface builder) { + FailureOr maybeChipset = + amdgpu::Chipset::parse(getChipset()); + if (failed(maybeChipset)) { + return emitOpError("Invalid chipset name: " + getChipset()); + } + if (builder.getTypeConverterType() != "LLVMTypeConverter") + return emitOpError("expected LLVMTypeConverter"); + return success(); +} + //===----------------------------------------------------------------------===// // Apply...PatternsOp //===----------------------------------------------------------------------===//s diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 633776398649f..7f59236a3bb27 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -5505,6 +5505,7 @@ cc_library( ":GPUDialect", ":GPUToGPURuntimeTransforms", ":GPUToNVVMTransforms", + ":GPUToROCDLTransforms", ":GPUTransformOpsIncGen", ":GPUTransforms", ":IR", @@ -5512,6 +5513,7 @@ cc_library( ":MemRefDialect", ":NVGPUDialect", ":NVVMDialect", + ":ROCDLDialect", ":SCFDialect", ":Support", ":TransformDialect",