diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index bf6f0b3ae..2c374c4cb 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -23,6 +23,9 @@ HuggingFaceStorageReader, HuggingFaceStorageWriter, ) +from torch.distributed.checkpoint._consolidate_hf_safetensors import ( + consolidate_safetensors_files_on_every_rank, +) from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, @@ -352,14 +355,23 @@ def dcp_save( state_dict = self.sd_adapter.to_hf(state_dict) fqn_to_index_mapping = self.sd_adapter.fqn_to_index_mapping - - storage_writer = HuggingFaceStorageWriter( - path=checkpoint_id, - save_distributed=True, - fqn_to_index_mapping=fqn_to_index_mapping, - enable_consolidation=True, - thread_count_consolidation=5, - ) + if fqn_to_index_mapping: + storage_writer = HuggingFaceStorageWriter( + path=os.path.join(checkpoint_id, "sharded"), + save_distributed=True, + fqn_to_index_mapping=fqn_to_index_mapping, + enable_consolidation=False, + ) + else: + # the reason for only enabling consolidation if there is + # no mapping is because no mapping implies that we save all fqns + # to one file. This means we only need one rank to consolidate. + # Otherwise we should use consolidate_safetensors_files_on_every_rank + storage_writer = HuggingFaceStorageWriter( + path=checkpoint_id, + save_distributed=True, + enable_consolidation=True, + ) else: checkpoint_save_id = checkpoint_id @@ -387,6 +399,14 @@ def dcp_save( checkpoint_id=checkpoint_save_id, ) + if to_hf and self.sd_adapter.fqn_to_index_mapping: + consolidate_safetensors_files_on_every_rank( + input_dir=os.path.join(checkpoint_id, "sharded"), + output_dir=checkpoint_id, + fqn_to_index_mapping=self.sd_adapter.fqn_to_index_mapping, + num_threads=5, + ) + if enable_garbage_collection: GarbageCollection.collect("GC collection invoked by checkpointer.") diff --git a/torchtitan/protocols/state_dict_adapter.py b/torchtitan/protocols/state_dict_adapter.py index 106a7937e..3cd8dc037 100644 --- a/torchtitan/protocols/state_dict_adapter.py +++ b/torchtitan/protocols/state_dict_adapter.py @@ -75,6 +75,6 @@ def __init__(self, model_args: BaseModelArgs, hf_assets_path: str | None): self.fqn_to_index_mapping = {} for hf_key, raw_indx in hf_safetensors_indx["weight_map"].items(): indx = re.search(r"\d+", raw_indx).group(0) - self.fqn_to_index_mapping[hf_key] = indx + self.fqn_to_index_mapping[hf_key] = int(indx) else: self.fqn_to_index_mapping = None