Skip to content
Merged
52 changes: 34 additions & 18 deletions medcat-v2/medcat/components/addons/meta_cat/meta_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
from medcat.cdb import CDB
from medcat.vocab import Vocab
from medcat.utils.defaults import COMPONENTS_FOLDER
from medcat.utils.defaults import (
avoid_legacy_conversion, doing_legacy_conversion_message,
LegacyConversionDisabledError)
from peft import get_peft_model, LoraConfig, TaskType

# It should be safe to do this always, as all other multiprocessing
Expand Down Expand Up @@ -173,9 +176,39 @@ def serialise_to(self, folder_path: str) -> None:
os.mkdir(folder_path)
self.save(folder_path)

@classmethod
def _create_throwaway_tokenizer(cls) -> BaseTokenizer:
from medcat.tokenizing.tokenizers import create_tokenizer
from medcat.config import Config
logger.warning(
"A base tokenizer was not provided during the loading of a "
"MetaCAT. The tokenizer is used to register the required data "
"paths for MetaCAT to function. Using the default of '%s'. If "
"this it not the tokenizer you will end up using, MetaCAT may "
"be unable to recover unless a) the paths are registered "
"explicitly, or b) there are other MetaCATs created with the "
"correct tokenizer. Do note that this will also create "
"another instance of the tokenizer, though it should be "
"garbage collected soon.", cls.DEFAULT_TOKENIZER
)
# NOTE: the use of a (mostly) default config here probably won't
# affect anything since the tokenizer itself won't be used
gcnf = Config()
gcnf.general.nlp.provider = 'spacy'
return create_tokenizer(cls.DEFAULT_TOKENIZER, gcnf)

@classmethod
def deserialise_from(cls, folder_path: str, **init_kwargs
) -> 'MetaCATAddon':
if "model.dat" in os.listdir(folder_path):
if not avoid_legacy_conversion():
doing_legacy_conversion_message(
logger, cls.__name__, folder_path)
from medcat.utils.legacy.convert_meta_cat import (
get_meta_cat_from_old)
return get_meta_cat_from_old(
folder_path, cls._create_throwaway_tokenizer())
raise LegacyConversionDisabledError(cls.__name__,)
if 'cnf' in init_kwargs:
cnf = init_kwargs['cnf']
else:
Expand All @@ -191,24 +224,7 @@ def deserialise_from(cls, folder_path: str, **init_kwargs
if 'tokenizer' in init_kwargs:
tokenizer = init_kwargs['tokenizer']
else:
from medcat.tokenizing.tokenizers import create_tokenizer
from medcat.config import Config
logger.warning(
"A base tokenizer was not provided during the loading of a "
"MetaCAT. The tokenizer is used to register the required data "
"paths for MetaCAT to function. Using the default of '%s'. If "
"this it not the tokenizer you will end up using, MetaCAT may "
"be unable to recover unless a) the paths are registered "
"explicitly, or b) there are other MetaCATs created with the "
"correct tokenizer. Do note that this will also create "
"another instance of the tokenizer, though it should be "
"garbage collected soon.", cls.DEFAULT_TOKENIZER
)
# NOTE: the use of a (mostly) default config here probably won't
# affect anything since the tokenizer itself won't be used
gcnf = Config()
gcnf.general.nlp.provider = 'spacy'
tokenizer = create_tokenizer(cls.DEFAULT_TOKENIZER, gcnf)
tokenizer = cls._create_throwaway_tokenizer()
return cls.load_existing(
load_path=folder_path,
cnf=cnf,
Expand Down