Skip to content

Commit 0e0f8fd

Browse files
asl3pytorchmergebot
authored andcommitted
Implement QAT for APoT (pytorch#83282)
### Summary: This PR implements QAT for APoT FakeQuant. It runs QAT with FX graph mode quantized models (Resnet-18 pre-trained model, full ImageNet dataset) to compare accuracy metrics for different qconfig settings of uniform vs. APoT quantized activation and weight. It also refactors the APoT PTQ module `apot_fx_graph_mode_ptq.py` (previously `fx_graph_mode_apot.py`) such that shared helper functions between PTQ and QAT are in a separate file `quantization_util.py`. Model #2 (uniformly quantized activation, APoT quantized weight) shows comparable accuracy compared to model #1 (uniformly quantized activation, APoT quantized weight) for 8-bit and significant accuracy improvement for 4-bit (see "Accuracy Stats" section below). ### Test Plan: Run QAT models with: `python test/quantization/core/experimental/apot_qat.py` Run PTQ models with: `python test/quantization/core/experimental/apot_ptq.py` ### Accuracy Stats 8-bit (Uniform int8, APoT b = 8 k = 2) Model #1: Uniform activation, uniform weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 69.67% (Top-1), 89.04% (Top-5) Model #2: Uniform activation, APoT weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 69.72% (Top-1), 89.06% (Top-5) 4-bit (Uniform int4, APoT b = 4 k = 2) Model #1: Uniform activation, uniform weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 46.85% (Top-1), 72.85% (Top-5) Model #2: Uniform activation, APoT weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 66.45% (Top-1), 86.23% (Top-5) Pull Request resolved: pytorch#83282 Approved by: https://github.com/jerryzh168
1 parent 2ca721c commit 0e0f8fd

File tree

3 files changed

+248
-129
lines changed

3 files changed

+248
-129
lines changed
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.quantization
4+
from torchvision.models.quantization.resnet import resnet18
5+
from torch.ao.quantization.experimental.quantization_helper import (
6+
evaluate,
7+
prepare_data_loaders
8+
)
9+
10+
# validation dataset: full ImageNet dataset
11+
data_path = '~/my_imagenet/'
12+
13+
data_loader, data_loader_test = prepare_data_loaders(data_path)
14+
criterion = nn.CrossEntropyLoss()
15+
float_model = resnet18(pretrained=True)
16+
float_model.eval()
17+
18+
# deepcopy the model since we need to keep the original model around
19+
import copy
20+
model_to_quantize = copy.deepcopy(float_model)
21+
22+
model_to_quantize.eval()
23+
24+
"""
25+
Prepare models
26+
"""
27+
28+
# Note that this is temporary, we'll expose these functions to torch.quantization after official releasee
29+
from torch.quantization.quantize_fx import prepare_qat_fx
30+
31+
def calibrate(model, data_loader):
32+
model.eval()
33+
with torch.no_grad():
34+
for image, target in data_loader:
35+
model(image)
36+
37+
from torch.ao.quantization.experimental.qconfig import (
38+
uniform_qconfig_8bit,
39+
apot_weights_qconfig_8bit,
40+
apot_qconfig_8bit,
41+
uniform_qconfig_4bit,
42+
apot_weights_qconfig_4bit,
43+
apot_qconfig_4bit
44+
)
45+
46+
"""
47+
Prepare full precision model
48+
"""
49+
full_precision_model = float_model
50+
51+
top1, top5 = evaluate(full_precision_model, criterion, data_loader_test)
52+
print("Model #0 Evaluation accuracy on test dataset: %2.2f, %2.2f" % (top1.avg, top5.avg))
53+
54+
"""
55+
Prepare model PTQ for specified qconfig for torch.nn.Linear
56+
"""
57+
def prepare_ptq_linear(qconfig):
58+
qconfig_dict = {"object_type": [(torch.nn.Linear, qconfig)]}
59+
prepared_model = prepare_qat_fx(copy.deepcopy(float_model), qconfig_dict) # fuse modules and insert observers
60+
calibrate(prepared_model, data_loader_test) # run calibration on sample data
61+
return prepared_model
62+
63+
"""
64+
Prepare model with uniform activation, uniform weight
65+
b=8, k=2
66+
"""
67+
68+
prepared_model = prepare_ptq_linear(uniform_qconfig_8bit)
69+
quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model
70+
71+
top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
72+
print("Model #1 Evaluation accuracy on test dataset (b=8, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
73+
74+
"""
75+
Prepare model with uniform activation, uniform weight
76+
b=4, k=2
77+
"""
78+
79+
prepared_model = prepare_ptq_linear(uniform_qconfig_4bit)
80+
quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model
81+
82+
top1, top5 = evaluate(quantized_model1, criterion, data_loader_test)
83+
print("Model #1 Evaluation accuracy on test dataset (b=4, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
84+
85+
"""
86+
Prepare model with uniform activation, APoT weight
87+
(b=8, k=2)
88+
"""
89+
90+
prepared_model = prepare_ptq_linear(apot_weights_qconfig_8bit)
91+
92+
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
93+
print("Model #2 Evaluation accuracy on test dataset (b=8, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
94+
95+
"""
96+
Prepare model with uniform activation, APoT weight
97+
(b=4, k=2)
98+
"""
99+
100+
prepared_model = prepare_ptq_linear(apot_weights_qconfig_4bit)
101+
102+
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
103+
print("Model #2 Evaluation accuracy on test dataset (b=4, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
104+
105+
106+
"""
107+
Prepare model with APoT activation and weight
108+
(b=8, k=2)
109+
"""
110+
111+
prepared_model = prepare_ptq_linear(apot_qconfig_8bit)
112+
113+
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
114+
print("Model #3 Evaluation accuracy on test dataset (b=8, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
115+
116+
"""
117+
Prepare model with APoT activation and weight
118+
(b=4, k=2)
119+
"""
120+
121+
prepared_model = prepare_ptq_linear(apot_qconfig_4bit)
122+
123+
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
124+
print("Model #3 Evaluation accuracy on test dataset (b=4, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
125+
126+
"""
127+
Prepare eager mode quantized model
128+
"""
129+
eager_quantized_model = resnet18(pretrained=True, quantize=True).eval()
130+
top1, top5 = evaluate(eager_quantized_model, criterion, data_loader_test)
131+
print("Eager mode quantized model evaluation accuracy on test dataset: %2.2f, %2.2f" % (top1.avg, top5.avg))
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from torchvision.models.quantization.resnet import resnet18
2+
from torch.ao.quantization.experimental.quantization_helper import (
3+
evaluate,
4+
prepare_data_loaders,
5+
training_loop
6+
)
7+
8+
# training and validation dataset: full ImageNet dataset
9+
data_path = '~/my_imagenet/'
10+
11+
train_batch_size = 30
12+
eval_batch_size = 50
13+
14+
data_loader, data_loader_test = prepare_data_loaders(data_path)
15+
criterion = nn.CrossEntropyLoss()
16+
float_model = resnet18(pretrained=True)
17+
float_model.eval()
18+
19+
# deepcopy the model since we need to keep the original model around
20+
import copy
21+
model_to_quantize = copy.deepcopy(float_model)
22+
23+
model_to_quantize.eval()
24+
25+
"""
26+
Prepare model QAT for specified qconfig for torch.nn.Linear
27+
"""
28+
def prepare_qat_linear(qconfig):
29+
qconfig_dict = {"object_type": [(torch.nn.Linear, qconfig)]}
30+
prepared_model = prepare_fx(copy.deepcopy(float_model), qconfig_dict) # fuse modules and insert observers
31+
training_loop(prepared_model, criterion, data_loader)
32+
prepared_model.eval()
33+
return prepared_model
34+
35+
"""
36+
Prepare model with uniform activation, uniform weight
37+
b=8, k=2
38+
"""
39+
40+
prepared_model = prepare_qat_linear(uniform_qconfig_8bit)
41+
42+
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
43+
print("Model #1 Evaluation accuracy on test dataset (b=8, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
44+
45+
"""
46+
Prepare model with uniform activation, uniform weight
47+
b=4, k=2
48+
"""
49+
50+
prepared_model = prepare_qat_linear(uniform_qconfig_4bit)
51+
52+
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
53+
print("Model #1 Evaluation accuracy on test dataset (b=4, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
54+
55+
"""
56+
Prepare model with uniform activation, APoT weight
57+
(b=8, k=2)
58+
"""
59+
60+
prepared_model = prepare_qat_linear(apot_weights_qconfig_8bit)
61+
62+
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
63+
print("Model #2 Evaluation accuracy on test dataset (b=8, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
64+
65+
"""
66+
Prepare model with uniform activation, APoT weight
67+
(b=4, k=2)
68+
"""
69+
70+
prepared_model = prepare_qat_linear(apot_weights_qconfig_4bit)
71+
72+
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
73+
print("Model #2 Evaluation accuracy on test dataset (b=4, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
74+
75+
76+
"""
77+
Prepare model with APoT activation and weight
78+
(b=8, k=2)
79+
"""
80+
81+
prepared_model = prepare_qat_linear(apot_qconfig_8bit)
82+
83+
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
84+
print("Model #3 Evaluation accuracy on test dataset (b=8, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
85+
86+
"""
87+
Prepare model with APoT activation and weight
88+
(b=4, k=2)
89+
"""
90+
91+
prepared_model = prepare_qat_linear(apot_qconfig_4bit)
92+
93+
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
94+
print("Model #3 Evaluation accuracy on test dataset (b=4, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))

0 commit comments

Comments
 (0)