Skip to content

Commit 9dc6061

Browse files
committed
fix(qat): Rescale input data for C++ application
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent b9b7f63 commit 9dc6061

File tree

3 files changed

+3
-1
lines changed

3 files changed

+3
-1
lines changed

examples/int8/datasets/cifar10.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ std::pair<torch::Tensor, torch::Tensor> read_batch(const std::string& path) {
5050
labels.push_back(label);
5151
auto image_tensor =
5252
torch::from_blob(image.data(), {kImageChannels, kImageDim, kImageDim}, torch::TensorOptions().dtype(torch::kU8))
53-
.to(torch::kF32);
53+
.to(torch::kF32).div(255);
5454
images.push_back(image_tensor);
5555
}
5656

examples/int8/ptq/main.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,4 +140,5 @@ int main(int argc, const char* argv[]) {
140140

141141
auto trt_runtimes = benchmark_module(trt_mod, dims[0]);
142142
print_avg_std_dev("TRT quantized model", trt_runtimes, dims[0][0]);
143+
trt_mod.save("/tmp/ptq_vgg16.trt.ts");
143144
}

examples/int8/qat/main.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,5 +124,6 @@ int main(int argc, const char* argv[]) {
124124

125125
auto trt_runtimes = benchmark_module(trt_mod, dims[0]);
126126
print_avg_std_dev("TRT quantized model", trt_runtimes, dims[0][0]);
127+
trt_mod.save("/tmp/qat_vgg16.trt.ts");
127128
}
128129

0 commit comments

Comments
 (0)