Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions scripts/train_eagle3_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@ def parse_args():
action="store_true",
help="Whether the input data is preformatted text with the chat template already applied to the conversation messages.",
)

parser.add_argument(
"--is-think-mode",
action="store_true",
help="Whether the input data need to handle special think token format.",
)
# distributed training
parser.add_argument("--tp-size", type=int, default=1)
parser.add_argument("--dp-size", type=int, default=1)
Expand Down Expand Up @@ -291,22 +295,17 @@ def main():

# load model with resume
if draft_model_last_checkpoint:
draft_model = (
AutoEagle3DraftModel.from_pretrained(
draft_model_last_checkpoint, attention_backend=args.attention_backend,
torch_dtype=torch.bfloat16
)
.cuda()

)
draft_model = AutoEagle3DraftModel.from_pretrained(
draft_model_last_checkpoint,
attention_backend=args.attention_backend,
torch_dtype=torch.bfloat16,
).cuda()
else:
draft_model = (
AutoEagle3DraftModel.from_config(
draft_model_config, attention_backend=args.attention_backend,
torch_dtype=torch.bfloat16
)
.cuda()
)
draft_model = AutoEagle3DraftModel.from_config(
draft_model_config,
attention_backend=args.attention_backend,
torch_dtype=torch.bfloat16,
).cuda()
draft_model.load_embedding(args.target_model_path, embedding_key=args.embedding_key)
draft_model.freeze_embedding()
print_with_rank("Initialized draft model")
Expand Down Expand Up @@ -341,6 +340,7 @@ def main():
cache_key=cache_key,
is_vlm=args.is_vlm,
is_preformatted=args.is_preformatted,
is_think_mode=args.is_think_mode,
processor=processor,
num_proc=args.build_dataset_num_proc,
)
Expand Down Expand Up @@ -388,6 +388,7 @@ def main():
processor=processor,
num_proc=args.build_dataset_num_proc,
is_preformatted=args.is_preformatted,
is_think_mode=args.is_think_mode,
)
eval_dataloader = prepare_dp_dataloaders(
eval_eagle3_dataset,
Expand Down
30 changes: 23 additions & 7 deletions specforge/data/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,25 @@ def parse(

class GeneralParser(Parser):

def __init__(self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate):
def __init__(
self,
tokenizer: PreTrainedTokenizer,
chat_template: ChatTemplate,
is_think_mode: bool = False,
):
super().__init__(tokenizer, chat_template)
self.system_prompt = chat_template.system_prompt
self.is_think_mode = is_think_mode
self.chat_template = chat_template
self.user_message_separator = (
f"{chat_template.end_of_turn_token}{chat_template.user_header}"
)
self.assistant_message_separator = (
f"{chat_template.end_of_turn_token}{chat_template.assistant_header}"
)
if is_think_mode:
self.assistant_message_separator = f"{chat_template.end_of_turn_token}{chat_template.assistant_think_header}"
else:
self.assistant_message_separator = (
f"{chat_template.end_of_turn_token}{chat_template.assistant_header}"
)

def parse(
self, conversation: "Conversation", max_length: int, preformatted: bool = False
Expand All @@ -66,7 +76,7 @@ def parse(
messages.append({"role": "system", "content": self.system_prompt})

convroles = ["user", "assistant"]
for j, sentence in enumerate(conversation):
for j, sentence in enumerate(conversation[:-1]):
role = sentence["role"]
if role != convroles[j % 2]:
warnings.warn(
Expand All @@ -75,10 +85,16 @@ def parse(
break
messages.append({"role": role, "content": sentence["content"]})

conversation = self.tokenizer.apply_chat_template(
conversation_ = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
add_generation_prompt=True,
enable_thinking=self.is_think_mode,
)
conversation = (
conversation_
+ conversation[-1]["content"]
+ self.chat_template.end_of_turn_token
)

if not self.tokenizer.pad_token_id:
Expand Down
21 changes: 16 additions & 5 deletions specforge/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def _apply_loss_mask_from_chat_template(
text: str,
offsets: torch.Tensor,
chat_template: ChatTemplate,
is_think_mode: bool = False,
) -> torch.Tensor:
"""
Apply loss mask to identify assistant response spans using chat template.
Expand All @@ -74,9 +75,14 @@ def _apply_loss_mask_from_chat_template(
user_message_separator = (
f"{chat_template.end_of_turn_token}{chat_template.user_header}"
)
assistant_message_separator = (
f"{chat_template.end_of_turn_token}{chat_template.assistant_header}"
)
if is_think_mode:
assistant_message_separator = (
f"{chat_template.end_of_turn_token}{chat_template.assistant_think_header}"
)
else:
assistant_message_separator = (
f"{chat_template.end_of_turn_token}{chat_template.assistant_header}"
)
Comment on lines +83 to +85
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assistant_message_separator = ( f"{chat_template.end_of_turn_token}{chat_template.assistant_header}" f"{chat_template.end_of_turn_token}" )
Any reason for removing f"{chat_template.end_of_turn_token}"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I adopted the suggested code from gemini code; I added back the chat_template.end_of_turn_token, restoring it to the original assistant_message_separator.


# Find spans of assistant responses using regex
assistant_pattern = (
Expand Down Expand Up @@ -116,6 +122,7 @@ def preprocess_conversations(
chat_template: ChatTemplate,
max_length: int = 2048,
is_preformatted: bool = False,
is_think_mode: bool = False,
) -> Dict[str, List[torch.Tensor]]:
"""
Preprocess a batch of ShareGPT style conversations or pre-formatted text.
Expand All @@ -139,14 +146,14 @@ def preprocess_conversations(
results = {"input_ids": [], "loss_mask": [], "attention_mask": []}

if chat_template.parser_type == "general":
parser = GeneralParser(tokenizer, chat_template)
parser = GeneralParser(tokenizer, chat_template, is_think_mode)
elif chat_template.parser_type == "openai-harmony":
parser = HarmonyParser(tokenizer, chat_template)
else:
raise ValueError(f"Invalid parser type: {chat_template.parser_type}")

for source in conversations:
if not source:
if not source or len(source) % 2 != 0:
# if the source is None, skip it
continue
input_ids, loss_mask = parser.parse(
Expand Down Expand Up @@ -286,6 +293,7 @@ def build_eagle3_dataset(
is_vlm: Optional[bool] = False,
processor: Optional[ImageProcessingMixin] = None,
is_preformatted: Optional[bool] = False,
is_think_mode: Optional[bool] = False,
) -> HFDataset:
"""
build eagle3 dataset
Expand All @@ -311,6 +319,7 @@ def build_eagle3_dataset(
the assistant spans for loss mask generation.
If True, expects "text" column with ready-to-train text.
If False, expects "conversations" column with ShareGPT format.
is_think_mode: Whether to enable think mode in the chat template processing.

Returns:
The processed HF dataset.
Expand Down Expand Up @@ -352,6 +361,7 @@ def preprocess_function(examples):
template,
max_length,
is_preformatted=True,
is_think_mode=is_think_mode,
)
else:
# Handle ShareGPT conversations
Expand All @@ -365,6 +375,7 @@ def preprocess_function(examples):
template,
max_length,
is_preformatted=False,
is_think_mode=is_think_mode,
)

return processed
Expand Down
35 changes: 35 additions & 0 deletions specforge/data/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ class ChatTemplate(BaseModel):

Args:
assistant_header(str): The header for the assistant.
assistant_think_header(str): The header for the assistant in the think mode.
user_header(str): The header for the user.
system_prompt(str): The system prompt.
end_of_turn_token(str): The end token of a turn of conversation.
"""

assistant_header: str | None
assistant_think_header: str | None
user_header: str | None
system_prompt: str | None
end_of_turn_token: str | None
Expand All @@ -34,6 +36,7 @@ class TemplateRegistry:
name="custom",
template=ChatTemplate(
assistant_header="<|start_header_id|>assistant<|end_header_id|>\n\n",
assistant_think_header="<|start_header_id|>assistant<|end_header_id|>\n\n",
user_header="<|start_header_id|>user<|end_header_id|>",
system_prompt="You are a helpful assistant.",
end_of_turn_token="<|eot_id|>"
Expand Down Expand Up @@ -89,6 +92,7 @@ def get_all_template_names(self) -> List[str]:
name="llama3",
template=ChatTemplate(
assistant_header="<|start_header_id|>assistant<|end_header_id|>\n\n",
assistant_think_header="<|start_header_id|>assistant<|end_header_id|>\n\n",
user_header="<|start_header_id|>user<|end_header_id|>",
system_prompt="You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.",
end_of_turn_token="<|eot_id|>",
Expand All @@ -99,6 +103,7 @@ def get_all_template_names(self) -> List[str]:
name="llama4",
template=ChatTemplate(
assistant_header="<|header_start|>assistant<|header_end|>\n\n",
assistant_think_header="<|header_start|>assistant<|header_end|>\n\n",
user_header="<|header_start|>user<|header_end|>",
system_prompt="You are a helpful assistant.",
end_of_turn_token="<|eot|>",
Expand All @@ -109,16 +114,41 @@ def get_all_template_names(self) -> List[str]:
name="qwen",
template=ChatTemplate(
assistant_header="<|im_start|>assistant\n",
assistant_think_header="<|im_start|>assistant\n",
user_header="<|im_start|>user\n",
system_prompt="You are a helpful assistant.",
end_of_turn_token="<|im_end|>\n",
),
)

TEMPLATE_REGISTRY.register(
name="qwen3",
template=ChatTemplate(
assistant_header="<|im_start|>assistant\n<think>\n\n</think>\n\n",
assistant_think_header="<|im_start|>assistant\n",
user_header="<|im_start|>user\n",
system_prompt="You are a helpful assistant.",
end_of_turn_token="<|im_end|>\n",
),
)

TEMPLATE_REGISTRY.register(
name="qwq",
template=ChatTemplate(
assistant_header="<|im_start|>assistant\n<think>\n",
assistant_think_header="<|im_start|>assistant\n<think>\n",
user_header="<|im_start|>user\n",
system_prompt="You are a helpful assistant.",
end_of_turn_token="<|im_end|>\n",
),
)


TEMPLATE_REGISTRY.register(
name="qwen2-vl",
template=ChatTemplate(
assistant_header="<|im_start|>assistant\n",
assistant_think_header="<|im_start|>assistant\n",
user_header="<|im_start|>user\n",
system_prompt="You are a helpful assistant.",
end_of_turn_token="<|im_end|>\n",
Expand All @@ -129,6 +159,7 @@ def get_all_template_names(self) -> List[str]:
name="deepseek",
template=ChatTemplate(
assistant_header="Assistant:",
assistant_think_header="Assistant:",
user_header="User:",
system_prompt="You are a helpful assistant.",
end_of_turn_token="",
Expand All @@ -139,6 +170,7 @@ def get_all_template_names(self) -> List[str]:
name="phi3",
template=ChatTemplate(
assistant_header="<|assistant|>\n",
assistant_think_header="<|assistant|>\n",
user_header="<|user|>\n",
system_prompt="You are a helpful assistant.",
end_of_turn_token="<|end|>\n",
Expand All @@ -149,6 +181,7 @@ def get_all_template_names(self) -> List[str]:
name="phi4",
template=ChatTemplate(
assistant_header="<|im_start|>assistant<|im_sep|>",
assistant_think_header="<|im_start|>assistant<|im_sep|>",
user_header="<|im_start|>user<|im_sep|>",
system_prompt="You are a helpful assistant.",
end_of_turn_token="<|im_end|>",
Expand All @@ -159,6 +192,7 @@ def get_all_template_names(self) -> List[str]:
name="phi4-mini",
template=ChatTemplate(
assistant_header="<|assistant|>",
assistant_think_header="<|assistant|>",
user_header="<|user|>",
system_prompt="You are a helpful assistant.",
end_of_turn_token="<|end|>",
Expand All @@ -170,6 +204,7 @@ def get_all_template_names(self) -> List[str]:
name="gpt-oss",
template=ChatTemplate(
assistant_header=None, # the headers are not applicable to openai-harmony's channel tags
assistant_think_header=None,
user_header=None,
system_prompt=None,
end_of_turn_token=None,
Expand Down
Loading