-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrouter.py
136 lines (117 loc) · 4.92 KB
/
router.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import json
import uvicorn
from pydantic import BaseSettings
from fastapi import FastAPI, Depends
from fastapi.responses import StreamingResponse
from fastapi.exceptions import HTTPException
from text_generation.errors import OverloadedError, UnknownError, ValidationError
from spitfight.log import get_logger, init_queued_root_logger, shutdown_queued_root_loggers
from spitfight.colosseum.common import (
COLOSSEUM_MODELS_ROUTE,
COLOSSEUM_PROMPT_ROUTE,
COLOSSEUM_RESP_VOTE_ROUTE,
COLOSSEUM_ENERGY_VOTE_ROUTE,
COLOSSEUM_HEALTH_ROUTE,
ModelsResponse,
PromptRequest,
ResponseVoteRequest,
ResponseVoteResponse,
EnergyVoteRequest,
EnergyVoteResponse,
)
from spitfight.colosseum.controller.controller import (
Controller,
init_global_controller,
get_global_controller,
)
from spitfight.utils import prepend_generator
class ControllerConfig(BaseSettings):
"""Controller settings automatically loaded from environment variables."""
# Controller
background_task_interval: int = 300
max_num_req_states: int = 10000
req_state_expiration_time: int = 600
compose_files: list[str] = ["deployment/docker-compose-0.yaml", "deployment/docker-compose-1.yaml"]
# Logging
log_dir: str = "/logs"
controller_log_file: str = "controller.log"
request_log_file: str = "requests.log"
uvicorn_log_file: str = "uvicorn.log"
# Generation
max_new_tokens: int = 512
do_sample: bool = True
temperature: float = 1.0
repetition_penalty: float = 1.0
top_k: int = 50
top_p: float = 0.95
app = FastAPI()
settings = ControllerConfig()
logger = get_logger("spitfight.colosseum.controller.router")
@app.on_event("startup")
async def startup_event():
init_queued_root_logger("uvicorn", os.path.join(settings.log_dir, settings.uvicorn_log_file))
init_queued_root_logger("spitfight.colosseum.controller", os.path.join(settings.log_dir, settings.controller_log_file))
init_queued_root_logger("colosseum_requests", os.path.join(settings.log_dir, settings.request_log_file))
init_global_controller(settings)
@app.on_event("shutdown")
async def shutdown_event():
get_global_controller().shutdown()
shutdown_queued_root_loggers()
@app.get(COLOSSEUM_MODELS_ROUTE, response_model=ModelsResponse)
async def models(controller: Controller = Depends(get_global_controller)):
return ModelsResponse(available_models=controller.get_available_models())
@app.post(COLOSSEUM_PROMPT_ROUTE)
async def prompt(
request: PromptRequest,
controller: Controller = Depends(get_global_controller),
):
generator = controller.prompt(
request.request_id,
request.prompt,
request.model_index,
request.model_preference,
)
# First try to get the first token in order to catch TGI errors.
try:
first_token = await generator.__anext__()
except OverloadedError:
name = controller.request_states[request.request_id].model_names[request.model_index]
logger.warning("Model %s is overloaded. Failed request: %s", name, repr(request))
raise HTTPException(status_code=429, detail="Model overloaded. Pleaes try again later.")
except ValidationError as e:
logger.info("TGI returned validation error: %s. Failed request: %s", str(e), repr(request))
raise HTTPException(status_code=422, detail=str(e))
except StopAsyncIteration:
logger.info("TGI returned empty response. Failed request: %s", repr(request))
return StreamingResponse(
iter([json.dumps("*The model generated an empty response.*").encode() + b"\0"]),
)
except UnknownError as e:
logger.error("TGI returned unknown error: %s. Failed request: %s", str(e), repr(request))
raise HTTPException(status_code=500, detail=str(e))
return StreamingResponse(prepend_generator(first_token, generator))
@app.post(COLOSSEUM_RESP_VOTE_ROUTE, response_model=ResponseVoteResponse)
async def response_vote(
request: ResponseVoteRequest,
controller: Controller = Depends(get_global_controller),
):
if (state := controller.response_vote(request.request_id, request.victory_index)) is None:
raise HTTPException(status_code=410, detail="Colosseum battle session timeout expired.")
return ResponseVoteResponse(
energy_consumptions=state.energy_consumptions,
model_names=state.model_names,
)
@app.post(COLOSSEUM_ENERGY_VOTE_ROUTE, response_model=EnergyVoteResponse)
async def energy_vote(
request: EnergyVoteRequest,
controller: Controller = Depends(get_global_controller),
):
if (state := controller.energy_vote(request.request_id, request.is_worth)) is None:
raise HTTPException(status_code=410, detail="Colosseum battle session timeout expired.")
return EnergyVoteResponse(model_names=state.model_names)
@app.get(COLOSSEUM_HEALTH_ROUTE)
async def health():
return "OK"
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", log_config=None)