Skip to content

Commit

Permalink
finish renaming experts to module_backends in ConnectionHandler (#487)
Browse files Browse the repository at this point in the history
Аinish renaming experts -> module_backends in ConnectionHandler
  • Loading branch information
justheuristic authored Jun 17, 2022
1 parent bc2cccf commit f60e34a
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions hivemind/moe/server/connection_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
:note: ConnectionHandler is designed so as to allow using multiple handler processes for the same port
:param dht: a running hivemind.dht.DHT, used to let other peers connect to this one
:param experts: a dict [UID -> ModuleBackend] with all active experts
:param module_backends: a dict [UID -> ModuleBackend] with all active experts
"""

def __init__(self, dht: DHT, experts: Dict[str, ModuleBackend]):
def __init__(self, dht: DHT, module_backends: Dict[str, ModuleBackend]):
super().__init__()
self.dht, self.experts = dht, experts
self.dht, self.module_backends = dht, module_backends
self._p2p: Optional[P2P] = None

self.ready = MPFuture()
Expand Down Expand Up @@ -59,7 +59,8 @@ async def _run():
logger.debug("Caught KeyboardInterrupt, shutting down")

async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(self.experts[request.uid].get_info()))
module_info = self.module_backends[request.uid].get_info()
return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(module_info))

async def _gather_inputs(
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
Expand Down Expand Up @@ -93,7 +94,7 @@ async def _process_inputs(

async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
expert = self.experts[request.uid]
expert = self.module_backends[request.uid]
return runtime_pb2.ExpertResponse(
tensors=await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
)
Expand All @@ -102,7 +103,7 @@ async def rpc_forward_stream(
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
uid, inputs = await self._gather_inputs(requests, context)
expert = self.experts[uid]
expert = self.module_backends[uid]
output_split = [
part
for tensor in await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
Expand All @@ -116,7 +117,7 @@ async def rpc_backward(
self, request: runtime_pb2.ExpertRequest, context: P2PContext
) -> runtime_pb2.ExpertResponse:
inputs_and_grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
expert = self.experts[request.uid]
expert = self.module_backends[request.uid]
return runtime_pb2.ExpertResponse(
tensors=await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
)
Expand All @@ -125,7 +126,7 @@ async def rpc_backward_stream(
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
) -> AsyncIterator[runtime_pb2.ExpertResponse]:
uid, inputs_and_grads = await self._gather_inputs(requests, context)
expert = self.experts[uid]
expert = self.module_backends[uid]
output_split = [
part
for tensor in await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
Expand Down

0 comments on commit f60e34a

Please sign in to comment.