Skip to content

Commit a7e1e0e

Browse files
authored
Merge pull request #1111 from effigies/enh/cifti_dtype_arg
ENH: Add dtype argument to Cifti2Image
2 parents d0532ec + 7933fae commit a7e1e0e

File tree

4 files changed

+103
-22
lines changed

4 files changed

+103
-22
lines changed

nibabel/analyze.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,13 +1020,24 @@ def to_file_map(self, file_map=None, dtype=None):
10201020
inter = hdr['scl_inter'].item() if hdr.has_data_intercept else np.nan
10211021
# Check whether to calculate slope / inter
10221022
scale_me = np.all(np.isnan((slope, inter)))
1023-
if scale_me:
1024-
arr_writer = make_array_writer(data,
1025-
out_dtype,
1026-
hdr.has_data_slope,
1027-
hdr.has_data_intercept)
1028-
else:
1029-
arr_writer = ArrayWriter(data, out_dtype, check_scaling=False)
1023+
try:
1024+
if scale_me:
1025+
arr_writer = make_array_writer(data,
1026+
out_dtype,
1027+
hdr.has_data_slope,
1028+
hdr.has_data_intercept)
1029+
else:
1030+
arr_writer = ArrayWriter(data, out_dtype, check_scaling=False)
1031+
except WriterError:
1032+
# Restore any changed consumable values, in case caller catches
1033+
# Should match cleanup at the end of the method
1034+
hdr.set_data_offset(offset)
1035+
hdr.set_data_dtype(data_dtype)
1036+
if hdr.has_data_slope:
1037+
hdr['scl_slope'] = slope
1038+
if hdr.has_data_intercept:
1039+
hdr['scl_inter'] = inter
1040+
raise
10301041
hdr_fh, img_fh = self._get_fileholders(file_map)
10311042
# Check if hdr and img refer to same file; this can happen with odd
10321043
# analyze images but most often this is because it's a single nifti

nibabel/cifti2/cifti2.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,18 @@
1919
import re
2020
from collections.abc import MutableSequence, MutableMapping, Iterable
2121
from collections import OrderedDict
22+
from warnings import warn
23+
24+
import numpy as np
25+
2226
from .. import xmlutils as xml
2327
from ..filebasedimages import FileBasedHeader, SerializableImage
2428
from ..dataobj_images import DataobjImage
2529
from ..nifti1 import Nifti1Extensions
2630
from ..nifti2 import Nifti2Image, Nifti2Header
2731
from ..arrayproxy import reshape_dataobj
2832
from ..caret import CaretMetaData
29-
from warnings import warn
33+
from ..volumeutils import make_dt_codes
3034

3135

3236
def _float_01(val):
@@ -41,6 +45,22 @@ class Cifti2HeaderError(Exception):
4145
"""
4246

4347

48+
_dtdefs = ( # code, label, dtype definition, niistring
49+
(2, 'uint8', np.uint8, "NIFTI_TYPE_UINT8"),
50+
(4, 'int16', np.int16, "NIFTI_TYPE_INT16"),
51+
(8, 'int32', np.int32, "NIFTI_TYPE_INT32"),
52+
(16, 'float32', np.float32, "NIFTI_TYPE_FLOAT32"),
53+
(64, 'float64', np.float64, "NIFTI_TYPE_FLOAT64"),
54+
(256, 'int8', np.int8, "NIFTI_TYPE_INT8"),
55+
(512, 'uint16', np.uint16, "NIFTI_TYPE_UINT16"),
56+
(768, 'uint32', np.uint32, "NIFTI_TYPE_UINT32"),
57+
(1024, 'int64', np.int64, "NIFTI_TYPE_INT64"),
58+
(1280, 'uint64', np.uint64, "NIFTI_TYPE_UINT64"),
59+
)
60+
61+
# Make full code alias bank, including dtype column
62+
data_type_codes = make_dt_codes(_dtdefs)
63+
4464
CIFTI_MAP_TYPES = ('CIFTI_INDEX_TYPE_BRAIN_MODELS',
4565
'CIFTI_INDEX_TYPE_PARCELS',
4666
'CIFTI_INDEX_TYPE_SERIES',
@@ -103,6 +123,10 @@ def _underscore(string):
103123
return re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', string).lower()
104124

105125

126+
class LimitedNifti2Header(Nifti2Header):
127+
_data_type_codes = data_type_codes
128+
129+
106130
class Cifti2MetaData(CaretMetaData):
107131
""" A list of name-value pairs
108132
@@ -1363,7 +1387,8 @@ def __init__(self,
13631387
header=None,
13641388
nifti_header=None,
13651389
extra=None,
1366-
file_map=None):
1390+
file_map=None,
1391+
dtype=None):
13671392
""" Initialize image
13681393
13691394
The image is a combination of (dataobj, header), with optional metadata
@@ -1392,12 +1417,13 @@ def __init__(self,
13921417
header = Cifti2Header.from_axes(header)
13931418
super(Cifti2Image, self).__init__(dataobj, header=header,
13941419
extra=extra, file_map=file_map)
1395-
self._nifti_header = Nifti2Header.from_header(nifti_header)
1420+
self._nifti_header = LimitedNifti2Header.from_header(nifti_header)
13961421

13971422
# if NIfTI header not specified, get data type from input array
1398-
if nifti_header is None:
1399-
if hasattr(dataobj, 'dtype'):
1400-
self._nifti_header.set_data_dtype(dataobj.dtype)
1423+
if dtype is not None:
1424+
self.set_data_dtype(dtype)
1425+
elif nifti_header is None and hasattr(dataobj, 'dtype'):
1426+
self.set_data_dtype(dataobj.dtype)
14011427
self.update_headers()
14021428

14031429
if self._dataobj.shape != self.header.matrix.get_data_shape():

nibabel/cifti2/tests/test_cifti2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pytest
1313

1414
from nibabel.tests.test_dataobj_images import TestDataobjAPI as _TDA
15-
from nibabel.tests.test_image_api import SerializeMixin
15+
from nibabel.tests.test_image_api import SerializeMixin, DtypeOverrideMixin
1616

1717

1818
def compare_xml_leaf(str1, str2):
@@ -415,7 +415,7 @@ def test_underscoring():
415415
assert ci.cifti2._underscore(camel) == underscored
416416

417417

418-
class TestCifti2ImageAPI(_TDA, SerializeMixin):
418+
class TestCifti2ImageAPI(_TDA, SerializeMixin, DtypeOverrideMixin):
419419
""" Basic validation for Cifti2Image instances
420420
"""
421421
# A callable returning an image from ``image_maker(data, header)``
@@ -426,6 +426,8 @@ class TestCifti2ImageAPI(_TDA, SerializeMixin):
426426
ni_header_maker = Nifti2Header
427427
example_shapes = ((2,), (2, 3), (2, 3, 4))
428428
standard_extension = '.nii'
429+
storable_dtypes = (np.int8, np.uint8, np.int16, np.uint16, np.int32, np.uint32,
430+
np.int64, np.uint64, np.float32, np.float64)
429431

430432
def make_imaker(self, arr, header=None, ni_header=None):
431433
for idx, sz in enumerate(arr.shape):

nibabel/tests/test_image_api.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555
from .test_parrec import EXAMPLE_IMAGES as PARREC_EXAMPLE_IMAGES
5656
from .test_brikhead import EXAMPLE_IMAGES as AFNI_EXAMPLE_IMAGES
5757

58+
from nibabel.arraywriters import WriterError
59+
5860

5961
def maybe_deprecated(meth_name):
6062
return pytest.deprecated_call() if meth_name == 'get_data' else nullcontext()
@@ -181,7 +183,7 @@ def validate_get_data_deprecated(self, imaker, params):
181183
assert_array_equal(np.asanyarray(img.dataobj), data)
182184

183185

184-
class GetSetDtypeMixin(object):
186+
class GetSetDtypeMixin:
185187
""" Adds dtype tests
186188
187189
Add this one if your image has ``get_data_dtype`` and ``set_data_dtype``.
@@ -666,6 +668,46 @@ def prox_imaker():
666668
yield make_prox_imaker(arr.copy(), aff, hdr), params
667669

668670

671+
class DtypeOverrideMixin(GetSetDtypeMixin):
672+
""" Test images that can accept ``dtype`` arguments to ``__init__`` and
673+
``to_file_map``
674+
"""
675+
676+
def validate_init_dtype_override(self, imaker, params):
677+
img = imaker()
678+
klass = img.__class__
679+
for dtype in self.storable_dtypes:
680+
if hasattr(img, 'affine'):
681+
new_img = klass(img.dataobj, img.affine, header=img.header, dtype=dtype)
682+
else: # XXX This is for CIFTI-2, these validators might need refactoring
683+
new_img = klass(img.dataobj, header=img.header, dtype=dtype)
684+
assert new_img.get_data_dtype() == dtype
685+
686+
if self.has_scaling and self.can_save:
687+
with np.errstate(invalid='ignore'):
688+
rt_img = bytesio_round_trip(new_img)
689+
assert rt_img.get_data_dtype() == dtype
690+
691+
def validate_to_file_dtype_override(self, imaker, params):
692+
if not self.can_save:
693+
raise unittest.SkipTest
694+
img = imaker()
695+
orig_dtype = img.get_data_dtype()
696+
fname = 'image' + self.standard_extension
697+
with InTemporaryDirectory():
698+
for dtype in self.storable_dtypes:
699+
try:
700+
img.to_filename(fname, dtype=dtype)
701+
except WriterError:
702+
# It's possible to try to save to a dtype that requires
703+
# scaling, and images without scale factors will fail.
704+
# We're not testing that here.
705+
continue
706+
rt_img = img.__class__.from_filename(fname)
707+
assert rt_img.get_data_dtype() == dtype
708+
assert img.get_data_dtype() == orig_dtype
709+
710+
669711
class ImageHeaderAPI(MakeImageAPI):
670712
""" When ``self.image_maker`` is an image class, make header from class
671713
"""
@@ -674,7 +716,12 @@ def header_maker(self):
674716
return self.image_maker.header_class()
675717

676718

677-
class TestAnalyzeAPI(ImageHeaderAPI):
719+
class TestSpatialImageAPI(ImageHeaderAPI):
720+
klass = image_maker = SpatialImage
721+
can_save = False
722+
723+
724+
class TestAnalyzeAPI(TestSpatialImageAPI, DtypeOverrideMixin):
678725
""" General image validation API instantiated for Analyze images
679726
"""
680727
klass = image_maker = AnalyzeImage
@@ -685,11 +732,6 @@ class TestAnalyzeAPI(ImageHeaderAPI):
685732
storable_dtypes = (np.uint8, np.int16, np.int32, np.float32, np.float64)
686733

687734

688-
class TestSpatialImageAPI(TestAnalyzeAPI):
689-
klass = image_maker = SpatialImage
690-
can_save = False
691-
692-
693735
class TestSpm99AnalyzeAPI(TestAnalyzeAPI):
694736
# SPM-type analyze need scipy for mat file IO
695737
klass = image_maker = Spm99AnalyzeImage

0 commit comments

Comments
 (0)