Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix alias resolution for default settings source. #468

Merged
merged 3 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 45 additions & 41 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from dotenv import dotenv_values
from pydantic import AliasChoices, AliasPath, BaseModel, Json, RootModel, TypeAdapter
from pydantic._internal._repr import Representation
from pydantic._internal._signature import _field_name_for_signature
from pydantic._internal._typing_extra import WithArgsTypes, origin_is_union, typing_base
from pydantic._internal._utils import deep_update, is_model_class, lenient_issubclass
from pydantic.dataclasses import is_pydantic_dataclass
Expand Down Expand Up @@ -336,10 +335,12 @@ def __init__(self, settings_cls: type[BaseSettings], nested_model_default_partia
)
if self.nested_model_default_partial_update:
for field_name, field_info in settings_cls.model_fields.items():
alias_names, *_ = _get_alias_names(field_name, field_info)
preferred_alias = alias_names[0]
if is_dataclass(type(field_info.default)):
self.defaults[_field_name_for_signature(field_name, field_info)] = asdict(field_info.default)
self.defaults[preferred_alias] = asdict(field_info.default)
elif is_model_class(type(field_info.default)):
self.defaults[_field_name_for_signature(field_name, field_info)] = field_info.default.model_dump()
self.defaults[preferred_alias] = field_info.default.model_dump()

def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
# Nothing to do here. Only implement the return statement to make mypy happy
Expand Down Expand Up @@ -1422,41 +1423,6 @@ def _get_sub_models(self, model: type[BaseModel], field_name: str, field_info: F
sub_models.append(type_) # type: ignore
return sub_models

def _get_alias_names(
self, field_name: str, field_info: FieldInfo, alias_path_args: dict[str, str]
) -> tuple[tuple[str, ...], bool]:
alias_names: list[str] = []
is_alias_path_only: bool = True
if not any((field_info.alias, field_info.validation_alias)):
alias_names += [field_name]
is_alias_path_only = False
else:
new_alias_paths: list[AliasPath] = []
for alias in (field_info.alias, field_info.validation_alias):
if alias is None:
continue
elif isinstance(alias, str):
alias_names.append(alias)
is_alias_path_only = False
elif isinstance(alias, AliasChoices):
for name in alias.choices:
if isinstance(name, str):
alias_names.append(name)
is_alias_path_only = False
else:
new_alias_paths.append(name)
else:
new_alias_paths.append(alias)
for alias_path in new_alias_paths:
name = cast(str, alias_path.path[0])
name = name.lower() if not self.case_sensitive else name
alias_path_args[name] = 'dict' if len(alias_path.path) > 2 else 'list'
if not alias_names and is_alias_path_only:
alias_names.append(name)
if not self.case_sensitive:
alias_names = [alias_name.lower() for alias_name in alias_names]
return tuple(dict.fromkeys(alias_names)), is_alias_path_only

def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str, field_info: FieldInfo) -> None:
if _CliImplicitFlag in field_info.metadata:
cli_flag_name = 'CliImplicitFlag'
Expand All @@ -1481,7 +1447,7 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]
if not field_info.is_required():
raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has a default value')
else:
alias_names, *_ = self._get_alias_names(field_name, field_info, {})
alias_names, *_ = _get_alias_names(field_name, field_info)
if len(alias_names) > 1:
raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has multiple aliases')
field_types = [type_ for type_ in get_args(field_info.annotation) if type_ is not type(None)]
Expand All @@ -1495,7 +1461,7 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]
if not field_info.is_required():
raise SettingsError(f'positional argument {model.__name__}.{field_name} has a default value')
else:
alias_names, *_ = self._get_alias_names(field_name, field_info, {})
alias_names, *_ = _get_alias_names(field_name, field_info)
if len(alias_names) > 1:
raise SettingsError(f'positional argument {model.__name__}.{field_name} has multiple aliases')
positional_args.append((field_name, field_info))
Expand Down Expand Up @@ -1597,7 +1563,9 @@ def _add_parser_args(
alias_path_args: dict[str, str] = {}
for field_name, field_info in self._sort_arg_fields(model):
sub_models: list[type[BaseModel]] = self._get_sub_models(model, field_name, field_info)
alias_names, is_alias_path_only = self._get_alias_names(field_name, field_info, alias_path_args)
alias_names, is_alias_path_only = _get_alias_names(
field_name, field_info, alias_path_args=alias_path_args, case_sensitive=self.case_sensitive
)
preferred_alias = alias_names[0]
if _CliSubCommand in field_info.metadata:
for model in sub_models:
Expand Down Expand Up @@ -2241,5 +2209,41 @@ def _get_model_fields(model_cls: type[Any]) -> dict[str, FieldInfo]:
raise SettingsError(f'Error: {model_cls.__name__} is not subclass of BaseModel or pydantic.dataclasses.dataclass')


def _get_alias_names(
field_name: str, field_info: FieldInfo, alias_path_args: dict[str, str] = {}, case_sensitive: bool = True
) -> tuple[tuple[str, ...], bool]:
alias_names: list[str] = []
is_alias_path_only: bool = True
if not any((field_info.alias, field_info.validation_alias)):
alias_names += [field_name]
is_alias_path_only = False
else:
new_alias_paths: list[AliasPath] = []
for alias in (field_info.alias, field_info.validation_alias):
if alias is None:
continue
elif isinstance(alias, str):
alias_names.append(alias)
is_alias_path_only = False
elif isinstance(alias, AliasChoices):
for name in alias.choices:
if isinstance(name, str):
alias_names.append(name)
is_alias_path_only = False
else:
new_alias_paths.append(name)
else:
new_alias_paths.append(alias)
for alias_path in new_alias_paths:
name = cast(str, alias_path.path[0])
name = name.lower() if not case_sensitive else name
alias_path_args[name] = 'dict' if len(alias_path.path) > 2 else 'list'
if not alias_names and is_alias_path_only:
alias_names.append(name)
if not case_sensitive:
alias_names = [alias_name.lower() for alias_name in alias_names]
return tuple(dict.fromkeys(alias_names)), is_alias_path_only


def _is_function(obj: Any) -> bool:
return isinstance(obj, (FunctionType, BuiltinFunctionType))
21 changes: 21 additions & 0 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from annotated_types import MinLen
from pydantic import (
AliasChoices,
AliasGenerator,
AliasPath,
BaseModel,
Discriminator,
Expand Down Expand Up @@ -621,6 +622,26 @@ def settings_customise_sources(
assert s.model_dump() == s_final


def test_alias_nested_model_default_partial_update():
class SubModel(BaseModel):
v1: str = 'default'
v2: bytes = b'hello'
v3: int

class Settings(BaseSettings):
model_config = SettingsConfigDict(
nested_model_default_partial_update=True, alias_generator=AliasGenerator(lambda s: s.replace('_', '-'))
)

v0: str = 'ok'
sub_model: SubModel = SubModel(v1='top default', v3=33)

assert Settings(**{'sub-model': {'v1': 'cli'}}).model_dump() == {
'v0': 'ok',
'sub_model': {'v1': 'cli', 'v2': b'hello', 'v3': 33},
}


def test_env_str(env):
class Settings(BaseSettings):
apple: str = Field(None, validation_alias='BOOM')
Expand Down
Loading