diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 3b669f51a615f..623cb81111d03 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -869,9 +869,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// Append a rewrite. Rewrites are committed upon success and rolled back upon /// failure. template - void appendRewrite(Args &&...args) { + RewriteTy *appendRewrite(Args &&...args) { rewrites.push_back( std::make_unique(*this, std::forward(args)...)); + return static_cast(rewrites.back().get()); } /// Undo the rewrites (motions, splits) one by one in reverse order until @@ -1181,7 +1182,6 @@ UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite( mappedValues(std::move(mappedValues)) { assert((!originalType || kind == MaterializationKind::Target) && "original type is valid only for target materializations"); - rewriterImpl.unresolvedMaterializations[op] = this; } void UnresolvedMaterializationRewrite::rollback() { @@ -1471,8 +1471,9 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( mapping.map(valuesToMap, convertOp.getResults()); if (castOp) *castOp = convertOp; - appendRewrite( - convertOp, converter, kind, originalType, std::move(valuesToMap)); + unresolvedMaterializations[convertOp] = + appendRewrite( + convertOp, converter, kind, originalType, std::move(valuesToMap)); return convertOp.getResults(); }