Skip to content

Commit afaa5fe

Browse files
author
Ben Cipollini
committed
Make this work for all 3D and 4D combinations possible, across all axes possible. Test extensively.
1 parent c39caac commit afaa5fe

File tree

2 files changed

+125
-78
lines changed

2 files changed

+125
-78
lines changed

nibabel/funcs.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,9 @@ def concat_images(images, check_affines=True, axis=None):
9898
check_affines : {True, False}, optional
9999
If True, then check that all the affines for `images` are nearly
100100
the same, raising a ``ValueError`` otherwise. Default is True
101-
axis : int, optional
102-
If None, concatenates on the last dimension.
103-
If not None, concatenates on the specified dimension.
101+
axis : None or int, optional
102+
If None, concatenates on the 4th dimension.
103+
If not None, concatenates on the specified dimension [-2 to 3).
104104
Returns
105105
-------
106106
concat_img : ``SpatialImage``
@@ -113,23 +113,40 @@ def concat_images(images, check_affines=True, axis=None):
113113
if not hasattr(img0, 'get_data'):
114114
img0 = load(img0)
115115
is_filename = True
116-
i0shape = img0.shape
117116
affine = img0.affine
118117
header = img0.header
119-
out_shape = (n_imgs, ) + i0shape
120-
out_data = []
118+
119+
if axis is None: # collect images in output array for efficiency
120+
out_shape = (n_imgs, ) + img0.shape[:3]
121+
out_data = np.empty(out_shape)
122+
else: # collect images in list for use with np.concatenate
123+
out_data = [None] * n_imgs
124+
121125
for i, img in enumerate(images):
122126
if is_filename:
123127
img = load(img)
124-
if check_affines:
125-
if not np.all(img.affine == affine):
126-
raise ValueError('Affines do not match')
127-
out_data.append(img.get_data())
128-
if axis is not None:
129-
out_data = np.concatenate(out_data, axis=axis)
128+
if check_affines and not np.all(img.affine == affine):
129+
raise ValueError('Affines do not match')
130+
131+
if axis is None and img.get_data().ndim == 4 and img.get_data().shape[3] == 1:
132+
out_data[i] = np.reshape(img.get_data(), img.get_data().shape[:-1])
133+
else:
134+
out_data[i] = img.get_data()
135+
136+
if is_filename:
137+
del img
138+
139+
if axis is None:
140+
out_data = np.rollaxis(out_data, 0, out_data.ndim)
130141
else:
131-
out_data = np.asarray(out_data)
132-
out_data = np.rollaxis(out_data, 0, len(i0shape)+1)
142+
# Massage the output, to allow combining 3D and 4D images.
143+
is_3D = [len(d.shape) == 3 for d in out_data]
144+
is_4D = [len(d.shape) == 4 for d in out_data]
145+
if np.any(is_3D) and np.any(is_4D):
146+
out_data = [data if is_4D[di] else np.reshape(data, data.shape + (1,))
147+
for di, data in enumerate(out_data)]
148+
out_data = np.concatenate(out_data, axis=axis)
149+
133150
klass = img0.__class__
134151
return klass(out_data, affine, header)
135152

nibabel/tests/test_funcs.py

Lines changed: 94 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -30,71 +30,101 @@ def _as_fname(img):
3030

3131

3232
def test_concat():
33-
for shape in ((1,2,5), (7,3,1), (13,11,11), (0,1,1)):
34-
numel = np.asarray(shape).prod()
35-
data0 = np.arange(numel).reshape(shape)
36-
affine = np.eye(4)
37-
img0_mem = Nifti1Image(data0, affine)
38-
data1 = data0 - 10
39-
img1_mem = Nifti1Image(data1, affine)
40-
img2_mem = Nifti1Image(data1, affine+1)
41-
img3_mem = Nifti1Image(data1.T, affine)
42-
all_data = np.concatenate(
43-
[data0[:,:,:,np.newaxis],data1[:,:,:,np.newaxis]],3)
44-
# Check filenames and in-memory images work
45-
with InTemporaryDirectory():
46-
imgs = [img0_mem, img1_mem, img2_mem, img3_mem]
47-
img_files = [_as_fname(img) for img in imgs]
48-
for img0, img1, img2, img3 in (imgs, img_files):
49-
all_imgs = concat_images([img0, img1])
50-
assert_array_equal(all_imgs.get_data(), all_data)
51-
assert_array_equal(all_imgs.affine, affine)
52-
# check that not-matching affines raise error
53-
assert_raises(ValueError, concat_images, [img0, img2])
54-
assert_raises(ValueError, concat_images, [img0, img3])
55-
# except if check_affines is False
56-
all_imgs = concat_images([img0, img1])
57-
assert_array_equal(all_imgs.get_data(), all_data)
58-
assert_array_equal(all_imgs.affine, affine)
59-
# Delete images as prophylaxis for windows access errors
60-
for img in imgs:
61-
del(img)
62-
63-
# Test axis parameter and trailing unary dimension
64-
shape_4D = np.asarray(shape + (1,))
65-
data0 = np.arange(numel).reshape(shape_4D)
66-
affine = np.eye(4)
33+
34+
# Build combinations of 3D, 4D w/size[3] == 1, and 4D w/size[3] == 3
35+
all_shapes_3D = ((1, 2, 5), (7, 3, 1), (13, 11, 11), (0, 1, 1))
36+
all_shapes_4D_unary = tuple((shape + (1,) for shape in all_shapes_3D))
37+
all_shapes_4D_multi = tuple((shape + (3,) for shape in all_shapes_3D))
38+
all_shapes = all_shapes_3D + all_shapes_4D_unary + all_shapes_4D_multi
39+
40+
affine = np.eye(4)
41+
# Loop over all possible combinations of images, in first and
42+
# second position.
43+
for data0_shape in all_shapes:
44+
data0_numel = np.asarray(data0_shape).prod()
45+
data0 = np.arange(data0_numel).reshape(data0_shape)
6746
img0_mem = Nifti1Image(data0, affine)
68-
img1_mem = Nifti1Image(data0 - 10, affine)
69-
70-
# 4d, same shape, append on axis 3
71-
concat_img1 = concat_images([img0_mem, img1_mem], axis=3)
72-
expected_shape1 = shape_4D.copy()
73-
expected_shape1[-1] *= 2
74-
assert_array_equal(concat_img1.shape, expected_shape1)
75-
76-
# 4d, same shape, append on axis 0
77-
concat_img2 = concat_images([img0_mem, img1_mem], axis=0)
78-
expected_shape2 = shape_4D.copy()
79-
expected_shape2[0] *= 2
80-
assert_array_equal(concat_img2.shape, expected_shape2)
81-
82-
# 4d, same shape, append on axis -1
83-
concat_img3 = concat_images([img0_mem, img1_mem], axis=-1)
84-
expected_shape3 = shape_4D.copy()
85-
expected_shape3[-1] *= 2
86-
assert_array_equal(concat_img3.shape, expected_shape3)
87-
88-
# 4d, different shape, append on axis that's different
89-
print('%s %s' % (str(concat_img3.shape), str(img1_mem.shape)))
90-
concat_img4 = concat_images([concat_img3, img1_mem], axis=-1)
91-
expected_shape4 = shape_4D.copy()
92-
expected_shape4[-1] *= 3
93-
assert_array_equal(concat_img4.shape, expected_shape4)
94-
95-
# 4d, different shape, append on axis that's not different...
96-
# Doesn't work!
97-
assert_raises(ValueError, concat_images, [concat_img3, img1_mem], axis=1)
47+
48+
for data1_shape in all_shapes:
49+
data1_numel = np.asarray(data1_shape).prod()
50+
data1 = np.arange(data1_numel).reshape(data1_shape)
51+
img1_mem = Nifti1Image(data1, affine)
52+
img2_mem = Nifti1Image(data1, affine+1) # bad affine
53+
img3_mem = Nifti1Image(data1.T, affine) # bad data shape
54+
55+
# Loop over every possible axis, including None (explicit and implied)
56+
for axis in (list(range(-2, 3)) + [None, '__default__']):
57+
58+
# Allow testing default vs. passing explicit param
59+
if axis == '__default__':
60+
np_concat_kwargs = dict(axis=-1)
61+
concat_imgs_kwargs = dict()
62+
axis = None # Convert downstream
63+
elif axis is None:
64+
np_concat_kwargs = dict(axis=-1)
65+
concat_imgs_kwargs = dict(axis=axis)
66+
else:
67+
np_concat_kwargs = dict(axis=axis)
68+
concat_imgs_kwargs = dict(axis=axis)
69+
70+
# Create expected output
71+
try:
72+
# Error will be thrown if the np.concatenate fails.
73+
# However, when axis=None, the concatenate is possible
74+
# but our efficient logic (where all images are
75+
# 3D and the same size) fails, so we also
76+
# have to expect errors for those.
77+
expect_error = False
78+
if data0.ndim == 3 and data1.ndim == 4:
79+
expect_error = axis is None and data1.shape[3] != 1
80+
all_data = np.concatenate([data0[..., np.newaxis], data1],
81+
**np_concat_kwargs)
82+
elif data0.ndim == 4 and data1.ndim == 3:
83+
expect_error = axis is None and data0.shape[3] != 1
84+
all_data = np.concatenate([data0, data1[..., np.newaxis]],
85+
**np_concat_kwargs)
86+
elif data0.ndim == 4 and data1.ndim == 4:
87+
expect_error = axis is None and (data0.shape[3] != 1 or
88+
data1.shape[3] != 1)
89+
all_data = np.concatenate([data0, data1],
90+
**np_concat_kwargs)
91+
elif axis is None: # 3D from here and below
92+
all_data = np.concatenate(
93+
[data0[..., np.newaxis], data1[..., np.newaxis]], 3)
94+
else: # both 3D, appending on final axis
95+
all_data = np.concatenate([data0, data1],
96+
**np_concat_kwargs)
97+
except ValueError:
98+
# Shapes are not combinable
99+
expect_error = True
100+
101+
# Check filenames and in-memory images work
102+
with InTemporaryDirectory():
103+
# Try mem-based, file-based, and mixed
104+
imgs = [img0_mem, img1_mem, img2_mem, img3_mem]
105+
img_files = [_as_fname(img) for img in imgs]
106+
for img0, img1, img2, img3 in (imgs, img_files):
107+
try:
108+
all_imgs = concat_images([img0, img1],
109+
**concat_imgs_kwargs)
110+
assert_array_equal(all_imgs.get_data(), all_data)
111+
assert_array_equal(all_imgs.affine, affine)
112+
assert_false(expect_error, "Expected a concatenation error, but got none.")
113+
except ValueError as ve:
114+
assert_true(expect_error, ve.message)
115+
116+
# check that not-matching affines raise error
117+
assert_raises(ValueError, concat_images, [img0, img2], **concat_imgs_kwargs)
118+
assert_raises(ValueError, concat_images, [img0, img3], **concat_imgs_kwargs)
119+
120+
# except if check_affines is False
121+
try:
122+
all_imgs = concat_images([img0, img1], **concat_imgs_kwargs)
123+
assert_array_equal(all_imgs.get_data(), all_data)
124+
assert_array_equal(all_imgs.affine, affine)
125+
assert_false(expect_error, "Expected a concatenation error, but got none.")
126+
except ValueError as ve:
127+
assert_true(expect_error, ve.message)
98128

99129

100130
def test_closest_canonical():

0 commit comments

Comments
 (0)