Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Issue/166 batch pydantic #195

Merged
merged 60 commits into from
Oct 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
1a866a7
pydantic test
peterdudfield Oct 1, 2021
1afa092
begun pydantic model, lots of dummy functions
peterdudfield Oct 1, 2021
a866439
add models for other data sources
peterdudfield Oct 1, 2021
24dbe93
update more model
peterdudfield Oct 4, 2021
1598b2b
add prints, for looking at size of files
peterdudfield Oct 4, 2021
df9ad12
fix
peterdudfield Oct 4, 2021
10fc202
fix
peterdudfield Oct 4, 2021
a121c12
tidy
peterdudfield Oct 4, 2021
95c9a70
do size test on actually .nc
peterdudfield Oct 4, 2021
f79f398
fix tests - not dataset ones
peterdudfield Oct 5, 2021
2c13637
fix tests
peterdudfield Oct 5, 2021
9c8b014
add pytest fixture, of general data source
peterdudfield Oct 5, 2021
67b238c
add general datasource
peterdudfield Oct 5, 2021
96e317e
update batch 0.nc
peterdudfield Oct 5, 2021
5d45631
udpate dataset (more todo)
peterdudfield Oct 5, 2021
9d547fa
option for not retuning xr data if not there
peterdudfield Oct 5, 2021
68aeacb
remove subset of data from dataset
peterdudfield Oct 5, 2021
887484c
get one subselect test working
peterdudfield Oct 5, 2021
c0df51d
re introduce tests
peterdudfield Oct 5, 2021
a346abf
fix for other sub selections
peterdudfield Oct 5, 2021
ac18891
remove example
peterdudfield Oct 6, 2021
2ae2558
remove old code
peterdudfield Oct 6, 2021
e58a00b
tidy up folder structure
peterdudfield Oct 6, 2021
90f7bcb
tidy and add README
peterdudfield Oct 6, 2021
ba7d1c5
pylint
peterdudfield Oct 6, 2021
ff8aa60
get ready for scritps working
peterdudfield Oct 6, 2021
2a47f4a
Merge commit 'b11b8ccbbfb22111d679ccf8282014183e63454f' into issue/16…
peterdudfield Oct 6, 2021
ab4500a
PR comments
peterdudfield Oct 6, 2021
06a4bdb
fix, add some test that accidentally got deleted
peterdudfield Oct 6, 2021
e943489
fix
peterdudfield Oct 6, 2021
fcd272c
typos
peterdudfield Oct 6, 2021
dc75cd2
PR comment
peterdudfield Oct 6, 2021
a934ceb
PR comments
peterdudfield Oct 6, 2021
fc1614b
PR comment
peterdudfield Oct 6, 2021
02e1179
PR comment - tidy
peterdudfield Oct 6, 2021
bde20c4
pylint
peterdudfield Oct 6, 2021
b620b01
PR comment
peterdudfield Oct 6, 2021
f51db97
tidy
peterdudfield Oct 6, 2021
7f12c86
put back in padding tests, and PR comments on Sub select data
peterdudfield Oct 6, 2021
cb7067c
pylint
peterdudfield Oct 6, 2021
9554896
rename variable
peterdudfield Oct 6, 2021
6d75d04
PR comment for sub selecting some data
peterdudfield Oct 6, 2021
b2ac635
fix for script prepare_ml_data
peterdudfield Oct 7, 2021
109f02e
add types to xarray datasets
peterdudfield Oct 7, 2021
e0774db
fix for validation ml data
peterdudfield Oct 7, 2021
235b1a7
Merge branch 'main' into issue/166-batch-pydantic
peterdudfield Oct 7, 2021
018f367
rename General to Metadata (PR comment)
peterdudfield Oct 7, 2021
725a4a2
update for 'object_at_center_label' label not string, helps with torc…
peterdudfield Oct 7, 2021
1d7d0e5
fix linting
peterdudfield Oct 7, 2021
7c837e2
fix test,
peterdudfield Oct 7, 2021
59b3378
remake 0.nc test data
peterdudfield Oct 7, 2021
a70c3a3
Apply suggestions from code review
peterdudfield Oct 7, 2021
695f9dc
fix
peterdudfield Oct 7, 2021
5e26337
remove torch from fake Batch
peterdudfield Oct 7, 2021
439c967
Merge branch 'main' into issue/166-batch-pydantic
peterdudfield Oct 7, 2021
3166a61
PR comment JK
peterdudfield Oct 7, 2021
7c6e5a3
batch_index in to xr_dataset
peterdudfield Oct 7, 2021
0af228f
PR comment JK.-
peterdudfield Oct 7, 2021
6960df8
some small notebooks
peterdudfield Oct 8, 2021
ee09b9a
update notebook
peterdudfield Oct 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from nowcasting_dataset.config.load import load_yaml_configuration
from nowcasting_dataset.data_sources import SatelliteDataSource
from nowcasting_dataset.data_sources.gsp.gsp_data_source import GSPDataSource
from nowcasting_dataset.data_sources.metadata.metadata_data_source import MetadataDataSource

pytest.IMAGE_SIZE_PIXELS = 128

Expand Down Expand Up @@ -50,6 +51,14 @@ def sat_data_source(sat_filename: Path):
)


@pytest.fixture
def general_data_source():

return MetadataDataSource(
history_minutes=0, forecast_minutes=5, object_at_center="GSP", convert_to_numpy=True
)


@pytest.fixture
def gsp_data_source():
return GSPDataSource(
Expand All @@ -65,9 +74,9 @@ def gsp_data_source():
@pytest.fixture
def configuration():
filename = os.path.join(os.path.dirname(nowcasting_dataset.__file__), "config", "gcp.yaml")
config = load_yaml_configuration(filename)
configuration = load_yaml_configuration(filename)

return config
return configuration


@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion notebooks/2021-09/2021-09-07/sat_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime

from nowcasting_dataset.data_sources.satellite_data_source import SatelliteDataSource
from nowcasting_dataset.data_sources.satellite.satellite_data_source import SatelliteDataSource

s = SatelliteDataSource(
filename="gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/"
Expand Down
129 changes: 129 additions & 0 deletions notebooks/2021-10/2021-10-01/pydantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from pydantic import BaseModel, Field, validator
from typing import Union
import numpy as np
import xarray as xr
import torch
from nowcasting_dataset.config.model import Configuration


Array = Union[xr.DataArray, np.ndarray, torch.Tensor]


class Satellite(BaseModel):

# width: int = Field(..., g=0, description="The width of the satellite image")
# height: int = Field(..., g=0, description="The width of the satellite image")
# num_channels: int = Field(..., g=0, description="The width of the satellite image")

# Shape: [batch_size,] seq_length, width, height, channel
image_data: Array = Field(
...,
description="Satellites images. Shape: [batch_size,] seq_length, width, height, channel",
)
x_coords: Array = Field(
...,
description="The x (OSGB geo-spatial) coordinates of the satellite images. Shape: [batch_size,] width",
)
y_coords: Array = Field(
...,
description="The y (OSGB geo-spatial) coordinates of the satellite images. Shape: [batch_size,] width",
)

# @validator("sat_data")
# def image_shape(cls, v):
# assert v.shape[-1] == cls.num_channels
# assert v.shape[-2] == cls.height
# assert v.shape[-3] == cls.width
#
# @validator("x_coords")
# def x_coords_shape(cls, v):
# assert v.shape[-1] == cls.width
#
# @validator("y_coords")
# def y_coords_shape(cls, v):
# assert v.shape[-1] == cls.height
#
class Config:
arbitrary_types_allowed = True


class Batch(BaseModel):

batch_size: int = Field(
...,
g=0,
description="The size of this batch. If the batch size is 0, "
"then this item stores one data item",
)

satellite: Satellite


class FakeDataset(torch.utils.data.Dataset):
"""Fake dataset."""

def __init__(self, configuration: Configuration = Configuration(), length: int = 10):
"""
Init

Args:
configuration: configuration object
length: length of dataset
"""
self.batch_size = configuration.process.batch_size
self.seq_length_5 = (
configuration.process.seq_len_5_minutes
) # the sequence data in 5 minute steps
self.seq_length_30 = (
configuration.process.seq_len_30_minutes
) # the sequence data in 30 minute steps
self.satellite_image_size_pixels = configuration.process.satellite_image_size_pixels
self.nwp_image_size_pixels = configuration.process.nwp_image_size_pixels
self.number_sat_channels = len(configuration.process.sat_channels)
self.number_nwp_channels = len(configuration.process.nwp_channels)
self.length = length

def __len__(self):
"""Number of pieces of data"""
return self.length

def per_worker_init(self, worker_id: int):
"""Not needed"""
pass

def __getitem__(self, idx):
"""
Get item, use for iter and next method

Args:
idx: batch index

Returns: Dictionary of random data

"""

sat = Satellite(
image_data=np.random.randn(
self.batch_size,
self.seq_length_5,
self.satellite_image_size_pixels,
self.satellite_image_size_pixels,
self.number_sat_channels,
),
x_coords=torch.sort(torch.randn(self.batch_size, self.satellite_image_size_pixels))[0],
y_coords=torch.sort(
torch.randn(self.batch_size, self.satellite_image_size_pixels), descending=True
)[0],
)

# Note need to return as nested dict
return Batch(satellite=sat, batch_size=self.batch_size).dict()


train = torch.utils.data.DataLoader(FakeDataset())
i = iter(train)
x = next(i)

x = Batch(**x)
# IT WORKS
assert type(x.satellite.image_data) == torch.Tensor
Empty file.
97 changes: 97 additions & 0 deletions notebooks/2021-10/2021-10-08/xr_compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os

import numpy as np
import xarray as xr
from nowcasting_dataset.utils import coord_to_range


def get_satellite_xrarray_data_array(
batch_size, seq_length_5, satellite_image_size_pixels, number_sat_channels=10
):

r = np.random.randn(
# self.batch_size,
seq_length_5,
satellite_image_size_pixels,
satellite_image_size_pixels,
number_sat_channels,
)

time = np.sort(np.random.randn(seq_length_5))

x_coords = np.sort(np.random.randint(0, 1000, (satellite_image_size_pixels)))
y_coords = np.sort(np.random.randint(0, 1000, (satellite_image_size_pixels)))[::-1].copy()

sat_xr = xr.DataArray(
data=r,
dims=["time", "x", "y", "channels"],
coords=dict(
# batch=range(0,self.batch_size),
x=list(x_coords),
y=list(y_coords),
time=list(time),
channels=range(0, number_sat_channels),
),
attrs=dict(
description="Ambient temperature.",
units="degC",
),
name="sata_data",
)

return sat_xr


def sat_data_array_to_dataset(sat_xr):
ds = sat_xr.to_dataset(name="sat_data")
# ds["sat_data"] = ds["sat_data"].astype(np.int16)

for dim in ["time", "x", "y"]:
# This does seem like the right way to do it
# https://ecco-v4-python-tutorial.readthedocs.io/ECCO_v4_Saving_Datasets_and_DataArrays_to_NetCDF.html
ds = coord_to_range(ds, dim, prefix="sat")
ds = ds.rename(
{
"channels": f"sat_channels",
"x": f"sat_x",
"y": f"sat_y",
}
)

# ds["sat_x_coords"] = ds["sat_x_coords"].astype(np.int32)
# ds["sat_y_coords"] = ds["sat_y_coords"].astype(np.int32)

return ds


def to_netcdf(batch_xr, local_filename):
encoding = {name: {"compression": "lzf"} for name in batch_xr.data_vars}
batch_xr.to_netcdf(local_filename, engine="h5netcdf", mode="w", encoding=encoding)


# 1. try to save netcdf files not using coord to range function
sat_xrs = [get_satellite_xrarray_data_array(4, 19, 32) for _ in range(0, 10)]

### error ###
# cant do this step as x/y index has duplicate values
sat_dataset = xr.merge(sat_xrs)
to_netcdf(sat_dataset, "test_no_alignment.nc")
###

# but can save it as separate files
os.mkdir("test_no_alignment")
[sat_xrs[i].to_netcdf(f"test_no_alignment/{i}.nc", engine="h5netcdf") for i in range(0, 10)]
# 10 files about 1.5MB

# 2.
sat_xrs = [get_satellite_xrarray_data_array(4, 19, 32) for _ in range(0, 10)]
sat_xrs = [sat_data_array_to_dataset(sat_xr) for sat_xr in sat_xrs]

sat_dataset = xr.concat(sat_xrs, dim="example")
to_netcdf(sat_dataset, "test_alignment.nc")
# this 15 MB


# conclusion
# no major improvement in compression by joining datasets together, buts by joining array together,
# it does make it easier to get array ready ML
99 changes: 99 additions & 0 deletions notebooks/2021-10/2021-10-08/xr_pydantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from pydantic import BaseModel, Field, validator
from typing import Union, List
import numpy as np
import xarray as xr
import torch
from nowcasting_dataset.config.model import Configuration


Array = Union[xr.DataArray, np.ndarray, torch.Tensor]


class Satellite(BaseModel):
# Shape: [batch_size,] seq_length, width, height, channel
image_data: xr.DataArray = Field(
...,
description="Satellites images. Shape: [batch_size,] seq_length, width, height, channel",
)

class Config:
arbitrary_types_allowed = True

@validator("image_data")
def v_image_data(cls, v):
print("validating image data")
return v


class Batch(BaseModel):

batch_size: int = 0
satellite: Satellite

@validator("batch_size")
def v_image_data(cls, v):
print("validating batch size")
return v


s = Satellite(image_data=xr.DataArray())
s_dict = s.dict()

x = Satellite(**s_dict)
x = Satellite.construct(Satellite.__fields_set__, **s_dict)


batch = Batch(batch_size=5, satellite=s)

b_dict = batch.dict()

x = Batch(**b_dict)
x = Batch.construct(Batch.__fields_set__, **b_dict)


# class Satellite(BaseModel):
#
# image_data: xr.DataArray
#
# # validate
#
# def to_dataset(self):
# pass
#
# def from_dateset(self):
# pass
#
# def to_numpy(self) -> SatelliteNumpy:
# pass
#
#
# class SatelliteNumpy(BaseModel):
#
# image_data: np.ndarray
# x: np.ndarray
# # more
#
#
# class Example(BaseModel):
#
# satelllite: Satellite
# # more
#
#
# class Batch(BaseModel):
#
# batch_size: int = 0
# examples: List[Example]
#
# def to/from_netcdf():
# pass
#
#
# class BatchNumpy(BaseModel):
#
# batch_size: int = 0
# satellite: SatellliteNumpy
# # more
#
# def from_batch(self) -> BatchNumpy:
# """ change to Batch numpy structure """
4 changes: 2 additions & 2 deletions nowcasting_dataset/config/gcp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ input_data:
satellite_zarr_path: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr
solar_pv_data_filename: gs://solar-pv-nowcasting-data/PV/PVOutput.org/UK_PV_timeseries_batch.nc
solar_pv_metadata_filename: gs://solar-pv-nowcasting-data/PV/PVOutput.org/UK_PV_metadata.csv
gsp_zarr_path: gs://solar-pv-nowcasting-data/PV/PVOutput.org/PV/GSP/v1/pv_gsp.zarr
gsp_zarr_path: gs://solar-pv-nowcasting-data/PV/GSP/v1/pv_gsp.zarr
topographic_filename: gs://solar-pv-nowcasting-data/Topographic/europe_dem_1km_osgb.tif
output_data:
filepath: gs://solar-pv-nowcasting-data/prepared_ML_training_data/v6/
filepath: gs://solar-pv-nowcasting-data/prepared_ML_training_data/v7/
process:
local_temp_path: ~/temp/
seed: 1234
Expand Down
Loading