diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 597f68865d..6f7d8ad438 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -191,7 +191,7 @@ def extract_batch_shape(p, ps, n): return shape batch_shape = [ - s if b is False else constant(1, "int64") + s if not b else constant(1, "int64") for s, b in zip(shape[:-n], p.type.broadcastable[:-n]) ] return batch_shape diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index ef55d2bf75..7e56fc53e3 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -109,8 +109,13 @@ def __init__( def parse_bcast_and_shape(s): if isinstance(s, (bool, np.bool_)): return 1 if s else None - else: + elif isinstance(s, (int, np.int_)): + return int(s) + elif s is None: return s + raise ValueError( + f"TensorType broadcastable/shape must be a boolean, integer or None, got {type(s)} {s}" + ) self.shape = tuple(parse_bcast_and_shape(s) for s in shape) self.dtype_specs() # error checking is done there diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index adc4177085..10da939d07 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -16,6 +16,7 @@ from pytensor.graph.op import get_test_value from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.db import RewriteDatabaseQuery +from pytensor.tensor import ones, stack from pytensor.tensor.random.basic import ( _gamma, bernoulli, @@ -1465,3 +1466,12 @@ def test_rebuild(): assert y_new.type.shape == (100,) assert y_new.shape.eval({x_new: x_new_test}) == (100,) assert y_new.eval({x_new: x_new_test}).shape == (100,) + + +def test_categorical_join_p_static_shape(): + """Regression test against a bug caused by misreading a numpy.bool_""" + p = ones(3) / 3 + prob = stack([p, 1 - p], axis=-1) + assert prob.type.shape == (3, 2) + x = categorical(p=prob) + assert x.type.shape == (3,) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index aea96365e1..2c2b82d1b5 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -2046,8 +2046,14 @@ def test_mixed_ndim_error(self): def test_static_shape_inference(self): a = at.tensor(dtype="int8", shape=(2, 3)) b = at.tensor(dtype="int8", shape=(2, 5)) - assert at.join(1, a, b).type.shape == (2, 8) - assert at.join(-1, a, b).type.shape == (2, 8) + + res = at.join(1, a, b).type.shape + assert res == (2, 8) + assert all(isinstance(s, int) for s in res) + + res = at.join(-1, a, b).type.shape + assert res == (2, 8) + assert all(isinstance(s, int) for s in res) # Check early informative errors from static shape info with pytest.raises(ValueError, match="must match exactly"): @@ -2055,8 +2061,9 @@ def test_static_shape_inference(self): # Check partial inference d = at.tensor(dtype="int8", shape=(2, None)) - assert at.join(1, a, b, d).type.shape == (2, None) - return + res = at.join(1, a, b, d).type.shape + assert res == (2, None) + assert isinstance(res[0], int) def test_split_0elem(self): rng = np.random.default_rng(seed=utt.fetch_seed()) diff --git a/tests/tensor/test_type.py b/tests/tensor/test_type.py index 4e9c456829..e80ee47637 100644 --- a/tests/tensor/test_type.py +++ b/tests/tensor/test_type.py @@ -267,6 +267,27 @@ def test_fixed_shape_basic(): assert t2.shape == (2, 4) +def test_shape_type_conversion(): + t1 = TensorType("float64", shape=np.array([3], dtype=int)) + assert t1.shape == (3,) + assert isinstance(t1.shape[0], int) + assert t1.broadcastable == (False,) + assert isinstance(t1.broadcastable[0], bool) + + t2 = TensorType("float64", broadcastable=np.array([True, False], dtype="bool")) + assert t2.shape == (1, None) + assert isinstance(t2.shape[0], int) + assert t2.broadcastable == (True, False) + assert isinstance(t2.broadcastable[0], bool) + assert isinstance(t2.broadcastable[1], bool) + + with pytest.raises( + ValueError, + match="TensorType broadcastable/shape must be a boolean, integer or None", + ): + TensorType("float64", shape=("1", "2")) + + def test_fixed_shape_clone(): t1 = TensorType("float64", (1,))