@@ -194,7 +194,7 @@ def predict(
194
194
X_t : Union [
195
195
xr .Dataset , xr .DataArray , pd .DataFrame , pd .Series , pd .Index , np .ndarray
196
196
],
197
- X_t_normalised : bool = False ,
197
+ X_t_is_normalised : bool = False ,
198
198
resolution_factor = 1 ,
199
199
n_samples = 0 ,
200
200
ar_sample = False ,
@@ -210,24 +210,25 @@ def predict(
210
210
TODO:
211
211
- Test with multiple targets model
212
212
213
- :param tasks: List of tasks containing context data.
214
- :param X_t: Target locations to predict at. Can be an xarray object containing
215
- on-grid locations or a pandas object containing off-grid locations.
216
- :param X_t_normalised: Whether the `X_t` coords are normalised.
217
- If False, will normalise the coords before passing to model. Default False.
218
- :param resolution_factor: Optional factor to increase the resolution of the
219
- target grid by. E.g. 2 will double the target resolution, 0.5 will halve it.
220
- Applies to on-grid predictions only. Default 1.
221
- :param n_samples: Number of joint samples to draw from the model.
222
- If 0, will not draw samples. Default 0.
223
- :param ar_sample: Whether to use autoregressive sampling. Default False.
224
- :param unnormalise: Whether to unnormalise the predictions. Only works if
225
- `self` has a `data_processor` and `task_loader` attribute. Default True.
226
- :param seed: Random seed for deterministic sampling. Default 0.
227
- :param append_indexes: Dictionary of index metadata to append to pandas indexes
228
- in the off-grid case. Default None.
229
- :param progress_bar: Whether to display a progress bar over tasks. Default 0.
230
- :param verbose: Whether to print time taken for prediction. Default False.
213
+ Args:
214
+ tasks: List of tasks containing context data.
215
+ X_t: Target locations to predict at. Can be an xarray object containing
216
+ on-grid locations or a pandas object containing off-grid locations.
217
+ X_t_is_normalised: Whether the `X_t` coords are normalised.
218
+ If False, will normalise the coords before passing to model. Default False.
219
+ resolution_factor: Optional factor to increase the resolution of the
220
+ target grid by. E.g. 2 will double the target resolution, 0.5 will halve it.
221
+ Applies to on-grid predictions only. Default 1.
222
+ n_samples: Number of joint samples to draw from the model.
223
+ If 0, will not draw samples. Default 0.
224
+ ar_sample: Whether to use autoregressive sampling. Default False.
225
+ unnormalise: Whether to unnormalise the predictions. Only works if
226
+ `self` has a `data_processor` and `task_loader` attribute. Default True.
227
+ seed: Random seed for deterministic sampling. Default 0.
228
+ append_indexes: Dictionary of index metadata to append to pandas indexes
229
+ in the off-grid case. Default None.
230
+ progress_bar: Whether to display a progress bar over tasks. Default 0.
231
+ verbose: Whether to print time taken for prediction. Default False.
231
232
232
233
Returns:
233
234
- If X_t is a pandas object, returns pandas objects containing off-grid predictions.
@@ -242,12 +243,25 @@ def predict(
242
243
raise ValueError (
243
244
"resolution_factor can only be used with on-grid predictions."
244
245
)
246
+ if ar_subsample_factor != 1 :
247
+ raise ValueError (
248
+ "ar_subsample_factor can only be used with on-grid predictions."
249
+ )
245
250
if not isinstance (X_t , (pd .DataFrame , pd .Series , pd .Index , np .ndarray )):
246
251
if append_indexes is not None :
247
252
raise ValueError (
248
253
"append_indexes can only be used with off-grid predictions."
249
254
)
250
255
256
+ if isinstance (X_t , (xr .DataArray , xr .Dataset )):
257
+ mode = "on-grid"
258
+ elif isinstance (X_t , (pd .DataFrame , pd .Series , pd .Index , np .ndarray )):
259
+ mode = "off-grid"
260
+ else :
261
+ raise ValueError (
262
+ f"X_t must be and xarray, pandas or numpy object. Got { type (X_t )} ."
263
+ )
264
+
251
265
if type (tasks ) is Task :
252
266
tasks = [tasks ]
253
267
@@ -262,59 +276,78 @@ def predict(
262
276
var_ID for set in self .task_loader .target_var_IDs for var_ID in set
263
277
]
264
278
279
+ # Pre-process X_t if necessary
265
280
if isinstance (X_t , pd .Index ):
266
281
X_t = pd .DataFrame (index = X_t )
267
282
elif isinstance (X_t , np .ndarray ):
268
283
# Convert to empty dataframe with normalised or unnormalised coord names
269
- if X_t_normalised :
284
+ if X_t_is_normalised :
270
285
index_names = ["x1" , "x2" ]
271
286
else :
272
287
index_names = self .data_processor .raw_spatial_coord_names
273
288
X_t = pd .DataFrame (X_t .T , columns = index_names )
274
289
X_t = X_t .set_index (index_names )
290
+ if mode == "off-grid" and append_indexes is not None :
291
+ # Check append_indexes are all same length as X_t
292
+ if append_indexes is not None :
293
+ for idx , vals in append_indexes .items ():
294
+ if len (vals ) != len (X_t ):
295
+ raise ValueError (
296
+ f"append_indexes[{ idx } ] must be same length as X_t, got { len (vals )} and { len (X_t )} respectively."
297
+ )
298
+ X_t = X_t .reset_index ()
299
+ X_t = pd .concat ([X_t , pd .DataFrame (append_indexes )], axis = 1 )
300
+ X_t = X_t .set_index (list (X_t .columns ))
275
301
276
- if not X_t_normalised :
277
- X_t = self . data_processor . map_coords ( X_t ) # Normalise
302
+ if X_t_is_normalised :
303
+ X_t_normalised = X_t
278
304
279
- if isinstance (X_t , (xr .DataArray , xr .Dataset )):
280
- mode = "on-grid"
281
- elif isinstance (X_t , (pd .DataFrame , pd .Series , pd .Index )):
282
- mode = "off-grid"
283
- if append_indexes is not None :
284
- # Check append_indexes are all same length as X_t
285
- if append_indexes is not None :
286
- for idx , vals in append_indexes .items ():
287
- if len (vals ) != len (X_t ):
288
- raise ValueError (
289
- f"append_indexes[{ idx } ] must be same length as X_t, got { len (vals )} and { len (X_t )} respectively."
290
- )
291
- X_t = X_t .reset_index ()
292
- X_t = pd .concat ([X_t , pd .DataFrame (append_indexes )], axis = 1 )
293
- X_t = X_t .set_index (list (X_t .columns ))
305
+ # Unnormalise coords to use for xarray/pandas objects for storing predictions
306
+ X_t = self .data_processor .map_coords (X_t , unnorm = True )
294
307
else :
295
- raise ValueError (
296
- f"X_t must be an xarray object or a pandas object, not { type (X_t )} "
297
- )
308
+ # Normalise coords to use for model
309
+ X_t_normalised = self .data_processor .map_coords (X_t )
298
310
311
+ if mode == "on-grid" :
312
+ X_t_arr = (X_t_normalised ["x1" ].values , X_t_normalised ["x2" ].values )
313
+ elif mode == "off-grid" :
314
+ X_t_arr = X_t_normalised .reset_index ()[["x1" , "x2" ]].values .T
315
+
316
+ if not unnormalise :
317
+ X_t = X_t_normalised
318
+ coord_names = {"x1" : "x1" , "x2" : "x2" }
319
+ elif unnormalise :
320
+ coord_names = {
321
+ "x1" : self .data_processor .raw_spatial_coord_names [0 ],
322
+ "x2" : self .data_processor .raw_spatial_coord_names [1 ],
323
+ }
324
+
325
+ # Create empty xarray/pandas objects to store predictions
299
326
if mode == "on-grid" :
300
327
mean = create_empty_spatiotemporal_xarray (
301
- X_t , dates , resolution_factor , data_vars = target_var_IDs
328
+ X_t ,
329
+ dates ,
330
+ resolution_factor ,
331
+ data_vars = target_var_IDs ,
332
+ coord_names = coord_names ,
302
333
).to_array (dim = "data_var" )
303
334
std = create_empty_spatiotemporal_xarray (
304
- X_t , dates , resolution_factor , data_vars = target_var_IDs
335
+ X_t ,
336
+ dates ,
337
+ resolution_factor ,
338
+ data_vars = target_var_IDs ,
339
+ coord_names = coord_names ,
305
340
).to_array (dim = "data_var" )
306
341
if n_samples >= 1 :
307
342
samples = create_empty_spatiotemporal_xarray (
308
343
X_t ,
309
344
dates ,
310
345
resolution_factor ,
311
346
data_vars = target_var_IDs ,
347
+ coord_names = coord_names ,
312
348
prepend_dims = ["sample" ],
313
349
prepend_coords = {"sample" : np .arange (n_samples )},
314
350
).to_array (dim = "data_var" )
315
-
316
- X_t_arr = (mean ["x1" ].values , mean ["x2" ].values )
317
-
318
351
elif mode == "off-grid" :
319
352
# Repeat target locs for each date to create multiindex
320
353
idxs = [(date , * idxs ) for date in dates for idxs in X_t .index ]
@@ -333,7 +366,18 @@ def predict(
333
366
)
334
367
samples = pd .DataFrame (index = index_samples , columns = target_var_IDs )
335
368
336
- X_t_arr = X_t .reset_index ()[["x1" , "x2" ]].values .T
369
+ def unnormalise_pred_array (arr , ** kwargs ):
370
+ var_IDs_flattened = [
371
+ var_ID
372
+ for var_IDs in self .task_loader .target_var_IDs
373
+ for var_ID in var_IDs
374
+ ]
375
+ assert arr .shape [0 ] == len (var_IDs_flattened )
376
+ for i , var_ID in enumerate (var_IDs_flattened ):
377
+ arr [i ] = self .data_processor .map_array (
378
+ arr [i ], var_ID , method = "mean_std" , unnorm = True , ** kwargs
379
+ )
380
+ return arr
337
381
338
382
# Don't change tasks by reference when overriding target locations
339
383
tasks = copy .deepcopy (tasks )
@@ -385,6 +429,15 @@ def predict(
385
429
if n_samples >= 1 :
386
430
samples_arr = np .concatenate (samples_arr , axis = 0 )
387
431
432
+ if unnormalise :
433
+ mean_arr = unnormalise_pred_array (mean_arr )
434
+ std_arr = unnormalise_pred_array (std_arr , add_offset = False )
435
+ if n_samples >= 1 :
436
+ for sample_i in range (n_samples ):
437
+ samples_arr [sample_i ] = unnormalise_pred_array (
438
+ samples_arr [sample_i ]
439
+ )
440
+
388
441
if mode == "on-grid" :
389
442
mean .loc [:, task ["time" ], :, :] = mean_arr
390
443
std .loc [:, task ["time" ], :, :] = std_arr
@@ -407,16 +460,6 @@ def predict(
407
460
if n_samples >= 1 :
408
461
samples = samples .to_dataset (dim = "data_var" )
409
462
410
- if (
411
- self .task_loader is not None
412
- and self .data_processor is not None
413
- and unnormalise == True
414
- ):
415
- mean = self .data_processor .unnormalise (mean )
416
- std = self .data_processor .unnormalise (std , add_offset = False )
417
- if n_samples >= 1 :
418
- samples = self .data_processor .unnormalise (samples )
419
-
420
463
if verbose :
421
464
dur = time .time () - tic
422
465
print (f"Done in { np .floor (dur / 60 )} m:{ dur % 60 :.0f} s.\n " )
0 commit comments