Skip to content

Commit a0224fe

Browse files
Merge pull request #25 from polpel/dptest
Fix rounding errors in `DeepSensorModel.predict` coordinates from normalise-unnormalise operations
2 parents 094c510 + ece593a commit a0224fe

File tree

4 files changed

+207
-84
lines changed

4 files changed

+207
-84
lines changed

deepsensor/active_learning/algorithms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def _model_infill_at_search_points(
198198
infill_ds, _ = self.model.predict(
199199
self.tasks,
200200
X_s,
201-
X_t_normalised=True,
201+
X_t_is_normalised=True,
202202
unnormalise=False,
203203
progress_bar=self.progress_bar >= 4,
204204
)

deepsensor/data/processor.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -301,54 +301,64 @@ def map_coords(
301301
data = data.set_index(indexes)
302302
return data
303303

304+
def map_array(
305+
self,
306+
data: Union[xr.DataArray, xr.Dataset, pd.DataFrame, pd.Series, np.ndarray],
307+
var_ID: str,
308+
method: str = "mean_std",
309+
unnorm: bool = False,
310+
add_offset=True,
311+
):
312+
"""Normalise or unnormalise the data values in an xarray, pandas, or numpy object"""
313+
param1, param2 = self.get_norm_params(var_ID, data, method, unnorm)
314+
if method == "mean_std":
315+
if not unnorm:
316+
scale = 1 / param2
317+
offset = -param1 / param2
318+
else:
319+
scale = param2
320+
offset = param1
321+
elif method == "min_max":
322+
if not unnorm:
323+
scale = 2 / (param2 - param1)
324+
offset = -(param2 + param1) / (param2 - param1)
325+
else:
326+
scale = (param2 - param1) / 2
327+
offset = (param2 + param1) / 2
328+
else:
329+
raise ValueError(
330+
f"Method {method} not recognised. Use 'mean_std' or 'min_max'."
331+
)
332+
data = data * scale
333+
if add_offset:
334+
data = data + offset
335+
return data
336+
304337
def map(
305338
self,
306339
data: Union[xr.DataArray, xr.Dataset, pd.DataFrame, pd.Series],
307340
method: str = "mean_std",
308341
add_offset: bool = True,
309342
unnorm: bool = False,
310343
):
311-
"""Normalise or unnormalise data"""
344+
"""Normalise or unnormalise the data values and coords in an xarray or pandas object"""
312345
if self.deepcopy:
313346
data = deepcopy(data)
314347

315-
def mapper(data, param1, param2, method):
316-
if method == "mean_std":
317-
if not unnorm:
318-
scale = 1 / param2
319-
offset = -param1 / param2
320-
else:
321-
scale = param2
322-
offset = param1
323-
elif method == "min_max":
324-
if not unnorm:
325-
scale = 2 / (param2 - param1)
326-
offset = -(param2 + param1) / (param2 - param1)
327-
else:
328-
scale = (param2 - param1) / 2
329-
offset = (param2 + param1) / 2
330-
data = data * scale
331-
if add_offset:
332-
data = data + offset
333-
return data
334-
335348
if isinstance(data, (xr.DataArray, xr.Dataset)) and not unnorm:
336349
self._validate_xr(data)
337350
elif isinstance(data, (pd.DataFrame, pd.Series)) and not unnorm:
338351
self._validate_pandas(data)
339352

340353
if isinstance(data, (xr.DataArray, pd.Series)):
341354
# Single var
342-
var_ID = data.name
343-
param1, param2 = self.get_norm_params(var_ID, data, method, unnorm)
344-
data = mapper(data, param1, param2, method)
355+
data = self.map_array(data, data.name, method, unnorm, add_offset)
345356
elif isinstance(data, (xr.Dataset, pd.DataFrame)):
346357
# Multiple vars
347358
for var_ID in data:
348-
param1, param2 = self.get_norm_params(
349-
var_ID, data[var_ID], method, unnorm
359+
data[var_ID] = self.map_array(
360+
data[var_ID], var_ID, method, unnorm, add_offset
350361
)
351-
data[var_ID] = mapper(data[var_ID], param1, param2, method)
352362

353363
data = self.map_coords(data, unnorm=unnorm)
354364

deepsensor/model/model.py

Lines changed: 99 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def predict(
194194
X_t: Union[
195195
xr.Dataset, xr.DataArray, pd.DataFrame, pd.Series, pd.Index, np.ndarray
196196
],
197-
X_t_normalised: bool = False,
197+
X_t_is_normalised: bool = False,
198198
resolution_factor=1,
199199
n_samples=0,
200200
ar_sample=False,
@@ -210,24 +210,25 @@ def predict(
210210
TODO:
211211
- Test with multiple targets model
212212
213-
:param tasks: List of tasks containing context data.
214-
:param X_t: Target locations to predict at. Can be an xarray object containing
215-
on-grid locations or a pandas object containing off-grid locations.
216-
:param X_t_normalised: Whether the `X_t` coords are normalised.
217-
If False, will normalise the coords before passing to model. Default False.
218-
:param resolution_factor: Optional factor to increase the resolution of the
219-
target grid by. E.g. 2 will double the target resolution, 0.5 will halve it.
220-
Applies to on-grid predictions only. Default 1.
221-
:param n_samples: Number of joint samples to draw from the model.
222-
If 0, will not draw samples. Default 0.
223-
:param ar_sample: Whether to use autoregressive sampling. Default False.
224-
:param unnormalise: Whether to unnormalise the predictions. Only works if
225-
`self` has a `data_processor` and `task_loader` attribute. Default True.
226-
:param seed: Random seed for deterministic sampling. Default 0.
227-
:param append_indexes: Dictionary of index metadata to append to pandas indexes
228-
in the off-grid case. Default None.
229-
:param progress_bar: Whether to display a progress bar over tasks. Default 0.
230-
:param verbose: Whether to print time taken for prediction. Default False.
213+
Args:
214+
tasks: List of tasks containing context data.
215+
X_t: Target locations to predict at. Can be an xarray object containing
216+
on-grid locations or a pandas object containing off-grid locations.
217+
X_t_is_normalised: Whether the `X_t` coords are normalised.
218+
If False, will normalise the coords before passing to model. Default False.
219+
resolution_factor: Optional factor to increase the resolution of the
220+
target grid by. E.g. 2 will double the target resolution, 0.5 will halve it.
221+
Applies to on-grid predictions only. Default 1.
222+
n_samples: Number of joint samples to draw from the model.
223+
If 0, will not draw samples. Default 0.
224+
ar_sample: Whether to use autoregressive sampling. Default False.
225+
unnormalise: Whether to unnormalise the predictions. Only works if
226+
`self` has a `data_processor` and `task_loader` attribute. Default True.
227+
seed: Random seed for deterministic sampling. Default 0.
228+
append_indexes: Dictionary of index metadata to append to pandas indexes
229+
in the off-grid case. Default None.
230+
progress_bar: Whether to display a progress bar over tasks. Default 0.
231+
verbose: Whether to print time taken for prediction. Default False.
231232
232233
Returns:
233234
- If X_t is a pandas object, returns pandas objects containing off-grid predictions.
@@ -242,12 +243,25 @@ def predict(
242243
raise ValueError(
243244
"resolution_factor can only be used with on-grid predictions."
244245
)
246+
if ar_subsample_factor != 1:
247+
raise ValueError(
248+
"ar_subsample_factor can only be used with on-grid predictions."
249+
)
245250
if not isinstance(X_t, (pd.DataFrame, pd.Series, pd.Index, np.ndarray)):
246251
if append_indexes is not None:
247252
raise ValueError(
248253
"append_indexes can only be used with off-grid predictions."
249254
)
250255

256+
if isinstance(X_t, (xr.DataArray, xr.Dataset)):
257+
mode = "on-grid"
258+
elif isinstance(X_t, (pd.DataFrame, pd.Series, pd.Index, np.ndarray)):
259+
mode = "off-grid"
260+
else:
261+
raise ValueError(
262+
f"X_t must be and xarray, pandas or numpy object. Got {type(X_t)}."
263+
)
264+
251265
if type(tasks) is Task:
252266
tasks = [tasks]
253267

@@ -262,59 +276,78 @@ def predict(
262276
var_ID for set in self.task_loader.target_var_IDs for var_ID in set
263277
]
264278

279+
# Pre-process X_t if necessary
265280
if isinstance(X_t, pd.Index):
266281
X_t = pd.DataFrame(index=X_t)
267282
elif isinstance(X_t, np.ndarray):
268283
# Convert to empty dataframe with normalised or unnormalised coord names
269-
if X_t_normalised:
284+
if X_t_is_normalised:
270285
index_names = ["x1", "x2"]
271286
else:
272287
index_names = self.data_processor.raw_spatial_coord_names
273288
X_t = pd.DataFrame(X_t.T, columns=index_names)
274289
X_t = X_t.set_index(index_names)
290+
if mode == "off-grid" and append_indexes is not None:
291+
# Check append_indexes are all same length as X_t
292+
if append_indexes is not None:
293+
for idx, vals in append_indexes.items():
294+
if len(vals) != len(X_t):
295+
raise ValueError(
296+
f"append_indexes[{idx}] must be same length as X_t, got {len(vals)} and {len(X_t)} respectively."
297+
)
298+
X_t = X_t.reset_index()
299+
X_t = pd.concat([X_t, pd.DataFrame(append_indexes)], axis=1)
300+
X_t = X_t.set_index(list(X_t.columns))
275301

276-
if not X_t_normalised:
277-
X_t = self.data_processor.map_coords(X_t) # Normalise
302+
if X_t_is_normalised:
303+
X_t_normalised = X_t
278304

279-
if isinstance(X_t, (xr.DataArray, xr.Dataset)):
280-
mode = "on-grid"
281-
elif isinstance(X_t, (pd.DataFrame, pd.Series, pd.Index)):
282-
mode = "off-grid"
283-
if append_indexes is not None:
284-
# Check append_indexes are all same length as X_t
285-
if append_indexes is not None:
286-
for idx, vals in append_indexes.items():
287-
if len(vals) != len(X_t):
288-
raise ValueError(
289-
f"append_indexes[{idx}] must be same length as X_t, got {len(vals)} and {len(X_t)} respectively."
290-
)
291-
X_t = X_t.reset_index()
292-
X_t = pd.concat([X_t, pd.DataFrame(append_indexes)], axis=1)
293-
X_t = X_t.set_index(list(X_t.columns))
305+
# Unnormalise coords to use for xarray/pandas objects for storing predictions
306+
X_t = self.data_processor.map_coords(X_t, unnorm=True)
294307
else:
295-
raise ValueError(
296-
f"X_t must be an xarray object or a pandas object, not {type(X_t)}"
297-
)
308+
# Normalise coords to use for model
309+
X_t_normalised = self.data_processor.map_coords(X_t)
298310

311+
if mode == "on-grid":
312+
X_t_arr = (X_t_normalised["x1"].values, X_t_normalised["x2"].values)
313+
elif mode == "off-grid":
314+
X_t_arr = X_t_normalised.reset_index()[["x1", "x2"]].values.T
315+
316+
if not unnormalise:
317+
X_t = X_t_normalised
318+
coord_names = {"x1": "x1", "x2": "x2"}
319+
elif unnormalise:
320+
coord_names = {
321+
"x1": self.data_processor.raw_spatial_coord_names[0],
322+
"x2": self.data_processor.raw_spatial_coord_names[1],
323+
}
324+
325+
# Create empty xarray/pandas objects to store predictions
299326
if mode == "on-grid":
300327
mean = create_empty_spatiotemporal_xarray(
301-
X_t, dates, resolution_factor, data_vars=target_var_IDs
328+
X_t,
329+
dates,
330+
resolution_factor,
331+
data_vars=target_var_IDs,
332+
coord_names=coord_names,
302333
).to_array(dim="data_var")
303334
std = create_empty_spatiotemporal_xarray(
304-
X_t, dates, resolution_factor, data_vars=target_var_IDs
335+
X_t,
336+
dates,
337+
resolution_factor,
338+
data_vars=target_var_IDs,
339+
coord_names=coord_names,
305340
).to_array(dim="data_var")
306341
if n_samples >= 1:
307342
samples = create_empty_spatiotemporal_xarray(
308343
X_t,
309344
dates,
310345
resolution_factor,
311346
data_vars=target_var_IDs,
347+
coord_names=coord_names,
312348
prepend_dims=["sample"],
313349
prepend_coords={"sample": np.arange(n_samples)},
314350
).to_array(dim="data_var")
315-
316-
X_t_arr = (mean["x1"].values, mean["x2"].values)
317-
318351
elif mode == "off-grid":
319352
# Repeat target locs for each date to create multiindex
320353
idxs = [(date, *idxs) for date in dates for idxs in X_t.index]
@@ -333,7 +366,18 @@ def predict(
333366
)
334367
samples = pd.DataFrame(index=index_samples, columns=target_var_IDs)
335368

336-
X_t_arr = X_t.reset_index()[["x1", "x2"]].values.T
369+
def unnormalise_pred_array(arr, **kwargs):
370+
var_IDs_flattened = [
371+
var_ID
372+
for var_IDs in self.task_loader.target_var_IDs
373+
for var_ID in var_IDs
374+
]
375+
assert arr.shape[0] == len(var_IDs_flattened)
376+
for i, var_ID in enumerate(var_IDs_flattened):
377+
arr[i] = self.data_processor.map_array(
378+
arr[i], var_ID, method="mean_std", unnorm=True, **kwargs
379+
)
380+
return arr
337381

338382
# Don't change tasks by reference when overriding target locations
339383
tasks = copy.deepcopy(tasks)
@@ -385,6 +429,15 @@ def predict(
385429
if n_samples >= 1:
386430
samples_arr = np.concatenate(samples_arr, axis=0)
387431

432+
if unnormalise:
433+
mean_arr = unnormalise_pred_array(mean_arr)
434+
std_arr = unnormalise_pred_array(std_arr, add_offset=False)
435+
if n_samples >= 1:
436+
for sample_i in range(n_samples):
437+
samples_arr[sample_i] = unnormalise_pred_array(
438+
samples_arr[sample_i]
439+
)
440+
388441
if mode == "on-grid":
389442
mean.loc[:, task["time"], :, :] = mean_arr
390443
std.loc[:, task["time"], :, :] = std_arr
@@ -407,16 +460,6 @@ def predict(
407460
if n_samples >= 1:
408461
samples = samples.to_dataset(dim="data_var")
409462

410-
if (
411-
self.task_loader is not None
412-
and self.data_processor is not None
413-
and unnormalise == True
414-
):
415-
mean = self.data_processor.unnormalise(mean)
416-
std = self.data_processor.unnormalise(std, add_offset=False)
417-
if n_samples >= 1:
418-
samples = self.data_processor.unnormalise(samples)
419-
420463
if verbose:
421464
dur = time.time() - tic
422465
print(f"Done in {np.floor(dur / 60)}m:{dur % 60:.0f}s.\n")

0 commit comments

Comments
 (0)