Skip to content

Enable multi rank safetensor consolidation #1625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

ankitageorge
Copy link
Contributor

On saves, we were relying on rank-0 to consolidate the sharded safetensor files, as it was being done in the DCP finish step, which is only done on rank-0. We can instead rely on all ranks available to split this work, speeding up the overall save operation. For the 8B model, the save without consolidation was ~40s on my server with 8 ranks. An extra 20s was for consolidation. This is brought down to 10s with this change. For larger models with more files to be split across more ranks, I would expect larger gains.

(titan) [[email protected] /data/users/ankitageorge/torchtitan (multi-rank-safetensor-consolidation)]$ CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh
+ NGPU=8
+ export LOG_RANK=0,1,2,3,4
+ LOG_RANK=0,1,2,3,4
+ CONFIG_FILE=./torchtitan/models/llama3/train_configs/llama3_8b.toml
+ TORCHFT_LIGHTHOUSE=http://localhost:29510
+ PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
+ TORCHFT_LIGHTHOUSE=http://localhost:29510
+ torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0,1,2,3,4 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/llama3/train_configs/llama3_8b.toml
W0822 07:52:34.351000 2536224 site-packages/torch/distributed/run.py:803] 
W0822 07:52:34.351000 2536224 site-packages/torch/distributed/run.py:803] *****************************************
W0822 07:52:34.351000 2536224 site-packages/torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0822 07:52:34.351000 2536224 site-packages/torch/distributed/run.py:803] *****************************************
[rank3]:[titan] 2025-08-22 07:52:40,492 - root - INFO - Starting job: Llama 3 8B training
[rank3]:/home/ankitageorge/.conda/envs/titan/lib/python3.13/site-packages/torch/cuda/__init__.py:410: UserWarning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (Triggered internally at /pytorch/c10/core/AllocatorConfig.cpp:28.)
[rank3]:  torch._C._cuda_init()
[rank2]:[titan] 2025-08-22 07:52:40,768 - root - INFO - Starting job: Llama 3 8B training
[rank2]:/home/ankitageorge/.conda/envs/titan/lib/python3.13/site-packages/torch/cuda/__init__.py:410: UserWarning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (Triggered internally at /pytorch/c10/core/AllocatorConfig.cpp:28.)
[rank2]:  torch._C._cuda_init()
[rank0]:[titan] 2025-08-22 07:52:40,796 - root - INFO - Starting job: Llama 3 8B training
[rank0]:/home/ankitageorge/.conda/envs/titan/lib/python3.13/site-packages/torch/cuda/__init__.py:410: UserWarning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (Triggered internally at /pytorch/c10/core/AllocatorConfig.cpp:28.)
[rank0]:  torch._C._cuda_init()
[rank1]:[titan] 2025-08-22 07:52:40,788 - root - INFO - Starting job: Llama 3 8B training
[rank1]:/home/ankitageorge/.conda/envs/titan/lib/python3.13/site-packages/torch/cuda/__init__.py:410: UserWarning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (Triggered internally at /pytorch/c10/core/AllocatorConfig.cpp:28.)
[rank1]:  torch._C._cuda_init()
[rank4]:[titan] 2025-08-22 07:52:40,728 - root - INFO - Starting job: Llama 3 8B training
[rank4]:/home/ankitageorge/.conda/envs/titan/lib/python3.13/site-packages/torch/cuda/__init__.py:410: UserWarning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (Triggered internally at /pytorch/c10/core/AllocatorConfig.cpp:28.)
[rank4]:  torch._C._cuda_init()
[rank3]:[titan] 2025-08-22 07:52:42,690 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank3]:[titan] 2025-08-22 07:52:42,692 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]
[rank3]:[titan] 2025-08-22 07:52:42,698 - root - INFO - [GC] Initial GC collection 0.00 seconds
[rank4]:[titan] 2025-08-22 07:52:44,081 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank4]:[titan] 2025-08-22 07:52:44,084 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]
[rank4]:[titan] 2025-08-22 07:52:44,090 - root - INFO - [GC] Initial GC collection 0.00 seconds
[rank1]:[titan] 2025-08-22 07:52:44,137 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank1]:[titan] 2025-08-22 07:52:44,139 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]
[rank1]:[titan] 2025-08-22 07:52:44,145 - root - INFO - [GC] Initial GC collection 0.00 seconds
[rank0]:[titan] 2025-08-22 07:52:44,215 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank0]:[titan] 2025-08-22 07:52:44,217 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]
[rank0]:NCCL version 2.27.5+cuda12.9
[rank2]:[titan] 2025-08-22 07:52:44,238 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank2]:[titan] 2025-08-22 07:52:44,240 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]
[rank2]:[titan] 2025-08-22 07:52:44,246 - root - INFO - [GC] Initial GC collection 0.00 seconds
[rank0]:[titan] 2025-08-22 07:52:44,223 - root - INFO - [GC] Initial GC collection 0.00 seconds
[rank1]:[titan] 2025-08-22 07:52:50,420 - root - INFO - Loading tokenizer from tokenizer.json
[rank2]:[titan] 2025-08-22 07:52:50,420 - root - INFO - Loading tokenizer from tokenizer.json
[rank4]:[titan] 2025-08-22 07:52:50,421 - root - INFO - Loading tokenizer from tokenizer.json
[rank0]:[titan] 2025-08-22 07:52:50,420 - root - INFO - Loading tokenizer from tokenizer.json
[rank3]:[titan] 2025-08-22 07:52:50,420 - root - INFO - Loading tokenizer from tokenizer.json
[rank1]:[titan] 2025-08-22 07:52:50,696 - root - INFO - Preparing c4 dataset from allenai/c4
[rank2]:[titan] 2025-08-22 07:52:50,718 - root - INFO - Preparing c4 dataset from allenai/c4
[rank4]:[titan] 2025-08-22 07:52:50,696 - root - INFO - Preparing c4 dataset from allenai/c4
[rank0]:[titan] 2025-08-22 07:52:50,718 - root - INFO - Preparing c4 dataset from allenai/c4
[rank3]:[titan] 2025-08-22 07:52:50,717 - root - INFO - Preparing c4 dataset from allenai/c4
[rank4]:[titan] 2025-08-22 07:52:56,570 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0)
[rank4]:[titan] 2025-08-22 07:52:56,707 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory
[rank4]:[titan] 2025-08-22 07:52:56,723 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank4]:[titan] 2025-08-22 07:52:56,724 - root - INFO - Applied selective activation checkpointing to the model
[rank4]:[titan] 2025-08-22 07:52:56,802 - root - INFO - Applied FSDP to the model
[rank0]:[titan] 2025-08-22 07:52:56,923 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0)
[rank1]:[titan] 2025-08-22 07:52:56,872 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0)
[rank3]:[titan] 2025-08-22 07:52:56,875 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0)
[rank1]:[titan] 2025-08-22 07:52:57,012 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory
[rank1]:[titan] 2025-08-22 07:52:57,028 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank1]:[titan] 2025-08-22 07:52:57,029 - root - INFO - Applied selective activation checkpointing to the model
[rank2]:[titan] 2025-08-22 07:52:57,137 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0)
[rank0]:[titan] 2025-08-22 07:52:57,066 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250822-0752
[rank0]:[titan] 2025-08-22 07:52:57,067 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory
[rank0]:[titan] 2025-08-22 07:52:57,081 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank0]:[titan] 2025-08-22 07:52:57,082 - root - INFO - Applied selective activation checkpointing to the model
[rank1]:[titan] 2025-08-22 07:52:57,104 - root - INFO - Applied FSDP to the model
[rank3]:[titan] 2025-08-22 07:52:57,047 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory
[rank3]:[titan] 2025-08-22 07:52:57,065 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank3]:[titan] 2025-08-22 07:52:57,066 - root - INFO - Applied selective activation checkpointing to the model
[rank4]:[titan] 2025-08-22 07:52:57,048 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14
[rank4]:[titan] 2025-08-22 07:52:57,049 - root - INFO - CUDA memory usage for model: 3.77GiB(3.97%)
[rank4]:[titan] 2025-08-22 07:52:57,050 - root - WARNING - Warmup steps (200) exceed total training steps (10). Adjusting warmup steps to 10.
[rank4]:[titan] 2025-08-22 07:52:57,078 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/hf__checkpoint
[rank4]:[titan] 2025-08-22 07:52:57,078 - root - INFO - Mixed precision training is handled by fully_shard
[rank4]:[titan] 2025-08-22 07:52:57,078 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 10 (warmup 200)
[rank4]:[titan] 2025-08-22 07:52:57,078 - root - INFO - Training starts at step 1
[rank4]:[titan] 2025-08-22 07:52:57,078 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank0]:[titan] 2025-08-22 07:52:57,159 - root - INFO - Applied FSDP to the model
[rank3]:[titan] 2025-08-22 07:52:57,160 - root - INFO - Applied FSDP to the model
[rank2]:[titan] 2025-08-22 07:52:57,280 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory
[rank2]:[titan] 2025-08-22 07:52:57,297 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank2]:[titan] 2025-08-22 07:52:57,298 - root - INFO - Applied selective activation checkpointing to the model
[rank2]:[titan] 2025-08-22 07:52:57,374 - root - INFO - Applied FSDP to the model
[rank0]:[titan] 2025-08-22 07:52:57,467 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14
[rank0]:[titan] 2025-08-22 07:52:57,467 - root - INFO - CUDA memory usage for model: 3.77GiB(3.97%)
[rank0]:[titan] 2025-08-22 07:52:57,469 - root - WARNING - Warmup steps (200) exceed total training steps (10). Adjusting warmup steps to 10.
[rank0]:[titan] 2025-08-22 07:52:57,498 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/hf__checkpoint
[rank0]:[titan] 2025-08-22 07:52:57,498 - root - INFO - Mixed precision training is handled by fully_shard
[rank0]:[titan] 2025-08-22 07:52:57,498 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 10 (warmup 200)
[rank0]:[titan] 2025-08-22 07:52:57,498 - root - INFO - Training starts at step 1
[rank0]:[titan] 2025-08-22 07:52:57,498 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank1]:[titan] 2025-08-22 07:52:57,470 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14
[rank1]:[titan] 2025-08-22 07:52:57,470 - root - INFO - CUDA memory usage for model: 3.77GiB(3.97%)
[rank1]:[titan] 2025-08-22 07:52:57,472 - root - WARNING - Warmup steps (200) exceed total training steps (10). Adjusting warmup steps to 10.
[rank1]:[titan] 2025-08-22 07:52:57,500 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/hf__checkpoint
[rank1]:[titan] 2025-08-22 07:52:57,500 - root - INFO - Mixed precision training is handled by fully_shard
[rank1]:[titan] 2025-08-22 07:52:57,500 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 10 (warmup 200)
[rank1]:[titan] 2025-08-22 07:52:57,500 - root - INFO - Training starts at step 1
[rank1]:[titan] 2025-08-22 07:52:57,500 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank3]:[titan] 2025-08-22 07:52:57,459 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14
[rank3]:[titan] 2025-08-22 07:52:57,459 - root - INFO - CUDA memory usage for model: 3.77GiB(3.97%)
[rank3]:[titan] 2025-08-22 07:52:57,461 - root - WARNING - Warmup steps (200) exceed total training steps (10). Adjusting warmup steps to 10.
[rank3]:[titan] 2025-08-22 07:52:57,489 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/hf__checkpoint
[rank3]:[titan] 2025-08-22 07:52:57,489 - root - INFO - Mixed precision training is handled by fully_shard
[rank3]:[titan] 2025-08-22 07:52:57,489 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 10 (warmup 200)
[rank3]:[titan] 2025-08-22 07:52:57,489 - root - INFO - Training starts at step 1
[rank3]:[titan] 2025-08-22 07:52:57,489 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank2]:[titan] 2025-08-22 07:52:57,633 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14
[rank2]:[titan] 2025-08-22 07:52:57,633 - root - INFO - CUDA memory usage for model: 3.77GiB(3.97%)
[rank2]:[titan] 2025-08-22 07:52:57,635 - root - WARNING - Warmup steps (200) exceed total training steps (10). Adjusting warmup steps to 10.
[rank2]:[titan] 2025-08-22 07:52:57,663 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/hf__checkpoint
[rank2]:[titan] 2025-08-22 07:52:57,663 - root - INFO - Mixed precision training is handled by fully_shard
[rank2]:[titan] 2025-08-22 07:52:57,663 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 10 (warmup 200)
[rank2]:[titan] 2025-08-22 07:52:57,664 - root - INFO - Training starts at step 1
[rank2]:[titan] 2025-08-22 07:52:57,664 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank1]:[titan] 2025-08-22 07:53:02,519 - root - INFO - step:  1  loss: 12.2541  grad_norm:  4.0170  memory: 39.86GiB(41.96%)  tps: 1,492  tflops: 86.40  mfu: 8.74%
[rank1]:[titan] 2025-08-22 07:53:02,519 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-08-22 07:53:02,519 - root - INFO - step:  1  loss: 12.2541  grad_norm:  4.0170  memory: 39.86GiB(41.96%)  tps: 1,506  tflops: 87.24  mfu: 8.82%
[rank0]:[titan] 2025-08-22 07:53:02,519 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank2]:[titan] 2025-08-22 07:53:02,519 - root - INFO - step:  1  loss: 12.2541  grad_norm:  4.0170  memory: 39.86GiB(41.96%)  tps: 1,569  tflops: 90.84  mfu: 9.19%
[rank2]:[titan] 2025-08-22 07:53:02,519 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank3]:[titan] 2025-08-22 07:53:02,519 - root - INFO - step:  1  loss: 12.2541  grad_norm:  4.0170  memory: 39.86GiB(41.96%)  tps: 1,502  tflops: 86.97  mfu: 8.79%
[rank3]:[titan] 2025-08-22 07:53:02,519 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank4]:[titan] 2025-08-22 07:53:02,519 - root - INFO - step:  1  loss: 12.2541  grad_norm:  4.0170  memory: 39.86GiB(41.96%)  tps: 1,413  tflops: 81.85  mfu: 8.28%
[rank4]:[titan] 2025-08-22 07:53:02,519 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank3]:[titan] 2025-08-22 07:53:13,006 - root - INFO - step: 10  loss:  9.8309  grad_norm:  4.4717  memory: 47.38GiB(49.87%)  tps: 7,031  tflops: 407.19  mfu: 41.17%
[rank3]:[titan] 2025-08-22 07:53:13,007 - root - INFO - Saving the checkpoint (or staging if async is enabled).
[rank3]:[titan] 2025-08-22 07:53:13,007 - root - INFO - Saving a model only checkpoint in torch.float32 at last step, step 10.
[rank0]:[titan] 2025-08-22 07:53:13,006 - root - INFO - step: 10  loss:  9.8309  grad_norm:  4.4717  memory: 47.38GiB(49.87%)  tps: 7,031  tflops: 407.21  mfu: 41.17%
[rank0]:[titan] 2025-08-22 07:53:13,007 - root - INFO - Saving the checkpoint (or staging if async is enabled).
[rank0]:[titan] 2025-08-22 07:53:13,007 - root - INFO - Saving a model only checkpoint in torch.float32 at last step, step 10.
[rank1]:[titan] 2025-08-22 07:53:13,006 - root - INFO - step: 10  loss:  9.8309  grad_norm:  4.4717  memory: 47.38GiB(49.87%)  tps: 7,031  tflops: 407.19  mfu: 41.17%
[rank1]:[titan] 2025-08-22 07:53:13,006 - root - INFO - Saving the checkpoint (or staging if async is enabled).
[rank1]:[titan] 2025-08-22 07:53:13,006 - root - INFO - Saving a model only checkpoint in torch.float32 at last step, step 10.
[rank2]:[titan] 2025-08-22 07:53:13,006 - root - INFO - step: 10  loss:  9.8309  grad_norm:  4.4717  memory: 47.38GiB(49.87%)  tps: 7,031  tflops: 407.19  mfu: 41.17%
[rank2]:[titan] 2025-08-22 07:53:13,006 - root - INFO - Saving the checkpoint (or staging if async is enabled).
[rank2]:[titan] 2025-08-22 07:53:13,006 - root - INFO - Saving a model only checkpoint in torch.float32 at last step, step 10.
[rank4]:[titan] 2025-08-22 07:53:13,006 - root - INFO - step: 10  loss:  9.8309  grad_norm:  4.4717  memory: 47.38GiB(49.87%)  tps: 7,031  tflops: 407.19  mfu: 41.17%
[rank4]:[titan] 2025-08-22 07:53:13,006 - root - INFO - Saving the checkpoint (or staging if async is enabled).
[rank4]:[titan] 2025-08-22 07:53:13,006 - root - INFO - Saving a model only checkpoint in torch.float32 at last step, step 10.
[rank0]:Time to save 1:  39.62316728616133  seconds
[rank4]:Time to save 1:  39.62515217997134  seconds
[rank0]:/home/ankitageorge/.conda/envs/titan/lib/python3.13/site-packages/torch/distributed/distributed_c10d.py:4814: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. 
[rank0]:  warnings.warn(  # warn only once
[rank1]:Time to save 1:  39.62614830210805  seconds
[rank2]:Time to save 1:  39.62664035195485  seconds
[rank3]:Time to save 1:  39.6256582220085  seconds
[rank4]:/home/ankitageorge/.conda/envs/titan/lib/python3.13/site-packages/torch/distributed/distributed_c10d.py:4814: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. 
[rank4]:  warnings.warn(  # warn only once
[rank2]:/home/ankitageorge/.conda/envs/titan/lib/python3.13/site-packages/torch/distributed/distributed_c10d.py:4814: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. 
[rank2]:  warnings.warn(  # warn only once
[rank3]:/home/ankitageorge/.conda/envs/titan/lib/python3.13/site-packages/torch/distributed/distributed_c10d.py:4814: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. 
[rank3]:  warnings.warn(  # warn only once
[rank4]:[titan] 2025-08-22 07:54:02,586 - root - INFO - [GC] GC collection invoked by checkpointer. 0.03 seconds
[rank4]:[titan] 2025-08-22 07:54:02,587 - root - INFO - Training completed
[rank4]:[titan] 2025-08-22 07:54:02,587 - root - INFO - Destroying the purge thread.
[rank0]:Time to save 2:  49.51617495715618  seconds
[rank4]:Time to save 2:  49.517131249886006  seconds
[rank1]:/home/ankitageorge/.conda/envs/titan/lib/python3.13/site-packages/torch/distributed/distributed_c10d.py:4814: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. 
[rank1]:  warnings.warn(  # warn only once
[rank1]:[titan] 2025-08-22 07:54:02,574 - root - INFO - [GC] GC collection invoked by checkpointer. 0.02 seconds
[rank1]:[titan] 2025-08-22 07:54:02,575 - root - INFO - Training completed
[rank1]:[titan] 2025-08-22 07:54:02,575 - root - INFO - Destroying the purge thread.
[rank0]:[titan] 2025-08-22 07:54:02,575 - root - INFO - [GC] GC collection invoked by checkpointer. 0.02 seconds
[rank0]:[titan] 2025-08-22 07:54:02,576 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank1]:Time to save 2:  49.51811962015927  seconds
[rank2]:[titan] 2025-08-22 07:54:02,574 - root - INFO - [GC] GC collection invoked by checkpointer. 0.02 seconds
[rank2]:[titan] 2025-08-22 07:54:02,575 - root - INFO - Training completed
[rank2]:[titan] 2025-08-22 07:54:02,575 - root - INFO - Destroying the purge thread.
[rank2]:Time to save 2:  49.51844248594716  seconds
[rank3]:Time to save 2:  49.51776701770723  seconds
[rank3]:[titan] 2025-08-22 07:54:02,575 - root - INFO - [GC] GC collection invoked by checkpointer. 0.02 seconds
[rank3]:[titan] 2025-08-22 07:54:02,576 - root - INFO - Training completed
[rank3]:[titan] 2025-08-22 07:54:02,576 - root - INFO - Destroying the purge thread.
[rank0]:[titan] 2025-08-22 07:54:04,576 - root - INFO - Training completed
[rank0]:[titan] 2025-08-22 07:54:04,577 - root - INFO - Destroying the purge thread.
[rank3]:[titan] 2025-08-22 07:54:05,224 - root - INFO - Process group destroyed
[rank4]:[titan] 2025-08-22 07:54:05,272 - root - INFO - Process group destroyed
[rank2]:[titan] 2025-08-22 07:54:05,272 - root - INFO - Process group destroyed
[rank1]:[titan] 2025-08-22 07:54:05,388 - root - INFO - Process group destroyed
[rank0]:[titan] 2025-08-22 07:54:05,924 - root - INFO - Process group destroyed
[W822 07:54:09.143827219 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator())

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 22, 2025
@ankitageorge ankitageorge marked this pull request as ready for review August 22, 2025 17:58
@fegin
Copy link
Contributor

fegin commented Aug 22, 2025

This is good. One general question, can the work be split across all ranks on different nodes? Or does this require the assumption that the underlying file system is distributed.

output_dir=checkpoint_id,
fqn_to_index_mapping=self.sd_adapter.fqn_to_index_mapping,
num_threads=5,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this API take PG as an argument?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to be the case. I don't know how post training will be and how people are going to split the nodes. But if not all ranks join the checkpoint save and load, dist.get_world() is not correct. I'm not familiar with the post training use case though. cc., @tianyu-l

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin
I don't know either. cc: @allenwang28 if you know.

Does dist.get_rank() and dist.get_world_size() rely on NCCL PG? It sounds a bit strange & unnecessary that CPU consolidating relies on GPU info.

@ankitageorge
Copy link
Contributor Author

This is good. One general question, can the work be split across all ranks on different nodes? Or does this require the assumption that the underlying file system is distributed.

@fegin If multiple nodes are used, then we have to assume that some distributed filesystem is being used. This is true for both the single rank and multi-rank consolidation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants