Skip to content

Commit

Permalink
Merge pull request #145 from argyle-engineering/annotated-fields
Browse files Browse the repository at this point in the history
Support some annotated constraints on numerical fields
  • Loading branch information
povilasb authored Oct 23, 2024
2 parents 4279c3c + eca7603 commit 1293a4a
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 25 deletions.
76 changes: 54 additions & 22 deletions pydantic2zod/_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,24 @@
from typing import Callable

from .model import (
AnnotatedType,
AnyType,
BuiltinType,
ClassDecl,
ClassField,
GenericType,
LiteralType,
PrimitiveType,
PydanticField,
PyDict,
PyFloat,
PyInteger,
PyList,
PyName,
PyNone,
PyString,
PyType,
PyValue,
TupleType,
UnionType,
UserDefinedType,
Expand Down Expand Up @@ -128,38 +132,63 @@ def _class_field_to_zod(field: ClassField, code: "Lines") -> None:
_comment_to_ts(comment, code)

code.add(f"{field.name}: ")
_class_field_type_to_zod(field.type, code)
_class_field_type_to_zod(field.type, None, code)

if default := field.default_value:
code.add(".default(", inline=True)
match default:
case PyString(value=value):
code.add(f'"{value}"', inline=True)
case PyInteger(value=value):
code.add(value, inline=True)
case PyNone():
code.add("null", inline=True)
case PyName(value=name):
code.add(name, inline=True)
case PyDict():
code.add("{}", inline=True)
case PyList():
code.add("[]", inline=True)
case other:
assert False, f"Unsupported value type: '{other}'"
_value_to_zod(default, code)
code.add(")", inline=True)


def _class_field_type_to_zod(field_type: PyType, code: "Lines") -> None:
def _value_to_zod(pyval: PyValue, code: "Lines") -> None:
match pyval:
case PyString(value=value):
code.add(f'"{value}"', inline=True)
case PyInteger(value=value) | PyFloat(value=value):
code.add(value, inline=True)
case PyNone():
code.add("null", inline=True)
case PyName(value=name):
code.add(name, inline=True)
case PyDict():
code.add("{}", inline=True)
case PyList():
code.add("[]", inline=True)
case other:
assert False, f"Unsupported value type: '{other}'"


def _class_field_type_to_zod(
field_type: PyType, type_constraints: PydanticField | None, code: "Lines"
) -> None:
match field_type:
case BuiltinType(name=type_name) | PrimitiveType(name=type_name):
match type_name:
case "str":
code.add("z.string()", inline=True)
case "int":
code.add("z.number().int()", inline=True)
case "float":

case "int" | "float":
code.add("z.number()", inline=True)
if type_name == "int":
code.add(".int()", inline=True)
if type_constraints:
if type_constraints.gt is not None:
code.add(".gt(", inline=True)
_value_to_zod(type_constraints.gt, code)
code.add(")", inline=True)
if type_constraints.ge is not None:
code.add(".gte(", inline=True)
_value_to_zod(type_constraints.ge, code)
code.add(")", inline=True)
if type_constraints.lt is not None:
code.add(".lt(", inline=True)
_value_to_zod(type_constraints.lt, code)
code.add(")", inline=True)
if type_constraints.le is not None:
code.add(".lte(", inline=True)
_value_to_zod(type_constraints.le, code)
code.add(")", inline=True)

case "None":
code.add("z.null()", inline=True)
case "bool":
Expand All @@ -180,7 +209,7 @@ def _class_field_type_to_zod(field_type: PyType, code: "Lines") -> None:
with code as indent_code:
code.add("")
for i, tp in enumerate(types):
_class_field_type_to_zod(tp, indent_code)
_class_field_type_to_zod(tp, type_constraints, indent_code)
code.add(",", inline=True)
if i < len(types) - 1:
code.add("")
Expand All @@ -198,7 +227,7 @@ def _class_field_type_to_zod(field_type: PyType, code: "Lines") -> None:
assert False, f"Unsupported generic type: '{other}'"

for i, tv in enumerate(type_vars):
_class_field_type_to_zod(tv, code)
_class_field_type_to_zod(tv, type_constraints, code)
if i < len(type_vars) - 1:
code.add(", ", inline=True)
code.add(")", inline=True)
Expand All @@ -215,6 +244,9 @@ def _class_field_type_to_zod(field_type: PyType, code: "Lines") -> None:
case AnyType():
code.add("z.any()", inline=True)

case AnnotatedType(type_=type_, metadata=metadata):
_class_field_type_to_zod(type_, metadata, code)

case other:
assert False, f"Unsupported field type: '{other}'"

Expand Down
53 changes: 52 additions & 1 deletion pydantic2zod/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing_extensions import Self

from .model import (
AnnotatedType,
AnyType,
BuiltinType,
ClassDecl,
Expand All @@ -23,7 +24,9 @@
Import,
LiteralType,
PrimitiveType,
PydanticField,
PyDict,
PyFloat,
PyInteger,
PyList,
PyNone,
Expand Down Expand Up @@ -448,7 +451,9 @@ def _get_user_defined_types(tp: PyType) -> list[str]:

def _parse_generic_type(
node: cst.Subscript,
) -> GenericType | LiteralType | UnionType | TupleType | UserDefinedType:
) -> (
GenericType | LiteralType | UnionType | TupleType | UserDefinedType | AnnotatedType
):
"""Try to parse a generic type.
Fall back to `UserDefinedType` when don't know how.
"""
Expand All @@ -468,6 +473,8 @@ def _parse_generic_type(
)
case "tuple" | "Tuple":
return TupleType(types=_parse_types_list(node))
case "Annotated":
return _parse_annotated(node)
case other:
_logger.warning("Generic type not supported: '%s'", other)
return UserDefinedType(name=other)
Expand All @@ -489,6 +496,18 @@ def _parse_literal(node: cst.Subscript) -> LiteralType | UnionType:
return UnionType(types=[LiteralType(value=v) for v in literal_values])


def _parse_annotated(node: cst.Subscript) -> AnnotatedType:
assert cst.ensure_type(node.value, cst.Name).value == "Annotated"
args = list(node.slice)
if len(args) != 2:
_logger.warning("Annotated type should have exactly two arguments")
return AnnotatedType(type_=AnyType(), metadata=None)

type_ = _extract_type(cst.ensure_type(args[0].slice, cst.Index).value)
metadata = _parse_field_constraints(cst.ensure_type(args[1].slice, cst.Index).value)
return AnnotatedType(type_=type_, metadata=metadata)


def _parse_types_list(node: cst.Subscript) -> list[PyType]:
types = list[PyType]()
for element in node.slice:
Expand Down Expand Up @@ -546,6 +565,8 @@ def _parse_value(node: cst.BaseExpression) -> PyValue:
return PyList()
case cst.Integer(value=value):
return PyInteger(value=value)
case cst.Float(value=value):
return PyFloat(value=value)
case cst.Call():
if empty_list := _parse_value_from_call(node):
return empty_list
Expand All @@ -567,3 +588,33 @@ def _parse_value_from_call(node: cst.Call) -> PyValue | None:
):
return PyList()
return None


def _parse_field_constraints(node: cst.BaseExpression) -> PydanticField | None:
if not m.matches(
node,
m.Call(func=m.Name("Field")),
):
return None
node = cst.ensure_type(node, cst.Call)

field_decl = PydanticField()

for arg in node.args:
if not (arg_name := arg.keyword):
continue

arg_value = _parse_value(arg.value)
match arg_name.value:
case "gt":
field_decl.gt = arg_value
case "ge":
field_decl.ge = arg_value
case "lt":
field_decl.lt = arg_value
case "le":
field_decl.le = arg_value
case _:
...

return field_decl
23 changes: 23 additions & 0 deletions pydantic2zod/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ class PyInteger(PyValue):
value: str


@dataclass
class PyFloat(PyValue):
value: str


@dataclass
class BuiltinType(PyType):
name: Literal[
Expand Down Expand Up @@ -139,3 +144,21 @@ class AnyType(PyType):
"""Represents `typing.Any`."""

...


@dataclass
class PydanticField:
"""Some constraints from `pydantic.Field()` declaration."""

gt: PyValue | None = None
ge: PyValue | None = None
lt: PyValue | None = None
le: PyValue | None = None


@dataclass
class AnnotatedType(PyType):
"""Represents `typing.Annotated`."""

type_: PyType
metadata: PydanticField | None = None
9 changes: 9 additions & 0 deletions tests/fixtures/annotated_fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import Annotated

from pydantic import BaseModel, Field


class Employee(BaseModel):
age: Annotated[int, Field(ge=18, le=67)]
level: Annotated[int, Field(1, gt=0, lt=6)]
salary: Annotated[float, Field(gt=1000, lt=10000)]
15 changes: 15 additions & 0 deletions tests/snapshots/snap_test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,21 @@

snapshots = Snapshot()

snapshots['test_annotated_fields 1'] = '''
/**
* NOTE: automatically generated by the pydantic2zod compiler.
*/
import { z } from "zod";
export const Employee = z.object({
age: z.number().int().gte(18).lte(67),
level: z.number().int().gt(0).lt(6),
salary: z.number().gt(1000).lt(10000),
}).strict();
export type EmployeeType = z.infer<typeof Employee>;
'''

snapshots['test_builtin_types 1'] = '''
/**
* NOTE: automatically generated by the pydantic2zod compiler.
Expand Down
12 changes: 10 additions & 2 deletions tests/test_compile.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# pyright: reportPrivateUsage=false

import pydantic
import pytest
from snapshottest.module import SnapshotTest

from pydantic2zod._compiler import Compiler
Expand Down Expand Up @@ -41,3 +41,11 @@ def test_class_variables_are_skipped(snapshot: SnapshotTest):
def test_builtin_types(snapshot: SnapshotTest):
out_src = Compiler().parse("tests.fixtures.builtin_types").to_zod()
snapshot.assert_match(out_src)


@pytest.mark.skipif(
not pydantic.VERSION.startswith("2"), reason="Only works with pydantic v2"
)
def test_annotated_fields(snapshot: SnapshotTest):
out_src = Compiler().parse("tests.fixtures.annotated_fields").to_zod()
snapshot.assert_match(out_src)

0 comments on commit 1293a4a

Please sign in to comment.