-
Notifications
You must be signed in to change notification settings - Fork 255
add segmentation based (dice) loss #1366
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Hi @rdyro . This is my first pass. I'll fix the lint issue (wasn't aware of this specific one, will keep in mind) based on what we decide here if thats okay. Thanks again :) |
Thank you very much @aymuos15 ! Looks pretty good to me. @selamw1 or @rajasekharporeddy would you have time to thoroughly review this PR? Thanks ! |
b9b2aba
to
a484018
Compare
The doctests were failing hence the repeated commits. Thank you very much @rajasekharporeddy for the review. |
@vroulet The PR looks good to me. |
optax/losses/_segmentation.py
Outdated
apply_softmax, | ||
# True branch: apply sigmoid or softmax based on number of classes | ||
lambda: jax.lax.cond( | ||
predictions.shape[-1] == 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should probably be python control flow since it's a static condition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah okay. The pylint test was failing here and that's why this was done. Will try to revert this in a way the tests don't fail. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please revert. If pylint still complains about too many conditionals you can try refactoring a little (it just counts if statements and their nesting within a single parent scope).
If it's too annoying still, feel free to add a pylint ignore comment https://pylint.pycqa.org/en/latest/user_guide/messages/message_control.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much for the info and the suggestions.
optax/losses/_segmentation.py
Outdated
|
||
# Convert logits to probabilities using jax.lax.cond to avoid control flow | ||
probs = jax.lax.cond( | ||
apply_softmax, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likewise static condition, python control flow should be enough
|
||
|
||
def dice_loss( | ||
predictions: chex.Array, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If these are the logits I think you should actually call the input logits
. A lot of other optax functions have the parameter named like that. Just my opinion. I am gonna try out the code from this PR for my own uses and see how it goes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's fair. Thanks! Please update here if you face any issues. Then I will accordingly update this. Otherwise I will just make that small change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see you have an apply_softmax input parameter. Is there a scenario where you wouldn't want softmax applied? Just curious.
Also I tried out the binary dice loss yesterday and it seems to be working fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I had a re-look at the code (sorry Its been kinda long since this started haha).
Given that the the apply_softmax exists, I think ill resort to no change at this point.
There are many non-standard scenarios where you would need the logits themselves and hence i guess it is a good practice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thats fair. I was just wondering what kind of scenarios that would be. No need for changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for this and for incorporating all the feedback!
@aymuos15 could you reindent your changes to use indent-width=2? This is needed to pass the internal checks, but I want you to get proper credit in git history. |
Also it does not have to be this PR. But I think we should add tversky loss as well. Its basically dice loss but you can weight false positives and false negatives using input parameters. |
@rdyro Really appreciate the proper credit, thank you very much. However, apologies, but this maybe an extremely naive question -- Could you please let me know a way to check for this? Looking around, I could not find a way to confirm through a linter about the index-width. I found yapf and autopep8 but they show that the other files (within the repo) are not completely using indent-width=2 so im just a bit confused. @Logon27 I was planning to do another very small PR just for Dice + CE (and I think that should be more than enough because these two are the most common losses) anyways. I will include this as well. Thank you! |
I unfortunately don't know how to set up a linter for this. You can try formatting your code with pyink (a google-compatible code formatter based on black): $ pip install pyink
$ pyink --pyink --pyink-indentation=2 --pyink-use-majority-quotes --line-length=80 {FILENAME} |
This is what I get with the current state of the commit pyink --check --pyink-indentation=2 --pyink-use-majority-quotes --line-length=80 _segmentation.py
pyink --check --pyink-indentation=2 --pyink-use-majority-quotes --line-length=80 _segmentation_test.py
pyink --check --pyink-indentation=2 --pyink-use-majority-quotes --line-length=80 __init__.py
All done! ✨ 🍰 ✨
1 file would be left unchanged.
All done! ✨ 🍰 ✨
1 file would be left unchanged.
All done! ✨ 🍰 ✨
1 file would be left unchanged. |
f13212d
into
google-deepmind:main
As proposed in #1358, it would be nice to have segmentation based losses as well. A quick google search (and perplexity's search) shows that there really isn't a standard implementation for dice loss in JAX.
The issue mentions DSC + CE as well. However, I am actually not very sure how exactly this wants to be implemented, given that there is no great reference standard to go by. My personal suggestion would be to just focus on dice first, and then do a separate PR for a follow up of including more variants. So I would love some advice on what would be the best way to proceed with this. Happy to include/remove anything.
For a first pass, I tried to cover the basic bases which I generally do, and keep a reasonable amount of tests to cover said bases.
Just to note: If it means anything to anyone here, there really isn't a standard followed implementation part of the torch ecosystem either.