Skip to content

Commit a395d68

Browse files
author
Ben Cipollini
committed
Modify save, remove vestigates of class_map / ext_map
1 parent 77fe652 commit a395d68

File tree

4 files changed

+45
-38
lines changed

4 files changed

+45
-38
lines changed

nibabel/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
from .orientations import (io_orientation, orientation_affine,
6262
flip_axis, OrientationError,
6363
apply_orientation, aff2axcodes)
64-
from .imageclasses import class_map, ext_map
64+
from .imageclasses import class_map, ext_map, all_image_classes
6565
from . import trackvis
6666
from . import mriutils
6767

nibabel/loadsave.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@
1313

1414
from .filename_parser import splitext_addext
1515
from .volumeutils import BinOpener
16-
from .nifti1 import Nifti1Image, Nifti1Pair
17-
from .nifti2 import Nifti2Image, Nifti2Pair
1816
from .spatialimages import ImageFileError
19-
from .imageclasses import class_map, ext_map, all_image_classes
17+
from .imageclasses import all_image_classes
2018
from .arrayproxy import is_proxy
2119

2220

@@ -60,14 +58,22 @@ def save(img, filename):
6058
-------
6159
None
6260
'''
61+
62+
# Save the type as expected
6363
try:
6464
img.to_filename(filename)
6565
except ImageFileError:
6666
pass
6767
else:
6868
return
69-
froot, ext, trailing = splitext_addext(filename, ('.gz', '.bz2'))
69+
70+
# Be nice to users by making common implicit conversions
71+
froot, ext, trailing = splitext_addext(filename, img._compressed_exts)
72+
lext = ext.lower()
73+
7074
# Special-case Nifti singles and Pairs
75+
from .nifti1 import Nifti1Image, Nifti1Pair # Inline imports, as this file
76+
from .nifti2 import Nifti2Image, Nifti2Pair # really shouldn't reference any image type
7177
if type(img) == Nifti1Image and ext in ('.img', '.hdr'):
7278
klass = Nifti1Pair
7379
elif type(img) == Nifti2Image and ext in ('.img', '.hdr'):
@@ -76,9 +82,14 @@ def save(img, filename):
7682
klass = Nifti1Image
7783
elif type(img) == Nifti2Pair and ext == '.nii':
7884
klass = Nifti2Image
79-
else:
80-
img_type = ext_map[ext]
81-
klass = class_map[img_type]['class']
85+
else: # arbitrary conversion
86+
valid_klasses = filter(lambda klass: klass.is_valid_extension(lext),
87+
all_image_classes)
88+
if len(valid_klasses) > 0:
89+
klass = valid_klasses[0]
90+
else:
91+
raise ImageFileError('Cannot work out file type of "%s"' %
92+
filename)
8293
converted = klass.from_image(img)
8394
converted.to_filename(filename)
8495

nibabel/tests/test_files_interface.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import numpy as np
1414

15-
from .. import class_map, Nifti1Image, Nifti1Pair, MGHImage
15+
from .. import Nifti1Image, Nifti1Pair, MGHImage, all_image_classes
1616
from ..externals.six import BytesIO
1717
from ..fileholders import FileHolderError
1818

@@ -25,15 +25,14 @@ def test_files_images():
2525
# test files creation in image classes
2626
arr = np.zeros((2,3,4))
2727
aff = np.eye(4)
28-
for img_def in class_map.values():
29-
klass = img_def['class']
28+
for klass in all_image_classes:
3029
file_map = klass.make_file_map()
3130
for key, value in file_map.items():
3231
assert_equal(value.filename, None)
3332
assert_equal(value.fileobj, None)
3433
assert_equal(value.pos, 0)
3534
# If we can't create new images in memory without loading, bail here
36-
if not img_def['makeable']:
35+
if not klass.makeable:
3736
continue
3837
# MGHImage accepts only a few datatypes
3938
# so we force a type change to float32
@@ -83,22 +82,21 @@ def test_files_interface():
8382

8483

8584
def test_round_trip():
86-
# write an image to files
87-
data = np.arange(24, dtype='i4').reshape((2,3,4))
88-
aff = np.eye(4)
89-
klasses = [val['class'] for key, val in class_map.items()
90-
if val['rw']]
91-
for klass in klasses:
92-
file_map = klass.make_file_map()
93-
for key in file_map:
94-
file_map[key].fileobj = BytesIO()
95-
img = klass(data, aff)
96-
img.file_map = file_map
97-
img.to_file_map()
98-
# read it back again from the written files
99-
img2 = klass.from_file_map(file_map)
100-
assert_array_equal(img2.get_data(), data)
101-
# write, read it again
102-
img2.to_file_map()
103-
img3 = klass.from_file_map(file_map)
104-
assert_array_equal(img3.get_data(), data)
85+
# write an image to files
86+
data = np.arange(24, dtype='i4').reshape((2,3,4))
87+
aff = np.eye(4)
88+
klasses = filter(lambda klass: klass.rw, all_image_classes)
89+
for klass in klasses:
90+
file_map = klass.make_file_map()
91+
for key in file_map:
92+
file_map[key].fileobj = BytesIO()
93+
img = klass(data, aff)
94+
img.file_map = file_map
95+
img.to_file_map()
96+
# read it back again from the written files
97+
img2 = klass.from_file_map(file_map)
98+
assert_array_equal(img2.get_data(), data)
99+
# write, read it again
100+
img2.to_file_map()
101+
img3 = klass.from_file_map(file_map)
102+
assert_array_equal(img3.get_data(), data)

nibabel/tests/test_image_load_save.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from .. import loadsave as nils
2727
from .. import (Nifti1Image, Nifti1Header, Nifti1Pair, Nifti2Image, Nifti2Pair,
2828
Minc1Image, Minc2Image, Spm2AnalyzeImage, Spm99AnalyzeImage,
29-
AnalyzeImage, MGHImage, class_map)
29+
AnalyzeImage, MGHImage, all_image_classes)
3030

3131
from ..tmpdirs import InTemporaryDirectory
3232

@@ -53,16 +53,14 @@ def test_conversion():
5353
affine = np.diag([1, 2, 3, 1])
5454
for npt in np.float32, np.int16:
5555
data = np.arange(np.prod(shape), dtype=npt).reshape(shape)
56-
for r_class_def in class_map.values():
57-
r_class = r_class_def['class']
58-
if not r_class_def['makeable']:
56+
for r_class in all_image_classes:
57+
if not r_class.makeable:
5958
continue
6059
img = r_class(data, affine)
6160
img.set_data_dtype(npt)
62-
for w_class_def in class_map.values():
63-
if not w_class_def['makeable']:
61+
for w_class in all_image_classes:
62+
if not w_class.makeable:
6463
continue
65-
w_class = w_class_def['class']
6664
img2 = w_class.from_image(img)
6765
assert_array_equal(img2.get_data(), data)
6866
assert_array_equal(img2.affine, affine)

0 commit comments

Comments
 (0)