Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions optimum/commands/export/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,20 @@ def parse_args_executorch(parser):
required=False,
help="Maximum sequence length for the model. If not specified, uses the model's default max_position_embeddings.",
)
required_group.add_argument(
"--dtype",
type=str,
choices=["float32", "float16", "bfloat16"],
required=False,
help="Data type for model weights. Options: float32, float16, bfloat16. Default: float32. For quantization (int8/int4), use the --qlinear arguments.",
)
required_group.add_argument(
"--device",
type=str,
choices=["cpu", "cuda", "cuda:0", "cuda:1", "cuda:2", "cuda:3"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why 0, 1, 2 etc? that seems like a strange thing

Copy link
Collaborator Author

@larryliu0820 larryliu0820 Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For an environment with more than 1 GPU you can specify which GPU

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is that a normal thing to do. I am a bit surprised as to where something like this will be used. Especially optimum-executorch is really for generating aot artifact and we dont know where it will be deployed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Especially optimum-executorch is really for generating aot artifact and we dont know where it will be deployed

I think this should be fine, I don't think the aoti compiled artifact has this GPU index hardcoded in. I think this is convenient if you have 2 GPUs and the first one is occupied / doesn't have enough memory and we want to export on the second GPU.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I do find it a bit awkward but sure.

required=False,
help="Device to run the model on. Options: cpu, cuda. Default: cpu.",
)


class ExecuTorchExportCommand(BaseOptimumCLICommand):
Expand Down Expand Up @@ -159,6 +173,10 @@ def run(self):
kwargs["qembedding_group_size"] = self.args.qembedding
if self.args.max_seq_len:
kwargs["max_seq_len"] = self.args.max_seq_len
if hasattr(self.args, "dtype") and self.args.dtype:
kwargs["dtype"] = self.args.dtype
if hasattr(self.args, "device") and self.args.device:
kwargs["device"] = self.args.device

main_export(
model_name_or_path=self.args.model,
Expand Down
1 change: 1 addition & 0 deletions optimum/exporters/executorch/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,6 @@ def export_to_executorch(
logging.info(
f"Saved exported program to {full_path} ({os.path.getsize(full_path) / (1024 * 1024):.2f} MB)"
)
prog.write_tensor_data_to_file(output_dir)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you sure that we want this behavior for all existing recipes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't put weight as external then this is just no-op

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's right^


return executorch_progs
129 changes: 129 additions & 0 deletions optimum/exporters/executorch/recipes/cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are the author, so add Meta copyright.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Hugging Face requires us to add this license.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're mixing two things:

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Dict, Union

import torch
from tabulate import tabulate
from torch.export import ExportedProgram
from torch.nn.attention import SDPBackend

from executorch.devtools.backend_debug import get_delegation_info
from executorch.exir import (
EdgeCompileConfig,
ExecutorchProgram,
to_edge_transform_and_lower,
)
from optimum.executorch.passes.remove_padding_idx_embedding_pass import RemovePaddingIdxEmbeddingPass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used in line 96


from ..integrations import (
CausalLMExportableModule,
MaskedLMExportableModule,
MultiModalTextToTextExportableModule,
Seq2SeqLMExportableModule,
)
from ..recipe_registry import register_recipe


aten = torch.ops.aten


@register_recipe("cuda")
def export_to_executorch_with_cuda(
model: Union[
CausalLMExportableModule,
MaskedLMExportableModule,
Seq2SeqLMExportableModule,
MultiModalTextToTextExportableModule,
],
**kwargs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype and device from the cli is never used

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

):
"""
Export a PyTorch model to ExecuTorch w/ delegation to CUDA backend.
This function also write metadata required by the ExecuTorch runtime to the .pte file.
Args:
model (Union[CausalLMExportableModule, MaskedLMExportableModule, Seq2SeqLMExportableModule, MultiModalTextToTextExportableModule]):
The PyTorch model to be exported to ExecuTorch.
**kwargs:
Additional keyword arguments for recipe-specific configurations, e.g. export using different example inputs, or different compile/bechend configs.
Returns:
Dict[str, ExecutorchProgram]:
A map of exported and optimized program for ExecuTorch.
For encoder-decoder models or multimodal models, it may generate multiple programs.
"""
# Import here to avoid version conflicts.
from torch._inductor.decomposition import conv1d_to_conv2d

from executorch.backends.cuda.cuda_backend import CudaBackend
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner

def _lower_to_executorch(
exported_programs: Dict[str, ExportedProgram],
metadata=None,
) -> Dict[str, ExecutorchProgram]:
logging.debug(f"\nExported program: {exported_programs}")

# If just one exported program, the method name in the .pte for it should be "forward".
if len(exported_programs) == 1:
exported_programs = {"forward": next(iter(exported_programs.values()))}

# CUDA backend compile spec with method name.
partitioners = {
key: [CudaPartitioner([CudaBackend.generate_method_name_compile_spec(key)])]
for key in exported_programs.keys()
}
# Add decompositions for triton to generate kernels.
for key, ep in exported_programs.items():
exported_programs[key] = ep.run_decompositions(
{
aten.conv1d.default: conv1d_to_conv2d,
}
)
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]):
et_prog = to_edge_transform_and_lower(
exported_programs,
partitioner=partitioners,
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
_skip_dim_order=True,
),
constant_methods=metadata,
transform_passes=[RemovePaddingIdxEmbeddingPass()],
)
et_prog = et_prog.to_executorch()
pte_name = "model"
for method in et_prog.methods:
logging.debug(f"---------------------- Method: {method} ----------------------")
logging.debug(f"\nExecuTorch program for {pte_name}.pte: {et_prog.exported_program(method).graph_module}")
delegation_info = get_delegation_info(et_prog.exported_program(method).graph_module)
logging.debug(f"\nDelegation info Summary for {pte_name}.pte: {delegation_info.get_summary()}")
logging.debug(
f"\nDelegation info for {pte_name}.pte: {tabulate(delegation_info.get_operator_delegation_dataframe(), headers='keys', tablefmt='fancy_grid')}"
)
return {pte_name: et_prog}

# Decomposes SDPA since we don't have a flash attention kernel for it yet.
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
exported_progs = model.export()

if (
model.config._attn_implementation == "custom_sdpa"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a generic recipe. does all supported transformer models have _attn_implementation? I dont suppose you are introducing this new behavior but just asking

Comment on lines +121 to +122
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you error out early on, before exporting, so that users don't have to wait the whole time

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I can do that

or model.config._attn_implementation == "custom_sdpa_ring_kv_cache"
):
Comment on lines +121 to +124
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how are we running custom sdpa for cuda? or is it supposed to be a graph break?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right this doesn't work. I can remove the whole block. Or actually it should throw an error for now.

raise NotImplementedError(
"Custom SDPA implementation is not supported for CUDA yet. Please use 'flash_attention' instead."
)

return _lower_to_executorch(exported_progs, model.metadata)
27 changes: 27 additions & 0 deletions tests/models/test_modeling_voxtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import gc
import logging
import os
import subprocess
import sys
import tempfile
import unittest

import pytest
Expand Down Expand Up @@ -324,3 +326,28 @@ def test_voxtral_audio_text_to_text_generation_with_custom_sdpa_kv_cache_8da4w_8
self.assertTrue(
check_multimodal_output_quality(model_id, generated_tokens, conversation, max_perplexity_threshold=5)
)

@slow
@pytest.mark.run_slow
@pytest.mark.skipif(is_linux_ci, reason="OOM")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA backend required")
def test_voxtral_export_to_executorch_cuda_recipe(self):
model_id = "mistralai/Voxtral-Mini-3B-2507"
task = "multimodal-text-to-text"
recipe = "cuda"
output_subdir = "executorch"

with tempfile.TemporaryDirectory() as tempdir:
output_dir = os.path.join(tempdir, output_subdir)
cmd = (
"optimum-cli export executorch "
f"--model {model_id} "
f"--task {task} "
f"--recipe {recipe} "
"--dtype bfloat16 "
"--device cuda:0 "
"--max_seq_len 1024 "
f"--output_dir {output_dir}"
)
subprocess.run(cmd, shell=True, check=True)
self.assertTrue(os.path.exists(os.path.join(output_dir, "model.pte")))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also test the existence of .ptd file too

Loading