diff --git a/examples/export/test/test_export.py b/examples/export/test/test_export.py index 29fbbab0dca..14c049d4f77 100644 --- a/examples/export/test/test_export.py +++ b/examples/export/test/test_export.py @@ -126,3 +126,13 @@ def test_resnet50_export_to_executorch(self): self._assert_eager_lowered_same_result( eager_model, example_inputs, self.validate_tensor_allclose ) + + def test_dl3_export_to_executorch(self): + eager_model, example_inputs = EagerModelFactory.create_model( + *MODEL_NAME_TO_MODEL["dl3"] + ) + eager_model = eager_model.eval() + + self._assert_eager_lowered_same_result( + eager_model, example_inputs, self.validate_tensor_allclose + ) diff --git a/examples/models/TARGETS b/examples/models/TARGETS index e73978c5615..9fe0cc6a66e 100644 --- a/examples/models/TARGETS +++ b/examples/models/TARGETS @@ -9,6 +9,7 @@ python_library( deps = [ "//caffe2:torch", "//executorch/examples/models:model_base", # @manual + "//executorch/examples/models/deeplab_v3:dl3_model", # @manual "//executorch/examples/models/inception_v3:ic3_model", # @manual "//executorch/examples/models/inception_v4:ic4_model", # @manual "//executorch/examples/models/mobilebert:mobilebert_model", # @manual diff --git a/examples/models/__init__.py b/examples/models/__init__.py index 2eb0c3d4522..2a94c5c7d70 100644 --- a/examples/models/__init__.py +++ b/examples/models/__init__.py @@ -11,6 +11,7 @@ "linear": ("toy_model", "LinearModule"), "add": ("toy_model", "AddModule"), "add_mul": ("toy_model", "AddMulModule"), + "dl3": ("deeplab_v3", "DeepLabV3ResNet50Model"), "mobilebert": ("mobilebert", "MobileBertModelExample"), "mv2": ("mobilenet_v2", "MV2Model"), "mv3": ("mobilenet_v3", "MV3Model"), diff --git a/examples/models/deeplab_v3/__init__.py b/examples/models/deeplab_v3/__init__.py new file mode 100644 index 00000000000..53030ca8c93 --- /dev/null +++ b/examples/models/deeplab_v3/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .model import DeepLabV3ResNet50Model + +__all__ = [ + DeepLabV3ResNet50Model, +] diff --git a/examples/models/deeplab_v3/model.py b/examples/models/deeplab_v3/model.py new file mode 100644 index 00000000000..07be6567603 --- /dev/null +++ b/examples/models/deeplab_v3/model.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +from torchvision.models.segmentation import deeplabv3, deeplabv3_resnet50 # @manual + +from ..model_base import EagerModelBase + + +class DeepLabV3ResNet50Model(EagerModelBase): + def __init__(self): + pass + + def get_eager_model(self) -> torch.nn.Module: + logging.info("loading deeplabv3_resnet50 model") + deeplabv3_model = deeplabv3_resnet50( + weights=deeplabv3.DeepLabV3_ResNet50_Weights.DEFAULT + ) + logging.info("loaded deeplabv3_resnet50 model") + return deeplabv3_model + + def get_example_inputs(self): + input_shape = (1, 3, 224, 224) + return (torch.randn(input_shape),)