From 6f6849581a9ac0725c0623b7f0a1236f1ab461c8 Mon Sep 17 00:00:00 2001 From: Daniil Dudkin Date: Wed, 25 Jun 2025 11:03:47 +0000 Subject: [PATCH 1/3] EmbeddingDenseBackward: Remove `padding_idx` double cast Fixes #9392 --- torch_xla/csrc/tensor_ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/tensor_ops.cpp b/torch_xla/csrc/tensor_ops.cpp index 8f1b17799e10..ad8d58e13727 100644 --- a/torch_xla/csrc/tensor_ops.cpp +++ b/torch_xla/csrc/tensor_ops.cpp @@ -229,7 +229,7 @@ XLATensorPtr EmbeddingDenseBackward(const XLATensorPtr& grad_output, // Don't accumulate gradients for indices which are equal with the given // padding_idx. XLATensorPtr skip_padding = tensor_methods::unsqueeze( - tensor_methods::ne(indices_rank1, static_cast(padding_idx)), 1); + tensor_methods::ne(indices_rank1, padding_idx), 1); skip_padding = tensor_methods::expand( skip_padding, torch::lazy::ToVector(grad->shape().get().dimensions())); From 26e0a2c117493d89fc582abe1a3290bf93081438 Mon Sep 17 00:00:00 2001 From: Daniil Dudkin Date: Thu, 26 Jun 2025 21:49:35 +0000 Subject: [PATCH 2/3] add test --- test/test_operations.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/test/test_operations.py b/test/test_operations.py index 3f6774e87413..656a361d3ef6 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -752,6 +752,28 @@ def test_rrelu_module(self): xla_output.sum().backward() self.assertEqual(a.grad, xla_a.grad.cpu()) + def test_embedding_module(self): + num_embeddings = 16 + embed_dim = 4 + input_shape = (2, 3) + + xla_device = torch_xla.device() + + idx = torch.randint(0, num_embeddings, input_shape, dtype=torch.long) + xla_idx = idx.to(xla_device) + + m = nn.Embedding(num_embeddings, embed_dim) + xla_m = nn.Embedding(num_embeddings, embed_dim).to(xla_device) + xla_m.weight.data.copy_(m.weight.data) # keep parameters in sync + + output = m(idx) + xla_output = xla_m(xla_idx) + self.assertEqual(output, xla_output.cpu()) + + output.sum().backward() + xla_output.sum().backward() + self.assertEqual(m.weight.grad, xla_m.weight.grad.cpu()) + def test_max_broadcast(self): xla_device = torch_xla.device() a = torch.rand(3, 1, 2) From 9022b1a197ea998e490299c0c0b925994dbfc1d8 Mon Sep 17 00:00:00 2001 From: Daniil Dudkin Date: Thu, 26 Jun 2025 22:42:35 +0000 Subject: [PATCH 3/3] fix style --- test/test_operations.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index 656a361d3ef6..7b52d6c9fa8d 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -755,7 +755,7 @@ def test_rrelu_module(self): def test_embedding_module(self): num_embeddings = 16 embed_dim = 4 - input_shape = (2, 3) + input_shape = (2, 3) xla_device = torch_xla.device() @@ -764,10 +764,11 @@ def test_embedding_module(self): m = nn.Embedding(num_embeddings, embed_dim) xla_m = nn.Embedding(num_embeddings, embed_dim).to(xla_device) - xla_m.weight.data.copy_(m.weight.data) # keep parameters in sync + # keep parameters in sync + xla_m.weight.data.copy_(m.weight.data) output = m(idx) - xla_output = xla_m(xla_idx) + xla_output = xla_m(xla_idx) self.assertEqual(output, xla_output.cpu()) output.sum().backward()