@@ -156,7 +156,7 @@ def make_quant(
156156 device : DEVICE = None ,
157157 from_quantized : bool = False ,
158158 pack_dtype : torch .dtype = None ,
159- ) -> BaseQuantLinear :
159+ ) -> Type [ BaseQuantLinear ] :
160160 # returns multiple validated kernels
161161 quant_linear_candidates = select_quant_linear (
162162 bits = bits ,
@@ -175,15 +175,15 @@ def make_quant(
175175 logger .info (f"make_quant: Linear candidates: { quant_linear_candidates } " )
176176
177177 # loop over actual QLinear init, catch errors and use fallbacks if applicable
178- for linear in quant_linear_candidates :
178+ for cls in quant_linear_candidates :
179179 try :
180180 # if linear is not selectedQLinear:
181181 # logger.info(f"make_quant: Faild linear: `{selectedQLinear}` failed, trying to use fallback: `{linear}`")
182182 # else:
183183 # logger.info("make_quant: Testing linear: {linear}")
184184
185- linear_instance = create_quant_layer (
186- linear = linear ,
185+ linear_cls = create_quant_layer (
186+ linear_cls = cls ,
187187 bits = bits ,
188188 desc_act = desc_act ,
189189 dynamic = dynamic ,
@@ -194,10 +194,10 @@ def make_quant(
194194 device = device ,
195195 lm_head_name = lm_head_name ,
196196 pack_dtype = pack_dtype )
197- logger .info (f"make_quant: Selected linear: `{ linear } `." )
198- return linear_instance
197+ logger .info (f"make_quant: Selected linear: `{ cls } `." )
198+ return linear_cls
199199 except NotImplementedError as e :
200- logger .info (f"make_quant: Skipped linear: `{ linear } `." )
200+ logger .info (f"make_quant: Skipped linear: `{ cls } `. " )
201201 # only fallback to other quant linears when backend is auto.
202202 if backend not in [BACKEND .AUTO , BACKEND .AUTO_TRAINABLE ]:
203203 raise e
@@ -206,7 +206,7 @@ def make_quant(
206206
207207
208208def create_quant_layer (
209- linear : nn . Module ,
209+ linear_cls : Type [ BaseQuantLinear ] ,
210210 bits : int ,
211211 desc_act : bool ,
212212 dynamic ,
@@ -216,10 +216,9 @@ def create_quant_layer(
216216 sym : bool ,
217217 device : DEVICE ,
218218 lm_head_name : str ,
219- pack_dtype : torch .dtype ,
220- ) -> BaseQuantLinear :
221- if isinstance (module , linear ):
222- return linear
219+ pack_dtype : torch .dtype ) -> Type [BaseQuantLinear ]:
220+ if isinstance (module , linear_cls ):
221+ return linear_cls
223222 for name , submodule in module .named_modules ():
224223 if name in names :
225224 ori_layer_device = next (submodule .parameters ()).device
@@ -266,7 +265,7 @@ def create_quant_layer(
266265
267266 # when loading a quantized model, device is target device passed in GPTQModel.load()
268267 # check in_features and out_features validate
269- _ , err = linear .validate (
268+ _ , err = linear_cls .validate (
270269 bits = tmp_bits ,
271270 group_size = tmp_group_size ,
272271 desc_act = tmp_desc_act ,
@@ -278,7 +277,7 @@ def create_quant_layer(
278277 if err is not None :
279278 raise err
280279
281- new_layer = linear (
280+ new_layer = linear_cls (
282281 bits = tmp_bits ,
283282 group_size = tmp_group_size ,
284283 desc_act = tmp_desc_act ,
@@ -293,7 +292,7 @@ def create_quant_layer(
293292 )
294293 new_layer .device = ori_layer_device
295294 recurse_setattr (module , name , new_layer .to (ori_layer_device ))
296- return linear
295+ return linear_cls
297296
298297# public/stable api exposed to transformer/optimum
299298def hf_convert_gptq_v1_to_v2_format (
@@ -489,25 +488,13 @@ def pack_model(
489488 parallel_packing : bool = True ,
490489 pack_dtype : torch .dtype = None ,
491490):
492- quantLinear = select_quant_linear (
493- bits = bits ,
494- dynamic = dynamic ,
495- group_size = group_size ,
496- desc_act = desc_act ,
497- sym = sym ,
498- backend = backend ,
499- format = format ,
500- pack = True ,
501- pack_dtype = pack_dtype ,
502- )
503-
504491 model .to (CPU )
505492
506493 logger .info ("Packing model..." )
507494
508495 modules = find_modules (model )
509496 modules = {n : modules [n ] for n in quantizers }
510- make_quant (
497+ quant_linear_cls = make_quant (
511498 model ,
512499 quantizers ,
513500 bits ,
@@ -520,7 +507,10 @@ def pack_model(
520507 dynamic = dynamic ,
521508 pack_dtype = pack_dtype ,
522509 )
523- qModules = find_modules (model , [quantLinear ])
510+ qModules = find_modules (model , [quant_linear_cls ])
511+
512+ assert len (qModules ) > 0 , f"No quantizeed modules[{ quant_linear_cls } ] found in the model."
513+
524514 names = list (qModules .keys ())
525515
526516 if parallel_packing :
@@ -537,7 +527,7 @@ def wrapper(name):
537527 pass
538528
539529 logger .info ("Model packed." )
540- return quantLinear
530+ return quant_linear_cls
541531
542532
543533def verify_model_hash (file_path : str , verify_hash : str ):
0 commit comments