Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 53 additions & 27 deletions medcat-v2/medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from medcat.storage.serialisers import serialise, AvailableSerialisers
from medcat.storage.serialisers import deserialise
from medcat.storage.serialisables import AbstractSerialisable
from medcat.storage.mp_ents_save import BatchAnnotationSaver
from medcat.utils.fileutils import ensure_folder_if_parent
from medcat.utils.hasher import Hasher
from medcat.pipeline.pipeline import Pipeline
Expand Down Expand Up @@ -159,7 +160,7 @@ def get_entities(self,
def _mp_worker_func(
self,
texts_and_indices: list[tuple[str, str, bool]]
) -> list[tuple[str, str, Union[dict, Entities, OnlyCUIEntities]]]:
) -> list[tuple[str, Union[dict, Entities, OnlyCUIEntities]]]:
# NOTE: this is needed for subprocess as otherwise they wouldn't have
# any of these set
# NOTE: these need to by dynamic in case the extra's aren't included
Expand All @@ -180,7 +181,7 @@ def _mp_worker_func(
elif has_rel_cat and isinstance(addon, RelCATAddon):
addon._rel_cat._init_data_paths()
return [
(text, text_index, self.get_entities(text, only_cui=only_cui))
(text_index, self.get_entities(text, only_cui=only_cui))
for text, text_index, only_cui in texts_and_indices]

def _generate_batches_by_char_length(
Expand Down Expand Up @@ -256,7 +257,8 @@ def _mp_one_batch_per_process(
self,
executor: ProcessPoolExecutor,
batch_iter: Iterator[list[tuple[str, str, bool]]],
external_processes: int
external_processes: int,
saver: Optional[BatchAnnotationSaver],
) -> Iterator[tuple[str, Union[dict, Entities, OnlyCUIEntities]]]:
futures: list[Future] = []
# submit batches, one for each external processes
Expand All @@ -269,16 +271,16 @@ def _mp_one_batch_per_process(
break
if not futures:
# NOTE: if there wasn't any data, we didn't process anything
return
raise OutOfDataException()
# Main process works on next batch while workers are busy
main_batch: Optional[list[tuple[str, str, bool]]]
try:
main_batch = next(batch_iter)
main_results = self._mp_worker_func(main_batch)

if saver:
saver(main_results)
# Yield main process results immediately
for result in main_results:
yield result[1], result[2]
yield from main_results

except StopIteration:
main_batch = None
Expand All @@ -295,20 +297,12 @@ def _mp_one_batch_per_process(
done_future = next(as_completed(futures))
futures.remove(done_future)

# Yield all results from this batch
for result in done_future.result():
yield result[1], result[2]
cur_results = done_future.result()
if saver:
saver(cur_results)

# Submit next batch to keep workers busy
try:
batch = next(batch_iter)
futures.append(
executor.submit(self._mp_worker_func, batch))
except StopIteration:
# NOTE: if there's nothing to batch, we've got nothing
# to submit in terms of new work to the workers,
# but we may still have some futures to wait for
pass
# Yield all results from this batch
yield from cur_results

def get_entities_multi_texts(
self,
Expand All @@ -317,6 +311,8 @@ def get_entities_multi_texts(
n_process: int = 1,
batch_size: int = -1,
batch_size_chars: int = 1_000_000,
save_dir_path: Optional[str] = None,
batches_per_save: int = 20,
) -> Iterator[tuple[str, Union[dict, Entities, OnlyCUIEntities]]]:
"""Get entities from multiple texts (potentially in parallel).

Expand All @@ -343,6 +339,16 @@ def get_entities_multi_texts(
Each process will be given batch of texts with a total
number of characters not exceeding this value. Defaults
to 1,000,000 characters. Set to -1 to disable.
save_dir_path (Optional[str]):
The path to where (if specified) the results are saved.
The directory will have a `annotated_ids.pickle` file
containing the tuple[list[str], int] with a list of
indices already saved and then umber of parts already saved.
In addition there will be (usually multuple) files in the
`part_<num>.pickle` format with the partial outputs.
batches_per_save (int):
The number of patches to save (if `save_dir_path` is specified)
at once. Defaults to 20.

Yields:
Iterator[tuple[str, Union[dict, Entities, OnlyCUIEntities]]]:
Expand All @@ -352,15 +358,27 @@ def get_entities_multi_texts(
Union[Iterator[str], Iterator[tuple[str, str]]], iter(texts))
batch_iter = self._generate_batches(
text_iter, batch_size, batch_size_chars, only_cui)
if save_dir_path:
saver = BatchAnnotationSaver(save_dir_path, batches_per_save)
else:
saver = None
if n_process == 1:
# just do in series
for batch in batch_iter:
for _, text_index, result in self._mp_worker_func(batch):
yield text_index, result
batch_results = self._mp_worker_func(batch)
if saver is not None:
saver(batch_results)
yield from batch_results
if saver:
# save remainder
saver._save_cache()
return

with self._no_usage_monitor_exit_flushing():
yield from self._multiprocess(n_process, batch_iter)
yield from self._multiprocess(n_process, batch_iter, saver)
if saver:
# save remainder
saver._save_cache()

@contextmanager
def _no_usage_monitor_exit_flushing(self):
Expand All @@ -379,7 +397,8 @@ def _no_usage_monitor_exit_flushing(self):

def _multiprocess(
self, n_process: int,
batch_iter: Iterator[list[tuple[str, str, bool]]]
batch_iter: Iterator[list[tuple[str, str, bool]]],
saver: Optional[BatchAnnotationSaver],
) -> Iterator[tuple[str, Union[dict, Entities, OnlyCUIEntities]]]:
external_processes = n_process - 1
if self.FORCE_SPAWN_MP:
Expand All @@ -390,8 +409,12 @@ def _multiprocess(
"libraries using threads or native extensions.")
mp.set_start_method("spawn", force=True)
with ProcessPoolExecutor(max_workers=external_processes) as executor:
yield from self._mp_one_batch_per_process(
executor, batch_iter, external_processes)
while True:
try:
yield from self._mp_one_batch_per_process(
executor, batch_iter, external_processes, saver=saver)
except OutOfDataException:
break

def _get_entity(self, ent: MutableEntity,
doc_tokens: list[str],
Expand Down Expand Up @@ -737,7 +760,6 @@ def load_addons(
]
return [(addon.full_name, addon) for addon in loaded_addons]


@overload
def get_model_card(self, as_dict: Literal[True]) -> ModelCard:
pass
Expand Down Expand Up @@ -794,3 +816,7 @@ def __eq__(self, other: Any) -> bool:
def add_addon(self, addon: AddonComponent) -> None:
self.config.components.addons.append(addon.config)
self._pipeline.add_addon(addon)


class OutOfDataException(ValueError):
pass
60 changes: 60 additions & 0 deletions medcat-v2/medcat/storage/mp_ents_save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Union
import os
import logging

import pickle

from medcat.data.entities import Entities, OnlyCUIEntities


logger = logging.getLogger(__name__)


class BatchAnnotationSaver:
def __init__(self, save_dir: str, batches_per_save: int):
self.save_dir = save_dir
self.batches_per_save = batches_per_save
self._batch_cache: list[list[
tuple[str, Union[dict, Entities, OnlyCUIEntities]]]] = []
os.makedirs(save_dir, exist_ok=True)
self.part_number = 0
self.annotated_ids_path = os.path.join(
save_dir, "annotated_ids.pickle")

def _load_existing_ids(self) -> tuple[list[str], int]:
if not os.path.exists(self.annotated_ids_path):
return [], -1
with open(self.annotated_ids_path, 'rb') as f:
return pickle.load(f)

def _save_cache(self):
annotated_ids, prev_part_num = self._load_existing_ids()
if (prev_part_num + 1) != self.part_number:
logger.info(
"Found part number %d off disk. Previously %d was kept track "
"of in code. Updating to %d off disk.",
prev_part_num, self.part_number, prev_part_num)
self.part_number = prev_part_num
for batch in self._batch_cache:
for doc_id, _ in batch:
annotated_ids.append(doc_id)
logger.debug("Saving part %d with %d batches",
self.part_number, len(self._batch_cache))
with open(self.annotated_ids_path, 'wb') as f:
pickle.dump((annotated_ids, self.part_number), f)
# Save batch as part_<num>.pickle
part_path = os.path.join(self.save_dir,
f"part_{self.part_number}.pickle")
part_dict = {id: val for
batch in self._batch_cache for
id, val in batch}
with open(part_path, 'wb') as f:
pickle.dump(part_dict, f)
self._batch_cache.clear()
self.part_number += 1

def __call__(self, batch: list[
tuple[str, Union[dict, Entities, OnlyCUIEntities]]]):
self._batch_cache.append(batch)
if len(self._batch_cache) >= self.batches_per_save:
self._save_cache()
Loading
Loading