Skip to content

ENH: Add dtype argument to Cifti2Image #1111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions nibabel/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,13 +1020,24 @@ def to_file_map(self, file_map=None, dtype=None):
inter = hdr['scl_inter'].item() if hdr.has_data_intercept else np.nan
# Check whether to calculate slope / inter
scale_me = np.all(np.isnan((slope, inter)))
if scale_me:
arr_writer = make_array_writer(data,
out_dtype,
hdr.has_data_slope,
hdr.has_data_intercept)
else:
arr_writer = ArrayWriter(data, out_dtype, check_scaling=False)
try:
if scale_me:
arr_writer = make_array_writer(data,
out_dtype,
hdr.has_data_slope,
hdr.has_data_intercept)
else:
arr_writer = ArrayWriter(data, out_dtype, check_scaling=False)
except WriterError:
# Restore any changed consumable values, in case caller catches
# Should match cleanup at the end of the method
hdr.set_data_offset(offset)
hdr.set_data_dtype(data_dtype)
if hdr.has_data_slope:
hdr['scl_slope'] = slope
if hdr.has_data_intercept:
hdr['scl_inter'] = inter
raise
hdr_fh, img_fh = self._get_fileholders(file_map)
# Check if hdr and img refer to same file; this can happen with odd
# analyze images but most often this is because it's a single nifti
Expand Down
38 changes: 32 additions & 6 deletions nibabel/cifti2/cifti2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,18 @@
import re
from collections.abc import MutableSequence, MutableMapping, Iterable
from collections import OrderedDict
from warnings import warn

import numpy as np

from .. import xmlutils as xml
from ..filebasedimages import FileBasedHeader, SerializableImage
from ..dataobj_images import DataobjImage
from ..nifti1 import Nifti1Extensions
from ..nifti2 import Nifti2Image, Nifti2Header
from ..arrayproxy import reshape_dataobj
from ..caret import CaretMetaData
from warnings import warn
from ..volumeutils import make_dt_codes


def _float_01(val):
Expand All @@ -41,6 +45,22 @@ class Cifti2HeaderError(Exception):
"""


_dtdefs = ( # code, label, dtype definition, niistring
(2, 'uint8', np.uint8, "NIFTI_TYPE_UINT8"),
(4, 'int16', np.int16, "NIFTI_TYPE_INT16"),
(8, 'int32', np.int32, "NIFTI_TYPE_INT32"),
(16, 'float32', np.float32, "NIFTI_TYPE_FLOAT32"),
(64, 'float64', np.float64, "NIFTI_TYPE_FLOAT64"),
(256, 'int8', np.int8, "NIFTI_TYPE_INT8"),
(512, 'uint16', np.uint16, "NIFTI_TYPE_UINT16"),
(768, 'uint32', np.uint32, "NIFTI_TYPE_UINT32"),
(1024, 'int64', np.int64, "NIFTI_TYPE_INT64"),
(1280, 'uint64', np.uint64, "NIFTI_TYPE_UINT64"),
)

# Make full code alias bank, including dtype column
data_type_codes = make_dt_codes(_dtdefs)

CIFTI_MAP_TYPES = ('CIFTI_INDEX_TYPE_BRAIN_MODELS',
'CIFTI_INDEX_TYPE_PARCELS',
'CIFTI_INDEX_TYPE_SERIES',
Expand Down Expand Up @@ -103,6 +123,10 @@ def _underscore(string):
return re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', string).lower()


class LimitedNifti2Header(Nifti2Header):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly there's a better name. _Cifti2AsNiftiHeader exists to do some validation during parsing, but isn't the type of nifti_header.

class _Cifti2AsNiftiHeader(Nifti2Header):
""" Class for Cifti2 header extension """
@classmethod
def _valid_intent_code(klass, intent_code):
""" Return True if `intent_code` matches our class `klass`
"""
return intent_code >= 3000 and intent_code < 3100
@classmethod
def may_contain_header(klass, binaryblock):
if not super(_Cifti2AsNiftiHeader, klass).may_contain_header(binaryblock):
return False
hdr = klass(binaryblock=binaryblock[:klass.sizeof_hdr])
return klass._valid_intent_code(hdr.get_intent('code')[0])
@staticmethod
def _chk_qfac(hdr, fix=False):
# Allow qfac of 0 without complaint for CIFTI-2
rep = Report(HeaderDataError)
if hdr['pixdim'][0] in (-1, 0, 1):
return hdr, rep
rep.problem_level = 20
rep.problem_msg = 'pixdim[0] (qfac) should be 1 (default) or 0 or -1'
if fix:
hdr['pixdim'][0] = 1
rep.fix_msg = 'setting qfac to 1'
return hdr, rep
@staticmethod
def _chk_pixdims(hdr, fix=False):
rep = Report(HeaderDataError)
pixdims = hdr['pixdim']
spat_dims = pixdims[1:4]
if not np.any(spat_dims < 0):
return hdr, rep
rep.problem_level = 35
rep.problem_msg = 'pixdim[1,2,3] should be zero or positive'
if fix:
hdr['pixdim'][1:4] = np.abs(spat_dims)
rep.fix_msg = 'setting to abs of pixdim values'
return hdr, rep

_data_type_codes = data_type_codes


class Cifti2MetaData(CaretMetaData):
""" A list of name-value pairs

Expand Down Expand Up @@ -1363,7 +1387,8 @@ def __init__(self,
header=None,
nifti_header=None,
extra=None,
file_map=None):
file_map=None,
dtype=None):
""" Initialize image

The image is a combination of (dataobj, header), with optional metadata
Expand Down Expand Up @@ -1392,12 +1417,13 @@ def __init__(self,
header = Cifti2Header.from_axes(header)
super(Cifti2Image, self).__init__(dataobj, header=header,
extra=extra, file_map=file_map)
self._nifti_header = Nifti2Header.from_header(nifti_header)
self._nifti_header = LimitedNifti2Header.from_header(nifti_header)

# if NIfTI header not specified, get data type from input array
if nifti_header is None:
if hasattr(dataobj, 'dtype'):
self._nifti_header.set_data_dtype(dataobj.dtype)
if dtype is not None:
self.set_data_dtype(dtype)
elif nifti_header is None and hasattr(dataobj, 'dtype'):
self.set_data_dtype(dataobj.dtype)
self.update_headers()

if self._dataobj.shape != self.header.matrix.get_data_shape():
Expand Down
6 changes: 4 additions & 2 deletions nibabel/cifti2/tests/test_cifti2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pytest

from nibabel.tests.test_dataobj_images import TestDataobjAPI as _TDA
from nibabel.tests.test_image_api import SerializeMixin
from nibabel.tests.test_image_api import SerializeMixin, DtypeOverrideMixin


def compare_xml_leaf(str1, str2):
Expand Down Expand Up @@ -415,7 +415,7 @@ def test_underscoring():
assert ci.cifti2._underscore(camel) == underscored


class TestCifti2ImageAPI(_TDA, SerializeMixin):
class TestCifti2ImageAPI(_TDA, SerializeMixin, DtypeOverrideMixin):
""" Basic validation for Cifti2Image instances
"""
# A callable returning an image from ``image_maker(data, header)``
Expand All @@ -426,6 +426,8 @@ class TestCifti2ImageAPI(_TDA, SerializeMixin):
ni_header_maker = Nifti2Header
example_shapes = ((2,), (2, 3), (2, 3, 4))
standard_extension = '.nii'
storable_dtypes = (np.int8, np.uint8, np.int16, np.uint16, np.int32, np.uint32,
np.int64, np.uint64, np.float32, np.float64)

def make_imaker(self, arr, header=None, ni_header=None):
for idx, sz in enumerate(arr.shape):
Expand Down
56 changes: 49 additions & 7 deletions nibabel/tests/test_image_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
from .test_parrec import EXAMPLE_IMAGES as PARREC_EXAMPLE_IMAGES
from .test_brikhead import EXAMPLE_IMAGES as AFNI_EXAMPLE_IMAGES

from nibabel.arraywriters import WriterError


def maybe_deprecated(meth_name):
return pytest.deprecated_call() if meth_name == 'get_data' else nullcontext()
Expand Down Expand Up @@ -181,7 +183,7 @@ def validate_get_data_deprecated(self, imaker, params):
assert_array_equal(np.asanyarray(img.dataobj), data)


class GetSetDtypeMixin(object):
class GetSetDtypeMixin:
""" Adds dtype tests

Add this one if your image has ``get_data_dtype`` and ``set_data_dtype``.
Expand Down Expand Up @@ -666,6 +668,46 @@ def prox_imaker():
yield make_prox_imaker(arr.copy(), aff, hdr), params


class DtypeOverrideMixin(GetSetDtypeMixin):
""" Test images that can accept ``dtype`` arguments to ``__init__`` and
``to_file_map``
"""

def validate_init_dtype_override(self, imaker, params):
img = imaker()
klass = img.__class__
for dtype in self.storable_dtypes:
if hasattr(img, 'affine'):
new_img = klass(img.dataobj, img.affine, header=img.header, dtype=dtype)
else: # XXX This is for CIFTI-2, these validators might need refactoring
new_img = klass(img.dataobj, header=img.header, dtype=dtype)
assert new_img.get_data_dtype() == dtype

if self.has_scaling and self.can_save:
with np.errstate(invalid='ignore'):
rt_img = bytesio_round_trip(new_img)
assert rt_img.get_data_dtype() == dtype

def validate_to_file_dtype_override(self, imaker, params):
if not self.can_save:
raise unittest.SkipTest
img = imaker()
orig_dtype = img.get_data_dtype()
fname = 'image' + self.standard_extension
with InTemporaryDirectory():
for dtype in self.storable_dtypes:
try:
img.to_filename(fname, dtype=dtype)
except WriterError:
# It's possible to try to save to a dtype that requires
# scaling, and images without scale factors will fail.
# We're not testing that here.
continue
rt_img = img.__class__.from_filename(fname)
assert rt_img.get_data_dtype() == dtype
assert img.get_data_dtype() == orig_dtype


class ImageHeaderAPI(MakeImageAPI):
""" When ``self.image_maker`` is an image class, make header from class
"""
Expand All @@ -674,7 +716,12 @@ def header_maker(self):
return self.image_maker.header_class()


class TestAnalyzeAPI(ImageHeaderAPI):
class TestSpatialImageAPI(ImageHeaderAPI):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This subclass/superclass swap occurs because not all spatial images have dtype overriding.

klass = image_maker = SpatialImage
can_save = False


class TestAnalyzeAPI(TestSpatialImageAPI, DtypeOverrideMixin):
""" General image validation API instantiated for Analyze images
"""
klass = image_maker = AnalyzeImage
Expand All @@ -685,11 +732,6 @@ class TestAnalyzeAPI(ImageHeaderAPI):
storable_dtypes = (np.uint8, np.int16, np.int32, np.float32, np.float64)


class TestSpatialImageAPI(TestAnalyzeAPI):
klass = image_maker = SpatialImage
can_save = False


class TestSpm99AnalyzeAPI(TestAnalyzeAPI):
# SPM-type analyze need scipy for mat file IO
klass = image_maker = Spm99AnalyzeImage
Expand Down