Skip to content

Commit a53b8d9

Browse files
Pylint
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent c76c0b2 commit a53b8d9

File tree

3 files changed

+22
-26
lines changed

3 files changed

+22
-26
lines changed

tests/backend_test_base.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,14 @@ def run_onnxruntime(self, model_path, inputs, output_names, use_custom_ops=False
9595
results = m.run(output_names, inputs)
9696
return results
9797

98-
def run_backend(self, g, outputs, input_dict, large_model=False, postfix=""):
98+
def run_backend(self, g, outputs, input_dict, large_model=False, postfix="", use_custom_ops=False):
9999
tensor_storage = ExternalTensorStorage() if large_model else None
100100
model_proto = g.make_model("test", external_tensor_storage=tensor_storage)
101101
model_path = self.save_onnx_model(model_proto, input_dict, external_tensor_storage=tensor_storage,
102102
postfix=postfix)
103103

104104
if self.config.backend == "onnxruntime":
105-
y = self.run_onnxruntime(model_path, input_dict, outputs)
105+
y = self.run_onnxruntime(model_path, input_dict, outputs, use_custom_ops)
106106
elif self.config.backend == "caffe2":
107107
y = self.run_onnxcaffe2(model_proto, input_dict)
108108
else:
@@ -307,7 +307,8 @@ def get_dtype(info):
307307
def run_test_case(self, func, feed_dict, input_names_with_port, output_names_with_port,
308308
rtol=1e-07, atol=1e-5, mtol=None, convert_var_to_const=True, constant_fold=True,
309309
check_value=True, check_shape=True, check_dtype=True, process_args=None, onnx_feed_dict=None,
310-
graph_validator=None, as_session=False, large_model=False, premade_placeholders=False):
310+
graph_validator=None, as_session=False, large_model=False, premade_placeholders=False,
311+
use_custom_ops=False):
311312
test_tf = not self.config.skip_tf_tests
312313
test_tflite = not self.config.skip_tflite_tests
313314
run_tfl_consistency_test = test_tf and test_tflite and self.config.run_tfl_consistency_test
@@ -347,7 +348,8 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
347348
initialized_tables=initialized_tables,
348349
**process_args)
349350
g = optimizer.optimize_graph(g, catch_errors=False)
350-
actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model)
351+
actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model,
352+
use_custom_ops=use_custom_ops)
351353

352354
self.assert_results_equal(expected, actual, rtol, atol, mtol, check_value, check_shape, check_dtype)
353355
self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker)
@@ -377,7 +379,8 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
377379
**tfl_process_args)
378380
g = optimizer.optimize_graph(g)
379381
onnx_feed_dict_without_port = {k.split(':')[0]: v for k, v in onnx_feed_dict.items()}
380-
onnx_tfl_res = self.run_backend(g, tfl_outputs, onnx_feed_dict_without_port, postfix="_from_tflite")
382+
onnx_tfl_res = self.run_backend(g, tfl_outputs, onnx_feed_dict_without_port,
383+
postfix="_from_tflite", use_custom_ops=use_custom_ops)
381384

382385
self.assert_results_equal(tfl_res, onnx_tfl_res, rtol, atol, mtol, check_value, check_shape, check_dtype)
383386
self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker)

tests/test_api.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import numpy as np
1616
import tensorflow as tf
17+
from onnx import helper
1718

1819
from common import check_tf_min_version, unittest_main, requires_custom_ops
1920
from backend_test_base import Tf2OnnxBackendTestBase
@@ -82,19 +83,21 @@ def test_keras_api_large(self):
8283
@check_tf_min_version("2.0")
8384
def test_keras_hashtable(self):
8485

85-
featCols = [tf.feature_column.numeric_column("f_inp", dtype=tf.float32),
86+
feature_cols = [
87+
tf.feature_column.numeric_column("f_inp", dtype=tf.float32),
8688
tf.feature_column.indicator_column(
8789
tf.feature_column.categorical_column_with_vocabulary_list("s_inp", ["a", "b", "z"], num_oov_buckets=1)
88-
)]
89-
featureLayer = tf.keras.layers.DenseFeatures(featCols)
90+
)
91+
]
92+
feature_layer = tf.keras.layers.DenseFeatures(feature_cols)
9093

91-
inputDict = {}
92-
inputDict["f_inp"] = tf.keras.Input(name="f_inp", shape=(1,), dtype=tf.float32)
93-
inputDict["s_inp"] = tf.keras.Input(name="s_inp", shape=(1,), dtype=tf.string)
94+
input_dict = {}
95+
input_dict["f_inp"] = tf.keras.Input(name="f_inp", shape=(1,), dtype=tf.float32)
96+
input_dict["s_inp"] = tf.keras.Input(name="s_inp", shape=(1,), dtype=tf.string)
9497

95-
inputs = [input for input in inputDict.values()]
96-
standardFeatures = featureLayer(inputDict)
97-
hidden1 = tf.keras.layers.Dense(512, activation='relu')(standardFeatures)
98+
inputs = list(input_dict.values())
99+
standard_features = feature_layer(input_dict)
100+
hidden1 = tf.keras.layers.Dense(512, activation='relu')(standard_features)
98101
output = tf.keras.layers.Dense(10, activation='softmax')(hidden1)
99102
model = tf.keras.Model(inputs=inputs, outputs=output)
100103
model.compile(optimizer='adam', loss=tf.keras.losses.mean_squared_error)
@@ -106,7 +109,6 @@ def test_keras_hashtable(self):
106109
tf.TensorSpec((None, 1), tf.string, name="s_inp"))
107110
output_path = os.path.join(self.test_data_directory, "model.onnx")
108111

109-
from onnx import helper
110112
model_proto, _ = tf2onnx.convert.from_keras(
111113
model, input_signature=spec, opset=self.config.opset, output_path=output_path,
112114
extra_opset=[helper.make_opsetid("ai.onnx.contrib", 1)])

tests/test_string_ops.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -135,17 +135,8 @@ def func(text):
135135
def _run_test_case(self, func, output_names_with_port, feed_dict, **kwargs):
136136
extra_opset = [utils.make_opsetid(constants.CONTRIB_OPS_DOMAIN, 1)]
137137
process_args = {"extra_opset": extra_opset}
138-
return self.run_test_case(func, feed_dict, [], output_names_with_port, process_args=process_args, **kwargs)
139-
140-
def run_onnxruntime(self, model_path, inputs, output_names):
141-
"""Run test against onnxruntime backend."""
142-
from onnxruntime_customops import get_library_path
143-
import onnxruntime as rt
144-
opt = rt.SessionOptions()
145-
opt.register_custom_ops_library(get_library_path())
146-
m = rt.InferenceSession(model_path, opt)
147-
results = m.run(output_names, inputs)
148-
return results
138+
return self.run_test_case(func, feed_dict, [], output_names_with_port,
139+
use_custom_ops=True, process_args=process_args, **kwargs)
149140

150141
@requires_custom_ops("WordpieceTokenizer")
151142
@check_tf_min_version("2.0", "tensorflow_text")

0 commit comments

Comments
 (0)