@@ -36,8 +36,8 @@ def _get_optimizers():
36
36
return _optimizers
37
37
38
38
39
- def optimize_graph (graph ):
40
- """ Optimize graph, return optimized graph. No throw. """
39
+ def optimize_graph (graph , catch_errors = True ):
40
+ """ Optimize graph, return optimized graph. No throw if catch_errors is true """
41
41
logger = logging .getLogger (__name__ )
42
42
logger .info ("Optimizing ONNX model" )
43
43
@@ -47,17 +47,21 @@ def optimize_graph(graph):
47
47
while continue_flag :
48
48
continue_flag = False
49
49
for name , factory in opts .items ():
50
- try :
51
- logger .verbose ("Apply %s" , name )
52
- current = copy .deepcopy (graph )
50
+ logger .verbose ("Apply %s" , name )
51
+ if catch_errors :
52
+ try :
53
+ current = copy .deepcopy (graph )
54
+ opt = factory ()
55
+ graph = opt .optimize (current ) or graph
56
+ continue_flag = continue_flag or opt .graph_been_opt
57
+ except Exception : # pylint: disable=broad-except
58
+ # if current optimizer fails, continue with other optimizers
59
+ logger .warning ("Failed to apply %s" , name , exc_info = 1 )
60
+ else :
53
61
opt = factory ()
54
- graph = opt .optimize (current ) or graph
62
+ graph = opt .optimize (graph )
55
63
continue_flag = continue_flag or opt .graph_been_opt
56
64
57
- except Exception : # pylint: disable=broad-except
58
- # if current optimizer fails, continue with other optimizers
59
- logger .warning ("Failed to apply %s" , name , exc_info = 1 )
60
-
61
65
try :
62
66
graph .topological_sort (graph .get_nodes ())
63
67
except Exception : # pylint: disable=broad-except
0 commit comments