Skip to content

Commit 4a09065

Browse files
DN6donhardman
authored andcommitted
Fix clearing backend cache from device agnostic testing (huggingface#6075)
update
1 parent 43ff401 commit 4a09065

File tree

4 files changed

+7
-7
lines changed

4 files changed

+7
-7
lines changed

tests/models/test_models_prior.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def tearDown(self):
164164
# clean up the VRAM after each test
165165
super().tearDown()
166166
gc.collect()
167-
backend_empty_cache()
167+
backend_empty_cache(torch_device)
168168

169169
@parameterized.expand(
170170
[

tests/models/test_models_unet_2d_condition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -869,7 +869,7 @@ def tearDown(self):
869869
# clean up the VRAM after each test
870870
super().tearDown()
871871
gc.collect()
872-
backend_empty_cache()
872+
backend_empty_cache(torch_device)
873873

874874
def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
875875
dtype = torch.float16 if fp16 else torch.float32

tests/models/test_models_vae.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ def tearDown(self):
485485
# clean up the VRAM after each test
486486
super().tearDown()
487487
gc.collect()
488-
backend_empty_cache()
488+
backend_empty_cache(torch_device)
489489

490490
def get_file_format(self, seed, shape):
491491
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
@@ -565,7 +565,7 @@ def tearDown(self):
565565
# clean up the VRAM after each test
566566
super().tearDown()
567567
gc.collect()
568-
backend_empty_cache()
568+
backend_empty_cache(torch_device)
569569

570570
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
571571
dtype = torch.float16 if fp16 else torch.float32
@@ -820,7 +820,7 @@ def tearDown(self):
820820
# clean up the VRAM after each test
821821
super().tearDown()
822822
gc.collect()
823-
backend_empty_cache()
823+
backend_empty_cache(torch_device)
824824

825825
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
826826
dtype = torch.float16 if fp16 else torch.float32

tests/pipelines/stable_diffusion_2/test_stable_diffusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase):
310310
def tearDown(self):
311311
super().tearDown()
312312
gc.collect()
313-
backend_empty_cache()
313+
backend_empty_cache(torch_device)
314314

315315
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
316316
_generator_device = "cpu" if not generator_device.startswith("cuda") else "cuda"
@@ -531,7 +531,7 @@ class StableDiffusion2PipelineNightlyTests(unittest.TestCase):
531531
def tearDown(self):
532532
super().tearDown()
533533
gc.collect()
534-
backend_empty_cache()
534+
backend_empty_cache(torch_device)
535535

536536
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
537537
_generator_device = "cpu" if not generator_device.startswith("cuda") else "cuda"

0 commit comments

Comments
 (0)