Skip to content

Commit 040e2a0

Browse files
committed
Supports extending empty tractograms
1 parent 4f5e4f3 commit 040e2a0

File tree

2 files changed

+43
-42
lines changed

2 files changed

+43
-42
lines changed

nibabel/streamlines/tests/test_tractogram.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import sys
2+
import copy
23
import unittest
34
import numpy as np
45
import warnings
@@ -264,10 +265,13 @@ def test_extend(self):
264265
assert_arrays_equal(sdict[k][len(DATA['tractogram']):],
265266
new_data[k])
266267

267-
# Test incompatible PerArrayDicts.
268-
# Other dict is missing entries.
269-
assert_raises(ValueError, sdict.extend, PerArrayDict())
268+
# Extending with an empty PerArrayDicts should change nothing.
269+
sdict_orig = copy.deepcopy(sdict)
270+
sdict.extend(PerArrayDict())
271+
for k in sdict_orig.keys():
272+
assert_arrays_equal(sdict[k], sdict_orig[k])
270273

274+
# Test incompatible PerArrayDicts.
271275
# Other dict has more entries.
272276
new_data = {'mean_curvature': 2 * np.array(DATA['mean_curvature']),
273277
'mean_torsion': 3 * np.array(DATA['mean_torsion']),
@@ -356,10 +360,13 @@ def test_extend(self):
356360
assert_arrays_equal(sdict[k][len(DATA['tractogram']):],
357361
new_data[k])
358362

359-
# Test incompatible PerArrayDicts.
360-
# Other dict is missing entries.
361-
assert_raises(ValueError, sdict.extend, PerArraySequenceDict())
363+
# Extending with an empty PerArrayDicts should change nothing.
364+
sdict_orig = copy.deepcopy(sdict)
365+
sdict.extend(PerArraySequenceDict())
366+
for k in sdict_orig.keys():
367+
assert_arrays_equal(sdict[k], sdict_orig[k])
362368

369+
# Test incompatible PerArrayDicts.
363370
# Other dict has more entries.
364371
data_per_point_shapes = {"colors": DATA['colors'][0].shape[1:],
365372
"fa": DATA['fa'][0].shape[1:],
@@ -728,6 +735,16 @@ def test_tractogram_extend(self):
728735
assert_tractogram_equal(new_t[:len(t)], DATA['tractogram'])
729736
assert_tractogram_equal(new_t[len(t):], DATA['tractogram'])
730737

738+
# Test extending an empty Tractogram.
739+
t = Tractogram()
740+
t += DATA['tractogram']
741+
assert_tractogram_equal(t, DATA['tractogram'])
742+
743+
# and the other way around.
744+
t = DATA['tractogram'].copy()
745+
t += Tractogram()
746+
assert_tractogram_equal(t, DATA['tractogram'])
747+
731748

732749
class TestLazyTractogram(unittest.TestCase):
733750

nibabel/streamlines/tractogram.py

Lines changed: 20 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ def __getitem__(self, key):
5959
# Key was not a valid index/slice after all.
6060
return self.store[key] # Will raise the proper error.
6161

62+
def __contains__(self, key):
63+
return key in self.store
64+
6265
def __delitem__(self, key):
6366
del self.store[key]
6467

@@ -90,7 +93,7 @@ class PerArrayDict(SliceableDataDict):
9093
Positional and keyword arguments, passed straight through the ``dict``
9194
constructor.
9295
"""
93-
def __init__(self, n_rows=None, *args, **kwargs):
96+
def __init__(self, n_rows=0, *args, **kwargs):
9497
self.n_rows = n_rows
9598
super(PerArrayDict, self).__init__(*args, **kwargs)
9699

@@ -105,13 +108,17 @@ def __setitem__(self, key, value):
105108
raise ValueError("data_per_streamline must be a 2D array.")
106109

107110
# We make sure there is the right amount of values
108-
if self.n_rows is not None and len(value) != self.n_rows:
111+
if self.n_rows > 0 and len(value) != self.n_rows:
109112
msg = ("The number of values ({0}) should match n_elements "
110113
"({1}).").format(len(value), self.n_rows)
111114
raise ValueError(msg)
112115

113116
self.store[key] = value
114117

118+
def _extend_entry(self, key, value):
119+
""" Appends the `value` to the entry specified by `key`. """
120+
self[key] = np.concatenate([self[key], value])
121+
115122
def extend(self, other):
116123
""" Appends the elements of another :class:`PerArrayDict`.
117124
@@ -131,16 +138,20 @@ def extend(self, other):
131138
-----
132139
The keys in both dictionaries must be the same.
133140
"""
134-
if sorted(self.keys()) != sorted(other.keys()):
141+
if (len(self) > 0 and len(other) > 0
142+
and sorted(self.keys()) != sorted(other.keys())):
135143
msg = ("Entry mismatched between the two PerArrayDict objects."
136144
" This PerArrayDict contains '{0}' whereas the other "
137145
" contains '{1}'.").format(sorted(self.keys()),
138146
sorted(other.keys()))
139147
raise ValueError(msg)
140148

141149
self.n_rows += other.n_rows
142-
for key in self.keys():
143-
self[key] = np.concatenate([self[key], other[key]])
150+
for key in other.keys():
151+
if key not in self:
152+
self[key] = other[key]
153+
else:
154+
self._extend_entry(key, other[key])
144155

145156

146157
class PerArraySequenceDict(PerArrayDict):
@@ -158,43 +169,16 @@ def __setitem__(self, key, value):
158169
value = ArraySequence(value)
159170

160171
# We make sure there is the right amount of data.
161-
if (self.n_rows is not None and
162-
value.total_nb_rows != self.n_rows):
172+
if self.n_rows > 0 and value.total_nb_rows != self.n_rows:
163173
msg = ("The number of values ({0}) should match "
164174
"({1}).").format(value.total_nb_rows, self.n_rows)
165175
raise ValueError(msg)
166176

167177
self.store[key] = value
168178

169-
def extend(self, other):
170-
""" Appends the elements of another :class:`PerArraySequenceDict`.
171-
172-
That is, for each entry in this dictionary, we append the elements
173-
coming from the other dictionary at the corresponding entry.
174-
175-
Parameters
176-
----------
177-
other : :class:`PerArraySequenceDict` object
178-
Its data will be appended to the data of this dictionary.
179-
180-
Returns
181-
-------
182-
None
183-
184-
Notes
185-
-----
186-
The keys in both dictionaries must be the same.
187-
"""
188-
if sorted(self.keys()) != sorted(other.keys()):
189-
msg = ("Key mismatched between the two PerArrayDict objects."
190-
" This PerArrayDict contains '{0}' whereas the other "
191-
" contains '{1}'.").format(sorted(self.keys()),
192-
sorted(other.keys()))
193-
raise ValueError(msg)
194-
195-
self.n_rows += other.n_rows
196-
for key in self.keys():
197-
self[key].extend(other[key])
179+
def _extend_entry(self, key, value):
180+
""" Appends the `value` to the entry specified by `key`. """
181+
self[key].extend(value)
198182

199183

200184
class LazyDict(collections.MutableMapping):

0 commit comments

Comments
 (0)