Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model metadata to request schema #624

Merged
merged 12 commits into from
Jul 3, 2024
3 changes: 2 additions & 1 deletion doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ Jump to:

Description

- Add Model schema with model metadata included
- Removed device from schemas, MessageHandler and tests
- Add ML worker manager, sample worker, and feature store
- Added schemas and MessageHandler class for de/serialization of
inference requests and response messages
- Removed device from schemas, MessageHandler and tests


### Development branch
Expand Down
23 changes: 12 additions & 11 deletions smartsim/_core/mli/infrastructure/control/workermanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from smartsim.log import get_logger

if t.TYPE_CHECKING:
from smartsim._core.mli.mli_schemas.model.model_capnp import Model
from smartsim._core.mli.mli_schemas.response.response_capnp import StatusEnum

logger = get_logger(__name__)
Expand All @@ -65,12 +66,12 @@
request = MessageHandler.deserialize_request(data_blob)
# return request
model_key: t.Optional[str] = None
model_bytes: t.Optional[bytes] = None
model_bytes: t.Optional[Model] = None

Check warning on line 69 in smartsim/_core/mli/infrastructure/control/workermanager.py

View check run for this annotation

Codecov / codecov/patch

smartsim/_core/mli/infrastructure/control/workermanager.py#L69

Added line #L69 was not covered by tests

if request.model.which() == "modelKey":
model_key = request.model.modelKey.key
elif request.model.which() == "modelData":
model_bytes = request.model.modelData
if request.model.which() == "key":
model_key = request.model.key.key
elif request.model.which() == "data":
model_bytes = request.model.data

Check warning on line 74 in smartsim/_core/mli/infrastructure/control/workermanager.py

View check run for this annotation

Codecov / codecov/patch

smartsim/_core/mli/infrastructure/control/workermanager.py#L71-L74

Added lines #L71 - L74 were not covered by tests

callback_key = request.replyChannel.reply

Expand All @@ -91,19 +92,19 @@
# # end client
input_meta: t.List[t.Any] = []

if request.input.which() == "inputKeys":
input_keys = [input_key.key for input_key in request.input.inputKeys]
elif request.input.which() == "inputData":
input_bytes = [data.blob for data in request.input.inputData]
input_meta = [data.tensorDescriptor for data in request.input.inputData]
if request.input.which() == "keys":
input_keys = [input_key.key for input_key in request.input.keys]
elif request.input.which() == "data":
input_bytes = [data.blob for data in request.input.data]
input_meta = [data.tensorDescriptor for data in request.input.data]

Check warning on line 99 in smartsim/_core/mli/infrastructure/control/workermanager.py

View check run for this annotation

Codecov / codecov/patch

smartsim/_core/mli/infrastructure/control/workermanager.py#L95-L99

Added lines #L95 - L99 were not covered by tests

inference_request = InferenceRequest(
model_key=model_key,
callback=comm_channel,
raw_inputs=input_bytes,
input_meta=input_meta,
input_keys=input_keys,
raw_model=model_bytes,
raw_model=model_bytes.data if model_bytes is not None else None,
batch_size=0,
)
return inference_request
Expand Down
50 changes: 37 additions & 13 deletions smartsim/_core/mli/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import numpy as np

from .mli_schemas.data import data_references_capnp
from .mli_schemas.model import model_capnp
from .mli_schemas.request import request_capnp
from .mli_schemas.request.request_attributes import request_attributes_capnp
from .mli_schemas.response import response_capnp
Expand Down Expand Up @@ -112,6 +113,25 @@
raise ValueError("Error building tensor key.") from e
return tensor_key

@staticmethod
def build_model(data: bytes, name: str, version: str) -> model_capnp.Model:
"""
Builds a new Model message with the provided data, name, and version.

:param data: Model data
:param name: Model name
:param version: Model version
:raises ValueError: if building fails
"""
try:
model = model_capnp.Model.new_message()
model.data = data
model.name = name
model.version = version
except Exception as e:
raise ValueError("Error building model.") from e

Check warning on line 132 in smartsim/_core/mli/message_handler.py

View check run for this annotation

Codecov / codecov/patch

smartsim/_core/mli/message_handler.py#L131-L132

Added lines #L131 - L132 were not covered by tests
return model

@staticmethod
def build_model_key(key: str) -> data_references_capnp.ModelKey:
"""
Expand Down Expand Up @@ -187,7 +207,7 @@
@staticmethod
def _assign_model(
request: request_capnp.Request,
model: t.Union[data_references_capnp.ModelKey, t.ByteString],
model: t.Union[data_references_capnp.ModelKey, model_capnp.Model],
) -> None:
"""
Assigns a model to the supplied request.
Expand All @@ -197,16 +217,20 @@
:raises ValueError: if building fails
"""
try:
if isinstance(model, bytes):
request.model.modelData = model
class_name = model.schema.node.displayName.split(":")[-1] # type: ignore
if class_name == "Model":
request.model.data = model # type: ignore
elif class_name == "ModelKey":
request.model.key = model # type: ignore

Check warning on line 224 in smartsim/_core/mli/message_handler.py

View check run for this annotation

Codecov / codecov/patch

smartsim/_core/mli/message_handler.py#L223-L224

Added lines #L223 - L224 were not covered by tests
else:
request.model.modelKey = model # type: ignore
raise ValueError("""Invalid custom attribute class name.

Check warning on line 226 in smartsim/_core/mli/message_handler.py

View check run for this annotation

Codecov / codecov/patch

smartsim/_core/mli/message_handler.py#L226

Added line #L226 was not covered by tests
Expected 'Model' or 'ModelKey'.""")
except Exception as e:
raise ValueError("Error building model portion of request.") from e

@staticmethod
def _assign_reply_channel(
request: request_capnp.Request, reply_channel: t.ByteString
request: request_capnp.Request, reply_channel: bytes
) -> None:
"""
Assigns a reply channel to the supplied request.
Expand Down Expand Up @@ -239,9 +263,9 @@
display_name = inputs[0].schema.node.displayName # type: ignore
input_class_name = display_name.split(":")[-1]
if input_class_name == "Tensor":
request.input.inputData = inputs # type: ignore
request.input.data = inputs # type: ignore
elif input_class_name == "TensorKey":
request.input.inputKeys = inputs # type: ignore
request.input.keys = inputs # type: ignore
else:
raise ValueError(
"Invalid input class name. Expected 'Tensor' or 'TensorKey'."
Expand Down Expand Up @@ -324,8 +348,8 @@

@staticmethod
def build_request(
reply_channel: t.ByteString,
model: t.Union[data_references_capnp.ModelKey, t.ByteString],
reply_channel: bytes,
model: t.Union[data_references_capnp.ModelKey, model_capnp.Model],
inputs: t.Union[
t.List[data_references_capnp.TensorKey], t.List[tensor_capnp.Tensor]
],
Expand Down Expand Up @@ -357,7 +381,7 @@
return request

@staticmethod
def serialize_request(request: request_capnp.RequestBuilder) -> t.ByteString:
def serialize_request(request: request_capnp.RequestBuilder) -> bytes:
"""
Serializes a built request message.

Expand All @@ -366,7 +390,7 @@
return request.to_bytes()

@staticmethod
def deserialize_request(request_bytes: t.ByteString) -> request_capnp.Request:
def deserialize_request(request_bytes: bytes) -> request_capnp.Request:
"""
Deserializes a serialized request message.

Expand Down Expand Up @@ -499,14 +523,14 @@
return response

@staticmethod
def serialize_response(response: response_capnp.ResponseBuilder) -> t.ByteString:
def serialize_response(response: response_capnp.ResponseBuilder) -> bytes:
"""
Serializes a built response message.
"""
return response.to_bytes()

@staticmethod
def deserialize_response(response_bytes: t.ByteString) -> response_capnp.Response:
def deserialize_response(response_bytes: bytes) -> response_capnp.Response:
"""
Deserializes a serialized response message.
"""
Expand Down
2 changes: 1 addition & 1 deletion smartsim/_core/mli/mli_schemas/data/data_references.capnp
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ struct ModelKey {

struct TensorKey {
key @0 :Text;
}
}
33 changes: 33 additions & 0 deletions smartsim/_core/mli/mli_schemas/model/model.capnp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# BSD 2-Clause License

# Copyright (c) 2021-2024, Hewlett Packard Enterprise
# All rights reserved.

# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:

# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.

# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.

# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@0xaefb9301e14ba4bd;

struct Model {
data @0 :Data;
name @1 :Text;
version @2 :Text;
}
12 changes: 12 additions & 0 deletions smartsim/_core/mli/mli_schemas/model/model_capnp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""This is an automatically generated stub for `model.capnp`."""

import os

import capnp # type: ignore

capnp.remove_import_hook()
here = os.path.dirname(os.path.abspath(__file__))
module_file = os.path.abspath(os.path.join(here, "model.capnp"))
Model = capnp.load(module_file).Model
ModelBuilder = Model
ModelReader = Model
46 changes: 46 additions & 0 deletions smartsim/_core/mli/mli_schemas/model/model_capnp.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""This is an automatically generated stub for `model.capnp`."""

# mypy: ignore-errors

from __future__ import annotations

from contextlib import contextmanager
from io import BufferedWriter
from typing import Iterator

class Model:
data: bytes
name: str
version: str
@staticmethod
@contextmanager
def from_bytes(
data: bytes,
traversal_limit_in_words: int | None = ...,
nesting_limit: int | None = ...,
) -> Iterator[ModelReader]: ...
@staticmethod
def from_bytes_packed(
data: bytes,
traversal_limit_in_words: int | None = ...,
nesting_limit: int | None = ...,
) -> ModelReader: ...
@staticmethod
def new_message() -> ModelBuilder: ...
def to_dict(self) -> dict: ...

class ModelReader(Model):
def as_builder(self) -> ModelBuilder: ...

class ModelBuilder(Model):
@staticmethod
def from_dict(dictionary: dict) -> ModelBuilder: ...
def copy(self) -> ModelBuilder: ...
def to_bytes(self) -> bytes: ...
def to_bytes_packed(self) -> bytes: ...
def to_segments(self) -> list[bytes]: ...
def as_reader(self) -> ModelReader: ...
@staticmethod
def write(file: BufferedWriter) -> None: ...
@staticmethod
def write_packed(file: BufferedWriter) -> None: ...
11 changes: 6 additions & 5 deletions smartsim/_core/mli/mli_schemas/request/request.capnp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
using Tensors = import "../tensor/tensor.capnp";
using RequestAttributes = import "request_attributes/request_attributes.capnp";
using DataRef = import "../data/data_references.capnp";
using Models = import "../model/model.capnp";

struct ChannelDescriptor {
reply @0 :Data;
Expand All @@ -37,12 +38,12 @@ struct ChannelDescriptor {
struct Request {
replyChannel @0 :ChannelDescriptor;
model :union {
modelKey @1 :DataRef.ModelKey;
modelData @2 :Data;
key @1 :DataRef.ModelKey;
data @2 :Models.Model;
}
input :union {
inputKeys @3 :List(DataRef.TensorKey);
inputData @4 :List(Tensors.Tensor);
keys @3 :List(DataRef.TensorKey);
data @4 :List(Tensors.Tensor);
}
output @5 :List(DataRef.TensorKey);
outputDescriptors @6 :List(Tensors.OutputDescriptor);
Expand All @@ -51,4 +52,4 @@ struct Request {
tf @8 :RequestAttributes.TensorFlowRequestAttributes;
none @9 :Void;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ struct TorchRequestAttributes {
struct TensorFlowRequestAttributes {
name @0 :Text;
tensorType @1 :TFTensorType;
}
}
Loading
Loading