diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index 76f06dff0e..d21e88e0f3 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -18,6 +18,7 @@ def setup(): global DATA + DATA['rng'] = np.random.RandomState(1234) DATA['streamlines'] = [np.arange(1*3, dtype="f4").reshape((1, 3)), np.arange(2*3, dtype="f4").reshape((2, 3)), np.arange(5*3, dtype="f4").reshape((5, 3))] @@ -373,6 +374,13 @@ def test_tractogram_getitem(self): DATA['tractogram'].data_per_streamline[::-1], DATA['tractogram'].data_per_point[::-1]) + # Make sure slicing conserves the affine_to_rasmm property. + tractogram = DATA['tractogram'].copy() + tractogram.affine_to_rasmm = DATA['rng'].rand(4, 4) + tractogram_view = tractogram[::2] + assert_array_equal(tractogram_view.affine_to_rasmm, + tractogram.affine_to_rasmm) + def test_tractogram_add_new_data(self): # Tractogram with only streamlines t = DATA['simple_tractogram'].copy() diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index c33f707d1c..b386f27fa5 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -335,7 +335,8 @@ def __getitem__(self, idx): if isinstance(idx, (numbers.Integral, np.integer)): return TractogramItem(pts, data_per_streamline, data_per_point) - return Tractogram(pts, data_per_streamline, data_per_point) + return Tractogram(pts, data_per_streamline, data_per_point, + affine_to_rasmm=self.affine_to_rasmm) def __len__(self): return len(self.streamlines)