Skip to content

Commit 25309d4

Browse files
authored
CU-8699wc4zb Port offline BERT MetaCAT load to v2 (#85)
* CU-8699wc4zb: Port PR 67 (offline BERT-based MetaCAT load) to v2 * CU-8699wc4zb: Add a simple test to make sure that offline loading uses correct path to load model * CU-8699wc4zb: Fix offline BERT based model load * CU-8699wc4zb: Add test to make sure BERT MetaCATs can be loaded when offline * CU-8699wc4zb: Update Bert MetaCAT online test to include checking that online calls were made * CU-8699wc4zb: Force HF model download during test time * CU-8699wc4zb: Make sure not to overwrite save_dir_path when loading MetaCAT * CU-8699wc4zb: Simplify MetaCAT save_dir_path checking test * CU-8699wc4zb: Fix typo in assert method
1 parent 7589a22 commit 25309d4

File tree

4 files changed

+146
-10
lines changed

4 files changed

+146
-10
lines changed

medcat-v2/medcat/components/addons/meta_cat/meta_cat.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def __call__(self, doc: MutableDocument) -> MutableDocument:
125125
def load(self, folder_path: str) -> 'MetaCAT':
126126
mc_path, tokenizer_folder = self._get_meta_cat_and_tokenizer_paths(
127127
folder_path)
128-
mc = cast(MetaCAT, deserialise(mc_path))
128+
mc = cast(MetaCAT, deserialise(mc_path, save_dir_path=folder_path))
129129
mc.tokenizer = self._load_tokenizer(self.config, tokenizer_folder)
130130
return mc
131131

@@ -150,6 +150,11 @@ def save(self, folder_path: str) -> None:
150150
raise MisconfiguredMetaCATException(
151151
"Unable to save MetaCAT without a tokenizer")
152152
self.mc.tokenizer.save(tokenizer_folder)
153+
if self.config.model.model_name == 'bert':
154+
model_config_save_path = os.path.join(
155+
folder_path, 'bert_config.json')
156+
self._mc.model.bert_config.to_json_file( # type: ignore
157+
model_config_save_path)
153158

154159
def _init_data_paths(self, base_tokenizer: BaseTokenizer):
155160
# a dictionary like {category_name: value, ...}
@@ -293,7 +298,7 @@ def get_init_attrs(cls) -> list[str]:
293298

294299
@classmethod
295300
def ignore_attrs(cls) -> list[str]:
296-
return ['model']
301+
return ['model', 'save_dir_path']
297302

298303
@classmethod
299304
def include_properties(cls) -> list[str]:
@@ -308,10 +313,12 @@ def __init__(self,
308313
tokenizer: Optional[TokenizerWrapperBase] = None,
309314
embeddings: Optional[Union[Tensor, numpy.ndarray]] = None,
310315
config: Optional[ConfigMetaCAT] = None,
311-
_model_state_dict: Optional[dict[str, Any]] = None) -> None:
316+
_model_state_dict: Optional[dict[str, Any]] = None,
317+
save_dir_path: Optional[str] = None) -> None:
312318
if config is None:
313319
config = ConfigMetaCAT()
314320
self.config = config
321+
self.save_dir_path = save_dir_path
315322
set_all_seeds(config.general.seed)
316323

317324
self.tokenizer = tokenizer
@@ -355,7 +362,7 @@ def get_model(self, embeddings: Optional[Tensor]) -> nn.Module:
355362
elif config.model.model_name == 'bert':
356363
from medcat.components.addons.meta_cat.models import (
357364
BertForMetaAnnotation)
358-
model = BertForMetaAnnotation(config)
365+
model = BertForMetaAnnotation(config, self.save_dir_path)
359366

360367
if not config.model.model_freeze_layers:
361368
peft_config = LoraConfig(

medcat-v2/medcat/components/addons/meta_cat/models.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,22 +98,54 @@ def forward(self,
9898
class BertForMetaAnnotation(nn.Module):
9999
_keys_to_ignore_on_load_unexpected: list[str] = [r"pooler"] # type: ignore
100100

101-
def __init__(self, config: ConfigMetaCAT):
101+
def __init__(self, config: ConfigMetaCAT,
102+
save_dir_path: Optional[str] = None):
102103
super(BertForMetaAnnotation, self).__init__()
103-
_bertconfig = AutoConfig.from_pretrained(
104-
config.model.model_variant,
105-
num_hidden_layers=config.model.num_layers)
104+
if save_dir_path:
105+
try:
106+
_bertconfig = AutoConfig.from_pretrained(
107+
save_dir_path + "/bert_config.json",
108+
num_hidden_layers=config.model.num_layers)
109+
except Exception as e:
110+
_bertconfig = AutoConfig.from_pretrained(
111+
config.model.model_variant,
112+
num_hidden_layers=config.model.num_layers)
113+
logger.info("BERT config not found locally — "
114+
"downloaded successfully from Hugging Face.")
115+
raise e
116+
else:
117+
_bertconfig = AutoConfig.from_pretrained(
118+
config.model.model_variant,
119+
num_hidden_layers=config.model.num_layers)
120+
106121
if config.model.input_size != _bertconfig.hidden_size:
107122
logger.warning(
108123
"Input size for %s model should be %d, provided input size is "
109124
"%d. Input size changed to %d", config.model.model_variant,
110125
_bertconfig.hidden_size, config.model.input_size,
111126
_bertconfig.hidden_size)
112127

113-
bert = BertModel.from_pretrained(config.model.model_variant,
114-
config=_bertconfig)
128+
try:
129+
bert = BertModel.from_pretrained(
130+
config.model.model_variant,
131+
config=_bertconfig)
132+
except Exception as e:
133+
bert = BertModel(_bertconfig)
134+
if save_dir_path:
135+
logger.info(
136+
"Could not load BERT pretrained weights from Hugging Face."
137+
" BERT model was loaded with random weights.\n"
138+
"This will work the weights will be loaded off disk.")
139+
else:
140+
logger.warning(
141+
"Could not load BERT pretrained weights from Hugging Face."
142+
" BERT model was loaded with random weights.\n"
143+
"DO NOT use this model without loading the model state!",
144+
exc_info=e)
145+
115146
self.config = config
116147
self.bert = bert
148+
self.bert_config = _bertconfig
117149
self.num_labels = config.model.nclasses
118150
for param in self.bert.parameters():
119151
param.requires_grad = not config.model.model_freeze_layers
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import socket
2+
from contextlib import contextmanager
3+
4+
from medcat.components.addons.meta_cat import meta_cat
5+
from medcat.storage.serialisers import serialise, deserialise
6+
7+
import unittest
8+
import tempfile
9+
import os
10+
from functools import partial
11+
12+
import transformers
13+
14+
from .test_meta_cat import FakeTokenizer
15+
16+
17+
@contextmanager
18+
def assert_tries_network():
19+
real_socket = socket.socket
20+
calls = []
21+
22+
def guard(*args, **kwargs):
23+
calls.append((len(args), len(kwargs)))
24+
raise OSError("Network disabled for test")
25+
26+
socket.socket = guard
27+
try:
28+
yield
29+
finally:
30+
socket.socket = real_socket
31+
assert calls, "No network calls were made during the test"
32+
33+
34+
# NOTE: need to disable the usage of the cache
35+
# otherwise other parts of the test suite
36+
# might have already downloaded and cached
37+
# the model and no network calls may be made
38+
# in such a situation
39+
@contextmanager
40+
def force_hf_download():
41+
orig_from_pretrained = transformers.BertModel.from_pretrained
42+
transformers.BertModel.from_pretrained = partial(
43+
orig_from_pretrained, force_download=True)
44+
try:
45+
yield
46+
finally:
47+
transformers.BertModel.from_pretrained = orig_from_pretrained
48+
49+
50+
class BERTMetaCATTests(unittest.TestCase):
51+
52+
@classmethod
53+
def setUpClass(cls):
54+
cls.cnf = meta_cat.ConfigMetaCAT()
55+
cls.cnf.model.model_name = 'bert'
56+
cls.cnf.general.vocab_size = 10
57+
cls.cnf.model.padding_idx = 5
58+
cls.cnf.general.tokenizer_name = 'bert-tokenizer'
59+
cls.cnf.model.model_variant = 'prajjwal1/bert-tiny'
60+
cls.cnf.general.category_name = 'FAKE_category'
61+
cls.cnf.general.category_value2id = {
62+
'Future': 0, 'Past': 2, 'Recent': 1}
63+
cls.tokenizer = FakeTokenizer()
64+
cls.meta_cat = meta_cat.MetaCATAddon.create_new(cls.cnf, cls.tokenizer)
65+
66+
cls.temp_dir = tempfile.TemporaryDirectory()
67+
cls.mc_save_path = os.path.join(cls.temp_dir.name, "bert_meta_cat")
68+
serialise('dill', cls.meta_cat, cls.mc_save_path)
69+
70+
@classmethod
71+
def tearDownClass(cls):
72+
cls.temp_dir.cleanup()
73+
74+
def test_no_network_load(self):
75+
with assert_tries_network():
76+
with force_hf_download():
77+
mc = deserialise(self.mc_save_path)
78+
self.assertIsInstance(mc, meta_cat.MetaCATAddon)

medcat-v2/tests/components/addons/meta_cat/test_meta_cat.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
from medcat.storage.serialisers import serialise, AvailableSerialisers
77
from medcat.config.config_meta_cat import ConfigMetaCAT
88
from medcat.config.config import Config
9+
from medcat.utils.defaults import COMPONENTS_FOLDER
910

11+
import os
12+
import unittest.mock
1013
import unittest
1114
import tempfile
1215

@@ -104,6 +107,22 @@ def test_can_save_and_load(self):
104107
cat2 = CAT.load_model_pack(file_name)
105108
self.assert_has_meta_cat(cat2, False)
106109

110+
def test_loading_uses_save_dir_path(self):
111+
with tempfile.TemporaryDirectory() as temp_dir:
112+
file_name = self.cat.save_model_pack(
113+
temp_dir, serialiser_type=self.SER_TYPE)
114+
exp_meta_cat_path = os.path.join(
115+
file_name.removesuffix(".zip"),
116+
COMPONENTS_FOLDER,
117+
self.meta_cat.get_folder_name()
118+
)
119+
cat = CAT.load_model_pack(file_name)
120+
meta_cats = cat.get_addons_of_type(meta_cat.MetaCATAddon)
121+
self.assertEqual(len(meta_cats), 1)
122+
mc = meta_cats[0]
123+
self.assertIsNotNone(mc.mc.save_dir_path)
124+
self.assertEqual(mc.mc.save_dir_path, exp_meta_cat_path)
125+
107126
def test_turns_up_in_output(self):
108127
ents = self.cat.get_entities(
109128
"This is a fit text for rich and chronic disease like fittest.")

0 commit comments

Comments
 (0)