diff --git a/README.md b/README.md index 0c3a65bd..d99bef0b 100644 --- a/README.md +++ b/README.md @@ -2112,6 +2112,7 @@ It has the following parameters that affects class selection rules: by which all the variants can be distinguished * `include_subtypes` — allow to deserialize subclasses * `include_supertypes` — allow to deserialize superclasses +* `include_current_type` - allow to deserialize to the class which hosts this discriminator * `variant_tagger_fn` — a custom function used to generate tag values associated with a variant diff --git a/mashumaro/core/meta/code/builder.py b/mashumaro/core/meta/code/builder.py index 1807dc32..b1f6ccbe 100644 --- a/mashumaro/core/meta/code/builder.py +++ b/mashumaro/core/meta/code/builder.py @@ -380,12 +380,19 @@ def _add_unpack_method_lines(self, method_name: str) -> None: if self.decoder is not None: self.add_line("d = decoder(d)") discr = self.get_discriminator() + inherited = False + if not discr: + discr = self.get_discriminator(look_in_parents=True) + if discr and discr.field is None: + discr = None + inherited = True if discr: if not discr.include_subtypes: raise ValueError( "Config based discriminator must have " "'include_subtypes' enabled" ) + fallthrough = inherited or discr.include_current_type discr = Discriminator( # prevent RecursionError field=discr.field, @@ -401,8 +408,18 @@ def _add_unpack_method_lines(self, method_name: str) -> None: field_ctx=FieldContext("", {}), ) ) - self.add_line(f"return {method}") - return + if fallthrough: + with self.indent("try:"): + self.add_line(f"return {method}") + with self.indent( + "except (" + "SuitableVariantNotFoundError, " + "MissingDiscriminatorError):" + ): + self.add_line("pass") + else: + self.add_line(f"return {method}") + return pre_deserialize = self.get_declared_hook(__PRE_DESERIALIZE__) if pre_deserialize: if not isinstance(pre_deserialize, classmethod): diff --git a/mashumaro/types.py b/mashumaro/types.py index 35091742..e203975e 100644 --- a/mashumaro/types.py +++ b/mashumaro/types.py @@ -101,6 +101,7 @@ class Discriminator: field: Optional[str] = None include_supertypes: bool = False include_subtypes: bool = False + include_current_type: bool = False variant_tagger_fn: Optional[Callable[[Any], Any]] = None def __post_init__(self) -> None: diff --git a/tests/test_config.py b/tests/test_config.py index 4a20cb06..3497f628 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -461,3 +461,176 @@ def test_forbid_extra_keys_with_discriminator_for_subclass(): {"x": "foo", "__type": "_VariantByField4", "y": "bar"} ) assert exc_info.value.extra_keys == {"y"} + + +@pytest.mark.parametrize("forbid", [False, True]) +def test_subclass_from_dict_inherits_discriminator(forbid): + + def _tagger(cls: type) -> str: + return f"{cls.__module__}.{cls.__qualname__}" + + @dataclass + class Root(DataClassDictMixin): + class Config(BaseConfig): + forbid_extra_keys = forbid + discriminator = Discriminator( + field="_type", + include_subtypes=True, + variant_tagger_fn=_tagger, + ) + + name: str = "base" + + def __post_serialize__(self, data: dict): + data["_type"] = _tagger(type(self)) + return data + + @classmethod + def __pre_deserialize__(cls, data: dict) -> dict: + return {k: v for k, v in data.items() if k != "_type"} + + @dataclass + class Middle(Root): + value: int = 0 + + @dataclass + class Child(Middle): + extra: int = 0 + + child = Child(name="hi", extra=42) + serialized = child.to_dict() + + # Root.from_dict works in both cases + from_root = Root.from_dict(serialized) + assert isinstance(from_root, Child) + assert from_root == child + + # Middle.from_dict should also work + from_middle = Middle.from_dict(serialized) + assert isinstance(from_middle, Child) + assert from_middle == child + + +def test_union_resolves_via_discriminator(): + """Union[Branch, int] should use discriminator to resolve Leaf.""" + + def _tagger(cls: type) -> str: + return f"{cls.__module__}.{cls.__qualname__}" + + @dataclass + class Root(DataClassDictMixin): + class Config(BaseConfig): + forbid_extra_keys = True + discriminator = Discriminator( + field="_type", + include_subtypes=True, + variant_tagger_fn=_tagger, + ) + + name: str = "base" + + def __post_serialize__(self, data: dict): + data["_type"] = _tagger(type(self)) + return data + + @classmethod + def __pre_deserialize__(cls, data: dict) -> dict: + return {k: v for k, v in data.items() if k != "_type"} + + @dataclass + class Branch(Root): + value: int = 0 + + @dataclass + class Leaf(Branch): + extra: int = 0 + + leaf = Leaf(name="hi", extra=42) + d = leaf.to_dict() + + assert isinstance(Root.from_dict(d), Leaf) + assert isinstance(Branch.from_dict(d), Leaf) + + @dataclass + class Unrelated(DataClassDictMixin): + class Config(BaseConfig): + forbid_extra_keys = True + discriminator = Discriminator( + field="_other_type", + include_subtypes=True, + variant_tagger_fn=_tagger, + ) + + def __post_serialize__(self, data: dict): + data["_other_type"] = _tagger(type(self)) + return data + + @classmethod + def __pre_deserialize__(cls, data: dict) -> dict: + return {k: v for k, v in data.items() if k != "_other_type"} + + @dataclass + class UnrelatedChild(Unrelated): + value: int = 0 + + @dataclass + class Container(DataClassDictMixin): + item: Union[Branch, int, Unrelated] + + container = Container(item=Leaf(name="hi", extra=42)) + result = Container.from_dict(container.to_dict()) + assert isinstance(result.item, Leaf) + assert result == container + + container = Container(item=1) + result = Container.from_dict(container.to_dict()) + assert isinstance(result.item, int) + assert result == container + + container = Container(item=UnrelatedChild(value=123)) + result = Container.from_dict(container.to_dict()) + assert isinstance(result.item, UnrelatedChild) + assert result == container + + +@pytest.mark.parametrize("forbid", [False, True]) +def test_discriminator_matches_current_class(forbid): + + def _tagger(cls: type) -> str: + return f"{cls.__module__}.{cls.__qualname__}" + + @dataclass + class Root(DataClassDictMixin): + class Config(BaseConfig): + forbid_extra_keys = forbid + discriminator = Discriminator( + field="_type", + include_subtypes=True, + include_current_type=True, + variant_tagger_fn=_tagger, + ) + + name: str = "base" + + def __post_serialize__(self, data: dict): + data["_type"] = _tagger(type(self)) + return data + + @classmethod + def __pre_deserialize__(cls, data: dict) -> dict: + return {k: v for k, v in data.items() if k != "_type"} + + @dataclass + class Middle(Root): + value: int = 0 + + @dataclass + class Container(DataClassDictMixin): + item: Root + + child = Container(item=Root(name="hi")) + serialized = child.to_dict() + + from_root = Container.from_dict(serialized) + assert isinstance(from_root.item, Root) + assert not isinstance(from_root.item, Middle)