@@ -334,7 +334,7 @@ def __and__(self, other: int | bool | Array, /) -> Array:
334
334
----------
335
335
self : Array
336
336
Array instance. Should have a numeric data type.
337
- other: int | float | Array
337
+ other: int | bool | Array
338
338
Other array. Must be compatible with self (see Broadcasting). Should have a numeric data type.
339
339
340
340
Returns
@@ -354,7 +354,7 @@ def __or__(self, other: int | bool | Array, /) -> Array:
354
354
----------
355
355
self : Array
356
356
Array instance. Should have a numeric data type.
357
- other: int | float | Array
357
+ other: int | bool | Array
358
358
Other array. Must be compatible with self (see Broadcasting). Should have a numeric data type.
359
359
360
360
Returns
@@ -374,7 +374,7 @@ def __xor__(self, other: int | bool | Array, /) -> Array:
374
374
----------
375
375
self : Array
376
376
Array instance. Should have a numeric data type.
377
- other: int | float | Array
377
+ other: int | bool | Array
378
378
Other array. Must be compatible with self (see Broadcasting). Should have a numeric data type.
379
379
380
380
Returns
@@ -394,7 +394,7 @@ def __lshift__(self, other: int | Array, /) -> Array:
394
394
----------
395
395
self : Array
396
396
Array instance. Should have a numeric data type.
397
- other: int | float | Array
397
+ other: int | Array
398
398
Other array. Must be compatible with self (see Broadcasting). Should have a numeric data type.
399
399
Each element must be greater than or equal to 0.
400
400
@@ -414,7 +414,7 @@ def __rshift__(self, other: int | Array, /) -> Array:
414
414
----------
415
415
self : Array
416
416
Array instance. Should have a numeric data type.
417
- other: int | float | Array
417
+ other: int | Array
418
418
Other array. Must be compatible with self (see Broadcasting). Should have a numeric data type.
419
419
Each element must be greater than or equal to 0.
420
420
@@ -429,44 +429,121 @@ def __rshift__(self, other: int | Array, /) -> Array:
429
429
430
430
def __lt__ (self , other : int | float | Array , / ) -> Array :
431
431
"""
432
- Return self < other.
432
+ Computes the truth value of self_i < other_i for each element of an array instance with the respective
433
+ element of the array other.
434
+
435
+ Parameters
436
+ ----------
437
+ self : Array
438
+ Array instance. Should have a numeric data type.
439
+ other: int | float | Array
440
+ Other array. Must be compatible with self (see Broadcasting). Should have a real-valued data type.
441
+
442
+ Returns
443
+ -------
444
+ out : Array
445
+ An array containing the element-wise results. The returned array must have a data type of bool.
433
446
"""
434
447
return _process_c_function (self , other , backend .get ().af_lt )
435
448
436
449
def __le__ (self , other : int | float | Array , / ) -> Array :
437
450
"""
438
- Return self <= other.
451
+ Computes the truth value of self_i <= other_i for each element of an array instance with the respective
452
+ element of the array other.
453
+
454
+ Parameters
455
+ ----------
456
+ self : Array
457
+ Array instance. Should have a numeric data type.
458
+ other: int | float | Array
459
+ Other array. Must be compatible with self (see Broadcasting). Should have a real-valued data type.
460
+
461
+ Returns
462
+ -------
463
+ out : Array
464
+ An array containing the element-wise results. The returned array must have a data type of bool.
439
465
"""
440
466
return _process_c_function (self , other , backend .get ().af_le )
441
467
442
468
def __gt__ (self , other : int | float | Array , / ) -> Array :
443
469
"""
444
- Return self > other.
470
+ Computes the truth value of self_i > other_i for each element of an array instance with the respective
471
+ element of the array other.
472
+
473
+ Parameters
474
+ ----------
475
+ self : Array
476
+ Array instance. Should have a numeric data type.
477
+ other: int | float | Array
478
+ Other array. Must be compatible with self (see Broadcasting). Should have a real-valued data type.
479
+
480
+ Returns
481
+ -------
482
+ out : Array
483
+ An array containing the element-wise results. The returned array must have a data type of bool.
445
484
"""
446
485
return _process_c_function (self , other , backend .get ().af_gt )
447
486
448
487
def __ge__ (self , other : int | float | Array , / ) -> Array :
449
488
"""
450
- Return self >= other.
489
+ Computes the truth value of self_i >= other_i for each element of an array instance with the respective
490
+ element of the array other.
491
+
492
+ Parameters
493
+ ----------
494
+ self : Array
495
+ Array instance. Should have a numeric data type.
496
+ other: int | float | Array
497
+ Other array. Must be compatible with self (see Broadcasting). Should have a real-valued data type.
498
+
499
+ Returns
500
+ -------
501
+ out : Array
502
+ An array containing the element-wise results. The returned array must have a data type of bool.
451
503
"""
452
504
return _process_c_function (self , other , backend .get ().af_ge )
453
505
454
506
def __eq__ (self , other : int | float | bool | Array , / ) -> Array : # type: ignore[override] # FIXME
455
507
"""
456
- Return self == other.
508
+ Computes the truth value of self_i == other_i for each element of an array instance with the respective
509
+ element of the array other.
510
+
511
+ Parameters
512
+ ----------
513
+ self : Array
514
+ Array instance. Should have a numeric data type.
515
+ other: int | float | bool | Array
516
+ Other array. Must be compatible with self (see Broadcasting). May have any data type.
517
+
518
+ Returns
519
+ -------
520
+ out : Array
521
+ An array containing the element-wise results. The returned array must have a data type of bool.
457
522
"""
458
523
return _process_c_function (self , other , backend .get ().af_eq )
459
524
460
525
def __ne__ (self , other : int | float | bool | Array , / ) -> Array : # type: ignore[override] # FIXME
461
526
"""
462
- Return self != other.
527
+ Computes the truth value of self_i != other_i for each element of an array instance with the respective
528
+ element of the array other.
529
+
530
+ Parameters
531
+ ----------
532
+ self : Array
533
+ Array instance. Should have a numeric data type.
534
+ other: int | float | bool | Array
535
+ Other array. Must be compatible with self (see Broadcasting). May have any data type.
536
+
537
+ Returns
538
+ -------
539
+ out : Array
540
+ An array containing the element-wise results. The returned array must have a data type of bool.
463
541
"""
464
542
return _process_c_function (self , other , backend .get ().af_neq )
465
543
466
544
# Reflected Arithmetic Operators
467
545
468
546
def __radd__ (self , other : Array , / ) -> Array :
469
- # TODO discuss either we need to support complex and bool as other input type
470
547
"""
471
548
Return other + self.
472
549
"""
@@ -656,8 +733,24 @@ def __float__(self) -> float:
656
733
return NotImplemented
657
734
658
735
def __getitem__ (self , key : int | slice | tuple [int | slice ] | Array , / ) -> Array :
659
- # TODO: API Specification - key: int | slice | ellipsis | tuple[int | slice] | Array - consider using af.span
660
- # TODO: refactor
736
+ """
737
+ Returns self[key].
738
+
739
+ Parameters
740
+ ----------
741
+ self : Array
742
+ Array instance.
743
+ key : int | slice | tuple[int | slice] | Array
744
+ Index key.
745
+
746
+ Returns
747
+ -------
748
+ out : Array
749
+ An array containing the accessed value(s). The returned array must have the same data type as self.
750
+ """
751
+ # TODO
752
+ # API Specification - key: int | slice | ellipsis | tuple[int | slice] | Array.
753
+ # consider using af.span to replace ellipsis during refactoring
661
754
out = Array ()
662
755
ndims = self .ndim
663
756
@@ -706,6 +799,14 @@ def to_device(self, device: Any, /, *, stream: None | int | Any = None) -> Array
706
799
707
800
@property
708
801
def dtype (self ) -> Dtype :
802
+ """
803
+ Data type of the array elements.
804
+
805
+ Returns
806
+ -------
807
+ out : Dtype
808
+ Array data type.
809
+ """
709
810
out = ctypes .c_int ()
710
811
safe_call (backend .get ().af_get_type (ctypes .pointer (out ), self .arr ))
711
812
return _c_api_value_to_dtype (out .value )
@@ -724,29 +825,66 @@ def mT(self) -> Array:
724
825
def T (self ) -> Array :
725
826
"""
726
827
Transpose of the array.
828
+
829
+ Returns
830
+ -------
831
+ out : Array
832
+ Two-dimensional array whose first and last dimensions (axes) are permuted in reverse order relative to
833
+ original array. The returned array must have the same data type as the original array.
834
+
835
+ Note
836
+ ----
837
+ - The array instance must be two-dimensional. If the array instance is not two-dimensional, an error
838
+ should be raised.
727
839
"""
840
+ if self .ndim < 2 :
841
+ raise TypeError (f"Array should be at least 2-dimensional. Got { self .ndim } -dimensional array" )
842
+
843
+ # TODO add check if out.dtype == self.dtype
728
844
out = Array ()
729
- # NOTE conj support is removed because it is never used
730
845
safe_call (backend .get ().af_transpose (ctypes .pointer (out .arr ), self .arr , False ))
731
846
return out
732
847
733
848
@property
734
849
def size (self ) -> int :
850
+ """
851
+ Number of elements in an array.
852
+
853
+ Returns
854
+ -------
855
+ out : int
856
+ Number of elements in an array
857
+
858
+ Note
859
+ ----
860
+ - This must equal the product of the array's dimensions.
861
+ """
735
862
# NOTE previously - elements()
736
863
out = c_dim_t (0 )
737
864
safe_call (backend .get ().af_get_elements (ctypes .pointer (out ), self .arr ))
738
865
return out .value
739
866
740
867
@property
741
868
def ndim (self ) -> int :
869
+ """
870
+ Number of array dimensions (axes).
871
+
872
+ out : int
873
+ Number of array dimensions (axes).
874
+ """
742
875
out = ctypes .c_uint (0 )
743
876
safe_call (backend .get ().af_get_numdims (ctypes .pointer (out ), self .arr ))
744
877
return out .value
745
878
746
879
@property
747
880
def shape (self ) -> ShapeType :
748
881
"""
749
- Return the shape of the array as a tuple.
882
+ Array dimensions.
883
+
884
+ Returns
885
+ -------
886
+ out : tuple[int, ...]
887
+ Array dimensions.
750
888
"""
751
889
# TODO refactor
752
890
d0 = c_dim_t (0 )
0 commit comments