Skip to content

Commit 74bbd10

Browse files
committed
chore: Rebase qat main.cpp
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent cc19809 commit 74bbd10

File tree

1 file changed

+27
-36
lines changed

1 file changed

+27
-36
lines changed

cpp/int8/qat/main.cpp

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include "torch/script.h"
22
#include "torch/torch.h"
3-
#include "trtorch/ptq.h"
43
#include "trtorch/trtorch.h"
54

65
#include "NvInfer.h"
@@ -28,23 +27,24 @@ struct Resize : public torch::data::transforms::TensorTransform<torch::Tensor> {
2827
std::vector<int64_t> new_size_;
2928
};
3029

31-
torch::jit::Module compile_int8_qat_model(torch::jit::Module& mod) {
32-
std::vector<std::vector<int64_t>> input_shape = {{32, 3, 32, 32}};
30+
torch::jit::Module compile_int8_qat_model(const std::string& data_dir, torch::jit::Module& mod) {
31+
32+
std::vector<trtorch::CompileSpec::Input> inputs = {
33+
trtorch::CompileSpec::Input(std::vector<int64_t>({32, 3, 32, 32}), trtorch::CompileSpec::DataType::kFloat)};
3334
/// Configure settings for compilation
34-
auto compile_spec = trtorch::CompileSpec({input_shape});
35+
auto compile_spec = trtorch::CompileSpec(inputs);
3536
/// Set operating precision to INT8
37+
// compile_spec.enabled_precisions.insert(torch::kF16);
3638
compile_spec.enabled_precisions.insert(torch::kI8);
3739
/// Set max batch size for the engine
3840
compile_spec.max_batch_size = 32;
3941
/// Set a larger workspace
4042
compile_spec.workspace_size = 1 << 28;
4143

42-
mod.eval();
43-
4444
#ifdef SAVE_ENGINE
4545
std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
4646
auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", compile_spec);
47-
std::ofstream out("/tmp/engine_converted_from_jit.trt");
47+
std::ofstream out("/tmp/int8_engine_converted_from_jit.trt");
4848
out << engine;
4949
out.close();
5050
#endif
@@ -71,62 +71,53 @@ int main(int argc, const char* argv[]) {
7171
return -1;
7272
}
7373

74-
/// Convert the model using TensorRT
75-
auto trt_mod = compile_int8_qat_model(mod);
76-
std::cout << "Model conversion to TensorRT completed." << std::endl;
77-
/// Dataloader moved into calibrator so need another for inference
74+
mod.eval();
75+
76+
/// Create the calibration dataset
7877
const std::string data_dir = std::string(argv[2]);
78+
79+
/// Dataloader moved into calibrator so need another for inference
7980
auto eval_dataset = datasets::CIFAR10(data_dir, datasets::CIFAR10::Mode::kTest)
81+
.use_subset(3200)
8082
.map(torch::data::transforms::Normalize<>({0.4914, 0.4822, 0.4465}, {0.2023, 0.1994, 0.2010}))
8183
.map(torch::data::transforms::Stack<>());
8284
auto eval_dataloader = torch::data::make_data_loader(
8385
std::move(eval_dataset), torch::data::DataLoaderOptions().batch_size(32).workers(2));
8486

8587
/// Check the FP32 accuracy in JIT
86-
float correct = 0.0, total = 0.0;
88+
torch::Tensor jit_correct = torch::zeros({1}, {torch::kCUDA}), jit_total = torch::zeros({1}, {torch::kCUDA});
8789
for (auto batch : *eval_dataloader) {
8890
auto images = batch.data.to(torch::kCUDA);
8991
auto targets = batch.target.to(torch::kCUDA);
9092

9193
auto outputs = mod.forward({images});
9294
auto predictions = std::get<1>(torch::max(outputs.toTensor(), 1, false));
9395

94-
total += targets.sizes()[0];
95-
correct += torch::sum(torch::eq(predictions, targets)).item().toFloat();
96+
jit_total += targets.sizes()[0];
97+
jit_correct += torch::sum(torch::eq(predictions, targets));
9698
}
97-
std::cout << "Accuracy of JIT model on test set: " << 100 * (correct / total) << "%"
98-
<< " correct: " << correct << " total: " << total << std::endl;
99+
torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100;
100+
101+
/// Compile Graph
102+
auto trt_mod = compile_int8_qat_model(data_dir, mod);
99103

100104
/// Check the INT8 accuracy in TRT
101-
correct = 0.0;
102-
total = 0.0;
105+
torch::Tensor trt_correct = torch::zeros({1}, {torch::kCUDA}), trt_total = torch::zeros({1}, {torch::kCUDA});
103106
for (auto batch : *eval_dataloader) {
104107
auto images = batch.data.to(torch::kCUDA);
105108
auto targets = batch.target.to(torch::kCUDA);
106109

107-
if (images.sizes()[0] < 32) {
108-
/// To handle smaller batches util Optimization profiles work with Int8
109-
auto diff = 32 - images.sizes()[0];
110-
auto img_padding = torch::zeros({diff, 3, 32, 32}, {torch::kCUDA});
111-
auto target_padding = torch::zeros({diff}, {torch::kCUDA});
112-
images = torch::cat({images, img_padding}, 0);
113-
targets = torch::cat({targets, target_padding}, 0);
114-
}
115-
116110
auto outputs = trt_mod.forward({images});
117111
auto predictions = std::get<1>(torch::max(outputs.toTensor(), 1, false));
118112
predictions = predictions.reshape(predictions.sizes()[0]);
119113

120-
if (predictions.sizes()[0] != targets.sizes()[0]) {
121-
/// To handle smaller batches util Optimization profiles work with Int8
122-
predictions = predictions.slice(0, 0, targets.sizes()[0]);
123-
}
124-
125-
total += targets.sizes()[0];
126-
correct += torch::sum(torch::eq(predictions, targets)).item().toFloat();
114+
trt_total += targets.sizes()[0];
115+
trt_correct += torch::sum(torch::eq(predictions, targets)).item().toFloat();
127116
}
128-
std::cout << "Accuracy of quantized model on test set: " << 100 * (correct / total) << "%"
129-
<< " correct: " << correct << " total: " << total << std::endl;
117+
torch::Tensor trt_accuracy = (trt_correct / trt_total) * 100;
118+
119+
std::cout << "Accuracy of JIT model on test set: " << jit_accuracy.item().toFloat() << "%" << std::endl;
120+
std::cout << "Accuracy of quantized model on test set: " << trt_accuracy.item().toFloat() << "%" << std::endl;
130121

131122
/// Time execution in JIT-FP32 and TRT-INT8
132123
std::vector<std::vector<int64_t>> dims = {{32, 3, 32, 32}};

0 commit comments

Comments
 (0)