-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathworker.py
166 lines (143 loc) · 6.49 KB
/
worker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import yaml
import random
import asyncio
from typing import Literal
from functools import cached_property
import httpx
from pydantic import BaseModel
from text_generation import AsyncClient
from spitfight.log import get_logger
logger = get_logger(__name__)
class Worker(BaseModel):
"""A worker that serves a model."""
# Worker's container name, since we're using Overlay networks.
hostname: str
# For TGI, this would always be 80.
port: int
# User-friendly model name, e.g. "Llama2-7B".
model_name: str
# Hugging Face model ID, e.g. "metaai/Llama-2-13b-chat-hf".
model_id: str
# Whether the model worker container is good.
status: Literal["up", "down"]
class Config:
keep_untouched = (cached_property,)
@cached_property
def url(self) -> str:
return f"http://{self.hostname}:{self.port}"
def get_client(self) -> AsyncClient:
return AsyncClient(base_url=self.url)
def audit(self) -> None:
"""Make sure the worker is running and information is as expected.
Assumed to be called on app startup when workers are initialized.
This method will just raise `ValueError`s if audit fails in order to
prevent the controller from starting if anything is wrong.
"""
try:
response = httpx.get(self.url + "/info")
except (httpx.ConnectError, httpx.TimeoutException) as e:
raise ValueError(f"Could not connect to {self!r}: {e!r}")
if response.status_code != 200:
raise ValueError(f"Could not get /info from {self!r}.")
info = response.json()
if info["model_id"] != self.model_id:
raise ValueError(f"Model name mismatch: {info['model_id']} != {self.model_id}")
self.status = "up"
logger.info("%s is up.", repr(self))
async def check_status(self) -> None:
"""Check worker status and update `self.status` accordingly."""
async with httpx.AsyncClient() as client:
try:
response = await client.get(self.url + "/info")
except (httpx.ConnectError, httpx.TimeoutException) as e:
self.status = "down"
logger.warning("%s is down: %s", repr(self), repr(e))
return
if response.status_code != 200:
self.status = "down"
logger.warning("GET /info from %s returned %s.", repr(self), response.json())
return
info = response.json()
if info["model_id"] != self.model_id:
self.status = "down"
logger.warning(
"Model name mismatch for %s: %s != %s",
repr(self),
info["model_id"],
self.model_id,
)
return
logger.info("%s is up.", repr(self))
self.status = "up"
class WorkerService:
"""A service that manages model serving workers.
Worker objects are only created once and shared across the
entire application. Especially, changing the status of a worker
will immediately take effect on the result of `choose_two`.
Attributes:
workers (list[Worker]): The list of workers.
"""
def __init__(self, compose_files: list[str]) -> None:
"""Initialize the worker service."""
self.workers: list[Worker] = []
worker_model_names = set()
for compose_file in compose_files:
spec = yaml.safe_load(open(compose_file))
for model_name, service_spec in spec["services"].items():
command = service_spec["command"]
for i, cmd in enumerate(command):
if cmd == "--model-id":
model_id = command[i + 1]
break
else:
raise ValueError(f"Could not find model ID in {command!r}")
worker_model_names.add(model_name)
worker = Worker(
hostname=service_spec["container_name"],
port=80,
model_name=model_name,
model_id=model_id,
status="down",
)
worker.audit()
self.workers.append(worker)
if len(worker_model_names) != len(self.workers):
raise ValueError("Model names must be unique.")
def get_worker(self, model_name: str) -> Worker:
"""Get a worker by model name."""
for worker in self.workers:
if worker.model_name == model_name:
if worker.status == "down":
# This is an unfortunate case where, when the two models were chosen,
# the worker was up, but after that went down before the request
# completed. We'll just raise a 500 internal error and have the user
# try again. This won't be common.
raise RuntimeError(f"The worker with model name {model_name} is down.")
return worker
raise ValueError(f"Worker with model name {model_name} does not exist.")
def choose_two(self) -> tuple[Worker, Worker]:
"""Choose two different workers.
Good place to use the Strategy Pattern when we want to
implement different strategies for choosing workers.
"""
live_workers = [worker for worker in self.workers if worker.status == "up"]
if len(live_workers) < 2:
raise ValueError("Not enough live workers to choose from.")
worker_a, worker_b = random.sample(live_workers, 2)
return worker_a, worker_b
def choose_based_on_preference(self, preference: str) -> tuple[Worker, Worker]:
"""Choose two different workers based on user preference.
Specifically, if `preference` is `"Random"`, this is equivalent to
choosing two models at random. Otherwise, if `preference` is a model
name, this is equivalent to choosing that model and another model at
random. In that case, the order of the two models is also randomized.
"""
if preference == "Random":
return self.choose_two()
else:
worker_a = self.get_worker(preference)
worker_b = random.choice([worker for worker in self.workers if worker != worker_a])
return tuple(random.sample([worker_a, worker_b], 2))
async def check_workers(self) -> None:
"""Check the status of all workers."""
await asyncio.gather(*[worker.check_status() for worker in self.workers])