diff --git a/tc/core/tc2halide.cc b/tc/core/tc2halide.cc index 3e95233ef..cba38382f 100644 --- a/tc/core/tc2halide.cc +++ b/tc/core/tc2halide.cc @@ -71,6 +71,12 @@ Type translateScalarType(int tcType) { } } +struct TensorInfo { + Type type; + vector args; + map bounds; +}; + // Translate the TC def input params to corresponding Halide components. // params, inputs will be populated here. void translateParam( @@ -112,20 +118,13 @@ void translateParam( (*params)[imageParam.name()] = imageParam.parameter(); } -void translateOutput( - const lang::Param& p, - const map& funcs, - vector* outputs) { - outputs->push_back(funcs.at(p.ident().name())); -} - Expr translateExpr( const lang::TreeRef& expr, const map& params, - const map& funcs, + const map& tensors, const map& lets) { auto t = [&](int idx) { - return translateExpr(expr->tree(idx), params, funcs, lets); + return translateExpr(expr->tree(idx), params, tensors, lets); }; switch (expr->kind()) { case lang::TK_IDENT: { @@ -139,17 +138,18 @@ Expr translateExpr( auto a = lang::Access(expr); string tensorName = a.name().name(); auto paramIt = params.find(tensorName); - auto funcIt = funcs.find(tensorName); + auto tensorIt = tensors.find(tensorName); vector args; for (auto e : a.arguments()) { - args.push_back(translateExpr(e, params, funcs, lets)); + args.push_back(translateExpr(e, params, tensors, lets)); } if (paramIt != params.end()) { // Accessing an input tensor return Call::make(paramIt->second, args); - } else if (funcIt != funcs.end()) { + } else if (tensorIt != tensors.end()) { // Call to a Func - return Call::make(funcIt->second, args); + return Call::make( + tensorIt->second.type, tensorName, args, Call::Halide); } else { LOG(FATAL) << "Access to unknown symbol: " << a << '\n'; return Expr(); @@ -203,7 +203,7 @@ Expr translateExpr( auto b = lang::BuiltIn(expr); vector exprs; for (auto a : b.arguments()) { - exprs.push_back(translateExpr(a, params, funcs, lets)); + exprs.push_back(translateExpr(a, params, tensors, lets)); } auto output_type = translateScalarType(b.type()->kind()); return Call::make(output_type, b.name(), exprs, Call::PureExtern); @@ -220,7 +220,7 @@ Expr translateExpr( } case lang::TK_CAST: { auto c = lang::Cast(expr); - auto v = translateExpr(c.value(), params, funcs, lets); + auto v = translateExpr(c.value(), params, tensors, lets); return cast(translateScalarType(c.type()->kind()), v); } default: @@ -229,7 +229,7 @@ Expr translateExpr( } } -vector unboundVariables(const vector& lhs, Expr rhs) { +vector unboundVariables(const vector& lhs, Expr rhs) { class FindUnboundVariables : public IRVisitor { using IRVisitor::visit; @@ -254,9 +254,9 @@ vector unboundVariables(const vector& lhs, Expr rhs) { set visited; public: - FindUnboundVariables(const vector& lhs) { + FindUnboundVariables(const vector& lhs) { for (auto v : lhs) { - bound.push(v.name()); + bound.push(v.as()->name); } } vector result; @@ -265,11 +265,9 @@ vector unboundVariables(const vector& lhs, Expr rhs) { return finder.result; } -typedef map, Function::Compare> FunctionBounds; - void forwardBoundsInference( const std::vector& exprs, - const FunctionBounds& bounds, + const map& tensors, const lang::TreeRef& comprehension, const tc::CompilerOptions& compilerOptions, Scope* solution) { @@ -288,13 +286,11 @@ void forwardBoundsInference( // Create inequalities that assert this is not an out-of-bounds access. if (op->call_type == Call::Halide) { - TC_CHECK(op->func.defined()) - << "Expected a Call of type Halide to have an associated Function\n"; - const auto& it = bounds.find(Function(op->func)); - if (it != bounds.end()) { - const map& b = it->second; + const auto& tensorInfo = tensors.find(op->name); + if (tensorInfo != tensors.end()) { + const map& b = tensorInfo->second.bounds; for (size_t i = 0; i < op->args.size(); i++) { - const string& dim = Function(op->func).args()[i]; + const string& dim = tensorInfo->second.args[i]; const auto& it = b.find(dim); if (it != b.end()) { Interval interval = it->second; @@ -332,9 +328,9 @@ void forwardBoundsInference( public: vector result; set freeVars; - const FunctionBounds& bounds; - CreateConstraints(const FunctionBounds& b) : bounds(b) {} - } constraints(bounds); + const map& tensors; + CreateConstraints(const map& t) : tensors(t) {} + } constraints(tensors); for (auto& expr : exprs) { expr.accept(&constraints); } @@ -501,36 +497,24 @@ Expr reductionUpdate(Expr e) { return Call::make(e.type(), kReductionUpdate, {e}, Call::Intrinsic); } -// Translate a single TC comprehension/statement to Halide components: funcs, -// bounds, reductions. +// Translate a single TC comprehension/statement to a Halide Stmt // // Note that the function definitions created by translateComprehension may // contain kReductionUpdate intrinsics. These may have to be removed // in order to be able to apply internal Halide analysis passes on them. -void translateComprehension( +Stmt translateComprehension( const lang::Comprehension& comprehension, const map& params, const tc::CompilerOptions& compilerOptions, - map* funcs, - FunctionBounds* bounds) { - Function f; - auto it = funcs->find(comprehension.ident().name()); - if (it != funcs->end()) { - f = it->second; - } else { - f = Function(comprehension.ident().name()); - (*funcs)[comprehension.ident().name()] = f; - } - // Function is the internal Halide IR type for a pipeline - // stage. Func is the front-end class that wraps it. Here it's - // convenient to use both. - Func func(f); + map* tensors) { + TensorInfo info; + + auto tensorName = comprehension.ident().name(); - vector lhs; - vector lhs_as_exprs; + vector lhs; for (lang::Ident id : comprehension.indices()) { lhs.push_back(Var(id.name())); - lhs_as_exprs.push_back(lhs.back()); + info.args.push_back(id.name()); } // we currently inline all of the let bindings generated in where clauses @@ -540,66 +524,50 @@ void translateComprehension( for (auto wc : comprehension.whereClauses()) { if (wc->kind() == lang::TK_LET) { auto let = lang::Let(wc); - lets[let.name().name()] = translateExpr(let.rhs(), params, *funcs, lets); + lets[let.name().name()] = + translateExpr(let.rhs(), params, *tensors, lets); } } - Expr rhs = translateExpr(comprehension.rhs(), params, *funcs, lets); + Expr rhs = translateExpr(comprehension.rhs(), params, *tensors, lets); + + info.type = rhs.type(); std::vector all_exprs; for (auto wc : comprehension.whereClauses()) { if (wc->kind() == lang::TK_EXISTS) { all_exprs.push_back( - translateExpr(lang::Exists(wc).exp(), params, *funcs, lets)); + translateExpr(lang::Exists(wc).exp(), params, *tensors, lets)); } } - // Halide doesn't have first-class reductions. We map reductions to recursion. - bool added_implicit_initialization = false; - - auto setupIdentity = [&](const Expr& identity, bool zero) { - if (!f.has_pure_definition()) { - added_implicit_initialization = true; - func(lhs) = (zero) ? identity - : undef(rhs.type()); // undef causes the original value - // to remain in input arrays - } - }; - // Each reduction operator has two variants // (1) +=, TK_PLUS_EQ which updates the tensor inplace using its existing // values (2) +=!, TK_PLUS_EQ_B which first sets the tensor to the identity // for the reduction and then applies the reduction. - bool should_zero = false; + Expr currentVal = Call::make(rhs.type(), tensorName, lhs, Call::Halide); + Expr identity; switch (comprehension.assignment()->kind()) { case lang::TK_PLUS_EQ_B: - should_zero = true; // fallthrough - case lang::TK_PLUS_EQ: - setupIdentity(make_zero(rhs.type()), should_zero); - rhs = func(lhs) + rhs; + identity = make_zero(rhs.type()); + case lang::TK_PLUS_EQ: // fallthrough + rhs = currentVal + rhs; break; - case lang::TK_TIMES_EQ_B: - should_zero = true; // fallthrough - case lang::TK_TIMES_EQ: - setupIdentity(make_one(rhs.type()), should_zero); - rhs = func(lhs) * rhs; + identity = make_one(rhs.type()); + case lang::TK_TIMES_EQ: // fallthrough + rhs = currentVal * rhs; break; - case lang::TK_MIN_EQ_B: - should_zero = true; // fallthrough - case lang::TK_MIN_EQ: - setupIdentity(rhs.type().max(), should_zero); - rhs = min(func(lhs), rhs); + identity = rhs.type().max(); + case lang::TK_MIN_EQ: // fallthrough + rhs = min(currentVal, rhs); break; - case lang::TK_MAX_EQ_B: - should_zero = true; // fallthrough - case lang::TK_MAX_EQ: - setupIdentity(rhs.type().min(), should_zero); - rhs = max(func(lhs), rhs); + identity = rhs.type().min(); + case lang::TK_MAX_EQ: // fallthrough + rhs = max(currentVal, rhs); break; - case '=': break; default: @@ -654,8 +622,8 @@ void translateComprehension( continue; auto constraint = lang::RangeConstraint(constraint_); Interval i; - i.min = translateExpr(constraint.start(), params, *funcs, lets); - i.max = translateExpr(constraint.end(), params, *funcs, lets) - 1; + i.min = translateExpr(constraint.start(), params, *tensors, lets); + i.max = translateExpr(constraint.end(), params, *tensors, lets) - 1; // TODO: In the future we'll want to make any non-trivial bounds // into hidden scalar parameters, and just pass variables to the @@ -671,7 +639,7 @@ void translateComprehension( // Infer the rest all_exprs.push_back(rhs); forwardBoundsInference( - all_exprs, *bounds, comprehension, compilerOptions, &solution); + all_exprs, *tensors, comprehension, compilerOptions, &solution); // TODO: What if subsequent updates have incompatible bounds // (e.g. an in-place stencil)?. The .bound directive will use the @@ -680,16 +648,16 @@ void translateComprehension( // Does a tensor have a single bound, or can its bounds shrink over // time? Solve for a single bound for now. - for (Var v : lhs) { - if (!solution.contains(v.name())) { + for (lang::Ident id : comprehension.indices()) { + if (!solution.contains(id.name())) { throw lang::ErrorReport(comprehension) - << "Free variable " << v + << "Free variable " << id.name() << " was not solved in range inference. May not be used right-hand side"; } // TODO: We're enforcing a single bound across all comprehensions // for now. We should really check later ones are equal to earlier // ones instead of just clobbering. - (*bounds)[f][v.name()] = solution.get(v.name()); + info.bounds[id.name()] = solution.get(id.name()); } // Free variables that appear on the rhs but not the lhs are @@ -714,7 +682,7 @@ void translateComprehension( Expr v_min = bound.min; Expr v_extent = simplify(bound.max - bound.min + 1); rVars.push_back({v->name, v_min, v_extent}); - (*bounds)[f][v->name] = bound; + info.bounds[v->name] = bound; } ReductionDomain domain(rVars); for (auto v : unbound) { @@ -724,182 +692,90 @@ void translateComprehension( rdom = RDom(domain); } - Stage stage{func(lhs) = rhs}; + // Now construct the Stmt + Stmt stmt = Provide::make(tensorName, {rhs}, lhs); - // Use the simplest possible Halide schedule, but reorder the loop - // indices to match TC convention. - vector loop_nest; + // Wrap the reduction loops if (rdom.defined()) { for (int i = 0; i < rdom.dimensions(); i++) { - loop_nest.push_back(rdom[i]); + stmt = For::make( + rdom[i].name(), + rdom[i].min(), + rdom[i].extent(), + ForType::Serial, + DeviceAPI::None, + stmt); } } - while (!lhs.empty()) { - loop_nest.push_back(lhs.back()); - lhs.pop_back(); + + // Add an initialization if needed + Stmt init; + if (identity.defined()) { + init = Provide::make(tensorName, {identity}, lhs); } - if (added_implicit_initialization) { - // Also reorder reduction initializations to the TC convention - vector funcArgs = func.args(); - loop_nest.clear(); - while (!funcArgs.empty()) { - loop_nest.push_back(funcArgs.back()); - funcArgs.pop_back(); + // Wrap the rest of the loops + for (auto id = info.args.rbegin(); id != info.args.rend(); id++) { + Interval in = info.bounds[*id]; + Expr extent = simplify(in.max - in.min + 1); + stmt = + For::make(*id, in.min, extent, ForType::Serial, DeviceAPI::None, stmt); + if (init.defined()) { + init = For::make( + *id, in.min, extent, ForType::Serial, DeviceAPI::None, init); } - func.reorder(loop_nest); } - func.compute_root(); - stage.reorder(loop_nest); + if (init.defined()) { + stmt = Block::make(init, stmt); + } + + auto existingInfo = tensors->find(tensorName); + + // Record information about this tensor for later stages of + // translation to refer to. + if (existingInfo == tensors->end()) { + tensors->emplace(tensorName, std::move(info)); + } else { + // Clobber the bounds information with the possibly-updated + // constraints. + existingInfo->second.bounds = info.bounds; + } + + return stmt; } // Translate a semantically checked TC def to HalideComponents struct. HalideComponents translateDef( const lang::Def& def, const tc::CompilerOptions& compilerOptions) { - map funcs; HalideComponents components; components.def = def; - FunctionBounds bounds; + + map tensors; for (auto p : def.params()) { translateParam(p, &components.params, &components.inputs); } - for (auto c : def.statements()) { - translateComprehension( - c, components.params, compilerOptions, &funcs, &bounds); - } - vector outputs; - for (auto p : def.returns()) { - translateOutput(p, funcs, &outputs); - } - - // Now apply an extremely simplified version of Halide lowering - - // Compute an environment - map env; - for (auto f : outputs) { - populate_environment(f, env); - } - - // Finalize all the LoopLevels - for (auto& iter : env) { - iter.second.lock_loop_levels(); - } - // Compute a realization order. This is a topological order on the - // pipeline of groups of Funcs. For our purposes, each group has a - // single Func in it. The Halide scheduling directive compute_with, - // (which does general loop fusion) can create groups with multiple - // Funcs in it, but we don't use it here. - vector order; - vector> fused_groups; - std::tie(order, fused_groups) = realization_order(outputs, env); - - // Create loop nests - bool any_memoized = false; - // This part of lowering requires a target, but it will never be - // used in the pipelines we construct here, so just make a host target. - Target target("host"); - Stmt s = schedule_functions(outputs, fused_groups, env, target, any_memoized); - // we insert these to allow for inplace mutation of in/out tensors - s = remove_undef(s); - // Apply forward bounds inference results. This replaces the usual Halide - // bounds inference. - for (auto p : bounds) { - const Function& f = p.first; - for (auto b : p.second) { - const string& var = b.first; - const Interval& bound = b.second; - for (size_t i = 0; i <= f.updates().size(); i++) { - // Halide lowers function loop bounds as follows: - string qualified_var_name = - f.name() + ".s" + std::to_string(i) + "." + var; - s = LetStmt::make(qualified_var_name + ".min", bound.min, s); - s = LetStmt::make(qualified_var_name + ".max", bound.max, s); - } + for (auto c : def.statements()) { + Stmt next = + translateComprehension(c, components.params, compilerOptions, &tensors); + if (!components.stmt.defined()) { + components.stmt = next; + } else { + components.stmt = Block::make(components.stmt, next); } } - // Collect the arguments (inputs and outputs) - s = uniquify_variable_names(s); - s = simplify(s); - - // Trim ProducerConsumer annotations. TC doesn't use them. - class RemoveProducerConsumer : public IRMutator2 { - using IRMutator2::visit; - Stmt visit(const ProducerConsumer* op) { - return mutate(op->body); - } - } removeProducerConsumer; - - s = removeProducerConsumer.mutate(s); - - // Rename all loop variables to be valid C identifiers, to ease - // conversion to isl. - class RenameVariables : public IRMutator2 { - using IRMutator2::visit; - - map new_names; - - Expr visit(const Variable* op) override { - auto it = new_names.find(op->name); - if (it != new_names.end()) { - return Variable::make( - op->type, it->second, op->image, op->param, op->reduction_domain); - } else { - return op; - } - } - - Stmt visit(const For* op) override { - string sanitized = replace_all(op->name, ".", "_"); - Expr min = mutate(op->min); - Expr extent = mutate(op->extent); - new_names[op->name] = sanitized; - Stmt body = mutate(op->body); - return For::make( - sanitized, - std::move(min), - std::move(extent), - op->for_type, - op->device_api, - std::move(body)); - } - } renameVariables; - - s = renameVariables.mutate(s); - - // We don't handle Let nodes after this point - class SubstituteAllLets : public IRMutator2 { - Scope scope; - Stmt visit(const LetStmt* op) override { - ScopedBinding bind(scope, op->name, mutate(op->value)); - return mutate(op->body); - } - Expr visit(const Let* op) override { - ScopedBinding bind(scope, op->name, mutate(op->value)); - return mutate(op->body); - } - Expr visit(const Variable* op) override { - if (scope.contains(op->name)) { - return scope.get(op->name); - } else { - return op; - } - } - }; - s = SubstituteAllLets().mutate(s); - - components.stmt = s; - - for (Function f : outputs) { - OutputImageParam o = Func(f).output_buffers()[0]; - // Apply forward bounds inference results to the output buffers. - const auto& b = bounds[f]; + // Populate the output bounds + for (auto p : def.returns()) { + // TODO: unify bounds and tensors map? + const auto& t = tensors[p.ident().name()]; + ImageParam o(t.type, t.args.size(), p.ident().name()); for (int i = 0; i < o.dimensions(); i++) { - const Interval& bound = b.at(f.args()[i]); + string arg = t.args[i]; + const Interval& bound = t.bounds.at(arg); o.dim(i).set_bounds(bound.min, simplify(bound.max - bound.min + 1)); } components.outputs.push_back(o); diff --git a/test/test_core.cc b/test/test_core.cc index 88a2ff8bc..26b464b27 100644 --- a/test/test_core.cc +++ b/test/test_core.cc @@ -60,19 +60,14 @@ dtype { struct GenericHalideCoreTest : public ::testing::Test { void CheckC(const std::string& tc, const std::vector& expected) { - auto curPos = std::string::npos; + auto curPos = 0; auto halide = tc2halide::translate( isl::with_exceptions::globalIslCtx(), tc, CompilerOptions()); auto res = tc::halideCodegenC(halide.stmt); for (const auto& e : expected) { - auto newPos = res.find(e); - if (curPos == std::string::npos) { - curPos = newPos; - } - ASSERT_NE(std::string::npos, res.find(e)) << "No: " << e << " in:\n" - << res; - ASSERT_GE(newPos, curPos) - << "Improper ordering of expected outputs in:" << res; + auto newPos = res.find(e, curPos); + ASSERT_NE(std::string::npos, newPos) + << "No: " << e << " in:\n" << res; curPos = newPos; } } @@ -100,27 +95,27 @@ def fun(float(M, K) I, float(K, N) W1, float(N, P) W2) -> (O1, O2) { CheckC( tc, R"C( -for (int O1_s0_m = 0; O1_s0_m < M; O1_s0_m++) { - for (int O1_s0_n = 0; O1_s0_n < N; O1_s0_n++) { - O1[O1_s0_m][O1_s0_n] = 0.000000f; +for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + O1[m][n] = 0.000000f; } } -for (int O1_s1_m = 0; O1_s1_m < M; O1_s1_m++) { - for (int O1_s1_n = 0; O1_s1_n < N; O1_s1_n++) { - for (int O1_s1_r_k = 0; O1_s1_r_k < K; O1_s1_r_k++) { - O1[O1_s1_m][O1_s1_n] = (O1[O1_s1_m][O1_s1_n] + (I[O1_s1_m][O1_s1_r_k]*W1[O1_s1_r_k][O1_s1_n])); +for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + for (int r_k = 0; r_k < K; r_k++) { + O1[m][n] = (O1[m][n] + (I[m][r_k]*W1[r_k][n])); } } } -for (int O2_s0_m = 0; O2_s0_m < M; O2_s0_m++) { - for (int O2_s0_p = 0; O2_s0_p < P; O2_s0_p++) { - O2[O2_s0_m][O2_s0_p] = 0.000000f; +for (int m = 0; m < M; m++) { + for (int p = 0; p < P; p++) { + O2[m][p] = 0.000000f; } } -for (int O2_s1_m = 0; O2_s1_m < M; O2_s1_m++) { - for (int O2_s1_p = 0; O2_s1_p < P; O2_s1_p++) { - for (int O2_s1_r_n = 0; O2_s1_r_n < N; O2_s1_r_n++) { - O2[O2_s1_m][O2_s1_p] = (O2[O2_s1_m][O2_s1_p] + (O1[O2_s1_m][O2_s1_r_n]*W2[O2_s1_r_n][O2_s1_p])); +for (int m = 0; m < M; m++) { + for (int p = 0; p < P; p++) { + for (int r_n = 0; r_n < N; r_n++) { + O2[m][p] = (O2[m][p] + (O1[m][r_n]*W2[r_n][p])); } } } @@ -136,23 +131,23 @@ def fun(float(N, C, H, W) I1, float(C, F, KH, KW) W1) -> (O1) { CheckC( tc, R"C( -for (int O1_s0_n = 0; O1_s0_n < N; O1_s0_n++) { - for (int O1_s0_f = 0; O1_s0_f < F; O1_s0_f++) { - for (int O1_s0_h = 0; O1_s0_h < ((H - KH) + 1); O1_s0_h++) { - for (int O1_s0_w = 0; O1_s0_w < ((W - KW) + 1); O1_s0_w++) { - O1[O1_s0_n][O1_s0_f][O1_s0_h][O1_s0_w] = 0.000000f; +for (int n = 0; n < N; n++) { + for (int f = 0; f < F; f++) { + for (int h = 0; h < ((H - KH) + 1); h++) { + for (int w = 0; w < ((W - KW) + 1); w++) { + O1[n][f][h][w] = 0.000000f; } } } } -for (int O1_s1_n = 0; O1_s1_n < N; O1_s1_n++) { - for (int O1_s1_f = 0; O1_s1_f < F; O1_s1_f++) { - for (int O1_s1_h = 0; O1_s1_h < ((H - KH) + 1); O1_s1_h++) { - for (int O1_s1_w = 0; O1_s1_w < ((W - KW) + 1); O1_s1_w++) { - for (int O1_s1_r_c = 0; O1_s1_r_c < C; O1_s1_r_c++) { - for (int O1_s1_r_kh = 0; O1_s1_r_kh < KH; O1_s1_r_kh++) { - for (int O1_s1_r_kw = 0; O1_s1_r_kw < KW; O1_s1_r_kw++) { - O1[O1_s1_n][O1_s1_f][O1_s1_h][O1_s1_w] = (O1[O1_s1_n][O1_s1_f][O1_s1_h][O1_s1_w] + (I1[O1_s1_n][O1_s1_r_c][(O1_s1_h + O1_s1_r_kh)][(O1_s1_w + O1_s1_r_kw)]*W1[O1_s1_r_c][O1_s1_f][O1_s1_r_kh][O1_s1_r_kw])); +for (int n = 0; n < N; n++) { + for (int f = 0; f < F; f++) { + for (int h = 0; h < ((H - KH) + 1); h++) { + for (int w = 0; w < ((W - KW) + 1); w++) { + for (int r_c = 0; r_c < C; r_c++) { + for (int r_kh = 0; r_kh < KH; r_kh++) { + for (int r_kw = 0; r_kw < KW; r_kw++) { + O1[n][f][h][w] = (O1[n][f][h][w] + (I1[n][r_c][(h + r_kh)][(w + r_kw)]*W1[r_c][f][r_kh][r_kw])); } } } @@ -166,10 +161,10 @@ for (int O1_s1_n = 0; O1_s1_n < N; O1_s1_n++) { TEST_F(GenericHalideCoreTest, Copy) { CheckC( makeCopyTc(3), - {"for (int O_s0_i0 = 0; O_s0_i0 < P0; O_s0_i0++) {", - " for (int O_s0_i1 = 0; O_s0_i1 < P1; O_s0_i1++) {", - " for (int O_s0_i2 = 0; O_s0_i2 < P2; O_s0_i2++) {", - " O[O_s0_i0][O_s0_i1][O_s0_i2] = I[O_s0_i0][O_s0_i1][O_s0_i2];"}); + {"for (int i0 = 0; i0 < P0; i0++) {", + " for (int i1 = 0; i1 < P1; i1++) {", + " for (int i2 = 0; i2 < P2; i2++) {", + " O[i0][i1][i2] = I[i0][i1][i2];"}); } TEST_F(GenericHalideCoreTest, GroupConvolution) { @@ -181,26 +176,26 @@ def fun(float(N, G, C, H, W) I1, float(G, C, F, KH, KW) W1) -> (O1) { CheckC( tc, R"C( -for (int O1_s0_n = 0; O1_s0_n < N; O1_s0_n++) { - for (int O1_s0_g = 0; O1_s0_g < G; O1_s0_g++) { - for (int O1_s0_f = 0; O1_s0_f < F; O1_s0_f++) { - for (int O1_s0_h = 0; O1_s0_h < ((H - KH) + 1); O1_s0_h++) { - for (int O1_s0_w = 0; O1_s0_w < ((W - KW) + 1); O1_s0_w++) { - O1[O1_s0_n][O1_s0_g][O1_s0_f][O1_s0_h][O1_s0_w] = 0.000000f; +for (int n = 0; n < N; n++) { + for (int g = 0; g < G; g++) { + for (int f = 0; f < F; f++) { + for (int h = 0; h < ((H - KH) + 1); h++) { + for (int w = 0; w < ((W - KW) + 1); w++) { + O1[n][g][f][h][w] = 0.000000f; } } } } } -for (int O1_s1_n = 0; O1_s1_n < N; O1_s1_n++) { - for (int O1_s1_g = 0; O1_s1_g < G; O1_s1_g++) { - for (int O1_s1_f = 0; O1_s1_f < F; O1_s1_f++) { - for (int O1_s1_h = 0; O1_s1_h < ((H - KH) + 1); O1_s1_h++) { - for (int O1_s1_w = 0; O1_s1_w < ((W - KW) + 1); O1_s1_w++) { - for (int O1_s1_r_c = 0; O1_s1_r_c < C; O1_s1_r_c++) { - for (int O1_s1_r_kh = 0; O1_s1_r_kh < KH; O1_s1_r_kh++) { - for (int O1_s1_r_kw = 0; O1_s1_r_kw < KW; O1_s1_r_kw++) { - O1[O1_s1_n][O1_s1_g][O1_s1_f][O1_s1_h][O1_s1_w] = (O1[O1_s1_n][O1_s1_g][O1_s1_f][O1_s1_h][O1_s1_w] + (I1[O1_s1_n][O1_s1_g][O1_s1_r_c][(O1_s1_h + O1_s1_r_kh)][(O1_s1_w + O1_s1_r_kw)]*W1[O1_s1_g][O1_s1_r_c][O1_s1_f][O1_s1_r_kh][O1_s1_r_kw])); +for (int n = 0; n < N; n++) { + for (int g = 0; g < G; g++) { + for (int f = 0; f < F; f++) { + for (int h = 0; h < ((H - KH) + 1); h++) { + for (int w = 0; w < ((W - KW) + 1); w++) { + for (int r_c = 0; r_c < C; r_c++) { + for (int r_kh = 0; r_kh < KH; r_kh++) { + for (int r_kw = 0; r_kw < KW; r_kw++) { + O1[n][g][f][h][w] = (O1[n][g][f][h][w] + (I1[n][g][r_c][(h + r_kh)][(w + r_kw)]*W1[g][r_c][f][r_kh][r_kw])); } } } @@ -216,15 +211,15 @@ TEST_F(GenericHalideCoreTest, Matmul) { CheckC( makeMatmulTc(false, false), R"C( -for (int O_s0_i = 0; O_s0_i < N; O_s0_i++) { - for (int O_s0_j = 0; O_s0_j < M; O_s0_j++) { - O[O_s0_i][O_s0_j] = 0.000000f; +for (int i = 0; i < N; i++) { + for (int j = 0; j < M; j++) { + O[i][j] = 0.000000f; } } -for (int O_s1_i = 0; O_s1_i < N; O_s1_i++) { - for (int O_s1_j = 0; O_s1_j < M; O_s1_j++) { - for (int O_s1_k = 0; O_s1_k < K; O_s1_k++) { - O[O_s1_i][O_s1_j] = (O[O_s1_i][O_s1_j] + (A[O_s1_i][O_s1_k]*B[O_s1_k][O_s1_j])); +for (int i = 0; i < N; i++) { + for (int j = 0; j < M; j++) { + for (int k = 0; k < K; k++) { + O[i][j] = (O[i][j] + (A[i][k]*B[k][j])); } } }