Skip to content

Commit

Permalink
fix send-multiple items behavior with no sender supplied FLI factory
Browse files Browse the repository at this point in the history
  • Loading branch information
ankona committed Oct 9, 2024
1 parent 608d6bd commit 170c9ea
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 35 deletions.
30 changes: 22 additions & 8 deletions smartsim/_core/mli/comm/channel/dragon_fli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

# isort: off
from dragon import fli
import dragon.channels as dch

# isort: on

Expand Down Expand Up @@ -59,9 +58,6 @@ def __init__(

self._fli = fli_
"""The underlying dragon FLInterface used by this CommChannel for communications"""
self._channel: t.Optional["dch.Channel"] = None
"""The underlying dragon Channel used by a sender-side DragonFLIChannel
to attach to the main FLI channel"""
self._buffer_size: int = buffer_size
"""Maximum number of messages that can be buffered before sending"""

Expand All @@ -73,18 +69,36 @@ def send(self, value: bytes, timeout: float = 0.001) -> None:
:raises SmartSimError: If sending message fails
"""
try:
if self._channel is None:
self._channel = drg_util.create_local(self._buffer_size)
channel = drg_util.create_local(self._buffer_size)

with self._fli.sendh(timeout=None, stream_channel=self._channel) as sendh:
with self._fli.sendh(timeout=None, stream_channel=channel) as sendh:
sendh.send_bytes(value, timeout=timeout)
logger.debug(f"DragonFLIChannel {self.descriptor} sent message")
except Exception as e:
self._channel = None
raise SmartSimError(
f"Error sending via DragonFLIChannel {self.descriptor}"
) from e

def send_multiple(self, values: t.Sequence[bytes], timeout: float = 0.001) -> None:
"""Send a message through the underlying communication channel.
:param values: The values to send
:param timeout: Maximum time to wait (in seconds) for messages to send
:raises SmartSimError: If sending message fails
"""
try:
channel = drg_util.create_local(self._buffer_size)

with self._fli.sendh(timeout=None, stream_channel=channel) as sendh:
for value in values:
sendh.send_bytes(value)
logger.debug(f"DragonFLIChannel {self.descriptor} sent message")
except Exception as e:
self._channel = None
raise SmartSimError(
f"Error sending via DragonFLIChannel {self.descriptor} {e}"
) from e

def recv(self, timeout: float = 0.001) -> t.List[bytes]:
"""Receives message(s) through the underlying communication channel.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ def _on_iteration(self) -> None:
None,
)

logger.debug(f"Dispatcher is processing {len(bytes_list)} messages")
request_bytes = bytes_list[0]
tensor_bytes_list = bytes_list[1:]
self._perf_timer.start_timings()
Expand Down
18 changes: 12 additions & 6 deletions tests/dragon/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@
from smartsim._core.mli.infrastructure.storage.backbone_feature_store import (
BackboneFeatureStore,
)
from smartsim._core.mli.infrastructure.storage.dragon_feature_store import (
DragonFeatureStore,
)
from smartsim.log import get_logger

logger = get_logger(__name__)
msg_pump_path = pathlib.Path(__file__).parent / "utils" / "msg_pump.py"

class MsgPumpRequest(t.NamedTuple):
"""Fields required for starting a simulated inference request producer."""
Expand Down Expand Up @@ -116,17 +117,22 @@ def run_message_pump(request: MsgPumpRequest) -> subprocess.Popen:
:param request: A request containing all parameters required to
invoke the message pump entrypoint
:returns: The Popen object for the subprocess that was started"""
# <smartsim_dir>/tests/dragon/utils/msg_pump.py
msg_pump_script = "tests/dragon/utils/msg_pump.py"
msg_pump_path = pathlib.Path(__file__).parent / msg_pump_script
assert request.backbone_descriptor
assert request.callback_descriptor
assert request.work_queue_descriptor

# <smartsim_dir>/tests/dragon/utils/msg_pump.py
cmd = [sys.executable, str(msg_pump_path.absolute()), *request.as_command()]
logger.info(f"Executing msg_pump with command: {cmd}")

popen = subprocess.Popen(
args=cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)

assert popen is not None
assert popen.returncode is None
return popen

return run_message_pump
Expand Down
18 changes: 7 additions & 11 deletions tests/dragon/test_request_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
from smartsim.log import get_logger

logger = get_logger(__name__)
mock_msg_pump_path = pathlib.Path(__file__).parent / "utils" / "msg_pump.py"
_MsgPumpFactory = t.Callable[[conftest.MsgPumpRequest], sp.Popen]

# The tests in this file belong to the dragon group
Expand Down Expand Up @@ -129,8 +128,8 @@ def test_request_dispatcher(
)

request_dispatcher._on_start()
pump_processes: t.List[sp.Popen] = []

# put some messages into the back queue for the dispatcher to pickup
for i in range(num_iterations):
batch: t.Optional[RequestBatch] = None
mem_allocs = []
Expand All @@ -149,18 +148,22 @@ def test_request_dispatcher(
)

msg_pump = msg_pump_factory(request)
pump_processes.append(msg_pump)

assert msg_pump is not None, "Msg Pump Process Creation Failed"
assert msg_pump.wait() == 0

time.sleep(1)

for _ in range(200):
for i in range(15):
try:
request_dispatcher._on_iteration()
batch = request_dispatcher.task_queue.get(timeout=0.1)
break
except Empty:
logger.warning(f"Task queue is empty on iteration {i}")
continue
except Exception as exc:
logger.error(f"Task queue exception on iteration {i}")
raise exc

assert batch is not None
Expand Down Expand Up @@ -219,13 +222,6 @@ def test_request_dispatcher(
assert model_key not in request_dispatcher._active_queues
assert model_key not in request_dispatcher._queues

msg_pump.wait()

for msg_pump in pump_processes:
if msg_pump.returncode is not None:
continue
msg_pump.terminate()

# Try to remove the dispatcher and free the memory
del request_dispatcher
gc.collect()
27 changes: 17 additions & 10 deletions tests/dragon/utils/msg_pump.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import io
import logging
import pathlib
import time
import sys
import typing as t

import pytest
Expand All @@ -44,7 +44,6 @@

# isort: on

from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel
from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel
from smartsim._core.mli.infrastructure.storage.backbone_feature_store import (
BackboneFeatureStore,
Expand Down Expand Up @@ -124,6 +123,8 @@ def mock_messages(
feature_store = BackboneFeatureStore.from_descriptor(fs_descriptor)
request_dispatcher_queue = DragonFLIChannel.from_descriptor(dispatch_fli_descriptor)

feature_store[model_key] = load_model()

for iteration_number in range(2):
logged_iteration = offset + iteration_number
logger.debug(f"Sending mock message {logged_iteration}")
Expand Down Expand Up @@ -163,9 +164,9 @@ def mock_messages(

logger.info(
f"Retrieving {iteration_number} from callback channel: {callback_descriptor}"
)
callback_channel = DragonCommChannel.from_descriptor(callback_descriptor)

# send the header & body together so they arrive together
request_dispatcher_queue.send_multiple([request_bytes, tensor.tobytes()])
# Results will be empty. The test pulls messages off the queue before they
# can be serviced by a worker. Just ensure the callback channel works.
results = callback_channel.recv(timeout=0.1)
Expand All @@ -185,9 +186,15 @@ def mock_messages(

args = args.parse_args()

mock_messages(
args.dispatch_fli_descriptor,
args.fs_descriptor,
args.parent_iteration,
args.callback_descriptor,
)
try:
mock_messages(
args.dispatch_fli_descriptor,
args.fs_descriptor,
args.parent_iteration,
args.callback_descriptor,
)
except Exception as ex:
logger.exception("The message pump did not execute properly")
sys.exit(100)

sys.exit(0)

0 comments on commit 170c9ea

Please sign in to comment.