Skip to content

Commit c200802

Browse files
authored
better representation for discrete/categorical variables (#5272)
1 parent c15af2a commit c200802

File tree

2 files changed

+47
-16
lines changed

2 files changed

+47
-16
lines changed

pymc/bart/utils.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def plot_dependence(
5656
xs_interval="linear",
5757
xs_values=None,
5858
var_idx=None,
59+
var_discrete=None,
5960
samples=50,
6061
instances=10,
6162
random_seed=None,
@@ -89,13 +90,16 @@ def plot_dependence(
8990
Method used to compute the values X used to evaluate the predicted function. "linear",
9091
evenly spaced values in the range of X. "quantiles", the evaluation is done at the specified
9192
quantiles of X. "insample", the evaluation is done at the values of X.
93+
For discrete variables these options are ommited.
9294
xs_values : int or list
9395
Values of X used to evaluate the predicted function. If ``xs_interval="linear"`` number of
9496
points in the evenly spaced grid. If ``xs_interval="quantiles"``quantile or sequence of
9597
quantiles to compute, which must be between 0 and 1 inclusive.
9698
Ignored when ``xs_interval="insample"``.
9799
var_idx : list
98100
List of the indices of the covariate for which to compute the pdp or ice.
101+
var_discrete : list
102+
List of the indices of the covariate treated as discrete.
99103
samples : int
100104
Number of posterior samples used in the predictions. Defaults to 50
101105
instances : int
@@ -161,6 +165,8 @@ def plot_dependence(
161165

162166
if var_idx is None:
163167
var_idx = indices
168+
if var_discrete is None:
169+
var_discrete = []
164170

165171
if X_names:
166172
X_labels = [X_names[idx] for idx in var_idx]
@@ -178,6 +184,7 @@ def plot_dependence(
178184

179185
new_Y = []
180186
new_X_target = []
187+
y_mins = []
181188

182189
new_X = np.zeros_like(X)
183190
idx_s = list(range(X.shape[0]))
@@ -186,12 +193,15 @@ def plot_dependence(
186193
indices_mi.pop(i)
187194
y_pred = []
188195
if kind == "pdp":
189-
if xs_interval == "linear":
190-
new_X_i = np.linspace(np.nanmin(X[:, i]), np.nanmax(X[:, i]), xs_values)
191-
elif xs_interval == "quantiles":
192-
new_X_i = np.quantile(X[:, i], q=xs_values)
193-
elif xs_interval == "insample":
194-
new_X_i = X[:, i]
196+
if i in var_discrete:
197+
new_X_i = np.unique(X[:, i])
198+
else:
199+
if xs_interval == "linear":
200+
new_X_i = np.linspace(np.nanmin(X[:, i]), np.nanmax(X[:, i]), xs_values)
201+
elif xs_interval == "quantiles":
202+
new_X_i = np.quantile(X[:, i], q=xs_values)
203+
elif xs_interval == "insample":
204+
new_X_i = X[:, i]
195205

196206
for x_i in new_X_i:
197207
new_X[:, indices_mi] = X[:, indices_mi]
@@ -204,6 +214,7 @@ def plot_dependence(
204214
new_X[:, indices_mi] = X[:, indices_mi][instance]
205215
y_pred.append(np.mean(predict(idata, rng, X_new=new_X, size=samples), 0))
206216
new_X_target.append(new_X[:, i])
217+
y_mins.append(np.min(y_pred))
207218
new_Y.append(np.array(y_pred).T)
208219

209220
if ax is None:
@@ -212,19 +223,34 @@ def plot_dependence(
212223
elif grid == "wide":
213224
fig, axes = plt.subplots(1, len(var_idx), sharey=sharey, figsize=figsize)
214225
elif isinstance(grid, tuple):
215-
_, axes = plt.subplots(grid[0], grid[1], sharey=sharey, figsize=figsize)
226+
fig, axes = plt.subplots(grid[0], grid[1], sharey=sharey, figsize=figsize)
216227
axes = np.ravel(axes)
217228
else:
218229
axes = [ax]
219-
220-
if rug:
221-
lb = np.min(new_Y)
230+
fig = ax.get_figure()
222231

223232
for i, ax in enumerate(axes):
224233
if i >= len(var_idx):
225234
ax.set_axis_off()
235+
fig.delaxes(ax)
226236
else:
227-
if smooth:
237+
var = var_idx[i]
238+
if var in var_discrete:
239+
if kind == "pdp":
240+
y_means = new_Y[i].mean(0)
241+
hdi = az.hdi(new_Y[i])
242+
ax.errorbar(
243+
new_X_target[i],
244+
y_means,
245+
(y_means - hdi[:, 0], hdi[:, 1] - y_means),
246+
fmt=".",
247+
color=color,
248+
)
249+
else:
250+
ax.plot(new_X_target[i], new_Y[i], ".", color=color, alpha=alpha)
251+
ax.plot(new_X_target[i], new_Y[i].mean(1), "o", color=color_mean)
252+
ax.set_xticks(new_X_target[i])
253+
elif smooth:
228254
if smooth_kwargs is None:
229255
smooth_kwargs = {}
230256
smooth_kwargs.setdefault("window_length", 55)
@@ -263,8 +289,10 @@ def plot_dependence(
263289
ax.plot(new_X_target[i][idx], new_Y[i][idx].mean(1), color=color_mean)
264290

265291
if rug:
266-
ax.plot(X[:, i], np.full_like(X[:, i], lb), "k|")
292+
lb = np.min(y_mins)
293+
ax.plot(X[:, var], np.full_like(X[:, var], lb), "k|")
267294

268295
ax.set_xlabel(X_labels[i])
269-
ax.get_figure().text(-0.05, 0.5, Y_label, va="center", rotation="vertical", fontsize=15)
296+
297+
fig.text(-0.05, 0.5, Y_label, va="center", rotation="vertical", fontsize=15)
270298
return axes

pymc/tests/test_bart.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_leaf_node():
3131

3232

3333
def test_bart_vi():
34-
X = np.random.normal(0, 1, size=(3, 250)).T
34+
X = np.random.normal(0, 1, size=(250, 3))
3535
Y = np.random.normal(0, 1, size=250)
3636
X[:, 0] = np.random.normal(Y, 0.1)
3737

@@ -51,7 +51,7 @@ def test_bart_vi():
5151

5252

5353
def test_missing_data():
54-
X = np.random.normal(0, 1, size=(2, 50)).T
54+
X = np.random.normal(0, 1, size=(50, 2))
5555
Y = np.random.normal(0, 1, size=50)
5656
X[10:20, 0] = np.nan
5757

@@ -63,7 +63,9 @@ def test_missing_data():
6363

6464

6565
class TestUtils:
66-
X = np.random.normal(0, 1, size=(2, 50)).T
66+
X_norm = np.random.normal(0, 1, size=(50, 2))
67+
X_binom = np.random.binomial(1, 0.5, size=(50, 1))
68+
X = np.hstack([X_norm, X_binom])
6769
Y = np.random.normal(0, 1, size=50)
6870

6971
with pm.Model() as model:
@@ -91,6 +93,7 @@ def test_predict(self):
9193
"samples": 2,
9294
"xs_interval": "quantiles",
9395
"xs_values": [0.25, 0.5, 0.75],
96+
"var_discrete": [3],
9497
},
9598
{"kind": "ice", "instances": 2},
9699
{"var_idx": [0], "rug": False, "smooth": False, "color": "k"},

0 commit comments

Comments
 (0)