Skip to content

Commit

Permalink
MLI helper methods (#709)
Browse files Browse the repository at this point in the history
Helper methods added to InferenceReply and InferenceRequest.

[ committed by @AlyssaCote ]
[ reviewed by @al-rigazzi ]
  • Loading branch information
AlyssaCote authored Sep 19, 2024
1 parent 0ebd5ab commit d43f7c7
Show file tree
Hide file tree
Showing 7 changed files with 303 additions and 16 deletions.
1 change: 1 addition & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Jump to:

Description

- Add helper methods to MLI classes
- Update error handling for consistency
- Parameterize installation of dragon package with `smart build`
- Update docstrings
Expand Down
20 changes: 14 additions & 6 deletions smartsim/_core/mli/infrastructure/control/request_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,14 @@ def __init__(
self._perf_timer = PerfTimer(prefix="r_", debug=False, timing_on=True)
"""Performance timer"""

@property
def has_featurestore_factory(self) -> bool:
"""Check if the RequestDispatcher has a FeatureStore factory.
:returns: True if there is a FeatureStore factory, False otherwise
"""
return self._featurestore_factory is not None

def _check_feature_stores(self, request: InferenceRequest) -> bool:
"""Ensures that all feature stores required by the request are available.
Expand All @@ -272,7 +280,7 @@ def _check_feature_stores(self, request: InferenceRequest) -> bool:
fs_actual = {item.descriptor for item in self._feature_stores.values()}
fs_missing = fs_desired - fs_actual

if self._featurestore_factory is None:
if self.has_featurestore_factory:
logger.error("No feature store factory configured")
return False

Expand All @@ -292,7 +300,7 @@ def _check_model(self, request: InferenceRequest) -> bool:
:param request: The request to validate
:returns: False if model validation fails for the request, True otherwise
"""
if request.model_key or request.raw_model:
if request.has_model_key or request.has_raw_model:
return True

logger.error("Unable to continue without model bytes or feature store key")
Expand All @@ -305,7 +313,7 @@ def _check_inputs(self, request: InferenceRequest) -> bool:
:param request: The request to validate
:returns: False if input validation fails for the request, True otherwise
"""
if request.input_keys or request.raw_inputs:
if request.has_input_keys or request.has_raw_inputs:
return True

logger.error("Unable to continue without input bytes or feature store keys")
Expand All @@ -318,7 +326,7 @@ def _check_callback(self, request: InferenceRequest) -> bool:
:param request: The request to validate
:returns: False if callback validation fails for the request, True otherwise
"""
if request.callback is not None:
if request.callback:
return True

logger.error("No callback channel provided in request")
Expand Down Expand Up @@ -362,7 +370,7 @@ def _on_iteration(self) -> None:
request = self._worker.deserialize_message(
request_bytes, self._callback_factory
)
if request.input_meta and tensor_bytes_list:
if request.has_input_meta and tensor_bytes_list:
request.raw_inputs = tensor_bytes_list

self._perf_timer.measure_time("deserialize_message")
Expand Down Expand Up @@ -445,7 +453,7 @@ def dispatch(self, request: InferenceRequest) -> None:
:param request: The request to place
"""
if request.raw_model is not None:
if request.has_raw_model:
logger.debug("Direct inference requested, creating tmp queue")
tmp_id = f"_tmp_{str(uuid.uuid4())}"
tmp_queue: BatchQueue = BatchQueue(
Expand Down
22 changes: 15 additions & 7 deletions smartsim/_core/mli/infrastructure/control/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,14 @@ def __init__(
self._perf_timer = PerfTimer(prefix="w_", debug=False, timing_on=True)
"""Performance timer"""

@property
def has_featurestore_factory(self) -> bool:
"""Check if the WorkerManager has a FeatureStore factory.
:returns: True if there is a FeatureStore factory, False otherwise
"""
return self._featurestore_factory is not None

def _on_start(self) -> None:
"""Called on initial entry into Service `execute` event loop before
`_on_iteration` is invoked."""
Expand All @@ -132,7 +140,7 @@ def _check_feature_stores(self, batch: RequestBatch) -> bool:
fs_actual = {item.descriptor for item in self._feature_stores.values()}
fs_missing = fs_desired - fs_actual

if self._featurestore_factory is None:
if not self.has_featurestore_factory:
logger.error("No feature store factory configured")
return False

Expand All @@ -151,7 +159,7 @@ def _validate_batch(self, batch: RequestBatch) -> bool:
:param batch: The batch of requests to validate
:returns: False if the request fails any validation checks, True otherwise
"""
if batch is None or len(batch.requests) == 0:
if batch is None or not batch.has_valid_requests:
return False

return self._check_feature_stores(batch)
Expand Down Expand Up @@ -179,7 +187,7 @@ def _on_iteration(self) -> None:
)
return

if self._device_manager is None:
if not self._device_manager:
for request in batch.requests:
msg = "No Device Manager found. WorkerManager._on_start() "
"must be called after initialization. If possible, "
Expand Down Expand Up @@ -225,7 +233,7 @@ def _on_iteration(self) -> None:
return
self._perf_timer.measure_time("load_model")

if batch.inputs is None:
if not batch.inputs:
for request in batch.requests:
exception_handler(
ValueError("Error batching inputs"),
Expand Down Expand Up @@ -258,7 +266,7 @@ def _on_iteration(self) -> None:

for request, transformed_output in zip(batch.requests, transformed_outputs):
reply = InferenceReply()
if request.output_keys:
if request.has_output_keys:
try:
reply.output_keys = self._worker.place_output(
request,
Expand All @@ -274,7 +282,7 @@ def _on_iteration(self) -> None:
reply.outputs = transformed_output.outputs
self._perf_timer.measure_time("assign_output")

if reply.outputs is None or not reply.outputs:
if not reply.has_outputs:
response = build_failure_reply("fail", "Outputs not found.")
else:
reply.status_enum = "complete"
Expand All @@ -296,7 +304,7 @@ def _on_iteration(self) -> None:

if request.callback:
request.callback.send(serialized_resp)
if reply.outputs:
if reply.has_outputs:
# send tensor data after response
for output in reply.outputs:
request.callback.send(output)
Expand Down
76 changes: 73 additions & 3 deletions smartsim/_core/mli/infrastructure/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,58 @@ def __init__(
self.batch_size = batch_size
"""The batch size to apply when batching"""

@property
def has_raw_model(self) -> bool:
"""Check if the InferenceRequest contains a raw_model.
:returns: True if raw_model is not None, False otherwise
"""
return self.raw_model is not None

@property
def has_model_key(self) -> bool:
"""Check if the InferenceRequest contains a model_key.
:returns: True if model_key is not None, False otherwise
"""
return self.model_key is not None

@property
def has_raw_inputs(self) -> bool:
"""Check if the InferenceRequest contains raw_outputs.
:returns: True if raw_outputs is not None and is not an empty list,
False otherwise
"""
return self.raw_inputs is not None and bool(self.raw_inputs)

@property
def has_input_keys(self) -> bool:
"""Check if the InferenceRequest contains input_keys.
:returns: True if input_keys is not None and is not an empty list,
False otherwise
"""
return self.input_keys is not None and bool(self.input_keys)

@property
def has_output_keys(self) -> bool:
"""Check if the InferenceRequest contains output_keys.
:returns: True if output_keys is not None and is not an empty list,
False otherwise
"""
return self.output_keys is not None and bool(self.output_keys)

@property
def has_input_meta(self) -> bool:
"""Check if the InferenceRequest contains input_meta.
:returns: True if input_meta is not None and is not an empty list,
False otherwise
"""
return self.input_meta is not None and bool(self.input_meta)


class InferenceReply:
"""Internal representation of the reply to a client request for inference."""
Expand Down Expand Up @@ -121,6 +173,24 @@ def __init__(
self.message = message
"""Status message that corresponds with the status enum"""

@property
def has_outputs(self) -> bool:
"""Check if the InferenceReply contains outputs.
:returns: True if outputs is not None and is not an empty list,
False otherwise
"""
return self.outputs is not None and bool(self.outputs)

@property
def has_output_keys(self) -> bool:
"""Check if the InferenceReply contains output_keys.
:returns: True if output_keys is not None and is not an empty list,
False otherwise
"""
return self.output_keys is not None and bool(self.output_keys)


class LoadModelResult:
"""A wrapper around a loaded model."""
Expand Down Expand Up @@ -372,13 +442,13 @@ def prepare_outputs(reply: InferenceReply) -> t.List[t.Any]:
information needed in the reply
"""
prepared_outputs: t.List[t.Any] = []
if reply.output_keys:
if reply.has_output_keys:
for value in reply.output_keys:
if not value:
continue
msg_key = MessageHandler.build_tensor_key(value.key, value.descriptor)
prepared_outputs.append(msg_key)
elif reply.outputs:
elif reply.has_outputs:
for _ in reply.outputs:
msg_tensor_desc = MessageHandler.build_tensor_descriptor(
"c",
Expand Down Expand Up @@ -448,7 +518,7 @@ def fetch_inputs(
if not feature_stores:
raise ValueError("No input and no feature store provided")

if request.input_keys:
if request.has_input_keys:
data: t.List[bytes] = []

for fs_key in request.input_keys:
Expand Down
6 changes: 6 additions & 0 deletions tests/dragon/test_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,12 @@ def test_dispatcher_pipeline_stage_errors_handled(

mock_reply_fn = mock_pipeline_stage(monkeypatch, integrated_worker, stage)

monkeypatch.setattr(
request_dispatcher,
"_validate_request",
MagicMock(return_value=True),
)

if stage not in ["fetch_inputs"]:
monkeypatch.setattr(
integrated_worker,
Expand Down
76 changes: 76 additions & 0 deletions tests/dragon/test_inference_reply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# BSD 2-Clause License
#
# Copyright (c) 2021-2024, Hewlett Packard Enterprise
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import pytest

dragon = pytest.importorskip("dragon")

from smartsim._core.mli.infrastructure.storage.feature_store import FeatureStoreKey
from smartsim._core.mli.infrastructure.worker.worker import InferenceReply
from smartsim._core.mli.message_handler import MessageHandler

# The tests in this file belong to the dragon group
pytestmark = pytest.mark.dragon

handler = MessageHandler()


@pytest.fixture
def inference_reply() -> InferenceReply:
return InferenceReply()


@pytest.fixture
def fs_key() -> FeatureStoreKey:
return FeatureStoreKey("key", "descriptor")


@pytest.mark.parametrize(
"outputs, expected",
[
([b"output bytes"], True),
(None, False),
([], False),
],
)
def test_has_outputs(monkeypatch, inference_reply, outputs, expected):
"""Test the has_outputs property with different values for outputs."""
monkeypatch.setattr(inference_reply, "outputs", outputs)
assert inference_reply.has_outputs == expected


@pytest.mark.parametrize(
"output_keys, expected",
[
([fs_key], True),
(None, False),
([], False),
],
)
def test_has_output_keys(monkeypatch, inference_reply, output_keys, expected):
"""Test the has_output_keys property with different values for output_keys."""
monkeypatch.setattr(inference_reply, "output_keys", output_keys)
assert inference_reply.has_output_keys == expected
Loading

0 comments on commit d43f7c7

Please sign in to comment.