Skip to content

Commit 867e956

Browse files
committed
chore: Add QAT test accuracy script
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 0dd11bd commit 867e956

File tree

1 file changed

+83
-0
lines changed

1 file changed

+83
-0
lines changed

tests/py/test_qat_trt_accuracy.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import unittest
2+
import trtorch
3+
from trtorch.logging import *
4+
import torch
5+
import torch.nn as nn
6+
from torch.nn import functional as F
7+
import torchvision
8+
import torchvision.transforms as transforms
9+
from model_test_case import ModelTestCase
10+
11+
12+
class TestAccuracy(ModelTestCase):
13+
14+
def setUp(self):
15+
self.testing_dataset = torchvision.datasets.CIFAR10(root='./data',
16+
train=False,
17+
download=True,
18+
transform=transforms.Compose([
19+
transforms.ToTensor(),
20+
transforms.Normalize((0.4914, 0.4822, 0.4465),
21+
(0.2023, 0.1994, 0.2010))
22+
]))
23+
24+
self.testing_dataloader = torch.utils.data.DataLoader(self.testing_dataset,
25+
batch_size=16,
26+
shuffle=False,
27+
num_workers=1)
28+
29+
def compute_accuracy(self, testing_dataloader, model):
30+
total = 0
31+
correct = 0
32+
loss = 0.0
33+
class_probs = []
34+
class_preds = []
35+
device = torch.device('cuda:0')
36+
with torch.no_grad():
37+
idx = 0
38+
for data, labels in testing_dataloader:
39+
data, labels = data.to(device), labels.to(device)
40+
out = model(data)
41+
preds = torch.max(out, 1)[1]
42+
class_probs.append([F.softmax(i, dim=0) for i in out])
43+
class_preds.append(preds)
44+
total += labels.size(0)
45+
correct += (preds == labels).sum().item()
46+
idx += 1
47+
48+
test_probs = torch.cat([torch.stack(batch) for batch in class_probs])
49+
test_preds = torch.cat(class_preds)
50+
return correct / total
51+
52+
def test_compile_script(self):
53+
fp32_test_acc = self.compute_accuracy(self.testing_dataloader, self.model)
54+
log(Level.Info, "[Pyt FP32] Test Acc: {:.2f}%".format(100 * fp32_test_acc))
55+
56+
compile_spec = {
57+
"inputs": [trtorch.Input([16, 3, 32, 32])],
58+
"op_precision": torch.int8,
59+
# "enabled_precision": {torch.float32, torch.int8},
60+
}
61+
62+
trt_mod = trtorch.compile(self.model, compile_spec)
63+
int8_test_acc = self.compute_accuracy(self.testing_dataloader, trt_mod)
64+
log(Level.Info, "[TRT QAT INT8] Test Acc: {:.2f}%".format(100 * int8_test_acc))
65+
acc_diff = fp32_test_acc - int8_test_acc
66+
self.assertTrue(abs(acc_diff) < 3)
67+
68+
69+
def test_suite():
70+
suite = unittest.TestSuite()
71+
# You need a VGG QAT model trained on CIFAR10 to run this test. Please follow instructions at
72+
# https://github.com/NVIDIA/TRTorch/tree/master/examples/int8/training/vgg16 to export this model.
73+
suite.addTest(TestAccuracy.parametrize(TestAccuracy, model=torch.jit.load('./trained_vgg16_qat.jit.pt')))
74+
75+
return suite
76+
77+
78+
suite = test_suite()
79+
80+
runner = unittest.TextTestRunner()
81+
result = runner.run(suite)
82+
83+
exit(int(not result.wasSuccessful()))

0 commit comments

Comments
 (0)