Skip to content

Commit 5c3c88f

Browse files
committed
pygmt.grdhisteq: Improve performance by storing output in virtual files
1 parent 3a507a8 commit 5c3c88f

File tree

1 file changed

+33
-41
lines changed

1 file changed

+33
-41
lines changed

pygmt/src/grdhisteq.py

Lines changed: 33 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
grdhisteq - Perform histogram equalization for a grid.
33
"""
44

5+
from typing import Literal
6+
57
import numpy as np
68
import pandas as pd
79
from pygmt.clib import Session
@@ -135,15 +137,19 @@ def equalize_grid(grid, **kwargs):
135137
@fmt_docstring
136138
@use_alias(
137139
C="divisions",
138-
D="outfile",
139140
R="region",
140141
N="gaussian",
141142
Q="quadratic",
142143
V="verbose",
143144
h="header",
144145
)
145146
@kwargs_to_strings(R="sequence")
146-
def compute_bins(grid, output_type="pandas", **kwargs):
147+
def compute_bins(
148+
grid,
149+
output_type: Literal["pandas", "numpy", "file"] = "pandas",
150+
outfile: str | None = None,
151+
**kwargs,
152+
) -> pd.DataFrame | np.ndarray | None:
147153
r"""
148154
Perform histogram equalization for a grid.
149155
@@ -168,16 +174,8 @@ def compute_bins(grid, output_type="pandas", **kwargs):
168174
Parameters
169175
----------
170176
{grid}
171-
outfile : str or bool or None
172-
The name of the output ASCII file to store the results of the
173-
histogram equalization in.
174-
output_type : str
175-
Determine the format the xyz data will be returned in [Default is
176-
``pandas``]:
177-
178-
- ``numpy`` - :class:`numpy.ndarray`
179-
- ``pandas``- :class:`pandas.DataFrame`
180-
- ``file`` - ASCII file (requires ``outfile``)
177+
{output_type}
178+
{outfile}
181179
divisions : int
182180
Set the number of divisions of the data range.
183181
quadratic : bool
@@ -188,13 +186,13 @@ def compute_bins(grid, output_type="pandas", **kwargs):
188186
189187
Returns
190188
-------
191-
ret : pandas.DataFrame or numpy.ndarray or None
189+
ret
192190
Return type depends on ``outfile`` and ``output_type``:
193191
194192
- None if ``outfile`` is set (output will be stored in file set by
195193
``outfile``)
196-
- :class:`pandas.DataFrame` or :class:`numpy.ndarray` if
197-
``outfile`` is not set (depends on ``output_type``)
194+
- :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is not
195+
set (depends on ``output_type``)
198196
199197
Example
200198
-------
@@ -225,39 +223,33 @@ def compute_bins(grid, output_type="pandas", **kwargs):
225223
This method does a weighted histogram equalization for geographic
226224
grids to account for node area varying with latitude.
227225
"""
228-
outfile = kwargs.get("D")
229226
output_type = validate_output_table_type(output_type, outfile=outfile)
230227

231228
if kwargs.get("h") is not None and output_type != "file":
232229
raise GMTInvalidInput("'header' is only allowed with output_type='file'.")
233230

234-
with GMTTempFile(suffix=".txt") as tmpfile:
235-
with Session() as lib:
236-
with lib.virtualfile_in(check_kind="raster", data=grid) as vingrd:
237-
if outfile is None:
238-
kwargs["D"] = outfile = tmpfile.name # output to tmpfile
239-
lib.call_module(
240-
module="grdhisteq", args=build_arg_string(kwargs, infile=vingrd)
241-
)
231+
with Session() as lib:
232+
with (
233+
lib.virtualfile_in(check_kind="raster", data=grid) as vingrd,
234+
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,
235+
):
236+
kwargs["D"] = vouttbl # -D for output file name
237+
lib.call_module(
238+
module="grdhisteq", args=build_arg_string(kwargs, infile=vingrd)
239+
)
242240

243-
if outfile == tmpfile.name:
244-
# if user did not set outfile, return pd.DataFrame
245-
result = pd.read_csv(
246-
filepath_or_buffer=outfile,
247-
sep="\t",
248-
header=None,
249-
names=["start", "stop", "bin_id"],
250-
dtype={
241+
result = lib.virtualfile_to_dataset(
242+
output_type=output_type,
243+
vfile=vouttbl,
244+
column_names=["start", "stop", "bin_id"],
245+
)
246+
if output_type == "pandas":
247+
result = result.astype(
248+
{
251249
"start": np.float32,
252250
"stop": np.float32,
253251
"bin_id": np.uint32,
254-
},
252+
}
255253
)
256-
elif outfile != tmpfile.name:
257-
# return None if outfile set, output in outfile
258-
return None
259-
260-
if output_type == "numpy":
261-
return result.to_numpy()
262-
263-
return result.set_index("bin_id")
254+
return result.set_index("bin_id")
255+
return result

0 commit comments

Comments
 (0)