@@ -223,30 +223,28 @@ def _replace_linear(
223
223
module .qzeros if hasattr (module , "qzeros" ) else None ,
224
224
g_idx ,
225
225
)
226
+ if not hasattr (module , "qweight" ):
227
+ n_pack = 32 // quantization_config .bits
228
+
229
+ weight = torch .zeros (
230
+ (math .ceil (in_features / n_pack ), out_features ),
231
+ dtype = torch .int32 ,
232
+ device = torch .device (device ),
233
+ )
234
+ model ._modules [name ].set_weights_bias (
235
+ module .qweight .data if hasattr (module , "qweight" ) else weight ,
236
+ None if module .bias is None else module .bias .data ,
237
+ )
226
238
else :
227
239
raise Exception ("{} device Unsupported weight only quantization!" .format (device ))
228
240
229
241
is_replaced = True
242
+ is_removed = True
230
243
# Store the module class in case we need to transpose the weight later
231
244
model ._modules [name ].source_cls = type (module )
232
245
# Force requires grad to False to avoid unexpected errors
233
246
model ._modules [name ].requires_grad_ (False )
234
247
235
- if device == "xpu" or device == torch .device ("xpu" ):
236
- if not hasattr (module , "qweight" ):
237
- n_pack = 32 // quantization_config .bits
238
-
239
- weight = torch .zeros (
240
- (math .ceil (in_features / n_pack ), out_features ),
241
- dtype = torch .int32 ,
242
- device = torch .device (device ),
243
- )
244
- model ._modules [name ].set_weights_bias (
245
- module .qweight .data if hasattr (module , "qweight" ) else weight ,
246
- None if module .bias is None else module .bias .data ,
247
- )
248
- is_removed = True
249
-
250
248
if not is_removed and len (list (module .children ())) > 0 : # pylint: disable=E1101
251
249
_ , is_replaced = _replace_linear (
252
250
module ,
0 commit comments