@@ -56,6 +56,7 @@ def plot_dependence(
56
56
xs_interval = "linear" ,
57
57
xs_values = None ,
58
58
var_idx = None ,
59
+ var_discrete = None ,
59
60
samples = 50 ,
60
61
instances = 10 ,
61
62
random_seed = None ,
@@ -89,13 +90,16 @@ def plot_dependence(
89
90
Method used to compute the values X used to evaluate the predicted function. "linear",
90
91
evenly spaced values in the range of X. "quantiles", the evaluation is done at the specified
91
92
quantiles of X. "insample", the evaluation is done at the values of X.
93
+ For discrete variables these options are ommited.
92
94
xs_values : int or list
93
95
Values of X used to evaluate the predicted function. If ``xs_interval="linear"`` number of
94
96
points in the evenly spaced grid. If ``xs_interval="quantiles"``quantile or sequence of
95
97
quantiles to compute, which must be between 0 and 1 inclusive.
96
98
Ignored when ``xs_interval="insample"``.
97
99
var_idx : list
98
100
List of the indices of the covariate for which to compute the pdp or ice.
101
+ var_discrete : list
102
+ List of the indices of the covariate treated as discrete.
99
103
samples : int
100
104
Number of posterior samples used in the predictions. Defaults to 50
101
105
instances : int
@@ -161,6 +165,8 @@ def plot_dependence(
161
165
162
166
if var_idx is None :
163
167
var_idx = indices
168
+ if var_discrete is None :
169
+ var_discrete = []
164
170
165
171
if X_names :
166
172
X_labels = [X_names [idx ] for idx in var_idx ]
@@ -178,6 +184,7 @@ def plot_dependence(
178
184
179
185
new_Y = []
180
186
new_X_target = []
187
+ y_mins = []
181
188
182
189
new_X = np .zeros_like (X )
183
190
idx_s = list (range (X .shape [0 ]))
@@ -186,12 +193,15 @@ def plot_dependence(
186
193
indices_mi .pop (i )
187
194
y_pred = []
188
195
if kind == "pdp" :
189
- if xs_interval == "linear" :
190
- new_X_i = np .linspace (np .nanmin (X [:, i ]), np .nanmax (X [:, i ]), xs_values )
191
- elif xs_interval == "quantiles" :
192
- new_X_i = np .quantile (X [:, i ], q = xs_values )
193
- elif xs_interval == "insample" :
194
- new_X_i = X [:, i ]
196
+ if i in var_discrete :
197
+ new_X_i = np .unique (X [:, i ])
198
+ else :
199
+ if xs_interval == "linear" :
200
+ new_X_i = np .linspace (np .nanmin (X [:, i ]), np .nanmax (X [:, i ]), xs_values )
201
+ elif xs_interval == "quantiles" :
202
+ new_X_i = np .quantile (X [:, i ], q = xs_values )
203
+ elif xs_interval == "insample" :
204
+ new_X_i = X [:, i ]
195
205
196
206
for x_i in new_X_i :
197
207
new_X [:, indices_mi ] = X [:, indices_mi ]
@@ -204,6 +214,7 @@ def plot_dependence(
204
214
new_X [:, indices_mi ] = X [:, indices_mi ][instance ]
205
215
y_pred .append (np .mean (predict (idata , rng , X_new = new_X , size = samples ), 0 ))
206
216
new_X_target .append (new_X [:, i ])
217
+ y_mins .append (np .min (y_pred ))
207
218
new_Y .append (np .array (y_pred ).T )
208
219
209
220
if ax is None :
@@ -212,19 +223,34 @@ def plot_dependence(
212
223
elif grid == "wide" :
213
224
fig , axes = plt .subplots (1 , len (var_idx ), sharey = sharey , figsize = figsize )
214
225
elif isinstance (grid , tuple ):
215
- _ , axes = plt .subplots (grid [0 ], grid [1 ], sharey = sharey , figsize = figsize )
226
+ fig , axes = plt .subplots (grid [0 ], grid [1 ], sharey = sharey , figsize = figsize )
216
227
axes = np .ravel (axes )
217
228
else :
218
229
axes = [ax ]
219
-
220
- if rug :
221
- lb = np .min (new_Y )
230
+ fig = ax .get_figure ()
222
231
223
232
for i , ax in enumerate (axes ):
224
233
if i >= len (var_idx ):
225
234
ax .set_axis_off ()
235
+ fig .delaxes (ax )
226
236
else :
227
- if smooth :
237
+ var = var_idx [i ]
238
+ if var in var_discrete :
239
+ if kind == "pdp" :
240
+ y_means = new_Y [i ].mean (0 )
241
+ hdi = az .hdi (new_Y [i ])
242
+ ax .errorbar (
243
+ new_X_target [i ],
244
+ y_means ,
245
+ (y_means - hdi [:, 0 ], hdi [:, 1 ] - y_means ),
246
+ fmt = "." ,
247
+ color = color ,
248
+ )
249
+ else :
250
+ ax .plot (new_X_target [i ], new_Y [i ], "." , color = color , alpha = alpha )
251
+ ax .plot (new_X_target [i ], new_Y [i ].mean (1 ), "o" , color = color_mean )
252
+ ax .set_xticks (new_X_target [i ])
253
+ elif smooth :
228
254
if smooth_kwargs is None :
229
255
smooth_kwargs = {}
230
256
smooth_kwargs .setdefault ("window_length" , 55 )
@@ -263,8 +289,10 @@ def plot_dependence(
263
289
ax .plot (new_X_target [i ][idx ], new_Y [i ][idx ].mean (1 ), color = color_mean )
264
290
265
291
if rug :
266
- ax .plot (X [:, i ], np .full_like (X [:, i ], lb ), "k|" )
292
+ lb = np .min (y_mins )
293
+ ax .plot (X [:, var ], np .full_like (X [:, var ], lb ), "k|" )
267
294
268
295
ax .set_xlabel (X_labels [i ])
269
- ax .get_figure ().text (- 0.05 , 0.5 , Y_label , va = "center" , rotation = "vertical" , fontsize = 15 )
296
+
297
+ fig .text (- 0.05 , 0.5 , Y_label , va = "center" , rotation = "vertical" , fontsize = 15 )
270
298
return axes
0 commit comments