Skip to content

Commit 2753379

Browse files
edwinsolisfsyurkevi
authored andcommitted
Fixed indexing with Arrays
1 parent 0892940 commit 2753379

File tree

1 file changed

+66
-12
lines changed

1 file changed

+66
-12
lines changed

arrayfire/array_object.py

Lines changed: 66 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -746,27 +746,40 @@ def __getitem__(self, key: IndexKey, /) -> Array:
746746
----------
747747
self : Array
748748
Array instance.
749-
key : int | slice | tuple[int | slice, ...] | Array
749+
key : int | slice | tuple[int | slice | Array, ...] | Array
750750
Index key.
751751
752752
Returns
753753
-------
754754
out : Array
755755
An array containing the accessed value(s). The returned array must have the same data type as self.
756756
"""
757-
# TODO
758-
# API Specification - key: Union[int, slice, ellipsis, tuple[Union[int, slice, ellipsis], ...], array].
759-
# consider using af.span to replace ellipsis during refactoring
760757
out = Array()
761758
ndims = self.ndim
762759

763-
if isinstance(key, Array) and key == afbool.c_api_value:
760+
indexing = key
761+
762+
if isinstance(key, int | float | slice): # when indexing with one dimension, treat it as indexing a flat array
764763
ndims = 1
765-
if wrapper.count_all(key.arr) == 0: # HACK was count() method before
766-
return out
767-
768-
# HACK known issue
769-
out._arr = wrapper.index_gen(self._arr, ndims, wrapper.get_indices(key)) # type: ignore[arg-type]
764+
elif isinstance(key, Array): # when indexing with one array, treat it as indexing a flat array
765+
ndims = 1
766+
if key.is_bool:
767+
indexing = wrapper.where(key.arr)
768+
else:
769+
indexing = key.arr
770+
elif isinstance(key, tuple):
771+
key_list = []
772+
for elem in key:
773+
if isinstance(elem, Array):
774+
if elem.is_bool:
775+
key_list.append(wrapper.where(elem.arr))
776+
else:
777+
key_list.append(elem.arr)
778+
else:
779+
key_list.append(elem)
780+
indexing = tuple(key_list)
781+
782+
out._arr = wrapper.index_gen(self._arr, ndims, wrapper.get_indices(indexing)) # type: ignore[arg-type]
770783
return out
771784

772785
def __index__(self) -> int:
@@ -781,6 +794,18 @@ def __len__(self) -> int:
781794
return self.shape[0] if self.shape else 0
782795

783796
def __setitem__(self, key: IndexKey, value: int | float | bool | Array, /) -> None:
797+
"""
798+
Assigns self[key] = value
799+
800+
Parameters
801+
----------
802+
self : Array
803+
Array instance.
804+
key : int | slice | tuple[int | slice | Array, ...] | Array
805+
Index key.
806+
value: int | float | complex | bool | Array
807+
808+
"""
784809
ndims = self.ndim
785810

786811
is_array_with_bool = isinstance(key, Array) and type(key) is afbool
@@ -803,7 +828,29 @@ def __setitem__(self, key: IndexKey, value: int | float | bool | Array, /) -> No
803828
other_arr = value.arr
804829
del_other = False
805830

806-
indices = wrapper.get_indices(key) # type: ignore[arg-type] # FIXME
831+
indexing = key
832+
if isinstance(key, int | float | slice): # when indexing with one dimension, treat it as indexing a flat array
833+
ndims = 1
834+
elif isinstance(key, Array): # when indexing with one array, treat it as indexing a flat array
835+
ndims = 1
836+
if key.is_bool:
837+
indexing = wrapper.where(key.arr)
838+
else:
839+
indexing = key.arr
840+
elif isinstance(key, tuple):
841+
key_list = []
842+
for elem in key:
843+
if isinstance(elem, Array):
844+
if elem.is_bool:
845+
key_list.append(wrapper.where(elem.arr))
846+
else:
847+
key_list.append(elem.arr)
848+
else:
849+
key_list.append(elem)
850+
indexing = tuple(key_list)
851+
852+
indices = wrapper.get_indices(indexing)
853+
807854
out = wrapper.assign_gen(self._arr, other_arr, ndims, indices)
808855

809856
wrapper.release_array(self._arr)
@@ -1144,7 +1191,10 @@ def _get_processed_index(key: IndexKey, shape: tuple[int, ...]) -> tuple[int, ..
11441191
if isinstance(key, tuple):
11451192
return tuple(_index_to_afindex(key[i], shape[i]) for i in range(len(key)))
11461193

1147-
return (_index_to_afindex(key, shape[0]),) + shape[1:]
1194+
size = 1
1195+
for dim_size in shape:
1196+
size *= dim_size
1197+
return (_index_to_afindex(key, size),)
11481198

11491199

11501200
def _index_to_afindex(key: int | float | complex | bool | slice | wrapper.ParallelRange | Array, axis: int) -> int:
@@ -1168,6 +1218,10 @@ def _index_to_afindex(key: int | float | complex | bool | slice | wrapper.Parall
11681218

11691219

11701220
def _slice_to_length(key: slice, axis: int) -> int:
1221+
start = key.start
1222+
stop = key.stop
1223+
step = key.step
1224+
11711225
if key.start is None:
11721226
start = 0
11731227
elif key.start < 0:

0 commit comments

Comments
 (0)