Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add environment parsing support for enums. #252

Merged
merged 2 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pydantic_settings/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class SettingsConfigDict(ConfigDict, total=False):
env_ignore_empty: bool
env_nested_delimiter: str | None
env_parse_none_str: str | None
env_parse_enums: bool | None
secrets_dir: str | Path | None
json_file: PathType | None
json_file_encoding: str | None
Expand Down Expand Up @@ -65,6 +66,7 @@ class BaseSettings(BaseModel):
_env_nested_delimiter: The nested env values delimiter. Defaults to `None`.
_env_parse_none_str: The env string value that should be parsed (e.g. "null", "void", "None", etc.)
into `None` type(None). Defaults to `None` type(None), which means no parsing should occur.
_env_parse_enums: Parse enum field names to values. Defaults to `None.`, which means no parsing should occur.
_secrets_dir: The secret files directory. Defaults to `None`.
"""

Expand All @@ -77,6 +79,7 @@ def __init__(
_env_ignore_empty: bool | None = None,
_env_nested_delimiter: str | None = None,
_env_parse_none_str: str | None = None,
_env_parse_enums: bool | None = None,
_secrets_dir: str | Path | None = None,
**values: Any,
) -> None:
Expand All @@ -91,6 +94,7 @@ def __init__(
_env_ignore_empty=_env_ignore_empty,
_env_nested_delimiter=_env_nested_delimiter,
_env_parse_none_str=_env_parse_none_str,
_env_parse_enums=_env_parse_enums,
_secrets_dir=_secrets_dir,
)
)
Expand Down Expand Up @@ -129,6 +133,7 @@ def _settings_build_values(
_env_ignore_empty: bool | None = None,
_env_nested_delimiter: str | None = None,
_env_parse_none_str: str | None = None,
_env_parse_enums: bool | None = None,
_secrets_dir: str | Path | None = None,
) -> dict[str, Any]:
# Determine settings config values
Expand All @@ -149,6 +154,7 @@ def _settings_build_values(
env_parse_none_str = (
_env_parse_none_str if _env_parse_none_str is not None else self.model_config.get('env_parse_none_str')
)
env_parse_enums = _env_parse_enums if _env_parse_enums is not None else self.model_config.get('env_parse_enums')
secrets_dir = _secrets_dir if _secrets_dir is not None else self.model_config.get('secrets_dir')

# Configure built-in sources
Expand All @@ -160,6 +166,7 @@ def _settings_build_values(
env_nested_delimiter=env_nested_delimiter,
env_ignore_empty=env_ignore_empty,
env_parse_none_str=env_parse_none_str,
env_parse_enums=env_parse_enums,
)
dotenv_settings = DotEnvSettingsSource(
self.__class__,
Expand All @@ -170,6 +177,7 @@ def _settings_build_values(
env_nested_delimiter=env_nested_delimiter,
env_ignore_empty=env_ignore_empty,
env_parse_none_str=env_parse_none_str,
env_parse_enums=env_parse_enums,
)

file_secret_settings = SecretsSettingsSource(
Expand Down Expand Up @@ -201,6 +209,7 @@ def _settings_build_values(
env_ignore_empty=False,
env_nested_delimiter=None,
env_parse_none_str=None,
env_parse_enums=None,
json_file=None,
json_file_encoding=None,
yaml_file=None,
Expand Down
26 changes: 23 additions & 3 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from abc import ABC, abstractmethod
from collections import deque
from dataclasses import is_dataclass
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any, List, Mapping, Sequence, Tuple, Union, cast

Expand Down Expand Up @@ -179,6 +180,7 @@ def __init__(
env_prefix: str | None = None,
env_ignore_empty: bool | None = None,
env_parse_none_str: str | None = None,
env_parse_enums: bool | None = None,
) -> None:
super().__init__(settings_cls)
self.case_sensitive = case_sensitive if case_sensitive is not None else self.config.get('case_sensitive', False)
Expand All @@ -189,6 +191,7 @@ def __init__(
self.env_parse_none_str = (
env_parse_none_str if env_parse_none_str is not None else self.config.get('env_parse_none_str')
)
self.env_parse_enums = env_parse_enums if env_parse_enums is not None else self.config.get('env_parse_enums')

def _apply_case_sensitive(self, value: str) -> str:
return value.lower() if not self.case_sensitive else value
Expand Down Expand Up @@ -357,8 +360,11 @@ def __init__(
env_prefix: str | None = None,
env_ignore_empty: bool | None = None,
env_parse_none_str: str | None = None,
env_parse_enums: bool | None = None,
) -> None:
super().__init__(settings_cls, case_sensitive, env_prefix, env_ignore_empty, env_parse_none_str)
super().__init__(
settings_cls, case_sensitive, env_prefix, env_ignore_empty, env_parse_none_str, env_parse_enums
)
self.secrets_dir = secrets_dir if secrets_dir is not None else self.config.get('secrets_dir')

def __call__(self) -> dict[str, Any]:
Expand Down Expand Up @@ -447,8 +453,11 @@ def __init__(
env_nested_delimiter: str | None = None,
env_ignore_empty: bool | None = None,
env_parse_none_str: str | None = None,
env_parse_enums: bool | None = None,
) -> None:
super().__init__(settings_cls, case_sensitive, env_prefix, env_ignore_empty, env_parse_none_str)
super().__init__(
settings_cls, case_sensitive, env_prefix, env_ignore_empty, env_parse_none_str, env_parse_enums
)
self.env_nested_delimiter = (
env_nested_delimiter if env_nested_delimiter is not None else self.config.get('env_nested_delimiter')
)
Expand Down Expand Up @@ -498,6 +507,10 @@ def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, val
ValuesError: When There is an error in deserializing value for complex field.
"""
is_complex, allow_parse_failure = self._field_is_complex(field)
if self.env_parse_enums and lenient_issubclass(field.annotation, Enum):
if value in tuple(val.name for val in field.annotation): # type: ignore
value = field.annotation[value] # type: ignore

if is_complex or value_is_complex:
if value is None:
# field is complex but no value found so far, try explode_env_vars
Expand Down Expand Up @@ -645,13 +658,20 @@ def __init__(
env_nested_delimiter: str | None = None,
env_ignore_empty: bool | None = None,
env_parse_none_str: str | None = None,
env_parse_enums: bool | None = None,
) -> None:
self.env_file = env_file if env_file != ENV_FILE_SENTINEL else settings_cls.model_config.get('env_file')
self.env_file_encoding = (
env_file_encoding if env_file_encoding is not None else settings_cls.model_config.get('env_file_encoding')
)
super().__init__(
settings_cls, case_sensitive, env_prefix, env_nested_delimiter, env_ignore_empty, env_parse_none_str
settings_cls,
case_sensitive,
env_prefix,
env_nested_delimiter,
env_ignore_empty,
env_parse_none_str,
env_parse_enums,
)

def _load_env_vars(self) -> Mapping[str, str | None]:
Expand Down
35 changes: 35 additions & 0 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
import uuid
from datetime import datetime, timezone
from enum import IntEnum
from pathlib import Path
from typing import Any, Callable, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar, Union

Expand Down Expand Up @@ -1853,6 +1854,40 @@ class Settings(BaseSettings):
]


def test_env_parse_enums(env):
class FruitsEnum(IntEnum):
pear = 0
kiwi = 1
lime = 2

class Settings(BaseSettings):
fruit: FruitsEnum

with pytest.raises(ValidationError) as exc_info:
env.set('FRUIT', 'kiwi')
s = Settings()
assert exc_info.value.errors(include_url=False) == [
{
'type': 'int_parsing',
'loc': ('fruit',),
'msg': 'Input should be a valid integer, unable to parse string as an integer',
'input': 'kiwi',
}
]

env.set('FRUIT', str(FruitsEnum.lime.value))
s = Settings()
assert s.fruit == FruitsEnum.lime

env.set('FRUIT', 'kiwi')
s = Settings(_env_parse_enums=True)
assert s.fruit == FruitsEnum.kiwi

env.set('FRUIT', str(FruitsEnum.lime.value))
s = Settings(_env_parse_enums=True)
assert s.fruit == FruitsEnum.lime


def test_env_parse_none_str(env):
env.set('x', 'null')
env.set('y', 'y_override')
Expand Down
Loading