From d1ab119771b1f204fbe8a9955632298ab0f3d95c Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Tue, 5 Nov 2024 07:42:21 -0700 Subject: [PATCH 1/2] Fix alias resolution for default settings source. --- pydantic_settings/sources.py | 86 +++++++++++++++++++----------------- tests/test_settings.py | 22 +++++++++ 2 files changed, 67 insertions(+), 41 deletions(-) diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index e0e08099..2688a049 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -39,7 +39,6 @@ from dotenv import dotenv_values from pydantic import AliasChoices, AliasPath, BaseModel, Json, RootModel, TypeAdapter from pydantic._internal._repr import Representation -from pydantic._internal._signature import _field_name_for_signature from pydantic._internal._typing_extra import WithArgsTypes, origin_is_union, typing_base from pydantic._internal._utils import deep_update, is_model_class, lenient_issubclass from pydantic.dataclasses import is_pydantic_dataclass @@ -336,10 +335,12 @@ def __init__(self, settings_cls: type[BaseSettings], nested_model_default_partia ) if self.nested_model_default_partial_update: for field_name, field_info in settings_cls.model_fields.items(): + alias_names, *_ = _get_alias_names(field_name, field_info) + preferred_alias = alias_names[0] if is_dataclass(type(field_info.default)): - self.defaults[_field_name_for_signature(field_name, field_info)] = asdict(field_info.default) + self.defaults[preferred_alias] = asdict(field_info.default) elif is_model_class(type(field_info.default)): - self.defaults[_field_name_for_signature(field_name, field_info)] = field_info.default.model_dump() + self.defaults[preferred_alias] = field_info.default.model_dump() def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: # Nothing to do here. Only implement the return statement to make mypy happy @@ -1422,41 +1423,6 @@ def _get_sub_models(self, model: type[BaseModel], field_name: str, field_info: F sub_models.append(type_) # type: ignore return sub_models - def _get_alias_names( - self, field_name: str, field_info: FieldInfo, alias_path_args: dict[str, str] - ) -> tuple[tuple[str, ...], bool]: - alias_names: list[str] = [] - is_alias_path_only: bool = True - if not any((field_info.alias, field_info.validation_alias)): - alias_names += [field_name] - is_alias_path_only = False - else: - new_alias_paths: list[AliasPath] = [] - for alias in (field_info.alias, field_info.validation_alias): - if alias is None: - continue - elif isinstance(alias, str): - alias_names.append(alias) - is_alias_path_only = False - elif isinstance(alias, AliasChoices): - for name in alias.choices: - if isinstance(name, str): - alias_names.append(name) - is_alias_path_only = False - else: - new_alias_paths.append(name) - else: - new_alias_paths.append(alias) - for alias_path in new_alias_paths: - name = cast(str, alias_path.path[0]) - name = name.lower() if not self.case_sensitive else name - alias_path_args[name] = 'dict' if len(alias_path.path) > 2 else 'list' - if not alias_names and is_alias_path_only: - alias_names.append(name) - if not self.case_sensitive: - alias_names = [alias_name.lower() for alias_name in alias_names] - return tuple(dict.fromkeys(alias_names)), is_alias_path_only - def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str, field_info: FieldInfo) -> None: if _CliImplicitFlag in field_info.metadata: cli_flag_name = 'CliImplicitFlag' @@ -1481,7 +1447,7 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo] if not field_info.is_required(): raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has a default value') else: - alias_names, *_ = self._get_alias_names(field_name, field_info, {}) + alias_names, *_ = _get_alias_names(field_name, field_info) if len(alias_names) > 1: raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has multiple aliases') field_types = [type_ for type_ in get_args(field_info.annotation) if type_ is not type(None)] @@ -1495,7 +1461,7 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo] if not field_info.is_required(): raise SettingsError(f'positional argument {model.__name__}.{field_name} has a default value') else: - alias_names, *_ = self._get_alias_names(field_name, field_info, {}) + alias_names, *_ = _get_alias_names(field_name, field_info) if len(alias_names) > 1: raise SettingsError(f'positional argument {model.__name__}.{field_name} has multiple aliases') positional_args.append((field_name, field_info)) @@ -1597,7 +1563,9 @@ def _add_parser_args( alias_path_args: dict[str, str] = {} for field_name, field_info in self._sort_arg_fields(model): sub_models: list[type[BaseModel]] = self._get_sub_models(model, field_name, field_info) - alias_names, is_alias_path_only = self._get_alias_names(field_name, field_info, alias_path_args) + alias_names, is_alias_path_only = _get_alias_names( + field_name, field_info, alias_path_args=alias_path_args, case_sensitive=self.case_sensitive + ) preferred_alias = alias_names[0] if _CliSubCommand in field_info.metadata: for model in sub_models: @@ -2241,5 +2209,41 @@ def _get_model_fields(model_cls: type[Any]) -> dict[str, FieldInfo]: raise SettingsError(f'Error: {model_cls.__name__} is not subclass of BaseModel or pydantic.dataclasses.dataclass') +def _get_alias_names( + field_name: str, field_info: FieldInfo, alias_path_args: dict[str, str] = {}, case_sensitive: bool = True +) -> tuple[tuple[str, ...], bool]: + alias_names: list[str] = [] + is_alias_path_only: bool = True + if not any((field_info.alias, field_info.validation_alias)): + alias_names += [field_name] + is_alias_path_only = False + else: + new_alias_paths: list[AliasPath] = [] + for alias in (field_info.alias, field_info.validation_alias): + if alias is None: + continue + elif isinstance(alias, str): + alias_names.append(alias) + is_alias_path_only = False + elif isinstance(alias, AliasChoices): + for name in alias.choices: + if isinstance(name, str): + alias_names.append(name) + is_alias_path_only = False + else: + new_alias_paths.append(name) + else: + new_alias_paths.append(alias) + for alias_path in new_alias_paths: + name = cast(str, alias_path.path[0]) + name = name.lower() if not case_sensitive else name + alias_path_args[name] = 'dict' if len(alias_path.path) > 2 else 'list' + if not alias_names and is_alias_path_only: + alias_names.append(name) + if not case_sensitive: + alias_names = [alias_name.lower() for alias_name in alias_names] + return tuple(dict.fromkeys(alias_names)), is_alias_path_only + + def _is_function(obj: Any) -> bool: return isinstance(obj, (FunctionType, BuiltinFunctionType)) diff --git a/tests/test_settings.py b/tests/test_settings.py index 8e6297e8..7447ec3a 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -13,6 +13,7 @@ from annotated_types import MinLen from pydantic import ( AliasChoices, + AliasGenerator, AliasPath, BaseModel, Discriminator, @@ -32,6 +33,7 @@ from pydantic_settings import ( BaseSettings, + CliApp, DotEnvSettingsSource, EnvSettingsSource, InitSettingsSource, @@ -621,6 +623,26 @@ def settings_customise_sources( assert s.model_dump() == s_final +def test_alias_nested_model_default_partial_update(): + class SubModel(BaseModel): + v1: str = 'default' + v2: bytes = b'hello' + v3: int + + class Settings(BaseSettings): + model_config = SettingsConfigDict( + nested_model_default_partial_update=True, alias_generator=AliasGenerator(lambda s: s.replace('_', '-')) + ) + + v0: str = 'ok' + sub_model: SubModel = SubModel(v1='top default', v3=33) + + assert CliApp.run(Settings, cli_args=['--sub-model.v1=cli']).model_dump() == { + 'v0': 'ok', + 'sub_model': {'v1': 'cli', 'v2': b'hello', 'v3': 33}, + } + + def test_env_str(env): class Settings(BaseSettings): apple: str = Field(None, validation_alias='BOOM') From 84bd950e6d524d2bb8ceb422d3390ec9c49ac44c Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Wed, 6 Nov 2024 06:30:48 -0700 Subject: [PATCH 2/2] Remove CliApp call. --- tests/test_settings.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_settings.py b/tests/test_settings.py index 7447ec3a..e266093c 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -33,7 +33,6 @@ from pydantic_settings import ( BaseSettings, - CliApp, DotEnvSettingsSource, EnvSettingsSource, InitSettingsSource, @@ -637,7 +636,7 @@ class Settings(BaseSettings): v0: str = 'ok' sub_model: SubModel = SubModel(v1='top default', v3=33) - assert CliApp.run(Settings, cli_args=['--sub-model.v1=cli']).model_dump() == { + assert Settings(**{'sub-model': {'v1': 'cli'}}).model_dump() == { 'v0': 'ok', 'sub_model': {'v1': 'cli', 'v2': b'hello', 'v3': 33}, }