forked from gcorso/DiffDock
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcp_diffdock_service.py
195 lines (167 loc) · 7.54 KB
/
cp_diffdock_service.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import asyncio
import json
import logging
import re
import os
from websockets.server import serve
import concurrent.futures
from argparse import ArgumentParser
from cp_diffdock_protocol import DiffDockProtocol
from cp_diffdock_api import DiffDockApi
from cp_ws_helpers import extractWsAppMessage, formWsAppMessage, formCompletedMessage
logging.basicConfig(level=logging.INFO)
log = logging.getLogger('diffdock_service')
#
# This is a websocket server for servicing DiffDock requests.
# This listens on a port and places incoming json requests into a queue, serviced by a number of workers.
#
async def handleRequest(websocket, queue):
"""
Receive a request and put it into the queue.
Make sure we keep the websocket so we can send the response when docking finishes.
"""
try:
message = await websocket.recv()
except Exception as ex:
log.error(f"Exception reading from websocket: {ex}. Can't handle this request.")
return
log.info("Received a request.")
# Optional special handling for messages from bfd-server or other Bioleap WsApps
# requestId is non-empty if it is from a wsApp, and will need to be included in responses.
requestId, cmdName, requestData = extractWsAppMessage(message)
requestContext = websocket, requestId
# Parse
try:
requestObj = DiffDockProtocol.Request.from_json(requestData)
except:
log.exception(f"Rejected a request:\nMessage: {message}\nrequestData: {requestData}")
await sendError(requestContext, f"Invalid request")
return
# Enqueue
log.info("handleRequest queuing a request...")
try:
# Put the request into the queue with a condition to resolve when it finishes.
# Condition protects the websocket from closing early.
condition = asyncio.Condition()
await queue.put((requestObj, requestContext, condition))
log.info(f"handleRequest queue size is now {queue.qsize()}")
await sendStatus(requestContext, "Request received")
async with condition:
await condition.wait()
log.info(f"handleRequest finished processing a request")
except Exception as ex:
log.exception(f"handleRequest encountered exception")
await sendError(requestContext, f"Service exception: {ex}")
finally:
log.info(f"handleRequest finished")
async def queueWorker(queue):
"""
Handle docking requests from the queue
"""
with concurrent.futures.ThreadPoolExecutor() as executor:
while True:
# Wait for an item from the queue
requestObj, requestContext, condition = await queue.get()
log.info(f"queueWorker running a request... (queue size is now {queue.qsize()})")
try:
# Since this is async, but diffdock is not, it seems it is tricky.
# run_in_executor seemed to do the trick.
# We want to make sure that diffdock execution doesn't block other
# connections from putting their requests into the queue.
await sendStatus(requestContext, "Working request...")
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(executor, DiffDockApi.run_diffdock, requestObj)
log.info("queueWorker resolved response")
await sendResults(requestContext, response)
log.info("queueWorker sent response.")
except Exception as ex:
log.exception(f"queueWorker encountered exception")
await sendError(requestContext, f"Service exception: {ex}")
# Notify the queue that the item has been processed
queue.task_done()
# Resolve the condition for the request handler
async with condition:
condition.notify_all()
async def main(host="localhost", port=9002, max_size=2**24, worker_count=5):
"""
DiffDock websocket server main loop.
Set up the queue, worker threads, and listen for connections.
"""
queue = asyncio.Queue()
workers = [asyncio.create_task(queueWorker(queue)) for _ in range(worker_count)]
async def handlerWrapper(websocket, path):
await handleRequest(websocket, queue)
log.info(f"DiffDock service running with pid {os.getpid()}, listening at {host}:{port}...")
start_server = serve(handlerWrapper, host, port, max_size=max_size)
server = asyncio.ensure_future(start_server)
log.info(f"DiffDock ready.")
try:
await asyncio.Future()
except (KeyboardInterrupt, asyncio.CancelledError):
log.info("Received KeyboardInterrupt")
# SHUTDOWN
try:
log.info("DiffDock service shutting down...let this finish naturally if you can")
cancelingTasks = [server, *workers]
for task in cancelingTasks:
log.info(f"Canceling task {task}...")
task.cancel()
log.info("Waiting for tasks to finish...")
exceptions = await asyncio.gather(*cancelingTasks, return_exceptions=True)
log.info("Done with all tasks.")
except:
log.exception("DiffDock service caught exception shutting down. Forcing termination.")
finally:
log.info("DiffDock service finished.")
os._exit(0)
async def sendToWebsocket(websocket, data):
try:
return await websocket.send(data)
except Exception as ex:
log.error(f"Exception encountered when writing to a websocket: {ex}")
async def sendPacket(requestContext, responseObj):
websocket, requestId = requestContext
if requestId:
return await sendWsAppPacket(requestContext, responseObj)
else:
return await sendToWebsocket(websocket, responseObj.to_json())
async def sendWsAppPacket(requestContext, responseObj):
"""
Bioleap WsApp (bfd-server) messages look like `#<request id> <cmd name> <message>`
and a request must be terminated with a `completed` message.
"""
websocket, requestId = requestContext
messageType = responseObj.messageType
responseCmd = None
needsComplete = False
if messageType == DiffDockProtocol.MessageType.ERROR:
responseCmd = "diffdock-error"
needsComplete = True
elif messageType == DiffDockProtocol.MessageType.RESULTS:
responseCmd = "diffdock-results"
needsComplete = True
elif messageType == DiffDockProtocol.MessageType.STATUS:
responseCmd = "diffdock-status"
payload = formWsAppMessage(requestId, responseCmd, responseObj.to_json())
log.info(f'Sending {payload}')
await sendToWebsocket(websocket, payload)
if needsComplete:
# wsApp messages need a `complete` at the end of the transaction
log.info(f'Sending completed')
completePayload = formCompletedMessage(requestId)
return await sendToWebsocket(websocket, completePayload)
async def sendStatus(requestContext, status):
responseObj = DiffDockProtocol.Response.makeStatus(status)
return await sendPacket(requestContext, responseObj)
async def sendError(requestContext, error):
responseObj = DiffDockProtocol.Response.makeError(error)
return await sendPacket(requestContext, responseObj)
async def sendResults(requestContext, responseObj):
return await sendPacket(requestContext, responseObj)
# ENTRYPOINT
arg_parser = ArgumentParser()
arg_parser.add_argument('--port', type=int, default=9002, help='Port to listen on')
arg_parser.add_argument('--worker_count', type=int, default=5, help='Number of worker threads to run')
args = arg_parser.parse_args()
log.info("Starting inference service...")
asyncio.run(main(port=args.port, worker_count=args.worker_count))