Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion openff/interchange/_tests/unit_tests/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class Person(_BaseModel):


class Roster(_BaseModel):
people: dict[str, Person] = Field(dict())
people: dict[str, Person] = Field(default_factory=dict)

foo: _Quantity = Field()

Expand Down
2 changes: 1 addition & 1 deletion openff/interchange/components/interchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class Interchange(_BaseModel):
"""

collections: _AnnotatedCollections = Field(dict())
collections: _AnnotatedCollections = Field(default_factory=dict)
topology: _AnnotatedTopology
mdconfig: MDConfig | None = Field(None)
box: _BoxQuantity | None = Field(None) # Needs shape/OpenMM validation
Expand Down
15 changes: 14 additions & 1 deletion openff/interchange/components/potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
class Potential(_BaseModel):
"""Base class for storing applied parameters."""

parameters: dict[str, _Quantity] = Field(dict())
parameters: dict[str, _Quantity] = Field(default_factory=dict)
map_key: int | None = None

def __hash__(self) -> int:
Expand Down Expand Up @@ -439,7 +439,20 @@ def validate_collections(
raise ValueError(f"Validation mode {info.mode} not implemented.")


def serialize_collections(v: Any, handler: Any, info: Any) -> dict:
"""Serialize collections using each collection's actual type schema.

Without this, pydantic uses the declared Collection base class schema and
drops subclass-specific fields (e.g. scale_14, cutoff, periodic_potential).
"""
if info.mode == "json":
return {name: json.loads(collection.model_dump_json()) for name, collection in v.items()}
else:
raise NotImplementedError(f"Serialization mode {info.mode} not implemented.")


_AnnotatedCollections = Annotated[
dict[str, Collection],
WrapValidator(validate_collections),
WrapSerializer(serialize_collections),
]
4 changes: 2 additions & 2 deletions openff/interchange/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class _BaseModel(BaseModel):
)

def model_dump(self, **kwargs) -> dict[str, Any]:
return super().model_dump(serialize_as_any=True, **kwargs)
return super().model_dump(**kwargs)

def model_dump_json(self, **kwargs) -> str:
return super().model_dump_json(serialize_as_any=True, **kwargs)
return super().model_dump_json(**kwargs)
Loading