Skip to content

GMTDataArrayAccessor: Enable grid operations on the current xarray.DataArray object directly #3854

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
116 changes: 113 additions & 3 deletions pygmt/xarray/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,19 @@
import xarray as xr
from pygmt.enums import GridRegistration, GridType
from pygmt.exceptions import GMTInvalidInput
from pygmt.src.grdinfo import grdinfo
from pygmt.src import (
dimfilter,
grdclip,
grdcut,
grdfill,
grdfilter,
grdgradient,
grdhisteq,
grdinfo,
grdproject,
grdsample,
grdtrack,
)


@xr.register_dataarray_accessor("gmt")
Expand All @@ -17,12 +29,17 @@ class GMTDataArrayAccessor:
GMT accessor for :class:`xarray.DataArray`.

The *gmt* accessor extends :class:`xarray.DataArray` to store GMT-specific
properties for grids, which are important for PyGMT to correctly process and plot
the grids. The *gmt* accessor contains the following properties:
properties for grids or images, which are important for PyGMT to correctly process
and plot them. The *gmt* accessor contains the following properties:

- ``registration``: Grid registration type :class:`pygmt.enums.GridRegistration`.
- ``gtype``: Grid coordinate system type :class:`pygmt.enums.GridType`.

The *gmt* accessor also provides a set of grid-operation methods that enables
applying GMT's grid processing functionalities directly to the current
:class:`xarray.DataArray` object. See the summary table below for the list of
available methods.

Examples
--------
For GMT's built-in remote datasets, these GMT-specific properties are automatically
Expand Down Expand Up @@ -67,6 +84,19 @@ class GMTDataArrayAccessor:
>>> grid.gmt.gtype
<GridType.GEOGRAPHIC: 1>

Instead of calling a grid-processing function and passing the
:class:`xarray.DataArray` object as an input, you can call the corresponding method
directly on the object. For example, the following two are equivalent:

>>> from pygmt.datasets import load_earth_relief
>>> grid = load_earth_relief(resolution="30m", region=[10, 30, 15, 25])
>>> # Create a new grid from an input grid. Set all values below 1,000 to
>>> # 0 and all values above 1,500 to 10,000.
>>> # Option 1:
>>> new_grid = pygmt.grdclip(grid=grid, below=[1000, 0], above=[1500, 10000])
>>> # Option 2:
>>> new_grid = grid.gmt.clip(below=[1000, 0], above=[1500, 10000])

Notes
-----
Due to the limitations of xarray accessors, the GMT accessors are created once per
Expand Down Expand Up @@ -178,3 +208,83 @@ def gtype(self, value: GridType | int):
)
raise GMTInvalidInput(msg)
self._gtype = GridType(value)

def dimfilter(self, **kwargs) -> xr.DataArray:
"""
Directional filtering of a grid in the space domain.

See the :func:`pygmt.dimfilter` function for available parameters.
"""
return dimfilter(grid=self._obj, **kwargs)
Comment on lines +212 to +218
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shorter way to declare new methods is to put this in the __init__ method:

import functools

def __init__(self, xarray_obj: xr.DataArray):
    ...
    self.dimfilter = functools.partial(dimfilter, grid=self._obj)
    self.dimfilter.__doc__ = dimfilter.__doc__

This would preserve the full docs too, e.g. output from help(grid.gmt.dimfilter)

Help on partial in module functools:

functools.partial(<function dimfilter at 0x7fd07...e:  [190. 981.]
    long_name:     elevation (m))
    Directional filtering of grids in the space domain.
    
    Filter a grid in the space (or time) domain by
    dividing the given filter circle into the given number of sectors,
    applying one of the selected primary convolution or non-convolution
    filters to each sector, and choosing the final outcome according to the
    selected secondary filter. It computes distances using Cartesian or
    Spherical geometries. The output grid can optionally be generated as a
    subregion of the input and/or with a new increment using ``spacing``,
    which may add an "extra space" in the input data to prevent edge
    effects for the output grid. If the filter is low-pass, then the output
    may be less frequently sampled than the input. :func:`pygmt.dimfilter`
    will not produce a smooth output as other spatial filters
    do because it returns a minimum median out of *N* medians of *N*
    sectors. The output can be rough unless the input data are noise-free.
    Thus, an additional filtering (e.g., Gaussian via :func:`pygmt.grdfilter`)
    of the DiM-filtered data is generally recommended.
    
    Full option list at :gmt-docs:`dimfilter.html`
    
    **Aliases:**
    
    .. hlist::
       :columns: 3
    
       - D = distance
       - F = filter
       - I = spacing
       - N = sectors
       - R = region
       - V = verbose

There might be a nicer way to wrap things (maybe using https://docs.python.org/3/library/functools.html#functools.update_wrapper?), but haven't played around with it too much.

Copy link
Member Author

@seisman seisman May 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried your solution above and found two issues:

  1. self.dimfilter is an attribute of the accessor, so it's not shown on the documentation (https://pygmt-dev--3854.org.readthedocs.build/en/3854/api/generated/pygmt.GMTDataArrayAccessor.html)
  2. The help docs still show grid as its first parameter, which may be more confusing?


def clip(self, **kwargs) -> xr.DataArray:
"""
Clip the range of grid values.

See the :func:`pygmt.grdclip` function for available parameters.
"""
return grdclip(grid=self._obj, **kwargs)

def cut(self, **kwargs) -> xr.DataArray:
"""
Extract subregion from a grid or image or a slice from a cube.

See the :func:`pygmt.grdcut` function for available parameters.
"""
return grdcut(grid=self._obj, **kwargs)

def equalize_hist(self, **kwargs) -> xr.DataArray:
"""
Perform histogram equalization for a grid.

See the :meth:`pygmt.grdhisteq.equalize_grid` method for available parameters.
"""
return grdhisteq.equalize_grid(grid=self._obj, **kwargs)

def fill(self, **kwargs) -> xr.DataArray:
"""
Interpolate across holes in the grid.

See the :func:`pygmt.grdfill` function for available parameters.
"""
return grdfill(grid=self._obj, **kwargs)

def filter(self, **kwargs) -> xr.DataArray:
"""
Filter a grid in the space (or time) domain.

See the :func:`pygmt.grdfilter` function for available parameters.
"""
return grdfilter(grid=self._obj, **kwargs)

def gradient(self, **kwargs) -> xr.DataArray:
"""
Compute directional gradients from a grid.

See the :func:`pygmt.grdgradient` function for available parameters.
"""
return grdgradient(grid=self._obj, **kwargs)

def project(self, **kwargs) -> xr.DataArray:
"""
Forward and inverse map transformation of grids.

See the :func:`pygmt.grdproject` function for available parameters.
"""
return grdproject(grid=self._obj, **kwargs)

def sample(self, **kwargs) -> xr.DataArray:
"""
Resample a grid onto a new lattice.

See the :func:`pygmt.grdsample` function for available parameters.
"""
return grdsample(grid=self._obj, **kwargs)

def track(self, **kwargs) -> xr.DataArray:
"""
Sample a grid at specified locations.

See the :func:`pygmt.grdtrack` function for available parameters.
"""
return grdtrack(grid=self._obj, **kwargs)