Skip to content

Commit 3f820bf

Browse files
committed
Add tests for activation functions
ghstack-source-id: 693e508 ghstack-comment-id: 3003538739 Pull-Request: #11961
1 parent 9653a05 commit 3f820bf

15 files changed

+598
-0
lines changed

backends/test/compliance_suite/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ def is_backend_enabled(backend):
5858
torch.float64,
5959
]
6060

61+
FLOAT_DTYPES =[
62+
torch.float16,
63+
torch.float32,
64+
torch.float64,
65+
]
66+
6167
class TestType(Enum):
6268
STANDARD = 1
6369
DTYPE = 2
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
from typing import Callable
6+
7+
import torch
8+
9+
from executorch.backends.test.compliance_suite import (
10+
dtype_test,
11+
operator_test,
12+
OperatorTest,
13+
)
14+
15+
class Model(torch.nn.Module):
16+
def __init__(self, alpha=1.0, inplace=False):
17+
super().__init__()
18+
self.alpha = alpha
19+
self.inplace = inplace
20+
21+
def forward(self, x):
22+
return torch.nn.functional.elu(x, alpha=self.alpha, inplace=self.inplace)
23+
24+
@operator_test
25+
class TestELU(OperatorTest):
26+
@dtype_test
27+
def test_elu_dtype(self, dtype, tester_factory: Callable) -> None:
28+
self._test_op(Model(), ((torch.rand(2, 10) * 100).to(dtype),), tester_factory)
29+
30+
def test_elu_f32_single_dim(self, tester_factory: Callable) -> None:
31+
self._test_op(Model(), (torch.randn(20),), tester_factory)
32+
33+
def test_elu_f32_multi_dim(self, tester_factory: Callable) -> None:
34+
self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory)
35+
36+
def test_elu_f32_alpha(self, tester_factory: Callable) -> None:
37+
self._test_op(Model(alpha=0.5), (torch.randn(3, 4, 5),), tester_factory)
38+
39+
def test_elu_f32_inplace(self, tester_factory: Callable) -> None:
40+
self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), tester_factory)
41+
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
from typing import Callable
6+
7+
import torch
8+
9+
from executorch.backends.test.compliance_suite import (
10+
dtype_test,
11+
operator_test,
12+
OperatorTest,
13+
)
14+
15+
class Model(torch.nn.Module):
16+
def __init__(self, approximate="none"):
17+
super().__init__()
18+
self.approximate = approximate
19+
20+
def forward(self, x):
21+
return torch.nn.functional.gelu(x, approximate=self.approximate)
22+
23+
@operator_test
24+
class TestGELU(OperatorTest):
25+
@dtype_test
26+
def test_gelu_dtype(self, dtype, tester_factory: Callable) -> None:
27+
self._test_op(Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), tester_factory)
28+
29+
def test_gelu_f32_single_dim(self, tester_factory: Callable) -> None:
30+
self._test_op(Model(), (torch.randn(20),), tester_factory)
31+
32+
def test_gelu_f32_multi_dim(self, tester_factory: Callable) -> None:
33+
self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory)
34+
35+
def test_gelu_f32_tanh_approximation(self, tester_factory: Callable) -> None:
36+
self._test_op(Model(approximate="tanh"), (torch.randn(3, 4, 5),), tester_factory)
37+
38+
def test_gelu_f32_boundary_values(self, tester_factory: Callable) -> None:
39+
# Test with specific values spanning negative and positive ranges
40+
x = torch.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0])
41+
self._test_op(Model(), (x,), tester_factory)
42+
43+
def test_gelu_f32_tanh_boundary_values(self, tester_factory: Callable) -> None:
44+
# Test tanh approximation with specific values
45+
x = torch.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0])
46+
self._test_op(Model(approximate="tanh"), (x,), tester_factory)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
from typing import Callable
6+
7+
import torch
8+
9+
from executorch.backends.test.compliance_suite import (
10+
dtype_test,
11+
operator_test,
12+
OperatorTest,
13+
)
14+
15+
class Model(torch.nn.Module):
16+
def __init__(self, dim=-1):
17+
super().__init__()
18+
self.dim = dim
19+
20+
def forward(self, x):
21+
return torch.nn.functional.glu(x, dim=self.dim)
22+
23+
@operator_test
24+
class TestGLU(OperatorTest):
25+
@dtype_test
26+
def test_glu_dtype(self, dtype, tester_factory: Callable) -> None:
27+
# Input must have even number of elements in the specified dimension
28+
self._test_op(Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), tester_factory)
29+
30+
def test_glu_f32_dim_last(self, tester_factory: Callable) -> None:
31+
# Default dim is -1 (last dimension)
32+
self._test_op(Model(), (torch.randn(3, 4, 6),), tester_factory)
33+
34+
def test_glu_f32_dim_first(self, tester_factory: Callable) -> None:
35+
# Test with dim=0 (first dimension)
36+
self._test_op(Model(dim=0), (torch.randn(4, 3, 5),), tester_factory)
37+
38+
def test_glu_f32_dim_middle(self, tester_factory: Callable) -> None:
39+
# Test with dim=1 (middle dimension)
40+
self._test_op(Model(dim=1), (torch.randn(3, 8, 5),), tester_factory)
41+
42+
def test_glu_f32_boundary_values(self, tester_factory: Callable) -> None:
43+
# Test with specific values spanning negative and positive ranges
44+
# Input must have even number of elements in the specified dimension
45+
x = torch.tensor([[-10.0, -5.0, -1.0, 0.0], [1.0, 5.0, 10.0, -2.0]])
46+
self._test_op(Model(dim=1), (x,), tester_factory)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
from typing import Callable
6+
7+
import torch
8+
9+
from executorch.backends.test.compliance_suite import (
10+
dtype_test,
11+
operator_test,
12+
OperatorTest,
13+
)
14+
15+
class Model(torch.nn.Module):
16+
def __init__(self, inplace=False):
17+
super().__init__()
18+
self.inplace = inplace
19+
20+
def forward(self, x):
21+
return torch.nn.functional.hardsigmoid(x, inplace=self.inplace)
22+
23+
@operator_test
24+
class TestHardsigmoid(OperatorTest):
25+
@dtype_test
26+
def test_hardsigmoid_dtype(self, dtype, tester_factory: Callable) -> None:
27+
self._test_op(Model(), ((torch.rand(2, 10)).to(dtype),), tester_factory)
28+
29+
def test_hardsigmoid_f32_single_dim(self, tester_factory: Callable) -> None:
30+
self._test_op(Model(), (torch.randn(20),), tester_factory)
31+
32+
def test_hardsigmoid_f32_multi_dim(self, tester_factory: Callable) -> None:
33+
self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory)
34+
35+
def test_hardsigmoid_f32_inplace(self, tester_factory: Callable) -> None:
36+
self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), tester_factory)
37+
38+
def test_hardsigmoid_f32_boundary_values(self, tester_factory: Callable) -> None:
39+
# Test with values that span the hardsigmoid's piecewise regions
40+
x = torch.tensor([-5.0, -3.0, -1.0, 0.0, 1.0, 3.0, 5.0])
41+
self._test_op(Model(), (x,), tester_factory)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
from typing import Callable
6+
7+
import torch
8+
9+
from executorch.backends.test.compliance_suite import (
10+
dtype_test,
11+
operator_test,
12+
OperatorTest,
13+
)
14+
15+
class Model(torch.nn.Module):
16+
def __init__(self, inplace=False):
17+
super().__init__()
18+
self.inplace = inplace
19+
20+
def forward(self, x):
21+
return torch.nn.functional.hardswish(x, inplace=self.inplace)
22+
23+
@operator_test
24+
class TestHardswish(OperatorTest):
25+
@dtype_test
26+
def test_hardswish_dtype(self, dtype, tester_factory: Callable) -> None:
27+
self._test_op(Model(), ((torch.rand(2, 10)).to(dtype),), tester_factory)
28+
29+
def test_hardswish_f32_single_dim(self, tester_factory: Callable) -> None:
30+
self._test_op(Model(), (torch.randn(20),), tester_factory)
31+
32+
def test_hardswish_f32_multi_dim(self, tester_factory: Callable) -> None:
33+
self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory)
34+
35+
def test_hardswish_f32_inplace(self, tester_factory: Callable) -> None:
36+
self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), tester_factory)
37+
38+
def test_hardswish_f32_boundary_values(self, tester_factory: Callable) -> None:
39+
# Test with values that span the hardswish's piecewise regions
40+
x = torch.tensor([-5.0, -3.0, -1.0, 0.0, 1.0, 3.0, 5.0])
41+
self._test_op(Model(), (x,), tester_factory)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
from typing import Callable
6+
7+
import torch
8+
9+
from executorch.backends.test.compliance_suite import (
10+
dtype_test,
11+
operator_test,
12+
OperatorTest,
13+
)
14+
15+
class Model(torch.nn.Module):
16+
def __init__(self, min_val=-1.0, max_val=1.0, inplace=False):
17+
super().__init__()
18+
self.min_val = min_val
19+
self.max_val = max_val
20+
self.inplace = inplace
21+
22+
def forward(self, x):
23+
return torch.nn.functional.hardtanh(x, min_val=self.min_val, max_val=self.max_val, inplace=self.inplace)
24+
25+
@operator_test
26+
class TestHardtanh(OperatorTest):
27+
@dtype_test
28+
def test_hardtanh_dtype(self, dtype, tester_factory: Callable) -> None:
29+
self._test_op(Model(), ((torch.rand(2, 10) * 4 - 2).to(dtype),), tester_factory)
30+
31+
def test_hardtanh_f32_single_dim(self, tester_factory: Callable) -> None:
32+
self._test_op(Model(), (torch.randn(20),), tester_factory)
33+
34+
def test_hardtanh_f32_multi_dim(self, tester_factory: Callable) -> None:
35+
self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory)
36+
37+
def test_hardtanh_f32_custom_range(self, tester_factory: Callable) -> None:
38+
self._test_op(Model(min_val=-2.0, max_val=2.0), (torch.randn(3, 4, 5),), tester_factory)
39+
40+
def test_hardtanh_f32_inplace(self, tester_factory: Callable) -> None:
41+
self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), tester_factory)
42+
43+
def test_hardtanh_f32_boundary_values(self, tester_factory: Callable) -> None:
44+
# Test with values that span the hardtanh's piecewise regions
45+
x = torch.tensor([-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0])
46+
self._test_op(Model(), (x,), tester_factory)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
from typing import Callable
6+
7+
import torch
8+
9+
from executorch.backends.test.compliance_suite import (
10+
dtype_test,
11+
operator_test,
12+
OperatorTest,
13+
)
14+
15+
class Model(torch.nn.Module):
16+
def __init__(self, negative_slope=0.01, inplace=False):
17+
super().__init__()
18+
self.negative_slope = negative_slope
19+
self.inplace = inplace
20+
21+
def forward(self, x):
22+
return torch.nn.functional.leaky_relu(x, negative_slope=self.negative_slope, inplace=self.inplace)
23+
24+
@operator_test
25+
class TestLeakyReLU(OperatorTest):
26+
@dtype_test
27+
def test_leaky_relu_dtype(self, dtype, tester_factory: Callable) -> None:
28+
self._test_op(Model(), ((torch.rand(2, 10) * 2 - 1).to(dtype),), tester_factory)
29+
30+
def test_leaky_relu_f32_single_dim(self, tester_factory: Callable) -> None:
31+
self._test_op(Model(), (torch.randn(20),), tester_factory)
32+
33+
def test_leaky_relu_f32_multi_dim(self, tester_factory: Callable) -> None:
34+
self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory)
35+
36+
def test_leaky_relu_f32_custom_slope(self, tester_factory: Callable) -> None:
37+
self._test_op(Model(negative_slope=0.1), (torch.randn(3, 4, 5),), tester_factory)
38+
39+
def test_leaky_relu_f32_inplace(self, tester_factory: Callable) -> None:
40+
self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), tester_factory)
41+
42+
def test_leaky_relu_f32_boundary_values(self, tester_factory: Callable) -> None:
43+
# Test with specific positive and negative values
44+
x = torch.tensor([-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0])
45+
self._test_op(Model(), (x,), tester_factory)
46+
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
from typing import Callable
6+
7+
import torch
8+
9+
from executorch.backends.test.compliance_suite import (
10+
dtype_test,
11+
operator_test,
12+
OperatorTest,
13+
)
14+
15+
class Model(torch.nn.Module):
16+
def forward(self, x):
17+
return torch.nn.functional.logsigmoid(x)
18+
19+
@operator_test
20+
class TestLogSigmoid(OperatorTest):
21+
@dtype_test
22+
def test_logsigmoid_dtype(self, dtype, tester_factory: Callable) -> None:
23+
self._test_op(Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), tester_factory)
24+
25+
def test_logsigmoid_f32_single_dim(self, tester_factory: Callable) -> None:
26+
self._test_op(Model(), (torch.randn(20),), tester_factory)
27+
28+
def test_logsigmoid_f32_multi_dim(self, tester_factory: Callable) -> None:
29+
self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory)
30+
31+
def test_logsigmoid_f32_boundary_values(self, tester_factory: Callable) -> None:
32+
# Test with specific values spanning negative and positive ranges
33+
x = torch.tensor([-10.0, -5.0, -1.0, 0.0, 1.0, 5.0, 10.0])
34+
self._test_op(Model(), (x,), tester_factory)

0 commit comments

Comments
 (0)