Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci_build/azure_pipelines/pretrained_model_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ jobs:
python_versions: ['3.7']
tf_versions: ['2.4.1']
skip_tflite_tests: 'False'
skip_tfjs_tests: 'False'
skip_tf_tests: 'True'
job:
steps:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ steps:
status=0
# TODO: fix unity model path
# python tests/run_pretrained_models.py --backend $CI_ONNX_BACKEND --opset $CI_ONNX_OPSET --config tests/unity.yaml || status=$?
python tests/run_pretrained_models.py --backend $CI_ONNX_BACKEND --opset $CI_ONNX_OPSET --skip_tf_tests $CI_SKIP_TF_TESTS --skip_tflite_tests $CI_SKIP_TFLITE_TESTS --config tests/run_pretrained_models.yaml || status=$?
python tests/run_pretrained_models.py --backend $CI_ONNX_BACKEND --opset $CI_ONNX_OPSET --skip_tf_tests $CI_SKIP_TF_TESTS --skip_tflite_tests $CI_SKIP_TFLITE_TESTS --skip_tfjs_tests $CI_SKIP_TFJS_TESTS --config tests/run_pretrained_models.yaml || status=$?
exit $status
displayName: 'Test Pre-trained Model'
44 changes: 34 additions & 10 deletions tests/run_pretrained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from tf2onnx.tfonnx import process_tf_graph
from tf2onnx.tf_loader import tf_session, tf_reset_default_graph
from tf2onnx.graph import ExternalTensorStorage
from tfjs_runner import run_tfjs

logger = logging.getLogger("run_pretrained")

Expand Down Expand Up @@ -251,6 +252,10 @@ def download_model(self):
elif self.model_type == 'tflite':
fname = self.local
dir_name = fname.replace(".tflite", "") + "_dir"
elif self.model_type == 'tfjs':
ftype = 'tgz'
fname = 'model.tar.gz'
dir_name = "_".join(url.split("/")[5:-3]) + "_dir"
dir_name = os.path.join(cache_dir, dir_name)
os.makedirs(dir_name, exist_ok=True)
fpath = os.path.join(dir_name, fname)
Expand Down Expand Up @@ -303,7 +308,8 @@ def run_tensorflow(self, sess, inputs):
return result

def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, input_names=None,
const_node_values=None, initialized_tables=None, tflite_path=None, tensors_to_rename=None):
const_node_values=None, initialized_tables=None, tflite_path=None, tensors_to_rename=None,
tfjs_path=None):
"""Convert graph to tensorflow."""
if extra_opset is None:
extra_opset = []
Expand All @@ -314,7 +320,7 @@ def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, i
input_names=input_names, output_names=self.output_names,
const_node_values=const_node_values, initialized_tables=initialized_tables,
tflite_path=tflite_path, dequantize=self.dequantize,
tensors_to_rename=tensors_to_rename)
tensors_to_rename=tensors_to_rename, tfjs_path=tfjs_path)

def run_onnxruntime(self, name, model_proto, inputs, outputs, external_tensor_storage=None):
"""Run test against onnxruntime backend."""
Expand Down Expand Up @@ -375,6 +381,7 @@ def run_test(self, name, backend="onnxruntime", onnx_file=None, opset=None, extr
initialized_tables = {}
outputs = self.output_names
tflite_path = None
tfjs_path = None
to_rename = {}
if self.model_type in ["checkpoint"]:
graph_def, input_names, outputs = tf_loader.from_checkpoint(model_path, input_names, outputs)
Expand All @@ -394,6 +401,9 @@ def run_test(self, name, backend="onnxruntime", onnx_file=None, opset=None, extr
elif self.model_type in ["tflite"]:
tflite_path = model_path
graph_def = None
elif self.model_type in ["tfjs"]:
tfjs_path = model_path
graph_def = None
else:
graph_def, input_names, outputs = tf_loader.from_graphdef(model_path, input_names, outputs)

Expand Down Expand Up @@ -434,6 +444,16 @@ def run_tflite():
logger.info("TFLite perf {:.2f}ms/inference, n={}".format(self.tf_runtime, n))
logger.info("TFLite OK")

if tfjs_path is not None:
inputs = {}
for k in input_names:
v = self.input_names[k]
inputs[k] = self.make_input(v)
if not self.skip_tensorflow:
logger.info("Running TFJS")
tf_results = run_tfjs(tfjs_path, inputs, dir_name)
logger.info("TFJS OK")

if not self.run_tf_frozen:
inputs = {}
for k in input_names:
Expand Down Expand Up @@ -465,7 +485,6 @@ def run_tflite():
logger.info("TF perf {:.2f}ms/inference, n={}".format(self.tf_runtime, n))
logger.info("TensorFlow OK")

shape_override = {}
const_node_values = None
tf_graph = None

Expand Down Expand Up @@ -497,10 +516,6 @@ def run_tflite():
else:
inputs[k] = self.make_input(v).astype(expected_dtype)

if self.force_input_shape:
for k, v in inputs.items():
shape_override[k] = list(v.shape)

# run the model with tensorflow
if self.skip_tensorflow:
logger.info("TensorFlow SKIPPED")
Expand All @@ -526,11 +541,15 @@ def run_tflite():
else:
try:
# convert model to onnx
if self.force_input_shape:
shape_override = {k: list(v.shape) for k, v in inputs.items()}
else:
shape_override = None
onnx_graph = self.to_onnx(tf_graph, opset=opset, extra_opset=extra_opset,
shape_override=shape_override, input_names=inputs.keys(),
const_node_values=const_node_values,
initialized_tables=initialized_tables, tflite_path=tflite_path,
tensors_to_rename=to_rename)
tensors_to_rename=to_rename, tfjs_path=tfjs_path)
onnx_graph = optimizer.optimize_graph(onnx_graph)
print("ONNX", onnx_graph.dump_node_statistics())
external_tensor_storage = ExternalTensorStorage() if self.large_model else None
Expand Down Expand Up @@ -636,6 +655,7 @@ def get_args():
help="extra opset with format like domain:version, e.g. com.microsoft:1")
parser.add_argument("--skip_tf_tests", help="skip non-tflite tests", default="False")
parser.add_argument("--skip_tflite_tests", help="skip tflite tests", default="False")
parser.add_argument("--skip_tfjs_tests", help="skip tfjs tests", default="False")
parser.add_argument("--verbose", "-v", help="verbose output, option is additive", action="count")
parser.add_argument("--debug", help="debug mode", action="store_true")
parser.add_argument("--list", help="list tests", action="store_true")
Expand All @@ -647,6 +667,7 @@ def get_args():
args.target = args.target.split(",")
args.skip_tf_tests = args.skip_tf_tests.upper() == "TRUE"
args.skip_tflite_tests = args.skip_tflite_tests.upper() == "TRUE"
args.skip_tfjs_tests = args.skip_tfjs_tests.upper() == "TRUE"
if args.extra_opset:
tokens = args.extra_opset.split(':')
if len(tokens) != 2:
Expand Down Expand Up @@ -739,11 +760,14 @@ def main():
logger.info("Skip %s: disabled", test)
continue

if args.skip_tfjs_tests and t.model_type == "tfjs":
logger.info("Skip %s: tfjs test", test)
continue
if args.skip_tflite_tests and t.model_type == "tflite":
logger.info("Skip %s: tflite test", test)
continue
if args.skip_tf_tests and t.model_type != "tflite":
logger.info("Skip %s: not tflite test", test)
if args.skip_tf_tests and t.model_type not in ["tflite", "tfjs"]:
logger.info("Skip %s: tf test", test)
continue

condition, reason = t.check_opset_constraints(args.opset, args.extra_opset)
Expand Down
50 changes: 50 additions & 0 deletions tests/run_pretrained_models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -586,3 +586,53 @@ melgan_tflite: # TFLite model with FlexOps and rank-3 transposes
- Identity
rtol: 0.02
atol: 0.0005

handdetector_tfjs:
tf_min_version: 2.1
disabled: false
url: https://tfhub.dev/tensorflow/tfjs-model/handdetector/1/default/1?tfjs-format=compressed
model: "model.json"
model_type: tfjs
input_get: get_beach
inputs:
"input:0": [1, 256, 256, 3]
outputs:
- Identity:0
atol: 0.0005

posenet_mobilenet_float_100_tfjs:
tf_min_version: 2.1
disabled: false
url: https://tfhub.dev/tensorflow/tfjs-model/posenet/mobilenet/float/100/1/default/1?tfjs-format=compressed
model: "model-stride8.json"
model_type: tfjs
input_get: get_beach
force_input_shape: True # ORT doesn't implement autopadding for convs with dilations
inputs:
"sub_2:0": [1, 256, 256, 3]
outputs:
- MobilenetV1/offset_2/BiasAdd:0
- MobilenetV1/heatmap_2/BiasAdd:0
- MobilenetV1/displacement_fwd_2/BiasAdd:0
- MobilenetV1/displacement_bwd_2/BiasAdd:0
rtol: 0.02
atol: 0.0005

posenet_mobilenet_quantized_2_075_tfjs:
tf_min_version: 2.1
disabled: false
url: https://tfhub.dev/tensorflow/tfjs-model/posenet/mobilenet/quantized/2/075/1/default/1?tfjs-format=compressed
model: "model-stride16.json"
model_type: tfjs
input_get: get_beach
force_input_shape: True # ORT doesn't implement autopadding for convs with dilations
inputs:
"sub_2:0": [1, 256, 256, 3]
outputs:
- MobilenetV1/offset_2/BiasAdd:0
- MobilenetV1/heatmap_2/BiasAdd:0
- MobilenetV1/displacement_fwd_2/BiasAdd:0
- MobilenetV1/displacement_bwd_2/BiasAdd:0
rtol: 0.1
ptol: 0.2
atol: 0.005
19 changes: 17 additions & 2 deletions tests/run_tfjs.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*/

const tf = require('@tensorflow/tfjs');
const zlib = require("zlib");

const fs = require('fs');
const http = require('http');
Expand Down Expand Up @@ -48,16 +49,30 @@ if (process.argv[2] == '--test') {
const modelDir = path.dirname(modelPath);
const modelName = path.basename(modelPath);

const fd = fs.openSync(modelPath, 'r');
const buffer = Buffer.alloc(2);
fs.readSync(fd, buffer, 0, 2);
fs.closeSync(fd);
// Check for gzip magic number
const needsUnzip = buffer[0] == 31 && buffer[1] == 139

// tf.loadGraphModel expects a url not a local file, so we serve it on localhost
http.createServer(function (req, res) {
fs.readFile(modelDir + req.url, function (err, data) {
const callback = function (err, data) {
if (err) {
res.writeHead(404);
res.end(JSON.stringify(err));
return;
}
res.writeHead(200);
res.end(data);
}
fs.readFile(modelDir + req.url, function (err, data) {
if (err || !needsUnzip) {
callback(err, data);
} else {
zlib.gunzip(data, callback);
}
});
}).listen(8080);

Expand Down Expand Up @@ -140,4 +155,4 @@ async function main() {
fs.writeFileSync(outputPath, outputString, 'utf8');
}

main().then(() => exit(0)).catch((err) => { console.error(err); exit(1) })
main().then(() => exit(0)).catch((err) => { console.error(err); exit(1) })
28 changes: 20 additions & 8 deletions tf2onnx/tfjs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ def read_model_json(model_path):
return model, zip_compressed


def graphs_from_tfjs(model_path, input_names=None, output_names=None, ignore_default=None, use_default=None):
def graphs_from_tfjs(model_path, input_names=None, output_names=None, shape_override=None,
ignore_default=None, use_default=None):
"""Given the path to a model.json file, parses the model into onnx graphs and returns the main graph and a
topologically sorted list of subgraphs."""
model, zip_compressed = read_model_json(model_path)
Expand Down Expand Up @@ -236,11 +237,13 @@ def graphs_from_tfjs(model_path, input_names=None, output_names=None, ignore_def
if output_names is None and 'signature' in model:
output_names = list(model['signature']['outputs'].keys())

main_g = read_tfjs_graph(topology['node'], weights, None, input_names, output_names, ignore_default, use_default)
main_g = read_tfjs_graph(topology['node'], weights, None, input_names, output_names, shape_override,
ignore_default, use_default)
subgraphs = []
funcs = sort_tfjs_functions(topology.get('library', {}).get('function', []))
for func in funcs:
sub_g = read_tfjs_graph(func.get('nodeDef', []), weights, func, None, None, ignore_default, use_default)
sub_g = read_tfjs_graph(func.get('nodeDef', []), weights, func, None, None, shape_override,
ignore_default, use_default)
subgraphs.append(sub_g)

return main_g, subgraphs
Expand All @@ -259,7 +262,7 @@ def read_tfjs_weight(weight, weights_data, offset):
if 'quantization' in weight:
q_info = weight['quantization']
q_dtype = np.dtype(q_info['dtype'])
np_arr = np.frombuffer(weights_data, dtype=q_dtype, count=count, offset=i)
np_arr = np.frombuffer(weights_data, dtype=q_dtype, count=count, offset=offset)
num_bytes = np_arr.nbytes
np_arr = np_arr.astype(np_dtype) * q_info['scale'] + q_info['min']
else:
Expand Down Expand Up @@ -303,18 +306,27 @@ def read_tfjs_function(func):
return tf_dtypes, output_shapes, inputs, outputs, name


def read_tfjs_graph(nodes, weights, func=None, graph_inputs=None, graph_outputs=None,
def read_tfjs_graph(nodes, weights, func=None, graph_inputs=None, graph_outputs=None, shape_override=None,
ignore_default=None, use_default=None):
"""Creates an onnx graph from the provided tfjs nodes"""
if shape_override is None:
shape_override = {}
onnx_nodes = []
output_shapes = {}
tf_dtypes = {}
op_info = {}
graph_name = 'tfjs_model'
func_name = None

def update_shapes(new_shapes):
if isinstance(new_shapes, dict):
new_shapes = new_shapes.items()
for k, v in new_shapes:
output_shapes[k] = shape_override.get(k, v)

if func is not None:
tf_dtypes, output_shapes, graph_inputs, graph_outputs, func_name = read_tfjs_function(func)
tf_dtypes, fn_input_shapes, graph_inputs, graph_outputs, func_name = read_tfjs_function(func)
update_shapes(fn_input_shapes)
graph_name = func_name
for inp in graph_inputs:
onnx_nodes.append(helper.make_node("Placeholder", [], outputs=[inp], name=inp))
Expand All @@ -338,7 +350,7 @@ def read_tfjs_graph(nodes, weights, func=None, graph_inputs=None, graph_outputs=
onnx_tensor = numpy_helper.from_array(np_arr.astype(np_dtype), out_name)
onnx_node = helper.make_node("Const", [], outputs=[out_name], name=node_name, value=onnx_tensor)
onnx_nodes.append(onnx_node)
output_shapes[out_name] = list(np_arr.shape)
output_shapes[out_name] = shape_override.get(out_name, list(np_arr.shape))
tf_dtypes[out_name] = tf_dtype
op_info[node_name] = (op_type, {'dtype': tf_dtypes[out_name]})
continue
Expand All @@ -365,7 +377,7 @@ def read_tfjs_graph(nodes, weights, func=None, graph_inputs=None, graph_outputs=

output_names = [node_name + ":" + str(i) for i in range(len(out_dtypes))]
tf_dtypes.update(zip(output_names, out_dtypes))
output_shapes.update(zip(output_names, out_shapes))
update_shapes(zip(output_names, out_shapes))
unused_outputs.update(output_names)

if op_type == "PlaceholderWithDefault":
Expand Down
3 changes: 2 additions & 1 deletion tf2onnx/tfonnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,8 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
main_g, subgraphs = graphs_from_tflite(tflite_path, input_names, output_names)
is_tflite = True
elif tfjs_path is not None:
main_g, subgraphs = graphs_from_tfjs(tfjs_path, input_names, output_names, ignore_default, use_default)
main_g, subgraphs = graphs_from_tfjs(tfjs_path, input_names, output_names, shape_override,
ignore_default, use_default)
else:
main_g, subgraphs = graphs_from_tf(tf_graph, input_names, output_names, shape_override, const_node_values,
ignore_default, use_default)
Expand Down