Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Megatron-LM
30 changes: 15 additions & 15 deletions docs/developer_guide/conversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,21 +230,21 @@ Continuing our `AwesomeModel` handler example, we define:

```python
def _create_weight_converters(self) -> list[WeightConverter]:
converters = []
# The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`.
num_layers = self._model.config.base_model.transformer.num_layers

# A simple renaming example, for the word embeddings.
converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight"))

# We usually want to loop dynamically over layers
for i in range(num_layers):
# A `SplitWeightConverter` example, splitting a weight in two.
converters.append(SplitWeightConverter(
f"layers.{i + 1}.weight",
(f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"),
))
return converters
converters = []
# The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`.
num_layers = self._model.config.base_model.transformer.num_layers

# A simple renaming example, for the word embeddings.
converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight"))

# We usually want to loop dynamically over layers
for i in range(num_layers):
# A `SplitWeightConverter` example, splitting a weight in two.
converters.append(SplitWeightConverter(
f"layers.{i + 1}.weight",
(f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"),
))
return converters
```

And that's it! We're ready to use the new checkpoint format in Fast-LLM.
Expand Down
22 changes: 22 additions & 0 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,28 @@ def __init__(self, config: ConfigType, *args, **kwargs):
# Handle multiple inheritance.
super().__init__(*args, **kwargs)

def __init_subclass__(cls):
# Automatically set `config_class` based on the bound type.
# Make sure `ConfigType` is bound and respects class hierarchy.
try:
config_class = None
for base in types.get_original_bases(cls):
if hasattr(base, "__origin__") and issubclass(base.__origin__, Configurable):
for arg in base.__args__:
if arg.__name__ == "ConfigType":
if config_class is None:
config_class = arg.__bound__
else:
assert arg.__bound__ is config_class
assert config_class is not None
except Exception as e:
raise TypeError(
f"Could not determine the configuration class for the configurable class {cls.__name__}: {e.args}. "
"Please make sure to declare in the format "
f"`class {cls.__name__}[ConfigType: ConfigClass](BaseConfigurable[ConfigType])`.] "
)
cls.config_class = config_class

@property
def config(self) -> ConfigType:
return self._config
Expand Down
2 changes: 0 additions & 2 deletions fast_llm/data/preparator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ def _get_runnable(self) -> typing.Callable[[], None]:


class DatasetPreparator[ConfigType: DatasetPreparatorConfig](Configurable[ConfigType], abc.ABC):
config_class: typing.ClassVar[type[DatasetPreparatorConfig]] = DatasetPreparatorConfig

@abc.abstractmethod
def run(self) -> None:
raise NotImplementedError
19 changes: 0 additions & 19 deletions fast_llm/data/preparator/gpt_memmap/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,6 @@ class SourceSchemaConfig(Config):
pass


@config_class(dynamic_type={SourceSchemaConfig: "prompt_completion"})
class PromptCompletionConfig(SourceSchemaConfig):
prompt_column: str = Field(
default="prompt",
desc="Field of the dataset to use.",
hint=FieldHint.optional,
)
completion_column: str = Field(
default="completion",
desc="Field of the dataset to use.",
hint=FieldHint.optional,
)
delimiter: str = Field(
default="",
desc="Delimiter between prompt and completion.",
hint=FieldHint.optional,
)


@config_class(dynamic_type={SourceSchemaConfig: "text_column"})
class TextColumnConfig(SourceSchemaConfig):
input_column: str = Field(
Expand Down
94 changes: 28 additions & 66 deletions fast_llm/data/preparator/gpt_memmap/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,7 @@
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset
from fast_llm.data.dataset.gpt.sampled import GPTSample
from fast_llm.data.preparator.config import DatasetPreparator
from fast_llm.data.preparator.gpt_memmap.config import (
GPTMemmapDatasetPreparatorConfig,
PromptCompletionConfig,
TextColumnConfig,
)
from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, TextColumnConfig
from fast_llm.data.tokenizer import Tokenizer
from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type
from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum
Expand All @@ -37,8 +33,6 @@


class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](DatasetPreparator[ConfigType]):
config_class: typing.ClassVar[type[GPTMemmapDatasetPreparatorConfig]] = GPTMemmapDatasetPreparatorConfig

_tokenizer: Tokenizer
_data_type: DataType
_text_column: str
Expand All @@ -54,30 +48,6 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[
"num_tokens": num_tokens,
}

def _tokenize_prompt_completion_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]:
"""
Tokenize prompt and completion columns separately, then concatenate.
Returns input_ids, token_spans (prompt len), and num_tokens.
"""
prompt_col = self._config.dataset.source_schema.prompt_column
completion_col = self._config.dataset.source_schema.completion_column
delimiter = self._config.dataset.source_schema.delimiter
input_ids = []
token_spans = []
for prompt, completion in zip(batch[prompt_col], batch[completion_col]):
prompt_tokens = self._tokenizer.tokenize(prompt, begin=True, end=False)
completion_tokens = self._tokenizer.tokenize(f"{delimiter}{completion}", begin=False, end=True)
combined = prompt_tokens + completion_tokens
input_ids.append(np.array(combined, dtype=self._data_type.numpy))
token_spans.append(np.array((0, len(prompt_tokens) - 1), dtype=np.int32).reshape(-1, 2))

num_tokens = [len(x) for x in input_ids]
return {
"input_ids": input_ids,
"token_spans": token_spans,
"num_tokens": num_tokens,
}

def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]:
input_ids, token_spans = map(
list,
Expand Down Expand Up @@ -171,7 +141,7 @@ def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetCon
shard_output_path = self._config.output_path / prefix

def _document_generator():
if "token_spans" in shard_dataset.column_names:
if "token_spans" in shard_dataset.column_names and self._loss_masking_spans_column is not None:
for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"):
yield GPTSample(
np.array(item["input_ids"], dtype=self._data_type.numpy),
Expand Down Expand Up @@ -317,46 +287,37 @@ def run(self) -> None:
)

# Set data column and loss masking spans column based on source schema
source_schema = self._config.dataset.source_schema
if isinstance(source_schema, TextColumnConfig):
self._text_column = source_schema.input_column
self._loss_masking_spans_column = source_schema.loss_masking_spans_column
elif isinstance(source_schema, PromptCompletionConfig):
Assert.incl(source_schema.prompt_column, dataset.column_names)
Assert.incl(source_schema.completion_column, dataset.column_names)
tokenize_fn = self._tokenize_prompt_completion_batch
if isinstance(self._config.dataset.source_schema, TextColumnConfig):
self._text_column = self._config.dataset.source_schema.input_column
self._loss_masking_spans_column = self._config.dataset.source_schema.loss_masking_spans_column
else:
raise ValueError(
f"Dataset source_schema set incorrectly. source_schema: '{self._config.dataset.source_schema}'."
)

# TODO: Add a new schema for preference datasets then drop class vars _loss_masking_spans_column & _text_column
if isinstance(source_schema, TextColumnConfig):
if self._text_column not in dataset.column_names:
raise ValueError(f"Dataset does not have field '{self._text_column}'.")
if self._text_column not in dataset.column_names:
raise ValueError(f"Dataset does not have field '{self._text_column}'.")

if self._config.dataset.source_schema.loss_masking_spans_column is not None and (
self._config.dataset.chosen_text is not None or self._config.dataset.rejected_text is not None
):
raise ValueError(f"Can not enable both loss masking spans and chosen/rejected loss masking spans.")
if (self._config.dataset.chosen_text is None) != (self._config.dataset.rejected_text is None):
raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.")

# route tokenize function
if self._loss_masking_spans_column is not None:
if self._loss_masking_spans_column not in dataset.column_names:
raise ValueError(f"Dataset does not have spans field '{self._loss_masking_spans_column}'.")
tokenize_fn = self._tokenize_batch_with_spans
elif self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None:
if self._config.dataset.chosen_text not in dataset.column_names:
raise ValueError(f"Dataset does not have chosen spans field '{self._config.dataset.chosen_text}'.")
if self._config.dataset.rejected_text not in dataset.column_names:
raise ValueError(
f"Dataset does not have rejected spans field '{self._config.dataset.rejected_text}'."
)
tokenize_fn = self._tokenize_preference_batch_with_spans
else:
tokenize_fn = self._tokenize_batch
if self._config.dataset.source_schema.loss_masking_spans_column is not None and (
self._config.dataset.chosen_text is not None or self._config.dataset.rejected_text is not None
):
raise ValueError(f"Can not enable both loss masking spans and chosen/rejected loss masking spans.")
if (self._config.dataset.chosen_text is None) != (self._config.dataset.rejected_text is None):
raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.")

# route tokenize function
if self._loss_masking_spans_column is not None:
if self._loss_masking_spans_column not in dataset.column_names:
raise ValueError(f"Dataset does not have spans field '{self._loss_masking_spans_column}'.")
tokenize_fn = self._tokenize_batch_with_spans
elif self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None:
if self._config.dataset.chosen_text not in dataset.column_names:
raise ValueError(f"Dataset does not have chosen spans field '{self._config.dataset.chosen_text}'.")
if self._config.dataset.rejected_text not in dataset.column_names:
raise ValueError(f"Dataset does not have rejected spans field '{self._config.dataset.rejected_text}'.")
tokenize_fn = self._tokenize_preference_batch_with_spans
else:
tokenize_fn = self._tokenize_batch

# Tokenize the dataset in parallel
tokenized_dataset = dataset.map(
Expand All @@ -368,6 +329,7 @@ def run(self) -> None:

# Calculate total number of tokens
total_tokens = sum(tqdm.tqdm(tokenized_dataset["num_tokens"], desc="Counting tokens", unit="tokens"))

# Split dataset into shards based on number of tokens
num_shards = int(np.ceil(total_tokens / self._config.tokens_per_shard))
shards = [
Expand Down
58 changes: 26 additions & 32 deletions fast_llm/engine/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from fast_llm.config import Configurable
from fast_llm.engine.base_model.config import BaseModelConfig
from fast_llm.engine.config_utils.tensor_space import TensorSpace
from fast_llm.engine.distributed.config import DistributedConfig, PhaseType
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.tensor import ParameterMeta, TensorMeta
Expand All @@ -20,11 +19,18 @@
class Module(torch.nn.Module, abc.ABC):
""" """

def forward(self, input_, kwargs):
"""
Run a forward pass for the module, with autograd support.
"""
raise NotImplementedError()
_is_setup: bool = False
_distributed: Distributed

def __init__(self, distributed_config: DistributedConfig):
self._distributed_config = distributed_config
super().__init__()

def setup(self, distributed: Distributed) -> None:
assert not self._is_setup
distributed.check_config(self._distributed_config)
self._distributed = distributed
self._is_setup = True


class Layer(Module):
Expand All @@ -39,9 +45,9 @@ def forward(


class Sequential(Layer):
def __init__(self, layers: list[Layer]):
super().__init__()
self.layers = torch.nn.ModuleList(layers)
def __init__(self, distributed_config: DistributedConfig):
super().__init__(distributed_config)
self.layers = torch.nn.ModuleList(self.get_layers())

def __getitem__(self, item):
return self.layers[item]
Expand All @@ -59,6 +65,15 @@ def forward(
input_ = layer(input_, kwargs, losses, metrics)
return input_

@abc.abstractmethod
def get_layers(self) -> list[Layer]:
pass

def setup(self, distributed: Distributed) -> None:
super().setup(distributed)
for layer in self.layers:
layer.setup(distributed)


@dataclasses.dataclass()
class LossDef:
Expand All @@ -71,29 +86,14 @@ class LossDef:
dtype: torch.dtype = torch.float32


class SequentialLayers(Sequential, abc.ABC):
# Small class defined to fix the MRO of BaseModel.__init__
def __init__(self):
super().__init__(self.get_layers())

@abc.abstractmethod
def get_layers(self) -> list[Layer]:
pass


class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], SequentialLayers, abc.ABC):
config_class: typing.ClassVar[type[BaseModelConfig]] = BaseModelConfig
_is_setup: bool = False
class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], Sequential):

def __init__(
self,
config: BaseModelConfig,
distributed_config: DistributedConfig,
):
self._tensor_space: TensorSpace = TensorSpace(distributed_config)
config.setup_tensor_space(self._tensor_space)

super().__init__(config)
super().__init__(config, distributed_config)

for key, value in self.named_parameters():
Assert.custom(isinstance, value, ParameterMeta)
Expand All @@ -104,12 +104,6 @@ def __init__(
# TODO: Add basic handling (preprocessor) in this class.
self._reference_models: dict[str, "InferenceRunner"] = {}

def setup(self, distributed: Distributed) -> None:
assert not self._is_setup
distributed.check_config(self._tensor_space.distributed_config)
self._tensor_space.setup(distributed)
self._is_setup = True

@abc.abstractmethod
def get_layers(self) -> list[Layer]:
pass
Expand Down
7 changes: 2 additions & 5 deletions fast_llm/engine/base_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fast_llm.utils import compare_nested, log

if typing.TYPE_CHECKING:
from fast_llm.engine.config_utils.tensor_space import TensorSpace
import torch


@config_class()
Expand All @@ -18,9 +18,6 @@ class BaseModelConfig(Config):

_abstract = True

def setup_tensor_space(self, tensor_space: "TensorSpace") -> None:
raise NotImplementedError()

def compare_architecture(
self,
model_config: typing.Self,
Expand Down Expand Up @@ -64,5 +61,5 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None:
pass

@abc.abstractmethod
def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None:
def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None:
pass
Loading