-
Notifications
You must be signed in to change notification settings - Fork 5
Support skipping tracing of selected pure modules #308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
8074c58
wip
tengyifei b2e9778
Support configuring the pure modules
tengyifei d2db238
wip
tengyifei 145f767
solve EinsumLinear patch
tengyifei e40187c
Fix EinsumLinear replacement
tengyifei de44d77
update
tengyifei 58ba60c
Add test
tengyifei 224c7df
revert
tengyifei 757fda4
fix test
tengyifei ec87c92
Add E2E test
tengyifei 961d168
Fix step time bounds
tengyifei File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
36 changes: 36 additions & 0 deletions
36
torchprime/torch_xla_models/model_rewriting/assume_pure.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import torch.nn as nn | ||
from omegaconf import DictConfig | ||
from torch_xla.experimental.assume_pure import PureModule | ||
|
||
from torchprime.sharding.shard_model import wrap_module | ||
from torchprime.torch_xla_models.model_rewriting.rematerialization_utils import ( | ||
get_classes_by_names, | ||
) | ||
|
||
|
||
def mark_pure_modules(model: nn.Module, config: DictConfig) -> nn.Module: | ||
"""Wrap the requested modules in the module tree with `PureModule`. | ||
|
||
There are a few advantages of wrapping a module whose forward pass you know is | ||
free of side-effects and whose behavior only depends on inputs in a `PureModule`: | ||
|
||
- `PureModule`s will only be traced once. | ||
- Framework profile scopes added via `xp.Trace` will show up in both the forward | ||
and the backward pass. | ||
|
||
Args: | ||
model: Model to transform. | ||
config: Config with model.pure_modules settings. | ||
|
||
Returns: | ||
Transformed model. | ||
""" | ||
pure_module_config = config.model.pure_modules | ||
pure_module_classes = get_classes_by_names(model, pure_module_config) | ||
|
||
def transform(mod: nn.Module, _: str): | ||
if isinstance(mod, pure_module_classes): | ||
return PureModule(mod) | ||
return mod | ||
|
||
return wrap_module(model, transform) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch_xla | ||
from omegaconf import OmegaConf | ||
|
||
from torchprime.torch_xla_models.model_rewriting.assume_pure import ( | ||
PureModule, | ||
mark_pure_modules, | ||
) | ||
|
||
|
||
def test_nn_linear(): | ||
inputs = torch.randn((4,), device="xla") | ||
linear = nn.Linear(4, 8) | ||
linear = linear.to("xla") | ||
expected_output = linear(inputs) | ||
torch_xla.sync() | ||
pure_linear = PureModule(linear) | ||
actual_output = pure_linear(inputs) | ||
torch_xla.sync() | ||
torch.testing.assert_close(actual_output, expected_output) | ||
|
||
|
||
def test_rewrite(): | ||
# Arrange | ||
class Foo(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, input): | ||
return input | ||
|
||
class Bar(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, input): | ||
return input | ||
|
||
class Model(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.foo = Foo() | ||
self.bar = Bar() | ||
|
||
def forward(self, input): | ||
return self.foo(input) + self.bar(input) | ||
|
||
model = Model() | ||
config = OmegaConf.create( | ||
{ | ||
"model": { | ||
"pure_modules": ["Foo"], | ||
}, | ||
} | ||
) | ||
|
||
# Act | ||
model = mark_pure_modules(model, config) | ||
|
||
# Assert | ||
assert isinstance(model.foo, PureModule) | ||
assert not isinstance(model.bar, PureModule) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.