Skip to content

local_add_neg_to_sub rewrite gives wrong results with negative constants #584

Closed
@ricardoV94

Description

@ricardoV94

Reported in https://discourse.pymc.io/t/shape-issue-with-custom-logp-in-densitydist/13608/5?u=ricardov94

import numpy as np

import pytensor
import pytensor.tensor as pt

x = pt.scalar("x")
neg_const = pt.as_tensor(np.full((10,), -0.1))
out = neg_const + x 

pytensor.function([x], out)
<<!! BUG IN FGRAPH.REPLACE OR A LISTENER !!>> <class 'TypeError'> The type of the replacement (Vector(float64, shape=(1,))) must be compatible with the type of the original Variable (Vector(float64, shape=(10,))). local_add_neg_to_sub
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_add_neg_to_sub
ERROR (pytensor.graph.rewriting.basic): node: Add([-0.1 -0.1 ... -0.1 -0.1], ExpandDims{axis=0}.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/rewriting/basic.py", line 1968, in process_node
    fgraph.replace_all_validate_remove(  # type: ignore
  File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/features.py", line 626, in replace_all_validate_remove
    chk = fgraph.replace_all_validate(replacements, reason=reason, **kwargs)
  File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/features.py", line 571, in replace_all_validate
    fgraph.replace(r, new_r, reason=reason, verbose=False, **kwargs)
  File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/fg.py", line 508, in replace
    self.change_node_input(
  File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/fg.py", line 419, in change_node_input
    raise TypeError(
TypeError: The type of the replacement (Vector(float64, shape=(1,))) must be compatible with the type of the original Variable (Vector(float64, shape=(10,))).

@register_specialize
@node_rewriter([add])
def local_add_neg_to_sub(fgraph, node):
"""
-x + y -> y - x
x + (-y) -> x - y
"""
# This rewrite is only registered during specialization, because the
# `local_neg_to_mul` rewrite modifies the relevant pattern during canonicalization
# Rewrite is only applicable when there are two inputs to add
if node.op == add and len(node.inputs) == 2:
# Look for pattern with either input order
for first, second in (node.inputs, reversed(node.inputs)):
if second.owner:
if second.owner.op == neg:
pre_neg = second.owner.inputs[0]
new_out = sub(first, pre_neg)
return [new_out]
# Check if it is a negative constant
const = get_constant(second)
if const is not None and const < 0:
new_out = sub(first, np.abs(const))
return [new_out]

The helper get_constant actually calls get_unique_constant_value under the hood, which in this case returns a single scalar, not the filled homogenous vector. The rewrite itself is also expecting only a scalar constant, otherwise it should be using np.all

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions