1515import torch
1616import torch .nn as nn
1717import torch .nn .functional as F
18- from build .utils import find_multiple , get_precision , name_to_dtype , use_et_backend
18+ from build .utils import find_multiple , get_precision , name_to_dtype , use_et_backend , state_dict_device
1919
2020
2121#########################################################################
@@ -63,7 +63,7 @@ def convert_for_runtime(self) -> nn.Module:
6363 pass
6464
6565 def quantized_model (self ) -> nn .Module :
66- model_updated_state_dict = self .create_quantized_state_dict ()
66+ model_updated_state_dict = state_dict_device ( self .create_quantized_state_dict () )
6767 self .convert_for_runtime ()
6868 self .model_ .load_state_dict (model_updated_state_dict )
6969 return self .model_
@@ -406,8 +406,9 @@ def __init__(
406406
407407 @torch .no_grad ()
408408 def create_quantized_state_dict (self ) -> Dict :
409- cur_state_dict = self .model_ .state_dict ()
410-
409+ cur_state_dict = state_dict_device (self .model_ .state_dict ())
410+ dict_device = "cpu" # self.device
411+
411412 if self .bitwidth == 4 :
412413 range_min = - 8
413414 range_max = 7
@@ -446,8 +447,8 @@ def create_quantized_state_dict(self) -> Dict:
446447 scales_dtype = mod .weight .dtype ,
447448 )
448449
449- weight = weight .to (device = self . device )
450- scales = scales .to (device = self . device )
450+ weight = weight .to (device = dict_device )
451+ scales = scales .to (device = dict_device )
451452 cur_state_dict [f"{ fqn } .weight" ] = weight
452453 # squeeze makes groupsize=rowsize unidimensional
453454 cur_state_dict [f"{ fqn } .scales" ] = scales .squeeze (dim = - 1 )
@@ -553,7 +554,8 @@ def __init__(
553554
554555 @torch .no_grad ()
555556 def create_quantized_state_dict (self ) -> Dict :
556- cur_state_dict = self .model_ .state_dict ()
557+ cur_state_dict = state_dict_device (self .model_ .state_dict ())
558+ dict_device = "cpu" # self.device
557559
558560 if self .bitwidth == 4 :
559561 range_min = - 8
@@ -595,8 +597,8 @@ def create_quantized_state_dict(self) -> Dict:
595597 weight_packed = weight_even + weight_odd
596598 weight = weight_packed
597599
598- weight = weight .to (device = self . device )
599- scales = scales .to (device = self . device )
600+ weight = weight .to (device = dict_device )
601+ scales = scales .to (device = dict_device )
600602 # Update state dict
601603 cur_state_dict [f"{ fqn } .weight" ] = weight
602604 # squeeze makes groupsize=rowsize unidimensional
@@ -822,9 +824,21 @@ def __init__(
822824 assert groupsize in [32 , 64 , 128 , 256 ]
823825 assert inner_k_tiles in [2 , 4 , 8 ]
824826
827+
828+ # @torch.no_grad()
829+ # def p(self):
830+ # cur_state_dict = state_dict_device(self.model_.state_dict())
831+ # dict_device = "cpu" # self.device
832+ #
833+ # for fqn, mod in self.model_.named_modules():
834+ # if hasattr(mod, "weight"):
835+ # print(f"device={str(mod.weight.data.device)}")
836+
825837 @torch .no_grad ()
826838 def create_quantized_state_dict (self ):
827- cur_state_dict = self .model_ .state_dict ()
839+ cur_state_dict = state_dict_device (self .model_ .state_dict ())
840+ dict_device = "cpu" # self.device
841+
828842 for fqn , mod in self .model_ .named_modules ():
829843 if isinstance (mod , torch .nn .Linear ):
830844 assert not mod .bias
@@ -856,8 +870,8 @@ def create_quantized_state_dict(self):
856870 weight .to (torch .float ), self .groupsize , self .inner_k_tiles
857871 )
858872 )
859- weight_int4pack = weight_int4pack .to (device = self . device )
860- scales_and_zeros = scales_and_zeros .to (device = self . device )
873+ weight_int4pack = weight_int4pack .to (device = dict_device )
874+ scales_and_zeros = scales_and_zeros .to (device = dict_device )
861875 cur_state_dict [f"{ fqn } .weight" ] = weight_int4pack
862876 cur_state_dict [f"{ fqn } .scales_and_zeros" ] = scales_and_zeros
863877
@@ -877,6 +891,7 @@ def quantized_model(self) -> nn.Module:
877891 model_updated_state_dict = self .create_quantized_state_dict ()
878892 self .convert_for_runtime ()
879893 self .model_ .load_state_dict (model_updated_state_dict )
894+ # self.p()
880895 return self .model_
881896
882897
0 commit comments