From 79583ac19a7daec284e44c70f95bd5d9e092f1af Mon Sep 17 00:00:00 2001 From: ankona <3595025+ankona@users.noreply.github.com> Date: Thu, 5 Sep 2024 18:53:43 -0500 Subject: [PATCH] fixes --- tests/dragon/test_error_handling.py | 38 +++++++++++------------------ tests/dragon/test_featurestore.py | 2 +- tests/dragon/test_worker_manager.py | 4 +-- tests/mli/channel.py | 7 ++---- 4 files changed, 18 insertions(+), 33 deletions(-) diff --git a/tests/dragon/test_error_handling.py b/tests/dragon/test_error_handling.py index 90e0c2fe40..7034dd2795 100644 --- a/tests/dragon/test_error_handling.py +++ b/tests/dragon/test_error_handling.py @@ -24,10 +24,10 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import typing as t from unittest.mock import MagicMock import pytest -import typing as t from smartsim._core.mli.comm.channel.channel import CommChannelBase from smartsim._core.mli.mli_schemas.response.response_capnp import ResponseBuilder @@ -121,15 +121,11 @@ def setup_worker_manager_model_bytes( dispatcher_task_queue: mp.Queue[RequestBatch] = mp.Queue(maxsize=0) worker_manager = WorkerManager( - EnvironmentConfigLoader( - featurestore_factory=DragonFeatureStore.from_descriptor, - callback_factory=FileSystemCommChannel.from_descriptor, - queue_factory=DragonFLIChannel.from_sender_supplied_descriptor, - ), - integrated_worker_type, + config_loader=config_loader, + worker_type=integrated_worker_type, + dispatcher_queue=dispatcher_task_queue, as_service=False, cooldown=3, - dispatcher_queue=dispatcher_task_queue, ) tensor_key = MessageHandler.build_feature_store_key( @@ -138,14 +134,8 @@ def setup_worker_manager_model_bytes( output_key = MessageHandler.build_feature_store_key( "key", app_feature_store.descriptor ) - model = MessageHandler.build_model(b"model", "model name", "v 0.0.1") - request = MessageHandler.build_request( - test_dir, model, [tensor_key], [output_key], [], None - ) - ser_request = MessageHandler.serialize_request(request) - worker_manager._task_queue.send(ser_request) - request = InferenceRequest( + inf_request = InferenceRequest( model_key=None, callback=None, raw_inputs=None, @@ -159,7 +149,7 @@ def setup_worker_manager_model_bytes( model_id = FeatureStoreKey(key="key", descriptor=app_feature_store.descriptor) request_batch = RequestBatch( - [request], + [inf_request], TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]), model_id=model_id, ) @@ -194,15 +184,11 @@ def setup_worker_manager_model_key( dispatcher_task_queue: mp.Queue[RequestBatch] = mp.Queue(maxsize=0) worker_manager = WorkerManager( - EnvironmentConfigLoader( - featurestore_factory=DragonFeatureStore.from_descriptor, - callback_factory=FileSystemCommChannel.from_descriptor, - queue_factory=DragonFLIChannel.from_sender_supplied_descriptor, - ), + config_loader=config_loader, worker_type=integrated_worker_type, + dispatcher_queue=dispatcher_task_queue, as_service=False, cooldown=3, - dispatcher_queue=dispatcher_task_queue, ) tensor_key = FeatureStoreKey(key="key", descriptor=app_feature_store.descriptor) @@ -327,8 +313,12 @@ def setup_request_dispatcher_model_key( return request_dispatcher, integrated_worker_type -def mock_pipeline_stage(monkeypatch: pytest.MonkeyPatch, integrated_worker, stage): - def mock_stage(*args, **kwargs): +def mock_pipeline_stage( + monkeypatch: pytest.MonkeyPatch, + integrated_worker: MachineLearningWorkerBase, + stage: str, +) -> t.Callable[[t.Any], ResponseBuilder]: + def mock_stage(*args: t.Any, **kwargs: t.Any) -> None: raise ValueError(f"Simulated error in {stage}") monkeypatch.setattr(integrated_worker, stage, mock_stage) diff --git a/tests/dragon/test_featurestore.py b/tests/dragon/test_featurestore.py index 5fca3733b7..f41e37088a 100644 --- a/tests/dragon/test_featurestore.py +++ b/tests/dragon/test_featurestore.py @@ -302,7 +302,7 @@ def set_value_after_delay( # ) # # p.start() # processes.append(p) - + # for p in processes: # p.start() diff --git a/tests/dragon/test_worker_manager.py b/tests/dragon/test_worker_manager.py index 4a8ccfeeba..b8352eb13c 100644 --- a/tests/dragon/test_worker_manager.py +++ b/tests/dragon/test_worker_manager.py @@ -288,9 +288,7 @@ def test_worker_manager(prepare_environment: pathlib.Path) -> None: to_worker_channel = dch.Channel.make_process_local() to_worker_fli = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None) - to_worker_fli_comm_channel = DragonFLIChannel( - to_worker_fli, sender_supplied=True - ) + to_worker_fli_comm_channel = DragonFLIChannel(to_worker_fli, sender_supplied=True) # NOTE: env vars must be set prior to instantiating EnvironmentConfigLoader # or test environment may be unable to send messages w/queue diff --git a/tests/mli/channel.py b/tests/mli/channel.py index fcc543547f..b00ba9aa2b 100644 --- a/tests/mli/channel.py +++ b/tests/mli/channel.py @@ -107,17 +107,14 @@ def clear(self) -> None: @classmethod def from_descriptor( cls, - descriptor: t.Union[str, bytes], + descriptor: str, ) -> "FileSystemCommChannel": """A factory method that creates an instance from a descriptor string :param descriptor: The descriptor that uniquely identifies the resource :returns: An attached FileSystemCommChannel""" try: - if isinstance(descriptor, str): - path = pathlib.Path(descriptor) - else: - path = pathlib.Path(descriptor.decode("utf-8")) + path = pathlib.Path(descriptor) return FileSystemCommChannel(path) except: logger.warning(f"failed to create fs comm channel: {descriptor}")