Skip to content

Session.virtualfile_in: Refactor the 'check_kind' parameter #3941

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1755,9 +1755,10 @@ def virtualfile_from_stringio(
@deprecate_parameter(
"required_data", "required", "v0.16.0", remove_version="v0.20.0"
)
def virtualfile_in( # noqa: PLR0912
def virtualfile_in(
self,
check_kind=None,
kind=None,
data=None,
x=None,
y=None,
Expand Down Expand Up @@ -1847,7 +1848,9 @@ def virtualfile_in( # noqa: PLR0912
)
mincols = 3

kind = data_kind(data, required=required)
# Determine the data kind if not given.
if kind is None:
kind = data_kind(data, required=required, check_kind=check_kind)
_validate_data_input(
data=data,
x=x,
Expand All @@ -1858,16 +1861,6 @@ def virtualfile_in( # noqa: PLR0912
kind=kind,
)

if check_kind:
valid_kinds = ("file", "arg") if required is False else ("file",)
if check_kind == "raster":
valid_kinds += ("grid", "image")
elif check_kind == "vector":
valid_kinds += ("empty", "matrix", "vectors", "geojson")
if kind not in valid_kinds:
msg = f"Unrecognized data type for {check_kind}: {type(data)}."
raise GMTInvalidInput(msg)

# Decide which virtualfile_from_ function to use
_virtualfile_from = {
"arg": contextlib.nullcontext,
Expand Down
41 changes: 36 additions & 5 deletions pygmt/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@
"ISO-8859-16",
]

# Type hints for the list of data kinds.
Kind = Literal[
"arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors"
]


def _validate_data_input( # noqa: PLR0912
data=None, x=None, y=None, z=None, required=True, mincols=2, kind=None
Expand Down Expand Up @@ -272,11 +277,11 @@ def _check_encoding(argstr: str) -> Encoding:
return "ISOLatin1+"


def data_kind(
data: Any, required: bool = True
) -> Literal[
"arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors"
]:
def data_kind( # noqa: PLR0912
data: Any,
required: bool = True,
check_kind: Kind | Sequence[Kind] | Literal["raster", "vector"] | None = None,
) -> Kind:
r"""
Check the kind of data that is provided to a module.

Expand Down Expand Up @@ -307,6 +312,14 @@ def data_kind(
required
Whether 'data' is required. Set to ``False`` when dealing with optional virtual
files.
check_kind
Used to validate the type of data that can be passed in. Valid values are:

- Any recognized data kind
- A list/tuple of recognized data kinds
- ``"raster"``: shorthand for a sequence of raster-like data kinds
- ``"vector"``: shorthand for a sequence of vector-like data kinds
- ``None``: means no validatation.

Returns
-------
Expand Down Expand Up @@ -414,6 +427,24 @@ def data_kind(
kind = "matrix"
case _: # Fall back to "vectors" if data is None and required=True.
kind = "vectors"

# Now start to check if the data kind is valid.
if check_kind is not None:
valid_kinds = ("file", "arg") if required is False else ("file",)
match check_kind:
case "raster":
valid_kinds += ("grid", "image")
case "vector":
valid_kinds += ("empty", "matrix", "vectors", "geojson")
case str():
valid_kinds = (check_kind,)
case list() | tuple():
valid_kinds = check_kind

if kind not in valid_kinds:
msg = f"Unrecognized data type: {type(data)}."
raise GMTInvalidInput(msg)

return kind # type: ignore[return-value]


Expand Down
4 changes: 2 additions & 2 deletions pygmt/src/grdcut.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def grdcut(
raise GMTInvalidInput(msg)

# Determine the output data kind based on the input data kind.
match inkind := data_kind(grid):
match inkind := data_kind(grid, check_kind="raster"):
case "grid" | "image":
outkind = inkind
case "file":
Expand All @@ -128,7 +128,7 @@ def grdcut(

with Session() as lib:
with (
lib.virtualfile_in(check_kind="raster", data=grid) as vingrd,
lib.virtualfile_in(data=grid, kind=inkind) as vingrd,
lib.virtualfile_out(kind=outkind, fname=outgrid) as voutgrd,
):
kwargs["G"] = voutgrd
Expand Down
7 changes: 2 additions & 5 deletions pygmt/src/legend.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,11 @@ def legend(
if kwargs.get("F") is None:
kwargs["F"] = box

kind = data_kind(spec)
if kind not in {"empty", "file", "stringio"}:
msg = f"Unrecognized data type: {type(spec)}"
raise GMTInvalidInput(msg)
kind = data_kind(spec, check_kind=("empty", "file", "stringio"))
if kind == "file" and is_nonstr_iter(spec):
msg = "Only one legend specification file is allowed."
raise GMTInvalidInput(msg)

with Session() as lib:
with lib.virtualfile_in(data=spec, required=False) as vintbl:
with lib.virtualfile_in(data=spec, required=False, kind=kind) as vintbl:
lib.call_module(module="legend", args=build_arg_list(kwargs, infile=vintbl))
4 changes: 2 additions & 2 deletions pygmt/src/meca.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _preprocess_spec(spec, colnames, override_cols):
Dictionary of column names and values to override in the input data. Only makes
sense if ``spec`` is a dict or :class:`pandas.DataFrame`.
"""
kind = data_kind(spec) # Determine the kind of the input data.
kind = data_kind(spec, check_kind="vector") # Determine the kind of the input data.

# Convert pandas.DataFrame and numpy.ndarray to dict.
if isinstance(spec, pd.DataFrame):
Expand Down Expand Up @@ -359,5 +359,5 @@ def meca( # noqa: PLR0913
kwargs["A"] = _auto_offset(spec)
kwargs["S"] = f"{_convention.code}{scale}"
with Session() as lib:
with lib.virtualfile_in(check_kind="vector", data=spec) as vintbl:
with lib.virtualfile_in(data=spec) as vintbl:
lib.call_module(module="meca", args=build_arg_list(kwargs, infile=vintbl))
5 changes: 3 additions & 2 deletions pygmt/src/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,9 @@ def plot( # noqa: PLR0912
# parameter.
self._activate_figure()

kind = data_kind(data)
kind = data_kind(data, check_kind="vector")
if kind == "empty": # Data is given via a series of vectors.
kind = "vectors"
data = {"x": x, "y": y}
# Parameters for vector styles
if (
Expand Down Expand Up @@ -280,5 +281,5 @@ def plot( # noqa: PLR0912
kwargs["S"] = "s0.2c"

with Session() as lib:
with lib.virtualfile_in(check_kind="vector", data=data) as vintbl:
with lib.virtualfile_in(data=data, kind=kind) as vintbl:
lib.call_module(module="plot", args=build_arg_list(kwargs, infile=vintbl))
5 changes: 3 additions & 2 deletions pygmt/src/plot3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,9 @@ def plot3d( # noqa: PLR0912
# parameter.
self._activate_figure()

kind = data_kind(data)
kind = data_kind(data, check_kind="vector")
if kind == "empty": # Data is given via a series of vectors.
kind = "vectors"
data = {"x": x, "y": y, "z": z}
# Parameters for vector styles
if (
Expand Down Expand Up @@ -259,5 +260,5 @@ def plot3d( # noqa: PLR0912
kwargs["S"] = "u0.2c"

with Session() as lib:
with lib.virtualfile_in(check_kind="vector", data=data, mincols=3) as vintbl:
with lib.virtualfile_in(data=data, mincols=3, kind=kind) as vintbl:
lib.call_module(module="plot3d", args=build_arg_list(kwargs, infile=vintbl))
9 changes: 6 additions & 3 deletions pygmt/src/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
w="wrap",
)
@kwargs_to_strings(R="sequence", c="sequence_comma", p="sequence")
def text_( # noqa: PLR0912
def text_( # noqa: PLR0912, PLR0915
self,
textfiles: PathLike | TableLike | None = None,
x=None,
Expand Down Expand Up @@ -191,7 +191,7 @@ def text_( # noqa: PLR0912
raise GMTInvalidInput(msg)

data_is_required = position is None
kind = data_kind(textfiles, required=data_is_required)
kind = data_kind(textfiles, required=data_is_required, check_kind="vector")

if position is not None and (text is None or is_nonstr_iter(text)):
msg = "'text' can't be None or array when 'position' is given."
Expand Down Expand Up @@ -225,6 +225,7 @@ def text_( # noqa: PLR0912
confdict = {}
data = None
if kind == "empty":
kind = "vectors"
data = {"x": x, "y": y}

for arg, flag, name in array_args:
Expand Down Expand Up @@ -261,7 +262,9 @@ def text_( # noqa: PLR0912

with Session() as lib:
with lib.virtualfile_in(
check_kind="vector", data=textfiles or data, required=data_is_required
data=textfiles or data,
required=data_is_required,
kind=kind,
) as vintbl:
lib.call_module(
module="text",
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/x2sys_cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def x2sys_cross(

file_contexts: list[contextlib.AbstractContextManager[Any]] = []
for track in tracks:
match data_kind(track):
match data_kind(track, check_kind="vector"):
case "file":
file_contexts.append(contextlib.nullcontext(track))
case "vectors":
Expand Down
Loading