Skip to content

Deprecate redundant utilities for extracting constants #1046

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions pytensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# pytensor code, since this code may want to log some messages.
import logging
import sys
import warnings
from functools import singledispatch
from pathlib import Path
from typing import Any, NoReturn, Optional
Expand Down Expand Up @@ -148,13 +149,13 @@ def get_underlying_scalar_constant(v):
If `v` is not some view of constant data, then raise a
`NotScalarConstantError`.
"""
# Is it necessary to test for presence of pytensor.sparse at runtime?
sparse = globals().get("sparse")
if sparse and isinstance(v.type, sparse.SparseTensorType):
if v.owner is not None and isinstance(v.owner.op, sparse.CSM):
data = v.owner.inputs[0]
return tensor.get_underlying_scalar_constant_value(data)
return tensor.get_underlying_scalar_constant_value(v)
warnings.warn(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the logger instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wdym. Never saw deprecation warnings in logging.

"get_underlying_scalar_constant is deprecated. Use tensor.get_underlying_scalar_constant_value instead.",
FutureWarning,
)
from pytensor.tensor.basic import get_underlying_scalar_constant_value

return get_underlying_scalar_constant_value(v)


# isort: off
Expand Down
9 changes: 6 additions & 3 deletions pytensor/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,7 +1329,7 @@ def try_to_copy_if_needed(var):
f" {i}. Since this input is only connected "
"to integer-valued outputs, it should "
"evaluate to zeros, but it evaluates to"
f"{pytensor.get_underlying_scalar_constant(term)}."
f"{pytensor.get_underlying_scalar_constant_value(term)}."
)
raise ValueError(msg)

Expand Down Expand Up @@ -2157,6 +2157,9 @@ def _is_zero(x):
'maybe' means that x is an expression that is complicated enough
that we can't tell that it simplifies to 0.
"""
from pytensor.tensor import get_underlying_scalar_constant_value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need a local import here but not in the above function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the other is using it through tensor. This was not done on purpose but I prefer explicit imports and there is in fact a circular dependency here.

from pytensor.tensor.exceptions import NotScalarConstantError

if not hasattr(x, "type"):
return np.all(x == 0.0)
if isinstance(x.type, NullType):
Expand All @@ -2166,9 +2169,9 @@ def _is_zero(x):

no_constant_value = True
try:
constant_value = pytensor.get_underlying_scalar_constant(x)
constant_value = get_underlying_scalar_constant_value(x)
no_constant_value = False
except pytensor.tensor.exceptions.NotScalarConstantError:
except NotScalarConstantError:
pass

if no_constant_value:
Expand Down
6 changes: 3 additions & 3 deletions pytensor/link/jax/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Split,
TensorFromScalar,
Tri,
get_underlying_scalar_constant_value,
get_scalar_constant_value,
)
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.shape import Shape_i
Expand Down Expand Up @@ -103,7 +103,7 @@ def join(axis, *tensors):
def jax_funcify_Split(op: Split, node, **kwargs):
_, axis, splits = node.inputs
try:
constant_axis = get_underlying_scalar_constant_value(axis)
constant_axis = get_scalar_constant_value(axis)
except NotScalarConstantError:
constant_axis = None
warnings.warn(
Expand All @@ -113,7 +113,7 @@ def jax_funcify_Split(op: Split, node, **kwargs):
try:
constant_splits = np.array(
[
get_underlying_scalar_constant_value(splits[i])
get_scalar_constant_value(splits[i])
for i in range(get_vector_length(splits))
]
)
Expand Down
2 changes: 1 addition & 1 deletion pytensor/scan/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def wrap_into_list(x):
n_fixed_steps = int(n_steps)
else:
try:
n_fixed_steps = pt.get_underlying_scalar_constant_value(n_steps)
n_fixed_steps = pt.get_scalar_constant_value(n_steps)
except NotScalarConstantError:
n_fixed_steps = None

Expand Down
42 changes: 26 additions & 16 deletions pytensor/scan/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from pytensor.tensor.basic import (
Alloc,
AllocEmpty,
get_underlying_scalar_constant_value,
get_scalar_constant_value,
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
Expand All @@ -71,7 +71,7 @@
get_slice_elements,
set_subtensor,
)
from pytensor.tensor.variable import TensorConstant, get_unique_constant_value
from pytensor.tensor.variable import TensorConstant


list_opt_slice = [
Expand Down Expand Up @@ -136,10 +136,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
all_ins = list(graph_inputs(op_outs))
for idx in range(op_info.n_seqs):
node_inp = node.inputs[idx + 1]
if (
isinstance(node_inp, TensorConstant)
and get_unique_constant_value(node_inp) is not None
):
if isinstance(node_inp, TensorConstant) and node_inp.unique_value is not None:
try:
# This works if input is a constant that has all entries
# equal
Expand Down Expand Up @@ -668,8 +665,10 @@ def inner_sitsot_only_last_step_used(
client = fgraph.clients[outer_var][0][0]
if isinstance(client, Apply) and isinstance(client.op, Subtensor):
lst = get_idx_list(client.inputs, client.op.idx_list)
if len(lst) == 1 and pt.extract_constant(lst[0]) == -1:
return True
return (
len(lst) == 1
and get_scalar_constant_value(lst[0], raise_not_constant=False) == -1
)

return False

Expand Down Expand Up @@ -1344,10 +1343,17 @@ def scan_save_mem(fgraph, node):
if isinstance(this_slice[0], slice) and this_slice[0].stop is None:
global_nsteps = None
if isinstance(cf_slice[0], slice):
stop = pt.extract_constant(cf_slice[0].stop)
stop = get_scalar_constant_value(
cf_slice[0].stop, raise_not_constant=False
)
else:
stop = pt.extract_constant(cf_slice[0]) + 1
if stop == maxsize or stop == pt.extract_constant(length):
stop = (
get_scalar_constant_value(cf_slice[0], raise_not_constant=False)
+ 1
)
if stop == maxsize or stop == get_scalar_constant_value(
length, raise_not_constant=False
):
stop = None
else:
# there is a **gotcha** here ! Namely, scan returns an
Expand Down Expand Up @@ -1451,9 +1457,13 @@ def scan_save_mem(fgraph, node):
cf_slice = get_canonical_form_slice(this_slice[0], length)

if isinstance(cf_slice[0], slice):
start = pt.extract_constant(cf_slice[0].start)
start = pt.get_scalar_constant_value(
cf_slice[0].start, raise_not_constant=False
)
else:
start = pt.extract_constant(cf_slice[0])
start = pt.get_scalar_constant_value(
cf_slice[0], raise_not_constant=False
)

if start == 0 or store_steps[i] == 0:
store_steps[i] = 0
Expand Down Expand Up @@ -1628,7 +1638,7 @@ def scan_save_mem(fgraph, node):
# 3.6 Compose the new scan
# TODO: currently we don't support scan with 0 step. So
# don't create one.
if pt.extract_constant(node_ins[0]) == 0:
if get_scalar_constant_value(node_ins[0], raise_not_constant=False) == 0:
return False

# Do not call make_node for test_value
Expand Down Expand Up @@ -1965,13 +1975,13 @@ def belongs_to_set(self, node, set_nodes):

nsteps = node.inputs[0]
try:
nsteps = int(get_underlying_scalar_constant_value(nsteps))
nsteps = int(get_scalar_constant_value(nsteps))
except NotScalarConstantError:
pass

rep_nsteps = rep_node.inputs[0]
try:
rep_nsteps = int(get_underlying_scalar_constant_value(rep_nsteps))
rep_nsteps = int(get_scalar_constant_value(rep_nsteps))
except NotScalarConstantError:
pass

Expand Down
4 changes: 4 additions & 0 deletions pytensor/sparse/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,10 @@ def __str__(self):
def __repr__(self):
return str(self)

@property
def unique_value(self):
return None


SparseTensorType.variable_type = SparseVariable
SparseTensorType.constant_type = SparseConstant
Expand Down
Loading
Loading