18
18
from medcat .storage .serialisers import serialise , AvailableSerialisers
19
19
from medcat .storage .serialisers import deserialise
20
20
from medcat .storage .serialisables import AbstractSerialisable
21
+ from medcat .storage .mp_ents_save import BatchAnnotationSaver
21
22
from medcat .utils .fileutils import ensure_folder_if_parent
22
23
from medcat .utils .hasher import Hasher
23
24
from medcat .pipeline .pipeline import Pipeline
@@ -159,7 +160,7 @@ def get_entities(self,
159
160
def _mp_worker_func (
160
161
self ,
161
162
texts_and_indices : list [tuple [str , str , bool ]]
162
- ) -> list [tuple [str , str , Union [dict , Entities , OnlyCUIEntities ]]]:
163
+ ) -> list [tuple [str , Union [dict , Entities , OnlyCUIEntities ]]]:
163
164
# NOTE: this is needed for subprocess as otherwise they wouldn't have
164
165
# any of these set
165
166
# NOTE: these need to by dynamic in case the extra's aren't included
@@ -180,7 +181,7 @@ def _mp_worker_func(
180
181
elif has_rel_cat and isinstance (addon , RelCATAddon ):
181
182
addon ._rel_cat ._init_data_paths ()
182
183
return [
183
- (text , text_index , self .get_entities (text , only_cui = only_cui ))
184
+ (text_index , self .get_entities (text , only_cui = only_cui ))
184
185
for text , text_index , only_cui in texts_and_indices ]
185
186
186
187
def _generate_batches_by_char_length (
@@ -256,7 +257,8 @@ def _mp_one_batch_per_process(
256
257
self ,
257
258
executor : ProcessPoolExecutor ,
258
259
batch_iter : Iterator [list [tuple [str , str , bool ]]],
259
- external_processes : int
260
+ external_processes : int ,
261
+ saver : Optional [BatchAnnotationSaver ],
260
262
) -> Iterator [tuple [str , Union [dict , Entities , OnlyCUIEntities ]]]:
261
263
futures : list [Future ] = []
262
264
# submit batches, one for each external processes
@@ -269,16 +271,16 @@ def _mp_one_batch_per_process(
269
271
break
270
272
if not futures :
271
273
# NOTE: if there wasn't any data, we didn't process anything
272
- return
274
+ raise OutOfDataException ()
273
275
# Main process works on next batch while workers are busy
274
276
main_batch : Optional [list [tuple [str , str , bool ]]]
275
277
try :
276
278
main_batch = next (batch_iter )
277
279
main_results = self ._mp_worker_func (main_batch )
278
-
280
+ if saver :
281
+ saver (main_results )
279
282
# Yield main process results immediately
280
- for result in main_results :
281
- yield result [1 ], result [2 ]
283
+ yield from main_results
282
284
283
285
except StopIteration :
284
286
main_batch = None
@@ -295,20 +297,12 @@ def _mp_one_batch_per_process(
295
297
done_future = next (as_completed (futures ))
296
298
futures .remove (done_future )
297
299
298
- # Yield all results from this batch
299
- for result in done_future . result () :
300
- yield result [ 1 ], result [ 2 ]
300
+ cur_results = done_future . result ()
301
+ if saver :
302
+ saver ( cur_results )
301
303
302
- # Submit next batch to keep workers busy
303
- try :
304
- batch = next (batch_iter )
305
- futures .append (
306
- executor .submit (self ._mp_worker_func , batch ))
307
- except StopIteration :
308
- # NOTE: if there's nothing to batch, we've got nothing
309
- # to submit in terms of new work to the workers,
310
- # but we may still have some futures to wait for
311
- pass
304
+ # Yield all results from this batch
305
+ yield from cur_results
312
306
313
307
def get_entities_multi_texts (
314
308
self ,
@@ -317,6 +311,8 @@ def get_entities_multi_texts(
317
311
n_process : int = 1 ,
318
312
batch_size : int = - 1 ,
319
313
batch_size_chars : int = 1_000_000 ,
314
+ save_dir_path : Optional [str ] = None ,
315
+ batches_per_save : int = 20 ,
320
316
) -> Iterator [tuple [str , Union [dict , Entities , OnlyCUIEntities ]]]:
321
317
"""Get entities from multiple texts (potentially in parallel).
322
318
@@ -343,6 +339,16 @@ def get_entities_multi_texts(
343
339
Each process will be given batch of texts with a total
344
340
number of characters not exceeding this value. Defaults
345
341
to 1,000,000 characters. Set to -1 to disable.
342
+ save_dir_path (Optional[str]):
343
+ The path to where (if specified) the results are saved.
344
+ The directory will have a `annotated_ids.pickle` file
345
+ containing the tuple[list[str], int] with a list of
346
+ indices already saved and then umber of parts already saved.
347
+ In addition there will be (usually multuple) files in the
348
+ `part_<num>.pickle` format with the partial outputs.
349
+ batches_per_save (int):
350
+ The number of patches to save (if `save_dir_path` is specified)
351
+ at once. Defaults to 20.
346
352
347
353
Yields:
348
354
Iterator[tuple[str, Union[dict, Entities, OnlyCUIEntities]]]:
@@ -352,15 +358,27 @@ def get_entities_multi_texts(
352
358
Union [Iterator [str ], Iterator [tuple [str , str ]]], iter (texts ))
353
359
batch_iter = self ._generate_batches (
354
360
text_iter , batch_size , batch_size_chars , only_cui )
361
+ if save_dir_path :
362
+ saver = BatchAnnotationSaver (save_dir_path , batches_per_save )
363
+ else :
364
+ saver = None
355
365
if n_process == 1 :
356
366
# just do in series
357
367
for batch in batch_iter :
358
- for _ , text_index , result in self ._mp_worker_func (batch ):
359
- yield text_index , result
368
+ batch_results = self ._mp_worker_func (batch )
369
+ if saver is not None :
370
+ saver (batch_results )
371
+ yield from batch_results
372
+ if saver :
373
+ # save remainder
374
+ saver ._save_cache ()
360
375
return
361
376
362
377
with self ._no_usage_monitor_exit_flushing ():
363
- yield from self ._multiprocess (n_process , batch_iter )
378
+ yield from self ._multiprocess (n_process , batch_iter , saver )
379
+ if saver :
380
+ # save remainder
381
+ saver ._save_cache ()
364
382
365
383
@contextmanager
366
384
def _no_usage_monitor_exit_flushing (self ):
@@ -379,7 +397,8 @@ def _no_usage_monitor_exit_flushing(self):
379
397
380
398
def _multiprocess (
381
399
self , n_process : int ,
382
- batch_iter : Iterator [list [tuple [str , str , bool ]]]
400
+ batch_iter : Iterator [list [tuple [str , str , bool ]]],
401
+ saver : Optional [BatchAnnotationSaver ],
383
402
) -> Iterator [tuple [str , Union [dict , Entities , OnlyCUIEntities ]]]:
384
403
external_processes = n_process - 1
385
404
if self .FORCE_SPAWN_MP :
@@ -390,8 +409,12 @@ def _multiprocess(
390
409
"libraries using threads or native extensions." )
391
410
mp .set_start_method ("spawn" , force = True )
392
411
with ProcessPoolExecutor (max_workers = external_processes ) as executor :
393
- yield from self ._mp_one_batch_per_process (
394
- executor , batch_iter , external_processes )
412
+ while True :
413
+ try :
414
+ yield from self ._mp_one_batch_per_process (
415
+ executor , batch_iter , external_processes , saver = saver )
416
+ except OutOfDataException :
417
+ break
395
418
396
419
def _get_entity (self , ent : MutableEntity ,
397
420
doc_tokens : list [str ],
@@ -737,7 +760,6 @@ def load_addons(
737
760
]
738
761
return [(addon .full_name , addon ) for addon in loaded_addons ]
739
762
740
-
741
763
@overload
742
764
def get_model_card (self , as_dict : Literal [True ]) -> ModelCard :
743
765
pass
@@ -794,3 +816,7 @@ def __eq__(self, other: Any) -> bool:
794
816
def add_addon (self , addon : AddonComponent ) -> None :
795
817
self .config .components .addons .append (addon .config )
796
818
self ._pipeline .add_addon (addon )
819
+
820
+
821
+ class OutOfDataException (ValueError ):
822
+ pass
0 commit comments