@@ -10,9 +10,10 @@ TEST(CppAPITests, TestCollectionTupleInput) {
10
10
11
11
std::string path =
12
12
" /root/Torch-TensorRT/tuple_input.ts" ;
13
- torch::Tensor in0 = torch::randn ({1 , 3 , 512 , 512 }, torch::kCUDA ).to (torch::kFloat );
14
- std::vector<at::Tensor> inputs;
15
- inputs.push_back (in0);
13
+ // torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kFloat);
14
+ torch::Tensor in0 = torch::randn ({1 , 3 , 512 , 512 }, torch::kCUDA ).to (torch::kHalf );
15
+ // std::vector<at::Tensor> inputs;
16
+ // inputs.push_back(in0);
16
17
17
18
torch::jit::Module mod;
18
19
try {
@@ -23,13 +24,13 @@ TEST(CppAPITests, TestCollectionTupleInput) {
23
24
}
24
25
mod.eval ();
25
26
mod.to (torch::kCUDA );
26
-
27
27
28
- std::vector<torch::jit::IValue> inputs_;
29
28
30
- for (auto in : inputs) {
31
- inputs_.push_back (torch::jit::IValue (in.clone ()));
32
- }
29
+ // std::vector<torch::jit::IValue> inputs_;
30
+
31
+ // for (auto in : inputs) {
32
+ // inputs_.push_back(torch::jit::IValue(in.clone()));
33
+ // }
33
34
34
35
35
36
std::vector<torch::jit::IValue> complex_inputs, complex_inputs_list;
@@ -42,16 +43,12 @@ TEST(CppAPITests, TestCollectionTupleInput) {
42
43
// torch::jit::IValue input_list_ivalue = torch::jit::IValue(input_list);
43
44
44
45
complex_inputs.push_back (input_tuple);
45
- // complex_inputs_list.push_back(in0);
46
- // complex_inputs_list.push_back(in0);
47
-
48
-
49
46
50
47
auto out = mod.forward (complex_inputs);
51
48
LOG_DEBUG (" Finish torchscirpt forward" );
52
49
53
-
54
- auto input_shape = torch_tensorrt::Input (in0.sizes (), torch_tensorrt::DataType::kUnknown );
50
+ // auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kUnknown);
51
+ auto input_shape = torch_tensorrt::Input (in0.sizes (), torch_tensorrt::DataType::kHalf );
55
52
56
53
auto input_shape_ivalue = torch::jit::IValue (std::move (c10::make_intrusive<torch_tensorrt::Input>(input_shape)));
57
54
@@ -63,7 +60,6 @@ TEST(CppAPITests, TestCollectionTupleInput) {
63
60
64
61
std::tuple<torch::jit::IValue, torch::jit::IValue> input_shape_tuple (input_shape_ivalue, input_shape_ivalue);
65
62
66
-
67
63
torch::jit::IValue complex_input_shape (input_shape_tuple);
68
64
std::tuple<torch::jit::IValue> input_tuple2 (complex_input_shape);
69
65
torch::jit::IValue complex_input_shape2 (input_tuple2);
@@ -74,13 +70,12 @@ TEST(CppAPITests, TestCollectionTupleInput) {
74
70
compile_settings.min_block_size = 1 ;
75
71
76
72
// // FP16 execution
77
- // compile_settings.enabled_precisions = {torch::kHalf};
73
+ compile_settings.enabled_precisions = {torch::kHalf };
78
74
// // Compile module
79
75
auto trt_mod = torch_tensorrt::torchscript::compile (mod, compile_settings);
80
76
LOG_DEBUG (" Finish compile" );
81
77
auto trt_out = trt_mod.forward (complex_inputs);
82
- // auto trt_out = trt_mod.forward(complex_inputs_list);
83
-
78
+ // std::cout << out.toTensor() << std::endl;
84
79
85
80
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (out.toTensor (), trt_out.toTensor (), 1e-5 ));
86
81
}
@@ -90,7 +85,7 @@ TEST(CppAPITests, TestCollectionNormalInput) {
90
85
91
86
std::string path =
92
87
" /root/Torch-TensorRT/normal_model.ts" ;
93
- torch::Tensor in0 = torch::randn ({1 , 3 , 512 , 512 }, torch::kCUDA ).to (torch::kFloat );
88
+ torch::Tensor in0 = torch::randn ({1 , 3 , 512 , 512 }, torch::kCUDA ).to (torch::kHalf );
94
89
std::vector<at::Tensor> inputs;
95
90
inputs.push_back (in0);
96
91
inputs.push_back (in0);
@@ -116,14 +111,14 @@ TEST(CppAPITests, TestCollectionNormalInput) {
116
111
LOG_DEBUG (" Finish torchscirpt forward" );
117
112
118
113
std::vector<torch_tensorrt::Input> input_range;
119
- input_range.push_back ({in0.sizes (), torch::kF32 });
120
- input_range.push_back ({in0.sizes (), torch::kF32 });
114
+ input_range.push_back ({in0.sizes (), torch::kF16 });
115
+ input_range.push_back ({in0.sizes (), torch::kF16 });
121
116
torch_tensorrt::ts::CompileSpec compile_settings (input_range);
122
117
compile_settings.require_full_compilation = true ;
123
118
compile_settings.min_block_size = 1 ;
124
119
125
120
// // FP16 execution
126
- // compile_settings.enabled_precisions = {torch::kHalf};
121
+ compile_settings.enabled_precisions = {torch::kHalf };
127
122
// // Compile module
128
123
auto trt_mod = torch_tensorrt::torchscript::compile (mod, compile_settings);
129
124
LOG_DEBUG (" Finish compile" );
@@ -138,7 +133,7 @@ TEST(CppAPITests, TestCollectionListInput) {
138
133
139
134
std::string path =
140
135
" /root/Torch-TensorRT/list_input.ts" ;
141
- torch::Tensor in0 = torch::randn ({1 , 3 , 512 , 512 }, torch::kCUDA ).to (torch::kFloat );
136
+ torch::Tensor in0 = torch::randn ({1 , 3 , 512 , 512 }, torch::kCUDA ).to (torch::kHalf );
142
137
std::vector<at::Tensor> inputs;
143
138
inputs.push_back (in0);
144
139
@@ -173,7 +168,8 @@ TEST(CppAPITests, TestCollectionListInput) {
173
168
LOG_DEBUG (" Finish torchscirpt forward" );
174
169
175
170
176
- auto input_shape = torch_tensorrt::Input (in0.sizes (), torch_tensorrt::DataType::kUnknown );
171
+ // auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kUnknown);
172
+ auto input_shape = torch_tensorrt::Input (in0.sizes (), torch_tensorrt::DataType::kHalf );
177
173
178
174
auto input_shape_ivalue = torch::jit::IValue (std::move (c10::make_intrusive<torch_tensorrt::Input>(input_shape)));
179
175
@@ -194,13 +190,13 @@ TEST(CppAPITests, TestCollectionListInput) {
194
190
compile_settings.torch_executed_ops .push_back (" aten::__getitem__" );
195
191
196
192
// // FP16 execution
197
- // compile_settings.enabled_precisions = {torch::kHalf};
193
+ compile_settings.enabled_precisions = {torch::kHalf };
198
194
// // Compile module
199
195
auto trt_mod = torch_tensorrt::torchscript::compile (mod, compile_settings);
200
196
LOG_DEBUG (" Finish compile" );
201
197
auto trt_out = trt_mod.forward (complex_inputs);
202
198
// auto trt_out = trt_mod.forward(complex_inputs_list);
203
199
204
-
200
+ // std::cout << out.toTensor() << std::endl;
205
201
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (out.toTensor (), trt_out.toTensor (), 1e-5 ));
206
202
}
0 commit comments