@@ -101,9 +101,10 @@ class PGBART(ArrayStepShared):
101
101
Number of particles for the conditional SMC sampler. Defaults to 10
102
102
max_stages : int
103
103
Maximum number of iterations of the conditional SMC sampler. Defaults to 100.
104
- batch : int
104
+ batch : int or tuple
105
105
Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees
106
- during tuning and 20% after tuning.
106
+ during tuning and 20% after tuning. If a tuple is passed the first element is the batch size
107
+ during tuning and the second the batch size after tuning.
107
108
model: PyMC Model
108
109
Optional model for sampling step. Defaults to None (taken from context).
109
110
@@ -138,9 +139,9 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
138
139
self .alpha = self .bart .alpha
139
140
self .k = self .bart .k
140
141
self .response = self .bart .response
141
- self .split_prior = self .bart .split_prior
142
- if self .split_prior is None :
143
- self .split_prior = np .ones (self .X .shape [1 ])
142
+ self .alpha_vec = self .bart .split_prior
143
+ if self .alpha_vec is None :
144
+ self .alpha_vec = np .ones (self .X .shape [1 ])
144
145
145
146
self .init_mean = self .Y .mean ()
146
147
# if data is binary
@@ -149,7 +150,7 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
149
150
self .mu_std = 6 / (self .k * self .m ** 0.5 )
150
151
# maybe we need to check for count data
151
152
else :
152
- self .mu_std = self .Y .std () / (self .k * self .m ** 0.5 )
153
+ self .mu_std = ( 2 * self .Y .std () ) / (self .k * self .m ** 0.5 )
153
154
154
155
self .num_observations = self .X .shape [0 ]
155
156
self .num_variates = self .X .shape [1 ]
@@ -167,14 +168,18 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
167
168
168
169
self .normal = NormalSampler ()
169
170
self .prior_prob_leaf_node = compute_prior_probability (self .alpha )
170
- self .ssv = SampleSplittingVariable (self .split_prior )
171
+ self .ssv = SampleSplittingVariable (self .alpha_vec )
171
172
172
173
self .tune = True
173
- self .idx = 0
174
- self .batch = batch
175
174
176
- if self .batch == "auto" :
177
- self .batch = max (1 , int (self .m * 0.1 ))
175
+ if batch == "auto" :
176
+ self .batch = (max (1 , int (self .m * 0.1 )), max (1 , int (self .m * 0.2 )))
177
+ else :
178
+ if isinstance (batch , (tuple , list )):
179
+ self .batch = batch
180
+ else :
181
+ self .batch = (batch , batch )
182
+
178
183
self .log_num_particles = np .log (num_particles )
179
184
self .indices = list (range (1 , num_particles ))
180
185
self .len_indices = len (self .indices )
@@ -187,6 +192,9 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
187
192
self .all_particles = []
188
193
for i in range (self .m ):
189
194
self .a_tree .tree_id = i
195
+ self .a_tree .leaf_node_value = (
196
+ self .init_mean / self .m + self .normal .random () * self .mu_std ,
197
+ )
190
198
p = ParticleTree (
191
199
self .a_tree ,
192
200
self .init_log_weight ,
@@ -201,20 +209,16 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
201
209
sum_trees_output = q .data
202
210
variable_inclusion = np .zeros (self .num_variates , dtype = "int" )
203
211
204
- if self .idx == self .m :
205
- self .idx = 0
206
-
207
- for tree_id in range (self .idx , self .idx + self .batch ):
208
- if tree_id >= self .m :
209
- break
212
+ tree_ids = np .random .randint (0 , self .m , size = self .batch [~ self .tune ])
213
+ for tree_id in tree_ids :
210
214
# Generate an initial set of SMC particles
211
215
# at the end of the algorithm we return one of these particles as the new tree
212
216
particles = self .init_particles (tree_id )
213
217
# Compute the sum of trees without the tree we are attempting to replace
214
218
self .sum_trees_output_noi = sum_trees_output - particles [0 ].tree .predict_output ()
215
219
216
220
# The old tree is not growing so we update the weights only once.
217
- self .update_weight (particles [0 ])
221
+ self .update_weight (particles [0 ], new = True )
218
222
for t in range (self .max_stages ):
219
223
# Sample each particle (try to grow each tree), except for the first one.
220
224
for p in particles [1 :]:
@@ -235,15 +239,15 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
235
239
if tree_grew :
236
240
self .update_weight (p )
237
241
# Normalize weights
238
- W_t , normalized_weights = self .normalize (particles )
242
+ W_t , normalized_weights = self .normalize (particles [ 1 :] )
239
243
240
244
# Resample all but first particle
241
- re_n_w = normalized_weights [ 1 :] / normalized_weights [ 1 :]. sum ()
245
+ re_n_w = normalized_weights
242
246
new_indices = np .random .choice (self .indices , size = self .len_indices , p = re_n_w )
243
247
particles [1 :] = particles [new_indices ]
244
248
245
249
# Set the new weights
246
- for p in particles :
250
+ for p in particles [ 1 :] :
247
251
p .log_weight = W_t
248
252
249
253
# Check if particles can keep growing, otherwise stop iterating
@@ -254,23 +258,25 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
254
258
if all (non_available_nodes_for_expansion ):
255
259
break
256
260
261
+ for p in particles [1 :]:
262
+ p .log_weight = p .old_likelihood_logp
263
+
264
+ _ , normalized_weights = self .normalize (particles )
257
265
# Get the new tree and update
258
266
new_particle = np .random .choice (particles , p = normalized_weights )
259
267
new_tree = new_particle .tree
260
- self .all_trees [self . idx ] = new_tree
268
+ self .all_trees [tree_id ] = new_tree
261
269
new_particle .log_weight = new_particle .old_likelihood_logp - self .log_num_particles
262
270
self .all_particles [tree_id ] = new_particle
263
271
sum_trees_output = self .sum_trees_output_noi + new_tree .predict_output ()
264
272
265
273
if self .tune :
274
+ self .ssv = SampleSplittingVariable (self .alpha_vec )
266
275
for index in new_particle .used_variates :
267
- self .split_prior [index ] += 1
268
- self .ssv = SampleSplittingVariable (self .split_prior )
276
+ self .alpha_vec [index ] += 1
269
277
else :
270
- self .batch = max (1 , int (self .m * 0.2 ))
271
278
for index in new_particle .used_variates :
272
279
variable_inclusion [index ] += 1
273
- self .idx += 1
274
280
275
281
stats = {"variable_inclusion" : variable_inclusion , "bart_trees" : copy (self .all_trees )}
276
282
sum_trees_output = RaveledVars (sum_trees_output , point_map_info )
@@ -323,7 +329,7 @@ def init_particles(self, tree_id: int) -> np.ndarray:
323
329
324
330
return np .array (particles )
325
331
326
- def update_weight (self , particle : List [ParticleTree ]) -> None :
332
+ def update_weight (self , particle : List [ParticleTree ], new = False ) -> None :
327
333
"""
328
334
Update the weight of a particle
329
335
@@ -333,20 +339,22 @@ def update_weight(self, particle: List[ParticleTree]) -> None:
333
339
new_likelihood = self .likelihood_logp (
334
340
self .sum_trees_output_noi + particle .tree .predict_output ()
335
341
)
336
- particle .log_weight += new_likelihood - particle .old_likelihood_logp
337
- particle .old_likelihood_logp = new_likelihood
342
+ if new :
343
+ particle .log_weight = new_likelihood
344
+ else :
345
+ particle .log_weight += new_likelihood - particle .old_likelihood_logp
346
+ particle .old_likelihood_logp = new_likelihood
338
347
339
348
340
349
class SampleSplittingVariable :
341
- def __init__ (self , alpha_prior ):
350
+ def __init__ (self , alpha_vec ):
342
351
"""
343
- Sample splitting variables proportional to `alpha_prior `.
352
+ Sample splitting variables proportional to `alpha_vec `.
344
353
345
- This is equivalent as sampling weights from a Dirichlet distribution with `alpha_prior`
346
- parameter and then using those weights to sample from the available spliting variables.
354
+ This is equivalent to compute the posterior mean of a Dirichlet-Multinomial model.
347
355
This enforce sparsity.
348
356
"""
349
- self .enu = list (enumerate (np .cumsum (alpha_prior / alpha_prior .sum ())))
357
+ self .enu = list (enumerate (np .cumsum (alpha_vec / alpha_vec .sum ())))
350
358
351
359
def rvs (self ):
352
360
r = np .random .random ()
0 commit comments