diff --git a/scripts/train_eagle3_online.py b/scripts/train_eagle3_online.py
index c4df940d..e64735a1 100644
--- a/scripts/train_eagle3_online.py
+++ b/scripts/train_eagle3_online.py
@@ -96,7 +96,11 @@ def parse_args():
action="store_true",
help="Whether the input data is preformatted text with the chat template already applied to the conversation messages.",
)
-
+ parser.add_argument(
+ "--is-think-mode",
+ action="store_true",
+ help="Whether the input data need to handle special think token format.",
+ )
# distributed training
parser.add_argument("--tp-size", type=int, default=1)
parser.add_argument("--dp-size", type=int, default=1)
@@ -291,22 +295,17 @@ def main():
# load model with resume
if draft_model_last_checkpoint:
- draft_model = (
- AutoEagle3DraftModel.from_pretrained(
- draft_model_last_checkpoint, attention_backend=args.attention_backend,
- torch_dtype=torch.bfloat16
- )
- .cuda()
-
- )
+ draft_model = AutoEagle3DraftModel.from_pretrained(
+ draft_model_last_checkpoint,
+ attention_backend=args.attention_backend,
+ torch_dtype=torch.bfloat16,
+ ).cuda()
else:
- draft_model = (
- AutoEagle3DraftModel.from_config(
- draft_model_config, attention_backend=args.attention_backend,
- torch_dtype=torch.bfloat16
- )
- .cuda()
- )
+ draft_model = AutoEagle3DraftModel.from_config(
+ draft_model_config,
+ attention_backend=args.attention_backend,
+ torch_dtype=torch.bfloat16,
+ ).cuda()
draft_model.load_embedding(args.target_model_path, embedding_key=args.embedding_key)
draft_model.freeze_embedding()
print_with_rank("Initialized draft model")
@@ -341,6 +340,7 @@ def main():
cache_key=cache_key,
is_vlm=args.is_vlm,
is_preformatted=args.is_preformatted,
+ is_think_mode=args.is_think_mode,
processor=processor,
num_proc=args.build_dataset_num_proc,
)
@@ -388,6 +388,7 @@ def main():
processor=processor,
num_proc=args.build_dataset_num_proc,
is_preformatted=args.is_preformatted,
+ is_think_mode=args.is_think_mode,
)
eval_dataloader = prepare_dp_dataloaders(
eval_eagle3_dataset,
diff --git a/specforge/data/parse.py b/specforge/data/parse.py
index a4c095df..561cf79c 100644
--- a/specforge/data/parse.py
+++ b/specforge/data/parse.py
@@ -38,15 +38,25 @@ def parse(
class GeneralParser(Parser):
- def __init__(self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate):
+ def __init__(
+ self,
+ tokenizer: PreTrainedTokenizer,
+ chat_template: ChatTemplate,
+ is_think_mode: bool = False,
+ ):
super().__init__(tokenizer, chat_template)
self.system_prompt = chat_template.system_prompt
+ self.is_think_mode = is_think_mode
+ self.chat_template = chat_template
self.user_message_separator = (
f"{chat_template.end_of_turn_token}{chat_template.user_header}"
)
- self.assistant_message_separator = (
- f"{chat_template.end_of_turn_token}{chat_template.assistant_header}"
- )
+ if is_think_mode:
+ self.assistant_message_separator = f"{chat_template.end_of_turn_token}{chat_template.assistant_think_header}"
+ else:
+ self.assistant_message_separator = (
+ f"{chat_template.end_of_turn_token}{chat_template.assistant_header}"
+ )
def parse(
self, conversation: "Conversation", max_length: int, preformatted: bool = False
@@ -66,7 +76,7 @@ def parse(
messages.append({"role": "system", "content": self.system_prompt})
convroles = ["user", "assistant"]
- for j, sentence in enumerate(conversation):
+ for j, sentence in enumerate(conversation[:-1]):
role = sentence["role"]
if role != convroles[j % 2]:
warnings.warn(
@@ -75,10 +85,16 @@ def parse(
break
messages.append({"role": role, "content": sentence["content"]})
- conversation = self.tokenizer.apply_chat_template(
+ conversation_ = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
- add_generation_prompt=False,
+ add_generation_prompt=True,
+ enable_thinking=self.is_think_mode,
+ )
+ conversation = (
+ conversation_
+ + conversation[-1]["content"]
+ + self.chat_template.end_of_turn_token
)
if not self.tokenizer.pad_token_id:
diff --git a/specforge/data/preprocessing.py b/specforge/data/preprocessing.py
index 3eb7849a..0b4b7651 100644
--- a/specforge/data/preprocessing.py
+++ b/specforge/data/preprocessing.py
@@ -57,6 +57,7 @@ def _apply_loss_mask_from_chat_template(
text: str,
offsets: torch.Tensor,
chat_template: ChatTemplate,
+ is_think_mode: bool = False,
) -> torch.Tensor:
"""
Apply loss mask to identify assistant response spans using chat template.
@@ -74,9 +75,14 @@ def _apply_loss_mask_from_chat_template(
user_message_separator = (
f"{chat_template.end_of_turn_token}{chat_template.user_header}"
)
- assistant_message_separator = (
- f"{chat_template.end_of_turn_token}{chat_template.assistant_header}"
- )
+ if is_think_mode:
+ assistant_message_separator = (
+ f"{chat_template.end_of_turn_token}{chat_template.assistant_think_header}"
+ )
+ else:
+ assistant_message_separator = (
+ f"{chat_template.end_of_turn_token}{chat_template.assistant_header}"
+ )
# Find spans of assistant responses using regex
assistant_pattern = (
@@ -116,6 +122,7 @@ def preprocess_conversations(
chat_template: ChatTemplate,
max_length: int = 2048,
is_preformatted: bool = False,
+ is_think_mode: bool = False,
) -> Dict[str, List[torch.Tensor]]:
"""
Preprocess a batch of ShareGPT style conversations or pre-formatted text.
@@ -139,14 +146,14 @@ def preprocess_conversations(
results = {"input_ids": [], "loss_mask": [], "attention_mask": []}
if chat_template.parser_type == "general":
- parser = GeneralParser(tokenizer, chat_template)
+ parser = GeneralParser(tokenizer, chat_template, is_think_mode)
elif chat_template.parser_type == "openai-harmony":
parser = HarmonyParser(tokenizer, chat_template)
else:
raise ValueError(f"Invalid parser type: {chat_template.parser_type}")
for source in conversations:
- if not source:
+ if not source or len(source) % 2 != 0:
# if the source is None, skip it
continue
input_ids, loss_mask = parser.parse(
@@ -286,6 +293,7 @@ def build_eagle3_dataset(
is_vlm: Optional[bool] = False,
processor: Optional[ImageProcessingMixin] = None,
is_preformatted: Optional[bool] = False,
+ is_think_mode: Optional[bool] = False,
) -> HFDataset:
"""
build eagle3 dataset
@@ -311,6 +319,7 @@ def build_eagle3_dataset(
the assistant spans for loss mask generation.
If True, expects "text" column with ready-to-train text.
If False, expects "conversations" column with ShareGPT format.
+ is_think_mode: Whether to enable think mode in the chat template processing.
Returns:
The processed HF dataset.
@@ -352,6 +361,7 @@ def preprocess_function(examples):
template,
max_length,
is_preformatted=True,
+ is_think_mode=is_think_mode,
)
else:
# Handle ShareGPT conversations
@@ -365,6 +375,7 @@ def preprocess_function(examples):
template,
max_length,
is_preformatted=False,
+ is_think_mode=is_think_mode,
)
return processed
diff --git a/specforge/data/template.py b/specforge/data/template.py
index 030aa5d0..cc1c56e5 100644
--- a/specforge/data/template.py
+++ b/specforge/data/template.py
@@ -10,12 +10,14 @@ class ChatTemplate(BaseModel):
Args:
assistant_header(str): The header for the assistant.
+ assistant_think_header(str): The header for the assistant in the think mode.
user_header(str): The header for the user.
system_prompt(str): The system prompt.
end_of_turn_token(str): The end token of a turn of conversation.
"""
assistant_header: str | None
+ assistant_think_header: str | None
user_header: str | None
system_prompt: str | None
end_of_turn_token: str | None
@@ -34,6 +36,7 @@ class TemplateRegistry:
name="custom",
template=ChatTemplate(
assistant_header="<|start_header_id|>assistant<|end_header_id|>\n\n",
+ assistant_think_header="<|start_header_id|>assistant<|end_header_id|>\n\n",
user_header="<|start_header_id|>user<|end_header_id|>",
system_prompt="You are a helpful assistant.",
end_of_turn_token="<|eot_id|>"
@@ -89,6 +92,7 @@ def get_all_template_names(self) -> List[str]:
name="llama3",
template=ChatTemplate(
assistant_header="<|start_header_id|>assistant<|end_header_id|>\n\n",
+ assistant_think_header="<|start_header_id|>assistant<|end_header_id|>\n\n",
user_header="<|start_header_id|>user<|end_header_id|>",
system_prompt="You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.",
end_of_turn_token="<|eot_id|>",
@@ -99,6 +103,7 @@ def get_all_template_names(self) -> List[str]:
name="llama4",
template=ChatTemplate(
assistant_header="<|header_start|>assistant<|header_end|>\n\n",
+ assistant_think_header="<|header_start|>assistant<|header_end|>\n\n",
user_header="<|header_start|>user<|header_end|>",
system_prompt="You are a helpful assistant.",
end_of_turn_token="<|eot|>",
@@ -109,16 +114,41 @@ def get_all_template_names(self) -> List[str]:
name="qwen",
template=ChatTemplate(
assistant_header="<|im_start|>assistant\n",
+ assistant_think_header="<|im_start|>assistant\n",
user_header="<|im_start|>user\n",
system_prompt="You are a helpful assistant.",
end_of_turn_token="<|im_end|>\n",
),
)
+TEMPLATE_REGISTRY.register(
+ name="qwen3",
+ template=ChatTemplate(
+ assistant_header="<|im_start|>assistant\n\n\n\n\n",
+ assistant_think_header="<|im_start|>assistant\n",
+ user_header="<|im_start|>user\n",
+ system_prompt="You are a helpful assistant.",
+ end_of_turn_token="<|im_end|>\n",
+ ),
+)
+
+TEMPLATE_REGISTRY.register(
+ name="qwq",
+ template=ChatTemplate(
+ assistant_header="<|im_start|>assistant\n\n",
+ assistant_think_header="<|im_start|>assistant\n\n",
+ user_header="<|im_start|>user\n",
+ system_prompt="You are a helpful assistant.",
+ end_of_turn_token="<|im_end|>\n",
+ ),
+)
+
+
TEMPLATE_REGISTRY.register(
name="qwen2-vl",
template=ChatTemplate(
assistant_header="<|im_start|>assistant\n",
+ assistant_think_header="<|im_start|>assistant\n",
user_header="<|im_start|>user\n",
system_prompt="You are a helpful assistant.",
end_of_turn_token="<|im_end|>\n",
@@ -129,6 +159,7 @@ def get_all_template_names(self) -> List[str]:
name="deepseek",
template=ChatTemplate(
assistant_header="Assistant:",
+ assistant_think_header="Assistant:",
user_header="User:",
system_prompt="You are a helpful assistant.",
end_of_turn_token="",
@@ -139,6 +170,7 @@ def get_all_template_names(self) -> List[str]:
name="phi3",
template=ChatTemplate(
assistant_header="<|assistant|>\n",
+ assistant_think_header="<|assistant|>\n",
user_header="<|user|>\n",
system_prompt="You are a helpful assistant.",
end_of_turn_token="<|end|>\n",
@@ -149,6 +181,7 @@ def get_all_template_names(self) -> List[str]:
name="phi4",
template=ChatTemplate(
assistant_header="<|im_start|>assistant<|im_sep|>",
+ assistant_think_header="<|im_start|>assistant<|im_sep|>",
user_header="<|im_start|>user<|im_sep|>",
system_prompt="You are a helpful assistant.",
end_of_turn_token="<|im_end|>",
@@ -159,6 +192,7 @@ def get_all_template_names(self) -> List[str]:
name="phi4-mini",
template=ChatTemplate(
assistant_header="<|assistant|>",
+ assistant_think_header="<|assistant|>",
user_header="<|user|>",
system_prompt="You are a helpful assistant.",
end_of_turn_token="<|end|>",
@@ -170,6 +204,7 @@ def get_all_template_names(self) -> List[str]:
name="gpt-oss",
template=ChatTemplate(
assistant_header=None, # the headers are not applicable to openai-harmony's channel tags
+ assistant_think_header=None,
user_header=None,
system_prompt=None,
end_of_turn_token=None,