Skip to content

Commit 72398b6

Browse files
fix xpu device set weight and bias (#2010)
Signed-off-by: changwangss <[email protected]> Co-authored-by: Sun, Xuehao <[email protected]>
1 parent 9d27743 commit 72398b6

File tree

2 files changed

+14
-18
lines changed

2 files changed

+14
-18
lines changed

neural_compressor/transformers/models/modeling_auto.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
180180
assert hasattr(torch, "xpu") and torch.xpu.is_available(), "There is no xpu device in this system!"
181181
quantization_config.update(**{"device": "xpu"})
182182
quantization_config.post_init_xpu()
183-
if (
184-
not torch.cuda.is_available() or device_map == "cpu" or device_map == torch.device("cpu")
185-
) and model.config.model_type == "chatglm":
183+
if (device_map == "cpu" or device_map == torch.device("cpu")) and model.config.model_type == "chatglm":
186184
model = model.float()
187185
model = convert_to_quantized_model(model, quantization_config, device=device_map)
188186
if isinstance(quantization_config, AwqConfig):

neural_compressor/transformers/quantization/utils.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -223,30 +223,28 @@ def _replace_linear(
223223
module.qzeros if hasattr(module, "qzeros") else None,
224224
g_idx,
225225
)
226+
if not hasattr(module, "qweight"):
227+
n_pack = 32 // quantization_config.bits
228+
229+
weight = torch.zeros(
230+
(math.ceil(in_features / n_pack), out_features),
231+
dtype=torch.int32,
232+
device=torch.device(device),
233+
)
234+
model._modules[name].set_weights_bias(
235+
module.qweight.data if hasattr(module, "qweight") else weight,
236+
None if module.bias is None else module.bias.data,
237+
)
226238
else:
227239
raise Exception("{} device Unsupported weight only quantization!".format(device))
228240

229241
is_replaced = True
242+
is_removed = True
230243
# Store the module class in case we need to transpose the weight later
231244
model._modules[name].source_cls = type(module)
232245
# Force requires grad to False to avoid unexpected errors
233246
model._modules[name].requires_grad_(False)
234247

235-
if device == "xpu" or device == torch.device("xpu"):
236-
if not hasattr(module, "qweight"):
237-
n_pack = 32 // quantization_config.bits
238-
239-
weight = torch.zeros(
240-
(math.ceil(in_features / n_pack), out_features),
241-
dtype=torch.int32,
242-
device=torch.device(device),
243-
)
244-
model._modules[name].set_weights_bias(
245-
module.qweight.data if hasattr(module, "qweight") else weight,
246-
None if module.bias is None else module.bias.data,
247-
)
248-
is_removed = True
249-
250248
if not is_removed and len(list(module.children())) > 0: # pylint: disable=E1101
251249
_, is_replaced = _replace_linear(
252250
module,

0 commit comments

Comments
 (0)