Skip to content

Commit 823b97f

Browse files
committed
Add zstd to all; Add tests for zstd
1 parent f231e5c commit 823b97f

File tree

5 files changed

+84
-28
lines changed

5 files changed

+84
-28
lines changed

nibabel/tests/test_analyze.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ..casting import as_int
3131
from ..tmpdirs import InTemporaryDirectory
3232
from ..arraywriters import WriterError
33+
from ..openers import HAVE_ZSTD
3334

3435
import pytest
3536
from numpy.testing import (assert_array_equal, assert_array_almost_equal)
@@ -788,6 +789,8 @@ def test_big_offset_exts(self):
788789
aff = np.eye(4)
789790
img_ext = img_klass.files_types[0][1]
790791
compressed_exts = ['', '.gz', '.bz2']
792+
if HAVE_ZSTD:
793+
compressed_exts += ['.zst']
791794
with InTemporaryDirectory():
792795
for offset in (0, 2048):
793796
# Set offset in in-memory image

nibabel/tests/test_minc1.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ..deprecated import ModuleProxy
2323
from .. import minc1
2424
from ..minc1 import Minc1File, Minc1Image, MincHeader
25+
from ..openers import HAVE_ZSTD
2526

2627
from ..tmpdirs import InTemporaryDirectory
2728
from ..deprecator import ExpiredDeprecationError
@@ -32,6 +33,10 @@
3233
from . import test_spatialimages as tsi
3334
from .test_fileslice import slicer_samples
3435

36+
# only import ZstdFile, if installed
37+
if HAVE_ZSTD:
38+
from ..openers import ZstdFile
39+
3540
EG_FNAME = pjoin(data_path, 'tiny.mnc')
3641

3742
# Example images in format expected for ``test_image_api``, adding ``zooms``
@@ -170,7 +175,10 @@ def test_compressed(self):
170175
# Not so for MINC2; hence this small sub-class
171176
for tp in self.test_files:
172177
content = open(tp['fname'], 'rb').read()
173-
openers_exts = ((gzip.open, '.gz'), (bz2.BZ2File, '.bz2'))
178+
openers_exts = [(gzip.open, '.gz'),
179+
(bz2.BZ2File, '.bz2')]
180+
if HAVE_ZSTD: # add .zst to test if installed
181+
openers_exts += [(ZstdFile, '.zst')]
174182
with InTemporaryDirectory():
175183
for opener, ext in openers_exts:
176184
fname = 'test.mnc' + ext

nibabel/tests/test_openers.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
from distutils.version import StrictVersion
1515

1616
from numpy.compat.py3k import asstr, asbytes
17-
from ..openers import Opener, ImageOpener, HAVE_INDEXED_GZIP, BZ2File
17+
from ..openers import (Opener,
18+
ImageOpener,
19+
HAVE_INDEXED_GZIP,
20+
BZ2File,
21+
HAVE_ZSTD)
1822
from ..tmpdirs import InTemporaryDirectory
1923
from ..volumeutils import BinOpener
2024

@@ -23,6 +27,9 @@
2327
import pytest
2428
from ..testing import error_warnings
2529

30+
if HAVE_ZSTD:
31+
from ..openers import ZstdFile
32+
2633

2734
class Lunk(object):
2835
# bare file-like for testing
@@ -71,10 +78,13 @@ def test_Opener_various():
7178
import indexed_gzip as igzip
7279
with InTemporaryDirectory():
7380
sobj = BytesIO()
74-
for input in ('test.txt',
75-
'test.txt.gz',
76-
'test.txt.bz2',
77-
sobj):
81+
files_to_test = ['test.txt',
82+
'test.txt.gz',
83+
'test.txt.bz2',
84+
sobj]
85+
if HAVE_ZSTD:
86+
files_to_test += ['test.txt.zst']
87+
for input in files_to_test:
7888
with Opener(input, 'wb') as fobj:
7989
fobj.write(message)
8090
assert fobj.tell() == len(message)
@@ -240,6 +250,8 @@ def test_compressed_ext_case():
240250
class StrictOpener(Opener):
241251
compress_ext_icase = False
242252
exts = ('gz', 'bz2', 'GZ', 'gZ', 'BZ2', 'Bz2')
253+
if HAVE_ZSTD:
254+
exts += ('zst', 'ZST', 'Zst')
243255
with InTemporaryDirectory():
244256
# Make a basic file to check type later
245257
with open(__file__, 'rb') as a_file:
@@ -264,6 +276,8 @@ class StrictOpener(Opener):
264276
except ImportError:
265277
IndexedGzipFile = GzipFile
266278
assert isinstance(fobj.fobj, (GzipFile, IndexedGzipFile))
279+
elif lext == 'zst':
280+
assert isinstance(fobj.fobj, ZstdFile)
267281
else:
268282
assert isinstance(fobj.fobj, BZ2File)
269283

@@ -273,11 +287,14 @@ def test_name():
273287
sobj = BytesIO()
274288
lunk = Lunk('in ART')
275289
with InTemporaryDirectory():
276-
for input in ('test.txt',
277-
'test.txt.gz',
278-
'test.txt.bz2',
279-
sobj,
280-
lunk):
290+
files_to_test = ['test.txt',
291+
'test.txt.gz',
292+
'test.txt.bz2',
293+
sobj,
294+
lunk]
295+
if HAVE_ZSTD:
296+
files_to_test += ['test.txt.zst']
297+
for input in files_to_test:
281298
exp_name = input if type(input) == type('') else None
282299
with Opener(input, 'wb') as fobj:
283300
assert fobj.name == exp_name
@@ -329,10 +346,13 @@ def test_iter():
329346
""".split('\n')
330347
with InTemporaryDirectory():
331348
sobj = BytesIO()
332-
for input, does_t in (('test.txt', True),
333-
('test.txt.gz', False),
334-
('test.txt.bz2', False),
335-
(sobj, True)):
349+
files_to_test = [('test.txt', True),
350+
('test.txt.gz', False),
351+
('test.txt.bz2', False),
352+
(sobj, True)]
353+
if HAVE_ZSTD:
354+
files_to_test += [('test.txt.zst', False)]
355+
for input, does_t in files_to_test:
336356
with Opener(input, 'wb') as fobj:
337357
for line in lines:
338358
fobj.write(asbytes(line + os.linesep))

nibabel/tests/test_volumeutils.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
_write_data,
4646
_ftype4scaled_finite,
4747
)
48-
from ..openers import Opener, BZ2File
48+
from ..openers import Opener, BZ2File, HAVE_ZSTD
4949
from ..casting import (floor_log2, type_info, OK_FLOATS, shared_range)
5050

5151
from ..deprecator import ExpiredDeprecationError
@@ -56,6 +56,10 @@
5656

5757
from nibabel.testing import nullcontext, assert_dt_equal, assert_allclose_safely, suppress_warnings
5858

59+
# only import ZstdFile, if installed
60+
if HAVE_ZSTD:
61+
from ..openers import ZstdFile
62+
5963
#: convenience variables for numpy types
6064
FLOAT_TYPES = np.sctypes['float']
6165
COMPLEX_TYPES = np.sctypes['complex']
@@ -68,9 +72,12 @@
6872
def test__is_compressed_fobj():
6973
# _is_compressed helper function
7074
with InTemporaryDirectory():
71-
for ext, opener, compressed in (('', open, False),
72-
('.gz', gzip.open, True),
73-
('.bz2', BZ2File, True)):
75+
file_openers = [('', open, False),
76+
('.gz', gzip.open, True),
77+
('.bz2', BZ2File, True)]
78+
if HAVE_ZSTD:
79+
file_openers += [('.zst', ZstdFile, True)]
80+
for ext, opener, compressed in file_openers:
7481
fname = 'test.bin' + ext
7582
for mode in ('wb', 'rb'):
7683
fobj = opener(fname, mode)
@@ -88,12 +95,15 @@ def make_array(n, bytes):
8895
arr.flags.writeable = True
8996
return arr
9097

91-
# Check whether file, gzip file, bz2 file reread memory from cache
98+
# Check whether file, gzip file, bz2, zst file reread memory from cache
9299
fname = 'test.bin'
93100
with InTemporaryDirectory():
101+
openers = [open, gzip.open, BZ2File]
102+
if HAVE_ZSTD:
103+
openers += [ZstdFile]
94104
for n, opener in itertools.product(
95105
(256, 1024, 2560, 25600),
96-
(open, gzip.open, BZ2File)):
106+
openers):
97107
in_arr = np.arange(n, dtype=dtype)
98108
# Write array to file
99109
fobj_w = opener(fname, 'wb')
@@ -230,7 +240,10 @@ def test_array_from_file_openers():
230240
dtype = np.dtype(np.float32)
231241
in_arr = np.arange(24, dtype=dtype).reshape(shape)
232242
with InTemporaryDirectory():
233-
for ext, offset in itertools.product(('', '.gz', '.bz2'),
243+
extensions = ['', '.gz', '.bz2']
244+
if HAVE_ZSTD:
245+
extensions += ['.zst']
246+
for ext, offset in itertools.product(extensions,
234247
(0, 5, 10)):
235248
fname = 'test.bin' + ext
236249
with Opener(fname, 'wb') as out_buf:
@@ -251,9 +264,12 @@ def test_array_from_file_reread():
251264
offset = 9
252265
fname = 'test.bin'
253266
with InTemporaryDirectory():
267+
openers = [open, gzip.open, bz2.BZ2File, BytesIO]
268+
if HAVE_ZSTD:
269+
openers += [ZstdFile]
254270
for shape, opener, dtt, order in itertools.product(
255271
((64,), (64, 65), (64, 65, 66)),
256-
(open, gzip.open, bz2.BZ2File, BytesIO),
272+
openers,
257273
(np.int16, np.float32),
258274
('F', 'C')):
259275
n_els = np.prod(shape)
@@ -901,7 +917,9 @@ def test_write_zeros():
901917
def test_seek_tell():
902918
# Test seek tell routine
903919
bio = BytesIO()
904-
in_files = bio, 'test.bin', 'test.gz', 'test.bz2'
920+
in_files = [bio, 'test.bin', 'test.gz', 'test.bz2']
921+
if HAVE_ZSTD:
922+
in_files += ['test.zst']
905923
start = 10
906924
end = 100
907925
diff = end - start
@@ -920,9 +938,12 @@ def test_seek_tell():
920938
fobj.write(b'\x01' * start)
921939
assert fobj.tell() == start
922940
# Files other than BZ2Files can seek forward on write, leaving
923-
# zeros in their wake. BZ2Files can't seek when writing, unless
924-
# we enable the write0 flag to seek_tell
925-
if not write0 and in_file == 'test.bz2': # Can't seek write in bz2
941+
# zeros in their wake. BZ2Files can't seek when writing,
942+
# unless we enable the write0 flag to seek_tell
943+
# ZstdFiles also does not support seek forward on write
944+
if (not write0 and
945+
(in_file == 'test.bz2' or
946+
in_file == 'test.zst')): # Can't seek write in bz2, zst
926947
# write the zeros by hand for the read test below
927948
fobj.write(b'\x00' * diff)
928949
else:
@@ -946,7 +967,10 @@ def test_seek_tell():
946967
# Check we have the expected written output
947968
with ImageOpener(in_file, 'rb') as fobj:
948969
assert fobj.read() == b'\x01' * start + b'\x00' * diff + b'\x02' * tail
949-
for in_file in ('test2.gz', 'test2.bz2'):
970+
input_files = ['test2.gz', 'test2.bz2']
971+
if HAVE_ZSTD:
972+
input_files += ['test2.zst']
973+
for in_file in input_files:
950974
# Check failure of write seek backwards
951975
with ImageOpener(in_file, 'wb') as fobj:
952976
fobj.write(b'g' * 10)

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ all =
7070
%(spm)s
7171
%(style)s
7272
%(test)s
73+
%(zstd)s
7374

7475
[options.entry_points]
7576
console_scripts =

0 commit comments

Comments
 (0)