Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement local mode with Unix domain socket #189

Merged
merged 9 commits into from
Feb 11, 2025
73 changes: 44 additions & 29 deletions mostlyai/sdk/_local/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,63 +11,74 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import atexit
import tempfile
import time
from pathlib import Path
import random

# import random
from threading import Thread

import rich
from fastapi import FastAPI
import uvicorn
from mostlyai.sdk._local.routes import Routes
import socket

# import socket
import os


class LocalServer:
def __init__(
self,
home_dir: str | Path | None = None,
host: str | None = None,
port: int | None = None,
# host: str | None = None,
# port: int | None = None,
uds: str | None = None,
):
self.home_dir = Path(home_dir if home_dir else "~/mostlyai").expanduser()
self.host = host if host else "127.0.0.1"
self.port = port if port else self._find_available_port()
self.base_url = f"http://{self.host}:{self.port}"
self.home_dir = Path(home_dir or "~/mostlyai").expanduser()
# self.host = host or "127.0.0.1"
# self.port = port or self._find_available_port()
self.uds = uds or tempfile.NamedTemporaryFile(prefix="mostlyai-", suffix=".sock", delete=False).name
self.base_url = "http://127.0.0.1"
self._app = FastAPI(
root_path="/api/v2",
title="Synthetic Data SDK ✨",
description="Welcome! This is your Local Server instance of the Synthetic Data SDK. "
"Connect via the MOSTLY AI client to train models and generate synthetic data locally. "
"Share the knowledge of your synthetic data generators with your team or the world by "
"deploying these then to a MOSTLY AI platform. Enjoy!",
version="1.0.0",
version="4.1.0",
)
routes = Routes(self.home_dir)
self._app.include_router(routes.router)
self._server = None
self._thread = None
self.start() # Automatically start the server during initialization

def _find_available_port(self) -> int:
def is_port_in_use(port: int, host: str = "127.0.0.1") -> bool:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex((host, port)) == 0
# def _find_available_port(self) -> int:
# def is_port_in_use(port: int, host: str = "127.0.0.1") -> bool:
# with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
# return s.connect_ex((host, port)) == 0
#
# failed_ports = []
# min_port, max_port = 49152, 65535
# for _ in range(10):
# port = random.randint(min_port, max_port)
# if not is_port_in_use(port, self.host):
# return port
# failed_ports.append(port)
# raise ValueError(
# f"Could not find an available port in range {min_port}-{max_port} after 10 attempts. Tried ports: {failed_ports}"
# )

failed_ports = []
min_port, max_port = 49152, 65535
for _ in range(10):
port = random.randint(min_port, max_port)
if not is_port_in_use(port, self.host):
return port
failed_ports.append(port)
raise ValueError(
f"Could not find an available port in range {min_port}-{max_port} after 10 attempts. Tried ports: {failed_ports}"
)
def _clear_socket_file(self):
if os.path.exists(self.uds):
os.remove(self.uds)

def _create_server(self):
config = uvicorn.Config(self._app, host=self.host, port=self.port, log_level="error", reload=False)
self._clear_socket_file()
config = uvicorn.Config(self._app, uds=self.uds, log_level="error", reload=False)
self._server = uvicorn.Server(config)

def _run_server(self):
Expand All @@ -76,19 +87,23 @@ def _run_server(self):

def start(self):
if not self._server:
rich.print(
f"Starting server on [link={self.base_url}]{self.base_url}[/] using [link=file://{self.home_dir}]file://{self.home_dir}[/]"
)
# rich.print(
# f"Starting server on [link={self.base_url}]{self.base_url}[/] using [link=file://{self.home_dir}]file://{self.home_dir}[/]"
# )
rich.print(f"Starting server using [link=file://{self.home_dir}]file://{self.home_dir}[/]")
self._create_server()
self._thread = Thread(target=self._run_server, daemon=True)
self._thread.start()
# make sure the socket file is cleaned up on exit
atexit.register(self._clear_socket_file)
time.sleep(0.5) # give the server a moment to start

def stop(self):
if self._server and self._server.started:
rich.print(f"Stopping server on {self.base_url} for {self.home_dir.absolute()}.")
self._server.should_exit = True # Signal the server to shut down
self._thread.join() # Wait for the server thread to finish
self._clear_socket_file()

def __enter__(self):
# Ensure the server is running
Expand All @@ -102,5 +117,5 @@ def __exit__(self, exc_type, exc_value, traceback):
def __del__(self):
# Backup cleanup in case `stop` was not called explicitly or via context
if self._server and self._server.started:
print(f"Automatically shutting down server on {self.host}:{self.port}")
print("Automatically shutting down server")
self.stop()
21 changes: 13 additions & 8 deletions mostlyai/sdk/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,9 @@ def __init__(
api_key: str | None = None,
local: bool = False,
local_dir: str | Path | None = None,
local_port: int | None = None,
local_host: str | None = None,
# local_port: int | None = None,
# local_host: str | None = None,
local_uds: str | None = None,
timeout: float = 60.0,
ssl_verify: bool = True,
quiet: bool = False,
Expand All @@ -120,17 +121,21 @@ def __init__(
"Local mode requires additional packages to be installed. Run `pip install 'mostlyai[local]'`."
)

self.local = LocalServer(home_dir=local_dir, host=local_host, port=local_port)
self.local = LocalServer(home_dir=local_dir, uds=local_uds)
base_url = self.local.base_url
api_key = "local"
uds = self.local.uds
else:
uds = None

super().__init__(base_url=base_url, api_key=api_key, timeout=timeout, ssl_verify=ssl_verify)
client_kwargs = {
"base_url": self.base_url,
"api_key": self.api_key,
"timeout": self.timeout,
"ssl_verify": self.ssl_verify,
"base_url": base_url,
"api_key": api_key,
"uds": uds,
"timeout": timeout,
"ssl_verify": ssl_verify,
}
super().__init__(**client_kwargs)
self.connectors = _MostlyConnectorsClient(**client_kwargs)
self.generators = _MostlyGeneratorsClient(**client_kwargs)
self.synthetic_datasets = _MostlySyntheticDatasetsClient(**client_kwargs)
Expand Down
4 changes: 3 additions & 1 deletion mostlyai/sdk/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,13 @@ def __init__(
self,
base_url: str | None = None,
api_key: str | None = None,
uds: str | None = None,
timeout: float = 60.0,
ssl_verify: bool = True,
):
self.base_url = (base_url or os.getenv("MOSTLY_BASE_URL") or DEFAULT_BASE_URL).rstrip("/")
self.api_key = api_key or os.getenv("MOSTLY_API_KEY")
self.transport = httpx.HTTPTransport(uds=uds) if uds else None
self.timeout = timeout
self.ssl_verify = ssl_verify
if not self.api_key:
Expand Down Expand Up @@ -137,7 +139,7 @@ def request(
kwargs["params"] = map_snake_to_camel_case(kwargs["params"])

try:
with httpx.Client(timeout=self.timeout, verify=self.ssl_verify) as client:
with httpx.Client(timeout=self.timeout, verify=self.ssl_verify, transport=self.transport) as client:
response = client.request(method=verb, url=full_url, **kwargs)
response.raise_for_status()
except httpx.HTTPStatusError as exc:
Expand Down