Skip to content

Commit 2da5bc7

Browse files
committed
Merge pull request #90 from bluss/mat-mul-trait
Add a scalar trait for specializable scalars
2 parents 95e24b9 + 2342ede commit 2da5bc7

File tree

3 files changed

+35
-25
lines changed

3 files changed

+35
-25
lines changed

examples/linalg.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use std::ops::{Add, Sub, Mul, Div};
1313

1414
use ndarray::{RcArray, Ix};
1515
use ndarray::{rcarr1, rcarr2};
16+
use ndarray::LinalgScalar;
1617

1718
/// Column vector.
1819
pub type Col<A> = RcArray<A, Ix>;
@@ -29,7 +30,7 @@ pub trait Field : Ring + Div<Output=Self> { }
2930
impl<A: Ring + Div<Output=A>> Field for A { }
3031

3132
/// A real or complex number.
32-
pub trait ComplexField : Copy + Field
33+
pub trait ComplexField : LinalgScalar
3334
{
3435
#[inline]
3536
fn conjugate(self) -> Self { self }
@@ -50,7 +51,7 @@ impl ComplexField for f64
5051
fn sqrt_real(self) -> f64 { self.sqrt() }
5152
}
5253

53-
impl<A: Num + Float> ComplexField for Complex<A>
54+
impl<A: LinalgScalar + Float + Num> ComplexField for Complex<A>
5455
{
5556
#[inline]
5657
fn conjugate(self) -> Complex<A> { self.conj() }

src/lib.rs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,8 @@ pub use iterators::{
9797
AxisChunksIterMut,
9898
};
9999

100-
#[allow(deprecated)]
101-
use linalg::{Field, Ring};
100+
pub use linalg::LinalgScalar;
102101

103-
pub mod linalg;
104102
mod arraytraits;
105103
#[cfg(feature = "serde")]
106104
mod arrayserialize;
@@ -110,6 +108,7 @@ pub mod blas;
110108
mod dimension;
111109
mod indexes;
112110
mod iterators;
111+
mod linalg;
113112
mod linspace;
114113
mod numeric_util;
115114
mod si;
@@ -2252,7 +2251,7 @@ impl<A, S, D> ArrayBase<S, D>
22522251
/// **Panics** if `axis` is out of bounds.
22532252
#[allow(deprecated)]
22542253
pub fn mean(&self, axis: usize) -> OwnedArray<A, <D as RemoveAxis>::Smaller>
2255-
where A: Copy + Field,
2254+
where A: LinalgScalar,
22562255
D: RemoveAxis,
22572256
{
22582257
let n = self.shape()[axis];
@@ -2289,7 +2288,7 @@ impl<A, S> ArrayBase<S, Ix>
22892288
/// **Panics** if the arrays are not of the same length.
22902289
pub fn dot<S2>(&self, rhs: &ArrayBase<S2, Ix>) -> A
22912290
where S2: Data<Elem=A>,
2292-
A: Clone + Add<Output=A> + Mul<Output=A> + libnum::Zero,
2291+
A: LinalgScalar,
22932292
{
22942293
assert_eq!(self.len(), rhs.len());
22952294
if let Some(self_s) = self.as_slice() {
@@ -2369,7 +2368,7 @@ impl<A, S> ArrayBase<S, (Ix, Ix)>
23692368
///
23702369
#[allow(deprecated)]
23712370
pub fn mat_mul(&self, rhs: &ArrayBase<S, (Ix, Ix)>) -> OwnedArray<A, (Ix, Ix)>
2372-
where A: Copy + Ring
2371+
where A: LinalgScalar,
23732372
{
23742373
// NOTE: Matrix multiplication only defined for Copy types to
23752374
// avoid trouble with panicking + and *, and destructors
@@ -2414,7 +2413,7 @@ impl<A, S> ArrayBase<S, (Ix, Ix)>
24142413
/// **Panics** if shapes are incompatible.
24152414
#[allow(deprecated)]
24162415
pub fn mat_mul_col(&self, rhs: &ArrayBase<S, Ix>) -> OwnedArray<A, Ix>
2417-
where A: Copy + Ring
2416+
where A: LinalgScalar,
24182417
{
24192418
let ((m, a), n) = (self.dim, rhs.dim);
24202419
let (self_columns, other_rows) = (a, n);
@@ -2609,7 +2608,7 @@ impl<'a, A, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for &'a ArrayBase<S, D>
26092608
fn $mth (self, rhs: &'a ArrayBase<S2, E>) -> OwnedArray<A, D>
26102609
{
26112610
// FIXME: Can we co-broadcast arrays here? And how?
2612-
self.to_owned().$mth(rhs.view())
2611+
self.to_owned().$mth(rhs)
26132612
}
26142613
}
26152614

src/linalg.rs

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,28 @@
1-
#![allow(non_snake_case, deprecated)]
2-
#![cfg_attr(has_deprecated, deprecated(note="`linalg` is not in good shape."))]
3-
4-
//! ***Deprecated: linalg is not in good shape.***
5-
//!
6-
//! A few linear algebra operations on two-dimensional arrays.
7-
8-
use libnum::{Zero, One};
1+
use libnum::{Zero, One, Float};
92
use std::ops::{Add, Sub, Mul, Div};
3+
use std::any::Any;
104

11-
/// Trait union for a ring with 1.
12-
pub trait Ring : Clone + Zero + Add<Output=Self> + Sub<Output=Self>
13-
+ One + Mul<Output=Self> { }
14-
impl<A: Clone + Zero + Add<Output=A> + Sub<Output=A> + One + Mul<Output=A>> Ring for A { }
5+
/// Trait union for scalars (array elements) that support linear algebra operations.
6+
///
7+
/// `Any` for type-based specialization, `Copy` so that they don't need move
8+
/// semantics or destructors, and the rest are numerical traits.
9+
pub trait LinalgScalar :
10+
Any +
11+
Copy +
12+
Zero + One +
13+
Add<Output=Self> +
14+
Sub<Output=Self> +
15+
Mul<Output=Self> +
16+
Div<Output=Self>
17+
{ }
1518

16-
/// Trait union for a field.
17-
pub trait Field : Ring + Div<Output=Self> { }
18-
impl<A: Ring + Div<Output = A>> Field for A {}
19+
impl<T> LinalgScalar for T
20+
where T:
21+
Any +
22+
Copy +
23+
Zero + One +
24+
Add<Output=T> +
25+
Sub<Output=T> +
26+
Mul<Output=T> +
27+
Div<Output=T>
28+
{ }

0 commit comments

Comments
 (0)