Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Use multiprocessing.Pool instead of ProcessPoolExecutor #453

Merged
merged 4 commits into from
Nov 18, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 38 additions & 25 deletions nowcasting_dataset/manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Manager class."""

import logging
from concurrent import futures
import multiprocessing
from pathlib import Path
from typing import Optional, Union

Expand Down Expand Up @@ -419,24 +419,21 @@ def create_batches(self, overwrite_batches: bool) -> None:
locations_for_each_example["t0_datetime_UTC"] = pd.to_datetime(
locations_for_each_example["t0_datetime_UTC"]
)
locations_for_each_example_of_each_split[split_name] = locations_for_each_example
if len(locations_for_each_example) > 0:
locations_for_each_example_of_each_split[split_name] = locations_for_each_example

# Fire up a separate process for each DataSource, and pass it a list of batches to
# create, and whether to utils.upload_and_delete_local_files().
# TODO: Issue 321: Split this up into separate functions!!!
n_data_sources = len(self.data_sources)
nd_utils.set_fsspec_for_multiprocess()
for split_name in splits_which_need_more_batches:
locations_for_split = locations_for_each_example_of_each_split[split_name]
with futures.ProcessPoolExecutor(max_workers=n_data_sources) as executor:
future_create_batches_jobs = []
for split_name, locations_for_split in locations_for_each_example_of_each_split.items():
with multiprocessing.Pool(processes=n_data_sources) as pool:
async_results_from_create_batches = []
for worker_id, (data_source_name, data_source) in enumerate(
self.data_sources.items()
):

if len(locations_for_split) == 0:
break

# Get indexes of first batch and example. And subset locations_for_split.
idx_of_first_batch = first_batches_to_create[split_name][data_source_name]
idx_of_first_example = idx_of_first_batch * self.config.process.batch_size
Expand All @@ -446,6 +443,8 @@ def create_batches(self, overwrite_batches: bool) -> None:
dst_path = (
self.config.output_data.filepath / split_name.value / data_source_name
)

# TODO: Issue 455: Guarantee that local temp path is unique and empty.
local_temp_path = (
self.local_temp_path
/ split_name.value
Expand All @@ -458,27 +457,41 @@ def create_batches(self, overwrite_batches: bool) -> None:
if self.save_batches_locally_and_upload:
nd_fs_utils.makedirs(local_temp_path, exist_ok=True)

# Submit data_source.create_batches task to the worker process.
future = executor.submit(
data_source.create_batches,
# Key word arguments to be passed into data_source.create_batches():
kwargs_for_create_batches = dict(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thats how you do it! nice

spatial_and_temporal_locations_of_each_example=locations,
idx_of_first_batch=idx_of_first_batch,
batch_size=self.config.process.batch_size,
dst_path=dst_path,
local_temp_path=local_temp_path,
upload_every_n_batches=self.config.process.upload_every_n_batches,
)
future_create_batches_jobs.append(future)

# Wait for all futures to finish:
for future, data_source_name in zip(
future_create_batches_jobs, self.data_sources.keys()
):
# Call exception() to propagate any exceptions raised by the worker process into
# the main process, and to wait for the worker to finish.
exception = future.exception()
if exception is not None:
logger.exception(
f"Worker process {data_source_name} raised exception!\n{exception}"
)
raise exception
# Logger messages for callbacks:
callback_msg = (
f"{data_source_name} has finished created batches for {split_name}!"
)
error_callback_msg = (
f"Exception raised by {data_source_name} whilst creating batches for"
f" {split_name}:\n"
)

# Submit data_source.create_batches task to the worker process.
logger.debug(
f"About to submit create_batches task for {data_source_name}, {split_name}"
)
async_result = pool.apply_async(
data_source.create_batches,
kwds=kwargs_for_create_batches,
callback=lambda result: logger.info(callback_msg),
error_callback=lambda exception: logger.error(
error_callback_msg + str(exception)
),
)
async_results_from_create_batches.append(async_result)

# Wait for all async_results to finish:
for async_result in async_results_from_create_batches:
async_result.wait()

logger.info(f"Finished creating batches for {split_name}!")