diff --git a/Dockerfile b/Dockerfile index 0f3c2d8c..7cf95101 100644 --- a/Dockerfile +++ b/Dockerfile @@ -29,8 +29,9 @@ ENV PIP_CONSTRAINT="" # There is no pre-build mamba image for pytorch 2.8, we build it before the rest to avoid rebuilds. # We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 (same for causal-conv1d) # We set the number of workers to avoid OOM when compiling on laptop. (TODO: Can we make it configurable?) +# Using varlen_mamba for variable length sequence support RUN MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1" -RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2" +RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/jxiw/varlen_mamba@varlen_mamba" # Copy dependency files with universal write permissions for all users. COPY --chmod=777 setup.py setup.cfg pyproject.toml ./ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index bccc1d62..7b959e34 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -8,6 +8,13 @@ from fast_llm.tensor import ParameterMeta, accumulate_gradient, init_ones_, init_zeros_ from fast_llm.utils import Assert +try: + from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn as mamba_rmsnorm_fn + + _mamba_ssm_available = True +except ImportError: + _mamba_ssm_available = False + try: import fused_layer_norm_cuda # noqa @@ -288,3 +295,22 @@ def _forward_fused(self, input_: torch.Tensor) -> torch.Tensor: def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: return torch.rms_norm(input_.to(self.weight.dtype), self.normalized_shape, self.weight, self._eps) + + +class MambaRMSNormGated(RMSNorm): + def __init__(self, hidden_dim: TensorDim, group_size: int, eps=1e-5, lr_scale: float | None = None): + assert _mamba_ssm_available + super().__init__(hidden_dim, eps=eps, lr_scale=lr_scale) + self.group_size = group_size + self._forward = mamba_rmsnorm_fn + + def forward(self, input_: torch.Tensor, gate=None): + return mamba_rmsnorm_fn( + x=input_, + weight=self.weight, + bias=None, # No bias + z=gate, + eps=self._eps, + group_size=self.group_size, + norm_before_gate=False, + ) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 3b21ca69..f8d57799 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -12,12 +12,33 @@ from fast_llm.tensor import Initializer +class BaseSSMKwargs: + _kwargs_attributes = { + "cu_seqlens": "cu_seqlens", + "seq_idx": "seq_idx", + "ssm_position_ids": "ssm_position_ids", + } + + _prefix = "" + + def __init_subclass__(cls, prefix="", **kwargs): + super().__init_subclass__(**kwargs) + cls._prefix = prefix + for attr, value in BaseSSMKwargs._kwargs_attributes.items(): + setattr(cls, value, f"{cls._prefix}_{value}" if cls._prefix else value) + + +class SSMKwargs(BaseSSMKwargs, prefix=""): + pass + + class SSMDimNames: # TODO: Use separate tensor space for different mixers so there is no risk of name conflict. state = "ssm_state" # State dimension (N), aka head size / num channels head_dim = "ssm_head_dim" head_groups = "ssm_head_groups" group_heads = "ssm_group_heads" + conv1d_dim = "ssm_conv1d_dim" # Mamba 2 x_proj_dim_2 = "x_proj_dim_2" # d_xb @@ -28,7 +49,10 @@ class SSMDimNames: # Composite dimensions composite_heads = "ssm_composite_heads" composite_heads_and_head_dim = "ssm_composite_heads_and_head_dim" + composite_heads_and_head_dim_nontp = "ssm_composite_heads_and_head_dim_nontp" + composite_heads_and_state_dim = "ssm_composite_heads_and_state_dim" composite_head_groups_and_state = "ssm_composite_head_groups_and_state" + composite_head_groups_and_head = "ssm_composite_head_groups_and_head" # Concatenated dimensions concatenated_convolution = "ssm_concatenated_convolution" @@ -45,6 +69,7 @@ class SSMBlockType(enum.StrEnum): mamba2_discrete = "m2d" mamba2 = "m2" transformer = "t" + nemotron_h_mamba2 = "nm2" def get_mixer_class(self): if self == SSMBlockType.mamba: @@ -59,6 +84,10 @@ def get_mixer_class(self): from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 return DiscreteMamba2 + elif self == SSMBlockType.nemotron_h_mamba2: + from fast_llm.layers.ssm.mamba2 import NemotronHMamba2 + + return NemotronHMamba2 else: raise NotImplementedError(self) @@ -206,6 +235,21 @@ class SSMConfig(LLMBlockConfig): valid=check_field(Assert.gt, 0), ) + # Nemotron H Mamba2 (the real mamba2 actually) + # here instead of setting d_inner, we set head dim. and number of heads + # Note: we do not implement n_groups for Mamba2, because, sicne we do MiL init, we do not want to share B and C parameters accross heads. + # Instead, we mimic the GQA behaviour (x -> v, B -> k, C -> q), where x and B are shared accross heads. So this is the same as having n_groups = n_heads? + # n_groups: int = Field( + # default=8, + # desc="Number of groups for Mamba2. Allows sharing B and C parameters accross heads.", + # hint=FieldHint.architecture, + # ) + head_dim: int = Field( + default=64, + desc="Head dimension for Nemotron H", + hint=FieldHint.architecture, + ) + def _validate(self) -> None: with self._set_implicit_default(): if self.activation_type is None: @@ -223,6 +267,10 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType elif block_type == SSMBlockType.mamba2: num_heads = div(self.d_inner, self.state_size) num_head_groups = div(self.d_xb, self.state_size) + elif block_type == SSMBlockType.nemotron_h_mamba2: + # head dim and state size are not the same + num_heads = div(self.d_inner, self.head_dim) + num_head_groups = div(self.d_xb, self.head_dim) elif block_type == SSMBlockType.mamba2_discrete: # TODO: Use different variables? num_heads = self.n_v_heads @@ -233,6 +281,8 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType tensor_space.add_tensor_dim(state := TensorDim(SSMDimNames.state, self.state_size)) if block_type == SSMBlockType.mamba2_discrete: tensor_space.add_tensor_dim(head_dim := TensorDim(SSMDimNames.head_dim, div(self.d_inner, num_heads))) + elif block_type == SSMBlockType.nemotron_h_mamba2: + tensor_space.add_tensor_dim(head_dim := TensorDim(SSMDimNames.head_dim, self.head_dim)) else: head_dim = state @@ -241,14 +291,16 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType tensor_space.add_tensor_dim( heads := CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads)) ) + # full d_inner or intermediate_size (e.g. for z gate, also the d_inner size for C in mamba2) tensor_space.add_tensor_dim( heads_and_head_dim := CompositeTensorDim( SSMDimNames.composite_heads_and_head_dim, (head_groups, group_heads, head_dim) ) ) + # d_xb tensor_space.add_tensor_dim( - head_groups_and_state := CompositeTensorDim( - SSMDimNames.composite_head_groups_and_state, (head_groups, state) + head_groups_and_head := CompositeTensorDim( + SSMDimNames.composite_head_groups_and_head, (head_groups, head_dim) ) ) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.convolution_kernel, self.conv_kernel_dimension)) @@ -272,7 +324,41 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType tensor_space.add_tensor_dim( ConcatenatedTensorDim( SSMDimNames.concatenated_inner_projection, - (heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim), + (heads_and_head_dim, head_groups_and_head, head_groups_and_head, heads_and_head_dim), + ) + ) + elif block_type == SSMBlockType.nemotron_h_mamba2: + # for the norm + tensor_space.add_tensor_dim( + TensorDim( + SSMDimNames.composite_heads_and_head_dim_nontp, num_head_groups * group_heads.size * head_dim.size + ) + ) + # state and head dim are not the same + # C: for each head, size of state + tensor_space.add_tensor_dim( + heads_and_state_dim := CompositeTensorDim( + SSMDimNames.composite_heads_and_state_dim, (head_groups, group_heads, state) + ) + ) + # B: for each head group, size of state + tensor_space.add_tensor_dim( + head_groups_and_state := CompositeTensorDim( + SSMDimNames.composite_head_groups_and_state, (head_groups, state) + ) + ) + # here we apply depthwise conv. layer to xBC, so the dim. is x (d_xb) x B (d_bb) x C + tensor_space.add_tensor_dim( + conv1d_dim := ConcatenatedTensorDim( + SSMDimNames.conv1d_dim, (heads_and_state_dim, head_groups_and_head, head_groups_and_state) + ) + ) + + # inner projection dimention: also includes z (gate), which has size d_inner (heads_and_head_dim) + tensor_space.add_tensor_dim( + ConcatenatedTensorDim( + SSMDimNames.concatenated_inner_projection, + (conv1d_dim, heads_and_head_dim), ) ) elif block_type == SSMBlockType.mamba2_discrete: diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 77c1b386..f2735092 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -1,22 +1,37 @@ +import inspect import logging import typing +import einops import torch from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear -from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.common.normalization import MambaRMSNormGated +from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames, SSMKwargs from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_ +from fast_llm.tensor import LambdaInitializer, ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_ from fast_llm.utils import Assert, div, get_lr_scale +_mamba_varlen = False try: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined _mamba_available = True + sig = inspect.signature(selective_scan_fn) + if "position_indices" in sig.parameters: + _mamba_varlen = True + logging.warning("Using selective_scan_fn from varlen_mamba that supports packing") + else: + _mamba_varlen = False + logging.warning("Using selective_scan_fn from original mamba without packing support") + # for training with packing install https://github.com/jxiw/varlen_mamba + # see https://github.com/jxiw/M1/blob/main/HYBRID_PACK.md + except (ImportError, RuntimeError): _mamba_available = False @@ -63,7 +78,7 @@ def __init__( lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) inner_dim: TensorDim = tensor_space[SSMDimNames.composite_heads_and_head_dim] - xb_dim = tensor_space[SSMDimNames.composite_head_groups_and_state] + xb_dim = tensor_space[SSMDimNames.composite_head_groups_and_head] hidden_dim: TensorDim = tensor_space[TransformerDimNames.hidden] dt_rank_dim = tensor_space[SSMDimNames.dt_rank] @@ -73,6 +88,9 @@ def __init__( self._local_inner_size = inner_dim.size self._local_xb_size = xb_dim.size + state_size = tensor_space[SSMDimNames.state].size + div(self._local_inner_size, state_size) + conv1d_dim = inner_dim if self._config.repeat_kv_before_conv else xb_dim self.conv1d_weight = ParameterMeta.from_dims( ( @@ -143,8 +161,16 @@ def __init__( ) def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Note, we are nto doing "read" sequence-tensor parallel trainign here, since inner_projection is gathered over all GPUS. + This is also desired, since the currently used mamba kernel does not support STP. + TODO: use correct kernel from Mamba2! + """ assert _mamba_available assert _causal_conv1d_available + cu_seqlens = kwargs[SSMKwargs.cu_seqlens] + seq_idx = kwargs[SSMKwargs.seq_idx] + position_indices = kwargs[SSMKwargs.ssm_position_ids] # inner_projection : (batch/local_sequence, local_sequence/batch, hidden) # -> (batch/sequence, sequence/batch, inner_projection) @@ -156,7 +182,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ dt = dt.transpose(0, 1) sequence_length = inner_projection.size(1) - + # is this like Mamba1, the conv is only on the x? z, x, b, c = torch.split( inner_projection, [self._local_inner_size, self._local_xb_size, self._local_xb_size, self._local_inner_size], @@ -168,15 +194,27 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ # x: (batch, sequence, local_head_groups * state) -> (batch, local_heads * state, sequence) x = x.transpose(1, 2) + # x: (batch, local_heads * state, sequence) -> (batch, local_head_per_groups, state, sequence) if self._config.repeat_kv_before_conv: x = ( x.unflatten(1, (self._local_head_groups, self._config.state_size)) .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) ) - x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight.squeeze(1), bias=self.conv1d_bias, activation="silu") + + if cu_seqlens is not None: + # from https://github.com/jxiw/M1/blob/d92b53faa640f8ebf624d3e9e771fe24648ef014/rl/verl/verl/models/mamba/hybrid_wrapper.py#L152 + x = _causal_conv1d_fn( + x=x.transpose(1, 2).contiguous().transpose(1, 2), + weight=self.conv1d_weight.squeeze(1), + bias=self.conv1d_bias, + seq_idx=seq_idx, + activation="silu", + ) else: x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight.squeeze(1), bias=self.conv1d_bias, activation="silu") + + if not self._config.repeat_kv_before_conv: x = ( x.unflatten(1, (self._local_head_groups, self._config.state_size)) .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) @@ -203,17 +241,34 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ self._debug_log(c, "c", self._BC_DIMS, kwargs) self._debug_log(dt, "dt", self._XZ_DIMS, kwargs) - y = selective_scan_fn( - x, - dt, - -torch.exp(self.A_log.float()), - b, - c, - self.D.float(), - z, - delta_bias=self.dt_proj_bias.float(), - delta_softplus=True, - ) + if not _mamba_varlen: + Assert.eq(cu_seqlens, None, msg="This version of Mamba2 does not support cu_seqlens, install verlen mamba") + y = selective_scan_fn( + x, + dt, + -torch.exp(self.A_log.float()), + b, + c, + self.D.float(), + z, + delta_bias=self.dt_proj_bias.float(), + delta_softplus=True, + ) + else: + position_indices = position_indices if cu_seqlens is not None else None + + y = selective_scan_fn( + x, + dt, + -torch.exp(self.A_log.float()), + b, + c, + self.D.float(), + z, + delta_bias=self.dt_proj_bias.float(), + delta_softplus=True, + position_indices=position_indices, + ) if self._debug_level: self._debug_log(y, "y", self._XZ_DIMS, kwargs) @@ -226,3 +281,240 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ # (batch/sequence, sequence/batch, local_heads * state) # -> (batch/local_sequence, local_sequence/batch, hidden) return self.out_proj(y) + + +class NemotronHMamba2(Mixer): + """ + This is the actual Mamab2, called NemotronHMamba2 for historical reasons. + Decompesl, d_state and head_dim. + Head dimention -- later head dimention means me project hidden statte into larger space (more channel mixing) + Larger state size -- more temporar memory. + + This code is adapted from https://huggingface.co/nvidia/Nemotron-H-8B-Base-8K/blob/main/modeling_nemotron_h.py + """ + + _mixer_name: typing.ClassVar[str] = "mamba_2" + + _XZ_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.composite_heads_and_head_dim, + TransformerDimNames.sequence_q, + ) + _BC_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.composite_heads, + SSMDimNames.state, + TransformerDimNames.sequence_q, + ) + + def __init__( + self, + config: SSMConfig, + tensor_space: TensorSpace, + block_index: int, + transformer_config: TransformerConfig, + ): + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) + self._config: SSMConfig = config + Assert.eq(self._config.activation_type, ActivationType.silu) + layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None + lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + + inner_dim: TensorDim = tensor_space[SSMDimNames.composite_heads_and_head_dim] + inner_dim_non_tp: TensorDim = tensor_space[SSMDimNames.composite_heads_and_head_dim_nontp] + c_dim: TensorDim = tensor_space[SSMDimNames.composite_heads_and_state_dim] + xb_dim = tensor_space[SSMDimNames.composite_head_groups_and_head] + bb_dim = tensor_space[SSMDimNames.composite_head_groups_and_state] + hidden_dim: TensorDim = tensor_space[TransformerDimNames.hidden] + + self._head_dim_size: TensorDim = tensor_space[SSMDimNames.head_dim].size + self._local_heads = tensor_space[SSMDimNames.composite_heads].size + self._local_head_groups = tensor_space[SSMDimNames.head_groups].size + self._group_heads = div(self._local_heads, self._local_head_groups) + Assert.eq(self._local_heads, self._local_head_groups * self._group_heads) + + self._local_inner_size = inner_dim.size + self._local_c_size = c_dim.size + + Assert.eq(self._local_inner_size, self._head_dim_size * self._local_heads) + self._local_xb_size = xb_dim.size # x has head dim and is for each head group + self._local_bb_size = bb_dim.size # b has state dim and is for each head group + Assert.eq(self._local_xb_size, self._head_dim_size * self._local_head_groups) + Assert.eq(self._local_bb_size, self._config.state_size * self._local_head_groups) + + conv1d_dim = tensor_space[SSMDimNames.conv1d_dim] # applied to xBC, so d_xb + d_bb + c_dim + self.conv1d_weight = ParameterMeta.from_dims( + ( + conv1d_dim, + tensor_space[DefaultDimNames.scalar], + tensor_space[SSMDimNames.convolution_kernel], + ), + init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), + lr_scale=lr_scale, + ) + self.conv1d_bias = ParameterMeta.from_dims( + (conv1d_dim,), + init_method=init_uniform_centered_(self._config.conv_kernel_dimension**-0.5), + lr_scale=lr_scale, + ) + self.in_proj = OutputParallelLinear( + hidden_dim, + tensor_space[SSMDimNames.concatenated_inner_projection], + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(transformer_config.hidden_size), + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, + ) + + # project single number per head + self.dt_in_proj = OutputParallelLinear( + hidden_dim, + tensor_space[SSMDimNames.composite_heads], + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(transformer_config.hidden_size), + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, + ) + + self.dt_proj_bias = ParameterMeta.from_dims( + (tensor_space[SSMDimNames.composite_heads],), + init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), + lr_scale=lr_scale, + ) + + def init_A_uniform(A_init_range: tuple[float, float] = (1, 16)) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa + tensor.uniform_(*A_init_range).log_() + + return LambdaInitializer(init_, requires_global_initialization=True) + + self.A_log = ParameterMeta.from_dims( + (tensor_space[SSMDimNames.composite_heads],), + init_method=init_A_uniform(A_init_range=(1, 16)), + lr_scale=lr_scale, + weight_decay=False, + ) + self.D = ParameterMeta.from_dims( + (tensor_space[SSMDimNames.composite_heads],), # can also be nheads x headim + weight_decay=False, + init_method=init_ones_, + lr_scale=lr_scale, + ) + self.out_proj = InputParallelLinear( + inner_dim, + hidden_dim, + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(self._config.d_inner), + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, + ) + # TODO: this norm does nto support TP. So we need a workaround! + self.norm = MambaRMSNormGated( + inner_dim_non_tp, + group_size=self._local_inner_size, + eps=1e-5, + lr_scale=lr_scale, + ) + + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + """ """ + assert _mamba_available + assert _causal_conv1d_available + cu_seqlens = kwargs[SSMKwargs.cu_seqlens] + seq_idx = kwargs[SSMKwargs.seq_idx] + + # inner_projection : (batch/local_sequence, local_sequence/batch, hidden) + # -> (batch/sequence, sequence/batch, inner_projection) + inner_projection = self.in_proj(input_) + dt = self.dt_in_proj(input_) # bs, seq, heads + # Standardize to (batch, sequence, inner_projection) + if kwargs[TransformerKwargs.sequence_first]: + inner_projection = inner_projection.transpose(0, 1) + dt = dt.transpose(0, 1) + # note: self.in_proj gathers full sequence length here + sequence_length = inner_projection.size(1) + + z, xBC = torch.split( + inner_projection, + [self._local_inner_size, self._local_xb_size + self._local_bb_size + self._local_c_size], + dim=2, + ) + + if cu_seqlens is not None: + xBC = _causal_conv1d_fn( + xBC.transpose(1, 2), + weight=self.conv1d_weight.squeeze(1), + bias=self.conv1d_bias, + seq_idx=seq_idx, + activation="silu", + ).transpose(1, 2) + else: + xBC = _causal_conv1d_fn( + x=xBC.transpose(1, 2), weight=self.conv1d_weight.squeeze(1), bias=self.conv1d_bias, activation="silu" + ).transpose(1, 2) + + x, b, c = torch.split(xBC, [self._local_xb_size, self._local_bb_size, self._local_c_size], dim=-1) + # simulate GQA by repeating heads in x,b, x -> v, B -> k, C -> q + x = einops.rearrange( + x, "b l (local_head_groups head_dim) -> b local_head_groups l head_dim", head_dim=self._head_dim_size + ) # x is b x local_head_groups x l x head_dim + b = einops.rearrange( + b, + "b l (local_head_groups state_size) -> b local_head_groups l state_size", + state_size=self._config.state_size, + ) # b is b x local_head_groups x l x state_size + batch, num_key_value_heads, slen, head_dim = x.shape + x = x[:, :, None, :, :].expand(batch, num_key_value_heads, self._group_heads, slen, head_dim) + x = x.reshape(batch, num_key_value_heads * self._group_heads, slen, head_dim) + b = b[:, :, None, :, :].expand(batch, num_key_value_heads, self._group_heads, slen, self._config.state_size) + b = b.reshape(batch, num_key_value_heads * self._group_heads, slen, self._config.state_size) + + if self._debug_level: + self._debug_log(z, "z", self._XZ_DIMS, kwargs) + self._debug_log(x, "x", self._XZ_DIMS, kwargs) + self._debug_log(b, "b", self._BC_DIMS, kwargs) + self._debug_log(c, "c", self._BC_DIMS, kwargs) + self._debug_log(dt, "dt", self._XZ_DIMS, kwargs) + + dt_limit_kwargs = ( + {} + ) # can be used to set time-step limit as in https://huggingface.co/nvidia/Nemotron-H-8B-Base-8K/blob/main/modeling_nemotron_h.py#L424 + # c is b x seq x (heads * state) + # b is b x heads x seq x state) + # x is b x heads x seq x head_dim + # note, we could used mamba_split_conv1d_scan_combined directly for training, however because of the GQA, we need to use the chunked version. + y = mamba_chunk_scan_combined( + einops.rearrange(x, "b g l p -> b l g p"), + dt, + A=-torch.exp(self.A_log.float()), + B=einops.rearrange(b, "b g l n -> b l g n"), + C=einops.rearrange(c, "b l (g n) -> b l g n", g=self._local_heads), + chunk_size=self._config.chunk_size, + D=self.D, + z=None, + dt_bias=self.dt_proj_bias, + dt_softplus=True, + seq_idx=seq_idx, # assume this is used for packing + cu_seqlens=cu_seqlens, # assume this is used for packing, but maybe not needed at training + **dt_limit_kwargs, + return_final_states=False, + return_varlen_states=False, + ) + + if self._debug_level: + self._debug_log(y, "y", self._XZ_DIMS, kwargs) + + # y: (batch, local_heads * state, sequence) -> (batch, sequence, local_heads * state) + y = y.view(batch, sequence_length, -1) + + if kwargs[TransformerKwargs.sequence_first]: + # TODO: Is contiguous needed? + y = y.transpose(0, 1).contiguous() + z = z.transpose(0, 1).contiguous() + # in tp need to to gather the y and z, cause norm does not + # gate norm + y = self.norm(y, gate=z) + # (batch/sequence, sequence/batch, local_heads * state) + # -> (batch/local_sequence, local_sequence/batch, hidden) + out = self.out_proj(y) + return out diff --git a/fast_llm/layers/ssm/preprocessing.py b/fast_llm/layers/ssm/preprocessing.py new file mode 100644 index 00000000..343f0bb2 --- /dev/null +++ b/fast_llm/layers/ssm/preprocessing.py @@ -0,0 +1,68 @@ +import logging +import typing + +import torch + +from fast_llm.engine.base_model.config import Preprocessor +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.ssm.config import SSMKwargs +from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.models.ssm.config import HybridSSMBaseModelConfig +from fast_llm.utils import Assert + +logger = logging.getLogger(__name__) + + +class Mamba2Preprocessor(Preprocessor): + def __init__(self, config: HybridSSMBaseModelConfig, tensor_space: TensorSpace): + self._config = config + self._tensor_space = tensor_space + self._distributed_config = self._tensor_space.distributed_config + self._transformer_dim_names = config.transformer._transformer_dim_names + + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + """ + Simplified preprocessor that does not take into account micro-sequences. + """ + if TransformerKwargs.sequence_lengths not in kwargs: + return + sequence_lengths = kwargs[TransformerKwargs.sequence_lengths] + if TransformerKwargs.cu_seqlens_k in kwargs: + # already set this in the transformer preprocessor, so we can use it here + cu_seqlens_k = kwargs[TransformerKwargs.cu_seqlens_k] + cu_seqlens_q = kwargs[TransformerKwargs.cu_seqlens_q] + Assert.eq( + cu_seqlens_k.shape[0], + cu_seqlens_q.shape[0], + msg="cu_seqlens_k and cu_seqlens_q have different lengths, is micro_sequence_length being used? This is currently not supported for Mamba.", + ) + Assert.all_equal(cu_seqlens_k, cu_seqlens_q) + cu_seqlens = cu_seqlens_k + else: + seqlens = torch.cat(sequence_lengths) + cu_seqlens = torch.cat( + ( + torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), + torch.cumsum(seqlens, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), + ) + ) + kwargs[SSMKwargs.cu_seqlens] = cu_seqlens + # from https://github.com/jxiw/M1/blob/d92b53faa640f8ebf624d3e9e771fe24648ef014/rl/verl/verl/models/mamba/hybrid_wrapper.py#L152 + kwargs[SSMKwargs.seq_idx] = torch.cat( + [ + torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device) + for i, s in enumerate(cu_seqlens[1:] - cu_seqlens[:-1]) + ], + dim=0, + ).unsqueeze(0) + + sequence_lengths = kwargs.get(TransformerKwargs.sequence_lengths) + sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size + sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size + position_ids = torch.stack( + [torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] + ).to(self._tensor_space.distributed.device, dtype=torch.int64) + position_ids = position_ids[ + :, sequence_k - sequence_q : sequence_k + ] # this is only needed if we do micro-sequences? + kwargs[SSMKwargs.ssm_position_ids] = position_ids.to(torch.int32) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 5dca41a7..34f3151a 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -207,6 +207,11 @@ def get_trainer_class(cls) -> type["HybridSSMTrainer"]: def _validate(self) -> None: super()._validate() + Assert.eq( + self.batch.micro_sequence_length, + self.batch.sequence_length, + msg="Micro-sequences not supported for SSMs. at htis point", + ) if (name := self.model.base_model.distillation_model) is None: Assert.empty(self.reference_models) else: diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 64afbea0..18ab09c7 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -296,6 +296,9 @@ def _create_weight_converters( converters += self._get_weight_and_bias_converters( f"layers.{offset+i+1}.mixer.dt_proj", f"{hf_base_prefix}model.layers.{i}.mixer.dt_proj", False ) + converters += self._get_weight_and_bias_converters( + f"layers.{offset+i+1}.mixer.norm", f"{hf_base_prefix}model.layers.{i}.mixer.norm", True + ) # bias is treated separately in Mamba2 and must always exist (https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py) converters.append( WeightConverter( @@ -304,6 +307,14 @@ def _create_weight_converters( self._model.config.base_model, ) ) + # for nemotron mamba2, bias is a seperate parameter + converters.append( + WeightConverter( + f"layers.{offset+i+1}.mixer.dt_proj_bias", + f"{hf_base_prefix}model.layers.{i}.mixer.dt_bias", + self._model.config.base_model, + ) + ) converters.append( WeightConverter( @@ -789,6 +800,10 @@ def _create_config_converters(cls) -> list[ParamConverter]: fast_llm_names=(("ssm", "d_inner"),), export_names=(("ssm_cfg", "d_inner"),), ), + RenameParamConverter( + fast_llm_names=(("ssm", "head_dim"),), + export_names=(("ssm_cfg", "head_dim"),), + ), IgnoreImportParamConverter(export_names=(("sliding_window",),), ignore_export_value=None), ] diff --git a/fast_llm/models/ssm/external/15B_hybrid.ipynb b/fast_llm/models/ssm/external/15B_hybrid.ipynb index a8f0c33b..cfc85ef7 100644 --- a/fast_llm/models/ssm/external/15B_hybrid.ipynb +++ b/fast_llm/models/ssm/external/15B_hybrid.ipynb @@ -745,7 +745,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -846,22 +846,30 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "path_thinker = \"/mnt/checkpoints/upstream/Apriel-Nemotron-15b-Thinker\"\n", - "n_ssm = 25\n", + "# path_hybrid = \"/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/15b-hyb25distsftvrlm2-bs64-lr5e-06-lrs1-1-1-1-sl16384_ti60000_aprsft/export/apriel_ssm_thinker_hybrid/23000\"\n", + "path_hybrid=\"/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/15b-h27distsftvrlm2f145-bs64-lr5e-06-lrs1-1-1-1-sl16384_ti60000_aprsft/export/apriel_ssm_thinker_hybrid/3500\"\n", "\n", + "n_ssm = 25\n", "\n", + "config_hybrid = AprielSSMHybridConfig.from_pretrained(path_hybrid)\n", "config_thinker = AutoConfig.from_pretrained(path_thinker)\n", "# config_thinker.num_hidden_layers = 5\n", - "hybrid_block_layout = [\"t\"] * config_thinker.num_hidden_layers\n", + "# hybrid_block_layout = [\"t\"] * config_thinker.num_hidden_layers\n", "# hybrid_block_layout[3] = \"m2\"\n", + "hybrid_block_layout = config_hybrid.hybrid_block_layout\n", + "# hybrid_block_layout[7] = \"m2\"\n", + "hybrid_block_layout[6] = \"m2\"\n", + "hybrid_block_layout[8] = \"m2\"\n", "\n", - "for i in range(n_ssm):\n", - " hybrid_block_layout[layer_importance[i]] = \"m2\"\n", + "\n", + "# for i in range(n_ssm):\n", + "# hybrid_block_layout[layer_importance[i]] = \"m2\"\n", "\n", "# group_size = 10 # 2nd layer importance is missing\n", "# for i in range(0, len(layer_importance), group_size):\n", @@ -922,7 +930,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -931,12 +939,12 @@ "['t',\n", " 't',\n", " 't',\n", + " 'm2',\n", " 't',\n", + " 'm2',\n", + " 'm2',\n", " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", + " 'm2',\n", " 't',\n", " 't',\n", " 't',\n", @@ -980,7 +988,7 @@ " 'm2']" ] }, - "execution_count": 10, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -991,25 +999,32 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Loading checkpoint shards: 100%|██████████| 7/7 [00:04<00:00, 1.58it/s]\n" + "Loading checkpoint shards: 29%|██▊ | 2/7 [00:01<00:04, 1.13it/s]\n" ] }, { - "data": { - "text/plain": [ - "_IncompatibleKeys(missing_keys=['model.layers.3.mixer.A_log', 'model.layers.3.mixer.D', 'model.layers.3.mixer.conv1d.weight', 'model.layers.3.mixer.conv1d.bias', 'model.layers.3.mixer.in_proj.weight', 'model.layers.3.mixer.dt_proj.weight', 'model.layers.3.mixer.dt_proj.bias', 'model.layers.3.mixer.out_proj.weight'], unexpected_keys=['model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.5.mlp.gate_proj.weight', 'model.layers.5.mlp.up_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.5.input_layernorm.weight', 'model.layers.5.post_attention_layernorm.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.6.mlp.gate_proj.weight', 'model.layers.6.mlp.up_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.6.input_layernorm.weight', 'model.layers.6.post_attention_layernorm.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.7.mlp.gate_proj.weight', 'model.layers.7.mlp.up_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.7.input_layernorm.weight', 'model.layers.7.post_attention_layernorm.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.8.mlp.gate_proj.weight', 'model.layers.8.mlp.up_proj.weight', 'model.layers.8.mlp.down_proj.weight', 'model.layers.8.input_layernorm.weight', 'model.layers.8.post_attention_layernorm.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.v_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.9.mlp.gate_proj.weight', 'model.layers.9.mlp.up_proj.weight', 'model.layers.9.mlp.down_proj.weight', 'model.layers.9.input_layernorm.weight', 'model.layers.9.post_attention_layernorm.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.10.mlp.gate_proj.weight', 'model.layers.10.mlp.up_proj.weight', 'model.layers.10.mlp.down_proj.weight', 'model.layers.10.input_layernorm.weight', 'model.layers.10.post_attention_layernorm.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.11.mlp.gate_proj.weight', 'model.layers.11.mlp.up_proj.weight', 'model.layers.11.mlp.down_proj.weight', 'model.layers.11.input_layernorm.weight', 'model.layers.11.post_attention_layernorm.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.12.mlp.gate_proj.weight', 'model.layers.12.mlp.up_proj.weight', 'model.layers.12.mlp.down_proj.weight', 'model.layers.12.input_layernorm.weight', 'model.layers.12.post_attention_layernorm.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.13.mlp.gate_proj.weight', 'model.layers.13.mlp.up_proj.weight', 'model.layers.13.mlp.down_proj.weight', 'model.layers.13.input_layernorm.weight', 'model.layers.13.post_attention_layernorm.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.14.mlp.gate_proj.weight', 'model.layers.14.mlp.up_proj.weight', 'model.layers.14.mlp.down_proj.weight', 'model.layers.14.input_layernorm.weight', 'model.layers.14.post_attention_layernorm.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.15.mlp.gate_proj.weight', 'model.layers.15.mlp.up_proj.weight', 'model.layers.15.mlp.down_proj.weight', 'model.layers.15.input_layernorm.weight', 'model.layers.15.post_attention_layernorm.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.16.mlp.gate_proj.weight', 'model.layers.16.mlp.up_proj.weight', 'model.layers.16.mlp.down_proj.weight', 'model.layers.16.input_layernorm.weight', 'model.layers.16.post_attention_layernorm.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.17.mlp.gate_proj.weight', 'model.layers.17.mlp.up_proj.weight', 'model.layers.17.mlp.down_proj.weight', 'model.layers.17.input_layernorm.weight', 'model.layers.17.post_attention_layernorm.weight', 'model.layers.18.self_attn.q_proj.weight', 'model.layers.18.self_attn.k_proj.weight', 'model.layers.18.self_attn.v_proj.weight', 'model.layers.18.self_attn.o_proj.weight', 'model.layers.18.mlp.gate_proj.weight', 'model.layers.18.mlp.up_proj.weight', 'model.layers.18.mlp.down_proj.weight', 'model.layers.18.input_layernorm.weight', 'model.layers.18.post_attention_layernorm.weight', 'model.layers.19.self_attn.q_proj.weight', 'model.layers.19.self_attn.k_proj.weight', 'model.layers.19.self_attn.v_proj.weight', 'model.layers.19.self_attn.o_proj.weight', 'model.layers.19.mlp.gate_proj.weight', 'model.layers.19.mlp.up_proj.weight', 'model.layers.19.mlp.down_proj.weight', 'model.layers.19.input_layernorm.weight', 'model.layers.19.post_attention_layernorm.weight', 'model.layers.20.self_attn.q_proj.weight', 'model.layers.20.self_attn.k_proj.weight', 'model.layers.20.self_attn.v_proj.weight', 'model.layers.20.self_attn.o_proj.weight', 'model.layers.20.mlp.gate_proj.weight', 'model.layers.20.mlp.up_proj.weight', 'model.layers.20.mlp.down_proj.weight', 'model.layers.20.input_layernorm.weight', 'model.layers.20.post_attention_layernorm.weight', 'model.layers.21.self_attn.q_proj.weight', 'model.layers.21.self_attn.k_proj.weight', 'model.layers.21.self_attn.v_proj.weight', 'model.layers.21.self_attn.o_proj.weight', 'model.layers.21.mlp.gate_proj.weight', 'model.layers.21.mlp.up_proj.weight', 'model.layers.21.mlp.down_proj.weight', 'model.layers.21.input_layernorm.weight', 'model.layers.21.post_attention_layernorm.weight', 'model.layers.22.self_attn.q_proj.weight', 'model.layers.22.self_attn.k_proj.weight', 'model.layers.22.self_attn.v_proj.weight', 'model.layers.22.self_attn.o_proj.weight', 'model.layers.22.mlp.gate_proj.weight', 'model.layers.22.mlp.up_proj.weight', 'model.layers.22.mlp.down_proj.weight', 'model.layers.22.input_layernorm.weight', 'model.layers.22.post_attention_layernorm.weight', 'model.layers.23.self_attn.q_proj.weight', 'model.layers.23.self_attn.k_proj.weight', 'model.layers.23.self_attn.v_proj.weight', 'model.layers.23.self_attn.o_proj.weight', 'model.layers.23.mlp.gate_proj.weight', 'model.layers.23.mlp.up_proj.weight', 'model.layers.23.mlp.down_proj.weight', 'model.layers.23.input_layernorm.weight', 'model.layers.23.post_attention_layernorm.weight', 'model.layers.24.self_attn.q_proj.weight', 'model.layers.24.self_attn.k_proj.weight', 'model.layers.24.self_attn.v_proj.weight', 'model.layers.24.self_attn.o_proj.weight', 'model.layers.24.mlp.gate_proj.weight', 'model.layers.24.mlp.up_proj.weight', 'model.layers.24.mlp.down_proj.weight', 'model.layers.24.input_layernorm.weight', 'model.layers.24.post_attention_layernorm.weight', 'model.layers.25.self_attn.q_proj.weight', 'model.layers.25.self_attn.k_proj.weight', 'model.layers.25.self_attn.v_proj.weight', 'model.layers.25.self_attn.o_proj.weight', 'model.layers.25.mlp.gate_proj.weight', 'model.layers.25.mlp.up_proj.weight', 'model.layers.25.mlp.down_proj.weight', 'model.layers.25.input_layernorm.weight', 'model.layers.25.post_attention_layernorm.weight', 'model.layers.26.self_attn.q_proj.weight', 'model.layers.26.self_attn.k_proj.weight', 'model.layers.26.self_attn.v_proj.weight', 'model.layers.26.self_attn.o_proj.weight', 'model.layers.26.mlp.gate_proj.weight', 'model.layers.26.mlp.up_proj.weight', 'model.layers.26.mlp.down_proj.weight', 'model.layers.26.input_layernorm.weight', 'model.layers.26.post_attention_layernorm.weight', 'model.layers.27.self_attn.q_proj.weight', 'model.layers.27.self_attn.k_proj.weight', 'model.layers.27.self_attn.v_proj.weight', 'model.layers.27.self_attn.o_proj.weight', 'model.layers.27.mlp.gate_proj.weight', 'model.layers.27.mlp.up_proj.weight', 'model.layers.27.mlp.down_proj.weight', 'model.layers.27.input_layernorm.weight', 'model.layers.27.post_attention_layernorm.weight', 'model.layers.28.self_attn.q_proj.weight', 'model.layers.28.self_attn.k_proj.weight', 'model.layers.28.self_attn.v_proj.weight', 'model.layers.28.self_attn.o_proj.weight', 'model.layers.28.mlp.gate_proj.weight', 'model.layers.28.mlp.up_proj.weight', 'model.layers.28.mlp.down_proj.weight', 'model.layers.28.input_layernorm.weight', 'model.layers.28.post_attention_layernorm.weight', 'model.layers.29.self_attn.q_proj.weight', 'model.layers.29.self_attn.k_proj.weight', 'model.layers.29.self_attn.v_proj.weight', 'model.layers.29.self_attn.o_proj.weight', 'model.layers.29.mlp.gate_proj.weight', 'model.layers.29.mlp.up_proj.weight', 'model.layers.29.mlp.down_proj.weight', 'model.layers.29.input_layernorm.weight', 'model.layers.29.post_attention_layernorm.weight', 'model.layers.30.self_attn.q_proj.weight', 'model.layers.30.self_attn.k_proj.weight', 'model.layers.30.self_attn.v_proj.weight', 'model.layers.30.self_attn.o_proj.weight', 'model.layers.30.mlp.gate_proj.weight', 'model.layers.30.mlp.up_proj.weight', 'model.layers.30.mlp.down_proj.weight', 'model.layers.30.input_layernorm.weight', 'model.layers.30.post_attention_layernorm.weight', 'model.layers.31.self_attn.q_proj.weight', 'model.layers.31.self_attn.k_proj.weight', 'model.layers.31.self_attn.v_proj.weight', 'model.layers.31.self_attn.o_proj.weight', 'model.layers.31.mlp.gate_proj.weight', 'model.layers.31.mlp.up_proj.weight', 'model.layers.31.mlp.down_proj.weight', 'model.layers.31.input_layernorm.weight', 'model.layers.31.post_attention_layernorm.weight', 'model.layers.32.self_attn.q_proj.weight', 'model.layers.32.self_attn.k_proj.weight', 'model.layers.32.self_attn.v_proj.weight', 'model.layers.32.self_attn.o_proj.weight', 'model.layers.32.mlp.gate_proj.weight', 'model.layers.32.mlp.up_proj.weight', 'model.layers.32.mlp.down_proj.weight', 'model.layers.32.input_layernorm.weight', 'model.layers.32.post_attention_layernorm.weight', 'model.layers.33.self_attn.q_proj.weight', 'model.layers.33.self_attn.k_proj.weight', 'model.layers.33.self_attn.v_proj.weight', 'model.layers.33.self_attn.o_proj.weight', 'model.layers.33.mlp.gate_proj.weight', 'model.layers.33.mlp.up_proj.weight', 'model.layers.33.mlp.down_proj.weight', 'model.layers.33.input_layernorm.weight', 'model.layers.33.post_attention_layernorm.weight', 'model.layers.34.self_attn.q_proj.weight', 'model.layers.34.self_attn.k_proj.weight', 'model.layers.34.self_attn.v_proj.weight', 'model.layers.34.self_attn.o_proj.weight', 'model.layers.34.mlp.gate_proj.weight', 'model.layers.34.mlp.up_proj.weight', 'model.layers.34.mlp.down_proj.weight', 'model.layers.34.input_layernorm.weight', 'model.layers.34.post_attention_layernorm.weight', 'model.layers.35.self_attn.q_proj.weight', 'model.layers.35.self_attn.k_proj.weight', 'model.layers.35.self_attn.v_proj.weight', 'model.layers.35.self_attn.o_proj.weight', 'model.layers.35.mlp.gate_proj.weight', 'model.layers.35.mlp.up_proj.weight', 'model.layers.35.mlp.down_proj.weight', 'model.layers.35.input_layernorm.weight', 'model.layers.35.post_attention_layernorm.weight', 'model.layers.36.self_attn.q_proj.weight', 'model.layers.36.self_attn.k_proj.weight', 'model.layers.36.self_attn.v_proj.weight', 'model.layers.36.self_attn.o_proj.weight', 'model.layers.36.mlp.gate_proj.weight', 'model.layers.36.mlp.up_proj.weight', 'model.layers.36.mlp.down_proj.weight', 'model.layers.36.input_layernorm.weight', 'model.layers.36.post_attention_layernorm.weight', 'model.layers.37.self_attn.q_proj.weight', 'model.layers.37.self_attn.k_proj.weight', 'model.layers.37.self_attn.v_proj.weight', 'model.layers.37.self_attn.o_proj.weight', 'model.layers.37.mlp.gate_proj.weight', 'model.layers.37.mlp.up_proj.weight', 'model.layers.37.mlp.down_proj.weight', 'model.layers.37.input_layernorm.weight', 'model.layers.37.post_attention_layernorm.weight', 'model.layers.38.self_attn.q_proj.weight', 'model.layers.38.self_attn.k_proj.weight', 'model.layers.38.self_attn.v_proj.weight', 'model.layers.38.self_attn.o_proj.weight', 'model.layers.38.mlp.gate_proj.weight', 'model.layers.38.mlp.up_proj.weight', 'model.layers.38.mlp.down_proj.weight', 'model.layers.38.input_layernorm.weight', 'model.layers.38.post_attention_layernorm.weight', 'model.layers.39.self_attn.q_proj.weight', 'model.layers.39.self_attn.k_proj.weight', 'model.layers.39.self_attn.v_proj.weight', 'model.layers.39.self_attn.o_proj.weight', 'model.layers.39.mlp.gate_proj.weight', 'model.layers.39.mlp.up_proj.weight', 'model.layers.39.mlp.down_proj.weight', 'model.layers.39.input_layernorm.weight', 'model.layers.39.post_attention_layernorm.weight', 'model.layers.40.self_attn.q_proj.weight', 'model.layers.40.self_attn.k_proj.weight', 'model.layers.40.self_attn.v_proj.weight', 'model.layers.40.self_attn.o_proj.weight', 'model.layers.40.mlp.gate_proj.weight', 'model.layers.40.mlp.up_proj.weight', 'model.layers.40.mlp.down_proj.weight', 'model.layers.40.input_layernorm.weight', 'model.layers.40.post_attention_layernorm.weight', 'model.layers.41.self_attn.q_proj.weight', 'model.layers.41.self_attn.k_proj.weight', 'model.layers.41.self_attn.v_proj.weight', 'model.layers.41.self_attn.o_proj.weight', 'model.layers.41.mlp.gate_proj.weight', 'model.layers.41.mlp.up_proj.weight', 'model.layers.41.mlp.down_proj.weight', 'model.layers.41.input_layernorm.weight', 'model.layers.41.post_attention_layernorm.weight', 'model.layers.42.self_attn.q_proj.weight', 'model.layers.42.self_attn.k_proj.weight', 'model.layers.42.self_attn.v_proj.weight', 'model.layers.42.self_attn.o_proj.weight', 'model.layers.42.mlp.gate_proj.weight', 'model.layers.42.mlp.up_proj.weight', 'model.layers.42.mlp.down_proj.weight', 'model.layers.42.input_layernorm.weight', 'model.layers.42.post_attention_layernorm.weight', 'model.layers.43.self_attn.q_proj.weight', 'model.layers.43.self_attn.k_proj.weight', 'model.layers.43.self_attn.v_proj.weight', 'model.layers.43.self_attn.o_proj.weight', 'model.layers.43.mlp.gate_proj.weight', 'model.layers.43.mlp.up_proj.weight', 'model.layers.43.mlp.down_proj.weight', 'model.layers.43.input_layernorm.weight', 'model.layers.43.post_attention_layernorm.weight', 'model.layers.44.self_attn.q_proj.weight', 'model.layers.44.self_attn.k_proj.weight', 'model.layers.44.self_attn.v_proj.weight', 'model.layers.44.self_attn.o_proj.weight', 'model.layers.44.mlp.gate_proj.weight', 'model.layers.44.mlp.up_proj.weight', 'model.layers.44.mlp.down_proj.weight', 'model.layers.44.input_layernorm.weight', 'model.layers.44.post_attention_layernorm.weight', 'model.layers.45.self_attn.q_proj.weight', 'model.layers.45.self_attn.k_proj.weight', 'model.layers.45.self_attn.v_proj.weight', 'model.layers.45.self_attn.o_proj.weight', 'model.layers.45.mlp.gate_proj.weight', 'model.layers.45.mlp.up_proj.weight', 'model.layers.45.mlp.down_proj.weight', 'model.layers.45.input_layernorm.weight', 'model.layers.45.post_attention_layernorm.weight', 'model.layers.46.self_attn.q_proj.weight', 'model.layers.46.self_attn.k_proj.weight', 'model.layers.46.self_attn.v_proj.weight', 'model.layers.46.self_attn.o_proj.weight', 'model.layers.46.mlp.gate_proj.weight', 'model.layers.46.mlp.up_proj.weight', 'model.layers.46.mlp.down_proj.weight', 'model.layers.46.input_layernorm.weight', 'model.layers.46.post_attention_layernorm.weight', 'model.layers.47.self_attn.q_proj.weight', 'model.layers.47.self_attn.k_proj.weight', 'model.layers.47.self_attn.v_proj.weight', 'model.layers.47.self_attn.o_proj.weight', 'model.layers.47.mlp.gate_proj.weight', 'model.layers.47.mlp.up_proj.weight', 'model.layers.47.mlp.down_proj.weight', 'model.layers.47.input_layernorm.weight', 'model.layers.47.post_attention_layernorm.weight', 'model.layers.48.self_attn.q_proj.weight', 'model.layers.48.self_attn.k_proj.weight', 'model.layers.48.self_attn.v_proj.weight', 'model.layers.48.self_attn.o_proj.weight', 'model.layers.48.mlp.gate_proj.weight', 'model.layers.48.mlp.up_proj.weight', 'model.layers.48.mlp.down_proj.weight', 'model.layers.48.input_layernorm.weight', 'model.layers.48.post_attention_layernorm.weight', 'model.layers.49.self_attn.q_proj.weight', 'model.layers.49.self_attn.k_proj.weight', 'model.layers.49.self_attn.v_proj.weight', 'model.layers.49.self_attn.o_proj.weight', 'model.layers.49.mlp.gate_proj.weight', 'model.layers.49.mlp.up_proj.weight', 'model.layers.49.mlp.down_proj.weight', 'model.layers.49.input_layernorm.weight', 'model.layers.49.post_attention_layernorm.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.3.self_attn.o_proj.weight'])" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[6], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Load state dict into hybrid model from Thinker\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m model_base \u001b[38;5;241m=\u001b[39m \u001b[43mMistralForCausalLM\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath_thinker\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m model_hybrid\u001b[38;5;241m.\u001b[39mload_state_dict(model_base\u001b[38;5;241m.\u001b[39mstate_dict(), strict\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/transformers/modeling_utils.py:311\u001b[0m, in \u001b[0;36mrestore_default_torch_dtype.._wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 309\u001b[0m old_dtype \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mget_default_dtype()\n\u001b[1;32m 310\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 311\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 312\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 313\u001b[0m torch\u001b[38;5;241m.\u001b[39mset_default_dtype(old_dtype)\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/transformers/modeling_utils.py:4839\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 4829\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dtype_orig \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 4830\u001b[0m torch\u001b[38;5;241m.\u001b[39mset_default_dtype(dtype_orig)\n\u001b[1;32m 4832\u001b[0m (\n\u001b[1;32m 4833\u001b[0m model,\n\u001b[1;32m 4834\u001b[0m missing_keys,\n\u001b[1;32m 4835\u001b[0m unexpected_keys,\n\u001b[1;32m 4836\u001b[0m mismatched_keys,\n\u001b[1;32m 4837\u001b[0m offload_index,\n\u001b[1;32m 4838\u001b[0m error_msgs,\n\u001b[0;32m-> 4839\u001b[0m ) \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_load_pretrained_model\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 4840\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4841\u001b[0m \u001b[43m \u001b[49m\u001b[43mstate_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4842\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheckpoint_files\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4843\u001b[0m \u001b[43m \u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4844\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_mismatched_sizes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_mismatched_sizes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4845\u001b[0m \u001b[43m \u001b[49m\u001b[43msharded_metadata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msharded_metadata\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4846\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice_map\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4847\u001b[0m \u001b[43m \u001b[49m\u001b[43mdisk_offload_folder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moffload_folder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4848\u001b[0m \u001b[43m \u001b[49m\u001b[43moffload_state_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moffload_state_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4849\u001b[0m \u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch_dtype\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4850\u001b[0m \u001b[43m \u001b[49m\u001b[43mhf_quantizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhf_quantizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4851\u001b[0m \u001b[43m \u001b[49m\u001b[43mkeep_in_fp32_regex\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkeep_in_fp32_regex\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4852\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice_mesh\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice_mesh\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4853\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey_mapping\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkey_mapping\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4854\u001b[0m \u001b[43m \u001b[49m\u001b[43mweights_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mweights_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4855\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4857\u001b[0m \u001b[38;5;66;03m# record tp degree the model sharded to\u001b[39;00m\n\u001b[1;32m 4858\u001b[0m model\u001b[38;5;241m.\u001b[39m_tp_size \u001b[38;5;241m=\u001b[39m tp_size\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/transformers/modeling_utils.py:5302\u001b[0m, in \u001b[0;36mPreTrainedModel._load_pretrained_model\u001b[0;34m(cls, model, state_dict, checkpoint_files, pretrained_model_name_or_path, ignore_mismatched_sizes, sharded_metadata, device_map, disk_offload_folder, offload_state_dict, dtype, hf_quantizer, keep_in_fp32_regex, device_mesh, key_mapping, weights_only)\u001b[0m\n\u001b[1;32m 5299\u001b[0m args_list \u001b[38;5;241m=\u001b[39m logging\u001b[38;5;241m.\u001b[39mtqdm(args_list, desc\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mLoading checkpoint shards\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 5301\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m args \u001b[38;5;129;01min\u001b[39;00m args_list:\n\u001b[0;32m-> 5302\u001b[0m _error_msgs, disk_offload_index, cpu_offload_index \u001b[38;5;241m=\u001b[39m \u001b[43mload_shard_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5303\u001b[0m error_msgs \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m _error_msgs\n\u001b[1;32m 5305\u001b[0m \u001b[38;5;66;03m# Adjust offloaded weights name and save if needed\u001b[39;00m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/transformers/modeling_utils.py:933\u001b[0m, in \u001b[0;36mload_shard_file\u001b[0;34m(args)\u001b[0m\n\u001b[1;32m 931\u001b[0m \u001b[38;5;66;03m# Skip it with fsdp on ranks other than 0\u001b[39;00m\n\u001b[1;32m 932\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (is_fsdp_enabled() \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_local_dist_rank_0() \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_quantized):\n\u001b[0;32m--> 933\u001b[0m disk_offload_index, cpu_offload_index \u001b[38;5;241m=\u001b[39m \u001b[43m_load_state_dict_into_meta_model\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 934\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel_to_load\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 935\u001b[0m \u001b[43m \u001b[49m\u001b[43mstate_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 936\u001b[0m \u001b[43m \u001b[49m\u001b[43mshard_file\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 937\u001b[0m \u001b[43m \u001b[49m\u001b[43mexpected_keys\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 938\u001b[0m \u001b[43m \u001b[49m\u001b[43mreverse_key_renaming_mapping\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 939\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice_map\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 940\u001b[0m \u001b[43m \u001b[49m\u001b[43mdisk_offload_folder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdisk_offload_folder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 941\u001b[0m \u001b[43m \u001b[49m\u001b[43mdisk_offload_index\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdisk_offload_index\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 942\u001b[0m \u001b[43m \u001b[49m\u001b[43mcpu_offload_folder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcpu_offload_folder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 943\u001b[0m \u001b[43m \u001b[49m\u001b[43mcpu_offload_index\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcpu_offload_index\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 944\u001b[0m \u001b[43m \u001b[49m\u001b[43mhf_quantizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhf_quantizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 945\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_safetensors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_offloaded_safetensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 946\u001b[0m \u001b[43m \u001b[49m\u001b[43mkeep_in_fp32_regex\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkeep_in_fp32_regex\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 947\u001b[0m \u001b[43m \u001b[49m\u001b[43munexpected_keys\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43munexpected_keys\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 948\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice_mesh\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice_mesh\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 949\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 951\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m error_msgs, disk_offload_index, cpu_offload_index\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/transformers/modeling_utils.py:810\u001b[0m, in \u001b[0;36m_load_state_dict_into_meta_model\u001b[0;34m(model, state_dict, shard_file, expected_keys, reverse_renaming_mapping, device_map, disk_offload_folder, disk_offload_index, cpu_offload_folder, cpu_offload_index, hf_quantizer, is_safetensors, keep_in_fp32_regex, unexpected_keys, device_mesh)\u001b[0m\n\u001b[1;32m 808\u001b[0m param \u001b[38;5;241m=\u001b[39m param[\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m]\n\u001b[1;32m 809\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m casting_dtype \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 810\u001b[0m param \u001b[38;5;241m=\u001b[39m \u001b[43mparam\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcasting_dtype\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 811\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m to_contiguous:\n\u001b[1;32m 812\u001b[0m param \u001b[38;5;241m=\u001b[39m param\u001b[38;5;241m.\u001b[39mcontiguous()\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] } ], "source": [ @@ -1050,7 +1065,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -1125,14 +1140,14 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Loading checkpoint shards: 100%|██████████| 7/7 [00:05<00:00, 1.22it/s]\n" + "Loading checkpoint shards: 100%|██████████| 7/7 [00:07<00:00, 1.01s/it]\n" ] }, { @@ -1146,17 +1161,21 @@ "Converting layer %d... 2\n", "Skipping transformer layer 2...\n", "Converting layer %d... 3\n", - "Skipping transformer layer 3...\n", + "Converting layer 3...\n", + "Init Mamba using Attention\n", "Converting layer %d... 4\n", "Skipping transformer layer 4...\n", "Converting layer %d... 5\n", - "Skipping transformer layer 5...\n", + "Converting layer 5...\n", + "Init Mamba using Attention\n", "Converting layer %d... 6\n", - "Skipping transformer layer 6...\n", + "Converting layer 6...\n", + "Init Mamba using Attention\n", "Converting layer %d... 7\n", "Skipping transformer layer 7...\n", "Converting layer %d... 8\n", - "Skipping transformer layer 8...\n", + "Converting layer 8...\n", + "Init Mamba using Attention\n", "Converting layer %d... 9\n", "Skipping transformer layer 9...\n", "Converting layer %d... 10\n", @@ -1277,7 +1296,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -1286,7 +1305,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -1306,12 +1325,12 @@ " \"t\",\n", " \"t\",\n", " \"t\",\n", + " \"m2\",\n", " \"t\",\n", + " \"m2\",\n", + " \"m2\",\n", " \"t\",\n", - " \"t\",\n", - " \"t\",\n", - " \"t\",\n", - " \"t\",\n", + " \"m2\",\n", " \"t\",\n", " \"t\",\n", " \"t\",\n", @@ -1380,6 +1399,8 @@ " \"dt_rank\": \"auto\",\n", " \"dt_scale\": 1.0,\n", " \"expand\": 1,\n", + " \"head_dim\": 128,\n", + " \"layer_norm_epsilon\": 1e-05,\n", " \"n_qk_heads\": 32,\n", " \"n_v_heads\": 32\n", " },\n", @@ -1391,7 +1412,7 @@ "}" ] }, - "execution_count": 14, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -1402,7 +1423,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -1411,26 +1432,40 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 427.77it/s]\n" + "Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 37.80it/s]" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['model.layers.6.mixer.A_log', 'model.layers.6.mixer.D', 'model.layers.6.mixer.conv1d.weight', 'model.layers.6.mixer.conv1d.bias', 'model.layers.6.mixer.in_proj.weight', 'model.layers.6.mixer.dt_in_proj.weight', 'model.layers.6.mixer.dt_proj.weight', 'model.layers.6.mixer.dt_proj.bias', 'model.layers.6.mixer.out_proj.weight', 'model.layers.8.mixer.A_log', 'model.layers.8.mixer.D', 'model.layers.8.mixer.conv1d.weight', 'model.layers.8.mixer.conv1d.bias', 'model.layers.8.mixer.in_proj.weight', 'model.layers.8.mixer.dt_in_proj.weight', 'model.layers.8.mixer.dt_proj.weight', 'model.layers.8.mixer.dt_proj.bias', 'model.layers.8.mixer.out_proj.weight']\n", + "['model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.8.self_attn.o_proj.weight']\n" ] } ], "source": [ "# load state dict from existing pretrained SSM?\n", - "path_25hyb = \"/mnt/checkpoints/ssm/apriel_ssm_thinker5l_hybrid_1ssm_init_rand_debug_tpformat\" #\"/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/15b-oshyb25lmil-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti5000_lm6/export/apriel_ssm_thinker_hybrid/5000_new\"\n", + "path_25hyb = path_hybrid #\"/mnt/checkpoints/ssm/apriel_ssm_thinker5l_hybrid_1ssm_init_rand_debug_tpformat\" #\"/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/15b-oshyb25lmil-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti5000_lm6/export/apriel_ssm_thinker_hybrid/5000_new\"\n", "model = AprielThinkerSSMHybridForCausalLM.from_pretrained(path_25hyb)\n", "state_dict = model.state_dict()\n", - "\n", - "# missing, unexpected = transformer.load_state_dict(state_dict, strict=False)\n", - "# print(missing)\n", - "# print(unexpected)\n", + "missing, unexpected = transformer.load_state_dict(state_dict, strict=False)\n", + "print(missing)\n", + "print(unexpected)\n", "\n", "\n", "\n", @@ -1450,7 +1485,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1458,7 +1493,10 @@ "# mamba2, state 16, expand 1, i.e. same as M1, but with discrete mamba2 and MIL\n", "# transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_1ssm_leastimportant_m2_16hexp1_init_mil\") # 1 ssm\n", "# transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_25ssm_leastimportant_m2_16hexp1_init_mil\") # 25 ssm\n", - "transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_25ssm_leastimportant_m2_16hexp1_init_mil_tpformat\") # 25 ssm\n", + "# transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_25ssm_leastimportant_m2_16hexp1_init_mil_tpformat\") # 25 ssm\n", + "# transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_27ssm_leastimportant_m2_16hexp1_init_hyb25distsftvrlm223k_mil\") # 25 ssm\n", + "# transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_30ssm_leastimportant_m2_16hexp1_init_hyb27distsftvrlm2_3p5ksteps_mil\") # 30 ssm\n", + "transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_29ssm_leastimportant_m2_16hexp1_init_hyb27distsftvrlm2_3p5ksteps_mil\") # 29 ssm\n", "\n", "\n", "# transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_40ssm_leastimportant_m2_16hexp1_init_mil_uniform_from_25h5000lm6\") # 40 ssm" @@ -1482,39 +1520,1123 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Data mixing" + "# Nemotron-h mamba layer" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/toolkit/.local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import gc\n", + "\n", + "import click\n", + "import torch\n", + "from transformers import AutoConfig, AutoModelForCausalLM\n", + "from transformers import MistralForCausalLM\n", + "\n", + "from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig\n", + "from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import AprielThinkerSSMHybridForCausalLM, AprielSSMM2DecoderLayer, AprielSSMDecoderLayer, AprielSSMNemotronHM2DecoderLayer, NemotronHMamba2Mixer\n", + "from fast_llm.models.ssm.external.nemotron.modeling import NemotronHMamba2Mixer as NemotronHMamba2Mixer_original\n", + "from fast_llm.models.ssm.external.nemotron.config import NemotronHConfig\n", + "from transformers.models.mistral.modeling_mistral import MistralDecoderLayer\n", + "\n", + "# enable file reload \n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Instantiating mamba2 with num_heads: 256, head_dim: 16, \n", + " intermediate_size: 4096, \n", + " d_xb: 1024, \n", + " number_xb_heads: 64, \n", + " repeat_groups: 4, \n", + " d_state: 16\n", + "Instantiating mamba2 with num_heads: 256, head_dim: 16, \n", + " intermediate_size: 4096, \n", + " d_xb: 1024, \n", + " number_xb_heads: 64, \n", + " repeat_groups: 4, \n", + " d_state: 16\n", + "Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 382.75it/s]\n" + ] + } + ], + "source": [ + "# model.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_nmtrhnm2_5l_debug\")\n", + "model = AprielThinkerSSMHybridForCausalLM.from_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_nmtrhnm2_5l_debug\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "torch.Size([])\n", - "KL (global, F.kl_div) = 0.738795\n", - "KL (sum of shards, manual) = 0.738795\n" + "model.embed_tokens.weight\n", + "model.layers.0.self_attn.q_proj.weight\n", + "model.layers.0.self_attn.k_proj.weight\n", + "model.layers.0.self_attn.v_proj.weight\n", + "model.layers.0.self_attn.o_proj.weight\n", + "model.layers.0.mlp.gate_proj.weight\n", + "model.layers.0.mlp.up_proj.weight\n", + "model.layers.0.mlp.down_proj.weight\n", + "model.layers.0.input_layernorm.weight\n", + "model.layers.0.post_attention_layernorm.weight\n", + "model.layers.1.self_attn.q_proj.weight\n", + "model.layers.1.self_attn.k_proj.weight\n", + "model.layers.1.self_attn.v_proj.weight\n", + "model.layers.1.self_attn.o_proj.weight\n", + "model.layers.1.mlp.gate_proj.weight\n", + "model.layers.1.mlp.up_proj.weight\n", + "model.layers.1.mlp.down_proj.weight\n", + "model.layers.1.input_layernorm.weight\n", + "model.layers.1.post_attention_layernorm.weight\n", + "model.layers.2.mixer.dt_bias\n", + "model.layers.2.mixer.A_log\n", + "model.layers.2.mixer.D\n", + "model.layers.2.mixer.conv1d.weight\n", + "model.layers.2.mixer.conv1d.bias\n", + "model.layers.2.mixer.in_proj.weight\n", + "model.layers.2.mixer.dt_in_proj.weight\n", + "model.layers.2.mixer.norm.weight\n", + "model.layers.2.mixer.out_proj.weight\n", + "model.layers.2.mlp.gate_proj.weight\n", + "model.layers.2.mlp.up_proj.weight\n", + "model.layers.2.mlp.down_proj.weight\n", + "model.layers.2.input_layernorm.weight\n", + "model.layers.2.post_attention_layernorm.weight\n", + "model.layers.3.mixer.dt_bias\n", + "model.layers.3.mixer.A_log\n", + "model.layers.3.mixer.D\n", + "model.layers.3.mixer.conv1d.weight\n", + "model.layers.3.mixer.conv1d.bias\n", + "model.layers.3.mixer.in_proj.weight\n", + "model.layers.3.mixer.dt_in_proj.weight\n", + "model.layers.3.mixer.norm.weight\n", + "model.layers.3.mixer.out_proj.weight\n", + "model.layers.3.mlp.gate_proj.weight\n", + "model.layers.3.mlp.up_proj.weight\n", + "model.layers.3.mlp.down_proj.weight\n", + "model.layers.3.input_layernorm.weight\n", + "model.layers.3.post_attention_layernorm.weight\n", + "model.layers.4.self_attn.q_proj.weight\n", + "model.layers.4.self_attn.k_proj.weight\n", + "model.layers.4.self_attn.v_proj.weight\n", + "model.layers.4.self_attn.o_proj.weight\n", + "model.layers.4.mlp.gate_proj.weight\n", + "model.layers.4.mlp.up_proj.weight\n", + "model.layers.4.mlp.down_proj.weight\n", + "model.layers.4.input_layernorm.weight\n", + "model.layers.4.post_attention_layernorm.weight\n", + "model.norm.weight\n", + "lm_head.weight\n" ] } ], - "source": [] + "source": [ + "for k,v in model.state_dict().items():\n", + " print(k)" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 27, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "path_hybrid = \"/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/15b-hyb25distsftvrlm2-bs64-lr5e-06-lrs1-1-1-1-sl16384_ti60000_aprsft/export/apriel_ssm_thinker_hybrid/23000\"\n", + "path_thinker = \"/mnt/checkpoints/upstream/Apriel-Nemotron-15b-Thinker\"\n", + "\n", + "config_thinker = AutoConfig.from_pretrained(path_thinker)\n", + "\n", + "# config_hybrid = AprielSSMHybridConfig.from_pretrained(path_hybrid)\n", + "# config_thinker.num_hidden_layers = 5\n", + "# hybrid_block_layout = [\"t\"] * config_thinker.num_hidden_layers\n", + "# # debug\n", + "# hybrid_block_layout[2] = \"nm2\"\n", + "# hybrid_block_layout[3] = \"nm2\"\n", + "\n", + "# 25/50\n", + "hybrid_block_layout = [\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"nm2\",\n", + " \"nm2\",\n", + " \"nm2\",\n", + " \"nm2\",\n", + " \"nm2\",\n", + " \"nm2\",\n", + " \"t\",\n", + " \"nm2\",\n", + " \"t\",\n", + " \"t\",\n", + " \"nm2\",\n", + " \"nm2\",\n", + " \"nm2\",\n", + " \"nm2\",\n", + " \"t\",\n", + " \"nm2\",\n", + " \"nm2\",\n", + " \"nm2\",\n", + " \"nm2\",\n", + " \"nm2\",\n", + " \"nm2\",\n", + " \"nm2\",\n", + " \"nm2\",\n", + " \"nm2\",\n", + " \"nm2\",\n", + " \"nm2\",\n", + " \"nm2\",\n", + " \"nm2\",\n", + " \"t\",\n", + " \"nm2\"\n", + " ]\n", + " \n", + "ssm_config = {\n", + " \"d_state\": 16,\n", + " \"d_xb\": 1024,\n", + " \"expand\": 1,\n", + " \"d_conv\": 4,\n", + " \"d_inner\": 4096,\n", + " \"conv_bias\": True,\n", + " \"bias\": False,\n", + " \"head_dim\": 16, # 4096/16 = 32 heads, 1024/128 = 8 KVheads and 4 repeat groups\n", + "}\n", + "config_thinker.hybrid_block_layout = hybrid_block_layout\n", + "# config_thinker.ssm_cfg = ssm_config\n", + "# model = AprielThinkerSSMHybridForCausalLM(config_hybrid)\n", + "# mixer = NemotronHMamba2Mixer(d_model=4096, **ssm_config)\n", + "\n", + "\n", + "\n", + "config_hybrid = AprielSSMHybridConfig(\n", + " **config_thinker.to_dict(),\n", + " ssm_cfg=ssm_config\n", + ")\n" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "50" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(config_hybrid.hybrid_block_layout)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "\n", + "def convert_layers(transformer, mamba_config, hybrid_block_layout, init_with_kqvo, attn_bias, torch_dtype):\n", + " config = transformer.config\n", + " embed_dim = config.hidden_size\n", + " num_heads = config.num_attention_heads\n", + " num_heads_kv = config.num_key_value_heads\n", + " head_dim = embed_dim // num_heads\n", + " q_dim = head_dim * num_heads\n", + " kv_dim = head_dim * num_heads_kv\n", + "\n", + " for layer_idx, type in enumerate(hybrid_block_layout):\n", + " print(\"Converting layer %d...\", layer_idx)\n", + " # Fetch the layer module for easier access\n", + " layer_module = transformer.model.layers._modules[f\"{layer_idx}\"]\n", + " if type == \"t\":\n", + " print(\"Skipping transformer layer %d...\" % layer_idx)\n", + " elif type == \"nm2\":\n", + " print(\"Converting layer %d...\" % layer_idx)\n", + " # Use MambaDecoderLayer for the remaining layers\n", + " mamba_encoder = AprielSSMNemotronHM2DecoderLayer(\n", + " mamba_config,\n", + " layer_idx,\n", + " device=\"cpu\",\n", + " dtype=torch_dtype,\n", + " )\n", + " \n", + " mamba_encoder.mlp.load_state_dict(layer_module.mlp.state_dict())\n", + " mamba_encoder.input_layernorm.load_state_dict(layer_module.input_layernorm.state_dict())\n", + " mamba_encoder.post_attention_layernorm.load_state_dict(layer_module.post_attention_layernorm.state_dict())\n", + " mamba_encoder.mixer.out_proj.load_state_dict(layer_module.self_attn.o_proj.state_dict())\n", + "\n", + " num_xb_heads = mamba_config.ssm_cfg[\"d_xb\"] // mamba_config.ssm_cfg[\"head_dim\"]\n", + " num_heads = mamba_config.ssm_cfg[\"d_inner\"] // mamba_config.ssm_cfg[\"head_dim\"]\n", + "\n", + " if init_with_kqvo:\n", + " # Copy weights: [z, x, B, C, dt], x -> v, B -> k, C -> q\n", + " mamba_encoder.mixer.in_proj.weight.data[\n", + " mamba_config.ssm_cfg[\"d_inner\"] : mamba_config.ssm_cfg[\"d_inner\"] + mamba_config.ssm_cfg[\"d_xb\"], :\n", + " ].copy_(layer_module.self_attn.v_proj.weight.data)\n", + " mamba_encoder.mixer.in_proj.weight.data[\n", + " mamba_config.ssm_cfg[\"d_inner\"] + mamba_config.ssm_cfg[\"d_xb\"] : mamba_config.ssm_cfg[\"d_inner\"] + mamba_config.ssm_cfg[\"d_xb\"] + (num_xb_heads * mamba_config.ssm_cfg[\"d_state\"]), :\n", + " ].copy_(layer_module.self_attn.k_proj.weight.data)\n", + " mamba_encoder.mixer.in_proj.weight.data[\n", + " mamba_config.ssm_cfg[\"d_inner\"] + mamba_config.ssm_cfg[\"d_xb\"] + (num_xb_heads * mamba_config.ssm_cfg[\"d_state\"]) : mamba_config.ssm_cfg[\"d_inner\"] + mamba_config.ssm_cfg[\"d_xb\"] + (num_xb_heads * mamba_config.ssm_cfg[\"d_state\"]) + (num_heads * mamba_config.ssm_cfg[\"d_state\"]), :\n", + " ].copy_(layer_module.self_attn.q_proj.weight.data)\n", + "\n", + " print(\"Init Mamba using Attention\")\n", + "\n", + " transformer.model.layers[layer_idx] = mamba_encoder\n", + "\n", + " # elif type == \"m2d\":\n", + " # print(\"Converting layer %d...\" % layer_idx)\n", + " # mamba_encoder = AprielSSMDecoderLayer(\n", + " # mamba_config,\n", + " # layer_idx,\n", + " # device=\"cpu\",\n", + " # dtype=torch_dtype,\n", + " # )\n", + " # mamba_encoder.mlp.load_state_dict(layer_module.mlp.state_dict())\n", + " # mamba_encoder.input_layernorm.load_state_dict(layer_module.input_layernorm.state_dict())\n", + " # mamba_encoder.post_attention_layernorm.load_state_dict(layer_module.post_attention_layernorm.state_dict())\n", + " # mamba_encoder.mixer.out_proj.load_state_dict(layer_module.self_attn.o_proj.state_dict())\n", + "\n", + " # if init_with_kqvo:\n", + " \n", + "\n", + "\n", + " \n", + " else:\n", + " raise ValueError(f\"Invalid layer type: {type}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 0%| | 0/7 [00:00 torch.Tensor: if cache_init: - self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states[0].device) else: self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) return self.conv_states[layer_idx] def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): - self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[0].device) return self.ssm_states[layer_idx] def reset(self): @@ -217,7 +218,8 @@ def reset(self): self.ssm_states.zero_() -# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py +# Copied from https://huggingface.co/nvidia/Nemotron-H-8B-Base-8K/blob/main/modeling_nemotron_h.py whichis taken from +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py class HybridMambaAttentionDynamicCache(DynamicCache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache @@ -242,7 +244,7 @@ def __init__(self, config: AprielSSMHybridConfig, batch_size, dtype=torch.float1 else config.ssm_cfg["expand"] * config.hidden_size ) ssm_state_size = config.ssm_cfg["d_state"] - conv_kernel_size = config.ssm_cfg["d_conv"] + self.conv_kernel_size = conv_kernel_size = config.ssm_cfg["d_conv"] self.n_qk_heads = config.ssm_cfg["n_qk_heads"] self.num_C_head = intermediate_size // ssm_state_size # mamba2 assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" @@ -341,14 +343,14 @@ def update_conv_state( self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False ) -> torch.Tensor: if cache_init: - self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states[0].device) else: self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) - self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states[0].device) return self.conv_states[layer_idx] def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): - self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[0].device) return self.ssm_states[layer_idx] def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: @@ -783,7 +785,411 @@ def convolutional_step(self, xBC, conv_state): return xBC, conv_state +def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): + """ + Padding x tensor with `pad_size` on the seq_len dim (dim=1) + + Assumes that we only have tensors of either size 4 or 3 + """ + pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) + + return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) + + +class MambaRMSNormGated(torch.nn.Module): + def __init__(self, hidden_size, group_size, eps=1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.group_size = group_size + + # jan28b version + def forward(self, hidden_states, gate=None): + return rmsnorm_fn( + x=hidden_states, + weight=self.weight, + bias=None, # No bias + z=gate, + eps=self.variance_epsilon, + group_size=self.group_size, + norm_before_gate=False, + ) + + +class NemotronHMamba2Mixer(nn.Module): + """ + From https://huggingface.co/nvidia/Nemotron-H-8B-Base-8K/blob/main/modeling_nemotron_h.py. + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + + + Note: we assume n_groups = num_heads here, we do not want to share B or C over heads. + Why: we mimic GQA here, so sharing B over heads would result in additional complecity which we want to avoid at this point. + Note: to reconstruct the architecture of original nemotron-H mixer (but iwth n_groups = num_heads), d_xb needs to be same as d_inner. + """ + + def __init__( + self, + d_model, + d_inner, + d_xb=None, + d_state=16, + d_conv=4, + expand=2, + head_dim=128, + layer_norm_epsilon=1e-5, + conv_bias=True, + chunk_size=128, + bias=False, + layer_idx=None, + # device=None, + # dtype=None, + **kwargs, + ): + super().__init__() + self.hidden_size = d_model + self.ssm_state_size = d_state + self.conv_kernel_size = d_conv + self.expand = expand + self.intermediate_size = ( + d_inner if d_inner is not None else d_model * expand + ) # config.mamba_num_heads * config.mamba_head_dim + + self.d_xb = d_xb if d_xb is not None else self.intermediate_size + self.layer_idx = layer_idx + self.use_conv_bias = conv_bias + self.activation = "silu" + self.act = nn.SiLU() + self.head_dim = head_dim + assert self.intermediate_size % self.head_dim == 0, "intermediate_size must be divisible by head_dim" + self.num_heads = self.intermediate_size // self.head_dim + + # for GQA simulation, where we repeat x and B for each group + self.num_xb_heads = self.d_xb // self.head_dim + assert self.num_heads % self.num_xb_heads == 0, "num_heads must be divisible by num_xb_heads" + self.repeat_groups = self.num_heads // self.num_xb_heads + if self.d_xb == self.intermediate_size: + assert self.repeat_groups == 1 + logger.warning( + f"d_xb == intermediate_size, d_xb: {self.d_xb}, intermediate_size: {self.intermediate_size}, repeat_groups: {self.repeat_groups}" + ) + + self.layer_norm_epsilon = layer_norm_epsilon + + logger.warning( + f"Instantiating mamba2 with num_heads: {self.num_heads}, head_dim: {self.head_dim}, \n \ + intermediate_size: {self.intermediate_size}, \n \ + d_xb: {self.d_xb}, \n \ + number_xb_heads: {self.num_xb_heads}, \n \ + repeat_groups: {self.repeat_groups}, \n \ + d_state: {self.ssm_state_size}" + ) + + self.n_groups = ( + self.num_heads + ) # nemotron allows for any num_groups, we use the same as num_heads for now, otherwisxe it becomes too complecated with GQA simulation + self.chunk_size = chunk_size + + self.time_step_limit = (0.0, float("inf")) # hard coded + # conv is over xBC -- d_xb (head_dim), d_bb (state_dim), d_c (state_dim) + # self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.conv_dim = ( + self.d_xb # self.num_xb_heads x head_dim + + self.num_xb_heads * self.ssm_state_size + + self.num_heads * self.ssm_state_size + ) + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=self.use_conv_bias, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1, + ) + + # projection of the input hidden states + projection_size = self.intermediate_size + self.conv_dim # + self.num_heads + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=bias, + ) + self.dt_in_proj = nn.Linear( + self.hidden_size, + self.num_heads, + bias=bias, + ) + # selective projection used to make dt, B and C input dependant + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.num_heads + 1) + self.A_log = nn.Parameter(torch.log(A)) + self.A_log._no_weight_decay = True + self.norm = MambaRMSNormGated( + self.intermediate_size, eps=self.layer_norm_epsilon, group_size=self.intermediate_size // self.n_groups + ) + self.D = nn.Parameter(torch.ones(self.num_heads)) + self.D._no_weight_decay = True + + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) + self.use_bias = bias + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" + " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + outputs = {} + # 1. Gated MLP's linear projection + # Apply_mask_to_padding_states is not used in nemotron, + # because attention_mask is not pased, see https://huggingface.co/nvidia/Nemotron-H-8B-Base-8K/blob/main/modeling_nemotron_h.py#L774 + attention_mask = None # so apply_mask_to_padding_states does nothing + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + projected_states = self.in_proj(hidden_states) + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + # C dim, note we keep number of groups same as number of heads here + head_time_state_size = self.num_heads * self.ssm_state_size + num_xb_heads_time_state_size = self.num_xb_heads * self.ssm_state_size + + # d_mlp = ( + # projected_states.shape[-1] + # - 2 * self.intermediate_size + # - 2 * self.n_groups * self.ssm_state_size + # - self.num_heads + # ) // 2 + + # Single step calculations via cache + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + gate, hidden_states_B_C = projected_states.squeeze(1).split( + [self.intermediate_size, self.conv_dim], dim=-1 + ) + dt = self.dt_in_proj(hidden_states).squeeze(1) + + # 2. Convolution sequence transformation + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + cache_params.conv_states[self.layer_idx], + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.d_xb, num_xb_heads_time_state_size, head_time_state_size], + dim=-1, + ) + # simulate GQA by repeating heads in x,b, x -> v, B -> k, C -> q + hidden_states = rearrange( + hidden_states, + "b (local_head_groups head_dim) -> b local_head_groups head_dim", + head_dim=self.head_dim, + ) # x is b x local_head_groups x l x head_dim + B = rearrange( + B, + "b (local_head_groups state_size) -> b local_head_groups state_size", + state_size=self.ssm_state_size, + ) # b is b x local_head_groups x l x state_size + batch, num_key_value_heads, head_dim = hidden_states.shape + hidden_states = hidden_states[:, :, None, :].expand( + batch, num_key_value_heads, self.repeat_groups, head_dim + ) + hidden_states = hidden_states.reshape(batch, num_key_value_heads * self.repeat_groups, head_dim) + B = B[:, :, None, :].expand(batch, num_key_value_heads, self.repeat_groups, self.ssm_state_size) + B = B.reshape(batch, num_key_value_heads * self.repeat_groups, self.ssm_state_size) + + # 3. SSM transformation z + A = -torch.exp(self.A_log.float()) # (nheads,) + A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + C = C.view(batch_size, self.num_heads, self.ssm_state_size) + # B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) + # C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) + hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) + hidden_states = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + ) + hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim) + hidden_states = self.norm(hidden_states, gate) + + # 4. Final linear projection + out = self.out_proj(hidden_states)[:, None, ...] + + # Fused calculations or step by step if no initialized cache is found + else: + A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) + dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} + + # 2-4. Fused kernel for conv1d, SSM, and the final projection + if self.training and cache_params is None: + assert False, "Should not have ended here for inference" + out = mamba_split_conv1d_scan_combined( + projected_states, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.dt_bias, + A, + D=self.D, + chunk_size=self.chunk_size, + seq_idx=None, # was seq_idx + activation=self.activation, + rmsnorm_weight=self.norm.weight, + rmsnorm_eps=self.norm.variance_epsilon, + outproj_weight=self.out_proj.weight, + outproj_bias=self.out_proj.bias, + headdim=self.head_dim, + ngroups=self.n_groups, + norm_before_gate=False, + return_final_states=False, + **dt_limit_kwargs, + ) + + else: + # we are not using mlp here, leaving it here from nemotron modeling + gate, hidden_states_B_C = projected_states.split([self.intermediate_size, self.conv_dim], dim=-1) + dt = self.dt_in_proj(hidden_states) + + # 2. Convolution sequence transformation + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, + (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), + ) + cache_params.update_conv_state( + layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True + ) + + if self.activation not in ["silu", "swish"]: + hidden_states_B_C = self.act( + self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2) + ) + else: + hidden_states_B_C = causal_conv1d_fn( + x=hidden_states_B_C.transpose(1, 2), + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + ).transpose(1, 2) + hidden_states_B_C = apply_mask_to_padding_states( + hidden_states_B_C, attention_mask + ) # this does not seem to do anything in nemotron + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.d_xb, num_xb_heads_time_state_size, head_time_state_size], + dim=-1, + ) + # simulate GQA by repeating heads in x,b, x -> v, B -> k, C -> q + hidden_states = rearrange( + hidden_states, + "b l (local_head_groups head_dim) -> b local_head_groups l head_dim", + head_dim=self.head_dim, + ) # x is b x local_head_groups x l x head_dim + B = rearrange( + B, + "b l (local_head_groups state_size) -> b local_head_groups l state_size", + state_size=self.ssm_state_size, + ) # b is b x local_head_groups x l x state_size + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, self.repeat_groups, slen, head_dim + ) + hidden_states = hidden_states.reshape(batch, num_key_value_heads * self.repeat_groups, slen, head_dim) + B = B[:, :, None, :, :].expand( + batch, num_key_value_heads, self.repeat_groups, slen, self.ssm_state_size + ) + B = B.reshape(batch, num_key_value_heads * self.repeat_groups, slen, self.ssm_state_size) + hidden_states = hidden_states.transpose(1, 2).contiguous() + B = B.transpose(1, 2).contiguous() + + # 3. SSM transformation + scan_output, ssm_state = mamba_chunk_scan_combined( + hidden_states.view(batch_size, seq_len, -1, self.head_dim), # (b, s, h, d) + dt, # (b, s, h) + A, # (h) + B.view(batch_size, seq_len, self.num_heads, -1), # (b, s, n_groups, state) + C.view(batch_size, seq_len, self.num_heads, -1), # (b, s, n_groups, state) + chunk_size=self.chunk_size, + D=self.D, + z=None, + seq_idx=None, + return_final_states=True, + dt_bias=self.dt_bias, + dt_softplus=True, + **dt_limit_kwargs, + ) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state) + + scan_output = scan_output.view(batch_size, seq_len, -1) + + # Multiply "gate" branch and apply extra normalization layer + scan_output = self.norm(scan_output, gate) + + # 4. Final linear projection + out = self.out_proj(scan_output) + outputs["hidden_states"] = out[:, :seq_len, :] + return outputs + + def torch_forward(self, *args, **kwargs): + assert False, "Should not have ended here for inference, make sure all neccessary kernels are installed" + # see implementation in nemotron modeling https://huggingface.co/nvidia/Nemotron-H-8B-Base-8K/blob/main/modeling_nemotron_h.py + + def forward( + self, + hidden_states, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + cache_params = past_key_value + if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + dtype = hidden_states.dtype + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) + + class Mamba2(nn.Module): + """ + From https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py + """ + def __init__( self, d_model, @@ -1200,6 +1606,10 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): return (hidden_states,) +class AprielSSMNemotronHM2DecoderLayer(AprielSSMDecoderLayer): + _mixer_class = NemotronHMamba2Mixer + + class AprielThinkerSSMHybridModel(MistralModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AprielDecoderLayer`, `AprielSSMDecoderLayer`] @@ -1213,7 +1623,7 @@ def __init__(self, config: AprielSSMHybridConfig, **kwargs): super().__init__(config_copy, **kwargs) self.config = config blocks = [] - logger.info(f"Loading hyubrid model with the following layout: {config.hybrid_block_layout}") + logger.info(f"Loading hybrid model with the following layout: {config.hybrid_block_layout}") for layer_idx, type in enumerate(config.hybrid_block_layout): if type == "m2d": blocks.append(AprielSSMDecoderLayer(config, layer_idx)) @@ -1223,6 +1633,8 @@ def __init__(self, config: AprielSSMHybridConfig, **kwargs): blocks.append(MistralDecoderLayer(config, layer_idx)) elif type == "i": blocks.append(AprielHybridIdentity(config)) + elif type == "nm2": + blocks.append(AprielSSMNemotronHM2DecoderLayer(config, layer_idx)) else: raise ValueError(f"Invalid block type: {type}") self.layers = nn.ModuleList(blocks) @@ -1246,7 +1658,7 @@ def forward( ) -> BaseModelOutputWithPast: use_cache = use_cache if use_cache is not None else self.config.use_cache if use_cache and past_key_values is None: - # for the case where prepare_inputs_for_generation is not called to create the cache (as in fast-llm test) + # for the case where prepare_inputs_for_generation is not called to create the cache (as in fast-llm test) batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] past_key_values = HybridMambaAttentionDynamicCache(self.config, batch_size, self.dtype, device=self.device) output = super().forward( diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/test_modeling.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/test_modeling.py new file mode 100644 index 00000000..91857d8a --- /dev/null +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/test_modeling.py @@ -0,0 +1,53 @@ +import pytest + +from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import NemotronHMamba2Mixer +from fast_llm.models.ssm.external.nemotron.config import NemotronHConfig +from fast_llm.models.ssm.external.nemotron.modeling import NemotronHMamba2Mixer as NemotronHMamba2Mixer_original + + +# in apriel's mamba2 mixer we do not used groups for B and C, but we have the d_xb dim, that simulates GQA +# so in order to reconstruct the original nemotron mixer, we need to set d_xb same as d_inner +@pytest.mark.parametrize( + "apriel_ssm_config, nemotron_h_config", + [ + ( + { + "d_state": 16, + "d_xb": 4096, + "expand": 1, + "d_conv": 4, + "d_inner": 4096, + "conv_bias": True, + "bias": False, + "head_dim": 128, # 4096/128 = 32 heads, 1024/128 = 8 KVheads and 4 repeat groups + }, + NemotronHConfig( + hidden_size=4096, + mamba_num_heads=32, + mamba_head_dim=128, + mamba_n_groups=32, + mamba_d_conv=4, + mamba_expand=1, + ssm_state_size=16, + use_bias=False, + mamba_hidden_act="silu", + ), + ) + ], +) +def test_nemotron_h_mamba2_mixers_identical(apriel_ssm_config: dict, nemotron_h_config: dict): + mixer_apriel = NemotronHMamba2Mixer(d_model=4096, **apriel_ssm_config) + mixer_nemotron_h = NemotronHMamba2Mixer_original(nemotron_h_config, 0) + + for k_a, v_a in mixer_apriel.state_dict().items(): + if k_a == "dt_in_proj.weight": + continue + v_b = mixer_nemotron_h.state_dict()[k_a] + if k_a == "in_proj.weight": + assert [v_a.shape[0], v_a.shape[1]] == [v_b.shape[0] - nemotron_h_config.mamba_num_heads, v_b.shape[1]] + else: + assert v_a.shape == v_b.shape + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py index ee2c83e0..ea1d6cd3 100644 --- a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py +++ b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py @@ -250,3 +250,76 @@ def _model_generate(self, context, max_length, stop, **generation_kwargs): use_cache=True, **generation_kwargs, ) + + +@register_model("nemotron_h") +class NemotronHWrapper(HFLM): + """Wrapper for NemotronH model for compatibility with lm-evaluation-harness.""" + + def __init__(self, pretrained, **kwargs) -> None: + if "backend" in kwargs: + assert kwargs["backend"] == "causal" + + super().__init__( + pretrained=pretrained, + backend=kwargs.pop("backend", "causal"), + **kwargs, + ) + + # Override device detection for distributed settings + self._device = _get_device() + + def _get_config(self, pretrained: str, **kwargs) -> None: + """Get the model configuration.""" + from fast_llm.models.ssm.external.nemotron.config import NemotronHConfig + + self._config = NemotronHConfig.from_pretrained(pretrained, trust_remote_code=True) + + def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: + """Create the model.""" + from fast_llm.models.ssm.external.nemotron.modeling import NemotronHForCausalLM + + # Ensure we're using the correct device + device = _get_device() + self._device = device + + self._model = NemotronHForCausalLM.from_pretrained( + pretrained, + # device=device, + torch_dtype=torch.bfloat16 if dtype == "auto" else lm_eval.models.utils.get_dtype(dtype), + config=self._config, + ) + + def _model_generate(self, context, max_length, stop, **generation_kwargs): + # Ensure we're using the correct device + device = _get_device() + + # Ensure context is on the same device as the model + context = context.to(device) + self.model.to(device) + + # Move any tensors in generation_kwargs to the correct device + generation_kwargs = _move_tensors_to_device(generation_kwargs, device) + + stopping_criteria = lm_eval.models.utils.stop_sequences_criteria( + self.tokenizer, + stop, + context.shape[1], + context.shape[0], + ) + + generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) + do_sample = generation_kwargs.get("do_sample", None) + + # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies + if generation_kwargs.get("temperature") == 0.0 and do_sample is None: + generation_kwargs["do_sample"] = do_sample = False + if do_sample is False and generation_kwargs.get("temperature") == 0.0: + generation_kwargs.pop("temperature") + return self.model.generate( + input_ids=context, + max_length=max_length, + stopping_criteria=stopping_criteria, + use_cache=True, + **generation_kwargs, + ) diff --git a/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_importance_15b_mil.py b/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_importance_15b_mil.py deleted file mode 100644 index dde11cfb..00000000 --- a/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_importance_15b_mil.py +++ /dev/null @@ -1,176 +0,0 @@ -import click -import torch -import transformers -from transformers import AutoConfig, AutoModelForCausalLM - -from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig -from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( - AprielSSMM2DecoderLayer, - AprielThinkerSSMHybridForCausalLM, -) - -device = "cuda" if torch.cuda.is_available() else "cpu" - -print("Transformers version:", transformers.__version__) - - -def convert_layers(transformer, mamba_config, hybrid_block_layout, init_with_kqvo, torch_dtype): - - for layer_idx, type in enumerate(hybrid_block_layout): - # print("Converting layer %d...", layer_idx) - # Fetch the layer module for easier access - layer_module = transformer.model.layers._modules[f"{layer_idx}"] - if type == "t": - print("Skipping transformer layer %d..." % layer_idx) - elif type == "m2": - print("Converting layer %d to Mamba2 with MIL init..." % layer_idx) - # Use MambaDecoderLayer for the remaining layers - mamba_encoder = AprielSSMM2DecoderLayer( - mamba_config, - layer_idx, - device="cpu", - dtype=torch_dtype, - ) - - mamba_encoder.mlp.load_state_dict(layer_module.mlp.state_dict()) - mamba_encoder.input_layernorm.load_state_dict(layer_module.input_layernorm.state_dict()) - mamba_encoder.post_attention_layernorm.load_state_dict(layer_module.post_attention_layernorm.state_dict()) - mamba_encoder.mixer.out_proj.load_state_dict(layer_module.self_attn.o_proj.state_dict()) - - if init_with_kqvo: - # Copy weights: [z, x, B, C, dt], x -> v, B -> k, C -> q - mamba_encoder.mixer.in_proj.weight.data[ - mamba_config.ssm_cfg["d_inner"] : mamba_config.ssm_cfg["d_inner"] + mamba_config.ssm_cfg["d_xb"], : - ].copy_(layer_module.self_attn.v_proj.weight.data) - mamba_encoder.mixer.in_proj.weight.data[ - mamba_config.ssm_cfg["d_inner"] - + mamba_config.ssm_cfg["d_xb"] : mamba_config.ssm_cfg["d_inner"] - + 2 * mamba_config.ssm_cfg["d_xb"], - :, - ].copy_(layer_module.self_attn.k_proj.weight.data) - mamba_encoder.mixer.in_proj.weight.data[ - mamba_config.ssm_cfg["d_inner"] - + 2 * mamba_config.ssm_cfg["d_xb"] : 2 * mamba_config.ssm_cfg["d_inner"] - + 2 * mamba_config.ssm_cfg["d_xb"], - :, - ].copy_(layer_module.self_attn.q_proj.weight.data) - - print("Init Mamba using Attention") - - transformer.model.layers[layer_idx] = mamba_encoder - - elif type == "m2d": - raise NotImplementedError("Discrete Mamba2 not implemented") - else: - raise ValueError(f"Invalid layer type: {type}") - - -@click.command() -@click.option("--index_to_swap", type=int, required=True) -@click.option("--checkpoint", type=str, required=True) -@click.option("--output_model_path", type=str, required=True) -@click.option("--layer_type", type=str, default="m2") -@click.option("--mil_init", type=bool, default=True) -def main( - index_to_swap: int, - checkpoint=None, - output_model_path="/mnt/checkpoints/ssm/iterative_hybrids_15b_rkl_m2/apriel_ssm_thinker_15b_hybrid", - layer_type="m2", - mil_init=True, -): - print(f"index_to_swap: {index_to_swap}, checkpoint: {checkpoint}") - - layer_importance = [ - 47, - 39, - 24, - 36, - 31, - 43, - 32, - 20, - 38, - 37, - 30, - 33, - 22, - 23, - 40, - 42, - 44, - 35, - 41, - 27, - 21, - 46, - 45, - 49, - 25, - 34, - 29, - 28, - 19, - 26, - 18, - 17, - 16, - 13, - 15, - 14, - 8, - 9, - 12, - 6, - 11, - 5, - 48, - 7, - 10, - 3, - 4, - 1, - 0, - ] - path_base = "/mnt/checkpoints/upstream/Apriel-Nemotron-15b-Thinker" - config_base = AutoConfig.from_pretrained(path_base) - hybrid_block_layout = ["t"] * config_base.num_hidden_layers - - for i in range(index_to_swap + 1): - layer_idx = int(layer_importance[i]) - print(f"Swapping layer {layer_idx} to {layer_type}") - hybrid_block_layout[layer_idx] = layer_type - - transformer = AutoModelForCausalLM.from_pretrained(path_base) - model_hybrid_prev = AprielThinkerSSMHybridForCausalLM.from_pretrained(checkpoint, trust_remote_code=True).to( - torch.bfloat16 - ) - config_hybrid = AprielSSMHybridConfig(**model_hybrid_prev.config.to_dict()) - config_hybrid.hybrid_block_layout = hybrid_block_layout - convert_layers(transformer, config_hybrid, hybrid_block_layout, mil_init, torch.bfloat16) - - missing, unexpected = transformer.load_state_dict( - model_hybrid_prev.state_dict(), strict=False - ) # will not load the newly innitialized layer (will stay MIL), but will overwrite previous layers - if missing: - print("Missing keys:", missing) - if unexpected: - print("Unexpected keys:", unexpected) - transformer.to(torch.bfloat16) - model_hybrid_prev = None - print(transformer) - model_hybrid = AprielThinkerSSMHybridForCausalLM(config_hybrid) - missing, unexpected = model_hybrid.load_state_dict(transformer.state_dict()) - assert len(missing) == 0, "Missing keys: " + str(missing) - assert len(unexpected) == 0, "Unexpected keys: " + str(unexpected) - - model_hybrid.save_pretrained(f"{output_model_path}") - # config_hybrid.save_pretrained(f"{output_model_path}") - - -if __name__ == "__main__": - main() - # main( - # index_to_swap=1, - # checkpoint="/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/15b-ihyb1lrklm216mil-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti1000_lm2/export/apriel_ssm_thinker_hybrid/1000", - # layer_type="m2", - # ) diff --git a/fast_llm/models/ssm/external/nemotron/config.py b/fast_llm/models/ssm/external/nemotron/config.py new file mode 100644 index 00000000..058cd0cb --- /dev/null +++ b/fast_llm/models/ssm/external/nemotron/config.py @@ -0,0 +1,249 @@ +# Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""NemotronH model configuration""" + +import re + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class NemotronHConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`NemotronHModel`]. It is used to instantiate a + NemotronH model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the NemotronH-v0.1 model. + + [todo](todo) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 131072): + Vocabulary size of the NemotronH model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`NemotronHModel`] + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the + model has a output word embedding layer. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 21504): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 52): + Number of hidden layers in the Transformer encoder. + hybrid_override_pattern (`str`, *optional*, defaults to `"M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-"`): + The pattern of the hybrid model. The pattern is a string of characters where each character represents M: Mamba2, *: Attention, -: MLP + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + attention_head_dim (`int`, *optional*, defaults to 128): + Dimension of each attention head. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. + mlp_hidden_act (`str`, *optional*, defaults to "relu2"): + The non-linear activation function in the MLP layers. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in attention layers. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in MLP layers. + use_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the model. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + residual_in_fp32 (`bool`, *optional*, defaults to `False`): + Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + num_logits_to_keep (`int` or `None`, *optional*, defaults to 1): + Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an + integer value, only last `num_logits_to_keep` logits will be calculated. + pad_token_id (`int`, *optional*, defaults to 0): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + sliding_window (`int`, *optional*, defaults to None): + Sliding window attention window size. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + hidden_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the hidden states. + use_mamba_kernels (`bool`, *optional*, defaults to `True`): + Flag indicating whether or not to use the fast mamba kernels. These are available only if `mamba-ssm` and + `causal-conv1d` are installed, and the mamba modules are running on a CUDA device. + ssm_state_size (`int`, *optional*, defaults to 128): + The dimension of the mamba state space latents. + mamba_num_heads (`int`, *optional*, defaults to 128): + Number of heads in Mamba layers. + mamba_n_groups (`int`, *optional*, defaults to 8): + Number of groups in Mamba layers. + mamba_head_dim (`int`, *optional*, defaults to 64): + Dimension of each Mamba head. + mamba_d_conv (`int`, *optional*, defaults to 4): + The size of the mamba convolution kernel. + mamba_expand (`int`, *optional*, defaults to 2): + Expanding factor used to determine the mamba intermediate size. + mamba_hidden_act (`str`, *optional*, defaults to "silu"): + The non-linear activation function in the Mamba layers. + mamba_dt_min (`float`, *optional*, defaults to 0.001): + Minimum value for the time step in Mamba. + mamba_dt_max (`float`, *optional*, defaults to 0.1): + Maximum value for the time step in Mamba. + mamba_dt_limit (`tuple`, *optional*, defaults to (0.0, float("inf"))): + Limits for the time step in Mamba. + mamba_dt_init_floor (`float`, *optional*, defaults to 1e-4): + Floor value for time step initialization in Mamba. + mamba_conv_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the convolution layer of the mamba mixer block. + mamba_proj_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the input and output projections of the mamba mixer block. + mamba_chunk_size (`int`, *optional*, defaults to 256): + Size of chunks for Mamba processing. + rescale_prenorm_residual (`bool`, *optional*, defaults to `True`): + Whether to rescale the pre-normalization residual connections. + """ + + model_type = "nemotron_h" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=131072, + tie_word_embeddings=False, + hidden_size=4096, + intermediate_size=21504, + num_hidden_layers=52, + hybrid_override_pattern="M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-", + num_attention_heads=32, + attention_head_dim=128, + num_key_value_heads=8, # nemo: num_query_groups + mlp_hidden_act="relu2", + attention_bias=False, + mlp_bias=False, + use_bias=False, + initializer_range=0.02, # nemo: init_method_std + layer_norm_epsilon=1e-5, # nemo: layernorm_epsilon + residual_in_fp32=False, # Megatron Core default value + use_cache=True, + num_logits_to_keep=1, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + sliding_window=None, + max_position_embeddings=4096, + attention_dropout=0.0, + hidden_dropout=0.0, # * ADDED + use_mamba_kernels=True, + ssm_state_size=128, # mamba_state_size + mamba_num_heads=128, + mamba_n_groups=8, # nemo: mamba_ssm_ngroups = num_heads + mamba_head_dim=64, + mamba_d_conv=4, + mamba_expand=2, + mamba_hidden_act="silu", + mamba_dt_min=0.001, + mamba_dt_max=0.1, + mamba_dt_limit=(0.0, float("inf")), + mamba_dt_init_floor=1e-4, + mamba_conv_bias=True, + mamba_proj_bias=False, + mamba_chunk_size=256, + rescale_prenorm_residual=True, + **kwargs, + ): + self.vocab_size = vocab_size + self.tie_word_embeddings = tie_word_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.hybrid_override_pattern = hybrid_override_pattern + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.sliding_window = sliding_window + self.max_position_embeddings = max_position_embeddings + self.attention_dropout = attention_dropout + self.hidden_dropout = hidden_dropout + + # Validate hybrid_override_pattern + # M: Mamba2, *: Attention, -: MLP + assert ( + len(self.hybrid_override_pattern) == self.num_hidden_layers + ), "hybrid_override_pattern must have the same length as num_hidden_layers" + assert re.match( + r"^[*-M]+$", self.hybrid_override_pattern + ), "hybrid_override_pattern must only contain characters 'M', '*', or '-'" + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.mlp_hidden_act = mlp_hidden_act + self.attention_bias = attention_bias + self.mlp_bias = mlp_bias + self.use_bias = use_bias + self.initializer_range = initializer_range + self.layer_norm_epsilon = layer_norm_epsilon + self.residual_in_fp32 = residual_in_fp32 + + self.use_cache = use_cache + self.num_logits_to_keep = num_logits_to_keep + + self.use_mamba_kernels = use_mamba_kernels + self.n_groups = mamba_n_groups + self.mamba_head_dim = mamba_head_dim + self.ssm_state_size = ssm_state_size + self.mamba_num_heads = mamba_num_heads + self.conv_kernel = mamba_d_conv + self.expand = mamba_expand + self.mamba_hidden_act = mamba_hidden_act + self.time_step_min = mamba_dt_min + self.time_step_max = mamba_dt_max + self.time_step_limit = mamba_dt_limit + self.time_step_floor = mamba_dt_init_floor + self.use_conv_bias = mamba_conv_bias + self.mamba_proj_bias = mamba_proj_bias + self.chunk_size = mamba_chunk_size + self.rescale_prenorm_residual = rescale_prenorm_residual + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @property + def layers_block_type(self): + return [ + ( + "mamba" + if self.hybrid_override_pattern[i] == "M" + else "attention" if self.hybrid_override_pattern[i] == "*" else "mlp" + ) + for i in range(self.num_hidden_layers) + ] diff --git a/fast_llm/models/ssm/external/nemotron/modeling.py b/fast_llm/models/ssm/external/nemotron/modeling.py new file mode 100644 index 00000000..154316c7 --- /dev/null +++ b/fast_llm/models/ssm/external/nemotron/modeling.py @@ -0,0 +1,1628 @@ +# Copyright 2024 HuggingFace Inc. team. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch NemotronH model.""" + +import math +from dataclasses import dataclass +from typing import Any, Optional, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.activations import ACT2FN +from transformers.cache_utils import DynamicCache # we need __iter__ and __len__ of pkv +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from transformers.utils.import_utils import ( + is_causal_conv1d_available, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + is_mamba_2_ssm_available, +) + +from fast_llm.models.ssm.external.nemotron.config import NemotronHConfig + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.mamba.modeling_mamba2.modeling_mamba2.py with MAMBA2->NEMOTRONH,Mamba2->NemotronH +# For Mamba2 components Mamba2->NemotronHMamba2 +if is_mamba_2_ssm_available(): + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined +else: + mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined, selective_state_update = None, None, None + +try: + # from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated + from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn +except ImportError: + raise ImportError("mamba-ssm is required by the Mamba model but cannot be imported") + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None + +if is_flash_attn_2_available(): + from transformers.modeling_flash_attention_utils import _flash_attention_forward + +is_fast_path_available = all( + ( + selective_state_update, + mamba_chunk_scan_combined, + mamba_split_conv1d_scan_combined, + causal_conv1d_fn, + causal_conv1d_update, + ) +) + + +_CHECKPOINT_FOR_DOC = "nvidia/Nemotron-H-56B-Base-8K" +_CONFIG_FOR_DOC = "NemotronHConfig" + + +# Helper methods for segment sum computation + + +def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): + """ + Padding x tensor with `pad_size` on the seq_len dim (dim=1) + + Assumes that we only have tensors of either size 4 or 3 + """ + pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) + + return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) + + +def reshape_into_chunks(input_tensor, pad_size, chunk_size): + """ + Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and + simultaneously splitting it into chunk sequences. + + Assumes that we only have tensors of either size 4 or 3 + """ + # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] + input_tensor = pad_tensor_by_size(input_tensor, pad_size) + + if len(input_tensor.shape) == 3: + # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] + return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) + else: + # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size] + return input_tensor.reshape( + input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3] + ) + + +def segment_sum(input_tensor): + """ + More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. + """ + chunk_size = input_tensor.size(-1) + # 1. expand input tensor to have an additional dimension and repeat along that dimension + # [..., chunk_size] -> [..., chunk_size, chunk_size] + input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size) + # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1) + input_tensor = input_tensor.masked_fill(~mask, 0) + # 3. compute actual cumsum + tensor_segsum = torch.cumsum(input_tensor, dim=-2) + + # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0) + tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf) + return tensor_segsum + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + +# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py +class HybridMambaAttentionDynamicCache(DynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__(self, config, batch_size, dtype=torch.float16, device=None): + super().__init__() + self.dtype = dtype + self.hybrid_override_pattern = config.hybrid_override_pattern + self.has_previous_state = False # only used by mamba + intermediate_size = config.expand * config.hidden_size + ssm_state_size = config.ssm_state_size + self.conv_kernel_size = conv_kernel_size = config.conv_kernel + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + for i in range(config.num_hidden_layers): + if self.hybrid_override_pattern[i] == "M": + # Mamba layer + self.conv_states += [ + torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) + ] + self.ssm_states += [ + torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype) + ] + else: + # Attention or MLP layer + self.conv_states += [torch.tensor([[]] * batch_size, device=device)] + self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] + self.transformer_layers.append(i) + + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Update the cache + if self.key_cache[layer_idx].shape[-1] == 0: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + # Copied from modeling_mamba2.py + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False + ) -> torch.Tensor: + if cache_init: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states[0].device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[0].device) + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +class MambaRMSNormGated(torch.nn.Module): + def __init__(self, hidden_size, group_size, eps=1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.group_size = group_size + + # jan28b version + def forward(self, hidden_states, gate=None): + return rmsnorm_fn( + x=hidden_states, + weight=self.weight, + bias=None, # No bias + z=gate, + eps=self.variance_epsilon, + group_size=self.group_size, + norm_before_gate=False, + ) + + +class NemotronHMamba2Mixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + """ + + def __init__(self, config: NemotronHConfig, layer_idx: int): + super().__init__() + self.num_heads = config.mamba_num_heads + self.hidden_size = config.hidden_size + self.ssm_state_size = config.ssm_state_size + self.conv_kernel_size = config.conv_kernel + self.intermediate_size = config.mamba_num_heads * config.mamba_head_dim + self.layer_idx = layer_idx + self.use_conv_bias = config.use_conv_bias + self.activation = config.mamba_hidden_act + self.act = ACT2FN[config.mamba_hidden_act] + + self.layer_norm_epsilon = config.layer_norm_epsilon + + self.n_groups = config.n_groups + self.head_dim = config.mamba_head_dim + self.chunk_size = config.chunk_size + + self.time_step_limit = config.time_step_limit + self.time_step_min = config.time_step_min + self.time_step_max = config.time_step_max + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=config.use_conv_bias, + kernel_size=config.conv_kernel, + groups=self.conv_dim, + padding=config.conv_kernel - 1, + ) + + # projection of the input hidden states + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=config.use_bias, + ) + # selective projection used to make dt, B and C input dependant + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.num_heads + 1) + self.A_log = nn.Parameter(torch.log(A)) + self.A_log._no_weight_decay = True + self.norm = MambaRMSNormGated( + self.intermediate_size, eps=self.layer_norm_epsilon, group_size=self.intermediate_size // self.n_groups + ) + self.D = nn.Parameter(torch.ones(self.num_heads)) + self.D._no_weight_decay = True + + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + self.use_bias = config.use_bias + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" + " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + # 1. Gated MLP's linear projection + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + projected_states = self.in_proj(hidden_states) + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + groups_time_state_size = self.n_groups * self.ssm_state_size + d_mlp = ( + projected_states.shape[-1] + - 2 * self.intermediate_size + - 2 * self.n_groups * self.ssm_state_size + - self.num_heads + ) // 2 + + # Single step calculations via cache + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + _, _, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + cache_params.conv_states[self.layer_idx], + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # (nheads,) + A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) + C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) + hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) + hidden_states = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + ) + hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim) + hidden_states = self.norm(hidden_states, gate) + + # 4. Final linear projection + out = self.out_proj(hidden_states)[:, None, ...] + + # Fused calculations or step by step if no initialized cache is found + else: + A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) + dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} + + # 2-4. Fused kernel for conv1d, SSM, and the final projection + if self.training and cache_params is None: + out = mamba_split_conv1d_scan_combined( + projected_states, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.dt_bias, + A, + D=self.D, + chunk_size=self.chunk_size, + seq_idx=None, # was seq_idx + activation=self.activation, + rmsnorm_weight=self.norm.weight, + rmsnorm_eps=self.norm.variance_epsilon, + outproj_weight=self.out_proj.weight, + outproj_bias=self.out_proj.bias, + headdim=self.head_dim, + ngroups=self.n_groups, + norm_before_gate=False, + return_final_states=False, + **dt_limit_kwargs, + ) + + else: + _, _, gate, hidden_states_B_C, dt = projected_states.split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, + (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), + ) + cache_params.update_conv_state( + layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True + ) + + if self.activation not in ["silu", "swish"]: + hidden_states_B_C = self.act( + self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2) + ) + else: + hidden_states_B_C = causal_conv1d_fn( + x=hidden_states_B_C.transpose(1, 2), + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + ).transpose(1, 2) + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + + # 3. SSM transformation + scan_output, ssm_state = mamba_chunk_scan_combined( + hidden_states.view(batch_size, seq_len, -1, self.head_dim), + dt, + A, + B.view(batch_size, seq_len, self.n_groups, -1), + C.view(batch_size, seq_len, self.n_groups, -1), + chunk_size=self.chunk_size, + D=self.D, + z=None, + seq_idx=None, + return_final_states=True, + dt_bias=self.dt_bias, + dt_softplus=True, + **dt_limit_kwargs, + ) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state) + + scan_output = scan_output.view(batch_size, seq_len, -1) + + # Multiply "gate" branch and apply extra normalization layer + scan_output = self.norm(scan_output, gate) + + # 4. Final linear projection + out = self.out_proj(scan_output) + return out + + # fmt: off + def torch_forward(self, input_states, cache_params: Optional[HybridMambaAttentionDynamicCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + + # 1. Gated MLP's linear projection + input_states = apply_mask_to_padding_states(input_states, attention_mask) + projected_states = self.in_proj(input_states) + d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2 + _, _, gate, hidden_states_B_C, dt = projected_states.split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=hidden_states_B_C, cache_init=False) + + # We need to guarantee that anything regarding the cache is on the same device + conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device) + + hidden_states_B_C = torch.sum( + conv_states * self.conv1d.weight.squeeze(1), dim=-1 + ) + if self.use_conv_bias: + hidden_states_B_C = hidden_states_B_C + self.conv1d.bias + hidden_states_B_C = self.act(hidden_states_B_C) + else: + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + ) + cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True) + + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], + dim=-1 + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # [num_heads] + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + # We need to guarantee that anything regarding the cache is on the same device + cache_device = cache_params.ssm_states.device + + # Note: there is no need to pad parameter matrices here, as there is just one new token + # for batched generation + dt = dt[:, 0, :][:, None, ...] + dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) + # [num_heads] -> [num_heads, head_dim] + dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) + + dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + # [bsz, num_heads, head_dim, state_size] + dA = (torch.exp(dt[..., None] * A)).to(device=cache_device) + + # Discretize B + # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> + # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] + B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] + B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() + B = B.reshape(batch_size, -1, B.shape[-1]) + # [bsz, num_heads, head_dim, state_size] + dB = dt[..., None] * B[..., None, :] + + # Discretize x into dB + # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] + hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) + dBx = (dB * hidden_states[..., None]).to(device=cache_device) + + # State calculation + cache_params.update_ssm_state( + layer_idx=self.layer_idx, + new_ssm_state=cache_params.ssm_states[self.layer_idx] * dA + dBx + ) + + # Subsequent output + # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] + C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] + C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() + C = C.reshape(batch_size, -1, C.shape[-1]) + # [bsz, num_heads, head_dim] + + ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] + # Reshape ssm_states to merge the first two dimensions + ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] + C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] + y = torch.bmm(ssm_states_reshaped, C_reshaped) + y = y.view(batch_size, self.num_heads, self.head_dim) + + # D skip connection + # [num_heads] -> [num_heads, head_dim] + D = self.D[..., None].expand(self.D.shape[0], self.head_dim) + y = (y + hidden_states * D).to(y.dtype) + + # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] + y = y.reshape(batch_size, -1)[:, None, ...] + else: + # begin ssd naive implementation without einsums + dt = nn.functional.softplus(dt + self.dt_bias) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() + B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) + C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) + pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size + + D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) + + # Discretize x and A + hidden_states = hidden_states * dt[..., None] + A = A.to(hidden_states.dtype) * dt + + # Rearrange into blocks/chunks + hidden_states, A, B, C = (reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)) + + # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] + A = A.permute(0, 3, 1, 2) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + # This is the analog of a causal mask + L = torch.exp(segment_sum(A)) + + # Contraction of C and B to get G (attention-weights like) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n) + G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) + + # Compute M, equivalent to applying attention mask to weights + M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] + M = M_intermediate.sum(dim=-1) + + # Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum) + B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None] + states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) + else: + previous_states = torch.zeros_like(states[:, :1]) + states = torch.cat([previous_states, states], dim=1) + decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) + decay_chunk = decay_chunk.transpose(1, 3) + new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1) + states, ssm_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + C_times_states = (C[..., None, :] * states[:, :, None, ...]) + state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) + Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + y = Y_diag + Y_off + # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] + y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) + + y = y + D_residual + # Cutting off padded chunks + if pad_size > 0: + y = y[:, :seq_len, :, :] + y = y.reshape(batch_size, seq_len, -1) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state) + + scan_output = self.norm(y, gate) + + # end ssd naive + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size] + return contextualized_states + # fmt: on + + def forward( + self, + hidden_states, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + dtype = hidden_states.dtype + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) + + +class NemotronHRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + NemotronHRMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + # Weights are in float32 + return (self.weight.to(torch.float32) * hidden_states).to(input_dtype) + + +class NemotronHBlock(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.residual_in_fp32 = config.residual_in_fp32 + self.norm = NemotronHRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + # M: Mamba2, *: Attention, -: MLP + self.block_type = config.layers_block_type[layer_idx] + if self.block_type == "mamba": + self.mixer = NemotronHMamba2Mixer(config, layer_idx=layer_idx) + elif self.block_type == "attention": + self.mixer = NEMOTRONH_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) + elif self.block_type == "mlp": + self.mixer = NemotronHMLP(config, layer_idx=layer_idx) + else: + raise ValueError(f"Invalid layer pattern {config.hybrid_override_pattern[layer_idx]}") + + def forward( + self, + hidden_states, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + with torch.cuda.stream(torch.cuda.default_stream(hidden_states.device)): + # * Use torch.cuda.stream() to avoid NaN issues when using multiple GPUs + residual = hidden_states + hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + + if self.block_type == "mamba": + hidden_states = self.mixer(hidden_states, cache_params=cache_params, cache_position=cache_position) + elif self.block_type == "attention": + hidden_states = self.mixer(hidden_states, cache_position=cache_position) + hidden_states = hidden_states[0] + elif self.block_type == "mlp": + hidden_states = self.mixer(hidden_states) + else: + raise ValueError(f"Invalid block_type: {self.block_type}") + + hidden_states = residual + hidden_states + return hidden_states + + +# Copied from transformers.models.nemotron.modeling_nemotron Nemotron->NemotronH +class NemotronHMLP(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.mlp_hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.up_proj(x))) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class NemotronHAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: NemotronHConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + if config.attention_head_dim is not None: + self.head_dim = config.attention_head_dim + else: + self.head_dim = config.hidden_size // config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.is_causal = True + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.head_dim * self.num_heads, self.hidden_size, bias=config.attention_bias) + + def forward( + self, + hidden_states: torch.Tensor, + # position_embeddings: Tuple[torch.Tensor, torch.Tensor], #TODO + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + # attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +# Adapted from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Jamba +# class JambaFlashAttention2(JambaAttention): +class NemotronHFlashAttention2(NemotronHAttention): + """ + Jamba flash attention module. This module inherits from `JambaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + sliding_window=getattr(self.config, "sliding_window", None), + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + # attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Adapted from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Jamba +# class JambaSdpaAttention(JambaAttention): +class NemotronHSdpaAttention(NemotronHAttention): + """ + Jamba attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `JambaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from NemotronHAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "NemotronHModel is using NemotronHSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if self.is_causal and causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +NEMOTRONH_ATTENTION_CLASSES = { + "eager": NemotronHAttention, + "flash_attention_2": NemotronHFlashAttention2, + "sdpa": NemotronHSdpaAttention, +} + + +# Copied from transformers.models.mamba.modeling_mamba2.Mamba2PreTrainedModel +class NemotronHPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = NemotronHConfig + base_model_prefix = "backbone" + _no_split_modules = ["NemotronHBlock"] + supports_gradient_checkpointing = True + _is_stateful = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, NemotronHMamba2Mixer): + module.A_log._no_weight_decay = True + module.D._no_weight_decay = True + + dt = torch.exp( + torch.rand(self.config.mamba_num_heads) + * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + + math.log(self.config.time_step_min) + ).clamp(min=self.config.time_step_floor) + + # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + module.dt_bias.copy_(inv_dt) + module.dt_bias._no_reinit = True + + if isinstance(module, nn.Linear): + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=self.config.initializer_range) + + # TODO: Check + if self.config.rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(self.config.num_hidden_layers) + + +@dataclass +# Copied from transformers.models.mamba.modeling_mamba2.Mamba2Output with MAMBA2->NemotronH,Mamba2->NemotronH +class NemotronHOutput(ModelOutput): + """ + Class for the NemotronH model outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + cache_params (`HybridMambaAttentionDynamicCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + cache_params: Optional[HybridMambaAttentionDynamicCache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.mamba2.modeling_mamba2.MambaCausalLMOutput with Mamba2->NemotronH +class NemotronHCausalLMOutput(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + cache_params (`HybridMambaAttentionDynamicCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + cache_params: Optional[HybridMambaAttentionDynamicCache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +NEMOTRONH_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`NemotronHConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +NEMOTRONH_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Indices of input sequence tokens in the vocabulary. + + If `cache_params.seqlen_offset>0`, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + position_ids (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. + cache_params (`HybridMambaAttentionDynamicCache`, *optional*): + If passed along, the model uses the previous state in all the blocks (which will give the output for the + `input_ids` provided as if the model add `state_input_ids + input_ids` as context). + use_cache (`bool`, *optional*): + If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + The position of the current input in the cache. This is used to ensure that the cache is correctly updated. + If `cache_params` is passed, `cache_position` should also be passed. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) +""" + + +@add_start_docstrings( + "The bare NemotronH Model transformer outputting raw hidden-states without any specific head on top.", + NEMOTRONH_START_DOCSTRING, +) +class NemotronHModel(NemotronHPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList([NemotronHBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]) + + self.gradient_checkpointing = False + self.norm_f = NemotronHRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + # Initialize weights and apply final processing + self._register_load_state_dict_pre_hook(self.load_hook) + self.post_init() + + def load_hook(self, state_dict, prefix, *args): + for k in state_dict: + if "embedding." in k: + state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k) + break + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings = new_embeddings + + @add_start_docstrings_to_model_forward(NEMOTRONH_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=NemotronHOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[tuple, NemotronHOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + # use_cache = use_cache if use_cache is not None else self.config.use_cache + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + # From zamba_modeling.py + if use_cache and cache_params is None: + logger.warning_once( + "NemotronH requires an initialized `NemotronHHybridDynamicCache` to return a cache. None was " + "provided, so no cache will be returned." + ) + + hidden_states = inputs_embeds + + if cache_position is None: + cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + mamba_mask = self._update_mamba_mask(attention_mask, cache_position) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + # Until HERE + + for layer_idx, mixer_block in enumerate(self.layers): + # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention) + if mixer_block.block_type == "mamba": + layer_mask = mamba_mask + elif mixer_block.block_type == "attention": + layer_mask = causal_mask + elif mixer_block.block_type == "mlp": + layer_mask = None + else: + raise ValueError(f"Invalid block_type: {self.block_type}") + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + mixer_block.__call__, hidden_states, cache_params, cache_position, layer_mask + ) + else: + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=layer_mask, + ) + + # TODO: Store attentions + # if output_attentions: + # if layer_outputs[1] is not None: + # # append attentions only of attention layers. Mamba layers return `None` as the attention weights + # all_self_attns += (layer_outputs[1],) + + # TODO (Check): should it happen before the forward pass? + # if output_hidden_states: + # all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.norm_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None) + + return NemotronHOutput( + last_hidden_state=hidden_states, + cache_params=cache_params if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from transformers.models.jamba.modeling_jamba.JambaModel._update_causal_mask + def _update_causal_mask(self, attention_mask, input_tensor, cache_position): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + target_length = cache_position[-1] + 1 + + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + def _update_mamba_mask(self, attention_mask, cache_position): + """ + No need for zeroing states when + 1. Cached forward + 2. Attending to all inputs + """ + mamba_mask = attention_mask + if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): + mamba_mask = None + return mamba_mask + + +@add_start_docstrings( + """ + The NEMOTRONH Model transformer with a language modeling head on top (linear layer with weights not tied to the input + embeddings). + """, + NEMOTRONH_START_DOCSTRING, +) +class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.backbone = NemotronHModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.backbone.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.backbone.set_input_embeddings(new_embeddings) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_decoder(self): + return self.model + + def set_decoder(self, decoder): + self.model = decoder + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py + # Overwitten -- uses `cache_params` as opposed to `past_key_values` + empty_past_kv = past_key_values is None + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if not empty_past_kv: + if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config, input_ids.shape[0], self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs + + @add_start_docstrings_to_model_forward(NEMOTRONH_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=NemotronHCausalLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, # for now we need this for generation + ) -> Union[tuple, NemotronHCausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + cache_params = cache_params if cache_params is not None else kwargs["past_key_values"] + + nemotron_h_outputs = self.backbone( + input_ids, + cache_params=cache_params, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_cache=use_cache, + cache_position=cache_position, + attention_mask=attention_mask, + ) + hidden_states = nemotron_h_outputs[0] + + # TODO: Check zamba_modeling.py: https://github.com/huggingface/transformers/blob/d7188ba600e36d3fd191b12e19f1b3bb81a8404f/src/transformers/models/zamba/modeling_zamba.py#L1284C1-L1286C2 + # logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float() + logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float() + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (logits,) + nemotron_h_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return NemotronHCausalLMOutput( + loss=loss, + logits=logits, + cache_params=nemotron_h_outputs.cache_params, + hidden_states=nemotron_h_outputs.hidden_states, + attentions=nemotron_h_outputs.attentions, + ) diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 29f115bd..fafe4409 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -6,6 +6,7 @@ from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.layers.language_model.head import LanguageModelHead from fast_llm.layers.ssm.llamba_block import SSMBlock +from fast_llm.layers.ssm.preprocessing import Mamba2Preprocessor from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.models.gpt.model import GPTBaseModel, GPTModel @@ -30,6 +31,7 @@ def __init__( distributed_config: DistributedConfig, ): super().__init__(config, distributed_config) + self._preprocessors.append(Mamba2Preprocessor(config, self._tensor_space)) def get_output_layers(self) -> list[Layer]: """ diff --git a/setup.cfg b/setup.cfg index 6ea98610..c2eb1f6f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,7 +50,7 @@ HUGGINGFACE = # To install on cpu environment (ex. for IDE support): # MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation SSM = - mamba_ssm[causal-conv1d]==2.2.4 + mamba_ssm[causal-conv1d] @ git+https://github.com/jxiw/varlen_mamba.git@varlen_mamba cartesia_pytorch>=0.0.2 # GENERATION = diff --git a/tests/test_ssms.py b/tests/test_ssms.py index 694faa55..f9c7dc57 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -1,19 +1,60 @@ +import inspect +import itertools import pathlib +from functools import partial import pytest import torch +from mamba2 import Mamba2, NemotronHMamba2 from fast_llm.config import NoAutoValidate from fast_llm.engine.checkpoint.config import CheckpointLoadConfig +from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, PhaseType +from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.schedule.config import ScheduleConfig from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.engine.schedule.schedule import Schedule -from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.ssm.config import SSMConfig +from fast_llm.layers.ssm.llamba_block import SSMBlock +from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs from fast_llm.models.gpt.config import GPTBatchConfig -from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat +from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, LLambaHuggingfaceCheckpointFormat from fast_llm.models.ssm.model import HybridSSMModel +_mamba_varlen = False +try: + from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa + + _mamba_available = True + sig = inspect.signature(selective_scan_fn) + if "position_indices" in sig.parameters: + _mamba_varlen = True + else: + _mamba_varlen = False + # for training with packing install https://github.com/jxiw/varlen_mamba + # see https://github.com/jxiw/M1/blob/main/HYBRID_PACK.md + +except (ImportError, RuntimeError): + _mamba_available = False + + +def get_hybrid_config(hybrid_block_layout=["t", "m2"], prediction_heads=1, default_mtp_type=None): + hidden_size = 512 + config = HybridSSMBaseModelConfig( + transformer=TransformerConfig(num_layers=len(hybrid_block_layout), hidden_size=hidden_size), + ssm=SSMConfig(d_xb=hidden_size, dt_rank=10, d_inner=hidden_size * 2, state_size=16, head_dim=8), + hybrid_block_layout=hybrid_block_layout, + prediction_heads=prediction_heads, + default_mtp_type=default_mtp_type, + init_method_std_embed=0.02, + init_method_min_embed=-0.02, + init_method_max_embed=0.02, + use_position_embeddings=True, + tie_word_embeddings=False, + ) + return config + @pytest.mark.skip("Disabled due to cartesia_pytorch installation issue") @pytest.mark.slow @@ -80,3 +121,245 @@ def test_load_from_llamba_checkpoint(): logits = input_data[0][1]["logits"].cpu() assert torch.allclose(logits, hf_logits, atol=1e-2) + + +@pytest.fixture +def distributed_config(): + return DistributedConfig( + tensor_parallel=1, + pipeline_parallel=1, + sequence_data_parallel=1, + local_world_size=1, + world_size=1, + ) + + +@pytest.fixture +def distributed(distributed_config): + return Distributed(config=distributed_config) + + +def materialize_meta_tensors(model, tensor_space): + # Materialize parameters that are on meta device + for name, param in model.named_parameters(): + if param.device.type == "meta": + # Check if the parameter is a custom tensor type + if hasattr(param, "tensor_name") and hasattr(param, "init_parameter"): + param_data = param.new_empty(param.shape, device="cuda") + # Initialize param_data + param.init_parameter(param_data, tensor_space.distributed) + # Replace the parameter in the module + module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) + module = model + if module_path is not None: + for part in module_path.split("."): + module = getattr(module, part) + param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) + # TODO: add param_grad_is_zero etc., grad_buffer, etc., see test_mlp_recomputation + param.grad = None + param.grad_buffer = torch.empty_like(param) + param.param_grad_is_zero = True + module._parameters[param_name] = param + return model + + +def unpack(packed_hidden_states, cu_seqlens): + batch_size = packed_hidden_states.shape[0] + package_num = cu_seqlens.shape[0] - 1 + seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + hidden_dim = packed_hidden_states.shape[2] + hidden_states = torch.zeros( + package_num * batch_size, + seq_len, + hidden_dim, + dtype=packed_hidden_states.dtype, + device=packed_hidden_states.device, + ) + for j in range(batch_size): + for i in range(package_num): + line = j * package_num + i + hidden_states[line, : cu_seqlens[i + 1] - cu_seqlens[i], :] = packed_hidden_states[ + j, cu_seqlens[i] : cu_seqlens[i + 1], : + ] + return hidden_states + + +def pack(hidden_states, cu_seqlens, batch_size): + package_num, seq_len, hidden_dim = hidden_states.shape + seq_len_list = cu_seqlens[1:] - cu_seqlens[:-1] + seq_len_list_3d = seq_len_list.unsqueeze(1).unsqueeze(2) + indices_3d = ( + torch.arange(seq_len, device=hidden_states.device).unsqueeze(0).unsqueeze(2).repeat(package_num, 1, hidden_dim) + ) + mask_3d = indices_3d < seq_len_list_3d.repeat(batch_size, 1, 1) + packed_hidden_states = hidden_states[mask_3d].view(batch_size, -1, hidden_dim) + return packed_hidden_states + + +def generate_random_cu_seqlens(seq_len, packages_num=2): + if packages_num < 1: + raise ValueError("packages_num must be at least 1") + + # base size of each chunk, and how many get an extra token + base, rem = divmod(seq_len, packages_num) + # lengths: e.g. for seq_len=10, packages=3 → [4,3,3] + lengths = [base + 1 if i < rem else base for i in range(packages_num)] + + # split points exclude the final cumulative (seq_len) + split_points = list(itertools.accumulate(lengths))[:-1] + + # cu_seqlens = [0] + split_points + [seq_len] + cu_seqlens = [0] + split_points + [seq_len] + + # index: for each chunk, we emit 0,1,...,length-1 + index = [] + for length in lengths: + index.extend(range(length)) + + # sanity check + assert len(cu_seqlens) - 1 == packages_num + assert sum(lengths) == seq_len + assert len(index) == seq_len + + return cu_seqlens, index + + +# Quick and dirty test for Mamba2 varlen block from https://github.com/jxiw/M1/blob/d92b53faa640f8ebf624d3e9e771fe24648ef014/rl/verl/tests/pack_mamba/test_mamba_layer.py +# test that packed and not packed are producing the same result in terms of outputs and gradients +# TODO: integrate in the testing framework +@pytest.mark.slow +@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA available") +@pytest.mark.skipif(not _mamba_available, reason="Mamba2 is not available") +@pytest.mark.parametrize( + "mixer_cls, hybrid_block_layout, tollerance", + [ + pytest.param( + partial(NemotronHMamba2, block_index=0), + ["nm2", "t"], + 1e-3, # not 100% sure why, mamba2 requires lower tollerance (maybe its not really supporting packing) + id="nemotron_hmamba2", + ), + pytest.param( + partial(Mamba2, block_index=0), + ["m2", "t"], + 1e-4, + marks=pytest.mark.skipif(not _mamba_varlen, reason="Mamba2 varlen is not available"), + id="mamba2", + ), + ], +) +def test_mamba_varlen_block(mixer_cls, hybrid_block_layout, tollerance, distributed_config, distributed): + """ + Compare that the output and grads of packed and unpacked Mamba2 varlen block are the same. + """ + hybrid_config = get_hybrid_config(hybrid_block_layout=hybrid_block_layout) + tensor_space = TensorSpace(distributed_config=distributed_config) + tensor_space.setup(distributed) + hybrid_config.setup_tensor_space(tensor_space) + layer_idx = 0 + + block_packed = SSMBlock( + hybrid_config.transformer, + hybrid_config.ssm, + mixer_cls=mixer_cls, + tensor_space=tensor_space, + block_index=layer_idx, + ) + block_ref = SSMBlock( + hybrid_config.transformer, + hybrid_config.ssm, + mixer_cls=mixer_cls, + tensor_space=tensor_space, + block_index=layer_idx, + ) + device = "cuda" + materialize_meta_tensors(block_packed, tensor_space) + materialize_meta_tensors(block_ref, tensor_space) + block_ref.load_state_dict(block_packed.state_dict()) + block_packed.to(device) + block_ref.to(device) + + batch_size = 2 + seq_len = 64 + packages_num = 2 + hidden_dim = hybrid_config.transformer.hidden_size + + cu_seqlens, index = generate_random_cu_seqlens(seq_len, packages_num=packages_num) + cu_seqlens = torch.tensor(cu_seqlens).cuda() + ssm_position_ids = torch.tensor(index, dtype=torch.int32).unsqueeze(0).expand(batch_size, -1).contiguous().cuda() + seq_idx = ( + torch.cat( + [ + torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device) + for i, s in enumerate(cu_seqlens[1:] - cu_seqlens[:-1]) + ], + dim=0, + ) + .unsqueeze(0) + .repeat(batch_size, 1) + ) + + # Generate packed_hidden_states with random values for testing + hidden_states_list = [ + torch.randn(l, hidden_dim, device=device, dtype=torch.bfloat16, requires_grad=True) + for l in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + ] + packed_hidden_states = torch.cat(hidden_states_list, dim=0).unsqueeze(0) + packed_hidden_states = packed_hidden_states.expand(batch_size, -1, -1).contiguous() + # hidden_states should be forwarded without cu_seqlens + hidden_states = unpack(packed_hidden_states, cu_seqlens) + + # Check: sum of seq_len of item in hidden_states_list should be equal to seq_len of packed_hidden_states + assert sum([hs.shape[0] for hs in hidden_states_list]) == packed_hidden_states.shape[1] + # Check: max of seq_len of item in hidden_states_list should be equal to seq_len of hidden_states + assert max([hs.shape[0] for hs in hidden_states_list]) == hidden_states.shape[1] + + output_states_packed = block_packed( + packed_hidden_states, + {"cu_seqlens": cu_seqlens, "seq_idx": seq_idx, "ssm_position_ids": ssm_position_ids, "sequence_first": False}, + ) + output_states_unpacked = block_ref( + hidden_states.clone(), {"cu_seqlens": None, "seq_idx": None, "ssm_position_ids": None, "sequence_first": False} + ) + assert output_states_packed.shape == packed_hidden_states.shape + assert output_states_unpacked.shape == hidden_states.shape + assert not torch.isnan(hidden_states).any() + assert not torch.isinf(hidden_states).any() + + output_states_unpacked = pack(output_states_unpacked, cu_seqlens, batch_size) + assert torch.allclose(output_states_packed, output_states_unpacked, atol=tollerance) + + loss = output_states_packed.sum() + loss.backward() + loss_ref = output_states_unpacked.sum() + loss_ref.backward() + assert torch.allclose(block_packed.mixer.conv1d_weight.grad, block_ref.mixer.conv1d_weight.grad, atol=tollerance) + assert torch.allclose(block_packed.mixer.conv1d_bias.grad, block_ref.mixer.conv1d_bias.grad, atol=tollerance) + assert torch.allclose( + block_packed.mixer.in_proj.weight.grad_buffer, block_ref.mixer.in_proj.weight.grad_buffer, atol=tollerance + ) + assert torch.allclose( + block_packed.mixer.out_proj.weight.grad_buffer, block_ref.mixer.out_proj.weight.grad_buffer, atol=tollerance + ) + assert torch.allclose( + block_packed.mixer.dt_in_proj.weight.grad_buffer, + block_ref.mixer.dt_in_proj.weight.grad_buffer, + atol=tollerance, + ) + + assert torch.allclose( + block_packed.mlp.layer_1.weight.grad_buffer, block_ref.mlp.layer_1.weight.grad_buffer, atol=tollerance + ) + assert torch.allclose( + block_packed.mlp.layer_1.bias.grad_buffer, block_ref.mlp.layer_1.bias.grad_buffer, atol=tollerance + ) + assert torch.allclose( + block_packed.mlp.layer_2.weight.grad_buffer, block_ref.mlp.layer_2.weight.grad_buffer, atol=tollerance + ) + assert torch.allclose( + block_packed.mlp.layer_2.bias.grad_buffer, block_ref.mlp.layer_2.bias.grad_buffer, atol=tollerance + ) + + +if __name__ == "__main__": + pytest.main([__file__])