diff --git a/nibabel/funcs.py b/nibabel/funcs.py index 645fe09b2b..fc69b6c780 100644 --- a/nibabel/funcs.py +++ b/nibabel/funcs.py @@ -88,43 +88,70 @@ def squeeze_image(img): img.extra) -def concat_images(images, check_affines=True): - ''' Concatenate images in list to single image, along last dimension +def concat_images(images, check_affines=True, axis=None): + ''' Concatenate images in list to single image, along specified dimension Parameters ---------- images : sequence - sequence of ``SpatialImage`` or of filenames\s + sequence of ``SpatialImage`` or filenames of the same dimensionality\s check_affines : {True, False}, optional If True, then check that all the affines for `images` are nearly the same, raising a ``ValueError`` otherwise. Default is True - + axis : None or int, optional + If None, concatenates on a new dimension. This requires all images to + be the same shape. If not None, concatenates on the specified + dimension. This requires all images to be the same shape, except on + the specified dimension. Returns ------- concat_img : ``SpatialImage`` New image resulting from concatenating `images` across last dimension ''' + images = [load(img) if not hasattr(img, 'get_data') + else img for img in images] n_imgs = len(images) + if n_imgs == 0: + raise ValueError("Cannot concatenate an empty list of images.") img0 = images[0] - is_filename = False - if not hasattr(img0, 'get_data'): - img0 = load(img0) - is_filename = True - i0shape = img0.shape affine = img0.affine header = img0.header - out_shape = (n_imgs, ) + i0shape - out_data = np.empty(out_shape) - for i, img in enumerate(images): - if is_filename: - img = load(img) - if check_affines: - if not np.all(img.affine == affine): - raise ValueError('Affines do not match') - out_data[i] = img.get_data() - out_data = np.rollaxis(out_data, 0, len(i0shape)+1) klass = img0.__class__ + shape0 = img0.shape + n_dim = len(shape0) + if axis is None: + # collect images in output array for efficiency + out_shape = (n_imgs, ) + shape0 + out_data = np.empty(out_shape) + else: + # collect images in list for use with np.concatenate + out_data = [None] * n_imgs + # Get part of shape we need to check inside loop + idx_mask = np.ones((n_dim,), dtype=bool) + if axis is not None: + idx_mask[axis] = False + masked_shape = np.array(shape0)[idx_mask] + for i, img in enumerate(images): + if len(img.shape) != n_dim: + raise ValueError( + 'Image {0} has {1} dimensions, image 0 has {2}'.format( + i, len(img.shape), n_dim)) + if not np.all(np.array(img.shape)[idx_mask] == masked_shape): + raise ValueError('shape {0} for image {1} not compatible with ' + 'first image shape {2} with axis == {0}'.format( + img.shape, i, shape0, axis)) + if check_affines and not np.all(img.affine == affine): + raise ValueError('Affine for image {0} does not match affine ' + 'for first image'.format(i)) + # Do not fill cache in image if it is empty + out_data[i] = img.get_data(caching='unchanged') + + if axis is None: + out_data = np.rollaxis(out_data, 0, out_data.ndim) + else: + out_data = np.concatenate(out_data, axis=axis) + return klass(out_data, affine, header) diff --git a/nibabel/tests/test_funcs.py b/nibabel/tests/test_funcs.py index 8a18afd739..20d11578b3 100644 --- a/nibabel/tests/test_funcs.py +++ b/nibabel/tests/test_funcs.py @@ -30,34 +30,96 @@ def _as_fname(img): def test_concat(): - shape = (1,2,5) - data0 = np.arange(10).reshape(shape) + # Smoke test: concat empty list. + assert_raises(ValueError, concat_images, []) + + # Build combinations of 3D, 4D w/size[3] == 1, and 4D w/size[3] == 3 + all_shapes_5D = ((1, 4, 5, 3, 3), + (7, 3, 1, 4, 5), + (0, 2, 1, 4, 5)) + affine = np.eye(4) - img0_mem = Nifti1Image(data0, affine) - data1 = data0 - 10 - img1_mem = Nifti1Image(data1, affine) - img2_mem = Nifti1Image(data1, affine+1) - img3_mem = Nifti1Image(data1.T, affine) - all_data = np.concatenate( - [data0[:,:,:,np.newaxis],data1[:,:,:,np.newaxis]],3) - # Check filenames and in-memory images work - with InTemporaryDirectory(): - imgs = [img0_mem, img1_mem, img2_mem, img3_mem] - img_files = [_as_fname(img) for img in imgs] - for img0, img1, img2, img3 in (imgs, img_files): - all_imgs = concat_images([img0, img1]) - assert_array_equal(all_imgs.get_data(), all_data) - assert_array_equal(all_imgs.affine, affine) - # check that not-matching affines raise error - assert_raises(ValueError, concat_images, [img0, img2]) - assert_raises(ValueError, concat_images, [img0, img3]) - # except if check_affines is False - all_imgs = concat_images([img0, img1]) - assert_array_equal(all_imgs.get_data(), all_data) - assert_array_equal(all_imgs.affine, affine) - # Delete images as prophylaxis for windows access errors - for img in imgs: - del(img) + for dim in range(2, 6): + all_shapes_ND = tuple((shape[:dim] for shape in all_shapes_5D)) + all_shapes_N1D_unary = tuple((shape + (1,) for shape in all_shapes_ND)) + all_shapes = all_shapes_ND + all_shapes_N1D_unary + + # Loop over all possible combinations of images, in first and + # second position. + for data0_shape in all_shapes: + data0_numel = np.asarray(data0_shape).prod() + data0 = np.arange(data0_numel).reshape(data0_shape) + img0_mem = Nifti1Image(data0, affine) + + for data1_shape in all_shapes: + data1_numel = np.asarray(data1_shape).prod() + data1 = np.arange(data1_numel).reshape(data1_shape) + img1_mem = Nifti1Image(data1, affine) + img2_mem = Nifti1Image(data1, affine+1) # bad affine + + # Loop over every possible axis, including None (explicit and implied) + for axis in (list(range(-(dim-2), (dim-1))) + [None, '__default__']): + + # Allow testing default vs. passing explicit param + if axis == '__default__': + np_concat_kwargs = dict(axis=-1) + concat_imgs_kwargs = dict() + axis = None # Convert downstream + elif axis is None: + np_concat_kwargs = dict(axis=-1) + concat_imgs_kwargs = dict(axis=axis) + else: + np_concat_kwargs = dict(axis=axis) + concat_imgs_kwargs = dict(axis=axis) + + # Create expected output + try: + # Error will be thrown if the np.concatenate fails. + # However, when axis=None, the concatenate is possible + # but our efficient logic (where all images are + # 3D and the same size) fails, so we also + # have to expect errors for those. + if axis is None: # 3D from here and below + all_data = np.concatenate([data0[..., np.newaxis], + data1[..., np.newaxis]], + **np_concat_kwargs) + else: # both 3D, appending on final axis + all_data = np.concatenate([data0, data1], + **np_concat_kwargs) + expect_error = False + except ValueError: + # Shapes are not combinable + expect_error = True + + # Check filenames and in-memory images work + with InTemporaryDirectory(): + # Try mem-based, file-based, and mixed + imgs = [img0_mem, img1_mem, img2_mem] + img_files = [_as_fname(img) for img in imgs] + imgs_mixed = [imgs[0], img_files[1], imgs[2]] + for img0, img1, img2 in (imgs, img_files, imgs_mixed): + try: + all_imgs = concat_images([img0, img1], + **concat_imgs_kwargs) + except ValueError as ve: + assert_true(expect_error, str(ve)) + else: + assert_false(expect_error, "Expected a concatenation error, but got none.") + assert_array_equal(all_imgs.get_data(), all_data) + assert_array_equal(all_imgs.affine, affine) + + # check that not-matching affines raise error + assert_raises(ValueError, concat_images, [img0, img2], **concat_imgs_kwargs) + + # except if check_affines is False + try: + all_imgs = concat_images([img0, img1], **concat_imgs_kwargs) + except ValueError as ve: + assert_true(expect_error, str(ve)) + else: + assert_false(expect_error, "Expected a concatenation error, but got none.") + assert_array_equal(all_imgs.get_data(), all_data) + assert_array_equal(all_imgs.affine, affine) def test_closest_canonical():