From c491f3fa5b5470d770875bd4391616b8b1227e41 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 13 Aug 2025 15:46:55 -0700 Subject: [PATCH 01/20] Fold constant of shapes? Signed-off-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 55fb8759d4..77ad0d4efc 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -1014,14 +1014,17 @@ def process_node(self, node: ir.Node) -> Replacement | None: _process_constant_node(node) return None + if _is_onnx_op(node, "ConstantOfShape"): + return None + if any(x.is_graph_input() for x in node.inputs if x is not None): # Do not fold any graph inputs to preserve graph signature return None # Ensure all node inputs are constants if any(x.const_value is None for x in node.inputs if x is not None): - if logger.isEnabledFor(logging.DEBUG): - logger.debug( + if logger.isEnabledFor(logging.INFO): + logger.info( "Skipping constant folding for node %s because it has non-constant inputs", node, [x.name for x in node.inputs if x is not None], @@ -1039,17 +1042,17 @@ def process_node(self, node: ir.Node) -> Replacement | None: ): # If the op is in always_fold_ops and all inputs are used only by this node, # we can still fold it even if the input size exceeds the limit. - logger.debug( + logger.info( "Folding large constant for node %s because it is in the always_fold_ops list", node, ) else: # Skip folding large tensors - if logger.isEnabledFor(logging.DEBUG): + if logger.isEnabledFor(logging.INFO): input_sizes = [ tensor.size for tensor in input_tensors if tensor is not None ] - logger.debug( + logger.info( "Skipping constant folding for node %s due to large input size: %s", node, input_sizes, @@ -1072,11 +1075,11 @@ def convert(av): return None if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)): replacement = self.new_constant(node, outputs) - if _is_onnx_op(node, "ConstantOfShape") or replacement is None: + if replacement is None: return None return Replacement(replacement.outputs, [replacement]) else: - logger.warning( + logger.info( "Skipping constant folding for op %s with multiple outputs.", node.op_type ) return None From dddd923e9756253338b09301527367826c27211d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 13 Aug 2025 16:11:37 -0700 Subject: [PATCH 02/20] Update constant folding behavior for large tensors Signed-off-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 107 +++++++++++++++------- 1 file changed, 75 insertions(+), 32 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 77ad0d4efc..6d777a6060 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -317,12 +317,6 @@ def _get_output(node: ir.Node, index: int) -> ir.Value | None: return None -def _update_type(value: ir.Value, type: ir.TypeProtocol | None) -> None: - if type is not None: - # TODO: merge types - value.type = type - - def _get_input_element_type(node: ir.Node, index: int) -> int: input = _get_input(node, index) if input is not None and input.type is not None: @@ -834,9 +828,11 @@ class FoldConstantsPass(ir.passes.InPlacePass): shape_inference: Whether to perform shape inference. input_size_limit: Maximum size of input tensors to fold. output_size_limit: Maximum size of output tensors to fold. - always_fold_ops: Collection of op types that should always be folded. + always_fold_ops: Collection of op types that should always be folded, unless + folding the operator will duplicate model weights and allow_bloat is False. For ops from the default opset, only op_type is neede (e.g. "Transpose"), otherwise specify the domain with ``{domain}::{op_type}``. + allow_bloat: If False, the pass will not fold ops that will duplicate model weights. """ def __init__( @@ -846,6 +842,7 @@ def __init__( input_size_limit: int, output_size_limit: int, always_fold_ops: Collection[str] = frozenset(["Transpose"]), + allow_bloat: bool = False, ) -> None: self.shape_inference = shape_inference self.input_size_limit = input_size_limit @@ -857,6 +854,7 @@ def __init__( domain = "" ops.append((domain, op_type)) self.always_fold_ops: frozenset[tuple[str, str]] = frozenset(ops) + self.allow_bloat = allow_bloat self._opset_imports: dict[str, int] = {} self._counts: dict[str, int] = {} @@ -896,7 +894,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: input_data = {k: v for k, v in input_data.items() if v is not None} if any(t is None for t in input_types.values()): logger.debug( - "Skipping shape inference for node %s due to missing input type.", + "Skipping shape inference for node %r due to missing input type.", node.name, ) else: @@ -922,7 +920,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: output.type = ir.serde.deserialize_type_proto_for_type(inferred_type) except Exception as e: logger.debug( - "Skipping shape inference for node %s due to exception: %s", + "Skipping shape inference for node %r due to exception: %s", node.name, e, ) @@ -1007,57 +1005,102 @@ def process_node(self, node: ir.Node) -> Replacement | None: output = [output] return Replacement(output, context.nodes) - if _is_control_flow_op(node) or _is_non_deterministic_op(node): + if _is_control_flow_op(node): + logger.info( + "Skipping constant folding for control flow op %r (%s::%s) because it is not supported yet", + node.name, + node.domain, + node.op_type, + ) + + return None + + if _is_non_deterministic_op(node): + logger.info( + "Skipping constant folding for non-deterministic op %r (%s::%s)", + node.name, + node.domain, + node.op_type, + ) return None if _is_onnx_op(node, "Constant"): _process_constant_node(node) return None - if _is_onnx_op(node, "ConstantOfShape"): + if ( + _is_onnx_op(node, "ConstantOfShape") + and "ConstantOfShape" not in self.always_fold_ops + ): + logger.info( + "Skipping constant folding for ConstantOfShape node %r because it is considered a constant " + "and is not in the always_fold_ops list", + node.name, + ) return None if any(x.is_graph_input() for x in node.inputs if x is not None): - # Do not fold any graph inputs to preserve graph signature + logger.info( + "Skipping constant folding for node %r because it is graph input to preserve graph signature", + node, + ) return None # Ensure all node inputs are constants if any(x.const_value is None for x in node.inputs if x is not None): - if logger.isEnabledFor(logging.INFO): - logger.info( - "Skipping constant folding for node %s because it has non-constant inputs", + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "Skipping constant folding for node %r because it has non-constant inputs", node, [x.name for x in node.inputs if x is not None], ) return None - input_tensors = [x.const_value if x is not None else None for x in node.inputs] - if any( - tensor.size > self.input_size_limit + input_tensors = [x.const_value for x in node.inputs if x is not None] + large_inputs = [ + tensor is not None and tensor.size > self.input_size_limit for tensor in input_tensors - if tensor is not None - ): - if (node.domain, node.op_type) in self.always_fold_ops and all( - len(input.consumers()) == 1 for input in node.inputs if input is not None - ): - # If the op is in always_fold_ops and all inputs are used only by this node, - # we can still fold it even if the input size exceeds the limit. - logger.info( - "Folding large constant for node %s because it is in the always_fold_ops list", - node, - ) - else: + ] + if any(large_inputs): + + def log_large_inputs(): # Skip folding large tensors if logger.isEnabledFor(logging.INFO): input_sizes = [ tensor.size for tensor in input_tensors if tensor is not None ] logger.info( - "Skipping constant folding for node %s due to large input size: %s", + "Skipping constant folding for node %r due to large input sizes: %s", node, input_sizes, ) - return None + + # Decide whether to fold large constants + if self.allow_bloat: + # If allow_bloat is True, we can fold large constants + if (node.domain, node.op_type) in self.always_fold_ops: + logger.info( + "Folding large constant for node %r because it is in the always_fold_ops list", + node, + ) + else: + log_large_inputs() + return None + else: + if (node.domain, node.op_type) in self.always_fold_ops and all( + len(input.consumers()) == 1 or (not is_large) + for input, is_large in zip(node.inputs, large_inputs) + if input is not None + ): + # If the op is in always_fold_ops and all large inputs are used only by this node, + # we can still fold it even if the input size exceeds the limit. + logger.info( + "Folding large constant for node %r because it is in the always_fold_ops list", + node, + ) + else: + log_large_inputs() + return None input_values = [_get_numpy_value(x) for x in node.inputs] From e69fdf2af76aa30028d8815a05593306e92b25b5 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 13 Aug 2025 16:14:14 -0700 Subject: [PATCH 03/20] docs Signed-off-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 6d777a6060..1cb18f99cd 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -824,6 +824,10 @@ def merge_dims(dim1, dim2): class FoldConstantsPass(ir.passes.InPlacePass): """A pass that folds constant expressions in the model. + .. NOTE:: + ``ConstantOfShape`` will not be replaced unless explicitly specified in + ``always_fold_ops``. + Attributes: shape_inference: Whether to perform shape inference. input_size_limit: Maximum size of input tensors to fold. From 0c235df371f25282ff02430fce8017a8e1c5f47c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 13 Aug 2025 16:33:15 -0700 Subject: [PATCH 04/20] Fix Signed-off-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 1cb18f99cd..2e4831f33e 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -1034,7 +1034,7 @@ def process_node(self, node: ir.Node) -> Replacement | None: if ( _is_onnx_op(node, "ConstantOfShape") - and "ConstantOfShape" not in self.always_fold_ops + and ("", "ConstantOfShape") not in self.always_fold_ops ): logger.info( "Skipping constant folding for ConstantOfShape node %r because it is considered a constant " @@ -1091,9 +1091,10 @@ def log_large_inputs(): log_large_inputs() return None else: + non_none_inputs = [input for input in node.inputs if input is not None] if (node.domain, node.op_type) in self.always_fold_ops and all( len(input.consumers()) == 1 or (not is_large) - for input, is_large in zip(node.inputs, large_inputs) + for input, is_large in zip(non_none_inputs, large_inputs) if input is not None ): # If the op is in always_fold_ops and all large inputs are used only by this node, From 1503b686772129512363327a381b83502fe531f2 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 13 Aug 2025 16:36:50 -0700 Subject: [PATCH 05/20] Fix Signed-off-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 2e4831f33e..a26f42320b 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -1060,7 +1060,7 @@ def process_node(self, node: ir.Node) -> Replacement | None: ) return None - input_tensors = [x.const_value for x in node.inputs if x is not None] + input_tensors = [x.const_value if x is not None else None for x in node.inputs] large_inputs = [ tensor is not None and tensor.size > self.input_size_limit for tensor in input_tensors @@ -1091,10 +1091,9 @@ def log_large_inputs(): log_large_inputs() return None else: - non_none_inputs = [input for input in node.inputs if input is not None] if (node.domain, node.op_type) in self.always_fold_ops and all( len(input.consumers()) == 1 or (not is_large) - for input, is_large in zip(non_none_inputs, large_inputs) + for input, is_large in zip(node.inputs, large_inputs) if input is not None ): # If the op is in always_fold_ops and all large inputs are used only by this node, From 85b36fce9418ca4f4248612882086c59ea75357b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 13 Aug 2025 16:38:33 -0700 Subject: [PATCH 06/20] Update onnxscript/optimizer/_constant_folding.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/optimizer/_constant_folding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index a26f42320b..50699145d5 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -1046,7 +1046,7 @@ def process_node(self, node: ir.Node) -> Replacement | None: if any(x.is_graph_input() for x in node.inputs if x is not None): logger.info( "Skipping constant folding for node %r because it is graph input to preserve graph signature", - node, + node.name, ) return None From a386851aaee5b3b32bbfc321996c55e85739a51c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 13 Aug 2025 16:39:52 -0700 Subject: [PATCH 07/20] assert Signed-off-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 50699145d5..b3e12e97f6 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -1091,6 +1091,7 @@ def log_large_inputs(): log_large_inputs() return None else: + assert len(node.inputs) == len(large_inputs) if (node.domain, node.op_type) in self.always_fold_ops and all( len(input.consumers()) == 1 or (not is_large) for input, is_large in zip(node.inputs, large_inputs) From a6bc0c82812c2647098e0b916be968c65bc20355 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 15 Aug 2025 12:55:33 -0700 Subject: [PATCH 08/20] Constant inputs Signed-off-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index b3e12e97f6..d99585f1af 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -1052,12 +1052,6 @@ def process_node(self, node: ir.Node) -> Replacement | None: # Ensure all node inputs are constants if any(x.const_value is None for x in node.inputs if x is not None): - if logger.isEnabledFor(logging.DEBUG): - logger.debug( - "Skipping constant folding for node %r because it has non-constant inputs", - node, - [x.name for x in node.inputs if x is not None], - ) return None input_tensors = [x.const_value if x is not None else None for x in node.inputs] @@ -1084,7 +1078,7 @@ def log_large_inputs(): # If allow_bloat is True, we can fold large constants if (node.domain, node.op_type) in self.always_fold_ops: logger.info( - "Folding large constant for node %r because it is in the always_fold_ops list", + "Folding large constant for node %r because it is in the always_fold_ops list and allow_bloat is True", node, ) else: From 2622ffa45ca6cae9d88563b1f13edc7fa82827e4 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 15 Aug 2025 12:57:35 -0700 Subject: [PATCH 09/20] fix Signed-off-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index d99585f1af..f11576ab2c 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -1121,6 +1121,7 @@ def convert(av): return None return Replacement(replacement.outputs, [replacement]) else: + # TODO(justinchuby): Enable folding of multiple outputs when allow_bloat is True logger.info( "Skipping constant folding for op %s with multiple outputs.", node.op_type ) From f824a51cc45419bb9ca47d2458347927eea0048c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 15 Aug 2025 12:59:31 -0700 Subject: [PATCH 10/20] deterministic Signed-off-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index f11576ab2c..1dda95f02e 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -1020,12 +1020,6 @@ def process_node(self, node: ir.Node) -> Replacement | None: return None if _is_non_deterministic_op(node): - logger.info( - "Skipping constant folding for non-deterministic op %r (%s::%s)", - node.name, - node.domain, - node.op_type, - ) return None if _is_onnx_op(node, "Constant"): From b576318f80cb34f023e112b6b4fde5fc7073730d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 15 Aug 2025 13:01:42 -0700 Subject: [PATCH 11/20] debug Signed-off-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 1dda95f02e..11dd4bc151 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -1071,7 +1071,7 @@ def log_large_inputs(): if self.allow_bloat: # If allow_bloat is True, we can fold large constants if (node.domain, node.op_type) in self.always_fold_ops: - logger.info( + logger.debug( "Folding large constant for node %r because it is in the always_fold_ops list and allow_bloat is True", node, ) @@ -1087,7 +1087,7 @@ def log_large_inputs(): ): # If the op is in always_fold_ops and all large inputs are used only by this node, # we can still fold it even if the input size exceeds the limit. - logger.info( + logger.debug( "Folding large constant for node %r because it is in the always_fold_ops list", node, ) From f12b982c831ece77aa7ddc3b7a7cec4488e80d6f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 15 Aug 2025 15:11:35 -0700 Subject: [PATCH 12/20] Remove special casing for constofshape Signed-off-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 11dd4bc151..b78bd22bd4 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -824,10 +824,6 @@ def merge_dims(dim1, dim2): class FoldConstantsPass(ir.passes.InPlacePass): """A pass that folds constant expressions in the model. - .. NOTE:: - ``ConstantOfShape`` will not be replaced unless explicitly specified in - ``always_fold_ops``. - Attributes: shape_inference: Whether to perform shape inference. input_size_limit: Maximum size of input tensors to fold. @@ -1026,17 +1022,6 @@ def process_node(self, node: ir.Node) -> Replacement | None: _process_constant_node(node) return None - if ( - _is_onnx_op(node, "ConstantOfShape") - and ("", "ConstantOfShape") not in self.always_fold_ops - ): - logger.info( - "Skipping constant folding for ConstantOfShape node %r because it is considered a constant " - "and is not in the always_fold_ops list", - node.name, - ) - return None - if any(x.is_graph_input() for x in node.inputs if x is not None): logger.info( "Skipping constant folding for node %r because it is graph input to preserve graph signature", From fafba24f4ef19036f12035aeb3b24a6e14dba3b6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 28 Aug 2025 16:50:05 -0700 Subject: [PATCH 13/20] should_fold Signed-off-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 116 ++++++++++------------ 1 file changed, 55 insertions(+), 61 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 21f7d6d3ad..7bc7e4f991 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -34,6 +34,11 @@ } ) +_DEFAULT_ALWAYS_FOLD_OPS = frozenset( + { + "Transpose", + } +) logger = logging.getLogger(__name__) @@ -893,11 +898,10 @@ class FoldConstantsPass(ir.passes.InPlacePass): shape_inference: Whether to perform shape inference. input_size_limit: Maximum size of input tensors to fold. output_size_limit: Maximum size of output tensors to fold. - always_fold_ops: Collection of op types that should always be folded, unless - folding the operator will duplicate model weights and allow_bloat is False. - For ops from the default opset, only op_type is neede (e.g. "Transpose"), - otherwise specify the domain with ``{domain}::{op_type}``. - allow_bloat: If False, the pass will not fold ops that will duplicate model weights. + should_fold: An optional function that takes a node and returns True if + the node should be considered for folding. + The function should return (1) True to always fold the node, (2) False to + never fold the node, (3) None to use the default rules. """ def __init__( @@ -906,20 +910,12 @@ def __init__( shape_inference: bool, input_size_limit: int, output_size_limit: int, - always_fold_ops: Collection[str] = frozenset(["Transpose"]), - allow_bloat: bool = False, + should_fold: Callable[[ir.Node], bool | None] = lambda node: None, ) -> None: self.shape_inference = shape_inference self.input_size_limit = input_size_limit self.output_size_limit = output_size_limit - ops = [] - for name in always_fold_ops: - domain, op_type = name.split("::", 1) if "::" in name else ("", name) - if domain == "ai.onnx": - domain = "" - ops.append((domain, op_type)) - self.always_fold_ops: frozenset[tuple[str, str]] = frozenset(ops) - self.allow_bloat = allow_bloat + self.should_fold = should_fold self._opset_imports: dict[str, int] = {} self._counts: dict[str, int] = {} @@ -1098,52 +1094,51 @@ def process_node(self, node: ir.Node) -> Replacement | None: if any(x.const_value is None for x in node.inputs if x is not None): return None - input_tensors = [x.const_value if x is not None else None for x in node.inputs] - large_inputs = [ - tensor is not None and tensor.size > self.input_size_limit - for tensor in input_tensors - ] - if any(large_inputs): - - def log_large_inputs(): - # Skip folding large tensors - if logger.isEnabledFor(logging.INFO): - input_sizes = [ - tensor.size for tensor in input_tensors if tensor is not None - ] - logger.info( - "Skipping constant folding for node %r due to large input sizes: %s", - node, - input_sizes, - ) - - # Decide whether to fold large constants - if self.allow_bloat: - # If allow_bloat is True, we can fold large constants - if (node.domain, node.op_type) in self.always_fold_ops: - logger.debug( - "Folding large constant for node %r because it is in the always_fold_ops list and allow_bloat is True", - node, - ) - else: - log_large_inputs() - return None - else: + should_fold = self.should_fold(node) + + if should_fold is False: + logger.info( + "Skipping constant folding for node %r because should_fold returned False", + node.name, + ) + return None + + elif should_fold is None: + # Use default rules to decide whether to fold the node: + # If the any tensor input size exceeds the input_size_limit, skip folding the node. + input_tensors = [x.const_value if x is not None else None for x in node.inputs] + large_inputs = [ + tensor is not None and tensor.size > self.input_size_limit + for tensor in input_tensors + ] + if any(large_inputs): + # Decide whether to fold large constants assert len(node.inputs) == len(large_inputs) - if (node.domain, node.op_type) in self.always_fold_ops and all( + if (node.domain, node.op_type) in _DEFAULT_ALWAYS_FOLD_OPS and all( len(input.consumers()) == 1 or (not is_large) for input, is_large in zip(node.inputs, large_inputs) if input is not None ): - # If the op is in always_fold_ops and all large inputs are used only by this node, + # If the op is in _DEFAULT_ALWAYS_FOLD_OPS and all large inputs are used only by this node, # we can still fold it even if the input size exceeds the limit. - logger.debug( - "Folding large constant for node %r because it is in the always_fold_ops list", - node, - ) + pass else: - log_large_inputs() + # Skip folding large tensors + if logger.isEnabledFor(logging.INFO): + input_sizes = [ + tensor.size for tensor in input_tensors if tensor is not None + ] + logger.info( + "Skipping constant folding for node %r due to large input sizes: %s", + node, + input_sizes, + ) return None + else: + logger.debug( + "Constant folding node %r because should_fold returned True", + node.name, + ) input_values = [_get_numpy_value(x) for x in node.inputs] @@ -1152,6 +1147,7 @@ def convert(av): return ir.serde.serialize_tensor(av.value) return av.value + # TODO(justinchuby): We should find a way to avoid serializing tensors every time we want to evaluate a node attr_values = {name: convert(attr) for name, attr in node.attributes.items()} outputs = _reference_evaluator.evaluate( node.domain, node.op_type, version, *input_values, **attr_values @@ -1165,8 +1161,7 @@ def convert(av): return None return Replacement(replacement.outputs, [replacement]) else: - # TODO(justinchuby): Enable folding of multiple outputs when allow_bloat is True - logger.info( + logger.warning( "Skipping constant folding for op %s with multiple outputs.", node.op_type ) return None @@ -1270,7 +1265,7 @@ def fold_constants( onnx_shape_inference: bool = False, input_size_limit: int = DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT, output_size_limit: int = DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT, - always_fold_ops: Collection[str] = frozenset(["Transpose"]), + should_fold: Callable[[ir.Node], bool | None] = lambda node: None, ) -> FoldConstantsResult: """ Applies constant folding optimization to the model. @@ -1285,10 +1280,9 @@ def fold_constants( output_size_limit: The maximum size of output tensors that can be stored after constant folding. Defaults to `DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT`. - always_fold_ops: A collection of op types that should always be folded, - regardless of their input or output sizes. For ops from the default opset, - only op_type is neede (e.g. "Transpose"), otherwise specify the domain - with ``{domain}::{op_type}``. + should_fold: An optional function that takes a node and returns True if + the node should be considered for folding, False if it should not be folded, + or None to use the default rules. Defaults to a function that always returns None. Returns: An instance of `FoldConstantsResult`. @@ -1298,6 +1292,6 @@ def fold_constants( shape_inference=onnx_shape_inference, input_size_limit=input_size_limit, output_size_limit=output_size_limit, - always_fold_ops=always_fold_ops, + should_fold=should_fold, ) return folder_pass(model) # type: ignore[return-value] From bdc64a809cc1c855d60d4e5e7998d0ded1441327 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 28 Aug 2025 16:53:26 -0700 Subject: [PATCH 14/20] should_fold Signed-off-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 7bc7e4f991..d8a714841f 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -1135,7 +1135,7 @@ def process_node(self, node: ir.Node) -> Replacement | None: ) return None else: - logger.debug( + logger.info( "Constant folding node %r because should_fold returned True", node.name, ) From f0e124b00d2725a1e5598fd7519fd94fbdad4879 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 28 Aug 2025 16:54:24 -0700 Subject: [PATCH 15/20] _is_non_deterministic_op Signed-off-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index d8a714841f..fa67ade679 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -1077,6 +1077,12 @@ def process_node(self, node: ir.Node) -> Replacement | None: return None if _is_non_deterministic_op(node): + logger.info( + "Skipping constant folding for non-deterministic op %r (%s::%s)", + node.name, + node.domain, + node.op_type, + ) return None if _is_onnx_op(node, "Constant"): From 7449c2a77333ded530d0ba2a899ac7c9621371db Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 28 Aug 2025 17:07:54 -0700 Subject: [PATCH 16/20] test Signed-off-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index fa67ade679..7fa972b26a 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -9,7 +9,7 @@ import logging import math import typing -from typing import Any, Callable, Collection, Iterable, Sequence, Union +from typing import Any, Callable, Iterable, Sequence, Union import numpy as np import onnx @@ -34,9 +34,11 @@ } ) +# A list of ops to always fold regardless of their input size limits, as long as +# they are the single consumer of the large input tensors _DEFAULT_ALWAYS_FOLD_OPS = frozenset( { - "Transpose", + ("", "Transpose"), } ) From 2c93163a521dcda1b59a443981ac2ce17644b09e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 28 Aug 2025 17:22:17 -0700 Subject: [PATCH 17/20] test Signed-off-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 12 +++++++-- .../optimizer/_constant_folding_test.py | 27 +++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 7fa972b26a..c34a44e6d6 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -1113,7 +1113,15 @@ def process_node(self, node: ir.Node) -> Replacement | None: elif should_fold is None: # Use default rules to decide whether to fold the node: - # If the any tensor input size exceeds the input_size_limit, skip folding the node. + # - ConstantOfShape is preserved to avoid increasing model size unnecessarily + # - If the any tensor input size exceeds the input_size_limit, skip folding the node + if _is_onnx_op(node, "ConstantOfShape"): + logger.info( + "Skipping constant folding for node %r because ConstantOfShape is preserved by default", + node.name, + ) + return None + input_tensors = [x.const_value if x is not None else None for x in node.inputs] large_inputs = [ tensor is not None and tensor.size > self.input_size_limit @@ -1128,7 +1136,7 @@ def process_node(self, node: ir.Node) -> Replacement | None: if input is not None ): # If the op is in _DEFAULT_ALWAYS_FOLD_OPS and all large inputs are used only by this node, - # we can still fold it even if the input size exceeds the limit. + # we can still fold it even if the input size exceeds the limit pass else: # Skip folding large tensors diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index 8c05fbc0a4..852ef51152 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -581,6 +581,33 @@ def test_transpose_is_always_folded(self): ops = [node.op_type for node in optimized.graph] self.assertEqual(ops, ["Constant"]) + def test_node_is_folded_if_specified_as_should_fold(self): + model_text = """ + + agraph (float[M, 256] x) => (float[42, 42] z) + + { + z = ConstantOfShape (w) + } + """ + model = ir.from_onnx_text(model_text) + + # ConstantOfShape is not folded by default + optimized = self._fold(model) + ops = [node.op_type for node in optimized.graph] + self.assertEqual(ops, ["ConstantOfShape"]) + + # But ConstantOfShape is folded when specified in should_fold + optimized = self._fold( + model, should_fold=lambda node: node.op_type == "ConstantOfShape" + ) + ops = [node.op_type for node in optimized.graph] + self.assertEqual(ops, ["Constant"]) + np.testing.assert_array_equal( + optimized.graph.node(0).attributes["value"].as_tensor().numpy(), + np.ones((42, 42), dtype=np.int64), + ) + def test_multi_graph_identity_output_preserves_output_name(self): model = """ From a53dcd1b9583d235f43aaba49a843c09e1d28cdd Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 3 Sep 2025 12:05:51 -0700 Subject: [PATCH 18/20] update Signed-off-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index c34a44e6d6..fb60879741 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -902,8 +902,8 @@ class FoldConstantsPass(ir.passes.InPlacePass): output_size_limit: Maximum size of output tensors to fold. should_fold: An optional function that takes a node and returns True if the node should be considered for folding. - The function should return (1) True to always fold the node, (2) False to - never fold the node, (3) None to use the default rules. + The function should return (1) True to fold the node, (2) False to + avoid folding the node, (3) None to use the default rules. """ def __init__( From 9922bae04d35abe755ad8294a1018a5ef4925989 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 3 Sep 2025 12:06:55 -0700 Subject: [PATCH 19/20] docs Signed-off-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index fb60879741..5f34e430dc 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -902,8 +902,8 @@ class FoldConstantsPass(ir.passes.InPlacePass): output_size_limit: Maximum size of output tensors to fold. should_fold: An optional function that takes a node and returns True if the node should be considered for folding. - The function should return (1) True to fold the node, (2) False to - avoid folding the node, (3) None to use the default rules. + The function should return True/False value to indicate if this particular + node should be folded, or None to use the default folding rules. """ def __init__( From b99fe8896029f532c1b1b8aa9c2ff4bd87b89d36 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 3 Sep 2025 12:13:08 -0700 Subject: [PATCH 20/20] fix Signed-off-by: Justin Chu --- onnxscript/optimizer/_constant_folding_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index 852ef51152..6b2557551e 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -599,7 +599,7 @@ def test_node_is_folded_if_specified_as_should_fold(self): # But ConstantOfShape is folded when specified in should_fold optimized = self._fold( - model, should_fold=lambda node: node.op_type == "ConstantOfShape" + model, should_fold=lambda node: node.op_type == "ConstantOfShape" or None ) ops = [node.op_type for node in optimized.graph] self.assertEqual(ops, ["Constant"])