Skip to content
This repository was archived by the owner on Sep 28, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions nowcasting_dataloader/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,45 +82,46 @@ class BatchML(Example):
def fake(configuration: Configuration = Configuration()):
"""Create fake batch"""
process = configuration.process
input_data = configuration.input_data

t0_dt, time_5, time_30 = make_random_time_vectors(
batch_size=process.batch_size,
seq_length_5_minutes=process.seq_length_5_minutes,
seq_length_30_minutes=process.seq_length_30_minutes,
seq_length_5_minutes=input_data.default_seq_length_5_minutes,
seq_length_30_minutes=input_data.default_seq_length_5_minutes // 6,
)

return BatchML(
batch_size=process.batch_size,
metadata=MetadataML.fake(batch_size=process.batch_size, t0_dt=t0_dt),
satellite=SatelliteML.fake(
process.batch_size,
process.seq_length_5_minutes,
process.satellite_image_size_pixels,
len(process.sat_channels),
input_data.default_seq_length_5_minutes,
input_data.satellite.satellite_image_size_pixels,
len(input_data.satellite.sat_channels),
time_5=time_5,
),
topographic=TopographicML.fake(
batch_size=process.batch_size,
image_size_pixels=process.satellite_image_size_pixels,
image_size_pixels=input_data.satellite.satellite_image_size_pixels,
),
pv=PVML.fake(
batch_size=process.batch_size,
seq_length_5=process.seq_length_5_minutes,
seq_length_5=input_data.default_seq_length_5_minutes,
n_pv_systems_per_batch=128,
time_5=time_5,
),
sun=SunML.fake(
batch_size=process.batch_size, seq_length_5=process.seq_length_5_minutes
batch_size=process.batch_size, seq_length_5=input_data.default_seq_length_5_minutes
),
nwp=NWPML.fake(
batch_size=process.batch_size,
seq_length_5=process.seq_length_5_minutes,
image_size_pixels=process.nwp_image_size_pixels,
number_nwp_channels=len(process.nwp_channels),
seq_length_5=input_data.default_seq_length_5_minutes,
image_size_pixels=input_data.nwp.nwp_image_size_pixels,
number_nwp_channels=len(input_data.nwp.nwp_channels),
time_5=time_5,
),
datetime=DatetimeML.fake(
batch_size=process.batch_size, seq_length_5=process.seq_length_5_minutes
batch_size=process.batch_size, seq_length_5=input_data.default_seq_length_5_minutes
),
)

Expand All @@ -145,7 +146,7 @@ def from_batch(batch: Batch) -> BatchML:

def normalize(self):
""" Normalize the batch """

# loop over all data sources and normalize
for data_sources in self.data_sources:
data_sources.normalize()
56 changes: 28 additions & 28 deletions nowcasting_dataloader/data_sources/satellite/satellite_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,34 @@
logger = logging.getLogger(__name__)

SAT_MEAN = [
93.23458,
131.71373,
843.7779,
736.6148,
771.1189,
589.66034,
862.29816,
927.69586,
90.70885,
107.58985,
618.4583,
532.47394,
]
93.23458,
131.71373,
843.7779,
736.6148,
771.1189,
589.66034,
862.29816,
927.69586,
90.70885,
107.58985,
618.4583,
532.47394,
]

SAT_STD = [
115.34247,
139.92636,
36.99538,
57.366386,
30.346825,
149.68007,
51.70631,
35.872967,
115.77212,
120.997154,
98.57828,
99.76469,
]
115.34247,
139.92636,
36.99538,
57.366386,
30.346825,
149.68007,
51.70631,
35.872967,
115.77212,
120.997154,
98.57828,
99.76469,
]


class SatelliteML(DataSourceOutputML):
Expand All @@ -56,7 +56,7 @@ class SatelliteML(DataSourceOutputML):
)
x: Array = Field(
...,
description="aThe x (OSGB geo-spatial) coordinates of the satellite images. Shape: [batch_size,] width",
description="The x (OSGB geo-spatial) coordinates of the satellite images. Shape: [batch_size,] width",
)
y: Array = Field(
...,
Expand Down Expand Up @@ -116,7 +116,7 @@ def from_xr_dataset(xr_dataset: xr.Dataset):
satellite_batch_ml = xr_dataset.torch.to_tensor(["data", "time", "x", "y", "channels"])

return SatelliteML(**satellite_batch_ml)

def normalize(self):
"""Normalize the satellite data """
if not self.normalized:
Expand Down
30 changes: 20 additions & 10 deletions nowcasting_dataloader/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from nowcasting_dataloader.subset import subselect_data
from nowcasting_dataset.filesystem.utils import download_to_local, delete_all_files_in_temp_path
from nowcasting_dataset.utils import set_fsspec_for_multiprocess
from nowcasting_dataloader.utils.position_encoding import generate_position_encodings_for_batch

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -87,7 +88,8 @@ def __init__(
required_keys: Union[Tuple[str], List[str]] = None,
history_minutes: Optional[int] = None,
forecast_minutes: Optional[int] = None,
normalize: bool = False
normalize: bool = False,
add_position_encoding: bool = False,
):
"""
Netcdf Dataset
Expand All @@ -105,6 +107,7 @@ def __init__(
configuration: configuration object
cloud: which cloud is used, can be "gcp", "aws" or "local".
normalize: normalize the batch data
add_position_encoding: Whether to add position encoding or not
"""
self.n_batches = n_batches
self.src_path = src_path
Expand All @@ -114,24 +117,27 @@ def __init__(
self.forecast_minutes = forecast_minutes
self.configuration = configuration
self.normalize = normalize
self.add_position_encoding = add_position_encoding

logger.info(f"Setting up NetCDFDataset for {src_path}")

if self.forecast_minutes is None:
self.forecast_minutes = configuration.process.forecast_minutes
self.forecast_minutes = configuration.input_data.default_forecast_minutes
if self.history_minutes is None:
self.history_minutes = configuration.process.history_minutes
self.history_minutes = configuration.input_data.default_history_minutes

# see if we need to select the subset of data. If turned on -
# only history_minutes + current time + forecast_minutes data is used.
self.select_subset_data = False
if self.forecast_minutes != configuration.process.forecast_minutes:
if self.forecast_minutes != configuration.input_data.default_forecast_minutes:
self.select_subset_data = True
if self.history_minutes != configuration.process.history_minutes:
if self.history_minutes != configuration.input_data.default_history_minutes:
self.select_subset_data = True

# Index into either sat_datetime_index or nwp_target_time indicating the current time,
self.current_timestep_5_index = int(configuration.process.history_minutes // 5) + 1
self.current_timestep_5_index = (
int(configuration.input_data.default_history_minutes // 5) + 1
)

if required_keys is None:
required_keys = DEFAULT_REQUIRED_KEYS
Expand Down Expand Up @@ -184,7 +190,7 @@ def __getitem__(self, batch_idx: int) -> dict:
else:
local_netcdf_folder = self.src_path

batch = Batch.load_netcdf(local_netcdf_folder, batch_idx=batch_idx)
batch: Batch = Batch.load_netcdf(local_netcdf_folder, batch_idx=batch_idx)

if self.select_subset_data:
batch = subselect_data(
Expand All @@ -193,8 +199,8 @@ def __getitem__(self, batch_idx: int) -> dict:
forecast_minutes=self.forecast_minutes,
current_timestep_index=self.current_timestep_5_index,
)

# TODO Add positional encodings here https://github.com/openclimatefix/nowcasting_dataloader/issues/4
if self.add_position_encoding:
position_encodings = generate_position_encodings_for_batch(batch)
# change batch into ML learning batch ready for training
batch: BatchML = BatchML.from_batch(batch=batch)

Expand All @@ -207,7 +213,11 @@ def __getitem__(self, batch_idx: int) -> dict:
if self.normalize:
batch.normalize()

return batch.dict()
batch: dict = batch.dict()
if self.add_position_encoding:
# Add position encodings
batch.update(position_encodings)
return batch


def worker_init_fn(worker_id):
Expand Down
2 changes: 1 addition & 1 deletion nowcasting_dataloader/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, configuration: Configuration, length: int = 10):
configuration: configuration object
length: length of dataset
"""
self.number_nwp_channels = len(configuration.process.nwp_channels)
self.number_nwp_channels = len(configuration.input_data.nwp.nwp_channels)
self.length = length
self.configuration = configuration

Expand Down
Loading