Skip to content
Merged
6 changes: 4 additions & 2 deletions pymc3/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,12 @@ def jacobian(f, vars=None):
def jacobian_diag(f, x):
idx = aet.arange(f.shape[0], dtype="int32")

def grad_ii(i):
def grad_ii(i, f, x):
return grad(f[i], x)[i]

return aesara.scan(grad_ii, sequences=[idx], n_steps=f.shape[0], name="jacobian_diag")[0]
return aesara.scan(
grad_ii, sequences=[idx], n_steps=f.shape[0], non_sequences=[f, x], name="jacobian_diag"
)[0]


@aesara.config.change_flags(compute_test_value="ignore")
Expand Down
2 changes: 1 addition & 1 deletion pymc3/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(self, name, model=None, vars=None, test_point=None):
if transform:
# We need to create and add an un-transformed version of
# each transformed variable
untrans_var = transform.backward(var)
untrans_var = transform.backward(v, var)
untrans_var.name = v.name
vars.append(untrans_var)
vars.append(var)
Expand Down
Loading