From 5658546334ef3fe7d2eac41a52154ed0ddc8ce00 Mon Sep 17 00:00:00 2001 From: bilzard <36561962+bilzard@users.noreply.github.com> Date: Tue, 17 Aug 2021 21:23:19 +0900 Subject: [PATCH] fix typo --- class_balanced_loss.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/class_balanced_loss.py b/class_balanced_loss.py index 179274d..3d086e8 100644 --- a/class_balanced_loss.py +++ b/class_balanced_loss.py @@ -84,10 +84,10 @@ def CB_loss(labels, logits, samples_per_cls, no_of_classes, loss_type, beta, gam if loss_type == "focal": cb_loss = focal_loss(labels_one_hot, logits, weights, gamma) elif loss_type == "sigmoid": - cb_loss = F.binary_cross_entropy_with_logits(input = logits,target = labels_one_hot, weights = weights) + cb_loss = F.binary_cross_entropy_with_logits(input=logits, target=labels_one_hot, weight=weights) elif loss_type == "softmax": pred = logits.softmax(dim = 1) - cb_loss = F.binary_cross_entropy(input = pred, target = labels_one_hot, weight = weights) + cb_loss = F.binary_cross_entropy(input=pred, target=labels_one_hot, weight=weights) return cb_loss @@ -100,5 +100,5 @@ def CB_loss(labels, logits, samples_per_cls, no_of_classes, loss_type, beta, gam gamma = 2.0 samples_per_cls = [2,3,1,2,2] loss_type = "focal" - cb_loss = CB_loss(labels, logits, samples_per_cls, no_of_classes,loss_type, beta, gamma) + cb_loss = CB_loss(labels, logits, samples_per_cls, no_of_classes, loss_type, beta, gamma) print(cb_loss)