Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion QEfficient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ def check_qaic_sdk():
# Conditionally import QAIC-related modules if the SDK is installed
__version__ = "0.0.1.dev0"
if QAIC_INSTALLED:
from QEfficient.base import QEFFAutoModel, QEFFAutoModelForCausalLM, QEFFCommonLoader,QEFFAutoModelForImageTextToText
from QEfficient.base import (
QEFFAutoModel,
QEFFAutoModelForCausalLM,
QEFFAutoModelForImageTextToText,
QEFFCommonLoader,
)
from QEfficient.compile.compile_helper import compile
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
Expand Down
4 changes: 3 additions & 1 deletion QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def compile(self, *args, **kwargs) -> Path:

def _export(
self,
model,
example_inputs: Dict[str, torch.Tensor],
output_names: List[str],
dynamic_axes: Dict[str, Dict[int, str]],
Expand Down Expand Up @@ -157,7 +158,7 @@ def _export(
try:
export_kwargs = {} if export_kwargs is None else export_kwargs
torch.onnx.export(
self.model,
model,
(example_inputs,),
str(tmp_onnx_path),
input_names=input_names,
Expand All @@ -175,6 +176,7 @@ def _export(
}
if onnx_transform_kwargs is not None:
transform_kwargs.update(onnx_transform_kwargs)

for transform in self._onnx_transforms:
model, transformed = transform.apply(model, **transform_kwargs)
model.metadata_props.append(
Expand Down
95 changes: 94 additions & 1 deletion QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
# -----------------------------------------------------------------------------

from collections import namedtuple
from typing import Dict, Type
from typing import Dict, Optional, Tuple, Type

import torch
import torch.nn as nn
from transformers.models.codegen.modeling_codegen import (
CodeGenAttention,
Expand Down Expand Up @@ -242,3 +243,95 @@
GPTBigCodeBlock: QEffGPTBigCodeBlock,
GPTBigCodeModel: QEffGPTBigCodeModel,
}


def _prepare_cross_attention_mask(
cross_attention_mask: torch.Tensor,
num_vision_tokens: int,
dtype: str,
) -> Tuple[torch.Tensor, torch.Tensor]:
# reshape so it can be used by attn module
batch_size, text_total_length, *_ = cross_attention_mask.shape
cross_attention_mask = cross_attention_mask.repeat_interleave(num_vision_tokens, dim=3)
cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1)
cross_attention_mask = cross_attention_mask.unsqueeze(1)

# invert the mask
inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype)
cross_attention_mask = inverted_cross_attn_mask.masked_fill(
inverted_cross_attn_mask.to(torch.bool), torch.tensor(-10000.0, dtype=torch.float32)
)

# apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's
# last dimension contains negative infinity values, otherwise it's 1
negative_inf_value = torch.tensor(-10000.0, dtype=torch.float32)
full_text_row_masked_out_mask = (
(cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None]
)
cross_attention_mask *= full_text_row_masked_out_mask

return cross_attention_mask, full_text_row_masked_out_mask


def _prepare_aspect_ratio_attention_mask(
aspect_ratio_mask: torch.Tensor,
num_patches: int,
target_length: int,
dtype: torch.dtype,
) -> torch.Tensor:
# Expand aspect ratio mask to target_length
batch_size, max_num_tiles = aspect_ratio_mask.shape
attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype)
attention_mask = attention_mask.repeat(1, 1, target_length, 1)

# Mask padding patches
pad_patches = target_length - num_patches
attention_mask[:, :, -pad_patches:] = 0

# Invert the mask (0 -> 1, 1 -> 0)
attention_mask = 1 - attention_mask

# Reshape to 2D and create 4D attention mask
# (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length)
attention_mask = attention_mask.reshape(batch_size, max_num_tiles * target_length, 1)
attention_mask = attention_mask @ attention_mask.transpose(-1, -2) * torch.tensor(-10000.0, dtype=torch.float32)
attention_mask = attention_mask.unsqueeze(1)

return attention_mask


def _create_causal_mask(
position_ids,
target_length,
sliding_window: Optional[int] = None,
):
"""
A utility attention mask class that allows one to:
- Create a causal 4d mask
- Create a causal 4d mask with slided window
"""
if sliding_window is not None:
query_indices = position_ids.unsqueeze(-1)
kv_indices = torch.arange(target_length).view(1, -1)
# --- Rolling buffer ---
pos_max = position_ids.max(1, keepdim=True).values
kv_start = (pos_max // target_length) * target_length
kv_indices_high = kv_indices + kv_start
kv_indices_low = torch.where(kv_indices_high < target_length, kv_indices, kv_indices_high - target_length)
kv_indices = torch.where(kv_indices_high > pos_max, kv_indices_low, kv_indices_high)
kv_indices = kv_indices.unsqueeze(1)
# ------
causal_mask = kv_indices > query_indices
attention_mask = causal_mask

window_indices = query_indices - sliding_window + 1
window_mask = kv_indices < window_indices
attention_mask = attention_mask | window_mask
attention_mask = attention_mask.unsqueeze(1)
else:
query_indices = position_ids.unsqueeze(-1)
kv_indices = torch.arange(target_length).view(1, 1, -1)
attention_mask = kv_indices > query_indices
attention_mask = attention_mask.unsqueeze(1)

return attention_mask
Loading