@@ -77,6 +77,15 @@ def convert_idx(self, g_idx, k):
7777 ret_idx [g_idx_2 ] = torch .arange (k ).to (g_idx .device )
7878 return ret_idx .to (torch .int32 )
7979
80+ if HAS_IPEX :
81+ try :
82+ # monkey patch GPTQShuffle.convert_idx to use fixed convert_idx, fix the slow ipex generate issue
83+ from intel_extension_for_pytorch .nn .utils ._quantize_convert import GPTQShuffle
84+
85+ GPTQShuffle .convert_idx = convert_idx
86+ except ImportError :
87+ # if import GPTQShuffle failed, do nothing
88+ pass
8089
8190class IPEXQuantLinear (BaseQuantLinear ):
8291 SUPPORTS_BITS = [4 ]
@@ -170,14 +179,6 @@ def post_init(self):
170179
171180 def init_ipex_linear (self , x : torch .Tensor ):
172181 if not self .training and HAS_IPEX and not x .requires_grad :
173- try :
174- # monkey patch GPTQShuffle.convert_idx to use fixed convert_idx, fix the slow ipex generate issue
175- from intel_extension_for_pytorch .nn .utils ._quantize_convert import GPTQShuffle
176- GPTQShuffle .convert_idx = convert_idx
177- except ImportError :
178- # if import GPTQShuffle failed, do nothing
179- pass
180-
181182 self .ipex_linear = IPEXWeightOnlyQuantizedLinear .from_weight (self .qweight , self .scales , self .qzeros ,
182183 self .infeatures , self .outfeatures , None , self .bias ,
183184 self .group_size , self .g_idx , quant_method = QuantMethod .GPTQ_GEMM , dtype = QuantDtype .INT4 )
0 commit comments