From 604bd4cce94d0ebbdb93d9900f7bc1abe296ec14 Mon Sep 17 00:00:00 2001 From: j0sie Date: Sat, 22 Jul 2023 16:16:12 -0400 Subject: [PATCH] device --- class_balanced_loss.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/class_balanced_loss.py b/class_balanced_loss.py index 179274d..664ad8b 100644 --- a/class_balanced_loss.py +++ b/class_balanced_loss.py @@ -68,13 +68,17 @@ def CB_loss(labels, logits, samples_per_cls, no_of_classes, loss_type, beta, gam Returns: cb_loss: A float tensor representing class balanced loss """ + assert labels.get_device() == logits.get_device() + device = labels.get_device() + device = torch.device("cuda:" + str(device)) if device >= 0 else torch.device("cpu") + effective_num = 1.0 - np.power(beta, samples_per_cls) weights = (1.0 - beta) / np.array(effective_num) weights = weights / np.sum(weights) * no_of_classes labels_one_hot = F.one_hot(labels, no_of_classes).float() - weights = torch.tensor(weights).float() + weights = torch.tensor(weights).float().to(device) weights = weights.unsqueeze(0) weights = weights.repeat(labels_one_hot.shape[0],1) * labels_one_hot weights = weights.sum(1)