Skip to content

Commit 3331a51

Browse files
author
Ben Cipollini
committed
Allow mixed files and objects.
1 parent afaa5fe commit 3331a51

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

nibabel/funcs.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def concat_images(images, check_affines=True, axis=None):
9494
Parameters
9595
----------
9696
images : sequence
97-
sequence of ``SpatialImage`` or of filenames\s
97+
sequence of ``SpatialImage`` or filenames\s
9898
check_affines : {True, False}, optional
9999
If True, then check that all the affines for `images` are nearly
100100
the same, raising a ``ValueError`` otherwise. Default is True
@@ -109,32 +109,33 @@ def concat_images(images, check_affines=True, axis=None):
109109
'''
110110
n_imgs = len(images)
111111
img0 = images[0]
112-
is_filename = False
113112
if not hasattr(img0, 'get_data'):
114113
img0 = load(img0)
115-
is_filename = True
116114
affine = img0.affine
117115
header = img0.header
116+
i0shape = img0.shape
117+
del img0
118118

119119
if axis is None: # collect images in output array for efficiency
120-
out_shape = (n_imgs, ) + img0.shape[:3]
120+
out_shape = (n_imgs, ) + i0shape[:3]
121121
out_data = np.empty(out_shape)
122122
else: # collect images in list for use with np.concatenate
123123
out_data = [None] * n_imgs
124124

125125
for i, img in enumerate(images):
126-
if is_filename:
126+
if not hasattr(img, 'get_data'):
127127
img = load(img)
128+
128129
if check_affines and not np.all(img.affine == affine):
129130
raise ValueError('Affines do not match')
130131

132+
# Special case for 4D image with size[3] == 1; reshape to work!
131133
if axis is None and img.get_data().ndim == 4 and img.get_data().shape[3] == 1:
132134
out_data[i] = np.reshape(img.get_data(), img.get_data().shape[:-1])
133135
else:
134136
out_data[i] = img.get_data()
135137

136-
if is_filename:
137-
del img
138+
del img
138139

139140
if axis is None:
140141
out_data = np.rollaxis(out_data, 0, out_data.ndim)

nibabel/tests/test_funcs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ def test_concat():
103103
# Try mem-based, file-based, and mixed
104104
imgs = [img0_mem, img1_mem, img2_mem, img3_mem]
105105
img_files = [_as_fname(img) for img in imgs]
106-
for img0, img1, img2, img3 in (imgs, img_files):
106+
imgs_mixed = [imgs[0], img_files[1], imgs[2], img_files[3]]
107+
for img0, img1, img2, img3 in (imgs, img_files, imgs_mixed):
107108
try:
108109
all_imgs = concat_images([img0, img1],
109110
**concat_imgs_kwargs)

0 commit comments

Comments
 (0)