Skip to content
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions medcat-v2/tests/components/addons/meta_cat/test_bert_meta_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import unittest
import tempfile
import os
from functools import partial

import transformers

Expand Down Expand Up @@ -38,13 +37,28 @@ def guard(*args, **kwargs):
# in such a situation
@contextmanager
def force_hf_download():
with tempfile.TemporaryDirectory() as temp_dir:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I'm officially out of my depth here it feels like

  • Forces the from_pretrained method to use force_download=True during the test
    Yep I get this one

  • Asserts that a network call is attempted (but refused)
    Yep I see the assert for sure

But when do you actually expect it to call "transformers.BertModel.from_pretrained " inside the code next?

I just see serialize to dill then deserialize, but the deserialize doesnt call from_pretrained, at least not in itself...

My gut is feeling like it should do something like "mc = deserialize(...); now do something with mc which calls from_pretrained", if this is anything like java/c# serialization anyway - like the saved file is just the object state, loading it back wont trigger a constructor or anything

But anyway - I'm sure there's a magic line that's likely to be found by people that know what they're doing :D

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to press send on the above...

Looking at the new lines around the wait for async - I'd hope there's some function like "do something on mc, that waits until its ready internally". Feels like there should be some way to use the deserialise funciton and rely on the asyc calls having finished, else anyone using this (not in a test) would be equally stuck.

Would always want to avoid the waits as it implies an underlying design issue, esp if we can fix it inside the library itself.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Within deserialisation it'll get to MetaCATAddon.load_existing. Which in turn calls MetaCATAddon.load method. And that deserialises the underlying object. Which then should call the MetaCAT.__init__. And that calls MetaCAT.get_model, which (in case of Bert-based MetaCAT like here) inits BertForMetaAnnotation. And that finally calls BertModel.from_pretrained.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is, this is built on a custom serialisation that is designed to tirgger the init. The trivial implementation of pickle'ing stuff would indeed avoid the calls to__init__. But we don't really want that since:

  • some things break when you do that
  • we don't want to save everything to disk in the same format that it is in the memory (i.e some things know how to save their bits better than I do)
  • this allows us to be more backwards compatible in terms of loading older models (otherwise pickleing would preserve the state of the class as well - not just its attributes)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the new lines around the wait for async - I'd hope there's some function like "do something on mc, that waits until its ready internally". Feels like there should be some way to use the deserialise funciton and rely on the asyc calls having finished, else anyone using this (not in a test) would be equally stuck.

Would always want to avoid the waits as it implies an underlying design issue, esp if we can fix it inside the library itself.

I'm pretty sure this is an async issue because (when testing locally) the network call is done from the same process, but on a different thread.

Now, this isn't anything we've designed, it's something on transformers side. I don't know this for certain, but my best guess is that they will wait for completion if/when the bit that's being downloaded is needed. Because - like you said - otherwise people would fail to use a model they've initialised / loaded.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we then call something on mc? Basically assert it works, not just that is the right class

   mc = deserialise(self.mc_save_path)
cat = CAT(...)
cat.add_addon(mc)

result = cat.get_entities(..)
Assert result just to confirm its worked

# (Noting the real assertion is that this also hasnt made any network calls)

Feels like:

  • Either the above works and something somewhere waits for the async calls
  • OR no user can technically use this without risking it not being ready when they get_entities

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's certainly a good idea!

I can have it run through a document and that should work if my assumption about the lazy loading from above is correct.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added something to workflow that runs through a document and an entity.

But it still fails.

So clearly it's actually getting the model from somewhere. And it does so without doing a network call (at least not in the way I'm guarding against).

with _force_hf_download(temp_dir):
yield


@contextmanager
def _force_hf_download(temp_dir_path: str):
orig_from_pretrained = transformers.BertModel.from_pretrained
transformers.BertModel.from_pretrained = partial(
orig_from_pretrained, force_download=True)

method_calls = []

def replacement_method(*args, **kwargs):
method_calls.append((len(args), len(kwargs)))
return orig_from_pretrained(
*args, force_download=True,
cache_dir=temp_dir_path, **kwargs)
transformers.BertModel.from_pretrained = replacement_method
try:
yield
finally:
transformers.BertModel.from_pretrained = orig_from_pretrained
assert method_calls, "BertModel.from_pretrained should be called"


class BERTMetaCATTests(unittest.TestCase):
Expand All @@ -64,6 +78,8 @@ def setUpClass(cls):
cls.meta_cat = meta_cat.MetaCATAddon.create_new(cls.cnf, cls.tokenizer)

cls.temp_dir = tempfile.TemporaryDirectory()
# change model variant to force a network call upon load
cls.cnf.model.model_variant = 'prajjwal1/bert-small'
cls.mc_save_path = os.path.join(cls.temp_dir.name, "bert_meta_cat")
serialise('dill', cls.meta_cat, cls.mc_save_path)

Expand Down
Loading