Skip to content

Commit 7e3efc5

Browse files
authored
Allgather coalescee: Check tuple shape only if return shape is tuple. (#9403)
1 parent 926700d commit 7e3efc5

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

torch_xla/csrc/cross_replica_reduces.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,9 @@ AllGatherResultCoalesced BuildAllGatherCoalesced(
309309
xla::AllGather(xla::Tuple(inputs[0].builder(), type_ctx.second.ops),
310310
dim, shard_count, cc_groups);
311311
}
312-
if (ShapeHelper::ShapeOfXlaOp(all_gather_result).tuple_shapes().size() !=
313-
0) {
312+
if (ShapeHelper::ShapeOfXlaOp(all_gather_result).IsTuple() &&
313+
ShapeHelper::ShapeOfXlaOp(all_gather_result).tuple_shapes().size() !=
314+
0) {
314315
for (size_t i = 0; i < type_ctx.second.indices.size(); ++i) {
315316
size_t op_idx = type_ctx.second.indices[i];
316317
result[op_idx] = xla::GetTupleElement(all_gather_result, i);

0 commit comments

Comments
 (0)