diff --git a/changes/2834.feature.md b/changes/2834.feature.md new file mode 100644 index 00000000000..1b864159b70 --- /dev/null +++ b/changes/2834.feature.md @@ -0,0 +1 @@ +Configure the logging module's config file to use Pydantic diff --git a/configs/wsproxy/halfstack.toml b/configs/wsproxy/halfstack.toml index ce552ba3bb3..b0360933e76 100644 --- a/configs/wsproxy/halfstack.toml +++ b/configs/wsproxy/halfstack.toml @@ -1,15 +1,15 @@ [wsproxy] -bind_host = "0.0.0.0" -advertised_host = "127.0.0.1" +bind-host = "0.0.0.0" +advertised-host = "127.0.0.1" -bind_api_port = 5050 -advertised_api_port = 5050 +bind-api-port = 5050 +advertised-api-port = 5050 # replace these values with your passphrase -jwt_encrypt_key = "QRX/ZX2nwKKpuTSD3ZycPA" -permit_hash_key = "gHNAohmntRV0t9zlwTVQeQ" +jwt-encrypt-key = "QRX/ZX2nwKKpuTSD3ZycPA" +permit-hash-key = "gHNAohmntRV0t9zlwTVQeQ" -api_secret = "v625xZLOgbMHhl0s49VuqQ" +api-secret = "v625xZLOgbMHhl0s49VuqQ" [pyroscope] enabled = true diff --git a/configs/wsproxy/sample.toml b/configs/wsproxy/sample.toml index 14ecab137d3..e8df4123316 100644 --- a/configs/wsproxy/sample.toml +++ b/configs/wsproxy/sample.toml @@ -1,31 +1,31 @@ [wsproxy] -ipc_base_path = "/tmp/backend.ai/ipc" -event_loop = "asyncio" -pid_file = "/run/backend.ai/wsproxy/wsproxy.pid" +ipc-base-path = "/tmp/backend.ai/ipc" +event-loop = "asyncio" +pid-file = "/run/backend.ai/wsproxy/wsproxy.pid" id = "i-node01" user = 501 group = 501 -bind_host = "0.0.0.0" -advertised_host = "example.com" -bind_api_port = 5050 -advertised_api_port = 15050 -bind_proxy_port_range = [ +bind-host = "0.0.0.0" +advertised-host = "example.com" +bind-api-port = 5050 +advertised-api-port = 15050 +bind-proxy-port-range = [ 10200, 10300, ] -advertised_proxy_port_range = [ +advertised-proxy-port-range = [ 20200, 20300, ] protocol = "http" # replace these values with your passphrase -jwt_encrypt_key = "50M3G00DL00KING53CR3T" -permit_hash_key = "50M3G00DL00KING53CR3T" -api_secret = "50M3G00DL00KING53CR3T" +jwt-encrypt-key = "50M3G00DL00KING53CR3T" +permit-hash-key = "50M3G00DL00KING53CR3T" +api-secret = "50M3G00DL00KING53CR3T" -aiomonitor_termui_port = 48500 -aiomonitor_webui_port = 49500 +aiomonitor-termui-port = 48500 +aiomonitor-webui-port = 49500 [pyroscope] enabled = true @@ -39,7 +39,7 @@ drivers = [ "console", ] -[logging.pkg_ns] +[logging.pkg-ns] "" = "WARNING" "ai.backend" = "DEBUG" tests = "DEBUG" @@ -52,14 +52,14 @@ format = "verbose" [logging.file] path = "/var/log/backend.ai" filename = "wsproxy.log" -backup_count = 5 -rotation_size = "10M" +backup-count = 5 +rotation-size = "10M" format = "verbose" [logging.logstash] protocol = "tcp" -ssl_enabled = true -ssl_verify = true +ssl-enabled = true +ssl-verify = true [logging.logstash.endpoint] host = "127.0.0.1" @@ -69,14 +69,14 @@ port = 8001 host = "127.0.0.1" port = 8000 level = "INFO" -ssl_verify = true -ca_certs = "/etc/ssl/ca.pem" +ssl-verify = true +ca-certs = "/etc/ssl/ca.pem" keyfile = "/etc/backend.ai/graylog/privkey.pem" certfile = "/etc/backend.ai/graylog/cert.pem" [debug] enabled = false asyncio = false -enhanced_aiomonitor_task_info = false -log_events = false +enhanced-aiomonitor-task-info = false +log-events = false diff --git a/src/ai/backend/common/config.py b/src/ai/backend/common/config.py index 5d3537ccd57..76ae2edb846 100644 --- a/src/ai/backend/common/config.py +++ b/src/ai/backend/common/config.py @@ -36,8 +36,16 @@ class BaseSchema(BaseModel): + @staticmethod + def snake_to_kebab_case(string: str) -> str: + return string.replace("_", "-") + model_config = ConfigDict( - populate_by_name=True, from_attributes=True, use_enum_values=True, extra="allow" + populate_by_name=True, + from_attributes=True, + use_enum_values=True, + extra="allow", + alias_generator=snake_to_kebab_case, ) diff --git a/src/ai/backend/common/metrics/profiler.py b/src/ai/backend/common/metrics/profiler.py index e973d9f4fa3..236d26fdeb6 100644 --- a/src/ai/backend/common/metrics/profiler.py +++ b/src/ai/backend/common/metrics/profiler.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Optional import pyroscope @@ -6,9 +7,9 @@ @dataclass class PyroscopeArgs: enabled: bool - app_name: str - server_address: str - sample_rate: int + app_name: Optional[str] + server_address: Optional[str] + sample_rate: Optional[int] class Profiler: diff --git a/src/ai/backend/logging/BUILD b/src/ai/backend/logging/BUILD index 21d1acf5461..fae884f1d70 100644 --- a/src/ai/backend/logging/BUILD +++ b/src/ai/backend/logging/BUILD @@ -9,7 +9,9 @@ visibility_private_component( allowed_dependents=[ "//src/ai/backend/**", ], - allowed_dependencies=[], + allowed_dependencies=[ + "//src/ai/backend/common/**", + ], ) python_distribution( diff --git a/src/ai/backend/logging/config_pydantic.py b/src/ai/backend/logging/config_pydantic.py new file mode 100644 index 00000000000..9c2a161b968 --- /dev/null +++ b/src/ai/backend/logging/config_pydantic.py @@ -0,0 +1,126 @@ +import enum +from pathlib import Path +from typing import Optional + +from pydantic import ByteSize, Field + +from ai.backend.common.config import BaseSchema + +from .types import LogFormat, LogLevel + +default_pkg_ns = { + "": LogLevel.WARNING, + "ai.backend": LogLevel.DEBUG, + "tests": LogLevel.DEBUG, + "aiohttp": LogLevel.INFO, +} + + +class LogDriver(str, enum.Enum): + CONSOLE = "console" + LOGSTASH = "logstash" + FILE = "file" + GRAYLOG = "graylog" + + +class LogstashProtocol(str, enum.Enum): + ZMQ_PUSH = "zmq.push" + ZMQ_PUB = "zmq.pub" + TCP = "tcp" + UDP = "udp" + + +class HostPortPair(BaseSchema): + host: str = Field(examples=["127.0.0.1"]) + port: int = Field(gt=0, lt=65536, examples=[8201]) + + def __repr__(self) -> str: + return f"{self.host}:{self.port}" + + def __str__(self) -> str: + return self.__repr__() + + def __getitem__(self, *args) -> int | str: + if args[0] == 0: + return self.host + elif args[0] == 1: + return self.port + else: + raise KeyError(*args) + + +class ConsoleConfig(BaseSchema): + colored: Optional[bool] = Field( + default=None, description="Opt to print colorized log.", examples=[True] + ) + format: LogFormat = Field(default=LogFormat.VERBOSE, description="Determine verbosity of log.") + + +class FileConfig(BaseSchema): + path: Path = Field(description="Path to store log.", examples=["/var/log/backend.ai"]) + filename: str = Field(description="Log file name.", examples=["wsproxy.log"]) + backup_count: int = Field(description="Number of outdated log files to retain.", default=5) + rotation_size: ByteSize = Field( + description="Maximum size for a single log file.", + default_factory=lambda: ByteSize("10MB"), + ) + format: LogFormat = Field(default=LogFormat.VERBOSE, description="Determine verbosity of log.") + + +class LogstashConfig(BaseSchema): + endpoint: HostPortPair = Field( + description="Connection information of logstash node.", + examples=[HostPortPair(host="127.0.0.1", port=8001)], + ) + protocol: LogstashProtocol = Field( + description="Protocol to communicate with logstash server.", + default=LogstashProtocol.TCP, + ) + ssl_enabled: bool = Field( + description="Use TLS to communicate with logstash server.", + default=True, + ) + ssl_verify: bool = Field( + description="Verify validity of TLS certificate when communicating with logstash.", + default=True, + ) + + +class GraylogConfig(BaseSchema): + host: str = Field(description="Graylog hostname.", examples=["127.0.0.1"]) + port: int = Field(description="Graylog server port number.", examples=[8000]) + level: LogLevel = Field(description="Log level.", default=LogLevel.INFO) + ssl_verify: bool = Field( + description="Verify validity of TLS certificate when communicating with logstash.", + default=True, + ) + ca_certs: Optional[str] = Field( + description="Path to Root CA certificate file.", + examples=["/etc/ssl/ca.pem"], + default=None, + ) + keyfile: Optional[str] = Field( + description="Path to TLS private key file.", + examples=["/etc/backend.ai/graylog/privkey.pem"], + default=None, + ) + certfile: Optional[str] = Field( + description="Path to TLS certificate file.", + examples=["/etc/backend.ai/graylog/cert.pem"], + default=None, + ) + + +class LoggingConfig(BaseSchema): + level: LogLevel = Field(default=LogLevel.INFO, description="Log level.") + drivers: list[LogDriver] = Field( + default=[LogDriver.CONSOLE], description="Array of log drivers to print." + ) + console: ConsoleConfig = Field(default=ConsoleConfig(colored=None, format=LogFormat.VERBOSE)) + file: Optional[FileConfig] = Field(default=None) + logstash: Optional[LogstashConfig] = Field(default=None) + graylog: Optional[GraylogConfig] = Field(default=None) + pkg_ns: dict[str, LogLevel] = Field( + description="Override default log level for specific scope of package", + default=default_pkg_ns, + ) diff --git a/src/ai/backend/wsproxy/config.py b/src/ai/backend/wsproxy/config.py index 993258449f3..1b21ae51589 100644 --- a/src/ai/backend/wsproxy/config.py +++ b/src/ai/backend/wsproxy/config.py @@ -1,4 +1,3 @@ -import enum import os import pwd import socket @@ -8,13 +7,10 @@ from dataclasses import dataclass from pathlib import Path from pprint import pformat -from typing import Annotated, Any +from typing import Annotated, Any, Optional import click from pydantic import ( - BaseModel, - ByteSize, - ConfigDict, Field, GetCoreSchemaHandler, GetJsonSchemaHandler, @@ -24,48 +20,23 @@ from pydantic_core import PydanticUndefined, core_schema from ai.backend.common import config -from ai.backend.logging import LogFormat, LogLevel +from ai.backend.common.config import BaseSchema +from ai.backend.logging import LogLevel +from ai.backend.logging.config_pydantic import LoggingConfig from .types import EventLoopType, ProxyProtocol _file_perm = (Path(__file__).parent / "server.py").stat() -class BaseSchema(BaseModel): - model_config = ConfigDict( - populate_by_name=True, - from_attributes=True, - use_enum_values=True, - ) - - -class HostPortPair(BaseSchema): - host: Annotated[str, Field(examples=["127.0.0.1"])] - port: Annotated[int, Field(gt=0, lt=65536, examples=[8201])] - - def __repr__(self) -> str: - return f"{self.host}:{self.port}" - - def __str__(self) -> str: - return self.__repr__() - - def __getitem__(self, *args) -> int | str: - if args[0] == 0: - return self.host - elif args[0] == 1: - return self.port - else: - raise KeyError(*args) - - @dataclass class UserID: - default_uid: int | None = None + default_uid: Optional[int] = None @classmethod def uid_validator( cls, - value: int | str | None, + value: Optional[int | str], ) -> int: if value is None: assert cls.default_uid, "value is None but default_uid not provided" @@ -201,170 +172,31 @@ def __get_pydantic_json_schema__( ) -class LogDriver(str, enum.Enum): - CONSOLE = "console" - LOGSTASH = "logstash" - FILE = "file" - GRAYLOG = "graylog" - - -class LogstashProtocol(str, enum.Enum): - ZMQ_PUSH = "zmq.push" - ZMQ_PUB = "zmq.pub" - TCP = "tcp" - UDP = "udp" - - -default_pkg_ns = {"": "WARNING", "ai.backend": "DEBUG", "tests": "DEBUG", "aiohttp": "INFO"} - - -class ConsoleLogConfig(BaseSchema): - colored: Annotated[ - bool | None, Field(default=None, description="Opt to print colorized log.", examples=[True]) - ] - format: Annotated[ - LogFormat, Field(default=LogFormat.VERBOSE, description="Determine verbosity of log.") - ] - - -class FileLogConfig(BaseSchema): - path: Annotated[Path, Field(description="Path to store log.", examples=["/var/log/backend.ai"])] - filename: Annotated[str, Field(description="Log file name.", examples=["wsproxy.log"])] - backup_count: Annotated[ - int, Field(description="Number of outdated log files to retain.", default=5) - ] - rotation_size: Annotated[ - ByteSize, Field(description="Maximum size for a single log file.", default=10 * 1024 * 1024) - ] - format: Annotated[ - LogFormat, Field(default=LogFormat.VERBOSE, description="Determine verbosity of log.") - ] - - -class LogstashConfig(BaseSchema): - endpoint: Annotated[ - HostPortPair, - Field( - description="Connection information of logstash node.", - examples=[HostPortPair(host="127.0.0.1", port=8001)], - ), - ] - protocol: Annotated[ - LogstashProtocol, - Field( - description="Protocol to communicate with logstash server.", - default=LogstashProtocol.TCP, - ), - ] - ssl_enabled: Annotated[ - bool, Field(description="Use TLS to communicate with logstash server.", default=True) - ] - ssl_verify: Annotated[ - bool, - Field( - description="Verify validity of TLS certificate when communicating with logstash.", - default=True, - ), - ] - - -class GraylogConfig(BaseSchema): - host: Annotated[str, Field(description="Graylog hostname.", examples=["127.0.0.1"])] - port: Annotated[int, Field(description="Graylog server port number.", examples=[8000])] - level: Annotated[LogLevel, Field(description="Log level.", default=LogLevel.INFO)] - ssl_verify: Annotated[ - bool, - Field( - description="Verify validity of TLS certificate when communicating with logstash.", - default=True, - ), - ] - ca_certs: Annotated[ - str | None, - Field( - description="Path to Root CA certificate file.", - examples=["/etc/ssl/ca.pem"], - default=None, - ), - ] - keyfile: Annotated[ - str | None, - Field( - description="Path to TLS private key file.", - examples=["/etc/backend.ai/graylog/privkey.pem"], - default=None, - ), - ] - certfile: Annotated[ - str | None, - Field( - description="Path to TLS certificate file.", - examples=["/etc/backend.ai/graylog/cert.pem"], - default=None, - ), - ] - - -class PyroscopeConfig(BaseSchema): - enabled: Annotated[bool, Field(default=False, description="Enable pyroscope profiler.")] - app_name: Annotated[str, Field(default=None, description="Pyroscope app name.")] - server_addr: Annotated[str, Field(default=None, description="Pyroscope server address.")] - sample_rate: Annotated[int, Field(default=None, description="Pyroscope sample rate.")] - - -class LoggingConfig(BaseSchema): - level: Annotated[LogLevel, Field(default=LogLevel.INFO, description="Log level.")] - pkg_ns: Annotated[ - dict[str, LogLevel], - Field( - description="Override default log level for specific scope of package", - default=default_pkg_ns, - ), - ] - drivers: Annotated[ - list[LogDriver], - Field(default=[LogDriver.CONSOLE], description="Array of log drivers to print."), - ] - console: Annotated[ - ConsoleLogConfig, Field(default=ConsoleLogConfig(colored=None, format=LogFormat.VERBOSE)) - ] - file: Annotated[FileLogConfig | None, Field(default=None)] - logstash: Annotated[LogstashConfig | None, Field(default=None)] - graylog: Annotated[GraylogConfig | None, Field(default=None)] - - class DebugConfig(BaseSchema): - enabled: Annotated[bool, Field(default=False)] - asyncio: Annotated[bool, Field(default=False)] - enhanced_aiomonitor_task_info: Annotated[bool, Field(default=False)] - log_events: Annotated[bool, Field(default=False)] + enabled: bool = Field(default=False) + asyncio: bool = Field(default=False) + enhanced_aiomonitor_task_info: bool = Field(default=False) + log_events: bool = Field(default=False) class WSProxyConfig(BaseSchema): - ipc_base_path: Annotated[ - Path, - Field( - default=Path("/tmp/backend.ai/ipc"), - description="Directory to store temporary UNIX sockets.", - ), - ] - event_loop: Annotated[ - EventLoopType, - Field(default=EventLoopType.ASYNCIO, description="Type of event loop to use."), - ] - pid_file: Annotated[ - Path, - Field( - default=Path(os.devnull), - description="Place to store process PID.", - examples=["/run/backend.ai/wsproxy/wsproxy.pid"], - ), - ] + ipc_base_path: Path = Field( + default=Path("/tmp/backend.ai/ipc"), + description="Directory to store temporary UNIX sockets.", + ) + event_loop: EventLoopType = Field( + default=EventLoopType.ASYNCIO, + description="Type of event loop to use.", + ) + pid_file: Path = Field( + default=Path(os.devnull), + description="Place to store process PID.", + examples=["/run/backend.ai/wsproxy/wsproxy.pid"], + ) - id: Annotated[ - str, - Field(default=f"i-{socket.gethostname()}", examples=["i-node01"], description="Node id."), - ] + id: str = Field( + default=f"i-{socket.gethostname()}", examples=["i-node01"], description="Node id." + ) user: Annotated[ int, UserID(default_uid=_file_perm.st_uid), @@ -376,73 +208,80 @@ class WSProxyConfig(BaseSchema): Field(default=_file_perm.st_uid, description="Process group."), ] - bind_host: Annotated[ - str, - Field( - default="0.0.0.0", - description="Bind address of the port opened on behalf of wsproxy worker", - ), - ] - advertised_host: Annotated[ - str, Field(examples=["example.com"], description="Hostname to be advertised to client") - ] + bind_host: str = Field( + default="0.0.0.0", + description="Bind address of the port opened on behalf of wsproxy worker", + ) + advertised_host: str = Field( + examples=["example.com"], + description="Hostname to be advertised to client", + ) - bind_api_port: Annotated[ - int, Field(default=5050, description="Port number to bind for API server") - ] - advertised_api_port: Annotated[ - int | None, - Field(default=None, examples=[15050], description="API port number reachable from client"), - ] + bind_api_port: int = Field( + default=5050, + description="Port number to bind for API server", + ) + advertised_api_port: Optional[int] = Field( + default=None, + examples=[15050], + description="API port number reachable from client", + ) - bind_proxy_port_range: Annotated[ - tuple[int, int], - Field(default=[10200, 10300], description="Port number to bind for actual traffic"), - ] - advertised_proxy_port_range: Annotated[ - tuple[int, int] | None, - Field( - default=None, - examples=[[20200, 20300]], - description="Traffic port range reachable from client", - ), - ] + bind_proxy_port_range: tuple[int, int] = Field( + default=(10200, 10300), + description="Port number to bind for actual traffic", + ) + advertised_proxy_port_range: Optional[tuple[int, int]] = Field( + default=None, + examples=[[20200, 20300]], + description="Traffic port range reachable from client", + ) - protocol: Annotated[ - ProxyProtocol, Field(default=ProxyProtocol.HTTP, description="Proxy protocol") - ] + protocol: ProxyProtocol = Field(default=ProxyProtocol.HTTP, description="Proxy protocol") - jwt_encrypt_key: Annotated[ - str, Field(examples=["50M3G00DL00KING53CR3T"], description="JWT encryption key") - ] - permit_hash_key: Annotated[ - str, Field(examples=["50M3G00DL00KING53CR3T"], description="Permit hash key") - ] + jwt_encrypt_key: str = Field( + examples=["50M3G00DL00KING53CR3T"], + description="JWT encryption key", + ) + permit_hash_key: str = Field( + examples=["50M3G00DL00KING53CR3T"], + description="Permit hash key", + ) - api_secret: Annotated[str, Field(examples=["50M3G00DL00KING53CR3T"], description="API secret")] + api_secret: str = Field( + examples=["50M3G00DL00KING53CR3T"], + description="API secret", + ) - aiomonitor_termui_port: Annotated[ - int, - Field( - gt=0, lt=65536, description="Port number for aiomonitor termui server.", default=48500 - ), - ] - aiomonitor_webui_port: Annotated[ - int, - Field( - gt=0, lt=65536, description="Port number for aiomonitor webui server.", default=49500 - ), - ] + aiomonitor_termui_port: int = Field( + gt=0, + lt=65536, + description="Port number for aiomonitor termui server.", + default=48500, + ) + aiomonitor_webui_port: int = Field( + gt=0, + lt=65536, + description="Port number for aiomonitor webui server.", + default=49500, + ) + + +class PyroscopeConfig(BaseSchema): + enabled: bool = Field(default=False, description="Enable pyroscope profiler.") + app_name: Optional[str] = Field(default=None, description="Pyroscope app name.") + server_addr: Optional[str] = Field(default=None, description="Pyroscope server address.") + sample_rate: Optional[int] = Field(default=None, description="Pyroscope sample rate.") class ServerConfig(BaseSchema): wsproxy: Annotated[WSProxyConfig, Field(default_factory=WSProxyConfig)] - pyroscope: Annotated[PyroscopeConfig, Field(default_factory=PyroscopeConfig)] - logging: Annotated[LoggingConfig, Field(default_factory=LoggingConfig)] - debug: Annotated[DebugConfig, Field(default_factory=DebugConfig)] + pyroscope: PyroscopeConfig = Field(default_factory=PyroscopeConfig) + logging: LoggingConfig = Field(default_factory=LoggingConfig) + debug: DebugConfig = Field(default_factory=DebugConfig) -def load(config_path: Path | None = None, log_level: LogLevel = LogLevel.NOTSET) -> ServerConfig: +def load(config_path: Optional[Path] = None, log_level: LogLevel = LogLevel.NOTSET) -> ServerConfig: # Determine where to read configuration. raw_cfg, _ = config.read_from_file(config_path, "wsproxy")