diff --git a/tf2onnx/graph_matcher.py b/tf2onnx/graph_matcher.py index 1ac9e437f..ae50eca0f 100644 --- a/tf2onnx/graph_matcher.py +++ b/tf2onnx/graph_matcher.py @@ -50,6 +50,7 @@ def __init__(self, op_type, name=None, inputs=None): input_pattern if isinstance(input_pattern, OpTypePattern) else OpTypePattern(input_pattern) for input_pattern in inputs ] + self.op_type_set = set(op_type.split('|')) if op_type else set() @property def op_type(self): @@ -154,7 +155,7 @@ def _is_op_type_same(op, pattern): if pattern.op_type == "*": return True - if op.type in pattern.op_type.split('|'): + if op.type in pattern.op_type_set: return True return False