Skip to content

Commit 827fb3d

Browse files
committed
CU-8699upt9a Allow saving output onto disk when multiprocessing (#52)
* CU-8699upt9a: Add option to save multiprocessing output * CU-8699upt9a: Add a test for multiprocessing saved data. Make sure all the data is saved. That all the files are present. That the saved data is equal to the returned data. * CU-8699upt9a: Fix typo in output saving * CU-8699upt9a: Add more comprehensive multiprocessing tests with proper batching * CU-8699upt9a: Fix issue with limited number of jobs submitted per process. * CU-8699upt9a: Add a few more tests regarding multiprocessing with batches for saved data
1 parent d6b00e8 commit 827fb3d

File tree

3 files changed

+258
-33
lines changed

3 files changed

+258
-33
lines changed

medcat-v2/medcat/cat.py

Lines changed: 53 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from medcat.storage.serialisers import serialise, AvailableSerialisers
1919
from medcat.storage.serialisers import deserialise
2020
from medcat.storage.serialisables import AbstractSerialisable
21+
from medcat.storage.mp_ents_save import BatchAnnotationSaver
2122
from medcat.utils.fileutils import ensure_folder_if_parent
2223
from medcat.utils.hasher import Hasher
2324
from medcat.pipeline.pipeline import Pipeline
@@ -159,7 +160,7 @@ def get_entities(self,
159160
def _mp_worker_func(
160161
self,
161162
texts_and_indices: list[tuple[str, str, bool]]
162-
) -> list[tuple[str, str, Union[dict, Entities, OnlyCUIEntities]]]:
163+
) -> list[tuple[str, Union[dict, Entities, OnlyCUIEntities]]]:
163164
# NOTE: this is needed for subprocess as otherwise they wouldn't have
164165
# any of these set
165166
# NOTE: these need to by dynamic in case the extra's aren't included
@@ -180,7 +181,7 @@ def _mp_worker_func(
180181
elif has_rel_cat and isinstance(addon, RelCATAddon):
181182
addon._rel_cat._init_data_paths()
182183
return [
183-
(text, text_index, self.get_entities(text, only_cui=only_cui))
184+
(text_index, self.get_entities(text, only_cui=only_cui))
184185
for text, text_index, only_cui in texts_and_indices]
185186

186187
def _generate_batches_by_char_length(
@@ -256,7 +257,8 @@ def _mp_one_batch_per_process(
256257
self,
257258
executor: ProcessPoolExecutor,
258259
batch_iter: Iterator[list[tuple[str, str, bool]]],
259-
external_processes: int
260+
external_processes: int,
261+
saver: Optional[BatchAnnotationSaver],
260262
) -> Iterator[tuple[str, Union[dict, Entities, OnlyCUIEntities]]]:
261263
futures: list[Future] = []
262264
# submit batches, one for each external processes
@@ -269,16 +271,16 @@ def _mp_one_batch_per_process(
269271
break
270272
if not futures:
271273
# NOTE: if there wasn't any data, we didn't process anything
272-
return
274+
raise OutOfDataException()
273275
# Main process works on next batch while workers are busy
274276
main_batch: Optional[list[tuple[str, str, bool]]]
275277
try:
276278
main_batch = next(batch_iter)
277279
main_results = self._mp_worker_func(main_batch)
278-
280+
if saver:
281+
saver(main_results)
279282
# Yield main process results immediately
280-
for result in main_results:
281-
yield result[1], result[2]
283+
yield from main_results
282284

283285
except StopIteration:
284286
main_batch = None
@@ -295,20 +297,12 @@ def _mp_one_batch_per_process(
295297
done_future = next(as_completed(futures))
296298
futures.remove(done_future)
297299

298-
# Yield all results from this batch
299-
for result in done_future.result():
300-
yield result[1], result[2]
300+
cur_results = done_future.result()
301+
if saver:
302+
saver(cur_results)
301303

302-
# Submit next batch to keep workers busy
303-
try:
304-
batch = next(batch_iter)
305-
futures.append(
306-
executor.submit(self._mp_worker_func, batch))
307-
except StopIteration:
308-
# NOTE: if there's nothing to batch, we've got nothing
309-
# to submit in terms of new work to the workers,
310-
# but we may still have some futures to wait for
311-
pass
304+
# Yield all results from this batch
305+
yield from cur_results
312306

313307
def get_entities_multi_texts(
314308
self,
@@ -317,6 +311,8 @@ def get_entities_multi_texts(
317311
n_process: int = 1,
318312
batch_size: int = -1,
319313
batch_size_chars: int = 1_000_000,
314+
save_dir_path: Optional[str] = None,
315+
batches_per_save: int = 20,
320316
) -> Iterator[tuple[str, Union[dict, Entities, OnlyCUIEntities]]]:
321317
"""Get entities from multiple texts (potentially in parallel).
322318
@@ -343,6 +339,16 @@ def get_entities_multi_texts(
343339
Each process will be given batch of texts with a total
344340
number of characters not exceeding this value. Defaults
345341
to 1,000,000 characters. Set to -1 to disable.
342+
save_dir_path (Optional[str]):
343+
The path to where (if specified) the results are saved.
344+
The directory will have a `annotated_ids.pickle` file
345+
containing the tuple[list[str], int] with a list of
346+
indices already saved and then umber of parts already saved.
347+
In addition there will be (usually multuple) files in the
348+
`part_<num>.pickle` format with the partial outputs.
349+
batches_per_save (int):
350+
The number of patches to save (if `save_dir_path` is specified)
351+
at once. Defaults to 20.
346352
347353
Yields:
348354
Iterator[tuple[str, Union[dict, Entities, OnlyCUIEntities]]]:
@@ -352,15 +358,27 @@ def get_entities_multi_texts(
352358
Union[Iterator[str], Iterator[tuple[str, str]]], iter(texts))
353359
batch_iter = self._generate_batches(
354360
text_iter, batch_size, batch_size_chars, only_cui)
361+
if save_dir_path:
362+
saver = BatchAnnotationSaver(save_dir_path, batches_per_save)
363+
else:
364+
saver = None
355365
if n_process == 1:
356366
# just do in series
357367
for batch in batch_iter:
358-
for _, text_index, result in self._mp_worker_func(batch):
359-
yield text_index, result
368+
batch_results = self._mp_worker_func(batch)
369+
if saver is not None:
370+
saver(batch_results)
371+
yield from batch_results
372+
if saver:
373+
# save remainder
374+
saver._save_cache()
360375
return
361376

362377
with self._no_usage_monitor_exit_flushing():
363-
yield from self._multiprocess(n_process, batch_iter)
378+
yield from self._multiprocess(n_process, batch_iter, saver)
379+
if saver:
380+
# save remainder
381+
saver._save_cache()
364382

365383
@contextmanager
366384
def _no_usage_monitor_exit_flushing(self):
@@ -379,7 +397,8 @@ def _no_usage_monitor_exit_flushing(self):
379397

380398
def _multiprocess(
381399
self, n_process: int,
382-
batch_iter: Iterator[list[tuple[str, str, bool]]]
400+
batch_iter: Iterator[list[tuple[str, str, bool]]],
401+
saver: Optional[BatchAnnotationSaver],
383402
) -> Iterator[tuple[str, Union[dict, Entities, OnlyCUIEntities]]]:
384403
external_processes = n_process - 1
385404
if self.FORCE_SPAWN_MP:
@@ -390,8 +409,12 @@ def _multiprocess(
390409
"libraries using threads or native extensions.")
391410
mp.set_start_method("spawn", force=True)
392411
with ProcessPoolExecutor(max_workers=external_processes) as executor:
393-
yield from self._mp_one_batch_per_process(
394-
executor, batch_iter, external_processes)
412+
while True:
413+
try:
414+
yield from self._mp_one_batch_per_process(
415+
executor, batch_iter, external_processes, saver=saver)
416+
except OutOfDataException:
417+
break
395418

396419
def _get_entity(self, ent: MutableEntity,
397420
doc_tokens: list[str],
@@ -737,7 +760,6 @@ def load_addons(
737760
]
738761
return [(addon.full_name, addon) for addon in loaded_addons]
739762

740-
741763
@overload
742764
def get_model_card(self, as_dict: Literal[True]) -> ModelCard:
743765
pass
@@ -794,3 +816,7 @@ def __eq__(self, other: Any) -> bool:
794816
def add_addon(self, addon: AddonComponent) -> None:
795817
self.config.components.addons.append(addon.config)
796818
self._pipeline.add_addon(addon)
819+
820+
821+
class OutOfDataException(ValueError):
822+
pass
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from typing import Union
2+
import os
3+
import logging
4+
5+
import pickle
6+
7+
from medcat.data.entities import Entities, OnlyCUIEntities
8+
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class BatchAnnotationSaver:
14+
def __init__(self, save_dir: str, batches_per_save: int):
15+
self.save_dir = save_dir
16+
self.batches_per_save = batches_per_save
17+
self._batch_cache: list[list[
18+
tuple[str, Union[dict, Entities, OnlyCUIEntities]]]] = []
19+
os.makedirs(save_dir, exist_ok=True)
20+
self.part_number = 0
21+
self.annotated_ids_path = os.path.join(
22+
save_dir, "annotated_ids.pickle")
23+
24+
def _load_existing_ids(self) -> tuple[list[str], int]:
25+
if not os.path.exists(self.annotated_ids_path):
26+
return [], -1
27+
with open(self.annotated_ids_path, 'rb') as f:
28+
return pickle.load(f)
29+
30+
def _save_cache(self):
31+
annotated_ids, prev_part_num = self._load_existing_ids()
32+
if (prev_part_num + 1) != self.part_number:
33+
logger.info(
34+
"Found part number %d off disk. Previously %d was kept track "
35+
"of in code. Updating to %d off disk.",
36+
prev_part_num, self.part_number, prev_part_num)
37+
self.part_number = prev_part_num
38+
for batch in self._batch_cache:
39+
for doc_id, _ in batch:
40+
annotated_ids.append(doc_id)
41+
logger.debug("Saving part %d with %d batches",
42+
self.part_number, len(self._batch_cache))
43+
with open(self.annotated_ids_path, 'wb') as f:
44+
pickle.dump((annotated_ids, self.part_number), f)
45+
# Save batch as part_<num>.pickle
46+
part_path = os.path.join(self.save_dir,
47+
f"part_{self.part_number}.pickle")
48+
part_dict = {id: val for
49+
batch in self._batch_cache for
50+
id, val in batch}
51+
with open(part_path, 'wb') as f:
52+
pickle.dump(part_dict, f)
53+
self._batch_cache.clear()
54+
self.part_number += 1
55+
56+
def __call__(self, batch: list[
57+
tuple[str, Union[dict, Entities, OnlyCUIEntities]]]):
58+
self._batch_cache.append(batch)
59+
if len(self._batch_cache) >= self.batches_per_save:
60+
self._save_cache()

0 commit comments

Comments
 (0)