Skip to content

Commit 5b56daa

Browse files
authored
gh-130870: Preserve GenericAlias subclasses in typing.get_type_hints() (#131583)
1 parent f0c7344 commit 5b56daa

File tree

3 files changed

+44
-12
lines changed

3 files changed

+44
-12
lines changed

Lib/test/test_typing.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,7 +1605,10 @@ def func1(*args: *Ts): pass
16051605
self.assertEqual(gth(func1), {'args': Unpack[Ts]})
16061606

16071607
def func2(*args: *tuple[int, str]): pass
1608-
self.assertEqual(gth(func2), {'args': Unpack[tuple[int, str]]})
1608+
hint = gth(func2)['args']
1609+
self.assertIsInstance(hint, types.GenericAlias)
1610+
self.assertEqual(hint.__args__[0], int)
1611+
self.assertIs(hint.__unpacked__, True)
16091612

16101613
class CustomVariadic(Generic[*Ts]): pass
16111614

@@ -1620,7 +1623,10 @@ def func1(*args: '*Ts'): pass
16201623
{'args': Unpack[Ts]})
16211624

16221625
def func2(*args: '*tuple[int, str]'): pass
1623-
self.assertEqual(gth(func2), {'args': Unpack[tuple[int, str]]})
1626+
hint = gth(func2)['args']
1627+
self.assertIsInstance(hint, types.GenericAlias)
1628+
self.assertEqual(hint.__args__[0], int)
1629+
self.assertIs(hint.__unpacked__, True)
16241630

16251631
class CustomVariadic(Generic[*Ts]): pass
16261632

@@ -7114,6 +7120,24 @@ def add_right(self, node: 'Node[T]' = None):
71147120
right_hints = get_type_hints(t.add_right, globals(), locals())
71157121
self.assertEqual(right_hints['node'], Node[T])
71167122

7123+
def test_get_type_hints_preserve_generic_alias_subclasses(self):
7124+
# https://github.com/python/cpython/issues/130870
7125+
# A real world example of this is `collections.abc.Callable`. When parameterized,
7126+
# the result is a subclass of `types.GenericAlias`.
7127+
class MyAlias(types.GenericAlias):
7128+
pass
7129+
7130+
class MyClass:
7131+
def __class_getitem__(cls, args):
7132+
return MyAlias(cls, args)
7133+
7134+
# Using a forward reference is important, otherwise it works as expected.
7135+
# `y` tests that the `GenericAlias` subclass is preserved when stripping `Annotated`.
7136+
def func(x: MyClass['int'], y: MyClass[Annotated[int, ...]]): ...
7137+
7138+
assert isinstance(get_type_hints(func)['x'], MyAlias)
7139+
assert isinstance(get_type_hints(func)['y'], MyAlias)
7140+
71177141

71187142
class GetUtilitiesTestCase(TestCase):
71197143
def test_get_origin(self):

Lib/typing.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,17 @@ def inner(*args, **kwds):
407407
return decorator
408408

409409

410+
def _rebuild_generic_alias(alias: GenericAlias, args: tuple[object, ...]) -> GenericAlias:
411+
is_unpacked = alias.__unpacked__
412+
if _should_unflatten_callable_args(alias, args):
413+
t = alias.__origin__[(args[:-1], args[-1])]
414+
else:
415+
t = alias.__origin__[args]
416+
if is_unpacked:
417+
t = Unpack[t]
418+
return t
419+
420+
410421
def _deprecation_warning_for_no_type_params_passed(funcname: str) -> None:
411422
import warnings
412423

@@ -454,25 +465,20 @@ def _eval_type(t, globalns, localns, type_params=_sentinel, *, recursive_guard=f
454465
_make_forward_ref(arg) if isinstance(arg, str) else arg
455466
for arg in t.__args__
456467
)
457-
is_unpacked = t.__unpacked__
458-
if _should_unflatten_callable_args(t, args):
459-
t = t.__origin__[(args[:-1], args[-1])]
460-
else:
461-
t = t.__origin__[args]
462-
if is_unpacked:
463-
t = Unpack[t]
468+
else:
469+
args = t.__args__
464470

465471
ev_args = tuple(
466472
_eval_type(
467473
a, globalns, localns, type_params, recursive_guard=recursive_guard,
468474
format=format, owner=owner,
469475
)
470-
for a in t.__args__
476+
for a in args
471477
)
472478
if ev_args == t.__args__:
473479
return t
474480
if isinstance(t, GenericAlias):
475-
return GenericAlias(t.__origin__, ev_args)
481+
return _rebuild_generic_alias(t, ev_args)
476482
if isinstance(t, Union):
477483
return functools.reduce(operator.or_, ev_args)
478484
else:
@@ -2404,7 +2410,7 @@ def _strip_annotations(t):
24042410
stripped_args = tuple(_strip_annotations(a) for a in t.__args__)
24052411
if stripped_args == t.__args__:
24062412
return t
2407-
return GenericAlias(t.__origin__, stripped_args)
2413+
return _rebuild_generic_alias(t, stripped_args)
24082414
if isinstance(t, Union):
24092415
stripped_args = tuple(_strip_annotations(a) for a in t.__args__)
24102416
if stripped_args == t.__args__:
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Preserve :class:`types.GenericAlias` subclasses in
2+
:func:`typing.get_type_hints`

0 commit comments

Comments
 (0)