Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions tests/unit_tests/test_activation_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,28 +65,28 @@ def get_bw_flops(model_fn):
ac_config_no_force = ACConfig(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list
per_op_sac_force_save_mm_shapes_by_fqns=[], # Empty list
)
apply_ac(model_selective_ac, ac_config_no_force)
flops_selective_ac = get_bw_flops(model_selective_ac)

# 3. Per-op SAC with force recompute "moe.router.gate"
# This leads to two mms being recomputed since they share the same shape!
# 3. Per-op SAC with force save "moe.router.gate"
# This leads to two mms being saved since they share the same shape!
model_with_force_first = ToyModule()
ac_config_with_force_first = ACConfig(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
per_op_sac_force_save_mm_shapes_by_fqns=["moe.router.gate"],
)
apply_ac(model_with_force_first, ac_config_with_force_first)
flops_with_force_first = get_bw_flops(model_with_force_first)

# 4. Per-op SAC with force recompute "output"
# 4. Per-op SAC with force save "output"
model_with_force_last = ToyModule()
ac_config_with_force_last = ACConfig(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
per_op_sac_force_save_mm_shapes_by_fqns=["output"],
)
apply_ac(model_with_force_last, ac_config_with_force_last)
flops_with_force_last = get_bw_flops(model_with_force_last)
Expand All @@ -101,8 +101,8 @@ def get_bw_flops(model_fn):

self.assertEqual(flops_no_ac, 8.0)
self.assertEqual(flops_selective_ac, 9.0)
self.assertEqual(flops_with_force_first, 10.0)
self.assertEqual(flops_with_force_last, 11.0)
self.assertEqual(flops_with_force_first, 8.0)
self.assertEqual(flops_with_force_last, 9.0)
self.assertEqual(flops_full_ac, 12.0)

def test_mem(self):
Expand Down Expand Up @@ -131,28 +131,28 @@ def get_act_mem(model_fn):
ac_config_no_force = ACConfig(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list
per_op_sac_force_save_mm_shapes_by_fqns=[], # Empty list
)
apply_ac(model_selective_ac, ac_config_no_force)
mem_selective_ac = get_act_mem(model_selective_ac)

# 3. Per-op SAC with force recompute "moe.router.gate"
# This leads to two mms being recomputed since they share the same shape!
# 3. Per-op SAC with force save "moe.router.gate"
# This leads to two mms being saved since they share the same shape!
model_with_force_first = ToyModule().cuda()
ac_config_with_force_first = ACConfig(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
per_op_sac_force_save_mm_shapes_by_fqns=["moe.router.gate"],
)
apply_ac(model_with_force_first, ac_config_with_force_first)
mem_with_force_first = get_act_mem(model_with_force_first)

# 4. Per-op SAC with force recompute "output"
# 4. Per-op SAC with force save "output"
model_with_force_last = ToyModule().cuda()
ac_config_with_force_last = ACConfig(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
per_op_sac_force_save_mm_shapes_by_fqns=["output"],
)
apply_ac(model_with_force_last, ac_config_with_force_last)
mem_with_force_last = get_act_mem(model_with_force_last)
Expand All @@ -167,8 +167,8 @@ def get_act_mem(model_fn):

self.assertEqual(mem_no_ac, 2.0)
self.assertEqual(mem_selective_ac, 3.0)
self.assertEqual(mem_with_force_first, 2.0)
self.assertEqual(mem_with_force_last, 1.0)
self.assertEqual(mem_with_force_first, 4.0)
self.assertEqual(mem_with_force_last, 3.0)
self.assertEqual(mem_full_ac, 0.0)
# Note: SAC > no-AC here because it unnecessarily saves "output"
# even that is not needed for recomputaion and output is double
Expand All @@ -184,7 +184,7 @@ def test_correctness(self):
ACConfig(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=[],
per_op_sac_force_save_mm_shapes_by_fqns=[],
),
)
model_force_first = ToyModule()
Expand All @@ -194,7 +194,7 @@ def test_correctness(self):
ACConfig(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
per_op_sac_force_save_mm_shapes_by_fqns=["moe.router.gate"],
),
)

Expand All @@ -205,7 +205,7 @@ def test_correctness(self):
ACConfig(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
per_op_sac_force_save_mm_shapes_by_fqns=["output"],
),
)

Expand Down
4 changes: 2 additions & 2 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,12 +509,12 @@ class ActivationCheckpoint:
'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
"""

per_op_sac_force_recompute_mm_shapes_by_fqns: list[str] = field(
per_op_sac_force_save_mm_shapes_by_fqns: list[str] = field(
default_factory=lambda: ["moe.router.gate"]
)
"""
When per-op selective ac is used, this list of fully qualified names is used
to determine which mm shapes to force recompute, rather than being considered
to determine which mm shapes to force save, rather than being considered
by rest of the sac policy, e.g save every other mm. Only nn.Linear modules are
supported today.

Expand Down
16 changes: 8 additions & 8 deletions torchtitan/models/llama3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,35 +269,35 @@ def _apply_ac_to_transformer_block(
create_selective_checkpoint_contexts,
)

mm_recompute_shapes = set()
if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0:
mm_save_shapes = set()
if len(ac_config.per_op_sac_force_save_mm_shapes_by_fqns) > 0:
for module_fqn, submod in module.named_modules():
fqn = module_fqn
if base_fqn is not None:
fqn = f"{base_fqn}.{module_fqn}"
if not any(
filter_fqn in fqn
for filter_fqn in ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns
for filter_fqn in ac_config.per_op_sac_force_save_mm_shapes_by_fqns
):
continue
if not isinstance(submod, nn.Linear):
raise ValueError(
"per_op_sac_force_recompute_mm_shapes_by_fqns expected to match "
"per_op_sac_force_save_mm_shapes_by_fqns expected to match "
f"a nn.Linear, but got: {submod}"
)
out_f, in_f = submod.weight.shape
mm_recompute_shapes.add((in_f, out_f))
mm_save_shapes.add((in_f, out_f))
logger.debug(
f"Selective op AC force recomputing mms with rhs shapes {mm_recompute_shapes}"
f"Selective op AC force saving mms with rhs shapes {mm_save_shapes}"
)

def _get_custom_policy(meta):
def _custom_policy(ctx, func, *args, **kwargs):
mode = "recompute" if ctx.is_recompute else "forward"
mm_count_key = f"{mode}_mm_count"
if func == torch.ops.aten.mm.default:
if args[1].shape in mm_recompute_shapes:
return CheckpointPolicy.PREFER_RECOMPUTE
if args[1].shape in mm_save_shapes:
return CheckpointPolicy.MUST_SAVE
meta[mm_count_key] += 1
# Saves output of all compute ops, except every second mm
to_save = func in _save_list and not (
Expand Down
Loading