From a7b66d4b534689121660dbf61eca4869d2e8588b Mon Sep 17 00:00:00 2001 From: Guang Yang Date: Wed, 16 Aug 2023 15:57:38 -0700 Subject: [PATCH] Add wav2letter model to examples Summary: Bring in `Wav2Letter` model to `executorch/examples`. - General info about `Wav2Letter` model: https://ai.meta.com/tools/wav2letter/ - Info about `Wav2Letter` model being used in this example: https://pytorch.org/audio/stable/_modules/torchaudio/models/wav2letter.html Reviewed By: JacobSzwejbka Differential Revision: D48403704 fbshipit-source-id: 8cde04b0e1327e5b94b0dec087989b4cd16e6f51 --- examples/export/test/test_export.py | 8 +++++++ examples/models/TARGETS | 1 + examples/models/models.py | 8 +++++++ examples/models/wav2letter/TARGETS | 14 ++++++++++++ examples/models/wav2letter/__init__.py | 11 ++++++++++ examples/models/wav2letter/export.py | 30 ++++++++++++++++++++++++++ 6 files changed, 72 insertions(+) create mode 100644 examples/models/wav2letter/TARGETS create mode 100644 examples/models/wav2letter/__init__.py create mode 100644 examples/models/wav2letter/export.py diff --git a/examples/export/test/test_export.py b/examples/export/test/test_export.py index 372fb74e6cc..f26092673da 100644 --- a/examples/export/test/test_export.py +++ b/examples/export/test/test_export.py @@ -81,3 +81,11 @@ def test_vit_export_to_executorch(self): self._assert_eager_lowered_same_result( eager_model, example_inputs, self.validate_tensor_allclose ) + + def test_w2l_export_to_executorch(self): + eager_model, example_inputs = MODEL_NAME_TO_MODEL["w2l"]() + eager_model = eager_model.eval() + + self._assert_eager_lowered_same_result( + eager_model, example_inputs, self.validate_tensor_allclose + ) diff --git a/examples/models/TARGETS b/examples/models/TARGETS index f7bd4eb4607..15e30256578 100644 --- a/examples/models/TARGETS +++ b/examples/models/TARGETS @@ -11,6 +11,7 @@ python_library( "//executorch/examples/models/mobilenet_v2:mv2_export", "//executorch/examples/models/mobilenet_v3:mv3_export", "//executorch/examples/models/torchvision_vit:vit_export", + "//executorch/examples/models/wav2letter:w2l_export", "//executorch/exir/backend:compile_spec_schema", ], ) diff --git a/examples/models/models.py b/examples/models/models.py index 37e7b1bd798..aa31718aaba 100644 --- a/examples/models/models.py +++ b/examples/models/models.py @@ -95,6 +95,13 @@ def gen_torchvision_vit_model_and_inputs() -> Tuple[torch.nn.Module, Any]: return TorchVisionViTModel.get_model(), TorchVisionViTModel.get_example_inputs() +def gen_wav2letter_model_and_inputs() -> Tuple[torch.nn.Module, Any]: + from ..models.wav2letter import Wav2LetterModel + + model = Wav2LetterModel() + return model.get_model(), model.get_example_inputs() + + MODEL_NAME_TO_MODEL = { "mul": lambda: (MulModule(), MulModule.get_example_inputs()), "linear": lambda: (LinearModule(), LinearModule.get_example_inputs()), @@ -103,4 +110,5 @@ def gen_torchvision_vit_model_and_inputs() -> Tuple[torch.nn.Module, Any]: "mv2": gen_mobilenet_v2_model_inputs, "mv3": gen_mobilenet_v3_model_inputs, "vit": gen_torchvision_vit_model_and_inputs, + "w2l": gen_wav2letter_model_and_inputs, } diff --git a/examples/models/wav2letter/TARGETS b/examples/models/wav2letter/TARGETS new file mode 100644 index 00000000000..1d87315a3f1 --- /dev/null +++ b/examples/models/wav2letter/TARGETS @@ -0,0 +1,14 @@ +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") + +python_library( + name = "w2l_export", + srcs = [ + "__init__.py", + "export.py", + ], + base_module = "executorch.examples.models.wav2letter", + deps = [ + "//caffe2:torch", + "//pytorch/audio:torchaudio", + ], +) diff --git a/examples/models/wav2letter/__init__.py b/examples/models/wav2letter/__init__.py new file mode 100644 index 00000000000..84473d4f54f --- /dev/null +++ b/examples/models/wav2letter/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .export import Wav2LetterModel + +__all__ = [ + Wav2LetterModel, +] diff --git a/examples/models/wav2letter/export.py b/examples/models/wav2letter/export.py new file mode 100644 index 00000000000..38be6d9d9c4 --- /dev/null +++ b/examples/models/wav2letter/export.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +from torchaudio import models + +FORMAT = "[%(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(format=FORMAT) + + +class Wav2LetterModel: + def __init__(self): + self.batch_size = 10 + self.input_frames = 700 + self.vocab_size = 4096 + + def get_model(self): + logging.info("loading wav2letter model") + wav2letter = models.Wav2Letter(num_classes=self.vocab_size) + logging.info("loaded wav2letter model") + return wav2letter + + def get_example_inputs(self): + input_shape = (self.batch_size, 1, self.input_frames) + return (torch.randn(input_shape),)