Skip to content

Commit 3b03131

Browse files
[FIX] not pack when group_size=-1 (#1298)
* Fix skipping pack() when group_size = -1 * assert len(qModules) > 0 * Update __init__.py * Update __init__.py --------- Co-authored-by: Qubitium-ModelCloud <[email protected]>
1 parent 4aa3520 commit 3b03131

File tree

4 files changed

+24
-35
lines changed

4 files changed

+24
-35
lines changed

examples/quantization/basic_usage_wikitext2.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,6 @@ def main():
6868
# with value under torch.LongTensor type.
6969
model.quantize(traindataset)
7070

71-
# save quantized model
72-
model.save(quantized_model_id)
73-
7471
# save quantized model using safetensors
7572
model.save(quantized_model_id)
7673

gptqmodel/models/definitions/minicpm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,4 @@ class MiniCPMGPTQ(BaseGPTQModel):
2929
["self_attn.v_proj"],
3030
["self_attn.o_proj"],
3131
["mlp.gate_proj", "mlp.up_proj","mlp.down_proj"],
32-
["mlp.c_proj"],
3332
]

gptqmodel/nn_modules/qlinear/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,8 @@ def _validate(cls, bits: int=4, group_size: int=128, desc_act: bool=False, sym:
205205
if bits not in cls.SUPPORTS_BITS:
206206
err = f"{cls} only supports `{cls.SUPPORTS_BITS}` bits: actual bits = `{bits}`"
207207
return False, NotImplementedError(err)
208-
if group_size not in cls.SUPPORTS_GROUP_SIZE:
208+
# valid group size is set of cls.SUPPORTS_GROUP_SIZE + in_features; group_size = -1 is alias for group_size == in_features
209+
if group_size not in cls.SUPPORTS_GROUP_SIZE and group_size != in_features:
209210
err = f"{cls} only supports `{cls.SUPPORTS_GROUP_SIZE}` group_size: actual group_size = `{group_size}`"
210211
return False, NotImplementedError(err)
211212
if sym not in cls.SUPPORTS_SYM:
@@ -365,3 +366,5 @@ def pack(self, linear, scales, zeros, g_idx=None):
365366
col += 1
366367

367368
self.qzeros = t.from_numpy(qzeros.astype(self.pack_np_dtype))
369+
370+
# print("self qw", self.qweight, self.scales, self.qzeros)

gptqmodel/utils/model.py

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def make_quant(
156156
device: DEVICE = None,
157157
from_quantized: bool = False,
158158
pack_dtype: torch.dtype = None,
159-
) -> BaseQuantLinear:
159+
) -> Type[BaseQuantLinear]:
160160
# returns multiple validated kernels
161161
quant_linear_candidates = select_quant_linear(
162162
bits=bits,
@@ -175,15 +175,15 @@ def make_quant(
175175
logger.info(f"make_quant: Linear candidates: {quant_linear_candidates}")
176176

177177
# loop over actual QLinear init, catch errors and use fallbacks if applicable
178-
for linear in quant_linear_candidates:
178+
for cls in quant_linear_candidates:
179179
try:
180180
# if linear is not selectedQLinear:
181181
# logger.info(f"make_quant: Faild linear: `{selectedQLinear}` failed, trying to use fallback: `{linear}`")
182182
# else:
183183
# logger.info("make_quant: Testing linear: {linear}")
184184

185-
linear_instance = create_quant_layer(
186-
linear=linear,
185+
linear_cls = create_quant_layer(
186+
linear_cls=cls,
187187
bits=bits,
188188
desc_act=desc_act,
189189
dynamic=dynamic,
@@ -194,10 +194,10 @@ def make_quant(
194194
device=device,
195195
lm_head_name=lm_head_name,
196196
pack_dtype=pack_dtype)
197-
logger.info(f"make_quant: Selected linear: `{linear}`.")
198-
return linear_instance
197+
logger.info(f"make_quant: Selected linear: `{cls}`.")
198+
return linear_cls
199199
except NotImplementedError as e:
200-
logger.info(f"make_quant: Skipped linear: `{linear}`.")
200+
logger.info(f"make_quant: Skipped linear: `{cls}`. ")
201201
# only fallback to other quant linears when backend is auto.
202202
if backend not in [BACKEND.AUTO, BACKEND.AUTO_TRAINABLE]:
203203
raise e
@@ -206,7 +206,7 @@ def make_quant(
206206

207207

208208
def create_quant_layer(
209-
linear: nn.Module,
209+
linear_cls: Type[BaseQuantLinear],
210210
bits: int,
211211
desc_act: bool,
212212
dynamic,
@@ -216,10 +216,9 @@ def create_quant_layer(
216216
sym: bool,
217217
device: DEVICE,
218218
lm_head_name: str,
219-
pack_dtype: torch.dtype,
220-
) -> BaseQuantLinear:
221-
if isinstance(module, linear):
222-
return linear
219+
pack_dtype: torch.dtype) -> Type[BaseQuantLinear]:
220+
if isinstance(module, linear_cls):
221+
return linear_cls
223222
for name, submodule in module.named_modules():
224223
if name in names:
225224
ori_layer_device = next(submodule.parameters()).device
@@ -266,7 +265,7 @@ def create_quant_layer(
266265

267266
# when loading a quantized model, device is target device passed in GPTQModel.load()
268267
# check in_features and out_features validate
269-
_, err = linear.validate(
268+
_, err = linear_cls.validate(
270269
bits=tmp_bits,
271270
group_size=tmp_group_size,
272271
desc_act=tmp_desc_act,
@@ -278,7 +277,7 @@ def create_quant_layer(
278277
if err is not None:
279278
raise err
280279

281-
new_layer = linear(
280+
new_layer = linear_cls(
282281
bits=tmp_bits,
283282
group_size=tmp_group_size,
284283
desc_act=tmp_desc_act,
@@ -293,7 +292,7 @@ def create_quant_layer(
293292
)
294293
new_layer.device = ori_layer_device
295294
recurse_setattr(module, name, new_layer.to(ori_layer_device))
296-
return linear
295+
return linear_cls
297296

298297
# public/stable api exposed to transformer/optimum
299298
def hf_convert_gptq_v1_to_v2_format(
@@ -489,25 +488,13 @@ def pack_model(
489488
parallel_packing: bool = True,
490489
pack_dtype: torch.dtype = None,
491490
):
492-
quantLinear = select_quant_linear(
493-
bits=bits,
494-
dynamic=dynamic,
495-
group_size=group_size,
496-
desc_act=desc_act,
497-
sym=sym,
498-
backend=backend,
499-
format=format,
500-
pack=True,
501-
pack_dtype=pack_dtype,
502-
)
503-
504491
model.to(CPU)
505492

506493
logger.info("Packing model...")
507494

508495
modules = find_modules(model)
509496
modules = {n: modules[n] for n in quantizers}
510-
make_quant(
497+
quant_linear_cls = make_quant(
511498
model,
512499
quantizers,
513500
bits,
@@ -520,7 +507,10 @@ def pack_model(
520507
dynamic=dynamic,
521508
pack_dtype=pack_dtype,
522509
)
523-
qModules = find_modules(model, [quantLinear])
510+
qModules = find_modules(model, [quant_linear_cls])
511+
512+
assert len(qModules) > 0, f"No quantizeed modules[{quant_linear_cls}] found in the model."
513+
524514
names = list(qModules.keys())
525515

526516
if parallel_packing:
@@ -537,7 +527,7 @@ def wrapper(name):
537527
pass
538528

539529
logger.info("Model packed.")
540-
return quantLinear
530+
return quant_linear_cls
541531

542532

543533
def verify_model_hash(file_path: str, verify_hash: str):

0 commit comments

Comments
 (0)