@@ -746,27 +746,40 @@ def __getitem__(self, key: IndexKey, /) -> Array:
746
746
----------
747
747
self : Array
748
748
Array instance.
749
- key : int | slice | tuple[int | slice, ...] | Array
749
+ key : int | slice | tuple[int | slice | Array , ...] | Array
750
750
Index key.
751
751
752
752
Returns
753
753
-------
754
754
out : Array
755
755
An array containing the accessed value(s). The returned array must have the same data type as self.
756
756
"""
757
- # TODO
758
- # API Specification - key: Union[int, slice, ellipsis, tuple[Union[int, slice, ellipsis], ...], array].
759
- # consider using af.span to replace ellipsis during refactoring
760
757
out = Array ()
761
758
ndims = self .ndim
762
759
763
- if isinstance (key , Array ) and key == afbool .c_api_value :
760
+ indexing = key
761
+
762
+ if isinstance (key , int | float | slice ): # when indexing with one dimension, treat it as indexing a flat array
764
763
ndims = 1
765
- if wrapper .count_all (key .arr ) == 0 : # HACK was count() method before
766
- return out
767
-
768
- # HACK known issue
769
- out ._arr = wrapper .index_gen (self ._arr , ndims , wrapper .get_indices (key )) # type: ignore[arg-type]
764
+ elif isinstance (key , Array ): # when indexing with one array, treat it as indexing a flat array
765
+ ndims = 1
766
+ if key .is_bool :
767
+ indexing = wrapper .where (key .arr )
768
+ else :
769
+ indexing = key .arr
770
+ elif isinstance (key , tuple ):
771
+ key_list = []
772
+ for elem in key :
773
+ if isinstance (elem , Array ):
774
+ if elem .is_bool :
775
+ key_list .append (wrapper .where (elem .arr ))
776
+ else :
777
+ key_list .append (elem .arr )
778
+ else :
779
+ key_list .append (elem )
780
+ indexing = tuple (key_list )
781
+
782
+ out ._arr = wrapper .index_gen (self ._arr , ndims , wrapper .get_indices (indexing )) # type: ignore[arg-type]
770
783
return out
771
784
772
785
def __index__ (self ) -> int :
@@ -781,6 +794,18 @@ def __len__(self) -> int:
781
794
return self .shape [0 ] if self .shape else 0
782
795
783
796
def __setitem__ (self , key : IndexKey , value : int | float | bool | Array , / ) -> None :
797
+ """
798
+ Assigns self[key] = value
799
+
800
+ Parameters
801
+ ----------
802
+ self : Array
803
+ Array instance.
804
+ key : int | slice | tuple[int | slice | Array, ...] | Array
805
+ Index key.
806
+ value: int | float | complex | bool | Array
807
+
808
+ """
784
809
ndims = self .ndim
785
810
786
811
is_array_with_bool = isinstance (key , Array ) and type (key ) is afbool
@@ -803,7 +828,29 @@ def __setitem__(self, key: IndexKey, value: int | float | bool | Array, /) -> No
803
828
other_arr = value .arr
804
829
del_other = False
805
830
806
- indices = wrapper .get_indices (key ) # type: ignore[arg-type] # FIXME
831
+ indexing = key
832
+ if isinstance (key , int | float | slice ): # when indexing with one dimension, treat it as indexing a flat array
833
+ ndims = 1
834
+ elif isinstance (key , Array ): # when indexing with one array, treat it as indexing a flat array
835
+ ndims = 1
836
+ if key .is_bool :
837
+ indexing = wrapper .where (key .arr )
838
+ else :
839
+ indexing = key .arr
840
+ elif isinstance (key , tuple ):
841
+ key_list = []
842
+ for elem in key :
843
+ if isinstance (elem , Array ):
844
+ if elem .is_bool :
845
+ key_list .append (wrapper .where (elem .arr ))
846
+ else :
847
+ key_list .append (elem .arr )
848
+ else :
849
+ key_list .append (elem )
850
+ indexing = tuple (key_list )
851
+
852
+ indices = wrapper .get_indices (indexing )
853
+
807
854
out = wrapper .assign_gen (self ._arr , other_arr , ndims , indices )
808
855
809
856
wrapper .release_array (self ._arr )
@@ -1144,7 +1191,10 @@ def _get_processed_index(key: IndexKey, shape: tuple[int, ...]) -> tuple[int, ..
1144
1191
if isinstance (key , tuple ):
1145
1192
return tuple (_index_to_afindex (key [i ], shape [i ]) for i in range (len (key )))
1146
1193
1147
- return (_index_to_afindex (key , shape [0 ]),) + shape [1 :]
1194
+ size = 1
1195
+ for dim_size in shape :
1196
+ size *= dim_size
1197
+ return (_index_to_afindex (key , size ),)
1148
1198
1149
1199
1150
1200
def _index_to_afindex (key : int | float | complex | bool | slice | wrapper .ParallelRange | Array , axis : int ) -> int :
@@ -1168,6 +1218,10 @@ def _index_to_afindex(key: int | float | complex | bool | slice | wrapper.Parall
1168
1218
1169
1219
1170
1220
def _slice_to_length (key : slice , axis : int ) -> int :
1221
+ start = key .start
1222
+ stop = key .stop
1223
+ step = key .step
1224
+
1171
1225
if key .start is None :
1172
1226
start = 0
1173
1227
elif key .start < 0 :
0 commit comments