Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ ENV PIP_CONSTRAINT=""
# There is no pre-build mamba image for pytorch 2.8, we build it before the rest to avoid rebuilds.
# We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 (same for causal-conv1d)
# We set the number of workers to avoid OOM when compiling on laptop. (TODO: Can we make it configurable?)
# Using varlen_mamba for variable length sequence support
RUN MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1"
RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2"
RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/jxiw/varlen_mamba@varlen_mamba"
# Copy dependency files with universal write permissions for all users.
COPY --chmod=777 setup.py setup.cfg pyproject.toml ./
COPY --chmod=777 ./fast_llm/__init__.py fast_llm/
Expand Down
26 changes: 26 additions & 0 deletions fast_llm/layers/common/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
from fast_llm.tensor import ParameterMeta, accumulate_gradient, init_ones_, init_zeros_
from fast_llm.utils import Assert

try:
from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn as mamba_rmsnorm_fn

_mamba_ssm_available = True
except ImportError:
_mamba_ssm_available = False

try:
import fused_layer_norm_cuda # noqa

Expand Down Expand Up @@ -288,3 +295,22 @@ def _forward_fused(self, input_: torch.Tensor) -> torch.Tensor:

def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor:
return torch.rms_norm(input_.to(self.weight.dtype), self.normalized_shape, self.weight, self._eps)


class MambaRMSNormGated(RMSNorm):
def __init__(self, hidden_dim: TensorDim, group_size: int, eps=1e-5, lr_scale: float | None = None):
assert _mamba_ssm_available
super().__init__(hidden_dim, eps=eps, lr_scale=lr_scale)
self.group_size = group_size
self._forward = mamba_rmsnorm_fn

def forward(self, input_: torch.Tensor, gate=None):
return mamba_rmsnorm_fn(
x=input_,
weight=self.weight,
bias=None, # No bias
z=gate,
eps=self._eps,
group_size=self.group_size,
norm_before_gate=False,
)
92 changes: 89 additions & 3 deletions fast_llm/layers/ssm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,33 @@
from fast_llm.tensor import Initializer


class BaseSSMKwargs:
_kwargs_attributes = {
"cu_seqlens": "cu_seqlens",
"seq_idx": "seq_idx",
"ssm_position_ids": "ssm_position_ids",
}

_prefix = ""

def __init_subclass__(cls, prefix="", **kwargs):
super().__init_subclass__(**kwargs)
cls._prefix = prefix
for attr, value in BaseSSMKwargs._kwargs_attributes.items():
setattr(cls, value, f"{cls._prefix}_{value}" if cls._prefix else value)


class SSMKwargs(BaseSSMKwargs, prefix=""):
pass


class SSMDimNames:
# TODO: Use separate tensor space for different mixers so there is no risk of name conflict.
state = "ssm_state" # State dimension (N), aka head size / num channels
head_dim = "ssm_head_dim"
head_groups = "ssm_head_groups"
group_heads = "ssm_group_heads"
conv1d_dim = "ssm_conv1d_dim"

# Mamba 2
x_proj_dim_2 = "x_proj_dim_2" # d_xb
Expand All @@ -28,7 +49,10 @@ class SSMDimNames:
# Composite dimensions
composite_heads = "ssm_composite_heads"
composite_heads_and_head_dim = "ssm_composite_heads_and_head_dim"
composite_heads_and_head_dim_nontp = "ssm_composite_heads_and_head_dim_nontp"
composite_heads_and_state_dim = "ssm_composite_heads_and_state_dim"
composite_head_groups_and_state = "ssm_composite_head_groups_and_state"
composite_head_groups_and_head = "ssm_composite_head_groups_and_head"

# Concatenated dimensions
concatenated_convolution = "ssm_concatenated_convolution"
Expand All @@ -45,6 +69,7 @@ class SSMBlockType(enum.StrEnum):
mamba2_discrete = "m2d"
mamba2 = "m2"
transformer = "t"
nemotron_h_mamba2 = "nm2"

def get_mixer_class(self):
if self == SSMBlockType.mamba:
Expand All @@ -59,6 +84,10 @@ def get_mixer_class(self):
from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2

return DiscreteMamba2
elif self == SSMBlockType.nemotron_h_mamba2:
from fast_llm.layers.ssm.mamba2 import NemotronHMamba2

return NemotronHMamba2
else:
raise NotImplementedError(self)

Expand Down Expand Up @@ -206,6 +235,21 @@ class SSMConfig(LLMBlockConfig):
valid=check_field(Assert.gt, 0),
)

# Nemotron H Mamba2 (the real mamba2 actually)
# here instead of setting d_inner, we set head dim. and number of heads
# Note: we do not implement n_groups for Mamba2, because, sicne we do MiL init, we do not want to share B and C parameters accross heads.
# Instead, we mimic the GQA behaviour (x -> v, B -> k, C -> q), where x and B are shared accross heads. So this is the same as having n_groups = n_heads?
# n_groups: int = Field(
# default=8,
# desc="Number of groups for Mamba2. Allows sharing B and C parameters accross heads.",
# hint=FieldHint.architecture,
# )
head_dim: int = Field(
default=64,
desc="Head dimension for Nemotron H",
hint=FieldHint.architecture,
)

def _validate(self) -> None:
with self._set_implicit_default():
if self.activation_type is None:
Expand All @@ -223,6 +267,10 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType
elif block_type == SSMBlockType.mamba2:
num_heads = div(self.d_inner, self.state_size)
num_head_groups = div(self.d_xb, self.state_size)
elif block_type == SSMBlockType.nemotron_h_mamba2:
# head dim and state size are not the same
num_heads = div(self.d_inner, self.head_dim)
num_head_groups = div(self.d_xb, self.head_dim)
elif block_type == SSMBlockType.mamba2_discrete:
# TODO: Use different variables?
num_heads = self.n_v_heads
Expand All @@ -233,6 +281,8 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType
tensor_space.add_tensor_dim(state := TensorDim(SSMDimNames.state, self.state_size))
if block_type == SSMBlockType.mamba2_discrete:
tensor_space.add_tensor_dim(head_dim := TensorDim(SSMDimNames.head_dim, div(self.d_inner, num_heads)))
elif block_type == SSMBlockType.nemotron_h_mamba2:
tensor_space.add_tensor_dim(head_dim := TensorDim(SSMDimNames.head_dim, self.head_dim))
else:
head_dim = state

Expand All @@ -241,14 +291,16 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType
tensor_space.add_tensor_dim(
heads := CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads))
)
# full d_inner or intermediate_size (e.g. for z gate, also the d_inner size for C in mamba2)
tensor_space.add_tensor_dim(
heads_and_head_dim := CompositeTensorDim(
SSMDimNames.composite_heads_and_head_dim, (head_groups, group_heads, head_dim)
)
)
# d_xb
tensor_space.add_tensor_dim(
head_groups_and_state := CompositeTensorDim(
SSMDimNames.composite_head_groups_and_state, (head_groups, state)
head_groups_and_head := CompositeTensorDim(
SSMDimNames.composite_head_groups_and_head, (head_groups, head_dim)
)
)
tensor_space.add_tensor_dim(TensorDim(SSMDimNames.convolution_kernel, self.conv_kernel_dimension))
Expand All @@ -272,7 +324,41 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType
tensor_space.add_tensor_dim(
ConcatenatedTensorDim(
SSMDimNames.concatenated_inner_projection,
(heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim),
(heads_and_head_dim, head_groups_and_head, head_groups_and_head, heads_and_head_dim),
)
)
elif block_type == SSMBlockType.nemotron_h_mamba2:
# for the norm
tensor_space.add_tensor_dim(
TensorDim(
SSMDimNames.composite_heads_and_head_dim_nontp, num_head_groups * group_heads.size * head_dim.size
)
)
# state and head dim are not the same
# C: for each head, size of state
tensor_space.add_tensor_dim(
heads_and_state_dim := CompositeTensorDim(
SSMDimNames.composite_heads_and_state_dim, (head_groups, group_heads, state)
)
)
# B: for each head group, size of state
tensor_space.add_tensor_dim(
head_groups_and_state := CompositeTensorDim(
SSMDimNames.composite_head_groups_and_state, (head_groups, state)
)
)
# here we apply depthwise conv. layer to xBC, so the dim. is x (d_xb) x B (d_bb) x C
tensor_space.add_tensor_dim(
conv1d_dim := ConcatenatedTensorDim(
SSMDimNames.conv1d_dim, (heads_and_state_dim, head_groups_and_head, head_groups_and_state)
)
)

# inner projection dimention: also includes z (gate), which has size d_inner (heads_and_head_dim)
tensor_space.add_tensor_dim(
ConcatenatedTensorDim(
SSMDimNames.concatenated_inner_projection,
(conv1d_dim, heads_and_head_dim),
)
)
elif block_type == SSMBlockType.mamba2_discrete:
Expand Down
Loading