Skip to content

Commit 8c9106c

Browse files
amitportnoynovice03
authored andcommitted
Update trainer.mdx class_weights example (huggingface#23787)
class_weights tensor should follow model's device
1 parent 3c55a12 commit 8c9106c

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

docs/source/en/main_classes/trainer.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class CustomTrainer(Trainer):
6161
outputs = model(**inputs)
6262
logits = outputs.get("logits")
6363
# compute custom loss (suppose one has 3 labels with different weights)
64-
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0]))
64+
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device))
6565
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
6666
return (loss, outputs) if return_outputs else loss
6767
```

0 commit comments

Comments
 (0)