Skip to content

Commit 6f68495

Browse files
committed
EmbeddingDenseBackward: Remove padding_idx double cast
Fixes #9392
1 parent 7e3efc5 commit 6f68495

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torch_xla/csrc/tensor_ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ XLATensorPtr EmbeddingDenseBackward(const XLATensorPtr& grad_output,
229229
// Don't accumulate gradients for indices which are equal with the given
230230
// padding_idx.
231231
XLATensorPtr skip_padding = tensor_methods::unsqueeze(
232-
tensor_methods::ne(indices_rank1, static_cast<double>(padding_idx)), 1);
232+
tensor_methods::ne(indices_rank1, padding_idx), 1);
233233
skip_padding = tensor_methods::expand(
234234
skip_padding,
235235
torch::lazy::ToVector<int64_t>(grad->shape().get().dimensions()));

0 commit comments

Comments
 (0)