diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index 095967dc3fa1..9ca9a83b7453 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -244,10 +244,20 @@ def transform(self) -> bool: tvar_def=order_tvar_def, ) + parent_decorator_arguments = [] + for parent in info.mro[1:-1]: + parent_args = parent.metadata.get("dataclass") + if parent_args: + parent_decorator_arguments.append(parent_args) + if decorator_arguments["frozen"]: + if any(not parent["frozen"] for parent in parent_decorator_arguments): + ctx.api.fail("Cannot inherit frozen dataclass from a non-frozen one", info) self._propertize_callables(attributes, settable=False) self._freeze(attributes) else: + if any(parent["frozen"] for parent in parent_decorator_arguments): + ctx.api.fail("Cannot inherit non-frozen dataclass from a frozen one", info) self._propertize_callables(attributes) if decorator_arguments["slots"]: @@ -446,6 +456,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None: # copy() because we potentially modify all_attrs below and if this code requires debugging # we'll have unmodified attrs laying around. all_attrs = attrs.copy() + known_super_attrs = set() for info in cls.info.mro[1:-1]: if "dataclass_tag" in info.metadata and "dataclass" not in info.metadata: # We haven't processed the base class yet. Need another pass. @@ -467,6 +478,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None: with state.strict_optional_set(ctx.api.options.strict_optional): attr.expand_typevar_from_subtype(ctx.cls.info) known_attrs.add(name) + known_super_attrs.add(name) super_attrs.append(attr) elif all_attrs: # How early in the attribute list an attribute appears is determined by the @@ -481,6 +493,14 @@ def collect_attributes(self) -> list[DataclassAttribute] | None: all_attrs = super_attrs + all_attrs all_attrs.sort(key=lambda a: a.kw_only) + for known_super_attr_name in known_super_attrs: + sym_node = cls.info.names.get(known_super_attr_name) + if sym_node and sym_node.node and not isinstance(sym_node.node, Var): + ctx.api.fail( + "Dataclass attribute may only be overridden by another attribute", + sym_node.node, + ) + # Ensure that arguments without a default don't follow # arguments that have a default. found_default = False @@ -515,8 +535,8 @@ def _freeze(self, attributes: list[DataclassAttribute]) -> None: sym_node = info.names.get(attr.name) if sym_node is not None: var = sym_node.node - assert isinstance(var, Var) - var.is_property = True + if isinstance(var, Var): + var.is_property = True else: var = attr.to_var() var.info = info diff --git a/test-data/unit/check-dataclasses.test b/test-data/unit/check-dataclasses.test index d49a3a01e82d..719dc6aecd73 100644 --- a/test-data/unit/check-dataclasses.test +++ b/test-data/unit/check-dataclasses.test @@ -187,6 +187,66 @@ reveal_type(C) # N: Revealed type is "def (some_int: builtins.int, some_str: bu [builtins fixtures/dataclasses.pyi] +[case testDataclassIncompatibleOverrides] +# flags: --python-version 3.7 +from dataclasses import dataclass + +@dataclass +class Base: + foo: int + +@dataclass +class BadDerived1(Base): + def foo(self) -> int: # E: Dataclass attribute may only be overridden by another attribute \ + # E: Signature of "foo" incompatible with supertype "Base" + return 1 + +@dataclass +class BadDerived2(Base): + @property # E: Dataclass attribute may only be overridden by another attribute + def foo(self) -> int: # E: Cannot override writeable attribute with read-only property + return 2 + +@dataclass +class BadDerived3(Base): + class foo: pass # E: Dataclass attribute may only be overridden by another attribute +[builtins fixtures/dataclasses.pyi] + +[case testDataclassMultipleInheritance] +# flags: --python-version 3.7 +from dataclasses import dataclass + +class Unrelated: + foo: str + +@dataclass +class Base: + bar: int + +@dataclass +class Derived(Base, Unrelated): + pass + +d = Derived(3) +reveal_type(d.foo) # N: Revealed type is "builtins.str" +reveal_type(d.bar) # N: Revealed type is "builtins.int" +[builtins fixtures/dataclasses.pyi] + +[case testDataclassIncompatibleFrozenOverride] +# flags: --python-version 3.7 +from dataclasses import dataclass + +@dataclass(frozen=True) +class Base: + foo: int + +@dataclass(frozen=True) +class BadDerived(Base): + @property # E: Dataclass attribute may only be overridden by another attribute + def foo(self) -> int: + return 3 +[builtins fixtures/dataclasses.pyi] + [case testDataclassesFreezing] # flags: --python-version 3.7 from dataclasses import dataclass @@ -200,6 +260,28 @@ john.name = 'Ben' # E: Property "name" defined in "Person" is read-only [builtins fixtures/dataclasses.pyi] +[case testDataclassesInconsistentFreezing] +# flags: --python-version 3.7 +from dataclasses import dataclass + +@dataclass(frozen=True) +class FrozenBase: + pass + +@dataclass +class BadNormalDerived(FrozenBase): # E: Cannot inherit non-frozen dataclass from a frozen one + pass + +@dataclass +class NormalBase: + pass + +@dataclass(frozen=True) +class BadFrozenDerived(NormalBase): # E: Cannot inherit frozen dataclass from a non-frozen one + pass + +[builtins fixtures/dataclasses.pyi] + [case testDataclassesFields] # flags: --python-version 3.7 from dataclasses import dataclass, field @@ -1283,9 +1365,9 @@ from dataclasses import dataclass class A: foo: int -@dataclass +@dataclass(frozen=True) class B(A): - @property + @property # E: Dataclass attribute may only be overridden by another attribute def foo(self) -> int: pass reveal_type(B) # N: Revealed type is "def (foo: builtins.int) -> __main__.B"