@@ -353,4 +353,82 @@ TEST(CppAPITests, TestCollectionListInputOutput) {
353
353
354
354
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (out.toList ().vec ()[0 ].toTensor (), trt_out.toList ().vec ()[0 ].toTensor (), 1e-5 ));
355
355
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (out.toList ().vec ()[1 ].toTensor (), trt_out.toList ().vec ()[1 ].toTensor (), 1e-5 ));
356
+ }
357
+
358
+
359
+ TEST (CppAPITests, TestCollectionComplexModel) {
360
+
361
+ std::string path =
362
+ " /root/Torch-TensorRT/complex_model.ts" ;
363
+ torch::Tensor in0 = torch::randn ({1 , 3 , 512 , 512 }, torch::kCUDA ).to (torch::kHalf );
364
+ std::vector<at::Tensor> inputs;
365
+ inputs.push_back (in0);
366
+
367
+ torch::jit::Module mod;
368
+ try {
369
+ // Deserialize the ScriptModule from a file using torch::jit::load().
370
+ mod = torch::jit::load (path);
371
+ } catch (const c10::Error& e) {
372
+ std::cerr << " error loading the model\n " ;
373
+ }
374
+ mod.eval ();
375
+ mod.to (torch::kCUDA );
376
+
377
+
378
+ std::vector<torch::jit::IValue> inputs_;
379
+
380
+ for (auto in : inputs) {
381
+ inputs_.push_back (torch::jit::IValue (in.clone ()));
382
+ }
383
+
384
+ std::vector<torch::jit::IValue> complex_inputs;
385
+ auto input_list = c10::impl::GenericList (c10::TensorType::get ());
386
+ input_list.push_back (inputs_[0 ]);
387
+ input_list.push_back (inputs_[0 ]);
388
+
389
+ torch::jit::IValue input_list_ivalue = torch::jit::IValue (input_list);
390
+
391
+ complex_inputs.push_back (input_list_ivalue);
392
+
393
+
394
+ auto out = mod.forward (complex_inputs);
395
+ LOG_DEBUG (" Finish torchscirpt forward" );
396
+
397
+
398
+ // auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kUnknown);
399
+ auto input_shape = torch_tensorrt::Input (in0.sizes (), torch_tensorrt::DataType::kHalf );
400
+
401
+ auto input_shape_ivalue = torch::jit::IValue (std::move (c10::make_intrusive<torch_tensorrt::Input>(input_shape)));
402
+
403
+
404
+ c10::TypePtr elementType = input_shape_ivalue.type ();
405
+ auto list = c10::impl::GenericList (elementType);
406
+ list.push_back (input_shape_ivalue);
407
+ list.push_back (input_shape_ivalue);
408
+
409
+
410
+ torch::jit::IValue complex_input_shape (list);
411
+ std::tuple<torch::jit::IValue> input_tuple2 (complex_input_shape);
412
+ torch::jit::IValue complex_input_shape2 (input_tuple2);
413
+
414
+ auto compile_settings = torch_tensorrt::ts::CompileSpec (complex_input_shape2);
415
+ compile_settings.require_full_compilation = false ;
416
+ compile_settings.min_block_size = 1 ;
417
+
418
+ // Need to skip the conversion of __getitem__ and ListConstruct
419
+ compile_settings.torch_executed_ops .push_back (" aten::__getitem__" );
420
+ compile_settings.torch_executed_ops .push_back (" prim::ListConstruct" );
421
+
422
+ // // FP16 execution
423
+ compile_settings.enabled_precisions = {torch::kHalf };
424
+ // // Compile module
425
+ auto trt_mod = torch_tensorrt::torchscript::compile (mod, compile_settings);
426
+ LOG_DEBUG (" Finish compile" );
427
+ auto trt_out = trt_mod.forward (complex_inputs);
428
+ // auto trt_out = trt_mod.forward(complex_inputs_list);
429
+
430
+ // std::cout << out.toTuple()->elements()[0].toTensor() << std::endl;
431
+
432
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (out.toTuple ()->elements ()[0 ].toTensor (), trt_out.toTuple ()->elements ()[0 ].toTensor (), 1e-5 ));
433
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (out.toTuple ()->elements ()[1 ].toTensor (), trt_out.toTuple ()->elements ()[1 ].toTensor (), 1e-5 ));
356
434
}
0 commit comments