Skip to content

Add an 'axis' parameter to concat_images, plus two tests. #298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Mar 27, 2015
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions nibabel/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def squeeze_image(img):
img.extra)


def concat_images(images, check_affines=True):
''' Concatenate images in list to single image, along last dimension
def concat_images(images, check_affines=True, axis=None):
''' Concatenate images in list to single image, along specified dimension

Parameters
----------
Expand All @@ -98,7 +98,9 @@ def concat_images(images, check_affines=True):
check_affines : {True, False}, optional
If True, then check that all the affines for `images` are nearly
the same, raising a ``ValueError`` otherwise. Default is True

axis : int, optional
If None, concatenates on the last dimension.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be:

    axis : None or int, optional
        If None, creates an extra final dimension, and concatenates along the new dimension.

If not None, concatenates on the specified dimension.
Returns
-------
concat_img : ``SpatialImage``
Expand All @@ -122,8 +124,13 @@ def concat_images(images, check_affines=True):
if check_affines:
if not np.all(img.affine == affine):
raise ValueError('Affines do not match')
out_data[i] = img.get_data()
out_data = np.rollaxis(out_data, 0, len(i0shape)+1)
out_data[i] = img.get_data().copy()
if axis is not None:
out_data = np.concatenate(out_data, axis=axis)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First - I think this makes another copy of the array, which may well be large by this point.

Second - isn't this the wrong 'axis'? For example, the most common use here would be to concatenate a (i, j, k, N1) and an (i, j, k, N2) image. That would be axis=-1 or axis=3. I think the code above would first - fail if N1 != N2 in the out_data[i] = line and if N1 == N2, would fail for axis=3 because the last allowed axis in the call above would be 2 (it is iterating over the first axis generating 3D arrays). For axis=-1 (corresponding to axis == 2) it would generate something of shape (i, j, k * (N1 + N2)) which is still not what you want, I don't think.

I think the solution is to use concatenate without out_data if axis != None - so have separate code paths for the previous case and the new case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthew-brett For some reason, I had it in my head that what we actually had was a list. I'm shocked that the few tests I did worked at all.

I'll start over on this.

elif np.all([d.shape[-1] == 1 for d in out_data]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer not to have a special case for ones on the last axis - the user can always do (when fixed) axis=-1 for that case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough. :)

out_data = np.concatenate(out_data, axis=d.ndim-1)
else:
out_data = np.rollaxis(out_data, 0, len(i0shape)+1)
klass = img0.__class__
return klass(out_data, affine, header)

Expand Down
17 changes: 17 additions & 0 deletions nibabel/tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,23 @@ def test_concat():
for img in imgs:
del(img)

# Test axis parameter and trailing unary dimension
shape_4D = np.asarray(shape + (1,))
data0 = np.arange(10).reshape(shape_4D)
affine = np.eye(4)
img0_mem = Nifti1Image(data0, affine)
img1_mem = Nifti1Image(data0 - 10, affine)

concat_img1 = concat_images([img0_mem, img1_mem])
expected_shape1 = shape_4D.copy()
expected_shape1[-1] *= 2
assert_array_equal(concat_img1.shape, expected_shape1)

concat_img2 = concat_images([img0_mem, img1_mem], axis=0)
expected_shape2 = shape_4D.copy()
expected_shape2[0] *= 2
assert_array_equal(concat_img2.shape, expected_shape2)


def test_closest_canonical():
arr = np.arange(24).reshape((2,3,4,1))
Expand Down