Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(BA-593): Configure the logging module's config file to use Pydantic #2834

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/2834.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Configure the logging module's config file to use Pydantic
14 changes: 7 additions & 7 deletions configs/wsproxy/halfstack.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
[wsproxy]
bind_host = "0.0.0.0"
advertised_host = "127.0.0.1"
bind-host = "0.0.0.0"
advertised-host = "127.0.0.1"

bind_api_port = 5050
advertised_api_port = 5050
bind-api-port = 5050
advertised-api-port = 5050

# replace these values with your passphrase
jwt_encrypt_key = "QRX/ZX2nwKKpuTSD3ZycPA"
permit_hash_key = "gHNAohmntRV0t9zlwTVQeQ"
jwt-encrypt-key = "QRX/ZX2nwKKpuTSD3ZycPA"
permit-hash-key = "gHNAohmntRV0t9zlwTVQeQ"

api_secret = "v625xZLOgbMHhl0s49VuqQ"
api-secret = "v625xZLOgbMHhl0s49VuqQ"
Yaminyam marked this conversation as resolved.
Show resolved Hide resolved

[pyroscope]
enabled = true
Expand Down
46 changes: 23 additions & 23 deletions configs/wsproxy/sample.toml
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
[wsproxy]
ipc_base_path = "/tmp/backend.ai/ipc"
event_loop = "asyncio"
pid_file = "/run/backend.ai/wsproxy/wsproxy.pid"
ipc-base-path = "/tmp/backend.ai/ipc"
event-loop = "asyncio"
pid-file = "/run/backend.ai/wsproxy/wsproxy.pid"
id = "i-node01"
user = 501
group = 501
bind_host = "0.0.0.0"
advertised_host = "example.com"
bind_api_port = 5050
advertised_api_port = 15050
bind_proxy_port_range = [
bind-host = "0.0.0.0"
advertised-host = "example.com"
bind-api-port = 5050
advertised-api-port = 15050
bind-proxy-port-range = [
10200,
10300,
]
advertised_proxy_port_range = [
advertised-proxy-port-range = [
20200,
20300,
]
protocol = "http"

# replace these values with your passphrase
jwt_encrypt_key = "50M3G00DL00KING53CR3T"
permit_hash_key = "50M3G00DL00KING53CR3T"
api_secret = "50M3G00DL00KING53CR3T"
jwt-encrypt-key = "50M3G00DL00KING53CR3T"
permit-hash-key = "50M3G00DL00KING53CR3T"
api-secret = "50M3G00DL00KING53CR3T"

aiomonitor_termui_port = 48500
aiomonitor_webui_port = 49500
aiomonitor-termui-port = 48500
aiomonitor-webui-port = 49500

[pyroscope]
enabled = true
Expand All @@ -39,7 +39,7 @@ drivers = [
"console",
]

[logging.pkg_ns]
[logging.pkg-ns]
"" = "WARNING"
"ai.backend" = "DEBUG"
tests = "DEBUG"
Expand All @@ -52,14 +52,14 @@ format = "verbose"
[logging.file]
path = "/var/log/backend.ai"
filename = "wsproxy.log"
backup_count = 5
rotation_size = "10M"
backup-count = 5
rotation-size = "10M"
format = "verbose"

[logging.logstash]
protocol = "tcp"
ssl_enabled = true
ssl_verify = true
ssl-enabled = true
ssl-verify = true

[logging.logstash.endpoint]
host = "127.0.0.1"
Expand All @@ -69,14 +69,14 @@ port = 8001
host = "127.0.0.1"
port = 8000
level = "INFO"
ssl_verify = true
ca_certs = "/etc/ssl/ca.pem"
ssl-verify = true
ca-certs = "/etc/ssl/ca.pem"
keyfile = "/etc/backend.ai/graylog/privkey.pem"
certfile = "/etc/backend.ai/graylog/cert.pem"

[debug]
enabled = false
asyncio = false
enhanced_aiomonitor_task_info = false
log_events = false
enhanced-aiomonitor-task-info = false
log-events = false

10 changes: 9 additions & 1 deletion src/ai/backend/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,16 @@


class BaseSchema(BaseModel):
@staticmethod
def snake_to_kebab_case(string: str) -> str:
return string.replace("_", "-")

model_config = ConfigDict(
populate_by_name=True, from_attributes=True, use_enum_values=True, extra="allow"
populate_by_name=True,
from_attributes=True,
use_enum_values=True,
extra="allow",
alias_generator=snake_to_kebab_case,
)


Expand Down
7 changes: 4 additions & 3 deletions src/ai/backend/common/metrics/profiler.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from dataclasses import dataclass
from typing import Optional

import pyroscope


@dataclass
class PyroscopeArgs:
enabled: bool
app_name: str
server_address: str
sample_rate: int
app_name: Optional[str]
server_address: Optional[str]
sample_rate: Optional[int]


class Profiler:
Expand Down
4 changes: 3 additions & 1 deletion src/ai/backend/logging/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ visibility_private_component(
allowed_dependents=[
"//src/ai/backend/**",
],
allowed_dependencies=[],
allowed_dependencies=[
"//src/ai/backend/common/**",
],
)

python_distribution(
Expand Down
126 changes: 126 additions & 0 deletions src/ai/backend/logging/config_pydantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import enum
from pathlib import Path
from typing import Optional

from pydantic import ByteSize, Field

from ai.backend.common.config import BaseSchema

from .types import LogFormat, LogLevel

default_pkg_ns = {
"": LogLevel.WARNING,
"ai.backend": LogLevel.DEBUG,
"tests": LogLevel.DEBUG,
"aiohttp": LogLevel.INFO,
}


class LogDriver(str, enum.Enum):
CONSOLE = "console"
LOGSTASH = "logstash"
FILE = "file"
GRAYLOG = "graylog"


class LogstashProtocol(str, enum.Enum):
ZMQ_PUSH = "zmq.push"
ZMQ_PUB = "zmq.pub"
TCP = "tcp"
UDP = "udp"


class HostPortPair(BaseSchema):
host: str = Field(examples=["127.0.0.1"])
port: int = Field(gt=0, lt=65536, examples=[8201])

def __repr__(self) -> str:
return f"{self.host}:{self.port}"

def __str__(self) -> str:
return self.__repr__()

def __getitem__(self, *args) -> int | str:
if args[0] == 0:
return self.host
elif args[0] == 1:
return self.port
else:
raise KeyError(*args)
Comment on lines +33 to +49
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I’m wondering is why the HostPortPair class is inside the logging package.
It seems like it should be in a different package.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We support several logging tools and may support more in the future.
(ex, graylog, logstash)
Since these logging tools require the host address by default, we've just separated that part into a common setting.
ref.

t.Key("host"): t.String,



class ConsoleConfig(BaseSchema):
colored: Optional[bool] = Field(
default=None, description="Opt to print colorized log.", examples=[True]
)
format: LogFormat = Field(default=LogFormat.VERBOSE, description="Determine verbosity of log.")


class FileConfig(BaseSchema):
path: Path = Field(description="Path to store log.", examples=["/var/log/backend.ai"])
filename: str = Field(description="Log file name.", examples=["wsproxy.log"])
backup_count: int = Field(description="Number of outdated log files to retain.", default=5)
rotation_size: ByteSize = Field(
description="Maximum size for a single log file.",
default_factory=lambda: ByteSize("10MB"),
)
format: LogFormat = Field(default=LogFormat.VERBOSE, description="Determine verbosity of log.")


class LogstashConfig(BaseSchema):
endpoint: HostPortPair = Field(
description="Connection information of logstash node.",
examples=[HostPortPair(host="127.0.0.1", port=8001)],
)
protocol: LogstashProtocol = Field(
description="Protocol to communicate with logstash server.",
default=LogstashProtocol.TCP,
)
ssl_enabled: bool = Field(
description="Use TLS to communicate with logstash server.",
default=True,
)
ssl_verify: bool = Field(
description="Verify validity of TLS certificate when communicating with logstash.",
default=True,
)


class GraylogConfig(BaseSchema):
host: str = Field(description="Graylog hostname.", examples=["127.0.0.1"])
port: int = Field(description="Graylog server port number.", examples=[8000])
level: LogLevel = Field(description="Log level.", default=LogLevel.INFO)
ssl_verify: bool = Field(
description="Verify validity of TLS certificate when communicating with logstash.",
default=True,
)
ca_certs: Optional[str] = Field(
description="Path to Root CA certificate file.",
examples=["/etc/ssl/ca.pem"],
default=None,
)
keyfile: Optional[str] = Field(
description="Path to TLS private key file.",
examples=["/etc/backend.ai/graylog/privkey.pem"],
default=None,
)
certfile: Optional[str] = Field(
description="Path to TLS certificate file.",
examples=["/etc/backend.ai/graylog/cert.pem"],
default=None,
)


class LoggingConfig(BaseSchema):
level: LogLevel = Field(default=LogLevel.INFO, description="Log level.")
drivers: list[LogDriver] = Field(
default=[LogDriver.CONSOLE], description="Array of log drivers to print."
)
console: ConsoleConfig = Field(default=ConsoleConfig(colored=None, format=LogFormat.VERBOSE))
file: Optional[FileConfig] = Field(default=None)
logstash: Optional[LogstashConfig] = Field(default=None)
graylog: Optional[GraylogConfig] = Field(default=None)
pkg_ns: dict[str, LogLevel] = Field(
description="Override default log level for specific scope of package",
default=default_pkg_ns,
)
Loading
Loading