-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Description
I use Python 3's typing features as much as possible. Unfortunately jax's value hierarchy makes this a little bit challenging. Consider the following snippet,
from typing import NamedTuple
import jax.numpy as jp
from jax import lax, random
class Normal(NamedTuple):
loc: ArrayType
scale: ArrayType
def sample(self, rng, sample_shape=()) -> ArrayType:
batch_shape = lax.broadcast_shapes(self.loc.shape, self.scale.shape)
return self.loc + self.scale * random.normal(
rng, shape=sample_shape + batch_shape)I'd like to be able to fill in the mystery ArrayTypes with something like a made-up "jp.Array", but AFAICT from the array class hierarchy, there is no such type that really fits. At first glance jp.DeviceArray looks like an eligible candidate, but then there is also ConcreteArray, ShapedArray, and UnshapedArray. I'm not really sure what the differences are between them but some of them derive from jax.core.AbstractValue, while DeviceArray does not... If there are other cases then I'd certainly like to avoid limiting my type signatures to only operating on arrays that live on-device. To make matters more confusing there also seems to be _FilledConstant and DeviceConstant:
In [10]: jp.ones((2, 3))
Out[10]:
_FilledConstant([[1., 1., 1.],
[1., 1., 1.]], dtype=float32)All in all, it's not clear to me how each of these types play together and how (if possible) to unify them. What's the appropriate type to be used here? And if it does not yet exist, could we create such a type?