Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions examples/quantization/basic_usage_wikitext2.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,6 @@ def main():
# with value under torch.LongTensor type.
model.quantize(traindataset)

# save quantized model
model.save(quantized_model_id)

# save quantized model using safetensors
model.save(quantized_model_id)

Expand Down
1 change: 0 additions & 1 deletion gptqmodel/models/definitions/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,4 @@ class MiniCPMGPTQ(BaseGPTQModel):
["self_attn.v_proj"],
["self_attn.o_proj"],
["mlp.gate_proj", "mlp.up_proj","mlp.down_proj"],
["mlp.c_proj"],
]
5 changes: 4 additions & 1 deletion gptqmodel/nn_modules/qlinear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ def _validate(cls, bits: int=4, group_size: int=128, desc_act: bool=False, sym:
if bits not in cls.SUPPORTS_BITS:
err = f"{cls} only supports `{cls.SUPPORTS_BITS}` bits: actual bits = `{bits}`"
return False, NotImplementedError(err)
if group_size not in cls.SUPPORTS_GROUP_SIZE:
# valid group size is set of cls.SUPPORTS_GROUP_SIZE + in_features; group_size = -1 is alias for group_size == in_features
if group_size not in cls.SUPPORTS_GROUP_SIZE and group_size != in_features:
err = f"{cls} only supports `{cls.SUPPORTS_GROUP_SIZE}` group_size: actual group_size = `{group_size}`"
return False, NotImplementedError(err)
if sym not in cls.SUPPORTS_SYM:
Expand Down Expand Up @@ -365,3 +366,5 @@ def pack(self, linear, scales, zeros, g_idx=None):
col += 1

self.qzeros = t.from_numpy(qzeros.astype(self.pack_np_dtype))

# print("self qw", self.qweight, self.scales, self.qzeros)
50 changes: 20 additions & 30 deletions gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def make_quant(
device: DEVICE = None,
from_quantized: bool = False,
pack_dtype: torch.dtype = None,
) -> BaseQuantLinear:
) -> Type[BaseQuantLinear]:
# returns multiple validated kernels
quant_linear_candidates = select_quant_linear(
bits=bits,
Expand All @@ -175,15 +175,15 @@ def make_quant(
logger.info(f"make_quant: Linear candidates: {quant_linear_candidates}")

# loop over actual QLinear init, catch errors and use fallbacks if applicable
for linear in quant_linear_candidates:
for cls in quant_linear_candidates:
try:
# if linear is not selectedQLinear:
# logger.info(f"make_quant: Faild linear: `{selectedQLinear}` failed, trying to use fallback: `{linear}`")
# else:
# logger.info("make_quant: Testing linear: {linear}")

linear_instance = create_quant_layer(
linear=linear,
linear_cls = create_quant_layer(
linear_cls=cls,
bits=bits,
desc_act=desc_act,
dynamic=dynamic,
Expand All @@ -194,10 +194,10 @@ def make_quant(
device=device,
lm_head_name=lm_head_name,
pack_dtype=pack_dtype)
logger.info(f"make_quant: Selected linear: `{linear}`.")
return linear_instance
logger.info(f"make_quant: Selected linear: `{cls}`.")
return linear_cls
except NotImplementedError as e:
logger.info(f"make_quant: Skipped linear: `{linear}`.")
logger.info(f"make_quant: Skipped linear: `{cls}`. ")
# only fallback to other quant linears when backend is auto.
if backend not in [BACKEND.AUTO, BACKEND.AUTO_TRAINABLE]:
raise e
Expand All @@ -206,7 +206,7 @@ def make_quant(


def create_quant_layer(
linear: nn.Module,
linear_cls: Type[BaseQuantLinear],
bits: int,
desc_act: bool,
dynamic,
Expand All @@ -216,10 +216,9 @@ def create_quant_layer(
sym: bool,
device: DEVICE,
lm_head_name: str,
pack_dtype: torch.dtype,
) -> BaseQuantLinear:
if isinstance(module, linear):
return linear
pack_dtype: torch.dtype) -> Type[BaseQuantLinear]:
if isinstance(module, linear_cls):
return linear_cls
for name, submodule in module.named_modules():
if name in names:
ori_layer_device = next(submodule.parameters()).device
Expand Down Expand Up @@ -266,7 +265,7 @@ def create_quant_layer(

# when loading a quantized model, device is target device passed in GPTQModel.load()
# check in_features and out_features validate
_, err = linear.validate(
_, err = linear_cls.validate(
bits=tmp_bits,
group_size=tmp_group_size,
desc_act=tmp_desc_act,
Expand All @@ -278,7 +277,7 @@ def create_quant_layer(
if err is not None:
raise err

new_layer = linear(
new_layer = linear_cls(
bits=tmp_bits,
group_size=tmp_group_size,
desc_act=tmp_desc_act,
Expand All @@ -293,7 +292,7 @@ def create_quant_layer(
)
new_layer.device = ori_layer_device
recurse_setattr(module, name, new_layer.to(ori_layer_device))
return linear
return linear_cls

# public/stable api exposed to transformer/optimum
def hf_convert_gptq_v1_to_v2_format(
Expand Down Expand Up @@ -489,25 +488,13 @@ def pack_model(
parallel_packing: bool = True,
pack_dtype: torch.dtype = None,
):
quantLinear = select_quant_linear(
bits=bits,
dynamic=dynamic,
group_size=group_size,
desc_act=desc_act,
sym=sym,
backend=backend,
format=format,
pack=True,
pack_dtype=pack_dtype,
)

model.to(CPU)

logger.info("Packing model...")

modules = find_modules(model)
modules = {n: modules[n] for n in quantizers}
make_quant(
quant_linear_cls = make_quant(
model,
quantizers,
bits,
Expand All @@ -520,7 +507,10 @@ def pack_model(
dynamic=dynamic,
pack_dtype=pack_dtype,
)
qModules = find_modules(model, [quantLinear])
qModules = find_modules(model, [quant_linear_cls])

assert len(qModules) > 0, f"No quantizeed modules[{quant_linear_cls}] found in the model."

names = list(qModules.keys())

if parallel_packing:
Expand All @@ -537,7 +527,7 @@ def wrapper(name):
pass

logger.info("Model packed.")
return quantLinear
return quant_linear_cls


def verify_model_hash(file_path: str, verify_hash: str):
Expand Down