Skip to content

Commit b5ff2de

Browse files
authored
Merge pull request #806 from effigies/fix/gifti_types
FIX: Coerce data types on writing GIFTI DataArrays
2 parents 5b4d03a + 2838c49 commit b5ff2de

File tree

2 files changed

+28
-6
lines changed

2 files changed

+28
-6
lines changed

nibabel/gifti/gifti.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -270,16 +270,21 @@ def _to_xml_element(self):
270270
return DataTag(dataarray, encoding, datatype, ordering).to_xml()
271271

272272

273-
def _data_tag_element(dataarray, encoding, datatype, ordering):
273+
def _data_tag_element(dataarray, encoding, dtype, ordering):
274274
""" Creates data tag with given `encoding`, returns as XML element
275275
"""
276276
import zlib
277-
ord = array_index_order_codes.npcode[ordering]
277+
order = array_index_order_codes.npcode[ordering]
278278
enclabel = gifti_encoding_codes.label[encoding]
279279
if enclabel == 'ASCII':
280-
da = _arr2txt(dataarray, datatype)
280+
# XXX Accommodating data_tag API
281+
# On removal (nibabel 4.0) drop str case
282+
da = _arr2txt(dataarray, dtype if isinstance(dtype, str) else KIND2FMT[dtype.kind])
281283
elif enclabel in ('B64BIN', 'B64GZ'):
282-
out = dataarray.tostring(ord)
284+
# XXX Accommodating data_tag API - don't try to fix dtype
285+
if isinstance(dtype, str):
286+
dtype = dataarray.dtype
287+
out = np.asanyarray(dataarray, dtype).tostring(order)
283288
if enclabel == 'B64GZ':
284289
out = zlib.compress(out)
285290
da = base64.b64encode(out).decode()
@@ -462,11 +467,10 @@ def _to_xml_element(self):
462467
if self.coordsys is not None:
463468
data_array.append(self.coordsys._to_xml_element())
464469
# write data array depending on the encoding
465-
dt_kind = data_type_codes.dtype[self.datatype].kind
466470
data_array.append(
467471
_data_tag_element(self.data,
468472
gifti_encoding_codes.specs[self.encoding],
469-
KIND2FMT[dt_kind],
473+
data_type_codes.dtype[self.datatype],
470474
self.ind_ord))
471475

472476
return data_array

nibabel/gifti/tests/test_gifti.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from nibabel.testing import clear_and_catch_warnings
2222
from .test_parse_gifti_fast import (DATA_FILE1, DATA_FILE2, DATA_FILE3,
2323
DATA_FILE4, DATA_FILE5, DATA_FILE6)
24+
import itertools
2425

2526

2627
def test_gifti_image():
@@ -400,3 +401,20 @@ def test_data_array_round_trip():
400401
gio = GiftiImage.from_file_map(fmap)
401402
vertices = gio.darrays[0].data
402403
assert_array_equal(vertices, verts)
404+
405+
406+
def test_darray_dtype_coercion_failures():
407+
dtypes = (np.uint8, np.int32, np.int64, np.float32, np.float64)
408+
encodings = ('ASCII', 'B64BIN', 'B64GZ')
409+
for data_dtype, darray_dtype, encoding in itertools.product(dtypes,
410+
dtypes,
411+
encodings):
412+
da = GiftiDataArray(np.arange(10).astype(data_dtype),
413+
encoding=encoding,
414+
intent='NIFTI_INTENT_NODE_INDEX',
415+
datatype=darray_dtype)
416+
gii = GiftiImage(darrays=[da])
417+
gii_copy = GiftiImage.from_bytes(gii.to_bytes())
418+
da_copy = gii_copy.darrays[0]
419+
assert_equal(np.dtype(da_copy.data.dtype), np.dtype(darray_dtype))
420+
assert_array_equal(da_copy.data, da.data)

0 commit comments

Comments
 (0)