diff --git a/neural_compressor/transformers/models/modeling_auto.py b/neural_compressor/transformers/models/modeling_auto.py index 55ef52a8e01..3ec2d0de9a2 100644 --- a/neural_compressor/transformers/models/modeling_auto.py +++ b/neural_compressor/transformers/models/modeling_auto.py @@ -180,9 +180,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): assert hasattr(torch, "xpu") and torch.xpu.is_available(), "There is no xpu device in this system!" quantization_config.update(**{"device": "xpu"}) quantization_config.post_init_xpu() - if ( - not torch.cuda.is_available() or device_map == "cpu" or device_map == torch.device("cpu") - ) and model.config.model_type == "chatglm": + if (device_map == "cpu" or device_map == torch.device("cpu")) and model.config.model_type == "chatglm": model = model.float() model = convert_to_quantized_model(model, quantization_config, device=device_map) if isinstance(quantization_config, AwqConfig): diff --git a/neural_compressor/transformers/quantization/utils.py b/neural_compressor/transformers/quantization/utils.py index e66e573e3b2..d6f90804a52 100644 --- a/neural_compressor/transformers/quantization/utils.py +++ b/neural_compressor/transformers/quantization/utils.py @@ -223,30 +223,28 @@ def _replace_linear( module.qzeros if hasattr(module, "qzeros") else None, g_idx, ) + if not hasattr(module, "qweight"): + n_pack = 32 // quantization_config.bits + + weight = torch.zeros( + (math.ceil(in_features / n_pack), out_features), + dtype=torch.int32, + device=torch.device(device), + ) + model._modules[name].set_weights_bias( + module.qweight.data if hasattr(module, "qweight") else weight, + None if module.bias is None else module.bias.data, + ) else: raise Exception("{} device Unsupported weight only quantization!".format(device)) is_replaced = True + is_removed = True # Store the module class in case we need to transpose the weight later model._modules[name].source_cls = type(module) # Force requires grad to False to avoid unexpected errors model._modules[name].requires_grad_(False) - if device == "xpu" or device == torch.device("xpu"): - if not hasattr(module, "qweight"): - n_pack = 32 // quantization_config.bits - - weight = torch.zeros( - (math.ceil(in_features / n_pack), out_features), - dtype=torch.int32, - device=torch.device(device), - ) - model._modules[name].set_weights_bias( - module.qweight.data if hasattr(module, "qweight") else weight, - None if module.bias is None else module.bias.data, - ) - is_removed = True - if not is_removed and len(list(module.children())) > 0: # pylint: disable=E1101 _, is_replaced = _replace_linear( module,