Skip to content

XLAShardedTensor.to_local() support #9505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from

Conversation

Hoomaaan
Copy link
Contributor

The implementation adds a to_local() method to XLAShardedTensor class that converts a sharded tensor back to its local representation while preserving gradient information.

  1. Core Functionality:
  • Returns the global tensor representation containing combined data across all devices
  • Maintains the same device placement as the original XLAShardedTensor
  • Creates a clone of the global tensor to ensure data independence
  1. Gradient Handling:
  • Preserves the requires_grad setting from the original tensor
  • Maintains gradient values when converting to local representation
  • Ensures proper gradient flow through the converted tensor
  1. Test Coverage:
  • Basic functionality test verifying shape and value preservation
  • Dedicated gradient flow test ensuring:
    requires_grad property is preserved
    gradients are properly calculated and maintained
    backward pass works correctly through the local tensor
    gradient values are accurately preserved

@Hoomaaan Hoomaaan force-pushed the toLocal_wspec branch 2 times, most recently from f0c89b9 to 933a964 Compare July 24, 2025 20:39
@Hoomaaan Hoomaaan requested a review from bfolie August 13, 2025 00:30
# Since global tensor is detached, add requires_grad and grad values back to the local tensor
if self.requires_grad:
result.requires_grad = self.requires_grad
result.grad = self.grad
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The grad doesn't need to be cloned? Is it fine that we break the reference for the tensor, but not its grad? See if we can add a test that shows it is the case (e.g. both the prior and newer tensor updating the same grad).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to clone the self.grad if available

result = self.global_tensor.clone()
# Since global tensor is detached, add requires_grad and grad values back to the local tensor
if self.requires_grad:
result.requires_grad = self.requires_grad
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: result.requires_grad_(self.requires_grad) for in-place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants