Skip to content

Commit 73db30f

Browse files
authored
Merge pull request vllm-project#5 from ri938/more_improvements_awq
More improvements awq
2 parents a3ac858 + db4db0c commit 73db30f

File tree

5 files changed

+128
-83
lines changed

5 files changed

+128
-83
lines changed

vllm/awq_quantization/qmodule.py

Lines changed: 62 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,41 @@
11
# adapted from llm-awq: https://github.com/mit-han-lab/llm-awq
22

3-
import math
43
import torch
54
import torch.nn as nn
65

76
try:
87
import awq_inference_engine # with CUDA kernels
98
except ImportError as ex:
10-
msg = "Unable to import awq_inference_engine: run setup.py to install CUDA kernels"
11-
raise ImportError(msg)
9+
raise ImportError(
10+
"Unable to import awq_inference_engine: run setup.py"
11+
" to install AWQ CUDA kernels")
1212

1313

1414
class ScaledActivation(nn.Module):
1515
def __init__(self, module, scales):
1616
super().__init__()
1717
self.act = module
1818
self.scales = nn.Parameter(scales.data)
19-
19+
2020
def forward(self, x):
2121
return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
2222

2323

2424
class WQLinear(nn.Module):
25-
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
25+
def __init__(
26+
self,
27+
w_bit,
28+
group_size,
29+
in_features,
30+
out_features,
31+
bias,
32+
dev
33+
):
2634
super().__init__()
27-
35+
2836
if w_bit not in [4]:
2937
raise NotImplementedError("Only 4-bit are supported for now.")
30-
38+
3139
self.in_features = in_features
3240
self.out_features = out_features
3341
self.w_bit = w_bit
@@ -37,23 +45,62 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
3745
assert self.in_features % self.group_size == 0
3846
assert out_features % (32 // self.w_bit) == 0
3947

40-
self.register_buffer('qweight', torch.empty((in_features, out_features // (32 // self.w_bit)), dtype=torch.int32, device=dev))
41-
self.register_buffer('qzeros', torch.empty((in_features // self.group_size, out_features // (32 // self.w_bit)), dtype=torch.int32, device=dev))
42-
self.register_buffer('scales', torch.empty((in_features // self.group_size, out_features), dtype=torch.float16, device=dev))
48+
qweight_buffer = torch.empty(
49+
(in_features, out_features // (32 // self.w_bit)),
50+
dtype=torch.int32,
51+
device=dev
52+
)
53+
self.register_buffer("qweight", qweight_buffer)
54+
55+
qzeros_buffer = torch.empty(
56+
(
57+
in_features // self.group_size,
58+
out_features // (32 // self.w_bit)
59+
),
60+
dtype=torch.int32,
61+
device=dev
62+
)
63+
self.register_buffer("qzeros", qzeros_buffer)
64+
65+
scales_buffer = torch.empty(
66+
(in_features // self.group_size, out_features),
67+
dtype=torch.float16,
68+
device=dev
69+
)
70+
self.register_buffer("scales", scales_buffer)
4371

4472
if bias:
45-
self.register_buffer('bias', torch.empty((out_features), dtype=torch.float16, device=dev))
73+
bias_buffer = torch.empty(
74+
(out_features),
75+
dtype=torch.float16,
76+
device=dev
77+
)
78+
self.register_buffer("bias", bias_buffer)
4679
else:
4780
self.bias = None
4881

4982
@torch.no_grad()
5083
def forward(self, x):
5184
out_shape = x.shape[:-1] + (self.out_features, )
52-
out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8)
85+
86+
out = awq_inference_engine.gemm_forward_cuda(
87+
x.reshape(-1, x.shape[-1]),
88+
self.qweight,
89+
self.scales,
90+
self.qzeros,
91+
8
92+
)
93+
5394
out = out + self.bias if self.bias is not None else out
5495
return out.reshape(out_shape)
55-
96+
5697
def extra_repr(self) -> str:
57-
return 'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'.format(
58-
self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size
98+
str_repr = "in_features={}, out_features={}, " \
99+
"bias={}, w_bit={}, group_size={}"
100+
return str_repr.format(
101+
self.in_features,
102+
self.out_features,
103+
self.bias is not None,
104+
self.w_bit,
105+
self.group_size
59106
)

vllm/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333
self._verify()
3434

3535
def _verify(self) -> None:
36-
allowed_methods = ['awq']
36+
allowed_methods = ["awq"]
3737
if self.method not in allowed_methods:
3838
raise ValueError(
3939
f"Unknown quantization method ({self.method})"
@@ -118,7 +118,8 @@ def verify_with_parallel_config(
118118
f"({pipeline_parallel_size}).")
119119

120120
if self.quantization_config and tensor_parallel_size > 1:
121-
raise NotImplementedError("Quantization does not currently support tensor parallelism")
121+
raise NotImplementedError(
122+
"Quantization does not currently support tensor parallelism")
122123

123124
def get_hidden_size(self) -> int:
124125
return self.hf_config.hidden_size

vllm/engine/arg_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,11 @@ def create_engine_configs(
152152
self,
153153
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
154154
# Initialize the configs.
155-
quantization_config = QuantizationConfig(self.quantization) if self.quantization else None
155+
if self.quantization is not None:
156+
quantization_config = QuantizationConfig(self.quantization)
157+
else:
158+
quantization_config = None
159+
156160
model_config = ModelConfig(self.model, self.tokenizer,
157161
self.tokenizer_mode, self.trust_remote_code,
158162
self.download_dir, self.use_np_weights,

vllm/model_executor/model_loader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ def get_model(model_config: ModelConfig) -> nn.Module:
5151
# The weights will be initialized as empty tensors.
5252

5353
if _supports_quantization(model_class):
54-
model = model_class(model_config.hf_config, model_config.quantization_config)
54+
model = model_class(
55+
model_config.hf_config,
56+
model_config.quantization_config
57+
)
5558
else:
5659
model = model_class(model_config.hf_config)
5760

vllm/model_executor/models/llama.py

Lines changed: 54 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def __init__(
164164
super().__init__()
165165
self.hidden_size = hidden_size
166166
tp_size = get_tensor_model_parallel_world_size()
167-
assert tp_size == 1, 'quantization does not support TP'
167+
assert tp_size == 1, "quantization does not support TP"
168168
self.total_num_heads = num_heads
169169
assert self.total_num_heads % tp_size == 0
170170
self.num_heads = self.total_num_heads // tp_size
@@ -178,7 +178,7 @@ def __init__(
178178

179179
self.qkv_proj = get_quantized_layer(
180180
hidden_size,
181-
(self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim,
181+
self.q_size + 2 * self.kv_size,
182182
quant_config
183183
)
184184

@@ -220,8 +220,17 @@ def __init__(
220220
quant_config: QuantizationConfig
221221
):
222222
super().__init__()
223-
self.gate_up_proj = get_quantized_layer(hidden_size, 2 * intermediate_size, quant_config)
224-
self.down_proj = get_quantized_layer(intermediate_size, hidden_size, quant_config)
223+
224+
self.gate_up_proj = get_quantized_layer(
225+
hidden_size,
226+
2 * intermediate_size, quant_config
227+
)
228+
229+
self.down_proj = get_quantized_layer(
230+
intermediate_size,
231+
hidden_size,
232+
quant_config
233+
)
225234

226235
if hidden_act != "silu":
227236
raise ValueError(f"Unsupported activation: {hidden_act}. "
@@ -313,9 +322,12 @@ def __init__(self, config: LlamaConfig, quant_config: QuantizationConfig):
313322
vocab_size = ((config.vocab_size + 63) // 64) * 64
314323
self.embed_tokens = VocabParallelEmbedding(
315324
vocab_size, config.hidden_size, perform_initialization=False)
325+
316326
self.layers = nn.ModuleList([
317-
LlamaDecoderLayer(config, quant_config) for _ in range(config.num_hidden_layers)
327+
LlamaDecoderLayer(config, quant_config)
328+
for _ in range(config.num_hidden_layers)
318329
])
330+
319331
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
320332

321333
def forward(
@@ -414,82 +426,60 @@ def load_weights(self,
414426
extra_rows = extra_rows.to(loaded_weight)
415427
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
416428

417-
is_quantized = self.quant_config is not None and self.quant_config.method is not None
429+
is_quantized = self.quant_config is not None
418430

419-
# merge linear layers
420-
if not is_quantized:
421-
is_attention_weight = False
422-
for weight_name, shard_size, offset in attention_weight_specs:
423-
if weight_name not in name:
424-
continue
425-
param = state_dict[name.replace(weight_name, "qkv_proj")]
431+
is_attention_weight = False
432+
for weight_name, shard_size, offset in attention_weight_specs:
433+
if weight_name not in name:
434+
continue
435+
param = state_dict[name.replace(weight_name, "qkv_proj")]
426436

437+
if not is_quantized:
427438
loaded_weight = loaded_weight[
428439
shard_size * tensor_model_parallel_rank:shard_size *
429440
(tensor_model_parallel_rank + 1)]
430441
param_slice = param.data[offset:offset + shard_size]
431-
assert param_slice.shape == loaded_weight.shape
442+
else:
443+
# TODO: this is specific to AWQ
444+
if "qweight" in name or "qzeros" in name:
445+
adjustment = 32 / self.quant_config.bits
446+
shard_size = int(shard_size // adjustment)
447+
offset = int(offset // adjustment)
448+
param_slice = param.data[:, offset:offset + shard_size]
449+
450+
assert param_slice.shape == loaded_weight.shape
432451

433-
param_slice.copy_(loaded_weight)
434-
is_attention_weight = True
435-
break
436-
if is_attention_weight:
452+
param_slice.copy_(loaded_weight)
453+
is_attention_weight = True
454+
break
455+
if is_attention_weight:
456+
continue
457+
458+
is_gate_up_weight = False
459+
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
460+
if weight_name not in name:
437461
continue
462+
param = state_dict[name.replace(weight_name, "gate_up_proj")]
438463

439-
is_gate_up_weight = False
440-
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
441-
if weight_name not in name:
442-
continue
443-
param = state_dict[name.replace(weight_name, "gate_up_proj")]
464+
if not is_quantized:
444465
shard_size = param.shape[0] // 2
445466
loaded_weight = loaded_weight[
446467
shard_size * tensor_model_parallel_rank:shard_size *
447468
(tensor_model_parallel_rank + 1)]
448469
param_slice = param.data[shard_size * stride_id:shard_size *
449470
(stride_id + 1)]
450-
assert param_slice.shape == loaded_weight.shape
451-
param_slice.copy_(loaded_weight)
452-
is_gate_up_weight = True
453-
break
454-
if is_gate_up_weight:
455-
continue
456-
else:
457-
# TODO: improve this block of code (not DRY, hacky, specific to AWQ)
458-
is_attention_weight = False
459-
for stride_id, (weight_name, shard_size, offset) in enumerate(attention_weight_specs):
460-
if weight_name not in name:
461-
continue
462-
param = state_dict[name.replace(weight_name, "qkv_proj")]
463-
464-
# TODO: this is specific to AWQ (should be more general)
465-
if 'qweight' in name or 'qzeros' in name:
466-
shard_size = int(shard_size // (32 / self.quant_config.bits))
467-
offset = int(offset // (32 / self.quant_config.bits))
468-
469-
param_slice = param.data[:, offset:offset + shard_size]
470-
assert param_slice.shape == loaded_weight.shape
471-
472-
param_slice.copy_(loaded_weight)
473-
is_attention_weight = True
474-
break
475-
if is_attention_weight:
476-
continue
477-
478-
is_gate_up_weight = False
479-
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
480-
if weight_name not in name:
481-
continue
482-
param = state_dict[name.replace(weight_name, "gate_up_proj")]
471+
else:
483472
shard_size = param.shape[1] // 2
484-
485-
start, end = shard_size * stride_id, shard_size * (stride_id + 1)
473+
start = shard_size * stride_id
474+
end = shard_size * (stride_id + 1)
486475
param_slice = param.data[:, start:end]
487-
assert param_slice.shape == loaded_weight.shape
488-
param_slice.copy_(loaded_weight)
489-
is_gate_up_weight = True
490-
break
491-
if is_gate_up_weight:
492-
continue
476+
477+
assert param_slice.shape == loaded_weight.shape
478+
param_slice.copy_(loaded_weight)
479+
is_gate_up_weight = True
480+
break
481+
if is_gate_up_weight:
482+
continue
493483

494484
param = state_dict[name]
495485
load_tensor_parallel_weights(param, loaded_weight, name,

0 commit comments

Comments
 (0)