Skip to content

Commit ee8c22a

Browse files
Merge pull request #480 from tensor-compiler/multidim-workspace
Fix some precompute transformation algorithm bugs that arose
2 parents 97edc84 + d460e29 commit ee8c22a

File tree

2 files changed

+235
-54
lines changed

2 files changed

+235
-54
lines changed

src/index_notation/transformations.cpp

Lines changed: 53 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -383,67 +383,76 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
383383
Forall foralli(node);
384384
std::vector<IndexVar> i_vars = precompute.getIVars();
385385

386-
vector<IndexVar> forallIndexVars;
386+
bool containsWhere = false;
387387
match(foralli,
388-
function<void(const ForallNode*)>([&](const ForallNode* op) {
389-
forallIndexVars.push_back(op->indexVar);
388+
function<void(const WhereNode*)>([&](const WhereNode* op) {
389+
containsWhere = true;
390390
})
391391
);
392392

393-
IndexStmt s = foralli.getStmt();
394-
TensorVar ws = precompute.getWorkspace();
395-
IndexExpr e = precompute.getExpr();
396-
std::vector<IndexVar> iw_vars = precompute.getIWVars();
393+
if (!containsWhere) {
394+
vector<IndexVar> forallIndexVars;
395+
match(foralli,
396+
function<void(const ForallNode*)>([&](const ForallNode* op) {
397+
forallIndexVars.push_back(op->indexVar);
398+
})
399+
);
397400

398-
map<IndexVar, IndexVar> substitutions;
399-
taco_iassert(i_vars.size() == iw_vars.size()) << "i_vars and iw_vars lists must be the same size";
401+
IndexStmt s = foralli.getStmt();
402+
TensorVar ws = precompute.getWorkspace();
403+
IndexExpr e = precompute.getExpr();
404+
std::vector<IndexVar> iw_vars = precompute.getIWVars();
400405

401-
for (int index = 0; index < (int)i_vars.size(); index++) {
402-
substitutions[i_vars[index]] = iw_vars[index];
403-
}
406+
map<IndexVar, IndexVar> substitutions;
407+
taco_iassert(i_vars.size() == iw_vars.size()) << "i_vars and iw_vars lists must be the same size";
404408

405-
// Build consumer by replacing with temporary (in replacedStmt)
406-
IndexStmt replacedStmt = replace(s, {{e, ws(i_vars) }});
407-
if (replacedStmt != s) {
408-
// Then modify the replacedStmt to have the correct foralls
409-
// by concretizing the consumer assignment
409+
for (int index = 0; index < (int)i_vars.size(); index++) {
410+
substitutions[i_vars[index]] = iw_vars[index];
411+
}
410412

411-
auto consumerAssignment = getConsumerAssignment(replacedStmt, ws);
412-
auto consumerIndexVars = consumerAssignment.getIndexVars();
413+
// Build consumer by replacing with temporary (in replacedStmt)
414+
IndexStmt replacedStmt = replace(s, {{e, ws(i_vars) }});
415+
if (replacedStmt != s) {
416+
// Then modify the replacedStmt to have the correct foralls
417+
// by concretizing the consumer assignment
413418

414-
auto producerAssignment = getProducerAssignment(ws, i_vars, iw_vars, e, substitutions);
415-
auto producerIndexVars = producerAssignment.getIndexVars();
419+
auto consumerAssignment = getConsumerAssignment(replacedStmt, ws);
420+
auto consumerIndexVars = consumerAssignment.getIndexVars();
416421

417-
vector<IndexVar> producerForallIndexVars;
418-
vector<IndexVar> consumerForallIndexVars;
419-
vector<IndexVar> outerForallIndexVars;
422+
auto producerAssignment = getProducerAssignment(ws, i_vars, iw_vars, e, substitutions);
423+
auto producerIndexVars = producerAssignment.getIndexVars();
420424

421-
bool stopForallDistribution = false;
422-
for (auto &i : util::reverse(forallIndexVars)) {
423-
if (!stopForallDistribution && containsIndexVarScheduled(i_vars, i)) {
424-
producerForallIndexVars.push_back(substitutions[i]);
425-
consumerForallIndexVars.push_back(i);
426-
} else {
427-
auto consumerContains = containsIndexVarScheduled(consumerIndexVars, i);
428-
auto producerContains = containsIndexVarScheduled(producerIndexVars, i);
429-
if (stopForallDistribution || (producerContains && consumerContains)) {
430-
outerForallIndexVars.push_back(i);
431-
stopForallDistribution = true;
432-
} else if (!stopForallDistribution && consumerContains) {
425+
vector<IndexVar> producerForallIndexVars;
426+
vector<IndexVar> consumerForallIndexVars;
427+
vector<IndexVar> outerForallIndexVars;
428+
429+
bool stopForallDistribution = false;
430+
for (auto &i : util::reverse(forallIndexVars)) {
431+
if (!stopForallDistribution && containsIndexVarScheduled(i_vars, i)) {
432+
producerForallIndexVars.push_back(substitutions[i]);
433433
consumerForallIndexVars.push_back(i);
434-
} else if (!stopForallDistribution && producerContains) {
435-
producerForallIndexVars.push_back(i);
434+
} else {
435+
auto consumerContains = containsIndexVarScheduled(consumerIndexVars, i);
436+
auto producerContains = containsIndexVarScheduled(producerIndexVars, i);
437+
if (stopForallDistribution || (producerContains && consumerContains)) {
438+
outerForallIndexVars.push_back(i);
439+
stopForallDistribution = true;
440+
} else if (!stopForallDistribution && consumerContains) {
441+
consumerForallIndexVars.push_back(i);
442+
} else if (!stopForallDistribution && producerContains) {
443+
producerForallIndexVars.push_back(i);
444+
}
436445
}
437446
}
438-
}
439447

440-
IndexStmt consumer = generateForalls(consumerAssignment, consumerForallIndexVars);
448+
IndexStmt consumer = generateForalls(consumerAssignment, consumerForallIndexVars);
441449

442-
IndexStmt producer = generateForalls(producerAssignment, producerForallIndexVars);
443-
Where where(consumer, producer);
450+
IndexStmt producer = generateForalls(producerAssignment, producerForallIndexVars);
451+
Where where(consumer, producer);
444452

445-
stmt = generateForalls(where, outerForallIndexVars);
446-
return;
453+
stmt = generateForalls(where, outerForallIndexVars);
454+
return;
455+
}
447456
}
448457
IndexNotationRewriter::visit(node);
449458
}

test/tests-workspaces.cpp

Lines changed: 182 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ TEST(workspaces, tile_vecElemMul_NoTail) {
4545
expected.compile();
4646
expected.assemble();
4747
expected.compute();
48-
ASSERT_TENSOR_EQ(A, expected);
48+
ASSERT_TENSOR_EQ(expected, A);
4949
}
5050

5151
TEST(workspaces, tile_vecElemMul_Tail1) {
@@ -83,7 +83,7 @@ TEST(workspaces, tile_vecElemMul_Tail1) {
8383
expected.compile();
8484
expected.assemble();
8585
expected.compute();
86-
ASSERT_TENSOR_EQ(A, expected);
86+
ASSERT_TENSOR_EQ(expected, A);
8787
}
8888

8989
TEST(workspaces, tile_vecElemMul_Tail2) {
@@ -121,7 +121,7 @@ TEST(workspaces, tile_vecElemMul_Tail2) {
121121
expected.compile();
122122
expected.assemble();
123123
expected.compute();
124-
ASSERT_TENSOR_EQ(A, expected);
124+
ASSERT_TENSOR_EQ(expected, A);
125125

126126
// ir::IRPrinter irp = ir::IRPrinter(cout);
127127
//
@@ -171,7 +171,7 @@ TEST(workspaces, tile_denseMatMul) {
171171
expected.compile();
172172
expected.assemble();
173173
expected.compute();
174-
ASSERT_TENSOR_EQ(A, expected);
174+
ASSERT_TENSOR_EQ(expected, A);
175175

176176
// ir::IRPrinter irp = ir::IRPrinter(cout);
177177
//
@@ -218,7 +218,7 @@ TEST(workspaces, precompute2D_add) {
218218
expected.compile();
219219
expected.assemble();
220220
expected.compute();
221-
ASSERT_TENSOR_EQ(A, expected);
221+
ASSERT_TENSOR_EQ(expected, A);
222222

223223
}
224224

@@ -263,7 +263,7 @@ TEST(workspaces, precompute4D_add) {
263263
expected.compile();
264264
expected.assemble();
265265
expected.compute();
266-
ASSERT_TENSOR_EQ(A, expected);
266+
ASSERT_TENSOR_EQ(expected, A);
267267
}
268268

269269
TEST(workspaces, precompute4D_multireduce) {
@@ -305,7 +305,7 @@ TEST(workspaces, precompute4D_multireduce) {
305305
expected.compile();
306306
expected.assemble();
307307
expected.compute();
308-
ASSERT_TENSOR_EQ(A, expected);
308+
ASSERT_TENSOR_EQ(expected, A);
309309
}
310310

311311
TEST(workspaces, precompute3D_TspV) {
@@ -344,7 +344,7 @@ TEST(workspaces, precompute3D_TspV) {
344344
expected.compile();
345345
expected.assemble();
346346
expected.compute();
347-
ASSERT_TENSOR_EQ(A, expected);
347+
ASSERT_TENSOR_EQ(expected, A);
348348

349349
}
350350

@@ -388,7 +388,7 @@ TEST(workspaces, precompute3D_multipleWS) {
388388
expected.compile();
389389
expected.assemble();
390390
expected.compute();
391-
ASSERT_TENSOR_EQ(A, expected);
391+
ASSERT_TENSOR_EQ(expected, A);
392392

393393
}
394394

@@ -431,6 +431,178 @@ TEST(workspaces, precompute3D_renamedIVars_TspV) {
431431
expected.compile();
432432
expected.assemble();
433433
expected.compute();
434-
ASSERT_TENSOR_EQ(A, expected);
434+
ASSERT_TENSOR_EQ(expected, A);
435435

436436
}
437+
438+
TEST(workspaces, DISABLED_tile_dotProduct_1) {
439+
// FIXME: Disabled because currently the precompute algorithm does not appropriately
440+
// find the correct forall substmt to next the WhereNode in after i has been
441+
// split into i0 and i1. As an example, the first precompute below is incorrect
442+
// since it should transform
443+
// forall(i0, forall(i1, A() += B(i) * C(i))) -->
444+
// forall(i0, where(forall(i1, A() += ws(i1)), forall(i1, ws(i1) += B(i) * C(i))))
445+
//
446+
// But currently the algorithm does
447+
// forall(i0, forall(i1, A() += B(i) * C(i))) -->
448+
// where(forall(i1, A() += ws(i1)), forall(i0, forall(i1, ws(i1) += B(i) * C(i))))
449+
450+
int N = 1024;
451+
Tensor<double> A("A");
452+
Tensor<double> B("B", {N}, Format({Dense}));
453+
Tensor<double> C("C", {N}, Format({Dense}));
454+
455+
for (int i = 0; i < N; i++) {
456+
B.insert({i}, (double) i);
457+
C.insert({i}, (double) i);
458+
}
459+
460+
B.pack();
461+
C.pack();
462+
463+
IndexVar i("i");
464+
IndexVar i_bounded("i_bounded");
465+
IndexVar i0("i0"), i1("i1");
466+
IndexExpr BExpr = B(i);
467+
IndexExpr CExpr = C(i);
468+
IndexExpr precomputedExpr = (BExpr) * (CExpr);
469+
A() = precomputedExpr;
470+
471+
IndexStmt stmt = A.getAssignment().concretize();
472+
TensorVar B_new("B_new", Type(Float64, {(size_t)N}), taco::dense);
473+
TensorVar C_new("C_new", Type(Float64, {(size_t)N}), taco::dense);
474+
TensorVar precomputed("precomputed", Type(Float64, {(size_t)N}), taco::dense);
475+
476+
stmt = stmt.bound(i, i_bounded, (size_t)N, BoundType::MaxExact)
477+
.split(i_bounded, i0, i1, 32);
478+
stmt = stmt.precompute(precomputedExpr, i1, i1, precomputed);
479+
stmt = stmt.precompute(BExpr, i1, i1, B_new)
480+
.precompute(CExpr, i1, i1, C_new);
481+
482+
stmt = stmt.concretize();
483+
484+
A.compile(stmt);
485+
A.assemble();
486+
A.compute();
487+
488+
ir::IRPrinter irp = ir::IRPrinter(cout);
489+
490+
cout << stmt << endl;
491+
492+
std::shared_ptr<ir::CodeGen> codegen = ir::CodeGen::init_default(cout, ir::CodeGen::ImplementationGen);
493+
ir::Stmt compute = lower(stmt, "compute", false, true);
494+
495+
irp.print(compute);
496+
cout << endl;
497+
codegen->compile(compute, false);
498+
499+
Tensor<double> expected("expected");
500+
expected() = B(i) * C(i);
501+
expected.compile();
502+
expected.assemble();
503+
expected.compute();
504+
ASSERT_TENSOR_EQ(expected, A);
505+
}
506+
507+
TEST(workspaces, DISABLED_tile_dotProduct_2) {
508+
// FIXME: This is also currently disabled since split(...) scheduling commands
509+
// only split on the FIRST INSTANCE of an indexVar (assumes only one).
510+
// This is wrong if the indexVar is not renamed across iw_vars since an indexVar can
511+
// then occur on BOTH the consumer and producer side and should be split across both.
512+
513+
int N = 1024;
514+
Tensor<double> A("A");
515+
Tensor<double> B("B", {N}, Format({Dense}));
516+
Tensor<double> C("C", {N}, Format({Dense}));
517+
518+
for (int i = 0; i < N; i++) {
519+
B.insert({i}, (double) i);
520+
C.insert({i}, (double) i);
521+
}
522+
523+
B.pack();
524+
C.pack();
525+
526+
IndexVar i("i");
527+
IndexVar i_bounded("i_bounded");
528+
IndexVar i0("i0"), i1("i1");
529+
IndexExpr BExpr = B(i);
530+
IndexExpr CExpr = C(i);
531+
IndexExpr precomputedExpr = (BExpr) * (CExpr);
532+
A() = precomputedExpr;
533+
534+
IndexStmt stmt = A.getAssignment().concretize();
535+
TensorVar B_new("B_new", Type(Float64, {(size_t)N}), taco::dense);
536+
TensorVar C_new("C_new", Type(Float64, {(size_t)N}), taco::dense);
537+
TensorVar precomputed("precomputed", Type(Float64, {(size_t)N}), taco::dense);
538+
539+
stmt = stmt.precompute(precomputedExpr, i, i, precomputed);
540+
541+
stmt = stmt.precompute(BExpr, i, i, B_new)
542+
.precompute(CExpr, i, i, C_new);
543+
544+
stmt = stmt.bound(i, i_bounded, (size_t)N, BoundType::MaxExact)
545+
.split(i_bounded, i0, i1, 32);
546+
547+
stmt = stmt.concretize();
548+
549+
A.compile(stmt);
550+
A.assemble();
551+
A.compute();
552+
553+
Tensor<double> expected("expected");
554+
expected() = B(i) * C(i);
555+
expected.compile();
556+
expected.assemble();
557+
expected.compute();
558+
ASSERT_TENSOR_EQ(expected, A);
559+
}
560+
561+
TEST(workspaces, tile_dotProduct_3) {
562+
int N = 1024;
563+
Tensor<double> A("A");
564+
Tensor<double> B("B", {N}, Format({Dense}));
565+
Tensor<double> C("C", {N}, Format({Dense}));
566+
567+
for (int i = 0; i < N; i++) {
568+
B.insert({i}, (double) i);
569+
C.insert({i}, (double) i);
570+
}
571+
572+
B.pack();
573+
C.pack();
574+
575+
IndexVar i("i");
576+
IndexVar i_bounded("i_bounded");
577+
IndexVar i0("i0"), i1("i1");
578+
IndexExpr BExpr = B(i);
579+
IndexExpr CExpr = C(i);
580+
IndexExpr precomputedExpr = (BExpr) * (CExpr);
581+
A() = precomputedExpr;
582+
583+
IndexStmt stmt = A.getAssignment().concretize();
584+
TensorVar B_new("B_new", Type(Float64, {(size_t)N}), taco::dense);
585+
TensorVar C_new("C_new", Type(Float64, {(size_t)N}), taco::dense);
586+
TensorVar precomputed("precomputed", Type(Float64, {(size_t)N}), taco::dense);
587+
588+
stmt = stmt.bound(i, i_bounded, (size_t)N, BoundType::MaxExact)
589+
.split(i_bounded, i0, i1, 32);
590+
stmt = stmt.precompute(precomputedExpr, i0, i0, precomputed);
591+
592+
stmt = stmt.precompute(BExpr, i1, i1, B_new)
593+
.precompute(CExpr, i1, i1, C_new);
594+
595+
596+
stmt = stmt.concretize();
597+
598+
A.compile(stmt);
599+
A.assemble();
600+
A.compute();
601+
602+
Tensor<double> expected("expected");
603+
expected() = B(i) * C(i);
604+
expected.compile();
605+
expected.assemble();
606+
expected.compute();
607+
ASSERT_TENSOR_EQ(expected, A);
608+
}

0 commit comments

Comments
 (0)