-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add softmax_focal_loss() to allow multi-class focal loss #7676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/7676
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hi @dhruvbird! Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention. You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
torchvision/ops/focal_loss.py
Outdated
alpha (float): Weighting factor in range (0,1) to balance | ||
positive vs negative examples or -1 for ignore. Default: ``0.25``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see a conditional checking if alpha is -1
anywhere.
torchvision/ops/focal_loss.py
Outdated
# need to compute the softmax manually anyway. We don't implement that | ||
# here for brevity, but this code can be extended for such a use-case. | ||
pt = torch.exp(-ce_loss) | ||
focal_loss = alpha * ((1 - pt) ** gamma) * ce_loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have my doubts that the alpha
term used here is correct. In fact we want alpha_t
.
Although the paper doesn't explicitly give the formula for alpha_t
, it states in it's definition
For notational convenience, we define $
α_t$
analogously to how we defined $p_t$
So given that
Therefore I believe
Torchvision's sigmoid_focal_loss
does something like this in its implementation:
vision/torchvision/ops/focal_loss.py
Lines 43 to 45 in cc0f9d0
if alpha >= 0: | |
alpha_t = alpha * targets + (1 - alpha) * (1 - targets) | |
loss = alpha_t * loss |
However I've seen another implementation in the wild pass through alpha
to the weights
of nll_loss
:
AdeelH/pytorch-multi-class-focal-loss/focal_loss.py
This strikes me as the correct approach since it allows one to weigh each class separately in a multi-class setting where there are no "negative" classes.
In other words, there's a subtle difference in the intent here versus the sigmoid BCE approach (sigmoid_focal_loss
) where every class is effectively split into separate positive / negative predictions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with your assessment above.
I have put up a revised PR at #7760 since I switched local branches.
In image segmentation tasks, focal loss is useful when trying to classify an image pixel as one of N classes. Unfortunately, sigmoid_focal_loss() isn't useful in such cases. I found that other have been asking for this as well here #3250 so I decided to submit a PR for the same.
I'm opening this PR to check if this is something the pytorch-vision team is interested in merging.