diff --git a/test/test_operations.py b/test/test_operations.py index 3f6774e8741..7b52d6c9fa8 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -752,6 +752,29 @@ def test_rrelu_module(self): xla_output.sum().backward() self.assertEqual(a.grad, xla_a.grad.cpu()) + def test_embedding_module(self): + num_embeddings = 16 + embed_dim = 4 + input_shape = (2, 3) + + xla_device = torch_xla.device() + + idx = torch.randint(0, num_embeddings, input_shape, dtype=torch.long) + xla_idx = idx.to(xla_device) + + m = nn.Embedding(num_embeddings, embed_dim) + xla_m = nn.Embedding(num_embeddings, embed_dim).to(xla_device) + # keep parameters in sync + xla_m.weight.data.copy_(m.weight.data) + + output = m(idx) + xla_output = xla_m(xla_idx) + self.assertEqual(output, xla_output.cpu()) + + output.sum().backward() + xla_output.sum().backward() + self.assertEqual(m.weight.grad, xla_m.weight.grad.cpu()) + def test_max_broadcast(self): xla_device = torch_xla.device() a = torch.rand(3, 1, 2) diff --git a/torch_xla/csrc/tensor_ops.cpp b/torch_xla/csrc/tensor_ops.cpp index 8f1b17799e1..ad8d58e1372 100644 --- a/torch_xla/csrc/tensor_ops.cpp +++ b/torch_xla/csrc/tensor_ops.cpp @@ -229,7 +229,7 @@ XLATensorPtr EmbeddingDenseBackward(const XLATensorPtr& grad_output, // Don't accumulate gradients for indices which are equal with the given // padding_idx. XLATensorPtr skip_padding = tensor_methods::unsqueeze( - tensor_methods::ne(indices_rank1, static_cast(padding_idx)), 1); + tensor_methods::ne(indices_rank1, padding_idx), 1); skip_padding = tensor_methods::expand( skip_padding, torch::lazy::ToVector(grad->shape().get().dimensions()));