Skip to content

Commit b51268e

Browse files
committed
Add PertBuilder; allow specification via mean or mode
1 parent 3888d88 commit b51268e

File tree

2 files changed

+66
-18
lines changed

2 files changed

+66
-18
lines changed

rand_distr/src/pert.rs

Lines changed: 65 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use rand::Rng;
2424
/// ```rust
2525
/// use rand_distr::{Pert, Distribution};
2626
///
27-
/// let d = Pert::new(0., 5., 2.5).unwrap();
27+
/// let d = Pert::new(0., 5.).with_mode(2.5).unwrap();
2828
/// let v = d.sample(&mut rand::thread_rng());
2929
/// println!("{} is from a PERT distribution", v);
3030
/// ```
@@ -75,27 +75,71 @@ where
7575
Exp1: Distribution<F>,
7676
Open01: Distribution<F>,
7777
{
78-
/// Set up the PERT distribution with defined `min`, `max` and `mode`.
78+
/// Construct a PERT distribution with defined `min`, `max`
7979
///
80-
/// This is equivalent to calling `Pert::new_with_shape` with `shape == 4.0`.
80+
/// # Example
81+
///
82+
/// ```
83+
/// use rand_distr::Pert;
84+
/// let pert_dist = Pert::new(0.0, 10.0)
85+
/// .with_shape(3.5)
86+
/// .with_mean(3.0)
87+
/// .unwrap();
88+
/// # let _unused: Pert<f64> = pert_dist;
89+
/// ```
8190
#[inline]
82-
pub fn new(min: F, max: F, mode: F) -> Result<Pert<F>, PertError> {
83-
Pert::new_with_shape(min, max, mode, F::from(4.).unwrap())
91+
pub fn new(min: F, max: F) -> PertBuilder<F> {
92+
let shape = F::from(4.0).unwrap();
93+
PertBuilder { min, max, shape }
8494
}
95+
}
96+
97+
/// Struct used to build a [`Pert`]
98+
#[derive(Debug)]
99+
pub struct PertBuilder<F> {
100+
min: F,
101+
max: F,
102+
shape: F,
103+
}
85104

86-
/// Set up the PERT distribution with defined `min`, `max`, `mode` and
87-
/// `shape`.
88-
pub fn new_with_shape(min: F, max: F, mode: F, shape: F) -> Result<Pert<F>, PertError> {
89-
if !(max > min) {
105+
impl<F> PertBuilder<F>
106+
where
107+
F: Float,
108+
StandardNormal: Distribution<F>,
109+
Exp1: Distribution<F>,
110+
Open01: Distribution<F>,
111+
{
112+
/// Set the shape parameter
113+
///
114+
/// If not specified, this defaults to 4.
115+
#[inline]
116+
pub fn with_shape(mut self, shape: F) -> PertBuilder<F> {
117+
self.shape = shape;
118+
self
119+
}
120+
121+
/// Specify the mean
122+
#[inline]
123+
pub fn with_mean(self, mean: F) -> Result<Pert<F>, PertError> {
124+
let two = F::from(2.0).unwrap();
125+
let mode = ((self.shape + two) * mean - self.min - self.max) / self.shape;
126+
self.with_mode(mode)
127+
}
128+
129+
/// Specify the mode
130+
#[inline]
131+
pub fn with_mode(self, mode: F) -> Result<Pert<F>, PertError> {
132+
if !(self.max > self.min) {
90133
return Err(PertError::RangeTooSmall);
91134
}
92-
if !(mode >= min && max >= mode) {
135+
if !(mode >= self.min && self.max >= mode) {
93136
return Err(PertError::ModeRange);
94137
}
95-
if !(shape >= F::from(0.).unwrap()) {
138+
if !(self.shape >= F::from(0.).unwrap()) {
96139
return Err(PertError::ShapeTooSmall);
97140
}
98141

142+
let (min, max, shape) = (self.min, self.max, self.shape);
99143
let range = max - min;
100144
let v = F::from(1.0).unwrap() + shape * (mode - min) / range;
101145
let w = F::from(1.0).unwrap() + shape * (max - mode) / range;
@@ -124,34 +168,38 @@ mod test {
124168
#[test]
125169
fn test_pert() {
126170
for &(min, max, mode) in &[(-1., 1., 0.), (1., 2., 1.), (5., 25., 25.)] {
127-
let _distr = Pert::new(min, max, mode).unwrap();
171+
let _distr = Pert::new(min, max).with_mode(mode).unwrap();
128172
// TODO: test correctness
129173
}
130174

131175
for &(min, max, mode) in &[(-1., 1., 2.), (-1., 1., -2.), (2., 1., 1.)] {
132-
assert!(Pert::new(min, max, mode).is_err());
176+
assert!(Pert::new(min, max).with_mode(mode).is_err());
133177
}
134178
}
135179

136180
#[test]
137181
fn distributions_can_be_compared() {
138-
assert_eq!(Pert::new(1.0, 3.0, 2.0), Pert::new(1.0, 3.0, 2.0));
182+
let (min, mode, max, shape) = (1.0, 2.0, 3.0, 4.0);
183+
let p1 = Pert::new(min, max).with_mode(mode).unwrap();
184+
let mean = (min + shape * mode + max) / (shape + 2.0);
185+
let p2 = Pert::new(min, max).with_mean(mean).unwrap();
186+
assert_eq!(p1, p2);
139187
}
140188

141189
#[test]
142190
fn mode_almost_half_range() {
143-
assert!(Pert::new(0.0f32, 0.48258883, 0.24129441).is_ok());
191+
assert!(Pert::new(0.0f32, 0.48258883).with_mode(0.24129441).is_ok());
144192
}
145193

146194
#[test]
147195
fn almost_symmetric_about_zero() {
148-
let distr = Pert::new(-10f32, 10f32, f32::EPSILON);
196+
let distr = Pert::new(-10f32, 10f32).with_mode(f32::EPSILON);
149197
assert!(distr.is_ok());
150198
}
151199

152200
#[test]
153201
fn almost_symmetric() {
154-
let distr = Pert::new(0f32, 2f32, 1f32 + f32::EPSILON);
202+
let distr = Pert::new(0f32, 2f32).with_mode(1f32 + f32::EPSILON);
155203
assert!(distr.is_ok());
156204
}
157205
}

rand_distr/tests/value_stability.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ fn pert_stability() {
250250
// mean = 4, var = 12/7
251251
test_samples(
252252
860,
253-
Pert::new(2., 10., 3.).unwrap(),
253+
Pert::new(2., 10.).with_mode(3.).unwrap(),
254254
&[
255255
4.908681667460367,
256256
4.014196196158352,

0 commit comments

Comments
 (0)