Skip to content

Conversation

aymuos15
Copy link
Contributor

As proposed in #1358, it would be nice to have segmentation based losses as well. A quick google search (and perplexity's search) shows that there really isn't a standard implementation for dice loss in JAX.

The issue mentions DSC + CE as well. However, I am actually not very sure how exactly this wants to be implemented, given that there is no great reference standard to go by. My personal suggestion would be to just focus on dice first, and then do a separate PR for a follow up of including more variants. So I would love some advice on what would be the best way to proceed with this. Happy to include/remove anything.

For a first pass, I tried to cover the basic bases which I generally do, and keep a reasonable amount of tests to cover said bases.

Just to note: If it means anything to anyone here, there really isn't a standard followed implementation part of the torch ecosystem either.

@aymuos15
Copy link
Contributor Author

aymuos15 commented Jun 27, 2025

Hi @rdyro . This is my first pass. I'll fix the lint issue (wasn't aware of this specific one, will keep in mind) based on what we decide here if thats okay. Thanks again :)

@vroulet
Copy link
Collaborator

vroulet commented Jun 27, 2025

Thank you very much @aymuos15 ! Looks pretty good to me.

@selamw1 or @rajasekharporeddy would you have time to thoroughly review this PR? Thanks !

@aymuos15 aymuos15 force-pushed the seg_losses branch 2 times, most recently from b9b2aba to a484018 Compare July 11, 2025 04:43
@aymuos15
Copy link
Contributor Author

The doctests were failing hence the repeated commits.

Thank you very much @rajasekharporeddy for the review.

@rajasekharporeddy
Copy link
Collaborator

@vroulet The PR looks good to me.

apply_softmax,
# True branch: apply sigmoid or softmax based on number of classes
lambda: jax.lax.cond(
predictions.shape[-1] == 1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be python control flow since it's a static condition.

Copy link
Contributor Author

@aymuos15 aymuos15 Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay. The pylint test was failing here and that's why this was done. Will try to revert this in a way the tests don't fail. Thanks!

Copy link
Collaborator

@rdyro rdyro Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please revert. If pylint still complains about too many conditionals you can try refactoring a little (it just counts if statements and their nesting within a single parent scope).

If it's too annoying still, feel free to add a pylint ignore comment https://pylint.pycqa.org/en/latest/user_guide/messages/message_control.html

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for the info and the suggestions.


# Convert logits to probabilities using jax.lax.cond to avoid control flow
probs = jax.lax.cond(
apply_softmax,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise static condition, python control flow should be enough



def dice_loss(
predictions: chex.Array,
Copy link

@Logon27 Logon27 Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If these are the logits I think you should actually call the input logits. A lot of other optax functions have the parameter named like that. Just my opinion. I am gonna try out the code from this PR for my own uses and see how it goes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fair. Thanks! Please update here if you face any issues. Then I will accordingly update this. Otherwise I will just make that small change.

Copy link

@Logon27 Logon27 Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see you have an apply_softmax input parameter. Is there a scenario where you wouldn't want softmax applied? Just curious.

Also I tried out the binary dice loss yesterday and it seems to be working fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I had a re-look at the code (sorry Its been kinda long since this started haha).

Given that the the apply_softmax exists, I think ill resort to no change at this point.

There are many non-standard scenarios where you would need the logits themselves and hence i guess it is a good practice?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats fair. I was just wondering what kind of scenarios that would be. No need for changes.

Copy link
Collaborator

@rdyro rdyro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for this and for incorporating all the feedback!

@rdyro
Copy link
Collaborator

rdyro commented Aug 23, 2025

@aymuos15 could you reindent your changes to use indent-width=2?

This is needed to pass the internal checks, but I want you to get proper credit in git history.

@Logon27
Copy link

Logon27 commented Aug 23, 2025

Also it does not have to be this PR. But I think we should add tversky loss as well. Its basically dice loss but you can weight false positives and false negatives using input parameters.

@aymuos15
Copy link
Contributor Author

aymuos15 commented Aug 24, 2025

@aymuos15 could you reindent your changes to use indent-width=2?

This is needed to pass the internal checks, but I want you to get proper credit in git history.

@rdyro Really appreciate the proper credit, thank you very much.

However, apologies, but this maybe an extremely naive question -- Could you please let me know a way to check for this? Looking around, I could not find a way to confirm through a linter about the index-width. I found yapf and autopep8 but they show that the other files (within the repo) are not completely using indent-width=2 so im just a bit confused.

@Logon27 I was planning to do another very small PR just for Dice + CE (and I think that should be more than enough because these two are the most common losses) anyways. I will include this as well. Thank you!

@rdyro
Copy link
Collaborator

rdyro commented Aug 25, 2025

However, apologies, but this maybe an extremely naive question -- Could you please let me know a way to check for this? Looking around, I could not find a way to confirm through a linter about the index-width. I found yapf and autopep8 but they show that the other files (within the repo) are not completely using indent-width=2 so im just a bit confused.

I unfortunately don't know how to set up a linter for this. You can try formatting your code with pyink (a google-compatible code formatter based on black):

$ pip install pyink
$ pyink --pyink --pyink-indentation=2 --pyink-use-majority-quotes --line-length=80 {FILENAME}

@aymuos15

@aymuos15
Copy link
Contributor Author

This is what I get with the current state of the commit

pyink --check --pyink-indentation=2 --pyink-use-majority-quotes --line-length=80 _segmentation.py 
pyink --check --pyink-indentation=2 --pyink-use-majority-quotes --line-length=80 _segmentation_test.py
pyink --check --pyink-indentation=2 --pyink-use-majority-quotes --line-length=80 __init__.py
All done! ✨ 🍰 ✨
1 file would be left unchanged.
All done! ✨ 🍰 ✨
1 file would be left unchanged.
All done! ✨ 🍰 ✨
1 file would be left unchanged.

@copybara-service copybara-service bot merged commit f13212d into google-deepmind:main Aug 25, 2025
14 of 15 checks passed
@aymuos15 aymuos15 deleted the seg_losses branch August 26, 2025 11:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants