- 
                Notifications
    You must be signed in to change notification settings 
- Fork 1.2k
Description
From the SigLip paper my understanding is that it doesn't require any all_gather and it's always performing local b x b computation iteratively, where b is micro_batch_size (see this section from the paper).

So if I can fit let's say micro_batch_size 10 (in 8 GPUs), and then I increase the number of GPUs to 16, 32, 64, 128, ... my memory consumption should (more or less) remain the same (just like doing normal DDP). Or simply put, we should be able to scale world_batch_size or the number of nodes by keeping the micro_batch_size constant (in theory) right?
But what I've observed is that the memory consumption spikes as i increase world_batch_size (num of nodes) and I need to lower my micro_batch_size (even to as low as 2 for 128 devices).
- I'm wondering if my understanding of siglip is correct that keeping the micro_batch_sizeconstant it allows you to scaleworld_batch_size? It could also be the case that they do some sort of TPU trick (i don't have much insights re that)?
- I have only skimmed through the siglip implementation here and I think it could also be possible that while swapping neighbors it doesn't free up the memory, and that's why the consumption accumulates...?
I could be totally wrong on both of these, so I'd be glad to know if anyone tried scaling world_batch_size and have had similar results, so i could validate my hypothesis