Skip to content

Commit

Permalink
merge fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ankona committed Sep 5, 2024
1 parent a95acd5 commit 9a1baf7
Show file tree
Hide file tree
Showing 12 changed files with 188 additions and 151 deletions.
32 changes: 22 additions & 10 deletions ex/high_throughput_inference/mock_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from smartsim._core.mli.infrastructure.storage.dragon_feature_store import (
DragonFeatureStore,
)
from smartsim.log import get_logger
from smartsim._core.utils.timings import PerfTimer

torch.set_num_interop_threads(16)
Expand All @@ -64,7 +65,7 @@
from smartsim._core.mli.infrastructure.storage.backbone_feature_store import (
BackboneFeatureStore,
EventBroadcaster,
EventPublisher,
EventProducer,
OnWriteFeatureStore,
)
from smartsim.error.errors import SmartSimError
Expand All @@ -79,6 +80,7 @@

CHECK_RESULTS_AND_MAKE_ALL_SLOWER = False


class ProtoClient:
def __init__(self, timing_on: bool):
comm = MPI.COMM_WORLD
Expand Down Expand Up @@ -130,12 +132,12 @@ def _create_worker_channels(self) -> t.Tuple[DragonCommChannel, DragonCommChanne

return _from_worker_ch, _to_worker_ch

def _create_publisher(self) -> EventPublisher:
def _create_publisher(self) -> EventProducer:
"""Create an event publisher that will broadcast updates to
other MLI components. This publisher
:returns: the event publisher instance"""
publisher: EventPublisher = EventBroadcaster(
publisher: EventProducer = EventBroadcaster(
self._backbone, DragonCommChannel.from_descriptor
)
return publisher
Expand Down Expand Up @@ -174,7 +176,9 @@ def __init__(self, timing_on: bool, wait_timeout: float = 0):

self._publisher = self._create_publisher()

self.perf_timer: PerfTimer = PerfTimer(debug=False, timing_on=timing_on, prefix=f"a{rank}_")
self.perf_timer: PerfTimer = PerfTimer(
debug=False, timing_on=timing_on, prefix=f"a{rank}_"
)
self._start = None
self._interm = None
self._timings: _TIMING_DICT = OrderedDict()
Expand Down Expand Up @@ -288,7 +292,6 @@ def set_model(self, key: str, model: bytes):
self._publisher.send(event)



class ResNetWrapper:
def __init__(self, name: str, model: str):
self._model = torch.jit.load(model)
Expand All @@ -309,6 +312,7 @@ def model(self):
def name(self):
return self._name


if __name__ == "__main__":

parser = argparse.ArgumentParser("Mock application")
Expand All @@ -324,24 +328,32 @@ def name(self):
if CHECK_RESULTS_AND_MAKE_ALL_SLOWER:
# TODO: adapt to non-Nvidia devices
torch_device = args.device.replace("gpu", "cuda")
pt_model = torch.jit.load(io.BytesIO(initial_bytes=(resnet.model))).to(torch_device)
pt_model = torch.jit.load(io.BytesIO(initial_bytes=(resnet.model))).to(
torch_device
)

TOTAL_ITERATIONS = 100

for log2_bsize in range(args.log_max_batchsize+1):
for log2_bsize in range(args.log_max_batchsize + 1):
b_size: int = 2**log2_bsize
logger.info(f"Batch size: {b_size}")
for iteration_number in range(TOTAL_ITERATIONS + int(b_size==1)):
for iteration_number in range(TOTAL_ITERATIONS + int(b_size == 1)):
logger.info(f"Iteration: {iteration_number}")
sample_batch = resnet.get_batch(b_size)
remote_result = client.run_model(resnet.name, sample_batch)
logger.info(client.perf_timer.get_last("total_time"))
if CHECK_RESULTS_AND_MAKE_ALL_SLOWER:
local_res = pt_model(sample_batch.to(torch_device))
err_norm = torch.linalg.vector_norm(torch.flatten(remote_result).to(torch_device)-torch.flatten(local_res), ord=1).cpu()
err_norm = torch.linalg.vector_norm(
torch.flatten(remote_result).to(torch_device)
- torch.flatten(local_res),
ord=1,
).cpu()
res_norm = torch.linalg.vector_norm(remote_result, ord=1).item()
local_res_norm = torch.linalg.vector_norm(local_res, ord=1).item()
logger.info(f"Avg norm of error {err_norm.item()/b_size} compared to result norm of {res_norm/b_size}:{local_res_norm/b_size}")
logger.info(
f"Avg norm of error {err_norm.item()/b_size} compared to result norm of {res_norm/b_size}:{local_res_norm/b_size}"
)
torch.cuda.synchronize()

client.perf_timer.print_timings(to_file=True)
1 change: 0 additions & 1 deletion tests/dragon/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def __init__(
"""Initialize the FileSystemFeatureStore instance
:param storage_dir: (optional) root directory to store all data relative to"""
super().__init__()
if isinstance(storage_dir, str):
storage_dir = pathlib.Path(storage_dir)
self._storage_dir = storage_dir
Expand Down
58 changes: 38 additions & 20 deletions tests/dragon/test_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
from unittest.mock import MagicMock

import pytest
import typing as t

from smartsim._core.mli.comm.channel.channel import CommChannelBase
from smartsim._core.mli.mli_schemas.response.response_capnp import ResponseBuilder

dragon = pytest.importorskip("dragon")

Expand Down Expand Up @@ -62,6 +66,7 @@
InferenceReply,
InferenceRequest,
LoadModelResult,
MachineLearningWorkerBase,
RequestBatch,
TransformInputResult,
TransformOutputResult,
Expand Down Expand Up @@ -92,7 +97,7 @@ def app_feature_store() -> FeatureStore:

@pytest.fixture
def setup_worker_manager_model_bytes(
test_dir,
test_dir: str,
monkeypatch: pytest.MonkeyPatch,
backbone_descriptor: str,
app_feature_store: FeatureStore,
Expand All @@ -113,17 +118,18 @@ def setup_worker_manager_model_bytes(
queue_factory=DragonFLIChannel.from_descriptor,
)

dispatcher_task_queue = mp.Queue(maxsize=0)
dispatcher_task_queue: mp.Queue[RequestBatch] = mp.Queue(maxsize=0)

worker_manager = WorkerManager(
EnvironmentConfigLoader(
featurestore_factory=DragonFeatureStore.from_descriptor,
callback_factory=FileSystemCommChannel.from_descriptor,
queue_factory=DragonFLIChannel.from_sender_supplied_descriptor,
),
integrated_worker,
integrated_worker_type,
as_service=False,
cooldown=3,
dispatcher_queue=dispatcher_task_queue,
)

tensor_key = MessageHandler.build_feature_store_key(
Expand Down Expand Up @@ -185,16 +191,18 @@ def setup_worker_manager_model_key(
queue_factory=DragonFLIChannel.from_descriptor,
)

dispatcher_task_queue = mp.Queue(maxsize=0)
dispatcher_task_queue: mp.Queue[RequestBatch] = mp.Queue(maxsize=0)

worker_manager = WorkerManager(
EnvironmentConfigLoader(
featurestore_factory=DragonFeatureStore.from_descriptor,
callback_factory=FileSystemCommChannel.from_descriptor,
queue_factory=DragonFLIChannel.from_sender_supplied_descriptor,
),
worker_type=integrated_worker_type,
as_service=False,
cooldown=3,
dispatcher_queue=dispatcher_task_queue,
)

tensor_key = FeatureStoreKey(key="key", descriptor=app_feature_store.descriptor)
Expand Down Expand Up @@ -223,7 +231,7 @@ def setup_worker_manager_model_key(

@pytest.fixture
def setup_request_dispatcher_model_bytes(
test_dir,
test_dir: str,
monkeypatch: pytest.MonkeyPatch,
backbone_descriptor: str,
app_feature_store: FeatureStore,
Expand Down Expand Up @@ -252,8 +260,12 @@ def setup_request_dispatcher_model_bytes(
)
request_dispatcher._on_start()

tensor_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor)
output_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor)
tensor_key = MessageHandler.build_feature_store_key(
"key", app_feature_store.descriptor
)
output_key = MessageHandler.build_feature_store_key(
"key", app_feature_store.descriptor
)
model = MessageHandler.build_model(b"model", "model name", "v 0.0.1")
request = MessageHandler.build_request(
test_dir, model, [tensor_key], [output_key], [], None
Expand All @@ -267,7 +279,7 @@ def setup_request_dispatcher_model_bytes(

@pytest.fixture
def setup_request_dispatcher_model_key(
test_dir,
test_dir: str,
monkeypatch: pytest.MonkeyPatch,
backbone_descriptor: str,
app_feature_store: FeatureStore,
Expand Down Expand Up @@ -296,9 +308,13 @@ def setup_request_dispatcher_model_key(
)
request_dispatcher._on_start()

tensor_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor)
output_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor)
model_key = MessageHandler.build_model_key(
tensor_key = MessageHandler.build_feature_store_key(
"key", app_feature_store.descriptor
)
output_key = MessageHandler.build_feature_store_key(
"key", app_feature_store.descriptor
)
model_key = MessageHandler.build_feature_store_key(
key="model key", feature_store_descriptor=app_feature_store.descriptor
)
request = MessageHandler.build_request(
Expand All @@ -325,8 +341,10 @@ def mock_stage(*args, **kwargs):
mock_reply_channel = MagicMock()
mock_reply_channel.send = MagicMock()

def mock_exception_handler(exc, reply_channel, failure_message):
return exception_handler(exc, mock_reply_channel, failure_message)
def mock_exception_handler(
exc: Exception, reply_channel: CommChannelBase, failure_message: str
) -> None:
exception_handler(exc, mock_reply_channel, failure_message)

monkeypatch.setattr(
"smartsim._core.mli.infrastructure.control.worker_manager.exception_handler",
Expand Down Expand Up @@ -373,12 +391,12 @@ def mock_exception_handler(exc, reply_channel, failure_message):
],
)
def test_wm_pipeline_stage_errors_handled(
request,
setup_worker_manager,
request: pytest.FixtureRequest,
setup_worker_manager: str,
monkeypatch: pytest.MonkeyPatch,
stage: str,
error_message: str,
):
) -> None:
"""Ensures that the worker manager does not crash after a failure in various pipeline stages"""
worker_manager, integrated_worker_type = request.getfixturevalue(
setup_worker_manager
Expand Down Expand Up @@ -457,12 +475,12 @@ def test_wm_pipeline_stage_errors_handled(
],
)
def test_dispatcher_pipeline_stage_errors_handled(
request,
setup_request_dispatcher,
request: pytest.FixtureRequest,
setup_request_dispatcher: str,
monkeypatch: pytest.MonkeyPatch,
stage: str,
error_message: str,
):
) -> None:
"""Ensures that the request dispatcher does not crash after a failure in various pipeline stages"""
request_dispatcher, integrated_worker_type = request.getfixturevalue(
setup_request_dispatcher
Expand All @@ -484,7 +502,7 @@ def test_dispatcher_pipeline_stage_errors_handled(
mock_reply_fn.assert_called_with("fail", error_message)


def test_exception_handling_helper(monkeypatch: pytest.MonkeyPatch):
def test_exception_handling_helper(monkeypatch: pytest.MonkeyPatch) -> None:
"""Ensures that the worker manager does not crash after a failure in the
execute pipeline stage"""

Expand Down
Loading

0 comments on commit 9a1baf7

Please sign in to comment.