Skip to content

Improving type support for math.prod #13572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions stdlib/@tests/test_cases/check_math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import annotations

from decimal import Decimal
from fractions import Fraction
from math import prod
from typing import Any, Literal, Union
from typing_extensions import assert_type


class SupportsMul:
def __mul__(self, other: Any) -> SupportsMul:
return SupportsMul()


class SupportsRMul:
def __rmul__(self, other: Any) -> SupportsRMul:
return SupportsRMul()


class SupportsMulAndRMul:
def __mul__(self, other: Any) -> SupportsMulAndRMul:
return SupportsMulAndRMul()

def __rmul__(self, other: Any) -> SupportsMulAndRMul:
return SupportsMulAndRMul()


literal_list: list[Literal[0, 1]] = [0, 1, 1]

assert_type(prod([2, 4]), int)
assert_type(prod([3, 5], start=4), int)

assert_type(prod([True, False]), int)
assert_type(prod([True, False], start=True), int)
assert_type(prod(literal_list), int)

assert_type(prod([SupportsMul(), SupportsMul()], start=SupportsMul()), SupportsMul)
assert_type(prod([SupportsMulAndRMul(), SupportsMulAndRMul()]), Union[SupportsMulAndRMul, Literal[1]])

assert_type(prod([5.6, 3.2]), Union[float, Literal[1]])
assert_type(prod([5.6, 3.2], start=3), Union[float, int])

assert_type(prod([Fraction(7, 2), Fraction(3, 5)]), Union[Fraction, Literal[1]])
assert_type(prod([Fraction(7, 2), Fraction(3, 5)], start=Fraction(1)), Fraction)
assert_type(prod([Decimal("3.14"), Decimal("2.71")]), Union[Decimal, Literal[1]])
assert_type(prod([Decimal("3.14"), Decimal("2.71")], start=Decimal("1.00")), Decimal)
assert_type(prod([complex(7, 2), complex(3, 5)]), Union[complex, Literal[1]])
assert_type(prod([complex(7, 2), complex(3, 5)], start=complex(1, 0)), complex)


# mypy and pyright infer the types differently for these, so we can't use assert_type
# Just test that no error is emitted for any of these
prod([5.6, 3.2]) # mypy: `float`; pyright: `float | Literal[0]`
prod([2.5, 5.8], start=5) # mypy: `float`; pyright: `float | int`

# These all fail at runtime
prod([SupportsMul(), SupportsMul()]) # type: ignore
prod([SupportsRMul(), SupportsRMul()], start=SupportsRMul()) # type: ignore
prod([SupportsRMul(), SupportsRMul()]) # type: ignore

# TODO: these pass pyright with the current stubs, but mypy erroneously emits an error:
# prod([3, Fraction(7, 22), complex(8, 0), 9.83])
# prod([3, Decimal("0.98")])
6 changes: 6 additions & 0 deletions stdlib/_typeshed/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ class SupportsSub(Protocol[_T_contra, _T_co]):
class SupportsRSub(Protocol[_T_contra, _T_co]):
def __rsub__(self, x: _T_contra, /) -> _T_co: ...

class SupportsMul(Protocol[_T_contra, _T_co]):
def __mul__(self, x: _T_contra, /) -> _T_co: ...

class SupportsRMul(Protocol[_T_contra, _T_co]):
def __rmul__(self, x: _T_contra, /) -> _T_co: ...

class SupportsDivMod(Protocol[_T_contra, _T_co]):
def __divmod__(self, other: _T_contra, /) -> _T_co: ...

Expand Down
26 changes: 23 additions & 3 deletions stdlib/math.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys
from _typeshed import SupportsMul, SupportsRMul
from collections.abc import Iterable
from typing import Final, Protocol, SupportsFloat, SupportsIndex, TypeVar, overload
from typing import Any, Final, Literal, Protocol, SupportsFloat, SupportsIndex, TypeVar, overload
from typing_extensions import TypeAlias

_T = TypeVar("_T")
Expand Down Expand Up @@ -99,10 +100,29 @@ elif sys.version_info >= (3, 9):

def perm(n: SupportsIndex, k: SupportsIndex | None = None, /) -> int: ...
def pow(x: _SupportsFloatOrIndex, y: _SupportsFloatOrIndex, /) -> float: ...

_PositiveInteger: TypeAlias = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]
_NegativeInteger: TypeAlias = Literal[-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20]
_LiteralInteger = _PositiveInteger | _NegativeInteger | Literal[0] # noqa: Y026 # TODO: Use TypeAlias once mypy bugs are fixed

_MultiplicableT1 = TypeVar("_MultiplicableT1", bound=SupportsMul[Any, Any])
_MultiplicableT2 = TypeVar("_MultiplicableT2", bound=SupportsMul[Any, Any])

class _SupportsProdWithNoDefaultGiven(SupportsMul[Any, Any], SupportsRMul[int, Any], Protocol): ...

_SupportsProdNoDefaultT = TypeVar("_SupportsProdNoDefaultT", bound=_SupportsProdWithNoDefaultGiven)

# This stub is based on the type stub for `builtins.sum`.
# Like `builtins.sum`, it cannot be precisely represented in a type stub
# without introducing many false positives.
# For more details on its limitations and false positives, see #13572.
# Instead, just like `builtins.sum`, we explicitly handle several useful cases.
@overload
def prod(iterable: Iterable[bool | _LiteralInteger], /, *, start: int = 1) -> int: ... # type: ignore[overload-overlap]
@overload
def prod(iterable: Iterable[SupportsIndex], /, *, start: SupportsIndex = 1) -> int: ... # type: ignore[overload-overlap]
def prod(iterable: Iterable[_SupportsProdNoDefaultT], /) -> _SupportsProdNoDefaultT | Literal[1]: ...
@overload
def prod(iterable: Iterable[_SupportsFloatOrIndex], /, *, start: _SupportsFloatOrIndex = 1) -> float: ...
def prod(iterable: Iterable[_MultiplicableT1], /, *, start: _MultiplicableT2) -> _MultiplicableT1 | _MultiplicableT2: ...
def radians(x: _SupportsFloatOrIndex, /) -> float: ...
def remainder(x: _SupportsFloatOrIndex, y: _SupportsFloatOrIndex, /) -> float: ...
def sin(x: _SupportsFloatOrIndex, /) -> float: ...
Expand Down