Skip to content

Commit f1dd6e5

Browse files
Fix bug with applying loras on fp8 scaled without fp8 ops. (#10279)
1 parent fc0fbf1 commit f1dd6e5

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

comfy/model_patcher.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,17 +130,21 @@ def __init__(self, key, patches, convert_func=None, set_func=None):
130130
self.set_func = set_func
131131

132132
def __call__(self, weight):
133+
intermediate_dtype = weight.dtype
133134
if self.convert_func is not None:
134135
weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True)
135136

136-
intermediate_dtype = weight.dtype
137-
if self.set_func is None and intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
137+
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
138138
intermediate_dtype = torch.float32
139-
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
139+
out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
140+
if self.set_func is None:
141+
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
142+
else:
143+
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
140144

141145
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
142146
if self.set_func is not None:
143-
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
147+
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
144148
else:
145149
return out
146150

0 commit comments

Comments
 (0)