From 67be27bb6787edc96bbd9739193b50d2d543621c Mon Sep 17 00:00:00 2001 From: Tom Wildenhain Date: Tue, 20 Jul 2021 13:20:19 -0700 Subject: [PATCH] Add pretrained model tests for tfjs Signed-off-by: Tom Wildenhain --- .../azure_pipelines/pretrained_model_test.yml | 1 + .../templates/pretrained_model_test.yml | 2 +- tests/run_pretrained_models.py | 44 ++++++++++++---- tests/run_pretrained_models.yaml | 50 +++++++++++++++++++ tests/run_tfjs.js | 19 ++++++- tf2onnx/tfjs_utils.py | 28 ++++++++--- tf2onnx/tfonnx.py | 3 +- 7 files changed, 125 insertions(+), 22 deletions(-) diff --git a/ci_build/azure_pipelines/pretrained_model_test.yml b/ci_build/azure_pipelines/pretrained_model_test.yml index 5adeffb1b..ec04629ef 100644 --- a/ci_build/azure_pipelines/pretrained_model_test.yml +++ b/ci_build/azure_pipelines/pretrained_model_test.yml @@ -6,6 +6,7 @@ jobs: python_versions: ['3.7'] tf_versions: ['2.4.1'] skip_tflite_tests: 'False' + skip_tfjs_tests: 'False' skip_tf_tests: 'True' job: steps: diff --git a/ci_build/azure_pipelines/templates/pretrained_model_test.yml b/ci_build/azure_pipelines/templates/pretrained_model_test.yml index b67a9e8da..a0971d9f2 100644 --- a/ci_build/azure_pipelines/templates/pretrained_model_test.yml +++ b/ci_build/azure_pipelines/templates/pretrained_model_test.yml @@ -6,6 +6,6 @@ steps: status=0 # TODO: fix unity model path # python tests/run_pretrained_models.py --backend $CI_ONNX_BACKEND --opset $CI_ONNX_OPSET --config tests/unity.yaml || status=$? - python tests/run_pretrained_models.py --backend $CI_ONNX_BACKEND --opset $CI_ONNX_OPSET --skip_tf_tests $CI_SKIP_TF_TESTS --skip_tflite_tests $CI_SKIP_TFLITE_TESTS --config tests/run_pretrained_models.yaml || status=$? + python tests/run_pretrained_models.py --backend $CI_ONNX_BACKEND --opset $CI_ONNX_OPSET --skip_tf_tests $CI_SKIP_TF_TESTS --skip_tflite_tests $CI_SKIP_TFLITE_TESTS --skip_tfjs_tests $CI_SKIP_TFJS_TESTS --config tests/run_pretrained_models.yaml || status=$? exit $status displayName: 'Test Pre-trained Model' diff --git a/tests/run_pretrained_models.py b/tests/run_pretrained_models.py index ae1254ded..d082e6e86 100644 --- a/tests/run_pretrained_models.py +++ b/tests/run_pretrained_models.py @@ -45,6 +45,7 @@ from tf2onnx.tfonnx import process_tf_graph from tf2onnx.tf_loader import tf_session, tf_reset_default_graph from tf2onnx.graph import ExternalTensorStorage +from tfjs_runner import run_tfjs logger = logging.getLogger("run_pretrained") @@ -251,6 +252,10 @@ def download_model(self): elif self.model_type == 'tflite': fname = self.local dir_name = fname.replace(".tflite", "") + "_dir" + elif self.model_type == 'tfjs': + ftype = 'tgz' + fname = 'model.tar.gz' + dir_name = "_".join(url.split("/")[5:-3]) + "_dir" dir_name = os.path.join(cache_dir, dir_name) os.makedirs(dir_name, exist_ok=True) fpath = os.path.join(dir_name, fname) @@ -303,7 +308,8 @@ def run_tensorflow(self, sess, inputs): return result def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, input_names=None, - const_node_values=None, initialized_tables=None, tflite_path=None, tensors_to_rename=None): + const_node_values=None, initialized_tables=None, tflite_path=None, tensors_to_rename=None, + tfjs_path=None): """Convert graph to tensorflow.""" if extra_opset is None: extra_opset = [] @@ -314,7 +320,7 @@ def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, i input_names=input_names, output_names=self.output_names, const_node_values=const_node_values, initialized_tables=initialized_tables, tflite_path=tflite_path, dequantize=self.dequantize, - tensors_to_rename=tensors_to_rename) + tensors_to_rename=tensors_to_rename, tfjs_path=tfjs_path) def run_onnxruntime(self, name, model_proto, inputs, outputs, external_tensor_storage=None): """Run test against onnxruntime backend.""" @@ -375,6 +381,7 @@ def run_test(self, name, backend="onnxruntime", onnx_file=None, opset=None, extr initialized_tables = {} outputs = self.output_names tflite_path = None + tfjs_path = None to_rename = {} if self.model_type in ["checkpoint"]: graph_def, input_names, outputs = tf_loader.from_checkpoint(model_path, input_names, outputs) @@ -394,6 +401,9 @@ def run_test(self, name, backend="onnxruntime", onnx_file=None, opset=None, extr elif self.model_type in ["tflite"]: tflite_path = model_path graph_def = None + elif self.model_type in ["tfjs"]: + tfjs_path = model_path + graph_def = None else: graph_def, input_names, outputs = tf_loader.from_graphdef(model_path, input_names, outputs) @@ -434,6 +444,16 @@ def run_tflite(): logger.info("TFLite perf {:.2f}ms/inference, n={}".format(self.tf_runtime, n)) logger.info("TFLite OK") + if tfjs_path is not None: + inputs = {} + for k in input_names: + v = self.input_names[k] + inputs[k] = self.make_input(v) + if not self.skip_tensorflow: + logger.info("Running TFJS") + tf_results = run_tfjs(tfjs_path, inputs, dir_name) + logger.info("TFJS OK") + if not self.run_tf_frozen: inputs = {} for k in input_names: @@ -465,7 +485,6 @@ def run_tflite(): logger.info("TF perf {:.2f}ms/inference, n={}".format(self.tf_runtime, n)) logger.info("TensorFlow OK") - shape_override = {} const_node_values = None tf_graph = None @@ -497,10 +516,6 @@ def run_tflite(): else: inputs[k] = self.make_input(v).astype(expected_dtype) - if self.force_input_shape: - for k, v in inputs.items(): - shape_override[k] = list(v.shape) - # run the model with tensorflow if self.skip_tensorflow: logger.info("TensorFlow SKIPPED") @@ -526,11 +541,15 @@ def run_tflite(): else: try: # convert model to onnx + if self.force_input_shape: + shape_override = {k: list(v.shape) for k, v in inputs.items()} + else: + shape_override = None onnx_graph = self.to_onnx(tf_graph, opset=opset, extra_opset=extra_opset, shape_override=shape_override, input_names=inputs.keys(), const_node_values=const_node_values, initialized_tables=initialized_tables, tflite_path=tflite_path, - tensors_to_rename=to_rename) + tensors_to_rename=to_rename, tfjs_path=tfjs_path) onnx_graph = optimizer.optimize_graph(onnx_graph) print("ONNX", onnx_graph.dump_node_statistics()) external_tensor_storage = ExternalTensorStorage() if self.large_model else None @@ -636,6 +655,7 @@ def get_args(): help="extra opset with format like domain:version, e.g. com.microsoft:1") parser.add_argument("--skip_tf_tests", help="skip non-tflite tests", default="False") parser.add_argument("--skip_tflite_tests", help="skip tflite tests", default="False") + parser.add_argument("--skip_tfjs_tests", help="skip tfjs tests", default="False") parser.add_argument("--verbose", "-v", help="verbose output, option is additive", action="count") parser.add_argument("--debug", help="debug mode", action="store_true") parser.add_argument("--list", help="list tests", action="store_true") @@ -647,6 +667,7 @@ def get_args(): args.target = args.target.split(",") args.skip_tf_tests = args.skip_tf_tests.upper() == "TRUE" args.skip_tflite_tests = args.skip_tflite_tests.upper() == "TRUE" + args.skip_tfjs_tests = args.skip_tfjs_tests.upper() == "TRUE" if args.extra_opset: tokens = args.extra_opset.split(':') if len(tokens) != 2: @@ -739,11 +760,14 @@ def main(): logger.info("Skip %s: disabled", test) continue + if args.skip_tfjs_tests and t.model_type == "tfjs": + logger.info("Skip %s: tfjs test", test) + continue if args.skip_tflite_tests and t.model_type == "tflite": logger.info("Skip %s: tflite test", test) continue - if args.skip_tf_tests and t.model_type != "tflite": - logger.info("Skip %s: not tflite test", test) + if args.skip_tf_tests and t.model_type not in ["tflite", "tfjs"]: + logger.info("Skip %s: tf test", test) continue condition, reason = t.check_opset_constraints(args.opset, args.extra_opset) diff --git a/tests/run_pretrained_models.yaml b/tests/run_pretrained_models.yaml index 1238861f2..690382d2d 100644 --- a/tests/run_pretrained_models.yaml +++ b/tests/run_pretrained_models.yaml @@ -586,3 +586,53 @@ melgan_tflite: # TFLite model with FlexOps and rank-3 transposes - Identity rtol: 0.02 atol: 0.0005 + +handdetector_tfjs: + tf_min_version: 2.1 + disabled: false + url: https://tfhub.dev/tensorflow/tfjs-model/handdetector/1/default/1?tfjs-format=compressed + model: "model.json" + model_type: tfjs + input_get: get_beach + inputs: + "input:0": [1, 256, 256, 3] + outputs: + - Identity:0 + atol: 0.0005 + +posenet_mobilenet_float_100_tfjs: + tf_min_version: 2.1 + disabled: false + url: https://tfhub.dev/tensorflow/tfjs-model/posenet/mobilenet/float/100/1/default/1?tfjs-format=compressed + model: "model-stride8.json" + model_type: tfjs + input_get: get_beach + force_input_shape: True # ORT doesn't implement autopadding for convs with dilations + inputs: + "sub_2:0": [1, 256, 256, 3] + outputs: + - MobilenetV1/offset_2/BiasAdd:0 + - MobilenetV1/heatmap_2/BiasAdd:0 + - MobilenetV1/displacement_fwd_2/BiasAdd:0 + - MobilenetV1/displacement_bwd_2/BiasAdd:0 + rtol: 0.02 + atol: 0.0005 + +posenet_mobilenet_quantized_2_075_tfjs: + tf_min_version: 2.1 + disabled: false + url: https://tfhub.dev/tensorflow/tfjs-model/posenet/mobilenet/quantized/2/075/1/default/1?tfjs-format=compressed + model: "model-stride16.json" + model_type: tfjs + input_get: get_beach + force_input_shape: True # ORT doesn't implement autopadding for convs with dilations + inputs: + "sub_2:0": [1, 256, 256, 3] + outputs: + - MobilenetV1/offset_2/BiasAdd:0 + - MobilenetV1/heatmap_2/BiasAdd:0 + - MobilenetV1/displacement_fwd_2/BiasAdd:0 + - MobilenetV1/displacement_bwd_2/BiasAdd:0 + rtol: 0.1 + ptol: 0.2 + atol: 0.005 diff --git a/tests/run_tfjs.js b/tests/run_tfjs.js index 0ec057851..9b0c688d0 100644 --- a/tests/run_tfjs.js +++ b/tests/run_tfjs.js @@ -8,6 +8,7 @@ */ const tf = require('@tensorflow/tfjs'); +const zlib = require("zlib"); const fs = require('fs'); const http = require('http'); @@ -48,9 +49,16 @@ if (process.argv[2] == '--test') { const modelDir = path.dirname(modelPath); const modelName = path.basename(modelPath); +const fd = fs.openSync(modelPath, 'r'); +const buffer = Buffer.alloc(2); +fs.readSync(fd, buffer, 0, 2); +fs.closeSync(fd); +// Check for gzip magic number +const needsUnzip = buffer[0] == 31 && buffer[1] == 139 + // tf.loadGraphModel expects a url not a local file, so we serve it on localhost http.createServer(function (req, res) { - fs.readFile(modelDir + req.url, function (err, data) { + const callback = function (err, data) { if (err) { res.writeHead(404); res.end(JSON.stringify(err)); @@ -58,6 +66,13 @@ http.createServer(function (req, res) { } res.writeHead(200); res.end(data); + } + fs.readFile(modelDir + req.url, function (err, data) { + if (err || !needsUnzip) { + callback(err, data); + } else { + zlib.gunzip(data, callback); + } }); }).listen(8080); @@ -140,4 +155,4 @@ async function main() { fs.writeFileSync(outputPath, outputString, 'utf8'); } -main().then(() => exit(0)).catch((err) => { console.error(err); exit(1) }) \ No newline at end of file +main().then(() => exit(0)).catch((err) => { console.error(err); exit(1) }) diff --git a/tf2onnx/tfjs_utils.py b/tf2onnx/tfjs_utils.py index 3f4b2a4a2..686bffcb7 100644 --- a/tf2onnx/tfjs_utils.py +++ b/tf2onnx/tfjs_utils.py @@ -206,7 +206,8 @@ def read_model_json(model_path): return model, zip_compressed -def graphs_from_tfjs(model_path, input_names=None, output_names=None, ignore_default=None, use_default=None): +def graphs_from_tfjs(model_path, input_names=None, output_names=None, shape_override=None, + ignore_default=None, use_default=None): """Given the path to a model.json file, parses the model into onnx graphs and returns the main graph and a topologically sorted list of subgraphs.""" model, zip_compressed = read_model_json(model_path) @@ -236,11 +237,13 @@ def graphs_from_tfjs(model_path, input_names=None, output_names=None, ignore_def if output_names is None and 'signature' in model: output_names = list(model['signature']['outputs'].keys()) - main_g = read_tfjs_graph(topology['node'], weights, None, input_names, output_names, ignore_default, use_default) + main_g = read_tfjs_graph(topology['node'], weights, None, input_names, output_names, shape_override, + ignore_default, use_default) subgraphs = [] funcs = sort_tfjs_functions(topology.get('library', {}).get('function', [])) for func in funcs: - sub_g = read_tfjs_graph(func.get('nodeDef', []), weights, func, None, None, ignore_default, use_default) + sub_g = read_tfjs_graph(func.get('nodeDef', []), weights, func, None, None, shape_override, + ignore_default, use_default) subgraphs.append(sub_g) return main_g, subgraphs @@ -259,7 +262,7 @@ def read_tfjs_weight(weight, weights_data, offset): if 'quantization' in weight: q_info = weight['quantization'] q_dtype = np.dtype(q_info['dtype']) - np_arr = np.frombuffer(weights_data, dtype=q_dtype, count=count, offset=i) + np_arr = np.frombuffer(weights_data, dtype=q_dtype, count=count, offset=offset) num_bytes = np_arr.nbytes np_arr = np_arr.astype(np_dtype) * q_info['scale'] + q_info['min'] else: @@ -303,9 +306,11 @@ def read_tfjs_function(func): return tf_dtypes, output_shapes, inputs, outputs, name -def read_tfjs_graph(nodes, weights, func=None, graph_inputs=None, graph_outputs=None, +def read_tfjs_graph(nodes, weights, func=None, graph_inputs=None, graph_outputs=None, shape_override=None, ignore_default=None, use_default=None): """Creates an onnx graph from the provided tfjs nodes""" + if shape_override is None: + shape_override = {} onnx_nodes = [] output_shapes = {} tf_dtypes = {} @@ -313,8 +318,15 @@ def read_tfjs_graph(nodes, weights, func=None, graph_inputs=None, graph_outputs= graph_name = 'tfjs_model' func_name = None + def update_shapes(new_shapes): + if isinstance(new_shapes, dict): + new_shapes = new_shapes.items() + for k, v in new_shapes: + output_shapes[k] = shape_override.get(k, v) + if func is not None: - tf_dtypes, output_shapes, graph_inputs, graph_outputs, func_name = read_tfjs_function(func) + tf_dtypes, fn_input_shapes, graph_inputs, graph_outputs, func_name = read_tfjs_function(func) + update_shapes(fn_input_shapes) graph_name = func_name for inp in graph_inputs: onnx_nodes.append(helper.make_node("Placeholder", [], outputs=[inp], name=inp)) @@ -338,7 +350,7 @@ def read_tfjs_graph(nodes, weights, func=None, graph_inputs=None, graph_outputs= onnx_tensor = numpy_helper.from_array(np_arr.astype(np_dtype), out_name) onnx_node = helper.make_node("Const", [], outputs=[out_name], name=node_name, value=onnx_tensor) onnx_nodes.append(onnx_node) - output_shapes[out_name] = list(np_arr.shape) + output_shapes[out_name] = shape_override.get(out_name, list(np_arr.shape)) tf_dtypes[out_name] = tf_dtype op_info[node_name] = (op_type, {'dtype': tf_dtypes[out_name]}) continue @@ -365,7 +377,7 @@ def read_tfjs_graph(nodes, weights, func=None, graph_inputs=None, graph_outputs= output_names = [node_name + ":" + str(i) for i in range(len(out_dtypes))] tf_dtypes.update(zip(output_names, out_dtypes)) - output_shapes.update(zip(output_names, out_shapes)) + update_shapes(zip(output_names, out_shapes)) unused_outputs.update(output_names) if op_type == "PlaceholderWithDefault": diff --git a/tf2onnx/tfonnx.py b/tf2onnx/tfonnx.py index 6bcbd16ca..dba105572 100644 --- a/tf2onnx/tfonnx.py +++ b/tf2onnx/tfonnx.py @@ -427,7 +427,8 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No main_g, subgraphs = graphs_from_tflite(tflite_path, input_names, output_names) is_tflite = True elif tfjs_path is not None: - main_g, subgraphs = graphs_from_tfjs(tfjs_path, input_names, output_names, ignore_default, use_default) + main_g, subgraphs = graphs_from_tfjs(tfjs_path, input_names, output_names, shape_override, + ignore_default, use_default) else: main_g, subgraphs = graphs_from_tf(tf_graph, input_names, output_names, shape_override, const_node_values, ignore_default, use_default)