@@ -1219,75 +1219,6 @@ def test_functional_deeply_nested_outputs_struct_losses(self):
12191219 )
12201220 self .assertListEqual (hist_keys , ref_keys )
12211221
1222- @parameterized .named_parameters (
1223- ("tf_saved_model" , "tf_saved_model" ),
1224- ("onnx" , "onnx" ),
1225- )
1226- @pytest .mark .skipif (
1227- backend .backend () not in ("tensorflow" , "jax" , "torch" ),
1228- reason = (
1229- "Currently, `Model.export` only supports the tensorflow, jax and "
1230- "torch backends."
1231- ),
1232- )
1233- @pytest .mark .skipif (
1234- testing .jax_uses_gpu (), reason = "Leads to core dumps on CI"
1235- )
1236- def test_export (self , export_format ):
1237- if export_format == "tf_saved_model" and testing .torch_uses_gpu ():
1238- self .skipTest ("Leads to core dumps on CI" )
1239-
1240- temp_filepath = os .path .join (self .get_temp_dir (), "exported_model" )
1241- model = _get_model ()
1242- x1 = np .random .rand (1 , 3 ).astype ("float32" )
1243- x2 = np .random .rand (1 , 3 ).astype ("float32" )
1244- ref_output = model ([x1 , x2 ])
1245-
1246- model .export (temp_filepath , format = export_format )
1247-
1248- if export_format == "tf_saved_model" :
1249- import tensorflow as tf
1250-
1251- revived_model = tf .saved_model .load (temp_filepath )
1252- self .assertAllClose (ref_output , revived_model .serve ([x1 , x2 ]))
1253-
1254- # Test with a different batch size
1255- if backend .backend () == "torch" :
1256- # TODO: Dynamic shape is not supported yet in the torch backend
1257- return
1258- revived_model .serve (
1259- [
1260- np .concatenate ([x1 , x1 ], axis = 0 ),
1261- np .concatenate ([x2 , x2 ], axis = 0 ),
1262- ]
1263- )
1264- elif export_format == "onnx" :
1265- import onnxruntime
1266-
1267- ort_session = onnxruntime .InferenceSession (temp_filepath )
1268- ort_inputs = {
1269- k .name : v for k , v in zip (ort_session .get_inputs (), [x1 , x2 ])
1270- }
1271- self .assertAllClose (
1272- ref_output , ort_session .run (None , ort_inputs )[0 ]
1273- )
1274-
1275- # Test with a different batch size
1276- if backend .backend () == "torch" :
1277- # TODO: Dynamic shape is not supported yet in the torch backend
1278- return
1279- ort_inputs = {
1280- k .name : v
1281- for k , v in zip (
1282- ort_session .get_inputs (),
1283- [
1284- np .concatenate ([x1 , x1 ], axis = 0 ),
1285- np .concatenate ([x2 , x2 ], axis = 0 ),
1286- ],
1287- )
1288- }
1289- ort_session .run (None , ort_inputs )
1290-
12911222 def test_export_error (self ):
12921223 temp_filepath = os .path .join (self .get_temp_dir (), "exported_model" )
12931224 model = _get_model ()
0 commit comments