Skip to content

Commit c0c5a80

Browse files
authored
BART: improve accuracy and other minor fixes (#5177)
* improve accuracy and other minor fixes * update release notes * fix typo
1 parent a11eaa2 commit c0c5a80

File tree

2 files changed

+45
-36
lines changed

2 files changed

+45
-36
lines changed

RELEASE-NOTES.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,9 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01
8686
- `pm.DensityDist` can now accept an optional `logcdf` keyword argument to pass in a function to compute the cummulative density function of the distribution (see [5026](https://github.com/pymc-devs/pymc/pull/5026)).
8787
- `pm.DensityDist` can now accept an optional `get_moment` keyword argument to pass in a function to compute the moment of the distribution (see [5026](https://github.com/pymc-devs/pymc/pull/5026)).
8888
- New features for BART:
89-
- Added linear response, increased number of trees fitted per step [5044](https://github.com/pymc-devs/pymc3/pull/5044).
90-
- Added partial dependence plots and individual conditional expectation plots [5091](https://github.com/pymc-devs/pymc3/pull/5091).
89+
- Added linear response, increased number of trees fitted per step [5044](https://github.com/pymc-devs/pymc3/pull/5044).
90+
- Added partial dependence plots and individual conditional expectation plots [5091](https://github.com/pymc-devs/pymc3/pull/5091).
91+
- Modify how particle weights are computed. This improves accuracy of the modeled function (see [5177](https://github.com/pymc-devs/pymc3/pull/5177)).
9192
- `pm.Data` now passes additional kwargs to `aesara.shared`. [#5098](https://github.com/pymc-devs/pymc/pull/5098)
9293
- ...
9394

pymc/bart/pgbart.py

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,10 @@ class PGBART(ArrayStepShared):
101101
Number of particles for the conditional SMC sampler. Defaults to 10
102102
max_stages : int
103103
Maximum number of iterations of the conditional SMC sampler. Defaults to 100.
104-
batch : int
104+
batch : int or tuple
105105
Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees
106-
during tuning and 20% after tuning.
106+
during tuning and 20% after tuning. If a tuple is passed the first element is the batch size
107+
during tuning and the second the batch size after tuning.
107108
model: PyMC Model
108109
Optional model for sampling step. Defaults to None (taken from context).
109110
@@ -138,9 +139,9 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
138139
self.alpha = self.bart.alpha
139140
self.k = self.bart.k
140141
self.response = self.bart.response
141-
self.split_prior = self.bart.split_prior
142-
if self.split_prior is None:
143-
self.split_prior = np.ones(self.X.shape[1])
142+
self.alpha_vec = self.bart.split_prior
143+
if self.alpha_vec is None:
144+
self.alpha_vec = np.ones(self.X.shape[1])
144145

145146
self.init_mean = self.Y.mean()
146147
# if data is binary
@@ -149,7 +150,7 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
149150
self.mu_std = 6 / (self.k * self.m ** 0.5)
150151
# maybe we need to check for count data
151152
else:
152-
self.mu_std = self.Y.std() / (self.k * self.m ** 0.5)
153+
self.mu_std = (2 * self.Y.std()) / (self.k * self.m ** 0.5)
153154

154155
self.num_observations = self.X.shape[0]
155156
self.num_variates = self.X.shape[1]
@@ -167,14 +168,18 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
167168

168169
self.normal = NormalSampler()
169170
self.prior_prob_leaf_node = compute_prior_probability(self.alpha)
170-
self.ssv = SampleSplittingVariable(self.split_prior)
171+
self.ssv = SampleSplittingVariable(self.alpha_vec)
171172

172173
self.tune = True
173-
self.idx = 0
174-
self.batch = batch
175174

176-
if self.batch == "auto":
177-
self.batch = max(1, int(self.m * 0.1))
175+
if batch == "auto":
176+
self.batch = (max(1, int(self.m * 0.1)), max(1, int(self.m * 0.2)))
177+
else:
178+
if isinstance(batch, (tuple, list)):
179+
self.batch = batch
180+
else:
181+
self.batch = (batch, batch)
182+
178183
self.log_num_particles = np.log(num_particles)
179184
self.indices = list(range(1, num_particles))
180185
self.len_indices = len(self.indices)
@@ -187,6 +192,9 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
187192
self.all_particles = []
188193
for i in range(self.m):
189194
self.a_tree.tree_id = i
195+
self.a_tree.leaf_node_value = (
196+
self.init_mean / self.m + self.normal.random() * self.mu_std,
197+
)
190198
p = ParticleTree(
191199
self.a_tree,
192200
self.init_log_weight,
@@ -201,20 +209,16 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
201209
sum_trees_output = q.data
202210
variable_inclusion = np.zeros(self.num_variates, dtype="int")
203211

204-
if self.idx == self.m:
205-
self.idx = 0
206-
207-
for tree_id in range(self.idx, self.idx + self.batch):
208-
if tree_id >= self.m:
209-
break
212+
tree_ids = np.random.randint(0, self.m, size=self.batch[~self.tune])
213+
for tree_id in tree_ids:
210214
# Generate an initial set of SMC particles
211215
# at the end of the algorithm we return one of these particles as the new tree
212216
particles = self.init_particles(tree_id)
213217
# Compute the sum of trees without the tree we are attempting to replace
214218
self.sum_trees_output_noi = sum_trees_output - particles[0].tree.predict_output()
215219

216220
# The old tree is not growing so we update the weights only once.
217-
self.update_weight(particles[0])
221+
self.update_weight(particles[0], new=True)
218222
for t in range(self.max_stages):
219223
# Sample each particle (try to grow each tree), except for the first one.
220224
for p in particles[1:]:
@@ -235,15 +239,15 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
235239
if tree_grew:
236240
self.update_weight(p)
237241
# Normalize weights
238-
W_t, normalized_weights = self.normalize(particles)
242+
W_t, normalized_weights = self.normalize(particles[1:])
239243

240244
# Resample all but first particle
241-
re_n_w = normalized_weights[1:] / normalized_weights[1:].sum()
245+
re_n_w = normalized_weights
242246
new_indices = np.random.choice(self.indices, size=self.len_indices, p=re_n_w)
243247
particles[1:] = particles[new_indices]
244248

245249
# Set the new weights
246-
for p in particles:
250+
for p in particles[1:]:
247251
p.log_weight = W_t
248252

249253
# Check if particles can keep growing, otherwise stop iterating
@@ -254,23 +258,25 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
254258
if all(non_available_nodes_for_expansion):
255259
break
256260

261+
for p in particles[1:]:
262+
p.log_weight = p.old_likelihood_logp
263+
264+
_, normalized_weights = self.normalize(particles)
257265
# Get the new tree and update
258266
new_particle = np.random.choice(particles, p=normalized_weights)
259267
new_tree = new_particle.tree
260-
self.all_trees[self.idx] = new_tree
268+
self.all_trees[tree_id] = new_tree
261269
new_particle.log_weight = new_particle.old_likelihood_logp - self.log_num_particles
262270
self.all_particles[tree_id] = new_particle
263271
sum_trees_output = self.sum_trees_output_noi + new_tree.predict_output()
264272

265273
if self.tune:
274+
self.ssv = SampleSplittingVariable(self.alpha_vec)
266275
for index in new_particle.used_variates:
267-
self.split_prior[index] += 1
268-
self.ssv = SampleSplittingVariable(self.split_prior)
276+
self.alpha_vec[index] += 1
269277
else:
270-
self.batch = max(1, int(self.m * 0.2))
271278
for index in new_particle.used_variates:
272279
variable_inclusion[index] += 1
273-
self.idx += 1
274280

275281
stats = {"variable_inclusion": variable_inclusion, "bart_trees": copy(self.all_trees)}
276282
sum_trees_output = RaveledVars(sum_trees_output, point_map_info)
@@ -323,7 +329,7 @@ def init_particles(self, tree_id: int) -> np.ndarray:
323329

324330
return np.array(particles)
325331

326-
def update_weight(self, particle: List[ParticleTree]) -> None:
332+
def update_weight(self, particle: List[ParticleTree], new=False) -> None:
327333
"""
328334
Update the weight of a particle
329335
@@ -333,20 +339,22 @@ def update_weight(self, particle: List[ParticleTree]) -> None:
333339
new_likelihood = self.likelihood_logp(
334340
self.sum_trees_output_noi + particle.tree.predict_output()
335341
)
336-
particle.log_weight += new_likelihood - particle.old_likelihood_logp
337-
particle.old_likelihood_logp = new_likelihood
342+
if new:
343+
particle.log_weight = new_likelihood
344+
else:
345+
particle.log_weight += new_likelihood - particle.old_likelihood_logp
346+
particle.old_likelihood_logp = new_likelihood
338347

339348

340349
class SampleSplittingVariable:
341-
def __init__(self, alpha_prior):
350+
def __init__(self, alpha_vec):
342351
"""
343-
Sample splitting variables proportional to `alpha_prior`.
352+
Sample splitting variables proportional to `alpha_vec`.
344353
345-
This is equivalent as sampling weights from a Dirichlet distribution with `alpha_prior`
346-
parameter and then using those weights to sample from the available spliting variables.
354+
This is equivalent to compute the posterior mean of a Dirichlet-Multinomial model.
347355
This enforce sparsity.
348356
"""
349-
self.enu = list(enumerate(np.cumsum(alpha_prior / alpha_prior.sum())))
357+
self.enu = list(enumerate(np.cumsum(alpha_vec / alpha_vec.sum())))
350358

351359
def rvs(self):
352360
r = np.random.random()

0 commit comments

Comments
 (0)