Skip to content

Commit

Permalink
Migrate mypy to pyright (#2295)
Browse files Browse the repository at this point in the history
* feat: fix types

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feat: remove mypy from deps

* feat: run lint and other checks in CI

* fix: fix setup env

* fix: remove unnecessary lines in CI config

* fix: remove unnecessary lines in CI config

* fix: fix CI uv

* fix: fix CI uv

* fix: fix types

* fix: fix types

* fix: fix types

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: fix types

* fix: fix types

* fix: fix types

* fix: fix types

* fix: fix types

* fix: fix types

* fix: ignore local coverage check in CI

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: exclude readme for windows

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
koxudaxi and pre-commit-ci[bot] authored Feb 3, 2025
1 parent 2c7f319 commit 95b28c3
Show file tree
Hide file tree
Showing 23 changed files with 132 additions and 100 deletions.
30 changes: 30 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,33 @@ jobs:
env_vars: OS,PY,ISORT
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
check:
name: tox env ${{ matrix.tox_env }} on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
tox_env:
- type
- dev
- docs
- pkg_meta
- readme
os:
- ubuntu-latest
- windows-latest
exclude:
- { os: windows-latest, tox_env: docs }
- { os: windows-latest, tox_env: readme }
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Install the latest version of uv
uses: astral-sh/setup-uv@v5
- name: Install tox
run: uv tool install --python-preference only-managed --python 3.13 tox --with tox-uv
- name: Setup check suite
run: tox r -vv --notest --skip-missing-interpreters false -e ${{ matrix.tox_env }}
- name: Run check for ${{ matrix.tox_env }}
run: tox r --skip-pkg-install -e ${{ matrix.tox_env }}
14 changes: 11 additions & 3 deletions datamodel_code_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def enable_debug_message() -> None: # pragma: no cover
pysnooper.tracer.DISABLED = False


DEFAULT_MAX_VARIABLE_LENGTH: int = 100


def snooper_to_methods( # type: ignore
output=None,
watch=(),
Expand All @@ -90,7 +93,7 @@ def snooper_to_methods( # type: ignore
overwrite=False,
thread_info=False,
custom_repr=(),
max_variable_length=100,
max_variable_length: Optional[int] = DEFAULT_MAX_VARIABLE_LENGTH,
) -> Callable[..., Any]:
def inner(cls: Type[T]) -> Type[T]:
if not pysnooper:
Expand All @@ -108,7 +111,9 @@ def inner(cls: Type[T]) -> Type[T]:
overwrite,
thread_info,
custom_repr,
max_variable_length,
max_variable_length
if max_variable_length is not None
else DEFAULT_MAX_VARIABLE_LENGTH,
)(method)
setattr(cls, name, snooper_method)
return cls
Expand Down Expand Up @@ -424,8 +429,10 @@ def get_header_and_first_line(csv_file: IO[str]) -> Dict[str, Any]:
data_model_types = get_data_model_types(
output_model_type, target_python_version, output_datetime_class
)
source = input_text or input_
assert not isinstance(source, Mapping)
parser = parser_class(
source=input_text or input_,
source=source,
data_model_type=data_model_types.data_model,
data_model_root_type=data_model_types.root_model,
data_model_field_type=data_model_types.field_model,
Expand Down Expand Up @@ -514,6 +521,7 @@ def get_header_and_first_line(csv_file: IO[str]) -> Dict[str, Any]:
# input_ might be a dict object provided directly, and missing a name field
input_filename = getattr(input_, 'name', '<dict>')
else:
assert isinstance(input_, Path)
input_filename = input_.name
if not results:
raise Error('Models not found in the input data')
Expand Down
24 changes: 13 additions & 11 deletions datamodel_code_generator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def sig_int_handler(_: int, __: Any) -> None: # pragma: no cover

class Config(BaseModel):
if PYDANTIC_V2:
model_config = ConfigDict(arbitrary_types_allowed=True)
model_config = ConfigDict(arbitrary_types_allowed=True) # pyright: ignore [reportAssignmentType]

def get(self, item: str) -> Any:
return getattr(self, item)
Expand Down Expand Up @@ -185,8 +185,8 @@ def validate_custom_file_header(cls, values: Dict[str, Any]) -> Dict[str, Any]:

@model_validator(mode='after')
def validate_keyword_only(cls, values: Dict[str, Any]) -> Dict[str, Any]:
output_model_type: DataModelType = values.get('output_model_type')
python_target: PythonVersion = values.get('target_python_version')
output_model_type: DataModelType = values.get('output_model_type') # pyright: ignore [reportAssignmentType]
python_target: PythonVersion = values.get('target_python_version') # pyright: ignore [reportAssignmentType]
if (
values.get('keyword_only')
and output_model_type == DataModelType.DataclassesDataclass
Expand Down Expand Up @@ -219,7 +219,7 @@ def validate_http_headers(cls, value: Any) -> Optional[List[Tuple[str, str]]]:
def validate_each_item(each_item: Any) -> Tuple[str, str]:
if isinstance(each_item, str): # pragma: no cover
try:
field_name, field_value = each_item.split(':', maxsplit=1) # type: str, str
field_name, field_value = each_item.split(':', maxsplit=1)
return field_name, field_value.lstrip()
except ValueError:
raise Error(f'Invalid http header: {each_item!r}')
Expand All @@ -236,7 +236,7 @@ def validate_http_query_parameters(
def validate_each_item(each_item: Any) -> Tuple[str, str]:
if isinstance(each_item, str): # pragma: no cover
try:
field_name, field_value = each_item.split('=', maxsplit=1) # type: str, str
field_name, field_value = each_item.split('=', maxsplit=1)
return field_name, field_value.lstrip()
except ValueError:
raise Error(f'Invalid http query parameter: {each_item!r}')
Expand All @@ -248,14 +248,16 @@ def validate_each_item(each_item: Any) -> Tuple[str, str]:

@model_validator(mode='before')
def validate_additional_imports(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if values.get('additional_imports') is not None:
values['additional_imports'] = values.get('additional_imports').split(',')
additional_imports = values.get('additional_imports')
if additional_imports is not None:
values['additional_imports'] = additional_imports.split(',')
return values

@model_validator(mode='before')
def validate_custom_formatters(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if values.get('custom_formatters') is not None:
values['custom_formatters'] = values.get('custom_formatters').split(',')
custom_formatters = values.get('custom_formatters')
if custom_formatters is not None:
values['custom_formatters'] = custom_formatters.split(',')
return values

if PYDANTIC_V2:
Expand All @@ -282,7 +284,7 @@ def validate_root(cls, values: Any) -> Any:
disable_warnings: bool = False
target_python_version: PythonVersion = PythonVersion.PY_38
base_class: str = ''
additional_imports: Optional[List[str]] = (None,)
additional_imports: Optional[List[str]] = None
custom_template_dir: Optional[Path] = None
extra_template_data: Optional[TextIOBase] = None
validation: bool = False
Expand Down Expand Up @@ -427,7 +429,7 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:
print(
f"Installed black doesn't support Python version {config.target_python_version.value}.\n" # type: ignore
f'You have to install a newer black.\n'
f'Installed black version: {black.__version__}',
f'Installed black version: {black.__version__}', # pyright: ignore [reportPrivateImportUsage]
file=sys.stderr,
)
return Exit.ERROR
Expand Down
2 changes: 1 addition & 1 deletion datamodel_code_generator/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def __init__(
if black.__version__.startswith('19.'): # type: ignore
warn(
f"black doesn't support `experimental-string-processing` option" # type: ignore
f' for wrapping string literal in {black.__version__}'
f' for wrapping string literal in {black.__version__}' # pyright: ignore [reportPrivateImportUsage]
)
elif black.__version__ < '24.1.0': # type: ignore
black_kwargs['experimental_string_processing'] = (
Expand Down
3 changes: 2 additions & 1 deletion datamodel_code_generator/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def get_body(
headers=headers,
verify=not ignore_tls,
follow_redirects=True,
params=query_parameters,
params=query_parameters, # pyright: ignore [reportArgumentType]
# TODO: Improve params type
).text


Expand Down
4 changes: 3 additions & 1 deletion datamodel_code_generator/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ class DataModelSet(NamedTuple):
def get_data_model_types(
data_model_type: DataModelType,
target_python_version: PythonVersion = DEFAULT_TARGET_PYTHON_VERSION,
target_datetime_class: DatetimeClassType = DEFAULT_TARGET_DATETIME_CLASS,
target_datetime_class: Optional[DatetimeClassType] = None,
) -> DataModelSet:
from .. import DataModelType
from . import dataclass, msgspec, pydantic, pydantic_v2, rootmodel, typed_dict
from .types import DataTypeManager

if target_datetime_class is None:
target_datetime_class = DEFAULT_TARGET_DATETIME_CLASS
if data_model_type == DataModelType.PydanticBaseModel:
return DataModelSet(
data_model=pydantic.BaseModel,
Expand Down
10 changes: 6 additions & 4 deletions datamodel_code_generator/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class ConstraintsBase(_BaseModel):
unique_items: Optional[bool] = Field(None, alias='uniqueItems')
_exclude_fields: ClassVar[Set[str]] = {'has_constraints'}
if PYDANTIC_V2:
model_config = ConfigDict(
model_config = ConfigDict( # pyright: ignore [reportAssignmentType]
arbitrary_types_allowed=True, ignored_types=(cached_property,)
)
else:
Expand Down Expand Up @@ -87,7 +87,9 @@ def merge_constraints(
else:
model_field_constraints = {}

if not issubclass(constraints_class, ConstraintsBase): # pragma: no cover
if constraints_class is None or not issubclass(
constraints_class, ConstraintsBase
): # pragma: no cover
return None

return constraints_class.parse_obj(
Expand Down Expand Up @@ -165,7 +167,7 @@ def imports(self) -> Tuple[Import, ...]:
type_hint = self.type_hint
has_union = not self.data_type.use_union_operator and UNION_PREFIX in type_hint
imports: List[Union[Tuple[Import], Iterator[Import]]] = [
(
iter(
i
for i in self.data_type.all_imports
if not (not has_union and i == IMPORT_UNION)
Expand Down Expand Up @@ -251,7 +253,7 @@ def get_module_name(name: str, file_path: Optional[Path]) -> str:


class TemplateBase(ABC):
@property
@cached_property
@abstractmethod
def template_file_path(self) -> Path:
raise NotImplementedError
Expand Down
2 changes: 1 addition & 1 deletion datamodel_code_generator/model/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def find_member(self, value: Any) -> Optional[Member]:

for field in self.fields:
# Remove surrounding quotes from field default value
field_default = field.default.strip('\'"')
field_default = (field.default or '').strip('\'"')

# Compare values after removing quotes
if field_default == str_value:
Expand Down
2 changes: 1 addition & 1 deletion datamodel_code_generator/model/pydantic/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,4 +322,4 @@ def __init__(
if config_parameters:
from datamodel_code_generator.model.pydantic import Config

self.extra_template_data['config'] = Config.parse_obj(config_parameters)
self.extra_template_data['config'] = Config.parse_obj(config_parameters) # pyright: ignore [reportArgumentType]
2 changes: 1 addition & 1 deletion datamodel_code_generator/model/pydantic/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def __init__(
self.data_type,
strict_types=self.strict_types,
pattern_key=self.PATTERN_KEY,
target_datetime_class=target_datetime_class,
target_datetime_class=self.target_datetime_class,
)
self.strict_type_map: Dict[StrictTypes, DataType] = strict_type_map_factory(
self.data_type,
Expand Down
4 changes: 2 additions & 2 deletions datamodel_code_generator/model/pydantic_v2/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class DataModelField(DataModelFieldV1):
'max_length',
'union_mode',
}
constraints: Optional[Constraints] = None
constraints: Optional[Constraints] = None # pyright: ignore [reportIncompatibleVariableOverride]
_PARSE_METHOD: ClassVar[str] = 'model_validate'
can_have_extra_keys: ClassVar[bool] = False

Expand Down Expand Up @@ -234,7 +234,7 @@ def __init__(
if config_parameters:
from datamodel_code_generator.model.pydantic_v2 import ConfigDict

self.extra_template_data['config'] = ConfigDict.parse_obj(config_parameters)
self.extra_template_data['config'] = ConfigDict.parse_obj(config_parameters) # pyright: ignore [reportArgumentType]
self._additional_imports.append(IMPORT_CONFIG_DICT)

def _get_config_extra(self) -> Optional[Literal["'allow'", "'forbid'"]]:
Expand Down
5 changes: 4 additions & 1 deletion datamodel_code_generator/model/pydantic_v2/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ def type_map_factory(
) -> Dict[Types, DataType]:
result = {
**super().type_map_factory(
data_type, strict_types, pattern_key, target_datetime_class
data_type,
strict_types,
pattern_key,
target_datetime_class or DatetimeClassType.Datetime,
),
Types.hostname: self.data_type.from_import(
IMPORT_CONSTR,
Expand Down
11 changes: 6 additions & 5 deletions datamodel_code_generator/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def __init__(
treat_dots_as_module: bool = False,
use_exact_imports: bool = False,
default_field_extras: Optional[Dict[str, Any]] = None,
target_datetime_class: DatetimeClassType = DatetimeClassType.Datetime,
target_datetime_class: Optional[DatetimeClassType] = DatetimeClassType.Datetime,
keyword_only: bool = False,
no_alias: bool = False,
) -> None:
Expand Down Expand Up @@ -849,12 +849,12 @@ def check_paths(

# Check the main discriminator model path
if mapping:
check_paths(discriminator_model, mapping)
check_paths(discriminator_model, mapping) # pyright: ignore [reportArgumentType]

# Check the base_classes if they exist
if len(type_names) == 0:
for base_class in discriminator_model.base_classes:
check_paths(base_class.reference, mapping)
check_paths(base_class.reference, mapping) # pyright: ignore [reportArgumentType]
else:
type_names = [discriminator_model.path.split('/')[-1]]
if not type_names: # pragma: no cover
Expand Down Expand Up @@ -1061,7 +1061,7 @@ def __collapse_root_models(

data_type.parent.data_type = copied_data_type

elif data_type.parent.is_list:
elif data_type.parent is not None and data_type.parent.is_list:
if self.field_constraints:
model_field.constraints = ConstraintsBase.merge_constraints(
root_type_field.constraints, model_field.constraints
Expand All @@ -1073,6 +1073,7 @@ def __collapse_root_models(
discriminator = root_type_field.extras.get('discriminator')
if discriminator:
model_field.extras['discriminator'] = discriminator
assert isinstance(data_type.parent, DataType)
data_type.parent.data_types.remove(
data_type
) # pragma: no cover
Expand Down Expand Up @@ -1358,7 +1359,7 @@ def sort_key(data_model: DataModel) -> Tuple[int, Tuple[str, ...]]:
module_to_import: Dict[Tuple[str, ...], Imports] = {}

previous_module = () # type: Tuple[str, ...]
for module, models in ((k, [*v]) for k, v in grouped_models): # type: Tuple[str, ...], List[DataModel]
for module, models in ((k, [*v]) for k, v in grouped_models):
for model in models:
model_to_module_models[model] = module, models
self.__delete_duplicate_models(models)
Expand Down
8 changes: 4 additions & 4 deletions datamodel_code_generator/parser/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def parse_enum(self, enum_object: graphql.GraphQLEnumType) -> None:
def parse_field(
self,
field_name: str,
alias: str,
alias: Optional[str],
field: Union[graphql.GraphQLField, graphql.GraphQLInputField],
) -> DataModelFieldBase:
final_data_type = DataType(
Expand All @@ -399,9 +399,9 @@ def parse_field(
elif graphql.is_non_null_type(obj): # pragma: no cover
data_type.is_optional = False

obj = obj.of_type
obj = obj.of_type # pyright: ignore [reportAttributeAccessIssue]

data_type.type = obj.name
data_type.type = obj.name # pyright: ignore [reportAttributeAccessIssue]

required = (not self.force_optional_for_required_fields) and (
not final_data_type.is_optional
Expand Down Expand Up @@ -456,7 +456,7 @@ def parse_object_like(

base_classes = []
if hasattr(obj, 'interfaces'): # pragma: no cover
base_classes = [self.references[i.name] for i in obj.interfaces]
base_classes = [self.references[i.name] for i in obj.interfaces] # pyright: ignore [reportAttributeAccessIssue]

data_model_type = self.data_model_type(
reference=self.references[obj.name],
Expand Down
Loading

0 comments on commit 95b28c3

Please sign in to comment.