Skip to content

Handle implicit broadcasting correctly in RandomVariable vectorization #664

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions pytensor/tensor/random/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Variable, equal_computations
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
from pytensor.graph.replace import _vectorize_node, vectorize_graph
from pytensor.misc.safe_asarray import _asarray
from pytensor.scalar import ScalarVariable
from pytensor.tensor.basic import (
Expand All @@ -20,7 +20,10 @@
infer_static_shape,
)
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
from pytensor.tensor.random.utils import broadcast_params, normalize_size_param
from pytensor.tensor.random.utils import (
explicit_expand_dims,
normalize_size_param,
)
from pytensor.tensor.shape import shape_tuple
from pytensor.tensor.type import TensorType, all_dtypes
from pytensor.tensor.type_other import NoneConst
Expand Down Expand Up @@ -387,10 +390,26 @@ def vectorize_random_variable(
# If size was provided originally and a new size hasn't been provided,
# We extend it to accommodate the new input batch dimensions.
# Otherwise, we assume the new size already has the right values

# Need to make parameters implicit broadcasting explicit
original_dist_params = node.inputs[3:]
old_size = node.inputs[1]
len_old_size = get_vector_length(old_size)

original_expanded_dist_params = explicit_expand_dims(
original_dist_params, op.ndims_params, len_old_size
)
# We call vectorize_graph to automatically handle any new explicit expand_dims
dist_params = vectorize_graph(
original_expanded_dist_params, dict(zip(original_dist_params, dist_params))
)

if len_old_size and equal_computations([old_size], [size]):
bcasted_param = broadcast_params(dist_params, op.ndims_params)[0]
# If the original RV had a size variable and a new one has not been provided,
# we need to define a new size as the concatenation of the original size dimensions
# and the novel ones implied by new broadcasted batched parameters dimensions.
# We use the first broadcasted batch dimension for reference.
bcasted_param = explicit_expand_dims(dist_params, op.ndims_params)[0]
new_param_ndim = (bcasted_param.type.ndim - op.ndims_params[0]) - len_old_size
if new_param_ndim >= 0:
new_size_dims = bcasted_param.shape[:new_param_ndim]
Expand Down
30 changes: 29 additions & 1 deletion pytensor/tensor/random/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pytensor.tensor.basic import as_tensor_variable, cast, constant
from pytensor.tensor.extra_ops import broadcast_to
from pytensor.tensor.math import maximum
from pytensor.tensor.shape import specify_shape
from pytensor.tensor.shape import shape_padleft, specify_shape
from pytensor.tensor.type import int_dtypes
from pytensor.tensor.variable import TensorVariable

Expand Down Expand Up @@ -121,6 +121,34 @@ def broadcast_params(params, ndims_params):
return bcast_params


def explicit_expand_dims(
params: Sequence[TensorVariable],
ndim_params: tuple[int],
size_length: int = 0,
) -> list[TensorVariable]:
"""Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size."""

batch_dims = [
param.type.ndim - ndim_param for param, ndim_param in zip(params, ndim_params)
]

if size_length:
# NOTE: PyTensor is currently treating zero-length size as size=None, which is not what Numpy does
# See: https://github.com/pymc-devs/pytensor/issues/568
max_batch_dims = size_length
else:
max_batch_dims = max(batch_dims)

new_params = []
for new_param, batch_dim in zip(params, batch_dims):
missing_dims = max_batch_dims - batch_dim
if missing_dims:
new_param = shape_padleft(new_param, missing_dims)
new_params.append(new_param)

return new_params


def normalize_size_param(
size: Optional[Union[int, np.ndarray, Variable, Sequence]],
) -> Variable:
Expand Down
60 changes: 43 additions & 17 deletions tests/tensor/random/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
from pytensor.tensor.type import all_dtypes, iscalar, tensor


@pytest.fixture(scope="module", autouse=True)
def set_pytensor_flags():
@pytest.fixture(scope="function", autouse=False)
def strict_test_value_flags():
with config.change_flags(cxx="", compute_test_value="raise"):
yield


def test_RandomVariable_basics():
def test_RandomVariable_basics(strict_test_value_flags):
str_res = str(
RandomVariable(
"normal",
Expand Down Expand Up @@ -95,7 +95,7 @@ def test_RandomVariable_basics():
grad(rv_out, [rv_node.inputs[0]])


def test_RandomVariable_bcast():
def test_RandomVariable_bcast(strict_test_value_flags):
rv = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True)

mu = tensor(dtype=config.floatX, shape=(1, None, None))
Expand Down Expand Up @@ -125,7 +125,7 @@ def test_RandomVariable_bcast():
assert res.broadcastable == (True, False)


def test_RandomVariable_bcast_specify_shape():
def test_RandomVariable_bcast_specify_shape(strict_test_value_flags):
rv = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True)

s1 = pt.as_tensor(1, dtype=np.int64)
Expand All @@ -146,7 +146,7 @@ def test_RandomVariable_bcast_specify_shape():
assert res.type.shape == (1, None, None, None, 1)


def test_RandomVariable_floatX():
def test_RandomVariable_floatX(strict_test_value_flags):
test_rv_op = RandomVariable(
"normal",
0,
Expand All @@ -172,14 +172,14 @@ def test_RandomVariable_floatX():
(3, default_rng, np.random.default_rng(3)),
],
)
def test_random_maker_op(seed, maker_op, numpy_res):
def test_random_maker_op(strict_test_value_flags, seed, maker_op, numpy_res):
seed = pt.as_tensor_variable(seed)
z = function(inputs=[], outputs=[maker_op(seed)])()
aes_res = z[0]
assert maker_op.random_type.values_eq(aes_res, numpy_res)


def test_random_maker_ops_no_seed():
def test_random_maker_ops_no_seed(strict_test_value_flags):
# Testing the initialization when seed=None
# Since internal states randomly generated,
# we just check the output classes
Expand All @@ -192,7 +192,7 @@ def test_random_maker_ops_no_seed():
assert isinstance(aes_res, np.random.Generator)


def test_RandomVariable_incompatible_size():
def test_RandomVariable_incompatible_size(strict_test_value_flags):
rv_op = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True)
with pytest.raises(
ValueError, match="Size length is incompatible with batched dimensions"
Expand All @@ -216,7 +216,6 @@ def _supp_shape_from_params(self, dist_params, param_shapes=None):
return [dist_params[0].shape[-1]]


@config.change_flags(compute_test_value="off")
def test_multivariate_rv_infer_static_shape():
"""Test that infer shape for multivariate random variable works when a parameter must be broadcasted."""
mv_op = MultivariateRandomVariable()
Expand Down Expand Up @@ -244,23 +243,21 @@ def test_multivariate_rv_infer_static_shape():

def test_vectorize_node():
vec = tensor(shape=(None,))
vec.tag.test_value = [0, 0, 0]
mat = tensor(shape=(None, None))
mat.tag.test_value = [[0, 0, 0], [1, 1, 1]]

# Test without size
node = normal(vec).owner
new_inputs = node.inputs.copy()
new_inputs[3] = mat
new_inputs[3] = mat # mu
vect_node = vectorize_node(node, *new_inputs)
assert vect_node.op is normal
assert vect_node.inputs[3] is mat

# Test with size, new size provided
node = normal(vec, size=(3,)).owner
new_inputs = node.inputs.copy()
new_inputs[1] = (2, 3)
new_inputs[3] = mat
new_inputs[1] = (2, 3) # size
new_inputs[3] = mat # mu
vect_node = vectorize_node(node, *new_inputs)
assert vect_node.op is normal
assert tuple(vect_node.inputs[1].eval()) == (2, 3)
Expand All @@ -269,8 +266,37 @@ def test_vectorize_node():
# Test with size, new size not provided
node = normal(vec, size=(3,)).owner
new_inputs = node.inputs.copy()
new_inputs[3] = mat
new_inputs[3] = mat # mu
vect_node = vectorize_node(node, *new_inputs)
assert vect_node.op is normal
assert vect_node.inputs[3] is mat
assert tuple(vect_node.inputs[1].eval({mat: mat.tag.test_value})) == (2, 3)
assert tuple(
vect_node.inputs[1].eval({mat: np.zeros((2, 3), dtype=config.floatX)})
) == (2, 3)

# Test parameter broadcasting
node = normal(vec).owner
new_inputs = node.inputs.copy()
new_inputs[3] = tensor("mu", shape=(10, 5)) # mu
new_inputs[4] = tensor("sigma", shape=(10,)) # sigma
vect_node = vectorize_node(node, *new_inputs)
assert vect_node.op is normal
assert vect_node.default_output().type.shape == (10, 5)

# Test parameter broadcasting with non-expanding size
node = normal(vec, size=(5,)).owner
new_inputs = node.inputs.copy()
new_inputs[3] = tensor("mu", shape=(10, 5)) # mu
new_inputs[4] = tensor("sigma", shape=(10,)) # sigma
vect_node = vectorize_node(node, *new_inputs)
assert vect_node.op is normal
assert vect_node.default_output().type.shape == (10, 5)

# Test parameter broadcasting with expanding size
node = normal(vec, size=(2, 5)).owner
new_inputs = node.inputs.copy()
new_inputs[3] = tensor("mu", shape=(10, 5)) # mu
new_inputs[4] = tensor("sigma", shape=(10,)) # sigma
vect_node = vectorize_node(node, *new_inputs)
assert vect_node.op is normal
assert vect_node.default_output().type.shape == (10, 2, 5)