From 02f8af5e5ce9189ded97a83ed5c90b84d18a5ec3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 29 Jul 2025 13:39:47 -0400 Subject: [PATCH 01/28] Block interface --- fast_llm/layers/block/__init__.py | 0 .../transformer.py => block/block.py} | 99 +--- fast_llm/layers/block/config.py | 120 ++++ fast_llm/layers/block/mixer.py | 68 +++ fast_llm/layers/block/mlp/__init__.py | 0 fast_llm/layers/block/mlp/config.py | 171 ++++++ .../mlp}/mixture_of_experts.py | 46 +- .../layers/{transformer => block/mlp}/mlp.py | 32 +- fast_llm/layers/block/peft.py | 128 +++++ fast_llm/layers/common/config.py | 12 - fast_llm/layers/language_model/config.py | 6 +- fast_llm/layers/language_model/embedding.py | 5 +- fast_llm/layers/language_model/head.py | 37 +- .../layers/language_model/preprocessing.py | 21 +- .../layers/ssm/{llamba_block.py => block.py} | 24 +- fast_llm/layers/ssm/config.py | 14 +- fast_llm/layers/ssm/discrete_mamba2.py | 29 +- fast_llm/layers/ssm/mamba2.py | 30 +- fast_llm/layers/ssm/mamba_layer.py | 16 +- fast_llm/layers/transformer/attention.py | 10 +- fast_llm/layers/transformer/block.py | 23 + fast_llm/layers/transformer/config.py | 542 +++--------------- fast_llm/models/gpt/conversion.py | 3 +- fast_llm/models/gpt/model.py | 14 +- tests/test_mlp.py | 14 +- 25 files changed, 749 insertions(+), 715 deletions(-) create mode 100644 fast_llm/layers/block/__init__.py rename fast_llm/layers/{transformer/transformer.py => block/block.py} (60%) create mode 100644 fast_llm/layers/block/config.py create mode 100644 fast_llm/layers/block/mixer.py create mode 100644 fast_llm/layers/block/mlp/__init__.py create mode 100644 fast_llm/layers/block/mlp/config.py rename fast_llm/layers/{transformer => block/mlp}/mixture_of_experts.py (89%) rename fast_llm/layers/{transformer => block/mlp}/mlp.py (77%) create mode 100644 fast_llm/layers/block/peft.py rename fast_llm/layers/ssm/{llamba_block.py => block.py} (52%) create mode 100644 fast_llm/layers/transformer/block.py diff --git a/fast_llm/layers/block/__init__.py b/fast_llm/layers/block/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/block/block.py similarity index 60% rename from fast_llm/layers/transformer/transformer.py rename to fast_llm/layers/block/block.py index 75d06f268..85da61c01 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/block/block.py @@ -1,83 +1,22 @@ import abc -import logging import typing import torch +from fast_llm.config import Configurable from fast_llm.core.distributed import set_generator from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs -from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP -from fast_llm.layers.transformer.mlp import MLP +from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs +from fast_llm.layers.block.mixer import Mixer +from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP +from fast_llm.layers.block.mlp.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta -from fast_llm.utils import Assert -logger = logging.getLogger(__name__) - -class Mixer(torch.nn.Module, abc.ABC): - """ - Base class for mixer modules. - """ - - _mixer_name: typing.ClassVar[str] - - def __init__(self, tensor_space: TensorSpace, block_index: int, debug_level: int = 0): - super().__init__() - self._tensor_space = tensor_space - self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel - self._block_index = block_index - self._debug_level = debug_level - - @abc.abstractmethod - def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - Mixer module forward. Returns the output hidden states and an optional bias, - in case its addition can be made more efficient in `_bias_dropout_add`. - """ - - def _get_meta( - self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> TensorMeta: - hidden_dims = { - dim.name: dim - for dim in kwargs[TransformerKwargs.hidden_dims] + (kwargs[TransformerKwargs.sequence_q_dim],) - } - return TensorMeta.from_dims( - tuple( - hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space[dim_name] - for dim_name in dim_names - ), - tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", - dtype=input_.dtype, - ) - - def _debug_log( - self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> None: - # TODO: Local vs global - Assert.gt(self._debug_level, 0) - log_distributed_tensor( - "", - tensor, - level=self._debug_level, - meta=self._get_meta(tensor, name, dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - if tensor.requires_grad: - log_distributed_grad( - "", - tensor, - level=self._debug_level, - meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - - -class BaseBlock(Layer, abc.ABC): +class Block[ConfigType: BlockConfig](Layer, Configurable[ConfigType]): """ A transformer-like decoder base block with abstract mixer. """ @@ -85,11 +24,9 @@ class BaseBlock(Layer, abc.ABC): # TODO: Standardize to `mixer` _mixer_module_name: typing.ClassVar[str] = "mixer" - def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False - ): + def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, return_input: bool = False): super().__init__() - self._config: TransformerConfig = config + self._config = config self._tensor_space: TensorSpace = tensor_space self._dropout_p: float = self._config.hidden_dropout # For multi-token prediction, return a stack of shared_hidden and transformer_output. @@ -97,7 +34,7 @@ def __init__( self._block_index = block_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory - hidden_dim = self._tensor_space[TransformerDimNames.hidden] + hidden_dim = self._tensor_space[BlockDimNames.hidden] # Note, layer_lr_scale does not impact the norms # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) @@ -131,7 +68,7 @@ def name(self) -> str: return f"{self._name} {self._block_index}" def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): - dims = kwargs[TransformerKwargs.hidden_dims] + dims = kwargs[BlockKwargs.hidden_dims] if self._return_input: dims = (TensorDim("stacked_input_output", 2),) + dims return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype) @@ -196,19 +133,3 @@ def forward( if self._return_input: hidden_states = torch.stack((fw_input, hidden_states), dim=0) return hidden_states - - -class TransformerBlock(BaseBlock): - _name = "Transformer layer" - # TODO: Standardize to `mixer` - _mixer_module_name: typing.ClassVar[str] = "self_attn" - - def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False - ): - super().__init__(config, tensor_space, block_index, return_input) - - def _create_mixer(self) -> Mixer: - from fast_llm.layers.transformer.attention import Attention - - return Attention(self._config, self._tensor_space, self._block_index) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py new file mode 100644 index 000000000..2f26d8d79 --- /dev/null +++ b/fast_llm/layers/block/config.py @@ -0,0 +1,120 @@ +import enum + +from fast_llm.config import Field, FieldHint, check_field, config_class +from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.layers.block.mlp.config import MLPConfig +from fast_llm.layers.block.peft_config import TransformerPeftConfig +from fast_llm.layers.common.config import NormalizationConfig +from fast_llm.utils import Assert + + +class BlockDimNames: + # A set of common tensor dim names packed into a namespace. + # Input dimensions (variable) + # TODO: Does batch belong here? + batch = "batch" + # TODO: Distinguish micro-sequence? + sequence_q = "sequence_q" + sequence_q_tp = "sequence_q_tp" + sequence_k = "sequence_k" + hidden = "hidden" + + +class BlockKwargs: + sequence_first = "sequence_first" + hidden_dims = "hidden_dims" + sequence_q_dim = "sequence_q_dim" + sequence_k_dim = "sequence_k_dim" + sequence_length = "sequence_length" + # TODO: Belongs elsewhere? + grad_output = "grad_output" + + +@config_class() +# TODO: Use composition for MLP config +class BlockConfig(MLPConfig, BaseModelConfig): + + # TODO: Review names + normalization: NormalizationConfig = Field( + desc="Configuration for the normalization layers architecture.", + hint=FieldHint.architecture, + ) + peft: TransformerPeftConfig = Field( + desc="Configuration for the parameter-efficient fine tuning.", + hint=FieldHint.architecture, + ) + hidden_dropout: float = Field( + default=0.0, + desc="Dropout applied to the residual connections.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + full_precision_residual: bool = Field( + default=False, + desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", + hint=FieldHint.stability, + ) + debug_transformer: int = Field( + default=0, + desc="Log the output of each operation in a transformer layer.", + hint=FieldHint.logging, + valid=check_field(Assert.geq, 0), + ) + debug_transformer_memory: bool = Field( + default=False, + desc="Log the memory usage after each operation in a transformer layer..", + hint=FieldHint.logging, + ) + add_linear_biases: bool | AddLinearBiasChoices = Field( + default=True, + desc="Add biases to all, none or Q, K, V layers. Accepted values: True, False, or AddLinearBiasChoices.", + hint=FieldHint.architecture, + ) + + # TODO: Move these, not specific to a single block. + num_layers: int = Field( + default=12, + desc="Number of layers in the transformer.", + hint=FieldHint.architecture, + valid=check_field(Assert.geq, 0), + ) + hidden_size: int = Field( + default=1024, + desc="Size of the transformer's main hidden dimension, e.g., for its input and output layers.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + per_layer_lr_scale: list[float] | None = Field( + default=None, + desc="Custom learning rate scale for each layer.", + doc="May be used to freeze some layers by setting their scale to zero.", + hint=FieldHint.feature, + ) + + def _validate(self) -> None: + with self._set_implicit_default(): + if self.ffn_hidden_size is None: + self.ffn_hidden_size = 4 * self.hidden_size + + super()._validate() + + @property + def add_mlp_bias(self) -> bool: + if isinstance(self.add_linear_biases, bool): + return self.add_linear_biases + if self.add_linear_biases == AddLinearBiasChoices.everywhere: + return True + return False + + def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + super().setup_tensor_space(tensor_space) + + # Hidden dimension + tensor_space.add_tensor_dim(TensorDim(BlockDimNames.hidden, self.hidden_size)) + + +class AddLinearBiasChoices(str, enum.Enum): + nowhere = "nowhere" + everywhere = "everywhere" + only_attn_qkv = "only_attn_qkv" diff --git a/fast_llm/layers/block/mixer.py b/fast_llm/layers/block/mixer.py new file mode 100644 index 000000000..5c811e330 --- /dev/null +++ b/fast_llm/layers/block/mixer.py @@ -0,0 +1,68 @@ +import abc +import typing + +import torch + +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.logging import log_distributed_grad, log_distributed_tensor +from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert + + +class Mixer(torch.nn.Module, abc.ABC): + """ + Base class for mixer modules. + """ + + _mixer_name: typing.ClassVar[str] + + def __init__(self, tensor_space: TensorSpace, block_index: int, debug_level: int = 0): + super().__init__() + self._tensor_space = tensor_space + self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel + self._block_index = block_index + self._debug_level = debug_level + + @abc.abstractmethod + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Mixer module forward. Returns the output hidden states and an optional bias, + in case its addition can be made more efficient in `_bias_dropout_add`. + """ + + def _get_meta( + self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> TensorMeta: + hidden_dims = { + dim.name: dim for dim in kwargs[BlockKwargs.hidden_dims] + (kwargs[BlockKwargs.sequence_q_dim],) + } + return TensorMeta.from_dims( + tuple( + hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space[dim_name] + for dim_name in dim_names + ), + tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", + dtype=input_.dtype, + ) + + def _debug_log( + self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> None: + # TODO: Local vs global + Assert.gt(self._debug_level, 0) + log_distributed_tensor( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name, dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + if tensor.requires_grad: + log_distributed_grad( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) diff --git a/fast_llm/layers/block/mlp/__init__.py b/fast_llm/layers/block/mlp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py new file mode 100644 index 000000000..63e31219a --- /dev/null +++ b/fast_llm/layers/block/mlp/config.py @@ -0,0 +1,171 @@ +import enum + +from fast_llm.config import Config, Field, FieldHint, check_field, skip_valid_if_none +from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace +from fast_llm.engine.distributed.config import DistributedDimNames +from fast_llm.functional.config import ActivationType, MLPRecomputeLevel +from fast_llm.utils import Assert + + +class MLPDimNames: + # MLP dimensions + mlp = "mlp" + gate_and_up = "gate_and_up" + composite_gated_mlp = "composite_gated_mlp" + experts = "experts" + top_experts = "top_experts" + shared_experts = "shared_experts" + unshared_experts = "unshared_experts" + composite_expert_mlp = "composite_expert_mlp" + composite_gated_expert_mlp = "composite_gated_expert_mlp" + composite_shared_expert_mlp = "composite_shared_expert_mlp" + composite_gated_shared_expert_mlp = "composite_gated_shared_expert_mlp" + + +class MLPLossNames: + load_balancing_loss = "load_balancing_loss" + router_z_loss = "router_z_loss" + + +class RoutingType(str, enum.Enum): + topk = "aux_loss" + sinkhorn = "sinkhorn" + + +class MLPConfig(Config): + # TODO: Review names + _abstract = False + ffn_hidden_size: int = Field( + default=None, + desc="Hidden dimension of the MLP intermediate state. Default: 4 * hidden_size.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + num_experts: int = Field( + default=1, + desc="Number of MLP experts in a Mixture of Expert (MoE) model", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + num_shared_experts: int = Field( + default=0, + desc="Number of MLP experts that are shared between all tokens, i.e., always enabled.", + hint=FieldHint.architecture, + valid=check_field(Assert.geq, 0), + ) + num_unshared_experts: int = Field( + init=False, + desc="Number of MLP experts excluding shared ones", + hint=FieldHint.architecture, + valid=check_field(Assert.geq, 0), + ) + num_experts_per_token: int = Field( + default=1, + desc="Active experts for each token in a MoE model.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + expert_routing_type: RoutingType = Field( + default=RoutingType.topk, + desc="The routing method, i.e., the method used to assign experts to tokens.", + hint=FieldHint.architecture, + ) + gated: bool = Field(default=False, desc="Enable gated MLP.", hint=FieldHint.architecture) + # Default: hidden_size**-0.5 + # TODO: Allow custom initialization (InitializationConfig?) + activation_type: ActivationType = Field( + default=None, + desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", + hint=FieldHint.core, + ) + # normalization_implementation: NormalizationImplementation = NormalizationImplementation.auto + mlp_recompute_level: MLPRecomputeLevel = Field( + default=MLPRecomputeLevel.none, + desc="Set which of the MLP intermediate activations will be recomputed during the backward passes. This provides a trade-off between memory and speed.", + hint=FieldHint.performance, + ) + expert_auxiliary_loss_coefficient: float = Field( + default=0.01, + desc="Scale of the load balancing auxiliary loss for topk routing.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + expert_z_loss_coefficient: float = Field( + default=0.0, + desc="Regularize the router during training by applying Z-loss to the logits.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + moe_jitter_eps: float = Field( + default=0.0, + desc="Regularize the router during training by applying a random multiplicative noise `uniform(1-eps, 1+eps)` to the logits.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + mlp_lr_scale: float | None | list[float | None] = Field( + default=None, + desc="Custom learning rate scale for each expert.", + doc="May be used to freeze some experts by setting their scale to zero.", + hint=FieldHint.feature, + ) + router_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate for the MoE router weight.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + dropless_moe: bool = Field( + default=True, desc="Evaluate all the experts at once using dropless MoE.", hint=FieldHint.expert + ) + dropless_dynamic_shape: bool = Field( + default=False, + desc="Use a dynamic shape for dropless MLP instead of the worst-case value." + " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", + hint=FieldHint.expert, + ) + + def _validate(self) -> None: + with self._set_implicit_default(): + if self.activation_type is None: + self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu + self.num_unshared_experts = self.num_experts - self.num_shared_experts + + super()._validate() + + Assert.leq(self.num_shared_experts, self.num_experts) + Assert.leq(self.num_shared_experts + self.num_experts_per_token, self.num_experts) + + if isinstance(self.mlp_lr_scale, list): + Assert.eq(len(self.mlp_lr_scale), self.num_experts) + for scale in self.mlp_lr_scale: + if scale is not None: + Assert.geq(scale, 0) + elif self.mlp_lr_scale is not None: + Assert.geq(self.mlp_lr_scale, 0) + + def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) + + # MLP dimensions + tensor_space.add_tensor_dim(mlp := TensorDim(MLPDimNames.mlp, self.ffn_hidden_size, tensor)) + tensor_space.add_tensor_dim(gate_and_up := TensorDim(MLPDimNames.gate_and_up, 2 if self.gated else 1)) + tensor_space.add_tensor_dim(CompositeTensorDim(MLPDimNames.composite_gated_mlp, (gate_and_up, mlp))) + tensor_space.add_tensor_dim(experts := TensorDim(MLPDimNames.experts, self.num_experts)) + tensor_space.add_tensor_dim(CompositeTensorDim(MLPDimNames.composite_expert_mlp, (experts, mlp))) + tensor_space.add_tensor_dim( + CompositeTensorDim(MLPDimNames.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) + ) + tensor_space.add_tensor_dim(TensorDim(MLPDimNames.top_experts, self.num_experts_per_token)) + tensor_space.add_tensor_dim(TensorDim(MLPDimNames.unshared_experts, self.num_unshared_experts)) + + # shared_experts + if self.num_shared_experts: + tensor_space.add_tensor_dim( + shared_experts := TensorDim(MLPDimNames.shared_experts, self.num_shared_experts) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(MLPDimNames.composite_shared_expert_mlp, (shared_experts, mlp)) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(MLPDimNames.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp)) + ) diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py similarity index 89% rename from fast_llm/layers/transformer/mixture_of_experts.py rename to fast_llm/layers/block/mlp/mixture_of_experts.py index 4fd2844d5..8d092b6dc 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -9,16 +9,11 @@ from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped from fast_llm.functional.triton.sparse_copy import get_sparse_map +from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs +from fast_llm.layers.block.mlp.config import MLPDimNames, MLPLossNames, RoutingType +from fast_llm.layers.block.mlp.mlp import MLPBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.linear import Linear -from fast_llm.layers.transformer.config import ( - RoutingType, - TransformerConfig, - TransformerDimNames, - TransformerKwargs, - TransformerLossNames, -) -from fast_llm.layers.transformer.mlp import MLPBase from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta, init_normal_ from fast_llm.utils import Assert, get_lr_scale @@ -26,7 +21,7 @@ logger = logging.getLogger(__name__) -class MixtureOfExpertMLP(MLPBase): +class MixtureOfExpertMLP[ConfigType: BlockConfig](MLPBase[ConfigType]): """ MoeLayer following implementation from https://github.com/NVIDIA/Megatron-LM/blob/46ebc0e4202c980d98900000d455f754a7ff9d4b/megatron/model/transformer.py#L346 @@ -40,12 +35,11 @@ class MixtureOfExpertMLP(MLPBase): _group: ProcessGroup - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): + def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." super().__init__(config, tensor_space, name, block_index) - self._config = config self._tensor_space = tensor_space self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory @@ -63,8 +57,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) self.router = Linear( - tensor_space[TransformerDimNames.hidden], - tensor_space[TransformerDimNames.unshared_experts], + tensor_space[BlockDimNames.hidden], + tensor_space[MLPDimNames.unshared_experts], bias=False, weight_init_method=init_normal_( std=config.init_method_std, min_val=config.init_method_min, max_val=config.init_method_max @@ -86,7 +80,7 @@ def forward( hidden_states = input_.flatten(0, -2) logits = self.router(hidden_states) if self._debug_mode: - self._debug_log(logits, "Router logits", TransformerDimNames.experts, kwargs) + self._debug_log(logits, "Router logits", MLPDimNames.experts, kwargs) # Apply z_loss if applicable if self._z_loss_factor > 0.0: @@ -96,7 +90,7 @@ def forward( self.training, grad_scale=kwargs.get("grad_output"), losses=losses, - loss_name=TransformerLossNames.router_z_loss, + loss_name=MLPLossNames.router_z_loss, ) # Apply input_jitter if applicable: @@ -106,7 +100,7 @@ def forward( # Routing if self._routing_type == RoutingType.topk: - scores, top_experts = self._topk_routing(logits, kwargs.get(TransformerKwargs.grad_output), losses) + scores, top_experts = self._topk_routing(logits, kwargs.get(BlockKwargs.grad_output), losses) if self._num_shared_experts > 0: scores, top_experts = self._add_shared_experts(top_experts, scores) elif self._routing_type == RoutingType.sinkhorn: @@ -116,8 +110,8 @@ def forward( if self._debug_mode: # To log all ranks set `global_=False` - self._debug_log(scores, "Router scores", TransformerDimNames.top_experts, kwargs) - self._debug_log(top_experts, "Router top experts", TransformerDimNames.top_experts, kwargs) + self._debug_log(scores, "Router scores", MLPDimNames.top_experts, kwargs) + self._debug_log(top_experts, "Router top experts", MLPDimNames.top_experts, kwargs) return self._mlp_forward(hidden_states, scores, top_experts).view_as(input_), None # noqa @@ -135,12 +129,12 @@ def _forward_dropless( None, self.layer_2.weight, None, - gated=self._gated, - activation_type=self._activation_type, + gated=self._config.gated, + activation_type=self._config.activation_type, group=self._intermediate_dim.parallel_group, sequence_parallel=self._sequence_parallel, training=self.training, - recompute_level=self._recompute_level, + recompute_level=self._config.mlp_recompute_level, transposed_layer_2_weight=True, sparse_map=sparse_map, ) @@ -155,12 +149,12 @@ def _forward_looped( self.layer_1.weight, self.layer_2.weight, self._num_experts, - self._gated, - self._activation_type, + self._config.gated, + self._config.activation_type, self._intermediate_dim.parallel_group, self._sequence_parallel, self.training, - self._recompute_level, + self._config.mlp_recompute_level, ) @torch.compile @@ -185,7 +179,7 @@ def _topk_routing( probs.flatten(0, -2).mean(dim=0) * mask.flatten(0, -2).mean(dim=0, dtype=torch.float32) ) if losses is not None: - losses[TransformerLossNames.load_balancing_loss].append(aux_loss.detach()) + losses[MLPLossNames.load_balancing_loss].append(aux_loss.detach()) if self.training and grad_scale is not None: scores = AuxiliaryLoss.apply( scores, aux_loss, self._num_unshared_experts * self._load_balancing_factor * grad_scale @@ -255,7 +249,7 @@ def _debug_log( def _get_meta(self, tensor: torch.Tensor, name: str, dim_name: str, kwargs: dict[str, typing.Any]) -> TensorMeta: return TensorMeta.from_dims( - kwargs[TransformerKwargs.hidden_dims][:-1] + (self._tensor_space[dim_name],), + kwargs[BlockKwargs.hidden_dims][:-1] + (self._tensor_space[dim_name],), tensor_name=f"{self._name} {name}", dtype=tensor.dtype, ) diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/block/mlp/mlp.py similarity index 77% rename from fast_llm/layers/transformer/mlp.py rename to fast_llm/layers/block/mlp/mlp.py index 101d97ef3..04b8506a4 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -1,21 +1,23 @@ import typing -from abc import ABC import torch +from fast_llm.config import Configurable from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd +from fast_llm.layers.block.config import BlockConfig, BlockDimNames +from fast_llm.layers.block.mlp.config import MLPDimNames +from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import LinearBase -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerSubLayerName from fast_llm.tensor import init_normal_, init_zeros_ from fast_llm.utils import Assert, get_lr_scale -class MLPBase(Layer, ABC): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): - super().__init__() +class MLPBase[ConfigType: BlockConfig](Layer, Configurable[ConfigType]): + def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): + super().__init__(config) self._name = name self._block_index = block_index @@ -30,13 +32,9 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s max_val=config.init_method_max_mlp_2, ) - hidden_dim = tensor_space[TransformerDimNames.hidden] - self._intermediate_dim = tensor_space[TransformerDimNames.composite_expert_mlp] + hidden_dim = tensor_space[BlockDimNames.hidden] + self._intermediate_dim = tensor_space[MLPDimNames.composite_expert_mlp] self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel - self._recompute_level = config.mlp_recompute_level - - self._gated = config.gated - self._activation_type = config.activation_type self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None @@ -46,7 +44,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, - tensor_space[TransformerDimNames.composite_gated_expert_mlp], + tensor_space[MLPDimNames.composite_gated_expert_mlp], bias=config.add_mlp_bias, weight_init_method=init_method_1, bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, @@ -68,8 +66,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self.layer_2 = config.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) -class MLP(MLPBase): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): +class MLP[ConfigType: BlockConfig](MLPBase[ConfigType]): + def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.eq(config.num_experts, 1) super().__init__(config, tensor_space, name, block_index) @@ -89,12 +87,12 @@ def forward( self.layer_1.bias, self.layer_2.weight, None if parallel_group else self.layer_2.bias, - gated=self._gated, - activation_type=self._activation_type, + gated=self._config.gated, + activation_type=self._config.activation_type, group=parallel_group, sequence_parallel=self._sequence_parallel, training=self.training, - recompute_level=self._recompute_level, + recompute_level=self._config.mlp_recompute_level, transposed_layer_2_weight=self.layer_2.transposed_weight, ), self.layer_2.bias if parallel_group else None, diff --git a/fast_llm/layers/block/peft.py b/fast_llm/layers/block/peft.py new file mode 100644 index 000000000..269ed0aac --- /dev/null +++ b/fast_llm/layers/block/peft.py @@ -0,0 +1,128 @@ +""" +TODO: Generalize beyond transformers. +""" + +import abc +import enum +import typing + +from fast_llm.config import Field, FieldHint, config_class +from fast_llm.layers.common.config import LoRAConfig, NoPeftConfig, PeftConfig +from fast_llm.tensor import ParameterMeta +from fast_llm.utils import div + +if typing.TYPE_CHECKING: + import torch + + from fast_llm.layers.common.linear import LinearBase, LinearLike + + +class TransformerSubLayerName(str, enum.Enum): + # TODO: Use this to replace AddLinearBiasChoices. + query = "query" + key = "key" + value_ = "value" + key_value = "key_value" + dense = "dense" + mlp_1 = "mlp_1" + mlp_2 = "mlp_2" + + +@config_class(registry=True) +class TransformerPeftConfig(PeftConfig): + @abc.abstractmethod + def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": + pass + + @abc.abstractmethod + def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": + pass + + @abc.abstractmethod + def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": + pass + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is TransformerPeftConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return TransformerNoPeftConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + + +@config_class(dynamic_type={TransformerPeftConfig: "none"}) +class TransformerNoPeftConfig(NoPeftConfig, TransformerPeftConfig): + _abstract = False + + def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": + return super().apply_linear(linear) + + def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": + return module + + def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": + return parameter + + +@config_class(dynamic_type={TransformerPeftConfig: "lora"}) +class TransformerLoRAConfig(LoRAConfig, TransformerPeftConfig): + layers: list[TransformerSubLayerName] = Field( + default=(TransformerSubLayerName.query, TransformerSubLayerName.value_), + desc="The layers on which to apply LoRA.", + hint=FieldHint.feature, + ) + freeze_others: bool = Field( + default=True, + desc="Whether to freeze other layers during training.", + ) + + def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": + if layer_type is None or self.layers is None or layer_type in self.layers: + if layer_type == TransformerSubLayerName.key: + return super().apply_linear(linear, out_channel_end=div(linear._out_dim.global_size, 2)) + elif layer_type == TransformerSubLayerName.value_: + return super().apply_linear(linear, out_channel_begin=div(linear._out_dim.global_size, 2)) + else: + return super().apply_linear(linear) + elif self.freeze_others: + linear.weight.requires_grad = False + return linear + + def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": + if self.freeze_others: + for parameter in module.parameters(): + parameter.requires_grad = False + return module + + def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": + if self.freeze_others: + parameter.requires_grad = False + return parameter + + def _validate(self) -> None: + super()._validate() + if TransformerSubLayerName.mlp_1 in self.layers or TransformerSubLayerName.mlp_2 in self.layers: + # TODO: Add MLP support. + raise NotImplementedError("LoRA not supported for MLP.") + if TransformerSubLayerName.dense in self.layers: + # TODO: Support InputParallelLinear (different output format). + raise NotImplementedError("LoRA not supported for attention dense layer.") + if ( + sum( + name in self.layers + for name in ( + TransformerSubLayerName.key_value, + TransformerSubLayerName.key, + TransformerSubLayerName.value_, + ) + ) + > 1 + ): + raise ValueError( + f"{TransformerSubLayerName.key_value.value}, {TransformerSubLayerName.key.value} and {TransformerSubLayerName.value_.value} are mutually exclusive." + ) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 07dadbc22..9d5ce3f3b 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -14,18 +14,6 @@ from fast_llm.layers.common.normalization import LayerNorm, RMSNorm -@config_class() -class LLMBlockConfig(BaseModelConfig): - _abstract = False - - per_layer_lr_scale: list[float] | None = Field( - default=None, - desc="Custom learning rate scale for each layer.", - doc="May be used to freeze some layers by setting their scale to zero.", - hint=FieldHint.feature, - ) - - class NormalizationImplementation(str, enum.Enum): """ An enum for the available implementations of layer norm. diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 8e2e97f1a..b667e5318 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -5,12 +5,13 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl +from fast_llm.layers.block.config import BlockDimNames, BlockKwargs from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.layers.transformer.rotary.config import NoRotaryConfig from fast_llm.utils import Assert -class LanguageModelDimNames: +class LanguageModelDimNames(BlockDimNames): # Embedding dimensions position_embed = "position_embed" vocab = "vocab" @@ -33,7 +34,7 @@ def multi_token_prediction_loss(index: int) -> str: return f"language_model_loss_{index}" -class LanguageModelKwargs: +class LanguageModelKwargs(BlockKwargs): position_ids = "position_ids" # TODO: These are generic labels = "labels" @@ -46,6 +47,7 @@ class LanguageModelKwargs: @config_class() class LanguageModelBaseConfig(BaseModelConfig): + # TODO: block transformer: TransformerConfig = Field( desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index f6f43d199..05678a700 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -8,7 +8,6 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelDimNames, LanguageModelKwargs -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ from fast_llm.utils import Assert @@ -46,7 +45,7 @@ def __init__( self._dropout_p = config.transformer.hidden_dropout self._use_absolute_position_embeddings = config.use_absolute_position_embeddings - hidden_dim = tensor_space[TransformerDimNames.hidden] + hidden_dim = tensor_space[LanguageModelDimNames.hidden] vocab_dim = tensor_space[ LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab ] @@ -129,7 +128,7 @@ def forward( ) -> torch.Tensor: if isinstance(input_, TensorMeta): return TensorMeta.from_dims( - kwargs[TransformerKwargs.hidden_dims], + kwargs[LanguageModelKwargs.hidden_dims], tensor_name="Embedding output", dtype=self._residual_dtype, ) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 210cad644..bc672725c 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -23,7 +23,6 @@ LanguageModelLossNames, ) from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs from fast_llm.logging import log_distributed_tensor from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ from fast_llm.utils import Assert, div, get_unique @@ -61,7 +60,7 @@ def __init__( if self._cross_entropy_splits is not None and self._sequence_parallel: assert not self._parallel_embeddings - hidden_dim = self._tensor_space[TransformerDimNames.hidden] + hidden_dim = self._tensor_space[LanguageModelDimNames.hidden] self._loss_coefficient = ( config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0 @@ -168,20 +167,22 @@ def _forward_backward( if "output_hidden_states" in kwargs and kwargs["output_hidden_states"]: # The last hidden layer output is returned normalized in the HF Transformers-style output, at least for LLama style models. # So, if needed, we gather the data after normalization and set it as the output of the previous layer. - dims = list(kwargs[TransformerKwargs.hidden_dims]) - sequence_index = 1 - int(kwargs[TransformerKwargs.sequence_first]) + dims = list(kwargs[LanguageModelKwargs.hidden_dims]) + sequence_index = 1 - int(kwargs[LanguageModelKwargs.sequence_first]) dims[sequence_index] = ( TensorDim( - TransformerDimNames.sequence_q_tp, dims[sequence_index].global_size, DistributedDimNames.tensor + LanguageModelDimNames.sequence_q_tp, + dims[sequence_index].global_size, + DistributedDimNames.tensor, ) if self._sequence_parallel_logits - else TensorDim(TransformerDimNames.sequence_q, dims[sequence_index].global_size) + else TensorDim(LanguageModelDimNames.sequence_q, dims[sequence_index].global_size) ) meta = TensorMeta.from_dims(tuple(dims), tensor_name="transformer hidden_state", dtype=ln_output.dtype) hidden_state, _ = meta.local_to_global(ln_output.detach(), distributed=self._tensor_space.distributed) kwargs["hidden_states"][len(kwargs["hidden_states"]) - 1]["tensor"] = hidden_state - grad_output = kwargs[TransformerKwargs.grad_output] / ( + grad_output = kwargs[LanguageModelKwargs.grad_output] / ( self._group_size if self._sequence_parallel_logits else 1 ) @@ -221,18 +222,18 @@ def _get_targets( if lm_target is not None: # MTP: Shift the labels lm_target_sequence_length = ( - lm_target.size(1 - kwargs[TransformerKwargs.sequence_first]) + lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - self._config.prediction_heads ) - if TransformerKwargs.sequence_q_dim in kwargs: - Assert.eq(lm_target_sequence_length, kwargs[TransformerKwargs.sequence_q_dim].size) + if LanguageModelKwargs.sequence_q_dim in kwargs: + Assert.eq(lm_target_sequence_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) lm_target_slice = slice( self._prediction_distance, self._prediction_distance + lm_target_sequence_length ) lm_target = ( lm_target[lm_target_slice] - if kwargs[TransformerKwargs.sequence_first] + if kwargs[LanguageModelKwargs.sequence_first] else lm_target[:, lm_target_slice] ).flatten() else: @@ -341,23 +342,23 @@ def _logits_cross_entropy_forward_backward( vocab_dim = self._tensor_space[ LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp ] - dims = [*kwargs[TransformerKwargs.hidden_dims][:-1], vocab_dim] - sequence_index = 1 - int(kwargs[TransformerKwargs.sequence_first]) + dims = [*kwargs[LanguageModelKwargs.hidden_dims][:-1], vocab_dim] + sequence_index = 1 - int(kwargs[LanguageModelKwargs.sequence_first]) dims[sequence_index] = ( TensorDim( - TransformerDimNames.sequence_q_tp, dims[sequence_index].global_size, DistributedDimNames.tensor + LanguageModelDimNames.sequence_q_tp, dims[sequence_index].global_size, DistributedDimNames.tensor ) if self._sequence_parallel_logits - else TensorDim(TransformerDimNames.sequence_q, dims[sequence_index].global_size) + else TensorDim(LanguageModelDimNames.sequence_q, dims[sequence_index].global_size) ) dim_names = ( - [TransformerDimNames.sequence_q_tp, LanguageModelDimNames.vocab] + [LanguageModelDimNames.sequence_q_tp, LanguageModelDimNames.vocab] if self._sequence_parallel_logits - else [TransformerDimNames.sequence_q, LanguageModelDimNames.vocab_tp] + else [LanguageModelDimNames.sequence_q, LanguageModelDimNames.vocab_tp] ) - dim_names.insert(int(kwargs[TransformerKwargs.sequence_first]), TransformerDimNames.batch) + dim_names.insert(int(kwargs[LanguageModelKwargs.sequence_first]), LanguageModelDimNames.batch) log_distributed_tensor( "", logits, diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index c8d53a789..f5d915855 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -6,7 +6,6 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs -from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert @@ -41,29 +40,29 @@ def _create_tensors(self, sequence_length: int) -> None: ) def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[TransformerKwargs.sequence_length]) - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size - if (sequence_lengths := kwargs.get(TransformerKwargs.sequence_lengths)) is not None: + self._create_tensors(kwargs[LanguageModelKwargs.sequence_length]) + sequence_k = kwargs[LanguageModelKwargs.sequence_k_dim].size + sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size + if (sequence_lengths := kwargs.get(LanguageModelKwargs.sequence_lengths)) is not None: position_ids = torch.stack( [torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] ).to(self._tensor_space.distributed.device, dtype=torch.int64) position_ids = position_ids[:, sequence_k - sequence_q : sequence_k] - if kwargs[TransformerKwargs.sequence_first]: + if kwargs[LanguageModelKwargs.sequence_first]: position_ids = position_ids.transpose(0, 1) kwargs[LanguageModelKwargs.position_ids] = position_ids else: kwargs[LanguageModelKwargs.position_ids] = self._position_ids[ sequence_k - sequence_q : sequence_k - ].unsqueeze(int(kwargs[TransformerKwargs.sequence_first])) + ].unsqueeze(int(kwargs[LanguageModelKwargs.sequence_first])) def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: # Position embeddings will be broadcast. - sequence_q_dim = kwargs[TransformerKwargs.sequence_q_dim] + sequence_q_dim = kwargs[LanguageModelKwargs.sequence_q_dim] kwargs[LanguageModelKwargs.position_ids] = TensorMeta.from_dims( ( (sequence_q_dim, self._scalar_dim) - if kwargs[TransformerKwargs.sequence_first] + if kwargs[LanguageModelKwargs.sequence_first] else (self._scalar_dim, sequence_q_dim) ), tensor_name=LanguageModelKwargs.position_ids, @@ -82,8 +81,8 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: return def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size + sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size + sequence_k = kwargs[LanguageModelKwargs.sequence_k_dim].size sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels if LanguageModelKwargs.chosen_spans not in kwargs or LanguageModelKwargs.rejected_spans not in kwargs: diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/block.py similarity index 52% rename from fast_llm/layers/ssm/llamba_block.py rename to fast_llm/layers/ssm/block.py index 986606634..4854900a3 100644 --- a/fast_llm/layers/ssm/llamba_block.py +++ b/fast_llm/layers/ssm/block.py @@ -1,14 +1,12 @@ -import typing +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.block.block import Block +from fast_llm.layers.block.config import BlockConfig +from fast_llm.layers.block.mixer import Mixer +from fast_llm.layers.ssm.config import SSMConfig -from fast_llm.layers.transformer.transformer import BaseBlock, Mixer -if typing.TYPE_CHECKING: - from fast_llm.engine.config_utils.tensor_space import TensorSpace - from fast_llm.layers.ssm.config import SSMConfig - from fast_llm.layers.transformer.config import TransformerConfig - - -class SSMBlock(BaseBlock): +# TODO: Sort out configs. +class SSMBlock[ConfigType: BlockConfig](Block[BlockConfig]): """ A transformer-like decoder block with a SSM mixer, see https://arxiv.org/abs/2502.14458 """ @@ -17,16 +15,16 @@ class SSMBlock(BaseBlock): def __init__( self, - transformer_config: "TransformerConfig", - ssm_config: "SSMConfig", - tensor_space: "TensorSpace", + config: BlockConfig, + ssm_config: SSMConfig, + tensor_space: TensorSpace, mixer_cls: type[Mixer], block_index: int, return_input: bool = False, ): self._ssm_config = ssm_config self._mixer_cls = mixer_cls - super().__init__(transformer_config, tensor_space, block_index, return_input) + super().__init__(config, tensor_space, block_index, return_input) def _create_mixer(self) -> Mixer: return self._mixer_cls( diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 9b0949d55..efcf2d873 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,18 +1,18 @@ import enum import typing -from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType -from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig +from fast_llm.layers.block.config import BlockDimNames from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: from fast_llm.tensor import Initializer -class SSMDimNames: +class SSMDimNames(BlockDimNames): # TODO: Use separate tensor space for different mixers so there is no risk of name conflict. state = "ssm_state" # State dimension (N), aka head size / num channels head_dim = "ssm_head_dim" @@ -72,15 +72,9 @@ def get_init_method(self, scale: float) -> "Initializer": @config_class() -class SSMConfig(LLMBlockConfig): +class SSMConfig(Config): _abstract = False - # Normalization - normalization: NormalizationConfig = Field( - desc="Configuration for the normalization layers architecture.", - hint=FieldHint.architecture, - ) - # Model dimensions # TODO: Remove (redundant default) expansion_factor: int = Field( diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index c9d555de9..550c44d0f 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -6,10 +6,10 @@ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType +from fast_llm.layers.block.config import BlockConfig, BlockKwargs +from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs -from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ from fast_llm.utils import get_lr_scale @@ -42,15 +42,15 @@ def __init__( config: SSMConfig, block_index: int, tensor_space: TensorSpace, - transformer_config: TransformerConfig, + block_config: BlockConfig, ): - super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) + super().__init__(tensor_space, block_index, debug_level=block_config.debug_transformer) self._config: SSMConfig = config - layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None + layer_lr_scale = block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) inner_dim = tensor_space[SSMDimNames.composite_heads_and_head_dim] - hidden_dim = tensor_space[TransformerDimNames.hidden] + hidden_dim = tensor_space[SSMDimNames.hidden] conv1d_dim = tensor_space[SSMDimNames.concatenated_convolution] heads_dim = tensor_space[SSMDimNames.composite_heads] @@ -69,7 +69,7 @@ def __init__( hidden_dim, tensor_space[SSMDimNames.concatenated_inner_projection], bias=config.add_bias_linear, - weight_init_method=init_kaiming_(transformer_config.hidden_size), + weight_init_method=init_kaiming_(block_config.hidden_size), sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) @@ -113,12 +113,12 @@ def __init__( def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available - sequence_length = kwargs[TransformerKwargs.sequence_q_dim].global_size + sequence_length = kwargs[BlockKwargs.sequence_q_dim].global_size # Pad input to nearest multiple of chunklen padded_length = (1 + (sequence_length - 1) // self._config.chunk_size) * self._config.chunk_size if padded_length != sequence_length: - assert not kwargs[TransformerKwargs.sequence_first] and input_.size(1) == sequence_length + assert not kwargs[BlockKwargs.sequence_first] and input_.size(1) == sequence_length input_ = torch.nn.functional.pad(input_, (0, 0, 0, padded_length - sequence_length)) # inner_projection : (batch/local_or_padded_sequence, local_sequence/batch, hidden) @@ -126,10 +126,9 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ # inner_projection: (batch, local_or_padded_sequence, hidden) -> (batch, padded_sequence, local_inner_size) inner_projection = self.in_proj(input_) # Standardize to (batch, padded_sequence, inner_projection) - if kwargs[TransformerKwargs.sequence_first]: + if kwargs[BlockKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) - print("QAIKOFNMJOWENM inner_projection", inner_projection.shape) xBC, z, A_log = torch.split( inner_projection, [ @@ -139,10 +138,6 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ ], dim=-1, ) - print("QAIKOFNMJOWENM xBC", xBC.shape, self._local_inner_size, self._local_bc_size) - print("QAIKOFNMJOWENM z", z.shape) - print("QAIKOFNMJOWENM A_log", A_log.shape) - # Convolutional layer # xbc: (batch, padded_sequence, local_heads * head_size + 2 * local_head_groups * state) xBC = self.convolutional_forward(xBC, padded_length) @@ -183,14 +178,12 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ # y: (batch, padded_sequence, local_heads, head_size) -> (batch, sequence, local_heads * head_size) y = ((y + Du).flatten(2, 3) * torch.nn.functional.silu(z))[:, :sequence_length] - if kwargs[TransformerKwargs.sequence_first]: + if kwargs[BlockKwargs.sequence_first]: # TODO: Is contiguous needed? y = y.transpose(0, 1).contiguous() # out_proj: (batch/sequence, sequence/batch, local_heads * head_size) # -> (batch/local_sequence, local_sequence/batch, hidden) a, b = self.out_proj(y) - logger.info(f"EKFBN y {y.shape}") - logger.info(f"EKFBN a {a.shape}") return self.out_proj(y) @torch.compile diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 77c1b3869..712c420ee 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -5,11 +5,11 @@ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.config import ActivationType +from fast_llm.layers.block.config import BlockConfig +from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs -from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_ from fast_llm.utils import Assert, div, get_lr_scale @@ -38,15 +38,15 @@ class Mamba2(Mixer): _mixer_name: typing.ClassVar[str] = "mamba_2" _XZ_DIMS = ( - TransformerDimNames.batch, + SSMDimNames.batch, SSMDimNames.composite_heads_and_head_dim, - TransformerDimNames.sequence_q, + SSMDimNames.sequence_q, ) _BC_DIMS = ( - TransformerDimNames.batch, + SSMDimNames.batch, SSMDimNames.composite_heads, SSMDimNames.state, - TransformerDimNames.sequence_q, + SSMDimNames.sequence_q, ) def __init__( @@ -54,17 +54,19 @@ def __init__( config: SSMConfig, tensor_space: TensorSpace, block_index: int, - transformer_config: TransformerConfig, + block_config: BlockConfig, ): - super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) + super().__init__(tensor_space, block_index, debug_level=block_config.debug_transformer) self._config: SSMConfig = config Assert.eq(self._config.activation_type, ActivationType.silu) - layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None + layer_lr_scale: float | None = ( + block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None + ) lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) inner_dim: TensorDim = tensor_space[SSMDimNames.composite_heads_and_head_dim] xb_dim = tensor_space[SSMDimNames.composite_head_groups_and_state] - hidden_dim: TensorDim = tensor_space[TransformerDimNames.hidden] + hidden_dim: TensorDim = tensor_space[SSMDimNames.hidden] dt_rank_dim = tensor_space[SSMDimNames.dt_rank] self._local_heads = tensor_space[SSMDimNames.composite_heads].size @@ -92,7 +94,7 @@ def __init__( hidden_dim, tensor_space[SSMDimNames.concatenated_inner_projection], bias=config.add_bias_linear, - weight_init_method=init_kaiming_(transformer_config.hidden_size), + weight_init_method=init_kaiming_(block_config.hidden_size), sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) @@ -101,7 +103,7 @@ def __init__( hidden_dim, dt_rank_dim, bias=config.add_bias_linear, - weight_init_method=init_kaiming_(transformer_config.hidden_size), + weight_init_method=init_kaiming_(block_config.hidden_size), lr_scale=lr_scale, ) self.dt_proj = OutputParallelLinear( @@ -151,7 +153,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ inner_projection = self.in_proj(input_) dt = self.dt_proj(self.dt_in_proj(input_)) + self.dt_proj_bias # Standardize to (batch, sequence, inner_projection) - if kwargs[TransformerKwargs.sequence_first]: + if kwargs[BlockKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) dt = dt.transpose(0, 1) @@ -220,7 +222,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ # y: (batch, local_heads * state, sequence) -> (batch, sequence, local_heads * state) y = y.transpose(1, 2)[:, :sequence_length] - if kwargs[TransformerKwargs.sequence_first]: + if kwargs[BlockKwargs.sequence_first]: # TODO: Is contiguous needed? y = y.transpose(0, 1).contiguous() # (batch/sequence, sequence/batch, local_heads * state) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 9343ef1b8..f5b0139cf 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -6,10 +6,10 @@ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType +from fast_llm.layers.block.config import BlockConfig, BlockKwargs +from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs -from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import LambdaInitializer, ParameterMeta, init_kaiming_, init_ones_ from fast_llm.utils import Assert, get_lr_scale @@ -60,9 +60,9 @@ def __init__( config: SSMConfig, block_index: int, tensor_space: TensorSpace, - transformer_config: TransformerConfig, + block_config: BlockConfig, ): - super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) + super().__init__(tensor_space, block_index, debug_level=block_config.debug_transformer) assert tensor_space.distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" self._config = config # TODO: It's not silu? @@ -70,8 +70,8 @@ def __init__( # Tensor dims: inner_dim = tensor_space[SSMDimNames.composite_heads_and_head_dim] - hidden_dim = tensor_space[TransformerDimNames.hidden] - layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None + hidden_dim = tensor_space[SSMDimNames.hidden] + layer_lr_scale = block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) # TODO: Backward compatibility? @@ -141,7 +141,7 @@ def __init__( def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available - in_proj = self.in_proj(input_).permute((1, 2, 0) if kwargs[TransformerKwargs.sequence_first] else (0, 2, 1)) + in_proj = self.in_proj(input_).permute((1, 2, 0) if kwargs[BlockKwargs.sequence_first] else (0, 2, 1)) # In the backward pass we write dx and dz next to each other to avoid torch.cat # not, if we wanbt to support inference, we would need to imp.lement slow path here, see https://github.com/Zyphra/Zamba2/blob/1b182f40f2257f822cc06dd785df53d67d691a15/mamba_layer.py#L172s @@ -160,6 +160,6 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ delta_bias=self.dt_proj_bias.float(), delta_softplus=True, ) - if kwargs[TransformerKwargs.sequence_first]: + if kwargs[BlockKwargs.sequence_first]: out = out.transpose(0, 1) return out, None diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index c59b191af..b1de792e3 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -6,14 +6,10 @@ from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.autograd import wrap_forward_backward +from fast_llm.layers.block.mixer import Mixer +from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import ( - TransformerConfig, - TransformerDimNames, - TransformerKwargs, - TransformerSubLayerName, -) -from fast_llm.layers.transformer.transformer import Mixer +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.tensor import init_normal_, init_zeros_ from fast_llm.utils import get_lr_scale diff --git a/fast_llm/layers/transformer/block.py b/fast_llm/layers/transformer/block.py new file mode 100644 index 000000000..4a0e818f0 --- /dev/null +++ b/fast_llm/layers/transformer/block.py @@ -0,0 +1,23 @@ +import logging +import typing + +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.block.block import Block +from fast_llm.layers.block.mixer import Mixer +from fast_llm.layers.transformer.attention import Attention +from fast_llm.layers.transformer.config import TransformerConfig + +logger = logging.getLogger(__name__) + + +class TransformerBlock[ConfigType: TransformerConfig](Block[ConfigType]): + _name = "Transformer layer" + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "self_attn" + _config: TransformerConfig + + def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, return_input: bool = False): + super().__init__(config, tensor_space, block_index, return_input) + + def _create_mixer(self) -> Mixer: + return Attention(self._config, self._tensor_space, self._block_index) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index f6eaf5890..1c10753a8 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -1,44 +1,25 @@ -import abc -import enum import functools import logging -import math import typing import warnings -from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames -from fast_llm.functional.config import ActivationType, MLPRecomputeLevel, TritonConfig -from fast_llm.layers.common.config import LLMBlockConfig, LoRAConfig, NoPeftConfig, NormalizationConfig, PeftConfig +from fast_llm.functional.config import TritonConfig +from fast_llm.layers.block.config import AddLinearBiasChoices, BlockConfig, BlockDimNames, BlockKwargs from fast_llm.layers.transformer.rotary.config import RotaryConfig from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: - import torch - - from fast_llm.layers.common.linear import LinearBase, LinearLike - from fast_llm.tensor import ParameterMeta + pass logger = logging.getLogger(__name__) -class RoutingType(str, enum.Enum): - topk = "aux_loss" - sinkhorn = "sinkhorn" - - -class TransformerDimNames: +class TransformerDimNames(BlockDimNames): # A set of common tensor dim names packed into a namespace. - # Input dimensions (variable) - # TODO: Does batch belong here? - batch = "batch" - # TODO: Distinguish micro-sequence? - sequence_q = "sequence_q" - sequence_q_tp = "sequence_q_tp" - sequence_k = "sequence_k" - hidden = "hidden" # Self-attention dimensions head_groups = "head_groups" group_heads = "group_heads" @@ -48,21 +29,9 @@ class TransformerDimNames: composite_query = "composite_query" composite_key_value = "composite_key_value" composite_dense = "composite_dense" - # MLP dimensions - mlp = "mlp" - gate_and_up = "gate_and_up" - composite_gated_mlp = "composite_gated_mlp" - experts = "experts" - top_experts = "top_experts" - shared_experts = "shared_experts" - unshared_experts = "unshared_experts" - composite_expert_mlp = "composite_expert_mlp" - composite_gated_expert_mlp = "composite_gated_expert_mlp" - composite_shared_expert_mlp = "composite_shared_expert_mlp" - composite_gated_shared_expert_mlp = "composite_gated_shared_expert_mlp" -class TransformerKwargs: +class TransformerKwargs(BlockKwargs): rotary_freq_q = "rotary_freq_q" rotary_freq_k = "rotary_freq_k" attention_mask = "attention_mask" @@ -75,164 +44,17 @@ class TransformerKwargs: # TODO: Review these presents = "presents" past_key_values = "past_key_values" - sequence_first = "sequence_first" - hidden_dims = "hidden_dims" - sequence_q_dim = "sequence_q_dim" - sequence_k_dim = "sequence_k_dim" - sequence_length = "sequence_length" - # TODO: Move - grad_output = "grad_output" - - -class TransformerLossNames: - load_balancing_loss = "load_balancing_loss" - router_z_loss = "router_z_loss" - - -class AddLinearBiasChoices(str, enum.Enum): - nowhere = "nowhere" - everywhere = "everywhere" - only_attn_qkv = "only_attn_qkv" - - -class TransformerSubLayerName(str, enum.Enum): - # TODO: Use this to replace AddLinearBiasChoices. - query = "query" - key = "key" - value_ = "value" - key_value = "key_value" - dense = "dense" - mlp_1 = "mlp_1" - mlp_2 = "mlp_2" -@config_class(registry=True) -class TransformerPeftConfig(PeftConfig): - @abc.abstractmethod - def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": - pass - - @abc.abstractmethod - def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": - pass - - @abc.abstractmethod - def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": - pass - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - if cls is TransformerPeftConfig and cls.get_subclass(default.get("type")) is None: - # Default subclass. - return TransformerNoPeftConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) - - -@config_class(dynamic_type={TransformerPeftConfig: "none"}) -class TransformerNoPeftConfig(NoPeftConfig, TransformerPeftConfig): +class AttentionConfig(Config): + # TODO: Make mixer class dynamic. _abstract = False - def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": - return super().apply_linear(linear) - - def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": - return module - - def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": - return parameter - - -@config_class(dynamic_type={TransformerPeftConfig: "lora"}) -class TransformerLoRAConfig(LoRAConfig, TransformerPeftConfig): - layers: list[TransformerSubLayerName] = Field( - default=(TransformerSubLayerName.query, TransformerSubLayerName.value_), - desc="The layers on which to apply LoRA.", - hint=FieldHint.feature, - ) - freeze_others: bool = Field( - default=True, - desc="Whether to freeze other layers during training.", - ) - - def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": - if layer_type is None or self.layers is None or layer_type in self.layers: - if layer_type == TransformerSubLayerName.key: - return super().apply_linear(linear, out_channel_end=div(linear._out_dim.global_size, 2)) - elif layer_type == TransformerSubLayerName.value_: - return super().apply_linear(linear, out_channel_begin=div(linear._out_dim.global_size, 2)) - else: - return super().apply_linear(linear) - elif self.freeze_others: - linear.weight.requires_grad = False - return linear - - def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": - if self.freeze_others: - for parameter in module.parameters(): - parameter.requires_grad = False - return module - - def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": - if self.freeze_others: - parameter.requires_grad = False - return parameter - - def _validate(self) -> None: - super()._validate() - if TransformerSubLayerName.mlp_1 in self.layers or TransformerSubLayerName.mlp_2 in self.layers: - # TODO: Add MLP support. - raise NotImplementedError("LoRA not supported for MLP.") - if TransformerSubLayerName.dense in self.layers: - # TODO: Support InputParallelLinear (different output format). - raise NotImplementedError("LoRA not supported for attention dense layer.") - if ( - sum( - name in self.layers - for name in ( - TransformerSubLayerName.key_value, - TransformerSubLayerName.key, - TransformerSubLayerName.value_, - ) - ) - > 1 - ): - raise ValueError( - f"{TransformerSubLayerName.key_value.value}, {TransformerSubLayerName.key.value} and {TransformerSubLayerName.value_.value} are mutually exclusive." - ) - - -@config_class() -class TransformerConfig(LLMBlockConfig): - _abstract = False - normalization: NormalizationConfig = Field( - desc="Configuration for the normalization layers architecture.", - hint=FieldHint.architecture, - ) + # TODO: Review names rotary: RotaryConfig = Field( desc="Configuration for the rotary positional embeddings.", hint=FieldHint.architecture, ) - peft: TransformerPeftConfig = Field( - desc="Configuration for the parameter-efficient fine tuning.", - hint=FieldHint.architecture, - ) - num_layers: int = Field( - default=12, - desc="Number of layers in the transformer.", - hint=FieldHint.architecture, - valid=check_field(Assert.geq, 0), - ) - hidden_size: int = Field( - default=1024, - desc="Size of the transformer's main hidden dimension, e.g., for its input and output layers.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) num_attention_heads: int = Field(default=8, desc="Number of attention heads.", hint=FieldHint.architecture) head_groups: int = Field( default=1, @@ -241,60 +63,104 @@ class TransformerConfig(LLMBlockConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - add_linear_biases: bool | AddLinearBiasChoices = Field( - default=True, - desc="Add biases to all, none or Q, K, V layers. Accepted values: True, False, or AddLinearBiasChoices.", - hint=FieldHint.architecture, - ) - ffn_hidden_size: int = Field( - default=None, - desc="Hidden dimension of the MLP intermediate state. Default: 4 * hidden_size.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) kv_channels: int = Field( default=None, desc="Number of key and value channels, i.e., hidden dimension of each attention head. Default: hidden_size // num_attention_heads", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - gated: bool = Field(default=False, desc="Enable gated MLP.", hint=FieldHint.architecture) - num_experts: int = Field( - default=1, - desc="Number of MLP experts in a Mixture of Expert (MoE) model", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - num_shared_experts: int = Field( - default=0, - desc="Number of MLP experts that are shared between all tokens, i.e., always enabled.", - hint=FieldHint.architecture, + attention_dropout: float = Field( + default=0.0, + desc="Dropout applied to the attention intermediate states.", + hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - num_unshared_experts: int = Field( - init=False, - desc="Number of MLP experts excluding shared ones", - hint=FieldHint.architecture, - valid=check_field(Assert.geq, 0), + # Use flash attention if possible (fp16 or bf16) + use_flash_attention: bool = Field( + default=True, desc="Enable Flash Attention if possible.", hint=FieldHint.optional ) - num_experts_per_token: int = Field( - default=1, - desc="Active experts for each token in a MoE model.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), + window_size: int | None = Field( + default=None, + desc="Size of the attention sliding window. Warning: this parameter is not part of the architecture and must be redefined when loading a pretrained model.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - expert_routing_type: RoutingType = Field( - default=RoutingType.topk, - desc="The routing method, i.e., the method used to assign experts to tokens.", - hint=FieldHint.architecture, + max_window_layers: int | None = Field( + default=None, + desc="The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.", + hint=FieldHint.optional, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - activation_type: ActivationType = Field( + attention_lr_scale: float | None = Field( default=None, - desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", - hint=FieldHint.core, + desc="Custom learning rate scale for the Attention projection weights.", + doc="Can be used in muP to scale the Attention learning rate by 1/width_factor", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + attention_softmax_scale_power: float = Field( + default=0.5, + desc="The scaling power to apply to kv_channel in the attention calculation. " + " Under Standard Parameterization (SP): default to 0.5. " + " Under muP (if scaling kv_channels size): use 1. " + " Under muP (if scaling number of heads instead of kv_channels): use 0.5.", + valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - # Default: hidden_size**-0.5 - # TODO: Allow custom initialization (InitializationConfig?) + + def _validate(self) -> None: + super()._validate() + + if not TritonConfig.TRITON_ENABLED: + warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") + + Assert.multiple(self.num_attention_heads, self.head_groups) + + @functools.cached_property + def projection_size(self): + assert self._validated + return self.num_attention_heads * self.kv_channels + + def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: + return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) + + def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) + # Needed for multiple inheritance. + + tensor_space.add_tensor_dim( + head_groups := TensorDim( + TransformerDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None + ) + ) + tensor_space.add_tensor_dim( + group_heads := TensorDim( + TransformerDimNames.group_heads, + div(self.num_attention_heads, self.head_groups), + None if self.head_groups > 1 else tensor, + ) + ) + tensor_space.add_tensor_dim(key_and_value := TensorDim(TransformerDimNames.key_and_value, 2)) + tensor_space.add_tensor_dim(kv_channels := TensorDim(TransformerDimNames.kv_channels, self.kv_channels)) + tensor_space.add_tensor_dim( + CompositeTensorDim(TransformerDimNames.composite_heads, (head_groups, group_heads)) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(TransformerDimNames.composite_query, (head_groups, group_heads, kv_channels)) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(TransformerDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(TransformerDimNames.composite_dense, (head_groups, group_heads, kv_channels)) + ) + + +@config_class() +# TODO: Use composition for attention config +class TransformerConfig(AttentionConfig, BlockConfig): + _abstract = False + + # TODO: Review names init_method_std: float = Field( default=None, desc="Default scale for weight initialization. Default: hidden_size**-0.5", @@ -375,125 +241,17 @@ class TransformerConfig(LLMBlockConfig): desc="Min value for clamping initialized weights for MLP second layer. Default: -float('inf')", hint=FieldHint.optional, ) - attention_dropout: float = Field( - default=0.0, - desc="Dropout applied to the attention intermediate states.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - hidden_dropout: float = Field( - default=0.0, - desc="Dropout applied to the residual connections.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - full_precision_residual: bool = Field( - default=False, - desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", - hint=FieldHint.stability, - ) - # Use flash attention if possible (fp16 or bf16) - use_flash_attention: bool = Field( - default=True, desc="Enable Flash Attention if possible.", hint=FieldHint.optional - ) - window_size: int | None = Field( - default=None, - desc="Size of the attention sliding window. Warning: this parameter is not part of the architecture and must be redefined when loading a pretrained model.", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - max_window_layers: int | None = Field( - default=None, - desc="The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.", - hint=FieldHint.optional, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - # normalization_implementation: NormalizationImplementation = NormalizationImplementation.auto - mlp_recompute_level: MLPRecomputeLevel = Field( - default=MLPRecomputeLevel.none, - desc="Set which of the MLP intermediate activations will be recomputed during the backward passes. This provides a trade-off between memory and speed.", - hint=FieldHint.performance, - ) - debug_transformer: int = Field( - default=0, - desc="Log the output of each operation in a transformer layer.", - hint=FieldHint.logging, - valid=check_field(Assert.geq, 0), - ) - debug_transformer_memory: bool = Field( - default=False, - desc="Log the memory usage after each operation in a transformer layer..", - hint=FieldHint.logging, - ) # Use random inits instead of constant values, useful for debugging. random_bias_init: bool = Field( default=False, desc="Initialize the biases using the initialization method of their respective weights instead of setting them to zero. Used to test for issues that may not be visible when the biases are zero.", hint=FieldHint.testing, ) - expert_auxiliary_loss_coefficient: float = Field( - default=0.01, - desc="Scale of the load balancing auxiliary loss for topk routing.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - expert_z_loss_coefficient: float = Field( - default=0.0, - desc="Regularize the router during training by applying Z-loss to the logits.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - moe_jitter_eps: float = Field( - default=0.0, - desc="Regularize the router during training by applying a random multiplicative noise `uniform(1-eps, 1+eps)` to the logits.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - mlp_lr_scale: float | None | list[float | None] = Field( - default=None, - desc="Custom learning rate scale for each expert.", - doc="May be used to freeze some experts by setting their scale to zero.", - hint=FieldHint.feature, - ) - router_lr_scale: float | None = Field( - default=None, - desc="Custom learning rate for the MoE router weight.", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - attention_lr_scale: float | None = Field( - default=None, - desc="Custom learning rate scale for the Attention projection weights.", - doc="Can be used in muP to scale the Attention learning rate by 1/width_factor", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - attention_softmax_scale_power: float = Field( - default=0.5, - desc="The scaling power to apply to kv_channel in the attention calculation. " - " Under Standard Parameterization (SP): default to 0.5. " - " Under muP (if scaling kv_channels size): use 1. " - " Under muP (if scaling number of heads instead of kv_channels): use 0.5.", - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - dropless_moe: bool = Field( - default=True, desc="Evaluate all the experts at once using dropless MoE.", hint=FieldHint.expert - ) - dropless_dynamic_shape: bool = Field( - default=False, - desc="Use a dynamic shape for dropless MLP instead of the worst-case value." - " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", - hint=FieldHint.expert, - ) def _validate(self) -> None: with self._set_implicit_default(): - if self.ffn_hidden_size is None: - self.ffn_hidden_size = 4 * self.hidden_size if self.kv_channels is None: self.kv_channels = div(self.hidden_size, self.num_attention_heads) - if self.activation_type is None: - self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu if self.init_method_std is None: self.init_method_std = self.hidden_size**-0.5 if self.init_method_std_qkv is None: @@ -532,40 +290,9 @@ def _validate(self) -> None: Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1) if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) - self.num_unshared_experts = self.num_experts - self.num_shared_experts super()._validate() - if not TritonConfig.TRITON_ENABLED: - warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") - - Assert.leq(self.num_shared_experts, self.num_experts) - Assert.leq(self.num_shared_experts + self.num_experts_per_token, self.num_experts) - Assert.multiple(self.num_attention_heads, self.head_groups) - Assert.geq(self.attention_dropout, 0) - Assert.geq(self.hidden_dropout, 0) - - if isinstance(self.mlp_lr_scale, list): - Assert.eq(len(self.mlp_lr_scale), self.num_experts) - for scale in self.mlp_lr_scale: - if scale is not None: - Assert.geq(scale, 0) - elif self.mlp_lr_scale is not None: - Assert.geq(self.mlp_lr_scale, 0) - - @functools.cached_property - def projection_size(self): - assert self._validated - return self.num_attention_heads * self.kv_channels - - @property - def add_mlp_bias(self) -> bool: - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.everywhere: - return True - return False - @property def add_attn_qkv_bias(self) -> bool: if isinstance(self.add_linear_biases, bool): @@ -581,84 +308,3 @@ def add_attn_dense_bias(self) -> bool: if self.add_linear_biases == AddLinearBiasChoices.everywhere: return True return False - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - # TODO v0.x: Remove backward compatibility. - cls._handle_renamed_field( - default, - "use_rotary_embeddings", - ("rotary", "type"), - lambda x: "default" if x else "none", - ) - cls._handle_renamed_field(default, "rotary_embedding_scale", ("rotary", "theta"), lambda x: math.exp(-x)) - cls._handle_renamed_field(default, "triton_rotary", ("rotary", "triton")) - return super()._from_dict(default, strict, flat) - - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) - - # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.hidden, self.hidden_size)) - - # Self-attention dimensions - tensor_space.add_tensor_dim( - head_groups := TensorDim( - TransformerDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None - ) - ) - tensor_space.add_tensor_dim( - group_heads := TensorDim( - TransformerDimNames.group_heads, - div(self.num_attention_heads, self.head_groups), - None if self.head_groups > 1 else tensor, - ) - ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(TransformerDimNames.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(TransformerDimNames.kv_channels, self.kv_channels)) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_heads, (head_groups, group_heads)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_query, (head_groups, group_heads, kv_channels)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_dense, (head_groups, group_heads, kv_channels)) - ) - - # MLP dimensions - tensor_space.add_tensor_dim(mlp := TensorDim(TransformerDimNames.mlp, self.ffn_hidden_size, tensor)) - tensor_space.add_tensor_dim(gate_and_up := TensorDim(TransformerDimNames.gate_and_up, 2 if self.gated else 1)) - tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_gated_mlp, (gate_and_up, mlp))) - tensor_space.add_tensor_dim(experts := TensorDim(TransformerDimNames.experts, self.num_experts)) - tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_expert_mlp, (experts, mlp))) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) - ) - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.top_experts, self.num_experts_per_token)) - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.unshared_experts, self.num_unshared_experts)) - - # shared_experts - if self.num_shared_experts: - tensor_space.add_tensor_dim( - shared_experts := TensorDim(TransformerDimNames.shared_experts, self.num_shared_experts) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_shared_expert_mlp, (shared_experts, mlp)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim( - TransformerDimNames.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) - ) - ) - - def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: - return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index d8425786d..2dbef77f3 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -24,8 +24,9 @@ from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType +from fast_llm.layers.block.mlp.config import RoutingType from fast_llm.layers.common.config import LayerNormalizationConfig -from fast_llm.layers.transformer.config import RoutingType, TransformerConfig +from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig from fast_llm.layers.transformer.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex from fast_llm.models.gpt.config import ( diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 49a5dcbd3..da647de57 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -10,18 +10,14 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.layers.block.mlp.config import MLPLossNames, RoutingType from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor -from fast_llm.layers.transformer.config import ( - RoutingType, - TransformerDimNames, - TransformerKwargs, - TransformerLossNames, -) +from fast_llm.layers.transformer.block import TransformerBlock +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor -from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -374,7 +370,7 @@ def loss_defs(self) -> list[LossDef]: ): loss_defs.append( LossDef( - name=TransformerLossNames.load_balancing_loss, + name=MLPLossNames.load_balancing_loss, formatted_name="load balancing loss", count=self._config.transformer.num_layers, ) @@ -382,7 +378,7 @@ def loss_defs(self) -> list[LossDef]: if self._config.transformer.expert_z_loss_coefficient: loss_defs.append( LossDef( - name=TransformerLossNames.router_z_loss, + name=MLPLossNames.router_z_loss, formatted_name="router z loss", count=self._config.transformer.num_layers, ) diff --git a/tests/test_mlp.py b/tests/test_mlp.py index bcfbaf693..4cf1ac458 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -1,8 +1,8 @@ -from fast_llm.layers.transformer.mlp import MLP -from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP -from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.block.mlp import MLP +from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP +from fast_llm.layers.transformer.config import TransformerConfig def test_mlp_constructor(): @@ -20,11 +20,7 @@ def test_mlp_constructor(): def test_moe_mlp_constructor(): transformer_conf = TransformerConfig( - num_layers=2, - num_attention_heads=2, - hidden_size=16, - num_experts=2, - add_linear_biases=False + num_layers=2, num_attention_heads=2, hidden_size=16, num_experts=2, add_linear_biases=False ) distributed_config = DistributedConfig() tensor_space = TensorSpace(distributed_config=distributed_config) From ce70b169e55dea29383eb3f6a488125b309487ce Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 29 Jul 2025 18:13:29 -0400 Subject: [PATCH 02/28] fixes --- fast_llm/layers/block/config.py | 16 +++++++++------- fast_llm/layers/block/mlp/config.py | 3 ++- fast_llm/layers/block/mlp/mlp.py | 2 +- fast_llm/layers/block/peft.py | 2 +- fast_llm/layers/ssm/mamba2.py | 2 +- fast_llm/layers/transformer/config.py | 3 ++- fast_llm/models/custom/model.py | 2 +- fast_llm/models/ssm/model.py | 8 ++++---- tests/test_mlp.py | 2 +- tests/test_multi_stage.py | 4 ++-- 10 files changed, 24 insertions(+), 20 deletions(-) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 2f26d8d79..5a999fa6d 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -4,7 +4,7 @@ from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.block.mlp.config import MLPConfig -from fast_llm.layers.block.peft_config import TransformerPeftConfig +from fast_llm.layers.block.peft import TransformerPeftConfig from fast_llm.layers.common.config import NormalizationConfig from fast_llm.utils import Assert @@ -26,11 +26,19 @@ class BlockKwargs: hidden_dims = "hidden_dims" sequence_q_dim = "sequence_q_dim" sequence_k_dim = "sequence_k_dim" + # TODO: These are confusing sequence_length = "sequence_length" + sequence_lengths = "sequence_lengths" # TODO: Belongs elsewhere? grad_output = "grad_output" +class AddLinearBiasChoices(str, enum.Enum): + nowhere = "nowhere" + everywhere = "everywhere" + only_attn_qkv = "only_attn_qkv" + + @config_class() # TODO: Use composition for MLP config class BlockConfig(MLPConfig, BaseModelConfig): @@ -112,9 +120,3 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: # Hidden dimension tensor_space.add_tensor_dim(TensorDim(BlockDimNames.hidden, self.hidden_size)) - - -class AddLinearBiasChoices(str, enum.Enum): - nowhere = "nowhere" - everywhere = "everywhere" - only_attn_qkv = "only_attn_qkv" diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 63e31219a..1d125c4f7 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -1,6 +1,6 @@ import enum -from fast_llm.config import Config, Field, FieldHint, check_field, skip_valid_if_none +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType, MLPRecomputeLevel @@ -32,6 +32,7 @@ class RoutingType(str, enum.Enum): sinkhorn = "sinkhorn" +@config_class() class MLPConfig(Config): # TODO: Review names _abstract = False diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index 04b8506a4..19349671e 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -15,7 +15,7 @@ from fast_llm.utils import Assert, get_lr_scale -class MLPBase[ConfigType: BlockConfig](Layer, Configurable[ConfigType]): +class MLPBase[ConfigType: BlockConfig](Configurable[ConfigType], Layer): def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): super().__init__(config) self._name = name diff --git a/fast_llm/layers/block/peft.py b/fast_llm/layers/block/peft.py index 269ed0aac..66bc675ed 100644 --- a/fast_llm/layers/block/peft.py +++ b/fast_llm/layers/block/peft.py @@ -8,13 +8,13 @@ from fast_llm.config import Field, FieldHint, config_class from fast_llm.layers.common.config import LoRAConfig, NoPeftConfig, PeftConfig -from fast_llm.tensor import ParameterMeta from fast_llm.utils import div if typing.TYPE_CHECKING: import torch from fast_llm.layers.common.linear import LinearBase, LinearLike + from fast_llm.tensor import ParameterMeta class TransformerSubLayerName(str, enum.Enum): diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 712c420ee..1c319f490 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -5,7 +5,7 @@ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.config import ActivationType -from fast_llm.layers.block.config import BlockConfig +from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 1c10753a8..ebb976e63 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -36,7 +36,6 @@ class TransformerKwargs(BlockKwargs): rotary_freq_k = "rotary_freq_k" attention_mask = "attention_mask" attention_mask_value = "attention_mask_value" - sequence_lengths = "sequence_lengths" cu_seqlens_q = "cu_seqlens_q" cu_seqlens_k = "cu_seqlens_k" max_seqlen_q = "max_seqlen_q" @@ -46,6 +45,7 @@ class TransformerKwargs(BlockKwargs): past_key_values = "past_key_values" +@config_class() class AttentionConfig(Config): # TODO: Make mixer class dynamic. _abstract = False @@ -126,6 +126,7 @@ def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Needed for multiple inheritance. + super().setup_tensor_space(tensor_space) # Noqa tensor_space.add_tensor_dim( head_groups := TensorDim( diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index 534d813ff..3c0ad8ab4 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -7,7 +7,7 @@ from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.schedule.config import BatchConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding -from fast_llm.layers.transformer.transformer import TransformerBlock +from fast_llm.layers.transformer.block import TransformerBlock from fast_llm.models.custom.config import CustomBaseModelConfig, CustomModelConfig from fast_llm.models.custom.head import CustomHead from fast_llm.models.gpt.config import GPTBaseModelConfig diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 3ba6b1a62..ca840911f 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -5,8 +5,8 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead -from fast_llm.layers.ssm.llamba_block import SSMBlock -from fast_llm.layers.transformer.transformer import TransformerBlock +from fast_llm.layers.ssm.block import SSMBlock +from fast_llm.layers.transformer.block import TransformerBlock from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType @@ -52,7 +52,7 @@ def get_output_layers(self) -> list[Layer]: else: layers.append( SSMBlock( - transformer_config=self._config.transformer, + config=self._config.transformer, ssm_config=self._config.ssm, mixer_cls=self._config.ssm_block_type.get_mixer_class(), block_index=len(self._config.hybrid_block_layout), @@ -88,7 +88,7 @@ def get_layers(self) -> list[Layer]: else: layers.append( SSMBlock( - transformer_config=self._config.transformer, + config=self._config.transformer, ssm_config=self._config.ssm, mixer_cls=self._config.ssm_block_type.get_mixer_class(), block_index=i + 1, diff --git a/tests/test_mlp.py b/tests/test_mlp.py index 4cf1ac458..5875822ff 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -1,7 +1,7 @@ from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.block.mlp import MLP from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP +from fast_llm.layers.block.mlp.mlp import MLP from fast_llm.layers.transformer.config import TransformerConfig diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index 2f125717e..0639ec7ed 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -3,8 +3,8 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer -from fast_llm.layers.ssm.llamba_block import SSMBlock -from fast_llm.layers.transformer.transformer import TransformerBlock +from fast_llm.layers.ssm.block import SSMBlock +from fast_llm.layers.transformer.block import TransformerBlock from fast_llm.utils import Assert from tests.utils.dataset import get_model_test_dataset from tests.utils.model_configs import ModelTestingGroup From a9f733d121e47df360b997097abb8bf2d5ac49d1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 29 Jul 2025 18:33:05 -0400 Subject: [PATCH 03/28] fix --- fast_llm/layers/ssm/block.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/layers/ssm/block.py b/fast_llm/layers/ssm/block.py index 4854900a3..0bfa266ac 100644 --- a/fast_llm/layers/ssm/block.py +++ b/fast_llm/layers/ssm/block.py @@ -31,5 +31,5 @@ def _create_mixer(self) -> Mixer: self._ssm_config, tensor_space=self._tensor_space, block_index=self._block_index, - transformer_config=self._config, + block_config=self._config, ) From a5eb0767e99038e18c1bd07f7f78718634296c4c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 31 Jul 2025 15:14:13 -0400 Subject: [PATCH 04/28] stuff --- docs/developer_guide/conversion.md | 30 ++- .../engine/config_utils/initialization.py | 178 ++++++++++++ fast_llm/layers/block/block.py | 132 +++++++-- fast_llm/layers/block/config.py | 160 +++++++++-- fast_llm/layers/block/mixer.py | 68 ----- fast_llm/layers/block/mlp/config.py | 79 +++++- .../layers/block/mlp/mixture_of_experts.py | 134 ++++------ fast_llm/layers/block/mlp/mlp.py | 72 ++--- fast_llm/layers/common/config.py | 2 +- fast_llm/layers/common/linear.py | 3 +- fast_llm/layers/common/normalization.py | 3 +- fast_llm/layers/language_model/config.py | 122 ++++----- fast_llm/layers/language_model/embedding.py | 48 ++-- fast_llm/layers/language_model/head.py | 109 ++++---- .../layers/language_model/preprocessing.py | 10 +- fast_llm/layers/ssm/config.py | 4 +- fast_llm/layers/ssm/discrete_mamba2.py | 4 +- fast_llm/layers/ssm/mamba2.py | 5 +- fast_llm/layers/ssm/mamba_layer.py | 7 +- fast_llm/layers/transformer/attention.py | 142 +++++----- fast_llm/layers/transformer/config.py | 253 ++++++------------ fast_llm/layers/transformer/preprocessing.py | 52 ++-- .../transformer/rotary/preprocessing.py | 26 +- fast_llm/layers/transformer/rotary/rotary.py | 30 +-- fast_llm/models/custom/model.py | 2 +- fast_llm/models/gpt/config.py | 9 +- fast_llm/models/gpt/conversion.py | 6 +- fast_llm/models/gpt/huggingface.py | 10 +- fast_llm/models/gpt/model.py | 54 ++-- fast_llm/models/ssm/config.py | 10 +- fast_llm/models/ssm/conversion.py | 6 +- fast_llm/tensor.py | 70 +---- tests/layers/test_lm_head.py | 6 +- tests/models/test_generate.py | 2 +- tests/test_attention.py | 16 +- tests/test_ssms.py | 6 +- tests/utils/model_configs.py | 2 + 37 files changed, 1015 insertions(+), 857 deletions(-) create mode 100644 fast_llm/engine/config_utils/initialization.py delete mode 100644 fast_llm/layers/block/mixer.py diff --git a/docs/developer_guide/conversion.md b/docs/developer_guide/conversion.md index 0620beaea..719757df1 100644 --- a/docs/developer_guide/conversion.md +++ b/docs/developer_guide/conversion.md @@ -230,21 +230,23 @@ Continuing our `AwesomeModel` handler example, we define: ```python def _create_weight_converters(self) -> list[WeightConverter]: + + converters = [] - # The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`. - num_layers = self._model.config.base_model.transformer.num_layers - - # A simple renaming example, for the word embeddings. - converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) - - # We usually want to loop dynamically over layers - for i in range(num_layers): - # A `SplitWeightConverter` example, splitting a weight in two. - converters.append(SplitWeightConverter( - f"layers.{i + 1}.weight", - (f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"), - )) - return converters +# The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`. +num_layers = self._model.config.base_model.transformer.num_blocks + +# A simple renaming example, for the word embeddings. +converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) + +# We usually want to loop dynamically over layers +for i in range(num_layers): + # A `SplitWeightConverter` example, splitting a weight in two. + converters.append(SplitWeightConverter( + f"layers.{i + 1}.weight", + (f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"), + )) +return converters ``` And that's it! We're ready to use the new checkpoint format in Fast-LLM. diff --git a/fast_llm/engine/config_utils/initialization.py b/fast_llm/engine/config_utils/initialization.py new file mode 100644 index 000000000..d35c2220c --- /dev/null +++ b/fast_llm/engine/config_utils/initialization.py @@ -0,0 +1,178 @@ +import abc +import typing + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + import torch + + from fast_llm.tensor import ParameterMeta + + +@config_class(registry=True) +class InitializationConfig(Config): + _abstract = True + has_initialization: typing.ClassVar[bool] = True + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is InitializationConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return DefaultInitializationConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + + def get_initializer(self) -> "Initializer": + raise NotImplementedError() + + +@config_class(dynamic_type={InitializationConfig: "default"}) +class DefaultInitializationConfig(InitializationConfig): + # A placeholder indicating that the class default should be used instead. + _abstract = False + has_initialization = False + + +@config_class(dynamic_type={InitializationConfig: "fill"}) +class NormalInitializationConfig(InitializationConfig): + """ + Normal initialization: normal(mean, std).clamp(min,max) + """ + + _abstract = False + + value: float = Field( + default=1, + desc="Initialization value.", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + + def get_initializer(self): + return init_fill_(self.value) + + +@config_class(dynamic_type={InitializationConfig: "zeros"}) +class ZeroInitializationConfig(InitializationConfig): + def get_initializer(self): + return init_zeros_ + + +@config_class(dynamic_type={InitializationConfig: "ones"}) +class ZeroInitializationConfig(InitializationConfig): + def get_initializer(self): + return init_ones_ + + +@config_class(dynamic_type={InitializationConfig: "normal"}) +class NormalInitializationConfig(InitializationConfig): + """ + Normal initialization: normal(mean, std).clamp(min,max) + """ + + _abstract = False + + std: float = Field( + default=1, + desc="Standard deviation for normal initialization.", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + mean: float = Field( + default=0, + desc="Mean for normal initialization.", + hint=FieldHint.optional, + ) + min: float | None = Field( + default=None, + desc="Min value for initialization clamping.", + hint=FieldHint.optional, + ) + max: float | None = Field( + default=None, + desc="Min value for initialization clamping.", + hint=FieldHint.optional, + ) + + def get_initializer(self): + return init_normal_(self.mean, self.std, self.min, self.max) + + +@config_class(dynamic_type={InitializationConfig: "uniform"}) +class UniformInitializationConfig(InitializationConfig): + """ + Uniform initialization: uniform(mean - scale, mean + scale).clamp(min,max) + """ + + _abstract = False + + scale: float = Field( + default=None, + desc="Initialization scale.", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + mean: float = Field( + default=None, + desc="Initialization mean.", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + + def get_initializer(self) -> "Initializer": + return init_uniform_centered_(self.scale, self.mean) + + +class Initializer(abc.ABC): + @abc.abstractmethod + def __call__(self, meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: + pass + + requires_global_initialization = False + + +class LambdaInitializer(Initializer): + def __init__( + self, + init_method: typing.Callable[["ParameterMeta", "torch.Tensor", "torch.Generator"], None], + requires_global_initialization: bool = False, + ) -> None: + self._init_method = init_method + self.requires_global_initialization = requires_global_initialization + + def __call__(self, meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: + return self._init_method(meta, tensor, generator) + + +def init_fill_(value: float) -> LambdaInitializer: + def init_(meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: # noqa + tensor.fill_(value) + + return LambdaInitializer(init_) + + +init_zeros_ = init_fill_(0.0) +init_ones_ = init_fill_(1.0) + + +def init_normal_( + mean: float = 0.0, std: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: # noqa + tensor = tensor.normal_(mean, std, generator=generator) + if min_val is not None or max_val is not None: + tensor.clamp_(min=min_val, max=max_val) + + return LambdaInitializer(init_) + + +def init_uniform_centered_(scale: float, mean: float = 0.0) -> LambdaInitializer: + def init_(meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: # noqa + tensor.uniform_(mean - scale, mean + scale, generator=generator) + + return LambdaInitializer(init_) diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 85da61c01..d13b09807 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -1,4 +1,5 @@ import abc +import functools import typing import torch @@ -8,23 +9,118 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs -from fast_llm.layers.block.mixer import Mixer -from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP -from fast_llm.layers.block.mlp.mlp import MLP +from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs, BlockLayerConfig from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta -class Block[ConfigType: BlockConfig](Layer, Configurable[ConfigType]): +class DebugLayer: + # TODO: Move elsewhere? + def __init__(self, tensor_space: TensorSpace, name: str, debug_level: int = 0, debug_memory: bool = False): + self._tensor_space = tensor_space + self._name = name + self._debug_level = debug_level + self._debug_memory = debug_memory + + def _get_meta( + self, tensor: torch.Tensor, name: str, dims: tuple[TensorDim | str, ...], kwargs: dict[str, typing.Any] + ) -> TensorMeta: + hidden_dims = { + dim.name: dim for dim in kwargs[BlockKwargs.hidden_dims] + (kwargs[BlockKwargs.sequence_q_dim],) + } + return TensorMeta.from_dims( + tuple( + ( + dim + if isinstance(dim, TensorDim) + else hidden_dims[dim] if dim in hidden_dims else self._tensor_space[dim] + ) + for dim in dims + ), + tensor_name=f"{self._name} {name}", + dtype=tensor.dtype, + ) + + @functools.cached_property + def enabled(self) -> bool: + return self._debug_level > 0 or self._debug_memory + + def __call__( + self, + tensor: torch.Tensor, + name: str, + dims: tuple[TensorDim | str, ...], + kwargs: dict[str, typing.Any], + scale: float = 1.0, + global_: bool = True, + log_fn: type[BaseException] | typing.Callable[[str], T] | None = logger.info, + ) -> None: + # TODO: Local vs global? + if self._debug_memory: + log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"{self._name} {name}", str)) + if self._debug_level > 0: + log_distributed_tensor( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name, dims, kwargs), + distributed=self._tensor_space.distributed, + global_=global_, + log_fn=log_fn, + scale=scale, + ) + if tensor.requires_grad: + log_distributed_grad( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name + " grad", dims, kwargs), + distributed=self._tensor_space.distributed, + global_=global_, + log_fn=log_fn, + scale=scale, + ) + + +class BlockLayer[ConfigType: BlockLayerConfig](Configurable[ConfigType], torch.nn.Module): + """ + Base class for mixer and MLP modules. + """ + + def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str): + super().__init__(config) + self._tensor_space = tensor_space + self._block_index = block_index + self._name = name + self._sequence_parallel: bool = self._tensor_space.distributed_config.sequence_tensor_parallel + self._debug = DebugLayer( + tensor_space, + f"Block {self._block_index} {self._name}", + self.config.block.debug_transformer, + self._config.block.debug_transformer_memory, + ) + + @abc.abstractmethod + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + pass + + +class Block[ConfigType: BlockConfig](Configurable[ConfigType], Layer): """ A transformer-like decoder base block with abstract mixer. """ # TODO: Standardize to `mixer` - _mixer_module_name: typing.ClassVar[str] = "mixer" - def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, return_input: bool = False): + def __init__( + self, config: ConfigType, tensor_space: TensorSpace, block_index: int = 0, return_input: bool = False + ): super().__init__() self._config = config self._tensor_space: TensorSpace = tensor_space @@ -40,21 +136,19 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: i self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) - # The mixer needs to be created here for backward-compatible weight ordering. - setattr(self, self._mixer_module_name, self._create_mixer()) - - self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, self._tensor_space, f"{self.name} mlp", block_index=block_index + # Attribute should be mixer, but Attention uses a different name for backward compatibility. TODO: Fix. + setattr( + self, + self._config.mixer.module_name, + self._config.mixer.get_layer(self._tensor_space, block_index, f"{self.name} mixer"), ) + self.mlp = self._config.mlp.get_layer(self._tensor_space, block_index, f"{self.name} mlp") + # PEFT. self.norm_1 = self._config.peft.apply_other(self.norm_1) self.norm_2 = self._config.peft.apply_other(self.norm_2) - @abc.abstractmethod - def _create_mixer(self) -> Mixer: - pass - @torch.compile def _bias_dropout_add( self, input_: torch.Tensor, bias: torch.Tensor | None, residual: torch.Tensor @@ -113,13 +207,13 @@ def forward( hidden_states = self.norm_1(input_) if self._debug_mode: self._debug_log(hidden_states, "Norm 1", kwargs) - hidden_states, bias = getattr(self, self._mixer_module_name)(hidden_states, kwargs) + hidden_states, bias = getattr(self, self._config.mixer.module_name)(hidden_states, kwargs) if self._debug_mode: - self._debug_log(hidden_states, f"{self._mixer_module_name} output", kwargs, bias=bias) + self._debug_log(hidden_states, f"{self._config.mixer.module_name} output", kwargs, bias=bias) with set_generator(generator): input_ = self._bias_dropout_add(hidden_states, bias, input_) if self._debug_mode: - self._debug_log(input_, f"{self._mixer_module_name} residual", kwargs) + self._debug_log(input_, f"{self._config.mixer.module_name} residual", kwargs) hidden_states = self.norm_2(input_) if self._debug_mode: self._debug_log(hidden_states, "Norm 2", kwargs) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 5a999fa6d..87bd6d249 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -1,13 +1,21 @@ +import abc import enum +import functools +import typing from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.block.mlp.config import MLPConfig from fast_llm.layers.block.peft import TransformerPeftConfig from fast_llm.layers.common.config import NormalizationConfig from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.layers.block.block import Block, BlockLayer + + +# TODO: Generalize these beyond language models? (Ex. vision) + class BlockDimNames: # A set of common tensor dim names packed into a namespace. @@ -39,10 +47,76 @@ class AddLinearBiasChoices(str, enum.Enum): only_attn_qkv = "only_attn_qkv" +@config_class(registry=True) +class BlockLayerConfig(BaseModelConfig): + _abstract = True + block: "BlockConfig" = Field(init=False) + + def _validate(self) -> None: + assert hasattr(self, "block") + Assert.is_(self.block.mlp, self) + super()._validate() + + @property + def layer_class(self) -> "type[BlockLayer]": + raise NotImplementedError() + + def get_layer(self, tensor_space: TensorSpace, block_index: int, name: str) -> "BlockLayer": + return self.layer_class(self, tensor_space, block_index, name) + + +@config_class() +class MixerConfig(BlockLayerConfig): + _abstract = True + + # Needed for backward compatibility. + module_name: typing.ClassVar[str] = "mixer" + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is MixerConfig and cls.get_subclass(default.get("type")) is None: + from fast_llm.layers.transformer.config import AttentionConfig + + # Default subclass. + return AttentionConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + + @config_class() -# TODO: Use composition for MLP config -class BlockConfig(MLPConfig, BaseModelConfig): +class MLPBaseConfig(BlockLayerConfig): + _abstract = True + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is MLPBaseConfig and cls.get_subclass(default.get("type")) is None: + from fast_llm.layers.block.mlp.config import MLPConfig + # Default subclass. + return MLPConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + + +@config_class() +class BlockConfig(BaseModelConfig): + _abstract = False + mixer: MixerConfig = Field( + desc="Configuration for the mixer.", + hint=FieldHint.architecture, + ) + mlp: MLPBaseConfig = Field( + desc="Configuration for the MLP.", + hint=FieldHint.architecture, + ) # TODO: Review names normalization: NormalizationConfig = Field( desc="Configuration for the normalization layers architecture.", @@ -58,11 +132,6 @@ class BlockConfig(MLPConfig, BaseModelConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - full_precision_residual: bool = Field( - default=False, - desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", - hint=FieldHint.stability, - ) debug_transformer: int = Field( default=0, desc="Log the output of each operation in a transformer layer.", @@ -80,8 +149,45 @@ class BlockConfig(MLPConfig, BaseModelConfig): hint=FieldHint.architecture, ) + block_sequence: "BlockSequenceConfig" = Field(init=False) + + def _validate(self) -> None: + assert hasattr(self, "block_sequence") + Assert.incl(self, self.block_sequence.blocks.values()) + self.mixer.block = self + self.mlp.block = self + super()._validate() + + def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + self.mlp.setup_tensor_space(tensor_space) + self.mixer.setup_tensor_space(tensor_space) + + # Hidden dimension + tensor_space.add_tensor_dim(TensorDim(BlockDimNames.hidden, self.block_sequence.hidden_size)) + + @abc.abstractmethod + def get_block(self) -> "Block": + pass + + +@config_class() +class BlockSequenceConfig(BaseModelConfig): + _abstract = True + + blocks: dict[str, BlockConfig] = Field() + block_pattern: tuple[str, ...] = Field( + default=None, + desc="The pattern of blocks (referred by name) to use. The sequence is repeated until reaching `num_blocks`." + " Default: cycle over `blocks` in the order they are defined.", + ) + default_block: str = Field( + default=None, + desc="The default block configuration to use when referring to the model." + " Used to set some defaults in the language model.", + ) + # TODO: Move these, not specific to a single block. - num_layers: int = Field( + num_blocks: int = Field( default=12, desc="Number of layers in the transformer.", hint=FieldHint.architecture, @@ -93,30 +199,28 @@ class BlockConfig(MLPConfig, BaseModelConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - per_layer_lr_scale: list[float] | None = Field( - default=None, - desc="Custom learning rate scale for each layer.", - doc="May be used to freeze some layers by setting their scale to zero.", - hint=FieldHint.feature, + full_precision_residual: bool = Field( + default=False, + desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", + hint=FieldHint.stability, ) def _validate(self) -> None: - with self._set_implicit_default(): - if self.ffn_hidden_size is None: - self.ffn_hidden_size = 4 * self.hidden_size - + for block in self.blocks.values(): + block.validate() + if self.block_pattern is None: + self.block_pattern = tuple(self.blocks) + if self.default_block is None: + self.default_block = self.block_pattern[0] super()._validate() - @property - def add_mlp_bias(self) -> bool: - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.everywhere: - return True - return False + def get_block_config(self, block_index: int) -> BlockConfig: + return self.blocks[self.block_pattern[block_index % len(self.block_pattern)]] def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - super().setup_tensor_space(tensor_space) + for block in self.blocks.values(): + block.setup_tensor_space(tensor_space) - # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(BlockDimNames.hidden, self.hidden_size)) + @functools.cached_property + def default_block_config(self) -> BlockConfig: + return self.blocks[self.default_block] diff --git a/fast_llm/layers/block/mixer.py b/fast_llm/layers/block/mixer.py deleted file mode 100644 index 5c811e330..000000000 --- a/fast_llm/layers/block/mixer.py +++ /dev/null @@ -1,68 +0,0 @@ -import abc -import typing - -import torch - -from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.block.config import BlockKwargs -from fast_llm.logging import log_distributed_grad, log_distributed_tensor -from fast_llm.tensor import TensorMeta -from fast_llm.utils import Assert - - -class Mixer(torch.nn.Module, abc.ABC): - """ - Base class for mixer modules. - """ - - _mixer_name: typing.ClassVar[str] - - def __init__(self, tensor_space: TensorSpace, block_index: int, debug_level: int = 0): - super().__init__() - self._tensor_space = tensor_space - self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel - self._block_index = block_index - self._debug_level = debug_level - - @abc.abstractmethod - def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - Mixer module forward. Returns the output hidden states and an optional bias, - in case its addition can be made more efficient in `_bias_dropout_add`. - """ - - def _get_meta( - self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> TensorMeta: - hidden_dims = { - dim.name: dim for dim in kwargs[BlockKwargs.hidden_dims] + (kwargs[BlockKwargs.sequence_q_dim],) - } - return TensorMeta.from_dims( - tuple( - hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space[dim_name] - for dim_name in dim_names - ), - tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", - dtype=input_.dtype, - ) - - def _debug_log( - self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> None: - # TODO: Local vs global - Assert.gt(self._debug_level, 0) - log_distributed_tensor( - "", - tensor, - level=self._debug_level, - meta=self._get_meta(tensor, name, dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - if tensor.requires_grad: - log_distributed_grad( - "", - tensor, - level=self._debug_level, - meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 1d125c4f7..526c513db 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -1,11 +1,18 @@ import enum +import functools +import typing -from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_zeros_ from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType, MLPRecomputeLevel from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.layers.block.config import AddLinearBiasChoices, BlockLayerConfig + from fast_llm.layers.block.mlp.mlp import MLPBase + class MLPDimNames: # MLP dimensions @@ -32,9 +39,10 @@ class RoutingType(str, enum.Enum): sinkhorn = "sinkhorn" -@config_class() -class MLPConfig(Config): +@config_class(dynamic_type={BlockLayerConfig: "mlp"}) +class MLPConfig(BlockLayerConfig): # TODO: Review names + # TODO: Separate MoE? _abstract = False ffn_hidden_size: int = Field( default=None, @@ -124,11 +132,52 @@ class MLPConfig(Config): " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", hint=FieldHint.expert, ) + layer_1_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for the first mlp layer weights. Default: hidden_size**-0.5", + hint=FieldHint.feature, + ) + layer_1_bias_initialization: InitializationConfig = Field( + desc="Initialization configuration for the first mlp layer biases. Default: fill with zeros.", + hint=FieldHint.feature, + ) + layer_2_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for the second mlp layer weights." + " Default: (2 * num_blocks * hidden_size)**-0.5", + hint=FieldHint.feature, + ) + layer_2_bias_initialization: InitializationConfig = Field( + desc="Initialization configuration for the second mlp layer biases. Default: fill with zeros.", + hint=FieldHint.feature, + ) + + @property + def layer_class(self) -> "type[MLPBase]": + if self.num_experts > 1: + from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP + + return MixtureOfExpertMLP + else: + from fast_llm.layers.block.mlp.mlp import MLP + + return MLP + + @property + def add_bias(self) -> bool: + if isinstance(self.block.add_linear_biases, bool): + return self.block.add_linear_biases + if self.block.add_linear_biases == AddLinearBiasChoices.everywhere: + return True + return False def _validate(self) -> None: + assert hasattr(self, "block") + with self._set_implicit_default(): if self.activation_type is None: self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu + if self.ffn_hidden_size is None: + # TODO: hidden_size not yet validated. + self.ffn_hidden_size = 4 * self.block.block_sequence.hidden_size self.num_unshared_experts = self.num_experts - self.num_shared_experts super()._validate() @@ -144,6 +193,30 @@ def _validate(self) -> None: elif self.mlp_lr_scale is not None: Assert.geq(self.mlp_lr_scale, 0) + @functools.cached_property + def layer_1_weight_initialization_method(self) -> Initializer: + if not self.layer_1_weight_initialization.has_initialization: + return self.layer_1_weight_initialization.get_initializer() + return self.block.block_sequence.hidden_size**-0.5 + + @functools.cached_property + def layer_1_bias_initialization_method(self) -> Initializer: + if not self.layer_1_bias_initialization.has_initialization: + return self.layer_1_bias_initialization.get_initializer() + return init_zeros_ + + @functools.cached_property + def layer_2_weight_initialization_method(self) -> Initializer: + if self.layer_2_weight_initialization.has_initialization: + return self.layer_2_weight_initialization.get_initializer() + return self.block.block_sequence.hidden_size**-0.5 / max(2 * self.block.block_sequence.num_blocks, 1) + + @functools.cached_property + def layer_2_bias_initialization_method(self) -> Initializer: + if self.layer_2_bias_initialization.has_initialization: + return self.layer_2_bias_initialization.get_initializer() + return init_zeros_ + def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index 8d092b6dc..332d3109f 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -1,27 +1,24 @@ import logging -import typing import warnings import torch from fast_llm.core.distributed import ProcessGroup, set_generator -from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank +from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped from fast_llm.functional.triton.sparse_copy import get_sparse_map -from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs -from fast_llm.layers.block.mlp.config import MLPDimNames, MLPLossNames, RoutingType +from fast_llm.layers.block.config import BlockDimNames, BlockKwargs +from fast_llm.layers.block.mlp.config import MLPConfig, MLPDimNames, MLPLossNames, RoutingType from fast_llm.layers.block.mlp.mlp import MLPBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.linear import Linear -from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage -from fast_llm.tensor import TensorMeta, init_normal_ -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.utils import get_lr_scale logger = logging.getLogger(__name__) -class MixtureOfExpertMLP[ConfigType: BlockConfig](MLPBase[ConfigType]): +class MixtureOfExpertMLP[ConfigType: MLPConfig](MLPBase[ConfigType]): """ MoeLayer following implementation from https://github.com/NVIDIA/Megatron-LM/blob/46ebc0e4202c980d98900000d455f754a7ff9d4b/megatron/model/transformer.py#L346 @@ -35,23 +32,10 @@ class MixtureOfExpertMLP[ConfigType: BlockConfig](MLPBase[ConfigType]): _group: ProcessGroup - def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): - Assert.gt(config.num_experts, 1) + def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str): + super().__init__(config, tensor_space, block_index, name) # TODO: Implement? - assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, tensor_space, name, block_index) - self._tensor_space = tensor_space - self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory - - self._num_experts = config.num_experts - self._experts_per_token = config.num_experts_per_token - self._num_shared_experts = config.num_shared_experts - self._num_unshared_experts = config.num_unshared_experts - - self._routing_type = config.expert_routing_type - self._load_balancing_factor = config.expert_auxiliary_loss_coefficient - self._z_loss_factor = config.expert_z_loss_coefficient - self._moe_jitter_eps = config.moe_jitter_eps + assert not self._config.add_linear_biases, "Biases not supported for MoE." layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) @@ -72,21 +56,20 @@ def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = " ) dropless_moe = False self._mlp_forward = self._forward_dropless if dropless_moe else self._forward_looped - self._dynamic_shape = config.dropless_dynamic_shape def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: hidden_states = input_.flatten(0, -2) logits = self.router(hidden_states) - if self._debug_mode: - self._debug_log(logits, "Router logits", MLPDimNames.experts, kwargs) + if self._debug.enabled: + self._debug(logits, "Router logits", kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.experts,), kwargs) # Apply z_loss if applicable - if self._z_loss_factor > 0.0: + if self._config.expert_z_loss_coefficient > 0.0: logits = z_loss( logits, - self._z_loss_factor, + self._config.expert_z_loss_coefficient, self.training, grad_scale=kwargs.get("grad_output"), losses=losses, @@ -94,24 +77,31 @@ def forward( ) # Apply input_jitter if applicable: - if self.training and self._moe_jitter_eps > 0.0: + if self.training and self._config.moe_jitter_eps > 0.0: with set_generator(self._tensor_space.distributed.pp_generator): logits = self._apply_input_jitter(logits) # Routing - if self._routing_type == RoutingType.topk: + if self._config.expert_routing_type == RoutingType.topk: scores, top_experts = self._topk_routing(logits, kwargs.get(BlockKwargs.grad_output), losses) - if self._num_shared_experts > 0: + if self._config.num_shared_experts > 0: scores, top_experts = self._add_shared_experts(top_experts, scores) - elif self._routing_type == RoutingType.sinkhorn: + elif self._config.expert_routing_type == RoutingType.sinkhorn: scores, top_experts = self._sinkhorn_routing(logits) else: - raise NotImplementedError(self._routing_type) + raise NotImplementedError(self._config.expert_routing_type) - if self._debug_mode: + if self._debug.enabled: # To log all ranks set `global_=False` - self._debug_log(scores, "Router scores", MLPDimNames.top_experts, kwargs) - self._debug_log(top_experts, "Router top experts", MLPDimNames.top_experts, kwargs) + self._debug( + scores, "Router scores", kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.top_experts,), kwargs + ) + self._debug( + top_experts, + "Router top experts", + kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.top_experts,), + kwargs, + ) return self._mlp_forward(hidden_states, scores, top_experts).view_as(input_), None # noqa @@ -119,7 +109,9 @@ def _forward_dropless( self, hidden_states: torch.Tensor, scores: torch.Tensor, top_experts: torch.Tensor ) -> torch.Tensor: # Compute token counts and the sparse mapping (dense_row, top_index) -> sparse_row. - sparse_map = get_sparse_map(top_experts, self._num_experts, dynamic_shape=self._dynamic_shape) + sparse_map = get_sparse_map( + top_experts, self._config.num_experts, dynamic_shape=self._config.dropless_dynamic_shape + ) # Sparse MLP return mlp_autograd( @@ -148,7 +140,7 @@ def _forward_looped( top_experts, self.layer_1.weight, self.layer_2.weight, - self._num_experts, + self._config.num_experts, self._config.gated, self._config.activation_type, self._intermediate_dim.parallel_group, @@ -159,7 +151,9 @@ def _forward_looped( @torch.compile def _apply_input_jitter(self, logits: torch.Tensor) -> torch.Tensor: - return logits * torch.empty_like(logits).uniform_(1.0 - self._moe_jitter_eps, 1.0 + self._moe_jitter_eps) + return logits * torch.empty_like(logits).uniform_( + 1.0 - self._config.moe_jitter_eps, 1.0 + self._config.moe_jitter_eps + ) def _topk_routing( self, @@ -167,11 +161,11 @@ def _topk_routing( grad_scale: float | None = None, losses: dict | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - top_logits, top_experts = torch.topk(logits, k=self._experts_per_token, dim=-1) + top_logits, top_experts = torch.topk(logits, k=self._config.num_experts_per_token, dim=-1) scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32) if losses is not None or (self.training and grad_scale is not None): probs = torch.softmax(logits, dim=-1, dtype=torch.float32) - mask = torch.nn.functional.one_hot(top_experts, num_classes=self._num_unshared_experts).sum(dim=1) + mask = torch.nn.functional.one_hot(top_experts, num_classes=self._config.num_unshared_experts).sum(dim=1) # Auxiliary loss, corresponding to the sum of probabilities for the top experts. # In the optimal case (uniform distribution), loss = experts_per_token / num_experts. # In the worst case (whole distribution in the top experts), loss = 1. @@ -182,7 +176,9 @@ def _topk_routing( losses[MLPLossNames.load_balancing_loss].append(aux_loss.detach()) if self.training and grad_scale is not None: scores = AuxiliaryLoss.apply( - scores, aux_loss, self._num_unshared_experts * self._load_balancing_factor * grad_scale + scores, + aux_loss, + self._config.num_unshared_experts * self._config.expert_auxiliary_loss_coefficient * grad_scale, ) return scores, top_experts @@ -191,69 +187,33 @@ def _add_shared_experts( ) -> tuple[torch.Tensor, torch.Tensor]: # Add the shared experts (last ones) to the top experts. shared_experts = torch.arange( - self._num_unshared_experts, self._num_experts, device=top_experts.device, dtype=top_experts.dtype + self._config.num_unshared_experts, + self._config.num_experts, + device=top_experts.device, + dtype=top_experts.dtype, )[None].repeat(top_experts.size(0), 1) top_experts = torch.cat((shared_experts, top_experts), dim=1) # Add scores of 1 to scores for shared experts. - scores = torch.cat((scores.new_ones(scores.size(0), self._num_shared_experts), scores), dim=1) + scores = torch.cat((scores.new_ones(scores.size(0), self._config.num_shared_experts), scores), dim=1) return scores, top_experts def _sinkhorn_routing(self, logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: if self.training: - _, top_experts = torch.topk(sinkhorn(logits), k=self._experts_per_token, dim=-1) + _, top_experts = torch.topk(sinkhorn(logits), k=self._config.num_experts_per_token, dim=-1) logits = self._sinkhorn_activation(logits) scores = torch.gather(logits, -1, top_experts) else: logits = self._sinkhorn_activation(logits) - scores, top_experts = torch.topk(logits, k=self._experts_per_token, dim=-1) + scores, top_experts = torch.topk(logits, k=self._config.num_experts_per_token, dim=-1) return scores, top_experts def _sinkhorn_activation(self, logits: torch.Tensor) -> torch.Tensor: return ( torch.sigmoid(logits) - if self._experts_per_token == 1 + if self._config.num_experts_per_token == 1 else torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) ) - def _debug_log( - self, - tensor: torch.Tensor | None, - name: str, - dim_name: str, - kwargs: dict[str, typing.Any], - global_: bool = True, - ) -> None: - if self._config.debug_transformer_memory: - log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"{self._name} {name}", str)) - if self._config.debug_transformer and tensor is not None: - # TODO: Local vs global - meta = self._get_meta(tensor, name, dim_name, kwargs) - log_distributed_tensor( - "", - tensor.view_as(meta), - level=self._config.debug_transformer, - meta=meta, - distributed=self._tensor_space.distributed, - global_=global_, - ) - if tensor.requires_grad: - log_distributed_grad( - "", - tensor, - level=self._config.debug_transformer, - meta=self._get_meta(tensor, name + " grad", dim_name, kwargs), - distributed=self._tensor_space.distributed, - grad_fn=lambda tensor_: tensor_.view_as(meta), - global_=global_, - ) - - def _get_meta(self, tensor: torch.Tensor, name: str, dim_name: str, kwargs: dict[str, typing.Any]) -> TensorMeta: - return TensorMeta.from_dims( - kwargs[BlockKwargs.hidden_dims][:-1] + (self._tensor_space[dim_name],), - tensor_name=f"{self._name} {name}", - dtype=tensor.dtype, - ) - def sinkhorn(cost: torch.Tensor, tolerance: float = 1e-5, eps=1e-9) -> torch.Tensor: """Sinkhorn based MoE routing function""" diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index 19349671e..aba5639b5 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -2,75 +2,77 @@ import torch -from fast_llm.config import Configurable -from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_ from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd -from fast_llm.layers.block.config import BlockConfig, BlockDimNames -from fast_llm.layers.block.mlp.config import MLPDimNames +from fast_llm.layers.block.block import BlockLayer +from fast_llm.layers.block.config import BlockDimNames +from fast_llm.layers.block.mlp.config import MLPConfig, MLPDimNames from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import LinearBase -from fast_llm.tensor import init_normal_, init_zeros_ -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.utils import get_lr_scale -class MLPBase[ConfigType: BlockConfig](Configurable[ConfigType], Layer): - def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): - super().__init__(config) - self._name = name - self._block_index = block_index +class MLPBase[ConfigType: MLPConfig](BlockLayer[ConfigType]): + _name: typing.ClassVar[str] = "mlp" + + def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str): + super().__init__(config, tensor_space, block_index, name) init_method_1 = init_normal_( - std=config.init_method_std_mlp_1, - min_val=config.init_method_min_mlp_1, - max_val=config.init_method_max_mlp_1, + std=self._config.init_method_std_mlp_1, + min_val=self._config.init_method_min_mlp_1, + max_val=self._config.init_method_max_mlp_1, ) init_method_2 = init_normal_( - std=config.init_method_std_mlp_2, - min_val=config.init_method_min_mlp_2, - max_val=config.init_method_max_mlp_2, + std=self._config.init_method_std_mlp_2, + min_val=self._config.init_method_min_mlp_2, + max_val=self._config.init_method_max_mlp_2, ) - hidden_dim = tensor_space[BlockDimNames.hidden] - self._intermediate_dim = tensor_space[MLPDimNames.composite_expert_mlp] - self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel + hidden_dim = self._tensor_space[BlockDimNames.hidden] + self._intermediate_dim = self._tensor_space[MLPDimNames.composite_expert_mlp] self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None - lr_scale = tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale + layer_lr_scale = ( + self._config.block.block_sequence.per_layer_lr_scale[self._block_index] + if self._config.block.block_sequence.per_layer_lr_scale + else None + ) + lr_scale = ( + tuple(self._config.mlp_lr_scale) + if isinstance(self._config.mlp_lr_scale, list) + else self._config.mlp_lr_scale + ) lr_scale = get_lr_scale(lr_scale, layer_lr_scale) # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, - tensor_space[MLPDimNames.composite_gated_expert_mlp], - bias=config.add_mlp_bias, + self._tensor_space[MLPDimNames.composite_gated_expert_mlp], + bias=self._config.add_bias, weight_init_method=init_method_1, - bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, + bias_init_method=init_method_1 if self._config.random_bias_init else init_zeros_, lr_scale=lr_scale, ) self.layer_2 = LinearBase( self._intermediate_dim, hidden_dim, - bias=config.add_mlp_bias, + bias=self._config.add_bias, weight_init_method=init_method_2, - bias_init_method=init_method_2 if config.random_bias_init else init_zeros_, - auto_bias_grad_accumulation=tensor_space.distributed_config.tensor_parallel > 1, + bias_init_method=init_method_2 if self._config.random_bias_init else init_zeros_, + auto_bias_grad_accumulation=self._tensor_space.distributed_config.tensor_parallel > 1, transposed_weight=True, lr_scale=lr_scale, ) # PEFT. - self.layer_1 = config.peft.apply_linear(self.layer_1, TransformerSubLayerName.mlp_1) - self.layer_2 = config.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) - + self.layer_1 = self._config.block.peft.apply_linear(self.layer_1, TransformerSubLayerName.mlp_1) + self.layer_2 = self._config.block.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) -class MLP[ConfigType: BlockConfig](MLPBase[ConfigType]): - def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): - Assert.eq(config.num_experts, 1) - super().__init__(config, tensor_space, name, block_index) +class MLP[ConfigType: MLPConfig](MLPBase[ConfigType]): def forward( self, input_: torch.Tensor, diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 9d5ce3f3b..2f45fdf9f 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -87,7 +87,7 @@ class LayerNormalizationBaseConfig(NormalizationConfig): ) def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": - from fast_llm.tensor import init_uniform_centered_ + from fast_llm.engine.config_utils.initialization import init_uniform_centered_ kwargs = { "hidden_dim": hidden_dim, diff --git a/fast_llm/layers/common/linear.py b/fast_llm/layers/common/linear.py index 7249ef569..740b4847c 100644 --- a/fast_llm/layers/common/linear.py +++ b/fast_llm/layers/common/linear.py @@ -3,6 +3,7 @@ import torch +from fast_llm.engine.config_utils.initialization import init_zeros_ from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.linear import ( @@ -14,7 +15,7 @@ output_parallel_linear_backward, output_parallel_linear_forward, ) -from fast_llm.tensor import ParameterMeta, init_zeros_ +from fast_llm.tensor import ParameterMeta logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index bccc1d627..d44be3297 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -1,11 +1,12 @@ import torch +from fast_llm.engine.config_utils.initialization import init_ones_, init_zeros_ from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.normalization import triton_normalization_autograd from fast_llm.layers.common.config import NormalizationImplementation -from fast_llm.tensor import ParameterMeta, accumulate_gradient, init_ones_, init_zeros_ +from fast_llm.tensor import ParameterMeta, accumulate_gradient from fast_llm.utils import Assert try: diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index b667e5318..2e7d71963 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,13 +1,11 @@ -import typing +import functools from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl -from fast_llm.layers.block.config import BlockDimNames, BlockKwargs -from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.layers.transformer.rotary.config import NoRotaryConfig +from fast_llm.layers.block.config import BlockDimNames, BlockKwargs, BlockSequenceConfig from fast_llm.utils import Assert @@ -46,27 +44,27 @@ class LanguageModelKwargs(BlockKwargs): @config_class() -class LanguageModelBaseConfig(BaseModelConfig): - # TODO: block - transformer: TransformerConfig = Field( - desc="Configuration for the transformer architecture.", +class LanguageModelConfig(BlockSequenceConfig): + decoder: BlockSequenceConfig = Field( hint=FieldHint.architecture, ) - max_position_embeddings: int = Field( - default=2048, - desc="Number of absolute position embeddings, if applicable.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) vocab_size: int = Field( default=49152, desc="Size of the vocabulary, i.e., number of vocabulary embeddings and logits.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - use_position_embeddings: bool = Field( + embedding_dropout: float = Field( + # TODO: backward compatibility? + default=0.0, + desc="Dropout applied to the embedding layer.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + absolute_position_embeddings: int | None = Field( + # TODO: backward compatibility? default=None, - desc="Enable absolute position embeddings. Default: Enable unless using rotary embeddings.", + desc="Number of absolute position embeddings, if applicable.", hint=FieldHint.architecture, ) tie_word_embeddings: bool = Field( @@ -80,22 +78,6 @@ class LanguageModelBaseConfig(BaseModelConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - init_method_std_embed: float = Field( - default=None, - desc="Initialization scale for the vocabulary embedding and output weights (logits).", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - init_method_max_embed: float | None = Field( - default=None, - desc="Max value for clamping initialized weights of the vocabulary embedding and output (logits).", - hint=FieldHint.feature, - ) - init_method_min_embed: float | None = Field( - default=None, - desc="Min value for clamping initialized weights of the vocabulary embedding and output (logits).", - hint=FieldHint.feature, - ) enable_dpo: bool | None = Field( default=False, desc="Whether to enable DPO loss", @@ -203,26 +185,27 @@ class LanguageModelBaseConfig(BaseModelConfig): doc="If not provided, all heads are equally weighted.", hint=FieldHint.feature, ) + word_embedding_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for word embeddings. Default: hidden_size**-0.5", + hint=FieldHint.feature, + ) + position_embedding_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for position embeddings. Default: hidden_size**-0.5", + hint=FieldHint.feature, + ) + output_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for untied output weights. Default: hidden_size**-0.5", + hint=FieldHint.feature, + ) def _validate(self) -> None: - self.transformer.validate() with self._set_implicit_default(): if self.language_model_loss_factor is None: if self.distillation_model is None: self.language_model_loss_factor = 1.0 else: self.language_model_loss_factor = 0.0 - if self.use_position_embeddings is None: - self.use_position_embeddings = isinstance(self.transformer.rotary, NoRotaryConfig) - if self.init_method_std_embed is None: - self.init_method_std_embed = self.transformer.init_method_std - if self.init_method_max_embed is None: - self.init_method_max_embed = self.transformer.init_method_max - if self.init_method_min_embed is None: - self.init_method_min_embed = self.transformer.init_method_min super()._validate() - if self.init_method_max_embed is not None and self.init_method_min_embed is not None: - Assert.leq(self.init_method_min_embed, self.init_method_max_embed) if self.distillation_model is not None: if self.prediction_heads > 1: raise NotImplementedError("Multi-token prediction not supported with distillation.") @@ -230,43 +213,40 @@ def _validate(self) -> None: Assert.eq(len(self.prediction_loss_coefficient), self.prediction_heads) for coeff in self.prediction_loss_coefficient: Assert.geq(coeff, 0) - if self.transformer.per_layer_lr_scale is not None: - # -1 because the first prediction head's transformer layer is accounted for in num_layers - # +1 because the layer index starts at 1 - Assert.eq( - len(self.transformer.per_layer_lr_scale), self.transformer.num_layers + self.prediction_heads - 1 + 1 - ) + + if self.output_weight_initialization.has_initialization: + assert self.use_absolute_position_embeddings + if self.output_weight_initialization.has_initialization: + assert not self.tie_word_embeddings def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - self.transformer.setup_tensor_space(tensor_space) + super().setup_tensor_space(tensor_space) tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Embedding dimensions - tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.position_embed, self.max_position_embeddings)) + if self.use_absolute_position_embeddings: + tensor_space.add_tensor_dim( + TensorDim(LanguageModelDimNames.position_embed, self.absolute_position_embeddings) + ) # TODO: Need both? tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab, self.vocab_size)) tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab_tp, self.vocab_size, tensor)) - @property - def num_absolute_position_embeddings(self) -> int: - # TODO: Rename from max embeddings. - return self.max_position_embeddings if self.use_absolute_position_embeddings else None + @functools.cached_property + def word_embedding_weight_initialization_method(self) -> Initializer: + if self.word_embedding_weight_initialization.has_initialization: + return self.word_embedding_weight_initialization.get_initializer() + else: + return self.hidden_size**-0.5 @property def use_absolute_position_embeddings(self) -> int: # TODO: Set through num embeddings instead instead. - return self.use_position_embeddings - - @classmethod - def from_flat_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - ) -> typing.Self: - # The backward compatibility fix in `NormalizationArchitectureConfig` - # won't work for older checkpoints saved with a flat config. - # TODO v0.3: Remove flat format - cls._handle_renamed_field(default, "normalization_type", "type") - cls._handle_renamed_field(default, "layer_norm_eps", "epsilon") - cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") - return super().from_flat_dict(default, strict) + return self.absolute_position_embeddings is not None + + @functools.cached_property + def output_weight_initialization_method(self) -> Initializer: + if self.output_weight_initialization.has_initialization: + return self.output_weight_initialization.get_initializer() + else: + return self.hidden_size**-0.5 diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 05678a700..b49fef7ba 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -7,28 +7,28 @@ from fast_llm.core.ops import reduce_forward, split from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelDimNames, LanguageModelKwargs -from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ +from fast_llm.layers.language_model.config import LanguageModelConfig, LanguageModelDimNames, LanguageModelKwargs +from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import Assert WORD_EMBEDDINGS_WEIGHT = "word_embeddings_weight" -class LanguageModelEmbedding[ConfigType: LanguageModelBaseConfig](Configurable[LanguageModelBaseConfig], Layer): +class LanguageModelEmbedding[ConfigType: LanguageModelConfig](Configurable[ConfigType], Layer): """ A language model embedding layer. Consists of word embeddings (tensor-parallel or sequence-tensor-parallel), together with optional absolute position embeddings and dropout. """ - config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = LanguageModelBaseConfig + config_class: typing.ClassVar[type[LanguageModelConfig]] = LanguageModelConfig # Ensure the layer is on its own stage. layer_count: float = 1000.0 def __init__( self, - config: LanguageModelBaseConfig, + config: LanguageModelConfig, tensor_space: TensorSpace, ): super().__init__(config) @@ -36,14 +36,14 @@ def __init__( self._tensor_space = tensor_space self._residual_dtype = ( self._distributed_config.optimization_dtype - if config.transformer.full_precision_residual + if self._config.full_precision_residual else self._distributed_config.training_dtype ).torch self._group_size = self._distributed_config.tensor_parallel self._sequence_parallel = self._distributed_config.sequence_tensor_parallel - self._parallel_embeddings = tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings - self._dropout_p = config.transformer.hidden_dropout - self._use_absolute_position_embeddings = config.use_absolute_position_embeddings + self._parallel_embeddings = ( + tensor_space.distributed_config.tensor_parallel > 1 and self._config.parallel_embeddings + ) hidden_dim = tensor_space[LanguageModelDimNames.hidden] vocab_dim = tensor_space[ @@ -56,23 +56,15 @@ def __init__( self.word_embeddings_weight = ParameterMeta.from_dims( (vocab_dim, hidden_dim), - init_method=init_normal_( - std=config.init_method_std_embed, - min_val=config.init_method_min_embed, - max_val=config.init_method_max_embed, - ), - lr_scale=config.embeddings_lr_scale, + init_method=self._config.word_embedding_weight_initialization_method, + lr_scale=self._config.embeddings_lr_scale, ) - if self._use_absolute_position_embeddings: + if self._config.use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( (tensor_space[LanguageModelDimNames.position_embed], hidden_dim), - init_method=init_normal_( - std=config.init_method_std_embed, - min_val=config.init_method_min_embed, - max_val=config.init_method_max_embed, - ), - allow_sequence_tensor_parallel=not config.parallel_embeddings, - lr_scale=config.embeddings_lr_scale, + init_method=self._config.position_embedding_weight_initialization_method, + allow_sequence_tensor_parallel=not self._config.parallel_embeddings, + lr_scale=self._config.embeddings_lr_scale, ) # PEFT. @@ -84,21 +76,21 @@ def __init__( @torch.compile def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask_inputs: bool) -> torch.Tensor: - Assert.eq(position_ids is not None, self._use_absolute_position_embeddings) + Assert.eq(position_ids is not None, self._config.use_absolute_position_embeddings) group = self._tensor_space.distributed.tensor_group if self._parallel_embeddings: input_mask = (input_ >= self._vocab_start_index) * (input_ < self._vocab_end_index) masked_input = (input_ - self._vocab_start_index) * input_mask embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * input_mask.unsqueeze(2) # noqa embeddings = reduce_forward(embeddings, group) - if self._use_absolute_position_embeddings: + if self._config.use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) else: if self._sequence_parallel: input_ = split(input_, group=group, dim=0) - if self._use_absolute_position_embeddings: + if self._config.use_absolute_position_embeddings: position_ids = split(position_ids, group=group, dim=0) # handle masked tokens if mask_inputs: @@ -107,7 +99,7 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask embeddings = torch.embedding(self.word_embeddings_weight, masked_input) else: embeddings = torch.embedding(self.word_embeddings_weight, input_) - if self._use_absolute_position_embeddings: + if self._config.use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if mask_inputs: embeddings = embeddings * input_mask.unsqueeze(2) @@ -116,7 +108,7 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask if self._sequence_parallel else self._tensor_space.distributed.pp_generator ): - embeddings = torch.dropout(embeddings, self._dropout_p, self.training) + embeddings = torch.dropout(embeddings, self._config.embedding_dropout, self.training) return embeddings.to(dtype=self._residual_dtype) def forward( diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index bc672725c..098b2463b 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -15,16 +15,16 @@ from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward +from fast_llm.layers.block.block import DebugLayer from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.language_model.config import ( - LanguageModelBaseConfig, + LanguageModelConfig, LanguageModelDimNames, LanguageModelKwargs, LanguageModelLossNames, ) from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT -from fast_llm.logging import log_distributed_tensor -from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ +from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import Assert, div, get_unique logger = logging.getLogger(__name__) @@ -32,61 +32,67 @@ OUTPUT_WEIGHTS = "output_weights" -class LanguageModelHead[ConfigType: LanguageModelBaseConfig](Configurable[LanguageModelBaseConfig], Layer): +class LanguageModelHead[ConfigType: LanguageModelConfig](Configurable[ConfigType], Layer): """ A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). """ - config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = LanguageModelBaseConfig + config_class: typing.ClassVar[type[LanguageModelConfig]] = LanguageModelConfig def __init__( self, - config: LanguageModelBaseConfig, + config: LanguageModelConfig, tensor_space: TensorSpace, prediction_distance: int, ): super().__init__(config) - self._debug_transformer = config.transformer.debug_transformer - self._tie_word_embeddings = config.tie_word_embeddings + # TODO: Avoid default_block_config? + self._debug = DebugLayer( + tensor_space, + f"Block {self._block_index} {self._name}", + self._config.default_block_config.debug_transformer, + self._config.default_block_config.debug_transformer_memory, + ) self._tensor_space = tensor_space self._group_size = tensor_space.distributed_config.tensor_parallel self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel - self._parallel_embeddings = tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings + self._parallel_embeddings = ( + tensor_space.distributed_config.tensor_parallel > 1 and self._config.parallel_embeddings + ) self._sequence_parallel_logits = ( - tensor_space.distributed_config.sequence_tensor_parallel and not config.parallel_embeddings + tensor_space.distributed_config.sequence_tensor_parallel and not self._config.parallel_embeddings ) - self._cross_entropy_splits = config.cross_entropy_splits + self._cross_entropy_splits = self._config.cross_entropy_splits if self._cross_entropy_splits is not None and self._sequence_parallel: assert not self._parallel_embeddings hidden_dim = self._tensor_space[LanguageModelDimNames.hidden] self._loss_coefficient = ( - config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0 + self._config.prediction_loss_coefficient[prediction_distance] + if self._config.prediction_loss_coefficient + else 1.0 ) self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance) - self.final_norm = config.transformer.normalization.get_layer(hidden_dim) - self._logits_scale_factor = config.logits_scale_factor - self._language_model_loss_factor = config.language_model_loss_factor - self._distillation_loss_factor = config.distillation_loss_factor - self._z_loss_factor = config.logit_z_loss + # TODO: Avoid default_block_config? + self.final_norm = self._config.default_block_config.normalization.get_layer(hidden_dim) + self._logits_scale_factor = self._config.logits_scale_factor + self._language_model_loss_factor = self._config.language_model_loss_factor + self._distillation_loss_factor = self._config.distillation_loss_factor + self._z_loss_factor = self._config.logit_z_loss # Distance of the target token prediction # 0: next-token prediction # >0: multi-token prediction (MTP) Assert.geq(prediction_distance, 0) self._prediction_distance = prediction_distance - self._is_last_head = self._prediction_distance == config.prediction_heads - 1 + self._is_last_head = self._prediction_distance == self._config.prediction_heads - 1 - self._init_output_weights(hidden_dim, config) + self._init_output_weights(hidden_dim, self._config) - self._use_dpo_loss = config.enable_dpo - if self._use_dpo_loss: - self.dpo_beta = config.dpo_beta - else: - self._cross_entropy_impl = config.cross_entropy_impl - self._distillation_loss_implementation = config.distillation_loss_implementation + if not self._config.enable_dpo: + self._cross_entropy_impl = self._config.cross_entropy_impl if self._cross_entropy_impl == CrossEntropyImpl.auto: if self._parallel_embeddings: self._cross_entropy_impl = CrossEntropyImpl.fused @@ -104,7 +110,7 @@ def __init__( def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: # Only the first head defines the output weights - if self._tie_word_embeddings or self._prediction_distance > 0: + if self._config.tie_word_embeddings or self._prediction_distance > 0: return # untie embedding weights vocab_dim = self._tensor_space[ @@ -112,11 +118,7 @@ def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: ] self.output_weights = ParameterMeta.from_dims( (vocab_dim, hidden_dim), - init_method=init_normal_( - std=config.init_method_std_embed, - min_val=config.init_method_min_embed, - max_val=config.init_method_max_embed, - ), + init_method=self._config.output_weight_initialization_method, lr_scale=config.output_lr_scale, ) @@ -201,7 +203,7 @@ def _get_targets( self, kwargs: dict ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None] | None: # Loss mask for distillation. (Labels are already masked.) - if self._use_dpo_loss: + if self._config.enable_dpo: dpo_target = kwargs.get(LanguageModelKwargs.labels) lm_target = None distillation_target = None @@ -251,7 +253,7 @@ def _get_targets( return targets def _get_output_weights(self, kwargs: dict) -> torch.Tensor: - if self._tie_word_embeddings: + if self._config.tie_word_embeddings: return kwargs[WORD_EMBEDDINGS_WEIGHT] if self._prediction_distance > 0: return kwargs[OUTPUT_WEIGHTS] @@ -338,35 +340,22 @@ def _logits_cross_entropy_forward_backward( LanguageModelLossNames.z_loss, logits_scale_factor=self._logits_scale_factor, ) - if self._debug_transformer and self._cross_entropy_splits is None: - vocab_dim = self._tensor_space[ + if self._debug.enabled and self._cross_entropy_splits is None: + vocab_dim = ( LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp - ] - dims = [*kwargs[LanguageModelKwargs.hidden_dims][:-1], vocab_dim] - sequence_index = 1 - int(kwargs[LanguageModelKwargs.sequence_first]) - dims[sequence_index] = ( - TensorDim( - LanguageModelDimNames.sequence_q_tp, dims[sequence_index].global_size, DistributedDimNames.tensor - ) - if self._sequence_parallel_logits - else TensorDim(LanguageModelDimNames.sequence_q, dims[sequence_index].global_size) ) - - dim_names = ( - [LanguageModelDimNames.sequence_q_tp, LanguageModelDimNames.vocab] + sequence_dim = ( + LanguageModelDimNames.sequence_q_tp if self._sequence_parallel_logits - else [LanguageModelDimNames.sequence_q, LanguageModelDimNames.vocab_tp] + else LanguageModelDimNames.sequence_q ) - - dim_names.insert(int(kwargs[LanguageModelKwargs.sequence_first]), LanguageModelDimNames.batch) - log_distributed_tensor( - "", - logits, - level=self._debug_transformer, - meta=TensorMeta.from_dims(tuple(dims), tensor_name="transformer logits", dtype=logits.dtype), - distributed=self._tensor_space.distributed, - scale=self._logits_scale_factor, + batch_dim = kwargs[LanguageModelKwargs.hidden_dims][1 if kwargs[LanguageModelKwargs.sequence_first] else 0] + dims = ( + (sequence_dim, batch_dim, vocab_dim) + if kwargs[LanguageModelKwargs.sequence_first] + else (batch_dim, sequence_dim, vocab_dim) ) + self._debug(logits, "Language model logits", dims, kwargs, scale=self._logits_scale_factor) if targets is None: return logits * self._logits_scale_factor, None @@ -379,7 +368,7 @@ def _logits_cross_entropy_forward_backward( kwargs.get(f"{self._config.dpo_reference_model}_logits"), kwargs[LanguageModelKwargs.chosen_spans], kwargs[LanguageModelKwargs.rejected_spans], - self.dpo_beta, + self._config.dpo_beta, grad_output * self._loss_coefficient, ) else: @@ -401,7 +390,7 @@ def _logits_cross_entropy_forward_backward( lm_loss, lm_grad = None, None if distillation_target is not None and self._distillation_loss_factor > 0.0: - if self._distillation_loss_implementation == DistillationLossImpl.reverse_kl: + if self._config.distillation_loss_implementation == DistillationLossImpl.reverse_kl: distillation_loss, distillation_grad = reverse_kl_forward_backward( logits.flatten(0, -2), distillation_target, @@ -414,7 +403,7 @@ def _logits_cross_entropy_forward_backward( TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits ), ) - elif self._distillation_loss_implementation == DistillationLossImpl.cross_entropy: + elif self._config.distillation_loss_implementation == DistillationLossImpl.cross_entropy: distillation_loss, distillation_grad = cross_entropy_forward_backward( logits.flatten(0, -2), distillation_target, diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index f5d915855..3c9f18c8d 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -5,7 +5,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs +from fast_llm.layers.language_model.config import LanguageModelConfig, LanguageModelKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert @@ -20,11 +20,11 @@ class PositionEmbeddingPreprocessor(Preprocessor): def __init__( self, - config: LanguageModelBaseConfig, + config: LanguageModelConfig, tensor_space: TensorSpace, ): self._config = config - assert config.use_absolute_position_embeddings + assert config.absolute_position_embeddings is not None self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] @@ -34,7 +34,7 @@ def _create_tensors(self, sequence_length: int) -> None: return self._tensor_cache_max_sequence_length = sequence_length - Assert.leq(sequence_length, self._config.num_absolute_position_embeddings) + Assert.leq(sequence_length, self._config.absolute_position_embeddings) self._position_ids = torch.arange( 0, sequence_length, device=self._tensor_space.distributed.device, dtype=torch.int64 ) @@ -71,7 +71,7 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: class PreferenceSpanPreprocessor(Preprocessor): - def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): + def __init__(self, config: LanguageModelConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index efcf2d873..00c709814 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -9,7 +9,7 @@ from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: - from fast_llm.tensor import Initializer + from fast_llm.engine.config_utils.initialization import Initializer, init_fill_, init_uniform_centered_ class SSMDimNames(BlockDimNames): @@ -66,8 +66,6 @@ class DTInitType(enum.StrEnum): random = "random" def get_init_method(self, scale: float) -> "Initializer": - from fast_llm.tensor import init_fill_, init_uniform_centered_ - return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 550c44d0f..04b27af47 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -4,13 +4,15 @@ import einops import torch +from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_, init_zeros_ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ +from fast_llm.layers.ssm.mamba_layer import init_kaiming_ +from fast_llm.tensor import ParameterMeta from fast_llm.utils import get_lr_scale logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 1c319f490..b02fbd401 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -3,14 +3,15 @@ import torch +from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias -from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_ +from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias, init_kaiming_ +from fast_llm.tensor import ParameterMeta from fast_llm.utils import Assert, div, get_lr_scale try: diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index f5b0139cf..e22852fe6 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -4,13 +4,14 @@ import torch +from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.tensor import LambdaInitializer, ParameterMeta, init_kaiming_, init_ones_ +from fast_llm.tensor import ParameterMeta from fast_llm.utils import Assert, get_lr_scale try: @@ -163,3 +164,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if kwargs[BlockKwargs.sequence_first]: out = out.transpose(0, 1) return out, None + + +def init_kaiming_(d_in: float) -> LambdaInitializer: + return init_normal_(0.0, math.sqrt(2.0 / d_in)) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index b1de792e3..2db7b0ac8 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -6,11 +6,10 @@ from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.autograd import wrap_forward_backward -from fast_llm.layers.block.mixer import Mixer +from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs -from fast_llm.tensor import init_normal_, init_zeros_ +from fast_llm.layers.transformer.config import AttentionConfig, AttentionDimNames, AttentionKwargs from fast_llm.utils import get_lr_scale try: @@ -46,55 +45,52 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # no return grad, None -class Attention(Mixer): +class Attention[ConfigType: AttentionConfig](BlockLayer[ConfigType]): """ A self-attention layer. """ - _mixer_name: typing.ClassVar[str] = "attn" - _QUERY_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.composite_heads, - TransformerDimNames.kv_channels, + AttentionDimNames.batch, + AttentionDimNames.sequence_q, + AttentionDimNames.composite_heads, + AttentionDimNames.kv_channels, ) _KV_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.head_groups, - TransformerDimNames.kv_channels, + AttentionDimNames.batch, + AttentionDimNames.sequence_q, + AttentionDimNames.head_groups, + AttentionDimNames.kv_channels, ) _CONTEXT_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.composite_dense, + AttentionDimNames.batch, + AttentionDimNames.sequence_q, + AttentionDimNames.composite_dense, ) - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int): - super().__init__(tensor_space, block_index, config.debug_transformer) + def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str): + super().__init__(config, tensor_space, block_index, name) self._config = config self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) - init_method_qkv = init_normal_( - std=self._config.init_method_std_qkv, - min_val=self._config.init_method_min_qkv, - max_val=self._config.init_method_max_qkv, - ) - init_method_std_attn_proj = init_normal_( - std=self._config.init_method_std_attn_proj, - min_val=self._config.init_method_min_attn_proj, - max_val=self._config.init_method_max_attn_proj, - ) - - self._kv_channels = self._tensor_space[TransformerDimNames.kv_channels].size - self._head_groups = self._tensor_space[TransformerDimNames.head_groups].global_size - self._local_head_groups = self._tensor_space[TransformerDimNames.head_groups].size - self._local_heads_per_group = self._tensor_space[TransformerDimNames.group_heads].size + # init_method_qkv = init_normal_( + # std=self._config.init_method_std_qkv, + # min_val=self._config.init_method_min_qkv, + # max_val=self._config.init_method_max_qkv, + # ) + # init_method_std_attn_proj = init_normal_( + # std=self._config.init_method_std_attn_proj, + # min_val=self._config.init_method_min_attn_proj, + # max_val=self._config.init_method_max_attn_proj, + # ) + self._kv_channels = self._tensor_space[AttentionDimNames.kv_channels].size + self._head_groups = self._tensor_space[AttentionDimNames.head_groups].global_size + self._local_head_groups = self._tensor_space[AttentionDimNames.head_groups].size + self._local_heads_per_group = self._tensor_space[AttentionDimNames.group_heads].size self._local_heads = self._local_head_groups * self._local_heads_per_group - self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power) + self._softmax_scale: float = self._kv_channels ** (-self._config.attention_softmax_scale_power) - hidden_dim = self._tensor_space[TransformerDimNames.hidden] + hidden_dim = self._tensor_space[AttentionDimNames.hidden] layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) @@ -102,19 +98,19 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, - self._tensor_space[TransformerDimNames.composite_query], + self._tensor_space[AttentionDimNames.composite_query], bias=self._config.add_attn_qkv_bias, - weight_init_method=init_method_qkv, - bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, + weight_init_method=self._config.qkv_weight_initialization_method, + bias_init_method=self._config.qkv_bias_initialization_method, sequence_parallel=self._sequence_parallel, lr_scale=attention_lr_scale, ) self.key_value = OutputParallelLinear( hidden_dim, - self._tensor_space[TransformerDimNames.composite_key_value], + self._tensor_space[AttentionDimNames.composite_key_value], bias=self._config.add_attn_qkv_bias, - weight_init_method=init_method_qkv, - bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, + weight_init_method=self._config.qkv_weight_initialization_method, + bias_init_method=self._config.qkv_bias_initialization_method, sequence_parallel=self._sequence_parallel, lr_scale=attention_lr_scale, ) @@ -125,11 +121,11 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i # Output. self.dense = InputParallelLinear( - self._tensor_space[TransformerDimNames.composite_dense], + self._tensor_space[AttentionDimNames.composite_dense], hidden_dim, bias=self._config.add_attn_dense_bias, - weight_init_method=init_method_std_attn_proj, - bias_init_method=init_method_std_attn_proj if self._config.random_bias_init else init_zeros_, + weight_init_method=self._config.dense_weight_initialization_method, + bias_init_method=self._config.dense_bias_initialization_method, sequence_parallel=self._sequence_parallel, lr_scale=attention_lr_scale, ) @@ -259,18 +255,24 @@ def _decide_window_size(self) -> int | None: return window_size - def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: - sequence_first = kwargs[TransformerKwargs.sequence_first] + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + sequence_first = kwargs[AttentionKwargs.sequence_first] query, key_value = self._query_key_value(input_, sequence_first) # TODO: Move the rest to function. - if (past_key_values := kwargs.get(TransformerKwargs.past_key_values)) is not None: + if (past_key_values := kwargs.get(AttentionKwargs.past_key_values)) is not None: assert sequence_first # Clear the lists so tensors can be de-allocated key_value = torch.cat((past_key_values.pop(0), key_value), dim=0) - if (presents := kwargs.get(TransformerKwargs.presents)) is not None: + if (presents := kwargs.get(AttentionKwargs.presents)) is not None: # Return the presents as a leaf tensors so the gradients from later micro-sequences # don't propagate to this one. presents.append(present := key_value.detach().requires_grad_()) @@ -279,9 +281,9 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if self._tensor_space.distributed.sequence_data_group: key_value = ( - key_value[: kwargs[TransformerKwargs.sequence_k_dim].size] + key_value[: kwargs[AttentionKwargs.sequence_k_dim].size] if sequence_first - else key_value[:, : kwargs[TransformerKwargs.sequence_k_dim].size] + else key_value[:, : kwargs[AttentionKwargs.sequence_k_dim].size] ) if sequence_first: @@ -295,9 +297,9 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels) value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels) - if self._debug_level: - self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs) - self._debug_log( + if self._debug.enabled: + self._debug(query, "query_rotary_input", self._QUERY_DIMS, kwargs) + self._debug( key, "key_rotary_input", self._KV_DIMS, @@ -310,7 +312,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if self._use_flash_attention: assert _flash_available with set_generator(self._tensor_space.distributed.tp_generator): - if (cu_seqlens_q := kwargs.get(TransformerKwargs.cu_seqlens_q, None)) is not None: + if (cu_seqlens_q := kwargs.get(AttentionKwargs.cu_seqlens_q, None)) is not None: out_dims = query.size() query = query.view(-1, query.size(-2), query.size(-1)) key = key.view(-1, key.size(-2), key.size(-1)) @@ -320,9 +322,9 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key, value, cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=kwargs.get(TransformerKwargs.cu_seqlens_k), - max_seqlen_q=kwargs.get(TransformerKwargs.max_seqlen_q), - max_seqlen_k=kwargs.get(TransformerKwargs.max_seqlen_k), + cu_seqlens_k=kwargs.get(AttentionKwargs.cu_seqlens_k), + max_seqlen_q=kwargs.get(AttentionKwargs.max_seqlen_q), + max_seqlen_k=kwargs.get(AttentionKwargs.max_seqlen_k), dropout_p=self._config.attention_dropout if self.training else 0.0, window_size=(-1, -1) if window_size is None else (window_size - 1, 0), causal=True, @@ -345,25 +347,15 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ query.flatten(-2), key.flatten(-2), value.flatten(-2), - kwargs[TransformerKwargs.attention_mask], - kwargs[TransformerKwargs.attention_mask_value], + kwargs[AttentionKwargs.attention_mask], + kwargs[AttentionKwargs.attention_mask_value], ) - if self._debug_level: - self._debug_log(query, "query", self._QUERY_DIMS, kwargs) - self._debug_log( - key, - "key", - self._KV_DIMS, - kwargs, - ) - self._debug_log( - value, - "value", - self._KV_DIMS, - kwargs, - ) - self._debug_log(input_, "context", self._CONTEXT_DIMS, kwargs) + if self._debug.enabled: + self._debug(query, "query", self._QUERY_DIMS, kwargs) + self._debug(key, "key", self._KV_DIMS, kwargs) + self._debug(value, "value", self._KV_DIMS, kwargs) + self._debug(input_, "context", self._CONTEXT_DIMS, kwargs) if sequence_first: # TODO: Optimize (is contiguous avoidable? Transpose dense output?) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index ebb976e63..bd72bd305 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -3,22 +3,29 @@ import typing import warnings -from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_zeros_ from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import TritonConfig -from fast_llm.layers.block.config import AddLinearBiasChoices, BlockConfig, BlockDimNames, BlockKwargs +from fast_llm.layers.block.config import ( + AddLinearBiasChoices, + BlockDimNames, + BlockKwargs, + BlockLayerConfig, + MixerConfig, +) from fast_llm.layers.transformer.rotary.config import RotaryConfig from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: - pass + from fast_llm.layers.transformer.attention import Attention logger = logging.getLogger(__name__) -class TransformerDimNames(BlockDimNames): +class AttentionDimNames(BlockDimNames): # A set of common tensor dim names packed into a namespace. # Self-attention dimensions head_groups = "head_groups" @@ -31,7 +38,7 @@ class TransformerDimNames(BlockDimNames): composite_dense = "composite_dense" -class TransformerKwargs(BlockKwargs): +class AttentionKwargs(BlockKwargs): rotary_freq_q = "rotary_freq_q" rotary_freq_k = "rotary_freq_k" attention_mask = "attention_mask" @@ -45,9 +52,8 @@ class TransformerKwargs(BlockKwargs): past_key_values = "past_key_values" -@config_class() -class AttentionConfig(Config): - # TODO: Make mixer class dynamic. +@config_class(dynamic_type={BlockLayerConfig: "attention"}) +class AttentionConfig(MixerConfig): _abstract = False # TODO: Review names @@ -107,7 +113,30 @@ class AttentionConfig(Config): valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) + qkv_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for the query, key and value layer weights. Default: hidden_size**-0.5", + hint=FieldHint.feature, + ) + qkv_bias_initialization: InitializationConfig = Field( + desc="Initialization configuration for the query, key and value layer biases. Default: fill with zeros.", + hint=FieldHint.feature, + ) + dense_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for the dense layer weight. Default: (2 * num_blocks * hidden_size)**-0.5", + hint=FieldHint.feature, + ) + dense_bias_initialization: InitializationConfig = Field( + desc="Initialization configuration for the dense layer biases. Default: fill with zeros.", + hint=FieldHint.feature, + ) + def _validate(self) -> None: + + with self._set_implicit_default(): + if self.kv_channels is None: + # TODO: hidden_size not yet validated. + self.kv_channels = div(self.block.block_sequence.hidden_size, self.num_attention_heads) + super()._validate() if not TritonConfig.TRITON_ENABLED: @@ -130,182 +159,74 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor_space.add_tensor_dim( head_groups := TensorDim( - TransformerDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None + AttentionDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None ) ) tensor_space.add_tensor_dim( group_heads := TensorDim( - TransformerDimNames.group_heads, + AttentionDimNames.group_heads, div(self.num_attention_heads, self.head_groups), None if self.head_groups > 1 else tensor, ) ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(TransformerDimNames.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(TransformerDimNames.kv_channels, self.kv_channels)) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_heads, (head_groups, group_heads)) - ) + tensor_space.add_tensor_dim(key_and_value := TensorDim(AttentionDimNames.key_and_value, 2)) + tensor_space.add_tensor_dim(kv_channels := TensorDim(AttentionDimNames.kv_channels, self.kv_channels)) + tensor_space.add_tensor_dim(CompositeTensorDim(AttentionDimNames.composite_heads, (head_groups, group_heads))) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_query, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(AttentionDimNames.composite_query, (head_groups, group_heads, kv_channels)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) + CompositeTensorDim(AttentionDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_dense, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(AttentionDimNames.composite_dense, (head_groups, group_heads, kv_channels)) ) + def get_block(self) -> "Attention": + pass -@config_class() -# TODO: Use composition for attention config -class TransformerConfig(AttentionConfig, BlockConfig): - _abstract = False - - # TODO: Review names - init_method_std: float = Field( - default=None, - desc="Default scale for weight initialization. Default: hidden_size**-0.5", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max: float | None = Field( - default=None, - desc="Max value for clamping initialized weights. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min: float | None = Field( - default=None, - desc="Min value for clamping initialized weights. Default: -float('inf')", - hint=FieldHint.optional, - ) - init_method_std_qkv: float = Field( - default=None, - desc="Scale for the query, key and value weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_qkv: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for query, key and value matrices. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_qkv: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for query, key and value matrices. Default: -float('inf')", - hint=FieldHint.optional, - ) - init_method_std_attn_proj: float = Field( - default=None, - desc="Scale for the attention projection weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_attn_proj: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for attention projection. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_attn_proj: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for attention projection. Default: -float('inf')", - hint=FieldHint.optional, - ) - init_method_std_mlp_1: float = Field( - default=None, - desc="Scale for the MLP first layer weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_mlp_1: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for MLP first layer. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_mlp_1: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for MLP first layer. Default: -float('inf')", - hint=FieldHint.optional, - ) - init_method_std_mlp_2: float = Field( - default=None, - desc="Scale for the MLP second layer weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_mlp_2: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for MLP second layer. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_mlp_2: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for MLP second layer. Default: -float('inf')", - hint=FieldHint.optional, - ) - # Use random inits instead of constant values, useful for debugging. - random_bias_init: bool = Field( - default=False, - desc="Initialize the biases using the initialization method of their respective weights instead of setting them to zero. Used to test for issues that may not be visible when the biases are zero.", - hint=FieldHint.testing, - ) - - def _validate(self) -> None: - with self._set_implicit_default(): - if self.kv_channels is None: - self.kv_channels = div(self.hidden_size, self.num_attention_heads) - if self.init_method_std is None: - self.init_method_std = self.hidden_size**-0.5 - if self.init_method_std_qkv is None: - self.init_method_std_qkv = self.init_method_std - if self.init_method_std_attn_proj is None: - self.init_method_std_attn_proj = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 - if self.init_method_std_mlp_1 is None: - self.init_method_std_mlp_1 = self.init_method_std - if self.init_method_std_mlp_2 is None: - self.init_method_std_mlp_2 = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 - if self.init_method_max_qkv is None: - self.init_method_max_qkv = self.init_method_max - if self.init_method_min_qkv is None: - self.init_method_min_qkv = self.init_method_min - if self.init_method_max_attn_proj is None: - self.init_method_max_attn_proj = self.init_method_max - if self.init_method_min_attn_proj is None: - self.init_method_min_attn_proj = self.init_method_min - if self.init_method_max_mlp_1 is None: - self.init_method_max_mlp_1 = self.init_method_max - if self.init_method_min_mlp_1 is None: - self.init_method_min_mlp_1 = self.init_method_min - if self.init_method_max_mlp_2 is None: - self.init_method_max_mlp_2 = self.init_method_max - if self.init_method_min_mlp_2 is None: - self.init_method_min_mlp_2 = self.init_method_min - if self.init_method_min is not None and self.init_method_max is not None: - Assert.leq(self.init_method_min, self.init_method_max) - if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: - Assert.leq(self.init_method_min, self.init_method_max) - if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: - Assert.leq(self.init_method_min_qkv, self.init_method_max_qkv) - if self.init_method_min_attn_proj is not None and self.init_method_max_attn_proj is not None: - Assert.leq(self.init_method_min_attn_proj, self.init_method_max_attn_proj) - if self.init_method_min_mlp_1 is not None and self.init_method_max_mlp_1 is not None: - Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1) - if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: - Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) - - super()._validate() - - @property - def add_attn_qkv_bias(self) -> bool: - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.nowhere: + @functools.cached_property + def add_qkv_bias(self) -> bool: + if isinstance(self.block.add_linear_biases, bool): + return self.block.add_linear_biases + if self.block.add_linear_biases == AddLinearBiasChoices.nowhere: return False return True - @property - def add_attn_dense_bias(self) -> bool: - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.everywhere: + @functools.cached_property + def add_dense_bias(self) -> bool: + if isinstance(self.block.add_linear_biases, bool): + return self.block.add_linear_biases + if self.block.add_linear_biases == AddLinearBiasChoices.everywhere: return True return False + + @functools.cached_property + def qkv_weight_initialization_method(self) -> Initializer: + if self.qkv_weight_initialization.has_initialization: + return self.qkv_weight_initialization.get_initializer() + else: + return self.block.block_sequence.hidden_size**-0.5 + + @functools.cached_property + def qkv_bias_initialization_method(self) -> Initializer: + if self.qkv_bias_initialization.has_initialization: + assert self.add_qkv_bias + return self.qkv_bias_initialization.get_initializer() + else: + return init_zeros_ + + @functools.cached_property + def dense_weight_initialization_method(self) -> Initializer: + if self.dense_weight_initialization.has_initialization: + return self.dense_weight_initialization.get_initializer() + else: + return self.block.block_sequence.hidden_size**-0.5 / max(2 * self.block.block_sequence.num_blocks, 1) + + @functools.cached_property + def dense_bias_initialization_method(self) -> Initializer: + if self.dense_bias_initialization.has_initialization: + assert self.add_dense_bias + return self.dense_bias_initialization.get_initializer() + else: + return init_zeros_ diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index 3f0e14eb7..16e5811e6 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -5,7 +5,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs +from fast_llm.layers.transformer.config import AttentionConfig, AttentionKwargs from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) @@ -21,7 +21,7 @@ class BackupAttentionPreprocessor(Preprocessor): def __init__( self, - config: TransformerConfig, + config: AttentionConfig, tensor_space: TensorSpace, ): self._config = config @@ -51,13 +51,13 @@ def _create_tensors(self, sequence_length: int) -> None: ) def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[TransformerKwargs.sequence_length]) - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size - kwargs[TransformerKwargs.attention_mask] = self._mask[ + self._create_tensors(kwargs[AttentionKwargs.sequence_length]) + sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size + sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size + kwargs[AttentionKwargs.attention_mask] = self._mask[ None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k ] - if (sequence_lengths := kwargs.get(TransformerKwargs.sequence_lengths, None)) is not None: + if (sequence_lengths := kwargs.get(AttentionKwargs.sequence_lengths, None)) is not None: seq_ids = torch.stack( [ torch.cat([torch.full((x,), i) for i, x in enumerate(sample_lens)]) @@ -65,33 +65,33 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: ] ) document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(self._tensor_space.distributed.device) - kwargs[TransformerKwargs.attention_mask] = ( - kwargs[TransformerKwargs.attention_mask] + kwargs[AttentionKwargs.attention_mask] = ( + kwargs[AttentionKwargs.attention_mask] & document_mask[:, None, sequence_k - sequence_q : sequence_k, None, :sequence_k] ) - kwargs[TransformerKwargs.attention_mask_value] = self._mask_value + kwargs[AttentionKwargs.attention_mask_value] = self._mask_value def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[TransformerKwargs.attention_mask] = TensorMeta.from_dims( + kwargs[AttentionKwargs.attention_mask] = TensorMeta.from_dims( ( self._scalar_dim, self._scalar_dim, - kwargs[TransformerKwargs.sequence_q_dim], + kwargs[AttentionKwargs.sequence_q_dim], self._scalar_dim, - kwargs[TransformerKwargs.sequence_k_dim], + kwargs[AttentionKwargs.sequence_k_dim], ), - tensor_name=TransformerKwargs.attention_mask, + tensor_name=AttentionKwargs.attention_mask, dtype=torch.bool, ) - kwargs[TransformerKwargs.attention_mask_value] = TensorMeta.from_dims( + kwargs[AttentionKwargs.attention_mask_value] = TensorMeta.from_dims( (self._scalar_dim,), - tensor_name=TransformerKwargs.attention_mask_value, + tensor_name=AttentionKwargs.attention_mask_value, dtype=self._tensor_space.distributed_config.training_dtype.torch, ) class FlashAttnVarlenPreprocessor(Preprocessor): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace): + def __init__(self, config: AttentionConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config @@ -107,12 +107,12 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: also contain previous tokens from the first document in micro-sequence. We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths. """ - if TransformerKwargs.sequence_lengths not in kwargs: + if AttentionKwargs.sequence_lengths not in kwargs: return - sequence_lengths = kwargs[TransformerKwargs.sequence_lengths] - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size - if sequence_q < kwargs[TransformerKwargs.sequence_length]: + sequence_lengths = kwargs[AttentionKwargs.sequence_lengths] + sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size + sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size + if sequence_q < kwargs[AttentionKwargs.sequence_length]: cumsums = [torch.cumsum(x, dim=0) for x in sequence_lengths] # The first and last documents in a microsequence need to be handled separately. Include all tokens from other documents # in the microsequence. We need to consider all keys computed so far from the first sample. We also store the offsets @@ -146,17 +146,17 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: else: seqlens_q = torch.cat(sequence_lengths) seqlens_k = torch.cat(sequence_lengths) - kwargs[TransformerKwargs.cu_seqlens_q] = torch.cat( + kwargs[AttentionKwargs.cu_seqlens_q] = torch.cat( ( torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), torch.cumsum(seqlens_q, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), ) ) - kwargs[TransformerKwargs.cu_seqlens_k] = torch.cat( + kwargs[AttentionKwargs.cu_seqlens_k] = torch.cat( ( torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), torch.cumsum(seqlens_k, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), ) ) - kwargs[TransformerKwargs.max_seqlen_q] = seqlens_q.max() - kwargs[TransformerKwargs.max_seqlen_k] = seqlens_k.max() + kwargs[AttentionKwargs.max_seqlen_q] = seqlens_q.max() + kwargs[AttentionKwargs.max_seqlen_k] = seqlens_k.max() diff --git a/fast_llm/layers/transformer/rotary/preprocessing.py b/fast_llm/layers/transformer/rotary/preprocessing.py index c357411b6..9f8732f85 100644 --- a/fast_llm/layers/transformer/rotary/preprocessing.py +++ b/fast_llm/layers/transformer/rotary/preprocessing.py @@ -4,7 +4,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig from fast_llm.tensor import TensorMeta @@ -26,34 +26,34 @@ def __init__( self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] - self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] + self._kv_channels_dim = self._tensor_space[AttentionDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[TransformerKwargs.sequence_length]) - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - kwargs[TransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ - :, sequence_k - kwargs[TransformerKwargs.sequence_q_dim].size : sequence_k + self._create_tensors(kwargs[AttentionKwargs.sequence_length]) + sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size + kwargs[AttentionKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ + :, sequence_k - kwargs[AttentionKwargs.sequence_q_dim].size : sequence_k ] - kwargs[TransformerKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] + kwargs[AttentionKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[TransformerKwargs.rotary_freq_q] = TensorMeta.from_dims( + kwargs[AttentionKwargs.rotary_freq_q] = TensorMeta.from_dims( ( self._scalar_dim, - kwargs[TransformerKwargs.sequence_q_dim], + kwargs[AttentionKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=TransformerKwargs.rotary_freq_q, + tensor_name=AttentionKwargs.rotary_freq_q, ) - kwargs[TransformerKwargs.rotary_freq_k] = TensorMeta.from_dims( + kwargs[AttentionKwargs.rotary_freq_k] = TensorMeta.from_dims( ( self._scalar_dim, - kwargs[TransformerKwargs.sequence_q_dim], + kwargs[AttentionKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=TransformerKwargs.rotary_freq_k, + tensor_name=AttentionKwargs.rotary_freq_k, ) def _create_tensors(self, sequence_length: int) -> None: diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index 17b18a1ca..ebb629aa1 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -8,7 +8,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.triton.rotary import triton_rotary_autograd_ -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs from fast_llm.layers.transformer.rotary.config import ( DefaultRotaryConfig, Llama3RotaryConfig, @@ -83,44 +83,44 @@ def __init__( self._tensor_space = tensor_space if self._tensor_space is not None: self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] - self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] + self._kv_channels_dim = self._tensor_space[AttentionDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: assert self._tensor_space is not None - self._create_tensors(kwargs[TransformerKwargs.sequence_length]) - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - kwargs[TransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ - :, sequence_k - kwargs[TransformerKwargs.sequence_q_dim].size : sequence_k + self._create_tensors(kwargs[AttentionKwargs.sequence_length]) + sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size + kwargs[AttentionKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ + :, sequence_k - kwargs[AttentionKwargs.sequence_q_dim].size : sequence_k ] - kwargs[TransformerKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] + kwargs[AttentionKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: assert self._tensor_space is not None - kwargs[TransformerKwargs.rotary_freq_q] = TensorMeta.from_dims( + kwargs[AttentionKwargs.rotary_freq_q] = TensorMeta.from_dims( ( self._scalar_dim, - kwargs[TransformerKwargs.sequence_q_dim], + kwargs[AttentionKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=TransformerKwargs.rotary_freq_q, + tensor_name=AttentionKwargs.rotary_freq_q, ) - kwargs[TransformerKwargs.rotary_freq_k] = TensorMeta.from_dims( + kwargs[AttentionKwargs.rotary_freq_k] = TensorMeta.from_dims( ( self._scalar_dim, - kwargs[TransformerKwargs.sequence_q_dim], + kwargs[AttentionKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=TransformerKwargs.rotary_freq_k, + tensor_name=AttentionKwargs.rotary_freq_k, ) def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: rotary_fn = triton_rotary_autograd_ if self._config.triton else apply_rotary_embeddings - query = rotary_fn(query, kwargs[TransformerKwargs.rotary_freq_q]) - key = rotary_fn(key, kwargs[TransformerKwargs.rotary_freq_k]) + query = rotary_fn(query, kwargs[AttentionKwargs.rotary_freq_q]) + key = rotary_fn(key, kwargs[AttentionKwargs.rotary_freq_k]) return query, key def _create_tensors(self, sequence_length: int) -> None: diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index 3c0ad8ab4..eb24ef183 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -36,7 +36,7 @@ def get_layers(self) -> list[Layer]: self._tensor_space, block_index=i + 1, ) - for i in range(self._config.transformer.num_layers) + for i in range(self._config.transformer.num_blocks) ], CustomHead(self._config, self._tensor_space), ] diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 0da16428e..a7fcad82d 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -9,7 +9,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.language_model.config import LanguageModelBaseConfig +from fast_llm.layers.language_model.config import LanguageModelConfig from fast_llm.models.gpt.megatron import set_megatron_distributed_seeds from fast_llm.utils import Assert, div @@ -119,7 +119,7 @@ def micro_batch_splits(self) -> int: @config_class() -class GPTBaseModelConfig(LanguageModelBaseConfig): +class GPTBaseModelConfig(LanguageModelConfig): _abstract = False # Debug, to get an exact match with megatron init. @@ -192,15 +192,12 @@ class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate() def _validate(self) -> None: - if self.batch.sequence_length is None: - # TODO: Drop this. - self.batch.sequence_length = self.model.base_model.max_position_embeddings if self.model.base_model.use_megatron_initialization: set_megatron_distributed_seeds(self.model.distributed) super()._validate() if self.model.base_model.use_absolute_position_embeddings: - Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) + Assert.geq(self.model.base_model.absolute_position_embeddings, self.batch.sequence_length) distillation_model = self.model.base_model.distillation_model dpo_reference_model = self.model.base_model.dpo_reference_model diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 2dbef77f3..f3e57fe13 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -176,7 +176,7 @@ def _create_weight_converters( self, ) -> list[WeightConverter]: converters = [] - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = self._model.config.base_model.transformer.num_blocks # Embeddings converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) @@ -256,7 +256,7 @@ def _create_transformer_layer_converters( return converters def _create_lm_head_converters(self) -> list[WeightConverter]: - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = self._model.config.base_model.transformer.num_blocks prediction_heads = self._model.config.base_model.prediction_heads norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) converters = [] @@ -654,7 +654,7 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig # Override base method to handle the MTP heads def _create_lm_head_converters(self) -> list[WeightConverter]: - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = self._model.config.base_model.transformer.num_blocks prediction_heads = self._model.config.base_model.prediction_heads norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) converters = [] diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index cf7da3872..4e3f258fc 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -9,7 +9,7 @@ from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM -from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.transformer.config import AttentionKwargs from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner @@ -86,12 +86,12 @@ def forward( if past_key_values is not None: # The transformers will use the past keys and values to this list. - kwargs[TransformerKwargs.past_key_values] = past_key_values + kwargs[AttentionKwargs.past_key_values] = past_key_values # TODO: preprocess needs to know about the past. raise NotImplementedError() if use_cache: # The transformers will save the present keys and values to this list. - kwargs[TransformerKwargs.presents] = [] + kwargs[AttentionKwargs.presents] = [] if output_hidden_states: kwargs["output_hidden_states"] = True @@ -117,11 +117,11 @@ def forward( outputs = (logits,) if use_cache: - outputs += (kwargs[TransformerKwargs.presents],) + outputs += (kwargs[AttentionKwargs.presents],) return outputs return transformers.modeling_outputs.CausalLMOutputWithPast( logits=logits, hidden_states=hidden_states, - past_key_values=kwargs[TransformerKwargs.presents], + past_key_values=kwargs[AttentionKwargs.presents], ) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index da647de57..30842597d 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -16,7 +16,7 @@ from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor from fast_llm.layers.transformer.block import TransformerBlock -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron @@ -68,7 +68,7 @@ def get_output_layers(self) -> list[Layer]: self._config.transformer, self._tensor_space, # TODO MTP: which index? - block_index=max(self._config.transformer.num_layers + i, 1), + block_index=max(self._config.transformer.num_blocks + i, 1), # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=i < self._config.prediction_heads - 1, @@ -93,9 +93,9 @@ def get_layers(self) -> list[Layer]: block_index=i + 1, # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. - return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, + return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_blocks - 1, ) - for i in range(self._config.transformer.num_layers) + for i in range(self._config.transformer.num_blocks) ], *self.get_output_layers(), ] @@ -119,7 +119,7 @@ def preprocess_meta( truncate_documents = True batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) - batch_dim = TensorDim(TransformerDimNames.batch, micro_batch_size * batch_data.size, batch_data) + batch_dim = TensorDim(AttentionDimNames.batch, micro_batch_size * batch_data.size, batch_data) if micro_sequence_length is None: micro_sequence_length = sequence_length @@ -128,13 +128,13 @@ def preprocess_meta( # TODO: Calculate hidden dims elsewhere? sequence_q_dim = TensorDim( - TransformerDimNames.sequence_q, + AttentionDimNames.sequence_q, micro_sequence_length, self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.sequence_data), ) hidden_sequence_q_dim = ( TensorDim( - TransformerDimNames.sequence_q_tp, + AttentionDimNames.sequence_q_tp, micro_sequence_length, self._tensor_space.distributed_config.get_distributed_dim( DistributedDimNames.tensor_and_sequence_data @@ -151,7 +151,7 @@ def preprocess_meta( sequence_first = self._config.sequence_first assert not (need_sequence_first and not sequence_first) - hidden_dim = self._tensor_space[TransformerDimNames.hidden] + hidden_dim = self._tensor_space[AttentionDimNames.hidden] hidden_dims = ( (hidden_sequence_q_dim, batch_dim, hidden_dim) if sequence_first @@ -160,10 +160,10 @@ def preprocess_meta( common_kwargs = { LanguageModelKwargs.phase: phase, - TransformerKwargs.sequence_first: sequence_first, - TransformerKwargs.hidden_dims: hidden_dims, - TransformerKwargs.sequence_length: sequence_length, - TransformerKwargs.sequence_q_dim: sequence_q_dim, + AttentionKwargs.sequence_first: sequence_first, + AttentionKwargs.hidden_dims: hidden_dims, + AttentionKwargs.sequence_length: sequence_length, + AttentionKwargs.sequence_q_dim: sequence_q_dim, LanguageModelKwargs.mask_inputs: not truncate_documents, } @@ -182,7 +182,7 @@ def preprocess_meta( preprocessed_meta = [] for i, sequence_k_past in enumerate(sequence_k_pasts): sequence_k = sequence_k_past + sequence_q_dim.size - sequence_k_dim = TensorDim(TransformerDimNames.sequence_k, sequence_k) + sequence_k_dim = TensorDim(AttentionDimNames.sequence_k, sequence_k) tokens = TensorMeta.from_dims( hidden_dims[:2], tensor_name=f"tokens_{sequence_k_past}_to_{sequence_k-1}", dtype=torch.int64 @@ -190,7 +190,7 @@ def preprocess_meta( kwargs = { **common_kwargs, - TransformerKwargs.sequence_k_dim: sequence_k_dim, + AttentionKwargs.sequence_k_dim: sequence_k_dim, } if phase != PhaseType.inference: kwargs[LanguageModelKwargs.labels] = TensorMeta.from_dims( @@ -202,10 +202,10 @@ def preprocess_meta( for name, reference_preprocessed_meta in reference_preprocessed_metas.items(): reference_tokens, reference_kwargs_ = reference_preprocessed_meta[i] for key in ( - TransformerKwargs.sequence_first, - TransformerKwargs.sequence_length, - TransformerKwargs.sequence_q_dim, - TransformerKwargs.sequence_k_dim, + AttentionKwargs.sequence_first, + AttentionKwargs.sequence_length, + AttentionKwargs.sequence_q_dim, + AttentionKwargs.sequence_k_dim, ): Assert.eq(reference_kwargs_[key], kwargs[key]) reference_kwargs[name] = reference_kwargs_ @@ -231,8 +231,8 @@ def preprocess( preprocessed_meta = self.preprocess_meta(batch.token_ids, phase) _, common_kwargs = preprocessed_meta[0] - sequence_q = common_kwargs[TransformerKwargs.sequence_q_dim].size - sequence_first = common_kwargs[TransformerKwargs.sequence_first] + sequence_q = common_kwargs[AttentionKwargs.sequence_q_dim].size + sequence_first = common_kwargs[AttentionKwargs.sequence_first] prediction_heads: int = self._config.prediction_heads batch.token_ids = batch.token_ids.to( @@ -264,14 +264,14 @@ def preprocess( preprocessed = [] presents = None for i, (_, kwargs_meta) in enumerate(preprocessed_meta): - sequence_k = kwargs_meta[TransformerKwargs.sequence_k_dim].size + sequence_k = kwargs_meta[AttentionKwargs.sequence_k_dim].size if sequence_first: tokens = batch.token_ids[sequence_k - sequence_q : sequence_k] else: # TODO: Avoid multiple contiguous calls? tokens = batch.token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() if batch.sequence_lengths is not None: - kwargs_meta[TransformerKwargs.sequence_lengths] = batch.sequence_lengths + kwargs_meta[AttentionKwargs.sequence_lengths] = batch.sequence_lengths if batch.chosen_spans is not None: kwargs_meta[LanguageModelKwargs.chosen_spans] = batch.chosen_spans if batch.rejected_spans is not None: @@ -283,8 +283,8 @@ def preprocess( presents = None if i == len(preprocessed_meta) - 1 else [] kwargs = { **kwargs_meta, - TransformerKwargs.past_key_values: pasts, - TransformerKwargs.presents: presents, + AttentionKwargs.past_key_values: pasts, + AttentionKwargs.presents: presents, } if phase != PhaseType.inference: sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels @@ -372,7 +372,7 @@ def loss_defs(self) -> list[LossDef]: LossDef( name=MLPLossNames.load_balancing_loss, formatted_name="load balancing loss", - count=self._config.transformer.num_layers, + count=self._config.transformer.num_blocks, ) ) if self._config.transformer.expert_z_loss_coefficient: @@ -380,7 +380,7 @@ def loss_defs(self) -> list[LossDef]: LossDef( name=MLPLossNames.router_z_loss, formatted_name="router z loss", - count=self._config.transformer.num_layers, + count=self._config.transformer.num_blocks, ) ) if self._config.logit_z_loss: @@ -421,7 +421,7 @@ def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration, batch_size, s consumed_tokens_per_iteration = sequence_length * batch_size - num_transformer_layers = transformer_config.num_layers + self._config.base_model.prediction_heads - 1 + num_transformer_layers = transformer_config.num_blocks + self._config.base_model.prediction_heads - 1 transformer_flops_base = ( 2 * checkpoint_activations_factor * consumed_tokens_per_iteration * num_transformer_layers ) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 9427f69be..a351522ca 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -62,13 +62,13 @@ def _validate(self): if self.hybrid_block_layout is None: with self._set_implicit_default(): - self.hybrid_block_layout = [SSMBlockType.mamba2_discrete] * self.transformer.num_layers + self.hybrid_block_layout = [SSMBlockType.mamba2_discrete] * self.transformer.num_blocks - if len(self.hybrid_block_layout) != self.transformer.num_layers: - message = f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" - if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: + if len(self.hybrid_block_layout) != self.transformer.num_blocks: + message = f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_blocks}" + if self.transformer.num_blocks % len(self.hybrid_block_layout) != 0: raise ValueError(message) - num_repeats = self.transformer.num_layers // len(self.hybrid_block_layout) + num_repeats = self.transformer.num_blocks // len(self.hybrid_block_layout) logger.warning(f"{message}, will repeat {self.hybrid_block_layout} {num_repeats} times.") self.hybrid_block_layout = self.hybrid_block_layout * num_repeats diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 43e3c67e5..fb24c1aec 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -219,7 +219,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: def _create_weight_converters(self) -> list[WeightConverter]: converters = super()._create_weight_converters() or [] - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = self._model.config.base_model.transformer.num_blocks ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear for i in range(num_layers): @@ -383,7 +383,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: def _create_weight_converters(self) -> list[WeightConverter]: # not using super() because LLamba model is called backbone in the checkpoints converters = [] - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = self._model.config.base_model.transformer.num_blocks norm_bias: bool = False ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear @@ -572,7 +572,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: def _create_weight_converters(self) -> list[WeightConverter]: converters = super()._create_weight_converters() - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = self._model.config.base_model.transformer.num_blocks norm_bias: bool = False # Embedding and output diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index d080e6a1e..b12d12072 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -1,13 +1,12 @@ -import abc import functools import logging -import math import typing import torch from fast_llm.core.distributed import ReduceOp from fast_llm.core.ops import reduce_op +from fast_llm.engine.config_utils.initialization import Initializer, LambdaInitializer from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed @@ -361,70 +360,3 @@ def accumulate_gradient(param: torch.Tensor, grad: torch.Tensor) -> None: triton_copy(grad, param.grad_buffer) # noqa else: triton_add(grad, param.grad_buffer, out=param.grad_buffer) # noqa - - -class Initializer(abc.ABC): - @abc.abstractmethod - def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: - pass - - requires_global_initialization = False - - -class LambdaInitializer(Initializer): - def __init__( - self, - init_method: typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None], - requires_global_initialization: bool = False, - ) -> None: - self._init_method = init_method - self.requires_global_initialization = requires_global_initialization - - def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: - return self._init_method(meta, tensor, generator) - - -def init_fill_(value: float) -> LambdaInitializer: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa - tensor.fill_(value) - - return LambdaInitializer(init_) - - -init_zeros_ = init_fill_(0.0) -init_ones_ = init_fill_(1.0) - - -def init_normal_( - mean: float = 0.0, std: float = 1.0, min_val: float | None = None, max_val: float | None = None -) -> LambdaInitializer: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa - tensor = tensor.normal_(mean, std, generator=generator) - if min_val is not None or max_val is not None: - tensor.clamp_(min=min_val, max=max_val) - - return LambdaInitializer(init_) - - -def init_kaiming_(d_in: float) -> LambdaInitializer: - return init_normal_(0.0, math.sqrt(2.0 / d_in)) - - -def init_uniform_( - low: float = 0.0, high: float = 1.0, min_val: float | None = None, max_val: float | None = None -) -> LambdaInitializer: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa - tensor = tensor.uniform_(low, high, generator=generator) - if min_val is not None or max_val is not None: - tensor.clamp_(min=min_val, max=max_val) - - return LambdaInitializer(init_) - - -def init_uniform_centered_(high: float, max_val: float | None = None, mean: float = 0.0) -> LambdaInitializer: - return init_uniform_( - mean - high, - mean + high, - min_val=None if max_val is None else mean - max_val, - max_val=None if max_val is None else mean + max_val, - ) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 9a878c494..8c33aed4d 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -9,7 +9,7 @@ from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead -from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.transformer.config import AttentionKwargs from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda @@ -198,8 +198,8 @@ def test_lm_head( else: loss_mask = None kwargs = { - TransformerKwargs.sequence_first: sequence_first, - TransformerKwargs.grad_output: 1.0, + AttentionKwargs.sequence_first: sequence_first, + AttentionKwargs.grad_output: 1.0, } if config.distillation_model is None: target = torch.randint( diff --git a/tests/models/test_generate.py b/tests/models/test_generate.py index 7f0b902f8..cb9c69ccb 100644 --- a/tests/models/test_generate.py +++ b/tests/models/test_generate.py @@ -354,7 +354,7 @@ def _test_forward_return_hidden_states( # hidden_states include embeddings layer assert ( - len(res_fast_llm.hidden_states) - 1 == fast_llm_model.config.fast_llm_config.base_model.transformer.num_layers + len(res_fast_llm.hidden_states) - 1 == fast_llm_model.config.fast_llm_config.base_model.transformer.num_blocks ) diff --git a/tests/test_attention.py b/tests/test_attention.py index dd36b840a..534e3800e 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -6,7 +6,7 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs, TransformerConfig from fast_llm.layers.transformer.preprocessing import FlashAttnVarlenPreprocessor from fast_llm.utils import Assert @@ -77,13 +77,13 @@ def test_varlen_preprocessor(): varlen_preprocessor = FlashAttnVarlenPreprocessor(transformer_cfg, tensor_space=tensor_space) for micro_seq_idx in range(int(sequence_length / micro_sequence_length)): kwargs = { - TransformerKwargs.sequence_q_dim: TensorDim(TransformerDimNames.sequence_k, micro_sequence_length), - TransformerKwargs.sequence_k_dim: TensorDim( - TransformerDimNames.sequence_k, (micro_seq_idx + 1) * micro_sequence_length + AttentionKwargs.sequence_q_dim: TensorDim(AttentionDimNames.sequence_k, micro_sequence_length), + AttentionKwargs.sequence_k_dim: TensorDim( + AttentionDimNames.sequence_k, (micro_seq_idx + 1) * micro_sequence_length ), - TransformerKwargs.sequence_length: sequence_length, - TransformerKwargs.sequence_lengths: sequence_lengths, + AttentionKwargs.sequence_length: sequence_length, + AttentionKwargs.sequence_lengths: sequence_lengths, } varlen_preprocessor.preprocess(None, kwargs) - Assert.all_equal(kwargs[TransformerKwargs.cu_seqlens_q], cumulative_sequences_q[micro_seq_idx]) - Assert.all_equal(kwargs[TransformerKwargs.cu_seqlens_k], cumulative_sequences_k[micro_seq_idx]) + Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_q], cumulative_sequences_q[micro_seq_idx]) + Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_k], cumulative_sequences_k[micro_seq_idx]) diff --git a/tests/test_ssms.py b/tests/test_ssms.py index 694faa55b..6c4c7f0cb 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -9,7 +9,7 @@ from fast_llm.engine.schedule.config import ScheduleConfig from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.engine.schedule.schedule import Schedule -from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.transformer.config import AttentionKwargs from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat from fast_llm.models.ssm.model import HybridSSMModel @@ -71,8 +71,8 @@ def test_load_from_llamba_checkpoint(): schedule_runner.setup(model.distributed, optimizer=None) common_kwargs = { - TransformerKwargs.sequence_first: True, - TransformerKwargs.grad_output: False, + AttentionKwargs.sequence_first: True, + AttentionKwargs.grad_output: False, } input_data = [(x, common_kwargs)] diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 722d8d63a..4705ebb79 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -162,6 +162,7 @@ def _update_and_add_testing_config( "model.base_model.transformer.num_attention_heads=8", "model.base_model.transformer.head_groups=8", "model.base_model.transformer.init_method_std=0.022", + "model.base_model.transformer.use_position_embeddings=True", f"model.base_model.vocab_size={MODEL_TEST_VOCAB_SIZE}", f"model.multi_stage.debug_param_init={_LOG_LEVEL}", f"model.multi_stage.debug_layer_outputs={_LOG_LEVEL}", @@ -258,6 +259,7 @@ def _update_and_add_testing_config( extra_args=[ "model.base_model.transformer.head_groups=4", "model.base_model.transformer.rotary.type=default", + "model.base_model.transformer.use_position_embeddings=False", # Unused, but prevents issues with conversion tests. "model.base_model.max_position_embeddings=2048", ], From ab484ac94555915bd2d808279d5909de42541550 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 31 Jul 2025 15:29:57 -0400 Subject: [PATCH 05/28] Revert "stuff" This reverts commit a5eb0767e99038e18c1bd07f7f78718634296c4c. --- docs/developer_guide/conversion.md | 30 +-- .../engine/config_utils/initialization.py | 178 ------------ fast_llm/layers/block/block.py | 132 ++------- fast_llm/layers/block/config.py | 160 ++--------- fast_llm/layers/block/mixer.py | 68 +++++ fast_llm/layers/block/mlp/config.py | 79 +----- .../layers/block/mlp/mixture_of_experts.py | 134 ++++++---- fast_llm/layers/block/mlp/mlp.py | 72 +++-- fast_llm/layers/common/config.py | 2 +- fast_llm/layers/common/linear.py | 3 +- fast_llm/layers/common/normalization.py | 3 +- fast_llm/layers/language_model/config.py | 122 +++++---- fast_llm/layers/language_model/embedding.py | 48 ++-- fast_llm/layers/language_model/head.py | 109 ++++---- .../layers/language_model/preprocessing.py | 10 +- fast_llm/layers/ssm/config.py | 4 +- fast_llm/layers/ssm/discrete_mamba2.py | 4 +- fast_llm/layers/ssm/mamba2.py | 5 +- fast_llm/layers/ssm/mamba_layer.py | 7 +- fast_llm/layers/transformer/attention.py | 142 +++++----- fast_llm/layers/transformer/config.py | 253 ++++++++++++------ fast_llm/layers/transformer/preprocessing.py | 52 ++-- .../transformer/rotary/preprocessing.py | 26 +- fast_llm/layers/transformer/rotary/rotary.py | 30 +-- fast_llm/models/custom/model.py | 2 +- fast_llm/models/gpt/config.py | 9 +- fast_llm/models/gpt/conversion.py | 6 +- fast_llm/models/gpt/huggingface.py | 10 +- fast_llm/models/gpt/model.py | 54 ++-- fast_llm/models/ssm/config.py | 10 +- fast_llm/models/ssm/conversion.py | 6 +- fast_llm/tensor.py | 70 ++++- tests/layers/test_lm_head.py | 6 +- tests/models/test_generate.py | 2 +- tests/test_attention.py | 16 +- tests/test_ssms.py | 6 +- tests/utils/model_configs.py | 2 - 37 files changed, 857 insertions(+), 1015 deletions(-) delete mode 100644 fast_llm/engine/config_utils/initialization.py create mode 100644 fast_llm/layers/block/mixer.py diff --git a/docs/developer_guide/conversion.md b/docs/developer_guide/conversion.md index 719757df1..0620beaea 100644 --- a/docs/developer_guide/conversion.md +++ b/docs/developer_guide/conversion.md @@ -230,23 +230,21 @@ Continuing our `AwesomeModel` handler example, we define: ```python def _create_weight_converters(self) -> list[WeightConverter]: - - converters = [] -# The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`. -num_layers = self._model.config.base_model.transformer.num_blocks - -# A simple renaming example, for the word embeddings. -converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) - -# We usually want to loop dynamically over layers -for i in range(num_layers): - # A `SplitWeightConverter` example, splitting a weight in two. - converters.append(SplitWeightConverter( - f"layers.{i + 1}.weight", - (f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"), - )) -return converters + # The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`. + num_layers = self._model.config.base_model.transformer.num_layers + + # A simple renaming example, for the word embeddings. + converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) + + # We usually want to loop dynamically over layers + for i in range(num_layers): + # A `SplitWeightConverter` example, splitting a weight in two. + converters.append(SplitWeightConverter( + f"layers.{i + 1}.weight", + (f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"), + )) + return converters ``` And that's it! We're ready to use the new checkpoint format in Fast-LLM. diff --git a/fast_llm/engine/config_utils/initialization.py b/fast_llm/engine/config_utils/initialization.py deleted file mode 100644 index d35c2220c..000000000 --- a/fast_llm/engine/config_utils/initialization.py +++ /dev/null @@ -1,178 +0,0 @@ -import abc -import typing - -from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.utils import Assert - -if typing.TYPE_CHECKING: - import torch - - from fast_llm.tensor import ParameterMeta - - -@config_class(registry=True) -class InitializationConfig(Config): - _abstract = True - has_initialization: typing.ClassVar[bool] = True - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - if cls is InitializationConfig and cls.get_subclass(default.get("type")) is None: - # Default subclass. - return DefaultInitializationConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) - - def get_initializer(self) -> "Initializer": - raise NotImplementedError() - - -@config_class(dynamic_type={InitializationConfig: "default"}) -class DefaultInitializationConfig(InitializationConfig): - # A placeholder indicating that the class default should be used instead. - _abstract = False - has_initialization = False - - -@config_class(dynamic_type={InitializationConfig: "fill"}) -class NormalInitializationConfig(InitializationConfig): - """ - Normal initialization: normal(mean, std).clamp(min,max) - """ - - _abstract = False - - value: float = Field( - default=1, - desc="Initialization value.", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - - def get_initializer(self): - return init_fill_(self.value) - - -@config_class(dynamic_type={InitializationConfig: "zeros"}) -class ZeroInitializationConfig(InitializationConfig): - def get_initializer(self): - return init_zeros_ - - -@config_class(dynamic_type={InitializationConfig: "ones"}) -class ZeroInitializationConfig(InitializationConfig): - def get_initializer(self): - return init_ones_ - - -@config_class(dynamic_type={InitializationConfig: "normal"}) -class NormalInitializationConfig(InitializationConfig): - """ - Normal initialization: normal(mean, std).clamp(min,max) - """ - - _abstract = False - - std: float = Field( - default=1, - desc="Standard deviation for normal initialization.", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - mean: float = Field( - default=0, - desc="Mean for normal initialization.", - hint=FieldHint.optional, - ) - min: float | None = Field( - default=None, - desc="Min value for initialization clamping.", - hint=FieldHint.optional, - ) - max: float | None = Field( - default=None, - desc="Min value for initialization clamping.", - hint=FieldHint.optional, - ) - - def get_initializer(self): - return init_normal_(self.mean, self.std, self.min, self.max) - - -@config_class(dynamic_type={InitializationConfig: "uniform"}) -class UniformInitializationConfig(InitializationConfig): - """ - Uniform initialization: uniform(mean - scale, mean + scale).clamp(min,max) - """ - - _abstract = False - - scale: float = Field( - default=None, - desc="Initialization scale.", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - mean: float = Field( - default=None, - desc="Initialization mean.", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - - def get_initializer(self) -> "Initializer": - return init_uniform_centered_(self.scale, self.mean) - - -class Initializer(abc.ABC): - @abc.abstractmethod - def __call__(self, meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: - pass - - requires_global_initialization = False - - -class LambdaInitializer(Initializer): - def __init__( - self, - init_method: typing.Callable[["ParameterMeta", "torch.Tensor", "torch.Generator"], None], - requires_global_initialization: bool = False, - ) -> None: - self._init_method = init_method - self.requires_global_initialization = requires_global_initialization - - def __call__(self, meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: - return self._init_method(meta, tensor, generator) - - -def init_fill_(value: float) -> LambdaInitializer: - def init_(meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: # noqa - tensor.fill_(value) - - return LambdaInitializer(init_) - - -init_zeros_ = init_fill_(0.0) -init_ones_ = init_fill_(1.0) - - -def init_normal_( - mean: float = 0.0, std: float = 1.0, min_val: float | None = None, max_val: float | None = None -) -> LambdaInitializer: - def init_(meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: # noqa - tensor = tensor.normal_(mean, std, generator=generator) - if min_val is not None or max_val is not None: - tensor.clamp_(min=min_val, max=max_val) - - return LambdaInitializer(init_) - - -def init_uniform_centered_(scale: float, mean: float = 0.0) -> LambdaInitializer: - def init_(meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: # noqa - tensor.uniform_(mean - scale, mean + scale, generator=generator) - - return LambdaInitializer(init_) diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index d13b09807..85da61c01 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -1,5 +1,4 @@ import abc -import functools import typing import torch @@ -9,118 +8,23 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs, BlockLayerConfig +from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs +from fast_llm.layers.block.mixer import Mixer +from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP +from fast_llm.layers.block.mlp.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta -class DebugLayer: - # TODO: Move elsewhere? - def __init__(self, tensor_space: TensorSpace, name: str, debug_level: int = 0, debug_memory: bool = False): - self._tensor_space = tensor_space - self._name = name - self._debug_level = debug_level - self._debug_memory = debug_memory - - def _get_meta( - self, tensor: torch.Tensor, name: str, dims: tuple[TensorDim | str, ...], kwargs: dict[str, typing.Any] - ) -> TensorMeta: - hidden_dims = { - dim.name: dim for dim in kwargs[BlockKwargs.hidden_dims] + (kwargs[BlockKwargs.sequence_q_dim],) - } - return TensorMeta.from_dims( - tuple( - ( - dim - if isinstance(dim, TensorDim) - else hidden_dims[dim] if dim in hidden_dims else self._tensor_space[dim] - ) - for dim in dims - ), - tensor_name=f"{self._name} {name}", - dtype=tensor.dtype, - ) - - @functools.cached_property - def enabled(self) -> bool: - return self._debug_level > 0 or self._debug_memory - - def __call__( - self, - tensor: torch.Tensor, - name: str, - dims: tuple[TensorDim | str, ...], - kwargs: dict[str, typing.Any], - scale: float = 1.0, - global_: bool = True, - log_fn: type[BaseException] | typing.Callable[[str], T] | None = logger.info, - ) -> None: - # TODO: Local vs global? - if self._debug_memory: - log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"{self._name} {name}", str)) - if self._debug_level > 0: - log_distributed_tensor( - "", - tensor, - level=self._debug_level, - meta=self._get_meta(tensor, name, dims, kwargs), - distributed=self._tensor_space.distributed, - global_=global_, - log_fn=log_fn, - scale=scale, - ) - if tensor.requires_grad: - log_distributed_grad( - "", - tensor, - level=self._debug_level, - meta=self._get_meta(tensor, name + " grad", dims, kwargs), - distributed=self._tensor_space.distributed, - global_=global_, - log_fn=log_fn, - scale=scale, - ) - - -class BlockLayer[ConfigType: BlockLayerConfig](Configurable[ConfigType], torch.nn.Module): - """ - Base class for mixer and MLP modules. - """ - - def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str): - super().__init__(config) - self._tensor_space = tensor_space - self._block_index = block_index - self._name = name - self._sequence_parallel: bool = self._tensor_space.distributed_config.sequence_tensor_parallel - self._debug = DebugLayer( - tensor_space, - f"Block {self._block_index} {self._name}", - self.config.block.debug_transformer, - self._config.block.debug_transformer_memory, - ) - - @abc.abstractmethod - def forward( - self, - input_: torch.Tensor, - kwargs: dict[str, typing.Any], - losses: dict[str, typing.Any] | None = None, - metrics: dict[str, typing.Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - pass - - -class Block[ConfigType: BlockConfig](Configurable[ConfigType], Layer): +class Block[ConfigType: BlockConfig](Layer, Configurable[ConfigType]): """ A transformer-like decoder base block with abstract mixer. """ # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "mixer" - def __init__( - self, config: ConfigType, tensor_space: TensorSpace, block_index: int = 0, return_input: bool = False - ): + def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, return_input: bool = False): super().__init__() self._config = config self._tensor_space: TensorSpace = tensor_space @@ -136,19 +40,21 @@ def __init__( self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) - # Attribute should be mixer, but Attention uses a different name for backward compatibility. TODO: Fix. - setattr( - self, - self._config.mixer.module_name, - self._config.mixer.get_layer(self._tensor_space, block_index, f"{self.name} mixer"), - ) + # The mixer needs to be created here for backward-compatible weight ordering. + setattr(self, self._mixer_module_name, self._create_mixer()) - self.mlp = self._config.mlp.get_layer(self._tensor_space, block_index, f"{self.name} mlp") + self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( + self._config, self._tensor_space, f"{self.name} mlp", block_index=block_index + ) # PEFT. self.norm_1 = self._config.peft.apply_other(self.norm_1) self.norm_2 = self._config.peft.apply_other(self.norm_2) + @abc.abstractmethod + def _create_mixer(self) -> Mixer: + pass + @torch.compile def _bias_dropout_add( self, input_: torch.Tensor, bias: torch.Tensor | None, residual: torch.Tensor @@ -207,13 +113,13 @@ def forward( hidden_states = self.norm_1(input_) if self._debug_mode: self._debug_log(hidden_states, "Norm 1", kwargs) - hidden_states, bias = getattr(self, self._config.mixer.module_name)(hidden_states, kwargs) + hidden_states, bias = getattr(self, self._mixer_module_name)(hidden_states, kwargs) if self._debug_mode: - self._debug_log(hidden_states, f"{self._config.mixer.module_name} output", kwargs, bias=bias) + self._debug_log(hidden_states, f"{self._mixer_module_name} output", kwargs, bias=bias) with set_generator(generator): input_ = self._bias_dropout_add(hidden_states, bias, input_) if self._debug_mode: - self._debug_log(input_, f"{self._config.mixer.module_name} residual", kwargs) + self._debug_log(input_, f"{self._mixer_module_name} residual", kwargs) hidden_states = self.norm_2(input_) if self._debug_mode: self._debug_log(hidden_states, "Norm 2", kwargs) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 87bd6d249..5a999fa6d 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -1,21 +1,13 @@ -import abc import enum -import functools -import typing from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.layers.block.mlp.config import MLPConfig from fast_llm.layers.block.peft import TransformerPeftConfig from fast_llm.layers.common.config import NormalizationConfig from fast_llm.utils import Assert -if typing.TYPE_CHECKING: - from fast_llm.layers.block.block import Block, BlockLayer - - -# TODO: Generalize these beyond language models? (Ex. vision) - class BlockDimNames: # A set of common tensor dim names packed into a namespace. @@ -47,76 +39,10 @@ class AddLinearBiasChoices(str, enum.Enum): only_attn_qkv = "only_attn_qkv" -@config_class(registry=True) -class BlockLayerConfig(BaseModelConfig): - _abstract = True - block: "BlockConfig" = Field(init=False) - - def _validate(self) -> None: - assert hasattr(self, "block") - Assert.is_(self.block.mlp, self) - super()._validate() - - @property - def layer_class(self) -> "type[BlockLayer]": - raise NotImplementedError() - - def get_layer(self, tensor_space: TensorSpace, block_index: int, name: str) -> "BlockLayer": - return self.layer_class(self, tensor_space, block_index, name) - - -@config_class() -class MixerConfig(BlockLayerConfig): - _abstract = True - - # Needed for backward compatibility. - module_name: typing.ClassVar[str] = "mixer" - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - if cls is MixerConfig and cls.get_subclass(default.get("type")) is None: - from fast_llm.layers.transformer.config import AttentionConfig - - # Default subclass. - return AttentionConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) - - @config_class() -class MLPBaseConfig(BlockLayerConfig): - _abstract = True - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - if cls is MLPBaseConfig and cls.get_subclass(default.get("type")) is None: - from fast_llm.layers.block.mlp.config import MLPConfig +# TODO: Use composition for MLP config +class BlockConfig(MLPConfig, BaseModelConfig): - # Default subclass. - return MLPConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) - - -@config_class() -class BlockConfig(BaseModelConfig): - _abstract = False - mixer: MixerConfig = Field( - desc="Configuration for the mixer.", - hint=FieldHint.architecture, - ) - mlp: MLPBaseConfig = Field( - desc="Configuration for the MLP.", - hint=FieldHint.architecture, - ) # TODO: Review names normalization: NormalizationConfig = Field( desc="Configuration for the normalization layers architecture.", @@ -132,6 +58,11 @@ class BlockConfig(BaseModelConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) + full_precision_residual: bool = Field( + default=False, + desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", + hint=FieldHint.stability, + ) debug_transformer: int = Field( default=0, desc="Log the output of each operation in a transformer layer.", @@ -149,45 +80,8 @@ class BlockConfig(BaseModelConfig): hint=FieldHint.architecture, ) - block_sequence: "BlockSequenceConfig" = Field(init=False) - - def _validate(self) -> None: - assert hasattr(self, "block_sequence") - Assert.incl(self, self.block_sequence.blocks.values()) - self.mixer.block = self - self.mlp.block = self - super()._validate() - - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - self.mlp.setup_tensor_space(tensor_space) - self.mixer.setup_tensor_space(tensor_space) - - # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(BlockDimNames.hidden, self.block_sequence.hidden_size)) - - @abc.abstractmethod - def get_block(self) -> "Block": - pass - - -@config_class() -class BlockSequenceConfig(BaseModelConfig): - _abstract = True - - blocks: dict[str, BlockConfig] = Field() - block_pattern: tuple[str, ...] = Field( - default=None, - desc="The pattern of blocks (referred by name) to use. The sequence is repeated until reaching `num_blocks`." - " Default: cycle over `blocks` in the order they are defined.", - ) - default_block: str = Field( - default=None, - desc="The default block configuration to use when referring to the model." - " Used to set some defaults in the language model.", - ) - # TODO: Move these, not specific to a single block. - num_blocks: int = Field( + num_layers: int = Field( default=12, desc="Number of layers in the transformer.", hint=FieldHint.architecture, @@ -199,28 +93,30 @@ class BlockSequenceConfig(BaseModelConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - full_precision_residual: bool = Field( - default=False, - desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", - hint=FieldHint.stability, + per_layer_lr_scale: list[float] | None = Field( + default=None, + desc="Custom learning rate scale for each layer.", + doc="May be used to freeze some layers by setting their scale to zero.", + hint=FieldHint.feature, ) def _validate(self) -> None: - for block in self.blocks.values(): - block.validate() - if self.block_pattern is None: - self.block_pattern = tuple(self.blocks) - if self.default_block is None: - self.default_block = self.block_pattern[0] + with self._set_implicit_default(): + if self.ffn_hidden_size is None: + self.ffn_hidden_size = 4 * self.hidden_size + super()._validate() - def get_block_config(self, block_index: int) -> BlockConfig: - return self.blocks[self.block_pattern[block_index % len(self.block_pattern)]] + @property + def add_mlp_bias(self) -> bool: + if isinstance(self.add_linear_biases, bool): + return self.add_linear_biases + if self.add_linear_biases == AddLinearBiasChoices.everywhere: + return True + return False def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - for block in self.blocks.values(): - block.setup_tensor_space(tensor_space) + super().setup_tensor_space(tensor_space) - @functools.cached_property - def default_block_config(self) -> BlockConfig: - return self.blocks[self.default_block] + # Hidden dimension + tensor_space.add_tensor_dim(TensorDim(BlockDimNames.hidden, self.hidden_size)) diff --git a/fast_llm/layers/block/mixer.py b/fast_llm/layers/block/mixer.py new file mode 100644 index 000000000..5c811e330 --- /dev/null +++ b/fast_llm/layers/block/mixer.py @@ -0,0 +1,68 @@ +import abc +import typing + +import torch + +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.logging import log_distributed_grad, log_distributed_tensor +from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert + + +class Mixer(torch.nn.Module, abc.ABC): + """ + Base class for mixer modules. + """ + + _mixer_name: typing.ClassVar[str] + + def __init__(self, tensor_space: TensorSpace, block_index: int, debug_level: int = 0): + super().__init__() + self._tensor_space = tensor_space + self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel + self._block_index = block_index + self._debug_level = debug_level + + @abc.abstractmethod + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Mixer module forward. Returns the output hidden states and an optional bias, + in case its addition can be made more efficient in `_bias_dropout_add`. + """ + + def _get_meta( + self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> TensorMeta: + hidden_dims = { + dim.name: dim for dim in kwargs[BlockKwargs.hidden_dims] + (kwargs[BlockKwargs.sequence_q_dim],) + } + return TensorMeta.from_dims( + tuple( + hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space[dim_name] + for dim_name in dim_names + ), + tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", + dtype=input_.dtype, + ) + + def _debug_log( + self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> None: + # TODO: Local vs global + Assert.gt(self._debug_level, 0) + log_distributed_tensor( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name, dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + if tensor.requires_grad: + log_distributed_grad( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 526c513db..1d125c4f7 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -1,18 +1,11 @@ import enum -import functools -import typing -from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_zeros_ +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType, MLPRecomputeLevel from fast_llm.utils import Assert -if typing.TYPE_CHECKING: - from fast_llm.layers.block.config import AddLinearBiasChoices, BlockLayerConfig - from fast_llm.layers.block.mlp.mlp import MLPBase - class MLPDimNames: # MLP dimensions @@ -39,10 +32,9 @@ class RoutingType(str, enum.Enum): sinkhorn = "sinkhorn" -@config_class(dynamic_type={BlockLayerConfig: "mlp"}) -class MLPConfig(BlockLayerConfig): +@config_class() +class MLPConfig(Config): # TODO: Review names - # TODO: Separate MoE? _abstract = False ffn_hidden_size: int = Field( default=None, @@ -132,52 +124,11 @@ class MLPConfig(BlockLayerConfig): " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", hint=FieldHint.expert, ) - layer_1_weight_initialization: InitializationConfig = Field( - desc="Initialization configuration for the first mlp layer weights. Default: hidden_size**-0.5", - hint=FieldHint.feature, - ) - layer_1_bias_initialization: InitializationConfig = Field( - desc="Initialization configuration for the first mlp layer biases. Default: fill with zeros.", - hint=FieldHint.feature, - ) - layer_2_weight_initialization: InitializationConfig = Field( - desc="Initialization configuration for the second mlp layer weights." - " Default: (2 * num_blocks * hidden_size)**-0.5", - hint=FieldHint.feature, - ) - layer_2_bias_initialization: InitializationConfig = Field( - desc="Initialization configuration for the second mlp layer biases. Default: fill with zeros.", - hint=FieldHint.feature, - ) - - @property - def layer_class(self) -> "type[MLPBase]": - if self.num_experts > 1: - from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP - - return MixtureOfExpertMLP - else: - from fast_llm.layers.block.mlp.mlp import MLP - - return MLP - - @property - def add_bias(self) -> bool: - if isinstance(self.block.add_linear_biases, bool): - return self.block.add_linear_biases - if self.block.add_linear_biases == AddLinearBiasChoices.everywhere: - return True - return False def _validate(self) -> None: - assert hasattr(self, "block") - with self._set_implicit_default(): if self.activation_type is None: self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu - if self.ffn_hidden_size is None: - # TODO: hidden_size not yet validated. - self.ffn_hidden_size = 4 * self.block.block_sequence.hidden_size self.num_unshared_experts = self.num_experts - self.num_shared_experts super()._validate() @@ -193,30 +144,6 @@ def _validate(self) -> None: elif self.mlp_lr_scale is not None: Assert.geq(self.mlp_lr_scale, 0) - @functools.cached_property - def layer_1_weight_initialization_method(self) -> Initializer: - if not self.layer_1_weight_initialization.has_initialization: - return self.layer_1_weight_initialization.get_initializer() - return self.block.block_sequence.hidden_size**-0.5 - - @functools.cached_property - def layer_1_bias_initialization_method(self) -> Initializer: - if not self.layer_1_bias_initialization.has_initialization: - return self.layer_1_bias_initialization.get_initializer() - return init_zeros_ - - @functools.cached_property - def layer_2_weight_initialization_method(self) -> Initializer: - if self.layer_2_weight_initialization.has_initialization: - return self.layer_2_weight_initialization.get_initializer() - return self.block.block_sequence.hidden_size**-0.5 / max(2 * self.block.block_sequence.num_blocks, 1) - - @functools.cached_property - def layer_2_bias_initialization_method(self) -> Initializer: - if self.layer_2_bias_initialization.has_initialization: - return self.layer_2_bias_initialization.get_initializer() - return init_zeros_ - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index 332d3109f..8d092b6dc 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -1,24 +1,27 @@ import logging +import typing import warnings import torch from fast_llm.core.distributed import ProcessGroup, set_generator -from fast_llm.engine.config_utils.initialization import init_normal_ +from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped from fast_llm.functional.triton.sparse_copy import get_sparse_map -from fast_llm.layers.block.config import BlockDimNames, BlockKwargs -from fast_llm.layers.block.mlp.config import MLPConfig, MLPDimNames, MLPLossNames, RoutingType +from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs +from fast_llm.layers.block.mlp.config import MLPDimNames, MLPLossNames, RoutingType from fast_llm.layers.block.mlp.mlp import MLPBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.linear import Linear -from fast_llm.utils import get_lr_scale +from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage +from fast_llm.tensor import TensorMeta, init_normal_ +from fast_llm.utils import Assert, get_lr_scale logger = logging.getLogger(__name__) -class MixtureOfExpertMLP[ConfigType: MLPConfig](MLPBase[ConfigType]): +class MixtureOfExpertMLP[ConfigType: BlockConfig](MLPBase[ConfigType]): """ MoeLayer following implementation from https://github.com/NVIDIA/Megatron-LM/blob/46ebc0e4202c980d98900000d455f754a7ff9d4b/megatron/model/transformer.py#L346 @@ -32,10 +35,23 @@ class MixtureOfExpertMLP[ConfigType: MLPConfig](MLPBase[ConfigType]): _group: ProcessGroup - def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str): - super().__init__(config, tensor_space, block_index, name) + def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): + Assert.gt(config.num_experts, 1) # TODO: Implement? - assert not self._config.add_linear_biases, "Biases not supported for MoE." + assert not config.add_linear_biases, "Biases not supported for MoE." + super().__init__(config, tensor_space, name, block_index) + self._tensor_space = tensor_space + self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory + + self._num_experts = config.num_experts + self._experts_per_token = config.num_experts_per_token + self._num_shared_experts = config.num_shared_experts + self._num_unshared_experts = config.num_unshared_experts + + self._routing_type = config.expert_routing_type + self._load_balancing_factor = config.expert_auxiliary_loss_coefficient + self._z_loss_factor = config.expert_z_loss_coefficient + self._moe_jitter_eps = config.moe_jitter_eps layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) @@ -56,20 +72,21 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: i ) dropless_moe = False self._mlp_forward = self._forward_dropless if dropless_moe else self._forward_looped + self._dynamic_shape = config.dropless_dynamic_shape def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: hidden_states = input_.flatten(0, -2) logits = self.router(hidden_states) - if self._debug.enabled: - self._debug(logits, "Router logits", kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.experts,), kwargs) + if self._debug_mode: + self._debug_log(logits, "Router logits", MLPDimNames.experts, kwargs) # Apply z_loss if applicable - if self._config.expert_z_loss_coefficient > 0.0: + if self._z_loss_factor > 0.0: logits = z_loss( logits, - self._config.expert_z_loss_coefficient, + self._z_loss_factor, self.training, grad_scale=kwargs.get("grad_output"), losses=losses, @@ -77,31 +94,24 @@ def forward( ) # Apply input_jitter if applicable: - if self.training and self._config.moe_jitter_eps > 0.0: + if self.training and self._moe_jitter_eps > 0.0: with set_generator(self._tensor_space.distributed.pp_generator): logits = self._apply_input_jitter(logits) # Routing - if self._config.expert_routing_type == RoutingType.topk: + if self._routing_type == RoutingType.topk: scores, top_experts = self._topk_routing(logits, kwargs.get(BlockKwargs.grad_output), losses) - if self._config.num_shared_experts > 0: + if self._num_shared_experts > 0: scores, top_experts = self._add_shared_experts(top_experts, scores) - elif self._config.expert_routing_type == RoutingType.sinkhorn: + elif self._routing_type == RoutingType.sinkhorn: scores, top_experts = self._sinkhorn_routing(logits) else: - raise NotImplementedError(self._config.expert_routing_type) + raise NotImplementedError(self._routing_type) - if self._debug.enabled: + if self._debug_mode: # To log all ranks set `global_=False` - self._debug( - scores, "Router scores", kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.top_experts,), kwargs - ) - self._debug( - top_experts, - "Router top experts", - kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.top_experts,), - kwargs, - ) + self._debug_log(scores, "Router scores", MLPDimNames.top_experts, kwargs) + self._debug_log(top_experts, "Router top experts", MLPDimNames.top_experts, kwargs) return self._mlp_forward(hidden_states, scores, top_experts).view_as(input_), None # noqa @@ -109,9 +119,7 @@ def _forward_dropless( self, hidden_states: torch.Tensor, scores: torch.Tensor, top_experts: torch.Tensor ) -> torch.Tensor: # Compute token counts and the sparse mapping (dense_row, top_index) -> sparse_row. - sparse_map = get_sparse_map( - top_experts, self._config.num_experts, dynamic_shape=self._config.dropless_dynamic_shape - ) + sparse_map = get_sparse_map(top_experts, self._num_experts, dynamic_shape=self._dynamic_shape) # Sparse MLP return mlp_autograd( @@ -140,7 +148,7 @@ def _forward_looped( top_experts, self.layer_1.weight, self.layer_2.weight, - self._config.num_experts, + self._num_experts, self._config.gated, self._config.activation_type, self._intermediate_dim.parallel_group, @@ -151,9 +159,7 @@ def _forward_looped( @torch.compile def _apply_input_jitter(self, logits: torch.Tensor) -> torch.Tensor: - return logits * torch.empty_like(logits).uniform_( - 1.0 - self._config.moe_jitter_eps, 1.0 + self._config.moe_jitter_eps - ) + return logits * torch.empty_like(logits).uniform_(1.0 - self._moe_jitter_eps, 1.0 + self._moe_jitter_eps) def _topk_routing( self, @@ -161,11 +167,11 @@ def _topk_routing( grad_scale: float | None = None, losses: dict | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - top_logits, top_experts = torch.topk(logits, k=self._config.num_experts_per_token, dim=-1) + top_logits, top_experts = torch.topk(logits, k=self._experts_per_token, dim=-1) scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32) if losses is not None or (self.training and grad_scale is not None): probs = torch.softmax(logits, dim=-1, dtype=torch.float32) - mask = torch.nn.functional.one_hot(top_experts, num_classes=self._config.num_unshared_experts).sum(dim=1) + mask = torch.nn.functional.one_hot(top_experts, num_classes=self._num_unshared_experts).sum(dim=1) # Auxiliary loss, corresponding to the sum of probabilities for the top experts. # In the optimal case (uniform distribution), loss = experts_per_token / num_experts. # In the worst case (whole distribution in the top experts), loss = 1. @@ -176,9 +182,7 @@ def _topk_routing( losses[MLPLossNames.load_balancing_loss].append(aux_loss.detach()) if self.training and grad_scale is not None: scores = AuxiliaryLoss.apply( - scores, - aux_loss, - self._config.num_unshared_experts * self._config.expert_auxiliary_loss_coefficient * grad_scale, + scores, aux_loss, self._num_unshared_experts * self._load_balancing_factor * grad_scale ) return scores, top_experts @@ -187,33 +191,69 @@ def _add_shared_experts( ) -> tuple[torch.Tensor, torch.Tensor]: # Add the shared experts (last ones) to the top experts. shared_experts = torch.arange( - self._config.num_unshared_experts, - self._config.num_experts, - device=top_experts.device, - dtype=top_experts.dtype, + self._num_unshared_experts, self._num_experts, device=top_experts.device, dtype=top_experts.dtype )[None].repeat(top_experts.size(0), 1) top_experts = torch.cat((shared_experts, top_experts), dim=1) # Add scores of 1 to scores for shared experts. - scores = torch.cat((scores.new_ones(scores.size(0), self._config.num_shared_experts), scores), dim=1) + scores = torch.cat((scores.new_ones(scores.size(0), self._num_shared_experts), scores), dim=1) return scores, top_experts def _sinkhorn_routing(self, logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: if self.training: - _, top_experts = torch.topk(sinkhorn(logits), k=self._config.num_experts_per_token, dim=-1) + _, top_experts = torch.topk(sinkhorn(logits), k=self._experts_per_token, dim=-1) logits = self._sinkhorn_activation(logits) scores = torch.gather(logits, -1, top_experts) else: logits = self._sinkhorn_activation(logits) - scores, top_experts = torch.topk(logits, k=self._config.num_experts_per_token, dim=-1) + scores, top_experts = torch.topk(logits, k=self._experts_per_token, dim=-1) return scores, top_experts def _sinkhorn_activation(self, logits: torch.Tensor) -> torch.Tensor: return ( torch.sigmoid(logits) - if self._config.num_experts_per_token == 1 + if self._experts_per_token == 1 else torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) ) + def _debug_log( + self, + tensor: torch.Tensor | None, + name: str, + dim_name: str, + kwargs: dict[str, typing.Any], + global_: bool = True, + ) -> None: + if self._config.debug_transformer_memory: + log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"{self._name} {name}", str)) + if self._config.debug_transformer and tensor is not None: + # TODO: Local vs global + meta = self._get_meta(tensor, name, dim_name, kwargs) + log_distributed_tensor( + "", + tensor.view_as(meta), + level=self._config.debug_transformer, + meta=meta, + distributed=self._tensor_space.distributed, + global_=global_, + ) + if tensor.requires_grad: + log_distributed_grad( + "", + tensor, + level=self._config.debug_transformer, + meta=self._get_meta(tensor, name + " grad", dim_name, kwargs), + distributed=self._tensor_space.distributed, + grad_fn=lambda tensor_: tensor_.view_as(meta), + global_=global_, + ) + + def _get_meta(self, tensor: torch.Tensor, name: str, dim_name: str, kwargs: dict[str, typing.Any]) -> TensorMeta: + return TensorMeta.from_dims( + kwargs[BlockKwargs.hidden_dims][:-1] + (self._tensor_space[dim_name],), + tensor_name=f"{self._name} {name}", + dtype=tensor.dtype, + ) + def sinkhorn(cost: torch.Tensor, tolerance: float = 1e-5, eps=1e-9) -> torch.Tensor: """Sinkhorn based MoE routing function""" diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index aba5639b5..19349671e 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -2,77 +2,75 @@ import torch -from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_ +from fast_llm.config import Configurable +from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd -from fast_llm.layers.block.block import BlockLayer -from fast_llm.layers.block.config import BlockDimNames -from fast_llm.layers.block.mlp.config import MLPConfig, MLPDimNames +from fast_llm.layers.block.config import BlockConfig, BlockDimNames +from fast_llm.layers.block.mlp.config import MLPDimNames from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import LinearBase -from fast_llm.utils import get_lr_scale +from fast_llm.tensor import init_normal_, init_zeros_ +from fast_llm.utils import Assert, get_lr_scale -class MLPBase[ConfigType: MLPConfig](BlockLayer[ConfigType]): - _name: typing.ClassVar[str] = "mlp" - - def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str): - super().__init__(config, tensor_space, block_index, name) +class MLPBase[ConfigType: BlockConfig](Configurable[ConfigType], Layer): + def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): + super().__init__(config) + self._name = name + self._block_index = block_index init_method_1 = init_normal_( - std=self._config.init_method_std_mlp_1, - min_val=self._config.init_method_min_mlp_1, - max_val=self._config.init_method_max_mlp_1, + std=config.init_method_std_mlp_1, + min_val=config.init_method_min_mlp_1, + max_val=config.init_method_max_mlp_1, ) init_method_2 = init_normal_( - std=self._config.init_method_std_mlp_2, - min_val=self._config.init_method_min_mlp_2, - max_val=self._config.init_method_max_mlp_2, + std=config.init_method_std_mlp_2, + min_val=config.init_method_min_mlp_2, + max_val=config.init_method_max_mlp_2, ) - hidden_dim = self._tensor_space[BlockDimNames.hidden] - self._intermediate_dim = self._tensor_space[MLPDimNames.composite_expert_mlp] + hidden_dim = tensor_space[BlockDimNames.hidden] + self._intermediate_dim = tensor_space[MLPDimNames.composite_expert_mlp] + self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - layer_lr_scale = ( - self._config.block.block_sequence.per_layer_lr_scale[self._block_index] - if self._config.block.block_sequence.per_layer_lr_scale - else None - ) - lr_scale = ( - tuple(self._config.mlp_lr_scale) - if isinstance(self._config.mlp_lr_scale, list) - else self._config.mlp_lr_scale - ) + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None + lr_scale = tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale lr_scale = get_lr_scale(lr_scale, layer_lr_scale) # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, - self._tensor_space[MLPDimNames.composite_gated_expert_mlp], - bias=self._config.add_bias, + tensor_space[MLPDimNames.composite_gated_expert_mlp], + bias=config.add_mlp_bias, weight_init_method=init_method_1, - bias_init_method=init_method_1 if self._config.random_bias_init else init_zeros_, + bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, lr_scale=lr_scale, ) self.layer_2 = LinearBase( self._intermediate_dim, hidden_dim, - bias=self._config.add_bias, + bias=config.add_mlp_bias, weight_init_method=init_method_2, - bias_init_method=init_method_2 if self._config.random_bias_init else init_zeros_, - auto_bias_grad_accumulation=self._tensor_space.distributed_config.tensor_parallel > 1, + bias_init_method=init_method_2 if config.random_bias_init else init_zeros_, + auto_bias_grad_accumulation=tensor_space.distributed_config.tensor_parallel > 1, transposed_weight=True, lr_scale=lr_scale, ) # PEFT. - self.layer_1 = self._config.block.peft.apply_linear(self.layer_1, TransformerSubLayerName.mlp_1) - self.layer_2 = self._config.block.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) + self.layer_1 = config.peft.apply_linear(self.layer_1, TransformerSubLayerName.mlp_1) + self.layer_2 = config.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) + +class MLP[ConfigType: BlockConfig](MLPBase[ConfigType]): + def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): + Assert.eq(config.num_experts, 1) + super().__init__(config, tensor_space, name, block_index) -class MLP[ConfigType: MLPConfig](MLPBase[ConfigType]): def forward( self, input_: torch.Tensor, diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 2f45fdf9f..9d5ce3f3b 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -87,7 +87,7 @@ class LayerNormalizationBaseConfig(NormalizationConfig): ) def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": - from fast_llm.engine.config_utils.initialization import init_uniform_centered_ + from fast_llm.tensor import init_uniform_centered_ kwargs = { "hidden_dim": hidden_dim, diff --git a/fast_llm/layers/common/linear.py b/fast_llm/layers/common/linear.py index 740b4847c..7249ef569 100644 --- a/fast_llm/layers/common/linear.py +++ b/fast_llm/layers/common/linear.py @@ -3,7 +3,6 @@ import torch -from fast_llm.engine.config_utils.initialization import init_zeros_ from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.linear import ( @@ -15,7 +14,7 @@ output_parallel_linear_backward, output_parallel_linear_forward, ) -from fast_llm.tensor import ParameterMeta +from fast_llm.tensor import ParameterMeta, init_zeros_ logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index d44be3297..bccc1d627 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -1,12 +1,11 @@ import torch -from fast_llm.engine.config_utils.initialization import init_ones_, init_zeros_ from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.normalization import triton_normalization_autograd from fast_llm.layers.common.config import NormalizationImplementation -from fast_llm.tensor import ParameterMeta, accumulate_gradient +from fast_llm.tensor import ParameterMeta, accumulate_gradient, init_ones_, init_zeros_ from fast_llm.utils import Assert try: diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 2e7d71963..b667e5318 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,11 +1,13 @@ -import functools +import typing from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer +from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl -from fast_llm.layers.block.config import BlockDimNames, BlockKwargs, BlockSequenceConfig +from fast_llm.layers.block.config import BlockDimNames, BlockKwargs +from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.layers.transformer.rotary.config import NoRotaryConfig from fast_llm.utils import Assert @@ -44,27 +46,27 @@ class LanguageModelKwargs(BlockKwargs): @config_class() -class LanguageModelConfig(BlockSequenceConfig): - decoder: BlockSequenceConfig = Field( +class LanguageModelBaseConfig(BaseModelConfig): + # TODO: block + transformer: TransformerConfig = Field( + desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) + max_position_embeddings: int = Field( + default=2048, + desc="Number of absolute position embeddings, if applicable.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) vocab_size: int = Field( default=49152, desc="Size of the vocabulary, i.e., number of vocabulary embeddings and logits.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - embedding_dropout: float = Field( - # TODO: backward compatibility? - default=0.0, - desc="Dropout applied to the embedding layer.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - absolute_position_embeddings: int | None = Field( - # TODO: backward compatibility? + use_position_embeddings: bool = Field( default=None, - desc="Number of absolute position embeddings, if applicable.", + desc="Enable absolute position embeddings. Default: Enable unless using rotary embeddings.", hint=FieldHint.architecture, ) tie_word_embeddings: bool = Field( @@ -78,6 +80,22 @@ class LanguageModelConfig(BlockSequenceConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + init_method_std_embed: float = Field( + default=None, + desc="Initialization scale for the vocabulary embedding and output weights (logits).", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + init_method_max_embed: float | None = Field( + default=None, + desc="Max value for clamping initialized weights of the vocabulary embedding and output (logits).", + hint=FieldHint.feature, + ) + init_method_min_embed: float | None = Field( + default=None, + desc="Min value for clamping initialized weights of the vocabulary embedding and output (logits).", + hint=FieldHint.feature, + ) enable_dpo: bool | None = Field( default=False, desc="Whether to enable DPO loss", @@ -185,27 +203,26 @@ class LanguageModelConfig(BlockSequenceConfig): doc="If not provided, all heads are equally weighted.", hint=FieldHint.feature, ) - word_embedding_weight_initialization: InitializationConfig = Field( - desc="Initialization configuration for word embeddings. Default: hidden_size**-0.5", - hint=FieldHint.feature, - ) - position_embedding_weight_initialization: InitializationConfig = Field( - desc="Initialization configuration for position embeddings. Default: hidden_size**-0.5", - hint=FieldHint.feature, - ) - output_weight_initialization: InitializationConfig = Field( - desc="Initialization configuration for untied output weights. Default: hidden_size**-0.5", - hint=FieldHint.feature, - ) def _validate(self) -> None: + self.transformer.validate() with self._set_implicit_default(): if self.language_model_loss_factor is None: if self.distillation_model is None: self.language_model_loss_factor = 1.0 else: self.language_model_loss_factor = 0.0 + if self.use_position_embeddings is None: + self.use_position_embeddings = isinstance(self.transformer.rotary, NoRotaryConfig) + if self.init_method_std_embed is None: + self.init_method_std_embed = self.transformer.init_method_std + if self.init_method_max_embed is None: + self.init_method_max_embed = self.transformer.init_method_max + if self.init_method_min_embed is None: + self.init_method_min_embed = self.transformer.init_method_min super()._validate() + if self.init_method_max_embed is not None and self.init_method_min_embed is not None: + Assert.leq(self.init_method_min_embed, self.init_method_max_embed) if self.distillation_model is not None: if self.prediction_heads > 1: raise NotImplementedError("Multi-token prediction not supported with distillation.") @@ -213,40 +230,43 @@ def _validate(self) -> None: Assert.eq(len(self.prediction_loss_coefficient), self.prediction_heads) for coeff in self.prediction_loss_coefficient: Assert.geq(coeff, 0) - - if self.output_weight_initialization.has_initialization: - assert self.use_absolute_position_embeddings - if self.output_weight_initialization.has_initialization: - assert not self.tie_word_embeddings + if self.transformer.per_layer_lr_scale is not None: + # -1 because the first prediction head's transformer layer is accounted for in num_layers + # +1 because the layer index starts at 1 + Assert.eq( + len(self.transformer.per_layer_lr_scale), self.transformer.num_layers + self.prediction_heads - 1 + 1 + ) def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - super().setup_tensor_space(tensor_space) + self.transformer.setup_tensor_space(tensor_space) tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Embedding dimensions - if self.use_absolute_position_embeddings: - tensor_space.add_tensor_dim( - TensorDim(LanguageModelDimNames.position_embed, self.absolute_position_embeddings) - ) + tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.position_embed, self.max_position_embeddings)) # TODO: Need both? tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab, self.vocab_size)) tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab_tp, self.vocab_size, tensor)) - @functools.cached_property - def word_embedding_weight_initialization_method(self) -> Initializer: - if self.word_embedding_weight_initialization.has_initialization: - return self.word_embedding_weight_initialization.get_initializer() - else: - return self.hidden_size**-0.5 + @property + def num_absolute_position_embeddings(self) -> int: + # TODO: Rename from max embeddings. + return self.max_position_embeddings if self.use_absolute_position_embeddings else None @property def use_absolute_position_embeddings(self) -> int: # TODO: Set through num embeddings instead instead. - return self.absolute_position_embeddings is not None - - @functools.cached_property - def output_weight_initialization_method(self) -> Initializer: - if self.output_weight_initialization.has_initialization: - return self.output_weight_initialization.get_initializer() - else: - return self.hidden_size**-0.5 + return self.use_position_embeddings + + @classmethod + def from_flat_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + ) -> typing.Self: + # The backward compatibility fix in `NormalizationArchitectureConfig` + # won't work for older checkpoints saved with a flat config. + # TODO v0.3: Remove flat format + cls._handle_renamed_field(default, "normalization_type", "type") + cls._handle_renamed_field(default, "layer_norm_eps", "epsilon") + cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") + return super().from_flat_dict(default, strict) diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index b49fef7ba..05678a700 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -7,28 +7,28 @@ from fast_llm.core.ops import reduce_forward, split from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.language_model.config import LanguageModelConfig, LanguageModelDimNames, LanguageModelKwargs -from fast_llm.tensor import ParameterMeta, TensorMeta +from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelDimNames, LanguageModelKwargs +from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ from fast_llm.utils import Assert WORD_EMBEDDINGS_WEIGHT = "word_embeddings_weight" -class LanguageModelEmbedding[ConfigType: LanguageModelConfig](Configurable[ConfigType], Layer): +class LanguageModelEmbedding[ConfigType: LanguageModelBaseConfig](Configurable[LanguageModelBaseConfig], Layer): """ A language model embedding layer. Consists of word embeddings (tensor-parallel or sequence-tensor-parallel), together with optional absolute position embeddings and dropout. """ - config_class: typing.ClassVar[type[LanguageModelConfig]] = LanguageModelConfig + config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = LanguageModelBaseConfig # Ensure the layer is on its own stage. layer_count: float = 1000.0 def __init__( self, - config: LanguageModelConfig, + config: LanguageModelBaseConfig, tensor_space: TensorSpace, ): super().__init__(config) @@ -36,14 +36,14 @@ def __init__( self._tensor_space = tensor_space self._residual_dtype = ( self._distributed_config.optimization_dtype - if self._config.full_precision_residual + if config.transformer.full_precision_residual else self._distributed_config.training_dtype ).torch self._group_size = self._distributed_config.tensor_parallel self._sequence_parallel = self._distributed_config.sequence_tensor_parallel - self._parallel_embeddings = ( - tensor_space.distributed_config.tensor_parallel > 1 and self._config.parallel_embeddings - ) + self._parallel_embeddings = tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings + self._dropout_p = config.transformer.hidden_dropout + self._use_absolute_position_embeddings = config.use_absolute_position_embeddings hidden_dim = tensor_space[LanguageModelDimNames.hidden] vocab_dim = tensor_space[ @@ -56,15 +56,23 @@ def __init__( self.word_embeddings_weight = ParameterMeta.from_dims( (vocab_dim, hidden_dim), - init_method=self._config.word_embedding_weight_initialization_method, - lr_scale=self._config.embeddings_lr_scale, + init_method=init_normal_( + std=config.init_method_std_embed, + min_val=config.init_method_min_embed, + max_val=config.init_method_max_embed, + ), + lr_scale=config.embeddings_lr_scale, ) - if self._config.use_absolute_position_embeddings: + if self._use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( (tensor_space[LanguageModelDimNames.position_embed], hidden_dim), - init_method=self._config.position_embedding_weight_initialization_method, - allow_sequence_tensor_parallel=not self._config.parallel_embeddings, - lr_scale=self._config.embeddings_lr_scale, + init_method=init_normal_( + std=config.init_method_std_embed, + min_val=config.init_method_min_embed, + max_val=config.init_method_max_embed, + ), + allow_sequence_tensor_parallel=not config.parallel_embeddings, + lr_scale=config.embeddings_lr_scale, ) # PEFT. @@ -76,21 +84,21 @@ def __init__( @torch.compile def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask_inputs: bool) -> torch.Tensor: - Assert.eq(position_ids is not None, self._config.use_absolute_position_embeddings) + Assert.eq(position_ids is not None, self._use_absolute_position_embeddings) group = self._tensor_space.distributed.tensor_group if self._parallel_embeddings: input_mask = (input_ >= self._vocab_start_index) * (input_ < self._vocab_end_index) masked_input = (input_ - self._vocab_start_index) * input_mask embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * input_mask.unsqueeze(2) # noqa embeddings = reduce_forward(embeddings, group) - if self._config.use_absolute_position_embeddings: + if self._use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) else: if self._sequence_parallel: input_ = split(input_, group=group, dim=0) - if self._config.use_absolute_position_embeddings: + if self._use_absolute_position_embeddings: position_ids = split(position_ids, group=group, dim=0) # handle masked tokens if mask_inputs: @@ -99,7 +107,7 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask embeddings = torch.embedding(self.word_embeddings_weight, masked_input) else: embeddings = torch.embedding(self.word_embeddings_weight, input_) - if self._config.use_absolute_position_embeddings: + if self._use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if mask_inputs: embeddings = embeddings * input_mask.unsqueeze(2) @@ -108,7 +116,7 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask if self._sequence_parallel else self._tensor_space.distributed.pp_generator ): - embeddings = torch.dropout(embeddings, self._config.embedding_dropout, self.training) + embeddings = torch.dropout(embeddings, self._dropout_p, self.training) return embeddings.to(dtype=self._residual_dtype) def forward( diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 098b2463b..bc672725c 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -15,16 +15,16 @@ from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward -from fast_llm.layers.block.block import DebugLayer from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.language_model.config import ( - LanguageModelConfig, + LanguageModelBaseConfig, LanguageModelDimNames, LanguageModelKwargs, LanguageModelLossNames, ) from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT -from fast_llm.tensor import ParameterMeta, TensorMeta +from fast_llm.logging import log_distributed_tensor +from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ from fast_llm.utils import Assert, div, get_unique logger = logging.getLogger(__name__) @@ -32,67 +32,61 @@ OUTPUT_WEIGHTS = "output_weights" -class LanguageModelHead[ConfigType: LanguageModelConfig](Configurable[ConfigType], Layer): +class LanguageModelHead[ConfigType: LanguageModelBaseConfig](Configurable[LanguageModelBaseConfig], Layer): """ A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). """ - config_class: typing.ClassVar[type[LanguageModelConfig]] = LanguageModelConfig + config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = LanguageModelBaseConfig def __init__( self, - config: LanguageModelConfig, + config: LanguageModelBaseConfig, tensor_space: TensorSpace, prediction_distance: int, ): super().__init__(config) - # TODO: Avoid default_block_config? - self._debug = DebugLayer( - tensor_space, - f"Block {self._block_index} {self._name}", - self._config.default_block_config.debug_transformer, - self._config.default_block_config.debug_transformer_memory, - ) + self._debug_transformer = config.transformer.debug_transformer + self._tie_word_embeddings = config.tie_word_embeddings self._tensor_space = tensor_space self._group_size = tensor_space.distributed_config.tensor_parallel self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel - self._parallel_embeddings = ( - tensor_space.distributed_config.tensor_parallel > 1 and self._config.parallel_embeddings - ) + self._parallel_embeddings = tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings self._sequence_parallel_logits = ( - tensor_space.distributed_config.sequence_tensor_parallel and not self._config.parallel_embeddings + tensor_space.distributed_config.sequence_tensor_parallel and not config.parallel_embeddings ) - self._cross_entropy_splits = self._config.cross_entropy_splits + self._cross_entropy_splits = config.cross_entropy_splits if self._cross_entropy_splits is not None and self._sequence_parallel: assert not self._parallel_embeddings hidden_dim = self._tensor_space[LanguageModelDimNames.hidden] self._loss_coefficient = ( - self._config.prediction_loss_coefficient[prediction_distance] - if self._config.prediction_loss_coefficient - else 1.0 + config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0 ) self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance) - # TODO: Avoid default_block_config? - self.final_norm = self._config.default_block_config.normalization.get_layer(hidden_dim) - self._logits_scale_factor = self._config.logits_scale_factor - self._language_model_loss_factor = self._config.language_model_loss_factor - self._distillation_loss_factor = self._config.distillation_loss_factor - self._z_loss_factor = self._config.logit_z_loss + self.final_norm = config.transformer.normalization.get_layer(hidden_dim) + self._logits_scale_factor = config.logits_scale_factor + self._language_model_loss_factor = config.language_model_loss_factor + self._distillation_loss_factor = config.distillation_loss_factor + self._z_loss_factor = config.logit_z_loss # Distance of the target token prediction # 0: next-token prediction # >0: multi-token prediction (MTP) Assert.geq(prediction_distance, 0) self._prediction_distance = prediction_distance - self._is_last_head = self._prediction_distance == self._config.prediction_heads - 1 + self._is_last_head = self._prediction_distance == config.prediction_heads - 1 - self._init_output_weights(hidden_dim, self._config) + self._init_output_weights(hidden_dim, config) - if not self._config.enable_dpo: - self._cross_entropy_impl = self._config.cross_entropy_impl + self._use_dpo_loss = config.enable_dpo + if self._use_dpo_loss: + self.dpo_beta = config.dpo_beta + else: + self._cross_entropy_impl = config.cross_entropy_impl + self._distillation_loss_implementation = config.distillation_loss_implementation if self._cross_entropy_impl == CrossEntropyImpl.auto: if self._parallel_embeddings: self._cross_entropy_impl = CrossEntropyImpl.fused @@ -110,7 +104,7 @@ def __init__( def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: # Only the first head defines the output weights - if self._config.tie_word_embeddings or self._prediction_distance > 0: + if self._tie_word_embeddings or self._prediction_distance > 0: return # untie embedding weights vocab_dim = self._tensor_space[ @@ -118,7 +112,11 @@ def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: ] self.output_weights = ParameterMeta.from_dims( (vocab_dim, hidden_dim), - init_method=self._config.output_weight_initialization_method, + init_method=init_normal_( + std=config.init_method_std_embed, + min_val=config.init_method_min_embed, + max_val=config.init_method_max_embed, + ), lr_scale=config.output_lr_scale, ) @@ -203,7 +201,7 @@ def _get_targets( self, kwargs: dict ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None] | None: # Loss mask for distillation. (Labels are already masked.) - if self._config.enable_dpo: + if self._use_dpo_loss: dpo_target = kwargs.get(LanguageModelKwargs.labels) lm_target = None distillation_target = None @@ -253,7 +251,7 @@ def _get_targets( return targets def _get_output_weights(self, kwargs: dict) -> torch.Tensor: - if self._config.tie_word_embeddings: + if self._tie_word_embeddings: return kwargs[WORD_EMBEDDINGS_WEIGHT] if self._prediction_distance > 0: return kwargs[OUTPUT_WEIGHTS] @@ -340,22 +338,35 @@ def _logits_cross_entropy_forward_backward( LanguageModelLossNames.z_loss, logits_scale_factor=self._logits_scale_factor, ) - if self._debug.enabled and self._cross_entropy_splits is None: - vocab_dim = ( + if self._debug_transformer and self._cross_entropy_splits is None: + vocab_dim = self._tensor_space[ LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp + ] + dims = [*kwargs[LanguageModelKwargs.hidden_dims][:-1], vocab_dim] + sequence_index = 1 - int(kwargs[LanguageModelKwargs.sequence_first]) + dims[sequence_index] = ( + TensorDim( + LanguageModelDimNames.sequence_q_tp, dims[sequence_index].global_size, DistributedDimNames.tensor + ) + if self._sequence_parallel_logits + else TensorDim(LanguageModelDimNames.sequence_q, dims[sequence_index].global_size) ) - sequence_dim = ( - LanguageModelDimNames.sequence_q_tp + + dim_names = ( + [LanguageModelDimNames.sequence_q_tp, LanguageModelDimNames.vocab] if self._sequence_parallel_logits - else LanguageModelDimNames.sequence_q + else [LanguageModelDimNames.sequence_q, LanguageModelDimNames.vocab_tp] ) - batch_dim = kwargs[LanguageModelKwargs.hidden_dims][1 if kwargs[LanguageModelKwargs.sequence_first] else 0] - dims = ( - (sequence_dim, batch_dim, vocab_dim) - if kwargs[LanguageModelKwargs.sequence_first] - else (batch_dim, sequence_dim, vocab_dim) + + dim_names.insert(int(kwargs[LanguageModelKwargs.sequence_first]), LanguageModelDimNames.batch) + log_distributed_tensor( + "", + logits, + level=self._debug_transformer, + meta=TensorMeta.from_dims(tuple(dims), tensor_name="transformer logits", dtype=logits.dtype), + distributed=self._tensor_space.distributed, + scale=self._logits_scale_factor, ) - self._debug(logits, "Language model logits", dims, kwargs, scale=self._logits_scale_factor) if targets is None: return logits * self._logits_scale_factor, None @@ -368,7 +379,7 @@ def _logits_cross_entropy_forward_backward( kwargs.get(f"{self._config.dpo_reference_model}_logits"), kwargs[LanguageModelKwargs.chosen_spans], kwargs[LanguageModelKwargs.rejected_spans], - self._config.dpo_beta, + self.dpo_beta, grad_output * self._loss_coefficient, ) else: @@ -390,7 +401,7 @@ def _logits_cross_entropy_forward_backward( lm_loss, lm_grad = None, None if distillation_target is not None and self._distillation_loss_factor > 0.0: - if self._config.distillation_loss_implementation == DistillationLossImpl.reverse_kl: + if self._distillation_loss_implementation == DistillationLossImpl.reverse_kl: distillation_loss, distillation_grad = reverse_kl_forward_backward( logits.flatten(0, -2), distillation_target, @@ -403,7 +414,7 @@ def _logits_cross_entropy_forward_backward( TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits ), ) - elif self._config.distillation_loss_implementation == DistillationLossImpl.cross_entropy: + elif self._distillation_loss_implementation == DistillationLossImpl.cross_entropy: distillation_loss, distillation_grad = cross_entropy_forward_backward( logits.flatten(0, -2), distillation_target, diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index 3c9f18c8d..f5d915855 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -5,7 +5,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.layers.language_model.config import LanguageModelConfig, LanguageModelKwargs +from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert @@ -20,11 +20,11 @@ class PositionEmbeddingPreprocessor(Preprocessor): def __init__( self, - config: LanguageModelConfig, + config: LanguageModelBaseConfig, tensor_space: TensorSpace, ): self._config = config - assert config.absolute_position_embeddings is not None + assert config.use_absolute_position_embeddings self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] @@ -34,7 +34,7 @@ def _create_tensors(self, sequence_length: int) -> None: return self._tensor_cache_max_sequence_length = sequence_length - Assert.leq(sequence_length, self._config.absolute_position_embeddings) + Assert.leq(sequence_length, self._config.num_absolute_position_embeddings) self._position_ids = torch.arange( 0, sequence_length, device=self._tensor_space.distributed.device, dtype=torch.int64 ) @@ -71,7 +71,7 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: class PreferenceSpanPreprocessor(Preprocessor): - def __init__(self, config: LanguageModelConfig, tensor_space: TensorSpace): + def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 00c709814..efcf2d873 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -9,7 +9,7 @@ from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: - from fast_llm.engine.config_utils.initialization import Initializer, init_fill_, init_uniform_centered_ + from fast_llm.tensor import Initializer class SSMDimNames(BlockDimNames): @@ -66,6 +66,8 @@ class DTInitType(enum.StrEnum): random = "random" def get_init_method(self, scale: float) -> "Initializer": + from fast_llm.tensor import init_fill_, init_uniform_centered_ + return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 04b27af47..550c44d0f 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -4,15 +4,13 @@ import einops import torch -from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_, init_zeros_ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.ssm.mamba_layer import init_kaiming_ -from fast_llm.tensor import ParameterMeta +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ from fast_llm.utils import get_lr_scale logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index b02fbd401..1c319f490 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -3,15 +3,14 @@ import torch -from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias, init_kaiming_ -from fast_llm.tensor import ParameterMeta +from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_ from fast_llm.utils import Assert, div, get_lr_scale try: diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index e22852fe6..f5b0139cf 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -4,14 +4,13 @@ import torch -from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.tensor import ParameterMeta +from fast_llm.tensor import LambdaInitializer, ParameterMeta, init_kaiming_, init_ones_ from fast_llm.utils import Assert, get_lr_scale try: @@ -164,7 +163,3 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if kwargs[BlockKwargs.sequence_first]: out = out.transpose(0, 1) return out, None - - -def init_kaiming_(d_in: float) -> LambdaInitializer: - return init_normal_(0.0, math.sqrt(2.0 / d_in)) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 2db7b0ac8..b1de792e3 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -6,10 +6,11 @@ from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.autograd import wrap_forward_backward -from fast_llm.layers.block.block import BlockLayer +from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import AttentionConfig, AttentionDimNames, AttentionKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.tensor import init_normal_, init_zeros_ from fast_llm.utils import get_lr_scale try: @@ -45,52 +46,55 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # no return grad, None -class Attention[ConfigType: AttentionConfig](BlockLayer[ConfigType]): +class Attention(Mixer): """ A self-attention layer. """ + _mixer_name: typing.ClassVar[str] = "attn" + _QUERY_DIMS = ( - AttentionDimNames.batch, - AttentionDimNames.sequence_q, - AttentionDimNames.composite_heads, - AttentionDimNames.kv_channels, + TransformerDimNames.batch, + TransformerDimNames.sequence_q, + TransformerDimNames.composite_heads, + TransformerDimNames.kv_channels, ) _KV_DIMS = ( - AttentionDimNames.batch, - AttentionDimNames.sequence_q, - AttentionDimNames.head_groups, - AttentionDimNames.kv_channels, + TransformerDimNames.batch, + TransformerDimNames.sequence_q, + TransformerDimNames.head_groups, + TransformerDimNames.kv_channels, ) _CONTEXT_DIMS = ( - AttentionDimNames.batch, - AttentionDimNames.sequence_q, - AttentionDimNames.composite_dense, + TransformerDimNames.batch, + TransformerDimNames.sequence_q, + TransformerDimNames.composite_dense, ) - def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str): - super().__init__(config, tensor_space, block_index, name) + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int): + super().__init__(tensor_space, block_index, config.debug_transformer) self._config = config self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) - # init_method_qkv = init_normal_( - # std=self._config.init_method_std_qkv, - # min_val=self._config.init_method_min_qkv, - # max_val=self._config.init_method_max_qkv, - # ) - # init_method_std_attn_proj = init_normal_( - # std=self._config.init_method_std_attn_proj, - # min_val=self._config.init_method_min_attn_proj, - # max_val=self._config.init_method_max_attn_proj, - # ) - self._kv_channels = self._tensor_space[AttentionDimNames.kv_channels].size - self._head_groups = self._tensor_space[AttentionDimNames.head_groups].global_size - self._local_head_groups = self._tensor_space[AttentionDimNames.head_groups].size - self._local_heads_per_group = self._tensor_space[AttentionDimNames.group_heads].size + init_method_qkv = init_normal_( + std=self._config.init_method_std_qkv, + min_val=self._config.init_method_min_qkv, + max_val=self._config.init_method_max_qkv, + ) + init_method_std_attn_proj = init_normal_( + std=self._config.init_method_std_attn_proj, + min_val=self._config.init_method_min_attn_proj, + max_val=self._config.init_method_max_attn_proj, + ) + + self._kv_channels = self._tensor_space[TransformerDimNames.kv_channels].size + self._head_groups = self._tensor_space[TransformerDimNames.head_groups].global_size + self._local_head_groups = self._tensor_space[TransformerDimNames.head_groups].size + self._local_heads_per_group = self._tensor_space[TransformerDimNames.group_heads].size self._local_heads = self._local_head_groups * self._local_heads_per_group - self._softmax_scale: float = self._kv_channels ** (-self._config.attention_softmax_scale_power) + self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power) - hidden_dim = self._tensor_space[AttentionDimNames.hidden] + hidden_dim = self._tensor_space[TransformerDimNames.hidden] layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) @@ -98,19 +102,19 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: i # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, - self._tensor_space[AttentionDimNames.composite_query], + self._tensor_space[TransformerDimNames.composite_query], bias=self._config.add_attn_qkv_bias, - weight_init_method=self._config.qkv_weight_initialization_method, - bias_init_method=self._config.qkv_bias_initialization_method, + weight_init_method=init_method_qkv, + bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, sequence_parallel=self._sequence_parallel, lr_scale=attention_lr_scale, ) self.key_value = OutputParallelLinear( hidden_dim, - self._tensor_space[AttentionDimNames.composite_key_value], + self._tensor_space[TransformerDimNames.composite_key_value], bias=self._config.add_attn_qkv_bias, - weight_init_method=self._config.qkv_weight_initialization_method, - bias_init_method=self._config.qkv_bias_initialization_method, + weight_init_method=init_method_qkv, + bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, sequence_parallel=self._sequence_parallel, lr_scale=attention_lr_scale, ) @@ -121,11 +125,11 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: i # Output. self.dense = InputParallelLinear( - self._tensor_space[AttentionDimNames.composite_dense], + self._tensor_space[TransformerDimNames.composite_dense], hidden_dim, bias=self._config.add_attn_dense_bias, - weight_init_method=self._config.dense_weight_initialization_method, - bias_init_method=self._config.dense_bias_initialization_method, + weight_init_method=init_method_std_attn_proj, + bias_init_method=init_method_std_attn_proj if self._config.random_bias_init else init_zeros_, sequence_parallel=self._sequence_parallel, lr_scale=attention_lr_scale, ) @@ -255,24 +259,18 @@ def _decide_window_size(self) -> int | None: return window_size - def forward( - self, - input_: torch.Tensor, - kwargs: dict[str, typing.Any], - losses: dict[str, typing.Any] | None = None, - metrics: dict[str, typing.Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - sequence_first = kwargs[AttentionKwargs.sequence_first] + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + sequence_first = kwargs[TransformerKwargs.sequence_first] query, key_value = self._query_key_value(input_, sequence_first) # TODO: Move the rest to function. - if (past_key_values := kwargs.get(AttentionKwargs.past_key_values)) is not None: + if (past_key_values := kwargs.get(TransformerKwargs.past_key_values)) is not None: assert sequence_first # Clear the lists so tensors can be de-allocated key_value = torch.cat((past_key_values.pop(0), key_value), dim=0) - if (presents := kwargs.get(AttentionKwargs.presents)) is not None: + if (presents := kwargs.get(TransformerKwargs.presents)) is not None: # Return the presents as a leaf tensors so the gradients from later micro-sequences # don't propagate to this one. presents.append(present := key_value.detach().requires_grad_()) @@ -281,9 +279,9 @@ def forward( if self._tensor_space.distributed.sequence_data_group: key_value = ( - key_value[: kwargs[AttentionKwargs.sequence_k_dim].size] + key_value[: kwargs[TransformerKwargs.sequence_k_dim].size] if sequence_first - else key_value[:, : kwargs[AttentionKwargs.sequence_k_dim].size] + else key_value[:, : kwargs[TransformerKwargs.sequence_k_dim].size] ) if sequence_first: @@ -297,9 +295,9 @@ def forward( key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels) value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels) - if self._debug.enabled: - self._debug(query, "query_rotary_input", self._QUERY_DIMS, kwargs) - self._debug( + if self._debug_level: + self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs) + self._debug_log( key, "key_rotary_input", self._KV_DIMS, @@ -312,7 +310,7 @@ def forward( if self._use_flash_attention: assert _flash_available with set_generator(self._tensor_space.distributed.tp_generator): - if (cu_seqlens_q := kwargs.get(AttentionKwargs.cu_seqlens_q, None)) is not None: + if (cu_seqlens_q := kwargs.get(TransformerKwargs.cu_seqlens_q, None)) is not None: out_dims = query.size() query = query.view(-1, query.size(-2), query.size(-1)) key = key.view(-1, key.size(-2), key.size(-1)) @@ -322,9 +320,9 @@ def forward( key, value, cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=kwargs.get(AttentionKwargs.cu_seqlens_k), - max_seqlen_q=kwargs.get(AttentionKwargs.max_seqlen_q), - max_seqlen_k=kwargs.get(AttentionKwargs.max_seqlen_k), + cu_seqlens_k=kwargs.get(TransformerKwargs.cu_seqlens_k), + max_seqlen_q=kwargs.get(TransformerKwargs.max_seqlen_q), + max_seqlen_k=kwargs.get(TransformerKwargs.max_seqlen_k), dropout_p=self._config.attention_dropout if self.training else 0.0, window_size=(-1, -1) if window_size is None else (window_size - 1, 0), causal=True, @@ -347,15 +345,25 @@ def forward( query.flatten(-2), key.flatten(-2), value.flatten(-2), - kwargs[AttentionKwargs.attention_mask], - kwargs[AttentionKwargs.attention_mask_value], + kwargs[TransformerKwargs.attention_mask], + kwargs[TransformerKwargs.attention_mask_value], ) - if self._debug.enabled: - self._debug(query, "query", self._QUERY_DIMS, kwargs) - self._debug(key, "key", self._KV_DIMS, kwargs) - self._debug(value, "value", self._KV_DIMS, kwargs) - self._debug(input_, "context", self._CONTEXT_DIMS, kwargs) + if self._debug_level: + self._debug_log(query, "query", self._QUERY_DIMS, kwargs) + self._debug_log( + key, + "key", + self._KV_DIMS, + kwargs, + ) + self._debug_log( + value, + "value", + self._KV_DIMS, + kwargs, + ) + self._debug_log(input_, "context", self._CONTEXT_DIMS, kwargs) if sequence_first: # TODO: Optimize (is contiguous avoidable? Transpose dense output?) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index bd72bd305..ebb976e63 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -3,29 +3,22 @@ import typing import warnings -from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_zeros_ from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import TritonConfig -from fast_llm.layers.block.config import ( - AddLinearBiasChoices, - BlockDimNames, - BlockKwargs, - BlockLayerConfig, - MixerConfig, -) +from fast_llm.layers.block.config import AddLinearBiasChoices, BlockConfig, BlockDimNames, BlockKwargs from fast_llm.layers.transformer.rotary.config import RotaryConfig from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: - from fast_llm.layers.transformer.attention import Attention + pass logger = logging.getLogger(__name__) -class AttentionDimNames(BlockDimNames): +class TransformerDimNames(BlockDimNames): # A set of common tensor dim names packed into a namespace. # Self-attention dimensions head_groups = "head_groups" @@ -38,7 +31,7 @@ class AttentionDimNames(BlockDimNames): composite_dense = "composite_dense" -class AttentionKwargs(BlockKwargs): +class TransformerKwargs(BlockKwargs): rotary_freq_q = "rotary_freq_q" rotary_freq_k = "rotary_freq_k" attention_mask = "attention_mask" @@ -52,8 +45,9 @@ class AttentionKwargs(BlockKwargs): past_key_values = "past_key_values" -@config_class(dynamic_type={BlockLayerConfig: "attention"}) -class AttentionConfig(MixerConfig): +@config_class() +class AttentionConfig(Config): + # TODO: Make mixer class dynamic. _abstract = False # TODO: Review names @@ -113,30 +107,7 @@ class AttentionConfig(MixerConfig): valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - qkv_weight_initialization: InitializationConfig = Field( - desc="Initialization configuration for the query, key and value layer weights. Default: hidden_size**-0.5", - hint=FieldHint.feature, - ) - qkv_bias_initialization: InitializationConfig = Field( - desc="Initialization configuration for the query, key and value layer biases. Default: fill with zeros.", - hint=FieldHint.feature, - ) - dense_weight_initialization: InitializationConfig = Field( - desc="Initialization configuration for the dense layer weight. Default: (2 * num_blocks * hidden_size)**-0.5", - hint=FieldHint.feature, - ) - dense_bias_initialization: InitializationConfig = Field( - desc="Initialization configuration for the dense layer biases. Default: fill with zeros.", - hint=FieldHint.feature, - ) - def _validate(self) -> None: - - with self._set_implicit_default(): - if self.kv_channels is None: - # TODO: hidden_size not yet validated. - self.kv_channels = div(self.block.block_sequence.hidden_size, self.num_attention_heads) - super()._validate() if not TritonConfig.TRITON_ENABLED: @@ -159,74 +130,182 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor_space.add_tensor_dim( head_groups := TensorDim( - AttentionDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None + TransformerDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None ) ) tensor_space.add_tensor_dim( group_heads := TensorDim( - AttentionDimNames.group_heads, + TransformerDimNames.group_heads, div(self.num_attention_heads, self.head_groups), None if self.head_groups > 1 else tensor, ) ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(AttentionDimNames.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(AttentionDimNames.kv_channels, self.kv_channels)) - tensor_space.add_tensor_dim(CompositeTensorDim(AttentionDimNames.composite_heads, (head_groups, group_heads))) + tensor_space.add_tensor_dim(key_and_value := TensorDim(TransformerDimNames.key_and_value, 2)) + tensor_space.add_tensor_dim(kv_channels := TensorDim(TransformerDimNames.kv_channels, self.kv_channels)) + tensor_space.add_tensor_dim( + CompositeTensorDim(TransformerDimNames.composite_heads, (head_groups, group_heads)) + ) tensor_space.add_tensor_dim( - CompositeTensorDim(AttentionDimNames.composite_query, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(TransformerDimNames.composite_query, (head_groups, group_heads, kv_channels)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(AttentionDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) + CompositeTensorDim(TransformerDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(AttentionDimNames.composite_dense, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(TransformerDimNames.composite_dense, (head_groups, group_heads, kv_channels)) ) - def get_block(self) -> "Attention": - pass - @functools.cached_property - def add_qkv_bias(self) -> bool: - if isinstance(self.block.add_linear_biases, bool): - return self.block.add_linear_biases - if self.block.add_linear_biases == AddLinearBiasChoices.nowhere: - return False - return True +@config_class() +# TODO: Use composition for attention config +class TransformerConfig(AttentionConfig, BlockConfig): + _abstract = False - @functools.cached_property - def add_dense_bias(self) -> bool: - if isinstance(self.block.add_linear_biases, bool): - return self.block.add_linear_biases - if self.block.add_linear_biases == AddLinearBiasChoices.everywhere: - return True - return False + # TODO: Review names + init_method_std: float = Field( + default=None, + desc="Default scale for weight initialization. Default: hidden_size**-0.5", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max: float | None = Field( + default=None, + desc="Max value for clamping initialized weights. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min: float | None = Field( + default=None, + desc="Min value for clamping initialized weights. Default: -float('inf')", + hint=FieldHint.optional, + ) + init_method_std_qkv: float = Field( + default=None, + desc="Scale for the query, key and value weight initialization. Default: init_method_std", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max_qkv: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for query, key and value matrices. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_qkv: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for query, key and value matrices. Default: -float('inf')", + hint=FieldHint.optional, + ) + init_method_std_attn_proj: float = Field( + default=None, + desc="Scale for the attention projection weight initialization. Default: init_method_std", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max_attn_proj: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for attention projection. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_attn_proj: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for attention projection. Default: -float('inf')", + hint=FieldHint.optional, + ) + init_method_std_mlp_1: float = Field( + default=None, + desc="Scale for the MLP first layer weight initialization. Default: init_method_std", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max_mlp_1: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for MLP first layer. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_mlp_1: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for MLP first layer. Default: -float('inf')", + hint=FieldHint.optional, + ) + init_method_std_mlp_2: float = Field( + default=None, + desc="Scale for the MLP second layer weight initialization. Default: init_method_std", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max_mlp_2: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for MLP second layer. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_mlp_2: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for MLP second layer. Default: -float('inf')", + hint=FieldHint.optional, + ) + # Use random inits instead of constant values, useful for debugging. + random_bias_init: bool = Field( + default=False, + desc="Initialize the biases using the initialization method of their respective weights instead of setting them to zero. Used to test for issues that may not be visible when the biases are zero.", + hint=FieldHint.testing, + ) - @functools.cached_property - def qkv_weight_initialization_method(self) -> Initializer: - if self.qkv_weight_initialization.has_initialization: - return self.qkv_weight_initialization.get_initializer() - else: - return self.block.block_sequence.hidden_size**-0.5 + def _validate(self) -> None: + with self._set_implicit_default(): + if self.kv_channels is None: + self.kv_channels = div(self.hidden_size, self.num_attention_heads) + if self.init_method_std is None: + self.init_method_std = self.hidden_size**-0.5 + if self.init_method_std_qkv is None: + self.init_method_std_qkv = self.init_method_std + if self.init_method_std_attn_proj is None: + self.init_method_std_attn_proj = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 + if self.init_method_std_mlp_1 is None: + self.init_method_std_mlp_1 = self.init_method_std + if self.init_method_std_mlp_2 is None: + self.init_method_std_mlp_2 = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 + if self.init_method_max_qkv is None: + self.init_method_max_qkv = self.init_method_max + if self.init_method_min_qkv is None: + self.init_method_min_qkv = self.init_method_min + if self.init_method_max_attn_proj is None: + self.init_method_max_attn_proj = self.init_method_max + if self.init_method_min_attn_proj is None: + self.init_method_min_attn_proj = self.init_method_min + if self.init_method_max_mlp_1 is None: + self.init_method_max_mlp_1 = self.init_method_max + if self.init_method_min_mlp_1 is None: + self.init_method_min_mlp_1 = self.init_method_min + if self.init_method_max_mlp_2 is None: + self.init_method_max_mlp_2 = self.init_method_max + if self.init_method_min_mlp_2 is None: + self.init_method_min_mlp_2 = self.init_method_min + if self.init_method_min is not None and self.init_method_max is not None: + Assert.leq(self.init_method_min, self.init_method_max) + if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: + Assert.leq(self.init_method_min, self.init_method_max) + if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: + Assert.leq(self.init_method_min_qkv, self.init_method_max_qkv) + if self.init_method_min_attn_proj is not None and self.init_method_max_attn_proj is not None: + Assert.leq(self.init_method_min_attn_proj, self.init_method_max_attn_proj) + if self.init_method_min_mlp_1 is not None and self.init_method_max_mlp_1 is not None: + Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1) + if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: + Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) - @functools.cached_property - def qkv_bias_initialization_method(self) -> Initializer: - if self.qkv_bias_initialization.has_initialization: - assert self.add_qkv_bias - return self.qkv_bias_initialization.get_initializer() - else: - return init_zeros_ + super()._validate() - @functools.cached_property - def dense_weight_initialization_method(self) -> Initializer: - if self.dense_weight_initialization.has_initialization: - return self.dense_weight_initialization.get_initializer() - else: - return self.block.block_sequence.hidden_size**-0.5 / max(2 * self.block.block_sequence.num_blocks, 1) + @property + def add_attn_qkv_bias(self) -> bool: + if isinstance(self.add_linear_biases, bool): + return self.add_linear_biases + if self.add_linear_biases == AddLinearBiasChoices.nowhere: + return False + return True - @functools.cached_property - def dense_bias_initialization_method(self) -> Initializer: - if self.dense_bias_initialization.has_initialization: - assert self.add_dense_bias - return self.dense_bias_initialization.get_initializer() - else: - return init_zeros_ + @property + def add_attn_dense_bias(self) -> bool: + if isinstance(self.add_linear_biases, bool): + return self.add_linear_biases + if self.add_linear_biases == AddLinearBiasChoices.everywhere: + return True + return False diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index 16e5811e6..3f0e14eb7 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -5,7 +5,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.layers.transformer.config import AttentionConfig, AttentionKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) @@ -21,7 +21,7 @@ class BackupAttentionPreprocessor(Preprocessor): def __init__( self, - config: AttentionConfig, + config: TransformerConfig, tensor_space: TensorSpace, ): self._config = config @@ -51,13 +51,13 @@ def _create_tensors(self, sequence_length: int) -> None: ) def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[AttentionKwargs.sequence_length]) - sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size - sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size - kwargs[AttentionKwargs.attention_mask] = self._mask[ + self._create_tensors(kwargs[TransformerKwargs.sequence_length]) + sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size + sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size + kwargs[TransformerKwargs.attention_mask] = self._mask[ None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k ] - if (sequence_lengths := kwargs.get(AttentionKwargs.sequence_lengths, None)) is not None: + if (sequence_lengths := kwargs.get(TransformerKwargs.sequence_lengths, None)) is not None: seq_ids = torch.stack( [ torch.cat([torch.full((x,), i) for i, x in enumerate(sample_lens)]) @@ -65,33 +65,33 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: ] ) document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(self._tensor_space.distributed.device) - kwargs[AttentionKwargs.attention_mask] = ( - kwargs[AttentionKwargs.attention_mask] + kwargs[TransformerKwargs.attention_mask] = ( + kwargs[TransformerKwargs.attention_mask] & document_mask[:, None, sequence_k - sequence_q : sequence_k, None, :sequence_k] ) - kwargs[AttentionKwargs.attention_mask_value] = self._mask_value + kwargs[TransformerKwargs.attention_mask_value] = self._mask_value def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[AttentionKwargs.attention_mask] = TensorMeta.from_dims( + kwargs[TransformerKwargs.attention_mask] = TensorMeta.from_dims( ( self._scalar_dim, self._scalar_dim, - kwargs[AttentionKwargs.sequence_q_dim], + kwargs[TransformerKwargs.sequence_q_dim], self._scalar_dim, - kwargs[AttentionKwargs.sequence_k_dim], + kwargs[TransformerKwargs.sequence_k_dim], ), - tensor_name=AttentionKwargs.attention_mask, + tensor_name=TransformerKwargs.attention_mask, dtype=torch.bool, ) - kwargs[AttentionKwargs.attention_mask_value] = TensorMeta.from_dims( + kwargs[TransformerKwargs.attention_mask_value] = TensorMeta.from_dims( (self._scalar_dim,), - tensor_name=AttentionKwargs.attention_mask_value, + tensor_name=TransformerKwargs.attention_mask_value, dtype=self._tensor_space.distributed_config.training_dtype.torch, ) class FlashAttnVarlenPreprocessor(Preprocessor): - def __init__(self, config: AttentionConfig, tensor_space: TensorSpace): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config @@ -107,12 +107,12 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: also contain previous tokens from the first document in micro-sequence. We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths. """ - if AttentionKwargs.sequence_lengths not in kwargs: + if TransformerKwargs.sequence_lengths not in kwargs: return - sequence_lengths = kwargs[AttentionKwargs.sequence_lengths] - sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size - sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size - if sequence_q < kwargs[AttentionKwargs.sequence_length]: + sequence_lengths = kwargs[TransformerKwargs.sequence_lengths] + sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size + sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size + if sequence_q < kwargs[TransformerKwargs.sequence_length]: cumsums = [torch.cumsum(x, dim=0) for x in sequence_lengths] # The first and last documents in a microsequence need to be handled separately. Include all tokens from other documents # in the microsequence. We need to consider all keys computed so far from the first sample. We also store the offsets @@ -146,17 +146,17 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: else: seqlens_q = torch.cat(sequence_lengths) seqlens_k = torch.cat(sequence_lengths) - kwargs[AttentionKwargs.cu_seqlens_q] = torch.cat( + kwargs[TransformerKwargs.cu_seqlens_q] = torch.cat( ( torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), torch.cumsum(seqlens_q, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), ) ) - kwargs[AttentionKwargs.cu_seqlens_k] = torch.cat( + kwargs[TransformerKwargs.cu_seqlens_k] = torch.cat( ( torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), torch.cumsum(seqlens_k, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), ) ) - kwargs[AttentionKwargs.max_seqlen_q] = seqlens_q.max() - kwargs[AttentionKwargs.max_seqlen_k] = seqlens_k.max() + kwargs[TransformerKwargs.max_seqlen_q] = seqlens_q.max() + kwargs[TransformerKwargs.max_seqlen_k] = seqlens_k.max() diff --git a/fast_llm/layers/transformer/rotary/preprocessing.py b/fast_llm/layers/transformer/rotary/preprocessing.py index 9f8732f85..c357411b6 100644 --- a/fast_llm/layers/transformer/rotary/preprocessing.py +++ b/fast_llm/layers/transformer/rotary/preprocessing.py @@ -4,7 +4,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig from fast_llm.tensor import TensorMeta @@ -26,34 +26,34 @@ def __init__( self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] - self._kv_channels_dim = self._tensor_space[AttentionDimNames.kv_channels] + self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[AttentionKwargs.sequence_length]) - sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size - kwargs[AttentionKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ - :, sequence_k - kwargs[AttentionKwargs.sequence_q_dim].size : sequence_k + self._create_tensors(kwargs[TransformerKwargs.sequence_length]) + sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size + kwargs[TransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ + :, sequence_k - kwargs[TransformerKwargs.sequence_q_dim].size : sequence_k ] - kwargs[AttentionKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] + kwargs[TransformerKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[AttentionKwargs.rotary_freq_q] = TensorMeta.from_dims( + kwargs[TransformerKwargs.rotary_freq_q] = TensorMeta.from_dims( ( self._scalar_dim, - kwargs[AttentionKwargs.sequence_q_dim], + kwargs[TransformerKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=AttentionKwargs.rotary_freq_q, + tensor_name=TransformerKwargs.rotary_freq_q, ) - kwargs[AttentionKwargs.rotary_freq_k] = TensorMeta.from_dims( + kwargs[TransformerKwargs.rotary_freq_k] = TensorMeta.from_dims( ( self._scalar_dim, - kwargs[AttentionKwargs.sequence_q_dim], + kwargs[TransformerKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=AttentionKwargs.rotary_freq_k, + tensor_name=TransformerKwargs.rotary_freq_k, ) def _create_tensors(self, sequence_length: int) -> None: diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index ebb629aa1..17b18a1ca 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -8,7 +8,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.triton.rotary import triton_rotary_autograd_ -from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.rotary.config import ( DefaultRotaryConfig, Llama3RotaryConfig, @@ -83,44 +83,44 @@ def __init__( self._tensor_space = tensor_space if self._tensor_space is not None: self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] - self._kv_channels_dim = self._tensor_space[AttentionDimNames.kv_channels] + self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: assert self._tensor_space is not None - self._create_tensors(kwargs[AttentionKwargs.sequence_length]) - sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size - kwargs[AttentionKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ - :, sequence_k - kwargs[AttentionKwargs.sequence_q_dim].size : sequence_k + self._create_tensors(kwargs[TransformerKwargs.sequence_length]) + sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size + kwargs[TransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ + :, sequence_k - kwargs[TransformerKwargs.sequence_q_dim].size : sequence_k ] - kwargs[AttentionKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] + kwargs[TransformerKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: assert self._tensor_space is not None - kwargs[AttentionKwargs.rotary_freq_q] = TensorMeta.from_dims( + kwargs[TransformerKwargs.rotary_freq_q] = TensorMeta.from_dims( ( self._scalar_dim, - kwargs[AttentionKwargs.sequence_q_dim], + kwargs[TransformerKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=AttentionKwargs.rotary_freq_q, + tensor_name=TransformerKwargs.rotary_freq_q, ) - kwargs[AttentionKwargs.rotary_freq_k] = TensorMeta.from_dims( + kwargs[TransformerKwargs.rotary_freq_k] = TensorMeta.from_dims( ( self._scalar_dim, - kwargs[AttentionKwargs.sequence_q_dim], + kwargs[TransformerKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=AttentionKwargs.rotary_freq_k, + tensor_name=TransformerKwargs.rotary_freq_k, ) def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: rotary_fn = triton_rotary_autograd_ if self._config.triton else apply_rotary_embeddings - query = rotary_fn(query, kwargs[AttentionKwargs.rotary_freq_q]) - key = rotary_fn(key, kwargs[AttentionKwargs.rotary_freq_k]) + query = rotary_fn(query, kwargs[TransformerKwargs.rotary_freq_q]) + key = rotary_fn(key, kwargs[TransformerKwargs.rotary_freq_k]) return query, key def _create_tensors(self, sequence_length: int) -> None: diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index eb24ef183..3c0ad8ab4 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -36,7 +36,7 @@ def get_layers(self) -> list[Layer]: self._tensor_space, block_index=i + 1, ) - for i in range(self._config.transformer.num_blocks) + for i in range(self._config.transformer.num_layers) ], CustomHead(self._config, self._tensor_space), ] diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index a7fcad82d..0da16428e 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -9,7 +9,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.language_model.config import LanguageModelConfig +from fast_llm.layers.language_model.config import LanguageModelBaseConfig from fast_llm.models.gpt.megatron import set_megatron_distributed_seeds from fast_llm.utils import Assert, div @@ -119,7 +119,7 @@ def micro_batch_splits(self) -> int: @config_class() -class GPTBaseModelConfig(LanguageModelConfig): +class GPTBaseModelConfig(LanguageModelBaseConfig): _abstract = False # Debug, to get an exact match with megatron init. @@ -192,12 +192,15 @@ class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate() def _validate(self) -> None: + if self.batch.sequence_length is None: + # TODO: Drop this. + self.batch.sequence_length = self.model.base_model.max_position_embeddings if self.model.base_model.use_megatron_initialization: set_megatron_distributed_seeds(self.model.distributed) super()._validate() if self.model.base_model.use_absolute_position_embeddings: - Assert.geq(self.model.base_model.absolute_position_embeddings, self.batch.sequence_length) + Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) distillation_model = self.model.base_model.distillation_model dpo_reference_model = self.model.base_model.dpo_reference_model diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index f3e57fe13..2dbef77f3 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -176,7 +176,7 @@ def _create_weight_converters( self, ) -> list[WeightConverter]: converters = [] - num_layers = self._model.config.base_model.transformer.num_blocks + num_layers = self._model.config.base_model.transformer.num_layers # Embeddings converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) @@ -256,7 +256,7 @@ def _create_transformer_layer_converters( return converters def _create_lm_head_converters(self) -> list[WeightConverter]: - num_layers = self._model.config.base_model.transformer.num_blocks + num_layers = self._model.config.base_model.transformer.num_layers prediction_heads = self._model.config.base_model.prediction_heads norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) converters = [] @@ -654,7 +654,7 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig # Override base method to handle the MTP heads def _create_lm_head_converters(self) -> list[WeightConverter]: - num_layers = self._model.config.base_model.transformer.num_blocks + num_layers = self._model.config.base_model.transformer.num_layers prediction_heads = self._model.config.base_model.prediction_heads norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) converters = [] diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index 4e3f258fc..cf7da3872 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -9,7 +9,7 @@ from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM -from fast_llm.layers.transformer.config import AttentionKwargs +from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner @@ -86,12 +86,12 @@ def forward( if past_key_values is not None: # The transformers will use the past keys and values to this list. - kwargs[AttentionKwargs.past_key_values] = past_key_values + kwargs[TransformerKwargs.past_key_values] = past_key_values # TODO: preprocess needs to know about the past. raise NotImplementedError() if use_cache: # The transformers will save the present keys and values to this list. - kwargs[AttentionKwargs.presents] = [] + kwargs[TransformerKwargs.presents] = [] if output_hidden_states: kwargs["output_hidden_states"] = True @@ -117,11 +117,11 @@ def forward( outputs = (logits,) if use_cache: - outputs += (kwargs[AttentionKwargs.presents],) + outputs += (kwargs[TransformerKwargs.presents],) return outputs return transformers.modeling_outputs.CausalLMOutputWithPast( logits=logits, hidden_states=hidden_states, - past_key_values=kwargs[AttentionKwargs.presents], + past_key_values=kwargs[TransformerKwargs.presents], ) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 30842597d..da647de57 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -16,7 +16,7 @@ from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor from fast_llm.layers.transformer.block import TransformerBlock -from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron @@ -68,7 +68,7 @@ def get_output_layers(self) -> list[Layer]: self._config.transformer, self._tensor_space, # TODO MTP: which index? - block_index=max(self._config.transformer.num_blocks + i, 1), + block_index=max(self._config.transformer.num_layers + i, 1), # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=i < self._config.prediction_heads - 1, @@ -93,9 +93,9 @@ def get_layers(self) -> list[Layer]: block_index=i + 1, # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. - return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_blocks - 1, + return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, ) - for i in range(self._config.transformer.num_blocks) + for i in range(self._config.transformer.num_layers) ], *self.get_output_layers(), ] @@ -119,7 +119,7 @@ def preprocess_meta( truncate_documents = True batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) - batch_dim = TensorDim(AttentionDimNames.batch, micro_batch_size * batch_data.size, batch_data) + batch_dim = TensorDim(TransformerDimNames.batch, micro_batch_size * batch_data.size, batch_data) if micro_sequence_length is None: micro_sequence_length = sequence_length @@ -128,13 +128,13 @@ def preprocess_meta( # TODO: Calculate hidden dims elsewhere? sequence_q_dim = TensorDim( - AttentionDimNames.sequence_q, + TransformerDimNames.sequence_q, micro_sequence_length, self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.sequence_data), ) hidden_sequence_q_dim = ( TensorDim( - AttentionDimNames.sequence_q_tp, + TransformerDimNames.sequence_q_tp, micro_sequence_length, self._tensor_space.distributed_config.get_distributed_dim( DistributedDimNames.tensor_and_sequence_data @@ -151,7 +151,7 @@ def preprocess_meta( sequence_first = self._config.sequence_first assert not (need_sequence_first and not sequence_first) - hidden_dim = self._tensor_space[AttentionDimNames.hidden] + hidden_dim = self._tensor_space[TransformerDimNames.hidden] hidden_dims = ( (hidden_sequence_q_dim, batch_dim, hidden_dim) if sequence_first @@ -160,10 +160,10 @@ def preprocess_meta( common_kwargs = { LanguageModelKwargs.phase: phase, - AttentionKwargs.sequence_first: sequence_first, - AttentionKwargs.hidden_dims: hidden_dims, - AttentionKwargs.sequence_length: sequence_length, - AttentionKwargs.sequence_q_dim: sequence_q_dim, + TransformerKwargs.sequence_first: sequence_first, + TransformerKwargs.hidden_dims: hidden_dims, + TransformerKwargs.sequence_length: sequence_length, + TransformerKwargs.sequence_q_dim: sequence_q_dim, LanguageModelKwargs.mask_inputs: not truncate_documents, } @@ -182,7 +182,7 @@ def preprocess_meta( preprocessed_meta = [] for i, sequence_k_past in enumerate(sequence_k_pasts): sequence_k = sequence_k_past + sequence_q_dim.size - sequence_k_dim = TensorDim(AttentionDimNames.sequence_k, sequence_k) + sequence_k_dim = TensorDim(TransformerDimNames.sequence_k, sequence_k) tokens = TensorMeta.from_dims( hidden_dims[:2], tensor_name=f"tokens_{sequence_k_past}_to_{sequence_k-1}", dtype=torch.int64 @@ -190,7 +190,7 @@ def preprocess_meta( kwargs = { **common_kwargs, - AttentionKwargs.sequence_k_dim: sequence_k_dim, + TransformerKwargs.sequence_k_dim: sequence_k_dim, } if phase != PhaseType.inference: kwargs[LanguageModelKwargs.labels] = TensorMeta.from_dims( @@ -202,10 +202,10 @@ def preprocess_meta( for name, reference_preprocessed_meta in reference_preprocessed_metas.items(): reference_tokens, reference_kwargs_ = reference_preprocessed_meta[i] for key in ( - AttentionKwargs.sequence_first, - AttentionKwargs.sequence_length, - AttentionKwargs.sequence_q_dim, - AttentionKwargs.sequence_k_dim, + TransformerKwargs.sequence_first, + TransformerKwargs.sequence_length, + TransformerKwargs.sequence_q_dim, + TransformerKwargs.sequence_k_dim, ): Assert.eq(reference_kwargs_[key], kwargs[key]) reference_kwargs[name] = reference_kwargs_ @@ -231,8 +231,8 @@ def preprocess( preprocessed_meta = self.preprocess_meta(batch.token_ids, phase) _, common_kwargs = preprocessed_meta[0] - sequence_q = common_kwargs[AttentionKwargs.sequence_q_dim].size - sequence_first = common_kwargs[AttentionKwargs.sequence_first] + sequence_q = common_kwargs[TransformerKwargs.sequence_q_dim].size + sequence_first = common_kwargs[TransformerKwargs.sequence_first] prediction_heads: int = self._config.prediction_heads batch.token_ids = batch.token_ids.to( @@ -264,14 +264,14 @@ def preprocess( preprocessed = [] presents = None for i, (_, kwargs_meta) in enumerate(preprocessed_meta): - sequence_k = kwargs_meta[AttentionKwargs.sequence_k_dim].size + sequence_k = kwargs_meta[TransformerKwargs.sequence_k_dim].size if sequence_first: tokens = batch.token_ids[sequence_k - sequence_q : sequence_k] else: # TODO: Avoid multiple contiguous calls? tokens = batch.token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() if batch.sequence_lengths is not None: - kwargs_meta[AttentionKwargs.sequence_lengths] = batch.sequence_lengths + kwargs_meta[TransformerKwargs.sequence_lengths] = batch.sequence_lengths if batch.chosen_spans is not None: kwargs_meta[LanguageModelKwargs.chosen_spans] = batch.chosen_spans if batch.rejected_spans is not None: @@ -283,8 +283,8 @@ def preprocess( presents = None if i == len(preprocessed_meta) - 1 else [] kwargs = { **kwargs_meta, - AttentionKwargs.past_key_values: pasts, - AttentionKwargs.presents: presents, + TransformerKwargs.past_key_values: pasts, + TransformerKwargs.presents: presents, } if phase != PhaseType.inference: sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels @@ -372,7 +372,7 @@ def loss_defs(self) -> list[LossDef]: LossDef( name=MLPLossNames.load_balancing_loss, formatted_name="load balancing loss", - count=self._config.transformer.num_blocks, + count=self._config.transformer.num_layers, ) ) if self._config.transformer.expert_z_loss_coefficient: @@ -380,7 +380,7 @@ def loss_defs(self) -> list[LossDef]: LossDef( name=MLPLossNames.router_z_loss, formatted_name="router z loss", - count=self._config.transformer.num_blocks, + count=self._config.transformer.num_layers, ) ) if self._config.logit_z_loss: @@ -421,7 +421,7 @@ def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration, batch_size, s consumed_tokens_per_iteration = sequence_length * batch_size - num_transformer_layers = transformer_config.num_blocks + self._config.base_model.prediction_heads - 1 + num_transformer_layers = transformer_config.num_layers + self._config.base_model.prediction_heads - 1 transformer_flops_base = ( 2 * checkpoint_activations_factor * consumed_tokens_per_iteration * num_transformer_layers ) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index a351522ca..9427f69be 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -62,13 +62,13 @@ def _validate(self): if self.hybrid_block_layout is None: with self._set_implicit_default(): - self.hybrid_block_layout = [SSMBlockType.mamba2_discrete] * self.transformer.num_blocks + self.hybrid_block_layout = [SSMBlockType.mamba2_discrete] * self.transformer.num_layers - if len(self.hybrid_block_layout) != self.transformer.num_blocks: - message = f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_blocks}" - if self.transformer.num_blocks % len(self.hybrid_block_layout) != 0: + if len(self.hybrid_block_layout) != self.transformer.num_layers: + message = f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" + if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: raise ValueError(message) - num_repeats = self.transformer.num_blocks // len(self.hybrid_block_layout) + num_repeats = self.transformer.num_layers // len(self.hybrid_block_layout) logger.warning(f"{message}, will repeat {self.hybrid_block_layout} {num_repeats} times.") self.hybrid_block_layout = self.hybrid_block_layout * num_repeats diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index fb24c1aec..43e3c67e5 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -219,7 +219,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: def _create_weight_converters(self) -> list[WeightConverter]: converters = super()._create_weight_converters() or [] - num_layers = self._model.config.base_model.transformer.num_blocks + num_layers = self._model.config.base_model.transformer.num_layers ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear for i in range(num_layers): @@ -383,7 +383,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: def _create_weight_converters(self) -> list[WeightConverter]: # not using super() because LLamba model is called backbone in the checkpoints converters = [] - num_layers = self._model.config.base_model.transformer.num_blocks + num_layers = self._model.config.base_model.transformer.num_layers norm_bias: bool = False ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear @@ -572,7 +572,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: def _create_weight_converters(self) -> list[WeightConverter]: converters = super()._create_weight_converters() - num_layers = self._model.config.base_model.transformer.num_blocks + num_layers = self._model.config.base_model.transformer.num_layers norm_bias: bool = False # Embedding and output diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index b12d12072..d080e6a1e 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -1,12 +1,13 @@ +import abc import functools import logging +import math import typing import torch from fast_llm.core.distributed import ReduceOp from fast_llm.core.ops import reduce_op -from fast_llm.engine.config_utils.initialization import Initializer, LambdaInitializer from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed @@ -360,3 +361,70 @@ def accumulate_gradient(param: torch.Tensor, grad: torch.Tensor) -> None: triton_copy(grad, param.grad_buffer) # noqa else: triton_add(grad, param.grad_buffer, out=param.grad_buffer) # noqa + + +class Initializer(abc.ABC): + @abc.abstractmethod + def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: + pass + + requires_global_initialization = False + + +class LambdaInitializer(Initializer): + def __init__( + self, + init_method: typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None], + requires_global_initialization: bool = False, + ) -> None: + self._init_method = init_method + self.requires_global_initialization = requires_global_initialization + + def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: + return self._init_method(meta, tensor, generator) + + +def init_fill_(value: float) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa + tensor.fill_(value) + + return LambdaInitializer(init_) + + +init_zeros_ = init_fill_(0.0) +init_ones_ = init_fill_(1.0) + + +def init_normal_( + mean: float = 0.0, std: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa + tensor = tensor.normal_(mean, std, generator=generator) + if min_val is not None or max_val is not None: + tensor.clamp_(min=min_val, max=max_val) + + return LambdaInitializer(init_) + + +def init_kaiming_(d_in: float) -> LambdaInitializer: + return init_normal_(0.0, math.sqrt(2.0 / d_in)) + + +def init_uniform_( + low: float = 0.0, high: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa + tensor = tensor.uniform_(low, high, generator=generator) + if min_val is not None or max_val is not None: + tensor.clamp_(min=min_val, max=max_val) + + return LambdaInitializer(init_) + + +def init_uniform_centered_(high: float, max_val: float | None = None, mean: float = 0.0) -> LambdaInitializer: + return init_uniform_( + mean - high, + mean + high, + min_val=None if max_val is None else mean - max_val, + max_val=None if max_val is None else mean + max_val, + ) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 8c33aed4d..9a878c494 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -9,7 +9,7 @@ from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead -from fast_llm.layers.transformer.config import AttentionKwargs +from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda @@ -198,8 +198,8 @@ def test_lm_head( else: loss_mask = None kwargs = { - AttentionKwargs.sequence_first: sequence_first, - AttentionKwargs.grad_output: 1.0, + TransformerKwargs.sequence_first: sequence_first, + TransformerKwargs.grad_output: 1.0, } if config.distillation_model is None: target = torch.randint( diff --git a/tests/models/test_generate.py b/tests/models/test_generate.py index cb9c69ccb..7f0b902f8 100644 --- a/tests/models/test_generate.py +++ b/tests/models/test_generate.py @@ -354,7 +354,7 @@ def _test_forward_return_hidden_states( # hidden_states include embeddings layer assert ( - len(res_fast_llm.hidden_states) - 1 == fast_llm_model.config.fast_llm_config.base_model.transformer.num_blocks + len(res_fast_llm.hidden_states) - 1 == fast_llm_model.config.fast_llm_config.base_model.transformer.num_layers ) diff --git a/tests/test_attention.py b/tests/test_attention.py index 534e3800e..dd36b840a 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -6,7 +6,7 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs, TransformerConfig +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.preprocessing import FlashAttnVarlenPreprocessor from fast_llm.utils import Assert @@ -77,13 +77,13 @@ def test_varlen_preprocessor(): varlen_preprocessor = FlashAttnVarlenPreprocessor(transformer_cfg, tensor_space=tensor_space) for micro_seq_idx in range(int(sequence_length / micro_sequence_length)): kwargs = { - AttentionKwargs.sequence_q_dim: TensorDim(AttentionDimNames.sequence_k, micro_sequence_length), - AttentionKwargs.sequence_k_dim: TensorDim( - AttentionDimNames.sequence_k, (micro_seq_idx + 1) * micro_sequence_length + TransformerKwargs.sequence_q_dim: TensorDim(TransformerDimNames.sequence_k, micro_sequence_length), + TransformerKwargs.sequence_k_dim: TensorDim( + TransformerDimNames.sequence_k, (micro_seq_idx + 1) * micro_sequence_length ), - AttentionKwargs.sequence_length: sequence_length, - AttentionKwargs.sequence_lengths: sequence_lengths, + TransformerKwargs.sequence_length: sequence_length, + TransformerKwargs.sequence_lengths: sequence_lengths, } varlen_preprocessor.preprocess(None, kwargs) - Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_q], cumulative_sequences_q[micro_seq_idx]) - Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_k], cumulative_sequences_k[micro_seq_idx]) + Assert.all_equal(kwargs[TransformerKwargs.cu_seqlens_q], cumulative_sequences_q[micro_seq_idx]) + Assert.all_equal(kwargs[TransformerKwargs.cu_seqlens_k], cumulative_sequences_k[micro_seq_idx]) diff --git a/tests/test_ssms.py b/tests/test_ssms.py index 6c4c7f0cb..694faa55b 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -9,7 +9,7 @@ from fast_llm.engine.schedule.config import ScheduleConfig from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.engine.schedule.schedule import Schedule -from fast_llm.layers.transformer.config import AttentionKwargs +from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat from fast_llm.models.ssm.model import HybridSSMModel @@ -71,8 +71,8 @@ def test_load_from_llamba_checkpoint(): schedule_runner.setup(model.distributed, optimizer=None) common_kwargs = { - AttentionKwargs.sequence_first: True, - AttentionKwargs.grad_output: False, + TransformerKwargs.sequence_first: True, + TransformerKwargs.grad_output: False, } input_data = [(x, common_kwargs)] diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 4705ebb79..722d8d63a 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -162,7 +162,6 @@ def _update_and_add_testing_config( "model.base_model.transformer.num_attention_heads=8", "model.base_model.transformer.head_groups=8", "model.base_model.transformer.init_method_std=0.022", - "model.base_model.transformer.use_position_embeddings=True", f"model.base_model.vocab_size={MODEL_TEST_VOCAB_SIZE}", f"model.multi_stage.debug_param_init={_LOG_LEVEL}", f"model.multi_stage.debug_layer_outputs={_LOG_LEVEL}", @@ -259,7 +258,6 @@ def _update_and_add_testing_config( extra_args=[ "model.base_model.transformer.head_groups=4", "model.base_model.transformer.rotary.type=default", - "model.base_model.transformer.use_position_embeddings=False", # Unused, but prevents issues with conversion tests. "model.base_model.max_position_embeddings=2048", ], From b68d36048852f6d78c2c8506f76f1708ada1f77e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 31 Jul 2025 15:58:34 -0400 Subject: [PATCH 06/28] stuff --- fast_llm/layers/block/config.py | 25 ++- fast_llm/layers/block/mlp/config.py | 54 ++++- fast_llm/layers/transformer/attention.py | 74 +++---- fast_llm/layers/transformer/config.py | 192 ++++++------------ fast_llm/layers/transformer/preprocessing.py | 48 ++--- .../transformer/rotary/preprocessing.py | 26 +-- fast_llm/layers/transformer/rotary/rotary.py | 30 +-- fast_llm/models/gpt/conversion.py | 6 +- fast_llm/models/gpt/huggingface.py | 10 +- fast_llm/models/gpt/model.py | 42 ++-- tests/layers/test_lm_head.py | 6 +- tests/test_attention.py | 16 +- tests/test_ssms.py | 6 +- 13 files changed, 266 insertions(+), 269 deletions(-) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 5a999fa6d..489cd4f3f 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -40,7 +40,7 @@ class AddLinearBiasChoices(str, enum.Enum): @config_class() -# TODO: Use composition for MLP config +# TODO: Use composition instead class BlockConfig(MLPConfig, BaseModelConfig): # TODO: Review names @@ -100,10 +100,33 @@ class BlockConfig(MLPConfig, BaseModelConfig): hint=FieldHint.feature, ) + # TODO: Review initialization + init_method_std: float = Field( + default=None, + desc="Default scale for weight initialization. Default: hidden_size**-0.5", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max: float | None = Field( + default=None, + desc="Max value for clamping initialized weights. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min: float | None = Field( + default=None, + desc="Min value for clamping initialized weights. Default: -float('inf')", + hint=FieldHint.optional, + ) + def _validate(self) -> None: with self._set_implicit_default(): if self.ffn_hidden_size is None: self.ffn_hidden_size = 4 * self.hidden_size + # TODO: Review initialization + if self.init_method_std is None: + self.init_method_std = self.hidden_size**-0.5 + if self.init_method_min is not None and self.init_method_max is not None: + Assert.leq(self.init_method_min, self.init_method_max) super()._validate() diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 1d125c4f7..64e234544 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -72,8 +72,6 @@ class MLPConfig(Config): hint=FieldHint.architecture, ) gated: bool = Field(default=False, desc="Enable gated MLP.", hint=FieldHint.architecture) - # Default: hidden_size**-0.5 - # TODO: Allow custom initialization (InitializationConfig?) activation_type: ActivationType = Field( default=None, desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", @@ -124,11 +122,63 @@ class MLPConfig(Config): " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", hint=FieldHint.expert, ) + # TODO: Review initialization + init_method_std_mlp_1: float = Field( + default=None, + desc="Scale for the MLP first layer weight initialization. Default: init_method_std", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max_mlp_1: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for MLP first layer. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_mlp_1: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for MLP first layer. Default: -float('inf')", + hint=FieldHint.optional, + ) + init_method_std_mlp_2: float = Field( + default=None, + desc="Scale for the MLP second layer weight initialization. Default: init_method_std", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max_mlp_2: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for MLP second layer. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_mlp_2: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for MLP second layer. Default: -float('inf')", + hint=FieldHint.optional, + ) def _validate(self) -> None: with self._set_implicit_default(): + # TODO: Make this work without inheritance. if self.activation_type is None: self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu + # TODO: Review initialization + if self.init_method_std_mlp_1 is None: + self.init_method_std_mlp_1 = self.init_method_std + if self.init_method_std_mlp_2 is None: + self.init_method_std_mlp_2 = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 + if self.init_method_max_mlp_1 is None: + self.init_method_max_mlp_1 = self.init_method_max + if self.init_method_min_mlp_1 is None: + self.init_method_min_mlp_1 = self.init_method_min + if self.init_method_max_mlp_2 is None: + self.init_method_max_mlp_2 = self.init_method_max + if self.init_method_min_mlp_2 is None: + self.init_method_min_mlp_2 = self.init_method_min + if self.init_method_min_mlp_1 is not None and self.init_method_max_mlp_1 is not None: + Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1) + if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: + Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) + self.num_unshared_experts = self.num_experts - self.num_shared_experts super()._validate() diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index b1de792e3..e84e92a96 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -9,7 +9,7 @@ from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs, TransformerConfig from fast_llm.tensor import init_normal_, init_zeros_ from fast_llm.utils import get_lr_scale @@ -54,21 +54,21 @@ class Attention(Mixer): _mixer_name: typing.ClassVar[str] = "attn" _QUERY_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.composite_heads, - TransformerDimNames.kv_channels, + AttentionDimNames.batch, + AttentionDimNames.sequence_q, + AttentionDimNames.composite_heads, + AttentionDimNames.kv_channels, ) _KV_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.head_groups, - TransformerDimNames.kv_channels, + AttentionDimNames.batch, + AttentionDimNames.sequence_q, + AttentionDimNames.head_groups, + AttentionDimNames.kv_channels, ) _CONTEXT_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.composite_dense, + AttentionDimNames.batch, + AttentionDimNames.sequence_q, + AttentionDimNames.composite_dense, ) def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int): @@ -87,14 +87,14 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i max_val=self._config.init_method_max_attn_proj, ) - self._kv_channels = self._tensor_space[TransformerDimNames.kv_channels].size - self._head_groups = self._tensor_space[TransformerDimNames.head_groups].global_size - self._local_head_groups = self._tensor_space[TransformerDimNames.head_groups].size - self._local_heads_per_group = self._tensor_space[TransformerDimNames.group_heads].size + self._kv_channels = self._tensor_space[AttentionDimNames.kv_channels].size + self._head_groups = self._tensor_space[AttentionDimNames.head_groups].global_size + self._local_head_groups = self._tensor_space[AttentionDimNames.head_groups].size + self._local_heads_per_group = self._tensor_space[AttentionDimNames.group_heads].size self._local_heads = self._local_head_groups * self._local_heads_per_group self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power) - hidden_dim = self._tensor_space[TransformerDimNames.hidden] + hidden_dim = self._tensor_space[AttentionDimNames.hidden] layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) @@ -102,19 +102,19 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, - self._tensor_space[TransformerDimNames.composite_query], - bias=self._config.add_attn_qkv_bias, + self._tensor_space[AttentionDimNames.composite_query], + bias=self._config.add_qkv_bias, weight_init_method=init_method_qkv, - bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, + bias_init_method=init_zeros_, sequence_parallel=self._sequence_parallel, lr_scale=attention_lr_scale, ) self.key_value = OutputParallelLinear( hidden_dim, - self._tensor_space[TransformerDimNames.composite_key_value], - bias=self._config.add_attn_qkv_bias, + self._tensor_space[AttentionDimNames.composite_key_value], + bias=self._config.add_qkv_bias, weight_init_method=init_method_qkv, - bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, + bias_init_method=init_zeros_, sequence_parallel=self._sequence_parallel, lr_scale=attention_lr_scale, ) @@ -125,11 +125,11 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i # Output. self.dense = InputParallelLinear( - self._tensor_space[TransformerDimNames.composite_dense], + self._tensor_space[AttentionDimNames.composite_dense], hidden_dim, - bias=self._config.add_attn_dense_bias, + bias=self._config.add_dense_bias, weight_init_method=init_method_std_attn_proj, - bias_init_method=init_method_std_attn_proj if self._config.random_bias_init else init_zeros_, + bias_init_method=init_zeros_, sequence_parallel=self._sequence_parallel, lr_scale=attention_lr_scale, ) @@ -260,17 +260,17 @@ def _decide_window_size(self) -> int | None: return window_size def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: - sequence_first = kwargs[TransformerKwargs.sequence_first] + sequence_first = kwargs[AttentionKwargs.sequence_first] query, key_value = self._query_key_value(input_, sequence_first) # TODO: Move the rest to function. - if (past_key_values := kwargs.get(TransformerKwargs.past_key_values)) is not None: + if (past_key_values := kwargs.get(AttentionKwargs.past_key_values)) is not None: assert sequence_first # Clear the lists so tensors can be de-allocated key_value = torch.cat((past_key_values.pop(0), key_value), dim=0) - if (presents := kwargs.get(TransformerKwargs.presents)) is not None: + if (presents := kwargs.get(AttentionKwargs.presents)) is not None: # Return the presents as a leaf tensors so the gradients from later micro-sequences # don't propagate to this one. presents.append(present := key_value.detach().requires_grad_()) @@ -279,9 +279,9 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if self._tensor_space.distributed.sequence_data_group: key_value = ( - key_value[: kwargs[TransformerKwargs.sequence_k_dim].size] + key_value[: kwargs[AttentionKwargs.sequence_k_dim].size] if sequence_first - else key_value[:, : kwargs[TransformerKwargs.sequence_k_dim].size] + else key_value[:, : kwargs[AttentionKwargs.sequence_k_dim].size] ) if sequence_first: @@ -310,7 +310,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if self._use_flash_attention: assert _flash_available with set_generator(self._tensor_space.distributed.tp_generator): - if (cu_seqlens_q := kwargs.get(TransformerKwargs.cu_seqlens_q, None)) is not None: + if (cu_seqlens_q := kwargs.get(AttentionKwargs.cu_seqlens_q, None)) is not None: out_dims = query.size() query = query.view(-1, query.size(-2), query.size(-1)) key = key.view(-1, key.size(-2), key.size(-1)) @@ -320,9 +320,9 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key, value, cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=kwargs.get(TransformerKwargs.cu_seqlens_k), - max_seqlen_q=kwargs.get(TransformerKwargs.max_seqlen_q), - max_seqlen_k=kwargs.get(TransformerKwargs.max_seqlen_k), + cu_seqlens_k=kwargs.get(AttentionKwargs.cu_seqlens_k), + max_seqlen_q=kwargs.get(AttentionKwargs.max_seqlen_q), + max_seqlen_k=kwargs.get(AttentionKwargs.max_seqlen_k), dropout_p=self._config.attention_dropout if self.training else 0.0, window_size=(-1, -1) if window_size is None else (window_size - 1, 0), causal=True, @@ -345,8 +345,8 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ query.flatten(-2), key.flatten(-2), value.flatten(-2), - kwargs[TransformerKwargs.attention_mask], - kwargs[TransformerKwargs.attention_mask_value], + kwargs[AttentionKwargs.attention_mask], + kwargs[AttentionKwargs.attention_mask_value], ) if self._debug_level: diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index ebb976e63..a8245f7da 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -18,7 +18,7 @@ logger = logging.getLogger(__name__) -class TransformerDimNames(BlockDimNames): +class AttentionDimNames(BlockDimNames): # A set of common tensor dim names packed into a namespace. # Self-attention dimensions head_groups = "head_groups" @@ -31,7 +31,7 @@ class TransformerDimNames(BlockDimNames): composite_dense = "composite_dense" -class TransformerKwargs(BlockKwargs): +class AttentionKwargs(BlockKwargs): rotary_freq_q = "rotary_freq_q" rotary_freq_k = "rotary_freq_k" attention_mask = "attention_mask" @@ -106,78 +106,7 @@ class AttentionConfig(Config): " Under muP (if scaling number of heads instead of kv_channels): use 0.5.", valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - - def _validate(self) -> None: - super()._validate() - - if not TritonConfig.TRITON_ENABLED: - warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") - - Assert.multiple(self.num_attention_heads, self.head_groups) - - @functools.cached_property - def projection_size(self): - assert self._validated - return self.num_attention_heads * self.kv_channels - - def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: - return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) - - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) - # Needed for multiple inheritance. - super().setup_tensor_space(tensor_space) # Noqa - - tensor_space.add_tensor_dim( - head_groups := TensorDim( - TransformerDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None - ) - ) - tensor_space.add_tensor_dim( - group_heads := TensorDim( - TransformerDimNames.group_heads, - div(self.num_attention_heads, self.head_groups), - None if self.head_groups > 1 else tensor, - ) - ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(TransformerDimNames.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(TransformerDimNames.kv_channels, self.kv_channels)) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_heads, (head_groups, group_heads)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_query, (head_groups, group_heads, kv_channels)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_dense, (head_groups, group_heads, kv_channels)) - ) - - -@config_class() -# TODO: Use composition for attention config -class TransformerConfig(AttentionConfig, BlockConfig): - _abstract = False - - # TODO: Review names - init_method_std: float = Field( - default=None, - desc="Default scale for weight initialization. Default: hidden_size**-0.5", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max: float | None = Field( - default=None, - desc="Max value for clamping initialized weights. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min: float | None = Field( - default=None, - desc="Min value for clamping initialized weights. Default: -float('inf')", - hint=FieldHint.optional, - ) + # TODO: Review initialization init_method_std_qkv: float = Field( default=None, desc="Scale for the query, key and value weight initialization. Default: init_method_std", @@ -210,59 +139,17 @@ class TransformerConfig(AttentionConfig, BlockConfig): desc="Min value for clamping initialized weights for attention projection. Default: -float('inf')", hint=FieldHint.optional, ) - init_method_std_mlp_1: float = Field( - default=None, - desc="Scale for the MLP first layer weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_mlp_1: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for MLP first layer. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_mlp_1: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for MLP first layer. Default: -float('inf')", - hint=FieldHint.optional, - ) - init_method_std_mlp_2: float = Field( - default=None, - desc="Scale for the MLP second layer weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_mlp_2: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for MLP second layer. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_mlp_2: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for MLP second layer. Default: -float('inf')", - hint=FieldHint.optional, - ) - # Use random inits instead of constant values, useful for debugging. - random_bias_init: bool = Field( - default=False, - desc="Initialize the biases using the initialization method of their respective weights instead of setting them to zero. Used to test for issues that may not be visible when the biases are zero.", - hint=FieldHint.testing, - ) def _validate(self) -> None: with self._set_implicit_default(): + # TODO: Make this work without inheritance. if self.kv_channels is None: self.kv_channels = div(self.hidden_size, self.num_attention_heads) - if self.init_method_std is None: - self.init_method_std = self.hidden_size**-0.5 + # TODO: Review initialization if self.init_method_std_qkv is None: self.init_method_std_qkv = self.init_method_std if self.init_method_std_attn_proj is None: self.init_method_std_attn_proj = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 - if self.init_method_std_mlp_1 is None: - self.init_method_std_mlp_1 = self.init_method_std - if self.init_method_std_mlp_2 is None: - self.init_method_std_mlp_2 = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 if self.init_method_max_qkv is None: self.init_method_max_qkv = self.init_method_max if self.init_method_min_qkv is None: @@ -271,31 +158,61 @@ def _validate(self) -> None: self.init_method_max_attn_proj = self.init_method_max if self.init_method_min_attn_proj is None: self.init_method_min_attn_proj = self.init_method_min - if self.init_method_max_mlp_1 is None: - self.init_method_max_mlp_1 = self.init_method_max - if self.init_method_min_mlp_1 is None: - self.init_method_min_mlp_1 = self.init_method_min - if self.init_method_max_mlp_2 is None: - self.init_method_max_mlp_2 = self.init_method_max - if self.init_method_min_mlp_2 is None: - self.init_method_min_mlp_2 = self.init_method_min - if self.init_method_min is not None and self.init_method_max is not None: - Assert.leq(self.init_method_min, self.init_method_max) if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: Assert.leq(self.init_method_min, self.init_method_max) if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: Assert.leq(self.init_method_min_qkv, self.init_method_max_qkv) if self.init_method_min_attn_proj is not None and self.init_method_max_attn_proj is not None: Assert.leq(self.init_method_min_attn_proj, self.init_method_max_attn_proj) - if self.init_method_min_mlp_1 is not None and self.init_method_max_mlp_1 is not None: - Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1) - if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: - Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) super()._validate() + if not TritonConfig.TRITON_ENABLED: + warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") + + Assert.multiple(self.num_attention_heads, self.head_groups) + + @functools.cached_property + def projection_size(self): + assert self._validated + return self.num_attention_heads * self.kv_channels + + def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: + return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) + + def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) + # Needed for multiple inheritance. + super().setup_tensor_space(tensor_space) # Noqa + + tensor_space.add_tensor_dim( + head_groups := TensorDim( + AttentionDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None + ) + ) + tensor_space.add_tensor_dim( + group_heads := TensorDim( + AttentionDimNames.group_heads, + div(self.num_attention_heads, self.head_groups), + None if self.head_groups > 1 else tensor, + ) + ) + tensor_space.add_tensor_dim(key_and_value := TensorDim(AttentionDimNames.key_and_value, 2)) + tensor_space.add_tensor_dim(kv_channels := TensorDim(AttentionDimNames.kv_channels, self.kv_channels)) + tensor_space.add_tensor_dim(CompositeTensorDim(AttentionDimNames.composite_heads, (head_groups, group_heads))) + tensor_space.add_tensor_dim( + CompositeTensorDim(AttentionDimNames.composite_query, (head_groups, group_heads, kv_channels)) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(AttentionDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(AttentionDimNames.composite_dense, (head_groups, group_heads, kv_channels)) + ) + @property - def add_attn_qkv_bias(self) -> bool: + def add_qkv_bias(self) -> bool: + # TODO: Make this work without inheritance. if isinstance(self.add_linear_biases, bool): return self.add_linear_biases if self.add_linear_biases == AddLinearBiasChoices.nowhere: @@ -303,9 +220,16 @@ def add_attn_qkv_bias(self) -> bool: return True @property - def add_attn_dense_bias(self) -> bool: + def add_dense_bias(self) -> bool: + # TODO: Make this work without inheritance. if isinstance(self.add_linear_biases, bool): return self.add_linear_biases if self.add_linear_biases == AddLinearBiasChoices.everywhere: return True return False + + +@config_class() +# TODO: Use composition instead +class TransformerConfig(AttentionConfig, BlockConfig): + _abstract = False diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index 3f0e14eb7..d8fa14a6d 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -5,7 +5,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs +from fast_llm.layers.transformer.config import AttentionKwargs, TransformerConfig from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) @@ -51,13 +51,13 @@ def _create_tensors(self, sequence_length: int) -> None: ) def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[TransformerKwargs.sequence_length]) - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size - kwargs[TransformerKwargs.attention_mask] = self._mask[ + self._create_tensors(kwargs[AttentionKwargs.sequence_length]) + sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size + sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size + kwargs[AttentionKwargs.attention_mask] = self._mask[ None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k ] - if (sequence_lengths := kwargs.get(TransformerKwargs.sequence_lengths, None)) is not None: + if (sequence_lengths := kwargs.get(AttentionKwargs.sequence_lengths, None)) is not None: seq_ids = torch.stack( [ torch.cat([torch.full((x,), i) for i, x in enumerate(sample_lens)]) @@ -65,27 +65,27 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: ] ) document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(self._tensor_space.distributed.device) - kwargs[TransformerKwargs.attention_mask] = ( - kwargs[TransformerKwargs.attention_mask] + kwargs[AttentionKwargs.attention_mask] = ( + kwargs[AttentionKwargs.attention_mask] & document_mask[:, None, sequence_k - sequence_q : sequence_k, None, :sequence_k] ) - kwargs[TransformerKwargs.attention_mask_value] = self._mask_value + kwargs[AttentionKwargs.attention_mask_value] = self._mask_value def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[TransformerKwargs.attention_mask] = TensorMeta.from_dims( + kwargs[AttentionKwargs.attention_mask] = TensorMeta.from_dims( ( self._scalar_dim, self._scalar_dim, - kwargs[TransformerKwargs.sequence_q_dim], + kwargs[AttentionKwargs.sequence_q_dim], self._scalar_dim, - kwargs[TransformerKwargs.sequence_k_dim], + kwargs[AttentionKwargs.sequence_k_dim], ), - tensor_name=TransformerKwargs.attention_mask, + tensor_name=AttentionKwargs.attention_mask, dtype=torch.bool, ) - kwargs[TransformerKwargs.attention_mask_value] = TensorMeta.from_dims( + kwargs[AttentionKwargs.attention_mask_value] = TensorMeta.from_dims( (self._scalar_dim,), - tensor_name=TransformerKwargs.attention_mask_value, + tensor_name=AttentionKwargs.attention_mask_value, dtype=self._tensor_space.distributed_config.training_dtype.torch, ) @@ -107,12 +107,12 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: also contain previous tokens from the first document in micro-sequence. We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths. """ - if TransformerKwargs.sequence_lengths not in kwargs: + if AttentionKwargs.sequence_lengths not in kwargs: return - sequence_lengths = kwargs[TransformerKwargs.sequence_lengths] - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size - if sequence_q < kwargs[TransformerKwargs.sequence_length]: + sequence_lengths = kwargs[AttentionKwargs.sequence_lengths] + sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size + sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size + if sequence_q < kwargs[AttentionKwargs.sequence_length]: cumsums = [torch.cumsum(x, dim=0) for x in sequence_lengths] # The first and last documents in a microsequence need to be handled separately. Include all tokens from other documents # in the microsequence. We need to consider all keys computed so far from the first sample. We also store the offsets @@ -146,17 +146,17 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: else: seqlens_q = torch.cat(sequence_lengths) seqlens_k = torch.cat(sequence_lengths) - kwargs[TransformerKwargs.cu_seqlens_q] = torch.cat( + kwargs[AttentionKwargs.cu_seqlens_q] = torch.cat( ( torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), torch.cumsum(seqlens_q, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), ) ) - kwargs[TransformerKwargs.cu_seqlens_k] = torch.cat( + kwargs[AttentionKwargs.cu_seqlens_k] = torch.cat( ( torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), torch.cumsum(seqlens_k, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), ) ) - kwargs[TransformerKwargs.max_seqlen_q] = seqlens_q.max() - kwargs[TransformerKwargs.max_seqlen_k] = seqlens_k.max() + kwargs[AttentionKwargs.max_seqlen_q] = seqlens_q.max() + kwargs[AttentionKwargs.max_seqlen_k] = seqlens_k.max() diff --git a/fast_llm/layers/transformer/rotary/preprocessing.py b/fast_llm/layers/transformer/rotary/preprocessing.py index c357411b6..9f8732f85 100644 --- a/fast_llm/layers/transformer/rotary/preprocessing.py +++ b/fast_llm/layers/transformer/rotary/preprocessing.py @@ -4,7 +4,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig from fast_llm.tensor import TensorMeta @@ -26,34 +26,34 @@ def __init__( self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] - self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] + self._kv_channels_dim = self._tensor_space[AttentionDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[TransformerKwargs.sequence_length]) - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - kwargs[TransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ - :, sequence_k - kwargs[TransformerKwargs.sequence_q_dim].size : sequence_k + self._create_tensors(kwargs[AttentionKwargs.sequence_length]) + sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size + kwargs[AttentionKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ + :, sequence_k - kwargs[AttentionKwargs.sequence_q_dim].size : sequence_k ] - kwargs[TransformerKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] + kwargs[AttentionKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[TransformerKwargs.rotary_freq_q] = TensorMeta.from_dims( + kwargs[AttentionKwargs.rotary_freq_q] = TensorMeta.from_dims( ( self._scalar_dim, - kwargs[TransformerKwargs.sequence_q_dim], + kwargs[AttentionKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=TransformerKwargs.rotary_freq_q, + tensor_name=AttentionKwargs.rotary_freq_q, ) - kwargs[TransformerKwargs.rotary_freq_k] = TensorMeta.from_dims( + kwargs[AttentionKwargs.rotary_freq_k] = TensorMeta.from_dims( ( self._scalar_dim, - kwargs[TransformerKwargs.sequence_q_dim], + kwargs[AttentionKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=TransformerKwargs.rotary_freq_k, + tensor_name=AttentionKwargs.rotary_freq_k, ) def _create_tensors(self, sequence_length: int) -> None: diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index 17b18a1ca..ebb629aa1 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -8,7 +8,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.triton.rotary import triton_rotary_autograd_ -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs from fast_llm.layers.transformer.rotary.config import ( DefaultRotaryConfig, Llama3RotaryConfig, @@ -83,44 +83,44 @@ def __init__( self._tensor_space = tensor_space if self._tensor_space is not None: self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] - self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] + self._kv_channels_dim = self._tensor_space[AttentionDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: assert self._tensor_space is not None - self._create_tensors(kwargs[TransformerKwargs.sequence_length]) - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - kwargs[TransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ - :, sequence_k - kwargs[TransformerKwargs.sequence_q_dim].size : sequence_k + self._create_tensors(kwargs[AttentionKwargs.sequence_length]) + sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size + kwargs[AttentionKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ + :, sequence_k - kwargs[AttentionKwargs.sequence_q_dim].size : sequence_k ] - kwargs[TransformerKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] + kwargs[AttentionKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: assert self._tensor_space is not None - kwargs[TransformerKwargs.rotary_freq_q] = TensorMeta.from_dims( + kwargs[AttentionKwargs.rotary_freq_q] = TensorMeta.from_dims( ( self._scalar_dim, - kwargs[TransformerKwargs.sequence_q_dim], + kwargs[AttentionKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=TransformerKwargs.rotary_freq_q, + tensor_name=AttentionKwargs.rotary_freq_q, ) - kwargs[TransformerKwargs.rotary_freq_k] = TensorMeta.from_dims( + kwargs[AttentionKwargs.rotary_freq_k] = TensorMeta.from_dims( ( self._scalar_dim, - kwargs[TransformerKwargs.sequence_q_dim], + kwargs[AttentionKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=TransformerKwargs.rotary_freq_k, + tensor_name=AttentionKwargs.rotary_freq_k, ) def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: rotary_fn = triton_rotary_autograd_ if self._config.triton else apply_rotary_embeddings - query = rotary_fn(query, kwargs[TransformerKwargs.rotary_freq_q]) - key = rotary_fn(key, kwargs[TransformerKwargs.rotary_freq_k]) + query = rotary_fn(query, kwargs[AttentionKwargs.rotary_freq_q]) + key = rotary_fn(key, kwargs[AttentionKwargs.rotary_freq_k]) return query, key def _create_tensors(self, sequence_length: int) -> None: diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 2dbef77f3..6e79388b0 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -199,19 +199,19 @@ def _create_transformer_layer_converters( ( f"{fast_llm_layer_name}.self_attn.query", f"{hf_layer_name}.self_attn.q_proj", - transformer_config.add_attn_qkv_bias, + transformer_config.add_qkv_bias, QueryWeightConverter, ), ( f"{fast_llm_layer_name}.self_attn.key_value", (f"{hf_layer_name}.self_attn.k_proj", f"{hf_layer_name}.self_attn.v_proj"), - transformer_config.add_attn_qkv_bias, + transformer_config.add_qkv_bias, KeyValueWeightConverter, ), ( f"{fast_llm_layer_name}.self_attn.dense", f"{hf_layer_name}.self_attn.o_proj", - transformer_config.add_attn_dense_bias, + transformer_config.add_dense_bias, WeightConverter, ), # Norm diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index cf7da3872..4e3f258fc 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -9,7 +9,7 @@ from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM -from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.transformer.config import AttentionKwargs from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner @@ -86,12 +86,12 @@ def forward( if past_key_values is not None: # The transformers will use the past keys and values to this list. - kwargs[TransformerKwargs.past_key_values] = past_key_values + kwargs[AttentionKwargs.past_key_values] = past_key_values # TODO: preprocess needs to know about the past. raise NotImplementedError() if use_cache: # The transformers will save the present keys and values to this list. - kwargs[TransformerKwargs.presents] = [] + kwargs[AttentionKwargs.presents] = [] if output_hidden_states: kwargs["output_hidden_states"] = True @@ -117,11 +117,11 @@ def forward( outputs = (logits,) if use_cache: - outputs += (kwargs[TransformerKwargs.presents],) + outputs += (kwargs[AttentionKwargs.presents],) return outputs return transformers.modeling_outputs.CausalLMOutputWithPast( logits=logits, hidden_states=hidden_states, - past_key_values=kwargs[TransformerKwargs.presents], + past_key_values=kwargs[AttentionKwargs.presents], ) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index da647de57..187ca618d 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -16,7 +16,7 @@ from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor from fast_llm.layers.transformer.block import TransformerBlock -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron @@ -119,7 +119,7 @@ def preprocess_meta( truncate_documents = True batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) - batch_dim = TensorDim(TransformerDimNames.batch, micro_batch_size * batch_data.size, batch_data) + batch_dim = TensorDim(AttentionDimNames.batch, micro_batch_size * batch_data.size, batch_data) if micro_sequence_length is None: micro_sequence_length = sequence_length @@ -128,13 +128,13 @@ def preprocess_meta( # TODO: Calculate hidden dims elsewhere? sequence_q_dim = TensorDim( - TransformerDimNames.sequence_q, + AttentionDimNames.sequence_q, micro_sequence_length, self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.sequence_data), ) hidden_sequence_q_dim = ( TensorDim( - TransformerDimNames.sequence_q_tp, + AttentionDimNames.sequence_q_tp, micro_sequence_length, self._tensor_space.distributed_config.get_distributed_dim( DistributedDimNames.tensor_and_sequence_data @@ -151,7 +151,7 @@ def preprocess_meta( sequence_first = self._config.sequence_first assert not (need_sequence_first and not sequence_first) - hidden_dim = self._tensor_space[TransformerDimNames.hidden] + hidden_dim = self._tensor_space[AttentionDimNames.hidden] hidden_dims = ( (hidden_sequence_q_dim, batch_dim, hidden_dim) if sequence_first @@ -160,10 +160,10 @@ def preprocess_meta( common_kwargs = { LanguageModelKwargs.phase: phase, - TransformerKwargs.sequence_first: sequence_first, - TransformerKwargs.hidden_dims: hidden_dims, - TransformerKwargs.sequence_length: sequence_length, - TransformerKwargs.sequence_q_dim: sequence_q_dim, + AttentionKwargs.sequence_first: sequence_first, + AttentionKwargs.hidden_dims: hidden_dims, + AttentionKwargs.sequence_length: sequence_length, + AttentionKwargs.sequence_q_dim: sequence_q_dim, LanguageModelKwargs.mask_inputs: not truncate_documents, } @@ -182,7 +182,7 @@ def preprocess_meta( preprocessed_meta = [] for i, sequence_k_past in enumerate(sequence_k_pasts): sequence_k = sequence_k_past + sequence_q_dim.size - sequence_k_dim = TensorDim(TransformerDimNames.sequence_k, sequence_k) + sequence_k_dim = TensorDim(AttentionDimNames.sequence_k, sequence_k) tokens = TensorMeta.from_dims( hidden_dims[:2], tensor_name=f"tokens_{sequence_k_past}_to_{sequence_k-1}", dtype=torch.int64 @@ -190,7 +190,7 @@ def preprocess_meta( kwargs = { **common_kwargs, - TransformerKwargs.sequence_k_dim: sequence_k_dim, + AttentionKwargs.sequence_k_dim: sequence_k_dim, } if phase != PhaseType.inference: kwargs[LanguageModelKwargs.labels] = TensorMeta.from_dims( @@ -202,10 +202,10 @@ def preprocess_meta( for name, reference_preprocessed_meta in reference_preprocessed_metas.items(): reference_tokens, reference_kwargs_ = reference_preprocessed_meta[i] for key in ( - TransformerKwargs.sequence_first, - TransformerKwargs.sequence_length, - TransformerKwargs.sequence_q_dim, - TransformerKwargs.sequence_k_dim, + AttentionKwargs.sequence_first, + AttentionKwargs.sequence_length, + AttentionKwargs.sequence_q_dim, + AttentionKwargs.sequence_k_dim, ): Assert.eq(reference_kwargs_[key], kwargs[key]) reference_kwargs[name] = reference_kwargs_ @@ -231,8 +231,8 @@ def preprocess( preprocessed_meta = self.preprocess_meta(batch.token_ids, phase) _, common_kwargs = preprocessed_meta[0] - sequence_q = common_kwargs[TransformerKwargs.sequence_q_dim].size - sequence_first = common_kwargs[TransformerKwargs.sequence_first] + sequence_q = common_kwargs[AttentionKwargs.sequence_q_dim].size + sequence_first = common_kwargs[AttentionKwargs.sequence_first] prediction_heads: int = self._config.prediction_heads batch.token_ids = batch.token_ids.to( @@ -264,14 +264,14 @@ def preprocess( preprocessed = [] presents = None for i, (_, kwargs_meta) in enumerate(preprocessed_meta): - sequence_k = kwargs_meta[TransformerKwargs.sequence_k_dim].size + sequence_k = kwargs_meta[AttentionKwargs.sequence_k_dim].size if sequence_first: tokens = batch.token_ids[sequence_k - sequence_q : sequence_k] else: # TODO: Avoid multiple contiguous calls? tokens = batch.token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() if batch.sequence_lengths is not None: - kwargs_meta[TransformerKwargs.sequence_lengths] = batch.sequence_lengths + kwargs_meta[AttentionKwargs.sequence_lengths] = batch.sequence_lengths if batch.chosen_spans is not None: kwargs_meta[LanguageModelKwargs.chosen_spans] = batch.chosen_spans if batch.rejected_spans is not None: @@ -283,8 +283,8 @@ def preprocess( presents = None if i == len(preprocessed_meta) - 1 else [] kwargs = { **kwargs_meta, - TransformerKwargs.past_key_values: pasts, - TransformerKwargs.presents: presents, + AttentionKwargs.past_key_values: pasts, + AttentionKwargs.presents: presents, } if phase != PhaseType.inference: sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 9a878c494..8c33aed4d 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -9,7 +9,7 @@ from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead -from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.transformer.config import AttentionKwargs from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda @@ -198,8 +198,8 @@ def test_lm_head( else: loss_mask = None kwargs = { - TransformerKwargs.sequence_first: sequence_first, - TransformerKwargs.grad_output: 1.0, + AttentionKwargs.sequence_first: sequence_first, + AttentionKwargs.grad_output: 1.0, } if config.distillation_model is None: target = torch.randint( diff --git a/tests/test_attention.py b/tests/test_attention.py index dd36b840a..534e3800e 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -6,7 +6,7 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs, TransformerConfig from fast_llm.layers.transformer.preprocessing import FlashAttnVarlenPreprocessor from fast_llm.utils import Assert @@ -77,13 +77,13 @@ def test_varlen_preprocessor(): varlen_preprocessor = FlashAttnVarlenPreprocessor(transformer_cfg, tensor_space=tensor_space) for micro_seq_idx in range(int(sequence_length / micro_sequence_length)): kwargs = { - TransformerKwargs.sequence_q_dim: TensorDim(TransformerDimNames.sequence_k, micro_sequence_length), - TransformerKwargs.sequence_k_dim: TensorDim( - TransformerDimNames.sequence_k, (micro_seq_idx + 1) * micro_sequence_length + AttentionKwargs.sequence_q_dim: TensorDim(AttentionDimNames.sequence_k, micro_sequence_length), + AttentionKwargs.sequence_k_dim: TensorDim( + AttentionDimNames.sequence_k, (micro_seq_idx + 1) * micro_sequence_length ), - TransformerKwargs.sequence_length: sequence_length, - TransformerKwargs.sequence_lengths: sequence_lengths, + AttentionKwargs.sequence_length: sequence_length, + AttentionKwargs.sequence_lengths: sequence_lengths, } varlen_preprocessor.preprocess(None, kwargs) - Assert.all_equal(kwargs[TransformerKwargs.cu_seqlens_q], cumulative_sequences_q[micro_seq_idx]) - Assert.all_equal(kwargs[TransformerKwargs.cu_seqlens_k], cumulative_sequences_k[micro_seq_idx]) + Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_q], cumulative_sequences_q[micro_seq_idx]) + Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_k], cumulative_sequences_k[micro_seq_idx]) diff --git a/tests/test_ssms.py b/tests/test_ssms.py index 694faa55b..6c4c7f0cb 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -9,7 +9,7 @@ from fast_llm.engine.schedule.config import ScheduleConfig from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.engine.schedule.schedule import Schedule -from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.transformer.config import AttentionKwargs from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat from fast_llm.models.ssm.model import HybridSSMModel @@ -71,8 +71,8 @@ def test_load_from_llamba_checkpoint(): schedule_runner.setup(model.distributed, optimizer=None) common_kwargs = { - TransformerKwargs.sequence_first: True, - TransformerKwargs.grad_output: False, + AttentionKwargs.sequence_first: True, + AttentionKwargs.grad_output: False, } input_data = [(x, common_kwargs)] From 82c9dbd2d4270ea0bbe30afbe79520be4ebc7e68 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 31 Jul 2025 16:29:55 -0400 Subject: [PATCH 07/28] misc --- fast_llm/layers/block/block.py | 104 ++++++++++++++++++++++- fast_llm/layers/block/mixer.py | 68 --------------- fast_llm/layers/language_model/head.py | 88 ++++++++++--------- fast_llm/layers/ssm/block.py | 7 +- fast_llm/layers/ssm/discrete_mamba2.py | 12 ++- fast_llm/layers/ssm/mamba2.py | 36 +++++--- fast_llm/layers/ssm/mamba_layer.py | 12 ++- fast_llm/layers/transformer/attention.py | 51 ++++++----- fast_llm/layers/transformer/block.py | 5 +- 9 files changed, 219 insertions(+), 164 deletions(-) delete mode 100644 fast_llm/layers/block/mixer.py diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 85da61c01..87a8f81cf 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -1,4 +1,5 @@ import abc +import functools import typing import torch @@ -9,13 +10,112 @@ from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs -from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.block.mlp.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta +class DebugLayer: + # TODO: Move elsewhere? + def __init__(self, tensor_space: TensorSpace, name: str, debug_level: int = 0, debug_memory: bool = False): + self._tensor_space = tensor_space + self._name = name + self._debug_level = debug_level + self._debug_memory = debug_memory + + def _get_meta( + self, tensor: torch.Tensor, name: str, dims: tuple[TensorDim | str, ...], kwargs: dict[str, typing.Any] + ) -> TensorMeta: + hidden_dims = { + dim.name: dim for dim in kwargs[BlockKwargs.hidden_dims] + (kwargs[BlockKwargs.sequence_q_dim],) + } + return TensorMeta.from_dims( + tuple( + ( + dim + if isinstance(dim, TensorDim) + else hidden_dims[dim] if dim in hidden_dims else self._tensor_space[dim] + ) + for dim in dims + ), + tensor_name=f"{self._name} {name}", + dtype=tensor.dtype, + ) + + @functools.cached_property + def enabled(self) -> bool: + return self._debug_level > 0 or self._debug_memory + + def __call__( + self, + tensor: torch.Tensor, + name: str, + dims: tuple[TensorDim | str, ...], + kwargs: dict[str, typing.Any], + scale: float = 1.0, + global_: bool = True, + log_fn: type[BaseException] | typing.Callable[[str], T] | None = logger.info, + ) -> None: + # TODO: Local vs global? + if self._debug_memory: + log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"{self._name} {name}", str)) + if self._debug_level > 0: + log_distributed_tensor( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name, dims, kwargs), + distributed=self._tensor_space.distributed, + global_=global_, + log_fn=log_fn, + scale=scale, + ) + if tensor.requires_grad: + log_distributed_grad( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name + " grad", dims, kwargs), + distributed=self._tensor_space.distributed, + global_=global_, + log_fn=log_fn, + scale=scale, + ) + + +class BlockLayer(torch.nn.Module, abc.ABC): + """ + Base class for mixer and MLP modules. + """ + + def __init__(self, tensor_space: TensorSpace, block_index: int, name: str, debug_level: int, debug_memory: bool): + super().__init__() + self._tensor_space = tensor_space + self._block_index = block_index + self._name = name + self._sequence_parallel: bool = self._tensor_space.distributed_config.sequence_tensor_parallel + self._debug = DebugLayer( + tensor_space, + f"Block {self._block_index} {self._name}", + debug_level, + debug_memory, + ) + + @abc.abstractmethod + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + pass + + def _debug_log(self, tensor: torch.Tensor) -> None: + pass + + class Block[ConfigType: BlockConfig](Layer, Configurable[ConfigType]): """ A transformer-like decoder base block with abstract mixer. @@ -52,7 +152,7 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: i self.norm_2 = self._config.peft.apply_other(self.norm_2) @abc.abstractmethod - def _create_mixer(self) -> Mixer: + def _create_mixer(self) -> BlockLayer: pass @torch.compile diff --git a/fast_llm/layers/block/mixer.py b/fast_llm/layers/block/mixer.py deleted file mode 100644 index 5c811e330..000000000 --- a/fast_llm/layers/block/mixer.py +++ /dev/null @@ -1,68 +0,0 @@ -import abc -import typing - -import torch - -from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.block.config import BlockKwargs -from fast_llm.logging import log_distributed_grad, log_distributed_tensor -from fast_llm.tensor import TensorMeta -from fast_llm.utils import Assert - - -class Mixer(torch.nn.Module, abc.ABC): - """ - Base class for mixer modules. - """ - - _mixer_name: typing.ClassVar[str] - - def __init__(self, tensor_space: TensorSpace, block_index: int, debug_level: int = 0): - super().__init__() - self._tensor_space = tensor_space - self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel - self._block_index = block_index - self._debug_level = debug_level - - @abc.abstractmethod - def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - Mixer module forward. Returns the output hidden states and an optional bias, - in case its addition can be made more efficient in `_bias_dropout_add`. - """ - - def _get_meta( - self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> TensorMeta: - hidden_dims = { - dim.name: dim for dim in kwargs[BlockKwargs.hidden_dims] + (kwargs[BlockKwargs.sequence_q_dim],) - } - return TensorMeta.from_dims( - tuple( - hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space[dim_name] - for dim_name in dim_names - ), - tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", - dtype=input_.dtype, - ) - - def _debug_log( - self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> None: - # TODO: Local vs global - Assert.gt(self._debug_level, 0) - log_distributed_tensor( - "", - tensor, - level=self._debug_level, - meta=self._get_meta(tensor, name, dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - if tensor.requires_grad: - log_distributed_grad( - "", - tensor, - level=self._debug_level, - meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index bc672725c..0623ac201 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -15,6 +15,7 @@ from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward +from fast_llm.layers.block.block import DebugLayer from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.language_model.config import ( LanguageModelBaseConfig, @@ -46,47 +47,66 @@ def __init__( prediction_distance: int, ): super().__init__(config) - self._debug_transformer = config.transformer.debug_transformer - self._tie_word_embeddings = config.tie_word_embeddings + self._debug = DebugLayer( + tensor_space, + f"Language model head", + self._config.transformer.debug_transformer, + self._config.transformer.debug_transformer_memory, + ) self._tensor_space = tensor_space self._group_size = tensor_space.distributed_config.tensor_parallel self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel - self._parallel_embeddings = tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings + self._parallel_embeddings = ( + tensor_space.distributed_config.tensor_parallel > 1 and self._config.parallel_embeddings + ) self._sequence_parallel_logits = ( - tensor_space.distributed_config.sequence_tensor_parallel and not config.parallel_embeddings + tensor_space.distributed_config.sequence_tensor_parallel and not self._config.parallel_embeddings ) - self._cross_entropy_splits = config.cross_entropy_splits + self._cross_entropy_splits = self._config.cross_entropy_splits if self._cross_entropy_splits is not None and self._sequence_parallel: assert not self._parallel_embeddings hidden_dim = self._tensor_space[LanguageModelDimNames.hidden] self._loss_coefficient = ( - config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0 + self._config.prediction_loss_coefficient[prediction_distance] + if self._config.prediction_loss_coefficient + else 1.0 ) self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance) - self.final_norm = config.transformer.normalization.get_layer(hidden_dim) - self._logits_scale_factor = config.logits_scale_factor - self._language_model_loss_factor = config.language_model_loss_factor - self._distillation_loss_factor = config.distillation_loss_factor - self._z_loss_factor = config.logit_z_loss + self.final_norm = self._config.transformer.normalization.get_layer(hidden_dim) + self._logits_scale_factor = self._config.logits_scale_factor + self._language_model_loss_factor = self._config.language_model_loss_factor + self._distillation_loss_factor = self._config.distillation_loss_factor + self._z_loss_factor = self._config.logit_z_loss # Distance of the target token prediction # 0: next-token prediction # >0: multi-token prediction (MTP) Assert.geq(prediction_distance, 0) self._prediction_distance = prediction_distance - self._is_last_head = self._prediction_distance == config.prediction_heads - 1 + self._is_last_head = self._prediction_distance == self._config.prediction_heads - 1 - self._init_output_weights(hidden_dim, config) + # Only the first head defines the output weights + if self._prediction_distance == 0 and not self._config.tie_word_embeddings: + # untie embedding weights + vocab_dim = self._tensor_space[ + LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab + ] + self.output_weights = ParameterMeta.from_dims( + (vocab_dim, hidden_dim), + init_method=init_normal_( + std=self._config.init_method_std_embed, + min_val=self._config.init_method_min_embed, + max_val=self._config.init_method_max_embed, + ), + lr_scale=self._config.output_lr_scale, + ) - self._use_dpo_loss = config.enable_dpo - if self._use_dpo_loss: - self.dpo_beta = config.dpo_beta - else: - self._cross_entropy_impl = config.cross_entropy_impl - self._distillation_loss_implementation = config.distillation_loss_implementation + self._use_dpo_loss = self._config.enable_dpo + if not self._use_dpo_loss: + self._cross_entropy_impl = self._config.cross_entropy_impl if self._cross_entropy_impl == CrossEntropyImpl.auto: if self._parallel_embeddings: self._cross_entropy_impl = CrossEntropyImpl.fused @@ -102,24 +122,6 @@ def __init__( if hasattr(self, "output_weights"): self.output_weights = self._config.transformer.peft.apply_weight(self.output_weights) - def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: - # Only the first head defines the output weights - if self._tie_word_embeddings or self._prediction_distance > 0: - return - # untie embedding weights - vocab_dim = self._tensor_space[ - LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ] - self.output_weights = ParameterMeta.from_dims( - (vocab_dim, hidden_dim), - init_method=init_normal_( - std=config.init_method_std_embed, - min_val=config.init_method_min_embed, - max_val=config.init_method_max_embed, - ), - lr_scale=config.output_lr_scale, - ) - def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: @@ -251,7 +253,7 @@ def _get_targets( return targets def _get_output_weights(self, kwargs: dict) -> torch.Tensor: - if self._tie_word_embeddings: + if self._config.tie_word_embeddings: return kwargs[WORD_EMBEDDINGS_WEIGHT] if self._prediction_distance > 0: return kwargs[OUTPUT_WEIGHTS] @@ -379,7 +381,7 @@ def _logits_cross_entropy_forward_backward( kwargs.get(f"{self._config.dpo_reference_model}_logits"), kwargs[LanguageModelKwargs.chosen_spans], kwargs[LanguageModelKwargs.rejected_spans], - self.dpo_beta, + self._config.dpo_beta, grad_output * self._loss_coefficient, ) else: @@ -401,7 +403,7 @@ def _logits_cross_entropy_forward_backward( lm_loss, lm_grad = None, None if distillation_target is not None and self._distillation_loss_factor > 0.0: - if self._distillation_loss_implementation == DistillationLossImpl.reverse_kl: + if self._config.distillation_loss_implementation == DistillationLossImpl.reverse_kl: distillation_loss, distillation_grad = reverse_kl_forward_backward( logits.flatten(0, -2), distillation_target, @@ -414,7 +416,7 @@ def _logits_cross_entropy_forward_backward( TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits ), ) - elif self._distillation_loss_implementation == DistillationLossImpl.cross_entropy: + elif self._config.distillation_loss_implementation == DistillationLossImpl.cross_entropy: distillation_loss, distillation_grad = cross_entropy_forward_backward( logits.flatten(0, -2), distillation_target, @@ -426,7 +428,9 @@ def _logits_cross_entropy_forward_backward( target_format=TargetFormat.logits, ) else: - raise ValueError(f"Invalid distillation loss implementation: {self._distillation_loss_implementation}") + raise ValueError( + f"Invalid distillation loss implementation: {self._config.distillation_loss_implementation}" + ) distillation_loss = distillation_loss * self._distillation_loss_factor else: distillation_loss, distillation_grad = None, None diff --git a/fast_llm/layers/ssm/block.py b/fast_llm/layers/ssm/block.py index 0bfa266ac..987d5fa0d 100644 --- a/fast_llm/layers/ssm/block.py +++ b/fast_llm/layers/ssm/block.py @@ -1,7 +1,6 @@ from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.block.block import Block +from fast_llm.layers.block.block import Block, BlockLayer from fast_llm.layers.block.config import BlockConfig -from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.ssm.config import SSMConfig @@ -18,7 +17,7 @@ def __init__( config: BlockConfig, ssm_config: SSMConfig, tensor_space: TensorSpace, - mixer_cls: type[Mixer], + mixer_cls: type[BlockLayer], block_index: int, return_input: bool = False, ): @@ -26,7 +25,7 @@ def __init__( self._mixer_cls = mixer_cls super().__init__(config, tensor_space, block_index, return_input) - def _create_mixer(self) -> Mixer: + def _create_mixer(self) -> BlockLayer: return self._mixer_cls( self._ssm_config, tensor_space=self._tensor_space, diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 550c44d0f..e48636926 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -6,8 +6,8 @@ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType +from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs -from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ @@ -32,7 +32,7 @@ _causal_conv1d_available = False -class DiscreteMamba2(Mixer): +class DiscreteMamba2(BlockLayer): """DiscreteMamba2 (This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py).""" _mixer_name: typing.ClassVar[str] = "discrete_mamba_2" @@ -44,7 +44,13 @@ def __init__( tensor_space: TensorSpace, block_config: BlockConfig, ): - super().__init__(tensor_space, block_index, debug_level=block_config.debug_transformer) + super().__init__( + tensor_space, + block_index, + self._mixer_name, + debug_level=block_config.debug_transformer, + debug_memory=block_config.debug_transformer_memory, + ) self._config: SSMConfig = config layer_lr_scale = block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 1c319f490..4357c0e86 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -5,8 +5,8 @@ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.config import ActivationType +from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs -from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) -class Mamba2(Mixer): +class Mamba2(BlockLayer): """ This code is adapted from https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py """ @@ -56,7 +56,13 @@ def __init__( block_index: int, block_config: BlockConfig, ): - super().__init__(tensor_space, block_index, debug_level=block_config.debug_transformer) + super().__init__( + tensor_space, + block_index, + self._mixer_name, + debug_level=block_config.debug_transformer, + debug_memory=block_config.debug_transformer_memory, + ) self._config: SSMConfig = config Assert.eq(self._config.activation_type, ActivationType.silu) layer_lr_scale: float | None = ( @@ -144,7 +150,13 @@ def __init__( # TODO: lr_scale? ) - def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available assert _causal_conv1d_available @@ -198,12 +210,12 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ # dt: (batch, sequence, heads * state) -> (batch, heads * state, sequence) dt = dt.transpose(1, 2) - if self._debug_level: - self._debug_log(z, "z", self._XZ_DIMS, kwargs) - self._debug_log(x, "x", self._XZ_DIMS, kwargs) - self._debug_log(b, "b", self._BC_DIMS, kwargs) - self._debug_log(c, "c", self._BC_DIMS, kwargs) - self._debug_log(dt, "dt", self._XZ_DIMS, kwargs) + if self._debug.enabled: + self._debug(z, "z", self._XZ_DIMS, kwargs) + self._debug(x, "x", self._XZ_DIMS, kwargs) + self._debug(b, "b", self._BC_DIMS, kwargs) + self._debug(c, "c", self._BC_DIMS, kwargs) + self._debug(dt, "dt", self._XZ_DIMS, kwargs) y = selective_scan_fn( x, @@ -217,8 +229,8 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ delta_softplus=True, ) - if self._debug_level: - self._debug_log(y, "y", self._XZ_DIMS, kwargs) + if self._debug.enabled: + self._debug(y, "y", self._XZ_DIMS, kwargs) # y: (batch, local_heads * state, sequence) -> (batch, sequence, local_heads * state) y = y.transpose(1, 2)[:, :sequence_length] diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index f5b0139cf..590edf18c 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -6,8 +6,8 @@ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType +from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs -from fast_llm.layers.block.mixer import Mixer from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.tensor import LambdaInitializer, ParameterMeta, init_kaiming_, init_ones_ @@ -52,7 +52,7 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) return LambdaInitializer(init_) -class MambaLayer(Mixer): +class MambaLayer(BlockLayer): _mixer_name: typing.ClassVar[str] = "mamba" def __init__( @@ -62,7 +62,13 @@ def __init__( tensor_space: TensorSpace, block_config: BlockConfig, ): - super().__init__(tensor_space, block_index, debug_level=block_config.debug_transformer) + super().__init__( + tensor_space, + block_index, + self._mixer_name, + debug_level=block_config.debug_transformer, + debug_memory=block_config.debug_transformer_memory, + ) assert tensor_space.distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" self._config = config # TODO: It's not silu? diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index e84e92a96..6598d3a29 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -6,7 +6,7 @@ from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.autograd import wrap_forward_backward -from fast_llm.layers.block.mixer import Mixer +from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs, TransformerConfig @@ -46,7 +46,7 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # no return grad, None -class Attention(Mixer): +class Attention(BlockLayer): """ A self-attention layer. """ @@ -72,7 +72,13 @@ class Attention(Mixer): ) def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int): - super().__init__(tensor_space, block_index, config.debug_transformer) + super().__init__( + tensor_space, + block_index, + self._mixer_name, + debug_level=config.debug_transformer, + debug_memory=config.debug_transformer_memory, + ) self._config = config self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) @@ -259,7 +265,13 @@ def _decide_window_size(self) -> int | None: return window_size - def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: sequence_first = kwargs[AttentionKwargs.sequence_first] query, key_value = self._query_key_value(input_, sequence_first) @@ -295,14 +307,9 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels) value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels) - if self._debug_level: - self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs) - self._debug_log( - key, - "key_rotary_input", - self._KV_DIMS, - kwargs, - ) + if self._debug.enabled: + self._debug(query, "query_rotary_input", self._QUERY_DIMS, kwargs) + self._debug(key, "key_rotary_input", self._KV_DIMS, kwargs) query, key = self._rotary(query, key, kwargs) window_size = self._decide_window_size() @@ -349,21 +356,11 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ kwargs[AttentionKwargs.attention_mask_value], ) - if self._debug_level: - self._debug_log(query, "query", self._QUERY_DIMS, kwargs) - self._debug_log( - key, - "key", - self._KV_DIMS, - kwargs, - ) - self._debug_log( - value, - "value", - self._KV_DIMS, - kwargs, - ) - self._debug_log(input_, "context", self._CONTEXT_DIMS, kwargs) + if self._debug.enabled: + self._debug(query, "query", self._QUERY_DIMS, kwargs) + self._debug(key, "key", self._KV_DIMS, kwargs) + self._debug(value, "value", self._KV_DIMS, kwargs) + self._debug(input_, "context", self._CONTEXT_DIMS, kwargs) if sequence_first: # TODO: Optimize (is contiguous avoidable? Transpose dense output?) diff --git a/fast_llm/layers/transformer/block.py b/fast_llm/layers/transformer/block.py index 4a0e818f0..89d7a2e3b 100644 --- a/fast_llm/layers/transformer/block.py +++ b/fast_llm/layers/transformer/block.py @@ -2,8 +2,7 @@ import typing from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.block.block import Block -from fast_llm.layers.block.mixer import Mixer +from fast_llm.layers.block.block import Block, BlockLayer from fast_llm.layers.transformer.attention import Attention from fast_llm.layers.transformer.config import TransformerConfig @@ -19,5 +18,5 @@ class TransformerBlock[ConfigType: TransformerConfig](Block[ConfigType]): def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, return_input: bool = False): super().__init__(config, tensor_space, block_index, return_input) - def _create_mixer(self) -> Mixer: + def _create_mixer(self) -> BlockLayer: return Attention(self._config, self._tensor_space, self._block_index) From 9fbb9ff52081c7444a2a54547319ddb8ec05ad01 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 31 Jul 2025 17:18:16 -0400 Subject: [PATCH 08/28] misc --- docs/developer_guide/conversion.md | 30 ++--- fast_llm/layers/block/block.py | 106 +++++++++--------- .../layers/block/mlp/mixture_of_experts.py | 9 +- fast_llm/layers/block/mlp/mlp.py | 19 ++-- tests/test_mlp.py | 4 +- 5 files changed, 88 insertions(+), 80 deletions(-) diff --git a/docs/developer_guide/conversion.md b/docs/developer_guide/conversion.md index 0620beaea..35a324db0 100644 --- a/docs/developer_guide/conversion.md +++ b/docs/developer_guide/conversion.md @@ -230,21 +230,21 @@ Continuing our `AwesomeModel` handler example, we define: ```python def _create_weight_converters(self) -> list[WeightConverter]: - converters = [] - # The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`. - num_layers = self._model.config.base_model.transformer.num_layers - - # A simple renaming example, for the word embeddings. - converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) - - # We usually want to loop dynamically over layers - for i in range(num_layers): - # A `SplitWeightConverter` example, splitting a weight in two. - converters.append(SplitWeightConverter( - f"layers.{i + 1}.weight", - (f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"), - )) - return converters + converters = [] + # The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`. + num_layers = self._model.config.base_model.transformer.num_layers + + # A simple renaming example, for the word embeddings. + converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) + + # We usually want to loop dynamically over layers + for i in range(num_layers): + # A `SplitWeightConverter` example, splitting a weight in two. + converters.append(SplitWeightConverter( + f"layers.{i + 1}.weight", + (f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"), + )) + return converters ``` And that's it! We're ready to use the new checkpoint format in Fast-LLM. diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 87a8f81cf..84fb5f2d4 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -1,5 +1,6 @@ import abc import functools +import logging import typing import torch @@ -15,6 +16,8 @@ from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta +logger = logging.getLogger(__name__) + class DebugLayer: # TODO: Move elsewhere? @@ -47,9 +50,11 @@ def _get_meta( def enabled(self) -> bool: return self._debug_level > 0 or self._debug_memory - def __call__( + def __call__[ + T + ]( self, - tensor: torch.Tensor, + tensor: torch.Tensor | None, name: str, dims: tuple[TensorDim | str, ...], kwargs: dict[str, typing.Any], @@ -60,7 +65,7 @@ def __call__( # TODO: Local vs global? if self._debug_memory: log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"{self._name} {name}", str)) - if self._debug_level > 0: + if self._debug_level > 0 and tensor is not None: log_distributed_tensor( "", tensor, @@ -112,11 +117,8 @@ def forward( ) -> tuple[torch.Tensor, torch.Tensor | None]: pass - def _debug_log(self, tensor: torch.Tensor) -> None: - pass - -class Block[ConfigType: BlockConfig](Layer, Configurable[ConfigType]): +class Block[ConfigType: BlockConfig](Configurable[ConfigType], Layer): """ A transformer-like decoder base block with abstract mixer. """ @@ -125,10 +127,15 @@ class Block[ConfigType: BlockConfig](Layer, Configurable[ConfigType]): _mixer_module_name: typing.ClassVar[str] = "mixer" def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, return_input: bool = False): - super().__init__() - self._config = config + super().__init__(config) self._tensor_space: TensorSpace = tensor_space self._dropout_p: float = self._config.hidden_dropout + self._debug = DebugLayer( + tensor_space, + f"Block {self._block_index} {self._name}", + self._config.debug_transformer, + self._config.debug_transformer_memory, + ) # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input @@ -144,7 +151,9 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: i setattr(self, self._mixer_module_name, self._create_mixer()) self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, self._tensor_space, f"{self.name} mlp", block_index=block_index + self._config, + self._tensor_space, + self._block_index, ) # PEFT. @@ -163,35 +172,9 @@ def _bias_dropout_add( input_ = input_ + bias return residual + torch.dropout(input_, self._dropout_p, self.training) - @property - def name(self) -> str: - return f"{self._name} {self._block_index}" - - def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): - dims = kwargs[BlockKwargs.hidden_dims] - if self._return_input: - dims = (TensorDim("stacked_input_output", 2),) + dims - return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype) - - def _debug_log(self, tensor: torch.Tensor | None, name: str, kwargs: dict[str, typing.Any], *, bias=None) -> None: - if self._config.debug_transformer_memory: - log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"{self.name} {name}", str)) - if self._config.debug_transformer and tensor is not None: - # TODO: Local vs global - log_distributed_tensor( - "", - tensor if bias is None else tensor + bias, - level=self._config.debug_transformer, - meta=self._get_meta(tensor, name, kwargs), - distributed=self._tensor_space.distributed, - ) - log_distributed_grad( - "", - tensor, - level=self._config.debug_transformer, - meta=self._get_meta(tensor, name + " grad", kwargs), - distributed=self._tensor_space.distributed, - ) + # @property + # def name(self) -> str: + # return f"{self._name} {self._block_index}" def forward( self, @@ -201,35 +184,50 @@ def forward( metrics: dict[str, typing.Any] | None = None, ) -> torch.Tensor: if isinstance(input_, TensorMeta): - return self._get_meta(input_, "output", kwargs) + dims = kwargs[BlockKwargs.hidden_dims] + if self._return_input: + dims = (TensorDim("stacked_input_output", 2),) + dims + return TensorMeta.from_dims( + dims, tensor_name=f"{self._name} {self._block_index} output", dtype=input_.dtype + ) generator = ( self._tensor_space.distributed.tp_generator if self._tensor_space.distributed_config.sequence_tensor_parallel else self._tensor_space.distributed.pp_generator ) - if self._debug_mode: - self._debug_log(None, "Begin", kwargs) + if self._debug.enabled: + self._debug(None, "Begin", kwargs[BlockKwargs.hidden_dims], kwargs) fw_input = input_ hidden_states = self.norm_1(input_) - if self._debug_mode: - self._debug_log(hidden_states, "Norm 1", kwargs) + if self._debug.enabled: + self._debug(hidden_states, "Norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) hidden_states, bias = getattr(self, self._mixer_module_name)(hidden_states, kwargs) - if self._debug_mode: - self._debug_log(hidden_states, f"{self._mixer_module_name} output", kwargs, bias=bias) + if self._debug.enabled: + self._debug( + hidden_states if bias is None else hidden_states + bias, + f"{self._mixer_module_name} output", + kwargs[BlockKwargs.hidden_dims], + kwargs, + ) with set_generator(generator): input_ = self._bias_dropout_add(hidden_states, bias, input_) - if self._debug_mode: - self._debug_log(input_, f"{self._mixer_module_name} residual", kwargs) + if self._debug.enabled: + self._debug(input_, f"{self._mixer_module_name} residual", kwargs[BlockKwargs.hidden_dims], kwargs) hidden_states = self.norm_2(input_) - if self._debug_mode: - self._debug_log(hidden_states, "Norm 2", kwargs) + if self._debug.enabled: + self._debug(hidden_states, "Norm 2", kwargs[BlockKwargs.hidden_dims], kwargs) hidden_states, bias = self.mlp(hidden_states, kwargs, losses, metrics) - if self._debug_mode: - self._debug_log(hidden_states, "MLP output", kwargs, bias=bias) + if self._debug.enabled: + self._debug( + hidden_states if bias is None else hidden_states + bias, + "MLP output", + kwargs[BlockKwargs.hidden_dims], + kwargs, + ) with set_generator(generator): hidden_states = self._bias_dropout_add(hidden_states, bias, input_) - if self._debug_mode: - self._debug_log(None, "MLP residual", kwargs, bias=bias) + if self._debug.enabled: + self._debug(None, "MLP residual", kwargs[BlockKwargs.hidden_dims], kwargs) if self._return_input: hidden_states = torch.stack((fw_input, hidden_states), dim=0) return hidden_states diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index 8d092b6dc..88d7ecf62 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -35,11 +35,16 @@ class MixtureOfExpertMLP[ConfigType: BlockConfig](MLPBase[ConfigType]): _group: ProcessGroup - def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): + def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: int = 0, name: str = "mlp"): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, tensor_space, name, block_index) + super().__init__( + config, + tensor_space, + block_index, + name, + ) self._tensor_space = tensor_space self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index 19349671e..a0980c39e 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -2,11 +2,10 @@ import torch -from fast_llm.config import Configurable -from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd +from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockDimNames from fast_llm.layers.block.mlp.config import MLPDimNames from fast_llm.layers.block.peft import TransformerSubLayerName @@ -15,9 +14,15 @@ from fast_llm.utils import Assert, get_lr_scale -class MLPBase[ConfigType: BlockConfig](Configurable[ConfigType], Layer): - def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): - super().__init__(config) +class MLPBase(BlockLayer): + def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: int = 0, name: str = "mlp"): + super().__init__( + tensor_space, + block_index, + name, + debug_level=config.debug_transformer, + debug_memory=config.debug_transformer_memory, + ) self._name = name self._block_index = block_index @@ -67,9 +72,9 @@ def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = " class MLP[ConfigType: BlockConfig](MLPBase[ConfigType]): - def __init__(self, config: BlockConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): + def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: int = 0, name: str = "mlp"): Assert.eq(config.num_experts, 1) - super().__init__(config, tensor_space, name, block_index) + super().__init__(config, tensor_space, block_index, name) def forward( self, diff --git a/tests/test_mlp.py b/tests/test_mlp.py index 5875822ff..802833eb2 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -15,7 +15,7 @@ def test_mlp_constructor(): tensor_space = TensorSpace(distributed_config=distributed_config) transformer_conf.setup_tensor_space(tensor_space) - MLP(transformer_conf, tensor_space, "name") + MLP(transformer_conf, tensor_space, 0, "name") def test_moe_mlp_constructor(): @@ -26,4 +26,4 @@ def test_moe_mlp_constructor(): tensor_space = TensorSpace(distributed_config=distributed_config) transformer_conf.setup_tensor_space(tensor_space) - MixtureOfExpertMLP(transformer_conf, tensor_space, "name") + MixtureOfExpertMLP(transformer_conf, tensor_space, 0, "name") From 44df195a207957254fb9bd50354c70cebe63766e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 31 Jul 2025 18:08:29 -0400 Subject: [PATCH 09/28] misc --- .../engine/config_utils/initialization.py | 57 +++++++++++++++ fast_llm/layers/block/block.py | 28 ++++---- .../layers/block/mlp/mixture_of_experts.py | 3 +- fast_llm/layers/block/mlp/mlp.py | 2 +- fast_llm/layers/common/config.py | 2 +- fast_llm/layers/common/linear.py | 3 +- fast_llm/layers/common/normalization.py | 3 +- fast_llm/layers/language_model/embedding.py | 3 +- fast_llm/layers/language_model/head.py | 3 +- fast_llm/layers/ssm/config.py | 4 +- fast_llm/layers/ssm/discrete_mamba2.py | 3 +- fast_llm/layers/ssm/mamba2.py | 3 +- fast_llm/layers/ssm/mamba_layer.py | 3 +- fast_llm/layers/transformer/attention.py | 2 +- fast_llm/tensor.py | 70 +------------------ 15 files changed, 91 insertions(+), 98 deletions(-) create mode 100644 fast_llm/engine/config_utils/initialization.py diff --git a/fast_llm/engine/config_utils/initialization.py b/fast_llm/engine/config_utils/initialization.py new file mode 100644 index 000000000..b60070562 --- /dev/null +++ b/fast_llm/engine/config_utils/initialization.py @@ -0,0 +1,57 @@ +import abc +import typing + +if typing.TYPE_CHECKING: + import torch + + from fast_llm.tensor import ParameterMeta + + +class Initializer(abc.ABC): + @abc.abstractmethod + def __call__(self, meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: + pass + + requires_global_initialization = False + + +class LambdaInitializer(Initializer): + def __init__( + self, + init_method: typing.Callable[["ParameterMeta", "torch.Tensor", "torch.Generator"], None], + requires_global_initialization: bool = False, + ) -> None: + self._init_method = init_method + self.requires_global_initialization = requires_global_initialization + + def __call__(self, meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: + return self._init_method(meta, tensor, generator) + + +def init_fill_(value: float) -> LambdaInitializer: + def init_(meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: # noqa + tensor.fill_(value) + + return LambdaInitializer(init_) + + +init_zeros_ = init_fill_(0.0) +init_ones_ = init_fill_(1.0) + + +def init_normal_( + mean: float = 0.0, std: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: # noqa + tensor = tensor.normal_(mean, std, generator=generator) + if min_val is not None or max_val is not None: + tensor.clamp_(min=min_val, max=max_val) + + return LambdaInitializer(init_) + + +def init_uniform_centered_(scale: float, mean: float = 0.0) -> LambdaInitializer: + def init_(meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: # noqa + tensor.uniform_(mean - scale, mean + scale, generator=generator) + + return LambdaInitializer(init_) diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 84fb5f2d4..292d2c9a4 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -102,7 +102,7 @@ def __init__(self, tensor_space: TensorSpace, block_index: int, name: str, debug self._sequence_parallel: bool = self._tensor_space.distributed_config.sequence_tensor_parallel self._debug = DebugLayer( tensor_space, - f"Block {self._block_index} {self._name}", + self._name, debug_level, debug_memory, ) @@ -128,19 +128,19 @@ class Block[ConfigType: BlockConfig](Configurable[ConfigType], Layer): def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, return_input: bool = False): super().__init__(config) + # TODO: Argument? + self._name = f"Block {self._block_index}" self._tensor_space: TensorSpace = tensor_space self._dropout_p: float = self._config.hidden_dropout + # For multi-token prediction, return a stack of shared_hidden and transformer_output. + self._return_input: bool = return_input + self._block_index = block_index self._debug = DebugLayer( tensor_space, - f"Block {self._block_index} {self._name}", + self._name, self._config.debug_transformer, self._config.debug_transformer_memory, ) - # For multi-token prediction, return a stack of shared_hidden and transformer_output. - self._return_input: bool = return_input - - self._block_index = block_index - self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory hidden_dim = self._tensor_space[BlockDimNames.hidden] # Note, layer_lr_scale does not impact the norms # TODO: add a separate norm_lr_scale @@ -187,35 +187,33 @@ def forward( dims = kwargs[BlockKwargs.hidden_dims] if self._return_input: dims = (TensorDim("stacked_input_output", 2),) + dims - return TensorMeta.from_dims( - dims, tensor_name=f"{self._name} {self._block_index} output", dtype=input_.dtype - ) + return TensorMeta.from_dims(dims, tensor_name=f"{self._name} output", dtype=input_.dtype) generator = ( self._tensor_space.distributed.tp_generator if self._tensor_space.distributed_config.sequence_tensor_parallel else self._tensor_space.distributed.pp_generator ) if self._debug.enabled: - self._debug(None, "Begin", kwargs[BlockKwargs.hidden_dims], kwargs) + self._debug(None, "begin", kwargs[BlockKwargs.hidden_dims], kwargs) fw_input = input_ hidden_states = self.norm_1(input_) if self._debug.enabled: - self._debug(hidden_states, "Norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) + self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) hidden_states, bias = getattr(self, self._mixer_module_name)(hidden_states, kwargs) if self._debug.enabled: self._debug( hidden_states if bias is None else hidden_states + bias, - f"{self._mixer_module_name} output", + "mixer output", kwargs[BlockKwargs.hidden_dims], kwargs, ) with set_generator(generator): input_ = self._bias_dropout_add(hidden_states, bias, input_) if self._debug.enabled: - self._debug(input_, f"{self._mixer_module_name} residual", kwargs[BlockKwargs.hidden_dims], kwargs) + self._debug(input_, "mixer residual", kwargs[BlockKwargs.hidden_dims], kwargs) hidden_states = self.norm_2(input_) if self._debug.enabled: - self._debug(hidden_states, "Norm 2", kwargs[BlockKwargs.hidden_dims], kwargs) + self._debug(hidden_states, "norm 2", kwargs[BlockKwargs.hidden_dims], kwargs) hidden_states, bias = self.mlp(hidden_states, kwargs, losses, metrics) if self._debug.enabled: self._debug( diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index 88d7ecf62..46005234c 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -5,6 +5,7 @@ import torch from fast_llm.core.distributed import ProcessGroup, set_generator +from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped @@ -15,7 +16,7 @@ from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.linear import Linear from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage -from fast_llm.tensor import TensorMeta, init_normal_ +from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert, get_lr_scale logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index a0980c39e..7d4643673 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -2,6 +2,7 @@ import torch +from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_ from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd @@ -10,7 +11,6 @@ from fast_llm.layers.block.mlp.config import MLPDimNames from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import LinearBase -from fast_llm.tensor import init_normal_, init_zeros_ from fast_llm.utils import Assert, get_lr_scale diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 9d5ce3f3b..2f45fdf9f 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -87,7 +87,7 @@ class LayerNormalizationBaseConfig(NormalizationConfig): ) def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": - from fast_llm.tensor import init_uniform_centered_ + from fast_llm.engine.config_utils.initialization import init_uniform_centered_ kwargs = { "hidden_dim": hidden_dim, diff --git a/fast_llm/layers/common/linear.py b/fast_llm/layers/common/linear.py index 7249ef569..740b4847c 100644 --- a/fast_llm/layers/common/linear.py +++ b/fast_llm/layers/common/linear.py @@ -3,6 +3,7 @@ import torch +from fast_llm.engine.config_utils.initialization import init_zeros_ from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.linear import ( @@ -14,7 +15,7 @@ output_parallel_linear_backward, output_parallel_linear_forward, ) -from fast_llm.tensor import ParameterMeta, init_zeros_ +from fast_llm.tensor import ParameterMeta logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index bccc1d627..d44be3297 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -1,11 +1,12 @@ import torch +from fast_llm.engine.config_utils.initialization import init_ones_, init_zeros_ from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.normalization import triton_normalization_autograd from fast_llm.layers.common.config import NormalizationImplementation -from fast_llm.tensor import ParameterMeta, accumulate_gradient, init_ones_, init_zeros_ +from fast_llm.tensor import ParameterMeta, accumulate_gradient from fast_llm.utils import Assert try: diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 05678a700..68aa4882b 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -6,9 +6,10 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import reduce_forward, split from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelDimNames, LanguageModelKwargs -from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ +from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import Assert WORD_EMBEDDINGS_WEIGHT = "word_embeddings_weight" diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 0623ac201..63d1a6b27 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -8,6 +8,7 @@ from fast_llm.config import Configurable from fast_llm.core.ops import split_op from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward @@ -25,7 +26,7 @@ ) from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.logging import log_distributed_tensor -from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ +from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import Assert, div, get_unique logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index efcf2d873..00c709814 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -9,7 +9,7 @@ from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: - from fast_llm.tensor import Initializer + from fast_llm.engine.config_utils.initialization import Initializer, init_fill_, init_uniform_centered_ class SSMDimNames(BlockDimNames): @@ -66,8 +66,6 @@ class DTInitType(enum.StrEnum): random = "random" def get_init_method(self, scale: float) -> "Initializer": - from fast_llm.tensor import init_fill_, init_uniform_centered_ - return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index e48636926..e967ab9d1 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -4,13 +4,14 @@ import einops import torch +from fast_llm.engine.config_utils.initialization import init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ +from fast_llm.tensor import ParameterMeta from fast_llm.utils import get_lr_scale logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 4357c0e86..5d62c144f 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -3,6 +3,7 @@ import torch +from fast_llm.engine.config_utils.initialization import init_kaiming_, init_ones_, init_uniform_centered_ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer @@ -10,7 +11,7 @@ from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias -from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_ +from fast_llm.tensor import ParameterMeta from fast_llm.utils import Assert, div, get_lr_scale try: diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 590edf18c..0f3224f77 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -4,13 +4,14 @@ import torch +from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_kaiming_, init_ones_ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.tensor import LambdaInitializer, ParameterMeta, init_kaiming_, init_ones_ +from fast_llm.tensor import ParameterMeta from fast_llm.utils import Assert, get_lr_scale try: diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 6598d3a29..ba7f2bb6e 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -4,13 +4,13 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim +from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_ from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs, TransformerConfig -from fast_llm.tensor import init_normal_, init_zeros_ from fast_llm.utils import get_lr_scale try: diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index d080e6a1e..b12d12072 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -1,13 +1,12 @@ -import abc import functools import logging -import math import typing import torch from fast_llm.core.distributed import ReduceOp from fast_llm.core.ops import reduce_op +from fast_llm.engine.config_utils.initialization import Initializer, LambdaInitializer from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed @@ -361,70 +360,3 @@ def accumulate_gradient(param: torch.Tensor, grad: torch.Tensor) -> None: triton_copy(grad, param.grad_buffer) # noqa else: triton_add(grad, param.grad_buffer, out=param.grad_buffer) # noqa - - -class Initializer(abc.ABC): - @abc.abstractmethod - def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: - pass - - requires_global_initialization = False - - -class LambdaInitializer(Initializer): - def __init__( - self, - init_method: typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None], - requires_global_initialization: bool = False, - ) -> None: - self._init_method = init_method - self.requires_global_initialization = requires_global_initialization - - def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: - return self._init_method(meta, tensor, generator) - - -def init_fill_(value: float) -> LambdaInitializer: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa - tensor.fill_(value) - - return LambdaInitializer(init_) - - -init_zeros_ = init_fill_(0.0) -init_ones_ = init_fill_(1.0) - - -def init_normal_( - mean: float = 0.0, std: float = 1.0, min_val: float | None = None, max_val: float | None = None -) -> LambdaInitializer: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa - tensor = tensor.normal_(mean, std, generator=generator) - if min_val is not None or max_val is not None: - tensor.clamp_(min=min_val, max=max_val) - - return LambdaInitializer(init_) - - -def init_kaiming_(d_in: float) -> LambdaInitializer: - return init_normal_(0.0, math.sqrt(2.0 / d_in)) - - -def init_uniform_( - low: float = 0.0, high: float = 1.0, min_val: float | None = None, max_val: float | None = None -) -> LambdaInitializer: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa - tensor = tensor.uniform_(low, high, generator=generator) - if min_val is not None or max_val is not None: - tensor.clamp_(min=min_val, max=max_val) - - return LambdaInitializer(init_) - - -def init_uniform_centered_(high: float, max_val: float | None = None, mean: float = 0.0) -> LambdaInitializer: - return init_uniform_( - mean - high, - mean + high, - min_val=None if max_val is None else mean - max_val, - max_val=None if max_val is None else mean + max_val, - ) From 3bb03cb3cf4bc64ba286f4f9a5074d0ecff8c227 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 31 Jul 2025 19:15:14 -0400 Subject: [PATCH 10/28] misc --- fast_llm/layers/block/config.py | 2 - fast_llm/layers/block/mlp/config.py | 4 +- .../layers/block/mlp/mixture_of_experts.py | 127 ++++++------------ fast_llm/layers/block/mlp/mlp.py | 42 +++--- 4 files changed, 63 insertions(+), 112 deletions(-) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 489cd4f3f..6111c7e00 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -120,8 +120,6 @@ class BlockConfig(MLPConfig, BaseModelConfig): def _validate(self) -> None: with self._set_implicit_default(): - if self.ffn_hidden_size is None: - self.ffn_hidden_size = 4 * self.hidden_size # TODO: Review initialization if self.init_method_std is None: self.init_method_std = self.hidden_size**-0.5 diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 64e234544..92697de44 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -158,9 +158,11 @@ class MLPConfig(Config): def _validate(self) -> None: with self._set_implicit_default(): - # TODO: Make this work without inheritance. if self.activation_type is None: self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu + # TODO: Make this work without inheritance. + if self.ffn_hidden_size is None: + self.ffn_hidden_size = 4 * self.hidden_size # TODO: Review initialization if self.init_method_std_mlp_1 is None: self.init_method_std_mlp_1 = self.init_method_std diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index 46005234c..3a517db20 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -1,12 +1,10 @@ import logging -import typing import warnings import torch from fast_llm.core.distributed import ProcessGroup, set_generator from fast_llm.engine.config_utils.initialization import init_normal_ -from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped from fast_llm.functional.triton.sparse_copy import get_sparse_map @@ -15,8 +13,6 @@ from fast_llm.layers.block.mlp.mlp import MLPBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.linear import Linear -from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage -from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert, get_lr_scale logger = logging.getLogger(__name__) @@ -40,59 +36,44 @@ def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__( - config, - tensor_space, - block_index, - name, - ) - self._tensor_space = tensor_space - self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory - - self._num_experts = config.num_experts - self._experts_per_token = config.num_experts_per_token - self._num_shared_experts = config.num_shared_experts - self._num_unshared_experts = config.num_unshared_experts - - self._routing_type = config.expert_routing_type - self._load_balancing_factor = config.expert_auxiliary_loss_coefficient - self._z_loss_factor = config.expert_z_loss_coefficient - self._moe_jitter_eps = config.moe_jitter_eps + super().__init__(config, tensor_space, block_index, name) - layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None - router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) + layer_lr_scale = self._config.per_layer_lr_scale[block_index] if self._config.per_layer_lr_scale else None + router_lr_scale = get_lr_scale(self._config.router_lr_scale, layer_lr_scale) self.router = Linear( tensor_space[BlockDimNames.hidden], tensor_space[MLPDimNames.unshared_experts], bias=False, weight_init_method=init_normal_( - std=config.init_method_std, min_val=config.init_method_min, max_val=config.init_method_max + std=self._config.init_method_std, + min_val=self._config.init_method_min, + max_val=self._config.init_method_max, ), lr_scale=router_lr_scale, ) - dropless_moe = config.dropless_moe + dropless_moe = self._config.dropless_moe if dropless_moe and tensor_space.distributed_config.sequence_tensor_parallel: warnings.warn( "Dropless MoE not supported for sequence-tensor-parallel, falling back to looped implementation." ) dropless_moe = False self._mlp_forward = self._forward_dropless if dropless_moe else self._forward_looped - self._dynamic_shape = config.dropless_dynamic_shape + self._dynamic_shape = self._config.dropless_dynamic_shape def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: hidden_states = input_.flatten(0, -2) logits = self.router(hidden_states) - if self._debug_mode: - self._debug_log(logits, "Router logits", MLPDimNames.experts, kwargs) + if self._debug.enabled: + self._debug(logits, "Router logits", MLPDimNames.experts, kwargs) # Apply z_loss if applicable - if self._z_loss_factor > 0.0: + if self._config.expert_z_loss_coefficient > 0.0: logits = z_loss( logits, - self._z_loss_factor, + self._config.expert_z_loss_coefficient, self.training, grad_scale=kwargs.get("grad_output"), losses=losses, @@ -100,24 +81,24 @@ def forward( ) # Apply input_jitter if applicable: - if self.training and self._moe_jitter_eps > 0.0: + if self.training and self._config.moe_jitter_eps > 0.0: with set_generator(self._tensor_space.distributed.pp_generator): logits = self._apply_input_jitter(logits) # Routing - if self._routing_type == RoutingType.topk: + if self._config.expert_routing_type == RoutingType.topk: scores, top_experts = self._topk_routing(logits, kwargs.get(BlockKwargs.grad_output), losses) - if self._num_shared_experts > 0: + if self._config.num_shared_experts > 0: scores, top_experts = self._add_shared_experts(top_experts, scores) - elif self._routing_type == RoutingType.sinkhorn: + elif self._config.expert_routing_type == RoutingType.sinkhorn: scores, top_experts = self._sinkhorn_routing(logits) else: - raise NotImplementedError(self._routing_type) + raise NotImplementedError(self._config.expert_routing_type) - if self._debug_mode: + if self._debug.enabled: # To log all ranks set `global_=False` - self._debug_log(scores, "Router scores", MLPDimNames.top_experts, kwargs) - self._debug_log(top_experts, "Router top experts", MLPDimNames.top_experts, kwargs) + self._debug(scores, "Router scores", MLPDimNames.top_experts, kwargs) + self._debug(top_experts, "Router top experts", MLPDimNames.top_experts, kwargs) return self._mlp_forward(hidden_states, scores, top_experts).view_as(input_), None # noqa @@ -125,7 +106,7 @@ def _forward_dropless( self, hidden_states: torch.Tensor, scores: torch.Tensor, top_experts: torch.Tensor ) -> torch.Tensor: # Compute token counts and the sparse mapping (dense_row, top_index) -> sparse_row. - sparse_map = get_sparse_map(top_experts, self._num_experts, dynamic_shape=self._dynamic_shape) + sparse_map = get_sparse_map(top_experts, self._config.num_experts, dynamic_shape=self._dynamic_shape) # Sparse MLP return mlp_autograd( @@ -154,7 +135,7 @@ def _forward_looped( top_experts, self.layer_1.weight, self.layer_2.weight, - self._num_experts, + self._config.num_experts, self._config.gated, self._config.activation_type, self._intermediate_dim.parallel_group, @@ -165,7 +146,9 @@ def _forward_looped( @torch.compile def _apply_input_jitter(self, logits: torch.Tensor) -> torch.Tensor: - return logits * torch.empty_like(logits).uniform_(1.0 - self._moe_jitter_eps, 1.0 + self._moe_jitter_eps) + return logits * torch.empty_like(logits).uniform_( + 1.0 - self._config.moe_jitter_eps, 1.0 + self._config.moe_jitter_eps + ) def _topk_routing( self, @@ -173,11 +156,11 @@ def _topk_routing( grad_scale: float | None = None, losses: dict | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - top_logits, top_experts = torch.topk(logits, k=self._experts_per_token, dim=-1) + top_logits, top_experts = torch.topk(logits, k=self._config.num_experts_per_token, dim=-1) scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32) if losses is not None or (self.training and grad_scale is not None): probs = torch.softmax(logits, dim=-1, dtype=torch.float32) - mask = torch.nn.functional.one_hot(top_experts, num_classes=self._num_unshared_experts).sum(dim=1) + mask = torch.nn.functional.one_hot(top_experts, num_classes=self._config.num_unshared_experts).sum(dim=1) # Auxiliary loss, corresponding to the sum of probabilities for the top experts. # In the optimal case (uniform distribution), loss = experts_per_token / num_experts. # In the worst case (whole distribution in the top experts), loss = 1. @@ -188,7 +171,9 @@ def _topk_routing( losses[MLPLossNames.load_balancing_loss].append(aux_loss.detach()) if self.training and grad_scale is not None: scores = AuxiliaryLoss.apply( - scores, aux_loss, self._num_unshared_experts * self._load_balancing_factor * grad_scale + scores, + aux_loss, + self._config.num_unshared_experts * self._config.expert_auxiliary_loss_coefficient * grad_scale, ) return scores, top_experts @@ -197,69 +182,33 @@ def _add_shared_experts( ) -> tuple[torch.Tensor, torch.Tensor]: # Add the shared experts (last ones) to the top experts. shared_experts = torch.arange( - self._num_unshared_experts, self._num_experts, device=top_experts.device, dtype=top_experts.dtype + self._config.num_unshared_experts, + self._config.num_experts, + device=top_experts.device, + dtype=top_experts.dtype, )[None].repeat(top_experts.size(0), 1) top_experts = torch.cat((shared_experts, top_experts), dim=1) # Add scores of 1 to scores for shared experts. - scores = torch.cat((scores.new_ones(scores.size(0), self._num_shared_experts), scores), dim=1) + scores = torch.cat((scores.new_ones(scores.size(0), self._config.num_shared_experts), scores), dim=1) return scores, top_experts def _sinkhorn_routing(self, logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: if self.training: - _, top_experts = torch.topk(sinkhorn(logits), k=self._experts_per_token, dim=-1) + _, top_experts = torch.topk(sinkhorn(logits), k=self._config.num_experts_per_token, dim=-1) logits = self._sinkhorn_activation(logits) scores = torch.gather(logits, -1, top_experts) else: logits = self._sinkhorn_activation(logits) - scores, top_experts = torch.topk(logits, k=self._experts_per_token, dim=-1) + scores, top_experts = torch.topk(logits, k=self._config.num_experts_per_token, dim=-1) return scores, top_experts def _sinkhorn_activation(self, logits: torch.Tensor) -> torch.Tensor: return ( torch.sigmoid(logits) - if self._experts_per_token == 1 + if self._config.num_experts_per_token == 1 else torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) ) - def _debug_log( - self, - tensor: torch.Tensor | None, - name: str, - dim_name: str, - kwargs: dict[str, typing.Any], - global_: bool = True, - ) -> None: - if self._config.debug_transformer_memory: - log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"{self._name} {name}", str)) - if self._config.debug_transformer and tensor is not None: - # TODO: Local vs global - meta = self._get_meta(tensor, name, dim_name, kwargs) - log_distributed_tensor( - "", - tensor.view_as(meta), - level=self._config.debug_transformer, - meta=meta, - distributed=self._tensor_space.distributed, - global_=global_, - ) - if tensor.requires_grad: - log_distributed_grad( - "", - tensor, - level=self._config.debug_transformer, - meta=self._get_meta(tensor, name + " grad", dim_name, kwargs), - distributed=self._tensor_space.distributed, - grad_fn=lambda tensor_: tensor_.view_as(meta), - global_=global_, - ) - - def _get_meta(self, tensor: torch.Tensor, name: str, dim_name: str, kwargs: dict[str, typing.Any]) -> TensorMeta: - return TensorMeta.from_dims( - kwargs[BlockKwargs.hidden_dims][:-1] + (self._tensor_space[dim_name],), - tensor_name=f"{self._name} {name}", - dtype=tensor.dtype, - ) - def sinkhorn(cost: torch.Tensor, tolerance: float = 1e-5, eps=1e-9) -> torch.Tensor: """Sinkhorn based MoE routing function""" diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index 7d4643673..577986e3a 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -23,52 +23,54 @@ def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: debug_level=config.debug_transformer, debug_memory=config.debug_transformer_memory, ) - self._name = name - self._block_index = block_index + self._config = config init_method_1 = init_normal_( - std=config.init_method_std_mlp_1, - min_val=config.init_method_min_mlp_1, - max_val=config.init_method_max_mlp_1, + std=self._config.init_method_std_mlp_1, + min_val=self._config.init_method_min_mlp_1, + max_val=self._config.init_method_max_mlp_1, ) init_method_2 = init_normal_( - std=config.init_method_std_mlp_2, - min_val=config.init_method_min_mlp_2, - max_val=config.init_method_max_mlp_2, + std=self._config.init_method_std_mlp_2, + min_val=self._config.init_method_min_mlp_2, + max_val=self._config.init_method_max_mlp_2, ) - hidden_dim = tensor_space[BlockDimNames.hidden] - self._intermediate_dim = tensor_space[MLPDimNames.composite_expert_mlp] - self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel + hidden_dim = self._tensor_space[BlockDimNames.hidden] + self._intermediate_dim = self._tensor_space[MLPDimNames.composite_expert_mlp] self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None - lr_scale = tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale + layer_lr_scale = self._config.per_layer_lr_scale[block_index] if self._config.per_layer_lr_scale else None + lr_scale = ( + tuple(self._config.mlp_lr_scale) + if isinstance(self._config.mlp_lr_scale, list) + else self._config.mlp_lr_scale + ) lr_scale = get_lr_scale(lr_scale, layer_lr_scale) # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, - tensor_space[MLPDimNames.composite_gated_expert_mlp], - bias=config.add_mlp_bias, + self._tensor_space[MLPDimNames.composite_gated_expert_mlp], + bias=self._config.add_mlp_bias, weight_init_method=init_method_1, - bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, + bias_init_method=init_zeros_, lr_scale=lr_scale, ) self.layer_2 = LinearBase( self._intermediate_dim, hidden_dim, - bias=config.add_mlp_bias, + bias=self._config.add_mlp_bias, weight_init_method=init_method_2, - bias_init_method=init_method_2 if config.random_bias_init else init_zeros_, + bias_init_method=init_zeros_, auto_bias_grad_accumulation=tensor_space.distributed_config.tensor_parallel > 1, transposed_weight=True, lr_scale=lr_scale, ) # PEFT. - self.layer_1 = config.peft.apply_linear(self.layer_1, TransformerSubLayerName.mlp_1) - self.layer_2 = config.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) + self.layer_1 = self._config.peft.apply_linear(self.layer_1, TransformerSubLayerName.mlp_1) + self.layer_2 = self._config.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) class MLP[ConfigType: BlockConfig](MLPBase[ConfigType]): From 98bae95d7595d13077c4608b0adedc23cdce1297 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 31 Jul 2025 19:41:57 -0400 Subject: [PATCH 11/28] misc --- fast_llm/layers/block/config.py | 8 ---- fast_llm/layers/block/mlp/config.py | 10 ++++ .../layers/block/mlp/mixture_of_experts.py | 18 +++++-- fast_llm/layers/block/mlp/mlp.py | 2 +- fast_llm/layers/language_model/embedding.py | 16 ++++--- fast_llm/layers/language_model/head.py | 47 +++++++------------ fast_llm/layers/ssm/discrete_mamba2.py | 11 ++++- fast_llm/layers/ssm/mamba2.py | 4 +- fast_llm/layers/ssm/mamba_layer.py | 14 +++++- fast_llm/layers/transformer/preprocessing.py | 6 +-- 10 files changed, 75 insertions(+), 61 deletions(-) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 6111c7e00..756e54dac 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -128,14 +128,6 @@ def _validate(self) -> None: super()._validate() - @property - def add_mlp_bias(self) -> bool: - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.everywhere: - return True - return False - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: super().setup_tensor_space(tensor_space) diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 92697de44..70f05956a 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -4,6 +4,7 @@ from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType, MLPRecomputeLevel +from fast_llm.layers.block.config import AddLinearBiasChoices from fast_llm.utils import Assert @@ -156,6 +157,15 @@ class MLPConfig(Config): hint=FieldHint.optional, ) + @property + def add_mlp_bias(self) -> bool: + # TODO: Make this work without inheritance. + if isinstance(self.add_linear_biases, bool): + return self.add_linear_biases + if self.add_linear_biases == AddLinearBiasChoices.everywhere: + return True + return False + def _validate(self) -> None: with self._set_implicit_default(): if self.activation_type is None: diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index 3a517db20..60cee9847 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -59,7 +59,6 @@ def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: ) dropless_moe = False self._mlp_forward = self._forward_dropless if dropless_moe else self._forward_looped - self._dynamic_shape = self._config.dropless_dynamic_shape def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None @@ -67,7 +66,7 @@ def forward( hidden_states = input_.flatten(0, -2) logits = self.router(hidden_states) if self._debug.enabled: - self._debug(logits, "Router logits", MLPDimNames.experts, kwargs) + self._debug(logits, "Router logits", kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.experts,), kwargs) # Apply z_loss if applicable if self._config.expert_z_loss_coefficient > 0.0: @@ -97,8 +96,15 @@ def forward( if self._debug.enabled: # To log all ranks set `global_=False` - self._debug(scores, "Router scores", MLPDimNames.top_experts, kwargs) - self._debug(top_experts, "Router top experts", MLPDimNames.top_experts, kwargs) + self._debug( + scores, "Router scores", kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.top_experts,), kwargs + ) + self._debug( + top_experts, + "Router top experts", + kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.top_experts,), + kwargs, + ) return self._mlp_forward(hidden_states, scores, top_experts).view_as(input_), None # noqa @@ -106,7 +112,9 @@ def _forward_dropless( self, hidden_states: torch.Tensor, scores: torch.Tensor, top_experts: torch.Tensor ) -> torch.Tensor: # Compute token counts and the sparse mapping (dense_row, top_index) -> sparse_row. - sparse_map = get_sparse_map(top_experts, self._config.num_experts, dynamic_shape=self._dynamic_shape) + sparse_map = get_sparse_map( + top_experts, self._config.num_experts, dynamic_shape=self._config.dropless_dynamic_shape + ) # Sparse MLP return mlp_autograd( diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index 577986e3a..6243c17bd 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -63,7 +63,7 @@ def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: bias=self._config.add_mlp_bias, weight_init_method=init_method_2, bias_init_method=init_zeros_, - auto_bias_grad_accumulation=tensor_space.distributed_config.tensor_parallel > 1, + auto_bias_grad_accumulation=self._tensor_space.distributed_config.tensor_parallel > 1, transposed_weight=True, lr_scale=lr_scale, ) diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 68aa4882b..051044ef6 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -22,19 +22,19 @@ class LanguageModelEmbedding[ConfigType: LanguageModelBaseConfig](Configurable[L together with optional absolute position embeddings and dropout. """ - config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = LanguageModelBaseConfig + config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = ConfigType # Ensure the layer is on its own stage. layer_count: float = 1000.0 def __init__( self, - config: LanguageModelBaseConfig, + config: ConfigType, tensor_space: TensorSpace, ): super().__init__(config) - self._distributed_config = tensor_space.distributed_config self._tensor_space = tensor_space + self._distributed_config = self._tensor_space.distributed_config self._residual_dtype = ( self._distributed_config.optimization_dtype if config.transformer.full_precision_residual @@ -42,12 +42,14 @@ def __init__( ).torch self._group_size = self._distributed_config.tensor_parallel self._sequence_parallel = self._distributed_config.sequence_tensor_parallel - self._parallel_embeddings = tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings + self._parallel_embeddings = ( + self._tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings + ) self._dropout_p = config.transformer.hidden_dropout self._use_absolute_position_embeddings = config.use_absolute_position_embeddings - hidden_dim = tensor_space[LanguageModelDimNames.hidden] - vocab_dim = tensor_space[ + hidden_dim = self._tensor_space[LanguageModelDimNames.hidden] + vocab_dim = self._tensor_space[ LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab ] @@ -66,7 +68,7 @@ def __init__( ) if self._use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( - (tensor_space[LanguageModelDimNames.position_embed], hidden_dim), + (self._tensor_space[LanguageModelDimNames.position_embed], hidden_dim), init_method=init_normal_( std=config.init_method_std_embed, min_val=config.init_method_min_embed, diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 63d1a6b27..2fa0b0f06 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -25,7 +25,6 @@ LanguageModelLossNames, ) from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT -from fast_llm.logging import log_distributed_tensor from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import Assert, div, get_unique @@ -34,16 +33,16 @@ OUTPUT_WEIGHTS = "output_weights" -class LanguageModelHead[ConfigType: LanguageModelBaseConfig](Configurable[LanguageModelBaseConfig], Layer): +class LanguageModelHead[ConfigType: LanguageModelBaseConfig](Configurable[ConfigType], Layer): """ A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). """ - config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = LanguageModelBaseConfig + config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = ConfigType def __init__( self, - config: LanguageModelBaseConfig, + config: ConfigType, tensor_space: TensorSpace, prediction_distance: int, ): @@ -105,8 +104,7 @@ def __init__( lr_scale=self._config.output_lr_scale, ) - self._use_dpo_loss = self._config.enable_dpo - if not self._use_dpo_loss: + if not self._config.enable_dpo: self._cross_entropy_impl = self._config.cross_entropy_impl if self._cross_entropy_impl == CrossEntropyImpl.auto: if self._parallel_embeddings: @@ -204,7 +202,7 @@ def _get_targets( self, kwargs: dict ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None] | None: # Loss mask for distillation. (Labels are already masked.) - if self._use_dpo_loss: + if self._config.enable_dpo: dpo_target = kwargs.get(LanguageModelKwargs.labels) lm_target = None distillation_target = None @@ -341,35 +339,22 @@ def _logits_cross_entropy_forward_backward( LanguageModelLossNames.z_loss, logits_scale_factor=self._logits_scale_factor, ) - if self._debug_transformer and self._cross_entropy_splits is None: - vocab_dim = self._tensor_space[ + if self._debug.enabled and self._cross_entropy_splits is None: + vocab_dim = ( LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp - ] - dims = [*kwargs[LanguageModelKwargs.hidden_dims][:-1], vocab_dim] - sequence_index = 1 - int(kwargs[LanguageModelKwargs.sequence_first]) - dims[sequence_index] = ( - TensorDim( - LanguageModelDimNames.sequence_q_tp, dims[sequence_index].global_size, DistributedDimNames.tensor - ) - if self._sequence_parallel_logits - else TensorDim(LanguageModelDimNames.sequence_q, dims[sequence_index].global_size) ) - - dim_names = ( - [LanguageModelDimNames.sequence_q_tp, LanguageModelDimNames.vocab] + sequence_dim = ( + LanguageModelDimNames.sequence_q_tp if self._sequence_parallel_logits - else [LanguageModelDimNames.sequence_q, LanguageModelDimNames.vocab_tp] + else LanguageModelDimNames.sequence_q ) - - dim_names.insert(int(kwargs[LanguageModelKwargs.sequence_first]), LanguageModelDimNames.batch) - log_distributed_tensor( - "", - logits, - level=self._debug_transformer, - meta=TensorMeta.from_dims(tuple(dims), tensor_name="transformer logits", dtype=logits.dtype), - distributed=self._tensor_space.distributed, - scale=self._logits_scale_factor, + batch_dim = kwargs[LanguageModelKwargs.hidden_dims][1 if kwargs[LanguageModelKwargs.sequence_first] else 0] + dims = ( + (sequence_dim, batch_dim, vocab_dim) + if kwargs[LanguageModelKwargs.sequence_first] + else (batch_dim, sequence_dim, vocab_dim) ) + self._debug(logits, "Language model logits", dims, kwargs, scale=self._logits_scale_factor) if targets is None: return logits * self._logits_scale_factor, None diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index e967ab9d1..61291f845 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -4,13 +4,14 @@ import einops import torch -from fast_llm.engine.config_utils.initialization import init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ +from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_, init_zeros_ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.ssm.mamba_layer import init_kaiming_ from fast_llm.tensor import ParameterMeta from fast_llm.utils import get_lr_scale @@ -117,7 +118,13 @@ def __init__( lr_scale=lr_scale, ) - def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available sequence_length = kwargs[BlockKwargs.sequence_q_dim].global_size diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 5d62c144f..b6626e893 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -3,14 +3,14 @@ import torch -from fast_llm.engine.config_utils.initialization import init_kaiming_, init_ones_, init_uniform_centered_ +from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias +from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias, init_kaiming_ from fast_llm.tensor import ParameterMeta from fast_llm.utils import Assert, div, get_lr_scale diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 0f3224f77..0dcc29f0b 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -4,7 +4,7 @@ import torch -from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_kaiming_, init_ones_ +from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer @@ -146,7 +146,13 @@ def __init__( ) self.out_proj.weight.auto_grad_accumulation = True - def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available in_proj = self.in_proj(input_).permute((1, 2, 0) if kwargs[BlockKwargs.sequence_first] else (0, 2, 1)) @@ -170,3 +176,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if kwargs[BlockKwargs.sequence_first]: out = out.transpose(0, 1) return out, None + + +def init_kaiming_(d_in: float) -> LambdaInitializer: + return init_normal_(0.0, math.sqrt(2.0 / d_in)) diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index d8fa14a6d..16e5811e6 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -5,7 +5,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.layers.transformer.config import AttentionKwargs, TransformerConfig +from fast_llm.layers.transformer.config import AttentionConfig, AttentionKwargs from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) @@ -21,7 +21,7 @@ class BackupAttentionPreprocessor(Preprocessor): def __init__( self, - config: TransformerConfig, + config: AttentionConfig, tensor_space: TensorSpace, ): self._config = config @@ -91,7 +91,7 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: class FlashAttnVarlenPreprocessor(Preprocessor): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace): + def __init__(self, config: AttentionConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config From fd731ef76ba1ac52610291cb17e38eee7107be71 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 31 Jul 2025 21:06:12 -0400 Subject: [PATCH 12/28] fixes --- fast_llm/data/data/abstract.py | 1 + fast_llm/data/data/gpt/data.py | 2 ++ fast_llm/engine/schedule/runner.py | 1 + fast_llm/layers/block/block.py | 9 ++++++--- fast_llm/layers/block/config.py | 10 ---------- fast_llm/layers/block/mlp/config.py | 3 ++- fast_llm/layers/block/mlp/mixture_of_experts.py | 2 +- fast_llm/layers/block/mlp/mlp.py | 2 +- fast_llm/layers/language_model/embedding.py | 2 +- fast_llm/layers/language_model/head.py | 2 +- fast_llm/layers/ssm/config.py | 4 +++- fast_llm/layers/transformer/config.py | 11 +++++++++++ fast_llm/layers/transformer/rotary/rotary.py | 9 +++++++++ 13 files changed, 39 insertions(+), 19 deletions(-) diff --git a/fast_llm/data/data/abstract.py b/fast_llm/data/data/abstract.py index e24d39985..04da64a9d 100644 --- a/fast_llm/data/data/abstract.py +++ b/fast_llm/data/data/abstract.py @@ -13,6 +13,7 @@ class Data[ConfigType: DataConfig](Configurable[ConfigType], abc.ABC): + config_class: typing.ClassVar[type[DataConfig]] = DataConfig _distributed: "Distributed" _sampling_parameters: dict[str, SamplingParameters] _cache_directory: pathlib.Path | None diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 6724afb59..37cfd9020 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -65,6 +65,8 @@ class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]): TODO: Separate generic and GPT classes. """ + config_class: typing.ClassVar[type[GPTDataConfig]] = GPTDataConfig + _datasets: dict[str, SampledDataset] _sampling_parameters: dict[str, GPTSamplingParameters] _tokenizer: Tokenizer | None diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 8eca4559d..7fdba1832 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -64,6 +64,7 @@ def __repr__(self): class ScheduleRunner[ConfigType: ScheduleConfig](Configurable[ScheduleConfig]): + config_class: typing.ClassVar[type[ScheduleConfig]] = ScheduleConfig _is_setup: bool = False _compute_stream: torch.cuda.Stream _data_stream: torch.cuda.Stream diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 292d2c9a4..03e0df928 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -11,8 +11,6 @@ from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs -from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP -from fast_llm.layers.block.mlp.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta @@ -123,18 +121,19 @@ class Block[ConfigType: BlockConfig](Configurable[ConfigType], Layer): A transformer-like decoder base block with abstract mixer. """ + config_class: typing.ClassVar[type[BlockConfig]] = BlockConfig # TODO: Standardize to `mixer` _mixer_module_name: typing.ClassVar[str] = "mixer" def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, return_input: bool = False): super().__init__(config) # TODO: Argument? + self._block_index = block_index self._name = f"Block {self._block_index}" self._tensor_space: TensorSpace = tensor_space self._dropout_p: float = self._config.hidden_dropout # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input - self._block_index = block_index self._debug = DebugLayer( tensor_space, self._name, @@ -150,6 +149,10 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: i # The mixer needs to be created here for backward-compatible weight ordering. setattr(self, self._mixer_module_name, self._create_mixer()) + # TODO: Use dynamic type. + from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP + from fast_llm.layers.block.mlp.mlp import MLP + self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( self._config, self._tensor_space, diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 756e54dac..919f95b3f 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -118,16 +118,6 @@ class BlockConfig(MLPConfig, BaseModelConfig): hint=FieldHint.optional, ) - def _validate(self) -> None: - with self._set_implicit_default(): - # TODO: Review initialization - if self.init_method_std is None: - self.init_method_std = self.hidden_size**-0.5 - if self.init_method_min is not None and self.init_method_max is not None: - Assert.leq(self.init_method_min, self.init_method_max) - - super()._validate() - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: super().setup_tensor_space(tensor_space) diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 70f05956a..a99debacc 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -4,7 +4,6 @@ from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType, MLPRecomputeLevel -from fast_llm.layers.block.config import AddLinearBiasChoices from fast_llm.utils import Assert @@ -159,6 +158,8 @@ class MLPConfig(Config): @property def add_mlp_bias(self) -> bool: + from fast_llm.layers.block.config import AddLinearBiasChoices + # TODO: Make this work without inheritance. if isinstance(self.add_linear_biases, bool): return self.add_linear_biases diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index 60cee9847..e53693460 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -18,7 +18,7 @@ logger = logging.getLogger(__name__) -class MixtureOfExpertMLP[ConfigType: BlockConfig](MLPBase[ConfigType]): +class MixtureOfExpertMLP(MLPBase): """ MoeLayer following implementation from https://github.com/NVIDIA/Megatron-LM/blob/46ebc0e4202c980d98900000d455f754a7ff9d4b/megatron/model/transformer.py#L346 diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index 6243c17bd..06850c8d0 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -73,7 +73,7 @@ def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: self.layer_2 = self._config.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) -class MLP[ConfigType: BlockConfig](MLPBase[ConfigType]): +class MLP(MLPBase): def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: int = 0, name: str = "mlp"): Assert.eq(config.num_experts, 1) super().__init__(config, tensor_space, block_index, name) diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 051044ef6..1ecafb344 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -22,7 +22,7 @@ class LanguageModelEmbedding[ConfigType: LanguageModelBaseConfig](Configurable[L together with optional absolute position embeddings and dropout. """ - config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = ConfigType + config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = LanguageModelBaseConfig # Ensure the layer is on its own stage. layer_count: float = 1000.0 diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 2fa0b0f06..6d1fedd26 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -38,7 +38,7 @@ class LanguageModelHead[ConfigType: LanguageModelBaseConfig](Configurable[Config A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). """ - config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = ConfigType + config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = LanguageModelBaseConfig def __init__( self, diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 00c709814..dec0675b9 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -9,7 +9,7 @@ from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: - from fast_llm.engine.config_utils.initialization import Initializer, init_fill_, init_uniform_centered_ + from fast_llm.engine.config_utils.initialization import Initializer class SSMDimNames(BlockDimNames): @@ -66,6 +66,8 @@ class DTInitType(enum.StrEnum): random = "random" def get_init_method(self, scale: float) -> "Initializer": + from fast_llm.engine.config_utils.initialization import init_fill_, init_uniform_centered_ + return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index a8245f7da..f7c7fea9c 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -233,3 +233,14 @@ def add_dense_bias(self) -> bool: # TODO: Use composition instead class TransformerConfig(AttentionConfig, BlockConfig): _abstract = False + + def _validate(self) -> None: + with self._set_implicit_default(): + # Kept here for initialization order. + # TODO: Review initialization + if self.init_method_std is None: + self.init_method_std = self.hidden_size**-0.5 + if self.init_method_min is not None and self.init_method_max is not None: + Assert.leq(self.init_method_min, self.init_method_max) + + super()._validate() diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index ebb629aa1..207cff7d3 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -42,6 +42,8 @@ def apply_rotary_embeddings(tensor: torch.Tensor, rope_frequencies: torch.Tensor class Rotary[ConfigType: RotaryConfig](Configurable[RotaryConfig], torch.nn.Module, Preprocessor): + config_class: typing.ClassVar[type[RotaryConfig]] = RotaryConfig + def __init__( self, config: ConfigType, @@ -58,6 +60,8 @@ def forward( class NoRotary[ConfigType: NoRotaryConfig](Rotary[NoRotaryConfig]): + config_class: typing.ClassVar[type[NoRotaryConfig]] = NoRotaryConfig + def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: @@ -71,6 +75,7 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: class DefaultRotary[ConfigType: DefaultRotaryConfig](Rotary[DefaultRotaryConfig]): + config_class: typing.ClassVar[type[DefaultRotaryConfig]] = DefaultRotaryConfig _rotary_embedding_frequencies: torch.Tensor _tensor_cache_max_sequence_length: int = -1 @@ -154,6 +159,8 @@ def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor: class Llama3Rotary[ConfigType: Llama3RotaryConfig](DefaultRotary[Llama3RotaryConfig]): + config_class: typing.ClassVar[type[Llama3RotaryConfig]] = Llama3RotaryConfig + def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor: scales = super()._get_angle_scales(kv_channels, device) low_frequency_wavelength = self._config.original_context_length / self._config.low_frequency_factor @@ -180,6 +187,8 @@ class YarnRotary[ConfigType: YarnRotaryConfig](DefaultRotary[YarnRotaryConfig]): [original paper](https://arxiv.org/abs/2309.00071) """ + config_class: typing.ClassVar[type[YarnRotaryConfig]] = YarnRotaryConfig + def _get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda") -> torch.Tensor: return super()._get_frequencies(sequence_length, kv_channels, device) * self._config.attention_factor From f48332139d54a7e6cbe3171b480434832d7e5a8d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 31 Jul 2025 21:07:17 -0400 Subject: [PATCH 13/28] fixes --- tests/test_ssms.py | 82 ---------------------------------------------- 1 file changed, 82 deletions(-) delete mode 100644 tests/test_ssms.py diff --git a/tests/test_ssms.py b/tests/test_ssms.py deleted file mode 100644 index 6c4c7f0cb..000000000 --- a/tests/test_ssms.py +++ /dev/null @@ -1,82 +0,0 @@ -import pathlib - -import pytest -import torch - -from fast_llm.config import NoAutoValidate -from fast_llm.engine.checkpoint.config import CheckpointLoadConfig -from fast_llm.engine.distributed.config import DistributedConfig, PhaseType -from fast_llm.engine.schedule.config import ScheduleConfig -from fast_llm.engine.schedule.runner import ScheduleRunner -from fast_llm.engine.schedule.schedule import Schedule -from fast_llm.layers.transformer.config import AttentionKwargs -from fast_llm.models.gpt.config import GPTBatchConfig -from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat -from fast_llm.models.ssm.model import HybridSSMModel - - -@pytest.mark.skip("Disabled due to cartesia_pytorch installation issue") -@pytest.mark.slow -def test_load_from_llamba_checkpoint(): - """ - Test to check whether the of Fast-LLM and Huggingface checkpoint loading for Llamba-1B produce the same results. - """ - import cartesia_pytorch.Llamba.llamba - - vocab_size = 128256 # from https://huggingface.co/cartesia-ai/Llamba-1B/blob/main/config.json - batch_size = 2 - seq_length = 32 - - path = pathlib.Path("/mnt/checkpoints_fml/pretrained_models/Llamba-1B") - format = LLambaHuggingfaceCheckpointFormat - - x = torch.randint(0, vocab_size, (batch_size, seq_length), device="cuda") - - hf_model = cartesia_pytorch.Llamba.llamba.LMHeadModel.from_pretrained(path, strict=True).to("cuda") - parameter_sum_hf = sum(p.detach().sum().cpu().item() for p in hf_model.parameters()) - hf_logits = hf_model(x)["logits"].cpu() - del hf_model - torch.cuda.empty_cache() - - # Create checkpoint load config - checkpoint_config = CheckpointLoadConfig(path=path, format=format, model_weights=True, optimizer_state=False) - # Initialize model - model = HybridSSMModel.from_pretrained(checkpoint_config) - param_sum = 0 - for stage in model.stages: - for fsdp in stage.fsdps: - if hasattr(fsdp, "_weight_shard"): - param_sum += torch.sum(fsdp._weight_shard).item() - assert torch.abs(torch.tensor(param_sum) - parameter_sum_hf) < 1e-1 - - # model = GPTModel.from_pretrained(checkpoint_config) - assert model.config.base_model.vocab_size == vocab_size - schedule_config = ScheduleConfig() - with NoAutoValidate(): - batch_config = GPTBatchConfig(micro_batch_size=batch_size, sequence_length=seq_length) - batch_config.setup(DistributedConfig.from_dict({})) - batch_config.validate() - schedule_runner = ScheduleRunner( - config=schedule_config, - multi_stage=model, - distributed_config=model.distributed.config, - ) - schedule = Schedule( - multi_stage=model, - batch_config=batch_config, - schedule_config=schedule_config, - distributed_config=model.distributed.config, - phase=PhaseType.inference, - ) - schedule_runner.setup(model.distributed, optimizer=None) - - common_kwargs = { - AttentionKwargs.sequence_first: True, - AttentionKwargs.grad_output: False, - } - input_data = [(x, common_kwargs)] - - schedule_runner.run_step(iter([input_data]), schedule, iteration=0, return_metrics=True, preprocessed=True) - - logits = input_data[0][1]["logits"].cpu() - assert torch.allclose(logits, hf_logits, atol=1e-2) From 07c921182b31f2a1fff16da703fcd7e82b73a7fe Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 8 Aug 2025 16:33:45 -0400 Subject: [PATCH 14/28] stuff --- fast_llm/config.py | 6 ++++++ fast_llm/data/data/abstract.py | 1 - fast_llm/data/data/gpt/data.py | 2 -- fast_llm/data/preparator/config.py | 2 -- fast_llm/data/preparator/gpt_memmap/prepare.py | 2 -- fast_llm/engine/base_model/base_model.py | 1 - fast_llm/engine/distributed/distributed.py | 3 --- fast_llm/engine/evaluation/evaluator.py | 4 ---- fast_llm/engine/evaluation/lm_eval/evaluator.py | 2 -- fast_llm/engine/multi_stage/fast_llm_model.py | 1 - fast_llm/engine/multi_stage/multi_stage.py | 1 - fast_llm/engine/multi_stage/stage_base.py | 1 - fast_llm/engine/schedule/runner.py | 1 - fast_llm/engine/training/trainer.py | 3 --- fast_llm/layers/block/block.py | 1 - fast_llm/layers/language_model/embedding.py | 2 -- fast_llm/layers/language_model/head.py | 3 --- fast_llm/layers/transformer/rotary/rotary.py | 9 --------- fast_llm/models/custom/model.py | 6 +----- fast_llm/models/custom/trainer.py | 4 ---- fast_llm/models/gpt/model.py | 3 --- fast_llm/models/gpt/trainer.py | 2 -- fast_llm/models/ssm/model.py | 2 -- fast_llm/models/ssm/trainer.py | 1 - 24 files changed, 7 insertions(+), 56 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index c534b11f3..099670625 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -1028,6 +1028,12 @@ def __init__(self, config: ConfigType, *args, **kwargs): # Handle multiple inheritance. super().__init__(*args, **kwargs) + def __init_subclass__(cls): + # Automatically set `config_class` based on the bound type. + # Make sure `ConfigType` is bound and respects class hierarchy. + Assert.custom(issubclass, config_class := ConfigType.__bound__, cls.config_class) + cls.config_class = config_class + @property def config(self) -> ConfigType: return self._config diff --git a/fast_llm/data/data/abstract.py b/fast_llm/data/data/abstract.py index 04da64a9d..e24d39985 100644 --- a/fast_llm/data/data/abstract.py +++ b/fast_llm/data/data/abstract.py @@ -13,7 +13,6 @@ class Data[ConfigType: DataConfig](Configurable[ConfigType], abc.ABC): - config_class: typing.ClassVar[type[DataConfig]] = DataConfig _distributed: "Distributed" _sampling_parameters: dict[str, SamplingParameters] _cache_directory: pathlib.Path | None diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 37cfd9020..6724afb59 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -65,8 +65,6 @@ class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]): TODO: Separate generic and GPT classes. """ - config_class: typing.ClassVar[type[GPTDataConfig]] = GPTDataConfig - _datasets: dict[str, SampledDataset] _sampling_parameters: dict[str, GPTSamplingParameters] _tokenizer: Tokenizer | None diff --git a/fast_llm/data/preparator/config.py b/fast_llm/data/preparator/config.py index 7f6376c7d..160fccafc 100644 --- a/fast_llm/data/preparator/config.py +++ b/fast_llm/data/preparator/config.py @@ -19,8 +19,6 @@ def _get_runnable(self) -> typing.Callable[[], None]: class DatasetPreparator[ConfigType: DatasetPreparatorConfig](Configurable[ConfigType], abc.ABC): - config_class: typing.ClassVar[type[DatasetPreparatorConfig]] = DatasetPreparatorConfig - @abc.abstractmethod def run(self) -> None: raise NotImplementedError diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 427309a99..33c40bf8f 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -33,8 +33,6 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](DatasetPreparator[ConfigType]): - config_class: typing.ClassVar[type[GPTMemmapDatasetPreparatorConfig]] = GPTMemmapDatasetPreparatorConfig - _tokenizer: Tokenizer _data_type: DataType _text_column: str diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index df603a910..caaf94794 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -82,7 +82,6 @@ def get_layers(self) -> list[Layer]: class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], SequentialLayers, abc.ABC): - config_class: typing.ClassVar[type[BaseModelConfig]] = BaseModelConfig _is_setup: bool = False def __init__( diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index f17a8f452..dc41539c0 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -1,6 +1,5 @@ import datetime import logging -import typing import torch import torch.distributed @@ -146,8 +145,6 @@ class Distributed[ConfigType: DistributedConfig](Configurable[ConfigType]): TODO: Clarify cpu support. """ - config_class: typing.ClassVar[type[DistributedConfig]] = DistributedConfig - def __init__(self, config: DistributedConfig, use_cpu: bool = False): super().__init__(config) assert self._config.reference_config is None diff --git a/fast_llm/engine/evaluation/evaluator.py b/fast_llm/engine/evaluation/evaluator.py index 3bdc2407f..6b8f8db00 100644 --- a/fast_llm/engine/evaluation/evaluator.py +++ b/fast_llm/engine/evaluation/evaluator.py @@ -44,8 +44,6 @@ class EvaluatorSamplingParameters: class Evaluator[ConfigType: EvaluatorConfig](Configurable[ConfigType], abc.ABC): - config_class: typing.ClassVar[type[EvaluatorConfig]] = EvaluatorConfig - _is_setup: bool = False def __init__( @@ -96,8 +94,6 @@ def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: class LossEvaluator[ConfigType: LossEvaluatorConfig](Evaluator[ConfigType]): - config_class: typing.ClassVar[type[LossEvaluatorConfig]] = LossEvaluatorConfig - def setup( self, distributed: Distributed, diff --git a/fast_llm/engine/evaluation/lm_eval/evaluator.py b/fast_llm/engine/evaluation/lm_eval/evaluator.py index 162ceaf60..9040b11b4 100644 --- a/fast_llm/engine/evaluation/lm_eval/evaluator.py +++ b/fast_llm/engine/evaluation/lm_eval/evaluator.py @@ -25,8 +25,6 @@ class LmEvalEvaluator[ConfigType: LmEvalEvaluatorConfig](Evaluator[ConfigType]): - config_class: typing.ClassVar[type[LmEvalEvaluatorConfig]] = LmEvalEvaluatorConfig - _hf_model: "HuggingfaceBaseModelForCausalLM" = None _flm_wrapper: "FastLLMLmEvalWrapper" = None diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index 56bae90fe..09ee788e6 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -14,7 +14,6 @@ class FastLLMModel[ConfigType: FastLLMModelConfig](MultiStageModel[ConfigType]): - config_class: typing.ClassVar[type[FastLLMModelConfig]] = FastLLMModelConfig _is_loaded: bool = False def save_checkpoint( diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index 1f734268b..e17bc4ff8 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -26,7 +26,6 @@ class MultiStageModel[ConfigType: FastLLMModelConfig](Configurable[ConfigType]): - config_class: typing.ClassVar[type[FastLLMModelConfig]] = FastLLMModelConfig base_model_class: typing.ClassVar[type[BaseModel]] = BaseModel _is_setup: bool = False _flat_shard: torch.Tensor diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 3218a1963..387a53a03 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -21,7 +21,6 @@ class StageBase(Configurable[StageConfig]): - config_class: typing.ClassVar[type[StageConfig]] = StageConfig _distributed: Distributed _mode: StageMode diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 7fdba1832..8eca4559d 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -64,7 +64,6 @@ def __repr__(self): class ScheduleRunner[ConfigType: ScheduleConfig](Configurable[ScheduleConfig]): - config_class: typing.ClassVar[type[ScheduleConfig]] = ScheduleConfig _is_setup: bool = False _compute_stream: torch.cuda.Stream _data_stream: torch.cuda.Stream diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 5f5511a15..e5bd5a583 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -43,8 +43,6 @@ class TrainingEvaluator[ConfigType: TrainingEvaluatorConfig](Evaluator[ConfigType]): - config_class: typing.ClassVar[type[TrainingEvaluatorConfig]] = TrainingEvaluatorConfig - evaluator: Evaluator def __init__( @@ -114,7 +112,6 @@ def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: class Trainer[ConfigType: TrainerConfig](Configurable[ConfigType], abc.ABC): - config_class: typing.ClassVar[type[TrainerConfig]] = TrainerConfig # TODO: Generalize data, schedule, logging, etc. _is_setup: bool = False _distributed: Distributed diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 03e0df928..f06b2da45 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -121,7 +121,6 @@ class Block[ConfigType: BlockConfig](Configurable[ConfigType], Layer): A transformer-like decoder base block with abstract mixer. """ - config_class: typing.ClassVar[type[BlockConfig]] = BlockConfig # TODO: Standardize to `mixer` _mixer_module_name: typing.ClassVar[str] = "mixer" diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 1ecafb344..d90442e9f 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -22,8 +22,6 @@ class LanguageModelEmbedding[ConfigType: LanguageModelBaseConfig](Configurable[L together with optional absolute position embeddings and dropout. """ - config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = LanguageModelBaseConfig - # Ensure the layer is on its own stage. layer_count: float = 1000.0 diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 6d1fedd26..8624612d6 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -1,5 +1,4 @@ import logging -import typing import torch from torch._C._distributed_c10d import ReduceOp # noqa @@ -38,8 +37,6 @@ class LanguageModelHead[ConfigType: LanguageModelBaseConfig](Configurable[Config A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). """ - config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = LanguageModelBaseConfig - def __init__( self, config: ConfigType, diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index 207cff7d3..ebb629aa1 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -42,8 +42,6 @@ def apply_rotary_embeddings(tensor: torch.Tensor, rope_frequencies: torch.Tensor class Rotary[ConfigType: RotaryConfig](Configurable[RotaryConfig], torch.nn.Module, Preprocessor): - config_class: typing.ClassVar[type[RotaryConfig]] = RotaryConfig - def __init__( self, config: ConfigType, @@ -60,8 +58,6 @@ def forward( class NoRotary[ConfigType: NoRotaryConfig](Rotary[NoRotaryConfig]): - config_class: typing.ClassVar[type[NoRotaryConfig]] = NoRotaryConfig - def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: @@ -75,7 +71,6 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: class DefaultRotary[ConfigType: DefaultRotaryConfig](Rotary[DefaultRotaryConfig]): - config_class: typing.ClassVar[type[DefaultRotaryConfig]] = DefaultRotaryConfig _rotary_embedding_frequencies: torch.Tensor _tensor_cache_max_sequence_length: int = -1 @@ -159,8 +154,6 @@ def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor: class Llama3Rotary[ConfigType: Llama3RotaryConfig](DefaultRotary[Llama3RotaryConfig]): - config_class: typing.ClassVar[type[Llama3RotaryConfig]] = Llama3RotaryConfig - def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor: scales = super()._get_angle_scales(kv_channels, device) low_frequency_wavelength = self._config.original_context_length / self._config.low_frequency_factor @@ -187,8 +180,6 @@ class YarnRotary[ConfigType: YarnRotaryConfig](DefaultRotary[YarnRotaryConfig]): [original paper](https://arxiv.org/abs/2309.00071) """ - config_class: typing.ClassVar[type[YarnRotaryConfig]] = YarnRotaryConfig - def _get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda") -> torch.Tensor: return super()._get_frequencies(sequence_length, kv_channels, device) * self._config.attention_factor diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index 3c0ad8ab4..98937bdb1 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -8,16 +8,13 @@ from fast_llm.engine.schedule.config import BatchConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.transformer.block import TransformerBlock -from fast_llm.models.custom.config import CustomBaseModelConfig, CustomModelConfig +from fast_llm.models.custom.config import CustomBaseModelConfig from fast_llm.models.custom.head import CustomHead -from fast_llm.models.gpt.config import GPTBaseModelConfig from fast_llm.models.gpt.model import GPTBaseModel, GPTModel from fast_llm.tensor import TensorMeta class CustomBaseModel[ConfigType: CustomBaseModelConfig](GPTBaseModel[ConfigType]): - config_class: typing.ClassVar[type[GPTBaseModelConfig]] = GPTBaseModelConfig - def __init__( self, config: CustomBaseModelConfig, @@ -66,5 +63,4 @@ def loss_defs(self) -> list[LossDef]: class CustomModel[ConfigType: CustomBaseModelConfig](GPTModel[ConfigType]): - config_class: typing.ClassVar[type[CustomModelConfig]] = CustomModelConfig base_model_class: typing.ClassVar[type[CustomBaseModel]] = CustomBaseModel diff --git a/fast_llm/models/custom/trainer.py b/fast_llm/models/custom/trainer.py index eba51235e..587adad3e 100644 --- a/fast_llm/models/custom/trainer.py +++ b/fast_llm/models/custom/trainer.py @@ -1,5 +1,3 @@ -import typing - from fast_llm.models.custom.config import CustomTrainerConfig from fast_llm.models.custom.data import CustomData from fast_llm.models.gpt.trainer import GPTTrainer @@ -7,8 +5,6 @@ class CustomTrainer[ConfigType: CustomTrainerConfig](GPTTrainer[ConfigType]): # TODO: Implement changes in the training loop (or tflops computation), if any (typically none). - config_class: typing.ClassVar[type[CustomTrainerConfig]] = CustomTrainerConfig - def _get_data(self): # TODO: Adjust signature if needed. return CustomData( diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 187ca618d..47df8ba1c 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -31,8 +31,6 @@ class GPTBaseModel[ConfigType: GPTBaseModelConfig](BaseModel[ConfigType]): A transformer-based language model generalizing the GPT model architecture. """ - config_class: typing.ClassVar[type[GPTBaseModelConfig]] = GPTBaseModelConfig - def __init__( self, config: GPTBaseModelConfig, @@ -410,7 +408,6 @@ def loss_defs(self) -> list[LossDef]: class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]): - config_class: typing.ClassVar[type[GPTModelConfig]] = GPTModelConfig base_model_class: typing.ClassVar[type[GPTBaseModel]] = GPTBaseModel def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration, batch_size, sequence_length) -> tuple[int, int]: diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 54508e8e1..7f2e83ab4 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -10,8 +10,6 @@ class GPTTrainer[ConfigType: GPTTrainerConfig](Trainer[ConfigType]): - config_class: typing.ClassVar[type[GPTTrainerConfig]] = GPTTrainerConfig - def _get_data(self) -> GPTData: return GPTData( config=self._config.data, diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index ca840911f..32fbdad9b 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -20,7 +20,6 @@ class HybridSSMBaseModel[ConfigType: HybridSSMBaseModelConfig](GPTBaseModel[Conf As for the mixer, transformer uses MHA. For the LlambaBlock we support Mamba1 and discrete mamba2. """ - config_class: typing.ClassVar[type[HybridSSMBaseModelConfig]] = HybridSSMBaseModelConfig _is_setup: bool = False def __init__( @@ -110,7 +109,6 @@ class HybridSSMModel[ConfigType: HybridSSMModelConfig](GPTModel[ConfigType]): A hybrid model that combines Transformer and SSM blocks. """ - config_class: typing.ClassVar[type[HybridSSMModelConfig]] = HybridSSMModelConfig base_model_class: typing.ClassVar[type[HybridSSMBaseModel]] = HybridSSMBaseModel diff --git a/fast_llm/models/ssm/trainer.py b/fast_llm/models/ssm/trainer.py index efa7b704f..39f589384 100644 --- a/fast_llm/models/ssm/trainer.py +++ b/fast_llm/models/ssm/trainer.py @@ -6,5 +6,4 @@ class HybridSSMTrainer[ConfigType: HybridSSMTrainerConfig](GPTTrainer[ConfigType]): - config_class: typing.ClassVar[type[HybridSSMTrainerConfig]] = HybridSSMTrainerConfig model_class: typing.ClassVar[type[HybridSSMModel]] = HybridSSMModel From 0a5e4584990165ad5ed69434fc3ca37f3e9ae856 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 14 Aug 2025 16:22:15 -0400 Subject: [PATCH 15/28] Remove tensor space, fixes --- Megatron-LM | 2 +- fast_llm/config.py | 18 +- fast_llm/engine/base_model/base_model.py | 57 +++--- fast_llm/engine/base_model/config.py | 7 +- .../{tensor_space.py => tensor_dim.py} | 50 +---- fast_llm/engine/multi_stage/fast_llm_model.py | 1 + fast_llm/engine/multi_stage/fsdp.py | 16 +- fast_llm/engine/multi_stage/multi_stage.py | 2 +- fast_llm/engine/multi_stage/stage.py | 13 +- fast_llm/engine/multi_stage/stage_base.py | 4 +- fast_llm/engine/schedule/runner.py | 2 +- fast_llm/layers/block/block.py | 107 +++++----- fast_llm/layers/block/config.py | 7 - fast_llm/layers/block/mlp/config.py | 44 ---- .../layers/block/mlp/mixture_of_experts.py | 51 +++-- fast_llm/layers/block/mlp/mlp.py | 65 ++++-- fast_llm/layers/common/config.py | 2 +- fast_llm/layers/common/linear.py | 2 +- fast_llm/layers/common/normalization.py | 2 +- fast_llm/layers/common/peft.py | 2 +- fast_llm/layers/language_model/config.py | 23 +-- fast_llm/layers/language_model/embedding.py | 63 +++--- fast_llm/layers/language_model/head.py | 188 ++++++++---------- .../layers/language_model/preprocessing.py | 36 ++-- fast_llm/layers/ssm/block.py | 23 ++- fast_llm/layers/ssm/config.py | 96 +-------- fast_llm/layers/ssm/discrete_mamba2.py | 78 +++++--- .../layers/ssm/{mamba_layer.py => mamba.py} | 52 +++-- fast_llm/layers/ssm/mamba2.py | 93 +++++---- fast_llm/layers/transformer/attention.py | 184 +++++++++-------- fast_llm/layers/transformer/block.py | 10 +- fast_llm/layers/transformer/config.py | 48 +---- fast_llm/layers/transformer/preprocessing.py | 51 ++--- fast_llm/layers/transformer/rotary/config.py | 6 +- .../transformer/rotary/preprocessing.py | 68 ------- fast_llm/layers/transformer/rotary/rotary.py | 57 ++---- fast_llm/logging.py | 6 +- fast_llm/models/custom/model.py | 29 +-- fast_llm/models/gpt/model.py | 105 ++++++---- fast_llm/models/ssm/config.py | 9 - fast_llm/models/ssm/model.py | 113 +++-------- fast_llm/tensor.py | 25 +-- tests/functional/test_triton_kernels.py | 4 +- tests/test_attention.py | 35 +--- tests/test_mlp.py | 29 --- tests/utils/global_variables.py | 4 +- tests/utils/utils.py | 7 +- 47 files changed, 791 insertions(+), 1105 deletions(-) rename fast_llm/engine/config_utils/{tensor_space.py => tensor_dim.py} (81%) rename fast_llm/layers/ssm/{mamba_layer.py => mamba.py} (79%) delete mode 100644 fast_llm/layers/transformer/rotary/preprocessing.py delete mode 100644 tests/test_mlp.py diff --git a/Megatron-LM b/Megatron-LM index 75b0d9787..f02b413f7 160000 --- a/Megatron-LM +++ b/Megatron-LM @@ -1 +1 @@ -Subproject commit 75b0d97876006c4b6b23fce302100d18dbf7db37 +Subproject commit f02b413f793af05ade3893bccd8aef6d644d3edf diff --git a/fast_llm/config.py b/fast_llm/config.py index 099670625..3352f3570 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -1031,7 +1031,23 @@ def __init__(self, config: ConfigType, *args, **kwargs): def __init_subclass__(cls): # Automatically set `config_class` based on the bound type. # Make sure `ConfigType` is bound and respects class hierarchy. - Assert.custom(issubclass, config_class := ConfigType.__bound__, cls.config_class) + try: + config_class = None + for base in types.get_original_bases(cls): + if hasattr(base, "__origin__") and issubclass(base.__origin__, Configurable): + for arg in base.__args__: + if arg.__name__ == "ConfigType": + if config_class is None: + config_class = arg.__bound__ + else: + assert arg.__bound__ is config_class + assert config_class is not None + except Exception as e: + raise TypeError( + f"Could not determine the configuration class for the configurable class {cls.__name__}: {e.args}. " + "Please make sure to declare in the format " + f"`class {cls.__name__}[ConfigType: ConfigClass](BaseConfigurable[ConfigType])`.] " + ) cls.config_class = config_class @property diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index caaf94794..832225803 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -7,7 +7,6 @@ from fast_llm.config import Configurable from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.tensor import ParameterMeta, TensorMeta @@ -20,11 +19,18 @@ class Module(torch.nn.Module, abc.ABC): """ """ - def forward(self, input_, kwargs): - """ - Run a forward pass for the module, with autograd support. - """ - raise NotImplementedError() + _is_setup: bool = False + _distributed: Distributed + + def __init__(self, distributed_config: DistributedConfig): + self._distributed_config = distributed_config + super().__init__() + + def setup(self, distributed: Distributed) -> None: + assert not self._is_setup + distributed.check_config(self._distributed_config) + self._distributed = distributed + self._is_setup = True class Layer(Module): @@ -39,9 +45,9 @@ def forward( class Sequential(Layer): - def __init__(self, layers: list[Layer]): - super().__init__() - self.layers = torch.nn.ModuleList(layers) + def __init__(self, distributed_config: DistributedConfig): + super().__init__(distributed_config) + self.layers = torch.nn.ModuleList(self.get_layers()) def __getitem__(self, item): return self.layers[item] @@ -59,6 +65,15 @@ def forward( input_ = layer(input_, kwargs, losses, metrics) return input_ + @abc.abstractmethod + def get_layers(self) -> list[Layer]: + pass + + def setup(self, distributed: Distributed) -> None: + super().setup(distributed) + for layer in self.layers: + layer.setup(distributed) + @dataclasses.dataclass() class LossDef: @@ -71,28 +86,14 @@ class LossDef: dtype: torch.dtype = torch.float32 -class SequentialLayers(Sequential, abc.ABC): - # Small class defined to fix the MRO of BaseModel.__init__ - def __init__(self): - super().__init__(self.get_layers()) - - @abc.abstractmethod - def get_layers(self) -> list[Layer]: - pass - - -class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], SequentialLayers, abc.ABC): - _is_setup: bool = False +class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], Sequential): def __init__( self, config: BaseModelConfig, distributed_config: DistributedConfig, ): - self._tensor_space: TensorSpace = TensorSpace(distributed_config) - config.setup_tensor_space(self._tensor_space) - - super().__init__(config) + super().__init__(config, distributed_config) for key, value in self.named_parameters(): Assert.custom(isinstance, value, ParameterMeta) @@ -103,12 +104,6 @@ def __init__( # TODO: Add basic handling (preprocessor) in this class. self._reference_models: dict[str, "InferenceRunner"] = {} - def setup(self, distributed: Distributed) -> None: - assert not self._is_setup - distributed.check_config(self._tensor_space.distributed_config) - self._tensor_space.setup(distributed) - self._is_setup = True - @abc.abstractmethod def get_layers(self) -> list[Layer]: pass diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 4be42e069..22abb021b 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -6,7 +6,7 @@ from fast_llm.utils import compare_nested, log if typing.TYPE_CHECKING: - from fast_llm.engine.config_utils.tensor_space import TensorSpace + import torch @config_class() @@ -18,9 +18,6 @@ class BaseModelConfig(Config): _abstract = True - def setup_tensor_space(self, tensor_space: "TensorSpace") -> None: - raise NotImplementedError() - def compare_architecture( self, model_config: typing.Self, @@ -64,5 +61,5 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: pass @abc.abstractmethod - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: pass diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_dim.py similarity index 81% rename from fast_llm/engine/config_utils/tensor_space.py rename to fast_llm/engine/config_utils/tensor_dim.py index 6c4b95b20..f67916a66 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_dim.py @@ -2,14 +2,13 @@ import math import typing -from fast_llm.engine.distributed.config import DistributedConfig, DistributedDim +from fast_llm.engine.distributed.config import DistributedDim from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: import torch from fast_llm.core.distributed import ProcessGroup - from fast_llm.engine.distributed.distributed import Distributed logger = logging.getLogger(__name__) @@ -219,49 +218,4 @@ def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = F ) -class DefaultDimNames: - # Scalar - scalar = "scalar" - - -class TensorSpace: - _is_setup: bool = False - _distributed: "Distributed" - - def __init__(self, distributed_config: DistributedConfig): - self._distributed_config = distributed_config - self._tensor_dims: dict[str, TensorDim] = {} - self.add_tensor_dim(TensorDim(DefaultDimNames.scalar, 1)) - - def setup(self, distributed: "Distributed") -> None: - assert not self._is_setup - if distributed.config is not self._distributed_config: - distributed.config.compare(self._distributed_config, ValueError) - self._is_setup = True - self._distributed = distributed - - @property - def distributed_config(self) -> DistributedConfig: - return self._distributed_config - - @property - def distributed(self) -> "Distributed": - assert self._is_setup - return self._distributed - - def add_tensor_dim(self, tensor_dim: TensorDim) -> None: - if tensor_dim.name in self._tensor_dims: - Assert.eq(tensor_dim, self._tensor_dims[tensor_dim.name]) - else: - if tensor_dim.parallel_dim is not None: - assert ( - tensor_dim.parallel_dim.name in self._distributed_config.distributed_dims - ), tensor_dim.parallel_dim.name - Assert.eq( - tensor_dim.parallel_dim.__dict__, - self._distributed_config.distributed_dims[tensor_dim.parallel_dim.name].__dict__, - ) - self._tensor_dims[tensor_dim.name] = tensor_dim - - def __getitem__(self, name: str) -> TensorDim: - return self._tensor_dims[name] +scalar_dim = TensorDim("scalar", 1) diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index 09ee788e6..da4fe527e 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -51,6 +51,7 @@ def from_pretrained( use_cpu: bool = False, stage_filter: set | None = None, ) -> typing.Self: + print("IUGRGHIOERIO", cls, cls.config_class) metadata = cls.config_class.load_metadata(pretrained_config) config = cls.config_class.from_dict(metadata.config, *updates, update_type=UpdateType.update) if mode.support_training: diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index be15cd37a..cb0a02a67 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -9,7 +9,7 @@ from fast_llm.core.distributed import ProcessGroup from fast_llm.core.ops import gather_op from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import SHARD_PAD_TO_MULTIPLE, ShardName, StageMode @@ -320,27 +320,31 @@ def import_state_tensor( return end - begin def export_shard( - self, shard: torch.Tensor, distributed: Distributed, data_type: DataType | None = None + self, shard: torch.Tensor, data_type: DataType | None = None ) -> typing.Generator[tuple[str, torch.Tensor], None, None]: if data_type is not None: shard = shard.to(dtype=data_type.torch) tensors = self.split_buffer(self.reconstruct_from_shard(shard)) for name, meta in self._parameter_metas.items(): - yield name, meta.local_to_global(tensors[name], distributed=distributed)[0] + yield name, meta.local_to_global(tensors[name])[0] def log_shard(self, name, shard, *, distributed: Distributed, level, global_: bool) -> None: # if global_ is None: # global_ = self._config.debug_global_tensors parameters = self.split_buffer(self.reconstruct_from_shard(shard)) if global_ else self.split_shard(shard) for parameter_name, parameter in parameters.items(): + meta = self.get_parameter_meta(parameter_name) log_distributed_tensor( name, parameter, level=level, - distributed=distributed, global_=global_, - duplicate_groups=(distributed.data_group,), - meta=self.get_parameter_meta(parameter_name), + # Assuming all tensors are either duplicated of parallel in the TP direction. + duplicate_groups=( + distributed.data_group, + distributed.tensor_group, + ), + meta=meta, ) def restore_parameters(self) -> None: diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index e17bc4ff8..d939bda2b 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -12,7 +12,7 @@ from fast_llm.engine.base_model.base_model import BaseModel from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.run import log_main_rank, log_model_parallel_main_rank -from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 87eac31c4..35547cd87 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -7,7 +7,7 @@ from fast_llm.core.distributed import check_parallel_match from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.multi_stage.config import StageMode +from fast_llm.engine.multi_stage.config import StageConfig, StageMode from fast_llm.engine.multi_stage.stage_base import StageBase from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage, log_tensor from fast_llm.tensor import ParameterMeta, TensorMeta, accumulate_gradient @@ -30,7 +30,7 @@ def hook(grad_inputs, grad_outputs): # noqa return hook -class Stage(StageBase): +class Stage[ConfigType: StageConfig](StageBase[ConfigType]): _is_restored: bool _training: bool | None = None # TODO: Handle all buffer sharing in multi_stage @@ -123,7 +123,7 @@ def forward( # Last layer does not provide output if output is not None: meta = self._meta_outputs[i] - output_global, _ = meta.local_to_global(output.detach(), distributed=self._distributed) + output_global, _ = meta.local_to_global(output.detach()) kwargs["hidden_states"][self._layer_range[i]] = { "layer_type": type(layer).__name__, "tensor": output_global, @@ -216,11 +216,13 @@ def _log_layer_forward(self, output: torch.Tensor, kwargs: dict[str, typing.Any] if (nms := kwargs.get("micro_batch_splits", 1)) > 1: name = f"{name}, ms={kwargs.get('micro_batch_split',0)}/{nms}" + # Assuming all tensors are either duplicated of parallel in the TP direction. log_distributed_tensor( name, output, level=self._config.debug_layer_outputs, - distributed=self._distributed, + # Assuming all tensors are either duplicated of parallel in the TP direction. + duplicate_groups=(self._distributed.tensor_group,), global_=self._config.debug_global_tensors, meta=self._meta_outputs[i], ) @@ -250,8 +252,9 @@ def _log_layer_backward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any name, input_, level=self._config.debug_layer_gradients, - distributed=self._distributed, grad_fn=lambda grad: grad / self._fsdp_size, + # Assuming all tensors are either duplicated of parallel in the TP direction. + duplicate_groups=(self._distributed.tensor_group,), global_=self._config.debug_global_tensors, meta=self._meta_inputs[i], ) diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 387a53a03..ded24e538 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) -class StageBase(Configurable[StageConfig]): +class StageBase[ConfigType: StageConfig](Configurable[ConfigType]): _distributed: Distributed _mode: StageMode @@ -314,7 +314,7 @@ def _export_shard( self, shards: tuple[torch.Tensor], data_type: DataType | None = None ) -> typing.Generator[tuple[str, torch.Tensor], None, None]: for fsdp, shard in zip(self._fsdps, shards, strict=True): - yield from fsdp.export_shard(shard, self._distributed, data_type) + yield from fsdp.export_shard(shard, data_type) def _get_parameter_metas(self) -> tuple[list[ParameterMeta], list[ParameterMeta]]: # Get all the stage parameters, diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 8eca4559d..21ecbe476 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -63,7 +63,7 @@ def __repr__(self): ) -class ScheduleRunner[ConfigType: ScheduleConfig](Configurable[ScheduleConfig]): +class ScheduleRunner[ConfigType: ScheduleConfig](Configurable[ConfigType]): _is_setup: bool = False _compute_stream: torch.cuda.Stream _data_stream: torch.cuda.Stream diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index f06b2da45..425731eb9 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -5,12 +5,14 @@ import torch -from fast_llm.config import Configurable +from fast_llm.config import Config, Configurable from fast_llm.core.distributed import set_generator -from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.base_model.base_model import Layer, Module from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta @@ -19,8 +21,7 @@ class DebugLayer: # TODO: Move elsewhere? - def __init__(self, tensor_space: TensorSpace, name: str, debug_level: int = 0, debug_memory: bool = False): - self._tensor_space = tensor_space + def __init__(self, name: str, debug_level: int = 0, debug_memory: bool = False): self._name = name self._debug_level = debug_level self._debug_memory = debug_memory @@ -36,9 +37,9 @@ def _get_meta( ( dim if isinstance(dim, TensorDim) - else hidden_dims[dim] if dim in hidden_dims else self._tensor_space[dim] + else hidden_dims[dim] if dim in hidden_dims else TensorDim(dim, tensor.size(i)) ) - for dim in dims + for i, dim in enumerate(dims) ), tensor_name=f"{self._name} {name}", dtype=tensor.dtype, @@ -69,7 +70,6 @@ def __call__[ tensor, level=self._debug_level, meta=self._get_meta(tensor, name, dims, kwargs), - distributed=self._tensor_space.distributed, global_=global_, log_fn=log_fn, scale=scale, @@ -80,31 +80,45 @@ def __call__[ tensor, level=self._debug_level, meta=self._get_meta(tensor, name + " grad", dims, kwargs), - distributed=self._tensor_space.distributed, global_=global_, log_fn=log_fn, scale=scale, ) -class BlockLayer(torch.nn.Module, abc.ABC): +class BlockLayerBase[ConfigType: Config](Configurable[ConfigType], Module): """ - Base class for mixer and MLP modules. + Base class for blocks, mixer and MLP modules. """ - def __init__(self, tensor_space: TensorSpace, block_index: int, name: str, debug_level: int, debug_memory: bool): - super().__init__() - self._tensor_space = tensor_space + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + # TODO: Review `hidden_dim` and `block_index` + hidden_dim: TensorDim, + block_index: int, + name: str, + debug_level: int, + debug_memory: bool, + ): + super().__init__(config, distributed_config) + self._hidden_dim = hidden_dim self._block_index = block_index self._name = name - self._sequence_parallel: bool = self._tensor_space.distributed_config.sequence_tensor_parallel + self._sequence_parallel: bool = self._distributed_config.sequence_tensor_parallel self._debug = DebugLayer( - tensor_space, self._name, debug_level, debug_memory, ) + +class BlockLayer[ConfigType: Config](BlockLayerBase[ConfigType]): + """ + Base class for mixer and MLP modules. + """ + @abc.abstractmethod def forward( self, @@ -116,7 +130,7 @@ def forward( pass -class Block[ConfigType: BlockConfig](Configurable[ConfigType], Layer): +class Block[ConfigType: BlockConfig](BlockLayerBase[ConfigType], Layer): """ A transformer-like decoder base block with abstract mixer. """ @@ -124,26 +138,30 @@ class Block[ConfigType: BlockConfig](Configurable[ConfigType], Layer): # TODO: Standardize to `mixer` _mixer_module_name: typing.ClassVar[str] = "mixer" - def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, return_input: bool = False): - super().__init__(config) - # TODO: Argument? - self._block_index = block_index - self._name = f"Block {self._block_index}" - self._tensor_space: TensorSpace = tensor_space - self._dropout_p: float = self._config.hidden_dropout + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + return_input: bool = False, + ): + super().__init__( + config, + distributed_config, + hidden_dim, + block_index, + name, + config.debug_transformer, + config.debug_transformer_memory, + ) # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input - self._debug = DebugLayer( - tensor_space, - self._name, - self._config.debug_transformer, - self._config.debug_transformer_memory, - ) - hidden_dim = self._tensor_space[BlockDimNames.hidden] # Note, layer_lr_scale does not impact the norms # TODO: add a separate norm_lr_scale - self.norm_1 = self._config.normalization.get_layer(hidden_dim) - self.norm_2 = self._config.normalization.get_layer(hidden_dim) + self.norm_1 = self._config.normalization.get_layer(self._hidden_dim) + self.norm_2 = self._config.normalization.get_layer(self._hidden_dim) # The mixer needs to be created here for backward-compatible weight ordering. setattr(self, self._mixer_module_name, self._create_mixer()) @@ -153,15 +171,18 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: i from fast_llm.layers.block.mlp.mlp import MLP self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, - self._tensor_space, - self._block_index, + self._config, self._distributed_config, self._hidden_dim, self._block_index, f"{self._name} MLP" ) # PEFT. self.norm_1 = self._config.peft.apply_other(self.norm_1) self.norm_2 = self._config.peft.apply_other(self.norm_2) + def setup(self, distributed: Distributed) -> None: + super().setup(distributed) + getattr(self, self._mixer_module_name).setup(distributed) + self.mlp.setup(distributed) + @abc.abstractmethod def _create_mixer(self) -> BlockLayer: pass @@ -172,11 +193,7 @@ def _bias_dropout_add( ) -> torch.Tensor: if bias is not None: input_ = input_ + bias - return residual + torch.dropout(input_, self._dropout_p, self.training) - - # @property - # def name(self) -> str: - # return f"{self._name} {self._block_index}" + return residual + torch.dropout(input_, self._config.hidden_dropout, self.training) def forward( self, @@ -190,11 +207,7 @@ def forward( if self._return_input: dims = (TensorDim("stacked_input_output", 2),) + dims return TensorMeta.from_dims(dims, tensor_name=f"{self._name} output", dtype=input_.dtype) - generator = ( - self._tensor_space.distributed.tp_generator - if self._tensor_space.distributed_config.sequence_tensor_parallel - else self._tensor_space.distributed.pp_generator - ) + generator = self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator if self._debug.enabled: self._debug(None, "begin", kwargs[BlockKwargs.hidden_dims], kwargs) fw_input = input_ diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 919f95b3f..0da7a0c99 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -2,7 +2,6 @@ from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.block.mlp.config import MLPConfig from fast_llm.layers.block.peft import TransformerPeftConfig from fast_llm.layers.common.config import NormalizationConfig @@ -117,9 +116,3 @@ class BlockConfig(MLPConfig, BaseModelConfig): desc="Min value for clamping initialized weights. Default: -float('inf')", hint=FieldHint.optional, ) - - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - super().setup_tensor_space(tensor_space) - - # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(BlockDimNames.hidden, self.hidden_size)) diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index a99debacc..57f7a9e03 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -1,27 +1,10 @@ import enum from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace -from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType, MLPRecomputeLevel from fast_llm.utils import Assert -class MLPDimNames: - # MLP dimensions - mlp = "mlp" - gate_and_up = "gate_and_up" - composite_gated_mlp = "composite_gated_mlp" - experts = "experts" - top_experts = "top_experts" - shared_experts = "shared_experts" - unshared_experts = "unshared_experts" - composite_expert_mlp = "composite_expert_mlp" - composite_gated_expert_mlp = "composite_gated_expert_mlp" - composite_shared_expert_mlp = "composite_shared_expert_mlp" - composite_gated_shared_expert_mlp = "composite_gated_shared_expert_mlp" - - class MLPLossNames: load_balancing_loss = "load_balancing_loss" router_z_loss = "router_z_loss" @@ -206,30 +189,3 @@ def _validate(self) -> None: Assert.geq(scale, 0) elif self.mlp_lr_scale is not None: Assert.geq(self.mlp_lr_scale, 0) - - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) - - # MLP dimensions - tensor_space.add_tensor_dim(mlp := TensorDim(MLPDimNames.mlp, self.ffn_hidden_size, tensor)) - tensor_space.add_tensor_dim(gate_and_up := TensorDim(MLPDimNames.gate_and_up, 2 if self.gated else 1)) - tensor_space.add_tensor_dim(CompositeTensorDim(MLPDimNames.composite_gated_mlp, (gate_and_up, mlp))) - tensor_space.add_tensor_dim(experts := TensorDim(MLPDimNames.experts, self.num_experts)) - tensor_space.add_tensor_dim(CompositeTensorDim(MLPDimNames.composite_expert_mlp, (experts, mlp))) - tensor_space.add_tensor_dim( - CompositeTensorDim(MLPDimNames.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) - ) - tensor_space.add_tensor_dim(TensorDim(MLPDimNames.top_experts, self.num_experts_per_token)) - tensor_space.add_tensor_dim(TensorDim(MLPDimNames.unshared_experts, self.num_unshared_experts)) - - # shared_experts - if self.num_shared_experts: - tensor_space.add_tensor_dim( - shared_experts := TensorDim(MLPDimNames.shared_experts, self.num_shared_experts) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(MLPDimNames.composite_shared_expert_mlp, (shared_experts, mlp)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(MLPDimNames.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp)) - ) diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index e53693460..0bc531dad 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -5,11 +5,12 @@ from fast_llm.core.distributed import ProcessGroup, set_generator from fast_llm.engine.config_utils.initialization import init_normal_ -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped from fast_llm.functional.triton.sparse_copy import get_sparse_map -from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs -from fast_llm.layers.block.mlp.config import MLPDimNames, MLPLossNames, RoutingType +from fast_llm.layers.block.config import BlockConfig, BlockKwargs +from fast_llm.layers.block.mlp.config import MLPLossNames, RoutingType from fast_llm.layers.block.mlp.mlp import MLPBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.linear import Linear @@ -18,7 +19,7 @@ logger = logging.getLogger(__name__) -class MixtureOfExpertMLP(MLPBase): +class MixtureOfExpertMLP[ConfigType: BlockConfig](MLPBase[ConfigType]): """ MoeLayer following implementation from https://github.com/NVIDIA/Megatron-LM/blob/46ebc0e4202c980d98900000d455f754a7ff9d4b/megatron/model/transformer.py#L346 @@ -32,18 +33,25 @@ class MixtureOfExpertMLP(MLPBase): _group: ProcessGroup - def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: int = 0, name: str = "mlp"): + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + ): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, tensor_space, block_index, name) + super().__init__(config, distributed_config, hidden_dim, block_index, name) layer_lr_scale = self._config.per_layer_lr_scale[block_index] if self._config.per_layer_lr_scale else None router_lr_scale = get_lr_scale(self._config.router_lr_scale, layer_lr_scale) self.router = Linear( - tensor_space[BlockDimNames.hidden], - tensor_space[MLPDimNames.unshared_experts], + hidden_dim, + TensorDim("router_experts", self._config.num_unshared_experts), bias=False, weight_init_method=init_normal_( std=self._config.init_method_std, @@ -53,20 +61,33 @@ def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: lr_scale=router_lr_scale, ) dropless_moe = self._config.dropless_moe - if dropless_moe and tensor_space.distributed_config.sequence_tensor_parallel: + if dropless_moe and self._sequence_parallel: warnings.warn( "Dropless MoE not supported for sequence-tensor-parallel, falling back to looped implementation." ) dropless_moe = False self._mlp_forward = self._forward_dropless if dropless_moe else self._forward_looped + if self._debug.enabled: + self._top_expert_dim = TensorDim("top_experts", self._config.num_experts_per_token) + + def _get_intermediate_dims(self) -> tuple[TensorDim, TensorDim]: + intermediate_1_dim, intermediate_2_dim = super()._get_intermediate_dims() + experts_dim = TensorDim("experts", self._config.num_experts) + return ( + CompositeTensorDim("moe_intermediate_1", (experts_dim, intermediate_1_dim)), + CompositeTensorDim("moe_intermediate_2", (experts_dim, intermediate_2_dim)), + ) + def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: hidden_states = input_.flatten(0, -2) logits = self.router(hidden_states) if self._debug.enabled: - self._debug(logits, "Router logits", kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.experts,), kwargs) + self._debug( + logits, "Router logits", kwargs[BlockKwargs.hidden_dims][:-1] + (self._top_expert_dim,), kwargs + ) # Apply z_loss if applicable if self._config.expert_z_loss_coefficient > 0.0: @@ -81,7 +102,7 @@ def forward( # Apply input_jitter if applicable: if self.training and self._config.moe_jitter_eps > 0.0: - with set_generator(self._tensor_space.distributed.pp_generator): + with set_generator(self._distributed.pp_generator): logits = self._apply_input_jitter(logits) # Routing @@ -97,12 +118,12 @@ def forward( if self._debug.enabled: # To log all ranks set `global_=False` self._debug( - scores, "Router scores", kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.top_experts,), kwargs + scores, "Router scores", kwargs[BlockKwargs.hidden_dims][:-1] + (self._top_expert_dim,), kwargs ) self._debug( top_experts, "Router top experts", - kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.top_experts,), + kwargs[BlockKwargs.hidden_dims][:-1] + (self._top_expert_dim,), kwargs, ) @@ -126,7 +147,7 @@ def _forward_dropless( None, gated=self._config.gated, activation_type=self._config.activation_type, - group=self._intermediate_dim.parallel_group, + group=self._parallel_dim.group, sequence_parallel=self._sequence_parallel, training=self.training, recompute_level=self._config.mlp_recompute_level, @@ -146,7 +167,7 @@ def _forward_looped( self._config.num_experts, self._config.gated, self._config.activation_type, - self._intermediate_dim.parallel_group, + self._parallel_dim.group, self._sequence_parallel, self.training, self._config.mlp_recompute_level, diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index 06850c8d0..dc5178479 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -3,27 +3,37 @@ import torch from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_ -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.tensor_dim import ConcatenatedTensorDim, TensorDim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd from fast_llm.layers.block.block import BlockLayer -from fast_llm.layers.block.config import BlockConfig, BlockDimNames -from fast_llm.layers.block.mlp.config import MLPDimNames +from fast_llm.layers.block.config import BlockConfig from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import LinearBase from fast_llm.utils import Assert, get_lr_scale -class MLPBase(BlockLayer): - def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: int = 0, name: str = "mlp"): +class MLPBase[ConfigType: BlockConfig](BlockLayer[ConfigType]): + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + ): super().__init__( - tensor_space, + config, + distributed_config, + hidden_dim, block_index, name, - debug_level=config.debug_transformer, - debug_memory=config.debug_transformer_memory, + config.debug_transformer, + config.debug_transformer_memory, ) - self._config = config + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + intermediate_1_dim, intermediate_2_dim = self._get_intermediate_dims() init_method_1 = init_normal_( std=self._config.init_method_std_mlp_1, @@ -36,8 +46,6 @@ def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: max_val=self._config.init_method_max_mlp_2, ) - hidden_dim = self._tensor_space[BlockDimNames.hidden] - self._intermediate_dim = self._tensor_space[MLPDimNames.composite_expert_mlp] self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation layer_lr_scale = self._config.per_layer_lr_scale[block_index] if self._config.per_layer_lr_scale else None @@ -51,19 +59,19 @@ def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, - self._tensor_space[MLPDimNames.composite_gated_expert_mlp], + intermediate_1_dim, bias=self._config.add_mlp_bias, weight_init_method=init_method_1, bias_init_method=init_zeros_, lr_scale=lr_scale, ) self.layer_2 = LinearBase( - self._intermediate_dim, + intermediate_2_dim, hidden_dim, bias=self._config.add_mlp_bias, weight_init_method=init_method_2, bias_init_method=init_zeros_, - auto_bias_grad_accumulation=self._tensor_space.distributed_config.tensor_parallel > 1, + auto_bias_grad_accumulation=self._distributed_config.tensor_parallel > 1, transposed_weight=True, lr_scale=lr_scale, ) @@ -72,11 +80,27 @@ def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: self.layer_1 = self._config.peft.apply_linear(self.layer_1, TransformerSubLayerName.mlp_1) self.layer_2 = self._config.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) + def _get_intermediate_dims(self): + intermediate_2_dim = TensorDim("intermediate", self._config.ffn_hidden_size, self._parallel_dim) + if self._config.gated: + TensorDim("gate_and_up", 2) + intermediate_1_dim = ConcatenatedTensorDim("gate_and_up", (intermediate_2_dim, intermediate_2_dim)) + else: + intermediate_1_dim = intermediate_2_dim + return intermediate_1_dim, intermediate_2_dim -class MLP(MLPBase): - def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: int = 0, name: str = "mlp"): + +class MLP[ConfigType: BlockConfig](MLPBase[ConfigType]): + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + ): Assert.eq(config.num_experts, 1) - super().__init__(config, tensor_space, block_index, name) + super().__init__(config, distributed_config, hidden_dim, block_index, name) def forward( self, @@ -85,7 +109,6 @@ def forward( losses: dict[str, typing.Any] | None = None, metrics: dict[str, typing.Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - parallel_group = self._intermediate_dim.parallel_group return ( mlp_autograd( input_, @@ -93,14 +116,14 @@ def forward( self.layer_1.weight, self.layer_1.bias, self.layer_2.weight, - None if parallel_group else self.layer_2.bias, + None if self._parallel_dim.group else self.layer_2.bias, gated=self._config.gated, activation_type=self._config.activation_type, - group=parallel_group, + group=self._parallel_dim.group, sequence_parallel=self._sequence_parallel, training=self.training, recompute_level=self._config.mlp_recompute_level, transposed_layer_2_weight=self.layer_2.transposed_weight, ), - self.layer_2.bias if parallel_group else None, + self.layer_2.bias if self._parallel_dim.group else None, ) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 2f45fdf9f..f56e2a2c1 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -9,7 +9,7 @@ if typing.TYPE_CHECKING: import torch - from fast_llm.engine.config_utils.tensor_space import TensorDim + from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.layers.common.linear import LinearBase, LinearLike from fast_llm.layers.common.normalization import LayerNorm, RMSNorm diff --git a/fast_llm/layers/common/linear.py b/fast_llm/layers/common/linear.py index 740b4847c..ca807e67c 100644 --- a/fast_llm/layers/common/linear.py +++ b/fast_llm/layers/common/linear.py @@ -4,7 +4,7 @@ import torch from fast_llm.engine.config_utils.initialization import init_zeros_ -from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.linear import ( input_parallel_linear_autograd, diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index d44be3297..2b928eb38 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -2,7 +2,7 @@ from fast_llm.engine.config_utils.initialization import init_ones_, init_zeros_ from fast_llm.engine.config_utils.run import log_main_rank -from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.normalization import triton_normalization_autograd from fast_llm.layers.common.config import NormalizationImplementation diff --git a/fast_llm/layers/common/peft.py b/fast_llm/layers/common/peft.py index 08f3e535b..87991ef29 100644 --- a/fast_llm/layers/common/peft.py +++ b/fast_llm/layers/common/peft.py @@ -2,7 +2,7 @@ import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.common.linear import Linear, LinearBase diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index b667e5318..de3f9f196 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -2,24 +2,13 @@ from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl -from fast_llm.layers.block.config import BlockDimNames, BlockKwargs +from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.layers.transformer.rotary.config import NoRotaryConfig from fast_llm.utils import Assert -class LanguageModelDimNames(BlockDimNames): - # Embedding dimensions - position_embed = "position_embed" - vocab = "vocab" - vocab_tp = "vocab_tp" - # Misc - scalar = "scalar" - - class LanguageModelLossNames: language_model_loss = "language_model_loss" z_loss = "z_loss" @@ -237,16 +226,6 @@ def _validate(self) -> None: len(self.transformer.per_layer_lr_scale), self.transformer.num_layers + self.prediction_heads - 1 + 1 ) - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - self.transformer.setup_tensor_space(tensor_space) - tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) - - # Embedding dimensions - tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.position_embed, self.max_position_embeddings)) - # TODO: Need both? - tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab, self.vocab_size)) - tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab_tp, self.vocab_size, tensor)) - @property def num_absolute_position_embeddings(self) -> int: # TODO: Rename from max embeddings. diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index d90442e9f..d1b912167 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -2,20 +2,21 @@ import torch -from fast_llm.config import Configurable from fast_llm.core.distributed import set_generator from fast_llm.core.ops import reduce_forward, split from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.initialization import init_normal_ -from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelDimNames, LanguageModelKwargs +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.layers.block.block import BlockLayerBase +from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import Assert WORD_EMBEDDINGS_WEIGHT = "word_embeddings_weight" -class LanguageModelEmbedding[ConfigType: LanguageModelBaseConfig](Configurable[LanguageModelBaseConfig], Layer): +class LanguageModelEmbedding[ConfigType: LanguageModelBaseConfig](BlockLayerBase[ConfigType], Layer): """ A language model embedding layer. Consists of word embeddings (tensor-parallel or sequence-tensor-parallel), @@ -28,35 +29,39 @@ class LanguageModelEmbedding[ConfigType: LanguageModelBaseConfig](Configurable[L def __init__( self, config: ConfigType, - tensor_space: TensorSpace, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + # TODO: Unnecessary? + block_index: int, + name: str, ): - super().__init__(config) - self._tensor_space = tensor_space - self._distributed_config = self._tensor_space.distributed_config + super().__init__( + config, + distributed_config, + hidden_dim, + block_index, + name, + config.transformer.debug_transformer, + config.transformer.debug_transformer_memory, + ) self._residual_dtype = ( self._distributed_config.optimization_dtype if config.transformer.full_precision_residual else self._distributed_config.training_dtype ).torch - self._group_size = self._distributed_config.tensor_parallel self._sequence_parallel = self._distributed_config.sequence_tensor_parallel - self._parallel_embeddings = ( - self._tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings + self._parallel_embeddings = self._distributed_config.tensor_parallel > 1 and config.parallel_embeddings + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + vocab_dim = TensorDim( + "vocab", self._config.vocab_size, self._parallel_dim if self._parallel_embeddings else None ) - self._dropout_p = config.transformer.hidden_dropout - self._use_absolute_position_embeddings = config.use_absolute_position_embeddings - - hidden_dim = self._tensor_space[LanguageModelDimNames.hidden] - vocab_dim = self._tensor_space[ - LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ] if self._parallel_embeddings: self._vocab_start_index = self._distributed_config.tensor_rank * vocab_dim.size self._vocab_end_index = (self._distributed_config.tensor_rank + 1) * vocab_dim.size self.word_embeddings_weight = ParameterMeta.from_dims( - (vocab_dim, hidden_dim), + (vocab_dim, self._hidden_dim), init_method=init_normal_( std=config.init_method_std_embed, min_val=config.init_method_min_embed, @@ -64,9 +69,9 @@ def __init__( ), lr_scale=config.embeddings_lr_scale, ) - if self._use_absolute_position_embeddings: + if self._config.use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( - (self._tensor_space[LanguageModelDimNames.position_embed], hidden_dim), + (TensorDim("position_embeddings", self._config.max_position_embeddings), self._hidden_dim), init_method=init_normal_( std=config.init_method_std_embed, min_val=config.init_method_min_embed, @@ -85,21 +90,21 @@ def __init__( @torch.compile def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask_inputs: bool) -> torch.Tensor: - Assert.eq(position_ids is not None, self._use_absolute_position_embeddings) - group = self._tensor_space.distributed.tensor_group + Assert.eq(position_ids is not None, self._config.use_absolute_position_embeddings) + group = self._parallel_dim.group if self._parallel_embeddings: input_mask = (input_ >= self._vocab_start_index) * (input_ < self._vocab_end_index) masked_input = (input_ - self._vocab_start_index) * input_mask embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * input_mask.unsqueeze(2) # noqa embeddings = reduce_forward(embeddings, group) - if self._use_absolute_position_embeddings: + if self._config.use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) else: if self._sequence_parallel: input_ = split(input_, group=group, dim=0) - if self._use_absolute_position_embeddings: + if self._config.use_absolute_position_embeddings: position_ids = split(position_ids, group=group, dim=0) # handle masked tokens if mask_inputs: @@ -108,16 +113,14 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask embeddings = torch.embedding(self.word_embeddings_weight, masked_input) else: embeddings = torch.embedding(self.word_embeddings_weight, input_) - if self._use_absolute_position_embeddings: + if self._config.use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if mask_inputs: embeddings = embeddings * input_mask.unsqueeze(2) with set_generator( - self._tensor_space.distributed.tp_generator - if self._sequence_parallel - else self._tensor_space.distributed.pp_generator + self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator ): - embeddings = torch.dropout(embeddings, self._dropout_p, self.training) + embeddings = torch.dropout(embeddings, self._config.transformer.hidden_dropout, self.training) return embeddings.to(dtype=self._residual_dtype) def forward( diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 8624612d6..cc6c69262 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -4,25 +4,20 @@ from torch._C._distributed_c10d import ReduceOp # noqa from torch.distributed import all_reduce -from fast_llm.config import Configurable from fast_llm.core.ops import split_op from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.initialization import init_normal_ -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.engine.distributed.config import DistributedDimNames +from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl, TargetFormat, TritonConfig from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward -from fast_llm.layers.block.block import DebugLayer +from fast_llm.layers.block.block import BlockLayerBase +from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss -from fast_llm.layers.language_model.config import ( - LanguageModelBaseConfig, - LanguageModelDimNames, - LanguageModelKwargs, - LanguageModelLossNames, -) +from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import Assert, div, get_unique @@ -32,7 +27,7 @@ OUTPUT_WEIGHTS = "output_weights" -class LanguageModelHead[ConfigType: LanguageModelBaseConfig](Configurable[ConfigType], Layer): +class LanguageModelHead[ConfigType: LanguageModelBaseConfig](BlockLayerBase[ConfigType], Layer): """ A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). """ @@ -40,31 +35,28 @@ class LanguageModelHead[ConfigType: LanguageModelBaseConfig](Configurable[Config def __init__( self, config: ConfigType, - tensor_space: TensorSpace, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + # TODO: Unnecessary? + block_index: int, + name: str, prediction_distance: int, ): - super().__init__(config) - self._debug = DebugLayer( - tensor_space, - f"Language model head", - self._config.transformer.debug_transformer, - self._config.transformer.debug_transformer_memory, + super().__init__( + config, + distributed_config, + hidden_dim, + block_index, + name, + config.transformer.debug_transformer, + config.transformer.debug_transformer_memory, ) - self._tensor_space = tensor_space + self._parallel_logits = self._distributed_config.tensor_parallel > 1 and config.parallel_embeddings + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) - self._group_size = tensor_space.distributed_config.tensor_parallel - self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel - self._parallel_embeddings = ( - tensor_space.distributed_config.tensor_parallel > 1 and self._config.parallel_embeddings - ) - self._sequence_parallel_logits = ( - tensor_space.distributed_config.sequence_tensor_parallel and not self._config.parallel_embeddings - ) - self._cross_entropy_splits = self._config.cross_entropy_splits - if self._cross_entropy_splits is not None and self._sequence_parallel: - assert not self._parallel_embeddings - - hidden_dim = self._tensor_space[LanguageModelDimNames.hidden] + self._sequence_parallel_logits = self._sequence_parallel and not self._config.parallel_embeddings + if self._config.cross_entropy_splits is not None and self._sequence_parallel: + assert not self._parallel_logits self._loss_coefficient = ( self._config.prediction_loss_coefficient[prediction_distance] @@ -72,11 +64,6 @@ def __init__( else 1.0 ) self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance) - self.final_norm = self._config.transformer.normalization.get_layer(hidden_dim) - self._logits_scale_factor = self._config.logits_scale_factor - self._language_model_loss_factor = self._config.language_model_loss_factor - self._distillation_loss_factor = self._config.distillation_loss_factor - self._z_loss_factor = self._config.logit_z_loss # Distance of the target token prediction # 0: next-token prediction @@ -85,14 +72,28 @@ def __init__( self._prediction_distance = prediction_distance self._is_last_head = self._prediction_distance == self._config.prediction_heads - 1 + if not self._config.enable_dpo: + self._cross_entropy_impl = self._config.cross_entropy_impl + if self._cross_entropy_impl == CrossEntropyImpl.auto: + if self._parallel_logits: + self._cross_entropy_impl = CrossEntropyImpl.fused + elif TritonConfig.TRITON_ENABLED: + self._cross_entropy_impl = CrossEntropyImpl.triton + else: + self._cross_entropy_impl = CrossEntropyImpl.fused + + self._forward = wrap_forward_backward(self._forward_backward, grad_is_context) + + self.final_norm = self._config.transformer.normalization.get_layer(hidden_dim) + + self._vocab_dim = TensorDim( + "vocab", self._config.vocab_size, self._parallel_dim if self._parallel_logits else None + ) # Only the first head defines the output weights if self._prediction_distance == 0 and not self._config.tie_word_embeddings: # untie embedding weights - vocab_dim = self._tensor_space[ - LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ] self.output_weights = ParameterMeta.from_dims( - (vocab_dim, hidden_dim), + (self._vocab_dim, hidden_dim), init_method=init_normal_( std=self._config.init_method_std_embed, min_val=self._config.init_method_min_embed, @@ -101,18 +102,6 @@ def __init__( lr_scale=self._config.output_lr_scale, ) - if not self._config.enable_dpo: - self._cross_entropy_impl = self._config.cross_entropy_impl - if self._cross_entropy_impl == CrossEntropyImpl.auto: - if self._parallel_embeddings: - self._cross_entropy_impl = CrossEntropyImpl.fused - elif TritonConfig.TRITON_ENABLED: - self._cross_entropy_impl = CrossEntropyImpl.triton - else: - self._cross_entropy_impl = CrossEntropyImpl.fused - - self._forward = wrap_forward_backward(self._forward_backward, grad_is_context) - # PEFT. self.final_norm = self._config.transformer.peft.apply_other(self.final_norm) if hasattr(self, "output_weights"): @@ -123,11 +112,12 @@ def forward( ) -> torch.Tensor: if isinstance(input_, TensorMeta): if self._is_last_head: - return TensorMeta.from_tensor_space( - (DefaultDimNames.scalar,), - self._tensor_space, + return TensorMeta.from_dims( + (scalar_dim,), tensor_name="Loss", - reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa + reductions=( + (self._distributed_config.get_distributed_dim(DistributedDimNames.data), ReduceOp.AVG), + ), # noqa ) else: return TensorMeta.from_dims(input_.dims[1:], tensor_name="Shared hidden") @@ -169,19 +159,19 @@ def _forward_backward( sequence_index = 1 - int(kwargs[LanguageModelKwargs.sequence_first]) dims[sequence_index] = ( TensorDim( - LanguageModelDimNames.sequence_q_tp, + BlockDimNames.sequence_q_tp, dims[sequence_index].global_size, DistributedDimNames.tensor, ) if self._sequence_parallel_logits - else TensorDim(LanguageModelDimNames.sequence_q, dims[sequence_index].global_size) + else TensorDim(BlockDimNames.sequence_q, dims[sequence_index].global_size) ) meta = TensorMeta.from_dims(tuple(dims), tensor_name="transformer hidden_state", dtype=ln_output.dtype) - hidden_state, _ = meta.local_to_global(ln_output.detach(), distributed=self._tensor_space.distributed) + hidden_state, _ = meta.local_to_global(ln_output.detach()) kwargs["hidden_states"][len(kwargs["hidden_states"]) - 1]["tensor"] = hidden_state grad_output = kwargs[LanguageModelKwargs.grad_output] / ( - self._group_size if self._sequence_parallel_logits else 1 + self._parallel_dim.size if self._sequence_parallel_logits else 1 ) output_weights = self._get_output_weights(kwargs) @@ -215,7 +205,7 @@ def _get_targets( if loss_mask is not None: loss_mask = loss_mask.flatten() - if self._config.distillation_model is None or self._language_model_loss_factor > 0.0: + if self._config.distillation_model is None or self._config.language_model_loss_factor > 0.0: lm_target = kwargs.get(LanguageModelKwargs.labels) if lm_target is not None: # MTP: Shift the labels @@ -239,10 +229,7 @@ def _get_targets( targets = (dpo_target, lm_target, distillation_target, loss_mask) if self._sequence_parallel_logits: - targets = [ - None if target is None else split_op(target, self._tensor_space.distributed.tensor_group, 0) - for target in targets - ] + targets = [None if target is None else split_op(target, self._parallel_dim.group, 0) for target in targets] if not any(target is not None for target in targets): # Simplify so we don't have to check every time. targets = None @@ -264,7 +251,7 @@ def _logits_cross_entropy_forward_backward_split( kwargs: dict, losses: dict | None = None, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: - if self._cross_entropy_splits is None or targets is None: + if self._config.cross_entropy_splits is None or targets is None: loss, logit_input_grad = self._logits_cross_entropy_forward_backward( input_, targets, weight, grad_output, kwargs, losses ) @@ -275,17 +262,18 @@ def _logits_cross_entropy_forward_backward_split( else: loss = None # TODO MTP: allow a _cross_entropy_splits that is not a divisor of the sequence length - grad_output /= self._cross_entropy_splits + grad_output /= self._config.cross_entropy_splits logit_input = input_.flatten(0, -2) if self.training: logit_input_grad = torch.empty_like(logit_input) else: logit_input_grad = None split_size = div( - get_unique(target.size(0) for target in targets if target is not None), self._cross_entropy_splits + get_unique(target.size(0) for target in targets if target is not None), + self._config.cross_entropy_splits, ) tensors_split = [ - [None] * self._cross_entropy_splits if tensor is None else tensor.split(split_size) + [None] * self._config.cross_entropy_splits if tensor is None else tensor.split(split_size) for tensor in [logit_input, *targets, logit_input_grad] ] for logit_input_, *targets_, logit_input_grad_ in zip(*tensors_split, strict=True): @@ -301,12 +289,14 @@ def _logits_cross_entropy_forward_backward_split( logit_input_grad_.copy_(grad_) loss = loss_ if loss is None else loss + loss_ del grad_, loss_ - loss_count = (self._cross_entropy_splits or 1) * (self._group_size if self._sequence_parallel_logits else 1) + loss_count = (self._config.cross_entropy_splits or 1) * ( + self._parallel_dim.size if self._sequence_parallel_logits else 1 + ) if loss_count != 1: loss.div_(loss_count) if self._sequence_parallel_logits: # TODO: Async - all_reduce(loss, group=self._tensor_space.distributed.tensor_group) + all_reduce(loss, group=self._parallel_dim.group) return loss, logit_input_grad.view_as(input_) if logit_input_grad is not None else None def _logits_cross_entropy_forward_backward( @@ -318,43 +308,37 @@ def _logits_cross_entropy_forward_backward( kwargs: dict, losses: dict | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: + group = self._parallel_dim.group if self._parallel_logits else None logits, context = output_parallel_linear_forward( input_=input_, weight=weight, bias=None, - group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, - sequence_parallel=self._sequence_parallel and self._parallel_embeddings, + group=group, + sequence_parallel=self._sequence_parallel and self._parallel_logits, ) - if self._z_loss_factor > 0.0: + if self._config.logit_z_loss > 0.0: logits = z_loss( logits, - self._z_loss_factor, + self._config.logit_z_loss, self.training, grad_output, losses, LanguageModelLossNames.z_loss, - logits_scale_factor=self._logits_scale_factor, - ) - if self._debug.enabled and self._cross_entropy_splits is None: - vocab_dim = ( - LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp - ) - sequence_dim = ( - LanguageModelDimNames.sequence_q_tp - if self._sequence_parallel_logits - else LanguageModelDimNames.sequence_q + logits_scale_factor=self._config.logits_scale_factor, ) + if self._debug.enabled and self._config.cross_entropy_splits is None: + sequence_dim = BlockDimNames.sequence_q_tp if self._sequence_parallel_logits else BlockDimNames.sequence_q batch_dim = kwargs[LanguageModelKwargs.hidden_dims][1 if kwargs[LanguageModelKwargs.sequence_first] else 0] dims = ( - (sequence_dim, batch_dim, vocab_dim) + (sequence_dim, batch_dim, self._vocab_dim) if kwargs[LanguageModelKwargs.sequence_first] - else (batch_dim, sequence_dim, vocab_dim) + else (batch_dim, sequence_dim, self._vocab_dim) ) - self._debug(logits, "Language model logits", dims, kwargs, scale=self._logits_scale_factor) + self._debug(logits, "Language model logits", dims, kwargs, scale=self._config.logits_scale_factor) if targets is None: - return logits * self._logits_scale_factor, None + return logits * self._config.logits_scale_factor, None dpo_target, lm_target, distillation_target, loss_mask = targets if dpo_target is not None: @@ -375,25 +359,25 @@ def _logits_cross_entropy_forward_backward( logits.flatten(0, -2), lm_target, None, - group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, - grad_output=grad_output * self._loss_coefficient * self._language_model_loss_factor, + group=group, + grad_output=grad_output * self._loss_coefficient * self._config.language_model_loss_factor, implementation=self._cross_entropy_impl, - logits_scale_factor=self._logits_scale_factor, + logits_scale_factor=self._config.logits_scale_factor, target_format=TargetFormat.labels, ) - lm_loss = lm_loss * self._language_model_loss_factor + lm_loss = lm_loss * self._config.language_model_loss_factor else: lm_loss, lm_grad = None, None - if distillation_target is not None and self._distillation_loss_factor > 0.0: + if distillation_target is not None and self._config.distillation_loss_factor > 0.0: if self._config.distillation_loss_implementation == DistillationLossImpl.reverse_kl: distillation_loss, distillation_grad = reverse_kl_forward_backward( logits.flatten(0, -2), distillation_target, loss_mask, - grad_output=grad_output * self._loss_coefficient * self._distillation_loss_factor, - group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, - logits_scale_factor=self._logits_scale_factor, + grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, + group=group, + logits_scale_factor=self._config.logits_scale_factor, teacher_softmax_temperature=self._config.teacher_softmax_temperature, target_format=( TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits @@ -404,17 +388,17 @@ def _logits_cross_entropy_forward_backward( logits.flatten(0, -2), distillation_target, loss_mask, - group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, - grad_output=grad_output * self._loss_coefficient * self._distillation_loss_factor, + group=group, + grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, implementation=self._cross_entropy_impl, - logits_scale_factor=self._logits_scale_factor, + logits_scale_factor=self._config.logits_scale_factor, target_format=TargetFormat.logits, ) else: raise ValueError( f"Invalid distillation loss implementation: {self._config.distillation_loss_implementation}" ) - distillation_loss = distillation_loss * self._distillation_loss_factor + distillation_loss = distillation_loss * self._config.distillation_loss_factor else: distillation_loss, distillation_grad = None, None diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index f5d915855..5ba31c0d0 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -4,7 +4,8 @@ import torch from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import scalar_dim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert @@ -13,40 +14,31 @@ class PositionEmbeddingPreprocessor(Preprocessor): - _scalar_dim: TensorDim _rotary_embedding_frequencies: torch.Tensor _position_ids: torch.Tensor _tensor_cache_max_sequence_length: int = -1 - def __init__( - self, - config: LanguageModelBaseConfig, - tensor_space: TensorSpace, - ): + def __init__(self, config: LanguageModelBaseConfig, distributed_config: DistributedConfig): self._config = config assert config.use_absolute_position_embeddings - self._tensor_space = tensor_space - self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] + self._distributed_config = distributed_config - def _create_tensors(self, sequence_length: int) -> None: + def _create_tensors(self, sequence_length: int, device: torch.device) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: return self._tensor_cache_max_sequence_length = sequence_length Assert.leq(sequence_length, self._config.num_absolute_position_embeddings) - self._position_ids = torch.arange( - 0, sequence_length, device=self._tensor_space.distributed.device, dtype=torch.int64 - ) + self._position_ids = torch.arange(0, sequence_length, device=device, dtype=torch.int64) - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[LanguageModelKwargs.sequence_length]) + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + self._create_tensors(kwargs[LanguageModelKwargs.sequence_length], batch.device) sequence_k = kwargs[LanguageModelKwargs.sequence_k_dim].size sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size if (sequence_lengths := kwargs.get(LanguageModelKwargs.sequence_lengths)) is not None: position_ids = torch.stack( [torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] - ).to(self._tensor_space.distributed.device, dtype=torch.int64) + ).to(batch.device, dtype=torch.int64) position_ids = position_ids[:, sequence_k - sequence_q : sequence_k] if kwargs[LanguageModelKwargs.sequence_first]: position_ids = position_ids.transpose(0, 1) @@ -61,9 +53,9 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: sequence_q_dim = kwargs[LanguageModelKwargs.sequence_q_dim] kwargs[LanguageModelKwargs.position_ids] = TensorMeta.from_dims( ( - (sequence_q_dim, self._scalar_dim) + (sequence_q_dim, scalar_dim) if kwargs[LanguageModelKwargs.sequence_first] - else (self._scalar_dim, sequence_q_dim) + else (scalar_dim, sequence_q_dim) ), tensor_name=LanguageModelKwargs.position_ids, dtype=torch.int64, @@ -71,11 +63,9 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: class PreferenceSpanPreprocessor(Preprocessor): - def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): + def __init__(self, config: LanguageModelBaseConfig, distributed_config: DistributedConfig): self._config = config - self._tensor_space = tensor_space - self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] + self._distributed_config = distributed_config def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: return diff --git a/fast_llm/layers/ssm/block.py b/fast_llm/layers/ssm/block.py index 987d5fa0d..361fe9818 100644 --- a/fast_llm/layers/ssm/block.py +++ b/fast_llm/layers/ssm/block.py @@ -1,34 +1,37 @@ -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.block.block import Block, BlockLayer from fast_llm.layers.block.config import BlockConfig from fast_llm.layers.ssm.config import SSMConfig # TODO: Sort out configs. -class SSMBlock[ConfigType: BlockConfig](Block[BlockConfig]): +class SSMBlock[ConfigType: BlockConfig](Block[ConfigType]): """ A transformer-like decoder block with a SSM mixer, see https://arxiv.org/abs/2502.14458 """ - _name = "Llamba block" - def __init__( self, - config: BlockConfig, + config: ConfigType, ssm_config: SSMConfig, - tensor_space: TensorSpace, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, mixer_cls: type[BlockLayer], block_index: int, + name: str, return_input: bool = False, ): self._ssm_config = ssm_config self._mixer_cls = mixer_cls - super().__init__(config, tensor_space, block_index, return_input) + super().__init__(config, distributed_config, hidden_dim, block_index, name, return_input) def _create_mixer(self) -> BlockLayer: return self._mixer_cls( self._ssm_config, - tensor_space=self._tensor_space, - block_index=self._block_index, - block_config=self._config, + self._config, + self._distributed_config, + self._hidden_dim, + self._block_index, + f"{self._name} mixer", ) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index dec0675b9..2daad1186 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -2,11 +2,9 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, TensorSpace -from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.block.config import BlockDimNames -from fast_llm.utils import Assert, div +from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.engine.config_utils.initialization import Initializer @@ -46,9 +44,9 @@ class SSMBlockType(enum.StrEnum): def get_mixer_class(self): if self == SSMBlockType.mamba: - from fast_llm.layers.ssm.mamba_layer import MambaLayer + from fast_llm.layers.ssm.mamba import Mamba - return MambaLayer + return Mamba elif self == SSMBlockType.mamba2: from fast_llm.layers.ssm.mamba2 import Mamba2 @@ -79,21 +77,21 @@ class SSMConfig(Config): # TODO: Remove (redundant default) expansion_factor: int = Field( default=2, - desc="Expansion factor for Mamba blocks.", + desc="Expansion factor.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) # head_size [MambaLayer, Mamba2, DiscreteMamba2] state_size: int = Field( default=16, - desc="State size for Mamba blocks.", + desc="State size.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) # [MambaLayer, Mamba2, DiscreteMamba2] conv_kernel_dimension: int = Field( default=4, - desc="Conv kernel dimension for Mamba blocks.", + desc="Conv kernel dimension.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) @@ -106,19 +104,19 @@ class SSMConfig(Config): # head_groups [DiscreteMamba2] n_qk_heads: int = Field( default=32, - desc="Number of QK heads for Mamba2 blocks.", + desc="Number of QK heads.", hint=FieldHint.architecture, ) # heads [DiscreteMamba2]# TODO: Remove? (redundant) n_v_heads: int = Field( default=32, - desc="Number of V heads for Mamba2 blocks.", + desc="Number of V heads.", hint=FieldHint.architecture, ) # c_size [MambaLayer, Mamba2, DiscreteMamba2]? d_inner: None | int = Field( default=None, - desc="Inner dimension for Mamba2 blocks.", + desc="Inner dimension.", hint=FieldHint.core, ) # xb_size [Mamba2] @@ -204,79 +202,3 @@ def _validate(self) -> None: self.activation_type = ActivationType.silu super()._validate() Assert.geq(self.dt_max, self.dt_min) - - def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType) -> None: - tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) - - # Head groups are configured differently depending on the block type. - if block_type == SSMBlockType.mamba: - num_heads = div(self.d_inner, self.state_size) - num_head_groups = num_heads - elif block_type == SSMBlockType.mamba2: - num_heads = div(self.d_inner, self.state_size) - num_head_groups = div(self.d_xb, self.state_size) - elif block_type == SSMBlockType.mamba2_discrete: - # TODO: Use different variables? - num_heads = self.n_v_heads - num_head_groups = self.n_qk_heads - else: - raise NotImplementedError(block_type) - - tensor_space.add_tensor_dim(state := TensorDim(SSMDimNames.state, self.state_size)) - if block_type == SSMBlockType.mamba2_discrete: - tensor_space.add_tensor_dim(head_dim := TensorDim(SSMDimNames.head_dim, div(self.d_inner, num_heads))) - else: - head_dim = state - - tensor_space.add_tensor_dim(head_groups := TensorDim(SSMDimNames.head_groups, num_head_groups, tensor)) - tensor_space.add_tensor_dim(group_heads := TensorDim(SSMDimNames.group_heads, div(num_heads, num_head_groups))) - tensor_space.add_tensor_dim( - heads := CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads)) - ) - tensor_space.add_tensor_dim( - heads_and_head_dim := CompositeTensorDim( - SSMDimNames.composite_heads_and_head_dim, (head_groups, group_heads, head_dim) - ) - ) - tensor_space.add_tensor_dim( - head_groups_and_state := CompositeTensorDim( - SSMDimNames.composite_head_groups_and_state, (head_groups, state) - ) - ) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.convolution_kernel, self.conv_kernel_dimension)) - - # DT projection - if block_type in (SSMBlockType.mamba, SSMBlockType.mamba2): - tensor_space.add_tensor_dim(dt_rank := TensorDim(SSMDimNames.dt_rank, self.dt_rank)) - - if block_type == SSMBlockType.mamba: - tensor_space.add_tensor_dim( - ConcatenatedTensorDim(SSMDimNames.concatenated_x_projection, (dt_rank, state, state)) - ) - # TODO: Use composition instead - tensor_space.add_tensor_dim( - ConcatenatedTensorDim( - SSMDimNames.concatenated_inner_projection, (heads_and_head_dim, heads_and_head_dim) - ) - ) - elif block_type == SSMBlockType.mamba2: - # TODO: Factor out state? - tensor_space.add_tensor_dim( - ConcatenatedTensorDim( - SSMDimNames.concatenated_inner_projection, - (heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim), - ) - ) - elif block_type == SSMBlockType.mamba2_discrete: - tensor_space.add_tensor_dim( - ConcatenatedTensorDim( - SSMDimNames.concatenated_inner_projection, - (heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim, heads), - ) - ) - tensor_space.add_tensor_dim( - ConcatenatedTensorDim( - SSMDimNames.concatenated_convolution, - (heads_and_head_dim, head_groups_and_state, head_groups_and_state), - ) - ) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 61291f845..7e445cca1 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -5,15 +5,16 @@ import torch from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_, init_zeros_ -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, scalar_dim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.ssm.mamba_layer import init_kaiming_ +from fast_llm.layers.ssm.mamba import init_kaiming_ from fast_llm.tensor import ParameterMeta -from fast_llm.utils import get_lr_scale +from fast_llm.utils import div, get_lr_scale logger = logging.getLogger(__name__) @@ -34,48 +35,69 @@ _causal_conv1d_available = False -class DiscreteMamba2(BlockLayer): - """DiscreteMamba2 (This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py).""" +class DiscreteMamba2[ConfigType: SSMConfig](BlockLayer[ConfigType]): + """ + This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py + """ _mixer_name: typing.ClassVar[str] = "discrete_mamba_2" + _config: SSMConfig def __init__( self, - config: SSMConfig, - block_index: int, - tensor_space: TensorSpace, + config: ConfigType, block_config: BlockConfig, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, ): super().__init__( - tensor_space, + config, + distributed_config, + hidden_dim, block_index, - self._mixer_name, - debug_level=block_config.debug_transformer, - debug_memory=block_config.debug_transformer_memory, + name, + block_config.debug_transformer, + block_config.debug_transformer_memory, ) - self._config: SSMConfig = config - layer_lr_scale = block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None - lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + state_dim = TensorDim("state", self._config.state_size) + v_head_size_dim = TensorDim(SSMDimNames.head_dim, div(self._config.d_inner, self._config.n_v_heads)) + + head_groups_dim = TensorDim( + SSMDimNames.head_groups, + self._config.n_qk_heads, + self._distributed_config.get_distributed_dim(DistributedDimNames.tensor), + ) + group_heads_dim = TensorDim(SSMDimNames.group_heads, div(self._config.n_v_heads, self._config.n_qk_heads)) + heads_dim = CompositeTensorDim(SSMDimNames.composite_heads, (head_groups_dim, group_heads_dim)) + inner_dim = CompositeTensorDim("inner", (head_groups_dim, group_heads_dim, v_head_size_dim)) + bc_dim = CompositeTensorDim("bc", (head_groups_dim, state_dim)) + convolution_kernel_dim = TensorDim("convolution_kernel", self._config.conv_kernel_dimension) - inner_dim = tensor_space[SSMDimNames.composite_heads_and_head_dim] - hidden_dim = tensor_space[SSMDimNames.hidden] - conv1d_dim = tensor_space[SSMDimNames.concatenated_convolution] - heads_dim = tensor_space[SSMDimNames.composite_heads] + inner_projection_dim = ConcatenatedTensorDim( + "inner_projection", + (inner_dim, bc_dim, bc_dim, inner_dim, heads_dim), + ) + convolution_dim = ConcatenatedTensorDim("convolution", (inner_dim, bc_dim, bc_dim)) # local_head_groups = head_groups / TP - self._local_head_groups = tensor_space[SSMDimNames.head_groups].size + self._local_head_groups = head_groups_dim.size # local_heads = local_head_groups * group_heads self._local_heads = heads_dim.size # local_inner_size = local_heads * head_size self._local_inner_size = inner_dim.size # local_bc_size = local_head_groups * state - self._local_bc_size = tensor_space[SSMDimNames.composite_head_groups_and_state].size + self._local_bc_size = bc_dim.size + + layer_lr_scale = block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None + lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) # TODO: double check initializations # Projections self.in_proj = OutputParallelLinear( hidden_dim, - tensor_space[SSMDimNames.concatenated_inner_projection], + inner_projection_dim, bias=config.add_bias_linear, weight_init_method=init_kaiming_(block_config.hidden_size), sequence_parallel=self._sequence_parallel, @@ -90,15 +112,17 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( ( - conv1d_dim, - tensor_space[DefaultDimNames.scalar], - tensor_space[SSMDimNames.convolution_kernel], + convolution_dim, + scalar_dim, + convolution_kernel_dim, + ), + init_method=init_uniform_centered_( + (convolution_dim.global_size * self._config.conv_kernel_dimension) ** -0.5 ), - init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( - (conv1d_dim,), + (convolution_dim,), init_method=init_uniform_centered_(self._config.conv_kernel_dimension**-0.5), lr_scale=lr_scale, ) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba.py similarity index 79% rename from fast_llm/layers/ssm/mamba_layer.py rename to fast_llm/layers/ssm/mamba.py index 0dcc29f0b..ac6576a87 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba.py @@ -5,14 +5,15 @@ import torch from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, scalar_dim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.common.linear import Linear -from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.ssm.config import SSMConfig from fast_llm.tensor import ParameterMeta -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.utils import Assert, div, get_lr_scale try: from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn as _mamba_inner_fn # noqa @@ -53,31 +54,40 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) return LambdaInitializer(init_) -class MambaLayer(BlockLayer): +class Mamba[ConfigType: SSMConfig](BlockLayer[ConfigType]): _mixer_name: typing.ClassVar[str] = "mamba" def __init__( self, - config: SSMConfig, - block_index: int, - tensor_space: TensorSpace, + config: ConfigType, block_config: BlockConfig, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, ): super().__init__( - tensor_space, + config, + distributed_config, + hidden_dim, block_index, - self._mixer_name, - debug_level=block_config.debug_transformer, - debug_memory=block_config.debug_transformer_memory, + name, + block_config.debug_transformer, + block_config.debug_transformer_memory, ) - assert tensor_space.distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" - self._config = config + assert self._distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" # TODO: It's not silu? Assert.eq(self._config.activation_type, ActivationType.silu) # Tensor dims: - inner_dim = tensor_space[SSMDimNames.composite_heads_and_head_dim] - hidden_dim = tensor_space[SSMDimNames.hidden] + heads_dim = TensorDim("heads", div(self._config.d_inner, self._config.state_size)) + state_dim = TensorDim("state", self._config.state_size) + inner_dim = CompositeTensorDim("inner", (heads_dim, state_dim)) + convolution_kernel_dim = TensorDim("convolution_kernel", self._config.conv_kernel_dimension) + dt_rank_dim = TensorDim("dt_rank", self._config.dt_rank) + inner_projection_dim = ConcatenatedTensorDim("inner_projection", (inner_dim, inner_dim)) + x_projection_dim = ConcatenatedTensorDim("x_projection", (dt_rank_dim, state_dim, state_dim)) + layer_lr_scale = block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) @@ -85,7 +95,7 @@ def __init__( # TODO: lr_scale? self.in_proj = Linear( hidden_dim, - tensor_space[SSMDimNames.concatenated_inner_projection], + inner_projection_dim, bias=False, weight_init_method=init_kaiming_(hidden_dim.size), ) @@ -93,8 +103,8 @@ def __init__( self.conv1d_weight = ParameterMeta.from_dims( ( inner_dim, - tensor_space[DefaultDimNames.scalar], - tensor_space[SSMDimNames.convolution_kernel], + scalar_dim, + convolution_kernel_dim, ), init_method=init_kaiming_(inner_dim.size), lr_scale=lr_scale, @@ -102,7 +112,7 @@ def __init__( self.x_proj = Linear( inner_dim, - tensor_space[SSMDimNames.concatenated_x_projection], + x_projection_dim, weight_init_method=init_kaiming_(inner_dim.size), bias=False, lr_scale=lr_scale, @@ -111,7 +121,7 @@ def __init__( # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 self.dt_proj_weight = ParameterMeta.from_dims( - (inner_dim, tensor_space[SSMDimNames.dt_rank]), + (inner_dim, dt_rank_dim), init_method=init_kaiming_(self._config.dt_rank), lr_scale=lr_scale, ) @@ -123,7 +133,7 @@ def __init__( ) self.A_log = ParameterMeta.from_dims( - (inner_dim, tensor_space[SSMDimNames.state]), + (inner_dim, state_dim), weight_decay=False, init_method=init_A(self._config.state_size, inner_dim.size), lr_scale=lr_scale, diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index b6626e893..e6ca9ea12 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -4,13 +4,14 @@ import torch from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_ -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, scalar_dim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer -from fast_llm.layers.block.config import BlockConfig, BlockKwargs +from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear -from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias, init_kaiming_ +from fast_llm.layers.ssm.config import SSMConfig +from fast_llm.layers.ssm.mamba import init_A, init_dtprojbias, init_kaiming_ from fast_llm.tensor import ParameterMeta from fast_llm.utils import Assert, div, get_lr_scale @@ -31,38 +32,30 @@ logger = logging.getLogger(__name__) -class Mamba2(BlockLayer): +class Mamba2[ConfigType: SSMConfig](BlockLayer[ConfigType]): """ This code is adapted from https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py """ _mixer_name: typing.ClassVar[str] = "mamba_2" - _XZ_DIMS = ( - SSMDimNames.batch, - SSMDimNames.composite_heads_and_head_dim, - SSMDimNames.sequence_q, - ) - _BC_DIMS = ( - SSMDimNames.batch, - SSMDimNames.composite_heads, - SSMDimNames.state, - SSMDimNames.sequence_q, - ) - def __init__( self, - config: SSMConfig, - tensor_space: TensorSpace, - block_index: int, + config: ConfigType, block_config: BlockConfig, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, ): super().__init__( - tensor_space, + config, + distributed_config, + hidden_dim, block_index, - self._mixer_name, - debug_level=block_config.debug_transformer, - debug_memory=block_config.debug_transformer_memory, + name, + block_config.debug_transformer, + block_config.debug_transformer_memory, ) self._config: SSMConfig = config Assert.eq(self._config.activation_type, ActivationType.silu) @@ -71,13 +64,32 @@ def __init__( ) lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - inner_dim: TensorDim = tensor_space[SSMDimNames.composite_heads_and_head_dim] - xb_dim = tensor_space[SSMDimNames.composite_head_groups_and_state] - hidden_dim: TensorDim = tensor_space[SSMDimNames.hidden] - dt_rank_dim = tensor_space[SSMDimNames.dt_rank] + num_heads = div(self._config.d_inner, self._config.state_size) + num_head_groups = div(self._config.d_xb, self._config.state_size) + + state_dim = TensorDim("state", self._config.state_size) + + head_groups_dim = TensorDim( + "head_groups", num_head_groups, self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + ) + group_heads_dim = TensorDim("group_heads", div(num_heads, num_head_groups)) + + heads_dim = CompositeTensorDim("heads", (head_groups_dim, group_heads_dim)) - self._local_heads = tensor_space[SSMDimNames.composite_heads].size - self._local_head_groups = tensor_space[SSMDimNames.head_groups].size + inner_dim = CompositeTensorDim("inner", (head_groups_dim, group_heads_dim, state_dim)) + xb_dim = CompositeTensorDim("xb", (head_groups_dim, state_dim)) + convolution_kernel_dim = TensorDim("convolution_kernel", self._config.conv_kernel_dimension) + + # DT projection + dt_rank_dim = TensorDim("dt_rank", self._config.dt_rank) + + inner_projection_dim = ConcatenatedTensorDim( + "inner_projection", + (inner_dim, xb_dim, xb_dim, inner_dim), + ) + + self._local_heads = heads_dim.size + self._local_head_groups = head_groups_dim.size self._group_heads = div(self._local_heads, self._local_head_groups) self._local_inner_size = inner_dim.size self._local_xb_size = xb_dim.size @@ -86,8 +98,8 @@ def __init__( self.conv1d_weight = ParameterMeta.from_dims( ( conv1d_dim, - tensor_space[DefaultDimNames.scalar], - tensor_space[SSMDimNames.convolution_kernel], + scalar_dim, + convolution_kernel_dim, ), init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, @@ -99,7 +111,7 @@ def __init__( ) self.in_proj = OutputParallelLinear( hidden_dim, - tensor_space[SSMDimNames.concatenated_inner_projection], + inner_projection_dim, bias=config.add_bias_linear, weight_init_method=init_kaiming_(block_config.hidden_size), sequence_parallel=self._sequence_parallel, @@ -131,7 +143,7 @@ def __init__( lr_scale=lr_scale, ) self.A_log = ParameterMeta.from_dims( - (inner_dim, tensor_space[SSMDimNames.state]), + (inner_dim, state_dim), init_method=init_A(self._config.state_size, self._config.d_inner), lr_scale=lr_scale, weight_decay=False, @@ -151,6 +163,19 @@ def __init__( # TODO: lr_scale? ) + if self._debug.enabled: + _xz_dims = ( + BlockDimNames.batch, + inner_dim, + BlockDimNames.sequence_q, + ) + _bc_dims = ( + BlockDimNames.batch, + heads_dim, + state_dim, + BlockDimNames.sequence_q, + ) + def forward( self, input_: torch.Tensor, diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index ba7f2bb6e..d7a669295 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -5,13 +5,15 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_ -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.block.block import BlockLayer +from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs, TransformerConfig -from fast_llm.utils import get_lr_scale +from fast_llm.layers.transformer.config import AttentionKwargs, TransformerConfig +from fast_llm.utils import div, get_lr_scale try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -46,41 +48,58 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # no return grad, None -class Attention(BlockLayer): +class Attention[ConfigType: TransformerConfig](BlockLayer[ConfigType]): """ A self-attention layer. """ - _mixer_name: typing.ClassVar[str] = "attn" - - _QUERY_DIMS = ( - AttentionDimNames.batch, - AttentionDimNames.sequence_q, - AttentionDimNames.composite_heads, - AttentionDimNames.kv_channels, - ) - _KV_DIMS = ( - AttentionDimNames.batch, - AttentionDimNames.sequence_q, - AttentionDimNames.head_groups, - AttentionDimNames.kv_channels, - ) - _CONTEXT_DIMS = ( - AttentionDimNames.batch, - AttentionDimNames.sequence_q, - AttentionDimNames.composite_dense, - ) - - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int): + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + ): super().__init__( - tensor_space, + config, + distributed_config, + hidden_dim, block_index, - self._mixer_name, - debug_level=config.debug_transformer, - debug_memory=config.debug_transformer_memory, + name, + config.debug_transformer, + config.debug_transformer_memory, + ) + self._use_flash_attention = self._config.do_use_flash_attention(self._distributed_config) + + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + self._sequence_data_parallel_dim = self._distributed_config.get_distributed_dim( + DistributedDimNames.sequence_data ) - self._config = config - self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) + head_group_dim = TensorDim( + "head_groups", self._config.head_groups, self._parallel_dim if self._config.head_groups > 1 else None + ) + group_heads_dim = TensorDim( + "group_heads", + div(self._config.num_attention_heads, self._config.head_groups), + None if self._config.head_groups > 1 else self._parallel_dim, + ) + self._local_head_groups = head_group_dim.size + self._local_heads_per_group = group_heads_dim.size + self._local_heads = self._local_head_groups * self._local_heads_per_group + + kv_channels_dim = TensorDim("kv_channels", self._config.kv_channels) + query_dim = CompositeTensorDim("query", (head_group_dim, group_heads_dim, kv_channels_dim)) + key_value_dim = ConcatenatedTensorDim( + "key_value", + ( + CompositeTensorDim("key", (head_group_dim, kv_channels_dim)), + CompositeTensorDim("value", (head_group_dim, kv_channels_dim)), + ), + ) + dense_dim = CompositeTensorDim("dense", (head_group_dim, group_heads_dim, kv_channels_dim)) + + self._softmax_scale = self._config.kv_channels ** (-self._config.attention_softmax_scale_power) init_method_qkv = init_normal_( std=self._config.init_method_std_qkv, @@ -93,22 +112,13 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i max_val=self._config.init_method_max_attn_proj, ) - self._kv_channels = self._tensor_space[AttentionDimNames.kv_channels].size - self._head_groups = self._tensor_space[AttentionDimNames.head_groups].global_size - self._local_head_groups = self._tensor_space[AttentionDimNames.head_groups].size - self._local_heads_per_group = self._tensor_space[AttentionDimNames.group_heads].size - self._local_heads = self._local_head_groups * self._local_heads_per_group - self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power) - - hidden_dim = self._tensor_space[AttentionDimNames.hidden] - layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, - self._tensor_space[AttentionDimNames.composite_query], + query_dim, bias=self._config.add_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_zeros_, @@ -117,7 +127,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i ) self.key_value = OutputParallelLinear( hidden_dim, - self._tensor_space[AttentionDimNames.composite_key_value], + key_value_dim, bias=self._config.add_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_zeros_, @@ -127,11 +137,11 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward) # Rotary embeddings. - self._rotary = self._config.rotary.build() + self._rotary = self._config.rotary.build(kv_channels_dim) # Output. self.dense = InputParallelLinear( - self._tensor_space[AttentionDimNames.composite_dense], + dense_dim, hidden_dim, bias=self._config.add_dense_bias, weight_init_method=init_method_std_attn_proj, @@ -145,6 +155,25 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i self.key_value = self._config.peft.apply_linear(self.key_value, TransformerSubLayerName.key_value) self.dense = self._config.peft.apply_linear(self.dense, TransformerSubLayerName.dense) + if self._debug.enabled: + self._query_dims = ( + BlockDimNames.batch, + BlockDimNames.sequence_q, + CompositeTensorDim("heads", (head_group_dim, group_heads_dim)), + kv_channels_dim, + ) + self._kv_dims = ( + BlockDimNames.batch, + BlockDimNames.sequence_q, + head_group_dim, + kv_channels_dim, + ) + self._context_dims = ( + BlockDimNames.batch, + BlockDimNames.sequence_q, + dense_dim, + ) + def _attn_fused( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor ) -> torch.Tensor: @@ -153,16 +182,18 @@ def _attn_fused( sk = key.size(1) if self._local_head_groups == 1: - query = query.view(b, sq * self._local_heads, self._kv_channels) + query = query.view(b, sq * self._local_heads, self._config.kv_channels) key = key.transpose(-1, -2) else: query = ( - query.unflatten(-1, (self._local_head_groups, self._local_heads_per_group, self._kv_channels)) + query.unflatten(-1, (self._local_head_groups, self._local_heads_per_group, self._config.kv_channels)) .transpose(1, 2) - .reshape(b * self._local_head_groups, sq * self._local_heads_per_group, self._kv_channels) + .reshape(b * self._local_head_groups, sq * self._local_heads_per_group, self._config.kv_channels) + ) + key = key.unflatten(-1, (self._local_head_groups, self._config.kv_channels)).movedim(1, 3).flatten(0, 1) + value = ( + value.unflatten(-1, (self._local_head_groups, self._config.kv_channels)).transpose(1, 2).flatten(0, 1) ) - key = key.unflatten(-1, (self._local_head_groups, self._kv_channels)).movedim(1, 3).flatten(0, 1) - value = value.unflatten(-1, (self._local_head_groups, self._kv_channels)).transpose(1, 2).flatten(0, 1) attn_weights = torch.empty( (b * self._local_head_groups, sq * self._local_heads_per_group, sk), device=query.device, dtype=query.dtype @@ -179,7 +210,7 @@ def _attn_fused( attn_weights = torch.where(mask, attn_weights, mask_value) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) - with set_generator(self._tensor_space.distributed.tp_generator): + with set_generator(self._distributed.tp_generator): attn_weights = torch.dropout(attn_weights, self._config.attention_dropout, self.training) attn_output = torch.bmm( attn_weights.view(b * self._local_head_groups, sq * self._local_heads_per_group, sk), value @@ -189,7 +220,7 @@ def _attn_fused( return attn_output.view(b, sq, -1) else: return ( - attn_output.view(b, self._local_head_groups, sq, self._local_heads_per_group, self._kv_channels) + attn_output.view(b, self._local_head_groups, sq, self._local_heads_per_group, self._config.kv_channels) .transpose(1, 2) .flatten(2) ) @@ -201,18 +232,16 @@ def _query_key_value_forward( handle = None - if self._head_groups == 1 and self._sequence_parallel: - key_value, handle = gather_op( - key_value, group=self._tensor_space.distributed.tensor_group, dim=0, async_op=True - ) + if self._config.head_groups == 1 and self._sequence_parallel: + key_value, handle = gather_op(key_value, group=self._parallel_dim.group, dim=0, async_op=True) - if self._tensor_space.distributed.sequence_data_group: + if self._sequence_data_parallel_dim.group: if handle: # TODO: This is probably unnecessary. handle.wait() # sequence dim may not be zero, but this needs to be handled after `handle.wait()` key_value, handle = gather_op( - key_value, group=self._tensor_space.distributed.sequence_data_group, dim=0, async_op=True + key_value, group=self._sequence_data_parallel_dim.group, dim=0, async_op=True ) query, query_context = self.query.forward_only(input_) @@ -220,8 +249,8 @@ def _query_key_value_forward( if handle: handle.wait() - if self._tensor_space.distributed.sequence_data_group and not sequence_first: - key_value = swap_mult_dim(key_value, self._tensor_space.distributed_config.sequence_data_parallel, 0, 1) + if self._sequence_data_parallel_dim.group and not sequence_first: + key_value = swap_mult_dim(key_value, self._sequence_parallel, 0, 1) context = {"query": query_context, "key_value": key_value_context, "sequence_first": sequence_first} return query, key_value, context @@ -230,15 +259,12 @@ def _query_key_value_backward( self, query_grad: torch.Tensor, key_value_grad: torch.Tensor, context: dict ) -> torch.Tensor: # TODO: De-allocate qkv grads quicker. - handle = None - - if self._tensor_space.distributed.sequence_data_group: - key_value_grad, handle = reduce_scatter_op( - key_value_grad, - group=self._tensor_space.distributed.sequence_data_group, - dim=1 - context["sequence_first"], - async_op=True, - ) + key_value_grad, handle = reduce_scatter_op( + key_value_grad, + group=self._sequence_data_parallel_dim.group, + dim=1 - context["sequence_first"], + async_op=True, + ) # TODO: Overlap with both. input_grad = self.query.backward(query_grad, context.pop("query")) @@ -246,7 +272,7 @@ def _query_key_value_backward( if handle: handle.wait() - if self._head_groups == 1 and (group := self._tensor_space.distributed.tensor_group): + if self._config.head_groups == 1 and (group := self._parallel_dim.group): if self._sequence_parallel: key_value_grad = reduce_scatter_op(key_value_grad, group=group, dim=0) else: @@ -289,7 +315,7 @@ def forward( # Manually add the gradients from later micro-sequences. key_value = AttachGrad.apply(key_value, present) - if self._tensor_space.distributed.sequence_data_group: + if self._sequence_data_parallel_dim.group: key_value = ( key_value[: kwargs[AttentionKwargs.sequence_k_dim].size] if sequence_first @@ -301,11 +327,11 @@ def forward( query = query.transpose(0, 1).contiguous() key_value = key_value.transpose(0, 1).contiguous() - key, value = key_value.split(self._local_head_groups * self._kv_channels, dim=-1) + key, value = key_value.split(self._local_head_groups * self._config.kv_channels, dim=-1) - query = query.view(*query.shape[:2], self._local_heads, self._kv_channels) - key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels) - value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels) + query = query.view(*query.shape[:2], self._local_heads, self._config.kv_channels) + key = key.view(*key.shape[:2], self._local_head_groups, self._config.kv_channels) + value = value.view(*value.shape[:2], self._local_head_groups, self._config.kv_channels) if self._debug.enabled: self._debug(query, "query_rotary_input", self._QUERY_DIMS, kwargs) @@ -316,7 +342,7 @@ def forward( if self._use_flash_attention: assert _flash_available - with set_generator(self._tensor_space.distributed.tp_generator): + with set_generator(self._distributed.tp_generator): if (cu_seqlens_q := kwargs.get(AttentionKwargs.cu_seqlens_q, None)) is not None: out_dims = query.size() query = query.view(-1, query.size(-2), query.size(-1)) @@ -357,10 +383,10 @@ def forward( ) if self._debug.enabled: - self._debug(query, "query", self._QUERY_DIMS, kwargs) - self._debug(key, "key", self._KV_DIMS, kwargs) - self._debug(value, "value", self._KV_DIMS, kwargs) - self._debug(input_, "context", self._CONTEXT_DIMS, kwargs) + self._debug(query, "query", self._query_dims, kwargs) + self._debug(key, "key", self._kv_dims, kwargs) + self._debug(value, "value", self._kv_dims, kwargs) + self._debug(input_, "context", self._context_dims, kwargs) if sequence_first: # TODO: Optimize (is contiguous avoidable? Transpose dense output?) diff --git a/fast_llm/layers/transformer/block.py b/fast_llm/layers/transformer/block.py index 89d7a2e3b..a5aad45a9 100644 --- a/fast_llm/layers/transformer/block.py +++ b/fast_llm/layers/transformer/block.py @@ -1,7 +1,6 @@ import logging import typing -from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.layers.block.block import Block, BlockLayer from fast_llm.layers.transformer.attention import Attention from fast_llm.layers.transformer.config import TransformerConfig @@ -10,13 +9,10 @@ class TransformerBlock[ConfigType: TransformerConfig](Block[ConfigType]): - _name = "Transformer layer" # TODO: Standardize to `mixer` _mixer_module_name: typing.ClassVar[str] = "self_attn" - _config: TransformerConfig - - def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, return_input: bool = False): - super().__init__(config, tensor_space, block_index, return_input) def _create_mixer(self) -> BlockLayer: - return Attention(self._config, self._tensor_space, self._block_index) + return Attention( + self._config, self._distributed_config, self._hidden_dim, self._block_index, f"{self._name} attn" + ) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index f7c7fea9c..a40f676ca 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -5,10 +5,9 @@ from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace -from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig -from fast_llm.layers.block.config import AddLinearBiasChoices, BlockConfig, BlockDimNames, BlockKwargs +from fast_llm.layers.block.config import AddLinearBiasChoices, BlockConfig, BlockKwargs from fast_llm.layers.transformer.rotary.config import RotaryConfig from fast_llm.utils import Assert, div @@ -18,19 +17,6 @@ logger = logging.getLogger(__name__) -class AttentionDimNames(BlockDimNames): - # A set of common tensor dim names packed into a namespace. - # Self-attention dimensions - head_groups = "head_groups" - group_heads = "group_heads" - key_and_value = "key_value" - kv_channels = "kv_channels" - composite_heads = "composite_heads" - composite_query = "composite_query" - composite_key_value = "composite_key_value" - composite_dense = "composite_dense" - - class AttentionKwargs(BlockKwargs): rotary_freq_q = "rotary_freq_q" rotary_freq_k = "rotary_freq_k" @@ -180,36 +166,6 @@ def projection_size(self): def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) - # Needed for multiple inheritance. - super().setup_tensor_space(tensor_space) # Noqa - - tensor_space.add_tensor_dim( - head_groups := TensorDim( - AttentionDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None - ) - ) - tensor_space.add_tensor_dim( - group_heads := TensorDim( - AttentionDimNames.group_heads, - div(self.num_attention_heads, self.head_groups), - None if self.head_groups > 1 else tensor, - ) - ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(AttentionDimNames.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(AttentionDimNames.kv_channels, self.kv_channels)) - tensor_space.add_tensor_dim(CompositeTensorDim(AttentionDimNames.composite_heads, (head_groups, group_heads))) - tensor_space.add_tensor_dim( - CompositeTensorDim(AttentionDimNames.composite_query, (head_groups, group_heads, kv_channels)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(AttentionDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(AttentionDimNames.composite_dense, (head_groups, group_heads, kv_channels)) - ) - @property def add_qkv_bias(self) -> bool: # TODO: Make this work without inheritance. diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index 16e5811e6..769177668 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -4,7 +4,8 @@ import torch from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.transformer.config import AttentionConfig, AttentionKwargs from fast_llm.tensor import TensorMeta @@ -12,25 +13,18 @@ class BackupAttentionPreprocessor(Preprocessor): - _scalar_dim: TensorDim _kv_channels_dim: TensorDim _rotary_embedding_frequencies: torch.Tensor _mask: torch.Tensor _mask_value: torch.Tensor _tensor_cache_max_sequence_length: int = -1 - def __init__( - self, - config: AttentionConfig, - tensor_space: TensorSpace, - ): + def __init__(self, config: AttentionConfig, distributed_config: DistributedConfig): self._config = config - self._tensor_space = tensor_space - self._distributed_config = self._tensor_space.distributed_config + self._distributed_config = distributed_config assert not self._config.do_use_flash_attention(self._distributed_config) - self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] - def _create_tensors(self, sequence_length: int) -> None: + def _create_tensors(self, sequence_length: int, device: torch.device) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: return self._tensor_cache_max_sequence_length = sequence_length @@ -38,7 +32,7 @@ def _create_tensors(self, sequence_length: int) -> None: self._mask = torch.ones( (sequence_length, sequence_length), dtype=torch.bool, - device=self._tensor_space.distributed.device, + device=device, ).tril_() if self._config.window_size is not None: @@ -47,11 +41,11 @@ def _create_tensors(self, sequence_length: int) -> None: [], torch.finfo(self._distributed_config.training_dtype.torch).min, dtype=self._distributed_config.training_dtype.torch, - device=self._tensor_space.distributed.device, + device=device, ) - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[AttentionKwargs.sequence_length]) + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + self._create_tensors(kwargs[AttentionKwargs.sequence_length], batch.device) sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size kwargs[AttentionKwargs.attention_mask] = self._mask[ @@ -64,7 +58,7 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: for sample_lens in sequence_lengths ] ) - document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(self._tensor_space.distributed.device) + document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(batch.device) kwargs[AttentionKwargs.attention_mask] = ( kwargs[AttentionKwargs.attention_mask] & document_mask[:, None, sequence_k - sequence_q : sequence_k, None, :sequence_k] @@ -74,30 +68,29 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: kwargs[AttentionKwargs.attention_mask] = TensorMeta.from_dims( ( - self._scalar_dim, - self._scalar_dim, + scalar_dim, + scalar_dim, kwargs[AttentionKwargs.sequence_q_dim], - self._scalar_dim, + scalar_dim, kwargs[AttentionKwargs.sequence_k_dim], ), tensor_name=AttentionKwargs.attention_mask, dtype=torch.bool, ) kwargs[AttentionKwargs.attention_mask_value] = TensorMeta.from_dims( - (self._scalar_dim,), + (scalar_dim,), tensor_name=AttentionKwargs.attention_mask_value, - dtype=self._tensor_space.distributed_config.training_dtype.torch, + dtype=self._distributed_config.training_dtype.torch, ) class FlashAttnVarlenPreprocessor(Preprocessor): - def __init__(self, config: AttentionConfig, tensor_space: TensorSpace): + def __init__(self, config: AttentionConfig, distributed_config: DistributedConfig): self._config = config - self._tensor_space = tensor_space - self._distributed_config = self._tensor_space.distributed_config + self._distributed_config = distributed_config assert self._config.do_use_flash_attention(self._distributed_config) - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: """ Prepares cu_seqlens_q and cu_seqlens_k for flash_attn_varlen_func: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L1375 @@ -148,14 +141,14 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: seqlens_k = torch.cat(sequence_lengths) kwargs[AttentionKwargs.cu_seqlens_q] = torch.cat( ( - torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), - torch.cumsum(seqlens_q, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), + torch.zeros(1, dtype=torch.int32, device=batch.device), + torch.cumsum(seqlens_q, dim=0, dtype=torch.int32).to(batch.device), ) ) kwargs[AttentionKwargs.cu_seqlens_k] = torch.cat( ( - torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), - torch.cumsum(seqlens_k, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), + torch.zeros(1, dtype=torch.int32, device=batch.device), + torch.cumsum(seqlens_k, dim=0, dtype=torch.int32).to(batch.device), ) ) kwargs[AttentionKwargs.max_seqlen_q] = seqlens_q.max() diff --git a/fast_llm/layers/transformer/rotary/config.py b/fast_llm/layers/transformer/rotary/config.py index 748f2af28..f0e0079c7 100644 --- a/fast_llm/layers/transformer/rotary/config.py +++ b/fast_llm/layers/transformer/rotary/config.py @@ -5,7 +5,7 @@ from fast_llm.config import Field, FieldHint, config_class from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.config import TritonConfig from fast_llm.utils import Assert @@ -29,8 +29,8 @@ def _from_dict( return NoRotaryConfig._from_dict(default, strict, flat) return super()._from_dict(default, strict=strict, flat=flat) - def build(self, tensor_space: TensorSpace | None = None) -> "Rotary": - return self._get_configurable_class()(self, tensor_space) + def build(self, kv_channels_dim: TensorDim) -> "Rotary": + return self._get_configurable_class()(self, kv_channels_dim) @classmethod @abc.abstractmethod diff --git a/fast_llm/layers/transformer/rotary/preprocessing.py b/fast_llm/layers/transformer/rotary/preprocessing.py deleted file mode 100644 index 9f8732f85..000000000 --- a/fast_llm/layers/transformer/rotary/preprocessing.py +++ /dev/null @@ -1,68 +0,0 @@ -import typing - -import torch - -from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs -from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig -from fast_llm.tensor import TensorMeta - - -class RotaryEmbeddingPreprocessor(Preprocessor): - _scalar_dim: TensorDim - _kv_channels_dim: TensorDim - _rotary_embedding_frequencies: torch.Tensor - _mask: torch.Tensor - _mask_value: torch.Tensor - _tensor_cache_max_sequence_length: int = -1 - - def __init__( - self, - config: DefaultRotaryConfig, - tensor_space: TensorSpace, - ): - self._config = config - self._tensor_space = tensor_space - self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] - self._kv_channels_dim = self._tensor_space[AttentionDimNames.kv_channels] - - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[AttentionKwargs.sequence_length]) - sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size - kwargs[AttentionKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ - :, sequence_k - kwargs[AttentionKwargs.sequence_q_dim].size : sequence_k - ] - kwargs[AttentionKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] - - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[AttentionKwargs.rotary_freq_q] = TensorMeta.from_dims( - ( - self._scalar_dim, - kwargs[AttentionKwargs.sequence_q_dim], - self._scalar_dim, - self._kv_channels_dim, - ), - tensor_name=AttentionKwargs.rotary_freq_q, - ) - kwargs[AttentionKwargs.rotary_freq_k] = TensorMeta.from_dims( - ( - self._scalar_dim, - kwargs[AttentionKwargs.sequence_q_dim], - self._scalar_dim, - self._kv_channels_dim, - ), - tensor_name=AttentionKwargs.rotary_freq_k, - ) - - def _create_tensors(self, sequence_length: int) -> None: - if sequence_length <= self._tensor_cache_max_sequence_length: - return - self._tensor_cache_max_sequence_length = sequence_length - - self._rotary_embedding_frequencies = self._config.get_frequencies( - sequence_length, - self._kv_channels_dim.global_size, - device=self._tensor_space.distributed.device, - ) diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index ebb629aa1..bbf8b524a 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -6,9 +6,9 @@ from fast_llm.config import Configurable from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.functional.triton.rotary import triton_rotary_autograd_ -from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs +from fast_llm.layers.transformer.config import AttentionKwargs from fast_llm.layers.transformer.rotary.config import ( DefaultRotaryConfig, Llama3RotaryConfig, @@ -41,14 +41,14 @@ def apply_rotary_embeddings(tensor: torch.Tensor, rope_frequencies: torch.Tensor return torch.view_as_real(complex_tensor * rope_frequencies).view_as(tensor).type_as(tensor) -class Rotary[ConfigType: RotaryConfig](Configurable[RotaryConfig], torch.nn.Module, Preprocessor): +class Rotary[ConfigType: RotaryConfig](Configurable[ConfigType], torch.nn.Module, Preprocessor): def __init__( self, config: ConfigType, - # The tensor space is only needed for preprocessing, so we make it optional. - tensor_space: TensorSpace | None = None, + kv_channels_dim: TensorDim, ): super().__init__(config) + self._kv_channels_dim = kv_channels_dim @abc.abstractmethod def forward( @@ -57,7 +57,7 @@ def forward( pass -class NoRotary[ConfigType: NoRotaryConfig](Rotary[NoRotaryConfig]): +class NoRotary[ConfigType: NoRotaryConfig](Rotary[ConfigType]): def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: @@ -70,24 +70,12 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: pass -class DefaultRotary[ConfigType: DefaultRotaryConfig](Rotary[DefaultRotaryConfig]): +class DefaultRotary[ConfigType: DefaultRotaryConfig](Rotary[ConfigType]): _rotary_embedding_frequencies: torch.Tensor _tensor_cache_max_sequence_length: int = -1 - def __init__( - self, - config: ConfigType, - tensor_space: TensorSpace | None = None, - ): - super().__init__(config, tensor_space) - self._tensor_space = tensor_space - if self._tensor_space is not None: - self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] - self._kv_channels_dim = self._tensor_space[AttentionDimNames.kv_channels] - - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - assert self._tensor_space is not None - self._create_tensors(kwargs[AttentionKwargs.sequence_length]) + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + self._create_tensors(kwargs[AttentionKwargs.sequence_length], batch.device) sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size kwargs[AttentionKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ :, sequence_k - kwargs[AttentionKwargs.sequence_q_dim].size : sequence_k @@ -95,21 +83,20 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: kwargs[AttentionKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - assert self._tensor_space is not None kwargs[AttentionKwargs.rotary_freq_q] = TensorMeta.from_dims( ( - self._scalar_dim, + scalar_dim, kwargs[AttentionKwargs.sequence_q_dim], - self._scalar_dim, + scalar_dim, self._kv_channels_dim, ), tensor_name=AttentionKwargs.rotary_freq_q, ) kwargs[AttentionKwargs.rotary_freq_k] = TensorMeta.from_dims( ( - self._scalar_dim, + scalar_dim, kwargs[AttentionKwargs.sequence_q_dim], - self._scalar_dim, + scalar_dim, self._kv_channels_dim, ), tensor_name=AttentionKwargs.rotary_freq_k, @@ -123,7 +110,7 @@ def forward( key = rotary_fn(key, kwargs[AttentionKwargs.rotary_freq_k]) return query, key - def _create_tensors(self, sequence_length: int) -> None: + def _create_tensors(self, sequence_length: int, device: torch.device) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: return self._tensor_cache_max_sequence_length = sequence_length @@ -131,10 +118,10 @@ def _create_tensors(self, sequence_length: int) -> None: self._rotary_embedding_frequencies = self._get_frequencies( sequence_length, self._kv_channels_dim.global_size, - device=self._tensor_space.distributed.device, + device=device, ) - def _get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda") -> torch.Tensor: + def _get_frequencies(self, sequence_length: int, kv_channels: int, device: torch.device) -> torch.Tensor: # Calculate the complex frequencies (https://blog.eleuther.ai/rotary-embeddings/) # `exp(i * n * a) = cos(n * a) + i sin(n * a)`, # `a = theta ** - (2 * (channel // 2) / kv_channels)`, @@ -149,12 +136,12 @@ def _get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda" ).contiguous() return frequencies - def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor: + def _get_angle_scales(self, kv_channels: int, device: torch.device) -> torch.Tensor: return self._config.theta ** -torch.arange(0, 1, 2 / kv_channels, device=device, dtype=torch.float64) -class Llama3Rotary[ConfigType: Llama3RotaryConfig](DefaultRotary[Llama3RotaryConfig]): - def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor: +class Llama3Rotary[ConfigType: Llama3RotaryConfig](DefaultRotary[ConfigType]): + def _get_angle_scales(self, kv_channels: int, device: torch.device) -> torch.Tensor: scales = super()._get_angle_scales(kv_channels, device) low_frequency_wavelength = self._config.original_context_length / self._config.low_frequency_factor high_frequency_wavelength = self._config.original_context_length / self._config.high_frequency_factor @@ -173,17 +160,17 @@ def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor: return torch.stack(new_scales) -class YarnRotary[ConfigType: YarnRotaryConfig](DefaultRotary[YarnRotaryConfig]): +class YarnRotary[ConfigType: YarnRotaryConfig](DefaultRotary[ConfigType]): """ Yarn scaling: https://github.com/huggingface/transformers/blob/006d9249ec0270ff6c4d3840979d23fe94bdc763/src/transformers/modeling_rope_utils.py#L163 [original paper](https://arxiv.org/abs/2309.00071) """ - def _get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda") -> torch.Tensor: + def _get_frequencies(self, sequence_length: int, kv_channels: int, device: torch.device) -> torch.Tensor: return super()._get_frequencies(sequence_length, kv_channels, device) * self._config.attention_factor - def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor: + def _get_angle_scales(self, kv_channels: int, device: torch.device) -> torch.Tensor: scales = super()._get_angle_scales(kv_channels, device) # TODO: max_position_embeddings or original_context_length? # see https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L304 diff --git a/fast_llm/logging.py b/fast_llm/logging.py index 6d555a0bb..024d7d79c 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -14,7 +14,6 @@ if typing.TYPE_CHECKING: from fast_llm.core.distributed import ProcessGroup - from fast_llm.engine.distributed.distributed import Distributed logger = logging.getLogger(__name__) @@ -254,7 +253,6 @@ def log_distributed_tensor[ scale: float = 1.0, level: int = 2, storage: bool = False, - distributed: "Distributed", duplicate_groups: tuple[typing.Optional["ProcessGroup"], ...] = (), global_: bool = True, log_fn: type[BaseException] | typing.Callable[[str], T] | None = logger.info, @@ -263,7 +261,7 @@ def log_distributed_tensor[ if level <= 0: return if global_: - tensor, is_first_rank = meta.local_to_global(tensor, distributed=distributed) + tensor, is_first_rank = meta.local_to_global(tensor) storage = False is_first_rank = is_first_rank and all(group.rank() == 0 for group in duplicate_groups if group) if not is_first_rank: @@ -289,7 +287,6 @@ def log_distributed_grad[ scale: float = 1.0, level: int = 2, storage: bool = False, - distributed: "Distributed", duplicate_groups: tuple[typing.Optional["ProcessGroup"], ...] = (), grad_fn: typing.Callable[[torch.Tensor], torch.Tensor] | None = None, global_: bool = True, @@ -305,7 +302,6 @@ def log_distributed_grad[ scale=scale, level=level, storage=storage, - distributed=distributed, duplicate_groups=duplicate_groups, global_=global_, log_fn=log_fn, diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index 98937bdb1..3afd88ce1 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -3,11 +3,9 @@ import torch from fast_llm.data.data.gpt.data import GPTBatch -from fast_llm.engine.base_model.base_model import Layer, LossDef +from fast_llm.engine.base_model.base_model import LossDef from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.schedule.config import BatchConfig -from fast_llm.layers.language_model.embedding import LanguageModelEmbedding -from fast_llm.layers.transformer.block import TransformerBlock from fast_llm.models.custom.config import CustomBaseModelConfig from fast_llm.models.custom.head import CustomHead from fast_llm.models.gpt.model import GPTBaseModel, GPTModel @@ -17,26 +15,21 @@ class CustomBaseModel[ConfigType: CustomBaseModelConfig](GPTBaseModel[ConfigType]): def __init__( self, - config: CustomBaseModelConfig, + config: ConfigType, distributed_config: DistributedConfig, ): # TODO: Implement / update. super().__init__(config, distributed_config) - def get_layers(self) -> list[Layer]: - # TODO: Adjust as needed. - return [ - LanguageModelEmbedding(self._config, self._tensor_space), - *[ - TransformerBlock( - self._config.transformer, - self._tensor_space, - block_index=i + 1, - ) - for i in range(self._config.transformer.num_layers) - ], - CustomHead(self._config, self._tensor_space), - ] + def _get_head(self, prediction_distance): + return CustomHead( + self._config, + self._distributed_config, + self._hidden_dim, + max(self._config.transformer.num_layers + prediction_distance, 1), + f"Language model head {prediction_distance}", + prediction_distance=prediction_distance, + ) def preprocess_meta( self, batch_meta: BatchConfig | torch.Tensor, phase: PhaseType diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 47df8ba1c..41e0d607d 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -6,17 +6,18 @@ from fast_llm.data.data.gpt.data import GPTBatch from fast_llm.engine.base_model.base_model import BaseModel, Layer, LossDef from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.block.mlp.config import MLPLossNames, RoutingType from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor from fast_llm.layers.transformer.block import TransformerBlock -from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs +from fast_llm.layers.transformer.config import AttentionKwargs from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron @@ -36,6 +37,7 @@ def __init__( config: GPTBaseModelConfig, distributed_config: DistributedConfig, ): + self._hidden_dim = TensorDim("hidden", config.transformer.hidden_size) super().__init__(config, distributed_config) self._use_flash_attention = self._config.transformer.do_use_flash_attention(distributed_config) if self._config.use_megatron_initialization: @@ -45,59 +47,81 @@ def __init__( # `self._reference_models` is not populated at this point, so we pass a mutable dict. self._preprocessors: list[Preprocessor] = [] if self._config.use_absolute_position_embeddings: - self._preprocessors.append(PositionEmbeddingPreprocessor(self._config, self._tensor_space)) + self._preprocessors.append(PositionEmbeddingPreprocessor(self._config, self._distributed_config)) # We have multiple identical rotary modules/preprocessors, so it's simpler to make a new one here. # TODO: Find a better solution. - self._preprocessors.append(self._config.transformer.rotary.build(self._tensor_space)) + self._preprocessors.append( + self._config.transformer.rotary.build(TensorDim("kv_channels", self._config.transformer.kv_channels)) + ) if self._use_flash_attention: - self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space)) + self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._distributed_config)) else: - self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) + self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._distributed_config)) if self._config.enable_dpo: # TODO better way to pass in? - self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._tensor_space)) + self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._distributed_config)) - def get_output_layers(self) -> list[Layer]: + def _get_output_layers(self) -> list[Layer]: layers = [] for i in range(self._config.prediction_heads): if i > 0: layers.append( - TransformerBlock( - self._config.transformer, - self._tensor_space, + self._get_block( # TODO MTP: which index? - block_index=max(self._config.transformer.num_layers + i, 1), + max(self._config.transformer.num_layers + i, 1), + f"MPT head {i} block", # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. - return_input=i < self._config.prediction_heads - 1, + i < self._config.prediction_heads - 1, ) ) - layers.append( - LanguageModelHead( - self._config, - self._tensor_space, - prediction_distance=i, - ) - ) + layers.append(self._get_head(i)) return layers def get_layers(self) -> list[Layer]: return [ - LanguageModelEmbedding(self._config, self._tensor_space), + self._get_embeddings(), *[ - TransformerBlock( - self._config.transformer, - self._tensor_space, - block_index=i + 1, + self._get_block( + i + 1, + f"Block {i + 1}", # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. - return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, + self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, ) for i in range(self._config.transformer.num_layers) ], - *self.get_output_layers(), + *self._get_output_layers(), ] + def _get_block( + self, + block_index: int, + name: str, + return_input: bool = False, + ): + return TransformerBlock( + self._config.transformer, + self._distributed_config, + self._hidden_dim, + block_index, + name, + return_input, + ) + + def _get_embeddings(self): + return LanguageModelEmbedding(self._config, self._distributed_config, self._hidden_dim, 0, "Embeddings") + + def _get_head(self, prediction_distance): + return LanguageModelHead( + self._config, + self._distributed_config, + self._hidden_dim, + max(self._config.transformer.num_layers + prediction_distance, 1), + f"Language model head {prediction_distance}", + prediction_distance=prediction_distance, + ) + def preprocess_meta( self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType ) -> list[tuple[TensorMeta, dict]]: @@ -116,8 +140,8 @@ def preprocess_meta( micro_sequence_length = sequence_length truncate_documents = True - batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) - batch_dim = TensorDim(AttentionDimNames.batch, micro_batch_size * batch_data.size, batch_data) + batch_data = self._distributed_config.get_distributed_dim(DistributedDimNames.batch_data) + batch_dim = TensorDim(BlockDimNames.batch, micro_batch_size * batch_data.size, batch_data) if micro_sequence_length is None: micro_sequence_length = sequence_length @@ -126,19 +150,17 @@ def preprocess_meta( # TODO: Calculate hidden dims elsewhere? sequence_q_dim = TensorDim( - AttentionDimNames.sequence_q, + BlockDimNames.sequence_q, micro_sequence_length, - self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.sequence_data), + self._distributed_config.get_distributed_dim(DistributedDimNames.sequence_data), ) hidden_sequence_q_dim = ( TensorDim( - AttentionDimNames.sequence_q_tp, + BlockDimNames.sequence_q_tp, micro_sequence_length, - self._tensor_space.distributed_config.get_distributed_dim( - DistributedDimNames.tensor_and_sequence_data - ), + self._distributed_config.get_distributed_dim(DistributedDimNames.tensor_and_sequence_data), ) - if self._tensor_space.distributed_config.sequence_tensor_parallel + if self._distributed_config.sequence_tensor_parallel else sequence_q_dim ) @@ -149,11 +171,10 @@ def preprocess_meta( sequence_first = self._config.sequence_first assert not (need_sequence_first and not sequence_first) - hidden_dim = self._tensor_space[AttentionDimNames.hidden] hidden_dims = ( - (hidden_sequence_q_dim, batch_dim, hidden_dim) + (hidden_sequence_q_dim, batch_dim, self._hidden_dim) if sequence_first - else (batch_dim, hidden_sequence_q_dim, hidden_dim) + else (batch_dim, hidden_sequence_q_dim, self._hidden_dim) ) common_kwargs = { @@ -166,7 +187,7 @@ def preprocess_meta( } sequence_k_pasts = range( - sequence_q_dim.size * self._tensor_space.distributed_config.sequence_data_rank, + sequence_q_dim.size * self._distributed_config.sequence_data_rank, sequence_length, micro_sequence_length, ) @@ -180,7 +201,7 @@ def preprocess_meta( preprocessed_meta = [] for i, sequence_k_past in enumerate(sequence_k_pasts): sequence_k = sequence_k_past + sequence_q_dim.size - sequence_k_dim = TensorDim(AttentionDimNames.sequence_k, sequence_k) + sequence_k_dim = TensorDim(BlockDimNames.sequence_k, sequence_k) tokens = TensorMeta.from_dims( hidden_dims[:2], tensor_name=f"tokens_{sequence_k_past}_to_{sequence_k-1}", dtype=torch.int64 @@ -234,7 +255,7 @@ def preprocess( prediction_heads: int = self._config.prediction_heads batch.token_ids = batch.token_ids.to( - device=self._tensor_space.distributed.device, + device=self._distributed.device, dtype=torch.int64, non_blocking=True, ) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 866de962f..9d54675be 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -6,7 +6,6 @@ from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointHandler from fast_llm.engine.config_utils.runnable import RunnableConfig -from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig @@ -47,14 +46,6 @@ class HybridSSMBaseModelConfig(GPTBaseModelConfig): # TODO: Support combination of different SSM block types. ssm_block_type: SSMBlockType | None = Field(init=False) - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - """ - Setup the tensor space for the model. - """ - super().setup_tensor_space(tensor_space) - if self.ssm_block_type is not None: - self.ssm.setup_tensor_space(tensor_space, self.ssm_block_type) - def _validate(self): with self._set_implicit_default(None): if self.ssm.dt_rank == "auto" or self.ssm.dt_rank is None: diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 32fbdad9b..7c67d7355 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -1,10 +1,6 @@ import logging import typing -from fast_llm.engine.base_model.base_model import Layer -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.language_model.embedding import LanguageModelEmbedding -from fast_llm.layers.language_model.head import LanguageModelHead from fast_llm.layers.ssm.block import SSMBlock from fast_llm.layers.transformer.block import TransformerBlock from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner, GPTModel @@ -20,88 +16,39 @@ class HybridSSMBaseModel[ConfigType: HybridSSMBaseModelConfig](GPTBaseModel[Conf As for the mixer, transformer uses MHA. For the LlambaBlock we support Mamba1 and discrete mamba2. """ - _is_setup: bool = False - - def __init__( + def _get_block( self, - config: HybridSSMBaseModelConfig, - distributed_config: DistributedConfig, + block_index: int, + name: str, + return_input: bool = False, ): - super().__init__(config, distributed_config) - - def get_output_layers(self) -> list[Layer]: - """ - Get the output layers of the model. - This includes the language model head and any additional heads specified in the configuration. - """ - layers: list[Layer] = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] - - if self._config.prediction_heads > 1: + if block_index > self._config.transformer.num_layers: + # MTP block block_type = self._config.default_mtp_type or self._config.hybrid_block_layout[-1] - for i in range(1, self._config.prediction_heads): - if block_type == SSMBlockType.transformer: - layers.append( - TransformerBlock( - self._config.transformer, - self._tensor_space, - block_index=len(self._config.hybrid_block_layout), - return_input=i != self._config.prediction_heads - 1, - ) - ) - else: - layers.append( - SSMBlock( - config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=self._config.ssm_block_type.get_mixer_class(), - block_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - ) - layers.append(LanguageModelHead(self._config, self._tensor_space, prediction_distance=i)) - - return layers - - def get_layers(self) -> list[Layer]: - """ - Create a list of layers for the model, interleaving Transformer and Mamba blocks - according to the block pattern. - """ - layers: list[Layer] = [LanguageModelEmbedding(self._config, self._tensor_space)] - - # Create blocks according to pattern - for i, block_type in enumerate(self._config.hybrid_block_layout): - if block_type == SSMBlockType.transformer: - # Transformer block - layers.append( - TransformerBlock( - self._config.transformer, - self._tensor_space, - block_index=i + 1, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - ) - else: - layers.append( - SSMBlock( - config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=self._config.ssm_block_type.get_mixer_class(), - block_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - ) - - # Add the output layers - layers += self.get_output_layers() - - return layers + else: + # Decoder block + block_type = self._config.hybrid_block_layout[block_index - 1] + + if block_type == SSMBlockType.transformer: + return TransformerBlock( + self._config.transformer, + self._distributed_config, + self._hidden_dim, + block_index, + name, + return_input, + ) + else: + return SSMBlock( + self._config.transformer, + self._config.ssm, + self._distributed_config, + self._hidden_dim, + self._config.ssm_block_type.get_mixer_class(), + block_index, + name, + return_input, + ) class HybridSSMModel[ConfigType: HybridSSMModelConfig](GPTModel[ConfigType]): diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index b12d12072..b6180c190 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -7,7 +7,7 @@ from fast_llm.core.distributed import ReduceOp from fast_llm.core.ops import reduce_op from fast_llm.engine.config_utils.initialization import Initializer, LambdaInitializer -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed from fast_llm.functional.triton.pointwise import triton_add, triton_copy @@ -138,30 +138,11 @@ def from_dims( **kwargs, ) - @classmethod - def from_tensor_space( - cls, - dim_names: tuple[str, ...], - tensor_space: TensorSpace, - *, - tensor_name: str = "", - dtype: torch.dtype = torch.float32, - reductions: tuple[tuple[str, ReduceOp], ...] = (), - **kwargs: typing.Any, - ) -> typing.Self: - dims = tuple(tensor_space[dim_name] for dim_name in dim_names) - if reductions: - # kwarg not available for ParameterMeta, so we only provide if necessary. - kwargs["reductions"] = tuple( - (tensor_space.distributed_config.get_distributed_dim(name), op) for name, op in reductions - ) - return cls.from_dims(dims, tensor_name=tensor_name, dtype=dtype, **kwargs) - @property def global_shape(self) -> torch.Size: return torch.Size([dim.global_size for dim in self.dims]) - def local_to_global(self, tensor: torch.Tensor, *, distributed: Distributed) -> tuple[torch.Tensor, ...]: + def local_to_global(self, tensor: torch.Tensor) -> tuple[torch.Tensor, ...]: """ Reconstruct a global tensor from its distributed slices. Support lazy-loaded safetensor slices. Returns a view of the input tensor (or the input tensor itself) when possible. @@ -171,7 +152,7 @@ def local_to_global(self, tensor: torch.Tensor, *, distributed: Distributed) -> Assert.eq(tensor.shape, self.shape) # Tensors are always either split or duplicated in the tensor-parallel direction. # TODO: Avoid hard-coded assumptions on duplication - is_first_rank, modified = distributed.config.tensor_rank == 0, False + is_first_rank, modified = True, False for dim, tensor_dim in enumerate(self.dims): if tensor_dim.is_parallel: diff --git a/tests/functional/test_triton_kernels.py b/tests/functional/test_triton_kernels.py index e61f72244..e4ad937b7 100644 --- a/tests/functional/test_triton_kernels.py +++ b/tests/functional/test_triton_kernels.py @@ -92,7 +92,7 @@ def test_triton_rotary(batch_size, sequence_length, num_heads, kv_channels): y1 = apply_rotary_embeddings( x, DefaultRotaryConfig(triton=False) - .build() + .build(None) ._get_frequencies( sequence_length, kv_channels, @@ -103,7 +103,7 @@ def test_triton_rotary(batch_size, sequence_length, num_heads, kv_channels): y2 = convert_rotary_real_to_complex( triton_rotary_( convert_rotary_complex_to_real(x, kv_channels, 3), - DefaultRotaryConfig(triton=True).build()._get_frequencies(sequence_length, kv_channels, device="cuda"), + DefaultRotaryConfig(triton=True).build(None)._get_frequencies(sequence_length, kv_channels, device="cuda"), ), kv_channels, 3, diff --git a/tests/test_attention.py b/tests/test_attention.py index 534e3800e..7d05e0a66 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -2,11 +2,12 @@ import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs, TransformerConfig +from fast_llm.layers.transformer.config import AttentionKwargs, TransformerConfig from fast_llm.layers.transformer.preprocessing import FlashAttnVarlenPreprocessor from fast_llm.utils import Assert @@ -30,19 +31,6 @@ def test_decide_window_size(): assert attention._decide_window_size() == 512 -def test_attention_constructor(): - transformer_conf = TransformerConfig( - num_layers=2, - num_attention_heads=2, - hidden_size=16, - ) - distributed_config = DistributedConfig() - tensor_space = TensorSpace(distributed_config=distributed_config) - transformer_conf.setup_tensor_space(tensor_space) - - Attention(transformer_conf, tensor_space, 1) - - def test_varlen_preprocessor(): sequence_lengths = [torch.tensor([8, 13, 4, 11], dtype=torch.int32), torch.tensor([11, 16, 9], dtype=torch.int32)] # First micro-sequence: @@ -63,27 +51,24 @@ def test_varlen_preprocessor(): ] micro_sequence_length = 12 sequence_length = 36 - transformer_cfg = TransformerConfig( + transformer_config = TransformerConfig( num_layers=2, num_attention_heads=2, hidden_size=16, use_flash_attention=True, ) - distributed_cfg = DistributedConfig(training_dtype="bfloat16") - distributed = Distributed(distributed_cfg, use_cpu=True) - tensor_space = TensorSpace(distributed_config=distributed_cfg) - tensor_space.setup(distributed) - transformer_cfg.setup_tensor_space(tensor_space) - varlen_preprocessor = FlashAttnVarlenPreprocessor(transformer_cfg, tensor_space=tensor_space) + distributed_config = DistributedConfig(training_dtype="bfloat16") + distributed = Distributed(distributed_config, use_cpu=True) + varlen_preprocessor = FlashAttnVarlenPreprocessor(transformer_config, distributed_config=distributed_config) for micro_seq_idx in range(int(sequence_length / micro_sequence_length)): kwargs = { - AttentionKwargs.sequence_q_dim: TensorDim(AttentionDimNames.sequence_k, micro_sequence_length), + AttentionKwargs.sequence_q_dim: TensorDim(BlockDimNames.sequence_k, micro_sequence_length), AttentionKwargs.sequence_k_dim: TensorDim( - AttentionDimNames.sequence_k, (micro_seq_idx + 1) * micro_sequence_length + BlockDimNames.sequence_k, (micro_seq_idx + 1) * micro_sequence_length ), AttentionKwargs.sequence_length: sequence_length, AttentionKwargs.sequence_lengths: sequence_lengths, } - varlen_preprocessor.preprocess(None, kwargs) + varlen_preprocessor.preprocess(torch.empty(1, device="cpu"), kwargs) Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_q], cumulative_sequences_q[micro_seq_idx]) Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_k], cumulative_sequences_k[micro_seq_idx]) diff --git a/tests/test_mlp.py b/tests/test_mlp.py deleted file mode 100644 index 802833eb2..000000000 --- a/tests/test_mlp.py +++ /dev/null @@ -1,29 +0,0 @@ -from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP -from fast_llm.layers.block.mlp.mlp import MLP -from fast_llm.layers.transformer.config import TransformerConfig - - -def test_mlp_constructor(): - transformer_conf = TransformerConfig( - num_layers=2, - num_attention_heads=2, - hidden_size=16, - ) - distributed_config = DistributedConfig() - tensor_space = TensorSpace(distributed_config=distributed_config) - transformer_conf.setup_tensor_space(tensor_space) - - MLP(transformer_conf, tensor_space, 0, "name") - - -def test_moe_mlp_constructor(): - transformer_conf = TransformerConfig( - num_layers=2, num_attention_heads=2, hidden_size=16, num_experts=2, add_linear_biases=False - ) - distributed_config = DistributedConfig() - tensor_space = TensorSpace(distributed_config=distributed_config) - transformer_conf.setup_tensor_space(tensor_space) - - MixtureOfExpertMLP(transformer_conf, tensor_space, 0, "name") diff --git a/tests/utils/global_variables.py b/tests/utils/global_variables.py index 80232bf53..42e588911 100644 --- a/tests/utils/global_variables.py +++ b/tests/utils/global_variables.py @@ -29,8 +29,8 @@ def set_testing_global_variables(): num_gpus = len(gpus) gpus = [gpus[(i + worker_id) % num_gpus] for i in range(num_gpus)] os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in gpus) - os.environ["TORCHINDUCTOR_CACHE_DIR"] = str(SHARED_RESULT_PATH / "torchinductor_cache") - os.environ["TRITON_CACHE_DIR"] = str(SHARED_RESULT_PATH / "triton_cache") + # os.environ["TORCHINDUCTOR_CACHE_DIR"] = str(SHARED_RESULT_PATH / "torchinductor_cache") + # os.environ["TRITON_CACHE_DIR"] = str(SHARED_RESULT_PATH / "triton_cache") # TODO: Fixtures diff --git a/tests/utils/utils.py b/tests/utils/utils.py index 88303a0f4..0dc3462eb 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -13,7 +13,6 @@ from fast_llm.core.distributed import ProcessGroup, allreduce_scalar, safe_barrier from fast_llm.engine.base_model.base_model import BaseModel, Layer from fast_llm.engine.config_utils.logging import configure_logging -from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageConfig from fast_llm.engine.multi_stage.stage import Stage @@ -33,12 +32,8 @@ def result_path(): def get_base_model(config: FastLLMModelConfig): # Create a base model (and distributed). # Using a full model config so we have the model type and distributed config in the same argument. - distributed = Distributed(config.distributed) - tensor_space = TensorSpace(config.distributed) - config.base_model.setup_tensor_space(tensor_space) - tensor_space.setup(distributed) base_model = config.get_model_class().base_model_class(config.base_model, config.distributed) - base_model.setup(distributed) + base_model.setup(distributed := Distributed(config.distributed)) return base_model, distributed From 797bd73befdb20c641b624b147a87971df266d55 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 14 Aug 2025 19:49:12 -0400 Subject: [PATCH 16/28] stuff --- fast_llm/engine/multi_stage/fast_llm_model.py | 1 - fast_llm/layers/block/block.py | 15 +- fast_llm/layers/block/config.py | 41 ++++- fast_llm/layers/block/mlp/config.py | 23 ++- .../layers/block/mlp/mixture_of_experts.py | 19 ++- fast_llm/layers/block/mlp/mlp.py | 25 ++- fast_llm/layers/common/config.py | 137 +---------------- .../layers/common/normalization/__init__.py | 0 .../layers/common/normalization/config.py | 142 ++++++++++++++++++ .../{ => normalization}/normalization.py | 2 +- fast_llm/models/gpt/conversion.py | 2 +- fast_llm/models/ssm/conversion.py | 2 +- 12 files changed, 235 insertions(+), 174 deletions(-) create mode 100644 fast_llm/layers/common/normalization/__init__.py create mode 100644 fast_llm/layers/common/normalization/config.py rename fast_llm/layers/common/{ => normalization}/normalization.py (99%) diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index da4fe527e..09ee788e6 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -51,7 +51,6 @@ def from_pretrained( use_cpu: bool = False, stage_filter: set | None = None, ) -> typing.Self: - print("IUGRGHIOERIO", cls, cls.config_class) metadata = cls.config_class.load_metadata(pretrained_config) config = cls.config_class.from_dict(metadata.config, *updates, update_type=UpdateType.update) if mode.support_training: diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 425731eb9..09370e3af 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -12,7 +12,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.layers.block.config import BlockConfig, BlockKwargs +from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockLayerConfig from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta @@ -94,27 +94,27 @@ class BlockLayerBase[ConfigType: Config](Configurable[ConfigType], Module): def __init__( self, config: ConfigType, + block_config: BlockConfig, distributed_config: DistributedConfig, # TODO: Review `hidden_dim` and `block_index` hidden_dim: TensorDim, block_index: int, name: str, - debug_level: int, - debug_memory: bool, ): super().__init__(config, distributed_config) + self._block_config = block_config self._hidden_dim = hidden_dim self._block_index = block_index self._name = name self._sequence_parallel: bool = self._distributed_config.sequence_tensor_parallel self._debug = DebugLayer( self._name, - debug_level, - debug_memory, + self._block_config.debug_transformer, + self._block_config.debug_transformer_memory, ) -class BlockLayer[ConfigType: Config](BlockLayerBase[ConfigType]): +class BlockLayer[ConfigType: BlockLayerConfig](BlockLayerBase[ConfigType]): """ Base class for mixer and MLP modules. """ @@ -148,13 +148,12 @@ def __init__( return_input: bool = False, ): super().__init__( + config, config, distributed_config, hidden_dim, block_index, name, - config.debug_transformer, - config.debug_transformer_memory, ) # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 0da7a0c99..3df82e24e 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -1,12 +1,20 @@ +import abc import enum +import functools +import typing from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.block.mlp.config import MLPConfig from fast_llm.layers.block.peft import TransformerPeftConfig -from fast_llm.layers.common.config import NormalizationConfig +from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.layers.block.block import BlockLayer + class BlockDimNames: # A set of common tensor dim names packed into a namespace. @@ -38,6 +46,37 @@ class AddLinearBiasChoices(str, enum.Enum): only_attn_qkv = "only_attn_qkv" +@config_class() +class BlockLayerConfig(BaseModelConfig): + """ + A common class for mixers and mlps, which have the exact same interface. + """ + + _abstract = True + + @functools.cached_property + @abc.abstractmethod + def layer_class(self) -> "type[BlockLayer]": + raise NotImplementedError() + + def get_layer( + self, + block_config: "BlockConfig", + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + ) -> "BlockLayer": + return self.layer_class( + self, + block_config, + distributed_config, + hidden_dim, + block_index, + name, + ) + + @config_class() # TODO: Use composition instead class BlockConfig(MLPConfig, BaseModelConfig): diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 57f7a9e03..83e45f002 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -1,9 +1,15 @@ import enum +import functools +import typing -from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.functional.config import ActivationType, MLPRecomputeLevel +from fast_llm.layers.block.config import BlockLayerConfig from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.layers.block.mlp.mlp import MLPBase + class MLPLossNames: load_balancing_loss = "load_balancing_loss" @@ -16,8 +22,8 @@ class RoutingType(str, enum.Enum): @config_class() -class MLPConfig(Config): - # TODO: Review names +class MLPConfig(BlockLayerConfig): + # TODO: Review names # TODO: Separate MoE? _abstract = False ffn_hidden_size: int = Field( default=None, @@ -150,6 +156,17 @@ def add_mlp_bias(self) -> bool: return True return False + @functools.cached_property + def layer_class(self) -> "type[MLPBase]": + if self.num_experts > 1: + from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP + + return MixtureOfExpertMLP + else: + from fast_llm.layers.block.mlp.mlp import MLP + + return MLP + def _validate(self) -> None: with self._set_implicit_default(): if self.activation_type is None: diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index 0bc531dad..2a234ca94 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -10,7 +10,7 @@ from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.layers.block.config import BlockConfig, BlockKwargs -from fast_llm.layers.block.mlp.config import MLPLossNames, RoutingType +from fast_llm.layers.block.mlp.config import MLPConfig, MLPLossNames, RoutingType from fast_llm.layers.block.mlp.mlp import MLPBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.linear import Linear @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) -class MixtureOfExpertMLP[ConfigType: BlockConfig](MLPBase[ConfigType]): +class MixtureOfExpertMLP[ConfigType: MLPConfig](MLPBase[ConfigType]): """ MoeLayer following implementation from https://github.com/NVIDIA/Megatron-LM/blob/46ebc0e4202c980d98900000d455f754a7ff9d4b/megatron/model/transformer.py#L346 @@ -36,6 +36,7 @@ class MixtureOfExpertMLP[ConfigType: BlockConfig](MLPBase[ConfigType]): def __init__( self, config: ConfigType, + block_config: BlockConfig, distributed_config: DistributedConfig, hidden_dim: TensorDim, block_index: int, @@ -44,19 +45,21 @@ def __init__( Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, distributed_config, hidden_dim, block_index, name) + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name) - layer_lr_scale = self._config.per_layer_lr_scale[block_index] if self._config.per_layer_lr_scale else None + layer_lr_scale = ( + self._block_config.per_layer_lr_scale[block_index] if self._block_config.per_layer_lr_scale else None + ) router_lr_scale = get_lr_scale(self._config.router_lr_scale, layer_lr_scale) self.router = Linear( - hidden_dim, + self._hidden_dim, TensorDim("router_experts", self._config.num_unshared_experts), bias=False, weight_init_method=init_normal_( - std=self._config.init_method_std, - min_val=self._config.init_method_min, - max_val=self._config.init_method_max, + std=self._block_config.init_method_std, + min_val=self._block_config.init_method_min, + max_val=self._block_config.init_method_max, ), lr_scale=router_lr_scale, ) diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index dc5178479..fd64713d1 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -9,29 +9,23 @@ from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig +from fast_llm.layers.block.mlp.config import MLPConfig from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import LinearBase from fast_llm.utils import Assert, get_lr_scale -class MLPBase[ConfigType: BlockConfig](BlockLayer[ConfigType]): +class MLPBase[ConfigType: MLPConfig](BlockLayer[ConfigType]): def __init__( self, config: ConfigType, + block_config: BlockConfig, distributed_config: DistributedConfig, hidden_dim: TensorDim, block_index: int, name: str, ): - super().__init__( - config, - distributed_config, - hidden_dim, - block_index, - name, - config.debug_transformer, - config.debug_transformer_memory, - ) + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name) self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) intermediate_1_dim, intermediate_2_dim = self._get_intermediate_dims() @@ -48,7 +42,9 @@ def __init__( self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - layer_lr_scale = self._config.per_layer_lr_scale[block_index] if self._config.per_layer_lr_scale else None + layer_lr_scale = ( + self._block_config.per_layer_lr_scale[block_index] if self._block_config.per_layer_lr_scale else None + ) lr_scale = ( tuple(self._config.mlp_lr_scale) if isinstance(self._config.mlp_lr_scale, list) @@ -77,8 +73,8 @@ def __init__( ) # PEFT. - self.layer_1 = self._config.peft.apply_linear(self.layer_1, TransformerSubLayerName.mlp_1) - self.layer_2 = self._config.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) + self.layer_1 = self._block_config.peft.apply_linear(self.layer_1, TransformerSubLayerName.mlp_1) + self.layer_2 = self._block_config.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) def _get_intermediate_dims(self): intermediate_2_dim = TensorDim("intermediate", self._config.ffn_hidden_size, self._parallel_dim) @@ -94,13 +90,14 @@ class MLP[ConfigType: BlockConfig](MLPBase[ConfigType]): def __init__( self, config: ConfigType, + block_config: BlockConfig, distributed_config: DistributedConfig, hidden_dim: TensorDim, block_index: int, name: str, ): Assert.eq(config.num_experts, 1) - super().__init__(config, distributed_config, hidden_dim, block_index, name) + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name) def forward( self, diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index f56e2a2c1..b09672961 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -1,146 +1,11 @@ import abc -import enum import typing -from fast_llm.config import Field, FieldHint, check_field, config_class +from fast_llm.config import Field, FieldHint, config_class from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.utils import Assert if typing.TYPE_CHECKING: - import torch - - from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.layers.common.linear import LinearBase, LinearLike - from fast_llm.layers.common.normalization import LayerNorm, RMSNorm - - -class NormalizationImplementation(str, enum.Enum): - """ - An enum for the available implementations of layer norm. - """ - - auto = "auto" - torch = "torch" - fused = "fused" - fast = "fast" - triton = "triton" - - -@config_class(registry=True) -class NormalizationConfig(BaseModelConfig): - pass - - @abc.abstractmethod - def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": - pass - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - if cls is NormalizationConfig and cls.get_subclass(default.get("type")) is None: - # Default subclass. - return LayerNormalizationConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) - - -@config_class(dynamic_type={NormalizationConfig: "none"}) -class NoNormalizationConfig(NormalizationConfig): - _abstract = False - - def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": - return torch.nn.Identity() - - -@config_class() -class LayerNormalizationBaseConfig(NormalizationConfig): - """ - Common configuration for layer norm and rms norm - """ - - # TODO: Rename to normalization_epsilon - epsilon: float = Field( - default=1e-5, - desc="Regularizer for the division.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - zero_centered: bool = Field( - default=False, - desc="Write the normalization weight as `w = 1 + w'`, to improve numerical accuracy when close to one.", - hint=FieldHint.architecture, - ) - implementation: NormalizationImplementation = Field( - default=NormalizationImplementation.auto, - desc="The implementation to use for the normalization layer.", - hint=FieldHint.performance, - ) - # TODO: Rename to normalization_init_range - initialization_range: float = Field( - default=0.0, - desc="Randomize the initialization with a uniform noise. Used to test for issues that may not be visible with the default initialization.", - hint=FieldHint.testing, - valid=check_field(Assert.geq, 0), - ) - - def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": - from fast_llm.engine.config_utils.initialization import init_uniform_centered_ - - kwargs = { - "hidden_dim": hidden_dim, - "eps": self.epsilon, - "implementation": self.implementation, - "zero_centered": self.zero_centered, - "lr_scale": lr_scale, - } - if self.initialization_range: - mean = 0 if self.zero_centered else 1 - kwargs["weight_init_method"] = init_uniform_centered_(self.initialization_range, mean=mean) - return self.module_class(**kwargs) - - @property - @abc.abstractmethod - def module_class(self): - pass - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - cls._handle_renamed_field(default, "normalization_type", "type") - cls._handle_renamed_field(default, "layer_norm_eps", "epsilon") - cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") - cls._handle_renamed_field(default, "normalization_implementation", "implementation") - cls._handle_renamed_field(default, "layer_norm_init_range", "initialization_range") - return super()._from_dict(default, strict, flat) - - -@config_class(dynamic_type={NormalizationConfig: "layer_norm"}) -class LayerNormalizationConfig(LayerNormalizationBaseConfig): - _abstract = False - - @property - def module_class(self): - from fast_llm.layers.common.normalization import LayerNorm - - return LayerNorm - - -@config_class(dynamic_type={NormalizationConfig: "rms_norm"}) -class RMSNormalizationConfig(LayerNormalizationBaseConfig): - _abstract = False - - @property - def module_class(self): - from fast_llm.layers.common.normalization import RMSNorm - - return RMSNorm @config_class() diff --git a/fast_llm/layers/common/normalization/__init__.py b/fast_llm/layers/common/normalization/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/layers/common/normalization/config.py b/fast_llm/layers/common/normalization/config.py new file mode 100644 index 000000000..658d00dfc --- /dev/null +++ b/fast_llm/layers/common/normalization/config.py @@ -0,0 +1,142 @@ +import abc +import enum +import typing + +from fast_llm.config import Field, FieldHint, check_field, config_class +from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + import torch + + from fast_llm.engine.config_utils.tensor_dim import TensorDim + from fast_llm.layers.common.normalization import LayerNorm, RMSNorm + + +class NormalizationImplementation(str, enum.Enum): + """ + An enum for the available implementations of layer norm. + """ + + auto = "auto" + torch = "torch" + fused = "fused" + fast = "fast" + triton = "triton" + + +@config_class(registry=True) +class NormalizationConfig(BaseModelConfig): + pass + + @abc.abstractmethod + def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": + pass + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is NormalizationConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return LayerNormalizationConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + + +@config_class(dynamic_type={NormalizationConfig: "none"}) +class NoNormalizationConfig(NormalizationConfig): + _abstract = False + + def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": + return torch.nn.Identity() + + +@config_class() +class LayerNormalizationBaseConfig(NormalizationConfig): + """ + Common configuration for layer norm and rms norm + """ + + # TODO: Rename to normalization_epsilon + epsilon: float = Field( + default=1e-5, + desc="Regularizer for the division.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + zero_centered: bool = Field( + default=False, + desc="Write the normalization weight as `w = 1 + w'`, to improve numerical accuracy when close to one.", + hint=FieldHint.architecture, + ) + implementation: NormalizationImplementation = Field( + default=NormalizationImplementation.auto, + desc="The implementation to use for the normalization layer.", + hint=FieldHint.performance, + ) + # TODO: Rename to normalization_init_range + initialization_range: float = Field( + default=0.0, + desc="Randomize the initialization with a uniform noise. Used to test for issues that may not be visible with the default initialization.", + hint=FieldHint.testing, + valid=check_field(Assert.geq, 0), + ) + + def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": + from fast_llm.engine.config_utils.initialization import init_uniform_centered_ + + kwargs = { + "hidden_dim": hidden_dim, + "eps": self.epsilon, + "implementation": self.implementation, + "zero_centered": self.zero_centered, + "lr_scale": lr_scale, + } + if self.initialization_range: + mean = 0 if self.zero_centered else 1 + kwargs["weight_init_method"] = init_uniform_centered_(self.initialization_range, mean=mean) + return self.module_class(**kwargs) + + @property + @abc.abstractmethod + def module_class(self): + pass + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + cls._handle_renamed_field(default, "normalization_type", "type") + cls._handle_renamed_field(default, "layer_norm_eps", "epsilon") + cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") + cls._handle_renamed_field(default, "normalization_implementation", "implementation") + cls._handle_renamed_field(default, "layer_norm_init_range", "initialization_range") + return super()._from_dict(default, strict, flat) + + +@config_class(dynamic_type={NormalizationConfig: "layer_norm"}) +class LayerNormalizationConfig(LayerNormalizationBaseConfig): + _abstract = False + + @property + def module_class(self): + from fast_llm.layers.common.normalization.normalization import LayerNorm + + return LayerNorm + + +@config_class(dynamic_type={NormalizationConfig: "rms_norm"}) +class RMSNormalizationConfig(LayerNormalizationBaseConfig): + _abstract = False + + @property + def module_class(self): + from fast_llm.layers.common.normalization.normalization import RMSNorm + + return RMSNorm diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization/normalization.py similarity index 99% rename from fast_llm/layers/common/normalization.py rename to fast_llm/layers/common/normalization/normalization.py index 2b928eb38..06ee11564 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -5,7 +5,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.normalization import triton_normalization_autograd -from fast_llm.layers.common.config import NormalizationImplementation +from fast_llm.layers.common.normalization.config import NormalizationImplementation from fast_llm.tensor import ParameterMeta, accumulate_gradient from fast_llm.utils import Assert diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 6e79388b0..e31c70a45 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -25,7 +25,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.block.mlp.config import RoutingType -from fast_llm.layers.common.config import LayerNormalizationConfig +from fast_llm.layers.common.normalization.config import LayerNormalizationConfig from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig from fast_llm.layers.transformer.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index b5e77e0f0..e9b18b848 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -21,7 +21,7 @@ from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType -from fast_llm.layers.common.config import RMSNormalizationConfig +from fast_llm.layers.common.normalization.config import RMSNormalizationConfig from fast_llm.layers.ssm.config import DTInitType, SSMBlockType from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter from fast_llm.models.ssm.config import ( From c0a37827488caca54a558aa5338e16b995dcf39c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 14 Aug 2025 20:15:02 -0400 Subject: [PATCH 17/28] stuff --- fast_llm/layers/transformer/attention.py | 2 +- fast_llm/layers/transformer/rotary/config.py | 2 +- fast_llm/models/gpt/model.py | 2 +- tests/functional/test_triton_kernels.py | 6 ++++-- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index d7a669295..9ad27534f 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -137,7 +137,7 @@ def __init__( self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward) # Rotary embeddings. - self._rotary = self._config.rotary.build(kv_channels_dim) + self._rotary = self._config.rotary.get_layer(kv_channels_dim) # Output. self.dense = InputParallelLinear( diff --git a/fast_llm/layers/transformer/rotary/config.py b/fast_llm/layers/transformer/rotary/config.py index f0e0079c7..6cc19fce8 100644 --- a/fast_llm/layers/transformer/rotary/config.py +++ b/fast_llm/layers/transformer/rotary/config.py @@ -29,7 +29,7 @@ def _from_dict( return NoRotaryConfig._from_dict(default, strict, flat) return super()._from_dict(default, strict=strict, flat=flat) - def build(self, kv_channels_dim: TensorDim) -> "Rotary": + def get_layer(self, kv_channels_dim: TensorDim) -> "Rotary": return self._get_configurable_class()(self, kv_channels_dim) @classmethod diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 41e0d607d..92f7b8173 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -51,7 +51,7 @@ def __init__( # We have multiple identical rotary modules/preprocessors, so it's simpler to make a new one here. # TODO: Find a better solution. self._preprocessors.append( - self._config.transformer.rotary.build(TensorDim("kv_channels", self._config.transformer.kv_channels)) + self._config.transformer.rotary.get_layer(TensorDim("kv_channels", self._config.transformer.kv_channels)) ) if self._use_flash_attention: self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._distributed_config)) diff --git a/tests/functional/test_triton_kernels.py b/tests/functional/test_triton_kernels.py index e4ad937b7..3f4446e4d 100644 --- a/tests/functional/test_triton_kernels.py +++ b/tests/functional/test_triton_kernels.py @@ -92,7 +92,7 @@ def test_triton_rotary(batch_size, sequence_length, num_heads, kv_channels): y1 = apply_rotary_embeddings( x, DefaultRotaryConfig(triton=False) - .build(None) + .get_layer(None) ._get_frequencies( sequence_length, kv_channels, @@ -103,7 +103,9 @@ def test_triton_rotary(batch_size, sequence_length, num_heads, kv_channels): y2 = convert_rotary_real_to_complex( triton_rotary_( convert_rotary_complex_to_real(x, kv_channels, 3), - DefaultRotaryConfig(triton=True).build(None)._get_frequencies(sequence_length, kv_channels, device="cuda"), + DefaultRotaryConfig(triton=True) + .get_layer(None) + ._get_frequencies(sequence_length, kv_channels, device="cuda"), ), kv_channels, 3, From e60ded4467c564bcb795d923758574d98e2f407b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 15 Aug 2025 13:30:30 -0400 Subject: [PATCH 18/28] stuff --- fast_llm/layers/block/block.py | 48 +++++++++++++++---- fast_llm/layers/block/config.py | 42 ++-------------- fast_llm/layers/block/mlp/config.py | 7 ++- .../layers/block/mlp/mixture_of_experts.py | 12 ++--- fast_llm/layers/block/mlp/mlp.py | 18 +++---- fast_llm/layers/ssm/block.py | 25 +++++----- fast_llm/layers/ssm/discrete_mamba2.py | 17 ++----- fast_llm/layers/ssm/mamba.py | 16 ++----- fast_llm/layers/ssm/mamba2.py | 20 ++------ fast_llm/layers/transformer/attention.py | 39 +++++++-------- fast_llm/layers/transformer/block.py | 16 ++++--- fast_llm/layers/transformer/config.py | 4 -- fast_llm/utils.py | 39 ++++++++------- 13 files changed, 133 insertions(+), 170 deletions(-) diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 09370e3af..64ba31626 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -12,7 +12,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockLayerConfig +from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta @@ -100,6 +100,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, ): super().__init__(config, distributed_config) self._block_config = block_config @@ -112,9 +113,10 @@ def __init__( self._block_config.debug_transformer, self._block_config.debug_transformer_memory, ) + self._lr_scale = lr_scale -class BlockLayer[ConfigType: BlockLayerConfig](BlockLayerBase[ConfigType]): +class BlockLayer[ConfigType: Config](BlockLayerBase[ConfigType]): """ Base class for mixer and MLP modules. """ @@ -145,6 +147,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, return_input: bool = False, ): super().__init__( @@ -154,28 +157,53 @@ def __init__( hidden_dim, block_index, name, + lr_scale, ) # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input # Note, layer_lr_scale does not impact the norms # TODO: add a separate norm_lr_scale - self.norm_1 = self._config.normalization.get_layer(self._hidden_dim) - self.norm_2 = self._config.normalization.get_layer(self._hidden_dim) + self.norm_1 = self._config.peft.apply_other(self._config.normalization.get_layer(self._hidden_dim)) + self.norm_2 = self._config.peft.apply_other(self._config.normalization.get_layer(self._hidden_dim)) - # The mixer needs to be created here for backward-compatible weight ordering. - setattr(self, self._mixer_module_name, self._create_mixer()) + # Attribute should be mixer, but Attention uses a different name for backward compatibility. TODO: Fix. + setattr( + self, + self._mixer_module_name, + self._mixer_class( + self._mixer_config, + self._config, + self._distributed_config, + self._hidden_dim, + self._block_index, + f"{self._name} mixer", + self._lr_scale, + ), + ) # TODO: Use dynamic type. from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.block.mlp.mlp import MLP self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, self._distributed_config, self._hidden_dim, self._block_index, f"{self._name} MLP" + self._config, + self._config, + self._distributed_config, + self._hidden_dim, + self._block_index, + f"{self._name} MLP", + lr_scale, ) - # PEFT. - self.norm_1 = self._config.peft.apply_other(self.norm_1) - self.norm_2 = self._config.peft.apply_other(self.norm_2) + @functools.cached_property + @abc.abstractmethod + def _mixer_class(self) -> type[BlockLayer]: + pass + + @property + @abc.abstractmethod + def _mixer_config(self) -> Config: + pass def setup(self, distributed: Distributed) -> None: super().setup(distributed) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 3df82e24e..63b58722b 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -1,19 +1,17 @@ -import abc import enum -import functools import typing from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.block.mlp.config import MLPConfig from fast_llm.layers.block.peft import TransformerPeftConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.layers.block.block import BlockLayer + pass + +# TODO: Generalize these beyond language models? (Ex. vision) class BlockDimNames: @@ -46,37 +44,6 @@ class AddLinearBiasChoices(str, enum.Enum): only_attn_qkv = "only_attn_qkv" -@config_class() -class BlockLayerConfig(BaseModelConfig): - """ - A common class for mixers and mlps, which have the exact same interface. - """ - - _abstract = True - - @functools.cached_property - @abc.abstractmethod - def layer_class(self) -> "type[BlockLayer]": - raise NotImplementedError() - - def get_layer( - self, - block_config: "BlockConfig", - distributed_config: DistributedConfig, - hidden_dim: TensorDim, - block_index: int, - name: str, - ) -> "BlockLayer": - return self.layer_class( - self, - block_config, - distributed_config, - hidden_dim, - block_index, - name, - ) - - @config_class() # TODO: Use composition instead class BlockConfig(MLPConfig, BaseModelConfig): @@ -90,6 +57,7 @@ class BlockConfig(MLPConfig, BaseModelConfig): desc="Configuration for the parameter-efficient fine tuning.", hint=FieldHint.architecture, ) + # TODO: Review names hidden_dropout: float = Field( default=0.0, desc="Dropout applied to the residual connections.", @@ -121,7 +89,7 @@ class BlockConfig(MLPConfig, BaseModelConfig): # TODO: Move these, not specific to a single block. num_layers: int = Field( default=12, - desc="Number of layers in the transformer.", + desc="Number of blocks in the model.", hint=FieldHint.architecture, valid=check_field(Assert.geq, 0), ) diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 83e45f002..89d423025 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -2,9 +2,8 @@ import functools import typing -from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.functional.config import ActivationType, MLPRecomputeLevel -from fast_llm.layers.block.config import BlockLayerConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -22,7 +21,7 @@ class RoutingType(str, enum.Enum): @config_class() -class MLPConfig(BlockLayerConfig): +class MLPConfig(Config): # TODO: Review names # TODO: Separate MoE? _abstract = False ffn_hidden_size: int = Field( @@ -90,7 +89,7 @@ class MLPConfig(BlockLayerConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - mlp_lr_scale: float | None | list[float | None] = Field( + mlp_lr_scale: float | None | tuple[float | None] = Field( default=None, desc="Custom learning rate scale for each expert.", doc="May be used to freeze some experts by setting their scale to zero.", diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index 2a234ca94..d52f5a429 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -14,7 +14,7 @@ from fast_llm.layers.block.mlp.mlp import MLPBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.linear import Linear -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.utils import Assert, combine_lr_scales logger = logging.getLogger(__name__) @@ -41,16 +41,12 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, ): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name) - - layer_lr_scale = ( - self._block_config.per_layer_lr_scale[block_index] if self._block_config.per_layer_lr_scale else None - ) - router_lr_scale = get_lr_scale(self._config.router_lr_scale, layer_lr_scale) + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) self.router = Linear( self._hidden_dim, @@ -61,7 +57,7 @@ def __init__( min_val=self._block_config.init_method_min, max_val=self._block_config.init_method_max, ), - lr_scale=router_lr_scale, + lr_scale=combine_lr_scales(self._config.router_lr_scale, self._lr_scale), ) dropless_moe = self._config.dropless_moe if dropless_moe and self._sequence_parallel: diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index fd64713d1..341ecf265 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -12,7 +12,7 @@ from fast_llm.layers.block.mlp.config import MLPConfig from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import LinearBase -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.utils import Assert, combine_lr_scales class MLPBase[ConfigType: MLPConfig](BlockLayer[ConfigType]): @@ -24,8 +24,9 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, ): - super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name) + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) intermediate_1_dim, intermediate_2_dim = self._get_intermediate_dims() @@ -42,15 +43,7 @@ def __init__( self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - layer_lr_scale = ( - self._block_config.per_layer_lr_scale[block_index] if self._block_config.per_layer_lr_scale else None - ) - lr_scale = ( - tuple(self._config.mlp_lr_scale) - if isinstance(self._config.mlp_lr_scale, list) - else self._config.mlp_lr_scale - ) - lr_scale = get_lr_scale(lr_scale, layer_lr_scale) + lr_scale = combine_lr_scales(self._lr_scale, self._config.mlp_lr_scale) # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( @@ -95,9 +88,10 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, ): Assert.eq(config.num_experts, 1) - super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name) + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) def forward( self, diff --git a/fast_llm/layers/ssm/block.py b/fast_llm/layers/ssm/block.py index 361fe9818..408f21041 100644 --- a/fast_llm/layers/ssm/block.py +++ b/fast_llm/layers/ssm/block.py @@ -1,3 +1,5 @@ +import functools + from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.block.block import Block, BlockLayer @@ -17,21 +19,20 @@ def __init__( ssm_config: SSMConfig, distributed_config: DistributedConfig, hidden_dim: TensorDim, - mixer_cls: type[BlockLayer], block_index: int, + lr_scale: float | list[float] | None, name: str, + mixer_class: type[BlockLayer], return_input: bool = False, ): self._ssm_config = ssm_config - self._mixer_cls = mixer_cls - super().__init__(config, distributed_config, hidden_dim, block_index, name, return_input) + self._mixer_class = mixer_class + super().__init__(config, distributed_config, hidden_dim, block_index, name, lr_scale, return_input) + + @functools.cached_property + def _mixer_class(self) -> type[BlockLayer]: + return self._mixer_class - def _create_mixer(self) -> BlockLayer: - return self._mixer_cls( - self._ssm_config, - self._config, - self._distributed_config, - self._hidden_dim, - self._block_index, - f"{self._name} mixer", - ) + @property + def _mixer_config(self) -> SSMConfig: + return self._config diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 7e445cca1..fb78f09c5 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -14,7 +14,7 @@ from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.ssm.mamba import init_kaiming_ from fast_llm.tensor import ParameterMeta -from fast_llm.utils import div, get_lr_scale +from fast_llm.utils import combine_lr_scales, div logger = logging.getLogger(__name__) @@ -41,7 +41,6 @@ class DiscreteMamba2[ConfigType: SSMConfig](BlockLayer[ConfigType]): """ _mixer_name: typing.ClassVar[str] = "discrete_mamba_2" - _config: SSMConfig def __init__( self, @@ -51,16 +50,9 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, ): - super().__init__( - config, - distributed_config, - hidden_dim, - block_index, - name, - block_config.debug_transformer, - block_config.debug_transformer_memory, - ) + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) state_dim = TensorDim("state", self._config.state_size) v_head_size_dim = TensorDim(SSMDimNames.head_dim, div(self._config.d_inner, self._config.n_v_heads)) @@ -90,8 +82,7 @@ def __init__( # local_bc_size = local_head_groups * state self._local_bc_size = bc_dim.size - layer_lr_scale = block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None - lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + lr_scale = combine_lr_scales(self._lr_scale, self._config.mamba_lr_scale) # TODO: double check initializations # Projections diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index ac6576a87..37ac20ef1 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -13,7 +13,7 @@ from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig from fast_llm.tensor import ParameterMeta -from fast_llm.utils import Assert, div, get_lr_scale +from fast_llm.utils import Assert, combine_lr_scales, div try: from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn as _mamba_inner_fn # noqa @@ -65,16 +65,9 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, ): - super().__init__( - config, - distributed_config, - hidden_dim, - block_index, - name, - block_config.debug_transformer, - block_config.debug_transformer_memory, - ) + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) assert self._distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" # TODO: It's not silu? Assert.eq(self._config.activation_type, ActivationType.silu) @@ -88,8 +81,7 @@ def __init__( inner_projection_dim = ConcatenatedTensorDim("inner_projection", (inner_dim, inner_dim)) x_projection_dim = ConcatenatedTensorDim("x_projection", (dt_rank_dim, state_dim, state_dim)) - layer_lr_scale = block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None - lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + lr_scale = combine_lr_scales(self._lr_scale, self._config.mamba_lr_scale) # TODO: Backward compatibility? # TODO: lr_scale? diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index e6ca9ea12..bc40658e6 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -13,7 +13,7 @@ from fast_llm.layers.ssm.config import SSMConfig from fast_llm.layers.ssm.mamba import init_A, init_dtprojbias, init_kaiming_ from fast_llm.tensor import ParameterMeta -from fast_llm.utils import Assert, div, get_lr_scale +from fast_llm.utils import Assert, combine_lr_scales, div try: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa @@ -47,22 +47,10 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, ): - super().__init__( - config, - distributed_config, - hidden_dim, - block_index, - name, - block_config.debug_transformer, - block_config.debug_transformer_memory, - ) - self._config: SSMConfig = config + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) Assert.eq(self._config.activation_type, ActivationType.silu) - layer_lr_scale: float | None = ( - block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None - ) - lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) num_heads = div(self._config.d_inner, self._config.state_size) num_head_groups = div(self._config.d_xb, self._config.state_size) @@ -94,6 +82,8 @@ def __init__( self._local_inner_size = inner_dim.size self._local_xb_size = xb_dim.size + lr_scale = combine_lr_scales(self._lr_scale, self._config.mamba_lr_scale) + conv1d_dim = inner_dim if self._config.repeat_kv_before_conv else xb_dim self.conv1d_weight = ParameterMeta.from_dims( ( diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 9ad27534f..8abab1206 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -9,11 +9,11 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.block.block import BlockLayer -from fast_llm.layers.block.config import BlockDimNames +from fast_llm.layers.block.config import BlockConfig, BlockDimNames from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import AttentionKwargs, TransformerConfig -from fast_llm.utils import div, get_lr_scale +from fast_llm.layers.transformer.config import AttentionConfig, AttentionKwargs +from fast_llm.utils import combine_lr_scales, div try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -48,7 +48,7 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # no return grad, None -class Attention[ConfigType: TransformerConfig](BlockLayer[ConfigType]): +class Attention[ConfigType: AttentionConfig](BlockLayer[ConfigType]): """ A self-attention layer. """ @@ -56,20 +56,15 @@ class Attention[ConfigType: TransformerConfig](BlockLayer[ConfigType]): def __init__( self, config: ConfigType, + block_config: BlockConfig, distributed_config: DistributedConfig, hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, ): - super().__init__( - config, - distributed_config, - hidden_dim, - block_index, - name, - config.debug_transformer, - config.debug_transformer_memory, - ) + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) + self._use_flash_attention = self._config.do_use_flash_attention(self._distributed_config) self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) @@ -112,8 +107,10 @@ def __init__( max_val=self._config.init_method_max_attn_proj, ) - layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None - attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) + lr_scale = combine_lr_scales( + self._lr_scale, + self._config.attention_lr_scale, + ) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( @@ -123,7 +120,7 @@ def __init__( weight_init_method=init_method_qkv, bias_init_method=init_zeros_, sequence_parallel=self._sequence_parallel, - lr_scale=attention_lr_scale, + lr_scale=lr_scale, ) self.key_value = OutputParallelLinear( hidden_dim, @@ -132,7 +129,7 @@ def __init__( weight_init_method=init_method_qkv, bias_init_method=init_zeros_, sequence_parallel=self._sequence_parallel, - lr_scale=attention_lr_scale, + lr_scale=lr_scale, ) self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward) @@ -147,13 +144,13 @@ def __init__( weight_init_method=init_method_std_attn_proj, bias_init_method=init_zeros_, sequence_parallel=self._sequence_parallel, - lr_scale=attention_lr_scale, + lr_scale=lr_scale, ) # PEFT. - self.query = self._config.peft.apply_linear(self.query, TransformerSubLayerName.query) - self.key_value = self._config.peft.apply_linear(self.key_value, TransformerSubLayerName.key_value) - self.dense = self._config.peft.apply_linear(self.dense, TransformerSubLayerName.dense) + self.query = self._block_config.peft.apply_linear(self.query, TransformerSubLayerName.query) + self.key_value = self._block_config.peft.apply_linear(self.key_value, TransformerSubLayerName.key_value) + self.dense = self._block_config.peft.apply_linear(self.dense, TransformerSubLayerName.dense) if self._debug.enabled: self._query_dims = ( diff --git a/fast_llm/layers/transformer/block.py b/fast_llm/layers/transformer/block.py index a5aad45a9..ba593461b 100644 --- a/fast_llm/layers/transformer/block.py +++ b/fast_llm/layers/transformer/block.py @@ -1,9 +1,10 @@ +import functools import logging import typing -from fast_llm.layers.block.block import Block, BlockLayer +from fast_llm.layers.block.block import Block from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.layers.transformer.config import AttentionConfig, TransformerConfig logger = logging.getLogger(__name__) @@ -12,7 +13,10 @@ class TransformerBlock[ConfigType: TransformerConfig](Block[ConfigType]): # TODO: Standardize to `mixer` _mixer_module_name: typing.ClassVar[str] = "self_attn" - def _create_mixer(self) -> BlockLayer: - return Attention( - self._config, self._distributed_config, self._hidden_dim, self._block_index, f"{self._name} attn" - ) + @functools.cached_property + def _mixer_class(self) -> type[Attention]: + return Attention + + @property + def _mixer_config(self) -> AttentionConfig: + return self._config diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index a40f676ca..02b741723 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -1,6 +1,5 @@ import functools import logging -import typing import warnings from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none @@ -11,9 +10,6 @@ from fast_llm.layers.transformer.rotary.config import RotaryConfig from fast_llm.utils import Assert, div -if typing.TYPE_CHECKING: - pass - logger = logging.getLogger(__name__) diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 58285d408..f7f5e9663 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -348,22 +348,29 @@ def check_equal_nested(config_a, config_b): raise ValueError("\n".join(errors)) -def get_lr_scale( - lr_scale: float | None | tuple[float | None, ...], layer_lr_scale: float | None -) -> float | None | tuple[float | None, ...]: - """ - Combine module and layer lr_scale. - If one is None, return the other. - """ - if lr_scale is None: - return layer_lr_scale - if layer_lr_scale is None: - return lr_scale - if isinstance(lr_scale, float): - return lr_scale * layer_lr_scale - if isinstance(lr_scale, tuple): - return tuple(lrs * layer_lr_scale if lrs is not None else layer_lr_scale for lrs in lr_scale) - raise ValueError(f"Invalid lr_scale: {lr_scale} (type {type(lr_scale)})") +def combine_lr_scales(*lr_scales: float | None | tuple[float | None, ...]): + # Remove `None` entries. + lr_scales = [lr_scale for lr_scale in lr_scales if lr_scale is not None] + if not lr_scales: + # Everything is None + return None + tuple_length = None + # Check if we have tuples, and determine the length. + for lr_scale in lr_scales: + if isinstance(lr_scale, tuple): + if tuple_length is None: + tuple_length = len(lr_scale) + else: + assert len(lr_scale) == tuple_length + if tuple_length is None: + # No tuple: simple product. + return math.prod(lr_scales) + else: + # Tuple(s): use recursion. + return [ + combine_lr_scales(*[lr_scale[i] if isinstance(lr_scale, tuple) else lr_scale for lr_scale in lr_scales]) + for i in range(tuple_length) + ] class Interrupter: From 1483bcc7cbe7bf6fa763ab12b75573ed89015207 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 15 Aug 2025 13:57:56 -0400 Subject: [PATCH 19/28] stuff --- fast_llm/layers/language_model/embedding.py | 5 +++-- fast_llm/layers/language_model/head.py | 7 ++++--- fast_llm/layers/ssm/config.py | 23 --------------------- fast_llm/layers/ssm/discrete_mamba2.py | 10 ++++----- fast_llm/layers/ssm/mamba.py | 7 ------- fast_llm/layers/ssm/mamba2.py | 3 +-- 6 files changed, 13 insertions(+), 42 deletions(-) diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index d1b912167..fd4e8412e 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -37,12 +37,13 @@ def __init__( ): super().__init__( config, + config.transformer, distributed_config, hidden_dim, block_index, name, - config.transformer.debug_transformer, - config.transformer.debug_transformer_memory, + # TODO: Add lr scale? + None, ) self._residual_dtype = ( self._distributed_config.optimization_dtype diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index cc6c69262..7b1b5f6d8 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -44,12 +44,13 @@ def __init__( ): super().__init__( config, + config.transformer, distributed_config, hidden_dim, block_index, name, - config.transformer.debug_transformer, - config.transformer.debug_transformer_memory, + # TODO: Add lr scale? + None, ) self._parallel_logits = self._distributed_config.tensor_parallel > 1 and config.parallel_embeddings self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) @@ -161,7 +162,7 @@ def _forward_backward( TensorDim( BlockDimNames.sequence_q_tp, dims[sequence_index].global_size, - DistributedDimNames.tensor, + self._distributed_config.get_distributed_dim(DistributedDimNames.tensor), ) if self._sequence_parallel_logits else TensorDim(BlockDimNames.sequence_q, dims[sequence_index].global_size) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 2daad1186..8917feaf6 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -3,35 +3,12 @@ from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.functional.config import ActivationType -from fast_llm.layers.block.config import BlockDimNames from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.engine.config_utils.initialization import Initializer -class SSMDimNames(BlockDimNames): - # TODO: Use separate tensor space for different mixers so there is no risk of name conflict. - state = "ssm_state" # State dimension (N), aka head size / num channels - head_dim = "ssm_head_dim" - head_groups = "ssm_head_groups" - group_heads = "ssm_group_heads" - - convolution_kernel = "ssm_convolution_kernel" # Kernel dimension of the conv1d in mamba layers - - dt_rank = "ssm_dt_rank" - - # Composite dimensions - composite_heads = "ssm_composite_heads" - composite_heads_and_head_dim = "ssm_composite_heads_and_head_dim" - composite_head_groups_and_state = "ssm_composite_head_groups_and_state" - - # Concatenated dimensions - concatenated_convolution = "ssm_concatenated_convolution" - concatenated_x_projection = "ssm_x_concatenated_x_projection" - concatenated_inner_projection = "ssm_concatenated_inner_projection" - - class SSMBlockType(enum.StrEnum): """ An enum for the available mamba types for the MLP layer. diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index fb78f09c5..7fea3d480 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -11,7 +11,7 @@ from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.ssm.config import SSMConfig from fast_llm.layers.ssm.mamba import init_kaiming_ from fast_llm.tensor import ParameterMeta from fast_llm.utils import combine_lr_scales, div @@ -54,15 +54,15 @@ def __init__( ): super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) state_dim = TensorDim("state", self._config.state_size) - v_head_size_dim = TensorDim(SSMDimNames.head_dim, div(self._config.d_inner, self._config.n_v_heads)) + v_head_size_dim = TensorDim("v_head_size", div(self._config.d_inner, self._config.n_v_heads)) head_groups_dim = TensorDim( - SSMDimNames.head_groups, + "head_groups", self._config.n_qk_heads, self._distributed_config.get_distributed_dim(DistributedDimNames.tensor), ) - group_heads_dim = TensorDim(SSMDimNames.group_heads, div(self._config.n_v_heads, self._config.n_qk_heads)) - heads_dim = CompositeTensorDim(SSMDimNames.composite_heads, (head_groups_dim, group_heads_dim)) + group_heads_dim = TensorDim("group_heads", div(self._config.n_v_heads, self._config.n_qk_heads)) + heads_dim = CompositeTensorDim("heads", (head_groups_dim, group_heads_dim)) inner_dim = CompositeTensorDim("inner", (head_groups_dim, group_heads_dim, v_head_size_dim)) bc_dim = CompositeTensorDim("bc", (head_groups_dim, state_dim)) convolution_kernel_dim = TensorDim("convolution_kernel", self._config.conv_kernel_dimension) diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index 37ac20ef1..59fd03a1e 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -91,7 +91,6 @@ def __init__( bias=False, weight_init_method=init_kaiming_(hidden_dim.size), ) - self.conv1d_weight = ParameterMeta.from_dims( ( inner_dim, @@ -101,7 +100,6 @@ def __init__( init_method=init_kaiming_(inner_dim.size), lr_scale=lr_scale, ) - self.x_proj = Linear( inner_dim, x_projection_dim, @@ -110,27 +108,23 @@ def __init__( lr_scale=lr_scale, ) self.x_proj.weight.auto_grad_accumulation = True - # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 self.dt_proj_weight = ParameterMeta.from_dims( (inner_dim, dt_rank_dim), init_method=init_kaiming_(self._config.dt_rank), lr_scale=lr_scale, ) - self.dt_proj_bias = ParameterMeta.from_dims( (inner_dim,), init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), lr_scale=lr_scale, ) - self.A_log = ParameterMeta.from_dims( (inner_dim, state_dim), weight_decay=False, init_method=init_A(self._config.state_size, inner_dim.size), lr_scale=lr_scale, ) - # D "skip" parameter self.D = ParameterMeta.from_dims( (inner_dim,), @@ -138,7 +132,6 @@ def __init__( init_method=init_ones_, lr_scale=lr_scale, ) - self.out_proj = Linear( inner_dim, hidden_dim, diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index bc40658e6..bf9c30521 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -81,10 +81,10 @@ def __init__( self._group_heads = div(self._local_heads, self._local_head_groups) self._local_inner_size = inner_dim.size self._local_xb_size = xb_dim.size + conv1d_dim = inner_dim if self._config.repeat_kv_before_conv else xb_dim lr_scale = combine_lr_scales(self._lr_scale, self._config.mamba_lr_scale) - conv1d_dim = inner_dim if self._config.repeat_kv_before_conv else xb_dim self.conv1d_weight = ParameterMeta.from_dims( ( conv1d_dim, @@ -107,7 +107,6 @@ def __init__( sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) - self.dt_in_proj = Linear( hidden_dim, dt_rank_dim, From 4deb501748a1f725d96aff0ba88034166b4ae04f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 15 Aug 2025 14:12:57 -0400 Subject: [PATCH 20/28] misc --- fast_llm/layers/block/block.py | 4 ---- fast_llm/models/gpt/model.py | 6 ++++++ fast_llm/models/ssm/model.py | 10 ++++++++-- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 64ba31626..535ca12c5 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -210,10 +210,6 @@ def setup(self, distributed: Distributed) -> None: getattr(self, self._mixer_module_name).setup(distributed) self.mlp.setup(distributed) - @abc.abstractmethod - def _create_mixer(self) -> BlockLayer: - pass - @torch.compile def _bias_dropout_add( self, input_: torch.Tensor, bias: torch.Tensor | None, residual: torch.Tensor diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 92f7b8173..581429467 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -100,12 +100,18 @@ def _get_block( name: str, return_input: bool = False, ): + lr_scale = ( + None + if self._config.transformer.per_layer_lr_scale is None + else self._config.transformer.per_layer_lr_scale[block_index] + ) return TransformerBlock( self._config.transformer, self._distributed_config, self._hidden_dim, block_index, name, + lr_scale, return_input, ) diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 7c67d7355..9afd7dabb 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -29,6 +29,12 @@ def _get_block( # Decoder block block_type = self._config.hybrid_block_layout[block_index - 1] + lr_scale = ( + None + if self._config.transformer.per_layer_lr_scale is None + else self._config.transformer.per_layer_lr_scale[block_index] + ) + if block_type == SSMBlockType.transformer: return TransformerBlock( self._config.transformer, @@ -36,7 +42,7 @@ def _get_block( self._hidden_dim, block_index, name, - return_input, + lr_scale.return_input, ) else: return SSMBlock( @@ -44,9 +50,9 @@ def _get_block( self._config.ssm, self._distributed_config, self._hidden_dim, - self._config.ssm_block_type.get_mixer_class(), block_index, name, + lr_scale.self._config.ssm_block_type.get_mixer_class(), return_input, ) From fc809e0ed20747314ca98f08734d7eebd43cfb22 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 15 Aug 2025 15:37:38 -0400 Subject: [PATCH 21/28] Misc, tests pass --- fast_llm/layers/block/block.py | 4 ++-- fast_llm/layers/block/config.py | 2 +- fast_llm/layers/block/mlp/config.py | 18 +++--------------- .../layers/block/mlp/mixture_of_experts.py | 2 +- fast_llm/layers/block/mlp/mlp.py | 4 ++-- fast_llm/layers/ssm/block.py | 4 ++-- fast_llm/layers/ssm/discrete_mamba2.py | 2 +- fast_llm/layers/ssm/mamba.py | 2 +- fast_llm/layers/ssm/mamba2.py | 2 +- fast_llm/layers/transformer/attention.py | 2 +- fast_llm/models/ssm/model.py | 6 ++++-- fast_llm/utils.py | 6 +++--- 12 files changed, 22 insertions(+), 32 deletions(-) diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 535ca12c5..b8aad3903 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -100,7 +100,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, ): super().__init__(config, distributed_config) self._block_config = block_config @@ -147,7 +147,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, return_input: bool = False, ): super().__init__( diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 63b58722b..95bcb02af 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -99,7 +99,7 @@ class BlockConfig(MLPConfig, BaseModelConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - per_layer_lr_scale: list[float] | None = Field( + per_layer_lr_scale: list[float | None] | None = Field( default=None, desc="Custom learning rate scale for each layer.", doc="May be used to freeze some layers by setting their scale to zero.", diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 89d423025..88ce4af10 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -1,5 +1,4 @@ import enum -import functools import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none @@ -7,7 +6,7 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.layers.block.mlp.mlp import MLPBase + pass class MLPLossNames: @@ -89,7 +88,7 @@ class MLPConfig(Config): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - mlp_lr_scale: float | None | tuple[float | None] = Field( + mlp_lr_scale: float | None | tuple[float | None, ...] = Field( default=None, desc="Custom learning rate scale for each expert.", doc="May be used to freeze some experts by setting their scale to zero.", @@ -155,17 +154,6 @@ def add_mlp_bias(self) -> bool: return True return False - @functools.cached_property - def layer_class(self) -> "type[MLPBase]": - if self.num_experts > 1: - from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP - - return MixtureOfExpertMLP - else: - from fast_llm.layers.block.mlp.mlp import MLP - - return MLP - def _validate(self) -> None: with self._set_implicit_default(): if self.activation_type is None: @@ -198,7 +186,7 @@ def _validate(self) -> None: Assert.leq(self.num_shared_experts, self.num_experts) Assert.leq(self.num_shared_experts + self.num_experts_per_token, self.num_experts) - if isinstance(self.mlp_lr_scale, list): + if isinstance(self.mlp_lr_scale, tuple): Assert.eq(len(self.mlp_lr_scale), self.num_experts) for scale in self.mlp_lr_scale: if scale is not None: diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index d52f5a429..4f7cf2dc4 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -41,7 +41,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, ): Assert.gt(config.num_experts, 1) # TODO: Implement? diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index 341ecf265..c3a714a42 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -24,7 +24,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, ): super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) @@ -88,7 +88,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, ): Assert.eq(config.num_experts, 1) super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) diff --git a/fast_llm/layers/ssm/block.py b/fast_llm/layers/ssm/block.py index 408f21041..22d01a5cb 100644 --- a/fast_llm/layers/ssm/block.py +++ b/fast_llm/layers/ssm/block.py @@ -20,8 +20,8 @@ def __init__( distributed_config: DistributedConfig, hidden_dim: TensorDim, block_index: int, - lr_scale: float | list[float] | None, name: str, + lr_scale: float | None, mixer_class: type[BlockLayer], return_input: bool = False, ): @@ -35,4 +35,4 @@ def _mixer_class(self) -> type[BlockLayer]: @property def _mixer_config(self) -> SSMConfig: - return self._config + return self._ssm_config diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 7fea3d480..0d91fbaff 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -50,7 +50,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, ): super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) state_dim = TensorDim("state", self._config.state_size) diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index 59fd03a1e..79a0e5c8e 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -65,7 +65,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, ): super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) assert self._distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index bf9c30521..eec134a22 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -47,7 +47,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, ): super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) Assert.eq(self._config.activation_type, ActivationType.silu) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 8abab1206..41d509512 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -61,7 +61,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, ): super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 9afd7dabb..4b7785402 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -42,7 +42,8 @@ def _get_block( self._hidden_dim, block_index, name, - lr_scale.return_input, + lr_scale, + return_input, ) else: return SSMBlock( @@ -52,7 +53,8 @@ def _get_block( self._hidden_dim, block_index, name, - lr_scale.self._config.ssm_block_type.get_mixer_class(), + lr_scale, + self._config.ssm_block_type.get_mixer_class(), return_input, ) diff --git a/fast_llm/utils.py b/fast_llm/utils.py index f7f5e9663..51249c3fa 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -350,7 +350,7 @@ def check_equal_nested(config_a, config_b): def combine_lr_scales(*lr_scales: float | None | tuple[float | None, ...]): # Remove `None` entries. - lr_scales = [lr_scale for lr_scale in lr_scales if lr_scale is not None] + lr_scales = tuple(lr_scale for lr_scale in lr_scales if lr_scale is not None) if not lr_scales: # Everything is None return None @@ -367,10 +367,10 @@ def combine_lr_scales(*lr_scales: float | None | tuple[float | None, ...]): return math.prod(lr_scales) else: # Tuple(s): use recursion. - return [ + return tuple( combine_lr_scales(*[lr_scale[i] if isinstance(lr_scale, tuple) else lr_scale for lr_scale in lr_scales]) for i in range(tuple_length) - ] + ) class Interrupter: From cdb67105cc9f70c234132ca7d248f7db0cdfef89 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 20 Aug 2025 10:50:59 -0400 Subject: [PATCH 22/28] misc --- fast_llm/layers/block/config.py | 4 - .../layers/common/normalization/config.py | 44 +++--- .../common/normalization/normalization.py | 128 ++++++++++-------- fast_llm/layers/language_model/head.py | 4 +- fast_llm/layers/transformer/attention.py | 1 - 5 files changed, 94 insertions(+), 87 deletions(-) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 95bcb02af..29acaadf0 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -1,5 +1,4 @@ import enum -import typing from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelConfig @@ -8,9 +7,6 @@ from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.utils import Assert -if typing.TYPE_CHECKING: - pass - # TODO: Generalize these beyond language models? (Ex. vision) diff --git a/fast_llm/layers/common/normalization/config.py b/fast_llm/layers/common/normalization/config.py index 658d00dfc..569d48b0e 100644 --- a/fast_llm/layers/common/normalization/config.py +++ b/fast_llm/layers/common/normalization/config.py @@ -7,10 +7,8 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - import torch - from fast_llm.engine.config_utils.tensor_dim import TensorDim - from fast_llm.layers.common.normalization import LayerNorm, RMSNorm + from fast_llm.layers.common.normalization.normalization import Normalization class NormalizationImplementation(str, enum.Enum): @@ -29,10 +27,18 @@ class NormalizationImplementation(str, enum.Enum): class NormalizationConfig(BaseModelConfig): pass + @property @abc.abstractmethod - def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": + def module_class(self) -> type["Normalization"]: pass + def get_layer( + self, + hidden_dim: "TensorDim", + lr_scale: float | None = None, + ) -> "Normalization": + return self.module_class(self, hidden_dim, lr_scale) + @classmethod def _from_dict( cls, @@ -50,8 +56,11 @@ def _from_dict( class NoNormalizationConfig(NormalizationConfig): _abstract = False - def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": - return torch.nn.Identity() + @property + def module_class(self) -> type["Normalization"]: + from fast_llm.layers.common.normalization.normalization import NoNormalization + + return NoNormalization @config_class() @@ -85,21 +94,6 @@ class LayerNormalizationBaseConfig(NormalizationConfig): valid=check_field(Assert.geq, 0), ) - def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": - from fast_llm.engine.config_utils.initialization import init_uniform_centered_ - - kwargs = { - "hidden_dim": hidden_dim, - "eps": self.epsilon, - "implementation": self.implementation, - "zero_centered": self.zero_centered, - "lr_scale": lr_scale, - } - if self.initialization_range: - mean = 0 if self.zero_centered else 1 - kwargs["weight_init_method"] = init_uniform_centered_(self.initialization_range, mean=mean) - return self.module_class(**kwargs) - @property @abc.abstractmethod def module_class(self): @@ -126,9 +120,9 @@ class LayerNormalizationConfig(LayerNormalizationBaseConfig): @property def module_class(self): - from fast_llm.layers.common.normalization.normalization import LayerNorm + from fast_llm.layers.common.normalization.normalization import LayerNormalization - return LayerNorm + return LayerNormalization @config_class(dynamic_type={NormalizationConfig: "rms_norm"}) @@ -137,6 +131,6 @@ class RMSNormalizationConfig(LayerNormalizationBaseConfig): @property def module_class(self): - from fast_llm.layers.common.normalization.normalization import RMSNorm + from fast_llm.layers.common.normalization.normalization import RMSNormalization - return RMSNorm + return RMSNormalization diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index 06ee11564..7f7d3eb65 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -1,11 +1,20 @@ +import abc + import torch -from fast_llm.engine.config_utils.initialization import init_ones_, init_zeros_ +from fast_llm.config import Configurable +from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_, init_zeros_ from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.normalization import triton_normalization_autograd -from fast_llm.layers.common.normalization.config import NormalizationImplementation +from fast_llm.layers.common.normalization.config import ( + LayerNormalizationConfig, + NoNormalizationConfig, + NormalizationConfig, + NormalizationImplementation, + RMSNormalizationConfig, +) from fast_llm.tensor import ParameterMeta, accumulate_gradient from fast_llm.utils import Assert @@ -139,7 +148,24 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None, return grad_input, None, None, None -class LayerNorm(torch.nn.Module): +class Normalization[ConfigType: NormalizationConfig](Configurable[ConfigType], torch.nn.Module): + def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | None = None): + super().__init__(config) + self._hidden_dim = hidden_dim + self._lr_scale = lr_scale + assert not self._hidden_dim.is_parallel + + @abc.abstractmethod + def forward(self, input_: torch.Tensor) -> torch.Tensor: + pass + + +class NoNormalization[ConfigType: NoNormalizationConfig](Normalization[ConfigType]): + def forward(self, input_: torch.Tensor) -> torch.Tensor: + return input_ + + +class LayerNormalization[ConfigType: LayerNormalizationConfig](Normalization[ConfigType]): """ A layer normalization layer, supporting multiple implementations. Note: Converting input automatically to training dtype to match Apex behaviour, @@ -147,25 +173,17 @@ class LayerNorm(torch.nn.Module): TODO: Review this? """ - def __init__( - self, - hidden_dim: TensorDim, - *, - eps=1e-5, - implementation: NormalizationImplementation = NormalizationImplementation.auto, - weight_init_method=None, - bias_init_method=init_zeros_, - zero_centered: bool = False, - lr_scale: float | None = None, - ): - super().__init__() - assert not hidden_dim.is_parallel - self._eps = eps - self._zero_centered = zero_centered + def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | None = None): + super().__init__(config, hidden_dim, lr_scale) + implementation = self._config.implementation if implementation == NormalizationImplementation.auto: - if _fast_normalization_available and hidden_dim.size in _PERSIST_LN_SIZES and not self._zero_centered: + if ( + _fast_normalization_available + and hidden_dim.size in _PERSIST_LN_SIZES + and not self._config.zero_centered + ): implementation = NormalizationImplementation.fast - elif TritonConfig.TRITON_ENABLED or self._zero_centered: + elif TritonConfig.TRITON_ENABLED or self._config.zero_centered: log_main_rank("Fast layer norm unavailable, using backup triton implementation.") implementation = NormalizationImplementation.triton elif _fused_normalization_available: @@ -174,7 +192,7 @@ def __init__( else: log_main_rank("Fast and fused layer norm unavailable, using backup pytorch implementation.") implementation = NormalizationImplementation.torch - if self._zero_centered: + if self._config.zero_centered: assert implementation == NormalizationImplementation.triton if implementation == NormalizationImplementation.triton: self._forward = self._forward_triton @@ -187,44 +205,49 @@ def __init__( else: raise NotImplementedError(implementation) - if weight_init_method is None: - weight_init_method = init_zeros_ if self._zero_centered else init_ones_ + if self.config.initialization_range: + mean = 0 if self.zero_centered else 1 + weight_init_method = init_uniform_centered_(self.config.initialization_range, mean=mean) + else: + weight_init_method = init_zeros_ if self._config.zero_centered else init_ones_ self.weight = ParameterMeta.from_dims( (hidden_dim,), init_method=weight_init_method, weight_decay=False, auto_grad_accumulation=implementation == NormalizationImplementation.torch, - lr_scale=lr_scale, + lr_scale=self._lr_scale, ) self.bias = ParameterMeta.from_dims( (hidden_dim,), - init_method=bias_init_method, + init_method=init_zeros_, weight_decay=False, auto_grad_accumulation=implementation == NormalizationImplementation.torch, - lr_scale=lr_scale, + lr_scale=self._lr_scale, ) - self.normalized_shape = self.weight.shape + self._normalized_shape = self.weight.shape def forward(self, input_: torch.Tensor) -> torch.Tensor: - return self._forward(input_.view(-1, *self.normalized_shape)).view_as(input_) + return self._forward(input_.view(-1, *self._normalized_shape)).view_as(input_) def _forward_triton(self, input_: torch.Tensor) -> torch.Tensor: return triton_normalization_autograd( - input_, self.weight, self.bias, self._eps, self.training, self._zero_centered + input_, self.weight, self.bias, self._config.epsilon, self.training, self._config.zero_centered ) def _forward_fast(self, input_: torch.Tensor) -> torch.Tensor: - return FastLayerNorm.apply(input_, self.normalized_shape, self.weight, self.bias, self._eps) + return FastLayerNorm.apply(input_, self._normalized_shape, self.weight, self.bias, self._config.epsilon) def _forward_fused(self, input_: torch.Tensor) -> torch.Tensor: - return FusedLayerNorm.apply(input_, self.normalized_shape, self.weight, self.bias, self._eps) + return FusedLayerNorm.apply(input_, self._normalized_shape, self.weight, self.bias, self._config.epsilon) def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: - return torch.layer_norm(input_.to(self.weight.dtype), self.normalized_shape, self.weight, self.bias, self._eps) + return torch.layer_norm( + input_.to(self.weight.dtype), self._normalized_shape, self.weight, self.bias, self._config.epsilon + ) -class RMSNorm(torch.nn.Module): +class RMSNormalization[ConfigType: RMSNormalizationConfig](Configurable[ConfigType], torch.nn.Module): """ A RMS normalization layer. Note: Converting input automatically to training dtype to match Apex behaviour, @@ -232,22 +255,12 @@ class RMSNorm(torch.nn.Module): TODO: Review this? """ - def __init__( - self, - hidden_dim: TensorDim, - *, - eps=1e-5, - implementation: NormalizationImplementation = NormalizationImplementation.auto, - weight_init_method=None, - zero_centered: bool = False, - lr_scale: float | None = None, - ): - super().__init__() + def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | None = None): + super().__init__(config, hidden_dim, lr_scale) assert not hidden_dim.is_parallel - self._eps = eps - self._zero_centered = zero_centered + implementation = self._config.implementation if implementation == NormalizationImplementation.auto: - if TritonConfig.TRITON_ENABLED or self._zero_centered: + if TritonConfig.TRITON_ENABLED or self._config.zero_centered: implementation = NormalizationImplementation.triton elif _fused_normalization_available: log_main_rank("Triton RMS norm unavailable, using fused implementation.") @@ -255,7 +268,7 @@ def __init__( else: log_main_rank("Fused RMS norm unavailable, using backup implementation.") implementation = NormalizationImplementation.torch - if self._zero_centered: + if self._config.zero_centered: assert implementation == NormalizationImplementation.triton if implementation == NormalizationImplementation.triton: self._forward = self._forward_triton @@ -266,8 +279,11 @@ def __init__( else: raise NotImplementedError(implementation) - if weight_init_method is None: - weight_init_method = init_zeros_ if self._zero_centered else init_ones_ + if self.config.initialization_range: + mean = 0 if self.zero_centered else 1 + weight_init_method = init_uniform_centered_(self.config.initialization_range, mean=mean) + else: + weight_init_method = init_zeros_ if self._config.zero_centered else init_ones_ self.weight = ParameterMeta.from_dims( (hidden_dim,), @@ -276,16 +292,18 @@ def __init__( auto_grad_accumulation=True, lr_scale=lr_scale, ) - self.normalized_shape = self.weight.shape + self._normalized_shape = self.weight.shape def forward(self, input_: torch.Tensor) -> torch.Tensor: - return self._forward(input_.view(-1, *self.normalized_shape)).view_as(input_) + return self._forward(input_.view(-1, *self._normalized_shape)).view_as(input_) def _forward_triton(self, input_: torch.Tensor) -> torch.Tensor: - return triton_normalization_autograd(input_, self.weight, None, self._eps, self.training, self._zero_centered) + return triton_normalization_autograd( + input_, self.weight, None, self._config.epsilon, self.training, self._config.zero_centered + ) def _forward_fused(self, input_: torch.Tensor) -> torch.Tensor: - return FusedRMSNorm.apply(input_, self.normalized_shape, self.weight, self._eps) + return FusedRMSNorm.apply(input_, self._normalized_shape, self.weight, self._config.epsilon) def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: - return torch.rms_norm(input_.to(self.weight.dtype), self.normalized_shape, self.weight, self._eps) + return torch.rms_norm(input_.to(self.weight.dtype), self._normalized_shape, self.weight, self._config.epsilon) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 7b1b5f6d8..d0c0eb8f9 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -118,7 +118,7 @@ def forward( tensor_name="Loss", reductions=( (self._distributed_config.get_distributed_dim(DistributedDimNames.data), ReduceOp.AVG), - ), # noqa + ), ) else: return TensorMeta.from_dims(input_.dims[1:], tensor_name="Shared hidden") @@ -262,7 +262,7 @@ def _logits_cross_entropy_forward_backward_split( return None, None else: loss = None - # TODO MTP: allow a _cross_entropy_splits that is not a divisor of the sequence length + # TODO MTP: allow a cross_entropy_splits that is not a divisor of the sequence length grad_output /= self._config.cross_entropy_splits logit_input = input_.flatten(0, -2) if self.training: diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 41d509512..91fca75b8 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -64,7 +64,6 @@ def __init__( lr_scale: float | None, ): super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) - self._use_flash_attention = self._config.do_use_flash_attention(self._distributed_config) self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) From 9ce72e04ead2857adefa2c13430c9cbcb373e506 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 20 Aug 2025 12:51:26 -0400 Subject: [PATCH 23/28] Move files --- fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py | 2 +- fast_llm/layers/{transformer => attention}/__init__.py | 0 .../layers/{transformer => attention}/attention.py | 2 +- fast_llm/layers/{transformer => attention}/block.py | 4 ++-- fast_llm/layers/{transformer => attention}/config.py | 2 +- .../layers/{transformer => attention}/preprocessing.py | 2 +- .../{transformer => attention}/rotary/__init__.py | 0 .../layers/{transformer => attention}/rotary/config.py | 10 +++++----- .../layers/{transformer => attention}/rotary/rotary.py | 4 ++-- fast_llm/layers/block/peft.py | 2 +- fast_llm/layers/common/normalization/normalization.py | 2 +- fast_llm/layers/common/peft/__init__.py | 0 fast_llm/layers/common/{ => peft}/config.py | 2 +- fast_llm/layers/common/{peft.py => peft/lora.py} | 0 fast_llm/layers/language_model/config.py | 4 ++-- fast_llm/models/gpt/conversion.py | 6 +++--- fast_llm/models/gpt/huggingface.py | 2 +- fast_llm/models/gpt/megatron.py | 6 +++--- fast_llm/models/gpt/model.py | 6 +++--- fast_llm/models/ssm/model.py | 2 +- tests/functional/test_triton_kernels.py | 4 ++-- tests/layers/test_lm_head.py | 2 +- tests/test_attention.py | 6 +++--- tests/test_multi_stage.py | 2 +- 24 files changed, 36 insertions(+), 36 deletions(-) rename fast_llm/layers/{transformer => attention}/__init__.py (100%) rename fast_llm/layers/{transformer => attention}/attention.py (99%) rename fast_llm/layers/{transformer => attention}/block.py (77%) rename fast_llm/layers/{transformer => attention}/config.py (99%) rename fast_llm/layers/{transformer => attention}/preprocessing.py (98%) rename fast_llm/layers/{transformer => attention}/rotary/__init__.py (100%) rename fast_llm/layers/{transformer => attention}/rotary/config.py (92%) rename fast_llm/layers/{transformer => attention}/rotary/rotary.py (98%) create mode 100644 fast_llm/layers/common/peft/__init__.py rename fast_llm/layers/common/{ => peft}/config.py (95%) rename fast_llm/layers/common/{peft.py => peft/lora.py} (100%) diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py index 8f4dffedf..439d1da2e 100644 --- a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -16,7 +16,7 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.evaluation.lm_eval.utils import prepare_lm_eval_simple_eval_params, process_lm_eval_results from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM -from fast_llm.layers.transformer.rotary.config import NoRotaryConfig +from fast_llm.layers.attention.rotary.config import NoRotaryConfig logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/transformer/__init__.py b/fast_llm/layers/attention/__init__.py similarity index 100% rename from fast_llm/layers/transformer/__init__.py rename to fast_llm/layers/attention/__init__.py diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/attention/attention.py similarity index 99% rename from fast_llm/layers/transformer/attention.py rename to fast_llm/layers/attention/attention.py index 91fca75b8..8a4c490c9 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -8,11 +8,11 @@ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import wrap_forward_backward +from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockDimNames from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import AttentionConfig, AttentionKwargs from fast_llm.utils import combine_lr_scales, div try: diff --git a/fast_llm/layers/transformer/block.py b/fast_llm/layers/attention/block.py similarity index 77% rename from fast_llm/layers/transformer/block.py rename to fast_llm/layers/attention/block.py index ba593461b..3396a2997 100644 --- a/fast_llm/layers/transformer/block.py +++ b/fast_llm/layers/attention/block.py @@ -2,9 +2,9 @@ import logging import typing +from fast_llm.layers.attention.attention import Attention +from fast_llm.layers.attention.config import AttentionConfig, TransformerConfig from fast_llm.layers.block.block import Block -from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import AttentionConfig, TransformerConfig logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/attention/config.py similarity index 99% rename from fast_llm/layers/transformer/config.py rename to fast_llm/layers/attention/config.py index 02b741723..e5c638adc 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/attention/config.py @@ -6,8 +6,8 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig +from fast_llm.layers.attention.rotary.config import RotaryConfig from fast_llm.layers.block.config import AddLinearBiasChoices, BlockConfig, BlockKwargs -from fast_llm.layers.transformer.rotary.config import RotaryConfig from fast_llm.utils import Assert, div logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/attention/preprocessing.py similarity index 98% rename from fast_llm/layers/transformer/preprocessing.py rename to fast_llm/layers/attention/preprocessing.py index 769177668..24ef3397c 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/attention/preprocessing.py @@ -6,7 +6,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.transformer.config import AttentionConfig, AttentionKwargs +from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/transformer/rotary/__init__.py b/fast_llm/layers/attention/rotary/__init__.py similarity index 100% rename from fast_llm/layers/transformer/rotary/__init__.py rename to fast_llm/layers/attention/rotary/__init__.py diff --git a/fast_llm/layers/transformer/rotary/config.py b/fast_llm/layers/attention/rotary/config.py similarity index 92% rename from fast_llm/layers/transformer/rotary/config.py rename to fast_llm/layers/attention/rotary/config.py index 6cc19fce8..4ebd6c5dc 100644 --- a/fast_llm/layers/transformer/rotary/config.py +++ b/fast_llm/layers/attention/rotary/config.py @@ -10,7 +10,7 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.layers.transformer.rotary.rotary import DefaultRotary, Llama3Rotary, NoRotary, Rotary, YarnRotary + from fast_llm.layers.attention.rotary.rotary import DefaultRotary, Llama3Rotary, NoRotary, Rotary, YarnRotary @config_class(registry=True) @@ -44,7 +44,7 @@ class NoRotaryConfig(RotaryConfig): @classmethod def _get_configurable_class(self) -> "type[NoRotary]": - from fast_llm.layers.transformer.rotary.rotary import NoRotary + from fast_llm.layers.attention.rotary.rotary import NoRotary return NoRotary @@ -75,7 +75,7 @@ def _validate(self) -> None: warnings.warn("Triton is disabled, but the triton rotary kernel will be used anyway.") def _get_configurable_class(self) -> "type[DefaultRotary]": - from fast_llm.layers.transformer.rotary.rotary import DefaultRotary + from fast_llm.layers.attention.rotary.rotary import DefaultRotary return DefaultRotary @@ -97,7 +97,7 @@ def _validate(self) -> None: Assert.gt(self.high_frequency_factor, self.low_frequency_factor) def _get_configurable_class(self) -> "type[Llama3Rotary]": - from fast_llm.layers.transformer.rotary.rotary import Llama3Rotary + from fast_llm.layers.attention.rotary.rotary import Llama3Rotary return Llama3Rotary @@ -137,6 +137,6 @@ def _validate(self) -> None: super()._validate() def _get_configurable_class(self) -> "type[YarnRotary]": - from fast_llm.layers.transformer.rotary.rotary import YarnRotary + from fast_llm.layers.attention.rotary.rotary import YarnRotary return YarnRotary diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/attention/rotary/rotary.py similarity index 98% rename from fast_llm/layers/transformer/rotary/rotary.py rename to fast_llm/layers/attention/rotary/rotary.py index bbf8b524a..53b24c9bb 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/attention/rotary/rotary.py @@ -8,8 +8,8 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.functional.triton.rotary import triton_rotary_autograd_ -from fast_llm.layers.transformer.config import AttentionKwargs -from fast_llm.layers.transformer.rotary.config import ( +from fast_llm.layers.attention.config import AttentionKwargs +from fast_llm.layers.attention.rotary.config import ( DefaultRotaryConfig, Llama3RotaryConfig, NoRotaryConfig, diff --git a/fast_llm/layers/block/peft.py b/fast_llm/layers/block/peft.py index 66bc675ed..2261a7ea1 100644 --- a/fast_llm/layers/block/peft.py +++ b/fast_llm/layers/block/peft.py @@ -7,7 +7,7 @@ import typing from fast_llm.config import Field, FieldHint, config_class -from fast_llm.layers.common.config import LoRAConfig, NoPeftConfig, PeftConfig +from fast_llm.layers.common.peft.config import LoRAConfig, NoPeftConfig, PeftConfig from fast_llm.utils import div if typing.TYPE_CHECKING: diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index 7f7d3eb65..a7eba72c8 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -247,7 +247,7 @@ def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: ) -class RMSNormalization[ConfigType: RMSNormalizationConfig](Configurable[ConfigType], torch.nn.Module): +class RMSNormalization[ConfigType: RMSNormalizationConfig](Normalization[ConfigType], torch.nn.Module): """ A RMS normalization layer. Note: Converting input automatically to training dtype to match Apex behaviour, diff --git a/fast_llm/layers/common/peft/__init__.py b/fast_llm/layers/common/peft/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/peft/config.py similarity index 95% rename from fast_llm/layers/common/config.py rename to fast_llm/layers/common/peft/config.py index b09672961..ae8ce3ba4 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/peft/config.py @@ -44,7 +44,7 @@ class LoRAConfig(PeftConfig): ) def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": - from fast_llm.layers.common.peft import lora_linear + from fast_llm.layers.common.peft.lora import lora_linear # TODO: Init method? return lora_linear( diff --git a/fast_llm/layers/common/peft.py b/fast_llm/layers/common/peft/lora.py similarity index 100% rename from fast_llm/layers/common/peft.py rename to fast_llm/layers/common/peft/lora.py diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index de3f9f196..df6969cfc 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -3,9 +3,9 @@ from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl +from fast_llm.layers.attention.config import TransformerConfig +from fast_llm.layers.attention.rotary.config import NoRotaryConfig from fast_llm.layers.block.config import BlockKwargs -from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.layers.transformer.rotary.config import NoRotaryConfig from fast_llm.utils import Assert diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index e31c70a45..36975dea1 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -24,11 +24,11 @@ from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType +from fast_llm.layers.attention.config import TransformerConfig +from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig +from fast_llm.layers.attention.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex from fast_llm.layers.block.mlp.config import RoutingType from fast_llm.layers.common.normalization.config import LayerNormalizationConfig -from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig -from fast_llm.layers.transformer.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex from fast_llm.models.gpt.config import ( DiffusionDreamGPTHuggingfaceCheckpointFormat, DiffusionLlamaGPTHuggingfaceCheckpointFormat, diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index 4e3f258fc..2f99ae4c3 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -9,7 +9,7 @@ from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM -from fast_llm.layers.transformer.config import AttentionKwargs +from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner diff --git a/fast_llm/models/gpt/megatron.py b/fast_llm/models/gpt/megatron.py index 20ed8e828..5d3130549 100644 --- a/fast_llm/models/gpt/megatron.py +++ b/fast_llm/models/gpt/megatron.py @@ -1,7 +1,7 @@ import typing -from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig +from fast_llm.layers.attention.config import TransformerConfig +from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: @@ -94,7 +94,7 @@ def _init_attention_megatron( raise NotImplementedError(meta.tensor_name) if isinstance(config.rotary, DefaultRotaryConfig) and config.rotary.complex_format: - from fast_llm.layers.transformer.rotary.config import convert_rotary_real_to_complex + from fast_llm.layers.attention.rotary.config import convert_rotary_real_to_complex # Megatron uses (2, kv_channels/2) for the complex split; we use (kv_channels/2, 2). # TODO: Avoid unnecessarily changing the value and dense tensors. diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 581429467..b13c77724 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -10,15 +10,15 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.layers.attention.block import TransformerBlock +from fast_llm.layers.attention.config import AttentionKwargs +from fast_llm.layers.attention.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.block.mlp.config import MLPLossNames, RoutingType from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor -from fast_llm.layers.transformer.block import TransformerBlock -from fast_llm.layers.transformer.config import AttentionKwargs -from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 4b7785402..9b79e74a3 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -1,8 +1,8 @@ import logging import typing +from fast_llm.layers.attention.block import TransformerBlock from fast_llm.layers.ssm.block import SSMBlock -from fast_llm.layers.transformer.block import TransformerBlock from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType diff --git a/tests/functional/test_triton_kernels.py b/tests/functional/test_triton_kernels.py index 3f4446e4d..5a9065454 100644 --- a/tests/functional/test_triton_kernels.py +++ b/tests/functional/test_triton_kernels.py @@ -23,8 +23,8 @@ from fast_llm.functional.triton.pointwise import triton_add, triton_copy, triton_fill from fast_llm.functional.triton.rotary import triton_rotary_ from fast_llm.functional.triton.sparse_copy import get_sparse_map -from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig -from fast_llm.layers.transformer.rotary.rotary import ( +from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig +from fast_llm.layers.attention.rotary.rotary import ( apply_rotary_embeddings, convert_rotary_complex_to_real, convert_rotary_real_to_complex, diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 8c33aed4d..380ab0550 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -6,10 +6,10 @@ from fast_llm.config import UpdateType from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl +from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead -from fast_llm.layers.transformer.config import AttentionKwargs from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda diff --git a/tests/test_attention.py b/tests/test_attention.py index 7d05e0a66..9564a931f 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -5,10 +5,10 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.attention.attention import Attention +from fast_llm.layers.attention.config import AttentionKwargs, TransformerConfig +from fast_llm.layers.attention.preprocessing import FlashAttnVarlenPreprocessor from fast_llm.layers.block.config import BlockDimNames -from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import AttentionKwargs, TransformerConfig -from fast_llm.layers.transformer.preprocessing import FlashAttnVarlenPreprocessor from fast_llm.utils import Assert diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index 0639ec7ed..56356cf7a 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -3,8 +3,8 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer +from fast_llm.layers.attention.block import TransformerBlock from fast_llm.layers.ssm.block import SSMBlock -from fast_llm.layers.transformer.block import TransformerBlock from fast_llm.utils import Assert from tests.utils.dataset import get_model_test_dataset from tests.utils.model_configs import ModelTestingGroup From 065b34fac5a44d87281c439ff173f1170126564b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 20 Aug 2025 13:12:28 -0400 Subject: [PATCH 24/28] misc --- fast_llm/layers/block/peft.py | 57 +++-------------- fast_llm/layers/common/peft/config.py | 67 +++++++++++++++----- fast_llm/layers/common/peft/peft.py | 88 +++++++++++++++++++++++++++ 3 files changed, 147 insertions(+), 65 deletions(-) create mode 100644 fast_llm/layers/common/peft/peft.py diff --git a/fast_llm/layers/block/peft.py b/fast_llm/layers/block/peft.py index 2261a7ea1..b51d352bc 100644 --- a/fast_llm/layers/block/peft.py +++ b/fast_llm/layers/block/peft.py @@ -2,7 +2,6 @@ TODO: Generalize beyond transformers. """ -import abc import enum import typing @@ -11,14 +10,10 @@ from fast_llm.utils import div if typing.TYPE_CHECKING: - import torch - from fast_llm.layers.common.linear import LinearBase, LinearLike - from fast_llm.tensor import ParameterMeta class TransformerSubLayerName(str, enum.Enum): - # TODO: Use this to replace AddLinearBiasChoices. query = "query" key = "key" value_ = "value" @@ -30,18 +25,6 @@ class TransformerSubLayerName(str, enum.Enum): @config_class(registry=True) class TransformerPeftConfig(PeftConfig): - @abc.abstractmethod - def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": - pass - - @abc.abstractmethod - def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": - pass - - @abc.abstractmethod - def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": - pass - @classmethod def _from_dict( cls, @@ -57,16 +40,7 @@ def _from_dict( @config_class(dynamic_type={TransformerPeftConfig: "none"}) class TransformerNoPeftConfig(NoPeftConfig, TransformerPeftConfig): - _abstract = False - - def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": - return super().apply_linear(linear) - - def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": - return module - - def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": - return parameter + pass @config_class(dynamic_type={TransformerPeftConfig: "lora"}) @@ -76,33 +50,18 @@ class TransformerLoRAConfig(LoRAConfig, TransformerPeftConfig): desc="The layers on which to apply LoRA.", hint=FieldHint.feature, ) - freeze_others: bool = Field( - default=True, - desc="Whether to freeze other layers during training.", - ) def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": + out_channel_begin, out_channel_end = None, None if layer_type is None or self.layers is None or layer_type in self.layers: + enabled = True if layer_type == TransformerSubLayerName.key: - return super().apply_linear(linear, out_channel_end=div(linear._out_dim.global_size, 2)) + out_channel_end = div(linear._out_dim.global_size, 2) elif layer_type == TransformerSubLayerName.value_: - return super().apply_linear(linear, out_channel_begin=div(linear._out_dim.global_size, 2)) - else: - return super().apply_linear(linear) - elif self.freeze_others: - linear.weight.requires_grad = False - return linear - - def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": - if self.freeze_others: - for parameter in module.parameters(): - parameter.requires_grad = False - return module - - def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": - if self.freeze_others: - parameter.requires_grad = False - return parameter + out_channel_begin = div(linear._out_dim.global_size, 2) + else: + enabled = False + return super().apply_linear(linear, enabled, out_channel_begin, out_channel_end) def _validate(self) -> None: super()._validate() diff --git a/fast_llm/layers/common/peft/config.py b/fast_llm/layers/common/peft/config.py index ae8ce3ba4..4b06623ba 100644 --- a/fast_llm/layers/common/peft/config.py +++ b/fast_llm/layers/common/peft/config.py @@ -5,23 +5,41 @@ from fast_llm.engine.base_model.config import BaseModelConfig if typing.TYPE_CHECKING: + import torch + from fast_llm.layers.common.linear import LinearBase, LinearLike + from fast_llm.layers.common.normalization.normalization import Normalization + from fast_llm.tensor import ParameterMeta @config_class() class PeftConfig(BaseModelConfig): @abc.abstractmethod - def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": - pass + def apply_linear( + self, + module: "LinearBase", + enabled: bool, + out_channel_begin: int | None = None, + out_channel_end: int | None = None, + ) -> "LinearLike": + return self.apply_other(module) + + def apply_normalization(self, module: "Normalization") -> "Normalization": + return self.apply_other(module) + + def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": + for parameter in module.parameters(): + self.apply_weight(parameter) + return module + + def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": + return parameter @config_class() class NoPeftConfig(PeftConfig): _abstract = False - def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": - return linear - @config_class() class LoRAConfig(PeftConfig): @@ -42,17 +60,34 @@ class LoRAConfig(PeftConfig): desc="Dropout rate for LoRA.", hint=FieldHint.stability, ) + freeze_others: bool = Field( + default=True, + desc="Whether to freeze other layers during training.", + ) + + def apply_linear( + self, + module: "LinearBase", + enabled: bool, + out_channel_begin: int | None = None, + out_channel_end: int | None = None, + ) -> "LinearLike": + if not enabled: + return self.apply_other(module) - def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": - from fast_llm.layers.common.peft.lora import lora_linear + from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear + from fast_llm.layers.common.peft.peft import lora_linear + + if isinstance(module, InputParallelLinear): + # TODO: Support InputParallelLinear (different output format). + raise NotImplementedError("LoRA not supported for InputParallelLinear.") + elif isinstance(module, OutputParallelLinear): + assert out_channel_begin is None and out_channel_end is None # TODO: Init method? - return lora_linear( - linear, - linear.weight.param_init_method, - linear.weight.param_init_method, - self.rank, - self.alpha, - self.dropout, - **kwargs, - ) + return lora_linear(module, self.rank, self.alpha, self.dropout, out_channel_begin, out_channel_end) + + def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": + if self.freeze_others: + parameter.requires_grad = False + return parameter diff --git a/fast_llm/layers/common/peft/peft.py b/fast_llm/layers/common/peft/peft.py new file mode 100644 index 000000000..9e0ca0dd0 --- /dev/null +++ b/fast_llm/layers/common/peft/peft.py @@ -0,0 +1,88 @@ +import typing + +import torch + +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.functional.autograd import wrap_forward_backward +from fast_llm.layers.common.linear import Linear, LinearBase + + +def lora_linear( + module: LinearBase, + rank: int, + alpha: float, + dropout: float = 0.0, + out_channel_begin: int | None = None, + out_channel_end: int | None = None, +): + module.weight.requires_grad = False + in_dim = module._in_dim + assert not in_dim.is_parallel, "LoRA not supported with tensor parallelism." + if in_dim.parallel_dim is not None: + in_dim = TensorDim(in_dim.name, in_dim.global_size) + out_dim = module._out_dim + assert not out_dim.is_parallel, "LoRA not supported with tensor parallelism." + if out_dim.parallel_dim is not None: + out_dim = TensorDim(out_dim.name, out_dim.global_size) + if out_channel_begin is not None or out_channel_end is not None: + if out_channel_begin is None: + out_channel_begin = 0 + if out_channel_end is None: + out_channel_end = out_dim.global_size + # TODO: This won't work with TP. Use Composite dim structure for proper split? + out_dim = TensorDim(out_dim.name, out_channel_end - out_channel_begin) + + middle_dim = TensorDim("lora_middle", rank) + + module.lora_0 = Linear( + in_dim, + middle_dim, + bias=False, + weight_init_method=module.weight.param_init_method, + transposed_weight=module.transposed_weight, + lr_scale=module.weight.lr_scale, + ) + module.lora_1 = Linear( + middle_dim, + out_dim, + bias=False, + weight_init_method=module.weight.param_init_method, + transposed_weight=module.transposed_weight, + lr_scale=module.weight.lr_scale, + ) + # TODO: Implement proper backward pass. + module.lora_0.weight.auto_grad_accumulation = True + module.lora_1.weight.auto_grad_accumulation = True + + old_forward = module._forward + + def forward_only(input_: torch.Tensor) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + # TODO: torch compile? + input_ = input_.detach().requires_grad_() + with torch.enable_grad(): + output = old_forward(input_) + if isinstance(output, tuple): + layer_out, tp_bias = output[0] + assert tp_bias is None + lora_out = (alpha / rank) * module.lora_1( + module.lora_0(torch.dropout(input_, dropout, module.training) if dropout > 0.0 else input_) + ) + if out_channel_begin is None: + output = output + lora_out + else: + output.view(-1, layer_out.size(-1))[:, out_channel_begin:out_channel_end] += lora_out + return output.detach(), (input_, output) + + def backward( + grad_output: torch.Tensor, context: torch.Tensor + ) -> tuple[torch.Tensor, typing.Callable[[], None] | None]: + # TODO: Implement proper backward pass. + input_, output = context + output.backward(grad_output) + return input_.grad + + module._forward = wrap_forward_backward(forward_only, backward) + module.forward_only = forward_only + module.backward = backward + + return module From 4510b7b1a20aea4cb0348aeb233f704c2fcc30cf Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 20 Aug 2025 13:15:19 -0400 Subject: [PATCH 25/28] misc --- fast_llm/layers/common/peft/config.py | 2 +- fast_llm/layers/common/peft/lora.py | 44 +++++++------- fast_llm/layers/common/peft/peft.py | 88 --------------------------- 3 files changed, 22 insertions(+), 112 deletions(-) delete mode 100644 fast_llm/layers/common/peft/peft.py diff --git a/fast_llm/layers/common/peft/config.py b/fast_llm/layers/common/peft/config.py index 4b06623ba..12e1810ff 100644 --- a/fast_llm/layers/common/peft/config.py +++ b/fast_llm/layers/common/peft/config.py @@ -76,7 +76,7 @@ def apply_linear( return self.apply_other(module) from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear - from fast_llm.layers.common.peft.peft import lora_linear + from fast_llm.layers.common.peft.lora import lora_linear if isinstance(module, InputParallelLinear): # TODO: Support InputParallelLinear (different output format). diff --git a/fast_llm/layers/common/peft/lora.py b/fast_llm/layers/common/peft/lora.py index 87991ef29..9e0ca0dd0 100644 --- a/fast_llm/layers/common/peft/lora.py +++ b/fast_llm/layers/common/peft/lora.py @@ -8,21 +8,19 @@ def lora_linear( - layer: LinearBase, - init_method_0, - init_method_1, + module: LinearBase, rank: int, alpha: float, dropout: float = 0.0, out_channel_begin: int | None = None, out_channel_end: int | None = None, ): - layer.weight.requires_grad = False - in_dim = layer._in_dim + module.weight.requires_grad = False + in_dim = module._in_dim assert not in_dim.is_parallel, "LoRA not supported with tensor parallelism." if in_dim.parallel_dim is not None: in_dim = TensorDim(in_dim.name, in_dim.global_size) - out_dim = layer._out_dim + out_dim = module._out_dim assert not out_dim.is_parallel, "LoRA not supported with tensor parallelism." if out_dim.parallel_dim is not None: out_dim = TensorDim(out_dim.name, out_dim.global_size) @@ -36,27 +34,27 @@ def lora_linear( middle_dim = TensorDim("lora_middle", rank) - layer.lora_0 = Linear( + module.lora_0 = Linear( in_dim, middle_dim, bias=False, - weight_init_method=init_method_0, - transposed_weight=layer.transposed_weight, - lr_scale=layer.weight.lr_scale, + weight_init_method=module.weight.param_init_method, + transposed_weight=module.transposed_weight, + lr_scale=module.weight.lr_scale, ) - layer.lora_1 = Linear( + module.lora_1 = Linear( middle_dim, out_dim, bias=False, - weight_init_method=init_method_1, - transposed_weight=layer.transposed_weight, - lr_scale=layer.weight.lr_scale, + weight_init_method=module.weight.param_init_method, + transposed_weight=module.transposed_weight, + lr_scale=module.weight.lr_scale, ) # TODO: Implement proper backward pass. - layer.lora_0.weight.auto_grad_accumulation = True - layer.lora_1.weight.auto_grad_accumulation = True + module.lora_0.weight.auto_grad_accumulation = True + module.lora_1.weight.auto_grad_accumulation = True - old_forward = layer._forward + old_forward = module._forward def forward_only(input_: torch.Tensor) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: # TODO: torch compile? @@ -66,8 +64,8 @@ def forward_only(input_: torch.Tensor) -> tuple[torch.Tensor, tuple[torch.Tensor if isinstance(output, tuple): layer_out, tp_bias = output[0] assert tp_bias is None - lora_out = (alpha / rank) * layer.lora_1( - layer.lora_0(torch.dropout(input_, dropout, layer.training) if dropout > 0.0 else input_) + lora_out = (alpha / rank) * module.lora_1( + module.lora_0(torch.dropout(input_, dropout, module.training) if dropout > 0.0 else input_) ) if out_channel_begin is None: output = output + lora_out @@ -83,8 +81,8 @@ def backward( output.backward(grad_output) return input_.grad - layer._forward = wrap_forward_backward(forward_only, backward) - layer.forward_only = forward_only - layer.backward = backward + module._forward = wrap_forward_backward(forward_only, backward) + module.forward_only = forward_only + module.backward = backward - return layer + return module diff --git a/fast_llm/layers/common/peft/peft.py b/fast_llm/layers/common/peft/peft.py deleted file mode 100644 index 9e0ca0dd0..000000000 --- a/fast_llm/layers/common/peft/peft.py +++ /dev/null @@ -1,88 +0,0 @@ -import typing - -import torch - -from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.functional.autograd import wrap_forward_backward -from fast_llm.layers.common.linear import Linear, LinearBase - - -def lora_linear( - module: LinearBase, - rank: int, - alpha: float, - dropout: float = 0.0, - out_channel_begin: int | None = None, - out_channel_end: int | None = None, -): - module.weight.requires_grad = False - in_dim = module._in_dim - assert not in_dim.is_parallel, "LoRA not supported with tensor parallelism." - if in_dim.parallel_dim is not None: - in_dim = TensorDim(in_dim.name, in_dim.global_size) - out_dim = module._out_dim - assert not out_dim.is_parallel, "LoRA not supported with tensor parallelism." - if out_dim.parallel_dim is not None: - out_dim = TensorDim(out_dim.name, out_dim.global_size) - if out_channel_begin is not None or out_channel_end is not None: - if out_channel_begin is None: - out_channel_begin = 0 - if out_channel_end is None: - out_channel_end = out_dim.global_size - # TODO: This won't work with TP. Use Composite dim structure for proper split? - out_dim = TensorDim(out_dim.name, out_channel_end - out_channel_begin) - - middle_dim = TensorDim("lora_middle", rank) - - module.lora_0 = Linear( - in_dim, - middle_dim, - bias=False, - weight_init_method=module.weight.param_init_method, - transposed_weight=module.transposed_weight, - lr_scale=module.weight.lr_scale, - ) - module.lora_1 = Linear( - middle_dim, - out_dim, - bias=False, - weight_init_method=module.weight.param_init_method, - transposed_weight=module.transposed_weight, - lr_scale=module.weight.lr_scale, - ) - # TODO: Implement proper backward pass. - module.lora_0.weight.auto_grad_accumulation = True - module.lora_1.weight.auto_grad_accumulation = True - - old_forward = module._forward - - def forward_only(input_: torch.Tensor) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - # TODO: torch compile? - input_ = input_.detach().requires_grad_() - with torch.enable_grad(): - output = old_forward(input_) - if isinstance(output, tuple): - layer_out, tp_bias = output[0] - assert tp_bias is None - lora_out = (alpha / rank) * module.lora_1( - module.lora_0(torch.dropout(input_, dropout, module.training) if dropout > 0.0 else input_) - ) - if out_channel_begin is None: - output = output + lora_out - else: - output.view(-1, layer_out.size(-1))[:, out_channel_begin:out_channel_end] += lora_out - return output.detach(), (input_, output) - - def backward( - grad_output: torch.Tensor, context: torch.Tensor - ) -> tuple[torch.Tensor, typing.Callable[[], None] | None]: - # TODO: Implement proper backward pass. - input_, output = context - output.backward(grad_output) - return input_.grad - - module._forward = wrap_forward_backward(forward_only, backward) - module.forward_only = forward_only - module.backward = backward - - return module From 39960cee17e7a047ac4cc7f3e6d33b0fa631ae5f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 21 Aug 2025 15:01:37 -0400 Subject: [PATCH 26/28] Cleanup --- fast_llm/layers/ssm/mamba.py | 4 ++-- fast_llm/layers/ssm/mamba2.py | 22 +++++++++++----------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index 79a0e5c8e..453c14af6 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -68,7 +68,7 @@ def __init__( lr_scale: float | None, ): super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) - assert self._distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" + assert self._distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for Mamba" # TODO: It's not silu? Assert.eq(self._config.activation_type, ActivationType.silu) @@ -84,12 +84,12 @@ def __init__( lr_scale = combine_lr_scales(self._lr_scale, self._config.mamba_lr_scale) # TODO: Backward compatibility? - # TODO: lr_scale? self.in_proj = Linear( hidden_dim, inner_projection_dim, bias=False, weight_init_method=init_kaiming_(hidden_dim.size), + lr_scale=lr_scale, ) self.conv1d_weight = ParameterMeta.from_dims( ( diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index eec134a22..2659e415f 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -149,16 +149,16 @@ def __init__( bias=config.add_bias_linear, weight_init_method=init_kaiming_(self._config.d_inner), sequence_parallel=self._sequence_parallel, - # TODO: lr_scale? + lr_scale=lr_scale, ) if self._debug.enabled: - _xz_dims = ( + self._xz_dims = ( BlockDimNames.batch, inner_dim, BlockDimNames.sequence_q, ) - _bc_dims = ( + self._bc_dims = ( BlockDimNames.batch, heads_dim, state_dim, @@ -176,10 +176,10 @@ def forward( assert _causal_conv1d_available # inner_projection : (batch/local_sequence, local_sequence/batch, hidden) - # -> (batch/sequence, sequence/batch, inner_projection) + # -> (batch/sequence, sequence/batch, local_inner_projection) inner_projection = self.in_proj(input_) dt = self.dt_proj(self.dt_in_proj(input_)) + self.dt_proj_bias - # Standardize to (batch, sequence, inner_projection) + # Standardize to (batch, sequence, local_inner_projection) if kwargs[BlockKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) dt = dt.transpose(0, 1) @@ -226,11 +226,11 @@ def forward( dt = dt.transpose(1, 2) if self._debug.enabled: - self._debug(z, "z", self._XZ_DIMS, kwargs) - self._debug(x, "x", self._XZ_DIMS, kwargs) - self._debug(b, "b", self._BC_DIMS, kwargs) - self._debug(c, "c", self._BC_DIMS, kwargs) - self._debug(dt, "dt", self._XZ_DIMS, kwargs) + self._debug(z, "z", self._xz_dims, kwargs) + self._debug(x, "x", self._xz_dims, kwargs) + self._debug(b, "b", self._bc_dims, kwargs) + self._debug(c, "c", self._bc_dims, kwargs) + self._debug(dt, "dt", self._xz_dims, kwargs) y = selective_scan_fn( x, @@ -245,7 +245,7 @@ def forward( ) if self._debug.enabled: - self._debug(y, "y", self._XZ_DIMS, kwargs) + self._debug(y, "y", self._xz_dims, kwargs) # y: (batch, local_heads * state, sequence) -> (batch, sequence, local_heads * state) y = y.transpose(1, 2)[:, :sequence_length] From 654aeeb4be24eb64fba6f3885d72ebcf4992d532 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 21 Aug 2025 16:09:54 -0400 Subject: [PATCH 27/28] Fix merge --- fast_llm/layers/block/block.py | 10 +++++++--- fast_llm/layers/ssm/discrete_mamba2.py | 6 ++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index b8aad3903..f90fce698 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -163,8 +163,12 @@ def __init__( self._return_input: bool = return_input # Note, layer_lr_scale does not impact the norms # TODO: add a separate norm_lr_scale - self.norm_1 = self._config.peft.apply_other(self._config.normalization.get_layer(self._hidden_dim)) - self.norm_2 = self._config.peft.apply_other(self._config.normalization.get_layer(self._hidden_dim)) + self.norm_1 = self._config.peft.apply_other( + self._config.normalization.get_layer(self._hidden_dim, self._lr_scale) + ) + self.norm_2 = self._config.peft.apply_other( + self._config.normalization.get_layer(self._hidden_dim, self._lr_scale) + ) # Attribute should be mixer, but Attention uses a different name for backward compatibility. TODO: Fix. setattr( @@ -192,7 +196,7 @@ def __init__( self._hidden_dim, self._block_index, f"{self._name} MLP", - lr_scale, + self._lr_scale, ) @functools.cached_property diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 0d91fbaff..f9462a942 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -150,11 +150,9 @@ def forward( assert not kwargs[BlockKwargs.sequence_first] and input_.size(1) == sequence_length input_ = torch.nn.functional.pad(input_, (0, 0, 0, padded_length - sequence_length)) - # inner_projection : (batch/local_or_padded_sequence, local_sequence/batch, hidden) - # -> (batch/local_or_padded_sequence, local_sequence/batch, inner_projection) - # inner_projection: (batch, local_or_padded_sequence, hidden) -> (batch, padded_sequence, local_inner_size) + # -> (batch/padded_sequence, sequence/batch, local_inner_projection inner_projection = self.in_proj(input_) - # Standardize to (batch, padded_sequence, inner_projection) + # Standardize to (batch, padded_sequence, local_inner_projection) if kwargs[BlockKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) From 3f4a8ba8600adba99f791d490c0588f10221068e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 27 Aug 2025 15:52:35 -0400 Subject: [PATCH 28/28] fix --- fast_llm/layers/common/peft/config.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fast_llm/layers/common/peft/config.py b/fast_llm/layers/common/peft/config.py index 12e1810ff..64a2ca57a 100644 --- a/fast_llm/layers/common/peft/config.py +++ b/fast_llm/layers/common/peft/config.py @@ -1,4 +1,3 @@ -import abc import typing from fast_llm.config import Field, FieldHint, config_class @@ -14,7 +13,6 @@ @config_class() class PeftConfig(BaseModelConfig): - @abc.abstractmethod def apply_linear( self, module: "LinearBase",