Open
Description
torch_xla/csrc/tensor_ops.cpp #L232 converts an int64_t padding_idx
to double
.
This introduces FP64 ops even though indices_rank1
is always int32
or int64
, while the upstream ATen implementation keeps everything in integer space.
Questions
- Was the double cast introduced for a specific historical or hardware reason?
- How is this handled on devices without native FP64 support (e.g. TPU)?
Proposed fix
- Keep
padding_idx
asint64_t
, or - Cast it to the same dtype as
indices_rank1
before comparison.
Either option would avoid unnecessary FP64 operations and align with ATen’s behavior.
Thanks for taking a look!