Skip to content

Commit 22a8e4d

Browse files
committed
Add support for assert_type
See python/cpython#30843. The implementation mostly follows that of cast(). It relies on `mypy.sametypes.is_same_type()`.
1 parent 0c6b290 commit 22a8e4d

24 files changed

+173
-12
lines changed

mypy/checkexpr.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
get_proper_types, flatten_nested_unions, LITERAL_TYPE_NAMES,
2424
)
2525
from mypy.nodes import (
26-
NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr,
26+
AssertTypeExpr, NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr,
2727
MemberExpr, IntExpr, StrExpr, BytesExpr, UnicodeExpr, FloatExpr,
2828
OpExpr, UnaryExpr, IndexExpr, CastExpr, RevealExpr, TypeApplication, ListExpr,
2929
TupleExpr, DictExpr, LambdaExpr, SuperExpr, SliceExpr, Context, Expression,
@@ -3144,6 +3144,14 @@ def visit_cast_expr(self, expr: CastExpr) -> Type:
31443144
context=expr)
31453145
return target_type
31463146

3147+
def visit_assert_type_expr(self, expr: AssertTypeExpr) -> Type:
3148+
source_type = self.accept(expr.expr, type_context=AnyType(TypeOfAny.special_form),
3149+
allow_none_return=True, always_allow_any=True)
3150+
target_type = expr.type
3151+
if not is_same_type(source_type, target_type):
3152+
self.msg.assert_type_fail(source_type, target_type, expr)
3153+
return source_type
3154+
31473155
def visit_reveal_expr(self, expr: RevealExpr) -> Type:
31483156
"""Type check a reveal_type expression."""
31493157
if expr.kind == REVEAL_TYPE:

mypy/errorcodes.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ def __str__(self) -> str:
113113
REDUNDANT_CAST: Final = ErrorCode(
114114
"redundant-cast", "Check that cast changes type of expression", "General"
115115
)
116+
ASSERT_TYPE: Final = ErrorCode(
117+
"assert-type", "Check that assert_type() call succeeds", "General"
118+
)
116119
COMPARISON_OVERLAP: Final = ErrorCode(
117120
"comparison-overlap", "Check that types in comparisons and 'in' expressions overlap", "General"
118121
)

mypy/literals.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing_extensions import Final
33

44
from mypy.nodes import (
5-
Expression, ComparisonExpr, OpExpr, MemberExpr, UnaryExpr, StarExpr, IndexExpr, LITERAL_YES,
5+
AssertTypeExpr, Expression, ComparisonExpr, OpExpr, MemberExpr, UnaryExpr, StarExpr, IndexExpr, LITERAL_YES,
66
LITERAL_NO, NameExpr, LITERAL_TYPE, IntExpr, FloatExpr, ComplexExpr, StrExpr, BytesExpr,
77
UnicodeExpr, ListExpr, TupleExpr, SetExpr, DictExpr, CallExpr, SliceExpr, CastExpr,
88
ConditionalExpr, EllipsisExpr, YieldFromExpr, YieldExpr, RevealExpr, SuperExpr,
@@ -175,6 +175,9 @@ def visit_slice_expr(self, e: SliceExpr) -> None:
175175
def visit_cast_expr(self, e: CastExpr) -> None:
176176
return None
177177

178+
def visit_assert_type_expr(self, e: AssertTypeExpr) -> None:
179+
return None
180+
178181
def visit_conditional_expr(self, e: ConditionalExpr) -> None:
179182
return None
180183

mypy/messages.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,6 +1213,11 @@ def redundant_cast(self, typ: Type, context: Context) -> None:
12131213
self.fail('Redundant cast to {}'.format(format_type(typ)), context,
12141214
code=codes.REDUNDANT_CAST)
12151215

1216+
def assert_type_fail(self, source_type: Type, target_type: Type, context: Context) -> None:
1217+
self.fail(f"Expression is of type {format_type(source_type)}, "
1218+
f"not {format_type(target_type)}", context,
1219+
code=codes.ASSERT_TYPE)
1220+
12161221
def unimported_type_becomes_any(self, prefix: str, typ: Type, ctx: Context) -> None:
12171222
self.fail("{} becomes {} due to an unfollowed import".format(prefix, format_type(typ)),
12181223
ctx, code=codes.NO_ANY_UNIMPORTED)

mypy/mixedtraverser.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Optional
22

33
from mypy.nodes import (
4-
Var, FuncItem, ClassDef, AssignmentStmt, ForStmt, WithStmt,
4+
AssertTypeExpr, Var, FuncItem, ClassDef, AssignmentStmt, ForStmt, WithStmt,
55
CastExpr, TypeApplication, TypeAliasExpr, TypeVarExpr, TypedDictExpr, NamedTupleExpr,
66
PromoteExpr, NewTypeExpr
77
)
@@ -79,6 +79,10 @@ def visit_cast_expr(self, o: CastExpr) -> None:
7979
super().visit_cast_expr(o)
8080
o.type.accept(self)
8181

82+
def visit_assert_type_expr(self, o: AssertTypeExpr) -> None:
83+
super().visit_assert_type_expr(o)
84+
o.type.accept(self)
85+
8286
def visit_type_application(self, o: TypeApplication) -> None:
8387
super().visit_type_application(o)
8488
for t in o.types:

mypy/nodes.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1945,6 +1945,22 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:
19451945
return visitor.visit_cast_expr(self)
19461946

19471947

1948+
class AssertTypeExpr(Expression):
1949+
"""Represents a typing.assert_type(expr, type) call."""
1950+
__slots__ = ('expr', 'type')
1951+
1952+
expr: Expression
1953+
type: "mypy.types.Type"
1954+
1955+
def __init__(self, expr: Expression, typ: 'mypy.types.Type') -> None:
1956+
super().__init__()
1957+
self.expr = expr
1958+
self.type = typ
1959+
1960+
def accept(self, visitor: ExpressionVisitor[T]) -> T:
1961+
return visitor.visit_assert_type_expr(self)
1962+
1963+
19481964
class RevealExpr(Expression):
19491965
"""Reveal type expression reveal_type(expr) or reveal_locals() expression."""
19501966

mypy/semanal.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
from typing_extensions import Final, TypeAlias as _TypeAlias
5757

5858
from mypy.nodes import (
59-
MypyFile, TypeInfo, Node, AssignmentStmt, FuncDef, OverloadedFuncDef,
59+
AssertTypeExpr, MypyFile, TypeInfo, Node, AssignmentStmt, FuncDef, OverloadedFuncDef,
6060
ClassDef, Var, GDEF, FuncItem, Import, Expression, Lvalue,
6161
ImportFrom, ImportAll, Block, LDEF, NameExpr, MemberExpr,
6262
IndexExpr, TupleExpr, ListExpr, ExpressionStmt, ReturnStmt,
@@ -94,7 +94,7 @@
9494
from mypy.errorcodes import ErrorCode
9595
from mypy import message_registry, errorcodes as codes
9696
from mypy.types import (
97-
NEVER_NAMES, FunctionLike, UnboundType, TypeVarType, TupleType, UnionType, StarType,
97+
ASSERT_TYPE_NAMES, NEVER_NAMES, FunctionLike, UnboundType, TypeVarType, TupleType, UnionType, StarType,
9898
CallableType, Overloaded, Instance, Type, AnyType, LiteralType, LiteralValue,
9999
TypeTranslator, TypeOfAny, TypeType, NoneType, PlaceholderType, TPDICT_NAMES, ProperType,
100100
get_proper_type, get_proper_types, TypeAliasType, TypeVarLikeType, Parameters, ParamSpecType,
@@ -3897,6 +3897,19 @@ def visit_call_expr(self, expr: CallExpr) -> None:
38973897
expr.analyzed.line = expr.line
38983898
expr.analyzed.column = expr.column
38993899
expr.analyzed.accept(self)
3900+
elif refers_to_fullname(expr.callee, ASSERT_TYPE_NAMES):
3901+
if not self.check_fixed_args(expr, 2, 'assert_type'):
3902+
return
3903+
# Translate second argument to an unanalyzed type.
3904+
try:
3905+
target = self.expr_to_unanalyzed_type(expr.args[1])
3906+
except TypeTranslationError:
3907+
self.fail('assert_type() type is not a type', expr)
3908+
return
3909+
expr.analyzed = AssertTypeExpr(expr.args[0], target)
3910+
expr.analyzed.line = expr.line
3911+
expr.analyzed.column = expr.column
3912+
expr.analyzed.accept(self)
39003913
elif refers_to_fullname(expr.callee, REVEAL_TYPE_NAMES):
39013914
if not self.check_fixed_args(expr, 1, 'reveal_type'):
39023915
return
@@ -4200,6 +4213,12 @@ def visit_cast_expr(self, expr: CastExpr) -> None:
42004213
if analyzed is not None:
42014214
expr.type = analyzed
42024215

4216+
def visit_assert_type_expr(self, expr: AssertTypeExpr) -> None:
4217+
expr.expr.accept(self)
4218+
analyzed = self.anal_type(expr.type)
4219+
if analyzed is not None:
4220+
expr.type = analyzed
4221+
42034222
def visit_reveal_expr(self, expr: RevealExpr) -> None:
42044223
if expr.kind == REVEAL_TYPE:
42054224
if expr.expr is not None:

mypy/server/astmerge.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from typing import Dict, List, cast, TypeVar, Optional
4949

5050
from mypy.nodes import (
51-
MypyFile, SymbolTable, Block, AssignmentStmt, NameExpr, MemberExpr, RefExpr, TypeInfo,
51+
AssertTypeExpr, MypyFile, SymbolTable, Block, AssignmentStmt, NameExpr, MemberExpr, RefExpr, TypeInfo,
5252
FuncDef, ClassDef, NamedTupleExpr, SymbolNode, Var, Statement, SuperExpr, NewTypeExpr,
5353
OverloadedFuncDef, LambdaExpr, TypedDictExpr, EnumCallExpr, FuncBase, TypeAliasExpr, CallExpr,
5454
CastExpr, TypeAlias,
@@ -226,6 +226,10 @@ def visit_cast_expr(self, node: CastExpr) -> None:
226226
super().visit_cast_expr(node)
227227
self.fixup_type(node.type)
228228

229+
def visit_assert_type_expr(self, node: AssertTypeExpr) -> None:
230+
super().visit_assert_type_expr(node)
231+
self.fixup_type(node.type)
232+
229233
def visit_super_expr(self, node: SuperExpr) -> None:
230234
super().visit_super_expr(node)
231235
if node.info is not None:

mypy/server/deps.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class 'mod.Cls'. This can also refer to an attribute inherited from a
8484

8585
from mypy.checkmember import bind_self
8686
from mypy.nodes import (
87-
Node, Expression, MypyFile, FuncDef, ClassDef, AssignmentStmt, NameExpr, MemberExpr, Import,
87+
AssertTypeExpr, Node, Expression, MypyFile, FuncDef, ClassDef, AssignmentStmt, NameExpr, MemberExpr, Import,
8888
ImportFrom, CallExpr, CastExpr, TypeVarExpr, TypeApplication, IndexExpr, UnaryExpr, OpExpr,
8989
ComparisonExpr, GeneratorExpr, DictionaryComprehension, StarExpr, PrintStmt, ForStmt, WithStmt,
9090
TupleExpr, OperatorAssignmentStmt, DelStmt, YieldFromExpr, Decorator, Block,
@@ -686,6 +686,10 @@ def visit_cast_expr(self, e: CastExpr) -> None:
686686
super().visit_cast_expr(e)
687687
self.add_type_dependencies(e.type)
688688

689+
def visit_assert_type_expr(self, e: AssertTypeExpr) -> None:
690+
super().visit_assert_type_expr(e)
691+
self.add_type_dependencies(e.type)
692+
689693
def visit_type_application(self, e: TypeApplication) -> None:
690694
super().visit_type_application(e)
691695
for typ in e.types:

mypy/server/subexpr.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import List
44

55
from mypy.nodes import (
6-
Expression, Node, MemberExpr, YieldFromExpr, YieldExpr, CallExpr, OpExpr, ComparisonExpr,
6+
AssertTypeExpr, Expression, Node, MemberExpr, YieldFromExpr, YieldExpr, CallExpr, OpExpr, ComparisonExpr,
77
SliceExpr, CastExpr, RevealExpr, UnaryExpr, ListExpr, TupleExpr, DictExpr, SetExpr,
88
IndexExpr, GeneratorExpr, ListComprehension, SetComprehension, DictionaryComprehension,
99
ConditionalExpr, TypeApplication, LambdaExpr, StarExpr, BackquoteExpr, AwaitExpr,
@@ -99,6 +99,10 @@ def visit_cast_expr(self, e: CastExpr) -> None:
9999
self.add(e)
100100
super().visit_cast_expr(e)
101101

102+
def visit_assert_type_expr(self, e: AssertTypeExpr) -> None:
103+
self.add(e)
104+
super().visit_assert_type_expr(e)
105+
102106
def visit_reveal_expr(self, e: RevealExpr) -> None:
103107
self.add(e)
104108
super().visit_reveal_expr(e)

0 commit comments

Comments
 (0)