diff --git a/medcat-v2/tests/components/addons/meta_cat/test_bert_meta_cat.py b/medcat-v2/tests/components/addons/meta_cat/test_bert_meta_cat.py index 0e82610a..cd6b811e 100644 --- a/medcat-v2/tests/components/addons/meta_cat/test_bert_meta_cat.py +++ b/medcat-v2/tests/components/addons/meta_cat/test_bert_meta_cat.py @@ -7,7 +7,6 @@ import unittest import tempfile import os -from functools import partial import transformers @@ -38,13 +37,28 @@ def guard(*args, **kwargs): # in such a situation @contextmanager def force_hf_download(): + with tempfile.TemporaryDirectory() as temp_dir: + with _force_hf_download(temp_dir): + yield + + +@contextmanager +def _force_hf_download(temp_dir_path: str): orig_from_pretrained = transformers.BertModel.from_pretrained - transformers.BertModel.from_pretrained = partial( - orig_from_pretrained, force_download=True) + + method_calls = [] + + def replacement_method(*args, **kwargs): + method_calls.append((len(args), len(kwargs))) + return orig_from_pretrained( + *args, force_download=True, + cache_dir=temp_dir_path, **kwargs) + transformers.BertModel.from_pretrained = replacement_method try: yield finally: transformers.BertModel.from_pretrained = orig_from_pretrained + assert method_calls, "BertModel.from_pretrained should be called" class BERTMetaCATTests(unittest.TestCase): @@ -64,6 +78,8 @@ def setUpClass(cls): cls.meta_cat = meta_cat.MetaCATAddon.create_new(cls.cnf, cls.tokenizer) cls.temp_dir = tempfile.TemporaryDirectory() + # change model variant to force a network call upon load + cls.cnf.model.model_variant = 'prajjwal1/bert-small' cls.mc_save_path = os.path.join(cls.temp_dir.name, "bert_meta_cat") serialise('dill', cls.meta_cat, cls.mc_save_path)