Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
9333b58
fix bt bark test
IlyasMoutawwakil Jul 25, 2024
4dda6df
setup
IlyasMoutawwakil Jul 25, 2024
5926bc5
patch clip models for sd
IlyasMoutawwakil Jul 25, 2024
c2a5c03
infer ort model dtype property from inputs dtypes
IlyasMoutawwakil Jul 25, 2024
b610212
patch all clip variants
IlyasMoutawwakil Jul 25, 2024
9923084
device setter
IlyasMoutawwakil Jul 25, 2024
0cb6be7
bigger model for now
IlyasMoutawwakil Jul 25, 2024
88831a5
fix device attribution
IlyasMoutawwakil Jul 25, 2024
a1f838c
onnx opset for owlvit and owlv2
IlyasMoutawwakil Jul 25, 2024
b8f5f32
model dtype
IlyasMoutawwakil Jul 25, 2024
81d0227
revert
IlyasMoutawwakil Jul 25, 2024
82a2879
use model part dtype instead
IlyasMoutawwakil Jul 25, 2024
d2a15b5
no need for dtype with diffusion pipelines
IlyasMoutawwakil Jul 25, 2024
c761026
revert
IlyasMoutawwakil Jul 25, 2024
0eb5dce
fix clip text model with projection not outputting hidden states
IlyasMoutawwakil Jul 25, 2024
f568bf6
whisper generation
IlyasMoutawwakil Jul 26, 2024
92ea60b
fix whisper, support cache_position, and using transformers whisper g…
IlyasMoutawwakil Jul 29, 2024
170eaba
style
IlyasMoutawwakil Jul 29, 2024
991b66b
create cache position for merged decoder and fix test for non whisper…
IlyasMoutawwakil Jul 29, 2024
8f8e6ca
typo
IlyasMoutawwakil Jul 29, 2024
e5934b3
Merge branch 'main' into support-transformers-4.43
echarlaix Jul 30, 2024
96bdde1
conditioned cache position argument
IlyasMoutawwakil Jul 30, 2024
9d09389
update whisper min transformers version
IlyasMoutawwakil Jul 30, 2024
056e450
compare whisper ort generation with transformers
IlyasMoutawwakil Jul 30, 2024
b3d9181
Merge branch 'support-transformers-4.43' of https://github.com/huggin…
IlyasMoutawwakil Jul 30, 2024
825cc6d
fix generation length for speech to text model type
IlyasMoutawwakil Jul 30, 2024
3fe0cac
cache position in whisper only with dynamic axis decoder_sequence_length
IlyasMoutawwakil Jul 30, 2024
b3948b9
use minimal prepare_inputs_for_generation in ORTModelForSpeechSeq2Seq
IlyasMoutawwakil Aug 2, 2024
2f69a8a
remove version restrictions on whisper
IlyasMoutawwakil Aug 2, 2024
4cc1065
comment
IlyasMoutawwakil Aug 2, 2024
8077ded
fix
IlyasMoutawwakil Aug 2, 2024
aa9b9d6
simpler
IlyasMoutawwakil Aug 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,12 +289,10 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
if self._behavior is not ConfigBehavior.ENCODER:
if self.use_past_in_inputs:
common_inputs["decoder_input_ids"] = {0: "batch_size"}
self.add_past_key_values(common_inputs, direction="inputs")
else:
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}

if self.use_past_in_inputs:
self.add_past_key_values(common_inputs, direction="inputs")

if self._behavior is ConfigBehavior.DECODER:
common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"}

Expand Down
42 changes: 40 additions & 2 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
NormalizedTextConfig,
NormalizedTextConfigWithGQA,
NormalizedVisionConfig,
check_if_transformers_greater,
is_diffusers_available,
logging,
)
Expand All @@ -71,6 +72,7 @@
)
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
from .model_patcher import (
CLIPModelPatcher,
FalconModelPatcher,
MistralModelPatcher,
MusicgenModelPatcher,
Expand Down Expand Up @@ -913,10 +915,16 @@ def outputs(self) -> Dict[str, Dict[int, str]]:

return common_outputs

def patch_model_for_export(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
model_kwargs: Optional[Dict[str, Any]] = None,
) -> "ModelPatcher":
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)


class CLIPOnnxConfig(TextAndVisionOnnxConfig):
NORMALIZED_CONFIG_CLASS = CLIPNormalizedConfig
DEFAULT_ONNX_OPSET = 14

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
Expand All @@ -935,6 +943,13 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
"image_embeds": {0: "image_batch_size"},
}

def patch_model_for_export(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
model_kwargs: Optional[Dict[str, Any]] = None,
) -> "ModelPatcher":
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)


class SentenceTransformersCLIPOnnxConfig(CLIPOnnxConfig):
@property
Expand Down Expand Up @@ -980,6 +995,13 @@ def outputs(self) -> Dict[str, Dict[int, str]]:

return common_outputs

def patch_model_for_export(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
model_kwargs: Optional[Dict[str, Any]] = None,
) -> "ModelPatcher":
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)


class CLIPTextOnnxConfig(CLIPTextWithProjectionOnnxConfig):
@property
Expand All @@ -997,12 +1019,20 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)

# TODO: fix should be by casting inputs during inference and not export
if framework == "pt":
import torch

dummy_inputs["input_ids"] = dummy_inputs["input_ids"].to(dtype=torch.int32)
return dummy_inputs

def patch_model_for_export(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
model_kwargs: Optional[Dict[str, Any]] = None,
) -> "ModelPatcher":
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)


class UNetOnnxConfig(VisionOnnxConfig):
ATOL_FOR_VALIDATION = 1e-3
Expand Down Expand Up @@ -1135,6 +1165,9 @@ class OwlViTOnnxConfig(CLIPOnnxConfig):
ATOL_FOR_VALIDATION = 1e-4
MIN_TORCH_VERSION = version.parse("2.1")

# needs einsum operator support, available since opset 12
DEFAULT_ONNX_OPSET = 12

def __init__(
self,
config: "PretrainedConfig",
Expand Down Expand Up @@ -1438,7 +1471,12 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
if self._behavior is not ConfigBehavior.DECODER:
common_inputs["input_features"] = {0: "batch_size"} # Remove unnecessary dynamic axis.

if self._behavior is ConfigBehavior.DECODER and self.use_past_in_inputs is False:
if self._behavior is not ConfigBehavior.ENCODER and self.use_past_in_inputs:
if check_if_transformers_greater("4.43.0"):
# since https://github.com/huggingface/transformers/pull/31166
common_inputs["cache_position"] = {0: "decoder_sequence_length"}

if self._behavior is ConfigBehavior.DECODER and not self.use_past_in_inputs:
common_inputs["encoder_outputs"][1] = f"{common_inputs['encoder_outputs'][1]} / 2"
return common_inputs

Expand Down
17 changes: 17 additions & 0 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,3 +1138,20 @@ def __init__(
self._update_causal_mask_original = self._model.model._update_causal_mask
else:
self._update_causal_mask_original = self._model._update_causal_mask


class CLIPModelPatcher(ModelPatcher):
def __enter__(self):
super().__enter__()

if _transformers_version >= version.parse("4.43"):
from transformers.models.clip.modeling_clip import CLIPAttention, CLIPSdpaAttention

self.original_sdpa_forward, CLIPSdpaAttention.forward = CLIPSdpaAttention.forward, CLIPAttention.forward

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if _transformers_version >= version.parse("4.43"):
from transformers.models.clip.modeling_clip import CLIPSdpaAttention

CLIPSdpaAttention.forward = self.original_sdpa_forward
10 changes: 6 additions & 4 deletions optimum/exporters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _get_submodels_for_export_diffusion(
pipeline, (StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline)
)
is_stable_diffusion_xl = isinstance(
pipeline, (StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline)
pipeline, (StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline)
)
is_latent_consistency_model = isinstance(
pipeline, (LatentConsistencyModelPipeline, LatentConsistencyModelImg2ImgPipeline)
Expand All @@ -117,10 +117,11 @@ def _get_submodels_for_export_diffusion(
models_for_export = {}

# Text encoder
if pipeline.text_encoder is not None:
text_encoder = getattr(pipeline, "text_encoder", None)
if text_encoder is not None:
if is_stable_diffusion_xl:
pipeline.text_encoder.config.output_hidden_states = True
models_for_export["text_encoder"] = pipeline.text_encoder
text_encoder.config.output_hidden_states = True
models_for_export["text_encoder"] = text_encoder

# U-NET
# ONNX export of torch.nn.functional.scaled_dot_product_attention not supported for < v2.1.0
Expand Down Expand Up @@ -151,6 +152,7 @@ def _get_submodels_for_export_diffusion(
text_encoder_2 = getattr(pipeline, "text_encoder_2", None)
if text_encoder_2 is not None:
text_encoder_2.config.output_hidden_states = True
text_encoder_2.text_model.config.output_hidden_states = True
models_for_export["text_encoder_2"] = text_encoder_2

return models_for_export
Expand Down
50 changes: 38 additions & 12 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from ..utils import NormalizedConfigManager
from ..utils.logging import warn_once
from .io_binding import TypeHelper
from .modeling_ort import ORTModel
from .utils import get_ordered_input_names, logging

Expand Down Expand Up @@ -62,6 +63,20 @@ def __init__(
def device(self):
return self.parent_model.device

@property
def dtype(self):
for dtype in self.input_dtypes.values():
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
if torch_dtype.is_floating_point:
return torch_dtype

for dtype in self.output_dtypes.values():
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
if torch_dtype.is_floating_point:
return torch_dtype

return None

@abstractmethod
def forward(self, *args, **kwargs):
pass
Expand Down Expand Up @@ -220,6 +235,7 @@ def forward(
encoder_attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.Tensor] = None,
use_cache_branch: None = None,
) -> Seq2SeqLMOutput:
# Adding use_cache_branch in the signature here is just a hack for IO Binding
Expand All @@ -236,8 +252,8 @@ def forward(
# no-ops if merged decoder is not used
use_merged_no_cache = past_key_values is None and self.parent_model.use_merged
use_merged_cache = past_key_values is not None and self.parent_model.use_merged
use_cache_branch_tensor, past_key_values = self.prepare_inputs_for_merged(
input_ids, past_key_values, use_torch=use_torch
use_cache_branch_tensor, past_key_values, cache_position = self.prepare_inputs_for_merged(
input_ids, past_key_values, cache_position, use_torch=use_torch
)

if self.parent_model.use_io_binding:
Expand Down Expand Up @@ -274,6 +290,9 @@ def forward(
if use_cache_branch_tensor is not None:
model_inputs.append(use_cache_branch_tensor)

if "cache_position" in self.input_names:
model_inputs.append(cache_position)

io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding(
self.session,
*model_inputs,
Expand Down Expand Up @@ -346,6 +365,7 @@ def forward(
"decoder_attention_mask": decoder_attention_mask,
"encoder_attention_mask": encoder_attention_mask,
"use_cache_branch": use_cache_branch_tensor,
"cache_position": cache_position,
"labels": labels,
}
if past_key_values is not None:
Expand Down Expand Up @@ -405,20 +425,20 @@ def forward(

def prepare_inputs_for_merged(
self,
input_ids: Union[None, torch.LongTensor, np.ndarray],
past_key_values: Union[None, Tuple[torch.FloatTensor], Tuple[np.ndarray]],
input_ids: Optional[Union[torch.LongTensor, np.ndarray]],
past_key_values: Optional[Tuple[Union[torch.FloatTensor, np.ndarray]]],
cache_position: Optional[Union[torch.Tensor, np.ndarray]],
use_torch: bool,
):
constructor = torch if use_torch is True else np

if self.parent_model.use_merged:
constructor = torch if use_torch is True else np
# Uses without/with branch of a merged decoder depending on whether real past key values are passed
use_cache_branch = constructor.full((1,), past_key_values is not None)
use_cache_branch_tensor = constructor.full((1,), past_key_values is not None)
if use_torch and use_cache_branch_tensor is not None:
use_cache_branch_tensor = use_cache_branch_tensor.to(self.device)
else:
# Uses separate decoders
use_cache_branch = None

if use_torch and use_cache_branch is not None:
use_cache_branch = use_cache_branch.to(self.device)
use_cache_branch_tensor = None

# Generate dummy past for the first forward if uses a merged decoder
if self.parent_model.use_merged and past_key_values is None:
Expand All @@ -434,7 +454,13 @@ def prepare_inputs_for_merged(

past_key_values = tuple(key_or_value for _ in range(len(self.key_value_input_names)))

return use_cache_branch, past_key_values
# Generate dummy position cache for the first forward if uses a merged decoder
if self.parent_model.use_merged and cache_position is None:
cache_position = constructor.zeros((1,), dtype=constructor.int64)
if use_torch is True:
cache_position = cache_position.to(self.device)

return use_cache_branch_tensor, past_key_values, cache_position
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a breaking change so we should be careful, not sure this method is used by anyone though

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the method is only used by the forward pass, I don't think any sub packages use it



class ORTDecoder(ORTDecoderForSeq2Seq):
Expand Down
8 changes: 7 additions & 1 deletion optimum/onnxruntime/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,10 +452,14 @@ def to(self, device: Union[torch.device, str, int]):
Returns:
`ORTModel`: the model placed on the requested device.
"""

device, provider_options = parse_device(device)
provider = get_provider_for_device(device)
validate_provider_availability(provider) # raise error if the provider is not available
self.device = device

if device.type == "cuda" and self.providers[0] == "TensorrtExecutionProvider":
return self

self.vae_decoder.session.set_providers([provider], provider_options=[provider_options])
self.text_encoder.session.set_providers([provider], provider_options=[provider_options])
self.unet.session.set_providers([provider], provider_options=[provider_options])
Expand All @@ -464,6 +468,8 @@ def to(self, device: Union[torch.device, str, int]):
self.vae_encoder.session.set_providers([provider], provider_options=[provider_options])

self.providers = self.vae_decoder.session.get_providers()
self._device = device

return self

@classmethod
Expand Down
28 changes: 23 additions & 5 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,24 @@ def __init__(

self._ordered_input_names = get_ordered_input_names(self.input_names.keys(), func=self.forward)

# TODO: why do we make device a property since we are only access the value, and do not do any check when setting the value?
@property
def dtype(self) -> torch.dtype:
"""
`torch.dtype`: The dtype of the model.
"""

for dtype in self.input_dtypes.values():
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
if torch_dtype.is_floating_point:
return torch_dtype

for dtype in self.output_dtypes.values():
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
if torch_dtype.is_floating_point:
return torch_dtype

return None

@property
def device(self) -> torch.device:
"""
Expand All @@ -286,8 +303,8 @@ def device(self) -> torch.device:
return self._device

@device.setter
def device(self, value: torch.device):
self._device = value
def device(self, **kwargs):
raise AttributeError("The device attribute is read-only, please use the `to` method to change the device.")

@property
def use_io_binding(self):
Expand All @@ -309,13 +326,13 @@ def to(self, device: Union[torch.device, str, int]):
Returns:
`ORTModel`: the model placed on the requested device.
"""

device, provider_options = parse_device(device)

if device.type == "cuda" and self.providers[0] == "TensorrtExecutionProvider":
return self

self.device = device
provider = get_provider_for_device(self.device)
provider = get_provider_for_device(device)
validate_provider_availability(provider) # raise error if the provider is not available

# IOBinding is only supported for CPU and CUDA Execution Providers.
Expand All @@ -331,6 +348,7 @@ def to(self, device: Union[torch.device, str, int]):

self.model.set_providers([provider], provider_options=[provider_options])
self.providers = self.model.get_providers()
self._device = device

return self

Expand Down
Loading