|
1 | 1 | from unittest import TestCase |
2 | 2 |
|
3 | 3 | import numpy as np |
| 4 | +import pytest |
4 | 5 | import torch |
5 | 6 | from datasets import load_dataset |
6 | 7 | from sentence_transformers import SentenceTransformer |
|
12 | 13 | from setfit.modeling import MODEL_HEAD_NAME, sentence_pairs_generation, sentence_pairs_generation_multilabel |
13 | 14 |
|
14 | 15 |
|
| 16 | +torch_cuda_available = pytest.mark.skipif(not torch.cuda.is_available(), reason="PyTorch must be compiled with CUDA") |
| 17 | + |
| 18 | + |
15 | 19 | def test_sentence_pairs_generation(): |
16 | 20 | sentences = np.array(["sent 1", "sent 2", "sent 3"]) |
17 | 21 | labels = np.array(["label 1", "label 2", "label 3"]) |
@@ -255,3 +259,20 @@ def test_to_torch_head(): |
255 | 259 | model.to(device) |
256 | 260 | assert model.model_body.device == device |
257 | 261 | assert model.model_head.device == device |
| 262 | + |
| 263 | + |
| 264 | +@torch_cuda_available |
| 265 | +@pytest.mark.parametrize("use_differentiable_head", [True, False]) |
| 266 | +def test_to_sentence_transformer_device_reset(use_differentiable_head): |
| 267 | + # This should initialize SentenceTransformer() without a specific device |
| 268 | + # which sets the model to CUDA iff CUDA is available. |
| 269 | + model = SetFitModel.from_pretrained( |
| 270 | + "sentence-transformers/paraphrase-albert-small-v2", use_differentiable_head=use_differentiable_head |
| 271 | + ) |
| 272 | + # If we move the entire model to CPU, we expect it to stay on CPU forever, |
| 273 | + # Even after encoding or fitting |
| 274 | + model.to("cpu") |
| 275 | + assert model.model_body.device == torch.device("cpu") |
| 276 | + |
| 277 | + model.model_body.encode("This is a test sample to encode") |
| 278 | + assert model.model_body.device == torch.device("cpu") |
0 commit comments