-
Notifications
You must be signed in to change notification settings - Fork 52
Open
Description
import array_api_strict as xp
a = xp.arange(10, device=xp.Device("device1"))
xp.searchsorted(a, 42)
raises:
Traceback (most recent call last):
Cell In[5], line 5
xp.searchsorted(a, 42)
File ~/miniforge3/envs/dev/lib/python3.13/site-packages/array_api_strict/_flags.py:395 in wrapper
return func(*args, **kwargs)
File ~/miniforge3/envs/dev/lib/python3.13/site-packages/array_api_strict/_searching_functions.py:78 in searchsorted
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
AttributeError: 'int' object has no attribute 'dtype'
This is a bit annoying as it requires to instead write: xp.searchsorted(a, xp.asarray(42, device=a.device))
which feels unnecessarily verbose.
Even PyTorch accepts the following without complaining, for instance:
import array_api_compat.torch as xp
a = xp.arange(10, device="mps")
xp.searchsorted(a, 42)
However, the SPEC does not mention Python scalar support explicitly, so maybe it would need to be updated first?
Metadata
Metadata
Assignees
Labels
No labels