-
Notifications
You must be signed in to change notification settings - Fork 606
Support transformers 4.43 #1971
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9333b58
4dda6df
5926bc5
c2a5c03
b610212
9923084
0cb6be7
88831a5
a1f838c
b8f5f32
81d0227
82a2879
d2a15b5
c761026
0eb5dce
f568bf6
92ea60b
170eaba
991b66b
8f8e6ca
e5934b3
96bdde1
9d09389
056e450
b3d9181
825cc6d
3fe0cac
b3948b9
2f69a8a
4cc1065
8077ded
aa9b9d6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,7 @@ | |
|
|
||
| from ..utils import NormalizedConfigManager | ||
| from ..utils.logging import warn_once | ||
| from .io_binding import TypeHelper | ||
| from .modeling_ort import ORTModel | ||
| from .utils import get_ordered_input_names, logging | ||
|
|
||
|
|
@@ -62,6 +63,20 @@ def __init__( | |
| def device(self): | ||
| return self.parent_model.device | ||
|
|
||
| @property | ||
| def dtype(self): | ||
| for dtype in self.input_dtypes.values(): | ||
| torch_dtype = TypeHelper.ort_type_to_torch_type(dtype) | ||
| if torch_dtype.is_floating_point: | ||
| return torch_dtype | ||
|
|
||
| for dtype in self.output_dtypes.values(): | ||
| torch_dtype = TypeHelper.ort_type_to_torch_type(dtype) | ||
| if torch_dtype.is_floating_point: | ||
| return torch_dtype | ||
|
|
||
| return None | ||
|
|
||
| @abstractmethod | ||
| def forward(self, *args, **kwargs): | ||
| pass | ||
|
|
@@ -220,6 +235,7 @@ def forward( | |
| encoder_attention_mask: Optional[torch.LongTensor] = None, | ||
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, | ||
| labels: Optional[torch.LongTensor] = None, | ||
| cache_position: Optional[torch.Tensor] = None, | ||
| use_cache_branch: None = None, | ||
| ) -> Seq2SeqLMOutput: | ||
| # Adding use_cache_branch in the signature here is just a hack for IO Binding | ||
|
|
@@ -236,8 +252,8 @@ def forward( | |
| # no-ops if merged decoder is not used | ||
| use_merged_no_cache = past_key_values is None and self.parent_model.use_merged | ||
| use_merged_cache = past_key_values is not None and self.parent_model.use_merged | ||
| use_cache_branch_tensor, past_key_values = self.prepare_inputs_for_merged( | ||
| input_ids, past_key_values, use_torch=use_torch | ||
| use_cache_branch_tensor, past_key_values, cache_position = self.prepare_inputs_for_merged( | ||
| input_ids, past_key_values, cache_position, use_torch=use_torch | ||
| ) | ||
|
|
||
| if self.parent_model.use_io_binding: | ||
|
|
@@ -274,6 +290,9 @@ def forward( | |
| if use_cache_branch_tensor is not None: | ||
| model_inputs.append(use_cache_branch_tensor) | ||
|
|
||
| if "cache_position" in self.input_names: | ||
| model_inputs.append(cache_position) | ||
|
|
||
| io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding( | ||
| self.session, | ||
| *model_inputs, | ||
|
|
@@ -346,6 +365,7 @@ def forward( | |
| "decoder_attention_mask": decoder_attention_mask, | ||
| "encoder_attention_mask": encoder_attention_mask, | ||
| "use_cache_branch": use_cache_branch_tensor, | ||
| "cache_position": cache_position, | ||
| "labels": labels, | ||
| } | ||
| if past_key_values is not None: | ||
|
|
@@ -405,20 +425,20 @@ def forward( | |
|
|
||
| def prepare_inputs_for_merged( | ||
| self, | ||
| input_ids: Union[None, torch.LongTensor, np.ndarray], | ||
| past_key_values: Union[None, Tuple[torch.FloatTensor], Tuple[np.ndarray]], | ||
| input_ids: Optional[Union[torch.LongTensor, np.ndarray]], | ||
| past_key_values: Optional[Tuple[Union[torch.FloatTensor, np.ndarray]]], | ||
| cache_position: Optional[Union[torch.Tensor, np.ndarray]], | ||
| use_torch: bool, | ||
| ): | ||
| constructor = torch if use_torch is True else np | ||
|
|
||
| if self.parent_model.use_merged: | ||
| constructor = torch if use_torch is True else np | ||
| # Uses without/with branch of a merged decoder depending on whether real past key values are passed | ||
| use_cache_branch = constructor.full((1,), past_key_values is not None) | ||
| use_cache_branch_tensor = constructor.full((1,), past_key_values is not None) | ||
| if use_torch and use_cache_branch_tensor is not None: | ||
| use_cache_branch_tensor = use_cache_branch_tensor.to(self.device) | ||
| else: | ||
| # Uses separate decoders | ||
| use_cache_branch = None | ||
|
|
||
| if use_torch and use_cache_branch is not None: | ||
| use_cache_branch = use_cache_branch.to(self.device) | ||
| use_cache_branch_tensor = None | ||
|
|
||
| # Generate dummy past for the first forward if uses a merged decoder | ||
| if self.parent_model.use_merged and past_key_values is None: | ||
|
|
@@ -434,7 +454,13 @@ def prepare_inputs_for_merged( | |
|
|
||
| past_key_values = tuple(key_or_value for _ in range(len(self.key_value_input_names))) | ||
|
|
||
| return use_cache_branch, past_key_values | ||
| # Generate dummy position cache for the first forward if uses a merged decoder | ||
| if self.parent_model.use_merged and cache_position is None: | ||
| cache_position = constructor.zeros((1,), dtype=constructor.int64) | ||
| if use_torch is True: | ||
| cache_position = cache_position.to(self.device) | ||
|
|
||
| return use_cache_branch_tensor, past_key_values, cache_position | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a breaking change so we should be careful, not sure this method is used by anyone though
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the method is only used by the forward pass, I don't think any sub packages use it |
||
|
|
||
|
|
||
| class ORTDecoder(ORTDecoderForSeq2Seq): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.