-
Notifications
You must be signed in to change notification settings - Fork 241
/
Copy pathmain.py
184 lines (159 loc) · 7.96 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
"""Starts the activation server. Methods on the server are defined in separate files."""
import datetime
import os
import re
import signal
import fire
import torch
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.routing import APIRoute
from starlette.exceptions import HTTPException as StarletteHTTPException
from neuron_explainer.activation_server.explainer_routes import (
AttentionExplainAndScoreMethodId,
NeuronExplainAndScoreMethodId,
define_explainer_routes,
)
from neuron_explainer.activation_server.inference_routes import define_inference_routes
from neuron_explainer.activation_server.interactive_model import InteractiveModel
from neuron_explainer.activation_server.read_routes import define_read_routes
from neuron_explainer.activation_server.requests_and_responses import GroupId
from neuron_explainer.models.autoencoder_context import AutoencoderContext # noqa: F401
from neuron_explainer.models.autoencoder_context import MultiAutoencoderContext
from neuron_explainer.models.model_context import StandardModelContext, get_default_device
from neuron_explainer.models.model_registry import make_autoencoder_context
def main(
host_name: str = "localhost",
port: int = 80,
model_name: str = "gpt2-small",
mlp_autoencoder_name: str | None = None,
attn_autoencoder_name: str | None = None,
run_model: bool = True,
neuron_method: str = "baseline",
attention_head_method: str = "baseline",
cuda_memory_debugging: bool = False,
) -> None:
neuron_method_id = NeuronExplainAndScoreMethodId.from_string(neuron_method)
attention_head_method_id = AttentionExplainAndScoreMethodId.from_string(attention_head_method)
def custom_generate_unique_id(route: APIRoute) -> str:
return f"{route.tags[0]}-{route.name}"
app = FastAPI(generate_unique_id_function=custom_generate_unique_id)
allow_origin_regex_str = r"https?://localhost(:\d+)?$"
allow_origin_regex = re.compile(allow_origin_regex_str)
app.add_middleware(
CORSMiddleware,
allow_origin_regex=allow_origin_regex_str,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# We don't just want to disable CORS for successful responses: we also want to do it for error
# responses, which FastAPI's middleware doesn't cover. This allows the client to see helpful
# information like the HTTP status code, which is otherwise hidden from it. To do this, we add
# two exception handlers. It's possible we could just get away with the first one, but GPT-4
# thought it was good to include both and who am I to disagree?
def add_access_control_headers(request: Request, response: JSONResponse) -> JSONResponse:
origin = request.headers.get("origin")
# This logic does something similar to what the standard CORSMiddleware does. You can't
# use a regex in the actual response header, but you can run the regex on the server and
# then choose to include the header if it matches the origin.
if origin and allow_origin_regex.fullmatch(origin):
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Access-Control-Allow-Methods"] = "*"
response.headers["Access-Control-Allow-Headers"] = "*"
return response
@app.exception_handler(Exception)
async def handle_unhandled_exception(request: Request, exc: Exception) -> JSONResponse:
print("************** Handling an unhandled exception ***********************")
print(f"Exception type: {type(exc).__name__}")
print(f"Exception details: {exc}")
response = add_access_control_headers(
request,
JSONResponse(status_code=500, content={"message": "Unhandled server exception"}),
)
# Check if this exception is a cuda OOM, which is unrecoverable. If it is, we should kill
# the server.
if isinstance(exc, torch.cuda.OutOfMemoryError):
print("***** Killing server due to cuda OOM *****")
# Use SIGKILL so the return code of the top-level process is *not* 0.
os.kill(os.getpid(), signal.SIGKILL)
return response
@app.exception_handler(StarletteHTTPException)
async def handle_starlette_http_exception(request: Request, exc: HTTPException) -> JSONResponse:
return add_access_control_headers(
request, JSONResponse(status_code=exc.status_code, content={"message": exc.detail})
)
@app.get("/", tags=["hello_world"])
def read_root() -> dict[str, str]:
return {"Hello": "World"}
# The FastAPI client code generation setup only generates TypeScript classes for types
# referenced from top-level endpoints. In some cases we want to share a type across client and
# server that isn't referenced in this way. For example, GroupId is used in requests, but only
# as a key in a dictionary, and the generated TypeScript for dictionaries treats enum values as
# strings, so GroupId isn't referenced in the generated TypeScript. To work around this, we
# define a dummy endpoint that references GroupId, which causes the client code generation to
# generate a TypeScript class for it. The same trick can be used for other types in the future.
@app.get("/force_client_code_generation", tags=["hello_world"])
def force_client_code_generation(group_id: GroupId) -> None:
return None
@app.get("/dump_memory_snapshot", tags=["memory"])
def dump_memory_snapshot() -> str:
if not cuda_memory_debugging:
raise HTTPException(
status_code=400,
detail="The cuda_memory_debugging flag must be set to dump a memory snapshot",
)
formatted_time = datetime.datetime.now().strftime("%H%M%S")
filename = f"torch_memory_snapshot_{formatted_time}.pkl"
torch.cuda.memory._dump_snapshot(filename)
return f"Dumped cuda memory snapshot to {filename}"
if run_model:
if cuda_memory_debugging:
torch.cuda.memory._record_memory_history(max_entries=100000)
device = get_default_device()
standard_model_context = StandardModelContext(model_name, device=device)
if mlp_autoencoder_name is not None or attn_autoencoder_name is not None:
autoencoder_context_list = [
make_autoencoder_context(
model_name=model_name,
autoencoder_name=autoencoder_name,
device=device,
omit_dead_latents=True,
)
for autoencoder_name in [mlp_autoencoder_name, attn_autoencoder_name]
if autoencoder_name is not None
]
multi_autoencoder_context = MultiAutoencoderContext.from_autoencoder_context_list(
autoencoder_context_list
)
multi_autoencoder_context.warmup()
model = InteractiveModel.from_standard_model_context_and_autoencoder_context(
standard_model_context, multi_autoencoder_context
)
else:
model = InteractiveModel.from_standard_model_context(standard_model_context)
else:
model = None
define_read_routes(app)
define_explainer_routes(
app=app,
neuron_method_id=neuron_method_id,
attention_head_method_id=attention_head_method_id,
)
define_inference_routes(
app=app,
model=model,
mlp_autoencoder_name=mlp_autoencoder_name,
attn_autoencoder_name=attn_autoencoder_name,
)
# TODO(sbills): Make reload=True work. We need to pass something like "main:app" as a string
# rather than passing a FastAPI object directly.
uvicorn.run(app, host=host_name, port=port)
if __name__ == "__main__":
fire.Fire(main)
"""
For local testing without running a subject model:
python neuron_explainer/activation_server/main.py --run_model False --port 8000
"""