Skip to content

Commit fa90772

Browse files
committed
feat: [collection] add unit test for complex collection model
Signed-off-by: inocsin <[email protected]>
1 parent ccee7f5 commit fa90772

File tree

2 files changed

+102
-1
lines changed

2 files changed

+102
-1
lines changed

tests/cpp/test_collection.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,4 +353,82 @@ TEST(CppAPITests, TestCollectionListInputOutput) {
353353

354354
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toList().vec()[0].toTensor(), trt_out.toList().vec()[0].toTensor(), 1e-5));
355355
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toList().vec()[1].toTensor(), trt_out.toList().vec()[1].toTensor(), 1e-5));
356+
}
357+
358+
359+
TEST(CppAPITests, TestCollectionComplexModel) {
360+
361+
std::string path =
362+
"/root/Torch-TensorRT/complex_model.ts";
363+
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
364+
std::vector<at::Tensor> inputs;
365+
inputs.push_back(in0);
366+
367+
torch::jit::Module mod;
368+
try {
369+
// Deserialize the ScriptModule from a file using torch::jit::load().
370+
mod = torch::jit::load(path);
371+
} catch (const c10::Error& e) {
372+
std::cerr << "error loading the model\n";
373+
}
374+
mod.eval();
375+
mod.to(torch::kCUDA);
376+
377+
378+
std::vector<torch::jit::IValue> inputs_;
379+
380+
for (auto in : inputs) {
381+
inputs_.push_back(torch::jit::IValue(in.clone()));
382+
}
383+
384+
std::vector<torch::jit::IValue> complex_inputs;
385+
auto input_list = c10::impl::GenericList(c10::TensorType::get());
386+
input_list.push_back(inputs_[0]);
387+
input_list.push_back(inputs_[0]);
388+
389+
torch::jit::IValue input_list_ivalue = torch::jit::IValue(input_list);
390+
391+
complex_inputs.push_back(input_list_ivalue);
392+
393+
394+
auto out = mod.forward(complex_inputs);
395+
LOG_DEBUG("Finish torchscirpt forward");
396+
397+
398+
// auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kUnknown);
399+
auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kHalf);
400+
401+
auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive<torch_tensorrt::Input>(input_shape)));
402+
403+
404+
c10::TypePtr elementType = input_shape_ivalue.type();
405+
auto list = c10::impl::GenericList(elementType);
406+
list.push_back(input_shape_ivalue);
407+
list.push_back(input_shape_ivalue);
408+
409+
410+
torch::jit::IValue complex_input_shape(list);
411+
std::tuple<torch::jit::IValue> input_tuple2(complex_input_shape);
412+
torch::jit::IValue complex_input_shape2(input_tuple2);
413+
414+
auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
415+
compile_settings.require_full_compilation = false;
416+
compile_settings.min_block_size = 1;
417+
418+
// Need to skip the conversion of __getitem__ and ListConstruct
419+
compile_settings.torch_executed_ops.push_back("aten::__getitem__");
420+
compile_settings.torch_executed_ops.push_back("prim::ListConstruct");
421+
422+
// // FP16 execution
423+
compile_settings.enabled_precisions = {torch::kHalf};
424+
// // Compile module
425+
auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings);
426+
LOG_DEBUG("Finish compile");
427+
auto trt_out = trt_mod.forward(complex_inputs);
428+
// auto trt_out = trt_mod.forward(complex_inputs_list);
429+
430+
// std::cout << out.toTuple()->elements()[0].toTensor() << std::endl;
431+
432+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTuple()->elements()[0].toTensor(), trt_out.toTuple()->elements()[0].toTensor(), 1e-5));
433+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTuple()->elements()[1].toTensor(), trt_out.toTuple()->elements()[1].toTensor(), 1e-5));
356434
}

tests/py/test_collection.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,22 @@ def forward(self, z: List[torch.Tensor]):
7979
r = [r1, r2]
8080
return r
8181

82+
class ComplexModel(nn.Module):
83+
def __init__(self):
84+
super(ComplexModel, self).__init__()
85+
self.list_model = ListInputOutput()
86+
self.tuple_model = TupleInputOutput()
87+
88+
def forward(self, z: List[torch.Tensor]):
89+
r1 = z[0] + z[1]
90+
r2 = z[0] - z[1]
91+
r3 = (r1, r2)
92+
r4 = [r2, r1]
93+
tuple_out = self.tuple_model(r3)
94+
list_out = self.list_model(r4)
95+
r = (tuple_out[1], list_out[0])
96+
return r
97+
8298
input_data = torch.randn((16, 3, 32, 32))
8399
input_data = input_data.float().to("cuda")
84100

@@ -115,4 +131,11 @@ def forward(self, z: List[torch.Tensor]):
115131
print(list_input_ts.graph)
116132
result = list_input_ts([input_data, input_data])
117133
list_input_ts.to("cuda").eval()
118-
torch.jit.save(list_input_ts, "./list_input_output.ts")
134+
torch.jit.save(list_input_ts, "./list_input_output.ts")
135+
136+
complex_model = ComplexModel()
137+
complex_model_ts = torch.jit.script(complex_model)
138+
print(complex_model_ts.graph)
139+
result = complex_model_ts([input_data, input_data])
140+
complex_model_ts.to("cuda").eval()
141+
torch.jit.save(complex_model_ts, "./complex_model.ts")

0 commit comments

Comments
 (0)