Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up
self.object_patches_backup = {}
self.weight_wrapper_patches = {}
self.model_options = {"transformer_options":{}}
self.model_size()
self.load_device = load_device
self.offload_device = offload_device
self.weight_inplace_update = weight_inplace_update
Expand Down Expand Up @@ -286,7 +285,7 @@ def lowvram_patch_counter(self):
return self.model.lowvram_patch_counter

def clone(self):
n = self.__class__(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
Expand Down
188 changes: 95 additions & 93 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,113 +540,115 @@ def forward(self, *args, **kwargs):
# ==============================================================================
from .quant_ops import QuantizedTensor, QUANT_ALGOS

class MixedPrecisionOps(disable_weight_init):
_layer_quant_config = {}
_compute_dtype = torch.bfloat16

class Linear(torch.nn.Module, CastWeightBiasOp):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super().__init__()

self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
# self.factory_kwargs = {"device": device, "dtype": dtype}

self.in_features = in_features
self.out_features = out_features
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)

self.tensor_class = None
def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
class MixedPrecisionOps(manual_cast):
_layer_quant_config = layer_quant_config
_compute_dtype = compute_dtype
_full_precision_mm = full_precision_mm

class Linear(torch.nn.Module, CastWeightBiasOp):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super().__init__()

self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
# self.factory_kwargs = {"device": device, "dtype": dtype}

self.in_features = in_features
self.out_features = out_features
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)

def reset_parameters(self):
return None
self.tensor_class = None
self._full_precision_mm = MixedPrecisionOps._full_precision_mm

def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
def reset_parameters(self):
return None

device = self.factory_kwargs["device"]
layer_name = prefix.rstrip('.')
weight_key = f"{prefix}weight"
weight = state_dict.pop(weight_key, None)
if weight is None:
raise ValueError(f"Missing weight for layer {layer_name}")
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):

manually_loaded_keys = [weight_key]
device = self.factory_kwargs["device"]
layer_name = prefix.rstrip('.')
weight_key = f"{prefix}weight"
weight = state_dict.pop(weight_key, None)
if weight is None:
raise ValueError(f"Missing weight for layer {layer_name}")

if layer_name not in MixedPrecisionOps._layer_quant_config:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
if quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")

qconfig = QUANT_ALGOS[quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]

weight_scale_key = f"{prefix}weight_scale"
layout_params = {
'scale': state_dict.pop(weight_scale_key, None),
'orig_dtype': MixedPrecisionOps._compute_dtype,
'block_size': qconfig.get("group_size", None),
}
if layout_params['scale'] is not None:
manually_loaded_keys.append(weight_scale_key)

self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
requires_grad=False
)

for param_name in qconfig["parameters"]:
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
continue
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)

super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

for key in manually_loaded_keys:
if key in missing_keys:
missing_keys.remove(key)

def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
manually_loaded_keys = [weight_key]

def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
if layer_name not in MixedPrecisionOps._layer_quant_config:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
if quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")

qconfig = QUANT_ALGOS[quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]

weight_scale_key = f"{prefix}weight_scale"
layout_params = {
'scale': state_dict.pop(weight_scale_key, None),
'orig_dtype': MixedPrecisionOps._compute_dtype,
'block_size': qconfig.get("group_size", None),
}
if layout_params['scale'] is not None:
manually_loaded_keys.append(weight_scale_key)

self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
requires_grad=False
)

for param_name in qconfig["parameters"]:
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
continue
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)

super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

for key in manually_loaded_keys:
if key in missing_keys:
missing_keys.remove(key)

def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)

def forward(self, input, *args, **kwargs):
run_every_op()
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x

if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(input, *args, **kwargs)
if (getattr(self, 'layout_type', None) is not None and
getattr(self, 'input_scale', None) is not None and
not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)
def forward(self, input, *args, **kwargs):
run_every_op()

if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(input, *args, **kwargs)
if (getattr(self, 'layout_type', None) is not None and
getattr(self, 'input_scale', None) is not None and
not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)
return MixedPrecisionOps

def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
MixedPrecisionOps._compute_dtype = compute_dtype
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
return MixedPrecisionOps
return mixed_precision_ops(model_config.layer_quant_config, compute_dtype)

fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None:
Expand Down
12 changes: 12 additions & 0 deletions comfy/quant_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,18 @@ def generic_copy_(func, args, kwargs):
return func(*args, **kwargs)


@register_generic_util(torch.ops.aten.to.dtype)
def generic_to_dtype(func, args, kwargs):
"""Handle .to(dtype) calls - dtype conversion only."""
src = args[0]
if isinstance(src, QuantizedTensor):
# For dtype-only conversion, just change the orig_dtype, no real cast is needed
target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype')
src._layout_params["orig_dtype"] = target_dtype
return src
return func(*args, **kwargs)


@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
return True
Expand Down
9 changes: 8 additions & 1 deletion comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,12 @@ class CLIPType(Enum):
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = []
for p in ckpt_paths:
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
if metadata is not None:
quant_metadata = metadata.get("_quantization_metadata", None)
if quant_metadata is not None:
sd["_quantization_metadata"] = quant_metadata
clip_data.append(sd)
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)


Expand Down Expand Up @@ -1142,6 +1147,8 @@ class EmptyClass:

parameters = 0
for c in clip_data:
if "_quantization_metadata" in c:
c.pop("_quantization_metadata")
parameters += comfy.utils.calculate_parameters(c)
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)

Expand Down
18 changes: 14 additions & 4 deletions comfy/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,23 @@ def __init__(self, device="cpu", max_length=77,

operations = model_options.get("custom_operations", None)
scaled_fp8 = None
quantization_metadata = model_options.get("quantization_metadata", None)

if operations is None:
scaled_fp8 = model_options.get("scaled_fp8", None)
if scaled_fp8 is not None:
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
layer_quant_config = None
if quantization_metadata is not None:
layer_quant_config = json.loads(quantization_metadata).get("layers", None)

if layer_quant_config is not None:
operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True)
logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers")
else:
operations = comfy.ops.manual_cast
# Fallback to scaled_fp8_ops for backward compatibility
scaled_fp8 = model_options.get("scaled_fp8", None)
if scaled_fp8 is not None:
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
else:
operations = comfy.ops.manual_cast

self.operations = operations
self.transformer = model_class(config, dtype, device, self.operations)
Expand Down
3 changes: 3 additions & 0 deletions comfy/text_encoders/hunyuan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def llama_detect(state_dict, prefix=""):
if scaled_fp8_key in state_dict:
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype

if "_quantization_metadata" in state_dict:
out["llama_quantization_metadata"] = state_dict["_quantization_metadata"]

return out


Expand Down
17 changes: 5 additions & 12 deletions tests-unit/comfy_quant/test_mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,8 @@ class TestMixedPrecisionOps(unittest.TestCase):

def test_all_layers_standard(self):
"""Test that model with no quantization works normally"""
# Configure no quantization
ops.MixedPrecisionOps._layer_quant_config = {}

# Create model
model = SimpleModel(operations=ops.MixedPrecisionOps)
model = SimpleModel(operations=ops.mixed_precision_ops({}))

# Initialize weights manually
model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16))
Expand Down Expand Up @@ -76,7 +73,6 @@ def test_mixed_precision_load(self):
"params": {}
}
}
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config

# Create state dict with mixed precision
fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
Expand All @@ -99,7 +95,7 @@ def test_mixed_precision_load(self):
}

# Create model and load state dict (strict=False because custom loading pops keys)
model = SimpleModel(operations=ops.MixedPrecisionOps)
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
model.load_state_dict(state_dict, strict=False)

# Verify weights are wrapped in QuantizedTensor
Expand Down Expand Up @@ -132,7 +128,6 @@ def test_state_dict_quantized_preserved(self):
"params": {}
}
}
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config

# Create and load model
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
Expand All @@ -146,7 +141,7 @@ def test_state_dict_quantized_preserved(self):
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
}

model = SimpleModel(operations=ops.MixedPrecisionOps)
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
model.load_state_dict(state_dict1, strict=False)

# Save state dict
Expand All @@ -170,7 +165,6 @@ def test_weight_function_compatibility(self):
"params": {}
}
}
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config

# Create and load model
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
Expand All @@ -184,7 +178,7 @@ def test_weight_function_compatibility(self):
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
}

model = SimpleModel(operations=ops.MixedPrecisionOps)
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
model.load_state_dict(state_dict, strict=False)

# Add a weight function (simulating LoRA)
Expand All @@ -210,7 +204,6 @@ def test_error_handling_unknown_format(self):
"params": {}
}
}
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config

# Create state dict
state_dict = {
Expand All @@ -223,7 +216,7 @@ def test_error_handling_unknown_format(self):
}

# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
model = SimpleModel(operations=ops.MixedPrecisionOps)
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
with self.assertRaises(KeyError):
model.load_state_dict(state_dict, strict=False)

Expand Down