diff --git a/README.md b/README.md index 68a24c06f..694c14c0e 100644 --- a/README.md +++ b/README.md @@ -205,6 +205,13 @@ Only valid with parameter `--saved_model`. When set, creates a zip file containi Saves the frozen tensorflow graph to file. +#### --custom-ops + +If a model contains ops not recognized by onnx runtime, you can tag these ops with a custom op domain so that the +runtime can still open the model. The format is a comma-separated map of tf op names to domains in the format +OpName:domain. If only an op name is provided (no colon), the default domain of `ai.onnx.converters.tensorflow` +will be used. + #### --target Some models require special handling to run on some runtimes. In particular, the model may use unsupported data types. Workarounds are activated with ```--target TARGET```. Currently supported values are listed on this [wiki](https://github.com/onnx/tensorflow-onnx/wiki/target). If your model will be run on Windows ML, you should specify the appropriate target value. diff --git a/tf2onnx/convert.py b/tf2onnx/convert.py index 8e3d8d762..52a86e125 100644 --- a/tf2onnx/convert.py +++ b/tf2onnx/convert.py @@ -60,7 +60,7 @@ def get_args(): parser.add_argument("--inputs", help="model input_names") parser.add_argument("--outputs", help="model output_names") parser.add_argument("--opset", type=int, default=None, help="opset version to use for onnx domain") - parser.add_argument("--custom-ops", help="list of custom ops") + parser.add_argument("--custom-ops", help="comma-separated map of custom ops to domains in format OpName:domain") parser.add_argument("--extra_opset", default=None, help="extra opset with format like domain:version, e.g. com.microsoft:1") parser.add_argument("--target", default=",".join(constants.DEFAULT_TARGET), choices=constants.POSSIBLE_TARGETS, @@ -103,11 +103,11 @@ def get_args(): return args - -def default_custom_op_handler(ctx, node, name, args): - node.domain = constants.TENSORFLOW_OPSET.domain - return node - +def make_default_custom_op_handler(domain): + def default_custom_op_handler(ctx, node, name, args): + node.domain = domain + return node + return default_custom_op_handler def main(): args = get_args() @@ -121,9 +121,17 @@ def main(): custom_ops = {} initialized_tables = None if args.custom_ops: - # default custom ops for tensorflow-onnx are in the "tf" namespace - custom_ops = {op: (default_custom_op_handler, []) for op in args.custom_ops.split(",")} - extra_opset.append(constants.TENSORFLOW_OPSET) + using_tf_opset = False + for op in args.custom_ops.split(","): + if ":" in op: + op, domain = op.split(":") + else: + # default custom ops for tensorflow-onnx are in the "tf" namespace + using_tf_opset = True + domain = constants.TENSORFLOW_OPSET.domain + custom_ops[op] = (make_default_custom_op_handler(domain), []) + if using_tf_opset: + extra_opset.append(constants.TENSORFLOW_OPSET) # get the frozen tensorflow model from graphdef, checkpoint or saved_model. if args.graphdef: