Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mli feature #18

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
d2fd6a7
Initial MLI schemas and MessageHandler class (#607)
AlyssaCote Jun 11, 2024
3c9915c
Merge branch 'develop' into mli-feature
ankona Jun 14, 2024
38081da
ML Worker Manager MVP (#608)
ankona Jun 20, 2024
ab900b8
Remove device attribute from schemas (#619)
AlyssaCote Jun 25, 2024
a9ffb14
Merge branch 'develop' into mli-feature
ankona Jul 2, 2024
ee2c110
Merge branch 'develop' into mli-feature
ankona Jul 2, 2024
8a2f173
Add model metadata to request schema (#624)
AlyssaCote Jul 3, 2024
52abd32
Enable environment variable based configuration for ML Worker Manager…
AlyssaCote Jul 10, 2024
eace71e
FLI-based Worker Manager (#622)
al-rigazzi Jul 15, 2024
5fac3e2
Add ability to specify hardware policies on dragon run requests (#631)
ankona Jul 17, 2024
0030a4a
Revert "Add ability to specify hardware policies on dragon run reques…
ankona Jul 17, 2024
b6c2f2b
Merge latest develop into mli-feature (#640)
ankona Jul 18, 2024
272a1d7
Improve error handling in worker manager (#629)
AlyssaCote Jul 18, 2024
7169f1c
Schema performance improvements (#632)
AlyssaCote Jul 18, 2024
84101b3
New develop merger (#645)
al-rigazzi Jul 19, 2024
e225c07
merging develop
ankona Jul 26, 2024
9f482b1
Merge branch 'develop' into mli-feature
ankona Jul 31, 2024
263e3c7
Fix dragon installation issues (#652)
ankona Aug 2, 2024
0453b8b
Add FeatureStore descriptor to tensor & model keys (#633)
ankona Aug 7, 2024
99ed41c
Merge branch 'develop' into mli-feature
ankona Aug 8, 2024
74d6e78
Use `torch.from_numpy` instead of `torch.tensor` to reduce a copy (#661)
AlyssaCote Aug 8, 2024
391784c
MLI environment variables updated using new naming convention (#665)
AlyssaCote Aug 14, 2024
f7ef49b
Remove pydantic dependency from MLI code (#667)
AlyssaCote Aug 20, 2024
ef034d5
Enable specification of target hostname for a dragon task (#660)
ankona Aug 26, 2024
6d5518b
fix init reordering bug (#675)
ankona Aug 26, 2024
5d85995
Queue-based Worker Manager (#647)
al-rigazzi Aug 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
strategy:
fail-fast: false
matrix:
subset: [backends, slow_tests, group_a, group_b]
subset: [backends, slow_tests, group_a, group_b, dragon]
os: [macos-12, macos-14, ubuntu-22.04] # Operating systems
compiler: [8] # GNU compiler version
rai: [1.2.7] # Redis AI versions
Expand Down Expand Up @@ -112,9 +112,17 @@ jobs:
python -m pip install .[dev,mypy,ml]

- name: Install ML Runtimes with Smart (with pt, tf, and onnx support)
if: contains( matrix.os, 'ubuntu' ) || contains( matrix.os, 'macos-12')
if: (contains( matrix.os, 'ubuntu' ) || contains( matrix.os, 'macos-12')) && ( matrix.subset != 'dragon' )
run: smart build --device cpu --onnx -v

- name: Install ML Runtimes with Smart (with pt, tf, dragon, and onnx support)
if: (contains( matrix.os, 'ubuntu' ) || contains( matrix.os, 'macos-12')) && ( matrix.subset == 'dragon' )
run: |
smart build --device cpu --onnx --dragon -v
SP=$(python -c 'import site; print(site.getsitepackages()[0])')/smartsim/_core/config/dragon/.env
LLP=$(cat $SP | grep LD_LIBRARY_PATH | awk '{split($0, array, "="); print array[2]}')
echo "LD_LIBRARY_PATH=$LLP:$LD_LIBRARY_PATH" >> $GITHUB_ENV

- name: Install ML Runtimes with Smart (no ONNX,TF on Apple Silicon)
if: contains( matrix.os, 'macos-14' )
run: smart build --device cpu --no_tf -v
Expand Down Expand Up @@ -142,9 +150,16 @@ jobs:
echo "SMARTSIM_LOG_LEVEL=debug" >> $GITHUB_ENV
py.test -s --import-mode=importlib -o log_cli=true --cov=$(smart site) --cov-report=xml --cov-config=./tests/test_configs/cov/local_cov.cfg --ignore=tests/full_wlm/ ./tests/backends

# Run pytest (dragon subtests)
- name: Run Dragon Pytest
if: (matrix.subset == 'dragon' && matrix.os == 'ubuntu-22.04')
run: |
echo "SMARTSIM_LOG_LEVEL=debug" >> $GITHUB_ENV
dragon -s py.test -s --import-mode=importlib -o log_cli=true --cov=$(smart site) --cov-report=xml --cov-config=./tests/test_configs/cov/local_cov.cfg --ignore=tests/full_wlm/ -m ${{ matrix.subset }} ./tests

# Run pytest (test subsets)
- name: Run Pytest
if: "!contains(matrix.subset, 'backends')" # if not running backend tests
if: (matrix.subset != 'backends' && matrix.subset != 'dragon') # if not running backend tests or dragon tests
run: |
echo "SMARTSIM_LOG_LEVEL=debug" >> $GITHUB_ENV
py.test -s --import-mode=importlib -o log_cli=true --cov=$(smart site) --cov-report=xml --cov-config=./tests/test_configs/cov/local_cov.cfg --ignore=tests/full_wlm/ -m ${{ matrix.subset }} ./tests
Expand Down
13 changes: 9 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -164,22 +164,22 @@ tutorials-prod:
# help: test - Run all tests
.PHONY: test
test:
@python -m pytest --ignore=tests/full_wlm/
@python -m pytest --ignore=tests/full_wlm/ --ignore=tests/dragon

# help: test-verbose - Run all tests verbosely
.PHONY: test-verbose
test-verbose:
@python -m pytest -vv --ignore=tests/full_wlm/
@python -m pytest -vv --ignore=tests/full_wlm/ --ignore=tests/dragon

# help: test-debug - Run all tests with debug output
.PHONY: test-debug
test-debug:
@SMARTSIM_LOG_LEVEL=developer python -m pytest -s -o log_cli=true -vv --ignore=tests/full_wlm/
@SMARTSIM_LOG_LEVEL=developer python -m pytest -s -o log_cli=true -vv --ignore=tests/full_wlm/ --ignore=tests/dragon

# help: test-cov - Run all tests with coverage
.PHONY: test-cov
test-cov:
@python -m pytest -vv --cov=./smartsim --cov-config=${COV_FILE} --ignore=tests/full_wlm/
@python -m pytest -vv --cov=./smartsim --cov-config=${COV_FILE} --ignore=tests/full_wlm/ --ignore=tests/dragon


# help: test-full - Run all WLM tests with Python coverage (full test suite)
Expand All @@ -192,3 +192,8 @@ test-full:
.PHONY: test-wlm
test-wlm:
@python -m pytest -vv tests/full_wlm/ tests/on_wlm

# help: test-dragon - Run dragon-specific tests
.PHONY: test-dragon
test-dragon:
@dragon pytest tests/dragon
22 changes: 22 additions & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,28 @@ Jump to:

## SmartSim

### MLI branch

Description

- Add RequestDispatcher and the possibility of batching inference requests
- Enable hostname selection for dragon tasks
- Remove pydantic dependency from MLI code
- Update MLI environment variables using new naming convention
- Reduce a copy by using torch.from_numpy instead of torch.tensor
- Enable dynamic feature store selection
- Fix dragon package installation bug
- Adjust schemas for better performance
- Add TorchWorker first implementation and mock inference app example
- Add error handling in Worker Manager pipeline
- Add EnvironmentConfigLoader for ML Worker Manager
- Add Model schema with model metadata included
- Removed device from schemas, MessageHandler and tests
- Add ML worker manager, sample worker, and feature store
- Add schemas and MessageHandler class for de/serialization of
inference requests and response messages


### Development branch

To be released at some future point in time
Expand Down
77 changes: 77 additions & 0 deletions ex/high_throughput_inference/mli_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os
import base64
import cloudpickle
import sys
from smartsim import Experiment
from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker
from smartsim.status import TERMINAL_STATUSES
from smartsim.settings import DragonRunSettings
import time
import typing as t

DEVICE = "gpu"
NUM_RANKS = 4
NUM_WORKERS = 1
filedir = os.path.dirname(__file__)
worker_manager_script_name = os.path.join(filedir, "standalone_workermanager.py")
app_script_name = os.path.join(filedir, "mock_app.py")
model_name = os.path.join(filedir, f"resnet50.{DEVICE}.pt")

transport: t.Literal["hsta", "tcp"] = "hsta"

os.environ["SMARTSIM_DRAGON_TRANSPORT"] = transport

exp_path = os.path.join(filedir, f"MLI_proto_{transport.upper()}")
os.makedirs(exp_path, exist_ok=True)
exp = Experiment("MLI_proto", launcher="dragon", exp_path=exp_path)

torch_worker_str = base64.b64encode(cloudpickle.dumps(TorchWorker)).decode("ascii")

worker_manager_rs: DragonRunSettings = exp.create_run_settings(
sys.executable,
[
worker_manager_script_name,
"--device",
DEVICE,
"--worker_class",
torch_worker_str,
"--batch_size",
str(NUM_RANKS//NUM_WORKERS),
"--batch_timeout",
str(0.00),
"--num_workers",
str(NUM_WORKERS)
],
)

aff = []

worker_manager_rs.set_cpu_affinity(aff)

worker_manager = exp.create_model("worker_manager", run_settings=worker_manager_rs)
worker_manager.attach_generator_files(to_copy=[worker_manager_script_name])

app_rs: DragonRunSettings = exp.create_run_settings(
sys.executable,
exe_args=[app_script_name, "--device", DEVICE, "--log_max_batchsize", str(6)],
)
app_rs.set_tasks_per_node(NUM_RANKS)


app = exp.create_model("app", run_settings=app_rs)
app.attach_generator_files(to_copy=[app_script_name], to_symlink=[model_name])

exp.generate(worker_manager, app, overwrite=True)
exp.start(worker_manager, app, block=False)

while True:
if exp.get_status(app)[0] in TERMINAL_STATUSES:
time.sleep(10)
exp.stop(worker_manager)
break
if exp.get_status(worker_manager)[0] in TERMINAL_STATUSES:
time.sleep(10)
exp.stop(app)
break

print("Exiting.")
189 changes: 189 additions & 0 deletions ex/high_throughput_inference/mock_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# BSD 2-Clause License
#
# Copyright (c) 2021-2024, Hewlett Packard Enterprise
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

# isort: off
import dragon
from dragon import fli
from dragon.channels import Channel
import dragon.channels
from dragon.data.ddict.ddict import DDict
from dragon.globalservices.api_setup import connect_to_infrastructure
from dragon.utils import b64decode, b64encode

# isort: on

import argparse
import io
import numpy
import os
import time
import torch

from mpi4py import MPI
from smartsim._core.mli.infrastructure.storage.dragonfeaturestore import (
DragonFeatureStore,
)
from smartsim._core.mli.message_handler import MessageHandler
from smartsim.log import get_logger
from smartsim._core.utils.timings import PerfTimer

torch.set_num_interop_threads(16)
torch.set_num_threads(1)

logger = get_logger("App")
logger.info("Started app")

CHECK_RESULTS_AND_MAKE_ALL_SLOWER = False

class ProtoClient:
def __init__(self, timing_on: bool):
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
connect_to_infrastructure()
ddict_str = os.environ["_SMARTSIM_INFRA_BACKBONE"]
self._ddict = DDict.attach(ddict_str)
self._backbone_descriptor = DragonFeatureStore(self._ddict).descriptor
to_worker_fli_str = None
while to_worker_fli_str is None:
try:
to_worker_fli_str = self._ddict["to_worker_fli"]
self._to_worker_fli = fli.FLInterface.attach(to_worker_fli_str)
except KeyError:
time.sleep(1)
self._from_worker_ch = Channel.make_process_local()
self._from_worker_ch_serialized = self._from_worker_ch.serialize()
self._to_worker_ch = Channel.make_process_local()

self.perf_timer: PerfTimer = PerfTimer(debug=False, timing_on=timing_on, prefix=f"a{rank}_")

def run_model(self, model: bytes | str, batch: torch.Tensor):
tensors = [batch.numpy()]
self.perf_timer.start_timings("batch_size", batch.shape[0])
built_tensor_desc = MessageHandler.build_tensor_descriptor(
"c", "float32", list(batch.shape)
)
self.perf_timer.measure_time("build_tensor_descriptor")
if isinstance(model, str):
model_arg = MessageHandler.build_model_key(model, self._backbone_descriptor)
else:
model_arg = MessageHandler.build_model(model, "resnet-50", "1.0")
request = MessageHandler.build_request(
reply_channel=self._from_worker_ch_serialized,
model=model_arg,
inputs=[built_tensor_desc],
outputs=[],
output_descriptors=[],
custom_attributes=None,
)
self.perf_timer.measure_time("build_request")
request_bytes = MessageHandler.serialize_request(request)
self.perf_timer.measure_time("serialize_request")
with self._to_worker_fli.sendh(timeout=None, stream_channel=self._to_worker_ch) as to_sendh:
to_sendh.send_bytes(request_bytes)
self.perf_timer.measure_time("send_request")
for tensor in tensors:
to_sendh.send_bytes(tensor.tobytes()) #TODO NOT FAST ENOUGH!!!
self.perf_timer.measure_time("send_tensors")
with self._from_worker_ch.recvh(timeout=None) as from_recvh:
resp = from_recvh.recv_bytes(timeout=None)
self.perf_timer.measure_time("receive_response")
response = MessageHandler.deserialize_response(resp)
self.perf_timer.measure_time("deserialize_response")
# list of data blobs? recv depending on the len(response.result.descriptors)?
data_blob: bytes = from_recvh.recv_bytes(timeout=None)
self.perf_timer.measure_time("receive_tensor")
result = torch.from_numpy(
numpy.frombuffer(
data_blob,
dtype=str(response.result.descriptors[0].dataType),
)
)
self.perf_timer.measure_time("deserialize_tensor")

self.perf_timer.end_timings()
return result

def set_model(self, key: str, model: bytes):
self._ddict[key] = model



class ResNetWrapper:
def __init__(self, name: str, model: str):
self._model = torch.jit.load(model)
self._name = name
buffer = io.BytesIO()
scripted = torch.jit.trace(self._model, self.get_batch())
torch.jit.save(scripted, buffer)
self._serialized_model = buffer.getvalue()

def get_batch(self, batch_size: int = 32):
return torch.randn((batch_size, 3, 224, 224), dtype=torch.float32)

@property
def model(self):
return self._serialized_model

@property
def name(self):
return self._name

if __name__ == "__main__":

parser = argparse.ArgumentParser("Mock application")
parser.add_argument("--device", default="cpu", type=str)
parser.add_argument("--log_max_batchsize", default=8, type=int)
args = parser.parse_args()

resnet = ResNetWrapper("resnet50", f"resnet50.{args.device}.pt")

client = ProtoClient(timing_on=True)
client.set_model(resnet.name, resnet.model)

if CHECK_RESULTS_AND_MAKE_ALL_SLOWER:
# TODO: adapt to non-Nvidia devices
torch_device = args.device.replace("gpu", "cuda")
pt_model = torch.jit.load(io.BytesIO(initial_bytes=(resnet.model))).to(torch_device)

TOTAL_ITERATIONS = 100

for log2_bsize in range(args.log_max_batchsize+1):
b_size: int = 2**log2_bsize
logger.info(f"Batch size: {b_size}")
for iteration_number in range(TOTAL_ITERATIONS + int(b_size==1)):
logger.info(f"Iteration: {iteration_number}")
sample_batch = resnet.get_batch(b_size)
remote_result = client.run_model(resnet.name, sample_batch)
logger.info(client.perf_timer.get_last("total_time"))
if CHECK_RESULTS_AND_MAKE_ALL_SLOWER:
local_res = pt_model(sample_batch.to(torch_device))
err_norm = torch.linalg.vector_norm(torch.flatten(remote_result).to(torch_device)-torch.flatten(local_res), ord=1).cpu()
res_norm = torch.linalg.vector_norm(remote_result, ord=1).item()
local_res_norm = torch.linalg.vector_norm(local_res, ord=1).item()
logger.info(f"Avg norm of error {err_norm.item()/b_size} compared to result norm of {res_norm/b_size}:{local_res_norm/b_size}")
torch.cuda.synchronize()

client.perf_timer.print_timings(to_file=True)
Loading
Loading