Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions numpy_groupies/utils_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,21 +222,25 @@ def offset_labels(group_idx, inshape, axis, order, size):
Copied from
https://stackoverflow.com/questions/46256279/bin-elements-per-row-vectorized-2d-bincount-for-numpy
"""

newaxes = tuple(ax for ax in range(len(inshape)) if ax != axis)
group_idx = np.broadcast_to(np.expand_dims(group_idx, newaxes), inshape)
if axis not in (-1, len(inshape) - 1):
group_idx = np.moveaxis(group_idx, axis, -1)
newshape = group_idx.shape

group_idx = (group_idx +
np.arange(np.prod(newshape[:-1]), dtype=int).reshape((*newshape[:-1], -1))
* size
)
if axis not in (-1, len(inshape) - 1):
newshape = (s for idx, s in enumerate(inshape) if idx != axis) + (inshape[axis],)
return np.moveaxis(group_idx, -1, axis)
else:
newshape = inshape
group_idx = np.broadcast_to(group_idx, newshape)
group_idx: np.ndarray = (
group_idx
+ np.arange(np.prod(group_idx.shape[:-1]), dtype=int).reshape((*group_idx.shape[:-1], -1))
* size
)
return group_idx.reshape(inshape).ravel()
return group_idx


def input_validation(group_idx, a, size=None, order='C', axis=None,
ravel_group_idx=True, check_bounds=True, method="ravel", func=None):
ravel_group_idx=True, check_bounds=True, method="offset", func=None):
""" Do some fairly extensive checking of group_idx and a, trying to
give the user as much help as possible with what is wrong. Also,
convert ndim-indexing to 1d indexing.
Expand Down