Skip to content

Commit

Permalink
Add custom formatters (#1733)
Browse files Browse the repository at this point in the history
* Support custom formatters for CodeFormatter

* Add custom formatters argument

* Add graphql to docs/supported-data-types.md

* Add test

custom formatter for custom-scalar-types.graphql;

* Run poetry run scripts/format.sh

* Add simple doc
  • Loading branch information
denisart authored Nov 24, 2023
1 parent 3e0f0aa commit a36ce94
Show file tree
Hide file tree
Showing 21 changed files with 351 additions and 5 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,10 @@ Template customization:
Wrap string literal by using black `experimental-
string-processing` option (require black 20.8b0 or
later)
--additional-imports Custom imports for output (delimited list input)
--additional-imports Custom imports for output (delimited list input).
For example "datetime.date,datetime.datetime"
--custom-formatters List of modules with custom formatter (delimited list input).
--custom-formatters-kwargs A file with kwargs for custom formatters.

OpenAPI-only options:
--openapi-scopes {schemas,paths,tags,parameters} [{schemas,paths,tags,parameters} ...]
Expand Down
4 changes: 4 additions & 0 deletions datamodel_code_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ def generate(
keep_model_order: bool = False,
custom_file_header: Optional[str] = None,
custom_file_header_path: Optional[Path] = None,
custom_formatters: Optional[List[str]] = None,
custom_formatters_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
remote_text_cache: DefaultPutDict[str, str] = DefaultPutDict()
if isinstance(input_, str):
Expand Down Expand Up @@ -452,6 +454,8 @@ def get_header_and_first_line(csv_file: IO[str]) -> Dict[str, Any]:
capitalise_enum_members=capitalise_enum_members,
keep_model_order=keep_model_order,
known_third_party=data_model_types.known_third_party,
custom_formatters=custom_formatters,
custom_formatters_kwargs=custom_formatters_kwargs,
**kwargs,
)

Expand Down
38 changes: 37 additions & 1 deletion datamodel_code_generator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ class Config:
def get_fields(cls) -> Dict[str, Any]:
return cls.__fields__

@field_validator('aliases', 'extra_template_data', mode='before')
@field_validator(
'aliases', 'extra_template_data', 'custom_formatters_kwargs', mode='before'
)
def validate_file(cls, value: Any) -> Optional[TextIOBase]:
if value is None or isinstance(value, TextIOBase):
return value
Expand Down Expand Up @@ -204,6 +206,14 @@ def validate_additional_imports(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values['additional_imports'] = []
return values

@model_validator(mode='before')
def validate_custom_formatters(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if values.get('custom_formatters') is not None:
values['custom_formatters'] = values.get('custom_formatters').split(',')
else:
values['custom_formatters'] = []
return values

if PYDANTIC_V2:

@model_validator(mode='after') # type: ignore
Expand Down Expand Up @@ -282,6 +292,8 @@ def validate_root(cls, values: Any) -> Any:
keep_model_order: bool = False
custom_file_header: Optional[str] = None
custom_file_header_path: Optional[Path] = None
custom_formatters: Optional[List[str]] = None
custom_formatters_kwargs: Optional[TextIOBase] = None

def merge_args(self, args: Namespace) -> None:
set_args = {
Expand Down Expand Up @@ -391,6 +403,28 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:
)
return Exit.ERROR

if config.custom_formatters_kwargs is None:
custom_formatters_kwargs = None
else:
with config.custom_formatters_kwargs as data:
try:
custom_formatters_kwargs = json.load(data)
except json.JSONDecodeError as e:
print(
f'Unable to load custom_formatters_kwargs mapping: {e}',
file=sys.stderr,
)
return Exit.ERROR
if not isinstance(custom_formatters_kwargs, dict) or not all(
isinstance(k, str) and isinstance(v, str)
for k, v in custom_formatters_kwargs.items()
):
print(
'Custom formatters kwargs mapping must be a JSON string mapping (e.g. {"from": "to", ...})',
file=sys.stderr,
)
return Exit.ERROR

try:
generate(
input_=config.url or config.input or sys.stdin.read(),
Expand Down Expand Up @@ -452,6 +486,8 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:
keep_model_order=config.keep_model_order,
custom_file_header=config.custom_file_header,
custom_file_header_path=config.custom_file_header_path,
custom_formatters=config.custom_formatters,
custom_formatters_kwargs=custom_formatters_kwargs,
)
return Exit.OK
except InvalidClassNameError as e:
Expand Down
13 changes: 12 additions & 1 deletion datamodel_code_generator/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,10 +387,21 @@ def start_section(self, heading: Optional[str]) -> None:
)
base_options.add_argument(
'--additional-imports',
help='Custom imports for output (delimited list input)',
help='Custom imports for output (delimited list input). For example "datetime.date,datetime.datetime"',
type=str,
default=None,
)
base_options.add_argument(
'--custom-formatters',
help='List of modules with custom formatter (delimited list input).',
type=str,
default=None,
)
template_options.add_argument(
'--custom-formatters-kwargs',
help='A file with kwargs for custom formatters.',
type=FileType('rt'),
)

# ======================================================================================
# Options specific to OpenAPI input schemas
Expand Down
48 changes: 48 additions & 0 deletions datamodel_code_generator/format.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from enum import Enum
from importlib import import_module
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
from warnings import warn
Expand Down Expand Up @@ -112,6 +113,8 @@ def __init__(
wrap_string_literal: Optional[bool] = None,
skip_string_normalization: bool = True,
known_third_party: Optional[List[str]] = None,
custom_formatters: Optional[List[str]] = None,
custom_formatters_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
if not settings_path:
settings_path = Path().resolve()
Expand Down Expand Up @@ -167,12 +170,49 @@ def __init__(
settings_path=self.settings_path, **self.isort_config_kwargs
)

self.custom_formatters_kwargs = custom_formatters_kwargs or {}
self.custom_formatters = self._check_custom_formatters(custom_formatters)

def _load_custom_formatter(
self, custom_formatter_import: str
) -> CustomCodeFormatter:
import_ = import_module(custom_formatter_import)

if not hasattr(import_, 'CodeFormatter'):
raise NameError(
f'Custom formatter module `{import_.__name__}` must contains object with name Formatter'
)

formatter_class = import_.__getattribute__('CodeFormatter')

if not issubclass(formatter_class, CustomCodeFormatter):
raise TypeError(
f'The custom module {custom_formatter_import} must inherit from `datamodel-code-generator`'
)

return formatter_class(formatter_kwargs=self.custom_formatters_kwargs)

def _check_custom_formatters(
self, custom_formatters: Optional[List[str]]
) -> List[CustomCodeFormatter]:
if custom_formatters is None:
return []

return [
self._load_custom_formatter(custom_formatter_import)
for custom_formatter_import in custom_formatters
]

def format_code(
self,
code: str,
) -> str:
code = self.apply_isort(code)
code = self.apply_black(code)

for formatter in self.custom_formatters:
code = formatter.apply(code)

return code

def apply_black(self, code: str) -> str:
Expand Down Expand Up @@ -200,3 +240,11 @@ def apply_isort(self, code: str) -> str:

def apply_isort(self, code: str) -> str:
return isort.code(code, config=self.isort_config)


class CustomCodeFormatter:
def __init__(self, formatter_kwargs: Dict[str, Any]) -> None:
self.formatter_kwargs = formatter_kwargs

def apply(self, code: str) -> str:
raise NotImplementedError
6 changes: 6 additions & 0 deletions datamodel_code_generator/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,8 @@ def __init__(
keep_model_order: bool = False,
use_one_literal_as_default: bool = False,
known_third_party: Optional[List[str]] = None,
custom_formatters: Optional[List[str]] = None,
custom_formatters_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
self.data_type_manager: DataTypeManager = data_type_manager_type(
python_version=target_python_version,
Expand Down Expand Up @@ -502,6 +504,8 @@ def __init__(
self.keep_model_order = keep_model_order
self.use_one_literal_as_default = use_one_literal_as_default
self.known_third_party = known_third_party
self.custom_formatter = custom_formatters
self.custom_formatters_kwargs = custom_formatters_kwargs

@property
def iter_source(self) -> Iterator[Source]:
Expand Down Expand Up @@ -1143,6 +1147,8 @@ def parse(
self.wrap_string_literal,
skip_string_normalization=not self.use_double_quotes,
known_third_party=self.known_third_party,
custom_formatters=self.custom_formatter,
custom_formatters_kwargs=self.custom_formatters_kwargs,
)
else:
code_formatter = None
Expand Down
4 changes: 4 additions & 0 deletions datamodel_code_generator/parser/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ def __init__(
keep_model_order: bool = False,
use_one_literal_as_default: bool = False,
known_third_party: Optional[List[str]] = None,
custom_formatters: Optional[List[str]] = None,
custom_formatters_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(
source=source,
Expand Down Expand Up @@ -217,6 +219,8 @@ def __init__(
capitalise_enum_members=capitalise_enum_members,
keep_model_order=keep_model_order,
known_third_party=known_third_party,
custom_formatters=custom_formatters,
custom_formatters_kwargs=custom_formatters_kwargs,
)

self.data_model_scalar_type = data_model_scalar_type
Expand Down
4 changes: 4 additions & 0 deletions datamodel_code_generator/parser/jsonschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,8 @@ def __init__(
capitalise_enum_members: bool = False,
keep_model_order: bool = False,
known_third_party: Optional[List[str]] = None,
custom_formatters: Optional[List[str]] = None,
custom_formatters_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(
source=source,
Expand Down Expand Up @@ -485,6 +487,8 @@ def __init__(
capitalise_enum_members=capitalise_enum_members,
keep_model_order=keep_model_order,
known_third_party=known_third_party,
custom_formatters=custom_formatters,
custom_formatters_kwargs=custom_formatters_kwargs,
)

self.remote_object_cache: DefaultPutDict[str, Dict[str, Any]] = DefaultPutDict()
Expand Down
4 changes: 4 additions & 0 deletions datamodel_code_generator/parser/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ def __init__(
capitalise_enum_members: bool = False,
keep_model_order: bool = False,
known_third_party: Optional[List[str]] = None,
custom_formatters: Optional[List[str]] = None,
custom_formatters_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(
source=source,
Expand Down Expand Up @@ -281,6 +283,8 @@ def __init__(
capitalise_enum_members=capitalise_enum_members,
keep_model_order=keep_model_order,
known_third_party=known_third_party,
custom_formatters=custom_formatters,
custom_formatters_kwargs=custom_formatters_kwargs,
)
self.open_api_scopes: List[OpenAPIScope] = openapi_scopes or [
OpenAPIScope.Schemas
Expand Down
23 changes: 23 additions & 0 deletions docs/custom-formatters.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Custom Code Formatters

New features of the `datamodel-code-generator` it is custom code formatters.

## Usage
To use the `--custom-formatters` option, you'll need to pass the module with your formatter. For example

**your_module.py**
```python
from datamodel_code_generator.format import CustomCodeFormatter

class CodeFormatter(CustomCodeFormatter):
def apply(self, code: str) -> str:
# processed code
return ...

```

and run the following command

```sh
$ datamodel-codegen --input {your_input_file} --output {your_output_file} --custom-formatters "{path_to_your_module}.your_module"
```
7 changes: 5 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,11 @@ Template customization:
Wrap string literal by using black `experimental-
string-processing` option (require black 20.8b0 or
later)
--additional-imports Custom imports for output (delimited list input)

--additional-imports Custom imports for output (delimited list input).
For example "datetime.date,datetime.datetime"
--custom-formatters List of modules with custom formatter (delimited list input).
--custom-formatters-kwargs A file with kwargs for custom formatters.

OpenAPI-only options:
--openapi-scopes {schemas,paths,tags,parameters} [{schemas,paths,tags,parameters} ...]
Scopes of OpenAPI model generation (default: schemas)
Expand Down
1 change: 1 addition & 0 deletions docs/supported-data-types.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ This code generator supports the following input formats:
- JSON Schema ([JSON Schema Core](http://json-schema.org/draft/2019-09/json-schema-validation.html) /[JSON Schema Validation](http://json-schema.org/draft/2019-09/json-schema-validation.html))
- JSON/YAML Data (it will be converted to JSON Schema)
- Python dictionary (it will be converted to JSON Schema)
- GraphQL schema ([GraphQL Schemas and Types](https://graphql.org/learn/schema/))

## Implemented data types and features

Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ nav:
- Generate from JSON Data: jsondata.md
- Generate from GraphQL Schema: graphql.md
- Custom template: custom_template.md
- Custom formatters: custom-formatters.md
- Using as module: using_as_module.md
- Formatting: formatting.md
- Field Constraints: field-constraints.md
Expand Down
37 changes: 37 additions & 0 deletions tests/data/expected/main/main_graphql_custom_formatters/output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# generated by datamodel-codegen:
# filename: custom-scalar-types.graphql
# timestamp: 2019-07-26T00:00:00+00:00

# a comment
from __future__ import annotations

from typing import Optional, TypeAlias

from pydantic import BaseModel, Field
from typing_extensions import Literal

Boolean: TypeAlias = bool
"""
The `Boolean` scalar type represents `true` or `false`.
"""


ID: TypeAlias = str
"""
The `ID` scalar type represents a unique identifier, often used to refetch an object or as key for a cache. The ID type appears in a JSON response as a String; however, it is not intended to be human-readable. When expected as an input type, any string (such as `"4"`) or integer (such as `4`) input value will be accepted as an ID.
"""


Long: TypeAlias = str


String: TypeAlias = str
"""
The `String` scalar type represents textual data, represented as UTF-8 character sequences. The String type is most often used by GraphQL to represent free-form human-readable text.
"""


class A(BaseModel):
duration: Long
id: ID
typename__: Optional[Literal['A']] = Field('A', alias='__typename')
7 changes: 7 additions & 0 deletions tests/data/python/custom_formatters/add_comment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from datamodel_code_generator.format import CustomCodeFormatter


class CodeFormatter(CustomCodeFormatter):
"""Simple correct formatter. Adding a comment to top of code."""
def apply(self, code: str) -> str:
return f'# a comment\n{code}'
Loading

0 comments on commit a36ce94

Please sign in to comment.