Skip to content
Merged
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 62 additions & 37 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,12 +317,6 @@ def _get_output(node: ir.Node, index: int) -> ir.Value | None:
return None


def _update_type(value: ir.Value, type: ir.TypeProtocol | None) -> None:
if type is not None:
# TODO: merge types
value.type = type


def _get_input_element_type(node: ir.Node, index: int) -> int:
input = _get_input(node, index)
if input is not None and input.type is not None:
Expand Down Expand Up @@ -834,9 +828,11 @@ class FoldConstantsPass(ir.passes.InPlacePass):
shape_inference: Whether to perform shape inference.
input_size_limit: Maximum size of input tensors to fold.
output_size_limit: Maximum size of output tensors to fold.
always_fold_ops: Collection of op types that should always be folded.
always_fold_ops: Collection of op types that should always be folded, unless
folding the operator will duplicate model weights and allow_bloat is False.
For ops from the default opset, only op_type is neede (e.g. "Transpose"),
otherwise specify the domain with ``{domain}::{op_type}``.
allow_bloat: If False, the pass will not fold ops that will duplicate model weights.
"""

def __init__(
Expand All @@ -846,6 +842,7 @@ def __init__(
input_size_limit: int,
output_size_limit: int,
always_fold_ops: Collection[str] = frozenset(["Transpose"]),
allow_bloat: bool = False,
) -> None:
self.shape_inference = shape_inference
self.input_size_limit = input_size_limit
Expand All @@ -857,6 +854,7 @@ def __init__(
domain = ""
ops.append((domain, op_type))
self.always_fold_ops: frozenset[tuple[str, str]] = frozenset(ops)
self.allow_bloat = allow_bloat

self._opset_imports: dict[str, int] = {}
self._counts: dict[str, int] = {}
Expand Down Expand Up @@ -896,7 +894,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
input_data = {k: v for k, v in input_data.items() if v is not None}
if any(t is None for t in input_types.values()):
logger.debug(
"Skipping shape inference for node %s due to missing input type.",
"Skipping shape inference for node %r due to missing input type.",
node.name,
)
else:
Expand All @@ -922,7 +920,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
output.type = ir.serde.deserialize_type_proto_for_type(inferred_type)
except Exception as e:
logger.debug(
"Skipping shape inference for node %s due to exception: %s",
"Skipping shape inference for node %r due to exception: %s",
node.name,
e,
)
Expand Down Expand Up @@ -1007,54 +1005,80 @@ def process_node(self, node: ir.Node) -> Replacement | None:
output = [output]
return Replacement(output, context.nodes)

if _is_control_flow_op(node) or _is_non_deterministic_op(node):
if _is_control_flow_op(node):
logger.info(
"Skipping constant folding for control flow op %r (%s::%s) because it is not supported yet",
node.name,
node.domain,
node.op_type,
)

return None

if _is_non_deterministic_op(node):
return None

if _is_onnx_op(node, "Constant"):
_process_constant_node(node)
return None

if any(x.is_graph_input() for x in node.inputs if x is not None):
# Do not fold any graph inputs to preserve graph signature
logger.info(
"Skipping constant folding for node %r because it is graph input to preserve graph signature",
node.name,
)
return None

# Ensure all node inputs are constants
if any(x.const_value is None for x in node.inputs if x is not None):
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Skipping constant folding for node %s because it has non-constant inputs",
node,
[x.name for x in node.inputs if x is not None],
)
return None

input_tensors = [x.const_value if x is not None else None for x in node.inputs]
if any(
tensor.size > self.input_size_limit
large_inputs = [
tensor is not None and tensor.size > self.input_size_limit
for tensor in input_tensors
if tensor is not None
):
if (node.domain, node.op_type) in self.always_fold_ops and all(
len(input.consumers()) == 1 for input in node.inputs if input is not None
):
# If the op is in always_fold_ops and all inputs are used only by this node,
# we can still fold it even if the input size exceeds the limit.
logger.debug(
"Folding large constant for node %s because it is in the always_fold_ops list",
node,
)
else:
]
if any(large_inputs):

def log_large_inputs():
# Skip folding large tensors
if logger.isEnabledFor(logging.DEBUG):
if logger.isEnabledFor(logging.INFO):
input_sizes = [
tensor.size for tensor in input_tensors if tensor is not None
]
logger.debug(
"Skipping constant folding for node %s due to large input size: %s",
logger.info(
"Skipping constant folding for node %r due to large input sizes: %s",
node,
input_sizes,
)
return None

# Decide whether to fold large constants
if self.allow_bloat:
# If allow_bloat is True, we can fold large constants
if (node.domain, node.op_type) in self.always_fold_ops:
logger.debug(
"Folding large constant for node %r because it is in the always_fold_ops list and allow_bloat is True",
node,
)
else:
log_large_inputs()
return None
else:
assert len(node.inputs) == len(large_inputs)
if (node.domain, node.op_type) in self.always_fold_ops and all(
len(input.consumers()) == 1 or (not is_large)
for input, is_large in zip(node.inputs, large_inputs)
if input is not None
):
# If the op is in always_fold_ops and all large inputs are used only by this node,
# we can still fold it even if the input size exceeds the limit.
logger.debug(
"Folding large constant for node %r because it is in the always_fold_ops list",
node,
)
else:
log_large_inputs()
return None

input_values = [_get_numpy_value(x) for x in node.inputs]

Expand All @@ -1072,11 +1096,12 @@ def convert(av):
return None
if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)):
replacement = self.new_constant(node, outputs)
if _is_onnx_op(node, "ConstantOfShape") or replacement is None:
if replacement is None:
return None
return Replacement(replacement.outputs, [replacement])
else:
logger.warning(
# TODO(justinchuby): Enable folding of multiple outputs when allow_bloat is True
logger.info(
"Skipping constant folding for op %s with multiple outputs.", node.op_type
)
return None
Expand Down
Loading