3333 ((pt .gt , pt .ge ), "logcdf" , "logsf" , (0.5 , pt .random .normal (0 , 1 ))),
3434 ],
3535)
36- def test_continuous_rv_comparison (comparison_op , exp_logp_true , exp_logp_false , inputs ):
36+ def test_continuous_rv_comparison_bitwise (comparison_op , exp_logp_true , exp_logp_false , inputs ):
3737 for op in comparison_op :
3838 comp_x_rv = op (* inputs )
3939
@@ -48,6 +48,17 @@ def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false,
4848 assert np .isclose (logp_fn (0 ), getattr (ref_scipy , exp_logp_false )(0.5 ))
4949 assert np .isclose (logp_fn (1 ), getattr (ref_scipy , exp_logp_true )(0.5 ))
5050
51+ bitwise_rv = pt .bitwise_not (op (* inputs ))
52+ bitwise_vv = bitwise_rv .clone ()
53+
54+ logprob_not = logp (bitwise_rv , bitwise_vv )
55+ assert_no_rvs (logprob_not )
56+
57+ logp_fn_not = pytensor .function ([bitwise_vv ], logprob_not )
58+
59+ assert np .isclose (logp_fn_not (0 ), getattr (ref_scipy , exp_logp_true )(0.5 ))
60+ assert np .isclose (logp_fn_not (1 ), getattr (ref_scipy , exp_logp_false )(0.5 ))
61+
5162
5263@pytest .mark .parametrize (
5364 "comparison_op, exp_logp_true, exp_logp_false, inputs" ,
@@ -87,7 +98,7 @@ def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false,
8798 ),
8899 ],
89100)
90- def test_discrete_rv_comparison (inputs , comparison_op , exp_logp_true , exp_logp_false ):
101+ def test_discrete_rv_comparison_bitwise (inputs , comparison_op , exp_logp_true , exp_logp_false ):
91102 cens_x_rv = comparison_op (* inputs )
92103
93104 cens_x_vv = cens_x_rv .clone ()
@@ -100,6 +111,17 @@ def test_discrete_rv_comparison(inputs, comparison_op, exp_logp_true, exp_logp_f
100111 assert np .isclose (logp_fn (1 ), exp_logp_true (3 ))
101112 assert np .isclose (logp_fn (0 ), exp_logp_false (3 ))
102113
114+ bitwise_rv = pt .bitwise_not (comparison_op (* inputs ))
115+ bitwise_vv = bitwise_rv .clone ()
116+
117+ logprob_not = logp (bitwise_rv , bitwise_vv )
118+ assert_no_rvs (logprob_not )
119+
120+ logp_fn_not = pytensor .function ([bitwise_vv ], logprob_not )
121+
122+ assert np .isclose (logp_fn_not (1 ), exp_logp_false (3 ))
123+ assert np .isclose (logp_fn_not (0 ), exp_logp_true (3 ))
124+
103125
104126def test_potentially_measurable_operand ():
105127 x_rv = pt .random .normal (2 )
0 commit comments