@@ -95,14 +95,14 @@ def run_onnxruntime(self, model_path, inputs, output_names, use_custom_ops=False
95
95
results = m .run (output_names , inputs )
96
96
return results
97
97
98
- def run_backend (self , g , outputs , input_dict , large_model = False , postfix = "" ):
98
+ def run_backend (self , g , outputs , input_dict , large_model = False , postfix = "" , use_custom_ops = False ):
99
99
tensor_storage = ExternalTensorStorage () if large_model else None
100
100
model_proto = g .make_model ("test" , external_tensor_storage = tensor_storage )
101
101
model_path = self .save_onnx_model (model_proto , input_dict , external_tensor_storage = tensor_storage ,
102
102
postfix = postfix )
103
103
104
104
if self .config .backend == "onnxruntime" :
105
- y = self .run_onnxruntime (model_path , input_dict , outputs )
105
+ y = self .run_onnxruntime (model_path , input_dict , outputs , use_custom_ops )
106
106
elif self .config .backend == "caffe2" :
107
107
y = self .run_onnxcaffe2 (model_proto , input_dict )
108
108
else :
@@ -307,7 +307,8 @@ def get_dtype(info):
307
307
def run_test_case (self , func , feed_dict , input_names_with_port , output_names_with_port ,
308
308
rtol = 1e-07 , atol = 1e-5 , mtol = None , convert_var_to_const = True , constant_fold = True ,
309
309
check_value = True , check_shape = True , check_dtype = True , process_args = None , onnx_feed_dict = None ,
310
- graph_validator = None , as_session = False , large_model = False , premade_placeholders = False ):
310
+ graph_validator = None , as_session = False , large_model = False , premade_placeholders = False ,
311
+ use_custom_ops = False ):
311
312
test_tf = not self .config .skip_tf_tests
312
313
test_tflite = not self .config .skip_tflite_tests
313
314
run_tfl_consistency_test = test_tf and test_tflite and self .config .run_tfl_consistency_test
@@ -347,7 +348,8 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
347
348
initialized_tables = initialized_tables ,
348
349
** process_args )
349
350
g = optimizer .optimize_graph (g , catch_errors = False )
350
- actual = self .run_backend (g , output_names_with_port , onnx_feed_dict , large_model )
351
+ actual = self .run_backend (g , output_names_with_port , onnx_feed_dict , large_model ,
352
+ use_custom_ops = use_custom_ops )
351
353
352
354
self .assert_results_equal (expected , actual , rtol , atol , mtol , check_value , check_shape , check_dtype )
353
355
self .assert_shapes_correct (g , self .config .allow_missing_shapes , not self .config .skip_onnx_checker )
@@ -377,7 +379,8 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
377
379
** tfl_process_args )
378
380
g = optimizer .optimize_graph (g )
379
381
onnx_feed_dict_without_port = {k .split (':' )[0 ]: v for k , v in onnx_feed_dict .items ()}
380
- onnx_tfl_res = self .run_backend (g , tfl_outputs , onnx_feed_dict_without_port , postfix = "_from_tflite" )
382
+ onnx_tfl_res = self .run_backend (g , tfl_outputs , onnx_feed_dict_without_port ,
383
+ postfix = "_from_tflite" , use_custom_ops = use_custom_ops )
381
384
382
385
self .assert_results_equal (tfl_res , onnx_tfl_res , rtol , atol , mtol , check_value , check_shape , check_dtype )
383
386
self .assert_shapes_correct (g , self .config .allow_missing_shapes , not self .config .skip_onnx_checker )
0 commit comments