@@ -52,7 +52,7 @@ def __init__(self, *args, **kwargs):
52
52
self .df = _gen_data_pandas ()
53
53
54
54
self .dp = DataProcessor ()
55
- _ = self .dp ([self .da , self .df ]) # Compute normalization parameters
55
+ _ = self .dp ([self .da , self .df ]) # Compute normalisation parameters
56
56
57
57
def _gen_task_loader_call_args (self , n_context , n_target ):
58
58
"""Generate arguments for TaskLoader.__call__
@@ -194,6 +194,48 @@ def test_prediction_shapes_lowlevel(self, n_target_sets):
194
194
x = B .to_numpy (model .loss_fn (task ))
195
195
assert x .size == 1 and x .shape == ()
196
196
197
+ @parameterized .expand (range (1 , 4 ))
198
+ def test_nans_offgrid_context (self , ndim ):
199
+ """Test that `ConvNP` can handle NaNs in offgrid context"""
200
+
201
+ tl = TaskLoader (
202
+ context = _gen_data_xr (data_vars = range (ndim )),
203
+ target = self .da ,
204
+ )
205
+
206
+ # All NaNs
207
+ task = tl ("2020-01-01" , context_sampling = 10 , target_sampling = 10 )
208
+ task ["Y_c" ][0 ][:, 0 ] = np .nan
209
+ model = ConvNP (self .dp , tl , unet_channels = (5 , 5 , 5 ), verbose = False )
210
+ _ = model (task )
211
+
212
+ # One NaN
213
+ task = tl ("2020-01-01" , context_sampling = 10 , target_sampling = 10 )
214
+ task ["Y_c" ][0 ][0 , 0 ] = np .nan
215
+ model = ConvNP (self .dp , tl , unet_channels = (5 , 5 , 5 ), verbose = False )
216
+ _ = model (task )
217
+
218
+ @parameterized .expand (range (1 , 4 ))
219
+ def test_nans_gridded_context (self , ndim ):
220
+ """Test that `ConvNP` can handle NaNs in gridded context"""
221
+
222
+ tl = TaskLoader (
223
+ context = _gen_data_xr (data_vars = range (ndim )),
224
+ target = self .da ,
225
+ )
226
+
227
+ # All NaNs
228
+ task = tl ("2020-01-01" , context_sampling = "all" , target_sampling = 10 )
229
+ task ["Y_c" ][0 ][:, 0 , 0 ] = np .nan
230
+ model = ConvNP (self .dp , tl , unet_channels = (5 , 5 , 5 ), verbose = False )
231
+ _ = model (task )
232
+
233
+ # One NaN
234
+ task = tl ("2020-01-01" , context_sampling = "all" , target_sampling = 10 )
235
+ task ["Y_c" ][0 ][0 , 0 , 0 ] = np .nan
236
+ model = ConvNP (self .dp , tl , unet_channels = (5 , 5 , 5 ), verbose = False )
237
+ _ = model (task )
238
+
197
239
@parameterized .expand (range (1 , 4 ))
198
240
def test_prediction_shapes_highlevel (self , target_dim ):
199
241
"""Test high-level `.predict` interface over a range of number of target sets
@@ -226,7 +268,9 @@ def test_prediction_shapes_highlevel(self, target_dim):
226
268
tasks ,
227
269
X_t = self .da ,
228
270
n_samples = n_samples ,
229
- unnormalise = False if target_dim > 1 else True ,
271
+ unnormalise = True
272
+ if target_dim == 1
273
+ else False , # TODO fix unnormalising for multiple equally named targets
230
274
)
231
275
assert [isinstance (ds , xr .Dataset ) for ds in [mean_ds , std_ds , samples_ds ]]
232
276
assert_shape (
0 commit comments