Skip to content

Commit cb6a213

Browse files
Add hash table support and frozen graph repairing for from_keras (#1512)
* Add hash table support and frozen graph repairing for from_keras Signed-off-by: Tom Wildenhain <[email protected]> * Pylint Signed-off-by: Tom Wildenhain <[email protected]>
1 parent dc91ae8 commit cb6a213

File tree

5 files changed

+126
-63
lines changed

5 files changed

+126
-63
lines changed

tests/backend_test_base.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def run_onnxcaffe2(self, onnx_graph, inputs):
7070
results = prepared_backend.run(inputs)
7171
return results
7272

73-
def run_onnxruntime(self, model_path, inputs, output_names):
73+
def run_onnxruntime(self, model_path, inputs, output_names, use_custom_ops=False):
7474
"""Run test against onnxruntime backend."""
7575
import onnxruntime as rt
7676
providers = ['CPUExecutionProvider']
@@ -79,6 +79,9 @@ def run_onnxruntime(self, model_path, inputs, output_names):
7979
if gpus is None or len(gpus) > 1:
8080
providers = ['CUDAExecutionProvider']
8181
opt = rt.SessionOptions()
82+
if use_custom_ops:
83+
from onnxruntime_customops import get_library_path
84+
opt.register_custom_ops_library(get_library_path())
8285
# in case of issues with the runtime, one can enable more logging
8386
# opt.log_severity_level = 0
8487
# opt.log_verbosity_level = 255
@@ -87,14 +90,14 @@ def run_onnxruntime(self, model_path, inputs, output_names):
8790
results = m.run(output_names, inputs)
8891
return results
8992

90-
def run_backend(self, g, outputs, input_dict, large_model=False, postfix=""):
93+
def run_backend(self, g, outputs, input_dict, large_model=False, postfix="", use_custom_ops=False):
9194
tensor_storage = ExternalTensorStorage() if large_model else None
9295
model_proto = g.make_model("test", external_tensor_storage=tensor_storage)
9396
model_path = self.save_onnx_model(model_proto, input_dict, external_tensor_storage=tensor_storage,
9497
postfix=postfix)
9598

9699
if self.config.backend == "onnxruntime":
97-
y = self.run_onnxruntime(model_path, input_dict, outputs)
100+
y = self.run_onnxruntime(model_path, input_dict, outputs, use_custom_ops)
98101
elif self.config.backend == "caffe2":
99102
y = self.run_onnxcaffe2(model_proto, input_dict)
100103
else:
@@ -299,7 +302,8 @@ def get_dtype(info):
299302
def run_test_case(self, func, feed_dict, input_names_with_port, output_names_with_port,
300303
rtol=1e-07, atol=1e-5, mtol=None, convert_var_to_const=True, constant_fold=True,
301304
check_value=True, check_shape=True, check_dtype=True, process_args=None, onnx_feed_dict=None,
302-
graph_validator=None, as_session=False, large_model=False, premade_placeholders=False):
305+
graph_validator=None, as_session=False, large_model=False, premade_placeholders=False,
306+
use_custom_ops=False):
303307
test_tf = not self.config.skip_tf_tests
304308
test_tflite = not self.config.skip_tflite_tests
305309
run_tfl_consistency_test = test_tf and test_tflite and self.config.run_tfl_consistency_test
@@ -339,7 +343,8 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
339343
initialized_tables=initialized_tables,
340344
**process_args)
341345
g = optimizer.optimize_graph(g, catch_errors=False)
342-
actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model)
346+
actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model,
347+
use_custom_ops=use_custom_ops)
343348

344349
self.assert_results_equal(expected, actual, rtol, atol, mtol, check_value, check_shape, check_dtype)
345350
self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker)
@@ -369,7 +374,8 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
369374
**tfl_process_args)
370375
g = optimizer.optimize_graph(g)
371376
onnx_feed_dict_without_port = {k.split(':')[0]: v for k, v in onnx_feed_dict.items()}
372-
onnx_tfl_res = self.run_backend(g, tfl_outputs, onnx_feed_dict_without_port, postfix="_from_tflite")
377+
onnx_tfl_res = self.run_backend(g, tfl_outputs, onnx_feed_dict_without_port,
378+
postfix="_from_tflite", use_custom_ops=use_custom_ops)
373379

374380
self.assert_results_equal(tfl_res, onnx_tfl_res, rtol, atol, mtol, check_value, check_shape, check_dtype)
375381
self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker)

tests/test_api.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010

1111
import numpy as np
1212
import tensorflow as tf
13+
from onnx import helper
1314

14-
from common import check_tf_min_version, unittest_main
15+
from common import check_tf_min_version, unittest_main, requires_custom_ops, check_opset_min_version
1516
from backend_test_base import Tf2OnnxBackendTestBase
1617
import tf2onnx
1718

@@ -74,6 +75,48 @@ def test_keras_api(self):
7475
def test_keras_api_large(self):
7576
self._test_keras_api(large_model=True)
7677

78+
@requires_custom_ops()
79+
@check_tf_min_version("2.0")
80+
@check_opset_min_version(11, "SparseToDense")
81+
def test_keras_hashtable(self):
82+
83+
feature_cols = [
84+
tf.feature_column.numeric_column("f_inp", dtype=tf.float32),
85+
tf.feature_column.indicator_column(
86+
tf.feature_column.categorical_column_with_vocabulary_list("s_inp", ["a", "b", "z"], num_oov_buckets=1)
87+
)
88+
]
89+
feature_layer = tf.keras.layers.DenseFeatures(feature_cols)
90+
91+
input_dict = {}
92+
input_dict["f_inp"] = tf.keras.Input(name="f_inp", shape=(1,), dtype=tf.float32)
93+
input_dict["s_inp"] = tf.keras.Input(name="s_inp", shape=(1,), dtype=tf.string)
94+
95+
inputs = list(input_dict.values())
96+
standard_features = feature_layer(input_dict)
97+
hidden1 = tf.keras.layers.Dense(512, activation='relu')(standard_features)
98+
output = tf.keras.layers.Dense(10, activation='softmax')(hidden1)
99+
model = tf.keras.Model(inputs=inputs, outputs=output)
100+
model.compile(optimizer='adam', loss=tf.keras.losses.mean_squared_error)
101+
102+
inp1 = np.array([[2.], [3.]], dtype=np.float32)
103+
inp2 = np.array([["a"], ["b"]], dtype=np.str)
104+
k_res = model.predict([inp1, inp2])
105+
spec = (tf.TensorSpec((None, 1), dtype=tf.float32, name="f_inp"),
106+
tf.TensorSpec((None, 1), tf.string, name="s_inp"))
107+
output_path = os.path.join(self.test_data_directory, "model.onnx")
108+
109+
model_proto, _ = tf2onnx.convert.from_keras(
110+
model, input_signature=spec, opset=self.config.opset, output_path=output_path,
111+
extra_opset=[helper.make_opsetid("ai.onnx.contrib", 1)])
112+
output_names = [n.name for n in model_proto.graph.output]
113+
114+
o_res = self.run_onnxruntime(output_path, {"f_inp": inp1, "s_inp": inp2}, output_names, use_custom_ops=True)
115+
self.assertAllClose(k_res, o_res[0], rtol=0.3, atol=0.1)
116+
# make sure the original keras model wasn't trashed
117+
k_res2 = model.predict([inp1, inp2])
118+
self.assertAllClose(k_res2, o_res[0], rtol=0.3, atol=0.1)
119+
77120
@check_tf_min_version("2.0")
78121
def test_function(self):
79122
def func(x, y):

tests/test_string_ops.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -130,17 +130,8 @@ def func(text):
130130
def _run_test_case(self, func, output_names_with_port, feed_dict, **kwargs):
131131
extra_opset = [utils.make_opsetid(constants.CONTRIB_OPS_DOMAIN, 1)]
132132
process_args = {"extra_opset": extra_opset}
133-
return self.run_test_case(func, feed_dict, [], output_names_with_port, process_args=process_args, **kwargs)
134-
135-
def run_onnxruntime(self, model_path, inputs, output_names):
136-
"""Run test against onnxruntime backend."""
137-
from onnxruntime_customops import get_library_path
138-
import onnxruntime as rt
139-
opt = rt.SessionOptions()
140-
opt.register_custom_ops_library(get_library_path())
141-
m = rt.InferenceSession(model_path, opt)
142-
results = m.run(output_names, inputs)
143-
return results
133+
return self.run_test_case(func, feed_dict, [], output_names_with_port,
134+
use_custom_ops=True, process_args=process_args, **kwargs)
144135

145136
@requires_custom_ops("WordpieceTokenizer")
146137
@check_tf_min_version("2.0", "tensorflow_text")

tf2onnx/convert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,6 @@ def wrap_call(*args, training=False, **kwargs):
329329
output_names = [output_tensor.name for output_tensor in concrete_func.outputs
330330
if output_tensor.dtype != tf.dtypes.resource]
331331

332-
initialized_tables = None
333332
tensors_to_rename = tensor_names_from_structed(concrete_func, input_names, output_names)
334333
reverse_lookup = {v: k for k, v in tensors_to_rename.items()}
335334

@@ -341,7 +340,8 @@ def wrap_call(*args, training=False, **kwargs):
341340
output_names = [reverse_lookup[out] for out in concrete_func.structured_outputs.keys()]
342341

343342
with tf.device("/cpu:0"):
344-
frozen_graph = tf_loader.from_function(concrete_func, input_names, output_names, large_model=large_model)
343+
frozen_graph, initialized_tables = \
344+
tf_loader.from_trackable(model, concrete_func, input_names, output_names, large_model)
345345
model_proto, external_tensor_storage = _convert_common(
346346
frozen_graph,
347347
name=model.name,

tf2onnx/tf_loader.py

Lines changed: 66 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,12 @@ def not_implemented_tf_placeholder(*args, **kwargs):
4747
try:
4848
# pylint: disable=protected-access
4949
from tensorflow.python.saved_model.load import _RestoredResource as TfRestoredResourceType
50+
from tensorflow.python.ops.lookup_ops import StaticHashTable as TfStaticHashTableType
51+
from tensorflow.python.training.tracking.base import Trackable as TfTrackableType
5052
except ImportError:
5153
TfRestoredResourceType = tuple() # isinstance(x, tuple()) is always false
52-
53-
try:
54-
from tensorflow.python.training.tracking.tracking import AutoTrackable as TfAutoTrackableType
55-
except ImportError:
56-
TfAutoTrackableType = tuple()
54+
TfStaticHashTableType = tuple()
55+
TfTrackableType = tuple()
5756

5857
if is_tf2():
5958
convert_variables_to_constants = tf.compat.v1.graph_util.convert_variables_to_constants
@@ -147,6 +146,46 @@ def fix_freezing_errors(graph_def):
147146
return graph_def
148147

149148

149+
def from_trackable(trackable, concrete_func, inputs, outputs, large_model):
150+
err_large_model = "model exceeds maximum protobuf size of 2GB. Try setting large_model."
151+
152+
# Avoid errors due to bug in TF freezing
153+
removed_resource_to_placeholder, graph_captures_copy, func_captures_copy = \
154+
_remove_non_variable_resources_from_captures(concrete_func)
155+
156+
try:
157+
frozen_graph = from_function(concrete_func, inputs, outputs, large_model)
158+
except ValueError as e:
159+
if any(msg in str(e) for msg in ["exceeds maximum protobuf size of 2GB", "string too long"]):
160+
raise ValueError(err_large_model)
161+
raise e
162+
163+
# We might be returning the concrete_func so let's put it back in working order
164+
_restore_captured_resources(concrete_func, graph_captures_copy, func_captures_copy)
165+
166+
table_names, key_dtypes, value_dtypes = get_hash_table_info(frozen_graph)
167+
placeholder_to_table_info = {}
168+
_get_hash_table_info_from_trackable(trackable, table_names, key_dtypes, value_dtypes,
169+
removed_resource_to_placeholder, placeholder_to_table_info)
170+
171+
initialized_tables = {}
172+
for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes):
173+
h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n)
174+
try:
175+
k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype)
176+
initialized_tables[n] = (k.numpy(), v.numpy())
177+
except Exception: # pylint: disable=broad-except
178+
logger.warning("Could not initialize table with shared_name = %r", n)
179+
180+
for placeholder in removed_resource_to_placeholder.values():
181+
if placeholder not in placeholder_to_table_info:
182+
logger.error("Could not find table resource to replace placeholder %s", placeholder)
183+
184+
replace_placeholders_with_tables(frozen_graph, placeholder_to_table_info)
185+
186+
return frozen_graph, initialized_tables
187+
188+
150189
def from_function(func, input_names, output_names, large_model=False):
151190
if large_model:
152191
return convert_variables_to_constants_large_model(func)
@@ -327,7 +366,27 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signa
327366
def _get_hash_table_info_from_trackable(trackable, table_names, key_dtypes, value_dtypes,
328367
removed_resource_to_placeholder, placeholder_to_table_info):
329368
# pylint: disable=protected-access
330-
for r in trackable.__dict__.values():
369+
stack = [trackable]
370+
visited = set()
371+
while stack:
372+
r = stack.pop()
373+
visited.add(id(r))
374+
try:
375+
for trackable_ref in r._checkpoint_dependencies:
376+
if id(trackable_ref.ref) not in visited:
377+
if isinstance(trackable_ref.ref, TfTrackableType):
378+
stack.append(trackable_ref.ref)
379+
except Exception: # pylint: disable=broad-except
380+
continue
381+
for t in r.__dict__.values() if hasattr(r, '__dict__') else []:
382+
if isinstance(t, TfStaticHashTableType) and hasattr(t, '_shared_name'):
383+
table_names.append(t._shared_name.encode())
384+
key_dtypes.append(t.key_dtype.as_datatype_enum)
385+
value_dtypes.append(t.value_dtype.as_datatype_enum)
386+
table_handle = id(t.resource_handle)
387+
if table_handle in removed_resource_to_placeholder:
388+
table_info = (table_names[-1], key_dtypes[-1], value_dtypes[-1])
389+
placeholder_to_table_info[removed_resource_to_placeholder[table_handle]] = table_info
331390
if isinstance(r, TfRestoredResourceType) and hasattr(r, '_create_resource'):
332391
try:
333392
table_handle = id(r.resource_handle)
@@ -341,9 +400,6 @@ def _get_hash_table_info_from_trackable(trackable, table_names, key_dtypes, valu
341400
if table_handle in removed_resource_to_placeholder and len(new_names) == 1:
342401
table_info = (new_names[0], new_k_dtypes[0], new_v_dtypes[0])
343402
placeholder_to_table_info[removed_resource_to_placeholder[table_handle]] = table_info
344-
if isinstance(r, TfAutoTrackableType):
345-
_get_hash_table_info_from_trackable(r, table_names, key_dtypes, value_dtypes,
346-
removed_resource_to_placeholder, placeholder_to_table_info)
347403

348404

349405
def _remove_non_variable_resources_from_captures(concrete_func):
@@ -399,7 +455,6 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d
399455
err_index = "Invalid concrete_function value: %i. Valid values are [0 to %i]"
400456
err_no_sig = "No signatures found in model. Try --concrete_function instead."
401457
err_sig_nomatch = "Specified signature not in model %s"
402-
err_large_model = "model exceeds maximum protobuf size of 2GB. Try running with --large_model flag."
403458

404459
if tag is None:
405460
tag = ['serve']
@@ -456,39 +511,7 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d
456511
outputs = output_names
457512
logger.info("Outputs not left as None; will use provided names not structured output names.")
458513

459-
# Avoid errors due to bug in TF freezing
460-
removed_resource_to_placeholder, graph_captures_copy, func_captures_copy = \
461-
_remove_non_variable_resources_from_captures(concrete_func)
462-
463-
try:
464-
frozen_graph = from_function(concrete_func, inputs, outputs, large_model)
465-
except ValueError as e:
466-
if any(msg in str(e) for msg in ["exceeds maximum protobuf size of 2GB", "string too long"]):
467-
raise ValueError(err_large_model)
468-
raise e
469-
470-
# We might be returning the concrete_func so let's put it back in working order
471-
_restore_captured_resources(concrete_func, graph_captures_copy, func_captures_copy)
472-
473-
table_names, key_dtypes, value_dtypes = get_hash_table_info(frozen_graph)
474-
placeholder_to_table_info = {}
475-
_get_hash_table_info_from_trackable(imported, table_names, key_dtypes, value_dtypes,
476-
removed_resource_to_placeholder, placeholder_to_table_info)
477-
478-
initialized_tables = {}
479-
for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes):
480-
h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n)
481-
try:
482-
k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype)
483-
initialized_tables[n] = (k.numpy(), v.numpy())
484-
except Exception: # pylint: disable=broad-except
485-
logger.warning("Could not initialize table with shared_name = %r", n)
486-
487-
for placeholder in removed_resource_to_placeholder.values():
488-
if placeholder not in placeholder_to_table_info:
489-
logger.error("Could not find table resource to replace placeholder %s", placeholder)
490-
491-
replace_placeholders_with_tables(frozen_graph, placeholder_to_table_info)
514+
frozen_graph, initialized_tables = from_trackable(imported, concrete_func, inputs, outputs, large_model)
492515

493516
return frozen_graph, inputs, outputs, concrete_func, imported, initialized_tables, tensors_to_rename
494517

0 commit comments

Comments
 (0)