diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index 47c462979..a0120b3ff 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -1,6 +1,6 @@ # ----------------------------------------------------------------------------- # -# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- @@ -12,8 +12,19 @@ # hf_transfer is imported (will happen on line 15 via leading imports) os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +from transformers import AutoConfig + +from QEfficient.transformers.modeling_utils import MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS from QEfficient.utils.logging_utils import logger +# loop over all the model types which are not present in transformers and register them +for model_type, model_cls in MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS.items(): + # Register the model config class based on the model type. This will be first element in the tuple + AutoConfig.register(model_type, model_cls[0]) + + # Register the non transformer library Class and config class using AutoModelClass + model_cls[2].register(model_cls[0], model_cls[1]) + def check_qaic_sdk(): """Check if QAIC SDK is installed""" diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 895388046..f9d529038 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -36,6 +36,83 @@ class QEffDynamicCache(DynamicCache): """ + def write_only(self, key_states, value_states, layer_idx, cache_kwargs): + """ + Write in the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + """ + # Update the cache + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + else: + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) + + # Scatter + if batch_index is not None: + invalid_scatter_index = torch.iinfo(torch.int32).max + scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids) + + self.key_cache[layer_idx] = CtxScatterFuncCB.apply( + self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states + ) + self.value_cache[layer_idx] = CtxScatterFuncCB.apply( + self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states + ) + else: + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply( + self.value_cache[layer_idx], position_ids, value_states + ) + + def read_only(self, layer_idx, cache_kwargs): + """ + Reads the `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) + ctx_len = k_out.shape[2] + ctx_indices = torch.arange(ctx_len)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + if batch_index is not None: + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) + else: + k_out = CtxGatherFunc.apply(k_out, ctx_indices) + v_out = CtxGatherFunc.apply(v_out, ctx_indices) + + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + return k_out, v_out + def update( self, key_states: torch.Tensor, diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index ccad5e020..548d8ef80 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -10,6 +10,7 @@ import torch import torch.nn as nn +from transformers import AutoModelForCausalLM from transformers.models.codegen.modeling_codegen import ( CodeGenAttention, CodeGenBlock, @@ -88,6 +89,12 @@ from QEfficient.customop import CustomRMSNormAIC +# Placeholder for all non-transformer models +from QEfficient.transformers.models.llama_swiftkv.modeling_llama_swiftkv import ( + QEffLlamaSwiftKVConfig, + QEffLlamaSwiftKVForCausalLM, +) + from .models.codegen.modeling_codegen import ( QEffCodeGenAttention, QeffCodeGenBlock, @@ -271,6 +278,11 @@ WhisperForConditionalGeneration: QEffWhisperForConditionalGeneration, } +# Map of model type to config class, Modelling class and transformer model architecture class +MODEL_TYPE_TO_CONFIG_CLS_AND_ARCH_CLS = { + "llama_swiftkv": [QEffLlamaSwiftKVConfig, QEffLlamaSwiftKVForCausalLM, AutoModelForCausalLM], +} + def _prepare_cross_attention_mask( cross_attention_mask: torch.Tensor, diff --git a/QEfficient/transformers/models/llama_swiftkv/__init__.py b/QEfficient/transformers/models/llama_swiftkv/__init__.py new file mode 100644 index 000000000..72ba36c8a --- /dev/null +++ b/QEfficient/transformers/models/llama_swiftkv/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py new file mode 100644 index 000000000..2ae0b3f38 --- /dev/null +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -0,0 +1,420 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +# This file is adapted from vllm implementation by snowflake here: https://github.com/Snowflake-Labs/vllm/blob/swiftkv/vllm/model_executor/models/llama_swiftkv.py +# The Modules are updated as required by Cloud AI 100 HW requirements. + + +"""Inference-only LLaMA model compatible with HuggingFace weights.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import LlamaConfig +from transformers.cache_utils import Cache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.models.llama.modeling_llama import LlamaMLP, LlamaRMSNorm, logger, repeat_kv + +from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.transformers.models.llama.modeling_llama import ( + QEffLlamaDecoderLayer, + QEffLlamaRotaryEmbedding, + qeff_apply_rotary_pos_emb, +) + + +class QEffLlamaSwiftKVConfig(LlamaConfig): + """ + Args: + num_key_value_layers (int, optional): + The number of layers, from the first layer, that have keys and + values. If None, all layers have keys and values. + last_key_value_heads (int, optional): + The number of heads in the last layer that have keys and values. + If None, the number of heads in the last key-value layer is equal + to the number of heads in all the other key-value layers. + """ + + model_type = "llama_swiftkv" + + def __init__( + self, + swiftkv: bool = False, + num_key_value_layers: Optional[int] = None, + key_value_group_size: Optional[int] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.swiftkv = swiftkv + self.num_key_value_layers = num_key_value_layers or self.num_hidden_layers + self.key_value_group_size = key_value_group_size or 1 + assert (self.num_hidden_layers - self.num_key_value_layers) % self.key_value_group_size == 0 + + +class QEffLlamaSwiftKVAttention(nn.Module): + def __init__(self, config: QEffLlamaSwiftKVConfig, layer_idx) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.attention_dropout = config.attention_dropout + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.layer_idx = layer_idx + self.q_proj_swiftkv = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj_swiftkv = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj_swiftkv = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + + self.rotary_emb = QEffLlamaRotaryEmbedding(config=config) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.LongTensor, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: torch.Tensor = None, + batch_index: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + bsz, q_len, _ = hidden_states.size() + q_len = 1 # as we always run this for single token + query = self.q_proj_swiftkv(hidden_states) + # Reshape the query, key, and value tensors. + query_states = query.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = position_ids.shape[-1] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + key_states, value_states = past_key_value.read_only(self.layer_idx, cache_kwargs=cache_kwargs) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + position_ids = position_ids[torch.arange(bsz), position_ids.to(torch.int32).argmax(1)].unsqueeze(1) + query_states, _ = qeff_apply_rotary_pos_emb( + query_states, torch.empty_like(query_states), cos, sin, position_ids + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attention_mask is not None: # no matter the length, we just slice it + attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + # attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + return attn_output, past_key_value + + +class QEffLlamaSwiftKVDecoderLayer(nn.Module): + def __init__(self, config: QEffLlamaSwiftKVConfig, layer_idx) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.num_key_value_heads = config.num_key_value_heads + self.self_attn = QEffLlamaSwiftKVAttention(config=config, layer_idx=layer_idx) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + past_key_values, + causal_mask, + batch_index: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, past_key_values = self.self_attn( + hidden_states=hidden_states, + position_ids=position_ids, + past_key_value=past_key_values, + attention_mask=causal_mask, + batch_index=batch_index, + ) + + hidden_states = residual + hidden_states + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, past_key_values + + +class QEffLlamaSwiftKVModel(nn.Module): + config_class = QEffLlamaSwiftKVConfig + + def __init__(self, config: QEffLlamaSwiftKVConfig): + super().__init__() + self.vocab_size = config.vocab_size + self.config = config + + self.embed_tokens = nn.Embedding(self.vocab_size, config.hidden_size, None) + self.layers = torch.nn.ModuleList( + [ + QEffLlamaDecoderLayer(config=config, layer_idx=idx) + if idx < config.num_key_value_layers + else QEffLlamaSwiftKVDecoderLayer(config=config, layer_idx=idx) + for idx in range(config.num_hidden_layers) + ] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm_swiftkv = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def _run_swiftkv_layers( + self, hidden_states: torch.Tensor, position_ids: torch.Tensor, past_key_values, causal_mask, batch_index + ) -> torch.Tensor: + for layer_idx in range(self.config.num_key_value_layers, self.config.num_hidden_layers): + layer = self.layers[layer_idx] + hidden_states, past_key_values = layer( + hidden_states, position_ids, past_key_values, causal_mask, batch_index + ) + + hidden_states = self.norm(hidden_states) + return hidden_states, past_key_values + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + self.config._attn_implementation = "eager" + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + else: + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=target_length) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + def forward( + self, + input_ids: Optional[torch.Tensor], + position_ids: torch.Tensor, + past_key_values: List[torch.Tensor], + batch_index: Optional[torch.LongTensor] = None, + ): + inputs_embeds = self.embed_tokens(input_ids) + + # kept for BC (non `Cache` `past_key_values` inputs) + use_cache = True + + if use_cache and not isinstance(past_key_values, Cache): + if past_key_values is None: + past_key_values = QEffDynamicCache() + else: + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + None, inputs_embeds, cache_position, position_ids, past_key_values, False + ) + hidden_states = inputs_embeds + + next_decoder_cache = None + + for layer_idx in range(self.config.num_key_value_layers): + layer = self.layers[layer_idx] + hidden_states, next_decoder_cache = layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + batch_index=batch_index, + output_attentions=False, + use_cache=True, + ) + + bsz, q_len, _ = hidden_states.size() + swiftkv_hidden_states = self.norm_swiftkv(hidden_states) + #################################### + ## THE MAGIC OF SWIFT KV BEGINS HERE + #################################### + for layer_idx in range(self.config.num_key_value_layers, self.config.num_hidden_layers): + self_attn = self.layers[layer_idx].self_attn + key_states = self_attn.k_proj_swiftkv(swiftkv_hidden_states) + value_states = self_attn.v_proj_swiftkv(swiftkv_hidden_states) + key_states = key_states.view(bsz, q_len, self_attn.num_key_value_heads, self_attn.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self_attn.num_key_value_heads, self_attn.head_dim).transpose( + 1, 2 + ) + + kv_seq_len = key_states.shape[-2] + if past_key_values is not None: + if self_attn.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self_attn.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len = past_key_values.get_usable_length(kv_seq_len, self_attn.layer_idx) + + cos, sin = self_attn.rotary_emb(value_states, seq_len=kv_seq_len) + _, key_states = qeff_apply_rotary_pos_emb(torch.empty_like(key_states), key_states, cos, sin, position_ids) + cache_kwargs = {"sin": sin, "cos": cos, "position_ids": position_ids, "batch_index": batch_index} + past_key_values.write_only(key_states, value_states, self_attn.layer_idx, cache_kwargs) + + last_pos_id = position_ids.to(torch.int32).argmax(1, keepdim=True) + orig_hidden_states = hidden_states + + # Extracting only the last valid position id to be processed by self-attn of half of the layers, as KV cache is already filled. + if batch_index is not None: + hidden_states = orig_hidden_states[torch.arange(orig_hidden_states.shape[0]).reshape(-1, 1), last_pos_id, :] + causal_mask = causal_mask[torch.arange(orig_hidden_states.shape[0]).reshape(-1, 1), :, last_pos_id, :] + else: + hidden_states = orig_hidden_states[torch.arange(bsz), last_pos_id, :] + causal_mask = causal_mask[torch.arange(bsz), :, last_pos_id, :] + + hidden_states, next_decoder_cache = self._run_swiftkv_layers( + hidden_states, position_ids, past_key_values, causal_mask, batch_index + ) + # We can fill the orig_hidden_states with the processed hidden_states here but it's not needed as for next token prediction + # we only need the last valid pos_indices hidden_states. + # Here the shape of hiden_states is [batch_size, 1, hidden_dim] instead of [batch_size, seq_len, hidden_dim] + # This saves un-necessary data movement on devices. + #################################### + ## THE MAGIC OF SWIFT KV ENDS HERE + #################################### + + next_cache = next_decoder_cache.to_legacy_cache() + return hidden_states, next_cache + + +class QEffLlamaSwiftKVForCausalLM(PreTrainedModel): # + config_class = QEffLlamaSwiftKVConfig + + def __init__(self, config: QEffLlamaSwiftKVConfig): + super().__init__(config=config) + + self.model = QEffLlamaSwiftKVModel( + config=config, + ) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.config = config + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Optional[Union[List[torch.FloatTensor]]] = None, + batch_index: Optional[torch.LongTensor] = None, + ): + hidden_states, output_past_key_values = self.model(input_ids, position_ids, past_key_values, batch_index) + logits = self.lm_head(hidden_states) + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=output_past_key_values, + hidden_states=None, + attentions=None, + ) diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index ea9044e2c..8ba5e2c18 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -17,7 +17,12 @@ import yaml from huggingface_hub import login, snapshot_download from requests.exceptions import HTTPError -from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers import ( + AutoProcessor, + AutoTokenizer, + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) from QEfficient.utils.constants import QEFF_MODELS_DIR, Constants, QnnConstants from QEfficient.utils.logging_utils import logger diff --git a/README.md b/README.md index 1096785bf..3a5a783f6 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ --- *Latest news* :fire:
+- [03/2025] Added support for swiftkv model [Snowflake/Llama-3.1-SwiftKV-8B-Instruct](https://huggingface.co/Snowflake/Llama-3.1-SwiftKV-8B-Instruct) - [02/2025] [VLMs support](https://github.com/quic/efficient-transformers/pull/267) added for the models [InternVL-1B](https://huggingface.co/OpenGVLab/InternVL2_5-1B), [Llava](https://huggingface.co/llava-hf/llava-1.5-7b-hf) and [Mllama](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) - [01/2025] [FP8 models support](https://huggingface.co/collections/neuralmagic/fp8-llms-for-vllm-666742ed2b78b7ac8df13127) Added support for inference of FP8 models. diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md index 2ccb013e9..33b9a03d7 100644 --- a/docs/source/quick_start.md +++ b/docs/source/quick_start.md @@ -14,7 +14,7 @@ To achieve this, we have 2 levels of APIs, with different levels of abstraction. | Feature | Impact | | --- | --- | | Context Length Specializations (upcoming) | Increases the maximum context length that models can handle, allowing for better performance on tasks requiring long sequences of text. | -| Swift KV (upcoming) | Reduces computational overhead during inference by optimizing key-value pair processing, leading to improved throughput. | +| Swift KV [Snowflake/Llama-3.1-SwiftKV-8B-Instruct] | Reduces computational overhead during inference by optimizing key-value pair processing, leading to improved throughput. | | Block Attention (in progress) | Reduces inference latency and computational cost by dividing context into blocks and reusing key-value states, particularly useful in RAG. | | [Vision Language Model](QEFFAutoModelForImageTextToText) | Provides support for the AutoModelForImageTextToText class from the transformers library, enabling advanced vision-language tasks. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/image_text_to_text_inference.py) for more **details**. | | [Speech Sequence to Sequence Model](QEFFAutoModelForSpeechSeq2Seq) | Provides support for the QEFFAutoModelForSpeechSeq2Seq Facilitates speech-to-text sequence models. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/speech_to_text/run_whisper_speech_to_text.py) for more **details**. | diff --git a/docs/source/validate.md b/docs/source/validate.md index acd4c11da..7f1690d2d 100644 --- a/docs/source/validate.md +++ b/docs/source/validate.md @@ -33,6 +33,7 @@ | **Phi3ForCausalLM** | Phi-3, Phi-3.5 | [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) | ✔️ | | **QwenForCausalLM** | DeepSeek-R1-Distill-Qwen | [DeepSeek-R1-Distill-Qwen-32B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B) | ✔️ | | | Qwen2, Qwen2.5 | [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) | ✔️ | +| **LlamaSwiftKVForCausalLM** | swiftkv | [Snowflake/Llama-3.1-SwiftKV-8B-Instruct](https://huggingface.co/Snowflake/Llama-3.1-SwiftKV-8B-Instruct) | ✔️ | ## Embedding Models diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 418386780..88c26bc23 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -46,6 +46,10 @@ "ibm-granite/granite-guardian-3.1-2b", ] +swiftkv_test_models = [ + "Snowflake/Llama-3.1-SwiftKV-8B-Instruct", # SwiftKV model +] + spd_test_models = [ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", ] @@ -110,7 +114,6 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( ) pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) - is_tlm = False if num_speculative_tokens is None else True qeff_model = QEFFAutoModelForCausalLM(model_hf, is_tlm=is_tlm) @@ -138,6 +141,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( ) exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR) cloud_ai_100_tokens = exec_info.generated_ids[0] # Because we always run for single input and single batch size + gen_len = ort_tokens.shape[-1] assert (ort_tokens == cloud_ai_100_tokens[:, :gen_len]).all(), ( "Tokens don't match for ONNXRT output and Cloud AI 100 output." @@ -146,6 +150,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( # testing for CB models model_hf, _ = load_causal_lm_model(model_config) + config = model_hf.config full_batch_size = 4 fbs_prompts = Constants.INPUT_STR * 4 api_runner = ApiRunner( @@ -161,6 +166,102 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch_CB(model_hf) pytorch_hf_tokens = np.vstack(pytorch_hf_tokens) + qeff_model = QEFFAutoModelForCausalLM(model_hf, continuous_batching=True, is_tlm=False) + onnx_model_path = qeff_model.export() + + if not get_available_device_id(): + pytest.skip("No available devices to run model on Cloud AI 100") + + qpc_path = qeff_model.compile( + prefill_seq_len=prompt_len, + ctx_len=ctx_len, + num_cores=14, + mxfp6_matmul=False, + aic_enable_depth_first=False, + full_batch_size=full_batch_size, + num_speculative_tokens=num_speculative_tokens, + ) + exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts) + + assert all( + [ + all(pt_token[:24] == cloud_token[:24]) + for pt_token, cloud_token in zip(pytorch_hf_tokens, exec_info_fbs.generated_ids) + ] + ), "Tokens don't match for HF PyTorch model output and Cloud AI 100 output." + assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json")) + + +def check_non_hf_kv_vs_ort_vs_ai100( + model_name: str, + prompt_len: int = Constants.PROMPT_LEN, + ctx_len: int = Constants.CTX_LEN, + n_layer: int = 1, + num_speculative_tokens: Optional[int] = None, +): + """ + Validate the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + :prompt_len (int): Prompt length for the model to compile. + :ctx_len (int): Maximum context length to compile the model. + :n_layers (int): Number of layers for the Model. + """ + replace_transformers_quantizers() + model_config = {"model_name": model_name} + model_config["n_layer"] = n_layer + + model_hf, _ = load_causal_lm_model(model_config) + + tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_name) + config = model_hf.config + batch_size = len(Constants.INPUT_STR) + api_runner = ApiRunner( + batch_size, + tokenizer, + config, + Constants.INPUT_STR, + Constants.PROMPT_LEN, + Constants.CTX_LEN, + ) + + is_tlm = False if num_speculative_tokens is None else True + + qeff_model = QEFFAutoModelForCausalLM(model_hf, is_tlm=is_tlm) + pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) + + onnx_model_path = qeff_model.export() + ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path, is_tlm=is_tlm) + + assert (pytorch_kv_tokens == ort_tokens).all(), "Tokens don't match for ONNXRT output and PyTorch output." + + if not get_available_device_id(): + pytest.skip("No available devices to run model on Cloud AI 100") + + qpc_path = qeff_model.compile( + prefill_seq_len=prompt_len, + ctx_len=ctx_len, + num_cores=14, + mxfp6=False, + aic_enable_depth_first=False, + num_speculative_tokens=num_speculative_tokens, + ) + + exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR) + cloud_ai_100_tokens = exec_info.generated_ids[0] # Because we always run for single input and single batch size + gen_len = ort_tokens.shape[-1] + + assert (ort_tokens == cloud_ai_100_tokens[:, :gen_len]).all(), ( + "Tokens don't match for ONNXRT output and Cloud AI 100 output." + ) + assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json")) + + # testing for CB models + model_hf, _ = load_causal_lm_model(model_config) + config = model_hf.config + full_batch_size = 4 + fbs_prompts = Constants.INPUT_STR * 4 + qeff_model = QEFFAutoModelForCausalLM(model_hf, continuous_batching=True, is_tlm=is_tlm) onnx_model_path = qeff_model.export() @@ -176,12 +277,13 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( full_batch_size=full_batch_size, num_speculative_tokens=num_speculative_tokens, ) + exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts) assert all( [ all(pt_token[:24] == cloud_token[:24]) - for pt_token, cloud_token in zip(pytorch_hf_tokens, exec_info_fbs.generated_ids) + for pt_token, cloud_token in zip(ort_tokens, exec_info_fbs.generated_ids) ] ), "Tokens don't match for HF PyTorch model output and Cloud AI 100 output." assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json")) @@ -233,6 +335,22 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer) +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_name", swiftkv_test_models) +def test_non_hf_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): + """ + Test function to validate the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + if model_name == "Snowflake/Llama-3.1-SwiftKV-8B-Instruct": + n_layer = 32 + else: + n_layer = 2 + + check_non_hf_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer) + + @pytest.mark.skip() # remove when the SDK 1.20.0 issue solved for compiling this model @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", spd_test_models)