@@ -433,17 +433,6 @@ macro_rules! overloaded_binary_func {
433433 ///
434434 /// An Array with results of the binary operation.
435435 ///
436- /// In the case of comparison operations such as the following, the type of output
437- /// Array is [DType::B8](./enum.DType.html). To retrieve the results of such boolean output
438- /// to host, an array of 8-bit wide types(eg. u8, i8) should be used since ArrayFire's internal
439- /// implementation uses char for boolean.
440- ///
441- /// * [gt](./fn.gt.html)
442- /// * [lt](./fn.lt.html)
443- /// * [ge](./fn.ge.html)
444- /// * [le](./fn.le.html)
445- /// * [eq](./fn.eq.html)
446- ///
447436 ///# Note
448437 ///
449438 /// The trait `Convertable` essentially translates to a scalar native type on rust or Array.
@@ -487,55 +476,123 @@ overloaded_binary_func!("Compute remainder from two Arrays", rem, rem_helper, af
487476overloaded_binary_func ! ( "Compute left shift" , shiftl, shiftl_helper, af_bitshiftl) ;
488477overloaded_binary_func ! ( "Compute right shift" , shiftr, shiftr_helper, af_bitshiftr) ;
489478overloaded_binary_func ! (
479+ "Compute modulo of two Arrays" ,
480+ modulo,
481+ modulo_helper,
482+ af_mod
483+ ) ;
484+ overloaded_binary_func ! (
485+ "Calculate atan2 of two Arrays" ,
486+ atan2,
487+ atan2_helper,
488+ af_atan2
489+ ) ;
490+ overloaded_binary_func ! (
491+ "Create complex array from two Arrays" ,
492+ cplx2,
493+ cplx2_helper,
494+ af_cplx2
495+ ) ;
496+ overloaded_binary_func ! ( "Compute root" , root, root_helper, af_root) ;
497+ overloaded_binary_func ! ( "Computer power" , pow, pow_helper, af_pow) ;
498+
499+ macro_rules! overloaded_compare_func {
500+ ( $doc_str: expr, $fn_name: ident, $help_name: ident, $ffi_name: ident) => {
501+ fn $help_name<A , B >( lhs: & Array <A >, rhs: & Array <B >, batch: bool ) -> Array <bool >
502+ where
503+ A : HasAfEnum + ImplicitPromote <B >,
504+ B : HasAfEnum + ImplicitPromote <A >,
505+ {
506+ let mut temp: i64 = 0 ;
507+ unsafe {
508+ let err_val = $ffi_name(
509+ & mut temp as MutAfArray ,
510+ lhs. get( ) as AfArray ,
511+ rhs. get( ) as AfArray ,
512+ batch as c_int,
513+ ) ;
514+ HANDLE_ERROR ( AfError :: from( err_val) ) ;
515+ }
516+ temp. into( )
517+ }
518+
519+ #[ doc=$doc_str]
520+ ///
521+ /// This is a comparison operation.
522+ ///
523+ ///# Parameters
524+ ///
525+ /// - `arg1`is an argument that implements an internal trait `Convertable`.
526+ /// - `arg2`is an argument that implements an internal trait `Convertable`.
527+ /// - `batch` is an boolean that indicates if the current operation is an batch operation.
528+ ///
529+ /// Both parameters `arg1` and `arg2` can be either an Array or a value of rust integral
530+ /// type.
531+ ///
532+ ///# Return Values
533+ ///
534+ /// An Array with results of the comparison operation a.k.a an Array of boolean values.
535+ ///# Note
536+ ///
537+ /// The trait `Convertable` essentially translates to a scalar native type on rust or Array.
538+ pub fn $fn_name<T , U >(
539+ arg1: & T ,
540+ arg2: & U ,
541+ batch: bool ,
542+ ) -> Array <bool >
543+ where
544+ T : Convertable ,
545+ U : Convertable ,
546+ <T as Convertable >:: OutType : HasAfEnum + ImplicitPromote <<U as Convertable >:: OutType >,
547+ <U as Convertable >:: OutType : HasAfEnum + ImplicitPromote <<T as Convertable >:: OutType >,
548+ {
549+ let lhs = arg1. convert( ) ; // Convert to Array<T>
550+ let rhs = arg2. convert( ) ; // Convert to Array<T>
551+ match ( lhs. is_scalar( ) , rhs. is_scalar( ) ) {
552+ ( true , false ) => {
553+ let l = tile( & lhs, rhs. dims( ) ) ;
554+ $help_name( & l, & rhs, batch)
555+ }
556+ ( false , true ) => {
557+ let r = tile( & rhs, lhs. dims( ) ) ;
558+ $help_name( & lhs, & r, batch)
559+ }
560+ _ => $help_name( & lhs, & rhs, batch) ,
561+ }
562+ }
563+ } ;
564+ }
565+
566+ overloaded_compare_func ! (
490567 "Perform `less than` comparison operation" ,
491568 lt,
492569 lt_helper,
493570 af_lt
494571) ;
495- overloaded_binary_func ! (
572+ overloaded_compare_func ! (
496573 "Perform `greater than` comparison operation" ,
497574 gt,
498575 gt_helper,
499576 af_gt
500577) ;
501- overloaded_binary_func ! (
578+ overloaded_compare_func ! (
502579 "Perform `less than equals` comparison operation" ,
503580 le,
504581 le_helper,
505582 af_le
506583) ;
507- overloaded_binary_func ! (
584+ overloaded_compare_func ! (
508585 "Perform `greater than equals` comparison operation" ,
509586 ge,
510587 ge_helper,
511588 af_ge
512589) ;
513- overloaded_binary_func ! (
590+ overloaded_compare_func ! (
514591 "Perform `equals` comparison operation" ,
515592 eq,
516593 eq_helper,
517594 af_eq
518595) ;
519- overloaded_binary_func ! (
520- "Compute modulo of two Arrays" ,
521- modulo,
522- modulo_helper,
523- af_mod
524- ) ;
525- overloaded_binary_func ! (
526- "Calculate atan2 of two Arrays" ,
527- atan2,
528- atan2_helper,
529- af_atan2
530- ) ;
531- overloaded_binary_func ! (
532- "Create complex array from two Arrays" ,
533- cplx2,
534- cplx2_helper,
535- af_cplx2
536- ) ;
537- overloaded_binary_func ! ( "Compute root" , root, root_helper, af_root) ;
538- overloaded_binary_func ! ( "Computer power" , pow, pow_helper, af_pow) ;
539596
540597fn clamp_helper < X , Y > (
541598 inp : & Array < X > ,
0 commit comments