From 1be9f6858f86f4c7ff054cdd7c6ef7b94bbbae25 Mon Sep 17 00:00:00 2001 From: mart-r Date: Thu, 17 Jul 2025 14:50:17 +0100 Subject: [PATCH 1/6] CU-8699upt9a: Add option to save multiprocessing output --- medcat-v2/medcat/cat.py | 56 +++++++++++++++++----- medcat-v2/medcat/storage/mp_ents_save.py | 60 ++++++++++++++++++++++++ 2 files changed, 103 insertions(+), 13 deletions(-) create mode 100644 medcat-v2/medcat/storage/mp_ents_save.py diff --git a/medcat-v2/medcat/cat.py b/medcat-v2/medcat/cat.py index 4abd6a425..287d72852 100644 --- a/medcat-v2/medcat/cat.py +++ b/medcat-v2/medcat/cat.py @@ -18,6 +18,7 @@ from medcat.storage.serialisers import serialise, AvailableSerialisers from medcat.storage.serialisers import deserialise from medcat.storage.serialisables import AbstractSerialisable +from medcat.storage.mp_ents_save import BatchAnnotationSaver from medcat.utils.fileutils import ensure_folder_if_parent from medcat.utils.hasher import Hasher from medcat.pipeline.pipeline import Pipeline @@ -159,7 +160,7 @@ def get_entities(self, def _mp_worker_func( self, texts_and_indices: list[tuple[str, str, bool]] - ) -> list[tuple[str, str, Union[dict, Entities, OnlyCUIEntities]]]: + ) -> list[tuple[str, Union[dict, Entities, OnlyCUIEntities]]]: # NOTE: this is needed for subprocess as otherwise they wouldn't have # any of these set # NOTE: these need to by dynamic in case the extra's aren't included @@ -180,7 +181,7 @@ def _mp_worker_func( elif has_rel_cat and isinstance(addon, RelCATAddon): addon._rel_cat._init_data_paths() return [ - (text, text_index, self.get_entities(text, only_cui=only_cui)) + (text_index, self.get_entities(text, only_cui=only_cui)) for text, text_index, only_cui in texts_and_indices] def _generate_batches_by_char_length( @@ -256,7 +257,8 @@ def _mp_one_batch_per_process( self, executor: ProcessPoolExecutor, batch_iter: Iterator[list[tuple[str, str, bool]]], - external_processes: int + external_processes: int, + saver: Optional[BatchAnnotationSaver], ) -> Iterator[tuple[str, Union[dict, Entities, OnlyCUIEntities]]]: futures: list[Future] = [] # submit batches, one for each external processes @@ -275,10 +277,10 @@ def _mp_one_batch_per_process( try: main_batch = next(batch_iter) main_results = self._mp_worker_func(main_batch) - + if saver: + saver(main_results) # Yield main process results immediately - for result in main_results: - yield result[1], result[2] + yield from main_results except StopIteration: main_batch = None @@ -295,9 +297,12 @@ def _mp_one_batch_per_process( done_future = next(as_completed(futures)) futures.remove(done_future) + cur_results = done_future.result() + if saver: + saver(main_results) + # Yield all results from this batch - for result in done_future.result(): - yield result[1], result[2] + yield from cur_results # Submit next batch to keep workers busy try: @@ -317,6 +322,8 @@ def get_entities_multi_texts( n_process: int = 1, batch_size: int = -1, batch_size_chars: int = 1_000_000, + save_dir_path: Optional[str] = None, + batches_per_save: int = 20, ) -> Iterator[tuple[str, Union[dict, Entities, OnlyCUIEntities]]]: """Get entities from multiple texts (potentially in parallel). @@ -343,6 +350,16 @@ def get_entities_multi_texts( Each process will be given batch of texts with a total number of characters not exceeding this value. Defaults to 1,000,000 characters. Set to -1 to disable. + save_dir_path (Optional[str]): + The path to where (if specified) the results are saved. + The directory will have a `annotated_ids.pickle` file + containing the tuple[list[str], int] with a list of + indices already saved and then umber of parts already saved. + In addition there will be (usually multuple) files in the + `part_.pickle` format with the partial outputs. + batches_per_save (int): + The number of patches to save (if `save_dir_path` is specified) + at once. Defaults to 20. Yields: Iterator[tuple[str, Union[dict, Entities, OnlyCUIEntities]]]: @@ -352,15 +369,27 @@ def get_entities_multi_texts( Union[Iterator[str], Iterator[tuple[str, str]]], iter(texts)) batch_iter = self._generate_batches( text_iter, batch_size, batch_size_chars, only_cui) + if save_dir_path: + saver = BatchAnnotationSaver(save_dir_path, batches_per_save) + else: + saver = None if n_process == 1: # just do in series for batch in batch_iter: - for _, text_index, result in self._mp_worker_func(batch): - yield text_index, result + batch_results = self._mp_worker_func(batch) + if saver is not None: + saver(batch_results) + yield from batch_results + if saver: + # save remainder + saver._save_cache() return with self._no_usage_monitor_exit_flushing(): - yield from self._multiprocess(n_process, batch_iter) + yield from self._multiprocess(n_process, batch_iter, saver) + if saver: + # save remainder + saver._save_cache() @contextmanager def _no_usage_monitor_exit_flushing(self): @@ -379,7 +408,8 @@ def _no_usage_monitor_exit_flushing(self): def _multiprocess( self, n_process: int, - batch_iter: Iterator[list[tuple[str, str, bool]]] + batch_iter: Iterator[list[tuple[str, str, bool]]], + saver: Optional[BatchAnnotationSaver], ) -> Iterator[tuple[str, Union[dict, Entities, OnlyCUIEntities]]]: external_processes = n_process - 1 if self.FORCE_SPAWN_MP: @@ -391,7 +421,7 @@ def _multiprocess( mp.set_start_method("spawn", force=True) with ProcessPoolExecutor(max_workers=external_processes) as executor: yield from self._mp_one_batch_per_process( - executor, batch_iter, external_processes) + executor, batch_iter, external_processes, saver=saver) def _get_entity(self, ent: MutableEntity, doc_tokens: list[str], diff --git a/medcat-v2/medcat/storage/mp_ents_save.py b/medcat-v2/medcat/storage/mp_ents_save.py new file mode 100644 index 000000000..07556f7a7 --- /dev/null +++ b/medcat-v2/medcat/storage/mp_ents_save.py @@ -0,0 +1,60 @@ +from typing import Union +import os +import logging + +import pickle + +from medcat.data.entities import Entities, OnlyCUIEntities + + +logger = logging.getLogger(__name__) + + +class BatchAnnotationSaver: + def __init__(self, save_dir: str, batches_per_save: int): + self.save_dir = save_dir + self.batches_per_save = batches_per_save + self._batch_cache: list[list[ + tuple[str, Union[dict, Entities, OnlyCUIEntities]]]] = [] + os.makedirs(save_dir, exist_ok=True) + self.part_number = 0 + self.annotated_ids_path = os.path.join( + save_dir, "annotated_ids.pickle") + + def _load_existing_ids(self) -> tuple[list[str], int]: + if not os.path.exists(self.annotated_ids_path): + return [], -1 + with open(self.annotated_ids_path, 'rb') as f: + return pickle.load(f) + + def _save_cache(self): + annotated_ids, prev_part_num = self._load_existing_ids() + if (prev_part_num + 1) != self.part_number: + logger.info( + "Found part number %d off disk. Previously %d was kept track " + "of in code. Updating to %d off disk.", + prev_part_num, self.part_number, prev_part_num) + self.part_number = prev_part_num + for batch in self._batch_cache: + for doc_id, _ in batch: + annotated_ids.append(doc_id) + logger.debug("Saving part %d with %d batches", + self.part_number, len(self._batch_cache)) + with open(self.annotated_ids_path, 'wb') as f: + pickle.dump((annotated_ids, self.part_number), f) + # Save batch as part_.pickle + part_path = os.path.join(self.save_dir, + f"part_{self.part_number}.pickle") + part_dict = {id: val for + batch in self._batch_cache for + id, val in batch} + with open(part_path, 'wb') as f: + pickle.dump(part_dict, f) + self._batch_cache.clear() + self.part_number += 1 + + def __call__(self, batch: list[ + tuple[str, Union[dict, Entities, OnlyCUIEntities]]]): + self._batch_cache.append(batch) + if len(self._batch_cache) >= self.batches_per_save: + self._save_cache() From 38a049620789eea20812b0585e3dddf3217233a3 Mon Sep 17 00:00:00 2001 From: mart-r Date: Thu, 17 Jul 2025 15:42:22 +0100 Subject: [PATCH 2/6] CU-8699upt9a: Add a test for multiprocessing saved data. Make sure all the data is saved. That all the files are present. That the saved data is equal to the returned data. --- medcat-v2/tests/test_cat.py | 105 +++++++++++++++++++++++++++++++++++- 1 file changed, 104 insertions(+), 1 deletion(-) diff --git a/medcat-v2/tests/test_cat.py b/medcat-v2/tests/test_cat.py index 5c4de7cd5..417e11ba6 100644 --- a/medcat-v2/tests/test_cat.py +++ b/medcat-v2/tests/test_cat.py @@ -2,7 +2,7 @@ import unittest.mock import pandas as pd import json -from typing import Optional +from typing import Optional, Any from collections import Counter from medcat import cat @@ -21,6 +21,7 @@ import unittest import tempfile +import pickle from . import EXAMPLE_MODEL_PACK_ZIP from . import V1_MODEL_PACK_PATH, UNPACKED_V1_MODEL_PACK_PATH @@ -441,6 +442,108 @@ def test_can_get_multiprocess(self): ents = list(self.cat.get_entities_multi_texts(texts, n_process=3)) self.assert_ents(ents, texts) + def _do_mp_run_with_save( + self, save_to: str, + chars_per_batch: int = 165, + batches_per_save: int = 5, + exp_parts: int = 8 + ) -> tuple[list[str], list[tuple], dict[str, Any], int]: + in_data = [ + f"The patient presented with {name} and " + f"did not have {negname}" + for name in self.cdb.name2info + for negname in self.cdb.name2info if name != negname + ] + out_data = list(self.cat.get_entities_multi_texts( + in_data, + save_dir_path=save_to, + batch_size_chars=chars_per_batch, + batches_per_save=batches_per_save)) + out_dict_all = { + key: cdata for key, cdata in out_data + } + return in_data, out_data, out_dict_all, exp_parts + + def assert_mp_runs_with_save_and_load( + self, save_to: str, + chars_per_batch: int = 165, + batches_per_save: int = 5, + exp_parts: int = 8 + ) -> tuple[ + tuple[list[str], list[tuple], dict[str, Any], int], + tuple[tuple[list[str], int], list[str], int], + ]: + in_data, out_data, out_dict_all, exp_parts = ( + self._do_mp_run_with_save( + save_to, chars_per_batch, batches_per_save, exp_parts)) + anns_file = os.path.join(save_to, 'annotated_ids.pickle') + self.assertTrue(os.path.exists(anns_file)) + with open(anns_file, 'rb') as f: + loaded_data = pickle.load(f) + self.assertEqual(len(loaded_data), 2) + ids, last_part_num = loaded_data + return (in_data, out_data, out_dict_all, exp_parts), ( + loaded_data, ids, last_part_num) + + def assert_mp_runs_save_load_gather( + self, save_to: str, + chars_per_batch: int = 165, + batches_per_save: int = 5, + exp_parts: int = 8 + ) -> tuple[ + tuple[list[str], list[tuple], dict[str, Any], int], + tuple[tuple[list[str], int], list[str], int], + dict[str, Any] + ]: + (in_data, out_data, out_dict_all, exp_parts), ( + loaded_data, ids, num_last_part + ) = self.assert_mp_runs_with_save_and_load( + save_to, chars_per_batch, batches_per_save, exp_parts) + all_loaded_output = {} + for num in range(num_last_part + 1): + with self.subTest(f"Part {num}"): + part_name = f"part_{num}.pickle" + part_path = os.path.join(save_to, part_name) + self.assertTrue(os.path.exists(part_path)) + with open(part_path, 'rb') as f: + part_data = pickle.load(f) + self.assertIsInstance(part_data, dict) + self.assertTrue( + all(key not in all_loaded_output for key in part_data)) + all_loaded_output.update(part_data) + return (in_data, out_data, out_dict_all, exp_parts), ( + loaded_data, ids, num_last_part), all_loaded_output + + def test_multiprocessing_can_save_indices(self): + with tempfile.TemporaryDirectory() as temp_dir: + (in_data, out_data, + out_dict_all, exp_parts), ( + loaded_data, ids, num_last_part + ) = self.assert_mp_runs_with_save_and_load(temp_dir) + self.assertEqual(len(out_data), len(in_data)) + self.assertEqual(len(in_data), len(ids)) + + def test_mp_saves_all_parts(self): + with tempfile.TemporaryDirectory() as temp_dir: + (in_data, out_data, + out_dict_all, exp_parts), ( + loaded_data, ids, num_last_part + ), all_loaded_output = self.assert_mp_runs_save_load_gather( + temp_dir) + # NOTE: the number of parts is 1 greater + self.assertEqual(num_last_part + 1, exp_parts) + + def test_mp_saves_correct_data(self): + with tempfile.TemporaryDirectory() as temp_dir: + (in_data, out_data, + out_dict_all, exp_parts), ( + loaded_data, ids, num_last_part + ), all_loaded_output = self.assert_mp_runs_save_load_gather( + temp_dir) + self.assertEqual(len(all_loaded_output), len(in_data)) + self.assertEqual(all_loaded_output.keys(), out_dict_all.keys()) + self.assertEqual(all_loaded_output, out_dict_all) + class CATWithDocAddonTests(CATIncludingTests): EXAMPLE_TEXT = "Example text to tokenize" From 074f2fce363177858183ba365db875438379665f Mon Sep 17 00:00:00 2001 From: mart-r Date: Thu, 17 Jul 2025 16:46:50 +0100 Subject: [PATCH 3/6] CU-8699upt9a: Fix typo in output saving --- medcat-v2/medcat/cat.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/medcat-v2/medcat/cat.py b/medcat-v2/medcat/cat.py index 287d72852..16c63632f 100644 --- a/medcat-v2/medcat/cat.py +++ b/medcat-v2/medcat/cat.py @@ -299,7 +299,7 @@ def _mp_one_batch_per_process( cur_results = done_future.result() if saver: - saver(main_results) + saver(cur_results) # Yield all results from this batch yield from cur_results @@ -767,7 +767,6 @@ def load_addons( ] return [(addon.full_name, addon) for addon in loaded_addons] - @overload def get_model_card(self, as_dict: Literal[True]) -> ModelCard: pass From 18d9f181cde5eda29827dd2d3feb737073b8892c Mon Sep 17 00:00:00 2001 From: mart-r Date: Thu, 17 Jul 2025 16:49:03 +0100 Subject: [PATCH 4/6] CU-8699upt9a: Add more comprehensive multiprocessing tests with proper batching --- medcat-v2/tests/test_cat.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/medcat-v2/tests/test_cat.py b/medcat-v2/tests/test_cat.py index 417e11ba6..9b5a2c2f1 100644 --- a/medcat-v2/tests/test_cat.py +++ b/medcat-v2/tests/test_cat.py @@ -417,8 +417,9 @@ def test_can_get_multiple_entities(self): texts = [ "The fittest most fit of chronic kidney failure", "The dog is sitting outside the house." - ] - ents = list(self.cat.get_entities_multi_texts(texts)) + ]*10 + ents = list(self.cat.get_entities_multi_texts( + texts, batch_size=2, batch_size_chars=-1)) self.assert_ents(ents, texts) def assert_ents(self, ents: list[tuple], texts: list[str]): @@ -438,8 +439,9 @@ def test_can_get_multiprocess(self): texts = [ "The fittest most fit of chronic kidney failure", "The dog is sitting outside the house." - ] - ents = list(self.cat.get_entities_multi_texts(texts, n_process=3)) + ]*10 + ents = list(self.cat.get_entities_multi_texts( + texts, n_process=3, batch_size=2, batch_size_chars=-1)) self.assert_ents(ents, texts) def _do_mp_run_with_save( From c599a1cdb9f83b12b7fbcf5254750bd3c8bf71ee Mon Sep 17 00:00:00 2001 From: mart-r Date: Thu, 17 Jul 2025 17:18:10 +0100 Subject: [PATCH 5/6] CU-8699upt9a: Fix issue with limited number of jobs submitted per process. --- medcat-v2/medcat/cat.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/medcat-v2/medcat/cat.py b/medcat-v2/medcat/cat.py index 16c63632f..775732cf4 100644 --- a/medcat-v2/medcat/cat.py +++ b/medcat-v2/medcat/cat.py @@ -271,7 +271,7 @@ def _mp_one_batch_per_process( break if not futures: # NOTE: if there wasn't any data, we didn't process anything - return + raise OutOfDataException() # Main process works on next batch while workers are busy main_batch: Optional[list[tuple[str, str, bool]]] try: @@ -304,17 +304,6 @@ def _mp_one_batch_per_process( # Yield all results from this batch yield from cur_results - # Submit next batch to keep workers busy - try: - batch = next(batch_iter) - futures.append( - executor.submit(self._mp_worker_func, batch)) - except StopIteration: - # NOTE: if there's nothing to batch, we've got nothing - # to submit in terms of new work to the workers, - # but we may still have some futures to wait for - pass - def get_entities_multi_texts( self, texts: Union[Iterable[str], Iterable[tuple[str, str]]], @@ -420,8 +409,12 @@ def _multiprocess( "libraries using threads or native extensions.") mp.set_start_method("spawn", force=True) with ProcessPoolExecutor(max_workers=external_processes) as executor: - yield from self._mp_one_batch_per_process( - executor, batch_iter, external_processes, saver=saver) + while True: + try: + yield from self._mp_one_batch_per_process( + executor, batch_iter, external_processes, saver=saver) + except OutOfDataException: + break def _get_entity(self, ent: MutableEntity, doc_tokens: list[str], @@ -823,3 +816,7 @@ def __eq__(self, other: Any) -> bool: def add_addon(self, addon: AddonComponent) -> None: self.config.components.addons.append(addon.config) self._pipeline.add_addon(addon) + + +class OutOfDataException(ValueError): + pass From ef03148901b149dd00f862eea19e9768a90afa9d Mon Sep 17 00:00:00 2001 From: mart-r Date: Thu, 17 Jul 2025 17:21:40 +0100 Subject: [PATCH 6/6] CU-8699upt9a: Add a few more tests regarding multiprocessing with batches for saved data --- medcat-v2/tests/test_cat.py | 54 ++++++++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 10 deletions(-) diff --git a/medcat-v2/tests/test_cat.py b/medcat-v2/tests/test_cat.py index 9b5a2c2f1..743618ebf 100644 --- a/medcat-v2/tests/test_cat.py +++ b/medcat-v2/tests/test_cat.py @@ -68,7 +68,6 @@ class TrainedModelTests(unittest.TestCase): def setUpClass(cls): cls.model = cat.CAT.load_model_pack(cls.TRAINED_MODEL_PATH) if cls.model.config.components.linking.train: - print("TRAINING WAS ENABLE! NEED TO DISABLE") cls.model.config.components.linking.train = False @@ -448,7 +447,8 @@ def _do_mp_run_with_save( self, save_to: str, chars_per_batch: int = 165, batches_per_save: int = 5, - exp_parts: int = 8 + exp_parts: int = 8, + n_process: int = 1, ) -> tuple[list[str], list[tuple], dict[str, Any], int]: in_data = [ f"The patient presented with {name} and " @@ -460,7 +460,9 @@ def _do_mp_run_with_save( in_data, save_dir_path=save_to, batch_size_chars=chars_per_batch, - batches_per_save=batches_per_save)) + batches_per_save=batches_per_save, + n_process=n_process, + )) out_dict_all = { key: cdata for key, cdata in out_data } @@ -470,14 +472,16 @@ def assert_mp_runs_with_save_and_load( self, save_to: str, chars_per_batch: int = 165, batches_per_save: int = 5, - exp_parts: int = 8 + exp_parts: int = 8, + n_process: int = 1, ) -> tuple[ tuple[list[str], list[tuple], dict[str, Any], int], tuple[tuple[list[str], int], list[str], int], ]: in_data, out_data, out_dict_all, exp_parts = ( self._do_mp_run_with_save( - save_to, chars_per_batch, batches_per_save, exp_parts)) + save_to, chars_per_batch, batches_per_save, exp_parts, + n_process=n_process)) anns_file = os.path.join(save_to, 'annotated_ids.pickle') self.assertTrue(os.path.exists(anns_file)) with open(anns_file, 'rb') as f: @@ -491,7 +495,8 @@ def assert_mp_runs_save_load_gather( self, save_to: str, chars_per_batch: int = 165, batches_per_save: int = 5, - exp_parts: int = 8 + exp_parts: int = 8, + n_process: int = 1, ) -> tuple[ tuple[list[str], list[tuple], dict[str, Any], int], tuple[tuple[list[str], int], list[str], int], @@ -500,7 +505,8 @@ def assert_mp_runs_save_load_gather( (in_data, out_data, out_dict_all, exp_parts), ( loaded_data, ids, num_last_part ) = self.assert_mp_runs_with_save_and_load( - save_to, chars_per_batch, batches_per_save, exp_parts) + save_to, chars_per_batch, batches_per_save, exp_parts, + n_process=n_process) all_loaded_output = {} for num in range(num_last_part + 1): with self.subTest(f"Part {num}"): @@ -535,6 +541,15 @@ def test_mp_saves_all_parts(self): # NOTE: the number of parts is 1 greater self.assertEqual(num_last_part + 1, exp_parts) + def assert_correct_loaded_output( + self, + in_data: list[str], + out_dict_all: dict[str, Any], + all_loaded_output: dict[str, Any]): + self.assertEqual(len(all_loaded_output), len(in_data)) + self.assertEqual(all_loaded_output.keys(), out_dict_all.keys()) + self.assertEqual(all_loaded_output, out_dict_all) + def test_mp_saves_correct_data(self): with tempfile.TemporaryDirectory() as temp_dir: (in_data, out_data, @@ -542,9 +557,28 @@ def test_mp_saves_correct_data(self): loaded_data, ids, num_last_part ), all_loaded_output = self.assert_mp_runs_save_load_gather( temp_dir) - self.assertEqual(len(all_loaded_output), len(in_data)) - self.assertEqual(all_loaded_output.keys(), out_dict_all.keys()) - self.assertEqual(all_loaded_output, out_dict_all) + self.assert_correct_loaded_output( + in_data, out_dict_all, all_loaded_output) + + def test_mp_saves_correct_data_with_2_proc(self): + with tempfile.TemporaryDirectory() as temp_dir: + (in_data, out_data, + out_dict_all, exp_parts), ( + loaded_data, ids, num_last_part + ), all_loaded_output = self.assert_mp_runs_save_load_gather( + temp_dir, n_process=2) + self.assert_correct_loaded_output( + in_data, out_dict_all, all_loaded_output) + + def test_mp_saves_correct_data_with_3_proc(self): + with tempfile.TemporaryDirectory() as temp_dir: + (in_data, out_data, + out_dict_all, exp_parts), ( + loaded_data, ids, num_last_part + ), all_loaded_output = self.assert_mp_runs_save_load_gather( + temp_dir, n_process=3) + self.assert_correct_loaded_output( + in_data, out_dict_all, all_loaded_output) class CATWithDocAddonTests(CATIncludingTests):