diff --git a/Megatron-LM b/Megatron-LM index 511e8f5cb..75b0d9787 160000 --- a/Megatron-LM +++ b/Megatron-LM @@ -1 +1 @@ -Subproject commit 511e8f5cbe3ab8291953ac64e5beceb727a1b814 +Subproject commit 75b0d97876006c4b6b23fce302100d18dbf7db37 diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 2f18f1360..9a8ce2092 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -191,6 +191,8 @@ def initialize_weights(self) -> None: # Initialize all global weights on every gpu, then select the appropriate slice if applicable. global_param = parameter.new_empty(global_shape, device=self._distributed.device) meta.init_parameter(global_param, distributed=self._distributed) + # It happens. + Assert.eq(global_param.shape, global_shape) if self._mode.on_device: parameter.copy_(fsdp.parameter_global_to_shard(global_param, meta.tensor_name)) elif self._mode.on_device: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 25fc2b28d..21bf3bbd0 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -125,12 +125,16 @@ def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: if isinstance(input_, TensorMeta): - return TensorMeta.from_tensor_space( - (DefaultDimNames.scalar,), - self._tensor_space, - tensor_name="Loss", - reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa - ) + if self._is_last_head: + return TensorMeta.from_tensor_space( + (DefaultDimNames.scalar,), + self._tensor_space, + tensor_name="Loss", + reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa + ) + else: + return TensorMeta.from_dims(input_.dims[1:], tensor_name="Shared hidden") + if not self._is_last_head: # MTP: split the stacked input shared_hidden, input_ = torch.unbind(input_, dim=0) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 46d629aa8..a1f357de9 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -23,6 +23,7 @@ class SSMDimNames: # Mamba 2 x_proj_dim_2 = "x_proj_dim_2" # d_xb + c_heads = "c_heads" class SSMBlockType(enum.StrEnum): @@ -35,6 +36,22 @@ class SSMBlockType(enum.StrEnum): mamba2 = "m2" transformer = "t" + def get_mixer_class(self): + if self == SSMBlockType.mamba: + from fast_llm.layers.ssm.mamba_layer import MambaLayer + + return MambaLayer + elif self == SSMBlockType.mamba2: + from fast_llm.layers.ssm.mamba2 import Mamba2 + + return Mamba2 + elif self == SSMBlockType.mamba2_discrete: + from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 + + return DiscreteMamba2 + else: + raise NotImplementedError(self) + @config_class() class SSMConfig(LLMBlockConfig): @@ -95,11 +112,6 @@ class SSMConfig(LLMBlockConfig): desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", hint=FieldHint.architecture, ) - debug_ssm: bool = Field( - default=False, - desc="debug_ssm", - hint=FieldHint.optional, - ) dt_min: float = Field( default=0.001, desc="Minimum step size for discretization", @@ -147,18 +159,6 @@ class SSMConfig(LLMBlockConfig): hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - dt_min: float = Field( - default=0.001, - desc="Minimum step size for discretization", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), - ) - dt_init_floor: float = Field( - default=1e-4, - desc="Minimum value for initializing dt", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), - ) dt_scale: float = Field( default=1.0, desc="Scale for dt", diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 934cd2b5d..c0ae7e781 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -1,13 +1,15 @@ import logging import math +import typing import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs +from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_ from fast_llm.utils import get_lr_scale @@ -36,29 +38,29 @@ def bias_init_method(conv_weight): return init_uniform_(-bound, bound) -class DiscreteMamba2(torch.nn.Module): +class DiscreteMamba2(Mixer): """DiscreteMamba2 (This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py).""" + _mixer_name: typing.ClassVar[str] = "discrete_mamba_2" + def __init__( self, config: SSMConfig, - layer_idx: int, + block_index: int, tensor_space: TensorSpace, - return_input: bool = False, + transformer_config: TransformerConfig, ): """ See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. Other options are all experimental and should not need to be configured. """ # factory_kwargs = {"device": "meta"} # , "dtype": torch.bfloat16} - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) self.config: SSMConfig = config bias = config.add_bias_linear - self.layer_idx = layer_idx - self._return_input = return_input - layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) - logger.info(f"Setting lr_scale for layer {layer_idx} of type {type(self)}: {mamba_layer_lr_scale}") + logger.info(f"Setting lr_scale for layer {block_index} of type {type(self)}: {mamba_layer_lr_scale}") td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) @@ -101,7 +103,7 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( - (td_conv, TensorDim("1", 1), td_conv_kernel), + (td_conv, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=init_uniform_( 1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size) ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 @@ -226,9 +228,6 @@ def forward(self, hidden_states, kwargs): out = self.out_proj(y * torch.nn.functional.silu(z + self.z_bias)) outputs["hidden_states"] = out[:, :seqlen, :].contiguous() - if self._return_input: - return torch.stack([input_, outputs["hidden_states"]], dim=0) - # TODO: since we do not support inference for now, we only return the hidden states for now. return outputs["hidden_states"], None diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/llamba_block.py index ee222d6d2..986606634 100644 --- a/fast_llm/layers/ssm/llamba_block.py +++ b/fast_llm/layers/ssm/llamba_block.py @@ -1,6 +1,6 @@ import typing -from fast_llm.layers.transformer.transformer import BaseBlock +from fast_llm.layers.transformer.transformer import BaseBlock, Mixer if typing.TYPE_CHECKING: from fast_llm.engine.config_utils.tensor_space import TensorSpace @@ -8,27 +8,30 @@ from fast_llm.layers.transformer.config import TransformerConfig -class LlambaBlock(BaseBlock): +class SSMBlock(BaseBlock): """ A transformer-like decoder block with a SSM mixer, see https://arxiv.org/abs/2502.14458 """ _name = "Llamba block" - _mixer_module_name = "mixer" def __init__( self, - config_transformer: "TransformerConfig", - config_ssm: "SSMConfig", + transformer_config: "TransformerConfig", + ssm_config: "SSMConfig", tensor_space: "TensorSpace", - mixer_cls, - layer_index: int, + mixer_cls: type[Mixer], + block_index: int, return_input: bool = False, ): - self.mixer_cls = mixer_cls - self._config_ssm = config_ssm - self._debug_mode = self._config_ssm.debug_ssm - super().__init__(config_transformer, tensor_space, layer_index, return_input) + self._ssm_config = ssm_config + self._mixer_cls = mixer_cls + super().__init__(transformer_config, tensor_space, block_index, return_input) - def _create_mixer(self): - self.mixer = self.mixer_cls(self._config_ssm, layer_idx=self._layer_index, tensor_space=self._tensor_space) + def _create_mixer(self) -> Mixer: + return self._mixer_cls( + self._ssm_config, + tensor_space=self._tensor_space, + block_index=self._block_index, + transformer_config=self._config, + ) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index a03509abb..74c212add 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -4,9 +4,12 @@ import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames +from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_fill_, init_ones_, init_uniform_, kaiming_init_ from fast_llm.utils import get_lr_scale @@ -43,24 +46,36 @@ def bias_init_method(conv_weight): return init_uniform_(-bound, bound) -class Mamba2(torch.nn.Module): +class Mamba2(Mixer): """ This code is adapted from https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py """ + _mixer_name: typing.ClassVar[str] = "mamba_2" + + _XZ_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.inner_dim, + TransformerDimNames.sequence_q, + ) + _BC_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.c_heads, + SSMDimNames.state_dim, + TransformerDimNames.sequence_q, + ) + def __init__( self, config: SSMConfig, - layer_idx: int, tensor_space: TensorSpace, - return_input: bool = False, + block_index: int, + transformer_config: TransformerConfig, ): - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) self.config: SSMConfig = config bias: bool = config.add_bias_linear - self.layer_idx = layer_idx - self._return_input = return_input - layer_lr_scale: float | None = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None mamba_layer_lr_scale: float | tuple[float | None, ...] | None = get_lr_scale( self.config.mamba_lr_scale, layer_lr_scale ) @@ -83,9 +98,9 @@ def __init__( if self.repeat_kv_before_conv: self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, TensorDim("1", 1), td_conv_kernel), + (td_inner, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=init_uniform_( - 1 / math.sqrt(td_inner.size * td_conv_kernel.size), + -1 / math.sqrt(td_inner.size * td_conv_kernel.size), 1 / math.sqrt(td_inner.size * td_conv_kernel.size), ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 lr_scale=mamba_layer_lr_scale, @@ -96,9 +111,9 @@ def __init__( ) else: self.conv1d_weight = ParameterMeta.from_dims( - (td_xb, TensorDim("1", 1), td_conv_kernel), + (td_xb, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=init_uniform_( - 1 / math.sqrt(td_xb.size * td_conv_kernel.size), + -1 / math.sqrt(td_xb.size * td_conv_kernel.size), 1 / math.sqrt(td_xb.size * td_conv_kernel.size), ), ) @@ -119,7 +134,13 @@ def __init__( weight_init_method=kaiming_init_(td_model.size), lr_scale=mamba_layer_lr_scale, ) - + self.dt_in_proj = Linear( + td_model, + tdt_rank, + bias=config.add_bias_linear, + weight_init_method=kaiming_init_(transformer_config.hidden_size), + lr_scale=mamba_layer_lr_scale, + ) # Initialize special dt projection to preserve variance at initialization dt_scale = config.dt_scale # 1.0 dt_init_std = self.dt_rank**-0.5 * dt_scale @@ -130,24 +151,6 @@ def __init__( else: raise NotImplementedError - # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - dt_max = config.dt_max # or 0.1 - dt_min = config.dt_min # or 0.001 - dt_init_floor = config.dt_init_floor # or 1e-4 - dt = torch.exp(torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)).clamp( - min=dt_init_floor - ) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - - def init_from_tensor_( - value: torch.Tensor, - ) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - return tensor.copy_(value) - - return init_ - self.dt_proj = Linear( tdt_rank, td_inner, @@ -157,18 +160,16 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) ) # define bias outside the linear layer since its also used in the selective_scan_fn self.dt_proj_bias = ParameterMeta.from_dims( - (td_inner,), init_method=init_from_tensor_(inv_dt), lr_scale=mamba_layer_lr_scale + (td_inner,), + init_method=init_dtprojbias( + self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor + ), + lr_scale=mamba_layer_lr_scale, ) - A = einops.repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_log = torch.log(A).flatten() # Keep A_log in fp32 self.A_log = ParameterMeta.from_dims( (td_inner, td_state), - init_method=init_from_tensor_(A_log), + init_method=init_A(self.config.state_size, self.config.d_inner), lr_scale=mamba_layer_lr_scale, weight_decay=False, ) @@ -200,8 +201,8 @@ def forward(self, hidden_states, kwargs): A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - zxbcdt = self.in_proj(hidden_states) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) + zxbc = self.in_proj(hidden_states) + z, x, B, C = torch.split(zxbc, [self.d_inner, self.d_xb, self.d_xb, self.d_inner], dim=-1) x = einops.rearrange(x, "b l d -> b d l") z = einops.rearrange(z, "b l d -> b d l") @@ -211,7 +212,7 @@ def forward(self, hidden_states, kwargs): B = einops.rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() C = einops.rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous() - dt = self.dt_proj(dt) + self.dt_proj_bias # B, L, d_inner + dt = self.dt_proj(self.dt_in_proj(hidden_states)) + self.dt_proj_bias # B, L, d_inner dt = einops.rearrange(dt, "b l d -> b d l") # B, d_inner, L if self.repeat_kv_before_conv: @@ -236,6 +237,13 @@ def forward(self, hidden_states, kwargs): x = repeat_kv(x, self.repeat_group) x = einops.rearrange(x, "b n_group l dstate -> b (n_group dstate) l") + if self._debug_level: + self._debug_log(z, "z", self._XZ_DIMS, kwargs) + self._debug_log(x, "x", self._XZ_DIMS, kwargs) + self._debug_log(B, "b", self._BC_DIMS, kwargs) + self._debug_log(C, "c", self._BC_DIMS, kwargs) + self._debug_log(dt, "dt", self._XZ_DIMS, kwargs) + y = selective_scan_fn( x, dt, @@ -249,6 +257,9 @@ def forward(self, hidden_states, kwargs): return_last_state=False, ) + if self._debug_level: + self._debug_log(y, "y", self._XZ_DIMS, kwargs) + if ssm_state is not None: y, last_state = y ssm_state.copy_(einops.rearrange(last_state, "b (h d) n -> b h d n", h=self.num_C_head)) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 7c824d235..4493332ce 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -1,12 +1,15 @@ import math +import typing from typing import Callable import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_ones_, kaiming_init_ from fast_llm.utils import get_lr_scale @@ -44,12 +47,12 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) def init_dtprojbias( - d_inner: int, dt_max: float, dt_min: float, dt_init_floor: float, factory_kwargs: dict + d_inner: int, dt_max: float, dt_min: float, dt_init_floor: float ) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - dt = torch.exp( - torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) - ).clamp(min=dt_init_floor) + dt = torch.exp(torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)).clamp( + min=dt_init_floor + ) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) tensor.copy_(inv_dt) @@ -58,20 +61,18 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) return init_ -class MambaLayer(torch.nn.Module): +class MambaLayer(Mixer): + _mixer_name: typing.ClassVar[str] = "mamba" + def __init__( self, config: SSMConfig, - layer_idx: int, + block_index: int, tensor_space: TensorSpace, - return_input: bool = False, + transformer_config: TransformerConfig, ): - factory_kwargs = {} - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) self.config: SSMConfig = config - self.layer_idx = layer_idx - - self._debug_mode = config.debug_ssm # Tensor dims: td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) @@ -88,7 +89,7 @@ def __init__( self.d_state = td_state.size self.d_model = td_model.size self.dt_rank = tdt_rank.size - layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) self.in_proj_weight = ParameterMeta.from_dims( @@ -97,7 +98,7 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, TensorDim("D_inner_2", self.d_inner // self.d_inner), td_conv_kernel), + (td_inner, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=kaiming_init_(td_inner.size), lr_scale=mamba_layer_lr_scale, ) @@ -113,7 +114,6 @@ def __init__( weight_init_method=kaiming_init_(td_inner.size), bias=False, lr_scale=mamba_layer_lr_scale, - **factory_kwargs, ) self.x_proj.weight.auto_grad_accumulation = True @@ -127,7 +127,7 @@ def __init__( self.dt_proj_bias = ParameterMeta.from_dims( (td_inner,), init_method=init_dtprojbias( - self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor, factory_kwargs + self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor ), lr_scale=mamba_layer_lr_scale, ) @@ -153,10 +153,8 @@ def __init__( bias=False, # TODO: note, if bias is used there is a problem in the MambaInnerFn.backward for the bias grads. I think this bias is not used in other mamba repos. weight_init_method=kaiming_init_(td_model.size), lr_scale=mamba_layer_lr_scale, - **factory_kwargs, ) self.out_proj.weight.auto_grad_accumulation = True - self._return_input = return_input def forward(self, hidden_states, kwargs): assert _mamba_available @@ -168,8 +166,6 @@ def forward(self, hidden_states, kwargs): "d (b l) -> b d l", l=seqlen, ) - if self._debug_mode: - print("XZ: ", xz.shape) A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # In the backward pass we write dx and dz next to each other to avoid torch.cat @@ -189,6 +185,4 @@ def forward(self, hidden_states, kwargs): delta_bias=self.dt_proj_bias.float(), delta_softplus=True, ) - if self._return_input: - out = torch.stack((hidden_states, out), dim=0) return out, None diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 3351c9906..174e19588 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -13,9 +13,9 @@ TransformerKwargs, TransformerSubLayerName, ) -from fast_llm.logging import log_distributed_grad, log_distributed_tensor -from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.layers.transformer.transformer import Mixer +from fast_llm.tensor import init_normal_, init_zeros_ +from fast_llm.utils import get_lr_scale try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -50,11 +50,13 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # no return grad, None -class Attention(torch.nn.Module): +class Attention(Mixer): """ A self-attention layer. """ + _mixer_name: typing.ClassVar[str] = "attn" + _QUERY_DIMS = ( TransformerDimNames.batch, TransformerDimNames.sequence_q, @@ -64,7 +66,7 @@ class Attention(torch.nn.Module): _KV_DIMS = ( TransformerDimNames.batch, TransformerDimNames.sequence_q, - TransformerDimNames.group_heads, + TransformerDimNames.head_groups, TransformerDimNames.kv_channels, ) _CONTEXT_DIMS = ( @@ -73,19 +75,9 @@ class Attention(torch.nn.Module): TransformerDimNames.composite_dense, ) - def __init__( - self, - config: TransformerConfig, - tensor_space: TensorSpace, - layer_index, - ): - super().__init__() + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int): + super().__init__(tensor_space, block_index, config.debug_transformer) self._config = config - self._tensor_space = tensor_space - # Assert.in_range_incl(layer_index, 1, max(self._config.num_layers, 1)) - self._layer_index = layer_index - self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel - self._debug_transformer = self._config.debug_transformer self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) init_method_qkv = init_normal_( @@ -108,7 +100,7 @@ def __init__( hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) @@ -178,10 +170,10 @@ def _attn_fused( query, key, beta=0, - alpha=self._softmax_scale / self._layer_index, + alpha=self._softmax_scale / self._block_index, ).view(b, self._local_head_groups, sq, self._local_heads_per_group, sk) - attn_weights = attn_weights.to(torch.float32) * self._layer_index + attn_weights = attn_weights.to(torch.float32) * self._block_index attn_weights = torch.where(mask, attn_weights, mask_value) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) @@ -200,40 +192,6 @@ def _attn_fused( .flatten(2) ) - def _get_meta( - self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> TensorMeta: - hidden_dims = {dim.name: dim for dim in kwargs[TransformerKwargs.hidden_dims]} - return TensorMeta.from_dims( - tuple( - hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) - for dim_name in dim_names - ), - tensor_name=f"transformer layer {self._layer_index} attn {name}", - dtype=input_.dtype, - ) - - def _debug_log( - self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> None: - # TODO: Local vs global - Assert.gt(self._debug_transformer, 0) - log_distributed_tensor( - "", - tensor, - level=self._debug_transformer, - meta=self._get_meta(tensor, name, dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - if tensor.requires_grad: - log_distributed_grad( - "", - tensor, - level=self._debug_transformer, - meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - def _query_key_value_forward( self, input_: torch.Tensor, sequence_first: bool ) -> tuple[torch.Tensor, torch.Tensor, dict[str, typing.Any]]: @@ -300,7 +258,7 @@ def _decide_window_size(self) -> int | None: # https://github.com/huggingface/transformers/blob/5e2183f344911aa82aba0b83778a4f196cff378e/src/transformers/models/qwen2/modular_qwen2.py#L71 # TODO: make universal per layer config window_size = self._config.window_size - if self._config.max_window_layers is not None and self._layer_index < self._config.max_window_layers: + if self._config.max_window_layers is not None and self._block_index < self._config.max_window_layers: window_size = None return window_size @@ -341,7 +299,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels) value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels) - if self._debug_transformer: + if self._debug_level: self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs) self._debug_log( key, @@ -395,7 +353,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ kwargs[TransformerKwargs.attention_mask_value], ) - if self._debug_transformer: + if self._debug_level: self._debug_log(query, "query", self._QUERY_DIMS, kwargs) self._debug_log( key, diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index a46af1387..73f83ccf5 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -40,11 +40,11 @@ class MixtureOfExpertMLP(MLPBase): _group: ProcessGroup - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, tensor_space, name, layer_index) + super().__init__(config, tensor_space, name, block_index) self._config = config self._tensor_space = tensor_space self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory @@ -59,7 +59,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._z_loss_factor = config.expert_z_loss_coefficient self._moe_jitter_eps = config.moe_jitter_eps - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) self.router = Linear( diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index b01eb2aa5..efe0c5cc5 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -14,10 +14,10 @@ class MLPBase(Layer, ABC): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): super().__init__() self._name = name - self._layer_index = layer_index + self._block_index = block_index init_method_1 = init_normal_( std=config.init_method_std_mlp_1, @@ -39,7 +39,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._activation_type = config.activation_type self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale lr_scale = get_lr_scale(lr_scale, layer_lr_scale) @@ -69,9 +69,9 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s class MLP(MLPBase): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.eq(config.num_experts, 1) - super().__init__(config, tensor_space, name, layer_index) + super().__init__(config, tensor_space, name, block_index) def forward( self, diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 147452073..d08db9a94 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -8,25 +8,85 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.transformer.attention import Attention from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.transformer.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert logger = logging.getLogger(__name__) +class Mixer(torch.nn.Module, abc.ABC): + """ + Base class for mixer modules. + """ + + _mixer_name: typing.ClassVar[str] + + def __init__(self, tensor_space: TensorSpace, block_index: int, debug_level: int = 0): + super().__init__() + self._tensor_space = tensor_space + self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel + self._block_index = block_index + self._debug_level = debug_level + + @abc.abstractmethod + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Mixer module forward. Returns the output hidden states and an optional bias, + in case its addition can be made more efficient in `_bias_dropout_add`. + """ + + def _get_meta( + self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> TensorMeta: + hidden_dims = { + dim.name: dim + for dim in kwargs[TransformerKwargs.hidden_dims] + (kwargs[TransformerKwargs.sequence_q_dim],) + } + return TensorMeta.from_dims( + tuple( + hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) + for dim_name in dim_names + ), + tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", + dtype=input_.dtype, + ) + + def _debug_log( + self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> None: + # TODO: Local vs global + Assert.gt(self._debug_level, 0) + log_distributed_tensor( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name, dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + if tensor.requires_grad: + log_distributed_grad( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + + class BaseBlock(Layer, abc.ABC): """ A transformer-like decoder base block with abstract mixer. """ - _mixer_module_name = "self_attn" + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "mixer" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False ): super().__init__() self._config: TransformerConfig = config @@ -35,18 +95,19 @@ def __init__( # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input - self._layer_index = layer_index + self._block_index = block_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) # Note, layer_lr_scale does not impact the norms - # TODO: add a seperate norm_lr_scale + # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) - self._create_mixer() + # The mixer needs to be created here for backward-compatible weight ordering. + setattr(self, self._mixer_module_name, self._create_mixer()) self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, self._tensor_space, f"{self.name} mlp", layer_index=layer_index + self._config, self._tensor_space, f"{self.name} mlp", block_index=block_index ) # PEFT. @@ -54,7 +115,7 @@ def __init__( self.norm_2 = self._config.peft.apply_other(self.norm_2) @abc.abstractmethod - def _create_mixer(self): + def _create_mixer(self) -> Mixer: pass @torch.compile @@ -67,7 +128,7 @@ def _bias_dropout_add( @property def name(self) -> str: - return f"{self._name} {self._layer_index}" + return f"{self._name} {self._block_index}" def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[TransformerKwargs.hidden_dims] @@ -137,14 +198,17 @@ def forward( return hidden_states -class TransformerLayer(BaseBlock): +class TransformerBlock(BaseBlock): _name = "Transformer layer" - _mixer_module_name = "self_attn" + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "self_attn" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False ): - super().__init__(config, tensor_space, layer_index, return_input) + super().__init__(config, tensor_space, block_index, return_input) + + def _create_mixer(self) -> Mixer: + from fast_llm.layers.transformer.attention import Attention - def _create_mixer(self): - self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) + return Attention(self._config, self._tensor_space, self._block_index) diff --git a/fast_llm/logging.py b/fast_llm/logging.py index e8334de6e..6d555a0bb 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -138,6 +138,8 @@ def log_tensor[ if level < 1: return tensor = tensor.detach() + if tensor.ndim == 0: + tensor = tensor[None] save_stats = TensorLogs.config.save shape = tuple(tensor.shape) _, dtype = str(tensor.dtype).split("torch.") diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index c206ef406..534d813ff 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -7,7 +7,7 @@ from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.schedule.config import BatchConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.custom.config import CustomBaseModelConfig, CustomModelConfig from fast_llm.models.custom.head import CustomHead from fast_llm.models.gpt.config import GPTBaseModelConfig @@ -31,10 +31,10 @@ def get_layers(self) -> list[Layer]: return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, ) for i in range(self._config.transformer.num_layers) ], diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 444ad72b2..4c1eab46f 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -21,7 +21,7 @@ TransformerLossNames, ) from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -68,11 +68,11 @@ def get_output_layers(self) -> list[Layer]: for i in range(self._config.prediction_heads): if i > 0: layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, # TODO MTP: which index? - layer_index=max(self._config.transformer.num_layers + i, 1), + block_index=max(self._config.transformer.num_layers + i, 1), # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=i < self._config.prediction_heads - 1, @@ -91,10 +91,10 @@ def get_layers(self) -> list[Layer]: return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, @@ -336,7 +336,7 @@ def embedding(self) -> LanguageModelEmbedding: return self.layers[0] @property - def transformer_layers(self) -> list[TransformerLayer]: + def transformer_layers(self) -> list[TransformerBlock]: return self.layers[1:-1] @property diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index a049118cc..1e3ac546b 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -9,9 +9,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.language_model.config import LanguageModelBaseConfig from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig, SSMDimNames -from fast_llm.models.gpt.config import GPTBatchConfig, PretrainedGPTModelConfig +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, PretrainedGPTModelConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -23,14 +22,14 @@ @config_class() -class HybridSSMBaseModelConfig(LanguageModelBaseConfig): +class HybridSSMBaseModelConfig(GPTBaseModelConfig): _abstract = False ssm: SSMConfig = Field( desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) - hybrid_block_layout: list[str] | None = Field( + hybrid_block_layout: list[SSMBlockType] | None = Field( default=None, desc=f"Pattern of blocks to use in the model. Available types: {SSMBlockType.__members__.values()}", hint=FieldHint.core, @@ -40,9 +39,8 @@ class HybridSSMBaseModelConfig(LanguageModelBaseConfig): desc="Multi-token prediction mixer to use in the model. If None, will use the last block type in `hybrid_block_layout`.", hint=FieldHint.optional, ) - use_megatron_initialization: bool = Field( - default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing - ) # TODO: is this needed? + # TODO: Support combination of different SSM block types. + ssm_block_type: SSMBlockType | None = Field(init=False) def setup_tensor_space(self, tensor_space: TensorSpace) -> None: """ @@ -79,9 +77,10 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_discrete_mamba2, inner_proj_dim)) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_dim, conv_dim)) elif SSMBlockType.mamba2.value in self.hybrid_block_layout: - inner_proj_dim: int = 2 * self.ssm.d_xb + 2 * d_inner + self.ssm.dt_rank + inner_proj_dim: int = 2 * self.ssm.d_xb + 2 * d_inner # + self.ssm.dt_rank tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba2, inner_proj_dim)) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim_2, self.ssm.d_xb)) + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.c_heads, d_inner // self.ssm.state_size)) def _validate(self): with self._set_implicit_default(None): @@ -95,30 +94,21 @@ def _validate(self): if self.hybrid_block_layout is None: with self._set_implicit_default(): - self.hybrid_block_layout = [SSMBlockType.mamba2_discrete.value] + self.hybrid_block_layout = [SSMBlockType.mamba2_discrete] * self.transformer.num_layers if len(self.hybrid_block_layout) != self.transformer.num_layers: + message = f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: - raise ValueError( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" - ) - num_repeats = int(self.transformer.num_layers // len(self.hybrid_block_layout)) - logger.warning( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times" - ) + raise ValueError(message) + num_repeats = self.transformer.num_layers // len(self.hybrid_block_layout) + logger.warning(f"{message}, will repeat {self.hybrid_block_layout} {num_repeats} times.") self.hybrid_block_layout = self.hybrid_block_layout * num_repeats - Assert.eq(len(self.hybrid_block_layout), self.transformer.num_layers) - Assert.custom( - lambda _: all(block_type in SSMBlockType.__members__.values() for block_type in self.hybrid_block_layout), - f"Invalid block type: {self.hybrid_block_layout}. Must be one of {SSMBlockType.__members__.values()}", - ) - Assert.custom( - lambda _: self.default_mtp_type in SSMBlockType.__members__.values() or self.default_mtp_type is None, - f"Invalid MTP type: {self.default_mtp_type}. Must be one of {SSMBlockType.__members__.values()} or None", - ) - super()._validate() + ssm_block_types = set(self.hybrid_block_layout) - {SSMBlockType.transformer} + # TODO: Support combination of different SSM block types. + Assert.leq(len(ssm_block_types), 1) + self.ssm_block_type = ssm_block_types.pop() if ssm_block_types else None class LLambaHuggingfaceCheckpointFormat(CheckpointFormat): diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index d57300252..43e3c67e5 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -3,6 +3,7 @@ import pathlib import typing +from fast_llm.config import MISSING from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( ConstantExportParamConverter, @@ -19,7 +20,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import RMSNormalizationConfig -from fast_llm.layers.ssm.config import SSMBlockType +from fast_llm.layers.ssm.config import DTInitType, SSMBlockType from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter from fast_llm.models.ssm.config import ( AprielSSMHHybridHuggingfaceCheckpointFormat, @@ -42,11 +43,11 @@ class HybridModelCheckpointHandler(HuggingfaceStateDictCheckpointHandler): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - block_converter = RenameParamConverter( + block_converter = MappedConfigParamConverter( fast_llm_names=(("hybrid_block_layout",),), export_names=(("hybrid_block_layout",),), - ignore_missing=True, - default_value=[cls._default_block_type], + fast_llm_value=lambda x: [cls._default_block_type] if x == MISSING else x, + export_value=lambda x: [x_.value for x_ in x], ) return super()._create_config_converters() + [block_converter] @@ -202,7 +203,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ignore_missing=True, default_value=4, ), - RenameParamConverter( + MappedConfigParamConverter( fast_llm_names=(("ssm", "dt_init"),), export_names=( ( @@ -210,8 +211,8 @@ def _create_config_converters(cls) -> list[ParamConverter]: "dt_init", ), ), - ignore_missing=True, - default_value="random", + fast_llm_value=lambda x: DTInitType.random if x == MISSING else DTInitType(x), + export_value=lambda x: x.value, ), ] @@ -258,6 +259,9 @@ def _create_weight_converters(self) -> list[WeightConverter]: ) # ================================================ # Mamba2 specific parameters + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.mixer.dt_in_proj", f"model.layers.{i}.mixer.dt_in_proj", ssm_bias + ) converters += self._get_weight_and_bias_converters( f"layers.{i+1}.mixer.dt_proj", f"model.layers.{i}.mixer.dt_proj", False ) diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index f8f6a0520..4fde72458 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -843,9 +843,8 @@ def __init__( self.num_C_head = self.d_inner // self.d_state self.repeat_group = self.num_C_head // self.num_xb_head - self.in_proj = nn.Linear( - self.d_model, 2 * self.d_xb + 2 * self.d_inner + self.dt_rank, bias=bias, **factory_kwargs - ) + self.in_proj = nn.Linear(self.d_model, 2 * self.d_xb + 2 * self.d_inner, bias=bias, **factory_kwargs) + self.dt_in_proj = nn.Linear(self.d_model, self.dt_rank, bias=bias, **factory_kwargs) self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=dt_proj_bias, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization @@ -933,8 +932,17 @@ def forward( outputs = {} A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - zxbcdt = self.in_proj(hidden_states) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) + zxbc = self.in_proj(hidden_states) + z, x, B, C = torch.split( + zxbc, + [ + self.d_inner, + self.d_xb, + self.d_xb, + self.d_inner, + ], + dim=-1, + ) x = rearrange(x, "b l d -> b d l") z = rearrange(z, "b l d -> b d l") @@ -944,7 +952,7 @@ def forward( B = rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() C = rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous() - dt = self.dt_proj(dt) # B, L, d_inner + dt = self.dt_proj(self.dt_in_proj(hidden_states)) # B, L, d_inner dt = rearrange(dt, "b l d -> b d l") # B, d_inner, L if self.repeat_kv_before_conv: diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 63a3ebf33..3ba6b1a62 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -5,11 +5,8 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead -from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 -from fast_llm.layers.ssm.llamba_block import LlambaBlock -from fast_llm.layers.ssm.mamba2 import Mamba2 -from fast_llm.layers.ssm.mamba_layer import MambaLayer -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.ssm.llamba_block import SSMBlock +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType @@ -31,7 +28,6 @@ def __init__( config: HybridSSMBaseModelConfig, distributed_config: DistributedConfig, ): - self.SSM_BLOCK_CLS = LlambaBlock # TODO: extend to other block types if needed super().__init__(config, distributed_config) def get_output_layers(self) -> list[Layer]: @@ -39,52 +35,31 @@ def get_output_layers(self) -> list[Layer]: Get the output layers of the model. This includes the language model head and any additional heads specified in the configuration. """ - layers = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] + layers: list[Layer] = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] if self._config.prediction_heads > 1: block_type = self._config.default_mtp_type or self._config.hybrid_block_layout[-1] for i in range(1, self._config.prediction_heads): if block_type == SSMBlockType.transformer: layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=len(self._config.hybrid_block_layout), + block_index=len(self._config.hybrid_block_layout), return_input=i != self._config.prediction_heads - 1, ) ) - elif block_type == SSMBlockType.mamba2_discrete: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=DiscreteMamba2, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) - elif block_type == SSMBlockType.mamba: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=MambaLayer, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) - elif block_type == SSMBlockType.mamba2: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=Mamba2, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") + layers.append( + SSMBlock( + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, + mixer_cls=self._config.ssm_block_type.get_mixer_class(), + block_index=len(self._config.hybrid_block_layout), + tensor_space=self._tensor_space, + return_input=i != self._config.prediction_heads - 1, + ) + ) layers.append(LanguageModelHead(self._config, self._tensor_space, prediction_distance=i)) return layers @@ -94,63 +69,35 @@ def get_layers(self) -> list[Layer]: Create a list of layers for the model, interleaving Transformer and Mamba blocks according to the block pattern. """ - layers = [LanguageModelEmbedding(self._config, self._tensor_space)] + layers: list[Layer] = [LanguageModelEmbedding(self._config, self._tensor_space)] # Create blocks according to pattern for i, block_type in enumerate(self._config.hybrid_block_layout): if block_type == SSMBlockType.transformer: # Transformer block layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, return_input=( i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 ), ) ) - elif block_type == SSMBlockType.mamba2_discrete: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=DiscreteMamba2, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) - - elif block_type == SSMBlockType.mamba: - # Create Mamba block - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=MambaLayer, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) - - elif block_type == SSMBlockType.mamba2: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=Mamba2, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") + layers.append( + SSMBlock( + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, + mixer_cls=self._config.ssm_block_type.get_mixer_class(), + block_index=i + 1, + tensor_space=self._tensor_space, + return_input=( + i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 + ), + ) + ) # Add the output layers layers += self.get_output_layers() diff --git a/setup.cfg b/setup.cfg index 843aa15ca..c086af7d0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,14 +48,9 @@ HUGGINGFACE = # Required to run SSMs # To install on cpu environment (ex. for IDE support): -# MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation +# MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[SSM]" --no-build-isolation SSM = mamba_ssm[causal-conv1d]==2.2.4 - cartesia_pytorch>=0.0.2 - -GENERATION = - lm_eval>=0.4.9 - DEV = # Pre-commit git hook diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 438782dfe..3e6c37632 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -193,6 +193,7 @@ def test_gpt_blended_mixed(): def test_gpt_blended_mixed_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index e951cc2b1..4f36cdf89 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -39,6 +39,7 @@ def test_gpt_concatenate(): def test_gpt_concatenate_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 7472f1958..004b96289 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -58,6 +58,7 @@ def test_gpt_fim(): def test_gpt_fim_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { @@ -81,6 +82,7 @@ def test_gpt_fim_data(): def test_gpt_fim_data_legacy(): + get_test_dataset() get_test_data_and_compare_samples( { "format": "list", diff --git a/tests/test_attention.py b/tests/test_attention.py index 87b0d3e59..dd36b840a 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -17,12 +17,12 @@ def test_decide_window_size(): # Arrange - Case 1: window_size is returned (layer_index >= max_window_layers) attention._config = TransformerConfig(window_size=512, max_window_layers=2) - attention._layer_index = 2 + attention._block_index = 2 assert attention._decide_window_size() == 512 # Arrange - Case 2: window_size is None (layer_index < max_window_layers) attention._config = TransformerConfig(window_size=512, max_window_layers=2) - attention._layer_index = 1 + attention._block_index = 1 assert attention._decide_window_size() is None # Arrange - Case 3: max_window_layers is None (always return window_size) diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index c530a170c..2f125717e 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -3,9 +3,10 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer -from fast_llm.layers.ssm.llamba_block import LlambaBlock -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.ssm.llamba_block import SSMBlock +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.utils import Assert +from tests.utils.dataset import get_model_test_dataset from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -23,6 +24,7 @@ def _get_trainer_from_args(args: list[str], model_type: str = "gpt") -> Trainer: @requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_frozen_weights(model_testing_config): + get_model_test_dataset() args = model_testing_config.config_args + ["run.tensor_logs.save=False"] model_ref = _get_trainer_from_args(args, model_testing_config.model_type)._multi_stage model_frozen = _get_trainer_from_args( @@ -39,7 +41,7 @@ def test_frozen_weights(model_testing_config): model_frozen._num_stages, ) frozen_parameter_counts = [ - sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerLayer, LlambaBlock)) else 0 + sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerBlock, SSMBlock)) else 0 for layer in model_ref.base_model.layers ] for weight_buffer_ref, weight_buffer_frozen in zip(