diff --git a/utils/check_tf_ops.py b/utils/check_tf_ops.py deleted file mode 100644 index a3b9593bb209..000000000000 --- a/utils/check_tf_ops.py +++ /dev/null @@ -1,101 +0,0 @@ -# coding=utf-8 -# Copyright 2020 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import json -import os - -from tensorflow.core.protobuf.saved_model_pb2 import SavedModel - - -# All paths are set with the intent you should run this script from the root of the repo with the command -# python utils/check_copies.py -REPO_PATH = "." - -# Internal TensorFlow ops that can be safely ignored (mostly specific to a saved model) -INTERNAL_OPS = [ - "Assert", - "AssignVariableOp", - "EmptyTensorList", - "MergeV2Checkpoints", - "ReadVariableOp", - "ResourceGather", - "RestoreV2", - "SaveV2", - "ShardedFilename", - "StatefulPartitionedCall", - "StaticRegexFullMatch", - "VarHandleOp", -] - - -def onnx_compliance(saved_model_path, strict, opset): - saved_model = SavedModel() - onnx_ops = [] - - with open(os.path.join(REPO_PATH, "utils", "tf_ops", "onnx.json")) as f: - onnx_opsets = json.load(f)["opsets"] - - for i in range(1, opset + 1): - onnx_ops.extend(onnx_opsets[str(i)]) - - with open(saved_model_path, "rb") as f: - saved_model.ParseFromString(f.read()) - - model_op_names = set() - - # Iterate over every metagraph in case there is more than one (a saved model can contain multiple graphs) - for meta_graph in saved_model.meta_graphs: - # Add operations in the graph definition - model_op_names.update(node.op for node in meta_graph.graph_def.node) - - # Go through the functions in the graph definition - for func in meta_graph.graph_def.library.function: - # Add operations in each function - model_op_names.update(node.op for node in func.node_def) - - # Convert to list, sorted if you want - model_op_names = sorted(model_op_names) - incompatible_ops = [] - - for op in model_op_names: - if op not in onnx_ops and op not in INTERNAL_OPS: - incompatible_ops.append(op) - - if strict and len(incompatible_ops) > 0: - raise Exception(f"Found the following incompatible ops for the opset {opset}:\n" + incompatible_ops) - elif len(incompatible_ops) > 0: - print(f"Found the following incompatible ops for the opset {opset}:") - print(*incompatible_ops, sep="\n") - else: - print(f"The saved model {saved_model_path} can properly be converted with ONNX.") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--saved_model_path", help="Path of the saved model to check (the .pb file).") - parser.add_argument( - "--opset", default=12, type=int, help="The ONNX opset against which the model has to be tested." - ) - parser.add_argument( - "--framework", choices=["onnx"], default="onnx", help="Frameworks against which to test the saved model." - ) - parser.add_argument( - "--strict", action="store_true", help="Whether make the checking strict (raise errors) or not (raise warnings)" - ) - args = parser.parse_args() - - if args.framework == "onnx": - onnx_compliance(args.saved_model_path, args.strict, args.opset)