@@ -123,16 +123,26 @@ def move_weight_functions(m, device):
123123    return  memory 
124124
125125class  LowVramPatch :
126-     def  __init__ (self , key , patches ):
126+     def  __init__ (self , key , patches ,  convert_func = None ,  set_func = None ):
127127        self .key  =  key 
128128        self .patches  =  patches 
129+         self .convert_func  =  convert_func 
130+         self .set_func  =  set_func 
131+ 
129132    def  __call__ (self , weight ):
133+         if  self .convert_func  is  not None :
134+             weight  =  self .convert_func (weight .to (dtype = torch .float32 , copy = True ), inplace = True )
135+ 
130136        intermediate_dtype  =  weight .dtype 
131-         if  intermediate_dtype  not  in torch .float32 , torch .float16 , torch .bfloat16 ]: #intermediate_dtype has to be one that is supported in math ops 
137+         if  self . set_func   is   None   and   intermediate_dtype  not  in torch .float32 , torch .float16 , torch .bfloat16 ]: #intermediate_dtype has to be one that is supported in math ops 
132138            intermediate_dtype  =  torch .float32 
133139            return  comfy .float .stochastic_rounding (comfy .lora .calculate_weight (self .patches [self .key ], weight .to (intermediate_dtype ), self .key , intermediate_dtype = intermediate_dtype ), weight .dtype , seed = string_to_seed (self .key ))
134140
135-         return  comfy .lora .calculate_weight (self .patches [self .key ], weight , self .key , intermediate_dtype = intermediate_dtype )
141+         out  =  comfy .lora .calculate_weight (self .patches [self .key ], weight , self .key , intermediate_dtype = intermediate_dtype )
142+         if  self .set_func  is  not None :
143+             return  self .set_func (out , seed = string_to_seed (self .key ), return_weight = True )
144+         else :
145+             return  out 
136146
137147def  get_key_weight (model , key ):
138148    set_func  =  None 
@@ -657,13 +667,15 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
657667                        if  force_patch_weights :
658668                            self .patch_weight_to_device (weight_key )
659669                        else :
660-                             m .weight_function  =  [LowVramPatch (weight_key , self .patches )]
670+                             _ , set_func , convert_func  =  get_key_weight (self .model , weight_key )
671+                             m .weight_function  =  [LowVramPatch (weight_key , self .patches , convert_func , set_func )]
661672                            patch_counter  +=  1 
662673                    if  bias_key  in  self .patches :
663674                        if  force_patch_weights :
664675                            self .patch_weight_to_device (bias_key )
665676                        else :
666-                             m .bias_function  =  [LowVramPatch (bias_key , self .patches )]
677+                             _ , set_func , convert_func  =  get_key_weight (self .model , bias_key )
678+                             m .bias_function  =  [LowVramPatch (bias_key , self .patches , convert_func , set_func )]
667679                            patch_counter  +=  1 
668680
669681                    cast_weight  =  True 
@@ -825,10 +837,12 @@ def partially_unload(self, device_to, memory_to_free=0):
825837                        module_mem  +=  move_weight_functions (m , device_to )
826838                        if  lowvram_possible :
827839                            if  weight_key  in  self .patches :
828-                                 m .weight_function .append (LowVramPatch (weight_key , self .patches ))
840+                                 _ , set_func , convert_func  =  get_key_weight (self .model , weight_key )
841+                                 m .weight_function .append (LowVramPatch (weight_key , self .patches , convert_func , set_func ))
829842                                patch_counter  +=  1 
830843                            if  bias_key  in  self .patches :
831-                                 m .bias_function .append (LowVramPatch (bias_key , self .patches ))
844+                                 _ , set_func , convert_func  =  get_key_weight (self .model , bias_key )
845+                                 m .bias_function .append (LowVramPatch (bias_key , self .patches , convert_func , set_func ))
832846                                patch_counter  +=  1 
833847                            cast_weight  =  True 
834848
0 commit comments