|
25 | 25 | from transformers import PreTrainedModel |
26 | 26 |
|
27 | 27 | from ...models._const import DEVICE, PLATFORM |
28 | | -from ...utils.torch import torch_compile |
29 | 28 |
|
30 | 29 | logger = setup_logger() |
31 | 30 |
|
@@ -113,13 +112,6 @@ def post_init(self): |
113 | 112 |
|
114 | 113 | self.wf = self.wf.to(device=self.qweight.device) |
115 | 114 |
|
116 | | - def optimize(self, backend: str = "inductor", mode: str = None, fullgraph: bool = False): |
117 | | - # compile dequantize |
118 | | - self.dequantize_weight = torch_compile(self.dequantize_weight, backend=backend, mode=mode, fullgraph=fullgraph) |
119 | | - |
120 | | - #if self.adapter: |
121 | | - # self.adapter.g_compile(backend=backend, mode=mode, fullgraph=fullgraph) |
122 | | - |
123 | 115 | def forward(self, x: torch.Tensor): |
124 | 116 | if x.size(-1) != self.padded_infeatures: |
125 | 117 | x = F.pad(x, (0, self.padded_infeatures - self.in_features)) |
@@ -150,59 +142,6 @@ def _empty_gptq_only_weights(self): |
150 | 142 | self.g_idx = None |
151 | 143 | self.scales = None |
152 | 144 |
|
153 | | - def dequantize_weight(self, num_itr: int=1): |
154 | | - if self.bits in [2, 4, 8]: |
155 | | - zeros = torch.bitwise_right_shift( |
156 | | - torch.unsqueeze(self.qzeros, 2).expand(-1, -1, self.pack_factor), |
157 | | - self.wf.unsqueeze(0), |
158 | | - ).to(self.dequant_dtype) |
159 | | - zeros = torch.bitwise_and(zeros, self.maxq).reshape(self.scales.shape) |
160 | | - |
161 | | - weight = torch.bitwise_and( |
162 | | - torch.bitwise_right_shift( |
163 | | - torch.unsqueeze(self.qweight, 1).expand(-1, self.pack_factor, -1), |
164 | | - self.wf.unsqueeze(-1), |
165 | | - ).to(self.dequant_dtype), |
166 | | - self.maxq |
167 | | - ) |
168 | | - elif self.bits == 3: |
169 | | - zeros = self.qzeros.reshape(self.qzeros.shape[0], self.qzeros.shape[1] // 3, 3, 1).expand( |
170 | | - -1, -1, -1, 12 |
171 | | - ) |
172 | | - zeros = zeros >> self.wf.unsqueeze(0) |
173 | | - zeros[:, :, 0, 10] = (zeros[:, :, 0, 10] & 0x3) | ((zeros[:, :, 1, 0] << 2) & 0x4) |
174 | | - zeros[:, :, 1, 11] = (zeros[:, :, 1, 11] & 0x1) | ((zeros[:, :, 2, 0] << 1) & 0x6) |
175 | | - zeros = zeros & 0x7 |
176 | | - zeros = torch.cat( |
177 | | - [zeros[:, :, 0, :11], zeros[:, :, 1, 1:12], zeros[:, :, 2, 1:11]], |
178 | | - dim=2, |
179 | | - ).reshape(self.scales.shape) |
180 | | - |
181 | | - weight = self.qweight.reshape(self.qweight.shape[0] // 3, 3, 1, self.qweight.shape[1]).expand( |
182 | | - -1, -1, 12, -1 |
183 | | - ) |
184 | | - weight = (weight >> self.wf.unsqueeze(-1)) & 0x7 |
185 | | - weight[:, 0, 10] = (weight[:, 0, 10] & 0x3) | ((weight[:, 1, 0] << 2) & 0x4) |
186 | | - weight[:, 1, 11] = (weight[:, 1, 11] & 0x1) | ((weight[:, 2, 0] << 1) & 0x6) |
187 | | - weight = weight & 0x7 |
188 | | - weight = torch.cat([weight[:, 0, :11], weight[:, 1, 1:12], weight[:, 2, 1:11]], dim=1) |
189 | | - weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) |
190 | | - |
191 | | - if num_itr == 1: |
192 | | - weights = self.scales[self.g_idx.long()] * (weight - zeros[self.g_idx.long()]) |
193 | | - else: |
194 | | - num_dim = self.g_idx.shape[0] // num_itr |
195 | | - weights = [] |
196 | | - for i in range(num_itr): |
197 | | - scale_i = self.scales[:, i * num_dim: (i + 1) * num_dim] |
198 | | - weight_i = weight[:, i * num_dim: (i + 1) * num_dim] |
199 | | - zeros_i = zeros[:, i * num_dim: (i + 1) * num_dim] |
200 | | - g_idx_i = self.g_idx[i * num_dim: (i + 1) * num_dim].long() |
201 | | - weights.append(scale_i[g_idx_i] * (weight_i - zeros_i[g_idx_i])) |
202 | | - weights = torch.cat(weights, dim=1) |
203 | | - |
204 | | - return weights |
205 | | - |
206 | 145 | def dequantize_model(model: PreTrainedModel): |
207 | 146 | for name, module in model.named_modules(): |
208 | 147 | if isinstance(module, BaseQuantLinear) and not isinstance(module, TorchQuantLinear): |
|
0 commit comments