Skip to content

Unnecessary FP64 cast for padding_idx in EmbeddingDenseBackward #9392

Open
@unterumarmung

Description

@unterumarmung

torch_xla/csrc/tensor_ops.cpp #L232 converts an int64_t padding_idx to double.
This introduces FP64 ops even though indices_rank1 is always int32 or int64, while the upstream ATen implementation keeps everything in integer space.

Questions

  • Was the double cast introduced for a specific historical or hardware reason?
  • How is this handled on devices without native FP64 support (e.g. TPU)?

Proposed fix

  1. Keep padding_idx as int64_t, or
  2. Cast it to the same dtype as indices_rank1 before comparison.

Either option would avoid unnecessary FP64 operations and align with ATen’s behavior.

Thanks for taking a look!

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requestloweringATen Operation lowering

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions