Skip to content

Commit

Permalink
feat: ignore empty env vars (#198)
Browse files Browse the repository at this point in the history
  • Loading branch information
niventc authored Dec 11, 2023
1 parent 1d6950f commit 4f24fad
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 22 deletions.
5 changes: 5 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,11 @@ except ValidationError as e:

## Parsing environment variable values

By default environment variables are parsed verbatim, including if the value is empty. You can choose to
ignore empty environment variables by setting the `env_ignore_empty` config setting to `True`. This can be
useful if you would prefer to use the default value for a field rather than an empty value from the
environment.

For most simple field types (such as `int`, `float`, `str`, etc.), the environment variable value is parsed
the same way it would be if passed directly to the initialiser (as a string).

Expand Down
11 changes: 11 additions & 0 deletions pydantic_settings/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class SettingsConfigDict(ConfigDict, total=False):
env_prefix: str
env_file: DotenvType | None
env_file_encoding: str | None
env_ignore_empty: bool
env_nested_delimiter: str | None
secrets_dir: str | Path | None

Expand Down Expand Up @@ -53,6 +54,7 @@ class BaseSettings(BaseModel):
means that the value from `model_config['env_file']` should be used. You can also pass
`None` to indicate that environment variables should not be loaded from an env file.
_env_file_encoding: The env file encoding, e.g. `'latin-1'`. Defaults to `None`.
_env_ignore_empty: Ignore environment variables where the value is an empty string. Default to `False`.
_env_nested_delimiter: The nested env values delimiter. Defaults to `None`.
_secrets_dir: The secret files directory. Defaults to `None`.
"""
Expand All @@ -63,6 +65,7 @@ def __init__(
_env_prefix: str | None = None,
_env_file: DotenvType | None = ENV_FILE_SENTINEL,
_env_file_encoding: str | None = None,
_env_ignore_empty: bool | None = None,
_env_nested_delimiter: str | None = None,
_secrets_dir: str | Path | None = None,
**values: Any,
Expand All @@ -75,6 +78,7 @@ def __init__(
_env_prefix=_env_prefix,
_env_file=_env_file,
_env_file_encoding=_env_file_encoding,
_env_ignore_empty=_env_ignore_empty,
_env_nested_delimiter=_env_nested_delimiter,
_secrets_dir=_secrets_dir,
)
Expand Down Expand Up @@ -111,6 +115,7 @@ def _settings_build_values(
_env_prefix: str | None = None,
_env_file: DotenvType | None = None,
_env_file_encoding: str | None = None,
_env_ignore_empty: bool | None = None,
_env_nested_delimiter: str | None = None,
_secrets_dir: str | Path | None = None,
) -> dict[str, Any]:
Expand All @@ -121,6 +126,9 @@ def _settings_build_values(
env_file_encoding = (
_env_file_encoding if _env_file_encoding is not None else self.model_config.get('env_file_encoding')
)
env_ignore_empty = (
_env_ignore_empty if _env_ignore_empty is not None else self.model_config.get('env_ignore_empty')
)
env_nested_delimiter = (
_env_nested_delimiter
if _env_nested_delimiter is not None
Expand All @@ -135,6 +143,7 @@ def _settings_build_values(
case_sensitive=case_sensitive,
env_prefix=env_prefix,
env_nested_delimiter=env_nested_delimiter,
env_ignore_empty=env_ignore_empty,
)
dotenv_settings = DotEnvSettingsSource(
self.__class__,
Expand All @@ -143,6 +152,7 @@ def _settings_build_values(
case_sensitive=case_sensitive,
env_prefix=env_prefix,
env_nested_delimiter=env_nested_delimiter,
env_ignore_empty=env_ignore_empty,
)

file_secret_settings = SecretsSettingsSource(
Expand Down Expand Up @@ -171,6 +181,7 @@ def _settings_build_values(
env_prefix='',
env_file=None,
env_file_encoding=None,
env_ignore_empty=False,
env_nested_delimiter=None,
secrets_dir=None,
protected_namespaces=('model_', 'settings_'),
Expand Down
54 changes: 39 additions & 15 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,18 @@ def __repr__(self) -> str:

class PydanticBaseEnvSettingsSource(PydanticBaseSettingsSource):
def __init__(
self, settings_cls: type[BaseSettings], case_sensitive: bool | None = None, env_prefix: str | None = None
self,
settings_cls: type[BaseSettings],
case_sensitive: bool | None = None,
env_prefix: str | None = None,
env_ignore_empty: bool | None = None,
) -> None:
super().__init__(settings_cls)
self.case_sensitive = case_sensitive if case_sensitive is not None else self.config.get('case_sensitive', False)
self.env_prefix = env_prefix if env_prefix is not None else self.config.get('env_prefix', '')
self.env_ignore_empty = (
env_ignore_empty if env_ignore_empty is not None else self.config.get('env_ignore_empty', False)
)

def _apply_case_sensitive(self, value: str) -> str:
return value.lower() if not self.case_sensitive else value
Expand Down Expand Up @@ -279,8 +286,9 @@ def __init__(
secrets_dir: str | Path | None = None,
case_sensitive: bool | None = None,
env_prefix: str | None = None,
env_ignore_empty: bool | None = None,
) -> None:
super().__init__(settings_cls, case_sensitive, env_prefix)
super().__init__(settings_cls, case_sensitive, env_prefix, env_ignore_empty)
self.secrets_dir = secrets_dir if secrets_dir is not None else self.config.get('secrets_dir')

def __call__(self) -> dict[str, Any]:
Expand Down Expand Up @@ -367,8 +375,9 @@ def __init__(
case_sensitive: bool | None = None,
env_prefix: str | None = None,
env_nested_delimiter: str | None = None,
env_ignore_empty: bool | None = None,
) -> None:
super().__init__(settings_cls, case_sensitive, env_prefix)
super().__init__(settings_cls, case_sensitive, env_prefix, env_ignore_empty)
self.env_nested_delimiter = (
env_nested_delimiter if env_nested_delimiter is not None else self.config.get('env_nested_delimiter')
)
Expand All @@ -377,9 +386,7 @@ def __init__(
self.env_vars = self._load_env_vars()

def _load_env_vars(self) -> Mapping[str, str | None]:
if self.case_sensitive:
return os.environ
return {k.lower(): v for k, v in os.environ.items()}
return parse_env_vars(os.environ, self.case_sensitive, self.env_ignore_empty)

def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
"""
Expand Down Expand Up @@ -562,17 +569,18 @@ def __init__(
case_sensitive: bool | None = None,
env_prefix: str | None = None,
env_nested_delimiter: str | None = None,
env_ignore_empty: bool | None = None,
) -> None:
self.env_file = env_file if env_file != ENV_FILE_SENTINEL else settings_cls.model_config.get('env_file')
self.env_file_encoding = (
env_file_encoding if env_file_encoding is not None else settings_cls.model_config.get('env_file_encoding')
)
super().__init__(settings_cls, case_sensitive, env_prefix, env_nested_delimiter)
super().__init__(settings_cls, case_sensitive, env_prefix, env_nested_delimiter, env_ignore_empty)

def _load_env_vars(self) -> Mapping[str, str | None]:
return self._read_env_files(self.case_sensitive)
return self._read_env_files()

def _read_env_files(self, case_sensitive: bool) -> Mapping[str, str | None]:
def _read_env_files(self) -> Mapping[str, str | None]:
env_files = self.env_file
if env_files is None:
return {}
Expand All @@ -585,7 +593,12 @@ def _read_env_files(self, case_sensitive: bool) -> Mapping[str, str | None]:
env_path = Path(env_file).expanduser()
if env_path.is_file():
dotenv_vars.update(
read_env_file(env_path, encoding=self.env_file_encoding, case_sensitive=case_sensitive)
read_env_file(
env_path,
encoding=self.env_file_encoding,
case_sensitive=self.case_sensitive,
ignore_empty=self.env_ignore_empty,
)
)

return dotenv_vars
Expand Down Expand Up @@ -618,14 +631,25 @@ def __repr__(self) -> str:
)


def _get_env_var_key(key: str, case_sensitive: bool = False) -> str:
return key if case_sensitive else key.lower()


def parse_env_vars(
env_vars: Mapping[str, str | None], case_sensitive: bool = False, ignore_empty: bool = False
) -> Mapping[str, str | None]:
return {_get_env_var_key(k, case_sensitive): v for k, v in env_vars.items() if not (ignore_empty and v == '')}


def read_env_file(
file_path: Path, *, encoding: str | None = None, case_sensitive: bool = False
file_path: Path,
*,
encoding: str | None = None,
case_sensitive: bool = False,
ignore_empty: bool = False,
) -> Mapping[str, str | None]:
file_vars: dict[str, str | None] = dotenv_values(file_path, encoding=encoding or 'utf8')
if not case_sensitive:
return {k.lower(): v for k, v in file_vars.items()}
else:
return file_vars
return parse_env_vars(file_vars, case_sensitive, ignore_empty)


def _annotation_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool:
Expand Down
63 changes: 56 additions & 7 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ class SimpleSettings(BaseSettings):
apple: str


class SettingWithIgnoreEmpty(BaseSettings):
apple: str = 'default'

model_config = SettingsConfigDict(env_ignore_empty=True)


def test_sub_env(env):
env.set('apple', 'hello')
s = SimpleSettings()
Expand All @@ -71,6 +77,44 @@ def test_other_setting():
SimpleSettings(apple='a', foobar=42)


def test_ignore_empty_when_empty_uses_default(env):
env.set('apple', '')
s = SettingWithIgnoreEmpty()
assert s.apple == 'default'


def test_ignore_empty_when_not_empty_uses_value(env):
env.set('apple', 'a')
s = SettingWithIgnoreEmpty()
assert s.apple == 'a'


def test_ignore_empty_with_dotenv_when_empty_uses_default(tmp_path):
p = tmp_path / '.env'
p.write_text('a=')

class Settings(BaseSettings):
a: str = 'default'

model_config = SettingsConfigDict(env_file=p, env_ignore_empty=True)

s = Settings()
assert s.a == 'default'


def test_ignore_empty_with_dotenv_when_not_empty_uses_value(tmp_path):
p = tmp_path / '.env'
p.write_text('a=b')

class Settings(BaseSettings):
a: str = 'default'

model_config = SettingsConfigDict(env_file=p, env_ignore_empty=True)

s = Settings()
assert s.a == 'b'


def test_with_prefix(env):
class Settings(BaseSettings):
apple: str
Expand Down Expand Up @@ -851,7 +895,7 @@ class Settings(BaseSettings):
assert s.a == 'ignore non-file'


def test_read_env_file_cast_sensitive(tmp_path):
def test_read_env_file_case_sensitive(tmp_path):
p = tmp_path / '.env'
p.write_text('a="test"\nB=123')

Expand Down Expand Up @@ -976,14 +1020,19 @@ def test_read_dotenv_vars(tmp_path):
prod_env = tmp_path / '.env.prod'
prod_env.write_text(test_prod_env_file)

source = DotEnvSettingsSource(BaseSettings(), env_file=[base_env, prod_env], env_file_encoding='utf8')
assert source._read_env_files(case_sensitive=False) == {
source = DotEnvSettingsSource(
BaseSettings(), env_file=[base_env, prod_env], env_file_encoding='utf8', case_sensitive=False
)
assert source._read_env_files() == {
'debug_mode': 'false',
'host': 'https://example.com/services',
'port': '8000',
}

assert source._read_env_files(case_sensitive=True) == {
source = DotEnvSettingsSource(
BaseSettings(), env_file=[base_env, prod_env], env_file_encoding='utf8', case_sensitive=True
)
assert source._read_env_files() == {
'debug_mode': 'false',
'host': 'https://example.com/services',
'Port': '8000',
Expand All @@ -992,9 +1041,9 @@ def test_read_dotenv_vars(tmp_path):

def test_read_dotenv_vars_when_env_file_is_none():
assert (
DotEnvSettingsSource(BaseSettings(), env_file=None, env_file_encoding=None)._read_env_files(
case_sensitive=False
)
DotEnvSettingsSource(
BaseSettings(), env_file=None, env_file_encoding=None, case_sensitive=False
)._read_env_files()
== {}
)

Expand Down

0 comments on commit 4f24fad

Please sign in to comment.