diff --git a/examples/export/test/test_export.py b/examples/export/test/test_export.py index 372fb74e6cc..f26092673da 100644 --- a/examples/export/test/test_export.py +++ b/examples/export/test/test_export.py @@ -81,3 +81,11 @@ def test_vit_export_to_executorch(self): self._assert_eager_lowered_same_result( eager_model, example_inputs, self.validate_tensor_allclose ) + + def test_w2l_export_to_executorch(self): + eager_model, example_inputs = MODEL_NAME_TO_MODEL["w2l"]() + eager_model = eager_model.eval() + + self._assert_eager_lowered_same_result( + eager_model, example_inputs, self.validate_tensor_allclose + ) diff --git a/examples/models/TARGETS b/examples/models/TARGETS index f7bd4eb4607..15e30256578 100644 --- a/examples/models/TARGETS +++ b/examples/models/TARGETS @@ -11,6 +11,7 @@ python_library( "//executorch/examples/models/mobilenet_v2:mv2_export", "//executorch/examples/models/mobilenet_v3:mv3_export", "//executorch/examples/models/torchvision_vit:vit_export", + "//executorch/examples/models/wav2letter:w2l_export", "//executorch/exir/backend:compile_spec_schema", ], ) diff --git a/examples/models/models.py b/examples/models/models.py index 37e7b1bd798..aa31718aaba 100644 --- a/examples/models/models.py +++ b/examples/models/models.py @@ -95,6 +95,13 @@ def gen_torchvision_vit_model_and_inputs() -> Tuple[torch.nn.Module, Any]: return TorchVisionViTModel.get_model(), TorchVisionViTModel.get_example_inputs() +def gen_wav2letter_model_and_inputs() -> Tuple[torch.nn.Module, Any]: + from ..models.wav2letter import Wav2LetterModel + + model = Wav2LetterModel() + return model.get_model(), model.get_example_inputs() + + MODEL_NAME_TO_MODEL = { "mul": lambda: (MulModule(), MulModule.get_example_inputs()), "linear": lambda: (LinearModule(), LinearModule.get_example_inputs()), @@ -103,4 +110,5 @@ def gen_torchvision_vit_model_and_inputs() -> Tuple[torch.nn.Module, Any]: "mv2": gen_mobilenet_v2_model_inputs, "mv3": gen_mobilenet_v3_model_inputs, "vit": gen_torchvision_vit_model_and_inputs, + "w2l": gen_wav2letter_model_and_inputs, } diff --git a/examples/models/wav2letter/TARGETS b/examples/models/wav2letter/TARGETS new file mode 100644 index 00000000000..1d87315a3f1 --- /dev/null +++ b/examples/models/wav2letter/TARGETS @@ -0,0 +1,14 @@ +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") + +python_library( + name = "w2l_export", + srcs = [ + "__init__.py", + "export.py", + ], + base_module = "executorch.examples.models.wav2letter", + deps = [ + "//caffe2:torch", + "//pytorch/audio:torchaudio", + ], +) diff --git a/examples/models/wav2letter/__init__.py b/examples/models/wav2letter/__init__.py new file mode 100644 index 00000000000..84473d4f54f --- /dev/null +++ b/examples/models/wav2letter/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .export import Wav2LetterModel + +__all__ = [ + Wav2LetterModel, +] diff --git a/examples/models/wav2letter/export.py b/examples/models/wav2letter/export.py new file mode 100644 index 00000000000..38be6d9d9c4 --- /dev/null +++ b/examples/models/wav2letter/export.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +from torchaudio import models + +FORMAT = "[%(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(format=FORMAT) + + +class Wav2LetterModel: + def __init__(self): + self.batch_size = 10 + self.input_frames = 700 + self.vocab_size = 4096 + + def get_model(self): + logging.info("loading wav2letter model") + wav2letter = models.Wav2Letter(num_classes=self.vocab_size) + logging.info("loaded wav2letter model") + return wav2letter + + def get_example_inputs(self): + input_shape = (self.batch_size, 1, self.input_frames) + return (torch.randn(input_shape),)