Skip to content
This repository was archived by the owner on Sep 28, 2023. It is now read-only.

Commit 96c9adb

Browse files
authored
Add Absolute Positional Encoding in Dataloader (#7)
#minor * Add stub version of encoding for Batch * Simplify the encodings for the batch * Add in SEVIRI RSS OSGB bounds * Topographic works * Almost working GSP and PV Issues on merging the spatial and temporal values together, the x and y shapes do not match Spatial ones are the 32 ID x and y coords in an array, and the actual spatial features would be along the diagonal. So should just need to slice that and make the spatial one smaller * Add getting diagonal Have to add support for 4D tensor as well now to support GSP and PV correctly * Fix position encodings for GSP,PV * Remove prints * Add asserts to test * Add position encodings to dataloader * Change tests * Fix updating Batch dictionary * Update dataset to fix #25 * Add unit test for shape encoding * Add comment * Fix FakeDataset issue * Split out combining features and add test * Move if statement to own line * Fix more of FakeDataset * Change SEVIRI RSS bounds to function * Change SEVIRI RSS bounds to function * Update test data to fix issue * Add config option
1 parent d28c96a commit 96c9adb

File tree

15 files changed

+278
-74
lines changed

15 files changed

+278
-74
lines changed

nowcasting_dataloader/batch.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,45 +82,46 @@ class BatchML(Example):
8282
def fake(configuration: Configuration = Configuration()):
8383
"""Create fake batch"""
8484
process = configuration.process
85+
input_data = configuration.input_data
8586

8687
t0_dt, time_5, time_30 = make_random_time_vectors(
8788
batch_size=process.batch_size,
88-
seq_length_5_minutes=process.seq_length_5_minutes,
89-
seq_length_30_minutes=process.seq_length_30_minutes,
89+
seq_length_5_minutes=input_data.default_seq_length_5_minutes,
90+
seq_length_30_minutes=input_data.default_seq_length_5_minutes // 6,
9091
)
9192

9293
return BatchML(
9394
batch_size=process.batch_size,
9495
metadata=MetadataML.fake(batch_size=process.batch_size, t0_dt=t0_dt),
9596
satellite=SatelliteML.fake(
9697
process.batch_size,
97-
process.seq_length_5_minutes,
98-
process.satellite_image_size_pixels,
99-
len(process.sat_channels),
98+
input_data.default_seq_length_5_minutes,
99+
input_data.satellite.satellite_image_size_pixels,
100+
len(input_data.satellite.sat_channels),
100101
time_5=time_5,
101102
),
102103
topographic=TopographicML.fake(
103104
batch_size=process.batch_size,
104-
image_size_pixels=process.satellite_image_size_pixels,
105+
image_size_pixels=input_data.satellite.satellite_image_size_pixels,
105106
),
106107
pv=PVML.fake(
107108
batch_size=process.batch_size,
108-
seq_length_5=process.seq_length_5_minutes,
109+
seq_length_5=input_data.default_seq_length_5_minutes,
109110
n_pv_systems_per_batch=128,
110111
time_5=time_5,
111112
),
112113
sun=SunML.fake(
113-
batch_size=process.batch_size, seq_length_5=process.seq_length_5_minutes
114+
batch_size=process.batch_size, seq_length_5=input_data.default_seq_length_5_minutes
114115
),
115116
nwp=NWPML.fake(
116117
batch_size=process.batch_size,
117-
seq_length_5=process.seq_length_5_minutes,
118-
image_size_pixels=process.nwp_image_size_pixels,
119-
number_nwp_channels=len(process.nwp_channels),
118+
seq_length_5=input_data.default_seq_length_5_minutes,
119+
image_size_pixels=input_data.nwp.nwp_image_size_pixels,
120+
number_nwp_channels=len(input_data.nwp.nwp_channels),
120121
time_5=time_5,
121122
),
122123
datetime=DatetimeML.fake(
123-
batch_size=process.batch_size, seq_length_5=process.seq_length_5_minutes
124+
batch_size=process.batch_size, seq_length_5=input_data.default_seq_length_5_minutes
124125
),
125126
)
126127

@@ -145,7 +146,7 @@ def from_batch(batch: Batch) -> BatchML:
145146

146147
def normalize(self):
147148
""" Normalize the batch """
148-
149+
149150
# loop over all data sources and normalize
150151
for data_sources in self.data_sources:
151152
data_sources.normalize()

nowcasting_dataloader/data_sources/satellite/satellite_model.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,34 +16,34 @@
1616
logger = logging.getLogger(__name__)
1717

1818
SAT_MEAN = [
19-
93.23458,
20-
131.71373,
21-
843.7779,
22-
736.6148,
23-
771.1189,
24-
589.66034,
25-
862.29816,
26-
927.69586,
27-
90.70885,
28-
107.58985,
29-
618.4583,
30-
532.47394,
31-
]
19+
93.23458,
20+
131.71373,
21+
843.7779,
22+
736.6148,
23+
771.1189,
24+
589.66034,
25+
862.29816,
26+
927.69586,
27+
90.70885,
28+
107.58985,
29+
618.4583,
30+
532.47394,
31+
]
3232

3333
SAT_STD = [
34-
115.34247,
35-
139.92636,
36-
36.99538,
37-
57.366386,
38-
30.346825,
39-
149.68007,
40-
51.70631,
41-
35.872967,
42-
115.77212,
43-
120.997154,
44-
98.57828,
45-
99.76469,
46-
]
34+
115.34247,
35+
139.92636,
36+
36.99538,
37+
57.366386,
38+
30.346825,
39+
149.68007,
40+
51.70631,
41+
35.872967,
42+
115.77212,
43+
120.997154,
44+
98.57828,
45+
99.76469,
46+
]
4747

4848

4949
class SatelliteML(DataSourceOutputML):
@@ -56,7 +56,7 @@ class SatelliteML(DataSourceOutputML):
5656
)
5757
x: Array = Field(
5858
...,
59-
description="aThe x (OSGB geo-spatial) coordinates of the satellite images. Shape: [batch_size,] width",
59+
description="The x (OSGB geo-spatial) coordinates of the satellite images. Shape: [batch_size,] width",
6060
)
6161
y: Array = Field(
6262
...,
@@ -116,7 +116,7 @@ def from_xr_dataset(xr_dataset: xr.Dataset):
116116
satellite_batch_ml = xr_dataset.torch.to_tensor(["data", "time", "x", "y", "channels"])
117117

118118
return SatelliteML(**satellite_batch_ml)
119-
119+
120120
def normalize(self):
121121
"""Normalize the satellite data """
122122
if not self.normalized:

nowcasting_dataloader/datasets.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from nowcasting_dataloader.subset import subselect_data
2121
from nowcasting_dataset.filesystem.utils import download_to_local, delete_all_files_in_temp_path
2222
from nowcasting_dataset.utils import set_fsspec_for_multiprocess
23+
from nowcasting_dataloader.utils.position_encoding import generate_position_encodings_for_batch
2324

2425
logger = logging.getLogger(__name__)
2526

@@ -87,7 +88,8 @@ def __init__(
8788
required_keys: Union[Tuple[str], List[str]] = None,
8889
history_minutes: Optional[int] = None,
8990
forecast_minutes: Optional[int] = None,
90-
normalize: bool = False
91+
normalize: bool = False,
92+
add_position_encoding: bool = False,
9193
):
9294
"""
9395
Netcdf Dataset
@@ -105,6 +107,7 @@ def __init__(
105107
configuration: configuration object
106108
cloud: which cloud is used, can be "gcp", "aws" or "local".
107109
normalize: normalize the batch data
110+
add_position_encoding: Whether to add position encoding or not
108111
"""
109112
self.n_batches = n_batches
110113
self.src_path = src_path
@@ -114,24 +117,27 @@ def __init__(
114117
self.forecast_minutes = forecast_minutes
115118
self.configuration = configuration
116119
self.normalize = normalize
120+
self.add_position_encoding = add_position_encoding
117121

118122
logger.info(f"Setting up NetCDFDataset for {src_path}")
119123

120124
if self.forecast_minutes is None:
121-
self.forecast_minutes = configuration.process.forecast_minutes
125+
self.forecast_minutes = configuration.input_data.default_forecast_minutes
122126
if self.history_minutes is None:
123-
self.history_minutes = configuration.process.history_minutes
127+
self.history_minutes = configuration.input_data.default_history_minutes
124128

125129
# see if we need to select the subset of data. If turned on -
126130
# only history_minutes + current time + forecast_minutes data is used.
127131
self.select_subset_data = False
128-
if self.forecast_minutes != configuration.process.forecast_minutes:
132+
if self.forecast_minutes != configuration.input_data.default_forecast_minutes:
129133
self.select_subset_data = True
130-
if self.history_minutes != configuration.process.history_minutes:
134+
if self.history_minutes != configuration.input_data.default_history_minutes:
131135
self.select_subset_data = True
132136

133137
# Index into either sat_datetime_index or nwp_target_time indicating the current time,
134-
self.current_timestep_5_index = int(configuration.process.history_minutes // 5) + 1
138+
self.current_timestep_5_index = (
139+
int(configuration.input_data.default_history_minutes // 5) + 1
140+
)
135141

136142
if required_keys is None:
137143
required_keys = DEFAULT_REQUIRED_KEYS
@@ -184,7 +190,7 @@ def __getitem__(self, batch_idx: int) -> dict:
184190
else:
185191
local_netcdf_folder = self.src_path
186192

187-
batch = Batch.load_netcdf(local_netcdf_folder, batch_idx=batch_idx)
193+
batch: Batch = Batch.load_netcdf(local_netcdf_folder, batch_idx=batch_idx)
188194

189195
if self.select_subset_data:
190196
batch = subselect_data(
@@ -193,8 +199,8 @@ def __getitem__(self, batch_idx: int) -> dict:
193199
forecast_minutes=self.forecast_minutes,
194200
current_timestep_index=self.current_timestep_5_index,
195201
)
196-
197-
# TODO Add positional encodings here https://github.com/openclimatefix/nowcasting_dataloader/issues/4
202+
if self.add_position_encoding:
203+
position_encodings = generate_position_encodings_for_batch(batch)
198204
# change batch into ML learning batch ready for training
199205
batch: BatchML = BatchML.from_batch(batch=batch)
200206

@@ -207,7 +213,11 @@ def __getitem__(self, batch_idx: int) -> dict:
207213
if self.normalize:
208214
batch.normalize()
209215

210-
return batch.dict()
216+
batch: dict = batch.dict()
217+
if self.add_position_encoding:
218+
# Add position encodings
219+
batch.update(position_encodings)
220+
return batch
211221

212222

213223
def worker_init_fn(worker_id):

nowcasting_dataloader/fake.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def __init__(self, configuration: Configuration, length: int = 10):
1616
configuration: configuration object
1717
length: length of dataset
1818
"""
19-
self.number_nwp_channels = len(configuration.process.nwp_channels)
19+
self.number_nwp_channels = len(configuration.input_data.nwp.nwp_channels)
2020
self.length = length
2121
self.configuration = configuration
2222

0 commit comments

Comments
 (0)