Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Megatron-LM
2 changes: 2 additions & 0 deletions fast_llm/engine/multi_stage/stage_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ def initialize_weights(self) -> None:
# Initialize all global weights on every gpu, then select the appropriate slice if applicable.
global_param = parameter.new_empty(global_shape, device=self._distributed.device)
meta.init_parameter(global_param, distributed=self._distributed)
# It happens.
Assert.eq(global_param.shape, global_shape)
if self._mode.on_device:
parameter.copy_(fsdp.parameter_global_to_shard(global_param, meta.tensor_name))
elif self._mode.on_device:
Expand Down
16 changes: 10 additions & 6 deletions fast_llm/layers/language_model/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,16 @@ def forward(
self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None
) -> torch.Tensor:
if isinstance(input_, TensorMeta):
return TensorMeta.from_tensor_space(
(DefaultDimNames.scalar,),
self._tensor_space,
tensor_name="Loss",
reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa
)
if self._is_last_head:
return TensorMeta.from_tensor_space(
(DefaultDimNames.scalar,),
self._tensor_space,
tensor_name="Loss",
reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa
)
else:
return TensorMeta.from_dims(input_.dims[1:], tensor_name="Shared hidden")

if not self._is_last_head:
# MTP: split the stacked input
shared_hidden, input_ = torch.unbind(input_, dim=0)
Expand Down
34 changes: 17 additions & 17 deletions fast_llm/layers/ssm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class SSMDimNames:

# Mamba 2
x_proj_dim_2 = "x_proj_dim_2" # d_xb
c_heads = "c_heads"


class SSMBlockType(enum.StrEnum):
Expand All @@ -35,6 +36,22 @@ class SSMBlockType(enum.StrEnum):
mamba2 = "m2"
transformer = "t"

def get_mixer_class(self):
if self == SSMBlockType.mamba:
from fast_llm.layers.ssm.mamba_layer import MambaLayer

return MambaLayer
elif self == SSMBlockType.mamba2:
from fast_llm.layers.ssm.mamba2 import Mamba2

return Mamba2
elif self == SSMBlockType.mamba2_discrete:
from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2

return DiscreteMamba2
else:
raise NotImplementedError(self)


@config_class()
class SSMConfig(LLMBlockConfig):
Expand Down Expand Up @@ -95,11 +112,6 @@ class SSMConfig(LLMBlockConfig):
desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.",
hint=FieldHint.architecture,
)
debug_ssm: bool = Field(
default=False,
desc="debug_ssm",
hint=FieldHint.optional,
)
dt_min: float = Field(
default=0.001,
desc="Minimum step size for discretization",
Expand Down Expand Up @@ -147,18 +159,6 @@ class SSMConfig(LLMBlockConfig):
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)
dt_min: float = Field(
default=0.001,
desc="Minimum step size for discretization",
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)
dt_init_floor: float = Field(
default=1e-4,
desc="Minimum value for initializing dt",
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)
dt_scale: float = Field(
default=1.0,
desc="Scale for dt",
Expand Down
27 changes: 13 additions & 14 deletions fast_llm/layers/ssm/discrete_mamba2.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import logging
import math
import typing

import einops
import torch

from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace
from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace
from fast_llm.layers.common.linear import Linear
from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames
from fast_llm.layers.transformer.config import TransformerKwargs
from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs
from fast_llm.layers.transformer.transformer import Mixer
from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_
from fast_llm.utils import get_lr_scale

Expand Down Expand Up @@ -36,29 +38,29 @@ def bias_init_method(conv_weight):
return init_uniform_(-bound, bound)


class DiscreteMamba2(torch.nn.Module):
class DiscreteMamba2(Mixer):
"""DiscreteMamba2 (This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py)."""

_mixer_name: typing.ClassVar[str] = "discrete_mamba_2"

def __init__(
self,
config: SSMConfig,
layer_idx: int,
block_index: int,
tensor_space: TensorSpace,
return_input: bool = False,
transformer_config: TransformerConfig,
):
"""
See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args.
Other options are all experimental and should not need to be configured.
"""
# factory_kwargs = {"device": "meta"} # , "dtype": torch.bfloat16}
super().__init__()
super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer)
self.config: SSMConfig = config
bias = config.add_bias_linear
self.layer_idx = layer_idx
self._return_input = return_input
layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None
layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None
mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale)
logger.info(f"Setting lr_scale for layer {layer_idx} of type {type(self)}: {mamba_layer_lr_scale}")
logger.info(f"Setting lr_scale for layer {block_index} of type {type(self)}: {mamba_layer_lr_scale}")

td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim)
td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim)
Expand Down Expand Up @@ -101,7 +103,7 @@ def __init__(
)

self.conv1d_weight = ParameterMeta.from_dims(
(td_conv, TensorDim("1", 1), td_conv_kernel),
(td_conv, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel),
init_method=init_uniform_(
1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size)
), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67
Expand Down Expand Up @@ -226,9 +228,6 @@ def forward(self, hidden_states, kwargs):
out = self.out_proj(y * torch.nn.functional.silu(z + self.z_bias))
outputs["hidden_states"] = out[:, :seqlen, :].contiguous()

if self._return_input:
return torch.stack([input_, outputs["hidden_states"]], dim=0)

# TODO: since we do not support inference for now, we only return the hidden states for now.
return outputs["hidden_states"], None

Expand Down
29 changes: 16 additions & 13 deletions fast_llm/layers/ssm/llamba_block.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,37 @@
import typing

from fast_llm.layers.transformer.transformer import BaseBlock
from fast_llm.layers.transformer.transformer import BaseBlock, Mixer

if typing.TYPE_CHECKING:
from fast_llm.engine.config_utils.tensor_space import TensorSpace
from fast_llm.layers.ssm.config import SSMConfig
from fast_llm.layers.transformer.config import TransformerConfig


class LlambaBlock(BaseBlock):
class SSMBlock(BaseBlock):
"""
A transformer-like decoder block with a SSM mixer, see https://arxiv.org/abs/2502.14458
"""

_name = "Llamba block"
_mixer_module_name = "mixer"

def __init__(
self,
config_transformer: "TransformerConfig",
config_ssm: "SSMConfig",
transformer_config: "TransformerConfig",
ssm_config: "SSMConfig",
tensor_space: "TensorSpace",
mixer_cls,
layer_index: int,
mixer_cls: type[Mixer],
block_index: int,
return_input: bool = False,
):
self.mixer_cls = mixer_cls
self._config_ssm = config_ssm
self._debug_mode = self._config_ssm.debug_ssm
super().__init__(config_transformer, tensor_space, layer_index, return_input)
self._ssm_config = ssm_config
self._mixer_cls = mixer_cls
super().__init__(transformer_config, tensor_space, block_index, return_input)

def _create_mixer(self):
self.mixer = self.mixer_cls(self._config_ssm, layer_idx=self._layer_index, tensor_space=self._tensor_space)
def _create_mixer(self) -> Mixer:
return self._mixer_cls(
self._ssm_config,
tensor_space=self._tensor_space,
block_index=self._block_index,
transformer_config=self._config,
)
Loading