This is a JAX implementation of GPT-2 with a naive (almost too naive) version of ZeRO-3, or Fully-Sharded Data Parallelism (FSDP). See the paper for more details. This implementation is...technically not FSDP, but it's a step in that direction. In FSDP, we shard the optimizer state, gradients, and parameters, and activations remain unsharded. The parameters are all-gathered before each layer is used in the forward pass, and then immediately freed after. Same thing during backpropagation. This is a semi-FSDP because instead of gathering the parameters by layer, I gather them all, do the forward and backwards passes, and then free them all.
To run the training script:
cd GPT2-DDP/gpt2ddp/gpt2ddp
uv run scripts/train.py
To modify the model/training configuration, see gpt2ddp/core/config.py
.
Here's a memory profile of 16 training steps. Compared to my experiments with GPT2-ZeRO2, the average memory is slightly higher, but we see a ~250 MB reduction in max memory use: