Skip to content

Support skipping tracing of selected pure modules #308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions .github/workflows/e2e_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jobs:
ARTIFACT_DIR: gs://torchprime-e2e-tests/${{ github.job }}/${{ github.run_id }}-${{ github.run_attempt }}
outputs:
llama-3-8b-name: ${{ steps.run-llama-3-8b.outputs.name }}
llama-3-8b-pure-mlp-name: ${{ steps.run-llama-3-8b-pure-mlp.outputs.name }}
llama-3_1-8b-sa-name: ${{ steps.run-llama-3_1-8b-SplashAttention.outputs.name }}
llama-3_1-8b-scan-offload-name: ${{ steps.run-llama-3_1-8b-scan-offload.outputs.name }}
llama-3-8b-2d-name: ${{ steps.run-llama-3-8b-2d.outputs.name }}
Expand Down Expand Up @@ -83,6 +84,28 @@ jobs:
ici_mesh.fsdp=4 \
profile_start_step=3

- name: Run Llama 3.0 8B (@assume_pure)
id: run-llama-3-8b-pure-mlp
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
run: |
name=$(e2e_testing/gen_name.py llama-3-8b-pure-mlp)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run ${{ steps.docker-url-option.outputs.value }} \
--name $name \
torchprime/torch_xla_models/train.py \
model=llama-3-8b \
dataset=wikitext \
task=train \
task.global_batch_size=8 \
task.lr_scheduler.type=constant \
task.max_steps=15 \
ici_mesh.fsdp=4 \
profile_start_step=3 \
model.pure_modules=[LlamaMLP,EinsumLinear]

- name: Run Llama 3.1 8B (Splash Attention)
id: run-llama-3_1-8b-SplashAttention
env:
Expand Down Expand Up @@ -259,6 +282,7 @@ jobs:
jobset_name: >-
${{
matrix.config.benchmark == 'llama-3-8b' && needs.tp-run.outputs.llama-3-8b-name ||
matrix.config.benchmark == 'llama-3-8b-pure-mlp' && needs.tp-run.outputs.llama-3-8b-pure-mlp-name ||
matrix.config.benchmark == 'llama-3_1-8b-sa' && needs.tp-run.outputs.llama-3_1-8b-sa-name ||
matrix.config.benchmark == 'llama-3_1-8b-scan-offload' && needs.tp-run.outputs.llama-3_1-8b-scan-offload-name ||
matrix.config.benchmark == 'llama-3-8b-2d' && needs.tp-run.outputs.llama-3-8b-2d-name ||
Expand Down
9 changes: 9 additions & 0 deletions e2e_testing/step_time_bounds.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@ benchmarks:
confidence_interval: 0.05407
average: 2.7352
sample_size: 427
llama-3-8b-pure-mlp:
name: Llama 3.0 8B (@assume_pure)
# Bounds are copied from `llama-3-8b`. They will be overwritten the next time
# somebody runs `e2e_testing/update_step_time.py`.
step_time_lower_bound: 2.68109009
step_time_upper_bound: 2.789223
confidence_interval: 0.05407
average: 2.7352
sample_size: 1
llama-3_1-8b-sa:
name: Llama 3.1 8B (Splash Attention)
step_time_lower_bound: 2.34653077
Expand Down
19 changes: 19 additions & 0 deletions e2e_testing/update_step_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,23 @@ def match_llama3_8b(row):
and config["dcn_mesh"]["data"] == 1
and config["dcn_mesh"]["fsdp"] == 1
and config["ici_mesh"]["tensor"] == 1
and (
"pure_modules" not in config["model"] or len(config["model"]["pure_modules"]) == 0
)
)


def match_llama3_8b_pure_mlp(row):
config = json.loads(row.configs_framework)
return (
row.run_id.startswith("llama-3-8b-pure-mlp")
and config["dcn_mesh"]["data"] == 1
and config["dcn_mesh"]["fsdp"] == 1
and config["ici_mesh"]["tensor"] == 1
and (
"pure_modules" in config["model"]
and config["model"]["pure_modules"] == ["LlamaMLP", "EinsumLinear"]
)
)


Expand Down Expand Up @@ -86,6 +103,7 @@ def match_llama_3_8b_ddp_fsdp(row):

BENCHMARKS = {
"Llama 3.0 8B": match_llama3_8b,
"Llama 3.0 8B (@assume_pure)": match_llama3_8b_pure_mlp,
"Llama 3.1 8B (Splash Attention)": match_llama3_1_8b_sa,
"Llama 3.1 8B (Scan + Offload)": match_llama3_1_8b_scan_offload,
"Llama 3.0 8B (2D sharding)": match_llama3_8b_2d,
Expand All @@ -96,6 +114,7 @@ def match_llama_3_8b_ddp_fsdp(row):

STEP_ID_MAPPING = {
"Llama 3.0 8B": "llama-3-8b",
"Llama 3.0 8B (@assume_pure)": "llama-3-8b-pure-mlp",
"Llama 3.1 8B (Splash Attention)": "llama-3_1-8b-sa",
"Llama 3.1 8B (Scan + Offload)": "llama-3_1-8b-scan-offload",
"Llama 3.0 8B (2D sharding)": "llama-3-8b-2d",
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ dev = [
tp = "torchprime.launcher.cli:cli"

[tool.torchprime]
torch_xla_version = "20250606"
torch_xla_version = "20250617"

[tool.setuptools.packages.find]
where = [""]
Expand Down
4 changes: 2 additions & 2 deletions torchprime/sharding/shard_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def shard_torchax_model_from_config(
"""
import jax
from jax.sharding import NamedSharding, PartitionSpec
from torchax.interop import torch_view
from torchax.interop import jax_view, torch_view

jax_mark_sharding = torch_view(jax.lax.with_sharding_constraint)

Expand All @@ -197,7 +197,7 @@ def shard_param(tensor, spec: tuple[str, ...]):
# and models are usually constructed eagerly in torchax.
return torch_view(
jax.make_array_from_callback(
tensor.shape, sharding, lambda slice_index: tensor[slice_index]
tensor.shape, sharding, lambda slice_index: jax_view(tensor[slice_index])
)
)

Expand Down
3 changes: 2 additions & 1 deletion torchprime/sharding/tests/test_shard_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,8 @@ def test_shard_model_from_config_torchax():
devices = mesh_utils.create_device_mesh((jax.device_count(),))
mesh = Mesh(devices, ("fsdp",))

model = shard_torchax_model_from_config(model, config, mesh)
with torchax.default_env():
model = shard_torchax_model_from_config(model, config, mesh)

# In order to shard activations, corresponding modules are
# wrapped with ShardedModule.
Expand Down
10 changes: 10 additions & 0 deletions torchprime/torch_xla_models/configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ dcn_mesh:
# They can be overridden on the command line or by importing one of the presets
# in the `model/remat` directory.
model:
# Name of classes in the module tree that are functionally pure.
#
# There are a few advantages of wrapping a module whose forward pass you know is
# free of side-effects and whose behavior only depends on inputs in a `PureModule`:
# - `PureModule`s will only be traced once.
# - Framework profile scopes added via `xp.Trace` will show up in both the forward
# and the backward pass.
pure_modules: []

# Options for controlling tensor rematerialization.
remat:
# The class names of model layers whose intermediate activations should be
# recomputed during the backward pass (i.e. activation checkpointing).
Expand Down
36 changes: 36 additions & 0 deletions torchprime/torch_xla_models/model_rewriting/assume_pure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch.nn as nn
from omegaconf import DictConfig
from torch_xla.experimental.assume_pure import PureModule

from torchprime.sharding.shard_model import wrap_module
from torchprime.torch_xla_models.model_rewriting.rematerialization_utils import (
get_classes_by_names,
)


def mark_pure_modules(model: nn.Module, config: DictConfig) -> nn.Module:
"""Wrap the requested modules in the module tree with `PureModule`.

There are a few advantages of wrapping a module whose forward pass you know is
free of side-effects and whose behavior only depends on inputs in a `PureModule`:

- `PureModule`s will only be traced once.
- Framework profile scopes added via `xp.Trace` will show up in both the forward
and the backward pass.

Args:
model: Model to transform.
config: Config with model.pure_modules settings.

Returns:
Transformed model.
"""
pure_module_config = config.model.pure_modules
pure_module_classes = get_classes_by_names(model, pure_module_config)

def transform(mod: nn.Module, _: str):
if isinstance(mod, pure_module_classes):
return PureModule(mod)
return mod

return wrap_module(model, transform)
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def add_activation_checkpointing_and_scan(
NotImplementedError: If checkpointed layer does not match scanned layer.
"""
remat_config = config.model.remat
remat_classes = _get_classes_by_names(
remat_classes = get_classes_by_names(
model, remat_config.get("activation_checkpoint_layers", [])
)
layers_to_scan = remat_config.get("scan_layers", None)
Expand Down Expand Up @@ -112,7 +112,7 @@ def add_optimization_barriers(model: nn.Module, config: DictConfig) -> nn.Module
Modified model with optimization barriers.
"""
remat_config = config.model.remat
classes = _get_classes_by_names(
classes = get_classes_by_names(
model, remat_config.get("optimization_barrier_layers", [])
)
if not classes:
Expand All @@ -128,7 +128,7 @@ def maybe_add_barrier(mod: nn.Module, _name: str) -> nn.Module:
return wrap_module(model, maybe_add_barrier)


def _get_classes_by_names(
def get_classes_by_names(
model: nn.Module, class_names: list[str]
) -> tuple[type[nn.Module], ...]:
"""Helper to resolve string class names to actual model classes.
Expand Down
63 changes: 63 additions & 0 deletions torchprime/torch_xla_models/tests/test_assume_pure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
import torch.nn as nn
import torch_xla
from omegaconf import OmegaConf

from torchprime.torch_xla_models.model_rewriting.assume_pure import (
PureModule,
mark_pure_modules,
)


def test_nn_linear():
inputs = torch.randn((4,), device="xla")
linear = nn.Linear(4, 8)
linear = linear.to("xla")
expected_output = linear(inputs)
torch_xla.sync()
pure_linear = PureModule(linear)
actual_output = pure_linear(inputs)
torch_xla.sync()
torch.testing.assert_close(actual_output, expected_output)


def test_rewrite():
# Arrange
class Foo(nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return input

class Bar(nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return input

class Model(nn.Module):
def __init__(self):
super().__init__()
self.foo = Foo()
self.bar = Bar()

def forward(self, input):
return self.foo(input) + self.bar(input)

model = Model()
config = OmegaConf.create(
{
"model": {
"pure_modules": ["Foo"],
},
}
)

# Act
model = mark_pure_modules(model, config)

# Assert
assert isinstance(model.foo, PureModule)
assert not isinstance(model.bar, PureModule)
2 changes: 1 addition & 1 deletion torchprime/torch_xla_models/tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def dummy_config():
return OmegaConf.create(
{
"model": {
"pure_modules": [],
"remat": {
"activation_checkpoint_layers": [],
"optimization_barrier_layers": [],
Expand All @@ -91,7 +92,6 @@ def dummy_config():
"profile_start_step": -1,
"profile_end_step": -1,
"profile_dir": "/tmp/profile",
"profile_duration": 5,
"ici_mesh": {"data": 1, "fsdp": 1, "tensor": 1},
"dcn_mesh": {},
}
Expand Down
4 changes: 4 additions & 0 deletions torchprime/torch_xla_models/trainer/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@

from torchprime.metrics.mfu import compute_mfu
from torchprime.metrics.step_duration import step_duration_from_latest_profile
from torchprime.torch_xla_models.model_rewriting.assume_pure import (
mark_pure_modules,
)
from torchprime.torch_xla_models.model_rewriting.auto_trace import auto_trace
from torchprime.torch_xla_models.model_rewriting.rematerialization_utils import (
add_activation_checkpointing_and_scan,
Expand Down Expand Up @@ -99,6 +102,7 @@ def __init__(
model, self.input_sharding_spec, self.minibatch = setup_sharding_and_mesh(
model, config
)
model = mark_pure_modules(model, config)
model = add_activation_checkpointing_and_scan(model, config)
model = add_optimization_barriers(model, config)
self.model = model
Expand Down