Skip to content

Commit 301d852

Browse files
voidspacematrixise
authored andcommitted
[3.8] bpo-38122: minor fixes to AsyncMock spec handling (pythonGH-16099).
(cherry picked from commit 14fd925) Co-authored-by: Michael Foord <[email protected]>
1 parent f05d39d commit 301d852

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

Lib/unittest/mock.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -402,18 +402,12 @@ def __new__(cls, /, *args, **kw):
402402
# so we can create magic methods on the
403403
# class without stomping on other mocks
404404
bases = (cls,)
405-
if not issubclass(cls, AsyncMock):
405+
if not issubclass(cls, AsyncMockMixin):
406406
# Check if spec is an async object or function
407-
sig = inspect.signature(NonCallableMock.__init__)
408-
bound_args = sig.bind_partial(cls, *args, **kw).arguments
409-
spec_arg = [
410-
arg for arg in bound_args.keys()
411-
if arg.startswith('spec')
412-
]
413-
if spec_arg:
414-
# what if spec_set is different than spec?
415-
if _is_async_obj(bound_args[spec_arg[0]]):
416-
bases = (AsyncMockMixin, cls,)
407+
bound_args = _MOCK_SIG.bind_partial(cls, *args, **kw).arguments
408+
spec_arg = bound_args.get('spec_set', bound_args.get('spec'))
409+
if spec_arg and _is_async_obj(spec_arg):
410+
bases = (AsyncMockMixin, cls)
417411
new = type(cls.__name__, bases, {'__doc__': cls.__doc__})
418412
instance = object.__new__(new)
419413
return instance
@@ -1019,6 +1013,25 @@ def _calls_repr(self, prefix="Calls"):
10191013
return f"\n{prefix}: {safe_repr(self.mock_calls)}."
10201014

10211015

1016+
_MOCK_SIG = inspect.signature(NonCallableMock.__init__)
1017+
1018+
1019+
class _AnyComparer(list):
1020+
"""A list which checks if it contains a call which may have an
1021+
argument of ANY, flipping the components of item and self from
1022+
their traditional locations so that ANY is guaranteed to be on
1023+
the left."""
1024+
def __contains__(self, item):
1025+
for _call in self:
1026+
if len(item) != len(_call):
1027+
continue
1028+
if all([
1029+
expected == actual
1030+
for expected, actual in zip(item, _call)
1031+
]):
1032+
return True
1033+
return False
1034+
10221035

10231036
def _try_iter(obj):
10241037
if obj is None:

0 commit comments

Comments
 (0)