-
Notifications
You must be signed in to change notification settings - Fork 140
Closed
Labels
enhancementNew feature or requestNew feature or requesthelp wantedExtra attention is neededExtra attention is neededvectorization
Description
Description
Now that #306 is merged, there are a couple of follow ups we should do:
- Blockwise more Ops (Everything in linalg?)
- Dispatch vectorize for more Ops
- Alloc Blockwise improvements #532
- Shape Vectorize dispatch for shape operations #454
- Subtensor Blockwise improvements #532
- Arange Blockwise improvements #532
- ExtractDiag Blockwise improvements #532
- Assert
- Implement JAX and Numba dispatch
- Use jax
vectorize
Support Blockwise in JAX backend #487 - Use the machinery developed in Add support for random Generators in Numba backend #691
- Use jax
import pytensor.tensor as pt
from pytensor.graph import vectorize
from pytensor.compile.builders import OpFromGraph
i = pt.scalar("i", dtype=int)
y_ = pt.sum(pt.arange(0, i))
y = OpFromGraph([i], [y_])(i)
new_i = pt.vector("new_i", dtype=int)
new_y = vectorize(y, {i: new_i})
new_y.eval({new_i: [1, 2, 3, 4]}) # [0, 1, 3, 6]
We could explore automatically wrapping such sequences of "non-square blockwised Ops - reduced non-square dims" in an blockwised OpFromGraph during rewrites, to support those cases.
twiecki, CAClaveau and op3
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or requesthelp wantedExtra attention is neededExtra attention is neededvectorization