Skip to content

Add an 'axis' parameter to concat_images, plus two tests. #298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Mar 27, 2015
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 58 additions & 20 deletions nibabel/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,43 +88,81 @@ def squeeze_image(img):
img.extra)


def concat_images(images, check_affines=True):
''' Concatenate images in list to single image, along last dimension
def _shape_equal_excluding(shape1, shape2, exclude_axes):
""" Helper function to compare two array shapes, excluding any
axis specified."""

if len(shape1) != len(shape2):
return False

idx_mask = np.ones((len(shape1),), dtype=bool)
idx_mask[exclude_axes] = False
return np.array_equal(np.asarray(shape1)[idx_mask],
np.asarray(shape2)[idx_mask])


def concat_images(images, check_affines=True, axis=None):
''' Concatenate images in list to single image, along specified dimension

Parameters
----------
images : sequence
sequence of ``SpatialImage`` or of filenames\s
sequence of ``SpatialImage`` or filenames of the same dimensionality\s
check_affines : {True, False}, optional
If True, then check that all the affines for `images` are nearly
the same, raising a ``ValueError`` otherwise. Default is True

axis : None or int, optional
If None, concatenates on a new dimension. This rrequires all images
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo rrequires

to be the same shape).
If not None, concatenates on the specified dimension. This requires
all images to be the same shape, except on the specified dimension.
For 4D images, axis must be between -2 and 3.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this last sentence necessary - I mean - won't the concatenate raise an informative error in cases outside this range?

Returns
-------
concat_img : ``SpatialImage``
New image resulting from concatenating `images` across last
dimension
'''

n_imgs = len(images)
img0 = images[0]
is_filename = False
if not hasattr(img0, 'get_data'):
img0 = load(img0)
is_filename = True
i0shape = img0.shape
affine = img0.affine
header = img0.header
out_shape = (n_imgs, ) + i0shape
out_data = np.empty(out_shape)
if n_imgs == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop this check? I guess if they pass in an empty list they can expect an empty list back?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the past, and currently, this throws an error. I added the check because the error did not indicate the issue clearly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough.

raise ValueError("Cannot concatenate an empty list of images.")

for i, img in enumerate(images):
if is_filename:
if not hasattr(img, 'get_data'):
img = load(img)
if check_affines:
if not np.all(img.affine == affine):
raise ValueError('Affines do not match')

if i == 0: # first image, initialize data from loaded image
affine = img.affine
header = img.header
shape = img.shape
klass = img.__class__

if axis is None: # collect images in output array for efficiency
out_shape = (n_imgs, ) + shape
out_data = np.empty(out_shape)
else: # collect images in list for use with np.concatenate
out_data = [None] * n_imgs

elif check_affines and not np.all(img.affine == affine):
raise ValueError('Affines do not match')

elif ((axis is None and not np.array_equal(shape, img.shape)) or
(axis is not None and not _shape_equal_excluding(shape, img.shape,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry - this is probably a dumb question, but won't np.concatenate error in this second situation anyway?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will error, but the error message is pretty cryptic. I added this logic in after seeing the error.

exclude_axes=[axis]))):
# shape mismatch; numpy broadcast / concatenate can hide these.
raise ValueError("Image #%d (shape=%s) does not match the first "
"image shape (%s)." % (i, shape, img.shape))

out_data[i] = img.get_data()
out_data = np.rollaxis(out_data, 0, len(i0shape)+1)
klass = img0.__class__

del img

if axis is None:
out_data = np.rollaxis(out_data, 0, out_data.ndim)
else:
out_data = np.concatenate(out_data, axis=axis)

return klass(out_data, affine, header)


Expand Down
116 changes: 89 additions & 27 deletions nibabel/tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,34 +30,96 @@ def _as_fname(img):


def test_concat():
shape = (1,2,5)
data0 = np.arange(10).reshape(shape)
# Smoke test: concat empty list.
assert_raises(ValueError, concat_images, [])

# Build combinations of 3D, 4D w/size[3] == 1, and 4D w/size[3] == 3
all_shapes_5D = ((1, 4, 5, 3, 3),
(7, 3, 1, 4, 5),
(0, 2, 1, 4, 5))

affine = np.eye(4)
img0_mem = Nifti1Image(data0, affine)
data1 = data0 - 10
img1_mem = Nifti1Image(data1, affine)
img2_mem = Nifti1Image(data1, affine+1)
img3_mem = Nifti1Image(data1.T, affine)
all_data = np.concatenate(
[data0[:,:,:,np.newaxis],data1[:,:,:,np.newaxis]],3)
# Check filenames and in-memory images work
with InTemporaryDirectory():
imgs = [img0_mem, img1_mem, img2_mem, img3_mem]
img_files = [_as_fname(img) for img in imgs]
for img0, img1, img2, img3 in (imgs, img_files):
all_imgs = concat_images([img0, img1])
assert_array_equal(all_imgs.get_data(), all_data)
assert_array_equal(all_imgs.affine, affine)
# check that not-matching affines raise error
assert_raises(ValueError, concat_images, [img0, img2])
assert_raises(ValueError, concat_images, [img0, img3])
# except if check_affines is False
all_imgs = concat_images([img0, img1])
assert_array_equal(all_imgs.get_data(), all_data)
assert_array_equal(all_imgs.affine, affine)
# Delete images as prophylaxis for windows access errors
for img in imgs:
del(img)
for dim in range(2, 6):
all_shapes_ND = tuple((shape[:dim] for shape in all_shapes_5D))
all_shapes_N1D_unary = tuple((shape + (1,) for shape in all_shapes_ND))
all_shapes = all_shapes_ND + all_shapes_N1D_unary

# Loop over all possible combinations of images, in first and
# second position.
for data0_shape in all_shapes:
data0_numel = np.asarray(data0_shape).prod()
data0 = np.arange(data0_numel).reshape(data0_shape)
img0_mem = Nifti1Image(data0, affine)

for data1_shape in all_shapes:
data1_numel = np.asarray(data1_shape).prod()
data1 = np.arange(data1_numel).reshape(data1_shape)
img1_mem = Nifti1Image(data1, affine)
img2_mem = Nifti1Image(data1, affine+1) # bad affine

# Loop over every possible axis, including None (explicit and implied)
for axis in (list(range(-(dim-2), (dim-1))) + [None, '__default__']):

# Allow testing default vs. passing explicit param
if axis == '__default__':
np_concat_kwargs = dict(axis=-1)
concat_imgs_kwargs = dict()
axis = None # Convert downstream
elif axis is None:
np_concat_kwargs = dict(axis=-1)
concat_imgs_kwargs = dict(axis=axis)
else:
np_concat_kwargs = dict(axis=axis)
concat_imgs_kwargs = dict(axis=axis)

# Create expected output
try:
# Error will be thrown if the np.concatenate fails.
# However, when axis=None, the concatenate is possible
# but our efficient logic (where all images are
# 3D and the same size) fails, so we also
# have to expect errors for those.
expect_error = data0.ndim != data1.ndim
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put this outside try ... except ? Do you need the try .. except if expect_error is True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally try to keep indent levels to be a logical block, so sometimes I include things inside a try...except that may not throw, so that the logical block is maintained. Here, the computation of expect_error is all at the same indent level, which I find to be semantically helpful.

With that said, I'm happy to move it out; doing that now!

if axis is None: # 3D from here and below
all_data = np.concatenate([data0[..., np.newaxis],
data1[..., np.newaxis]],
**np_concat_kwargs)
else: # both 3D, appending on final axis
all_data = np.concatenate([data0, data1],
**np_concat_kwargs)
except ValueError:
# Shapes are not combinable
expect_error = True

# Check filenames and in-memory images work
with InTemporaryDirectory():
# Try mem-based, file-based, and mixed
imgs = [img0_mem, img1_mem, img2_mem]
img_files = [_as_fname(img) for img in imgs]
imgs_mixed = [imgs[0], img_files[1], imgs[2]]
for img0, img1, img2 in (imgs, img_files, imgs_mixed):
try:
all_imgs = concat_images([img0, img1],
**concat_imgs_kwargs)
except ValueError as ve:
assert_true(expect_error, str(ve))
else:
assert_false(expect_error, "Expected a concatenation error, but got none.")
assert_array_equal(all_imgs.get_data(), all_data)
assert_array_equal(all_imgs.affine, affine)

# check that not-matching affines raise error
assert_raises(ValueError, concat_images, [img0, img2], **concat_imgs_kwargs)

# except if check_affines is False
try:
all_imgs = concat_images([img0, img1], **concat_imgs_kwargs)
except ValueError as ve:
assert_true(expect_error, str(ve))
else:
assert_false(expect_error, "Expected a concatenation error, but got none.")
assert_array_equal(all_imgs.get_data(), all_data)
assert_array_equal(all_imgs.affine, affine)


def test_closest_canonical():
Expand Down