Skip to content

Commit e6eeb0c

Browse files
committed
Rename Elemwise.check_runtime_broadcast
1 parent bf648df commit e6eeb0c

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

tests/link/jax/test_elemwise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from tests.tensor.test_elemwise import TestElemwise
1919

2020

21-
def test_elemwise_runtime_shape_error():
22-
TestElemwise.check_runtime_shapes_error(get_mode("JAX"))
21+
def test_elemwise_runtime_broadcast():
22+
TestElemwise.check_runtime_broadcast(get_mode("JAX"))
2323

2424

2525
def test_jax_Dimshuffle():

tests/link/numba/test_elemwise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ def test_Elemwise(inputs, input_vals, output_fn, exc):
122122

123123

124124
@pytest.mark.xfail(reason="Logic had to be reversed due to surprising segfaults")
125-
def test_elemwise_runtime_shape_error():
126-
TestElemwise.check_runtime_shapes_error(get_mode("NUMBA"))
125+
def test_elemwise_runtime_broadcast():
126+
TestElemwise.check_runtime_broadcast(get_mode("NUMBA"))
127127

128128

129129
def test_elemwise_speed(benchmark):

tests/tensor/test_elemwise.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,7 @@ def test_input_dimensions_overflow(self):
751751
g(*[np.zeros(2**11, config.floatX) for i in range(6)])
752752

753753
@staticmethod
754-
def check_runtime_shapes_error(mode):
754+
def check_runtime_broadcast(mode):
755755
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
756756
x_v = matrix("x")
757757
m_v = vector("m")
@@ -777,15 +777,15 @@ def check_runtime_shapes_error(mode):
777777
with pytest.raises((ValueError, TypeError)):
778778
f(x, m)
779779

780-
def test_runtime_shapes_error_python(self):
781-
self.check_runtime_shapes_error(Mode(linker="py"))
780+
def test_runtime_broadcast_python(self):
781+
self.check_runtime_broadcast(Mode(linker="py"))
782782

783783
@pytest.mark.skipif(
784784
not pytensor.config.cxx,
785785
reason="G++ not available, so we need to skip this test.",
786786
)
787-
def test_runtime_shapes_error_c(self):
788-
self.check_runtime_shapes_error(Mode(linker="c"))
787+
def test_runtime_broadcast_c(self):
788+
self.check_runtime_broadcast(Mode(linker="c"))
789789

790790
def test_str(self):
791791
op = Elemwise(aes.add, inplace_pattern={0: 0}, name=None)

0 commit comments

Comments
 (0)