Skip to content

Commit 8d78ac9

Browse files
author
Amit Raj
committed
Height and widht now can be passed from compile and __call__ method and comments addressed
Signed-off-by: Amit Raj <[email protected]>
1 parent b91e2c9 commit 8d78ac9

File tree

8 files changed

+77
-111
lines changed

8 files changed

+77
-111
lines changed

QEfficient/diffusers/models/pytorch_transforms.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
QEffFluxAttnProcessor,
3636
QEffFluxSingleTransformerBlock,
3737
QEffFluxTransformer2DModel,
38-
QEffFluxTransformer2DModelOF,
3938
QEffFluxTransformerBlock,
4039
)
4140

@@ -81,12 +80,3 @@ class NormalizationTransform(ModuleMappingTransform):
8180
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
8281
model, transformed = super().apply(model)
8382
return model, transformed
84-
85-
86-
class OnnxFunctionTransform(ModuleMappingTransform):
87-
_module_mapping = {QEffFluxTransformer2DModel, QEffFluxTransformer2DModelOF}
88-
89-
@classmethod
90-
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
91-
model, transformed = super().apply(model)
92-
return model, transformed

QEfficient/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 36 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import os
99
import time
10-
from typing import Any, Callable, Dict, List, Optional, Union
10+
from typing import Callable, Dict, List, Optional, Union
1111

1212
import numpy as np
1313
import torch
@@ -24,6 +24,7 @@
2424
from QEfficient.diffusers.pipelines.pipeline_utils import (
2525
ModulePerf,
2626
QEffPipelineOutput,
27+
calculate_compressed_latent_dimension,
2728
compile_modules_parallel,
2829
compile_modules_sequential,
2930
config_manager,
@@ -47,21 +48,21 @@ class QEFFFluxPipeline(FluxPipeline):
4748

4849
_hf_auto_class = FluxPipeline
4950

50-
def __init__(self, model, use_onnx_function: bool, *args, **kwargs):
51+
def __init__(self, model, use_onnx_subfunctions: bool, *args, **kwargs):
5152
"""
5253
Initialize the QEfficient Flux pipeline.
5354
5455
Args:
5556
model: Pre-loaded FluxPipeline model
56-
use_onnx_function (bool): Whether to export transformer blocks as ONNX functions
57+
use_onnx_subfunctions (bool): Whether to export transformer blocks as ONNX functions
5758
**kwargs: Additional arguments including height and width
5859
"""
60+
5961
# Wrap model components with QEfficient optimized versions
6062
self.text_encoder = QEffTextEncoder(model.text_encoder)
6163
self.text_encoder_2 = QEffTextEncoder(model.text_encoder_2)
62-
self.transformer = QEffFluxTransformerModel(model.transformer, use_onnx_function=use_onnx_function)
64+
self.transformer = QEffFluxTransformerModel(model.transformer, use_onnx_subfunctions=use_onnx_subfunctions)
6365
self.vae_decode = QEffVAE(model, "decoder")
64-
self.use_onnx_function = use_onnx_function
6566

6667
# Store all modules in a dictionary for easy iteration during export/compile
6768
self.modules = {
@@ -78,10 +79,6 @@ def __init__(self, model, use_onnx_function: bool, *args, **kwargs):
7879
self.tokenizer_max_length = model.tokenizer_max_length
7980
self.scheduler = model.scheduler
8081

81-
# Set default image dimensions
82-
self.height = kwargs.get("height", 256)
83-
self.width = kwargs.get("width", 256)
84-
8582
# Override VAE forward method to use decode directly
8683
self.vae_decode.model.forward = lambda latent_sample, return_dict: self.vae_decode.model.decode(
8784
latent_sample, return_dict
@@ -102,10 +99,6 @@ def __init__(self, model, use_onnx_function: bool, *args, **kwargs):
10299

103100
# Calculate latent dimensions based on image size and VAE scale factor
104101
self.default_sample_size = 128
105-
self.latent_height = self.height // self.vae_scale_factor
106-
self.latent_width = self.width // self.vae_scale_factor
107-
# cl = compressed latent dimension (divided by 4 for Flux's 2x2 packing)
108-
self.cl = (self.latent_height * self.latent_width) // 4
109102

110103
# Sync max position embeddings between text encoders
111104
self.text_encoder_2.model.config.max_position_embeddings = (
@@ -116,17 +109,15 @@ def __init__(self, model, use_onnx_function: bool, *args, **kwargs):
116109
def from_pretrained(
117110
cls,
118111
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
119-
use_onnx_function: bool = False,
120-
height: Optional[int] = 512,
121-
width: Optional[int] = 512,
112+
use_onnx_subfunctions: bool = False,
122113
**kwargs,
123114
):
124115
"""
125116
Load a pretrained Flux model and wrap it with QEfficient optimizations.
126117
127118
Args:
128119
pretrained_model_name_or_path (str or os.PathLike): HuggingFace model ID or local path
129-
use_onnx_function (bool): Whether to export transformer blocks as ONNX functions
120+
use_onnx_subfunctions (bool): Whether to export transformer blocks as ONNX functions
130121
height (int): Target image height (default: 512)
131122
width (int): Target image width (default: 512)
132123
**kwargs: Additional arguments passed to FluxPipeline.from_pretrained
@@ -144,10 +135,8 @@ def from_pretrained(
144135

145136
return cls(
146137
model=model,
147-
use_onnx_function=use_onnx_function,
138+
use_onnx_subfunctions=use_onnx_subfunctions,
148139
pretrained_model_name_or_path=pretrained_model_name_or_path,
149-
height=height,
150-
width=width,
151140
**kwargs,
152141
)
153142

@@ -168,20 +157,12 @@ def export(self, export_dir: Optional[str] = None) -> str:
168157
# Get ONNX export configuration for this module
169158
example_inputs, dynamic_axes, output_names = module_obj.get_onnx_config()
170159

171-
export_kwargs = {}
172-
# Special handling for transformer: export blocks as functions if enabled
173-
if module_name == "transformer" and self.use_onnx_function:
174-
export_kwargs = {
175-
"export_modules_as_functions": self.transformer.model._block_classes,
176-
}
177-
178160
# Export the module to ONNX
179161
module_obj.export(
180162
inputs=example_inputs,
181163
output_names=output_names,
182164
dynamic_axes=dynamic_axes,
183165
export_dir=export_dir,
184-
export_kwargs=export_kwargs,
185166
)
186167

187168
@staticmethod
@@ -194,7 +175,9 @@ def get_default_config_path() -> str:
194175
"""
195176
return os.path.join(os.path.dirname(__file__), "flux_config.json")
196177

197-
def compile(self, compile_config: Optional[str] = None, parallel: bool = False) -> None:
178+
def compile(
179+
self, compile_config: Optional[str] = None, parallel: bool = False, height: int = 512, width: int = 512
180+
) -> None:
198181
"""
199182
Compile ONNX models for deployment on Qualcomm AI hardware.
200183
@@ -204,7 +187,7 @@ def compile(self, compile_config: Optional[str] = None, parallel: bool = False)
204187
Args:
205188
compile_config (str, optional): Path to JSON configuration file.
206189
If None, uses default configuration.
207-
parallel (bool): If True, compile modules in parallel using ProcessPoolExecutor.
190+
parallel (bool): If True, compile modules in parallel using ThreadPoolExecutor.
208191
If False, compile sequentially (default: False).
209192
"""
210193
# Ensure all modules are exported to ONNX before compilation
@@ -223,12 +206,15 @@ def compile(self, compile_config: Optional[str] = None, parallel: bool = False)
223206
if self.custom_config is None:
224207
config_manager(self, config_source=compile_config)
225208

209+
# Calculate compressed latent dimension using utility function
210+
cl, latent_height, latent_width = calculate_compressed_latent_dimension(height, width, self.vae_scale_factor)
211+
226212
# Prepare dynamic specialization updates based on image dimensions
227213
specialization_updates = {
228-
"transformer": {"cl": self.cl},
214+
"transformer": {"cl": cl},
229215
"vae_decoder": {
230-
"latent_height": self.latent_height,
231-
"latent_width": self.latent_width,
216+
"latent_height": latent_height,
217+
"latent_width": latent_width,
232218
},
233219
}
234220

@@ -448,6 +434,8 @@ def encode_prompt(
448434

449435
def __call__(
450436
self,
437+
height: int = 512,
438+
width: int = 512,
451439
prompt: Union[str, List[str]] = None,
452440
prompt_2: Optional[Union[str, List[str]]] = None,
453441
negative_prompt: Union[str, List[str]] = None,
@@ -464,8 +452,6 @@ def __call__(
464452
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
465453
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
466454
output_type: Optional[str] = "pil",
467-
return_dict: bool = True,
468-
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
469455
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
470456
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
471457
max_sequence_length: int = 512,
@@ -513,19 +499,21 @@ def __call__(
513499
"""
514500
device = "cpu"
515501

502+
if height is None or width is None:
503+
logger.warning("Height or width is None. Setting default values of 512 for both dimensions.")
504+
516505
# Step 1: Load configuration and compile models if needed
517506
if custom_config_path is not None:
518507
config_manager(self, custom_config_path)
519508
set_module_device_ids(self)
520509

521-
self.compile(compile_config=custom_config_path, parallel=parallel_compile)
522-
510+
self.compile(compile_config=custom_config_path, parallel=parallel_compile, height=height, width=width)
523511
# Validate all inputs
524512
self.check_inputs(
525513
prompt,
526514
prompt_2,
527-
self.height,
528-
self.width,
515+
height,
516+
width,
529517
negative_prompt=negative_prompt,
530518
negative_prompt_2=negative_prompt_2,
531519
prompt_embeds=prompt_embeds,
@@ -587,23 +575,26 @@ def __call__(
587575
latents, latent_image_ids = self.prepare_latents(
588576
batch_size * num_images_per_prompt,
589577
num_channels_latents,
590-
self.height,
591-
self.width,
578+
height,
579+
width,
592580
prompt_embeds.dtype,
593581
device,
594582
generator,
595583
latents,
596584
)
597585

598-
# Step 6: Initialize transformer inference session
586+
# Step 6: Calculate compressed latent dimension for transformer buffer allocation
587+
cl, _, _ = calculate_compressed_latent_dimension(height, width, self.vae_scale_factor)
588+
589+
# Initialize transformer inference session
599590
if self.transformer.qpc_session is None:
600591
self.transformer.qpc_session = QAICInferenceSession(
601592
str(self.transformer.qpc_path), device_ids=self.transformer.device_ids
602593
)
603594

604595
# Allocate output buffer for transformer
605596
output_buffer = {
606-
"output": np.random.rand(batch_size, self.cl, self.transformer.model.config.in_channels).astype(np.float32),
597+
"output": np.random.rand(batch_size, cl, self.transformer.model.config.in_channels).astype(np.float32),
607598
}
608599
self.transformer.qpc_session.set_buffers(output_buffer)
609600

@@ -693,7 +684,7 @@ def __call__(
693684
image = latents
694685
else:
695686
# Unpack and denormalize latents
696-
latents = self._unpack_latents(latents, self.height, self.width, self.vae_scale_factor)
687+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
697688
latents = (latents / self.vae_decode.model.scaling_factor) + self.vae_decode.model.shift_factor
698689

699690
# Initialize VAE decoder inference session
@@ -703,7 +694,7 @@ def __call__(
703694
)
704695

705696
# Allocate output buffer for VAE decoder
706-
output_buffer = {"sample": np.random.rand(batch_size, 3, self.height, self.width).astype(np.int32)}
697+
output_buffer = {"sample": np.random.rand(batch_size, 3, height, width).astype(np.int32)}
707698
self.vae_decode.qpc_session.set_buffers(output_buffer)
708699

709700
# Run VAE decoder inference and measure time

QEfficient/diffusers/pipelines/pipeline_module.py

Lines changed: 10 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
AttentionTransform,
1818
CustomOpsTransform,
1919
NormalizationTransform,
20-
OnnxFunctionTransform,
20+
)
21+
from QEfficient.diffusers.models.transformers.transformer_flux import (
22+
QEffFluxSingleTransformerBlock,
23+
QEffFluxTransformerBlock,
2124
)
2225
from QEfficient.transformers.models.pytorch_transforms import (
2326
T5ModelTransform,
@@ -377,28 +380,20 @@ class QEffFluxTransformerModel(QEFFBaseModel):
377380
_pytorch_transforms = [AttentionTransform, NormalizationTransform, CustomOpsTransform]
378381
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
379382

380-
def __init__(self, model: nn.Module, use_onnx_function: bool) -> None:
383+
def __init__(self, model: nn.Module, use_onnx_subfunctions: bool) -> None:
381384
"""
382385
Initialize the Flux transformer wrapper.
383386
384387
Args:
385388
model (nn.Module): The Flux transformer model to wrap
386-
use_onnx_function (bool): Whether to export transformer blocks as ONNX functions
389+
use_onnx_subfunctions (bool): Whether to export transformer blocks as ONNX functions
387390
for better modularity and potential optimization
388391
"""
389-
390-
# Optionally apply ONNX function transform for modular export
391-
392-
if use_onnx_function:
393-
model, _ = OnnxFunctionTransform.apply(model)
394-
395392
super().__init__(model)
396393

397-
if use_onnx_function:
398-
self._pytorch_transforms.append(OnnxFunctionTransform)
399-
400394
# Ensure model is on CPU to avoid meta device issues
401395
self.model = model.to("cpu")
396+
self.use_onnx_subfunctions = use_onnx_subfunctions
402397

403398
def get_onnx_config(
404399
self, batch_size: int = 1, seq_length: int = 256, cl: int = 4096
@@ -423,17 +418,12 @@ def get_onnx_config(
423418
example_inputs = {
424419
# Latent representation of the image
425420
"hidden_states": torch.randn(batch_size, cl, self.model.config.in_channels, dtype=torch.float32),
426-
# Text embeddings from T5 encoder
427421
"encoder_hidden_states": torch.randn(
428422
batch_size, seq_length, self.model.config.joint_attention_dim, dtype=torch.float32
429423
),
430-
# Pooled text embeddings from CLIP encoder
431424
"pooled_projections": torch.randn(batch_size, self.model.config.pooled_projection_dim, dtype=torch.float32),
432-
# Diffusion timestep (normalized to [0, 1])
433425
"timestep": torch.tensor([1.0], dtype=torch.float32),
434-
# Position IDs for image patches
435426
"img_ids": torch.randn(cl, 3, dtype=torch.float32),
436-
# Position IDs for text tokens
437427
"txt_ids": torch.randn(seq_length, 3, dtype=torch.float32),
438428
# AdaLN embeddings for dual transformer blocks
439429
# Shape: [num_layers, 12 chunks (6 for norm1 + 6 for norm1_context), hidden_dim]
@@ -490,6 +480,9 @@ def export(
490480
Returns:
491481
str: Path to the exported ONNX model
492482
"""
483+
if self.use_onnx_subfunctions:
484+
export_kwargs = {"export_modules_as_functions": {QEffFluxTransformerBlock, QEffFluxSingleTransformerBlock}}
485+
493486
return self._export(
494487
example_inputs=inputs,
495488
output_names=output_names,
@@ -498,35 +491,6 @@ def export(
498491
export_kwargs=export_kwargs,
499492
)
500493

501-
def get_specializations(self, batch_size: int, seq_len: int, cl: int) -> List[Dict]:
502-
"""
503-
Generate specialization configuration for compilation.
504-
505-
Specializations define fixed values for certain dimensions to enable
506-
compiler optimizations specific to the target use case.
507-
508-
Args:
509-
batch_size (int): Batch size for inference
510-
seq_len (int): Text sequence length
511-
cl (int): Compressed latent dimension
512-
513-
Returns:
514-
List[Dict]: Specialization configurations for the compiler
515-
"""
516-
specializations = [
517-
{
518-
"batch_size": batch_size,
519-
"stats-batchsize": batch_size,
520-
"num_layers": self.model.config.num_layers,
521-
"num_single_layers": self.model.config.num_single_layers,
522-
"seq_len": seq_len,
523-
"cl": cl,
524-
"steps": 1,
525-
}
526-
]
527-
528-
return specializations
529-
530494
def compile(self, specializations: List[Dict], **compiler_options) -> None:
531495
"""
532496
Compile the ONNX model for Qualcomm AI hardware.

0 commit comments

Comments
 (0)