77
88import os
99import time
10- from typing import Any , Callable , Dict , List , Optional , Union
10+ from typing import Callable , Dict , List , Optional , Union
1111
1212import numpy as np
1313import torch
2424from QEfficient .diffusers .pipelines .pipeline_utils import (
2525 ModulePerf ,
2626 QEffPipelineOutput ,
27+ calculate_compressed_latent_dimension ,
2728 compile_modules_parallel ,
2829 compile_modules_sequential ,
2930 config_manager ,
@@ -47,21 +48,21 @@ class QEFFFluxPipeline(FluxPipeline):
4748
4849 _hf_auto_class = FluxPipeline
4950
50- def __init__ (self , model , use_onnx_function : bool , * args , ** kwargs ):
51+ def __init__ (self , model , use_onnx_subfunctions : bool , * args , ** kwargs ):
5152 """
5253 Initialize the QEfficient Flux pipeline.
5354
5455 Args:
5556 model: Pre-loaded FluxPipeline model
56- use_onnx_function (bool): Whether to export transformer blocks as ONNX functions
57+ use_onnx_subfunctions (bool): Whether to export transformer blocks as ONNX functions
5758 **kwargs: Additional arguments including height and width
5859 """
60+
5961 # Wrap model components with QEfficient optimized versions
6062 self .text_encoder = QEffTextEncoder (model .text_encoder )
6163 self .text_encoder_2 = QEffTextEncoder (model .text_encoder_2 )
62- self .transformer = QEffFluxTransformerModel (model .transformer , use_onnx_function = use_onnx_function )
64+ self .transformer = QEffFluxTransformerModel (model .transformer , use_onnx_subfunctions = use_onnx_subfunctions )
6365 self .vae_decode = QEffVAE (model , "decoder" )
64- self .use_onnx_function = use_onnx_function
6566
6667 # Store all modules in a dictionary for easy iteration during export/compile
6768 self .modules = {
@@ -78,10 +79,6 @@ def __init__(self, model, use_onnx_function: bool, *args, **kwargs):
7879 self .tokenizer_max_length = model .tokenizer_max_length
7980 self .scheduler = model .scheduler
8081
81- # Set default image dimensions
82- self .height = kwargs .get ("height" , 256 )
83- self .width = kwargs .get ("width" , 256 )
84-
8582 # Override VAE forward method to use decode directly
8683 self .vae_decode .model .forward = lambda latent_sample , return_dict : self .vae_decode .model .decode (
8784 latent_sample , return_dict
@@ -102,10 +99,6 @@ def __init__(self, model, use_onnx_function: bool, *args, **kwargs):
10299
103100 # Calculate latent dimensions based on image size and VAE scale factor
104101 self .default_sample_size = 128
105- self .latent_height = self .height // self .vae_scale_factor
106- self .latent_width = self .width // self .vae_scale_factor
107- # cl = compressed latent dimension (divided by 4 for Flux's 2x2 packing)
108- self .cl = (self .latent_height * self .latent_width ) // 4
109102
110103 # Sync max position embeddings between text encoders
111104 self .text_encoder_2 .model .config .max_position_embeddings = (
@@ -116,17 +109,15 @@ def __init__(self, model, use_onnx_function: bool, *args, **kwargs):
116109 def from_pretrained (
117110 cls ,
118111 pretrained_model_name_or_path : Optional [Union [str , os .PathLike ]],
119- use_onnx_function : bool = False ,
120- height : Optional [int ] = 512 ,
121- width : Optional [int ] = 512 ,
112+ use_onnx_subfunctions : bool = False ,
122113 ** kwargs ,
123114 ):
124115 """
125116 Load a pretrained Flux model and wrap it with QEfficient optimizations.
126117
127118 Args:
128119 pretrained_model_name_or_path (str or os.PathLike): HuggingFace model ID or local path
129- use_onnx_function (bool): Whether to export transformer blocks as ONNX functions
120+ use_onnx_subfunctions (bool): Whether to export transformer blocks as ONNX functions
130121 height (int): Target image height (default: 512)
131122 width (int): Target image width (default: 512)
132123 **kwargs: Additional arguments passed to FluxPipeline.from_pretrained
@@ -144,10 +135,8 @@ def from_pretrained(
144135
145136 return cls (
146137 model = model ,
147- use_onnx_function = use_onnx_function ,
138+ use_onnx_subfunctions = use_onnx_subfunctions ,
148139 pretrained_model_name_or_path = pretrained_model_name_or_path ,
149- height = height ,
150- width = width ,
151140 ** kwargs ,
152141 )
153142
@@ -168,20 +157,12 @@ def export(self, export_dir: Optional[str] = None) -> str:
168157 # Get ONNX export configuration for this module
169158 example_inputs , dynamic_axes , output_names = module_obj .get_onnx_config ()
170159
171- export_kwargs = {}
172- # Special handling for transformer: export blocks as functions if enabled
173- if module_name == "transformer" and self .use_onnx_function :
174- export_kwargs = {
175- "export_modules_as_functions" : self .transformer .model ._block_classes ,
176- }
177-
178160 # Export the module to ONNX
179161 module_obj .export (
180162 inputs = example_inputs ,
181163 output_names = output_names ,
182164 dynamic_axes = dynamic_axes ,
183165 export_dir = export_dir ,
184- export_kwargs = export_kwargs ,
185166 )
186167
187168 @staticmethod
@@ -194,7 +175,9 @@ def get_default_config_path() -> str:
194175 """
195176 return os .path .join (os .path .dirname (__file__ ), "flux_config.json" )
196177
197- def compile (self , compile_config : Optional [str ] = None , parallel : bool = False ) -> None :
178+ def compile (
179+ self , compile_config : Optional [str ] = None , parallel : bool = False , height : int = 512 , width : int = 512
180+ ) -> None :
198181 """
199182 Compile ONNX models for deployment on Qualcomm AI hardware.
200183
@@ -204,7 +187,7 @@ def compile(self, compile_config: Optional[str] = None, parallel: bool = False)
204187 Args:
205188 compile_config (str, optional): Path to JSON configuration file.
206189 If None, uses default configuration.
207- parallel (bool): If True, compile modules in parallel using ProcessPoolExecutor .
190+ parallel (bool): If True, compile modules in parallel using ThreadPoolExecutor .
208191 If False, compile sequentially (default: False).
209192 """
210193 # Ensure all modules are exported to ONNX before compilation
@@ -223,12 +206,15 @@ def compile(self, compile_config: Optional[str] = None, parallel: bool = False)
223206 if self .custom_config is None :
224207 config_manager (self , config_source = compile_config )
225208
209+ # Calculate compressed latent dimension using utility function
210+ cl , latent_height , latent_width = calculate_compressed_latent_dimension (height , width , self .vae_scale_factor )
211+
226212 # Prepare dynamic specialization updates based on image dimensions
227213 specialization_updates = {
228- "transformer" : {"cl" : self . cl },
214+ "transformer" : {"cl" : cl },
229215 "vae_decoder" : {
230- "latent_height" : self . latent_height ,
231- "latent_width" : self . latent_width ,
216+ "latent_height" : latent_height ,
217+ "latent_width" : latent_width ,
232218 },
233219 }
234220
@@ -448,6 +434,8 @@ def encode_prompt(
448434
449435 def __call__ (
450436 self ,
437+ height : int = 512 ,
438+ width : int = 512 ,
451439 prompt : Union [str , List [str ]] = None ,
452440 prompt_2 : Optional [Union [str , List [str ]]] = None ,
453441 negative_prompt : Union [str , List [str ]] = None ,
@@ -464,8 +452,6 @@ def __call__(
464452 negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
465453 negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
466454 output_type : Optional [str ] = "pil" ,
467- return_dict : bool = True ,
468- joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
469455 callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] = None ,
470456 callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
471457 max_sequence_length : int = 512 ,
@@ -513,19 +499,21 @@ def __call__(
513499 """
514500 device = "cpu"
515501
502+ if height is None or width is None :
503+ logger .warning ("Height or width is None. Setting default values of 512 for both dimensions." )
504+
516505 # Step 1: Load configuration and compile models if needed
517506 if custom_config_path is not None :
518507 config_manager (self , custom_config_path )
519508 set_module_device_ids (self )
520509
521- self .compile (compile_config = custom_config_path , parallel = parallel_compile )
522-
510+ self .compile (compile_config = custom_config_path , parallel = parallel_compile , height = height , width = width )
523511 # Validate all inputs
524512 self .check_inputs (
525513 prompt ,
526514 prompt_2 ,
527- self . height ,
528- self . width ,
515+ height ,
516+ width ,
529517 negative_prompt = negative_prompt ,
530518 negative_prompt_2 = negative_prompt_2 ,
531519 prompt_embeds = prompt_embeds ,
@@ -587,23 +575,26 @@ def __call__(
587575 latents , latent_image_ids = self .prepare_latents (
588576 batch_size * num_images_per_prompt ,
589577 num_channels_latents ,
590- self . height ,
591- self . width ,
578+ height ,
579+ width ,
592580 prompt_embeds .dtype ,
593581 device ,
594582 generator ,
595583 latents ,
596584 )
597585
598- # Step 6: Initialize transformer inference session
586+ # Step 6: Calculate compressed latent dimension for transformer buffer allocation
587+ cl , _ , _ = calculate_compressed_latent_dimension (height , width , self .vae_scale_factor )
588+
589+ # Initialize transformer inference session
599590 if self .transformer .qpc_session is None :
600591 self .transformer .qpc_session = QAICInferenceSession (
601592 str (self .transformer .qpc_path ), device_ids = self .transformer .device_ids
602593 )
603594
604595 # Allocate output buffer for transformer
605596 output_buffer = {
606- "output" : np .random .rand (batch_size , self . cl , self .transformer .model .config .in_channels ).astype (np .float32 ),
597+ "output" : np .random .rand (batch_size , cl , self .transformer .model .config .in_channels ).astype (np .float32 ),
607598 }
608599 self .transformer .qpc_session .set_buffers (output_buffer )
609600
@@ -693,7 +684,7 @@ def __call__(
693684 image = latents
694685 else :
695686 # Unpack and denormalize latents
696- latents = self ._unpack_latents (latents , self . height , self . width , self .vae_scale_factor )
687+ latents = self ._unpack_latents (latents , height , width , self .vae_scale_factor )
697688 latents = (latents / self .vae_decode .model .scaling_factor ) + self .vae_decode .model .shift_factor
698689
699690 # Initialize VAE decoder inference session
@@ -703,7 +694,7 @@ def __call__(
703694 )
704695
705696 # Allocate output buffer for VAE decoder
706- output_buffer = {"sample" : np .random .rand (batch_size , 3 , self . height , self . width ).astype (np .int32 )}
697+ output_buffer = {"sample" : np .random .rand (batch_size , 3 , height , width ).astype (np .int32 )}
707698 self .vae_decode .qpc_session .set_buffers (output_buffer )
708699
709700 # Run VAE decoder inference and measure time
0 commit comments