1
1
#include " torch/script.h"
2
2
#include " torch/torch.h"
3
- #include " trtorch/ptq.h"
4
3
#include " trtorch/trtorch.h"
5
4
6
5
#include " NvInfer.h"
@@ -28,23 +27,24 @@ struct Resize : public torch::data::transforms::TensorTransform<torch::Tensor> {
28
27
std::vector<int64_t > new_size_;
29
28
};
30
29
31
- torch::jit::Module compile_int8_qat_model (torch::jit::Module& mod) {
32
- std::vector<std::vector<int64_t >> input_shape = {{32 , 3 , 32 , 32 }};
30
+ torch::jit::Module compile_int8_qat_model (const std::string& data_dir, torch::jit::Module& mod) {
31
+
32
+ std::vector<trtorch::CompileSpec::Input> inputs = {
33
+ trtorch::CompileSpec::Input (std::vector<int64_t >({32 , 3 , 32 , 32 }), trtorch::CompileSpec::DataType::kFloat )};
33
34
// / Configure settings for compilation
34
- auto compile_spec = trtorch::CompileSpec ({input_shape} );
35
+ auto compile_spec = trtorch::CompileSpec (inputs );
35
36
// / Set operating precision to INT8
37
+ // compile_spec.enabled_precisions.insert(torch::kF16);
36
38
compile_spec.enabled_precisions .insert (torch::kI8 );
37
39
// / Set max batch size for the engine
38
40
compile_spec.max_batch_size = 32 ;
39
41
// / Set a larger workspace
40
42
compile_spec.workspace_size = 1 << 28 ;
41
43
42
- mod.eval ();
43
-
44
44
#ifdef SAVE_ENGINE
45
45
std::cout << " Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
46
46
auto engine = trtorch::ConvertGraphToTRTEngine (mod, " forward" , compile_spec);
47
- std::ofstream out (" /tmp/engine_converted_from_jit .trt" );
47
+ std::ofstream out (" /tmp/int8_engine_converted_from_jit .trt" );
48
48
out << engine;
49
49
out.close ();
50
50
#endif
@@ -71,62 +71,53 @@ int main(int argc, const char* argv[]) {
71
71
return -1 ;
72
72
}
73
73
74
- // / Convert the model using TensorRT
75
- auto trt_mod = compile_int8_qat_model (mod);
76
- std::cout << " Model conversion to TensorRT completed." << std::endl;
77
- // / Dataloader moved into calibrator so need another for inference
74
+ mod.eval ();
75
+
76
+ // / Create the calibration dataset
78
77
const std::string data_dir = std::string (argv[2 ]);
78
+
79
+ // / Dataloader moved into calibrator so need another for inference
79
80
auto eval_dataset = datasets::CIFAR10 (data_dir, datasets::CIFAR10::Mode::kTest )
81
+ .use_subset (3200 )
80
82
.map (torch::data::transforms::Normalize<>({0.4914 , 0.4822 , 0.4465 }, {0.2023 , 0.1994 , 0.2010 }))
81
83
.map (torch::data::transforms::Stack<>());
82
84
auto eval_dataloader = torch::data::make_data_loader (
83
85
std::move (eval_dataset), torch::data::DataLoaderOptions ().batch_size (32 ).workers (2 ));
84
86
85
87
// / Check the FP32 accuracy in JIT
86
- float correct = 0.0 , total = 0.0 ;
88
+ torch::Tensor jit_correct = torch::zeros ({ 1 }, {torch:: kCUDA }), jit_total = torch::zeros ({ 1 }, {torch:: kCUDA }) ;
87
89
for (auto batch : *eval_dataloader) {
88
90
auto images = batch.data .to (torch::kCUDA );
89
91
auto targets = batch.target .to (torch::kCUDA );
90
92
91
93
auto outputs = mod.forward ({images});
92
94
auto predictions = std::get<1 >(torch::max (outputs.toTensor (), 1 , false ));
93
95
94
- total += targets.sizes ()[0 ];
95
- correct += torch::sum (torch::eq (predictions, targets)). item (). toFloat ( );
96
+ jit_total += targets.sizes ()[0 ];
97
+ jit_correct += torch::sum (torch::eq (predictions, targets));
96
98
}
97
- std::cout << " Accuracy of JIT model on test set: " << 100 * (correct / total) << " %"
98
- << " correct: " << correct << " total: " << total << std::endl;
99
+ torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100 ;
100
+
101
+ // / Compile Graph
102
+ auto trt_mod = compile_int8_qat_model (data_dir, mod);
99
103
100
104
// / Check the INT8 accuracy in TRT
101
- correct = 0.0 ;
102
- total = 0.0 ;
105
+ torch::Tensor trt_correct = torch::zeros ({1 }, {torch::kCUDA }), trt_total = torch::zeros ({1 }, {torch::kCUDA });
103
106
for (auto batch : *eval_dataloader) {
104
107
auto images = batch.data .to (torch::kCUDA );
105
108
auto targets = batch.target .to (torch::kCUDA );
106
109
107
- if (images.sizes ()[0 ] < 32 ) {
108
- // / To handle smaller batches util Optimization profiles work with Int8
109
- auto diff = 32 - images.sizes ()[0 ];
110
- auto img_padding = torch::zeros ({diff, 3 , 32 , 32 }, {torch::kCUDA });
111
- auto target_padding = torch::zeros ({diff}, {torch::kCUDA });
112
- images = torch::cat ({images, img_padding}, 0 );
113
- targets = torch::cat ({targets, target_padding}, 0 );
114
- }
115
-
116
110
auto outputs = trt_mod.forward ({images});
117
111
auto predictions = std::get<1 >(torch::max (outputs.toTensor (), 1 , false ));
118
112
predictions = predictions.reshape (predictions.sizes ()[0 ]);
119
113
120
- if (predictions.sizes ()[0 ] != targets.sizes ()[0 ]) {
121
- // / To handle smaller batches util Optimization profiles work with Int8
122
- predictions = predictions.slice (0 , 0 , targets.sizes ()[0 ]);
123
- }
124
-
125
- total += targets.sizes ()[0 ];
126
- correct += torch::sum (torch::eq (predictions, targets)).item ().toFloat ();
114
+ trt_total += targets.sizes ()[0 ];
115
+ trt_correct += torch::sum (torch::eq (predictions, targets)).item ().toFloat ();
127
116
}
128
- std::cout << " Accuracy of quantized model on test set: " << 100 * (correct / total) << " %"
129
- << " correct: " << correct << " total: " << total << std::endl;
117
+ torch::Tensor trt_accuracy = (trt_correct / trt_total) * 100 ;
118
+
119
+ std::cout << " Accuracy of JIT model on test set: " << jit_accuracy.item ().toFloat () << " %" << std::endl;
120
+ std::cout << " Accuracy of quantized model on test set: " << trt_accuracy.item ().toFloat () << " %" << std::endl;
130
121
131
122
// / Time execution in JIT-FP32 and TRT-INT8
132
123
std::vector<std::vector<int64_t >> dims = {{32 , 3 , 32 , 32 }};
0 commit comments