-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Closed
Labels
Description
Description of your problem
I want to make a prediction, so I set the value of a shared variable. It has shape (num_samples, num_features). It works in an unexpected way when num_samples is equal to 1. In this case, it generates results for the original value of the shared value, on which the model was trained.
Please provide a minimal, self-contained, and reproducible example.
x = np.random.randn(123, 2)
y = x[:,0] + x[:,1] > 0
with pm.Model() as model:
x_shared = pm.Data('x_shared', x)
coeff = pm.Normal('x', mu=0, sigma=1, shape=(2,))
print(coeff.shape)
logistic = pm.math.sigmoid(coeff[0] * x_shared[:,0] + coeff[1] * x_shared[:,1])
pm.Bernoulli('obs', p=logistic, observed=y)
# fit the model
trace = pm.sample()
good_values = np.array([[0,0],
[0,1],
[1,1]])
bad_values = np.array([[1,2],])
print('good', good_values.shape)
print('bad', bad_values.shape, flush=True)
with model:
# For good
pm.set_data({'x_shared': good_values})
post_pred = pm.sample_posterior_predictive(trace, samples=500)
print('\ngood res', post_pred['obs'].shape)
# For bad
pm.set_data({'x_shared': bad_values})
post_pred = pm.sample_posterior_predictive(trace, samples=500)
print('\nbad res', post_pred['obs'].shape) # Expect (500, 1), get (500, 123) !!!
Output
2020-01-07 14:26:25,534: Auto-assigning NUTS sampler...
2020-01-07 14:26:25,535: Initializing NUTS using jitter+adapt_diag...
Shape.0
2020-01-07 14:26:26,115: Sequential sampling (2 chains in 1 job)
2020-01-07 14:26:26,116: NUTS: [x]
Sampling chain 0, 0 divergences: 100%|██████████| 1000/1000 [00:00<00:00, 1147.73it/s]
Sampling chain 1, 0 divergences: 100%|██████████| 1000/1000 [00:00<00:00, 1124.48it/s]
good (3, 2)
bad (1, 2)
/usr/local/lib/python3.6/dist-packages/pymc3/sampling.py:1247: UserWarning: samples parameter is smaller than nchains times ndraws, some draws and/or chains may not be represented in the returned posterior predictive sample
"samples parameter is smaller than nchains times ndraws, some draws "
100%|██████████| 500/500 [00:00<00:00, 628.52it/s]
25%|██▌ | 127/500 [00:00<00:00, 1260.77it/s]
good res (500, 3)
100%|██████████| 500/500 [00:00<00:00, 1273.05it/s]
bad res (500, 123)
Versions and main components
- PyMC3 Version: 3.8
- Theano Version: 1.0.4
- Python Version: 3.6.9
- Operating system: Linux Debian buster/sid
- How did you install PyMC3: pip
vtseng