11"""Manager class."""
22
33import logging
4- from concurrent import futures
4+ import multiprocessing
55from pathlib import Path
66from typing import Optional , Union
77
@@ -419,24 +419,21 @@ def create_batches(self, overwrite_batches: bool) -> None:
419419 locations_for_each_example ["t0_datetime_UTC" ] = pd .to_datetime (
420420 locations_for_each_example ["t0_datetime_UTC" ]
421421 )
422- locations_for_each_example_of_each_split [split_name ] = locations_for_each_example
422+ if len (locations_for_each_example ) > 0 :
423+ locations_for_each_example_of_each_split [split_name ] = locations_for_each_example
423424
424425 # Fire up a separate process for each DataSource, and pass it a list of batches to
425426 # create, and whether to utils.upload_and_delete_local_files().
426427 # TODO: Issue 321: Split this up into separate functions!!!
427428 n_data_sources = len (self .data_sources )
428429 nd_utils .set_fsspec_for_multiprocess ()
429- for split_name in splits_which_need_more_batches :
430- locations_for_split = locations_for_each_example_of_each_split [split_name ]
431- with futures .ProcessPoolExecutor (max_workers = n_data_sources ) as executor :
432- future_create_batches_jobs = []
430+ for split_name , locations_for_split in locations_for_each_example_of_each_split .items ():
431+ with multiprocessing .Pool (processes = n_data_sources ) as pool :
432+ async_results_from_create_batches = []
433433 for worker_id , (data_source_name , data_source ) in enumerate (
434434 self .data_sources .items ()
435435 ):
436436
437- if len (locations_for_split ) == 0 :
438- break
439-
440437 # Get indexes of first batch and example. And subset locations_for_split.
441438 idx_of_first_batch = first_batches_to_create [split_name ][data_source_name ]
442439 idx_of_first_example = idx_of_first_batch * self .config .process .batch_size
@@ -446,6 +443,8 @@ def create_batches(self, overwrite_batches: bool) -> None:
446443 dst_path = (
447444 self .config .output_data .filepath / split_name .value / data_source_name
448445 )
446+
447+ # TODO: Issue 455: Guarantee that local temp path is unique and empty.
449448 local_temp_path = (
450449 self .local_temp_path
451450 / split_name .value
@@ -458,27 +457,41 @@ def create_batches(self, overwrite_batches: bool) -> None:
458457 if self .save_batches_locally_and_upload :
459458 nd_fs_utils .makedirs (local_temp_path , exist_ok = True )
460459
461- # Submit data_source.create_batches task to the worker process.
462- future = executor .submit (
463- data_source .create_batches ,
460+ # Key word arguments to be passed into data_source.create_batches():
461+ kwargs_for_create_batches = dict (
464462 spatial_and_temporal_locations_of_each_example = locations ,
465463 idx_of_first_batch = idx_of_first_batch ,
466464 batch_size = self .config .process .batch_size ,
467465 dst_path = dst_path ,
468466 local_temp_path = local_temp_path ,
469467 upload_every_n_batches = self .config .process .upload_every_n_batches ,
470468 )
471- future_create_batches_jobs .append (future )
472469
473- # Wait for all futures to finish:
474- for future , data_source_name in zip (
475- future_create_batches_jobs , self .data_sources .keys ()
476- ):
477- # Call exception() to propagate any exceptions raised by the worker process into
478- # the main process, and to wait for the worker to finish.
479- exception = future .exception ()
480- if exception is not None :
481- logger .exception (
482- f"Worker process { data_source_name } raised exception!\n { exception } "
483- )
484- raise exception
470+ # Logger messages for callbacks:
471+ callback_msg = (
472+ f"{ data_source_name } has finished created batches for { split_name } !"
473+ )
474+ error_callback_msg = (
475+ f"Exception raised by { data_source_name } whilst creating batches for"
476+ f" { split_name } :\n "
477+ )
478+
479+ # Submit data_source.create_batches task to the worker process.
480+ logger .debug (
481+ f"About to submit create_batches task for { data_source_name } , { split_name } "
482+ )
483+ async_result = pool .apply_async (
484+ data_source .create_batches ,
485+ kwds = kwargs_for_create_batches ,
486+ callback = lambda result : logger .info (callback_msg ),
487+ error_callback = lambda exception : logger .error (
488+ error_callback_msg + str (exception )
489+ ),
490+ )
491+ async_results_from_create_batches .append (async_result )
492+
493+ # Wait for all async_results to finish:
494+ for async_result in async_results_from_create_batches :
495+ async_result .wait ()
496+
497+ logger .info (f"Finished creating batches for { split_name } !" )
0 commit comments