@@ -123,16 +123,26 @@ def move_weight_functions(m, device):
123123 return memory
124124
125125class LowVramPatch :
126- def __init__ (self , key , patches ):
126+ def __init__ (self , key , patches , convert_func = None , set_func = None ):
127127 self .key = key
128128 self .patches = patches
129+ self .convert_func = convert_func
130+ self .set_func = set_func
131+
129132 def __call__ (self , weight ):
133+ if self .convert_func is not None :
134+ weight = self .convert_func (weight .to (dtype = torch .float32 , copy = True ), inplace = True )
135+
130136 intermediate_dtype = weight .dtype
131- if intermediate_dtype not in [torch .float32 , torch .float16 , torch .bfloat16 ]: #intermediate_dtype has to be one that is supported in math ops
137+ if self . set_func is None and intermediate_dtype not in [torch .float32 , torch .float16 , torch .bfloat16 ]: #intermediate_dtype has to be one that is supported in math ops
132138 intermediate_dtype = torch .float32
133139 return comfy .float .stochastic_rounding (comfy .lora .calculate_weight (self .patches [self .key ], weight .to (intermediate_dtype ), self .key , intermediate_dtype = intermediate_dtype ), weight .dtype , seed = string_to_seed (self .key ))
134140
135- return comfy .lora .calculate_weight (self .patches [self .key ], weight , self .key , intermediate_dtype = intermediate_dtype )
141+ out = comfy .lora .calculate_weight (self .patches [self .key ], weight , self .key , intermediate_dtype = intermediate_dtype )
142+ if self .set_func is not None :
143+ return self .set_func (out , seed = string_to_seed (self .key ), return_weight = True )
144+ else :
145+ return out
136146
137147def get_key_weight (model , key ):
138148 set_func = None
@@ -657,13 +667,15 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
657667 if force_patch_weights :
658668 self .patch_weight_to_device (weight_key )
659669 else :
660- m .weight_function = [LowVramPatch (weight_key , self .patches )]
670+ _ , set_func , convert_func = get_key_weight (self .model , weight_key )
671+ m .weight_function = [LowVramPatch (weight_key , self .patches , convert_func , set_func )]
661672 patch_counter += 1
662673 if bias_key in self .patches :
663674 if force_patch_weights :
664675 self .patch_weight_to_device (bias_key )
665676 else :
666- m .bias_function = [LowVramPatch (bias_key , self .patches )]
677+ _ , set_func , convert_func = get_key_weight (self .model , bias_key )
678+ m .bias_function = [LowVramPatch (bias_key , self .patches , convert_func , set_func )]
667679 patch_counter += 1
668680
669681 cast_weight = True
@@ -825,10 +837,12 @@ def partially_unload(self, device_to, memory_to_free=0):
825837 module_mem += move_weight_functions (m , device_to )
826838 if lowvram_possible :
827839 if weight_key in self .patches :
828- m .weight_function .append (LowVramPatch (weight_key , self .patches ))
840+ _ , set_func , convert_func = get_key_weight (self .model , weight_key )
841+ m .weight_function .append (LowVramPatch (weight_key , self .patches , convert_func , set_func ))
829842 patch_counter += 1
830843 if bias_key in self .patches :
831- m .bias_function .append (LowVramPatch (bias_key , self .patches ))
844+ _ , set_func , convert_func = get_key_weight (self .model , bias_key )
845+ m .bias_function .append (LowVramPatch (bias_key , self .patches , convert_func , set_func ))
832846 patch_counter += 1
833847 cast_weight = True
834848
0 commit comments