Closed
Description
Description
import numpy as np
import pytensor
import pytensor.tensor as pt
from pytensor.tensor.random.utils import RandomStream
pytensor_rng = RandomStream(1)
batchshape = (25 ,1, 28, 28)
inp_shared = pytensor.shared(np.zeros(batchshape), name="inp_shared")
inp = pt.tensor4(name="inp")
out = inp + pytensor_rng.normal(size=inp.shape, scale=1)
updates = {inp_shared: out.reshape(inp_shared.shape)}
fn = pytensor.function(
inputs=[],
outputs=[],
updates=updates,
givens={inp: inp_shared},
mode="JAX",
)
Reported in https://discourse.pymc.io/t/pytensor-cannot-handle-randomtype-sharedvariables-in-mode-jax/12135/