diff --git a/pydantic2zod/_codegen.py b/pydantic2zod/_codegen.py index bb51581..aca2d2a 100644 --- a/pydantic2zod/_codegen.py +++ b/pydantic2zod/_codegen.py @@ -33,9 +33,11 @@ def __init__( self, model_rename_rules: dict[str, str] | None = None, modify_models: Callable[[list[ClassDecl]], list[ClassDecl]] | None = None, + gen_header: Callable[[], str] | None = None, ) -> None: self._model_rename_rules = model_rename_rules or {} self._modify_models = modify_models or (lambda m: m) + self._gen_header = gen_header or (lambda: "") def to_zod(self, pydantic_models: list[ClassDecl]) -> str: self._apply_model_rename_rules(pydantic_models) @@ -43,7 +45,7 @@ def to_zod(self, pydantic_models: list[ClassDecl]) -> str: _warn_about_duplicate_models(models) code = Lines() - self._gen_header(code) + code.add(self._gen_header()) for cls in models: if not cls.name.startswith("_"): @@ -52,16 +54,6 @@ def to_zod(self, pydantic_models: list[ClassDecl]) -> str: return str(code) - def _gen_header(self, code: "Lines") -> None: - header = """ -/** - * NOTE: automatically generated by the pydantic2zod compiler. - */ - -import { z } from "zod"; -""" - code.add(header) - def _apply_model_rename_rules(self, pydantic_models: list[ClassDecl]) -> None: for model in pydantic_models: if new_name := self._model_rename_rules.get(model.full_path): diff --git a/pydantic2zod/_compiler.py b/pydantic2zod/_compiler.py index 86fd465..c9a188b 100644 --- a/pydantic2zod/_compiler.py +++ b/pydantic2zod/_compiler.py @@ -33,7 +33,9 @@ class Compiler: """ def __init__(self) -> None: - self._codegen = Codegen(self.MODEL_RENAME_RULES, self._modify_models) + self._codegen = Codegen( + self.MODEL_RENAME_RULES, self._modify_models, self._gen_header + ) self._pydantic_models: list[ClassDecl] = [] def parse(self, module_name: str) -> Self: @@ -52,3 +54,13 @@ def _modify_models(self, pydantic_models: list[ClassDecl]) -> list[ClassDecl]: e.g. remove default field values. """ return pydantic_models + + def _gen_header(self) -> str: + """Override in case you want to add some header to the generated code.""" + return """ +/** + * NOTE: automatically generated by the pydantic2zod compiler. + */ + +import { z } from "zod"; +"""