Skip to content

Commit 8f8fd5a

Browse files
committed
Addressed @matthew-brett's comments
1 parent fc51e76 commit 8f8fd5a

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

nibabel/streamlines/tests/test_tractogram.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def test_extend(self):
265265
assert_arrays_equal(sdict[k][len(DATA['tractogram']):],
266266
new_data[k])
267267

268-
# Extending with an empty PerArrayDicts should change nothing.
268+
# Extending with an empty PerArrayDict should change nothing.
269269
sdict_orig = copy.deepcopy(sdict)
270270
sdict.extend(PerArrayDict())
271271
for k in sdict_orig.keys():
@@ -280,10 +280,17 @@ def test_extend(self):
280280
sdict2 = PerArrayDict(len(DATA['tractogram']), new_data)
281281
assert_raises(ValueError, sdict.extend, sdict2)
282282

283+
# Other dict has not the same entries (key mistmached).
284+
new_data = {'mean_curvature': 2 * np.array(DATA['mean_curvature']),
285+
'mean_torsion': 3 * np.array(DATA['mean_torsion']),
286+
'other': 4 * np.array(DATA['mean_colors'])}
287+
sdict2 = PerArrayDict(len(DATA['tractogram']), new_data)
288+
assert_raises(ValueError, sdict.extend, sdict2)
289+
283290
# Other dict has the right number of entries but wrong shape.
284291
new_data = {'mean_curvature': 2 * np.array(DATA['mean_curvature']),
285292
'mean_torsion': 3 * np.array(DATA['mean_torsion']),
286-
'other': 4 * np.array(DATA['mean_torsion'])}
293+
'mean_colors': 4 * np.array(DATA['mean_torsion'])}
287294
sdict2 = PerArrayDict(len(DATA['tractogram']), new_data)
288295
assert_raises(ValueError, sdict.extend, sdict2)
289296

@@ -343,7 +350,7 @@ def test_extend(self):
343350
total_nb_rows = DATA['tractogram'].streamlines.total_nb_rows
344351
sdict = PerArraySequenceDict(total_nb_rows, DATA['data_per_point'])
345352

346-
# Test compatible PerArrayDicts.
353+
# Test compatible PerArraySequenceDicts.
347354
list_nb_points = [2, 7, 4]
348355
data_per_point_shapes = {"colors": DATA['colors'][0].shape[1:],
349356
"fa": DATA['fa'][0].shape[1:]}
@@ -360,13 +367,13 @@ def test_extend(self):
360367
assert_arrays_equal(sdict[k][len(DATA['tractogram']):],
361368
new_data[k])
362369

363-
# Extending with an empty PerArrayDicts should change nothing.
370+
# Extending with an empty PerArraySequenceDicts should change nothing.
364371
sdict_orig = copy.deepcopy(sdict)
365372
sdict.extend(PerArraySequenceDict())
366373
for k in sdict_orig.keys():
367374
assert_arrays_equal(sdict[k], sdict_orig[k])
368375

369-
# Test incompatible PerArrayDicts.
376+
# Test incompatible PerArraySequenceDicts.
370377
# Other dict has more entries.
371378
data_per_point_shapes = {"colors": DATA['colors'][0].shape[1:],
372379
"fa": DATA['fa'][0].shape[1:],
@@ -377,6 +384,15 @@ def test_extend(self):
377384
sdict2 = PerArraySequenceDict(np.sum(list_nb_points), new_data)
378385
assert_raises(ValueError, sdict.extend, sdict2)
379386

387+
# Other dict has not the same entries (key mistmached).
388+
data_per_point_shapes = {"colors": DATA['colors'][0].shape[1:],
389+
"other": DATA['fa'][0].shape[1:]}
390+
_, new_data, _ = make_fake_tractogram(list_nb_points,
391+
data_per_point_shapes,
392+
rng=DATA['rng'])
393+
sdict2 = PerArraySequenceDict(np.sum(list_nb_points), new_data)
394+
assert_raises(ValueError, sdict.extend, sdict2)
395+
380396
# Other dict has the right number of entries but wrong shape.
381397
data_per_point_shapes = {"colors": DATA['colors'][0].shape[1:],
382398
"fa": DATA['fa'][0].shape[1:] + (3,)}

0 commit comments

Comments
 (0)