Skip to content

Commit ec87c92

Browse files
committed
Add E2E test
1 parent 757fda4 commit ec87c92

File tree

3 files changed

+79
-37
lines changed

3 files changed

+79
-37
lines changed

.github/workflows/e2e_test.yml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ jobs:
2222
ARTIFACT_DIR: gs://torchprime-e2e-tests/${{ github.job }}/${{ github.run_id }}-${{ github.run_attempt }}
2323
outputs:
2424
llama-3-8b-name: ${{ steps.run-llama-3-8b.outputs.name }}
25+
llama-3-8b-pure-mlp-name: ${{ steps.run-llama-3-8b-pure-mlp.outputs.name }}
2526
llama-3_1-8b-sa-name: ${{ steps.run-llama-3_1-8b-SplashAttention.outputs.name }}
2627
llama-3_1-8b-scan-offload-name: ${{ steps.run-llama-3_1-8b-scan-offload.outputs.name }}
2728
llama-3-8b-2d-name: ${{ steps.run-llama-3-8b-2d.outputs.name }}
@@ -83,6 +84,27 @@ jobs:
8384
ici_mesh.fsdp=4 \
8485
profile_start_step=3
8586
87+
- name: Run Llama 3.0 8B (@assume_pure)
88+
id: run-llama-3-8b-pure-mlp
89+
env:
90+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
91+
XLA_IR_DEBUG: 1
92+
XLA_HLO_DEBUG: 1
93+
run: |
94+
name=$(e2e_testing/gen_name.py llama-3-8b-pure-mlp)
95+
echo "name=$name" >> "$GITHUB_OUTPUT"
96+
tp run ${{ steps.docker-url-option.outputs.value }} \
97+
--name $name \
98+
torchprime/torch_xla_models/train.py \
99+
model=llama-3-8b \
100+
dataset=wikitext \
101+
task=train \
102+
task.global_batch_size=8 \
103+
task.max_steps=15 \
104+
ici_mesh.fsdp=4 \
105+
profile_start_step=3 \
106+
model.pure_modules=[LlamaMLP,EinsumLinear]
107+
86108
- name: Run Llama 3.1 8B (Splash Attention)
87109
id: run-llama-3_1-8b-SplashAttention
88110
env:
@@ -259,6 +281,7 @@ jobs:
259281
jobset_name: >-
260282
${{
261283
matrix.config.benchmark == 'llama-3-8b' && needs.tp-run.outputs.llama-3-8b-name ||
284+
matrix.config.benchmark == 'llama-3-8b-pure-mlp' && needs.tp-run.outputs.llama-3-8b-pure-mlp-name ||
262285
matrix.config.benchmark == 'llama-3_1-8b-sa' && needs.tp-run.outputs.llama-3_1-8b-sa-name ||
263286
matrix.config.benchmark == 'llama-3_1-8b-scan-offload' && needs.tp-run.outputs.llama-3_1-8b-scan-offload-name ||
264287
matrix.config.benchmark == 'llama-3-8b-2d' && needs.tp-run.outputs.llama-3-8b-2d-name ||

e2e_testing/step_time_bounds.yaml

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,54 @@
11
benchmarks:
22
llama-3-8b:
33
name: Llama 3.0 8B
4-
step_time_lower_bound: 2.68109009
5-
step_time_upper_bound: 2.789223
6-
confidence_interval: 0.05407
7-
average: 2.7352
8-
sample_size: 427
4+
step_time_lower_bound: 0.894678
5+
step_time_upper_bound: 4.54563437
6+
confidence_interval: 1.82548
7+
average: 2.7202
8+
sample_size: 114
99
llama-3_1-8b-sa:
1010
name: Llama 3.1 8B (Splash Attention)
11-
step_time_lower_bound: 2.34653077
12-
step_time_upper_bound: 2.467111
13-
confidence_interval: 0.06029
14-
average: 2.4068
15-
sample_size: 428
11+
step_time_lower_bound: 2.35428493
12+
step_time_upper_bound: 2.470571
13+
confidence_interval: 0.05814
14+
average: 2.4124
15+
sample_size: 112
1616
llama-3_1-8b-scan-offload:
1717
name: Llama 3.1 8B (Scan + Offload)
18-
step_time_lower_bound: 2.74099553
19-
step_time_upper_bound: 2.860302
20-
confidence_interval: 0.05965
21-
average: 2.8006
22-
sample_size: 428
18+
step_time_lower_bound: 2.74872464
19+
step_time_upper_bound: 2.871284
20+
confidence_interval: 0.06128
21+
average: 2.81
22+
sample_size: 94
2323
llama-3-8b-2d:
2424
name: Llama 3.0 8B (2D sharding)
25-
step_time_lower_bound: 3.28827914
26-
step_time_upper_bound: 3.38842977
27-
confidence_interval: 0.05008
28-
average: 3.3384
29-
sample_size: 428
25+
step_time_lower_bound: 3.31281298
26+
step_time_upper_bound: 3.41371084
27+
confidence_interval: 0.05045
28+
average: 3.3633
29+
sample_size: 114
3030
mixtral-8x7b:
3131
name: Mixtral 8x7B
32-
step_time_lower_bound: 3.09900735
33-
step_time_upper_bound: 3.19339336
34-
confidence_interval: 0.04719
35-
average: 3.1462
36-
sample_size: 427
32+
step_time_lower_bound: 3.12225098
33+
step_time_upper_bound: 3.21734492
34+
confidence_interval: 0.04755
35+
average: 3.1698
36+
sample_size: 114
3737
llama-3-8b-2-slice:
3838
name: Llama 3.0 8B (2 Slice)
39-
step_time_lower_bound: 3.82985294
40-
step_time_upper_bound: 4.087614
41-
confidence_interval: 0.12888
42-
average: 3.9587
43-
sample_size: 416
39+
step_time_lower_bound: 3.47510115
40+
step_time_upper_bound: 4.505638
41+
confidence_interval: 0.51527
42+
average: 3.9904
43+
sample_size: 110
4444
llama-3-8b-ddp-fsdp:
4545
name: Llama 3.0 8B (ddp + fsdp)
46-
step_time_lower_bound: 3.22420277
47-
step_time_upper_bound: 3.351676
48-
confidence_interval: 0.06374
49-
average: 3.2879
50-
sample_size: 47
46+
step_time_lower_bound: 3.2263914
47+
step_time_upper_bound: 3.341676
48+
confidence_interval: 0.05764
49+
average: 3.284
50+
sample_size: 110
5151
metadata:
52-
query_start: '2025-05-26T18:37:58.674556-07:00'
53-
query_end: '2025-06-13T13:20:09-07:00'
52+
query_start: '2025-06-12T22:37:43+00:00'
53+
query_end: '2025-06-17T22:37:43+00:00'
5454
confidence_level: 0.999

e2e_testing/update_step_time.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,23 @@ def match_llama3_8b(row):
2929
and config["dcn_mesh"]["data"] == 1
3030
and config["dcn_mesh"]["fsdp"] == 1
3131
and config["ici_mesh"]["tensor"] == 1
32+
and (
33+
"pure_modules" not in config["model"] or len(config["model"]["pure_modules"]) == 0
34+
)
35+
)
36+
37+
38+
def match_llama3_8b_pure_mlp(row):
39+
config = json.loads(row.configs_framework)
40+
return (
41+
row.run_id.startswith("llama-3-8b-pure-mlp")
42+
and config["dcn_mesh"]["data"] == 1
43+
and config["dcn_mesh"]["fsdp"] == 1
44+
and config["ici_mesh"]["tensor"] == 1
45+
and (
46+
"pure_modules" in config["model"]
47+
and config["model"]["pure_modules"] == ["LlamaMLP", "EinsumLinear"]
48+
)
3249
)
3350

3451

@@ -86,6 +103,7 @@ def match_llama_3_8b_ddp_fsdp(row):
86103

87104
BENCHMARKS = {
88105
"Llama 3.0 8B": match_llama3_8b,
106+
"Llama 3.0 8B (@assume_pure)": match_llama3_8b_pure_mlp,
89107
"Llama 3.1 8B (Splash Attention)": match_llama3_1_8b_sa,
90108
"Llama 3.1 8B (Scan + Offload)": match_llama3_1_8b_scan_offload,
91109
"Llama 3.0 8B (2D sharding)": match_llama3_8b_2d,
@@ -96,6 +114,7 @@ def match_llama_3_8b_ddp_fsdp(row):
96114

97115
STEP_ID_MAPPING = {
98116
"Llama 3.0 8B": "llama-3-8b",
117+
"Llama 3.0 8B (@assume_pure)": "llama-3-8b-pure-mlp",
99118
"Llama 3.1 8B (Splash Attention)": "llama-3_1-8b-sa",
100119
"Llama 3.1 8B (Scan + Offload)": "llama-3_1-8b-scan-offload",
101120
"Llama 3.0 8B (2D sharding)": "llama-3-8b-2d",

0 commit comments

Comments
 (0)