Skip to content

Commit d533bad

Browse files
comfyanonymousadlerfaulkner
authored andcommitted
1 parent 3072253 commit d533bad

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

comfy/model_patcher.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,16 +123,26 @@ def move_weight_functions(m, device):
123123
return memory
124124

125125
class LowVramPatch:
126-
def __init__(self, key, patches):
126+
def __init__(self, key, patches, convert_func=None, set_func=None):
127127
self.key = key
128128
self.patches = patches
129+
self.convert_func = convert_func
130+
self.set_func = set_func
131+
129132
def __call__(self, weight):
133+
if self.convert_func is not None:
134+
weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True)
135+
130136
intermediate_dtype = weight.dtype
131-
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
137+
if self.set_func is None and intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
132138
intermediate_dtype = torch.float32
133139
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
134140

135-
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
141+
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
142+
if self.set_func is not None:
143+
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
144+
else:
145+
return out
136146

137147
def get_key_weight(model, key):
138148
set_func = None
@@ -657,13 +667,15 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
657667
if force_patch_weights:
658668
self.patch_weight_to_device(weight_key)
659669
else:
660-
m.weight_function = [LowVramPatch(weight_key, self.patches)]
670+
_, set_func, convert_func = get_key_weight(self.model, weight_key)
671+
m.weight_function = [LowVramPatch(weight_key, self.patches, convert_func, set_func)]
661672
patch_counter += 1
662673
if bias_key in self.patches:
663674
if force_patch_weights:
664675
self.patch_weight_to_device(bias_key)
665676
else:
666-
m.bias_function = [LowVramPatch(bias_key, self.patches)]
677+
_, set_func, convert_func = get_key_weight(self.model, bias_key)
678+
m.bias_function = [LowVramPatch(bias_key, self.patches, convert_func, set_func)]
667679
patch_counter += 1
668680

669681
cast_weight = True
@@ -825,10 +837,12 @@ def partially_unload(self, device_to, memory_to_free=0):
825837
module_mem += move_weight_functions(m, device_to)
826838
if lowvram_possible:
827839
if weight_key in self.patches:
828-
m.weight_function.append(LowVramPatch(weight_key, self.patches))
840+
_, set_func, convert_func = get_key_weight(self.model, weight_key)
841+
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
829842
patch_counter += 1
830843
if bias_key in self.patches:
831-
m.bias_function.append(LowVramPatch(bias_key, self.patches))
844+
_, set_func, convert_func = get_key_weight(self.model, bias_key)
845+
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
832846
patch_counter += 1
833847
cast_weight = True
834848

comfy/ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,8 +416,10 @@ def convert_weight(self, weight, inplace=False, **kwargs):
416416
else:
417417
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
418418

419-
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
419+
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
420420
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
421+
if return_weight:
422+
return weight
421423
if inplace_update:
422424
self.weight.data.copy_(weight)
423425
else:

0 commit comments

Comments
 (0)