From a7f5a67223b0e7d45e57d842f21186cd99fadd14 Mon Sep 17 00:00:00 2001 From: ankitageorge Date: Thu, 14 Aug 2025 07:20:38 -0700 Subject: [PATCH 1/7] consolidate on every rank --- torchtitan/components/checkpoint.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 478062e8e1..f6c92f528e 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -23,6 +23,7 @@ HuggingFaceStorageReader, HuggingFaceStorageWriter, ) +from torch.distributed.checkpoint._consolidate_hf_safetensors import consolidate_safetensors_files_on_every_rank from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, @@ -363,11 +364,9 @@ def dcp_save( fqn_to_index_mapping = self.sd_adapter.fqn_to_index_mapping storage_writer = HuggingFaceStorageWriter( - path=checkpoint_id, + path=os.path.join(checkpoint_id, "sharded"), save_distributed=True, fqn_to_index_mapping=fqn_to_index_mapping, - enable_consolidation=True, - thread_count_consolidation=5, ) else: @@ -396,6 +395,9 @@ def dcp_save( checkpoint_id=checkpoint_save_id, ) + if to_hf: + consolidate_safetensors_files_on_every_rank(input_path=os.path.join(checkpoint_id, "sharded"), output_path=checkpoint_id, fqn_to_index_mapping=fqn_to_index_mapping, num_threads=5) + if enable_garbage_collection: GarbageCollection.collect("GC collection invoked by checkpointer.") From a29f1ea7865bee4c3659b7ba9d81a8d8c8572076 Mon Sep 17 00:00:00 2001 From: ankitageorge Date: Fri, 22 Aug 2025 07:39:29 -0700 Subject: [PATCH 2/7] it works --- run_train.sh | 2 +- torchtitan/components/checkpoint.py | 6 +++++- torchtitan/models/llama3/train_configs/llama3_8b.toml | 7 ++++--- torchtitan/protocols/state_dict_adapter.py | 2 +- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/run_train.sh b/run_train.sh index 01dddd0abd..822a163b2a 100755 --- a/run_train.sh +++ b/run_train.sh @@ -11,7 +11,7 @@ set -ex # e.g. # LOG_RANK=0,1 NGPU=4 ./run_train.sh NGPU=${NGPU:-"8"} -export LOG_RANK=${LOG_RANK:-0} +export LOG_RANK=${LOG_RANK:-0,1,2,3,4} CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"} TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"} diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index f6c92f528e..56b4862a1e 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -367,6 +367,8 @@ def dcp_save( path=os.path.join(checkpoint_id, "sharded"), save_distributed=True, fqn_to_index_mapping=fqn_to_index_mapping, + # enable_consolidation=True, + # thread_count_consolidation=5, ) else: @@ -389,14 +391,16 @@ def dcp_save( async_stager=self.stager, ) else: + start = time.monotonic() ret = dcp.save( state_dict, storage_writer=storage_writer, checkpoint_id=checkpoint_save_id, ) + print("Time to save: ", time.monotonic() - start, " seconds") if to_hf: - consolidate_safetensors_files_on_every_rank(input_path=os.path.join(checkpoint_id, "sharded"), output_path=checkpoint_id, fqn_to_index_mapping=fqn_to_index_mapping, num_threads=5) + consolidate_safetensors_files_on_every_rank(input_dir=os.path.join(checkpoint_id, "sharded"), output_dir=checkpoint_id, fqn_to_index_mapping=self.sd_adapter.fqn_to_index_mapping, num_threads=5) if enable_garbage_collection: GarbageCollection.collect("GC collection invoked by checkpointer.") diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml index f3c2931a55..bfe8732102 100644 --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml @@ -33,7 +33,7 @@ warmup_steps = 200 # lr scheduler warm up local_batch_size = 1 seq_len = 8192 max_norm = 1.0 # grad norm clipping -steps = 1000 +steps = 10 compile = false dataset = "c4" @@ -45,12 +45,13 @@ pipeline_parallel_degree = 1 context_parallel_degree = 1 [checkpoint] -enable_checkpoint = false -folder = "checkpoint" +enable_checkpoint = true +folder = "hf__checkpoint" interval = 500 last_save_model_only = true export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] +last_save_in_hf = true [activation_checkpoint] mode = "selective" # ["none", "selective", "full"] diff --git a/torchtitan/protocols/state_dict_adapter.py b/torchtitan/protocols/state_dict_adapter.py index ce03d732d6..c371dcfb42 100644 --- a/torchtitan/protocols/state_dict_adapter.py +++ b/torchtitan/protocols/state_dict_adapter.py @@ -75,6 +75,6 @@ def __init__(self, model_args: BaseModelArgs, hf_assets_path: str | None): self.fqn_to_index_mapping = {} for hf_key, raw_indx in hf_safetensors_indx["weight_map"].items(): indx = re.search(r"\d+", raw_indx).group(0) - self.fqn_to_index_mapping[hf_key] = indx + self.fqn_to_index_mapping[hf_key] = int(indx) else: self.fqn_to_index_mapping = None From 6f80fcb1ae68b3088eb5c182f879a2d172c219c7 Mon Sep 17 00:00:00 2001 From: ankitageorge Date: Fri, 22 Aug 2025 07:59:49 -0700 Subject: [PATCH 3/7] clean ups --- run_train.sh | 6 +----- torchtitan/components/checkpoint.py | 14 +++++++++----- .../models/llama3/train_configs/llama3_8b.toml | 15 +++------------ 3 files changed, 13 insertions(+), 22 deletions(-) diff --git a/run_train.sh b/run_train.sh index 822a163b2a..1592e56cc3 100755 --- a/run_train.sh +++ b/run_train.sh @@ -1,21 +1,17 @@ #!/usr/bin/bash # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. - # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - set -ex - # use envs as local overwrites for convenience # e.g. # LOG_RANK=0,1 NGPU=4 ./run_train.sh NGPU=${NGPU:-"8"} -export LOG_RANK=${LOG_RANK:-0,1,2,3,4} +export LOG_RANK=${LOG_RANK:-0} CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"} TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"} - PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 56b4862a1e..b0a1caa318 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -23,7 +23,9 @@ HuggingFaceStorageReader, HuggingFaceStorageWriter, ) -from torch.distributed.checkpoint._consolidate_hf_safetensors import consolidate_safetensors_files_on_every_rank +from torch.distributed.checkpoint._consolidate_hf_safetensors import ( + consolidate_safetensors_files_on_every_rank, +) from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, @@ -367,8 +369,6 @@ def dcp_save( path=os.path.join(checkpoint_id, "sharded"), save_distributed=True, fqn_to_index_mapping=fqn_to_index_mapping, - # enable_consolidation=True, - # thread_count_consolidation=5, ) else: @@ -397,10 +397,14 @@ def dcp_save( storage_writer=storage_writer, checkpoint_id=checkpoint_save_id, ) - print("Time to save: ", time.monotonic() - start, " seconds") if to_hf: - consolidate_safetensors_files_on_every_rank(input_dir=os.path.join(checkpoint_id, "sharded"), output_dir=checkpoint_id, fqn_to_index_mapping=self.sd_adapter.fqn_to_index_mapping, num_threads=5) + consolidate_safetensors_files_on_every_rank( + input_dir=os.path.join(checkpoint_id, "sharded"), + output_dir=checkpoint_id, + fqn_to_index_mapping=self.sd_adapter.fqn_to_index_mapping, + num_threads=5, + ) if enable_garbage_collection: GarbageCollection.collect("GC collection invoked by checkpointer.") diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml index bfe8732102..597e69cb72 100644 --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml @@ -1,39 +1,32 @@ # torchtitan Config.toml # NOTE: this toml config is a preset for 64 A100 GPUs. - [job] dump_folder = "./outputs" description = "Llama 3 8B training" - [profiling] enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 100 - [metrics] log_freq = 10 enable_tensorboard = true save_tb_folder = "tb" - [model] name = "llama3" flavor = "8B" hf_assets_path = "./assets/hf/Llama-3.1-8B" # converters = ["float8"] - [optimizer] name = "AdamW" lr = 3e-4 eps = 1e-8 - [lr_scheduler] warmup_steps = 200 # lr scheduler warm up - [training] local_batch_size = 1 seq_len = 8192 max_norm = 1.0 # grad norm clipping -steps = 10 +steps = 1000 compile = false dataset = "c4" @@ -45,18 +38,16 @@ pipeline_parallel_degree = 1 context_parallel_degree = 1 [checkpoint] -enable_checkpoint = true -folder = "hf__checkpoint" +enable_checkpoint = false +folder = "checkpoint" interval = 500 last_save_model_only = true export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] -last_save_in_hf = true [activation_checkpoint] mode = "selective" # ["none", "selective", "full"] selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy - [float8] enable_fsdp_float8_all_gather = false precompute_float8_dynamic_scale_for_fsdp = false From 85e8453b58508bf4fffdf0ba49c777f87471a82f Mon Sep 17 00:00:00 2001 From: ankitageorge Date: Fri, 22 Aug 2025 08:02:11 -0700 Subject: [PATCH 4/7] clean ups --- run_train.sh | 6 +++++- torchtitan/components/checkpoint.py | 1 - torchtitan/models/llama3/train_configs/llama3_8b.toml | 8 ++++++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/run_train.sh b/run_train.sh index 1592e56cc3..0f9d1829b7 100755 --- a/run_train.sh +++ b/run_train.sh @@ -1,9 +1,12 @@ #!/usr/bin/bash # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. + # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + set -ex + # use envs as local overwrites for convenience # e.g. # LOG_RANK=0,1 NGPU=4 ./run_train.sh @@ -12,7 +15,8 @@ export LOG_RANK=${LOG_RANK:-0} CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"} TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"} -PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ + +PYTORCH_ALLOC_CONF="expandable_segments:True" \ TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index b0a1caa318..2f9e04b09a 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -391,7 +391,6 @@ def dcp_save( async_stager=self.stager, ) else: - start = time.monotonic() ret = dcp.save( state_dict, storage_writer=storage_writer, diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml index 597e69cb72..f3c2931a55 100644 --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml @@ -1,27 +1,34 @@ # torchtitan Config.toml # NOTE: this toml config is a preset for 64 A100 GPUs. + [job] dump_folder = "./outputs" description = "Llama 3 8B training" + [profiling] enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 100 + [metrics] log_freq = 10 enable_tensorboard = true save_tb_folder = "tb" + [model] name = "llama3" flavor = "8B" hf_assets_path = "./assets/hf/Llama-3.1-8B" # converters = ["float8"] + [optimizer] name = "AdamW" lr = 3e-4 eps = 1e-8 + [lr_scheduler] warmup_steps = 200 # lr scheduler warm up + [training] local_batch_size = 1 seq_len = 8192 @@ -48,6 +55,7 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = "selective" # ["none", "selective", "full"] selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy + [float8] enable_fsdp_float8_all_gather = false precompute_float8_dynamic_scale_for_fsdp = false From 08b6be8cd95b3f62faf596335cd6f325308cbd27 Mon Sep 17 00:00:00 2001 From: ankitageorge Date: Fri, 22 Aug 2025 08:03:00 -0700 Subject: [PATCH 5/7] clean ups --- run_train.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run_train.sh b/run_train.sh index 0f9d1829b7..01dddd0abd 100755 --- a/run_train.sh +++ b/run_train.sh @@ -16,7 +16,7 @@ CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"} -PYTORCH_ALLOC_CONF="expandable_segments:True" \ +PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ From 58c51a9805789b7494482624eff9a60e83004891 Mon Sep 17 00:00:00 2001 From: ankitageorge Date: Fri, 22 Aug 2025 08:50:43 -0700 Subject: [PATCH 6/7] no fqn mapping case --- torchtitan/components/checkpoint.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 2f9e04b09a..be4ac7508b 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -369,8 +369,12 @@ def dcp_save( path=os.path.join(checkpoint_id, "sharded"), save_distributed=True, fqn_to_index_mapping=fqn_to_index_mapping, + # the reason for only enabling consolidation if there is + # no mapping is because no mapping implies that we save all fqns + # to one file. This means we only need one rank to consolidate. + # Otherwise we should use consolidate_safetensors_files_on_every_rank + enable_consolidation=fqn_to_index_mapping is None, ) - else: checkpoint_save_id = checkpoint_id @@ -397,7 +401,7 @@ def dcp_save( checkpoint_id=checkpoint_save_id, ) - if to_hf: + if to_hf and self.sd_adapter.fqn_to_index_mapping: consolidate_safetensors_files_on_every_rank( input_dir=os.path.join(checkpoint_id, "sharded"), output_dir=checkpoint_id, From 6eec347ea61f14206f280bf4aab0e604c2d07b0f Mon Sep 17 00:00:00 2001 From: ankitageorge Date: Fri, 22 Aug 2025 10:33:10 -0700 Subject: [PATCH 7/7] test should pass --- torchtitan/components/checkpoint.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 573839f429..2c374c4cbe 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -355,17 +355,24 @@ def dcp_save( state_dict = self.sd_adapter.to_hf(state_dict) fqn_to_index_mapping = self.sd_adapter.fqn_to_index_mapping - - storage_writer = HuggingFaceStorageWriter( - path=os.path.join(checkpoint_id, "sharded"), - save_distributed=True, - fqn_to_index_mapping=fqn_to_index_mapping, + if fqn_to_index_mapping: + storage_writer = HuggingFaceStorageWriter( + path=os.path.join(checkpoint_id, "sharded"), + save_distributed=True, + fqn_to_index_mapping=fqn_to_index_mapping, + enable_consolidation=False, + ) + else: # the reason for only enabling consolidation if there is # no mapping is because no mapping implies that we save all fqns # to one file. This means we only need one rank to consolidate. # Otherwise we should use consolidate_safetensors_files_on_every_rank - enable_consolidation=fqn_to_index_mapping is None, - ) + storage_writer = HuggingFaceStorageWriter( + path=checkpoint_id, + save_distributed=True, + enable_consolidation=True, + ) + else: checkpoint_save_id = checkpoint_id