Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2511,6 +2511,16 @@ def test_flip_raises_error_on_duplicated_dims(self):
f"from {dims} to {dims_suggestion}.")
self.assertEqual(str(e), expected_error)

def test_full_raises_error_on_negative_size(self):
shape = [2, -2, 2]
try:
torch.full(shape, 1.5, device="xla")
except RuntimeError as e:
expected_error = (
"full(): expected concrete sizes (i.e. non-symbolic) to be "
f"positive values. However found negative ones: {shape}.")
self.assertEqual(str(e), expected_error)


class MNISTComparator(nn.Module):

Expand Down
38 changes: 18 additions & 20 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1702,16 +1702,14 @@ at::Tensor XLANativeFunctions::empty_symint(
// does not actually end up doing any memory initialization, we use that and
// avoid going to CPU for it. A common PT pattern is indeed doing empty() plus
// s_copy_().
XLATensorPtr xla_tensor;
if (all_dims_static) {
xla_tensor = tensor_methods::full(XlaHelpers::I64List(int_sizes.value()), 0,
GetXlaDeviceOrCurrent(device),
at::dtype_or_default(dtype));
} else {
xla_tensor =
tensor_methods::full_symint(sym_size, 0, GetXlaDeviceOrCurrent(device),
at::dtype_or_default(dtype));
}
XLATensorPtr xla_tensor = GetValueOrThrow(
all_dims_static
? tensor_methods::full(XlaHelpers::I64List(int_sizes.value()), 0,
GetXlaDeviceOrCurrent(device),
at::dtype_or_default(dtype))
: tensor_methods::full_symint(sym_size, 0,
GetXlaDeviceOrCurrent(device),
at::dtype_or_default(dtype)));
// `tensor.to` will trigger an `empty` + `_to_copy`. In the egaer mode, the
// `full` will be evulated eagerly and got a replicated sharding. We should
// leave the sharding to be empty.
Expand Down Expand Up @@ -1858,9 +1856,9 @@ at::Tensor XLANativeFunctions::full(at::IntArrayRef size,
} else {
intend_dtype = fill_value.type();
}
return bridge::AtenFromXlaTensor(
return bridge::AtenFromXlaTensor(GetValueOrThrow(
tensor_methods::full(absl::Span<const int64_t>(size), fill_value,
GetXlaDeviceOrCurrent(device), intend_dtype));
GetXlaDeviceOrCurrent(device), intend_dtype)));
}

at::Tensor XLANativeFunctions::gather(const at::Tensor& self, int64_t dim,
Expand Down Expand Up @@ -2681,8 +2679,8 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::nll_loss2d_forward(
int64_t ignore_index) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self));
XLATensorPtr total_weight = tensor_methods::full(
{}, 1, self_tensor->GetDevice(), self_tensor->dtype());
XLATensorPtr total_weight = GetValueOrThrow(tensor_methods::full(
{}, 1, self_tensor->GetDevice(), self_tensor->dtype()));
return std::make_tuple(
bridge::AtenFromXlaTensor(tensor_methods::nll_loss2d(
self_tensor, GetValueOrThrow(bridge::GetXlaTensor(target)),
Expand Down Expand Up @@ -2716,8 +2714,8 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::nll_loss_forward(
int64_t ignore_index) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self));
XLATensorPtr total_weight = tensor_methods::full(
{}, 1, self_tensor->GetDevice(), self_tensor->dtype());
XLATensorPtr total_weight = GetValueOrThrow(tensor_methods::full(
{}, 1, self_tensor->GetDevice(), self_tensor->dtype()));
return std::make_tuple(
bridge::AtenFromXlaTensor(tensor_methods::nll_loss(
self_tensor, GetValueOrThrow(bridge::GetXlaTensor(target)),
Expand Down Expand Up @@ -4038,10 +4036,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> XLANativeFunctions::_linalg_svd(
if (!compute_uv) {
// When compute_uv is false, torch::_linalg_svd returns an empty tensor for
// u and vh.
u = tensor_methods::full({0}, 0, self_tensor->GetDevice(),
self_tensor->dtype());
vh = tensor_methods::full({0}, 0, self_tensor->GetDevice(),
self_tensor->dtype());
u = GetValueOrThrow(tensor_methods::full({0}, 0, self_tensor->GetDevice(),
self_tensor->dtype()));
vh = GetValueOrThrow(tensor_methods::full({0}, 0, self_tensor->GetDevice(),
self_tensor->dtype()));
}
return std::make_tuple(bridge::AtenFromXlaTensor(u),
bridge::AtenFromXlaTensor(s),
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/ops/index_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,8 @@ XLATensorPtr GetZeroElementTensor(const XLATensorPtr& base,
base_dimensions.begin() + start_dim + indices.size(),
base_dimensions.end());

return tensor_methods::full(dimensions, 0, base->GetDevice(), base->dtype());
return GetValueOrThrow(
tensor_methods::full(dimensions, 0, base->GetDevice(), base->dtype()));
}

XLATensorPtr IndexByTensors(const XLATensorPtr& base,
Expand Down
58 changes: 41 additions & 17 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,39 @@ absl::Status CheckFlipDimensionsAreUnique(
return absl::OkStatus();
}

template <class F>
absl::Status CheckFullSizesArePositiveImpl(absl::Span<const int64_t> sizes,
const F& original_sizes_as_str) {
const bool has_concrete_negative_size = std::any_of(
sizes.begin(), sizes.end(), [](const int64_t size) { return size < 0; });
if (has_concrete_negative_size) {
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(
absl::StrCat("full(): expected concrete sizes (i.e. non-symbolic) to "
"be positive values. However found negative ones: [",
original_sizes_as_str(), "].")));
}
return absl::OkStatus();
}

absl::Status CheckFullSizesArePositive(absl::Span<const int64_t> sizes) {
return CheckFullSizesArePositiveImpl(
sizes, [&]() { return absl::StrJoin(sizes, /* sep= */ ", "); });
}

absl::Status CheckFullConcreteSizesArePositive(at::SymIntArrayRef sym_sizes) {
std::vector<int64_t> concrete_sizes_or_zero;
std::transform(sym_sizes.begin(), sym_sizes.end(),
std::back_inserter(concrete_sizes_or_zero),
[](at::SymInt sym) { return sym.maybe_as_int().value_or(0); });
return CheckFullSizesArePositiveImpl(concrete_sizes_or_zero, [&]() {
return absl::StrJoin(sym_sizes.begin(), sym_sizes.end(), /* sep= */ ", ",
[](std::string* out, at::SymInt sym) {
absl::StrAppendFormat(out, "%s",
absl::FormatStreamed(sym));
});
});
}

} // namespace

//////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -1767,10 +1800,10 @@ XLATensorPtr fmod(const XLATensorPtr& input, const at::Scalar& other,
logical_element_type);
}

XLATensorPtr full(absl::Span<const int64_t> size, const at::Scalar& fill_value,
const torch::lazy::BackendDevice& device,
at::ScalarType scalar_type) {
CheckShapeDimensions(size);
absl::StatusOr<absl_nonnull XLATensorPtr> full(
absl::Span<const int64_t> size, const at::Scalar& fill_value,
const torch::lazy::BackendDevice& device, at::ScalarType scalar_type) {
XLA_RETURN_IF_ERROR(CheckFullSizesArePositive(size));
xla::Shape shape =
MakeArrayShapeFromDimensions(size, /*dynamic_dimensions=*/{},
MakeXlaPrimitiveType(scalar_type, &device),
Expand All @@ -1794,19 +1827,10 @@ XLATensorPtr full_like(const XLATensorPtr& input, const at::Scalar& fill_value,
device, *scalar_type);
}

XLATensorPtr full_symint(at::SymIntArrayRef sym_size,
const at::Scalar& fill_value,
const torch::lazy::BackendDevice& device,
at::ScalarType scalar_type) {
XLA_CHECK(std::all_of(sym_size.begin(), sym_size.end(), [](at::SymInt dim) {
// TODO: It should be OK to perform this test on symbolic ints too, not
// sure why you conditionalized it.
if (auto c = dim.maybe_as_int()) {
return *c >= 0;
}
return true;
})) << "Dimensions cannot be negative numbers";

absl::StatusOr<absl_nonnull XLATensorPtr> full_symint(
at::SymIntArrayRef sym_size, const at::Scalar& fill_value,
const torch::lazy::BackendDevice& device, at::ScalarType scalar_type) {
XLA_RETURN_IF_ERROR(CheckFullConcreteSizesArePositive(sym_size));
return XLATensor::Create(
XLAGraphExecutor::Get()->GetIrValueForScalar(
fill_value, MakeXlaPrimitiveType(scalar_type, &device), sym_size,
Expand Down
13 changes: 6 additions & 7 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -460,16 +460,15 @@ XLATensorPtr fmod(
const XLATensorPtr& input, const at::Scalar& other,
std::optional<at::ScalarType> logical_element_type = std::nullopt);

XLATensorPtr full(absl::Span<const int64_t> size, const at::Scalar& fill_value,
const torch::lazy::BackendDevice& device,
at::ScalarType scalar_type);
absl::StatusOr<absl_nonnull XLATensorPtr> full(
absl::Span<const int64_t> size, const at::Scalar& fill_value,
const torch::lazy::BackendDevice& device, at::ScalarType scalar_type);
XLATensorPtr full_like(const XLATensorPtr& input, const at::Scalar& fill_value,
const torch::lazy::BackendDevice& device,
std::optional<at::ScalarType> scalar_type);
XLATensorPtr full_symint(at::SymIntArrayRef sym_size,
const at::Scalar& fill_value,
const torch::lazy::BackendDevice& device,
at::ScalarType scalar_type);
absl::StatusOr<absl_nonnull XLATensorPtr> full_symint(
at::SymIntArrayRef sym_size, const at::Scalar& fill_value,
const torch::lazy::BackendDevice& device, at::ScalarType scalar_type);

XLATensorPtr gather(const XLATensorPtr& input, int64_t dim,
const XLATensorPtr& index);
Expand Down
12 changes: 6 additions & 6 deletions torch_xla/csrc/tensor_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,16 +207,16 @@ XLATensorPtr EmbeddingDenseBackward(const XLATensorPtr& grad_output,
int64_t numel = xla::ShapeUtil::ElementsIn(indices_shape_ref.get());
XLATensorPtr grad =
tensor_methods::view(grad_output, {numel, grad_output->size(-1)});
XLATensorPtr grad_weight =
XLATensorPtr grad_weight = GetValueOrThrow(
tensor_methods::full({num_weights, grad_output->size(-1)}, 0,
grad_output->GetDevice(), grad_output->dtype());
grad_output->GetDevice(), grad_output->dtype()));
XLATensorPtr indices_rank1 = tensor_methods::view(indices, {numel});
if (scale_grad_by_freq) {
// Compute the histogram of index values.
XLATensorPtr counts = tensor_methods::full(
{num_weights}, 0, indices->GetDevice(), indices->dtype());
XLATensorPtr ones = tensor_methods::full({numel}, 1, indices->GetDevice(),
indices->dtype());
XLATensorPtr counts = GetValueOrThrow(tensor_methods::full(
{num_weights}, 0, indices->GetDevice(), indices->dtype()));
XLATensorPtr ones = GetValueOrThrow(tensor_methods::full(
{numel}, 1, indices->GetDevice(), indices->dtype()));
tensor_methods::index_put_(counts, counts, {indices_rank1}, /*start_dim=*/0,
/*values=*/ones,
/*accumulate=*/true, /*result_permutation=*/{0});
Expand Down