Skip to content

SigLip memory consumption increases as we scale number of GPUs #942

@khalidsaifullaah

Description

@khalidsaifullaah

From the SigLip paper my understanding is that it doesn't require any all_gather and it's always performing local b x b computation iteratively, where b is micro_batch_size (see this section from the paper).
image

So if I can fit let's say micro_batch_size 10 (in 8 GPUs), and then I increase the number of GPUs to 16, 32, 64, 128, ... my memory consumption should (more or less) remain the same (just like doing normal DDP). Or simply put, we should be able to scale world_batch_size or the number of nodes by keeping the micro_batch_size constant (in theory) right?

But what I've observed is that the memory consumption spikes as i increase world_batch_size (num of nodes) and I need to lower my micro_batch_size (even to as low as 2 for 128 devices).

  1. I'm wondering if my understanding of siglip is correct that keeping the micro_batch_size constant it allows you to scale world_batch_size? It could also be the case that they do some sort of TPU trick (i don't have much insights re that)?
  2. I have only skimmed through the siglip implementation here and I think it could also be possible that while swapping neighbors it doesn't free up the memory, and that's why the consumption accumulates...?

I could be totally wrong on both of these, so I'd be glad to know if anyone tried scaling world_batch_size and have had similar results, so i could validate my hypothesis

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions