Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve import aliases #7

Merged
merged 8 commits into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions pydantic2zod/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,12 @@ def _class_field_type_to_zod(field_type: PyType, code: "Lines") -> None:
code.add(", ", inline=True)
code.add(")", inline=True)

case UserDefinedType(name=type_):
if type_ == "UUID":
case UserDefinedType(name=type_name):
if type_name == "uuid.UUID":
code.add("z.string().uuid()", inline=True)
else:
code.add(type_, inline=True)
type_name = type_name.split(".")[-1]
code.add(type_name, inline=True)

case other:
assert False, f"Unsupported field type: '{other}'"
Expand Down
14 changes: 14 additions & 0 deletions pydantic2zod/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@ class PyValue:
...


@dataclass
class Import:
from_module: str
"pkg.module1"

name: str
"ClassName"

alias: str | None = None
"""`import module.Class as OtherClass`"""


@dataclass
class ClassField:
name: str
Expand All @@ -28,6 +40,8 @@ class ClassField:
@dataclass
class ClassDecl:
name: str
full_path: str = ""
"""pkg1.module.ClassName"""
fields: list[ClassField] = field(default_factory=lambda: [])
base_classes: list[str] = field(default_factory=lambda: [])
comment: str | None = None
Expand Down
162 changes: 106 additions & 56 deletions pydantic2zod/_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""An incomplete Python parser focused around Pydantic declarations."""

import inspect
import logging
from importlib import import_module
from importlib.util import resolve_name
Expand All @@ -18,6 +19,7 @@
ClassDecl,
ClassField,
GenericType,
Import,
LiteralType,
PrimitiveType,
PyDict,
Expand All @@ -32,7 +34,8 @@

_logger = logging.getLogger(__name__)

Imports = NewType("Imports", dict[str, str])

Imports = NewType("Imports", dict[str, Import])
"""imported_symbol -> from_module

e.g. Request --> scanner_common.http.cassette
Expand All @@ -42,7 +45,7 @@
def parse(module: ModuleType) -> list[ClassDecl]:
model_graph = DiGraph()
pydantic_models = _parse(module, set(), model_graph)
models_by_name = {c.name: c for c in pydantic_models}
models_by_name = {c.full_path: c for c in pydantic_models}
ordered_models = list[str](dfs_postorder_nodes(model_graph))
return [models_by_name[c] for c in ordered_models if c in models_by_name]

Expand All @@ -55,19 +58,18 @@ def _parse(

classes = list[ClassDecl]()

parse_module = _ParseModule(model_graph, parse_only_models)
parse_module = _ParseModule(module, model_graph, parse_only_models)
m = cst.parse_module(Path(fname).read_text())
classes += parse_module.visit(m).classes()

if depends_on := parse_module.external_models():
_logger.info("'%s' depends on other pydantic models:", fname)
for model_name, pymodule in depends_on.items():
_logger.info(" '%s' from '%s'", model_name, pymodule)
for model_path in depends_on:
_logger.info(" '%s'", model_path)

# TODO(povilas): group models by pymodule
for model_name, pymodule in depends_on.items():
abs_module_name = resolve_name(pymodule, module.__package__)
m = import_module(abs_module_name)
for model_path in depends_on:
m = import_module(".".join(model_path.split(".")[:-1]))
model_name = model_path.split(".")[-1]
classes += _parse(m, {model_name}, model_graph)

return classes
Expand All @@ -84,47 +86,52 @@ def visit(self, node: _NodeT) -> Self:

class _ParseModule(_Parse[cst.Module]):
def __init__(
self, model_graph: DiGraph, parse_only_models: set[str] | None = None
self,
module: ModuleType,
model_graph: DiGraph,
parse_only_models: set[str] | None = None,
) -> None:
super().__init__()

self._parse_only_models = parse_only_models
self._model_graph = model_graph
self._parsing_module = module

self._pydantic_classes: dict[str, ClassDecl] = {}
# All classes found in the module.
self._classes: dict[str, ClassDecl] = {}
self._pydantic_classes: dict[str, ClassDecl] = {}
self._class_nodes: dict[str, cst.ClassDef] = {}
self._alias_nodes: dict[str, cst.AnnAssign] = {}

self._external_models = set[str]()
self._imports = Imports({})

def external_models(self) -> Imports:
"""A List of pydantic models coming from other Python modules."""
return Imports(
{k: v for k, v in self._imports.items() if k in self._external_models}
)
def exec(self) -> Self:
"""A helper for tests."""
self.visit(cst.parse_module(inspect.getsource(self._parsing_module)))
return self

def external_models(self) -> set[str]:
"""A List of pydantic models coming from other Python modules.

Built-in common types like uuid.UUID are filtered out so that pydanitc2zod
would not try to parse them recursively.
"""
return self._external_models

def classes(self) -> list[ClassDecl]:
ordered_models = list(dfs_postorder_nodes(self._model_graph))
return [
self._pydantic_classes[c]
for c in ordered_models
if c in self._pydantic_classes
]
return list(self._pydantic_classes.values())

def visit_ImportFrom(self, node: cst.ImportFrom):
self._imports |= _ParseImportFrom().visit(node).imports()
self._imports |= {
i.alias or i.name: i for i in _ParseImportFrom().visit(node).imports()
}

def visit_ClassDef(self, node: cst.ClassDef):
parse = _ParseClassDecl()
parse.visit_ClassDef(node)
cls = parse.class_decl

cls = _ParseClassDecl().visit(node).class_decl
cls.full_path = f"{self._parsing_module.__name__}.{cls.name}"
self._class_nodes[cls.name] = node
self._classes[cls.name] = cls
if cls.name in self._model_graph:
_logger.warning("Model with name '%s' already exists.", cls.name)
self._model_graph.add_node(cls.name)

@m.call_if_inside(
m.AnnAssign(annotation=m.Annotation(annotation=m.Name("TypeAlias")))
Expand All @@ -141,39 +148,67 @@ def leave_Module(self, original_node: cst.Module) -> None:
"""Parse the class definitions and resolve imported classes."""
if self._parse_only_models:
for m in self._parse_only_models:
self._parse_pydantic_model(self._classes[m])
self._recursively_parse_pydantic_model(self._classes[m])
else:
self._parse_all_classes()
for cls in self._pydantic_classes.values():
for dep in self._class_deps(cls):
self._model_graph.add_edge(cls.name, dep)
if dep in self._imports:
self._external_models.add(dep)
elif dep not in self._classes:
_logger.warning(
"Can't infer where '%s' is coming from. '%s' depends on it.",
dep,
cls.name,
)

def _parse_pydantic_model(self, cls: ClassDecl) -> None:
"""Parse a Pydantic model."""
self._parse_class_deps(cls)

for cls in self._pydantic_classes.values():
for field in cls.fields:
self._resolve_class_field_names(field.type)

def _recursively_parse_pydantic_model(self, cls: ClassDecl) -> None:
if not self._is_pydantic_model(cls) or cls.name in self._pydantic_classes:
return None

cls = self._finish_parsing_class(cls)
for dep in self._parse_class_deps(cls):
self._recursively_parse_pydantic_model(dep)

def _parse_class_deps(self, cls: ClassDecl) -> list[ClassDecl]:
local_deps = []
for dep in self._class_deps(cls):
self._model_graph.add_edge(cls.name, dep)
if dep in self._imports:
self._external_models.add(dep)
if resolved_dep_path := self._is_imported(dep):
if resolved_dep_path not in ["uuid.UUID", "pydantic.BaseModel"]:
self._external_models.add(resolved_dep_path)
self._model_graph.add_edge(cls.full_path, resolved_dep_path)

elif cls_decl := self._classes.get(dep):
self._parse_pydantic_model(cls_decl)
local_deps.append(cls_decl)
self._model_graph.add_edge(cls.full_path, cls_decl.full_path)
else:
_logger.warning(
"Can't infer where '%s' is coming from. '%s' depends on it.",
dep,
cls.name,
)
return local_deps

def _resolve_class_field_names(self, field_type: PyType) -> None:
"""Resolve fully qualified model names in the field type."""
match field_type:
case UserDefinedType(name=name):
if full_path := self._is_imported(name):
field_type.name = full_path
case GenericType(type_vars=type_vars):
for type_var in type_vars:
self._resolve_class_field_names(type_var)

def _is_imported(self, cls_name: str) -> str | None:
"""
Returns: a full path to the class.
"""
if cls_name not in self._imports:
return None

import_ = self._imports[cls_name]
abs_module_name = resolve_name(
import_.from_module, self._parsing_module.__package__
)
abs_cls_name = f"{abs_module_name}.{import_.name}"

return abs_cls_name

def _class_deps(self, cls: ClassDecl) -> list[str]:
deps = [c for c in cls.base_classes if c != "BaseModel"]
Expand All @@ -191,8 +226,9 @@ def _parse_all_classes(self) -> None:

def _finish_parsing_class(self, cls_decl: ClassDecl) -> ClassDecl:
cls = _ParseClassDecl().visit(self._class_nodes[cls_decl.name]).class_decl
cls.full_path = cls_decl.full_path
self._model_graph.add_node(cls.full_path)
self._pydantic_classes[cls.name] = cls
self._model_graph.add_node(cls.name)

# Try to resolve type aliases.
for f in cls.fields:
Expand All @@ -204,7 +240,10 @@ def _finish_parsing_class(self, cls_decl: ClassDecl) -> ClassDecl:
return cls

def _is_pydantic_model(self, cls: ClassDecl) -> bool:
if "BaseModel" in cls.base_classes and self._imports["BaseModel"] == "pydantic":
if (
"BaseModel" in cls.base_classes
and self._is_imported("BaseModel") == "pydantic.BaseModel"
):
return True

# TODO(povilas): when the base is imported model
Expand Down Expand Up @@ -252,16 +291,17 @@ def visit_AnnAssign(self, node: cst.AnnAssign):
)


class _ParseImportFrom(_Parse):
class _ParseImportFrom(_Parse[cst.ImportFrom]):
def __init__(self) -> None:
super().__init__()
self._from = list[str]()
self._imports = list[str]()
self._imports = list[Import]()
self._relative = 0

def imports(self) -> dict[str, str]:
from_ = "." * self._relative + ".".join(self._from)
return {imp: from_ for imp in self._imports}
def imports(self) -> list[Import]:
for imp in self._imports:
imp.from_module = "." * self._relative + ".".join(self._from)
return self._imports

def visit_ImportFrom(self, node: cst.ImportFrom):
self._relative = len(list(node.relative))
Expand All @@ -271,7 +311,17 @@ def visit_Name(self, node: cst.Name):
self._from.append(node.value)

def visit_ImportAlias(self, node: cst.ImportAlias) -> None:
self._imports.append(cst.ensure_type(node.name, cst.Name).value)
import_name = cst.ensure_type(node.name, cst.Name).value
import_ = Import(from_module="", name=import_name)
if node.asname:
if isinstance(node.asname.name, cst.Name):
import_.alias = node.asname.name.value
else:
_logger.warning(
"Don't know how to parse this import alias: '%s'", node.asname
)

self._imports.append(import_)


def _extract_type(node: cst.BaseExpression) -> PyType:
Expand Down
8 changes: 8 additions & 0 deletions tests/fixtures/builtin_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from uuid import UUID

from pydantic import BaseModel


class User(BaseModel):
id: UUID
name: str
8 changes: 8 additions & 0 deletions tests/fixtures/import_alias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pydantic import BaseModel

from .all_in_one import Class as Cls


class Module(BaseModel):
name: str
classes: list[Cls]
Loading