Skip to content

Commit 7dfef14

Browse files
committed
TEST: Refactor NetCDF tests to be more pytest friendly
1 parent ee4b703 commit 7dfef14

File tree

3 files changed

+55
-72
lines changed

3 files changed

+55
-72
lines changed

.azure-pipelines/windows.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ jobs:
8080
-I test_minc2 ^
8181
-I test_minc2_data ^
8282
-I test_mriutils ^
83+
-I test_netcdf ^
8384
-I test_nibabel_data ^
8485
-I test_nifti1 ^
8586
-I test_nifti2 ^

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ script:
173173
-I test_minc2 \
174174
-I test_minc2_data \
175175
-I test_mriutils \
176+
-I test_netcdf \
176177
-I test_nibabel_data \
177178
-I test_nifti1 \
178179
-I test_nifti2 \

nibabel/externals/tests/test_netcdf.py

Lines changed: 53 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,15 @@
22

33
import os
44
from os.path import join as pjoin, dirname
5-
import shutil
6-
import tempfile
7-
import time
8-
import sys
95
from io import BytesIO
106
from glob import glob
117
from contextlib import contextmanager
128

139
import numpy as np
14-
from numpy.testing import dec, assert_
1510

16-
from ..netcdf import netcdf_file
11+
import pytest
1712

18-
from nose.tools import assert_true, assert_false, assert_equal, assert_raises
13+
from ..netcdf import netcdf_file
1914

2015
TEST_DATA_PATH = pjoin(dirname(__file__), 'data')
2116

@@ -36,54 +31,41 @@ def make_simple(*args, **kwargs):
3631
f.close()
3732

3833

39-
def gen_for_simple(ncfileobj):
40-
''' Generator for example fileobj tests '''
41-
yield assert_equal, ncfileobj.history, b'Created for a test'
34+
def assert_simple_truths(ncfileobj):
35+
assert ncfileobj.history == b'Created for a test'
4236
time = ncfileobj.variables['time']
43-
yield assert_equal, time.units, b'days since 2008-01-01'
44-
yield assert_equal, time.shape, (N_EG_ELS,)
45-
yield assert_equal, time[-1], N_EG_ELS-1
46-
47-
48-
def test_read_write_files():
49-
# test round trip for example file
50-
cwd = os.getcwd()
51-
try:
52-
tmpdir = tempfile.mkdtemp()
53-
os.chdir(tmpdir)
54-
with make_simple('simple.nc', 'w') as f:
55-
pass
56-
# To read the NetCDF file we just created::
57-
with netcdf_file('simple.nc') as f:
58-
# Using mmap is the default
59-
yield assert_true, f.use_mmap
60-
for testargs in gen_for_simple(f):
61-
yield testargs
62-
63-
# Now without mmap
64-
with netcdf_file('simple.nc', mmap=False) as f:
65-
# Using mmap is the default
66-
yield assert_false, f.use_mmap
67-
for testargs in gen_for_simple(f):
68-
yield testargs
69-
70-
# To read the NetCDF file we just created, as file object, no
71-
# mmap. When n * n_bytes(var_type) is not divisible by 4, this
72-
# raised an error in pupynere 1.0.12 and scipy rev 5893, because
73-
# calculated vsize was rounding up in units of 4 - see
74-
# https://www.unidata.ucar.edu/software/netcdf/docs/netcdf.html
75-
fobj = open('simple.nc', 'rb')
76-
with netcdf_file(fobj) as f:
77-
# by default, don't use mmap for file-like
78-
yield assert_false, f.use_mmap
79-
for testargs in gen_for_simple(f):
80-
yield testargs
81-
except:
82-
os.chdir(cwd)
83-
shutil.rmtree(tmpdir)
84-
raise
85-
os.chdir(cwd)
86-
shutil.rmtree(tmpdir)
37+
assert time.units == b'days since 2008-01-01'
38+
assert time.shape == (N_EG_ELS,)
39+
assert time[-1] == N_EG_ELS - 1
40+
41+
42+
def test_read_write_files(tmp_path):
43+
os.chdir(str(tmp_path))
44+
45+
with make_simple('simple.nc', 'w') as f:
46+
pass
47+
# To read the NetCDF file we just created::
48+
with netcdf_file('simple.nc') as f:
49+
# Using mmap is the default
50+
assert f.use_mmap
51+
assert_simple_truths(f)
52+
53+
# Now without mmap
54+
with netcdf_file('simple.nc', mmap=False) as f:
55+
# Using mmap is the default
56+
assert not f.use_mmap
57+
assert_simple_truths(f)
58+
59+
# To read the NetCDF file we just created, as file object, no
60+
# mmap. When n * n_bytes(var_type) is not divisible by 4, this
61+
# raised an error in pupynere 1.0.12 and scipy rev 5893, because
62+
# calculated vsize was rounding up in units of 4 - see
63+
# https://www.unidata.ucar.edu/software/netcdf/docs/netcdf.html
64+
fobj = open('simple.nc', 'rb')
65+
with netcdf_file(fobj) as f:
66+
# by default, don't use mmap for file-like
67+
assert not f.use_mmap
68+
assert_simple_truths(f)
8769

8870

8971
def test_read_write_sio():
@@ -93,28 +75,26 @@ def test_read_write_sio():
9375

9476
eg_sio2 = BytesIO(str_val)
9577
with netcdf_file(eg_sio2) as f2:
96-
for testargs in gen_for_simple(f2):
97-
yield testargs
78+
assert_simple_truths(f2)
9879

9980
# Test that error is raised if attempting mmap for sio
10081
eg_sio3 = BytesIO(str_val)
101-
yield assert_raises, ValueError, netcdf_file, eg_sio3, 'r', True
82+
with pytest.raises(ValueError):
83+
netcdf_file(eg_sio3, 'r', True)
10284
# Test 64-bit offset write / read
10385
eg_sio_64 = BytesIO()
10486
with make_simple(eg_sio_64, 'w', version=2) as f_64:
10587
str_val = eg_sio_64.getvalue()
10688

10789
eg_sio_64 = BytesIO(str_val)
10890
with netcdf_file(eg_sio_64) as f_64:
109-
for testargs in gen_for_simple(f_64):
110-
yield testargs
111-
yield assert_equal, f_64.version_byte, 2
91+
assert_simple_truths(f_64)
92+
assert f_64.version_byte == 2
11293
# also when version 2 explicitly specified
11394
eg_sio_64 = BytesIO(str_val)
11495
with netcdf_file(eg_sio_64, version=2) as f_64:
115-
for testargs in gen_for_simple(f_64):
116-
yield testargs
117-
yield assert_equal, f_64.version_byte, 2
96+
assert_simple_truths(f_64)
97+
assert f_64.version_byte == 2
11898

11999

120100
def test_read_example_data():
@@ -134,7 +114,8 @@ def test_itemset_no_segfault_on_readonly():
134114
time_var = f.variables['time']
135115

136116
# time_var.assignValue(42) should raise a RuntimeError--not seg. fault!
137-
assert_raises(RuntimeError, time_var.assignValue, 42)
117+
with pytest.raises(RuntimeError):
118+
time_var.assignValue(42)
138119

139120

140121
def test_write_invalid_dtype():
@@ -147,22 +128,22 @@ def test_write_invalid_dtype():
147128
with netcdf_file(BytesIO(), 'w') as f:
148129
f.createDimension('time', N_EG_ELS)
149130
for dt in dtypes:
150-
yield assert_raises, ValueError, \
151-
f.createVariable, 'time', dt, ('time',)
131+
with pytest.raises(ValueError):
132+
f.createVariable('time', dt, ('time',))
152133

153134

154135
def test_flush_rewind():
155136
stream = BytesIO()
156137
with make_simple(stream, mode='w') as f:
157-
x = f.createDimension('x',4)
138+
x = f.createDimension('x', 4)
158139
v = f.createVariable('v', 'i2', ['x'])
159140
v[:] = 1
160141
f.flush()
161142
len_single = len(stream.getvalue())
162143
f.flush()
163144
len_double = len(stream.getvalue())
164145

165-
assert_(len_single == len_double)
146+
assert len_single == len_double
166147

167148

168149
def test_dtype_specifiers():
@@ -192,8 +173,8 @@ def test_ticket_1720():
192173

193174
io = BytesIO(contents)
194175
with netcdf_file(io, 'r') as f:
195-
assert_equal(f.history, b'Created for a test')
176+
assert f.history == b'Created for a test'
196177
float_var = f.variables['float_var']
197-
assert_equal(float_var.units, b'metres')
198-
assert_equal(float_var.shape, (10,))
199-
assert_(np.allclose(float_var[:], items))
178+
assert float_var.units == b'metres'
179+
assert float_var.shape == (10,)
180+
assert np.allclose(float_var[:], items)

0 commit comments

Comments
 (0)