Skip to content

Commit b5d311d

Browse files
committed
move dequantize_weight() to PackableQuantLinear
Signed-off-by: ZX-ModelCloud <[email protected]>
1 parent 9b90b67 commit b5d311d

File tree

2 files changed

+61
-61
lines changed

2 files changed

+61
-61
lines changed

gptqmodel/nn_modules/qlinear/__init__.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from gptqmodel.adapter.adapter import LORA_MERGED_WEIGHT_PATHS, Adapter
2626

2727
from ...models._const import DEVICE, PLATFORM
28+
from ...utils.torch import torch_compile
2829

2930

3031
class BaseQuantLinear(nn.Module):
@@ -420,3 +421,63 @@ def pack(self, linear, scales, zeros, g_idx=None):
420421
col += 1
421422

422423
self.qzeros = t.from_numpy(qzeros.astype(self.pack_np_dtype))
424+
425+
def dequantize_weight(self, num_itr: int=1):
426+
if self.bits in [2, 4, 8]:
427+
zeros = t.bitwise_right_shift(
428+
t.unsqueeze(self.qzeros, 2).expand(-1, -1, self.pack_factor),
429+
self.wf.unsqueeze(0),
430+
).to(self.dequant_dtype)
431+
zeros = t.bitwise_and(zeros, self.maxq).reshape(self.scales.shape)
432+
433+
weight = t.bitwise_and(
434+
t.bitwise_right_shift(
435+
t.unsqueeze(self.qweight, 1).expand(-1, self.pack_factor, -1),
436+
self.wf.unsqueeze(-1),
437+
).to(self.dequant_dtype),
438+
self.maxq
439+
)
440+
elif self.bits == 3:
441+
zeros = self.qzeros.reshape(self.qzeros.shape[0], self.qzeros.shape[1] // 3, 3, 1).expand(
442+
-1, -1, -1, 12
443+
)
444+
zeros = zeros >> self.wf.unsqueeze(0)
445+
zeros[:, :, 0, 10] = (zeros[:, :, 0, 10] & 0x3) | ((zeros[:, :, 1, 0] << 2) & 0x4)
446+
zeros[:, :, 1, 11] = (zeros[:, :, 1, 11] & 0x1) | ((zeros[:, :, 2, 0] << 1) & 0x6)
447+
zeros = zeros & 0x7
448+
zeros = t.cat(
449+
[zeros[:, :, 0, :11], zeros[:, :, 1, 1:12], zeros[:, :, 2, 1:11]],
450+
dim=2,
451+
).reshape(self.scales.shape)
452+
453+
weight = self.qweight.reshape(self.qweight.shape[0] // 3, 3, 1, self.qweight.shape[1]).expand(
454+
-1, -1, 12, -1
455+
)
456+
weight = (weight >> self.wf.unsqueeze(-1)) & 0x7
457+
weight[:, 0, 10] = (weight[:, 0, 10] & 0x3) | ((weight[:, 1, 0] << 2) & 0x4)
458+
weight[:, 1, 11] = (weight[:, 1, 11] & 0x1) | ((weight[:, 2, 0] << 1) & 0x6)
459+
weight = weight & 0x7
460+
weight = t.cat([weight[:, 0, :11], weight[:, 1, 1:12], weight[:, 2, 1:11]], dim=1)
461+
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
462+
463+
if num_itr == 1:
464+
weights = self.scales[self.g_idx.long()] * (weight - zeros[self.g_idx.long()])
465+
else:
466+
num_dim = self.g_idx.shape[0] // num_itr
467+
weights = []
468+
for i in range(num_itr):
469+
scale_i = self.scales[:, i * num_dim: (i + 1) * num_dim]
470+
weight_i = weight[:, i * num_dim: (i + 1) * num_dim]
471+
zeros_i = zeros[:, i * num_dim: (i + 1) * num_dim]
472+
g_idx_i = self.g_idx[i * num_dim: (i + 1) * num_dim].long()
473+
weights.append(scale_i[g_idx_i] * (weight_i - zeros_i[g_idx_i]))
474+
weights = t.cat(weights, dim=1)
475+
476+
return weights
477+
478+
def optimize(self, backend: str = "inductor", mode: str = None, fullgraph: bool = False):
479+
# compile dequantize
480+
self.dequantize_weight = torch_compile(self.dequantize_weight, backend=backend, mode=mode, fullgraph=fullgraph)
481+
482+
#if self.adapter:
483+
# self.adapter.g_compile(backend=backend, mode=mode, fullgraph=fullgraph)

gptqmodel/nn_modules/qlinear/torch.py

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from transformers import PreTrainedModel
2626

2727
from ...models._const import DEVICE, PLATFORM
28-
from ...utils.torch import torch_compile
2928

3029
logger = setup_logger()
3130

@@ -113,13 +112,6 @@ def post_init(self):
113112

114113
self.wf = self.wf.to(device=self.qweight.device)
115114

116-
def optimize(self, backend: str = "inductor", mode: str = None, fullgraph: bool = False):
117-
# compile dequantize
118-
self.dequantize_weight = torch_compile(self.dequantize_weight, backend=backend, mode=mode, fullgraph=fullgraph)
119-
120-
#if self.adapter:
121-
# self.adapter.g_compile(backend=backend, mode=mode, fullgraph=fullgraph)
122-
123115
def forward(self, x: torch.Tensor):
124116
if x.size(-1) != self.padded_infeatures:
125117
x = F.pad(x, (0, self.padded_infeatures - self.in_features))
@@ -150,59 +142,6 @@ def _empty_gptq_only_weights(self):
150142
self.g_idx = None
151143
self.scales = None
152144

153-
def dequantize_weight(self, num_itr: int=1):
154-
if self.bits in [2, 4, 8]:
155-
zeros = torch.bitwise_right_shift(
156-
torch.unsqueeze(self.qzeros, 2).expand(-1, -1, self.pack_factor),
157-
self.wf.unsqueeze(0),
158-
).to(self.dequant_dtype)
159-
zeros = torch.bitwise_and(zeros, self.maxq).reshape(self.scales.shape)
160-
161-
weight = torch.bitwise_and(
162-
torch.bitwise_right_shift(
163-
torch.unsqueeze(self.qweight, 1).expand(-1, self.pack_factor, -1),
164-
self.wf.unsqueeze(-1),
165-
).to(self.dequant_dtype),
166-
self.maxq
167-
)
168-
elif self.bits == 3:
169-
zeros = self.qzeros.reshape(self.qzeros.shape[0], self.qzeros.shape[1] // 3, 3, 1).expand(
170-
-1, -1, -1, 12
171-
)
172-
zeros = zeros >> self.wf.unsqueeze(0)
173-
zeros[:, :, 0, 10] = (zeros[:, :, 0, 10] & 0x3) | ((zeros[:, :, 1, 0] << 2) & 0x4)
174-
zeros[:, :, 1, 11] = (zeros[:, :, 1, 11] & 0x1) | ((zeros[:, :, 2, 0] << 1) & 0x6)
175-
zeros = zeros & 0x7
176-
zeros = torch.cat(
177-
[zeros[:, :, 0, :11], zeros[:, :, 1, 1:12], zeros[:, :, 2, 1:11]],
178-
dim=2,
179-
).reshape(self.scales.shape)
180-
181-
weight = self.qweight.reshape(self.qweight.shape[0] // 3, 3, 1, self.qweight.shape[1]).expand(
182-
-1, -1, 12, -1
183-
)
184-
weight = (weight >> self.wf.unsqueeze(-1)) & 0x7
185-
weight[:, 0, 10] = (weight[:, 0, 10] & 0x3) | ((weight[:, 1, 0] << 2) & 0x4)
186-
weight[:, 1, 11] = (weight[:, 1, 11] & 0x1) | ((weight[:, 2, 0] << 1) & 0x6)
187-
weight = weight & 0x7
188-
weight = torch.cat([weight[:, 0, :11], weight[:, 1, 1:12], weight[:, 2, 1:11]], dim=1)
189-
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
190-
191-
if num_itr == 1:
192-
weights = self.scales[self.g_idx.long()] * (weight - zeros[self.g_idx.long()])
193-
else:
194-
num_dim = self.g_idx.shape[0] // num_itr
195-
weights = []
196-
for i in range(num_itr):
197-
scale_i = self.scales[:, i * num_dim: (i + 1) * num_dim]
198-
weight_i = weight[:, i * num_dim: (i + 1) * num_dim]
199-
zeros_i = zeros[:, i * num_dim: (i + 1) * num_dim]
200-
g_idx_i = self.g_idx[i * num_dim: (i + 1) * num_dim].long()
201-
weights.append(scale_i[g_idx_i] * (weight_i - zeros_i[g_idx_i]))
202-
weights = torch.cat(weights, dim=1)
203-
204-
return weights
205-
206145
def dequantize_model(model: PreTrainedModel):
207146
for name, module in model.named_modules():
208147
if isinstance(module, BaseQuantLinear) and not isinstance(module, TorchQuantLinear):

0 commit comments

Comments
 (0)