diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index 956ccf316..0481ace3e 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -25,7 +25,12 @@ def check_qaic_sdk(): # Conditionally import QAIC-related modules if the SDK is installed __version__ = "0.0.1.dev0" if QAIC_INSTALLED: - from QEfficient.base import QEFFAutoModel, QEFFAutoModelForCausalLM, QEFFCommonLoader,QEFFAutoModelForImageTextToText + from QEfficient.base import ( + QEFFAutoModel, + QEFFAutoModelForCausalLM, + QEFFAutoModelForImageTextToText, + QEFFCommonLoader, + ) from QEfficient.compile.compile_helper import compile from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 2760cf52f..b77279dcf 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -114,6 +114,7 @@ def compile(self, *args, **kwargs) -> Path: def _export( self, + model, example_inputs: Dict[str, torch.Tensor], output_names: List[str], dynamic_axes: Dict[str, Dict[int, str]], @@ -157,7 +158,7 @@ def _export( try: export_kwargs = {} if export_kwargs is None else export_kwargs torch.onnx.export( - self.model, + model, (example_inputs,), str(tmp_onnx_path), input_names=input_names, @@ -175,6 +176,7 @@ def _export( } if onnx_transform_kwargs is not None: transform_kwargs.update(onnx_transform_kwargs) + for transform in self._onnx_transforms: model, transformed = transform.apply(model, **transform_kwargs) model.metadata_props.append( diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index f749cc0c3..23364655f 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -6,8 +6,9 @@ # ----------------------------------------------------------------------------- from collections import namedtuple -from typing import Dict, Type +from typing import Dict, Optional, Tuple, Type +import torch import torch.nn as nn from transformers.models.codegen.modeling_codegen import ( CodeGenAttention, @@ -242,3 +243,95 @@ GPTBigCodeBlock: QEffGPTBigCodeBlock, GPTBigCodeModel: QEffGPTBigCodeModel, } + + +def _prepare_cross_attention_mask( + cross_attention_mask: torch.Tensor, + num_vision_tokens: int, + dtype: str, +) -> Tuple[torch.Tensor, torch.Tensor]: + # reshape so it can be used by attn module + batch_size, text_total_length, *_ = cross_attention_mask.shape + cross_attention_mask = cross_attention_mask.repeat_interleave(num_vision_tokens, dim=3) + cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1) + cross_attention_mask = cross_attention_mask.unsqueeze(1) + + # invert the mask + inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype) + cross_attention_mask = inverted_cross_attn_mask.masked_fill( + inverted_cross_attn_mask.to(torch.bool), torch.tensor(-10000.0, dtype=torch.float32) + ) + + # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's + # last dimension contains negative infinity values, otherwise it's 1 + negative_inf_value = torch.tensor(-10000.0, dtype=torch.float32) + full_text_row_masked_out_mask = ( + (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None] + ) + cross_attention_mask *= full_text_row_masked_out_mask + + return cross_attention_mask, full_text_row_masked_out_mask + + +def _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask: torch.Tensor, + num_patches: int, + target_length: int, + dtype: torch.dtype, +) -> torch.Tensor: + # Expand aspect ratio mask to target_length + batch_size, max_num_tiles = aspect_ratio_mask.shape + attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype) + attention_mask = attention_mask.repeat(1, 1, target_length, 1) + + # Mask padding patches + pad_patches = target_length - num_patches + attention_mask[:, :, -pad_patches:] = 0 + + # Invert the mask (0 -> 1, 1 -> 0) + attention_mask = 1 - attention_mask + + # Reshape to 2D and create 4D attention mask + # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length) + attention_mask = attention_mask.reshape(batch_size, max_num_tiles * target_length, 1) + attention_mask = attention_mask @ attention_mask.transpose(-1, -2) * torch.tensor(-10000.0, dtype=torch.float32) + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + +def _create_causal_mask( + position_ids, + target_length, + sliding_window: Optional[int] = None, +): + """ + A utility attention mask class that allows one to: + - Create a causal 4d mask + - Create a causal 4d mask with slided window + """ + if sliding_window is not None: + query_indices = position_ids.unsqueeze(-1) + kv_indices = torch.arange(target_length).view(1, -1) + # --- Rolling buffer --- + pos_max = position_ids.max(1, keepdim=True).values + kv_start = (pos_max // target_length) * target_length + kv_indices_high = kv_indices + kv_start + kv_indices_low = torch.where(kv_indices_high < target_length, kv_indices, kv_indices_high - target_length) + kv_indices = torch.where(kv_indices_high > pos_max, kv_indices_low, kv_indices_high) + kv_indices = kv_indices.unsqueeze(1) + # ------ + causal_mask = kv_indices > query_indices + attention_mask = causal_mask + + window_indices = query_indices - sliding_window + 1 + window_mask = kv_indices < window_indices + attention_mask = attention_mask | window_mask + attention_mask = attention_mask.unsqueeze(1) + else: + query_indices = position_ids.unsqueeze(-1) + kv_indices = torch.arange(target_length).view(1, 1, -1) + attention_mask = kv_indices > query_indices + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index 90be64096..76f4bd102 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -40,6 +40,21 @@ ) from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.modeling_utils import ( + _create_causal_mask, + _prepare_aspect_ratio_attention_mask, + _prepare_cross_attention_mask, +) +from QEfficient.utils import constants +from QEfficient.utils.constants import Constants + +bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE +max_num_images = constants.ONNX_EXPORT_MAX_NUM_IMAGES +max_image_tiles = constants.ONNX_EXPORT_MAX_IMAGE_TILES +image_length = constants.ONNX_EXPORT_IMAGE_LENGHT +image_width = constants.ONNX_EXPORT_IMAGE_WIDTH +num_channel = constants.ONNX_EXPORT_IMAGE_DEPTH +seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): @@ -72,73 +87,93 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): # Cast back to original dtype return q_embed.to(q.dtype), k_embed.to(k.dtype) -def _prepare_aspect_ratio_attention_mask( - aspect_ratio_mask: torch.Tensor, - num_patches: int, - target_length: int, - dtype: torch.dtype, -) -> torch.Tensor: - # Expand aspect ratio mask to target_length - batch_size, max_num_tiles = aspect_ratio_mask.shape - attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype) - attention_mask = attention_mask.repeat(1, 1, target_length, 1) - - # Mask padding patches - pad_patches = target_length - num_patches - attention_mask[:, :, -pad_patches:] = 0 - - # Invert the mask (0 -> 1, 1 -> 0) - attention_mask = 1 - attention_mask - - # Reshape to 2D and create 4D attention mask - # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length) - attention_mask = attention_mask.reshape(batch_size, max_num_tiles * target_length, 1) - attention_mask = ( - attention_mask - @ attention_mask.transpose(-1, -2) - * torch.tensor(-10000.0, dtype=torch.float32) - ) - attention_mask = attention_mask.unsqueeze(1) - - return attention_mask - -def _create_causal_mask( - position_ids, - target_length, - sliding_window: Optional[int] = None, -): + +class QEffMllamaTextCrossAttention(MllamaTextCrossAttention): """ - A utility attention mask class that allows one to: - - Create a causal 4d mask - - Create a causal 4d mask with slided window + Copied from MllamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/modeling_mllama.py + The only differences are: + - add new args cache idx for the kv retention """ - if sliding_window is not None: - query_indices = position_ids.unsqueeze(-1) - kv_indices = torch.arange(target_length).view(1, -1) - # --- Rolling buffer --- - pos_max = position_ids.max(1, keepdim=True).values - kv_start = (pos_max // target_length) * target_length - kv_indices_high = kv_indices + kv_start - kv_indices_low = torch.where( - kv_indices_high < target_length, kv_indices, kv_indices_high - target_length - ) - kv_indices = torch.where(kv_indices_high > pos_max, kv_indices_low, kv_indices_high) - kv_indices = kv_indices.unsqueeze(1) - # ------ - causal_mask = kv_indices > query_indices - attention_mask = causal_mask - window_indices = query_indices - sliding_window + 1 - window_mask = kv_indices < window_indices - attention_mask = attention_mask | window_mask - attention_mask = attention_mask.unsqueeze(1) - else: - query_indices = position_ids.unsqueeze(-1) - kv_indices = torch.arange(target_length).view(1, 1, -1) - attention_mask = kv_indices > query_indices - attention_mask = attention_mask.unsqueeze(1) + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = self.q_norm(query_states) - return attention_mask + # elif past_key_value is not None: + # Fetch old cache + key_states_old = past_key_value.key_cache[self.layer_idx] + value_states_old = past_key_value.value_cache[self.layer_idx] + + # if cross_attention_states is not None: + # Compute new KV states + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + # if past_key_value is not None: + # # if we have a new image + new tokens, we only computed key_states on that new image + # # we still update the cross key states, past_image, new_image. And use it! + # key_states, value_states = past_key_value.update( + # key_states, + # value_states, + # self.layer_idx, + # {"batch_index": batch_index, "position_ids": position_ids}, + # ) + + # Out-of-place Scatter new into old + # out-of-place is important so the original tensor is not affected, + # otherwise leads to same operations in both graphs + indices = (torch.arange(bsz),) + key_states_new = torch.index_put(key_states_old, indices, key_states) + value_states_new = torch.index_put(value_states_old, indices, value_states) + + # Select old or new image KV states based on q_len + key_states = torch.where(q_len == 1, key_states_old, key_states_new) + value_states = torch.where(q_len == 1, value_states_old, value_states_new) + + # Update the image cache + past_key_value.key_cache[self.layer_idx] = key_states + past_key_value.value_cache[self.layer_idx] = value_states + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + key_states = self.k_norm(key_states) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + # attn_weights = torch.where( + # attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights + # ) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value class QEffMllamaTextSelfAttention(MllamaTextSelfAttention): @@ -196,7 +231,12 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -222,89 +262,6 @@ def forward( return attn_output, attn_weights, past_key_value -class QEffMllamaTextCrossAttention(MllamaTextCrossAttention): - """ - Copied from MllamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/modeling_mllama.py - The only differences are: - - add new args cache idx for the kv retention - """ - - def forward( - self, - hidden_states: torch.Tensor, - cross_attention_states: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - batch_index: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - use_cache: bool = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - query_states = self.q_norm(query_states) - - if cross_attention_states is not None: - key_states = self.k_proj(cross_attention_states) - value_states = self.v_proj(cross_attention_states) - key_states = key_states.view( - bsz, -1, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, -1, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - if past_key_value is not None: - # if we have a new image + new tokens, we only computed key_states on that new image - # we still update the cross key states, past_image, new_image. And use it! - key_states, value_states = past_key_value.update( - key_states, - value_states, - self.layer_idx, - {"batch_index": batch_index, "position_ids": position_ids}, - ) - elif past_key_value is not None: - key_states, value_states = ( - past_key_value.key_cache[self.layer_idx], - past_key_value.value_cache[self.layer_idx], - ) - else: - raise ValueError( - "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" - ) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - key_states = self.k_norm(key_states) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( - self.head_dim - ) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - # attn_weights = torch.where( - # attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights - # ) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( - query_states.dtype - ) - attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - class QEffMllamaSelfAttentionDecoderLayer(MllamaSelfAttentionDecoderLayer): """ @@ -326,9 +283,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[ - Tuple[torch.Tensor, torch.Tensor] - ] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -479,9 +434,7 @@ def __init__( else: # BC: "rope_type" was originally "type" if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get( - "rope_type", config.rope_scaling.get("type") - ) + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings @@ -490,9 +443,7 @@ def __init__( self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. @@ -504,9 +455,7 @@ def __init__( def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as( - self.inv_freq - ) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) freqs = torch.outer(t, self.inv_freq) @@ -535,23 +484,15 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - batch_size, num_concurrent_media, num_tiles, num_channels, height, width = ( - pixel_values.shape - ) + batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape - pixel_values = pixel_values.reshape( - batch_size * num_concurrent_media * num_tiles, num_channels, height, width - ) + pixel_values = pixel_values.reshape(batch_size * num_concurrent_media * num_tiles, num_channels, height, width) aspect_ratio_ids = aspect_ratio_ids.reshape(batch_size * num_concurrent_media, -1) # Patch embedding @@ -564,16 +505,12 @@ def forward( hidden_state = self.pre_tile_positional_embedding(hidden_state, aspect_ratio_ids) # Add cls token - hidden_state = hidden_state.reshape( - batch_size * num_concurrent_media * num_tiles, num_patches, dim - ) + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media * num_tiles, num_patches, dim) hidden_state = self.apply_class_embedding(hidden_state) num_patches += 1 # Position embeddings - hidden_state = hidden_state.reshape( - batch_size * num_concurrent_media, num_tiles, num_patches, dim - ) + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, num_tiles, num_patches, dim) hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids) hidden_state = self.layernorm_pre(hidden_state) @@ -633,16 +570,12 @@ def forward( batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim ) hidden_state = hidden_state[:, :, :slice_index] - hidden_state = hidden_state.reshape( - batch_size, num_concurrent_media, num_tiles, num_patches, dim - ) + hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, num_tiles, num_patches, dim) # Collect intermediate layer outputs from encoder output all_intermediate_hidden_states = output[1] intermediate_hidden_states = torch.stack(all_intermediate_hidden_states, dim=-1) - intermediate_hidden_states = intermediate_hidden_states[ - ..., self.intermediate_layers_indices - ] + intermediate_hidden_states = intermediate_hidden_states[..., self.intermediate_layers_indices] # Remove padding from intermediate hidden states intermediate_hidden_states = intermediate_hidden_states.reshape( @@ -663,9 +596,7 @@ def forward( if output_attentions: # global transformer in contrast to `self.transformer` doesn't always return hidden states so we might go index out-of-range - global_attn = ( - tuple(global_output[2]) if output_hidden_states else tuple(global_output[1]) - ) + global_attn = tuple(global_output[2]) if output_hidden_states else tuple(global_output[1]) attentions = tuple(output[2]) + global_attn else: attentions = None @@ -704,13 +635,9 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -730,16 +657,12 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) return_legacy_cache = False - if use_cache and not isinstance( - past_key_values, Cache - ): # kept for BC (non `Cache` `past_key_values` inputs) + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: - past_seen_tokens = ( - past_key_values.get_seq_length() if past_key_values is not None else 0 - ) + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], @@ -788,11 +711,7 @@ def forward( # TODO: vbaddi: since past_key_values are retained from previous states, the condition for is_cross_attention_cache_empty is False # so explicitly making it true in order to skip the cross attention for language model # comment once there is vision and cross attention support - if ( - is_cross_attention_layer - and cross_attention_states is None - and is_cross_attention_cache_empty - ): + if is_cross_attention_layer and cross_attention_states is None and is_cross_attention_cache_empty: continue if self.gradient_checkpointing and self.training: @@ -859,11 +778,7 @@ def forward( next_cache = next_cache.to_legacy_cache() if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None - ) + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, @@ -908,11 +823,7 @@ def _update_causal_mask( # TODO: vbaddi: unused, comment to fix linters # sequence_length = input_tensor.shape[1] - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens - ) + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = _create_causal_mask(position_ids=position_ids, target_length=target_length) @@ -957,13 +868,9 @@ def forward( cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -1015,14 +922,216 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) - + + +class QEffMllamaForConditionalGeneration(MllamaForConditionalGeneration): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[torch.Tensor] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and cross_attention_states is not None: + raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously") + + if pixel_values is not None: + if aspect_ratio_ids is None: + raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided") + # get vision tokens from vision model + vision_outputs = self.vision_model( + pixel_values=pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + ) + cross_attention_states = vision_outputs[0] + cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape( + -1, cross_attention_states.shape[-2], self.hidden_size + ) + + if cross_attention_mask is not None: + cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask( + cross_attention_mask, + num_vision_tokens=self.vision_model.num_patches, + dtype=self.dtype, + ) + else: + full_text_row_masked_out_mask = None + + if cross_attention_mask is not None and cache_position is not None: + cross_attention_mask = cross_attention_mask[:, :, cache_position] + full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position] + + outputs = self.language_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + past_key_values=past_key_values, + batch_index=batch_index, + use_cache=use_cache, + inputs_embeds=inputs_embeds, + labels=labels, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + ) + + return outputs + + def generate_input(self, kv_offload): + # vision_inputs + vision_inputs = { + "pixel_values": torch.zeros( + (bs, max_num_images, max_image_tiles, num_channel, image_length, image_width), dtype=torch.int64 + ), + "aspect_ratio_ids": torch.ones((bs, max_num_images), dtype=torch.int64), + "aspect_ratio_mask": torch.ones((bs, max_num_images, max_image_tiles, 1), dtype=torch.int64), + } + + vision_output_names = [] + for i in self.config.text_config.cross_attention_layers: + vision_output_names.append(f"past_key.{i}") + vision_output_names.append(f"past_value.{i}") + + vision_dynamic_axes = { + "pixel_values": {0: "batch_size", 1: "max_num_images", 2: "max_image_tiles"}, + "aspect_ratio_ids": {0: "batch_size", 1: "max_num_images"}, + "aspect_ratio_mask": { + 0: "batch_size", + 1: "max_num_images", + 2: "max_image_tiles", + }, + } + + # lang_inputs + lang_inputs = { + "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), + "position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1), + "cross_attention_mask": torch.ones((bs, max_image_tiles), dtype=torch.int64), + "attention_mask": torch.ones((bs, seq_len), dtype=torch.int64), + } + + lang_inputs["position_ids"] = torch.where( + lang_inputs.pop("attention_mask") == 1, + torch.arange(lang_inputs["input_ids"].shape[1]).view(1, -1), + -1, + ) + + ctx_len = Constants.CTX_LEN + txt_cfg = self.config.get_text_config() + num_hidden_layers = txt_cfg.num_hidden_layers + cross_attention_layers = txt_cfg.cross_attention_layers + num_key_value_heads = txt_cfg.num_key_value_heads + head_dim = txt_cfg.hidden_size // txt_cfg.num_attention_heads + + vis_cfg = self.config.vision_config + num_patches = (vis_cfg.image_size // vis_cfg.patch_size) ** 2 + 1 + image_tokens_len = vis_cfg.max_num_tiles * num_patches + + lang_inputs["past_key_values"] = DynamicCache(num_hidden_layers) + lang_inputs["past_key_values"].key_cache = [0] * num_hidden_layers + lang_inputs["past_key_values"].value_cache = [0] * num_hidden_layers + + for i in range(num_hidden_layers): + if i in cross_attention_layers: + idx = cross_attention_layers.index(i) + assert idx == ((i - 3) // 5), f"{i}, {(i - 3) // 5}" + lang_inputs["past_key_values"].key_cache[i] = torch.zeros( + 1, num_key_value_heads, image_tokens_len, head_dim + ) + lang_inputs["past_key_values"].value_cache[i] = torch.zeros( + 1, num_key_value_heads, image_tokens_len, head_dim + ) + else: + lang_inputs["past_key_values"].key_cache[i] = torch.zeros(1, num_key_value_heads, ctx_len, head_dim) + lang_inputs["past_key_values"].value_cache[i] = torch.zeros(1, num_key_value_heads, ctx_len, head_dim) + + lang_output_names = [ + "logits", + *[f"past_{kv}.{i}_RetainedState" for i in range(num_hidden_layers) for kv in ["key", "value"]], + ] + + lang_dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "position_ids": {0: "batch_size", 1: "seq_len"}, + "cross_attention_mask": { + 0: "batch_size", + 1: "seq_len", + 2: "max_num_images", + 3: "max_image_tiles", + }, + } + + for i in range(num_hidden_layers): + if i in cross_attention_layers: + lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size"} + lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size"} + else: + lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} + lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + + lang_inputs["past_key_values"] = lang_inputs["past_key_values"].to_legacy_cache() + lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, ctx_len - 1) + + inputs = [] + output_names = [] + dynamic_axes = [] + + if kv_offload: + inputs.extend([vision_inputs, lang_inputs]) + output_names.extend([vision_output_names, lang_output_names]) + dynamic_axes.extend([vision_dynamic_axes, lang_dynamic_axes]) + else: + inputs.append({**vision_inputs, **lang_inputs}) + output_names = vision_output_names + lang_output_names + dynamic_axes.append({**vision_dynamic_axes, **lang_dynamic_axes}) + + return inputs, output_names, dynamic_axes + + class VisionEncoder(nn.Module): def __init__(self, mllama: MllamaForConditionalGeneration): super().__init__() self.mllama = mllama - self.cross_attention_layers = ( - self.mllama.config.get_text_config().cross_attention_layers - ) + self.cross_attention_layers = self.mllama.config.get_text_config().cross_attention_layers self.config = self.mllama.config.get_text_config() def forward( @@ -1037,9 +1146,9 @@ def forward( aspect_ratio_mask=aspect_ratio_mask, ) cross_attention_states = vision_outputs[0] - cross_attention_states = self.mllama.multi_modal_projector( - cross_attention_states - ).reshape(-1, cross_attention_states.shape[-2], self.mllama.hidden_size) + cross_attention_states = self.mllama.multi_modal_projector(cross_attention_states).reshape( + -1, cross_attention_states.shape[-2], self.mllama.hidden_size + ) bsz = pixel_values.shape[0] outputs = [] @@ -1047,16 +1156,15 @@ def forward( cross_attn = self.mllama.language_model.model.layers[i].cross_attn key_states = cross_attn.k_proj(cross_attention_states) value_states = cross_attn.v_proj(cross_attention_states) - key_states = key_states.view( - bsz, -1, cross_attn.num_key_value_heads, cross_attn.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, -1, cross_attn.num_key_value_heads, cross_attn.head_dim - ).transpose(1, 2) + key_states = key_states.view(bsz, -1, cross_attn.num_key_value_heads, cross_attn.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, cross_attn.num_key_value_heads, cross_attn.head_dim).transpose( + 1, 2 + ) outputs.append((key_states, value_states)) return outputs + class ModelWrapper(nn.Module): def __init__(self, mllama): super().__init__() @@ -1107,4 +1215,4 @@ def forward( ) if "past_key_values" in outputs: outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache() - return outputs \ No newline at end of file + return outputs diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 7fd8ef94f..c4558cb3d 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -14,10 +14,9 @@ from typing import List, Optional, Union import numpy as np -import requests import torch import torch.nn as nn -from PIL import Image +import transformers from transformers import ( AutoModel, AutoModelForCausalLM, @@ -33,7 +32,6 @@ from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.generation.text_generation_inference import get_compilation_dims -from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.models.mllama.modeling_mllama import ModelWrapper, VisionEncoder from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform, KVCacheTransform, SpDTransform from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING, with_replaced_quantizers @@ -722,8 +720,10 @@ def from_pretrained( self = super().from_pretrained(pretrained_model_name_or_path, is_tlm=is_tlm, *args, **kwargs) self.processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, padding_side="right", **kwargs) + self.tokenizer = self.processor.tokenizer self.continuous_batching = continuous_batching self.kv_offload = kv_offload + # self.model_name=pretrained_model_name_or_path self.is_tlm = is_tlm return self @@ -739,202 +739,47 @@ def model_hash(self) -> str: mhash = mhash.hexdigest()[:16] return mhash - def _generate_inputs(self, **kwargs): - bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE - # seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN - # fbs = constants.ONNX_EXPORT_EXAMPLE_FBS - - self.ctx_len = kwargs["ctx_len"] if "ctx_len" in kwargs else self.ctx_len - - ## PREPROCESSING THE MULTI-MODAL INPUTS for Phi-3.5 for now - # TODO: Create a map for the other models to have their own inputs accordingly - images = [] - placeholder = "" - - # Note: if OOM, you might consider reduce number of frames in this example. - for i in range(1, 2): - url = f"https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-{i}-2048.jpg" - images.append(Image.open(requests.get(url, stream=True).raw)) - placeholder += f"<|image_{1}|>\n" - - messages = [ - {"role": "user", "content": placeholder + "Summarize the deck of slides."}, - ] - - prompt = self.processor.tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True, - ) - inputs = dict(self.processor(images=images, text=prompt, return_tensors="pt")) - inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) - inputs["past_key_values"] = [] - for i in range(self.num_layers): - inputs["past_key_values"].append( - ( - torch.zeros(bs, self.num_key_value_heads, self.ctx_len, self.head_dim), - torch.zeros(bs, self.num_key_value_heads, self.ctx_len, self.head_dim), - ) - ) - output_names = [ - "logits", - "pixel_values_RetainedState", - "image_sizes_RetainedState", - *[f"past_{kv}.{i}_RetainedState" for i in range(self.num_layers) for kv in ["key", "value"]], - ] - dynamic_axes = { - "input_ids": {0: "batch_size", 1: "seq_len"}, - "position_ids": {0: "batch_size", 1: "seq_len"}, - # "pixel_values": {0: "img_batch_size"}, - } - for i in range(self.num_layers): - dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} - dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} - - # Avoid issues due to index out of range - inputs["position_ids"] = torch.full(inputs["position_ids"].shape, self.ctx_len - 1) - - return inputs, dynamic_axes, output_names - - def _generate_inputs_mllama( - self, - ): - url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" - image = Image.open(requests.get(url, stream=True).raw) - - messages = [ - { - "role": "user", - "content": [ - {"type": "image"}, - {"type": "text", "text": "If I had to write a haiku for this one, it would be: "}, - ], - } - ] - input_text = self.processor.apply_chat_template(messages, add_generation_prompt=True) - - split_inputs = self.processor( - text=input_text, - images=image, - return_tensors="pt", - add_special_tokens=False, - padding="max_length", - max_length=32, - ) - - lang_inputs = {} - vision_input = {} - - for k, v in split_inputs.items(): - if k in ["input_ids", "attention_mask", "cross_attention_mask"]: - lang_inputs[k] = v - else: - vision_input[k] = v - - return lang_inputs, vision_input - def export( self, export_dir: Optional[str] = None, **kwargs, ) -> str: - self.kv_offload = True + self.inputs, self.output_names, self.dynamic_axes = self.model.generate_input(self.kv_offload) if self.kv_offload: - print("generating input") - lang_inputs, vision_input = self._generate_inputs_mllama() - print("generating vision model") - self.vision_export_path = self.export_vision(vision_input, export_dir) - print("generating lang model") - self.lang_export_path = self.export_lang(lang_inputs, export_dir) - - def export_vision(self, vision_input, export_dir): - model = self.model - self.vision_encoder = self.model = VisionEncoder(self.model) - - vision_output_names = [] - for i in self.model.cross_attention_layers: - vision_output_names.append(f"past_key.{i}") - vision_output_names.append(f"past_value.{i}") - vision_dynamic_axes = { - "pixel_values": {0: "batch_size", 1: "max_num_images", 2: "max_image_tiles"}, - "aspect_ratio_ids": {0: "batch_size", 1: "max_num_images"}, - "aspect_ratio_mask": { - 0: "batch_size", - 1: "max_num_images", - 2: "max_image_tiles", - }, - } + self.vision_export_path = self.export_vision(export_dir) + self.lang_export_path = self.export_lang(export_dir) + else: + self.model = ModelWrapper(self.model) + self._export(self.model, self.inputs[0], self.output_names[0], self.dynamic_axes[0], export_dir=export_dir) + + def export_vision(self, export_dir): + self.vision_encoder_model = VisionEncoder(self.model) + + vision_inputs = self.inputs[0] + vision_output_names = self.output_names[0] + vision_dynamic_axes = self.dynamic_axes[0] self.vision_onnx_path = self._export( - vision_input, + self.vision_encoder_model, + vision_inputs, vision_output_names, vision_dynamic_axes, export_dir=export_dir, ) - self.model = model - self.vision_output_names = vision_output_names return self.vision_onnx_path - def export_lang(self, lang_inputs, export_dir): - self.num_layers = num_hidden_layers = self.model.config.get_text_config().num_hidden_layers - - lang_inputs["position_ids"] = torch.where( - lang_inputs.pop("attention_mask") == 1, - torch.arange(lang_inputs["input_ids"].shape[1]).view(1, -1), - -1, - ) - - lang_inputs["past_key_values"] = QEffDynamicCache(num_hidden_layers) - lang_inputs["past_key_values"].key_cache = [0] * num_hidden_layers - lang_inputs["past_key_values"].value_cache = [0] * num_hidden_layers + def export_lang(self, export_dir): + self.lang_model = ModelWrapper(self.model) - for i in range(num_hidden_layers): - if i in self.vision_encoder.cross_attention_layers: - idx = self.vision_encoder.cross_attention_layers.index(i) - assert idx == ((i - 3) // 5), f"{i}, {(i - 3) // 5}" - lang_inputs["past_key_values"].key_cache[i] = torch.zeros((1, 8, 6404, 128)) - lang_inputs["past_key_values"].value_cache[i] = torch.zeros((1, 8, 6404, 128)) - else: - lang_inputs["past_key_values"].key_cache[i] = torch.zeros((1, 8, 1024, 128)) - lang_inputs["past_key_values"].value_cache[i] = torch.zeros((1, 8, 1024, 128)) - - lang_inputs["position_ids"] = torch.full((1, 1), lang_inputs["past_key_values"].key_cache[0].shape[2] - 1) - lang_output_names = ["logits", "past_key_values"] - pkv_idx = lang_output_names.index("past_key_values") + lang_inputs = self.inputs[1] + lang_output_names = self.output_names[1] + lang_dynamic_axes = self.dynamic_axes[1] - lang_output_names[pkv_idx : pkv_idx + 1] = [ - f"past_{kv}.{i}_RetainedState" for i in range(num_hidden_layers) for kv in ["key", "value"] - ] - - lang_dynamic_axes = { - "input_ids": {0: "batch_size", 1: "seq_len"}, - "position_ids": {0: "batch_size", 1: "seq_len"}, - "cross_attention_mask": { - 0: "batch_size", - 1: "seq_len", - 2: "max_num_images", - 3: "max_image_tiles", - }, - } + self.lang_onnx_path = self._export( + self.lang_model, lang_inputs, lang_output_names, lang_dynamic_axes, export_dir=export_dir + ) - for i in range(num_hidden_layers): - if i in self.vision_encoder.cross_attention_layers: - lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size"} - lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size"} - continue - lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} - lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} - - lang_inputs["past_key_values"] = lang_inputs["past_key_values"].to_legacy_cache() - lang_inputs["input_ids"] = torch.tensor([[374]]) - lang_inputs["cross_attention_mask"] = lang_inputs["cross_attention_mask"][:, -1:] - self.lang_output_names = lang_output_names - model = self.model - self.model = ModelWrapper(model) - - self.lang_onnx_path = self._export(lang_inputs, lang_output_names, lang_dynamic_axes, export_dir=export_dir) - self.model = model return self.lang_onnx_path def compile( @@ -950,11 +795,10 @@ def compile( mxfp6_matmul: bool = False, **compiler_options, ) -> str: - self.kv_offload = True if self.kv_offload: model = self.model self.model = VisionEncoder(model) - vision_specializations = [{"batch_size": "1", "max_num_images": "1", "max_image_tiles": "4"}] + vision_specializations = [{"batch_size": batch_size, "max_num_images": "1", "max_image_tiles": "4"}] custom_io = {} kv_cache_dtype = "float16" @@ -995,12 +839,13 @@ def compile( "max_image_tiles": "4", }, ] - + # num_devices=4 custom_io_lang = {} # Inputs for output_name in self.lang_output_names: if output_name.startswith("past_"): custom_io_lang[output_name[: -len("_RetainedState")]] = kv_cache_dtype + # outputs for output_name in self.lang_output_names: if output_name.startswith("past_"): @@ -1022,6 +867,49 @@ def compile( ) self.model = model return self.vision_qpc_path, self.lang_qpc_path + else: + specializations = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "max_num_images": "1", + "max_image_tiles": "4", + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "max_num_images": "1", + "max_image_tiles": "4", + }, + ] + custom_io = {} + kv_cache_dtype = "float16" + + # inputs + for input_name in self.output_names: + if input_name.endswith("_RetainedState"): + custom_io[input_name[: -len("_RetainedState")]] = kv_cache_dtype + + # outputs + for output_name in self.output_names: + if output_name.endswith("_RetainedState"): + custom_io[output_name] = kv_cache_dtype + + compiler_options.update({"retained-state": True}) + self.lang_qpc_path = self._compile( + self.onnx_path, + compile_dir, + compile_only=True, + specializations=specializations, + convert_to_fp16=True, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=num_devices, + aic_num_cores=num_cores, + custom_io=custom_io, + **compiler_options, + ) def generate( self, @@ -1047,71 +935,121 @@ def generate( if self.kv_offload: self.kv_offload_generate(inputs, streamer, device_ids) else: - return self.cloud_ai_100_vlm_generate(inputs=inputs, device_ids=device_ids) + return self.cloud_ai_100_generate(inputs=inputs, device_ids=device_ids) # PyTorch runtime else: return self.pytorch_vlm_generate(model=self.model, inputs=inputs, streamer=streamer) - # TODO: Add the code based on how we did in single inference script - def cloud_ai_100_vlm_generate( + def cloud_ai_100_generate( self, inputs: torch.Tensor, device_ids: List[int] = [0], + enable_debug_logs: bool = False, ) -> np.ndarray: - """ - Generates features with list of prompts using AI 100 runtime. - - ``Mandatory`` Args: - :inputs (Union[torch.Tensor, np.ndarray]): inputs to run the execution. - ``Optional`` Args: - device_ids (List[int], optional): A list of device IDs to use for the session. Defaults to [0]. + qpc_session = QAICInferenceSession( + self.qpc_path, device_ids, enable_debug_logs=enable_debug_logs, activate=False + ) - Returns: - np.ndarray: A list of dictionaries containing the generated output features. - """ + batch_size, ctx_len, fbs = get_compilation_dims(self.qpc_path) - if self.qpc_session is None: - self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids) - self.batch_size = self.qpc_session.bindings[0].dims[0] - self.seq_len = self.qpc_session.bindings[0].dims[1] # Skip inputs/outputs - self.qpc_session.skip_buffers( - [x for x in self.qpc_session.input_names + self.qpc_session.output_names if x.startswith("past_")] - + ["pixel_values_RetainedState", "image_sizes_RetainedState"] + qpc_session.skip_buffers( + [x for x in qpc_session.input_names + qpc_session.output_names if x.startswith("past_")] ) # Read prompt and ctx len from session - # batch_size = max( - # [x[self.qpc_session.binding_index_map["input_ids"]][1][0] for x in self.qpc_session.allowed_shapes] - # + [self.qpc_session.bindings[self.qpc_session.binding_index_map["input_ids"]].dims[0]] - # ) - - # prefill_seq_len = max( - # [x[self.qpc_session.binding_index_map["input_ids"]][1][1] for x in self.qpc_session.allowed_shapes] - # + [self.qpc_session.bindings[self.qpc_session.binding_index_map["input_ids"]].dims[1]] - # ) - # Prepare input - input_ids_len = inputs["input_ids"].shape[1] - input_ids = np.array( - torch.nn.functional.pad(inputs["input_ids"], (0, self.seq_len - inputs["input_ids"].size(1)), "constant", 0) + batch_size = max( + [x[qpc_session.binding_index_map["input_ids"]][1][0] for x in qpc_session.allowed_shapes] + + [qpc_session.bindings[qpc_session.binding_index_map["input_ids"]].dims[0]] ) - attention_mask = np.array( - torch.nn.functional.pad( - inputs["attention_mask"], (0, self.seq_len - inputs["attention_mask"].size(1)), "constant", 0 - ) + + prefill_seq_len = max( + [x[qpc_session.binding_index_map["input_ids"]][1][1] for x in qpc_session.allowed_shapes] + + [qpc_session.bindings[qpc_session.binding_index_map["input_ids"]].dims[1]] ) - inputs = dict(input_ids=input_ids, attention_mask=attention_mask) + # lang_inputs = tokenizer(prompt, return_tensors="np", padding=True) + input_len = inputs["attention_mask"].sum(1, keepdims=True) + padded_len = inputs["input_ids"].shape[1] + num_chunks = -(padded_len // -prefill_seq_len) # ceil divide without float + padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + generation_len = None + if generation_len is None: + generation_len = ctx_len - input_len.max() - outputs = { - "output": np.random.randn(self.batch_size, self.seq_len, self.qpc_session.bindings[2].dims[2]).astype( - np.float32 - ), - } - self.qpc_session.set_buffers(outputs) - outputs = self.qpc_session.run(inputs) - outputs = outputs["output"][:, :input_ids_len, :] - return outputs + assert generation_len > 0, "generation length should be greater than zero" + generated_ids = np.full((batch_size, generation_len + 1), self.tokenizer.pad_token_id) + stream = None + if stream: + streamer = transformers.TextStreamer(self.tokenizer) + + # Prepare inputs for prefill + start = perf_counter() + + inputs["position_ids"] = np.where( + inputs.pop("attention_mask"), np.arange(padded_len), -1 + ) # Need to use -1 as position_ids for invalid tokens + inputs = dict(inputs) + + # vision_session.deactivate() + qpc_session.activate() + + # Run prefill + for i in range(num_chunks): + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] + chunk_inputs["position_ids"] = inputs["position_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] + outputs = qpc_session.run(chunk_inputs) + + # Skip inputs/outputs again + qpc_session.skip_buffers( + [x for x in qpc_session.input_names + qpc_session.output_names if x.startswith("past_")] + ) + + # Get first token + inputs["input_ids"] = outputs["logits"].argmax(2) + inputs["position_ids"] = input_len + inputs["cross_attention_mask"] = inputs["cross_attention_mask"][:, -1:, :, :] + generated_ids[:, 0] = inputs["input_ids"].squeeze(1) + finished_sequences = inputs["input_ids"] == self.tokenizer.eos_token_id + if stream: + streamer.put(inputs["input_ids"][0]) + + # Decode loop + loop_start = perf_counter() + for num_token in range(1, generation_len): + outputs = qpc_session.run(inputs) + + # Prepare inputs for next iteration + inputs["input_ids"] = outputs["logits"].argmax(2) + inputs["position_ids"] += 1 + generated_ids[:, num_token] = inputs["input_ids"].squeeze(1) + finished_sequences |= inputs["input_ids"] == self.tokenizer.eos_token_id + if stream: + streamer.put(inputs["input_ids"][0]) + if finished_sequences.all(): + break + + end = perf_counter() + if stream: + streamer.end() + generated_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + for i in range(1 if stream else 0, batch_size): + print(i, generated_texts[i]) + + prefill_perf = 1 / (loop_start - start) + decode_perf = (num_token - 1) / (end - loop_start) + total_perf = num_token / (end - start) + + print("TTFT:", round(loop_start - start, 2), "s", file=sys.stderr) + print("E2ET:", round(end - start, 2), "s", file=sys.stderr) + print("Prefill:", round(prefill_perf, 2), "tok/s", file=sys.stderr) + print("Decode:", round(decode_perf, 2), "tok/s", file=sys.stderr) + print("E2E:", round(total_perf, 2), "tok/s", file=sys.stderr) + if batch_size > 1: + print("Prefill (batch):", round(prefill_perf * batch_size, 2), "tok/s", file=sys.stderr) + print("Decode (batch):", round(decode_perf * batch_size, 2), "tok/s", file=sys.stderr) + print("E2E (batch):", round(total_perf * batch_size, 2), "tok/s", file=sys.stderr) def pytorch_vlm_generate( self, @@ -1169,6 +1107,7 @@ def kv_offload_generate( **kwargs, ): lang_session = QAICInferenceSession(self.lang_qpc_path, device_id, activate=False) + vision_session = QAICInferenceSession(self.vision_qpc_path, device_id) batch_size, ctx_len, fbs = get_compilation_dims(self.lang_qpc_path) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index c3ad99f85..3580d4fda 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -69,6 +69,7 @@ from transformers.models.mllama.modeling_mllama import ( MllamaCrossAttentionDecoderLayer, MllamaForCausalLM, + MllamaForConditionalGeneration, MllamaRotaryEmbedding, MllamaSelfAttentionDecoderLayer, MllamaTextCrossAttention, @@ -167,6 +168,7 @@ from QEfficient.transformers.models.mllama.modeling_mllama import ( QEffMllamaCrossAttentionDecoderLayer, QEffMllamaForCausalLM, + QEffMllamaForConditionalGeneration, QEffMllamaRotaryEmbedding, QEffMllamaSelfAttentionDecoderLayer, QEffMllamaTextCrossAttention, @@ -258,14 +260,16 @@ class KVCacheTransform(ModuleMappingTransform): Gemma2Model: QEffGemma2Model, Gemma2ForCausalLM: QEffGemma2ForCausalLM, # mllama - MllamaForCausalLM: QEffMllamaForCausalLM, - MllamaTextModel: QEffMllamaTextModel, - MllamaVisionModel: QEffMllamaVisionModel, - MllamaTextSelfAttention: QEffMllamaTextSelfAttention, + MllamaTextRMSNorm: CustomRMSNormAIC, MllamaTextCrossAttention: QEffMllamaTextCrossAttention, - MllamaCrossAttentionDecoderLayer: QEffMllamaCrossAttentionDecoderLayer, + MllamaTextSelfAttention: QEffMllamaTextSelfAttention, MllamaSelfAttentionDecoderLayer: QEffMllamaSelfAttentionDecoderLayer, + MllamaCrossAttentionDecoderLayer: QEffMllamaCrossAttentionDecoderLayer, MllamaRotaryEmbedding: QEffMllamaRotaryEmbedding, + MllamaVisionModel: QEffMllamaVisionModel, + MllamaTextModel: QEffMllamaTextModel, + MllamaForCausalLM: QEffMllamaForCausalLM, + MllamaForConditionalGeneration: QEffMllamaForConditionalGeneration, # Mistral MistralAttention: QEffMistralAttention, MistralDecoderLayer: QEffMistralDecoderLayer, @@ -349,4 +353,4 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: f"model class {model_class} does not yet support returning multiple logits to keep." ) - return model, transformed \ No newline at end of file + return model, transformed diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 462acf169..028dd13b7 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -49,11 +49,12 @@ def get_models_dir(): ONNX_EXPORT_EXAMPLE_FBS = 4 ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep ONNX_EXPORT_OPSET = 13 -ONNX_EXPORT_MAX_NUM_IMAGES =1 +ONNX_EXPORT_MAX_NUM_IMAGES = 1 ONNX_EXPORT_MAX_IMAGE_TILES = 4 ONNX_EXPORT_IMAGE_WIDTH = 560 ONNX_EXPORT_IMAGE_LENGHT = 560 -ONNX_EXPORT_IMAGE_DEPTH =3 +ONNX_EXPORT_IMAGE_DEPTH = 3 +ONNX_EXPORT_CTX_LEN = 1024 COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw", "-aic-hw-version=2.0"]