Skip to content

Commit

Permalink
Added type hints for methods
Browse files Browse the repository at this point in the history
  • Loading branch information
OkeyDev committed Nov 24, 2024
1 parent 815a477 commit e0a962b
Showing 1 changed file with 130 additions and 29 deletions.
159 changes: 130 additions & 29 deletions src/environs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,41 @@
import builtins
import collections
import contextlib
import datetime as dt
import decimal
import functools
import inspect
import json as pyjson
import logging
import os
import re
import typing
import uuid
from collections.abc import Mapping
from datetime import timedelta
from enum import Enum
from pathlib import Path
from urllib.parse import ParseResult, urlparse

import marshmallow as ma
from dj_database_url import DBConfig
from dotenv.main import _walk_to_root, load_dotenv

__all__ = ["EnvError", "Env"]

_T = typing.TypeVar("_T")
_StrType = str
_BoolType = bool
_EnumT = typing.TypeVar("_EnumT", bound=Enum)


ErrorMapping = typing.Mapping[str, typing.List[str]]
ErrorList = typing.List[str]
FieldFactory = typing.Callable[..., ma.fields.Field]
Subcast = typing.Union[typing.Type, typing.Callable[..., _T], ma.fields.Field]
FieldType = typing.Type[ma.fields.Field]
FieldOrFactory = typing.Union[FieldType, FieldFactory]
ParserMethod = typing.Callable
ParserMethod = typing.Callable[..., _T]


_EXPANDED_VAR_PATTERN = re.compile(r"(?<!\\)\$\{([A-Za-z0-9_]+)(:-[^\}:]*)?\}")
Expand Down Expand Up @@ -75,13 +82,96 @@ class ParserConflictError(ValueError):
"""


class Field2MethodType(typing.Generic[_T]):
def __call__(
self,
name: str,
default: typing.Any = ma.missing,
subcast: typing.Optional[Subcast[_T]] = None,
*,
# Subset of relevant marshmallow.Field kwargs
load_default: typing.Any = ma.missing,
validate: typing.Optional[
typing.Union[
typing.Callable[[typing.Any], typing.Any],
typing.Iterable[typing.Callable[[typing.Any], typing.Any]],
]
] = None,
required: bool = False,
allow_none: typing.Optional[bool] = None,
error_messages: typing.Optional[typing.Dict[str, str]] = None,
metadata: typing.Optional[typing.Mapping[str, typing.Any]] = None,
) -> typing.Optional[_T]:
pass


class Field2MethodListType:
def __call__(
self,
name: str,
default: typing.Any = ma.missing,
subcast: typing.Optional[Subcast[_T]] = None,
*,
# Subset of relevant marshmallow.Field kwargs
load_default: typing.Any = ma.missing,
validate: typing.Optional[
typing.Union[
typing.Callable[[typing.Any], typing.Any],
typing.Iterable[typing.Callable[[typing.Any], typing.Any]],
]
] = None,
required: bool = False,
allow_none: typing.Optional[bool] = None,
error_messages: typing.Optional[typing.Dict[str, str]] = None,
metadata: typing.Optional[typing.Mapping[str, typing.Any]] = None,
delimiter: typing.Optional[str] = None,
) -> typing.Optional[list]:
pass


class Field2MethodDictType:
def __call__(
self,
name: str,
default: typing.Any = ma.missing,
*,
# Subset of relevant marshmallow.Field kwargs
load_default: typing.Any = ma.missing,
validate: typing.Optional[
typing.Union[
typing.Callable[[typing.Any], typing.Any],
typing.Iterable[typing.Callable[[typing.Any], typing.Any]],
]
] = None,
required: bool = False,
allow_none: typing.Optional[bool] = None,
error_messages: typing.Optional[typing.Dict[str, str]] = None,
metadata: typing.Optional[typing.Mapping[str, typing.Any]] = None,
subcast_keys: typing.Optional[Subcast[_T]],
subcast_values: typing.Optional[Subcast[_T]],
delimiter: typing.Optional[str] = None,
) -> typing.Optional[dict]:
pass


class Func2MethodEnum:
def __call__(
self,
value,
type: typing.Type[_EnumT],
default: typing.Optional[_EnumT] = None,
ignore_case: bool = False,
) -> typing.Optional[_EnumT]:
pass


def _field2method(
field_or_factory: FieldOrFactory,
method_name: str,
*,
preprocess: typing.Optional[typing.Callable] = None,
preprocess_kwarg_names: typing.Sequence[str] = tuple(),
) -> ParserMethod:
) -> typing.Any:
def method(
self: "Env",
name: str,
Expand Down Expand Up @@ -152,13 +242,13 @@ def method(
self._errors[parsed_key].extend(error.messages)
else:
self._values[parsed_key] = value
return value
return typing.cast(typing.Optional[_T], value)

method.__name__ = method_name
return method


def _func2method(func: typing.Callable, method_name: str) -> ParserMethod:
def _func2method(func: typing.Callable[..., _T], method_name: str) -> typing.Any:
def method(
self: "Env",
name: str,
Expand Down Expand Up @@ -200,14 +290,14 @@ def method(
self._errors[parsed_key].extend(messages)
else:
self._values[parsed_key] = value
return value
return typing.cast(typing.Optional[_T], value)

method.__name__ = method_name
return method


def _make_subcast_field(
subcast: typing.Optional[Subcast],
subcast: typing.Optional[Subcast[_T]],
) -> typing.Type[ma.fields.Field]:
if isinstance(subcast, type) and subcast in ma.Schema.TYPE_MAPPING:
inner_field = ma.Schema.TYPE_MAPPING[subcast]
Expand Down Expand Up @@ -274,9 +364,6 @@ def _preprocess_json(value: typing.Union[str, typing.Mapping, typing.List], **kw
raise ma.ValidationError("Not valid JSON.") from error


_EnumT = typing.TypeVar("_EnumT", bound=Enum)


def _enum_parser(value, type: typing.Type[_EnumT], ignore_case: bool = False) -> _EnumT:
invalid_exc = ma.ValidationError(f"Not a valid '{type.__name__}' enum.")

Expand All @@ -293,7 +380,7 @@ def _enum_parser(value, type: typing.Type[_EnumT], ignore_case: bool = False) ->
raise invalid_exc


def _dj_db_url_parser(value: str, **kwargs) -> dict:
def _dj_db_url_parser(value: str, **kwargs) -> DBConfig:
try:
import dj_database_url
except ImportError as error:
Expand Down Expand Up @@ -350,7 +437,7 @@ def deserialize(
data: typing.Optional[typing.Mapping] = None,
**kwargs,
) -> ParseResult:
ret = super().deserialize(value, attr, data, **kwargs)
ret = typing.cast(str, super().deserialize(value, attr, data, **kwargs))
return urlparse(ret)


Expand Down Expand Up @@ -398,18 +485,20 @@ class Env:

__call__: ParserMethod = _field2method(ma.fields.Field, "__call__")

int = _field2method(ma.fields.Int, "int")
bool = _field2method(ma.fields.Bool, "bool")
str = _field2method(ma.fields.Str, "str")
float = _field2method(ma.fields.Float, "float")
decimal = _field2method(ma.fields.Decimal, "decimal")
list = _field2method(
int: Field2MethodType["int"] = _field2method(ma.fields.Int, "int")
bool: Field2MethodType["bool"] = _field2method(ma.fields.Bool, "bool")
str: Field2MethodType["str"] = _field2method(ma.fields.Str, "str")
float: Field2MethodType["float"] = _field2method(ma.fields.Float, "float")
decimal: Field2MethodType["decimal.Decimal"] = _field2method(
ma.fields.Decimal, "decimal"
)
list: Field2MethodListType = _field2method(
_make_list_field,
"list",
preprocess=_preprocess_list,
preprocess_kwarg_names=("subcast", "delimiter"),
)
dict = _field2method(
dict: Field2MethodDictType = _field2method(
ma.fields.Dict,
"dict",
preprocess=_preprocess_dict,
Expand All @@ -421,16 +510,26 @@ class Env:
"delimiter",
),
)
json = _field2method(ma.fields.Field, "json", preprocess=_preprocess_json)
datetime = _field2method(ma.fields.DateTime, "datetime")
date = _field2method(ma.fields.Date, "date")
time = _field2method(ma.fields.Time, "time")
path = _field2method(PathField, "path")
log_level = _field2method(LogLevelField, "log_level")
timedelta = _field2method(TimeDeltaField, "timedelta")
uuid = _field2method(ma.fields.UUID, "uuid")
url = _field2method(URLField, "url")
enum = _func2method(_enum_parser, "enum")
json: Field2MethodType[typing.Union[typing.List, typing.Dict]] = _field2method(
ma.fields.Field, "json", preprocess=_preprocess_json
)
datetime: Field2MethodType["dt.datetime"] = _field2method(
ma.fields.DateTime, "datetime"
)
date: Field2MethodType["dt.date"] = _field2method(ma.fields.Date, "date")
time: Field2MethodType["dt.time"] = _field2method(ma.fields.Time, "time")
timedelta: Field2MethodType["dt.timedelta"] = _field2method(
TimeDeltaField, "timedelta"
)
path: Field2MethodType[Path] = _field2method(PathField, "path")
log_level: Field2MethodType["builtins.int"] = _field2method(
LogLevelField, "log_level"
)

uuid: Field2MethodType["uuid.UUID"] = _field2method(ma.fields.UUID, "uuid")
url: Field2MethodType[ParseResult] = _field2method(URLField, "url")

enum: Func2MethodEnum = _func2method(_enum_parser, "enum")
dj_db_url = _func2method(_dj_db_url_parser, "dj_db_url")
dj_email_url = _func2method(_dj_email_url_parser, "dj_email_url")
dj_cache_url = _func2method(_dj_cache_url_parser, "dj_cache_url")
Expand All @@ -439,7 +538,9 @@ def __init__(self, *, eager: _BoolType = True, expand_vars: _BoolType = False):
self.eager = eager
self._sealed: bool = False
self.expand_vars = expand_vars
self._fields: typing.Dict[_StrType, typing.Union[ma.fields.Field, type]] = {}
self._fields: typing.Dict[
_StrType, typing.Union[ma.fields.Field, type[ma.fields.Field]]
] = {}
self._values: typing.Dict[_StrType, typing.Any] = {}
self._errors: ErrorMapping = collections.defaultdict(list)
self._prefix: typing.Optional[_StrType] = None
Expand Down

0 comments on commit e0a962b

Please sign in to comment.