diff --git a/stdlib/functools.pyi b/stdlib/functools.pyi index 9957fa8f1634..b3a009ac4699 100644 --- a/stdlib/functools.pyi +++ b/stdlib/functools.pyi @@ -2,8 +2,8 @@ import sys import types from _typeshed import SupportsAllComparisons, SupportsItems from collections.abc import Callable, Hashable, Iterable, Sequence, Sized -from typing import Any, Generic, Literal, NamedTuple, TypedDict, TypeVar, final, overload -from typing_extensions import ParamSpec, Self, TypeAlias +from typing import Any, Generic, Literal, NamedTuple, Protocol, TypedDict, TypeVar, overload, type_check_only +from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias if sys.version_info >= (3, 9): from types import GenericAlias @@ -28,12 +28,16 @@ if sys.version_info >= (3, 9): __all__ += ["cache"] _T = TypeVar("_T") +_T_contra = TypeVar("_T_contra", contravariant=True) _T_co = TypeVar("_T_co", covariant=True) _S = TypeVar("_S") _PWrapped = ParamSpec("_PWrapped") _RWrapped = TypeVar("_RWrapped") _PWrapper = ParamSpec("_PWrapper") _RWrapper = TypeVar("_RWrapper") +_R = TypeVar("_R") +_R_co = TypeVar("_R_co", covariant=True) +_P = ParamSpec("_P") @overload def reduce(function: Callable[[_T, _S], _T], sequence: Iterable[_S], initial: _T, /) -> _T: ... @@ -51,22 +55,102 @@ if sys.version_info >= (3, 9): maxsize: int typed: bool -@final -class _lru_cache_wrapper(Generic[_T]): - __wrapped__: Callable[..., _T] - def __call__(self, *args: Hashable, **kwargs: Hashable) -> _T: ... +class _Method(Protocol[_T_contra, _P, _R_co]): + def __call__(__self, /, self: _T_contra, *args: _P.args, **kwds: _P.kwargs) -> _R_co: ... + +class _Classmethod(Protocol[_T_contra, _P, _R_co]): + def __call__(__self, /, cls: type[_T_contra], *args: _P.args, **kwds: _P.kwargs) -> _R_co: ... + +class _Function(Protocol[_P, _R_co]): + def __call__(__self, /, *args: _P.args, **kwds: _P.kwargs) -> _R_co: ... + +class _lru_cache_wrapper(Generic[_P, _R]): + __wrapped__: Callable[_P, _R] + # def __call__(self, *args: Hashable, **kwargs: Hashable) -> _T: ... def cache_info(self) -> _CacheInfo: ... def cache_clear(self) -> None: ... if sys.version_info >= (3, 9): def cache_parameters(self) -> _CacheParameters: ... - def __copy__(self) -> _lru_cache_wrapper[_T]: ... - def __deepcopy__(self, memo: Any, /) -> _lru_cache_wrapper[_T]: ... + def __copy__(self) -> Self: ... + def __deepcopy__(self, memo: Any, /) -> Self: ... + +# Below types are type_check_only. Type +# checkers assume that cache functions +# are descriptors because the cache +# decorators aren't simple pass through +# decorators. At run time the normal method +# binding behavior is applied but type +# checkers don't know, so the below +# descriptor types mirror the normal +# binding behavior but aren't present at runtime. +# Returned objects are still +# _lru_cache_wrapper at runtime. + +# function, staticmethod and bound +# class/instance method descriptor. +@type_check_only +class _FunctionDescriptor(_lru_cache_wrapper[_P, _R]): + def __call__(__self, /, *args: _P.args, **kwds: _P.kwargs) -> _R: ... + +@type_check_only +class _MethodDescriptor(_lru_cache_wrapper[Concatenate[_T, _P], _R], Generic[_T, _P, _R]): + def __call__(__self, /, self: _T, *args: _P.args, **kwds: _P.kwargs) -> _R: ... + @overload + def __get__(self, instance: None, owner: type[_T]) -> Self: ... + @overload + def __get__(self, instance: _T, owner: type[_T]) -> _FunctionDescriptor[_P, _R]: ... + +@type_check_only +class _ClassmethodDescriptor(_lru_cache_wrapper[Concatenate[type[_T], _P], _R], Generic[_T, _P, _R]): + def __call__(self, cls: type[_T], *args: _P.args, **kwds: _P.kwargs) -> _R: ... + @overload + def __get__(self, instance: None, owner: type[_T]) -> _FunctionDescriptor[_P, _R]: ... + @overload + def __get__(self, instance: _T, owner: type[_T]) -> _FunctionDescriptor[_P, _R]: ... + +# All functions except unbound classmethods +# because @classmethod is typed to expect a function +# with first parameter of type `type[T]` and has +# different binding behavior to __call__. +@type_check_only +class _FunctionHasHashableArgs(_lru_cache_wrapper[_P, _R]): + def __call__(__self, /, *args: Hashable, **kwds: Hashable) -> _R: ... +@type_check_only +class _ClassmethodHasHashableArgs(_lru_cache_wrapper[Concatenate[type[_T], _P], _R], Generic[_T, _P, _R]): + def __call__(__self, /, cls: type[_T], *args: Hashable, **kwds: Hashable) -> _R: ... + @overload + def __get__(__self, /, instance: None, owner: type[_T]) -> _FunctionHasHashableArgs[_P, _R]: ... + @overload + def __get__(__self, /, instance: _T, owner: type[_T]) -> _FunctionHasHashableArgs[_P, _R]: ... + +class _LruInnerFunction(Protocol): + @overload + def __call__( + self, fn: _Method[_T, _P, _R] + ) -> _MethodDescriptor[_T, _P, _R] | _FunctionHasHashableArgs[Concatenate[_T, _P], _R]: ... + @overload + def __call__( + self, fn: _Classmethod[_T, _P, _R] + ) -> _ClassmethodDescriptor[_T, _P, _R] | _ClassmethodHasHashableArgs[_T, _P, _R]: ... + @overload + def __call__(self, fn: _Function[_P, _R]) -> _FunctionDescriptor[_P, _R] | _FunctionHasHashableArgs[_P, _R]: ... + +@overload +def lru_cache(maxsize: int | None = 128, typed: bool = False) -> _LruInnerFunction: ... +@overload +def lru_cache( + maxsize: _Method[_T, _P, _R], typed: bool = False +) -> _MethodDescriptor[_T, _P, _R] | _FunctionHasHashableArgs[Concatenate[_T, _P], _R]: ... @overload -def lru_cache(maxsize: int | None = 128, typed: bool = False) -> Callable[[Callable[..., _T]], _lru_cache_wrapper[_T]]: ... +def lru_cache( + maxsize: _Classmethod[_T, _P, _R], typed: bool = False +) -> _ClassmethodDescriptor[_T, _P, _R] | _ClassmethodHasHashableArgs[_T, _P, _R]: ... @overload -def lru_cache(maxsize: Callable[..., _T], typed: bool = False) -> _lru_cache_wrapper[_T]: ... +def lru_cache( + maxsize: _Function[_P, _R], typed: bool = False +) -> _FunctionDescriptor[_P, _R] | _FunctionHasHashableArgs[_P, _R]: ... if sys.version_info >= (3, 12): WRAPPER_ASSIGNMENTS: tuple[ @@ -199,7 +283,16 @@ class cached_property(Generic[_T_co]): def __class_getitem__(cls, item: Any, /) -> GenericAlias: ... if sys.version_info >= (3, 9): - def cache(user_function: Callable[..., _T], /) -> _lru_cache_wrapper[_T]: ... + @overload + def cache( + user_function: _Method[_T, _P, _R], / + ) -> _MethodDescriptor[_T, _P, _R] | _FunctionHasHashableArgs[Concatenate[_T, _P], _R]: ... + @overload + def cache( + user_function: _Classmethod[_T, _P, _R], / + ) -> _ClassmethodDescriptor[_T, _P, _R] | _ClassmethodHasHashableArgs[_T, _P, _R]: ... + @overload + def cache(user_function: _Function[_P, _R], /) -> _FunctionDescriptor[_P, _R] | _FunctionHasHashableArgs[_P, _R]: ... def _make_key( args: tuple[Hashable, ...],