Skip to content

Commit 9f11db4

Browse files
Improving type support for math.prod (#13572)
1 parent 175e700 commit 9f11db4

File tree

3 files changed

+92
-3
lines changed

3 files changed

+92
-3
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from __future__ import annotations
2+
3+
from decimal import Decimal
4+
from fractions import Fraction
5+
from math import prod
6+
from typing import Any, Literal, Union
7+
from typing_extensions import assert_type
8+
9+
10+
class SupportsMul:
11+
def __mul__(self, other: Any) -> SupportsMul:
12+
return SupportsMul()
13+
14+
15+
class SupportsRMul:
16+
def __rmul__(self, other: Any) -> SupportsRMul:
17+
return SupportsRMul()
18+
19+
20+
class SupportsMulAndRMul:
21+
def __mul__(self, other: Any) -> SupportsMulAndRMul:
22+
return SupportsMulAndRMul()
23+
24+
def __rmul__(self, other: Any) -> SupportsMulAndRMul:
25+
return SupportsMulAndRMul()
26+
27+
28+
literal_list: list[Literal[0, 1]] = [0, 1, 1]
29+
30+
assert_type(prod([2, 4]), int)
31+
assert_type(prod([3, 5], start=4), int)
32+
33+
assert_type(prod([True, False]), int)
34+
assert_type(prod([True, False], start=True), int)
35+
assert_type(prod(literal_list), int)
36+
37+
assert_type(prod([SupportsMul(), SupportsMul()], start=SupportsMul()), SupportsMul)
38+
assert_type(prod([SupportsMulAndRMul(), SupportsMulAndRMul()]), Union[SupportsMulAndRMul, Literal[1]])
39+
40+
assert_type(prod([5.6, 3.2]), Union[float, Literal[1]])
41+
assert_type(prod([5.6, 3.2], start=3), Union[float, int])
42+
43+
assert_type(prod([Fraction(7, 2), Fraction(3, 5)]), Union[Fraction, Literal[1]])
44+
assert_type(prod([Fraction(7, 2), Fraction(3, 5)], start=Fraction(1)), Fraction)
45+
assert_type(prod([Decimal("3.14"), Decimal("2.71")]), Union[Decimal, Literal[1]])
46+
assert_type(prod([Decimal("3.14"), Decimal("2.71")], start=Decimal("1.00")), Decimal)
47+
assert_type(prod([complex(7, 2), complex(3, 5)]), Union[complex, Literal[1]])
48+
assert_type(prod([complex(7, 2), complex(3, 5)], start=complex(1, 0)), complex)
49+
50+
51+
# mypy and pyright infer the types differently for these, so we can't use assert_type
52+
# Just test that no error is emitted for any of these
53+
prod([5.6, 3.2]) # mypy: `float`; pyright: `float | Literal[0]`
54+
prod([2.5, 5.8], start=5) # mypy: `float`; pyright: `float | int`
55+
56+
# These all fail at runtime
57+
prod([SupportsMul(), SupportsMul()]) # type: ignore
58+
prod([SupportsRMul(), SupportsRMul()], start=SupportsRMul()) # type: ignore
59+
prod([SupportsRMul(), SupportsRMul()]) # type: ignore
60+
61+
# TODO: these pass pyright with the current stubs, but mypy erroneously emits an error:
62+
# prod([3, Fraction(7, 22), complex(8, 0), 9.83])
63+
# prod([3, Decimal("0.98")])

stdlib/_typeshed/__init__.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,12 @@ class SupportsSub(Protocol[_T_contra, _T_co]):
117117
class SupportsRSub(Protocol[_T_contra, _T_co]):
118118
def __rsub__(self, x: _T_contra, /) -> _T_co: ...
119119

120+
class SupportsMul(Protocol[_T_contra, _T_co]):
121+
def __mul__(self, x: _T_contra, /) -> _T_co: ...
122+
123+
class SupportsRMul(Protocol[_T_contra, _T_co]):
124+
def __rmul__(self, x: _T_contra, /) -> _T_co: ...
125+
120126
class SupportsDivMod(Protocol[_T_contra, _T_co]):
121127
def __divmod__(self, other: _T_contra, /) -> _T_co: ...
122128

stdlib/math.pyi

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import sys
2+
from _typeshed import SupportsMul, SupportsRMul
23
from collections.abc import Iterable
3-
from typing import Final, Protocol, SupportsFloat, SupportsIndex, TypeVar, overload
4+
from typing import Any, Final, Literal, Protocol, SupportsFloat, SupportsIndex, TypeVar, overload
45
from typing_extensions import TypeAlias
56

67
_T = TypeVar("_T")
@@ -99,10 +100,29 @@ elif sys.version_info >= (3, 9):
99100

100101
def perm(n: SupportsIndex, k: SupportsIndex | None = None, /) -> int: ...
101102
def pow(x: _SupportsFloatOrIndex, y: _SupportsFloatOrIndex, /) -> float: ...
103+
104+
_PositiveInteger: TypeAlias = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]
105+
_NegativeInteger: TypeAlias = Literal[-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20]
106+
_LiteralInteger = _PositiveInteger | _NegativeInteger | Literal[0] # noqa: Y026 # TODO: Use TypeAlias once mypy bugs are fixed
107+
108+
_MultiplicableT1 = TypeVar("_MultiplicableT1", bound=SupportsMul[Any, Any])
109+
_MultiplicableT2 = TypeVar("_MultiplicableT2", bound=SupportsMul[Any, Any])
110+
111+
class _SupportsProdWithNoDefaultGiven(SupportsMul[Any, Any], SupportsRMul[int, Any], Protocol): ...
112+
113+
_SupportsProdNoDefaultT = TypeVar("_SupportsProdNoDefaultT", bound=_SupportsProdWithNoDefaultGiven)
114+
115+
# This stub is based on the type stub for `builtins.sum`.
116+
# Like `builtins.sum`, it cannot be precisely represented in a type stub
117+
# without introducing many false positives.
118+
# For more details on its limitations and false positives, see #13572.
119+
# Instead, just like `builtins.sum`, we explicitly handle several useful cases.
120+
@overload
121+
def prod(iterable: Iterable[bool | _LiteralInteger], /, *, start: int = 1) -> int: ... # type: ignore[overload-overlap]
102122
@overload
103-
def prod(iterable: Iterable[SupportsIndex], /, *, start: SupportsIndex = 1) -> int: ... # type: ignore[overload-overlap]
123+
def prod(iterable: Iterable[_SupportsProdNoDefaultT], /) -> _SupportsProdNoDefaultT | Literal[1]: ...
104124
@overload
105-
def prod(iterable: Iterable[_SupportsFloatOrIndex], /, *, start: _SupportsFloatOrIndex = 1) -> float: ...
125+
def prod(iterable: Iterable[_MultiplicableT1], /, *, start: _MultiplicableT2) -> _MultiplicableT1 | _MultiplicableT2: ...
106126
def radians(x: _SupportsFloatOrIndex, /) -> float: ...
107127
def remainder(x: _SupportsFloatOrIndex, y: _SupportsFloatOrIndex, /) -> float: ...
108128
def sin(x: _SupportsFloatOrIndex, /) -> float: ...

0 commit comments

Comments
 (0)