Skip to content

Commit

Permalink
Memory usage optimization via reuse of SchemaValidator and `SchemaS…
Browse files Browse the repository at this point in the history
…erializer` (#1616)
  • Loading branch information
sydney-runkle authored Feb 5, 2025
1 parent 3707dcd commit 164b9ff
Show file tree
Hide file tree
Showing 8 changed files with 224 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub(crate) mod prebuilt;
pub(crate) mod union;
43 changes: 43 additions & 0 deletions src/common/prebuilt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyAny, PyDict, PyType};

use crate::tools::SchemaDict;

pub fn get_prebuilt<T>(
type_: &str,
schema: &Bound<'_, PyDict>,
prebuilt_attr_name: &str,
extractor: impl FnOnce(Bound<'_, PyAny>) -> PyResult<T>,
) -> PyResult<Option<T>> {
let py = schema.py();

// we can only use prebuilt validators / serializers from models, typed dicts, and dataclasses
// however, we don't want to use a prebuilt structure from dataclasses if we have a generic_origin
// because the validator / serializer is cached on the unparametrized dataclass
if !matches!(type_, "model" | "typed-dict")
|| matches!(type_, "dataclass") && schema.contains(intern!(py, "generic_origin"))?
{
return Ok(None);
}

let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?;

// Note: we NEED to use the __dict__ here (and perform get_item calls rather than getattr)
// because we don't want to fetch prebuilt validators from parent classes.
// We don't downcast here because __dict__ on a class is a readonly mappingproxy,
// so we can just leave it as is and do get_item checks.
let class_dict = class.getattr(intern!(py, "__dict__"))?;

let is_complete: bool = class_dict
.get_item(intern!(py, "__pydantic_complete__"))
.is_ok_and(|b| b.extract().unwrap_or(false));

if !is_complete {
return Ok(None);
}

// Retrieve the prebuilt validator / serializer if available
let prebuilt: Bound<'_, PyAny> = class_dict.get_item(prebuilt_attr_name)?;
extractor(prebuilt).map(Some)
}
1 change: 1 addition & 0 deletions src/serializers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ mod fields;
mod filter;
mod infer;
mod ob_type;
mod prebuilt;
pub mod ser;
mod shared;
mod type_serializers;
Expand Down
68 changes: 68 additions & 0 deletions src/serializers/prebuilt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
use std::borrow::Cow;

use pyo3::prelude::*;
use pyo3::types::PyDict;

use crate::common::prebuilt::get_prebuilt;
use crate::SchemaSerializer;

use super::extra::Extra;
use super::shared::{CombinedSerializer, TypeSerializer};

#[derive(Debug)]
pub struct PrebuiltSerializer {
schema_serializer: Py<SchemaSerializer>,
}

impl PrebuiltSerializer {
pub fn try_get_from_schema(type_: &str, schema: &Bound<'_, PyDict>) -> PyResult<Option<CombinedSerializer>> {
get_prebuilt(type_, schema, "__pydantic_serializer__", |py_any| {
py_any
.extract::<Py<SchemaSerializer>>()
.map(|schema_serializer| Self { schema_serializer }.into())
})
}
}

impl_py_gc_traverse!(PrebuiltSerializer { schema_serializer });

impl TypeSerializer for PrebuiltSerializer {
fn to_python(
&self,
value: &Bound<'_, PyAny>,
include: Option<&Bound<'_, PyAny>>,
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> PyResult<PyObject> {
self.schema_serializer
.get()
.serializer
.to_python(value, include, exclude, extra)
}

fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
self.schema_serializer.get().serializer.json_key(key, extra)
}

fn serde_serialize<S: serde::ser::Serializer>(
&self,
value: &Bound<'_, PyAny>,
serializer: S,
include: Option<&Bound<'_, PyAny>>,
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> Result<S::Ok, S::Error> {
self.schema_serializer
.get()
.serializer
.serde_serialize(value, serializer, include, exclude, extra)
}

fn get_name(&self) -> &str {
self.schema_serializer.get().serializer.get_name()
}

fn retry_with_lax_check(&self) -> bool {
self.schema_serializer.get().serializer.retry_with_lax_check()
}
}
12 changes: 11 additions & 1 deletion src/serializers/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ combined_serializer! {
Function: super::type_serializers::function::FunctionPlainSerializer;
FunctionWrap: super::type_serializers::function::FunctionWrapSerializer;
Fields: super::fields::GeneralFieldsSerializer;
// prebuilt serializers are manually constructed, and thus manually added to the `CombinedSerializer` enum
Prebuilt: super::prebuilt::PrebuiltSerializer;
}
// `find_only` is for type_serializers which are built directly via the `type` key and `find_serializer`
// but aren't actually used for serialization, e.g. their `build` method must return another serializer
Expand Down Expand Up @@ -195,7 +197,14 @@ impl CombinedSerializer {
}

let type_: Bound<'_, PyString> = schema.get_as_req(type_key)?;
Self::find_serializer(type_.to_str()?, schema, config, definitions)
let type_ = type_.to_str()?;

// if we have a SchemaValidator on the type already, use it
if let Ok(Some(prebuilt_serializer)) = super::prebuilt::PrebuiltSerializer::try_get_from_schema(type_, schema) {
return Ok(prebuilt_serializer);
}

Self::find_serializer(type_, schema, config, definitions)
}
}

Expand All @@ -219,6 +228,7 @@ impl PyGcTraverse for CombinedSerializer {
CombinedSerializer::Function(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::FunctionWrap(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Fields(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Prebuilt(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::None(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Nullable(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Int(inner) => inner.py_gc_traverse(visit),
Expand Down
12 changes: 11 additions & 1 deletion src/validators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ mod model;
mod model_fields;
mod none;
mod nullable;
mod prebuilt;
mod set;
mod string;
mod time;
Expand Down Expand Up @@ -515,8 +516,15 @@ pub fn build_validator(
definitions: &mut DefinitionsBuilder<CombinedValidator>,
) -> PyResult<CombinedValidator> {
let dict = schema.downcast::<PyDict>()?;
let type_: Bound<'_, PyString> = dict.get_as_req(intern!(schema.py(), "type"))?;
let py = schema.py();
let type_: Bound<'_, PyString> = dict.get_as_req(intern!(py, "type"))?;
let type_ = type_.to_str()?;

// if we have a SchemaValidator on the type already, use it
if let Ok(Some(prebuilt_validator)) = prebuilt::PrebuiltValidator::try_get_from_schema(type_, dict) {
return Ok(prebuilt_validator);
}

validator_match!(
type_,
dict,
Expand Down Expand Up @@ -763,6 +771,8 @@ pub enum CombinedValidator {
// input dependent
JsonOrPython(json_or_python::JsonOrPython),
Complex(complex::ComplexValidator),
// uses a reference to an existing SchemaValidator to reduce memory usage
Prebuilt(prebuilt::PrebuiltValidator),
}

/// This trait must be implemented by all validators, it allows various validators to be accessed consistently,
Expand Down
41 changes: 41 additions & 0 deletions src/validators/prebuilt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use pyo3::prelude::*;
use pyo3::types::PyDict;

use crate::common::prebuilt::get_prebuilt;
use crate::errors::ValResult;
use crate::input::Input;

use super::ValidationState;
use super::{CombinedValidator, SchemaValidator, Validator};

#[derive(Debug)]
pub struct PrebuiltValidator {
schema_validator: Py<SchemaValidator>,
}

impl PrebuiltValidator {
pub fn try_get_from_schema(type_: &str, schema: &Bound<'_, PyDict>) -> PyResult<Option<CombinedValidator>> {
get_prebuilt(type_, schema, "__pydantic_validator__", |py_any| {
py_any
.extract::<Py<SchemaValidator>>()
.map(|schema_validator| Self { schema_validator }.into())
})
}
}

impl_py_gc_traverse!(PrebuiltValidator { schema_validator });

impl Validator for PrebuiltValidator {
fn validate<'py>(
&self,
py: Python<'py>,
input: &(impl Input<'py> + ?Sized),
state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
self.schema_validator.get().validator.validate(py, input, state)
}

fn get_name(&self) -> &str {
self.schema_validator.get().validator.get_name()
}
}
48 changes: 48 additions & 0 deletions tests/test_prebuilt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from pydantic_core import SchemaSerializer, SchemaValidator, core_schema


def test_prebuilt_val_and_ser_used() -> None:
class InnerModel:
x: int

inner_schema = core_schema.model_schema(
InnerModel,
schema=core_schema.model_fields_schema(
{'x': core_schema.model_field(schema=core_schema.int_schema())},
),
)

inner_schema_validator = SchemaValidator(inner_schema)
inner_schema_serializer = SchemaSerializer(inner_schema)
InnerModel.__pydantic_complete__ = True # pyright: ignore[reportAttributeAccessIssue]
InnerModel.__pydantic_validator__ = inner_schema_validator # pyright: ignore[reportAttributeAccessIssue]
InnerModel.__pydantic_serializer__ = inner_schema_serializer # pyright: ignore[reportAttributeAccessIssue]

class OuterModel:
inner: InnerModel

outer_schema = core_schema.model_schema(
OuterModel,
schema=core_schema.model_fields_schema(
{
'inner': core_schema.model_field(
schema=core_schema.model_schema(
InnerModel,
schema=core_schema.model_fields_schema(
# note, we use str schema here even though that's incorrect
# in order to verify that the prebuilt validator is used
# off of InnerModel with the correct int schema, not this str schema
{'x': core_schema.model_field(schema=core_schema.str_schema())},
),
)
)
}
),
)

outer_validator = SchemaValidator(outer_schema)
outer_serializer = SchemaSerializer(outer_schema)

result = outer_validator.validate_python({'inner': {'x': 1}})
assert result.inner.x == 1
assert outer_serializer.to_python(result) == {'inner': {'x': 1}}

0 comments on commit 164b9ff

Please sign in to comment.