Skip to content

FIX: Disable direct creation of non-conformant GiftiDataArrays #1199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 13, 2023
59 changes: 51 additions & 8 deletions nibabel/gifti/gifti.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import base64
import sys
import warnings
from typing import Type
from copy import copy
from typing import Type, cast

import numpy as np

Expand All @@ -27,6 +28,12 @@
from ..nifti1 import data_type_codes, intent_codes, xform_codes
from .util import KIND2FMT, array_index_order_codes, gifti_encoding_codes, gifti_endian_codes

GIFTI_DTYPES = (
data_type_codes['NIFTI_TYPE_UINT8'],
data_type_codes['NIFTI_TYPE_INT32'],
data_type_codes['NIFTI_TYPE_FLOAT32'],
)


class _GiftiMDList(list):
"""List view of GiftiMetaData object that will translate most operations"""
Expand Down Expand Up @@ -81,7 +88,8 @@ def _sanitize(args, kwargs):
<GiftiMetaData {'key': 'val'}>
>>> GiftiMetaData({"key": "val"})
<GiftiMetaData {'key': 'val'}>
>>> nvpairs = GiftiNVPairs(name='key', value='val')
>>> with pytest.deprecated_call():
... nvpairs = GiftiNVPairs(name='key', value='val')
>>> with pytest.warns(FutureWarning):
... GiftiMetaData(nvpairs)
<GiftiMetaData {'key': 'val'}>
Expand Down Expand Up @@ -460,7 +468,17 @@ def __init__(
self.data = None if data is None else np.asarray(data)
self.intent = intent_codes.code[intent]
if datatype is None:
datatype = 'none' if self.data is None else self.data.dtype
if self.data is None:
datatype = 'none'
elif data_type_codes[self.data.dtype] in GIFTI_DTYPES:
datatype = self.data.dtype
else:
raise ValueError(
f'Data array has type {self.data.dtype}. '
'The GIFTI standard only supports uint8, int32 and float32 arrays.\n'
'Explicitly cast the data array to a supported dtype or pass an '
'explicit "datatype" parameter to GiftiDataArray().'
)
self.datatype = data_type_codes.code[datatype]
self.encoding = gifti_encoding_codes.code[encoding]
self.endian = gifti_endian_codes.code[endian]
Expand Down Expand Up @@ -834,20 +852,45 @@ def _to_xml_element(self):
GIFTI.append(dar._to_xml_element())
return GIFTI

def to_xml(self, enc='utf-8') -> bytes:
def to_xml(self, enc='utf-8', *, mode='strict') -> bytes:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, what is the * for here? Do you expect this function to be called with more kwargs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The * means you have to call mode= as a keyword argument. So to_xml('utf-8', 'force') will fail, to_xml('utf-8', mode='force') will pass. I would be inclined to make enc keyword-only as well, I just didn't want to make it part of this PR...

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's interesting, thanks!

"""Return XML corresponding to image content"""
if mode == 'strict':
if any(arr.datatype not in GIFTI_DTYPES for arr in self.darrays):
raise ValueError(
'GiftiImage contains data arrays with invalid data types; '
'use mode="compat" to automatically cast to conforming types'
)
elif mode == 'compat':
darrays = []
for arr in self.darrays:
if arr.datatype not in GIFTI_DTYPES:
arr = copy(arr)
# TODO: Better typing for recoders
dtype = cast(np.dtype, data_type_codes.dtype[arr.datatype])
if np.issubdtype(dtype, np.floating):
arr.datatype = data_type_codes['float32']
elif np.issubdtype(dtype, np.integer):
arr.datatype = data_type_codes['int32']
else:
raise ValueError(f'Cannot convert {dtype} to float32/int32')
Comment on lines +866 to +875

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get how this part corrects the darray's datatype attribute, but I don't understand where the actual data is casted to the new datatype

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The data is only cast at write time. Serialization to XML will call DataArray._to_xml_element():

def _to_xml_element(self):
# fix endianness to machine endianness
self.endian = gifti_endian_codes.code[sys.byteorder]
# All attribute values must be strings
data_array = xml.Element(
'DataArray',
attrib={
'Intent': intent_codes.niistring[self.intent],
'DataType': data_type_codes.niistring[self.datatype],
'ArrayIndexingOrder': array_index_order_codes.label[self.ind_ord],
'Dimensionality': str(self.num_dim),
'Encoding': gifti_encoding_codes.specs[self.encoding],
'Endian': gifti_endian_codes.specs[self.endian],
'ExternalFileName': self.ext_fname,
'ExternalFileOffset': str(self.ext_offset),
},
)
for di, dn in enumerate(self.dims):
data_array.attrib['Dim%d' % di] = str(dn)
if self.meta is not None:
data_array.append(self.meta._to_xml_element())
if self.coordsys is not None:
data_array.append(self.coordsys._to_xml_element())
# write data array depending on the encoding
data_array.append(
_data_tag_element(
self.data,
gifti_encoding_codes.specs[self.encoding],
data_type_codes.dtype[self.datatype],
self.ind_ord,
)
)
return data_array

Which calls _data_tag_element():

def _data_tag_element(dataarray, encoding, dtype, ordering):
"""Creates data tag with given `encoding`, returns as XML element"""
import zlib
order = array_index_order_codes.npcode[ordering]
enclabel = gifti_encoding_codes.label[encoding]
if enclabel == 'ASCII':
da = _arr2txt(dataarray, KIND2FMT[dtype.kind])
elif enclabel in ('B64BIN', 'B64GZ'):
out = np.asanyarray(dataarray, dtype).tobytes(order)
if enclabel == 'B64GZ':
out = zlib.compress(out)
da = base64.b64encode(out).decode()
elif enclabel == 'External':
raise NotImplementedError('In what format are the external files?')
else:
da = ''
data = xml.Element('Data')
data.text = da
return data

L380 is the one that finally does it: out = np.asanyarray(dataarray, dtype).tobytes(order)

darrays.append(arr)
gii = copy(self)
gii.darrays = darrays
return gii.to_xml(enc=enc, mode='strict')
elif mode != 'force':
raise TypeError(f'Unknown mode {mode}')
header = b"""<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE GIFTI SYSTEM "http://www.nitrc.org/frs/download.php/115/gifti.dtd">
"""
return header + super().to_xml(enc)

# Avoid the indirection of going through to_file_map
def to_bytes(self, enc='utf-8'):
return self.to_xml(enc=enc)
def to_bytes(self, enc='utf-8', *, mode='strict'):
return self.to_xml(enc=enc, mode=mode)

to_bytes.__doc__ = SerializableImage.to_bytes.__doc__

def to_file_map(self, file_map=None, enc='utf-8'):
def to_file_map(self, file_map=None, enc='utf-8', *, mode='strict'):
"""Save the current image to the specified file_map

Parameters
Expand All @@ -863,7 +906,7 @@ def to_file_map(self, file_map=None, enc='utf-8'):
if file_map is None:
file_map = self.file_map
with file_map['image'].get_prepare_fileobj('wb') as f:
f.write(self.to_xml(enc=enc))
f.write(self.to_xml(enc=enc, mode=mode))

@classmethod
def from_file_map(klass, file_map, buffer_size=35000000, mmap=True):
Expand Down
101 changes: 92 additions & 9 deletions nibabel/gifti/tests/test_gifti.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
DATA_FILE6,
)

rng = np.random.default_rng()


def test_agg_data():
surf_gii_img = load(get_test_data('gifti', 'ascii.gii'))
Expand Down Expand Up @@ -81,7 +83,7 @@ def test_gifti_image():
assert gi.numDA == 0

# Test from numpy numeric array
data = np.random.random((5,))
data = rng.random(5, dtype=np.float32)
da = GiftiDataArray(data)
gi.add_gifti_data_array(da)
assert gi.numDA == 1
Expand All @@ -98,7 +100,7 @@ def test_gifti_image():

# Remove one
gi = GiftiImage()
da = GiftiDataArray(np.zeros((5,)), intent=0)
da = GiftiDataArray(np.zeros((5,), np.float32), intent=0)
gi.add_gifti_data_array(da)

gi.remove_gifti_data_array_by_intent(3)
Expand Down Expand Up @@ -126,6 +128,42 @@ def assign_metadata(val):
pytest.raises(TypeError, assign_metadata, 'not-a-meta')


@pytest.mark.parametrize('label', data_type_codes.value_set('label'))
def test_image_typing(label):
dtype = data_type_codes.dtype[label]
if dtype == np.void:
return
arr = 127 * rng.random(20)
try:
cast = arr.astype(label)
except TypeError:
return
darr = GiftiDataArray(cast, datatype=label)
img = GiftiImage(darrays=[darr])

# Force-write always works
force_rt = img.from_bytes(img.to_bytes(mode='force'))
assert np.array_equal(cast, force_rt.darrays[0].data)

# Compatibility mode does its best
if np.issubdtype(dtype, np.integer) or np.issubdtype(dtype, np.floating):
compat_rt = img.from_bytes(img.to_bytes(mode='compat'))
compat_darr = compat_rt.darrays[0].data
assert np.allclose(cast, compat_darr)
assert compat_darr.dtype in ('uint8', 'int32', 'float32')
else:
with pytest.raises(ValueError):
img.to_bytes(mode='compat')

# Strict mode either works or fails
if label in ('uint8', 'int32', 'float32'):
strict_rt = img.from_bytes(img.to_bytes(mode='strict'))
assert np.array_equal(cast, strict_rt.darrays[0].data)
else:
with pytest.raises(ValueError):
img.to_bytes(mode='strict')


def test_dataarray_empty():
# Test default initialization of DataArray
null_da = GiftiDataArray()
Expand Down Expand Up @@ -195,6 +233,38 @@ def test_dataarray_init():
assert gda(ext_offset=12).ext_offset == 12


@pytest.mark.parametrize('label', data_type_codes.value_set('label'))
def test_dataarray_typing(label):
dtype = data_type_codes.dtype[label]
code = data_type_codes.code[label]
arr = np.zeros((5,), dtype=dtype)

# Default interface: accept standards-conformant arrays, reject else
if dtype in ('uint8', 'int32', 'float32'):
assert GiftiDataArray(arr).datatype == code
else:
with pytest.raises(ValueError):
GiftiDataArray(arr)

# Explicit override - permit for now, may want to warn or eventually
# error
assert GiftiDataArray(arr, datatype=label).datatype == code
assert GiftiDataArray(arr, datatype=code).datatype == code
# Void is how we say we don't know how to do something, so it's not unique
if dtype != np.dtype('void'):
assert GiftiDataArray(arr, datatype=dtype).datatype == code

# Side-load data array (as in parsing)
# We will probably always want this to load legacy images, but it's
# probably not ideal to make it easy to silently propagate nonconformant
# arrays
gda = GiftiDataArray()
gda.data = arr
gda.datatype = data_type_codes.code[label]
assert gda.data.dtype == dtype
assert gda.datatype == data_type_codes.code[label]


def test_labeltable():
img = GiftiImage()
assert len(img.labeltable.labels) == 0
Expand Down Expand Up @@ -303,7 +373,7 @@ def test_metadata_list_interface():


def test_gifti_label_rgba():
rgba = np.random.rand(4)
rgba = rng.random(4)
kwargs = dict(zip(['red', 'green', 'blue', 'alpha'], rgba))

gl1 = GiftiLabel(**kwargs)
Expand Down Expand Up @@ -332,13 +402,17 @@ def assign_rgba(gl, val):
assert np.all([elem is None for elem in gl4.rgba])


def test_print_summary():
for fil in [DATA_FILE1, DATA_FILE2, DATA_FILE3, DATA_FILE4, DATA_FILE5, DATA_FILE6]:
gimg = load(fil)
gimg.print_summary()
@pytest.mark.parametrize(
'fname', [DATA_FILE1, DATA_FILE2, DATA_FILE3, DATA_FILE4, DATA_FILE5, DATA_FILE6]
)
def test_print_summary(fname, capsys):
gimg = load(fname)
gimg.print_summary()
captured = capsys.readouterr()
assert captured.out.startswith('----start----\n')


def test_gifti_coord():
def test_gifti_coord(capsys):
from ..gifti import GiftiCoordSystem

gcs = GiftiCoordSystem()
Expand All @@ -347,6 +421,15 @@ def test_gifti_coord():
# Smoke test
gcs.xform = None
gcs.print_summary()
captured = capsys.readouterr()
assert captured.out == '\n'.join(
[
'Dataspace: NIFTI_XFORM_UNKNOWN',
'XFormSpace: NIFTI_XFORM_UNKNOWN',
'Affine Transformation Matrix: ',
' None\n',
]
)
gcs.to_xml()


Expand Down Expand Up @@ -471,7 +554,7 @@ def test_darray_dtype_coercion_failures():
datatype=darray_dtype,
)
gii = GiftiImage(darrays=[da])
gii_copy = GiftiImage.from_bytes(gii.to_bytes())
gii_copy = GiftiImage.from_bytes(gii.to_bytes(mode='force'))
da_copy = gii_copy.darrays[0]
assert np.dtype(da_copy.data.dtype) == np.dtype(darray_dtype)
assert_array_equal(da_copy.data, da.data)
Expand Down