Skip to content

Commit 550c35f

Browse files
committed
wip add RasterIndex
Based on coordinate transform examples copied and adapted from pydata/xarray#9543.
1 parent 3c8512c commit 550c35f

File tree

1 file changed

+362
-0
lines changed

1 file changed

+362
-0
lines changed

rioxarray/raster_index.py

Lines changed: 362 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,362 @@
1+
from collections.abc import Hashable, Mapping
2+
from typing import Any
3+
4+
import numpy as np
5+
from affine import Affine
6+
from xarray import DataArray, Index, Variable
7+
# TODO: import from public API once it is available
8+
from xarray.core.indexes import CoordinateTransformIndex, PandasIndex
9+
from xarray.core.indexing import IndexSelResult, merge_sel_results
10+
from xarray.core.coordinate_transform import CoordinateTransform
11+
12+
13+
class AffineTransform(CoordinateTransform):
14+
"""Affine 2D transform wrapper."""
15+
16+
affine: Affine
17+
xy_dims: tuple[str, str]
18+
19+
def __init__(
20+
self,
21+
affine: Affine,
22+
width: int,
23+
height: int,
24+
x_dim: str = "x",
25+
y_dim: str = "y",
26+
dtype: Any = np.dtype(np.float64),
27+
):
28+
super().__init__((x_dim, y_dim), {x_dim: width, y_dim: height}, dtype=dtype)
29+
self.affine = affine
30+
31+
# array dimensions in reverse order (y = rows, x = cols)
32+
self.xy_dims = self.dims[0], self.dims[1]
33+
self.dims = self.dims[1], self.dims[0]
34+
35+
def forward(self, dim_positions):
36+
positions = tuple(dim_positions[dim] for dim in self.xy_dims)
37+
x_labels, y_labels = self.affine * positions
38+
39+
results = {}
40+
for name, labels in zip(self.coord_names, [x_labels, y_labels]):
41+
results[name] = labels
42+
43+
return results
44+
45+
def reverse(self, coord_labels):
46+
labels = tuple(coord_labels[name] for name in self.coord_names)
47+
x_positions, y_positions = ~self.affine * labels
48+
49+
results = {}
50+
for dim, positions in zip(self.xy_dims, [x_positions, y_positions]):
51+
results[dim] = positions
52+
53+
return results
54+
55+
def equals(self, other):
56+
if not isinstance(other, AffineTransform):
57+
return False
58+
return self.affine == other.affine and self.dim_size == other.dim_size
59+
60+
61+
class AxisAffineTransform(CoordinateTransform):
62+
"""Axis-independent wrapper of an affine 2D transform with no skew/rotation."""
63+
64+
affine: Affine
65+
is_xaxis: bool
66+
coord_name: Hashable
67+
dim: str
68+
size: int
69+
70+
def __init__(
71+
self,
72+
affine: Affine,
73+
size: int,
74+
dim: str,
75+
is_xaxis: bool,
76+
dtype: Any = np.dtype(np.float64),
77+
):
78+
assert (affine.is_rectilinear and (affine.b == affine.d == 0))
79+
80+
super().__init__((dim), {dim: size}, dtype=dtype)
81+
self.affine = affine
82+
self.is_xaxis = is_xaxis
83+
self.coord_name = dim
84+
self.dim = dim
85+
self.size = size
86+
87+
def forward(self, dim_positions: dict[str, Any]) -> dict[Hashable, Any]:
88+
positions = dim_positions[self.dim]
89+
90+
if self.is_xaxis:
91+
labels, _ = self.affine * (positions, np.zeros_like(positions))
92+
else:
93+
_, labels = self.affine * (np.zeros_like(positions), positions)
94+
95+
return {self.coord_name: labels}
96+
97+
def reverse(self, coord_labels: dict[Hashable, Any]) -> dict[str, Any]:
98+
labels = coord_labels[self.coord_name]
99+
100+
if self.is_xaxis:
101+
positions, _ = ~self.affine * (labels, np.zeros_like(labels))
102+
else:
103+
_, positions = ~self.affine * (np.zeros_like(labels), labels)
104+
105+
return {self.dim: positions}
106+
107+
def equals(self, other):
108+
if not isinstance(other, AxisAffineTransform):
109+
return False
110+
111+
# only compare the affine parameters of the relevant axis
112+
if self.is_xaxis:
113+
affine_match = self.affine.a == other.affine.a and self.affine.c == other.affine.c
114+
else:
115+
affine_match = self.affine.e == other.affine.e and self.affine.f == other.affine.f
116+
117+
return affine_match and self.size == other.size
118+
119+
def generate_coords(
120+
self, dims: tuple[str, ...] | None = None
121+
) -> dict[Hashable, Any]:
122+
assert dims is None or dims == self.dims
123+
return self.forward({self.dim: np.arange(self.size)})
124+
125+
def slice(self, slice: slice) -> "AxisAffineTransform":
126+
start = max(slice.start or 0, 0)
127+
stop = min(slice.stop or self.size, self.size)
128+
step = slice.step or 1
129+
130+
# TODO: support reverse transform (i.e., start > stop)?
131+
assert slice.start < slice.stop
132+
133+
size = stop - start // step
134+
scale = 1. / step
135+
136+
if self.is_xaxis:
137+
affine = self.affine * Affine.translation(start, 0.) * Affine.scale(scale, 1.)
138+
else:
139+
affine = self.affine * Affine.translation(0., start) * Affine.scale(1., scale)
140+
141+
return type(self)(affine, size, self.dim, is_xaxis=self.is_xaxis, dtype=self.dtype)
142+
143+
144+
class AxisAffineTransformIndex(CoordinateTransformIndex):
145+
"""Axis-independent Xarray Index for an affine 2D transform with no
146+
skew/rotation.
147+
148+
For internal use only.
149+
150+
This Index class provides specific behavior on top of
151+
Xarray's `CoordinateTransformIndex`:
152+
153+
- Data slicing computes a new affine transform and returns a new
154+
`AxisAffineTransformIndex` object
155+
156+
- Otherwise data selection creates and returns a new Xarray
157+
`PandasIndex` object for non-scalar indexers
158+
159+
"""
160+
axis_transform: AxisAffineTransform
161+
dim: str
162+
163+
def __init__(self, transform: AxisAffineTransform):
164+
assert isinstance(transform, AxisAffineTransform)
165+
super().__init__(transform)
166+
self.axis_transform = transform
167+
self.dim = transform.dim
168+
169+
def isel( # type: ignore[override]
170+
self, indexers: Mapping[Any, int | slice | np.ndarray | Variable]
171+
) -> "AxisAffineTransformIndex | PandasIndex | None":
172+
idxer = indexers[self.dim]
173+
174+
# generate a new index with updated transform if a slice is given
175+
if isinstance(idxer, slice):
176+
return AxisAffineTransformIndex(self.axis_transform.slice(idxer))
177+
# no index for scalar value
178+
elif np.isscalar(idxer):
179+
return None
180+
# otherwise return a PandasIndex with values computed by forward transformation
181+
else:
182+
values = np.asarray(self.axis_transform.forward({self.dim: idxer}))
183+
return PandasIndex(values, self.dim, coord_dtype=values.dtype)
184+
185+
def sel(self, labels, method=None, tolerance=None):
186+
coord_name = self.axis_transform.coord_name
187+
label = labels[coord_name]
188+
189+
if isinstance(label, slice):
190+
if label.step is None:
191+
# continuous interval slice indexing (preserves the index)
192+
pos = self.transform.reverse({coord_name: np.array([label.start, label.stop])})
193+
pos = np.round(pos[self.dim]).astype("int")
194+
new_start = max(pos[0], 0)
195+
new_stop = min(pos[1], self.axis_transform.size)
196+
return IndexSelResult({self.dim: slice(new_start, new_stop)})
197+
else:
198+
# otherwise convert to basic (array) indexing
199+
label = np.arange(label.start, label.stop, label.step)
200+
201+
# support basic indexing (in the 1D case basic vs. vectorized indexing
202+
# are pretty much similar)
203+
unwrap_xr = False
204+
if not isinstance(label, Variable | DataArray):
205+
# basic indexing -> either scalar or 1-d array
206+
try:
207+
var = Variable("_", label)
208+
except ValueError:
209+
var = Variable((), label)
210+
labels = {self.dim: var}
211+
unwrap_xr = True
212+
213+
result = super().sel(labels, method=method, tolerance=tolerance)
214+
215+
if unwrap_xr:
216+
dim_indexers = {self.dim: result.dim_indexers[self.dim].values}
217+
result = IndexSelResult(dim_indexers)
218+
219+
return result
220+
221+
222+
class RectilinearAffineTransformIndex(Index):
223+
"""Xarray index for 2D rectilinear affine transform (no skew/rotation).
224+
225+
For internal use only.
226+
227+
"""
228+
def __init__(
229+
self,
230+
x_index: AxisAffineTransformIndex,
231+
y_index: AxisAffineTransformIndex,
232+
):
233+
self.x_index = x_index
234+
self.y_index = y_index
235+
236+
def sel(
237+
self, labels: dict[Any, Any], method=None, tolerance=None
238+
) -> IndexSelResult:
239+
results = []
240+
241+
for axis_index in (self.x_index, self.y_index):
242+
coord_name = axis_index.axis_transform.coord_name
243+
if coord_name in labels:
244+
results.append(axis_index.sel({coord_name: labels[coord_name]}, method=method, tolerance=tolerance))
245+
246+
return merge_sel_results(results)
247+
248+
def equals(self, other: "RectilinearAffineTransformIndex") -> bool:
249+
return self.x_index.equals(other.x_index) and self.y_index.equals(other.y_index)
250+
251+
252+
class RasterIndex(Index):
253+
"""Xarray custom index for raster coordinates."""
254+
255+
_x_index: AxisAffineTransformIndex | PandasIndex | None
256+
_y_index: AxisAffineTransformIndex | PandasIndex | None
257+
_xy_index: CoordinateTransformIndex | None
258+
259+
def __init__(
260+
self,
261+
x_index: AxisAffineTransformIndex | PandasIndex | None = None,
262+
y_index: AxisAffineTransformIndex | PandasIndex | None = None,
263+
xy_index: CoordinateTransformIndex | None = None,
264+
):
265+
# must at least have one index passed
266+
assert any(idx is not None for idx in (x_index, y_index, xy_index))
267+
# either 1D x/y coordinates with x_index/y_index or 2D x/y coordinates with xy_index
268+
if xy_index is not None:
269+
assert x_index is None and y_index is None
270+
271+
self._x_index = x_index
272+
self._y_index = y_index
273+
self._xy_index = xy_index
274+
275+
def _get_subindexes(self) -> tuple[Index | None, ...]:
276+
return (self._xy_index, self._x_index, self._y_index)
277+
278+
@classmethod
279+
def from_transform(cls, affine: Affine, width: int, height: int, x_dim: str = "x", y_dim: str = "y") -> "RasterIndex":
280+
if affine.is_rectilinear and affine.b == affine.d == 0:
281+
x_transform = AxisAffineTransform(affine, width, x_dim, is_xaxis=True)
282+
y_transform = AxisAffineTransform(affine, height, y_dim, is_xaxis=False)
283+
return cls(
284+
x_index=AxisAffineTransformIndex(x_transform),
285+
y_index=AxisAffineTransformIndex(y_transform),
286+
)
287+
else:
288+
xy_transform = AffineTransform(affine, width, height, x_dim=x_dim, y_dim=y_dim)
289+
return cls(xy_index=CoordinateTransformIndex(xy_transform))
290+
291+
@classmethod
292+
def from_variables(
293+
cls,
294+
variables: Mapping[Any, Variable],
295+
*,
296+
options: Mapping[str, Any],
297+
) -> "RasterIndex":
298+
# TODO: compute bounds, resolution and affine transform from explicit coordinates.
299+
raise NotImplementedError(
300+
"Creating a RasterIndex from existing coordinates is not yet supported."
301+
)
302+
303+
def create_variables(
304+
self, variables: Mapping[Any, Variable] | None = None
305+
) -> dict[Hashable, Variable]:
306+
new_variables: dict[Hashable, Variable] = {}
307+
308+
for index in (self._x_index, self._y_index, self._xy_index):
309+
if index is not None:
310+
new_variables.update(index.create_variables())
311+
312+
return new_variables
313+
314+
def isel(
315+
self, indexers: Mapping[Any, int | slice | np.ndarray | Variable]
316+
) -> "RasterIndex | None":
317+
indexes: dict[str, Any] = {}
318+
319+
if self._xy_index is not None:
320+
indexes["xy_index"] = self._xy_index.isel(indexers)
321+
322+
if self._x_index is not None and self._x_index.dim in indexers:
323+
dim = self._x_index.dim
324+
indexes["x_index"] = self._x_index.isel(indexers={dim: indexers[dim]})
325+
326+
if self._y_index is not None and self._y_index.dim in indexers:
327+
dim = self._y_index.dim
328+
indexes["x_index"] = self._y_index.isel(indexers={dim: indexers[dim]})
329+
330+
if any(idx is not None for idx in indexes.values()):
331+
return RasterIndex(**indexes)
332+
else:
333+
return None
334+
335+
def sel(
336+
self, labels: dict[Any, Any], method=None, tolerance=None
337+
) -> IndexSelResult:
338+
results = []
339+
340+
if self._xy_index is not None:
341+
results.append(self._xy_index.sel(labels, method=method, tolerance=tolerance))
342+
343+
if self._x_index is not None and self._x_index.dim in labels:
344+
dim = self._x_index.dim
345+
results.append(self._x_index.sel(labels={dim: labels[dim]}, method=method, tolerance=tolerance))
346+
347+
if self._y_index is not None and self._y_index.dim in labels:
348+
dim = self._y_index.dim
349+
results.append(self._y_index.sel(labels={dim: labels[dim]}, method=method, tolerance=tolerance))
350+
351+
return merge_sel_results(results)
352+
353+
def equals(self, other: "RasterIndex") -> bool:
354+
if not isinstance(other, RasterIndex):
355+
return False
356+
357+
for (idx, oidx) in zip(self._get_subindexes(), other._get_subindexes()):
358+
if idx is not None and not idx.equals(oidx)
359+
if self._xy_index is not None and not self._xy_index.equals(other._xy_index):
360+
return False
361+
362+
return self.x_index.equals(other.x_index) and self.y_index.equals(other.y_index)

0 commit comments

Comments
 (0)