2323import comfy .float
2424import comfy .rmsnorm
2525import contextlib
26+ import json
2627
2728def run_every_op ():
2829 if torch .compiler .is_compiling ():
@@ -422,22 +423,12 @@ def fp8_linear(self, input):
422423
423424 if input .ndim == 3 or input .ndim == 2 :
424425 w , bias , offload_stream = cast_bias_weight (self , input , dtype = dtype , bias_dtype = input_dtype , offloadable = True )
426+ scale_weight = torch .ones ((), device = input .device , dtype = torch .float32 )
425427
426- scale_weight = self .scale_weight
427- scale_input = self .scale_input
428- if scale_weight is None :
429- scale_weight = torch .ones ((), device = input .device , dtype = torch .float32 )
430- else :
431- scale_weight = scale_weight .to (input .device )
432-
433- if scale_input is None :
434- scale_input = torch .ones ((), device = input .device , dtype = torch .float32 )
435- input = torch .clamp (input , min = - 448 , max = 448 , out = input )
436- layout_params_weight = {'scale' : scale_input , 'orig_dtype' : input_dtype }
437- quantized_input = QuantizedTensor (input .to (dtype ).contiguous (), "TensorCoreFP8Layout" , layout_params_weight )
438- else :
439- scale_input = scale_input .to (input .device )
440- quantized_input = QuantizedTensor .from_float (input , "TensorCoreFP8Layout" , scale = scale_input , dtype = dtype )
428+ scale_input = torch .ones ((), device = input .device , dtype = torch .float32 )
429+ input = torch .clamp (input , min = - 448 , max = 448 , out = input )
430+ layout_params_weight = {'scale' : scale_input , 'orig_dtype' : input_dtype }
431+ quantized_input = QuantizedTensor (input .to (dtype ).contiguous (), "TensorCoreFP8Layout" , layout_params_weight )
441432
442433 # Wrap weight in QuantizedTensor - this enables unified dispatch
443434 # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
@@ -471,59 +462,6 @@ def forward_comfy_cast_weights(self, input):
471462 uncast_bias_weight (self , weight , bias , offload_stream )
472463 return x
473464
474- def scaled_fp8_ops (fp8_matrix_mult = False , scale_input = False , override_dtype = None ):
475- logging .info ("Using scaled fp8: fp8 matrix mult: {}, scale input: {}" .format (fp8_matrix_mult , scale_input ))
476- class scaled_fp8_op (manual_cast ):
477- class Linear (manual_cast .Linear ):
478- def __init__ (self , * args , ** kwargs ):
479- if override_dtype is not None :
480- kwargs ['dtype' ] = override_dtype
481- super ().__init__ (* args , ** kwargs )
482-
483- def reset_parameters (self ):
484- if not hasattr (self , 'scale_weight' ):
485- self .scale_weight = torch .nn .parameter .Parameter (data = torch .ones ((), device = self .weight .device , dtype = torch .float32 ), requires_grad = False )
486-
487- if not scale_input :
488- self .scale_input = None
489-
490- if not hasattr (self , 'scale_input' ):
491- self .scale_input = torch .nn .parameter .Parameter (data = torch .ones ((), device = self .weight .device , dtype = torch .float32 ), requires_grad = False )
492- return None
493-
494- def forward_comfy_cast_weights (self , input ):
495- if fp8_matrix_mult :
496- out = fp8_linear (self , input )
497- if out is not None :
498- return out
499-
500- weight , bias , offload_stream = cast_bias_weight (self , input , offloadable = True )
501-
502- if weight .numel () < input .numel (): #TODO: optimize
503- x = torch .nn .functional .linear (input , weight * self .scale_weight .to (device = weight .device , dtype = weight .dtype ), bias )
504- else :
505- x = torch .nn .functional .linear (input * self .scale_weight .to (device = weight .device , dtype = weight .dtype ), weight , bias )
506- uncast_bias_weight (self , weight , bias , offload_stream )
507- return x
508-
509- def convert_weight (self , weight , inplace = False , ** kwargs ):
510- if inplace :
511- weight *= self .scale_weight .to (device = weight .device , dtype = weight .dtype )
512- return weight
513- else :
514- return weight .to (dtype = torch .float32 ) * self .scale_weight .to (device = weight .device , dtype = torch .float32 )
515-
516- def set_weight (self , weight , inplace_update = False , seed = None , return_weight = False , ** kwargs ):
517- weight = comfy .float .stochastic_rounding (weight / self .scale_weight .to (device = weight .device , dtype = weight .dtype ), self .weight .dtype , seed = seed )
518- if return_weight :
519- return weight
520- if inplace_update :
521- self .weight .data .copy_ (weight )
522- else :
523- self .weight = torch .nn .Parameter (weight , requires_grad = False )
524-
525- return scaled_fp8_op
526-
527465CUBLAS_IS_AVAILABLE = False
528466try :
529467 from cublas_ops import CublasLinear
@@ -550,9 +488,9 @@ def forward(self, *args, **kwargs):
550488from .quant_ops import QuantizedTensor , QUANT_ALGOS
551489
552490
553- def mixed_precision_ops (layer_quant_config = {}, compute_dtype = torch .bfloat16 , full_precision_mm = False ):
491+ def mixed_precision_ops (quant_config = {}, compute_dtype = torch .bfloat16 , full_precision_mm = False ):
554492 class MixedPrecisionOps (manual_cast ):
555- _layer_quant_config = layer_quant_config
493+ _quant_config = quant_config
556494 _compute_dtype = compute_dtype
557495 _full_precision_mm = full_precision_mm
558496
@@ -595,27 +533,36 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
595533
596534 manually_loaded_keys = [weight_key ]
597535
598- if layer_name not in MixedPrecisionOps ._layer_quant_config :
536+ layer_conf = state_dict .pop (f"{ prefix } comfy_quant" , None )
537+ if layer_conf is not None :
538+ layer_conf = json .loads (layer_conf .numpy ().tobytes ())
539+
540+ if layer_conf is None :
599541 self .weight = torch .nn .Parameter (weight .to (device = device , dtype = MixedPrecisionOps ._compute_dtype ), requires_grad = False )
600542 else :
601- quant_format = MixedPrecisionOps ._layer_quant_config [layer_name ].get ("format" , None )
602- if quant_format is None :
543+ self .quant_format = layer_conf .get ("format" , None )
544+ if not self ._full_precision_mm :
545+ self ._full_precision_mm = layer_conf .get ("full_precision_matrix_mult" , False )
546+
547+ if self .quant_format is None :
603548 raise ValueError (f"Unknown quantization format for layer { layer_name } " )
604549
605- qconfig = QUANT_ALGOS [quant_format ]
550+ qconfig = QUANT_ALGOS [self . quant_format ]
606551 self .layout_type = qconfig ["comfy_tensor_layout" ]
607552
608553 weight_scale_key = f"{ prefix } weight_scale"
554+ scale = state_dict .pop (weight_scale_key , None )
609555 layout_params = {
610- 'scale' : state_dict . pop ( weight_scale_key , None ) ,
556+ 'scale' : scale ,
611557 'orig_dtype' : MixedPrecisionOps ._compute_dtype ,
612558 'block_size' : qconfig .get ("group_size" , None ),
613559 }
614- if layout_params ['scale' ] is not None :
560+
561+ if scale is not None :
615562 manually_loaded_keys .append (weight_scale_key )
616563
617564 self .weight = torch .nn .Parameter (
618- QuantizedTensor (weight .to (device = device ), self .layout_type , layout_params ),
565+ QuantizedTensor (weight .to (device = device , dtype = qconfig . get ( "storage_t" , None ) ), self .layout_type , layout_params ),
619566 requires_grad = False
620567 )
621568
@@ -624,7 +571,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
624571 _v = state_dict .pop (param_key , None )
625572 if _v is None :
626573 continue
627- setattr ( self , param_name , torch .nn .Parameter (_v .to (device = device ), requires_grad = False ))
574+ self . register_parameter ( param_name , torch .nn .Parameter (_v .to (device = device ), requires_grad = False ))
628575 manually_loaded_keys .append (param_key )
629576
630577 super ()._load_from_state_dict (state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs )
@@ -633,6 +580,16 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
633580 if key in missing_keys :
634581 missing_keys .remove (key )
635582
583+ def state_dict (self , * args , destination = None , prefix = "" , ** kwargs ):
584+ sd = super ().state_dict (* args , destination = destination , prefix = prefix , ** kwargs )
585+ if isinstance (self .weight , QuantizedTensor ):
586+ sd ["{}weight_scale" .format (prefix )] = self .weight ._layout_params ['scale' ]
587+ quant_conf = {"format" : self .quant_format }
588+ if self ._full_precision_mm :
589+ quant_conf ["full_precision_matrix_mult" ] = True
590+ sd ["{}comfy_quant" .format (prefix )] = torch .frombuffer (json .dumps (quant_conf ).encode ('utf-8' ), dtype = torch .uint8 )
591+ return sd
592+
636593 def _forward (self , input , weight , bias ):
637594 return torch .nn .functional .linear (input , weight , bias )
638595
@@ -648,9 +605,8 @@ def forward(self, input, *args, **kwargs):
648605 if self ._full_precision_mm or self .comfy_cast_weights or len (self .weight_function ) > 0 or len (self .bias_function ) > 0 :
649606 return self .forward_comfy_cast_weights (input , * args , ** kwargs )
650607 if (getattr (self , 'layout_type' , None ) is not None and
651- getattr (self , 'input_scale' , None ) is not None and
652608 not isinstance (input , QuantizedTensor )):
653- input = QuantizedTensor .from_float (input , self .layout_type , scale = self . input_scale , dtype = self .weight .dtype )
609+ input = QuantizedTensor .from_float (input , self .layout_type , scale = getattr ( self , ' input_scale' , None ) , dtype = self .weight .dtype )
654610 return self ._forward (input , self .weight , self .bias )
655611
656612 def convert_weight (self , weight , inplace = False , ** kwargs ):
@@ -661,7 +617,7 @@ def convert_weight(self, weight, inplace=False, **kwargs):
661617
662618 def set_weight (self , weight , inplace_update = False , seed = None , return_weight = False , ** kwargs ):
663619 if getattr (self , 'layout_type' , None ) is not None :
664- weight = QuantizedTensor .from_float (weight , self .layout_type , scale = None , dtype = self .weight .dtype , stochastic_rounding = seed , inplace_ops = True )
620+ weight = QuantizedTensor .from_float (weight , self .layout_type , scale = "recalculate" , dtype = self .weight .dtype , stochastic_rounding = seed , inplace_ops = True )
665621 else :
666622 weight = weight .to (self .weight .dtype )
667623 if return_weight :
@@ -672,15 +628,12 @@ def set_weight(self, weight, inplace_update=False, seed=None, return_weight=Fals
672628
673629 return MixedPrecisionOps
674630
675- def pick_operations (weight_dtype , compute_dtype , load_device = None , disable_fast_fp8 = False , fp8_optimizations = False , scaled_fp8 = None , model_config = None ):
631+ def pick_operations (weight_dtype , compute_dtype , load_device = None , disable_fast_fp8 = False , fp8_optimizations = False , model_config = None ):
676632 fp8_compute = comfy .model_management .supports_fp8_compute (load_device ) # TODO: if we support more ops this needs to be more granular
677633
678- if model_config and hasattr (model_config , 'layer_quant_config' ) and model_config .layer_quant_config :
679- logging .info (f"Using mixed precision operations: { len (model_config .layer_quant_config )} quantized layers" )
680- return mixed_precision_ops (model_config .layer_quant_config , compute_dtype , full_precision_mm = not fp8_compute )
681-
682- if scaled_fp8 is not None :
683- return scaled_fp8_ops (fp8_matrix_mult = fp8_compute and fp8_optimizations , scale_input = fp8_optimizations , override_dtype = scaled_fp8 )
634+ if model_config and hasattr (model_config , 'quant_config' ) and model_config .quant_config :
635+ logging .info ("Using mixed precision operations" )
636+ return mixed_precision_ops (model_config .quant_config , compute_dtype , full_precision_mm = not fp8_compute )
684637
685638 if (
686639 fp8_compute and
0 commit comments