diff --git a/test/test_operations.py b/test/test_operations.py index cb790a07414..a544d9ba19a 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2511,6 +2511,16 @@ def test_flip_raises_error_on_duplicated_dims(self): f"from {dims} to {dims_suggestion}.") self.assertEqual(str(e), expected_error) + def test_full_raises_error_on_negative_size(self): + shape = [2, -2, 2] + try: + torch.full(shape, 1.5, device="xla") + except RuntimeError as e: + expected_error = ( + "full(): expected concrete sizes (i.e. non-symbolic) to be " + f"positive values. However found negative ones: {shape}.") + self.assertEqual(str(e), expected_error) + class MNISTComparator(nn.Module): diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 005e0e98dcc..ceacd59603e 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1702,16 +1702,14 @@ at::Tensor XLANativeFunctions::empty_symint( // does not actually end up doing any memory initialization, we use that and // avoid going to CPU for it. A common PT pattern is indeed doing empty() plus // s_copy_(). - XLATensorPtr xla_tensor; - if (all_dims_static) { - xla_tensor = tensor_methods::full(XlaHelpers::I64List(int_sizes.value()), 0, - GetXlaDeviceOrCurrent(device), - at::dtype_or_default(dtype)); - } else { - xla_tensor = - tensor_methods::full_symint(sym_size, 0, GetXlaDeviceOrCurrent(device), - at::dtype_or_default(dtype)); - } + XLATensorPtr xla_tensor = GetValueOrThrow( + all_dims_static + ? tensor_methods::full(XlaHelpers::I64List(int_sizes.value()), 0, + GetXlaDeviceOrCurrent(device), + at::dtype_or_default(dtype)) + : tensor_methods::full_symint(sym_size, 0, + GetXlaDeviceOrCurrent(device), + at::dtype_or_default(dtype))); // `tensor.to` will trigger an `empty` + `_to_copy`. In the egaer mode, the // `full` will be evulated eagerly and got a replicated sharding. We should // leave the sharding to be empty. @@ -1858,9 +1856,9 @@ at::Tensor XLANativeFunctions::full(at::IntArrayRef size, } else { intend_dtype = fill_value.type(); } - return bridge::AtenFromXlaTensor( + return bridge::AtenFromXlaTensor(GetValueOrThrow( tensor_methods::full(absl::Span(size), fill_value, - GetXlaDeviceOrCurrent(device), intend_dtype)); + GetXlaDeviceOrCurrent(device), intend_dtype))); } at::Tensor XLANativeFunctions::gather(const at::Tensor& self, int64_t dim, @@ -2681,8 +2679,8 @@ std::tuple XLANativeFunctions::nll_loss2d_forward( int64_t ignore_index) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - XLATensorPtr total_weight = tensor_methods::full( - {}, 1, self_tensor->GetDevice(), self_tensor->dtype()); + XLATensorPtr total_weight = GetValueOrThrow(tensor_methods::full( + {}, 1, self_tensor->GetDevice(), self_tensor->dtype())); return std::make_tuple( bridge::AtenFromXlaTensor(tensor_methods::nll_loss2d( self_tensor, GetValueOrThrow(bridge::GetXlaTensor(target)), @@ -2716,8 +2714,8 @@ std::tuple XLANativeFunctions::nll_loss_forward( int64_t ignore_index) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - XLATensorPtr total_weight = tensor_methods::full( - {}, 1, self_tensor->GetDevice(), self_tensor->dtype()); + XLATensorPtr total_weight = GetValueOrThrow(tensor_methods::full( + {}, 1, self_tensor->GetDevice(), self_tensor->dtype())); return std::make_tuple( bridge::AtenFromXlaTensor(tensor_methods::nll_loss( self_tensor, GetValueOrThrow(bridge::GetXlaTensor(target)), @@ -4038,10 +4036,10 @@ std::tuple XLANativeFunctions::_linalg_svd( if (!compute_uv) { // When compute_uv is false, torch::_linalg_svd returns an empty tensor for // u and vh. - u = tensor_methods::full({0}, 0, self_tensor->GetDevice(), - self_tensor->dtype()); - vh = tensor_methods::full({0}, 0, self_tensor->GetDevice(), - self_tensor->dtype()); + u = GetValueOrThrow(tensor_methods::full({0}, 0, self_tensor->GetDevice(), + self_tensor->dtype())); + vh = GetValueOrThrow(tensor_methods::full({0}, 0, self_tensor->GetDevice(), + self_tensor->dtype())); } return std::make_tuple(bridge::AtenFromXlaTensor(u), bridge::AtenFromXlaTensor(s), diff --git a/torch_xla/csrc/ops/index_ops.cpp b/torch_xla/csrc/ops/index_ops.cpp index a189446bbef..25b0b4f4b85 100644 --- a/torch_xla/csrc/ops/index_ops.cpp +++ b/torch_xla/csrc/ops/index_ops.cpp @@ -315,7 +315,8 @@ XLATensorPtr GetZeroElementTensor(const XLATensorPtr& base, base_dimensions.begin() + start_dim + indices.size(), base_dimensions.end()); - return tensor_methods::full(dimensions, 0, base->GetDevice(), base->dtype()); + return GetValueOrThrow( + tensor_methods::full(dimensions, 0, base->GetDevice(), base->dtype())); } XLATensorPtr IndexByTensors(const XLATensorPtr& base, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 4a749d50ac7..7534191f042 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -409,6 +409,39 @@ absl::Status CheckFlipDimensionsAreUnique( return absl::OkStatus(); } +template +absl::Status CheckFullSizesArePositiveImpl(absl::Span sizes, + const F& original_sizes_as_str) { + const bool has_concrete_negative_size = std::any_of( + sizes.begin(), sizes.end(), [](const int64_t size) { return size < 0; }); + if (has_concrete_negative_size) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError( + absl::StrCat("full(): expected concrete sizes (i.e. non-symbolic) to " + "be positive values. However found negative ones: [", + original_sizes_as_str(), "]."))); + } + return absl::OkStatus(); +} + +absl::Status CheckFullSizesArePositive(absl::Span sizes) { + return CheckFullSizesArePositiveImpl( + sizes, [&]() { return absl::StrJoin(sizes, /* sep= */ ", "); }); +} + +absl::Status CheckFullConcreteSizesArePositive(at::SymIntArrayRef sym_sizes) { + std::vector concrete_sizes_or_zero; + std::transform(sym_sizes.begin(), sym_sizes.end(), + std::back_inserter(concrete_sizes_or_zero), + [](at::SymInt sym) { return sym.maybe_as_int().value_or(0); }); + return CheckFullSizesArePositiveImpl(concrete_sizes_or_zero, [&]() { + return absl::StrJoin(sym_sizes.begin(), sym_sizes.end(), /* sep= */ ", ", + [](std::string* out, at::SymInt sym) { + absl::StrAppendFormat(out, "%s", + absl::FormatStreamed(sym)); + }); + }); +} + } // namespace ////////////////////////////////////////////////////////////////////////////// @@ -1767,10 +1800,10 @@ XLATensorPtr fmod(const XLATensorPtr& input, const at::Scalar& other, logical_element_type); } -XLATensorPtr full(absl::Span size, const at::Scalar& fill_value, - const torch::lazy::BackendDevice& device, - at::ScalarType scalar_type) { - CheckShapeDimensions(size); +absl::StatusOr full( + absl::Span size, const at::Scalar& fill_value, + const torch::lazy::BackendDevice& device, at::ScalarType scalar_type) { + XLA_RETURN_IF_ERROR(CheckFullSizesArePositive(size)); xla::Shape shape = MakeArrayShapeFromDimensions(size, /*dynamic_dimensions=*/{}, MakeXlaPrimitiveType(scalar_type, &device), @@ -1794,19 +1827,10 @@ XLATensorPtr full_like(const XLATensorPtr& input, const at::Scalar& fill_value, device, *scalar_type); } -XLATensorPtr full_symint(at::SymIntArrayRef sym_size, - const at::Scalar& fill_value, - const torch::lazy::BackendDevice& device, - at::ScalarType scalar_type) { - XLA_CHECK(std::all_of(sym_size.begin(), sym_size.end(), [](at::SymInt dim) { - // TODO: It should be OK to perform this test on symbolic ints too, not - // sure why you conditionalized it. - if (auto c = dim.maybe_as_int()) { - return *c >= 0; - } - return true; - })) << "Dimensions cannot be negative numbers"; - +absl::StatusOr full_symint( + at::SymIntArrayRef sym_size, const at::Scalar& fill_value, + const torch::lazy::BackendDevice& device, at::ScalarType scalar_type) { + XLA_RETURN_IF_ERROR(CheckFullConcreteSizesArePositive(sym_size)); return XLATensor::Create( XLAGraphExecutor::Get()->GetIrValueForScalar( fill_value, MakeXlaPrimitiveType(scalar_type, &device), sym_size, diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index fb7eae93f8d..869dcaa8dff 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -460,16 +460,15 @@ XLATensorPtr fmod( const XLATensorPtr& input, const at::Scalar& other, std::optional logical_element_type = std::nullopt); -XLATensorPtr full(absl::Span size, const at::Scalar& fill_value, - const torch::lazy::BackendDevice& device, - at::ScalarType scalar_type); +absl::StatusOr full( + absl::Span size, const at::Scalar& fill_value, + const torch::lazy::BackendDevice& device, at::ScalarType scalar_type); XLATensorPtr full_like(const XLATensorPtr& input, const at::Scalar& fill_value, const torch::lazy::BackendDevice& device, std::optional scalar_type); -XLATensorPtr full_symint(at::SymIntArrayRef sym_size, - const at::Scalar& fill_value, - const torch::lazy::BackendDevice& device, - at::ScalarType scalar_type); +absl::StatusOr full_symint( + at::SymIntArrayRef sym_size, const at::Scalar& fill_value, + const torch::lazy::BackendDevice& device, at::ScalarType scalar_type); XLATensorPtr gather(const XLATensorPtr& input, int64_t dim, const XLATensorPtr& index); diff --git a/torch_xla/csrc/tensor_ops.cpp b/torch_xla/csrc/tensor_ops.cpp index 2b925d7c381..edb7d22297c 100644 --- a/torch_xla/csrc/tensor_ops.cpp +++ b/torch_xla/csrc/tensor_ops.cpp @@ -207,16 +207,16 @@ XLATensorPtr EmbeddingDenseBackward(const XLATensorPtr& grad_output, int64_t numel = xla::ShapeUtil::ElementsIn(indices_shape_ref.get()); XLATensorPtr grad = tensor_methods::view(grad_output, {numel, grad_output->size(-1)}); - XLATensorPtr grad_weight = + XLATensorPtr grad_weight = GetValueOrThrow( tensor_methods::full({num_weights, grad_output->size(-1)}, 0, - grad_output->GetDevice(), grad_output->dtype()); + grad_output->GetDevice(), grad_output->dtype())); XLATensorPtr indices_rank1 = tensor_methods::view(indices, {numel}); if (scale_grad_by_freq) { // Compute the histogram of index values. - XLATensorPtr counts = tensor_methods::full( - {num_weights}, 0, indices->GetDevice(), indices->dtype()); - XLATensorPtr ones = tensor_methods::full({numel}, 1, indices->GetDevice(), - indices->dtype()); + XLATensorPtr counts = GetValueOrThrow(tensor_methods::full( + {num_weights}, 0, indices->GetDevice(), indices->dtype())); + XLATensorPtr ones = GetValueOrThrow(tensor_methods::full( + {numel}, 1, indices->GetDevice(), indices->dtype())); tensor_methods::index_put_(counts, counts, {indices_rank1}, /*start_dim=*/0, /*values=*/ones, /*accumulate=*/true, /*result_permutation=*/{0});