diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index f22ad1fd70db2..1b4ea6b1164ec 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -194,7 +194,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, DenseI32ArrayAttr knownGridSize = gpuFuncOp.getKnownGridSizeAttr(); // Ensure we don't lose information if the function is lowered before its // surrounding context. - auto *gpuDialect = cast(gpuFuncOp->getDialect()); + auto *gpuDialect = gpu::GPUDialect::getLoaded(gpuFuncOp); if (knownBlockSize) attributes.emplace_back(gpuDialect->getKnownBlockSizeAttrHelper().getName(), knownBlockSize); diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h index 1f158b271e5c6..d7aa5f70d984a 100644 --- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h +++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h @@ -74,21 +74,18 @@ struct OpLowering : public ConvertOpToLLVMPattern { // 3. Discardable attributes on a surrounding function of any kind // The below code handles these in reverse order so that more important // sources overwrite less important ones. + auto *gpuDialect = gpu::GPUDialect::getLoaded(op); DenseI32ArrayAttr funcBounds = nullptr; if (auto funcOp = op->template getParentOfType()) { switch (indexKind) { case IndexKind::Block: { - auto blockHelper = - gpu::GPUDialect::KnownBlockSizeAttrHelper(op.getContext()); - if (blockHelper.isAttrPresent(funcOp)) - funcBounds = blockHelper.getAttr(funcOp); + auto blockHelper = gpuDialect->getKnownBlockSizeAttrHelper(); + funcBounds = blockHelper.getAttr(funcOp); break; } case IndexKind::Grid: { - auto gridHelper = - gpu::GPUDialect::KnownGridSizeAttrHelper(op.getContext()); - if (gridHelper.isAttrPresent(funcOp)) - funcBounds = gridHelper.getAttr(funcOp); + auto gridHelper = gpuDialect->getKnownGridSizeAttrHelper(); + funcBounds = gridHelper.getAttr(funcOp); break; } case IndexKind::Other: diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index c6c695b442b4f..4a4c97dfc7bc0 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -326,7 +326,7 @@ struct LowerGpuOpsToROCDLOpsPass final configureGpuToROCDLConversionLegality(target); if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) signalPassFailure(); - auto *rocdlDialect = getContext().getLoadedDialect(); + auto *rocdlDialect = ROCDL::ROCDLDialect::getLoaded(getContext()); auto reqdWorkGroupSizeAttrHelper = rocdlDialect->getReqdWorkGroupSizeAttrHelper(); auto flatWorkGroupSizeAttrHelper = @@ -374,8 +374,7 @@ void mlir::populateGpuToROCDLConversionPatterns( using gpu::index_lowering::IndexKind; using gpu::index_lowering::IntrType; using mlir::gpu::amd::Runtime; - auto *rocdlDialect = - converter.getContext().getLoadedDialect(); + auto *rocdlDialect = ROCDL::ROCDLDialect::getLoaded(converter.getContext()); populateWithGenerated(patterns); patterns.add< gpu::index_lowering::OpLowering instructions, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const final { - auto *dialect = dyn_cast(attribute.getNameDialect()); + auto *dialect = ROCDL::ROCDLDialect::getLoaded(op); llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); if (dialect->getKernelAttrHelper().getName() == attribute.getName()) { auto func = dyn_cast(op); diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp index 6cf71d2bb0174..700f68e940f13 100644 --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -110,15 +110,23 @@ tblgen::findDialectToGenerate(ArrayRef dialects) { /// {2}: The dialect parent class. static const char *const dialectDeclBeginStr = R"( class {0} : public ::mlir::{2} { + typedef {0} DialectType; explicit {0}(::mlir::MLIRContext *context); void initialize(); friend class ::mlir::MLIRContext; public: ~{0}() override; - static constexpr ::llvm::StringLiteral getDialectNamespace() { + static constexpr ::llvm::StringLiteral getDialectNamespace() {{ return ::llvm::StringLiteral("{1}"); } + static const DialectType *getLoaded(::mlir::MLIRContext &context) {{ + return context.getLoadedDialect(); + } + static const DialectType *getLoaded(::mlir::MLIRContext *context) {{ + return getLoaded(*context); + } + static const DialectType *getLoaded(::mlir::Operation *operation); )"; /// Registration for a single dependent dialect: to be inserted in the ctor @@ -206,28 +214,28 @@ static const char *const discardableAttrHelperDecl = R"( static constexpr ::llvm::StringLiteral getNameStr() {{ return "{4}.{1}"; } - constexpr ::mlir::StringAttr getName() {{ + constexpr ::mlir::StringAttr getName() const {{ return name; } {0}AttrHelper(::mlir::MLIRContext *ctx) : name(::mlir::StringAttr::get(ctx, getNameStr())) {{} - {2} getAttr(::mlir::Operation *op) {{ - return op->getAttrOfType<{2}>(name); - } - void setAttr(::mlir::Operation *op, {2} val) {{ - op->setAttr(name, val); - } - bool isAttrPresent(::mlir::Operation *op) {{ - return op->hasAttrOfType<{2}>(name); - } - void removeAttr(::mlir::Operation *op) {{ - assert(op->hasAttrOfType<{2}>(name)); - op->removeAttr(name); - } + {2} getAttr(::mlir::Operation *op) const {{ + return op->getAttrOfType<{2}>(name); + } + void setAttr(::mlir::Operation *op, {2} val) const {{ + op->setAttr(name, val); + } + bool isAttrPresent(::mlir::Operation *op) const {{ + return op->hasAttrOfType<{2}>(name); + } + void removeAttr(::mlir::Operation *op) const {{ + assert(op->hasAttrOfType<{2}>(name)); + op->removeAttr(name); + } }; - {0}AttrHelper get{0}AttrHelper() { + const {0}AttrHelper get{0}AttrHelper() const { return {3}AttrName; } private: @@ -342,6 +350,16 @@ static const char *const dialectDestructorStr = R"( )"; +/// The code block to generate a member funcs. +/// +/// {0}: The name of the dialect class. +static const char *const dialectStaticMemberDefs = R"( +const {0} *{0}::getLoaded(::mlir::Operation *operation) {{ + return getLoaded(*operation->getContext()); +} + +)"; + static void emitDialectDef(Dialect &dialect, const RecordKeeper &records, raw_ostream &os) { std::string cppClassName = dialect.getCppClassName(); @@ -388,6 +406,9 @@ static void emitDialectDef(Dialect &dialect, const RecordKeeper &records, discardableAttributesInit); if (!dialect.hasNonDefaultDestructor()) os << llvm::formatv(dialectDestructorStr, cppClassName); + + // Emit member function definitions. + os << llvm::formatv(dialectStaticMemberDefs, cppClassName); } static bool emitDialectDefs(const RecordKeeper &records, raw_ostream &os) {