diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 9694a022e3..7fd02ca406 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -400,7 +400,7 @@ def local_exp_log(fgraph, node): @register_specialize -@node_rewriter([exp, expm1]) +@node_rewriter([exp, expm1, softplus]) def local_exp_log_nan_switch(fgraph, node): # Rewrites of the kind exp(log...(x)) that require a `nan` switch x = node.inputs[0] @@ -453,6 +453,13 @@ def local_exp_log_nan_switch(fgraph, node): new_out = switch(le(x, 0), neg(exp(x)), np.asarray(np.nan, old_out.dtype)) return [new_out] + # Case for softplus(log(x)) -> log1p(x) + if isinstance(prev_op, ps.Log) and isinstance(node_op, ps_math.Softplus): + x = x.owner.inputs[0] + old_out = node.outputs[0] + new_out = switch(ge(x, 0), log1p(x), np.asarray(np.nan, old_out.dtype)) + return [new_out] + @register_canonicalize @register_specialize diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 9a092663a9..cfd1265bad 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -2010,6 +2010,27 @@ def test_exp_softplus(self, exp_op): decimal=6, ) + def test_softplus_log(self): + # softplus(log(x)) -> log1p(x) + data_valid = np.random.random((4, 3)).astype("float32") * 2 + data_valid[0, 0] = 0 # edge case + data_invalid = data_valid - 2 + + x = fmatrix() + f = function([x], softplus(log(x)), mode=self.mode) + graph = f.maker.fgraph.toposort() + ops_graph = [ + node + for node in graph + if isinstance(node.op, Elemwise) + and isinstance(node.op.scalar_op, ps.Log | ps.Exp | ps.Softplus) + ] + assert len(ops_graph) == 0 + + expected = np.log1p(data_valid) + np.testing.assert_almost_equal(f(data_valid), expected) + assert np.all(np.isnan(f(data_invalid))) + @pytest.mark.parametrize( ["nested_expression", "expected_switches"], [