Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2112,6 +2112,7 @@ It has the following parameters that affects class selection rules:
by which all the variants can be distinguished
* `include_subtypes` — allow to deserialize subclasses
* `include_supertypes` — allow to deserialize superclasses
* `include_current_type` - allow to deserialize to the class which hosts this discriminator
* `variant_tagger_fn` — a custom function used to generate tag values
associated with a variant

Expand Down
21 changes: 19 additions & 2 deletions mashumaro/core/meta/code/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,12 +380,19 @@ def _add_unpack_method_lines(self, method_name: str) -> None:
if self.decoder is not None:
self.add_line("d = decoder(d)")
discr = self.get_discriminator()
inherited = False
if not discr:
discr = self.get_discriminator(look_in_parents=True)
if discr and discr.field is None:
discr = None
inherited = True
if discr:
if not discr.include_subtypes:
raise ValueError(
"Config based discriminator must have "
"'include_subtypes' enabled"
)
fallthrough = inherited or discr.include_current_type
discr = Discriminator(
# prevent RecursionError
field=discr.field,
Expand All @@ -401,8 +408,18 @@ def _add_unpack_method_lines(self, method_name: str) -> None:
field_ctx=FieldContext("", {}),
)
)
self.add_line(f"return {method}")
return
if fallthrough:
with self.indent("try:"):
self.add_line(f"return {method}")
with self.indent(
"except ("
"SuitableVariantNotFoundError, "
"MissingDiscriminatorError):"
):
self.add_line("pass")
else:
self.add_line(f"return {method}")
return
pre_deserialize = self.get_declared_hook(__PRE_DESERIALIZE__)
if pre_deserialize:
if not isinstance(pre_deserialize, classmethod):
Expand Down
1 change: 1 addition & 0 deletions mashumaro/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class Discriminator:
field: Optional[str] = None
include_supertypes: bool = False
include_subtypes: bool = False
include_current_type: bool = False
variant_tagger_fn: Optional[Callable[[Any], Any]] = None

def __post_init__(self) -> None:
Expand Down
173 changes: 173 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,3 +461,176 @@ def test_forbid_extra_keys_with_discriminator_for_subclass():
{"x": "foo", "__type": "_VariantByField4", "y": "bar"}
)
assert exc_info.value.extra_keys == {"y"}


@pytest.mark.parametrize("forbid", [False, True])
def test_subclass_from_dict_inherits_discriminator(forbid):

def _tagger(cls: type) -> str:
return f"{cls.__module__}.{cls.__qualname__}"

@dataclass
class Root(DataClassDictMixin):
class Config(BaseConfig):
forbid_extra_keys = forbid
discriminator = Discriminator(
field="_type",
include_subtypes=True,
variant_tagger_fn=_tagger,
)

name: str = "base"

def __post_serialize__(self, data: dict):
data["_type"] = _tagger(type(self))
return data

@classmethod
def __pre_deserialize__(cls, data: dict) -> dict:
return {k: v for k, v in data.items() if k != "_type"}

@dataclass
class Middle(Root):
value: int = 0

@dataclass
class Child(Middle):
extra: int = 0

child = Child(name="hi", extra=42)
serialized = child.to_dict()

# Root.from_dict works in both cases
from_root = Root.from_dict(serialized)
assert isinstance(from_root, Child)
assert from_root == child

# Middle.from_dict should also work
from_middle = Middle.from_dict(serialized)
assert isinstance(from_middle, Child)
assert from_middle == child


def test_union_resolves_via_discriminator():
"""Union[Branch, int] should use discriminator to resolve Leaf."""

def _tagger(cls: type) -> str:
return f"{cls.__module__}.{cls.__qualname__}"

@dataclass
class Root(DataClassDictMixin):
class Config(BaseConfig):
forbid_extra_keys = True
discriminator = Discriminator(
field="_type",
include_subtypes=True,
variant_tagger_fn=_tagger,
)

name: str = "base"

def __post_serialize__(self, data: dict):
data["_type"] = _tagger(type(self))
return data

@classmethod
def __pre_deserialize__(cls, data: dict) -> dict:
return {k: v for k, v in data.items() if k != "_type"}

@dataclass
class Branch(Root):
value: int = 0

@dataclass
class Leaf(Branch):
extra: int = 0

leaf = Leaf(name="hi", extra=42)
d = leaf.to_dict()

assert isinstance(Root.from_dict(d), Leaf)
assert isinstance(Branch.from_dict(d), Leaf)

@dataclass
class Unrelated(DataClassDictMixin):
class Config(BaseConfig):
forbid_extra_keys = True
discriminator = Discriminator(
field="_other_type",
include_subtypes=True,
variant_tagger_fn=_tagger,
)

def __post_serialize__(self, data: dict):
data["_other_type"] = _tagger(type(self))
return data

@classmethod
def __pre_deserialize__(cls, data: dict) -> dict:
return {k: v for k, v in data.items() if k != "_other_type"}

@dataclass
class UnrelatedChild(Unrelated):
value: int = 0

@dataclass
class Container(DataClassDictMixin):
item: Union[Branch, int, Unrelated]

container = Container(item=Leaf(name="hi", extra=42))
result = Container.from_dict(container.to_dict())
assert isinstance(result.item, Leaf)
assert result == container

container = Container(item=1)
result = Container.from_dict(container.to_dict())
assert isinstance(result.item, int)
assert result == container

container = Container(item=UnrelatedChild(value=123))
result = Container.from_dict(container.to_dict())
assert isinstance(result.item, UnrelatedChild)
assert result == container


@pytest.mark.parametrize("forbid", [False, True])
def test_discriminator_matches_current_class(forbid):

def _tagger(cls: type) -> str:
return f"{cls.__module__}.{cls.__qualname__}"

@dataclass
class Root(DataClassDictMixin):
class Config(BaseConfig):
forbid_extra_keys = forbid
discriminator = Discriminator(
field="_type",
include_subtypes=True,
include_current_type=True,
variant_tagger_fn=_tagger,
)

name: str = "base"

def __post_serialize__(self, data: dict):
data["_type"] = _tagger(type(self))
return data

@classmethod
def __pre_deserialize__(cls, data: dict) -> dict:
return {k: v for k, v in data.items() if k != "_type"}

@dataclass
class Middle(Root):
value: int = 0

@dataclass
class Container(DataClassDictMixin):
item: Root

child = Container(item=Root(name="hi"))
serialized = child.to_dict()

from_root = Container.from_dict(serialized)
assert isinstance(from_root.item, Root)
assert not isinstance(from_root.item, Middle)
Loading