From 607c96a3f4e7245b61472b82a32c2b795dd05842 Mon Sep 17 00:00:00 2001 From: pradeep Date: Wed, 15 Feb 2017 16:52:53 +0530 Subject: [PATCH 1/4] Implement arithmetic ops/triats for Array and &Array combinations Earlier to this commit, arithmetic ops/traits were implemented for only &Array type. --- src/arith/mod.rs | 54 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 36 insertions(+), 18 deletions(-) diff --git a/src/arith/mod.rs b/src/arith/mod.rs index 32e291dd9..22371364c 100644 --- a/src/arith/mod.rs +++ b/src/arith/mod.rs @@ -403,33 +403,51 @@ arith_scalar_spec!(i32); arith_scalar_spec!(u8); macro_rules! arith_func { - ($op_name:ident, $fn_name:ident, $ffi_fn: ident) => ( + ($op_name:ident, $fn_name:ident) => ( impl $op_name for Array { type Output = Array; fn $fn_name(self, rhs: Array) -> Array { - unsafe { - let mut temp: i64 = 0; - let err_val = $ffi_fn(&mut temp as MutAfArray, - self.get() as AfArray, rhs.get() as AfArray, 0); - HANDLE_ERROR(AfError::from(err_val)); - Array::from(temp) - } + add(&self, &rhs, false) + } + } + + impl<'a> $op_name<&'a Array> for Array { + type Output = Array; + + fn $fn_name(self, rhs: &'a Array) -> Array { + add(&self, rhs, false) + } + } + + impl<'a> $op_name for &'a Array { + type Output = Array; + + fn $fn_name(self, rhs: Array) -> Array { + add(self, &rhs, false) + } + } + + impl<'a, 'b> $op_name<&'a Array> for &'b Array { + type Output = Array; + + fn $fn_name(self, rhs: &'a Array) -> Array { + add(self, rhs, false) } } ) } -arith_func!(Add, add, af_add); -arith_func!(Sub, sub, af_sub); -arith_func!(Mul, mul, af_mul); -arith_func!(Div, div, af_div); -arith_func!(Rem, rem, af_rem); -arith_func!(BitAnd, bitand, af_bitand); -arith_func!(BitOr, bitor, af_bitor); -arith_func!(BitXor, bitxor, af_bitxor); -arith_func!(Shl, shl, af_bitshiftl); -arith_func!(Shr, shr, af_bitshiftr); +arith_func!(Add , add ); +arith_func!(Sub , sub ); +arith_func!(Mul , mul ); +arith_func!(Div , div ); +arith_func!(Rem , rem ); +arith_func!(BitAnd, bitand); +arith_func!(BitOr , bitor ); +arith_func!(BitXor, bitxor); +arith_func!(Shl , shl ); +arith_func!(Shr , shr ); #[cfg(op_assign)] mod op_assign { From 51702507360f86f09bd90aff73693705844567d2 Mon Sep 17 00:00:00 2001 From: pradeep Date: Wed, 15 Feb 2017 17:16:10 +0530 Subject: [PATCH 2/4] Delegate scalar arithmetic ops to function versions Earlier to this change, the arithmetic trait implementations used to call the C-API directly. --- src/arith/mod.rs | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/src/arith/mod.rs b/src/arith/mod.rs index 22371364c..958ec3081 100644 --- a/src/arith/mod.rs +++ b/src/arith/mod.rs @@ -217,6 +217,8 @@ macro_rules! convertable_type_def { ) } +convertable_type_def!(Complex); +convertable_type_def!(Complex); convertable_type_def!(u64); convertable_type_def!(i64); convertable_type_def!(f64); @@ -355,14 +357,8 @@ macro_rules! arith_scalar_func { type Output = Array; fn $fn_name(self, rhs: $rust_type) -> Array { - let cnst_arr = constant(rhs, self.dims()); - unsafe { - let mut temp: i64 = 0; - let err_val = $ffi_fn(&mut temp as MutAfArray, self.get() as AfArray, - cnst_arr.get() as AfArray, 0); - HANDLE_ERROR(AfError::from(err_val)); - Array::from(temp) - } + let temp = rhs.clone(); + add(self, &temp, false) } } @@ -370,14 +366,8 @@ macro_rules! arith_scalar_func { type Output = Array; fn $fn_name(self, rhs: $rust_type) -> Array { - let cnst_arr = constant(rhs, self.dims()); - unsafe { - let mut temp: i64 = 0; - let err_val = $ffi_fn(&mut temp as MutAfArray, self.get() as AfArray, - cnst_arr.get() as AfArray, 0); - HANDLE_ERROR(AfError::from(err_val)); - Array::from(temp) - } + let temp = rhs.clone(); + add(&self, &temp, false) } } ) From 718990894ed94733d78858d2a40194ce2fe2ef35 Mon Sep 17 00:00:00 2001 From: pradeep Date: Wed, 15 Feb 2017 18:06:15 +0530 Subject: [PATCH 3/4] Add batch parameter to logical operation fns --- src/arith/mod.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/arith/mod.rs b/src/arith/mod.rs index 958ec3081..113a827a8 100644 --- a/src/arith/mod.rs +++ b/src/arith/mod.rs @@ -180,12 +180,12 @@ macro_rules! binary_func { /// /// This is an element wise binary operation. #[allow(unused_mut)] - pub fn $fn_name(lhs: &Array, rhs: &Array) -> Array { + pub fn $fn_name(lhs: &Array, rhs: &Array, batch: bool) -> Array { unsafe { let mut temp: i64 = 0; let err_val = $ffi_fn(&mut temp as MutAfArray, lhs.get() as AfArray, rhs.get() as AfArray, - 0); + batch as c_int); HANDLE_ERROR(AfError::from(err_val)); Array::from(temp) } @@ -485,7 +485,7 @@ macro_rules! bit_assign_func { let mut idxrs = Indexer::new(); idxrs.set_index(&Seq::::default(), 0, Some(false)); idxrs.set_index(&Seq::::default(), 1, Some(false)); - let tmp = assign_gen(self as &Array, &idxrs, & $func(self as &Array, &rhs)); + let tmp = assign_gen(self as &Array, &idxrs, & $func(self as &Array, &rhs, false)); mem::replace(self, tmp); } } From 3ba93c8d3b355992eca2bc48afde94cc9c5770ed Mon Sep 17 00:00:00 2001 From: pradeep Date: Wed, 15 Feb 2017 18:06:58 +0530 Subject: [PATCH 4/4] fix typos in arith ops/traits implementation --- src/arith/mod.rs | 44 ++++++++++++++++++++++---------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/src/arith/mod.rs b/src/arith/mod.rs index 113a827a8..a3f194ce8 100644 --- a/src/arith/mod.rs +++ b/src/arith/mod.rs @@ -352,13 +352,13 @@ pub fn clamp (input: &Array, arg1: &T, arg2: &U, batch: bool) -> Array } macro_rules! arith_scalar_func { - ($rust_type: ty, $op_name:ident, $fn_name: ident, $ffi_fn: ident) => ( + ($rust_type: ty, $op_name:ident, $fn_name: ident) => ( impl<'f> $op_name<$rust_type> for &'f Array { type Output = Array; fn $fn_name(self, rhs: $rust_type) -> Array { let temp = rhs.clone(); - add(self, &temp, false) + $fn_name(self, &temp, false) } } @@ -367,7 +367,7 @@ macro_rules! arith_scalar_func { fn $fn_name(self, rhs: $rust_type) -> Array { let temp = rhs.clone(); - add(&self, &temp, false) + $fn_name(&self, &temp, false) } } ) @@ -375,10 +375,10 @@ macro_rules! arith_scalar_func { macro_rules! arith_scalar_spec { ($ty_name:ty) => ( - arith_scalar_func!($ty_name, Add, add, af_add); - arith_scalar_func!($ty_name, Sub, sub, af_sub); - arith_scalar_func!($ty_name, Mul, mul, af_mul); - arith_scalar_func!($ty_name, Div, div, af_div); + arith_scalar_func!($ty_name, Add, add); + arith_scalar_func!($ty_name, Sub, sub); + arith_scalar_func!($ty_name, Mul, mul); + arith_scalar_func!($ty_name, Div, div); ) } @@ -393,12 +393,12 @@ arith_scalar_spec!(i32); arith_scalar_spec!(u8); macro_rules! arith_func { - ($op_name:ident, $fn_name:ident) => ( + ($op_name:ident, $fn_name:ident, $delegate:ident) => ( impl $op_name for Array { type Output = Array; fn $fn_name(self, rhs: Array) -> Array { - add(&self, &rhs, false) + $delegate(&self, &rhs, false) } } @@ -406,7 +406,7 @@ macro_rules! arith_func { type Output = Array; fn $fn_name(self, rhs: &'a Array) -> Array { - add(&self, rhs, false) + $delegate(&self, rhs, false) } } @@ -414,7 +414,7 @@ macro_rules! arith_func { type Output = Array; fn $fn_name(self, rhs: Array) -> Array { - add(self, &rhs, false) + $delegate(self, &rhs, false) } } @@ -422,22 +422,22 @@ macro_rules! arith_func { type Output = Array; fn $fn_name(self, rhs: &'a Array) -> Array { - add(self, rhs, false) + $delegate(self, rhs, false) } } ) } -arith_func!(Add , add ); -arith_func!(Sub , sub ); -arith_func!(Mul , mul ); -arith_func!(Div , div ); -arith_func!(Rem , rem ); -arith_func!(BitAnd, bitand); -arith_func!(BitOr , bitor ); -arith_func!(BitXor, bitxor); -arith_func!(Shl , shl ); -arith_func!(Shr , shr ); +arith_func!(Add , add , add ); +arith_func!(Sub , sub , sub ); +arith_func!(Mul , mul , mul ); +arith_func!(Div , div , div ); +arith_func!(Rem , rem , rem ); +arith_func!(Shl , shl , shiftl); +arith_func!(Shr , shr , shiftr); +arith_func!(BitAnd, bitand, bitand); +arith_func!(BitOr , bitor , bitor ); +arith_func!(BitXor, bitxor, bitxor); #[cfg(op_assign)] mod op_assign {