diff --git a/conftest.py b/conftest.py index b47b2c31..cb14753f 100644 --- a/conftest.py +++ b/conftest.py @@ -10,9 +10,17 @@ from nowcasting_dataset.data_sources import SatelliteDataSource from nowcasting_dataset.data_sources.gsp.gsp_data_source import GSPDataSource from nowcasting_dataset.data_sources.metadata.metadata_data_source import MetadataDataSource +from nowcasting_dataset.dataset.xr_utils import ( + register_xr_data_array_to_tensor, + register_xr_data_set_to_tensor, +) pytest.IMAGE_SIZE_PIXELS = 128 +# need to run these to ensure that xarray DataArray and Dataset have torch functions +register_xr_data_array_to_tensor() +register_xr_data_set_to_tensor() + def pytest_addoption(parser): parser.addoption( diff --git a/nowcasting_dataset/config/gcp.yaml b/nowcasting_dataset/config/gcp.yaml index 7974cfd8..8c979d43 100644 --- a/nowcasting_dataset/config/gcp.yaml +++ b/nowcasting_dataset/config/gcp.yaml @@ -8,6 +8,7 @@ input_data: solar_pv_metadata_filename: gs://solar-pv-nowcasting-data/PV/PVOutput.org/UK_PV_metadata.csv gsp_zarr_path: gs://solar-pv-nowcasting-data/PV/GSP/v1/pv_gsp.zarr topographic_filename: gs://solar-pv-nowcasting-data/Topographic/europe_dem_1km_osgb.tif + sun_zarr_path: gs://solar-pv-nowcasting-data/Sun/v0/sun.zarr output_data: filepath: gs://solar-pv-nowcasting-data/prepared_ML_training_data/v7/ process: diff --git a/nowcasting_dataset/data_sources/README.md b/nowcasting_dataset/data_sources/README.md index 90ff7c7a..215ef139 100644 --- a/nowcasting_dataset/data_sources/README.md +++ b/nowcasting_dataset/data_sources/README.md @@ -38,5 +38,12 @@ General pydantic model of output of the data source. Contains the following meth Roughly each of the data source folders follows this pattern - A class which defines how to load the data source, how to select for batches etc. This inherits from 'data_source.DataSource', -- A class which contains the output model of the data source. This is the information used in the batches. +- A class which contains the output model of the data source, built from an xarray Dataset. This is the information used in the batches. This inherits from 'datasource_output.DataSourceOutput'. +- A second class (pydantic) which moves the xarray Dataset to tensor fields. This will be used for training in ML models + + +# fake + +`fake.py` has several function to create fake `Batch` data. This is useful for testing, +and hopefully useful outside this module too. diff --git a/nowcasting_dataset/data_sources/data_source.py b/nowcasting_dataset/data_sources/data_source.py index 949fa101..a4019ea9 100644 --- a/nowcasting_dataset/data_sources/data_source.py +++ b/nowcasting_dataset/data_sources/data_source.py @@ -12,6 +12,7 @@ import nowcasting_dataset.time as nd_time from nowcasting_dataset import square from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput +from nowcasting_dataset.dataset.xr_utils import join_dataset_to_batch_dataset logger = logging.getLogger(__name__) @@ -122,16 +123,19 @@ def get_batch( examples = [] zipped = zip(t0_datetimes, x_locations, y_locations) for t0_datetime, x_location, y_location in zipped: - output: DataSourceOutput = self.get_example(t0_datetime, x_location, y_location) + output: xr.Dataset = self.get_example(t0_datetime, x_location, y_location) - if self.convert_to_numpy: - output.to_numpy() examples.append(output) # could add option here, to save each data source using # 1. # DataSourceOutput.to_xr_dataset() to make it a dataset # 2. DataSourceOutput.save_netcdf(), save to netcdf - return DataSourceOutput.create_batch_from_examples(examples) + + # get the name of the cls, this could be one of the data sources like Sun + cls = examples[0].__class__ + + # join the examples together, and cast them to the cls, so that validation can occur + return cls(join_dataset_to_batch_dataset(examples)) def datetime_index(self) -> pd.DatetimeIndex: """Returns a complete list of all available datetimes.""" @@ -203,7 +207,7 @@ def get_example( t0_dt: pd.Timestamp, #: Datetime of "now": The most recent obs. x_meters_center: Number, #: Centre, in OSGB coordinates. y_meters_center: Number, #: Centre, in OSGB coordinates. - ) -> DataSourceOutput: + ) -> xr.Dataset: """Must be overridden by child classes.""" raise NotImplementedError() @@ -305,7 +309,10 @@ def get_example( f"actual shape {selected_data.shape}" ) - return self._put_data_into_example(selected_data) + # rename 'variable' to 'channels' + selected_data = selected_data.rename({"variable": "channels"}) + + return selected_data def geospatial_border(self) -> List[Tuple[Number, Number]]: """ @@ -342,6 +349,3 @@ def open(self) -> None: def _open_data(self) -> xr.DataArray: raise NotImplementedError() - - def _put_data_into_example(self, selected_data: xr.DataArray) -> DataSourceOutput: - raise NotImplementedError() diff --git a/nowcasting_dataset/data_sources/datasource_output.py b/nowcasting_dataset/data_sources/datasource_output.py index 3f375e42..7a4e59ef 100644 --- a/nowcasting_dataset/data_sources/datasource_output.py +++ b/nowcasting_dataset/data_sources/datasource_output.py @@ -1,115 +1,40 @@ """ General Data Source output pydantic class. """ from __future__ import annotations -import os -from nowcasting_dataset.filesystem.utils import make_folder -from nowcasting_dataset.utils import get_netcdf_filename +import logging +import os from pathlib import Path -from pydantic import BaseModel, Field -import pandas as pd -import xarray as xr +from typing import List + import numpy as np -from typing import List, Union -import logging -from datetime import datetime +from pydantic import BaseModel, Field -from nowcasting_dataset.utils import to_numpy +from nowcasting_dataset.dataset.xr_utils import PydanticXArrayDataSet +from nowcasting_dataset.filesystem.utils import make_folder +from nowcasting_dataset.utils import get_netcdf_filename logger = logging.getLogger(__name__) -class DataSourceOutput(BaseModel): +class DataSourceOutput(PydanticXArrayDataSet): """General Data Source output pydantic class. Data source output classes should inherit from this class """ - class Config: - """ Allowed classes e.g. tensor.Tensor""" - - # TODO maybe there is a better way to do this - arbitrary_types_allowed = True - - batch_size: int = Field( - 0, - ge=0, - description="The size of this batch. If the batch size is 0, " - "then this item stores one data item i.e Example", - ) + __slots__ = [] def get_name(self) -> str: - """ Get the name of the class """ + """Get the name of the class""" return self.__class__.__name__.lower() - def to_numpy(self): - """Change to numpy""" - for k, v in self.dict().items(): - self.__setattr__(k, to_numpy(v)) - - def to_xr_data_array(self): - """ Change to xr DataArray""" - raise NotImplementedError() - - @staticmethod - def create_batch_from_examples(data): - """ - Join a list of data source items to a batch. - - Note that this only works for numpy objects, so objects are changed into numpy - """ - _ = [d.to_numpy() for d in data] - - # use the first item in the list, and then update each item - batch = data[0] - for k in batch.dict().keys(): - - # set batch size to the list of the items - if k == "batch_size": - batch.batch_size = len(data) - else: - - # get list of one variable from the list of data items. - one_variable_list = [d.__getattribute__(k) for d in data] - batch.__setattr__(k, np.stack(one_variable_list, axis=0)) - - return batch - - def split(self) -> List[DataSourceOutput]: - """ - Split the datasource from a batch to a list of items - - Returns: List of single data source items - """ - cls = self.__class__ - - items = [] - for batch_idx in range(self.batch_size): - d = {k: v[batch_idx] for k, v in self.dict().items() if k != "batch_size"} - d["batch_size"] = 0 - items.append(cls(**d)) - - return items - - def to_xr_dataset(self, **kwargs): - """ Make a xr dataset. Each data source needs to define this """ - raise NotImplementedError - - def from_xr_dataset(self): - """ Load from xr dataset. Each data source needs to define this """ - raise NotImplementedError - - def get_datetime_index(self): - """ Datetime index for the data """ - pass - - def save_netcdf(self, batch_i: int, path: Path, xr_dataset: xr.Dataset): + def save_netcdf(self, batch_i: int, path: Path): """ Save batch to netcdf file Args: batch_i: the batch id, used to make the filename path: the path where it will be saved. This can be local or in the cloud. - xr_dataset: xr dataset that has batch information in it """ filename = get_netcdf_filename(batch_i) @@ -124,77 +49,46 @@ def save_netcdf(self, batch_i: int, path: Path, xr_dataset: xr.Dataset): # make file local_filename = os.path.join(folder, filename) - encoding = {name: {"compression": "lzf"} for name in xr_dataset.data_vars} - xr_dataset.to_netcdf(local_filename, engine="h5netcdf", mode="w", encoding=encoding) - - def select_time_period( - self, - keys: List[str], - history_minutes: int, - forecast_minutes: int, - t0_dt_of_first_example: Union[datetime, pd.Timestamp], - ): - """ - Selects a subset of data between the indicies of [start, end] for each key in keys - - Note that class is edited so nothing is returned. - - Args: - keys: Keys in batch to use - t0_dt_of_first_example: datetime of the current time (t0) in the first example of the batch - history_minutes: How many minutes of history to use - forecast_minutes: How many minutes of future data to use for forecasting - - """ - logger.debug( - f"Taking a sub-selection of the batch data based on a history minutes of {history_minutes} " - f"and forecast minutes of {forecast_minutes}" - ) + encoding = {name: {"compression": "lzf"} for name in self.data_vars} + self.to_netcdf(local_filename, engine="h5netcdf", mode="w", encoding=encoding) - start_time_of_first_batch = t0_dt_of_first_example - pd.to_timedelta( - f"{history_minutes} minute 30 second" - ) - end_time_of_first_example = t0_dt_of_first_example + pd.to_timedelta( - f"{forecast_minutes} minute 30 second" - ) - logger.debug(f"New start time for first batch is {start_time_of_first_batch}") - logger.debug(f"New end time for first batch is {end_time_of_first_example}") +class DataSourceOutputML(BaseModel): + """General Data Source output pydantic class. - start_time_of_first_example = to_numpy(start_time_of_first_batch) - end_time_of_first_example = to_numpy(end_time_of_first_example) + Data source output classes should inherit from this class + """ - if self.get_datetime_index() is not None: + class Config: + """Allowed classes e.g. tensor.Tensor""" - time_of_first_example = to_numpy(pd.to_datetime(self.get_datetime_index()[0])) + # TODO maybe there is a better way to do this + arbitrary_types_allowed = True - # find the start and end index, that we will then use to slice the data - start_i, end_i = np.searchsorted( - time_of_first_example, [start_time_of_first_example, end_time_of_first_example] - ) + batch_size: int = Field( + 0, + ge=0, + description="The size of this batch. If the batch size is 0, " + "then this item stores one data item i.e Example", + ) - # slice all the data - for key in keys: - if "time" in self.__getattribute__(key).dims: - self.__setattr__( - key, self.__getattribute__(key).isel(time=slice(start_i, end_i)) - ) - elif "time_30" in self.__getattribute__(key).dims: - self.__setattr__( - key, self.__getattribute__(key).isel(time_30=slice(start_i, end_i)) - ) + def get_name(self) -> str: + """Get the name of the class""" + return self.__class__.__name__.lower() - logger.debug(f"{self.__class__.__name__} {key}: {self.__getattribute__(key).shape}") + def get_datetime_index(self): + """Datetime index for the data""" + pass def pad_nans(array, pad_width) -> np.ndarray: - """ Pad nans with nans""" + """Pad nans with nans""" array = array.astype(np.float32) return np.pad(array, pad_width, constant_values=np.NaN) def pad_data( - data: DataSourceOutput, + data: DataSourceOutputML, pad_size: int, one_dimensional_arrays: List[str], two_dimensional_arrays: List[str], diff --git a/nowcasting_dataset/data_sources/datetime/datetime_data_source.py b/nowcasting_dataset/data_sources/datetime/datetime_data_source.py index 9d1088a2..b6d900e3 100644 --- a/nowcasting_dataset/data_sources/datetime/datetime_data_source.py +++ b/nowcasting_dataset/data_sources/datetime/datetime_data_source.py @@ -8,6 +8,7 @@ from nowcasting_dataset import time as nd_time from nowcasting_dataset.data_sources.data_source import DataSource from nowcasting_dataset.data_sources.datetime.datetime_model import Datetime +from nowcasting_dataset.dataset.xr_utils import make_dim_index @dataclass @@ -36,7 +37,13 @@ def get_example( start_dt = self._get_start_dt(t0_dt) end_dt = self._get_end_dt(t0_dt) index = pd.date_range(start_dt, end_dt, freq="5T") - return nd_time.datetime_features_in_example(index) + + datetime_xr_dataset = nd_time.datetime_features_in_example(index).rename({"index": "time"}) + + # make sure time is indexes in the correct way + datetime_xr_dataset = make_dim_index(datetime_xr_dataset) + + return Datetime(datetime_xr_dataset) def get_locations_for_batch( self, t0_datetimes: pd.DatetimeIndex diff --git a/nowcasting_dataset/data_sources/datetime/datetime_model.py b/nowcasting_dataset/data_sources/datetime/datetime_model.py index 9629e9e6..c64d318d 100644 --- a/nowcasting_dataset/data_sources/datetime/datetime_model.py +++ b/nowcasting_dataset/data_sources/datetime/datetime_model.py @@ -1,13 +1,29 @@ """ Model for output of datetime data """ -from pydantic import validator -import xarray as xr import numpy as np -from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput +import xarray as xr +from pydantic import validator + from nowcasting_dataset.consts import Array, DATETIME_FEATURE_NAMES +from nowcasting_dataset.data_sources.datasource_output import ( + DataSourceOutputML, + DataSourceOutput, +) + from nowcasting_dataset.utils import coord_to_range class Datetime(DataSourceOutput): + """ Class to store Datetime data as a xr.Dataset with some validation """ + + # Use to store xr.Dataset data + + __slots__ = () + + # todo add validation here - https://github.com/openclimatefix/nowcasting_dataset/issues/233 + _expected_dimensions = ("time",) + + +class DatetimeML(DataSourceOutputML): """ Model for output of datetime data """ hour_of_day_sin: Array #: Shape: [batch_size,] seq_length @@ -42,7 +58,7 @@ def v_day_of_year_cos(cls, v, values): @staticmethod def fake(batch_size, seq_length_5): """ Make a fake Datetime object """ - return Datetime( + return DatetimeML( batch_size=batch_size, hour_of_day_sin=np.random.randn( batch_size, @@ -88,7 +104,7 @@ def to_xr_dataset(self, _): def from_xr_dataset(xr_dataset): """ Change xr dataset to model. If data does not exist, then return None """ if "hour_of_day_sin" in xr_dataset.keys(): - return Datetime( + return DatetimeML( batch_size=xr_dataset["hour_of_day_sin"].shape[0], hour_of_day_sin=xr_dataset["hour_of_day_sin"], hour_of_day_cos=xr_dataset["hour_of_day_cos"], diff --git a/nowcasting_dataset/data_sources/fake.py b/nowcasting_dataset/data_sources/fake.py new file mode 100644 index 00000000..c2f6ffe8 --- /dev/null +++ b/nowcasting_dataset/data_sources/fake.py @@ -0,0 +1,320 @@ +""" To make fake Datasets + +Wanted to keep this out of the testing frame works, as other repos, might want to use this +""" +import numpy as np +import pandas as pd +import xarray as xr + +from nowcasting_dataset.dataset.xr_utils import ( + convert_data_array_to_dataset, + join_dataset_to_batch_dataset, + join_list_data_array_to_batch_dataset, +) + +from nowcasting_dataset.data_sources.datetime.datetime_model import Datetime +from nowcasting_dataset.data_sources.metadata.metadata_model import Metadata +from nowcasting_dataset.data_sources.gsp.gsp_model import GSP +from nowcasting_dataset.data_sources.nwp.nwp_model import NWP +from nowcasting_dataset.data_sources.pv.pv_model import PV +from nowcasting_dataset.data_sources.satellite.satellite_model import Satellite +from nowcasting_dataset.data_sources.sun.sun_model import Sun +from nowcasting_dataset.data_sources.topographic.topographic_model import Topographic + + +def datetime_fake(batch_size, seq_length_5): + """ Create fake data """ + xr_arrays = [create_datetime_dataset(seq_length=seq_length_5) for _ in range(batch_size)] + + # make dataset + xr_dataset = join_dataset_to_batch_dataset(xr_arrays) + + return Datetime(xr_dataset) + + +def gsp_fake( + batch_size, + seq_length_30, + n_gsp_per_batch, +): + """ Create fake data """ + # make batch of arrays + xr_arrays = [ + create_gsp_pv_dataset( + seq_length=seq_length_30, + freq="30T", + number_of_systems=n_gsp_per_batch, + ) + for _ in range(batch_size) + ] + + # make dataset + xr_dataset = join_dataset_to_batch_dataset(xr_arrays) + + return GSP(xr_dataset) + + +def metadata_fake(batch_size): + """Make a xr dataset""" + xr_arrays = [create_metadata_dataset() for _ in range(batch_size)] + + # make dataset + xr_dataset = join_dataset_to_batch_dataset(xr_arrays) + + return Metadata(xr_dataset) + + +def nwp_fake( + batch_size=32, + seq_length_5=19, + image_size_pixels=64, + number_nwp_channels=7, +) -> NWP: + """ Create fake data """ + # make batch of arrays + xr_arrays = [ + create_image_array( + seq_length_5=seq_length_5, + image_size_pixels=image_size_pixels, + number_channels=number_nwp_channels, + ) + for _ in range(batch_size) + ] + + # make dataset + xr_dataset = join_list_data_array_to_batch_dataset(xr_arrays) + + xr_dataset = xr_dataset.rename({"time": "target_time"}) + xr_dataset["init_time"] = xr_dataset.target_time[:, 0] + + return NWP(xr_dataset) + + +def pv_fake(batch_size, seq_length_5, n_pv_systems_per_batch): + """ Create fake data """ + # make batch of arrays + xr_arrays = [ + create_gsp_pv_dataset( + seq_length=seq_length_5, + freq="5T", + number_of_systems=n_pv_systems_per_batch, + ) + for _ in range(batch_size) + ] + + # make dataset + xr_dataset = join_dataset_to_batch_dataset(xr_arrays) + + return PV(xr_dataset) + + +def satellite_fake( + batch_size=32, + seq_length_5=19, + satellite_image_size_pixels=64, + number_sat_channels=7, +) -> Satellite: + """ Create fake data """ + # make batch of arrays + xr_arrays = [ + create_image_array( + seq_length_5=seq_length_5, + image_size_pixels=satellite_image_size_pixels, + number_channels=number_sat_channels, + ) + for _ in range(batch_size) + ] + + # make dataset + xr_dataset = join_list_data_array_to_batch_dataset(xr_arrays) + + return Satellite(xr_dataset) + + +def sun_fake(batch_size, seq_length_5): + """ Create fake data """ + # create dataset with both azimuth and elevation, index with time + # make batch of arrays + xr_arrays = [ + create_sun_dataset( + seq_length=seq_length_5, + ) + for _ in range(batch_size) + ] + + # make dataset + xr_dataset = join_dataset_to_batch_dataset(xr_arrays) + + return Sun(xr_dataset) + + +def topographic_fake(batch_size, image_size_pixels): + """ Create fake data """ + # make batch of arrays + xr_arrays = [ + xr.DataArray( + data=np.random.randn( + image_size_pixels, + image_size_pixels, + ), + dims=["x", "y"], + coords=dict( + x=np.sort(np.random.randn(image_size_pixels)), + y=np.sort(np.random.randn(image_size_pixels))[::-1].copy(), + ), + ) + for _ in range(batch_size) + ] + + # make dataset + xr_dataset = join_list_data_array_to_batch_dataset(xr_arrays) + + return Topographic(xr_dataset) + + +def create_image_array( + dims=("time", "x", "y", "channels"), + seq_length_5=19, + image_size_pixels=64, + number_channels=7, +): + """ Create Satellite or NWP fake image data""" + ALL_COORDS = { + "time": pd.date_range("2021-01-01", freq="5T", periods=seq_length_5), + "x": np.random.randint(low=0, high=1000, size=image_size_pixels), + "y": np.random.randint(low=0, high=1000, size=image_size_pixels), + "channels": np.arange(number_channels), + } + coords = [(dim, ALL_COORDS[dim]) for dim in dims] + image_data_array = xr.DataArray( + abs( + np.random.randn( + seq_length_5, + image_size_pixels, + image_size_pixels, + number_channels, + ) + ), + coords=coords, + ) # Fake data for testing! + return image_data_array + + +def create_gsp_pv_dataset( + dims=("time", "id"), + freq="5T", + seq_length=19, + number_of_systems=128, +): + """ Create gsp or pv fake dataset """ + ALL_COORDS = { + "time": pd.date_range("2021-01-01", freq=freq, periods=seq_length), + "id": np.random.randint(low=0, high=1000, size=number_of_systems), + } + coords = [(dim, ALL_COORDS[dim]) for dim in dims] + data_array = xr.DataArray( + np.random.randn( + seq_length, + number_of_systems, + ), + coords=coords, + ) # Fake data for testing! + + data = convert_data_array_to_dataset(data_array) + + x_coords = xr.DataArray( + data=np.sort(np.random.randn(number_of_systems)), + dims=["id_index"], + coords=dict( + id_index=range(number_of_systems), + ), + ) + + y_coords = xr.DataArray( + data=np.sort(np.random.randn(number_of_systems)), + dims=["id_index"], + coords=dict( + id_index=range(number_of_systems), + ), + ) + + data["x_coords"] = x_coords + data["y_coords"] = y_coords + + return data + + +def create_sun_dataset( + dims=("time",), + freq="5T", + seq_length=19, +) -> xr.Dataset: + """ + Create sun fake dataset + + Args: + dims: # TODO + freq: # TODO + seq_length: # TODO + + Returns: # TODO + + """ + ALL_COORDS = { + "time": pd.date_range("2021-01-01", freq=freq, periods=seq_length), + } + coords = [(dim, ALL_COORDS[dim]) for dim in dims] + data_array = xr.DataArray( + np.random.randn( + seq_length, + ), + coords=coords, + ) # Fake data for testing! + + data = convert_data_array_to_dataset(data_array) + sun = data.rename({"data": "elevation"}) + sun["azimuth"] = data.data + + return sun + + +def create_metadata_dataset() -> xr.Dataset: + """ Create fake metadata dataset""" + d = { + "dims": ("t0_dt",), + "data": pd.date_range("2021-01-01", freq="5T", periods=1) + pd.Timedelta("30T"), + } + + data = convert_data_array_to_dataset(xr.DataArray.from_dict(d)) + + for v in ["x_meters_center", "y_meters_center", "object_at_center_label"]: + d: dict = {"dims": ("t0_dt",), "data": [np.random.randint(0, 1000)]} + d: xr.Dataset = convert_data_array_to_dataset(xr.DataArray.from_dict(d)).rename({"data": v}) + data[v] = getattr(d, v) + + return data + + +def create_datetime_dataset( + seq_length=19, +) -> xr.Dataset: + """ Create fake datetime dataset""" + ALL_COORDS = { + "time": pd.date_range("2021-01-01", freq="5T", periods=seq_length), + } + coords = [("time", ALL_COORDS["time"])] + data_array = xr.DataArray( + np.random.randn( + seq_length, + ), + coords=coords, + ) # Fake data + + data = convert_data_array_to_dataset(data_array) + + ds = data.rename({"data": "day_of_year_cos"}) + ds["day_of_year_sin"] = data.rename({"data": "day_of_year_sin"}).day_of_year_sin + ds["hour_of_day_cos"] = data.rename({"data": "hour_of_day_cos"}).hour_of_day_cos + ds["hour_of_day_sin"] = data.rename({"data": "hour_of_day_sin"}).hour_of_day_sin + + return data diff --git a/nowcasting_dataset/data_sources/gsp/gsp_data_source.py b/nowcasting_dataset/data_sources/gsp/gsp_data_source.py index fe9c9212..6b6cc825 100644 --- a/nowcasting_dataset/data_sources/gsp/gsp_data_source.py +++ b/nowcasting_dataset/data_sources/gsp/gsp_data_source.py @@ -19,12 +19,12 @@ ) from nowcasting_dataset.data_sources.data_source import ImageDataSource from nowcasting_dataset.data_sources.gsp.eso import get_gsp_metadata_from_eso +from nowcasting_dataset.data_sources.gsp.gsp_model import GSP +from nowcasting_dataset.dataset.xr_utils import convert_data_array_to_dataset from nowcasting_dataset.geospatial import lat_lon_to_osgb from nowcasting_dataset.square import get_bounding_box_mask -# from nowcasting_dataset.utils import scale_to_0_to_1, pad_data from nowcasting_dataset.utils import scale_to_0_to_1 -from nowcasting_dataset.data_sources.gsp.gsp_model import GSP logger = logging.getLogger(__name__) @@ -202,19 +202,45 @@ def get_example( gsp_x_coords = self.metadata[self.metadata["gsp_id"].isin(all_gsp_ids)].location_x gsp_y_coords = self.metadata[self.metadata["gsp_id"].isin(all_gsp_ids)].location_y - # Save data into the Example dict... + # convert to data array + da = xr.DataArray( + data=selected_gsp_power.values, + dims=["time", "id"], + coords=dict( + id=all_gsp_ids.values.astype(int), + time=selected_gsp_power.index.values, + ), + ) + + # convert to dataset + gsp = convert_data_array_to_dataset(da) - gsp = GSP( - gsp_id=all_gsp_ids.values, - gsp_yield=selected_gsp_power.values, - gsp_x_coords=gsp_x_coords.values, - gsp_y_coords=gsp_y_coords.values, - gsp_datetime_index=selected_gsp_power.index.values, + # add gsp x coords + gsp_x_coords = xr.DataArray( + data=gsp_x_coords.values, + dims=["id_index"], + coords=dict( + id_index=range(len(all_gsp_ids.values)), + ), ) - gsp.pad() + gsp_y_coords = xr.DataArray( + data=gsp_y_coords.values, + dims=["id_index"], + coords=dict( + id_index=range(len(all_gsp_ids.values)), + ), + ) + gsp["x_coords"] = gsp_x_coords + gsp["y_coords"] = gsp_y_coords + + # pad out so that there are always 32 gsp + pad_n = self.n_gsp_per_example - len(gsp.id_index) + gsp = gsp.pad(id_index=(0, pad_n), data=((0, 0), (0, pad_n))) + + gsp.__setitem__("id_index", range(self.n_gsp_per_example)) - return gsp + return GSP(gsp) def _get_central_gsp_id( self, diff --git a/nowcasting_dataset/data_sources/gsp/gsp_model.py b/nowcasting_dataset/data_sources/gsp/gsp_model.py index 350e2761..41a65789 100644 --- a/nowcasting_dataset/data_sources/gsp/gsp_model.py +++ b/nowcasting_dataset/data_sources/gsp/gsp_model.py @@ -1,26 +1,36 @@ """ Model for output of GSP data """ -from pydantic import Field, validator +import logging + import numpy as np -import xarray as xr +from pydantic import Field, validator -from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput, pad_data from nowcasting_dataset.consts import Array - from nowcasting_dataset.consts import ( GSP_ID, GSP_YIELD, GSP_X_COORDS, GSP_Y_COORDS, GSP_DATETIME_INDEX, - DEFAULT_N_GSP_PER_EXAMPLE, +) +from nowcasting_dataset.data_sources.datasource_output import ( + DataSourceOutputML, + DataSourceOutput, ) from nowcasting_dataset.time import make_random_time_vectors -import logging logger = logging.getLogger(__name__) class GSP(DataSourceOutput): + """ Class to store GSP data as a xr.Dataset with some validation """ + + __slots__ = () + _expected_dimensions = ("time", "id") + + # todo add validation here - https://github.com/openclimatefix/nowcasting_dataset/issues/233 + + +class GSPML(DataSourceOutputML): """ Model for output of GSP data """ # Shape: [batch_size,] seq_length, width, height, channel @@ -92,7 +102,7 @@ def fake(batch_size, seq_length_30, n_gsp_per_batch, time_30=None): batch_size=batch_size, seq_length_5_minutes=0, seq_length_30_minutes=seq_length_30 ) - return GSP( + return GSPML( batch_size=batch_size, gsp_yield=np.random.randn( batch_size, @@ -106,73 +116,15 @@ def fake(batch_size, seq_length_30, n_gsp_per_batch, time_30=None): ) # copy is needed as torch doesnt not support negative strides - def pad(self, n_gsp_per_example: int = DEFAULT_N_GSP_PER_EXAMPLE): - """ - Pad out data - - Args: - n_gsp_per_example: The number of gsp's there are per example. - - Note that nothing is returned as the changes are made inplace. - """ - assert self.batch_size == 0, "Padding only works for batch_size=0, i.e one Example" - - pad_size = n_gsp_per_example - self.gsp_yield.shape[-1] - pad_data( - data=self, - one_dimensional_arrays=[GSP_ID, GSP_X_COORDS, GSP_Y_COORDS], - two_dimensional_arrays=[GSP_YIELD], - pad_size=pad_size, - ) - def get_datetime_index(self) -> Array: """ Get the datetime index of this data """ return self.gsp_datetime_index - def to_xr_dataset(self, i): - """ Make a xr dataset """ - logger.debug(f"Making xr dataset for batch {i}") - assert self.batch_size == 0 - - example_dim = {"example": np.array([i], dtype=np.int32)} - - # GSP - n_gsp = len(self.gsp_id) - - one_dataset = xr.DataArray(self.gsp_yield, dims=["time_30", "gsp"], name="gsp_yield") - one_dataset = one_dataset.to_dataset(name="gsp_yield") - one_dataset[GSP_DATETIME_INDEX] = xr.DataArray( - self.gsp_datetime_index, - dims=["time_30"], - coords=[np.arange(len(self.gsp_datetime_index))], - ) - - # GSP - for name in [GSP_ID, GSP_X_COORDS, GSP_Y_COORDS]: - - var = self.__getattribute__(name) - - one_dataset[name] = xr.DataArray( - var[None, :], - coords={ - **example_dim, - **{"gsp": np.arange(n_gsp, dtype=np.int32)}, - }, - dims=["example", "gsp"], - ) - - one_dataset[GSP_YIELD] = one_dataset[GSP_YIELD].astype(np.float32) - one_dataset[GSP_ID] = one_dataset[GSP_ID].astype(np.float32) - one_dataset[GSP_X_COORDS] = one_dataset[GSP_X_COORDS].astype(np.float32) - one_dataset[GSP_Y_COORDS] = one_dataset[GSP_Y_COORDS].astype(np.float32) - - return one_dataset - @staticmethod def from_xr_dataset(xr_dataset): """ Change xr dataset to model. If data does not exist, then return None """ if "gsp_yield" in xr_dataset.keys(): - return GSP( + return GSPML( batch_size=xr_dataset["gsp_yield"].shape[0], gsp_yield=xr_dataset[GSP_YIELD], gsp_id=xr_dataset[GSP_ID], diff --git a/nowcasting_dataset/data_sources/metadata/metadata_data_source.py b/nowcasting_dataset/data_sources/metadata/metadata_data_source.py index 1af77596..efdd7ff6 100644 --- a/nowcasting_dataset/data_sources/metadata/metadata_data_source.py +++ b/nowcasting_dataset/data_sources/metadata/metadata_data_source.py @@ -3,11 +3,13 @@ from numbers import Number from typing import List, Tuple -import pandas as pd import numpy as np +import pandas as pd +import xarray as xr from nowcasting_dataset.data_sources.data_source import DataSource from nowcasting_dataset.data_sources.metadata.metadata_model import Metadata +from nowcasting_dataset.dataset.xr_utils import convert_data_array_to_dataset from nowcasting_dataset.utils import to_numpy @@ -42,13 +44,31 @@ def get_example( else: object_at_center_label = 0 - return Metadata( + data_dict = dict( t0_dt=to_numpy(t0_dt), #: Shape: [batch_size,] x_meters_center=np.array(x_meters_center), y_meters_center=np.array(y_meters_center), object_at_center_label=object_at_center_label, ) + d_all = { + "t0_dt": {"dims": ("t0_dt"), "data": [t0_dt]}, + "x_meters_center": {"dims": ("t0_dt_index"), "data": [x_meters_center]}, + "y_meters_center": {"dims": ("t0_dt_index"), "data": [y_meters_center]}, + "object_at_center_label": {"dims": ("t0_dt_index"), "data": [object_at_center_label]}, + } + + data = convert_data_array_to_dataset(xr.DataArray.from_dict(d_all["t0_dt"])) + + for v in ["x_meters_center", "y_meters_center", "object_at_center_label"]: + d: dict = d_all[v] + d: xr.Dataset = convert_data_array_to_dataset(xr.DataArray.from_dict(d)).rename( + {"data": v} + ) + data[v] = getattr(d, v) + + return Metadata(data) + def get_locations_for_batch( self, t0_datetimes: pd.DatetimeIndex ) -> Tuple[List[Number], List[Number]]: diff --git a/nowcasting_dataset/data_sources/metadata/metadata_model.py b/nowcasting_dataset/data_sources/metadata/metadata_model.py index 57be042c..eb20b24b 100644 --- a/nowcasting_dataset/data_sources/metadata/metadata_model.py +++ b/nowcasting_dataset/data_sources/metadata/metadata_model.py @@ -1,17 +1,32 @@ """ Model for output of general/metadata data, useful for a batch """ -from typing import Union, List +from typing import Union + import numpy as np -import xarray as xr import torch -from pydantic import validator, Field +import xarray as xr +from pydantic import Field + +from nowcasting_dataset.data_sources.datasource_output import ( + DataSourceOutputML, + DataSourceOutput, +) -from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput from nowcasting_dataset.time import make_random_time_vectors + # seems to be a pandas dataseries class Metadata(DataSourceOutput): + """ Class to store metedata data as a xr.Dataset with some validation """ + + __slots__ = () + _expected_dimensions = ("t0_dt",) + + # todo add validation here - https://github.com/openclimatefix/nowcasting_dataset/issues/233 + + +class MetadataML(DataSourceOutputML): """Model for output of general/metadata data""" # TODO add descriptions @@ -34,7 +49,7 @@ def fake(batch_size, t0_dt=None): batch_size=batch_size, seq_length_5_minutes=0, seq_length_30_minutes=0 ) - return Metadata( + return MetadataML( batch_size=batch_size, t0_dt=t0_dt, x_meters_center=np.random.randn( @@ -46,29 +61,13 @@ def fake(batch_size, t0_dt=None): object_at_center_label=np.array([1] * batch_size), ) - def to_xr_dataset(self, i): - """Make a xr dataset""" - individual_datasets = [] - for name in ["t0_dt", "x_meters_center", "y_meters_center", "object_at_center_label"]: - - var = self.__getattribute__(name) - - example_dim = {"example": np.array([i], dtype=np.int32)} - - data = xr.DataArray([var], coords=example_dim, dims=["example"], name=name) - - ds = data.to_dataset() - individual_datasets.append(ds) - - return xr.merge(individual_datasets) - @staticmethod def from_xr_dataset(xr_dataset): """Change xr dataset to model. If data does not exist, then return None""" - return Metadata( - batch_size=xr_dataset["t0_dt"].shape[0], - t0_dt=xr_dataset["t0_dt"], - x_meters_center=xr_dataset["x_meters_center"], - y_meters_center=xr_dataset["y_meters_center"], - object_at_center_label=xr_dataset["object_at_center_label"], + return MetadataML( + batch_size=xr_dataset.t0_dt.shape[0], + t0_dt=xr_dataset.t0_dt.values, + x_meters_center=xr_dataset.x_meters_center.values, + y_meters_center=xr_dataset.y_meters_center.values, + object_at_center_label=xr_dataset.object_at_center_label.values, ) diff --git a/nowcasting_dataset/data_sources/nwp/nwp_data_source.py b/nowcasting_dataset/data_sources/nwp/nwp_data_source.py index a506ab36..bd647101 100644 --- a/nowcasting_dataset/data_sources/nwp/nwp_data_source.py +++ b/nowcasting_dataset/data_sources/nwp/nwp_data_source.py @@ -12,7 +12,7 @@ from nowcasting_dataset import utils from nowcasting_dataset.data_sources.data_source import ZarrDataSource from nowcasting_dataset.data_sources.nwp.nwp_model import NWP -from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput +from nowcasting_dataset.dataset.xr_utils import join_list_data_array_to_batch_dataset _LOG = logging.getLogger(__name__) @@ -178,27 +178,15 @@ def get_batch( t0_dt = t0_datetimes[i] selected_data = self._post_process_example(selected_data, t0_dt) - output: DataSourceOutput = self._put_data_into_example(selected_data) - if self.convert_to_numpy: - output.to_numpy() - examples.append(output) + examples.append(selected_data) - return DataSourceOutput.create_batch_from_examples(examples) + output = join_list_data_array_to_batch_dataset(examples) + + return NWP(output) def _open_data(self) -> xr.DataArray: return open_nwp(self.filename, consolidated=self.consolidated) - def _put_data_into_example(self, selected_data: xr.DataArray) -> NWP: - - return NWP( - nwp=selected_data, - nwp_x_coords=selected_data.x, - nwp_y_coords=selected_data.y, - nwp_target_time=selected_data.target_time, - nwp_init_time=np.array(selected_data.init_time.data), - nwp_channel_names=self.channels, # TODO perhaps could get this from selected data instead - ) - def _get_time_slice(self, t0_dt: pd.Timestamp) -> xr.DataArray: """ Select the numerical weather predictions for a single time slice. @@ -243,6 +231,11 @@ def _post_process_example( selected_data = selected_data.resample({"target_time": "5T"}) selected_data = selected_data.interpolate() selected_data = selected_data.sel(target_time=slice(start_dt, end_dt)) + selected_data = selected_data.rename({"target_time": "time"}) + selected_data = selected_data.rename({"variable": "channels"}) + + selected_data.data = selected_data.data.astype(np.float32) + return selected_data def datetime_index(self) -> pd.DatetimeIndex: diff --git a/nowcasting_dataset/data_sources/nwp/nwp_model.py b/nowcasting_dataset/data_sources/nwp/nwp_model.py index 09626ebb..3c36087c 100644 --- a/nowcasting_dataset/data_sources/nwp/nwp_model.py +++ b/nowcasting_dataset/data_sources/nwp/nwp_model.py @@ -1,175 +1,112 @@ """ Model for output of NWP data """ -from pydantic import Field, validator -from typing import Union, List +from __future__ import annotations + +import logging + import numpy as np import xarray as xr -import torch +from pydantic import Field -from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput -from nowcasting_dataset.consts import ( - Array, - NWP_VARIABLE_NAMES, - NWP_DATA, +from nowcasting_dataset.consts import Array +from nowcasting_dataset.data_sources.datasource_output import ( + DataSourceOutputML, + DataSourceOutput, ) -from nowcasting_dataset.utils import coord_to_range from nowcasting_dataset.time import make_random_time_vectors -import logging logger = logging.getLogger(__name__) class NWP(DataSourceOutput): + """ Class to store NWP data as a xr.Dataset with some validation """ + + # Use to store xr.Dataset data + + __slots__ = () + _expected_dimensions = ("time", "x", "y", "channels") + + @classmethod + def model_validation(cls, v): + """ Check that all values are not NaNs """ + assert (v.data != np.nan).all(), "Some nwp data values are NaNs" + return v + + +class NWPML(DataSourceOutputML): """ Model for output of NWP data """ # Shape: [batch_size,] seq_length, width, height, channel - nwp: Array = Field( + data: Array = Field( ..., description=" Numerical weather predictions (NWPs) \ : Shape: [batch_size,] channel, seq_length, width, height", ) - nwp_x_coords: Array = Field( + x: Array = Field( ..., description="The x (OSGB geo-spatial) coordinates of the NWP data. " "Shape: [batch_size,] width", ) - nwp_y_coords: Array = Field( + y: Array = Field( ..., description="The y (OSGB geo-spatial) coordinates of the NWP data. " "Shape: [batch_size,] height", ) - nwp_target_time: Array = Field( + target_time: Array = Field( ..., description="Time index of nwp data at 5 minutes past the hour {0, 5, ..., 55}. " "Datetimes become Unix epochs (UTC) represented as int64 just before being" "passed into the ML model. The 'target time' is the time the NWP is _about_.", ) - nwp_init_time: Union[xr.DataArray, np.ndarray, torch.Tensor, int] = Field( - ..., description="The time when the nwp forecast was made" - ) - - nwp_channel_names: Union[List[List[str]], List[str], np.ndarray] = Field( - ..., description="List of the nwp channels" - ) - - @property - def width(self): - """The width of the nwp data""" - return self.nwp.shape[-2] - - @property - def height(self): - """The width of the nwp data""" - return self.nwp.shape[-1] - - @property - def sequence_length(self): - """The sequence length of the NWP timeseries""" - return self.nwp.shape[-3] - - @validator("nwp_x_coords") - def x_coordinates_shape(cls, v, values): - """ Validate 'nwp_x_coords' """ - assert v.shape[-1] == values["nwp"].shape[-2] - return v + init_time: Array = Field(..., description="The time when the nwp forecast was made") - @validator("nwp_y_coords") - def y_coordinates_shape(cls, v, values): - """ Validate 'nwp_y_coords' """ - assert v.shape[-1] == values["nwp"].shape[-1] - return v + channels: Array = Field(..., description="List of the nwp channels") @staticmethod - def fake(batch_size, seq_length_5, nwp_image_size_pixels, number_nwp_channels, time_5=None): - """ Create fake data """ + def fake( + batch_size=32, + seq_length_5=19, + image_size_pixels=64, + number_nwp_channels=7, + time_5=None, + ): + """Create fake data""" if time_5 is None: _, time_5, _ = make_random_time_vectors( batch_size=batch_size, seq_length_5_minutes=seq_length_5, seq_length_30_minutes=0 ) - return NWP( + s = NWPML( batch_size=batch_size, - nwp=np.random.randn( + data=np.random.randn( batch_size, - number_nwp_channels, seq_length_5, - nwp_image_size_pixels, - nwp_image_size_pixels, + image_size_pixels, + image_size_pixels, + number_nwp_channels, ), - nwp_x_coords=np.sort(np.random.randn(batch_size, nwp_image_size_pixels)), - nwp_y_coords=np.sort(np.random.randn(batch_size, nwp_image_size_pixels))[ - :, ::-1 - ].copy(), + x=np.sort(np.random.randn(batch_size, image_size_pixels)), + y=np.sort(np.random.randn(batch_size, image_size_pixels))[:, ::-1].copy() # copy is needed as torch doesnt not support negative strides - nwp_target_time=time_5, - nwp_init_time=np.sort( - np.random.randn( - batch_size, - ) - ), - nwp_channel_names=[ - NWP_VARIABLE_NAMES[0:number_nwp_channels] for _ in range(batch_size) - ], + , + target_time=time_5, + init_time=time_5[0], + channels=np.array([list(range(number_nwp_channels)) for _ in range(batch_size)]), ) + return s + def get_datetime_index(self) -> Array: - """ Get the datetime index of this data """ - return self.nwp_target_time - - def to_xr_data_array(self): - """ Change to data_array. Sets the nwp field in-place.""" - self.nwp = xr.DataArray( - self.nwp, - dims=["variable", "target_time", "x", "y"], - coords={ - "variable": self.nwp_channel_names, - "target_time": self.nwp_target_time, - "init_time": self.nwp_init_time, - "x": self.nwp_x_coords, - "y": self.nwp_y_coords, - }, - ) + """Get the datetime index of this data""" + return self.target_time - def to_xr_dataset(self, i): - """ Make a xr dataset """ - logger.debug(f"Making xr dataset for batch {i}") - if type(self.nwp) != xr.DataArray: - self.to_xr_data_array() - - ds = self.nwp.to_dataset(name="nwp") - ds["nwp"] = ds["nwp"].astype(np.float32) - ds = ds.round(2) - - ds = ds.rename({"target_time": "time"}) - for dim in ["time", "x", "y"]: - ds = coord_to_range(ds, dim, prefix="nwp") - ds = ds.rename( - { - "variable": f"nwp_variable", - "x": "nwp_x", - "y": "nwp_y", - } + @staticmethod + def from_xr_dataset(xr_dataset: xr.Dataset): + """Change xr dataset to model with tensors""" + nwp_batch_ml = xr_dataset.torch.to_tensor( + ["data", "target_time", "init_time", "x", "y", "channels"] ) - ds["nwp_x_coords"] = ds["nwp_x_coords"].astype(np.float32) - ds["nwp_y_coords"] = ds["nwp_y_coords"].astype(np.float32) - - return ds - - @staticmethod - def from_xr_dataset(xr_dataset): - """ Change xr dataset to model. If data does not exist, then return None """ - if NWP_DATA in xr_dataset.keys(): - return NWP( - batch_size=xr_dataset[NWP_DATA].shape[0], - nwp=xr_dataset[NWP_DATA], - nwp_channel_names=xr_dataset[NWP_DATA].nwp_variable.values, - nwp_init_time=xr_dataset[NWP_DATA].init_time, - nwp_target_time=xr_dataset["nwp_time_coords"], - nwp_x_coords=xr_dataset[NWP_DATA].nwp_x, - nwp_y_coords=xr_dataset[NWP_DATA].nwp_y, - ) - else: - return None + return NWPML(**nwp_batch_ml) diff --git a/nowcasting_dataset/data_sources/pv/pv_data_source.py b/nowcasting_dataset/data_sources/pv/pv_data_source.py index 8f8bf558..bff3c9e9 100644 --- a/nowcasting_dataset/data_sources/pv/pv_data_source.py +++ b/nowcasting_dataset/data_sources/pv/pv_data_source.py @@ -20,8 +20,9 @@ DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE, ) from nowcasting_dataset.data_sources.data_source import ImageDataSource -from nowcasting_dataset.square import get_bounding_box_mask from nowcasting_dataset.data_sources.pv.pv_model import PV +from nowcasting_dataset.dataset.xr_utils import convert_data_array_to_dataset +from nowcasting_dataset.square import get_bounding_box_mask logger = logging.getLogger(__name__) @@ -235,18 +236,45 @@ def get_example( # Save data into the PV object... - pv = PV( - pv_system_id=all_pv_system_ids.values, - pv_system_row_number=pv_system_row_number, - pv_yield=selected_pv_power.values, - pv_system_x_coords=pv_system_x_coords.values, - pv_system_y_coords=pv_system_y_coords.values, - pv_datetime_index=selected_pv_power.index.values, + # convert to data array + da = xr.DataArray( + data=selected_pv_power.values, + dims=["time", "id"], + coords=dict( + id=all_pv_system_ids.values.astype(int), + time=selected_pv_power.index.values, + ), + ) + + # convert to dataset + pv = convert_data_array_to_dataset(da) + + # add pv x coords + x_coords = xr.DataArray( + data=pv_system_x_coords.values, + dims=["id_index"], + coords=dict( + id_index=range(len(all_pv_system_ids.values)), + ), + ) + + y_coords = xr.DataArray( + data=pv_system_y_coords.values, + dims=["id_index"], + coords=dict( + id_index=range(len(all_pv_system_ids.values)), + ), ) + pv["x_coords"] = x_coords + pv["y_coords"] = y_coords + + # pad out so that there are always 32 gsp + pad_n = self.n_pv_systems_per_example - len(pv.id_index) + pv = pv.pad(id_index=(0, pad_n), data=((0, 0), (0, pad_n))) - pv.pad() + pv.__setitem__("id_index", range(self.n_pv_systems_per_example)) - return pv + return PV(pv) def get_locations_for_batch( self, t0_datetimes: pd.DatetimeIndex diff --git a/nowcasting_dataset/data_sources/pv/pv_model.py b/nowcasting_dataset/data_sources/pv/pv_model.py index fdf9c8da..c31a8394 100644 --- a/nowcasting_dataset/data_sources/pv/pv_model.py +++ b/nowcasting_dataset/data_sources/pv/pv_model.py @@ -1,9 +1,9 @@ """ Model for output of PV data """ -from pydantic import Field, validator +import logging + import numpy as np -import xarray as xr +from pydantic import Field, validator -from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput, pad_data from nowcasting_dataset.consts import ( Array, PV_YIELD, @@ -12,15 +12,28 @@ PV_SYSTEM_X_COORDS, PV_SYSTEM_ROW_NUMBER, PV_SYSTEM_ID, - DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE, +) +from nowcasting_dataset.data_sources.datasource_output import ( + DataSourceOutputML, + DataSourceOutput, ) from nowcasting_dataset.time import make_random_time_vectors -import logging logger = logging.getLogger(__name__) class PV(DataSourceOutput): + """ Class to store PV data as a xr.Dataset with some validation """ + + # Use to store xr.Dataset data + + __slots__ = () + _expected_dimensions = ("time", "id") + + # todo add validation here - https://github.com/openclimatefix/nowcasting_dataset/issues/233 + + +class PVML(DataSourceOutputML): """ Model for output of PV data """ # Shape: [batch_size,] seq_length, width, height, channel @@ -82,7 +95,7 @@ def fake(batch_size, seq_length_5, n_pv_systems_per_batch, time_5=None): batch_size=batch_size, seq_length_5_minutes=seq_length_5, seq_length_30_minutes=0 ) - return PV( + return PVML( batch_size=batch_size, pv_yield=np.random.randn( batch_size, @@ -100,87 +113,15 @@ def fake(batch_size, seq_length_5, n_pv_systems_per_batch, time_5=None): ].copy(), # copy is needed as torch doesnt not support negative strides ) - def pad(self, n_pv_systems_per_example: int = DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE): - """ - Pad out data - - Args: - n_pv_systems_per_example: The number of pv systems there are per example. - - Note that nothing is returned as the changes are made inplace. - """ - assert self.batch_size == 0, "Padding only works for batch_size=0, i.e one Example" - - pad_size = n_pv_systems_per_example - self.pv_yield.shape[-1] - # Pad (if necessary) so returned arrays are always of size - pad_shape = (0, pad_size) # (before, after) - - one_dimensional_arrays = [ - PV_SYSTEM_ID, - PV_SYSTEM_ROW_NUMBER, - PV_SYSTEM_X_COORDS, - PV_SYSTEM_Y_COORDS, - ] - - pad_data( - data=self, - pad_size=pad_size, - one_dimensional_arrays=one_dimensional_arrays, - two_dimensional_arrays=[PV_YIELD], - ) - def get_datetime_index(self) -> Array: """ Get the datetime index of this data """ return self.pv_datetime_index - def to_xr_dataset(self, i): - """ Make a xr dataset """ - logger.debug(f"Making xr dataset for batch {i}") - assert self.batch_size == 0 - - example_dim = {"example": np.array([i], dtype=np.int32)} - - # PV - one_dataset = xr.DataArray(self.pv_yield, dims=["time", "pv_system"]) - one_dataset = one_dataset.to_dataset(name="pv_yield") - n_pv_systems = len(self.pv_system_id) - - one_dataset[PV_DATETIME_INDEX] = xr.DataArray( - self.pv_datetime_index, - dims=["time"], - coords=[np.arange(len(self.pv_datetime_index))], - ) - - # 1D - for name in [ - PV_SYSTEM_ID, - PV_SYSTEM_ROW_NUMBER, - PV_SYSTEM_X_COORDS, - PV_SYSTEM_Y_COORDS, - ]: - var = self.__getattribute__(name) - - one_dataset[name] = xr.DataArray( - var[None, :], - coords={ - **example_dim, - **{"pv_system": np.arange(n_pv_systems, dtype=np.int32)}, - }, - dims=["example", "pv_system"], - ) - - one_dataset["pv_system_id"] = one_dataset["pv_system_id"].astype(np.float32) - one_dataset["pv_system_row_number"] = one_dataset["pv_system_row_number"].astype(np.float32) - one_dataset["pv_system_x_coords"] = one_dataset["pv_system_x_coords"].astype(np.float32) - one_dataset["pv_system_y_coords"] = one_dataset["pv_system_y_coords"].astype(np.float32) - - return one_dataset - @staticmethod def from_xr_dataset(xr_dataset): """ Change xr dataset to model. If data does not exist, then return None """ if PV_YIELD in xr_dataset.keys(): - return PV( + return PVML( batch_size=xr_dataset[PV_YIELD].shape[0], pv_yield=xr_dataset[PV_YIELD], pv_system_id=xr_dataset[PV_SYSTEM_ID], diff --git a/nowcasting_dataset/data_sources/satellite/satellite_data_source.py b/nowcasting_dataset/data_sources/satellite/satellite_data_source.py index 58a18b13..3cb3590d 100644 --- a/nowcasting_dataset/data_sources/satellite/satellite_data_source.py +++ b/nowcasting_dataset/data_sources/satellite/satellite_data_source.py @@ -11,7 +11,7 @@ from nowcasting_dataset.data_sources.data_source import ZarrDataSource from nowcasting_dataset.data_sources.satellite.satellite_model import Satellite -from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput +from nowcasting_dataset.dataset.xr_utils import join_list_data_array_to_batch_dataset import nowcasting_dataset.time as nd_time _LOG = logging.getLogger("nowcasting_dataset") @@ -152,22 +152,11 @@ def get_batch( example = self.get_example(t0_datetime, x_location, y_location) examples.append(example) - output = DataSourceOutput.create_batch_from_examples(examples) + output = join_list_data_array_to_batch_dataset(examples) - if self.convert_to_numpy: - output.to_numpy() self._cache = {} - return output - - def _put_data_into_example(self, selected_data: xr.DataArray) -> Satellite: - return Satellite( - sat_data=selected_data, - sat_x_coords=selected_data.x, - sat_y_coords=selected_data.y, - sat_datetime_index=selected_data.time, - sat_channel_names=self.channels, - ) + return Satellite(output) def _get_time_slice(self, t0_dt: pd.Timestamp) -> xr.DataArray: try: @@ -186,6 +175,9 @@ def _post_process_example( if self.normalise: selected_data = selected_data - SAT_MEAN selected_data = selected_data / SAT_STD + + selected_data.data = selected_data.data.astype(np.float32) + return selected_data def datetime_index(self, remove_night: bool = True) -> pd.DatetimeIndex: diff --git a/nowcasting_dataset/data_sources/satellite/satellite_model.py b/nowcasting_dataset/data_sources/satellite/satellite_model.py index 83b346a3..e6392d6d 100644 --- a/nowcasting_dataset/data_sources/satellite/satellite_model.py +++ b/nowcasting_dataset/data_sources/satellite/satellite_model.py @@ -1,36 +1,55 @@ """ Model for output of satellite data """ -from pydantic import Field, validator -from typing import Union, List +from __future__ import annotations + +import logging + import numpy as np import xarray as xr +from pydantic import Field -from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput -from nowcasting_dataset.consts import Array, SAT_VARIABLE_NAMES -from nowcasting_dataset.utils import coord_to_range +from nowcasting_dataset.consts import Array +from nowcasting_dataset.data_sources.datasource_output import ( + DataSourceOutputML, + DataSourceOutput, +) from nowcasting_dataset.time import make_random_time_vectors -import logging logger = logging.getLogger(__name__) class Satellite(DataSourceOutput): - """ Model for output of satellite data """ + """ Class to store satellite data as a xr.Dataset with some validation """ + + # Use to store xr.Dataset data + + __slots__ = () + _expected_dimensions = ("time", "x", "y", "channels") + + @classmethod + def model_validation(cls, v): + """ Check that all values are non negative """ + assert (v.data != np.NaN).all(), f"Some satellite data values are NaNs" + return v + + +class SatelliteML(DataSourceOutputML): + """Model for output of satellite data""" # Shape: [batch_size,] seq_length, width, height, channel - sat_data: Array = Field( + data: Array = Field( ..., description="Satellites images. Shape: [batch_size,] seq_length, width, height, channel", ) - sat_x_coords: Array = Field( + x: Array = Field( ..., description="aThe x (OSGB geo-spatial) coordinates of the satellite images. Shape: [batch_size,] width", ) - sat_y_coords: Array = Field( + y: Array = Field( ..., description="The y (OSGB geo-spatial) coordinates of the satellite images. Shape: [batch_size,] height", ) - sat_datetime_index: Array = Field( + time: Array = Field( ..., description="Time index of satellite data at 5 minutes past the hour {0, 5, ..., 55}. " "*not* the {4, 9, ..., 59} timings of the satellite imagery. " @@ -38,21 +57,7 @@ class Satellite(DataSourceOutput): "passed into the ML model.", ) - sat_channel_names: Union[List[List[str]], List[str], np.ndarray] = Field( - ..., description="List of the satellite channels" - ) - - @validator("sat_x_coords") - def x_coordinates_shape(cls, v, values): - """ Validate 'sat_x_coords' """ - assert v.shape[-1] == values["sat_data"].shape[-3] - return v - - @validator("sat_y_coords") - def y_coordinates_shape(cls, v, values): - """ Validate 'sat_y_coords' """ - assert v.shape[-1] == values["sat_data"].shape[-2] - return v + channels: Array = Field(..., description="List of the satellite channels") @staticmethod def fake( @@ -62,83 +67,38 @@ def fake( number_sat_channels=7, time_5=None, ): - """ Create fake data """ + """Create fake data""" if time_5 is None: _, time_5, _ = make_random_time_vectors( batch_size=batch_size, seq_length_5_minutes=seq_length_5, seq_length_30_minutes=0 ) - s = Satellite( + s = SatelliteML( batch_size=batch_size, - sat_data=np.random.randn( + data=np.random.randn( batch_size, seq_length_5, satellite_image_size_pixels, satellite_image_size_pixels, number_sat_channels, ), - sat_x_coords=np.sort(np.random.randn(batch_size, satellite_image_size_pixels)), - sat_y_coords=np.sort(np.random.randn(batch_size, satellite_image_size_pixels))[ - :, ::-1 - ].copy() + x=np.sort(np.random.randn(batch_size, satellite_image_size_pixels)), + y=np.sort(np.random.randn(batch_size, satellite_image_size_pixels))[:, ::-1].copy() # copy is needed as torch doesnt not support negative strides , - sat_datetime_index=time_5, - sat_channel_names=[ - SAT_VARIABLE_NAMES[0:number_sat_channels] for _ in range(batch_size) - ], + time=time_5, + channels=np.array([list(range(number_sat_channels)) for _ in range(batch_size)]), ) return s def get_datetime_index(self) -> Array: - """ Get the datetime index of this data """ - return self.sat_datetime_index - - def to_xr_dataset(self, i): - """ Make a xr dataset """ - logger.debug(f"Making xr dataset for batch {i}") - if type(self.sat_data) != xr.DataArray: - self.sat_data = xr.DataArray( - self.sat_data, - coords={ - "time": self.sat_datetime_index, - "x": self.sat_x_coords, - "y": self.sat_y_coords, - "variable": self.sat_channel_names, # assume all channels are the same - }, - ) - - ds = self.sat_data.to_dataset(name="sat_data") - ds["sat_data"] = ds["sat_data"].astype(np.int16) - ds = ds.round(2) - - for dim in ["time", "x", "y"]: - ds = coord_to_range(ds, dim, prefix="sat") - ds = ds.rename( - { - "variable": f"sat_variable", - "x": f"sat_x", - "y": f"sat_y", - } - ) - - ds["sat_x_coords"] = ds["sat_x_coords"].astype(np.int32) - ds["sat_y_coords"] = ds["sat_y_coords"].astype(np.int32) - - return ds + """Get the datetime index of this data""" + return self.time @staticmethod - def from_xr_dataset(xr_dataset): - """ Change xr dataset to model. If data does not exist, then return None """ - if "sat_data" in xr_dataset.keys(): - return Satellite( - batch_size=xr_dataset["sat_data"].shape[0], - sat_data=xr_dataset["sat_data"], - sat_x_coords=xr_dataset["sat_x_coords"], - sat_y_coords=xr_dataset["sat_y_coords"], - sat_datetime_index=xr_dataset["sat_time_coords"], - sat_channel_names=xr_dataset["sat_data"].sat_variable.values, - ) - else: - return None + def from_xr_dataset(xr_dataset: xr.Dataset): + """Change xr dataset to model. """ + satellite_batch_ml = xr_dataset.torch.to_tensor(["data", "time", "x", "y", "channels"]) + + return SatelliteML(**satellite_batch_ml) diff --git a/nowcasting_dataset/data_sources/sun/raw_data_load_save.py b/nowcasting_dataset/data_sources/sun/raw_data_load_save.py index 8ec1b531..015b366d 100644 --- a/nowcasting_dataset/data_sources/sun/raw_data_load_save.py +++ b/nowcasting_dataset/data_sources/sun/raw_data_load_save.py @@ -1,18 +1,15 @@ """ Sun Data Source """ import datetime -import time -import io import logging +import time from concurrent import futures +from pathlib import Path from typing import List, Union, Optional -import fsspec import numcodecs import pandas as pd -from tqdm import tqdm import xarray as xr -import numpy as np -from pathlib import Path +from tqdm import tqdm from nowcasting_dataset import geospatial diff --git a/nowcasting_dataset/data_sources/sun/sun_data_source.py b/nowcasting_dataset/data_sources/sun/sun_data_source.py index 755ec736..439f0150 100644 --- a/nowcasting_dataset/data_sources/sun/sun_data_source.py +++ b/nowcasting_dataset/data_sources/sun/sun_data_source.py @@ -1,16 +1,18 @@ """ Loading Raw data """ -from nowcasting_dataset.data_sources.data_source import DataSource from dataclasses import dataclass -import pandas as pd +from datetime import datetime from numbers import Number -from typing import List, Tuple, Union, Optional from pathlib import Path +from typing import List, Tuple, Union, Optional + import numpy as np -from datetime import datetime +import pandas as pd +import xarray as xr +from nowcasting_dataset.data_sources.data_source import DataSource from nowcasting_dataset.data_sources.sun.raw_data_load_save import load_from_zarr, x_y_to_name - from nowcasting_dataset.data_sources.sun.sun_model import Sun +from nowcasting_dataset.dataset.xr_utils import convert_data_array_to_dataset @dataclass @@ -66,13 +68,14 @@ def get_example( azimuth = self.azimuth.loc[start_dt:end_dt][name] elevation = self.elevation.loc[start_dt:end_dt][name] - sun = Sun( - sun_azimuth_angle=azimuth.values, - sun_elevation_angle=elevation.values, - sun_datetime_index=azimuth.index.values, - ) + azimuth = azimuth.to_xarray().rename({"index": "time"}) + elevation = elevation.to_xarray().rename({"index": "time"}) + + sun = convert_data_array_to_dataset(azimuth).rename({"data": "azimuth"}) + elevation = convert_data_array_to_dataset(elevation) + sun["elevation"] = elevation.data - return sun + return Sun(sun) def _load(self): diff --git a/nowcasting_dataset/data_sources/sun/sun_model.py b/nowcasting_dataset/data_sources/sun/sun_model.py index e7620936..4d13d138 100644 --- a/nowcasting_dataset/data_sources/sun/sun_model.py +++ b/nowcasting_dataset/data_sources/sun/sun_model.py @@ -1,18 +1,30 @@ """ Model for Sun features """ -from pydantic import Field, validator +import logging + import numpy as np -import xarray as xr +from pydantic import Field, validator -from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput from nowcasting_dataset.consts import Array, SUN_AZIMUTH_ANGLE, SUN_ELEVATION_ANGLE -from nowcasting_dataset.utils import coord_to_range +from nowcasting_dataset.data_sources.datasource_output import ( + DataSourceOutputML, + DataSourceOutput, +) from nowcasting_dataset.time import make_random_time_vectors -import logging logger = logging.getLogger(__name__) class Sun(DataSourceOutput): + """ Class to store Sun data as a xr.Dataset with some validation """ + + # Use to store xr.Dataset data + __slots__ = () + _expected_dimensions = ("time",) + + # todo add validation here - https://github.com/openclimatefix/nowcasting_dataset/issues/233 + + +class SunML(DataSourceOutputML): """ Model for Sun features """ sun_azimuth_angle: Array = Field( @@ -55,7 +67,7 @@ def fake(batch_size, seq_length_5, time_5=None): batch_size=batch_size, seq_length_5_minutes=seq_length_5, seq_length_30_minutes=0 ) - return Sun( + return SunML( batch_size=batch_size, sun_azimuth_angle=np.random.randn( batch_size, @@ -72,40 +84,11 @@ def get_datetime_index(self): """ Get the datetime index of this data """ return self.sun_datetime_index - def to_xr_dataset(self, i): - """ Make a xr dataset """ - logger.debug(f"Making xr dataset for batch {i}") - individual_datasets = [] - for name in [SUN_AZIMUTH_ANGLE, SUN_ELEVATION_ANGLE]: - - var = self.__getattribute__(name) - - data = xr.DataArray( - var, - dims=["time"], - coords={"time": self.sun_datetime_index}, - name=name, - ) - - ds = data.to_dataset() - ds = coord_to_range(ds, "time", prefix=None) - individual_datasets.append(ds) - - data = xr.DataArray( - self.sun_datetime_index, - dims=["time"], - coords=[np.arange(len(self.sun_datetime_index))], - ) - ds = data.to_dataset(name="sun_datetime_index") - individual_datasets.append(ds) - - return xr.merge(individual_datasets) - @staticmethod def from_xr_dataset(xr_dataset): """ Change xr dataset to model. If data does not exist, then return None """ if SUN_AZIMUTH_ANGLE in xr_dataset.keys(): - return Sun( + return SunML( batch_size=xr_dataset[SUN_AZIMUTH_ANGLE].shape[0], sun_azimuth_angle=xr_dataset[SUN_AZIMUTH_ANGLE], sun_elevation_angle=xr_dataset[SUN_ELEVATION_ANGLE], diff --git a/nowcasting_dataset/data_sources/topographic/topographic_data_source.py b/nowcasting_dataset/data_sources/topographic/topographic_data_source.py index e396dd26..5d0cd123 100644 --- a/nowcasting_dataset/data_sources/topographic/topographic_data_source.py +++ b/nowcasting_dataset/data_sources/topographic/topographic_data_source.py @@ -10,10 +10,10 @@ from nowcasting_dataset.consts import TOPOGRAPHIC_DATA from nowcasting_dataset.data_sources.data_source import ImageDataSource -from nowcasting_dataset.geospatial import OSGB - from nowcasting_dataset.data_sources.topographic.topographic_model import Topographic +from nowcasting_dataset.geospatial import OSGB from nowcasting_dataset.utils import OpenData +from nowcasting_dataset.dataset.xr_utils import convert_data_array_to_dataset # Means computed with # out_fp = "europe_dem_1km.tif" @@ -112,23 +112,9 @@ def get_example( f"actual shape {selected_data.shape}" ) - return self._put_data_into_example(selected_data) - - def _put_data_into_example(self, selected_data: xr.DataArray) -> Topographic: - """ - Insert the data and coordinates into an Example - - Args: - selected_data: DataArray containing the data to insert + topo_xd = convert_data_array_to_dataset(selected_data) - Returns: - Example containing the Topographic data - """ - return Topographic( - topo_data=selected_data, - topo_x_coords=selected_data.x, - topo_y_coords=selected_data.y, - ) + return Topographic(topo_xd) def _post_process_example( self, selected_data: xr.DataArray, t0_dt: pd.Timestamp diff --git a/nowcasting_dataset/data_sources/topographic/topographic_model.py b/nowcasting_dataset/data_sources/topographic/topographic_model.py index 06f4350f..0bc44148 100644 --- a/nowcasting_dataset/data_sources/topographic/topographic_model.py +++ b/nowcasting_dataset/data_sources/topographic/topographic_model.py @@ -1,18 +1,27 @@ """ Model for Topogrpahic features """ -from pydantic import Field, validator -import xarray as xr -import numpy as np import logging -from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput -from nowcasting_dataset.consts import Array -from nowcasting_dataset.consts import TOPOGRAPHIC_DATA, TOPOGRAPHIC_X_COORDS, TOPOGRAPHIC_Y_COORDS -from nowcasting_dataset.utils import coord_to_range +import numpy as np +from pydantic import Field, validator + +from nowcasting_dataset.consts import Array +from nowcasting_dataset.consts import TOPOGRAPHIC_DATA +from nowcasting_dataset.data_sources.datasource_output import DataSourceOutputML, DataSourceOutput logger = logging.getLogger(__name__) class Topographic(DataSourceOutput): + """ Class to store topographic data as a xr.Dataset with some validation """ + + # Use to store xr.Dataset data + __slots__ = () + _expected_dimensions = ("x", "y") + + # todo add validation here - https://github.com/openclimatefix/nowcasting_dataset/issues/233 + + +class TopographicML(DataSourceOutputML): """ Topographic/elevation map features. """ @@ -55,54 +64,25 @@ def y_coordinates_shape(cls, v, values): return v @staticmethod - def fake(batch_size, satellite_image_size_pixels): + def fake(batch_size, image_size_pixels): """ Create fake data """ - return Topographic( + return TopographicML( batch_size=batch_size, topo_data=np.random.randn( batch_size, - satellite_image_size_pixels, - satellite_image_size_pixels, + image_size_pixels, + image_size_pixels, ), - topo_x_coords=np.sort(np.random.randn(batch_size, satellite_image_size_pixels)), - topo_y_coords=np.sort(np.random.randn(batch_size, satellite_image_size_pixels))[ - :, ::-1 - ].copy(), + topo_x_coords=np.sort(np.random.randn(batch_size, image_size_pixels)), + topo_y_coords=np.sort(np.random.randn(batch_size, image_size_pixels))[:, ::-1].copy(), # copy is needed as torch doesnt not support negative strides ) - def to_xr_dataset(self, i): - """ Make a xr dataset """ - logger.debug(f"Making xr dataset for batch {i}") - data = xr.DataArray( - self.topo_data, - coords={ - "x": self.topo_x_coords, - "y": self.topo_y_coords, - }, - ) - - ds = data.to_dataset(name=TOPOGRAPHIC_DATA) - for dim in ["x", "y"]: - ds = coord_to_range(ds, dim, prefix="topo") - ds = ds.rename( - { - "x": f"topo_x", - "y": f"topo_y", - } - ) - - ds[TOPOGRAPHIC_DATA] = ds[TOPOGRAPHIC_DATA].astype(np.float32) - ds[TOPOGRAPHIC_X_COORDS] = ds[TOPOGRAPHIC_X_COORDS].astype(np.float32) - ds[TOPOGRAPHIC_Y_COORDS] = ds[TOPOGRAPHIC_Y_COORDS].astype(np.float32) - - return ds - @staticmethod def from_xr_dataset(xr_dataset): """ Change xr dataset to model. If data does not exist, then return None """ if TOPOGRAPHIC_DATA in xr_dataset.keys(): - return Topographic( + return TopographicML( batch_size=xr_dataset[TOPOGRAPHIC_DATA].shape[0], topo_data=xr_dataset[TOPOGRAPHIC_DATA], topo_x_coords=xr_dataset[TOPOGRAPHIC_DATA].topo_x, diff --git a/nowcasting_dataset/dataset/README.md b/nowcasting_dataset/dataset/README.md index d7f59549..040a47f2 100644 --- a/nowcasting_dataset/dataset/README.md +++ b/nowcasting_dataset/dataset/README.md @@ -5,6 +5,7 @@ This folder contains the following files ## batch.py 'Batch' pydantic class, to hold batch data in. An 'Example' is one item in the batch. +'BatchML' pydantic class, holds data for a batch, ready for ML models. ## datamodule.py @@ -23,6 +24,10 @@ NetCDFDataset - torch.utils.data.Dataset: Use for loading pre-made batches NowcastingDataset - torch.utils.data.IterableDataset: Dataset for making batches -## validate.py +## subset.py -Contains a class that can validate the prepare ml dataset +Function to subset the 'Batch' + +## fake.py + +A fake dataset, perhaps useful outside this repo. diff --git a/nowcasting_dataset/dataset/batch.py b/nowcasting_dataset/dataset/batch.py index cde65ad8..1cb63ef9 100644 --- a/nowcasting_dataset/dataset/batch.py +++ b/nowcasting_dataset/dataset/batch.py @@ -1,34 +1,70 @@ """ batch functions """ +from __future__ import annotations + import logging import os from pathlib import Path -from typing import List, Optional, Union, Dict +from typing import Optional, Union import xarray as xr from pydantic import BaseModel, Field -from nowcasting_dataset.filesystem.utils import make_folder - from nowcasting_dataset.config.model import Configuration - -from nowcasting_dataset.data_sources.datetime.datetime_model import Datetime -from nowcasting_dataset.data_sources.metadata.metadata_model import Metadata -from nowcasting_dataset.data_sources.gsp.gsp_model import GSP -from nowcasting_dataset.data_sources.nwp.nwp_model import NWP -from nowcasting_dataset.data_sources.pv.pv_model import PV -from nowcasting_dataset.data_sources.satellite.satellite_model import Satellite -from nowcasting_dataset.data_sources.sun.sun_model import Sun -from nowcasting_dataset.data_sources.topographic.topographic_model import Topographic +from nowcasting_dataset.data_sources.datetime.datetime_model import DatetimeML, Datetime +from nowcasting_dataset.data_sources.gsp.gsp_model import GSPML, GSP +from nowcasting_dataset.data_sources.metadata.metadata_model import MetadataML, Metadata +from nowcasting_dataset.data_sources.nwp.nwp_model import ( + NWPML, + NWP, +) +from nowcasting_dataset.data_sources.pv.pv_model import PVML, PV +from nowcasting_dataset.data_sources.satellite.satellite_model import SatelliteML, Satellite +from nowcasting_dataset.data_sources.sun.sun_model import SunML, Sun +from nowcasting_dataset.data_sources.topographic.topographic_model import TopographicML, Topographic +from nowcasting_dataset.dataset.xr_utils import ( + register_xr_data_array_to_tensor, + register_xr_data_set_to_tensor, +) from nowcasting_dataset.time import make_random_time_vectors -from nowcasting_dataset.utils import get_netcdf_filename +from nowcasting_dataset.data_sources.fake import ( + datetime_fake, + metadata_fake, + gsp_fake, + pv_fake, + satellite_fake, + sun_fake, + topographic_fake, + nwp_fake, +) _LOG = logging.getLogger(__name__) +register_xr_data_array_to_tensor() +register_xr_data_set_to_tensor() + +data_sources = [Metadata, Satellite, Topographic, PV, Sun, GSP, NWP, Datetime] -class Example(BaseModel): - """Single Data item""" - metadata: Metadata +class Batch(BaseModel): + """ + Batch data object + + Contains the following data sources + - gsp, satellite, topogrpahic, sun, pv, nwp and datetime. + Also contains metadata of the class. + + All data sources are xr.Datasets + + """ + + batch_size: int = Field( + ..., + g=0, + description="The size of this batch. If the batch size is 0, " + "then this item stores one data item", + ) + + metadata: Optional[Metadata] satellite: Optional[Satellite] topographic: Optional[Topographic] pv: Optional[PV] @@ -37,15 +73,110 @@ class Example(BaseModel): nwp: Optional[NWP] datetime: Optional[Datetime] - def change_type_to_numpy(self): - """Change data to numpy""" + @property + def data_sources(self): + """The different data sources""" + return [ + self.satellite, + self.topographic, + self.pv, + self.sun, + self.gsp, + self.nwp, + self.datetime, + self.metadata, + ] + + @staticmethod + def fake(configuration: Configuration = Configuration()): + """ Make fake batch object """ + batch_size = configuration.process.batch_size + seq_length_5 = configuration.process.seq_length_5_minutes + seq_length_30 = configuration.process.seq_length_30_minutes + image_size_pixels = configuration.process.satellite_image_size_pixels + + return Batch( + batch_size=batch_size, + satellite=satellite_fake( + batch_size=batch_size, + seq_length_5=seq_length_5, + satellite_image_size_pixels=image_size_pixels, + number_sat_channels=len(configuration.process.sat_channels), + ), + nwp=nwp_fake( + batch_size=batch_size, + seq_length_5=seq_length_5, + image_size_pixels=image_size_pixels, + number_nwp_channels=len(configuration.process.nwp_channels), + ), + metadata=metadata_fake(batch_size=batch_size), + pv=pv_fake( + batch_size=batch_size, seq_length_5=seq_length_5, n_pv_systems_per_batch=128 + ), + gsp=gsp_fake(batch_size=batch_size, seq_length_30=seq_length_30, n_gsp_per_batch=32), + sun=sun_fake(batch_size=batch_size, seq_length_5=seq_length_5), + topographic=topographic_fake( + batch_size=batch_size, image_size_pixels=image_size_pixels + ), + datetime=datetime_fake(batch_size=batch_size, seq_length_5=seq_length_5), + ) + + def save_netcdf(self, batch_i: int, path: Path): + """ + Save batch to netcdf file + + Args: + batch_i: the batch id, used to make the filename + path: the path where it will be saved. This can be local or in the cloud. + + """ for data_source in self.data_sources: if data_source is not None: - data_source.to_numpy() + data_source.save_netcdf(batch_i=batch_i, path=path) + + @staticmethod + def load_netcdf(local_netcdf_path: Union[Path, str], batch_idx: int): + """Load batch from netcdf file""" + data_sources_names = Example.__fields__.keys() + + # collect data sources + batch_dict = {} + for data_source_name in data_sources_names: + + local_netcdf_filename = os.path.join( + local_netcdf_path, data_source_name, f"{batch_idx}.nc" + ) + if os.path.exists(local_netcdf_filename): + xr_dataset = xr.load_dataset(local_netcdf_filename) + else: + xr_dataset = None + + batch_dict[data_source_name] = xr_dataset + + batch_dict["batch_size"] = len(batch_dict["metadata"].example) + + return Batch(**batch_dict) + + +class Example(BaseModel): + """ + Single Data item + + Note that this is currently not really used + """ + + metadata: Optional[MetadataML] + satellite: Optional[SatelliteML] + topographic: Optional[TopographicML] + pv: Optional[PVML] + sun: Optional[SunML] + gsp: Optional[GSPML] + nwp: Optional[NWPML] + datetime: Optional[DatetimeML] @property def data_sources(self): - """ The different data sources """ + """The different data sources""" return [ self.satellite, self.topographic, @@ -58,7 +189,7 @@ def data_sources(self): ] -class Batch(Example): +class BatchML(Example): """ Batch data object. @@ -75,45 +206,6 @@ class Batch(Example): "then this item stores one data item", ) - def batch_to_dict_dataset(self) -> Dict[str, xr.Dataset]: - """Change batch to xr.Dataset so it can be saved and compressed""" - return batch_to_dict_dataset(batch=self) - - @staticmethod - def load_batch_from_dict_dataset(xr_dataset: Dict[str, xr.Dataset]): - """Change dictionary of xr.Datatset to Batch object""" - # get a list of data sources - data_sources_names = Example.__fields__.keys() - - # collect data sources - data_sources_dict = {} - for data_source_name in data_sources_names: - cls = Example.__fields__[data_source_name].type_ - data_sources_dict[data_source_name] = cls.from_xr_dataset( - xr_dataset=xr_dataset[data_source_name] - ) - - data_sources_dict["batch_size"] = data_sources_dict["metadata"].batch_size - - return Batch(**data_sources_dict) - - def split(self) -> List[Example]: - """Split batch into list of data items""" - # collect split data - split_data_dict = {} - for data_source in self.data_sources: - if data_source is not None: - cls = data_source.__class__.__name__.lower() - split_data_dict[cls] = data_source.split() - - # make in to Example objects - data_items = [] - for batch_idx in range(self.batch_size): - split_data_one_example_dict = {k: v[batch_idx] for k, v in split_data_dict.items()} - data_items.append(Example(**split_data_one_example_dict)) - - return data_items - @staticmethod def fake(configuration: Configuration = Configuration()): """Create fake batch""" @@ -125,117 +217,56 @@ def fake(configuration: Configuration = Configuration()): seq_length_30_minutes=process.seq_length_30_minutes, ) - return Batch( + return BatchML( batch_size=process.batch_size, - metadata=Metadata.fake(batch_size=process.batch_size, t0_dt=t0_dt), - satellite=Satellite.fake( + metadata=MetadataML.fake(batch_size=process.batch_size, t0_dt=t0_dt), + satellite=SatelliteML.fake( process.batch_size, process.seq_length_5_minutes, process.satellite_image_size_pixels, len(process.sat_channels), time_5=time_5, ), - topographic=Topographic.fake( + topographic=TopographicML.fake( batch_size=process.batch_size, - satellite_image_size_pixels=process.satellite_image_size_pixels, + image_size_pixels=process.satellite_image_size_pixels, ), - pv=PV.fake( + pv=PVML.fake( batch_size=process.batch_size, seq_length_5=process.seq_length_5_minutes, n_pv_systems_per_batch=128, time_5=time_5, ), - sun=Sun.fake(batch_size=process.batch_size, seq_length_5=process.seq_length_5_minutes), - gsp=GSP.fake( - batch_size=process.batch_size, - seq_length_30=process.seq_length_30_minutes, - n_gsp_per_batch=32, - time_30=time_30, + sun=SunML.fake( + batch_size=process.batch_size, seq_length_5=process.seq_length_5_minutes ), - nwp=NWP.fake( + nwp=NWPML.fake( batch_size=process.batch_size, seq_length_5=process.seq_length_5_minutes, - nwp_image_size_pixels=process.nwp_image_size_pixels, + image_size_pixels=process.nwp_image_size_pixels, number_nwp_channels=len(process.nwp_channels), time_5=time_5, ), - datetime=Datetime.fake( + datetime=DatetimeML.fake( batch_size=process.batch_size, seq_length_5=process.seq_length_5_minutes ), ) - def save_netcdf(self, batch_i: int, path: Path): - """ - Save batch to netcdf file - - Args: - batch_i: the batch id, used to make the filename - path: the path where it will be saved. This can be local or in the cloud. - - """ - batch_xr = self.batch_to_dict_dataset() - - for data_source in self.data_sources: - xr_dataset = batch_xr[data_source.get_name()] - data_source.save_netcdf(batch_i=batch_i, path=path, xr_dataset=xr_dataset) - @staticmethod - def load_netcdf(local_netcdf_path: Union[Path, str], batch_idx: int): - """Load batch from netcdf file""" + def from_batch(batch: Batch) -> BatchML: + """ Change batch to ML batch """ data_sources_names = Example.__fields__.keys() - # collect data sources - batch_dict = {} + data_sources_dict = {} for data_source_name in data_sources_names: - local_netcdf_filename = os.path.join( - local_netcdf_path, data_source_name, f"{batch_idx}.nc" - ) - xr_dataset = xr.load_dataset(local_netcdf_filename) - - batch_dict[data_source_name] = xr_dataset - - return Batch.load_batch_from_dict_dataset(batch_dict) - - -def batch_to_dict_dataset(batch: Batch) -> Dict[str, xr.Dataset]: - """Concat all the individual fields in an Example into a dictionary of Datasets. - - Args: - batch: List of Example objects, which together constitute a single batch. - """ - individual_datasets = {} - split_batch = batch.split() - - # loop over each data source - for data_source in split_batch[0].data_sources: + data_source = BatchML.__fields__[data_source_name].type_ - datasets = [] - name = data_source.get_name() + xr_dataset = getattr(batch, data_source_name) + if xr_dataset is not None: - # loop over each item in the batch - for i, example in enumerate(split_batch): + data_sources_dict[data_source_name] = data_source.from_xr_dataset(xr_dataset) - if data_source is not None: - datasets.append(getattr(split_batch[i], name).to_xr_dataset(i)) - - # Merge - merged_ds = xr.concat(datasets, dim="example") - individual_datasets[name] = merged_ds - - return individual_datasets - - -def write_batch_locally(batch: Union[Batch, dict], batch_i: int, path: Path): - """ - Write a batch to a locally file - - Args: - batch: A batch of data - batch_i: The number of the batch - path: The directory to write the batch into. - """ - if type(batch): - batch = Batch(**batch) + data_sources_dict["batch_size"] = data_sources_dict["satellite"].batch_size - batch.save_netcdf(batch_i=batch_i, path=path) + return BatchML(**data_sources_dict) diff --git a/nowcasting_dataset/dataset/datamodule.py b/nowcasting_dataset/dataset/datamodule.py index 3c583476..896227fb 100644 --- a/nowcasting_dataset/dataset/datamodule.py +++ b/nowcasting_dataset/dataset/datamodule.py @@ -1,7 +1,6 @@ """ Data Modules """ import logging import warnings -from copy import deepcopy from dataclasses import dataclass from pathlib import Path from typing import Union, Optional, Iterable, Dict, Callable @@ -14,8 +13,8 @@ from nowcasting_dataset import time as nd_time from nowcasting_dataset import utils from nowcasting_dataset.data_sources.gsp.gsp_data_source import GSPDataSource -from nowcasting_dataset.data_sources.sun.sun_data_source import SunDataSource from nowcasting_dataset.data_sources.metadata.metadata_data_source import MetadataDataSource +from nowcasting_dataset.data_sources.sun.sun_data_source import SunDataSource from nowcasting_dataset.dataset import datasets from nowcasting_dataset.dataset.split.split import split_data, SplitMethod diff --git a/nowcasting_dataset/dataset/datasets.py b/nowcasting_dataset/dataset/datasets.py index 6bdca87f..f1bff1aa 100644 --- a/nowcasting_dataset/dataset/datasets.py +++ b/nowcasting_dataset/dataset/datasets.py @@ -15,25 +15,16 @@ from nowcasting_dataset import data_sources from nowcasting_dataset import utils as nd_utils -from nowcasting_dataset.filesystem.utils import download_to_local, delete_all_files_in_temp_path from nowcasting_dataset.config.model import Configuration from nowcasting_dataset.consts import ( - GSP_YIELD, - GSP_DATETIME_INDEX, SATELLITE_DATA, - NWP_DATA, - PV_YIELD, - SUN_ELEVATION_ANGLE, - SUN_AZIMUTH_ANGLE, - SATELLITE_DATETIME_INDEX, - NWP_TARGET_TIME, - PV_DATETIME_INDEX, DEFAULT_REQUIRED_KEYS, ) from nowcasting_dataset.data_sources.satellite.satellite_data_source import SAT_VARIABLE_NAMES - -from nowcasting_dataset.utils import set_fsspec_for_multiprocess, to_numpy -from nowcasting_dataset.dataset.batch import Batch +from nowcasting_dataset.dataset.batch import BatchML, Batch +from nowcasting_dataset.dataset.subset import subselect_data +from nowcasting_dataset.filesystem.utils import download_to_local, delete_all_files_in_temp_path +from nowcasting_dataset.utils import set_fsspec_for_multiprocess logger = logging.getLogger(__name__) @@ -168,7 +159,7 @@ def __len__(self): """ Length of dataset """ return self.n_batches - def __getitem__(self, batch_idx: int) -> Batch: + def __getitem__(self, batch_idx: int) -> dict: """Returns a whole batch at once. Args: @@ -184,12 +175,9 @@ def __getitem__(self, batch_idx: int) -> Batch: raise IndexError( "batch_idx must be in the range" f" [0, {self.n_batches}), not {batch_idx}!" ) - netcdf_filename = nd_utils.get_netcdf_filename(batch_idx) - # remote_netcdf_folder = os.path.join(self.src_path, netcdf_filename) - # local_netcdf_filename = os.path.join(self.tmp_path, netcdf_filename) if self.cloud in ["gcp", "aws"]: - # TODO check this works for mulitple files + # TODO check this works for multiple files download_to_local( remote_filename=self.src_path, local_filename=self.tmp_path, @@ -199,34 +187,33 @@ def __getitem__(self, batch_idx: int) -> Batch: local_netcdf_folder = self.src_path batch = Batch.load_netcdf(local_netcdf_folder, batch_idx=batch_idx) + + if self.select_subset_data: + batch = subselect_data( + batch=batch, + history_minutes=self.history_minutes, + forecast_minutes=self.forecast_minutes, + current_timestep_index=self.current_timestep_5_index, + ) + + # change batch into ML learning batch ready for training + batch = BatchML.from_batch(batch=batch) + # netcdf_batch = xr.load_dataset(local_netcdf_filename) if self.cloud != "local": # remove files in a folder, but not the folder itself delete_all_files_in_temp_path(self.src_path) - # batch = example.xr_to_example(batch_xr=netcdf_batch, required_keys=self.required_keys) - - # Todo this may should be done when the data is created + # Todo issue - https://github.com/openclimatefix/nowcasting_dataset/issues/231 if SATELLITE_DATA in self.required_keys: - sat_data = batch.satellite.sat_data + sat_data = batch.satellite.data if sat_data.dtype == np.int16: sat_data = sat_data.astype(np.float32) sat_data = sat_data - SAT_MEAN sat_data = sat_data / SAT_STD - batch.satellite.sat_data = sat_data - - if self.select_subset_data: - batch = subselect_data( - batch=batch, - required_keys=self.required_keys, - history_minutes=self.history_minutes, - forecast_minutes=self.forecast_minutes, - current_timestep_index=self.current_timestep_5_index, - ) - - batch.change_type_to_numpy() + batch.satellite.data = sat_data - return batch + return batch.dict() @dataclass @@ -296,7 +283,7 @@ def __iter__(self): for _ in range(self.n_batches_per_epoch_per_worker): yield self._get_batch() - def _get_batch(self) -> torch.Tensor: + def _get_batch(self) -> Batch: _LOG.debug(f"Getting batch {self.batch_index}") @@ -329,14 +316,11 @@ def _get_batch(self) -> torch.Tensor: # print(type(examples_from_source)) name = type(examples_from_source).__name__.lower() - examples[name] = examples_from_source.dict() + examples[name] = examples_from_source examples["batch_size"] = len(t0_datetimes) - b = Batch(**examples) - - # return as dictionary because .... # TODO - return b.dict() + return Batch(**examples) def _get_t0_datetimes_for_batch(self) -> pd.DatetimeIndex: # Pick random datetimes. @@ -370,95 +354,3 @@ def worker_init_fn(worker_id): # The NowcastingDataset copy in this worker process. dataset_obj = worker_info.dataset dataset_obj.per_worker_init(worker_info.id) - - -def subselect_data( - batch: Batch, - required_keys: Union[Tuple[str], List[str]], - history_minutes: int, - forecast_minutes: int, - current_timestep_index: Optional[int] = None, -) -> Batch: - """ - Subselects the data temporally. This function selects all data within the time range [t0 - history_minutes, t0 + forecast_minutes] - - Args: - batch: Example dictionary containing at least the required_keys - required_keys: The required keys present in the dictionary to use - current_timestep_index: The index into either SATELLITE_DATETIME_INDEX or NWP_TARGET_TIME giving the current timestep - history_minutes: How many minutes of history to use - forecast_minutes: How many minutes of future data to use for forecasting - - Returns: - Example with only data between [t0 - history_minutes, t0 + forecast_minutes] remaining - """ - _LOG.debug( - f"Select sub data with new historic minutes of {history_minutes} " - f"and forecast minutes if {forecast_minutes}" - ) - - # We are subsetting the data, so we need to select the t0_dt, i.e the time now for eahc Example. - # We infact only need this from the first example in each batch - if current_timestep_index is None: - # t0_dt or if not available use a different datetime index - t0_dt_of_first_example = batch.metadata.t0_dt[0].values - else: - if SATELLITE_DATA in required_keys: - t0_dt_of_first_example = batch.satellite.sat_datetime_index[ - 0, current_timestep_index - ].values - else: - t0_dt_of_first_example = batch.satellite.sat_datetime_index[ - 0, current_timestep_index - ].values - - # make this a datetime object - t0_dt_of_first_example = pd.to_datetime(t0_dt_of_first_example) - - if batch.satellite is not None: - batch.satellite.select_time_period( - keys=[SATELLITE_DATA, SATELLITE_DATETIME_INDEX], - history_minutes=history_minutes, - forecast_minutes=forecast_minutes, - t0_dt_of_first_example=t0_dt_of_first_example, - ) - - # Now for NWP, if used - if batch.nwp is not None: - batch.nwp.select_time_period( - keys=[NWP_DATA, NWP_TARGET_TIME], - history_minutes=history_minutes, - forecast_minutes=forecast_minutes, - t0_dt_of_first_example=t0_dt_of_first_example, - ) - # - # Now for GSP, if used - if batch.gsp is not None: - batch.gsp.select_time_period( - keys=[GSP_DATETIME_INDEX, GSP_YIELD], - history_minutes=history_minutes, - forecast_minutes=forecast_minutes, - t0_dt_of_first_example=t0_dt_of_first_example, - ) - - # Now for PV, if used - if batch.pv is not None: - batch.pv.select_time_period( - keys=[PV_DATETIME_INDEX, PV_YIELD], - history_minutes=history_minutes, - forecast_minutes=forecast_minutes, - t0_dt_of_first_example=t0_dt_of_first_example, - ) - - # Now for SUN, if used - if batch.sun is not None: - batch.sun.select_time_period( - keys=[SUN_ELEVATION_ANGLE, SUN_AZIMUTH_ANGLE], - history_minutes=history_minutes, - forecast_minutes=forecast_minutes, - t0_dt_of_first_example=t0_dt_of_first_example, - ) - - # DATETIME TODO - - return batch diff --git a/nowcasting_dataset/dataset/fake.py b/nowcasting_dataset/dataset/fake.py new file mode 100644 index 00000000..ca258d7f --- /dev/null +++ b/nowcasting_dataset/dataset/fake.py @@ -0,0 +1,43 @@ +""" A class to create a fake dataset """ +import torch + +from nowcasting_dataset.config.model import Configuration +from nowcasting_dataset.dataset.batch import BatchML + + +class FakeDataset(torch.utils.data.Dataset): + """Fake dataset.""" + + def __init__(self, configuration: Configuration, length: int = 10): + """ + Init + + Args: + configuration: configuration object + length: length of dataset + """ + self.number_nwp_channels = len(configuration.process.nwp_channels) + self.length = length + self.configuration = configuration + + def __len__(self): + """ Number of pieces of data """ + return self.length + + def per_worker_init(self, worker_id: int): + """ Nothing to do for FakeDataset """ + pass + + def __getitem__(self, idx): + """ + Get item, use for iter and next method + + Args: + idx: batch index + + Returns: Dictionary of random data + + """ + x = BatchML.fake(configuration=self.configuration) + + return x.dict() diff --git a/nowcasting_dataset/dataset/subset.py b/nowcasting_dataset/dataset/subset.py new file mode 100644 index 00000000..2e70a895 --- /dev/null +++ b/nowcasting_dataset/dataset/subset.py @@ -0,0 +1,145 @@ +""" Take subsets of xr.datasets """ +import logging +from datetime import datetime +from typing import Optional, Union + +import numpy as np +import pandas as pd + +from nowcasting_dataset.dataset.batch import Batch + +logger = logging.getLogger(__name__) + + +def subselect_data( + batch: Batch, + history_minutes: int, + forecast_minutes: int, + current_timestep_index: Optional[int] = None, +) -> Batch: + """ + Subselects the data temporally. This function selects all data within the time range [t0 - history_minutes, t0 + forecast_minutes] + + Args: + batch: Example dictionary containing at least the required_keys + required_keys: The required keys present in the dictionary to use + current_timestep_index: The index into either SATELLITE_DATETIME_INDEX or NWP_TARGET_TIME giving the current timestep + history_minutes: How many minutes of history to use + forecast_minutes: How many minutes of future data to use for forecasting + + Returns: + Example with only data between [t0 - history_minutes, t0 + forecast_minutes] remaining + """ + logger.debug( + f"Select sub data with new historic minutes of {history_minutes} " + f"and forecast minutes if {forecast_minutes}" + ) + + # We are subsetting the data, so we need to select the t0_dt, i.e the time now for each Example. + # We in fact only need this from the first example in each batch + if current_timestep_index is None: + # t0_dt or if not available use a different datetime index + t0_dt_of_first_example = batch.metadata.t0_dt[0].values + else: + if batch.satellite is not None: + t0_dt_of_first_example = batch.satellite.time[0, current_timestep_index].values + else: + t0_dt_of_first_example = batch.nwp.time[0, current_timestep_index].values + + if batch.satellite is not None: + batch.satellite = select_time_period( + x=batch.satellite, + history_minutes=history_minutes, + forecast_minutes=forecast_minutes, + t0_dt_of_first_example=t0_dt_of_first_example, + ) + + # Now for NWP, if used + if batch.nwp is not None: + batch.nwp = select_time_period( + x=batch.nwp, + history_minutes=history_minutes, + forecast_minutes=forecast_minutes, + t0_dt_of_first_example=t0_dt_of_first_example, + ) + + # Now for GSP, if used + if batch.gsp is not None: + batch.gsp = select_time_period( + x=batch.gsp, + history_minutes=history_minutes, + forecast_minutes=forecast_minutes, + t0_dt_of_first_example=t0_dt_of_first_example, + ) + + # Now for PV, if used + if batch.pv is not None: + batch.pv = select_time_period( + x=batch.pv, + history_minutes=history_minutes, + forecast_minutes=forecast_minutes, + t0_dt_of_first_example=t0_dt_of_first_example, + ) + + # Now for SUN, if used + if batch.sun is not None: + batch.sun = select_time_period( + x=batch.sun, + history_minutes=history_minutes, + forecast_minutes=forecast_minutes, + t0_dt_of_first_example=t0_dt_of_first_example, + ) + + # DATETIME TODO + + return batch + + +def select_time_period( + x, + history_minutes: int, + forecast_minutes: int, + t0_dt_of_first_example: Union[datetime, pd.Timestamp], +): + """ + Selects a subset of data between the indicies of [start, end] for each key in keys + + Note that class is edited so nothing is returned. + + Args: + x: dataset that is ot be reduced + t0_dt_of_first_example: datetime of the current time (t0) in the first example of the batch + history_minutes: How many minutes of history to use + forecast_minutes: How many minutes of future data to use for forecasting + + """ + logger.debug( + f"Taking a sub-selection of the batch data based on a history minutes of {history_minutes} " + f"and forecast minutes of {forecast_minutes}" + ) + + start_time_of_first_example = t0_dt_of_first_example - pd.to_timedelta( + f"{history_minutes} minute 30 second" + ) + end_time_of_first_example = t0_dt_of_first_example + pd.to_timedelta( + f"{forecast_minutes} minute 30 second" + ) + + logger.debug(f"New start time for first example is {start_time_of_first_example}") + logger.debug(f"New end time for first example is {end_time_of_first_example}") + + if hasattr(x, "time"): + + time_of_first_example = pd.to_datetime(x.time[0]) + + else: + # for nwp, maybe reaname + time_of_first_example = pd.to_datetime(x.target_time[0]) + + # find the start and end index, that we will then use to slice the data + start_i, end_i = np.searchsorted( + time_of_first_example, [start_time_of_first_example, end_time_of_first_example] + ) + + # slice all the data + return x.where(((x.time_index >= start_i) & (x.time_index < end_i)), drop=True) diff --git a/nowcasting_dataset/dataset/validate.py b/nowcasting_dataset/dataset/validate.py deleted file mode 100644 index 0e23202e..00000000 --- a/nowcasting_dataset/dataset/validate.py +++ /dev/null @@ -1,137 +0,0 @@ -""" A class to validate the prepare ml dataset """ -from typing import Union - -import numpy as np -import pandas as pd -import torch - -from nowcasting_dataset.config.model import Configuration -from nowcasting_dataset.consts import ( - GSP_DATETIME_INDEX, - DEFAULT_N_GSP_PER_EXAMPLE, - DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE, - GSP_ID, - GSP_YIELD, - GSP_X_COORDS, - GSP_Y_COORDS, - OBJECT_AT_CENTER, - PV_SYSTEM_ID, - PV_YIELD, - PV_SYSTEM_X_COORDS, - PV_SYSTEM_Y_COORDS, - PV_SYSTEM_ROW_NUMBER, - SUN_AZIMUTH_ANGLE, - SUN_ELEVATION_ANGLE, - DATETIME_FEATURE_NAMES, - TOPOGRAPHIC_X_COORDS, - TOPOGRAPHIC_DATA, - TOPOGRAPHIC_Y_COORDS, -) -from nowcasting_dataset.dataset.datasets import NetCDFDataset, logger - -# from nowcasting_dataset.dataset.example import Example -from nowcasting_dataset.dataset.batch import Batch - - -class ValidatorDataset: - """ - Validation of a dataset - """ - - def __init__( - self, - batches: Union[NetCDFDataset, torch.utils.data.DataLoader], - configuration: Configuration, - ): - """ - Initialize class and run validation - - Args: - batches: Dataset that needs validating - configuration: Configuration file - """ - self.batches = batches - self.configuration = configuration - - self.validate() - - def validate(self): - """ - This validates the batches, and calculates unique days that are in the all the batches - - """ - logger.debug("Validating dataset") - assert self.configuration is not None - - day_datetimes = None - for batch_idx, batch in enumerate(self.batches): - logger.info(f"Validating batch {batch_idx}") - - # change dict to Batch, this does some validation - if type(batch) == dict: - batch = Batch(**batch) - - all_day_from_batch_unique = self.validate_and_get_day_datetimes_for_one_batch( - batch=batch - ) - if day_datetimes is None: - day_datetimes = all_day_from_batch_unique - else: - day_datetimes = day_datetimes.join(all_day_from_batch_unique) - - self.day_datetimes = day_datetimes - - def validate_and_get_day_datetimes_for_one_batch(self, batch: Batch): - """ - For one batch, validate, and return the day datetimes in that batch - - Args: - batch: batch data - - Returns: list of days that the batch has data for - - """ - if type(batch.metadata.t0_dt) == torch.Tensor: - batch.metadata.t0_dt = batch.metadata.t0_dt.detach().numpy() - - all_datetimes_from_batch = pd.to_datetime(batch.metadata.t0_dt.reshape(-1), unit="s") - return pd.DatetimeIndex(all_datetimes_from_batch.date).unique() - - -class FakeDataset(torch.utils.data.Dataset): - """Fake dataset.""" - - def __init__(self, configuration: Configuration, length: int = 10): - """ - Init - - Args: - configuration: configuration object - length: length of dataset - """ - self.number_nwp_channels = len(configuration.process.nwp_channels) - self.length = length - self.configuration = configuration - - def __len__(self): - """ Number of pieces of data """ - return self.length - - def per_worker_init(self, worker_id: int): - """ Not needed """ - pass - - def __getitem__(self, idx): - """ - Get item, use for iter and next method - - Args: - idx: batch index - - Returns: Dictionary of random data - - """ - x = Batch.fake(configuration=self.configuration) - x.change_type_to_numpy() - - return x.dict() diff --git a/nowcasting_dataset/dataset/xr_utils.py b/nowcasting_dataset/dataset/xr_utils.py new file mode 100644 index 00000000..b0ec41d0 --- /dev/null +++ b/nowcasting_dataset/dataset/xr_utils.py @@ -0,0 +1,154 @@ +""" Useful functions for xarray objects + +1. joining data arrays to datasets +2. pydantic exentions model of xr.Dataset +3. xr array and xr dataset --> to torch functions +""" +from typing import List, Any + +import numpy as np +import torch +import xarray as xr + + +def join_list_data_array_to_batch_dataset(image_data_arrays: List[xr.DataArray]) -> xr.Dataset: + """ Join a list of data arrays to a dataset byt expanding dims """ + image_data_arrays = [ + convert_data_array_to_dataset(image_data_arrays[i]) for i in range(len(image_data_arrays)) + ] + + return join_dataset_to_batch_dataset(image_data_arrays) + + +def join_dataset_to_batch_dataset(image_data_arrays: List[xr.Dataset]) -> xr.Dataset: + """ Join a list of data arrays to a dataset byt expanding dims """ + image_data_arrays = [ + image_data_arrays[i].expand_dims(dim="example").assign_coords(example=("example", [i])) + for i in range(len(image_data_arrays)) + ] + + return xr.concat(image_data_arrays, dim="example") + + +def convert_data_array_to_dataset(data_xarray): + """ Convert data array to dataset. Reindex dim so that it can be merged with batch""" + data = xr.Dataset({"data": data_xarray}) + + return make_dim_index(data_xarray_dataset=data) + + +def make_dim_index(data_xarray_dataset: xr.Dataset) -> xr.Dataset: + """ Reindex dataset dims so that it can be merged with batch""" + + dims = data_xarray_dataset.dims + + for dim in dims: + coord = data_xarray_dataset[dim] + data_xarray_dataset[dim] = np.arange(len(coord)) + + data_xarray_dataset = data_xarray_dataset.rename({dim: f"{dim}_index"}) + + data_xarray_dataset[dim] = xr.DataArray( + coord, coords=data_xarray_dataset[f"{dim}_index"].coords, dims=[f"{dim}_index"] + ) + + return data_xarray_dataset + + +class PydanticXArrayDataSet(xr.Dataset): + """Pydantic Xarray Dataset Class + + Adapted from https://pydantic-docs.helpmanual.io/usage/types/#classes-with-__get_validators__ + + """ + + _expected_dimensions = () # Subclasses should set this. + + # xarray doesnt support sub classing at the moment - https://github.com/pydata/xarray/issues/3980 + __slots__ = () + + @classmethod + def model_validation(cls, v): + """ Specific model validation, to be overwritten by class """ + return v + + @classmethod + def __get_validators__(cls): + """Get validators""" + yield cls.validate + + @classmethod + def validate(cls, v: Any) -> Any: + """Do validation""" + v = cls.validate_dims(v) + v = cls.validate_coords(v) + v = cls.model_validation(v) + return v + + @classmethod + def validate_dims(cls, v: Any) -> Any: + """Validate the dims""" + assert all( + dim.replace("_index", "") in cls._expected_dimensions + for dim in v.dims + if dim != "example" + ), ( + f"{cls.__name__}.dims is wrong! " + f"{cls.__name__}.dims is {v.dims}. " + f"But we expected {cls._expected_dimensions}. Note that '_index' is removed, and 'example' is ignored" + ) + return v + + @classmethod + def validate_coords(cls, v: Any) -> Any: + """Validate the coords""" + for dim in cls._expected_dimensions: + coord = v.coords[f"{dim}_index"] + assert len(coord) > 0, f"{dim}_index is empty in {cls.__name__}!" + return v + + +def register_xr_data_array_to_tensor(): + """ Add torch object to data array """ + if not hasattr(xr.DataArray, "torch"): + + @xr.register_dataarray_accessor("torch") + class TorchAccessor: + def __init__(self, xarray_obj): + self._obj = xarray_obj + + def to_tensor(self): + """Convert this DataArray to a torch.Tensor""" + return torch.tensor(self._obj.data) + + # torch tensor names does not working in dataloader yet - 2021-10-15 + # https://discuss.pytorch.org/t/collating-named-tensors/78650 + # https://github.com/openclimatefix/nowcasting_dataset/issues/25 + # def to_named_tensor(self): + # """Convert this DataArray to a torch.Tensor with named dimensions""" + # import torch + # + # return torch.tensor(self._obj.data, names=self._obj.dims) + + +def register_xr_data_set_to_tensor(): + """ Add torch object to dataset """ + if not hasattr(xr.Dataset, "torch"): + + @xr.register_dataset_accessor("torch") + class TorchAccessor: + def __init__(self, xdataset_obj: xr.Dataset): + self._obj = xdataset_obj + + def to_tensor(self, dims: List[str]) -> dict: + """Convert this Dataset to dictionary of torch tensors""" + torch_dict = {} + + for dim in dims: + v = getattr(self._obj, dim) + if dim.find("time") != -1: + v = v.astype(np.int32) + + torch_dict[dim] = v.torch.to_tensor() + + return torch_dict diff --git a/nowcasting_dataset/filesystem/utils.py b/nowcasting_dataset/filesystem/utils.py index 756586d4..51319825 100644 --- a/nowcasting_dataset/filesystem/utils.py +++ b/nowcasting_dataset/filesystem/utils.py @@ -5,7 +5,6 @@ import fsspec - _LOG = logging.getLogger("nowcasting_dataset") @@ -30,6 +29,11 @@ def get_maximum_batch_id(path: str): """ _LOG.debug(f"Looking for maximum batch id in {path}") + filesystem = fsspec.open(path).fs + if not filesystem.exists(path): + _LOG.debug(f"{path} does not exists") + return None + filenames = get_all_filenames_in_path(path=path) # just take filename @@ -53,17 +57,24 @@ def get_maximum_batch_id(path: str): return maximum_batch_id -def delete_all_files_in_temp_path(path: Union[Path, str]): +def delete_all_files_in_temp_path(path: Union[Path, str], delete_dirs: bool = False): """ - Delete all the files in a temporary path + Delete all the files in a temporary path. Option to delete the folders or not """ filesystem = fsspec.open(path).fs filenames = get_all_filenames_in_path(path=path) _LOG.info(f"Deleting {len(filenames)} files from {path}.") - for file in filenames: - filesystem.rm(file, recursive=True) + if delete_dirs: + for file in filenames: + filesystem.rm(file, recursive=True) + else: + # loop over folder structure, but only delete files + for root, dirs, files in filesystem.walk(path): + + for f in files: + filesystem.rm(f"{root}/{f}") def check_path_exists(path: Union[str, Path]): @@ -132,4 +143,5 @@ def upload_one_file( def make_folder(path: Union[str, Path]): """ Make folder """ filesystem = fsspec.open(path).fs - filesystem.mkdir(path) + if not filesystem.exists(path): + filesystem.mkdir(path) diff --git a/nowcasting_dataset/time.py b/nowcasting_dataset/time.py index fb04c51c..a60d1f20 100644 --- a/nowcasting_dataset/time.py +++ b/nowcasting_dataset/time.py @@ -1,15 +1,15 @@ """ Time functions """ import logging +import random import warnings from typing import Iterable, Tuple, List, Dict import numpy as np import pandas as pd +import xarray as xr import pvlib -import random from nowcasting_dataset import geospatial, utils -from nowcasting_dataset.data_sources.datetime.datetime_model import Datetime logger = logging.getLogger(__name__) @@ -267,7 +267,7 @@ def datetime_features(index: pd.DatetimeIndex) -> pd.DataFrame: return pd.DataFrame(features, index=index).astype(np.float32) -def datetime_features_in_example(index: pd.DatetimeIndex) -> Datetime: +def datetime_features_in_example(index: pd.DatetimeIndex) -> xr.Dataset: """ Make datetime features with sin and cos @@ -282,13 +282,7 @@ def datetime_features_in_example(index: pd.DatetimeIndex) -> Datetime: dt_features["day_of_year"] /= 365 dt_features = utils.sin_and_cos(dt_features) - datetime_dict = {} - for col_name, series in dt_features.iteritems(): - datetime_dict[col_name] = series.values - - datetime_dict["datetime_index"] = series.index.values - - return Datetime(**datetime_dict) + return dt_features.to_xarray() def make_random_time_vectors(batch_size, seq_length_5_minutes, seq_length_30_minutes): @@ -324,8 +318,8 @@ def make_random_time_vectors(batch_size, seq_length_5_minutes, seq_length_30_min - int(seq_length_30_minutes / 2) * delta_5 ) - t0_dt = utils.to_numpy(t0_dt) - time_5 = utils.to_numpy(time_5.T) - time_30 = utils.to_numpy(time_30.T) + t0_dt = utils.to_numpy(t0_dt).astype(np.int32) + time_5 = utils.to_numpy(time_5.T).astype(np.int32) + time_30 = utils.to_numpy(time_30.T).astype(np.int32) return t0_dt, time_5, time_30 diff --git a/nowcasting_dataset/utils.py b/nowcasting_dataset/utils.py index 57ea8aa4..5ba4b258 100644 --- a/nowcasting_dataset/utils.py +++ b/nowcasting_dataset/utils.py @@ -1,16 +1,16 @@ """ utils functions """ import hashlib import logging +import tempfile from pathlib import Path -from typing import List, Optional +from typing import Optional import fsspec.asyn +import gcsfs import numpy as np import pandas as pd import torch import xarray as xr -import tempfile -import gcsfs from nowcasting_dataset.consts import Array @@ -109,12 +109,6 @@ def get_netcdf_filename(batch_idx: int, add_hash: bool = False) -> Path: return filename -def pad_nans(array, pad_width) -> np.ndarray: - """ Pad nans with nans""" - array = array.astype(np.float32) - return np.pad(array, pad_width, constant_values=np.NaN) - - def to_numpy(value): """ Change generic data to numpy""" if isinstance(value, xr.DataArray): diff --git a/scripts/generate_data_for_tests/get_test_data.py b/scripts/generate_data_for_tests/get_test_data.py index 3bd1d86c..7b81f671 100644 --- a/scripts/generate_data_for_tests/get_test_data.py +++ b/scripts/generate_data_for_tests/get_test_data.py @@ -144,6 +144,7 @@ ######## c = Configuration() +c.process.batch_size = 4 c.process.nwp_channels = c.process.nwp_channels[0:1] c.process.sat_channels = c.process.sat_channels[0:1] diff --git a/scripts/prepare_ml_data.py b/scripts/prepare_ml_data.py index c9082e01..06883b03 100755 --- a/scripts/prepare_ml_data.py +++ b/scripts/prepare_ml_data.py @@ -21,9 +21,11 @@ from nowcasting_dataset.filesystem.utils import check_path_exists from nowcasting_dataset.dataset.datamodule import NowcastingDataModule -from nowcasting_dataset.dataset.batch import write_batch_locally + +# from nowcasting_dataset.dataset.batch import write_batch_locally from nowcasting_dataset.data_sources.satellite.satellite_data_source import SAT_VARIABLE_NAMES from nowcasting_dataset.data_sources.nwp.nwp_data_source import NWP_VARIABLE_NAMES +from nowcasting_dataset.dataset.batch import Batch from pathy import Pathy from pathlib import Path import fsspec @@ -109,9 +111,11 @@ def get_data_module(): num_workers = 4 # get the batch id already made - maximum_batch_id_train = utils.get_maximum_batch_id(DST_TRAIN_PATH) - maximum_batch_id_validation = utils.get_maximum_batch_id(DST_VALIDATION_PATH) - maximum_batch_id_test = utils.get_maximum_batch_id(DST_TEST_PATH) + maximum_batch_id_train = utils.get_maximum_batch_id(os.path.join(DST_TRAIN_PATH, "metadata")) + maximum_batch_id_validation = utils.get_maximum_batch_id( + os.path.join(DST_VALIDATION_PATH, "metadata") + ) + maximum_batch_id_test = utils.get_maximum_batch_id(os.path.join(DST_TEST_PATH, "metadata")) if maximum_batch_id_train is None: maximum_batch_id_train = 0 @@ -136,6 +140,7 @@ def get_data_module(): nwp_base_path=NWP_ZARR_PATH, gsp_filename=GSP_ZARR_PATH, topographic_filename=TOPO_TIFF_PATH, + sun_filename=config.input_data.sun_zarr_path, pin_memory=False, #: Passed to DataLoader. num_workers=num_workers, #: Passed to DataLoader. prefetch_factor=8, #: Passed to DataLoader. @@ -169,14 +174,15 @@ def iterate_over_dataloader_and_write_to_disk( for batch_i, batch in enumerate(dataloader): _LOG.info(f"Got batch {batch_i}") - if len(batch) > 0: - write_batch_locally(batch, batch_i, local_output_path) + + batch.save_netcdf(batch_i=batch_i, path=local_output_path) + if UPLOAD_EVERY_N_BATCHES > 0 and batch_i > 0 and batch_i % UPLOAD_EVERY_N_BATCHES == 0: - utils.upload_and_delete_local_files(dst_path, LOCAL_TEMP_PATH, cloud=CLOUD) + utils.upload_and_delete_local_files(dst_path, LOCAL_TEMP_PATH) # Make sure we upload the last few batches, if necessary. if UPLOAD_EVERY_N_BATCHES > 0: - utils.upload_and_delete_local_files(dst_path, LOCAL_TEMP_PATH, cloud=CLOUD) + utils.upload_and_delete_local_files(dst_path, LOCAL_TEMP_PATH) def main(): diff --git a/tests/data/0.nc b/tests/data/0.nc deleted file mode 100644 index b1bdcd18..00000000 Binary files a/tests/data/0.nc and /dev/null differ diff --git a/tests/data/batch/datetime/0.nc b/tests/data/batch/datetime/0.nc index d4fe9154..ec37ee13 100644 Binary files a/tests/data/batch/datetime/0.nc and b/tests/data/batch/datetime/0.nc differ diff --git a/tests/data/batch/gsp/0.nc b/tests/data/batch/gsp/0.nc index 295d0bb7..e26b131b 100644 Binary files a/tests/data/batch/gsp/0.nc and b/tests/data/batch/gsp/0.nc differ diff --git a/tests/data/batch/metadata/0.nc b/tests/data/batch/metadata/0.nc index a89328a0..dde5a8de 100644 Binary files a/tests/data/batch/metadata/0.nc and b/tests/data/batch/metadata/0.nc differ diff --git a/tests/data/batch/nwp/0.nc b/tests/data/batch/nwp/0.nc index 2b8a58b4..659b2c57 100644 Binary files a/tests/data/batch/nwp/0.nc and b/tests/data/batch/nwp/0.nc differ diff --git a/tests/data/batch/pv/0.nc b/tests/data/batch/pv/0.nc index 2a942331..6ba7b5a5 100644 Binary files a/tests/data/batch/pv/0.nc and b/tests/data/batch/pv/0.nc differ diff --git a/tests/data/batch/satellite/0.nc b/tests/data/batch/satellite/0.nc index 181aa098..a232771a 100644 Binary files a/tests/data/batch/satellite/0.nc and b/tests/data/batch/satellite/0.nc differ diff --git a/tests/data/batch/sun/0.nc b/tests/data/batch/sun/0.nc index 397b737a..83e6f1de 100644 Binary files a/tests/data/batch/sun/0.nc and b/tests/data/batch/sun/0.nc differ diff --git a/tests/data/batch/topographic/0.nc b/tests/data/batch/topographic/0.nc index 2487645f..8af4bbd5 100644 Binary files a/tests/data/batch/topographic/0.nc and b/tests/data/batch/topographic/0.nc differ diff --git a/tests/data_sources/gsp/test_gsp_data_source.py b/tests/data_sources/gsp/test_gsp_data_source.py index 86b5be11..1fc7bbdd 100644 --- a/tests/data_sources/gsp/test_gsp_data_source.py +++ b/tests/data_sources/gsp/test_gsp_data_source.py @@ -69,9 +69,9 @@ def test_gsp_pv_data_source_get_example(): t0_dt=gsp.gsp_power.index[0], x_meters_center=x_locations[0], y_meters_center=y_locations[0] ) - assert len(l.gsp_id) == len(l.gsp_yield[0]) - assert len(l.gsp_x_coords) == len(l.gsp_y_coords) - assert len(l.gsp_x_coords) > 0 + assert len(l.id) == len(l.data[0]) + assert len(l.x_coords) == len(l.y_coords) + assert len(l.x_coords) > 0 # assert type(l[T0_DT]) == pd.Timestamp @@ -102,9 +102,10 @@ def test_gsp_pv_data_source_get_batch(): y_locations=y_locations[0:batch_size], ) - assert batch.batch_size == batch_size - assert len(batch.gsp_yield[0]) == 4 - assert len(batch.gsp_id[0]) == len(batch.gsp_x_coords[0]) - assert len(batch.gsp_x_coords[1]) == len(batch.gsp_y_coords[1]) - assert len(batch.gsp_x_coords[2]) > 0 + print(batch.data[0]) + + assert len(batch.data[0]) == 4 + assert len(batch.id[0]) == len(batch.x_coords[0]) + assert len(batch.x_coords[1]) == len(batch.y_coords[1]) + assert len(batch.x_coords[2]) > 0 # assert T0_DT in batch[3].keys() diff --git a/tests/data_sources/test_satellite_data_source.py b/tests/data_sources/satellite/test_satellite_data_source.py similarity index 90% rename from tests/data_sources/test_satellite_data_source.py rename to tests/data_sources/satellite/test_satellite_data_source.py index 12b817a7..86e20798 100644 --- a/tests/data_sources/test_satellite_data_source.py +++ b/tests/data_sources/satellite/test_satellite_data_source.py @@ -17,7 +17,7 @@ def test_datetime_index(sat_data_source): assert isinstance(datetimes, pd.DatetimeIndex) assert len(datetimes) > 0 assert len(np.unique(datetimes)) == len(datetimes) - assert np.all(np.diff(datetimes.astype(int)) > 0) + assert np.all(np.diff(datetimes.view(int)) > 0) @pytest.mark.parametrize( @@ -37,8 +37,8 @@ def test_datetime_index(sat_data_source): def test_get_example(sat_data_source, x, y, left, right, top, bottom): sat_data_source.open() t0_dt = pd.Timestamp("2019-01-01T13:00") - example = sat_data_source.get_example(t0_dt=t0_dt, x_meters_center=x, y_meters_center=y) - sat_data = example.sat_data + sat_data = sat_data_source.get_example(t0_dt=t0_dt, x_meters_center=x, y_meters_center=y) + assert left == sat_data.x.values[0] assert right == sat_data.x.values[-1] # sat_data.y is top-to-bottom. diff --git a/tests/data_sources/satellite/test_satellite_model.py b/tests/data_sources/satellite/test_satellite_model.py new file mode 100644 index 00000000..b02ead5f --- /dev/null +++ b/tests/data_sources/satellite/test_satellite_model.py @@ -0,0 +1,28 @@ +import os +import tempfile + +from nowcasting_dataset.data_sources.satellite.satellite_model import SatelliteML + +from nowcasting_dataset.data_sources.fake import satellite_fake + + +def test_satellite_init(): + _ = satellite_fake + + +def test_satellite_save(): + + with tempfile.TemporaryDirectory() as dirpath: + satellite_fake().save_netcdf(path=dirpath, batch_i=0) + + assert os.path.exists(f"{dirpath}/satellite/0.nc") + + +def test_satellite_to_ml(): + sat = satellite_fake() + + _ = SatelliteML.from_xr_dataset(sat) + + +def test_satellite_ml_fake(): + _ = satellite_fake diff --git a/tests/data_sources/sun/test_sun_data_source.py b/tests/data_sources/sun/test_sun_data_source.py index 8902548f..af73a40c 100644 --- a/tests/data_sources/sun/test_sun_data_source.py +++ b/tests/data_sources/sun/test_sun_data_source.py @@ -27,8 +27,8 @@ def test_get_example(test_data_folder): example = sun_data_source.get_example(t0_dt=start_dt, x_meters_center=x, y_meters_center=y) - assert len(example.sun_elevation_angle) == 19 - assert len(example.sun_azimuth_angle) == 19 + assert len(example.elevation) == 19 + assert len(example.azimuth) == 19 def test_get_example_different_year(test_data_folder): @@ -44,5 +44,5 @@ def test_get_example_different_year(test_data_folder): example = sun_data_source.get_example(t0_dt=start_dt, x_meters_center=x, y_meters_center=y) - assert len(example.sun_elevation_angle) == 19 - assert len(example.sun_azimuth_angle) == 19 + assert len(example.elevation) == 19 + assert len(example.azimuth) == 19 diff --git a/tests/data_sources/test_datasource_output.py b/tests/data_sources/test_datasource_output.py index 934d3222..66db06fb 100644 --- a/tests/data_sources/test_datasource_output.py +++ b/tests/data_sources/test_datasource_output.py @@ -1,15 +1,17 @@ -from nowcasting_dataset.data_sources.datetime.datetime_model import Datetime -from nowcasting_dataset.data_sources.gsp.gsp_model import GSP -from nowcasting_dataset.data_sources.pv.pv_model import PV -from nowcasting_dataset.data_sources.nwp.nwp_model import NWP -from nowcasting_dataset.data_sources.satellite.satellite_model import Satellite -from nowcasting_dataset.data_sources.sun.sun_model import Sun -from nowcasting_dataset.data_sources.topographic.topographic_model import Topographic +from nowcasting_dataset.data_sources.fake import ( + sun_fake, + topographic_fake, + gsp_fake, + datetime_fake, + nwp_fake, + satellite_fake, + pv_fake, +) def test_datetime(): - s = Datetime.fake( + s = datetime_fake( batch_size=4, seq_length_5=13, ) @@ -17,78 +19,47 @@ def test_datetime(): def test_gsp(): - s = GSP.fake(batch_size=4, seq_length_30=13, n_gsp_per_batch=32) + s = gsp_fake(batch_size=4, seq_length_30=13, n_gsp_per_batch=32) - -def test_gsp_pad(): - - s = GSP.fake(batch_size=4, seq_length_30=13, n_gsp_per_batch=7).split()[0] - s.to_numpy() - s.pad(n_gsp_per_example=32) - - assert s.gsp_yield.shape == (13, 32) - - -def test_gsp_split(): - - s = GSP.fake(batch_size=4, seq_length_30=13, n_gsp_per_batch=32) - split = s.split() - - assert len(split) == 4 - assert type(split[0]) == GSP - assert (split[0].gsp_yield == s.gsp_yield[0]).all() - - -def test_gsp_join(): - - s = GSP.fake(batch_size=2, seq_length_30=13, n_gsp_per_batch=32).split() - - s: GSP = GSP.create_batch_from_examples(s) - - assert s.batch_size == 2 - assert len(s.gsp_yield.shape) == 3 - assert s.gsp_yield.shape[0] == 2 - assert s.gsp_yield.shape[1] == 13 - assert s.gsp_yield.shape[2] == 32 + assert s.data.shape == (4, 13, 32) def test_nwp(): - s = NWP.fake(batch_size=4, seq_length_5=13, nwp_image_size_pixels=64, number_nwp_channels=8) - - -def test_nwp_split(): - - s = NWP.fake(batch_size=4, seq_length_5=13, nwp_image_size_pixels=64, number_nwp_channels=8) - s = s.split() + s = nwp_fake( + batch_size=4, + seq_length_5=13, + image_size_pixels=64, + number_nwp_channels=8, + ) def test_pv(): - s = PV.fake(batch_size=4, seq_length_5=13, n_pv_systems_per_batch=128) - - -def test_nwp_pad(): + s = pv_fake(batch_size=4, seq_length_5=13, n_pv_systems_per_batch=128) - s = PV.fake(batch_size=4, seq_length_5=13, n_pv_systems_per_batch=37).split()[0] - s.to_numpy() - s.pad(n_pv_systems_per_example=128) - assert s.pv_yield.shape == (13, 128) +# def test_nwp_pad(): +# +# s = PV.fake(batch_size=4, seq_length_5=13, n_pv_systems_per_batch=37).split()[0] +# s.to_numpy() +# s.pad(n_pv_systems_per_example=128) +# +# assert s.pv_yield.shape == (13, 128) def test_satellite(): - s = Satellite.fake( + s = satellite_fake( batch_size=4, seq_length_5=13, satellite_image_size_pixels=64, number_sat_channels=7 ) - assert s.sat_x_coords is not None + assert s.x is not None def test_sun(): - s = Sun.fake( + s = sun_fake( batch_size=4, seq_length_5=13, ) @@ -96,7 +67,7 @@ def test_sun(): def test_topo(): - s = Topographic.fake( + s = topographic_fake( batch_size=4, - satellite_image_size_pixels=64, + image_size_pixels=64, ) diff --git a/tests/data_sources/test_datetime.py b/tests/data_sources/test_datetime.py new file mode 100644 index 00000000..3156299a --- /dev/null +++ b/tests/data_sources/test_datetime.py @@ -0,0 +1,23 @@ +import pandas as pd + +from nowcasting_dataset.data_sources.datetime.datetime_data_source import DatetimeDataSource + + +def test_datetime_source(): + datetime_source = DatetimeDataSource( + convert_to_numpy=True, + forecast_minutes=300, + history_minutes=10, + ) + t0_dt = pd.Timestamp("2019-01-01T13:00") + _ = datetime_source.get_example(t0_dt=t0_dt, x_meters_center=0, y_meters_center=0) + + +def test_datetime_source_batch(): + datetime_source = DatetimeDataSource( + convert_to_numpy=True, + forecast_minutes=300, + history_minutes=10, + ) + t0_dt = pd.Timestamp("2019-01-01T13:00") + _ = datetime_source.get_batch(t0_datetimes=[t0_dt], x_locations=[0], y_locations=[0]) diff --git a/tests/data_sources/test_nwp_data_source.py b/tests/data_sources/test_nwp_data_source.py index e698c492..81c8c781 100644 --- a/tests/data_sources/test_nwp_data_source.py +++ b/tests/data_sources/test_nwp_data_source.py @@ -63,4 +63,4 @@ def test_nwp_data_source_batch(): batch = nwp.get_batch(t0_datetimes=t0_datetimes, x_locations=x, y_locations=y) - assert batch.batch_size == 4 + assert batch.data.shape == (4, 1, 19, 2, 2) diff --git a/tests/data_sources/test_pv_data_source.py b/tests/data_sources/test_pv_data_source.py index 93dbc41c..3d863b47 100644 --- a/tests/data_sources/test_pv_data_source.py +++ b/tests/data_sources/test_pv_data_source.py @@ -44,7 +44,7 @@ def test_get_example_and_batch(): batch = pv_data_source.get_batch( pv_data_source.pv_power.index[6:11], x_locations[0:10], y_locations[0:10] ) - assert batch.batch_size == 5 + assert batch.data.shape == (5, 19, 128) def test_drop_pv_systems_which_produce_overnight(): diff --git a/tests/data_sources/test_topographic_data_source.py b/tests/data_sources/test_topographic_data_source.py index 168c42e3..e3f14bfb 100644 --- a/tests/data_sources/test_topographic_data_source.py +++ b/tests/data_sources/test_topographic_data_source.py @@ -31,20 +31,19 @@ def test_get_example_2km(x, y, left, right, top, bottom): history_minutes=10, ) t0_dt = pd.Timestamp("2019-01-01T13:00") - example = topo_source.get_example(t0_dt=t0_dt, x_meters_center=x, y_meters_center=y) - topo_data = example.topo_data - assert topo_data.shape == (128, 128) + topo_data = topo_source.get_example(t0_dt=t0_dt, x_meters_center=x, y_meters_center=y) + assert topo_data.data.shape == (128, 128) assert len(topo_data.x) == 128 assert len(topo_data.y) == 128 - assert not np.isnan(topo_data).any() + assert not np.isnan(topo_data.data).any() # Topo x and y coords are not exactly set on the edges, but the center of the pixels assert np.isclose(left, topo_data.x.values[0], atol=size) assert np.isclose(right, topo_data.x.values[-1], atol=size) assert np.isclose(top, topo_data.y.values[0], atol=size) assert np.isclose(bottom, topo_data.y.values[-1], atol=size) # Check normalization works - assert np.max(topo_data) <= 1.0 - assert np.min(topo_data) >= -1.0 + assert np.max(topo_data.data) <= 1.0 + assert np.min(topo_data.data) >= -1.0 @pytest.mark.skip("CD does not have access to GCS") diff --git a/tests/dataset/test_batch.py b/tests/dataset/test_batch.py index 75a3efb8..3fbb4685 100644 --- a/tests/dataset/test_batch.py +++ b/tests/dataset/test_batch.py @@ -1,14 +1,11 @@ -from nowcasting_dataset.data_sources.gsp.gsp_model import GSP -import numpy as np import tempfile -from pathlib import Path - -from nowcasting_dataset.dataset.batch import Batch, GSP -from nowcasting_dataset.dataset.validate import FakeDataset +import os import torch + +from nowcasting_dataset.dataset.batch import BatchML, Batch from nowcasting_dataset.config.model import Configuration -import xarray as xr +from nowcasting_dataset.dataset.fake import FakeDataset def test_model(): @@ -16,65 +13,27 @@ def test_model(): _ = Batch.fake() -def test_model_to_numpy(): - - f = Batch.fake() - - f.change_type_to_numpy() - - assert type(f.gsp) == GSP - - -def test_model_split(): - - f = Batch.fake() - - data = f.split() - - assert len(data) == f.batch_size - assert type(data[0].gsp) == GSP +def test_model_save_to_netcdf(): + with tempfile.TemporaryDirectory() as dirpath: + Batch.fake().save_netcdf(path=dirpath, batch_i=0) -def test_model_to_xr_dataset(configuration): - - f = Batch.fake(configuration=configuration) - f_xr = f.batch_to_dict_dataset() - - assert type(f_xr) == dict - assert type(f_xr["metadata"]) == xr.Dataset - - -def test_model_from_xr_dataset(): - - f = Batch.fake() - - f_xr = f.batch_to_dict_dataset() - - _ = Batch.load_batch_from_dict_dataset(xr_dataset=f_xr) + assert os.path.exists(f"{dirpath}/satellite/0.nc") -def test_model_save_to_netcdf(test_data_folder): +def test_model_load_from_netcdf(): with tempfile.TemporaryDirectory() as dirpath: Batch.fake().save_netcdf(path=dirpath, batch_i=0) + batch = Batch.load_netcdf(batch_idx=0, local_netcdf_path=dirpath) -def test_model_from_test_data(test_data_folder): - x = Batch.load_netcdf(local_netcdf_path=f"{test_data_folder}/batch", batch_idx=0) - + assert batch.satellite is not None -def test_model_from_xr_dataset_to_numpy(): - f = Batch.fake() +def test_batch_to_batch_ml(): - f_xr = f.batch_to_dict_dataset() - fs = Batch.load_batch_from_dict_dataset(xr_dataset=f_xr) - # check they are the same - fs.change_type_to_numpy() - f.gsp.to_numpy() - assert f.gsp.gsp_yield.shape == fs.gsp.gsp_yield.shape - assert (f.gsp.gsp_yield[0].astype(np.float32) == fs.gsp.gsp_yield[0]).all() - assert (f.gsp.gsp_yield.astype(np.float32) == fs.gsp.gsp_yield).all() + _ = BatchML.from_batch(batch=Batch.fake()) def test_fake_dataset(): @@ -82,6 +41,6 @@ def test_fake_dataset(): i = iter(train) x = next(i) - x = Batch(**x) + x = BatchML(**x) # IT WORKS - assert type(x.satellite.sat_data) == torch.Tensor + assert type(x.satellite.data) == torch.Tensor diff --git a/tests/dataset/test_subselect.py b/tests/dataset/test_subselect.py new file mode 100644 index 00000000..8efbfb8e --- /dev/null +++ b/tests/dataset/test_subselect.py @@ -0,0 +1,64 @@ +import os +import tempfile +from pathlib import Path + +import pandas as pd +import plotly +import plotly.graph_objects as go +import pytest +import torch +import xarray as xr + +import nowcasting_dataset +import nowcasting_dataset.dataset.batch +from nowcasting_dataset.config.model import Configuration +from nowcasting_dataset.consts import ( + SATELLITE_X_COORDS, + SATELLITE_Y_COORDS, + SATELLITE_DATA, + NWP_DATA, + SATELLITE_DATETIME_INDEX, + NWP_TARGET_TIME, + NWP_Y_COORDS, + NWP_X_COORDS, + PV_YIELD, + GSP_YIELD, + GSP_DATETIME_INDEX, + T0_DT, +) + +# from nowcasting_dataset.dataset import example +from nowcasting_dataset.dataset.batch import Batch +from nowcasting_dataset.dataset.datasets import NetCDFDataset, worker_init_fn +from nowcasting_dataset.dataset.subset import subselect_data + + +def test_subselect_date(test_data_folder): + + x = Batch.fake() + + batch = subselect_data( + x, + current_timestep_index=7, + history_minutes=10, + forecast_minutes=10, + ) + + assert batch.satellite.data.shape == (32, 5, 64, 64, 12) + assert batch.nwp.data.shape == (32, 5, 64, 64, 10) + + +# +def test_subselect_date_with_to_dt(test_data_folder): + + # x = Batch.load_netcdf(f"{test_data_folder}/0.nc") + x = Batch.fake() + + batch = subselect_data( + x, + history_minutes=10, + forecast_minutes=10, + ) + + assert batch.satellite.data.shape == (32, 5, 64, 64, 12) + assert batch.nwp.data.shape == (32, 5, 64, 64, 10) diff --git a/tests/dataset/test_validate.py b/tests/dataset/test_validate.py deleted file mode 100644 index 275c725e..00000000 --- a/tests/dataset/test_validate.py +++ /dev/null @@ -1,37 +0,0 @@ -import os - -import torch - -import nowcasting_dataset -from nowcasting_dataset.config.load import load_yaml_configuration -from nowcasting_dataset.dataset.datasets import worker_init_fn -from nowcasting_dataset.dataset.validate import FakeDataset, ValidatorDataset - - -def test_validate(): - - local_path = os.path.join(os.path.dirname(nowcasting_dataset.__file__), "../") - - # load configuration, this can be changed to a different filename as needed - filename = os.path.join(local_path, "tests", "config", "test.yaml") - config = load_yaml_configuration(filename) - - train_dataset = FakeDataset( - configuration=config, - length=10, - ) - - dataloader_config = dict( - pin_memory=True, - num_workers=1, - prefetch_factor=1, - worker_init_fn=worker_init_fn, - persistent_workers=True, - # Disable automatic batching because dataset - # returns complete batches. - batch_size=None, - ) - - train_dataset = torch.utils.data.DataLoader(train_dataset, **dataloader_config) - - ValidatorDataset(configuration=config, batches=train_dataset) diff --git a/tests/filesystem/test_local.py b/tests/filesystem/test_local.py index 2a42179c..13ec847e 100644 --- a/tests/filesystem/test_local.py +++ b/tests/filesystem/test_local.py @@ -60,6 +60,7 @@ def test_make_folder(): def test_delete_local_files(): file1 = "test_file1.txt" + folder1 = "test_dir" file2 = "test_dir/test_file2.txt" with tempfile.TemporaryDirectory() as tmpdirname: @@ -71,7 +72,8 @@ def test_delete_local_files(): pass # add fake file to dir - os.mkdir(f"{tmpdirname}/test_dir") + path_and_folder_1 = os.path.join(local_path, folder1) + os.mkdir(path_and_folder_1) path_and_filename_2 = os.path.join(local_path, file2) with open(os.path.join(local_path, file2), "w"): pass @@ -82,6 +84,37 @@ def test_delete_local_files(): # check the object are not there assert not os.path.exists(path_and_filename_1) assert not os.path.exists(path_and_filename_2) + assert os.path.exists(path_and_folder_1) + + +def test_delete_local_files_and_folder(): + + file1 = "test_file1.txt" + folder1 = "test_dir" + file2 = "test_dir/test_file2.txt" + + with tempfile.TemporaryDirectory() as tmpdirname: + local_path = Path(tmpdirname) + + # add fake file to dir + path_and_filename_1 = os.path.join(local_path, file1) + with open(path_and_filename_1, "w"): + pass + + # add fake file to dir + path_and_folder_1 = os.path.join(local_path, folder1) + os.mkdir(path_and_folder_1) + path_and_filename_2 = os.path.join(local_path, file2) + with open(os.path.join(local_path, file2), "w"): + pass + + # run function + delete_all_files_in_temp_path(path=local_path, delete_dirs=True) + + # check the object are not there + assert not os.path.exists(path_and_filename_1) + assert not os.path.exists(path_and_filename_2) + assert not os.path.exists(path_and_folder_1) def test_download(): diff --git a/tests/test_datamodule.py b/tests/test_datamodule.py index 30b28026..e7a1e883 100644 --- a/tests/test_datamodule.py +++ b/tests/test_datamodule.py @@ -116,9 +116,17 @@ def test_data_module(config_filename): data_generator = iter(data_module.train_dataset) batch = next(data_generator) - assert batch["batch_size"] == config.process.batch_size - - _ = Batch(**batch) + assert batch.batch_size == config.process.batch_size + assert type(batch) == Batch + + assert batch.satellite is not None + assert batch.nwp is not None + assert batch.sun is not None + assert batch.topographic is not None + assert batch.pv is not None + assert batch.gsp is not None + assert batch.metadata is not None + assert batch.datetime is not None # for key in list(Example.__annotations__.keys()): # assert key in batch[0].keys() @@ -182,4 +190,4 @@ def test_batch_to_batch_to_dataset(): data_generator = iter(data_module.train_dataset) batch = next(data_generator) - _ = Batch(**batch) + assert type(batch) == Batch diff --git a/tests/test_dataset.py b/tests/test_dataset.py index c7183582..17af6c29 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -5,6 +5,7 @@ import nowcasting_dataset.time as nd_time from nowcasting_dataset.consts import GSP_DATETIME_INDEX from nowcasting_dataset.dataset.datasets import NowcastingDataset +from nowcasting_dataset.dataset.batch import Batch @pytest.fixture @@ -55,10 +56,10 @@ def test_per_worker_init(dataset: NowcastingDataset): def test_get_batch(dataset: NowcastingDataset): dataset.per_worker_init(worker_id=1) - example = dataset._get_batch() - assert isinstance(example, dict) - assert "satellite" in example - assert example["satellite"]["sat_data"].shape == ( + batch = dataset._get_batch() + assert isinstance(batch, Batch) + assert batch.satellite is not None + assert batch.satellite.data.shape == ( 8, 2, pytest.IMAGE_SIZE_PIXELS, @@ -69,7 +70,7 @@ def test_get_batch(dataset: NowcastingDataset): def test_get_batch_gsp(dataset_gsp: NowcastingDataset): dataset_gsp.per_worker_init(worker_id=1) - example = dataset_gsp._get_batch() - assert isinstance(example, dict) + batch = dataset_gsp._get_batch() + assert isinstance(batch, Batch) - assert "gsp" in example.keys() + assert batch.gsp is not None diff --git a/tests/test_netcdf_dataset.py b/tests/test_netcdf_dataset.py index 7cbd9bb0..cdadc8d0 100644 --- a/tests/test_netcdf_dataset.py +++ b/tests/test_netcdf_dataset.py @@ -28,45 +28,8 @@ ) # from nowcasting_dataset.dataset import example -from nowcasting_dataset.dataset.batch import Batch -from nowcasting_dataset.dataset.datasets import NetCDFDataset, worker_init_fn, subselect_data - - -def test_subselect_date(test_data_folder): - - # x = Batch.load_netcdf(f"{test_data_folder}/0.nc") - x = Batch.fake() - x = x.batch_to_dict_dataset() - x = Batch.load_batch_from_dict_dataset(x) - - batch = subselect_data( - x, - required_keys=(NWP_DATA, NWP_TARGET_TIME, SATELLITE_DATA, SATELLITE_DATETIME_INDEX), - current_timestep_index=7, - history_minutes=10, - forecast_minutes=10, - ) - - assert batch.satellite.sat_data.shape[1] == 5 - assert batch.nwp.nwp.shape[2] == 5 - - -def test_subselect_date_with_to_dt(test_data_folder): - - # x = Batch.load_netcdf(f"{test_data_folder}/0.nc") - x = Batch.fake() - x = x.batch_to_dict_dataset() - x = Batch.load_batch_from_dict_dataset(x) - - batch = subselect_data( - x, - required_keys=(NWP_DATA, NWP_TARGET_TIME, SATELLITE_DATA, SATELLITE_DATETIME_INDEX), - history_minutes=10, - forecast_minutes=10, - ) - - assert batch.satellite.sat_data.shape[1] == 5 - assert batch.nwp.nwp.shape[2] == 5 +from nowcasting_dataset.dataset.batch import BatchML +from nowcasting_dataset.dataset.datasets import NetCDFDataset, worker_init_fn def test_netcdf_dataset_local_using_configuration(configuration: Configuration): @@ -105,12 +68,14 @@ def test_netcdf_dataset_local_using_configuration(configuration: Configuration): t = iter(train_dataset) data = next(t) - sat_data = data.satellite.sat_data + batch_ml = BatchML(**data) + + sat_data = batch_ml.satellite.data # TODO # Sat is in 5min increments, so should have 2 history + current + 2 future assert sat_data.shape[1] == 5 - assert data.nwp.nwp.shape[2] == 5 + assert batch_ml.nwp.data.shape == (4, 5, 64, 64, 1) # Make sure file isn't deleted! assert os.path.exists(os.path.join(DATA_PATH, "metadata/0.nc")) @@ -143,7 +108,7 @@ def test_get_dataloaders_gcp(configuration: Configuration): train_dataset.per_worker_init(1) t = iter(train_dataset) - data: Batch = next(t) + data: BatchML = next(t) # image z = data.satellite.sat_data[0][0][:, :, 0] diff --git a/tests/test_time.py b/tests/test_time.py index 60199fa3..dcc3b09d 100644 --- a/tests/test_time.py +++ b/tests/test_time.py @@ -98,8 +98,8 @@ def test_datetime_features_in_example(): assert len(example.hour_of_day_sin) == len(index) for col_name in ["hour_of_day_sin", "hour_of_day_cos"]: np.testing.assert_array_almost_equal( - example.__getattribute__(col_name), - np.tile(example.__getattribute__(col_name)[:24], reps=6), + getattr(example, col_name), + np.tile(getattr(example, col_name)[:24], reps=6), ) diff --git a/tests/test_utils.py b/tests/test_utils.py index 2b8bf5cb..994d49da 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,8 +4,6 @@ from nowcasting_dataset import utils -# from nowcasting_dataset.dataset.example import Example - def test_is_monotically_increasing(): assert utils.is_monotonically_increasing([1, 2, 3, 4])