From 96f35808bd5d3a795c64f4e1ffb1893da56f0092 Mon Sep 17 00:00:00 2001 From: hanq-moreh Date: Mon, 29 Sep 2025 13:55:15 +0900 Subject: [PATCH 1/5] fix resume logic. load optimizer sate --- scripts/train_eagle3_offline.py | 78 +++++++++++++++++++++++++++++---- 1 file changed, 69 insertions(+), 9 deletions(-) diff --git a/scripts/train_eagle3_offline.py b/scripts/train_eagle3_offline.py index 83f3961c..ee3424a2 100644 --- a/scripts/train_eagle3_offline.py +++ b/scripts/train_eagle3_offline.py @@ -2,6 +2,7 @@ import hashlib import math import os +import re import time from collections import defaultdict @@ -209,10 +210,48 @@ def main(): # detecting last ckpt for draft model draft_model_last_checkpoint = None + resume_training_state = None + start_epoch = 0 if args.resume and os.path.isdir(args.output_dir): print_on_rank0(args.output_dir) draft_model_last_checkpoint = get_last_checkpoint(args.output_dir) print_on_rank0(f"Last checkpoint detected: {draft_model_last_checkpoint}") + if draft_model_last_checkpoint: + epoch_dir_name = os.path.basename( + draft_model_last_checkpoint.rstrip(os.sep) + ) + match = re.search(r"epoch_(\d+)", epoch_dir_name) + if match: + start_epoch = int(match.group(1)) + 1 + print_on_rank0(f"Resuming training from epoch {start_epoch}") + else: + print_on_rank0( + "Failed to parse epoch index from checkpoint directory; starting from epoch 0" + ) + training_state_path = os.path.join( + draft_model_last_checkpoint, "training_state.pt" + ) + if os.path.isfile(training_state_path): + try: + resume_training_state = torch.load( + training_state_path, + map_location="cpu", + weights_only=False, + ) + saved_epoch = resume_training_state.get("epoch") + if saved_epoch is not None: + start_epoch = saved_epoch + 1 + print_on_rank0( + f"Loaded training state from {training_state_path}" + ) + except Exception as exc: # pragma: no cover - informational + print_on_rank0( + f"Failed to load training state from {training_state_path}: {exc}" + ) + else: + print_on_rank0( + f"No training_state.pt found in {draft_model_last_checkpoint}; starting from epoch {start_epoch}" + ) # build target and draft model target_head = TargetHead(args.target_model_path) @@ -366,10 +405,31 @@ def main(): ) print_with_rank("Initialized optimizer and scheduler") + if resume_training_state: + optimizer.load_state_dict(resume_training_state) + global_step = resume_training_state.get("global_step") + if global_step is None: + global_step = start_epoch * math.ceil( + len(train_dataloader) / args.draft_accumulation_steps + ) + if global_step is None or global_step < 0: + global_step = 0 + print_on_rank0( + f"Resumed optimizer and scheduler state (global_step={global_step})" + ) + last_time = time.time() # start running - for epoch in range(args.num_epochs): + if start_epoch >= args.num_epochs: + print_on_rank0( + f"Start epoch {start_epoch} is >= total epochs {args.num_epochs}; nothing to resume." + ) + tracker.close() + destroy_distributed() + return + + for epoch in range(start_epoch, args.num_epochs): # Run training train_dataloader.sampler.set_epoch(epoch + 1) draft_model.train() @@ -487,14 +547,13 @@ def main(): eval_plosses = [[] for _ in range(eagle3_model.length)] for data in tqdm(eval_dataloader, desc=f"Evaluating Epoch {epoch}"): - with torch.no_grad(): - plosses, _, acces = eagle3_model( - input_ids=data["input_ids"].cuda(), - attention_mask=data["attention_mask"].cuda(), - loss_mask=data["loss_mask"].unsqueeze(-1).cuda(), - hidden_states=data["hidden_state"].cuda(), - target=data["target"].cuda(), - ) + plosses, _, acces = eagle3_model( + input_ids=data["input_ids"].cuda(), + attention_mask=data["attention_mask"].cuda(), + loss_mask=data["loss_mask"].unsqueeze(-1).cuda(), + hidden_states=data["hidden_state"].cuda(), + target=data["target"].cuda(), + ) acces = torch.stack(acces).cpu().tolist() eval_acces = [eval_acces[i] + [acces[i]] for i in range(len(acces))] eval_plosses = [ @@ -535,6 +594,7 @@ def main(): state_to_save = { "epoch": epoch, "args": args, + "global_step": global_step, } state_to_save.update(optimizer.state_dict()) draft_model_state_dict = { From b8d0b1d31308b1800c913de77f95fa5b8104b603 Mon Sep 17 00:00:00 2001 From: hanq-moreh Date: Mon, 29 Sep 2025 14:22:51 +0900 Subject: [PATCH 2/5] add evaluation torch.no_grad() --- scripts/train_eagle3_offline.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/scripts/train_eagle3_offline.py b/scripts/train_eagle3_offline.py index ee3424a2..9c50cd97 100644 --- a/scripts/train_eagle3_offline.py +++ b/scripts/train_eagle3_offline.py @@ -547,13 +547,14 @@ def main(): eval_plosses = [[] for _ in range(eagle3_model.length)] for data in tqdm(eval_dataloader, desc=f"Evaluating Epoch {epoch}"): - plosses, _, acces = eagle3_model( - input_ids=data["input_ids"].cuda(), - attention_mask=data["attention_mask"].cuda(), - loss_mask=data["loss_mask"].unsqueeze(-1).cuda(), - hidden_states=data["hidden_state"].cuda(), - target=data["target"].cuda(), - ) + with torch.no_grad(): + plosses, _, acces = eagle3_model( + input_ids=data["input_ids"].cuda(), + attention_mask=data["attention_mask"].cuda(), + loss_mask=data["loss_mask"].unsqueeze(-1).cuda(), + hidden_states=data["hidden_state"].cuda(), + target=data["target"].cuda(), + ) acces = torch.stack(acces).cpu().tolist() eval_acces = [eval_acces[i] + [acces[i]] for i in range(len(acces))] eval_plosses = [ From e9e37511f9b92ccdf1a11f8da9c19adfbc845b82 Mon Sep 17 00:00:00 2001 From: hanq-moreh Date: Mon, 29 Sep 2025 14:58:59 +0900 Subject: [PATCH 3/5] remove redudant condition --- scripts/train_eagle3_offline.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/scripts/train_eagle3_offline.py b/scripts/train_eagle3_offline.py index 9c50cd97..f3c86082 100644 --- a/scripts/train_eagle3_offline.py +++ b/scripts/train_eagle3_offline.py @@ -409,11 +409,10 @@ def main(): optimizer.load_state_dict(resume_training_state) global_step = resume_training_state.get("global_step") if global_step is None: - global_step = start_epoch * math.ceil( + steps_per_epoch = math.ceil( len(train_dataloader) / args.draft_accumulation_steps ) - if global_step is None or global_step < 0: - global_step = 0 + global_step = int(start_epoch * steps_per_epoch) print_on_rank0( f"Resumed optimizer and scheduler state (global_step={global_step})" ) From 0efdad980fd3da4d6427abb39ab17715de6b84d2 Mon Sep 17 00:00:00 2001 From: hankyu jang Date: Tue, 21 Oct 2025 16:27:15 +0900 Subject: [PATCH 4/5] add finetune, postprocess, sanity_check --- requirements.txt | 9 ++--- scripts/train_eagle3_offline.py | 61 +++++++++++++++++---------------- specforge/core/loss.py | 2 +- specforge/data/preprocessing.py | 7 ++-- 4 files changed, 41 insertions(+), 38 deletions(-) diff --git a/requirements.txt b/requirements.txt index 71f8944c..54967be9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,4 @@ pre-commit -torch==2.8.0 -torchaudio==2.8.0 -torchvision==0.23.0 transformers==4.55.2 qwen-vl-utils==0.0.11 datasets @@ -12,5 +9,9 @@ psutil numpy accelerate pydantic -sglang[all]==0.5.1 openai-harmony + +# sglang[all]==0.5.1 +# torch==2.8.0 +# torchaudio==2.8.0 +# torchvision==0.23.0 \ No newline at end of file diff --git a/scripts/train_eagle3_offline.py b/scripts/train_eagle3_offline.py index f3c86082..61f440c1 100644 --- a/scripts/train_eagle3_offline.py +++ b/scripts/train_eagle3_offline.py @@ -60,9 +60,9 @@ def parse_args(): ) # add training-related arguments - parser.add_argument("--train-data-path", type=str, required=True) + # parser.add_argument("--train-data-path", type=str, required=True) + # parser.add_argument("--eval-data-path", type=str, default=None) parser.add_argument("--train-hidden-states-path", type=str, required=True) - parser.add_argument("--eval-data-path", type=str, default=None) parser.add_argument("--eval-hidden-states-path", type=str, default=None) parser.add_argument("--num-epochs", type=int, default=10) parser.add_argument("--draft-global-batch-size", type=int, default=16) @@ -114,6 +114,10 @@ def parse_args(): # resume parser.add_argument("--resume", action="store_true") + # finetune + parser.add_argument("--finetune", action="store_true") + parser.add_argument("--baseline-dir", type=str, default=None) + # report backend parser.add_argument( "--report-to", @@ -212,6 +216,12 @@ def main(): draft_model_last_checkpoint = None resume_training_state = None start_epoch = 0 + if args.finetune and args.baseline_dir is not None: + if os.path.isdir(args.baseline_dir): + draft_model_last_checkpoint = args.baseline_dir + print_on_rank0(f"Finetuning from baseline model: {draft_model_last_checkpoint}") + else: + raise ValueError(f"Provided baseline-dir {args.baseline_dir} is not a valid directory.") if args.resume and os.path.isdir(args.output_dir): print_on_rank0(args.output_dir) draft_model_last_checkpoint = get_last_checkpoint(args.output_dir) @@ -222,7 +232,7 @@ def main(): ) match = re.search(r"epoch_(\d+)", epoch_dir_name) if match: - start_epoch = int(match.group(1)) + 1 + start_epoch = int(match.group(1)) print_on_rank0(f"Resuming training from epoch {start_epoch}") else: print_on_rank0( @@ -301,24 +311,18 @@ def main(): # convert to dataloader cache_params_string = ( - f"{args.train_data_path}-" + f"{args.train_hidden_states_path}-" f"{args.max_length}-" f"{args.chat_template}-" f"{args.target_model_path}" # Tokenizer may also different ) cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() - train_dataset = load_dataset("json", data_files=args.train_data_path)["train"] with rank_0_priority(): - train_eagle3_dataset_tmp = build_eagle3_dataset( - dataset=train_dataset, - tokenizer=tokenizer, - chat_template=args.chat_template, - is_preformatted=args.is_preformatted, - max_length=args.max_length, - cache_dir=os.path.join(args.cache_dir, "processed_dataset"), - cache_key=cache_key, - num_proc=args.build_dataset_num_proc, + train_eagle3_dataset = build_offline_eagle3_dataset( + args.train_hidden_states_path, + args.max_length, ) + train_eagle3_dataset_tmp = train_eagle3_dataset vocab_mapping_path = generate_vocab_mapping_file( dataset=train_eagle3_dataset_tmp, target_vocab_size=draft_model_config.vocab_size, @@ -326,10 +330,7 @@ def main(): cache_dir=os.path.join(args.cache_dir, "vocab_mapping"), cache_key=cache_key, ) - train_eagle3_dataset = build_offline_eagle3_dataset( - args.train_hidden_states_path, - args.max_length, - ) + train_dataloader = prepare_dp_dataloaders( train_eagle3_dataset, @@ -357,7 +358,7 @@ def main(): draft_model.load_vocab_mapping(vocab_mapping_path) print_with_rank("Loaded vocab mapping") - if args.eval_data_path is not None: + if args.eval_hidden_states_path is not None: eval_eagle3_dataset = build_offline_eagle3_dataset( args.eval_hidden_states_path, args.max_length, @@ -409,10 +410,11 @@ def main(): optimizer.load_state_dict(resume_training_state) global_step = resume_training_state.get("global_step") if global_step is None: - steps_per_epoch = math.ceil( + global_step = start_epoch * math.ceil( len(train_dataloader) / args.draft_accumulation_steps ) - global_step = int(start_epoch * steps_per_epoch) + if global_step is None or global_step < 0: + global_step = 0 print_on_rank0( f"Resumed optimizer and scheduler state (global_step={global_step})" ) @@ -539,21 +541,20 @@ def main(): tracker.log(train_epoch_logdict, step=global_step) # run evaluation - if args.eval_data_path is not None and epoch % args.eval_interval == 0: + if args.eval_hidden_states_path is not None and epoch % args.eval_interval == 0: # Run evaluation draft_model.eval() eval_acces = [[] for _ in range(eagle3_model.length)] eval_plosses = [[] for _ in range(eagle3_model.length)] for data in tqdm(eval_dataloader, desc=f"Evaluating Epoch {epoch}"): - with torch.no_grad(): - plosses, _, acces = eagle3_model( - input_ids=data["input_ids"].cuda(), - attention_mask=data["attention_mask"].cuda(), - loss_mask=data["loss_mask"].unsqueeze(-1).cuda(), - hidden_states=data["hidden_state"].cuda(), - target=data["target"].cuda(), - ) + plosses, _, acces = eagle3_model( + input_ids=data["input_ids"].cuda(), + attention_mask=data["attention_mask"].cuda(), + loss_mask=data["loss_mask"].unsqueeze(-1).cuda(), + hidden_states=data["hidden_state"].cuda(), + target=data["target"].cuda(), + ) acces = torch.stack(acces).cpu().tolist() eval_acces = [eval_acces[i] + [acces[i]] for i in range(len(acces))] eval_plosses = [ diff --git a/specforge/core/loss.py b/specforge/core/loss.py index 5e4423bd..a83d2c25 100644 --- a/specforge/core/loss.py +++ b/specforge/core/loss.py @@ -30,7 +30,7 @@ def _calculate_settings(n): raise RuntimeError( f"Cannot launch Triton kernel since n = {n} exceeds the recommended Triton blocksize = {MAX_FUSED_SIZE}." ) - + return 1024, 4 num_warps = 4 if BLOCK_SIZE >= 32768: num_warps = 32 diff --git a/specforge/data/preprocessing.py b/specforge/data/preprocessing.py index 3eb7849a..a990beee 100644 --- a/specforge/data/preprocessing.py +++ b/specforge/data/preprocessing.py @@ -366,7 +366,8 @@ def preprocess_function(examples): max_length, is_preformatted=False, ) - + if "request_id" in examples: + processed["request_id"] = examples["request_id"] return processed # Process dataset only once @@ -509,8 +510,8 @@ def generate_vocab_mapping_file( # we first count the frequency of effectiev tokens in the dataset token_dict = Counter() for item in tqdm(dataset, desc="Counting tokens for vocab mapping"): - input_ids = item["input_ids"] - loss_mask = item["loss_mask"] + input_ids = item["input_ids"].unsqueeze(0) + loss_mask = item["loss_mask"].unsqueeze(0) masked_ids = input_ids[loss_mask == 1] unique_ids, counts = masked_ids.unique(return_counts=True) batch_token_dict = dict(zip(unique_ids.tolist(), counts.tolist())) From c94089b6a24d8126382381b7b46abb43025c5842 Mon Sep 17 00:00:00 2001 From: hankyu jang Date: Tue, 21 Oct 2025 16:36:11 +0900 Subject: [PATCH 5/5] add example of runtime training --- examples/runtime_training_example.sh | 74 ++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 examples/runtime_training_example.sh diff --git a/examples/runtime_training_example.sh b/examples/runtime_training_example.sh new file mode 100644 index 00000000..705fb8f7 --- /dev/null +++ b/examples/runtime_training_example.sh @@ -0,0 +1,74 @@ +#!/bin/bash + +# run server +export CUDA_VISIBLE_DEVICES=0,1,2,3 +export SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 + +cd SpecForge + +python3 -m sglang.launch_server --model /models/gpt-oss-120b --tp 4 \ + --speculative-draft-model-path /cache/from_scratch_dumped_train_fixed_output/epoch_9 \ + --speculative-num-steps 5 \ + --speculative-eagle-topk 4 \ + --speculative-num-draft-tokens 8 \ + --mem-fraction 0.8 \ + --speculative-algorithm EAGLE3 \ + --cuda-graph-max-bs 32 \ + --port 41555 \ + --trust-remote-code \ + --disable-radix-cache \ + --enable-dump-hidden-states \ + --hidden-states-dump-path /cache/hidden_states_default \ + +# data generation +python3 -m sglang.bench_serving \ + --backend sglang-oai-chat\ + --dataset-name sharegpt \ + --num-prompts 1000\ + --model /models/gpt-oss-120b \ + --dataset-path /cache/dataset_new/cluster0_user_test.json \ + --output-file output.jsonl \ + --max-concurrency 32 \ + --port 41555 + +# postprocess +python postprocess_test.py \ + --data-path /cache/hidden_states_default/ \ + --model-path /models/gpt-oss-120b/ \ + --output-path /cache/dump_train + +python postprocess_test.py \ + --data-path /cache/hidden_states_default/ \ + --model-path /models/gpt-oss-120b/ \ + --output-path /cache/dump_eval \ + --test-mode + +# finetuning +export NUM_GPUS=4 +export CUDA_VISIBLE_DEVICES=0,1,2,3 +export WANDB_API_KEY=your_wandb_api_key_here +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + scripts/train_eagle3_offline.py \ + --target-model-path /models/gpt-oss-120b \ + --draft-model-config ./configs/gpt-oss-120B-eagle3.json \ + --train-hidden-states-path /cache/dump_train \ + --eval-hidden-states-path /cache/dump_eval \ + --output-dir /cache/dump_output \ + --num-epochs 10 \ + --draft-global-batch-size 16 \ + --draft-micro-batch-size 1 \ + --learning-rate 5e-5 \ + --draft-attention-backend flex_attention \ + --max-length 2048 \ + --chat-template gpt-oss \ + --cache-dir /cache/dump_cache \ + --dist-timeout 3600 \ + --log-steps 1 \ + --is-preformatted \ + --finetune \ + --baseline-dir /workspace/EAGLE3-gpt-oss-120b-bf16 \ + --report-to wandb \ + --wandb-project gpt-oss-120b-eagle3 \ + --wandb-name dump-train-10epoch-batch16-lr5e-5 \ No newline at end of file