Skip to content

Draft: Yifeit/assume pure #308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft

Draft: Yifeit/assume pure #308

wants to merge 9 commits into from

Conversation

tengyifei
Copy link
Collaborator

No description provided.

@tengyifei tengyifei force-pushed the yifeit/assume-pure branch 2 times, most recently from cd77464 to d577a40 Compare June 14, 2025 02:07
def replace_nn_linear_with_einsum(module: torch.nn.Module, config: DictConfig):
"""Recursively replace `nn.Linear` layers with `EinsumLinear` in the module.

Without this patch, an `nn.Linear` module in PyTorch/XLA will lower to reshapes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recall that xla's einsum worked fine when used without custom-op - can you point me to where the einsum becomes a custom op and remind me why?

do we need to use this torch_xla.distributed.spmd.xla_sharding.apply_xla_patch_to_nn_linear - could we use assume_pure on the linear model directly in this code? That question may be missing some basic knowledge of assume-pure, my apologies if so.

tengyifei added a commit to pytorch/xla that referenced this pull request Jun 16, 2025
- Support nested tuples in `assume_pure(mark_sharding)`
- Add a `PureModule` from AI-Hypercomputer/torchprime#308
- Support `PureModule(EinsumLinear)` which uses `torch.ops.xla.einsum_linear_forward`
tengyifei added a commit to pytorch/xla that referenced this pull request Jun 16, 2025
- Support nested tuples in `assume_pure(mark_sharding)`
- Add a `PureModule` from AI-Hypercomputer/torchprime#308
- Support `PureModule(EinsumLinear)` which uses `torch.ops.xla.einsum_linear_forward`
tengyifei added a commit to pytorch/xla that referenced this pull request Jun 16, 2025
- Support nested tuples in `assume_pure(mark_sharding)`
- Add a `PureModule` from AI-Hypercomputer/torchprime#308
- Support `PureModule(EinsumLinear)` which uses `torch.ops.xla.einsum_linear_forward`
tengyifei added a commit to pytorch/xla that referenced this pull request Jun 16, 2025
- Support nested tuples in `assume_pure(mark_sharding)`
- Add a `PureModule` from AI-Hypercomputer/torchprime#308
- Support `PureModule(EinsumLinear)` which uses `torch.ops.xla.einsum_linear_forward`
@tengyifei tengyifei force-pushed the yifeit/assume-pure branch from 574f142 to c8ba416 Compare June 17, 2025 08:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants