Skip to content

Add softmax_focal_loss() to allow multi-class focal loss #7676

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 0 commits into from

Conversation

dhruvbird
Copy link

In image segmentation tasks, focal loss is useful when trying to classify an image pixel as one of N classes. Unfortunately, sigmoid_focal_loss() isn't useful in such cases. I found that other have been asking for this as well here #3250 so I decided to submit a PR for the same.

I'm opening this PR to check if this is something the pytorch-vision team is interested in merging.

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 18, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/7676

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link

Hi @dhruvbird!

Thank you for your pull request.

We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

Comment on lines 80 to 81
alpha (float): Weighting factor in range (0,1) to balance
positive vs negative examples or -1 for ignore. Default: ``0.25``.
Copy link

@rehno-lindeque rehno-lindeque Jul 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a conditional checking if alpha is -1 anywhere.

# need to compute the softmax manually anyway. We don't implement that
# here for brevity, but this code can be extended for such a use-case.
pt = torch.exp(-ce_loss)
focal_loss = alpha * ((1 - pt) ** gamma) * ce_loss
Copy link

@rehno-lindeque rehno-lindeque Jul 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have my doubts that the alpha term used here is correct. In fact we want alpha_t.

Although the paper doesn't explicitly give the formula for alpha_t, it states in it's definition

For notational convenience, we define $α_t$ analogously to how we defined $p_t$

So given that

$ \begin{align*} p_t = \begin{cases} \alpha & \text{if } y = 1 \\ (1 - p) & \text{otherwise} \end{cases} \end{align*}$

Therefore I believe $\alpha_t$ should be analogously be interpreted as

$ \begin{align*} \alpha_t = \begin{cases} \alpha & \text{if } y = 1 \\ (1 - \alpha) & \text{otherwise} \end{cases} \end{align*}$

Torchvision's sigmoid_focal_loss does something like this in its implementation:

if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss

However I've seen another implementation in the wild pass through alpha to the weights of nll_loss:

AdeelH/pytorch-multi-class-focal-loss/focal_loss.py

This strikes me as the correct approach since it allows one to weigh each class separately in a multi-class setting where there are no "negative" classes.

In other words, there's a subtle difference in the intent here versus the sigmoid BCE approach (sigmoid_focal_loss) where every class is effectively split into separate positive / negative predictions.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with your assessment above.

I have put up a revised PR at #7760 since I switched local branches.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants