From a4509d760564eb594b5efe2434933b0de1e8573f Mon Sep 17 00:00:00 2001 From: RaymondLi0 Date: Fri, 18 Jul 2025 11:42:08 -0400 Subject: [PATCH 01/37] fix dim name (#331) --- fast_llm/layers/ssm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index c69ada38..46d629aa 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -22,7 +22,7 @@ class SSMDimNames: v_heads = "v_heads" # Number of V heads # Mamba 2 - x_proj_dim_2 = "x_proj_dim" # d_xb + x_proj_dim_2 = "x_proj_dim_2" # d_xb class SSMBlockType(enum.StrEnum): From 82eed2b44c30c891ef2e07c2c80c4f5fcfa1e7f1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 21 Jul 2025 17:17:26 -0400 Subject: [PATCH 02/37] TP mamba --- fast_llm/layers/common/config.py | 6 +- fast_llm/layers/ssm/config.py | 214 +++++++++---- fast_llm/layers/ssm/discrete_mamba2.py | 39 ++- fast_llm/layers/ssm/llamba_block.py | 18 +- fast_llm/layers/ssm/mamba2.py | 302 +++++++----------- fast_llm/layers/ssm/mamba_layer.py | 159 ++++----- fast_llm/layers/transformer/attention.py | 3 +- fast_llm/layers/transformer/transformer.py | 27 +- fast_llm/models/custom/model.py | 4 +- fast_llm/models/gpt/model.py | 8 +- fast_llm/models/ssm/config.py | 42 +-- .../external/llamba/modeling_mtp_llamba.py | 10 +- fast_llm/models/ssm/model.py | 34 +- fast_llm/tensor.py | 8 +- setup.cfg | 2 +- tests/test_multi_stage.py | 4 +- 16 files changed, 407 insertions(+), 473 deletions(-) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 9f32ac68..07dadbc2 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -99,7 +99,7 @@ class LayerNormalizationBaseConfig(NormalizationConfig): ) def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": - from fast_llm.tensor import init_uniform_ + from fast_llm.tensor import init_uniform_centered_ kwargs = { "hidden_dim": hidden_dim, @@ -110,9 +110,7 @@ def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> " } if self.initialization_range: mean = 0 if self.zero_centered else 1 - kwargs["weight_init_method"] = init_uniform_( - mean - self.initialization_range, mean + self.initialization_range - ) + kwargs["weight_init_method"] = init_uniform_centered_(self.initialization_range, mean=mean) return self.module_class(**kwargs) @property diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index c69ada38..f4c8067d 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,28 +1,35 @@ import enum from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace +from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig -from fast_llm.utils import Assert +from fast_llm.utils import Assert, div class SSMDimNames: - model_dim = "model_dim" # Model dimension (D) - state_dim = "state_dim" # State dimension (N) - conv_dim = "conv_dim" # Dimension of the conv1d input in mamba layers - inner_dim = "inner_dim" # Inner dimension after expansion - dt_rank = "dt_rank" # Rank of Δ - inner_proj_mamba = "inner_proj_mamba" # Inner projection dimension for mamba - inner_proj_discrete_mamba2 = "inner_proj_discrete_mamba2" # Inner projection dimension for discrete mamba2 - inner_proj_mamba2 = "inner_proj_mamba2" # Inner projection dimension for mamba2 - x_proj_dim = "x_proj_dim" # X projection dimension - head_dim = "head_dim" # Dimension of the mamba2 head (P) - conv_kernel_size = "conv_kernel_size" # Kernel size of the conv1d in mamba layers - qk_heads = "qk_heads" # Number of QK heads - v_heads = "v_heads" # Number of V heads + # TODO: Use separate tensor space for different mixers so there is no risk of name conflict. + state = "ssm_state" # State dimension (N), aka head size / num channels + + head_groups = "ssm_head_groups" + group_heads = "ssm_group_heads" + + composite_heads = "ssm_composite_heads" + composite_heads_and_state = "ssm_composite_heads_and_state" + composite_head_groups_and_state = "ssm_composite_head_groups_and_state" + + # Inner projection total dimension. + inner_projection = "ssm_inner_projection" + composite_inner_projection = "ssm_composite_inner_projection" + + # Convolution shape in discrete mamba 2. TODO: Remove (dim too complex) + conv_dim = "ssm_conv_dim" + + dt_rank = "ssm_dt_rank" - # Mamba 2 - x_proj_dim_2 = "x_proj_dim" # d_xb + x_proj_dim = "x_proj_dim" # X projection dimension + conv_kernel = "conv_kernel" # Kernel size of the conv1d in mamba layers class SSMBlockType(enum.StrEnum): @@ -36,6 +43,16 @@ class SSMBlockType(enum.StrEnum): transformer = "t" +class DTInitType(enum.StrEnum): + constant = "constant" + random = "random" + + def get_init_method(self, scale: float): + from fast_llm.tensor import init_fill_, init_uniform_centered_ + + return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) + + @config_class() class SSMConfig(LLMBlockConfig): _abstract = False @@ -45,79 +62,87 @@ class SSMConfig(LLMBlockConfig): desc="Configuration for the normalization layers architecture.", hint=FieldHint.architecture, ) + + # Model dimensions + # TODO: Remove (redundant default) expansion_factor: int = Field( default=2, desc="Expansion factor for Mamba blocks.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + # head_size [MambaLayer, Mamba2, DiscreteMamba2] state_size: int = Field( default=16, desc="State size for Mamba blocks.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + # [MambaLayer, Mamba2, DiscreteMamba2] conv_kernel_dimension: int = Field( default=4, desc="Conv kernel dimension for Mamba blocks.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - # Layer parameters - add_bias_linear: bool = Field( - default=False, - desc="Whether to use bias in SSM layers", - hint=FieldHint.architecture, - ) - + # [MambaLayer, Mamba2] dt_rank: None | int = Field( default=None, desc="Rank of the Δ projection matrix. If 'None', will be set to ceil(hidden_size/16)", hint=FieldHint.architecture, ) - chunk_size: int = Field( - default=256, - desc="Chunk size for Mamba2 blocks.", - hint=FieldHint.architecture, - ) + # head_groups [DiscreteMamba2] n_qk_heads: int = Field( default=32, desc="Number of QK heads for Mamba2 blocks.", hint=FieldHint.architecture, ) + # heads [DiscreteMamba2]# TODO: Remove? (redundant) n_v_heads: int = Field( default=32, desc="Number of V heads for Mamba2 blocks.", hint=FieldHint.architecture, ) - activation_type: ActivationType = Field( + # c_size [MambaLayer, Mamba2, DiscreteMamba2]? + d_inner: None | int = Field( + default=None, + desc="Inner dimension for Mamba2 blocks.", + hint=FieldHint.core, + ) + # xb_size [Mamba2] + d_xb: int = Field( default=None, - desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", + desc="Dimension of the xB in Mamba2 blocks.", hint=FieldHint.architecture, ) - debug_ssm: bool = Field( + + # Model options + # add_bias_linear [Mamba2, DiscreteMamba2] [hard-coded to False in MambaLayer] + add_bias_linear: bool = Field( default=False, - desc="debug_ssm", - hint=FieldHint.optional, + desc="Whether to use bias in SSM layers", + hint=FieldHint.architecture, ) - dt_min: float = Field( - default=0.001, - desc="Minimum step size for discretization", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), + # activation_type [DiscreteMamba2] [hard-coded to silu in MambaLayer, Mamba2] + activation_type: ActivationType = Field( + default=None, + hint=FieldHint.architecture, ) - dt_init_floor: float = Field( - default=1e-4, - desc="Minimum value for initializing dt", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), + # repeat_xb_before_conv [Mamba2] + repeat_kv_before_conv: bool = Field( + default=True, + desc="Whether to repeat x and B before (True) or after (False) the conv1d in Mamba2 blocks.", + hint=FieldHint.architecture, ) - - d_inner: None | int = Field( - default=None, - desc="Inner dimension for Mamba2 blocks.", - hint=FieldHint.core, + # chunk_size [DiscreteMamba2] + chunk_size: int = Field( + default=256, + desc="Chunk size for Mamba2 blocks.", + hint=FieldHint.architecture, ) + + # Learning rate + # lr_scale [MambaLayer, Mamba2, DiscreteMamba2] mamba_lr_scale: float | None = Field( default=None, desc="Learning rate scale for Mamba blocks.", @@ -125,43 +150,38 @@ class SSMConfig(LLMBlockConfig): valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - # Mamba 2 - repeat_kv_before_conv: bool = Field( - default=True, - desc="Whether to repeat the KV before the conv1d in Mamba2 blocks.", - hint=FieldHint.architecture, - ) - d_xb: int = Field( - default=None, - desc="Dimension of the xB in Mamba2 blocks.", - hint=FieldHint.architecture, - ) - dt_init: str = Field( + # Initialization + # dt_weight_initialization_method [Mamba2] + dt_init: DTInitType = Field( default="random", desc="Initialization method for dt", hint=FieldHint.core, ) - dt_max: float = Field( - default=0.1, - desc="Maximum step size for discretization", + # dt_weight_initialization_scale [Mamba2] + dt_scale: float = Field( + default=1.0, + desc="Scale for dt", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) + # dt_bias_initialization_min [MambaLayer, Mamba2] dt_min: float = Field( default=0.001, desc="Minimum step size for discretization", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - dt_init_floor: float = Field( - default=1e-4, - desc="Minimum value for initializing dt", + # dt_bias_initialization_max [MambaLayer, Mamba2] + dt_max: float = Field( + default=0.1, + desc="Maximum step size for discretization", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - dt_scale: float = Field( - default=1.0, - desc="Scale for dt", + # dt_bias_initialization_floor [MambaLayer, Mamba2] + dt_init_floor: float = Field( + default=1e-4, + desc="Minimum value for initializing dt", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) @@ -172,3 +192,59 @@ def _validate(self) -> None: self.activation_type = ActivationType.silu super()._validate() Assert.geq(self.dt_max, self.dt_min) + + def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType) -> None: + tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) + + num_heads = div(self.d_inner, self.state_size) + # Head groups are configured differently depending on the block type. + if block_type == SSMBlockType.mamba: + num_head_groups = num_heads + # (head_groups, 2 * group_heads * state_dim) + inner_projection_size = self.d_inner * 2 + elif block_type == SSMBlockType.mamba2: + num_head_groups = div(self.d_xb, self.state_size) + # (head_groups, 2 * group_heads + 2, state_dim) + (dt,) + inner_projection_size: int = 2 * self.d_inner + 2 * num_head_groups * self.state_size + self.dt_rank + elif block_type == SSMBlockType.mamba2_discrete: + Assert.eq(num_heads, self.n_v_heads) + num_head_groups = self.n_qk_heads + # (head_groups, (2 * group_heads + 2) * state_dim + group_heads) + inner_projection_size = 2 * self.d_inner + 2 * num_head_groups * self.state_size + num_heads + else: + raise NotImplementedError(block_type) + + tensor_space.add_tensor_dim(state_dim := TensorDim(SSMDimNames.state, self.state_size)) + tensor_space.add_tensor_dim(head_groups := TensorDim(SSMDimNames.head_groups, num_head_groups, tensor)) + tensor_space.add_tensor_dim( + group_heads := TensorDim(SSMDimNames.group_heads, num_group_heads := div(num_heads, num_head_groups)) + ) + tensor_space.add_tensor_dim(CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads))) + tensor_space.add_tensor_dim( + CompositeTensorDim(SSMDimNames.composite_heads_and_state, (head_groups, group_heads, state_dim)) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(SSMDimNames.composite_head_groups_and_state, (head_groups, state_dim)) + ) + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_kernel, self.conv_kernel_dimension)) + + # DT projection + if block_type in (SSMBlockType.mamba, SSMBlockType.mamba2): + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.dt_rank, self.dt_rank)) + + if block_type == SSMBlockType.mamba: + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim, self.dt_rank + self.state_size * 2)) + inner_projection_size = 2 * num_group_heads * self.state_size + elif block_type == SSMBlockType.mamba2: + inner_projection_size = 2 * (num_group_heads + 1) * self.state_size + elif block_type == SSMBlockType.mamba2_discrete: + inner_projection_size = 2 * (num_group_heads + 1) * self.state_size + num_group_heads + # TODO: (head_groups, group_heads + 2, state_size) + tensor_space.add_tensor_dim( + TensorDim(SSMDimNames.conv_dim, self.d_inner + 2 * self.n_qk_heads * self.state_size) + ) + + tensor_space.add_tensor_dim(inner_projection := TensorDim(SSMDimNames.inner_projection, inner_projection_size)) + tensor_space.add_tensor_dim( + CompositeTensorDim(SSMDimNames.composite_inner_projection, (head_groups, inner_projection)) + ) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 934cd2b5..d06b4796 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -1,5 +1,6 @@ import logging import math +import typing import einops import torch @@ -7,8 +8,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_ +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ from fast_llm.utils import get_lr_scale logger = logging.getLogger(__name__) @@ -33,7 +34,7 @@ def bias_init_method(conv_weight): fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(conv_weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - return init_uniform_(-bound, bound) + return init_uniform_centered_(bound) class DiscreteMamba2(torch.nn.Module): @@ -53,21 +54,20 @@ def __init__( # factory_kwargs = {"device": "meta"} # , "dtype": torch.bfloat16} super().__init__() self.config: SSMConfig = config - bias = config.add_bias_linear self.layer_idx = layer_idx self._return_input = return_input layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) logger.info(f"Setting lr_scale for layer {layer_idx} of type {type(self)}: {mamba_layer_lr_scale}") - td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) - td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) + td_inner = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) + td_state = tensor_space.get_tensor_dim(SSMDimNames.state) + td_model = tensor_space.get_tensor_dim(TransformerDimNames.hidden) td_conv = tensor_space.get_tensor_dim(SSMDimNames.conv_dim) - td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.qk_heads) - td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.v_heads) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) - td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.inner_proj_discrete_mamba2) + td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.head_groups) + td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.composite_heads) + td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel) + td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.composite_inner_projection) self.d_model = td_model.size self.d_inner = td_inner.size @@ -85,8 +85,8 @@ def __init__( self.in_proj = Linear( td_model, td_inner_proj, - bias=bias, - weight_init_method=kaiming_init_(td_model.size), + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(td_model.size), lr_scale=mamba_layer_lr_scale, ) self.z_bias = ( @@ -96,15 +96,13 @@ def __init__( init_method=init_zeros_, lr_scale=mamba_layer_lr_scale, ) - if not bias + if not config.add_bias_linear else 0.0 ) self.conv1d_weight = ParameterMeta.from_dims( (td_conv, TensorDim("1", 1), td_conv_kernel), - init_method=init_uniform_( - 1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size) - ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 + init_method=init_uniform_centered_((td_conv.size * td_conv_kernel.size) ** -0.5), lr_scale=mamba_layer_lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( @@ -123,12 +121,12 @@ def __init__( self.out_proj = Linear( td_inner, td_model, - bias=bias, - weight_init_method=kaiming_init_(td_inner.size), + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(td_inner.size), lr_scale=mamba_layer_lr_scale, ) - def forward(self, hidden_states, kwargs): + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: """ ON variable names and pep8: keeping some variable names as in the original code for clarity. @@ -144,7 +142,6 @@ def forward(self, hidden_states, kwargs): raise NotImplementedError(f"Sequence-first not supported for SSMs.") assert _mamba_available - input_ = hidden_states outputs = {} # assert state is None batch, seqlen, dim = input_.shape diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/llamba_block.py index ee222d6d..e877ff9c 100644 --- a/fast_llm/layers/ssm/llamba_block.py +++ b/fast_llm/layers/ssm/llamba_block.py @@ -1,6 +1,6 @@ import typing -from fast_llm.layers.transformer.transformer import BaseBlock +from fast_llm.layers.transformer.transformer import BaseBlock, Mixer if typing.TYPE_CHECKING: from fast_llm.engine.config_utils.tensor_space import TensorSpace @@ -14,21 +14,19 @@ class LlambaBlock(BaseBlock): """ _name = "Llamba block" - _mixer_module_name = "mixer" def __init__( self, - config_transformer: "TransformerConfig", - config_ssm: "SSMConfig", + transformer_config: "TransformerConfig", + ssm_config: "SSMConfig", tensor_space: "TensorSpace", - mixer_cls, + mixer_cls: type[Mixer], layer_index: int, return_input: bool = False, ): - self.mixer_cls = mixer_cls - self._config_ssm = config_ssm self._debug_mode = self._config_ssm.debug_ssm - super().__init__(config_transformer, tensor_space, layer_index, return_input) + super().__init__(transformer_config, tensor_space, layer_index, return_input) + self.mixer = mixer_cls(ssm_config, layer_idx=self._layer_index, tensor_space=self._tensor_space) - def _create_mixer(self): - self.mixer = self.mixer_cls(self._config_ssm, layer_idx=self._layer_index, tensor_space=self._tensor_space) + def get_mixer(self) -> Mixer: + return self.mixer diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index a03509ab..011889d0 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -1,14 +1,15 @@ -import math -import typing - -import einops import torch from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.common.linear import Linear +from fast_llm.functional.config import ActivationType +from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.tensor import ParameterMeta, init_fill_, init_ones_, init_uniform_, kaiming_init_ -from fast_llm.utils import get_lr_scale +from fast_llm.layers.ssm.discrete_mamba2 import bias_init_method +from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.transformer import Mixer +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_ +from fast_llm.utils import Assert, div, get_lr_scale try: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa @@ -25,25 +26,7 @@ _causal_conv1d_available = False -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def bias_init_method(conv_weight): - fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(conv_weight) - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - return init_uniform_(-bound, bound) - - -class Mamba2(torch.nn.Module): +class Mamba2(Mixer): """ This code is adapted from https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py """ @@ -53,207 +36,138 @@ def __init__( config: SSMConfig, layer_idx: int, tensor_space: TensorSpace, - return_input: bool = False, ): super().__init__() - self.config: SSMConfig = config - bias: bool = config.add_bias_linear - self.layer_idx = layer_idx - self._return_input = return_input + self._config: SSMConfig = config + Assert.eq(self._config.activation_type, ActivationType.silu) layer_lr_scale: float | None = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None - mamba_layer_lr_scale: float | tuple[float | None, ...] | None = get_lr_scale( - self.config.mamba_lr_scale, layer_lr_scale - ) - - td_inner: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.inner_dim) - td_state: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.state_dim) - td_model: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.model_dim) - tdt_rank: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank) - td_xb: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.x_proj_dim_2) - td_inner_proj: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.inner_proj_mamba2) - td_conv_kernel: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel_size) + lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - self.repeat_kv_before_conv = config.repeat_kv_before_conv + inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_state) + hidden_dim: TensorDim = tensor_space.get_tensor_dim(name=TransformerDimNames.hidden) - self.d_state = td_state.size - self.d_model = td_model.size - self.d_xb = td_xb.size - self.d_inner = td_inner.size - self.dt_rank = tdt_rank.size - - if self.repeat_kv_before_conv: - self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, TensorDim("1", 1), td_conv_kernel), - init_method=init_uniform_( - 1 / math.sqrt(td_inner.size * td_conv_kernel.size), - 1 / math.sqrt(td_inner.size * td_conv_kernel.size), - ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 - lr_scale=mamba_layer_lr_scale, - ) + self._head_groups = div(self._config.d_xb, self._config.state_size) + self._heads = div(self._config.d_inner, self._config.state_size) + self._group_heads = div(self._heads, self._head_groups) - self.conv1d_bias = ParameterMeta.from_dims( - (td_inner,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale - ) - else: - self.conv1d_weight = ParameterMeta.from_dims( - (td_xb, TensorDim("1", 1), td_conv_kernel), - init_method=init_uniform_( - 1 / math.sqrt(td_xb.size * td_conv_kernel.size), - 1 / math.sqrt(td_xb.size * td_conv_kernel.size), - ), - ) - self.conv1d_bias = ParameterMeta.from_dims( - (td_xb,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale - ) - - self.activation = "silu" - - self.num_xb_head = td_xb.size // td_state.size - self.num_C_head = td_inner.size // td_state.size - self.repeat_group = self.num_C_head // self.num_xb_head - - self.in_proj = Linear( - td_model, - td_inner_proj, - bias=bias, - weight_init_method=kaiming_init_(td_model.size), - lr_scale=mamba_layer_lr_scale, + conv1d_dim = ( + inner_dim + if self._config.repeat_kv_before_conv + else tensor_space.get_tensor_dim(name=SSMDimNames.composite_head_groups_and_state) ) - - # Initialize special dt projection to preserve variance at initialization - dt_scale = config.dt_scale # 1.0 - dt_init_std = self.dt_rank**-0.5 * dt_scale - if config.dt_init == "constant": - dt_init = init_fill_(dt_init_std) - elif config.dt_init == "random": - dt_init = init_uniform_(-dt_init_std, dt_init_std) - else: - raise NotImplementedError - - # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - dt_max = config.dt_max # or 0.1 - dt_min = config.dt_min # or 0.001 - dt_init_floor = config.dt_init_floor # or 1e-4 - dt = torch.exp(torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)).clamp( - min=dt_init_floor + self.conv1d_weight = ParameterMeta.from_dims( + (conv1d_dim, tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel)), + init_method=init_uniform_centered_((conv1d_dim.size * self._config.conv_kernel_dimension) ** -0.5), + lr_scale=lr_scale, + ) + self.conv1d_bias = ParameterMeta.from_dims( + (conv1d_dim,), init_method=bias_init_method(self._config.conv_kernel_dimension**-0.5), lr_scale=lr_scale + ) + self.in_proj = OutputParallelLinear( + hidden_dim, + tensor_space.get_tensor_dim(name=SSMDimNames.composite_inner_projection), + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(hidden_dim.size), + lr_scale=lr_scale, ) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - - def init_from_tensor_( - value: torch.Tensor, - ) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - return tensor.copy_(value) - - return init_ - self.dt_proj = Linear( - tdt_rank, - td_inner, + tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank), + inner_dim, bias=False, - weight_init_method=dt_init, - lr_scale=mamba_layer_lr_scale, + # Initialize special dt projection to preserve variance at initialization + weight_init_method=self._config.dt_init.get_init_method( + self._config.dt_rank**-0.5 * self._config.dt_scale + ), + lr_scale=lr_scale, ) # define bias outside the linear layer since its also used in the selective_scan_fn self.dt_proj_bias = ParameterMeta.from_dims( - (td_inner,), init_method=init_from_tensor_(inv_dt), lr_scale=mamba_layer_lr_scale + (inner_dim,), + init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), + lr_scale=lr_scale, ) - - A = einops.repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_log = torch.log(A).flatten() # Keep A_log in fp32 self.A_log = ParameterMeta.from_dims( - (td_inner, td_state), - init_method=init_from_tensor_(A_log), - lr_scale=mamba_layer_lr_scale, + (inner_dim, tensor_space.get_tensor_dim(name=SSMDimNames.state)), + init_method=init_A(self._config.state_size, self._config.d_inner), + lr_scale=lr_scale, weight_decay=False, ) - self.D = ParameterMeta.from_dims( - (td_inner,), + (inner_dim,), weight_decay=False, init_method=init_ones_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) - - self.out_proj = Linear( - td_inner, - td_model, - bias=bias, - weight_init_method=kaiming_init_(td_inner.size), + self.out_proj = InputParallelLinear( + inner_dim, + hidden_dim, + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(self._config.d_inner), ) def forward(self, hidden_states, kwargs): - """ - hidden_states: (B, L, D) - Returns: same shape as hidden_states - """ assert _mamba_available - batch, seqlen, dim = hidden_states.shape - outputs = {} - - conv_state, ssm_state = None, None - - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - - zxbcdt = self.in_proj(hidden_states) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) - - x = einops.rearrange(x, "b l d -> b d l") - z = einops.rearrange(z, "b l d -> b d l") - - B = einops.rearrange(B, "b l (n_group dstate) -> b n_group l dstate", dstate=self.d_state) - B = repeat_kv(B, self.repeat_group) # B, n_group, L, H - B = einops.rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() - C = einops.rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous() - - dt = self.dt_proj(dt) + self.dt_proj_bias # B, L, d_inner - dt = einops.rearrange(dt, "b l d -> b d l") # B, d_inner, L + assert _causal_conv1d_available + + inner_projection = self.in_proj(hidden_states) + # Standardize to (batch, sequence, inner_projection) + if kwargs[TransformerKwargs.sequence_first]: + inner_projection = inner_projection.transpose(0, 1) + sequence_length = hidden_states.size(1) + + z, x, b, c, dt = torch.split( + inner_projection, + [self._config.d_inner, self._config.d_xb, self._config.d_xb, self._config.d_inner, self._config.dt_rank], + dim=2, + ) + # z: (batch, sequence, heads * state) -> (batch, heads * state, sequence) + z = z.transpose(1, 2) + + # x: (batch, sequence, head_groups * state) -> (batch, heads * state, sequence) + x = x.transpose(1, 2) + if self._config.repeat_kv_before_conv: + x = ( + x.unflatten(1, (self._head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._heads) + .flatten(1, 2) + ) + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias, activation="silu") + else: + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias, activation="silu") + x = ( + x.unflatten(1, (self._head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._heads) + .flatten(1, 2) + ) - if self.repeat_kv_before_conv: - assert self.repeat_group > 0 - x = einops.rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state) - x = repeat_kv(x, self.repeat_group) - x = einops.rearrange(x, "b n_group l dstate -> b (n_group dstate) l") + # b: (batch, sequence, head_groups * state) -> (batch, heads, state, sequence) + b = ( + b.transpose(1, 2) + .unflatten(1, (self._head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._heads) + ) - assert self.activation in ["silu", "swish"] - if _causal_conv1d_available: - x = _causal_conv1d_fn( - x=x, - weight=einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), - bias=self.conv1d_bias, - activation=self.activation, - ) # B, L, D - else: - raise RuntimeError("Causal conv1d is not available. Please install causal_conv1d.") + # c: (batch, sequence, heads * state) -> (batch, heads, state, sequence) + c = c.transpose(1, 2).unflatten(1, (self._heads, self._config.state_size)) - if not self.repeat_kv_before_conv: - x = einops.rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state) - x = repeat_kv(x, self.repeat_group) - x = einops.rearrange(x, "b n_group l dstate -> b (n_group dstate) l") + # dt: (batch, sequence, dt_rank) -> (batch, heads * state, sequence) + dt = (self.dt_proj(dt) + self.dt_proj_bias).transpose(1, 2) y = selective_scan_fn( x, dt, - A, - B, - C, + -torch.exp(self.A_log.float()), + b, + c, self.D.float(), - z=z, - delta_bias=self.dt_proj_bias.float(), # self.dt_proj.bias.float(), + z, + delta_bias=self.dt_proj_bias.float(), delta_softplus=True, - return_last_state=False, ) - if ssm_state is not None: - y, last_state = y - ssm_state.copy_(einops.rearrange(last_state, "b (h d) n -> b h d n", h=self.num_C_head)) - - y = einops.rearrange(y, "b d l -> b l d") - out = self.out_proj(y) - outputs["hidden_states"] = out[:, :seqlen, :].contiguous() - return outputs["hidden_states"], None + # y: (batch, heads * state, sequence) -> out: (batch, sequence, hidden) + out = self.out_proj(y.transpose(1, 2))[:, :sequence_length] + if kwargs[TransformerKwargs.sequence_first]: + out = out.transpose(0, 1) + # TODO: Is contiguous needed? + return out.contiguous(), None diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 7c824d23..fa2789b1 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -1,14 +1,18 @@ +import logging import math +import typing from typing import Callable -import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.tensor import ParameterMeta, init_ones_, kaiming_init_ -from fast_llm.utils import get_lr_scale +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.transformer import Mixer +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_ +from fast_llm.utils import Assert, get_lr_scale try: from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn as _mamba_inner_fn # noqa @@ -17,6 +21,8 @@ except (ImportError, RuntimeError): _mamba_available = False +logger = logging.getLogger(__name__) + """ Note: this is mostly adapted from https://github.com/Zyphra/Zamba2, similar code is also in https://github.com/state-spaces/mamba. For now it only supports training and not inference. @@ -26,169 +32,126 @@ def init_A(d_state, d_inner) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - # S4D real initialization # TODO: adopt this initialization to work for tensor parallel setting! - A = einops.repeat(torch.arange(1, d_state + 1, dtype=torch.float32), "n -> d n", d=d_inner).contiguous() - A_log = torch.log(A) # Keep A_log in fp32 - if tensor.shape != A_log.shape: - if tensor.numel() == A_log.numel(): - tensor_view = tensor.view(d_inner, d_state) - tensor_view.copy_(A_log) - else: - raise ValueError(f"Tensor size {tensor.numel()} doesn't match expected size {A_log.numel()}") - else: - tensor.copy_(A_log) - return tensor + if tensor.numel() != d_state * d_inner: + raise ValueError(f"_init_A requires not supported for tensor slices.") + return torch.log(torch.arange(1, d_state + 1, dtype=torch.float32).repeat(d_inner), out=tensor) return init_ def init_dtprojbias( - d_inner: int, dt_max: float, dt_min: float, dt_init_floor: float, factory_kwargs: dict + dt_max: float, dt_min: float, dt_init_floor: float ) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - dt = torch.exp( - torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) - ).clamp(min=dt_init_floor) + tensor = tensor.uniform_(math.log(dt_min), math.log(dt_max)).exp_().clamp_min(dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - tensor.copy_(inv_dt) - return tensor + return tensor.add_(torch.log(-torch.expm1(-tensor))) return init_ -class MambaLayer(torch.nn.Module): +class MambaLayer(Mixer): def __init__( self, config: SSMConfig, layer_idx: int, tensor_space: TensorSpace, - return_input: bool = False, ): - factory_kwargs = {} super().__init__() - self.config: SSMConfig = config - self.layer_idx = layer_idx - - self._debug_mode = config.debug_ssm + assert tensor_space.distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" + self._config = config + # TODO: It's not silu? + Assert.eq(self._config.activation_type, ActivationType.silu) # Tensor dims: - td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) - td_inner_proj = tensor_space.get_tensor_dim( - SSMDimNames.inner_proj_mamba - ) # TensorDim("D_inner_2", self.d_inner * 2) - tdt_rank = tensor_space.get_tensor_dim(SSMDimNames.dt_rank) - td_x_proj = tensor_space.get_tensor_dim(SSMDimNames.x_proj_dim) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) - td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) - self.d_conv = td_conv_kernel.size - self.d_inner = td_inner.size - self.d_state = td_state.size - self.d_model = td_model.size - self.dt_rank = tdt_rank.size + inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) + hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None - mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) + lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - self.in_proj_weight = ParameterMeta.from_dims( - (td_inner_proj, td_model), - init_method=kaiming_init_(td_model.size), + # TODO: Backward compatibility? + # TODO: lr_scale? + self.in_proj = Linear( + hidden_dim, + tensor_space.get_tensor_dim(SSMDimNames.composite_inner_projection), + bias=False, + weight_init_method=init_kaiming_(hidden_dim.size), ) self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, TensorDim("D_inner_2", self.d_inner // self.d_inner), td_conv_kernel), - init_method=kaiming_init_(td_inner.size), - lr_scale=mamba_layer_lr_scale, + (inner_dim, tensor_space.get_tensor_dim(SSMDimNames.conv_kernel)), + init_method=init_kaiming_(inner_dim.size), + lr_scale=lr_scale, ) - self.conv1d_bias = None - - self.activation = "silu" - self.act = torch.nn.SiLU() - self.x_proj = Linear( - td_inner, - td_x_proj, - weight_init_method=kaiming_init_(td_inner.size), + inner_dim, + tensor_space.get_tensor_dim(SSMDimNames.x_proj_dim), + weight_init_method=init_kaiming_(inner_dim.size), bias=False, - lr_scale=mamba_layer_lr_scale, - **factory_kwargs, + lr_scale=lr_scale, ) self.x_proj.weight.auto_grad_accumulation = True # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 self.dt_proj_weight = ParameterMeta.from_dims( - (td_inner, tdt_rank), - init_method=kaiming_init_(tdt_rank.size), - lr_scale=mamba_layer_lr_scale, + (inner_dim, tensor_space.get_tensor_dim(SSMDimNames.dt_rank)), + init_method=init_kaiming_(self._config.dt_rank), + lr_scale=lr_scale, ) self.dt_proj_bias = ParameterMeta.from_dims( - (td_inner,), - init_method=init_dtprojbias( - self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor, factory_kwargs - ), - lr_scale=mamba_layer_lr_scale, + (inner_dim,), + init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), + lr_scale=lr_scale, ) self.A_log = ParameterMeta.from_dims( - (td_inner, td_state), + (inner_dim, tensor_space.get_tensor_dim(SSMDimNames.state)), weight_decay=False, - init_method=init_A(self.d_state, self.d_inner), - lr_scale=mamba_layer_lr_scale, + init_method=init_A(self._config.state_size, inner_dim.size), + lr_scale=lr_scale, ) # D "skip" parameter self.D = ParameterMeta.from_dims( - (td_inner,), + (inner_dim,), weight_decay=False, init_method=init_ones_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) self.out_proj = Linear( - td_inner, - td_model, + inner_dim, + hidden_dim, bias=False, # TODO: note, if bias is used there is a problem in the MambaInnerFn.backward for the bias grads. I think this bias is not used in other mamba repos. - weight_init_method=kaiming_init_(td_model.size), - lr_scale=mamba_layer_lr_scale, - **factory_kwargs, + weight_init_method=init_kaiming_(hidden_dim.size), + lr_scale=lr_scale, ) self.out_proj.weight.auto_grad_accumulation = True - self._return_input = return_input - def forward(self, hidden_states, kwargs): + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available - batch, seqlen, dim = hidden_states.shape - - # We do matmul and transpose BLH -> HBL at the same time - xz = einops.rearrange( - self.in_proj_weight @ einops.rearrange(hidden_states, "b l d -> d (b l)"), - "d (b l) -> b d l", - l=seqlen, - ) - if self._debug_mode: - print("XZ: ", xz.shape) + in_proj = self.in_proj(input_).permute((1, 2, 0) if kwargs[TransformerKwargs.sequence_first] else (0, 2, 1)) - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # In the backward pass we write dx and dz next to each other to avoid torch.cat # not, if we wanbt to support inference, we would need to imp.lement slow path here, see https://github.com/Zyphra/Zamba2/blob/1b182f40f2257f822cc06dd785df53d67d691a15/mamba_layer.py#L172s out = _mamba_inner_fn( - xz, - self.conv1d_weight, - self.conv1d_bias, + in_proj, + self.conv1d_weight.unsqueeze(1), + None, self.x_proj.weight, self.dt_proj_weight, self.out_proj.weight, self.out_proj.bias, # is None here - A, + -torch.exp(self.A_log.float()), None, # input-dependent B None, # input-dependent C self.D.float(), delta_bias=self.dt_proj_bias.float(), delta_softplus=True, ) - if self._return_input: - out = torch.stack((hidden_states, out), dim=0) + if kwargs[TransformerKwargs.sequence_first]: + out = out.transpose(0, 1) return out, None diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 3351c990..76b8ed1c 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -13,6 +13,7 @@ TransformerKwargs, TransformerSubLayerName, ) +from fast_llm.layers.transformer.transformer import Mixer from fast_llm.logging import log_distributed_grad, log_distributed_tensor from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ from fast_llm.utils import Assert, get_lr_scale @@ -50,7 +51,7 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # no return grad, None -class Attention(torch.nn.Module): +class Attention(Mixer): """ A self-attention layer. """ diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 14745207..f80e903f 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -18,13 +18,24 @@ logger = logging.getLogger(__name__) +class Mixer(torch.nn.Module, abc.ABC): + """ + Base class for mixer modules. + """ + + @abc.abstractmethod + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Mixer module forward. Returns the output hidden states and an optional bias, + in case its addition can be made more efficient in `_bias_dropout_add`. + """ + + class BaseBlock(Layer, abc.ABC): """ A transformer-like decoder base block with abstract mixer. """ - _mixer_module_name = "self_attn" - def __init__( self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False ): @@ -54,7 +65,7 @@ def __init__( self.norm_2 = self._config.peft.apply_other(self.norm_2) @abc.abstractmethod - def _create_mixer(self): + def get_mixer(self) -> Mixer: pass @torch.compile @@ -115,7 +126,7 @@ def forward( hidden_states = self.norm_1(input_) if self._debug_mode: self._debug_log(hidden_states, "Norm 1", kwargs) - hidden_states, bias = getattr(self, self._mixer_module_name)(hidden_states, kwargs) + hidden_states, bias = self.get_mixer()(hidden_states, kwargs) if self._debug_mode: self._debug_log(hidden_states, f"{self._mixer_module_name} output", kwargs, bias=bias) with set_generator(generator): @@ -137,14 +148,14 @@ def forward( return hidden_states -class TransformerLayer(BaseBlock): +class TransformerBlock(BaseBlock): _name = "Transformer layer" - _mixer_module_name = "self_attn" def __init__( self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False ): super().__init__(config, tensor_space, layer_index, return_input) - - def _create_mixer(self): self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) + + def get_mixer(self) -> Mixer: + return self.self_attn diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index c206ef40..a9cf3bb8 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -7,7 +7,7 @@ from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.schedule.config import BatchConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.custom.config import CustomBaseModelConfig, CustomModelConfig from fast_llm.models.custom.head import CustomHead from fast_llm.models.gpt.config import GPTBaseModelConfig @@ -31,7 +31,7 @@ def get_layers(self) -> list[Layer]: return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, layer_index=i + 1, diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 444ad72b..a3a68e0a 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -21,7 +21,7 @@ TransformerLossNames, ) from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -68,7 +68,7 @@ def get_output_layers(self) -> list[Layer]: for i in range(self._config.prediction_heads): if i > 0: layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, # TODO MTP: which index? @@ -91,7 +91,7 @@ def get_layers(self) -> list[Layer]: return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, layer_index=i + 1, @@ -336,7 +336,7 @@ def embedding(self) -> LanguageModelEmbedding: return self.layers[0] @property - def transformer_layers(self) -> list[TransformerLayer]: + def transformer_layers(self) -> list[TransformerBlock]: return self.layers[1:-1] @property diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index cc83f11b..c294fe52 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -6,12 +6,11 @@ from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler from fast_llm.engine.config_utils.runnable import RunnableConfig -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.language_model.config import LanguageModelBaseConfig -from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig, SSMDimNames -from fast_llm.models.gpt.config import GPTBatchConfig, PretrainedGPTModelConfig +from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, PretrainedGPTModelConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -24,7 +23,7 @@ @config_class() -class HybridSSMBaseModelConfig(LanguageModelBaseConfig): +class HybridSSMBaseModelConfig(GPTBaseModelConfig): _abstract = False ssm: SSMConfig = Field( @@ -51,38 +50,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: Some of these can be setup directly in the layer config, but keeping them here for clarity. """ super().setup_tensor_space(tensor_space) - d_inner: int = self.ssm.d_inner - - # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.model_dim, self.transformer.hidden_size)) - # Mamba-specific dimensions - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_dim, d_inner)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.state_dim, self.ssm.state_size)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.dt_rank, self.ssm.dt_rank)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim, self.ssm.dt_rank + self.ssm.state_size * 2)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_kernel_size, self.ssm.conv_kernel_dimension)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba, d_inner * 2)) - - if SSMBlockType.mamba2_discrete.value in self.hybrid_block_layout: - # Mamba2 specific dimensions - # as per https://github.com/cartesia-ai/edge/blob/a0e121ebed3d2324c6d762b0e211a08d62583681/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py#L66C3-L66C4 - headdim = d_inner // self.ssm.n_v_heads - Assert.eq(self.ssm.n_v_heads, d_inner // headdim) - Assert.eq(d_inner % headdim, 0) - Assert.eq(self.ssm.n_v_heads % self.ssm.n_qk_heads, 0) - - conv_dim = d_inner + 2 * self.ssm.n_qk_heads * self.ssm.state_size - inner_proj_dim = 2 * d_inner + 2 * self.ssm.n_qk_heads * self.ssm.state_size + self.ssm.n_v_heads - - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.head_dim, headdim)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.qk_heads, self.ssm.n_qk_heads)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.v_heads, self.ssm.n_v_heads)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_discrete_mamba2, inner_proj_dim)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_dim, conv_dim)) - elif SSMBlockType.mamba2.value in self.hybrid_block_layout: - inner_proj_dim: int = 2 * self.ssm.d_xb + 2 * d_inner + self.ssm.dt_rank - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba2, inner_proj_dim)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim_2, self.ssm.d_xb)) + self.ssm.setup_tensor_space(tensor_space) def _validate(self): with self._set_implicit_default(None): diff --git a/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py b/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py index 6d9746db..8f49ded4 100644 --- a/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py +++ b/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py @@ -322,19 +322,21 @@ def __init__(self, config, factory_kwargs, layer_idx, **kwargs): # Mixer self.mixer = DiscreteMamba2( - d_model=self.config.d_model, + d_model=self.config._hidden_size, layer_idx=layer_idx, **config.ssm_cfg, **factory_kwargs, ) # Other components - self.input_layernorm = LlamaRMSNorm(hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs) + self.input_layernorm = LlamaRMSNorm( + hidden_size=self.config._hidden_size, eps=1e-5, factory_kwargs=factory_kwargs + ) self.post_attention_layernorm = LlamaRMSNorm( - hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs + hidden_size=self.config._hidden_size, eps=1e-5, factory_kwargs=factory_kwargs ) self.mlp = LlamaMLP( - hidden_size=self.config.d_model, + hidden_size=self.config._hidden_size, **config.mlp_cfg, factory_kwargs=factory_kwargs, ) diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 02a5ac23..3e57689b 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -9,7 +9,7 @@ from fast_llm.layers.ssm.llamba_block import LlambaBlock from fast_llm.layers.ssm.mamba2 import Mamba2 from fast_llm.layers.ssm.mamba_layer import MambaLayer -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.model import GPTBaseModel, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType @@ -39,14 +39,14 @@ def get_output_layers(self) -> list[Layer]: Get the output layers of the model. This includes the language model head and any additional heads specified in the configuration. """ - layers = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] + layers: list[Layer] = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] if self._config.prediction_heads > 1: block_type = self._config.default_mtp_type or self._config.hybrid_block_layout[-1] for i in range(1, self._config.prediction_heads): if block_type == SSMBlockType.transformer: layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, layer_index=len(self._config.hybrid_block_layout), @@ -55,8 +55,8 @@ def get_output_layers(self) -> list[Layer]: ) elif block_type == SSMBlockType.mamba2_discrete: mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, mixer_cls=DiscreteMamba2, layer_index=len(self._config.hybrid_block_layout), tensor_space=self._tensor_space, @@ -65,8 +65,8 @@ def get_output_layers(self) -> list[Layer]: layers.append(mamba_block) elif block_type == SSMBlockType.mamba: mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, mixer_cls=MambaLayer, layer_index=len(self._config.hybrid_block_layout), tensor_space=self._tensor_space, @@ -75,8 +75,8 @@ def get_output_layers(self) -> list[Layer]: layers.append(mamba_block) elif block_type == SSMBlockType.mamba2: mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, mixer_cls=Mamba2, layer_index=len(self._config.hybrid_block_layout), tensor_space=self._tensor_space, @@ -94,14 +94,14 @@ def get_layers(self) -> list[Layer]: Create a list of layers for the model, interleaving Transformer and Mamba blocks according to the block pattern. """ - layers = [LanguageModelEmbedding(self._config, self._tensor_space)] + layers: list[Layer] = [LanguageModelEmbedding(self._config, self._tensor_space)] # Create blocks according to pattern for i, block_type in enumerate(self._config.hybrid_block_layout): if block_type == SSMBlockType.transformer: # Transformer block layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, layer_index=i + 1, @@ -112,8 +112,8 @@ def get_layers(self) -> list[Layer]: ) elif block_type == SSMBlockType.mamba2_discrete: mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, mixer_cls=DiscreteMamba2, layer_index=i + 1, tensor_space=self._tensor_space, @@ -126,8 +126,8 @@ def get_layers(self) -> list[Layer]: elif block_type == SSMBlockType.mamba: # Create Mamba block mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, mixer_cls=MambaLayer, layer_index=i + 1, tensor_space=self._tensor_space, @@ -139,8 +139,8 @@ def get_layers(self) -> list[Layer]: elif block_type == SSMBlockType.mamba2: mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, mixer_cls=Mamba2, layer_index=i + 1, tensor_space=self._tensor_space, diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index d780e4d6..b474fe87 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -354,7 +354,7 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) return init_ -def kaiming_init_(d_in): +def init_kaiming_(d_in): return init_normal_(0.0, math.sqrt(2.0 / d_in)) @@ -369,3 +369,9 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) return tensor return init_ + + +def init_uniform_centered_( + high, max_val=None, mean=0.0 +) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: + return init_uniform_(mean - high, mean + high, min_val=mean - max_val, max_val=mean + max_val) diff --git a/setup.cfg b/setup.cfg index 2f69b8e0..c086af7d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,7 +48,7 @@ HUGGINGFACE = # Required to run SSMs # To install on cpu environment (ex. for IDE support): -# MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation +# MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[SSM]" --no-build-isolation SSM = mamba_ssm[causal-conv1d]==2.2.4 diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index c530a170..e5fbc7d6 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -4,7 +4,7 @@ from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer from fast_llm.layers.ssm.llamba_block import LlambaBlock -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.utils import Assert from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -39,7 +39,7 @@ def test_frozen_weights(model_testing_config): model_frozen._num_stages, ) frozen_parameter_counts = [ - sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerLayer, LlambaBlock)) else 0 + sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerBlock, LlambaBlock)) else 0 for layer in model_ref.base_model.layers ] for weight_buffer_ref, weight_buffer_frozen in zip( From 4e310c74634a70c4d8117cc025f18a040ffbd098 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 22 Jul 2025 13:04:54 -0400 Subject: [PATCH 03/37] TP mamba --- fast_llm/engine/config_utils/tensor_space.py | 174 ++++++++++++------- fast_llm/layers/common/linear.py | 8 +- fast_llm/layers/common/normalization.py | 4 +- fast_llm/layers/common/peft.py | 4 +- fast_llm/layers/ssm/config.py | 45 +++-- fast_llm/layers/ssm/discrete_mamba2.py | 2 +- fast_llm/layers/ssm/mamba2.py | 22 ++- fast_llm/layers/ssm/mamba_layer.py | 2 +- fast_llm/tensor.py | 31 ++-- 9 files changed, 184 insertions(+), 108 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 99c1bcf7..dceeb7da 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -5,6 +5,8 @@ from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: + import torch + from fast_llm.core.distributed import ProcessGroup from fast_llm.engine.distributed.distributed import Distributed @@ -23,7 +25,7 @@ def __repr__(self) -> str: f"name={self._name}," f" size={self._size}," f" global_size={self._global_size}," - f" parallel_dim={None if self.parallel_dim is None else self._parallel_dim}" + f" parallel_dim={self._parallel_dim}" f")" ) @@ -38,83 +40,134 @@ def name(self) -> str: def size(self) -> int: return self._size - @property - def expanded_shape(self) -> tuple[int, ...]: - return (self._size,) - - @property - def ndim(self) -> int: - return 1 - @property def global_size(self) -> int: return self._global_size @property - def global_expanded_shape(self) -> tuple[int, ...]: - return (self._size if self._parallel_dim is None else self._size * self._parallel_dim.size,) + def is_parallel(self) -> bool: + return self._parallel_dim is not None and self._parallel_dim.size > 1 @property def parallel_dim(self) -> DistributedDim | None: + # TODO: Make more flexible for derived classes? return self._parallel_dim - @property - def parallel_dim_index(self) -> int | None: - return None if self._parallel_dim is None else 0 - @property def parallel_group(self) -> "ProcessGroup|None": + # TODO: Make more flexible for derived classes? return None if self._parallel_dim is None else self._parallel_dim.group def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - assert self.parallel_dim is not None + assert self.is_parallel return TensorDim(self.name, self.size * distributed_dim.size, distributed_dim) + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + if self.parallel_group is not None: + from fast_llm.core.ops import gather_op + + return gather_op(tensor, self.parallel_group, dim) + else: + return tensor + + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> torch.Tensor: + return ( + tensor.chunk(self.parallel_dim.size, dim)[self.parallel_dim.rank] + if self.parallel_dim is not None and self.parallel_dim.size > 1 + else tensor + ) + class CompositeTensorDim(TensorDim): - def __init__(self, name: str, dims: tuple[TensorDim, ...]): - # TODO: Recursive composition?? - parallel_dims = [(i, dim.parallel_dim) for i, dim in enumerate(dims) if dim.parallel_dim] - Assert.leq(len(parallel_dims), 1) + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): + parallel_dim = None + for dim, tensor_dim in enumerate(tensor_dims): + if tensor_dim.is_parallel: + # TODO: Allow more than one parallel subdim? + assert parallel_dim is None + parallel_dim = tensor_dim.parallel_dim + self._parallel_dim_index = dim super().__init__( name=name, - global_size=math.prod(dim.global_size for dim in dims), - parallel_dim=parallel_dims[0][1] if parallel_dims else None, - ) - self._dims = dims - self._parallel_dim_index = ( - sum(dim.ndim for dim in self._dims[: parallel_dims[0][0]]) - + self._dims[parallel_dims[0][0]].parallel_dim_index - if parallel_dims - else None + global_size=math.prod(dim.global_size for dim in tensor_dims), + parallel_dim=parallel_dim, ) + self._tensor_dims = tensor_dims - @property - def dims(self) -> tuple[TensorDim, ...]: - return self._dims + def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: + assert self._parallel_dim_index is not None + dims = list(self._tensor_dims) + dims[self._parallel_dim_index] = dims[self._parallel_dim_index].replace_parallel_dim(distributed_dim) + return CompositeTensorDim(self.name, tuple(dims)) - @property - def ndim(self) -> int: - return sum(dim.ndim for dim in self._dims) + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in enumerate(self._tensor_dims): + tensor = tensor_dim.local_to_global(tensor, dim + i) - @property - def expanded_shape(self) -> tuple[int, ...]: - return sum((dim.expanded_shape for dim in self._dims), ()) + return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) - @property - def global_expanded_shape(self) -> tuple[int, ...]: - return sum((dim.global_expanded_shape for dim in self._dims), ()) + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> torch.Tensor: + tensor = tensor.unflatten(dim, [tensor_dim.global_size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in reversed(list(enumerate(self._tensor_dims))): + tensor = tensor_dim.global_to_local(tensor, dim + i) + return tensor if expand else tensor.flatten(dim, dim + len(self._tensor_dims) - 1) - @property - def parallel_dim_index(self) -> int | None: - return self._parallel_dim_index + +class ConcatenatedTensorDim(TensorDim): + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): + parallel_dim = tensor_dims[0].parallel_dim + for dim, tensor_dim in enumerate(tensor_dims[1:]): + # TODO: Allow more flexibility? + Assert.is_(tensor_dim.parallel_dim, parallel_dim) + + super().__init__( + name=name, + global_size=sum(dim.global_size for dim in tensor_dims), + parallel_dim=parallel_dim, + ) + self._tensor_dims = tensor_dims def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - assert self.parallel_dim_index is not None - dims = list(self.dims) - dims[self.parallel_dim_index] = dims[self.parallel_dim_index].replace_parallel_dim(distributed_dim) - return CompositeTensorDim(self.name, tuple(dims)) + # TODO: Implement + raise NotImplementedError() + + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + return ( + torch.concatenate( + [ + tensor_dim.local_to_global(tensor_, dim)[0] + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) + + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> torch.Tensor: + if self.is_parallel and expand: + raise NotImplementedError() + return ( + torch.concatenate( + [ + tensor_dim.global_to_local(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.global_size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) class DefaultDimNames: @@ -147,21 +200,22 @@ def distributed(self) -> "Distributed": assert self._is_setup return self._distributed - def add_tensor_dim(self, dim: TensorDim) -> None: - if isinstance(dim, CompositeTensorDim): - for dim_ in dim.dims: - Assert.incl(dim_.name, self._tensor_dims) - Assert.eq(dim_, self._tensor_dims[dim_.name]) - if dim.name in self._tensor_dims: - Assert.eq(dim, self._tensor_dims[dim.name]) + def add_tensor_dim(self, tensor_dim: TensorDim) -> None: + if tensor_dim.name in self._tensor_dims: + Assert.eq(tensor_dim, self._tensor_dims[tensor_dim.name]) else: - if dim.parallel_dim is not None: - assert dim.parallel_dim.name in self._distributed_config.distributed_dims, dim.parallel_dim.name + if tensor_dim.parallel_dim is not None: + assert ( + tensor_dim.parallel_dim.name in self._distributed_config.distributed_dims + ), tensor_dim.parallel_dim.name Assert.eq( - dim.parallel_dim.__dict__, - self._distributed_config.distributed_dims[dim.parallel_dim.name].__dict__, + tensor_dim.parallel_dim.__dict__, + self._distributed_config.distributed_dims[tensor_dim.parallel_dim.name].__dict__, ) - self._tensor_dims[dim.name] = dim + self._tensor_dims[tensor_dim.name] = tensor_dim def get_tensor_dim(self, name: str) -> TensorDim: return self._tensor_dims[name] + + # TODO: Replace uses + __getitem__ = get_tensor_dim diff --git a/fast_llm/layers/common/linear.py b/fast_llm/layers/common/linear.py index cd19a47a..7249ef56 100644 --- a/fast_llm/layers/common/linear.py +++ b/fast_llm/layers/common/linear.py @@ -94,8 +94,8 @@ def __init__( transposed_weight: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert in_dim.parallel_dim is None - assert out_dim.parallel_dim is None + assert not in_dim.is_parallel + assert not out_dim.is_parallel super().__init__( in_dim, out_dim, @@ -132,7 +132,7 @@ def __init__( sequence_parallel: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert in_dim.parallel_dim is None + assert not in_dim.is_parallel self._group_size = 1 if out_dim.parallel_dim is None else out_dim.parallel_dim.size self._sequence_parallel = sequence_parallel and self._group_size > 1 super().__init__( @@ -176,7 +176,7 @@ def __init__( transposed_weight: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert out_dim.parallel_dim is None + assert not out_dim.is_parallel self._group_size = 1 if in_dim.parallel_dim is None else in_dim.parallel_dim.size self._sequence_parallel = sequence_parallel and self._group_size > 1 super().__init__( diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index 5f30beae..bccc1d62 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -158,7 +158,7 @@ def __init__( lr_scale: float | None = None, ): super().__init__() - assert hidden_dim.parallel_dim is None + assert not hidden_dim.is_parallel self._eps = eps self._zero_centered = zero_centered if implementation == NormalizationImplementation.auto: @@ -242,7 +242,7 @@ def __init__( lr_scale: float | None = None, ): super().__init__() - assert hidden_dim.parallel_dim is None + assert not hidden_dim.is_parallel self._eps = eps self._zero_centered = zero_centered if implementation == NormalizationImplementation.auto: diff --git a/fast_llm/layers/common/peft.py b/fast_llm/layers/common/peft.py index 3a1966e5..08f3e535 100644 --- a/fast_llm/layers/common/peft.py +++ b/fast_llm/layers/common/peft.py @@ -19,12 +19,12 @@ def lora_linear( ): layer.weight.requires_grad = False in_dim = layer._in_dim + assert not in_dim.is_parallel, "LoRA not supported with tensor parallelism." if in_dim.parallel_dim is not None: - assert in_dim.parallel_dim.size == 1, "LoRA not supported with tensor parallelism." in_dim = TensorDim(in_dim.name, in_dim.global_size) out_dim = layer._out_dim + assert not out_dim.is_parallel, "LoRA not supported with tensor parallelism." if out_dim.parallel_dim is not None: - assert out_dim.parallel_dim.size == 1, "LoRA not supported with tensor parallelism." out_dim = TensorDim(out_dim.name, out_dim.global_size) if out_channel_begin is not None or out_channel_end is not None: if out_channel_begin is None: diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index f4c8067d..ce37a980 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,7 +1,7 @@ import enum from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig @@ -20,8 +20,7 @@ class SSMDimNames: composite_head_groups_and_state = "ssm_composite_head_groups_and_state" # Inner projection total dimension. - inner_projection = "ssm_inner_projection" - composite_inner_projection = "ssm_composite_inner_projection" + concatenated_inner_projection = "ssm_concatenated_inner_projection" # Convolution shape in discrete mamba 2. TODO: Remove (dim too complex) conv_dim = "ssm_conv_dim" @@ -210,7 +209,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType Assert.eq(num_heads, self.n_v_heads) num_head_groups = self.n_qk_heads # (head_groups, (2 * group_heads + 2) * state_dim + group_heads) - inner_projection_size = 2 * self.d_inner + 2 * num_head_groups * self.state_size + num_heads + 2 * self.d_inner + 2 * num_head_groups * self.state_size + num_heads else: raise NotImplementedError(block_type) @@ -219,12 +218,18 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType tensor_space.add_tensor_dim( group_heads := TensorDim(SSMDimNames.group_heads, num_group_heads := div(num_heads, num_head_groups)) ) - tensor_space.add_tensor_dim(CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads))) tensor_space.add_tensor_dim( - CompositeTensorDim(SSMDimNames.composite_heads_and_state, (head_groups, group_heads, state_dim)) + heads := CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(SSMDimNames.composite_head_groups_and_state, (head_groups, state_dim)) + heads_and_state := CompositeTensorDim( + SSMDimNames.composite_heads_and_state, (head_groups, group_heads, state_dim) + ) + ) + tensor_space.add_tensor_dim( + head_groups_and_state := CompositeTensorDim( + SSMDimNames.composite_head_groups_and_state, (head_groups, state_dim) + ) ) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_kernel, self.conv_kernel_dimension)) @@ -234,17 +239,27 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType if block_type == SSMBlockType.mamba: tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim, self.dt_rank + self.state_size * 2)) - inner_projection_size = 2 * num_group_heads * self.state_size + # TODO: Use composition instead + tensor_space.add_tensor_dim( + ConcatenatedTensorDim(SSMDimNames.concatenated_inner_projection, (heads_and_state, heads_and_state)) + ) elif block_type == SSMBlockType.mamba2: - inner_projection_size = 2 * (num_group_heads + 1) * self.state_size + # TODO: Factor out state? + tensor_space.add_tensor_dim( + ConcatenatedTensorDim( + SSMDimNames.concatenated_inner_projection, + (heads_and_state, head_groups_and_state, head_groups_and_state, heads_and_state), + ) + ) elif block_type == SSMBlockType.mamba2_discrete: - inner_projection_size = 2 * (num_group_heads + 1) * self.state_size + num_group_heads + # TODO: Factor as (head_groups, (group_heads + 2) * state_size + group_heads)? + tensor_space.add_tensor_dim( + ConcatenatedTensorDim( + SSMDimNames.concatenated_inner_projection, + (heads_and_state, head_groups_and_state, head_groups_and_state, heads_and_state, heads), + ) + ) # TODO: (head_groups, group_heads + 2, state_size) tensor_space.add_tensor_dim( TensorDim(SSMDimNames.conv_dim, self.d_inner + 2 * self.n_qk_heads * self.state_size) ) - - tensor_space.add_tensor_dim(inner_projection := TensorDim(SSMDimNames.inner_projection, inner_projection_size)) - tensor_space.add_tensor_dim( - CompositeTensorDim(SSMDimNames.composite_inner_projection, (head_groups, inner_projection)) - ) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index d06b4796..988a0950 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -67,7 +67,7 @@ def __init__( td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.head_groups) td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.composite_heads) td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel) - td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.composite_inner_projection) + td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.concatenated_inner_projection) self.d_model = td_model.size self.d_inner = td_inner.size diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 011889d0..dff1356e 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -45,6 +45,7 @@ def __init__( inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_state) hidden_dim: TensorDim = tensor_space.get_tensor_dim(name=TransformerDimNames.hidden) + dt_rank_dim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank) self._head_groups = div(self._config.d_xb, self._config.state_size) self._heads = div(self._config.d_inner, self._config.state_size) @@ -65,13 +66,21 @@ def __init__( ) self.in_proj = OutputParallelLinear( hidden_dim, - tensor_space.get_tensor_dim(name=SSMDimNames.composite_inner_projection), + tensor_space.get_tensor_dim(name=SSMDimNames.concatenated_inner_projection), bias=config.add_bias_linear, weight_init_method=init_kaiming_(hidden_dim.size), lr_scale=lr_scale, ) - self.dt_proj = Linear( - tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank), + + self.dt_in_proj = Linear( + hidden_dim, + dt_rank_dim, + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(hidden_dim.size), + lr_scale=lr_scale, + ) + self.dt_proj = OutputParallelLinear( + dt_rank_dim, inner_dim, bias=False, # Initialize special dt projection to preserve variance at initialization @@ -110,16 +119,19 @@ def forward(self, hidden_states, kwargs): assert _causal_conv1d_available inner_projection = self.in_proj(hidden_states) + dt = self.dt_in_proj(hidden_states) # Standardize to (batch, sequence, inner_projection) if kwargs[TransformerKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) + dt = dt.transpose(0, 1) sequence_length = hidden_states.size(1) - z, x, b, c, dt = torch.split( + z, x, b, c = torch.split( inner_projection, - [self._config.d_inner, self._config.d_xb, self._config.d_xb, self._config.d_inner, self._config.dt_rank], + [self._config.d_inner, self._config.d_xb, self._config.d_xb, self._config.d_inner], dim=2, ) + # z: (batch, sequence, heads * state) -> (batch, heads * state, sequence) z = z.transpose(1, 2) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index fa2789b1..0cdcb524 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -74,7 +74,7 @@ def __init__( # TODO: lr_scale? self.in_proj = Linear( hidden_dim, - tensor_space.get_tensor_dim(SSMDimNames.composite_inner_projection), + tensor_space.get_tensor_dim(SSMDimNames.concatenated_inner_projection), bias=False, weight_init_method=init_kaiming_(hidden_dim.size), ) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index b474fe87..f312f196 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -5,7 +5,7 @@ import torch from fast_llm.core.distributed import ReduceOp -from fast_llm.core.ops import gather_op, reduce_op +from fast_llm.core.ops import reduce_op from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed @@ -166,14 +166,13 @@ def local_to_global( ) -> tuple[torch.Tensor, ...]: # Tensors are always either split or duplicated in the tensor-parallel direction. # TODO: Avoid hard-coded assumptions on duplication - is_first_rank = distributed.config.tensor_rank == 0 - modified = False - for i, dim in enumerate(self.dims): - if dim.parallel_group is not None: - tensor = gather_op( - tensor.unflatten(i, dim.expanded_shape), dim.parallel_group, i + dim.parallel_dim_index - ).flatten(i, i + len(dim.expanded_shape) - 1) - is_first_rank, modified = is_first_rank and dim.parallel_group.rank() == 0, True + is_first_rank, modified = distributed.config.tensor_rank == 0, False + + for dim, tensor_dim in enumerate(self.dims): + if tensor_dim.is_parallel: + tensor = tensor_dim.local_to_global(tensor, dim) + is_first_rank &= tensor_dim.parallel_dim.rank == 0 + modified = True for distributed_dim, op in self._reductions: if distributed_dim.group is not None: @@ -187,23 +186,19 @@ def local_to_global( def global_to_local( self, tensor: torch.Tensor | SafeTensorSlice, - # Return an expanded tensor, avoiding `flatten` which copies the data. + # Return an expanded tensor, avoiding `flatten` which copies the data. TODO: Rework. expand: bool = False, ) -> torch.Tensor: """ Recover the tensor-parallel slice of a tensor. Support lazy-loaded safetensor slices. """ # Take a trivial slice to convert safetensor slices. - tensor_ = tensor[:] + tensor = tensor[:] assert not self._reductions - for i, dim in reversed(list(enumerate(self.dims))): - if dim.parallel_dim is not None and dim.parallel_dim.size > 1: - tensor_ = tensor_.unflatten(i, dim.global_expanded_shape).chunk( - dim.parallel_dim.size, i + dim.parallel_dim_index - )[dim.parallel_dim.rank] - - return tensor_ if expand else tensor_.reshape(self.shape) + for dim, tensor_dim in reversed(list(enumerate(self.dims))): + tensor = tensor_dim.global_to_local(tensor, dim, expand) + return tensor @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): From 3cc41182a71d28e02918d76cd882978ca8384f73 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 22 Jul 2025 16:57:38 -0400 Subject: [PATCH 04/37] fix --- fast_llm/engine/config_utils/tensor_space.py | 6 +- fast_llm/layers/ssm/config.py | 24 +++-- fast_llm/layers/ssm/discrete_mamba2.py | 2 + fast_llm/layers/ssm/llamba_block.py | 10 +- fast_llm/layers/ssm/mamba_layer.py | 13 ++- fast_llm/layers/transformer/transformer.py | 20 ++-- fast_llm/models/ssm/config.py | 41 +++----- fast_llm/models/ssm/model.py | 99 +++++--------------- fast_llm/tensor.py | 7 +- tests/data/test_blending.py | 1 + tests/data/test_concatenate.py | 1 + tests/data/test_fim.py | 2 + tests/test_multi_stage.py | 6 +- tests/utils/model_configs.py | 43 +++++---- 14 files changed, 127 insertions(+), 148 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index dceeb7da..d927f2e7 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -70,7 +70,7 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor else: return tensor - def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> torch.Tensor: + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": return ( tensor.chunk(self.parallel_dim.size, dim)[self.parallel_dim.rank] if self.parallel_dim is not None and self.parallel_dim.size > 1 @@ -108,7 +108,7 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) - def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> torch.Tensor: + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": tensor = tensor.unflatten(dim, [tensor_dim.global_size for tensor_dim in self._tensor_dims]) for i, tensor_dim in reversed(list(enumerate(self._tensor_dims))): tensor = tensor_dim.global_to_local(tensor, dim + i) @@ -150,7 +150,7 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor else tensor ) - def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> torch.Tensor: + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": if self.is_parallel and expand: raise NotImplementedError() return ( diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index ce37a980..aa011f75 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -41,6 +41,22 @@ class SSMBlockType(enum.StrEnum): mamba2 = "m2" transformer = "t" + def get_mixer_class(self): + if self == SSMBlockType.mamba: + from fast_llm.layers.ssm.mamba_layer import MambaLayer + + return MambaLayer + elif self == SSMBlockType.mamba2: + from fast_llm.layers.ssm.mamba2 import Mamba2 + + return Mamba2 + elif self == SSMBlockType.mamba2_discrete: + from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 + + return DiscreteMamba2 + else: + raise NotImplementedError(self) + class DTInitType(enum.StrEnum): constant = "constant" @@ -199,17 +215,13 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType # Head groups are configured differently depending on the block type. if block_type == SSMBlockType.mamba: num_head_groups = num_heads - # (head_groups, 2 * group_heads * state_dim) - inner_projection_size = self.d_inner * 2 elif block_type == SSMBlockType.mamba2: num_head_groups = div(self.d_xb, self.state_size) - # (head_groups, 2 * group_heads + 2, state_dim) + (dt,) - inner_projection_size: int = 2 * self.d_inner + 2 * num_head_groups * self.state_size + self.dt_rank elif block_type == SSMBlockType.mamba2_discrete: Assert.eq(num_heads, self.n_v_heads) + # TODO: Fix (Du einsum crashes) + Assert.eq(self.n_qk_heads, self.n_v_heads) num_head_groups = self.n_qk_heads - # (head_groups, (2 * group_heads + 2) * state_dim + group_heads) - 2 * self.d_inner + 2 * num_head_groups * self.state_size + num_heads else: raise NotImplementedError(block_type) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 988a0950..14fb8aae 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -216,6 +216,8 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ else: y = result + print("AHNFIUWEGIUWEI", self.D.shape, x.shape) + # TODO: h different for D and x (qk_heads, v_heads) Du = torch.einsum("h,blhp->blhp", self.D, x) y = einops.rearrange(y + Du, "b l h p -> b l (h p)") diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/llamba_block.py index e877ff9c..774ee730 100644 --- a/fast_llm/layers/ssm/llamba_block.py +++ b/fast_llm/layers/ssm/llamba_block.py @@ -8,7 +8,7 @@ from fast_llm.layers.transformer.config import TransformerConfig -class LlambaBlock(BaseBlock): +class SSMBlock(BaseBlock): """ A transformer-like decoder block with a SSM mixer, see https://arxiv.org/abs/2502.14458 """ @@ -24,9 +24,9 @@ def __init__( layer_index: int, return_input: bool = False, ): - self._debug_mode = self._config_ssm.debug_ssm + self._ssm_config = ssm_config + self._mixer_cls = mixer_cls super().__init__(transformer_config, tensor_space, layer_index, return_input) - self.mixer = mixer_cls(ssm_config, layer_idx=self._layer_index, tensor_space=self._tensor_space) - def get_mixer(self) -> Mixer: - return self.mixer + def _create_mixer(self) -> Mixer: + return self._mixer_cls(self._ssm_config, layer_idx=self._layer_index, tensor_space=self._tensor_space) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 0cdcb524..8235f4f1 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -1,7 +1,6 @@ import logging import math import typing -from typing import Callable import torch @@ -30,21 +29,25 @@ """ -def init_A(d_state, d_inner) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: +def init_A(d_state, d_inner) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa # TODO: adopt this initialization to work for tensor parallel setting! if tensor.numel() != d_state * d_inner: raise ValueError(f"_init_A requires not supported for tensor slices.") - return torch.log(torch.arange(1, d_state + 1, dtype=torch.float32).repeat(d_inner), out=tensor) + return torch.log( + torch.arange(1, d_state + 1, dtype=torch.float32, device=tensor.device).repeat(d_inner), out=tensor + ) return init_ def init_dtprojbias( dt_max: float, dt_min: float, dt_init_floor: float -) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: +) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - tensor = tensor.uniform_(math.log(dt_min), math.log(dt_max)).exp_().clamp_min(dt_init_floor) + tensor = ( + tensor.uniform_(math.log(dt_min), math.log(dt_max), generator=generator).exp_().clamp_min(dt_init_floor) + ) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 return tensor.add_(torch.log(-torch.expm1(-tensor))) diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index f80e903f..a0611cd2 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -8,7 +8,6 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.transformer.attention import Attention from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.transformer.mlp import MLP @@ -36,6 +35,9 @@ class BaseBlock(Layer, abc.ABC): A transformer-like decoder base block with abstract mixer. """ + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "mixer" + def __init__( self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False ): @@ -54,7 +56,8 @@ def __init__( self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) - self._create_mixer() + # The mixer needs to be created here for backward-compatible weight ordering. + setattr(self, self._mixer_module_name, self._create_mixer()) self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( self._config, self._tensor_space, f"{self.name} mlp", layer_index=layer_index @@ -65,7 +68,7 @@ def __init__( self.norm_2 = self._config.peft.apply_other(self.norm_2) @abc.abstractmethod - def get_mixer(self) -> Mixer: + def _create_mixer(self) -> Mixer: pass @torch.compile @@ -126,7 +129,7 @@ def forward( hidden_states = self.norm_1(input_) if self._debug_mode: self._debug_log(hidden_states, "Norm 1", kwargs) - hidden_states, bias = self.get_mixer()(hidden_states, kwargs) + hidden_states, bias = getattr(self, self._mixer_module_name)(hidden_states, kwargs) if self._debug_mode: self._debug_log(hidden_states, f"{self._mixer_module_name} output", kwargs, bias=bias) with set_generator(generator): @@ -150,12 +153,15 @@ def forward( class TransformerBlock(BaseBlock): _name = "Transformer layer" + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "self_attn" def __init__( self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False ): super().__init__(config, tensor_space, layer_index, return_input) - self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) - def get_mixer(self) -> Mixer: - return self.self_attn + def _create_mixer(self) -> Mixer: + from fast_llm.layers.transformer.attention import Attention + + return Attention(self._config, self._tensor_space, self._layer_index) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index c294fe52..6b9e2858 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -30,7 +30,7 @@ class HybridSSMBaseModelConfig(GPTBaseModelConfig): desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) - hybrid_block_layout: list[str] | None = Field( + hybrid_block_layout: list[SSMBlockType] | None = Field( default=None, desc=f"Pattern of blocks to use in the model. Available types: {SSMBlockType.__members__.values()}", hint=FieldHint.core, @@ -43,14 +43,16 @@ class HybridSSMBaseModelConfig(GPTBaseModelConfig): use_megatron_initialization: bool = Field( default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing ) # TODO: is this needed? + # TODO: Support combination of different SSM block types. + ssm_block_type: SSMBlockType | None = Field(init=False) def setup_tensor_space(self, tensor_space: TensorSpace) -> None: """ Setup the tensor space for the model. - Some of these can be setup directly in the layer config, but keeping them here for clarity. """ super().setup_tensor_space(tensor_space) - self.ssm.setup_tensor_space(tensor_space) + if self.ssm_block_type is not None: + self.ssm.setup_tensor_space(tensor_space, self.ssm_block_type) def _validate(self): with self._set_implicit_default(None): @@ -64,30 +66,21 @@ def _validate(self): if self.hybrid_block_layout is None: with self._set_implicit_default(): - self.hybrid_block_layout = [SSMBlockType.mamba2_discrete.value] + self.hybrid_block_layout = [SSMBlockType.mamba2_discrete] * self.transformer.num_layers if len(self.hybrid_block_layout) != self.transformer.num_layers: + message = f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: - raise ValueError( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" - ) - num_repeats = int(self.transformer.num_layers // len(self.hybrid_block_layout)) - logger.warning( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times" - ) + raise ValueError(message) + num_repeats = self.transformer.num_layers // len(self.hybrid_block_layout) + logger.warning(f"{message}, will repeat {self.hybrid_block_layout} {num_repeats} times.") self.hybrid_block_layout = self.hybrid_block_layout * num_repeats - Assert.eq(len(self.hybrid_block_layout), self.transformer.num_layers) - Assert.custom( - lambda _: all(block_type in SSMBlockType.__members__.values() for block_type in self.hybrid_block_layout), - f"Invalid block type: {self.hybrid_block_layout}. Must be one of {SSMBlockType.__members__.values()}", - ) - Assert.custom( - lambda _: self.default_mtp_type in SSMBlockType.__members__.values() or self.default_mtp_type is None, - f"Invalid MTP type: {self.default_mtp_type}. Must be one of {SSMBlockType.__members__.values()} or None", - ) - super()._validate() + ssm_block_types = set(self.hybrid_block_layout) - {SSMBlockType.transformer} + # TODO: Support combination of different SSM block types. + Assert.leq(len(ssm_block_types), 1) + self.ssm_block_type = ssm_block_types.pop() if ssm_block_types else None class LLambaHuggingfaceCheckpointFormat(CheckpointFormat): @@ -162,12 +155,6 @@ def _validate(self): logger.warning( "HybridSSMModelConfig is being instantiated. This model is experimental and may not work as expected." ) - if ( - self.base_model.sequence_first - or self.distributed.sequence_data_parallel > 1 - or self.distributed.sequence_tensor_parallel - ): - raise NotImplementedError(f"Sequence-first not supported for SSMs.") super()._validate() diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 3e57689b..4a95891a 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -5,10 +5,7 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead -from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 -from fast_llm.layers.ssm.llamba_block import LlambaBlock -from fast_llm.layers.ssm.mamba2 import Mamba2 -from fast_llm.layers.ssm.mamba_layer import MambaLayer +from fast_llm.layers.ssm.llamba_block import SSMBlock from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.model import GPTBaseModel, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType @@ -31,7 +28,6 @@ def __init__( config: HybridSSMBaseModelConfig, distributed_config: DistributedConfig, ): - self.SSM_BLOCK_CLS = LlambaBlock # TODO: extend to other block types if needed super().__init__(config, distributed_config) def get_output_layers(self) -> list[Layer]: @@ -53,38 +49,17 @@ def get_output_layers(self) -> list[Layer]: return_input=i != self._config.prediction_heads - 1, ) ) - elif block_type == SSMBlockType.mamba2_discrete: - mamba_block = self.SSM_BLOCK_CLS( - transformer_config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=DiscreteMamba2, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) - elif block_type == SSMBlockType.mamba: - mamba_block = self.SSM_BLOCK_CLS( - transformer_config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=MambaLayer, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) - elif block_type == SSMBlockType.mamba2: - mamba_block = self.SSM_BLOCK_CLS( - transformer_config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=Mamba2, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") + layers.append( + SSMBlock( + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, + mixer_cls=self._config.ssm_block_type.get_mixer_class(), + layer_index=len(self._config.hybrid_block_layout), + tensor_space=self._tensor_space, + return_input=i != self._config.prediction_heads - 1, + ) + ) layers.append(LanguageModelHead(self._config, self._tensor_space, prediction_distance=i)) return layers @@ -110,47 +85,19 @@ def get_layers(self) -> list[Layer]: ), ) ) - elif block_type == SSMBlockType.mamba2_discrete: - mamba_block = self.SSM_BLOCK_CLS( - transformer_config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=DiscreteMamba2, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) - - elif block_type == SSMBlockType.mamba: - # Create Mamba block - mamba_block = self.SSM_BLOCK_CLS( - transformer_config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=MambaLayer, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) - - elif block_type == SSMBlockType.mamba2: - mamba_block = self.SSM_BLOCK_CLS( - transformer_config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=Mamba2, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") + layers.append( + SSMBlock( + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, + mixer_cls=self._config.ssm_block_type.get_mixer_class(), + layer_index=i + 1, + tensor_space=self._tensor_space, + return_input=( + i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 + ), + ) + ) # Add the output layers layers += self.get_output_layers() diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index f312f196..1111fd04 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -369,4 +369,9 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) def init_uniform_centered_( high, max_val=None, mean=0.0 ) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - return init_uniform_(mean - high, mean + high, min_val=mean - max_val, max_val=mean + max_val) + return init_uniform_( + mean - high, + mean + high, + min_val=None if max_val is None else mean - max_val, + max_val=None if max_val is None else mean + max_val, + ) diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 438782df..3e6c3763 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -193,6 +193,7 @@ def test_gpt_blended_mixed(): def test_gpt_blended_mixed_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index e951cc2b..4f36cdf8 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -39,6 +39,7 @@ def test_gpt_concatenate(): def test_gpt_concatenate_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 7472f195..004b9628 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -58,6 +58,7 @@ def test_gpt_fim(): def test_gpt_fim_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { @@ -81,6 +82,7 @@ def test_gpt_fim_data(): def test_gpt_fim_data_legacy(): + get_test_dataset() get_test_data_and_compare_samples( { "format": "list", diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index e5fbc7d6..2f125717 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -3,9 +3,10 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer -from fast_llm.layers.ssm.llamba_block import LlambaBlock +from fast_llm.layers.ssm.llamba_block import SSMBlock from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.utils import Assert +from tests.utils.dataset import get_model_test_dataset from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -23,6 +24,7 @@ def _get_trainer_from_args(args: list[str], model_type: str = "gpt") -> Trainer: @requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_frozen_weights(model_testing_config): + get_model_test_dataset() args = model_testing_config.config_args + ["run.tensor_logs.save=False"] model_ref = _get_trainer_from_args(args, model_testing_config.model_type)._multi_stage model_frozen = _get_trainer_from_args( @@ -39,7 +41,7 @@ def test_frozen_weights(model_testing_config): model_frozen._num_stages, ) frozen_parameter_counts = [ - sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerBlock, LlambaBlock)) else 0 + sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerBlock, SSMBlock)) else 0 for layer in model_ref.base_model.layers ] for weight_buffer_ref, weight_buffer_frozen in zip( diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index b96a8963..b834ed4d 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -451,16 +451,14 @@ def _update_and_add_testing_config( ) _update_and_add_testing_config( - # Tests hybrid ssm, llamba converter. + # Tests hybrid Mamba, llamba converter. "llama", "llamba", model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m']", - "model.base_model.ssm.state_size=8", - "model.base_model.ssm.chunk_size=32", - "model.base_model.ssm.n_qk_heads=8", - "model.base_model.ssm.n_v_heads=8", + "model.base_model.ssm.d_inner=512", + "model.base_model.ssm.state_size=16", ], megatron_args=None, checkpoint_format=LLambaHuggingfaceCheckpointFormat, @@ -468,26 +466,31 @@ def _update_and_add_testing_config( groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.broken, # TODO: Fix and bring back to `testing_groups` + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - # TODO: Fix and bring back to `testing_groups` - ModelTestingGroup.distributed: ModelTestingGroupAction.broken, + ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, }, compare_factor=2.0, - # SSMs don't support sequence-first configurations. - skip_tests=("sf", "sdp", "stp", "ms"), + # Micro-sequence split not supported. + skip_tests=("sdp", "ms"), ) _update_and_add_testing_config( - # Tests hybrid ssm, llamba converter. - "llamba", + # Tests hybrid discrete Mamba 2. + "llama", "hybrid_discrete_mamba2", model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2d']", + "model.base_model.ssm.d_inner=512", + "model.base_model.ssm.state_size=8", + # TODO: Set to 16 once fixed. + "model.base_model.ssm.n_qk_heads=32", + "model.base_model.ssm.n_v_heads=32", + "model.base_model.ssm.chunk_size=32", ], megatron_args=None, checkpoint_format=None, @@ -497,17 +500,23 @@ def _update_and_add_testing_config( ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + # TODO: Implement + ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, }, + # Micro-sequence split and sequence-first not supported. + skip_tests=("sf", "stp", "sdp", "ms"), ) _update_and_add_testing_config( - # Tests hybrid ssm, llamba converter. - "llamba", + # Tests hybrid Mamba 2. + "llama", "hybrid_mamba2", model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2']", + "model.base_model.ssm.d_inner=512", + "model.base_model.ssm.state_size=16", + "model.base_model.ssm.d_xb=256", ], megatron_args=None, checkpoint_format=None, @@ -517,8 +526,10 @@ def _update_and_add_testing_config( ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, + # Micro-sequence split not supported. + skip_tests=("sdp", "ms"), ) From 9f7f75c72f1fff36a781773c8c772441d7fa9067 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 22 Jul 2025 19:56:35 -0400 Subject: [PATCH 05/37] fix --- fast_llm/engine/config_utils/tensor_space.py | 6 +++++- fast_llm/layers/ssm/config.py | 2 -- fast_llm/layers/ssm/discrete_mamba2.py | 4 +--- fast_llm/layers/ssm/mamba2.py | 19 +++++++++++-------- fast_llm/layers/ssm/mamba_layer.py | 5 ++++- fast_llm/tensor.py | 6 ++++++ tests/utils/model_configs.py | 9 +++++---- 7 files changed, 32 insertions(+), 19 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index d927f2e7..2ca7e3e9 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -21,7 +21,7 @@ def __init__(self, name: str, global_size: int | None, parallel_dim: Distributed def __repr__(self) -> str: return ( - f"TensorDim(" + f"{type(self).__name__}(" f"name={self._name}," f" size={self._size}," f" global_size={self._global_size}," @@ -134,6 +134,8 @@ def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: raise NotImplementedError() def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + import torch + return ( torch.concatenate( [ @@ -153,6 +155,8 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": if self.is_parallel and expand: raise NotImplementedError() + import torch + return ( torch.concatenate( [ diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index aa011f75..7da4283b 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -219,8 +219,6 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType num_head_groups = div(self.d_xb, self.state_size) elif block_type == SSMBlockType.mamba2_discrete: Assert.eq(num_heads, self.n_v_heads) - # TODO: Fix (Du einsum crashes) - Assert.eq(self.n_qk_heads, self.n_v_heads) num_head_groups = self.n_qk_heads else: raise NotImplementedError(block_type) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 14fb8aae..102accb8 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -111,7 +111,7 @@ def __init__( # D "skip" parameter self.D = ParameterMeta.from_dims( - (td_n_qk_heads,), + (td_n_v_heads,), weight_decay=False, init_method=init_ones_, lr_scale=mamba_layer_lr_scale, @@ -216,8 +216,6 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ else: y = result - print("AHNFIUWEGIUWEI", self.D.shape, x.shape) - # TODO: h different for D and x (qk_heads, v_heads) Du = torch.einsum("h,blhp->blhp", self.D, x) y = einops.rearrange(y + Du, "b l h p -> b l (h p)") diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index dff1356e..11ab91e4 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -4,7 +4,6 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.ssm.discrete_mamba2 import bias_init_method from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer @@ -62,7 +61,9 @@ def __init__( lr_scale=lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( - (conv1d_dim,), init_method=bias_init_method(self._config.conv_kernel_dimension**-0.5), lr_scale=lr_scale + (conv1d_dim,), + init_method=init_uniform_centered_(self._config.conv_kernel_dimension**-0.5), + lr_scale=lr_scale, ) self.in_proj = OutputParallelLinear( hidden_dim, @@ -124,7 +125,7 @@ def forward(self, hidden_states, kwargs): if kwargs[TransformerKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) dt = dt.transpose(0, 1) - sequence_length = hidden_states.size(1) + sequence_length = inner_projection.size(1) z, x, b, c = torch.split( inner_projection, @@ -177,9 +178,11 @@ def forward(self, hidden_states, kwargs): delta_softplus=True, ) - # y: (batch, heads * state, sequence) -> out: (batch, sequence, hidden) - out = self.out_proj(y.transpose(1, 2))[:, :sequence_length] + # y: (batch, heads * state, sequence) -> (batch, sequence, heads * state) + y = y.transpose(1, 2)[:, :sequence_length] if kwargs[TransformerKwargs.sequence_first]: - out = out.transpose(0, 1) - # TODO: Is contiguous needed? - return out.contiguous(), None + # TODO: Is contiguous needed? + y = y.transpose(0, 1).contiguous() + a, b = self.out_proj(y) + Assert.eq(a.shape, hidden_states.shape) + return a, b diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 8235f4f1..49b9e45b 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -35,7 +35,10 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) if tensor.numel() != d_state * d_inner: raise ValueError(f"_init_A requires not supported for tensor slices.") return torch.log( - torch.arange(1, d_state + 1, dtype=torch.float32, device=tensor.device).repeat(d_inner), out=tensor + torch.arange(1, d_state + 1, dtype=torch.float32, device=tensor.device) + .unsqueeze(0) + .expand(d_inner, d_state), + out=tensor, ) return init_ diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 1111fd04..25ae49a3 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -164,6 +164,9 @@ def local_to_global( *, distributed: Distributed, ) -> tuple[torch.Tensor, ...]: + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.shape) # Tensors are always either split or duplicated in the tensor-parallel direction. # TODO: Avoid hard-coded assumptions on duplication is_first_rank, modified = distributed.config.tensor_rank == 0, False @@ -195,6 +198,9 @@ def global_to_local( # Take a trivial slice to convert safetensor slices. tensor = tensor[:] assert not self._reductions + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.global_shape) for dim, tensor_dim in reversed(list(enumerate(self.dims))): tensor = tensor_dim.global_to_local(tensor, dim, expand) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index b834ed4d..47314263 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -487,9 +487,8 @@ def _update_and_add_testing_config( "model.base_model.hybrid_block_layout=['t','m2d']", "model.base_model.ssm.d_inner=512", "model.base_model.ssm.state_size=8", - # TODO: Set to 16 once fixed. - "model.base_model.ssm.n_qk_heads=32", - "model.base_model.ssm.n_v_heads=32", + "model.base_model.ssm.n_qk_heads=8", + "model.base_model.ssm.n_v_heads=16", "model.base_model.ssm.chunk_size=32", ], megatron_args=None, @@ -503,6 +502,7 @@ def _update_and_add_testing_config( # TODO: Implement ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, }, + compare_factor=2.0, # Micro-sequence split and sequence-first not supported. skip_tests=("sf", "stp", "sdp", "ms"), ) @@ -515,7 +515,7 @@ def _update_and_add_testing_config( extra_args=[ "model.base_model.hybrid_block_layout=['t','m2']", "model.base_model.ssm.d_inner=512", - "model.base_model.ssm.state_size=16", + "model.base_model.ssm.state_size=8", "model.base_model.ssm.d_xb=256", ], megatron_args=None, @@ -528,6 +528,7 @@ def _update_and_add_testing_config( ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, + compare_factor=2.0, # Micro-sequence split not supported. skip_tests=("sdp", "ms"), ) From 4054e047d7318c2dfd6e37712f3b6b94d3beca5b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 23 Jul 2025 15:22:24 -0400 Subject: [PATCH 06/37] fixes --- fast_llm/engine/config_utils/tensor_space.py | 11 ++++-- fast_llm/engine/multi_stage/stage_base.py | 2 + fast_llm/layers/ssm/mamba2.py | 41 +++++++++++--------- fast_llm/tensor.py | 2 + 4 files changed, 34 insertions(+), 22 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 2ca7e3e9..0d971a88 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -1,3 +1,4 @@ +import logging import math import typing @@ -10,6 +11,8 @@ from fast_llm.core.distributed import ProcessGroup from fast_llm.engine.distributed.distributed import Distributed +logger = logging.getLogger(__name__) + class TensorDim: def __init__(self, name: str, global_size: int | None, parallel_dim: DistributedDim | None = None): @@ -130,8 +133,10 @@ def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): self._tensor_dims = tensor_dims def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - # TODO: Implement - raise NotImplementedError() + assert self.is_parallel + return ConcatenatedTensorDim( + self.name, tuple(tensor_dim.replace_parallel_dim(distributed_dim) for tensor_dim in self._tensor_dims) + ) def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": import torch @@ -139,7 +144,7 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor return ( torch.concatenate( [ - tensor_dim.local_to_global(tensor_, dim)[0] + tensor_dim.local_to_global(tensor_, dim) for tensor_, tensor_dim in zip( tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), self._tensor_dims, diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 2f18f136..9a8ce209 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -191,6 +191,8 @@ def initialize_weights(self) -> None: # Initialize all global weights on every gpu, then select the appropriate slice if applicable. global_param = parameter.new_empty(global_shape, device=self._distributed.device) meta.init_parameter(global_param, distributed=self._distributed) + # It happens. + Assert.eq(global_param.shape, global_shape) if self._mode.on_device: parameter.copy_(fsdp.parameter_global_to_shard(global_param, meta.tensor_name)) elif self._mode.on_device: diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 11ab91e4..a285711c 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -1,3 +1,5 @@ +import logging + import torch from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace @@ -24,6 +26,8 @@ except (ImportError, RuntimeError): _causal_conv1d_available = False +logger = logging.getLogger(__name__) + class Mamba2(Mixer): """ @@ -43,21 +47,20 @@ def __init__( lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_state) + xb_dim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_head_groups_and_state) hidden_dim: TensorDim = tensor_space.get_tensor_dim(name=TransformerDimNames.hidden) dt_rank_dim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank) - self._head_groups = div(self._config.d_xb, self._config.state_size) - self._heads = div(self._config.d_inner, self._config.state_size) - self._group_heads = div(self._heads, self._head_groups) + self._local_heads = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads).size + self._local_head_groups = tensor_space.get_tensor_dim(name=SSMDimNames.head_groups).size + self._group_heads = div(self._local_heads, self._local_head_groups) + self._local_inner_size = inner_dim.size + self._local_xb_size = xb_dim.size - conv1d_dim = ( - inner_dim - if self._config.repeat_kv_before_conv - else tensor_space.get_tensor_dim(name=SSMDimNames.composite_head_groups_and_state) - ) + conv1d_dim = inner_dim if self._config.repeat_kv_before_conv else xb_dim self.conv1d_weight = ParameterMeta.from_dims( (conv1d_dim, tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel)), - init_method=init_uniform_centered_((conv1d_dim.size * self._config.conv_kernel_dimension) ** -0.5), + init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( @@ -69,7 +72,7 @@ def __init__( hidden_dim, tensor_space.get_tensor_dim(name=SSMDimNames.concatenated_inner_projection), bias=config.add_bias_linear, - weight_init_method=init_kaiming_(hidden_dim.size), + weight_init_method=init_kaiming_(hidden_dim.global_size), lr_scale=lr_scale, ) @@ -77,7 +80,7 @@ def __init__( hidden_dim, dt_rank_dim, bias=config.add_bias_linear, - weight_init_method=init_kaiming_(hidden_dim.size), + weight_init_method=init_kaiming_(hidden_dim.global_size), lr_scale=lr_scale, ) self.dt_proj = OutputParallelLinear( @@ -129,7 +132,7 @@ def forward(self, hidden_states, kwargs): z, x, b, c = torch.split( inner_projection, - [self._config.d_inner, self._config.d_xb, self._config.d_xb, self._config.d_inner], + [self._local_inner_size, self._local_xb_size, self._local_xb_size, self._local_inner_size], dim=2, ) @@ -140,28 +143,28 @@ def forward(self, hidden_states, kwargs): x = x.transpose(1, 2) if self._config.repeat_kv_before_conv: x = ( - x.unflatten(1, (self._head_groups, self._config.state_size)) - .repeat_interleave(self._group_heads, 1, output_size=self._heads) + x.unflatten(1, (self._local_head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) ) x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias, activation="silu") else: x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias, activation="silu") x = ( - x.unflatten(1, (self._head_groups, self._config.state_size)) - .repeat_interleave(self._group_heads, 1, output_size=self._heads) + x.unflatten(1, (self._local_head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) ) # b: (batch, sequence, head_groups * state) -> (batch, heads, state, sequence) b = ( b.transpose(1, 2) - .unflatten(1, (self._head_groups, self._config.state_size)) - .repeat_interleave(self._group_heads, 1, output_size=self._heads) + .unflatten(1, (self._local_head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) ) # c: (batch, sequence, heads * state) -> (batch, heads, state, sequence) - c = c.transpose(1, 2).unflatten(1, (self._heads, self._config.state_size)) + c = c.transpose(1, 2).unflatten(1, (self._local_heads, self._config.state_size)) # dt: (batch, sequence, dt_rank) -> (batch, heads * state, sequence) dt = (self.dt_proj(dt) + self.dt_proj_bias).transpose(1, 2) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 25ae49a3..6995e9e9 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -184,6 +184,7 @@ def local_to_global( tensor = tensor.clone() tensor = reduce_op(tensor, distributed_dim.group, op=op) is_first_rank, modified = is_first_rank and distributed_dim.group.rank() == 0, True + Assert.eq(tensor.shape, self.global_shape) return tensor, is_first_rank def global_to_local( @@ -204,6 +205,7 @@ def global_to_local( for dim, tensor_dim in reversed(list(enumerate(self.dims))): tensor = tensor_dim.global_to_local(tensor, dim, expand) + Assert.eq(tensor.shape, self.shape) return tensor @classmethod From 0014cc6b3f79138e53610dc86cb654a5eaba90a0 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 23 Jul 2025 18:02:43 -0400 Subject: [PATCH 07/37] fix --- fast_llm/layers/ssm/discrete_mamba2.py | 27 +++----- fast_llm/layers/ssm/llamba_block.py | 11 ++- fast_llm/layers/ssm/mamba2.py | 53 ++++++++++---- fast_llm/layers/ssm/mamba_layer.py | 11 +-- fast_llm/layers/transformer/attention.py | 69 ++++--------------- .../layers/transformer/mixture_of_experts.py | 6 +- fast_llm/layers/transformer/mlp.py | 10 +-- fast_llm/layers/transformer/transformer.py | 63 ++++++++++++++--- fast_llm/models/custom/model.py | 2 +- fast_llm/models/gpt/model.py | 4 +- fast_llm/models/ssm/model.py | 8 +-- tests/utils/model_configs.py | 6 +- 12 files changed, 154 insertions(+), 116 deletions(-) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 102accb8..b95ff76d 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -8,7 +8,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ from fast_llm.utils import get_lr_scale @@ -37,28 +38,23 @@ def bias_init_method(conv_weight): return init_uniform_centered_(bound) -class DiscreteMamba2(torch.nn.Module): +class DiscreteMamba2(Mixer): """DiscreteMamba2 (This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py).""" + _mixer_name: typing.ClassVar[str] = "discrete_mamba_2" + def __init__( self, config: SSMConfig, - layer_idx: int, + block_index: int, tensor_space: TensorSpace, - return_input: bool = False, + transformer_config: TransformerConfig, ): - """ - See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. - Other options are all experimental and should not need to be configured. - """ - # factory_kwargs = {"device": "meta"} # , "dtype": torch.bfloat16} - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) self.config: SSMConfig = config - self.layer_idx = layer_idx - self._return_input = return_input - layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) - logger.info(f"Setting lr_scale for layer {layer_idx} of type {type(self)}: {mamba_layer_lr_scale}") + logger.info(f"Setting lr_scale for layer {block_index} of type {type(self)}: {mamba_layer_lr_scale}") td_inner = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) td_state = tensor_space.get_tensor_dim(SSMDimNames.state) @@ -223,9 +219,6 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ out = self.out_proj(y * torch.nn.functional.silu(z + self.z_bias)) outputs["hidden_states"] = out[:, :seqlen, :].contiguous() - if self._return_input: - return torch.stack([input_, outputs["hidden_states"]], dim=0) - # TODO: since we do not support inference for now, we only return the hidden states for now. return outputs["hidden_states"], None diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/llamba_block.py index 774ee730..98660663 100644 --- a/fast_llm/layers/ssm/llamba_block.py +++ b/fast_llm/layers/ssm/llamba_block.py @@ -21,12 +21,17 @@ def __init__( ssm_config: "SSMConfig", tensor_space: "TensorSpace", mixer_cls: type[Mixer], - layer_index: int, + block_index: int, return_input: bool = False, ): self._ssm_config = ssm_config self._mixer_cls = mixer_cls - super().__init__(transformer_config, tensor_space, layer_index, return_input) + super().__init__(transformer_config, tensor_space, block_index, return_input) def _create_mixer(self) -> Mixer: - return self._mixer_cls(self._ssm_config, layer_idx=self._layer_index, tensor_space=self._tensor_space) + return self._mixer_cls( + self._ssm_config, + tensor_space=self._tensor_space, + block_index=self._block_index, + transformer_config=self._config, + ) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index a285711c..88fe4abc 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -1,4 +1,5 @@ import logging +import typing import torch @@ -7,7 +8,7 @@ from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_ from fast_llm.utils import Assert, div, get_lr_scale @@ -34,16 +35,31 @@ class Mamba2(Mixer): This code is adapted from https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py """ + _mixer_name: typing.ClassVar[str] = "mamba_2" + + _XZ_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.composite_heads_and_state, + TransformerDimNames.sequence_q, + ) + _BC_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.composite_heads, + SSMDimNames.state, + TransformerDimNames.sequence_q, + ) + def __init__( self, config: SSMConfig, - layer_idx: int, tensor_space: TensorSpace, + block_index: int, + transformer_config: TransformerConfig, ): - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) self._config: SSMConfig = config Assert.eq(self._config.activation_type, ActivationType.silu) - layer_lr_scale: float | None = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_state) @@ -72,7 +88,8 @@ def __init__( hidden_dim, tensor_space.get_tensor_dim(name=SSMDimNames.concatenated_inner_projection), bias=config.add_bias_linear, - weight_init_method=init_kaiming_(hidden_dim.global_size), + weight_init_method=init_kaiming_(transformer_config.hidden_size), + sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) @@ -80,7 +97,7 @@ def __init__( hidden_dim, dt_rank_dim, bias=config.add_bias_linear, - weight_init_method=init_kaiming_(hidden_dim.global_size), + weight_init_method=init_kaiming_(transformer_config.hidden_size), lr_scale=lr_scale, ) self.dt_proj = OutputParallelLinear( @@ -91,6 +108,7 @@ def __init__( weight_init_method=self._config.dt_init.get_init_method( self._config.dt_rank**-0.5 * self._config.dt_scale ), + sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) # define bias outside the linear layer since its also used in the selective_scan_fn @@ -116,6 +134,8 @@ def __init__( hidden_dim, bias=config.add_bias_linear, weight_init_method=init_kaiming_(self._config.d_inner), + sequence_parallel=self._sequence_parallel, + # TODO: lr_scale? ) def forward(self, hidden_states, kwargs): @@ -123,11 +143,12 @@ def forward(self, hidden_states, kwargs): assert _causal_conv1d_available inner_projection = self.in_proj(hidden_states) - dt = self.dt_in_proj(hidden_states) + dt = self.dt_proj(self.dt_in_proj(hidden_states)) + self.dt_proj_bias # Standardize to (batch, sequence, inner_projection) if kwargs[TransformerKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) dt = dt.transpose(0, 1) + sequence_length = inner_projection.size(1) z, x, b, c = torch.split( @@ -166,8 +187,15 @@ def forward(self, hidden_states, kwargs): # c: (batch, sequence, heads * state) -> (batch, heads, state, sequence) c = c.transpose(1, 2).unflatten(1, (self._local_heads, self._config.state_size)) - # dt: (batch, sequence, dt_rank) -> (batch, heads * state, sequence) - dt = (self.dt_proj(dt) + self.dt_proj_bias).transpose(1, 2) + # dt: (batch, sequence, heads * state) -> (batch, heads * state, sequence) + dt = dt.transpose(1, 2) + + if self._debug_level: + self._debug_log(z, "z", self._XZ_DIMS, kwargs) + self._debug_log(x, "x", self._XZ_DIMS, kwargs) + self._debug_log(b, "b", self._BC_DIMS, kwargs) + self._debug_log(c, "c", self._BC_DIMS, kwargs) + self._debug_log(dt, "dt", self._XZ_DIMS, kwargs) y = selective_scan_fn( x, @@ -181,11 +209,12 @@ def forward(self, hidden_states, kwargs): delta_softplus=True, ) + if self._debug_level: + self._debug_log(y, "y", self._XZ_DIMS, kwargs) + # y: (batch, heads * state, sequence) -> (batch, sequence, heads * state) y = y.transpose(1, 2)[:, :sequence_length] if kwargs[TransformerKwargs.sequence_first]: # TODO: Is contiguous needed? y = y.transpose(0, 1).contiguous() - a, b = self.out_proj(y) - Assert.eq(a.shape, hidden_states.shape) - return a, b + return self.out_proj(y) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 49b9e45b..49afa910 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -8,7 +8,7 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_ from fast_llm.utils import Assert, get_lr_scale @@ -58,13 +58,16 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) class MambaLayer(Mixer): + _mixer_name: typing.ClassVar[str] = "mamba" + def __init__( self, config: SSMConfig, - layer_idx: int, + block_index: int, tensor_space: TensorSpace, + transformer_config: TransformerConfig, ): - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) assert tensor_space.distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" self._config = config # TODO: It's not silu? @@ -73,7 +76,7 @@ def __init__( # Tensor dims: inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) # TODO: Backward compatibility? diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 76b8ed1c..174e1958 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -14,9 +14,8 @@ TransformerSubLayerName, ) from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.logging import log_distributed_grad, log_distributed_tensor -from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.tensor import init_normal_, init_zeros_ +from fast_llm.utils import get_lr_scale try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -56,6 +55,8 @@ class Attention(Mixer): A self-attention layer. """ + _mixer_name: typing.ClassVar[str] = "attn" + _QUERY_DIMS = ( TransformerDimNames.batch, TransformerDimNames.sequence_q, @@ -65,7 +66,7 @@ class Attention(Mixer): _KV_DIMS = ( TransformerDimNames.batch, TransformerDimNames.sequence_q, - TransformerDimNames.group_heads, + TransformerDimNames.head_groups, TransformerDimNames.kv_channels, ) _CONTEXT_DIMS = ( @@ -74,19 +75,9 @@ class Attention(Mixer): TransformerDimNames.composite_dense, ) - def __init__( - self, - config: TransformerConfig, - tensor_space: TensorSpace, - layer_index, - ): - super().__init__() + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int): + super().__init__(tensor_space, block_index, config.debug_transformer) self._config = config - self._tensor_space = tensor_space - # Assert.in_range_incl(layer_index, 1, max(self._config.num_layers, 1)) - self._layer_index = layer_index - self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel - self._debug_transformer = self._config.debug_transformer self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) init_method_qkv = init_normal_( @@ -109,7 +100,7 @@ def __init__( hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) @@ -179,10 +170,10 @@ def _attn_fused( query, key, beta=0, - alpha=self._softmax_scale / self._layer_index, + alpha=self._softmax_scale / self._block_index, ).view(b, self._local_head_groups, sq, self._local_heads_per_group, sk) - attn_weights = attn_weights.to(torch.float32) * self._layer_index + attn_weights = attn_weights.to(torch.float32) * self._block_index attn_weights = torch.where(mask, attn_weights, mask_value) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) @@ -201,40 +192,6 @@ def _attn_fused( .flatten(2) ) - def _get_meta( - self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> TensorMeta: - hidden_dims = {dim.name: dim for dim in kwargs[TransformerKwargs.hidden_dims]} - return TensorMeta.from_dims( - tuple( - hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) - for dim_name in dim_names - ), - tensor_name=f"transformer layer {self._layer_index} attn {name}", - dtype=input_.dtype, - ) - - def _debug_log( - self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> None: - # TODO: Local vs global - Assert.gt(self._debug_transformer, 0) - log_distributed_tensor( - "", - tensor, - level=self._debug_transformer, - meta=self._get_meta(tensor, name, dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - if tensor.requires_grad: - log_distributed_grad( - "", - tensor, - level=self._debug_transformer, - meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - def _query_key_value_forward( self, input_: torch.Tensor, sequence_first: bool ) -> tuple[torch.Tensor, torch.Tensor, dict[str, typing.Any]]: @@ -301,7 +258,7 @@ def _decide_window_size(self) -> int | None: # https://github.com/huggingface/transformers/blob/5e2183f344911aa82aba0b83778a4f196cff378e/src/transformers/models/qwen2/modular_qwen2.py#L71 # TODO: make universal per layer config window_size = self._config.window_size - if self._config.max_window_layers is not None and self._layer_index < self._config.max_window_layers: + if self._config.max_window_layers is not None and self._block_index < self._config.max_window_layers: window_size = None return window_size @@ -342,7 +299,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels) value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels) - if self._debug_transformer: + if self._debug_level: self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs) self._debug_log( key, @@ -396,7 +353,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ kwargs[TransformerKwargs.attention_mask_value], ) - if self._debug_transformer: + if self._debug_level: self._debug_log(query, "query", self._QUERY_DIMS, kwargs) self._debug_log( key, diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index a46af138..73f83ccf 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -40,11 +40,11 @@ class MixtureOfExpertMLP(MLPBase): _group: ProcessGroup - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, tensor_space, name, layer_index) + super().__init__(config, tensor_space, name, block_index) self._config = config self._tensor_space = tensor_space self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory @@ -59,7 +59,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._z_loss_factor = config.expert_z_loss_coefficient self._moe_jitter_eps = config.moe_jitter_eps - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) self.router = Linear( diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index b01eb2aa..efe0c5cc 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -14,10 +14,10 @@ class MLPBase(Layer, ABC): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): super().__init__() self._name = name - self._layer_index = layer_index + self._block_index = block_index init_method_1 = init_normal_( std=config.init_method_std_mlp_1, @@ -39,7 +39,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._activation_type = config.activation_type self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale lr_scale = get_lr_scale(lr_scale, layer_lr_scale) @@ -69,9 +69,9 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s class MLP(MLPBase): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.eq(config.num_experts, 1) - super().__init__(config, tensor_space, name, layer_index) + super().__init__(config, tensor_space, name, block_index) def forward( self, diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index a0611cd2..d08db9a9 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -13,6 +13,7 @@ from fast_llm.layers.transformer.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -22,6 +23,15 @@ class Mixer(torch.nn.Module, abc.ABC): Base class for mixer modules. """ + _mixer_name: typing.ClassVar[str] + + def __init__(self, tensor_space: TensorSpace, block_index: int, debug_level: int = 0): + super().__init__() + self._tensor_space = tensor_space + self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel + self._block_index = block_index + self._debug_level = debug_level + @abc.abstractmethod def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: """ @@ -29,6 +39,43 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ in case its addition can be made more efficient in `_bias_dropout_add`. """ + def _get_meta( + self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> TensorMeta: + hidden_dims = { + dim.name: dim + for dim in kwargs[TransformerKwargs.hidden_dims] + (kwargs[TransformerKwargs.sequence_q_dim],) + } + return TensorMeta.from_dims( + tuple( + hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) + for dim_name in dim_names + ), + tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", + dtype=input_.dtype, + ) + + def _debug_log( + self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> None: + # TODO: Local vs global + Assert.gt(self._debug_level, 0) + log_distributed_tensor( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name, dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + if tensor.requires_grad: + log_distributed_grad( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + class BaseBlock(Layer, abc.ABC): """ @@ -39,7 +86,7 @@ class BaseBlock(Layer, abc.ABC): _mixer_module_name: typing.ClassVar[str] = "mixer" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False ): super().__init__() self._config: TransformerConfig = config @@ -48,11 +95,11 @@ def __init__( # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input - self._layer_index = layer_index + self._block_index = block_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) # Note, layer_lr_scale does not impact the norms - # TODO: add a seperate norm_lr_scale + # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) @@ -60,7 +107,7 @@ def __init__( setattr(self, self._mixer_module_name, self._create_mixer()) self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, self._tensor_space, f"{self.name} mlp", layer_index=layer_index + self._config, self._tensor_space, f"{self.name} mlp", block_index=block_index ) # PEFT. @@ -81,7 +128,7 @@ def _bias_dropout_add( @property def name(self) -> str: - return f"{self._name} {self._layer_index}" + return f"{self._name} {self._block_index}" def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[TransformerKwargs.hidden_dims] @@ -157,11 +204,11 @@ class TransformerBlock(BaseBlock): _mixer_module_name: typing.ClassVar[str] = "self_attn" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False ): - super().__init__(config, tensor_space, layer_index, return_input) + super().__init__(config, tensor_space, block_index, return_input) def _create_mixer(self) -> Mixer: from fast_llm.layers.transformer.attention import Attention - return Attention(self._config, self._tensor_space, self._layer_index) + return Attention(self._config, self._tensor_space, self._block_index) diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index a9cf3bb8..534d813f 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -34,7 +34,7 @@ def get_layers(self) -> list[Layer]: TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, ) for i in range(self._config.transformer.num_layers) ], diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index a3a68e0a..4c1eab46 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -72,7 +72,7 @@ def get_output_layers(self) -> list[Layer]: self._config.transformer, self._tensor_space, # TODO MTP: which index? - layer_index=max(self._config.transformer.num_layers + i, 1), + block_index=max(self._config.transformer.num_layers + i, 1), # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=i < self._config.prediction_heads - 1, @@ -94,7 +94,7 @@ def get_layers(self) -> list[Layer]: TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 4a95891a..89f0cd4a 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -45,7 +45,7 @@ def get_output_layers(self) -> list[Layer]: TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=len(self._config.hybrid_block_layout), + block_index=len(self._config.hybrid_block_layout), return_input=i != self._config.prediction_heads - 1, ) ) @@ -55,7 +55,7 @@ def get_output_layers(self) -> list[Layer]: transformer_config=self._config.transformer, ssm_config=self._config.ssm, mixer_cls=self._config.ssm_block_type.get_mixer_class(), - layer_index=len(self._config.hybrid_block_layout), + block_index=len(self._config.hybrid_block_layout), tensor_space=self._tensor_space, return_input=i != self._config.prediction_heads - 1, ) @@ -79,7 +79,7 @@ def get_layers(self) -> list[Layer]: TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, return_input=( i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 ), @@ -91,7 +91,7 @@ def get_layers(self) -> list[Layer]: transformer_config=self._config.transformer, ssm_config=self._config.ssm, mixer_cls=self._config.ssm_block_type.get_mixer_class(), - layer_index=i + 1, + block_index=i + 1, tensor_space=self._tensor_space, return_input=( i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 47314263..4090e5a3 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -517,6 +517,7 @@ def _update_and_add_testing_config( "model.base_model.ssm.d_inner=512", "model.base_model.ssm.state_size=8", "model.base_model.ssm.d_xb=256", + # f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}" ], megatron_args=None, checkpoint_format=None, @@ -530,7 +531,10 @@ def _update_and_add_testing_config( }, compare_factor=2.0, # Micro-sequence split not supported. - skip_tests=("sdp", "ms"), + skip_tests=( + "sdp", + "ms", + ), # "pp","dp", "ce","16", "bf", "df", "stp"), ) From 47ad5485454236d557570a32771c5888bbb3658e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 23 Jul 2025 19:03:01 -0400 Subject: [PATCH 08/37] fixes --- Megatron-LM | 2 +- fast_llm/layers/language_model/head.py | 16 ++++++++++------ fast_llm/logging.py | 2 ++ fast_llm/tensor.py | 3 ++- tests/test_attention.py | 4 ++-- tests/utils/model_configs.py | 2 +- 6 files changed, 18 insertions(+), 11 deletions(-) diff --git a/Megatron-LM b/Megatron-LM index 511e8f5c..75b0d978 160000 --- a/Megatron-LM +++ b/Megatron-LM @@ -1 +1 @@ -Subproject commit 511e8f5cbe3ab8291953ac64e5beceb727a1b814 +Subproject commit 75b0d97876006c4b6b23fce302100d18dbf7db37 diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 25fc2b28..21bf3bbd 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -125,12 +125,16 @@ def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: if isinstance(input_, TensorMeta): - return TensorMeta.from_tensor_space( - (DefaultDimNames.scalar,), - self._tensor_space, - tensor_name="Loss", - reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa - ) + if self._is_last_head: + return TensorMeta.from_tensor_space( + (DefaultDimNames.scalar,), + self._tensor_space, + tensor_name="Loss", + reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa + ) + else: + return TensorMeta.from_dims(input_.dims[1:], tensor_name="Shared hidden") + if not self._is_last_head: # MTP: split the stacked input shared_hidden, input_ = torch.unbind(input_, dim=0) diff --git a/fast_llm/logging.py b/fast_llm/logging.py index e8334de6..6d555a0b 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -138,6 +138,8 @@ def log_tensor[ if level < 1: return tensor = tensor.detach() + if tensor.ndim == 0: + tensor = tensor[None] save_stats = TensorLogs.config.save shape = tuple(tensor.shape) _, dtype = str(tensor.dtype).split("torch.") diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 6995e9e9..899e7000 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -205,7 +205,8 @@ def global_to_local( for dim, tensor_dim in reversed(list(enumerate(self.dims))): tensor = tensor_dim.global_to_local(tensor, dim, expand) - Assert.eq(tensor.shape, self.shape) + if not expand: + Assert.eq(tensor.shape, self.shape) return tensor @classmethod diff --git a/tests/test_attention.py b/tests/test_attention.py index 87b0d3e5..dd36b840 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -17,12 +17,12 @@ def test_decide_window_size(): # Arrange - Case 1: window_size is returned (layer_index >= max_window_layers) attention._config = TransformerConfig(window_size=512, max_window_layers=2) - attention._layer_index = 2 + attention._block_index = 2 assert attention._decide_window_size() == 512 # Arrange - Case 2: window_size is None (layer_index < max_window_layers) attention._config = TransformerConfig(window_size=512, max_window_layers=2) - attention._layer_index = 1 + attention._block_index = 1 assert attention._decide_window_size() is None # Arrange - Case 3: max_window_layers is None (always return window_size) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 4090e5a3..18db0d40 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -467,7 +467,7 @@ def _update_and_add_testing_config( ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, # TODO: Fix and bring back to `testing_groups` - ModelTestingGroup.convert: ModelTestingGroupAction.broken, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, From 6a074fa3c72bbe16c617a11cff690c543e4c5e86 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 23 Jul 2025 19:50:05 -0400 Subject: [PATCH 09/37] fixes --- fast_llm/layers/ssm/config.py | 2 +- fast_llm/models/ssm/conversion.py | 18 ++++++---- tests/utils/model_configs.py | 55 ++++++++++++++++--------------- 3 files changed, 41 insertions(+), 34 deletions(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 7da4283b..15a6a821 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -168,7 +168,7 @@ class SSMConfig(LLMBlockConfig): # Initialization # dt_weight_initialization_method [Mamba2] dt_init: DTInitType = Field( - default="random", + default=DTInitType.random, desc="Initialization method for dt", hint=FieldHint.core, ) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index d5730025..43e3c67e 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -3,6 +3,7 @@ import pathlib import typing +from fast_llm.config import MISSING from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( ConstantExportParamConverter, @@ -19,7 +20,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import RMSNormalizationConfig -from fast_llm.layers.ssm.config import SSMBlockType +from fast_llm.layers.ssm.config import DTInitType, SSMBlockType from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter from fast_llm.models.ssm.config import ( AprielSSMHHybridHuggingfaceCheckpointFormat, @@ -42,11 +43,11 @@ class HybridModelCheckpointHandler(HuggingfaceStateDictCheckpointHandler): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - block_converter = RenameParamConverter( + block_converter = MappedConfigParamConverter( fast_llm_names=(("hybrid_block_layout",),), export_names=(("hybrid_block_layout",),), - ignore_missing=True, - default_value=[cls._default_block_type], + fast_llm_value=lambda x: [cls._default_block_type] if x == MISSING else x, + export_value=lambda x: [x_.value for x_ in x], ) return super()._create_config_converters() + [block_converter] @@ -202,7 +203,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ignore_missing=True, default_value=4, ), - RenameParamConverter( + MappedConfigParamConverter( fast_llm_names=(("ssm", "dt_init"),), export_names=( ( @@ -210,8 +211,8 @@ def _create_config_converters(cls) -> list[ParamConverter]: "dt_init", ), ), - ignore_missing=True, - default_value="random", + fast_llm_value=lambda x: DTInitType.random if x == MISSING else DTInitType(x), + export_value=lambda x: x.value, ), ] @@ -258,6 +259,9 @@ def _create_weight_converters(self) -> list[WeightConverter]: ) # ================================================ # Mamba2 specific parameters + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.mixer.dt_in_proj", f"model.layers.{i}.mixer.dt_in_proj", ssm_bias + ) converters += self._get_weight_and_bias_converters( f"layers.{i+1}.mixer.dt_proj", f"model.layers.{i}.mixer.dt_proj", False ) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 18db0d40..3ffc3281 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -19,7 +19,10 @@ Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) -from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat +from fast_llm.models.ssm.config import ( + AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, + LLambaHuggingfaceCheckpointFormat, +) from tests.utils.dataset import MODEL_DATASET_PREFIX, MODEL_TEST_VOCAB_SIZE from tests.utils.distributed_configs import DistributedTestingConfig @@ -467,7 +470,7 @@ def _update_and_add_testing_config( ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, # TODO: Fix and bring back to `testing_groups` - ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, @@ -477,47 +480,49 @@ def _update_and_add_testing_config( skip_tests=("sdp", "ms"), ) - _update_and_add_testing_config( - # Tests hybrid discrete Mamba 2. + # Tests hybrid Mamba 2. "llama", - "hybrid_discrete_mamba2", + "hybrid_mamba2", model_type="hybrid_ssm", extra_args=[ - "model.base_model.hybrid_block_layout=['t','m2d']", + "model.base_model.hybrid_block_layout=['t','m2']", "model.base_model.ssm.d_inner=512", "model.base_model.ssm.state_size=8", - "model.base_model.ssm.n_qk_heads=8", - "model.base_model.ssm.n_v_heads=16", - "model.base_model.ssm.chunk_size=32", + "model.base_model.ssm.d_xb=256", + # f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}" ], megatron_args=None, - checkpoint_format=None, + checkpoint_format=AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - # TODO: Implement - ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, compare_factor=2.0, - # Micro-sequence split and sequence-first not supported. - skip_tests=("sf", "stp", "sdp", "ms"), + # Micro-sequence split not supported. + skip_tests=( + "sdp", + "ms", + ), # "pp","dp", "ce","16", "bf", "df", "stp"), ) + _update_and_add_testing_config( - # Tests hybrid Mamba 2. + # Tests hybrid discrete Mamba 2. "llama", - "hybrid_mamba2", + "hybrid_discrete_mamba2", model_type="hybrid_ssm", extra_args=[ - "model.base_model.hybrid_block_layout=['t','m2']", + "model.base_model.hybrid_block_layout=['t','m2d']", "model.base_model.ssm.d_inner=512", "model.base_model.ssm.state_size=8", - "model.base_model.ssm.d_xb=256", - # f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}" + "model.base_model.ssm.n_qk_heads=8", + "model.base_model.ssm.n_v_heads=16", + "model.base_model.ssm.chunk_size=32", ], megatron_args=None, checkpoint_format=None, @@ -527,14 +532,12 @@ def _update_and_add_testing_config( ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.normal, + # TODO: Implement + ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, }, compare_factor=2.0, - # Micro-sequence split not supported. - skip_tests=( - "sdp", - "ms", - ), # "pp","dp", "ce","16", "bf", "df", "stp"), + # Micro-sequence split and sequence-first not supported. + skip_tests=("sf", "stp", "sdp", "ms"), ) From d66651f5433392794d1b45560282d9237824881d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 23 Jul 2025 19:56:19 -0400 Subject: [PATCH 10/37] Update external --- .../modeling_ssm_hybrid_apriel15b.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index f8f6a052..4fde7245 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -843,9 +843,8 @@ def __init__( self.num_C_head = self.d_inner // self.d_state self.repeat_group = self.num_C_head // self.num_xb_head - self.in_proj = nn.Linear( - self.d_model, 2 * self.d_xb + 2 * self.d_inner + self.dt_rank, bias=bias, **factory_kwargs - ) + self.in_proj = nn.Linear(self.d_model, 2 * self.d_xb + 2 * self.d_inner, bias=bias, **factory_kwargs) + self.dt_in_proj = nn.Linear(self.d_model, self.dt_rank, bias=bias, **factory_kwargs) self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=dt_proj_bias, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization @@ -933,8 +932,17 @@ def forward( outputs = {} A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - zxbcdt = self.in_proj(hidden_states) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) + zxbc = self.in_proj(hidden_states) + z, x, B, C = torch.split( + zxbc, + [ + self.d_inner, + self.d_xb, + self.d_xb, + self.d_inner, + ], + dim=-1, + ) x = rearrange(x, "b l d -> b d l") z = rearrange(z, "b l d -> b d l") @@ -944,7 +952,7 @@ def forward( B = rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() C = rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous() - dt = self.dt_proj(dt) # B, L, d_inner + dt = self.dt_proj(self.dt_in_proj(hidden_states)) # B, L, d_inner dt = rearrange(dt, "b l d -> b d l") # B, d_inner, L if self.repeat_kv_before_conv: From 4e67fbfb44ba583e3d7915bd0204c345212d8a32 Mon Sep 17 00:00:00 2001 From: Denis Kochetkov Date: Thu, 24 Jul 2025 14:55:11 +0300 Subject: [PATCH 11/37] Adds lm_eval to evaluations (#282) Co-authored-by: Joel Lamy-Poirier --- .github/workflows/ci.yaml | 2 +- .github/workflows/docs.yaml | 2 +- Dockerfile | 2 +- docs/user_guide/evaluators.md | 134 +++ fast_llm/cli.py | 3 + fast_llm/config.py | 2 +- fast_llm/core/distributed.py | 121 +++ fast_llm/engine/config_utils/run.py | 16 +- fast_llm/engine/evaluation/config.py | 61 +- fast_llm/engine/evaluation/evaluator.py | 10 +- .../engine/evaluation/lm_eval/evaluator.py | 90 ++ .../evaluation/lm_eval/fast_llm_wrapper.py | 909 ++++++++++++++++++ fast_llm/engine/evaluation/lm_eval/utils.py | 244 +++++ fast_llm/engine/inference/huggingface.py | 10 + fast_llm/engine/multi_stage/stage.py | 3 +- fast_llm/engine/training/config.py | 2 +- fast_llm/engine/training/trainer.py | 17 +- fast_llm/engine/training/wandb.py | 4 +- fast_llm/utils.py | 22 + mkdocs.yaml | 1 + setup.cfg | 5 + tests/conftest.py | 10 +- tests/models/test_lm_eval.py | 124 +++ tests/test_ssms.py | 14 +- tests/utils/dataset.py | 12 +- tests/utils/model_configs.py | 15 + tests/utils/run_test_script.py | 10 +- 27 files changed, 1791 insertions(+), 54 deletions(-) create mode 100644 docs/user_guide/evaluators.md create mode 100644 fast_llm/engine/evaluation/lm_eval/evaluator.py create mode 100644 fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py create mode 100644 fast_llm/engine/evaluation/lm_eval/utils.py create mode 100644 tests/models/test_lm_eval.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ca7ea749..59cdd51b 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -31,7 +31,7 @@ jobs: pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ - pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,DEV,DOCS]" + pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV,DOCS]" - name: Run tests run: pytest -v -ra . diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 632fa7b9..0eef80a3 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -33,7 +33,7 @@ jobs: pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ - pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,DEV,DOCS]" + pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV,DOCS]" - name: Build the documentation run: mkdocs build diff --git a/Dockerfile b/Dockerfile index e98223de..71f59fff 100644 --- a/Dockerfile +++ b/Dockerfile @@ -37,7 +37,7 @@ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/ # Install dependencies within the virtual environment. -RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,DEV]" triton==3.1.0 +RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV]" triton==3.1.0 # Copy the remaining source code with universal write permissions. COPY --chmod=777 ./Megatron-LM Megatron-LM diff --git a/docs/user_guide/evaluators.md b/docs/user_guide/evaluators.md new file mode 100644 index 00000000..0e075ceb --- /dev/null +++ b/docs/user_guide/evaluators.md @@ -0,0 +1,134 @@ +# Evaluations + +Fast-LLM allows you to perform various evaluations during training or as a separate evaluation step. In both cases, you need to use your training config with `training.evaluators` specified. + +For evaluators used during training, both `interval` and `offset` must be specified. Then, start training as usual with: + +`fast-llm train gpt --config path/to/training/config.yaml` + +To perform evaluation as a separate step, use the same training config. Depending on the training progress, either the start model or the latest checkpoint will be loaded, and `interval` and `offset` will be ignored. To start evaluation: + +`fast-llm evaluate gpt --config path/to/training/config.yaml` + +## Currently Supported Evaluators + +- `loss` +- `lm_eval` + +## Loss Evaluator + +To set up loss evaluation, specify a dataset to be used in the `data.datasets` section of the config. You must also define the loss evaluator in the `training.evaluators` config section. See example below. + +```yaml +training: + evaluations: + stack_3b: + interval: 10 + evaluator: + type: loss + iterations: 10 + dataset_name: stack_3b + fineweb: + evaluator: + type: loss + iterations: 10 + dataset_name: stack_3b + interval: 10 +data: + datasets: + stack_3b: + type: memmap + path: path/to/memmap/dataset + fineweb: + type: memmap + path: path/to/memmap/dataset1 +``` + +## Evaluation Harness (`lm_eval`) Evaluator + +**Note:** Only data parallelism is currently supported for the `lm_eval` evaluator. + +To run `lm_eval` evaluations, version `0.4.9` of `lm_eval` must be installed along with all dependencies required for your evaluation tasks. + +The following environment variables may need to be set: + +- `HF_HOME`: Path for Hugging Face data caching +- `WANDB_API_KEY_PATH`: Path to a file containing your Weights & Biases API key (if logging to W&B) +- `HUGGINGFACE_API_KEY_PATH`: Path to a file containing your Hugging Face hub token +- `NLTK_DATA`: Path to a directory that will contain downloaded NLTK packages (needed for some tasks) +- `HF_ALLOW_CODE_EVAL=1`: Required for some evaluation tasks + +You may need to specify additional environment variables depending on the `lm_eval` tasks you want to run. + +To specify an `lm_eval` task, the evaluator config includes the following fields: + +### Model Config + +The model instantiated for training is reused for evaluation, so you don't need to specify it separately. However, there are some parameters specific to `lm_eval`. See `fast_llm/engine/evaluation/config.EvaluatorLmEvalConfig` for details. + +### CLI Parameters for `lm_eval` + +All other parameters are specified as if you were calling the `lm_eval` CLI, using a list of strings. Some CLI parameters are ignored or restricted—specifically those related to model loading, W&B, batch sizes, and device setup, as these are managed by the rest of the Fast-LLM configuration. + +Also, the tokenizer must be specified in `data.tokenizer`. If the tokenizer does not have a `bos_token`, it must be specified explicitly in `data.tokenizer.bos_token`. Although `lm_eval` does not use the `bos_token` directly, it is still required because the same tokenizer is used by other Fast-LLM components. + +Below is an example of the config: + +```yaml +training: + evaluations: + lm_eval_tasks1: + interval: 10 + evaluator: + type: lm_eval + cli_args: + - --tasks + - gsm8k,xnli_en,wikitext,ifeval + - --output_path + - /path/to/lm_eval/output +data: + tokenizer: + path: path/to/the/tokenizer +``` + +It is also possible to run different tasks with different intervals and offsets—for example, to run slower or more comprehensive tasks less frequently.: + +```yaml +training: + evaluations: + gsm8k: + interval: 20 + evaluator: + type: lm_eval + cli_args: + - --tasks + - gsm8k + - --output_path + - /path/to/lm_eval/output + - --limit + - "64" + ifeval: + offset: 10 + interval: 40 + evaluator: + type: lm_eval + cli_args: + - --tasks + - ifeval + - --output_path + - /path/to/lm_eval/output + - --limit + - "32" + faster_tasks: + interval: 10 + evaluator: + type: lm_eval + cli_args: + - --tasks + - xnli_en,wikitext + - --output_path + - /path/to/lm_eval/output +data: + tokenizer: + path: path/to/the/tokenizer +``` diff --git a/fast_llm/cli.py b/fast_llm/cli.py index 66ce096d..c4a13c5d 100644 --- a/fast_llm/cli.py +++ b/fast_llm/cli.py @@ -7,6 +7,7 @@ from fast_llm.engine.config_utils.logging import configure_logging from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.config_utils.runnable import RunnableConfig +from fast_llm.utils import set_global_variables # Import these submodules to ensure classes are added to the dynamic class registry. import fast_llm.data.auto # isort: skip @@ -20,6 +21,8 @@ def fast_llm_main_wrapper(): # (Pre-)configure logging configure_logging() + # Set global and environment variables before third-party imports. + set_global_variables() try: yield except Exception as e: diff --git a/fast_llm/config.py b/fast_llm/config.py index 0004501b..c534b11f 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -735,7 +735,7 @@ def _get_class_name(cls) -> str: @classmethod def from_dict( cls, - default: "Config| dict[str, typing.Any]]", + default: "Config| dict[str, typing.Any]", *updates: "Config| dict[str | tuple[str, ...], typing.Any]", strict: bool = True, update_type: UpdateType = UpdateType.override, diff --git a/fast_llm/core/distributed.py b/fast_llm/core/distributed.py index e82e0801..86f8e729 100644 --- a/fast_llm/core/distributed.py +++ b/fast_llm/core/distributed.py @@ -8,10 +8,13 @@ import contextlib import datetime +import io import logging +import pickle import typing import torch +import torch.monitor from torch._C._distributed_c10d import Work from torch.distributed import ( # noqa ProcessGroup, @@ -46,6 +49,7 @@ def broadcast( return work else: work.wait() + return None def check_parallel_match(tensor: torch.Tensor, group: ProcessGroup | None, name: str) -> None: @@ -110,6 +114,7 @@ def send(tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, ta return work else: work.wait() + return None def recv(tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, tag: int = 0) -> Work | None: @@ -119,6 +124,7 @@ def recv(tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, ta return work else: work.wait() + return None @contextlib.contextmanager @@ -133,3 +139,118 @@ def set_generator(generator: torch.Generator) -> typing.Generator[None, None, No finally: generator.set_state(default_generator.get_state()) default_generator.set_state(old_state) + + +def gather( + tensor: torch.Tensor, + gather_list: list[torch.Tensor] | None = None, + group: ProcessGroup | None = None, + async_op: bool = False, + dst: int = 0, +): + assert group is not None + opts = torch.distributed.GatherOptions() + opts.rootRank = dst + work = group.gather([gather_list] if dst == group.rank() else [], [tensor], opts) + + if async_op: + return work + elif work is not None: + work.wait() + return None + + +def scatter( + tensor: torch.Tensor, + scatter_list: list[torch.Tensor] | None = None, + group: ProcessGroup | None = None, + async_op: bool = False, + src: int = 0, +): + assert group is not None + opts = torch.distributed.ScatterOptions() + opts.rootRank = src + opts.asyncOp = async_op + work = group.scatter( + [tensor if not tensor.is_complex() else torch.view_as_real(tensor)], + [[t if not t.is_complex() else torch.view_as_real(t) for t in scatter_list]] if src == group.rank() else [], + opts, + ) + if async_op: + return work + elif work is not None: + work.wait() + return None + + +def _object_to_tensor(obj: typing.Any) -> torch.Tensor: + f = io.BytesIO() + pickle.Pickler(f).dump(obj) + return torch.tensor(torch.UntypedStorage.from_buffer(f.getvalue(), dtype=torch.uint8), dtype=torch.uint8) + + +def _tensor_to_object(tensor: torch.Tensor) -> typing.Any: + return pickle.Unpickler(io.BytesIO(tensor.numpy(force=True).tobytes())).load() + + +def gather_object( + obj: typing.Any, + group: ProcessGroup | None = None, + dst: int = 0, +) -> list[typing.Any] | None: + assert group is not None + group_rank = group.rank() + group_size = group.size() + device = torch.cuda.current_device() + + obj_tensor = _object_to_tensor(None if group_rank == dst else obj) + sizes = torch.full([group.size()], len(obj_tensor), dtype=torch.int64, device=device) + all_gather_into_tensor(sizes, sizes[group.rank()], group=group) + sizes = sizes.tolist() + max_size = max(sizes) + + input_tensor = torch.empty(max_size, dtype=torch.uint8, device=device) + + if group_rank == dst: + output_tensors = list(torch.empty(max_size * group_size, dtype=torch.uint8, device=device).chunk(group_size)) + gather(input_tensor, output_tensors, dst=dst, group=group) + return [ + obj if rank_ == dst else _tensor_to_object(tensor[:size]) + for rank_, (tensor, size) in enumerate(zip(output_tensors, sizes, strict=True)) + ] + else: + input_tensor[: obj_tensor.numel()].copy_(obj_tensor) + gather(input_tensor, None, dst=dst, group=group) + return None + + +def scatter_object( + scatter_object_input_list: typing.Optional[list[typing.Any]] = None, + group: ProcessGroup | None = None, + src: int = 0, +) -> typing.Any: + assert group is not None + group_rank = group.rank() + group_size = group.size() + device = torch.cuda.current_device() + + if group_rank == src: + tensor_list = [ + _object_to_tensor(None if rank_ == src else obj) for rank_, obj in enumerate(scatter_object_input_list) + ] + sizes = [tensor.numel() for tensor in tensor_list] + max_size = max(sizes) + size_tensor = torch.tensor([[size, max_size] for size in sizes], dtype=torch.int64, device=device) + scatter(size_tensor[group_rank], list(size_tensor.unbind()), src=src, group=group) + scatter_list = list(torch.empty(max_size * group_size, dtype=torch.uint8, device=device).chunk(group_size)) + for scatter_tensor, tensor, size in zip(scatter_list, tensor_list, sizes, strict=True): + scatter_tensor[:size].copy_(tensor) + scatter(scatter_list[src], scatter_list, src=src, group=group) + return scatter_object_input_list[src] + else: + size_tensor = torch.empty(2, dtype=torch.int64, device=device) + scatter(size_tensor, None, src=src, group=group) + size, max_size = size_tensor.tolist() + output_tensor = torch.empty(max_size, dtype=torch.uint8, device=device) + scatter(output_tensor, None, src=src, group=group) + return _tensor_to_object(output_tensor[:size]) diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index f8cfa8c5..7ab5b8e4 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -2,7 +2,6 @@ import os import pathlib import typing -import warnings import yaml @@ -10,7 +9,7 @@ from fast_llm.engine.config_utils.logging import TensorLogs, TensorLogsConfig, configure_logging from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.utils import log +from fast_llm.utils import log, set_global_variables if typing.TYPE_CHECKING: from fast_llm.engine.distributed.distributed import Distributed @@ -99,20 +98,9 @@ def get_run(self, distributed: "Distributed") -> "Run": TritonConfig.TRITON_ENABLED = self.run.enable_triton_kernels TritonConfig.TRITON_LINEAR = self.run.triton_linear_kernels run = Run(config=self, distributed=distributed) - self._set_external_variables() + set_global_variables(not self.run.torch_dynamo_enable) return run - def _set_external_variables(self) -> None: - import torch._dynamo - - # TODO: Find an alternative to get reliable tensor-parallel overlap. - if os.environ.get("CUDA_DEVICE_MAX_CONNECTIONS", ""): - warnings.warn("Setting CUDA_DEVICE_MAX_CONNECTIONS breaks things.") - if "PYTHONHASHSEED" not in os.environ: - warnings.warn("PYTHONHASHSEED should be set and to the same value for all workers.") - - torch._dynamo.config.disable = not self.run.torch_dynamo_enable # noqa - _MAIN_RANK = 0 diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py index 7223631f..04e4227f 100644 --- a/fast_llm/engine/evaluation/config.py +++ b/fast_llm/engine/evaluation/config.py @@ -6,7 +6,7 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.engine.evaluation.evaluator import Evaluator, EvaluatorLoss + from fast_llm.engine.evaluation.evaluator import Evaluator, EvaluatorLmEval, LossEvaluator @config_class() @@ -40,7 +40,7 @@ def _from_dict( @config_class(dynamic_type={EvaluatorConfig: "loss"}) -class EvaluatorLossConfig(EvaluatorConfig): +class LossEvaluatorConfig(EvaluatorConfig): _abstract: typing.ClassVar[bool] = False iterations: int | None = Field( @@ -58,7 +58,58 @@ def get_evaluator( batch_config: BatchConfig, data_load_num_proc: int, train_iters: int | None = None, - ) -> "EvaluatorLoss": - from fast_llm.engine.evaluation.evaluator import EvaluatorLoss + ) -> "LossEvaluator": + from fast_llm.engine.evaluation.evaluator import LossEvaluator - return EvaluatorLoss(name, self, batch_config, data_load_num_proc, train_iters) + return LossEvaluator(name, self, batch_config, data_load_num_proc, train_iters) + + +@config_class(dynamic_type={EvaluatorConfig: "lm_eval"}) +class LmEvalEvaluatorConfig(EvaluatorConfig): + _abstract: typing.ClassVar[bool] = False + + cli_args: list[str] = Field( + default_factory=lambda: [], + desc="lm_eval CLI arguments, excluding those related to model, wandb, batch sizes, and device.", + ) + + truncation: bool = Field( + default=False, + desc="Whether to use truncation during tokenization (useful when inputs exceed model's max length);" + " passed to the Fast-LLM lm_eval model wrapper.", + ) + + logits_cache: bool = Field( + default=True, + desc="Whether to enable logits caching for speedup and avoiding recomputation during repeated evaluations;" + " passed to the Fast-LLM lm_eval model wrapper.", + ) + + add_bos_token: bool = Field( + default=False, + desc="Whether to prepend a beginning-of-sequence (BOS) token, required for some models like LLaMA;" + " passed to the Fast-LLM lm_eval model wrapper.", + ) + + prefix_token_id: int | None = Field( + default=None, + desc="Token ID to use as a prefix to the input (e.g., for control codes or prompts);" + " passed to the Fast-LLM lm_eval model wrapper.", + ) + + max_length: int | None = Field( + default=None, + desc="Maximum sequence length including both prompt and newly generated tokens." + " If not set, it is inferred from the Fast-LLM model config or tokenizer.", + ) + + def get_evaluator( + self, + name: str, + batch_config: BatchConfig, + data_load_num_proc: int, + train_iters: int | None = None, + ) -> "EvaluatorLmEval": + from fast_llm.engine.evaluation.lm_eval.evaluator import LmEvalEvaluator + + return LmEvalEvaluator(name, self, batch_config, data_load_num_proc, train_iters) diff --git a/fast_llm/engine/evaluation/evaluator.py b/fast_llm/engine/evaluation/evaluator.py index 3fee32ba..3bdc2407 100644 --- a/fast_llm/engine/evaluation/evaluator.py +++ b/fast_llm/engine/evaluation/evaluator.py @@ -10,7 +10,7 @@ from fast_llm.engine.config_utils.run import Run, log_main_rank from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.evaluation.config import EvaluatorConfig, EvaluatorConfigBase, EvaluatorLossConfig +from fast_llm.engine.evaluation.config import EvaluatorConfig, EvaluatorConfigBase, LossEvaluatorConfig from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.schedule.runner import ScheduleRunner @@ -20,8 +20,6 @@ from fast_llm.logging import format_metrics from fast_llm.utils import get_and_reset_memory_usage_mib -# from fast_llm.engine.training.lm_eval.evaluator import simple_evaluate as lm_eval_simple_evaluate - logger = logging.getLogger(__name__) @@ -53,7 +51,7 @@ class Evaluator[ConfigType: EvaluatorConfig](Configurable[ConfigType], abc.ABC): def __init__( self, name: str, - eval_config: EvaluatorLossConfig, + eval_config: LossEvaluatorConfig, batch_config: BatchConfig, data_load_num_proc: int, train_iters: int | None = None, @@ -97,8 +95,8 @@ def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: """ -class EvaluatorLoss[ConfigType: EvaluatorLossConfig](Evaluator[ConfigType]): - config_class: typing.ClassVar[type[EvaluatorLossConfig]] = EvaluatorLossConfig +class LossEvaluator[ConfigType: LossEvaluatorConfig](Evaluator[ConfigType]): + config_class: typing.ClassVar[type[LossEvaluatorConfig]] = LossEvaluatorConfig def setup( self, diff --git a/fast_llm/engine/evaluation/lm_eval/evaluator.py b/fast_llm/engine/evaluation/lm_eval/evaluator.py new file mode 100644 index 00000000..162ceaf6 --- /dev/null +++ b/fast_llm/engine/evaluation/lm_eval/evaluator.py @@ -0,0 +1,90 @@ +import logging +import os +import pathlib +import typing + +from fast_llm.data.data.abstract import Data +from fast_llm.engine.config_utils.run import Run +from fast_llm.engine.distributed.config import PhaseType +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.evaluation.config import LmEvalEvaluatorConfig +from fast_llm.engine.evaluation.evaluator import ( + EvaluationMetrics, + Evaluator, + EvaluatorSamplingParameters, + TrainingProgress, +) +from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.engine.schedule.runner import ScheduleRunner + +if typing.TYPE_CHECKING: + from fast_llm.engine.evaluation.lm_eval.fast_llm_wrapper import FastLLMLmEvalWrapper + from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM + +logger = logging.getLogger(__name__) + + +class LmEvalEvaluator[ConfigType: LmEvalEvaluatorConfig](Evaluator[ConfigType]): + config_class: typing.ClassVar[type[LmEvalEvaluatorConfig]] = LmEvalEvaluatorConfig + + _hf_model: "HuggingfaceBaseModelForCausalLM" = None + _flm_wrapper: "FastLLMLmEvalWrapper" = None + + def setup( + self, + distributed: Distributed, + run: Run, + multi_stage: FastLLMModel, + runner: ScheduleRunner, + data: Data, + phase: PhaseType, + ) -> None: + if "HUGGINGFACE_API_KEY_PATH" in os.environ: + os.environ["HF_TOKEN"] = pathlib.Path(os.environ["HUGGINGFACE_API_KEY_PATH"]).open("r").read().strip() + else: + if not "HF_TOKEN" in os.environ: + logger.warning( + "No `HF_TOKEN` or `HUGGINGFACE_API_KEY_PATH` environment variable provided. " + "Assuming the user has already logged in to the Hugging Face Hub." + ) + + from fast_llm.engine.evaluation.lm_eval.fast_llm_wrapper import FastLLMLmEvalWrapper + + super().setup(distributed, run, multi_stage, runner, data, phase) + + self._hf_model = self._multi_stage.config_class.get_huggingface_model_for_causal_lm_class()( + self._multi_stage, runner=self._runner + ) + + # For reporting purposes, just to indicate it is from Fast-LLM + # as lm_eval.simple_evaluate will take it for results['config']['model'] + self._hf_model.config.name_or_path = type(self._hf_model).__name__ + + self._flm_wrapper = FastLLMLmEvalWrapper( + model=self._hf_model, + tokenizer=self._data.tokenizer.tokenizer, + truncation=self._config.truncation, + logits_cache=self._config.logits_cache, + add_bos_token=self._config.add_bos_token, + prefix_token_id=self._config.prefix_token_id, + max_length=self._config.max_length, + ) + self._is_setup = True + + def run( + self, + training_progress: TrainingProgress | None = None, + run_index: int | None = None, + ) -> EvaluationMetrics: + assert self._is_setup + + # completed_steps is added to output_path like output_path/runs/run_index/completed_steps/ + completed_steps = 0 if training_progress is None else training_progress.completed_steps + + self._flm_wrapper.run(self._config.cli_args, completed_steps, self._run.index) + + # lm_eval logs to disc, wandb and prints to screen itself + return EvaluationMetrics() + + def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: + return None diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py new file mode 100644 index 00000000..8f4dffed --- /dev/null +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -0,0 +1,909 @@ +import copy +import logging + +import jinja2 +import lm_eval.api.instance +import lm_eval.api.model +import lm_eval.evaluator +import lm_eval.models.utils +import lm_eval.utils +import torch +import torch.nn.functional as F +import tqdm.auto +import transformers + +from fast_llm.core.distributed import gather_object, safe_barrier, scatter_object +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.evaluation.lm_eval.utils import prepare_lm_eval_simple_eval_params, process_lm_eval_results +from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM +from fast_llm.layers.transformer.rotary.config import NoRotaryConfig + +logger = logging.getLogger(__name__) + + +class FastLLMLmEvalWrapper(lm_eval.api.model.TemplateLM): + _DEFAULT_MAX_LENGTH = 2048 + _DEFAULT_MAX_GEN_TOKENS = 256 + + def __init__( + self, + model: HuggingfaceBaseModelForCausalLM, + tokenizer: transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerFast, + truncation: bool | None = False, + logits_cache: bool = True, + add_bos_token: bool | None = False, + prefix_token_id: int | None = None, + max_length: int | None = None, + ): + super().__init__() + + # === Distributed setup === + self._rank = 0 # For lm_eval: always run on main rank + self._world_size = 1 + self._distributed: Distributed = model._inference_runner._fast_llm_model.distributed + + if ( + self._distributed.config.sequence_data_rank == 0 + and self._distributed.config.pipeline_rank == 0 + and self._distributed.config.tensor_rank == 0 + ): + self._group = self._distributed.batch_data_group + else: + self._group = torch.distributed.GroupMember.NON_GROUP_MEMBER + + # === Model & tokenizer setup === + self._model = model + self._device = model.device + self._config = model.config + + assert isinstance(tokenizer, (transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast)) + self._tokenizer = tokenizer + self._tokenizer = lm_eval.models.utils.configure_pad_token(self._tokenizer, model_config=self._config) + + # === Generation/configuration parameters === + self._truncation = truncation + self._logits_cache = logits_cache + self._add_bos_token = add_bos_token + self._max_length = max_length + self._custom_prefix_token_id = prefix_token_id + if prefix_token_id is not None: + logger.info(f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}") + + # === Internal constants === + self._backend = "causal" + self._vocab_size = self._tokenizer.vocab_size + + # === Batch configuration === + self._batch_schedule = 1 + self._batch_sizes = {} # Not used dynamically by lm_eval + self._batch_size_per_gpu = model._inference_runner._batch_config.micro_batch_size + self._batch_size = self._batch_size_per_gpu * self._distributed.config.batch_data_parallel + self._max_batch_size = self._batch_size + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self._tokenizer.eos_token_id + + # overrides from TemplateLM, but not used externally + @property + def prefix_token_id(self): + # it is used as prefix for loglikelihood + if self._custom_prefix_token_id is not None: + return self._custom_prefix_token_id + if self._tokenizer.bos_token_id is not None: + return self._tokenizer.bos_token_id + return self._tokenizer.eos_token_id + + @property + def max_length(self): + # if max length manually set, return it + if self._max_length: + return self._max_length + + # check if it is absolute positional encoding and return max_position_embeddings + if hasattr(self._config.fast_llm_config.base_model, "transformer"): + # NOTE: will need to extend if more relative encoding types will be added + if isinstance(self._config.fast_llm_config.base_model.transformer.rotary, NoRotaryConfig): + return self._config.fast_llm_config.base_model.max_position_embeddings + + # check if tokenizer holds model sequence leigh info + if hasattr(self._tokenizer, "model_max_length"): + if self._tokenizer.model_max_length == 1000000000000000019884624838656: + return self._DEFAULT_MAX_LENGTH + return self._tokenizer.model_max_length + + # finally try to get sequence length from batch config + if hasattr(self._model._inference_runner._batch_config, "sequence_length"): + return self._model._inference_runner._batch_config.sequence_length + + return self._DEFAULT_MAX_LENGTH + + # @property + # def device(self): + # # only used for world_size when lm_eval world size > 1 and + # # should not be called with current lm_eval support implementation + # return self._device + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + @property + def tokenizer(self): + return self._tokenizer + + @property + def tokenizer_name(self) -> str: + return self._tokenizer.name_or_path.replace("/", "__") + + def run(self, cli_args: list[str], completed_steps: int, run_index: int): + if self._distributed.config.rank == 0: + args, simple_eval_kwargs = prepare_lm_eval_simple_eval_params(cli_args, completed_steps, run_index) + simple_eval_kwargs["model"] = self + + # Needed for reporting as batch_size is set from args not lm for reporting in evaluate + simple_eval_kwargs["batch_size"] = self._batch_size + simple_eval_kwargs["max_batch_size"] = self._max_batch_size + + # As of lm_eval commit 758c5ed891b1ca48acd8d3a0d309a827215796b7 + # Expected to be a string even if empty and not None in simple_evaluate + simple_eval_kwargs["model_args"] = "" + + results = lm_eval.evaluator.simple_evaluate(**simple_eval_kwargs) + self.stop_workers() + + # Evaluation_tracker save expects model to be either string, but if model is passed + # LM wrapper needs to be deep copyable and json serializable + simple_eval_kwargs["evaluation_tracker"].general_config_tracker.model_source = ( + self._model.config.name_or_path + ) + + if results is not None: + process_lm_eval_results( + args, + results, + simple_eval_kwargs["evaluation_tracker"], + completed_steps, + ) + else: + self.worker_model_invoke() + + # TODO: do we need it here as self.stop_workers() and self.worker_model_invoke() + # already have barrier + safe_barrier(self._distributed.world_group, f"lm_eval Run end") + + def _model_invoke( + self, + input_ids, + attention_mask, + labels, + max_length, + stop, + generate: bool, + continue_generate: bool, + **generation_kwargs, + ): + # TODO: Consider passing true messages and payloads around instead of combining all data into a large tuple. + # Messages could include types like logits, generate, finished. + + # Group is always None if world size is 1 + if self._group is None: + # Must not be called with continue_generate false on one process + assert continue_generate + return self._model_invoke_inner( + input_ids, attention_mask, labels, max_length, stop, generate, **generation_kwargs + ) + + world_size = self._group.size() + + assert self._group.rank() == 0 + + if continue_generate: + assert input_ids is not None + if generate: + assert max_length is not None and stop is not None + + # always divide by world_size, if not full batch, some ranks will get less work or not at all + assert self._batch_size % world_size == 0 + step = self._batch_size // world_size + + input_ids = [input_ids[i * step : (i + 1) * step] for i in range(world_size)] + attention_mask = [ + attention_mask[i * step : (i + 1) * step] if attention_mask is not None else None + for i in range(world_size) + ] + labels = [labels[i * step : (i + 1) * step] if labels is not None else None for i in range(world_size)] + + scatter_list = [ + [ + input_ids[i], + attention_mask[i], + labels[i], + max_length, + stop, + generate, + continue_generate, + generation_kwargs, + ] + for i in range(world_size) + ] + else: + scatter_list = [[None, None, None, None, None, None, False, None] for _ in range(world_size)] + + input_ids, attention_mask, labels, max_length, stop, generate, continue_generate, generation_kwargs = ( + scatter_object( + scatter_list, + group=self._group, + ) + ) + + if not continue_generate: + return None + + assert len(input_ids) > 0 + + result = self._model_invoke_inner( + input_ids, attention_mask, labels, max_length, stop, generate, **generation_kwargs + ) + + gather_list = gather_object(result, group=self._group) + # Clean gather list from empty shards + gather_list = [el for el in gather_list if len(el) > 0] + + # If it was model generate tensors could be of different length + # so we aggregate results to list instead of a tensor + if generate: + result = sum((el.tolist() for el in gather_list), []) + else: + assert all(el.device.type == "cpu" for el in gather_list) + result = torch.cat(gather_list, dim=0) + + return result + + def worker_model_invoke(self): + assert self._group is not None + # if isinstance(self.group, dist.ProcessGroup): + if not isinstance(self._group, int): + # groups is None for world_size 1 + assert self._group.rank() != 0 + # on worker ranks the function need to wait to be called multiple times + while True: + input_ids, attention_mask, labels, max_length, stop, generate, continue_generate, generation_kwargs = ( + scatter_object( + None, + group=self._group, + ) + ) + + # Stop signal was send, end waiting/processing loop + if not continue_generate: + break + + # if some data was received, work, otherwise return empty tensor + if len(input_ids) > 0: + result = self._model_invoke_inner( + input_ids, attention_mask, labels, max_length, stop, generate, **generation_kwargs + ) + else: + result = input_ids + + gather_object(result, group=self._group) + else: + # TODO: implement distributed model support + assert self._group == torch.distributed.GroupMember.NON_GROUP_MEMBER + safe_barrier(self._distributed.world_group, "lm_eval_end") + + def stop_workers(self): + # Group is always None if world size is 1 + if self._group is None: + return + self._model_invoke(None, None, None, None, None, None, continue_generate=False) + safe_barrier(self._distributed.world_group, "lm_eval_end") + + def _model_invoke_inner( + self, input_ids, attention_mask, labels, max_length, stop, generate: bool, **generation_kwargs + ): + if generate: + return self._model_generate_inner(input_ids, attention_mask, max_length, stop, **generation_kwargs) + else: + return self._model_call_inner(input_ids, attention_mask, labels) + + def _model_call(self, input_ids, attention_mask=None, labels=None): + return self._model_invoke( + input_ids, attention_mask, labels, None, None, generate=False, continue_generate=True + ) + + def _model_generate(self, input_ids, attention_mask, max_length, stop, **generation_kwargs): + return self._model_invoke( + input_ids, + attention_mask, + None, + max_length, + stop, + generate=True, + continue_generate=True, + **generation_kwargs, + ) + + def _model_call_inner(self, input_ids, attention_mask=None, labels=None): + """ + :param input_ids: torch.Tensor + A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape + [batch, sequence_ctx]. the size of sequence may vary from call to call + :param attention_mask: torch.Tensor, optional + A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed + (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM + :param labels: torch.Tensor, optional + A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed + (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM + :return + A torch tensor of shape [batch, sequence, vocab] with the + logits returned from the model's decoder + """ + if attention_mask is not None or labels is not None: + assert attention_mask is not None and labels is not None + + # TODO: do we need no_grad for fast_llm model? + with torch.no_grad(): + # We move logits to the CPU because they will be copied across processes and nodes + # in a multi-GPU, multi-node setup and eventually collected on the main rank. + # We cannot afford to accumulate them on rank 0 GPU, as GPU memory may already be tight. + # CPU tensors are slower, but we typically have much more CPU RAM available. + + # TODO: Check if it's possible to move some of the _loglikelihood_tokens work here + # and pass only the results around instead of the full logits. + # Computing errors here is also preferable, as copying logits across nodes and GPUs + # is inefficient and can involve gigabytes of data. + return self._model( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + use_cache=False, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ).logits.cpu() + + def _model_generate_inner(self, input_ids, attention_mask, max_length, stop, **generation_kwargs): + # temperature = 0.0 if not set + # if do_sample is false and temp==0.0: + # remove temperature, as do_sample=False takes care of this + # and we don't want a warning from HF + generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) + do_sample = generation_kwargs.get("do_sample", None) + + # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies + if generation_kwargs.get("temperature") == 0.0 and do_sample is None: + generation_kwargs["do_sample"] = do_sample = False + + if do_sample is False and generation_kwargs.get("temperature") == 0.0: + generation_kwargs.pop("temperature") + # build stopping criteria + stopping_criteria = lm_eval.models.utils.stop_sequences_criteria( + self._tokenizer, stop, input_ids.shape[1], input_ids.shape[0] + ) + + return self._model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + stopping_criteria=stopping_criteria, + pad_token_id=self._tokenizer.pad_token_id, + use_cache=False, + **generation_kwargs, + ) + + def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> list[int]: + """ """ + # default for None - empty dict, use predefined tokenizer param + # used for all models except for CausalLM or predefined value + special_tokens_kwargs = {} + + # by default for CausalLM - false or self.add_bos_token is set + if add_special_tokens is None: + if self._backend == "causal": + special_tokens_kwargs = {"add_special_tokens": False or self._add_bos_token} + # otherwise the method explicitly defines the value + else: + special_tokens_kwargs = {"add_special_tokens": add_special_tokens} + + encoding = self._tokenizer.encode(string, **special_tokens_kwargs) + + # left-truncate the encoded context to be at most `left_truncate_len` tokens long + if left_truncate_len: + encoding = encoding[-left_truncate_len:] + + return encoding + + def tok_batch_encode( + self, + strings: list[str], + padding_side: str = "left", + left_truncate_len: int = None, + truncation: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode. + old_padding_side = self._tokenizer.padding_side + self._tokenizer.padding_side = padding_side + + add_special_tokens = {} + if self._backend == "causal": + add_special_tokens = {"add_special_tokens": False or self._add_bos_token} + + encoding = self._tokenizer( + strings, + truncation=truncation, + padding="longest", + return_tensors="pt", + **add_special_tokens, + ) + if left_truncate_len: + original_lengths = encoding["input_ids"].size(1) + if original_lengths > left_truncate_len: + logger.warn( + f"Left truncation applied. Original sequence length was {original_lengths}, " + f"truncating to last {left_truncate_len} tokens. Some content will be lost.", + ) + encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:] + encoding["attention_mask"] = encoding["attention_mask"][:, -left_truncate_len:] + self._tokenizer.padding_side = old_padding_side + + return encoding["input_ids"], encoding["attention_mask"] + + def tok_decode(self, tokens, skip_special_tokens=True): + return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) + + def _select_cont_toks(self, logits: torch.Tensor, contlen: int = None, inplen: int = None) -> torch.Tensor: + if self._backend == "causal": + assert contlen and inplen, "Must pass input len and cont. len to select scored logits for causal LM" + # discard right-padding. + # also discard the input/context tokens. we'll only score continuations. + logits = logits[inplen - contlen : inplen] + elif self._backend == "seq2seq": + assert contlen and not inplen, "Selecting scored logits for Seq2SeqLM requires only cont. len" + # only discard right-padding. + # the logits input to this fn only contain decoder-side tokens. + logits = logits[:contlen] + + return logits + + def loglikelihood_rolling( + self, requests: list[lm_eval.api.instance.Instance], disable_tqdm: bool = False + ) -> list[float]: + adaptive_batch_size = None + if self._batch_size == "auto": + # using rolling window with maximum context + print("Passed argument batch_size = auto. Detecting largest batch size") + batch_size = self._detect_batch_size() + print(f"Determined Largest batch size: {batch_size}") + adaptive_batch_size = batch_size + + # First, collect all windows from all requests + all_windows = [] # List of (request_idx, window) tuples + request_window_counts = [] # Track number of windows per request + + for req_idx, (string,) in enumerate( + tqdm.auto.tqdm( + [req.args for req in requests], + disable=(disable_tqdm or (self.rank != 0)), + ) + ): + # The tokenizer may raise: "Token indices sequence length is longer than the specified maximum sequence length for this model" + # This is expected and fine, as the sequence will be split into chunks of max_length later. + rolling_token_windows: list[tuple[list[int], list[int]]] = list( + map( + lm_eval.utils.make_disjoint_window, + lm_eval.utils.get_rolling_token_windows( + token_list=self.tok_encode(string), + prefix_token=self.prefix_token_id, + max_seq_len=self.max_length, + context_len=1, + ), + ) + ) + + # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case + windows = [(None,) + x for x in rolling_token_windows] + + # Store windows with their request index + all_windows.extend((req_idx, window) for window in windows) + request_window_counts.append(len(windows)) + + # Handle distributed case padding + pad_amnt = 0 + if self.world_size > 1: + mytensor = torch.tensor(len(all_windows), device=self._device) + gathered = self.accelerator.gather(mytensor).cpu().detach().numpy().tolist() + pad_amnt = max(gathered) - gathered[self.rank] + if pad_amnt > 0: + all_windows += pad_amnt * [all_windows[0]] + + all_nlls = [] + batch_size = adaptive_batch_size or self._batch_size + for i in range(0, len(all_windows), batch_size): + batch = all_windows[i : i + batch_size] + # Extract just the windows for processing, keeping track of request indices + batch_indices, batch_windows = zip(*batch) + + batch_nlls = self._loglikelihood_tokens( + requests=batch_windows, + disable_tqdm=False, + override_bs=len(batch_windows), + ) + # Store results with their request indices + all_nlls.extend(zip(batch_indices, batch_nlls)) + + # Remove padding if necessary + if (self.world_size > 1) and (pad_amnt > 0): + all_nlls = all_nlls[:-pad_amnt] + + # Reconstruct per-request loglikelihoods + loglikelihoods = [] + current_idx = 0 + for window_count in request_window_counts: + # Get all nlls for this request + request_nlls = all_nlls[current_idx : current_idx + window_count] + # Sum up the nlls for this request (discarding is_greedy) + request_total = sum(nll[0] for _, nll in request_nlls) + loglikelihoods.append(request_total) + current_idx += window_count + + string = requests[len(loglikelihoods) - 1].args[0] + self.cache_hook.add_partial("loglikelihood_rolling", (string,), request_total) + + return loglikelihoods + + def _batch_scheduler(self, pos, n_reordered_requests): + sched = pos // int(len(n_reordered_requests) / self._batch_schedule) + if sched in self._batch_sizes: + return self._batch_sizes[sched] + if (len(self._batch_sizes) > 1) and (self._batch_sizes[sched - 1] == self._max_batch_size): + # if previous batch size is already maximal, skip recomputation + self._batch_sizes[sched] = self._max_batch_size + return self._batch_sizes[sched] + print(f"Passed argument batch_size = auto:{self._batch_schedule}. Detecting largest batch size") + self._batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos) + print(f"Determined largest batch size: {self._batch_sizes[sched]}") + return self._batch_sizes[sched] + + def _loglikelihood_tokens( + self, + requests: list[tuple[tuple[str, str], list[int], list[int]]], + disable_tqdm: bool = False, + override_bs: int = None, + ) -> list[tuple[float, bool]]: + # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context + res = [] + + # NOTE: for the sort_fn, the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + # NOTE: the group_fn Defines the key to group and lookup one-token continuations + # Use with group_by="contexts" (optional)" + # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations. + # speeds up some multiple-choice tasks proportionally to the number of choices. + # groups requests by context+continuation[:-1] and infer on one request/group. + re_ord = lm_eval.models.utils.Collator( + requests, + sort_fn=lambda req: (-(len(req[1]) + len(req[2])), tuple(req[1]) + tuple(req[2])), + group_by="contexts" if self._backend == "causal" and self._logits_cache else None, + group_fn=lambda req: req[-2] + req[-1][:-1], + ) + + # automatic (variable) batch size detection for vectorization + # pull longest context sample from request + n_reordered_requests = len(re_ord) + batch_size = self._batch_size if self._batch_size != "auto" else override_bs if override_bs is not None else 0 + batch_fn = ( + self._batch_scheduler + if self._batch_size == "auto" and n_reordered_requests > 0 and not override_bs + else None + ) + + chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn) + pbar = tqdm.auto.tqdm( + total=len(requests), + disable=(disable_tqdm or (self.rank != 0)), + desc="Running loglikelihood requests", + ) + for chunk in chunks: + inps = [] + cont_toks_list = [] + inplens = [] + + conts = [] + encoder_attns = [] + + padding_len_inp = None + padding_len_cont = None + # because vectorizing is annoying, we first convert each (context, continuation) pair to padded + # tensors, then we pack them together into a batch, call the model, and then pick it all apart + # again because vectorizing is annoying + + for _, context_enc, continuation_enc in chunk: + # sanity check + assert len(context_enc) > 0 + assert len(continuation_enc) > 0 + assert len(continuation_enc) <= self.max_length + + # how this all works (illustrated on a causal decoder-only setup): + # CTX CONT + # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] + # model \ \ + # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the + # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice + + # when too long to fit in context, truncate from the left + if self._backend == "causal": + total_length = len(context_enc) + len(continuation_enc) + if total_length > self.max_length + 1: + logger.warning( + f"Combined length of context ({len(context_enc)}) and continuation ({len(continuation_enc)}) " + f"exceeds model's maximum length ({self.max_length}). " + f"Truncating {total_length - self.max_length + 1} tokens from the left." + ) + inp = torch.tensor( + (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], + dtype=torch.long, + device=self._device, + ) + (inplen,) = inp.shape + elif self._backend == "seq2seq": + inp = torch.tensor( + (context_enc)[-self.max_length :], + dtype=torch.long, + device=self._device, + ) + (inplen,) = inp.shape + + # build encoder attn masks + encoder_attns.append(torch.ones_like(inp)) + + cont = torch.tensor( + (continuation_enc)[-self.max_length :], + # TODO: left-shift these? + # TODO: our code assumes we never end up truncating conts for either model type + dtype=torch.long, + device=self._device, + ) + (contlen,) = cont.shape + + conts.append(cont) + + padding_len_cont = max(padding_len_cont, contlen) if padding_len_cont is not None else contlen + + padding_len_inp = max(padding_len_inp, inplen) if padding_len_inp is not None else inplen + + inps.append(inp) # [1, inp_length] + cont_toks_list.append(continuation_enc) + inplens.append(inplen) + + # create encoder attn mask and batched conts, if seq2seq + call_kwargs = {} + if self._backend == "causal": + batched_inps = lm_eval.models.utils.pad_and_concat( + padding_len_inp, inps, padding_side="right" + ) # [batch, padding_len_inp] + elif self._backend == "seq2seq": + # TODO: left-pad encoder inps and mask? + batched_inps = lm_eval.models.utils.pad_and_concat(padding_len_inp, inps) # [batch, padding_len_inp] + batched_conts = lm_eval.models.utils.pad_and_concat( + padding_len_cont, conts + ) # [batch, padding_len_cont] + batched_encoder_mask = lm_eval.models.utils.pad_and_concat( + padding_len_inp, encoder_attns + ) # [batch, padding_len_inp] + call_kwargs = { + "attention_mask": batched_encoder_mask, + "labels": batched_conts, + } + + multi_logits = F.log_softmax( + self._model_call(batched_inps, **call_kwargs), dim=-1 + ) # [batch, padding_length (inp or cont), vocab] + + # TODO: Consider moving this part to per-shard execution in a multi-GPU and multi-node setup + # to avoid copying logits between GPUs and nodes, and to enable performing logits computations on the GPU. + for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip( + chunk, multi_logits, inplens, cont_toks_list + ): + # Slice to original seq length + contlen = len(cont_toks) + # take only logits in the continuation + # (discard context toks if decoder-only ; discard right-padding) + # also discards + checks for "virtual tokens" in the causal LM's input window + # from prompt/prefix tuning tokens, if applicable + ctx_len = inplen + (logits.shape[0] - padding_len_inp) if self._backend == "causal" else None + logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) + logits = logits.unsqueeze(0) # [1, seq, vocab] + + # Check if per-token argmax is exactly equal to continuation + greedy_tokens = logits.argmax(dim=-1) + + # check for one-token continuation cache hits. + # noop in case group_by != "contexts" or no cache hit and returns the + # original args. Otherwise, expands the logits batch dimension and yields each + # batch along with matching continuation tokens and prompt strings. + # logits -> [1, seq, vocab] + for request_str, cont_toks, logits in re_ord.get_cache( + req_str=request_str, + cxt_toks=ctx_tokens, + cont_toks=cont_toks, + logits=logits, + ): + # NOTE: Currently, computations are performed on the CPU due to limited GPU memory. + cont_toks = torch.tensor(cont_toks, dtype=torch.long, device="cpu").unsqueeze(0) # [1, seq] + + max_equal = (greedy_tokens == cont_toks).all() + + # Obtain log-probs at the corresponding continuation token indices + # last_token_slice = logits[:, -1, :].squeeze(0).tolist() + logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq] + + # Answer: (log prob, is-exact-match) + answer = (float(logits.sum()), bool(max_equal)) + + res.append(answer) + + if request_str is not None: + # special case: loglikelihood_rolling produces a number of loglikelihood requests + # all with cache key None. instead do add_partial on the per-example level + # in the loglikelihood_rolling() function for those. + self.cache_hook.add_partial("loglikelihood", request_str, answer) + pbar.update(1) + + pbar.close() + + return re_ord.get_original(res) + + def generate_until(self, requests: list[lm_eval.api.instance.Instance], disable_tqdm: bool = False) -> list[str]: + res = [] + + pbar = tqdm.auto.tqdm( + total=len(requests), + disable=(disable_tqdm or (self.rank != 0)), + desc="Running generate_until requests", + ) + adaptive_batch_size = None + if self._batch_size == "auto": + # using rolling window with maximum context + print("Passed argument batch_size = auto. Detecting largest batch size") + batch_size = self._detect_batch_size() + print(f"Determined Largest batch size: {batch_size}") + adaptive_batch_size = batch_size + # for each different set of kwargs, we execute all requests, by batch. + batch_size = ( + self._batch_size + if self._batch_size != "auto" + else adaptive_batch_size if adaptive_batch_size is not None else 0 + ) + batch_fn = self._batch_scheduler if self._batch_size == "auto" and not adaptive_batch_size else None + + # we group requests by their generation_kwargs, + # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling + # in the same batch. + # group_fn=lambda x: x[1] -> x=(context, gen_kwargs) + # NOTE: for sort_fn, the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + re_ords = lm_eval.models.utils.Collator( + [reg.args for reg in requests], + sort_fn=lambda req: (-len(self.tok_encode(req[0])), req[0]), + group_by="gen_kwargs", + group_fn=lambda x: x[1], + ) + chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn) + eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False) + + for chunk in chunks: + contexts, all_gen_kwargs = zip(*chunk) + # we assume all gen kwargs in the batch are the same + # this is safe to assume because the `grouper` object ensures it. + gen_kwargs = all_gen_kwargs[0] + # unpack our keyword arguments. + if isinstance(gen_kwargs, dict): + kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 + # add EOS token to stop sequences + until = lm_eval.models.utils.handle_stop_sequences(kwargs.pop("until", None), eos=eos) + else: + raise ValueError(f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}") + if "max_gen_toks" in kwargs.keys(): + max_gen_toks = kwargs.pop("max_gen_toks") + else: + max_gen_toks = self._DEFAULT_MAX_GEN_TOKENS + + # set the max length in tokens of inputs ("context_enc") + if self._backend == "causal": + # max len for inputs = max length, minus room to generate the max new tokens + max_ctx_len = self.max_length - max_gen_toks + assert ( + max_ctx_len > 0 + ), f"Invalid configuration: requested max tokens to generate ({max_gen_toks}) must be less than model's maximum sequence length ({self.max_length})." + elif self._backend == "seq2seq": + # max len for inputs = encoder's whole max_length + max_ctx_len = self.max_length + + # encode, pad, and truncate contexts for this batch + input_ids, attention_mask = self.tok_batch_encode( + contexts, + left_truncate_len=max_ctx_len, + truncation=self._truncation, + ) + input_ids = input_ids.to(self._device) + attention_mask = attention_mask.to(self._device) + + if "max_length" not in kwargs: + kwargs["max_length"] = input_ids.shape[1] + max_gen_toks + + # perform batched generation + cont = self._model_generate( + input_ids=input_ids, + attention_mask=attention_mask, + stop=until, + **kwargs, + ) + + # cont_toks_list = cont.tolist() + cont_toks_list = cont + + for cont_toks, context in zip(cont_toks_list, contexts): + # discard context + left-padding toks if using causal decoder-only LM + if self._backend == "causal": + cont_toks = cont_toks[input_ids.shape[1] :] + + s = self.tok_decode(cont_toks) + + # use secondary stop seqs to cut off should-have-been-stopped content post-hoc + for term in until: + if len(term) > 0: + # ignore '' separator, + # for seq2seq case where self.tok_decode(self.eot_token_id) = '' + s = s.split(term)[0] + + res.append(s) + + self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s) + pbar.update(1) + # reorder this group of results back to original unsorted form + res = re_ords.get_original(res) + + pbar.close() + + return res + + def apply_chat_template(self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True) -> str: + """ + Method to apply a chat template to a list of chat history between user and model. + """ + try: + chat_templated = self._tokenizer.apply_chat_template( + chat_history, + tokenize=False, + add_generation_prompt=add_generation_prompt, + continue_final_message=not add_generation_prompt, + ) + except jinja2.exceptions.TemplateError: + logger.warning("Failed to apply chat template. removing the system role in chat history.") + chat_history = [msg for msg in chat_history if msg["role"] != "system"] + chat_templated = self._tokenizer.apply_chat_template( + chat_history, + tokenize=False, + add_generation_prompt=add_generation_prompt, + continue_final_message=not add_generation_prompt, + ) + + return chat_templated diff --git a/fast_llm/engine/evaluation/lm_eval/utils.py b/fast_llm/engine/evaluation/lm_eval/utils.py new file mode 100644 index 00000000..afcfc1a9 --- /dev/null +++ b/fast_llm/engine/evaluation/lm_eval/utils.py @@ -0,0 +1,244 @@ +import argparse +import json +import logging +import os +import pathlib +import sys +from pathlib import Path + +import lm_eval.__main__ +import lm_eval.evaluator +import lm_eval.loggers +import lm_eval.tasks +import lm_eval.utils + +from fast_llm.utils import Assert + +logger = logging.getLogger(__name__) + + +def parse_eval_args(parser: argparse.ArgumentParser, args: list[str]) -> argparse.Namespace: + lm_eval.__main__.check_argument_types(parser) + return parser.parse_args(args) + + +def prepare_lm_eval_simple_eval_params( + cli_args: list[str], + completed_steps: int, + run_index: int, +) -> tuple[argparse.Namespace, dict[str, any]]: + """ + Parses CLI arguments for an LM evaluation run and prepares keyword arguments + for the `evaluate` function. + + This function wraps argument parsing, environment configuration, task resolution, + and metadata setup needed for evaluation with Fast-LLM's `lm_eval` wrapper. It also + handles special cases like hub token injection, dynamic sample loading, and task + listing commands. + + Args: + cli_args (list[str]): Command-line arguments, excluding the program name. + completed_steps (int): Current number of completed training steps, used to + uniquely tag evaluation output paths. + run_index (int): index of the current run of Fast-LLM experiment + + Returns: + tuple: + - argparse.Namespace: Parsed CLI arguments. + - dict: Keyword arguments to pass into `simple_evaluate`, including task list, + tracker, cache settings, random seeds, and generation parameters. + + Raises: + ValueError: If required fields like `--tasks` or `--output_path` are missing + when needed, or if misconfigured combinations are detected. + SystemExit: If special task listing flags are used. + """ + parser = lm_eval.__main__.setup_parser() + parser.add_argument( + "--no_defaults", + action="store_true", + ) + args = parse_eval_args(parser, cli_args) + + # NOTE: all this args are set by fast_llm on the model directly or not used here + Assert.eq(args.wandb_args, "") + Assert.eq(args.wandb_config_args, "") + Assert.eq(args.model, "hf") + Assert.eq(args.model_args, "") + Assert.eq(int(args.batch_size), 1) + Assert.none(args.max_batch_size) + Assert.none(args.device) + + # update the evaluation tracker args with the output path and the HF token + evaluation_tracker_args = "" + if args.output_path: + args.output_path = str(pathlib.Path(args.output_path) / f"runs/{run_index}/{completed_steps}") + evaluation_tracker_args += f",output_path={args.output_path}" + + evaluation_tracker_args = lm_eval.utils.simple_parse_args_string(evaluation_tracker_args) + evaluation_tracker = lm_eval.loggers.EvaluationTracker(**evaluation_tracker_args) + + if args.predict_only: + args.log_samples = True + if (args.log_samples or args.predict_only) and not args.output_path: + raise ValueError("Specify --output_path if providing --log_samples or --predict_only") + + if args.fewshot_as_multiturn and args.apply_chat_template is False: + raise ValueError( + "When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)." + ) + + if args.include_path is not None: + args.include_path = args.include_path.split(",") + logger.info(f"Including paths: {args.include_path}") + metadata = ( + lm_eval.utils.simple_parse_args_string(args.model_args) + if isinstance(args.model_args, str) + else args.model_args if isinstance(args.model_args, dict) else {} + ) | (args.metadata if isinstance(args.metadata, dict) else lm_eval.utils.simple_parse_args_string(args.metadata)) + + task_manager = lm_eval.tasks.TaskManager( + verbosity=args.verbosity, + include_path=args.include_path, + include_defaults=not args.no_defaults, + metadata=metadata, + ) + + if args.limit: + logger.warning(" --limit SHOULD ONLY BE USED FOR TESTING." "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.") + if args.samples: + assert args.limit is None, "If --samples is not None, then --limit must be None." + if (samples := Path(args.samples)).is_file(): + args.samples = json.loads(samples.read_text()) + else: + args.samples = json.loads(args.samples) + + if args.tasks is None: + logger.error("Need to specify task to evaluate.") + sys.exit() + elif args.tasks == "list": + print(task_manager.list_all_tasks()) + sys.exit() + elif args.tasks == "list_groups": + print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False)) + sys.exit() + elif args.tasks == "list_tags": + print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False)) + sys.exit() + elif args.tasks == "list_subtasks": + print(task_manager.list_all_tasks(list_groups=False, list_tags=False)) + sys.exit() + else: + if os.path.isdir(args.tasks): + import glob + + task_names = [] + yaml_path = os.path.join(args.tasks, "*.yaml") + for yaml_file in glob.glob(yaml_path): + config = lm_eval.utils.load_yaml_config(yaml_file) + task_names.append(config) + else: + task_list = args.tasks.split(",") + task_names = task_manager.match_tasks(task_list) + for task in [task for task in task_list if task not in task_names]: + if os.path.isfile(task): + config = lm_eval.utils.load_yaml_config(task) + task_names.append(config) + task_missing = [ + task for task in task_list if task not in task_names and "*" not in task + ] # we don't want errors if a wildcard ("*") task name was used + + if task_missing: + missing = ", ".join(task_missing) + logger.error( + f"Tasks were not found: {missing}\n" + f"{lm_eval.utils.SPACING}Try `lm-eval --tasks list` for list of available tasks", + ) + raise ValueError( + f"Tasks not found: {missing}. Try `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all" + " available names for task groupings; only (sub)tasks; tags; or all of the above, or pass '--verbosity DEBUG'" + " to troubleshoot task registration issues." + ) + + logger.info(f"Selected Tasks: {task_names}") + + request_caching_args = lm_eval.evaluator.request_caching_arg_to_dict(cache_requests=args.cache_requests) + + eval_kwargs = dict( + tasks=task_names, + num_fewshot=args.num_fewshot, + # batch_size=args.batch_size, + # max_batch_size=args.max_batch_size, + # device=args.device, + use_cache=args.use_cache, + limit=args.limit, + samples=args.samples, + check_integrity=args.check_integrity, + write_out=args.write_out, + log_samples=args.log_samples, + evaluation_tracker=evaluation_tracker, + system_instruction=args.system_instruction, + apply_chat_template=args.apply_chat_template, + fewshot_as_multiturn=args.fewshot_as_multiturn, + gen_kwargs=args.gen_kwargs, + task_manager=task_manager, + predict_only=args.predict_only, + random_seed=args.seed[0], + numpy_random_seed=args.seed[1], + torch_random_seed=args.seed[2], + fewshot_random_seed=args.seed[3], + confirm_run_unsafe_code=args.confirm_run_unsafe_code, + metadata=metadata, + **request_caching_args, + ) + + return args, eval_kwargs + + +def process_lm_eval_results( + args: argparse.Namespace, + results: dict[str, any], + evaluation_tracker: lm_eval.loggers.EvaluationTracker, + completed_steps: int | None, +) -> None: + if results is not None: + completed_steps = 0 if completed_steps is None else completed_steps + import wandb + + if args.log_samples: + samples = results.pop("samples") + dumped = json.dumps(results, indent=2, default=lm_eval.utils.handle_non_serializable, ensure_ascii=False) + if args.show_config: + print(dumped) + + batch_sizes = ",".join(map(str, results["config"]["batch_sizes"])) + + # Add W&B logging if we have the run to log to + # we expect the rest of the fast_llm code will finish the run. + if wandb.run is not None: + try: + wandb_logger = lm_eval.loggers.WandbLogger(init_args={"step": completed_steps}) + wandb_logger.post_init(results) + wandb_logger.log_eval_result() + if args.log_samples: + wandb_logger.log_eval_samples(samples) + except Exception as e: + logger.info(f"Logging to Weights and Biases failed due to {e}") + + evaluation_tracker.save_results_aggregated(results=results, samples=samples if args.log_samples else None) + + if args.log_samples: + for task_name, config in results["configs"].items(): + evaluation_tracker.save_results_samples(task_name=task_name, samples=samples[task_name]) + + if evaluation_tracker.push_results_to_hub or evaluation_tracker.push_samples_to_hub: + evaluation_tracker.recreate_metadata_card() + + # TODO: convert to logging entries instead? + print( + f"{results["config"]["model"]}, gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, " + f"batch_size: {results["config"]["batch_size"]}{f' ({batch_sizes})' if batch_sizes else ''}" + ) + print(lm_eval.utils.make_table(results)) + if "groups" in results: + print(lm_eval.utils.make_table(results, "groups")) diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index 3c2db428..54a82492 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -1,3 +1,4 @@ +import logging import os import pathlib import typing @@ -14,6 +15,8 @@ from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.utils import Assert +logger = logging.getLogger(__name__) + class HuggingfacePreTrainedModel(transformers.PreTrainedModel): config_class: typing.ClassVar[type[HuggingfaceModelConfig]] = HuggingfaceModelConfig @@ -41,6 +44,8 @@ def __init__( # The HF constructor performs a deep copy of the config, # but config.fast_llm_config may contain non-picklable items like process groups. # Temporarily remove it before the call and restore it afterward. + # TODO: Find a clean solution — overriding __deepcopy__ doesn't work here + # because internally they use copy.deepcopy(self.__dict__). fast_llm_config = config.fast_llm_config config.fast_llm_config = None super().__init__(config, **kwargs) @@ -64,6 +69,11 @@ def __init__( with transformers.modeling_utils.no_init_weights(): self.post_init() + if fast_llm_model.config.multi_stage.zero_stage == 3: + logger.warning( + "zero_stage=3 is used for the model; forward and generate will be extremely slow during inference." + ) + @classmethod def from_pretrained( cls, diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index a2a9d9d3..87eac31c 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -113,7 +113,8 @@ def forward( losses, metrics, ) - self._log_layer_forward(output, kwargs, i) + if output is not None: + self._log_layer_forward(output, kwargs, i) # TODO: very slow and memory consuming, only use for debugging for now # TODO: decide if and how we want to return diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index efe8f714..4b8d805b 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -54,7 +54,7 @@ class IntervalConfig(Config): def _validate(self) -> None: if self.interval: - with self._set_implicit_default(): + with self._set_implicit_default(None): self.offset %= self.interval super()._validate() diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 64408bb0..5f5511a1 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -303,7 +303,16 @@ def _run_training(self) -> None: else: metrics = {} done = True - self._evaluator_runner.run(metrics=metrics) + self._evaluator_runner.run( + metrics=metrics, + # This is set to ensure that evaluators like lm_eval log results at the correct step if a checkpoint was loaded. + training_progress=TrainingProgress( + done=done, + completed_steps=self._completed_steps, + consumed_samples=self._consumed_samples, + consumed_tokens=self._consumed_tokens, + ), + ) if done and PhaseType.test in self._samples_per_split: log_main_rank(lambda: f"Running test phase ...") @@ -318,7 +327,7 @@ def _run_training(self) -> None: log_main_rank(formatted_metrics) self._wandb.alert("Testing results", formatted_metrics, "WARN") # TODO: This may erase some metrics. - self._wandb.log_metrics(self._completed_steps, metrics) + self._wandb.log_metrics(self._completed_steps, metrics, commit=True) def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: # Tracking loss. @@ -339,6 +348,8 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: self._config.training.prefetch_factor, ) + has_test_phase = PhaseType.test in self._samples_per_split + log_main_rank("Training ...") # TODO: Synchronization is probably unnecessary. @@ -456,7 +467,7 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: ) if is_main_rank() and metrics: - self._wandb.log_metrics(self._completed_steps, metrics) + self._wandb.log_metrics(self._completed_steps, metrics, commit=not (done and has_test_phase)) stop = done or self._config.training.shutdown.enabled(self._completed_steps) diff --git a/fast_llm/engine/training/wandb.py b/fast_llm/engine/training/wandb.py index 185b89c2..724b5b71 100644 --- a/fast_llm/engine/training/wandb.py +++ b/fast_llm/engine/training/wandb.py @@ -44,12 +44,12 @@ def __init__(self, config: WandbConfig, run: Run, experiment_config: Config): else: self._wandb = None - def log_metrics(self, completed_steps: int, metrics: dict[str, dict[str, float | int]]) -> None: + def log_metrics(self, completed_steps: int, metrics: dict[str, dict[str, float | int]], commit: bool) -> None: # Note: metrics modified in-place if self._wandb is not None: import wandb - wandb.log(metrics, step=completed_steps) # noqa + wandb.log(metrics, step=completed_steps, commit=commit) # noqa def alert(self, title, text, level="INFO", wait=0.001) -> None: if self._wandb is not None and self._config.alert.post_alerts: diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 472f5e9b..58285d40 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -2,8 +2,10 @@ import itertools import logging import math +import os import signal import typing +import warnings from typing import Callable if typing.TYPE_CHECKING: @@ -395,6 +397,26 @@ def interrupted(self): return self._interrupted +def set_global_variables(disable_torch_dynamo: bool = False) -> None: + # Set global and environment variables. This needs to be called before importing any third-party package. + # TODO: Find an alternative to get reliable tensor-parallel overlap. + if os.environ.get("CUDA_DEVICE_MAX_CONNECTIONS", ""): + warnings.warn("Setting CUDA_DEVICE_MAX_CONNECTIONS breaks things.") + # All distributed workers need the same hash seed for consistent hashing. + if "PYTHONHASHSEED" not in os.environ: + warnings.warn("PYTHONHASHSEED should be set and to the same value for all workers.") + # On systems with more than 64 cores, numexpr may log an error and ignore the thread setting. + if "NUMEXPR_MAX_THREADS" not in os.environ: + import multiprocessing + + os.environ["NUMEXPR_MAX_THREADS"] = str(multiprocessing.cpu_count()) + + if disable_torch_dynamo: + import torch._dynamo + + torch._dynamo.config.disable = True # noqa + + _global_max_allocated = 0 _global_max_reserved = 0 diff --git a/mkdocs.yaml b/mkdocs.yaml index ab71bc23..85fd4bff 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -179,6 +179,7 @@ nav: - Configuration: user_guide/configuration.md - Multi-Stage: user_guide/multi-stage.md - Parallelism: user_guide/parallelism.md + - Evaluators: user_guide/evaluators.md - Developer Guide: - Configuration: developer_guide/configuration.md - Model: diff --git a/setup.cfg b/setup.cfg index 2f69b8e0..843aa15c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -51,6 +51,11 @@ HUGGINGFACE = # MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation SSM = mamba_ssm[causal-conv1d]==2.2.4 + cartesia_pytorch>=0.0.2 + +GENERATION = + lm_eval>=0.4.9 + DEV = # Pre-commit git hook diff --git a/tests/conftest.py b/tests/conftest.py index 298117e1..19bdfe5d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,9 +8,13 @@ import pytest import xdist.scheduler -from fast_llm.utils import get_and_reset_memory_usage_mib +from fast_llm.utils import get_and_reset_memory_usage_mib, set_global_variables from tests.utils.depends import DependencyManager +# TODO: Is this early enough? +set_global_variables() # isort: skip + + if worker_name := os.environ.get("PYTEST_XDIST_WORKER"): if gpus := os.environ.get("CUDA_VISIBLE_DEVICES"): # We set the device through "CUDA_VISIBLE_DEVICES", and this needs to happen before importing torch. @@ -225,7 +229,9 @@ def pytest_terminal_summary(terminalreporter): terminalreporter.write_sep("=", "Highest gpu memory usage", bold=True) sorted_nodeids = sorted( resource_reports.keys(), - key=lambda nodeid: resource_reports[nodeid]["max_reserved"], + key=lambda nodeid: ( + resource_reports[nodeid]["max_reserved"] if "max_reserved" in resource_reports[nodeid] else 0 + ), reverse=True, ) for nodeid in sorted_nodeids[: terminalreporter.config.getoption("--show-gpu-memory")]: diff --git a/tests/models/test_lm_eval.py b/tests/models/test_lm_eval.py new file mode 100644 index 00000000..b9e2aa8c --- /dev/null +++ b/tests/models/test_lm_eval.py @@ -0,0 +1,124 @@ +import pathlib +import shutil + +import pytest + +from tests.utils.dataset import TOKENIZER_PATH, download_santacoder_tokenizer +from tests.utils.distributed_configs import DistributedTestingConfig +from tests.utils.model_configs import ModelTestingGroup +from tests.utils.utils import requires_cuda + +# NOTE: These tests only verify that the functionality runs without crashing. +# NOTE: The tokenizer is from a LLaMA-style model, which may not be suitable for all models, +# but it should be sufficient since we are not concerned with actual accuracy in this tests. + + +@pytest.fixture(scope="module") +def tokenizer_path(): + download_santacoder_tokenizer() + return TOKENIZER_PATH + + +@pytest.fixture(scope="function") +def get_lm_eval_config(tokenizer_path, monkeypatch): + # TODO: Investigate why loading the tokenizer here gives a vocab_size + # smaller than 49157, which is the size when loaded by Fast-LLM. + import lm_eval.evaluator + + # lm_eval gathers lots of system info when reporting results, and this is extremely slow, so we skip here. + monkeypatch.setattr(lm_eval.evaluator, "add_env_info", lambda x: None, raising=True) + + def do_get_lm_eval_config(base_path): + import lm_eval.tasks + + task_dir = pathlib.Path(lm_eval.tasks.__file__).parent.resolve() + return [ + f"data.tokenizer.path={tokenizer_path}", + f"model.base_model.vocab_size=49157", + "training.evaluators.evaluation_test.interval=2", + "training.evaluators.evaluation_test.evaluator.type=lm_eval", + "training.evaluators.evaluation_test.evaluator.cli_args=" + f'["--tasks=wikitext",' + f'"--output_path={str(base_path / "lm_eval")}",' + # lm_eval loads all available tasks by default which is slow. + f'"--include_path={str(task_dir / "wikitext")}",' + f'"--no_defaults",' + f'"--limit=1",' + f'"--batch_size=1",' + f'"--verbosity=DEBUG"]', + ] + + return do_get_lm_eval_config + + +# "gsm8k,xnli_en,wikitext" + + +@requires_cuda +@pytest.mark.model_testing_group(ModelTestingGroup.generate) +def test_lm_eval_in_training(run_test_script_for_all_models, run_test_script_base_path, get_lm_eval_config): + run_test_script_for_all_models( + distributed_testing_config=DistributedTestingConfig( + name="lm_eval_in_training", + config_args=get_lm_eval_config(run_test_script_base_path / "lm_eval_in_training") + + ["training.checkpoint.interval=2"], + ) + ) + + +@pytest.fixture(scope="module") +def copy_training_output(run_test_script_base_path: pathlib.Path): + def do_copy_training_output(distributed_testing_config: DistributedTestingConfig): + self_path = run_test_script_base_path / distributed_testing_config.name + shutil.copytree(run_test_script_base_path / distributed_testing_config.compare, self_path) + + return do_copy_training_output + + +@requires_cuda +@pytest.mark.depends_on(on=["test_lm_eval_in_training[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.generate) +def test_lm_eval_evaluation_last_checkpoint( + run_test_script_for_all_models, run_test_script_base_path, get_lm_eval_config, copy_training_output +): + distributed_testing_config = DistributedTestingConfig( + name="lm_eval_evaluation_last_checkpoint", + config_args=get_lm_eval_config(run_test_script_base_path / "lm_eval_evaluation_last_checkpoint"), + compare="lm_eval_in_training", + ) + copy_training_output(distributed_testing_config) + run_test_script_for_all_models(distributed_testing_config=distributed_testing_config, runnable_type="evaluate") + + +@requires_cuda +@pytest.mark.depends_on(on=["test_lm_eval_in_training[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.generate) +def test_lm_eval_evaluation_from_pretrained( + run_test_script_for_all_models, run_test_script_base_path, get_lm_eval_config +): + run_test_script_for_all_models( + distributed_testing_config=DistributedTestingConfig( + name="lm_eval_evaluation_from_pretrained", + config_args=get_lm_eval_config(run_test_script_base_path / "lm_eval_evaluation_from_pretrained") + + [ + "pretrained.format=distributed", + f"pretrained.path={run_test_script_base_path/'lm_eval_in_training/checkpoint/2'}", + "pretrained.model_weights=True", + ], + ) + ) + + +# TODO: rewrite for a new distributed test function +# @requires_cuda +# @pytest.mark.depends_on(on=["test_lm_eval_in_training[{model_testing_config}]"]) +# @pytest.mark.model_testing_group(ModelTestingGroup.generate, ModelTestingGroup.distributed) +# def test_lm_eval_in_training_dp2(run_test_script_for_all_models, run_test_script_base_path, get_lm_eval_config): +# run_test_script_for_all_models( +# distributed_testing_config=DistributedTestingConfig( +# name="lm_eval_in_training_dp2", +# config_args=get_lm_eval_config(run_test_script_base_path / "lm_eval_in_training_dp2") +# + ["training.checkpoint.interval=1"], +# num_gpus=2, +# ) +# ) diff --git a/tests/test_ssms.py b/tests/test_ssms.py index 52b51c8a..694faa55 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -14,21 +14,15 @@ from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat from fast_llm.models.ssm.model import HybridSSMModel -try: - from cartesia_pytorch.Llamba.llamba import LlambaLMHeadModel as LMHeadModel -except ImportError: - LMHeadModel = None - +@pytest.mark.skip("Disabled due to cartesia_pytorch installation issue") @pytest.mark.slow -@pytest.mark.skipif( - LMHeadModel is None, - reason=f"cartesia_pytorch.Llamba not installed", -) def test_load_from_llamba_checkpoint(): """ Test to check whether the of Fast-LLM and Huggingface checkpoint loading for Llamba-1B produce the same results. """ + import cartesia_pytorch.Llamba.llamba + vocab_size = 128256 # from https://huggingface.co/cartesia-ai/Llamba-1B/blob/main/config.json batch_size = 2 seq_length = 32 @@ -38,7 +32,7 @@ def test_load_from_llamba_checkpoint(): x = torch.randint(0, vocab_size, (batch_size, seq_length), device="cuda") - hf_model = LMHeadModel.from_pretrained(path, strict=True).to("cuda") + hf_model = cartesia_pytorch.Llamba.llamba.LMHeadModel.from_pretrained(path, strict=True).to("cuda") parameter_sum_hf = sum(p.detach().sum().cpu().item() for p in hf_model.parameters()) hf_logits = hf_model(x)["logits"].cpu() del hf_model diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index a4136c40..b770675d 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -24,6 +24,13 @@ MODEL_TEST_VOCAB_SIZE = 384 +def download_santacoder_tokenizer(): + if not TOKENIZER_FILE.is_file(): + import transformers + + transformers.AutoTokenizer.from_pretrained("bigcode/santacoder").save_pretrained(TOKENIZER_PATH) + + def get_test_dataset( prefix: pathlib.Path = DATASET_PREFIX, seed: int = 1234, @@ -32,10 +39,7 @@ def get_test_dataset( vocab_size: int = TEST_VOCAB_SIZE, max_spans: int = 0, ): - if not TOKENIZER_FILE.is_file(): - import transformers - - transformers.AutoTokenizer.from_pretrained("bigcode/santacoder").save_pretrained(TOKENIZER_PATH) + download_santacoder_tokenizer() if not ( prefix.with_suffix(".idx").is_file() diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index b96a8963..1eee3675 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -23,6 +23,10 @@ from tests.utils.dataset import MODEL_DATASET_PREFIX, MODEL_TEST_VOCAB_SIZE from tests.utils.distributed_configs import DistributedTestingConfig +from fast_llm.engine.evaluation.evaluators import ( # isort:skip # needed for dynamic type registration + EvaluatorsConfig, +) + _LOG_LEVEL = int(os.environ.get("LOG_LEVEL", 13)) @@ -70,6 +74,17 @@ def trainer_config(self) -> TrainerConfig: # See `RunnableConfig._from_parsed_args` return self.trainer_config_class.from_dict(self.trainer_config_class._parse_updates(self.config_args)) + @functools.cached_property + def evaluators_config_class(self) -> type[EvaluatorsConfig]: + # EvaluatorsConfig is a base class that, during parse_and_run, replaces itself with the appropriate TrainingConfig subclass. + # Therefore, the arguments passed to EvaluatorsConfig.parse_and_run must include the model type as the first element. + return EvaluatorsConfig + + @functools.cached_property + def evaluators_config(self) -> EvaluatorsConfig: + # See `RunnableConfig._from_parsed_args` + return self.evaluators_config_class.from_dict(self.evaluators_config_class._parse_updates(self.config_args)) + @functools.cached_property def model_config_class(self) -> type[FastLLMModelConfig]: # TODO: Ok to assume the model and trainer have the same name? diff --git a/tests/utils/run_test_script.py b/tests/utils/run_test_script.py index b8f996a8..7d706ebd 100644 --- a/tests/utils/run_test_script.py +++ b/tests/utils/run_test_script.py @@ -69,12 +69,13 @@ def do_run_test_script_for_all_models( distributed_testing_config: DistributedTestingConfig, model_testing_config: ModelTestingConfig, base_path: pathlib.Path, + runnable_type: str = "train", ): Assert.leq(distributed_testing_config.num_gpus, DistributedConfig.default_world_size) get_model_test_dataset() args = [ "fast-llm", - "train", + runnable_type, model_testing_config.model_type, *model_testing_config.config_args, *distributed_testing_config.config_args, @@ -83,7 +84,12 @@ def do_run_test_script_for_all_models( f"run.experiment_dir={base_path/distributed_testing_config.name}", ] print(" ".join(args)) - model_testing_config.trainer_config_class.parse_and_run(args[3:]) + if runnable_type == "train": + model_testing_config.trainer_config_class.parse_and_run(args[3:]) + elif runnable_type == "evaluate": + model_testing_config.evaluators_config_class.parse_and_run(args[2:]) + else: + raise ValueError(f"Unknown runnable_type {runnable_type}") @pytest.fixture(scope="function") From 50083ba88a0bfa58747d2bc8307814b62af1a79a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 15:14:13 -0400 Subject: [PATCH 12/37] SSM debugging --- Megatron-LM | 2 +- fast_llm/engine/multi_stage/stage_base.py | 2 + fast_llm/layers/language_model/head.py | 16 ++- fast_llm/layers/ssm/config.py | 34 +++--- fast_llm/layers/ssm/discrete_mamba2.py | 23 ++-- fast_llm/layers/ssm/llamba_block.py | 29 +++-- fast_llm/layers/ssm/mamba2.py | 38 ++++-- fast_llm/layers/ssm/mamba_layer.py | 36 +++--- fast_llm/layers/transformer/attention.py | 72 +++-------- .../layers/transformer/mixture_of_experts.py | 6 +- fast_llm/layers/transformer/mlp.py | 10 +- fast_llm/layers/transformer/transformer.py | 94 ++++++++++++--- fast_llm/logging.py | 2 + fast_llm/models/gpt/model.py | 12 +- fast_llm/models/ssm/config.py | 40 +++---- fast_llm/models/ssm/model.py | 113 +++++------------- setup.cfg | 7 +- tests/data/test_blending.py | 1 + tests/data/test_concatenate.py | 1 + tests/data/test_fim.py | 2 + tests/test_attention.py | 4 +- tests/test_multi_stage.py | 8 +- tests/utils/model_configs.py | 1 + 23 files changed, 271 insertions(+), 282 deletions(-) diff --git a/Megatron-LM b/Megatron-LM index 511e8f5c..75b0d978 160000 --- a/Megatron-LM +++ b/Megatron-LM @@ -1 +1 @@ -Subproject commit 511e8f5cbe3ab8291953ac64e5beceb727a1b814 +Subproject commit 75b0d97876006c4b6b23fce302100d18dbf7db37 diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 2f18f136..9a8ce209 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -191,6 +191,8 @@ def initialize_weights(self) -> None: # Initialize all global weights on every gpu, then select the appropriate slice if applicable. global_param = parameter.new_empty(global_shape, device=self._distributed.device) meta.init_parameter(global_param, distributed=self._distributed) + # It happens. + Assert.eq(global_param.shape, global_shape) if self._mode.on_device: parameter.copy_(fsdp.parameter_global_to_shard(global_param, meta.tensor_name)) elif self._mode.on_device: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 25fc2b28..21bf3bbd 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -125,12 +125,16 @@ def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: if isinstance(input_, TensorMeta): - return TensorMeta.from_tensor_space( - (DefaultDimNames.scalar,), - self._tensor_space, - tensor_name="Loss", - reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa - ) + if self._is_last_head: + return TensorMeta.from_tensor_space( + (DefaultDimNames.scalar,), + self._tensor_space, + tensor_name="Loss", + reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa + ) + else: + return TensorMeta.from_dims(input_.dims[1:], tensor_name="Shared hidden") + if not self._is_last_head: # MTP: split the stacked input shared_hidden, input_ = torch.unbind(input_, dim=0) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 46d629aa..a1f357de 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -23,6 +23,7 @@ class SSMDimNames: # Mamba 2 x_proj_dim_2 = "x_proj_dim_2" # d_xb + c_heads = "c_heads" class SSMBlockType(enum.StrEnum): @@ -35,6 +36,22 @@ class SSMBlockType(enum.StrEnum): mamba2 = "m2" transformer = "t" + def get_mixer_class(self): + if self == SSMBlockType.mamba: + from fast_llm.layers.ssm.mamba_layer import MambaLayer + + return MambaLayer + elif self == SSMBlockType.mamba2: + from fast_llm.layers.ssm.mamba2 import Mamba2 + + return Mamba2 + elif self == SSMBlockType.mamba2_discrete: + from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 + + return DiscreteMamba2 + else: + raise NotImplementedError(self) + @config_class() class SSMConfig(LLMBlockConfig): @@ -95,11 +112,6 @@ class SSMConfig(LLMBlockConfig): desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", hint=FieldHint.architecture, ) - debug_ssm: bool = Field( - default=False, - desc="debug_ssm", - hint=FieldHint.optional, - ) dt_min: float = Field( default=0.001, desc="Minimum step size for discretization", @@ -147,18 +159,6 @@ class SSMConfig(LLMBlockConfig): hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - dt_min: float = Field( - default=0.001, - desc="Minimum step size for discretization", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), - ) - dt_init_floor: float = Field( - default=1e-4, - desc="Minimum value for initializing dt", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), - ) dt_scale: float = Field( default=1.0, desc="Scale for dt", diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 934cd2b5..734e35b2 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -1,5 +1,6 @@ import logging import math +import typing import einops import torch @@ -7,7 +8,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs +from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_ from fast_llm.utils import get_lr_scale @@ -36,29 +38,29 @@ def bias_init_method(conv_weight): return init_uniform_(-bound, bound) -class DiscreteMamba2(torch.nn.Module): +class DiscreteMamba2(Mixer): """DiscreteMamba2 (This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py).""" + _mixer_name: typing.ClassVar[str] = "discrete_mamba_2" + def __init__( self, config: SSMConfig, - layer_idx: int, + block_index: int, tensor_space: TensorSpace, - return_input: bool = False, + transformer_config: TransformerConfig, ): """ See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. Other options are all experimental and should not need to be configured. """ # factory_kwargs = {"device": "meta"} # , "dtype": torch.bfloat16} - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) self.config: SSMConfig = config bias = config.add_bias_linear - self.layer_idx = layer_idx - self._return_input = return_input - layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) - logger.info(f"Setting lr_scale for layer {layer_idx} of type {type(self)}: {mamba_layer_lr_scale}") + logger.info(f"Setting lr_scale for layer {block_index} of type {type(self)}: {mamba_layer_lr_scale}") td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) @@ -226,9 +228,6 @@ def forward(self, hidden_states, kwargs): out = self.out_proj(y * torch.nn.functional.silu(z + self.z_bias)) outputs["hidden_states"] = out[:, :seqlen, :].contiguous() - if self._return_input: - return torch.stack([input_, outputs["hidden_states"]], dim=0) - # TODO: since we do not support inference for now, we only return the hidden states for now. return outputs["hidden_states"], None diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/llamba_block.py index ee222d6d..98660663 100644 --- a/fast_llm/layers/ssm/llamba_block.py +++ b/fast_llm/layers/ssm/llamba_block.py @@ -1,6 +1,6 @@ import typing -from fast_llm.layers.transformer.transformer import BaseBlock +from fast_llm.layers.transformer.transformer import BaseBlock, Mixer if typing.TYPE_CHECKING: from fast_llm.engine.config_utils.tensor_space import TensorSpace @@ -8,27 +8,30 @@ from fast_llm.layers.transformer.config import TransformerConfig -class LlambaBlock(BaseBlock): +class SSMBlock(BaseBlock): """ A transformer-like decoder block with a SSM mixer, see https://arxiv.org/abs/2502.14458 """ _name = "Llamba block" - _mixer_module_name = "mixer" def __init__( self, - config_transformer: "TransformerConfig", - config_ssm: "SSMConfig", + transformer_config: "TransformerConfig", + ssm_config: "SSMConfig", tensor_space: "TensorSpace", - mixer_cls, - layer_index: int, + mixer_cls: type[Mixer], + block_index: int, return_input: bool = False, ): - self.mixer_cls = mixer_cls - self._config_ssm = config_ssm - self._debug_mode = self._config_ssm.debug_ssm - super().__init__(config_transformer, tensor_space, layer_index, return_input) + self._ssm_config = ssm_config + self._mixer_cls = mixer_cls + super().__init__(transformer_config, tensor_space, block_index, return_input) - def _create_mixer(self): - self.mixer = self.mixer_cls(self._config_ssm, layer_idx=self._layer_index, tensor_space=self._tensor_space) + def _create_mixer(self) -> Mixer: + return self._mixer_cls( + self._ssm_config, + tensor_space=self._tensor_space, + block_index=self._block_index, + transformer_config=self._config, + ) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index a03509ab..ead32fa2 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -7,6 +7,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames +from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_fill_, init_ones_, init_uniform_, kaiming_init_ from fast_llm.utils import get_lr_scale @@ -43,24 +45,36 @@ def bias_init_method(conv_weight): return init_uniform_(-bound, bound) -class Mamba2(torch.nn.Module): +class Mamba2(Mixer): """ This code is adapted from https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py """ + _mixer_name: typing.ClassVar[str] = "mamba_2" + + _XZ_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.inner_dim, + TransformerDimNames.sequence_q, + ) + _BC_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.c_heads, + SSMDimNames.state_dim, + TransformerDimNames.sequence_q, + ) + def __init__( self, config: SSMConfig, - layer_idx: int, tensor_space: TensorSpace, - return_input: bool = False, + block_index: int, + transformer_config: TransformerConfig, ): - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) self.config: SSMConfig = config bias: bool = config.add_bias_linear - self.layer_idx = layer_idx - self._return_input = return_input - layer_lr_scale: float | None = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None mamba_layer_lr_scale: float | tuple[float | None, ...] | None = get_lr_scale( self.config.mamba_lr_scale, layer_lr_scale ) @@ -236,6 +250,13 @@ def forward(self, hidden_states, kwargs): x = repeat_kv(x, self.repeat_group) x = einops.rearrange(x, "b n_group l dstate -> b (n_group dstate) l") + if self._debug_level: + self._debug_log(z, "z", self._XZ_DIMS, kwargs) + self._debug_log(x, "x", self._XZ_DIMS, kwargs) + self._debug_log(B, "b", self._BC_DIMS, kwargs) + self._debug_log(C, "c", self._BC_DIMS, kwargs) + self._debug_log(dt, "dt", self._XZ_DIMS, kwargs) + y = selective_scan_fn( x, dt, @@ -249,6 +270,9 @@ def forward(self, hidden_states, kwargs): return_last_state=False, ) + if self._debug_level: + self._debug_log(y, "y", self._XZ_DIMS, kwargs) + if ssm_state is not None: y, last_state = y ssm_state.copy_(einops.rearrange(last_state, "b (h d) n -> b h d n", h=self.num_C_head)) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 7c824d23..a95e94c0 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -1,4 +1,5 @@ import math +import typing from typing import Callable import einops @@ -7,6 +8,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_ones_, kaiming_init_ from fast_llm.utils import get_lr_scale @@ -44,12 +47,12 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) def init_dtprojbias( - d_inner: int, dt_max: float, dt_min: float, dt_init_floor: float, factory_kwargs: dict + d_inner: int, dt_max: float, dt_min: float, dt_init_floor: float ) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - dt = torch.exp( - torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) - ).clamp(min=dt_init_floor) + dt = torch.exp(torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)).clamp( + min=dt_init_floor + ) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) tensor.copy_(inv_dt) @@ -58,20 +61,18 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) return init_ -class MambaLayer(torch.nn.Module): +class MambaLayer(Mixer): + _mixer_name: typing.ClassVar[str] = "mamba" + def __init__( self, config: SSMConfig, - layer_idx: int, + block_index: int, tensor_space: TensorSpace, - return_input: bool = False, + transformer_config: TransformerConfig, ): - factory_kwargs = {} - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) self.config: SSMConfig = config - self.layer_idx = layer_idx - - self._debug_mode = config.debug_ssm # Tensor dims: td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) @@ -88,7 +89,7 @@ def __init__( self.d_state = td_state.size self.d_model = td_model.size self.dt_rank = tdt_rank.size - layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) self.in_proj_weight = ParameterMeta.from_dims( @@ -113,7 +114,6 @@ def __init__( weight_init_method=kaiming_init_(td_inner.size), bias=False, lr_scale=mamba_layer_lr_scale, - **factory_kwargs, ) self.x_proj.weight.auto_grad_accumulation = True @@ -127,7 +127,7 @@ def __init__( self.dt_proj_bias = ParameterMeta.from_dims( (td_inner,), init_method=init_dtprojbias( - self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor, factory_kwargs + self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor ), lr_scale=mamba_layer_lr_scale, ) @@ -153,10 +153,8 @@ def __init__( bias=False, # TODO: note, if bias is used there is a problem in the MambaInnerFn.backward for the bias grads. I think this bias is not used in other mamba repos. weight_init_method=kaiming_init_(td_model.size), lr_scale=mamba_layer_lr_scale, - **factory_kwargs, ) self.out_proj.weight.auto_grad_accumulation = True - self._return_input = return_input def forward(self, hidden_states, kwargs): assert _mamba_available @@ -168,8 +166,6 @@ def forward(self, hidden_states, kwargs): "d (b l) -> b d l", l=seqlen, ) - if self._debug_mode: - print("XZ: ", xz.shape) A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # In the backward pass we write dx and dz next to each other to avoid torch.cat @@ -189,6 +185,4 @@ def forward(self, hidden_states, kwargs): delta_bias=self.dt_proj_bias.float(), delta_softplus=True, ) - if self._return_input: - out = torch.stack((hidden_states, out), dim=0) return out, None diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 3351c990..174e1958 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -13,9 +13,9 @@ TransformerKwargs, TransformerSubLayerName, ) -from fast_llm.logging import log_distributed_grad, log_distributed_tensor -from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.layers.transformer.transformer import Mixer +from fast_llm.tensor import init_normal_, init_zeros_ +from fast_llm.utils import get_lr_scale try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -50,11 +50,13 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # no return grad, None -class Attention(torch.nn.Module): +class Attention(Mixer): """ A self-attention layer. """ + _mixer_name: typing.ClassVar[str] = "attn" + _QUERY_DIMS = ( TransformerDimNames.batch, TransformerDimNames.sequence_q, @@ -64,7 +66,7 @@ class Attention(torch.nn.Module): _KV_DIMS = ( TransformerDimNames.batch, TransformerDimNames.sequence_q, - TransformerDimNames.group_heads, + TransformerDimNames.head_groups, TransformerDimNames.kv_channels, ) _CONTEXT_DIMS = ( @@ -73,19 +75,9 @@ class Attention(torch.nn.Module): TransformerDimNames.composite_dense, ) - def __init__( - self, - config: TransformerConfig, - tensor_space: TensorSpace, - layer_index, - ): - super().__init__() + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int): + super().__init__(tensor_space, block_index, config.debug_transformer) self._config = config - self._tensor_space = tensor_space - # Assert.in_range_incl(layer_index, 1, max(self._config.num_layers, 1)) - self._layer_index = layer_index - self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel - self._debug_transformer = self._config.debug_transformer self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) init_method_qkv = init_normal_( @@ -108,7 +100,7 @@ def __init__( hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) @@ -178,10 +170,10 @@ def _attn_fused( query, key, beta=0, - alpha=self._softmax_scale / self._layer_index, + alpha=self._softmax_scale / self._block_index, ).view(b, self._local_head_groups, sq, self._local_heads_per_group, sk) - attn_weights = attn_weights.to(torch.float32) * self._layer_index + attn_weights = attn_weights.to(torch.float32) * self._block_index attn_weights = torch.where(mask, attn_weights, mask_value) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) @@ -200,40 +192,6 @@ def _attn_fused( .flatten(2) ) - def _get_meta( - self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> TensorMeta: - hidden_dims = {dim.name: dim for dim in kwargs[TransformerKwargs.hidden_dims]} - return TensorMeta.from_dims( - tuple( - hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) - for dim_name in dim_names - ), - tensor_name=f"transformer layer {self._layer_index} attn {name}", - dtype=input_.dtype, - ) - - def _debug_log( - self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> None: - # TODO: Local vs global - Assert.gt(self._debug_transformer, 0) - log_distributed_tensor( - "", - tensor, - level=self._debug_transformer, - meta=self._get_meta(tensor, name, dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - if tensor.requires_grad: - log_distributed_grad( - "", - tensor, - level=self._debug_transformer, - meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - def _query_key_value_forward( self, input_: torch.Tensor, sequence_first: bool ) -> tuple[torch.Tensor, torch.Tensor, dict[str, typing.Any]]: @@ -300,7 +258,7 @@ def _decide_window_size(self) -> int | None: # https://github.com/huggingface/transformers/blob/5e2183f344911aa82aba0b83778a4f196cff378e/src/transformers/models/qwen2/modular_qwen2.py#L71 # TODO: make universal per layer config window_size = self._config.window_size - if self._config.max_window_layers is not None and self._layer_index < self._config.max_window_layers: + if self._config.max_window_layers is not None and self._block_index < self._config.max_window_layers: window_size = None return window_size @@ -341,7 +299,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels) value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels) - if self._debug_transformer: + if self._debug_level: self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs) self._debug_log( key, @@ -395,7 +353,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ kwargs[TransformerKwargs.attention_mask_value], ) - if self._debug_transformer: + if self._debug_level: self._debug_log(query, "query", self._QUERY_DIMS, kwargs) self._debug_log( key, diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index a46af138..73f83ccf 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -40,11 +40,11 @@ class MixtureOfExpertMLP(MLPBase): _group: ProcessGroup - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, tensor_space, name, layer_index) + super().__init__(config, tensor_space, name, block_index) self._config = config self._tensor_space = tensor_space self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory @@ -59,7 +59,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._z_loss_factor = config.expert_z_loss_coefficient self._moe_jitter_eps = config.moe_jitter_eps - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) self.router = Linear( diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index b01eb2aa..efe0c5cc 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -14,10 +14,10 @@ class MLPBase(Layer, ABC): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): super().__init__() self._name = name - self._layer_index = layer_index + self._block_index = block_index init_method_1 = init_normal_( std=config.init_method_std_mlp_1, @@ -39,7 +39,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._activation_type = config.activation_type self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale lr_scale = get_lr_scale(lr_scale, layer_lr_scale) @@ -69,9 +69,9 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s class MLP(MLPBase): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.eq(config.num_experts, 1) - super().__init__(config, tensor_space, name, layer_index) + super().__init__(config, tensor_space, name, block_index) def forward( self, diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 14745207..d08db9a9 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -8,25 +8,85 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.transformer.attention import Attention from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.transformer.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert logger = logging.getLogger(__name__) +class Mixer(torch.nn.Module, abc.ABC): + """ + Base class for mixer modules. + """ + + _mixer_name: typing.ClassVar[str] + + def __init__(self, tensor_space: TensorSpace, block_index: int, debug_level: int = 0): + super().__init__() + self._tensor_space = tensor_space + self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel + self._block_index = block_index + self._debug_level = debug_level + + @abc.abstractmethod + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Mixer module forward. Returns the output hidden states and an optional bias, + in case its addition can be made more efficient in `_bias_dropout_add`. + """ + + def _get_meta( + self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> TensorMeta: + hidden_dims = { + dim.name: dim + for dim in kwargs[TransformerKwargs.hidden_dims] + (kwargs[TransformerKwargs.sequence_q_dim],) + } + return TensorMeta.from_dims( + tuple( + hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) + for dim_name in dim_names + ), + tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", + dtype=input_.dtype, + ) + + def _debug_log( + self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> None: + # TODO: Local vs global + Assert.gt(self._debug_level, 0) + log_distributed_tensor( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name, dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + if tensor.requires_grad: + log_distributed_grad( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + + class BaseBlock(Layer, abc.ABC): """ A transformer-like decoder base block with abstract mixer. """ - _mixer_module_name = "self_attn" + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "mixer" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False ): super().__init__() self._config: TransformerConfig = config @@ -35,18 +95,19 @@ def __init__( # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input - self._layer_index = layer_index + self._block_index = block_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) # Note, layer_lr_scale does not impact the norms - # TODO: add a seperate norm_lr_scale + # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) - self._create_mixer() + # The mixer needs to be created here for backward-compatible weight ordering. + setattr(self, self._mixer_module_name, self._create_mixer()) self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, self._tensor_space, f"{self.name} mlp", layer_index=layer_index + self._config, self._tensor_space, f"{self.name} mlp", block_index=block_index ) # PEFT. @@ -54,7 +115,7 @@ def __init__( self.norm_2 = self._config.peft.apply_other(self.norm_2) @abc.abstractmethod - def _create_mixer(self): + def _create_mixer(self) -> Mixer: pass @torch.compile @@ -67,7 +128,7 @@ def _bias_dropout_add( @property def name(self) -> str: - return f"{self._name} {self._layer_index}" + return f"{self._name} {self._block_index}" def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[TransformerKwargs.hidden_dims] @@ -137,14 +198,17 @@ def forward( return hidden_states -class TransformerLayer(BaseBlock): +class TransformerBlock(BaseBlock): _name = "Transformer layer" - _mixer_module_name = "self_attn" + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "self_attn" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False ): - super().__init__(config, tensor_space, layer_index, return_input) + super().__init__(config, tensor_space, block_index, return_input) + + def _create_mixer(self) -> Mixer: + from fast_llm.layers.transformer.attention import Attention - def _create_mixer(self): - self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) + return Attention(self._config, self._tensor_space, self._block_index) diff --git a/fast_llm/logging.py b/fast_llm/logging.py index e8334de6..6d555a0b 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -138,6 +138,8 @@ def log_tensor[ if level < 1: return tensor = tensor.detach() + if tensor.ndim == 0: + tensor = tensor[None] save_stats = TensorLogs.config.save shape = tuple(tensor.shape) _, dtype = str(tensor.dtype).split("torch.") diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 444ad72b..4c1eab46 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -21,7 +21,7 @@ TransformerLossNames, ) from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -68,11 +68,11 @@ def get_output_layers(self) -> list[Layer]: for i in range(self._config.prediction_heads): if i > 0: layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, # TODO MTP: which index? - layer_index=max(self._config.transformer.num_layers + i, 1), + block_index=max(self._config.transformer.num_layers + i, 1), # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=i < self._config.prediction_heads - 1, @@ -91,10 +91,10 @@ def get_layers(self) -> list[Layer]: return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, @@ -336,7 +336,7 @@ def embedding(self) -> LanguageModelEmbedding: return self.layers[0] @property - def transformer_layers(self) -> list[TransformerLayer]: + def transformer_layers(self) -> list[TransformerBlock]: return self.layers[1:-1] @property diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index cc83f11b..9ca0123b 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -9,9 +9,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.language_model.config import LanguageModelBaseConfig from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig, SSMDimNames -from fast_llm.models.gpt.config import GPTBatchConfig, PretrainedGPTModelConfig +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, PretrainedGPTModelConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -24,14 +23,14 @@ @config_class() -class HybridSSMBaseModelConfig(LanguageModelBaseConfig): +class HybridSSMBaseModelConfig(GPTBaseModelConfig): _abstract = False ssm: SSMConfig = Field( desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) - hybrid_block_layout: list[str] | None = Field( + hybrid_block_layout: list[SSMBlockType] | None = Field( default=None, desc=f"Pattern of blocks to use in the model. Available types: {SSMBlockType.__members__.values()}", hint=FieldHint.core, @@ -41,9 +40,8 @@ class HybridSSMBaseModelConfig(LanguageModelBaseConfig): desc="Multi-token prediction mixer to use in the model. If None, will use the last block type in `hybrid_block_layout`.", hint=FieldHint.optional, ) - use_megatron_initialization: bool = Field( - default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing - ) # TODO: is this needed? + # TODO: Support combination of different SSM block types. + ssm_block_type: SSMBlockType | None = Field(init=False) def setup_tensor_space(self, tensor_space: TensorSpace) -> None: """ @@ -83,6 +81,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: inner_proj_dim: int = 2 * self.ssm.d_xb + 2 * d_inner + self.ssm.dt_rank tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba2, inner_proj_dim)) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim_2, self.ssm.d_xb)) + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.c_heads, d_inner // self.ssm.state_size)) def _validate(self): with self._set_implicit_default(None): @@ -96,30 +95,21 @@ def _validate(self): if self.hybrid_block_layout is None: with self._set_implicit_default(): - self.hybrid_block_layout = [SSMBlockType.mamba2_discrete.value] + self.hybrid_block_layout = [SSMBlockType.mamba2_discrete] * self.transformer.num_layers if len(self.hybrid_block_layout) != self.transformer.num_layers: + message = f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: - raise ValueError( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" - ) - num_repeats = int(self.transformer.num_layers // len(self.hybrid_block_layout)) - logger.warning( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times" - ) + raise ValueError(message) + num_repeats = self.transformer.num_layers // len(self.hybrid_block_layout) + logger.warning(f"{message}, will repeat {self.hybrid_block_layout} {num_repeats} times.") self.hybrid_block_layout = self.hybrid_block_layout * num_repeats - Assert.eq(len(self.hybrid_block_layout), self.transformer.num_layers) - Assert.custom( - lambda _: all(block_type in SSMBlockType.__members__.values() for block_type in self.hybrid_block_layout), - f"Invalid block type: {self.hybrid_block_layout}. Must be one of {SSMBlockType.__members__.values()}", - ) - Assert.custom( - lambda _: self.default_mtp_type in SSMBlockType.__members__.values() or self.default_mtp_type is None, - f"Invalid MTP type: {self.default_mtp_type}. Must be one of {SSMBlockType.__members__.values()} or None", - ) - super()._validate() + ssm_block_types = set(self.hybrid_block_layout) - {SSMBlockType.transformer} + # TODO: Support combination of different SSM block types. + Assert.leq(len(ssm_block_types), 1) + self.ssm_block_type = ssm_block_types.pop() if ssm_block_types else None class LLambaHuggingfaceCheckpointFormat(CheckpointFormat): diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 02a5ac23..89f0cd4a 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -5,11 +5,8 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead -from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 -from fast_llm.layers.ssm.llamba_block import LlambaBlock -from fast_llm.layers.ssm.mamba2 import Mamba2 -from fast_llm.layers.ssm.mamba_layer import MambaLayer -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.ssm.llamba_block import SSMBlock +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.model import GPTBaseModel, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType @@ -31,7 +28,6 @@ def __init__( config: HybridSSMBaseModelConfig, distributed_config: DistributedConfig, ): - self.SSM_BLOCK_CLS = LlambaBlock # TODO: extend to other block types if needed super().__init__(config, distributed_config) def get_output_layers(self) -> list[Layer]: @@ -39,52 +35,31 @@ def get_output_layers(self) -> list[Layer]: Get the output layers of the model. This includes the language model head and any additional heads specified in the configuration. """ - layers = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] + layers: list[Layer] = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] if self._config.prediction_heads > 1: block_type = self._config.default_mtp_type or self._config.hybrid_block_layout[-1] for i in range(1, self._config.prediction_heads): if block_type == SSMBlockType.transformer: layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=len(self._config.hybrid_block_layout), + block_index=len(self._config.hybrid_block_layout), return_input=i != self._config.prediction_heads - 1, ) ) - elif block_type == SSMBlockType.mamba2_discrete: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=DiscreteMamba2, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) - elif block_type == SSMBlockType.mamba: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=MambaLayer, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) - elif block_type == SSMBlockType.mamba2: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=Mamba2, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") + layers.append( + SSMBlock( + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, + mixer_cls=self._config.ssm_block_type.get_mixer_class(), + block_index=len(self._config.hybrid_block_layout), + tensor_space=self._tensor_space, + return_input=i != self._config.prediction_heads - 1, + ) + ) layers.append(LanguageModelHead(self._config, self._tensor_space, prediction_distance=i)) return layers @@ -94,63 +69,35 @@ def get_layers(self) -> list[Layer]: Create a list of layers for the model, interleaving Transformer and Mamba blocks according to the block pattern. """ - layers = [LanguageModelEmbedding(self._config, self._tensor_space)] + layers: list[Layer] = [LanguageModelEmbedding(self._config, self._tensor_space)] # Create blocks according to pattern for i, block_type in enumerate(self._config.hybrid_block_layout): if block_type == SSMBlockType.transformer: # Transformer block layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, return_input=( i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 ), ) ) - elif block_type == SSMBlockType.mamba2_discrete: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=DiscreteMamba2, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) - - elif block_type == SSMBlockType.mamba: - # Create Mamba block - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=MambaLayer, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) - - elif block_type == SSMBlockType.mamba2: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=Mamba2, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") + layers.append( + SSMBlock( + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, + mixer_cls=self._config.ssm_block_type.get_mixer_class(), + block_index=i + 1, + tensor_space=self._tensor_space, + return_input=( + i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 + ), + ) + ) # Add the output layers layers += self.get_output_layers() diff --git a/setup.cfg b/setup.cfg index 843aa15c..c086af7d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,14 +48,9 @@ HUGGINGFACE = # Required to run SSMs # To install on cpu environment (ex. for IDE support): -# MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation +# MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[SSM]" --no-build-isolation SSM = mamba_ssm[causal-conv1d]==2.2.4 - cartesia_pytorch>=0.0.2 - -GENERATION = - lm_eval>=0.4.9 - DEV = # Pre-commit git hook diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 438782df..3e6c3763 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -193,6 +193,7 @@ def test_gpt_blended_mixed(): def test_gpt_blended_mixed_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index e951cc2b..4f36cdf8 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -39,6 +39,7 @@ def test_gpt_concatenate(): def test_gpt_concatenate_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 7472f195..004b9628 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -58,6 +58,7 @@ def test_gpt_fim(): def test_gpt_fim_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { @@ -81,6 +82,7 @@ def test_gpt_fim_data(): def test_gpt_fim_data_legacy(): + get_test_dataset() get_test_data_and_compare_samples( { "format": "list", diff --git a/tests/test_attention.py b/tests/test_attention.py index 87b0d3e5..dd36b840 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -17,12 +17,12 @@ def test_decide_window_size(): # Arrange - Case 1: window_size is returned (layer_index >= max_window_layers) attention._config = TransformerConfig(window_size=512, max_window_layers=2) - attention._layer_index = 2 + attention._block_index = 2 assert attention._decide_window_size() == 512 # Arrange - Case 2: window_size is None (layer_index < max_window_layers) attention._config = TransformerConfig(window_size=512, max_window_layers=2) - attention._layer_index = 1 + attention._block_index = 1 assert attention._decide_window_size() is None # Arrange - Case 3: max_window_layers is None (always return window_size) diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index c530a170..2f125717 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -3,9 +3,10 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer -from fast_llm.layers.ssm.llamba_block import LlambaBlock -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.ssm.llamba_block import SSMBlock +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.utils import Assert +from tests.utils.dataset import get_model_test_dataset from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -23,6 +24,7 @@ def _get_trainer_from_args(args: list[str], model_type: str = "gpt") -> Trainer: @requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_frozen_weights(model_testing_config): + get_model_test_dataset() args = model_testing_config.config_args + ["run.tensor_logs.save=False"] model_ref = _get_trainer_from_args(args, model_testing_config.model_type)._multi_stage model_frozen = _get_trainer_from_args( @@ -39,7 +41,7 @@ def test_frozen_weights(model_testing_config): model_frozen._num_stages, ) frozen_parameter_counts = [ - sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerLayer, LlambaBlock)) else 0 + sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerBlock, SSMBlock)) else 0 for layer in model_ref.base_model.layers ] for weight_buffer_ref, weight_buffer_frozen in zip( diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 1eee3675..42252c62 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -523,6 +523,7 @@ def _update_and_add_testing_config( model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2']", + f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}", ], megatron_args=None, checkpoint_format=None, From 7b32699be7c1a1fb29cc7386eb33280b0bc19a5c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 17:28:56 -0400 Subject: [PATCH 13/37] stuff --- fast_llm/layers/ssm/mamba2.py | 57 ++++++++++++++--------------------- fast_llm/models/ssm/config.py | 2 +- tests/utils/model_configs.py | 2 +- 3 files changed, 24 insertions(+), 37 deletions(-) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index ead32fa2..b936ccf1 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -7,6 +7,7 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_fill_, init_ones_, init_uniform_, kaiming_init_ @@ -97,9 +98,9 @@ def __init__( if self.repeat_kv_before_conv: self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, TensorDim("1", 1), td_conv_kernel), + (td_inner, td_conv_kernel), init_method=init_uniform_( - 1 / math.sqrt(td_inner.size * td_conv_kernel.size), + -1 / math.sqrt(td_inner.size * td_conv_kernel.size), 1 / math.sqrt(td_inner.size * td_conv_kernel.size), ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 lr_scale=mamba_layer_lr_scale, @@ -110,9 +111,9 @@ def __init__( ) else: self.conv1d_weight = ParameterMeta.from_dims( - (td_xb, TensorDim("1", 1), td_conv_kernel), + (td_xb, td_conv_kernel), init_method=init_uniform_( - 1 / math.sqrt(td_xb.size * td_conv_kernel.size), + -1 / math.sqrt(td_xb.size * td_conv_kernel.size), 1 / math.sqrt(td_xb.size * td_conv_kernel.size), ), ) @@ -133,7 +134,13 @@ def __init__( weight_init_method=kaiming_init_(td_model.size), lr_scale=mamba_layer_lr_scale, ) - + self.dt_in_proj = Linear( + td_model, + tdt_rank, + bias=config.add_bias_linear, + weight_init_method=kaiming_init_(transformer_config.hidden_size), + lr_scale=mamba_layer_lr_scale, + ) # Initialize special dt projection to preserve variance at initialization dt_scale = config.dt_scale # 1.0 dt_init_std = self.dt_rank**-0.5 * dt_scale @@ -144,24 +151,6 @@ def __init__( else: raise NotImplementedError - # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - dt_max = config.dt_max # or 0.1 - dt_min = config.dt_min # or 0.001 - dt_init_floor = config.dt_init_floor # or 1e-4 - dt = torch.exp(torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)).clamp( - min=dt_init_floor - ) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - - def init_from_tensor_( - value: torch.Tensor, - ) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - return tensor.copy_(value) - - return init_ - self.dt_proj = Linear( tdt_rank, td_inner, @@ -171,18 +160,16 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) ) # define bias outside the linear layer since its also used in the selective_scan_fn self.dt_proj_bias = ParameterMeta.from_dims( - (td_inner,), init_method=init_from_tensor_(inv_dt), lr_scale=mamba_layer_lr_scale + (td_inner,), + init_method=init_dtprojbias( + self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor + ), + lr_scale=mamba_layer_lr_scale, ) - A = einops.repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_log = torch.log(A).flatten() # Keep A_log in fp32 self.A_log = ParameterMeta.from_dims( (td_inner, td_state), - init_method=init_from_tensor_(A_log), + init_method=init_A(self.config.state_size, self.config.d_inner), lr_scale=mamba_layer_lr_scale, weight_decay=False, ) @@ -214,8 +201,8 @@ def forward(self, hidden_states, kwargs): A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - zxbcdt = self.in_proj(hidden_states) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) + zxbc = self.in_proj(hidden_states) + z, x, B, C = torch.split(zxbc, [self.d_inner, self.d_xb, self.d_xb, self.d_inner], dim=-1) x = einops.rearrange(x, "b l d -> b d l") z = einops.rearrange(z, "b l d -> b d l") @@ -225,7 +212,7 @@ def forward(self, hidden_states, kwargs): B = einops.rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() C = einops.rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous() - dt = self.dt_proj(dt) + self.dt_proj_bias # B, L, d_inner + dt = self.dt_proj(self.dt_in_proj(hidden_states)) + self.dt_proj_bias # B, L, d_inner dt = einops.rearrange(dt, "b l d -> b d l") # B, d_inner, L if self.repeat_kv_before_conv: @@ -238,7 +225,7 @@ def forward(self, hidden_states, kwargs): if _causal_conv1d_available: x = _causal_conv1d_fn( x=x, - weight=einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), + weight=self.conv1d_weight, bias=self.conv1d_bias, activation=self.activation, ) # B, L, D diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 9ca0123b..b04b1f21 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -78,7 +78,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_discrete_mamba2, inner_proj_dim)) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_dim, conv_dim)) elif SSMBlockType.mamba2.value in self.hybrid_block_layout: - inner_proj_dim: int = 2 * self.ssm.d_xb + 2 * d_inner + self.ssm.dt_rank + inner_proj_dim: int = 2 * self.ssm.d_xb + 2 * d_inner # + self.ssm.dt_rank tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba2, inner_proj_dim)) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim_2, self.ssm.d_xb)) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.c_heads, d_inner // self.ssm.state_size)) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 42252c62..4976ad2b 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -523,7 +523,7 @@ def _update_and_add_testing_config( model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2']", - f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}", + # f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}", ], megatron_args=None, checkpoint_format=None, From 1feccc866c1dea2da66567476fc911a37a855038 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 17:48:23 -0400 Subject: [PATCH 14/37] stuff --- fast_llm/layers/ssm/mamba2.py | 2 +- fast_llm/layers/ssm/mamba_layer.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 88fe4abc..fdba10be 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -111,7 +111,7 @@ def __init__( sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) - # define bias outside the linear layer since its also used in the selective_scan_fn + # define bias outside the linear layer since it's also used in the selective_scan_fn self.dt_proj_bias = ParameterMeta.from_dims( (inner_dim,), init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 49afa910..11db3791 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -48,9 +48,7 @@ def init_dtprojbias( dt_max: float, dt_min: float, dt_init_floor: float ) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - tensor = ( - tensor.uniform_(math.log(dt_min), math.log(dt_max), generator=generator).exp_().clamp_min(dt_init_floor) - ) + tensor.uniform_(math.log(dt_min), math.log(dt_max), generator=generator).exp_().clamp_min_(dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 return tensor.add_(torch.log(-torch.expm1(-tensor))) From e528b50ba5c5e2ea726876779db010f83fccd8ef Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 18:00:20 -0400 Subject: [PATCH 15/37] misc --- fast_llm/layers/ssm/discrete_mamba2.py | 4 ++-- fast_llm/layers/ssm/mamba2.py | 12 ++++++++---- fast_llm/layers/ssm/mamba_layer.py | 10 +++++++--- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index b95ff76d..fdce9bf6 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -5,7 +5,7 @@ import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs @@ -97,7 +97,7 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( - (td_conv, TensorDim("1", 1), td_conv_kernel), + (td_conv, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=init_uniform_centered_((td_conv.size * td_conv_kernel.size) ** -0.5), lr_scale=mamba_layer_lr_scale, ) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index fdba10be..8be9dcb9 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -3,7 +3,7 @@ import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames @@ -75,7 +75,11 @@ def __init__( conv1d_dim = inner_dim if self._config.repeat_kv_before_conv else xb_dim self.conv1d_weight = ParameterMeta.from_dims( - (conv1d_dim, tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel)), + ( + conv1d_dim, + tensor_space.get_tensor_dim(DefaultDimNames.scalar), + tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel), + ), init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, ) @@ -168,9 +172,9 @@ def forward(self, hidden_states, kwargs): .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) ) - x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias, activation="silu") + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias.squeeze(1), activation="silu") else: - x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias, activation="silu") + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias.squeeze(1), activation="silu") x = ( x.unflatten(1, (self._local_head_groups, self._config.state_size)) .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 11db3791..07eec38e 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -4,7 +4,7 @@ import torch -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames @@ -87,7 +87,11 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( - (inner_dim, tensor_space.get_tensor_dim(SSMDimNames.conv_kernel)), + ( + inner_dim, + tensor_space.get_tensor_dim(DefaultDimNames.scalar), + tensor_space.get_tensor_dim(SSMDimNames.conv_kernel), + ), init_method=init_kaiming_(inner_dim.size), lr_scale=lr_scale, ) @@ -146,7 +150,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ # not, if we wanbt to support inference, we would need to imp.lement slow path here, see https://github.com/Zyphra/Zamba2/blob/1b182f40f2257f822cc06dd785df53d67d691a15/mamba_layer.py#L172s out = _mamba_inner_fn( in_proj, - self.conv1d_weight.unsqueeze(1), + self.conv1d_weight, None, self.x_proj.weight, self.dt_proj_weight, From b49c42febac4f32dc1be83655b242d6199a385bc Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 18:16:42 -0400 Subject: [PATCH 16/37] misc --- fast_llm/layers/ssm/discrete_mamba2.py | 4 ++-- fast_llm/layers/ssm/mamba2.py | 8 ++++---- fast_llm/layers/ssm/mamba_layer.py | 4 ++-- .../modeling_ssm_hybrid_apriel15b.py | 20 +++++++++++++------ tests/utils/model_configs.py | 1 - 5 files changed, 22 insertions(+), 15 deletions(-) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 734e35b2..c0ae7e78 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -5,7 +5,7 @@ import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs @@ -103,7 +103,7 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( - (td_conv, TensorDim("1", 1), td_conv_kernel), + (td_conv, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=init_uniform_( 1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size) ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index b936ccf1..74c212ad 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -4,7 +4,7 @@ import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias @@ -98,7 +98,7 @@ def __init__( if self.repeat_kv_before_conv: self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, td_conv_kernel), + (td_inner, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=init_uniform_( -1 / math.sqrt(td_inner.size * td_conv_kernel.size), 1 / math.sqrt(td_inner.size * td_conv_kernel.size), @@ -111,7 +111,7 @@ def __init__( ) else: self.conv1d_weight = ParameterMeta.from_dims( - (td_xb, td_conv_kernel), + (td_xb, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=init_uniform_( -1 / math.sqrt(td_xb.size * td_conv_kernel.size), 1 / math.sqrt(td_xb.size * td_conv_kernel.size), @@ -225,7 +225,7 @@ def forward(self, hidden_states, kwargs): if _causal_conv1d_available: x = _causal_conv1d_fn( x=x, - weight=self.conv1d_weight, + weight=einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), bias=self.conv1d_bias, activation=self.activation, ) # B, L, D diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index a95e94c0..4493332c 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -5,7 +5,7 @@ import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig @@ -98,7 +98,7 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, TensorDim("D_inner_2", self.d_inner // self.d_inner), td_conv_kernel), + (td_inner, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=kaiming_init_(td_inner.size), lr_scale=mamba_layer_lr_scale, ) diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index f8f6a052..4fde7245 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -843,9 +843,8 @@ def __init__( self.num_C_head = self.d_inner // self.d_state self.repeat_group = self.num_C_head // self.num_xb_head - self.in_proj = nn.Linear( - self.d_model, 2 * self.d_xb + 2 * self.d_inner + self.dt_rank, bias=bias, **factory_kwargs - ) + self.in_proj = nn.Linear(self.d_model, 2 * self.d_xb + 2 * self.d_inner, bias=bias, **factory_kwargs) + self.dt_in_proj = nn.Linear(self.d_model, self.dt_rank, bias=bias, **factory_kwargs) self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=dt_proj_bias, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization @@ -933,8 +932,17 @@ def forward( outputs = {} A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - zxbcdt = self.in_proj(hidden_states) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) + zxbc = self.in_proj(hidden_states) + z, x, B, C = torch.split( + zxbc, + [ + self.d_inner, + self.d_xb, + self.d_xb, + self.d_inner, + ], + dim=-1, + ) x = rearrange(x, "b l d -> b d l") z = rearrange(z, "b l d -> b d l") @@ -944,7 +952,7 @@ def forward( B = rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() C = rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous() - dt = self.dt_proj(dt) # B, L, d_inner + dt = self.dt_proj(self.dt_in_proj(hidden_states)) # B, L, d_inner dt = rearrange(dt, "b l d -> b d l") # B, d_inner, L if self.repeat_kv_before_conv: diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 4976ad2b..1eee3675 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -523,7 +523,6 @@ def _update_and_add_testing_config( model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2']", - # f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}", ], megatron_args=None, checkpoint_format=None, From c1b7f44a10ff379a067b10b76df296f3bee4cac1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 18:19:08 -0400 Subject: [PATCH 17/37] misc --- .../models/ssm/external/llamba/modeling_mtp_llamba.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py b/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py index 8f49ded4..6d9746db 100644 --- a/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py +++ b/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py @@ -322,21 +322,19 @@ def __init__(self, config, factory_kwargs, layer_idx, **kwargs): # Mixer self.mixer = DiscreteMamba2( - d_model=self.config._hidden_size, + d_model=self.config.d_model, layer_idx=layer_idx, **config.ssm_cfg, **factory_kwargs, ) # Other components - self.input_layernorm = LlamaRMSNorm( - hidden_size=self.config._hidden_size, eps=1e-5, factory_kwargs=factory_kwargs - ) + self.input_layernorm = LlamaRMSNorm(hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs) self.post_attention_layernorm = LlamaRMSNorm( - hidden_size=self.config._hidden_size, eps=1e-5, factory_kwargs=factory_kwargs + hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs ) self.mlp = LlamaMLP( - hidden_size=self.config._hidden_size, + hidden_size=self.config.d_model, **config.mlp_cfg, factory_kwargs=factory_kwargs, ) From 31f5d415ef0c7eeca54a26d415076cbf3ba33cfd Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 18:20:26 -0400 Subject: [PATCH 18/37] misc --- fast_llm/models/ssm/conversion.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index d5730025..43e3c67e 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -3,6 +3,7 @@ import pathlib import typing +from fast_llm.config import MISSING from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( ConstantExportParamConverter, @@ -19,7 +20,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import RMSNormalizationConfig -from fast_llm.layers.ssm.config import SSMBlockType +from fast_llm.layers.ssm.config import DTInitType, SSMBlockType from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter from fast_llm.models.ssm.config import ( AprielSSMHHybridHuggingfaceCheckpointFormat, @@ -42,11 +43,11 @@ class HybridModelCheckpointHandler(HuggingfaceStateDictCheckpointHandler): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - block_converter = RenameParamConverter( + block_converter = MappedConfigParamConverter( fast_llm_names=(("hybrid_block_layout",),), export_names=(("hybrid_block_layout",),), - ignore_missing=True, - default_value=[cls._default_block_type], + fast_llm_value=lambda x: [cls._default_block_type] if x == MISSING else x, + export_value=lambda x: [x_.value for x_ in x], ) return super()._create_config_converters() + [block_converter] @@ -202,7 +203,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ignore_missing=True, default_value=4, ), - RenameParamConverter( + MappedConfigParamConverter( fast_llm_names=(("ssm", "dt_init"),), export_names=( ( @@ -210,8 +211,8 @@ def _create_config_converters(cls) -> list[ParamConverter]: "dt_init", ), ), - ignore_missing=True, - default_value="random", + fast_llm_value=lambda x: DTInitType.random if x == MISSING else DTInitType(x), + export_value=lambda x: x.value, ), ] @@ -258,6 +259,9 @@ def _create_weight_converters(self) -> list[WeightConverter]: ) # ================================================ # Mamba2 specific parameters + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.mixer.dt_in_proj", f"model.layers.{i}.mixer.dt_in_proj", ssm_bias + ) converters += self._get_weight_and_bias_converters( f"layers.{i+1}.mixer.dt_proj", f"model.layers.{i}.mixer.dt_proj", False ) From 0a9ff25f6e0a699caef881dfcaeef0b19f825764 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 18:22:24 -0400 Subject: [PATCH 19/37] misc --- fast_llm/models/ssm/config.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 6b9e2858..d2a69303 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -40,9 +40,6 @@ class HybridSSMBaseModelConfig(GPTBaseModelConfig): desc="Multi-token prediction mixer to use in the model. If None, will use the last block type in `hybrid_block_layout`.", hint=FieldHint.optional, ) - use_megatron_initialization: bool = Field( - default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing - ) # TODO: is this needed? # TODO: Support combination of different SSM block types. ssm_block_type: SSMBlockType | None = Field(init=False) From e7d9636819ab83df7204cc2b021fd4565188e946 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 19:55:53 -0400 Subject: [PATCH 20/37] Parallel discrete mamba 2 --- fast_llm/layers/ssm/config.py | 12 +- fast_llm/layers/ssm/discrete_mamba2.py | 212 ++++++++++--------------- fast_llm/layers/ssm/mamba2.py | 6 +- 3 files changed, 95 insertions(+), 135 deletions(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 15a6a821..7f0b3cf6 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -211,23 +211,25 @@ def _validate(self) -> None: def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType) -> None: tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) - num_heads = div(self.d_inner, self.state_size) # Head groups are configured differently depending on the block type. if block_type == SSMBlockType.mamba: + num_heads = div(self.d_inner, self.state_size) num_head_groups = num_heads elif block_type == SSMBlockType.mamba2: + num_heads = div(self.d_inner, self.state_size) num_head_groups = div(self.d_xb, self.state_size) elif block_type == SSMBlockType.mamba2_discrete: - Assert.eq(num_heads, self.n_v_heads) + # TODO: Use different variables? + num_heads = self.n_v_heads num_head_groups = self.n_qk_heads + # v_heads have size `headdim` that may be different from `state_size`. + Assert.multiple(self.d_inner, num_heads) else: raise NotImplementedError(block_type) tensor_space.add_tensor_dim(state_dim := TensorDim(SSMDimNames.state, self.state_size)) tensor_space.add_tensor_dim(head_groups := TensorDim(SSMDimNames.head_groups, num_head_groups, tensor)) - tensor_space.add_tensor_dim( - group_heads := TensorDim(SSMDimNames.group_heads, num_group_heads := div(num_heads, num_head_groups)) - ) + tensor_space.add_tensor_dim(group_heads := TensorDim(SSMDimNames.group_heads, div(num_heads, num_head_groups))) tensor_space.add_tensor_dim( heads := CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads)) ) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index fdce9bf6..ac4fb87c 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -1,12 +1,12 @@ import logging -import math import typing import einops import torch from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace -from fast_llm.layers.common.linear import Linear +from fast_llm.functional.config import ActivationType +from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer @@ -32,12 +32,6 @@ _causal_conv1d_available = False -def bias_init_method(conv_weight): - fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(conv_weight) - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - return init_uniform_centered_(bound) - - class DiscreteMamba2(Mixer): """DiscreteMamba2 (This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py).""" @@ -51,198 +45,162 @@ def __init__( transformer_config: TransformerConfig, ): super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) - self.config: SSMConfig = config + self._config: SSMConfig = config layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None - mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) - logger.info(f"Setting lr_scale for layer {block_index} of type {type(self)}: {mamba_layer_lr_scale}") - - td_inner = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state) - td_model = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - td_conv = tensor_space.get_tensor_dim(SSMDimNames.conv_dim) - td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.head_groups) - td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.composite_heads) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel) - td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.concatenated_inner_projection) - - self.d_model = td_model.size - self.d_inner = td_inner.size - self.d_state = td_state.size - self.chunk_size = config.chunk_size - self.n_qk_heads = td_n_qk_heads.size - self.n_v_heads = td_n_v_heads.size - self.conv_kernel_size = td_conv_kernel.size - - self.act = config.activation_type.activation_fn - self.activation_name = config.activation_type.name + lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + + inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) + hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) + conv1d_dim = tensor_space.get_tensor_dim(SSMDimNames.conv_dim) + heads_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads) + + self._local_heads = heads_dim.size + self._local_head_groups = tensor_space.get_tensor_dim(SSMDimNames.head_groups).size + self._local_inner_size = inner_dim.size + self._local_bc_size = tensor_space.get_tensor_dim(SSMDimNames.composite_head_groups_and_state).size # TODO: double check initializations # Projections - self.in_proj = Linear( - td_model, - td_inner_proj, + self.in_proj = OutputParallelLinear( + hidden_dim, + tensor_space.get_tensor_dim(name=SSMDimNames.concatenated_inner_projection), bias=config.add_bias_linear, - weight_init_method=init_kaiming_(td_model.size), - lr_scale=mamba_layer_lr_scale, + weight_init_method=init_kaiming_(transformer_config.hidden_size), + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, ) - self.z_bias = ( - ParameterMeta.from_dims( - (td_inner,), + if not config.add_bias_linear: + self.z_bias = ParameterMeta.from_dims( + (inner_dim,), weight_decay=False, init_method=init_zeros_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) - if not config.add_bias_linear - else 0.0 - ) - self.conv1d_weight = ParameterMeta.from_dims( - (td_conv, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), - init_method=init_uniform_centered_((td_conv.size * td_conv_kernel.size) ** -0.5), - lr_scale=mamba_layer_lr_scale, + ( + conv1d_dim, + tensor_space.get_tensor_dim(DefaultDimNames.scalar), + tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel), + ), + init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), + lr_scale=lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( - (td_conv,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale + (conv1d_dim,), + init_method=init_uniform_centered_(self._config.conv_kernel_dimension**-0.5), + lr_scale=lr_scale, ) - # D "skip" parameter self.D = ParameterMeta.from_dims( - (td_n_v_heads,), + (heads_dim,), weight_decay=False, init_method=init_ones_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) - - # out_proj - self.out_proj = Linear( - td_inner, - td_model, + self.out_proj = InputParallelLinear( + inner_dim, + hidden_dim, bias=config.add_bias_linear, - weight_init_method=init_kaiming_(td_inner.size), - lr_scale=mamba_layer_lr_scale, + weight_init_method=init_kaiming_(self._config.d_inner), + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, ) def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - ON variable names and pep8: keeping some variable names as in the original code for clarity. - - Args: - u: (B, L, D), - - Returns: - outputs: dict. - outputs["hidden_states"]: (B, L, D). - outputs["state"]: inference cache. - """ if kwargs[TransformerKwargs.sequence_first]: raise NotImplementedError(f"Sequence-first not supported for SSMs.") assert _mamba_available - outputs = {} - # assert state is None - batch, seqlen, dim = input_.shape - - state = None - # Hacky way to initialize state during inference - chunk_size = self.chunk_size if state is None else seqlen + sequence_length = input_.size(0 if kwargs[TransformerKwargs.sequence_first] else 1) # Pad input to nearest multiple of chunklen - padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size - u = torch.nn.functional.pad(input_, (0, 0, 0, padded_len - seqlen)) + padded_length = (1 + (sequence_length - 1) // self._config.chunk_size) * self._config.chunk_size + if padded_length != sequence_length: + assert not kwargs[TransformerKwargs.sequence_first] and not self._sequence_parallel + input_ = torch.nn.functional.pad(input_, (0, 0, 0, padded_length - sequence_length)) - # Project input - xBCzA_log = self.in_proj(u) + inner_projection = self.in_proj(input_) + # Standardize to (batch, sequence, inner_projection) + if kwargs[TransformerKwargs.sequence_first]: + inner_projection = inner_projection.transpose(0, 1) - ( - xBC, - z, - A_log, - ) = torch.split( - xBCzA_log, + xBC, z, A_log = torch.split( + inner_projection, [ - self.d_inner + 2 * self.n_qk_heads * self.d_state, - self.d_inner, - self.n_v_heads, + self._local_inner_size + 2 * self._local_bc_size, + self._local_inner_size, + self._local_heads, ], dim=-1, ) - if state is not None: - # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead torch.nn.functional.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - xBC_t = einops.rearrange(xBC[:, :seqlen, :], "b l d -> b d l") - state["conv"].copy_( - torch.nn.functional.pad(xBC_t, (self.conv_kernel_size - xBC_t.shape[-1], 0)) - ) # Update state (B D W) - # Convolutional layer - xBC = self.convolutional_forward(xBC, padded_len) + xBC = self.convolutional_forward(xBC, sequence_length) x, B, C = torch.split( xBC, [ - self.d_inner, - self.n_qk_heads * self.d_state, - self.n_qk_heads * self.d_state, + self._local_inner_size, + self._local_bc_size, + self._local_bc_size, ], dim=-1, ) - x = einops.rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) - B = einops.rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) - C = einops.rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) + x = einops.rearrange(x, "b l (h n) -> b l h n", h=self._local_heads) + B = einops.rearrange(B, "b l (h n) -> b l h n", h=self._local_head_groups) + C = einops.rearrange(C, "b l (h n) -> b l h n", h=self._local_head_groups) # SSM forward - result = _mamba_chunk_scan_combined( + y = _mamba_chunk_scan_combined( x=x / torch.nn.functional.softplus(A_log).to(x.dtype).unsqueeze(-1), dt=A_log, dt_softplus=True, - A=-torch.ones(self.n_v_heads, device=A_log.device), + A=-torch.ones(self._local_heads, device=A_log.device), B=B, C=C, - chunk_size=chunk_size, - # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation - return_final_states=(state is not None), + chunk_size=self._config.chunk_size, + return_final_states=False, ) - - if state is not None: - y, ssm_state = result - state["ssm"].copy_(ssm_state) - else: - y = result - Du = torch.einsum("h,blhp->blhp", self.D, x) - y = einops.rearrange(y + Du, "b l h p -> b l (h p)") # Norm and gate - out = self.out_proj(y * torch.nn.functional.silu(z + self.z_bias)) - outputs["hidden_states"] = out[:, :seqlen, :].contiguous() + if not self._config.add_bias_linear: + z = z + self.z_bias - # TODO: since we do not support inference for now, we only return the hidden states for now. - return outputs["hidden_states"], None + # y: (batch, sequence, heads, state) -> (batch, sequence, heads * state) + y = ((y + Du).flatten(2, 3) * torch.nn.functional.silu(z))[:, :sequence_length] + if kwargs[TransformerKwargs.sequence_first]: + # TODO: Is contiguous needed? + y = y.transpose(0, 1).contiguous() + return self.out_proj(y) def convolutional_forward(self, xBC, padded_len): """Convolutional layer forward pass for the full sequence.""" - if _causal_conv1d_available and self.activation_name in ( - "silu", + if _causal_conv1d_available and self._config.activation_type in ( + ActivationType.silu, "swish", - "identity", + ActivationType.identity, ): xBC = _causal_conv1d_fn( xBC.transpose(1, 2), einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), self.conv1d_bias, - activation=None if self.activation_name == "identity" else self.activation_name, + activation=( + None + if self._config.activation_type == ActivationType.identity + else self._config.activation_type.value + ), ).transpose(1, 2) else: - xBC = self.act( + xBC = self._config.activation_type.activation_fn( torch.nn.functional.conv1d( xBC.transpose(1, 2), self.conv1d_weight, bias=self.conv1d_bias, groups=self.conv1d_weight.shape[0], - padding=self.conv_kernel_size - 1, + padding=self._config.conv_kernel_dimension - 1, )[..., :padded_len].transpose(1, 2) ) return xBC diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 8be9dcb9..cba28f8b 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -142,12 +142,12 @@ def __init__( # TODO: lr_scale? ) - def forward(self, hidden_states, kwargs): + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available assert _causal_conv1d_available - inner_projection = self.in_proj(hidden_states) - dt = self.dt_proj(self.dt_in_proj(hidden_states)) + self.dt_proj_bias + inner_projection = self.in_proj(input_) + dt = self.dt_proj(self.dt_in_proj(input_)) + self.dt_proj_bias # Standardize to (batch, sequence, inner_projection) if kwargs[TransformerKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) From f88fb2f1f44aca0852ac293fc9b8950941cc7dd2 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 25 Jul 2025 19:10:28 +0000 Subject: [PATCH 21/37] rename vit layer to block --- fast_llm/layers/transformer/transformer.py | 2 +- fast_llm/models/gpt/model.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 19443a77..d2f3bfba 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -217,5 +217,5 @@ def _create_mixer(self) -> Mixer: return Attention(self._config, self._tensor_space, self._block_index) -class VisionTransformerLayer(TransformerLayer): +class VisionTransformerBlock(TransformerBlock): _name: str = "Vision transformer layer" diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 62a58546..47100d67 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -24,7 +24,7 @@ VisionTransformerKwargs, ) from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor -from fast_llm.layers.transformer.transformer import TransformerBlock, VisionTransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock, VisionTransformerBlock from fast_llm.layers.vision_encoder.adapter import VisionAdapter from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionEncoderKwargs from fast_llm.layers.vision_encoder.patch_conv import PatchConv @@ -100,7 +100,7 @@ def get_output_layers(self) -> list[Layer]: def get_vision_layers(self) -> list[Layer]: vit_layers = [ - VisionTransformerLayer(self._config.vision_encoder.transformer, self._tensor_space, layer_index=idx + 1) + VisionTransformerBlock(self._config.vision_encoder.transformer, self._tensor_space, layer_index=idx + 1) for idx in range(self._config.vision_encoder.transformer.num_layers) ] return [ From 22296b32ce1b296c505e42dab7ca8893d0bf75a4 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 25 Jul 2025 19:24:58 +0000 Subject: [PATCH 22/37] block_index --- fast_llm/models/gpt/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 47100d67..c76c2191 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -100,7 +100,7 @@ def get_output_layers(self) -> list[Layer]: def get_vision_layers(self) -> list[Layer]: vit_layers = [ - VisionTransformerBlock(self._config.vision_encoder.transformer, self._tensor_space, layer_index=idx + 1) + VisionTransformerBlock(self._config.vision_encoder.transformer, self._tensor_space, block_index=idx + 1) for idx in range(self._config.vision_encoder.transformer.num_layers) ] return [ From c14b7643ae3f840f8da23404922f9482ff507284 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 25 Jul 2025 17:14:17 -0400 Subject: [PATCH 23/37] Mamba 2, misc --- fast_llm/engine/multi_stage/stage_base.py | 5 +- fast_llm/layers/ssm/config.py | 62 ++++++++++--------- fast_llm/layers/ssm/discrete_mamba2.py | 50 ++++++++++----- fast_llm/layers/ssm/mamba2.py | 22 ++++--- fast_llm/layers/ssm/mamba_layer.py | 27 ++++----- fast_llm/tensor.py | 74 +++++++++++++++-------- tests/models/test_checkpoint.py | 11 +++- tests/utils/model_configs.py | 9 +-- 8 files changed, 160 insertions(+), 100 deletions(-) diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 9a8ce209..3218a196 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -185,8 +185,9 @@ def initialize_weights(self) -> None: # Multi-gpu init may be different because of TP or FSDP (different shape), or PP (not on device) global_shape = meta.global_shape - if self._distributed_config.reproducible_init and ( - global_shape.numel() != parameter.numel() or not self._mode.on_device + if meta.requires_global_initialization or ( + self._distributed_config.reproducible_init + and (global_shape.numel() != parameter.numel() or not self._mode.on_device) ): # Initialize all global weights on every gpu, then select the appropriate slice if applicable. global_param = parameter.new_empty(global_shape, device=self._distributed.device) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 7f0b3cf6..c06d8514 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -5,31 +5,31 @@ from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig +from fast_llm.tensor import Initializer from fast_llm.utils import Assert, div class SSMDimNames: # TODO: Use separate tensor space for different mixers so there is no risk of name conflict. state = "ssm_state" # State dimension (N), aka head size / num channels - + head_dim = "ssm_head_dim" head_groups = "ssm_head_groups" group_heads = "ssm_group_heads" + convolution_kernel = "ssm_convolution_kernel" # Kernel dimension of the conv1d in mamba layers + + dt_rank = "ssm_dt_rank" + + # Composite dimensions composite_heads = "ssm_composite_heads" - composite_heads_and_state = "ssm_composite_heads_and_state" + composite_heads_and_head_dim = "ssm_composite_heads_and_head_dim" composite_head_groups_and_state = "ssm_composite_head_groups_and_state" - # Inner projection total dimension. + # Concatenated dimensions + concatenated_convolution = "ssm_concatenated_convolution" + concatenated_x_projection = "ssm_x_concatenated_x_projection" concatenated_inner_projection = "ssm_concatenated_inner_projection" - # Convolution shape in discrete mamba 2. TODO: Remove (dim too complex) - conv_dim = "ssm_conv_dim" - - dt_rank = "ssm_dt_rank" - - x_proj_dim = "x_proj_dim" # X projection dimension - conv_kernel = "conv_kernel" # Kernel size of the conv1d in mamba layers - class SSMBlockType(enum.StrEnum): """ @@ -62,7 +62,7 @@ class DTInitType(enum.StrEnum): constant = "constant" random = "random" - def get_init_method(self, scale: float): + def get_init_method(self, scale: float) -> Initializer: from fast_llm.tensor import init_fill_, init_uniform_centered_ return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) @@ -222,56 +222,64 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType # TODO: Use different variables? num_heads = self.n_v_heads num_head_groups = self.n_qk_heads - # v_heads have size `headdim` that may be different from `state_size`. - Assert.multiple(self.d_inner, num_heads) else: raise NotImplementedError(block_type) - tensor_space.add_tensor_dim(state_dim := TensorDim(SSMDimNames.state, self.state_size)) + tensor_space.add_tensor_dim(state := TensorDim(SSMDimNames.state, self.state_size)) + if block_type == SSMBlockType.mamba2_discrete: + tensor_space.add_tensor_dim(head_dim := TensorDim(SSMDimNames.head_dim, div(self.d_inner, num_heads))) + else: + head_dim = state + tensor_space.add_tensor_dim(head_groups := TensorDim(SSMDimNames.head_groups, num_head_groups, tensor)) tensor_space.add_tensor_dim(group_heads := TensorDim(SSMDimNames.group_heads, div(num_heads, num_head_groups))) tensor_space.add_tensor_dim( heads := CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads)) ) tensor_space.add_tensor_dim( - heads_and_state := CompositeTensorDim( - SSMDimNames.composite_heads_and_state, (head_groups, group_heads, state_dim) + heads_and_head_dim := CompositeTensorDim( + SSMDimNames.composite_heads_and_head_dim, (head_groups, group_heads, head_dim) ) ) tensor_space.add_tensor_dim( head_groups_and_state := CompositeTensorDim( - SSMDimNames.composite_head_groups_and_state, (head_groups, state_dim) + SSMDimNames.composite_head_groups_and_state, (head_groups, state) ) ) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_kernel, self.conv_kernel_dimension)) + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.convolution_kernel, self.conv_kernel_dimension)) # DT projection if block_type in (SSMBlockType.mamba, SSMBlockType.mamba2): - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.dt_rank, self.dt_rank)) + tensor_space.add_tensor_dim(dt_rank := TensorDim(SSMDimNames.dt_rank, self.dt_rank)) if block_type == SSMBlockType.mamba: - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim, self.dt_rank + self.state_size * 2)) + tensor_space.add_tensor_dim( + ConcatenatedTensorDim(SSMDimNames.concatenated_x_projection, (dt_rank, state, state)) + ) # TODO: Use composition instead tensor_space.add_tensor_dim( - ConcatenatedTensorDim(SSMDimNames.concatenated_inner_projection, (heads_and_state, heads_and_state)) + ConcatenatedTensorDim( + SSMDimNames.concatenated_inner_projection, (heads_and_head_dim, heads_and_head_dim) + ) ) elif block_type == SSMBlockType.mamba2: # TODO: Factor out state? tensor_space.add_tensor_dim( ConcatenatedTensorDim( SSMDimNames.concatenated_inner_projection, - (heads_and_state, head_groups_and_state, head_groups_and_state, heads_and_state), + (heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim), ) ) elif block_type == SSMBlockType.mamba2_discrete: - # TODO: Factor as (head_groups, (group_heads + 2) * state_size + group_heads)? tensor_space.add_tensor_dim( ConcatenatedTensorDim( SSMDimNames.concatenated_inner_projection, - (heads_and_state, head_groups_and_state, head_groups_and_state, heads_and_state, heads), + (heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim, heads), ) ) - # TODO: (head_groups, group_heads + 2, state_size) tensor_space.add_tensor_dim( - TensorDim(SSMDimNames.conv_dim, self.d_inner + 2 * self.n_qk_heads * self.state_size) + ConcatenatedTensorDim( + SSMDimNames.concatenated_convolution, + (heads_and_head_dim, head_groups_and_state, head_groups_and_state), + ) ) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index ac4fb87c..64377b93 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -49,14 +49,18 @@ def __init__( layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) + inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_head_dim) hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - conv1d_dim = tensor_space.get_tensor_dim(SSMDimNames.conv_dim) + conv1d_dim = tensor_space.get_tensor_dim(SSMDimNames.concatenated_convolution) heads_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads) - self._local_heads = heads_dim.size + # local_head_groups = head_groups / TP self._local_head_groups = tensor_space.get_tensor_dim(SSMDimNames.head_groups).size + # local_heads = local_head_groups * group_heads + self._local_heads = heads_dim.size + # local_inner_size = local_heads * head_size self._local_inner_size = inner_dim.size + # local_bc_size = local_head_groups * state self._local_bc_size = tensor_space.get_tensor_dim(SSMDimNames.composite_head_groups_and_state).size # TODO: double check initializations @@ -80,7 +84,7 @@ def __init__( ( conv1d_dim, tensor_space.get_tensor_dim(DefaultDimNames.scalar), - tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel), + tensor_space.get_tensor_dim(name=SSMDimNames.convolution_kernel), ), init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, @@ -107,24 +111,25 @@ def __init__( ) def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: - if kwargs[TransformerKwargs.sequence_first]: - raise NotImplementedError(f"Sequence-first not supported for SSMs.") - assert _mamba_available - sequence_length = input_.size(0 if kwargs[TransformerKwargs.sequence_first] else 1) + sequence_length = kwargs[TransformerKwargs.sequence_q_dim].global_size # Pad input to nearest multiple of chunklen padded_length = (1 + (sequence_length - 1) // self._config.chunk_size) * self._config.chunk_size if padded_length != sequence_length: - assert not kwargs[TransformerKwargs.sequence_first] and not self._sequence_parallel + assert not kwargs[TransformerKwargs.sequence_first] and input_.size(1) == sequence_length input_ = torch.nn.functional.pad(input_, (0, 0, 0, padded_length - sequence_length)) + # inner_projection : (batch/local_or_padded_sequence, local_sequence/batch, hidden) + # -> (batch/local_or_padded_sequence, local_sequence/batch, inner_projection) + # inner_projection: (batch, local_or_padded_sequence, hidden) -> (batch, padded_sequence, local_inner_size) inner_projection = self.in_proj(input_) - # Standardize to (batch, sequence, inner_projection) + # Standardize to (batch, padded_sequence, inner_projection) if kwargs[TransformerKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) + print("QAIKOFNMJOWENM inner_projection", inner_projection.shape) xBC, z, A_log = torch.split( inner_projection, [ @@ -134,9 +139,13 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ ], dim=-1, ) + print("QAIKOFNMJOWENM xBC", xBC.shape, self._local_inner_size, self._local_bc_size) + print("QAIKOFNMJOWENM z", z.shape) + print("QAIKOFNMJOWENM A_log", A_log.shape) # Convolutional layer - xBC = self.convolutional_forward(xBC, sequence_length) + # xbc: (batch, padded_sequence, local_heads * head_size + 2 * local_head_groups * state) + xBC = self.convolutional_forward(xBC, padded_length) x, B, C = torch.split( xBC, @@ -148,13 +157,16 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ dim=-1, ) + # x: (batch, padded_sequence, local_heads * head_size) -> (batch, padded_sequence, local_heads, head_size) x = einops.rearrange(x, "b l (h n) -> b l h n", h=self._local_heads) + + # b,c: (batch, padded_sequence, local_head_groups * state) -> (batch, padded_sequence, local_head_groups, state) B = einops.rearrange(B, "b l (h n) -> b l h n", h=self._local_head_groups) C = einops.rearrange(C, "b l (h n) -> b l h n", h=self._local_head_groups) # SSM forward y = _mamba_chunk_scan_combined( - x=x / torch.nn.functional.softplus(A_log).to(x.dtype).unsqueeze(-1), + x=self._apply_a_log(x, A_log), dt=A_log, dt_softplus=True, A=-torch.ones(self._local_heads, device=A_log.device), @@ -169,23 +181,31 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if not self._config.add_bias_linear: z = z + self.z_bias - # y: (batch, sequence, heads, state) -> (batch, sequence, heads * state) + # y: (batch, padded_sequence, local_heads, head_size) -> (batch, sequence, local_heads * head_size) y = ((y + Du).flatten(2, 3) * torch.nn.functional.silu(z))[:, :sequence_length] if kwargs[TransformerKwargs.sequence_first]: # TODO: Is contiguous needed? y = y.transpose(0, 1).contiguous() + # out_proj: (batch/sequence, sequence/batch, local_heads * head_size) + # -> (batch/local_sequence, local_sequence/batch, hidden) + a, b = self.out_proj(y) + logger.info(f"EKFBN y {y.shape}") + logger.info(f"EKFBN a {a.shape}") return self.out_proj(y) + @torch.compile + def _apply_a_log(self, x: torch.Tensor, A_log: torch.Tensor) -> torch.Tensor: + return x / torch.nn.functional.softplus(A_log).to(x.dtype).unsqueeze(-1) + def convolutional_forward(self, xBC, padded_len): """Convolutional layer forward pass for the full sequence.""" if _causal_conv1d_available and self._config.activation_type in ( ActivationType.silu, - "swish", ActivationType.identity, ): xBC = _causal_conv1d_fn( xBC.transpose(1, 2), - einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), + self.conv1d_weight.squeeze(1), self.conv1d_bias, activation=( None diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index cba28f8b..1ae25e44 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -39,7 +39,7 @@ class Mamba2(Mixer): _XZ_DIMS = ( TransformerDimNames.batch, - SSMDimNames.composite_heads_and_state, + SSMDimNames.composite_heads_and_head_dim, TransformerDimNames.sequence_q, ) _BC_DIMS = ( @@ -62,7 +62,7 @@ def __init__( layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_state) + inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_head_dim) xb_dim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_head_groups_and_state) hidden_dim: TensorDim = tensor_space.get_tensor_dim(name=TransformerDimNames.hidden) dt_rank_dim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank) @@ -78,7 +78,7 @@ def __init__( ( conv1d_dim, tensor_space.get_tensor_dim(DefaultDimNames.scalar), - tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel), + tensor_space.get_tensor_dim(name=SSMDimNames.convolution_kernel), ), init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, @@ -146,6 +146,8 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ assert _mamba_available assert _causal_conv1d_available + # inner_projection : (batch/local_sequence, local_sequence/batch, hidden) + # -> (batch/sequence, sequence/batch, inner_projection) inner_projection = self.in_proj(input_) dt = self.dt_proj(self.dt_in_proj(input_)) + self.dt_proj_bias # Standardize to (batch, sequence, inner_projection) @@ -161,10 +163,10 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ dim=2, ) - # z: (batch, sequence, heads * state) -> (batch, heads * state, sequence) + # z: (batch, sequence, local_heads * state) -> (batch, local_heads * state, sequence) z = z.transpose(1, 2) - # x: (batch, sequence, head_groups * state) -> (batch, heads * state, sequence) + # x: (batch, sequence, local_head_groups * state) -> (batch, local_heads * state, sequence) x = x.transpose(1, 2) if self._config.repeat_kv_before_conv: x = ( @@ -172,16 +174,16 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) ) - x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias.squeeze(1), activation="silu") + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight.squeeze(1), bias=self.conv1d_bias, activation="silu") else: - x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight, bias=self.conv1d_bias.squeeze(1), activation="silu") + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight.squeeze(1), bias=self.conv1d_bias, activation="silu") x = ( x.unflatten(1, (self._local_head_groups, self._config.state_size)) .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) ) - # b: (batch, sequence, head_groups * state) -> (batch, heads, state, sequence) + # b: (batch, sequence, local_head_groups * state) -> (batch, local_heads, state, sequence) b = ( b.transpose(1, 2) .unflatten(1, (self._local_head_groups, self._config.state_size)) @@ -216,9 +218,11 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if self._debug_level: self._debug_log(y, "y", self._XZ_DIMS, kwargs) - # y: (batch, heads * state, sequence) -> (batch, sequence, heads * state) + # y: (batch, local_heads * state, sequence) -> (batch, sequence, local_heads * state) y = y.transpose(1, 2)[:, :sequence_length] if kwargs[TransformerKwargs.sequence_first]: # TODO: Is contiguous needed? y = y.transpose(0, 1).contiguous() + # (batch/sequence, sequence/batch, local_heads * state) + # -> (batch/local_sequence, local_sequence/batch, hidden) return self.out_proj(y) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 07eec38e..64c8227f 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -10,7 +10,7 @@ from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_ +from fast_llm.tensor import LambdaInitializer, ParameterMeta, init_kaiming_, init_ones_ from fast_llm.utils import Assert, get_lr_scale try: @@ -29,30 +29,27 @@ """ -def init_A(d_state, d_inner) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - # TODO: adopt this initialization to work for tensor parallel setting! +def init_A(d_state, d_inner) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa if tensor.numel() != d_state * d_inner: - raise ValueError(f"_init_A requires not supported for tensor slices.") - return torch.log( + raise ValueError("_init_A requires not supported for tensor slices.") + torch.log( torch.arange(1, d_state + 1, dtype=torch.float32, device=tensor.device) .unsqueeze(0) .expand(d_inner, d_state), out=tensor, ) - return init_ + return LambdaInitializer(init_, requires_global_initialization=True) -def init_dtprojbias( - dt_max: float, dt_min: float, dt_init_floor: float -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: +def init_dtprojbias(dt_max: float, dt_min: float, dt_init_floor: float) -> LambdaInitializer: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa tensor.uniform_(math.log(dt_min), math.log(dt_max), generator=generator).exp_().clamp_min_(dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - return tensor.add_(torch.log(-torch.expm1(-tensor))) + tensor.add_(torch.log(-torch.expm1(-tensor))) - return init_ + return LambdaInitializer(init_) class MambaLayer(Mixer): @@ -72,7 +69,7 @@ def __init__( Assert.eq(self._config.activation_type, ActivationType.silu) # Tensor dims: - inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_state) + inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_head_dim) hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) @@ -90,7 +87,7 @@ def __init__( ( inner_dim, tensor_space.get_tensor_dim(DefaultDimNames.scalar), - tensor_space.get_tensor_dim(SSMDimNames.conv_kernel), + tensor_space.get_tensor_dim(SSMDimNames.convolution_kernel), ), init_method=init_kaiming_(inner_dim.size), lr_scale=lr_scale, @@ -98,7 +95,7 @@ def __init__( self.x_proj = Linear( inner_dim, - tensor_space.get_tensor_dim(SSMDimNames.x_proj_dim), + tensor_space.get_tensor_dim(SSMDimNames.concatenated_x_projection), weight_init_method=init_kaiming_(inner_dim.size), bias=False, lr_scale=lr_scale, diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 899e7000..b89ed4a0 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -1,3 +1,4 @@ +import abc import functools import math import typing @@ -241,7 +242,7 @@ def __init__( *, tensor_name: str = "", dims: tuple[TensorDim, ...], - init_method: typing.Callable[["ParameterMeta", torch.Tensor, torch.Generator], torch.Tensor] | None = None, + init_method: "Initializer | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None" = None, weight_decay: bool = True, # Pass a list to split the parameter in contiguous (dim=0) chunks of equal size for optimization. lr_scale: float | None | tuple[float | None, ...] = None, @@ -251,7 +252,11 @@ def __init__( allow_no_grad: bool = False, ): super().__init__(data, tensor_name=tensor_name, dims=dims) - self.param_init_method = init_method + if init_method is not None and not isinstance(init_method, Initializer): + # Support non-wrapped callables for convenience. + assert callable(init_method) + init_method = LambdaInitializer(init_method) + self.param_init_method: Initializer | None = init_method self.param_weight_decay = weight_decay self._is_param = True self.param_grad_is_zero = False @@ -276,7 +281,7 @@ def __new__( *, tensor_name: str = "", dims: tuple[TensorDim, ...], - init_method: typing.Callable, + init_method: "Initializer | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None", weight_decay: bool = True, lr_scale: float | None | tuple[float | None, ...] = None, allow_sequence_tensor_parallel: bool = True, @@ -303,6 +308,10 @@ def init_parameter(self, tensor: torch.Tensor, distributed: Distributed) -> None generator = distributed.tp_init_generator if self.is_tensor_parallel else distributed.pp_init_generator self.param_init_method(self, tensor, generator) + @property + def requires_global_initialization(self) -> bool: + return self.param_init_method.requires_global_initialization + def save(self) -> dict[str, typing.Any]: return { "name": self.tensor_name, @@ -334,11 +343,32 @@ def accumulate_gradient(param: torch.Tensor, grad: torch.Tensor) -> None: triton_add(grad, param.grad_buffer, out=param.grad_buffer) # noqa -def init_fill_(value) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - return tensor.fill_(value) +class Initializer(abc.ABC): + @abc.abstractmethod + def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: + pass + + requires_global_initialization = False - return init_ + +class LambdaInitializer(Initializer): + def __init__( + self, + init_method: typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None], + requires_global_initialization: bool = False, + ) -> None: + self._init_method = init_method + self.requires_global_initialization = requires_global_initialization + + def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: + return self._init_method(meta, tensor, generator) + + +def init_fill_(value: float) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa + tensor.fill_(value) + + return LambdaInitializer(init_) init_zeros_ = init_fill_(0.0) @@ -346,38 +376,32 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) def init_normal_( - mean=0.0, std=1.0, min_val=None, max_val=None -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa + mean: float = 0.0, std: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa tensor = tensor.normal_(mean, std, generator=generator) if min_val is not None or max_val is not None: - return tensor.clamp_(min=min_val, max=max_val) # noqa - else: - return tensor + tensor.clamp_(min=min_val, max=max_val) - return init_ + return LambdaInitializer(init_) -def init_kaiming_(d_in): +def init_kaiming_(d_in: float) -> LambdaInitializer: return init_normal_(0.0, math.sqrt(2.0 / d_in)) def init_uniform_( - low=0.0, high=1.0, min_val=None, max_val=None -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa + low: float = 0.0, high: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa tensor = tensor.uniform_(low, high, generator=generator) if min_val is not None or max_val is not None: - return tensor.clamp_(min=min_val, max=max_val) # noqa - else: - return tensor + tensor.clamp_(min=min_val, max=max_val) - return init_ + return LambdaInitializer(init_) -def init_uniform_centered_( - high, max_val=None, mean=0.0 -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: +def init_uniform_centered_(high: float, max_val: float | None = None, mean: float = 0.0) -> LambdaInitializer: return init_uniform_( mean - high, mean + high, diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 05acf23d..4bda5512 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -284,10 +284,15 @@ def test_load_pretrained( @pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_huggingface_model(model_testing_config, get_convert_path): # Test that Fast-LLM's Hugging Face wrapper produces the same results as the converted Hugging Face model. + # TODO: Stress the importance of this test as the main correctness test for most models. # TODO: Review test. Move to test_generate? fast_llm_path = get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat) hf_path = get_convert_path(model_testing_config.checkpoint_format, DistributedCheckpointFormat) - model_ref = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained( + try: + hf_class = model_testing_config.huggingface_model_for_causal_lm_class + except NotImplementedError: + pytest.skip(f"Hugging Face wrapper not implemented for {model_testing_config.name}.") + model_ref = hf_class.from_pretrained( CheckpointLoadConfig( path=get_convert_path(), format=DistributedCheckpointFormat, @@ -298,8 +303,8 @@ def test_huggingface_model(model_testing_config, get_convert_path): 0, model_ref.config.fast_llm_config.base_model.vocab_size, size=(4, 100), dtype=torch.int64, device="cuda" ) output_ref = model_ref(test_input) - model_from_fast_llm = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained(fast_llm_path) - model_from_hf = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained( + model_from_fast_llm = hf_class.from_pretrained(fast_llm_path) + model_from_hf = hf_class.from_pretrained( CheckpointLoadConfig( path=hf_path, format=model_testing_config.checkpoint_format, diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 038b53c2..722d8d63 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -20,6 +20,7 @@ Starcoder2GPTHuggingfaceCheckpointFormat, ) from fast_llm.models.ssm.config import ( + AprielSSMHHybridHuggingfaceCheckpointFormat, AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, LLambaHuggingfaceCheckpointFormat, ) @@ -540,19 +541,19 @@ def _update_and_add_testing_config( "model.base_model.ssm.chunk_size=32", ], megatron_args=None, - checkpoint_format=None, + checkpoint_format=AprielSSMHHybridHuggingfaceCheckpointFormat, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, # TODO: Implement - ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, compare_factor=2.0, # Micro-sequence split and sequence-first not supported. - skip_tests=("sf", "stp", "sdp", "ms"), + skip_tests=("sdp", "ms"), ) From fa211747ea1ed81528e771e140b58ed7b579c3b7 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 25 Jul 2025 21:29:07 +0000 Subject: [PATCH 24/37] flexible import --- .../external/llava_hybrid/modeling_llava_hybrid.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py index 9896d91d..0c7fd9b9 100644 --- a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py +++ b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py @@ -4,6 +4,16 @@ from .configuration_llava_hybrid import LlavaHybridConfig +try: + # In the fast-llm repo, import from the SSM modeling file + from ..apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( + AprielThinkerSSMHybridModel, + HybridMambaAttentionDynamicCache, + ) +except ImportError: + # In the exported checkpoint, import from local file + from .modeling_ssm_hybrid_apriel15b import AprielThinkerSSMHybridModel, HybridMambaAttentionDynamicCache + class LlavaMultiModalProjector(nn.Module): def __init__(self, config: LlavaHybridConfig): @@ -42,7 +52,6 @@ def __init__(self, config: LlavaHybridConfig): assert ( config.text_config.model_type == "apriel_ssm_thinker_hybrid" ), "Only Apriel SSM Hybrid model is supported in LlavaHybridModel" - from .modeling_ssm_hybrid_apriel15b import AprielThinkerSSMHybridModel self.language_model = AprielThinkerSSMHybridModel(config.text_config) self.post_init() @@ -69,8 +78,6 @@ def prepare_inputs_for_generation( use_cache=True, **kwargs, ): - from .modeling_ssm_hybrid_apriel15b import HybridMambaAttentionDynamicCache - # Copy of the method from `AprielThinkerSSMHybridForCausalLM` # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` From d3cc1583f0d75b60f99c3ab96a82424371adb83f Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 28 Jul 2025 14:10:39 +0000 Subject: [PATCH 25/37] update import --- .../models/ssm/external/llava_hybrid/modeling_llava_hybrid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py index 0c7fd9b9..b056d3a0 100644 --- a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py +++ b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py @@ -6,7 +6,7 @@ try: # In the fast-llm repo, import from the SSM modeling file - from ..apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( + from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( AprielThinkerSSMHybridModel, HybridMambaAttentionDynamicCache, ) From 6d245c0578cf5d91ef95a4ce528fffe3a5d69f3f Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 28 Jul 2025 18:21:44 +0000 Subject: [PATCH 26/37] fix automodel export --- fast_llm/models/ssm/conversion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 059bff43..64afbea0 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -851,6 +851,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: "AutoConfig": "configuration_llava_hybrid.LlavaHybridConfig", "AutoModel": "modeling_llava_hybrid.LlavaHybridModel", "AutoModelForVision2Seq": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", + "AutoModelForCausalLM": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", }, ), ] From 61ecb5d8206cd3335747a860cd87114919d80666 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 28 Jul 2025 18:59:06 +0000 Subject: [PATCH 27/37] try: remove assert for TP and distillation --- fast_llm/engine/training/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 4b8d805b..3dbec534 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -388,7 +388,7 @@ def _validate(self) -> None: # TODO: Add support. Assert.eq(self.model.distributed.pipeline_parallel, 1) # TODO: Check if these work. - Assert.eq(self.model.distributed.tensor_parallel, 1) + # Assert.eq(self.model.distributed.tensor_parallel, 1) Assert.eq(self.model.distributed.sequence_data_parallel, 1) if self.run.experiment_dir is None: assert not self.training.checkpoint.enabled() From 2565dac89eac0c746d9e3868ac46bbc62a8de695 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 28 Jul 2025 22:46:41 +0000 Subject: [PATCH 28/37] more verbose config --- fast_llm/engine/config_utils/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index 7ab5b8e4..b23037e8 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -130,7 +130,7 @@ def __init__( self._distributed.config.data_rank == 0 and self._distributed.config.tensor_rank == 0 ) config_dict = config.to_dict() - config_dict_verbose = config.to_dict(verbose=FieldVerboseLevel.performance) + config_dict_verbose = config.to_dict(verbose=FieldVerboseLevel.debug) if self._config.experiment_dir is not None: self._experiment_directory = self._config.experiment_dir.resolve() From 7a7f12c54660a9091bcef3409867d426728dfc23 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 29 Jul 2025 00:19:41 +0000 Subject: [PATCH 29/37] use local token_ids instead of modifying batch --- fast_llm/models/gpt/model.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index c76c2191..3d393fd4 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -342,19 +342,20 @@ def preprocess( reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"] + token_ids = batch.token_ids if sequence_first: # Move the sequence dimension first to make sequence parallel ops more efficient. - batch.token_ids = batch.token_ids.transpose(0, 1).contiguous() + token_ids = token_ids.transpose(0, 1).contiguous() preprocessed = [] presents = None for i, (_, kwargs_meta) in enumerate(preprocessed_meta): sequence_k = kwargs_meta[TransformerKwargs.sequence_k_dim].size if sequence_first: - tokens = batch.token_ids[sequence_k - sequence_q : sequence_k] + tokens = token_ids[sequence_k - sequence_q : sequence_k] else: # TODO: Avoid multiple contiguous calls? - tokens = batch.token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() + tokens = token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() if batch.sequence_lengths is not None: kwargs_meta[TransformerKwargs.sequence_lengths] = batch.sequence_lengths if batch.chosen_spans is not None: @@ -374,10 +375,10 @@ def preprocess( if phase != PhaseType.inference: sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels if sequence_first: - labels = batch.token_ids[sequence_offset : sequence_k + prediction_heads] + labels = token_ids[sequence_offset : sequence_k + prediction_heads] else: # TODO: Avoid multiple contiguous calls? - labels = batch.token_ids[:, sequence_offset : sequence_k + prediction_heads].contiguous() + labels = token_ids[:, sequence_offset : sequence_k + prediction_heads].contiguous() # We set label indices to -100 for masked spans, inline with ignore_index in torch.nn.CrossEntropyLoss # TODO: take ignore_index from config labels_cloned = False From 743b42c393b162d6c5881619a9a0d5bd2a21aec6 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 29 Jul 2025 01:08:56 +0000 Subject: [PATCH 30/37] fix allreduce --- fast_llm/functional/cross_entropy.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 7a289b57..df93bca2 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -277,7 +277,8 @@ def _torch_reverse_kl_forward_backward( loss = (loss_per_sample * loss_mask).mean() if group is not None and target_format != TargetFormat.labels: - all_reduce(loss, op=ReduceOp.MEAN, group=group) + all_reduce(loss, op=ReduceOp.SUM, group=group) + loss /= group.size() if grad_output is not None: loss.backward(torch.full_like(loss, grad_output)) From c7247dc4c70752a163c048d2fd720659e7e55200 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 29 Jul 2025 02:09:34 +0000 Subject: [PATCH 31/37] fix --- fast_llm/functional/cross_entropy.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index df93bca2..95f141d9 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -151,7 +151,8 @@ def _fused_cross_entropy_forward_backward( loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: - all_reduce(loss, op=ReduceOp.MEAN, group=group) + all_reduce(loss, op=ReduceOp.SUM, group=group) + loss /= group.size() return loss, grad From 3074ec90d3d2e3164a77aeb57d040c7aa326c85f Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 29 Jul 2025 20:06:19 +0000 Subject: [PATCH 32/37] revert images_sizes conversion to np array --- fast_llm/data/dataset/gpt/memmap.py | 1 - fast_llm/data/dataset/gpt/sampled.py | 2 +- fast_llm/data/preparator/gpt_memmap/prepare.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 2a1986b6..493361f3 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -177,7 +177,6 @@ def _init( assert self._num_pixels == num_pixels if num_tokens is not None: assert self._num_tokens == num_tokens - self._image_sizes = np.array(self._image_sizes, dtype=np.int32) def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: return (self._name, self._prefix, self._num_documents, self._num_tokens, self._num_pixels) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 42062a58..29a784b7 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -143,7 +143,7 @@ def _sample(self) -> None: # Get the document sizes, the main information needed for sampling. document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) - if image_sizes.any(): + if image_sizes: image_token_sizes = [] for i, sizes in enumerate(image_sizes): image_token_sizes.append( diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index fce0f022..d6d47383 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -458,7 +458,7 @@ def _split_and_blend_dataset_configs( text_sizes, image_sizes = dataset.get_document_sizes() tokens_cumsum = text_sizes.cumsum() Assert.eq(tokens_cumsum[-1], dataset_config.num_tokens) - if image_sizes.any(): + if image_sizes: num_pixels_cumsum = np.cumsum([x.prod(axis=1).sum() for x in image_sizes]) # We use the patch sizes only for the purposes of even splitting and blending weights. # We can always use a different patch size for training without any significant impact From c4cdd86403942141e289221502392fb675398cca Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 30 Jul 2025 16:13:29 +0000 Subject: [PATCH 33/37] debug logs --- fast_llm/engine/schedule/runner.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 8eca4559..338c7a5d 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -193,6 +193,8 @@ def run_step( for step in schedule: self._train_step(context, step) + logger.info("End of the schedule steps") + # Make sure we used all the data. This also ensures the generator terminates and prevents a memory leak. try: next(context.data_iterator) @@ -202,6 +204,7 @@ def run_step( raise AssertionError("Data iterator did not terminate") assert context.done, context + logger.info("End data-iterator") if self._multi_stage.config.multi_stage.debug_activation_memory: log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"End of the schedule steps", str)) @@ -240,7 +243,9 @@ def run_step( # TODO: Option to update with reduce (needs per-layer grad_norm and update_successful) # TODO: Avoid blocking synchronizations: async transfer, turn noop_flag into a real noop flag # (uncomment line in apex). + logger.info("Updating weights") update_successful = self._optimizer.step(metrics) + logger.info("Weights updated") if self._multi_stage.config.multi_stage.debug_tensor_parallel and self._distributed.tensor_group is not None: for stage in self._stages_on_device: @@ -275,6 +280,7 @@ def run_step( return self._reduce_losses(context), update_successful, metrics def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: + logger.info("Reducing losses") reduced_losses = {} num_inputs = self._distributed_config.data_parallel * context.schedule.batch_config.num_inputs for name, losses in context.losses.items(): @@ -290,6 +296,7 @@ def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: else: reduced_loss = 0.0 reduced_losses[name] = reduced_loss + logger.info(f"Reduced losses: {reduced_losses}") return { name: reduced_loss.item() if isinstance(reduced_loss, torch.Tensor) else reduced_loss for name, reduced_loss in reduced_losses.items() From 24d7a05df5bf730629fee0f0ad9cbdae0da0bf22 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 5 Aug 2025 15:23:26 +0000 Subject: [PATCH 34/37] rm debug logs --- fast_llm/engine/schedule/runner.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 338c7a5d..8eca4559 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -193,8 +193,6 @@ def run_step( for step in schedule: self._train_step(context, step) - logger.info("End of the schedule steps") - # Make sure we used all the data. This also ensures the generator terminates and prevents a memory leak. try: next(context.data_iterator) @@ -204,7 +202,6 @@ def run_step( raise AssertionError("Data iterator did not terminate") assert context.done, context - logger.info("End data-iterator") if self._multi_stage.config.multi_stage.debug_activation_memory: log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"End of the schedule steps", str)) @@ -243,9 +240,7 @@ def run_step( # TODO: Option to update with reduce (needs per-layer grad_norm and update_successful) # TODO: Avoid blocking synchronizations: async transfer, turn noop_flag into a real noop flag # (uncomment line in apex). - logger.info("Updating weights") update_successful = self._optimizer.step(metrics) - logger.info("Weights updated") if self._multi_stage.config.multi_stage.debug_tensor_parallel and self._distributed.tensor_group is not None: for stage in self._stages_on_device: @@ -280,7 +275,6 @@ def run_step( return self._reduce_losses(context), update_successful, metrics def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: - logger.info("Reducing losses") reduced_losses = {} num_inputs = self._distributed_config.data_parallel * context.schedule.batch_config.num_inputs for name, losses in context.losses.items(): @@ -296,7 +290,6 @@ def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: else: reduced_loss = 0.0 reduced_losses[name] = reduced_loss - logger.info(f"Reduced losses: {reduced_losses}") return { name: reduced_loss.item() if isinstance(reduced_loss, torch.Tensor) else reduced_loss for name, reduced_loss in reduced_losses.items() From 37ddef4d2d366b97a4b5fe9112df68c9d2bb5b5a Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 5 Aug 2025 20:17:39 +0000 Subject: [PATCH 35/37] changes for stp reverse-kl --- fast_llm/functional/cross_entropy.py | 28 +++++++++++++++++--------- fast_llm/layers/language_model/head.py | 11 +++++++++- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 95f141d9..afd7c2ef 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -226,7 +226,8 @@ def _torch_reverse_kl_forward_backward( ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Reverse KL using PyTorch's native kl_div function. - Much simpler and more reliable than custom implementation! + This works with sequence-tensor-parallel (distributing over the sequence dimention) as well as a non-TP case. + In sequence-tensor-parallel, where we split along sequence dim., we compute per split loss and then average the loss. """ Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") Assert.eq(target.shape, logits.shape) @@ -244,7 +245,6 @@ def _torch_reverse_kl_forward_backward( scaled_target = target * (logits_scale_factor / teacher_softmax_temperature) # Clamp to prevent extreme values before log_softmax - scaled_target = torch.clamp(scaled_target, min=-50, max=50) teacher_log_probs = torch.log_softmax(scaled_target, dim=-1) # For reverse KL: KL(q||p) = Σ q * log(q/p) = Σ q * (log(q) - log(p)) @@ -254,14 +254,9 @@ def _torch_reverse_kl_forward_backward( with torch.enable_grad(): logits_ = logits.detach().requires_grad_(grad_output is not None) - # Use log_softmax for consistency instead of _fused_softmax scaled_logits = logits_ * logits_scale_factor - scaled_logits = torch.clamp(scaled_logits, min=-50, max=50) student_log_probs = torch.log_softmax(scaled_logits, dim=-1) - # Convert to probabilities for kl_div - # student_probs_ = torch.exp(student_log_probs) - # Reverse KL: input=teacher_log_probs, target=student_probs if loss_mask is None: loss = torch.nn.functional.kl_div( @@ -299,6 +294,7 @@ def reverse_kl_forward_backward( logits_scale_factor: float = 1.0, teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, + vocab_parallel: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute reverse KL divergence: KL(q||p) where q is the predicted distribution (student) and p is the target (teacher). @@ -342,6 +338,18 @@ def reverse_kl_forward_backward( if loss_mask is not None: Assert.eq(loss_mask.shape, logits.shape[:-1]) # TODO: implement fused? - return _torch_reverse_kl_forward_backward( - logits, target, loss_mask, grad_output, logits_scale_factor, target_format, group, teacher_softmax_temperature - ) + if vocab_parallel: + Assert.eq(teacher_softmax_temperature, 1) + Assert.eq(logits_scale_factor, 1) + raise NotImplementedError("Vocab parallel reverse KL is not implemented yet.") + else: + return _torch_reverse_kl_forward_backward( + logits, + target, + loss_mask, + grad_output, + logits_scale_factor, + target_format, + group, + teacher_softmax_temperature, + ) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 21bf3bbd..791b1f09 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -239,7 +239,15 @@ def _get_targets( lm_target = None targets = (dpo_target, lm_target, distillation_target, loss_mask) - if self._sequence_parallel_logits: + # If we do distillation, no need to split it here as it has already been split in the embedding layer! + # if we do CPT/language modeling, we need to split the targets here! + if ( + self._config.distillation_model is not None + and self._sequence_parallel_logits + and not self._parallel_embeddings + and not self._sequence_parallel + ) or (self._config.distillation_model is None and self._sequence_parallel_logits): + # We dont split targets if they already have been split in the embedding layer! targets = [ None if target is None else split_op(target, self._tensor_space.distributed.tensor_group, 0) for target in targets @@ -412,6 +420,7 @@ def _logits_cross_entropy_forward_backward( target_format=( TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits ), + vocab_parallel=logits.shape[-1] != self._config.vocab_size, ) elif self._distillation_loss_implementation == DistillationLossImpl.cross_entropy: distillation_loss, distillation_grad = cross_entropy_forward_backward( From a0d7a09fdce57e70fbf1524574ca8393adc0453b Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 6 Aug 2025 03:49:09 +0000 Subject: [PATCH 36/37] reverse kl: add clamping --- fast_llm/functional/cross_entropy.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index afd7c2ef..eaeaa0d1 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -245,6 +245,7 @@ def _torch_reverse_kl_forward_backward( scaled_target = target * (logits_scale_factor / teacher_softmax_temperature) # Clamp to prevent extreme values before log_softmax + scaled_target = torch.clamp(scaled_target, min=-50, max=50) teacher_log_probs = torch.log_softmax(scaled_target, dim=-1) # For reverse KL: KL(q||p) = Σ q * log(q/p) = Σ q * (log(q) - log(p)) @@ -255,6 +256,7 @@ def _torch_reverse_kl_forward_backward( logits_ = logits.detach().requires_grad_(grad_output is not None) scaled_logits = logits_ * logits_scale_factor + scaled_logits = torch.clamp(scaled_logits, min=-50, max=50) student_log_probs = torch.log_softmax(scaled_logits, dim=-1) # Reverse KL: input=teacher_log_probs, target=student_probs From 72945bcfd099c09ee7bf399d13258052aef5b483 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 6 Aug 2025 14:09:35 +0000 Subject: [PATCH 37/37] add loss mask for vision. should also handle padded sequences --- fast_llm/layers/language_model/head.py | 6 +++++- fast_llm/models/gpt/model.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 791b1f09..eed2d134 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -238,7 +238,7 @@ def _get_targets( else: lm_target = None - targets = (dpo_target, lm_target, distillation_target, loss_mask) + targets = (dpo_target, lm_target, distillation_target) # If we do distillation, no need to split it here as it has already been split in the embedding layer! # if we do CPT/language modeling, we need to split the targets here! if ( @@ -252,6 +252,10 @@ def _get_targets( None if target is None else split_op(target, self._tensor_space.distributed.tensor_group, 0) for target in targets ] + # Loss mask may need to be split. It was not split in the embedding layer as it is not used there. + if loss_mask is not None and self._sequence_parallel_logits: + loss_mask = split_op(loss_mask, self._tensor_space.distributed.tensor_group, 0) + targets = (*targets, loss_mask) if not any(target is not None for target in targets): # Simplify so we don't have to check every time. targets = None diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 3d393fd4..be172af9 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -407,16 +407,32 @@ def preprocess( kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) if self._config.vision_encoder.enabled: + loss_mask = kwargs.get(LanguageModelKwargs.loss_mask, torch.ones_like(labels, dtype=torch.bool)) if self._config.vision_encoder.image_break_token is not None: if not labels_cloned: labels = labels.clone() labels_cloned = True labels = torch.where(labels == self._config.vision_encoder.image_break_token, -100, labels) + loss_mask = torch.where( + labels == self._config.vision_encoder.image_break_token, False, loss_mask + ) + if self._config.distillation_model is not None: + kwargs[LanguageModelKwargs.loss_mask] = loss_mask if self._config.vision_encoder.image_end_token is not None: if not labels_cloned: labels = labels.clone() labels_cloned = True labels = torch.where(labels == self._config.vision_encoder.image_end_token, -100, labels) + loss_mask = torch.where( + labels == self._config.vision_encoder.image_end_token, False, loss_mask + ) + if self._config.distillation_model is not None: + kwargs[LanguageModelKwargs.loss_mask] = loss_mask + # TODO: Check that this works. Can we remove previous loss_masking? + if self._config.distillation_model is not None: + loss_mask = kwargs.get(LanguageModelKwargs.loss_mask, torch.ones_like(labels, dtype=torch.bool)) + loss_mask = torch.where(labels == -100, False, loss_mask) + kwargs[LanguageModelKwargs.loss_mask] = loss_mask kwargs[LanguageModelKwargs.labels] = labels kwargs.update(reference_logits[i])