Skip to content

Commit 1b74259

Browse files
Meghan Jonesweiji14michaelgrundseisman
authored
Refactor pygmt.surface tests (#1568)
* Refactor surface tests to use preprocessed data and xarray testing Co-authored-by: Wei Ji <[email protected]> Co-authored-by: Michael Grund <[email protected]> Co-authored-by: Dongdong Tian <[email protected]>
1 parent f228cdf commit 1b74259

File tree

2 files changed

+113
-63
lines changed

2 files changed

+113
-63
lines changed

pygmt/helpers/testing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def download_test_data():
172172
"@EGM96_to_36.txt",
173173
"@MaunaLoa_CO2.txt",
174174
"@Table_5_11.txt",
175+
"@Table_5_11_mean.xyz",
175176
"@fractures_06.txt",
176177
"@hotspots.txt",
177178
"@ridge.txt",

pygmt/tests/test_surface.py

Lines changed: 112 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -3,118 +3,167 @@
33
"""
44
import os
55

6+
import pandas as pd
67
import pytest
78
import xarray as xr
89
from pygmt import surface, which
9-
from pygmt.datasets import load_sample_bathymetry
1010
from pygmt.exceptions import GMTInvalidInput
11-
from pygmt.helpers import data_kind
11+
from pygmt.helpers import GMTTempFile, data_kind
1212

13-
TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
14-
TEMP_GRID = os.path.join(TEST_DATA_DIR, "tmp_grid.nc")
1513

14+
@pytest.fixture(scope="module", name="data")
15+
def fixture_data():
16+
"""
17+
Load Table 5.11 in Davis: Statistics and Data Analysis in Geology.
18+
"""
19+
fname = which("@Table_5_11_mean.xyz", download="c")
20+
return pd.read_csv(
21+
fname, sep=r"\s+", header=None, names=["x", "y", "z"], skiprows=1
22+
)
23+
24+
25+
@pytest.fixture(scope="module", name="region")
26+
def fixture_region():
27+
"""
28+
Define the region.
29+
"""
30+
return [0, 4, 0, 8]
1631

17-
@pytest.fixture(scope="module", name="ship_data")
18-
def fixture_ship_data():
32+
33+
@pytest.fixture(scope="module", name="spacing")
34+
def fixture_spacing():
1935
"""
20-
Load the data from the sample bathymetry dataset.
36+
Define the spacing.
2137
"""
22-
return load_sample_bathymetry()
38+
return "1"
2339

2440

25-
def test_surface_input_file():
41+
@pytest.fixture(scope="module", name="expected_grid")
42+
def fixture_grid_result():
43+
"""
44+
Load the expected grdcut grid result.
45+
"""
46+
return xr.DataArray(
47+
data=[
48+
[962.2361, 909.526, 872.2578, 876.5983, 950.573],
49+
[944.369, 905.8253, 872.1614, 901.8583, 943.6854],
50+
[911.0599, 865.305, 845.5058, 855.7317, 867.1638],
51+
[878.5973, 851.71, 814.6884, 812.1127, 819.9591],
52+
[842.0522, 815.2896, 788.2292, 777.0031, 785.6345],
53+
[854.2515, 813.3035, 781, 742.3641, 735.6497],
54+
[882.972, 818.4636, 773.611, 718.7798, 685.4824],
55+
[897.4532, 822.9642, 756.4472, 687.594, 626.2299],
56+
[910.2932, 823.3307, 737.9952, 651.4994, 565.9981],
57+
],
58+
coords=dict(
59+
y=[0, 1, 2, 3, 4, 5, 6, 7, 8],
60+
x=[0, 1, 2, 3, 4],
61+
),
62+
dims=[
63+
"y",
64+
"x",
65+
],
66+
)
67+
68+
69+
def check_values(grid, expected_grid):
70+
"""
71+
Check the attributes and values of the DataArray returned by surface.
72+
"""
73+
assert isinstance(grid, xr.DataArray)
74+
assert grid.gmt.registration == 0 # Gridline registration
75+
assert grid.gmt.gtype == 0 # Cartesian type
76+
xr.testing.assert_allclose(a=grid, b=expected_grid)
77+
78+
79+
def test_surface_input_file(region, spacing, expected_grid):
2680
"""
2781
Run surface by passing in a filename.
2882
"""
29-
fname = which("@tut_ship.xyz", download="c")
30-
output = surface(data=fname, spacing="5m", region=[245, 255, 20, 30])
31-
assert isinstance(output, xr.DataArray)
32-
assert output.gmt.registration == 0 # Gridline registration
33-
assert output.gmt.gtype == 0 # Cartesian type
83+
output = surface(
84+
data="@Table_5_11_mean.xyz",
85+
spacing=spacing,
86+
region=region,
87+
verbose="e", # Suppress warnings for IEEE 754 rounding
88+
)
89+
check_values(output, expected_grid)
3490

3591

36-
def test_surface_input_data_array(ship_data):
92+
def test_surface_input_data_array(data, region, spacing, expected_grid):
3793
"""
3894
Run surface by passing in a numpy array into data.
3995
"""
40-
data = ship_data.values # convert pandas.DataFrame to numpy.ndarray
41-
output = surface(data=data, spacing="5m", region=[245, 255, 20, 30])
42-
assert isinstance(output, xr.DataArray)
96+
data = data.values # convert pandas.DataFrame to numpy.ndarray
97+
output = surface(
98+
data=data,
99+
spacing=spacing,
100+
region=region,
101+
verbose="e", # Suppress warnings for IEEE 754 rounding
102+
)
103+
check_values(output, expected_grid)
43104

44105

45-
def test_surface_input_xyz(ship_data):
106+
def test_surface_input_xyz(data, region, spacing, expected_grid):
46107
"""
47108
Run surface by passing in x, y, z numpy.ndarrays individually.
48109
"""
49110
output = surface(
50-
x=ship_data.longitude,
51-
y=ship_data.latitude,
52-
z=ship_data.bathymetry,
53-
spacing="5m",
54-
region=[245, 255, 20, 30],
111+
x=data.x,
112+
y=data.y,
113+
z=data.z,
114+
spacing=spacing,
115+
region=region,
116+
verbose="e", # Suppress warnings for IEEE 754 rounding
55117
)
56-
assert isinstance(output, xr.DataArray)
118+
check_values(output, expected_grid)
57119

58120

59-
def test_surface_wrong_kind_of_input(ship_data):
121+
def test_surface_wrong_kind_of_input(data, region, spacing):
60122
"""
61123
Run surface using grid input that is not file/matrix/vectors.
62124
"""
63-
data = ship_data.bathymetry.to_xarray() # convert pandas.Series to xarray.DataArray
125+
data = data.z.to_xarray() # convert pandas.Series to xarray.DataArray
64126
assert data_kind(data) == "grid"
65127
with pytest.raises(GMTInvalidInput):
66-
surface(data=data, spacing="5m", region=[245, 255, 20, 30])
128+
surface(data=data, spacing=spacing, region=region)
67129

68130

69-
def test_surface_with_outgrid_param(ship_data):
131+
def test_surface_with_outgrid_param(data, region, spacing, expected_grid):
70132
"""
71133
Run surface with the -Goutputfile.nc parameter.
72134
"""
73-
data = ship_data.values # convert pandas.DataFrame to numpy.ndarray
74-
try:
135+
data = data.values # convert pandas.DataFrame to numpy.ndarray
136+
with GMTTempFile(suffix=".nc") as tmpfile:
75137
output = surface(
76-
data=data, spacing="5m", region=[245, 255, 20, 30], outgrid=TEMP_GRID
138+
data=data,
139+
spacing=spacing,
140+
region=region,
141+
outgrid=tmpfile.name,
142+
verbose="e", # Suppress warnings for IEEE 754 rounding
77143
)
78144
assert output is None # check that output is None since outgrid is set
79-
assert os.path.exists(path=TEMP_GRID) # check that outgrid exists at path
80-
with xr.open_dataarray(TEMP_GRID) as grid:
81-
assert isinstance(grid, xr.DataArray) # ensure netcdf grid loads ok
82-
finally:
83-
os.remove(path=TEMP_GRID)
145+
assert os.path.exists(path=tmpfile.name) # check that outgrid exists at path
146+
with xr.open_dataarray(tmpfile.name) as grid:
147+
check_values(grid, expected_grid)
84148

85149

86-
def test_surface_deprecate_outfile_to_outgrid(ship_data):
150+
def test_surface_deprecate_outfile_to_outgrid(data, region, spacing, expected_grid):
87151
"""
88152
Make sure that the old parameter "outfile" is supported and it reports a
89153
warning.
90154
"""
91155
with pytest.warns(expected_warning=FutureWarning) as record:
92-
data = ship_data.values # convert pandas.DataFrame to numpy.ndarray
93-
try:
156+
data = data.values # convert pandas.DataFrame to numpy.ndarray
157+
with GMTTempFile(suffix=".nc") as tmpfile:
94158
output = surface(
95-
data=data, spacing="5m", region=[245, 255, 20, 30], outfile=TEMP_GRID
159+
data=data,
160+
spacing=spacing,
161+
region=region,
162+
outfile=tmpfile.name,
163+
verbose="e", # Suppress warnings for IEEE 754 rounding
96164
)
97165
assert output is None # check that output is None since outfile is set
98-
assert os.path.exists(path=TEMP_GRID) # check that file exists at path
99-
100-
with xr.open_dataarray(TEMP_GRID) as grid:
101-
assert isinstance(grid, xr.DataArray) # ensure netcdf grid loads ok
102-
finally:
103-
os.remove(path=TEMP_GRID)
166+
assert os.path.exists(path=tmpfile.name) # check that file exists at path
167+
with xr.open_dataarray(tmpfile.name) as grid:
168+
check_values(grid, expected_grid)
104169
assert len(record) == 1 # check that only one warning was raised
105-
106-
107-
def test_surface_short_aliases(ship_data):
108-
"""
109-
Run surface using short aliases -I for spacing, -R for region, -G for
110-
outgrid.
111-
"""
112-
data = ship_data.values # convert pandas.DataFrame to numpy.ndarray
113-
try:
114-
output = surface(data=data, I="5m", R=[245, 255, 20, 30], G=TEMP_GRID)
115-
assert output is None # check that output is None since outgrid is set
116-
assert os.path.exists(path=TEMP_GRID) # check that outgrid exists at path
117-
with xr.open_dataarray(TEMP_GRID) as grid:
118-
assert isinstance(grid, xr.DataArray) # ensure netcdf grid loads ok
119-
finally:
120-
os.remove(path=TEMP_GRID)

0 commit comments

Comments
 (0)