Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Commit f896a5e

Browse files
committed
Remove n_timesteps_per_batch and _cache from DataSources.
1 parent 7004973 commit f896a5e

File tree

9 files changed

+74
-411
lines changed

9 files changed

+74
-411
lines changed

conftest.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
register_xr_data_set_to_tensor()
2323

2424

25-
def pytest_addoption(parser):
25+
def pytest_addoption(parser): # noqa: D103
2626
parser.addoption(
2727
"--use_cloud_data",
2828
action="store_true",
@@ -32,12 +32,12 @@ def pytest_addoption(parser):
3232

3333

3434
@pytest.fixture
35-
def use_cloud_data(request):
35+
def use_cloud_data(request): # noqa: D103
3636
return request.config.getoption("--use_cloud_data")
3737

3838

3939
@pytest.fixture
40-
def sat_filename(use_cloud_data: bool) -> Path:
40+
def sat_filename(use_cloud_data: bool) -> Path: # noqa: D103
4141
if use_cloud_data:
4242
return consts.SAT_FILENAME
4343
else:
@@ -47,24 +47,23 @@ def sat_filename(use_cloud_data: bool) -> Path:
4747

4848

4949
@pytest.fixture
50-
def sat_data_source(sat_filename: Path):
50+
def sat_data_source(sat_filename: Path): # noqa: D103
5151
return SatelliteDataSource(
5252
image_size_pixels=pytest.IMAGE_SIZE_PIXELS,
5353
zarr_path=sat_filename,
5454
history_minutes=0,
5555
forecast_minutes=5,
5656
channels=("HRV",),
57-
n_timesteps_per_batch=2,
5857
)
5958

6059

6160
@pytest.fixture
62-
def general_data_source():
61+
def general_data_source(): # noqa: D103
6362
return MetadataDataSource(history_minutes=0, forecast_minutes=5, object_at_center="GSP")
6463

6564

6665
@pytest.fixture
67-
def gsp_data_source():
66+
def gsp_data_source(): # noqa: D103
6867
return GSPDataSource(
6968
image_size_pixels=16,
7069
meters_per_pixel=2000,
@@ -75,13 +74,13 @@ def gsp_data_source():
7574

7675

7776
@pytest.fixture
78-
def configuration():
77+
def configuration(): # noqa: D103
7978
filename = os.path.join(os.path.dirname(nowcasting_dataset.__file__), "config", "gcp.yaml")
8079
configuration = load_yaml_configuration(filename)
8180

8281
return configuration
8382

8483

8584
@pytest.fixture
86-
def test_data_folder():
85+
def test_data_folder(): # noqa: D103
8786
return os.path.join(os.path.dirname(nowcasting_dataset.__file__), "../tests/data")

notebooks/2021-09/2021-09-07/sat_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
"""Notebook"""
12
from datetime import datetime
23

34
from nowcasting_dataset.data_sources.satellite.satellite_data_source import SatelliteDataSource
@@ -9,7 +10,6 @@
910
forecast_len=12,
1011
image_size_pixels=64,
1112
meters_per_pixel=2000,
12-
n_timesteps_per_batch=32,
1313
)
1414

1515
s.open()

nowcasting_dataset/data_sources/data_source.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,20 @@
66
from pathlib import Path
77
from typing import Iterable, List, Tuple
88

9+
import futures
910
import pandas as pd
1011
import xarray as xr
1112

12-
import nowcasting_dataset.filesystem.utils as nd_fs_utils
13-
1413
# nowcasting_dataset imports
14+
import nowcasting_dataset.filesystem.utils as nd_fs_utils
1515
import nowcasting_dataset.time as nd_time
1616
import nowcasting_dataset.utils as nd_utils
1717
from nowcasting_dataset import square
1818
from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput
19-
from nowcasting_dataset.dataset.xr_utils import join_dataset_to_batch_dataset
19+
from nowcasting_dataset.dataset.xr_utils import (
20+
join_dataset_to_batch_dataset,
21+
join_list_data_array_to_batch_dataset,
22+
)
2023

2124
logger = logging.getLogger(__name__)
2225

@@ -310,20 +313,12 @@ class ZarrDataSource(ImageDataSource):
310313
"""
311314

312315
channels: Iterable[str]
313-
#: Mustn't be None, but cannot have a non-default arg in this position :)
314-
n_timesteps_per_batch: int = None
315316
consolidated: bool = True
316317

317318
def __post_init__(self, image_size_pixels: int, meters_per_pixel: int):
318319
""" Post init """
319320
super().__post_init__(image_size_pixels, meters_per_pixel)
320321
self._data = None
321-
if self.n_timesteps_per_batch is None:
322-
# Using hacky default for now. The whole concept of n_timesteps_per_batch
323-
# will be removed when #213 is completed.
324-
# TODO: Remove n_timesteps_per_batch when #213 is completed!
325-
self.n_timesteps_per_batch = 16
326-
logger.warning("n_timesteps_per_batch is not set! Using default!")
327322

328323
@property
329324
def data(self):
@@ -378,7 +373,42 @@ def get_example(
378373
# rename 'variable' to 'channels'
379374
selected_data = selected_data.rename({"variable": "channels"})
380375

381-
return selected_data
376+
return selected_data.load()
377+
378+
def get_batch(
379+
self,
380+
t0_datetimes: pd.DatetimeIndex,
381+
x_locations: Iterable[Number],
382+
y_locations: Iterable[Number],
383+
) -> DataSourceOutput:
384+
"""
385+
Get batch data
386+
387+
Args:
388+
t0_datetimes: list of timestamps for the datetime of the batches. The batch will also
389+
include data for historic and future depending on `history_minutes` and
390+
`future_minutes`.
391+
x_locations: x center batch locations
392+
y_locations: y center batch locations
393+
394+
Returns: Batch data
395+
396+
"""
397+
zipped = list(zip(t0_datetimes, x_locations, y_locations))
398+
batch_size = len(t0_datetimes)
399+
400+
with futures.ThreadPoolExecutor(max_workers=batch_size) as executor:
401+
future_examples = []
402+
for coords in zipped:
403+
t0_datetime, x_location, y_location = coords
404+
future_example = executor.submit(
405+
self.get_example, t0_datetime, x_location, y_location
406+
)
407+
future_examples.append(future_example)
408+
examples = [future_example.result() for future_example in future_examples]
409+
410+
output = join_list_data_array_to_batch_dataset(examples)
411+
return self._dataset_to_data_source_output(output)
382412

383413
def geospatial_border(self) -> List[Tuple[Number, Number]]:
384414
"""
@@ -415,3 +445,6 @@ def open(self) -> None:
415445

416446
def _open_data(self) -> xr.DataArray:
417447
raise NotImplementedError()
448+
449+
def _dataset_to_data_source_output(output: xr.Dataset) -> DataSourceOutput:
450+
raise NotImplementedError()

nowcasting_dataset/data_sources/nwp/nwp_data_source.py

Lines changed: 5 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
11
""" NWP Data Source """
22
import logging
3-
from concurrent import futures
43
from dataclasses import InitVar, dataclass
5-
from numbers import Number
64
from typing import Iterable, Optional
75

86
import numpy as np
97
import pandas as pd
108
import xarray as xr
119

1210
from nowcasting_dataset import utils
11+
from nowcasting_dataset.consts import NWP_VARIABLE_NAMES
1312
from nowcasting_dataset.data_sources.data_source import ZarrDataSource
1413
from nowcasting_dataset.data_sources.nwp.nwp_model import NWP
15-
from nowcasting_dataset.dataset.xr_utils import join_list_data_array_to_batch_dataset
16-
from nowcasting_dataset.consts import NWP_VARIABLE_NAMES
1714

1815
_LOG = logging.getLogger(__name__)
1916

@@ -78,72 +75,12 @@ def open(self) -> None:
7875
data = self._open_data()
7976
self._data = data["UKV"].sel(variable=list(self.channels))
8077

81-
def get_batch(
82-
self,
83-
t0_datetimes: pd.DatetimeIndex,
84-
x_locations: Iterable[Number],
85-
y_locations: Iterable[Number],
86-
) -> NWP:
87-
"""
88-
Get batch data
89-
90-
Args:
91-
t0_datetimes: list of timstamps
92-
x_locations: list of x locations, where the batch data is for
93-
y_locations: list of y locations, where the batch data is for
94-
95-
Returns: batch data
96-
97-
"""
98-
# Lazily select time slices.
99-
selections = []
100-
for t0_dt in t0_datetimes[: self.n_timesteps_per_batch]:
101-
selections.append(self._get_time_slice(t0_dt))
102-
103-
# Load entire time slices from disk in multiple threads.
104-
data = []
105-
with futures.ThreadPoolExecutor(max_workers=self.n_timesteps_per_batch) as executor:
106-
data_futures = []
107-
# Submit tasks.
108-
for selection in selections:
109-
future = executor.submit(selection.load)
110-
data_futures.append(future)
111-
112-
# Grab tasks
113-
for future in data_futures:
114-
d = future.result()
115-
data.append(d)
116-
117-
# Select squares from pre-loaded time slices.
118-
examples = []
119-
for i, (x_meters_center, y_meters_center) in enumerate(zip(x_locations, y_locations)):
120-
selected_data = data[i % self.n_timesteps_per_batch]
121-
bounding_box = self._square.bounding_box_centered_on(
122-
x_meters_center=x_meters_center, y_meters_center=y_meters_center
123-
)
124-
selected_data = selected_data.sel(
125-
x=slice(bounding_box.left, bounding_box.right),
126-
y=slice(bounding_box.top, bounding_box.bottom),
127-
)
128-
129-
# selected_sat_data is likely to have 1 too many pixels in x and y
130-
# because sel(x=slice(a, b)) is [a, b], not [a, b). So trim:
131-
selected_data = selected_data.isel(
132-
x=slice(0, self._square.size_pixels), y=slice(0, self._square.size_pixels)
133-
)
134-
135-
t0_dt = t0_datetimes[i]
136-
selected_data = self._post_process_example(selected_data, t0_dt)
137-
138-
examples.append(selected_data)
139-
140-
output = join_list_data_array_to_batch_dataset(examples)
141-
142-
return NWP(output)
143-
14478
def _open_data(self) -> xr.DataArray:
14579
return open_nwp(self.zarr_path, consolidated=self.consolidated)
14680

81+
def _dataset_to_data_source_output(output: xr.Dataset) -> NWP:
82+
NWP(output)
83+
14784
def _get_time_slice(self, t0_dt: pd.Timestamp) -> xr.DataArray:
14885
"""
14986
Select the numerical weather predictions for a single time slice.
@@ -177,20 +114,9 @@ def _get_time_slice(self, t0_dt: pd.Timestamp) -> xr.DataArray:
177114
selected["target_time"] = init_time + selected.step
178115
return selected
179116

180-
def _post_process_example(
181-
self, selected_data: xr.DataArray, t0_dt: pd.Timestamp
182-
) -> xr.DataArray:
183-
"""Resamples to 5 minutely."""
184-
start_dt = self._get_start_dt(t0_dt)
185-
end_dt = self._get_end_dt(t0_dt)
186-
selected_data = selected_data.resample({"target_time": "5T"})
187-
selected_data = selected_data.interpolate()
188-
selected_data = selected_data.sel(target_time=slice(start_dt, end_dt))
117+
def _post_process_example(self, selected_data: xr.DataArray) -> xr.DataArray:
189118
selected_data = selected_data.rename({"target_time": "time"})
190119
selected_data = selected_data.rename({"variable": "channels"})
191-
192-
selected_data.data = selected_data.data.astype(np.float32)
193-
194120
return selected_data
195121

196122
def datetime_index(self) -> pd.DatetimeIndex:

0 commit comments

Comments
 (0)