diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index ab91b32d7ef43..06977e07fda5b 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -186,6 +186,10 @@ struct CppEmitter { /// Return the existing or a new name for a Value. StringRef getOrCreateName(Value val); + /// Return the existing or a new name for a loop induction variable of an + /// emitc::ForOp. + StringRef getOrCreateInductionVarName(Value val); + // Returns the textual representation of a subscript operation. std::string getSubscriptName(emitc::SubscriptOp op); @@ -201,23 +205,39 @@ struct CppEmitter { /// Whether to map an mlir integer to a unsigned integer in C++. bool shouldMapToUnsigned(IntegerType::SignednessSemantics val); - /// RAII helper function to manage entering/exiting C++ scopes. + /// Abstract RAII helper function to manage entering/exiting C++ scopes. struct Scope { + ~Scope() { emitter.labelInScopeCount.pop(); } + + private: + llvm::ScopedHashTableScope valueMapperScope; + llvm::ScopedHashTableScope blockMapperScope; + + protected: Scope(CppEmitter &emitter) : valueMapperScope(emitter.valueMapper), blockMapperScope(emitter.blockMapper), emitter(emitter) { - emitter.valueInScopeCount.push(emitter.valueInScopeCount.top()); emitter.labelInScopeCount.push(emitter.labelInScopeCount.top()); } - ~Scope() { - emitter.valueInScopeCount.pop(); - emitter.labelInScopeCount.pop(); + CppEmitter &emitter; + }; + + /// RAII helper function to manage entering/exiting functions, while re-using + /// value names. + struct FunctionScope : Scope { + FunctionScope(CppEmitter &emitter) : Scope(emitter) { + // Re-use value names. + emitter.resetValueCounter(); } + }; - private: - llvm::ScopedHashTableScope valueMapperScope; - llvm::ScopedHashTableScope blockMapperScope; - CppEmitter &emitter; + /// RAII helper function to manage entering/exiting emitc::forOp loops and + /// handle induction variable naming. + struct LoopScope : Scope { + LoopScope(CppEmitter &emitter) : Scope(emitter) { + emitter.increaseLoopNestingLevel(); + } + ~LoopScope() { emitter.decreaseLoopNestingLevel(); } }; /// Returns wether the Value is assigned to a C++ variable in the scope. @@ -253,6 +273,15 @@ struct CppEmitter { return operandExpression == emittedExpression; }; + // Resets the value counter to 0. + void resetValueCounter(); + + // Increases the loop nesting level by 1. + void increaseLoopNestingLevel(); + + // Decreases the loop nesting level by 1. + void decreaseLoopNestingLevel(); + private: using ValueMapper = llvm::ScopedHashTable; using BlockMapper = llvm::ScopedHashTable; @@ -274,11 +303,19 @@ struct CppEmitter { /// Map from block to name of C++ label. BlockMapper blockMapper; - /// The number of values in the current scope. This is used to declare the - /// names of values in a scope. - std::stack valueInScopeCount; + /// Default values representing outermost scope. + llvm::ScopedHashTableScope defaultValueMapperScope; + llvm::ScopedHashTableScope defaultBlockMapperScope; + std::stack labelInScopeCount; + /// Keeps track of the amount of nested loops the emitter currently operates + /// in. + uint64_t loopNestingLevel{0}; + + /// Emitter-level count of created values to enable unique identifiers. + unsigned int valueCount{0}; + /// State of the current expression being emitted. ExpressionOp emittedExpression; SmallVector emittedExpressionPrecedence; @@ -860,7 +897,6 @@ static LogicalResult printOperation(CppEmitter &emitter, } static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) { - raw_indented_ostream &os = emitter.ostream(); // Utility function to determine whether a value is an expression that will be @@ -879,12 +915,12 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) { emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType()))) return failure(); os << " "; - os << emitter.getOrCreateName(forOp.getInductionVar()); + os << emitter.getOrCreateInductionVarName(forOp.getInductionVar()); os << " = "; if (failed(emitter.emitOperand(forOp.getLowerBound()))) return failure(); os << "; "; - os << emitter.getOrCreateName(forOp.getInductionVar()); + os << emitter.getOrCreateInductionVarName(forOp.getInductionVar()); os << " < "; Value upperBound = forOp.getUpperBound(); bool upperBoundRequiresParentheses = requiresParentheses(upperBound); @@ -895,13 +931,15 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) { if (upperBoundRequiresParentheses) os << ")"; os << "; "; - os << emitter.getOrCreateName(forOp.getInductionVar()); + os << emitter.getOrCreateInductionVarName(forOp.getInductionVar()); os << " += "; if (failed(emitter.emitOperand(forOp.getStep()))) return failure(); os << ") {\n"; os.indent(); + CppEmitter::LoopScope lScope(emitter); + Region &forRegion = forOp.getRegion(); auto regionOps = forRegion.getOps(); @@ -988,8 +1026,6 @@ static LogicalResult printOperation(CppEmitter &emitter, } static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) { - CppEmitter::Scope scope(emitter); - for (Operation &op : moduleOp) { if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false))) return failure(); @@ -998,7 +1034,6 @@ static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) { } static LogicalResult printOperation(CppEmitter &emitter, ClassOp classOp) { - CppEmitter::Scope classScope(emitter); raw_indented_ostream &os = emitter.ostream(); os << "class " << classOp.getSymName(); if (classOp.getFinalSpecifier()) @@ -1044,8 +1079,6 @@ static LogicalResult printOperation(CppEmitter &emitter, FileOp file) { if (!emitter.shouldEmitFile(file)) return success(); - CppEmitter::Scope scope(emitter); - for (Operation &op : file) { if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false))) return failure(); @@ -1161,7 +1194,7 @@ static LogicalResult printOperation(CppEmitter &emitter, return functionOp.emitOpError() << "cannot emit array type as result type"; } - CppEmitter::Scope scope(emitter); + CppEmitter::FunctionScope scope(emitter); raw_indented_ostream &os = emitter.ostream(); if (failed(emitter.emitTypes(functionOp.getLoc(), functionOp.getFunctionType().getResults()))) @@ -1189,7 +1222,7 @@ static LogicalResult printOperation(CppEmitter &emitter, "with multiple blocks needs variables declared at top"); } - CppEmitter::Scope scope(emitter); + CppEmitter::FunctionScope scope(emitter); raw_indented_ostream &os = emitter.ostream(); if (functionOp.getSpecifiers()) { for (Attribute specifier : functionOp.getSpecifiersAttr()) { @@ -1223,7 +1256,6 @@ static LogicalResult printOperation(CppEmitter &emitter, static LogicalResult printOperation(CppEmitter &emitter, DeclareFuncOp declareFuncOp) { - CppEmitter::Scope scope(emitter); raw_indented_ostream &os = emitter.ostream(); auto functionOp = SymbolTable::lookupNearestSymbolFrom( @@ -1255,8 +1287,8 @@ static LogicalResult printOperation(CppEmitter &emitter, CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop, StringRef fileId) : os(os), declareVariablesAtTop(declareVariablesAtTop), - fileId(fileId.str()) { - valueInScopeCount.push(0); + fileId(fileId.str()), defaultValueMapperScope(valueMapper), + defaultBlockMapperScope(blockMapper) { labelInScopeCount.push(0); } @@ -1297,7 +1329,26 @@ StringRef CppEmitter::getOrCreateName(Value val) { assert(!hasDeferredEmission(val.getDefiningOp()) && "cacheDeferredOpResult should have been called on this value, " "update the emitOperation function."); - valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top())); + + valueMapper.insert(val, formatv("v{0}", ++valueCount)); + } + return *valueMapper.begin(val); +} + +/// Return the existing or a new name for a loop induction variable Value. +/// Loop induction variables follow natural naming: i, j, k, ..., t, uX. +StringRef CppEmitter::getOrCreateInductionVarName(Value val) { + if (!valueMapper.count(val)) { + + int64_t identifier = 'i' + loopNestingLevel; + + if (identifier >= 'i' && identifier <= 't') { + valueMapper.insert(val, + formatv("{0}{1}", (char)identifier, ++valueCount)); + } else { + // If running out of letters, continue with uX. + valueMapper.insert(val, formatv("u{0}", ++valueCount)); + } } return *valueMapper.begin(val); } @@ -1838,6 +1889,12 @@ LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef types) { return success(); } +void CppEmitter::resetValueCounter() { valueCount = 0; } + +void CppEmitter::increaseLoopNestingLevel() { loopNestingLevel++; } + +void CppEmitter::decreaseLoopNestingLevel() { loopNestingLevel--; } + LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os, bool declareVariablesAtTop, StringRef fileId) { diff --git a/mlir/test/Target/Cpp/for_loop_induction_vars.mlir b/mlir/test/Target/Cpp/for_loop_induction_vars.mlir new file mode 100644 index 0000000000000..aa85612aef36a --- /dev/null +++ b/mlir/test/Target/Cpp/for_loop_induction_vars.mlir @@ -0,0 +1,73 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s + +// CHECK-LABEL: test_for_siblings +func.func @test_for_siblings() { + %start = emitc.literal "0" : index + %stop = emitc.literal "10" : index + %step = emitc.literal "1" : index + + %var1 = "emitc.variable"() <{value = 0 : index}> : () -> !emitc.lvalue + %var2 = "emitc.variable"() <{value = 0 : index}> : () -> !emitc.lvalue + + // CHECK: for (size_t [[ITER0:i[0-9]*]] = {{.*}}; [[ITER0]] < {{.*}}; [[ITER0]] += {{.*}}) { + emitc.for %i0 = %start to %stop step %step { + // CHECK: for (size_t [[ITER1:j[0-9]*]] = {{.*}}; [[ITER1]] < {{.*}}; [[ITER1]] += {{.*}}) { + emitc.for %i1 = %start to %stop step %step { + // CHECK: {{.*}} = [[ITER0]]; + //"emitc.assign"(%var1,%i0) : (!emitc.lvalue, !emitc.size_t) -> () + emitc.assign %i0 : index to %var1 : !emitc.lvalue + // CHECK: {{.*}} = [[ITER1]]; + //"emitc.assign"(%var2,%i1) : (!emitc.lvalue, !emitc.size_t) -> () + emitc.assign %i1 : index to %var2 : !emitc.lvalue + } + } + // CHECK: for (size_t [[ITER2:i[0-9]*]] = {{.*}}; [[ITER2]] < {{.*}}; [[ITER2]] += {{.*}}) + emitc.for %ki2 = %start to %stop step %step { + // CHECK: for (size_t [[ITER3:j[0-9]*]] = {{.*}}; [[ITER3]] < {{.*}}; [[ITER3]] += {{.*}}) + emitc.for %i3 = %start to %stop step %step { + %1 = emitc.call_opaque "f"() : () -> i32 + } + } + return +} + +// CHECK-LABEL: test_for_nesting +func.func @test_for_nesting() { + %start = emitc.literal "0" : index + %stop = emitc.literal "10" : index + %step = emitc.literal "1" : index + + // CHECK-COUNT-12: for (size_t [[ITER:[i-t][0-9]*]] = {{.*}}; [[ITER]] < {{.*}}; [[ITER]] += {{.*}}) { + emitc.for %i0 = %start to %stop step %step { + emitc.for %i1 = %start to %stop step %step { + emitc.for %i2 = %start to %stop step %step { + emitc.for %i3 = %start to %stop step %step { + emitc.for %i4 = %start to %stop step %step { + emitc.for %i5 = %start to %stop step %step { + emitc.for %i6 = %start to %stop step %step { + emitc.for %i7 = %start to %stop step %step { + emitc.for %i8 = %start to %stop step %step { + emitc.for %i9 = %start to %stop step %step { + emitc.for %i10 = %start to %stop step %step { + emitc.for %i11 = %start to %stop step %step { + // CHECK: for (size_t [[ITERu0:u13]] = {{.*}}; [[ITERu0]] < {{.*}}; [[ITERu0]] += {{.*}}) { + emitc.for %i14 = %start to %stop step %step { + // CHECK: for (size_t [[ITERu1:u14]] = {{.*}}; [[ITERu1]] < {{.*}}; [[ITERu1]] += {{.*}}) { + emitc.for %i15 = %start to %stop step %step { + %0 = emitc.call_opaque "f"() : () -> i32 + } + } + } + } + } + } + } + } + } + } + } + } + } + } + return +}