diff --git a/dpnp/dpnp_array.py b/dpnp/dpnp_array.py index c50ed9792720..5c4a5551d08a 100644 --- a/dpnp/dpnp_array.py +++ b/dpnp/dpnp_array.py @@ -174,6 +174,9 @@ def __ge__(self, other): # '__getattribute__', def __getitem__(self, key): + if isinstance(key, dpnp_array): + key = key.get_array() + item = self._array_obj.__getitem__(key) if not isinstance(item, dpt.usm_ndarray): raise RuntimeError( @@ -290,6 +293,11 @@ def __rtruediv__(self, other): # '__setattr__', def __setitem__(self, key, val): + if isinstance(key, dpnp_array): + key = key.get_array() + if isinstance(val, dpnp_array): + val = val.get_array() + self._array_obj.__setitem__(key, val) # '__setstate__',