From 727c0f752f6df16b1445a61a8370271c62c9c30f Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Wed, 7 May 2025 10:42:58 +0200 Subject: [PATCH 1/3] prototype experiment more --- mlir/docs/DialectConversion.md | 47 +++++---- .../mlir/Transforms/DialectConversion.h | 95 ++++++++++++++++--- .../Transforms/StructuralTypeConversions.cpp | 4 +- .../Transforms/Utils/DialectConversion.cpp | 45 ++++++++- .../test-legalize-type-conversion.mlir | 18 ++++ mlir/test/lib/Dialect/Test/TestPatterns.cpp | 29 +++++- 6 files changed, 200 insertions(+), 38 deletions(-) diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md index cf577eca5b9a6..61872d10670dc 100644 --- a/mlir/docs/DialectConversion.md +++ b/mlir/docs/DialectConversion.md @@ -235,6 +235,15 @@ target types. If the source type is converted to itself, we say it is a "legal" type. Type conversions are specified via the `addConversion` method described below. +There are two kind of conversion functions: context-aware and context-unaware +conversions. A context-unaware conversion function converts a `Type` into a +`Type`. A context-aware conversion function converts a `Value` into a type. The +latter allows users to customize type conversion rules based on the IR. + +Note: When there is at least one context-aware type conversion function, the +result of type conversions can no longer be cached, which can increase +compilation time. Use this feature with caution! + A `materialization` describes how a list of values should be converted to a list of values with specific types. An important distinction from a `conversion` is that a `materialization` can produce IR, whereas a `conversion` @@ -287,29 +296,31 @@ Several of the available hooks are detailed below: ```c++ class TypeConverter { public: - /// Register a conversion function. A conversion function defines how a given - /// source type should be converted. A conversion function must be convertible - /// to any of the following forms(where `T` is a class derived from `Type`: - /// * Optional(T) + /// Register a conversion function. A conversion function must be convertible + /// to any of the following forms (where `T` is `Value` or a class derived + /// from `Type`, including `Type` itself): + /// + /// * std::optional(T) /// - This form represents a 1-1 type conversion. It should return nullptr - /// or `std::nullopt` to signify failure. If `std::nullopt` is returned, the - /// converter is allowed to try another conversion function to perform - /// the conversion. - /// * Optional(T, SmallVectorImpl &) + /// or `std::nullopt` to signify failure. If `std::nullopt` is returned, + /// the converter is allowed to try another conversion function to + /// perform the conversion. + /// * std::optional(T, SmallVectorImpl &) /// - This form represents a 1-N type conversion. It should return - /// `failure` or `std::nullopt` to signify a failed conversion. If the new - /// set of types is empty, the type is removed and any usages of the + /// `failure` or `std::nullopt` to signify a failed conversion. If the + /// new set of types is empty, the type is removed and any usages of the /// existing value are expected to be removed during conversion. If /// `std::nullopt` is returned, the converter is allowed to try another /// conversion function to perform the conversion. - /// * Optional(T, SmallVectorImpl &, ArrayRef) - /// - This form represents a 1-N type conversion supporting recursive - /// types. The first two arguments and the return value are the same as - /// for the regular 1-N form. The third argument is contains is the - /// "call stack" of the recursive conversion: it contains the list of - /// types currently being converted, with the current type being the - /// last one. If it is present more than once in the list, the - /// conversion concerns a recursive type. + /// + /// Conversion functions that accept `Value` as the first argument are + /// context-aware. I.e., they can take into account IR when converting the + /// type of the given value. Context-unaware conversion functions accept + /// `Type` or a derived class as the first argument. + /// + /// Note: Context-unaware conversions are cached, but context-aware + /// conversions are not. + /// /// Note: When attempting to convert a type, e.g. via 'convertType', the /// mostly recently added conversions will be invoked first. template +#include namespace mlir { @@ -139,7 +140,8 @@ class TypeConverter { }; /// Register a conversion function. A conversion function must be convertible - /// to any of the following forms (where `T` is a class derived from `Type`): + /// to any of the following forms (where `T` is `Value` or a class derived + /// from `Type`, including `Type` itself): /// /// * std::optional(T) /// - This form represents a 1-1 type conversion. It should return nullptr @@ -154,6 +156,14 @@ class TypeConverter { /// `std::nullopt` is returned, the converter is allowed to try another /// conversion function to perform the conversion. /// + /// Conversion functions that accept `Value` as the first argument are + /// context-aware. I.e., they can take into account IR when converting the + /// type of the given value. Context-unaware conversion functions accept + /// `Type` or a derived class as the first argument. + /// + /// Note: Context-unaware conversions are cached, but context-aware + /// conversions are not. + /// /// Note: When attempting to convert a type, e.g. via 'convertType', the /// mostly recently added conversions will be invoked first. template (std::forward(callback))); } - /// Convert the given type. This function should return failure if no valid + /// Convert the given type. This function returns failure if no valid /// conversion exists, success otherwise. If the new set of types is empty, /// the type is removed and any usages of the existing value are expected to /// be removed during conversion. + /// + /// Note: This overload invokes only context-unaware type conversion + /// functions. Users should call the other overload if possible. LogicalResult convertType(Type t, SmallVectorImpl &results) const; + /// Convert the type of the given value. This function returns failure if no + /// valid conversion exists, success otherwise. If the new set of types is + /// empty, the type is removed and any usages of the existing value are + /// expected to be removed during conversion. + /// + /// Note: This overload invokes both context-aware and context-unaware type + /// conversion functions. + LogicalResult convertType(Value v, SmallVectorImpl &results) const; + /// This hook simplifies defining 1-1 type conversions. This function returns /// the type to convert to on success, and a null type on failure. Type convertType(Type t) const; + Type convertType(Value v) const; /// Attempts a 1-1 type conversion, expecting the result type to be /// `TargetType`. Returns the converted type cast to `TargetType` on success, @@ -258,13 +281,23 @@ class TypeConverter { TargetType convertType(Type t) const { return dyn_cast_or_null(convertType(t)); } + template + TargetType convertType(Value v) const { + return dyn_cast_or_null(convertType(v)); + } - /// Convert the given set of types, filling 'results' as necessary. This - /// returns failure if the conversion of any of the types fails, success + /// Convert the given types, filling 'results' as necessary. This returns + /// "failure" if the conversion of any of the types fails, "success" /// otherwise. LogicalResult convertTypes(TypeRange types, SmallVectorImpl &results) const; + /// Convert the types of the given values, filling 'results' as necessary. + /// This returns "failure" if the conversion of any of the types fails, + /// "success" otherwise. + LogicalResult convertTypes(ValueRange values, + SmallVectorImpl &results) const; + /// Return true if the given type is legal for this type converter, i.e. the /// type converts to itself. bool isLegal(Type type) const; @@ -328,7 +361,7 @@ class TypeConverter { /// types is empty, the type is removed and any usages of the existing value /// are expected to be removed during conversion. using ConversionCallbackFn = std::function( - Type, SmallVectorImpl &)>; + std::variant, SmallVectorImpl &)>; /// The signature of the callback used to materialize a source conversion. /// @@ -348,13 +381,14 @@ class TypeConverter { /// Generate a wrapper for the given callback. This allows for accepting /// different callback forms, that all compose into a single version. - /// With callback of form: `std::optional(T)` + /// With callback of form: `std::optional(T)`, where `T` can be a + /// `Value` or a `Type` (or a class derived from `Type`). template std::enable_if_t, ConversionCallbackFn> - wrapCallback(FnT &&callback) const { + wrapCallback(FnT &&callback) { return wrapCallback([callback = std::forward(callback)]( - T type, SmallVectorImpl &results) { - if (std::optional resultOpt = callback(type)) { + T typeOrValue, SmallVectorImpl &results) { + if (std::optional resultOpt = callback(typeOrValue)) { bool wasSuccess = static_cast(*resultOpt); if (wasSuccess) results.push_back(*resultOpt); @@ -364,20 +398,49 @@ class TypeConverter { }); } /// With callback of form: `std::optional( - /// T, SmallVectorImpl &, ArrayRef)`. + /// T, SmallVectorImpl &)`, where `T` is a type. template - std::enable_if_t &>, + std::enable_if_t &> && + std::is_base_of_v, ConversionCallbackFn> wrapCallback(FnT &&callback) const { return [callback = std::forward(callback)]( - Type type, + std::variant type, SmallVectorImpl &results) -> std::optional { - T derivedType = dyn_cast(type); + T derivedType; + if (Type *t = std::get_if(&type)) { + derivedType = dyn_cast(*t); + } else if (Value *v = std::get_if(&type)) { + derivedType = dyn_cast(v->getType()); + } else { + llvm_unreachable("unexpected variant"); + } if (!derivedType) return std::nullopt; return callback(derivedType, results); }; } + /// With callback of form: `std::optional( + /// T, SmallVectorImpl)`, where `T` is a `Value`. + template + std::enable_if_t &> && + std::is_same_v, + ConversionCallbackFn> + wrapCallback(FnT &&callback) { + hasContextAwareTypeConversions = true; + return [callback = std::forward(callback)]( + std::variant type, + SmallVectorImpl &results) -> std::optional { + if (Type *t = std::get_if(&type)) { + // Context-aware type conversion was called with a type. + return std::nullopt; + } else if (Value *v = std::get_if(&type)) { + return callback(*v, results); + } + llvm_unreachable("unexpected variant"); + return std::nullopt; + }; + } /// Register a type conversion. void registerConversion(ConversionCallbackFn callback) { @@ -504,6 +567,12 @@ class TypeConverter { mutable DenseMap> cachedMultiConversions; /// A mutex used for cache access mutable llvm::sys::SmartRWMutex cacheMutex; + /// Whether the type converter has context-aware type conversions. I.e., + /// conversion rules that depend on the SSA value instead of just the type. + /// Type conversion caching is deactivated when there are context-aware + /// conversions because the type converter may return different results for + /// the same input type. + bool hasContextAwareTypeConversions = false; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp index 09326242eec2a..de4612fa0846a 100644 --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -52,8 +52,8 @@ class Structural1ToNConversionPattern : public OpConversionPattern { SmallVector offsets; offsets.push_back(0); // Do the type conversion and record the offsets. - for (Type type : op.getResultTypes()) { - if (failed(typeConverter->convertTypes(type, dstTypes))) + for (Value v : op.getResults()) { + if (failed(typeConverter->convertType(v, dstTypes))) return rewriter.notifyMatchFailure(op, "could not convert result type"); offsets.push_back(dstTypes.size()); } diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index bd11bbe58a3f6..2a1d154faeaf3 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1256,7 +1256,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( // If there is no legal conversion, fail to match this pattern. SmallVector legalTypes; - if (failed(currentTypeConverter->convertType(origType, legalTypes))) { + if (failed(currentTypeConverter->convertType(operand, legalTypes))) { notifyMatchFailure(operandLoc, [=](Diagnostic &diag) { diag << "unable to convert type for " << valueDiagTag << " #" << it.index() << ", type was " << origType; @@ -2899,6 +2899,28 @@ LogicalResult TypeConverter::convertType(Type t, return failure(); } +LogicalResult TypeConverter::convertType(Value v, + SmallVectorImpl &results) const { + assert(v && "expected non-null value"); + + // If this type converter does not have context-aware type conversions, call + // the type-based overload, which has caching. + if (!hasContextAwareTypeConversions) { + return convertType(v.getType(), results); + } + + // Walk the added converters in reverse order to apply the most recently + // registered first. + for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) { + if (std::optional result = converter(v, results)) { + if (!succeeded(*result)) + return failure(); + return success(); + } + } + return failure(); +} + Type TypeConverter::convertType(Type t) const { // Use the multi-type result version to convert the type. SmallVector results; @@ -2909,6 +2931,16 @@ Type TypeConverter::convertType(Type t) const { return results.size() == 1 ? results.front() : nullptr; } +Type TypeConverter::convertType(Value v) const { + // Use the multi-type result version to convert the type. + SmallVector results; + if (failed(convertType(v, results))) + return nullptr; + + // Check to ensure that only one type was produced. + return results.size() == 1 ? results.front() : nullptr; +} + LogicalResult TypeConverter::convertTypes(TypeRange types, SmallVectorImpl &results) const { @@ -2918,6 +2950,15 @@ TypeConverter::convertTypes(TypeRange types, return success(); } +LogicalResult +TypeConverter::convertTypes(ValueRange values, + SmallVectorImpl &results) const { + for (Value value : values) + if (failed(convertType(value, results))) + return failure(); + return success(); +} + bool TypeConverter::isLegal(Type type) const { return convertType(type) == type; } @@ -3128,7 +3169,7 @@ mlir::convertOpResultTypes(Operation *op, ValueRange operands, newOp.addOperands(operands); SmallVector newResultTypes; - if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes))) + if (failed(converter.convertTypes(op->getResults(), newResultTypes))) return rewriter.notifyMatchFailure(loc, "couldn't convert return types"); newOp.addTypes(newResultTypes); diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir index db8bd0f6378d2..ccb18cb81a6c0 100644 --- a/mlir/test/Transforms/test-legalize-type-conversion.mlir +++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir @@ -142,3 +142,21 @@ func.func @test_signature_conversion_no_converter() { }) : () -> () return } + +// ----- + +// CHECK-LABEL: func @context_aware_conversion() +func.func @context_aware_conversion() { + // Case 1: Convert i37 --> i38. + // CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i37 to i38 + // CHECK: "test.legal_op_d"(%[[cast0]]) : (i38) -> () + %0 = "test.context_op"() {increment = 1 : i64} : () -> (i37) + "test.replace_with_legal_op"(%0) : (i37) -> () + + // Case 2: Convert i37 --> i39. + // CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i37 to i39 + // CHECK: "test.legal_op_d"(%[[cast1]]) : (i39) -> () + %1 = "test.context_op"() {increment = 2 : i64} : () -> (i37) + "test.replace_with_legal_op"(%1) : (i37) -> () + return +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index d073843484d81..bd85e6fd9ae7f 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1827,9 +1827,9 @@ struct TestReplaceWithLegalOp : public ConversionPattern { : ConversionPattern(converter, "test.replace_with_legal_op", /*benefit=*/1, ctx) {} LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - rewriter.replaceOpWithNewOp(op, operands[0]); + rewriter.replaceOpWithNewOp(op, operands[0].front()); return success(); } }; @@ -1865,7 +1865,7 @@ struct TestTypeConversionDriver return nullptr; }); converter.addConversion([](IntegerType type, SmallVectorImpl &) { - // Drop all integer types. + // Drop all other integer types. return success(); }); converter.addConversion( @@ -1902,6 +1902,19 @@ struct TestTypeConversionDriver results.push_back(result); return success(); }); + converter.addConversion([](Value v) -> std::optional { + auto intType = dyn_cast(v.getType()); + if (!intType || intType.getWidth() != 37) + return std::nullopt; + Operation *op = v.getDefiningOp(); + if (!op) + return std::nullopt; + auto incrementAttr = op->getAttrOfType("increment"); + if (!incrementAttr) + return std::nullopt; + return IntegerType::get(v.getContext(), + intType.getWidth() + incrementAttr.getInt()); + }); /// Add the legal set of type materializations. converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, @@ -1922,9 +1935,19 @@ struct TestTypeConversionDriver // Otherwise, fail. return nullptr; }); + // Materialize i37 to any desired type with unrealized_conversion_cast. + converter.addTargetMaterialization([](OpBuilder &builder, Type type, + ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1 || !inputs[0].getType().isInteger(37)) + return Value(); + return builder.create(loc, type, inputs) + .getResult(0); + }); // Initialize the conversion target. mlir::ConversionTarget target(getContext()); + target.addLegalOp(OperationName("test.context_op", &getContext())); target.addLegalOp(); target.addDynamicallyLegalOp([](TestTypeProducerOp op) { auto recursiveType = dyn_cast(op.getType()); From c61f6597429658e5e3f6f511a9211aba087676a6 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 19 May 2025 09:43:20 +0900 Subject: [PATCH 2/3] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Markus Böck --- mlir/include/mlir/Transforms/DialectConversion.h | 2 +- mlir/lib/Transforms/Utils/DialectConversion.cpp | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 07adbde3a5a60..4bd506a4b763c 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -361,7 +361,7 @@ class TypeConverter { /// types is empty, the type is removed and any usages of the existing value /// are expected to be removed during conversion. using ConversionCallbackFn = std::function( - std::variant, SmallVectorImpl &)>; + PointerUnion, SmallVectorImpl &)>; /// The signature of the callback used to materialize a source conversion. /// diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 2a1d154faeaf3..1b23faa683fe3 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2905,9 +2905,8 @@ LogicalResult TypeConverter::convertType(Value v, // If this type converter does not have context-aware type conversions, call // the type-based overload, which has caching. - if (!hasContextAwareTypeConversions) { + if (!hasContextAwareTypeConversions) return convertType(v.getType(), results); - } // Walk the added converters in reverse order to apply the most recently // registered first. From 12e6416884fe1bb03d5d7e20d0738bfa56127afc Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 19 May 2025 02:54:13 +0200 Subject: [PATCH 3/3] fix --- .../mlir/Transforms/DialectConversion.h | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 4bd506a4b763c..8152aa72d6db2 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -18,7 +18,6 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/StringMap.h" #include -#include namespace mlir { @@ -405,13 +404,13 @@ class TypeConverter { ConversionCallbackFn> wrapCallback(FnT &&callback) const { return [callback = std::forward(callback)]( - std::variant type, + PointerUnion typeOrValue, SmallVectorImpl &results) -> std::optional { T derivedType; - if (Type *t = std::get_if(&type)) { - derivedType = dyn_cast(*t); - } else if (Value *v = std::get_if(&type)) { - derivedType = dyn_cast(v->getType()); + if (Type t = dyn_cast(typeOrValue)) { + derivedType = dyn_cast(t); + } else if (Value v = dyn_cast(typeOrValue)) { + derivedType = dyn_cast(v.getType()); } else { llvm_unreachable("unexpected variant"); } @@ -429,13 +428,13 @@ class TypeConverter { wrapCallback(FnT &&callback) { hasContextAwareTypeConversions = true; return [callback = std::forward(callback)]( - std::variant type, + PointerUnion typeOrValue, SmallVectorImpl &results) -> std::optional { - if (Type *t = std::get_if(&type)) { + if (Type t = dyn_cast(typeOrValue)) { // Context-aware type conversion was called with a type. return std::nullopt; - } else if (Value *v = std::get_if(&type)) { - return callback(*v, results); + } else if (Value v = dyn_cast(typeOrValue)) { + return callback(v, results); } llvm_unreachable("unexpected variant"); return std::nullopt;