Skip to content

Commit 094c510

Browse files
committed
Fix context NaN handling and add tests
1 parent 09c0520 commit 094c510

File tree

2 files changed

+50
-4
lines changed

2 files changed

+50
-4
lines changed

deepsensor/model/convnp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,11 +172,13 @@ def array_modify_fn(arr):
172172

173173
arr = arr.astype(np.float32) # Cast to float32
174174

175-
# Find NaNs and keep size-1 variable dim
176-
mask = np.any(np.isnan(arr), axis=1, keepdims=True)
175+
# Find NaNs
176+
mask = np.isnan(arr)
177177
if np.any(mask):
178178
# Set NaNs to zero - necessary for `neuralprocesses` (can't have NaNs)
179179
arr[mask] = 0.0
180+
# Mask array (True for observed, False for missing) - keep size 1 variable dim
181+
mask = ~np.any(mask, axis=1, keepdims=True)
180182

181183
# Convert to tensor object based on deep learning backend
182184
arr = backend.convert_to_tensor(arr)

tests/test_model.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(self, *args, **kwargs):
5252
self.df = _gen_data_pandas()
5353

5454
self.dp = DataProcessor()
55-
_ = self.dp([self.da, self.df]) # Compute normalization parameters
55+
_ = self.dp([self.da, self.df]) # Compute normalisation parameters
5656

5757
def _gen_task_loader_call_args(self, n_context, n_target):
5858
"""Generate arguments for TaskLoader.__call__
@@ -194,6 +194,48 @@ def test_prediction_shapes_lowlevel(self, n_target_sets):
194194
x = B.to_numpy(model.loss_fn(task))
195195
assert x.size == 1 and x.shape == ()
196196

197+
@parameterized.expand(range(1, 4))
198+
def test_nans_offgrid_context(self, ndim):
199+
"""Test that `ConvNP` can handle NaNs in offgrid context"""
200+
201+
tl = TaskLoader(
202+
context=_gen_data_xr(data_vars=range(ndim)),
203+
target=self.da,
204+
)
205+
206+
# All NaNs
207+
task = tl("2020-01-01", context_sampling=10, target_sampling=10)
208+
task["Y_c"][0][:, 0] = np.nan
209+
model = ConvNP(self.dp, tl, unet_channels=(5, 5, 5), verbose=False)
210+
_ = model(task)
211+
212+
# One NaN
213+
task = tl("2020-01-01", context_sampling=10, target_sampling=10)
214+
task["Y_c"][0][0, 0] = np.nan
215+
model = ConvNP(self.dp, tl, unet_channels=(5, 5, 5), verbose=False)
216+
_ = model(task)
217+
218+
@parameterized.expand(range(1, 4))
219+
def test_nans_gridded_context(self, ndim):
220+
"""Test that `ConvNP` can handle NaNs in gridded context"""
221+
222+
tl = TaskLoader(
223+
context=_gen_data_xr(data_vars=range(ndim)),
224+
target=self.da,
225+
)
226+
227+
# All NaNs
228+
task = tl("2020-01-01", context_sampling="all", target_sampling=10)
229+
task["Y_c"][0][:, 0, 0] = np.nan
230+
model = ConvNP(self.dp, tl, unet_channels=(5, 5, 5), verbose=False)
231+
_ = model(task)
232+
233+
# One NaN
234+
task = tl("2020-01-01", context_sampling="all", target_sampling=10)
235+
task["Y_c"][0][0, 0, 0] = np.nan
236+
model = ConvNP(self.dp, tl, unet_channels=(5, 5, 5), verbose=False)
237+
_ = model(task)
238+
197239
@parameterized.expand(range(1, 4))
198240
def test_prediction_shapes_highlevel(self, target_dim):
199241
"""Test high-level `.predict` interface over a range of number of target sets
@@ -226,7 +268,9 @@ def test_prediction_shapes_highlevel(self, target_dim):
226268
tasks,
227269
X_t=self.da,
228270
n_samples=n_samples,
229-
unnormalise=False if target_dim > 1 else True,
271+
unnormalise=True
272+
if target_dim == 1
273+
else False, # TODO fix unnormalising for multiple equally named targets
230274
)
231275
assert [isinstance(ds, xr.Dataset) for ds in [mean_ds, std_ds, samples_ds]]
232276
assert_shape(

0 commit comments

Comments
 (0)