Skip to content

Commit e1a5375

Browse files
authored
Resolve SentenceTransformer resetting devices after moving a SetFitModel (#283)
* Update the SentenceTransformer target device when moving SetFitModel * Add regression test for moving SetFitModel * Parametrize regression test for different head types
1 parent 9b7f74e commit e1a5375

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

src/setfit/modeling.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,9 @@ def to(self, device: Union[str, torch.device]) -> "SetFitModel":
445445
Returns:
446446
SetFitModel: Returns the original model, but now on the desired device.
447447
"""
448+
# Note that we must also set _target_device, or any SentenceTransformer.fit() call will reset
449+
# the body location
450+
self.model_body._target_device = device if isinstance(device, torch.device) else torch.device(device)
448451
self.model_body = self.model_body.to(device)
449452

450453
if self.has_differentiable_head:

tests/test_modeling.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from unittest import TestCase
22

33
import numpy as np
4+
import pytest
45
import torch
56
from datasets import load_dataset
67
from sentence_transformers import SentenceTransformer
@@ -12,6 +13,9 @@
1213
from setfit.modeling import MODEL_HEAD_NAME, sentence_pairs_generation, sentence_pairs_generation_multilabel
1314

1415

16+
torch_cuda_available = pytest.mark.skipif(not torch.cuda.is_available(), reason="PyTorch must be compiled with CUDA")
17+
18+
1519
def test_sentence_pairs_generation():
1620
sentences = np.array(["sent 1", "sent 2", "sent 3"])
1721
labels = np.array(["label 1", "label 2", "label 3"])
@@ -255,3 +259,20 @@ def test_to_torch_head():
255259
model.to(device)
256260
assert model.model_body.device == device
257261
assert model.model_head.device == device
262+
263+
264+
@torch_cuda_available
265+
@pytest.mark.parametrize("use_differentiable_head", [True, False])
266+
def test_to_sentence_transformer_device_reset(use_differentiable_head):
267+
# This should initialize SentenceTransformer() without a specific device
268+
# which sets the model to CUDA iff CUDA is available.
269+
model = SetFitModel.from_pretrained(
270+
"sentence-transformers/paraphrase-albert-small-v2", use_differentiable_head=use_differentiable_head
271+
)
272+
# If we move the entire model to CPU, we expect it to stay on CPU forever,
273+
# Even after encoding or fitting
274+
model.to("cpu")
275+
assert model.model_body.device == torch.device("cpu")
276+
277+
model.model_body.encode("This is a test sample to encode")
278+
assert model.model_body.device == torch.device("cpu")

0 commit comments

Comments
 (0)