-
Notifications
You must be signed in to change notification settings - Fork 90
Description
Summary
As discussed briefly on other channels, the number one thing preventing neos from being used on a realistic use-case is the technical hurdle that comes with differentiating through pyhf
model construction (#882).
A recently talked about way to circumvent this issue would be to skip the model construction entirely, and just modify the spec information in-place for a given model. I've tested that this works on a technical level, and it's super fast. However, there are problems with validating the likelihood shape for the modified model, i.e. not all information is propagating in the correct way.
Here's an example of hand-altering the nominal rates of a model to match those of a different model, but their likelihoods don't match:
import pyhf
pyhf.set_backend("jax")
def make_model(s, b, bup, bdown):
return pyhf.simplemodels.correlated_background(s, b, bup, bdown)
s = [1,2,3]
b = [21,20,22]
bup = [22,22,24]
bdown = [20, 18, 20]
# some target model
model1 = make_model(s, b, bup, bdown)
# a skeleton model with different yields, but the same up/down variations
model2 = make_model([0,0,0], [0,0,0], bup, bdown)
# simply set the yields of the old model to the new one
model2.main_model.nominal_rates = model1.main_model.nominal_rates
# Plotting:
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
plt.rc("figure", figsize=(6, 2), dpi=200, facecolor="w")
data = [23, 23, 23] + model1.config.auxdata
def plot_model(model, ax, model_name=None):
grid = (500, 500)
x_range = np.linspace(0,10,grid[0])
y_range = np.linspace(0,10,grid[1])
xx, yy = np.meshgrid(x_range, y_range)
xy = np.concatenate([xx.reshape(-1, 1), yy.reshape(-1, 1)], axis=1)
@jax.vmap
def logpdf(pars):
return model.logpdf(pars, data)[0]
z = logpdf(jnp.array(xy)).reshape(*grid)
ax.contourf(xx, yy, z)
if model_name is not None:
ax.set_title(model_name)
def plot_difference(model1, model2, ax):
grid = (500, 500)
x_range = np.linspace(0,10,grid[0])
y_range = np.linspace(0,10,grid[1])
xx, yy = np.meshgrid(x_range, y_range)
xy = np.concatenate([xx.reshape(-1, 1), yy.reshape(-1, 1)], axis=1)
zs = []
for model in [model1, model2]:
@jax.vmap
def logpdf(pars):
return model.logpdf(pars, data)[0]
zs.append(logpdf(jnp.array(xy)).reshape(*grid))
diffs = zs[1] - zs[0]
ax.contourf(xx, yy, diffs)
ax.set_title('differences')
fig, axs = plt.subplots(1, 3)
plot_model(model1, axs[0], 'original model')
plot_model(model2, axs[1], 'manually set yields')
plot_difference(model1, model2, axs[2])
plt.tight_layout()
Additional Information
No response
Code of Conduct
- I agree to follow the Code of Conduct
Metadata
Metadata
Assignees
Labels
Type
Projects
Status