Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement local mode with Unix domain socket #189

Merged
merged 9 commits into from
Feb 11, 2025
47 changes: 1 addition & 46 deletions mostlyai/sdk/_local/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from fastapi import APIRouter, Body, HTTPException, UploadFile, File
from fastapi.encoders import jsonable_encoder
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse, HTMLResponse, RedirectResponse
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse, HTMLResponse

from mostlyai import sdk
from mostlyai.sdk._data.conversions import create_container_from_connector
Expand Down Expand Up @@ -105,51 +105,6 @@ def _html(self, title: str, body: str):
def _initialize_routes(self):
## GENERAL

@self.router.get("/", include_in_schema=False)
async def root() -> RedirectResponse:
return RedirectResponse(url="/docs")

@self.router.get("/d/connectors/{id}", response_class=HTMLResponse, include_in_schema=False)
async def show_connector(id: str) -> HTMLResponse:
connector_dir = self.home_dir / "connectors" / id
connector_json = connector_dir / "connector.json"
body = f"<h3>file://{connector_json}</h3><pre>{connector_json.read_text()}</pre>"
return HTMLResponse(content=self._html("Connector", body))

@self.router.get("/d/generators/{id}", response_class=HTMLResponse, include_in_schema=False)
async def show_generator(id: str) -> HTMLResponse:
generator_dir = self.home_dir / "generators" / id
generator_json = generator_dir / "generator.json"
progress_json = generator_dir / "job_progress.json"
job_log_file = generator_dir / "job.log"
job_log_logs = job_log_file.read_text() if job_log_file.exists() else ""
body = (
f"<h3>file://{generator_json}</h3>"
f"<pre>{generator_json.read_text()}</pre>"
f"<h3>file://{progress_json}</h3>"
f"<pre>{progress_json.read_text()}</pre>"
f"<h3>file://{job_log_file}</h3>"
f"<pre>{job_log_logs}</pre>"
)
return HTMLResponse(content=self._html("Generator", body))

@self.router.get("/d/synthetic-datasets/{id}", response_class=HTMLResponse, include_in_schema=False)
async def show_synthetic_dataset(id: str) -> HTMLResponse:
synthetic_dataset_dir = self.home_dir / "synthetic-datasets" / id
synthetic_dataset_json = synthetic_dataset_dir / "synthetic-dataset.json"
progress_json = synthetic_dataset_dir / "job_progress.json"
job_log_file = synthetic_dataset_dir / "job.log"
job_log_logs = job_log_file.read_text() if job_log_file.exists() else ""
body = (
f"<h3>file://{synthetic_dataset_json}</h3>"
f"<pre>{synthetic_dataset_json.read_text()}</pre>"
f"<h3>file://{progress_json}</h3>"
f"<pre>{progress_json.read_text()}</pre>"
f"<h3>file://{job_log_file}</h3>"
f"<pre>{job_log_logs}</pre>"
)
return HTMLResponse(content=self._html("Synthetic Dataset", body))

@self.router.get("/about", response_model=AboutService)
async def get_about_service() -> AboutService:
return AboutService(version=sdk.__version__, assistant=False)
Expand Down
50 changes: 22 additions & 28 deletions mostlyai/sdk/_local/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import atexit
import tempfile
import time
from pathlib import Path
import random

from threading import Thread

import rich
from fastapi import FastAPI
import uvicorn
from mostlyai.sdk._local.routes import Routes
import socket

import os


class LocalServer:
def __init__(
self,
home_dir: str | Path | None = None,
host: str | None = None,
port: int | None = None,
):
self.home_dir = Path(home_dir if home_dir else "~/mostlyai").expanduser()
self.host = host if host else "127.0.0.1"
self.port = port if port else self._find_available_port()
self.base_url = f"http://{self.host}:{self.port}"
self.home_dir = Path(home_dir or "~/mostlyai").expanduser()
self.home_dir.mkdir(parents=True, exist_ok=True)
self.uds = tempfile.NamedTemporaryFile(
dir=self.home_dir, prefix=".mostlyai-", suffix=".sock", delete=False
).name
self.base_url = "http://127.0.0.1"
self._app = FastAPI(
root_path="/api/v2",
title="Synthetic Data SDK ✨",
Expand All @@ -50,24 +52,13 @@ def __init__(
self._thread = None
self.start() # Automatically start the server during initialization

def _find_available_port(self) -> int:
def is_port_in_use(port: int, host: str = "127.0.0.1") -> bool:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex((host, port)) == 0

failed_ports = []
min_port, max_port = 49152, 65535
for _ in range(10):
port = random.randint(min_port, max_port)
if not is_port_in_use(port, self.host):
return port
failed_ports.append(port)
raise ValueError(
f"Could not find an available port in range {min_port}-{max_port} after 10 attempts. Tried ports: {failed_ports}"
)
def _clear_socket_file(self):
if os.path.exists(self.uds):
os.remove(self.uds)

def _create_server(self):
config = uvicorn.Config(self._app, host=self.host, port=self.port, log_level="error", reload=False)
self._clear_socket_file()
config = uvicorn.Config(self._app, uds=self.uds, log_level="error", reload=False)
self._server = uvicorn.Server(config)

def _run_server(self):
Expand All @@ -77,18 +68,21 @@ def _run_server(self):
def start(self):
if not self._server:
rich.print(
f"Starting server on [link={self.base_url}]{self.base_url}[/] using [link=file://{self.home_dir}]file://{self.home_dir}[/]"
f"Starting Synthetic Data SDK in local mode using [link=file://{self.home_dir} dodger_blue2 underline]file://{self.home_dir}[/]"
)
self._create_server()
self._thread = Thread(target=self._run_server, daemon=True)
self._thread.start()
# make sure the socket file is cleaned up on exit
atexit.register(self._clear_socket_file)
time.sleep(0.5) # give the server a moment to start

def stop(self):
if self._server and self._server.started:
rich.print(f"Stopping server on {self.base_url} for {self.home_dir.absolute()}.")
rich.print("Stopping Synthetic Data SDK in local mode")
self._server.should_exit = True # Signal the server to shut down
self._thread.join() # Wait for the server thread to finish
self._clear_socket_file()

def __enter__(self):
# Ensure the server is running
Expand All @@ -102,5 +96,5 @@ def __exit__(self, exc_type, exc_value, traceback):
def __del__(self):
# Backup cleanup in case `stop` was not called explicitly or via context
if self._server and self._server.started:
print(f"Automatically shutting down server on {self.host}:{self.port}")
print("Automatically shutting down server")
self.stop()
1 change: 1 addition & 0 deletions mostlyai/sdk/client/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def job_wait(
return
# check whether we are done
if job.progress.value >= job.progress.max:
progress.refresh()
time.sleep(1) # give the system a moment to update the status
return
else:
Expand Down
28 changes: 16 additions & 12 deletions mostlyai/sdk/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ class MostlyAI(_MostlyBaseClient):
api_key: The API key for authenticating. If not provided, it would rely on environment variables.
local: Whether to run in local mode or not.
local_dir: The directory to use for local mode. If not provided, `~/mostlyai` will be used.
local_port: The port to use for local mode. If not provided, a random port will be used.
local_host: The host to use for local mode. If not provided, then 127.0.0.1 will be used.
timeout: Timeout for HTTPS requests in seconds.
ssl_verify: Whether to verify SSL certificates.
quiet: Whether to suppress rich output.
Expand All @@ -98,8 +96,6 @@ def __init__(
api_key: str | None = None,
local: bool = False,
local_dir: str | Path | None = None,
local_port: int | None = None,
local_host: str | None = None,
timeout: float = 60.0,
ssl_verify: bool = True,
quiet: bool = False,
Expand All @@ -117,27 +113,35 @@ def __init__(
from mostlyai import qa # noqa
except ImportError:
raise APIError(
"Local mode requires additional packages to be installed. Run `pip install 'mostlyai[local]'`."
"LOCAL mode requires additional packages to be installed. Run `pip install 'mostlyai[local]'`."
)

self.local = LocalServer(home_dir=local_dir, host=local_host, port=local_port)
self.local = LocalServer(home_dir=local_dir)
base_url = self.local.base_url
api_key = "local"
uds = self.local.uds
else:
uds = None

super().__init__(base_url=base_url, api_key=api_key, timeout=timeout, ssl_verify=ssl_verify)
client_kwargs = {
"base_url": self.base_url,
"api_key": self.api_key,
"timeout": self.timeout,
"ssl_verify": self.ssl_verify,
"base_url": base_url,
"api_key": api_key,
"uds": uds,
"timeout": timeout,
"ssl_verify": ssl_verify,
}
super().__init__(**client_kwargs)
self.connectors = _MostlyConnectorsClient(**client_kwargs)
self.generators = _MostlyGeneratorsClient(**client_kwargs)
self.synthetic_datasets = _MostlySyntheticDatasetsClient(**client_kwargs)
self.synthetic_probes = _MostlySyntheticProbesClient(**client_kwargs)
try:
if local:
msg = "Connected to Synthetic Data SDK in local mode"
else:
msg = f"Connected to [link={self.base_url} dodger_blue2 underline]{self.base_url}[/]"
version = self.about().version
msg = f"Connected to [link={self.base_url}]{self.base_url}[/] ({version})"
msg += f" ({version})"
email = self.me().email
if email:
msg += f" as [bold]{email}[/bold]"
Expand Down
5 changes: 4 additions & 1 deletion mostlyai/sdk/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,14 @@ def __init__(
self,
base_url: str | None = None,
api_key: str | None = None,
uds: str | None = None,
timeout: float = 60.0,
ssl_verify: bool = True,
):
self.base_url = (base_url or os.getenv("MOSTLY_BASE_URL") or DEFAULT_BASE_URL).rstrip("/")
self.api_key = api_key or os.getenv("MOSTLY_API_KEY")
self.local = uds is not None
self.transport = httpx.HTTPTransport(uds=uds) if uds else None
self.timeout = timeout
self.ssl_verify = ssl_verify
if not self.api_key:
Expand Down Expand Up @@ -137,7 +140,7 @@ def request(
kwargs["params"] = map_snake_to_camel_case(kwargs["params"])

try:
with httpx.Client(timeout=self.timeout, verify=self.ssl_verify) as client:
with httpx.Client(timeout=self.timeout, verify=self.ssl_verify, transport=self.transport) as client:
response = client.request(method=verb, url=full_url, **kwargs)
response.raise_for_status()
except httpx.HTTPStatusError as exc:
Expand Down
5 changes: 4 additions & 1 deletion mostlyai/sdk/client/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,10 @@ def create(
response_type=Connector,
)
cid = connector.id
rich.print(f"Created connector [link={self.base_url}/d/connectors/{cid} blue underline]{cid}[/]")
if self.local:
rich.print(f"Created connector [dodger_blue2]{cid}[/]")
else:
rich.print(f"Created connector [link={self.base_url}/d/connectors/{cid} dodger_blue2 underline]{cid}[/]")
return connector

# PRIVATE METHODS #
Expand Down
10 changes: 8 additions & 2 deletions mostlyai/sdk/client/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,10 @@ def create(self, config: GeneratorConfig | dict) -> Generator:
table["columns"] = [{"name": col} if isinstance(col, str) else col for col in table["columns"]]
generator = self.request(verb=POST, path=[], json=config, response_type=Generator)
gid = generator.id
rich.print(f"Created generator [link={self.base_url}/d/generators/{gid} blue underline]{gid}[/]")
if self.local:
rich.print(f"Created generator [dodger_blue2]{gid}[/]")
else:
rich.print(f"Created generator [link={self.base_url}/d/generators/{gid} dodger_blue2 underline]{gid}[/]")
return generator

def import_from_file(
Expand Down Expand Up @@ -206,7 +209,10 @@ def import_from_file(
response_type=Generator,
)
gid = generator.id
rich.print(f"Imported generator [link={self.base_url}/d/generators/{gid} blue underline]{gid}[/]")
if self.local:
rich.print(f"Imported generator [dodger_blue2]{gid}[/]")
else:
rich.print(f"Imported generator [link={self.base_url}/d/generators/{gid} dodger_blue2 underline]{gid}[/]")
return generator

# PRIVATE METHODS #
Expand Down
9 changes: 6 additions & 3 deletions mostlyai/sdk/client/synthetic_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,12 @@ def create(self, config: SyntheticDatasetConfig | dict[str, Any]) -> SyntheticDa
)
sid = synthetic_dataset.id
gid = synthetic_dataset.generator_id
rich.print(
f"Created synthetic dataset [link={self.base_url}/d/synthetic-datasets/{sid} blue underline]{sid}[/] with generator [link={self.base_url}/d/generators/{gid} blue underline]{gid}[/]"
)
if self.local:
rich.print(f"Created synthetic dataset [dodger_blue2]{sid}[/] with generator [dodger_blue2]{gid}[/]")
else:
rich.print(
f"Created synthetic dataset [link={self.base_url}/d/synthetic-datasets/{sid} dodger_blue2 underline]{sid}[/] with generator [link={self.base_url}/d/generators/{gid} dodger_blue2 underline]{gid}[/]"
)
return synthetic_dataset

# PRIVATE METHODS #
Expand Down