-
Notifications
You must be signed in to change notification settings - Fork 558
XLAShardedTensor.to_local() support #9505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
f0c89b9
to
933a964
Compare
…ure proper XLA support and maintain consistency with PyTorch/XLA SPMD integration.
812a69a
to
6dc2351
Compare
# Since global tensor is detached, add requires_grad and grad values back to the local tensor | ||
if self.requires_grad: | ||
result.requires_grad = self.requires_grad | ||
result.grad = self.grad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The grad doesn't need to be cloned? Is it fine that we break the reference for the tensor, but not its grad? See if we can add a test that shows it is the case (e.g. both the prior and newer tensor updating the same grad).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to clone the self.grad if available
result = self.global_tensor.clone() | ||
# Since global tensor is detached, add requires_grad and grad values back to the local tensor | ||
if self.requires_grad: | ||
result.requires_grad = self.requires_grad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: result.requires_grad_(self.requires_grad) for in-place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
The implementation adds a to_local() method to XLAShardedTensor class that converts a sharded tensor back to its local representation while preserving gradient information.
requires_grad property is preserved
gradients are properly calculated and maintained
backward pass works correctly through the local tensor
gradient values are accurately preserved