Skip to content

How to manually override the yields inputs to the model spec #1894

@phinate

Description

@phinate

Summary

As discussed briefly on other channels, the number one thing preventing neos from being used on a realistic use-case is the technical hurdle that comes with differentiating through pyhf model construction (#882).

A recently talked about way to circumvent this issue would be to skip the model construction entirely, and just modify the spec information in-place for a given model. I've tested that this works on a technical level, and it's super fast. However, there are problems with validating the likelihood shape for the modified model, i.e. not all information is propagating in the correct way.

Here's an example of hand-altering the nominal rates of a model to match those of a different model, but their likelihoods don't match:

import pyhf

pyhf.set_backend("jax")

def make_model(s, b, bup, bdown):
    return pyhf.simplemodels.correlated_background(s, b, bup, bdown)

s = [1,2,3]
b = [21,20,22]

bup = [22,22,24]
bdown = [20, 18, 20]

# some target model
model1 = make_model(s, b, bup, bdown)

# a skeleton model with different yields, but the same up/down variations
model2 = make_model([0,0,0], [0,0,0], bup, bdown)

# simply set the yields of the old model to the new one
model2.main_model.nominal_rates = model1.main_model.nominal_rates

# Plotting:
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

plt.rc("figure", figsize=(6, 2), dpi=200, facecolor="w")

data = [23, 23, 23] + model1.config.auxdata

def plot_model(model, ax, model_name=None):
    grid = (500, 500)
    x_range = np.linspace(0,10,grid[0])
    y_range = np.linspace(0,10,grid[1]) 
    xx, yy = np.meshgrid(x_range, y_range)
    xy = np.concatenate([xx.reshape(-1, 1), yy.reshape(-1, 1)], axis=1)
    @jax.vmap
    def logpdf(pars):
        return model.logpdf(pars, data)[0]
    z = logpdf(jnp.array(xy)).reshape(*grid)
    ax.contourf(xx, yy, z)
    if model_name is not None:
        ax.set_title(model_name)

def plot_difference(model1, model2, ax):
    grid = (500, 500)
    x_range = np.linspace(0,10,grid[0])
    y_range = np.linspace(0,10,grid[1]) 
    xx, yy = np.meshgrid(x_range, y_range)
    xy = np.concatenate([xx.reshape(-1, 1), yy.reshape(-1, 1)], axis=1)
    zs = []
    for model in [model1, model2]:
        @jax.vmap
        def logpdf(pars):
            return model.logpdf(pars, data)[0]
        zs.append(logpdf(jnp.array(xy)).reshape(*grid))
    diffs = zs[1] - zs[0]
    ax.contourf(xx, yy, diffs)
    ax.set_title('differences')

fig, axs = plt.subplots(1, 3)

plot_model(model1, axs[0], 'original model')
plot_model(model2, axs[1], 'manually set yields')
plot_difference(model1, model2, axs[2])
plt.tight_layout()

image

Additional Information

No response

Code of Conduct

  • I agree to follow the Code of Conduct

Metadata

Metadata

Assignees

No one assigned

    Labels

    feat/enhancementNew feature or requestneeds-triageNeeds a maintainer to categorize and assign

    Type

    No type

    Projects

    Status

    To do

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions