diff --git a/docs/index.md b/docs/index.md index 3d514729..dc770d9e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -464,6 +464,592 @@ class Settings(BaseSettings): So if you provide extra values in a dotenv file, whether they start with `env_prefix` or not, a `ValidationError` will be raised. +## Command Line Support + +Pydantic settings provides integrated CLI support, making it easy to quickly define CLI applications using Pydantic +models. There are two primary use cases for Pydantic settings CLI: + +1. When using a CLI to override fields in Pydantic models. +2. When using Pydantic models to define CLIs. + +By default, the experience is tailored towards use case #1 and builds on the foundations established in [parsing +environment variables](#parsing-environment-variables). If your use case primarily falls into #2, you will likely want +to enable [enforcing required arguments at the CLI](#enforce-required-arguments-at-cli). + +### The Basics + +To get started, let's revisit the example presented in [parsing environment variables](#parsing-environment-variables) +but using a Pydantic settings CLI: + +```py +import sys + +from pydantic import BaseModel + +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class DeepSubModel(BaseModel): + v4: str + + +class SubModel(BaseModel): + v1: str + v2: bytes + v3: int + deep: DeepSubModel + + +class Settings(BaseSettings): + model_config = SettingsConfigDict(cli_parse_args=True) + + v0: str + sub_model: SubModel + + +sys.argv = [ + 'example.py', + '--v0=0', + '--sub_model={"v1": "json-1", "v2": "json-2"}', + '--sub_model.v2=nested-2', + '--sub_model.v3=3', + '--sub_model.deep.v4=v4', +] + +print(Settings().model_dump()) +""" +{ + 'v0': '0', + 'sub_model': {'v1': 'json-1', 'v2': b'nested-2', 'v3': 3, 'deep': {'v4': 'v4'}}, +} +""" +``` + +To enable CLI parsing, we simply set the `cli_parse_args` flag to a valid value, which retains similar conotations as +defined in `argparse`. Alternatively, we can also directly provided the args to parse at time of instantiation: + +```py test="skip" lint="skip" +Settings( + _cli_parse_args=[ + '--v0=0', + '--sub_model={"v1": "json-1", "v2": "json-2"}', + '--sub_model.v2=nested-2', + '--sub_model.v3=3', + '--sub_model.deep.v4=v4', + ] +) +``` + +Note that a CLI settings source is [**the topmost source**](#field-value-priority) by default unless its [priority value +is customised](#customise-settings-sources): + +```py +import os +import sys +from typing import Tuple, Type + +from pydantic_settings import ( + BaseSettings, + CliSettingsSource, + PydanticBaseSettingsSource, +) + + +class Settings(BaseSettings): + my_foo: str + + @classmethod + def settings_customise_sources( + cls, + settings_cls: Type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> Tuple[PydanticBaseSettingsSource, ...]: + return env_settings, CliSettingsSource(settings_cls, cli_parse_args=True) + + +os.environ['MY_FOO'] = 'from environment' + +sys.argv = ['example.py', '--my_foo=from cli'] + +print(Settings().model_dump()) +#> {'my_foo': 'from environment'} +``` + +#### Lists + +CLI argument parsing of lists supports intermixing of any of the below three styles: + + * JSON style `--field='[1,2]'` + * Argparse style `--field 1 --field 2` + * Lazy style `--field=1,2` + +```py +import sys +from typing import List + +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings, cli_parse_args=True): + my_list: List[int] + + +sys.argv = ['example.py', '--my_list', '[1,2]'] +print(Settings().model_dump()) +#> {'my_list': [1, 2]} + +sys.argv = ['example.py', '--my_list', '1', '--my_list', '2'] +print(Settings().model_dump()) +#> {'my_list': [1, 2]} + +sys.argv = ['example.py', '--my_list', '1,2'] +print(Settings().model_dump()) +#> {'my_list': [1, 2]} +``` + +#### Dictionaries + +CLI argument parsing of dictionaries supports intermixing of any of the below two styles: + + * JSON style `--field='{"k1": 1, "k2": 2}'` + * Environment variable style `--field k1=1 --field k2=2` + +These can be used in conjunction with list forms as well, e.g: + + * `--field k1=1,k2=2 --field k3=3 --field '{"k4: 4}'` etc. + +```py +import sys +from typing import Dict + +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings, cli_parse_args=True): + my_dict: Dict[str, int] + + +sys.argv = ['example.py', '--my_dict', '{"k1":1,"k2":2}'] +print(Settings().model_dump()) +#> {'my_dict': {'k1': 1, 'k2': 2}} + +sys.argv = ['example.py', '--my_dict', 'k1=1', '--my_dict', 'k2=2'] +print(Settings().model_dump()) +#> {'my_dict': {'k1': 1, 'k2': 2}} +``` + +#### Literals and Enums + +CLI argument parsing of literals and enums are converted into CLI choices. + +```py +import sys +from enum import IntEnum +from typing import Literal + +from pydantic_settings import BaseSettings + + +class Fruit(IntEnum): + pear = 0 + kiwi = 1 + lime = 2 + + +class Settings(BaseSettings, cli_parse_args=True): + fruit: Fruit + pet: Literal['dog', 'cat', 'bird'] + + +sys.argv = ['example.py', '--fruit', 'lime', '--pet', 'cat'] +print(Settings().model_dump()) +#> {'fruit': , 'pet': 'cat'} +``` + +### Subcommands and Positional Arguments + +Subcommands and positional arguments are expressed using the `CliSubCommand` and `CliPositionalArg` annotations. These +annotations can only be applied to required fields (i.e. fields that do not have a default value). Furthermore, +subcommands must be a valid type derived from the pydantic `BaseModel` class. + +!!! note + CLI settings subcommands are limited to a single subparser per model. In other words, all subcommands for a model + are grouped under a single subparser; it does not allow for multiple subparsers with each subparser having its own + set of subcommands. For more information on subparsers, see [argparse + subcommands](https://docs.python.org/3/library/argparse.html#sub-commands). + +```py +import sys + +from pydantic import BaseModel, Field + +from pydantic_settings import ( + BaseSettings, + CliPositionalArg, + CliSubCommand, +) + + +class FooPlugin(BaseModel): + """git-plugins-foo - Extra deep foo plugin command""" + + my_feature: bool = Field( + default=False, description='Enable my feature on foo plugin' + ) + + +class BarPlugin(BaseModel): + """git-plugins-bar - Extra deep bar plugin command""" + + my_feature: bool = Field( + default=False, description='Enable my feature on bar plugin' + ) + + +class Plugins(BaseModel): + """git-plugins - Fake plugins for GIT""" + + foo: CliSubCommand[FooPlugin] = Field(description='Foo is fake plugin') + + bar: CliSubCommand[BarPlugin] = Field(description='Bar is also a fake plugin') + + +class Clone(BaseModel): + """git-clone - Clone a repository into a new directory""" + + repository: CliPositionalArg[str] = Field(description='The repository to clone') + + directory: CliPositionalArg[str] = Field(description='The directory to clone into') + + local: bool = Field( + default=False, + description='When the resposity to clone from is on a local machine, bypass ...', + ) + + +class Git(BaseSettings, cli_parse_args=True, cli_prog_name='git'): + """git - The stupid content tracker""" + + clone: CliSubCommand[Clone] = Field( + description='Clone a repository into a new directory' + ) + + plugins: CliSubCommand[Plugins] = Field(description='Fake GIT plugin commands') + + +try: + sys.argv = ['example.py', '--help'] + Git() +except SystemExit as e: + print(e) + #> 0 +""" +usage: git [-h] {clone,plugins} ... + +git - The stupid content tracker + +options: + -h, --help show this help message and exit + +subcommands: + {clone,plugins} + clone Clone a repository into a new directory + plugins Fake GIT plugin commands +""" + + +try: + sys.argv = ['example.py', 'clone', '--help'] + Git() +except SystemExit as e: + print(e) + #> 0 +""" +usage: git clone [-h] [--local bool] [--shared bool] REPOSITORY DIRECTORY + +git-clone - Clone a repository into a new directory + +positional arguments: + REPOSITORY The repository to clone + DIRECTORY The directory to clone into + +options: + -h, --help show this help message and exit + --local bool When the resposity to clone from is on a local machine, bypass ... (default: False) +""" + + +try: + sys.argv = ['example.py', 'plugins', 'bar', '--help'] + Git() +except SystemExit as e: + print(e) + #> 0 +""" +usage: git plugins bar [-h] [--my_feature bool] + +git-plugins-bar - Extra deep bar plugin command + +options: + -h, --help show this help message and exit + --my_feature bool Enable my feature on bar plugin (default: False) +""" +``` + +### Customizing the CLI Experience + +The below flags can be used to customise the CLI experience to your needs. + +#### Change the Displayed Program Name + +Change the default program name displayed in the help text usage by setting `cli_prog_name`. By default, it will derive +the name of the currently executing program from `sys.argv[0]`, just like argparse. + +```py test="skip" +import sys + +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings, cli_parse_args=True, cli_prog_name='appdantic'): + pass + + +sys.argv = ['example.py', '--help'] +Settings() +""" +usage: appdantic [-h] + +options: + -h, --help show this help message and exit +""" +``` + +#### Enforce Required Arguments at CLI + +Pydantic settings is designed to pull values in from various sources when instantating a model. This means a field that +is required is not strictly required from any single source (e.g. the CLI). Instead, all that matters is that one of the +sources provides the required value. + +However, if your use case [aligns more with #2](#command-line-support), using Pydantic models to define CLIs, you will +likely want required fields to be _strictly required at the CLI_. We can enable this behavior by using the +`cli_enforce_required`. + +```py +import os +import sys + +from pydantic import Field + +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings, cli_parse_args=True, cli_enforce_required=True): + my_required_field: str = Field(description='a top level required field') + + +os.environ['MY_REQUIRED_FIELD'] = 'hello from environment' + +try: + sys.argv = ['example.py'] + Settings() +except SystemExit as e: + print(e) + #> 2 +""" +usage: example.py [-h] --my_required_field str +example.py: error: the following arguments are required: --my_required_field +""" +``` + +#### Change the None Type Parse String + +Change the CLI string value that will be parsed (e.g. "null", "void", "None", etc.) into `None` type(None) by setting +`cli_parse_none_str`. By default it will use the `env_parse_none_str` value if set. Otherwise, it will default to "null" +if `cli_avoid_json` is `False`, and "None" if `cli_avoid_json` is `True`. + +```py +import sys +from typing import Optional + +from pydantic import Field + +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings, cli_parse_args=True, cli_parse_none_str='void'): + v1: Optional[int] = Field(description='the top level v0 option') + + +sys.argv = ['example.py', '--v1', 'void'] +print(Settings().model_dump()) +#> {'v1': None} +``` + +#### Hide None Type Values + +Hide `None` values from the CLI help text by enabling `cli_hide_none_type`. + +```py test="skip" +import sys +from typing import Optional + +from pydantic import Field + +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings, cli_parse_args=True, cli_hide_none_type=True): + v0: Optional[str] = Field(description='the top level v0 option') + + +sys.argv = ['example.py', '--help'] +Settings() +""" +usage: example.py [-h] [--v0 str] + +options: + -h, --help show this help message and exit + --v0 str the top level v0 option (required) +""" +``` + +#### Avoid Adding JSON CLI Options + +Avoid adding complex fields that result in JSON strings at the CLI by enabling `cli_avoid_json`. + +```py test="skip" +import sys + +from pydantic import BaseModel, Field + +from pydantic_settings import BaseSettings + + +class SubModel(BaseModel): + v1: int = Field(description='the sub model v1 option') + + +class Settings(BaseSettings, cli_parse_args=True, cli_avoid_json=True): + sub_model: SubModel = Field( + description='The help summary for SubModel related options' + ) + + +sys.argv = ['example.py', '--help'] +Settings() +""" +usage: example.py [-h] [--sub_model.v1 int] + +options: + -h, --help show this help message and exit + +sub_model options: + The help summary for SubModel related options + + --sub_model.v1 int the sub model v1 option (required) +""" +``` + +#### Use Class Docstring for Group Help Text + +By default, when populating the group help text for nested models it will pull from the field descriptions. +Alternatively, we can also configure CLI settings to pull from the class docstring instead. + +!!! note + If the field is a union of nested models the group help text will always be pulled from the field description; + even if `cli_use_class_docs_for_groups` is set to `True`. + +```py test="skip" +import sys + +from pydantic import BaseModel, Field + +from pydantic_settings import BaseSettings + + +class SubModel(BaseModel): + """The help text from the class docstring.""" + + v1: int = Field(description='the sub model v1 option') + + +class Settings(BaseSettings, cli_parse_args=True, cli_use_class_docs_for_groups=True): + """My application help text.""" + + sub_model: SubModel = Field(description='The help text from the field description') + + +sys.argv = ['example.py', '--help'] +Settings() +""" +usage: example.py [-h] [--sub_model JSON] [--sub_model.v1 int] + +My application help text. + +options: + -h, --help show this help message and exit + +sub_model options: + The help text from the class docstring. + + --sub_model JSON set sub_model from JSON string + --sub_model.v1 int the sub model v1 option (required) +""" +``` + +### Integrating with Existing Parsers + +A CLI settings source can be integrated with existing parsers by overriding the default CLI settings source with a user +defined one that specifies the `root_parser` object. + +```py +import sys +from argparse import ArgumentParser + +from pydantic_settings import BaseSettings, CliSettingsSource + +parser = ArgumentParser() +parser.add_argument('--food', choices=['pear', 'kiwi', 'lime']) + + +class Settings(BaseSettings): + name: str = 'Bob' + + +# Set existing `parser` as the `root_parser` object for the user defined settings source +cli_settings = CliSettingsSource(Settings, root_parser=parser) + +# Parse and load CLI settings from the command line into the settings source. +sys.argv = ['example.py', '--food', 'kiwi', '--name', 'waldo'] +print(Settings(_cli_settings_source=cli_settings(args=True)).model_dump()) +#> {'name': 'waldo'} + +# Load CLI settings from pre-parsed arguments. i.e., the parsing occurs elsewhere and we +# just need to load the pre-parsed args into the settings source. +parsed_args = parser.parse_args(['--food', 'kiwi', '--name', 'ralph']) +print(Settings(_cli_settings_source=cli_settings(parsed_args=parsed_args)).model_dump()) +#> {'name': 'ralph'} +``` + +A `CliSettingsSource` connects with a `root_parser` object by using parser methods to add `settings_cls` fields as +command line arguments. The `CliSettingsSource` internal parser representation is based on the `argparse` library, and +therefore, requires parser methods that support the same attributes as their `argparse` counterparts. The available +parser methods that can be customised, along with their argparse counterparts (the defaults), are listed below: + +* `parse_args_method` - argparse.ArgumentParser.parse_args +* `add_argument_method` - argparse.ArgumentParser.add_argument +* `add_argument_group_method` - argparse.ArgumentParser.add\_argument_group +* `add_parser_method` - argparse.\_SubParsersAction.add_parser +* `add_subparsers_method` - argparse.ArgumentParser.add_subparsers +* `formatter_class` - argparse.HelpFormatter + +For a non-argparse parser the parser methods can be set to `None` if not supported. The CLI settings will only raise an +error when connecting to the root parser if a parser method is necessary but set to `None`. + ## Secrets Placing secret values in files is a common pattern to provide sensitive configuration to an application. @@ -719,11 +1305,12 @@ class ExplicitFilePathSettings(BaseSettings): In the case where a value is specified for the same `Settings` field in multiple ways, the selected value is determined as follows (in descending order of priority): -1. Arguments passed to the `Settings` class initialiser. -2. Environment variables, e.g. `my_prefix_special_function` as described above. -3. Variables loaded from a dotenv (`.env`) file. -4. Variables loaded from the secrets directory. -5. The default field values for the `Settings` model. +1. If `cli_parse_args` is enabled, arguments passed in at the CLI. +2. Arguments passed to the `Settings` class initialiser. +3. Environment variables, e.g. `my_prefix_special_function` as described above. +4. Variables loaded from a dotenv (`.env`) file. +5. Variables loaded from the secrets directory. +6. The default field values for the `Settings` model. ## Customise settings sources diff --git a/pydantic_settings/__init__.py b/pydantic_settings/__init__.py index 5f08cc62..d70ccc8a 100644 --- a/pydantic_settings/__init__.py +++ b/pydantic_settings/__init__.py @@ -1,5 +1,8 @@ from .main import BaseSettings, SettingsConfigDict from .sources import ( + CliPositionalArg, + CliSettingsSource, + CliSubCommand, DotEnvSettingsSource, EnvSettingsSource, InitSettingsSource, @@ -16,6 +19,9 @@ 'BaseSettings', 'DotEnvSettingsSource', 'EnvSettingsSource', + 'CliSettingsSource', + 'CliSubCommand', + 'CliPositionalArg', 'InitSettingsSource', 'JsonConfigSettingsSource', 'PyprojectTomlConfigSettingsSource', diff --git a/pydantic_settings/main.py b/pydantic_settings/main.py index c5764fc2..3b3ba434 100644 --- a/pydantic_settings/main.py +++ b/pydantic_settings/main.py @@ -10,6 +10,7 @@ from .sources import ( ENV_FILE_SENTINEL, + CliSettingsSource, DotEnvSettingsSource, DotenvType, EnvSettingsSource, @@ -29,6 +30,15 @@ class SettingsConfigDict(ConfigDict, total=False): env_nested_delimiter: str | None env_parse_none_str: str | None env_parse_enums: bool | None + cli_prog_name: str | None + cli_parse_args: bool | list[str] | tuple[str, ...] | None + cli_settings_source: CliSettingsSource[Any] | None + cli_parse_none_str: str | None + cli_hide_none_type: bool + cli_avoid_json: bool + cli_enforce_required: bool + cli_use_class_docs_for_groups: bool + cli_prefix: str secrets_dir: str | Path | None json_file: PathType | None json_file_encoding: str | None @@ -87,6 +97,20 @@ class BaseSettings(BaseModel): _env_parse_none_str: The env string value that should be parsed (e.g. "null", "void", "None", etc.) into `None` type(None). Defaults to `None` type(None), which means no parsing should occur. _env_parse_enums: Parse enum field names to values. Defaults to `None.`, which means no parsing should occur. + _cli_prog_name: The CLI program name to display in help text. Defaults to `None` if _cli_parse_args is `None`. + Otherwse, defaults to sys.argv[0]. + _cli_parse_args: The list of CLI arguments to parse. Defaults to None. + If set to `True`, defaults to sys.argv[1:]. + _cli_settings_source: Override the default CLI settings source with a user defined instance. Defaults to None. + _cli_parse_none_str: The CLI string value that should be parsed (e.g. "null", "void", "None", etc.) into + `None` type(None). Defaults to _env_parse_none_str value if set. Otherwise, defaults to "null" if + _cli_avoid_json is `False`, and "None" if _cli_avoid_json is `True`. + _cli_hide_none_type: Hide `None` values in CLI help text. Defaults to `False`. + _cli_avoid_json: Avoid complex JSON objects in CLI help text. Defaults to `False`. + _cli_enforce_required: Enforce required fields at the CLI. Defaults to `False`. + _cli_use_class_docs_for_groups: Use class docstrings in CLI group help text instead of field descriptions. + Defaults to `False`. + _cli_prefix: The root parser command line arguments prefix. Defaults to "". _secrets_dir: The secret files directory. Defaults to `None`. """ @@ -100,6 +124,15 @@ def __init__( _env_nested_delimiter: str | None = None, _env_parse_none_str: str | None = None, _env_parse_enums: bool | None = None, + _cli_prog_name: str | None = None, + _cli_parse_args: bool | list[str] | tuple[str, ...] | None = None, + _cli_settings_source: CliSettingsSource[Any] | None = None, + _cli_parse_none_str: str | None = None, + _cli_hide_none_type: bool | None = None, + _cli_avoid_json: bool | None = None, + _cli_enforce_required: bool | None = None, + _cli_use_class_docs_for_groups: bool | None = None, + _cli_prefix: str | None = None, _secrets_dir: str | Path | None = None, **values: Any, ) -> None: @@ -115,6 +148,15 @@ def __init__( _env_nested_delimiter=_env_nested_delimiter, _env_parse_none_str=_env_parse_none_str, _env_parse_enums=_env_parse_enums, + _cli_prog_name=_cli_prog_name, + _cli_parse_args=_cli_parse_args, + _cli_settings_source=_cli_settings_source, + _cli_parse_none_str=_cli_parse_none_str, + _cli_hide_none_type=_cli_hide_none_type, + _cli_avoid_json=_cli_avoid_json, + _cli_enforce_required=_cli_enforce_required, + _cli_use_class_docs_for_groups=_cli_use_class_docs_for_groups, + _cli_prefix=_cli_prefix, _secrets_dir=_secrets_dir, ) ) @@ -154,6 +196,15 @@ def _settings_build_values( _env_nested_delimiter: str | None = None, _env_parse_none_str: str | None = None, _env_parse_enums: bool | None = None, + _cli_prog_name: str | None = None, + _cli_parse_args: bool | list[str] | tuple[str, ...] | None = None, + _cli_settings_source: CliSettingsSource[Any] | None = None, + _cli_parse_none_str: str | None = None, + _cli_hide_none_type: bool | None = None, + _cli_avoid_json: bool | None = None, + _cli_enforce_required: bool | None = None, + _cli_use_class_docs_for_groups: bool | None = None, + _cli_prefix: str | None = None, _secrets_dir: str | Path | None = None, ) -> dict[str, Any]: # Determine settings config values @@ -175,10 +226,52 @@ def _settings_build_values( _env_parse_none_str if _env_parse_none_str is not None else self.model_config.get('env_parse_none_str') ) env_parse_enums = _env_parse_enums if _env_parse_enums is not None else self.model_config.get('env_parse_enums') + + cli_prog_name = _cli_prog_name if _cli_prog_name is not None else self.model_config.get('cli_prog_name') + cli_parse_args = _cli_parse_args if _cli_parse_args is not None else self.model_config.get('cli_parse_args') + cli_settings_source = ( + _cli_settings_source if _cli_settings_source is not None else self.model_config.get('cli_settings_source') + ) + cli_parse_none_str = ( + _cli_parse_none_str if _cli_parse_none_str is not None else self.model_config.get('cli_parse_none_str') + ) + cli_parse_none_str = cli_parse_none_str if not env_parse_none_str else env_parse_none_str + cli_hide_none_type = ( + _cli_hide_none_type if _cli_hide_none_type is not None else self.model_config.get('cli_hide_none_type') + ) + cli_avoid_json = _cli_avoid_json if _cli_avoid_json is not None else self.model_config.get('cli_avoid_json') + cli_enforce_required = ( + _cli_enforce_required + if _cli_enforce_required is not None + else self.model_config.get('cli_enforce_required') + ) + cli_use_class_docs_for_groups = ( + _cli_use_class_docs_for_groups + if _cli_use_class_docs_for_groups is not None + else self.model_config.get('cli_use_class_docs_for_groups') + ) + cli_prefix = _cli_prefix if _cli_prefix is not None else self.model_config.get('cli_prefix') + secrets_dir = _secrets_dir if _secrets_dir is not None else self.model_config.get('secrets_dir') # Configure built-in sources init_settings = InitSettingsSource(self.__class__, init_kwargs=init_kwargs) + cli_settings = ( + CliSettingsSource( + self.__class__, + cli_prog_name=cli_prog_name, + cli_parse_args=cli_parse_args, + cli_parse_none_str=cli_parse_none_str, + cli_hide_none_type=cli_hide_none_type, + cli_avoid_json=cli_avoid_json, + cli_enforce_required=cli_enforce_required, + cli_use_class_docs_for_groups=cli_use_class_docs_for_groups, + cli_prefix=cli_prefix, + case_sensitive=case_sensitive, + ) + if cli_settings_source is None + else cli_settings_source + ) env_settings = EnvSettingsSource( self.__class__, case_sensitive=case_sensitive, @@ -211,6 +304,9 @@ def _settings_build_values( dotenv_settings=dotenv_settings, file_secret_settings=file_secret_settings, ) + if not any([source for source in sources if isinstance(source, CliSettingsSource)]): + if cli_parse_args or cli_settings_source: + sources = (cli_settings,) + sources if sources: return deep_update(*reversed([source() for source in sources])) else: @@ -230,6 +326,15 @@ def _settings_build_values( env_nested_delimiter=None, env_parse_none_str=None, env_parse_enums=None, + cli_prog_name=None, + cli_parse_args=None, + cli_settings_source=None, + cli_parse_none_str=None, + cli_hide_none_type=False, + cli_avoid_json=False, + cli_enforce_required=False, + cli_use_class_docs_for_groups=False, + cli_prefix='', json_file=None, json_file_encoding=None, yaml_file=None, diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index 5e6a32a4..069e50a6 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -2,21 +2,43 @@ import json import os +import re +import shlex import sys +import typing import warnings from abc import ABC, abstractmethod +from argparse import SUPPRESS, ArgumentParser, HelpFormatter, Namespace, _SubParsersAction from collections import deque from dataclasses import is_dataclass from enum import Enum from pathlib import Path -from typing import TYPE_CHECKING, Any, List, Mapping, Sequence, Tuple, Union, cast - +from types import FunctionType +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + List, + Mapping, + Sequence, + Tuple, + TypeVar, + Union, + cast, + overload, +) + +import typing_extensions from dotenv import dotenv_values from pydantic import AliasChoices, AliasPath, BaseModel, Json -from pydantic._internal._typing_extra import WithArgsTypes, origin_is_union +from pydantic._internal._repr import Representation +from pydantic._internal._typing_extra import WithArgsTypes, origin_is_union, typing_base from pydantic._internal._utils import deep_update, is_model_class, lenient_issubclass +from pydantic.dataclasses import is_pydantic_dataclass from pydantic.fields import FieldInfo -from typing_extensions import _AnnotatedAlias, get_args, get_origin +from pydantic_core import PydanticUndefined +from typing_extensions import Annotated, _AnnotatedAlias, get_args, get_origin from pydantic_settings.utils import path_type_label @@ -71,6 +93,23 @@ def import_toml() -> None: ENV_FILE_SENTINEL: DotenvType = Path('') +class _CliSubCommand: + pass + + +class _CliPositionalArg: + pass + + +class _CliInternalArgParser(ArgumentParser): + pass + + +T = TypeVar('T') +CliSubCommand = Annotated[Union[T, None], _CliSubCommand] +CliPositionalArg = Annotated[T, _CliPositionalArg] + + class EnvNoneType(str): pass @@ -512,7 +551,9 @@ def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, val value = field.annotation[value] # type: ignore if is_complex or value_is_complex: - if value is None: + if isinstance(value, EnvNoneType): + return value + elif value is None: # field is complex but no value found so far, try explode_env_vars env_val_built = self.explode_env_vars(field_name, field, self.env_vars) if env_val_built: @@ -640,7 +681,8 @@ def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[ if not allow_json_failure: raise e if isinstance(env_var, dict): - env_var[last_key] = env_val + if last_key not in env_var or not isinstance(env_val, EnvNoneType) or env_var[last_key] is {}: + env_var[last_key] = env_val return result @@ -740,6 +782,624 @@ def __repr__(self) -> str: ) +class CliSettingsSource(EnvSettingsSource, Generic[T]): + """ + Source class for loading settings values from CLI. + + Note: + A `CliSettingsSource` connects with a `root_parser` object by using the parser methods to add + `settings_cls` fields as command line arguments. The `CliSettingsSource` internal parser representation + is based upon the `argparse` parsing library, and therefore, requires the parser methods to support + the same attributes as their `argparse` library counterparts. + + Args: + cli_prog_name: The CLI program name to display in help text. Defaults to `None` if cli_parse_args is `None`. + Otherwse, defaults to sys.argv[0]. + cli_parse_args: The list of CLI arguments to parse. Defaults to None. + If set to `True`, defaults to sys.argv[1:]. + cli_settings_source: Override the default CLI settings source with a user defined instance. Defaults to `None`. + cli_parse_none_str: The CLI string value that should be parsed (e.g. "null", "void", "None", etc.) into `None` + type(None). Defaults to "null" if cli_avoid_json is `False`, and "None" if cli_avoid_json is `True`. + cli_hide_none_type: Hide `None` values in CLI help text. Defaults to `False`. + cli_avoid_json: Avoid complex JSON objects in CLI help text. Defaults to `False`. + cli_enforce_required: Enforce required fields at the CLI. Defaults to `False`. + cli_use_class_docs_for_groups: Use class docstrings in CLI group help text instead of field descriptions. + Defaults to `False`. + cli_prefix: Prefix for command line arguments added under the root parser. Defaults to "". + case_sensitive: Whether CLI "--arg" names should be read with case-sensitivity. Defaults to `True`. + Note: Case-insensitive matching is only supported on the internal root parser and does not apply to CLI + subcommands. + root_parser: The root parser object. + parse_args_method: The root parser parse args method. Defaults to `argparse.ArgumentParser.parse_args`. + add_argument_method: The root parser add argument method. Defaults to `argparse.ArgumentParser.add_argument`. + add_argument_group_method: The root parser add argument group method. + Defaults to `argparse.ArgumentParser.add_argument_group`. + add_parser_method: The root parser add new parser (sub-command) method. + Defaults to `argparse._SubParsersAction.add_parser`. + add_subparsers_method: The root parser add subparsers (sub-commands) method. + Defaults to `argparse.ArgumentParser.add_subparsers`. + formatter_class: A class for customizing the root parser help text. Defaults to `argparse.HelpFormatter`. + """ + + def __init__( + self, + settings_cls: type[BaseSettings], + cli_prog_name: str | None = None, + cli_parse_args: bool | list[str] | tuple[str, ...] | None = None, + cli_parse_none_str: str | None = None, + cli_hide_none_type: bool | None = None, + cli_avoid_json: bool | None = None, + cli_enforce_required: bool | None = None, + cli_use_class_docs_for_groups: bool | None = None, + cli_prefix: str | None = None, + case_sensitive: bool | None = True, + root_parser: Any = None, + parse_args_method: Callable[..., Any] | None = ArgumentParser.parse_args, + add_argument_method: Callable[..., Any] | None = ArgumentParser.add_argument, + add_argument_group_method: Callable[..., Any] | None = ArgumentParser.add_argument_group, + add_parser_method: Callable[..., Any] | None = _SubParsersAction.add_parser, + add_subparsers_method: Callable[..., Any] | None = ArgumentParser.add_subparsers, + formatter_class: Any = HelpFormatter, + ) -> None: + self.cli_prog_name = ( + cli_prog_name if cli_prog_name is not None else settings_cls.model_config.get('cli_prog_name', sys.argv[0]) + ) + self.cli_hide_none_type = ( + cli_hide_none_type + if cli_hide_none_type is not None + else settings_cls.model_config.get('cli_hide_none_type', False) + ) + self.cli_avoid_json = ( + cli_avoid_json if cli_avoid_json is not None else settings_cls.model_config.get('cli_avoid_json', False) + ) + if not cli_parse_none_str: + cli_parse_none_str = 'None' if self.cli_avoid_json is True else 'null' + self.cli_parse_none_str = cli_parse_none_str + self.cli_enforce_required = ( + cli_enforce_required + if cli_enforce_required is not None + else settings_cls.model_config.get('cli_enforce_required', False) + ) + self.cli_use_class_docs_for_groups = ( + cli_use_class_docs_for_groups + if cli_use_class_docs_for_groups is not None + else settings_cls.model_config.get('cli_use_class_docs_for_groups', False) + ) + self.cli_prefix = cli_prefix if cli_prefix is not None else settings_cls.model_config.get('cli_prefix', '') + if self.cli_prefix: + if cli_prefix.startswith('.') or cli_prefix.endswith('.') or not cli_prefix.replace('.', '').isidentifier(): # type: ignore + raise SettingsError(f'CLI settings source prefix is invalid: {cli_prefix}') + self.cli_prefix += '.' + + case_sensitive = case_sensitive if case_sensitive is not None else True + if not case_sensitive and root_parser is not None: + raise SettingsError('Case-insensitive matching is only supported on the internal root parser') + + super().__init__( + settings_cls, + env_nested_delimiter='.', + env_parse_none_str=self.cli_parse_none_str, + env_parse_enums=True, + env_prefix=self.cli_prefix, + case_sensitive=case_sensitive, + ) + + root_parser = ( + _CliInternalArgParser(prog=self.cli_prog_name, description=settings_cls.__doc__) + if root_parser is None + else root_parser + ) + self._connect_root_parser( + root_parser=root_parser, + parse_args_method=parse_args_method, + add_argument_method=add_argument_method, + add_argument_group_method=add_argument_group_method, + add_parser_method=add_parser_method, + add_subparsers_method=add_subparsers_method, + formatter_class=formatter_class, + ) + + if cli_parse_args not in (None, False): + if cli_parse_args is True: + cli_parse_args = sys.argv[1:] + elif not isinstance(cli_parse_args, (list, tuple)): + raise SettingsError( + f'cli_parse_args must be List[str] or Tuple[str, ...], recieved {type(cli_parse_args)}' + ) + self._load_env_vars(parsed_args=self._parse_args(self.root_parser, cli_parse_args)) + + @overload + def __call__(self) -> dict[str, Any]: ... + + @overload + def __call__(self, *, args: list[str] | tuple[str, ...] | bool) -> CliSettingsSource[T]: + """ + Parse and load the command line arguments list into the CLI settings source. + + Args: + args: + The command line arguments to parse and load. Defaults to `None`, which means do not parse + command line arguments. If set to `True`, defaults to sys.argv[1:]. If set to `False`, does + not parse command line arguments. + + Returns: + CliSettingsSource: The object instance itself. + """ + ... + + @overload + def __call__(self, *, parsed_args: Namespace | dict[str, list[str] | str]) -> CliSettingsSource[T]: + """ + Loads parsed command line arguments into the CLI settings source. + + Note: + The parsed args must be in `argparse.Namespace` or vars dictionary (e.g., vars(argparse.Namespace)) + format. + + Args: + parsed_args: The parsed args to load. + + Returns: + CliSettingsSource: The object instance itself. + """ + ... + + def __call__( + self, + *, + args: list[str] | tuple[str, ...] | bool | None = None, + parsed_args: Namespace | dict[str, list[str] | str] | None = None, + ) -> dict[str, Any] | CliSettingsSource[T]: + if args is not None and parsed_args is not None: + raise SettingsError('`args` and `parsed_args` are mutually exclusive') + elif args is not None: + if args is False: + return self._load_env_vars(parsed_args={}) + if args is True: + args = sys.argv[1:] + return self._load_env_vars(parsed_args=self._parse_args(self.root_parser, args)) + elif parsed_args is not None: + return self._load_env_vars(parsed_args=parsed_args) + else: + return super().__call__() + + @overload + def _load_env_vars(self) -> Mapping[str, str | None]: ... + + @overload + def _load_env_vars(self, *, parsed_args: Namespace | dict[str, list[str] | str]) -> CliSettingsSource[T]: + """ + Loads the parsed command line arguments into the CLI environment settings variables. + + Note: + The parsed args must be in `argparse.Namespace` or vars dictionary (e.g., vars(argparse.Namespace)) + format. + + Args: + parsed_args: The parsed args to load. + + Returns: + CliSettingsSource: The object instance itself. + """ + ... + + def _load_env_vars( + self, *, parsed_args: Namespace | dict[str, list[str] | str] | None = None + ) -> Mapping[str, str | None] | CliSettingsSource[T]: + if parsed_args is None: + return {} + + if isinstance(parsed_args, Namespace): + parsed_args = vars(parsed_args) + + selected_subcommands: list[str] = [] + for field_name, val in parsed_args.items(): + if isinstance(val, list): + parsed_args[field_name] = self._merge_parsed_list(val, field_name) + elif field_name.endswith(':subcommand') and val is not None: + selected_subcommands.append(field_name.split(':')[0] + val) + + for subcommands in self._cli_subcommands.values(): + for subcommand in subcommands: + if subcommand not in selected_subcommands: + parsed_args[subcommand] = self.cli_parse_none_str + + parsed_args = {key: val for key, val in parsed_args.items() if not key.endswith(':subcommand')} + if selected_subcommands: + last_selected_subcommand = max(selected_subcommands, key=len) + if not any(field_name for field_name in parsed_args.keys() if f'{last_selected_subcommand}.' in field_name): + parsed_args[last_selected_subcommand] = '{}' + + self.env_vars = parse_env_vars( + cast(Mapping[str, str], parsed_args), + self.case_sensitive, + self.env_ignore_empty, + self.cli_parse_none_str, + ) + + return self + + def _merge_parsed_list(self, parsed_list: list[str], field_name: str) -> str: + try: + merged_list: list[str] = [] + is_last_consumed_a_value = False + merge_type = self._cli_dict_args.get(field_name, list) + if ( + merge_type is list + or not origin_is_union(get_origin(merge_type)) + or not any( + type_ + for type_ in get_args(merge_type) + if type_ is not type(None) and get_origin(type_) not in (dict, Mapping) + ) + ): + inferred_type = merge_type + else: + inferred_type = ( + list if parsed_list and (len(parsed_list) > 1 or parsed_list[0].startswith('[')) else str + ) + for val in parsed_list: + if val.startswith('[') and val.endswith(']'): + val = val[1:-1] + while val: + if val.startswith(','): + val = self._consume_comma(val, merged_list, is_last_consumed_a_value) + is_last_consumed_a_value = False + else: + if val.startswith('{') or val.startswith('['): + val = self._consume_object_or_array(val, merged_list) + else: + try: + val = self._consume_string_or_number(val, merged_list, merge_type) + except ValueError as e: + if merge_type is inferred_type: + raise e + merge_type = inferred_type + val = self._consume_string_or_number(val, merged_list, merge_type) + is_last_consumed_a_value = True + if not is_last_consumed_a_value: + val = self._consume_comma(val, merged_list, is_last_consumed_a_value) + + if merge_type is str: + return merged_list[0] + elif merge_type is list: + return f'[{",".join(merged_list)}]' + else: + merged_dict: dict[str, str] = {} + for item in merged_list: + merged_dict.update(json.loads(item)) + return json.dumps(merged_dict) + except Exception as e: + raise SettingsError(f'Parsing error encountered for {field_name}: {e}') + + def _consume_comma(self, item: str, merged_list: list[str], is_last_consumed_a_value: bool) -> str: + if not is_last_consumed_a_value: + merged_list.append('""') + return item[1:] + + def _consume_object_or_array(self, item: str, merged_list: list[str]) -> str: + count = 1 + close_delim = '}' if item.startswith('{') else ']' + for consumed in range(1, len(item)): + if item[consumed] in ('{', '['): + count += 1 + elif item[consumed] in ('}', ']'): + count -= 1 + if item[consumed] == close_delim and count == 0: + merged_list.append(item[: consumed + 1]) + return item[consumed + 1 :] + raise SettingsError(f'Missing end delimiter "{close_delim}"') + + def _consume_string_or_number(self, item: str, merged_list: list[str], merge_type: type[Any] | None) -> str: + consumed = 0 if merge_type is not str else len(item) + is_find_end_quote = False + while consumed < len(item): + if item[consumed] == '"' and (consumed == 0 or item[consumed - 1] != '\\'): + is_find_end_quote = not is_find_end_quote + if not is_find_end_quote and item[consumed] == ',': + break + consumed += 1 + if is_find_end_quote: + raise SettingsError('Mismatched quotes') + val_string = item[:consumed].strip() + if merge_type in (list, str): + try: + float(val_string) + except ValueError: + if val_string == self.cli_parse_none_str: + val_string = 'null' + if val_string not in ('true', 'false', 'null') and not val_string.startswith('"'): + val_string = f'"{val_string}"' + merged_list.append(val_string) + else: + key, val = (kv for kv in val_string.split('=', 1)) + if key.startswith('"') and not key.endswith('"') and not val.startswith('"') and val.endswith('"'): + raise ValueError(f'Dictionary key=val parameter is a quoted string: {val_string}') + key, val = key.strip('"'), val.strip('"') + merged_list.append(json.dumps({key: val})) + return item[consumed:] + + def _get_sub_models(self, model: type[BaseModel], field_name: str, field_info: FieldInfo) -> list[type[BaseModel]]: + field_types: tuple[Any, ...] = ( + (field_info.annotation,) if not get_args(field_info.annotation) else get_args(field_info.annotation) + ) + if self.cli_hide_none_type: + field_types = tuple([type_ for type_ in field_types if type_ is not type(None)]) + + sub_models: list[type[BaseModel]] = [] + for type_ in field_types: + if _annotation_contains_types(type_, (_CliSubCommand,), is_include_origin=False): + raise SettingsError(f'CliSubCommand is not outermost annotation for {model.__name__}.{field_name}') + elif _annotation_contains_types(type_, (_CliPositionalArg,), is_include_origin=False): + raise SettingsError(f'CliPositionalArg is not outermost annotation for {model.__name__}.{field_name}') + if is_model_class(type_) or is_pydantic_dataclass(type_): + sub_models.append(type_) # type: ignore + return sub_models + + def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, str, FieldInfo]]: + positional_args, subcommand_args, optional_args = [], [], [] + fields = model.__pydantic_fields__ if is_pydantic_dataclass(model) else model.model_fields + for field_name, field_info in fields.items(): + resolved_name = field_name if field_info.alias is None else field_info.alias + resolved_name = resolved_name.lower() if not self.case_sensitive else resolved_name + if _CliSubCommand in field_info.metadata: + if not field_info.is_required(): + raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has a default value') + else: + field_types = [type_ for type_ in get_args(field_info.annotation) if type_ is not type(None)] + if len(field_types) != 1: + raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has multiple types') + elif not is_model_class(field_types[0]): + raise SettingsError( + f'subcommand argument {model.__name__}.{resolved_name} is not derived from BaseModel' + ) + subcommand_args.append((field_name, resolved_name, field_info)) + elif _CliPositionalArg in field_info.metadata: + if not field_info.is_required(): + raise SettingsError(f'positional argument {model.__name__}.{field_name} has a default value') + positional_args.append((field_name, resolved_name, field_info)) + else: + optional_args.append((field_name, resolved_name, field_info)) + return positional_args + subcommand_args + optional_args + + @property + def root_parser(self) -> T: + """The connected root parser instance.""" + return self._root_parser + + def _connect_parser_method( + self, parser_method: Callable[..., Any] | None, method_name: str, *args: Any, **kwargs: Any + ) -> Callable[..., Any]: + if ( + parser_method is not None + and self.case_sensitive is False + and method_name == 'parsed_args_method' + and isinstance(self._root_parser, _CliInternalArgParser) + ): + + def parse_args_insensitive_method( + root_parser: _CliInternalArgParser, + args: list[str] | tuple[str, ...] | None = None, + namespace: Namespace | None = None, + ) -> Any: + insensitive_args = [] + for arg in shlex.split(shlex.join(args)) if args else []: + matched = re.match(r'^(--[^\s=]+)(.*)', arg) + if matched: + arg = matched.group(1).lower() + matched.group(2) + insensitive_args.append(arg) + return parser_method(root_parser, insensitive_args, namespace) # type: ignore + + return parse_args_insensitive_method + + elif parser_method is None: + + def none_parser_method(*args: Any, **kwargs: Any) -> Any: + raise SettingsError( + f'cannot connect CLI settings source root parser: {method_name} is set to `None` but is needed for connecting' + ) + + return none_parser_method + + else: + return parser_method + + def _connect_root_parser( + self, + root_parser: T, + parse_args_method: Callable[..., Any] | None = ArgumentParser.parse_args, + add_argument_method: Callable[..., Any] | None = ArgumentParser.add_argument, + add_argument_group_method: Callable[..., Any] | None = ArgumentParser.add_argument_group, + add_parser_method: Callable[..., Any] | None = _SubParsersAction.add_parser, + add_subparsers_method: Callable[..., Any] | None = ArgumentParser.add_subparsers, + formatter_class: Any = HelpFormatter, + ) -> None: + self._root_parser = root_parser + self._parse_args = self._connect_parser_method(parse_args_method, 'parsed_args_method') + self._add_argument = self._connect_parser_method(add_argument_method, 'add_argument_method') + self._add_argument_group = self._connect_parser_method(add_argument_group_method, 'add_argument_group_method') + self._add_parser = self._connect_parser_method(add_parser_method, 'add_parser_method') + self._add_subparsers = self._connect_parser_method(add_subparsers_method, 'add_subparsers_method') + self._formatter_class = formatter_class + self._cli_dict_args: dict[str, type[Any] | None] = {} + self._cli_subcommands: dict[str, list[str]] = {} + self._add_parser_args( + parser=self.root_parser, + model=self.settings_cls, + added_args=[], + arg_prefix=self.env_prefix, + subcommand_prefix=self.env_prefix, + group=None, + ) + + def _add_parser_args( + self, + parser: Any, + model: type[BaseModel], + added_args: list[str], + arg_prefix: str, + subcommand_prefix: str, + group: Any, + ) -> ArgumentParser: + subparsers: Any = None + for field_name, resolved_name, field_info in self._sort_arg_fields(model): + sub_models: list[type[BaseModel]] = self._get_sub_models(model, field_name, field_info) + if _CliSubCommand in field_info.metadata: + if subparsers is None: + subparsers = self._add_subparsers( + parser, title='subcommands', dest=f'{arg_prefix}:subcommand', required=self.cli_enforce_required + ) + self._cli_subcommands[f'{arg_prefix}:subcommand'] = [f'{arg_prefix}{resolved_name}'] + else: + self._cli_subcommands[f'{arg_prefix}:subcommand'].append(f'{arg_prefix}{resolved_name}') + if hasattr(subparsers, 'metavar'): + metavar = ','.join(self._cli_subcommands[f'{arg_prefix}:subcommand']) + subparsers.metavar = f'{{{metavar}}}' + + model = sub_models[0] + self._add_parser_args( + parser=self._add_parser( + subparsers, + resolved_name, + help=field_info.description, + formatter_class=self._formatter_class, + description=model.__doc__, + ), + model=model, + added_args=[], + arg_prefix=f'{arg_prefix}{resolved_name}.', + subcommand_prefix=f'{subcommand_prefix}{resolved_name}.', + group=None, + ) + else: + arg_flag: str = '--' + kwargs: dict[str, Any] = {} + kwargs['default'] = SUPPRESS + kwargs['help'] = self._help_format(field_info) + kwargs['dest'] = f'{arg_prefix}{resolved_name}' + kwargs['metavar'] = self._metavar_format(field_info.annotation) + kwargs['required'] = self.cli_enforce_required and field_info.is_required() + if kwargs['dest'] in added_args: + continue + if _annotation_contains_types( + field_info.annotation, (list, set, dict, Sequence, Mapping), is_strip_annotated=True + ): + kwargs['action'] = 'append' + if _annotation_contains_types(field_info.annotation, (dict, Mapping), is_strip_annotated=True): + self._cli_dict_args[kwargs['dest']] = field_info.annotation + + arg_name = ( + f'{arg_prefix}{resolved_name}' + if subcommand_prefix == self.env_prefix + else f'{arg_prefix.replace(subcommand_prefix, "", 1)}{resolved_name}' + ) + if _CliPositionalArg in field_info.metadata: + kwargs['metavar'] = resolved_name.upper() + arg_name = kwargs['dest'] + del kwargs['dest'] + del kwargs['required'] + arg_flag = '' + + if sub_models and kwargs.get('action') != 'append': + model_group: Any = None + model_group_kwargs: dict[str, Any] = {} + model_group_kwargs['title'] = f'{arg_name} options' + model_group_kwargs['description'] = ( + sub_models[0].__doc__ + if self.cli_use_class_docs_for_groups and len(sub_models) == 1 + else field_info.description + ) + if not self.cli_avoid_json: + added_args.append(arg_name) + kwargs['help'] = f'set {arg_name} from JSON string' + model_group = self._add_argument_group(parser, **model_group_kwargs) + self._add_argument(model_group, f'{arg_flag}{arg_name}', **kwargs) + for model in sub_models: + self._add_parser_args( + parser=parser, + model=model, + added_args=added_args, + arg_prefix=f'{arg_prefix}{resolved_name}.', + subcommand_prefix=subcommand_prefix, + group=model_group if model_group else model_group_kwargs, + ) + elif group is not None: + if isinstance(group, dict): + group = self._add_argument_group(parser, **group) + added_args.append(arg_name) + self._add_argument(group, f'{arg_flag}{arg_name}', **kwargs) + else: + added_args.append(arg_name) + self._add_argument(parser, f'{arg_flag}{arg_name}', **kwargs) + return parser + + def _get_modified_args(self, obj: Any) -> tuple[str, ...]: + if not self.cli_hide_none_type: + return get_args(obj) + else: + return tuple([type_ for type_ in get_args(obj) if type_ is not type(None)]) + + def _metavar_format_choices(self, args: list[str], obj_qualname: str | None = None) -> str: + if 'JSON' in args: + args = args[: args.index('JSON') + 1] + [arg for arg in args[args.index('JSON') + 1 :] if arg != 'JSON'] + metavar = ','.join(args) + if obj_qualname: + return f'{obj_qualname}[{metavar}]' + else: + return metavar if len(args) == 1 else f'{{{metavar}}}' + + def _metavar_format_recurse(self, obj: Any) -> str: + """Pretty metavar representation of a type. Adapts logic from `pydantic._repr.display_as_type`.""" + obj = _strip_annotated(obj) + if isinstance(obj, FunctionType): + return obj.__name__ + elif obj is ...: + return '...' + elif isinstance(obj, Representation): + return repr(obj) + elif isinstance(obj, typing_extensions.TypeAliasType): + return str(obj) + + if not isinstance(obj, (typing_base, WithArgsTypes, type)): + obj = obj.__class__ + + if origin_is_union(get_origin(obj)): + return self._metavar_format_choices(list(map(self._metavar_format_recurse, self._get_modified_args(obj)))) + elif get_origin(obj) in (typing_extensions.Literal, typing.Literal): + return self._metavar_format_choices(list(map(str, self._get_modified_args(obj)))) + elif lenient_issubclass(obj, Enum): + return self._metavar_format_choices([val.name for val in obj]) + elif isinstance(obj, WithArgsTypes): + return self._metavar_format_choices( + list(map(self._metavar_format_recurse, self._get_modified_args(obj))), obj_qualname=obj.__qualname__ + ) + elif obj is type(None): + return self.cli_parse_none_str + elif is_model_class(obj): + return 'JSON' + elif isinstance(obj, type): + return obj.__qualname__ + else: + return repr(obj).replace('typing.', '').replace('typing_extensions.', '') + + def _metavar_format(self, obj: Any) -> str: + return self._metavar_format_recurse(obj).replace(', ', ',') + + def _help_format(self, field_info: FieldInfo) -> str: + _help = field_info.description if field_info.description else '' + if field_info.is_required(): + if _CliPositionalArg not in field_info.metadata: + _help += ' (required)' if _help else '(required)' + else: + default = f'(default: {self.cli_parse_none_str})' + if field_info.default not in (PydanticUndefined, None): + default = f'(default: {field_info.default})' + elif field_info.default_factory is not None: + default = f'(default: {field_info.default_factory})' + _help += f' {default}' if _help else default + return _help + + class ConfigFileSourceMixin(ABC): def _read_files(self, files: PathType | None) -> dict[str, Any]: if files is None: @@ -938,3 +1598,25 @@ def _annotation_is_complex_inner(annotation: type[Any] | None) -> bool: def _union_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool: return any(_annotation_is_complex(arg, metadata) for arg in get_args(annotation)) + + +def _annotation_contains_types( + annotation: type[Any] | None, + types: tuple[Any, ...], + is_include_origin: bool = True, + is_strip_annotated: bool = False, +) -> bool: + if is_strip_annotated: + annotation = _strip_annotated(annotation) + if is_include_origin is True and get_origin(annotation) in types: + return True + for type_ in get_args(annotation): + if _annotation_contains_types(type_, types, is_include_origin=True, is_strip_annotated=is_strip_annotated): + return True + return annotation in types + + +def _strip_annotated(annotation: Any) -> Any: + while get_origin(annotation) == Annotated: + annotation = get_args(annotation)[0] + return annotation diff --git a/tests/test_settings.py b/tests/test_settings.py index ea430953..df89c821 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -1,7 +1,10 @@ +import argparse import dataclasses import json import os +import re import sys +import typing import uuid from datetime import datetime, timezone from enum import IntEnum @@ -9,11 +12,13 @@ from typing import Any, Callable, Dict, Generic, Hashable, List, Optional, Set, Tuple, Type, TypeVar, Union import pytest +import typing_extensions from annotated_types import MinLen from pydantic import ( AliasChoices, AliasPath, BaseModel, + DirectoryPath, Discriminator, Field, HttpUrl, @@ -26,6 +31,7 @@ from pydantic import ( dataclasses as pydantic_dataclasses, ) +from pydantic._internal._repr import Representation from pydantic.fields import FieldInfo from pytest_mock import MockerFixture from typing_extensions import Annotated, Literal @@ -43,7 +49,7 @@ TomlConfigSettingsSource, YamlConfigSettingsSource, ) -from pydantic_settings.sources import SettingsError, read_env_file +from pydantic_settings.sources import CliPositionalArg, CliSettingsSource, CliSubCommand, SettingsError, read_env_file try: import dotenv @@ -59,6 +65,53 @@ tomli = None +def foobar(a, b, c=4): + pass + + +T = TypeVar('T') + + +class FruitsEnum(IntEnum): + pear = 0 + kiwi = 1 + lime = 2 + + +class CliDummyArgGroup(BaseModel, arbitrary_types_allowed=True): + group: argparse._ArgumentGroup + + def add_argument(self, *args, **kwargs) -> None: + self.group.add_argument(*args, **kwargs) + + +class CliDummySubParsers(BaseModel, arbitrary_types_allowed=True): + sub_parser: argparse._SubParsersAction + + def add_parser(self, *args, **kwargs) -> 'CliDummyParser': + return CliDummyParser(parser=self.sub_parser.add_parser(*args, **kwargs)) + + +class CliDummyParser(BaseModel, arbitrary_types_allowed=True): + parser: argparse.ArgumentParser = Field(default_factory=lambda: argparse.ArgumentParser()) + + def add_argument(self, *args, **kwargs) -> None: + self.parser.add_argument(*args, **kwargs) + + def add_argument_group(self, *args, **kwargs) -> CliDummyArgGroup: + return CliDummyArgGroup(group=self.parser.add_argument_group(*args, **kwargs)) + + def add_subparsers(self, *args, **kwargs) -> CliDummySubParsers: + return CliDummySubParsers(sub_parser=self.parser.add_subparsers(*args, **kwargs)) + + def parse_args(self, *args, **kwargs) -> argparse.Namespace: + return self.parser.parse_args(*args, **kwargs) + + +class LoggedVar(Generic[T]): + def get(self) -> T: ... + + class SimpleSettings(BaseSettings): apple: str @@ -1897,11 +1950,6 @@ class Settings(BaseSettings): def test_env_parse_enums(env): - class FruitsEnum(IntEnum): - pear = 0 - kiwi = 1 - lime = 2 - class Settings(BaseSettings): fruit: FruitsEnum @@ -1970,6 +2018,26 @@ class NestedSettings(BaseSettings, env_nested_delimiter='__'): assert s.nested.deep['z'] is None assert s.nested.keep['z'] == 'None' + env.set('nested__deep', 'None') + + with pytest.raises(ValidationError): + s = NestedSettings() + s = NestedSettings(_env_parse_none_str='None') + assert s.nested.x is None + assert s.nested.y == 'y_override' + assert s.nested.deep['z'] is None + assert s.nested.keep['z'] == 'None' + + env.pop('nested__deep__z') + + with pytest.raises(ValidationError): + s = NestedSettings() + s = NestedSettings(_env_parse_none_str='None') + assert s.nested.x is None + assert s.nested.y == 'y_override' + assert s.nested.deep is None + assert s.nested.keep['z'] == 'None' + def test_env_json_field_dict(env): class Settings(BaseSettings): @@ -2063,6 +2131,1099 @@ class Settings(BaseSettings): assert s.data == {'foo': 'bar'} +def test_cli_nested_arg(): + class SubSubValue(BaseModel): + v6: str + + class SubValue(BaseModel): + v4: str + v5: int + sub_sub: SubSubValue + + class TopValue(BaseModel): + v1: str + v2: str + v3: str + sub: SubValue + + class Cfg(BaseSettings): + v0: str + v0_union: Union[SubValue, int] + top: TopValue + + args: List[str] = [] + args += ['--top', '{"v1": "json-1", "v2": "json-2", "sub": {"v5": "xx"}}'] + args += ['--top.sub.v5', '5'] + args += ['--v0', '0'] + args += ['--top.v2', '2'] + args += ['--top.v3', '3'] + args += ['--v0_union', '0'] + args += ['--top.sub.sub_sub.v6', '6'] + args += ['--top.sub.v4', '4'] + cfg = Cfg(_cli_parse_args=args) + assert cfg.model_dump() == { + 'v0': '0', + 'v0_union': 0, + 'top': { + 'v1': 'json-1', + 'v2': '2', + 'v3': '3', + 'sub': {'v4': '4', 'v5': 5, 'sub_sub': {'v6': '6'}}, + }, + } + + +def test_cli_source_prioritization(env): + class CfgDefault(BaseSettings): + foo: str + + class CfgPrioritized(BaseSettings): + foo: str + + @classmethod + def settings_customise_sources( + cls, + settings_cls: Type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> Tuple[PydanticBaseSettingsSource, ...]: + return env_settings, CliSettingsSource(settings_cls, cli_parse_args=['--foo', 'FOO FROM CLI']) + + env.set('FOO', 'FOO FROM ENV') + + cfg = CfgDefault(_cli_parse_args=['--foo', 'FOO FROM CLI']) + assert cfg.model_dump() == {'foo': 'FOO FROM CLI'} + + cfg = CfgPrioritized() + assert cfg.model_dump() == {'foo': 'FOO FROM ENV'} + + +def test_cli_alias_arg(): + class Animal(BaseModel): + name: str + + class Cfg(BaseSettings): + apple: str = Field(alias='alias') + pet: Animal = Field(alias='critter') + + cfg = Cfg(_cli_parse_args=['--alias', 'foo', '--critter.name', 'harry']) + assert cfg.model_dump() == {'apple': 'foo', 'pet': {'name': 'harry'}} + assert cfg.model_dump(by_alias=True) == {'alias': 'foo', 'critter': {'name': 'harry'}} + + +def test_cli_case_insensitve_arg(): + class Cfg(BaseSettings): + Foo: str + Bar: str + + cfg = Cfg(_cli_parse_args=['--FOO=--VAL', '--BAR', '"--VAL"']) + assert cfg.model_dump() == {'Foo': '--VAL', 'Bar': '"--VAL"'} + + cfg = Cfg(_cli_parse_args=['--Foo=--VAL', '--Bar', '"--VAL"'], _case_sensitive=True) + assert cfg.model_dump() == {'Foo': '--VAL', 'Bar': '"--VAL"'} + + with pytest.raises(SystemExit): + Cfg(_cli_parse_args=['--FOO=--VAL', '--BAR', '"--VAL"'], _case_sensitive=True) + + with pytest.raises(SettingsError) as exc_info: + CliSettingsSource(Cfg, root_parser=CliDummyParser(), case_sensitive=False) + assert str(exc_info.value) == 'Case-insensitive matching is only supported on the internal root parser' + + +def test_cli_help_differentiation(capsys, monkeypatch): + class Cfg(BaseSettings): + foo: str + bar: int = 123 + boo: int = Field(default_factory=lambda: 456) + + argparse_options_text = 'options' if sys.version_info >= (3, 10) else 'optional arguments' + + with monkeypatch.context() as m: + m.setattr(sys, 'argv', ['example.py', '--help']) + + with pytest.raises(SystemExit): + Cfg(_cli_parse_args=True) + + assert ( + re.sub(r'0x\w+', '0xffffffff', capsys.readouterr().out, re.MULTILINE) + == f"""usage: example.py [-h] [--foo str] [--bar int] [--boo int] + +{argparse_options_text}: + -h, --help show this help message and exit + --foo str (required) + --bar int (default: 123) + --boo int (default: .Cfg. at + 0xffffffff>) +""" + ) + + +def test_cli_nested_dataclass_arg(): + @pydantic_dataclasses.dataclass + class MyDataclass: + foo: int + bar: str + + class Settings(BaseSettings): + n: MyDataclass + + s = Settings(_cli_parse_args=['--n.foo', '123', '--n.bar', 'bar value']) + assert isinstance(s.n, MyDataclass) + assert s.n.foo == 123 + assert s.n.bar == 'bar value' + + +@pytest.mark.parametrize('prefix', ['', 'child.']) +def test_cli_list_arg(prefix): + class Obj(BaseModel): + val: int + + class Child(BaseModel): + num_list: Optional[List[int]] = None + obj_list: Optional[List[Obj]] = None + str_list: Optional[List[str]] = None + union_list: Optional[List[Union[Obj, int]]] = None + + class Cfg(BaseSettings): + num_list: Optional[List[int]] = None + obj_list: Optional[List[Obj]] = None + union_list: Optional[List[Union[Obj, int]]] = None + str_list: Optional[List[str]] = None + child: Optional[Child] = None + + def check_answer(cfg, prefix, expected): + if prefix: + assert cfg.model_dump() == { + 'num_list': None, + 'obj_list': None, + 'union_list': None, + 'str_list': None, + 'child': expected, + } + else: + expected['child'] = None + assert cfg.model_dump() == expected + + args: List[str] = [] + args = [f'--{prefix}num_list', '[1,2]'] + args += [f'--{prefix}num_list', '3,4'] + args += [f'--{prefix}num_list', '5', f'--{prefix}num_list', '6'] + cfg = Cfg(_cli_parse_args=args) + expected = { + 'num_list': [1, 2, 3, 4, 5, 6], + 'obj_list': None, + 'union_list': None, + 'str_list': None, + } + check_answer(cfg, prefix, expected) + + args = [f'--{prefix}obj_list', '[{"val":1},{"val":2}]'] + args += [f'--{prefix}obj_list', '{"val":3},{"val":4}'] + args += [f'--{prefix}obj_list', '{"val":5}', f'--{prefix}obj_list', '{"val":6}'] + cfg = Cfg(_cli_parse_args=args) + expected = { + 'num_list': None, + 'obj_list': [{'val': 1}, {'val': 2}, {'val': 3}, {'val': 4}, {'val': 5}, {'val': 6}], + 'union_list': None, + 'str_list': None, + } + check_answer(cfg, prefix, expected) + + args = [f'--{prefix}union_list', '[{"val":1},2]', f'--{prefix}union_list', '[3,{"val":4}]'] + args += [f'--{prefix}union_list', '{"val":5},6', f'--{prefix}union_list', '7,{"val":8}'] + args += [f'--{prefix}union_list', '{"val":9}', f'--{prefix}union_list', '10'] + cfg = Cfg(_cli_parse_args=args) + expected = { + 'num_list': None, + 'obj_list': None, + 'union_list': [{'val': 1}, 2, 3, {'val': 4}, {'val': 5}, 6, 7, {'val': 8}, {'val': 9}, 10], + 'str_list': None, + } + check_answer(cfg, prefix, expected) + + args = [f'--{prefix}str_list', '["0,0","1,1"]'] + args += [f'--{prefix}str_list', '"2,2","3,3"'] + args += [f'--{prefix}str_list', '"4,4"', f'--{prefix}str_list', '"5,5"'] + cfg = Cfg(_cli_parse_args=args) + expected = { + 'num_list': None, + 'obj_list': None, + 'union_list': None, + 'str_list': ['0,0', '1,1', '2,2', '3,3', '4,4', '5,5'], + } + check_answer(cfg, prefix, expected) + + +def test_cli_list_json_value_parsing(): + class Cfg(BaseSettings): + json_list: List[Union[str, bool, None]] + + assert Cfg( + _cli_parse_args=[ + '--json_list', + 'true,"true"', + '--json_list', + 'false,"false"', + '--json_list', + 'null,"null"', + '--json_list', + 'hi,"bye"', + ] + ).model_dump() == {'json_list': [True, 'true', False, 'false', None, 'null', 'hi', 'bye']} + + assert Cfg(_cli_parse_args=['--json_list', '"","","",""']).model_dump() == {'json_list': ['', '', '', '']} + assert Cfg(_cli_parse_args=['--json_list', ',,,']).model_dump() == {'json_list': ['', '', '', '']} + + +@pytest.mark.parametrize('prefix', ['', 'child.']) +def test_cli_dict_arg(prefix): + class Child(BaseModel): + check_dict: Dict[str, str] + + class Cfg(BaseSettings): + check_dict: Optional[Dict[str, str]] = None + child: Optional[Child] = None + + args: List[str] = [] + args = [f'--{prefix}check_dict', '{"k1":"a","k2":"b"}'] + args += [f'--{prefix}check_dict', '{"k3":"c"},{"k4":"d"}'] + args += [f'--{prefix}check_dict', '{"k5":"e"}', f'--{prefix}check_dict', '{"k6":"f"}'] + args += [f'--{prefix}check_dict', '[k7=g,k8=h]'] + args += [f'--{prefix}check_dict', 'k9=i,k10=j'] + args += [f'--{prefix}check_dict', 'k11=k', f'--{prefix}check_dict', 'k12=l'] + args += [f'--{prefix}check_dict', '[{"k13":"m"},k14=n]', f'--{prefix}check_dict', '[k15=o,{"k16":"p"}]'] + args += [f'--{prefix}check_dict', '{"k17":"q"},k18=r', f'--{prefix}check_dict', 'k19=s,{"k20":"t"}'] + args += [f'--{prefix}check_dict', '{"k21":"u"},k22=v,{"k23":"w"}'] + args += [f'--{prefix}check_dict', 'k24=x,{"k25":"y"},k26=z'] + args += [f'--{prefix}check_dict', '[k27="x,y",k28="x,y"]'] + args += [f'--{prefix}check_dict', 'k29="x,y",k30="x,y"'] + args += [f'--{prefix}check_dict', 'k31="x,y"', f'--{prefix}check_dict', 'k32="x,y"'] + cfg = Cfg(_cli_parse_args=args) + expected: Dict[str, Any] = { + 'check_dict': { + 'k1': 'a', + 'k2': 'b', + 'k3': 'c', + 'k4': 'd', + 'k5': 'e', + 'k6': 'f', + 'k7': 'g', + 'k8': 'h', + 'k9': 'i', + 'k10': 'j', + 'k11': 'k', + 'k12': 'l', + 'k13': 'm', + 'k14': 'n', + 'k15': 'o', + 'k16': 'p', + 'k17': 'q', + 'k18': 'r', + 'k19': 's', + 'k20': 't', + 'k21': 'u', + 'k22': 'v', + 'k23': 'w', + 'k24': 'x', + 'k25': 'y', + 'k26': 'z', + 'k27': 'x,y', + 'k28': 'x,y', + 'k29': 'x,y', + 'k30': 'x,y', + 'k31': 'x,y', + 'k32': 'x,y', + } + } + if prefix: + expected = {'check_dict': None, 'child': expected} + else: + expected['child'] = None + assert cfg.model_dump() == expected + + with pytest.raises(SettingsError) as exc_info: + cfg = Cfg(_cli_parse_args=[f'--{prefix}check_dict', 'k9="i']) + assert str(exc_info.value) == f'Parsing error encountered for {prefix}check_dict: Mismatched quotes' + + with pytest.raises(SettingsError): + cfg = Cfg(_cli_parse_args=[f'--{prefix}check_dict', 'k9=i"']) + assert str(exc_info.value) == f'Parsing error encountered for {prefix}check_dict: Mismatched quotes' + + +def test_cli_union_dict_arg(): + class Cfg(BaseSettings): + union_str_dict: Union[str, Dict[str, Any]] + + with pytest.raises(ValidationError) as exc_info: + args = ['--union_str_dict', 'hello world', '--union_str_dict', 'hello world'] + cfg = Cfg(_cli_parse_args=args) + assert exc_info.value.errors(include_url=False) == [ + { + 'input': [ + 'hello world', + 'hello world', + ], + 'loc': ( + 'union_str_dict', + 'str', + ), + 'msg': 'Input should be a valid string', + 'type': 'string_type', + }, + { + 'input': [ + 'hello world', + 'hello world', + ], + 'loc': ( + 'union_str_dict', + 'dict[str,any]', + ), + 'msg': 'Input should be a valid dictionary', + 'type': 'dict_type', + }, + ] + + args = ['--union_str_dict', 'hello world'] + cfg = Cfg(_cli_parse_args=args) + assert cfg.model_dump() == {'union_str_dict': 'hello world'} + + args = ['--union_str_dict', '{"hello": "world"}'] + cfg = Cfg(_cli_parse_args=args) + assert cfg.model_dump() == {'union_str_dict': {'hello': 'world'}} + + args = ['--union_str_dict', 'hello=world'] + cfg = Cfg(_cli_parse_args=args) + assert cfg.model_dump() == {'union_str_dict': {'hello': 'world'}} + + args = ['--union_str_dict', '"hello=world"'] + cfg = Cfg(_cli_parse_args=args) + assert cfg.model_dump() == {'union_str_dict': 'hello=world'} + + class Cfg(BaseSettings): + union_list_dict: Union[List[str], Dict[str, Any]] + + with pytest.raises(ValidationError) as exc_info: + args = ['--union_list_dict', 'hello,world'] + cfg = Cfg(_cli_parse_args=args) + assert exc_info.value.errors(include_url=False) == [ + { + 'input': 'hello,world', + 'loc': ( + 'union_list_dict', + 'list[str]', + ), + 'msg': 'Input should be a valid list', + 'type': 'list_type', + }, + { + 'input': 'hello,world', + 'loc': ( + 'union_list_dict', + 'dict[str,any]', + ), + 'msg': 'Input should be a valid dictionary', + 'type': 'dict_type', + }, + ] + + args = ['--union_list_dict', 'hello,world', '--union_list_dict', 'hello,world'] + cfg = Cfg(_cli_parse_args=args) + assert cfg.model_dump() == {'union_list_dict': ['hello', 'world', 'hello', 'world']} + + args = ['--union_list_dict', '[hello,world]'] + cfg = Cfg(_cli_parse_args=args) + assert cfg.model_dump() == {'union_list_dict': ['hello', 'world']} + + args = ['--union_list_dict', '{"hello": "world"}'] + cfg = Cfg(_cli_parse_args=args) + assert cfg.model_dump() == {'union_list_dict': {'hello': 'world'}} + + args = ['--union_list_dict', 'hello=world'] + cfg = Cfg(_cli_parse_args=args) + assert cfg.model_dump() == {'union_list_dict': {'hello': 'world'}} + + with pytest.raises(ValidationError) as exc_info: + args = ['--union_list_dict', '"hello=world"'] + cfg = Cfg(_cli_parse_args=args) + assert exc_info.value.errors(include_url=False) == [ + { + 'input': 'hello=world', + 'loc': ( + 'union_list_dict', + 'list[str]', + ), + 'msg': 'Input should be a valid list', + 'type': 'list_type', + }, + { + 'input': 'hello=world', + 'loc': ( + 'union_list_dict', + 'dict[str,any]', + ), + 'msg': 'Input should be a valid dictionary', + 'type': 'dict_type', + }, + ] + + args = ['--union_list_dict', '["hello=world"]'] + cfg = Cfg(_cli_parse_args=args) + assert cfg.model_dump() == {'union_list_dict': ['hello=world']} + + +def test_cli_nested_dict_arg(): + class Cfg(BaseSettings): + check_dict: Dict[str, Any] + + args = ['--check_dict', '{"k1":{"a": 1}},{"k2":{"b": 2}}'] + cfg = Cfg(_cli_parse_args=args) + assert cfg.model_dump() == {'check_dict': {'k1': {'a': 1}, 'k2': {'b': 2}}} + + with pytest.raises(SettingsError) as exc_info: + args = ['--check_dict', '{"k1":{"a": 1}},"k2":{"b": 2}}'] + cfg = Cfg(_cli_parse_args=args) + assert ( + str(exc_info.value) + == 'Parsing error encountered for check_dict: not enough values to unpack (expected 2, got 1)' + ) + + with pytest.raises(SettingsError) as exc_info: + args = ['--check_dict', '{"k1":{"a": 1}},{"k2":{"b": 2}'] + cfg = Cfg(_cli_parse_args=args) + assert str(exc_info.value) == 'Parsing error encountered for check_dict: Missing end delimiter "}"' + + +def test_cli_subcommand_with_positionals(): + class FooPlugin(BaseModel): + my_feature: bool = False + + class BarPlugin(BaseModel): + my_feature: bool = False + + class Plugins(BaseModel): + foo: CliSubCommand[FooPlugin] + bar: CliSubCommand[BarPlugin] + + class Clone(BaseModel): + repository: CliPositionalArg[str] + directory: CliPositionalArg[str] + local: bool = False + shared: bool = False + + class Init(BaseModel): + directory: CliPositionalArg[str] + quiet: bool = False + bare: bool = False + + class Git(BaseSettings): + clone: CliSubCommand[Clone] + init: CliSubCommand[Init] + plugins: CliSubCommand[Plugins] + + git = Git(_cli_parse_args=['init', '--quiet', 'true', 'dir/path']) + assert git.model_dump() == { + 'clone': None, + 'init': {'directory': 'dir/path', 'quiet': True, 'bare': False}, + 'plugins': None, + } + + git = Git(_cli_parse_args=['clone', 'repo', '.', '--shared', 'true']) + assert git.model_dump() == { + 'clone': {'repository': 'repo', 'directory': '.', 'local': False, 'shared': True}, + 'init': None, + 'plugins': None, + } + + git = Git(_cli_parse_args=['plugins', 'bar']) + assert git.model_dump() == { + 'clone': None, + 'init': None, + 'plugins': {'foo': None, 'bar': {'my_feature': False}}, + } + + +def test_cli_union_similar_sub_models(): + class ChildA(BaseModel): + name: str = 'child a' + diff_a: str = 'child a difference' + + class ChildB(BaseModel): + name: str = 'child b' + diff_b: str = 'child b difference' + + class Cfg(BaseSettings): + child: Union[ChildA, ChildB] + + cfg = Cfg(_cli_parse_args=['--child.name', 'new name a', '--child.diff_a', 'new diff a']) + assert cfg.model_dump() == {'child': {'name': 'new name a', 'diff_a': 'new diff a'}} + + +def test_cli_enums(): + class Pet(IntEnum): + dog = 0 + cat = 1 + bird = 2 + + class Cfg(BaseSettings): + pet: Pet + + cfg = Cfg(_cli_parse_args=['--pet', 'cat']) + assert cfg.model_dump() == {'pet': Pet.cat} + + with pytest.raises(ValidationError) as exc_info: + Cfg(_cli_parse_args=['--pet', 'rock']) + assert exc_info.value.errors(include_url=False) == [ + { + 'type': 'enum', + 'loc': ('pet',), + 'msg': 'Input should be 0, 1 or 2', + 'input': 'rock', + 'ctx': {'expected': '0, 1 or 2'}, + } + ] + + +def test_cli_literals(): + class Cfg(BaseSettings): + pet: Literal['dog', 'cat', 'bird'] + + cfg = Cfg(_cli_parse_args=['--pet', 'cat']) + assert cfg.model_dump() == {'pet': 'cat'} + + with pytest.raises(ValidationError) as exc_info: + Cfg(_cli_parse_args=['--pet', 'rock']) + assert exc_info.value.errors(include_url=False) == [ + { + 'ctx': {'expected': "'dog', 'cat' or 'bird'"}, + 'type': 'literal_error', + 'loc': ('pet',), + 'msg': "Input should be 'dog', 'cat' or 'bird'", + 'input': 'rock', + } + ] + + +def test_cli_annotation_exceptions(monkeypatch): + class SubCmdAlt(BaseModel): + pass + + class SubCmd(BaseModel): + pass + + with monkeypatch.context() as m: + m.setattr(sys, 'argv', ['example.py', '--help']) + + with pytest.raises(SettingsError) as exc_info: + + class SubCommandNotOutermost(BaseSettings, cli_parse_args=True): + subcmd: Union[int, CliSubCommand[SubCmd]] + + SubCommandNotOutermost() + assert str(exc_info.value) == 'CliSubCommand is not outermost annotation for SubCommandNotOutermost.subcmd' + + with pytest.raises(SettingsError) as exc_info: + + class SubCommandHasDefault(BaseSettings, cli_parse_args=True): + subcmd: CliSubCommand[SubCmd] = SubCmd() + + SubCommandHasDefault() + assert str(exc_info.value) == 'subcommand argument SubCommandHasDefault.subcmd has a default value' + + with pytest.raises(SettingsError) as exc_info: + + class SubCommandMultipleTypes(BaseSettings, cli_parse_args=True): + subcmd: CliSubCommand[Union[SubCmd, SubCmdAlt]] + + SubCommandMultipleTypes() + assert str(exc_info.value) == 'subcommand argument SubCommandMultipleTypes.subcmd has multiple types' + + with pytest.raises(SettingsError) as exc_info: + + class SubCommandNotModel(BaseSettings, cli_parse_args=True): + subcmd: CliSubCommand[str] + + SubCommandNotModel() + assert str(exc_info.value) == 'subcommand argument SubCommandNotModel.subcmd is not derived from BaseModel' + + with pytest.raises(SettingsError) as exc_info: + + class PositionalArgNotOutermost(BaseSettings, cli_parse_args=True): + pos_arg: Union[int, CliPositionalArg[str]] + + PositionalArgNotOutermost() + assert ( + str(exc_info.value) == 'CliPositionalArg is not outermost annotation for PositionalArgNotOutermost.pos_arg' + ) + + with pytest.raises(SettingsError) as exc_info: + + class PositionalArgHasDefault(BaseSettings, cli_parse_args=True): + pos_arg: CliPositionalArg[str] = 'bad' + + PositionalArgHasDefault() + assert str(exc_info.value) == 'positional argument PositionalArgHasDefault.pos_arg has a default value' + + with pytest.raises(SettingsError) as exc_info: + + class InvalidCliParseArgsType(BaseSettings, cli_parse_args='invalid type'): + val: int + + InvalidCliParseArgsType() + assert str(exc_info.value) == "cli_parse_args must be List[str] or Tuple[str, ...], recieved " + + +def test_cli_avoid_json(capsys, monkeypatch): + class SubModel(BaseModel): + v1: int + + class Settings(BaseSettings): + sub_model: SubModel + + model_config = SettingsConfigDict(cli_parse_args=True) + + argparse_options_text = 'options' if sys.version_info >= (3, 10) else 'optional arguments' + + with monkeypatch.context() as m: + m.setattr(sys, 'argv', ['example.py', '--help']) + + with pytest.raises(SystemExit): + Settings(_cli_avoid_json=False) + + assert ( + capsys.readouterr().out + == f"""usage: example.py [-h] [--sub_model JSON] [--sub_model.v1 int] + +{argparse_options_text}: + -h, --help show this help message and exit + +sub_model options: + --sub_model JSON set sub_model from JSON string + --sub_model.v1 int (required) +""" + ) + + with pytest.raises(SystemExit): + Settings(_cli_avoid_json=True) + + assert ( + capsys.readouterr().out + == f"""usage: example.py [-h] [--sub_model.v1 int] + +{argparse_options_text}: + -h, --help show this help message and exit + +sub_model options: + --sub_model.v1 int (required) +""" + ) + + +def test_cli_remove_empty_groups(capsys, monkeypatch): + class SubModel(BaseModel): + pass + + class Settings(BaseSettings): + sub_model: SubModel + + model_config = SettingsConfigDict(cli_parse_args=True) + + argparse_options_text = 'options' if sys.version_info >= (3, 10) else 'optional arguments' + + with monkeypatch.context() as m: + m.setattr(sys, 'argv', ['example.py', '--help']) + + with pytest.raises(SystemExit): + Settings(_cli_avoid_json=False) + + assert ( + capsys.readouterr().out + == f"""usage: example.py [-h] [--sub_model JSON] + +{argparse_options_text}: + -h, --help show this help message and exit + +sub_model options: + --sub_model JSON set sub_model from JSON string +""" + ) + + with pytest.raises(SystemExit): + Settings(_cli_avoid_json=True) + + assert ( + capsys.readouterr().out + == f"""usage: example.py [-h] + +{argparse_options_text}: + -h, --help show this help message and exit +""" + ) + + +def test_cli_hide_none_type(capsys, monkeypatch): + class Settings(BaseSettings): + v0: Optional[str] + + model_config = SettingsConfigDict(cli_parse_args=True) + + argparse_options_text = 'options' if sys.version_info >= (3, 10) else 'optional arguments' + + with monkeypatch.context() as m: + m.setattr(sys, 'argv', ['example.py', '--help']) + + with pytest.raises(SystemExit): + Settings(_cli_hide_none_type=False) + + assert ( + capsys.readouterr().out + == f"""usage: example.py [-h] [--v0 {{str,null}}] + +{argparse_options_text}: + -h, --help show this help message and exit + --v0 {{str,null}} (required) +""" + ) + + with pytest.raises(SystemExit): + Settings(_cli_hide_none_type=True) + + assert ( + capsys.readouterr().out + == f"""usage: example.py [-h] [--v0 str] + +{argparse_options_text}: + -h, --help show this help message and exit + --v0 str (required) +""" + ) + + +def test_cli_use_class_docs_for_groups(capsys, monkeypatch): + class SubModel(BaseModel): + """The help text from the class docstring""" + + v1: int + + class Settings(BaseSettings): + """My application help text.""" + + sub_model: SubModel = Field(description='The help text from the field description') + + model_config = SettingsConfigDict(cli_parse_args=True) + + argparse_options_text = 'options' if sys.version_info >= (3, 10) else 'optional arguments' + + with monkeypatch.context() as m: + m.setattr(sys, 'argv', ['example.py', '--help']) + + with pytest.raises(SystemExit): + Settings(_cli_use_class_docs_for_groups=False) + + assert ( + capsys.readouterr().out + == f"""usage: example.py [-h] [--sub_model JSON] [--sub_model.v1 int] + +My application help text. + +{argparse_options_text}: + -h, --help show this help message and exit + +sub_model options: + The help text from the field description + + --sub_model JSON set sub_model from JSON string + --sub_model.v1 int (required) +""" + ) + + with pytest.raises(SystemExit): + Settings(_cli_use_class_docs_for_groups=True) + + assert ( + capsys.readouterr().out + == f"""usage: example.py [-h] [--sub_model JSON] [--sub_model.v1 int] + +My application help text. + +{argparse_options_text}: + -h, --help show this help message and exit + +sub_model options: + The help text from the class docstring + + --sub_model JSON set sub_model from JSON string + --sub_model.v1 int (required) +""" + ) + + +def test_cli_enforce_required(env): + class Settings(BaseSettings): + my_required_field: str + + env.set('MY_REQUIRED_FIELD', 'hello from environment') + + assert Settings(_cli_parse_args=[], _cli_enforce_required=False).model_dump() == { + 'my_required_field': 'hello from environment' + } + + with pytest.raises(SystemExit): + Settings(_cli_parse_args=[], _cli_enforce_required=True).model_dump() + + +@pytest.mark.parametrize('parser_type', [pytest.Parser, argparse.ArgumentParser, CliDummyParser]) +@pytest.mark.parametrize('prefix', ['', 'cfg']) +def test_cli_user_settings_source(parser_type, prefix): + class Cfg(BaseSettings): + pet: Literal['dog', 'cat', 'bird'] = 'bird' + + if parser_type is pytest.Parser: + parser = pytest.Parser(_ispytest=True) + parse_args = parser.parse + add_arg = parser.addoption + cli_cfg_settings = CliSettingsSource( + Cfg, + cli_prefix=prefix, + root_parser=parser, + parse_args_method=pytest.Parser.parse, + add_argument_method=pytest.Parser.addoption, + add_argument_group_method=pytest.Parser.getgroup, + add_parser_method=None, + add_subparsers_method=None, + formatter_class=None, + ) + elif parser_type is CliDummyParser: + parser = CliDummyParser() + parse_args = parser.parse_args + add_arg = parser.add_argument + cli_cfg_settings = CliSettingsSource( + Cfg, + cli_prefix=prefix, + root_parser=parser, + parse_args_method=CliDummyParser.parse_args, + add_argument_method=CliDummyParser.add_argument, + add_argument_group_method=CliDummyParser.add_argument_group, + add_parser_method=CliDummySubParsers.add_parser, + add_subparsers_method=CliDummyParser.add_subparsers, + ) + else: + parser = argparse.ArgumentParser() + parse_args = parser.parse_args + add_arg = parser.add_argument + cli_cfg_settings = CliSettingsSource(Cfg, cli_prefix=prefix, root_parser=parser) + + add_arg('--fruit', choices=['pear', 'kiwi', 'lime']) + + args = ['--fruit', 'pear'] + parsed_args = parse_args(args) + assert Cfg(_cli_settings_source=cli_cfg_settings(parsed_args=parsed_args)).model_dump() == {'pet': 'bird'} + assert Cfg(_cli_settings_source=cli_cfg_settings(args=args)).model_dump() == {'pet': 'bird'} + assert Cfg(_cli_settings_source=cli_cfg_settings(args=False)).model_dump() == {'pet': 'bird'} + + arg_prefix = f'{prefix}.' if prefix else '' + args = ['--fruit', 'kiwi', f'--{arg_prefix}pet', 'dog'] + parsed_args = parse_args(args) + assert Cfg(_cli_settings_source=cli_cfg_settings(parsed_args=parsed_args)).model_dump() == {'pet': 'dog'} + assert Cfg(_cli_settings_source=cli_cfg_settings(args=args)).model_dump() == {'pet': 'dog'} + assert Cfg(_cli_settings_source=cli_cfg_settings(args=False)).model_dump() == {'pet': 'bird'} + + parsed_args = parse_args(['--fruit', 'kiwi', f'--{arg_prefix}pet', 'cat']) + assert Cfg(_cli_settings_source=cli_cfg_settings(parsed_args=vars(parsed_args))).model_dump() == {'pet': 'cat'} + assert Cfg(_cli_settings_source=cli_cfg_settings(args=False)).model_dump() == {'pet': 'bird'} + + +@pytest.mark.parametrize('prefix', ['', 'cfg']) +def test_cli_dummy_user_settings_with_subcommand(prefix): + class DogCommands(BaseModel): + name: str = 'Bob' + command: Literal['roll', 'bark', 'sit'] = 'sit' + + class Cfg(BaseSettings): + pet: Literal['dog', 'cat', 'bird'] = 'bird' + command: CliSubCommand[DogCommands] + + parser = CliDummyParser() + cli_cfg_settings = CliSettingsSource( + Cfg, + root_parser=parser, + cli_prefix=prefix, + parse_args_method=CliDummyParser.parse_args, + add_argument_method=CliDummyParser.add_argument, + add_argument_group_method=CliDummyParser.add_argument_group, + add_parser_method=CliDummySubParsers.add_parser, + add_subparsers_method=CliDummyParser.add_subparsers, + ) + + parser.add_argument('--fruit', choices=['pear', 'kiwi', 'lime']) + + args = ['--fruit', 'pear'] + parsed_args = parser.parse_args(args) + assert Cfg(_cli_settings_source=cli_cfg_settings(parsed_args=parsed_args)).model_dump() == { + 'pet': 'bird', + 'command': None, + } + assert Cfg(_cli_settings_source=cli_cfg_settings(args=args)).model_dump() == { + 'pet': 'bird', + 'command': None, + } + + arg_prefix = f'{prefix}.' if prefix else '' + args = ['--fruit', 'kiwi', f'--{arg_prefix}pet', 'dog'] + parsed_args = parser.parse_args(args) + assert Cfg(_cli_settings_source=cli_cfg_settings(parsed_args=parsed_args)).model_dump() == { + 'pet': 'dog', + 'command': None, + } + assert Cfg(_cli_settings_source=cli_cfg_settings(args=args)).model_dump() == { + 'pet': 'dog', + 'command': None, + } + + parsed_args = parser.parse_args(['--fruit', 'kiwi', f'--{arg_prefix}pet', 'cat']) + assert Cfg(_cli_settings_source=cli_cfg_settings(parsed_args=vars(parsed_args))).model_dump() == { + 'pet': 'cat', + 'command': None, + } + + args = ['--fruit', 'kiwi', f'--{arg_prefix}pet', 'dog', 'command', '--name', 'ralph', '--command', 'roll'] + parsed_args = parser.parse_args(args) + assert Cfg(_cli_settings_source=cli_cfg_settings(parsed_args=vars(parsed_args))).model_dump() == { + 'pet': 'dog', + 'command': {'name': 'ralph', 'command': 'roll'}, + } + assert Cfg(_cli_settings_source=cli_cfg_settings(args=args)).model_dump() == { + 'pet': 'dog', + 'command': {'name': 'ralph', 'command': 'roll'}, + } + + +def test_cli_user_settings_source_exceptions(): + class Cfg(BaseSettings): + pet: Literal['dog', 'cat', 'bird'] = 'bird' + + with pytest.raises(SettingsError) as exc_info: + args = ['--pet', 'dog'] + parsed_args = {'pet': 'dog'} + cli_cfg_settings = CliSettingsSource(Cfg) + Cfg(_cli_settings_source=cli_cfg_settings(args=args, parsed_args=parsed_args)) + assert str(exc_info.value) == '`args` and `parsed_args` are mutually exclusive' + + with pytest.raises(SettingsError) as exc_info: + CliSettingsSource(Cfg, cli_prefix='.cfg') + assert str(exc_info.value) == 'CLI settings source prefix is invalid: .cfg' + + with pytest.raises(SettingsError) as exc_info: + CliSettingsSource(Cfg, cli_prefix='cfg.') + assert str(exc_info.value) == 'CLI settings source prefix is invalid: cfg.' + + with pytest.raises(SettingsError) as exc_info: + CliSettingsSource(Cfg, cli_prefix='123') + assert str(exc_info.value) == 'CLI settings source prefix is invalid: 123' + + class Food(BaseModel): + fruit: FruitsEnum = FruitsEnum.kiwi + + class CfgWithSubCommand(BaseSettings): + pet: Literal['dog', 'cat', 'bird'] = 'bird' + food: CliSubCommand[Food] + + with pytest.raises(SettingsError) as exc_info: + CliSettingsSource(CfgWithSubCommand, add_subparsers_method=None) + assert ( + str(exc_info.value) + == 'cannot connect CLI settings source root parser: add_subparsers_method is set to `None` but is needed for connecting' + ) + + +@pytest.mark.parametrize( + 'value,expected', + [ + (str, 'str'), + ('foobar', 'str'), + ('SomeForwardRefString', 'str'), # included to document current behavior; could be changed + (List['SomeForwardRef'], "List[ForwardRef('SomeForwardRef')]"), # noqa: F821 + (Union[str, int], '{str,int}'), + (list, 'list'), + (List, 'List'), + ([1, 2, 3], 'list'), + (List[Dict[str, int]], 'List[Dict[str,int]]'), + (Tuple[str, int, float], 'Tuple[str,int,float]'), + (Tuple[str, ...], 'Tuple[str,...]'), + (Union[int, List[str], Tuple[str, int]], '{int,List[str],Tuple[str,int]}'), + (foobar, 'foobar'), + (LoggedVar, 'LoggedVar'), + (LoggedVar(), 'LoggedVar'), + (Representation(), 'Representation()'), + (typing.Literal[1, 2, 3], '{1,2,3}'), + (typing_extensions.Literal[1, 2, 3], '{1,2,3}'), + (typing.Literal['a', 'b', 'c'], '{a,b,c}'), + (typing_extensions.Literal['a', 'b', 'c'], '{a,b,c}'), + (SimpleSettings, 'JSON'), + (Union[SimpleSettings, SettingWithIgnoreEmpty], 'JSON'), + (Union[SimpleSettings, str, SettingWithIgnoreEmpty], '{JSON,str}'), + (Union[str, SimpleSettings, SettingWithIgnoreEmpty], '{str,JSON}'), + (Annotated[SimpleSettings, 'annotation'], 'JSON'), + (DirectoryPath, 'Path'), + (FruitsEnum, '{pear,kiwi,lime}'), + ], +) +@pytest.mark.parametrize('hide_none_type', [True, False]) +def test_cli_metavar_format(hide_none_type, value, expected): + cli_settings = CliSettingsSource(SimpleSettings, cli_hide_none_type=hide_none_type) + if hide_none_type: + if value == [1, 2, 3] or isinstance(value, LoggedVar) or isinstance(value, Representation): + pytest.skip() + if value in ('foobar', 'SomeForwardRefString'): + expected = f"ForwardRef('{value}')" # forward ref implicit cast + if typing_extensions.get_origin(value) is Union: + args = typing_extensions.get_args(value) + value = Union[args + (None,) if args else (value, None)] + else: + value = Union[(value, None)] + assert cli_settings._metavar_format(value) == expected + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason='requires python 3.10 or higher') +@pytest.mark.parametrize( + 'value_gen,expected', + [ + (lambda: str | int, '{str,int}'), + (lambda: list[int], 'list[int]'), + (lambda: List[int], 'List[int]'), + (lambda: list[dict[str, int]], 'list[dict[str,int]]'), + (lambda: list[Union[str, int]], 'list[{str,int}]'), + (lambda: list[str | int], 'list[{str,int}]'), + (lambda: LoggedVar[int], 'LoggedVar[int]'), + (lambda: LoggedVar[Dict[int, str]], 'LoggedVar[Dict[int,str]]'), + ], +) +@pytest.mark.parametrize('hide_none_type', [True, False]) +def test_cli_metavar_format_310(hide_none_type, value_gen, expected): + value = value_gen() + cli_settings = CliSettingsSource(SimpleSettings, cli_hide_none_type=hide_none_type) + if hide_none_type: + if typing_extensions.get_origin(value) is Union: + args = typing_extensions.get_args(value) + value = Union[args + (None,) if args else (value, None)] + else: + value = Union[(value, None)] + assert cli_settings._metavar_format(value) == expected + + +@pytest.mark.skipif(sys.version_info < (3, 12), reason='requires python 3.12 or higher') +def test_cli_metavar_format_type_alias_312(): + exec( + """ +type TypeAliasInt = int +assert CliSettingsSource(SimpleSettings)._metavar_format(TypeAliasInt) == 'TypeAliasInt' +""" + ) + + def test_json_file(tmp_path): p = tmp_path / '.env' p.write_text(