Skip to content

Commit 722dae9

Browse files
weiji14Meghan Jonesseisman
authored
Allow passing None explicitly to pygmt functions Part 1 (#1857)
Implements a more robust check for None values in pygmt functions. * Let grdimage shading=None or False work Refactor grdimage to check `if "I" in kwargs` to using `if kwargs.get("I") is not None`. * Let grd2cpt's categorical, cyclic and output work with None input * Let grd2xyz's outcols work with None input Specifically when output_type="pandas" too. * Let grdgradient's tiles, normalize and outgrid work with None input * Let grdview's drapegrid work with None inputs * Let makecpt's categorical, cyclic and output work with None inputs * Let plot's style, color, intensity and transparency work with None input * Let plot3d's style, color, intensity & transparency work with None input * Let solar's T work with None input * Let transparency work with 0, None and False input * Let project's center, convention and generate work with None inputs * Let velo's spec work with None inputs Or rather, catch it properly if someone uses spec=None. * Update pygmt/src/grdgradient.py using walrus operator Co-authored-by: Meghan Jones <[email protected]> Co-authored-by: Dongdong Tian <[email protected]>
1 parent 61781e4 commit 722dae9

16 files changed

+72
-34
lines changed

pygmt/src/grd2cpt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,14 +160,14 @@ def grd2cpt(grid, **kwargs):
160160
``categorical=True``.
161161
{V}
162162
"""
163-
if "W" in kwargs and "Ww" in kwargs:
163+
if kwargs.get("W") is not None and kwargs.get("Ww") is not None:
164164
raise GMTInvalidInput("Set only categorical or cyclic to True, not both.")
165165
with Session() as lib:
166166
file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)
167167
with file_context as infile:
168-
if "H" not in kwargs: # if no output is set
168+
if kwargs.get("H") is None: # if no output is set
169169
arg_str = build_arg_string(kwargs, infile=infile)
170-
if "H" in kwargs: # if output is set
170+
else: # if output is set
171171
outfile, kwargs["H"] = kwargs["H"], True
172172
if not outfile or not isinstance(outfile, str):
173173
raise GMTInvalidInput("'output' should be a proper file name.")

pygmt/src/grd2xyz.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def grd2xyz(grid, output_type="pandas", outfile=None, **kwargs):
159159
elif outfile is None and output_type == "file":
160160
raise GMTInvalidInput("Must specify 'outfile' for ASCII output.")
161161

162-
if "o" in kwargs and output_type == "pandas":
162+
if kwargs.get("o") is not None and output_type == "pandas":
163163
raise GMTInvalidInput(
164164
"If 'outcols' is specified, 'output_type' must be either 'numpy'"
165165
"or 'file'."

pygmt/src/grdgradient.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def grdgradient(grid, **kwargs):
164164
>>> new_grid = pygmt.grdgradient(grid=grid, azimuth=10)
165165
"""
166166
with GMTTempFile(suffix=".nc") as tmpfile:
167-
if "Q" in kwargs and "N" not in kwargs:
167+
if kwargs.get("Q") is not None and kwargs.get("N") is None:
168168
raise GMTInvalidInput("""Must specify normalize if tiles is specified.""")
169169
if not args_in_kwargs(args=["A", "D", "E"], kwargs=kwargs):
170170
raise GMTInvalidInput(
@@ -174,9 +174,8 @@ def grdgradient(grid, **kwargs):
174174
with Session() as lib:
175175
file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)
176176
with file_context as infile:
177-
if "G" not in kwargs: # if outgrid is unset, output to tempfile
178-
kwargs.update({"G": tmpfile.name})
179-
outgrid = kwargs["G"]
177+
if (outgrid := kwargs.get("G")) is None:
178+
kwargs["G"] = outgrid = tmpfile.name # output to tmpfile
180179
lib.call_module("grdgradient", build_arg_string(kwargs, infile=infile))
181180

182181
return load_dataarray(outgrid) if outgrid == tmpfile.name else None

pygmt/src/grdimage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def grdimage(self, grid, **kwargs):
166166
file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)
167167
with contextlib.ExitStack() as stack:
168168
# shading using an xr.DataArray
169-
if "I" in kwargs and data_kind(kwargs["I"]) == "grid":
169+
if kwargs.get("I") is not None and data_kind(kwargs["I"]) == "grid":
170170
shading_context = lib.virtualfile_from_grid(kwargs["I"])
171171
kwargs["I"] = stack.enter_context(shading_context)
172172

pygmt/src/grdview.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ def grdview(self, grid, **kwargs):
126126
file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)
127127

128128
with contextlib.ExitStack() as stack:
129-
if "G" in kwargs: # deal with kwargs["G"] if drapegrid is xr.DataArray
129+
if kwargs.get("G") is not None:
130+
# deal with kwargs["G"] if drapegrid is xr.DataArray
130131
drapegrid = kwargs["G"]
131132
if data_kind(drapegrid) in ("file", "grid"):
132133
if data_kind(drapegrid) == "grid":

pygmt/src/makecpt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,11 @@ def makecpt(**kwargs):
147147
``categorical=True``.
148148
"""
149149
with Session() as lib:
150-
if "W" in kwargs and "Ww" in kwargs:
150+
if kwargs.get("W") is not None and kwargs.get("Ww") is not None:
151151
raise GMTInvalidInput("Set only categorical or cyclic to True, not both.")
152-
if "H" not in kwargs: # if no output is set
152+
if kwargs.get("H") is None: # if no output is set
153153
arg_str = build_arg_string(kwargs)
154-
elif "H" in kwargs: # if output is set
154+
else: # if output is set
155155
outfile, kwargs["H"] = kwargs.pop("H"), True
156156
if not outfile or not isinstance(outfile, str):
157157
raise GMTInvalidInput("'output' should be a proper file name.")

pygmt/src/plot.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,15 +218,15 @@ def plot(self, data=None, x=None, y=None, size=None, direction=None, **kwargs):
218218
kind = data_kind(data, x, y)
219219

220220
extra_arrays = []
221-
if "S" in kwargs and kwargs["S"][0] in "vV" and direction is not None:
221+
if kwargs.get("S") is not None and kwargs["S"][0] in "vV" and direction is not None:
222222
extra_arrays.extend(direction)
223223
elif (
224-
"S" not in kwargs
224+
kwargs.get("S") is None
225225
and kind == "geojson"
226226
and data.geom_type.isin(["Point", "MultiPoint"]).all()
227227
): # checking if the geometry of a geoDataFrame is Point or MultiPoint
228228
kwargs["S"] = "s0.2c"
229-
elif "S" not in kwargs and kind == "file" and data.endswith(".gmt"):
229+
elif kwargs.get("S") is None and kind == "file" and data.endswith(".gmt"):
230230
# checking that the data is a file path to set default style
231231
try:
232232
with open(which(data), mode="r", encoding="utf8") as file:
@@ -236,7 +236,7 @@ def plot(self, data=None, x=None, y=None, size=None, direction=None, **kwargs):
236236
kwargs["S"] = "s0.2c"
237237
except FileNotFoundError:
238238
pass
239-
if "G" in kwargs and is_nonstr_iter(kwargs["G"]):
239+
if kwargs.get("G") is not None and is_nonstr_iter(kwargs["G"]):
240240
if kind != "vectors":
241241
raise GMTInvalidInput(
242242
"Can't use arrays for color if data is matrix or file."
@@ -251,7 +251,7 @@ def plot(self, data=None, x=None, y=None, size=None, direction=None, **kwargs):
251251
extra_arrays.append(size)
252252

253253
for flag in ["I", "t"]:
254-
if flag in kwargs and is_nonstr_iter(kwargs[flag]):
254+
if kwargs.get(flag) is not None and is_nonstr_iter(kwargs[flag]):
255255
if kind != "vectors":
256256
raise GMTInvalidInput(
257257
f"Can't use arrays for {plot.aliases[flag]} if data is matrix or file."

pygmt/src/plot3d.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,15 +188,15 @@ def plot3d(
188188
kind = data_kind(data, x, y, z)
189189

190190
extra_arrays = []
191-
if "S" in kwargs and kwargs["S"][0] in "vV" and direction is not None:
191+
if kwargs.get("S") is not None and kwargs["S"][0] in "vV" and direction is not None:
192192
extra_arrays.extend(direction)
193193
elif (
194-
"S" not in kwargs
194+
kwargs.get("S") is None
195195
and kind == "geojson"
196196
and data.geom_type.isin(["Point", "MultiPoint"]).all()
197197
): # checking if the geometry of a geoDataFrame is Point or MultiPoint
198198
kwargs["S"] = "u0.2c"
199-
elif "S" not in kwargs and kind == "file" and data.endswith(".gmt"):
199+
elif kwargs.get("S") is None and kind == "file" and data.endswith(".gmt"):
200200
# checking that the data is a file path to set default style
201201
try:
202202
with open(which(data), mode="r", encoding="utf8") as file:
@@ -206,7 +206,7 @@ def plot3d(
206206
kwargs["S"] = "u0.2c"
207207
except FileNotFoundError:
208208
pass
209-
if "G" in kwargs and is_nonstr_iter(kwargs["G"]):
209+
if kwargs.get("G") is not None and is_nonstr_iter(kwargs["G"]):
210210
if kind != "vectors":
211211
raise GMTInvalidInput(
212212
"Can't use arrays for color if data is matrix or file."
@@ -221,7 +221,7 @@ def plot3d(
221221
extra_arrays.append(size)
222222

223223
for flag in ["I", "t"]:
224-
if flag in kwargs and is_nonstr_iter(kwargs[flag]):
224+
if kwargs.get(flag) is not None and is_nonstr_iter(kwargs[flag]):
225225
if kind != "vectors":
226226
raise GMTInvalidInput(
227227
f"Can't use arrays for {plot3d.aliases[flag]} if data is matrix or file."

pygmt/src/project.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,13 +210,13 @@ def project(data=None, x=None, y=None, z=None, outfile=None, **kwargs):
210210
by ``outfile``)
211211
"""
212212

213-
if "C" not in kwargs:
213+
if kwargs.get("C") is None:
214214
raise GMTInvalidInput("The `center` parameter must be specified.")
215-
if "G" not in kwargs and data is None:
215+
if kwargs.get("G") is None and data is None:
216216
raise GMTInvalidInput(
217217
"The `data` parameter must be specified unless `generate` is used."
218218
)
219-
if "G" in kwargs and "F" in kwargs:
219+
if kwargs.get("G") is not None and kwargs.get("F") is not None:
220220
raise GMTInvalidInput(
221221
"The `convention` parameter is not allowed with `generate`."
222222
)
@@ -225,7 +225,7 @@ def project(data=None, x=None, y=None, z=None, outfile=None, **kwargs):
225225
if outfile is None: # Output to tmpfile if outfile is not set
226226
outfile = tmpfile.name
227227
with Session() as lib:
228-
if "G" not in kwargs:
228+
if kwargs.get("G") is None:
229229
# Choose how data will be passed into the module
230230
table_context = lib.virtualfile_from_data(
231231
check_kind="vector", data=data, x=x, y=y, z=z, required_z=False
@@ -240,7 +240,7 @@ def project(data=None, x=None, y=None, z=None, outfile=None, **kwargs):
240240

241241
# if user did not set outfile, return pd.DataFrame
242242
if outfile == tmpfile.name:
243-
if "G" in kwargs:
243+
if kwargs.get("G") is not None:
244244
column_names = list("rsp")
245245
result = pd.read_csv(tmpfile.name, sep="\t", names=column_names)
246246
else:

pygmt/src/solar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def solar(self, terminator="d", terminator_datetime=None, **kwargs):
6666
"""
6767

6868
kwargs = self._preprocess(**kwargs) # pylint: disable=protected-access
69-
if "T" in kwargs:
69+
if kwargs.get("T") is not None:
7070
raise GMTInvalidInput(
7171
"Use 'terminator' and 'terminator_datetime' instead of 'T'."
7272
)

pygmt/src/text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def text_(
215215
extra_arrays = []
216216
# If an array of transparency is given, GMT will read it from
217217
# the last numerical column per data record.
218-
if "t" in kwargs and is_nonstr_iter(kwargs["t"]):
218+
if kwargs.get("t") is not None and is_nonstr_iter(kwargs["t"]):
219219
extra_arrays.append(kwargs["t"])
220220
kwargs["t"] = ""
221221

pygmt/src/velo.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,12 @@ def velo(self, data=None, **kwargs):
238238
"""
239239
kwargs = self._preprocess(**kwargs) # pylint: disable=protected-access
240240

241-
if "S" not in kwargs or ("S" in kwargs and not isinstance(kwargs["S"], str)):
242-
raise GMTInvalidInput("Spec is a required argument and has to be a string.")
241+
if kwargs.get("S") is None or (
242+
kwargs.get("S") is not None and not isinstance(kwargs["S"], str)
243+
):
244+
raise GMTInvalidInput(
245+
"The parameter `spec` is required and has to be a string."
246+
)
243247

244248
if isinstance(data, np.ndarray) and not pd.api.types.is_numeric_dtype(data):
245249
raise GMTInvalidInput(

pygmt/tests/test_grd2xyz.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_grd2xyz_format(grid):
4545
np.testing.assert_allclose(orig_val, xyz_val)
4646
xyz_array = grd2xyz(grid=grid, output_type="numpy")
4747
assert isinstance(xyz_array, np.ndarray)
48-
xyz_df = grd2xyz(grid=grid, output_type="pandas")
48+
xyz_df = grd2xyz(grid=grid, output_type="pandas", outcols=None)
4949
assert isinstance(xyz_df, pd.DataFrame)
5050
assert list(xyz_df.columns) == ["lon", "lat", "z"]
5151

pygmt/tests/test_grdgradient.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,13 @@ def test_grdgradient_no_outgrid(grid, expected_grid):
5757
"""
5858
Test the azimuth and direction parameters for grdgradient with no set
5959
outgrid.
60+
61+
This is a regression test for
62+
https://github.com/GenericMappingTools/pygmt/issues/1807.
6063
"""
61-
result = grdgradient(grid=grid, azimuth=10, region=[-53, -49, -20, -17])
64+
result = grdgradient(
65+
grid=grid, azimuth=10, region=[-53, -49, -20, -17], outgrid=None
66+
)
6267
# check information of the output grid
6368
assert isinstance(result, xr.DataArray)
6469
assert result.gmt.gtype == 1 # Geographic grid

pygmt/tests/test_grdimage.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,21 @@ def test_grdimage_file():
9292
return fig
9393

9494

95+
@pytest.mark.mpl_image_compare(filename="test_grdimage_slice.png")
96+
@pytest.mark.parametrize("shading", [None, False])
97+
def test_grdimage_default_no_shading(grid, shading):
98+
"""
99+
Plot an image with no shading.
100+
101+
This is a regression test for
102+
https://github.com/GenericMappingTools/pygmt/issues/1852
103+
"""
104+
grid_ = grid.sel(lat=slice(-30, 30))
105+
fig = Figure()
106+
fig.grdimage(grid_, cmap="earth", projection="M6i", shading=shading)
107+
return fig
108+
109+
95110
@check_figures_equal()
96111
@pytest.mark.parametrize(
97112
"shading",

pygmt/tests/test_text.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,20 @@ def test_text_varying_transparency():
337337
return fig
338338

339339

340+
@pytest.mark.mpl_image_compare(filename="test_text_input_single_filename.png")
341+
@pytest.mark.parametrize("transparency", [None, False, 0])
342+
def test_text_no_transparency(transparency):
343+
"""
344+
Add text with no transparency set.
345+
346+
This is a regression test for
347+
https://github.com/GenericMappingTools/pygmt/issues/1852.
348+
"""
349+
fig = Figure()
350+
fig.text(region=[10, 70, -5, 10], textfiles=POINTS_DATA, transparency=transparency)
351+
return fig
352+
353+
340354
@pytest.mark.mpl_image_compare
341355
def test_text_nonstr_text():
342356
"""

0 commit comments

Comments
 (0)