Skip to content

Commit d49014d

Browse files
authored
Replace 'signature' argument in apply_ufunc with *_core_dims (#1245)
* Replace 'signature' argument to apply_func with *_core_dims This results in a much easier to use function signature than the previous `signature` argument which required triply nested lists. e.g., new: def inner_product(a, b, dim): return apply_ufunc(_inner, a, b, input_core_dims=[[dim], [dim]]) vs. old: def inner_product(a, b, dim): return apply_ufunc(_inner, a, b, signature=[[[dim], [dim]], []]) * Remove use of lambdas to please flake8
1 parent 7761c4b commit d49014d

File tree

2 files changed

+127
-166
lines changed

2 files changed

+127
-166
lines changed

xarray/core/computation.py

Lines changed: 93 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -21,36 +21,21 @@
2121
_DEFAULT_FILL_VALUE = object()
2222

2323

24-
# see http://docs.scipy.org/doc/numpy/reference/c-api.generalized-ufuncs.html
25-
DIMENSION_NAME = r'\w+'
26-
CORE_DIMENSION_LIST = '(?:' + DIMENSION_NAME + '(?:,' + DIMENSION_NAME + ')*)?'
27-
ARGUMENT = r'\(' + CORE_DIMENSION_LIST + r'\)'
28-
ARGUMENT_LIST = ARGUMENT + '(?:,' + ARGUMENT + ')*'
29-
SIGNATURE = '^' + ARGUMENT_LIST + '->' + ARGUMENT_LIST + '$'
30-
31-
32-
def safe_tuple(x):
33-
# type: Iterable -> tuple
34-
if isinstance(x, basestring):
35-
raise ValueError('cannot safely convert %r to a tuple')
36-
return tuple(x)
37-
38-
39-
class UFuncSignature(object):
24+
class _UFuncSignature(object):
4025
"""Core dimensions signature for a given function.
4126
4227
Based on the signature provided by generalized ufuncs in NumPy.
4328
4429
Attributes
4530
----------
46-
input_core_dims : list of tuples
47-
A list of tuples of core dimension names on each input variable.
48-
output_core_dims : list of tuples
49-
A list of tuples of core dimension names on each output variable.
31+
input_core_dims : tuple[tuple]
32+
Core dimension names on each input variable.
33+
output_core_dims : tuple[tuple]
34+
Core dimension names on each output variable.
5035
"""
5136
def __init__(self, input_core_dims, output_core_dims=((),)):
52-
self.input_core_dims = tuple(safe_tuple(a) for a in input_core_dims)
53-
self.output_core_dims = tuple(safe_tuple(a) for a in output_core_dims)
37+
self.input_core_dims = tuple(tuple(a) for a in input_core_dims)
38+
self.output_core_dims = tuple(tuple(a) for a in output_core_dims)
5439
self._all_input_core_dims = None
5540
self._all_output_core_dims = None
5641
self._all_core_dims = None
@@ -103,21 +88,6 @@ def from_sequence(cls, nested):
10388
'both input and output.')
10489
return cls(*nested)
10590

106-
@classmethod
107-
def from_string(cls, string):
108-
"""Create a UFuncSignature object from a NumPy gufunc signature.
109-
110-
Parameters
111-
----------
112-
string : str
113-
Signature string, e.g., (m,n),(n,p)->(m,p).
114-
"""
115-
if not re.match(SIGNATURE, string):
116-
raise ValueError('not a valid gufunc signature: {}'.format(string))
117-
return cls(*[[re.findall(DIMENSION_NAME, arg)
118-
for arg in re.findall(ARGUMENT, arg_list)]
119-
for arg_list in string.split('->')])
120-
12191
def __eq__(self, other):
12292
try:
12393
return (self.input_core_dims == other.input_core_dims and
@@ -148,9 +118,6 @@ def result_name(objects):
148118
return name
149119

150120

151-
_REPEAT_NONE = itertools.repeat(None)
152-
153-
154121
def _get_coord_variables(args):
155122
input_coords = []
156123
for arg in args:
@@ -166,7 +133,7 @@ def _get_coord_variables(args):
166133

167134
def build_output_coords(
168135
args, # type: list
169-
signature, # type: UFuncSignature
136+
signature, # type: _UFuncSignature
170137
exclude_dims=frozenset(), # type: set
171138
):
172139
# type: (...) -> List[OrderedDict[Any, Variable]]
@@ -235,6 +202,7 @@ def apply_dataarray_ufunc(func, *args, **kwargs):
235202
out._copy_attrs_from(args[0])
236203
return out
237204

205+
238206
def ordered_set_union(all_keys):
239207
# type: List[Iterable] -> Iterable
240208
result_dict = OrderedDict()
@@ -567,51 +535,62 @@ def apply_array_ufunc(func, *args, **kwargs):
567535

568536

569537
def apply_ufunc(func, *args, **kwargs):
570-
"""apply_ufunc(func, *args, signature=None, join='inner',
571-
exclude_dims=frozenset(), dataset_join='inner',
572-
dataset_fill_value=_DEFAULT_FILL_VALUE, keep_attrs=False,
573-
kwargs=None, dask_array='forbidden')
574-
575-
Apply a vectorized function for unlabeled arrays to xarray objects.
576-
577-
The input arguments will be handled using xarray's standard rules for
578-
labeled computation, including alignment, broadcasting, looping over
579-
GroupBy/Dataset variables, and merging of coordinates.
538+
"""apply_ufunc(func : Callable,
539+
*args : Any,
540+
input_core_dims : Optional[Sequence[Sequence]] = None,
541+
output_core_dims : Optional[Sequence[Sequence]] = ((),),
542+
exclude_dims : Collection = frozenset(),
543+
join : str = 'inner',
544+
dataset_join : str = 'inner',
545+
dataset_fill_value : Any = _DEFAULT_FILL_VALUE,
546+
keep_attrs : bool = False,
547+
kwargs : Mapping = None,
548+
dask_array : str = 'forbidden')
549+
550+
Apply a vectorized function for unlabeled arrays on xarray objects.
551+
552+
The function will be mapped over the data variable(s) of the input
553+
arguments using xarray's standard rules for labeled computation, including
554+
alignment, broadcasting, looping over GroupBy/Dataset variables, and
555+
merging of coordinates.
580556
581557
Parameters
582558
----------
583559
func : callable
584560
Function to call like ``func(*args, **kwargs)`` on unlabeled arrays
585-
(``.data``). If multiple arguments with non-matching dimensions are
586-
supplied, this function is expected to vectorize (broadcast) over
587-
axes of positional arguments in the style of NumPy universal
588-
functions [1]_.
561+
(``.data``) that returns an array or tuple of arrays. If multiple
562+
arguments with non-matching dimensions are supplied, this function is
563+
expected to vectorize (broadcast) over axes of positional arguments in
564+
the style of NumPy universal functions [1]_. If this function returns
565+
multiple outputs, you most set ``output_core_dims`` as well.
589566
*args : Dataset, DataArray, GroupBy, Variable, numpy/dask arrays or scalars
590567
Mix of labeled and/or unlabeled arrays to which to apply the function.
591-
signature : string or triply nested sequence, optional
592-
Object indicating core dimensions that should not be broadcast on
593-
the input and outputs arguments. If omitted, inputs will be broadcast
594-
to share all dimensions in common before calling ``func`` on their
595-
values, and the output of ``func`` will be assumed to be a single array
596-
with the same shape as the inputs.
597-
598-
Two forms of signatures are accepted:
599-
(a) A signature string of the form used by NumPy's generalized
600-
universal functions [2]_, e.g., '(),(time)->()' indicating a
601-
function that accepts two arguments and returns a single argument,
602-
on which all dimensions should be broadcast except 'time' on the
603-
second argument.
604-
(a) A triply nested sequence providing lists of core dimensions for
605-
each variable, for both input and output, e.g.,
606-
``([(), ('time',)], [()])``.
607-
608-
Core dimensions are automatically moved to the last axes of any input
609-
variables, which facilitates using NumPy style generalized ufuncs (see
610-
the examples below).
611-
612-
Unlike the NumPy gufunc signature spec, the names of all dimensions
613-
provided in signatures must be the names of actual dimensions on the
614-
xarray objects.
568+
input_core_dims : Sequence[Sequence], optional
569+
List of the same length as ``args`` giving the list of core dimensions
570+
on each input argument that should be broadcast. By default, we assume
571+
there are no core dimensions on any input arguments.
572+
573+
For example ,``input_core_dims=[[], ['time']]`` indicates that all
574+
dimensions on the first argument and all dimensions other than 'time'
575+
on the second argument should be broadcast.
576+
577+
Core dimensions are automatically moved to the last axes of input
578+
variables before applying ``func``, which facilitates using NumPy style
579+
generalized ufuncs [2]_.
580+
output_core_dims : List[tuple], optional
581+
List of the same length as the number of output arguments from
582+
``func``, giving the list of core dimensions on each output that were
583+
not broadcast on the inputs. By default, we assume that ``func``
584+
outputs exactly one array, with axes corresponding to each broadcast
585+
dimension.
586+
587+
Core dimensions are assumed to appear as the last dimensions of each
588+
output in the provided order.
589+
exclude_dims : set, optional
590+
Core dimensions on the inputs to exclude from alignment and
591+
broadcasting entirely. Any input coordinates along these dimensions
592+
will be dropped. Each excluded dimension must also appear in
593+
``input_core_dims`` for at least one argument.
615594
join : {'outer', 'inner', 'left', 'right'}, optional
616595
Method for joining the indexes of the passed objects along each
617596
dimension, and the variables of Dataset objects with mismatched
@@ -633,10 +612,6 @@ def apply_ufunc(func, *args, **kwargs):
633612
``dataset_join != 'inner'``, otherwise ignored.
634613
keep_attrs: boolean, Optional
635614
Whether to copy attributes from the first argument to the output.
636-
exclude_dims : set, optional
637-
Dimensions to exclude from alignment and broadcasting. Any inputs
638-
coordinates along these dimensions will be dropped. Each excluded
639-
dimension must be a core dimension in the function signature.
640615
kwargs: dict, optional
641616
Optional keyword arguments passed directly on to call ``func``.
642617
dask_array: 'forbidden' or 'allowed', optional
@@ -661,13 +636,13 @@ def magnitude(a, b):
661636
func = lambda x, y: np.sqrt(x ** 2 + y ** 2)
662637
return xr.apply_func(func, a, b)
663638
664-
Compute the mean (``.mean``)::
639+
Compute the mean (``.mean``) over one dimension::
665640
666641
def mean(obj, dim):
667642
# note: apply always moves core dimensions to the end
668-
sig = ([(dim,)], [()])
669-
kwargs = {'axis': -1}
670-
return apply_ufunc(np.mean, obj, signature=sig, kwargs=kwargs)
643+
return apply_ufunc(np.mean, obj,
644+
input_core_dims=[[dim]],
645+
kwargs={'axis': -1})
671646
672647
Inner product over a specific dimension::
673648
@@ -676,16 +651,17 @@ def _inner(x, y):
676651
return result[..., 0, 0]
677652
678653
def inner_product(a, b, dim):
679-
sig = ([(dim,), (dim,)], [()])
680-
return apply_ufunc(_inner, a, b, signature=sig)
654+
return apply_ufunc(_inner, a, b, input_core_dims=[[dim], [dim]])
681655
682656
Stack objects along a new dimension (like ``xr.concat``)::
683657
684658
def stack(objects, dim, new_coord):
685-
sig = ([()] * len(objects), [(dim,)])
659+
# note: this version does not stack coordinates
686660
func = lambda *x: np.stack(x, axis=-1)
687-
result = apply_ufunc(func, *objects, signature=sig,
688-
join='outer', dataset_fill_value=np.nan)
661+
result = apply_ufunc(func, *objects,
662+
output_core_dims=[[dim]],
663+
join='outer',
664+
dataset_fill_value=np.nan)
689665
result[dim] = new_coord
690666
return result
691667
@@ -710,7 +686,8 @@ def stack(objects, dim, new_coord):
710686
from .dataarray import DataArray
711687
from .variable import Variable
712688

713-
signature = kwargs.pop('signature', None)
689+
input_core_dims = kwargs.pop('input_core_dims', None)
690+
output_core_dims = kwargs.pop('output_core_dims', ((),))
714691
join = kwargs.pop('join', 'inner')
715692
dataset_join = kwargs.pop('dataset_join', 'inner')
716693
keep_attrs = kwargs.pop('keep_attrs', False)
@@ -722,12 +699,10 @@ def stack(objects, dim, new_coord):
722699
raise TypeError('apply_ufunc() got unexpected keyword arguments: %s'
723700
% list(kwargs))
724701

725-
if signature is None:
726-
signature = UFuncSignature.default(len(args))
727-
elif isinstance(signature, basestring):
728-
signature = UFuncSignature.from_string(signature)
729-
elif not isinstance(signature, UFuncSignature):
730-
signature = UFuncSignature.from_sequence(signature)
702+
if input_core_dims is None:
703+
input_core_dims = ((),) * (len(args))
704+
705+
signature = _UFuncSignature(input_core_dims, output_core_dims)
731706

732707
if exclude_dims and not exclude_dims <= signature.all_core_dims:
733708
raise ValueError('each dimension in `exclude_dims` must also be a '
@@ -739,28 +714,35 @@ def stack(objects, dim, new_coord):
739714
array_ufunc = functools.partial(
740715
apply_array_ufunc, func, dask_array=dask_array)
741716

742-
variables_ufunc = functools.partial(
743-
apply_variable_ufunc, array_ufunc, signature=signature,
744-
exclude_dims=exclude_dims)
717+
variables_ufunc = functools.partial(apply_variable_ufunc, array_ufunc,
718+
signature=signature,
719+
exclude_dims=exclude_dims)
745720

746721
if any(isinstance(a, GroupBy) for a in args):
747-
this_apply = functools.partial(
748-
apply_ufunc, func, signature=signature, join=join,
749-
dask_array=dask_array, exclude_dims=exclude_dims,
750-
dataset_fill_value=dataset_fill_value,
751-
dataset_join=dataset_join,
752-
keep_attrs=keep_attrs)
722+
# kwargs has already been added into func
723+
this_apply = functools.partial(apply_ufunc, func,
724+
input_core_dims=input_core_dims,
725+
output_core_dims=output_core_dims,
726+
exclude_dims=exclude_dims,
727+
join=join,
728+
dataset_join=dataset_join,
729+
dataset_fill_value=dataset_fill_value,
730+
keep_attrs=keep_attrs,
731+
dask_array=dask_array)
753732
return apply_groupby_ufunc(this_apply, *args)
754733
elif any(is_dict_like(a) for a in args):
755-
return apply_dataset_ufunc(variables_ufunc, *args, signature=signature,
756-
join=join, exclude_dims=exclude_dims,
734+
return apply_dataset_ufunc(variables_ufunc, *args,
735+
signature=signature,
736+
join=join,
737+
exclude_dims=exclude_dims,
757738
fill_value=dataset_fill_value,
758739
dataset_join=dataset_join,
759740
keep_attrs=keep_attrs)
760741
elif any(isinstance(a, DataArray) for a in args):
761742
return apply_dataarray_ufunc(variables_ufunc, *args,
762743
signature=signature,
763-
join=join, exclude_dims=exclude_dims,
744+
join=join,
745+
exclude_dims=exclude_dims,
764746
keep_attrs=keep_attrs)
765747
elif any(isinstance(a, Variable) for a in args):
766748
return variables_ufunc(*args)

0 commit comments

Comments
 (0)