-
Notifications
You must be signed in to change notification settings - Fork 20
Description
I'd like to get feedback on potential breaking changes in the objects returned by model.predict
in DeepSensor. As a reminder, model.predict
allows you to take a bunch of Tasks (containing context data) and get unnormalised model predictions directly in xarray
(for gridded predictions) or pandas
(for non-gridded predictions). This is one of the key selling points of DeepSensor because it is very convenient for downstream analysis of model predictions.
There are two key challenges that make it tricky to decide what to return from .predict:
- We aim to support non-Gaussian distributions in future, so the model's prediction parameters will depend on the model's likelihood/distribution
- There can be multiple target variables, leading to N-D predictions. Not all models will support this, but the default (and only) model in DeepSensor right now, the
ConvNP
, does.
Therefore, we can have a varying number of distribution parameters, and also a varying number of variables.
Currently, .predict
by default returns two objects: a mean
and a stddev
of the model's Gaussian distributions. If the user passes the n_samples
kwarg, there is also a third samples object returned. One can obtain the predictions for a specific variable like mean["var1"]
and mean["var2"]
(assuming a 2D target), for example.
When we start to support non-Gaussian distributions in future, it would be silly to then start returning these as additional return objects. For example, if we used a Bernoulli-Gamma distribution for a positive semi-definite variable like precip, it would be silly to return extra bernoulli_prob
and gamma_alpha
and gamma_beta
objects. This would risk ValueError: too many values to unpack
in user code.
Here is my proposal: model.predict
instead returns a single dictionary where keys are variable IDs (e.g. "t2m") and values are single xarray/pandas objects containing distribution parameters. For example:
pred = model.predict(tasks, ..., n_samples=5)
print(pred["t2m"]["mean"]) # the model mean for the t2m var (xarray.DataArray or pandas.Series)
print(pred["t2m"]["sample1"]) # the first sample for the t2m var (xarray.DataArray or pandas.Series)
print(pred["sst"]["std"]) # the model std dev for the sst var (xarray.DataArray or pandas.Series)
This allows the number of target variables and number of distribution parameters to change, while wrapping everything in a single return object.
You can see in the above snippet that it gets a little awkward when drawing model samples with n_samples>0
because samples have a 'sample number' dimension as well as spatial and temporal dimensions. Forcing the samples into the same xarray/pandas objects means we need to have sample<n>
-like variables. Perhaps we can get away with putting the samples in a separate object, eg pred, samples = model.predict(tasks, ..., n_samples=5)
, and assume the user will be happy to handle the second returned object in this case.