diff --git a/torchtitan/experiments/gpt_oss/__init__.py b/torchtitan/experiments/gpt_oss/__init__.py new file mode 100644 index 0000000000..14c3600dde --- /dev/null +++ b/torchtitan/experiments/gpt_oss/__init__.py @@ -0,0 +1,53 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.experiments.llama4.optimizer import build_llama4_optimizers + +from torchtitan.protocols.train_spec import register_train_spec, TrainSpec + +from .infra.parallelize import parallelize_gptoss +from .model.args import GptOssModelArgs +from .model.model import GptOssModel + +__all__ = [ + "parallelize_gptoss", + "GptOssModelArgs", + "GptOssModel", + "gptoss_configs", +] + + +gptoss_configs = { + "debugmodel": GptOssModelArgs( + hidden_size=256, + num_hidden_layers=4, + ), + "20B": GptOssModelArgs( + num_hidden_layers=24, + num_local_experts=32, + ), + "120B": GptOssModelArgs( + num_hidden_layers=36, + num_local_experts=128, + ), +} + + +register_train_spec( + TrainSpec( + name="gpt_oss", + cls=GptOssModel, + config=gptoss_configs, + parallelize_fn=parallelize_gptoss, + pipelining_fn=None, + build_optimizers_fn=build_llama4_optimizers, # use optimizer hooks to update expert weights + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) +) diff --git a/torchtitan/experiments/gpt_oss/model/args.py b/torchtitan/experiments/gpt_oss/model/args.py new file mode 100644 index 0000000000..e91441092f --- /dev/null +++ b/torchtitan/experiments/gpt_oss/model/args.py @@ -0,0 +1,141 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass +from typing import Literal + +from torch import nn + +from torchtitan.components.tokenizer import Tokenizer +from torchtitan.config_manager import JobConfig +from torchtitan.protocols.train_spec import BaseModelArgs +from torchtitan.tools.logging import logger + +# from transformers.models.gpt_oss.modeling_gpt_oss import GPT_OSS_PRETRAINED_INIT_CONFIGURATION + + +# Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py +@dataclass +class GptOssModelArgs(BaseModelArgs): + """ + Data class for defining model arguments and hyperparameters. + + Attributes: + max_batch_size (int): Maximum batch size. + max_seq_len (int): Maximum sequence length. + dtype (Literal["bf16", "fp8"]): Data type for computations. + vocab_size (int): Vocabulary size. + dim (int): Model dimension. + inter_dim (int): Intermediate dimension for MLP layers. + moe_inter_dim (int): Intermediate dimension for MoE layers. + n_layers (int): Number of transformer layers. + n_dense_layers (int): Number of dense layers in the model. + n_heads (int): Number of attention heads. + n_routed_experts (int): Number of routed experts for MoE layers. + n_shared_experts (int): Number of shared experts for MoE layers. + n_activated_experts (int): Number of activated experts in MoE layers. + n_expert_groups (int): Number of expert groups. + n_limited_groups (int): Number of limited groups for MoE routing. + score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing. + route_scale (float): Scaling factor for routing scores. + use_grouped_mm (bool): Whether to use grouped matrix multiplication for MoE layers. + load_balance_coeff (float | None): Auxiliary-Loss-Free Load balancing coefficient for MoE layers. + q_lora_rank (int): LoRA rank for query projections. + kv_lora_rank (int): LoRA rank for key-value projections. + qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings. + qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings. + v_head_dim (int): Dimension for value projections. + original_seq_len (int): Original sequence length. + rope_theta (float): Base for rotary positional encoding. + rope_factor (float): Scaling factor for extended sequence lengths. + beta_fast (int): Fast beta correction factor. + beta_slow (int): Slow beta correction factor. + """ + + max_batch_size: int = 8 + max_seq_len: int = 131072 + dtype: Literal["bf16", "fp8"] = "bf16" + vocab_size: int = 201088 + hidden_size: int = 2880 + num_hidden_layers: int = 24 + norm_eps: float = 1e-5 # eps used for RMSNorm + # MoE + num_local_experts: int = 32 + num_experts_per_tok: int = 4 + use_grouped_mm: bool = True + # Multi-Head Latent Attention (MLA) + head_dim: int = 64 + num_attention_heads: int = 64 + num_key_value_heads: int = 8 + sliding_window: int = 128 + use_flex_attn: bool = True + attn_mask_type: str = "causal" + # yarn + original_seq_len: int = 4096 + rope_theta: float = 150000.0 + rope_factor: float = 32 + beta_fast: int = 32 + beta_slow: int = 1 + + def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None: + """ + Update the model_config config from the given job config. + """ + # self.vocab_size = tokenizer.vocab_size # TODO: add tiktokenizer support? + self.max_seq_len = job_config.training.seq_len + + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: + """ + Adopted from llama4 implementation. + """ + nparams_embedding = 0 + nparams_moe_router = 0 + nparams_shared_expert = 0 + nparams_experts = 0 + nparams_dense = 0 + + for name, p in model.named_parameters(): + if "embedding" in name: + nparams_embedding += p.numel() + nparams_dense += p.numel() + elif "moe.shared_expert" in name: + nparams_shared_expert += p.numel() + elif "moe.router" in name: + nparams_moe_router += p.numel() + elif "moe.experts" in name: + nparams_experts += p.numel() + else: + nparams_dense += p.numel() + + nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts + nparams = nparams_dense + nparams_sparse + nparams_sparse_active = ( + nparams_moe_router + + nparams_shared_expert + + nparams_experts * self.num_experts_per_tok // self.num_local_experts + ) + + logger.info( + f"Total parameter count: dense {nparams_dense:,}, " + f"sparse {nparams_sparse:,}, active {nparams_dense + nparams_sparse_active:,}" + ) + + l, h, q, t = ( + self.num_hidden_layers, + self.num_attention_heads, + self.hidden_size // self.num_attention_heads, + seq_len, + ) + # Reasoning behind the factor of 12 for the self-attention part of the formula: + # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6) + # 2. the flash attention does 1 more matmul recomputation in the backward + # but recomputation should not be counted in calculating MFU (+0) + # 3. each matmul performs 1 multiplication and 1 addition (*2) + # 4. we follow the convention and do not account for sparsity in causal attention + num_flops_per_token = ( + 6 * (nparams_dense - nparams_embedding + nparams_sparse_active) + + 12 * l * h * q * t + ) + + return nparams, num_flops_per_token diff --git a/torchtitan/experiments/gpt_oss/model/model.py b/torchtitan/experiments/gpt_oss/model/model.py new file mode 100644 index 0000000000..2fea0bf2c8 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/model/model.py @@ -0,0 +1,473 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Tuple + +import torch +from torch import nn +from torchtitan.models.attention import build_attention +from torchtitan.protocols.train_spec import ModelProtocol + +from .args import GptOssModelArgs +from .moe import MoE + +# TODO: may be able to remove this once parallelized properly +def convert_submodules_to_bf16( + module: nn.Module, + exclude_names: tuple[str, ...] = ("freqs_cis", "attention_norm", "ffn_norm", "norm"), + attr_opt_out: str = "no_bf16", # if a submodule sets `self.no_bf16 = True`, it will be skipped + ) -> None: + """ + Recursively convert parameters & buffers of submodules to bfloat16, + except: + - modules whose *qualified name* ends with any of `exclude_names` + - modules with attribute `{attr_opt_out} == True` + Conversion is *shallow per-module* so exclusions are respected even deep in the tree. + """ + + def should_skip(qname: str, mod: nn.Module) -> bool: + base = qname.rsplit(".", 1)[-1] # local (leaf) name + if base in exclude_names: + return True + if getattr(mod, attr_opt_out, False): + return True + return False + + def convert_shallow(mod: nn.Module): + # convert parameters owned by this module + for _, p in mod.named_parameters(recurse=False): + if p.is_floating_point(): + p.data = p.data.to(torch.bfloat16) + # convert buffers owned by this module + for _, b in mod.named_buffers(recurse=False): + # keep non-float buffers (e.g., ints, bool masks) as-is + if torch.is_floating_point(b): + b.data = b.data.to(torch.bfloat16) + + # walk the module tree; convert only *this* module's tensors if not skipped + for qname, mod in module.named_modules(): + # skip the root container name (empty) check gracefully + local_name = qname.rsplit(".", 1)[-1] if qname else "" + if local_name and should_skip(qname, mod): + continue + convert_shallow(mod) + +# Adapted from https://github.com/DeepSeek-ai/DeepSeek-V3/blob/main/inference/model.py#L294 +def precompute_freqs_cis(args: GptOssModelArgs) -> torch.Tensor: + """ + Precomputes frequency-based complex exponential values for rotary positional embeddings. + + Args: + args (GptOssModelArgs): Model arguments containing positional embedding parameters. + + Returns: + torch.Tensor: Precomputed complex exponential values for positional embeddings. + """ + dim = args.head_dim + seqlen = args.max_seq_len + beta_fast = args.beta_fast + beta_slow = args.beta_slow + base = args.rope_theta + factor = args.rope_factor + original_seq_len = args.original_seq_len + + # YaRN default m-scale (attention_factor). Matches HF when attention_factor is None. + mscale = 0.1 * math.log(factor) + 1.0 + + def find_correction_dim( + num_rotations: float, dim: int, base: float, max_seq_len: int + ) -> float: + """ + Computes the correction dimension for a given number of rotations in the rotary positional embedding. + + Args: + num_rotations (float): Number of rotations to compute the correction for. + dim (int): Dimensionality of the embedding space. + base (float): Base value for the exponential computation. + max_seq_len (int): Maximum sequence length. + + Returns: + float: The correction dimension based on the input parameters. + """ + return ( + dim + * math.log(max_seq_len / (num_rotations * 2 * math.pi)) + / (2 * math.log(base)) + ) + + def find_correction_range( + low_rot: float, high_rot: float, dim: int, base: float, max_seq_len: int + ) -> Tuple[int, int]: + """ + Computes the range of correction dimensions for rotary positional embeddings. + + Args: + low_rot (float): Lower bound for the number of rotations. + high_rot (float): Upper bound for the number of rotations. + dim (int): Dimensionality of the embedding space. + base (float): Base value for the exponential computation. + max_seq_len (int): Maximum sequence length. + + Returns: + Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices. + """ + low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min: float, max: float, dim: int) -> torch.Tensor: + """ + Computes a linear ramp function used to smooth values between a minimum and maximum range. + + Args: + min (float): Minimum value for the ramp function. + max (float): Maximum value for the ramp function. + dim (int): Dimensionality of the ramp tensor. + + Returns: + torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1, + clamped to the range [0, 1]. + """ + if min == max: + max += 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + # Basic RoPE frequency calculation + freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + + # YaRN scaling for extended context. YaRN is used to extend the context length after pre-training. + if seqlen > original_seq_len: + low, high = find_correction_range( + beta_fast, beta_slow, dim, base, original_seq_len + ) + smooth = 1 - linear_ramp_factor(low, high, dim // 2) + freqs = freqs / factor * (1 - smooth) + freqs * smooth + + # Create position indices + t = torch.arange(seqlen) + + # Outer product: [positions] Ɨ [frequencies] + freqs = torch.outer(t, freqs) + + # Convert to complex exponentials: e^(i*freq*pos) + freqs_cis = torch.polar(torch.full_like(freqs, fill_value=mscale), freqs) + + return freqs_cis + + +def apply_rotary_emb_inner(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + """ + Applies rotary positional embeddings to the input tensor. + + Args: + x (torch.Tensor): Input tensor with positional embeddings to be applied. + freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings. + + Returns: + torch.Tensor: Tensor with rotary embeddings applied. + """ + dtype = x.dtype + x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) + y = torch.view_as_real(x * freqs_cis).flatten(3) + return y.to(dtype) + +def apply_rotary_emb(q: torch.Tensor, k: torch.Tensor, freqs_cis: torch.Tensor): + """ + HF-style inputs (half-split last dim) -> interleave -> Torchtitan complex RoPE -> de-interleave. + Shapes: + q, k: [B, T, H, D] with D even (HF half-split: first D/2 real, last D/2 imag) + freqs_cis: complex, last dim == D/2. Typically [T, D/2] or [1, T, D/2]. + Returns: + q_out, k_out in HF half-split layout (same shape as q, k). + """ + B, T, H, D = q.shape + assert D % 2 == 0, "head_dim must be even for RoPE" + rot = D // 2 + assert freqs_cis.shape[-1] == rot, "freqs_cis last dim must be D/2" + freqs_cis = freqs_cis[:T, :] + + # --- inline: HF half-split -> interleaved (real0, imag0, real1, imag1, ...) + # q_i, k_i: [B, T, H, D] + q_i = torch.empty_like(q) + k_i = torch.empty_like(k) + q_i[..., 0::2] = q[..., :rot] + q_i[..., 1::2] = q[..., rot:] + k_i[..., 0::2] = k[..., :rot] + k_i[..., 1::2] = k[..., rot:] + + # --- Torchtitan default complex apply (expects interleaved last dim) + # freqs_cis will be reshaped inside to [1, T, 1, rot] + q_rot_i = apply_rotary_emb_inner(q_i, freqs_cis) # uses TT's complex path + k_rot_i = apply_rotary_emb_inner(k_i, freqs_cis) + + # --- inline: interleaved -> HF half-split + q_out = torch.cat([q_rot_i[..., 0::2], q_rot_i[..., 1::2]], dim=-1) + k_out = torch.cat([k_rot_i[..., 0::2], k_rot_i[..., 1::2]], dim=-1) + return q_out, k_out + +# Torch Attention backup implementation (for debugging and sampling) from HuggingFace +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + +def eager_attention_forward( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + attention_mask: torch.Tensor, + scaling: float, + dropout: float = 0.0, + num_key_value_groups: int = 1, + **kwargs, +): + key_states = repeat_kv(key, num_key_value_groups) + value_states = repeat_kv(value, num_key_value_groups) + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + # attention_mask can be [Tq, Tk] or [B, H, Tq, Tk] + # Convert boolean "allowed" -> additive mask + if attention_mask.dtype == torch.bool: + m = attention_mask + add_mask = torch.zeros_like(m, dtype=attn_weights.dtype) + add_mask = add_mask.masked_fill(~m, -float("inf")) + else: + add_mask = attention_mask.to(attn_weights.dtype) + + # Truncate to current key length and add (broadcasts if needed) + add_mask = add_mask[..., : key_states.shape[-2]] + attn_weights = attn_weights + add_mask + + sinks = sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + combined_logits = torch.cat([attn_weights, sinks], dim=-1) + + # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 + # when training with bsz>1 we clamp max values. + + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + probs = nn.functional.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) + scores = probs[..., :-1] # we drop the sink here + attn_weights = nn.functional.dropout(scores, p=dropout, training=False) + attn_output = torch.matmul(attn_weights, value_states) + return attn_output + +class Attention(nn.Module): + """ + Multi-head attention (MLA) module. + """ + + def __init__(self, model_args: GptOssModelArgs, use_sliding_attention: bool = False): + super().__init__() + + self.sliding_window = model_args.sliding_window if use_sliding_attention else None + self.head_dim = model_args.head_dim + + self.wq = nn.Linear( + model_args.hidden_size, model_args.num_attention_heads * model_args.head_dim, bias=True + ) + self.wk = nn.Linear( + model_args.hidden_size, model_args.num_key_value_heads * model_args.head_dim, bias=True + ) + self.wv = nn.Linear( + model_args.hidden_size, model_args.num_key_value_heads * model_args.head_dim, bias=True + ) + self.wo = nn.Linear( + model_args.num_attention_heads * model_args.head_dim, model_args.hidden_size, bias=True + ) + self.sinks = nn.Parameter(torch.empty(model_args.num_attention_heads)) + + self.use_flex_attn = model_args.use_flex_attn + if self.use_flex_attn: + self.attn = build_attention(True, model_args.attn_mask_type) + else: + # NOTE: sampling with FlexAttn seems broken; use TorchAttn if needed + self.attn = eager_attention_forward + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """ + Forward pass for the Multi-Head Latent Attention (MLA) Layer. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). + freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + + Returns: + torch.Tensor: Output tensor with the same shape as the input. + """ + bsz, seqlen, _ = x.size() + hidden_shape = (bsz, seqlen, -1, self.head_dim) + + q = self.wq(x).view(hidden_shape) + k = self.wk(x).view(hidden_shape) + v = self.wv(x).view(hidden_shape) + + q, k = apply_rotary_emb(q, k, freqs_cis) + + q = q.transpose(1, 2).contiguous() + k = k.transpose(1, 2).contiguous() + v = v.transpose(1, 2).contiguous() + + if self.use_flex_attn: + output = self.attn(q, k, v, self.sinks, sliding_window=self.sliding_window, enable_gqa=True) + else: + output = self.attn( + q, k, v, self.sinks, + attention_mask=self.sliding_window_causal(seqlen, x.device), + scaling=self.head_dim**-0.5, + dropout=0.0, + num_key_value_groups=8, + ) + output = output.transpose(1, 2).contiguous() # (B, H, T, D) -> (B, T, H, D) + + # Reshape and project output + output = output.reshape(bsz, seqlen, -1).contiguous() # (bsz, seqlen, n_heads * v_head_dim) + output = self.wo(output) # (bsz, seqlen, dim) + return output + + def init_weights(self, init_std: float): + linear_list = [ + self.wq, + self.wk, + self.wv, + ] + + for linear in linear_list: + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + # TODO: statically init the mask using train.seq_len + def sliding_window_causal(self, seqlen, device): + i = torch.arange(seqlen, device=device) + q_idx = i[:, None] + kv_idx = i[None, :] + + causal_mask = q_idx >= kv_idx + if self.sliding_window is None: + return causal_mask + window_mask = q_idx - kv_idx <= self.sliding_window + return causal_mask & window_mask + + +class TransformerBlock(nn.Module): + """ + Transformer block with attention and feed-forward layers. + """ + + def __init__(self, layer_id: int, model_args: GptOssModelArgs): + + super().__init__() + use_sliding_attention = layer_id % 2 == 0 + self.attention = Attention(model_args, use_sliding_attention=use_sliding_attention) + self.attention_norm = nn.RMSNorm(model_args.hidden_size, eps=model_args.norm_eps) + self.ffn_norm = nn.RMSNorm(model_args.hidden_size, eps=model_args.norm_eps) + + self.moe = MoE(model_args) + + self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 + self.layer_id = layer_id + + def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): + """ + Forward pass for the Transformer block. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). + freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + + Returns: + torch.Tensor: Output tensor with the same shape as the input. + """ + x = x + self.attention(self.attention_norm(x), freqs_cis) + x = x + self.moe(self.ffn_norm(x)) + return x + + def init_weights(self, buffer_device: torch.device): + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + self.moe.init_weights(self.weight_init_std, buffer_device) + + +class GptOssModel(nn.Module, ModelProtocol): + """ + GPT-OSS Transformer model with attention and feed-forward layers. + """ + + def __init__(self, model_args: GptOssModelArgs): + super().__init__() + self.max_seq_len = model_args.max_seq_len + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.hidden_size) + self.register_buffer( + "freqs_cis", precompute_freqs_cis(model_args), persistent=True + ) + + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.num_hidden_layers): + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args).to(torch.bfloat16) + convert_submodules_to_bf16(self.layers[str(layer_id)]) + + self.norm = nn.RMSNorm(model_args.hidden_size, eps=model_args.norm_eps) + self.output = nn.Linear( + model_args.hidden_size, + model_args.vocab_size, + dtype=torch.get_default_dtype(), + bias=False, + ) + self.model_args = model_args + self.init_weights() + convert_submodules_to_bf16(self) + + def init_weights(self, buffer_device: torch.device | None = None) -> None: + buffer_device = buffer_device or self.freqs_cis.device + with torch.device(buffer_device): + self.freqs_cis = precompute_freqs_cis(self.model_args) + if self.tok_embeddings is not None: + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + if layer is not None: + layer.init_weights(buffer_device=buffer_device) + if self.norm is not None: + self.norm.reset_parameters() + final_out_std = self.model_args.hidden_size**-0.5 + cutoff_factor = 3 + if self.output is not None: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def forward(self, tokens: torch.Tensor): + """ + Forward pass for the Transformer model. + + Args: + tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len). + + Returns: + torch.Tensor: Logits tensor of shape (batch_size, vocab_size). + """ + h = self.tok_embeddings(tokens) + + for layer in self.layers.values(): + h = layer(h, self.freqs_cis) + h = self.norm(h) + output = self.output(h) + return output diff --git a/torchtitan/experiments/gpt_oss/model/moe.py b/torchtitan/experiments/gpt_oss/model/moe.py new file mode 100644 index 0000000000..1bbd7a838a --- /dev/null +++ b/torchtitan/experiments/gpt_oss/model/moe.py @@ -0,0 +1,280 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +from torch import nn +from torchtitan.experiments.llama4.infra.expert_parallel import expert_parallel + +from .args import GptOssModelArgs + +def swiglu(x, alpha: float = 1.702, limit: float = 7.0): + x_glu, x_linear = x[..., ::2], x[..., 1::2] + # Clamp the input values + x_glu = x_glu.clamp(min=None, max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + # Note we add an extra bias of 1 to the linear layer + return out_glu * (x_linear + 1) + +class GroupedExperts(nn.Module): + def __init__( + self, + dim: int, + num_experts: int, + use_grouped_mm: bool, + ): + super().__init__() + self.num_experts = num_experts + self.use_grouped_mm = use_grouped_mm + + self.mlp1_weight = nn.Parameter(torch.empty((num_experts, dim, dim * 2))) + self.mlp1_bias = nn.Parameter(torch.empty((num_experts, dim * 2))) + self.mlp2_weight = nn.Parameter(torch.empty((num_experts, dim, dim))) + self.mlp2_bias = nn.Parameter(torch.empty((num_experts, dim))) + + def forward( + self, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if self.use_grouped_mm: + return GroupedExperts._run_experts_grouped_mm( + self.mlp1_weight, self.mlp1_bias, self.mlp2_weight, self.mlp2_bias, x, num_tokens_per_expert + ) + else: + return GroupedExperts._run_experts_for_loop( + self.mlp1_weight, self.mlp1_bias, self.mlp2_weight, self.mlp2_bias, x, num_tokens_per_expert + ) + + # TODO: keeping this for-loop implementation for comparison + # and readability, may remove later + # @expert_parallel + @staticmethod + def _run_experts_for_loop( + mlp1_weight: torch.Tensor, + mlp1_bias: torch.Tensor, + mlp2_weight: torch.Tensor, + mlp2_bias: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if num_tokens_per_expert is not None: + # NOTE: this would incur a synchronization between device and host + num_tokens_per_expert = num_tokens_per_expert.tolist() + + # side-effect code due to the usage of generate_permute_indices + num_padding = x.shape[0] - sum(num_tokens_per_expert) + + # a tuple of tensors indexed by experts + # each with shape (tokens_per_expert(varying), dim) + x = torch.split( + x[: sum(num_tokens_per_expert)], + split_size_or_sections=num_tokens_per_expert, + dim=0, + ) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + h = torch.matmul(x_expert, mlp1_weight[expert_idx]) + mlp1_bias[expert_idx] + h = swiglu(h) + h = torch.matmul(h, mlp2_weight[expert_idx]) + mlp2_bias[expert_idx] + out_experts_splits.append(h) + out = torch.cat(out_experts_splits, dim=0) + + # side-effect code due to the usage of generate_permute_indices + out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) + else: + # x shape (num_experts, tokens_per_expert, dim) + h = torch.bmm(x, mlp1_weight) + mlp1_bias.unsqueeze(1) + h = swiglu(h) + out = torch.bmm(h, mlp2_weight) + mlp2_bias.unsqueeze(1) + + return out + + # @expert_parallel # TODO: e-sharding currently breaks shapes + @staticmethod + def _run_experts_grouped_mm( + mlp1_weight: torch.Tensor, + mlp1_bias: torch.Tensor, + mlp2_weight: torch.Tensor, + mlp2_bias: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if num_tokens_per_expert is not None: + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + # grouped mm between a 2D tensor and a 3D tensor + assert x.dim() == 2 + else: + offsets = None + # fall back to regular bmm between 3D tensors + assert x.dim() == 3 + + num_tokens_per_expert_long = num_tokens_per_expert.to(torch.long) + + h = torch._grouped_mm(x.bfloat16(), mlp1_weight.bfloat16(), offs=offsets) + h += mlp1_bias.repeat_interleave(num_tokens_per_expert_long, dim=0) + h = swiglu(h) + h = torch._grouped_mm(h, mlp2_weight.bfloat16(), offs=offsets) + h += mlp2_bias.repeat_interleave(num_tokens_per_expert_long, dim=0) + + return h + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.mlp1_weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.mlp1_bias, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.mlp2_weight, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.mlp2_bias, mean=0.0, std=init_std) + + def extra_repr(self): + return (f"num_experts={self.num_experts}, " + f"use_grouped_mm={self.use_grouped_mm}, " + f"mlp1_weight={tuple(self.mlp1_weight.shape)}, " + f"mlp1_bias={tuple(self.mlp1_bias.shape)}, " + f"mlp2_weight={tuple(self.mlp2_weight.shape)}, " + f"mlp2_bias={tuple(self.mlp2_bias.shape)}") + +class TokenChoiceTopKRouter(nn.Module): + """This class implements token-choice routing. In token-choice top-K routing, each token is + routed to top K experts based on the router scores. + + Args: + dim (int): Dimension of the input. + num_experts (int): Number of experts in each moe layer. + top_k (int): Number of experts each token will be routed to in token-choice routing. + """ + + def __init__( + self, + dim: int, + num_experts: int, + top_k: int, + ): + super().__init__() + + self.dim = dim + self.num_experts = num_experts + self.top_k = top_k + self.gate = nn.Linear(self.dim, self.num_experts, bias=True) + + def forward( + self, x: torch.Tensor, expert_bias: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + TODO: We haven't implement the group-based routing (node limit routing), + and currently EP is not supporting node limit routing yet. + + Args: + x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``. + + Returns: + routed_input (torch.Tensor): + Tokens grouped together by experts indices with shape ``(bs*slen*top_k,)``. + token_indices (torch.Tensor): + Token indices for routed_input with shape ``(bs*slen*top_k,)``. + num_tokens_per_expert (torch.Tensor): + Number of tokens assigned to each expert with shape ``(num_experts,)``. + """ + # scores shape (bs*slen, num_experts) + router_logits = self.gate(x) + + # top scores shape (bs*slen, top_k) + top_scores, selected_experts_indices = torch.topk( + router_logits, k=self.top_k, dim=1 + ) + + top_scores = F.softmax(top_scores, dim=1) + + # group tokens together by expert indices from 0 to num_experts and pass that to experts forward + num_tokens_per_expert = torch.histc( + selected_experts_indices.view(-1), + bins=self.num_experts, + min=0, + max=self.num_experts, + ) + + # Reorder the token indices to match the order of the experts + # token_indices_experts_sorted shape (bs*slen*top_k,) + token_indices_experts_sorted = torch.argsort( + selected_experts_indices.view(-1), stable=True + ) + + # reorder the scores to match the order of the token indices + top_scores = top_scores.view(-1)[token_indices_experts_sorted] + token_indices_experts_sorted = token_indices_experts_sorted // self.top_k + + return top_scores, token_indices_experts_sorted, num_tokens_per_expert + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) + + +class MoE(nn.Module): + def __init__(self, model_args: GptOssModelArgs): + + super().__init__() + dim = model_args.hidden_size + + num_experts = model_args.num_local_experts + top_k = model_args.num_experts_per_tok + + self.experts = GroupedExperts( + dim=dim, + num_experts=num_experts, + use_grouped_mm=model_args.use_grouped_mm, + ) + self.router = TokenChoiceTopKRouter( + dim=dim, + num_experts=num_experts, + top_k=top_k, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``. + + Returns: + out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. + """ + bs, slen, dim = x.shape + + # top_scores and selected_indices shape (bs*slen*top_k,) + # num_tokens_per_expert shape (num_experts,) + ( + top_scores, + token_indices, + num_tokens_per_expert, + ) = self.router(x.reshape(bs * slen, dim)) + + # shape (bs*slen*top_k, dim) + token_indices = token_indices.reshape(-1, 1).expand(-1, dim) + + # shape (bs*slen*top_k, dim) + routed_input = torch.gather( + x.view(-1, dim), + dim=0, + index=token_indices, + ) + + # shape (bs*slen*top_k, dim) + routed_output = self.experts(routed_input, num_tokens_per_expert) + + routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to( + x.dtype + ) + + out = torch.zeros_like(x.reshape(bs * slen, dim)) + + # Accumulate multiple expert results becase each token can be routed to multiple experts + out = out.scatter_add(dim=0, index=token_indices, src=routed_output) + out = out.reshape(bs, slen, dim) + return out + + def init_weights( + self, + init_std: float, + buffer_device: torch.device, + ): + self.experts.init_weights(init_std) + self.router.init_weights(init_std) diff --git a/torchtitan/experiments/gpt_oss/scripts/compare_hf_to_tt.py b/torchtitan/experiments/gpt_oss/scripts/compare_hf_to_tt.py new file mode 100644 index 0000000000..dbbb880af5 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/scripts/compare_hf_to_tt.py @@ -0,0 +1,405 @@ +""" +Compare logits and generations of GPT-OSS implemented in TorchTitan and HuggingFace. +This requires at least a 2xH100. + +First ensure you convert the HF model to a TorchTitan DCP checkpoint: +uv run torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py hf-to-dcp --input-path openai/gpt-oss-20b --output-path gptoss_dcp/ + +Then you can run a comparison like this: +uv run torchtitan/experiments/gpt_oss/scripts/compare_hf_to_tt.py \ + --tt_config torchtitan/models/gpt_oss/train_configs/gpt_oss_20b.toml \ + --tt_checkpoint_path gptoss_dcp/ \ + --hf_model_path openai/gpt-oss-20b \ + --prompt "Once upon a time, in a land far away," \ + --temperature 0.8 \ + --max_new_tokens 256 \ + --batch_size 1 \ + --out +""" + +import json +import os +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Sequence, Tuple, NamedTuple + +import torch +import torch.nn as nn +import torch.distributed.checkpoint as dcp +import tyro +from transformers import AutoModelForCausalLM, AutoTokenizer + +from torchtitan.tools.logging import init_logger, logger +from torchtitan.tools.utils import device_module, device_type +from torchtitan.components.metrics import build_device_memory_monitor +from torchtitan.config_manager import ConfigManager +from torchtitan.protocols.train_spec import get_train_spec +from torchtitan.distributed import ParallelDims, utils as dist_utils +from torch.distributed import DeviceMesh +from torch.distributed.elastic.multiprocessing.errors import record + +# -------- Torchtitan Sampling Utils -------- +def multinomial_sample_one( + probs: torch.Tensor, rng: Optional[torch.Generator] = None +) -> torch.Tensor: + q = torch.empty_like(probs).exponential_(1, generator=rng) + return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.long) + + +def logits_to_probs( + logits: torch.Tensor, + temperature: float = 1.0, + top_k: Optional[int] = None, +) -> torch.Tensor: + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, k=min(top_k, logits.size(-1))) + pivot = v.select(dim=-1, index=-1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def generate_next_token( + model, + x: torch.Tensor, + *, + temperature: float = 1.0, + top_k: Optional[int] = None, + rng: Optional[torch.Generator] = None, +) -> torch.Tensor: + logits = model(x) # (B, T, vocab_size) + probs = logits_to_probs(logits[:, -1, :], temperature, top_k) + next_token = multinomial_sample_one(probs, rng=rng) + return next_token + + +@torch.no_grad() +def tt_generate_text( + model, + input_ids: torch.Tensor, + *, + max_new_tokens: int, + temperature: float = 1.0, + top_k: Optional[int] = None, + seed: Optional[int] = None, +) -> torch.Tensor: + # ensure batch dimension (T,) --> (B, T) + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + + rng = None + if seed is not None: + rng = torch.Generator(input_ids.device).manual_seed(seed) + + generated_tokens = input_ids.clone() + + for i in range(max_new_tokens): + next_token = generate_next_token( + model, + x=generated_tokens.to(input_ids.device), + temperature=temperature, + top_k=top_k, + rng=rng, + ) + print(f"generated token {i}: {next_token}") + + generated_tokens = torch.cat([generated_tokens, next_token], dim=1) + + return generated_tokens + +@dataclass +class GenerateConfig: + """Configuration for test generation.""" + hf_model_path: Optional[str] = None + """HuggingFace model path to load (if provided).""" + tt_config: Optional[str] = None + """TOML config file path for TorchTitan model.""" + tt_checkpoint_path: Optional[str] = None + """Checkpoint path for the TorchTitan model (if provided).""" + tt_tokenizer_path: Optional[str] = "libs/torchtitan/torchtitan/models/gpt_oss_20b/tokenizer" + """Tokenizer path to load.""" + temperature: float = 1.0 + """Sampling temperature (0 for greedy).""" + max_new_tokens: int = 32 + """Max number of tokens to generate.""" + batch_size: int = 1 + """Batch size for inputs.""" + top_k: Optional[int] = None + """Top-k sampling (optional).""" + seed: Optional[int] = None + """Random seed for reproducibility.""" + deterministic: bool = False + """Use deterministic algorithms.""" + prompt: str = "" + """Input prompt string.""" + out: bool = False + """If true, print JSON report at end.""" + + +class LogitsComparison(NamedTuple): + max_abs_diff: float + mean_abs_diff: float + max_rel_diff: float + mean_rel_diff: float + allclose_results: Sequence[Tuple[float, float, str, bool]] + sample_diffs: Optional[torch.Tensor] + systematic_offset: Optional[Tuple[float, float]] + + +def load_hf_model(path: str, device: torch.device) -> nn.Module: + model = AutoModelForCausalLM.from_pretrained(path).to(device) + model.eval() + return model + +def print_param_dtypes_first_block(model): + """ + Prints the dtype of every parameter in the given model. + For any parameters under a 'layers' module (e.g., layers.), + only prints those from the first block (idx == "0"). + This works for both GptOssForCausalLM (with a .model submodule) + and GptOssModel architectures. + """ + for name, param in model.named_parameters(): + parts = name.split('.') + # If this parameter is under a 'layers' module, check its index + if 'layers' in parts: + idx = parts.index('layers') + 1 + if idx < len(parts) and parts[idx] != '0': + continue + print(f"{name:50s} → {param.dtype}") + +def get_logits(model: nn.Module, input_ids: torch.Tensor) -> torch.Tensor: + with torch.no_grad(): + out = model(input_ids) + if hasattr(out, "logits"): + return out.logits + else: + return out + + +def compare_logits( + tt_logits: torch.Tensor, + hf_logits: torch.Tensor, + tolerances: Sequence[Tuple[float, float, str]] = ( + (1e-4, 1e-6, "Very Strict"), + (1e-2, 1e-4, "Strict"), + (1e-1, 1e-2, "Moderate"), + ), +) -> LogitsComparison: + # Apply softmax to convert logits to probabilities + hf_logits = torch.nn.functional.softmax(hf_logits.float(), dim=-1) + tt_logits = torch.nn.functional.softmax(tt_logits.float(), dim=-1) + + diff = torch.abs(tt_logits - hf_logits) + max_abs = float(torch.max(diff)) + mean_abs = float(torch.mean(diff)) + rel = diff / (torch.abs(tt_logits) + 1e-8) + max_rel = float(torch.max(rel)) + mean_rel = float(torch.mean(rel)) + + results = [] + any_match = False + for rtol, atol, name in tolerances: + match = torch.allclose(tt_logits, hf_logits, rtol=rtol, atol=atol) + results.append((rtol, atol, name, bool(match))) + if match: + any_match = True + break + + sample_diffs = None + sys_offset = None + if not any_match: + flat = (tt_logits - hf_logits).flatten() + sample_diffs = flat[:25] + sys_offset = (float(torch.mean(flat)), float(torch.std(flat))) + + return LogitsComparison(max_abs, mean_abs, max_rel, mean_rel, results, sample_diffs, sys_offset) + + +def generate_text( + model: nn.Module, + input_ids: torch.Tensor, + max_new_tokens: int, + temperature: float = 0.0, + top_k: Optional[int] = None, +) -> torch.Tensor: + do_sample = temperature > 0 + temp_arg = temperature if do_sample else None + with torch.no_grad(): + return model.generate( + input_ids, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temp_arg, + top_k=top_k, + ) + + +def print_logits_comparison(comp: LogitsComparison): + print("\n" + "="*70) + print("LOGITS COMPARISON") + print("="*70) + print(f"Max abs diff: {comp.max_abs_diff:.6f}") + print(f"Mean abs diff: {comp.mean_abs_diff:.6f}") + print(f"Max rel diff: {comp.max_rel_diff:.6f}") + print(f"Mean rel diff: {comp.mean_rel_diff:.6f}\n") + print("Tolerance tests:") + for rtol, atol, name, match in comp.allclose_results: + print(f" {'āœ…' if match else 'āŒ'} {name} (rtol={rtol}, atol={atol})") + if comp.sample_diffs is not None: + print("\nšŸ” Sample diffs (first 25):") + for v in comp.sample_diffs.tolist(): + print(f" {v:.6f}") + mean, std = comp.systematic_offset + print(f"\nSystematic offset: mean={mean:.6f}, std={std:.6f}") + + +def print_generation(title: str, outputs: torch.Tensor, tokenizer): + text = tokenizer.decode(outputs[0].tolist()) + print("\n" + "="*60) + print(title) + print("="*60) + print(text) + print("="*60) + + +def print_generation_comparison( + tt_out: torch.Tensor, + hf_out: torch.Tensor, + tokenizer, + prompt_len: int, +): + tt_tokens = tt_out[0][prompt_len:].tolist() + hf_tokens = hf_out[0][prompt_len:].tolist() + n = min(len(tt_tokens), len(hf_tokens)) + matches = sum(1 for i in range(n) if tt_tokens[i] == hf_tokens[i]) + print("\n" + "="*70) + print("GENERATION COMPARISON") + print("="*70) + print(f"Match rate: {matches}/{n} ({matches/n*100:.1f}%)") + if matches != n or len(tt_tokens) != len(hf_tokens): + print("First mismatches:") + for i in range(min(10, n)): + if tt_tokens[i] != hf_tokens[i]: + tt_txt = tokenizer.decode([tt_tokens[i]]) + hf_txt = tokenizer.decode([hf_tokens[i]]) + print(f" Pos {i}: TT='{tt_txt}' vs HF='{hf_txt}'") + + +@record +def test_generate(args: GenerateConfig): + init_logger() + + if not args.hf_model_path and not args.tt_config: + raise ValueError("Either hf_model_path or tt_config must be provided.") + if not args.prompt: + logger.warning("Empty prompt; generating from scratch.") + + # --- Common setup: tokenizer & inputs --- + if args.hf_model_path: + tokenizer = AutoTokenizer.from_pretrained(args.hf_model_path) + input_ids = tokenizer.encode(args.prompt, add_special_tokens=False, return_tensors="pt") + print(input_ids) + if args.tt_config: + config_mgr = ConfigManager() + config = config_mgr.parse_args([ + f"--job.config_file={args.tt_config}", + f"--model.tokenizer_path={args.tt_tokenizer_path}", + ]) + train_spec = get_train_spec(config.model.name) + + # --- HuggingFace model (optional) --- + hf_model = None + hf_logits = None + hf_out = None + if args.hf_model_path: # NOTE: comment this block out for rapid tt testing + hf_device = torch.device(f"{device_type}:0") + hf_model = load_hf_model(args.hf_model_path, hf_device) + print("\n" + "="*60) + print("HUGGINGFACE MODEL ARCHITECTURE:") + print(hf_model) + print("="*60) + print_param_dtypes_first_block(hf_model) + print("="*60) + + hf_in = input_ids.to(hf_device) + hf_logits = get_logits(hf_model, hf_in).to(input_ids.device) + print(f"hf_logits: {hf_logits[:, :, 42069:42072]}") + hf_out = generate_text( + hf_model, hf_in, + max_new_tokens=args.max_new_tokens, + temperature=0.0, + top_k=args.top_k, + ).to(input_ids.device) + + # --- TorchTitan model (optional) --- + tt_model = None + tt_logits = None + tt_out = None + if args.tt_config: + # (Original TT setup: distributed, device, checkpoint load, etc.) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + device = torch.device(f"{device_type}:1") + device_module.set_device(device) + dist_utils.set_determinism(None, device, args.seed, args.deterministic) + + # instantiate & load TT model + model_args = train_spec.config[config.model.flavor] + model_args.update_from_config(config, tokenizer) + init_dev = "meta" if world_size > 1 else device + with torch.device(init_dev): + tt_model = train_spec.cls(model_args) + if world_size > 1: + # parallelize if needed + pass + print("\n" + "="*60) + print("TORCHTITAN MODEL ARCHITECTURE:") + print(tt_model) + print("="*60) + print_param_dtypes_first_block(tt_model) + print("="*60) + + tt_model.eval() + if args.tt_checkpoint_path: # only load checkpoint if provided + tt_state = tt_model.state_dict() + tt_state.pop("freqs_cis", None) + state = {"model": tt_state} + dcp.load(state, checkpoint_id=args.tt_checkpoint_path) + + tt_logits = get_logits(tt_model, input_ids.to(device)).to(hf_logits.device if hf_logits is not None else device) + print(f"āœ… Torchtitan model forward pass succeeded: {tt_logits.shape=}") + print(f"tt_logits: {tt_logits[:, :, 42069:42072]}") + + tt_out = tt_generate_text( + tt_model, input_ids.to(device), + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + top_k=args.top_k, + seed=args.seed, + ) + + # --- Logits comparison (if both present) --- + if hf_logits is not None and tt_logits is not None: + comp = compare_logits(tt_logits, hf_logits) + print_logits_comparison(comp) + + # --- Print generations --- + if hf_out is not None: + print_generation("HUGGINGFACE MODEL OUTPUT:", hf_out, tokenizer) + if tt_out is not None: + print_generation("TORCHTITAN MODEL OUTPUT:", tt_out, tokenizer) + + # --- Generation comparison --- + if hf_out is not None and tt_out is not None: + prompt_len = input_ids.size(1) + print_generation_comparison(tt_out, hf_out, tokenizer, prompt_len) + + +if __name__ == "__main__": + args = tyro.cli(GenerateConfig) + test_generate(args) diff --git a/torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py b/torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py new file mode 100644 index 0000000000..f69d5898d7 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py @@ -0,0 +1,513 @@ +""" +Convert checkpoints between TorchTitan and HuggingFace. + +# Convert HF to TorchTitan DCP +uv run torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py hf-to-dcp --input-path openai/gpt-oss-20b --output-path gptoss_dcp/ + +# Convert TorchTitan DCP to HF +uv run torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py dcp-to-hf --input-path gptoss_dcp/ --output-path gptoss_hf/ +""" + +import tempfile +from pathlib import Path +from typing import Union + +import torch +import torch.distributed.checkpoint as DCP +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LlamaConfig +from tqdm import tqdm +from tyro.extras import SubcommandApp + +from torchtitan.tools.logging import init_logger, logger + +app = SubcommandApp() + + +def validate_config_compatibility(hf_config, torchtitan_config_name, torchtitan_configs): + """Validate that HF config is compatible with TorchTitan config.""" + if torchtitan_config_name not in torchtitan_configs: + available = list(torchtitan_configs.keys()) + raise ValueError(f"TorchTitan config '{torchtitan_config_name}' not found. Available: {available}") + + tt_config = torchtitan_configs[torchtitan_config_name] + + # Critical configuration checks with proper field mappings + checks = [ + ("vocab_size", "vocab_size"), + ("hidden_size", "hidden_size"), + ("num_hidden_layers", "num_hidden_layers"), + ("head_dim", "head_dim"), + ("num_attention_heads", "num_attention_heads"), + ("num_key_value_heads", "num_key_value_heads"), + ("sliding_window", "sliding_window"), + ("num_local_experts", "num_local_experts"), + ("num_experts_per_tok", "num_experts_per_tok"), + ("rope_theta", "rope_theta"), + # ("rope_scaling.factor", "rope_factor"), + # ("rope_scaling.beta_fast", "beta_fast"), + # ("rope_scaling.beta_slow", "beta_slow"), + ] + + mismatches = [] + warnings = [] + + for hf_attr, tt_attr in checks: + hf_val = getattr(hf_config, hf_attr, None) + tt_val = getattr(tt_config, tt_attr, None) + + if hf_val != tt_val: + mismatches.append(f"{hf_attr}: HF={hf_val} vs TT.{tt_attr}={tt_val}") + + if mismatches: + raise ValueError(f"Config mismatch for {torchtitan_config_name}:\n" + "\n".join(mismatches)) + + if warnings: + print(f"āš ļø Configuration warnings for {torchtitan_config_name}:") + for warning in warnings: + print(f" {warning}") + print(" These differences might affect model behavior but won't prevent conversion.") + + print(f"āœ“ Configuration validation passed for {torchtitan_config_name}") + return tt_config + +def validate_tt_keys(tt_sd, n_layers, strict=True): + """Ensure the TorchTitan dict looks like gpt-oss as encoded in hf->tt mapping.""" + top_expected = [ + "tok_embeddings.weight", + "output.weight", + "norm.weight", + ] + per_layer_expected = [ + # attention projections + biases + sinks + "attention.wq.weight", "attention.wq.bias", + "attention.wk.weight", "attention.wk.bias", + "attention.wv.weight", "attention.wv.bias", + "attention.wo.weight", "attention.wo.bias", + "attention.sinks", + # MoE experts (mlp1/2) + biases + "moe.experts.mlp1_weight", "moe.experts.mlp1_bias", + "moe.experts.mlp2_weight", "moe.experts.mlp2_bias", + # Router + "moe.router.gate.weight", "moe.router.gate.bias", + # Norms + "attention_norm.weight", "ffn_norm.weight", + ] + + missing = [] + for k in top_expected: + if k not in tt_sd: + missing.append(k) + + for i in range(n_layers): + base = f"layers.{i}." + for suffix in per_layer_expected: + key = base + suffix + if key not in tt_sd: + missing.append(key) + + if missing and strict: + preview = "\n - " + "\n - ".join(missing[:20]) + more = "" if len(missing) <= 20 else f"\n ...and {len(missing)-20} more" + raise KeyError( + "TorchTitan checkpoint is missing keys required for gpt-oss inverse mapping:" + f"{preview}{more}" + ) + return missing # may be useful for logging if strict=False + +def validate_hf_keys(hf_state_dict, model_config, model_name): + """Validate that all expected weight keys exist in the HF state dict.""" + missing_keys = [] + n_layers = model_config.num_hidden_layers + + # Check basic weights + required_keys = [ + "model.embed_tokens.weight", + "lm_head.weight", + "model.norm.weight" + ] + + for key in required_keys: + if key not in hf_state_dict: + missing_keys.append(key) + + # Check layer weights + for layer_idx in range(n_layers): + layer_prefix = f'model.layers.{layer_idx}' + + # Check attention weights + attention_keys = [ + f"{layer_prefix}.self_attn.q_proj.weight", + f"{layer_prefix}.self_attn.k_proj.weight", + f"{layer_prefix}.self_attn.v_proj.weight", + f"{layer_prefix}.self_attn.o_proj.weight", + f"{layer_prefix}.self_attn.q_proj.bias", + f"{layer_prefix}.self_attn.k_proj.bias", + f"{layer_prefix}.self_attn.v_proj.bias", + f"{layer_prefix}.self_attn.o_proj.bias", + f"{layer_prefix}.input_layernorm.weight", + f"{layer_prefix}.post_attention_layernorm.weight", + ] + + for key in attention_keys: + if key not in hf_state_dict: + missing_keys.append(key) + + # Check MoE weights + mlp_keys = [ + f"{layer_prefix}.mlp.router.weight", + f"{layer_prefix}.mlp.router.bias", + f"{layer_prefix}.mlp.experts.gate_up_proj", + f"{layer_prefix}.mlp.experts.gate_up_proj_bias", + f"{layer_prefix}.mlp.experts.down_proj", + f"{layer_prefix}.mlp.experts.down_proj_bias", + ] + + for key in mlp_keys: + if key not in hf_state_dict: + missing_keys.append(key) + + if missing_keys: + logger.error(f"Missing {len(missing_keys)} expected weight keys in HF model:") + for key in missing_keys[:10]: # Show first 10 + logger.error(f" - {key}") + if len(missing_keys) > 10: + logger.error(f" ... and {len(missing_keys) - 10} more") + + # Try to diagnose the issue + logger.info("Available keys in HF model:") + available_keys = list(hf_state_dict.keys()) + for key in available_keys[:20]: # Show first 20 + logger.info(f" - {key}") + if len(available_keys) > 20: + logger.info(f" ... and {len(available_keys) - 20} more") + + raise ValueError(f"HF model '{model_name}' is missing expected weight keys. " + f"This suggests the model architecture doesn't match expectations.") + + logger.info(f"āœ“ Weight key validation passed - found all expected keys") + + +def map_hf_to_torchtitan(hf_state_dict, model_config, max_seq_len=131072, rope_theta=500000.0, model_name="meta-llama/Llama-3.1-8B"): + """Map HuggingFace state dict to TorchTitan format. + + Note: TorchTitan and HuggingFace use different RoPE implementations: + - TorchTitan: Adjacent element pairing with complex arithmetic + - HuggingFace: First/second half pairing with cos/sin arithmetic + + This difference is architectural, not a bug. Converted models will have + slightly different positional encoding but typically minimal impact on performance. + """ + + # Validate that all expected keys exist + validate_hf_keys(hf_state_dict, model_config, model_name) + + n_layers = model_config.num_hidden_layers + n_heads = model_config.num_attention_heads + dim = model_config.hidden_size + dims_per_head = dim // n_heads + + # Fix: Corrected model family detection logic + if "llama" in model_name.lower(): + model_family = "llama3" + elif "qwen" in model_name.lower(): + model_family = "qwen3" + max_seq_len = model_config.max_position_embeddings + rope_theta = model_config.rope_theta + elif "gpt-oss" in model_name.lower(): + model_family = "gptoss" + max_seq_len = model_config.max_position_embeddings + rope_theta = model_config.rope_theta + else: + raise ValueError(f"Unsupported HuggingFace model for conversion: {model_name}") + + # Determine n_kv_heads for GQA models + n_kv_heads = model_config.num_key_value_heads + head_dim = model_config.head_dim + print(f"Model info: dim={dim}, n_heads={n_heads}, n_kv_heads={n_kv_heads}, head_dim={head_dim}, model_family={model_family}, max_seq_len={max_seq_len}, rope_theta={rope_theta}") + torchtitan_state_dict = {} + + # Convert embeddings and output + torchtitan_state_dict["tok_embeddings.weight"] = hf_state_dict["model.embed_tokens.weight"].clone() + torchtitan_state_dict["output.weight"] = hf_state_dict["lm_head.weight"].clone() + torchtitan_state_dict["norm.weight"] = hf_state_dict["model.norm.weight"].clone() + + def permute(w, n_heads_arg, dim1=None, dim2=None): + if dim1 is None: + dim1 = w.shape[0] + if dim2 is None: + dim2 = w.shape[1] + return w.view(n_heads_arg, 2, dim1 // n_heads_arg // 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + # Convert layers + for layer_idx in tqdm(range(n_layers), desc="Converting layers"): + hf_layer_prefix = f'model.layers.{layer_idx}' + layer_prefix = f'layers.{layer_idx}' + + wq = hf_state_dict[f'{hf_layer_prefix}.self_attn.q_proj.weight'] + torchtitan_state_dict[f'{layer_prefix}.attention.wq.weight'] = wq.clone() + wq_bias = hf_state_dict[f'{hf_layer_prefix}.self_attn.q_proj.bias'] + torchtitan_state_dict[f'{layer_prefix}.attention.wq.bias'] = wq_bias.clone() + + wk = hf_state_dict[f'{hf_layer_prefix}.self_attn.k_proj.weight'] + torchtitan_state_dict[f'{layer_prefix}.attention.wk.weight'] = wk.clone() + wk_bias = hf_state_dict[f'{hf_layer_prefix}.self_attn.k_proj.bias'] + torchtitan_state_dict[f'{layer_prefix}.attention.wk.bias'] = wk_bias.clone() + + wv = hf_state_dict[f'{hf_layer_prefix}.self_attn.v_proj.weight'] + torchtitan_state_dict[f'{layer_prefix}.attention.wv.weight'] = wv.clone() + wv_bias = hf_state_dict[f'{hf_layer_prefix}.self_attn.v_proj.bias'] + torchtitan_state_dict[f'{layer_prefix}.attention.wv.bias'] = wv_bias.clone() + + wo = hf_state_dict[f'{hf_layer_prefix}.self_attn.o_proj.weight'] + torchtitan_state_dict[f'{layer_prefix}.attention.wo.weight'] = wo.clone() + wo_bias = hf_state_dict[f'{hf_layer_prefix}.self_attn.o_proj.bias'] + torchtitan_state_dict[f'{layer_prefix}.attention.wo.bias'] = wo_bias.clone() + + sinks = hf_state_dict[f'{hf_layer_prefix}.self_attn.sinks'] + torchtitan_state_dict[f'{layer_prefix}.attention.sinks'] = sinks.clone() + + # MoE weights + mlp1 = hf_state_dict[f'{hf_layer_prefix}.mlp.experts.gate_up_proj'] + torchtitan_state_dict[f'{layer_prefix}.moe.experts.mlp1_weight'] = mlp1.clone() + + mlp1_bias = hf_state_dict[f'{hf_layer_prefix}.mlp.experts.gate_up_proj_bias'] + torchtitan_state_dict[f'{layer_prefix}.moe.experts.mlp1_bias'] = mlp1_bias.clone() + + mlp2 = hf_state_dict[f'{hf_layer_prefix}.mlp.experts.down_proj'] + torchtitan_state_dict[f'{layer_prefix}.moe.experts.mlp2_weight'] = mlp2.clone() + + mlp2_bias = hf_state_dict[f'{hf_layer_prefix}.mlp.experts.down_proj_bias'] + torchtitan_state_dict[f'{layer_prefix}.moe.experts.mlp2_bias'] = mlp2_bias.clone() + + # router + gate = hf_state_dict[f'{hf_layer_prefix}.mlp.router.weight'] + torchtitan_state_dict[f'{layer_prefix}.moe.router.gate.weight'] = gate.clone() + router_bias = hf_state_dict[f'{hf_layer_prefix}.mlp.router.bias'] + torchtitan_state_dict[f'{layer_prefix}.moe.router.gate.bias'] = router_bias.clone() + + # # @vwxyzjn: This is technically not needed, but we added here because we haven't figured out + # # how to tell torchtitan to ignore this parameter. + # tokens_per_expert = torch.zeros_like(expert_bias) + # torchtitan_state_dict[f'{layer_prefix}.moe.tokens_per_expert'] = tokens_per_expert.clone() + + # Layer norms + attention_norm = hf_state_dict[f'{hf_layer_prefix}.input_layernorm.weight'] + torchtitan_state_dict[f'{layer_prefix}.attention_norm.weight'] = attention_norm.clone() + ffn_norm = hf_state_dict[f'{hf_layer_prefix}.post_attention_layernorm.weight'] + torchtitan_state_dict[f'{layer_prefix}.ffn_norm.weight'] = ffn_norm.clone() + + # Precompute RoPE frequencies + # NOTE: we no longer precompute RoPE frequencies in TorchTitan + # this `model_config` is HF but needs to be TT (to include e.g. beta_fast) + # torchtitan_state_dict["freqs_cis"] = precompute_freqs_cis(model_config) + + print(f"Converted {len(torchtitan_state_dict)} parameters from HuggingFace to TorchTitan format") + return torchtitan_state_dict + + +def num_layers_from_keys(state_dict): + layer_idxs = [] + pat = re.compile(r"^layers\.(\d+)\.") + for k in state_dict.keys(): + m = pat.match(k) + if m: + layer_idxs.append(int(m.group(1))) + if not layer_idxs: + raise ValueError("Could not find any 'layers..' keys in the TorchTitan state dict.") + return max(layer_idxs) + 1 + +# TODO: correctness of map_torchtitan_to_hf is not yet tested for GPT-OSS +def map_torchtitan_to_hf(torchtitan_state_dict, *, strict=True): + """ + Map TorchTitan (DCP) state dict -> HuggingFace format for *gpt-oss only*. + + This is the exact inverse of your `map_hf_to_torchtitan`: + - No weight permutations. + - Copies biases for q/k/v/o and MoE projections. + - Preserves `.attention.sinks`. + - MoE and router parameters use the same custom names you used on the HF side + (i.e., HF bias keys are `gate_up_proj_bias` / `down_proj_bias`). + + Parameters + ---------- + torchtitan_state_dict : dict[str, Tensor-like] + TorchTitan checkpoint (flat dict). + strict : bool + If True, error on any missing keys. If False, copy what exists and skip missing. + + Returns + ------- + dict[str, Tensor-like] + HuggingFace-formatted state dict. + """ + tt = torchtitan_state_dict + n_layers = num_layers_from_keys(tt) + validate_tt_keys(tt, n_layers, strict=strict) + + hf = {} + + # Top-level + if "tok_embeddings.weight" in tt: hf["model.embed_tokens.weight"] = tt["tok_embeddings.weight"].clone() + if "output.weight" in tt: hf["lm_head.weight"] = tt["output.weight"].clone() + if "norm.weight" in tt: hf["model.norm.weight"] = tt["norm.weight"].clone() + + # Per-layer mappings (exact inverse of your hf->tt) + for i in range(n_layers): + tt_pref = f"layers.{i}" + hf_pref = f"model.layers.{i}" + + # Attention projections (+biases) + m = { + f"{tt_pref}.attention.wq.weight": (f"{hf_pref}.self_attn.q_proj.weight",), + f"{tt_pref}.attention.wq.bias": (f"{hf_pref}.self_attn.q_proj.bias",), + f"{tt_pref}.attention.wk.weight": (f"{hf_pref}.self_attn.k_proj.weight",), + f"{tt_pref}.attention.wk.bias": (f"{hf_pref}.self_attn.k_proj.bias",), + f"{tt_pref}.attention.wv.weight": (f"{hf_pref}.self_attn.v_proj.weight",), + f"{tt_pref}.attention.wv.bias": (f"{hf_pref}.self_attn.v_proj.bias",), + f"{tt_pref}.attention.wo.weight": (f"{hf_pref}.self_attn.o_proj.weight",), + f"{tt_pref}.attention.wo.bias": (f"{hf_pref}.self_attn.o_proj.bias",), + + # Sinks tensor + f"{tt_pref}.attention.sinks": (f"{hf_pref}.self_attn.sinks",), + + # MoE experts (your custom naming on HF side) + f"{tt_pref}.moe.experts.mlp1_weight": (f"{hf_pref}.mlp.experts.gate_up_proj",), + f"{tt_pref}.moe.experts.mlp1_bias": (f"{hf_pref}.mlp.experts.gate_up_proj_bias",), + f"{tt_pref}.moe.experts.mlp2_weight": (f"{hf_pref}.mlp.experts.down_proj",), + f"{tt_pref}.moe.experts.mlp2_bias": (f"{hf_pref}.mlp.experts.down_proj_bias",), + + # Router + f"{tt_pref}.moe.router.gate.weight": (f"{hf_pref}.mlp.router.weight",), + f"{tt_pref}.moe.router.gate.bias": (f"{hf_pref}.mlp.router.bias",), + + # Norms + f"{tt_pref}.attention_norm.weight": (f"{hf_pref}.input_layernorm.weight",), + f"{tt_pref}.ffn_norm.weight": (f"{hf_pref}.post_attention_layernorm.weight",), + } + + for tt_key, (hf_key,) in m.items(): + if tt_key in tt: + hf[hf_key] = tt[tt_key].clone() + elif strict: + raise KeyError(f"Missing expected key in TorchTitan state dict: '{tt_key}'") + + print(f"Converted {len(hf)} parameters from TorchTitan to HuggingFace format (gpt-oss).") + return hf + + +@app.command(name="hf_to_dcp") +@torch.inference_mode() +def convert_hf_to_dcp(input_path: str, output_path: Path, max_seq_len: int = 131072, rope_theta: float = 150000.0, dtype: str = "auto", torchtitan_config: str = "20B"): + """Convert HuggingFace model to TorchTitan DCP format. + + Args: + input_path: HuggingFace model name or path + output_path: Output DCP checkpoint path + max_seq_len: Max sequence length for RoPE + rope_theta: RoPE theta parameter + dtype: Data type to use ("auto" to preserve original, or specific dtype like "float32") + torchtitan_config: TorchTitan model config name (e.g., "16B-A3B", "debugmodel") + """ + # Import TorchTitan configs + try: + from torchtitan.models.gpt_oss import gptoss_configs + except ImportError: + raise ImportError("Cannot import TorchTitan GPT-OSS configs. Make sure you're in the right environment.") + + logger.info(f"Loading model from {input_path}") + + # Load model with original dtype if "auto", otherwise use specified dtype + hf_model = AutoModelForCausalLM.from_pretrained(input_path, torch_dtype=torch.bfloat16) + + # Validate configuration compatibility + logger.info(f"Validating config compatibility with TorchTitan config: {torchtitan_config}") + validate_config_compatibility(hf_model.config, torchtitan_config, gptoss_configs) + + hf_state_dict = hf_model.state_dict() + logger.info(f"Loaded model with dtype: {next(iter(hf_state_dict.values())).dtype}") + + logger.info("Converting weights to TorchTitan format") + torchtitan_state_dict = map_hf_to_torchtitan(hf_state_dict, hf_model.config, max_seq_len, rope_theta, input_path) + + logger.info(f"Writing to DCP at '{output_path}'") + output_path.mkdir(parents=True, exist_ok=True) + storage_writer = DCP.filesystem.FileSystemWriter(output_path, thread_count=8) + DCP.save({"model": torchtitan_state_dict}, storage_writer=storage_writer) + + # Save metadata for reference + metadata = { + "original_hf_model": input_path, + "torchtitan_config": torchtitan_config, + "conversion_time": str(torch.tensor(0).item()), # placeholder + "hf_config": dict(hf_model.config.__dict__), + "torchtitan_config_dict": dict(gptoss_configs[torchtitan_config].__dict__), + } + with open(output_path / "conversion_metadata.json", "w") as f: + import json + json.dump(metadata, f, indent=2, default=str) + + logger.info("Conversion complete!") + logger.info(f"šŸ“‹ Saved conversion metadata to {output_path}/conversion_metadata.json") + logger.info(f"šŸš€ To use in TorchTitan, specify model config: {torchtitan_config}") + + # Final reminder about RoPE differences + if "gpt-oss" in input_path.lower(): + logger.info(f"") + logger.info(f"šŸ”” IMPORTANT: Converted GPT-OSS model uses TorchTitan's RoPE implementation") + logger.info(f" This differs from HuggingFace but is expected behavior") + logger.info(f" See conversion script documentation for details") + + +@app.command(name="dcp_to_hf") +@torch.inference_mode() +def convert_dcp_to_hf(input_path: Path, output_path: Path, max_seq_len: int = 131072, rope_theta: float = 500000.0, default_model: str = "meta-llama/Meta-Llama-3.1-8B"): + """Convert TorchTitan DCP format to HuggingFace model. + + Args: + input_path: Input DCP checkpoint path + output_path: Output HuggingFace model path + max_seq_len: Max sequence length for RoPE + rope_theta: RoPE theta parameter + default_model: Default HuggingFace model for config + """ + from torchtitan.datasets.transformation import get_tokenizer_with_chat_template + from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner + from torch.distributed.checkpoint.state_dict_loader import _load_state_dict + logger.info(f"Loading DCP checkpoint from {input_path}") + + # Load DCP input_path + state_dict = {} + _load_state_dict( + state_dict, + storage_reader=DCP.filesystem.FileSystemReader(input_path), + planner=_EmptyStateDictLoadPlanner(), + no_dist=True, + ) + torchtitan_state_dict = state_dict["model"] + logger.info("Converting weights to HuggingFace format") + hf_state_dict = map_torchtitan_to_hf(torchtitan_state_dict, max_seq_len, rope_theta) + + # Create HuggingFace config + hf_config = AutoConfig.from_pretrained(default_model) + + # Create and load model + logger.info("Creating HuggingFace model") + # tokenizer = AutoTokenizer.from_pretrained(default_model) + tokenizer = get_tokenizer_with_chat_template(default_model, "tulu", override=True) + hf_model = AutoModelForCausalLM.from_pretrained(default_model) + + # load state dict + logger.info("Loading state dict") + hf_model.load_state_dict(hf_state_dict, strict=True) + + # Save model + logger.info(f"Saving model to {output_path}") + output_path.mkdir(parents=True, exist_ok=True) + hf_model.save_pretrained(output_path) + tokenizer.save_pretrained(output_path) + logger.info("Conversion complete!") + + +if __name__ == "__main__": + init_logger() + app.cli() diff --git a/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml b/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml new file mode 100644 index 0000000000..878e478ff5 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml @@ -0,0 +1,73 @@ +# torchtitan Config.toml + +[job] +dump_folder = "./outputs" +description = "GPT-OSS debug training" +print_args = false +use_for_integration_test = true + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "gpt_oss" +flavor = "debugmodel" +# test tokenizer, for debug purpose only +tokenizer_path = "./tests/assets/tokenizer" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +lr_min = 0.0 + +[training] +local_batch_size = 8 +seq_len = 2048 +max_norm = 1.0 # grad norm clipping +steps = 1 +compile = false +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 2 +enable_async_tensor_parallel = false +expert_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval = 10 +last_save_model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "none" # ["none", "selective", "full"] +selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output", "router.gate"] +moe_fqns = ["experts"] diff --git a/torchtitan/experiments/gpt_oss/train_configs/gpt_oss_120b.toml b/torchtitan/experiments/gpt_oss/train_configs/gpt_oss_120b.toml new file mode 100644 index 0000000000..81908972ad --- /dev/null +++ b/torchtitan/experiments/gpt_oss/train_configs/gpt_oss_120b.toml @@ -0,0 +1,70 @@ +# torchtitan Config.toml + +[job] +dump_folder = "./outputs" +description = "GPT-OSS 120B model training" +print_args = false + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 10 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "gpt_oss" +flavor = "120B" +tokenizer_path = "./assets/tokenizer/GPT-OSS" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 2.2e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2_000 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +lr_min = 2.2e-5 + +[training] +local_batch_size = 4 +seq_len = 4096 +max_norm = 1.0 # grad norm clipping +steps = 10_000 +compile = false +dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 8 +enable_async_tensor_parallel = false +expert_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval = 500 +last_save_model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" + +[activation_checkpoint] +mode = "full" # ["none", "selective", "full"] + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output", "router.gate"] +moe_fqns = ["experts"] diff --git a/torchtitan/experiments/gpt_oss/train_configs/gpt_oss_20b.toml b/torchtitan/experiments/gpt_oss/train_configs/gpt_oss_20b.toml new file mode 100644 index 0000000000..88d1c4d27f --- /dev/null +++ b/torchtitan/experiments/gpt_oss/train_configs/gpt_oss_20b.toml @@ -0,0 +1,70 @@ +# torchtitan Config.toml + +[job] +dump_folder = "./outputs" +description = "GPT-OSS 20B model training" +print_args = false + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 10 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "gpt_oss" +flavor = "20B" +tokenizer_path = "./assets/tokenizer/GPT-OSS" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 2.2e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +lr_min = 2.2e-5 + +[training] +local_batch_size = 8 +seq_len = 4096 +max_norm = 1.0 # grad norm clipping +steps = 1000 +compile = false +dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +expert_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval = 10 +last_save_model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" + +[activation_checkpoint] +mode = "full" # ["none", "selective", "full"] + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output", "router.gate"] +moe_fqns = ["experts"] diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 570d894f51..334d3da935 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -76,15 +76,80 @@ def __init__( def mask_key(self) -> FLEX_ATTN_MASK_T: return (self.attn_mask_type, self.fixed_block_size) - def forward( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - scale: float | None = None, - ) -> torch.Tensor: - block_mask = FlexAttention.block_masks[self.mask_key] - return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, scale=scale) + def forward(self, q, k, v, sink_weights=None, sliding_window=0, enable_gqa=False): + """ + q : (B, H_q, S_q, D) + k : (B, H_kv, S_kv, D) -- without sink + v : (B, H_kv, S_kv, D) + sink_weights : (H_q,) or (H, M) -- broadcast to all queries + sliding_window : int + enable_gqa : bool + """ + if sink_weights is None: + block_mask = FlexAttention.block_masks[self.mask_key] + return FlexAttention.flex_attn(q, k, v, block_mask=block_mask) + + B, H_q, S_q, D = q.shape + _, H_kv, S_kv, _ = k.shape + sink_idx = S_kv # sink occupies final key slot + + sink_k = k.new_zeros(B, H_kv, 1, D) # this needn't be 0's since it's overwritten + sink_v = v.new_zeros(B, H_kv, 1, D) # 0 value nullifies sink weight in output + + k_ext = torch.cat([k, sink_k], dim=2) + v_ext = torch.cat([v, sink_v], dim=2) + + # masks ensure sinks are included in softmax + if sliding_window is not None and sliding_window > 0: + mask_mod = FlexAttention._get_sliding_window_with_sink_mask_mod(sliding_window, sink_idx) + else: + mask_mod = FlexAttention._get_causal_with_sink_mask_mod(sink_idx) + + block_mask = FlexAttention.compiled_create_block_mask( + mask_mod, B, H_q, S_q, S_kv+1 + ) + + # overwrite the dummy sink scores with actual sink weights + def score_mod(score, b, h_q, q_idx, kv_idx): + return torch.where( + kv_idx == sink_idx, + sink_weights[h_q].to(score.dtype) + 0.0, # cast + keep grad + score + ) + + return FlexAttention.flex_attn( + q, k_ext, v_ext, + block_mask=block_mask, + score_mod=score_mod, + enable_gqa=enable_gqa + ) + + @staticmethod + def _get_causal_with_sink_mask_mod(sink_idx): + """ + Returns a mask_mod function that + - only allows kv_idx ≤ q_idx (causal) + - or if kv_idx == sink_idx (always allow the sink) + """ + orig = FlexAttention._get_causal_mask_mod() + def causal_with_sink(b, h, q_idx, kv_idx): + return orig(b, h, q_idx, kv_idx) | (kv_idx == sink_idx) + return causal_with_sink + + @staticmethod + def _get_sliding_window_with_sink_mask_mod(window: int, sink_idx: int): + """ + Returns a mask_mod function that + - only allows kv_idx ≤ q_idx (causal) + - and only if (q_idx - kv_idx) ≤ window + - or if kv_idx == sink_idx (always allow the sink) + """ + def sliding_mod(b, h, q_idx, kv_idx): + # causal within window + keep = (kv_idx <= q_idx) & (q_idx - kv_idx <= window) + # always allow the sink slot + return keep | (kv_idx == sink_idx) + return sliding_mod @staticmethod def _get_causal_mask_mod() -> _mask_mod_signature: