diff --git a/optimum/commands/export/executorch.py b/optimum/commands/export/executorch.py index 09a907fb..bf996805 100644 --- a/optimum/commands/export/executorch.py +++ b/optimum/commands/export/executorch.py @@ -126,6 +126,20 @@ def parse_args_executorch(parser): required=False, help="Maximum sequence length for the model. If not specified, uses the model's default max_position_embeddings.", ) + required_group.add_argument( + "--dtype", + type=str, + choices=["float32", "float16", "bfloat16"], + required=False, + help="Data type for model weights. Options: float32, float16, bfloat16. Default: float32. For quantization (int8/int4), use the --qlinear arguments.", + ) + required_group.add_argument( + "--device", + type=str, + choices=["cpu", "cuda", "cuda:0", "cuda:1", "cuda:2", "cuda:3"], + required=False, + help="Device to run the model on. Options: cpu, cuda. Default: cpu.", + ) class ExecuTorchExportCommand(BaseOptimumCLICommand): @@ -159,6 +173,10 @@ def run(self): kwargs["qembedding_group_size"] = self.args.qembedding if self.args.max_seq_len: kwargs["max_seq_len"] = self.args.max_seq_len + if hasattr(self.args, "dtype") and self.args.dtype: + kwargs["dtype"] = self.args.dtype + if hasattr(self.args, "device") and self.args.device: + kwargs["device"] = self.args.device main_export( model_name_or_path=self.args.model, diff --git a/optimum/exporters/executorch/convert.py b/optimum/exporters/executorch/convert.py index 8a381423..d64dc527 100644 --- a/optimum/exporters/executorch/convert.py +++ b/optimum/exporters/executorch/convert.py @@ -86,5 +86,6 @@ def export_to_executorch( logging.info( f"Saved exported program to {full_path} ({os.path.getsize(full_path) / (1024 * 1024):.2f} MB)" ) + prog.write_tensor_data_to_file(output_dir) return executorch_progs diff --git a/optimum/exporters/executorch/recipes/cuda.py b/optimum/exporters/executorch/recipes/cuda.py new file mode 100644 index 00000000..5aa69735 --- /dev/null +++ b/optimum/exporters/executorch/recipes/cuda.py @@ -0,0 +1,129 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict, Union + +import torch +from tabulate import tabulate +from torch.export import ExportedProgram +from torch.nn.attention import SDPBackend + +from executorch.devtools.backend_debug import get_delegation_info +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchProgram, + to_edge_transform_and_lower, +) +from optimum.executorch.passes.remove_padding_idx_embedding_pass import RemovePaddingIdxEmbeddingPass + +from ..integrations import ( + CausalLMExportableModule, + MaskedLMExportableModule, + MultiModalTextToTextExportableModule, + Seq2SeqLMExportableModule, +) +from ..recipe_registry import register_recipe + + +aten = torch.ops.aten + + +@register_recipe("cuda") +def export_to_executorch_with_cuda( + model: Union[ + CausalLMExportableModule, + MaskedLMExportableModule, + Seq2SeqLMExportableModule, + MultiModalTextToTextExportableModule, + ], + **kwargs, +): + """ + Export a PyTorch model to ExecuTorch w/ delegation to CUDA backend. + This function also write metadata required by the ExecuTorch runtime to the .pte file. + Args: + model (Union[CausalLMExportableModule, MaskedLMExportableModule, Seq2SeqLMExportableModule, MultiModalTextToTextExportableModule]): + The PyTorch model to be exported to ExecuTorch. + **kwargs: + Additional keyword arguments for recipe-specific configurations, e.g. export using different example inputs, or different compile/bechend configs. + Returns: + Dict[str, ExecutorchProgram]: + A map of exported and optimized program for ExecuTorch. + For encoder-decoder models or multimodal models, it may generate multiple programs. + """ + # Import here to avoid version conflicts. + from torch._inductor.decomposition import conv1d_to_conv2d + + from executorch.backends.cuda.cuda_backend import CudaBackend + from executorch.backends.cuda.cuda_partitioner import CudaPartitioner + + def _lower_to_executorch( + exported_programs: Dict[str, ExportedProgram], + metadata=None, + ) -> Dict[str, ExecutorchProgram]: + logging.debug(f"\nExported program: {exported_programs}") + + # If just one exported program, the method name in the .pte for it should be "forward". + if len(exported_programs) == 1: + exported_programs = {"forward": next(iter(exported_programs.values()))} + + # CUDA backend compile spec with method name. + partitioners = { + key: [CudaPartitioner([CudaBackend.generate_method_name_compile_spec(key)])] + for key in exported_programs.keys() + } + # Add decompositions for triton to generate kernels. + for key, ep in exported_programs.items(): + exported_programs[key] = ep.run_decompositions( + { + aten.conv1d.default: conv1d_to_conv2d, + } + ) + with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]): + et_prog = to_edge_transform_and_lower( + exported_programs, + partitioner=partitioners, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + constant_methods=metadata, + transform_passes=[RemovePaddingIdxEmbeddingPass()], + ) + et_prog = et_prog.to_executorch() + pte_name = "model" + for method in et_prog.methods: + logging.debug(f"---------------------- Method: {method} ----------------------") + logging.debug(f"\nExecuTorch program for {pte_name}.pte: {et_prog.exported_program(method).graph_module}") + delegation_info = get_delegation_info(et_prog.exported_program(method).graph_module) + logging.debug(f"\nDelegation info Summary for {pte_name}.pte: {delegation_info.get_summary()}") + logging.debug( + f"\nDelegation info for {pte_name}.pte: {tabulate(delegation_info.get_operator_delegation_dataframe(), headers='keys', tablefmt='fancy_grid')}" + ) + return {pte_name: et_prog} + + # Decomposes SDPA since we don't have a flash attention kernel for it yet. + with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): + exported_progs = model.export() + + if ( + model.config._attn_implementation == "custom_sdpa" + or model.config._attn_implementation == "custom_sdpa_ring_kv_cache" + ): + raise NotImplementedError( + "Custom SDPA implementation is not supported for CUDA yet. Please use 'flash_attention' instead." + ) + + return _lower_to_executorch(exported_progs, model.metadata) diff --git a/tests/models/test_modeling_voxtral.py b/tests/models/test_modeling_voxtral.py index 3dd1f84e..e0b13e61 100644 --- a/tests/models/test_modeling_voxtral.py +++ b/tests/models/test_modeling_voxtral.py @@ -16,7 +16,9 @@ import gc import logging import os +import subprocess import sys +import tempfile import unittest import pytest @@ -324,3 +326,28 @@ def test_voxtral_audio_text_to_text_generation_with_custom_sdpa_kv_cache_8da4w_8 self.assertTrue( check_multimodal_output_quality(model_id, generated_tokens, conversation, max_perplexity_threshold=5) ) + + @slow + @pytest.mark.run_slow + @pytest.mark.skipif(is_linux_ci, reason="OOM") + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA backend required") + def test_voxtral_export_to_executorch_cuda_recipe(self): + model_id = "mistralai/Voxtral-Mini-3B-2507" + task = "multimodal-text-to-text" + recipe = "cuda" + output_subdir = "executorch" + + with tempfile.TemporaryDirectory() as tempdir: + output_dir = os.path.join(tempdir, output_subdir) + cmd = ( + "optimum-cli export executorch " + f"--model {model_id} " + f"--task {task} " + f"--recipe {recipe} " + "--dtype bfloat16 " + "--device cuda:0 " + "--max_seq_len 1024 " + f"--output_dir {output_dir}" + ) + subprocess.run(cmd, shell=True, check=True) + self.assertTrue(os.path.exists(os.path.join(output_dir, "model.pte")))