Skip to content

Make xp.searchsorted accept Python scalars for the second argument #967

@ogrisel

Description

@ogrisel
import array_api_strict as xp

a = xp.arange(10, device=xp.Device("device1"))
xp.searchsorted(a, 42)

raises:

Traceback (most recent call last):
  Cell In[5], line 5
    xp.searchsorted(a, 42)
  File ~/miniforge3/envs/dev/lib/python3.13/site-packages/array_api_strict/_flags.py:395 in wrapper
    return func(*args, **kwargs)
  File ~/miniforge3/envs/dev/lib/python3.13/site-packages/array_api_strict/_searching_functions.py:78 in searchsorted
    if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
AttributeError: 'int' object has no attribute 'dtype'

This is a bit annoying as it requires to instead write: xp.searchsorted(a, xp.asarray(42, device=a.device)) which feels unnecessarily verbose.

Even PyTorch accepts the following without complaining, for instance:

import array_api_compat.torch as xp

a = xp.arange(10, device="mps")
xp.searchsorted(a, 42)

However, the SPEC does not mention Python scalar support explicitly, so maybe it would need to be updated first?

https://data-apis.org/array-api/latest/API_specification/generated/array_api.searchsorted.html#searchsorted

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions