Skip to content

Commit 40e31e8

Browse files
authored
Merge pull request #1199 from effigies/fix/gifti_dtypes
FIX: Disable direct creation of non-conformant GiftiDataArrays
2 parents cafd0ab + cd1a39a commit 40e31e8

File tree

2 files changed

+143
-17
lines changed

2 files changed

+143
-17
lines changed

nibabel/gifti/gifti.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
import base64
1717
import sys
1818
import warnings
19-
from typing import Type
19+
from copy import copy
20+
from typing import Type, cast
2021

2122
import numpy as np
2223

@@ -27,6 +28,12 @@
2728
from ..nifti1 import data_type_codes, intent_codes, xform_codes
2829
from .util import KIND2FMT, array_index_order_codes, gifti_encoding_codes, gifti_endian_codes
2930

31+
GIFTI_DTYPES = (
32+
data_type_codes['NIFTI_TYPE_UINT8'],
33+
data_type_codes['NIFTI_TYPE_INT32'],
34+
data_type_codes['NIFTI_TYPE_FLOAT32'],
35+
)
36+
3037

3138
class _GiftiMDList(list):
3239
"""List view of GiftiMetaData object that will translate most operations"""
@@ -81,7 +88,8 @@ def _sanitize(args, kwargs):
8188
<GiftiMetaData {'key': 'val'}>
8289
>>> GiftiMetaData({"key": "val"})
8390
<GiftiMetaData {'key': 'val'}>
84-
>>> nvpairs = GiftiNVPairs(name='key', value='val')
91+
>>> with pytest.deprecated_call():
92+
... nvpairs = GiftiNVPairs(name='key', value='val')
8593
>>> with pytest.warns(FutureWarning):
8694
... GiftiMetaData(nvpairs)
8795
<GiftiMetaData {'key': 'val'}>
@@ -460,7 +468,17 @@ def __init__(
460468
self.data = None if data is None else np.asarray(data)
461469
self.intent = intent_codes.code[intent]
462470
if datatype is None:
463-
datatype = 'none' if self.data is None else self.data.dtype
471+
if self.data is None:
472+
datatype = 'none'
473+
elif data_type_codes[self.data.dtype] in GIFTI_DTYPES:
474+
datatype = self.data.dtype
475+
else:
476+
raise ValueError(
477+
f'Data array has type {self.data.dtype}. '
478+
'The GIFTI standard only supports uint8, int32 and float32 arrays.\n'
479+
'Explicitly cast the data array to a supported dtype or pass an '
480+
'explicit "datatype" parameter to GiftiDataArray().'
481+
)
464482
self.datatype = data_type_codes.code[datatype]
465483
self.encoding = gifti_encoding_codes.code[encoding]
466484
self.endian = gifti_endian_codes.code[endian]
@@ -834,20 +852,45 @@ def _to_xml_element(self):
834852
GIFTI.append(dar._to_xml_element())
835853
return GIFTI
836854

837-
def to_xml(self, enc='utf-8') -> bytes:
855+
def to_xml(self, enc='utf-8', *, mode='strict') -> bytes:
838856
"""Return XML corresponding to image content"""
857+
if mode == 'strict':
858+
if any(arr.datatype not in GIFTI_DTYPES for arr in self.darrays):
859+
raise ValueError(
860+
'GiftiImage contains data arrays with invalid data types; '
861+
'use mode="compat" to automatically cast to conforming types'
862+
)
863+
elif mode == 'compat':
864+
darrays = []
865+
for arr in self.darrays:
866+
if arr.datatype not in GIFTI_DTYPES:
867+
arr = copy(arr)
868+
# TODO: Better typing for recoders
869+
dtype = cast(np.dtype, data_type_codes.dtype[arr.datatype])
870+
if np.issubdtype(dtype, np.floating):
871+
arr.datatype = data_type_codes['float32']
872+
elif np.issubdtype(dtype, np.integer):
873+
arr.datatype = data_type_codes['int32']
874+
else:
875+
raise ValueError(f'Cannot convert {dtype} to float32/int32')
876+
darrays.append(arr)
877+
gii = copy(self)
878+
gii.darrays = darrays
879+
return gii.to_xml(enc=enc, mode='strict')
880+
elif mode != 'force':
881+
raise TypeError(f'Unknown mode {mode}')
839882
header = b"""<?xml version="1.0" encoding="UTF-8"?>
840883
<!DOCTYPE GIFTI SYSTEM "http://www.nitrc.org/frs/download.php/115/gifti.dtd">
841884
"""
842885
return header + super().to_xml(enc)
843886

844887
# Avoid the indirection of going through to_file_map
845-
def to_bytes(self, enc='utf-8'):
846-
return self.to_xml(enc=enc)
888+
def to_bytes(self, enc='utf-8', *, mode='strict'):
889+
return self.to_xml(enc=enc, mode=mode)
847890

848891
to_bytes.__doc__ = SerializableImage.to_bytes.__doc__
849892

850-
def to_file_map(self, file_map=None, enc='utf-8'):
893+
def to_file_map(self, file_map=None, enc='utf-8', *, mode='strict'):
851894
"""Save the current image to the specified file_map
852895
853896
Parameters
@@ -863,7 +906,7 @@ def to_file_map(self, file_map=None, enc='utf-8'):
863906
if file_map is None:
864907
file_map = self.file_map
865908
with file_map['image'].get_prepare_fileobj('wb') as f:
866-
f.write(self.to_xml(enc=enc))
909+
f.write(self.to_xml(enc=enc, mode=mode))
867910

868911
@classmethod
869912
def from_file_map(klass, file_map, buffer_size=35000000, mmap=True):

nibabel/gifti/tests/test_gifti.py

Lines changed: 92 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
DATA_FILE6,
3434
)
3535

36+
rng = np.random.default_rng()
37+
3638

3739
def test_agg_data():
3840
surf_gii_img = load(get_test_data('gifti', 'ascii.gii'))
@@ -81,7 +83,7 @@ def test_gifti_image():
8183
assert gi.numDA == 0
8284

8385
# Test from numpy numeric array
84-
data = np.random.random((5,))
86+
data = rng.random(5, dtype=np.float32)
8587
da = GiftiDataArray(data)
8688
gi.add_gifti_data_array(da)
8789
assert gi.numDA == 1
@@ -98,7 +100,7 @@ def test_gifti_image():
98100

99101
# Remove one
100102
gi = GiftiImage()
101-
da = GiftiDataArray(np.zeros((5,)), intent=0)
103+
da = GiftiDataArray(np.zeros((5,), np.float32), intent=0)
102104
gi.add_gifti_data_array(da)
103105

104106
gi.remove_gifti_data_array_by_intent(3)
@@ -126,6 +128,42 @@ def assign_metadata(val):
126128
pytest.raises(TypeError, assign_metadata, 'not-a-meta')
127129

128130

131+
@pytest.mark.parametrize('label', data_type_codes.value_set('label'))
132+
def test_image_typing(label):
133+
dtype = data_type_codes.dtype[label]
134+
if dtype == np.void:
135+
return
136+
arr = 127 * rng.random(20)
137+
try:
138+
cast = arr.astype(label)
139+
except TypeError:
140+
return
141+
darr = GiftiDataArray(cast, datatype=label)
142+
img = GiftiImage(darrays=[darr])
143+
144+
# Force-write always works
145+
force_rt = img.from_bytes(img.to_bytes(mode='force'))
146+
assert np.array_equal(cast, force_rt.darrays[0].data)
147+
148+
# Compatibility mode does its best
149+
if np.issubdtype(dtype, np.integer) or np.issubdtype(dtype, np.floating):
150+
compat_rt = img.from_bytes(img.to_bytes(mode='compat'))
151+
compat_darr = compat_rt.darrays[0].data
152+
assert np.allclose(cast, compat_darr)
153+
assert compat_darr.dtype in ('uint8', 'int32', 'float32')
154+
else:
155+
with pytest.raises(ValueError):
156+
img.to_bytes(mode='compat')
157+
158+
# Strict mode either works or fails
159+
if label in ('uint8', 'int32', 'float32'):
160+
strict_rt = img.from_bytes(img.to_bytes(mode='strict'))
161+
assert np.array_equal(cast, strict_rt.darrays[0].data)
162+
else:
163+
with pytest.raises(ValueError):
164+
img.to_bytes(mode='strict')
165+
166+
129167
def test_dataarray_empty():
130168
# Test default initialization of DataArray
131169
null_da = GiftiDataArray()
@@ -195,6 +233,38 @@ def test_dataarray_init():
195233
assert gda(ext_offset=12).ext_offset == 12
196234

197235

236+
@pytest.mark.parametrize('label', data_type_codes.value_set('label'))
237+
def test_dataarray_typing(label):
238+
dtype = data_type_codes.dtype[label]
239+
code = data_type_codes.code[label]
240+
arr = np.zeros((5,), dtype=dtype)
241+
242+
# Default interface: accept standards-conformant arrays, reject else
243+
if dtype in ('uint8', 'int32', 'float32'):
244+
assert GiftiDataArray(arr).datatype == code
245+
else:
246+
with pytest.raises(ValueError):
247+
GiftiDataArray(arr)
248+
249+
# Explicit override - permit for now, may want to warn or eventually
250+
# error
251+
assert GiftiDataArray(arr, datatype=label).datatype == code
252+
assert GiftiDataArray(arr, datatype=code).datatype == code
253+
# Void is how we say we don't know how to do something, so it's not unique
254+
if dtype != np.dtype('void'):
255+
assert GiftiDataArray(arr, datatype=dtype).datatype == code
256+
257+
# Side-load data array (as in parsing)
258+
# We will probably always want this to load legacy images, but it's
259+
# probably not ideal to make it easy to silently propagate nonconformant
260+
# arrays
261+
gda = GiftiDataArray()
262+
gda.data = arr
263+
gda.datatype = data_type_codes.code[label]
264+
assert gda.data.dtype == dtype
265+
assert gda.datatype == data_type_codes.code[label]
266+
267+
198268
def test_labeltable():
199269
img = GiftiImage()
200270
assert len(img.labeltable.labels) == 0
@@ -303,7 +373,7 @@ def test_metadata_list_interface():
303373

304374

305375
def test_gifti_label_rgba():
306-
rgba = np.random.rand(4)
376+
rgba = rng.random(4)
307377
kwargs = dict(zip(['red', 'green', 'blue', 'alpha'], rgba))
308378

309379
gl1 = GiftiLabel(**kwargs)
@@ -332,13 +402,17 @@ def assign_rgba(gl, val):
332402
assert np.all([elem is None for elem in gl4.rgba])
333403

334404

335-
def test_print_summary():
336-
for fil in [DATA_FILE1, DATA_FILE2, DATA_FILE3, DATA_FILE4, DATA_FILE5, DATA_FILE6]:
337-
gimg = load(fil)
338-
gimg.print_summary()
405+
@pytest.mark.parametrize(
406+
'fname', [DATA_FILE1, DATA_FILE2, DATA_FILE3, DATA_FILE4, DATA_FILE5, DATA_FILE6]
407+
)
408+
def test_print_summary(fname, capsys):
409+
gimg = load(fname)
410+
gimg.print_summary()
411+
captured = capsys.readouterr()
412+
assert captured.out.startswith('----start----\n')
339413

340414

341-
def test_gifti_coord():
415+
def test_gifti_coord(capsys):
342416
from ..gifti import GiftiCoordSystem
343417

344418
gcs = GiftiCoordSystem()
@@ -347,6 +421,15 @@ def test_gifti_coord():
347421
# Smoke test
348422
gcs.xform = None
349423
gcs.print_summary()
424+
captured = capsys.readouterr()
425+
assert captured.out == '\n'.join(
426+
[
427+
'Dataspace: NIFTI_XFORM_UNKNOWN',
428+
'XFormSpace: NIFTI_XFORM_UNKNOWN',
429+
'Affine Transformation Matrix: ',
430+
' None\n',
431+
]
432+
)
350433
gcs.to_xml()
351434

352435

@@ -471,7 +554,7 @@ def test_darray_dtype_coercion_failures():
471554
datatype=darray_dtype,
472555
)
473556
gii = GiftiImage(darrays=[da])
474-
gii_copy = GiftiImage.from_bytes(gii.to_bytes())
557+
gii_copy = GiftiImage.from_bytes(gii.to_bytes(mode='force'))
475558
da_copy = gii_copy.darrays[0]
476559
assert np.dtype(da_copy.data.dtype) == np.dtype(darray_dtype)
477560
assert_array_equal(da_copy.data, da.data)

0 commit comments

Comments
 (0)