-
Notifications
You must be signed in to change notification settings - Fork 484
Enable multi rank safetensor consolidation #1625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This is good. One general question, can the work be split across all ranks on different nodes? Or does this require the assumption that the underlying file system is distributed. |
output_dir=checkpoint_id, | ||
fqn_to_index_mapping=self.sd_adapter.fqn_to_index_mapping, | ||
num_threads=5, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this API take PG as an argument?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, it's just inferring https://github.com/pytorch/pytorch/blob/e20f6d798606f3245686e950c43635bbe526232d/torch/distributed/checkpoint/_consolidate_hf_safetensors.py#L650. Do you think it should?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to be the case. I don't know how post training will be and how people are going to split the nodes. But if not all ranks join the checkpoint save and load, dist.get_world()
is not correct. I'm not familiar with the post training use case though. cc., @tianyu-l
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fegin
I don't know either. cc: @allenwang28 if you know.
Does dist.get_rank()
and dist.get_world_size()
rely on NCCL PG? It sounds a bit strange & unnecessary that CPU consolidating relies on GPU info.
@fegin If multiple nodes are used, then we have to assume that some distributed filesystem is being used. This is true for both the single rank and multi-rank consolidation |
On saves, we were relying on rank-0 to consolidate the sharded safetensor files, as it was being done in the DCP finish step, which is only done on rank-0. We can instead rely on all ranks available to split this work, speeding up the overall save operation. For the 8B model, the save without consolidation was ~40s on my server with 8 ranks. An extra 20s was for consolidation. This is brought down to 10s with this change. For larger models with more files to be split across more ranks, I would expect larger gains.