Skip to content

Commit

Permalink
Fixes and improvements to Nested typing (#2721)
Browse files Browse the repository at this point in the history
* Fixes and improvements to Nested typing

* Remove SchemaABC
  • Loading branch information
sloria authored Jan 4, 2025
1 parent f4ca03f commit c188cdb
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 49 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ Features:

- Typing: Replace type comments with inline typings (:pr:`2718`).

Bug fixes:

- Typing: Fix type hint for ``nested`` parameter of `Nested <marshmallow.fields.Nested>`.


3.23.3 (2025-01-03)
*******************
Expand Down
28 changes: 0 additions & 28 deletions src/marshmallow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,31 +35,3 @@ def _serialize(self, value, attr, obj, **kwargs):
@abstractmethod
def _deserialize(self, value, attr, data, **kwargs):
pass


class SchemaABC(ABC):
"""Abstract base class from which all Schemas inherit."""

@abstractmethod
def dump(self, obj, *, many: bool | None = None):
pass

@abstractmethod
def dumps(self, obj, *, many: bool | None = None):
pass

@abstractmethod
def load(self, data, *, many: bool | None = None, partial=None, unknown=None):
pass

@abstractmethod
def loads(
self,
json_data,
*,
many: bool | None = None,
partial=None,
unknown=None,
**kwargs,
):
pass
10 changes: 10 additions & 0 deletions src/marshmallow/class_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ class MyClass:
return None


@typing.overload
def get_class(classname: str, all: typing.Literal[False]) -> SchemaType: ...


@typing.overload
def get_class(
classname: str, all: typing.Literal[True]
) -> list[SchemaType] | SchemaType: ...


def get_class(classname: str, all: bool = False) -> list[SchemaType] | SchemaType:
"""Retrieve a class from the registry.
Expand Down
38 changes: 18 additions & 20 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from enum import Enum as EnumType

from marshmallow import class_registry, types, utils, validate
from marshmallow.base import FieldABC, SchemaABC
from marshmallow.base import FieldABC
from marshmallow.exceptions import (
FieldInstanceResolutionError,
StringNotCollectionError,
Expand All @@ -32,7 +32,7 @@
from marshmallow.warnings import RemovedInMarshmallow4Warning

if typing.TYPE_CHECKING:
from marshmallow.schema import SchemaMeta
from marshmallow.schema import Schema, SchemaMeta


__all__ = [
Expand Down Expand Up @@ -519,13 +519,11 @@ class ParentSchema(Schema):
def __init__(
self,
nested: (
SchemaABC
Schema
| SchemaMeta
| str
| dict[str, Field | type[Field]]
| typing.Callable[
[], SchemaABC | SchemaMeta | dict[str, Field | type[Field]]
]
| dict[str, Field]
| typing.Callable[[], Schema | SchemaMeta | dict[str, Field]]
),
*,
dump_default: typing.Any = missing_,
Expand Down Expand Up @@ -555,11 +553,11 @@ def __init__(
self.exclude = exclude
self.many = many
self.unknown = unknown
self._schema = None # Cached Schema instance
self._schema: Schema | None = None # Cached Schema instance
super().__init__(default=default, dump_default=dump_default, **kwargs)

@property
def schema(self):
def schema(self) -> Schema:
"""The nested Schema object.
.. versionchanged:: 1.0.0
Expand All @@ -571,18 +569,18 @@ def schema(self):
if callable(self.nested) and not isinstance(self.nested, type):
nested = self.nested()
else:
nested = self.nested
if isinstance(nested, dict):
# defer the import of `marshmallow.schema` to avoid circular imports
from marshmallow.schema import Schema
nested = typing.cast("Schema", self.nested)
# defer the import of `marshmallow.schema` to avoid circular imports
from marshmallow.schema import Schema

if isinstance(nested, dict):
nested = Schema.from_dict(nested)

if isinstance(nested, SchemaABC):
if isinstance(nested, Schema):
self._schema = copy.copy(nested)
self._schema.context.update(context)
# Respect only and exclude passed from parent and re-initialize fields
set_class = self._schema.set_class
set_class = typing.cast(type[set], self._schema.set_class)
if self.only is not None:
if self._schema.only is not None:
original = self._schema.only
Expand All @@ -594,17 +592,17 @@ def schema(self):
self._schema.exclude = set_class(self.exclude) | set_class(original)
self._schema._init_fields()
else:
if isinstance(nested, type) and issubclass(nested, SchemaABC):
schema_class = nested
if isinstance(nested, type) and issubclass(nested, Schema):
schema_class: type[Schema] = nested
elif not isinstance(nested, (str, bytes)):
raise ValueError(
"`Nested` fields must be passed a "
f"`Schema`, not {nested.__class__}."
)
elif nested == "self":
schema_class = self.root.__class__
schema_class = typing.cast(Schema, self.root).__class__
else:
schema_class = class_registry.get_class(nested)
schema_class = class_registry.get_class(nested, all=False)
self._schema = schema_class(
many=self.many,
only=self.only,
Expand Down Expand Up @@ -688,7 +686,7 @@ class AlbumSchema(Schema):

def __init__(
self,
nested: SchemaABC | SchemaMeta | str | typing.Callable[[], SchemaABC],
nested: Schema | SchemaMeta | str | typing.Callable[[], Schema],
field_name: str,
**kwargs,
):
Expand Down
2 changes: 1 addition & 1 deletion src/marshmallow/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def __init__(self, meta, ordered: bool = False):
self.many = getattr(meta, "many", False)


class Schema(base.SchemaABC, metaclass=SchemaMeta):
class Schema(metaclass=SchemaMeta):
"""Base schema class with which to define custom schemas.
Example usage:
Expand Down

0 comments on commit c188cdb

Please sign in to comment.