From 9481670731636bc8bcca0753a2abf431b4ab8423 Mon Sep 17 00:00:00 2001 From: "Kyle D. Kavanagh" Date: Thu, 12 Mar 2026 21:24:04 -0500 Subject: [PATCH 1/2] Add ability to enumerate possible subtypes to deserialize to. Fixes #281 --- README.md | 1 + mashumaro/core/meta/code/builder.py | 1 + mashumaro/core/meta/types/unpack.py | 8 +++ mashumaro/types.py | 3 +- .../test_parent_via_config.py | 49 +++++++++++++++++++ 5 files changed, 61 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 0c3a65bd..81b55eb9 100644 --- a/README.md +++ b/README.md @@ -2114,6 +2114,7 @@ It has the following parameters that affects class selection rules: * `include_supertypes` — allow to deserialize superclasses * `variant_tagger_fn` — a custom function used to generate tag values associated with a variant +* `possible_subtypes_fn` — a custom function which enumerates possible subtypes to deserialize to. Used in cases where cyclic-imports prevent importing subtypes at the module-level. By default, each variant that you want to discriminate by tags should have a class-level attribute containing an associated tag value. This attribute should diff --git a/mashumaro/core/meta/code/builder.py b/mashumaro/core/meta/code/builder.py index 1807dc32..0b126f7a 100644 --- a/mashumaro/core/meta/code/builder.py +++ b/mashumaro/core/meta/code/builder.py @@ -391,6 +391,7 @@ def _add_unpack_method_lines(self, method_name: str) -> None: field=discr.field, include_subtypes=discr.include_subtypes, variant_tagger_fn=discr.variant_tagger_fn, + possible_subtypes_fn=discr.possible_subtypes_fn, ) self.add_type_modules(self.cls) method = SubtypeUnpackerBuilder(discr).build( diff --git a/mashumaro/core/meta/types/unpack.py b/mashumaro/core/meta/types/unpack.py index 2c67cb9b..92f243f2 100644 --- a/mashumaro/core/meta/types/unpack.py +++ b/mashumaro/core/meta/types/unpack.py @@ -382,6 +382,10 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None: variant_method_call = self._get_variant_method_call( variant_method_name, spec ) + if discriminator.possible_subtypes_fn: + spec.builder.ensure_object_imported( + discriminator.possible_subtypes_fn, "possible_subtypes_fn" + ) if discriminator.variant_tagger_fn: spec.builder.ensure_object_imported( discriminator.variant_tagger_fn, "variant_tagger_fn" @@ -420,6 +424,8 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None: ) with lines.indent("except (KeyError, AttributeError):"): lines.append(f"variants_map = {variants_map}") + if discriminator.possible_subtypes_fn: + lines.append("list(possible_subtypes_fn())") with lines.indent(f"for variant in {variants}:"): if discriminator.variant_tagger_fn is not None: self._add_register_variant_tags( @@ -454,6 +460,8 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None: "discriminator) from None" ) else: + if discriminator.possible_subtypes_fn: + lines.append("list(possible_subtypes_fn())") with lines.indent(f"for variant in {variants}:"): with lines.indent("try:"): if spec.builder.is_nailed: diff --git a/mashumaro/types.py b/mashumaro/types.py index 35091742..14dc1cbd 100644 --- a/mashumaro/types.py +++ b/mashumaro/types.py @@ -1,7 +1,7 @@ import decimal from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Optional, Type, Union +from typing import Any, Iterable, Optional, Type, Union from typing_extensions import Literal @@ -102,6 +102,7 @@ class Discriminator: include_supertypes: bool = False include_subtypes: bool = False variant_tagger_fn: Optional[Callable[[Any], Any]] = None + possible_subtypes_fn: Optional[Callable[[], Iterable[Type]]] = None def __post_init__(self) -> None: if not self.include_supertypes and not self.include_subtypes: diff --git a/tests/test_discriminated_unions/test_parent_via_config.py b/tests/test_discriminated_unions/test_parent_via_config.py index d73f4db7..25928a1e 100644 --- a/tests/test_discriminated_unions/test_parent_via_config.py +++ b/tests/test_discriminated_unions/test_parent_via_config.py @@ -811,3 +811,52 @@ def test_by_subtypes_with_custom_variant_tagger_and_multiple_tags(): VariantWithMultipleTags.from_dict({"type": "unknown"}) with pytest.raises(SuitableVariantNotFoundError): decode({"type": "unknown"}, _VariantWithMultipleTags) + + +def test_cross_module_discriminator_by_field_with_subtypes(tmp_path): + base_code = """\ +from dataclasses import dataclass +from mashumaro import DataClassDictMixin +from mashumaro.config import BaseConfig +from mashumaro.types import Discriminator + +def _get_possible_subtypes(): + from _sub_mod import SubType + yield SubType + +@dataclass +class MyBase(DataClassDictMixin): + class Config(BaseConfig): + discriminator = Discriminator( + field="config_type", + include_subtypes=True, + get_possible_subtypes=_get_possible_subtypes, + ) +""" + sub_code = """\ +from dataclasses import dataclass +from _base_mod import MyBase + +@dataclass +class SubType(MyBase): + config_type = "SubType" +""" + base_file = tmp_path / "_base_mod.py" + sub_file = tmp_path / "_sub_mod.py" + base_file.write_text(base_code) + sub_file.write_text(sub_code) + + import sys + + original_path = sys.path.copy() + sys.path.insert(0, str(tmp_path)) + try: + import _base_mod + + result = _base_mod.MyBase.from_dict({"config_type": "SubType"}) + assert type(result).__name__ == "SubType" + + finally: + sys.path[:] = original_path + sys.modules.pop("_base_mod", None) + sys.modules.pop("_sub_mod", None) From 6a1ba9676442a7227e0b2d3165d8f1ddf0b7d06f Mon Sep 17 00:00:00 2001 From: kdkavanagh Date: Sun, 29 Mar 2026 19:53:13 -0500 Subject: [PATCH 2/2] Fix missing rename --- tests/test_discriminated_unions/test_parent_via_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_discriminated_unions/test_parent_via_config.py b/tests/test_discriminated_unions/test_parent_via_config.py index 25928a1e..f85c9dab 100644 --- a/tests/test_discriminated_unions/test_parent_via_config.py +++ b/tests/test_discriminated_unions/test_parent_via_config.py @@ -830,7 +830,7 @@ class Config(BaseConfig): discriminator = Discriminator( field="config_type", include_subtypes=True, - get_possible_subtypes=_get_possible_subtypes, + possible_subtypes_fn=_get_possible_subtypes, ) """ sub_code = """\