|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from decimal import Decimal |
| 4 | +from fractions import Fraction |
| 5 | +from math import prod |
| 6 | +from typing import Any, Literal, Union |
| 7 | +from typing_extensions import assert_type |
| 8 | + |
| 9 | + |
| 10 | +class SupportsMul: |
| 11 | + def __mul__(self, other: Any) -> SupportsMul: |
| 12 | + return SupportsMul() |
| 13 | + |
| 14 | + |
| 15 | +class SupportsRMul: |
| 16 | + def __rmul__(self, other: Any) -> SupportsRMul: |
| 17 | + return SupportsRMul() |
| 18 | + |
| 19 | + |
| 20 | +class SupportsMulAndRMul: |
| 21 | + def __mul__(self, other: Any) -> SupportsMulAndRMul: |
| 22 | + return SupportsMulAndRMul() |
| 23 | + |
| 24 | + def __rmul__(self, other: Any) -> SupportsMulAndRMul: |
| 25 | + return SupportsMulAndRMul() |
| 26 | + |
| 27 | + |
| 28 | +literal_list: list[Literal[0, 1]] = [0, 1, 1] |
| 29 | + |
| 30 | +assert_type(prod([2, 4]), int) |
| 31 | +assert_type(prod([3, 5], start=4), int) |
| 32 | + |
| 33 | +assert_type(prod([True, False]), int) |
| 34 | +assert_type(prod([True, False], start=True), int) |
| 35 | +assert_type(prod(literal_list), int) |
| 36 | + |
| 37 | +assert_type(prod([SupportsMul(), SupportsMul()], start=SupportsMul()), SupportsMul) |
| 38 | +assert_type(prod([SupportsMulAndRMul(), SupportsMulAndRMul()]), Union[SupportsMulAndRMul, Literal[1]]) |
| 39 | + |
| 40 | +assert_type(prod([5.6, 3.2]), Union[float, Literal[1]]) |
| 41 | +assert_type(prod([5.6, 3.2], start=3), Union[float, int]) |
| 42 | + |
| 43 | +assert_type(prod([Fraction(7, 2), Fraction(3, 5)]), Union[Fraction, Literal[1]]) |
| 44 | +assert_type(prod([Fraction(7, 2), Fraction(3, 5)], start=Fraction(1)), Fraction) |
| 45 | +assert_type(prod([Decimal("3.14"), Decimal("2.71")]), Union[Decimal, Literal[1]]) |
| 46 | +assert_type(prod([Decimal("3.14"), Decimal("2.71")], start=Decimal("1.00")), Decimal) |
| 47 | +assert_type(prod([complex(7, 2), complex(3, 5)]), Union[complex, Literal[1]]) |
| 48 | +assert_type(prod([complex(7, 2), complex(3, 5)], start=complex(1, 0)), complex) |
| 49 | + |
| 50 | + |
| 51 | +# mypy and pyright infer the types differently for these, so we can't use assert_type |
| 52 | +# Just test that no error is emitted for any of these |
| 53 | +prod([5.6, 3.2]) # mypy: `float`; pyright: `float | Literal[0]` |
| 54 | +prod([2.5, 5.8], start=5) # mypy: `float`; pyright: `float | int` |
| 55 | + |
| 56 | +# These all fail at runtime |
| 57 | +prod([SupportsMul(), SupportsMul()]) # type: ignore |
| 58 | +prod([SupportsRMul(), SupportsRMul()], start=SupportsRMul()) # type: ignore |
| 59 | +prod([SupportsRMul(), SupportsRMul()]) # type: ignore |
| 60 | + |
| 61 | +# TODO: these pass pyright with the current stubs, but mypy erroneously emits an error: |
| 62 | +# prod([3, Fraction(7, 22), complex(8, 0), 9.83]) |
| 63 | +# prod([3, Decimal("0.98")]) |
0 commit comments