Skip to content

Commit 663ade6

Browse files
committed
Merge pull request #554 from shoyer/real-and-imag
Fixes for complex numbers
2 parents 04f4e88 + 0531c37 commit 663ade6

10 files changed

+82
-4
lines changed

doc/api-hidden.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@
4040
Dataset.clip
4141
Dataset.conj
4242
Dataset.conjugate
43+
Dataset.imag
4344
Dataset.round
45+
Dataset.real
4446
Dataset.T
4547

4648
DataArray.ndim
@@ -80,8 +82,10 @@
8082
DataArray.clip
8183
DataArray.conj
8284
DataArray.conjugate
85+
DataArray.imag
8386
DataArray.searchsorted
8487
DataArray.round
88+
DataArray.real
8589
DataArray.T
8690

8791
ufuncs.angle

doc/api.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,9 @@ Computation
139139
:py:attr:`~Dataset.clip`
140140
:py:attr:`~Dataset.conj`
141141
:py:attr:`~Dataset.conjugate`
142+
:py:attr:`~Dataset.imag`
142143
:py:attr:`~Dataset.round`
144+
:py:attr:`~Dataset.real`
143145
:py:attr:`~Dataset.T`
144146

145147
**Grouped operations**:
@@ -253,8 +255,10 @@ Computation
253255
:py:attr:`~DataArray.clip`
254256
:py:attr:`~DataArray.conj`
255257
:py:attr:`~DataArray.conjugate`
258+
:py:attr:`~DataArray.imag`
256259
:py:attr:`~DataArray.searchsorted`
257260
:py:attr:`~DataArray.round`
261+
:py:attr:`~DataArray.real`
258262
:py:attr:`~DataArray.T`
259263

260264
**Grouped operations**:

doc/whats-new.rst

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,19 @@ API Changes
1919
:py:meth:`~xray.DataArray.plot` was changed to provide more compatibility
2020
with matplotlib's `contour` and `contourf` functions (:issue:`538`).
2121
Now discrete lists of colors should be specified using `colors` keyword,
22-
rather than `cmap`.
22+
rather than `cmap`.
23+
24+
Enhancements
25+
~~~~~~~~~~~~
26+
27+
- Add :py:attr:`~xray.Dataset.real` and :py:attr:`~xray.Dataset.imag`
28+
attributes to Dataset and DataArray (:issue:`553`).
29+
30+
Bug fixes
31+
~~~~~~~~~
32+
33+
- Aggregation functions now correctly skip ``NaN`` for data for ``complex128``
34+
dtype (:issue:`554`).
2335

2436
v0.6.0 (21 August 2015)
2537
-----------------------

xray/core/dataarray.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,5 +1171,13 @@ def diff(self, dim, n=1, label='upper'):
11711171
ds = self._dataset.diff(n=n, dim=dim, label=label)
11721172
return self._with_replaced_dataset(ds)
11731173

1174+
@property
1175+
def real(self):
1176+
return self._with_replaced_dataset(self._dataset.real)
1177+
1178+
@property
1179+
def imag(self):
1180+
return self._with_replaced_dataset(self._dataset.imag)
1181+
11741182
# priority most be higher than Variable to properly work with binary ufuncs
11751183
ops.inject_all_ops_and_reduce_methods(DataArray, priority=60)

xray/core/dataset.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1855,12 +1855,14 @@ def from_dataframe(cls, dataframe):
18551855
return obj
18561856

18571857
@staticmethod
1858-
def _unary_op(f):
1858+
def _unary_op(f, keep_attrs=False):
18591859
@functools.wraps(f)
18601860
def func(self, *args, **kwargs):
18611861
ds = self.coords.to_dataset()
18621862
for k in self.data_vars:
18631863
ds._variables[k] = f(self._variables[k], *args, **kwargs)
1864+
if keep_attrs:
1865+
ds._attrs = self._attrs
18641866
return ds
18651867
return func
18661868

@@ -2019,5 +2021,13 @@ def diff(self, dim, n=1, label='upper'):
20192021
else:
20202022
return difference
20212023

2024+
@property
2025+
def real(self):
2026+
return self._unary_op(lambda x: x.real, keep_attrs=True)(self)
2027+
2028+
@property
2029+
def imag(self):
2030+
return self._unary_op(lambda x: x.imag, keep_attrs=True)(self)
2031+
20222032

20232033
ops.inject_all_ops_and_reduce_methods(Dataset, array_only=False)

xray/core/ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,8 @@ def f(values, axis=None, skipna=None, **kwargs):
292292
if coerce_strings and values.dtype.kind in 'SU':
293293
values = values.astype(object)
294294

295-
if skipna or (skipna is None and values.dtype.kind == 'f'):
296-
if values.dtype.kind not in ['i', 'f']:
295+
if skipna or (skipna is None and values.dtype.kind in 'cf'):
296+
if values.dtype.kind not in ['i', 'f', 'c']:
297297
raise NotImplementedError(
298298
'skipna=True not yet implemented for %s with dtype %s'
299299
% (name, values.dtype))

xray/core/variable.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,14 @@ def identical(self, other):
751751
except (TypeError, AttributeError):
752752
return False
753753

754+
@property
755+
def real(self):
756+
return type(self)(self.dims, self.data.real, self._attrs)
757+
758+
@property
759+
def imag(self):
760+
return type(self)(self.dims, self.data.imag, self._attrs)
761+
754762
def __array_wrap__(self, obj, context=None):
755763
return Variable(self.dims, obj)
756764

xray/test/test_dataarray.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,3 +1460,8 @@ def test_dataarray_diff_n1(self):
14601460
[da['x'].values, da['y'].values[1:]],
14611461
['x', 'y'])
14621462
self.assertDataArrayEqual(expected, actual)
1463+
1464+
def test_real_and_imag(self):
1465+
array = DataArray(1 + 2j)
1466+
self.assertDataArrayIdentical(array.real, DataArray(1))
1467+
self.assertDataArrayIdentical(array.imag, DataArray(2))

xray/test/test_dataset.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2115,3 +2115,13 @@ def test_dataset_diff_exception_label_str(self):
21152115
ds = create_test_data(seed=1)
21162116
with self.assertRaisesRegexp(ValueError, '\'label\' argument has to'):
21172117
ds.diff('dim2', label='raise_me')
2118+
2119+
def test_real_and_imag(self):
2120+
attrs = {'foo': 'bar'}
2121+
ds = Dataset({'x': ((), 1 + 2j, attrs)}, attrs=attrs)
2122+
2123+
expected_re = Dataset({'x': ((), 1, attrs)}, attrs=attrs)
2124+
self.assertDatasetIdentical(ds.real, expected_re)
2125+
2126+
expected_im = Dataset({'x': ((), 2, attrs)}, attrs=attrs)
2127+
self.assertDatasetIdentical(ds.imag, expected_im)

xray/test/test_variable.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,23 @@ def test_copy(self):
371371
source_ndarray(w.values))
372372
self.assertVariableIdentical(v, copy(v))
373373

374+
def test_real_and_imag(self):
375+
v = self.cls('x', np.arange(3) - 1j * np.arange(3), {'foo': 'bar'})
376+
expected_re = self.cls('x', np.arange(3), {'foo': 'bar'})
377+
self.assertVariableIdentical(v.real, expected_re)
378+
379+
expected_im = self.cls('x', -np.arange(3), {'foo': 'bar'})
380+
self.assertVariableIdentical(v.imag, expected_im)
381+
382+
expected_abs = self.cls('x', np.sqrt(2 * np.arange(3) ** 2))
383+
self.assertVariableAllClose(abs(v), expected_abs)
384+
385+
def test_aggregate_complex(self):
386+
# should skip NaNs
387+
v = self.cls('x', [1, 2j, np.nan])
388+
expected = Variable((), 0.5 + 1j)
389+
self.assertVariableAllClose(v.mean(), expected)
390+
374391

375392
class TestVariable(TestCase, VariableSubclassTestCases):
376393
cls = staticmethod(Variable)

0 commit comments

Comments
 (0)