Skip to content

Commit 4a87e96

Browse files
Make old fp8 system use new mixed quant system.
Upgrade the internal mixed quant format to be easier to deal with.
1 parent 440268d commit 4a87e96

23 files changed

+241
-256
lines changed

comfy/model_base.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
import comfy.latent_formats
5858
import comfy.model_sampling
5959
import math
60+
import json
6061
from typing import TYPE_CHECKING
6162
if TYPE_CHECKING:
6263
from comfy.model_patcher import ModelPatcher
@@ -134,7 +135,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod
134135
if not unet_config.get("disable_unet_model_creation", False):
135136
if model_config.custom_operations is None:
136137
fp8 = model_config.optimizations.get("fp8", False)
137-
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
138+
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, model_config=model_config)
138139
else:
139140
operations = model_config.custom_operations
140141
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
@@ -329,18 +330,6 @@ def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_
329330
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
330331

331332
unet_state_dict = self.diffusion_model.state_dict()
332-
333-
if self.model_config.scaled_fp8 is not None:
334-
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
335-
336-
# Save mixed precision metadata
337-
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
338-
metadata = {
339-
"format_version": "1.0",
340-
"layers": self.model_config.layer_quant_config
341-
}
342-
unet_state_dict["_quantization_metadata"] = metadata
343-
344333
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
345334

346335
if self.model_type == ModelType.V_PREDICTION:

comfy/model_detection.py

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,6 @@
66
import logging
77
import torch
88

9-
10-
def detect_layer_quantization(metadata):
11-
quant_key = "_quantization_metadata"
12-
if metadata is not None and quant_key in metadata:
13-
quant_metadata = metadata.pop(quant_key)
14-
quant_metadata = json.loads(quant_metadata)
15-
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
16-
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
17-
return quant_metadata["layers"]
18-
else:
19-
raise ValueError("Invalid quantization metadata format")
20-
return None
21-
22-
239
def count_blocks(state_dict_keys, prefix_string):
2410
count = 0
2511
while True:
@@ -767,22 +753,11 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
767753
if model_config is None and use_base_if_no_match:
768754
model_config = comfy.supported_models_base.BASE(unet_config)
769755

770-
scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix)
771-
if scaled_fp8_key in state_dict:
772-
scaled_fp8_weight = state_dict.pop(scaled_fp8_key)
773-
model_config.scaled_fp8 = scaled_fp8_weight.dtype
774-
if model_config.scaled_fp8 == torch.float32:
775-
model_config.scaled_fp8 = torch.float8_e4m3fn
776-
if scaled_fp8_weight.nelement() == 2:
777-
model_config.optimizations["fp8"] = False
778-
else:
779-
model_config.optimizations["fp8"] = True
780-
781756
# Detect per-layer quantization (mixed precision)
782-
layer_quant_config = detect_layer_quantization(metadata)
783-
if layer_quant_config:
784-
model_config.layer_quant_config = layer_quant_config
785-
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
757+
quant_config = comfy.utils.detect_layer_quantization(state_dict, unet_key_prefix)
758+
if quant_config:
759+
model_config.quant_config = quant_config
760+
logging.info("Detected mixed precision quantization")
786761

787762
return model_config
788763

comfy/model_patcher.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -126,27 +126,11 @@ class LowVramPatch:
126126
def __init__(self, key, patches, convert_func=None, set_func=None):
127127
self.key = key
128128
self.patches = patches
129-
self.convert_func = convert_func
129+
self.convert_func = convert_func # TODO: remove
130130
self.set_func = set_func
131131

132132
def __call__(self, weight):
133-
intermediate_dtype = weight.dtype
134-
if self.convert_func is not None:
135-
weight = self.convert_func(weight, inplace=False)
136-
137-
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
138-
intermediate_dtype = torch.float32
139-
out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
140-
if self.set_func is None:
141-
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
142-
else:
143-
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
144-
145-
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
146-
if self.set_func is not None:
147-
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
148-
else:
149-
return out
133+
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
150134

151135
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
152136
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3

comfy/ops.py

Lines changed: 41 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import comfy.float
2424
import comfy.rmsnorm
2525
import contextlib
26+
import json
2627

2728
def run_every_op():
2829
if torch.compiler.is_compiling():
@@ -422,22 +423,12 @@ def fp8_linear(self, input):
422423

423424
if input.ndim == 3 or input.ndim == 2:
424425
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
426+
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
425427

426-
scale_weight = self.scale_weight
427-
scale_input = self.scale_input
428-
if scale_weight is None:
429-
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
430-
else:
431-
scale_weight = scale_weight.to(input.device)
432-
433-
if scale_input is None:
434-
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
435-
input = torch.clamp(input, min=-448, max=448, out=input)
436-
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
437-
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
438-
else:
439-
scale_input = scale_input.to(input.device)
440-
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
428+
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
429+
input = torch.clamp(input, min=-448, max=448, out=input)
430+
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
431+
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
441432

442433
# Wrap weight in QuantizedTensor - this enables unified dispatch
443434
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
@@ -471,59 +462,6 @@ def forward_comfy_cast_weights(self, input):
471462
uncast_bias_weight(self, weight, bias, offload_stream)
472463
return x
473464

474-
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
475-
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
476-
class scaled_fp8_op(manual_cast):
477-
class Linear(manual_cast.Linear):
478-
def __init__(self, *args, **kwargs):
479-
if override_dtype is not None:
480-
kwargs['dtype'] = override_dtype
481-
super().__init__(*args, **kwargs)
482-
483-
def reset_parameters(self):
484-
if not hasattr(self, 'scale_weight'):
485-
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
486-
487-
if not scale_input:
488-
self.scale_input = None
489-
490-
if not hasattr(self, 'scale_input'):
491-
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
492-
return None
493-
494-
def forward_comfy_cast_weights(self, input):
495-
if fp8_matrix_mult:
496-
out = fp8_linear(self, input)
497-
if out is not None:
498-
return out
499-
500-
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
501-
502-
if weight.numel() < input.numel(): #TODO: optimize
503-
x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
504-
else:
505-
x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
506-
uncast_bias_weight(self, weight, bias, offload_stream)
507-
return x
508-
509-
def convert_weight(self, weight, inplace=False, **kwargs):
510-
if inplace:
511-
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
512-
return weight
513-
else:
514-
return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32)
515-
516-
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
517-
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
518-
if return_weight:
519-
return weight
520-
if inplace_update:
521-
self.weight.data.copy_(weight)
522-
else:
523-
self.weight = torch.nn.Parameter(weight, requires_grad=False)
524-
525-
return scaled_fp8_op
526-
527465
CUBLAS_IS_AVAILABLE = False
528466
try:
529467
from cublas_ops import CublasLinear
@@ -550,9 +488,9 @@ def forward(self, *args, **kwargs):
550488
from .quant_ops import QuantizedTensor, QUANT_ALGOS
551489

552490

553-
def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
491+
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
554492
class MixedPrecisionOps(manual_cast):
555-
_layer_quant_config = layer_quant_config
493+
_quant_config = quant_config
556494
_compute_dtype = compute_dtype
557495
_full_precision_mm = full_precision_mm
558496

@@ -595,27 +533,36 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
595533

596534
manually_loaded_keys = [weight_key]
597535

598-
if layer_name not in MixedPrecisionOps._layer_quant_config:
536+
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
537+
if layer_conf is not None:
538+
layer_conf = json.loads(layer_conf.numpy().tobytes())
539+
540+
if layer_conf is None:
599541
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
600542
else:
601-
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
602-
if quant_format is None:
543+
self.quant_format = layer_conf.get("format", None)
544+
if not self._full_precision_mm:
545+
self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False)
546+
547+
if self.quant_format is None:
603548
raise ValueError(f"Unknown quantization format for layer {layer_name}")
604549

605-
qconfig = QUANT_ALGOS[quant_format]
550+
qconfig = QUANT_ALGOS[self.quant_format]
606551
self.layout_type = qconfig["comfy_tensor_layout"]
607552

608553
weight_scale_key = f"{prefix}weight_scale"
554+
scale = state_dict.pop(weight_scale_key, None)
609555
layout_params = {
610-
'scale': state_dict.pop(weight_scale_key, None),
556+
'scale': scale,
611557
'orig_dtype': MixedPrecisionOps._compute_dtype,
612558
'block_size': qconfig.get("group_size", None),
613559
}
614-
if layout_params['scale'] is not None:
560+
561+
if scale is not None:
615562
manually_loaded_keys.append(weight_scale_key)
616563

617564
self.weight = torch.nn.Parameter(
618-
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
565+
QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params),
619566
requires_grad=False
620567
)
621568

@@ -624,7 +571,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
624571
_v = state_dict.pop(param_key, None)
625572
if _v is None:
626573
continue
627-
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
574+
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
628575
manually_loaded_keys.append(param_key)
629576

630577
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
@@ -633,6 +580,16 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
633580
if key in missing_keys:
634581
missing_keys.remove(key)
635582

583+
def state_dict(self, *args, destination=None, prefix="", **kwargs):
584+
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
585+
if isinstance(self.weight, QuantizedTensor):
586+
sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale']
587+
quant_conf = {"format": self.quant_format}
588+
if self._full_precision_mm:
589+
quant_conf["full_precision_matrix_mult"] = True
590+
sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8)
591+
return sd
592+
636593
def _forward(self, input, weight, bias):
637594
return torch.nn.functional.linear(input, weight, bias)
638595

@@ -648,9 +605,8 @@ def forward(self, input, *args, **kwargs):
648605
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
649606
return self.forward_comfy_cast_weights(input, *args, **kwargs)
650607
if (getattr(self, 'layout_type', None) is not None and
651-
getattr(self, 'input_scale', None) is not None and
652608
not isinstance(input, QuantizedTensor)):
653-
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
609+
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype)
654610
return self._forward(input, self.weight, self.bias)
655611

656612
def convert_weight(self, weight, inplace=False, **kwargs):
@@ -661,7 +617,7 @@ def convert_weight(self, weight, inplace=False, **kwargs):
661617

662618
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
663619
if getattr(self, 'layout_type', None) is not None:
664-
weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
620+
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
665621
else:
666622
weight = weight.to(self.weight.dtype)
667623
if return_weight:
@@ -672,15 +628,12 @@ def set_weight(self, weight, inplace_update=False, seed=None, return_weight=Fals
672628

673629
return MixedPrecisionOps
674630

675-
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
631+
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
676632
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
677633

678-
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
679-
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
680-
return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute)
681-
682-
if scaled_fp8 is not None:
683-
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
634+
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
635+
logging.info("Using mixed precision operations")
636+
return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute)
684637

685638
if (
686639
fp8_compute and

comfy/quant_ops.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,12 @@ def is_pinned(self):
238238
def is_contiguous(self, *arg, **kwargs):
239239
return self._qdata.is_contiguous(*arg, **kwargs)
240240

241+
def storage(self):
242+
return self._qdata.storage()
243+
244+
def untyped_storage(self):
245+
return self._qdata.untyped_storage()
246+
241247
# ==============================================================================
242248
# Generic Utilities (Layout-Agnostic Operations)
243249
# ==============================================================================
@@ -397,17 +403,20 @@ class TensorCoreFP8Layout(QuantizedLayout):
397403
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
398404
orig_dtype = tensor.dtype
399405

400-
if scale is None:
406+
if scale == "recalculate":
401407
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
402408

403-
if not isinstance(scale, torch.Tensor):
404-
scale = torch.tensor(scale)
405-
scale = scale.to(device=tensor.device, dtype=torch.float32)
409+
if scale is not None:
410+
if not isinstance(scale, torch.Tensor):
411+
scale = torch.tensor(scale)
412+
scale = scale.to(device=tensor.device, dtype=torch.float32)
406413

407-
if inplace_ops:
408-
tensor *= (1.0 / scale).to(tensor.dtype)
414+
if inplace_ops:
415+
tensor *= (1.0 / scale).to(tensor.dtype)
416+
else:
417+
tensor = tensor * (1.0 / scale).to(tensor.dtype)
409418
else:
410-
tensor = tensor * (1.0 / scale).to(tensor.dtype)
419+
scale = torch.ones((), device=tensor.device, dtype=torch.float32)
411420

412421
if stochastic_rounding > 0:
413422
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)

0 commit comments

Comments
 (0)