diff --git a/torch_xla/csrc/cross_replica_reduces.cpp b/torch_xla/csrc/cross_replica_reduces.cpp index a5bdbbae7d2..9b062e4ce1a 100644 --- a/torch_xla/csrc/cross_replica_reduces.cpp +++ b/torch_xla/csrc/cross_replica_reduces.cpp @@ -309,8 +309,9 @@ AllGatherResultCoalesced BuildAllGatherCoalesced( xla::AllGather(xla::Tuple(inputs[0].builder(), type_ctx.second.ops), dim, shard_count, cc_groups); } - if (ShapeHelper::ShapeOfXlaOp(all_gather_result).tuple_shapes().size() != - 0) { + if (ShapeHelper::ShapeOfXlaOp(all_gather_result).IsTuple() && + ShapeHelper::ShapeOfXlaOp(all_gather_result).tuple_shapes().size() != + 0) { for (size_t i = 0; i < type_ctx.second.indices.size(); ++i) { size_t op_idx = type_ctx.second.indices[i]; result[op_idx] = xla::GetTupleElement(all_gather_result, i);