Skip to content

Commit b60c6cb

Browse files
authored
[MLIR][TableGen] Migrate MLIR backends to use const RecordKeeper (#107505)
- Migrate MLIR backends to use a const RecordKeeper reference.
1 parent c1e3b99 commit b60c6cb

18 files changed

+146
-152
lines changed

mlir/include/mlir/TableGen/CodeGenHelpers.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,15 +106,15 @@ class StaticVerifierFunctionEmitter {
106106
StringRef tag = "");
107107

108108
/// Collect and unique all the constraints used by operations.
109-
void collectOpConstraints(ArrayRef<llvm::Record *> opDefs);
109+
void collectOpConstraints(ArrayRef<const llvm::Record *> opDefs);
110110

111111
/// Collect and unique all compatible type, attribute, successor, and region
112112
/// constraints from the operations in the file and emit them at the top of
113113
/// the generated file.
114114
///
115115
/// Constraints that do not meet the restriction that they can only reference
116116
/// `$_self` and `$_op` are not uniqued.
117-
void emitOpConstraints(ArrayRef<llvm::Record *> opDefs);
117+
void emitOpConstraints(ArrayRef<const llvm::Record *> opDefs);
118118

119119
/// Unique all compatible type and attribute constraints from a pattern file
120120
/// and emit them at the top of the generated file.

mlir/include/mlir/TableGen/GenInfo.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ class RecordKeeper;
2121
namespace mlir {
2222

2323
/// Generator function to invoke.
24-
using GenFunction =
25-
std::function<bool(llvm::RecordKeeper &recordKeeper, raw_ostream &os)>;
24+
using GenFunction = std::function<bool(const llvm::RecordKeeper &recordKeeper,
25+
raw_ostream &os)>;
2626

2727
/// Structure to group information about a generator (argument to invoke via
2828
/// mlir-tblgen, description, and generator function).
@@ -34,7 +34,7 @@ class GenInfo {
3434
: arg(arg), description(description), generator(std::move(generator)) {}
3535

3636
/// Invokes the generator and returns whether the generator failed.
37-
bool invoke(llvm::RecordKeeper &recordKeeper, raw_ostream &os) const {
37+
bool invoke(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) const {
3838
assert(generator && "Cannot call generator with null generator");
3939
return generator(recordKeeper, os);
4040
}

mlir/lib/TableGen/CodeGenHelpers.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
4949
: os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {}
5050

5151
void StaticVerifierFunctionEmitter::emitOpConstraints(
52-
ArrayRef<llvm::Record *> opDefs) {
52+
ArrayRef<const llvm::Record *> opDefs) {
5353
NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace());
5454
emitTypeConstraints();
5555
emitAttrConstraints();
@@ -264,14 +264,14 @@ void StaticVerifierFunctionEmitter::collectConstraint(ConstraintMap &map,
264264
}
265265

266266
void StaticVerifierFunctionEmitter::collectOpConstraints(
267-
ArrayRef<Record *> opDefs) {
267+
ArrayRef<const Record *> opDefs) {
268268
const auto collectTypeConstraints = [&](Operator::const_value_range values) {
269269
for (const NamedTypeConstraint &value : values)
270270
if (value.hasPredicate())
271271
collectConstraint(typeConstraints, "type", value.constraint);
272272
};
273273

274-
for (Record *def : opDefs) {
274+
for (const Record *def : opDefs) {
275275
Operator op(*def);
276276
/// Collect type constraints.
277277
collectTypeConstraints(op.getOperands());

mlir/lib/Tools/PDLL/Parser/Parser.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ class Parser {
164164
SmallVectorImpl<ast::Decl *> &decls);
165165

166166
/// Process the records of a parsed tablegen include file.
167-
void processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
167+
void processTdIncludeRecords(const llvm::RecordKeeper &tdRecords,
168168
SmallVectorImpl<ast::Decl *> &decls);
169169

170170
/// Create a user defined native constraint for a constraint imported from
@@ -863,7 +863,7 @@ LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
863863
return success();
864864
}
865865

866-
void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
866+
void Parser::processTdIncludeRecords(const llvm::RecordKeeper &tdRecords,
867867
SmallVectorImpl<ast::Decl *> &decls) {
868868
// Return the length kind of the given value.
869869
auto getLengthKind = [](const auto &value) {
@@ -887,7 +887,7 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
887887

888888
// Process the parsed tablegen records to build ODS information.
889889
/// Operations.
890-
for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) {
890+
for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) {
891891
tblgen::Operator op(def);
892892

893893
// Check to see if this operation is known to support type inferrence.
@@ -920,13 +920,13 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
920920
}
921921
}
922922

923-
auto shouldBeSkipped = [this](llvm::Record *def) {
923+
auto shouldBeSkipped = [this](const llvm::Record *def) {
924924
return def->isAnonymous() || curDeclScope->lookup(def->getName()) ||
925925
def->isSubClassOf("DeclareInterfaceMethods");
926926
};
927927

928928
/// Attr constraints.
929-
for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
929+
for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
930930
if (shouldBeSkipped(def))
931931
continue;
932932

@@ -936,7 +936,7 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
936936
constraint.getStorageType()));
937937
}
938938
/// Type constraints.
939-
for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
939+
for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
940940
if (shouldBeSkipped(def))
941941
continue;
942942

@@ -947,7 +947,8 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
947947
}
948948
/// OpInterfaces.
949949
ast::Type opTy = ast::OperationType::get(ctx);
950-
for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("OpInterface")) {
950+
for (const llvm::Record *def :
951+
tdRecords.getAllDerivedDefinitions("OpInterface")) {
951952
if (shouldBeSkipped(def))
952953
continue;
953954

mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ static bool findUse(Record &record, Init *deprecatedInit,
9090
});
9191
}
9292

93-
static void warnOfDeprecatedUses(RecordKeeper &records) {
93+
static void warnOfDeprecatedUses(const RecordKeeper &records) {
9494
// This performs a direct check for any def marked as deprecated and then
9595
// finds all uses of deprecated def. Deprecated defs are not expected to be
9696
// either numerous or long lived.

mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ using namespace mlir::tblgen;
3030
/// Find all the AttrOrTypeDef for the specified dialect. If no dialect
3131
/// specified and can only find one dialect's defs, use that.
3232
static void collectAllDefs(StringRef selectedDialect,
33-
std::vector<llvm::Record *> records,
33+
ArrayRef<const llvm::Record *> records,
3434
SmallVectorImpl<AttrOrTypeDef> &resultDefs) {
3535
// Nothing to do if no defs were found.
3636
if (records.empty())
@@ -690,14 +690,15 @@ class DefGenerator {
690690
bool emitDefs(StringRef selectedDialect);
691691

692692
protected:
693-
DefGenerator(const std::vector<llvm::Record *> &defs, raw_ostream &os,
693+
DefGenerator(ArrayRef<const llvm::Record *> defs, raw_ostream &os,
694694
StringRef defType, StringRef valueType, bool isAttrGenerator)
695695
: defRecords(defs), os(os), defType(defType), valueType(valueType),
696696
isAttrGenerator(isAttrGenerator) {
697697
// Sort by occurrence in file.
698-
llvm::sort(defRecords, [](llvm::Record *lhs, llvm::Record *rhs) {
699-
return lhs->getID() < rhs->getID();
700-
});
698+
llvm::sort(defRecords,
699+
[](const llvm::Record *lhs, const llvm::Record *rhs) {
700+
return lhs->getID() < rhs->getID();
701+
});
701702
}
702703

703704
/// Emit the list of def type names.
@@ -706,7 +707,7 @@ class DefGenerator {
706707
void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs);
707708

708709
/// The set of def records to emit.
709-
std::vector<llvm::Record *> defRecords;
710+
std::vector<const llvm::Record *> defRecords;
710711
/// The attribute or type class to emit.
711712
/// The stream to emit to.
712713
raw_ostream &os;
@@ -721,13 +722,13 @@ class DefGenerator {
721722

722723
/// A specialized generator for AttrDefs.
723724
struct AttrDefGenerator : public DefGenerator {
724-
AttrDefGenerator(llvm::RecordKeeper &records, raw_ostream &os)
725+
AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
725726
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("AttrDef"), os,
726727
"Attr", "Attribute", /*isAttrGenerator=*/true) {}
727728
};
728729
/// A specialized generator for TypeDefs.
729730
struct TypeDefGenerator : public DefGenerator {
730-
TypeDefGenerator(llvm::RecordKeeper &records, raw_ostream &os)
731+
TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
731732
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("TypeDef"), os,
732733
"Type", "Type", /*isAttrGenerator=*/false) {}
733734
};
@@ -1029,9 +1030,9 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
10291030

10301031
/// Find all type constraints for which a C++ function should be generated.
10311032
static std::vector<Constraint>
1032-
getAllTypeConstraints(llvm::RecordKeeper &records) {
1033+
getAllTypeConstraints(const llvm::RecordKeeper &records) {
10331034
std::vector<Constraint> result;
1034-
for (llvm::Record *def :
1035+
for (const llvm::Record *def :
10351036
records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) {
10361037
// Ignore constraints defined outside of the top-level file.
10371038
if (llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
@@ -1046,7 +1047,7 @@ getAllTypeConstraints(llvm::RecordKeeper &records) {
10461047
return result;
10471048
}
10481049

1049-
static void emitTypeConstraintDecls(llvm::RecordKeeper &records,
1050+
static void emitTypeConstraintDecls(const llvm::RecordKeeper &records,
10501051
raw_ostream &os) {
10511052
static const char *const typeConstraintDecl = R"(
10521053
bool {0}(::mlir::Type type);
@@ -1056,7 +1057,7 @@ bool {0}(::mlir::Type type);
10561057
os << strfmt(typeConstraintDecl, *constr.getCppFunctionName());
10571058
}
10581059

1059-
static void emitTypeConstraintDefs(llvm::RecordKeeper &records,
1060+
static void emitTypeConstraintDefs(const llvm::RecordKeeper &records,
10601061
raw_ostream &os) {
10611062
static const char *const typeConstraintDef = R"(
10621063
bool {0}(::mlir::Type type) {
@@ -1087,13 +1088,13 @@ static llvm::cl::opt<std::string>
10871088

10881089
static mlir::GenRegistration
10891090
genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions",
1090-
[](llvm::RecordKeeper &records, raw_ostream &os) {
1091+
[](const llvm::RecordKeeper &records, raw_ostream &os) {
10911092
AttrDefGenerator generator(records, os);
10921093
return generator.emitDefs(attrDialect);
10931094
});
10941095
static mlir::GenRegistration
10951096
genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations",
1096-
[](llvm::RecordKeeper &records, raw_ostream &os) {
1097+
[](const llvm::RecordKeeper &records, raw_ostream &os) {
10971098
AttrDefGenerator generator(records, os);
10981099
return generator.emitDecls(attrDialect);
10991100
});
@@ -1109,28 +1110,28 @@ static llvm::cl::opt<std::string>
11091110

11101111
static mlir::GenRegistration
11111112
genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions",
1112-
[](llvm::RecordKeeper &records, raw_ostream &os) {
1113+
[](const llvm::RecordKeeper &records, raw_ostream &os) {
11131114
TypeDefGenerator generator(records, os);
11141115
return generator.emitDefs(typeDialect);
11151116
});
11161117
static mlir::GenRegistration
11171118
genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations",
1118-
[](llvm::RecordKeeper &records, raw_ostream &os) {
1119+
[](const llvm::RecordKeeper &records, raw_ostream &os) {
11191120
TypeDefGenerator generator(records, os);
11201121
return generator.emitDecls(typeDialect);
11211122
});
11221123

11231124
static mlir::GenRegistration
11241125
genTypeConstrDefs("gen-type-constraint-defs",
11251126
"Generate type constraint definitions",
1126-
[](llvm::RecordKeeper &records, raw_ostream &os) {
1127+
[](const llvm::RecordKeeper &records, raw_ostream &os) {
11271128
emitTypeConstraintDefs(records, os);
11281129
return false;
11291130
});
11301131
static mlir::GenRegistration
11311132
genTypeConstrDecls("gen-type-constraint-decls",
11321133
"Generate type constraint declarations",
1133-
[](llvm::RecordKeeper &records, raw_ostream &os) {
1134+
[](const llvm::RecordKeeper &records, raw_ostream &os) {
11341135
emitTypeConstraintDecls(records, os);
11351136
return false;
11361137
});

mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,14 +429,15 @@ struct AttrOrType {
429429

430430
static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) {
431431
MapVector<StringRef, AttrOrType> dialectAttrOrType;
432-
for (auto &it : records.getAllDerivedDefinitions("DialectAttributes")) {
432+
for (const Record *it :
433+
records.getAllDerivedDefinitions("DialectAttributes")) {
433434
if (!selectedBcDialect.empty() &&
434435
it->getValueAsString("dialect") != selectedBcDialect)
435436
continue;
436437
dialectAttrOrType[it->getValueAsString("dialect")].attr =
437438
it->getValueAsListOfDefs("elems");
438439
}
439-
for (auto &it : records.getAllDerivedDefinitions("DialectTypes")) {
440+
for (const Record *it : records.getAllDerivedDefinitions("DialectTypes")) {
440441
if (!selectedBcDialect.empty() &&
441442
it->getValueAsString("dialect") != selectedBcDialect)
442443
continue;

mlir/tools/mlir-tblgen/DirectiveCommonGen.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
using llvm::Clause;
2424
using llvm::ClauseVal;
2525
using llvm::raw_ostream;
26+
using llvm::Record;
2627
using llvm::RecordKeeper;
2728

2829
// LLVM has multiple places (Clang, Flang, MLIR) where information about
@@ -49,13 +50,11 @@ static bool emitDecls(const RecordKeeper &recordKeeper, llvm::StringRef dialect,
4950
"'--directives-dialect'");
5051
}
5152

52-
const auto &directiveLanguages =
53+
const auto directiveLanguages =
5354
recordKeeper.getAllDerivedDefinitions("DirectiveLanguage");
5455
assert(!directiveLanguages.empty() && "DirectiveLanguage missing.");
5556

56-
const auto &clauses = recordKeeper.getAllDerivedDefinitions("Clause");
57-
58-
for (const auto &r : clauses) {
57+
for (const Record *r : recordKeeper.getAllDerivedDefinitions("Clause")) {
5958
Clause c{r};
6059
const auto &clauseVals = c.getClauseVals();
6160
if (clauseVals.empty())

mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,14 @@ static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
136136
static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
137137
raw_ostream &os) {
138138
os << fileHeader;
139-
for (auto &it :
139+
for (const llvm::Record *it :
140140
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
141141
EnumAttr enumAttr(*it);
142142
emitEnumClass(enumAttr, os);
143143
emitAttributeBuilder(enumAttr, os);
144144
}
145-
for (auto &it : recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
145+
for (const llvm::Record *it :
146+
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
146147
AttrOrTypeDef attr(&*it);
147148
if (!attr.getMnemonic()) {
148149
llvm::errs() << "enum case " << attr

mlir/tools/mlir-tblgen/EnumsGen.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -645,8 +645,8 @@ class {1} : public ::mlir::{2} {
645645
static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
646646
llvm::emitSourceFileHeader("Enum Utility Declarations", os, recordKeeper);
647647

648-
auto defs = recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo");
649-
for (const auto *def : defs)
648+
for (const Record *def :
649+
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo"))
650650
emitEnumDecl(*def, os);
651651

652652
return false;
@@ -683,8 +683,8 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
683683
static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
684684
llvm::emitSourceFileHeader("Enum Utility Definitions", os, recordKeeper);
685685

686-
auto defs = recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo");
687-
for (const auto *def : defs)
686+
for (const Record *def :
687+
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo"))
688688
emitEnumDef(*def, os);
689689

690690
return false;

0 commit comments

Comments
 (0)