29
29
//! that the items are not compatible (e.g. that a type doesn't implement a
30
30
//! necessary trait).
31
31
32
- use crate :: rand:: distributions:: Distribution ;
32
+ use crate :: rand:: distributions:: { Distribution , Uniform } ;
33
33
use crate :: rand:: rngs:: SmallRng ;
34
+ use crate :: rand:: seq:: index;
34
35
use crate :: rand:: { thread_rng, Rng , SeedableRng } ;
35
36
36
- use ndarray:: ShapeBuilder ;
37
+ use ndarray:: { Array , Axis , RemoveAxis , ShapeBuilder } ;
37
38
use ndarray:: { ArrayBase , DataOwned , Dimension } ;
39
+ #[ cfg( feature = "quickcheck" ) ]
40
+ use quickcheck:: { Arbitrary , Gen } ;
38
41
39
42
/// [`rand`](https://docs.rs/rand/0.7), re-exported for convenience and version-compatibility.
40
43
pub mod rand {
@@ -59,9 +62,9 @@ pub mod rand_distr {
59
62
/// low-quality random numbers, and reproducibility is not guaranteed. See its
60
63
/// documentation for information. You can select a different RNG with
61
64
/// [`.random_using()`](#tymethod.random_using).
62
- pub trait RandomExt < S , D >
65
+ pub trait RandomExt < S , A , D >
63
66
where
64
- S : DataOwned ,
67
+ S : DataOwned < Elem = A > ,
65
68
D : Dimension ,
66
69
{
67
70
/// Create an array with shape `dim` with elements drawn from
@@ -116,21 +119,125 @@ where
116
119
IdS : Distribution < S :: Elem > ,
117
120
R : Rng + ?Sized ,
118
121
Sh : ShapeBuilder < Dim = D > ;
122
+
123
+ /// Sample `n_samples` lanes slicing along `axis` using the default RNG.
124
+ ///
125
+ /// If `strategy==SamplingStrategy::WithoutReplacement`, each lane can only be sampled once.
126
+ /// If `strategy==SamplingStrategy::WithReplacement`, each lane can be sampled multiple times.
127
+ ///
128
+ /// ***Panics*** when:
129
+ /// - creation of the RNG fails;
130
+ /// - `n_samples` is greater than the length of `axis` (if sampling without replacement);
131
+ /// - length of `axis` is 0.
132
+ ///
133
+ /// ```
134
+ /// use ndarray::{array, Axis};
135
+ /// use ndarray_rand::{RandomExt, SamplingStrategy};
136
+ ///
137
+ /// # fn main() {
138
+ /// let a = array![
139
+ /// [1., 2., 3.],
140
+ /// [4., 5., 6.],
141
+ /// [7., 8., 9.],
142
+ /// [10., 11., 12.],
143
+ /// ];
144
+ /// // Sample 2 rows, without replacement
145
+ /// let sample_rows = a.sample_axis(Axis(0), 2, SamplingStrategy::WithoutReplacement);
146
+ /// println!("{:?}", sample_rows);
147
+ /// // Example Output: (1st and 3rd rows)
148
+ /// // [
149
+ /// // [1., 2., 3.],
150
+ /// // [7., 8., 9.]
151
+ /// // ]
152
+ /// // Sample 2 columns, with replacement
153
+ /// let sample_columns = a.sample_axis(Axis(1), 1, SamplingStrategy::WithReplacement);
154
+ /// println!("{:?}", sample_columns);
155
+ /// // Example Output: (2nd column, sampled twice)
156
+ /// // [
157
+ /// // [2., 2.],
158
+ /// // [5., 5.],
159
+ /// // [8., 8.],
160
+ /// // [11., 11.]
161
+ /// // ]
162
+ /// # }
163
+ /// ```
164
+ fn sample_axis ( & self , axis : Axis , n_samples : usize , strategy : SamplingStrategy ) -> Array < A , D >
165
+ where
166
+ A : Copy ,
167
+ D : RemoveAxis ;
168
+
169
+ /// Sample `n_samples` lanes slicing along `axis` using the specified RNG `rng`.
170
+ ///
171
+ /// If `strategy==SamplingStrategy::WithoutReplacement`, each lane can only be sampled once.
172
+ /// If `strategy==SamplingStrategy::WithReplacement`, each lane can be sampled multiple times.
173
+ ///
174
+ /// ***Panics*** when:
175
+ /// - creation of the RNG fails;
176
+ /// - `n_samples` is greater than the length of `axis` (if sampling without replacement);
177
+ /// - length of `axis` is 0.
178
+ ///
179
+ /// ```
180
+ /// use ndarray::{array, Axis};
181
+ /// use ndarray_rand::{RandomExt, SamplingStrategy};
182
+ /// use ndarray_rand::rand::SeedableRng;
183
+ /// use rand_isaac::isaac64::Isaac64Rng;
184
+ ///
185
+ /// # fn main() {
186
+ /// // Get a seeded random number generator for reproducibility (Isaac64 algorithm)
187
+ /// let seed = 42;
188
+ /// let mut rng = Isaac64Rng::seed_from_u64(seed);
189
+ ///
190
+ /// let a = array![
191
+ /// [1., 2., 3.],
192
+ /// [4., 5., 6.],
193
+ /// [7., 8., 9.],
194
+ /// [10., 11., 12.],
195
+ /// ];
196
+ /// // Sample 2 rows, without replacement
197
+ /// let sample_rows = a.sample_axis_using(Axis(0), 2, SamplingStrategy::WithoutReplacement, &mut rng);
198
+ /// println!("{:?}", sample_rows);
199
+ /// // Example Output: (1st and 3rd rows)
200
+ /// // [
201
+ /// // [1., 2., 3.],
202
+ /// // [7., 8., 9.]
203
+ /// // ]
204
+ ///
205
+ /// // Sample 2 columns, with replacement
206
+ /// let sample_columns = a.sample_axis_using(Axis(1), 1, SamplingStrategy::WithReplacement, &mut rng);
207
+ /// println!("{:?}", sample_columns);
208
+ /// // Example Output: (2nd column, sampled twice)
209
+ /// // [
210
+ /// // [2., 2.],
211
+ /// // [5., 5.],
212
+ /// // [8., 8.],
213
+ /// // [11., 11.]
214
+ /// // ]
215
+ /// # }
216
+ /// ```
217
+ fn sample_axis_using < R > (
218
+ & self ,
219
+ axis : Axis ,
220
+ n_samples : usize ,
221
+ strategy : SamplingStrategy ,
222
+ rng : & mut R ,
223
+ ) -> Array < A , D >
224
+ where
225
+ R : Rng + ?Sized ,
226
+ A : Copy ,
227
+ D : RemoveAxis ;
119
228
}
120
229
121
- impl < S , D > RandomExt < S , D > for ArrayBase < S , D >
230
+ impl < S , A , D > RandomExt < S , A , D > for ArrayBase < S , D >
122
231
where
123
- S : DataOwned ,
232
+ S : DataOwned < Elem = A > ,
124
233
D : Dimension ,
125
234
{
126
235
fn random < Sh , IdS > ( shape : Sh , dist : IdS ) -> ArrayBase < S , D >
127
236
where
128
237
IdS : Distribution < S :: Elem > ,
129
238
Sh : ShapeBuilder < Dim = D > ,
130
239
{
131
- let mut rng =
132
- SmallRng :: from_rng ( thread_rng ( ) ) . expect ( "create SmallRng from thread_rng failed" ) ;
133
- Self :: random_using ( shape, dist, & mut rng)
240
+ Self :: random_using ( shape, dist, & mut get_rng ( ) )
134
241
}
135
242
136
243
fn random_using < Sh , IdS , R > ( shape : Sh , dist : IdS , rng : & mut R ) -> ArrayBase < S , D >
@@ -141,6 +248,66 @@ where
141
248
{
142
249
Self :: from_shape_fn ( shape, |_| dist. sample ( rng) )
143
250
}
251
+
252
+ fn sample_axis ( & self , axis : Axis , n_samples : usize , strategy : SamplingStrategy ) -> Array < A , D >
253
+ where
254
+ A : Copy ,
255
+ D : RemoveAxis ,
256
+ {
257
+ self . sample_axis_using ( axis, n_samples, strategy, & mut get_rng ( ) )
258
+ }
259
+
260
+ fn sample_axis_using < R > (
261
+ & self ,
262
+ axis : Axis ,
263
+ n_samples : usize ,
264
+ strategy : SamplingStrategy ,
265
+ rng : & mut R ,
266
+ ) -> Array < A , D >
267
+ where
268
+ R : Rng + ?Sized ,
269
+ A : Copy ,
270
+ D : RemoveAxis ,
271
+ {
272
+ let indices: Vec < _ > = match strategy {
273
+ SamplingStrategy :: WithReplacement => {
274
+ let distribution = Uniform :: from ( 0 ..self . len_of ( axis) ) ;
275
+ ( 0 ..n_samples) . map ( |_| distribution. sample ( rng) ) . collect ( )
276
+ }
277
+ SamplingStrategy :: WithoutReplacement => {
278
+ index:: sample ( rng, self . len_of ( axis) , n_samples) . into_vec ( )
279
+ }
280
+ } ;
281
+ self . select ( axis, & indices)
282
+ }
283
+ }
284
+
285
+ /// Used as parameter in [`sample_axis`] and [`sample_axis_using`] to determine
286
+ /// if lanes from the original array should only be sampled once (*without replacement*) or
287
+ /// multiple times (*with replacement*).
288
+ ///
289
+ /// [`sample_axis`]: trait.RandomExt.html#tymethod.sample_axis
290
+ /// [`sample_axis_using`]: trait.RandomExt.html#tymethod.sample_axis_using
291
+ #[ derive( Debug , Clone ) ]
292
+ pub enum SamplingStrategy {
293
+ WithReplacement ,
294
+ WithoutReplacement ,
295
+ }
296
+
297
+ // `Arbitrary` enables `quickcheck` to generate random `SamplingStrategy` values for testing.
298
+ #[ cfg( feature = "quickcheck" ) ]
299
+ impl Arbitrary for SamplingStrategy {
300
+ fn arbitrary < G : Gen > ( g : & mut G ) -> Self {
301
+ if g. gen_bool ( 0.5 ) {
302
+ SamplingStrategy :: WithReplacement
303
+ } else {
304
+ SamplingStrategy :: WithoutReplacement
305
+ }
306
+ }
307
+ }
308
+
309
+ fn get_rng ( ) -> SmallRng {
310
+ SmallRng :: from_rng ( thread_rng ( ) ) . expect ( "create SmallRng from thread_rng failed" )
144
311
}
145
312
146
313
/// A wrapper type that allows casting f64 distributions to f32
0 commit comments