Skip to content

Commit c1884d1

Browse files
guangy10facebook-github-bot
authored andcommitted
Add torchvision_vit model to the examples (#33)
Summary: Pull Request resolved: #33 Info about the the model: https://pytorch.org/vision/main/models/generated/torchvision.models.vit_b_16.html#torchvision.models.vit_b_16 Differential Revision: D48012005 fbshipit-source-id: 7eb02ceaae21977419e21c3bb4c1cfe8c7de7cee
1 parent b66faf2 commit c1884d1

File tree

9 files changed

+73
-11
lines changed

9 files changed

+73
-11
lines changed

.ci/docker/ci_commit_pins/pytorch.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.1.0.dev20230809
1+
2.1.0.dev20230810

docs/website/docs/tutorials/00_setting_up_executorch.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Setting up PyTorch
2424
# Install the nightly builds
2525
# Note: if you are behind a firewall an appropriate proxy server must be setup
2626
# for all subsequent steps.
27-
TORCH_VERSION=2.1.0.dev20230809
27+
TORCH_VERSION=2.1.0.dev20230810
2828
pip install --force-reinstall --pre torch=="${TORCH_VERSION}" -i https://download.pytorch.org/whl/nightly/cpu
2929
```
3030

examples/export/test/test_export.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,6 @@
1111
from executorch.examples.export.utils import _CAPTURE_CONFIG, _EDGE_COMPILE_CONFIG
1212
from executorch.examples.models import MODEL_NAME_TO_MODEL
1313

14-
# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`.
15-
from executorch.extension.pybindings.portable import ( # @manual
16-
_load_for_executorch_from_buffer,
17-
)
18-
1914

2015
class ExportTest(unittest.TestCase):
2116
def _assert_eager_lowered_same_result(
@@ -28,13 +23,11 @@ def _assert_eager_lowered_same_result(
2823
)
2924

3025
executorch_model = edge_model.to_executorch()
31-
# pyre-ignore
32-
pte_model = _load_for_executorch_from_buffer(executorch_model.buffer)
3326

3427
with torch.no_grad():
3528
eager_output = eager_model(*example_inputs)
3629
with torch.no_grad():
37-
executorch_output = pte_model.forward(example_inputs)
30+
executorch_output = executorch_model.exported_program(*example_inputs)
3831
self.assertTrue(
3932
torch.allclose(eager_output, executorch_output[0], rtol=1e-5, atol=1e-5)
4033
)
@@ -50,3 +43,9 @@ def test_mv2_export_to_executorch(self):
5043
eager_model = eager_model.eval()
5144

5245
self._assert_eager_lowered_same_result(eager_model, example_inputs)
46+
47+
def test_vit_export_to_executorch(self):
48+
eager_model, example_inputs = MODEL_NAME_TO_MODEL["vit"]()
49+
eager_model = eager_model.eval()
50+
51+
self._assert_eager_lowered_same_result(eager_model, example_inputs)

examples/install_requirements.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,5 @@
1616
# `TORCH_VISION_VERSION` value in this document will be the correct version for the
1717
# corresponsing version of the repo.
1818

19-
TORCH_VISION_VERSION=0.16.0.dev20230809
19+
TORCH_VISION_VERSION=0.16.0.dev20230810
2020
pip install --force-reinstall --pre torchvision=="${TORCH_VISION_VERSION}" -i https://download.pytorch.org/whl/nightly/cpu

examples/models/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ python_library(
1010
"//caffe2:torch",
1111
"//executorch/examples/models/mobilenet_v2:mv2_export",
1212
"//executorch/examples/models/mobilenet_v3:mv3_export",
13+
"//executorch/examples/models/torchvision_vit:vit_export",
1314
"//executorch/exir/backend:compile_spec_schema",
1415
],
1516
)

examples/models/models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,18 @@ def gen_mobilenet_v2_model_inputs() -> Tuple[torch.nn.Module, Any]:
8989
return MV2Model.get_model(), MV2Model.get_example_inputs()
9090

9191

92+
def gen_torchvision_vit_model_and_inputs() -> Tuple[torch.nn.Module, Any]:
93+
from ..models.torchvision_vit import TorchVisionViTModel
94+
95+
return TorchVisionViTModel.get_model(), TorchVisionViTModel.get_example_inputs()
96+
97+
9298
MODEL_NAME_TO_MODEL = {
9399
"mul": lambda: (MulModule(), MulModule.get_example_inputs()),
94100
"linear": lambda: (LinearModule(), LinearModule.get_example_inputs()),
95101
"add": lambda: (AddModule(), AddModule.get_example_inputs()),
96102
"add_mul": lambda: (AddMulModule(), AddMulModule.get_example_inputs()),
97103
"mv2": gen_mobilenet_v2_model_inputs,
98104
"mv3": gen_mobilenet_v3_model_inputs,
105+
"vit": gen_torchvision_vit_model_and_inputs,
99106
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
python_library(
4+
name = "vit_export",
5+
srcs = [
6+
"__init__.py",
7+
"export.py",
8+
],
9+
base_module = "executorch.examples.models.torchvision_vit",
10+
deps = [
11+
"//caffe2:torch",
12+
"//pytorch/vision:torchvision",
13+
],
14+
)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .export import TorchVisionViTModel
8+
9+
__all__ = [
10+
TorchVisionViTModel,
11+
]
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
from torchvision import models
11+
12+
FORMAT = "[%(filename)s:%(lineno)s] %(message)s"
13+
logging.basicConfig(format=FORMAT)
14+
15+
16+
class TorchVisionViTModel:
17+
def __init__(self):
18+
pass
19+
20+
@staticmethod
21+
def get_model():
22+
logging.info("loading torchvision vit_b_16 model")
23+
vit_b_16 = models.vit_b_16(weights="IMAGENET1K_V1")
24+
logging.info("loaded torchvision vit_b_16 model")
25+
return vit_b_16
26+
27+
@staticmethod
28+
def get_example_inputs():
29+
input_shape = (1, 3, 224, 224)
30+
return (torch.randn(input_shape),)

0 commit comments

Comments
 (0)