From ec2c9fb5d01304a8b024c663129df7dd63ec431e Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Wed, 7 May 2025 22:37:45 +0000 Subject: [PATCH 01/25] initial data prep --- fast_llm/data/dataset/gpt/memmap.py | 12 ++++++ fast_llm/data/dataset/gpt/sampled.py | 2 + fast_llm/data/preparator/gpt_memmap/config.py | 6 +++ .../data/preparator/gpt_memmap/prepare.py | 34 ++++++++++------ fast_llm/data/tokenizer.py | 40 ++++++++++++------- 5 files changed, 68 insertions(+), 26 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 8651b8fcd..752c83cee 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -258,6 +258,10 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP image_lengths = [] im_positions = [] total_images = 0 + n_audio = [] + audio_lengths = [] + aud_positions = [] + total_audio = 0 pointers = [] offset = 0 # number of spans for each document @@ -295,6 +299,14 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP bin_stream.write(pixels.tobytes(order="C")) total_im_size += pixels.size im_positions.append(document.image_positions) + if document.audio: + n_audio.append(len(document.audio)) + total_audio += len(document.audio) + for audio in document.audio: + audio_lengths.append(len(audio)) + bin_stream.write(audio.to_bytes(order="C")) + # total_aud_size += + aud_positions.append(document.audio_positions) # Update metadata doc_length = len(document.token_ids) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 0ba3f0e13..d700dcc00 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -33,6 +33,8 @@ class GPTSample: loss_masking_spans: np.ndarray | None = None images: np.ndarray | None = None image_positions: np.ndarray | None = None + audio: np.ndarray | None = None + audio_positions: np.ndarray | None = None sequence_lengths: np.ndarray | None = None diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 89fe904cd..a56d766a8 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -68,6 +68,12 @@ class GPTHuggingfaceDatasetConfig(Config): images: None | str = Field( default=None, desc="Field containing images relevant to a document", hint=FieldHint.optional ) + audio_positions: None | str = Field( + default=None, desc="Field containing audio positions within a document", hint=FieldHint.optional + ) + audio: None | str = Field( + default=None, desc="Field containing audio relevant to a document", hint=FieldHint.optional + ) data_type: DataType | None = Field( default=None, desc="Data type of the dataset field." diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 4965dfdfc..19858cbcf 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -50,22 +50,25 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[ # np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) # for text in batch[self._config.dataset.field] # ] - input_ids, image_token_positions = map( + input_ids, image_token_positions, audio_token_positions = map( list, zip( *[ ( np.array(input_ids, dtype=self._data_type.numpy), np.array(image_token_positions, dtype=np.int32), + np.array(audio_token_positions, dtype=np.int32), ) - for input_ids, image_token_positions in [ + for input_ids, image_token_positions, audio_token_positions in [ self._tokenizer.tokenize( text, im_char_positions, + aud_char_positions, ) - for text, im_char_positions in zip( + for text, im_char_positions, aud_char_positions in zip( batch[self._config.dataset.field], batch.get(self._config.dataset.image_positions, itertools.repeat(None)), + batch.get(self._config.dataset.audio_positions, itertools.repeat(None)), ) ] ] @@ -82,6 +85,7 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[ return { "input_ids": input_ids, "image_positions": image_token_positions, + "audio_token_positions": audio_token_positions, "num_tokens": num_tokens, "num_pixels": num_pixels, } @@ -143,6 +147,8 @@ def _document_generator(): # [np.array(im) for im in item["images"]] if self._config.dataset.images else None, item["images"] if self._config.dataset.images else None, item["image_positions"] if self._config.dataset.image_positions else None, + np.array(item[self._config.dataset.audio]) if self._config.dataset.audio else None, + item[self._config.dataset.audio_positions] if self._config.dataset.audio_positions else None, ) # if "token_spans" in shard_dataset.column_names and self._config.dataset.loss_masking_spans is not None: # for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): @@ -167,15 +173,19 @@ def _document_generator(): ) def _load_dataset(self) -> datasets.Dataset: - dataset = datasets.load_dataset( - path=self._config.dataset.path, - name=self._config.dataset.config_name, - data_dir=self._config.dataset.data_directory, - data_files=self._config.dataset.data_files, - split=self._config.dataset.split, - num_proc=self._config.loading_workers, - trust_remote_code=self._config.dataset.trust_remote_code, - ) + try: + dataset = datasets.load_dataset( + path=self._config.dataset.path, + name=self._config.dataset.config_name, + data_dir=self._config.dataset.data_directory, + data_files=self._config.dataset.data_files, + split=self._config.dataset.split, + num_proc=self._config.loading_workers, + trust_remote_code=self._config.dataset.trust_remote_code, + ) + except: + # backup if dataset is saved in arrow format (can we auto-detect this?) + dataset = datasets.load_from_disk(dataset_path=self._config.dataset.data_directory) assert isinstance(dataset, datasets.Dataset) return dataset diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 0e7d54709..98cfbb851 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -42,33 +42,45 @@ def _tokenize(self, text: str, begin=True, end=True) -> list[int]: + ([self.eod_id] if end else []) ) - def tokenize(self, text, image_positions=None): - if not image_positions: + def tokenize(self, text, image_positions=None, audio_positions=None): + image_positions = image_positions or [] + audio_positions = audio_positions or [] + if len(set(image_positions).intersection(audio_positions)) > 0: + raise ValueError("Image and audio can not have the same position.") + multimodal_positions = sorted(image_positions + audio_positions) + if not multimodal_positions: return self._tokenize(text), [], [] - image_idx = 0 + multimodel_idx = 0 char_pos = 0 token_ids = [] image_token_positions = [] + audio_token_positions = [] beginning_of_text = True - while image_idx < len(image_positions): - if image_positions[image_idx] > len(text): + while multimodel_idx < len(multimodal_positions): + multimodal_char_pos = multimodal_positions[multimodel_idx] + multimodal_type = "image" if multimodal_char_pos in image_positions else "audio" + + if multimodal_char_pos > len(text): raise ValueError( - f"Image position {image_positions[image_idx]} is greater than text length {len(text)}" + f"{multimodal_type.capitalize()} position {multimodal_char_pos} is greater than text length {len(text)}" ) - curr_text = text[char_pos : image_positions[image_idx]] - tokenized_text = self._tokenize( - curr_text, begin=beginning_of_text, end=image_positions[image_idx] >= len(text) - ) + curr_text = text[char_pos:multimodal_char_pos] + tokenized_text = self._tokenize(curr_text, begin=beginning_of_text, end=multimodal_char_pos >= len(text)) beginning_of_text = False token_ids.extend(tokenized_text) - image_token_positions = len(token_ids) - char_pos = image_positions[image_idx] - image_idx += 1 + + # store multimodal token positions + if multimodal_type == "image": + image_token_positions.append(len(token_ids)) + else: + audio_token_positions.append(len(token_ids)) + char_pos = multimodal_char_pos + multimodel_idx += 1 if char_pos < len(text): curr_text = text[char_pos:] tokenized_text = self._tokenize(curr_text, begin=beginning_of_text, end=True) token_ids.extend(tokenized_text) - return token_ids, image_token_positions + return token_ids, image_token_positions, audio_token_positions def tokenize_with_spans( self, text: str, char_spans: list[tuple[int, int]] From 82e4edba3558df9531ae6b5226b44bc50b7c4952 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Fri, 9 May 2025 01:31:42 +0000 Subject: [PATCH 02/25] audio dataset changes --- fast_llm/data/data/gpt/data.py | 9 +++ fast_llm/data/dataset/gpt/config.py | 3 + fast_llm/data/dataset/gpt/indexed.py | 8 +- fast_llm/data/dataset/gpt/memmap.py | 105 ++++++++++++++++++++++++--- fast_llm/data/dataset/gpt/sampled.py | 14 +++- fast_llm/engine/schedule/config.py | 15 ++++ fast_llm/models/gpt/trainer.py | 3 + 7 files changed, 144 insertions(+), 13 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 34b86f213..681d44437 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -92,6 +92,9 @@ def __init__( cross_document_attention: bool = True, patch_size: list[int] | None = None, max_image_size: int | None = None, + aud_downsampling_k: int | None = None, + aud_padding_duration: int | None = None, + aud_sampling_rate: int | None = None, ): """ Create the data and gather some basic information on the dataset(s). @@ -103,6 +106,9 @@ def __init__( self._cross_document_attention = cross_document_attention self._patch_size = patch_size self._max_image_size = max_image_size + self._aud_downsampling_k = aud_downsampling_k + self._aud_padding_duration = aud_padding_duration + self._aud_sampling_rate = aud_sampling_rate def setup( self, @@ -152,6 +158,9 @@ def setup( cross_document_attention=self._cross_document_attention, patch_size=self._patch_size, image_size=self._max_image_size, + aud_downsampling_k=self._aud_downsampling_k, + aud_padding_duration=self._aud_padding_duration, + aud_sampling_rate=self._aud_sampling_rate, ) dataset = self._config.datasets[dataset_name].build_and_sample(sampling) self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 65adf0bda..aeb57ffe8 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -75,6 +75,9 @@ class GPTSamplingData(SamplingData): cross_document_attention: bool = True patch_size: int | None = None image_size: int | None = None + aud_downsampling_k: int | None = None + aud_padding_duration: int | None = None + aud_sampling_rate: int | None = None @config_class() diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 209c6e317..1bbd30c78 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -45,8 +45,12 @@ class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[Indexe def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. - doc_sizes, im_sizes = self._dataset.get_document_sizes() - return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] + doc_sizes, im_sizes, aud_sizes = self._dataset.get_document_sizes() + return ( + doc_sizes[self._begin : self._end], + im_sizes[self._begin : self._end], + aud_sizes[self._begin : self._end], + ) def get_document_size(self, index: int, patch_size: list[int]) -> int: return self._dataset.get_document_size(self._begin + index, patch_size) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 752c83cee..88e31d78c 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -59,6 +59,9 @@ def _init( if self._version >= 4: self._has_images = struct.unpack("= 5: + self._has_audio = struct.unpack("= 5: + self._n_audio = np.frombuffer( + self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset + ) + self._audio_lengths = [] + self._audio_positions = [] + audio_seen = 0 + + offset = offset + self._n_audio.nbytes + for n_audio in self._n_audio: + self._audio_lengths.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=n_audio, + offset=offset + audio_seen * np.dtype(np.int32).itemsize, + ).reshape(-1, 2) + ) + # self._num_pixels += self._image_lengths[-1].prod(axis=1, initial=3).sum() + self._audio_positions.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=n_audio, + offset=offset + + self._n_audio.sum() * np.dtype(np.int32).itemsize + + audio_seen * np.dtype(np.int32).itemsize, + ) + ) + audio_seen += n_audio self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) @@ -193,8 +227,30 @@ def get( n_pixels = image_length.prod(initial=3) images.append(pixels[start : start + n_pixels].reshape(3, image_length[0], image_length[1])) start += n_pixels + + if self._has_audio: + audio_positions = self._audio_positions[idx] + all_audio = np.frombuffer( + self._bin_buffer, + dtype=np.dtype(np.float32), + count=self._audio_lengths[idx].sum(), + offset=self._pointers[idx] + + self._document_sizes[idx] * np.dtype(self._dtype).itemsize + + self._image_lengths.prod(initial=3) * np.dtype(np.uint8).itemsize, + ) + audio = [] + start = 0 + for audio_length in self._audio_lengths[idx]: + audio.append(all_audio[start : start + audio_length]) + start += audio_length # TODO Soham: return loss_masking_spans - return GPTSample(token_ids=token_ids, images=images, image_positions=image_positions) + return GPTSample( + token_ids=token_ids, + images=images, + image_positions=image_positions, + audio=audio, + audio_positions=audio_positions, + ) # def get( # self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False @@ -231,6 +287,10 @@ def num_tokens(self) -> int: def has_images(self) -> bool: return self._has_images + @property + def has_audio(self) -> bool: + return self._has_audio + # TODO: image sizes def get_document_sizes(self) -> tuple[np.ndarray, np.ndarray]: """ @@ -238,7 +298,7 @@ def get_document_sizes(self) -> tuple[np.ndarray, np.ndarray]: The resulting array could be very large, so this method should be called cautiously, and derived classes should try to avoid holding the whole array im memory. """ - return self._document_sizes, self._image_lengths + return self._document_sizes, self._image_lengths, self._audio_lengths def get_document_size(self, index: int, patch_size: list[int]) -> int: # return self._document_sizes[index].item() + ( @@ -246,7 +306,10 @@ def get_document_size(self, index: int, patch_size: list[int]) -> int: # if self._has_images # else 0 # ) - return self._document_sizes[index].item(), self._image_lengths[index] if self._has_images else [] + docsize = self._document_sizes[index].item() + imagesize = self._image_lengths[index] if self._has_images else [] + audiosize = self._audio_lengths if self._has_audio else 0 + return docsize, imagesize, audiosize @classmethod def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTSample]): @@ -285,6 +348,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # Write document to binary file bin_stream.write(document.token_ids.tobytes(order="C")) total_im_size = 0 + total_aud_size = 0 if document.images: n_images.append(len(document.images)) total_images += len(document.images) @@ -299,13 +363,13 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP bin_stream.write(pixels.tobytes(order="C")) total_im_size += pixels.size im_positions.append(document.image_positions) - if document.audio: + if document.audio is not None: n_audio.append(len(document.audio)) total_audio += len(document.audio) for audio in document.audio: audio_lengths.append(len(audio)) - bin_stream.write(audio.to_bytes(order="C")) - # total_aud_size += + bin_stream.write(audio.tobytes(order="C")) + total_aud_size += audio.size aud_positions.append(document.audio_positions) # Update metadata @@ -315,7 +379,11 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP if document.loss_masking_spans is not None: num_spans.append(len(document.loss_masking_spans)) spans.append(document.loss_masking_spans) - offset += doc_length * np.dtype(dtype).itemsize + total_im_size * np.dtype(np.uint8).itemsize + offset += ( + doc_length * np.dtype(dtype).itemsize + + total_im_size * np.dtype(np.uint8).itemsize + + total_aud_size * np.dtype(np.float32).itemsize + ) num_documents += 1 # Finalize metadata arrays @@ -329,10 +397,21 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP if total_images: n_images = np.array(n_images, dtype=np.int32) + image_lengths = np.stack(image_lengths, dtype=np.int32) + im_positions = np.array(im_positions, dtype=np.int32) else: n_images = np.array([]) - image_lengths = np.stack(image_lengths, dtype=np.int32) - im_positions = np.array(im_positions, dtype=np.int32) + image_lengths = np.array([]) + im_positions = np.array([]) + + if total_audio: + n_audio = np.array(n_audio, dtype=np.int32) + audio_lengths = np.array(audio_lengths, dtype=np.int32) + aud_positions = np.array(aud_positions, dtype=np.int32) + else: + n_audio = np.array([]) + audio_lengths = np.array([]) + aud_positions = np.array([]) # Write the index file (.idx) with prefix.with_suffix(".idx").open("wb") as idx_stream: @@ -340,7 +419,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # Indicates the version # Version 2 onwards optionally add loss-masking spans # Version 4 onwards optionally add images - idx_stream.write(struct.pack(" 0 else 0)) # Placeholder flag for preference spans @@ -367,5 +446,11 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP idx_stream.write(image_lengths.tobytes(order="C")) # Position of each image in the document idx_stream.write(im_positions.tobytes(order="C")) + # Number of audio per document + idx_stream.write(n_audio.tobytes(order="C")) + # Audio lengths + idx_stream.write(audio_lengths.tobytes(order="C")) + # Position of each audio in the document + idx_stream.write(aud_positions.tobytes(order="C")) # Document indices, unused but needed for compatibility with Megatron-LM idx_stream.write(np.arange(num_documents + 1, dtype=np.int64).tobytes(order="C")) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index d700dcc00..d2a530660 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -94,6 +94,9 @@ def __init__( self._sequence_length = sampling.sequence_length self._patch_size = sampling.patch_size self._image_size = sampling.image_size + self._aud_downsampling_k = sampling.aud_downsampling_k + self._aud_padding_duration = sampling.aud_padding_duration + self._aud_sampling_rate = sampling.aud_sampling_rate self._cross_document_attention = sampling.cross_document_attention self._config = sampling.config self._truncate_documents = sampling.truncate_documents @@ -138,13 +141,22 @@ def _sample(self) -> None: """ # Get the document sizes, the main information needed for sampling. # TODO Soham: verify numpy correctness - document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() + document_sizes, image_sizes, audio_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) image_token_sizes = torch.zeros_like(document_sizes).to(self._device) # TODO Soham: handle max image size for i, sizes in enumerate(image_sizes): image_token_sizes[i] = sum((sizes[:, 0] // self._patch_size) * (sizes[:, 1] // self._patch_size)) + # compute audio token sizes + if self._aud_padding_duration > 0 and len(audio_sizes) > 0: + self._aud_padding_duration * self._aud_sampling_rate + # 2. mel spectogram + + # 3. convolution + + # 4. downsampling + documents_per_epoch = document_sizes.numel() tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 517a9cff5..0f692482e 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -66,6 +66,21 @@ class BatchConfig(Config): desc="Applies attention to tokens from other documents in the packed sequence. Set to False for masking attention to other documents.", hint=FieldHint.feature, ) + aud_downsampling_k: int = Field( + default=5, + desc="Audio downsampling k parameter.", + hint=FieldHint.feature, + ) + aud_padding_duration: int = Field( + default=-1, + desc="Audio padding duration in seconds.", + hint=FieldHint.feature, + ) + aud_sampling_rate: int = Field( + default=16000, + desc="Audio sampling rate to use.", + hint=FieldHint.feature, + ) _distributed: DistributedConfig = Field( init=False, desc="Pointer to a distributed configuration, required to know the data-parallel split of the batch.", diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index bc16829b3..e2ce3fd9f 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -23,6 +23,9 @@ def _get_data(self) -> GPTData: cross_document_attention=self._config.batch.cross_document_attention, patch_size=self._config.batch.patch_size, max_image_size=self._config.batch.max_image_size, + aud_downsampling_k=self._config.batch.aud_downsampling_k, + aud_padding_duration=self._config.batch.aud_padding_duration, + aud_sampling_rate=self._config.batch.aud_sampling_rate, ) def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: From 0d1cd96984612c22e012870ef8db092d65aa8a2a Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Fri, 9 May 2025 18:48:49 +0000 Subject: [PATCH 03/25] audio token computation --- fast_llm/data/dataset/gpt/sampled.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 25fba8b7e..7b70d598d 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -138,18 +138,27 @@ def _sample(self) -> None: document_sizes, image_sizes, audio_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) image_token_sizes = torch.zeros_like(document_sizes).to(self._device) + audio_token_sizes = torch.zeros_like(document_sizes).to(self._device) # TODO Soham: handle max image size for i, sizes in enumerate(image_sizes): image_token_sizes[i] = sum((sizes[:, 0] // self._patch_size) * (sizes[:, 1] // self._patch_size)) # compute audio token sizes - if self._aud_padding_duration > 0 and len(audio_sizes) > 0: - self._aud_padding_duration * self._aud_sampling_rate - # 2. mel spectogram + audio_sizes = torch.tensor(audio_sizes) - # 3. convolution + # account for padding + if len(audio_sizes) > 0 and self._parameters.aud_padding_duration > 0: + raw_audio_seq_length = self._parameters.aud_padding_duration * self._parameters.aud_sampling_rate + audio_sizes.fill_(raw_audio_seq_length) # set all audio sizes to padded amount + long_audio_filter = audio_sizes > raw_audio_seq_length # filter audio that is too long + else: + audio_sizes > self._parameters.sequence_length + 1 - # 4. downsampling + # account for mel spectogram, convolution, downsampling k + audio_token_sizes = audio_sizes / 160 # default hop length + audio_token_sizes = audio_token_sizes // ( + 2 * self._parameters.aud_downsampling_k + ) # convolution (2) * downsampling documents_per_epoch = document_sizes.numel() tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() From 40f3882d6f7738238ad53579edc517e29f17d3f2 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Sat, 10 May 2025 16:52:12 +0000 Subject: [PATCH 04/25] implement mm packing --- fast_llm/data/dataset/gpt/memmap.py | 42 ++++--- fast_llm/data/dataset/gpt/sampled.py | 165 ++++++++++++++++++++------- 2 files changed, 151 insertions(+), 56 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 88e31d78c..cd4bf7b7f 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -49,7 +49,7 @@ def _init( with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") self._version = struct.unpack("= 2: self._has_spans = struct.unpack("= 4: self._n_images = np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset ) - self._image_lengths = [] - self._image_positions = [] images_seen = 0 # TODO Soham: verify correctness, reshaping into width, height? for n_images in self._n_images: @@ -141,12 +141,12 @@ def _init( ) images_seen += n_images offset = offset + self._n_images.nbytes + 3 * self._n_images.sum() * np.dtype(np.int32).itemsize + self._audio_lengths = [] + self._audio_positions = [] if self._has_audio and self._version >= 5: self._n_audio = np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset ) - self._audio_lengths = [] - self._audio_positions = [] audio_seen = 0 offset = offset + self._n_audio.nbytes @@ -157,7 +157,7 @@ def _init( dtype=np.int32, count=n_audio, offset=offset + audio_seen * np.dtype(np.int32).itemsize, - ).reshape(-1, 2) + ) ) # self._num_pixels += self._image_lengths[-1].prod(axis=1, initial=3).sum() self._audio_positions.append( @@ -177,11 +177,13 @@ def _init( # TODO Soham: fix num_tokens to include images. Get total number of image pixels from index file and assign # self._num_tokens = div(self._bin_buffer_mmap.size - n_pixels, np.dtype(self._dtype).itemsize) + + # TODO Toby: Add audio num tokens check self._num_tokens = div(self._bin_buffer_mmap.size - self._num_pixels, np.dtype(self._dtype).itemsize) - if num_pixels is not None: - assert self._num_pixels == num_pixels - if num_tokens is not None: - assert self._num_tokens == num_tokens + # if num_pixels is not None: + # assert self._num_pixels == num_pixels + # if num_tokens is not None: + # assert self._num_tokens == num_tokens def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: return (self._name, self._prefix, self._num_documents, self._num_tokens, self._num_pixels) @@ -212,6 +214,8 @@ def get( count=self._document_sizes[idx] - offset if length is None else length, offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, ) + images = [] + image_positions = np.array([]) if self._has_images: image_positions = self._image_positions[idx] pixels = np.frombuffer( @@ -220,7 +224,6 @@ def get( count=self._image_lengths[idx].prod(initial=3), offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize, ) - images = [] start = 0 for image_length in self._image_lengths[idx]: # TODO Soham: verify reshape dimension order @@ -228,17 +231,19 @@ def get( images.append(pixels[start : start + n_pixels].reshape(3, image_length[0], image_length[1])) start += n_pixels + audio = [] + audio_positions = np.array([]) if self._has_audio: audio_positions = self._audio_positions[idx] + offset = self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize + if len(self._image_lengths) > 0: + offset += self._image_lengths[idx].prod(initial=3) * np.dtype(np.uint8).itemsize all_audio = np.frombuffer( self._bin_buffer, dtype=np.dtype(np.float32), count=self._audio_lengths[idx].sum(), - offset=self._pointers[idx] - + self._document_sizes[idx] * np.dtype(self._dtype).itemsize - + self._image_lengths.prod(initial=3) * np.dtype(np.uint8).itemsize, + offset=offset, ) - audio = [] start = 0 for audio_length in self._audio_lengths[idx]: audio.append(all_audio[start : start + audio_length]) @@ -308,7 +313,7 @@ def get_document_size(self, index: int, patch_size: list[int]) -> int: # ) docsize = self._document_sizes[index].item() imagesize = self._image_lengths[index] if self._has_images else [] - audiosize = self._audio_lengths if self._has_audio else 0 + audiosize = self._audio_lengths[index] if self._has_audio else [] return docsize, imagesize, audiosize @classmethod @@ -370,7 +375,8 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP audio_lengths.append(len(audio)) bin_stream.write(audio.tobytes(order="C")) total_aud_size += audio.size - aud_positions.append(document.audio_positions) + if len(document.audio) > 0: + aud_positions.append(document.audio_positions) # Update metadata doc_length = len(document.token_ids) @@ -426,6 +432,8 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP idx_stream.write(struct.pack(" 0 else 0)) + # Flag to indicate whether audio is present + idx_stream.write(struct.pack(" 0 else 0)) # Data type idx_stream.write(struct.pack(" 0: + raw_audio_seq_length = self._parameters.aud_padding_duration * self._parameters.aud_sampling_rate + sizes = sizes.copy() # original is read-only + to_filter = bool(np.any(sizes > raw_audio_seq_length)) # filter sample where any audio is too long + sizes.fill(raw_audio_seq_length) # set all audio sizes to padded amount + + # account for mel spectogram, convolution, downsampling k + audio_token_size_arr = sizes // 160 # default hop length TODO: check divisible? + audio_token_size_arr = audio_token_size_arr // ( + 2 * self._parameters.aud_downsampling_k + ) # convolution (2) * downsampling + return audio_token_size_arr, to_filter + + def apply_audio_padding(self, audio): + if len(audio) == 0: + return audio + # TODO Toby: check 2d + padded_audio = [] + if self._parameters.aud_padding_duration > 0: + raw_audio_seq_length = self._parameters.aud_padding_duration * self._parameters.aud_sampling_rate + for aud in audio: + padded = np.pad(aud, (0, raw_audio_seq_length - len(aud)), mode="constant", constant_values=0) + padded_audio.append(padded) + return padded_audio + else: + return audio + def _sample(self) -> None: """ Create a `GPTSampledDataset` with the requested parameters. @@ -139,29 +175,22 @@ def _sample(self) -> None: document_sizes = torch.from_numpy(document_sizes).to(self._device) image_token_sizes = torch.zeros_like(document_sizes).to(self._device) audio_token_sizes = torch.zeros_like(document_sizes).to(self._device) + long_audio_filter = torch.zeros_like(document_sizes, dtype=torch.bool) # TODO Soham: handle max image size for i, sizes in enumerate(image_sizes): - image_token_sizes[i] = sum((sizes[:, 0] // self._patch_size) * (sizes[:, 1] // self._patch_size)) - - # compute audio token sizes - audio_sizes = torch.tensor(audio_sizes) - - # account for padding - if len(audio_sizes) > 0 and self._parameters.aud_padding_duration > 0: - raw_audio_seq_length = self._parameters.aud_padding_duration * self._parameters.aud_sampling_rate - audio_sizes.fill_(raw_audio_seq_length) # set all audio sizes to padded amount - long_audio_filter = audio_sizes > raw_audio_seq_length # filter audio that is too long - else: - audio_sizes > self._parameters.sequence_length + 1 + image_token_sizes[i] = sum( + (sizes[:, 0] // self._parameters.patch_size) * (sizes[:, 1] // self._parameters.patch_size) + ) - # account for mel spectogram, convolution, downsampling k - audio_token_sizes = audio_sizes / 160 # default hop length - audio_token_sizes = audio_token_sizes // ( - 2 * self._parameters.aud_downsampling_k - ) # convolution (2) * downsampling + for i, sizes in enumerate(audio_sizes): + audio_token_size_arr, to_filter = self._compute_audio_token_size(sizes) + audio_token_sizes[i] = audio_token_size_arr.sum() + long_audio_filter[i] = to_filter documents_per_epoch = document_sizes.numel() - tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() + tokens_per_epoch = ( + document_sizes.sum().item() + image_token_sizes.sum().item() + audio_token_sizes.sum().item() + ) # Calculate basic stats. if not self._truncate_documents: @@ -169,14 +198,31 @@ def _sample(self) -> None: "The C++ extension for dataset sampling is missing." " Please make sure Fast-LLM is installed correctly." ) - long_docs_filter = document_sizes + image_token_sizes > self._parameters.sequence_length + 1 + long_docs_filter = ( + document_sizes + image_token_sizes + audio_token_sizes > self._parameters.sequence_length + 1 + ) ignored_documents = sum(long_docs_filter) if ignored_documents: log_main_rank( f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._parameters.sequence_length+1} tokens and will be ignored.", log_fn=logger.warning, ) - tokens_per_epoch = (document_sizes[~long_docs_filter] + image_token_sizes[~long_docs_filter]).sum().item() + ignored_audio_samples = sum(long_audio_filter) + if ignored_audio_samples: + log_main_rank( + f" > {ignored_audio_samples}/{documents_per_epoch} samples contain audio longer than {self._parameters.aud_padding_duration} seconds and will be ignored.", + log_fn=logger.warning, + ) + long_docs_filter = long_docs_filter | long_audio_filter + tokens_per_epoch = ( + ( + document_sizes[~long_docs_filter] + + image_token_sizes[~long_docs_filter] + + audio_token_sizes[~long_docs_filter] + ) + .sum() + .item() + ) if tokens_per_epoch == 0: raise RuntimeError( f" > No documents shorter than {self._parameters.sequence_length+1} tokens found in dataset {self._indexed_dataset.name}." @@ -215,7 +261,7 @@ def _sample(self) -> None: "num_samples": self._parameters.num_samples, "unshuffled_epochs": unshuffled_epochs, "sequence_length": self._parameters.sequence_length, - "patch_size": self._patch_size, + "patch_size": self._parameters.patch_size, "truncate_documents": self._truncate_documents, "config": self._config.to_dict(), } @@ -298,7 +344,7 @@ def _sample(self) -> None: # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` if unshuffled_epochs > 0: token_cumsum_unshuffled, unshuffled_tokens = self._get_token_cumsum( - document_sizes + image_token_sizes, + document_sizes + image_token_sizes + audio_token_sizes, offset=0, # TODO: Allowing for max 100% extra tokens for padding, is that enough? dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), @@ -308,7 +354,7 @@ def _sample(self) -> None: unshuffled_tokens = 0 if not self._truncate_documents: - yaml_data["unshuffled_tokens"] = unshuffled_tokens + yaml_data["unshuffled_tokens"] = unshuffled_tokens.item() self._load_yaml_data(yaml_data) if self._yaml_path is not None: self._yaml_path.parent.mkdir(parents=True, exist_ok=True) @@ -323,7 +369,14 @@ def _sample(self) -> None: ) ] + image_token_sizes[ - document_shuffling.to(torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32) + document_shuffling.to( + dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 + ) + ] + + audio_token_sizes[ + document_shuffling.to( + dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 + ) ], offset=self._unshuffled_tokens, # TODO: Allowing for max 100% extra tokens for padding, is that enough? @@ -416,8 +469,10 @@ def __getitem__(self, index: int) -> typing.Any: token_ids = [] loss_masking_spans = [] images = [] + audio = [] image_positions = [] - image_tokens_added = 0 + audio_positions = [] + mm_tokens_added = 0 while token_count < token_end: # Find the document index in the dataset. if document_sampling_index < self._unshuffled_documents: @@ -425,29 +480,40 @@ def __getitem__(self, index: int) -> typing.Any: else: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - document_size, image_lengths = self._indexed_dataset.get_document_size(document_index, self._patch_size) + document_size, image_lengths, audio_lengths = self._indexed_dataset.get_document_size( + document_index, self._parameters.patch_size + ) image_sizes = [ get_num_patches( - *get_resize_dims(*image_length, self._image_size, self._image_size, self._patch_size), - self._patch_size, + *get_resize_dims(*image_length, self._image_size, self._image_size, self._parameters.patch_size), + self._parameters.patch_size, ) for image_length in image_lengths ] image_tokens = sum(image_sizes) + audio_token_size_arr, _ = self._compute_audio_token_size(audio_lengths) + audio_tokens = audio_token_size_arr.sum() + if not self._truncate_documents: - if document_size + image_tokens > self._parameters.sequence_length + 1: + if document_size + image_tokens + audio_tokens > self._parameters.sequence_length + 1: # Document too long, ignore document_sampling_index += 1 continue tokens_in_sample = token_count % (self._parameters.sequence_length + 1) - if document_size + image_tokens + tokens_in_sample > self._parameters.sequence_length + 1: + if ( + document_size + image_tokens + audio_tokens + tokens_in_sample + > self._parameters.sequence_length + 1 + ): # Document belongs to the next sample, need to account for padding. padding_size = self._parameters.sequence_length + 1 - tokens_in_sample if token_count > token_start: # Add padding tokens to current sample - token_ids.append(np.full((padding_size,), -100, dtype=np.int64)) + try: + token_ids.append(np.full((padding_size,), -100, dtype=np.int64)) + except: + pass Assert.eq(token_count + padding_size, token_end) break else: @@ -455,7 +521,7 @@ def __getitem__(self, index: int) -> typing.Any: token_count += padding_size # Determine if the document belongs to the requested sample. - if token_count + document_size + image_tokens >= token_start: + if token_count + document_size + image_tokens + audio_tokens >= token_start: # Determine which part of the document belong to the sample, and add it to the list. token_start_index_in_document = max(token_start - token_count, 0) token_end_index_in_document = min(token_end - token_count, document_size) @@ -466,16 +532,32 @@ def __getitem__(self, index: int) -> typing.Any: use_loss_masking_spans=self._parameters.use_loss_masking_spans, ) start_pos = 0 - for idx, im_position in enumerate(sample.image_positions): + multimodal_positions = np.concatenate( + [sample.image_positions.astype(np.int32), sample.audio_positions.astype(np.int32)] + ) + multimodal_positions.sort() + for idx, mm_position in enumerate(multimodal_positions): + if mm_position in sample.image_positions: # TODO Toby: use enum + mm_type = "image" + elif mm_position in sample.audio_positions: + mm_type = "audio" + else: + assert False # image_positions.append(im_positions + len(token_ids) + image_tokens_added) # Add placeholders for image tokens - token_ids.append(sample.token_ids[start_pos:im_position]) - token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) - image_positions.append(im_position + len(token_ids) + image_tokens_added) - image_tokens_added += image_tokens - start_pos = im_position + token_ids.append(sample.token_ids[start_pos:mm_position]) + if mm_type == "image": + token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) + image_positions.append(mm_position + len(token_ids) + mm_tokens_added) + mm_tokens_added += image_tokens + elif mm_type == "audio": + token_ids.append(np.full((audio_token_size_arr[idx],), -100, dtype=np.int64)) + audio_positions.append(mm_position + mm_tokens_added) + mm_tokens_added += audio_tokens + start_pos = mm_position token_ids.append(sample.token_ids[start_pos:]) images.append(sample.images) + audio.append(self.apply_audio_padding(sample.audio)) # TODO Soham: add offsets for loss masking spans if self._parameters.use_loss_masking_spans: for loss_masking_span in sample.loss_masking_spans: @@ -489,7 +571,7 @@ def __getitem__(self, index: int) -> typing.Any: # Go to the next document. document_sampling_index += 1 - token_count += document_size + image_tokens + token_count += document_size + image_tokens + audio_tokens sequence_lengths = ( np.array([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) @@ -504,6 +586,9 @@ def __getitem__(self, index: int) -> typing.Any: ) images = [im for img_list in images for im in img_list] if images else None image_positions = np.array(image_positions) if image_positions else None + + audio = [aud for aud_list in audio for aud in aud_list] if audio else None + audio_positions = np.array(audio_positions) if audio_positions else None Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens) return GPTSample( @@ -512,6 +597,8 @@ def __getitem__(self, index: int) -> typing.Any: sequence_lengths=sequence_lengths, images=images, image_positions=image_positions, + audio=audio, + audio_positions=audio_positions, ) @property From 94e439c9c3b5b2a2d486f2a940fc5947c9bd6e22 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Thu, 15 May 2025 22:53:23 +0000 Subject: [PATCH 05/25] data updates --- fast_llm/data/data/gpt/data.py | 21 +++++++++++- fast_llm/data/dataset/gpt/memmap.py | 8 ++--- fast_llm/data/dataset/gpt/sampled.py | 50 ++++++++++++++++------------ 3 files changed, 53 insertions(+), 26 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 4fcd42ae1..0e43ec2b0 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -34,6 +34,8 @@ class GPTBatch: sequence_lengths: list[torch.Tensor] | None = None images: list[torch.Tensor] | None = None image_positions: list[torch.Tensor] | None = None + audio: list[torch.Tensor] | None = None + audio_positions: list[torch.Tensor] | None = None def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch: @@ -54,16 +56,33 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling batch_images.append(None) batch_image_positions = [] for sample in batch: - if sample.image_positions is not None: + if sample.image_positions is not None and len(sample.image_positions) > 0: batch_image_positions.append(torch.from_numpy(sample.image_positions)) else: batch_image_positions.append(None) + + has_audio = False + batch_audio = [] + for sample in batch: + if sample.audio is not None and len(sample.audio_positions) > 0: + batch_audio.append([torch.from_numpy(image) for image in sample.audio]) + has_audio = True + else: + batch_audio.append(None) + batch_audio_positions = [] + for sample in batch: + if sample.audio_positions is not None: + batch_audio_positions.append(torch.from_numpy(sample.audio_positions)) + else: + batch_audio_positions.append(None) return GPTBatch( token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths, images=batch_images if has_images else None, image_positions=batch_image_positions if has_images else None, + audio=batch_audio if has_audio else None, + audio_positions=batch_image_positions if has_audio else None, ) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 619c56242..50d4b4165 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -111,8 +111,8 @@ def _init( + sum([x.nbytes for x in self._spans]) ) self._num_pixels = 0 - self._image_lengths = None - self._image_positions = None + self._image_lengths = [] + self._image_positions = [] if self._has_images and self._version >= 4: self._n_images = np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset @@ -266,7 +266,7 @@ def get( if self._has_audio: audio_positions = self._audio_positions[idx] offset = self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize - if len(self._image_lengths) > 0: + if self._has_images and len(self._image_lengths) > 0: offset += self._image_lengths[idx].prod(initial=3) * np.dtype(np.uint8).itemsize all_audio = np.frombuffer( self._bin_buffer, @@ -340,7 +340,7 @@ def get_document_sizes(self) -> tuple[np.ndarray, np.ndarray]: """ return self._document_sizes, self._image_lengths, self._audio_lengths - def get_document_size(self, index: int, patch_size: list[int]) -> int: + def get_document_size(self, index: int) -> int: # return self._document_sizes[index].item() + ( # sum((h // patch_size[0]) * (w // patch_size[1]) for h, w in self._image_lengths[index]) # if self._has_images diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index d5c8fc4b8..5074f1a09 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -169,26 +169,26 @@ def _sample(self) -> None: Create a `GPTSampledDataset` with the requested parameters. """ # Get the document sizes, the main information needed for sampling. - document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() + document_sizes, image_sizes, audio_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) - image_token_sizes = [] + image_token_sizes = torch.zeros_like(document_sizes).to(self._device) for i, sizes in enumerate(image_sizes): - image_token_sizes.append( - sum( - get_num_patches( - *get_resize_dims( - *size, - self._parameters.image_size, - self._parameters.image_size, - self._parameters.patch_size, - ), + image_token_sizes[i] = sum( + get_num_patches( + *get_resize_dims( + *size, + self._parameters.image_size, + self._parameters.image_size, self._parameters.patch_size, - ) - for size in sizes + ), + self._parameters.patch_size, ) + for size in sizes ) - image_token_sizes = torch.tensor(image_token_sizes).to(self._device) + # image_token_sizes = torch.tensor(image_token_sizes).to(self._device) + audio_token_sizes = torch.zeros_like(document_sizes).to(self._device) + long_audio_filter = torch.zeros_like(document_sizes, dtype=torch.bool) # longer than audio padding for i, sizes in enumerate(audio_sizes): audio_token_size_arr, to_filter = self._compute_audio_token_size(sizes) audio_token_sizes[i] = audio_token_size_arr.sum() @@ -502,17 +502,22 @@ def __getitem__(self, index: int) -> typing.Any: for image_length in image_lengths ] image_tokens = sum(image_sizes) - document_size = text_size + image_tokens audio_token_size_arr, _ = self._compute_audio_token_size(audio_lengths) audio_tokens = audio_token_size_arr.sum() + document_size = text_size + image_tokens + audio_tokens + if not self._truncate_documents: + # Document too long, ignore if document_size > self._parameters.sequence_length + 1: - # Document too long, ignore document_sampling_index += 1 continue + + # Where are we currently in sample? tokens_in_sample = token_count % (self._parameters.sequence_length + 1) + + # Add padding if document_size + tokens_in_sample > self._parameters.sequence_length + 1: # Document belongs to the next sample, need to account for padding. padding_size = self._parameters.sequence_length + 1 - tokens_in_sample @@ -540,6 +545,8 @@ def __getitem__(self, index: int) -> typing.Any: use_loss_masking_spans=self._parameters.use_loss_masking_spans, ) start_pos = 0 + + # add tokens and multi modal padding placeholders multimodal_positions = np.concatenate( [sample.image_positions.astype(np.int32), sample.audio_positions.astype(np.int32)] ) @@ -552,7 +559,7 @@ def __getitem__(self, index: int) -> typing.Any: else: assert False # image_positions.append(im_positions + len(token_ids) + image_tokens_added) - # Add placeholders for image tokens + # Add placeholders for image and audio tokens tokens token_ids.append(sample.token_ids[start_pos:mm_position]) if mm_type == "image": token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) @@ -560,7 +567,7 @@ def __getitem__(self, index: int) -> typing.Any: mm_tokens_added += image_tokens elif mm_type == "audio": token_ids.append(np.full((audio_token_size_arr[idx],), -100, dtype=np.int64)) - audio_positions.append(mm_position + mm_tokens_added) + audio_positions.append(len(token_ids)) mm_tokens_added += audio_tokens start_pos = mm_position token_ids.append(sample.token_ids[start_pos:]) @@ -593,12 +600,13 @@ def __getitem__(self, index: int) -> typing.Any: if self._parameters.use_loss_masking_spans else None ) - images = [im for img_list in images for im in img_list] if images else None - image_positions = np.array(image_positions) if image_positions else None + # images = [im for img_list in images for im in img_list] if images else None + # image_positions = np.array(image_positions) if image_positions else None + images = None audio = [aud for aud_list in audio for aud in aud_list] if audio else None audio_positions = np.array(audio_positions) if audio_positions else None - Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens) + # Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens) return GPTSample( token_ids=token_ids, From 543fc0d53026f82bd0735f872e4109f5fea8f7fa Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Fri, 16 May 2025 20:52:23 +0000 Subject: [PATCH 06/25] changes --- fast_llm/data/dataset/gpt/memmap.py | 2 +- fast_llm/data/tokenizer.py | 6 +++--- fast_llm/engine/schedule/config.py | 11 ----------- fast_llm/layers/language_model/config.py | 8 +++++++- fast_llm/layers/transformer/config.py | 13 +++++++++++++ fast_llm/models/gpt/model.py | 22 ++++++++++++++++------ fast_llm/models/gpt/trainer.py | 4 ++-- 7 files changed, 42 insertions(+), 24 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 50d4b4165..d63653e61 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -411,7 +411,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP bin_stream.write(audio.tobytes(order="C")) total_aud_size += audio.size if len(document.audio) > 0: - aud_positions.append(document.audio_positions) + aud_positions += document.audio_positions # Update metadata doc_length = len(document.token_ids) diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index c6d7a51a2..cccf59856 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -73,7 +73,7 @@ def tokenize( token_ids.extend(tokenized_text) # update mm token positions - multimodal_type = "image" if multimodal_position in multimodal_positions else "audio" + multimodal_type = "image" if multimodal_position in image_positions else "audio" if multimodal_type == "image": image_token_positions.append(len(token_ids)) else: @@ -104,7 +104,7 @@ def tokenize( token_ids.extend(tokenized_text) # update mm token positions - multimodal_type = "image" if multimodal_position in multimodal_positions else "audio" + multimodal_type = "image" if multimodal_position in image_positions else "audio" if multimodal_type == "image": image_token_positions.append(len(token_ids)) else: @@ -141,7 +141,7 @@ def tokenize( token_ids.extend(tokenized_text) # update mm token positions - multimodal_type = "image" if multimodal_position in multimodal_positions else "audio" + multimodal_type = "image" if multimodal_position in image_positions else "audio" if multimodal_type == "image": image_token_positions.append(len(token_ids)) else: diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 7412beb0d..13aee09a6 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -60,22 +60,11 @@ class BatchConfig(Config): desc="Maximum image height and width", hint=FieldHint.optional, ) - # Audio inputs - aud_downsampling_k: int = Field( - default=5, - desc="Audio downsampling k parameter.", - hint=FieldHint.feature, - ) aud_padding_duration: int = Field( default=-1, desc="Audio padding duration in seconds.", hint=FieldHint.feature, ) - aud_sampling_rate: int = Field( - default=16000, - desc="Audio sampling rate to use.", - hint=FieldHint.feature, - ) def setup(self, distributed_config: DistributedConfig) -> None: self._distributed = distributed_config diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 78de218f1..44bc5f30f 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -5,6 +5,7 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl +from fast_llm.layers.audio_encoder.config import AudioEncoderConfig from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.layers.vision_encoder.config import VisionEncoderConfig from fast_llm.utils import Assert @@ -47,11 +48,16 @@ class LanguageModelBaseConfig(BaseModelConfig): hint=FieldHint.architecture, ) # TODO Soham: make this None by default. Need to figure out how to handle this in the config (see ) - vision_encoder: VisionEncoderConfig = Field( + vision_encoder: VisionEncoderConfig | None = Field( default_factory=VisionEncoderConfig, desc="Configuration for the vision encoder that transforms images into embeddings.", hint=FieldHint.optional, ) + audio_encoder: AudioEncoderConfig | None = Field( + default_factory=AudioEncoderConfig, + desc="Configuration for the audio encoder that transforms audio into embeddings.", + hint=FieldHint.optional, + ) max_position_embeddings: int = Field( default=2048, desc="Number of absolute position embeddings, if applicable.", diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 38dc9ec48..40a29959b 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -765,3 +765,16 @@ class VisionTransformerConfig(TransformerConfig): desc="Configuration for the rotary positional embeddings.", hint=FieldHint.feature, ) + + +@config_class() +class AudioTransformerConfig(TransformerConfig): + """ + Configuration for the Audio Transformer model. + """ + + causal: bool = FieldUpdate( + default=False, + desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Vision Transformer.", + hint=FieldHint.feature, + ) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index b832f1b04..b4bbf736e 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -10,6 +10,8 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.layers.audio_encoder.config import AudioEncoderKwargs +from fast_llm.layers.audio_encoder.preprocessing import AudioPreprocessor from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead @@ -82,11 +84,8 @@ def __init__( self._preprocessors.append( RotaryEmbeddingPreprocessor(self._config.vision_encoder.transformer.rotary, self._tensor_space) ) - # self._vision_preprocessor = VisionPreprocessor(self._config.vision_encoder, self._tensor_space) - # if self._config.vision_encoder.transformer.rotary.enabled: - # self._vision_rotary_embedding_preprocessor = RotaryEmbeddingPreprocessor( - # self._config.vision_encoder.transformer.rotary, self._tensor_space - # ) + if self._config.audio_encoder: + self._preprocessors.append(AudioPreprocessor(self._config.audio_encoder, self._tensor_space)) def get_output_layers(self) -> list[Layer]: layers = [] @@ -418,6 +417,17 @@ def preprocess( kwargs[VisionEncoderKwargs.image_positions] = batch.image_positions kwargs[LanguageModelKwargs.tokens] = tokens + if batch.audio is not None: + kwargs[AudioEncoderKwargs.audio] = [ + [ + aud.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) + for aud in audio + ] + for audio in batch.audio + ] + kwargs[AudioEncoderKwargs.audio_positions] = batch.audio_positions + kwargs[LanguageModelKwargs.tokens] = tokens + for preprocessor in self._preprocessors: preprocessor.preprocess(tokens, kwargs) image_patches = kwargs.get(VisionEncoderKwargs.image_patches, None) @@ -448,7 +458,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: if self._config.tie_word_embeddings: return { WORD_EMBEDDINGS_WEIGHT: ( - self.embedding.word_embeddings_weight, + self.layers[0].word_embeddings_weight, (self._config.vision_encoder is not None, *self.model_head_indices), ) } diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 66587b7c8..ed912ec40 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -32,9 +32,9 @@ def _get_sampling_parameters( "extra_tokens": self._config.model.base_model.prediction_heads, "patch_size": self._config.batch.patch_size, "image_size": self._config.batch.image_size, - "aud_downsampling_k": self._config.batch.aud_downsampling_k, + "aud_downsampling_k": self._config.model.base_model.audio_encoder.aud_downsampling_k, "aud_padding_duration": self._config.batch.aud_padding_duration, - "aud_sampling_rate": self._config.batch.aud_sampling_rate, + "aud_sampling_rate": self._config.model.base_model.audio_encoder.aud_sampling_rate, } ) return parameters if _return_dict else GPTSamplingParameters(**parameters) From 1a20913589ebd48f6d50ad324a7eacd46868a0e4 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Tue, 20 May 2025 23:13:53 +0000 Subject: [PATCH 07/25] layer changes --- fast_llm/data/data/gpt/data.py | 2 +- fast_llm/data/dataset/gpt/memmap.py | 10 +- fast_llm/data/dataset/gpt/sampled.py | 16 +- fast_llm/data/dataset/monitor.py | 22 +- .../data/preparator/gpt_memmap/prepare.py | 6 +- fast_llm/layers/audio_encoder/adapter.py | 54 ++++ fast_llm/layers/audio_encoder/config.py | 143 +++++++++++ fast_llm/layers/audio_encoder/encoder.py | 61 +++++ .../layers/audio_encoder/preprocessing.py | 47 ++++ fast_llm/layers/language_model/config.py | 4 +- .../layers/transformer/audio_transformer.py | 41 +++ fast_llm/layers/transformer/config.py | 20 +- fast_llm/models/gpt/config.py | 8 + fast_llm/models/gpt/conversion.py | 239 ++++++++++++++++++ fast_llm/models/gpt/model.py | 35 ++- 15 files changed, 680 insertions(+), 28 deletions(-) create mode 100644 fast_llm/layers/audio_encoder/adapter.py create mode 100644 fast_llm/layers/audio_encoder/config.py create mode 100644 fast_llm/layers/audio_encoder/encoder.py create mode 100644 fast_llm/layers/audio_encoder/preprocessing.py create mode 100644 fast_llm/layers/transformer/audio_transformer.py diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 0e43ec2b0..028f008a4 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -65,7 +65,7 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling batch_audio = [] for sample in batch: if sample.audio is not None and len(sample.audio_positions) > 0: - batch_audio.append([torch.from_numpy(image) for image in sample.audio]) + batch_audio.append([torch.from_numpy(audio) for audio in sample.audio]) has_audio = True else: batch_audio.append(None) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 51fd4cc24..9ea397072 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -261,11 +261,17 @@ def get( images.append(pixels[start : start + n_pixels].reshape(3, image_length[0], image_length[1])) start += n_pixels - audio = None + audio = [] audio_positions = None if self._has_audio: audio_positions = self._audio_positions[idx] - offset = self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize + # increment offset by documents and images + offset = ( + self._pointers[idx] + + offset * np.dtype(self._dtype).itemsize + + self._document_sizes[idx] * np.dtype(self._dtype).itemsize + ) + if self._has_images and len(self._image_lengths) > 0: offset += self._image_lengths[idx].prod(initial=3) * np.dtype(np.uint8).itemsize all_audio = np.frombuffer( diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 45ddeb86f..b1cdf8262 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -554,13 +554,19 @@ def __getitem__(self, index: int) -> typing.Any: # add tokens and multi modal padding placeholders multimodal_positions = np.concatenate( - [sample.image_positions.astype(np.int32), sample.audio_positions.astype(np.int32)] - ) + [ + arr.astype(np.int32) + for arr in (sample.image_positions, sample.audio_positions) + if arr is not None + ] + ) or np.array([], dtype=np.int32) multimodal_positions.sort() for idx, mm_position in enumerate(multimodal_positions): - if mm_position in sample.image_positions: # TODO Toby: use enum + if ( + sample.image_positions is not None and mm_position in sample.image_positions + ): # TODO Toby: use enum mm_type = "image" - elif mm_position in sample.audio_positions: + elif sample.audio_positions is not None and mm_position in sample.audio_positions: mm_type = "audio" else: assert False @@ -572,8 +578,8 @@ def __getitem__(self, index: int) -> typing.Any: image_positions.append(mm_position + len(token_ids) + mm_tokens_added) mm_tokens_added += image_tokens elif mm_type == "audio": + audio_positions.append(sum(t.size for t in token_ids)) token_ids.append(np.full((audio_token_size_arr[idx],), -100, dtype=np.int64)) - audio_positions.append(len(token_ids)) mm_tokens_added += audio_tokens start_pos = mm_position token_ids.append(sample.token_ids[start_pos:]) diff --git a/fast_llm/data/dataset/monitor.py b/fast_llm/data/dataset/monitor.py index 86bc080fe..53df3add1 100644 --- a/fast_llm/data/dataset/monitor.py +++ b/fast_llm/data/dataset/monitor.py @@ -35,18 +35,16 @@ def __len__(self) -> int: def __getitem__(self, idx) -> typing.Any: start_time = time.perf_counter() - try: - sample = self._dataset[idx] - sample_time = (time.perf_counter() - start_time) * 1000 - if sample_time > self._data_sample_warn_time_ms: - logger.warning( - f"Sample {idx} from dataset {self._dataset.name})" f" took {sample_time:,.2f} ms to load" - ) - return sample - - except Exception: - logger.error(f"Failed to get sample {idx} from dataset {self._dataset.name}") - raise + # try: + sample = self._dataset[idx] + sample_time = (time.perf_counter() - start_time) * 1000 + if sample_time > self._data_sample_warn_time_ms: + logger.warning(f"Sample {idx} from dataset {self._dataset.name})" f" took {sample_time:,.2f} ms to load") + return sample + + # except Exception as e: + # logger.error(f"Failed to get sample {idx} from dataset {self._dataset.name}") + # raise @property def name(self) -> str: diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index c697d54dc..bffa6c834 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -150,7 +150,11 @@ def _document_generator(): # [np.array(im) for im in item["images"]] if self._config.dataset.images else None, item["images"] if self._config.dataset.images else None, item["image_positions"] if self._config.dataset.image_positions else None, - np.array(item[self._config.dataset.audio]) if self._config.dataset.audio else None, + ( + np.array(item[self._config.dataset.audio], dtype=np.float32) + if self._config.dataset.audio + else None + ), item[self._config.dataset.audio_positions] if self._config.dataset.audio_positions else None, ) # if "token_spans" in shard_dataset.column_names and self._config.dataset.loss_masking_spans is not None: diff --git a/fast_llm/layers/audio_encoder/adapter.py b/fast_llm/layers/audio_encoder/adapter.py new file mode 100644 index 000000000..4f77971e8 --- /dev/null +++ b/fast_llm/layers/audio_encoder/adapter.py @@ -0,0 +1,54 @@ +import typing + +import torch + +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.functional.triton.mlp import torch_mlp_activation +from fast_llm.layers.common.linear import Linear +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames +from fast_llm.tensor import TensorMeta, init_normal_ + + +class AudioAdapter(Layer): + """ + Vision adapter layer for the LLM. + """ + + def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): + super().__init__() + input_dim = tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) + self._activation_type = config.adapter_activation_type + # TODO Soham: Make them OutputParallelLinear instead? How would this work with parallelism? + self.layer_1 = Linear( + input_dim, + tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), + bias=True, + weight_init_method=init_normal_(), + bias_init_method=init_normal_(), + ) + self.layer_2 = Linear( + tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), + tensor_space.get_tensor_dim(TransformerDimNames.hidden), + bias=True, + weight_init_method=init_normal_(), + bias_init_method=init_normal_(), + ) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> torch.Tensor: + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + kwargs[TransformerKwargs.hidden_dims], + tensor_name="Vision adapter output", + dtype=input_.dtype, + ) + return self.layer_2( + torch_mlp_activation(input_=self.layer_1(input_), gated=False, activation_type=self._activation_type) + ) diff --git a/fast_llm/layers/audio_encoder/config.py b/fast_llm/layers/audio_encoder/config.py new file mode 100644 index 000000000..52a8673e4 --- /dev/null +++ b/fast_llm/layers/audio_encoder/config.py @@ -0,0 +1,143 @@ +import enum + +from fast_llm.config import Field, FieldHint, config_class +from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.functional.config import ActivationType +from fast_llm.layers.transformer.config import AudioTransformerConfig + + +class AudioEncoderDimNames: + in_channels = "audio_in_channels" + out_channels = "audio_out_channels" + kernel_size = "audio_kernel_size" + adapter_size = "audio_adapter_size" + audio_channels = "audio_kv_channels" + + +class AudioTransformerDimNames: + # A set of common tensor dim names packed into a namespace. + # Input dimensions (variable) + # TODO: Does batch belong here? + batch = "audio_batch" + # TODO: Distinguish micro-sequence? + sequence_q = "audio_sequence_q" + sequence_q_tp = "audio_sequence_q_tp" + sequence_k = "audio_sequence_k" + hidden = "audio_hidden" + # Self-attention dimensions + head_groups = "audio_head_groups" + group_heads = "audio_group_heads" + key_and_value = "audio_key_value" + kv_channels = "audio_kv_channels" + composite_heads = "audio_composite_heads" + composite_query = "audio_composite_query" + composite_key_value = "audio_composite_key_value" + composite_dense = "audio_composite_dense" + # MLP dimensions + mlp = "audio_mlp" + gate_and_up = "audio_gate_and_up" + composite_gated_mlp = "audio_composite_gated_mlp" + experts = "audio_experts" + top_experts = "audio_top_experts" + shared_experts = "audio_shared_experts" + unshared_experts = "audio_unshared_experts" + composite_expert_mlp = "audio_composite_expert_mlp" + composite_gated_expert_mlp = "audio_composite_gated_expert_mlp" + composite_shared_expert_mlp = "audio_composite_shared_expert_mlp" + composite_gated_shared_expert_mlp = "audio_composite_gated_shared_expert_mlp" + + +class AudioEncoderKwargs: + audio = "audio" + audio_mel = "audio_mel" + audio_positions = "audio_positions" + kv_channels = "audio_kv_channels" + hidden_dims = "audio_hidden_dims" + + +class AudioEncoderType(str, enum.Enum): + none = "none" + whisper = "whisper" + + +# # TODO Toby: do we need all of them? +class AudioTransformerKwargs: + rotary_freq_q = "audio_rotary_freq_q" + rotary_freq_k = "audio_rotary_freq_k" + attention_mask = "audio_attention_mask" + attention_mask_value = "audio_attention_mask_value" + sequence_lengths = "audio_sequence_lengths" + cu_seqlens_q = "audio_cu_seqlens_q" + cu_seqlens_k = "audio_cu_seqlens_k" + max_seqlen_q = "audio_max_seqlen_q" + max_seqlen_k = "audio_max_seqlen_k" + # TODO: Review these + presents = "audio_presents" + past_key_values = "audio_past_key_values" + sequence_first = "audio_sequence_first" + hidden_dims = "audio_hidden_dims" + sequence_q_dim = "audio_sequence_q_dim" + sequence_k_dim = "audio_sequence_k_dim" + sequence_length = "audio_sequence_length" + micro_batch_size = "audio_micro_batch_size" + # TODO: Move + grad_output = "audio_grad_output" + patch_position_ids = "patch_position_ids" + + +@config_class() +class AudioEncoderConfig(BaseModelConfig): + _abstract = False + + transformer: AudioTransformerConfig = Field( + default_factory=AudioTransformerConfig, + desc="Configuration for the audio transformer architecture.", + hint=FieldHint.core, + ) + type: AudioEncoderType = Field( + default=AudioEncoderType.none, + desc="Type of the audio encoder. Choices: none, whisper.", + hint=FieldHint.architecture, + ) + conv_bias: bool = Field( + default=False, + desc="Whether to use bias in the convolutional layer.", + hint=FieldHint.optional, + ) + adapter_size: int = Field( + default=5120, + desc="Intermediate size for the adapter linear layers. Assuming 2 linear layers", + hint=FieldHint.core, + ) + adapter_activation_type: ActivationType = Field( + default=ActivationType.gelu, + desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", + hint=FieldHint.core, + ) + aud_downsampling_k: int = Field( + default=5, + desc="Audio downsampling k parameter.", + hint=FieldHint.feature, + ) + aud_sampling_rate: int = Field( + default=16000, + desc="Audio sampling rate to use.", + hint=FieldHint.feature, + ) + + def setup_tensor_space(self, tensor_space: TensorSpace): + tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.out_channels, self.transformer.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.adapter_size, self.adapter_size)) + tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.in_channels)) + # TODO Soham: add a check for presence of kv channels parameter (head_dim) + tensor_space.add_tensor_dim( + TensorDim( + AudioEncoderDimNames.kv_channels, self.transformer.hidden_size // self.transformer.num_attention_heads + ) + ) + self.transformer.setup_tensor_space(tensor_space, type="audio") + + @property + def enabled(self) -> bool: + return self.type != AudioEncoderType.none diff --git a/fast_llm/layers/audio_encoder/encoder.py b/fast_llm/layers/audio_encoder/encoder.py new file mode 100644 index 000000000..8cd071bd4 --- /dev/null +++ b/fast_llm/layers/audio_encoder/encoder.py @@ -0,0 +1,61 @@ +import typing + +import torch + +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.audio_encoder.config import AudioEncoderConfig, AudioEncoderDimNames, AudioEncoderKwargs +from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ + + +class AudioConv(Layer): + def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): + super().__init__() + self._tensor_space = tensor_space + # TODO Toby: lr_scale + self.conv1_weight = ParameterMeta.from_dims( + ( + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels), + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.in_channels), + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.kernel_size), + ), + init_method=init_normal_(), + ) + self.conv1_stride = 1 + + self.conv2_weight = ParameterMeta.from_dims( + ( + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.in_channels), # in/out channels are the same + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.in_channels), + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.kernel_size), + ), + init_method=init_normal_(), + ) + self.conv2_stride = 2 + + if config.conv_bias: + self.bias = ParameterMeta.from_dims( + (self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels),) + ) + else: + self.bias = None + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict | None = None, + ) -> torch.Tensor: + hidden_dims = kwargs[AudioEncoderKwargs.hidden_dims] + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims(hidden_dims, tensor_name="audio conv output", dtype=input_.dtype) + input_ = torch.nn.functional.conv1d(input_, self.conv1_weight, self.bias, stride=self.conv1_stride) + input_ = torch.nn.functional.gelu(input_) + input_ = torch.nn.functional.conv1d(input_, self.conv2_weight, self.bias, stride=self.conv2_stride) + input_ = torch.nn.functional.gelu(input_) + + # TODO Toby: add learned position embeddings and dropout + audio_embeddings = audio_embeddings.reshape(*(x.size for x in hidden_dims)) + + return audio_embeddings diff --git a/fast_llm/layers/audio_encoder/preprocessing.py b/fast_llm/layers/audio_encoder/preprocessing.py new file mode 100644 index 000000000..54bfeef65 --- /dev/null +++ b/fast_llm/layers/audio_encoder/preprocessing.py @@ -0,0 +1,47 @@ +import typing + +import torch +from torchaudio.transforms import MelSpectrogram + +from fast_llm.engine.base_model.config import Preprocessor +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.audio_encoder.config import AudioEncoderConfig, AudioEncoderKwargs + +# from transformers import WhisperFeatureExtractor + + +class AudioPreprocessor(Preprocessor): + def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): + self._config = config + self._tensor_space = tensor_space + self._distributed_config = self._tensor_space.distributed_config + + # self.feature_extractor = WhisperFeatureExtractor(sampling_rate=self._config.aud_sampling_rate) + + self.mel_transform = MelSpectrogram( + sample_rate=self._config.aud_sampling_rate, + n_fft=400, + win_length=400, + hop_length=160, + n_mels=80, + f_min=0.0, + f_max=8000.0, + mel_scale="slaney", + norm="slaney", + center=True, + power=2.0, + ) + + def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + pass + + def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: + audio_raw = kwargs[AudioEncoderKwargs.audio] + # audio_inputs = self.feature_extractor(audio_raw, sampling_rate=16000, return_tensors="pt") + self.mel_transform.to(self._tensor_space.distributed.device) + + audio_mel = [] + for batch in audio_raw: + batch_stacked = torch.stack(batch).unsqueeze(1) + audio_mel.append(self.mel_transform(batch_stacked)) + kwargs[AudioEncoderKwargs.audio_mel] = torch.cat(audio_mel) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index ea72de5c4..625e5da65 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -48,12 +48,12 @@ class LanguageModelBaseConfig(BaseModelConfig): hint=FieldHint.architecture, ) # TODO Soham: make this None by default. Need to figure out how to handle this in the config (see ) - vision_encoder: VisionEncoderConfig | None = Field( + vision_encoder: VisionEncoderConfig = Field( default_factory=VisionEncoderConfig, desc="Configuration for the vision encoder that transforms images into embeddings.", hint=FieldHint.optional, ) - audio_encoder: AudioEncoderConfig | None = Field( + audio_encoder: AudioEncoderConfig = Field( default_factory=AudioEncoderConfig, desc="Configuration for the audio encoder that transforms audio into embeddings.", hint=FieldHint.optional, diff --git a/fast_llm/layers/transformer/audio_transformer.py b/fast_llm/layers/transformer/audio_transformer.py new file mode 100644 index 000000000..43ee3f465 --- /dev/null +++ b/fast_llm/layers/transformer/audio_transformer.py @@ -0,0 +1,41 @@ +import torch + +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.layers.audio_encoder.config import AudioTransformerDimNames, AudioTransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.tensor import TensorMeta + + +class AudioTransformerLayer(TransformerLayer): + """ + A vision transformer layer to encode image patches + """ + + def __init__( + self, + config: TransformerConfig, + tensor_space: TensorSpace, + layer_index: int, + return_input: bool = False, + ): + super().__init__(config, tensor_space, layer_index, return_input) + + hidden_dim = self._tensor_space.get_tensor_dim(AudioTransformerDimNames.hidden) + + # use regular layernorm (not rms norm) + self.norm_1 = self._config.normalization.get_layer(hidden_dim) + self.norm_2 = self._config.normalization.get_layer(hidden_dim) + + self.norm_1 = self._config.peft.apply_other(self.norm_1) + self.norm_2 = self._config.peft.apply_other(self.norm_2) + + @property + def name(self) -> str: + return f"Audio transformer layer {self._layer_index}" + + def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): + dims = kwargs[AudioTransformerKwargs.hidden_dims] + if self._return_input: + dims = (TensorDim("stacked_input_output", 2),) + dims + return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 40a29959b..3847fd314 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -170,6 +170,15 @@ class VisionRotaryConfig(RotaryConfig): ) +# @config_class() +# class AudioRotaryConfig(RotaryConfig): +# type: RotaryEmbeddingType = Field( +# default=RotaryEmbeddingType.none, +# desc="The type of rotary embedding to use. Choices: none, default, llama3, yarn, pixtral.", +# hint=FieldHint.feature, +# ) + + class AddLinearBiasChoices(str, enum.Enum): nowhere = "nowhere" everywhere = "everywhere" @@ -664,6 +673,10 @@ def setup_tensor_space(self, tensor_space: TensorSpace, type: str | None = None) from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames transformer_dim_names = VisionTransformerDimNames + elif type == "audio": + from fast_llm.layers.audio_encoder.config import AudioTransformerDimNames + + transformer_dim_names = AudioTransformerDimNames else: transformer_dim_names = TransformerDimNames tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) @@ -775,6 +788,11 @@ class AudioTransformerConfig(TransformerConfig): causal: bool = FieldUpdate( default=False, - desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Vision Transformer.", + desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Audio Transformer.", hint=FieldHint.feature, ) + # rotary: AudioRotaryConfig = FieldUpdate( + # default_factory=AudioRotaryConfig, + # desc="Configuration for the rotary positional embeddings.", + # hint=FieldHint.feature, + # ) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 162015768..f82051abf 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -51,13 +51,20 @@ class MistralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): class MixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "mixtral" + class MTPLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "mtp_llama" trust_remote_code: typing.ClassVar[bool] = True + class LlavaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "llava" + +class WhisperGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "whisper" + + @config_class() class GPTBatchConfig(BatchConfig): sequence_length: int = Field( @@ -140,6 +147,7 @@ class GPTModelConfig(FastLLMModelConfig): MixtralGPTHuggingfaceCheckpointFormat, MTPLlamaGPTHuggingfaceCheckpointFormat, LlavaGPTHuggingfaceCheckpointFormat, + WhisperGPTHuggingfaceCheckpointFormat, ) @classmethod diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index ad4df7378..ec765e55a 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -25,6 +25,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.functional.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex +from fast_llm.layers.audio_encoder.config import AudioEncoderType from fast_llm.layers.common.config import NormalizationType from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType, TransformerConfig from fast_llm.layers.vision_encoder.config import VisionEncoderType @@ -38,6 +39,7 @@ MTPLlamaGPTHuggingfaceCheckpointFormat, Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, + WhisperGPTHuggingfaceCheckpointFormat, ) from fast_llm.models.gpt.external.mtp_llama.configuration_mtp_llama import MTPLlamaConfig from fast_llm.models.gpt.model import GPTModel @@ -555,6 +557,242 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ] +class WhisperHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = WhisperGPTHuggingfaceCheckpointFormat + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + # lm_converters = super()._create_config_converters() + lm_converters = super()._create_config_converters() + for idx, converter in enumerate(lm_converters): + if converter.export_names == (("model_type",),): + continue + elif converter.export_names == (("architectures",),): + ignore_index = idx + if converter.export_names: + converter.export_names = (("text_config", *converter.export_names[0]), *converter.export_names[1:]) + + return ( + lm_converters[:ignore_index] + + lm_converters[ignore_index + 1 :] + + [ + ConstantImportParamConverter( + fast_llm_names=(("audio_encoder", "type"),), fast_llm_value=AudioEncoderType.whisper + ), + ConstantExportParamConverter( + export_names=(("architectures",),), export_value=["WhisperForConditionalGeneration"] + ), + # Audio Adapter + # RenameParamConverter( + # fast_llm_names=(("vision_encoder", "adapter_size"),), + # export_names=(("text_config", "hidden_size"),), + # ), + # ConstantImportParamConverter( + # fast_llm_names=(("vision_encoder", "patch_norm", "type"),), + # fast_llm_value=NormalizationType.rms_norm, + # ), + # ConstantImportParamConverter( + # fast_llm_names=(("vision_encoder", "transformer", "normalization", "type"),), + # fast_llm_value=NormalizationType.rms_norm, + # ), + # Audio Transformer + RenameParamConverter( + fast_llm_names=(("audio_encoder", "transformer", "num_layers"),), + export_names=(("encoder_layers",),), + ), + RenameParamConverter( + fast_llm_names=(("audio_encoder", "transformer", "hidden_size"),), + export_names=(("d_model",),), + ), + RenameParamConverter( + fast_llm_names=(("audio_encoder", "transformer", "num_attention_heads"),), + export_names=(("encoder_attention_heads",),), + ), + # RenameParamConverter( + # fast_llm_names=(("audio_encoder", "transformer", "head_groups"),), + # export_names=( + # ( + # "encoder_attention_heads", + # ), + # ), + # ), + RenameParamConverter( + fast_llm_names=(("audio_encoder", "transformer", "ffn_hidden_size"),), + export_names=(("encoder_ffn_dim",),), + ), + MappedConfigParamConverter( + fast_llm_names=(("audio_encoder", "transformer", "activation_type"),), + export_names=(("activation_function",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + # ConstantImportParamConverter( + # fast_llm_names=(("vision_encoder", "transformer", "gated"),), fast_llm_value=True + # ), + # MappedConfigParamConverter( + # fast_llm_names=(("vision_encoder", "adapter_activation_type"),), + # export_names=(("projector_hidden_act",),), + # fast_llm_value=ActivationType.from_hf_name, + # export_value=lambda activation_type: activation_type.hf_name, + # ), + # ConstantImportParamConverter( + # fast_llm_names=(("vision_encoder", "transformer", "add_linear_biases"),), fast_llm_value=False + # ), + # RenameParamConverter( + # fast_llm_names=(("vision_encoder", "transformer", "rotary", "theta"),), + # export_names=(("vision_config", "rope_theta"),), + # ), + ] + ) + + def _create_vision_transformer_layer_converters( + self, + i: int, + ignore_export: bool = False, + hf_base_prefix: str = "", + fast_llm_offset: int = 1, + type: str | None = None, + ) -> list[WeightConverter]: + if type is not None: + if type == "vision": + transformer_config: TransformerConfig = self._model.config.base_model.vision_encoder.transformer + else: + transformer_config: TransformerConfig = self._model.config.base_model.transformer + norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm + converters = [] + names_bias_cls = [ + # Self-attn + ( + f"layers.{i+fast_llm_offset}.self_attn.query", + f"vision_tower.transformer.layers.{i}.attention.q_proj", + transformer_config.add_attn_qkv_bias, + QueryWeightConverter, + ), + ( + f"layers.{i+fast_llm_offset}.self_attn.key_value", + ( + f"vision_tower.transformer.layers.{i}.attention.k_proj", + f"vision_tower.transformer.layers.{i}.attention.v_proj", + ), + transformer_config.add_attn_qkv_bias, + KeyValueWeightConverter, + ), + ( + f"layers.{i+fast_llm_offset}.self_attn.dense", + f"vision_tower.transformer.layers.{i}.attention.o_proj", + transformer_config.add_attn_dense_bias, + WeightConverter, + ), + # Norm + ( + f"layers.{i+fast_llm_offset}.norm_1", + f"vision_tower.transformer.layers.{i}.attention_norm", + norm_bias, + WeightConverter, + ), + ( + f"layers.{i+fast_llm_offset}.norm_2", + f"vision_tower.transformer.layers.{i}.ffn_norm", + norm_bias, + WeightConverter, + ), + ] + for fast_llm_prefix, hf_prefix, use_bias, cls in names_bias_cls: + converters += self._get_weight_and_bias_converters( + fast_llm_prefix, + () if ignore_export else hf_prefix, + use_bias, + cls=IgnoreExportWeightConverter if ignore_export else cls, + ) + + # MLP + if ignore_export: + converters += self._get_weight_and_bias_converters( + f"layers.{i+fast_llm_offset}.mlp.layer_1", + (), + transformer_config.add_mlp_bias, + cls=IgnoreExportWeightConverter, + ) + converters += self._get_weight_and_bias_converters( + f"layers.{i+fast_llm_offset}.mlp.layer_2", + (), + transformer_config.add_mlp_bias, + cls=IgnoreExportWeightConverter, + ) + converters += [IgnoreExportWeightConverter(f"layers.{i+fast_llm_offset}.mlp.router.weight", ())] + else: + converters += self._get_vision_transformer_mlp_converters( + f"layers.{i+fast_llm_offset}", f"vision_tower.transformer.layers.{i}" + ) + return converters + + def _get_vision_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + return [ + SplitWeightConverter( + f"{fast_llm_prefix}.mlp.layer_1.weight", + (f"{hf_prefix}.feed_forward.gate_proj.weight", f"{hf_prefix}.feed_forward.up_proj.weight"), + ), + MLPLayer2Converter( + f"{fast_llm_prefix}.mlp.layer_2.weight", + f"{hf_prefix}.feed_forward.down_proj.weight", + self._model.config.base_model, + ), + ] + + def _create_vision_transformer_converters(self) -> list[WeightConverter]: + num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers + vision_transformer_converters = [] + for layer in range(num_layers): + # TODO Soham: check if args are correct + vision_transformer_converters.extend( + self._create_vision_transformer_layer_converters( + layer, + ignore_export=False, + hf_base_prefix="vision_tower.transformer.layers.", + fast_llm_offset=1, + type="vision", + ) + ) + + return vision_transformer_converters + + def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: + patch_conv_converters = [WeightConverter("layers.0.weight", "vision_tower.patch_conv.weight")] + if self._model.config.base_model.vision_encoder.conv_bias: + patch_conv_converters.append(WeightConverter("layers.0.bias", "vision_tower.patch_conv.bias")) + layernorm_converters = [ + WeightConverter("layers.0.norm.weight", "vision_tower.ln_pre.weight"), + ] + if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: + layernorm_converters.append(WeightConverter("layers.0.norm.bias", "vision_tower.ln_pre.bias")) + + vision_transformer_converters = self._create_vision_transformer_converters() + offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 1 + adapter_converters = [ + WeightConverter(f"layers.{offset}.layer_1.weight", "multi_modal_projector.linear_1.weight"), + WeightConverter(f"layers.{offset}.layer_1.bias", "multi_modal_projector.linear_1.bias"), + # TODO Soham: add bias based on config + WeightConverter(f"layers.{offset}.layer_2.weight", "multi_modal_projector.linear_2.weight"), + WeightConverter(f"layers.{offset}.layer_2.bias", "multi_modal_projector.linear_2.bias"), + ] + + return patch_conv_converters + layernorm_converters + vision_transformer_converters + adapter_converters + + def _create_weight_converters(self) -> list[WeightConverter]: + vision_encoder_converter = self._create_vision_encoder_weight_converters() + offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 3 + # Embeddings + lm_converters = [ + WeightConverter(f"layers.{offset - 1}.word_embeddings_weight", f"language_model.model.embed_tokens.weight") + ] + for i in range(self._model.config.base_model.transformer.num_layers): + lm_converters += self._create_transformer_layer_converters( + fast_llm_layer_name=f"layers.{i + offset}", hf_layer_name=f"language_model.model.layers.{i}" + ) + lm_converters += self._create_lm_head_converters(hf_base_prefix="language_model.", fast_llm_offset=offset) + return vision_encoder_converter + lm_converters + + class LlavaHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat @@ -950,4 +1188,5 @@ class AutoGPTHuggingfaceCheckpointHandler( MixtralGPTHuggingfaceCheckpointFormat.name: MixtralHuggingfaceCheckpointHandler, MTPLlamaGPTHuggingfaceCheckpointFormat.name: MTPLlamaHuggingfaceCheckpointHandler, LlavaGPTHuggingfaceCheckpointFormat.name: LlavaHuggingfaceCheckpointHandler, + WhisperGPTHuggingfaceCheckpointFormat.name: WhisperHuggingfaceCheckpointHandler, } diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 1b91f3e65..9b450c731 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -10,13 +10,16 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.layers.audio_encoder.adapter import AudioAdapter from fast_llm.layers.audio_encoder.config import AudioEncoderKwargs +from fast_llm.layers.audio_encoder.encoder import AudioConv from fast_llm.layers.audio_encoder.preprocessing import AudioPreprocessor from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor from fast_llm.layers.multi_modal.embedding import MultiModalEmbedding +from fast_llm.layers.transformer.audio_transformer import AudioTransformerLayer from fast_llm.layers.transformer.config import ( RoutingType, TransformerDimNames, @@ -84,7 +87,7 @@ def __init__( self._preprocessors.append( RotaryEmbeddingPreprocessor(self._config.vision_encoder.transformer.rotary, self._tensor_space) ) - if self._config.audio_encoder: + if self._config.audio_encoder.enabled: self._preprocessors.append(AudioPreprocessor(self._config.audio_encoder, self._tensor_space)) def get_output_layers(self) -> list[Layer]: @@ -124,12 +127,33 @@ def get_vision_layers(self) -> list[Layer]: MultiModalEmbedding(self._config, self._tensor_space), ] + def get_audio_layers(self) -> list[Layer]: + audio_conv = AudioConv(self._config.audio_encoder, self._tensor_space) + audio_layers = [ + AudioTransformerLayer(self._config.audio_encoder.transformer, self._tensor_space, layer_index=idx + 1) + for idx in range(self._config.audio_encoder.transformer.num_layers) + ] + return [ + audio_conv, + *audio_layers, + AudioAdapter(self._config.audio_encoder, self._tensor_space), + MultiModalEmbedding(self._config, self._tensor_space), + ] + + def get_multimodal_layers(self) -> list[Layer]: + if self._config.vision_encoder.enabled: + return self.get_vision_layers() + elif self._config.audio_encoder.enabled: + return self.get_audio_layers() + else: + assert False + def get_layers(self) -> list[Layer]: return [ *( [LanguageModelEmbedding(self._config, self._tensor_space)] - if not self._config.vision_encoder.enabled - else self.get_vision_layers() + if not self._config.vision_encoder.enabled and not self._config.audio_encoder.enabled + else self.get_multimodal_layers() ), *[ TransformerLayer( @@ -423,7 +447,7 @@ def preprocess( if batch.audio is not None: kwargs[AudioEncoderKwargs.audio] = [ [ - aud.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) + aud.to(device=self._tensor_space.distributed.device, dtype=torch.float32, non_blocking=True) for aud in audio ] for audio in batch.audio @@ -434,8 +458,11 @@ def preprocess( for preprocessor in self._preprocessors: preprocessor.preprocess(tokens, kwargs) image_patches = kwargs.get(VisionEncoderKwargs.image_patches, None) + audio_mel = kwargs.get(AudioEncoderKwargs.audio_mel, None) if image_patches is not None: preprocessed.append((image_patches, kwargs)) + elif audio_mel is not None: + preprocessed.append((audio_mel, kwargs)) else: preprocessed.append((tokens, kwargs)) From 7eea79bdfe1cd0275fd4310d75487a2b78c7a998 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Wed, 28 May 2025 00:21:31 +0000 Subject: [PATCH 08/25] update audio encoder --- fast_llm/data/data/gpt/data.py | 2 +- fast_llm/data/dataset/gpt/sampled.py | 7 +- fast_llm/layers/audio_encoder/adapter.py | 46 +++++-- fast_llm/layers/audio_encoder/config.py | 123 ++++++++---------- fast_llm/layers/audio_encoder/encoder.py | 52 ++++++-- .../layers/audio_encoder/preprocessing.py | 31 ++++- 6 files changed, 158 insertions(+), 103 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 577ebc794..a4b183ca0 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -83,7 +83,7 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling images=batch_images if has_images else None, image_positions=batch_image_positions if has_images else None, audio=batch_audio if has_audio else None, - audio_positions=batch_image_positions if has_audio else None, + audio_positions=batch_audio_positions if has_audio else None, ) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 36f8a0e70..c21e0825e 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -622,11 +622,10 @@ def __getitem__(self, index: int) -> typing.Any: if self._parameters.use_loss_masking_spans else None ) - # images = [im for img_list in images for im in img_list] if images else None - # image_positions = np.array(image_positions) if image_positions else None - images = None + images = [im for img_list in images for im in img_list] if images else None + image_positions = np.array(image_positions) if image_positions else None - audio = [aud for aud_list in audio for aud in aud_list] if audio else None + audio = [aud for aud_list in audio for aud in aud_list] if audio else None # flatten audio_positions = np.array(audio_positions) if audio_positions else None # Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens) diff --git a/fast_llm/layers/audio_encoder/adapter.py b/fast_llm/layers/audio_encoder/adapter.py index 4f77971e8..8c0c7175b 100644 --- a/fast_llm/layers/audio_encoder/adapter.py +++ b/fast_llm/layers/audio_encoder/adapter.py @@ -5,9 +5,9 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.triton.mlp import torch_mlp_activation +from fast_llm.layers.audio_encoder.config import AudioEncoderConfig, AudioEncoderDimNames from fast_llm.layers.common.linear import Linear from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs -from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames from fast_llm.tensor import TensorMeta, init_normal_ @@ -16,26 +16,36 @@ class AudioAdapter(Layer): Vision adapter layer for the LLM. """ - def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): + def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): super().__init__() - input_dim = tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) + audio_hidden_dim = tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels) + input_dim = tensor_space.get_tensor_dim(AudioEncoderDimNames.adapter_input) self._activation_type = config.adapter_activation_type + self._use_adapter_bias = config.adapter_bias + + self.norm_1 = config.transformer.normalization.get_layer(audio_hidden_dim) + self.norm_2 = config.transformer.normalization.get_layer( + tensor_space.get_tensor_dim(AudioEncoderDimNames.adapter_size) + ) + # TODO Soham: Make them OutputParallelLinear instead? How would this work with parallelism? self.layer_1 = Linear( input_dim, - tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), - bias=True, + tensor_space.get_tensor_dim(AudioEncoderDimNames.adapter_size), + bias=self._use_adapter_bias, weight_init_method=init_normal_(), bias_init_method=init_normal_(), ) self.layer_2 = Linear( - tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), + tensor_space.get_tensor_dim(AudioEncoderDimNames.adapter_size), tensor_space.get_tensor_dim(TransformerDimNames.hidden), - bias=True, + bias=self._use_adapter_bias, weight_init_method=init_normal_(), bias_init_method=init_normal_(), ) + self.aud_downsampling_k = config.aud_downsampling_k + def forward( self, input_: torch.Tensor, @@ -46,9 +56,25 @@ def forward( if isinstance(input_, TensorMeta): return TensorMeta.from_dims( kwargs[TransformerKwargs.hidden_dims], - tensor_name="Vision adapter output", + tensor_name="Audio adapter output", dtype=input_.dtype, ) - return self.layer_2( - torch_mlp_activation(input_=self.layer_1(input_), gated=False, activation_type=self._activation_type) + batch_size, seq_len, dim = input_.size() + + # Check if sequence length is divisible by downsampling rate. + if seq_len % self.aud_downsampling_k != 0: + # If not divisible, trim the end of the sequence. + trimmed_seq_len = seq_len - (seq_len % self.aud_downsampling_k) + input_ = input_[:, :trimmed_seq_len, :] + seq_len = trimmed_seq_len + + # Reshape: group every k frames together (concatenate along feature dimension). + new_seq_len = seq_len // self.aud_downsampling_k + input_ = input_.contiguous().view(batch_size, new_seq_len, dim * self.aud_downsampling_k) + + res = self.layer_2( + self.norm_2( + torch_mlp_activation(input_=self.layer_1(input_), gated=False, activation_type=self._activation_type) + ) ) + return res diff --git a/fast_llm/layers/audio_encoder/config.py b/fast_llm/layers/audio_encoder/config.py index 52a8673e4..3e09b39f9 100644 --- a/fast_llm/layers/audio_encoder/config.py +++ b/fast_llm/layers/audio_encoder/config.py @@ -11,100 +11,65 @@ class AudioEncoderDimNames: in_channels = "audio_in_channels" out_channels = "audio_out_channels" kernel_size = "audio_kernel_size" + adapter_input = "audio_adapter_input" adapter_size = "audio_adapter_size" audio_channels = "audio_kv_channels" - - -class AudioTransformerDimNames: - # A set of common tensor dim names packed into a namespace. - # Input dimensions (variable) - # TODO: Does batch belong here? - batch = "audio_batch" - # TODO: Distinguish micro-sequence? - sequence_q = "audio_sequence_q" - sequence_q_tp = "audio_sequence_q_tp" - sequence_k = "audio_sequence_k" - hidden = "audio_hidden" - # Self-attention dimensions - head_groups = "audio_head_groups" - group_heads = "audio_group_heads" - key_and_value = "audio_key_value" - kv_channels = "audio_kv_channels" - composite_heads = "audio_composite_heads" - composite_query = "audio_composite_query" - composite_key_value = "audio_composite_key_value" - composite_dense = "audio_composite_dense" - # MLP dimensions - mlp = "audio_mlp" - gate_and_up = "audio_gate_and_up" - composite_gated_mlp = "audio_composite_gated_mlp" - experts = "audio_experts" - top_experts = "audio_top_experts" - shared_experts = "audio_shared_experts" - unshared_experts = "audio_unshared_experts" - composite_expert_mlp = "audio_composite_expert_mlp" - composite_gated_expert_mlp = "audio_composite_gated_expert_mlp" - composite_shared_expert_mlp = "audio_composite_shared_expert_mlp" - composite_gated_shared_expert_mlp = "audio_composite_gated_shared_expert_mlp" + max_source_positions = "audio_max_source_positions" class AudioEncoderKwargs: audio = "audio" audio_mel = "audio_mel" audio_positions = "audio_positions" - kv_channels = "audio_kv_channels" + + kv_channels = "audio_kv_channels" # TODO: check this + out_channels = "audio_out_channels" hidden_dims = "audio_hidden_dims" + # TODO: used for backup attention + sequence_length = "audio_sequence_length" + sequence_k_dim = "audio_sequence_k_dim" + sequence_q_dim = "audio_sequence_q_dim" + class AudioEncoderType(str, enum.Enum): none = "none" whisper = "whisper" -# # TODO Toby: do we need all of them? -class AudioTransformerKwargs: - rotary_freq_q = "audio_rotary_freq_q" - rotary_freq_k = "audio_rotary_freq_k" - attention_mask = "audio_attention_mask" - attention_mask_value = "audio_attention_mask_value" - sequence_lengths = "audio_sequence_lengths" - cu_seqlens_q = "audio_cu_seqlens_q" - cu_seqlens_k = "audio_cu_seqlens_k" - max_seqlen_q = "audio_max_seqlen_q" - max_seqlen_k = "audio_max_seqlen_k" - # TODO: Review these - presents = "audio_presents" - past_key_values = "audio_past_key_values" - sequence_first = "audio_sequence_first" - hidden_dims = "audio_hidden_dims" - sequence_q_dim = "audio_sequence_q_dim" - sequence_k_dim = "audio_sequence_k_dim" - sequence_length = "audio_sequence_length" - micro_batch_size = "audio_micro_batch_size" - # TODO: Move - grad_output = "audio_grad_output" - patch_position_ids = "patch_position_ids" - - @config_class() class AudioEncoderConfig(BaseModelConfig): _abstract = False - transformer: AudioTransformerConfig = Field( - default_factory=AudioTransformerConfig, - desc="Configuration for the audio transformer architecture.", - hint=FieldHint.core, - ) type: AudioEncoderType = Field( default=AudioEncoderType.none, desc="Type of the audio encoder. Choices: none, whisper.", hint=FieldHint.architecture, ) + transformer: AudioTransformerConfig = Field( + default_factory=AudioTransformerConfig, + desc="Configuration for the audio transformer architecture.", + hint=FieldHint.core, + ) + + # encoder configs conv_bias: bool = Field( - default=False, + default=True, desc="Whether to use bias in the convolutional layer.", hint=FieldHint.optional, ) + encoder_dropout: float = Field( + default=0.0, + desc="Dropout for encoder.", + hint=FieldHint.core, + ) + kernel_size: int = Field( + default=3, + desc="Encoder convolution layer kernel size.", + hint=FieldHint.core, + ) + + # adapter configs adapter_size: int = Field( default=5120, desc="Intermediate size for the adapter linear layers. Assuming 2 linear layers", @@ -115,6 +80,18 @@ class AudioEncoderConfig(BaseModelConfig): desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", hint=FieldHint.core, ) + adapter_bias: bool = Field( + default=True, + desc="Whether to use bias in the adapter layer.", + hint=FieldHint.optional, + ) + + # audio configs + num_mel_bins: int = Field( + default=80, + desc="Number of bins for mel spectogram.", + hint=FieldHint.core, + ) aud_downsampling_k: int = Field( default=5, desc="Audio downsampling k parameter.", @@ -127,16 +104,24 @@ class AudioEncoderConfig(BaseModelConfig): ) def setup_tensor_space(self, tensor_space: TensorSpace): + tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.in_channels, self.num_mel_bins)) tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.out_channels, self.transformer.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.kernel_size, self.kernel_size)) + tensor_space.add_tensor_dim( + TensorDim(AudioEncoderDimNames.adapter_input, self.transformer.hidden_size * self.aud_downsampling_k) + ) tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.adapter_size, self.adapter_size)) - tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.in_channels)) - # TODO Soham: add a check for presence of kv channels parameter (head_dim) + tensor_space.add_tensor_dim( + TensorDim(AudioEncoderDimNames.max_source_positions, 1500) + ) # TODO: configure later + tensor_space.add_tensor_dim( TensorDim( - AudioEncoderDimNames.kv_channels, self.transformer.hidden_size // self.transformer.num_attention_heads + AudioEncoderDimNames.audio_channels, + self.transformer.hidden_size // self.transformer.num_attention_heads, ) ) - self.transformer.setup_tensor_space(tensor_space, type="audio") + self.transformer.setup_tensor_space(tensor_space) @property def enabled(self) -> bool: diff --git a/fast_llm/layers/audio_encoder/encoder.py b/fast_llm/layers/audio_encoder/encoder.py index 8cd071bd4..20c7d5078 100644 --- a/fast_llm/layers/audio_encoder/encoder.py +++ b/fast_llm/layers/audio_encoder/encoder.py @@ -4,7 +4,8 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.audio_encoder.config import AudioEncoderConfig, AudioEncoderDimNames, AudioEncoderKwargs +from fast_llm.layers.audio_encoder.config import AudioEncoderConfig, AudioEncoderDimNames +from fast_llm.layers.transformer.config import AudioTransformerKwargs from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ @@ -12,6 +13,8 @@ class AudioConv(Layer): def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): super().__init__() self._tensor_space = tensor_space + self.dropout_p = config.encoder_dropout + # TODO Toby: lr_scale self.conv1_weight = ParameterMeta.from_dims( ( @@ -21,24 +24,36 @@ def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): ), init_method=init_normal_(), ) - self.conv1_stride = 1 + self.conv1_stride = 1 # TODO: parameterize? self.conv2_weight = ParameterMeta.from_dims( ( - self._tensor_space.get_tensor_dim(AudioEncoderDimNames.in_channels), # in/out channels are the same - self._tensor_space.get_tensor_dim(AudioEncoderDimNames.in_channels), + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels), + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels), self._tensor_space.get_tensor_dim(AudioEncoderDimNames.kernel_size), ), init_method=init_normal_(), ) - self.conv2_stride = 2 + self.conv2_stride = 2 # TODO: parameterize? if config.conv_bias: - self.bias = ParameterMeta.from_dims( - (self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels),) + self.conv1_bias = ParameterMeta.from_dims( + (self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels),), init_method=init_normal_() + ) + self.conv2_bias = ParameterMeta.from_dims( + (self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels),), init_method=init_normal_() ) else: - self.bias = None + self.conv1_bias = None + self.conv2_bias = None + + self.positional_embeddings = ParameterMeta.from_dims( + ( + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.max_source_positions), + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels), + ), + init_method=init_normal_(), + ) def forward( self, @@ -47,15 +62,24 @@ def forward( losses: dict[str, typing.Any] | None = None, metrics: dict | None = None, ) -> torch.Tensor: - hidden_dims = kwargs[AudioEncoderKwargs.hidden_dims] + hidden_dims = kwargs[AudioTransformerKwargs.hidden_dims] # TODO: check seq q if isinstance(input_, TensorMeta): return TensorMeta.from_dims(hidden_dims, tensor_name="audio conv output", dtype=input_.dtype) - input_ = torch.nn.functional.conv1d(input_, self.conv1_weight, self.bias, stride=self.conv1_stride) + + # TODO: check how to best cast dtype + input_ = input_.to(self.conv1_weight.dtype) + + input_ = torch.nn.functional.conv1d( + input_, self.conv1_weight, self.conv1_bias, stride=self.conv1_stride, padding=1 + ) input_ = torch.nn.functional.gelu(input_) - input_ = torch.nn.functional.conv1d(input_, self.conv2_weight, self.bias, stride=self.conv2_stride) + input_ = torch.nn.functional.conv1d( + input_, self.conv2_weight, self.conv2_bias, stride=self.conv2_stride, padding=1 + ) input_ = torch.nn.functional.gelu(input_) - # TODO Toby: add learned position embeddings and dropout - audio_embeddings = audio_embeddings.reshape(*(x.size for x in hidden_dims)) + audio_embeddings = input_.permute(0, 2, 1) + audio_embeddings = audio_embeddings + self.positional_embeddings + audio_embeddings = torch.nn.functional.dropout(audio_embeddings, p=self.dropout_p, training=self.training) - return audio_embeddings + return audio_embeddings.contiguous() diff --git a/fast_llm/layers/audio_encoder/preprocessing.py b/fast_llm/layers/audio_encoder/preprocessing.py index 54bfeef65..9d0db1b41 100644 --- a/fast_llm/layers/audio_encoder/preprocessing.py +++ b/fast_llm/layers/audio_encoder/preprocessing.py @@ -33,15 +33,36 @@ def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): ) def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + # kwargs[AudioEncoderKwargs.audio_mel_meta] = TensorMeta.from_dims( + # ( + # TensorDim( + # VisionTransformerDimNames.batch, + # kwargs[TransformerKwargs.micro_batch_size] * kwargs[TransformerKwargs.sequence_q_dim].size, + # ), + # TensorDim(VisionEncoderDimNames.in_channels, 3), + # TensorDim(VisionEncoderDimNames.patch_size, kwargs[VisionEncoderKwargs.patch_size]), + # TensorDim(VisionEncoderDimNames.patch_size, kwargs[VisionEncoderKwargs.patch_size]), + # ), + # dtype=self._distributed_config.training_dtype.torch, + # ) pass def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: audio_raw = kwargs[AudioEncoderKwargs.audio] + flattened_audio = [audio_arr for sequence in audio_raw for audio_arr in sequence] + flattened_audio_tensor = torch.stack(flattened_audio, dim=0) # audio_inputs = self.feature_extractor(audio_raw, sampling_rate=16000, return_tensors="pt") self.mel_transform.to(self._tensor_space.distributed.device) - audio_mel = [] - for batch in audio_raw: - batch_stacked = torch.stack(batch).unsqueeze(1) - audio_mel.append(self.mel_transform(batch_stacked)) - kwargs[AudioEncoderKwargs.audio_mel] = torch.cat(audio_mel) + audio_mel = self.mel_transform(flattened_audio_tensor) + audio_mel = audio_mel[:, :, :-1] # TODO Toby: check this! + + # # set attention mask # TODO Toby: fix backup attention + # sequence_k = kwargs[self._transformer_kwargs.sequence_k_dim].size + # sequence_q = kwargs[self._transformer_kwargs.sequence_q_dim].size + # kwargs[self._transformer_kwargs.attention_mask] = self._mask[ + # None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k + # ] + # kwargs[self._transformer_kwargs.attention_mask_value] = self._mask_value + + kwargs[AudioEncoderKwargs.audio_mel] = audio_mel From daf98b3c9eac1c49f12a0f34e5d056b9c9d7a351 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Wed, 28 May 2025 00:23:20 +0000 Subject: [PATCH 09/25] audio transformer updates --- fast_llm/layers/language_model/config.py | 2 ++ fast_llm/layers/multi_modal/embedding.py | 18 +++++++++-- .../layers/transformer/audio_transformer.py | 5 ++- fast_llm/layers/transformer/config.py | 31 +++++++++++++++++-- fast_llm/layers/transformer/preprocessing.py | 6 ++-- 5 files changed, 51 insertions(+), 11 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 4a425da8e..8ba066cb3 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -182,6 +182,8 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab_tp, self.vocab_size, tensor)) if self.vision_encoder.enabled: self.vision_encoder.setup_tensor_space(tensor_space) + if self.audio_encoder.enabled: + self.audio_encoder.setup_tensor_space(tensor_space) @property def num_absolute_position_embeddings(self) -> int: diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 8c541e983..3bce539d7 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -5,6 +5,7 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import gather, reduce_forward, split from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.audio_encoder.config import AudioEncoderKwargs from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.transformer.config import TransformerKwargs @@ -34,6 +35,7 @@ def _forward( position_ids: torch.Tensor | None, image_positions: list[torch.Tensor] | None, image_sizes: list[list[tuple[int, int]]] | None, + audio_positions: list[torch.Tensor] | None, ) -> torch.Tensor: """ Forward pass for the multi-modal embedding layer. @@ -57,6 +59,7 @@ def _forward( embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) embeddings = embeddings.clone() input_ = gather(input_, group, dim=0) + # TODO: Toby implement audio for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 for position, size in zip(positions, sizes): @@ -91,6 +94,13 @@ def _forward( ] image_embedding_offset += num_image_tokens + audio_position_idx = 0 + for sample_idx, positions in enumerate(audio_positions): + for position in positions: + num_audio_tokens = input_.shape[1] # TODO: Toby better way to get this? + embeddings[sample_idx, position : position + num_audio_tokens] = input_[audio_position_idx] + audio_position_idx += 1 + if self._use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) with set_generator( @@ -114,9 +124,11 @@ def forward( tensor_name="Embedding output", dtype=self._residual_dtype, ) + # TODO: How do we support both Audio and Vision? position_ids = kwargs.get(LanguageModelKwargs.position_ids) - image_sizes = kwargs.get(VisionEncoderKwargs.image_sizes) - image_positions = kwargs.get(VisionEncoderKwargs.image_positions) + image_sizes = kwargs.get(VisionEncoderKwargs.image_sizes, []) + image_positions = kwargs.get(VisionEncoderKwargs.image_positions, []) + audio_positions = kwargs.get(AudioEncoderKwargs.audio_positions, []) tokens = kwargs.get(LanguageModelKwargs.tokens) - return self._forward(input_, tokens, position_ids, image_positions, image_sizes) + return self._forward(input_, tokens, position_ids, image_positions, image_sizes, audio_positions) diff --git a/fast_llm/layers/transformer/audio_transformer.py b/fast_llm/layers/transformer/audio_transformer.py index 43ee3f465..f0fb6d17f 100644 --- a/fast_llm/layers/transformer/audio_transformer.py +++ b/fast_llm/layers/transformer/audio_transformer.py @@ -1,15 +1,14 @@ import torch from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.audio_encoder.config import AudioTransformerDimNames, AudioTransformerKwargs -from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.layers.transformer.config import AudioTransformerDimNames, AudioTransformerKwargs, TransformerConfig from fast_llm.layers.transformer.transformer import TransformerLayer from fast_llm.tensor import TensorMeta class AudioTransformerLayer(TransformerLayer): """ - A vision transformer layer to encode image patches + A audio transformer layer to encode image patches """ def __init__( diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index c9d379abd..45d911a67 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -71,6 +71,10 @@ class VisionTransformerDimNames(BaseTransformerDimNames, prefix="image_encoder") pass +class AudioTransformerDimNames(BaseTransformerDimNames, prefix="audio_encoder"): + pass + + class BaseTransformerKwargs: _kwargs_attributes = { "rotary_freq_q": "rotary_freq_q", @@ -110,6 +114,10 @@ class VisionTransformerKwargs(BaseTransformerKwargs, prefix="image_encoder"): patch_position_ids = "patch_position_ids" +class AudioTransformerKwargs(BaseTransformerKwargs, prefix="audio_encoder"): + pass + + class TransformerLossNames: load_balancing_loss = "load_balancing_loss" router_z_loss = "router_z_loss" @@ -127,6 +135,7 @@ class RotaryEmbeddingType(str, enum.Enum): class TransformerType(str, enum.Enum): lm_decoder = "lm_decoder" image_encoder = "image_encoder" + audio_encoder = "audio_encoder" @config_class() @@ -317,7 +326,7 @@ class TransformerConfig(BaseModelConfig): _abstract = False transformer_type: TransformerType = Field( default=TransformerType.lm_decoder, - desc="Type of the transformer. Choices: lm_decoder, image_encoder.", + desc="Type of the transformer. Choices: lm_decoder, image_encoder, audio_encoder.", hint=FieldHint.architecture, ) normalization: NormalizationConfig = Field( @@ -828,7 +837,7 @@ class VisionTransformerConfig(TransformerConfig): transformer_type: TransformerType = FieldUpdate( default=TransformerType.image_encoder, - desc="Type of the transformer. Choices: lm_decoder, image_encoder.", + desc="Type of the transformer. Choices: lm_decoder, image_encoder, audio_encoder.", hint=FieldHint.architecture, ) causal: bool = FieldUpdate( @@ -857,13 +866,31 @@ class AudioTransformerConfig(TransformerConfig): Configuration for the Audio Transformer model. """ + transformer_type: TransformerType = FieldUpdate( + default=TransformerType.audio_encoder, + desc="Type of the transformer. Choices: lm_decoder, image_encoder, audio_encoder.", + hint=FieldHint.architecture, + ) causal: bool = FieldUpdate( default=False, desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Audio Transformer.", hint=FieldHint.feature, ) + gated: bool = FieldUpdate( + default=False, + desc="MLP gating.", + hint=FieldHint.feature, + ) # rotary: AudioRotaryConfig = FieldUpdate( # default_factory=AudioRotaryConfig, # desc="Configuration for the rotary positional embeddings.", # hint=FieldHint.feature, # ) + + @property + def _transformer_kwargs(self) -> AudioTransformerKwargs: + return AudioTransformerKwargs + + @property + def _transformer_dim_names(self) -> AudioTransformerDimNames: + return AudioTransformerDimNames diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index af1a53f68..1b436eba3 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -280,9 +280,9 @@ def _create_tensors(self, sequence_length: int) -> None: ) def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[TransformerKwargs.sequence_length]) - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size + self._create_tensors(kwargs[self._transformer_kwargs.sequence_length]) + sequence_k = kwargs[self._transformer_kwargs.sequence_k_dim].size + sequence_q = kwargs[self._transformer_kwargs.sequence_q_dim].size kwargs[self._transformer_kwargs.attention_mask] = self._mask[ None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k ] From cd167fc81c77f0d78d8b9f77bf168c7416718773 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Wed, 28 May 2025 00:24:47 +0000 Subject: [PATCH 10/25] audio conversion --- fast_llm/models/gpt/config.py | 7 + fast_llm/models/gpt/conversion.py | 478 ++++++++++++++++++------------ fast_llm/models/gpt/model.py | 33 +++ fast_llm/models/gpt/trainer.py | 11 +- 4 files changed, 340 insertions(+), 189 deletions(-) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 2bfdf8925..ae6fc6ad8 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -72,6 +72,12 @@ class WhisperGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "whisper" +class AyraAudioModelGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "ayra_audio" + audio_name: typing.ClassVar[str] = "whisper" + text_name: typing.ClassVar[str] = "llama" + + @config_class() class GPTBatchConfig(BatchConfig): sequence_length: int = Field( @@ -156,6 +162,7 @@ class GPTModelConfig(FastLLMModelConfig): LlavaGPTHuggingfaceCheckpointFormat, WhisperGPTHuggingfaceCheckpointFormat, PixtralGPTHuggingfaceCheckpointFormat, + AyraAudioModelGPTHuggingfaceCheckpointFormat, ) @classmethod diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index da9e897bf..a1d91b2a8 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -32,6 +32,7 @@ from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType, TransformerConfig from fast_llm.layers.vision_encoder.config import VisionEncoderType from fast_llm.models.gpt.config import ( + AyraAudioModelGPTHuggingfaceCheckpointFormat, GPTBaseModelConfig, GPTModelConfig, LlamaGPTHuggingfaceCheckpointFormat, @@ -566,240 +567,212 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ] -class WhisperHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): +class WhisperHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = WhisperGPTHuggingfaceCheckpointFormat + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - # lm_converters = super()._create_config_converters() - lm_converters = super()._create_config_converters() - for idx, converter in enumerate(lm_converters): - if converter.export_names == (("model_type",),): - continue - elif converter.export_names == (("architectures",),): - ignore_index = idx - if converter.export_names: - converter.export_names = (("text_config", *converter.export_names[0]), *converter.export_names[1:]) - - return ( - lm_converters[:ignore_index] - + lm_converters[ignore_index + 1 :] - + [ - ConstantImportParamConverter( - fast_llm_names=(("audio_encoder", "type"),), fast_llm_value=AudioEncoderType.whisper - ), - ConstantExportParamConverter( - export_names=(("architectures",),), export_value=["WhisperForConditionalGeneration"] - ), - # Audio Adapter - # RenameParamConverter( - # fast_llm_names=(("vision_encoder", "adapter_size"),), - # export_names=(("text_config", "hidden_size"),), - # ), - # ConstantImportParamConverter( - # fast_llm_names=(("vision_encoder", "patch_norm", "type"),), - # fast_llm_value=NormalizationType.rms_norm, - # ), - # ConstantImportParamConverter( - # fast_llm_names=(("vision_encoder", "transformer", "normalization", "type"),), - # fast_llm_value=NormalizationType.rms_norm, - # ), - # Audio Transformer - RenameParamConverter( - fast_llm_names=(("audio_encoder", "transformer", "num_layers"),), - export_names=(("encoder_layers",),), + return super()._create_config_converters() + [ + # set default layernorm + ConstantImportParamConverter( + fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.layer_norm + ), + ConstantExportParamConverter( + export_names=(("architectures",),), export_value=["WhisperForConditionalGeneration"] + ), + ConstantImportParamConverter(fast_llm_names=(("type",),), fast_llm_value=AudioEncoderType.whisper), + # make transformer noncasual + ConstantImportParamConverter(fast_llm_names=(("transformer", "causal"),), fast_llm_value=False), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "num_layers", + ), ), - RenameParamConverter( - fast_llm_names=(("audio_encoder", "transformer", "hidden_size"),), - export_names=(("d_model",),), + export_names=(("num_hidden_layers",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "hidden_size", + ), ), - RenameParamConverter( - fast_llm_names=(("audio_encoder", "transformer", "num_attention_heads"),), - export_names=(("encoder_attention_heads",),), + export_names=(("d_model",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "num_attention_heads", + ), ), - # RenameParamConverter( - # fast_llm_names=(("audio_encoder", "transformer", "head_groups"),), - # export_names=( - # ( - # "encoder_attention_heads", - # ), - # ), - # ), - RenameParamConverter( - fast_llm_names=(("audio_encoder", "transformer", "ffn_hidden_size"),), - export_names=(("encoder_ffn_dim",),), + export_names=(("encoder_attention_heads",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "head_groups", + ), ), - MappedConfigParamConverter( - fast_llm_names=(("audio_encoder", "transformer", "activation_type"),), - export_names=(("activation_function",),), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, + export_names=(("encoder_attention_heads",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "ffn_hidden_size", + ), ), - # ConstantImportParamConverter( - # fast_llm_names=(("vision_encoder", "transformer", "gated"),), fast_llm_value=True - # ), - # MappedConfigParamConverter( - # fast_llm_names=(("vision_encoder", "adapter_activation_type"),), - # export_names=(("projector_hidden_act",),), - # fast_llm_value=ActivationType.from_hf_name, - # export_value=lambda activation_type: activation_type.hf_name, - # ), - # ConstantImportParamConverter( - # fast_llm_names=(("vision_encoder", "transformer", "add_linear_biases"),), fast_llm_value=False - # ), - # RenameParamConverter( - # fast_llm_names=(("vision_encoder", "transformer", "rotary", "theta"),), - # export_names=(("vision_config", "rope_theta"),), - # ), - ] - ) + export_names=(("encoder_ffn_dim",),), + ), + MappedConfigParamConverter( + fast_llm_names=(("transformer", "activation_type"),), + export_names=(("activation_function",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "rotary", "type"),), fast_llm_value=RotaryEmbeddingType.none + ), + ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=False), + ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=True), + RenameParamConverter( + fast_llm_names=(("num_mel_bins",),), + export_names=(("num_mel_bins",),), + ), + RenameParamConverter( + fast_llm_names=(("aud_downsampling_k",),), + export_names=(("encoder_projector_ds_rate",),), + ), + ] - def _create_vision_transformer_layer_converters( - self, - i: int, - ignore_export: bool = False, - hf_base_prefix: str = "", - fast_llm_offset: int = 1, - type: str | None = None, + def _get_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + return [ + WeightConverter(f"{fast_llm_prefix}.mlp.layer_1.weight", f"{hf_prefix}fc1.weight"), + WeightConverter(f"{fast_llm_prefix}.mlp.layer_1.bias", f"{hf_prefix}fc1.bias"), + WeightConverter(f"{fast_llm_prefix}.mlp.layer_2.weight", f"{hf_prefix}fc2.weight"), + WeightConverter(f"{fast_llm_prefix}.mlp.layer_2.bias", f"{hf_prefix}fc2.bias"), + ] + + def _create_audio_transformer_layer_converters( + self, transformer_layer_index: int, fast_llm_offset: int = 1, hf_base_prefix: str = "" ) -> list[WeightConverter]: - if type is not None: - if type == "vision": - transformer_config: TransformerConfig = self._model.config.base_model.vision_encoder.transformer - else: - transformer_config: TransformerConfig = self._model.config.base_model.transformer - norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm - converters = [] - names_bias_cls = [ + # Vision transformer layer + transformer_config = self._model.config.base_model.audio_encoder.transformer + norm_bias: bool = transformer_config.normalization.type == NormalizationType.layer_norm + name_bias_cls = [ # Self-attn ( - f"layers.{i+fast_llm_offset}.self_attn.query", - f"vision_tower.transformer.layers.{i}.attention.q_proj", + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.query", + f"{hf_base_prefix}layers.{transformer_layer_index}.self_attn.q_proj", transformer_config.add_attn_qkv_bias, QueryWeightConverter, ), ( - f"layers.{i+fast_llm_offset}.self_attn.key_value", + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.key_value", ( - f"vision_tower.transformer.layers.{i}.attention.k_proj", - f"vision_tower.transformer.layers.{i}.attention.v_proj", + f"{hf_base_prefix}layers.{transformer_layer_index}.self_attn.k_proj", + f"{hf_base_prefix}layers.{transformer_layer_index}.self_attn.v_proj", ), - transformer_config.add_attn_qkv_bias, + transformer_config.add_attn_qkv_bias, # TODO Toby: add permanent fix for key bias KeyValueWeightConverter, ), ( - f"layers.{i+fast_llm_offset}.self_attn.dense", - f"vision_tower.transformer.layers.{i}.attention.o_proj", + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.dense", + f"{hf_base_prefix}layers.{transformer_layer_index}.self_attn.out_proj", transformer_config.add_attn_dense_bias, WeightConverter, ), # Norm ( - f"layers.{i+fast_llm_offset}.norm_1", - f"vision_tower.transformer.layers.{i}.attention_norm", + f"layers.{fast_llm_offset + transformer_layer_index}.norm_1", + f"{hf_base_prefix}layers.{transformer_layer_index}.self_attn_layer_norm", norm_bias, WeightConverter, ), ( - f"layers.{i+fast_llm_offset}.norm_2", - f"vision_tower.transformer.layers.{i}.ffn_norm", + f"layers.{fast_llm_offset + transformer_layer_index}.norm_2", + f"{hf_base_prefix}layers.{transformer_layer_index}.final_layer_norm", norm_bias, WeightConverter, ), ] - for fast_llm_prefix, hf_prefix, use_bias, cls in names_bias_cls: + converters = [] + for fast_llm_prefix, hf_prefix, use_bias, cls in name_bias_cls: converters += self._get_weight_and_bias_converters( fast_llm_prefix, - () if ignore_export else hf_prefix, + hf_prefix, use_bias, - cls=IgnoreExportWeightConverter if ignore_export else cls, + cls, ) - # MLP - if ignore_export: - converters += self._get_weight_and_bias_converters( - f"layers.{i+fast_llm_offset}.mlp.layer_1", - (), - transformer_config.add_mlp_bias, - cls=IgnoreExportWeightConverter, - ) - converters += self._get_weight_and_bias_converters( - f"layers.{i+fast_llm_offset}.mlp.layer_2", - (), - transformer_config.add_mlp_bias, - cls=IgnoreExportWeightConverter, - ) - converters += [IgnoreExportWeightConverter(f"layers.{i+fast_llm_offset}.mlp.router.weight", ())] - else: - converters += self._get_vision_transformer_mlp_converters( - f"layers.{i+fast_llm_offset}", f"vision_tower.transformer.layers.{i}" - ) + converters += self._get_transformer_mlp_converters( + f"layers.{fast_llm_offset + transformer_layer_index}", + f"{hf_base_prefix}layers.{transformer_layer_index}.", + ) return converters - def _get_vision_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - return [ - SplitWeightConverter( - f"{fast_llm_prefix}.mlp.layer_1.weight", - (f"{hf_prefix}.feed_forward.gate_proj.weight", f"{hf_prefix}.feed_forward.up_proj.weight"), - ), - MLPLayer2Converter( - f"{fast_llm_prefix}.mlp.layer_2.weight", - f"{hf_prefix}.feed_forward.down_proj.weight", - self._model.config.base_model, - ), + def _create_weight_converters(self, offset: int = 0, hf_base_prefix: str = "") -> list[WeightConverter]: + converters = [] + + # audio encoder conv + converters += [ + WeightConverter(f"layers.{offset}.conv1_weight", f"{hf_base_prefix}conv1.weight"), + WeightConverter(f"layers.{offset}.conv2_weight", f"{hf_base_prefix}conv2.weight"), ] - def _create_vision_transformer_converters(self) -> list[WeightConverter]: - num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers - vision_transformer_converters = [] - for layer in range(num_layers): - # TODO Soham: check if args are correct - vision_transformer_converters.extend( - self._create_vision_transformer_layer_converters( - layer, - ignore_export=False, - hf_base_prefix="vision_tower.transformer.layers.", - fast_llm_offset=1, - type="vision", - ) - ) + if self._model.config.base_model.audio_encoder.conv_bias: + converters += [ + WeightConverter(f"layers.{offset}.conv1_bias", f"{hf_base_prefix}conv1.bias"), + WeightConverter(f"layers.{offset}.conv2_bias", f"{hf_base_prefix}conv2.bias"), + ] - return vision_transformer_converters + # position embedding + converters.append( + WeightConverter(f"layers.{offset}.positional_embeddings", f"{hf_base_prefix}embed_positions.weight") + ) - def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: - patch_conv_converters = [WeightConverter("layers.0.weight", "vision_tower.patch_conv.weight")] - if self._model.config.base_model.vision_encoder.conv_bias: - patch_conv_converters.append(WeightConverter("layers.0.bias", "vision_tower.patch_conv.bias")) - layernorm_converters = [ - WeightConverter("layers.0.norm.weight", "vision_tower.ln_pre.weight"), - ] - if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: - layernorm_converters.append(WeightConverter("layers.0.norm.bias", "vision_tower.ln_pre.bias")) - - vision_transformer_converters = self._create_vision_transformer_converters() - offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 1 - adapter_converters = [ - WeightConverter(f"layers.{offset}.layer_1.weight", "multi_modal_projector.linear_1.weight"), - WeightConverter(f"layers.{offset}.layer_1.bias", "multi_modal_projector.linear_1.bias"), - # TODO Soham: add bias based on config - WeightConverter(f"layers.{offset}.layer_2.weight", "multi_modal_projector.linear_2.weight"), - WeightConverter(f"layers.{offset}.layer_2.bias", "multi_modal_projector.linear_2.bias"), - ] + # transformer encoder layers + num_layers = self._model.config.base_model.audio_encoder.transformer.num_layers + for i in range(num_layers): + converters += self._create_audio_transformer_layer_converters(i, offset + 1, hf_base_prefix) - return patch_conv_converters + layernorm_converters + vision_transformer_converters + adapter_converters + offset = offset + num_layers + 1 - def _create_weight_converters(self) -> list[WeightConverter]: - vision_encoder_converter = self._create_vision_encoder_weight_converters() - offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 3 - # Embeddings - lm_converters = [ - WeightConverter(f"layers.{offset - 1}.word_embeddings_weight", f"language_model.model.embed_tokens.weight") - ] - for i in range(self._model.config.base_model.transformer.num_layers): - lm_converters += self._create_transformer_layer_converters( - fast_llm_layer_name=f"layers.{i + offset}", hf_layer_name=f"language_model.model.layers.{i}" + # add final layernorm + if self._model.config.base_model.audio_encoder.transformer.normalization.type == NormalizationType.layer_norm: + converters += [ + WeightConverter(f"layers.{offset}.norm_1.weight", f"{hf_base_prefix}layer_norm.weight"), + WeightConverter(f"layers.{offset}.norm_2.weight", "encoder_projector.layer_norm.weight"), + WeightConverter(f"layers.{offset}.norm_1.bias", f"{hf_base_prefix}layer_norm.bias"), + WeightConverter(f"layers.{offset}.norm_2.bias", "encoder_projector.layer_norm.bias"), + ] + + # multimodal projector + converters.extend( + [ + WeightConverter(f"layers.{offset}.layer_1.weight", "encoder_projector.linear1.weight"), + WeightConverter(f"layers.{offset}.layer_2.weight", "encoder_projector.linear2.weight"), + ] + ) + if self._model.config.base_model.audio_encoder.adapter_bias: + converters.extend( + [ + WeightConverter(f"layers.{offset}.layer_1.bias", "encoder_projector.linear1.bias"), + WeightConverter(f"layers.{offset}.layer_2.bias", "encoder_projector.linear2.bias"), + ] ) - lm_converters += self._create_lm_head_converters(hf_base_prefix="language_model.", fast_llm_offset=offset) - return vision_encoder_converter + lm_converters + + return converters + + @property + def num_layers(self) -> int: + # +2 for projector and conv layers + return self._model.config.base_model.audio_encoder.transformer.num_layers + 2 class PixtralHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): @@ -1007,6 +980,139 @@ def num_layers(self) -> int: return self._model.config.base_model.vision_encoder.transformer.num_layers + 2 +class AyraAudioModelHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = AyraAudioModelGPTHuggingfaceCheckpointFormat + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + + @classmethod + def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: + cfg_dict = cls._load_config(config.path) + kwargs = {} + if "text_config" in cfg_dict: + text_kwargs = cls._import_config(cfg_dict["text_config"]) + kwargs.update(text_kwargs) + if "audio_config" in cfg_dict: + audio_kwargs = cls._import_config(cfg_dict["audio_config"]) + audio_kwargs = {tuple(["audio_encoder"] + list(key)): value for key, value in audio_kwargs.items()} + kwargs.update(audio_kwargs) + kwargs.update( + cls._import_config( + {key: value for key, value in cfg_dict.items() if key not in ("text_config", "audio_config")} + ) + ) + imported_model_config = cls._model_class.get_base_model_config_class().from_dict({}, kwargs) + return CheckpointMetadata( + fast_llm_version=__version__, + model=cls._model_class, + format=config.format, + config=cls._model_class.from_dict({"base_model": imported_model_config.to_dict()}), + shards=["weights"], + ) + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantExportParamConverter(export_names=(("architectures",),), export_value=["AyraAudioModel"]), + # projector + MappedConfigParamConverter( + fast_llm_names=(("audio_encoder", "adapter_activation_type"),), + export_names=(("activation_function",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=(("audio_encoder", "adapter_size"),), + export_names=(("adapter_size",),), + ), + ] + + @classmethod + def _import_config(cls, config: dict[str, typing.Any]) -> GPTBaseModelConfig: + handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(config["model_type"]) + kwargs = {} + for converter in handler_cls._create_config_converters(): + try: + values = () + for export_name in converter.export_names: + try: + value = get_nested_dict_value(config, export_name) + except KeyError: + value = MISSING + values = values + (value,) + values = converter.import_params(values) + for fast_llm_name, value in zip(converter.fast_llm_names, values, strict=True): + if value is MISSING: + raise ValueError(f"Missing converted value for fast-llm parameter {fast_llm_name}") + if fast_llm_name in kwargs: + raise ValueError(f"Duplicate converted value for fast-llm parameter {fast_llm_name}") + kwargs[fast_llm_name] = value + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + return kwargs + + @classmethod + def _export_config(cls, config: BaseModelConfig) -> dict[str, typing.Any]: + # TODO Toby: implement for audio + exported_config = {} + vision_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.vision_name) + text_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.text_name) + for converter in vision_handler_cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, ("vision_encoder",) + fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, ("vision_config",) + export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + for converter in text_handler_cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, ("text_config",) + export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + for converter in cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + return exported_config + + def _create_weight_converters(self): + audio_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(self.format.audio_name) + audio_handler = audio_handler_cls(self._model) # TODO Toby: are we calling this twice? + converters = audio_handler._create_weight_converters(hf_base_prefix="encoder.", offset=0) + text_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(self.format.text_name) + text_handler = text_handler_cls(self._model) + converters.extend( + text_handler._create_weight_converters(hf_base_prefix="llm.", offset=audio_handler.num_layers) + ) + return converters + + class LlavaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig @@ -1275,5 +1381,5 @@ class AutoGPTHuggingfaceCheckpointHandler( LlavaGPTHuggingfaceCheckpointFormat.name: LlavaHuggingfaceCheckpointHandler, WhisperGPTHuggingfaceCheckpointFormat.name: WhisperHuggingfaceCheckpointHandler, PixtralGPTHuggingfaceCheckpointFormat.name: PixtralHuggingfaceCheckpointHandler, - # MultiModalGPTHuggingfaceCheckpointFormat.name: MultiModalHuggingfaceCheckpointHandler + AyraAudioModelGPTHuggingfaceCheckpointFormat.name: AyraAudioModelHuggingfaceCheckpointHandler, } diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 01cf0ec38..330bd8328 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -21,6 +21,8 @@ from fast_llm.layers.multi_modal.embedding import MultiModalEmbedding from fast_llm.layers.transformer.audio_transformer import AudioTransformerLayer from fast_llm.layers.transformer.config import ( + AudioTransformerDimNames, + AudioTransformerKwargs, RoutingType, TransformerDimNames, TransformerKwargs, @@ -217,6 +219,18 @@ def preprocess_meta( else: vision_kwargs = {} + if self._config.audio_encoder.enabled: + audio_kwargs = { + AudioEncoderKwargs.kv_channels: self._tensor_space.get_tensor_dim( + AudioTransformerDimNames.kv_channels + ).size, + AudioEncoderKwargs.out_channels: self._tensor_space.get_tensor_dim( + AudioEncoderKwargs.out_channels + ).size, + } + else: + audio_kwargs = {} + batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) batch_dim = TensorDim(TransformerDimNames.batch, micro_batch_size * batch_data.size, batch_data) @@ -272,6 +286,22 @@ def preprocess_meta( } ) + if self._config.audio_encoder.enabled: + audio_hidden_dim = self._tensor_space.get_tensor_dim(AudioTransformerDimNames.hidden) + audio_hidden_dims = ( + (hidden_sequence_q_dim, batch_dim, audio_hidden_dim) + if sequence_first + else (batch_dim, hidden_sequence_q_dim, audio_hidden_dim) + ) + audio_kwargs.update( + { + AudioTransformerKwargs.hidden_dims: audio_hidden_dims, + AudioTransformerKwargs.sequence_length: 1500, # TODO: Toby Parameterize + AudioTransformerKwargs.sequence_k_dim: 1500, + AudioTransformerKwargs.sequence_q_dim: 1500, + } + ) + common_kwargs = { LanguageModelKwargs.phase: phase, TransformerKwargs.sequence_first: sequence_first, @@ -281,6 +311,7 @@ def preprocess_meta( TransformerKwargs.micro_batch_size: micro_batch_size, } common_kwargs.update(vision_kwargs) + common_kwargs.update(audio_kwargs) sequence_k_pasts = range( sequence_q_dim.size * self._tensor_space.distributed_config.sequence_data_rank, @@ -482,6 +513,8 @@ def transformer_layers(self) -> list[TransformerLayer]: def embedding_layer_index(self) -> int: if self._config.vision_encoder.enabled: return self._config.vision_encoder.transformer.num_layers + 2 + elif self._config.audio_encoder.enabled: + return self._config.audio_encoder.transformer.num_layers + 2 else: return 0 diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 38264d4ad..3000e9be7 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -30,9 +30,6 @@ def _get_sampling_parameters( "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, "cross_document_attention": self._config.batch.cross_document_attention, "extra_tokens": self._config.model.base_model.prediction_heads, - "aud_downsampling_k": self._config.model.base_model.audio_encoder.aud_downsampling_k, - "aud_padding_duration": self._config.batch.aud_padding_duration, - "aud_sampling_rate": self._config.model.base_model.audio_encoder.aud_sampling_rate, } ) if self._config.model.base_model.vision_encoder.enabled: @@ -42,6 +39,14 @@ def _get_sampling_parameters( "image_size": self._config.batch.image_size, } ) + if self._config.model.base_model.audio_encoder.enabled: + parameters.update( + { + "aud_downsampling_k": self._config.model.base_model.audio_encoder.aud_downsampling_k, + "aud_padding_duration": self._config.batch.aud_padding_duration, + "aud_sampling_rate": self._config.model.base_model.audio_encoder.aud_sampling_rate, + } + ) return parameters if _return_dict else GPTSamplingParameters(**parameters) def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: From e0f7dfd7cd23f4974aa1e19c8bdb39e0b5a8f290 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Thu, 29 May 2025 16:31:51 +0000 Subject: [PATCH 11/25] mm loss masking spans --- fast_llm/data/dataset/gpt/memmap.py | 61 +++++++----- fast_llm/data/dataset/gpt/sampled.py | 99 ++++++++++++++----- .../data/preparator/gpt_memmap/prepare.py | 76 +++++++------- 3 files changed, 148 insertions(+), 88 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 855026cfc..eb0a14c86 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -10,7 +10,6 @@ from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_DTYPES_INV, MEMMAP_INDEX_HEADER from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.layers.vision_encoder.preprocessing import get_num_image_tokens, get_resize_dims from fast_llm.utils import Assert, div @@ -108,7 +107,7 @@ def _init( offset += ( self._num_spans.nbytes + self._num_spans.sum() * 2 * np.dtype(np.int32).itemsize - + sum([x.nbytes for x in self._spans]) + # + sum([x.nbytes for x in self._spans]) ) self._num_pixels = 0 self._image_lengths = [] @@ -141,8 +140,8 @@ def _init( ) images_seen += n_images offset = offset + self._n_images.nbytes + 3 * self._n_images.sum() * np.dtype(np.int32).itemsize - self._audio_lengths = [] - self._audio_positions = [] + self._audio_lengths = [] # list of arrays + self._audio_positions = [] # list of arrays if self._has_audio and self._version >= 5: self._n_audio = np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset @@ -267,19 +266,19 @@ def get( if self._has_audio: audio_positions = self._audio_positions[idx] # increment offset by documents and images - offset = ( + aud_offset = ( self._pointers[idx] + offset * np.dtype(self._dtype).itemsize + self._document_sizes[idx] * np.dtype(self._dtype).itemsize ) if self._has_images and len(self._image_lengths) > 0: - offset += self._image_lengths[idx].prod(initial=3) * np.dtype(np.uint8).itemsize + aud_offset += self._image_lengths[idx].prod(initial=3) * np.dtype(np.uint8).itemsize all_audio = np.frombuffer( self._bin_buffer, dtype=np.dtype(np.float32), count=self._audio_lengths[idx].sum(), - offset=offset, + offset=aud_offset, ) start = 0 for audio_length in self._audio_lengths[idx]: @@ -295,23 +294,37 @@ def get( ] sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset - if images: - image_idx = 0 - for span in sample_spans: - additional_tokens = 0 - image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") - while image_position >= span[0] and image_position <= span[1]: - image_tokens = get_num_image_tokens( - get_resize_dims(*self._image_lengths[idx][image_idx], image_size, image_size, patch_size), - patch_size, - image_break=image_break, - ) - additional_tokens += image_tokens - image_idx += 1 - image_position = ( - image_positions[image_idx] if image_idx < len(image_positions) else float("inf") - ) - span[1] += additional_tokens + # if images: + # image_idx = 0 + # for span in sample_spans: + # additional_tokens = 0 + # image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + # while image_position >= span[0] and image_position <= span[1]: + # image_tokens = get_num_image_tokens( + # get_resize_dims(*self._image_lengths[idx][image_idx], image_size, image_size, patch_size), + # patch_size, + # image_break=image_break, + # ) + # additional_tokens += image_tokens + # image_idx += 1 + # image_position = ( + # image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + # ) + # span[1] += additional_tokens + # if audio: + # audio_idx = 0 + # for span in sample_spans: + # additional_tokens = 0 + # audio_position = audio_positions[audio_idx] if audio_idx < len(audio_positions) else float("inf") + # while audio_position >= span[0] and audio_position <= span[1]: + # audio_tokens = ... + # additional_tokens += audio_tokens + # audio_idx += 1 + # audio_position = ( + # audio_positions[audio_idx] if audio_idx < len(audio_positions) else float("inf") + # ) + # span[1] += additional_tokens + return GPTSample( token_ids=token_ids, images=images, diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 87cbebf3e..c6a7e2bcb 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -144,7 +144,7 @@ def _compute_audio_token_size(self, sizes): sizes.fill(raw_audio_seq_length) # set all audio sizes to padded amount # account for mel spectogram, convolution, downsampling k - audio_token_size_arr = sizes // 160 # default hop length TODO: check divisible? + audio_token_size_arr = sizes // 160 # default hop length TODO Toby: check divisible? audio_token_size_arr = audio_token_size_arr // ( 2 * self._parameters.aud_downsampling_k ) # convolution (2) * downsampling @@ -557,24 +557,27 @@ def __getitem__(self, index: int) -> typing.Any: start_pos = 0 # add tokens and multi modal padding placeholders - multimodal_positions = np.concatenate( - [ - arr.astype(np.int32) - for arr in (sample.image_positions, sample.audio_positions) - if arr is not None - ] - ) or np.array([], dtype=np.int32) - multimodal_positions.sort() - for idx, mm_position in enumerate(multimodal_positions): - if ( - sample.image_positions is not None and mm_position in sample.image_positions - ): # TODO Toby: use enum - mm_type = "image" - elif sample.audio_positions is not None and mm_position in sample.audio_positions: - mm_type = "audio" - else: - assert False - # image_positions.append(im_positions + len(token_ids) + image_tokens_added) + # multimodal_positions = np.concatenate( + # [ + # arr.astype(np.int32) + # for arr in (sample.image_positions, sample.audio_positions) + # if arr is not None + # ] + # ) or np.array([], dtype=np.int32) + # multimodal_positions.sort() + + multimodal_positions = [] + if sample.image_positions is not None: + multimodal_positions.extend( + [(pos, "image", idx) for idx, pos in enumerate(sample.image_positions)] + ) + if sample.audio_positions is not None: + multimodal_positions.extend( + [(pos, "audio", idx) for idx, pos in enumerate(sample.audio_positions)] + ) + + multimodal_positions.sort(key=lambda x: x[0]) + for global_idx, (mm_position, mm_type, source_idx) in enumerate(multimodal_positions): # Add placeholders for image and audio tokens tokens token_ids.append(sample.token_ids[start_pos:mm_position]) if mm_type == "image": @@ -584,8 +587,8 @@ def __getitem__(self, index: int) -> typing.Any: if self._parameters.image_break_token is not None: # Calculate patch dimensions for the image height, width = get_resize_dims( - image_lengths[idx][0], - image_lengths[idx][1], + image_lengths[source_idx][0], + image_lengths[source_idx][1], self._parameters.image_size, self._parameters.image_size, self._parameters.patch_size, @@ -613,11 +616,14 @@ def __getitem__(self, index: int) -> typing.Any: mm_tokens_added += resized_image_tokens else: # Just add placeholders for all image tokens without break tokens - token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) - mm_tokens_added += image_sizes[idx] + token_ids.append(np.full((image_sizes[source_idx],), -100, dtype=np.int64)) + mm_tokens_added += image_sizes[source_idx] elif mm_type == "audio": - audio_positions.append(sum(t.size for t in token_ids)) - token_ids.append(np.full((audio_token_size_arr[idx],), -100, dtype=np.int64)) + audio_pos = sum(t.size for t in token_ids) # includes mm tokens added already + audio_positions.append(audio_pos) + token_ids.append( + np.full((audio_token_size_arr[source_idx],), -100, dtype=np.int64) + ) # TODO Toby: index doesnt work here mm_tokens_added += audio_tokens start_pos = mm_position token_ids.append(sample.token_ids[start_pos:]) @@ -634,7 +640,47 @@ def __getitem__(self, index: int) -> typing.Any: audio.append([]) if self._parameters.use_loss_masking_spans: - for loss_masking_span in sample.loss_masking_spans: + mm_idx = 0 + total_mm_tokens = 0 + for loss_masking_span in sample.loss_masking_spans: # TODO: check these must be sorted + mm_tokens_in_span = 0 + mm_position, mm_type, source_idx = ( + multimodal_positions[mm_idx] + if mm_idx < len(multimodal_positions) + else (float("inf"), _, _) + ) + + # increment mm_idx until span is reached + while mm_position < loss_masking_span[0]: + if mm_type == "image": + num_mm_tokens = image_sizes[source_idx] + elif mm_type == "audio": + num_mm_tokens = audio_token_size_arr[source_idx] + total_mm_tokens += num_mm_tokens + mm_idx += 1 + mm_position, mm_type, source_idx = ( + multimodal_positions[mm_idx] + if mm_idx < len(multimodal_positions) + else (float("inf"), _, _) + ) + + # get all multimodal positions within span + while mm_position >= loss_masking_span[0] and mm_position <= loss_masking_span[1]: + if mm_type == "image": + num_mm_tokens = image_sizes[source_idx] + elif mm_type == "audio": + num_mm_tokens = audio_token_size_arr[source_idx] + mm_tokens_in_span += num_mm_tokens + mm_idx += 1 + mm_position, mm_type, source_idx = ( + multimodal_positions[mm_idx] + if mm_idx < len(multimodal_positions) + else (float("inf"), _, _) + ) + loss_masking_span[0] += total_mm_tokens # increment by all mm tokens before span + loss_masking_span[1] += total_mm_tokens + mm_tokens_in_span + total_mm_tokens += mm_tokens_in_span + span = np.clip( loss_masking_span + token_count - token_start, 0, @@ -658,6 +704,7 @@ def __getitem__(self, index: int) -> typing.Any: if self._parameters.use_loss_masking_spans else None ) + images = [im for img_list in images for im in img_list] if images else None image_positions = np.array(image_positions) if image_positions else None diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index aa2481f06..283a6bf80 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -93,44 +93,44 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[ "num_pixels": num_pixels, } - def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids, token_spans, images, image_token_positions = map( - list, - zip( - *[ - ( - np.array(input_ids, dtype=self._data_type.numpy), - np.array(token_spans, dtype=np.int32).reshape(-1, 2), - np.array(images, dtype=np.uint8), - np.array(image_token_positions, dtype=np.int32), - ) - for input_ids, token_spans, images, image_token_positions in [ - self._tokenizer.tokenize_with_spans(text, char_spans) - for text, char_spans in zip( - batch[self._config.dataset.field], - batch.get(self._config.dataset.loss_masking_spans, itertools.repeat(None)), - batch.get(self._config.dataset.images, itertools.repeat(None)), - batch.get(self._config.dataset.image_positions, itertools.repeat(None)), - ) - ] - ] - ), - ) - num_tokens = [len(x) for x in input_ids] - num_pixels = [0] * len(input_ids) - for idx, images in enumerate(images): - for bytes_im in images: - with PIL.Image.open(io.BytesIO(bytes_im["bytes"])) as im: - width, height = im.size - num_pixels[idx] += width * height * 3 - return { - "input_ids": input_ids, - "token_spans": token_spans, - "images": images, - "image_positions": image_token_positions, - "num_tokens": num_tokens, - "num_pixels": num_pixels, - } + # def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: + # input_ids, token_spans, images, image_token_positions = map( + # list, + # zip( + # *[ + # ( + # np.array(input_ids, dtype=self._data_type.numpy), + # np.array(token_spans, dtype=np.int32).reshape(-1, 2), + # np.array(images, dtype=np.uint8), + # np.array(image_token_positions, dtype=np.int32), + # ) + # for input_ids, token_spans, images, image_token_positions in [ + # self._tokenizer.tokenize_with_spans(text, char_spans) + # for text, char_spans in zip( + # batch[self._config.dataset.field], + # batch.get(self._config.dataset.loss_masking_spans, itertools.repeat(None)), + # batch.get(self._config.dataset.images, itertools.repeat(None)), + # batch.get(self._config.dataset.image_positions, itertools.repeat(None)), + # ) + # ] + # ] + # ), + # ) + # num_tokens = [len(x) for x in input_ids] + # num_pixels = [0] * len(input_ids) + # for idx, images in enumerate(images): + # for bytes_im in images: + # with PIL.Image.open(io.BytesIO(bytes_im["bytes"])) as im: + # width, height = im.size + # num_pixels[idx] += width * height * 3 + # return { + # "input_ids": input_ids, + # "token_spans": token_spans, + # "images": images, + # "image_positions": image_token_positions, + # "num_tokens": num_tokens, + # "num_pixels": num_pixels, + # } def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetConfig: shard_idx, shard_dataset = args From 0ae74d191c99d2d6e31232bfb3e1a67015248109 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Thu, 29 May 2025 16:33:00 +0000 Subject: [PATCH 12/25] add lr scale --- fast_llm/layers/audio_encoder/adapter.py | 6 ++++++ fast_llm/layers/audio_encoder/config.py | 21 ++++++++++++++++++++- fast_llm/layers/audio_encoder/encoder.py | 20 ++++++++++++++------ fast_llm/models/gpt/conversion.py | 13 +++++++++---- 4 files changed, 49 insertions(+), 11 deletions(-) diff --git a/fast_llm/layers/audio_encoder/adapter.py b/fast_llm/layers/audio_encoder/adapter.py index 8c0c7175b..b02b4e77e 100644 --- a/fast_llm/layers/audio_encoder/adapter.py +++ b/fast_llm/layers/audio_encoder/adapter.py @@ -22,11 +22,14 @@ def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): input_dim = tensor_space.get_tensor_dim(AudioEncoderDimNames.adapter_input) self._activation_type = config.adapter_activation_type self._use_adapter_bias = config.adapter_bias + self.lr_scale = config.adapter_lr_scale self.norm_1 = config.transformer.normalization.get_layer(audio_hidden_dim) + self.norm_1.lr_scale = self.lr_scale self.norm_2 = config.transformer.normalization.get_layer( tensor_space.get_tensor_dim(AudioEncoderDimNames.adapter_size) ) + self.norm_2.lr_scale = self.lr_scale # TODO Soham: Make them OutputParallelLinear instead? How would this work with parallelism? self.layer_1 = Linear( @@ -35,6 +38,7 @@ def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): bias=self._use_adapter_bias, weight_init_method=init_normal_(), bias_init_method=init_normal_(), + lr_scale=self.lr_scale, ) self.layer_2 = Linear( tensor_space.get_tensor_dim(AudioEncoderDimNames.adapter_size), @@ -42,6 +46,7 @@ def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): bias=self._use_adapter_bias, weight_init_method=init_normal_(), bias_init_method=init_normal_(), + lr_scale=self.lr_scale, ) self.aud_downsampling_k = config.aud_downsampling_k @@ -59,6 +64,7 @@ def forward( tensor_name="Audio adapter output", dtype=input_.dtype, ) + input_ = self.norm_1(input_) batch_size, seq_len, dim = input_.size() # Check if sequence length is divisible by downsampling rate. diff --git a/fast_llm/layers/audio_encoder/config.py b/fast_llm/layers/audio_encoder/config.py index 3e09b39f9..9503c60cc 100644 --- a/fast_llm/layers/audio_encoder/config.py +++ b/fast_llm/layers/audio_encoder/config.py @@ -1,10 +1,11 @@ import enum -from fast_llm.config import Field, FieldHint, config_class +from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.transformer.config import AudioTransformerConfig +from fast_llm.utils import Assert class AudioEncoderDimNames: @@ -68,6 +69,18 @@ class AudioEncoderConfig(BaseModelConfig): desc="Encoder convolution layer kernel size.", hint=FieldHint.core, ) + conv_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the convolutional layer weights.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + pos_emb_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the position embedding layer weights.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) # adapter configs adapter_size: int = Field( @@ -85,6 +98,12 @@ class AudioEncoderConfig(BaseModelConfig): desc="Whether to use bias in the adapter layer.", hint=FieldHint.optional, ) + adapter_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the adapter weights.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) # audio configs num_mel_bins: int = Field( diff --git a/fast_llm/layers/audio_encoder/encoder.py b/fast_llm/layers/audio_encoder/encoder.py index 20c7d5078..b35cc1740 100644 --- a/fast_llm/layers/audio_encoder/encoder.py +++ b/fast_llm/layers/audio_encoder/encoder.py @@ -14,8 +14,9 @@ def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): super().__init__() self._tensor_space = tensor_space self.dropout_p = config.encoder_dropout + self._conv_lr_scale = config.conv_lr_scale + self._pos_emb_lr_scale = config.pos_emb_lr_scale - # TODO Toby: lr_scale self.conv1_weight = ParameterMeta.from_dims( ( self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels), @@ -23,8 +24,9 @@ def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): self._tensor_space.get_tensor_dim(AudioEncoderDimNames.kernel_size), ), init_method=init_normal_(), + lr_scale=self._conv_lr_scale, ) - self.conv1_stride = 1 # TODO: parameterize? + self.conv1_stride = 1 # TODO Toby: parameterize? self.conv2_weight = ParameterMeta.from_dims( ( @@ -33,15 +35,20 @@ def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): self._tensor_space.get_tensor_dim(AudioEncoderDimNames.kernel_size), ), init_method=init_normal_(), + lr_scale=self._conv_lr_scale, ) - self.conv2_stride = 2 # TODO: parameterize? + self.conv2_stride = 2 # TODO Toby: parameterize? if config.conv_bias: self.conv1_bias = ParameterMeta.from_dims( - (self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels),), init_method=init_normal_() + (self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels),), + init_method=init_normal_(), + lr_scale=self._conv_lr_scale, ) self.conv2_bias = ParameterMeta.from_dims( - (self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels),), init_method=init_normal_() + (self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels),), + init_method=init_normal_(), + lr_scale=self._conv_lr_scale, ) else: self.conv1_bias = None @@ -53,6 +60,7 @@ def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels), ), init_method=init_normal_(), + lr_scale=self._pos_emb_lr_scale, ) def forward( @@ -66,7 +74,7 @@ def forward( if isinstance(input_, TensorMeta): return TensorMeta.from_dims(hidden_dims, tensor_name="audio conv output", dtype=input_.dtype) - # TODO: check how to best cast dtype + # TODO Toby: check how to best cast dtype input_ = input_.to(self.conv1_weight.dtype) input_ = torch.nn.functional.conv1d( diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index a1d91b2a8..6438ce0f9 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -644,10 +644,6 @@ def _create_config_converters(cls) -> list[ParamConverter]: fast_llm_names=(("num_mel_bins",),), export_names=(("num_mel_bins",),), ), - RenameParamConverter( - fast_llm_names=(("aud_downsampling_k",),), - export_names=(("encoder_projector_ds_rate",),), - ), ] def _get_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: @@ -1024,6 +1020,15 @@ def _create_config_converters(cls) -> list[ParamConverter]: fast_llm_names=(("audio_encoder", "adapter_size"),), export_names=(("adapter_size",),), ), + RenameParamConverter( + fast_llm_names=( + ( + "audio_encoder", + "aud_downsampling_k", + ), + ), + export_names=(("encoder_projector_ds_rate",),), + ), ] @classmethod From 438ba80062077b9c0c69653229fd1c04a9a88c91 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Fri, 30 May 2025 16:42:54 +0000 Subject: [PATCH 13/25] mel spec changes --- fast_llm/data/dataset/gpt/sampled.py | 3 +- .../layers/audio_encoder/preprocessing.py | 61 ++++++++++++------- 2 files changed, 41 insertions(+), 23 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 66ae6a881..b2d303799 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -636,7 +636,8 @@ def __getitem__(self, index: int) -> typing.Any: else: images.append([]) if sample.audio: - audio.append(self.apply_audio_padding(sample.audio)) + # audio.append(self.apply_audio_padding(sample.audio)) + audio.append(sample.audio) else: audio.append([]) diff --git a/fast_llm/layers/audio_encoder/preprocessing.py b/fast_llm/layers/audio_encoder/preprocessing.py index 9d0db1b41..506e026e6 100644 --- a/fast_llm/layers/audio_encoder/preprocessing.py +++ b/fast_llm/layers/audio_encoder/preprocessing.py @@ -1,14 +1,12 @@ import typing import torch -from torchaudio.transforms import MelSpectrogram +from transformers import WhisperFeatureExtractor from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.layers.audio_encoder.config import AudioEncoderConfig, AudioEncoderKwargs -# from transformers import WhisperFeatureExtractor - class AudioPreprocessor(Preprocessor): def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): @@ -16,21 +14,21 @@ def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - # self.feature_extractor = WhisperFeatureExtractor(sampling_rate=self._config.aud_sampling_rate) + self.feature_extractor = WhisperFeatureExtractor(sampling_rate=self._config.aud_sampling_rate) - self.mel_transform = MelSpectrogram( - sample_rate=self._config.aud_sampling_rate, - n_fft=400, - win_length=400, - hop_length=160, - n_mels=80, - f_min=0.0, - f_max=8000.0, - mel_scale="slaney", - norm="slaney", - center=True, - power=2.0, - ) + # self.mel_transform = MelSpectrogram( + # sample_rate=self._config.aud_sampling_rate, + # n_fft=400, + # win_length=400, + # hop_length=160, + # n_mels=80, + # f_min=0.0, + # f_max=8000.0, + # mel_scale="slaney", + # norm="slaney", + # center=True, + # power=2.0, + # ) def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: # kwargs[AudioEncoderKwargs.audio_mel_meta] = TensorMeta.from_dims( @@ -49,13 +47,31 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: audio_raw = kwargs[AudioEncoderKwargs.audio] - flattened_audio = [audio_arr for sequence in audio_raw for audio_arr in sequence] - flattened_audio_tensor = torch.stack(flattened_audio, dim=0) + flattened_audio = [ + audio_arr for sequence in audio_raw for audio_arr in sequence + ] # flatten in the batch dimension + # flattened_audio_tensor = torch.stack(flattened_audio, dim=0) # audio_inputs = self.feature_extractor(audio_raw, sampling_rate=16000, return_tensors="pt") - self.mel_transform.to(self._tensor_space.distributed.device) + # self.mel_transform.to(self._tensor_space.distributed.device) + + # audio_mel = self.mel_transform(flattened_audio_tensor) + # flattened_audio_tensor = np.stack(flattened_audio, axis=0) + # audio_inputs = self.feature_extractor(flattened_audio_tensor, sampling_rate=16000, return_tensors="pt") + # audio_mel = audio_inputs['input_features'] + + audio_mel = [] + for audio in flattened_audio: + audio_mel.append( + self.feature_extractor( + audio, + sampling_rate=self._config.aud_sampling_rate, + return_tensors="pt", + max_length=30 * self._config.aud_sampling_rate, + )["input_features"] + ) + audio_mel = torch.stack(audio_mel, dim=0).squeeze(1) - audio_mel = self.mel_transform(flattened_audio_tensor) - audio_mel = audio_mel[:, :, :-1] # TODO Toby: check this! + # audio_mel = audio_mel[:, :, :-1] # TODO Toby: check this! # # set attention mask # TODO Toby: fix backup attention # sequence_k = kwargs[self._transformer_kwargs.sequence_k_dim].size @@ -65,4 +81,5 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: # ] # kwargs[self._transformer_kwargs.attention_mask_value] = self._mask_value + audio_mel = audio_mel.to(self._tensor_space.distributed.device) kwargs[AudioEncoderKwargs.audio_mel] = audio_mel From 525543a74bf4d4362f441766f41c07172234d064 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Fri, 30 May 2025 16:44:10 +0000 Subject: [PATCH 14/25] updates --- fast_llm/layers/audio_encoder/adapter.py | 13 +++++++------ fast_llm/models/gpt/conversion.py | 23 +++++++++++++++++++---- fast_llm/models/gpt/model.py | 5 +---- 3 files changed, 27 insertions(+), 14 deletions(-) diff --git a/fast_llm/layers/audio_encoder/adapter.py b/fast_llm/layers/audio_encoder/adapter.py index b02b4e77e..bc4f8f00f 100644 --- a/fast_llm/layers/audio_encoder/adapter.py +++ b/fast_llm/layers/audio_encoder/adapter.py @@ -77,10 +77,11 @@ def forward( # Reshape: group every k frames together (concatenate along feature dimension). new_seq_len = seq_len // self.aud_downsampling_k input_ = input_.contiguous().view(batch_size, new_seq_len, dim * self.aud_downsampling_k) - - res = self.layer_2( - self.norm_2( - torch_mlp_activation(input_=self.layer_1(input_), gated=False, activation_type=self._activation_type) - ) + layer1_res = torch_mlp_activation( + input_=self.layer_1(input_), gated=False, activation_type=self._activation_type ) - return res + torch.manual_seed(0) # TODO Toby: remove after debugging + layer1_res_dropout = torch.nn.functional.dropout(layer1_res, 0.1) + layer1_res_norm = self.norm_2(layer1_res_dropout) + layer2_res = self.layer_2(layer1_res_norm) + return layer2_res diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 6438ce0f9..ad348ce97 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -647,11 +647,26 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] def _get_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + # return [ + # WeightConverter(f"{fast_llm_prefix}.mlp.layer_1.weight", f"{hf_prefix}fc1.weight"), + # WeightConverter(f"{fast_llm_prefix}.mlp.layer_1.bias", f"{hf_prefix}fc1.bias"), + # WeightConverter(f"{fast_llm_prefix}.mlp.layer_2.weight", f"{hf_prefix}fc2.weight"), + # WeightConverter(f"{fast_llm_prefix}.mlp.layer_2.bias", f"{hf_prefix}fc2.bias"), + # ] + transformer_config = self._model.config.base_model.audio_encoder.transformer return [ - WeightConverter(f"{fast_llm_prefix}.mlp.layer_1.weight", f"{hf_prefix}fc1.weight"), - WeightConverter(f"{fast_llm_prefix}.mlp.layer_1.bias", f"{hf_prefix}fc1.bias"), - WeightConverter(f"{fast_llm_prefix}.mlp.layer_2.weight", f"{hf_prefix}fc2.weight"), - WeightConverter(f"{fast_llm_prefix}.mlp.layer_2.bias", f"{hf_prefix}fc2.bias"), + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + f"{hf_prefix}fc1", + transformer_config.add_mlp_bias, + WeightConverter, + ), + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}fc2", + transformer_config.add_mlp_bias, + MLPLayer2Converter, + ), ] def _create_audio_transformer_layer_converters( diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 330bd8328..57ec951b1 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -479,10 +479,7 @@ def preprocess( if batch.audio is not None: kwargs[AudioEncoderKwargs.audio] = [ - [ - aud.to(device=self._tensor_space.distributed.device, dtype=torch.float32, non_blocking=True) - for aud in audio - ] + [aud.to(device="cpu", dtype=torch.float32, non_blocking=True) for aud in audio] for audio in batch.audio ] kwargs[AudioEncoderKwargs.audio_positions] = batch.audio_positions From 95526a3d674b89e89c8559757f286d1b48ab89a6 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Mon, 2 Jun 2025 21:03:45 +0000 Subject: [PATCH 15/25] adding audio start and end tokens --- fast_llm/data/dataset/gpt/config.py | 2 + fast_llm/data/dataset/gpt/sampled.py | 142 +++++++++++------- fast_llm/layers/audio_encoder/config.py | 12 ++ .../layers/audio_encoder/preprocessing.py | 42 ++++++ fast_llm/models/gpt/trainer.py | 2 + 5 files changed, 144 insertions(+), 56 deletions(-) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 0b72402f1..9819f4e81 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -81,6 +81,8 @@ class GPTSamplingParameters(SamplingParameters): aud_sampling_rate: int | None = None image_break_token: int | None = None image_end_token: int | None = None + audio_start_token: int | None = None + audio_end_token: int | None = None # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. extra_tokens: int = 1 diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index b2d303799..63337e671 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -14,6 +14,7 @@ from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank +from fast_llm.layers.audio_encoder.preprocessing import get_num_audio_tokens from fast_llm.layers.vision_encoder.preprocessing import get_num_image_tokens, get_resize_dims from fast_llm.utils import Assert, div @@ -132,38 +133,6 @@ def __init__( # No barrier yet to allow running in parallel. # There needs to be one before calling `__getitem__`, normally handled through `GPTData`. - def _compute_audio_token_size(self, sizes): - if len(sizes) == 0: # sample has no audio - return sizes, False - to_filter = False - # account for padding - if self._parameters.aud_padding_duration > 0: - raw_audio_seq_length = self._parameters.aud_padding_duration * self._parameters.aud_sampling_rate - sizes = sizes.copy() # original is read-only - to_filter = bool(np.any(sizes > raw_audio_seq_length)) # filter sample where any audio is too long - sizes.fill(raw_audio_seq_length) # set all audio sizes to padded amount - - # account for mel spectogram, convolution, downsampling k - audio_token_size_arr = sizes // 160 # default hop length TODO Toby: check divisible? - audio_token_size_arr = audio_token_size_arr // ( - 2 * self._parameters.aud_downsampling_k - ) # convolution (2) * downsampling - return audio_token_size_arr, to_filter - - def apply_audio_padding(self, audio): - if len(audio) == 0: - return audio - # TODO Toby: check 2d - padded_audio = [] - if self._parameters.aud_padding_duration > 0: - raw_audio_seq_length = self._parameters.aud_padding_duration * self._parameters.aud_sampling_rate - for aud in audio: - padded = np.pad(aud, (0, raw_audio_seq_length - len(aud)), mode="constant", constant_values=0) - padded_audio.append(padded) - return padded_audio - else: - return audio - def _sample(self) -> None: """ Create a `GPTSampledDataset` with the requested parameters. @@ -198,7 +167,14 @@ def _sample(self) -> None: audio_token_sizes = torch.zeros_like(document_sizes).to(self._device) long_audio_filter = torch.zeros_like(document_sizes, dtype=torch.bool) # longer than audio padding for i, sizes in enumerate(audio_sizes): - audio_token_size_arr, to_filter = self._compute_audio_token_size(sizes) + audio_token_size_arr, to_filter = get_num_audio_tokens( + sizes, + self._parameters.aud_padding_duration, + self._parameters.aud_sampling_rate, + self._parameters.aud_downsampling_k, + self._parameters.audio_start_token, + self._parameters.audio_end_token, + ) audio_token_sizes[i] = audio_token_size_arr.sum() long_audio_filter[i] = to_filter @@ -371,7 +347,7 @@ def _sample(self) -> None: unshuffled_tokens = 0 if not self._truncate_documents: - yaml_data["unshuffled_tokens"] = unshuffled_tokens.item() + yaml_data["unshuffled_tokens"] = unshuffled_tokens.item() * unshuffled_epochs self._load_yaml_data(yaml_data) if self._yaml_path is not None: self._yaml_path.parent.mkdir(parents=True, exist_ok=True) @@ -520,7 +496,14 @@ def __getitem__(self, index: int) -> typing.Any: ] image_tokens = sum(image_sizes) - audio_token_size_arr, _ = self._compute_audio_token_size(audio_lengths) + audio_token_size_arr, _ = get_num_audio_tokens( + audio_lengths, + self._parameters.aud_padding_duration, + self._parameters.aud_sampling_rate, + self._parameters.aud_downsampling_k, + self._parameters.audio_start_token, + self._parameters.audio_end_token, + ) audio_tokens = audio_token_size_arr.sum() document_size = text_size + image_tokens + audio_tokens @@ -585,14 +568,16 @@ def __getitem__(self, index: int) -> typing.Any: [(pos, "audio", idx) for idx, pos in enumerate(sample.audio_positions)] ) + token_ids_per_sample = [] + special_mm_tok_loss_masking_spans = np.empty((0, 2), dtype=np.int32) multimodal_positions.sort(key=lambda x: x[0]) for global_idx, (mm_position, mm_type, source_idx) in enumerate(multimodal_positions): # Add placeholders for image and audio tokens tokens - token_ids.append(sample.token_ids[start_pos:mm_position]) + token_ids_per_sample.append(sample.token_ids[start_pos:mm_position]) + text_tokens_added += len(token_ids_per_sample[-1]) if mm_type == "image": # image_positions.append(im_positions + len(token_ids) + image_tokens_added) # Add placeholders for image tokens - text_tokens_added += len(token_ids[-1]) image_positions.append(text_tokens_added + mm_tokens_added) if self._parameters.image_break_token is not None: height, width = resized_image_lengths[source_idx] @@ -616,21 +601,55 @@ def __getitem__(self, index: int) -> typing.Any: image_token_array = np.full((image_sizes[source_idx],), -100, dtype=np.int64) if self._parameters.image_end_token is not None: image_token_array[-1] = self._parameters.image_end_token - token_ids.append(image_token_array) + token_ids_per_sample.append(image_token_array) mm_tokens_added += image_sizes[source_idx] elif mm_type == "audio": - audio_pos = sum(t.size for t in token_ids) # includes mm tokens added already + # audio_pos = sum(t.size for t in token_ids) # includes mm tokens added already + # compute audio position + start_token_offset = int(self._parameters.audio_start_token is not None) + audio_pos = text_tokens_added + mm_tokens_added + start_token_offset audio_positions.append(audio_pos) - token_ids.append( - np.full((audio_token_size_arr[source_idx],), -100, dtype=np.int64) - ) # TODO Toby: index doesnt work here - mm_tokens_added += audio_tokens + + # compute number of special tokens + num_audio_special_tokens = int(self._parameters.audio_start_token is not None) + int( + self._parameters.audio_end_token is not None + ) + + # add start tokens + if self._parameters.audio_start_token is not None: + token_ids_per_sample.append(np.array([self._parameters.audio_start_token])) + # add to loss masking spans + special_mm_tok_loss_masking_spans = np.append( + special_mm_tok_loss_masking_spans, [[audio_pos - 1, audio_pos - 1]], axis=0 + ) + # sample.loss_masking_spans = np.append(sample.loss_masking_spans, [[audio_pos-1, audio_pos-1]], axis=0) + + # add audio pad tokens + num_audio_pad_tokens = audio_token_size_arr[source_idx] + num_audio_pad_tokens -= num_audio_special_tokens # ignore start/end tokens for padding + audio_padding_tokens = np.full((num_audio_pad_tokens,), -100, dtype=np.int64) + token_ids_per_sample.append(audio_padding_tokens) + + # add end token + if self._parameters.audio_end_token is not None: + token_ids_per_sample.append(np.array([self._parameters.audio_end_token])) + # add to loss masking spans + special_mm_tok_loss_masking_spans = np.append( + special_mm_tok_loss_masking_spans, + [[audio_pos + num_audio_pad_tokens, audio_pos + num_audio_pad_tokens]], + axis=0, + ) + # sample.loss_masking_spans = np.append(sample.loss_masking_spans, [[audio_pos + num_audio_pad_tokens, audio_pos + num_audio_pad_tokens]], axis=0) + + # update mm tokens added + mm_tokens_added += num_audio_special_tokens + num_audio_pad_tokens start_pos = mm_position - token_ids.append(sample.token_ids[start_pos:]) + # add remaining text tokens + token_ids_per_sample.append(sample.token_ids[start_pos:]) + text_tokens_added += len(token_ids_per_sample[-1]) - # TODO Soham: add offsets for loss masking spans - text_tokens_added += len(token_ids[-1]) + token_ids.append(np.concatenate(token_ids_per_sample)) if sample.images: images.append(sample.images) else: @@ -643,22 +662,25 @@ def __getitem__(self, index: int) -> typing.Any: if self._parameters.use_loss_masking_spans: mm_idx = 0 - total_mm_tokens = 0 - for loss_masking_span in sample.loss_masking_spans: # TODO: check these must be sorted - mm_tokens_in_span = 0 + mm_tokens_before_span = 0 + + # sort by start of span + sample.loss_masking_spans = sample.loss_masking_spans[sample.loss_masking_spans[:, 0].argsort()] + for loss_masking_span in sample.loss_masking_spans: + mm_tokens_within_span = 0 mm_position, mm_type, source_idx = ( multimodal_positions[mm_idx] if mm_idx < len(multimodal_positions) else (float("inf"), _, _) ) - # increment mm_idx until span is reached + # increment mm_idx until span is reached, track mm tokens before span while mm_position < loss_masking_span[0]: if mm_type == "image": num_mm_tokens = image_sizes[source_idx] elif mm_type == "audio": num_mm_tokens = audio_token_size_arr[source_idx] - total_mm_tokens += num_mm_tokens + mm_tokens_before_span += num_mm_tokens mm_idx += 1 mm_position, mm_type, source_idx = ( multimodal_positions[mm_idx] @@ -672,25 +694,33 @@ def __getitem__(self, index: int) -> typing.Any: num_mm_tokens = image_sizes[source_idx] elif mm_type == "audio": num_mm_tokens = audio_token_size_arr[source_idx] - mm_tokens_in_span += num_mm_tokens + mm_tokens_within_span += num_mm_tokens mm_idx += 1 mm_position, mm_type, source_idx = ( multimodal_positions[mm_idx] if mm_idx < len(multimodal_positions) else (float("inf"), _, _) ) - loss_masking_span[0] += total_mm_tokens # increment by all mm tokens before span - loss_masking_span[1] += total_mm_tokens + mm_tokens_in_span - total_mm_tokens += mm_tokens_in_span + loss_masking_span[0] += mm_tokens_before_span # increment by all mm tokens before span + loss_masking_span[1] += mm_tokens_before_span + mm_tokens_within_span + mm_tokens_before_span += mm_tokens_within_span span = np.clip( loss_masking_span + token_count - token_start, 0, self._parameters.sequence_length + self._parameters.extra_tokens, ) - if span[1] > span[0]: + if span[1] >= span[0]: loss_masking_spans.append(span) + for span in special_mm_tok_loss_masking_spans: + # span = np.clip( + # loss_masking_span + token_count - token_start, + # 0, + # self._parameters.sequence_length + self._parameters.extra_tokens, + # ) + if span[1] >= span[0]: + loss_masking_spans.append(span) # Go to the next document. document_sampling_index += 1 token_count += document_size diff --git a/fast_llm/layers/audio_encoder/config.py b/fast_llm/layers/audio_encoder/config.py index 9503c60cc..95665901e 100644 --- a/fast_llm/layers/audio_encoder/config.py +++ b/fast_llm/layers/audio_encoder/config.py @@ -122,6 +122,18 @@ class AudioEncoderConfig(BaseModelConfig): hint=FieldHint.feature, ) + # audio start/end tokens + audio_start_token: int | None = Field( + default=None, + desc="Token id for audio start.", + hint=FieldHint.optional, + ) + audio_end_token: int | None = Field( + default=None, + desc="Token id for audio end.", + hint=FieldHint.optional, + ) + def setup_tensor_space(self, tensor_space: TensorSpace): tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.in_channels, self.num_mel_bins)) tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.out_channels, self.transformer.hidden_size)) diff --git a/fast_llm/layers/audio_encoder/preprocessing.py b/fast_llm/layers/audio_encoder/preprocessing.py index 506e026e6..8959837d6 100644 --- a/fast_llm/layers/audio_encoder/preprocessing.py +++ b/fast_llm/layers/audio_encoder/preprocessing.py @@ -1,5 +1,6 @@ import typing +import numpy as np import torch from transformers import WhisperFeatureExtractor @@ -8,6 +9,47 @@ from fast_llm.layers.audio_encoder.config import AudioEncoderConfig, AudioEncoderKwargs +def get_num_audio_tokens( + sizes, aud_padding_duration, aud_sampling_rate, aud_downsampling_k, audio_start_token, audio_end_token +): + if len(sizes) == 0: # sample has no audio + return sizes, False + to_filter = False + # account for padding + if aud_padding_duration > 0: + raw_audio_seq_length = aud_padding_duration * aud_sampling_rate + sizes = sizes.copy() # original is read-only + to_filter = bool(np.any(sizes > raw_audio_seq_length)) # filter sample where any audio is too long + sizes.fill(raw_audio_seq_length) # set all audio sizes to padded amount + + # account for mel spectogram, convolution, downsampling k + audio_token_size_arr = sizes // 160 # default hop length TODO Toby: check divisible? + audio_token_size_arr = audio_token_size_arr // ( + 2 * aud_downsampling_k + ) # convolution (2 stride) * downsampling TODO Toby: make configurable convolution + + if audio_start_token is not None: + audio_token_size_arr += 1 + if audio_end_token is not None: + audio_token_size_arr += 1 + return audio_token_size_arr, to_filter + + +def apply_audio_padding(audio, aud_padding_duration, aud_sampling_rate): + if len(audio) == 0: + return audio + # TODO Toby: check 2d + padded_audio = [] + if aud_padding_duration > 0: + raw_audio_seq_length = aud_padding_duration * aud_sampling_rate + for aud in audio: + padded = np.pad(aud, (0, raw_audio_seq_length - len(aud)), mode="constant", constant_values=0) + padded_audio.append(padded) + return padded_audio + else: + return audio + + class AudioPreprocessor(Preprocessor): def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): self._config = config diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index c9025f7f1..b4a3036fe 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -47,6 +47,8 @@ def _get_sampling_parameters( "aud_downsampling_k": self._config.model.base_model.audio_encoder.aud_downsampling_k, "aud_padding_duration": self._config.batch.aud_padding_duration, "aud_sampling_rate": self._config.model.base_model.audio_encoder.aud_sampling_rate, + "audio_start_token": self._config.model.base_model.audio_encoder.audio_start_token, + "audio_end_token": self._config.model.base_model.audio_encoder.audio_end_token, } ) return parameters if _return_dict else GPTSamplingParameters(**parameters) From fb23ef8cceb15d43612c042c8bcf43070a68d9ed Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Tue, 3 Jun 2025 17:04:29 +0000 Subject: [PATCH 16/25] conversion changes --- fast_llm/models/gpt/conversion.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index ad348ce97..568c78080 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -1075,19 +1075,19 @@ def _import_config(cls, config: dict[str, typing.Any]) -> GPTBaseModelConfig: def _export_config(cls, config: BaseModelConfig) -> dict[str, typing.Any]: # TODO Toby: implement for audio exported_config = {} - vision_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.vision_name) + audio_handler_class = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.audio_name) text_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.text_name) - for converter in vision_handler_cls._create_config_converters(): + for converter in audio_handler_class._create_config_converters(): try: values = converter.export_params( tuple( - cls._get_fast_llm_attribute(config, ("vision_encoder",) + fast_llm_name) + cls._get_fast_llm_attribute(config, ("audio_encoder",) + fast_llm_name) for fast_llm_name in converter.fast_llm_names ) ) for export_name, value in zip(converter.export_names, values, strict=True): if value is not MISSING: - set_nested_dict_value(exported_config, ("vision_config",) + export_name, value) + set_nested_dict_value(exported_config, ("audio_config",) + export_name, value) except Exception as e: raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) From d7d11352bfa828352f2d9d338390d33b95693683 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Tue, 3 Jun 2025 22:43:32 +0000 Subject: [PATCH 17/25] adding data prep sharding --- fast_llm/data/preparator/gpt_memmap/prepare.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 283a6bf80..cb34cc919 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -84,6 +84,11 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[ width, height = im.size num_pixels[idx] += width * height * 3 + num_audio = [0] * len(input_ids) + for idx, audio_lst in enumerate(batch.get(self._config.dataset.audio, [])): + for audio in audio_lst: + num_audio[idx] += len(audio) + return { "input_ids": input_ids, "image_positions": image_token_positions, @@ -91,6 +96,7 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[ "token_spans": token_spans, "num_tokens": num_tokens, "num_pixels": num_pixels, + "num_audio": num_audio, } # def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: @@ -296,6 +302,7 @@ def run(self) -> None: batched=True, num_proc=self._config.tokenize_workers, desc="Tokenizing batches", + # load_from_cache_file=False # TODO Toby: remove ) # Calculate total number of tokens @@ -305,7 +312,13 @@ def run(self) -> None: if self._config.dataset.images else 0 ) + total_audio = ( + sum(tqdm.tqdm(tokenized_dataset["num_audio"], desc="Counting audio", unit="audio")) + if self._config.dataset.audio + else 0 + ) total_tokens += total_pixels // np.dtype(self._data_type.numpy).itemsize + total_tokens += total_audio * np.float32().itemsize // np.dtype(self._data_type.numpy).itemsize # Split dataset into shards based on number of tokens num_shards = int(np.ceil(total_tokens / self._config.tokens_per_shard)) From 012a6364dfa98c54e7424ab748856568a0c21eae Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Fri, 6 Jun 2025 02:24:33 +0000 Subject: [PATCH 18/25] faster mel sepc --- fast_llm/layers/audio_encoder/preprocessing.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/fast_llm/layers/audio_encoder/preprocessing.py b/fast_llm/layers/audio_encoder/preprocessing.py index 8959837d6..916c97bb6 100644 --- a/fast_llm/layers/audio_encoder/preprocessing.py +++ b/fast_llm/layers/audio_encoder/preprocessing.py @@ -109,6 +109,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: sampling_rate=self._config.aud_sampling_rate, return_tensors="pt", max_length=30 * self._config.aud_sampling_rate, + device=self._tensor_space.distributed.device, )["input_features"] ) audio_mel = torch.stack(audio_mel, dim=0).squeeze(1) @@ -124,4 +125,11 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: # kwargs[self._transformer_kwargs.attention_mask_value] = self._mask_value audio_mel = audio_mel.to(self._tensor_space.distributed.device) + + # PAD_TO = 100 + # padding_size = PAD_TO - audio_mel.size(0) + # padding = torch.zeros(padding_size, audio_mel.size(1), audio_mel.size(2), dtype=audio_mel.dtype, device=audio_mel.device) + + # audio_mel = torch.cat((audio_mel, padding), dim=0) + kwargs[AudioEncoderKwargs.audio_mel] = audio_mel From c664444bf323740a3d2488956bb9e5a933c2aa9b Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Thu, 12 Jun 2025 21:04:01 +0000 Subject: [PATCH 19/25] adding num audio to config --- fast_llm/data/preparator/gpt_memmap/prepare.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index cb34cc919..832af202e 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -171,6 +171,7 @@ def _document_generator(): "num_documents": len(shard_dataset), # Use the length of the shard dataset directly "num_tokens": sum(len(doc["input_ids"]) for doc in shard_dataset), "num_pixels": sum(doc["num_pixels"] for doc in shard_dataset), + "num_audio": sum(doc["num_audio"] for doc in shard_dataset), } ) From ba7393970f0d9afbcd4fd241beaf7a5ec1ee1834 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Thu, 12 Jun 2025 22:41:47 +0000 Subject: [PATCH 20/25] audio encoder padding updates --- fast_llm/data/data/gpt/data.py | 2 +- fast_llm/data/dataset/gpt/sampled.py | 10 ++- .../layers/audio_encoder/preprocessing.py | 76 ++++++++++--------- fast_llm/models/gpt/model.py | 19 ++--- 4 files changed, 56 insertions(+), 51 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index a4b183ca0..8f7009784 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -64,7 +64,7 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling has_audio = False batch_audio = [] for sample in batch: - if sample.audio is not None and len(sample.audio_positions) > 0: + if sample.audio is not None and sample.audio_positions is not None: batch_audio.append([torch.from_numpy(audio) for audio in sample.audio]) has_audio = True else: diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 628de35ff..2c64c47ec 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -243,7 +243,7 @@ def _sample(self) -> None: shuffled_documents = documents_per_epoch * shuffled_epochs unshuffled_epochs = num_epochs - shuffled_epochs - yaml_data = { + yaml_data = { # TODO Toby: add audio "dataset": { "name": self._indexed_dataset.name, "documents_per_epoch": documents_per_epoch, @@ -504,7 +504,7 @@ def __getitem__(self, index: int) -> typing.Any: self._parameters.audio_start_token, self._parameters.audio_end_token, ) - audio_tokens = audio_token_size_arr.sum() + audio_tokens = int(audio_token_size_arr.sum()) document_size = text_size + image_tokens + audio_tokens @@ -705,7 +705,7 @@ def __getitem__(self, index: int) -> typing.Any: mm_tokens_before_span += mm_tokens_within_span span = np.clip( - loss_masking_span + token_count - token_start, + loss_masking_span + int(token_count) - int(token_start), 0, self._parameters.sequence_length + self._parameters.extra_tokens, ) @@ -743,13 +743,15 @@ def __getitem__(self, index: int) -> typing.Any: audio_positions = np.array(audio_positions) if audio_positions else None # Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens) + # # TODO: Toby remove/comment after testing (for testing only first sample) + # loss_masking_spans = np.append(loss_masking_spans, [[sequence_lengths[0], sequence_lengths[:-1].sum()]], axis=0) return GPTSample( token_ids=token_ids, loss_masking_spans=loss_masking_spans, sequence_lengths=sequence_lengths, images=images, image_positions=image_positions, - audio=audio, + audio=audio if len(audio) > 0 else None, audio_positions=audio_positions, ) diff --git a/fast_llm/layers/audio_encoder/preprocessing.py b/fast_llm/layers/audio_encoder/preprocessing.py index 916c97bb6..f6a696d7b 100644 --- a/fast_llm/layers/audio_encoder/preprocessing.py +++ b/fast_llm/layers/audio_encoder/preprocessing.py @@ -1,3 +1,4 @@ +import math import typing import numpy as np @@ -13,7 +14,7 @@ def get_num_audio_tokens( sizes, aud_padding_duration, aud_sampling_rate, aud_downsampling_k, audio_start_token, audio_end_token ): if len(sizes) == 0: # sample has no audio - return sizes, False + return np.array(sizes), False to_filter = False # account for padding if aud_padding_duration > 0: @@ -88,33 +89,43 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: pass def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: - audio_raw = kwargs[AudioEncoderKwargs.audio] - flattened_audio = [ - audio_arr for sequence in audio_raw for audio_arr in sequence - ] # flatten in the batch dimension - # flattened_audio_tensor = torch.stack(flattened_audio, dim=0) - # audio_inputs = self.feature_extractor(audio_raw, sampling_rate=16000, return_tensors="pt") - # self.mel_transform.to(self._tensor_space.distributed.device) - - # audio_mel = self.mel_transform(flattened_audio_tensor) - # flattened_audio_tensor = np.stack(flattened_audio, axis=0) - # audio_inputs = self.feature_extractor(flattened_audio_tensor, sampling_rate=16000, return_tensors="pt") - # audio_mel = audio_inputs['input_features'] - + # check if audio is in batch audio_mel = [] - for audio in flattened_audio: - audio_mel.append( - self.feature_extractor( - audio, - sampling_rate=self._config.aud_sampling_rate, - return_tensors="pt", - max_length=30 * self._config.aud_sampling_rate, - device=self._tensor_space.distributed.device, - )["input_features"] - ) - audio_mel = torch.stack(audio_mel, dim=0).squeeze(1) - - # audio_mel = audio_mel[:, :, :-1] # TODO Toby: check this! + if AudioEncoderKwargs.audio in kwargs: + audio_raw = kwargs[AudioEncoderKwargs.audio] + flattened_audio = [ + audio_arr for sequence in audio_raw for audio_arr in sequence + ] # flatten in the batch dimension + + for audio in flattened_audio: + audio_mel.append( + self.feature_extractor( + audio, + sampling_rate=self._config.aud_sampling_rate, + return_tensors="pt", + max_length=30 * self._config.aud_sampling_rate, + device=self._tensor_space.distributed.device, + )["input_features"] + ) + audio_mel = torch.stack(audio_mel, dim=0).squeeze(1) + curr_size = audio_mel.size(0) + else: + audio_mel = torch.tensor(audio_mel, dtype=torch.float32) + curr_size = 0 + + max_pad = math.ceil(kwargs["sequence_length"] / (kwargs["audio_encoder_sequence_length"] // 5)) + padding_size = max_pad - curr_size + padding = torch.zeros( + padding_size, + self.feature_extractor.feature_size, + self.feature_extractor.nb_max_frames, + dtype=audio_mel.dtype, + device=audio_mel.device, + ) + audio_mel = torch.cat((audio_mel, padding), dim=0) + audio_mel = audio_mel.to(self._tensor_space.distributed.device) + + kwargs[AudioEncoderKwargs.audio_mel] = audio_mel # # set attention mask # TODO Toby: fix backup attention # sequence_k = kwargs[self._transformer_kwargs.sequence_k_dim].size @@ -123,13 +134,4 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: # None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k # ] # kwargs[self._transformer_kwargs.attention_mask_value] = self._mask_value - - audio_mel = audio_mel.to(self._tensor_space.distributed.device) - - # PAD_TO = 100 - # padding_size = PAD_TO - audio_mel.size(0) - # padding = torch.zeros(padding_size, audio_mel.size(1), audio_mel.size(2), dtype=audio_mel.dtype, device=audio_mel.device) - - # audio_mel = torch.cat((audio_mel, padding), dim=0) - - kwargs[AudioEncoderKwargs.audio_mel] = audio_mel + # audio_mel = torch.rand(len(flattened_audio), 80, 3000) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index bc04fc6e7..05b15e4d2 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -484,22 +484,23 @@ def preprocess( ) kwargs[LanguageModelKwargs.tokens] = tokens - if batch.audio is not None: - kwargs[AudioEncoderKwargs.audio] = [ - [aud.to(device="cpu", dtype=torch.float32, non_blocking=True) for aud in audio] - for audio in batch.audio - ] - kwargs[AudioEncoderKwargs.audio_positions] = batch.audio_positions + if self._config.audio_encoder.enabled: + if batch.audio is not None: + kwargs[AudioEncoderKwargs.audio] = [ + [aud.to(device="cpu", dtype=torch.float32, non_blocking=True) for aud in audio] + for audio in batch.audio + ] + kwargs[AudioEncoderKwargs.audio_positions] = batch.audio_positions kwargs[LanguageModelKwargs.tokens] = tokens for preprocessor in self._preprocessors: preprocessor.preprocess(tokens, kwargs) image_patches = kwargs.get(VisionEncoderKwargs.image_patches, None) audio_mel = kwargs.get(AudioEncoderKwargs.audio_mel, None) - if image_patches is not None: - preprocessed.append((image_patches, kwargs)) - elif audio_mel is not None: + if audio_mel is not None: preprocessed.append((audio_mel, kwargs)) + elif image_patches is not None: + preprocessed.append((image_patches, kwargs)) else: preprocessed.append((tokens, kwargs)) From 5667a0a3d1a8cdcf7f36253976c40e0cbc8bf7c6 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Thu, 12 Jun 2025 23:02:37 +0000 Subject: [PATCH 21/25] configurable max pad --- fast_llm/layers/audio_encoder/preprocessing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/audio_encoder/preprocessing.py b/fast_llm/layers/audio_encoder/preprocessing.py index f6a696d7b..a326dc609 100644 --- a/fast_llm/layers/audio_encoder/preprocessing.py +++ b/fast_llm/layers/audio_encoder/preprocessing.py @@ -113,7 +113,9 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: audio_mel = torch.tensor(audio_mel, dtype=torch.float32) curr_size = 0 - max_pad = math.ceil(kwargs["sequence_length"] / (kwargs["audio_encoder_sequence_length"] // 5)) + max_pad = math.ceil( + kwargs["sequence_length"] / (kwargs["audio_encoder_sequence_length"] // self._config.aud_downsampling_k) + ) padding_size = max_pad - curr_size padding = torch.zeros( padding_size, From 9f68a5e47361f10bf8429c529d8d0d61b008c142 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Mon, 16 Jun 2025 20:18:47 +0000 Subject: [PATCH 22/25] small fix --- fast_llm/data/dataset/gpt/config.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 9819f4e81..357623b11 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -209,6 +209,11 @@ class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig): desc="Expected number of pixels in the dataset.", hint=FieldHint.optional, ) + num_audio: int | None = Field( + default=None, + desc="Expected number of audio in the dataset.", + hint=FieldHint.optional, + ) def build(self) -> "GPTMemmapDataset": from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset From c286f8d649965128ec284a8324255bb487415fe9 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Wed, 18 Jun 2025 23:34:07 +0000 Subject: [PATCH 23/25] debugging updates --- fast_llm/data/dataset/gpt/memmap.py | 23 +++++++++---- .../data/preparator/gpt_memmap/prepare.py | 12 +++---- .../layers/audio_encoder/preprocessing.py | 32 +++++++++++++------ 3 files changed, 46 insertions(+), 21 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 74a7b420a..c0353a42d 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -5,6 +5,8 @@ import numpy as np import PIL.Image +import torchaudio +import soundfile as sf from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.data.dataset.gpt.sampled import GPTSample @@ -285,6 +287,10 @@ def get( for audio_length in self._audio_lengths[idx]: audio.append(all_audio[start : start + audio_length]) start += audio_length + + print("Memmap audio length: ", self._audio_lengths[idx]) + print("Memmap audio pos: ", self._audio_positions[idx]) + print("Memmap get audio: ", audio) # TODO Soham: return loss_masking_spans sample_spans = None @@ -427,13 +433,18 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP total_im_size += pixels.size im_positions.append(document.image_positions) if document.audio is not None: - n_audio.append(len(document.audio)) - total_audio += len(document.audio) + num_audio = 0 for audio in document.audio: - audio_lengths.append(len(audio)) - bin_stream.write(audio.tobytes(order="C")) - total_aud_size += audio.size - if len(document.audio) > 0: + # audio_arr, _ = torchaudio.load(io.BytesIO(audio["bytes"])) + audio_arr, _ = sf.read(io.BytesIO(audio["bytes"])) + if len(audio_arr) > 0: + num_audio += 1 + audio_lengths.append(len(audio_arr)) + bin_stream.write(audio_arr.tobytes(order="C")) + total_aud_size += audio_arr.size + n_audio.append(num_audio) + total_audio += num_audio + if num_audio > 0: aud_positions += document.audio_positions # Update metadata diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 832af202e..888e1b634 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -154,11 +154,7 @@ def _document_generator(): ), item["images"] if self._config.dataset.images else None, item["image_positions"] if self._config.dataset.image_positions else None, - ( - np.array(item[self._config.dataset.audio], dtype=np.float32) - if self._config.dataset.audio - else None - ), + item[self._config.dataset.audio] if self._config.dataset.audio else None, item[self._config.dataset.audio_positions] if self._config.dataset.audio_positions else None, ) @@ -296,6 +292,8 @@ def run(self) -> None: # decoding bytes to images is slow and should be done only when needed if self._config.dataset.images is not None: dataset = dataset.cast_column("images", datasets.Sequence(datasets.Image(decode=False))) + if self._config.dataset.audio is not None: + dataset = dataset.cast_column("audio", datasets.Sequence(datasets.Audio(decode=False))) # Tokenize the dataset in parallel tokenized_dataset = dataset.map( @@ -303,7 +301,7 @@ def run(self) -> None: batched=True, num_proc=self._config.tokenize_workers, desc="Tokenizing batches", - # load_from_cache_file=False # TODO Toby: remove + load_from_cache_file=False # TODO Toby: remove ) # Calculate total number of tokens @@ -321,6 +319,8 @@ def run(self) -> None: total_tokens += total_pixels // np.dtype(self._data_type.numpy).itemsize total_tokens += total_audio * np.float32().itemsize // np.dtype(self._data_type.numpy).itemsize + tokenized_dataset = tokenized_dataset.shuffle(seed=42) + # Split dataset into shards based on number of tokens num_shards = int(np.ceil(total_tokens / self._config.tokens_per_shard)) shards = [ diff --git a/fast_llm/layers/audio_encoder/preprocessing.py b/fast_llm/layers/audio_encoder/preprocessing.py index a326dc609..21262fe92 100644 --- a/fast_llm/layers/audio_encoder/preprocessing.py +++ b/fast_llm/layers/audio_encoder/preprocessing.py @@ -92,10 +92,12 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: # check if audio is in batch audio_mel = [] if AudioEncoderKwargs.audio in kwargs: + print("Preprocessing Contains Audio") audio_raw = kwargs[AudioEncoderKwargs.audio] flattened_audio = [ audio_arr for sequence in audio_raw for audio_arr in sequence ] # flatten in the batch dimension + print("Preprocessing Flattened Audio: ", flattened_audio) for audio in flattened_audio: audio_mel.append( @@ -110,23 +112,35 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: audio_mel = torch.stack(audio_mel, dim=0).squeeze(1) curr_size = audio_mel.size(0) else: + print("Preprocessing No Audio") audio_mel = torch.tensor(audio_mel, dtype=torch.float32) curr_size = 0 + + print("Preprocessing Audio Mel Raw: ", audio_mel) + # compute max pad max_pad = math.ceil( kwargs["sequence_length"] / (kwargs["audio_encoder_sequence_length"] // self._config.aud_downsampling_k) ) + max_pad = 1 + max_pad = max(max_pad, curr_size) + + # add padding padding_size = max_pad - curr_size - padding = torch.zeros( - padding_size, - self.feature_extractor.feature_size, - self.feature_extractor.nb_max_frames, - dtype=audio_mel.dtype, - device=audio_mel.device, - ) - audio_mel = torch.cat((audio_mel, padding), dim=0) + if padding_size > 0: + padding = torch.zeros( + padding_size, + self.feature_extractor.feature_size, + self.feature_extractor.nb_max_frames, + dtype=audio_mel.dtype, + device=audio_mel.device, + ) + audio_mel = torch.cat((audio_mel, padding), dim=0) + + print("Preprocessing Audio Mel Final: ", audio_mel) + + # move to device audio_mel = audio_mel.to(self._tensor_space.distributed.device) - kwargs[AudioEncoderKwargs.audio_mel] = audio_mel # # set attention mask # TODO Toby: fix backup attention From eb39e7e9a9c4ca78c25a0d072ae513ff35a89cc1 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Mon, 23 Jun 2025 20:09:52 +0000 Subject: [PATCH 24/25] working 5b changes --- fast_llm/data/dataset/gpt/memmap.py | 5 +---- fast_llm/data/dataset/gpt/sampled.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index c0353a42d..c47d3cf64 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -287,10 +287,6 @@ def get( for audio_length in self._audio_lengths[idx]: audio.append(all_audio[start : start + audio_length]) start += audio_length - - print("Memmap audio length: ", self._audio_lengths[idx]) - print("Memmap audio pos: ", self._audio_positions[idx]) - print("Memmap get audio: ", audio) # TODO Soham: return loss_masking_spans sample_spans = None @@ -437,6 +433,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP for audio in document.audio: # audio_arr, _ = torchaudio.load(io.BytesIO(audio["bytes"])) audio_arr, _ = sf.read(io.BytesIO(audio["bytes"])) + audio_arr = audio_arr.astype(np.float32) if len(audio_arr) > 0: num_audio += 1 audio_lengths.append(len(audio_arr)) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 2c64c47ec..ea8eed402 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -347,7 +347,7 @@ def _sample(self) -> None: unshuffled_tokens = 0 if not self._truncate_documents: - yaml_data["unshuffled_tokens"] = unshuffled_tokens.item() * unshuffled_epochs + yaml_data["unshuffled_tokens"] = unshuffled_tokens * unshuffled_epochs self._load_yaml_data(yaml_data) if self._yaml_path is not None: self._yaml_path.parent.mkdir(parents=True, exist_ok=True) From a53c89a454decdec70356ffac601dc3e4fe6dc09 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Mon, 23 Jun 2025 23:23:10 +0000 Subject: [PATCH 25/25] small fixes --- fast_llm/data/dataset/gpt/sampled.py | 2 +- fast_llm/models/gpt/model.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index ea8eed402..42cc07298 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -751,7 +751,7 @@ def __getitem__(self, index: int) -> typing.Any: sequence_lengths=sequence_lengths, images=images, image_positions=image_positions, - audio=audio if len(audio) > 0 else None, + audio=audio if audio is not None and len(audio) > 0 else None, audio_positions=audio_positions, ) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 05b15e4d2..48f5760b6 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -444,7 +444,7 @@ def preprocess( if batch.loss_masking_spans is not None: # avoid changing input tokens labels = labels.clone() - for i, spans in enumerate(batch.loss_masking_spans): + for idx, spans in enumerate(batch.loss_masking_spans): if not spans.numel(): continue valid_spans = spans[ @@ -457,9 +457,9 @@ def preprocess( loss_mask = torch.ones_like(labels, dtype=torch.bool) for start, end in valid_spans: if sequence_first: - loss_mask[start : end + 1, i] = False + loss_mask[start : end + 1, idx] = False else: - loss_mask[i, start : end + 1] = False + loss_mask[idx, start : end + 1] = False if self._config.distillation_model is not None: kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100)