Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2114,6 +2114,7 @@ It has the following parameters that affects class selection rules:
* `include_supertypes` — allow to deserialize superclasses
* `variant_tagger_fn` — a custom function used to generate tag values
associated with a variant
* `possible_subtypes_fn` — a custom function which enumerates possible subtypes to deserialize to. Used in cases where cyclic-imports prevent importing subtypes at the module-level.

By default, each variant that you want to discriminate by tags should have a
class-level attribute containing an associated tag value. This attribute should
Expand Down
1 change: 1 addition & 0 deletions mashumaro/core/meta/code/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def _add_unpack_method_lines(self, method_name: str) -> None:
field=discr.field,
include_subtypes=discr.include_subtypes,
variant_tagger_fn=discr.variant_tagger_fn,
possible_subtypes_fn=discr.possible_subtypes_fn,
)
self.add_type_modules(self.cls)
method = SubtypeUnpackerBuilder(discr).build(
Expand Down
8 changes: 8 additions & 0 deletions mashumaro/core/meta/types/unpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,10 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None:
variant_method_call = self._get_variant_method_call(
variant_method_name, spec
)
if discriminator.possible_subtypes_fn:
spec.builder.ensure_object_imported(
discriminator.possible_subtypes_fn, "possible_subtypes_fn"
)
if discriminator.variant_tagger_fn:
spec.builder.ensure_object_imported(
discriminator.variant_tagger_fn, "variant_tagger_fn"
Expand Down Expand Up @@ -420,6 +424,8 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None:
)
with lines.indent("except (KeyError, AttributeError):"):
lines.append(f"variants_map = {variants_map}")
if discriminator.possible_subtypes_fn:
lines.append("list(possible_subtypes_fn())")
with lines.indent(f"for variant in {variants}:"):
if discriminator.variant_tagger_fn is not None:
self._add_register_variant_tags(
Expand Down Expand Up @@ -454,6 +460,8 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None:
"discriminator) from None"
)
else:
if discriminator.possible_subtypes_fn:
lines.append("list(possible_subtypes_fn())")
with lines.indent(f"for variant in {variants}:"):
with lines.indent("try:"):
if spec.builder.is_nailed:
Expand Down
3 changes: 2 additions & 1 deletion mashumaro/types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import decimal
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Optional, Type, Union
from typing import Any, Iterable, Optional, Type, Union

from typing_extensions import Literal

Expand Down Expand Up @@ -102,6 +102,7 @@ class Discriminator:
include_supertypes: bool = False
include_subtypes: bool = False
variant_tagger_fn: Optional[Callable[[Any], Any]] = None
possible_subtypes_fn: Optional[Callable[[], Iterable[Type]]] = None

def __post_init__(self) -> None:
if not self.include_supertypes and not self.include_subtypes:
Expand Down
49 changes: 49 additions & 0 deletions tests/test_discriminated_unions/test_parent_via_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,3 +811,52 @@ def test_by_subtypes_with_custom_variant_tagger_and_multiple_tags():
VariantWithMultipleTags.from_dict({"type": "unknown"})
with pytest.raises(SuitableVariantNotFoundError):
decode({"type": "unknown"}, _VariantWithMultipleTags)


def test_cross_module_discriminator_by_field_with_subtypes(tmp_path):
base_code = """\
from dataclasses import dataclass
from mashumaro import DataClassDictMixin
from mashumaro.config import BaseConfig
from mashumaro.types import Discriminator

def _get_possible_subtypes():
from _sub_mod import SubType
yield SubType

@dataclass
class MyBase(DataClassDictMixin):
class Config(BaseConfig):
discriminator = Discriminator(
field="config_type",
include_subtypes=True,
possible_subtypes_fn=_get_possible_subtypes,
)
"""
sub_code = """\
from dataclasses import dataclass
from _base_mod import MyBase

@dataclass
class SubType(MyBase):
config_type = "SubType"
"""
base_file = tmp_path / "_base_mod.py"
sub_file = tmp_path / "_sub_mod.py"
base_file.write_text(base_code)
sub_file.write_text(sub_code)

import sys

original_path = sys.path.copy()
sys.path.insert(0, str(tmp_path))
try:
import _base_mod

result = _base_mod.MyBase.from_dict({"config_type": "SubType"})
assert type(result).__name__ == "SubType"

finally:
sys.path[:] = original_path
sys.modules.pop("_base_mod", None)
sys.modules.pop("_sub_mod", None)
Loading