diff --git a/nowcasting_dataset/data_sources/satellite_data_source.py b/nowcasting_dataset/data_sources/satellite_data_source.py index 657ee5c0..19b22a53 100644 --- a/nowcasting_dataset/data_sources/satellite_data_source.py +++ b/nowcasting_dataset/data_sources/satellite_data_source.py @@ -9,7 +9,6 @@ import pandas as pd import xarray as xr -from nowcasting_dataset import utils from nowcasting_dataset.data_sources.data_source import ZarrDataSource from nowcasting_dataset.dataset.example import Example, to_numpy @@ -203,7 +202,6 @@ def open_sat_data(filename: str, consolidated: bool) -> xr.DataArray: consolidated: Whether or not the Zarr metadata is consolidated. """ _LOG.debug("Opening satellite data: %s", filename) - utils.set_fsspec_for_multiprocess() # We load using chunks=None so xarray *doesn't* use Dask to # load the Zarr chunks from disk. Using Dask to load the data diff --git a/nowcasting_dataset/dataset/datasets.py b/nowcasting_dataset/dataset/datasets.py index 9038cdd3..6739b356 100644 --- a/nowcasting_dataset/dataset/datasets.py +++ b/nowcasting_dataset/dataset/datasets.py @@ -35,6 +35,7 @@ ) from nowcasting_dataset.data_sources.satellite_data_source import SAT_VARIABLE_NAMES from nowcasting_dataset.dataset import example +from nowcasting_dataset.utils import set_fsspec_for_multiprocess logger = logging.getLogger(__name__) @@ -127,6 +128,8 @@ def __init__( self.forecast_minutes = forecast_minutes self.configuration = configuration + logger.info(f"Setting up NetCDFDataset for {src_path}") + if self.forecast_minutes is None: self.forecast_minutes = configuration.process.forecast_minutes if self.history_minutes is None: @@ -285,6 +288,9 @@ def per_worker_init(self, worker_id: int) -> None: _LOG.debug(f"Opening {type(data_source).__name__}") data_source.open() + # fix for fsspecs + set_fsspec_for_multiprocess() + self._per_worker_init_has_run = True def __iter__(self): @@ -349,8 +355,12 @@ def _get_locations_for_batch( def worker_init_fn(worker_id): """Configures each dataset worker process. - Just has one job! To call NowcastingDataset.per_worker_init(). + 1. Get fsspec ready for multi process + 2. To call NowcastingDataset.per_worker_init(). """ + # fix for fsspec when using multprocess + set_fsspec_for_multiprocess() + # get_worker_info() returns information specific to each worker process. worker_info = torch.utils.data.get_worker_info() if worker_info is None: