Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ankona committed Sep 5, 2024
1 parent bd6ed3e commit 79583ac
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 33 deletions.
38 changes: 14 additions & 24 deletions tests/dragon/test_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import typing as t
from unittest.mock import MagicMock

import pytest
import typing as t

from smartsim._core.mli.comm.channel.channel import CommChannelBase
from smartsim._core.mli.mli_schemas.response.response_capnp import ResponseBuilder
Expand Down Expand Up @@ -121,15 +121,11 @@ def setup_worker_manager_model_bytes(
dispatcher_task_queue: mp.Queue[RequestBatch] = mp.Queue(maxsize=0)

worker_manager = WorkerManager(
EnvironmentConfigLoader(
featurestore_factory=DragonFeatureStore.from_descriptor,
callback_factory=FileSystemCommChannel.from_descriptor,
queue_factory=DragonFLIChannel.from_sender_supplied_descriptor,
),
integrated_worker_type,
config_loader=config_loader,
worker_type=integrated_worker_type,
dispatcher_queue=dispatcher_task_queue,
as_service=False,
cooldown=3,
dispatcher_queue=dispatcher_task_queue,
)

tensor_key = MessageHandler.build_feature_store_key(
Expand All @@ -138,14 +134,8 @@ def setup_worker_manager_model_bytes(
output_key = MessageHandler.build_feature_store_key(
"key", app_feature_store.descriptor
)
model = MessageHandler.build_model(b"model", "model name", "v 0.0.1")
request = MessageHandler.build_request(
test_dir, model, [tensor_key], [output_key], [], None
)
ser_request = MessageHandler.serialize_request(request)
worker_manager._task_queue.send(ser_request)

request = InferenceRequest(
inf_request = InferenceRequest(
model_key=None,
callback=None,
raw_inputs=None,
Expand All @@ -159,7 +149,7 @@ def setup_worker_manager_model_bytes(
model_id = FeatureStoreKey(key="key", descriptor=app_feature_store.descriptor)

request_batch = RequestBatch(
[request],
[inf_request],
TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]),
model_id=model_id,
)
Expand Down Expand Up @@ -194,15 +184,11 @@ def setup_worker_manager_model_key(
dispatcher_task_queue: mp.Queue[RequestBatch] = mp.Queue(maxsize=0)

worker_manager = WorkerManager(
EnvironmentConfigLoader(
featurestore_factory=DragonFeatureStore.from_descriptor,
callback_factory=FileSystemCommChannel.from_descriptor,
queue_factory=DragonFLIChannel.from_sender_supplied_descriptor,
),
config_loader=config_loader,
worker_type=integrated_worker_type,
dispatcher_queue=dispatcher_task_queue,
as_service=False,
cooldown=3,
dispatcher_queue=dispatcher_task_queue,
)

tensor_key = FeatureStoreKey(key="key", descriptor=app_feature_store.descriptor)
Expand Down Expand Up @@ -327,8 +313,12 @@ def setup_request_dispatcher_model_key(
return request_dispatcher, integrated_worker_type


def mock_pipeline_stage(monkeypatch: pytest.MonkeyPatch, integrated_worker, stage):
def mock_stage(*args, **kwargs):
def mock_pipeline_stage(
monkeypatch: pytest.MonkeyPatch,
integrated_worker: MachineLearningWorkerBase,
stage: str,
) -> t.Callable[[t.Any], ResponseBuilder]:
def mock_stage(*args: t.Any, **kwargs: t.Any) -> None:
raise ValueError(f"Simulated error in {stage}")

monkeypatch.setattr(integrated_worker, stage, mock_stage)
Expand Down
2 changes: 1 addition & 1 deletion tests/dragon/test_featurestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def set_value_after_delay(
# )
# # p.start()
# processes.append(p)

# for p in processes:
# p.start()

Expand Down
4 changes: 1 addition & 3 deletions tests/dragon/test_worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,7 @@ def test_worker_manager(prepare_environment: pathlib.Path) -> None:
to_worker_channel = dch.Channel.make_process_local()
to_worker_fli = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None)

to_worker_fli_comm_channel = DragonFLIChannel(
to_worker_fli, sender_supplied=True
)
to_worker_fli_comm_channel = DragonFLIChannel(to_worker_fli, sender_supplied=True)

# NOTE: env vars must be set prior to instantiating EnvironmentConfigLoader
# or test environment may be unable to send messages w/queue
Expand Down
7 changes: 2 additions & 5 deletions tests/mli/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,17 +107,14 @@ def clear(self) -> None:
@classmethod
def from_descriptor(
cls,
descriptor: t.Union[str, bytes],
descriptor: str,
) -> "FileSystemCommChannel":
"""A factory method that creates an instance from a descriptor string
:param descriptor: The descriptor that uniquely identifies the resource
:returns: An attached FileSystemCommChannel"""
try:
if isinstance(descriptor, str):
path = pathlib.Path(descriptor)
else:
path = pathlib.Path(descriptor.decode("utf-8"))
path = pathlib.Path(descriptor)
return FileSystemCommChannel(path)
except:
logger.warning(f"failed to create fs comm channel: {descriptor}")
Expand Down

0 comments on commit 79583ac

Please sign in to comment.