@@ -88,6 +88,19 @@ def squeeze_image(img):
88
88
img .extra )
89
89
90
90
91
+ def _shape_equal_excluding (shape1 , shape2 , exclude_axes = None ):
92
+ """ Helper function to compare two array shapes, excluding any
93
+ axis specified."""
94
+
95
+ if len (shape1 ) != len (shape2 ):
96
+ return False
97
+
98
+ idx_mask = np .ones ((len (shape1 ),), dtype = bool )
99
+ idx_mask [exclude_axes ] = False
100
+ return np .array_equal (np .asarray (shape1 )[idx_mask ],
101
+ np .asarray (shape2 )[idx_mask ])
102
+
103
+
91
104
def concat_images (images , check_affines = True , axis = None ):
92
105
''' Concatenate images in list to single image, along specified dimension
93
106
@@ -134,10 +147,12 @@ def concat_images(images, check_affines=True, axis=None):
134
147
elif check_affines and not np .all (img .affine == affine ):
135
148
raise ValueError ('Affines do not match' )
136
149
137
- elif axis is None and not np .array_equal (shape , img .shape ):
138
- # shape mismatch; numpy broadcasting can hide these.
139
- raise ValueError ("Image %d (shape=%s) does not match first image "
140
- " shape (%s)." % (i , shape , img .shape ))
150
+ elif ((axis is None and not np .array_equal (shape , img .shape )) or
151
+ (axis is not None and not _shape_equal_excluding (shape , img .shape ,
152
+ exclude_axes = [axis ]))):
153
+ # shape mismatch; numpy broadcast / concatenate can hide these.
154
+ raise ValueError ("Image #%d (shape=%s) does not match the first "
155
+ "image shape (%s)." % (i , shape , img .shape ))
141
156
142
157
out_data [i ] = img .get_data ()
143
158
0 commit comments