Skip to content

TEST: Refactor NetCDF tests to be more pytest friendly #879

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .azure-pipelines/windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ jobs:
-I test_minc2 ^
-I test_minc2_data ^
-I test_mriutils ^
-I test_netcdf ^
-I test_nibabel_data ^
-I test_nifti1 ^
-I test_nifti2 ^
Expand Down
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ script:
-I test_minc2 \
-I test_minc2_data \
-I test_mriutils \
-I test_netcdf \
-I test_nibabel_data \
-I test_nifti1 \
-I test_nifti2 \
Expand Down
125 changes: 53 additions & 72 deletions nibabel/externals/tests/test_netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,15 @@

import os
from os.path import join as pjoin, dirname
import shutil
import tempfile
import time
import sys
from io import BytesIO
from glob import glob
from contextlib import contextmanager

import numpy as np
from numpy.testing import dec, assert_

from ..netcdf import netcdf_file
import pytest

from nose.tools import assert_true, assert_false, assert_equal, assert_raises
from ..netcdf import netcdf_file

TEST_DATA_PATH = pjoin(dirname(__file__), 'data')

Expand All @@ -36,54 +31,41 @@ def make_simple(*args, **kwargs):
f.close()


def gen_for_simple(ncfileobj):
''' Generator for example fileobj tests '''
yield assert_equal, ncfileobj.history, b'Created for a test'
def assert_simple_truths(ncfileobj):
assert ncfileobj.history == b'Created for a test'
time = ncfileobj.variables['time']
yield assert_equal, time.units, b'days since 2008-01-01'
yield assert_equal, time.shape, (N_EG_ELS,)
yield assert_equal, time[-1], N_EG_ELS-1


def test_read_write_files():
# test round trip for example file
cwd = os.getcwd()
try:
tmpdir = tempfile.mkdtemp()
os.chdir(tmpdir)
with make_simple('simple.nc', 'w') as f:
pass
# To read the NetCDF file we just created::
with netcdf_file('simple.nc') as f:
# Using mmap is the default
yield assert_true, f.use_mmap
for testargs in gen_for_simple(f):
yield testargs

# Now without mmap
with netcdf_file('simple.nc', mmap=False) as f:
# Using mmap is the default
yield assert_false, f.use_mmap
for testargs in gen_for_simple(f):
yield testargs

# To read the NetCDF file we just created, as file object, no
# mmap. When n * n_bytes(var_type) is not divisible by 4, this
# raised an error in pupynere 1.0.12 and scipy rev 5893, because
# calculated vsize was rounding up in units of 4 - see
# https://www.unidata.ucar.edu/software/netcdf/docs/netcdf.html
fobj = open('simple.nc', 'rb')
with netcdf_file(fobj) as f:
# by default, don't use mmap for file-like
yield assert_false, f.use_mmap
for testargs in gen_for_simple(f):
yield testargs
except:
os.chdir(cwd)
shutil.rmtree(tmpdir)
raise
os.chdir(cwd)
shutil.rmtree(tmpdir)
assert time.units == b'days since 2008-01-01'
assert time.shape == (N_EG_ELS,)
assert time[-1] == N_EG_ELS - 1


def test_read_write_files(tmp_path):
fname = str(tmp_path / 'simple.nc')

with make_simple(fname, 'w') as f:
pass
# To read the NetCDF file we just created::
with netcdf_file(fname) as f:
# Using mmap is the default
assert f.use_mmap
assert_simple_truths(f)

# Now without mmap
with netcdf_file(fname, mmap=False) as f:
# Using mmap is the default
assert not f.use_mmap
assert_simple_truths(f)

# To read the NetCDF file we just created, as file object, no
# mmap. When n * n_bytes(var_type) is not divisible by 4, this
# raised an error in pupynere 1.0.12 and scipy rev 5893, because
# calculated vsize was rounding up in units of 4 - see
# https://www.unidata.ucar.edu/software/netcdf/docs/netcdf.html
fobj = open(fname, 'rb')
with netcdf_file(fobj) as f:
# by default, don't use mmap for file-like
assert not f.use_mmap
assert_simple_truths(f)


def test_read_write_sio():
Expand All @@ -93,28 +75,26 @@ def test_read_write_sio():

eg_sio2 = BytesIO(str_val)
with netcdf_file(eg_sio2) as f2:
for testargs in gen_for_simple(f2):
yield testargs
assert_simple_truths(f2)

# Test that error is raised if attempting mmap for sio
eg_sio3 = BytesIO(str_val)
yield assert_raises, ValueError, netcdf_file, eg_sio3, 'r', True
with pytest.raises(ValueError):
netcdf_file(eg_sio3, 'r', True)
# Test 64-bit offset write / read
eg_sio_64 = BytesIO()
with make_simple(eg_sio_64, 'w', version=2) as f_64:
str_val = eg_sio_64.getvalue()

eg_sio_64 = BytesIO(str_val)
with netcdf_file(eg_sio_64) as f_64:
for testargs in gen_for_simple(f_64):
yield testargs
yield assert_equal, f_64.version_byte, 2
assert_simple_truths(f_64)
assert f_64.version_byte == 2
# also when version 2 explicitly specified
eg_sio_64 = BytesIO(str_val)
with netcdf_file(eg_sio_64, version=2) as f_64:
for testargs in gen_for_simple(f_64):
yield testargs
yield assert_equal, f_64.version_byte, 2
assert_simple_truths(f_64)
assert f_64.version_byte == 2


def test_read_example_data():
Expand All @@ -134,7 +114,8 @@ def test_itemset_no_segfault_on_readonly():
time_var = f.variables['time']

# time_var.assignValue(42) should raise a RuntimeError--not seg. fault!
assert_raises(RuntimeError, time_var.assignValue, 42)
with pytest.raises(RuntimeError):
time_var.assignValue(42)


def test_write_invalid_dtype():
Expand All @@ -147,22 +128,22 @@ def test_write_invalid_dtype():
with netcdf_file(BytesIO(), 'w') as f:
f.createDimension('time', N_EG_ELS)
for dt in dtypes:
yield assert_raises, ValueError, \
f.createVariable, 'time', dt, ('time',)
with pytest.raises(ValueError):
f.createVariable('time', dt, ('time',))


def test_flush_rewind():
stream = BytesIO()
with make_simple(stream, mode='w') as f:
x = f.createDimension('x',4)
x = f.createDimension('x', 4)
v = f.createVariable('v', 'i2', ['x'])
v[:] = 1
f.flush()
len_single = len(stream.getvalue())
f.flush()
len_double = len(stream.getvalue())

assert_(len_single == len_double)
assert len_single == len_double


def test_dtype_specifiers():
Expand Down Expand Up @@ -192,8 +173,8 @@ def test_ticket_1720():

io = BytesIO(contents)
with netcdf_file(io, 'r') as f:
assert_equal(f.history, b'Created for a test')
assert f.history == b'Created for a test'
float_var = f.variables['float_var']
assert_equal(float_var.units, b'metres')
assert_equal(float_var.shape, (10,))
assert_(np.allclose(float_var[:], items))
assert float_var.units == b'metres'
assert float_var.shape == (10,)
assert np.allclose(float_var[:], items)