Skip to content

Commit 2d31ac3

Browse files
authored
Add outlier in AWQ test cases (#3106)
1 parent 28612d0 commit 2d31ac3

File tree

1 file changed

+37
-29
lines changed

1 file changed

+37
-29
lines changed

test/prototype/test_awq.py

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,35 @@
2020

2121

2222
class ToyLinearModel(torch.nn.Module):
23-
def __init__(self, m=512, n=256, k=128):
24-
super().__init__()
25-
self.linear1 = torch.nn.Linear(m, n, bias=False)
26-
self.linear2 = torch.nn.Linear(n, k, bias=False)
27-
self.linear3 = torch.nn.Linear(k, 64, bias=False)
28-
29-
def example_inputs(
30-
self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda"
23+
def __init__(
24+
self,
25+
m=512,
26+
n=256,
27+
k=128,
28+
dtype=None,
29+
device=None,
3130
):
32-
return [
33-
torch.randn(
34-
1, sequence_length, self.linear1.in_features, dtype=dtype, device=device
35-
)
36-
for j in range(batch_size)
37-
]
31+
super().__init__()
32+
self.dtype = dtype
33+
self.device = device
34+
self.linear1 = torch.nn.Linear(m, n, bias=False, device=device, dtype=dtype)
35+
self.linear2 = torch.nn.Linear(n, k, bias=False, device=device, dtype=dtype)
36+
self.linear3 = torch.nn.Linear(k, 64, bias=False, device=device, dtype=dtype)
37+
38+
def example_inputs(self, batch_size, sequence_length=10):
39+
# For AWQ tests, we intentionally insert some outliers to input features
40+
x = torch.randn(
41+
batch_size,
42+
sequence_length,
43+
self.linear1.in_features,
44+
dtype=self.dtype,
45+
device=self.device,
46+
)
47+
n_outliers = max(1, int(x.size(-1) * 0.1))
48+
# Randomly select outlier features
49+
outlier_indices = torch.randperm(x.size(-1))[:n_outliers]
50+
x[:, :, outlier_indices] *= 10.0
51+
return (x,)
3852

3953
def forward(self, x):
4054
x = self.linear1(x)
@@ -92,14 +106,12 @@ def test_awq_functionality(self, device):
92106
base_configs = device_to_base_configs[device]
93107

94108
for base_config in base_configs:
95-
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
109+
m = ToyLinearModel(l1, l2, l3, device=device, dtype=original_dtype).eval()
96110
m_baseline = copy.deepcopy(m)
97111

98112
dataset = m.example_inputs(
99113
dataset_size,
100114
sequence_length=sequence_length,
101-
dtype=original_dtype,
102-
device=device,
103115
)
104116
# for test, we use calibration_data = dataset so that awq is
105117
# guranteed to be better than baseline
@@ -142,12 +154,10 @@ def test_awq_loading(self, device):
142154
base_configs = device_to_base_configs[device]
143155

144156
for base_config in base_configs:
145-
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
157+
m = ToyLinearModel(l1, l2, l3, device=device, dtype=original_dtype).eval()
146158
dataset = m.example_inputs(
147159
dataset_size,
148160
sequence_length=sequence_length,
149-
dtype=original_dtype,
150-
device=device,
151161
)
152162
# for test purpose, we don't need to get a subset
153163
calibration_data = dataset
@@ -171,9 +181,9 @@ def test_awq_loading(self, device):
171181
f.seek(0)
172182
state_dict = torch.load(f)
173183

174-
loaded_model = (
175-
ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
176-
)
184+
loaded_model = ToyLinearModel(
185+
l1, l2, l3, device=device, dtype=original_dtype
186+
).eval()
177187
loaded_model.load_state_dict(state_dict, assign=True)
178188

179189
m = torch.compile(m, fullgraph=True)
@@ -203,12 +213,10 @@ def test_awq_loading_vllm(self, device):
203213
base_configs = device_to_base_configs[device]
204214

205215
for base_config in base_configs:
206-
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
216+
m = ToyLinearModel(l1, l2, l3, device=device, dtype=original_dtype).eval()
207217
dataset = m.example_inputs(
208218
dataset_size,
209219
sequence_length=sequence_length,
210-
dtype=original_dtype,
211-
device=device,
212220
)
213221
# for test purpose, we don't need to get a subset
214222
calibration_data = dataset
@@ -231,9 +239,9 @@ def test_awq_loading_vllm(self, device):
231239
f.seek(0)
232240
state_dict = torch.load(f)
233241

234-
loaded_model = (
235-
ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
236-
)
242+
loaded_model = ToyLinearModel(
243+
l1, l2, l3, device=device, dtype=original_dtype
244+
).eval()
237245
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE_FOR_LOADING)
238246
quantize_(loaded_model, quant_config)
239247

0 commit comments

Comments
 (0)