Skip to content

Commit

Permalink
Added type hints for methods (#371)
Browse files Browse the repository at this point in the history
* Added type hints for methods

Fix #367

* Build docs with py313

* Modernize type hints

* Fix typing

* Fix type hints

* Add mypy test cases

* Move types to environs.types

* Rename types

* Rename mypy test file so it doesn't get run my pytest

* DRY common kwargs

* Add __future__ import to fix py39

* Update changelog

* Add back subcast_keys and subcast_values

* Add case for untyped parser

---------

Co-authored-by: Steven Loria <[email protected]>
  • Loading branch information
OkeyDev and sloria authored Jan 7, 2025
1 parent 562e25d commit 2989e03
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 39 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## 12.1.0 (unreleased)

Features:

- Typing: Add type hints for parser methods ([#367](https://github.com/sloria/environs/issues/367)).
Thanks [OkeyDev](https://github/OkeyDev) for the PR.

## 12.0.0 (2025-01-06)

Features:
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ classifiers = [
"Programming Language :: Python :: 3.13",
]
requires-python = ">=3.9"
dependencies = ["python-dotenv", "marshmallow>=3.13.0"]
dependencies = [
"python-dotenv",
"marshmallow>=3.13.0",
"typing-extensions; python_version < '3.11'",
]

[project.urls]
Changelog = "https://github.com/sloria/environs/blob/master/CHANGELOG.md"
Expand Down
88 changes: 50 additions & 38 deletions src/environs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,39 @@

import collections
import contextlib
import datetime as dt
import decimal
import functools
import inspect
import json as pyjson
import logging
import os
import re
import typing
import uuid
from collections.abc import Mapping
from datetime import timedelta
from enum import Enum
from pathlib import Path
from urllib.parse import ParseResult, urlparse

import marshmallow as ma
from dj_database_url import DBConfig
from dotenv.main import _walk_to_root, load_dotenv

from .types import (
DictFieldMethod,
EnumFuncMethod,
EnumT,
ErrorList,
ErrorMapping,
FieldFactory,
FieldMethod,
FieldOrFactory,
ListFieldMethod,
ParserMethod,
Subcast,
)

if typing.TYPE_CHECKING:
try:
from dj_database_url import DBConfig
Expand All @@ -29,15 +46,9 @@
_T = typing.TypeVar("_T")
_StrType = str
_BoolType = bool

ErrorMapping = typing.Mapping[str, list[str]]
ErrorList = list[str]
FieldFactory = typing.Callable[..., ma.fields.Field]
Subcast = typing.Union[type, typing.Callable[..., _T], ma.fields.Field]
FieldType = type[ma.fields.Field]
FieldOrFactory = typing.Union[FieldType, FieldFactory]
ParserMethod = typing.Callable[..., typing.Any]

_IntType = int
_ListType = list
_DictType = dict

_EXPANDED_VAR_PATTERN = re.compile(r"(?<!\\)\$\{([A-Za-z0-9_]+)(:-[^\}:]*)?\}")
# Ordered duration strings, loosely based on the [GEP-2257](https://gateway-api.sigs.k8s.io/geps/gep-2257/) spec
Expand Down Expand Up @@ -91,12 +102,12 @@ def _field2method(
*,
preprocess: typing.Callable | None = None,
preprocess_kwarg_names: typing.Sequence[str] = tuple(),
) -> ParserMethod:
) -> typing.Any:
def method(
self: Env,
name: str,
default: typing.Any = ma.missing,
subcast: Subcast | None = None,
subcast: Subcast[_T] | None = None,
*,
# Subset of relevant marshmallow.Field kwargs
load_default: typing.Any = ma.missing,
Expand Down Expand Up @@ -161,13 +172,13 @@ def method(
self._errors[parsed_key].extend(error.messages)
else:
self._values[parsed_key] = value
return value
return typing.cast(typing.Optional[_T], value)

method.__name__ = method_name
return method


def _func2method(func: typing.Callable, method_name: str) -> ParserMethod:
def _func2method(func: typing.Callable[..., _T], method_name: str) -> typing.Any:
def method(
self: Env,
name: str,
Expand Down Expand Up @@ -209,7 +220,7 @@ def method(
self._errors[parsed_key].extend(messages)
else:
self._values[parsed_key] = value
return value
return typing.cast(typing.Optional[_T], value)

method.__name__ = method_name
return method
Expand Down Expand Up @@ -292,10 +303,7 @@ def _preprocess_json(value: str | typing.Mapping | list, **kwargs):
raise ma.ValidationError("Not valid JSON.") from error


_EnumT = typing.TypeVar("_EnumT", bound=Enum)


def _enum_parser(value, type: type[_EnumT], ignore_case: bool = False) -> _EnumT:
def _enum_parser(value, type: type[EnumT], ignore_case: bool = False) -> EnumT:
if isinstance(value, type):
return value

Expand Down Expand Up @@ -371,7 +379,7 @@ def deserialize( # type: ignore[override]
data: typing.Mapping[str, typing.Any] | None = None,
**kwargs,
) -> ParseResult:
ret = super().deserialize(value, attr, data, **kwargs)
ret = typing.cast(str, super().deserialize(value, attr, data, **kwargs))
return urlparse(ret)


Expand Down Expand Up @@ -423,20 +431,20 @@ def _deserialize(self, value, *args, **kwargs) -> timedelta:
class Env:
"""An environment variable reader."""

__call__: ParserMethod = _field2method(ma.fields.Raw, "__call__")
__call__: FieldMethod[typing.Any] = _field2method(ma.fields.Raw, "__call__")

int = _field2method(ma.fields.Int, "int")
bool = _field2method(ma.fields.Bool, "bool")
str = _field2method(ma.fields.Str, "str")
float = _field2method(ma.fields.Float, "float")
decimal = _field2method(ma.fields.Decimal, "decimal")
list = _field2method(
int: FieldMethod[int] = _field2method(ma.fields.Int, "int")
bool: FieldMethod[bool] = _field2method(ma.fields.Bool, "bool")
str: FieldMethod[str] = _field2method(ma.fields.Str, "str")
float: FieldMethod[float] = _field2method(ma.fields.Float, "float")
decimal: FieldMethod[decimal.Decimal] = _field2method(ma.fields.Decimal, "decimal")
list: ListFieldMethod = _field2method(
_make_list_field,
"list",
preprocess=_preprocess_list,
preprocess_kwarg_names=("subcast", "delimiter"),
)
dict = _field2method(
dict: DictFieldMethod = _field2method(
ma.fields.Dict,
"dict",
preprocess=_preprocess_dict,
Expand All @@ -448,16 +456,20 @@ class Env:
"delimiter",
),
)
json = _field2method(ma.fields.Field, "json", preprocess=_preprocess_json)
datetime = _field2method(ma.fields.DateTime, "datetime")
date = _field2method(ma.fields.Date, "date")
time = _field2method(ma.fields.Time, "time")
path = _field2method(_PathField, "path")
log_level = _field2method(_LogLevelField, "log_level")
timedelta = _field2method(_TimeDeltaField, "timedelta")
uuid = _field2method(ma.fields.UUID, "uuid")
url = _field2method(_URLField, "url")
enum = _func2method(_enum_parser, "enum")
json: FieldMethod[_ListType | _DictType] = _field2method(
ma.fields.Field, "json", preprocess=_preprocess_json
)
datetime: FieldMethod[dt.datetime] = _field2method(ma.fields.DateTime, "datetime")
date: FieldMethod[dt.date] = _field2method(ma.fields.Date, "date")
time: FieldMethod[dt.time] = _field2method(ma.fields.Time, "time")
timedelta: FieldMethod[dt.timedelta] = _field2method(_TimeDeltaField, "timedelta")
path: FieldMethod[Path] = _field2method(_PathField, "path")
log_level: FieldMethod[_IntType] = _field2method(_LogLevelField, "log_level")

uuid: FieldMethod[uuid.UUID] = _field2method(ma.fields.UUID, "uuid")
url: FieldMethod[ParseResult] = _field2method(_URLField, "url")

enum: EnumFuncMethod = _func2method(_enum_parser, "enum")
dj_db_url = _func2method(_dj_db_url_parser, "dj_db_url")
dj_email_url = _func2method(_dj_email_url_parser, "dj_email_url")
dj_cache_url = _func2method(_dj_cache_url_parser, "dj_cache_url")
Expand Down
89 changes: 89 additions & 0 deletions src/environs/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Custom types and type aliases.
.. warning::
This module is provisional. Types may be modified, added, and removed between minor releases.
"""

from __future__ import annotations

import enum
import typing

try:
from typing import Unpack
except ImportError: # Remove when dropping Python 3.10
from typing_extensions import Unpack

import marshmallow as ma

T = typing.TypeVar("T")
EnumT = typing.TypeVar("EnumT", bound=enum.Enum)


ErrorMapping = typing.Mapping[str, list[str]]
ErrorList = list[str]
FieldFactory = typing.Callable[..., ma.fields.Field]
Subcast = typing.Union[type, typing.Callable[..., T], ma.fields.Field]
FieldType = type[ma.fields.Field]
FieldOrFactory = typing.Union[FieldType, FieldFactory]
ParserMethod = typing.Callable[..., T]


class BaseMethodKwargs(typing.TypedDict, total=False):
# Subset of relevant marshmallow.Field kwargs shared by all parser methods
load_default: typing.Any
validate: (
typing.Callable[[typing.Any], typing.Any]
| typing.Iterable[typing.Callable[[typing.Any], typing.Any]]
| None
)
required: bool
allow_none: bool | None
error_messages: dict[str, str] | None
metadata: typing.Mapping[str, typing.Any] | None


class FieldMethod(typing.Generic[T]):
def __call__(
self,
name: str,
default: typing.Any = ma.missing,
subcast: Subcast[T] | None = None,
**kwargs: Unpack[BaseMethodKwargs],
) -> T | None: ...


class ListFieldMethod:
def __call__(
self,
name: str,
default: typing.Any = ma.missing,
subcast: Subcast[T] | None = None,
*,
delimiter: str | None = None,
**kwargs: Unpack[BaseMethodKwargs],
) -> list | None: ...


class DictFieldMethod:
def __call__(
self,
name: str,
default: typing.Any = ma.missing,
*,
subcast_keys: Subcast[T] | None = None,
subcast_values: Subcast[T] | None = None,
delimiter: str | None = None,
**kwargs: Unpack[BaseMethodKwargs],
) -> dict | None: ...


class EnumFuncMethod:
def __call__(
self,
value,
type: type[EnumT],
default: EnumT | None = None,
ignore_case: bool = False,
) -> EnumT | None: ...
41 changes: 41 additions & 0 deletions tests/mypy_test_cases/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Test cases for type hints of environs.Env.
To run these, use: ::
tox -e mypy-marshmallow3
Or ::
tox -e mypy-marshmallowdev
"""

import datetime as dt
import decimal
import pathlib
import uuid
from typing import Any
from urllib.parse import ParseResult

import environs

env = environs.Env()

A: int | None = env.int("FOO", None)
B: bool | None = env.bool("FOO", None)
C: str | None = env.str("FOO", None)
D: float | None = env.float("FOO", None)
E: decimal.Decimal | None = env.decimal("FOO", None)
F: list | None = env.list("FOO", None)
G: list[int] | None = env.list("FOO", None, subcast=int)
H: dict | None = env.dict("FOO", None)
J: dict[str, int] | None = env.dict("FOO", None, subcast_keys=str, subcast_values=int)
K: list | dict | None = env.json("FOO", None)
L: dt.datetime | None = env.datetime("FOO", None)
M: dt.date | None = env.date("FOO", None)
N: dt.time | None = env.time("FOO", None)
P: dt.timedelta | None = env.timedelta("FOO", None)
Q: pathlib.Path | None = env.path("FOO", None)
R: int | None = env.log_level("FOO", None)
S: uuid.UUID | None = env.uuid("FOO", None)
T: ParseResult | None = env.url("FOO", None)
U: Any = env("FOO", None)

0 comments on commit 2989e03

Please sign in to comment.