Skip to content

Commit f310cbe

Browse files
Support transformers 4.43 (#1971)
* fix bt bark test * setup * patch clip models for sd * infer ort model dtype property from inputs dtypes * patch all clip variants * device setter * bigger model for now * fix device attribution * onnx opset for owlvit and owlv2 * model dtype * revert * use model part dtype instead * no need for dtype with diffusion pipelines * revert * fix clip text model with projection not outputting hidden states * whisper generation * fix whisper, support cache_position, and using transformers whisper generation loop * style * create cache position for merged decoder and fix test for non whisper speech to text * typo * conditioned cache position argument * update whisper min transformers version * compare whisper ort generation with transformers * fix generation length for speech to text model type * cache position in whisper only with dynamic axis decoder_sequence_length * use minimal prepare_inputs_for_generation in ORTModelForSpeechSeq2Seq * remove version restrictions on whisper * comment * fix * simpler --------- Co-authored-by: Ella Charlaix <[email protected]>
1 parent 6e2cbb1 commit f310cbe

File tree

12 files changed

+261
-493
lines changed

12 files changed

+261
-493
lines changed

optimum/exporters/onnx/config.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,12 +289,10 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
289289
if self._behavior is not ConfigBehavior.ENCODER:
290290
if self.use_past_in_inputs:
291291
common_inputs["decoder_input_ids"] = {0: "batch_size"}
292+
self.add_past_key_values(common_inputs, direction="inputs")
292293
else:
293294
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}
294295

295-
if self.use_past_in_inputs:
296-
self.add_past_key_values(common_inputs, direction="inputs")
297-
298296
if self._behavior is ConfigBehavior.DECODER:
299297
common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"}
300298

optimum/exporters/onnx/model_configs.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
NormalizedTextConfig,
5454
NormalizedTextConfigWithGQA,
5555
NormalizedVisionConfig,
56+
check_if_transformers_greater,
5657
is_diffusers_available,
5758
logging,
5859
)
@@ -71,6 +72,7 @@
7172
)
7273
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
7374
from .model_patcher import (
75+
CLIPModelPatcher,
7476
FalconModelPatcher,
7577
MistralModelPatcher,
7678
MusicgenModelPatcher,
@@ -919,6 +921,13 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
919921
"image_embeds": {0: "image_batch_size"},
920922
}
921923

924+
def patch_model_for_export(
925+
self,
926+
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
927+
model_kwargs: Optional[Dict[str, Any]] = None,
928+
) -> "ModelPatcher":
929+
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)
930+
922931

923932
class SentenceTransformersCLIPOnnxConfig(CLIPOnnxConfig):
924933
@property
@@ -964,6 +973,13 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
964973

965974
return common_outputs
966975

976+
def patch_model_for_export(
977+
self,
978+
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
979+
model_kwargs: Optional[Dict[str, Any]] = None,
980+
) -> "ModelPatcher":
981+
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)
982+
967983

968984
class CLIPTextOnnxConfig(CLIPTextWithProjectionOnnxConfig):
969985
@property
@@ -981,12 +997,20 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
981997
def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
982998
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)
983999

1000+
# TODO: fix should be by casting inputs during inference and not export
9841001
if framework == "pt":
9851002
import torch
9861003

9871004
dummy_inputs["input_ids"] = dummy_inputs["input_ids"].to(dtype=torch.int32)
9881005
return dummy_inputs
9891006

1007+
def patch_model_for_export(
1008+
self,
1009+
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
1010+
model_kwargs: Optional[Dict[str, Any]] = None,
1011+
) -> "ModelPatcher":
1012+
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)
1013+
9901014

9911015
class UNetOnnxConfig(VisionOnnxConfig):
9921016
ATOL_FOR_VALIDATION = 1e-3
@@ -1119,6 +1143,9 @@ class OwlViTOnnxConfig(CLIPOnnxConfig):
11191143
ATOL_FOR_VALIDATION = 1e-4
11201144
MIN_TORCH_VERSION = version.parse("2.1")
11211145

1146+
# needs einsum operator support, available since opset 12
1147+
DEFAULT_ONNX_OPSET = 12
1148+
11221149
def __init__(
11231150
self,
11241151
config: "PretrainedConfig",
@@ -1422,7 +1449,12 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
14221449
if self._behavior is not ConfigBehavior.DECODER:
14231450
common_inputs["input_features"] = {0: "batch_size"} # Remove unnecessary dynamic axis.
14241451

1425-
if self._behavior is ConfigBehavior.DECODER and self.use_past_in_inputs is False:
1452+
if self._behavior is not ConfigBehavior.ENCODER and self.use_past_in_inputs:
1453+
if check_if_transformers_greater("4.43.0"):
1454+
# since https://github.com/huggingface/transformers/pull/31166
1455+
common_inputs["cache_position"] = {0: "decoder_sequence_length"}
1456+
1457+
if self._behavior is ConfigBehavior.DECODER and not self.use_past_in_inputs:
14261458
common_inputs["encoder_outputs"][1] = f"{common_inputs['encoder_outputs'][1]} / 2"
14271459
return common_inputs
14281460

optimum/exporters/onnx/model_patcher.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,3 +1138,20 @@ def __init__(
11381138
self._update_causal_mask_original = self._model.model._update_causal_mask
11391139
else:
11401140
self._update_causal_mask_original = self._model._update_causal_mask
1141+
1142+
1143+
class CLIPModelPatcher(ModelPatcher):
1144+
def __enter__(self):
1145+
super().__enter__()
1146+
1147+
if _transformers_version >= version.parse("4.43"):
1148+
from transformers.models.clip.modeling_clip import CLIPAttention, CLIPSdpaAttention
1149+
1150+
self.original_sdpa_forward, CLIPSdpaAttention.forward = CLIPSdpaAttention.forward, CLIPAttention.forward
1151+
1152+
def __exit__(self, exc_type, exc_value, traceback):
1153+
super().__exit__(exc_type, exc_value, traceback)
1154+
if _transformers_version >= version.parse("4.43"):
1155+
from transformers.models.clip.modeling_clip import CLIPSdpaAttention
1156+
1157+
CLIPSdpaAttention.forward = self.original_sdpa_forward

optimum/exporters/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,11 @@ def _get_submodels_for_export_stable_diffusion(
8787
projection_dim = pipeline.text_encoder.config.projection_dim
8888

8989
# Text encoder
90-
if pipeline.text_encoder is not None:
90+
text_encoder = getattr(pipeline, "text_encoder", None)
91+
if text_encoder is not None:
9192
if isinstance(pipeline, StableDiffusionXLImg2ImgPipeline):
92-
pipeline.text_encoder.config.output_hidden_states = True
93-
models_for_export["text_encoder"] = pipeline.text_encoder
93+
text_encoder.config.output_hidden_states = True
94+
models_for_export["text_encoder"] = text_encoder
9495

9596
# U-NET
9697
# ONNX export of torch.nn.functional.scaled_dot_product_attention not supported for < v2.1.0
@@ -120,6 +121,7 @@ def _get_submodels_for_export_stable_diffusion(
120121
text_encoder_2 = getattr(pipeline, "text_encoder_2", None)
121122
if text_encoder_2 is not None:
122123
text_encoder_2.config.output_hidden_states = True
124+
text_encoder_2.text_model.config.output_hidden_states = True
123125
models_for_export["text_encoder_2"] = text_encoder_2
124126

125127
return models_for_export

optimum/onnxruntime/base.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from ..utils import NormalizedConfigManager
2626
from ..utils.logging import warn_once
27+
from .io_binding import TypeHelper
2728
from .modeling_ort import ORTModel
2829
from .utils import get_ordered_input_names, logging
2930

@@ -62,6 +63,20 @@ def __init__(
6263
def device(self):
6364
return self.parent_model.device
6465

66+
@property
67+
def dtype(self):
68+
for dtype in self.input_dtypes.values():
69+
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
70+
if torch_dtype.is_floating_point:
71+
return torch_dtype
72+
73+
for dtype in self.output_dtypes.values():
74+
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
75+
if torch_dtype.is_floating_point:
76+
return torch_dtype
77+
78+
return None
79+
6580
@abstractmethod
6681
def forward(self, *args, **kwargs):
6782
pass
@@ -220,6 +235,7 @@ def forward(
220235
encoder_attention_mask: Optional[torch.LongTensor] = None,
221236
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
222237
labels: Optional[torch.LongTensor] = None,
238+
cache_position: Optional[torch.Tensor] = None,
223239
use_cache_branch: None = None,
224240
) -> Seq2SeqLMOutput:
225241
# Adding use_cache_branch in the signature here is just a hack for IO Binding
@@ -236,8 +252,8 @@ def forward(
236252
# no-ops if merged decoder is not used
237253
use_merged_no_cache = past_key_values is None and self.parent_model.use_merged
238254
use_merged_cache = past_key_values is not None and self.parent_model.use_merged
239-
use_cache_branch_tensor, past_key_values = self.prepare_inputs_for_merged(
240-
input_ids, past_key_values, use_torch=use_torch
255+
use_cache_branch_tensor, past_key_values, cache_position = self.prepare_inputs_for_merged(
256+
input_ids, past_key_values, cache_position, use_torch=use_torch
241257
)
242258

243259
if self.parent_model.use_io_binding:
@@ -274,6 +290,9 @@ def forward(
274290
if use_cache_branch_tensor is not None:
275291
model_inputs.append(use_cache_branch_tensor)
276292

293+
if "cache_position" in self.input_names:
294+
model_inputs.append(cache_position)
295+
277296
io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding(
278297
self.session,
279298
*model_inputs,
@@ -346,6 +365,7 @@ def forward(
346365
"decoder_attention_mask": decoder_attention_mask,
347366
"encoder_attention_mask": encoder_attention_mask,
348367
"use_cache_branch": use_cache_branch_tensor,
368+
"cache_position": cache_position,
349369
"labels": labels,
350370
}
351371
if past_key_values is not None:
@@ -405,20 +425,20 @@ def forward(
405425

406426
def prepare_inputs_for_merged(
407427
self,
408-
input_ids: Union[None, torch.LongTensor, np.ndarray],
409-
past_key_values: Union[None, Tuple[torch.FloatTensor], Tuple[np.ndarray]],
428+
input_ids: Optional[Union[torch.LongTensor, np.ndarray]],
429+
past_key_values: Optional[Tuple[Union[torch.FloatTensor, np.ndarray]]],
430+
cache_position: Optional[Union[torch.Tensor, np.ndarray]],
410431
use_torch: bool,
411432
):
433+
constructor = torch if use_torch is True else np
434+
412435
if self.parent_model.use_merged:
413-
constructor = torch if use_torch is True else np
414436
# Uses without/with branch of a merged decoder depending on whether real past key values are passed
415-
use_cache_branch = constructor.full((1,), past_key_values is not None)
437+
use_cache_branch_tensor = constructor.full((1,), past_key_values is not None)
438+
if use_torch and use_cache_branch_tensor is not None:
439+
use_cache_branch_tensor = use_cache_branch_tensor.to(self.device)
416440
else:
417-
# Uses separate decoders
418-
use_cache_branch = None
419-
420-
if use_torch and use_cache_branch is not None:
421-
use_cache_branch = use_cache_branch.to(self.device)
441+
use_cache_branch_tensor = None
422442

423443
# Generate dummy past for the first forward if uses a merged decoder
424444
if self.parent_model.use_merged and past_key_values is None:
@@ -434,7 +454,13 @@ def prepare_inputs_for_merged(
434454

435455
past_key_values = tuple(key_or_value for _ in range(len(self.key_value_input_names)))
436456

437-
return use_cache_branch, past_key_values
457+
# Generate dummy position cache for the first forward if uses a merged decoder
458+
if self.parent_model.use_merged and cache_position is None:
459+
cache_position = constructor.zeros((1,), dtype=constructor.int64)
460+
if use_torch is True:
461+
cache_position = cache_position.to(self.device)
462+
463+
return use_cache_branch_tensor, past_key_values, cache_position
438464

439465

440466
class ORTDecoder(ORTDecoderForSeq2Seq):

optimum/onnxruntime/modeling_diffusion.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,10 +452,14 @@ def to(self, device: Union[torch.device, str, int]):
452452
Returns:
453453
`ORTModel`: the model placed on the requested device.
454454
"""
455+
455456
device, provider_options = parse_device(device)
456457
provider = get_provider_for_device(device)
457458
validate_provider_availability(provider) # raise error if the provider is not available
458-
self.device = device
459+
460+
if device.type == "cuda" and self.providers[0] == "TensorrtExecutionProvider":
461+
return self
462+
459463
self.vae_decoder.session.set_providers([provider], provider_options=[provider_options])
460464
self.text_encoder.session.set_providers([provider], provider_options=[provider_options])
461465
self.unet.session.set_providers([provider], provider_options=[provider_options])
@@ -464,6 +468,8 @@ def to(self, device: Union[torch.device, str, int]):
464468
self.vae_encoder.session.set_providers([provider], provider_options=[provider_options])
465469

466470
self.providers = self.vae_decoder.session.get_providers()
471+
self._device = device
472+
467473
return self
468474

469475
@classmethod

optimum/onnxruntime/modeling_ort.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,24 @@ def __init__(
276276

277277
self._ordered_input_names = get_ordered_input_names(self.input_names.keys(), func=self.forward)
278278

279-
# TODO: why do we make device a property since we are only access the value, and do not do any check when setting the value?
279+
@property
280+
def dtype(self) -> torch.dtype:
281+
"""
282+
`torch.dtype`: The dtype of the model.
283+
"""
284+
285+
for dtype in self.input_dtypes.values():
286+
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
287+
if torch_dtype.is_floating_point:
288+
return torch_dtype
289+
290+
for dtype in self.output_dtypes.values():
291+
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
292+
if torch_dtype.is_floating_point:
293+
return torch_dtype
294+
295+
return None
296+
280297
@property
281298
def device(self) -> torch.device:
282299
"""
@@ -286,8 +303,8 @@ def device(self) -> torch.device:
286303
return self._device
287304

288305
@device.setter
289-
def device(self, value: torch.device):
290-
self._device = value
306+
def device(self, **kwargs):
307+
raise AttributeError("The device attribute is read-only, please use the `to` method to change the device.")
291308

292309
@property
293310
def use_io_binding(self):
@@ -309,13 +326,13 @@ def to(self, device: Union[torch.device, str, int]):
309326
Returns:
310327
`ORTModel`: the model placed on the requested device.
311328
"""
329+
312330
device, provider_options = parse_device(device)
313331

314332
if device.type == "cuda" and self.providers[0] == "TensorrtExecutionProvider":
315333
return self
316334

317-
self.device = device
318-
provider = get_provider_for_device(self.device)
335+
provider = get_provider_for_device(device)
319336
validate_provider_availability(provider) # raise error if the provider is not available
320337

321338
# IOBinding is only supported for CPU and CUDA Execution Providers.
@@ -331,6 +348,7 @@ def to(self, device: Union[torch.device, str, int]):
331348

332349
self.model.set_providers([provider], provider_options=[provider_options])
333350
self.providers = self.model.get_providers()
351+
self._device = device
334352

335353
return self
336354

0 commit comments

Comments
 (0)