Skip to content

Function with only updates after givens fails in JAX mode #314

Closed
@ricardoV94

Description

@ricardoV94

Description

import numpy as np
import pytensor
import pytensor.tensor as pt
from pytensor.tensor.random.utils import RandomStream

pytensor_rng = RandomStream(1)

batchshape = (25 ,1, 28, 28)
inp_shared = pytensor.shared(np.zeros(batchshape), name="inp_shared")

inp = pt.tensor4(name="inp")
out = inp + pytensor_rng.normal(size=inp.shape, scale=1)

updates = {inp_shared: out.reshape(inp_shared.shape)}

fn = pytensor.function(
    inputs=[], 
    outputs=[], 
    updates=updates, 
    givens={inp: inp_shared},
    mode="JAX",
)

Reported in https://discourse.pymc.io/t/pytensor-cannot-handle-randomtype-sharedvariables-in-mode-jax/12135/

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingjax

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions