Skip to content

Commit b2cdf13

Browse files
committed
Fix Y_c dim when catching 0 context sampling case
1 parent 67249a5 commit b2cdf13

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

deepsensor/data/loader.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,9 +305,13 @@ def sample_da(
305305
elif not self.discrete_xarray_sampling:
306306
if N == 0:
307307
# Catch zero-context edge case before interp fails
308-
return np.array([[], []], dtype=self.dtype), np.array(
309-
[], dtype=self.dtype
310-
)
308+
X_c = np.zeros((2, 0), dtype=self.dtype)
309+
if isinstance(da, xr.Dataset):
310+
dim = len(da.data_vars) # Multiple data variables
311+
elif isinstance(da, xr.DataArray):
312+
dim = 1 # Single data variable
313+
Y_c = np.zeros((dim, 0), dtype=self.dtype)
314+
return X_c, Y_c
311315
x1 = rng.uniform(da.coords["x1"].min(), da.coords["x1"].max(), N)
312316
x2 = rng.uniform(da.coords["x2"].min(), da.coords["x2"].max(), N)
313317
Y_c = da.interp(

0 commit comments

Comments
 (0)