Skip to content

Commit dbb8293

Browse files
committed
Merge pull request #97 from bluss/strongly-typed-axis
Strongly typed Axis argument (newtype called Axis)
2 parents ca61280 + 9bd49a8 commit dbb8293

File tree

8 files changed

+127
-90
lines changed

8 files changed

+127
-90
lines changed

benches/bench1.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use rblas::matrix::Matrix;
1212

1313
use ndarray::{
1414
OwnedArray,
15+
Axis,
1516
};
1617
use ndarray::{arr0, arr1, arr2};
1718

@@ -562,5 +563,5 @@ fn dot_f32_1024(bench: &mut test::Bencher)
562563
fn means(bench: &mut test::Bencher) {
563564
let a = OwnedArray::from_iter(0..100_000i64);
564565
let a = a.into_shape((100, 1000)).unwrap();
565-
bench.iter(|| a.mean(0));
566+
bench.iter(|| a.mean(Axis(0)));
566567
}

examples/axis.rs

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
extern crate ndarray;
2+
3+
use ndarray::{
4+
OwnedArray,
5+
Axis,
6+
};
7+
8+
fn main() {
9+
let a = OwnedArray::<f32, _>::linspace(0., 24., 25).into_shape((5, 5)).unwrap();
10+
println!("{:?}", a.subview(Axis(0), 0));
11+
println!("{:?}", a.subview(Axis(0), 1));
12+
println!("{:?}", a.subview(Axis(1), 1));
13+
println!("{:?}", a.subview(Axis(0), 1));
14+
println!("{:?}", a.subview(Axis(2), 1)); // PANIC
15+
}

src/dimension.rs

+10
Original file line numberDiff line numberDiff line change
@@ -715,3 +715,13 @@ mod test {
715715
assert!(super::dim_stride_overlap(&dim, &strides));
716716
}
717717
}
718+
719+
/// An axis index.
720+
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
721+
pub struct Axis(pub usize);
722+
723+
impl Axis {
724+
#[inline(always)]
725+
pub fn axis(&self) -> usize { self.0 }
726+
}
727+

src/lib.rs

+55-50
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ use itertools::free::enumerate;
9292
pub use dimension::{
9393
Dimension,
9494
RemoveAxis,
95+
Axis,
9596
};
9697

9798
use dimension::stride_offset;
@@ -294,7 +295,7 @@ pub type Ixs = isize;
294295
/// Subview takes two arguments: `axis` and `index`.
295296
///
296297
/// ```
297-
/// use ndarray::{arr3, aview2};
298+
/// use ndarray::{arr3, aview2, Axis};
298299
///
299300
/// // 2 submatrices of 2 rows with 3 elements per row, means a shape of `[2, 2, 3]`.
300301
///
@@ -310,8 +311,8 @@ pub type Ixs = isize;
310311
/// // Let’s take a subview along the greatest dimension (axis 0),
311312
/// // taking submatrix 0, then submatrix 1
312313
///
313-
/// let sub_0 = a.subview(0, 0);
314-
/// let sub_1 = a.subview(0, 1);
314+
/// let sub_0 = a.subview(Axis(0), 0);
315+
/// let sub_1 = a.subview(Axis(0), 1);
315316
///
316317
/// assert_eq!(sub_0, aview2(&[[ 1, 2, 3],
317318
/// [ 4, 5, 6]]));
@@ -320,7 +321,7 @@ pub type Ixs = isize;
320321
/// assert_eq!(sub_0.shape(), &[2, 3]);
321322
///
322323
/// // This is the subview picking only axis 2, column 0
323-
/// let sub_col = a.subview(2, 0);
324+
/// let sub_col = a.subview(Axis(2), 0);
324325
///
325326
/// assert_eq!(sub_col, aview2(&[[ 1, 4],
326327
/// [ 7, 10]]));
@@ -1336,7 +1337,7 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
13361337
/// **Panics** if `axis` or `index` is out of bounds.
13371338
///
13381339
/// ```
1339-
/// use ndarray::{arr1, arr2};
1340+
/// use ndarray::{arr1, arr2, Axis};
13401341
///
13411342
/// let a = arr2(&[[1., 2.], // -- axis 0, row 0
13421343
/// [3., 4.], // -- axis 0, row 1
@@ -1345,13 +1346,13 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
13451346
/// // \ axis 1, column 1
13461347
/// // axis 1, column 0
13471348
/// assert!(
1348-
/// a.subview(0, 1) == arr1(&[3., 4.]) &&
1349-
/// a.subview(1, 1) == arr1(&[2., 4., 6.])
1349+
/// a.subview(Axis(0), 1) == arr1(&[3., 4.]) &&
1350+
/// a.subview(Axis(1), 1) == arr1(&[2., 4., 6.])
13501351
/// );
13511352
/// ```
1352-
pub fn subview(&self, axis: usize, index: Ix)
1353+
pub fn subview(&self, axis: Axis, index: Ix)
13531354
-> ArrayView<A, <D as RemoveAxis>::Smaller>
1354-
where D: RemoveAxis
1355+
where D: RemoveAxis,
13551356
{
13561357
self.view().into_subview(axis, index)
13571358
}
@@ -1362,19 +1363,19 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
13621363
/// **Panics** if `axis` or `index` is out of bounds.
13631364
///
13641365
/// ```
1365-
/// use ndarray::{arr2, aview2};
1366+
/// use ndarray::{arr2, aview2, Axis};
13661367
///
13671368
/// let mut a = arr2(&[[1., 2.],
13681369
/// [3., 4.]]);
13691370
///
1370-
/// a.subview_mut(1, 1).iadd_scalar(&10.);
1371+
/// a.subview_mut(Axis(1), 1).iadd_scalar(&10.);
13711372
///
13721373
/// assert!(
13731374
/// a == aview2(&[[1., 12.],
13741375
/// [3., 14.]])
13751376
/// );
13761377
/// ```
1377-
pub fn subview_mut(&mut self, axis: usize, index: Ix)
1378+
pub fn subview_mut(&mut self, axis: Axis, index: Ix)
13781379
-> ArrayViewMut<A, D::Smaller>
13791380
where S: DataMut,
13801381
D: RemoveAxis,
@@ -1386,19 +1387,21 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
13861387
/// and select the subview of `index` along that axis.
13871388
///
13881389
/// **Panics** if `index` is past the length of the axis.
1389-
pub fn isubview(&mut self, axis: usize, index: Ix) {
1390-
dimension::do_sub(&mut self.dim, &mut self.ptr, &self.strides, axis, index)
1390+
pub fn isubview(&mut self, axis: Axis, index: Ix) {
1391+
dimension::do_sub(&mut self.dim, &mut self.ptr, &self.strides,
1392+
axis.axis(), index)
13911393
}
13921394

13931395
/// Along `axis`, select the subview `index` and return `self`
13941396
/// with that axis removed.
13951397
///
13961398
/// See [`.subview()`](#method.subview) and [*Subviews*](#subviews) for full documentation.
1397-
pub fn into_subview(mut self, axis: usize, index: Ix)
1399+
pub fn into_subview(mut self, axis: Axis, index: Ix)
13981400
-> ArrayBase<S, <D as RemoveAxis>::Smaller>
1399-
where D: RemoveAxis
1401+
where D: RemoveAxis,
14001402
{
14011403
self.isubview(axis, index);
1404+
let axis = axis.axis();
14021405
// don't use reshape -- we always know it will fit the size,
14031406
// and we can use remove_axis on the strides as well
14041407
ArrayBase {
@@ -1450,15 +1453,16 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
14501453
/// Iterator element is `ArrayView<A, D::Smaller>` (read-only array view).
14511454
///
14521455
/// ```
1453-
/// use ndarray::arr3;
1456+
/// use ndarray::{arr3, Axis};
1457+
///
14541458
/// let a = arr3(&[[[ 0, 1, 2], // \ axis 0, submatrix 0
14551459
/// [ 3, 4, 5]], // /
14561460
/// [[ 6, 7, 8], // \ axis 0, submatrix 1
14571461
/// [ 9, 10, 11]]]); // /
14581462
/// // `outer_iter` yields the two submatrices along axis 0.
14591463
/// let mut iter = a.outer_iter();
1460-
/// assert_eq!(iter.next().unwrap(), a.subview(0, 0));
1461-
/// assert_eq!(iter.next().unwrap(), a.subview(0, 1));
1464+
/// assert_eq!(iter.next().unwrap(), a.subview(Axis(0), 0));
1465+
/// assert_eq!(iter.next().unwrap(), a.subview(Axis(0), 1));
14621466
/// ```
14631467
pub fn outer_iter(&self) -> OuterIter<A, D::Smaller>
14641468
where D: RemoveAxis,
@@ -1489,10 +1493,10 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
14891493
/// See [*Subviews*](#subviews) for full documentation.
14901494
///
14911495
/// **Panics** if `axis` is out of bounds.
1492-
pub fn axis_iter(&self, axis: usize) -> OuterIter<A, D::Smaller>
1493-
where D: RemoveAxis
1496+
pub fn axis_iter(&self, axis: Axis) -> OuterIter<A, D::Smaller>
1497+
where D: RemoveAxis,
14941498
{
1495-
iterators::new_axis_iter(self.view(), axis)
1499+
iterators::new_axis_iter(self.view(), axis.axis())
14961500
}
14971501

14981502

@@ -1503,11 +1507,11 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
15031507
/// (read-write array view).
15041508
///
15051509
/// **Panics** if `axis` is out of bounds.
1506-
pub fn axis_iter_mut(&mut self, axis: usize) -> OuterIterMut<A, D::Smaller>
1510+
pub fn axis_iter_mut(&mut self, axis: Axis) -> OuterIterMut<A, D::Smaller>
15071511
where S: DataMut,
15081512
D: RemoveAxis,
15091513
{
1510-
iterators::new_axis_iter_mut(self.view_mut(), axis)
1514+
iterators::new_axis_iter_mut(self.view_mut(), axis.axis())
15111515
}
15121516

15131517

@@ -1523,20 +1527,22 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
15231527
///
15241528
/// ```
15251529
/// use ndarray::OwnedArray;
1526-
/// use ndarray::arr3;
1530+
/// use ndarray::{arr3, Axis};
15271531
///
15281532
/// let a = OwnedArray::from_iter(0..28).into_shape((2, 7, 2)).unwrap();
1529-
/// let mut iter = a.axis_chunks_iter(1, 2);
1533+
/// let mut iter = a.axis_chunks_iter(Axis(1), 2);
15301534
///
15311535
/// // first iteration yields a 2 × 2 × 2 view
15321536
/// assert_eq!(iter.next().unwrap(),
1533-
/// arr3(&[[[0, 1], [2, 3]], [[14, 15], [16, 17]]]));
1537+
/// arr3(&[[[ 0, 1], [ 2, 3]],
1538+
/// [[14, 15], [16, 17]]]));
15341539
///
15351540
/// // however the last element is a 2 × 1 × 2 view since 7 % 2 == 1
1536-
/// assert_eq!(iter.next_back().unwrap(), arr3(&[[[12, 13]], [[26, 27]]]));
1541+
/// assert_eq!(iter.next_back().unwrap(), arr3(&[[[12, 13]],
1542+
/// [[26, 27]]]));
15371543
/// ```
1538-
pub fn axis_chunks_iter(&self, axis: usize, size: usize) -> AxisChunksIter<A, D> {
1539-
iterators::new_chunk_iter(self.view(), axis, size)
1544+
pub fn axis_chunks_iter(&self, axis: Axis, size: usize) -> AxisChunksIter<A, D> {
1545+
iterators::new_chunk_iter(self.view(), axis.axis(), size)
15401546
}
15411547

15421548
/// Return an iterator that traverses over `axis` by chunks of `size`,
@@ -1545,11 +1551,11 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
15451551
/// Iterator element is `ArrayViewMut<A, D>`
15461552
///
15471553
/// **Panics** if `axis` is out of bounds.
1548-
pub fn axis_chunks_iter_mut(&mut self, axis: usize, size: usize)
1554+
pub fn axis_chunks_iter_mut(&mut self, axis: Axis, size: usize)
15491555
-> AxisChunksIterMut<A, D>
15501556
where S: DataMut
15511557
{
1552-
iterators::new_chunk_iter_mut(self.view_mut(), axis, size)
1558+
iterators::new_chunk_iter_mut(self.view_mut(), axis.axis(), size)
15531559
}
15541560

15551561
// Return (length, stride) for diagonal
@@ -2301,24 +2307,24 @@ impl<A, S, D> ArrayBase<S, D>
23012307
/// Return sum along `axis`.
23022308
///
23032309
/// ```
2304-
/// use ndarray::{aview0, aview1, arr2};
2310+
/// use ndarray::{aview0, aview1, arr2, Axis};
23052311
///
23062312
/// let a = arr2(&[[1., 2.],
23072313
/// [3., 4.]]);
23082314
/// assert!(
2309-
/// a.sum(0) == aview1(&[4., 6.]) &&
2310-
/// a.sum(1) == aview1(&[3., 7.]) &&
2315+
/// a.sum(Axis(0)) == aview1(&[4., 6.]) &&
2316+
/// a.sum(Axis(1)) == aview1(&[3., 7.]) &&
23112317
///
2312-
/// a.sum(0).sum(0) == aview0(&10.)
2318+
/// a.sum(Axis(0)).sum(Axis(0)) == aview0(&10.)
23132319
/// );
23142320
/// ```
23152321
///
23162322
/// **Panics** if `axis` is out of bounds.
2317-
pub fn sum(&self, axis: usize) -> OwnedArray<A, <D as RemoveAxis>::Smaller>
2323+
pub fn sum(&self, axis: Axis) -> OwnedArray<A, <D as RemoveAxis>::Smaller>
23182324
where A: Clone + Add<Output=A>,
23192325
D: RemoveAxis,
23202326
{
2321-
let n = self.shape()[axis];
2327+
let n = self.shape()[axis.axis()];
23222328
let mut res = self.subview(axis, 0).to_owned();
23232329
for i in 1..n {
23242330
let view = self.subview(axis, i);
@@ -2355,24 +2361,23 @@ impl<A, S, D> ArrayBase<S, D>
23552361

23562362
/// Return mean along `axis`.
23572363
///
2364+
/// **Panics** if `axis` is out of bounds.
2365+
///
23582366
/// ```
2359-
/// use ndarray::{aview1, arr2};
2367+
/// use ndarray::{aview1, arr2, Axis};
23602368
///
23612369
/// let a = arr2(&[[1., 2.],
23622370
/// [3., 4.]]);
23632371
/// assert!(
2364-
/// a.mean(0) == aview1(&[2.0, 3.0]) &&
2365-
/// a.mean(1) == aview1(&[1.5, 3.5])
2372+
/// a.mean(Axis(0)) == aview1(&[2.0, 3.0]) &&
2373+
/// a.mean(Axis(1)) == aview1(&[1.5, 3.5])
23662374
/// );
23672375
/// ```
2368-
///
2369-
///
2370-
/// **Panics** if `axis` is out of bounds.
2371-
pub fn mean(&self, axis: usize) -> OwnedArray<A, <D as RemoveAxis>::Smaller>
2376+
pub fn mean(&self, axis: Axis) -> OwnedArray<A, <D as RemoveAxis>::Smaller>
23722377
where A: LinalgScalar,
23732378
D: RemoveAxis,
23742379
{
2375-
let n = self.shape()[axis];
2380+
let n = self.shape()[axis.axis()];
23762381
let mut sum = self.sum(axis);
23772382
let one = libnum::one::<A>();
23782383
let mut cnt = one;
@@ -2485,7 +2490,7 @@ impl<A, S> ArrayBase<S, (Ix, Ix)>
24852490
/// **Panics** if `index` is out of bounds.
24862491
pub fn row(&self, index: Ix) -> ArrayView<A, Ix>
24872492
{
2488-
self.subview(0, index)
2493+
self.subview(Axis(0), index)
24892494
}
24902495

24912496
/// Return a mutable array view of row `index`.
@@ -2494,15 +2499,15 @@ impl<A, S> ArrayBase<S, (Ix, Ix)>
24942499
pub fn row_mut(&mut self, index: Ix) -> ArrayViewMut<A, Ix>
24952500
where S: DataMut
24962501
{
2497-
self.subview_mut(0, index)
2502+
self.subview_mut(Axis(0), index)
24982503
}
24992504

25002505
/// Return an array view of column `index`.
25012506
///
25022507
/// **Panics** if `index` is out of bounds.
25032508
pub fn column(&self, index: Ix) -> ArrayView<A, Ix>
25042509
{
2505-
self.subview(1, index)
2510+
self.subview(Axis(1), index)
25062511
}
25072512

25082513
/// Return a mutable array view of column `index`.
@@ -2511,7 +2516,7 @@ impl<A, S> ArrayBase<S, (Ix, Ix)>
25112516
pub fn column_mut(&mut self, index: Ix) -> ArrayViewMut<A, Ix>
25122517
where S: DataMut
25132518
{
2514-
self.subview_mut(1, index)
2519+
self.subview_mut(Axis(1), index)
25152520
}
25162521

25172522
/// Perform matrix multiplication of rectangular arrays `self` and `rhs`.

0 commit comments

Comments
 (0)