Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 61 additions & 1 deletion scripts/train_eagle3_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import hashlib
import math
import os
import re
import time
from collections import defaultdict

Expand Down Expand Up @@ -209,10 +210,48 @@ def main():

# detecting last ckpt for draft model
draft_model_last_checkpoint = None
resume_training_state = None
start_epoch = 0
if args.resume and os.path.isdir(args.output_dir):
print_on_rank0(args.output_dir)
draft_model_last_checkpoint = get_last_checkpoint(args.output_dir)
print_on_rank0(f"Last checkpoint detected: {draft_model_last_checkpoint}")
if draft_model_last_checkpoint:
epoch_dir_name = os.path.basename(
draft_model_last_checkpoint.rstrip(os.sep)
)
match = re.search(r"epoch_(\d+)", epoch_dir_name)
if match:
start_epoch = int(match.group(1)) + 1
print_on_rank0(f"Resuming training from epoch {start_epoch}")
else:
print_on_rank0(
"Failed to parse epoch index from checkpoint directory; starting from epoch 0"
)
training_state_path = os.path.join(
draft_model_last_checkpoint, "training_state.pt"
)
if os.path.isfile(training_state_path):
try:
resume_training_state = torch.load(
training_state_path,
map_location="cpu",
weights_only=False,
)
saved_epoch = resume_training_state.get("epoch")
if saved_epoch is not None:
start_epoch = saved_epoch + 1
print_on_rank0(
f"Loaded training state from {training_state_path}"
)
except Exception as exc: # pragma: no cover - informational
print_on_rank0(
f"Failed to load training state from {training_state_path}: {exc}"
)
else:
print_on_rank0(
f"No training_state.pt found in {draft_model_last_checkpoint}; starting from epoch {start_epoch}"
)

# build target and draft model
target_head = TargetHead(args.target_model_path)
Expand Down Expand Up @@ -366,10 +405,30 @@ def main():
)
print_with_rank("Initialized optimizer and scheduler")

if resume_training_state:
optimizer.load_state_dict(resume_training_state)
global_step = resume_training_state.get("global_step")
if global_step is None:
steps_per_epoch = math.ceil(
len(train_dataloader) / args.draft_accumulation_steps
)
global_step = int(start_epoch * steps_per_epoch)
print_on_rank0(
f"Resumed optimizer and scheduler state (global_step={global_step})"
)

last_time = time.time()

# start running
for epoch in range(args.num_epochs):
if start_epoch >= args.num_epochs:
print_on_rank0(
f"Start epoch {start_epoch} is >= total epochs {args.num_epochs}; nothing to resume."
)
tracker.close()
destroy_distributed()
return

for epoch in range(start_epoch, args.num_epochs):
# Run training
train_dataloader.sampler.set_epoch(epoch + 1)
draft_model.train()
Expand Down Expand Up @@ -535,6 +594,7 @@ def main():
state_to_save = {
"epoch": epoch,
"args": args,
"global_step": global_step,
}
state_to_save.update(optimizer.state_dict())
draft_model_state_dict = {
Expand Down