Skip to content

[WIP] Experimental implementation of gpt-oss (grouped GEMM MoE + FlexAttention sink/sliding) #1559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions torchtitan/experiments/gpt_oss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.tokenizer import build_hf_tokenizer
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.experiments.llama4.optimizer import build_llama4_optimizers

from torchtitan.protocols.train_spec import register_train_spec, TrainSpec

from .infra.parallelize import parallelize_gptoss
from .model.args import GptOssModelArgs
from .model.model import GptOssModel

__all__ = [
"parallelize_gptoss",
"GptOssModelArgs",
"GptOssModel",
"gptoss_configs",
]


gptoss_configs = {
"debugmodel": GptOssModelArgs(
hidden_size=256,
num_hidden_layers=4,
),
"20B": GptOssModelArgs(
num_hidden_layers=24,
num_local_experts=32,
),
"120B": GptOssModelArgs(
num_hidden_layers=36,
num_local_experts=128,
),
}


register_train_spec(
TrainSpec(
name="gpt_oss",
cls=GptOssModel,
config=gptoss_configs,
parallelize_fn=parallelize_gptoss,
pipelining_fn=None,
build_optimizers_fn=build_llama4_optimizers, # use optimizer hooks to update expert weights
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
build_tokenizer_fn=build_hf_tokenizer,
build_loss_fn=build_cross_entropy_loss,
)
)
141 changes: 141 additions & 0 deletions torchtitan/experiments/gpt_oss/model/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from dataclasses import dataclass
from typing import Literal

from torch import nn

from torchtitan.components.tokenizer import Tokenizer
from torchtitan.config_manager import JobConfig
from torchtitan.protocols.train_spec import BaseModelArgs
from torchtitan.tools.logging import logger

# from transformers.models.gpt_oss.modeling_gpt_oss import GPT_OSS_PRETRAINED_INIT_CONFIGURATION


# Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
@dataclass
class GptOssModelArgs(BaseModelArgs):
"""
Data class for defining model arguments and hyperparameters.

Attributes:
max_batch_size (int): Maximum batch size.
max_seq_len (int): Maximum sequence length.
dtype (Literal["bf16", "fp8"]): Data type for computations.
vocab_size (int): Vocabulary size.
dim (int): Model dimension.
inter_dim (int): Intermediate dimension for MLP layers.
moe_inter_dim (int): Intermediate dimension for MoE layers.
n_layers (int): Number of transformer layers.
n_dense_layers (int): Number of dense layers in the model.
n_heads (int): Number of attention heads.
n_routed_experts (int): Number of routed experts for MoE layers.
n_shared_experts (int): Number of shared experts for MoE layers.
n_activated_experts (int): Number of activated experts in MoE layers.
n_expert_groups (int): Number of expert groups.
n_limited_groups (int): Number of limited groups for MoE routing.
score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing.
route_scale (float): Scaling factor for routing scores.
use_grouped_mm (bool): Whether to use grouped matrix multiplication for MoE layers.
load_balance_coeff (float | None): Auxiliary-Loss-Free Load balancing coefficient for MoE layers.
q_lora_rank (int): LoRA rank for query projections.
kv_lora_rank (int): LoRA rank for key-value projections.
qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings.
qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings.
v_head_dim (int): Dimension for value projections.
original_seq_len (int): Original sequence length.
rope_theta (float): Base for rotary positional encoding.
rope_factor (float): Scaling factor for extended sequence lengths.
beta_fast (int): Fast beta correction factor.
beta_slow (int): Slow beta correction factor.
"""

max_batch_size: int = 8
max_seq_len: int = 131072
dtype: Literal["bf16", "fp8"] = "bf16"
vocab_size: int = 201088
hidden_size: int = 2880
num_hidden_layers: int = 24
norm_eps: float = 1e-5 # eps used for RMSNorm
# MoE
num_local_experts: int = 32
num_experts_per_tok: int = 4
use_grouped_mm: bool = True
# Multi-Head Latent Attention (MLA)
head_dim: int = 64
num_attention_heads: int = 64
num_key_value_heads: int = 8
sliding_window: int = 128
use_flex_attn: bool = True
attn_mask_type: str = "causal"
# yarn
original_seq_len: int = 4096
rope_theta: float = 150000.0
rope_factor: float = 32
beta_fast: int = 32
beta_slow: int = 1

def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
"""
Update the model_config config from the given job config.
"""
# self.vocab_size = tokenizer.vocab_size # TODO: add tiktokenizer support?
self.max_seq_len = job_config.training.seq_len

def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
"""
Adopted from llama4 implementation.
"""
nparams_embedding = 0
nparams_moe_router = 0
nparams_shared_expert = 0
nparams_experts = 0
nparams_dense = 0

for name, p in model.named_parameters():
if "embedding" in name:
nparams_embedding += p.numel()
nparams_dense += p.numel()
elif "moe.shared_expert" in name:
nparams_shared_expert += p.numel()
elif "moe.router" in name:
nparams_moe_router += p.numel()
elif "moe.experts" in name:
nparams_experts += p.numel()
else:
nparams_dense += p.numel()

nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts
nparams = nparams_dense + nparams_sparse
nparams_sparse_active = (
nparams_moe_router
+ nparams_shared_expert
+ nparams_experts * self.num_experts_per_tok // self.num_local_experts
)

logger.info(
f"Total parameter count: dense {nparams_dense:,}, "
f"sparse {nparams_sparse:,}, active {nparams_dense + nparams_sparse_active:,}"
)

l, h, q, t = (
self.num_hidden_layers,
self.num_attention_heads,
self.hidden_size // self.num_attention_heads,
seq_len,
)
# Reasoning behind the factor of 12 for the self-attention part of the formula:
# 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
# 2. the flash attention does 1 more matmul recomputation in the backward
# but recomputation should not be counted in calculating MFU (+0)
# 3. each matmul performs 1 multiplication and 1 addition (*2)
# 4. we follow the convention and do not account for sparsity in causal attention
num_flops_per_token = (
6 * (nparams_dense - nparams_embedding + nparams_sparse_active)
+ 12 * l * h * q * t
)

return nparams, num_flops_per_token
Loading