Skip to content

TheBatmanofButler/GPT2-semi-FSDP

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 

Repository files navigation

GPT2-semi-FSDP

This is a JAX implementation of GPT-2 with a naive (almost too naive) version of ZeRO-3, or Fully-Sharded Data Parallelism (FSDP). See the paper for more details. This implementation is...technically not FSDP, but it's a step in that direction. In FSDP, we shard the optimizer state, gradients, and parameters, and activations remain unsharded. The parameters are all-gathered before each layer is used in the forward pass, and then immediately freed after. Same thing during backpropagation. This is a semi-FSDP because instead of gathering the parameters by layer, I gather them all, do the forward and backwards passes, and then free them all.

To run the training script:

cd GPT2-DDP/gpt2ddp/gpt2ddp
uv run scripts/train.py

To modify the model/training configuration, see gpt2ddp/core/config.py.

Here's a memory profile of 16 training steps. Compared to my experiments with GPT2-ZeRO2, the average memory is slightly higher, but we see a ~250 MB reduction in max memory use: image

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages