diff --git a/stdlib/contextlib.pyi b/stdlib/contextlib.pyi index dc2101dc01f7..6683955d9afe 100644 --- a/stdlib/contextlib.pyi +++ b/stdlib/contextlib.pyi @@ -2,7 +2,7 @@ import abc import sys from _typeshed import FileDescriptorOrPath, Unused from abc import abstractmethod -from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable, Generator, Iterator +from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable, Coroutine, Generator, Iterator from types import TracebackType from typing import IO, Any, Generic, Protocol, TypeVar, overload, runtime_checkable from typing_extensions import ParamSpec, Self, TypeAlias @@ -104,6 +104,11 @@ else: self, typ: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None ) -> bool | None: ... +@overload +def asynccontextmanager( + func: Callable[_P, Coroutine[Any, Any, AsyncIterator[_T_co]]] +) -> Callable[_P, _AsyncGeneratorContextManager[_T_co]]: ... +@overload def asynccontextmanager(func: Callable[_P, AsyncIterator[_T_co]]) -> Callable[_P, _AsyncGeneratorContextManager[_T_co]]: ... class _SupportsClose(Protocol): diff --git a/test_cases/stdlib/check_contextlib.py b/test_cases/stdlib/check_contextlib.py index 648661bca856..d5badda38dfe 100644 --- a/test_cases/stdlib/check_contextlib.py +++ b/test_cases/stdlib/check_contextlib.py @@ -1,6 +1,7 @@ from __future__ import annotations -from contextlib import ExitStack +from contextlib import ExitStack, asynccontextmanager +from typing import AsyncGenerator from typing_extensions import assert_type @@ -18,3 +19,17 @@ class Thing(ExitStack): assert_type(cm, ExitStack) with thing as cm2: assert_type(cm2, Thing) + + +@asynccontextmanager +async def async_context() -> AsyncGenerator[str, None]: + yield "example" + + +async def async_gen() -> AsyncGenerator[str, None]: + yield "async gen" + + +@asynccontextmanager +def async_cm_func() -> AsyncGenerator[str, None]: + return async_gen()