-
Notifications
You must be signed in to change notification settings - Fork 61
Extend on-device sampling support for dual QPC VLMs #597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
409da24
e06e175
3e242ce
1a01d57
30d6061
d02d04d
7cf106e
45aed11
6273ab5
ef9ae14
3789d5a
5e2afb7
df06617
10990a9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -289,6 +289,7 @@ | |
| QEffGrok1MultiHeadAttention, | ||
| ) | ||
| from QEfficient.transformers.models.internvl.modeling_internvl import ( | ||
| QEffInternDecoderWrapper, | ||
| QEffInternVisionEmbeddings, | ||
| QEffInternVLModel, | ||
| ) | ||
|
|
@@ -392,6 +393,7 @@ | |
| QEffQwen2_5_VLModel, | ||
| QEffQwen2_5_VLTextModel, | ||
| QEffQwen2_5_VLVisionAttention, | ||
| QEffQwen_2_5_vl_DecoderWrapper, | ||
| QEffQwen_2_5_vl_ForConditionalGeneration, | ||
| ) | ||
| from QEfficient.transformers.models.qwen3.modeling_qwen3 import ( | ||
|
|
@@ -707,10 +709,12 @@ class SamplerTransform: | |
| QEffGPTJForCausalLM, | ||
| QEffGraniteForCausalLM, | ||
| QEffGraniteMoeForCausalLM, | ||
| QEffInternDecoderWrapper, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this mean we are enabling sampling only for intern model?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Other VLMs are also supposed to be supported. But currently only InternVL and Qwen VL 2.5 have been tested. |
||
| QEffLlamaForCausalLM, | ||
| QEffMptForCausalLM, | ||
| QEffPhi3ForCausalLM, | ||
| QEffQwen2ForCausalLM, | ||
| QEffQwen_2_5_vl_DecoderWrapper, | ||
| } | ||
|
|
||
| @classmethod | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,8 @@ class SamplerOutput(ModelOutput): | |
|
|
||
| probs: torch.FloatTensor = None | ||
| next_tokens: torch.IntTensor = None | ||
| vision_embeds: Optional[torch.FloatTensor] = None # For VLMs | ||
| image_idx: Optional[torch.IntTensor] = None # for VLMs | ||
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | ||
| past_repetition_penalty_buffer: Optional[torch.Tensor] = None | ||
| past_presence_penalty_buffer: Optional[torch.Tensor] = None | ||
|
|
@@ -122,6 +124,8 @@ def sampler_forward( | |
| top_ps: Optional[torch.Tensor] = None, | ||
| min_ps: Optional[torch.Tensor] = None, | ||
| random_numbers: Optional[torch.Tensor] = None, | ||
| vision_embeds: Optional[torch.Tensor] = None, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add these both vision_embeds and image_idx in docs Args list. |
||
| image_idx: Optional[torch.Tensor] = None, | ||
| ) -> Union[Tuple, SamplerOutput]: | ||
| r""" | ||
| Perform the sampling of next tokens on the QAIC device (instead of the host) | ||
|
|
@@ -170,20 +174,35 @@ def sampler_forward( | |
| Sampling parameter that represents the random seeds to use for random sampling. | ||
| Must be in [-1, 1]. | ||
| """ | ||
|
|
||
| outputs = self.old_forward( | ||
| input_ids=input_ids, | ||
| attention_mask=attention_mask, | ||
| position_ids=position_ids, | ||
| past_key_values=past_key_values, | ||
| batch_index=batch_index, | ||
| inputs_embeds=inputs_embeds, | ||
| use_cache=use_cache, | ||
| output_attentions=output_attentions, | ||
| output_hidden_states=output_hidden_states, | ||
| return_dict=return_dict, | ||
| cache_position=cache_position, | ||
| ) | ||
| if vision_embeds is not None: | ||
| forward_kwargs = dict( | ||
| input_ids=input_ids, | ||
| vision_embeds=vision_embeds, | ||
| position_ids=position_ids, | ||
| image_idx=image_idx, | ||
| past_key_values=past_key_values, | ||
| ) | ||
| if batch_index is not None: | ||
| forward_kwargs["batch_index"] = batch_index | ||
|
|
||
| logits, vision_embeds, image_idx, past_key_values = self.old_forward(**forward_kwargs) | ||
| outputs = dict(logits=logits, vision_embeds=vision_embeds, image_idx=image_idx, past_key_values=past_key_values) | ||
| if position_ids.dim() == 3: # For models using m-rope | ||
| position_ids = position_ids[0] | ||
| else: | ||
| outputs = self.old_forward( | ||
| input_ids=input_ids, | ||
| attention_mask=attention_mask, | ||
| position_ids=position_ids, | ||
| past_key_values=past_key_values, | ||
| batch_index=batch_index, | ||
| inputs_embeds=inputs_embeds, | ||
| use_cache=use_cache, | ||
| output_attentions=output_attentions, | ||
| output_hidden_states=output_hidden_states, | ||
| return_dict=return_dict, | ||
| cache_position=cache_position, | ||
| ) | ||
|
|
||
| logits = outputs.get("logits", None) | ||
| assert logits is not None, f"{self.model.__class__.__name__} does not return logits." | ||
|
|
@@ -230,7 +249,9 @@ def sampler_forward( | |
| return SamplerOutput( | ||
| probs=None, | ||
| next_tokens=greedy_samples.reshape(-1, spec_length, 1), # Return sampled next tokens instead of logits | ||
| past_key_values=outputs.past_key_values, | ||
| vision_embeds=outputs.get("vision_embeds", None), | ||
| image_idx=outputs.get("image_idx", None), | ||
| past_key_values=outputs.get("past_key_values", None), | ||
| past_repetition_penalty_buffer=past_repetition_penalty_buffer, | ||
| past_presence_penalty_buffer=past_presence_penalty_buffer, | ||
| ) | ||
|
|
@@ -300,9 +321,8 @@ def sampler_forward( | |
| ) # (batch_size, spec_length, vocab_size) | ||
|
|
||
| # Random Sampling | ||
| topk_probs_asc = torch.softmax(topk_values_asc, dim=1) # (batch_size * spec_length, max_top_k_ids) | ||
| gumbel_noise = -torch.log(-torch.log(random_numbers.repeat(spec_length, 1))) # Gumbel-Max Trick | ||
| y = topk_probs_asc + gumbel_noise | ||
| y = topk_values_asc + gumbel_noise # (batch_size * spec_length, max_top_k_ids) | ||
| random_samples_indices = torch.argmax(y, dim=1, keepdim=True) | ||
| random_samples = torch.gather(topk_indices_asc, 1, random_samples_indices) # (batch_size * spec_length, 1) | ||
|
|
||
|
|
@@ -314,7 +334,9 @@ def sampler_forward( | |
| return SamplerOutput( | ||
| probs=probs, | ||
| next_tokens=next_tokens, # Return sampled next tokens instead of logits | ||
| past_key_values=outputs.past_key_values, | ||
| vision_embeds=outputs.get("vision_embeds", None), | ||
| image_idx=outputs.get("image_idx", None), | ||
| past_key_values=outputs.get("past_key_values", None), | ||
| past_repetition_penalty_buffer=past_repetition_penalty_buffer, | ||
| past_presence_penalty_buffer=past_presence_penalty_buffer, | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.