-
Notifications
You must be signed in to change notification settings - Fork 6k
Add gradient checkpointing support for AutoencoderKLWan #11105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add gradient checkpointing support for AutoencoderKLWan #11105
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, very nice!
@bot /style |
My PR failed to pass the test code. Would you like to look at this link? test failure case: test_effective_gradient_checkpointing |
@victolee0 I implemented it the same way you did. It works fine for the forward pass, but during the backward pass, the cache index gets mixed up and goes out of bounds. I might need to use a dictionary for the cache mechanism instead of an index-based list. |
@quickdahuk Code
test case error
|
@victolee0 I've implemented gradient checkpointing for the decoder. I don't need it for the encoder now. The training is working fine for me. I implemented checkpointing for each frame but didn't put it inside decoder operations. |
@quickdahuk
|
@victolee0 The gradient calculated (<= 1.1367e-04) differs slightly. This difference may not be significant, but it's better to have a much lower difference. @a-r-r-o-w, Do you see anything suspicious in our implementation? |
@victolee0 I just found that if I used use_reentry=False, then it matches perfectly. I've updated the code. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
The differing result is definitely suspicious. I would actually prefer having a refactor of the VAE so that we don't have to work with cache indexing in the way it's done here (and rather have it behave similar to what's done in CogVideoX and Mochi). If it's not possible to do it without indexing, we can consider removing the cache completely too.. The speed difference from using cache vs not is minimal from my past benchmarks with CogVideoX VAE and removing it, given it's complicated to do and probably doesn't save much time here, is a tradeoff that we could possibly make. cc @hlky @yiyixuxu Would you be able to take a look? |
quickdahuk's updated code passes the test code successfully. Should I create a new commit based on this code? I plan to commit by replacing the |
If the test passes just for the decoder, let's add support for just the decoder now. We can revisit if someone wants encoder support too. I'm not sure if anyone on the team has the bandwidth to refactor the implementation at the moment, so it may take time until that's possible. |
I've made the commit |
else: | ||
return self._decode(x, feat_cache, feat_idx) | ||
|
||
def _decode(self, x, in_cache=None, feat_idx=[0], latent=0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's this new argument latent
? it is not used here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've removed the latent
argument as it wasn't being used in this implementation. Thank you for pointing that out.
@a-r-r-o-w @yiyixuxu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
What does this PR do?
Fixes #11071 (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@a-r-r-o-w