33import logging
44from dataclasses import InitVar , dataclass
55from numbers import Number
6+ from pathlib import Path
67from typing import Iterable , List , Tuple
78
89import pandas as pd
910import xarray as xr
1011
12+ import nowcasting_dataset .filesystem .utils as nd_fs_utils
13+
14+ # nowcasting_dataset imports
1115import nowcasting_dataset .time as nd_time
16+ import nowcasting_dataset .utils as nd_utils
1217from nowcasting_dataset import square
1318from nowcasting_dataset .data_sources .datasource_output import DataSourceOutput
1419from nowcasting_dataset .dataset .xr_utils import join_dataset_to_batch_dataset
@@ -99,8 +104,7 @@ def sample_period_minutes(self) -> int:
99104 """
100105 This is the default sample period in minutes.
101106
102- This functions may be overwritten if
103- the sample period of the data source is not 5 minutes.
107+ This functions may be overwritten if the sample period of the data source is not 5 minutes.
104108 """
105109 logging .debug (
106110 "Getting sample_period_minutes default of 5 minutes. "
@@ -112,13 +116,79 @@ def open(self):
112116 """Open the data source, if necessary.
113117
114118 Called from each worker process. Useful for data sources where the
115- underlying data source cannot be forked (like Zarr on GCP! ).
119+ underlying data source cannot be forked (like Zarr).
116120
117- Data sources which can be forked safely should call open()
118- from __init__().
121+ Data sources which can be forked safely should call open() from __init__().
119122 """
120123 pass
121124
125+ def create_batches (
126+ self ,
127+ spatial_and_temporal_locations_of_each_example : pd .DataFrame ,
128+ idx_of_first_batch : int ,
129+ batch_size : int ,
130+ dst_path : Path ,
131+ temp_path : Path ,
132+ upload_every_n_batches : int ,
133+ ) -> None :
134+ """Create multiple batches and save them to disk.
135+
136+ Args:
137+ spatial_and_temporal_locations_of_each_example: A DataFrame where each row specifies
138+ the spatial and temporal location of an example. The number of rows must be
139+ an exact multiple of `batch_size`.
140+ Columns are: t0_datetime_UTC, x_center_OSGB, y_center_OSGB.
141+ idx_of_first_batch: The batch number of the first batch to create.
142+ batch_size: The number of examples per batch.
143+ dst_path: The final destination path for the batches. Must exist.
144+ temp_path: The local temporary path. This is only required when dst_path is a
145+ cloud storage bucket, so files must first be created on the VM's local disk in temp_path
146+ and then uploaded to dst_path every upload_every_n_batches. Must exist. Will be emptied.
147+ upload_every_n_batches: Upload the contents of temp_path to dst_path after this number
148+ of batches have been created. If 0 then will write directly to dst_path.
149+ """
150+ # Sanity checks:
151+ assert idx_of_first_batch >= 0
152+ assert batch_size > 0
153+ assert len (spatial_and_temporal_locations_of_each_example ) % batch_size == 0
154+ assert upload_every_n_batches >= 0
155+
156+ # Figure out where to write batches to:
157+ save_batches_locally_and_upload = upload_every_n_batches > 0
158+ if save_batches_locally_and_upload :
159+ nd_fs_utils .delete_all_files_in_temp_path (temp_path )
160+ path_to_write_to = temp_path if save_batches_locally_and_upload else dst_path
161+
162+ # Loop round each batch:
163+ examples_for_batch = spatial_and_temporal_locations_of_each_example .iloc [:batch_size ]
164+ n_batches_processed = 0
165+ while not examples_for_batch .empty :
166+ # Generate batch.
167+ batch = self .get_batch (
168+ t0_datetimes = examples_for_batch .t0_datetime_UTC ,
169+ x_locations = examples_for_batch .x_center_OSGB ,
170+ y_locations = examples_for_batch .y_center_OSGB ,
171+ )
172+
173+ # Save batch to disk.
174+ batch_idx = idx_of_first_batch + n_batches_processed
175+ netcdf_filename = path_to_write_to / nd_utils .get_netcdf_filename (batch_idx )
176+ batch .to_netcdf (netcdf_filename )
177+
178+ # Upload if necessary.
179+ if (
180+ save_batches_locally_and_upload
181+ and n_batches_processed > 0
182+ and n_batches_processed % upload_every_n_batches == 0
183+ ):
184+ nd_fs_utils .upload_and_delete_local_files (dst_path , path_to_write_to )
185+
186+ n_batches_processed += 1
187+
188+ # Upload last few batches, if necessary:
189+ if save_batches_locally_and_upload :
190+ nd_fs_utils .upload_and_delete_local_files (dst_path , path_to_write_to )
191+
122192 def get_batch (
123193 self ,
124194 t0_datetimes : pd .DatetimeIndex ,
@@ -141,14 +211,9 @@ def get_batch(
141211 zipped = zip (t0_datetimes , x_locations , y_locations )
142212 for t0_datetime , x_location , y_location in zipped :
143213 output : xr .Dataset = self .get_example (t0_datetime , x_location , y_location )
144-
145214 examples .append (output )
146215
147- # could add option here, to save each data source using
148- # 1. # DataSourceOutput.to_xr_dataset() to make it a dataset
149- # 2. DataSourceOutput.save_netcdf(), save to netcdf
150-
151- # get the name of the cls, this could be one of the data sources like Sun
216+ # Get the DataSource class, this could be one of the data sources like Sun
152217 cls = examples [0 ].__class__
153218
154219 # join the examples together, and cast them to the cls, so that validation can occur
0 commit comments