Skip to content

Improving format of object returned by DeepSensorModel.predict #53

@tom-andersson

Description

@tom-andersson

I'd like to get feedback on potential breaking changes in the objects returned by model.predict in DeepSensor. As a reminder, model.predict allows you to take a bunch of Tasks (containing context data) and get unnormalised model predictions directly in xarray (for gridded predictions) or pandas (for non-gridded predictions). This is one of the key selling points of DeepSensor because it is very convenient for downstream analysis of model predictions.

There are two key challenges that make it tricky to decide what to return from .predict:

  • We aim to support non-Gaussian distributions in future, so the model's prediction parameters will depend on the model's likelihood/distribution
  • There can be multiple target variables, leading to N-D predictions. Not all models will support this, but the default (and only) model in DeepSensor right now, the ConvNP, does.

Therefore, we can have a varying number of distribution parameters, and also a varying number of variables.

Currently, .predict by default returns two objects: a mean and a stddev of the model's Gaussian distributions. If the user passes the n_samples kwarg, there is also a third samples object returned. One can obtain the predictions for a specific variable like mean["var1"] and mean["var2"] (assuming a 2D target), for example.

When we start to support non-Gaussian distributions in future, it would be silly to then start returning these as additional return objects. For example, if we used a Bernoulli-Gamma distribution for a positive semi-definite variable like precip, it would be silly to return extra bernoulli_prob and gamma_alpha and gamma_beta objects. This would risk ValueError: too many values to unpack in user code.

Here is my proposal: model.predict instead returns a single dictionary where keys are variable IDs (e.g. "t2m") and values are single xarray/pandas objects containing distribution parameters. For example:

pred = model.predict(tasks, ..., n_samples=5)
print(pred["t2m"]["mean"])  # the model mean for the t2m var (xarray.DataArray or pandas.Series)
print(pred["t2m"]["sample1"])  # the first sample for the t2m var (xarray.DataArray or pandas.Series)
print(pred["sst"]["std"])  # the model std dev for the sst var (xarray.DataArray or pandas.Series)

This allows the number of target variables and number of distribution parameters to change, while wrapping everything in a single return object.

You can see in the above snippet that it gets a little awkward when drawing model samples with n_samples>0 because samples have a 'sample number' dimension as well as spatial and temporal dimensions. Forcing the samples into the same xarray/pandas objects means we need to have sample<n>-like variables. Perhaps we can get away with putting the samples in a separate object, eg pred, samples = model.predict(tasks, ..., n_samples=5), and assume the user will be happy to handle the second returned object in this case.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions