1515import torch
1616import torch .nn as nn
1717import torch .nn .functional as F
18- from build .utils import find_multiple , get_precision , name_to_dtype , use_et_backend , state_dict_device
18+ from build .utils import (
19+ find_multiple ,
20+ get_precision ,
21+ name_to_dtype ,
22+ state_dict_device ,
23+ use_et_backend ,
24+ )
1925
2026
2127#########################################################################
@@ -116,6 +122,28 @@ def quantized_model(self) -> nn.Module:
116122 return self .model_ .to (device = self .device , dtype = self .dtype )
117123
118124
125+ #########################################################################
126+ ### wrapper for setting device as a QuantHandler ###
127+
128+
129+ class ExecutorHandler (QuantHandler ):
130+ def __init__ (self , model : nn .Module , device = "cpu" , tokenizer = None , * , accelerator ):
131+ self .model_ = model
132+
133+ if isinstance (accelerator , str ):
134+ device = get_device_str (accelerator )
135+ self .device = device
136+
137+ def create_quantized_state_dict (self ) -> Dict : # "StateDict"
138+ pass
139+
140+ def convert_for_runtime (self ) -> nn .Module :
141+ pass
142+
143+ def quantized_model (self ) -> nn .Module :
144+ return self .model_ .to (device = self .device )
145+
146+
119147#########################################################################
120148##### Quantization Primitives ######
121149
@@ -407,8 +435,8 @@ def __init__(
407435 @torch .no_grad ()
408436 def create_quantized_state_dict (self ) -> Dict :
409437 cur_state_dict = state_dict_device (self .model_ .state_dict ())
410- dict_device = "cpu" # self.device
411-
438+ dict_device = "cpu" # self.device
439+
412440 if self .bitwidth == 4 :
413441 range_min = - 8
414442 range_max = 7
@@ -824,12 +852,11 @@ def __init__(
824852 assert groupsize in [32 , 64 , 128 , 256 ]
825853 assert inner_k_tiles in [2 , 4 , 8 ]
826854
827-
828855 # @torch.no_grad()
829856 # def p(self):
830857 # cur_state_dict = state_dict_device(self.model_.state_dict())
831858 # dict_device = "cpu" # self.device
832- #
859+ #
833860 # for fqn, mod in self.model_.named_modules():
834861 # if hasattr(mod, "weight"):
835862 # print(f"device={str(mod.weight.data.device)}")
@@ -838,7 +865,7 @@ def __init__(
838865 def create_quantized_state_dict (self ):
839866 cur_state_dict = state_dict_device (self .model_ .state_dict ())
840867 dict_device = "cpu" # self.device
841-
868+
842869 for fqn , mod in self .model_ .named_modules ():
843870 if isinstance (mod , torch .nn .Linear ):
844871 assert not mod .bias
@@ -1282,4 +1309,5 @@ def quantized_model(self) -> nn.Module:
12821309 "linear:int4-gptq" : WeightOnlyInt4GPTQQuantHandler ,
12831310 "linear:hqq" : WeightOnlyInt4HqqQuantHandler ,
12841311 "precision" : PrecisionHandler ,
1312+ "executor" : ExecutorHandler ,
12851313}
0 commit comments