diff --git a/doc/changelog.md b/doc/changelog.md index 495cff3ed..1c91705ad 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -13,6 +13,7 @@ Jump to: Description +- Adjust schemas for better performance - Add TorchWorker first implementation and mock inference app example - Add error handling in Worker Manager pipeline - Add EnvironmentConfigLoader for ML Worker Manager diff --git a/ex/high_throughput_inference/mock_app.py b/ex/high_throughput_inference/mock_app.py index 45246db2e..e244c93e0 100644 --- a/ex/high_throughput_inference/mock_app.py +++ b/ex/high_throughput_inference/mock_app.py @@ -108,10 +108,11 @@ def print_timings(self, to_file: bool = False): def run_model(self, model: bytes | str, batch: torch.Tensor): + tensors = [batch.numpy()] self.start_timings(batch.shape[0]) - built_tensor = MessageHandler.build_tensor( - batch.numpy(), "c", "float32", list(batch.shape)) - self.measure_time("build_tensor") + built_tensor_desc = MessageHandler.build_tensor_descriptor( + "c", "float32", list(batch.shape)) + self.measure_time("build_tensor_descriptor") built_model = None if isinstance(model, str): model_arg = MessageHandler.build_model_key(model) @@ -120,7 +121,7 @@ def run_model(self, model: bytes | str, batch: torch.Tensor): request = MessageHandler.build_request( reply_channel=self._from_worker_ch_serialized, model= model_arg, - inputs=[built_tensor], + inputs=[built_tensor_desc], outputs=[], output_descriptors=[], custom_attributes=None, @@ -130,6 +131,9 @@ def run_model(self, model: bytes | str, batch: torch.Tensor): self.measure_time("serialize_request") with self._to_worker_fli.sendh(timeout=None, stream_channel=self._to_worker_ch) as to_sendh: to_sendh.send_bytes(request_bytes) + for t in tensors: + to_sendh.send_bytes(t.tobytes()) #TODO NOT FAST ENOUGH!!! + # to_sendh.send_bytes(bytes(t.data)) logger.info(f"Message size: {len(request_bytes)} bytes") self.measure_time("send") @@ -138,10 +142,12 @@ def run_model(self, model: bytes | str, batch: torch.Tensor): self.measure_time("receive") response = MessageHandler.deserialize_response(resp) self.measure_time("deserialize_response") + # list of data blobs? recv depending on the len(response.result.descriptors)? + data_blob = from_recvh.recv_bytes(timeout=None) result = torch.from_numpy( numpy.frombuffer( - response.result.data[0].blob, - dtype=str(response.result.data[0].tensorDescriptor.dataType), + data_blob, + dtype=str(response.result.descriptors[0].dataType), ) ) self.measure_time("deserialize_tensor") diff --git a/smartsim/_core/mli/comm/channel/channel.py b/smartsim/_core/mli/comm/channel/channel.py index 2318896a9..a3cce2181 100644 --- a/smartsim/_core/mli/comm/channel/channel.py +++ b/smartsim/_core/mli/comm/channel/channel.py @@ -45,7 +45,7 @@ def send(self, value: bytes) -> None: :param value: The value to send""" @abstractmethod - def recv(self) -> bytes: + def recv(self) -> t.List[bytes]: """Receieve a message through the underlying communication channel :returns: the received message""" diff --git a/smartsim/_core/mli/comm/channel/dragonchannel.py b/smartsim/_core/mli/comm/channel/dragonchannel.py index 1409747a9..672fce75b 100644 --- a/smartsim/_core/mli/comm/channel/dragonchannel.py +++ b/smartsim/_core/mli/comm/channel/dragonchannel.py @@ -25,6 +25,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import sys +import typing as t import smartsim._core.mli.comm.channel.channel as cch from smartsim.log import get_logger @@ -52,9 +53,9 @@ def send(self, value: bytes) -> None: with self._channel.sendh(timeout=None) as sendh: sendh.send_bytes(value) - def recv(self) -> bytes: + def recv(self) -> t.List[bytes]: """Receieve a message through the underlying communication channel :returns: the received message""" with self._channel.recvh(timeout=None) as recvh: message_bytes: bytes = recvh.recv_bytes(timeout=None) - return message_bytes + return [message_bytes] diff --git a/smartsim/_core/mli/comm/channel/dragonfli.py b/smartsim/_core/mli/comm/channel/dragonfli.py index 75f8fb4bf..28b4c2bf3 100644 --- a/smartsim/_core/mli/comm/channel/dragonfli.py +++ b/smartsim/_core/mli/comm/channel/dragonfli.py @@ -57,13 +57,16 @@ def send(self, value: bytes) -> None: with self._fli.sendh(timeout=None, stream_channel=self._channel) as sendh: sendh.send_bytes(value) - def recv(self) -> bytes: + def recv(self) -> t.List[bytes]: """Receieve a message through the underlying communication channel :returns: the received message""" + messages = [] + eot = False with self._fli.recvh(timeout=None) as recvh: - try: - request_bytes: bytes - request_bytes, _ = recvh.recv_bytes(timeout=None) - return request_bytes - except fli.FLIEOT as exc: - return b"" + while not eot: + try: + message, _ = recvh.recv_bytes(timeout=None) + messages.append(message) + except fli.FLIEOT as exc: + eot = True + return messages diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index 8e3ed3fb4..27f5bfc97 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -58,6 +58,7 @@ from smartsim._core.mli.mli_schemas.model.model_capnp import Model from smartsim._core.mli.mli_schemas.response.response_capnp import Status + from smartsim._core.mli.mli_schemas.tensor.tensor_capnp import TensorDescriptor logger = get_logger(__name__) @@ -88,25 +89,23 @@ def deserialize_message( elif request.model.which() == "data": model_bytes = request.model.data - callback_key = request.replyChannel.reply + callback_key = request.replyChannel.descriptor # todo: shouldn't this be `CommChannel.find` instead of `DragonCommChannel` comm_channel = channel_type(callback_key) # comm_channel = DragonCommChannel(request.replyChannel) input_keys: t.Optional[t.List[str]] = None - input_bytes: t.Optional[t.List[bytes]] = ( - None # these will really be tensors already - ) + input_bytes: t.Optional[t.List[bytes]] = None + output_keys: t.Optional[t.List[str]] = None - input_meta: t.List[t.Any] = [] + input_meta: t.Optional[t.List[TensorDescriptor]] = None if request.input.which() == "keys": input_keys = [input_key.key for input_key in request.input.keys] - elif request.input.which() == "data": - input_bytes = [data.blob for data in request.input.data] - input_meta = [data.tensorDescriptor for data in request.input.data] + elif request.input.which() == "descriptors": + input_meta = request.input.descriptors # type: ignore if request.output: output_keys = [tensor_key.key for tensor_key in request.output] @@ -142,20 +141,13 @@ def prepare_outputs(reply: InferenceReply) -> t.List[t.Any]: msg_key = MessageHandler.build_tensor_key(key) prepared_outputs.append(msg_key) elif reply.outputs: - arrays: t.List[np.ndarray[t.Any, np.dtype[t.Any]]] = [ - output.numpy() for output in reply.outputs - ] - for tensor in arrays: - # todo: need to have the output attributes specified in the req? - # maybe, add `MessageHandler.dtype_of(tensor)`? - # can `build_tensor` do dtype and shape? - msg_tensor = MessageHandler.build_tensor( - tensor, + for _ in reply.outputs: + msg_tensor_desc = MessageHandler.build_tensor_descriptor( "c", "float32", [1], ) - prepared_outputs.append(msg_tensor) + prepared_outputs.append(msg_tensor_desc) return prepared_outputs @@ -272,13 +264,28 @@ def _on_iteration(self) -> None: return timings = [] # timing - # perform default deserialization of the message envelope - request_bytes: bytes = self._task_queue.recv() + + bytes_list: t.List[bytes] = self._task_queue.recv() + + if not bytes_list: + exception_handler( + ValueError("No request data found"), + None, + "No request data found.", + ) + return + + request_bytes = bytes_list[0] + tensor_bytes_list = bytes_list[1:] interm = time.perf_counter() # timing request = deserialize_message( request_bytes, self._comm_channel_type, self._device ) + + if request.input_meta and tensor_bytes_list: + request.raw_inputs = tensor_bytes_list + if not self._validate_request(request): return @@ -430,7 +437,12 @@ def _on_iteration(self) -> None: timings.append(time.perf_counter() - interm) # timing interm = time.perf_counter() # timing if request.callback: + # send serialized response request.callback.send(serialized_resp) + if reply.outputs: + # send tensor data after response + for output in reply.outputs: + request.callback.send(output) timings.append(time.perf_counter() - interm) # timing interm = time.perf_counter() # timing diff --git a/smartsim/_core/mli/infrastructure/worker/torch_worker.py b/smartsim/_core/mli/infrastructure/worker/torch_worker.py index a4e725ab9..e732ecd2c 100644 --- a/smartsim/_core/mli/infrastructure/worker/torch_worker.py +++ b/smartsim/_core/mli/infrastructure/worker/torch_worker.py @@ -110,10 +110,16 @@ def transform_output( result_device: str, ) -> TransformOutputResult: if result_device != "cpu": - transformed = [item.to("cpu") for item in execute_result.predictions] + transformed = [ + item.to("cpu").numpy().tobytes() for item in execute_result.predictions + ] + # todo: need the shape from latest schemas added here. return TransformOutputResult(transformed, None, "c", "float32") # fixme return TransformOutputResult( - execute_result.predictions, None, "c", "float32" + [item.numpy().tobytes() for item in execute_result.predictions], + None, + "c", + "float32", ) # fixme diff --git a/smartsim/_core/mli/infrastructure/worker/worker.py b/smartsim/_core/mli/infrastructure/worker/worker.py index dd874abe3..bb8d82231 100644 --- a/smartsim/_core/mli/infrastructure/worker/worker.py +++ b/smartsim/_core/mli/infrastructure/worker/worker.py @@ -59,7 +59,7 @@ def __init__( self.model_key = model_key self.raw_model = raw_model self.callback = callback - self.raw_inputs = raw_inputs + self.raw_inputs = raw_inputs or [] self.input_keys = input_keys or [] self.input_meta = input_meta or [] self.output_keys = output_keys or [] diff --git a/smartsim/_core/mli/message_handler.py b/smartsim/_core/mli/message_handler.py index 4fe2bef3a..00670dce8 100644 --- a/smartsim/_core/mli/message_handler.py +++ b/smartsim/_core/mli/message_handler.py @@ -25,8 +25,6 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import typing as t -import numpy as np - from .mli_schemas.data import data_references_capnp from .mli_schemas.model import model_capnp from .mli_schemas.request import request_capnp @@ -38,17 +36,15 @@ class MessageHandler: @staticmethod - def build_tensor( - tensor: np.ndarray[t.Any, np.dtype[t.Any]], + def build_tensor_descriptor( order: "tensor_capnp.Order", data_type: "tensor_capnp.NumericalType", dimensions: t.List[int], - ) -> tensor_capnp.Tensor: + ) -> tensor_capnp.TensorDescriptor: """ - Builds a Tensor message using the provided data, + Builds a TensorDescriptor message using the provided order, data type, and dimensions. - :param tensor: Tensor to build the message around :param order: Order of the tensor, such as row-major (c) or column-major (f) :param data_type: Data type of the tensor :param dimensions: Dimensions of the tensor @@ -59,15 +55,12 @@ def build_tensor( description.order = order description.dataType = data_type description.dimensions = dimensions - built_tensor = tensor_capnp.Tensor.new_message() - built_tensor.blob = tensor.tobytes() # tensor channel instead? - built_tensor.tensorDescriptor = description except Exception as e: raise ValueError( - "Error building tensor." + "Error building tensor descriptor." ) from e # TODO: create custom exception - return built_tensor + return description @staticmethod def build_output_tensor_descriptor( @@ -240,7 +233,7 @@ def _assign_reply_channel( :raises ValueError: if building fails """ try: - request.replyChannel.reply = reply_channel + request.replyChannel.descriptor = reply_channel except Exception as e: raise ValueError("Error building reply channel portion of request.") from e @@ -248,7 +241,8 @@ def _assign_reply_channel( def _assign_inputs( request: request_capnp.Request, inputs: t.Union[ - t.List[data_references_capnp.TensorKey], t.List[tensor_capnp.Tensor] + t.List[data_references_capnp.TensorKey], + t.List[tensor_capnp.TensorDescriptor], ], ) -> None: """ @@ -262,14 +256,13 @@ def _assign_inputs( if inputs: display_name = inputs[0].schema.node.displayName # type: ignore input_class_name = display_name.split(":")[-1] - if input_class_name == "Tensor": - request.input.data = inputs # type: ignore + if input_class_name == "TensorDescriptor": + request.input.descriptors = inputs # type: ignore elif input_class_name == "TensorKey": request.input.keys = inputs # type: ignore else: - raise ValueError( - "Invalid input class name. Expected 'Tensor' or 'TensorKey'." - ) + raise ValueError("""Invalid input class name. Expected + 'TensorDescriptor' or 'TensorKey'.""") except Exception as e: raise ValueError("Error building inputs portion of request.") from e @@ -351,7 +344,8 @@ def build_request( reply_channel: bytes, model: t.Union[data_references_capnp.ModelKey, model_capnp.Model], inputs: t.Union[ - t.List[data_references_capnp.TensorKey], t.List[tensor_capnp.Tensor] + t.List[data_references_capnp.TensorKey], + t.List[tensor_capnp.TensorDescriptor], ], outputs: t.List[data_references_capnp.TensorKey], output_descriptors: t.List[tensor_capnp.OutputDescriptor], @@ -437,7 +431,8 @@ def _assign_message(response: response_capnp.Response, message: str) -> None: def _assign_result( response: response_capnp.Response, result: t.Union[ - t.List[tensor_capnp.Tensor], t.List[data_references_capnp.TensorKey] + t.List[tensor_capnp.TensorDescriptor], + t.List[data_references_capnp.TensorKey], ], ) -> None: """ @@ -452,13 +447,13 @@ def _assign_result( first_result = result[0] display_name = first_result.schema.node.displayName # type: ignore result_class_name = display_name.split(":")[-1] - if result_class_name == "Tensor": - response.result.data = result # type: ignore + if result_class_name == "TensorDescriptor": + response.result.descriptors = result # type: ignore elif result_class_name == "TensorKey": response.result.keys = result # type: ignore else: raise ValueError("""Invalid custom attribute class name. - Expected 'Tensor' or 'TensorKey'.""") + Expected 'TensorDescriptor' or 'TensorKey'.""") except Exception as e: raise ValueError("Error assigning result to response.") from e @@ -501,7 +496,8 @@ def build_response( status: "response_capnp.Status", message: str, result: t.Union[ - t.List[tensor_capnp.Tensor], t.List[data_references_capnp.TensorKey] + t.List[tensor_capnp.TensorDescriptor], + t.List[data_references_capnp.TensorKey], ], custom_attributes: t.Union[ response_attributes_capnp.TorchResponseAttributes, diff --git a/smartsim/_core/mli/mli_schemas/request/request.capnp b/smartsim/_core/mli/mli_schemas/request/request.capnp index f9508cb54..4be1cfa21 100644 --- a/smartsim/_core/mli/mli_schemas/request/request.capnp +++ b/smartsim/_core/mli/mli_schemas/request/request.capnp @@ -32,7 +32,7 @@ using DataRef = import "../data/data_references.capnp"; using Models = import "../model/model.capnp"; struct ChannelDescriptor { - reply @0 :Data; + descriptor @0 :Data; } struct Request { @@ -43,7 +43,7 @@ struct Request { } input :union { keys @3 :List(DataRef.TensorKey); - data @4 :List(Tensors.Tensor); + descriptors @4 :List(Tensors.TensorDescriptor); } output @5 :List(DataRef.TensorKey); outputDescriptors @6 :List(Tensors.OutputDescriptor); diff --git a/smartsim/_core/mli/mli_schemas/request/request_capnp.pyi b/smartsim/_core/mli/mli_schemas/request/request_capnp.pyi index 39093f61a..a4ad631f9 100644 --- a/smartsim/_core/mli/mli_schemas/request/request_capnp.pyi +++ b/smartsim/_core/mli/mli_schemas/request/request_capnp.pyi @@ -47,9 +47,9 @@ from ..tensor.tensor_capnp import ( OutputDescriptor, OutputDescriptorBuilder, OutputDescriptorReader, - Tensor, - TensorBuilder, - TensorReader, + TensorDescriptor, + TensorDescriptorBuilder, + TensorDescriptorReader, ) from .request_attributes.request_attributes_capnp import ( TensorFlowRequestAttributes, @@ -61,7 +61,7 @@ from .request_attributes.request_attributes_capnp import ( ) class ChannelDescriptor: - reply: bytes + descriptor: bytes @staticmethod @contextmanager def from_bytes( @@ -143,8 +143,10 @@ class Request: class Input: keys: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader] - data: Sequence[Tensor | TensorBuilder | TensorReader] - def which(self) -> Literal["keys", "data"]: ... + descriptors: Sequence[ + TensorDescriptor | TensorDescriptorBuilder | TensorDescriptorReader + ] + def which(self) -> Literal["keys", "descriptors"]: ... @staticmethod @contextmanager def from_bytes( @@ -164,12 +166,14 @@ class Request: class InputReader(Request.Input): keys: Sequence[TensorKeyReader] - data: Sequence[TensorReader] + descriptors: Sequence[TensorDescriptorReader] def as_builder(self) -> Request.InputBuilder: ... class InputBuilder(Request.Input): keys: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader] - data: Sequence[Tensor | TensorBuilder | TensorReader] + descriptors: Sequence[ + TensorDescriptor | TensorDescriptorBuilder | TensorDescriptorReader + ] @staticmethod def from_dict(dictionary: dict) -> Request.InputBuilder: ... def copy(self) -> Request.InputBuilder: ... diff --git a/smartsim/_core/mli/mli_schemas/response/response.capnp b/smartsim/_core/mli/mli_schemas/response/response.capnp index 83aa05a41..7194524cd 100644 --- a/smartsim/_core/mli/mli_schemas/response/response.capnp +++ b/smartsim/_core/mli/mli_schemas/response/response.capnp @@ -42,7 +42,7 @@ struct Response { message @1 :Text; result :union { keys @2 :List(DataRef.TensorKey); - data @3 :List(Tensors.Tensor); + descriptors @3 :List(Tensors.TensorDescriptor); } customAttributes :union { torch @4 :ResponseAttributes.TorchResponseAttributes; diff --git a/smartsim/_core/mli/mli_schemas/response/response_capnp.pyi b/smartsim/_core/mli/mli_schemas/response/response_capnp.pyi index f19bdefe0..6b4c50fd0 100644 --- a/smartsim/_core/mli/mli_schemas/response/response_capnp.pyi +++ b/smartsim/_core/mli/mli_schemas/response/response_capnp.pyi @@ -35,7 +35,11 @@ from io import BufferedWriter from typing import Iterator, Literal, Sequence, overload from ..data.data_references_capnp import TensorKey, TensorKeyBuilder, TensorKeyReader -from ..tensor.tensor_capnp import Tensor, TensorBuilder, TensorReader +from ..tensor.tensor_capnp import ( + TensorDescriptor, + TensorDescriptorBuilder, + TensorDescriptorReader, +) from .response_attributes.response_attributes_capnp import ( TensorFlowResponseAttributes, TensorFlowResponseAttributesBuilder, @@ -50,8 +54,10 @@ Status = Literal["complete", "fail", "timeout", "running"] class Response: class Result: keys: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader] - data: Sequence[Tensor | TensorBuilder | TensorReader] - def which(self) -> Literal["keys", "data"]: ... + descriptors: Sequence[ + TensorDescriptor | TensorDescriptorBuilder | TensorDescriptorReader + ] + def which(self) -> Literal["keys", "descriptors"]: ... @staticmethod @contextmanager def from_bytes( @@ -71,12 +77,14 @@ class Response: class ResultReader(Response.Result): keys: Sequence[TensorKeyReader] - data: Sequence[TensorReader] + descriptors: Sequence[TensorDescriptorReader] def as_builder(self) -> Response.ResultBuilder: ... class ResultBuilder(Response.Result): keys: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader] - data: Sequence[Tensor | TensorBuilder | TensorReader] + descriptors: Sequence[ + TensorDescriptor | TensorDescriptorBuilder | TensorDescriptorReader + ] @staticmethod def from_dict(dictionary: dict) -> Response.ResultBuilder: ... def copy(self) -> Response.ResultBuilder: ... diff --git a/smartsim/_core/mli/mli_schemas/tensor/tensor.capnp b/smartsim/_core/mli/mli_schemas/tensor/tensor.capnp index aca1ce083..4b2218b16 100644 --- a/smartsim/_core/mli/mli_schemas/tensor/tensor.capnp +++ b/smartsim/_core/mli/mli_schemas/tensor/tensor.capnp @@ -58,12 +58,7 @@ enum ReturnNumericalType { float32 @8; float64 @9; none @10; - auto @ 11; -} - -struct Tensor { - blob @0 :Data; - tensorDescriptor @1 :TensorDescriptor; + auto @11; } struct TensorDescriptor { diff --git a/smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.py b/smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.py index aa7f1e7b1..8c9d6c902 100644 --- a/smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.py +++ b/smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.py @@ -33,9 +33,6 @@ capnp.remove_import_hook() here = os.path.dirname(os.path.abspath(__file__)) module_file = os.path.abspath(os.path.join(here, "tensor.capnp")) -Tensor = capnp.load(module_file).Tensor -TensorBuilder = Tensor -TensorReader = Tensor TensorDescriptor = capnp.load(module_file).TensorDescriptor TensorDescriptorBuilder = TensorDescriptor TensorDescriptorReader = TensorDescriptor diff --git a/smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.pyi b/smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.pyi index 7e7222ef5..b55f26b45 100644 --- a/smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.pyi +++ b/smartsim/_core/mli/mli_schemas/tensor/tensor_capnp.pyi @@ -101,49 +101,6 @@ class TensorDescriptorBuilder(TensorDescriptor): @staticmethod def write_packed(file: BufferedWriter) -> None: ... -class Tensor: - blob: bytes - tensorDescriptor: ( - TensorDescriptor | TensorDescriptorBuilder | TensorDescriptorReader - ) - def init(self, name: Literal["tensorDescriptor"]) -> TensorDescriptor: ... - @staticmethod - @contextmanager - def from_bytes( - data: bytes, - traversal_limit_in_words: int | None = ..., - nesting_limit: int | None = ..., - ) -> Iterator[TensorReader]: ... - @staticmethod - def from_bytes_packed( - data: bytes, - traversal_limit_in_words: int | None = ..., - nesting_limit: int | None = ..., - ) -> TensorReader: ... - @staticmethod - def new_message() -> TensorBuilder: ... - def to_dict(self) -> dict: ... - -class TensorReader(Tensor): - tensorDescriptor: TensorDescriptorReader - def as_builder(self) -> TensorBuilder: ... - -class TensorBuilder(Tensor): - tensorDescriptor: ( - TensorDescriptor | TensorDescriptorBuilder | TensorDescriptorReader - ) - @staticmethod - def from_dict(dictionary: dict) -> TensorBuilder: ... - def copy(self) -> TensorBuilder: ... - def to_bytes(self) -> bytes: ... - def to_bytes_packed(self) -> bytes: ... - def to_segments(self) -> list[bytes]: ... - def as_reader(self) -> TensorReader: ... - @staticmethod - def write(file: BufferedWriter) -> None: ... - @staticmethod - def write_packed(file: BufferedWriter) -> None: ... - class OutputDescriptor: order: Order optionalKeys: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader] diff --git a/tests/mli/test_torch_worker.py b/tests/mli/test_torch_worker.py index 0b1cd4ccf..b73e4a31b 100644 --- a/tests/mli/test_torch_worker.py +++ b/tests/mli/test_torch_worker.py @@ -95,17 +95,18 @@ def create_torch_model(): def get_request() -> InferenceRequest: tensors = [get_batch() for _ in range(2)] - serialized_tensors = [ - MessageHandler.build_tensor(tensor.numpy(), "c", "float32", list(tensor.shape)) + tensor_numpy = [tensor.numpy() for tensor in tensors] + serialized_tensors_descriptors = [ + MessageHandler.build_tensor_descriptor("c", "float32", list(tensor.shape)) for tensor in tensors ] return InferenceRequest( model_key="model", callback=None, - raw_inputs=[s_tensor.blob for s_tensor in serialized_tensors], + raw_inputs=tensor_numpy, input_keys=None, - input_meta=[s_tensor.tensorDescriptor for s_tensor in serialized_tensors], + input_meta=serialized_tensors_descriptors, output_keys=None, raw_model=create_torch_model(), batch_size=0, @@ -167,7 +168,9 @@ def test_transform_output(mlutils): sample_request, execute_result, torch_device[mlutils.get_test_device().lower()] ) - assert transformed_output.outputs == execute_result.predictions + assert transformed_output.outputs == [ + item.numpy().tobytes() for item in execute_result.predictions + ] assert transformed_output.shape == None assert transformed_output.order == "c" assert transformed_output.dtype == "float32" diff --git a/tests/test_message_handler/test_build_tensor.py b/tests/test_message_handler/test_build_tensor.py deleted file mode 100644 index aa7bd4e6e..000000000 --- a/tests/test_message_handler/test_build_tensor.py +++ /dev/null @@ -1,185 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import pytest - -try: - import tensorflow as tf -except ImportError: - should_run_tf = False -else: - should_run_tf = True - - small_tf_tensor = tf.zeros((3, 2, 5), dtype=tf.int8) - small_tf_tensor = small_tf_tensor.numpy() - medium_tf_tensor = tf.ones((1040, 1040, 3), dtype=tf.int64) - medium_tf_tensor = medium_tf_tensor.numpy() - - -try: - import torch -except ImportError: - should_run_torch = False -else: - should_run_torch = True - - small_torch_tensor = torch.zeros((3, 2, 5), dtype=torch.int8) - small_torch_tensor = small_torch_tensor.numpy() - medium_torch_tensor = torch.ones((1040, 1040, 3), dtype=torch.int64) - medium_torch_tensor = medium_torch_tensor.numpy() - -from smartsim._core.mli.message_handler import MessageHandler - -# The tests in this file belong to the group_a group -pytestmark = pytest.mark.group_a - -handler = MessageHandler() - - -@pytest.mark.skipif(not should_run_torch, reason="Test needs Torch to run") -@pytest.mark.parametrize( - "tensor, dtype, order, dimension", - [ - pytest.param( - small_torch_tensor, - "int8", - "c", - [3, 2, 5], - id="small torch tensor", - ), - pytest.param( - medium_torch_tensor, - "int64", - "c", - [1040, 1040, 3], - id="medium torch tensor", - ), - ], -) -def test_build_torch_tensor_successful(tensor, dtype, order, dimension): - built_tensor = handler.build_tensor(tensor, order, dtype, dimension) - assert built_tensor is not None - assert type(built_tensor.blob) == bytes - assert built_tensor.tensorDescriptor.order == order - assert built_tensor.tensorDescriptor.dataType == dtype - for i, j in zip(built_tensor.tensorDescriptor.dimensions, dimension): - assert i == j - - -@pytest.mark.skipif(not should_run_tf, reason="Test needs TF to run") -@pytest.mark.parametrize( - "tensor, dtype, order, dimension", - [ - pytest.param( - small_tf_tensor, - "int8", - "c", - [3, 2, 5], - id="small tf tensor", - ), - pytest.param( - medium_tf_tensor, - "int64", - "c", - [1040, 1040, 3], - id="medium tf tensor", - ), - ], -) -def test_build_tf_tensor_successful(tensor, dtype, order, dimension): - built_tensor = handler.build_tensor(tensor, order, dtype, dimension) - assert built_tensor is not None - assert type(built_tensor.blob) == bytes - assert built_tensor.tensorDescriptor.order == order - assert built_tensor.tensorDescriptor.dataType == dtype - for i, j in zip(built_tensor.tensorDescriptor.dimensions, dimension): - assert i == j - - -@pytest.mark.skipif(not should_run_torch, reason="Test needs Torch to run") -@pytest.mark.parametrize( - "tensor, dtype, order, dimension", - [ - pytest.param([1, 2, 4], "c", "int8", [1, 2, 3], id="bad tensor type"), - pytest.param( - small_torch_tensor, - "bad_order", - "int8", - [3, 2, 5], - id="bad order type", - ), - pytest.param( - small_torch_tensor, - "f", - "bad_num_type", - [3, 2, 5], - id="bad numerical type", - ), - pytest.param( - small_torch_tensor, - "f", - "int8", - "bad shape type", - id="bad shape type", - ), - ], -) -def test_build_torch_tensor_bad_input(tensor, dtype, order, dimension): - with pytest.raises(ValueError): - built_tensor = handler.build_tensor(tensor, order, dtype, dimension) - - -@pytest.mark.skipif(not should_run_tf, reason="Test needs TF to run") -@pytest.mark.parametrize( - "tensor, dtype, order, dimension", - [ - pytest.param([1, 2, 4], "c", "int8", [1, 2, 3], id="bad tensor type"), - pytest.param( - small_tf_tensor, - "bad_order", - "int8", - [3, 2, 5], - id="bad order type", - ), - pytest.param( - small_tf_tensor, - "f", - "bad_num_type", - [3, 2, 5], - id="bad numerical type", - ), - pytest.param( - small_tf_tensor, - "f", - "int8", - "bad shape type", - id="bad shape type", - ), - ], -) -def test_build_tf_tensor_bad_input(tensor, dtype, order, dimension): - with pytest.raises(ValueError): - built_tensor = handler.build_tensor(tensor, order, dtype, dimension) diff --git a/tests/test_message_handler/test_build_tensor_desc.py b/tests/test_message_handler/test_build_tensor_desc.py new file mode 100644 index 000000000..45126fb16 --- /dev/null +++ b/tests/test_message_handler/test_build_tensor_desc.py @@ -0,0 +1,90 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from smartsim._core.mli.message_handler import MessageHandler + +# The tests in this file belong to the group_a group +pytestmark = pytest.mark.group_a + +handler = MessageHandler() + + +@pytest.mark.parametrize( + "dtype, order, dimension", + [ + pytest.param( + "int8", + "c", + [3, 2, 5], + id="small torch tensor", + ), + pytest.param( + "int64", + "c", + [1040, 1040, 3], + id="medium torch tensor", + ), + ], +) +def test_build_tensor_descriptor_successful(dtype, order, dimension): + built_tensor_descriptor = handler.build_tensor_descriptor(order, dtype, dimension) + assert built_tensor_descriptor is not None + assert built_tensor_descriptor.order == order + assert built_tensor_descriptor.dataType == dtype + for i, j in zip(built_tensor_descriptor.dimensions, dimension): + assert i == j + + +@pytest.mark.parametrize( + "dtype, order, dimension", + [ + pytest.param( + "bad_order", + "int8", + [3, 2, 5], + id="bad order type", + ), + pytest.param( + "f", + "bad_num_type", + [3, 2, 5], + id="bad numerical type", + ), + pytest.param( + "f", + "int8", + "bad shape type", + id="bad shape type", + ), + ], +) +def test_build_tensor_descriptor_unsuccessful(dtype, order, dimension): + with pytest.raises(ValueError): + built_tensor_descriptor = handler.build_tensor_descriptor( + order, dtype, dimension + ) diff --git a/tests/test_message_handler/test_request.py b/tests/test_message_handler/test_request.py index b1fedaa02..4cfc11584 100644 --- a/tests/test_message_handler/test_request.py +++ b/tests/test_message_handler/test_request.py @@ -28,46 +28,6 @@ from smartsim._core.mli.message_handler import MessageHandler -try: - import tensorflow as tf -except ImportError: - should_run_tf = False -else: - should_run_tf = True - tflow1 = tf.zeros((3, 2, 5), dtype=tf.int8) - tflow2 = tf.ones((10, 10, 3), dtype=tf.int64) - - tensor_3 = MessageHandler.build_tensor( - tflow1.numpy(), "c", "int8", list(tflow1.shape) - ) - tensor_4 = MessageHandler.build_tensor( - tflow2.numpy(), "c", "int64", list(tflow2.shape) - ) - - tf_attributes = MessageHandler.build_tf_request_attributes( - name="tf", tensor_type="sparse" - ) - - -try: - import torch -except ImportError: - should_run_torch = False -else: - should_run_torch = True - - torch1 = torch.zeros((3, 2, 5), dtype=torch.int8) - torch2 = torch.ones((10, 10, 3), dtype=torch.int64) - - tensor_1 = MessageHandler.build_tensor( - torch1.numpy(), "c", "int8", list(torch1.shape) - ) - tensor_2 = MessageHandler.build_tensor( - torch2.numpy(), "c", "int64", list(torch2.shape) - ) - - torch_attributes = MessageHandler.build_torch_request_attributes("sparse") - # The tests in this file belong to the group_a group pytestmark = pytest.mark.group_a @@ -87,123 +47,54 @@ output_descriptor3 = MessageHandler.build_output_tensor_descriptor( "c", [output_key1], "none", [1, 2, 3] ) +torch_attributes = MessageHandler.build_torch_request_attributes("sparse") +tf_attributes = MessageHandler.build_tf_request_attributes( + name="tf", tensor_type="sparse" +) +tensor_1 = MessageHandler.build_tensor_descriptor("c", "int8", [1]) +tensor_2 = MessageHandler.build_tensor_descriptor("c", "int64", [3, 2]) +tensor_3 = MessageHandler.build_tensor_descriptor("f", "int8", [1]) +tensor_4 = MessageHandler.build_tensor_descriptor("f", "int64", [3, 2]) -if should_run_tf: - tf_indirect_request = MessageHandler.build_request( - b"reply", - model, - [input_key1, input_key2], - [output_key1, output_key2], - [output_descriptor1, output_descriptor2, output_descriptor3], - tf_attributes, - ) - tf_direct_request = MessageHandler.build_request( - b"reply", - model, - [tensor_3, tensor_4], - [], - [output_descriptor1, output_descriptor2], - tf_attributes, - ) +tf_indirect_request = MessageHandler.build_request( + b"reply", + model, + [input_key1, input_key2], + [output_key1, output_key2], + [output_descriptor1, output_descriptor2, output_descriptor3], + tf_attributes, +) -if should_run_torch: - torch_indirect_request = MessageHandler.build_request( - b"reply", - model, - [input_key1, input_key2], - [output_key1, output_key2], - [output_descriptor1, output_descriptor2, output_descriptor3], - torch_attributes, - ) - torch_direct_request = MessageHandler.build_request( - b"reply", - model, - [tensor_1, tensor_2], - [], - [output_descriptor1, output_descriptor2], - torch_attributes, - ) +tf_direct_request = MessageHandler.build_request( + b"reply", + model, + [tensor_3, tensor_4], + [], + [output_descriptor1, output_descriptor2], + tf_attributes, +) +torch_indirect_request = MessageHandler.build_request( + b"reply", + model, + [input_key1, input_key2], + [output_key1, output_key2], + [output_descriptor1, output_descriptor2, output_descriptor3], + torch_attributes, +) -@pytest.mark.skipif(not should_run_tf, reason="Test needs TF to run") -@pytest.mark.parametrize( - "reply_channel, model, input, output, output_descriptors, custom_attributes", - [ - pytest.param( - b"reply channel", - model_key, - [input_key1, input_key2], - [output_key1, output_key2], - [output_descriptor1], - tf_attributes, - ), - pytest.param( - b"another reply channel", - model, - [input_key1], - [output_key2], - [output_descriptor1], - tf_attributes, - ), - pytest.param( - b"another reply channel", - model, - [input_key1], - [output_key2], - [output_descriptor1], - tf_attributes, - ), - pytest.param( - b"reply channel", - model_key, - [input_key1], - [output_key1], - [output_descriptor1], - None, - ), - ], +torch_direct_request = MessageHandler.build_request( + b"reply", + model, + [tensor_1, tensor_2], + [], + [output_descriptor1, output_descriptor2], + torch_attributes, ) -def test_build_request_indirect_tf_successful( - reply_channel, model, input, output, output_descriptors, custom_attributes -): - built_request = MessageHandler.build_request( - reply_channel, - model, - input, - output, - output_descriptors, - custom_attributes, - ) - assert built_request is not None - assert built_request.replyChannel.reply == reply_channel - if built_request.model.which() == "key": - assert built_request.model.key.key == model.key - else: - assert built_request.model.data.data == model.data - assert built_request.model.data.name == model.name - assert built_request.model.data.version == model.version - assert built_request.input.which() == "keys" - assert built_request.input.keys[0].key == input[0].key - assert len(built_request.input.keys) == len(input) - assert len(built_request.output) == len(output) - for i, j in zip(built_request.outputDescriptors, output_descriptors): - assert i.order == j.order - if built_request.customAttributes.which() == "tf": - assert ( - built_request.customAttributes.tf.tensorType == custom_attributes.tensorType - ) - elif built_request.customAttributes.which() == "torch": - assert ( - built_request.customAttributes.torch.tensorType - == custom_attributes.tensorType - ) - else: - assert built_request.customAttributes.none == custom_attributes -@pytest.mark.skipif(not should_run_torch, reason="Test needs Torch to run") @pytest.mark.parametrize( "reply_channel, model, input, output, output_descriptors, custom_attributes", [ @@ -221,7 +112,7 @@ def test_build_request_indirect_tf_successful( [input_key1], [output_key2], [output_descriptor1], - torch_attributes, + tf_attributes, ), pytest.param( b"another reply channel", @@ -241,7 +132,7 @@ def test_build_request_indirect_tf_successful( ), ], ) -def test_build_request_indirect_torch_successful( +def test_build_request_indirect_successful( reply_channel, model, input, output, output_descriptors, custom_attributes ): built_request = MessageHandler.build_request( @@ -253,7 +144,7 @@ def test_build_request_indirect_torch_successful( custom_attributes, ) assert built_request is not None - assert built_request.replyChannel.reply == reply_channel + assert built_request.replyChannel.descriptor == reply_channel if built_request.model.which() == "key": assert built_request.model.key.key == model.key else: @@ -279,108 +170,6 @@ def test_build_request_indirect_torch_successful( assert built_request.customAttributes.none == custom_attributes -@pytest.mark.skipif(not should_run_torch, reason="Test needs Torch to run") -@pytest.mark.parametrize( - "reply_channel, model, input, output, output_descriptors, custom_attributes", - [ - pytest.param( - [], - model_key, - [input_key1, input_key2], - [output_key1, output_key2], - [output_descriptor1], - torch_attributes, - id="bad channel", - ), - pytest.param( - b"reply channel", - "bad model", - [input_key1], - [output_key2], - [output_descriptor1], - torch_attributes, - id="bad model", - ), - pytest.param( - b"reply channel", - model_key, - ["input_key1", "input_key2"], - [output_key1, output_key2], - [output_descriptor1], - torch_attributes, - id="bad inputs", - ), - pytest.param( - b"reply channel", - model_key, - [model_key], - [output_key1, output_key2], - [output_descriptor1], - torch_attributes, - id="bad input schema type", - ), - pytest.param( - b"reply channel", - model_key, - [input_key1], - ["output_key1", "output_key2"], - [output_descriptor1], - torch_attributes, - id="bad outputs", - ), - pytest.param( - b"reply channel", - model_key, - [input_key1], - [model_key], - [output_descriptor1], - torch_attributes, - id="bad output schema type", - ), - pytest.param( - b"reply channel", - model_key, - [input_key1], - [output_key1, output_key2], - [output_descriptor1], - "bad attributes", - id="bad custom attributes", - ), - pytest.param( - b"reply channel", - model_key, - [input_key1], - [output_key1, output_key2], - [output_descriptor1], - model_key, - id="bad custom attributes schema type", - ), - pytest.param( - b"reply channel", - model_key, - [input_key1], - [output_key1, output_key2], - "bad descriptors", - torch_attributes, - id="bad output descriptors", - ), - ], -) -def test_build_request_indirect_torch_unsuccessful( - reply_channel, model, input, output, output_descriptors, custom_attributes -): - with pytest.raises(ValueError): - built_request = MessageHandler.build_request( - reply_channel, - model, - input, - output, - output_descriptors, - custom_attributes, - ) - - -@pytest.mark.skipif(not should_run_tf, reason="Test needs TF to run") @pytest.mark.parametrize( "reply_channel, model, input, output, output_descriptors, custom_attributes", [ @@ -399,7 +188,7 @@ def test_build_request_indirect_torch_unsuccessful( [input_key1], [output_key2], [output_descriptor1], - tf_attributes, + torch_attributes, id="bad model", ), pytest.param( @@ -417,7 +206,7 @@ def test_build_request_indirect_torch_unsuccessful( [model_key], [output_key1, output_key2], [output_descriptor1], - tf_attributes, + torch_attributes, id="bad input schema type", ), pytest.param( @@ -462,12 +251,12 @@ def test_build_request_indirect_torch_unsuccessful( [input_key1], [output_key1, output_key2], "bad descriptors", - tf_attributes, + torch_attributes, id="bad output descriptors", ), ], ) -def test_build_request_indirect_tf_unsuccessful( +def test_build_request_indirect_unsuccessful( reply_channel, model, input, output, output_descriptors, custom_attributes ): with pytest.raises(ValueError): @@ -481,7 +270,6 @@ def test_build_request_indirect_tf_unsuccessful( ) -@pytest.mark.skipif(not should_run_torch, reason="Test needs Torch to run") @pytest.mark.parametrize( "reply_channel, model, input, output, output_descriptors, custom_attributes", [ @@ -499,88 +287,12 @@ def test_build_request_indirect_tf_unsuccessful( [tensor_1], [], [output_descriptor3], - torch_attributes, - ), - pytest.param( - b"another reply channel", - model, - [tensor_2], - [], - [output_descriptor1], - torch_attributes, - ), - pytest.param( - b"another reply channel", - model, - [tensor_1], - [], - [output_descriptor1], - None, - ), - ], -) -def test_build_request_direct_torch_successful( - reply_channel, model, input, output, output_descriptors, custom_attributes -): - built_request = MessageHandler.build_request( - reply_channel, - model, - input, - output, - output_descriptors, - custom_attributes, - ) - assert built_request is not None - assert built_request.replyChannel.reply == reply_channel - if built_request.model.which() == "key": - assert built_request.model.key.key == model.key - else: - assert built_request.model.data.data == model.data - assert built_request.model.data.name == model.name - assert built_request.model.data.version == model.version - assert built_request.input.which() == "data" - assert built_request.input.data[0].blob == input[0].blob - assert len(built_request.input.data) == len(input) - assert len(built_request.output) == len(output) - for i, j in zip(built_request.outputDescriptors, output_descriptors): - assert i.order == j.order - if built_request.customAttributes.which() == "tf": - assert ( - built_request.customAttributes.tf.tensorType == custom_attributes.tensorType - ) - elif built_request.customAttributes.which() == "torch": - assert ( - built_request.customAttributes.torch.tensorType - == custom_attributes.tensorType - ) - else: - assert built_request.customAttributes.none == custom_attributes - - -@pytest.mark.skipif(not should_run_tf, reason="Test needs TF to run") -@pytest.mark.parametrize( - "reply_channel, model, input, output, output_descriptors, custom_attributes", - [ - pytest.param( - b"reply channel", - model_key, - [tensor_3, tensor_4], - [], - [output_descriptor2], tf_attributes, ), pytest.param( b"another reply channel", model, - [tensor_4], - [], - [output_descriptor3], - tf_attributes, - ), - pytest.param( - b"another reply channel", - model, - [tensor_4], + [tensor_2], [], [output_descriptor1], tf_attributes, @@ -588,14 +300,14 @@ def test_build_request_direct_torch_successful( pytest.param( b"another reply channel", model, - [tensor_3], + [tensor_1], [], [output_descriptor1], None, ), ], ) -def test_build_request_direct_tf_successful( +def test_build_request_direct_successful( reply_channel, model, input, output, output_descriptors, custom_attributes ): built_request = MessageHandler.build_request( @@ -607,16 +319,15 @@ def test_build_request_direct_tf_successful( custom_attributes, ) assert built_request is not None - assert built_request.replyChannel.reply == reply_channel + assert built_request.replyChannel.descriptor == reply_channel if built_request.model.which() == "key": assert built_request.model.key.key == model.key else: assert built_request.model.data.data == model.data assert built_request.model.data.name == model.name assert built_request.model.data.version == model.version - assert built_request.input.which() == "data" - assert built_request.input.data[0].blob == input[0].blob - assert len(built_request.input.data) == len(input) + assert built_request.input.which() == "descriptors" + assert len(built_request.input.descriptors) == len(input) assert len(built_request.output) == len(output) for i, j in zip(built_request.outputDescriptors, output_descriptors): assert i.order == j.order @@ -633,81 +344,6 @@ def test_build_request_direct_tf_successful( assert built_request.customAttributes.none == custom_attributes -@pytest.mark.skipif(not should_run_torch, reason="Test needs Torch to run") -@pytest.mark.parametrize( - "reply_channel, model, input, output, output_descriptors, custom_attributes", - [ - pytest.param( - [], - model_key, - [tensor_1, tensor_2], - [], - [output_descriptor2], - torch_attributes, - id="bad channel", - ), - pytest.param( - b"reply channel", - "bad model", - [tensor_1], - [], - [output_descriptor2], - torch_attributes, - id="bad model", - ), - pytest.param( - b"reply channel", - model_key, - ["input_key1", "input_key2"], - [], - [output_descriptor2], - torch_attributes, - id="bad inputs", - ), - pytest.param( - b"reply channel", - model_key, - [], - ["output_key1", "output_key2"], - [output_descriptor2], - torch_attributes, - id="bad outputs", - ), - pytest.param( - b"reply channel", - model_key, - [tensor_1], - [], - [output_descriptor2], - "bad attributes", - id="bad custom attributes", - ), - pytest.param( - b"reply_channel", - model_key, - [tensor_1, tensor_2], - [], - ["output_descriptor2"], - torch_attributes, - id="bad output descriptors", - ), - ], -) -def test_build_torch_request_direct_unsuccessful( - reply_channel, model, input, output, output_descriptors, custom_attributes -): - with pytest.raises(ValueError): - built_request = MessageHandler.build_request( - reply_channel, - model, - input, - output, - output_descriptors, - custom_attributes, - ) - - -@pytest.mark.skipif(not should_run_tf, reason="Test needs TF to run") @pytest.mark.parametrize( "reply_channel, model, input, output, output_descriptors, custom_attributes", [ @@ -735,7 +371,7 @@ def test_build_torch_request_direct_unsuccessful( ["input_key1", "input_key2"], [], [output_descriptor2], - tf_attributes, + torch_attributes, id="bad inputs", ), pytest.param( @@ -762,12 +398,12 @@ def test_build_torch_request_direct_unsuccessful( [tensor_3, tensor_4], [], ["output_descriptor2"], - tf_attributes, + torch_attributes, id="bad output descriptors", ), ], ) -def test_build_tf_request_direct_unsuccessful( +def test_build_request_direct_unsuccessful( reply_channel, model, input, output, output_descriptors, custom_attributes ): with pytest.raises(ValueError): @@ -781,31 +417,16 @@ def test_build_tf_request_direct_unsuccessful( ) -@pytest.mark.skipif(not should_run_torch, reason="Test needs Torch to run") @pytest.mark.parametrize( "req", [ + pytest.param(tf_indirect_request, id="tf indirect"), + pytest.param(tf_direct_request, id="tf direct"), pytest.param(torch_indirect_request, id="indirect"), pytest.param(torch_direct_request, id="direct"), ], ) -def test_serialize_torch_request_successful(req): - serialized = MessageHandler.serialize_request(req) - assert type(serialized) == bytes - - deserialized = MessageHandler.deserialize_request(serialized) - assert deserialized.to_dict() == req.to_dict() - - -@pytest.mark.skipif(not should_run_tf, reason="Test needs TF to run") -@pytest.mark.parametrize( - "req", - [ - pytest.param(tf_indirect_request, id="indirect"), - pytest.param(tf_direct_request, id="direct"), - ], -) -def test_serialize_tf_request_successful(req): +def test_serialize_request_successful(req): serialized = MessageHandler.serialize_request(req) assert type(serialized) == bytes diff --git a/tests/test_message_handler/test_response.py b/tests/test_message_handler/test_response.py index 9d59a1879..03bd9ba73 100644 --- a/tests/test_message_handler/test_response.py +++ b/tests/test_message_handler/test_response.py @@ -28,60 +28,6 @@ from smartsim._core.mli.message_handler import MessageHandler -try: - import tensorflow as tf -except ImportError: - should_run_tf = False -else: - should_run_tf = True - - tflow1 = tf.zeros((3, 2, 5), dtype=tf.int8) - tflow2 = tf.ones((1040, 1040, 3), dtype=tf.int64) - - small_tf_tensor = MessageHandler.build_tensor( - tflow1.numpy(), "c", "int8", list(tflow1.shape) - ) - medium_tf_tensor = MessageHandler.build_tensor( - tflow2.numpy(), "c", "int64", list(tflow2.shape) - ) - - tf_attributes = MessageHandler.build_tf_response_attributes() - - tf_direct_response = MessageHandler.build_response( - "complete", - "Success again!", - [small_tf_tensor, medium_tf_tensor], - tf_attributes, - ) - - -try: - import torch -except ImportError: - should_run_torch = False -else: - should_run_torch = True - - torch1 = torch.zeros((3, 2, 5), dtype=torch.int8) - torch2 = torch.ones((1040, 1040, 3), dtype=torch.int64) - - small_torch_tensor = MessageHandler.build_tensor( - torch1.numpy(), "c", "int8", list(torch1.shape) - ) - medium_torch_tensor = MessageHandler.build_tensor( - torch2.numpy(), "c", "int64", list(torch2.shape) - ) - - torch_attributes = MessageHandler.build_torch_response_attributes() - - torch_direct_response = MessageHandler.build_response( - "complete", - "Success again!", - [small_torch_tensor, medium_torch_tensor], - torch_attributes, - ) - - # The tests in this file belong to the group_a group pytestmark = pytest.mark.group_a @@ -89,86 +35,51 @@ result_key1 = MessageHandler.build_tensor_key("result_key1") result_key2 = MessageHandler.build_tensor_key("result_key2") +torch_attributes = MessageHandler.build_torch_response_attributes() +tf_attributes = MessageHandler.build_tf_response_attributes() -if should_run_tf: - tf_indirect_response = MessageHandler.build_response( - "complete", - "Success!", - [result_key1, result_key2], - tf_attributes, - ) +tensor1 = MessageHandler.build_tensor_descriptor("c", "int8", [1]) +tensor2 = MessageHandler.build_tensor_descriptor("c", "int64", [3, 2]) -if should_run_torch: - torch_indirect_response = MessageHandler.build_response( - "complete", - "Success!", - [result_key1, result_key2], - torch_attributes, - ) +tf_indirect_response = MessageHandler.build_response( + "complete", + "Success!", + [result_key1, result_key2], + tf_attributes, +) -@pytest.mark.skipif(not should_run_torch, reason="Test needs Torch to run") -@pytest.mark.parametrize( - "status, status_message, result, custom_attribute", - [ - pytest.param( - 200, - "Yay, it worked!", - [small_torch_tensor, medium_torch_tensor], - None, - id="tensor list", - ), - pytest.param( - 200, - "Yay, it worked!", - [small_torch_tensor], - torch_attributes, - id="small tensor", - ), - pytest.param( - 200, - "Yay, it worked!", - [result_key1, result_key2], - torch_attributes, - id="tensor key list", - ), - ], +tf_direct_response = MessageHandler.build_response( + "complete", + "Success again!", + [tensor2, tensor1], + tf_attributes, +) + +torch_indirect_response = MessageHandler.build_response( + "complete", + "Success!", + [result_key1, result_key2], + torch_attributes, +) + +torch_direct_response = MessageHandler.build_response( + "complete", + "Success again!", + [tensor1, tensor2], + torch_attributes, ) -def test_build_torch_response_successful( - status, status_message, result, custom_attribute -): - response = MessageHandler.build_response( - status=status, - message=status_message, - result=result, - custom_attributes=custom_attribute, - ) - assert response is not None - assert response.status == status - assert response.message == status_message - if response.result.which() == "keys": - assert response.result.keys[0].to_dict() == result[0].to_dict() - else: - assert response.result.data[0].to_dict() == result[0].to_dict() -@pytest.mark.skipif(not should_run_tf, reason="Test needs TF to run") @pytest.mark.parametrize( "status, status_message, result, custom_attribute", [ pytest.param( 200, "Yay, it worked!", - [small_tf_tensor, medium_tf_tensor], + [tensor1, tensor2], None, - id="tensor list", - ), - pytest.param( - 200, - "Yay, it worked!", - [small_tf_tensor], - tf_attributes, - id="small tensor", + id="tensor descriptor list", ), pytest.param( 200, @@ -179,7 +90,7 @@ def test_build_torch_response_successful( ), ], ) -def test_build_tf_response_successful(status, status_message, result, custom_attribute): +def test_build_response_successful(status, status_message, result, custom_attribute): response = MessageHandler.build_response( status=status, message=status_message, @@ -192,25 +103,24 @@ def test_build_tf_response_successful(status, status_message, result, custom_att if response.result.which() == "keys": assert response.result.keys[0].to_dict() == result[0].to_dict() else: - assert response.result.data[0].to_dict() == result[0].to_dict() + assert response.result.descriptors[0].to_dict() == result[0].to_dict() -@pytest.mark.skipif(not should_run_tf, reason="Test needs TF to run") @pytest.mark.parametrize( "status, status_message, result, custom_attribute", [ pytest.param( "bad status", "Yay, it worked!", - [small_tf_tensor, medium_tf_tensor], + [tensor1, tensor2], None, id="bad status", ), pytest.param( "complete", 200, - [small_tf_tensor], - tf_attributes, + [tensor2], + torch_attributes, id="bad status message", ), pytest.param( @@ -230,110 +140,36 @@ def test_build_tf_response_successful(status, status_message, result, custom_att pytest.param( "complete", "Yay, it worked!", - [small_tf_tensor, medium_tf_tensor], - "custom attributes", - id="bad custom attributes", - ), - pytest.param( - "complete", - "Yay, it worked!", - [small_tf_tensor, medium_tf_tensor], - result_key1, - id="bad custom attributes type", - ), - ], -) -def test_build_tf_response_unsuccessful( - status, status_message, result, custom_attribute -): - with pytest.raises(ValueError): - response = MessageHandler.build_response( - status, status_message, result, custom_attribute - ) - - -@pytest.mark.skipif(not should_run_torch, reason="Test needs Torch to run") -@pytest.mark.parametrize( - "status, status_message, result, custom_attribute", - [ - pytest.param( - "bad status", - "Yay, it worked!", - [small_torch_tensor, medium_torch_tensor], - None, - id="bad status", - ), - pytest.param( - "complete", - 200, - [small_torch_tensor], - torch_attributes, - id="bad status message", - ), - pytest.param( - "complete", - "Yay, it worked!", - ["result_key1", "result_key2"], - torch_attributes, - id="bad result", - ), - pytest.param( - "complete", - "Yay, it worked!", - [torch_attributes], - torch_attributes, - id="bad result type", - ), - pytest.param( - "complete", - "Yay, it worked!", - [small_torch_tensor, medium_torch_tensor], + [tensor2, tensor1], "custom attributes", id="bad custom attributes", ), pytest.param( "complete", "Yay, it worked!", - [small_torch_tensor, medium_torch_tensor], + [tensor2, tensor1], result_key1, id="bad custom attributes type", ), ], ) -def test_build_torch_response_unsuccessful( - status, status_message, result, custom_attribute -): +def test_build_response_unsuccessful(status, status_message, result, custom_attribute): with pytest.raises(ValueError): response = MessageHandler.build_response( status, status_message, result, custom_attribute ) -@pytest.mark.skipif(not should_run_torch, reason="Test needs Torch to run") @pytest.mark.parametrize( "response", [ pytest.param(torch_indirect_response, id="indirect"), pytest.param(torch_direct_response, id="direct"), + pytest.param(tf_indirect_response, id="tf indirect"), + pytest.param(tf_direct_response, id="tf direct"), ], ) -def test_torch_serialize_response(response): - serialized = MessageHandler.serialize_response(response) - assert type(serialized) == bytes - - deserialized = MessageHandler.deserialize_response(serialized) - assert deserialized.to_dict() == response.to_dict() - - -@pytest.mark.skipif(not should_run_tf, reason="Test needs TF to run") -@pytest.mark.parametrize( - "response", - [ - pytest.param(tf_indirect_response, id="indirect"), - pytest.param(tf_direct_response, id="direct"), - ], -) -def test_tf_serialize_response(response): +def test_serialize_response(response): serialized = MessageHandler.serialize_response(response) assert type(serialized) == bytes