From 3966b5d5773ba8f719ec0a3baae78b90f53908bd Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Thu, 1 May 2025 06:00:13 +0000 Subject: [PATCH 1/5] [mlir][gpu] Add pass for imitating unsupported types. This pass imitates (bitcast/reinterpret_cast) unsupported types with supported types of same bitwidth. The imitation is done by bitcasting the unspported types to the supported types of same bitwidth. Therefore, the source type and destination type must have the same bitwidth. The imitation is done by using the following operations: arith.bitcast. The imitation is often needed when the GPU target (dialect/IR) does not support a certain type but the underlying architecture does. Take SPIR-V for example, it does not support bf16, but an underlying architecture (e.g., intel pvc gpu) that uses SPIR-V for code-generation does. Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a kernel parameter or inside the kernel), bf16 have to be bitcasted (similar to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The SPIR-V kernel can then use the imitated type (i16) in the computation. However, i16 is not the same as bf16 (integer vs float), so the computation can not readily use the imitated type (i16). Therefore, this transformation pass is intended to be used in conjuction with other transformation passes such as `EmulateUnsupportedFloats` and `ExtendUnsupportedTypes` that extend the bitwidth of bf16 to f32 and vice-versa. Finally, usually, there are instructions available in the target (dialect/IR) that can take advantage of these generated patterns (bf16->i16->f32, f32->bf16->i16), and convert them to the supported types. For example, Intel provides SPIR-V extension ops that can take imitated bf16 (i16) and convert them to f32 and vice-versa. https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op --- mlir/include/mlir/Dialect/Arith/Utils/Utils.h | 6 + .../mlir/Dialect/GPU/Transforms/Passes.h | 20 + .../mlir/Dialect/GPU/Transforms/Passes.td | 53 + mlir/lib/Dialect/Arith/Utils/Utils.cpp | 25 + mlir/lib/Dialect/GPU/CMakeLists.txt | 5 +- .../Transforms/ImitateUnsupportedTypes.cpp | 915 ++++++++++++++++++ .../GPU/imitate-unsupported-types.mlir | 141 +++ 7 files changed, 1163 insertions(+), 2 deletions(-) create mode 100644 mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp create mode 100644 mlir/test/Dialect/GPU/imitate-unsupported-types.mlir diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h index c0b286494996b..ef5ff54a2f470 100644 --- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h @@ -146,6 +146,12 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef values, // Map strings to float types. std::optional parseFloatType(MLIRContext *ctx, StringRef name); +// Map strings to Int types. +std::optional parseIntType(MLIRContext *ctx, StringRef name); + +// Map strings to int or float types. +std::optional parseIntOrFloatType(MLIRContext *ctx, StringRef name); + } // namespace arith } // namespace mlir diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h index 6cd6f03253aea..0b7339a94b274 100644 --- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h @@ -16,6 +16,8 @@ #include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/GPU/Utils/GPUUtils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include @@ -87,6 +89,24 @@ void populateGpuLowerClusteredSubgroupReduceToDPPPatterns( RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset, PatternBenefit benefit = 1); +/// Set up a type converter to convert unsupported source types to +/// supported target types. +void populateImitateUnsupportedTypesTypeConverter(TypeConverter &typeConverter, + ArrayRef sourceTypes, + ArrayRef targetTypes); + +/// Collect a set of pattern needed to imitate unsupported source types +/// using supported target types. +void populateImitateUnsupportedTypesConversionPatterns( + RewritePatternSet &patterns, TypeConverter &typeConverter, + ArrayRef sourceTypes, ArrayRef targetTypes, + DenseMap &convertedFuncTypes); + +/// Set up a dialect conversion to reject operations on unsupported +/// float types. +void configureImitateUnsupportedTypesLegality(ConversionTarget &target, + TypeConverter &typeConverter); + /// Collect all patterns to rewrite ops within the GPU dialect. inline void populateGpuRewritePatterns(RewritePatternSet &patterns) { populateGpuAllReducePatterns(patterns); diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td index 3766eb16e9429..feb1b2820abd6 100644 --- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td @@ -258,4 +258,57 @@ def GpuSPIRVAttachTarget: Pass<"spirv-attach-target", ""> { ]; } +def GpuImitateUnsupportedTypes : Pass<"imitate-unsupported-types", "::mlir::ModuleOp"> { + let summary = "Imitate unsupported types with supported types of same bitwidth."; + let description = [{ + This pass imitates (bitcast/reinterpret_cast) unsupported types + with supported types of same bitwidth. The imitation is done + by bitcasting the unspported types to the supported types of same bitwidth. + Therefore, the source type and destination type must have the same bitwidth. + The imitation is done by using the following operations: arith.bitcast. + + The imitation is often needed when the GPU target (dialect/IR) does not + support a certain type but the underlying architecture does. Take SPIR-V for + example, it does not support bf16, but an underlying architecture (e.g., + intel pvc gpu) that uses SPIR-V for code-generation does. + Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to + be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a + kernel parameter or inside the kernel), bf16 have to be bitcasted (similar + to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The + SPIR-V kernel can then use the imitated type (i16) in the computation. + However, i16 is not the same as bf16 (integer vs float), so the computation + can not readily use the imitated type (i16). + + Therefore, this transformation pass is intended to be used in conjuction + with other transformation passes such as `EmulateUnsupportedFloats` and + `ExtendUnsupportedTypes` that extend the bitwidth of bf16 to f32 and + vice-versa. + + Finally, usually, there are instructions available in the target + (dialect/IR) that can take advantage of these generated patterns + (bf16->i16->f32, f32->bf16->i16), and convert them to the supported + types. + For example, Intel provides SPIR-V extension ops that can + take imitated bf16 (i16) and convert them to f32 and vice-versa. + https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc + https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop + https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op + + }]; + + let options = [ + ListOption<"sourceTypeStrs", "source-types", "std::string", + "MLIR types without type support on a given target">, + ListOption<"targetTypeStrs", "target-types", "std::string", + "MLIR types to convert the unsupported source types to">, + ]; + + let dependentDialects = [ + "::mlir::gpu::GPUDialect", + "::mlir::arith::ArithDialect", + "::mlir::memref::MemRefDialect" + ]; +} + + #endif // MLIR_DIALECT_GPU_PASSES diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp index 6b1074e454bd5..6f2e054a34620 100644 --- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp @@ -380,4 +380,29 @@ std::optional parseFloatType(MLIRContext *ctx, StringRef name) { .Default(std::nullopt); } +/// Map strings to Int types. +std::optional parseIntType(MLIRContext *ctx, StringRef name) { + Builder b(ctx); + return llvm::StringSwitch>(name) + .Case("i1", b.getIntegerType(1)) + .Case("i2", b.getIntegerType(2)) + .Case("i4", b.getIntegerType(4)) + .Case("i6", b.getIntegerType(6)) + .Case("i8", b.getIntegerType(8)) + .Case("i16", b.getIntegerType(16)) + .Case("i32", b.getIntegerType(32)) + .Case("i64", b.getIntegerType(64)) + .Case("i80", b.getIntegerType(80)) + .Case("i128", b.getIntegerType(128)) + .Default(std::nullopt); +} +/// Map strings to Int or Float types. +std::optional parseIntOrFloatType(MLIRContext *ctx, StringRef name) { + if (auto floatTy = parseFloatType(ctx, name)) + return *floatTy; + if (auto intTy = parseIntType(ctx, name)) + return *intTy; + return std::nullopt; +} + } // namespace mlir::arith diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt index e21fa501bae6b..6d63f0d79e7d2 100644 --- a/mlir/lib/Dialect/GPU/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/CMakeLists.txt @@ -23,7 +23,7 @@ add_mlir_dialect_library(MLIRGPUDialect MLIRMemRefDialect MLIRSideEffectInterfaces MLIRSupport - ) +) add_mlir_dialect_library(MLIRGPUTransforms Transforms/AllReduceLowering.cpp @@ -42,6 +42,7 @@ add_mlir_dialect_library(MLIRGPUTransforms Transforms/SPIRVAttachTarget.cpp Transforms/SubgroupIdRewriter.cpp Transforms/SubgroupReduceLowering.cpp + Transforms/ImitateUnsupportedTypes.cpp OBJECT @@ -76,7 +77,7 @@ add_mlir_dialect_library(MLIRGPUTransforms MLIRROCDLTarget MLIRTransformUtils MLIRVectorDialect - ) +) add_subdirectory(TransformOps) add_subdirectory(Pipelines) diff --git a/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp b/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp new file mode 100644 index 0000000000000..fa7a5e74f13d8 --- /dev/null +++ b/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp @@ -0,0 +1,915 @@ +//===- ImitateUnsupportedTypes.cpp - Unsupported Type Imitation ----*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// +/// \file +/// This pass imitates (bitcast/reinterpret_cast) unsupported types +/// with supported types of same bitwidth. The imitation is done +/// by bitcasting the unspported types to the supported types of same bitwidth. +/// Therefore, the source type and destination type must have the same bitwidth. +/// The imitation is done by using the following operations: arith.bitcast. +/// +/// The imitation is often needed when the GPU target (dialect/IR) does not +/// support a certain type but the underlying architecture does. Take SPIR-V for +/// example, it does not support bf16, but an underlying architecture (e.g., +/// intel pvc gpu) that uses SPIR-V for code-generation does. +/// Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to +/// be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a +/// kernel parameter or inside the kernel), bf16 have to be bitcasted (similar +/// to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The +/// SPIR-V kernel can then use the imitated type (i16) in the computation. +/// However, i16 is not the same as bf16 (integer vs float), so the computation +/// can not readily use the imitated type (i16). +/// +/// Therefore, this transformation pass is intended to be used in conjuction +/// with other transformation passes such as `EmulateUnsupportedFloats` and +/// `ExtendUnsupportedTypes` that extend the bitwidth of bf16 to f32 and +/// vice-versa. +/// +/// Finally, usually, there are instructions available in the target +/// (dialect/IR) that can take advantage of these generated patterns +/// (bf16->i16->f32, f32->bf16->i16), and convert them to the supported +/// types. +/// For example, Intel provides SPIR-V extension ops that can +/// take imitated bf16 (i16) and convert them to f32 and vice-versa. +/// https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc +/// https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop +/// https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/GPU/Transforms/Passes.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/ErrorHandling.h" + +#include +#include +#include + +using namespace mlir; +using namespace mlir::gpu; + +namespace mlir { +#define GEN_PASS_DEF_GPUIMITATEUNSUPPORTEDTYPES +#include "mlir/Dialect/GPU/Transforms/Passes.h.inc" +} // namespace mlir + +//===----------------------------------------------------------------------===// +// Utility functions +//===----------------------------------------------------------------------===// + +APFloat bitcastAPIntToAPFloat(const APInt &intValue, + const llvm::fltSemantics &semantics) { + // Get the bit width of the APInt. + unsigned intBitWidth = intValue.getBitWidth(); + // Get the total bit size required for the APFloat based on the semantics. + unsigned floatBitWidth = APFloat::getSizeInBits(semantics); + // Ensure the bit widths match for a direct bitcast. + assert(intBitWidth == floatBitWidth && + "Bitwidth of APInt and APFloat must match for bitcast"); + + // Get the raw bit representation of the APInt as a byte vector. + auto intWords = intValue.getRawData(); + // Create an APFloat with the specified semantics and the raw integer bits. + APFloat floatValue(semantics, APInt(intBitWidth, *intWords)); + return floatValue; +} + +// Get FloatAttr from IntegerAttr. +FloatAttr getFloatAttrFromIntegerAttr(IntegerAttr intAttr, Type dstType, + ConversionPatternRewriter &rewriter) { + APInt intVal = intAttr.getValue(); + auto floatVal = bitcastAPIntToAPFloat( + intVal, cast(dstType).getFloatSemantics()); + return rewriter.getFloatAttr(dstType, floatVal); +} +// Get IntegerAttr from FloatAttr. +IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, + ConversionPatternRewriter &rewriter) { + APFloat floatVal = floatAttr.getValue(); + APInt intVal = floatVal.bitcastToAPInt(); + return rewriter.getIntegerAttr(dstType, intVal); +} + +struct RawAllocator { + RawAllocator(OpBuilder &builder, Location loc) : builder(builder), loc(loc) {} + + std::variant computeTotalBytes(MemRefType srcType, + Value srcMemref) { + // Element size in bytes. + int64_t elemBitWidth = srcType.getElementTypeBitWidth(); + int64_t elemByteWidth = (elemBitWidth + 7) / 8; + + if (srcType.hasStaticShape()) { + // Static shape: compute total bytes statically. + int64_t numElements = 1; + for (int64_t dim : srcType.getShape()) { + numElements *= dim; + } + return numElements * elemByteWidth; + } + + auto sizes = getSizes(srcType, srcMemref); + // Compute number of elements dynamically. + Value numElements = sizes.front(); + for (auto size : llvm::drop_begin(sizes)) + numElements = builder.create(loc, numElements, size); + Value elemSize = builder.create(loc, elemByteWidth); + + return builder.create(loc, numElements, elemSize); + } + + SmallVector getSizes(MemRefType type, Value memref) { + SmallVector sizes; + for (unsigned i = 0; i < type.getRank(); ++i) { + if (type.isDynamicDim(i)) { + sizes.push_back(builder.create(loc, memref, i)); + } else { + sizes.push_back( + builder.create(loc, type.getShape()[i])); + } + } + return sizes; + } + + SmallVector getDynamicSizes(MemRefType type, Value memref) { + SmallVector sizes; + for (unsigned i = 0; i < type.getRank(); ++i) { + if (type.isDynamicDim(i)) { + sizes.push_back(builder.create(loc, memref, i)); + } + } + return sizes; + } + + SmallVector getIdentityStrides(MemRefType type) { + SmallVector strides; + int64_t runningStride = 1; + for (int64_t dim : llvm::reverse(type.getShape())) { + strides.push_back( + builder.create(loc, runningStride)); + if (dim != ShapedType::kDynamic) + runningStride *= dim; + else + runningStride = -1; // not handling dynamic strides. + } + std::reverse(strides.begin(), strides.end()); + return strides; + } + +private: + OpBuilder &builder; + Location loc; +}; + +// Replace uses according to predicates automatically. +template +void replaceUsesWithPredicate( + OpTy originalValue, + ArrayRef, Value>> replacements, + ConversionPatternRewriter &rewriter) { + + for (OpOperand &use : llvm::make_early_inc_range(originalValue->getUses())) { + for (const auto &[predicate, newValue] : replacements) { + if (predicate(use)) { + use.set(newValue); + break; + } + } + } +} + +//===----------------------------------------------------------------------===// +// Convertion patterns +//===----------------------------------------------------------------------===// +namespace { + +//===----------------------------------------------------------------------===// +// FunctionOp conversion pattern +//===----------------------------------------------------------------------===// +template +struct ConvertFuncOp final : public OpConversionPattern { + ConvertFuncOp(MLIRContext *context, TypeConverter &typeConverter, + ArrayRef sourceTypes, ArrayRef targetTypes, + DenseMap &convertedFuncTypes) + : OpConversionPattern(context), + typeConverter(typeConverter), // Store the reference + sourceTypes(sourceTypes), targetTypes(targetTypes), + convertedFuncTypes(convertedFuncTypes) {} + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(FuncLikeOp op, typename FuncLikeOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Only handle functions a gpu.module + if (!op->template getParentOfType()) + return failure(); + FunctionType oldFuncType = op.getFunctionType(); + + // Convert function signature + TypeConverter::SignatureConversion signatureConverter( + oldFuncType.getNumInputs()); + for (const auto &argType : + llvm::enumerate(op.getFunctionType().getInputs())) { + auto convertedType = typeConverter.convertType(argType.value()); + if (!convertedType) + return failure(); + signatureConverter.addInputs(argType.index(), convertedType); + } + SmallVector newResultTypes; + for (const auto &resultType : llvm::enumerate(oldFuncType.getResults())) { + auto convertedType = typeConverter.convertType(resultType.value()); + if (!convertedType) + return failure(); + newResultTypes.push_back(convertedType); + } + + // Convert function signature + FunctionType newFuncType = rewriter.getFunctionType( + signatureConverter.getConvertedTypes(), newResultTypes); + + if (!newFuncType) + return rewriter.notifyMatchFailure(op, "could not convert function " + "type"); + + // Create new GPU function with converted type + auto newFuncOp = + rewriter.create(op.getLoc(), op.getName(), newFuncType); + + newFuncOp.setVisibility(op.getVisibility()); + // Copy attributes + for (auto attr : op->getAttrs()) { + // Skip the function_type attribute since it is already set by + // the newFuncType and we don't want to overwrite it. + if (attr.getName() != op.getFunctionTypeAttrName() && + attr.getName() != SymbolTable::getSymbolAttrName()) + newFuncOp->setAttr(attr.getName(), attr.getValue()); + } + + newFuncOp.getRegion().getBlocks().clear(); + // Inline region approach + rewriter.inlineRegionBefore(op.getBody(), newFuncOp.getBody(), + newFuncOp.end()); + // Convert block argument types using the type converter + if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, + &signatureConverter))) { + return rewriter.notifyMatchFailure(op, "could not convert region " + "types"); + } + + if (!op.use_empty()) { + op.emitError("Cannot erase func: still has uses"); + } + for (Operation *user : op->getUsers()) { + user->emitRemark() << "User of function " << op.getName(); + } + rewriter.eraseOp(op); + // Add the converted function type to the map + newFuncOp.getNameAttr().getValue(); + convertedFuncTypes[newFuncOp.getNameAttr()] = newFuncType; + return success(); + } + +private: + TypeConverter &typeConverter; // Store a reference + ArrayRef sourceTypes; + ArrayRef targetTypes; + DenseMap &convertedFuncTypes; +}; + +//===----------------------------------------------------------------------===// +// CallOp conversion pattern +//===----------------------------------------------------------------------===// +struct ConvertCallOp : OpConversionPattern { + ConvertCallOp(MLIRContext *context, TypeConverter &typeConverter, + const DenseMap &convertedFuncTypes) + : OpConversionPattern(context), convertedFuncTypes(convertedFuncTypes) {} + + LogicalResult + matchAndRewrite(func::CallOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto callee = op.getCalleeAttr(); + + auto it = convertedFuncTypes.find( + StringAttr::get(callee.getContext(), callee.getValue())); + if (it == convertedFuncTypes.end()) + return rewriter.notifyMatchFailure( + op, "Callee signature not converted. Perhaps the callee is not in " + "the same gpu module as the caller."); + + auto newResultTypes = it->second.getResults(); + rewriter.replaceOpWithNewOp( + op, callee.getValue(), newResultTypes, adaptor.getOperands()); + + return success(); + } + +private: + const DenseMap &convertedFuncTypes; +}; + +//===----------------------------------------------------------------------===// +// GPULaunchFuncOp conversion pattern +//===----------------------------------------------------------------------===// +struct ConvertGPULaunchFuncOp : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(gpu::LaunchFuncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + std::optional clusterSizeOpernads = + op.hasClusterSize() + ? std::optional(op.getClusterSizeOperandValues()) + : std::nullopt; + + // Create the new launch_func. + auto newOp = rewriter.create( + op.getLoc(), adaptor.getKernel(), op.getGridSizeOperandValues(), + op.getBlockSizeOperandValues(), op.getDynamicSharedMemorySize(), + adaptor.getKernelOperands(), op.getAsyncObject(), clusterSizeOpernads); + + // Copy block size and grid size attributes + newOp->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// AllocOp conversion pattern +//===----------------------------------------------------------------------===// +template +struct ConvertAllocOp : OpConversionPattern { + ConvertAllocOp(MLIRContext *ctx, TypeConverter &typeConverter) + : OpConversionPattern(ctx), typeConverter(typeConverter) {} + + LogicalResult + matchAndRewrite(AllocOp op, typename AllocOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MemRefType srcType = llvm::cast(op.getType()); + // Only supports memref types with identity layout. Since this mechanism + // requires the usage of memref.ViewOp, which requires the layout to be + // identity. + if (!srcType.getLayout().isIdentity()) + op.emitError("only memrefs with identity layout is supported"); + + auto dstType = + dyn_cast_or_null(typeConverter.convertType(srcType)); + if (!dstType || dstType == srcType) + return failure(); // No need to rewrite. + + // Helper class to allocate raw memory. + RawAllocator allocator(rewriter, loc); + + // 1. Compute total allocation size. + auto totalBytes = allocator.computeTotalBytes(srcType, op.getMemref()); + + // 2. Create raw i8 buffer. + MemRefType rawType; + if (std::holds_alternative(totalBytes)) { + // Static size. + SmallVector staticI8Shape; + staticI8Shape.push_back(std::get(totalBytes)); + rawType = MemRefType::get(staticI8Shape, rewriter.getI8Type(), {}, + srcType.getMemorySpaceAsInt()); + } else { + // Dynamic size. + rawType = MemRefType::get({ShapedType::kDynamic}, rewriter.getI8Type(), + {}, srcType.getMemorySpaceAsInt()); + } + Value rawAlloc; + + if constexpr (std::is_same_v) { + rawAlloc = + rewriter + .create( + loc, rawType, + op.getAsyncToken() ? op.getAsyncToken().getType() : nullptr, + adaptor.getAsyncDependencies(), + std::holds_alternative(totalBytes) + ? ValueRange{std::get(totalBytes)} + : ValueRange{}, + adaptor.getSymbolOperands(), op.getHostShared()) + .getResult(0); + } else { + rawAlloc = rewriter.create( + loc, rawType, + std::holds_alternative(totalBytes) + ? ValueRange{std::get(totalBytes)} + : ValueRange{}, + op.getSymbolOperands()); + } + + // 3. Create view for original type. + SmallVector dynamicSizes = + allocator.getDynamicSizes(srcType, op.getMemref()); + // Since we are using memref::ViewOp, only identity strides are supported. + SmallVector dynamicStrides = allocator.getIdentityStrides(srcType); + Value zeroOffset = rewriter.create(loc, 0); + Value originalView = rewriter.create( + loc, srcType, rawAlloc, zeroOffset, dynamicSizes); + + // 4. Create view for converted type. + Value convertedView = rewriter.create( + loc, dstType, rawAlloc, zeroOffset, dynamicSizes); + + // 5. Replace uses: + // gpu::LaunchFuncOp uses -> Replace the original AllocOp use in + // gpu::LaunchFuncOp with the view of the + // converted type. + // + // DeallocOp uses -> Replace the original AllocOp use in dealloc with + // the new AllocOp. + // + // Other uses-> Replace the original AllocOp use with the view of the + // original type. + + SmallVector launchFuncUses; + SmallVector deallocUses; + SmallVector otherUses; + + for (OpOperand &use : op->getUses()) { + if (isa(use.getOwner())) { + launchFuncUses.push_back(&use); + } else if (isa(use.getOwner()) || + isa(use.getOwner())) { + deallocUses.push_back(&use); + } else { + otherUses.push_back(&use); + } + } + + for (OpOperand *use : launchFuncUses) + use->set(convertedView); + for (OpOperand *use : deallocUses) + use->set(rawAlloc); + for (OpOperand *use : otherUses) + use->set(originalView); + + // Erase the original AllocOp. + rewriter.eraseOp(op); + return success(); + } + +private: + TypeConverter &typeConverter; +}; + +//===----------------------------------------------------------------------===// +// ArithConstantOp conversion pattern +//===----------------------------------------------------------------------===// +struct ConvertArithConstantOp : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + ConvertArithConstantOp(MLIRContext *context, TypeConverter &typeConverter, + ArrayRef sourceTypes, ArrayRef targetTypes) + : OpConversionPattern(context), + typeConverter(typeConverter), // Store the reference. + sourceTypes(sourceTypes), targetTypes(targetTypes) {} + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type srcType = op.getType(); + Type dstType = typeConverter.convertType(srcType); + if (!dstType || dstType == srcType) + return failure(); + + Attribute value = op.getValue(); + Value newConstOp = nullptr; + + // When source is IntegerAttr. + if (auto intAttr = dyn_cast(value)) { + APInt intVal = intAttr.getValue(); + if (isa(dstType)) { + auto newAttr = getFloatAttrFromIntegerAttr(intAttr, dstType, rewriter); + newConstOp = + rewriter.create(op.getLoc(), dstType, newAttr); + } else if (isa(dstType)) { + auto newAttr = rewriter.getIntegerAttr(dstType, intVal); + newConstOp = + rewriter.create(op.getLoc(), dstType, newAttr); + } else { + return rewriter.notifyMatchFailure( + op, "expected integer or float target type for constant"); + } + } + + // When source is FloatAttr. + else if (auto floatAttr = dyn_cast(value)) { + if (llvm::isa(dstType)) { + auto newAttr = + getIntegerAttrFromFloatAttr(floatAttr, dstType, rewriter); + newConstOp = + rewriter.create(op.getLoc(), dstType, newAttr); + } else if (llvm::isa(dstType)) { + auto newAttr = rewriter.getFloatAttr(dstType, floatAttr.getValue()); + newConstOp = + rewriter.create(op.getLoc(), dstType, newAttr); + } else { + return rewriter.notifyMatchFailure( + op, "expected integer or float target type for constant"); + } + } + // Handle DenseElementsAttr. + else if (auto denseAttr = dyn_cast(value)) { + Type newEltType; + if (auto shapedType = dyn_cast(dstType)) + newEltType = shapedType.getElementType(); + else + return rewriter.notifyMatchFailure( + op, "expected shaped type for dense constant"); + + SmallVector newValues; + for (Attribute attr : denseAttr.getValues()) { + if (auto intAttr = dyn_cast(attr)) { + if (llvm::isa(newEltType)) { + auto newAttr = + getFloatAttrFromIntegerAttr(intAttr, newEltType, rewriter); + newValues.push_back(newAttr); + } else if (llvm::isa(newEltType)) { + newValues.push_back( + rewriter.getIntegerAttr(newEltType, intAttr.getValue())); + } else { + return rewriter.notifyMatchFailure( + op, "unsupported target element type in dense constant"); + } + } else if (auto floatAttr = dyn_cast(attr)) { + if (llvm::isa(newEltType)) { + auto newAttr = + getIntegerAttrFromFloatAttr(floatAttr, newEltType, rewriter); + newValues.push_back(newAttr); + } else if (llvm::isa(newEltType)) + newValues.push_back( + rewriter.getFloatAttr(newEltType, floatAttr.getValue())); + else + return rewriter.notifyMatchFailure( + op, "unsupported target element type in dense constant"); + } else { + return rewriter.notifyMatchFailure( + op, "unsupported target element type in dense constant"); + } + } + + auto newAttr = + DenseElementsAttr::get(cast(dstType), newValues); + newConstOp = + rewriter.create(op.getLoc(), dstType, newAttr); + } + if (!newConstOp) + return rewriter.notifyMatchFailure( + op, "unsupported constant type for source to target conversion"); + + auto bitcastOp = + rewriter.create(op.getLoc(), srcType, newConstOp); + rewriter.replaceOp(op, bitcastOp.getResult()); + return success(); + } + +private: + TypeConverter &typeConverter; // Store a reference. + ArrayRef sourceTypes; + ArrayRef targetTypes; +}; + +//===----------------------------------------------------------------------===// +// GenericOp conversion pattern +//===----------------------------------------------------------------------===// +struct ConvertOpWithSourceType final : ConversionPattern { + ConvertOpWithSourceType(MLIRContext *context, + const TypeConverter &typeConverter, + ArrayRef sourceTypes, + ArrayRef targetTypes) + : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 1, context), + sourceTypes(sourceTypes), targetTypes(targetTypes) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + SmallVector newResultTypes; + for (Type t : op->getResultTypes()) { + Type converted = typeConverter->convertType(t); + if (!converted) + return failure(); + newResultTypes.push_back(converted); + } + + // Clone the op manually with the converted result types + OperationState state(op->getLoc(), op->getName().getStringRef()); + state.addOperands(operands); + state.addTypes(newResultTypes); + state.addAttributes(op->getAttrs()); + + for ([[maybe_unused]] auto ®ion : op->getRegions()) + state.regions.emplace_back(); + + Operation *newOp = rewriter.create(state); + // Transfer regions and convert them + for (auto [oldRegion, newRegion] : + llvm::zip(op->getRegions(), newOp->getRegions())) { + if (!oldRegion.empty()) { + newRegion.takeBody(oldRegion); + if (failed(rewriter.convertRegionTypes(&newRegion, *typeConverter))) { + return rewriter.notifyMatchFailure(op, + "region type conversion failed"); + } + } + } + + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } + +private: + ArrayRef sourceTypes; + ArrayRef targetTypes; +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Type Converter +//===----------------------------------------------------------------------===// + +void mlir::populateImitateUnsupportedTypesTypeConverter( + TypeConverter &typeConverter, ArrayRef sourceTypes, + ArrayRef targetTypes) { + auto srcTypes = SmallVector(sourceTypes); + auto tgtTypes = SmallVector(targetTypes); + + assert(sourceTypes.size() == targetTypes.size() && + "Source and target types must have same size"); + + typeConverter.addConversion([srcTypes, tgtTypes](Type type) -> Type { + if (type.isIntOrIndexOrFloat()) { + for (auto [src, tgt] : llvm::zip_equal(srcTypes, tgtTypes)) { + if (type == src) + return tgt; + } + } else if (auto memref = llvm::dyn_cast(type)) { + Type elemType = memref.getElementType(); + for (auto [src, tgt] : llvm::zip_equal(srcTypes, tgtTypes)) { + if (elemType == src) + return MemRefType::get(memref.getShape(), tgt, memref.getLayout(), + memref.getMemorySpace()); + } + } else if (auto vec = llvm::dyn_cast(type)) { + Type elemType = vec.getElementType(); + for (auto [src, tgt] : llvm::zip_equal(srcTypes, tgtTypes)) { + if (elemType == src) + return VectorType::get(vec.getShape(), tgt); + } + } + return type; + }); + + auto materializeCast = [](OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) -> Value { + assert(inputs.size() == 1 && "Expected single input"); + Type inputType = inputs[0].getType(); + if (isa(resultType) && isa(inputType)) { + return builder.create(loc, resultType, inputs) + .getResult(0); + } + if ((resultType.isIntOrIndexOrFloat() || isa(resultType)) && + (inputType.isIntOrIndexOrFloat() || isa(inputType))) { + return builder.create(loc, resultType, inputs[0]) + .getResult(); + } + return nullptr; + }; + + typeConverter.addSourceMaterialization(materializeCast); + typeConverter.addTargetMaterialization(materializeCast); +} + +//===----------------------------------------------------------------------===// +// Pattern Population +//===----------------------------------------------------------------------===// + +void mlir::populateImitateUnsupportedTypesConversionPatterns( + RewritePatternSet &patterns, TypeConverter &typeConverter, + ArrayRef sourceTypes, ArrayRef targetTypes, + DenseMap &convertedFuncTypes) { + auto ctx = patterns.getContext(); + auto srcTypes = SmallVector(sourceTypes); + auto tgtTypes = SmallVector(targetTypes); + assert(srcTypes.size() == tgtTypes.size() && + "Source and target types must have same size"); + + patterns.add(ctx, typeConverter, srcTypes, tgtTypes); + patterns.add, ConvertFuncOp>( + ctx, typeConverter, srcTypes, tgtTypes, convertedFuncTypes); + patterns.add(ctx, typeConverter, convertedFuncTypes); + patterns.add(ctx, typeConverter, srcTypes, tgtTypes); + patterns.add(ctx); + patterns.add>(ctx, typeConverter); + patterns.add>(ctx, typeConverter); +} + +//===----------------------------------------------------------------------===// +// Conversion Legality configuration +//===----------------------------------------------------------------------===// + +void mlir::configureImitateUnsupportedTypesLegality( + ConversionTarget &target, TypeConverter &typeConverter) { + target.addLegalDialect(); + target.addLegalDialect(); + // Make Memref, func dialect legal for all ops in host code + target.addDynamicallyLegalDialect([&](Operation *op) { + if (op->getParentOfType()) + return typeConverter.isLegal(op); + else + return true; + }); + + target.addDynamicallyLegalDialect( + [&](Operation *op) { return typeConverter.isLegal(op); }); + + target.addDynamicallyLegalDialect([&](Operation *op) { + if (op->getParentOfType()) + return typeConverter.isLegal(op); + else + return true; + }); + + target.addLegalOp(); + target.addLegalOp(); + // Manually mark arithmetic-performing vector instructions. + target.addLegalOp(); + target.addDynamicallyLegalOp([&](arith::ConstantOp op) { + return typeConverter.isLegal(op.getType()); + }); + target.addDynamicallyLegalOp([&](gpu::GPUFuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()); + }); + // Only convert functions and function calls in gpu.module + target.addDynamicallyLegalOp([&](func::FuncOp op) { + if (op->getParentOfType()) + return typeConverter.isSignatureLegal(op.getFunctionType()); + return true; + }); + target.addDynamicallyLegalOp([&](func::CallOp op) { + if (op->getParentOfType()) + return typeConverter.isSignatureLegal(op.getCalleeType()); + return true; + }); + + // Only convert alloc ops in gpu.module or in host functions and has a use + // in LaunchFunc + target.addDynamicallyLegalOp([&](memref::AllocOp op) { + if (op->getParentOfType()) + return typeConverter.isLegal(op.getType()); + else { + for (auto user : op->getUsers()) { + if (isa(user)) + return typeConverter.isLegal(op.getType()); + } + } + return true; + }); + + // Mark unknown ops that are inside gpu.module, and one of its's operand is a + // memref type as dynamically legal. + target.markUnknownOpDynamicallyLegal([&typeConverter](Operation *op) -> bool { + // Check if the operation is inside a gpu.module. + if (op->getParentOfType()) { + // Check if the operation has any operands of type MemRefType. + for (Value operand : op->getOperands()) { + if (isa(operand.getType())) + return typeConverter.isLegal(op); + } + // If no operands are of type MemRefType, mark it as illegal. + return true; + } + return true; // If not in gpu.module, mark it as legal. + }); +} + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +namespace { +struct GpuImitateUnsupportedTypesPass + : public impl::GpuImitateUnsupportedTypesBase< + GpuImitateUnsupportedTypesPass> { + using Base::Base; + + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + Operation *op = getOperation(); + + SmallVector sourceTypes; + SmallVector targetTypes; + + // Parse source types + for (StringRef sourceTypeStr : sourceTypeStrs) { + std::optional maybeSourceType = + arith::parseIntOrFloatType(ctx, sourceTypeStr); + + if (!maybeSourceType) { + emitError(UnknownLoc::get(ctx), + "could not map source type '" + sourceTypeStr + + "' to a known integer or floating-point type."); + return signalPassFailure(); + } + sourceTypes.push_back(*maybeSourceType); + } + if (sourceTypes.empty()) { + (void)emitOptionalWarning(std::nullopt, "no source types " + "specified, type " + "imitation will do " + "nothing"); + } + + // Parse target types + for (StringRef targetTypeStr : targetTypeStrs) { + std::optional maybeTargetType = + arith::parseIntOrFloatType(ctx, targetTypeStr); + + if (!maybeTargetType) { + emitError(UnknownLoc::get(ctx), + "could not map target type '" + targetTypeStr + + "' to a known integer or floating-point type"); + return signalPassFailure(); + } + targetTypes.push_back(*maybeTargetType); + + if (llvm::is_contained(sourceTypes, *maybeTargetType)) { + emitError(UnknownLoc::get(ctx), + "target type cannot be an unsupported source type"); + return signalPassFailure(); + } + } + if (targetTypes.empty()) { + (void)emitOptionalWarning( + std::nullopt, + "no target types specified, type imitation will do nothing"); + } + + // Set up the type converter + TypeConverter typeConverter; + populateImitateUnsupportedTypesTypeConverter(typeConverter, sourceTypes, + targetTypes); + + // Populate the conversion patterns + RewritePatternSet patterns(ctx); + DenseMap convertedFuncTypes; + populateImitateUnsupportedTypesConversionPatterns( + patterns, typeConverter, sourceTypes, targetTypes, convertedFuncTypes); + + // Set up conversion target and configure the legality of the conversion + ConversionTarget target(*ctx); + configureImitateUnsupportedTypesLegality(target, typeConverter); + + // Apply the conversion + if (failed(applyPartialConversion(op, target, std::move(patterns)))) + signalPassFailure(); + + // Post-conversion validation: check for any remaining + // unrealized_conversion_cast + bool hasUnresolvedCast = false; + op->walk([&](UnrealizedConversionCastOp op) { + // Check if the cast is from a source type to a target type + for (auto [sourceType, targetType] : + llvm::zip_equal(sourceTypes, targetTypes)) { + if (getElementTypeOrSelf(op.getOperand(0).getType()) == sourceType && + getElementTypeOrSelf(op.getResult(0).getType()) == targetType) { + op->emitError("unresolved unrealized_conversion_cast left in IR " + "after conversion"); + hasUnresolvedCast = true; + } + } + }); + + if (hasUnresolvedCast) { + signalPassFailure(); + } + } +}; +} // namespace diff --git a/mlir/test/Dialect/GPU/imitate-unsupported-types.mlir b/mlir/test/Dialect/GPU/imitate-unsupported-types.mlir new file mode 100644 index 0000000000000..8279a2e4594b1 --- /dev/null +++ b/mlir/test/Dialect/GPU/imitate-unsupported-types.mlir @@ -0,0 +1,141 @@ +// RUN: mlir-opt -verify-diagnostics -imitate-unsupported-types="source-types=bf16 target-types=i16" --canonicalize -split-input-file %s | FileCheck %s + +// CHECK: module @builtin_module +module @builtin_module { + // CHECK: gpu.module @gpu_func_module { + gpu.module @gpu_func_module attributes{} { + // CHECK-LABEL: gpu.func @arith_and_vector_ops + // CHECK-SAME: (%[[ARG0:.*]]: memref<10x10xi16>, %[[ARG1:.*]]: memref<10x10xf32>, %[[ARG2:.*]]: vector<10x10xi16>, %[[ARG3:.*]]: memref<10x10xi16>, %[[ARG4:.*]]: vector<10x10xi16>) kernel + gpu.func @arith_and_vector_ops(%arg0: memref<10x10xbf16>, %arg1: memref<10x10xf32>, %arg2: vector<10x10xbf16>, %arg3: memref<10x10xi16>, %arg4: vector<10x10xi16>) kernel attributes {} { + + %c0 = arith.constant 0 : index + + // CHECK: %[[ARG2_CAST:.*]] = arith.bitcast %[[ARG2]] : vector<10x10xi16> to vector<10x10xbf16> + // CHECK: %[[LOAD1:.*]] = vector.load %[[ARG0]][%c0, %c0] : memref<10x10xi16>, vector<10x10xi16> + // CHECK: %[[BITCAST1:.*]] = arith.bitcast %[[LOAD1]] : vector<10x10xi16> to vector<10x10xbf16> + %2 = vector.load %arg0[%c0, %c0] : memref<10x10xbf16>, vector<10x10xbf16> + + // CHECK: %[[ADDF:.*]] = arith.addf %[[BITCAST1]], %[[ARG2_CAST]] : vector<10x10xbf16> + %add = arith.addf %2, %arg2 : vector<10x10xbf16> + + // CHECK: %[[EXTF1:.*]] = arith.extf %[[BITCAST1]] : vector<10x10xbf16> to vector<10x10xf32> + %3 = arith.extf %2 : vector<10x10xbf16> to vector<10x10xf32> + + // CHECK: %[[EXTF2:.*]] = arith.extf %[[ADDF]] : vector<10x10xbf16> to vector<10x10xf32> + %4 = arith.extf %add : vector<10x10xbf16> to vector<10x10xf32> + + // CHECK: %[[ADDF2:.*]] = arith.addf %[[EXTF1]], %[[EXTF2]] : vector<10x10xf32> + %5 = arith.addf %3, %4 : vector<10x10xf32> + + // CHECK: %[[TRUNCF:.*]] = arith.truncf %[[ADDF2]] : vector<10x10xf32> to vector<10x10xbf16> + %6 = arith.truncf %5 : vector<10x10xf32> to vector<10x10xbf16> + + // CHECK: %[[TRUNCF_CAST:.*]] = arith.bitcast %[[TRUNCF]] : vector<10x10xbf16> to vector<10x10xi16> + // CHECK: vector.store %[[TRUNCF_CAST]], %[[ARG0]][%c0, %c0] : memref<10x10xi16>, vector<10x10xi16> + vector.store %6, %arg0[%c0, %c0] : memref<10x10xbf16>, vector<10x10xbf16> + + // CHECK: %[[LOAD2:.*]] = vector.load %[[ARG3]][%c0, %c0] : memref<10x10xi16>, vector<10x10xi16> + %7 = vector.load %arg3[%c0, %c0] : memref<10x10xi16>, vector<10x10xi16> + + // CHECK: %[[ADDI:.*]] = arith.addi %[[LOAD2]], %[[ARG4]] : vector<10x10xi16> + %8 = arith.addi %7, %arg4 : vector<10x10xi16> + + // CHECK: vector.store %[[ADDI]], %[[ARG3]][%c0, %c0] : memref<10x10xi16>, vector<10x10xi16> + vector.store %8, %arg3[%c0, %c0] : memref<10x10xi16>, vector<10x10xi16> + + gpu.return + } + } +} + +// ----- + + +// CHECK: module @caller_callee_launch_func_module attributes {gpu.container_module} +module @caller_callee_launch_func_module attributes {gpu.container_module} { + + // CHECK: gpu.module @caller_callee_gpu_module { + gpu.module @caller_callee_gpu_module attributes{} { + + // CHECK: gpu.func @caller_func(%[[ARG0:.*]]: memref<10x10xi16>, %[[ARG1:.*]]: vector<10x10xi16>) kernel { + gpu.func @caller_func(%arg0: memref<10x10xbf16>, %arg1: vector<10x10xbf16>) kernel attributes {} { + %c0 = arith.constant 0 : index + + // CHECK: %[[CALL_RET:.*]] = func.call @callee_constant_return() : () -> vector<10x10xi16> + %func_result = func.call @callee_constant_return() : () -> vector<10x10xbf16> + + // CHECK: vector.store %[[CALL_RET]], %[[ARG0]][%c0, %c0] : memref<10x10xi16>, vector<10x10xi16> + vector.store %func_result, %arg0[%c0, %c0] : memref<10x10xbf16>, vector<10x10xbf16> + + // CHECK: func.call @callee_func(%[[CALL_RET]]) : (vector<10x10xi16>) -> () + func.call @callee_func(%func_result) : (vector<10x10xbf16>) -> () + + gpu.return + } + + // CHECK: func.func @callee_constant_return() -> vector<10x10xi16> { + func.func @callee_constant_return() -> vector<10x10xbf16> { + // CHECK: arith.constant dense<16128> : vector<10x10xi16> + %dense_const = arith.constant dense<5.000000e-01> : vector<10x10xbf16> + func.return %dense_const : vector<10x10xbf16> + } + + // CHECK: func.func @callee_func(%[[ARG:.*]]: vector<10x10xi16>) { + func.func @callee_func(%arg0: vector<10x10xbf16>) { + return + } + } + + // CHECK: func.func @gpu_launch_func(%[[ARG0:.*]]: memref<10x10xbf16>, %[[ARG1:.*]]: vector<10x10xbf16>) { + func.func @gpu_launch_func(%arg0: memref<10x10xbf16>, %arg1: vector<10x10xbf16>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // CHECK: arith.constant dense<16128> : vector<10x10xi16> + %dense_const = arith.constant dense<5.000000e-01> : vector<10x10xbf16> + // CHECK: arith.constant dense<6.015630e-01> : vector<10x10xbf16> + %dense_const_2 = arith.constant dense<6.000000e-01> : vector<10x10xbf16> + + // CHECK: %[[ALLOC:.*]] = gpu.alloc () : memref<200xi8> + %alloc = gpu.alloc () : memref<10x10xbf16> + + vector.store %dense_const_2, %alloc[%c0, %c0] : memref<10x10xbf16>, vector<10x10xbf16> + // CHECK: %[[VIEW:.*]] = memref.view %[[ALLOC]][%c0][] : memref<200xi8> to memref<10x10xi16> + // CHECK: gpu.launch_func @caller_callee_gpu_module::@caller_func blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%[[VIEW]] : memref<10x10xi16>, %[[CST:.*]] : vector<10x10xi16>) + gpu.launch_func @caller_callee_gpu_module::@caller_func + blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) + args(%alloc: memref<10x10xbf16>, %dense_const: vector<10x10xbf16>) + return + } +} + +// ----- + +// Only support alloc ops if it is in the same region as the launch op. +// Otherwise, it will leave an unresolved unrealized_conversion_cast in the IR +// due to typeconverter materialization. +module @unsupported_module attributes {gpu.container_module} { + gpu.module @unsupported_gpu_module attributes{} { + gpu.func @kernel(%arg0: memref<10x10xbf16>) kernel attributes {} { + gpu.return + } + } + + func.func @gpu_launch_func(%arg0: memref<10x10xbf16>) { + %c1 = arith.constant 1 : index + // expected-error@+1 {{unresolved unrealized_conversion_cast left in IR after conversion}} + gpu.launch_func @unsupported_gpu_module::@kernel + blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) + args(%arg0: memref<10x10xbf16>) + return + } + + func.func @main() { + %alloc = memref.alloc () : memref<10x10xbf16> + call @gpu_launch_func(%alloc) : (memref<10x10xbf16>) -> () + memref.dealloc %alloc : memref<10x10xbf16> + return + } +} + +// ----- + From 0cb8c964ec66239520665ffdd6d7a78be8ecebf7 Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Thu, 1 May 2025 20:09:18 +0000 Subject: [PATCH 2/5] Address review comments. Move common pass logics to initialize() from runOnOperation(). --- .../Transforms/ImitateUnsupportedTypes.cpp | 55 ++++++++++--------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp b/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp index fa7a5e74f13d8..3fd3c3b8ffe2a 100644 --- a/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp @@ -815,19 +815,18 @@ void mlir::configureImitateUnsupportedTypesLegality( //===----------------------------------------------------------------------===// namespace { + struct GpuImitateUnsupportedTypesPass : public impl::GpuImitateUnsupportedTypesBase< GpuImitateUnsupportedTypesPass> { using Base::Base; - void runOnOperation() override { - MLIRContext *ctx = &getContext(); - Operation *op = getOperation(); - - SmallVector sourceTypes; - SmallVector targetTypes; + SmallVector sourceTypes; + SmallVector targetTypes; + TypeConverter typeConverter; - // Parse source types + LogicalResult initialize(MLIRContext *ctx) override { + // Parse source types. for (StringRef sourceTypeStr : sourceTypeStrs) { std::optional maybeSourceType = arith::parseIntOrFloatType(ctx, sourceTypeStr); @@ -836,7 +835,7 @@ struct GpuImitateUnsupportedTypesPass emitError(UnknownLoc::get(ctx), "could not map source type '" + sourceTypeStr + "' to a known integer or floating-point type."); - return signalPassFailure(); + return failure(); } sourceTypes.push_back(*maybeSourceType); } @@ -847,7 +846,7 @@ struct GpuImitateUnsupportedTypesPass "nothing"); } - // Parse target types + // Parse target types. for (StringRef targetTypeStr : targetTypeStrs) { std::optional maybeTargetType = arith::parseIntOrFloatType(ctx, targetTypeStr); @@ -856,14 +855,14 @@ struct GpuImitateUnsupportedTypesPass emitError(UnknownLoc::get(ctx), "could not map target type '" + targetTypeStr + "' to a known integer or floating-point type"); - return signalPassFailure(); + return failure(); } targetTypes.push_back(*maybeTargetType); if (llvm::is_contained(sourceTypes, *maybeTargetType)) { emitError(UnknownLoc::get(ctx), "target type cannot be an unsupported source type"); - return signalPassFailure(); + return failure(); } } if (targetTypes.empty()) { @@ -872,44 +871,50 @@ struct GpuImitateUnsupportedTypesPass "no target types specified, type imitation will do nothing"); } - // Set up the type converter - TypeConverter typeConverter; + if (sourceTypes.size() != targetTypes.size()) { + emitError(UnknownLoc::get(ctx), + "source and target types must have the same size"); + return failure(); + } + // Set up the type converter. populateImitateUnsupportedTypesTypeConverter(typeConverter, sourceTypes, targetTypes); + return success(); + } - // Populate the conversion patterns + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + Operation *op = getOperation(); + + // Populate the conversion patterns. RewritePatternSet patterns(ctx); DenseMap convertedFuncTypes; populateImitateUnsupportedTypesConversionPatterns( patterns, typeConverter, sourceTypes, targetTypes, convertedFuncTypes); - // Set up conversion target and configure the legality of the conversion + // Set up conversion target and configure the legality of the conversion. ConversionTarget target(*ctx); configureImitateUnsupportedTypesLegality(target, typeConverter); - // Apply the conversion + // Apply the conversion. if (failed(applyPartialConversion(op, target, std::move(patterns)))) - signalPassFailure(); + return signalPassFailure(); // Post-conversion validation: check for any remaining - // unrealized_conversion_cast - bool hasUnresolvedCast = false; + // unrealized_conversion_cast. op->walk([&](UnrealizedConversionCastOp op) { - // Check if the cast is from a source type to a target type + // Check if the cast is from a source type to a target type. for (auto [sourceType, targetType] : llvm::zip_equal(sourceTypes, targetTypes)) { if (getElementTypeOrSelf(op.getOperand(0).getType()) == sourceType && getElementTypeOrSelf(op.getResult(0).getType()) == targetType) { op->emitError("unresolved unrealized_conversion_cast left in IR " "after conversion"); - hasUnresolvedCast = true; + return signalPassFailure(); } } }); - - if (hasUnresolvedCast) { - signalPassFailure(); - } } }; + } // namespace From f13507ee90ff1e49867bfcbe38b5e0964b1ae0a1 Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Mon, 5 May 2025 23:09:36 +0000 Subject: [PATCH 3/5] Use arith.bitcast for memref element casting. Remove the usage of memref.view op and the restrictions comes with it. Makes the pass straight forward. --- .../Transforms/ImitateUnsupportedTypes.cpp | 264 +----------------- .../GPU/imitate-unsupported-types.mlir | 175 ++++++++---- 2 files changed, 133 insertions(+), 306 deletions(-) diff --git a/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp b/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp index 3fd3c3b8ffe2a..a0491cefdc0fa 100644 --- a/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp @@ -113,94 +113,6 @@ IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, return rewriter.getIntegerAttr(dstType, intVal); } -struct RawAllocator { - RawAllocator(OpBuilder &builder, Location loc) : builder(builder), loc(loc) {} - - std::variant computeTotalBytes(MemRefType srcType, - Value srcMemref) { - // Element size in bytes. - int64_t elemBitWidth = srcType.getElementTypeBitWidth(); - int64_t elemByteWidth = (elemBitWidth + 7) / 8; - - if (srcType.hasStaticShape()) { - // Static shape: compute total bytes statically. - int64_t numElements = 1; - for (int64_t dim : srcType.getShape()) { - numElements *= dim; - } - return numElements * elemByteWidth; - } - - auto sizes = getSizes(srcType, srcMemref); - // Compute number of elements dynamically. - Value numElements = sizes.front(); - for (auto size : llvm::drop_begin(sizes)) - numElements = builder.create(loc, numElements, size); - Value elemSize = builder.create(loc, elemByteWidth); - - return builder.create(loc, numElements, elemSize); - } - - SmallVector getSizes(MemRefType type, Value memref) { - SmallVector sizes; - for (unsigned i = 0; i < type.getRank(); ++i) { - if (type.isDynamicDim(i)) { - sizes.push_back(builder.create(loc, memref, i)); - } else { - sizes.push_back( - builder.create(loc, type.getShape()[i])); - } - } - return sizes; - } - - SmallVector getDynamicSizes(MemRefType type, Value memref) { - SmallVector sizes; - for (unsigned i = 0; i < type.getRank(); ++i) { - if (type.isDynamicDim(i)) { - sizes.push_back(builder.create(loc, memref, i)); - } - } - return sizes; - } - - SmallVector getIdentityStrides(MemRefType type) { - SmallVector strides; - int64_t runningStride = 1; - for (int64_t dim : llvm::reverse(type.getShape())) { - strides.push_back( - builder.create(loc, runningStride)); - if (dim != ShapedType::kDynamic) - runningStride *= dim; - else - runningStride = -1; // not handling dynamic strides. - } - std::reverse(strides.begin(), strides.end()); - return strides; - } - -private: - OpBuilder &builder; - Location loc; -}; - -// Replace uses according to predicates automatically. -template -void replaceUsesWithPredicate( - OpTy originalValue, - ArrayRef, Value>> replacements, - ConversionPatternRewriter &rewriter) { - - for (OpOperand &use : llvm::make_early_inc_range(originalValue->getUses())) { - for (const auto &[predicate, newValue] : replacements) { - if (predicate(use)) { - use.set(newValue); - break; - } - } - } -} - //===----------------------------------------------------------------------===// // Convertion patterns //===----------------------------------------------------------------------===// @@ -355,127 +267,6 @@ struct ConvertGPULaunchFuncOp : OpConversionPattern { } }; -//===----------------------------------------------------------------------===// -// AllocOp conversion pattern -//===----------------------------------------------------------------------===// -template -struct ConvertAllocOp : OpConversionPattern { - ConvertAllocOp(MLIRContext *ctx, TypeConverter &typeConverter) - : OpConversionPattern(ctx), typeConverter(typeConverter) {} - - LogicalResult - matchAndRewrite(AllocOp op, typename AllocOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - MemRefType srcType = llvm::cast(op.getType()); - // Only supports memref types with identity layout. Since this mechanism - // requires the usage of memref.ViewOp, which requires the layout to be - // identity. - if (!srcType.getLayout().isIdentity()) - op.emitError("only memrefs with identity layout is supported"); - - auto dstType = - dyn_cast_or_null(typeConverter.convertType(srcType)); - if (!dstType || dstType == srcType) - return failure(); // No need to rewrite. - - // Helper class to allocate raw memory. - RawAllocator allocator(rewriter, loc); - - // 1. Compute total allocation size. - auto totalBytes = allocator.computeTotalBytes(srcType, op.getMemref()); - - // 2. Create raw i8 buffer. - MemRefType rawType; - if (std::holds_alternative(totalBytes)) { - // Static size. - SmallVector staticI8Shape; - staticI8Shape.push_back(std::get(totalBytes)); - rawType = MemRefType::get(staticI8Shape, rewriter.getI8Type(), {}, - srcType.getMemorySpaceAsInt()); - } else { - // Dynamic size. - rawType = MemRefType::get({ShapedType::kDynamic}, rewriter.getI8Type(), - {}, srcType.getMemorySpaceAsInt()); - } - Value rawAlloc; - - if constexpr (std::is_same_v) { - rawAlloc = - rewriter - .create( - loc, rawType, - op.getAsyncToken() ? op.getAsyncToken().getType() : nullptr, - adaptor.getAsyncDependencies(), - std::holds_alternative(totalBytes) - ? ValueRange{std::get(totalBytes)} - : ValueRange{}, - adaptor.getSymbolOperands(), op.getHostShared()) - .getResult(0); - } else { - rawAlloc = rewriter.create( - loc, rawType, - std::holds_alternative(totalBytes) - ? ValueRange{std::get(totalBytes)} - : ValueRange{}, - op.getSymbolOperands()); - } - - // 3. Create view for original type. - SmallVector dynamicSizes = - allocator.getDynamicSizes(srcType, op.getMemref()); - // Since we are using memref::ViewOp, only identity strides are supported. - SmallVector dynamicStrides = allocator.getIdentityStrides(srcType); - Value zeroOffset = rewriter.create(loc, 0); - Value originalView = rewriter.create( - loc, srcType, rawAlloc, zeroOffset, dynamicSizes); - - // 4. Create view for converted type. - Value convertedView = rewriter.create( - loc, dstType, rawAlloc, zeroOffset, dynamicSizes); - - // 5. Replace uses: - // gpu::LaunchFuncOp uses -> Replace the original AllocOp use in - // gpu::LaunchFuncOp with the view of the - // converted type. - // - // DeallocOp uses -> Replace the original AllocOp use in dealloc with - // the new AllocOp. - // - // Other uses-> Replace the original AllocOp use with the view of the - // original type. - - SmallVector launchFuncUses; - SmallVector deallocUses; - SmallVector otherUses; - - for (OpOperand &use : op->getUses()) { - if (isa(use.getOwner())) { - launchFuncUses.push_back(&use); - } else if (isa(use.getOwner()) || - isa(use.getOwner())) { - deallocUses.push_back(&use); - } else { - otherUses.push_back(&use); - } - } - - for (OpOperand *use : launchFuncUses) - use->set(convertedView); - for (OpOperand *use : deallocUses) - use->set(rawAlloc); - for (OpOperand *use : otherUses) - use->set(originalView); - - // Erase the original AllocOp. - rewriter.eraseOp(op); - return success(); - } - -private: - TypeConverter &typeConverter; -}; - //===----------------------------------------------------------------------===// // ArithConstantOp conversion pattern //===----------------------------------------------------------------------===// @@ -688,12 +479,10 @@ void mlir::populateImitateUnsupportedTypesTypeConverter( ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1 && "Expected single input"); Type inputType = inputs[0].getType(); - if (isa(resultType) && isa(inputType)) { - return builder.create(loc, resultType, inputs) - .getResult(0); - } - if ((resultType.isIntOrIndexOrFloat() || isa(resultType)) && - (inputType.isIntOrIndexOrFloat() || isa(inputType))) { + if ((resultType.isIntOrIndexOrFloat() || isa(resultType) || + isa(resultType)) && + (inputType.isIntOrIndexOrFloat() || isa(inputType) || + isa(inputType))) { return builder.create(loc, resultType, inputs[0]) .getResult(); } @@ -724,8 +513,6 @@ void mlir::populateImitateUnsupportedTypesConversionPatterns( patterns.add(ctx, typeConverter, convertedFuncTypes); patterns.add(ctx, typeConverter, srcTypes, tgtTypes); patterns.add(ctx); - patterns.add>(ctx, typeConverter); - patterns.add>(ctx, typeConverter); } //===----------------------------------------------------------------------===// @@ -744,8 +531,11 @@ void mlir::configureImitateUnsupportedTypesLegality( return true; }); - target.addDynamicallyLegalDialect( - [&](Operation *op) { return typeConverter.isLegal(op); }); + target.addDynamicallyLegalDialect([&](Operation *op) { + if (op->getParentOfType()) + return typeConverter.isLegal(op); + return true; + }); target.addDynamicallyLegalDialect([&](Operation *op) { if (op->getParentOfType()) @@ -755,7 +545,6 @@ void mlir::configureImitateUnsupportedTypesLegality( }); target.addLegalOp(); - target.addLegalOp(); // Manually mark arithmetic-performing vector instructions. target.addLegalOp([&](gpu::GPUFuncOp op) { return typeConverter.isSignatureLegal(op.getFunctionType()); }); + target.addDynamicallyLegalOp( + [&](gpu::LaunchFuncOp op) { return typeConverter.isLegal(op); }); // Only convert functions and function calls in gpu.module target.addDynamicallyLegalOp([&](func::FuncOp op) { if (op->getParentOfType()) @@ -779,22 +570,8 @@ void mlir::configureImitateUnsupportedTypesLegality( return true; }); - // Only convert alloc ops in gpu.module or in host functions and has a use - // in LaunchFunc - target.addDynamicallyLegalOp([&](memref::AllocOp op) { - if (op->getParentOfType()) - return typeConverter.isLegal(op.getType()); - else { - for (auto user : op->getUsers()) { - if (isa(user)) - return typeConverter.isLegal(op.getType()); - } - } - return true; - }); - - // Mark unknown ops that are inside gpu.module, and one of its's operand is a - // memref type as dynamically legal. + // Mark unknown ops that are inside gpu.module, and one of its's operand is + // a memref type as dynamically legal. target.markUnknownOpDynamicallyLegal([&typeConverter](Operation *op) -> bool { // Check if the operation is inside a gpu.module. if (op->getParentOfType()) { @@ -899,21 +676,6 @@ struct GpuImitateUnsupportedTypesPass // Apply the conversion. if (failed(applyPartialConversion(op, target, std::move(patterns)))) return signalPassFailure(); - - // Post-conversion validation: check for any remaining - // unrealized_conversion_cast. - op->walk([&](UnrealizedConversionCastOp op) { - // Check if the cast is from a source type to a target type. - for (auto [sourceType, targetType] : - llvm::zip_equal(sourceTypes, targetTypes)) { - if (getElementTypeOrSelf(op.getOperand(0).getType()) == sourceType && - getElementTypeOrSelf(op.getResult(0).getType()) == targetType) { - op->emitError("unresolved unrealized_conversion_cast left in IR " - "after conversion"); - return signalPassFailure(); - } - } - }); } }; diff --git a/mlir/test/Dialect/GPU/imitate-unsupported-types.mlir b/mlir/test/Dialect/GPU/imitate-unsupported-types.mlir index 8279a2e4594b1..db4d692241023 100644 --- a/mlir/test/Dialect/GPU/imitate-unsupported-types.mlir +++ b/mlir/test/Dialect/GPU/imitate-unsupported-types.mlir @@ -2,47 +2,54 @@ // CHECK: module @builtin_module module @builtin_module { - // CHECK: gpu.module @gpu_func_module { - gpu.module @gpu_func_module attributes{} { - // CHECK-LABEL: gpu.func @arith_and_vector_ops + // CHECK: gpu.module @gpu_func_module + gpu.module @gpu_func_module { + // CHECK: gpu.func @arith_and_vector_ops // CHECK-SAME: (%[[ARG0:.*]]: memref<10x10xi16>, %[[ARG1:.*]]: memref<10x10xf32>, %[[ARG2:.*]]: vector<10x10xi16>, %[[ARG3:.*]]: memref<10x10xi16>, %[[ARG4:.*]]: vector<10x10xi16>) kernel - gpu.func @arith_and_vector_ops(%arg0: memref<10x10xbf16>, %arg1: memref<10x10xf32>, %arg2: vector<10x10xbf16>, %arg3: memref<10x10xi16>, %arg4: vector<10x10xi16>) kernel attributes {} { - + gpu.func @arith_and_vector_ops( + %arg0: memref<10x10xbf16>, + %arg1: memref<10x10xf32>, + %arg2: vector<10x10xbf16>, + %arg3: memref<10x10xi16>, + %arg4: vector<10x10xi16> + ) kernel { + // CHECK: %[[C0:.*]] = arith.constant 0 : index %c0 = arith.constant 0 : index - // CHECK: %[[ARG2_CAST:.*]] = arith.bitcast %[[ARG2]] : vector<10x10xi16> to vector<10x10xbf16> - // CHECK: %[[LOAD1:.*]] = vector.load %[[ARG0]][%c0, %c0] : memref<10x10xi16>, vector<10x10xi16> - // CHECK: %[[BITCAST1:.*]] = arith.bitcast %[[LOAD1]] : vector<10x10xi16> to vector<10x10xbf16> - %2 = vector.load %arg0[%c0, %c0] : memref<10x10xbf16>, vector<10x10xbf16> + // CHECK: %[[CAST_ARG2:.*]] = arith.bitcast %[[ARG2]] : vector<10x10xi16> to vector<10x10xbf16> + // CHECK: %[[LOAD_ARG0:.*]] = vector.load %[[ARG0]][%[[C0]], %[[C0]]] : memref<10x10xi16>, vector<10x10xi16> + // CHECK: %[[CAST_LOAD:.*]] = arith.bitcast %[[LOAD_ARG0]] : vector<10x10xi16> to vector<10x10xbf16> + %0 = vector.load %arg0[%c0, %c0] : memref<10x10xbf16>, vector<10x10xbf16> - // CHECK: %[[ADDF:.*]] = arith.addf %[[BITCAST1]], %[[ARG2_CAST]] : vector<10x10xbf16> - %add = arith.addf %2, %arg2 : vector<10x10xbf16> + // CHECK: %[[ADDF:.*]] = arith.addf %[[CAST_LOAD]], %[[CAST_ARG2]] : vector<10x10xbf16> + %1 = arith.addf %0, %arg2 : vector<10x10xbf16> - // CHECK: %[[EXTF1:.*]] = arith.extf %[[BITCAST1]] : vector<10x10xbf16> to vector<10x10xf32> - %3 = arith.extf %2 : vector<10x10xbf16> to vector<10x10xf32> + // CHECK: %[[EXT0:.*]] = arith.extf %[[CAST_LOAD]] : vector<10x10xbf16> to vector<10x10xf32> + %2 = arith.extf %0 : vector<10x10xbf16> to vector<10x10xf32> - // CHECK: %[[EXTF2:.*]] = arith.extf %[[ADDF]] : vector<10x10xbf16> to vector<10x10xf32> - %4 = arith.extf %add : vector<10x10xbf16> to vector<10x10xf32> + // CHECK: %[[EXT1:.*]] = arith.extf %[[ADDF]] : vector<10x10xbf16> to vector<10x10xf32> + %3 = arith.extf %1 : vector<10x10xbf16> to vector<10x10xf32> - // CHECK: %[[ADDF2:.*]] = arith.addf %[[EXTF1]], %[[EXTF2]] : vector<10x10xf32> - %5 = arith.addf %3, %4 : vector<10x10xf32> + // CHECK: %[[FADD:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<10x10xf32> + %4 = arith.addf %2, %3 : vector<10x10xf32> - // CHECK: %[[TRUNCF:.*]] = arith.truncf %[[ADDF2]] : vector<10x10xf32> to vector<10x10xbf16> - %6 = arith.truncf %5 : vector<10x10xf32> to vector<10x10xbf16> + // CHECK: %[[TRUNC:.*]] = arith.truncf %[[FADD]] : vector<10x10xf32> to vector<10x10xbf16> + %5 = arith.truncf %4 : vector<10x10xf32> to vector<10x10xbf16> - // CHECK: %[[TRUNCF_CAST:.*]] = arith.bitcast %[[TRUNCF]] : vector<10x10xbf16> to vector<10x10xi16> - // CHECK: vector.store %[[TRUNCF_CAST]], %[[ARG0]][%c0, %c0] : memref<10x10xi16>, vector<10x10xi16> - vector.store %6, %arg0[%c0, %c0] : memref<10x10xbf16>, vector<10x10xbf16> + // CHECK: %[[CAST_TRUNC:.*]] = arith.bitcast %[[TRUNC]] : vector<10x10xbf16> to vector<10x10xi16> + // CHECK: vector.store %[[CAST_TRUNC]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<10x10xi16>, vector<10x10xi16> + vector.store %5, %arg0[%c0, %c0] : memref<10x10xbf16>, vector<10x10xbf16> - // CHECK: %[[LOAD2:.*]] = vector.load %[[ARG3]][%c0, %c0] : memref<10x10xi16>, vector<10x10xi16> - %7 = vector.load %arg3[%c0, %c0] : memref<10x10xi16>, vector<10x10xi16> + // CHECK: %[[LOAD2:.*]] = vector.load %[[ARG3]][%[[C0]], %[[C0]]] : memref<10x10xi16>, vector<10x10xi16> + %6 = vector.load %arg3[%c0, %c0] : memref<10x10xi16>, vector<10x10xi16> // CHECK: %[[ADDI:.*]] = arith.addi %[[LOAD2]], %[[ARG4]] : vector<10x10xi16> - %8 = arith.addi %7, %arg4 : vector<10x10xi16> + %7 = arith.addi %6, %arg4 : vector<10x10xi16> - // CHECK: vector.store %[[ADDI]], %[[ARG3]][%c0, %c0] : memref<10x10xi16>, vector<10x10xi16> - vector.store %8, %arg3[%c0, %c0] : memref<10x10xi16>, vector<10x10xi16> + // CHECK: vector.store %[[ADDI]], %[[ARG3]][%[[C0]], %[[C0]]] : memref<10x10xi16>, vector<10x10xi16> + vector.store %7, %arg3[%c0, %c0] : memref<10x10xi16>, vector<10x10xi16> + // CHECK: gpu.return gpu.return } } @@ -53,30 +60,32 @@ module @builtin_module { // CHECK: module @caller_callee_launch_func_module attributes {gpu.container_module} module @caller_callee_launch_func_module attributes {gpu.container_module} { - // CHECK: gpu.module @caller_callee_gpu_module { gpu.module @caller_callee_gpu_module attributes{} { - - // CHECK: gpu.func @caller_func(%[[ARG0:.*]]: memref<10x10xi16>, %[[ARG1:.*]]: vector<10x10xi16>) kernel { + // CHECK: gpu.func @caller_func + // CHECK-SAME: (%[[ARG0:.*]]: memref<10x10xi16>, %[[ARG1:.*]]: vector<10x10xi16>) kernel gpu.func @caller_func(%arg0: memref<10x10xbf16>, %arg1: vector<10x10xbf16>) kernel attributes {} { + // CHECK: %[[C0:.*]] = arith.constant 0 : index %c0 = arith.constant 0 : index - // CHECK: %[[CALL_RET:.*]] = func.call @callee_constant_return() : () -> vector<10x10xi16> + // CHECK: %[[RET:.*]] = func.call @callee_constant_return() : () -> vector<10x10xi16> %func_result = func.call @callee_constant_return() : () -> vector<10x10xbf16> - // CHECK: vector.store %[[CALL_RET]], %[[ARG0]][%c0, %c0] : memref<10x10xi16>, vector<10x10xi16> + // CHECK: vector.store %[[RET]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<10x10xi16>, vector<10x10xi16> vector.store %func_result, %arg0[%c0, %c0] : memref<10x10xbf16>, vector<10x10xbf16> - // CHECK: func.call @callee_func(%[[CALL_RET]]) : (vector<10x10xi16>) -> () + // CHECK: func.call @callee_func(%[[RET]]) : (vector<10x10xi16>) -> () func.call @callee_func(%func_result) : (vector<10x10xbf16>) -> () + // CHECK: gpu.return gpu.return } // CHECK: func.func @callee_constant_return() -> vector<10x10xi16> { func.func @callee_constant_return() -> vector<10x10xbf16> { - // CHECK: arith.constant dense<16128> : vector<10x10xi16> + // CHECK: %[[CST:.*]] = arith.constant dense<16128> : vector<10x10xi16> %dense_const = arith.constant dense<5.000000e-01> : vector<10x10xbf16> + // CHECK: return %[[CST]] : vector<10x10xi16> func.return %dense_const : vector<10x10xbf16> } @@ -86,21 +95,40 @@ module @caller_callee_launch_func_module attributes {gpu.container_module} { } } - // CHECK: func.func @gpu_launch_func(%[[ARG0:.*]]: memref<10x10xbf16>, %[[ARG1:.*]]: vector<10x10xbf16>) { + // CHECK: func.func @gpu_launch_func( func.func @gpu_launch_func(%arg0: memref<10x10xbf16>, %arg1: vector<10x10xbf16>) { + + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[C1:.*]] = arith.constant 1 : index %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - // CHECK: arith.constant dense<16128> : vector<10x10xi16> + + // Handling bf16 constants, dealing with constants for both cases: + // - not used in gpu.launch_func (no conversion) + // - used in gpu.launch_func (needs conversion to i16) + + // CHECK: %[[BF16_CONST:.*]] = arith.constant dense<5.000000e-01> : vector<10x10xbf16> + // CHECK: %[[I16_CONST:.*]] = arith.constant dense<16128> : vector<10x10xi16> %dense_const = arith.constant dense<5.000000e-01> : vector<10x10xbf16> - // CHECK: arith.constant dense<6.015630e-01> : vector<10x10xbf16> - %dense_const_2 = arith.constant dense<6.000000e-01> : vector<10x10xbf16> - // CHECK: %[[ALLOC:.*]] = gpu.alloc () : memref<200xi8> - %alloc = gpu.alloc () : memref<10x10xbf16> + // CHECK: %[[BF16_CONST_2:.*]] = arith.constant dense<1.500000e+00> : vector<10x10xbf16> + %dense_const_2 = arith.constant dense<1.500000e+00> : vector<10x10xbf16> + + // CHECK: %[[ADDF:.*]] = arith.addf %arg1, %[[BF16_CONST]] : vector<10x10xbf16> + %add = arith.addf %dense_const, %arg1 : vector<10x10xbf16> + + // CHECK: vector.store %[[ADDF]], %arg0[%[[C0]], %[[C0]]] : memref<10x10xbf16>, vector<10x10xbf16> + vector.store %add, %arg0[%c0, %c0] : memref<10x10xbf16>, vector<10x10xbf16> + // CHECK: %[[ALLOC:.*]] = gpu.alloc () : memref<10x10xbf16> + %alloc = gpu.alloc () : memref<10x10xbf16> + // CHECK: %[[BITCAST:.*]] = arith.bitcast %[[ALLOC]] : memref<10x10xbf16> to memref<10x10xi16> + // CHECK: vector.store %[[BF16_CONST_2]], %[[ALLOC]][%[[C0]], %[[C0]]] : memref<10x10xbf16>, vector<10x10xbf16> vector.store %dense_const_2, %alloc[%c0, %c0] : memref<10x10xbf16>, vector<10x10xbf16> - // CHECK: %[[VIEW:.*]] = memref.view %[[ALLOC]][%c0][] : memref<200xi8> to memref<10x10xi16> - // CHECK: gpu.launch_func @caller_callee_gpu_module::@caller_func blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%[[VIEW]] : memref<10x10xi16>, %[[CST:.*]] : vector<10x10xi16>) + + + // CHECK: gpu.launch_func @caller_callee_gpu_module::@caller_func + // CHECK-SAME: args(%[[BITCAST]] : memref<10x10xi16>, %[[I16_CONST]] : vector<10x10xi16>) gpu.launch_func @caller_callee_gpu_module::@caller_func blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%alloc: memref<10x10xbf16>, %dense_const: vector<10x10xbf16>) @@ -110,32 +138,69 @@ module @caller_callee_launch_func_module attributes {gpu.container_module} { // ----- -// Only support alloc ops if it is in the same region as the launch op. -// Otherwise, it will leave an unresolved unrealized_conversion_cast in the IR -// due to typeconverter materialization. -module @unsupported_module attributes {gpu.container_module} { - gpu.module @unsupported_gpu_module attributes{} { - gpu.func @kernel(%arg0: memref<10x10xbf16>) kernel attributes {} { + +// CHECK: #map = affine_map<(d0, d1) -> (d1, d0)> +#map = affine_map<(d0, d1) -> (d1, d0)> + +// CHECK-LABEL: module @module_multi_level_call attributes {gpu.container_module} { +module @module_multi_level_call attributes {gpu.container_module} { + // CHECK: gpu.module @gpu_module_multi_level_call { + gpu.module @gpu_module_multi_level_call { + // CHECK: gpu.func @kernel(%[[K_ARG:.*]]: memref<10x10xi16>) kernel { + gpu.func @kernel(%arg0: memref<10x10xi16>) kernel { + // CHECK: gpu.return + gpu.return + } + + // CHECK: gpu.func @affine_memref_arg(%[[AFF_ARG:.*]]: memref<100x100xi16, #map, 2>) kernel { + gpu.func @affine_memref_arg(%arg0: memref<100x100xi16, #map, 2>) kernel { + // CHECK: gpu.return gpu.return } } - func.func @gpu_launch_func(%arg0: memref<10x10xbf16>) { + // CHECK-LABEL: func.func @gpu_launch_func + func.func @gpu_launch_func(%arg0: memref<10x10xbf16>, %arg1: memref<100x100xbf16, #map, 2>) { + // CHECK: %[[C1:.*]] = arith.constant 1 : index %c1 = arith.constant 1 : index - // expected-error@+1 {{unresolved unrealized_conversion_cast left in IR after conversion}} - gpu.launch_func @unsupported_gpu_module::@kernel + + // CHECK: %[[AFF_CAST:.*]] = arith.bitcast %[[ARG1:.*]] : memref<100x100xbf16, #map, 2> to memref<100x100xi16, #map, 2> + %0 = arith.bitcast %arg1 : memref<100x100xbf16, #map, 2> to memref<100x100xi16, #map, 2> + + // CHECK: %[[BF16_CAST:.*]] = arith.bitcast %[[ARG0:.*]] : memref<10x10xbf16> to memref<10x10xi16> + %1 = arith.bitcast %arg0 : memref<10x10xbf16> to memref<10x10xi16> + + // CHECK: gpu.launch_func @gpu_module_multi_level_call::@kernel + // CHECK-SAME: args(%[[BF16_CAST]] : memref<10x10xi16>) + gpu.launch_func @gpu_module_multi_level_call::@kernel + blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) + args(%1 : memref<10x10xi16>) + + // CHECK: gpu.launch_func @gpu_module_multi_level_call::@affine_memref_arg + // CHECK-SAME: args(%[[AFF_CAST]] : memref<100x100xi16, #map, 2>) + gpu.launch_func @gpu_module_multi_level_call::@affine_memref_arg blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) - args(%arg0: memref<10x10xbf16>) + args(%0 : memref<100x100xi16, #map, 2>) + // CHECK: return return } + // CHECK-LABEL: func.func @main func.func @main() { - %alloc = memref.alloc () : memref<10x10xbf16> - call @gpu_launch_func(%alloc) : (memref<10x10xbf16>) -> () + // CHECK: %[[ALLOC0:.*]] = memref.alloc() : memref<10x10xbf16> + %alloc = memref.alloc() : memref<10x10xbf16> + // CHECK: %[[ALLOC1:.*]] = memref.alloc() : memref<100x100xbf16, #map, 2> + %alloc_0 = memref.alloc() : memref<100x100xbf16, #map, 2> + // CHECK: call @gpu_launch_func(%[[ALLOC0]], %[[ALLOC1]]) + call @gpu_launch_func(%alloc, %alloc_0) : (memref<10x10xbf16>, memref<100x100xbf16, #map, 2>) -> () + // CHECK: memref.dealloc %[[ALLOC0]] memref.dealloc %alloc : memref<10x10xbf16> + // CHECK: memref.dealloc %[[ALLOC1]] + memref.dealloc %alloc_0 : memref<100x100xbf16, #map, 2> + // CHECK: return return } } -// ----- + From d138e5d4e84e4d45232af3ab3e1abf3ddbbf6414 Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Wed, 7 May 2025 18:29:32 +0000 Subject: [PATCH 4/5] Mark all ops with unsupported data types ready for conversion if not marked legal explicitly. --- mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp b/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp index a0491cefdc0fa..f5c6a6eea8fa2 100644 --- a/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp @@ -576,10 +576,7 @@ void mlir::configureImitateUnsupportedTypesLegality( // Check if the operation is inside a gpu.module. if (op->getParentOfType()) { // Check if the operation has any operands of type MemRefType. - for (Value operand : op->getOperands()) { - if (isa(operand.getType())) - return typeConverter.isLegal(op); - } + return typeConverter.isLegal(op); // If no operands are of type MemRefType, mark it as illegal. return true; } From 63336e47d191e8ba67d6bd85cdee434a9691d028 Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Thu, 8 May 2025 10:41:43 +0000 Subject: [PATCH 5/5] Remove dead code. --- mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp b/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp index f5c6a6eea8fa2..8330214b873a2 100644 --- a/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp @@ -575,10 +575,7 @@ void mlir::configureImitateUnsupportedTypesLegality( target.markUnknownOpDynamicallyLegal([&typeConverter](Operation *op) -> bool { // Check if the operation is inside a gpu.module. if (op->getParentOfType()) { - // Check if the operation has any operands of type MemRefType. return typeConverter.isLegal(op); - // If no operands are of type MemRefType, mark it as illegal. - return true; } return true; // If not in gpu.module, mark it as legal. });