Skip to content

Commit 7b62579

Browse files
Michael0x2addfisher
authored andcommitted
Make Type[T] comparisons work (#1806)
This commit introduces a workaround to fix the bug discussed in #1787 Previously, code where you compared two types (eg `int == int` or `int != int`) would cause mypy to incorrectly report a "too few arguments" error.
1 parent 52092dd commit 7b62579

File tree

4 files changed

+84
-22
lines changed

4 files changed

+84
-22
lines changed

mypy/checkexpr.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,8 @@ def check_call(self, callee: Type, args: List[Node],
286286
callee)
287287
elif isinstance(callee, Instance):
288288
call_function = analyze_member_access('__call__', callee, context,
289-
False, False, self.named_type, self.not_ready_callback,
290-
self.msg)
289+
False, False, False, self.named_type,
290+
self.not_ready_callback, self.msg)
291291
return self.check_call(call_function, args, arg_kinds, context, arg_names,
292292
callable_node, arg_messages)
293293
elif isinstance(callee, TypeVarType):
@@ -861,7 +861,7 @@ def analyze_ordinary_member_access(self, e: MemberExpr,
861861
else:
862862
# This is a reference to a non-module attribute.
863863
return analyze_member_access(e.name, self.accept(e.expr), e,
864-
is_lvalue, False,
864+
is_lvalue, False, False,
865865
self.named_type, self.not_ready_callback, self.msg)
866866

867867
def analyze_external_member_access(self, member: str, base_type: Type,
@@ -870,7 +870,7 @@ def analyze_external_member_access(self, member: str, base_type: Type,
870870
refer to private definitions. Return the result type.
871871
"""
872872
# TODO remove; no private definitions in mypy
873-
return analyze_member_access(member, base_type, context, False, False,
873+
return analyze_member_access(member, base_type, context, False, False, False,
874874
self.named_type, self.not_ready_callback, self.msg)
875875

876876
def visit_int_expr(self, e: IntExpr) -> Type:
@@ -1008,7 +1008,7 @@ def check_op_local(self, method: str, base_type: Type, arg: Node,
10081008
10091009
Return tuple (result type, inferred operator method type).
10101010
"""
1011-
method_type = analyze_member_access(method, base_type, context, False, False,
1011+
method_type = analyze_member_access(method, base_type, context, False, False, True,
10121012
self.named_type, self.not_ready_callback, local_errors)
10131013
return self.check_call(method_type, [arg], [nodes.ARG_POS],
10141014
context, arg_messages=local_errors)
@@ -1434,7 +1434,7 @@ def analyze_super(self, e: SuperExpr, is_lvalue: bool) -> Type:
14341434
if not self.chk.typing_mode_full():
14351435
return AnyType()
14361436
return analyze_member_access(e.name, self_type(e.info), e,
1437-
is_lvalue, True,
1437+
is_lvalue, True, False,
14381438
self.named_type, self.not_ready_callback,
14391439
self.msg, base)
14401440
else:

mypy/checkmember.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
DeletedType, NoneTyp, TypeType
99
)
1010
from mypy.nodes import TypeInfo, FuncBase, Var, FuncDef, SymbolNode, Context
11-
from mypy.nodes import ARG_POS, ARG_STAR, ARG_STAR2, function_type, Decorator, OverloadedFuncDef
11+
from mypy.nodes import ARG_POS, ARG_STAR, ARG_STAR2, OpExpr, ComparisonExpr
12+
from mypy.nodes import function_type, Decorator, OverloadedFuncDef
1213
from mypy.messages import MessageBuilder
1314
from mypy.maptype import map_instance_to_supertype
1415
from mypy.expandtype import expand_type_by_instance
@@ -23,6 +24,7 @@ def analyze_member_access(name: str,
2324
node: Context,
2425
is_lvalue: bool,
2526
is_super: bool,
27+
is_operator: bool,
2628
builtin_type: Callable[[str], Instance],
2729
not_ready_callback: Callable[[str, Context], None],
2830
msg: MessageBuilder,
@@ -79,45 +81,59 @@ def analyze_member_access(name: str,
7981
elif isinstance(typ, NoneTyp):
8082
# The only attribute NoneType has are those it inherits from object
8183
return analyze_member_access(name, builtin_type('builtins.object'), node, is_lvalue,
82-
is_super, builtin_type, not_ready_callback, msg,
84+
is_super, is_operator, builtin_type, not_ready_callback, msg,
8385
report_type=report_type)
8486
elif isinstance(typ, UnionType):
8587
# The base object has dynamic type.
8688
msg.disable_type_names += 1
87-
results = [analyze_member_access(name, subtype, node, is_lvalue,
88-
is_super, builtin_type, not_ready_callback, msg)
89+
results = [analyze_member_access(name, subtype, node, is_lvalue, is_super,
90+
is_operator, builtin_type, not_ready_callback, msg)
8991
for subtype in typ.items]
9092
msg.disable_type_names -= 1
9193
return UnionType.make_simplified_union(results)
9294
elif isinstance(typ, TupleType):
9395
# Actually look up from the fallback instance type.
94-
return analyze_member_access(name, typ.fallback, node, is_lvalue,
95-
is_super, builtin_type, not_ready_callback, msg)
96+
return analyze_member_access(name, typ.fallback, node, is_lvalue, is_super,
97+
is_operator, builtin_type, not_ready_callback, msg)
9698
elif isinstance(typ, FunctionLike) and typ.is_type_obj():
9799
# Class attribute.
98100
# TODO super?
99101
ret_type = typ.items()[0].ret_type
100102
if isinstance(ret_type, TupleType):
101103
ret_type = ret_type.fallback
102104
if isinstance(ret_type, Instance):
103-
result = analyze_class_attribute_access(ret_type, name, node, is_lvalue,
104-
builtin_type, not_ready_callback, msg)
105-
if result:
106-
return result
105+
if not is_operator:
106+
# When Python sees an operator (eg `3 == 4`), it automatically translates that
107+
# into something like `int.__eq__(3, 4)` instead of `(3).__eq__(4)` as an
108+
# optimation.
109+
#
110+
# While it normally it doesn't matter which of the two versions are used, it
111+
# does cause inconsistencies when working with classes. For example, translating
112+
# `int == int` to `int.__eq__(int)` would not work since `int.__eq__` is meant to
113+
# compare two int _instances_. What we really want is `type(int).__eq__`, which
114+
# is meant to compare two types or classes.
115+
#
116+
# This check makes sure that when we encounter an operator, we skip looking up
117+
# the corresponding method in the current instance to avoid this edge case.
118+
# See https://github.com/python/mypy/pull/1787 for more info.
119+
result = analyze_class_attribute_access(ret_type, name, node, is_lvalue,
120+
builtin_type, not_ready_callback, msg)
121+
if result:
122+
return result
107123
# Look up from the 'type' type.
108124
return analyze_member_access(name, typ.fallback, node, is_lvalue, is_super,
109-
builtin_type, not_ready_callback, msg,
125+
is_operator, builtin_type, not_ready_callback, msg,
110126
report_type=report_type)
111127
else:
112128
assert False, 'Unexpected type {}'.format(repr(ret_type))
113129
elif isinstance(typ, FunctionLike):
114130
# Look up from the 'function' type.
115131
return analyze_member_access(name, typ.fallback, node, is_lvalue, is_super,
116-
builtin_type, not_ready_callback, msg,
132+
is_operator, builtin_type, not_ready_callback, msg,
117133
report_type=report_type)
118134
elif isinstance(typ, TypeVarType):
119135
return analyze_member_access(name, typ.upper_bound, node, is_lvalue, is_super,
120-
builtin_type, not_ready_callback, msg,
136+
is_operator, builtin_type, not_ready_callback, msg,
121137
report_type=report_type)
122138
elif isinstance(typ, DeletedType):
123139
msg.deleted_as_rvalue(typ, node)
@@ -130,14 +146,15 @@ def analyze_member_access(name: str,
130146
elif isinstance(typ.item, TypeVarType):
131147
if isinstance(typ.item.upper_bound, Instance):
132148
item = typ.item.upper_bound
133-
if item:
149+
if item and not is_operator:
150+
# See comment above for why operators are skipped
134151
result = analyze_class_attribute_access(item, name, node, is_lvalue,
135152
builtin_type, not_ready_callback, msg)
136153
if result:
137154
return result
138155
fallback = builtin_type('builtins.type')
139156
return analyze_member_access(name, fallback, node, is_lvalue, is_super,
140-
builtin_type, not_ready_callback, msg,
157+
is_operator, builtin_type, not_ready_callback, msg,
141158
report_type=report_type)
142159
return msg.has_no_attr(report_type, name, node)
143160

test-data/unit/check-classes.test

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2006,3 +2006,44 @@ reveal_type(User) # E: Revealed type is 'builtins.type'
20062006
[builtins fixtures/args.py]
20072007
[out]
20082008

2009+
[case testTypeTypeComparisonWorks]
2010+
class User: pass
2011+
2012+
User == User
2013+
User == type(User())
2014+
type(User()) == User
2015+
type(User()) == type(User())
2016+
2017+
User != User
2018+
User != type(User())
2019+
type(User()) != User
2020+
type(User()) != type(User())
2021+
2022+
int == int
2023+
int == type(3)
2024+
type(3) == int
2025+
type(3) == type(3)
2026+
2027+
int != int
2028+
int != type(3)
2029+
type(3) != int
2030+
type(3) != type(3)
2031+
2032+
User is User
2033+
User is type(User)
2034+
type(User) is User
2035+
type(User) is type(User)
2036+
2037+
int is int
2038+
int is type(3)
2039+
type(3) is int
2040+
type(3) is type(3)
2041+
2042+
int.__eq__(int)
2043+
int.__eq__(3, 4)
2044+
[builtins fixtures/args.py]
2045+
[out]
2046+
main:33: error: Too few arguments for "__eq__" of "int"
2047+
main:33: error: Unsupported operand types for == ("int" and "int")
2048+
2049+

test-data/unit/fixtures/args.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,21 @@
88

99
class object:
1010
def __init__(self) -> None: pass
11+
def __eq__(self, o: object) -> bool: pass
12+
def __ne__(self, o: object) -> bool: pass
1113

1214
class type:
1315
@overload
1416
def __init__(self, o: object) -> None: pass
1517
@overload
1618
def __init__(self, name: str, bases: Tuple[type, ...], dict: Dict[str, Any]) -> None: pass
19+
def __call__(self, *args: Any, **kwargs: Any) -> Any: pass
1720

1821
class tuple(Iterable[Tco], Generic[Tco]): pass
1922
class dict(Generic[T, S]): pass
2023

21-
class int: pass
24+
class int:
25+
def __eq__(self, o: object) -> bool: pass
2226
class str: pass
2327
class bool: pass
2428
class function: pass

0 commit comments

Comments
 (0)