diff --git a/class_balanced_loss.py b/class_balanced_loss.py index 179274d..664ad8b 100644 --- a/class_balanced_loss.py +++ b/class_balanced_loss.py @@ -68,13 +68,17 @@ def CB_loss(labels, logits, samples_per_cls, no_of_classes, loss_type, beta, gam Returns: cb_loss: A float tensor representing class balanced loss """ + assert labels.get_device() == logits.get_device() + device = labels.get_device() + device = torch.device("cuda:" + str(device)) if device >= 0 else torch.device("cpu") + effective_num = 1.0 - np.power(beta, samples_per_cls) weights = (1.0 - beta) / np.array(effective_num) weights = weights / np.sum(weights) * no_of_classes labels_one_hot = F.one_hot(labels, no_of_classes).float() - weights = torch.tensor(weights).float() + weights = torch.tensor(weights).float().to(device) weights = weights.unsqueeze(0) weights = weights.repeat(labels_one_hot.shape[0],1) * labels_one_hot weights = weights.sum(1)