Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 95 additions & 60 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import logging
import math
import typing
from typing import Any, Callable, Collection, Iterable, Sequence, Union
from typing import Any, Callable, Iterable, Sequence, Union

import numpy as np
import onnx
Expand All @@ -34,6 +34,13 @@
}
)

# A list of ops to always fold regardless of their input size limits, as long as
# they are the single consumer of the large input tensors
_DEFAULT_ALWAYS_FOLD_OPS = frozenset(
{
("", "Transpose"),
}
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -332,12 +339,6 @@ def _get_output(node: ir.Node, index: int) -> ir.Value | None:
return None


def _update_type(value: ir.Value, type: ir.TypeProtocol | None) -> None:
if type is not None:
# TODO: merge types
value.type = type


def _get_input_element_type(node: ir.Node, index: int) -> int:
input = _get_input(node, index)
if input is not None and input.type is not None:
Expand Down Expand Up @@ -899,9 +900,10 @@ class FoldConstantsPass(ir.passes.InPlacePass):
shape_inference: Whether to perform shape inference.
input_size_limit: Maximum size of input tensors to fold.
output_size_limit: Maximum size of output tensors to fold.
always_fold_ops: Collection of op types that should always be folded.
For ops from the default opset, only op_type is neede (e.g. "Transpose"),
otherwise specify the domain with ``{domain}::{op_type}``.
should_fold: An optional function that takes a node and returns True if
the node should be considered for folding.
The function should return (1) True to always fold the node, (2) False to
never fold the node, (3) None to use the default rules.
"""

def __init__(
Expand All @@ -910,18 +912,12 @@ def __init__(
shape_inference: bool,
input_size_limit: int,
output_size_limit: int,
always_fold_ops: Collection[str] = frozenset(["Transpose"]),
should_fold: Callable[[ir.Node], bool | None] = lambda node: None,
) -> None:
self.shape_inference = shape_inference
self.input_size_limit = input_size_limit
self.output_size_limit = output_size_limit
ops = []
for name in always_fold_ops:
domain, op_type = name.split("::", 1) if "::" in name else ("", name)
if domain == "ai.onnx":
domain = ""
ops.append((domain, op_type))
self.always_fold_ops: frozenset[tuple[str, str]] = frozenset(ops)
self.should_fold = should_fold

self._opset_imports: dict[str, int] = {}
self._counts: dict[str, int] = {}
Expand Down Expand Up @@ -961,7 +957,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
input_data = {k: v for k, v in input_data.items() if v is not None}
if any(t is None for t in input_types.values()):
logger.debug(
"Skipping shape inference for node %s due to missing input type.",
"Skipping shape inference for node %r due to missing input type.",
node.name,
)
else:
Expand All @@ -987,7 +983,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
output.type = ir.serde.deserialize_type_proto_for_type(inferred_type)
except Exception as e:
logger.debug(
"Skipping shape inference for node %s due to exception: %s",
"Skipping shape inference for node %r due to exception: %s",
node.name,
e,
)
Expand Down Expand Up @@ -1072,62 +1068,102 @@ def process_node(self, node: ir.Node) -> Replacement | None:
output = [output]
return Replacement(output, context.nodes)

if _is_control_flow_op(node) or _is_non_deterministic_op(node):
if _is_control_flow_op(node):
logger.info(
"Skipping constant folding for control flow op %r (%s::%s) because it is not supported yet",
node.name,
node.domain,
node.op_type,
)

return None

if _is_non_deterministic_op(node):
logger.info(
"Skipping constant folding for non-deterministic op %r (%s::%s)",
node.name,
node.domain,
node.op_type,
)
return None

if _is_onnx_op(node, "Constant"):
_process_constant_node(node)
return None

if any(x.is_graph_input() for x in node.inputs if x is not None):
# Do not fold any graph inputs to preserve graph signature
logger.info(
"Skipping constant folding for node %r because it is graph input to preserve graph signature",
node.name,
)
return None

# Ensure all node inputs are constants
if any(x.const_value is None for x in node.inputs if x is not None):
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Skipping constant folding for node %s because it has non-constant inputs",
node,
[x.name for x in node.inputs if x is not None],
)
return None

input_tensors = [x.const_value if x is not None else None for x in node.inputs]
if any(
tensor.size > self.input_size_limit
for tensor in input_tensors
if tensor is not None
):
if (node.domain, node.op_type) in self.always_fold_ops and all(
len(input.consumers()) == 1 for input in node.inputs if input is not None
):
# If the op is in always_fold_ops and all inputs are used only by this node,
# we can still fold it even if the input size exceeds the limit.
logger.debug(
"Folding large constant for node %s because it is in the always_fold_ops list",
node,
should_fold = self.should_fold(node)

if should_fold is False:
logger.info(
"Skipping constant folding for node %r because should_fold returned False",
node.name,
)
return None

elif should_fold is None:
# Use default rules to decide whether to fold the node:
# - ConstantOfShape is preserved to avoid increasing model size unnecessarily
# - If the any tensor input size exceeds the input_size_limit, skip folding the node
if _is_onnx_op(node, "ConstantOfShape"):
logger.info(
"Skipping constant folding for node %r because ConstantOfShape is preserved by default",
node.name,
)
else:
# Skip folding large tensors
if logger.isEnabledFor(logging.DEBUG):
input_sizes = [
tensor.size for tensor in input_tensors if tensor is not None
]
logger.debug(
"Skipping constant folding for node %s due to large input size: %s",
node,
input_sizes,
)
return None

input_tensors = [x.const_value if x is not None else None for x in node.inputs]
large_inputs = [
tensor is not None and tensor.size > self.input_size_limit
for tensor in input_tensors
]
if any(large_inputs):
# Decide whether to fold large constants
assert len(node.inputs) == len(large_inputs)
if (node.domain, node.op_type) in _DEFAULT_ALWAYS_FOLD_OPS and all(
len(input.consumers()) == 1 or (not is_large)
for input, is_large in zip(node.inputs, large_inputs)
if input is not None
):
# If the op is in _DEFAULT_ALWAYS_FOLD_OPS and all large inputs are used only by this node,
# we can still fold it even if the input size exceeds the limit
pass
else:
# Skip folding large tensors
if logger.isEnabledFor(logging.INFO):
input_sizes = [
tensor.size for tensor in input_tensors if tensor is not None
]
logger.info(
"Skipping constant folding for node %r due to large input sizes: %s",
node,
input_sizes,
)
return None
else:
logger.info(
"Constant folding node %r because should_fold returned True",
node.name,
)

input_values = [_get_numpy_value(x) for x in node.inputs]

def convert(av):
if av.type == ir.AttributeType.TENSOR:
return ir.serde.serialize_tensor(av.value)
return av.value

# TODO(justinchuby): We should find a way to avoid serializing tensors every time we want to evaluate a node
attr_values = {name: convert(attr) for name, attr in node.attributes.items()}
outputs = _reference_evaluator.evaluate(
node.domain, node.op_type, version, *input_values, **attr_values
Expand All @@ -1137,7 +1173,7 @@ def convert(av):
return None
if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)):
replacement = self.new_constant(node, outputs)
if _is_onnx_op(node, "ConstantOfShape") or replacement is None:
if replacement is None:
return None
return Replacement(replacement.outputs, [replacement])
else:
Expand Down Expand Up @@ -1245,7 +1281,7 @@ def fold_constants(
onnx_shape_inference: bool = False,
input_size_limit: int = DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
output_size_limit: int = DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
always_fold_ops: Collection[str] = frozenset(["Transpose"]),
should_fold: Callable[[ir.Node], bool | None] = lambda node: None,
) -> FoldConstantsResult:
"""
Applies constant folding optimization to the model.
Expand All @@ -1260,10 +1296,9 @@ def fold_constants(
output_size_limit: The maximum size of output tensors
that can be stored after constant folding. Defaults to
`DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT`.
always_fold_ops: A collection of op types that should always be folded,
regardless of their input or output sizes. For ops from the default opset,
only op_type is neede (e.g. "Transpose"), otherwise specify the domain
with ``{domain}::{op_type}``.
should_fold: An optional function that takes a node and returns True if
the node should be considered for folding, False if it should not be folded,
or None to use the default rules. Defaults to a function that always returns None.

Returns:
An instance of `FoldConstantsResult`.
Expand All @@ -1273,6 +1308,6 @@ def fold_constants(
shape_inference=onnx_shape_inference,
input_size_limit=input_size_limit,
output_size_limit=output_size_limit,
always_fold_ops=always_fold_ops,
should_fold=should_fold,
)
return folder_pass(model) # type: ignore[return-value]
27 changes: 27 additions & 0 deletions onnxscript/optimizer/_constant_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,33 @@ def test_transpose_is_always_folded(self):
ops = [node.op_type for node in optimized.graph]
self.assertEqual(ops, ["Constant"])

def test_node_is_folded_if_specified_as_should_fold(self):
model_text = """
<ir_version: 10, opset_import: [ "" : 20]>
agraph (float[M, 256] x) => (float[42, 42] z)
<int64[2] w = {42, 42}>
{
z = ConstantOfShape <value: tensor = int64[1] {1}> (w)
}
"""
model = ir.from_onnx_text(model_text)

# ConstantOfShape is not folded by default
optimized = self._fold(model)
ops = [node.op_type for node in optimized.graph]
self.assertEqual(ops, ["ConstantOfShape"])

# But ConstantOfShape is folded when specified in should_fold
optimized = self._fold(
model, should_fold=lambda node: node.op_type == "ConstantOfShape"
)
ops = [node.op_type for node in optimized.graph]
self.assertEqual(ops, ["Constant"])
np.testing.assert_array_equal(
optimized.graph.node(0).attributes["value"].as_tensor().numpy(),
np.ones((42, 42), dtype=np.int64),
)

def test_multi_graph_identity_output_preserves_output_name(self):
model = """
<ir_version: 10, opset_import: ["" : 20]>
Expand Down
Loading