@@ -24,7 +24,7 @@ use rand::Rng;
2424/// ```rust
2525/// use rand_distr::{Pert, Distribution};
2626///
27- /// let d = Pert::new(0., 5., 2.5).unwrap();
27+ /// let d = Pert::new(0., 5.).with_mode( 2.5).unwrap();
2828/// let v = d.sample(&mut rand::thread_rng());
2929/// println!("{} is from a PERT distribution", v);
3030/// ```
@@ -75,27 +75,71 @@ where
7575 Exp1 : Distribution < F > ,
7676 Open01 : Distribution < F > ,
7777{
78- /// Set up the PERT distribution with defined `min`, `max` and `mode`.
78+ /// Construct a PERT distribution with defined `min`, `max`
7979 ///
80- /// This is equivalent to calling `Pert::new_with_shape` with `shape == 4.0`.
80+ /// # Example
81+ ///
82+ /// ```
83+ /// use rand_distr::Pert;
84+ /// let pert_dist = Pert::new(0.0, 10.0)
85+ /// .with_shape(3.5)
86+ /// .with_mean(3.0)
87+ /// .unwrap();
88+ /// # let _unused: Pert<f64> = pert_dist;
89+ /// ```
8190 #[ inline]
82- pub fn new ( min : F , max : F , mode : F ) -> Result < Pert < F > , PertError > {
83- Pert :: new_with_shape ( min, max, mode, F :: from ( 4. ) . unwrap ( ) )
91+ pub fn new ( min : F , max : F ) -> PertBuilder < F > {
92+ let shape = F :: from ( 4.0 ) . unwrap ( ) ;
93+ PertBuilder { min, max, shape }
8494 }
95+ }
96+
97+ /// Struct used to build a [`Pert`]
98+ #[ derive( Debug ) ]
99+ pub struct PertBuilder < F > {
100+ min : F ,
101+ max : F ,
102+ shape : F ,
103+ }
85104
86- /// Set up the PERT distribution with defined `min`, `max`, `mode` and
87- /// `shape`.
88- pub fn new_with_shape ( min : F , max : F , mode : F , shape : F ) -> Result < Pert < F > , PertError > {
89- if !( max > min) {
105+ impl < F > PertBuilder < F >
106+ where
107+ F : Float ,
108+ StandardNormal : Distribution < F > ,
109+ Exp1 : Distribution < F > ,
110+ Open01 : Distribution < F > ,
111+ {
112+ /// Set the shape parameter
113+ ///
114+ /// If not specified, this defaults to 4.
115+ #[ inline]
116+ pub fn with_shape ( mut self , shape : F ) -> PertBuilder < F > {
117+ self . shape = shape;
118+ self
119+ }
120+
121+ /// Specify the mean
122+ #[ inline]
123+ pub fn with_mean ( self , mean : F ) -> Result < Pert < F > , PertError > {
124+ let two = F :: from ( 2.0 ) . unwrap ( ) ;
125+ let mode = ( ( self . shape + two) * mean - self . min - self . max ) / self . shape ;
126+ self . with_mode ( mode)
127+ }
128+
129+ /// Specify the mode
130+ #[ inline]
131+ pub fn with_mode ( self , mode : F ) -> Result < Pert < F > , PertError > {
132+ if !( self . max > self . min ) {
90133 return Err ( PertError :: RangeTooSmall ) ;
91134 }
92- if !( mode >= min && max >= mode) {
135+ if !( mode >= self . min && self . max >= mode) {
93136 return Err ( PertError :: ModeRange ) ;
94137 }
95- if !( shape >= F :: from ( 0. ) . unwrap ( ) ) {
138+ if !( self . shape >= F :: from ( 0. ) . unwrap ( ) ) {
96139 return Err ( PertError :: ShapeTooSmall ) ;
97140 }
98141
142+ let ( min, max, shape) = ( self . min , self . max , self . shape ) ;
99143 let range = max - min;
100144 let v = F :: from ( 1.0 ) . unwrap ( ) + shape * ( mode - min) / range;
101145 let w = F :: from ( 1.0 ) . unwrap ( ) + shape * ( max - mode) / range;
@@ -124,34 +168,38 @@ mod test {
124168 #[ test]
125169 fn test_pert ( ) {
126170 for & ( min, max, mode) in & [ ( -1. , 1. , 0. ) , ( 1. , 2. , 1. ) , ( 5. , 25. , 25. ) ] {
127- let _distr = Pert :: new ( min, max, mode) . unwrap ( ) ;
171+ let _distr = Pert :: new ( min, max) . with_mode ( mode) . unwrap ( ) ;
128172 // TODO: test correctness
129173 }
130174
131175 for & ( min, max, mode) in & [ ( -1. , 1. , 2. ) , ( -1. , 1. , -2. ) , ( 2. , 1. , 1. ) ] {
132- assert ! ( Pert :: new( min, max, mode) . is_err( ) ) ;
176+ assert ! ( Pert :: new( min, max) . with_mode ( mode) . is_err( ) ) ;
133177 }
134178 }
135179
136180 #[ test]
137181 fn distributions_can_be_compared ( ) {
138- assert_eq ! ( Pert :: new( 1.0 , 3.0 , 2.0 ) , Pert :: new( 1.0 , 3.0 , 2.0 ) ) ;
182+ let ( min, mode, max, shape) = ( 1.0 , 2.0 , 3.0 , 4.0 ) ;
183+ let p1 = Pert :: new ( min, max) . with_mode ( mode) . unwrap ( ) ;
184+ let mean = ( min + shape * mode + max) / ( shape + 2.0 ) ;
185+ let p2 = Pert :: new ( min, max) . with_mean ( mean) . unwrap ( ) ;
186+ assert_eq ! ( p1, p2) ;
139187 }
140188
141189 #[ test]
142190 fn mode_almost_half_range ( ) {
143- assert ! ( Pert :: new( 0.0f32 , 0.48258883 , 0.24129441 ) . is_ok( ) ) ;
191+ assert ! ( Pert :: new( 0.0f32 , 0.48258883 ) . with_mode ( 0.24129441 ) . is_ok( ) ) ;
144192 }
145193
146194 #[ test]
147195 fn almost_symmetric_about_zero ( ) {
148- let distr = Pert :: new ( -10f32 , 10f32 , f32:: EPSILON ) ;
196+ let distr = Pert :: new ( -10f32 , 10f32 ) . with_mode ( f32:: EPSILON ) ;
149197 assert ! ( distr. is_ok( ) ) ;
150198 }
151199
152200 #[ test]
153201 fn almost_symmetric ( ) {
154- let distr = Pert :: new ( 0f32 , 2f32 , 1f32 + f32:: EPSILON ) ;
202+ let distr = Pert :: new ( 0f32 , 2f32 ) . with_mode ( 1f32 + f32:: EPSILON ) ;
155203 assert ! ( distr. is_ok( ) ) ;
156204 }
157205}
0 commit comments