@@ -435,7 +435,7 @@ def tmp(_, inp, out):
435435 torch .cuda .empty_cache ()
436436
437437 if self .quantize_config .format == FORMAT .BITBLAS :
438- from ..nn_modules .qlinear .qlinear_bitblas import QuantLinear as BitBLASQuantLinear
438+ from ..nn_modules .qlinear .qlinear_bitblas import BitBLASQuantLinear
439439
440440 # BitBLASQuantLinear does not have a pack method and needs to be converted to BitBLAS format when saving.
441441 logger .info ("Converting model to BitBlas Format..." )
@@ -515,15 +515,15 @@ def save_quantized(
515515 # no need to set it back, no calculation below
516516 if quantize_config .bits != 4 :
517517 cuda_name_modules = {}
518- from gptqmodel .nn_modules .qlinear .qlinear_cuda import BaseCudaQuantLinear
518+ from gptqmodel .nn_modules .qlinear .qlinear_cuda import CudaQuantLinear
519519 for name , module in model .named_modules ():
520- if isinstance (module , BaseCudaQuantLinear ):
520+ if isinstance (module , CudaQuantLinear ):
521521 cuda_name_modules [name ] = module .gptqmodel_cuda
522522 module .gptqmodel_cuda = None
523523 model = copy .deepcopy (self .model )
524524
525525 for name , modules in model .named_modules ():
526- if isinstance (module , BaseCudaQuantLinear ) and name in cuda_name_modules :
526+ if isinstance (module , CudaQuantLinear ) and name in cuda_name_modules :
527527 module .gptqmodel_cuda = cuda_name_modules [name ]
528528
529529 del cuda_name_modules
@@ -1109,9 +1109,9 @@ def skip(*args, **kwargs):
11091109
11101110 # == step6: (optional) warmup triton == #
11111111 if backend == Backend .TRITON and warmup_triton :
1112- from ..nn_modules .qlinear .qlinear_tritonv2 import QuantLinear
1112+ from ..nn_modules .qlinear .qlinear_tritonv2 import TritonV2QuantLinear
11131113
1114- QuantLinear .warmup (model , seqlen = model .seqlen )
1114+ TritonV2QuantLinear .warmup (model , seqlen = model .seqlen )
11151115
11161116 return cls (
11171117 model ,
@@ -1124,9 +1124,9 @@ def warmup_triton(self, enabled: bool = True):
11241124 if not enabled :
11251125 return
11261126
1127- from ..nn_modules .qlinear .qlinear_tritonv2 import QuantLinear
1127+ from ..nn_modules .qlinear .qlinear_tritonv2 import TritonV2QuantLinear
11281128
1129- QuantLinear .warmup (self .model , seqlen = self .model .seqlen )
1129+ TritonV2QuantLinear .warmup (self .model , seqlen = self .model .seqlen )
11301130
11311131 def __getattr__ (self , item ):
11321132 try :
0 commit comments