-
Notifications
You must be signed in to change notification settings - Fork 262
ENH: Support + and += operators for Tractogram #495
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Current coverage is 94.02% (diff: 94.70%)@@ master #495 diff @@
==========================================
Files 166 166
Lines 21832 21992 +160
Methods 0 0
Messages 0 0
Branches 2325 2343 +18
==========================================
+ Hits 20527 20679 +152
- Misses 875 878 +3
- Partials 430 435 +5
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry to be slow to review - see comments.
other : :class:`PerArrayDict` object | ||
Its data will be appended to the data of this dictionary. | ||
|
||
Notes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add the method returns None.
|
||
Notes | ||
----- | ||
The entries in both dictionaries must match. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More specifically, the keys in each dictionary must be the same.
other : :class:`PerArraySequenceDict` object | ||
Its data will be appended to the data of this dictionary. | ||
|
||
Notes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Returns None, keys must match.
@@ -136,6 +162,32 @@ def __setitem__(self, key, value): | |||
|
|||
self.store[key] = value | |||
|
|||
def extend(self, other): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't you just inherit this method? I guess you'd have make the docstrings and message a bit more generic, but it seems a shame to duplicate the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to duplicate it (almost every line except one) because one method is using self[key] = np.concatenate([self[key], other[key]])
and the other self[key].extend(other[key])
. If you think of an alternative, I'm all ears.
other : :class:`Tractogram` object | ||
Its data will be appended to the data of this tractogram. | ||
|
||
Notes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Returns None, keys much match.
t = DATA['tractogram'].copy() | ||
|
||
# Double the tractogram. | ||
new_t = t + t |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could use:
def extender(a, b):
a.extend(b)
return a
import operator
for op, in_place in ((operator.add, False, (operator.iadd, True), (extender, True)):
first_arg = copy(t)
new_t = op(first_arg, t)
assert_equal(new_t is first_arg, in_place)
# etc
def test_extend(self): | ||
total_nb_rows = DATA['tractogram'].streamlines.total_nb_rows | ||
sdict = PerArraySequenceDict(total_nb_rows, DATA['data_per_point']) | ||
sdict2 = PerArraySequenceDict(total_nb_rows, DATA['data_per_point']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be nice to check for situation where data was not the same.
assert_arrays_equal(sdict[k][len(DATA['tractogram']):], v) | ||
|
||
# Test incompatible PerArrayDicts. | ||
assert_raises(ValueError, sdict.extend, PerArraySequenceDict()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check for situation where there are extra keys in one or other? Other than empty case here?
@@ -233,6 +248,20 @@ def test_getitem(self): | |||
assert_arrays_equal(sdict[-1][k], v[-1]) | |||
assert_arrays_equal(sdict[[0, -1]][k], v[[0, -1]]) | |||
|
|||
def test_extend(self): | |||
total_nb_rows = DATA['tractogram'].streamlines.total_nb_rows | |||
sdict = PerArraySequenceDict(total_nb_rows, DATA['data_per_point']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid duplication by inheriting this test class from TestPerArrayDict
, and this in TestPerArrayDict
?
tested_cls = PerArraySequenceDict
def test_extend(self):
sdict = self.tested_cls(total_nb_rows, DATA['data_per_point'])
# etc
Or a mixin with just this method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure it will work as you intend or maybe I don't get your point. There is a couple of differences between testing a PerArrayDict
and a PerArraySequenceDict
. For instance, the constructor of the first class takes the number of streamlines whereas the second takes the total number of points in a ArraySequence
. Also, one checks data_per_streamline
where the other data_per_point
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I got it to work :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see the complications now. What do you mean by "I got it to work" ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is related to my previous comment. I meant I succeeded in reducing code duplication.
@@ -181,6 +181,21 @@ def test_getitem(self): | |||
assert_arrays_equal(sdict[-1][k], v[-1]) | |||
assert_arrays_equal(sdict[[0, -1]][k], v[[0, -1]]) | |||
|
|||
def test_extend(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test for +
and +=
? See below for general suggestion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By choice, PerArrayDict
and PerArraySequenceDict
don't support +
and +=
. Only Tractogram
objects have it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for my slow understanding, but why no +
and +=
? Just because native dict
objects don't support these? In which case, why extend
which doesn't exist for dict
either?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't see the need for them att the time. I don't think these dicts are going to be used extensively but I can add them if you want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's fine, just trying to work out how you were thinking of these.
@matthew-brett thanks for the feedback. I addressed most of your comments except those related to code duplication (see my replies above). |
040e2a0
to
5d98758
Compare
@matthew-brett this PR is ready for a second round of reviews. |
Thanks for your patience, I should be able to get to this on Monday. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some small comments
'mean_torsion': mean_torsion_func, | ||
'mean_colors': mean_colors_func} | ||
DATA['data_per_point_func'] = { | ||
'colors': lambda: (e for e in DATA['colors']), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indentation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found the code cleaner when I break the line this way compared to in the middle of the generator comprehension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure - I was wondering if the indentation you got here was PEP8 compatible - fine if so.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is passing the flake8 test of one of the Travis bots :).
'colors': lambda: (e for e in DATA['colors']), | ||
'fa': lambda: (e for e in DATA['fa'])} | ||
DATA['data_per_streamline_func'] = { | ||
'mean_curvature': lambda: (e for e in DATA['mean_curvature']), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indentation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found the code cleaner when I break the line this way compared to in the middle of the generator comprehension.
@@ -181,6 +181,21 @@ def test_getitem(self): | |||
assert_arrays_equal(sdict[-1][k], v[-1]) | |||
assert_arrays_equal(sdict[[0, -1]][k], v[[0, -1]]) | |||
|
|||
def test_extend(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for my slow understanding, but why no +
and +=
? Just because native dict
objects don't support these? In which case, why extend
which doesn't exist for dict
either?
assert_arrays_equal(sdict[k][len(DATA['tractogram']):], | ||
new_data[k]) | ||
|
||
# Extending with an empty PerArrayDicts should change nothing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be PerArrayDict
(no 's' at end)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed
sdict2 = PerArrayDict(len(DATA['tractogram']), new_data) | ||
assert_raises(ValueError, sdict.extend, sdict2) | ||
|
||
# Other dict has the right number of entries but wrong shape. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the check is, that the keys must be the same. Is that what you mean by "shape" here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By entries, I mean keys. By shape, I mean the shape (except for the first dimension) of the ndarray or ArraySequence that will be appended to the the value at dict[k]
where k is one of the entry. This is because we know these dict are dictionaries of ndarray or ArraySequence.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hum - but isn't the error in fact coming from the fact that mean_color
!= other
, rather than the shape difference? Maybe you need two tests here, one for the keys and one for the shapes, where the shapes test has an entry with the same name, but a different shape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes,you are right. I'll make another test.
@@ -233,6 +248,20 @@ def test_getitem(self): | |||
assert_arrays_equal(sdict[-1][k], v[-1]) | |||
assert_arrays_equal(sdict[[0, -1]][k], v[[0, -1]]) | |||
|
|||
def test_extend(self): | |||
total_nb_rows = DATA['tractogram'].streamlines.total_nb_rows | |||
sdict = PerArraySequenceDict(total_nb_rows, DATA['data_per_point']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see the complications now. What do you mean by "I got it to work" ?
total_nb_rows = DATA['tractogram'].streamlines.total_nb_rows | ||
sdict = PerArraySequenceDict(total_nb_rows, DATA['data_per_point']) | ||
|
||
# Test compatible PerArrayDicts. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PerArraySequenceDicts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
assert_arrays_equal(sdict[k][len(DATA['tractogram']):], | ||
new_data[k]) | ||
|
||
# Extending with an empty PerArrayDicts should change nothing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PerArraySequenceDict
sdict2 = PerArraySequenceDict(np.sum(list_nb_points), new_data) | ||
assert_raises(ValueError, sdict.extend, sdict2) | ||
|
||
# Other dict has the right number of entries but wrong shape. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrong keys? (As above).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep.
|
||
for op, in_place in ((operator.add, False), (operator.iadd, True), | ||
(extender, True)): | ||
first_arg = t.copy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs deepcopy
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you asking if we need a deepcopy or you are suggesting me to use deepcopy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, asking if you need deepcopy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, because .copy()
is doing a deepcopy (https://github.com/nipy/nibabel/blob/master/nibabel/streamlines/tractogram.py#L346).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Response to replies.
'mean_torsion': mean_torsion_func, | ||
'mean_colors': mean_colors_func} | ||
DATA['data_per_point_func'] = { | ||
'colors': lambda: (e for e in DATA['colors']), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure - I was wondering if the indentation you got here was PEP8 compatible - fine if so.
@@ -181,6 +181,21 @@ def test_getitem(self): | |||
assert_arrays_equal(sdict[-1][k], v[-1]) | |||
assert_arrays_equal(sdict[[0, -1]][k], v[[0, -1]]) | |||
|
|||
def test_extend(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's fine, just trying to work out how you were thinking of these.
sdict2 = PerArrayDict(len(DATA['tractogram']), new_data) | ||
assert_raises(ValueError, sdict.extend, sdict2) | ||
|
||
# Other dict has the right number of entries but wrong shape. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hum - but isn't the error in fact coming from the fact that mean_color
!= other
, rather than the shape difference? Maybe you need two tests here, one for the keys and one for the shapes, where the shapes test has an entry with the same name, but a different shape.
@matthew-brett should be ready to be merged if you don't have any additional comments. |
Thanks for the edits - and sorry about the wait. |
No worry. Thanks for the review. |
This PR adds the functionality of concatenating two Tractogram objects using either
tractogram += other_tractogram
ortractogram = tractogram1 + tractogram2
.This will definitively interest @jchoude, @arnaudbore, @Garyfallidis, @FrancoisRheaultUS and many others.