Skip to content

Commit

Permalink
Queue-based Worker Manager (CrayLabs#647)
Browse files Browse the repository at this point in the history
This PR adds the `RequestDispatcher` to the MLI. The `RequestDispatcher`
batches inference requests together and dispatches batches to `WorkerManagers`.

[ committed by @al-rigazzi ]
[ reviewed by @mellis13 @ankona @AlyssaCote ]
  • Loading branch information
al-rigazzi authored Aug 28, 2024
1 parent 6d5518b commit 5d85995
Show file tree
Hide file tree
Showing 26 changed files with 2,426 additions and 655 deletions.
1 change: 1 addition & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Jump to:

Description

- Add RequestDispatcher and the possibility of batching inference requests
- Enable hostname selection for dragon tasks
- Remove pydantic dependency from MLI code
- Update MLI environment variables using new naming convention
Expand Down
34 changes: 25 additions & 9 deletions ex/high_throughput_inference/mli_driver.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import argparse
import os
import base64
import cloudpickle
import sys
from smartsim import Experiment
from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker
from smartsim.status import TERMINAL_STATUSES
from smartsim.settings import DragonRunSettings
import time
import typing as t

device = "gpu"
DEVICE = "gpu"
NUM_RANKS = 4
NUM_WORKERS = 1
filedir = os.path.dirname(__file__)
worker_manager_script_name = os.path.join(filedir, "standalone_workermanager.py")
app_script_name = os.path.join(filedir, "mock_app.py")
model_name = os.path.join(filedir, f"resnet50.{device.upper()}.pt")
model_name = os.path.join(filedir, f"resnet50.{DEVICE}.pt")

transport: t.Literal["hsta", "tcp"] = "hsta"

Expand All @@ -25,37 +27,51 @@

torch_worker_str = base64.b64encode(cloudpickle.dumps(TorchWorker)).decode("ascii")

worker_manager_rs = exp.create_run_settings(
worker_manager_rs: DragonRunSettings = exp.create_run_settings(
sys.executable,
[
worker_manager_script_name,
"--device",
device,
DEVICE,
"--worker_class",
torch_worker_str,
"--batch_size",
str(NUM_RANKS//NUM_WORKERS),
"--batch_timeout",
str(0.00),
"--num_workers",
str(NUM_WORKERS)
],
)

aff = []

worker_manager_rs.set_cpu_affinity(aff)

worker_manager = exp.create_model("worker_manager", run_settings=worker_manager_rs)
worker_manager.attach_generator_files(to_copy=[worker_manager_script_name])

app_rs = exp.create_run_settings(
app_rs: DragonRunSettings = exp.create_run_settings(
sys.executable,
exe_args=[app_script_name, "--device", device],
exe_args=[app_script_name, "--device", DEVICE, "--log_max_batchsize", str(6)],
)
app_rs.set_tasks_per_node(NUM_RANKS)


app = exp.create_model("app", run_settings=app_rs)
app.attach_generator_files(to_copy=[app_script_name], to_symlink=[model_name])


exp.generate(worker_manager, app, overwrite=True)
exp.start(worker_manager, app, block=False)

while True:
if exp.get_status(app)[0] in TERMINAL_STATUSES:
time.sleep(10)
exp.stop(worker_manager)
break
if exp.get_status(worker_manager)[0] in TERMINAL_STATUSES:
time.sleep(10)
exp.stop(app)
break
time.sleep(5)

print("Exiting.")
136 changes: 56 additions & 80 deletions ex/high_throughput_inference/mock_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,27 @@
import os
import time
import torch
import numbers

from collections import OrderedDict
from mpi4py import MPI
from smartsim._core.mli.infrastructure.storage.dragonfeaturestore import (
DragonFeatureStore,
)
from smartsim._core.mli.message_handler import MessageHandler
from smartsim.log import get_logger
from smartsim._core.utils.timings import PerfTimer

torch.set_num_interop_threads(16)
torch.set_num_threads(1)

logger = get_logger("App")
logger.info("Started app")

CHECK_RESULTS_AND_MAKE_ALL_SLOWER = False

class ProtoClient:
def __init__(self, timing_on: bool):
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
connect_to_infrastructure()
ddict_str = os.environ["_SMARTSIM_INFRA_BACKBONE"]
self._ddict = DDict.attach(ddict_str)
Expand All @@ -70,61 +77,15 @@ def __init__(self, timing_on: bool):
self._from_worker_ch_serialized = self._from_worker_ch.serialize()
self._to_worker_ch = Channel.make_process_local()

self._start = None
self._interm = None
self._timings: OrderedDict[str, list[numbers.Number]] = OrderedDict()
self._timing_on = timing_on

def _add_label_to_timings(self, label: str):
if label not in self._timings:
self._timings[label] = []

@staticmethod
def _format_number(number: numbers.Number):
return f"{number:0.4e}"

def start_timings(self, batch_size: int):
if self._timing_on:
self._add_label_to_timings("batch_size")
self._timings["batch_size"].append(batch_size)
self._start = time.perf_counter()
self._interm = time.perf_counter()

def end_timings(self):
if self._timing_on:
self._add_label_to_timings("total_time")
self._timings["total_time"].append(
self._format_number(time.perf_counter() - self._start)
)

def measure_time(self, label: str):
if self._timing_on:
self._add_label_to_timings(label)
self._timings[label].append(
self._format_number(time.perf_counter() - self._interm)
)
self._interm = time.perf_counter()

def print_timings(self, to_file: bool = False):
print(" ".join(self._timings.keys()))
value_array = numpy.array(
[value for value in self._timings.values()], dtype=float
)
value_array = numpy.transpose(value_array)
for i in range(value_array.shape[0]):
print(" ".join(self._format_number(value) for value in value_array[i]))
if to_file:
numpy.save("timings.npy", value_array)
numpy.savetxt("timings.txt", value_array)
self.perf_timer: PerfTimer = PerfTimer(debug=False, timing_on=timing_on, prefix=f"a{rank}_")

def run_model(self, model: bytes | str, batch: torch.Tensor):
tensors = [batch.numpy()]
self.start_timings(batch.shape[0])
self.perf_timer.start_timings("batch_size", batch.shape[0])
built_tensor_desc = MessageHandler.build_tensor_descriptor(
"c", "float32", list(batch.shape)
)
self.measure_time("build_tensor_descriptor")
built_model = None
self.perf_timer.measure_time("build_tensor_descriptor")
if isinstance(model, str):
model_arg = MessageHandler.build_model_key(model, self._backbone_descriptor)
else:
Expand All @@ -137,39 +98,39 @@ def run_model(self, model: bytes | str, batch: torch.Tensor):
output_descriptors=[],
custom_attributes=None,
)
self.measure_time("build_request")
self.perf_timer.measure_time("build_request")
request_bytes = MessageHandler.serialize_request(request)
self.measure_time("serialize_request")
with self._to_worker_fli.sendh(
timeout=None, stream_channel=self._to_worker_ch
) as to_sendh:
self.perf_timer.measure_time("serialize_request")
with self._to_worker_fli.sendh(timeout=None, stream_channel=self._to_worker_ch) as to_sendh:
to_sendh.send_bytes(request_bytes)
for t in tensors:
to_sendh.send_bytes(t.tobytes()) # TODO NOT FAST ENOUGH!!!
# to_sendh.send_bytes(bytes(t.data))
logger.info(f"Message size: {len(request_bytes)} bytes")

self.measure_time("send")
self.perf_timer.measure_time("send_request")
for tensor in tensors:
to_sendh.send_bytes(tensor.tobytes()) #TODO NOT FAST ENOUGH!!!
self.perf_timer.measure_time("send_tensors")
with self._from_worker_ch.recvh(timeout=None) as from_recvh:
resp = from_recvh.recv_bytes(timeout=None)
self.measure_time("receive")
self.perf_timer.measure_time("receive_response")
response = MessageHandler.deserialize_response(resp)
self.measure_time("deserialize_response")
self.perf_timer.measure_time("deserialize_response")
# list of data blobs? recv depending on the len(response.result.descriptors)?
data_blob = from_recvh.recv_bytes(timeout=None)
result = numpy.frombuffer(
data_blob,
dtype=str(response.result.descriptors[0].dataType),
data_blob: bytes = from_recvh.recv_bytes(timeout=None)
self.perf_timer.measure_time("receive_tensor")
result = torch.from_numpy(
numpy.frombuffer(
data_blob,
dtype=str(response.result.descriptors[0].dataType),
)
)
self.measure_time("deserialize_tensor")
self.perf_timer.measure_time("deserialize_tensor")

self.end_timings()
self.perf_timer.end_timings()
return result

def set_model(self, key: str, model: bytes):
self._ddict[key] = model



class ResNetWrapper:
def __init__(self, name: str, model: str):
self._model = torch.jit.load(model)
Expand All @@ -190,24 +151,39 @@ def model(self):
def name(self):
return self._name


if __name__ == "__main__":

parser = argparse.ArgumentParser("Mock application")
parser.add_argument("--device", default="cpu")
parser.add_argument("--device", default="cpu", type=str)
parser.add_argument("--log_max_batchsize", default=8, type=int)
args = parser.parse_args()

resnet = ResNetWrapper("resnet50", f"resnet50.{args.device.upper()}.pt")
resnet = ResNetWrapper("resnet50", f"resnet50.{args.device}.pt")

client = ProtoClient(timing_on=True)
client.set_model(resnet.name, resnet.model)

total_iterations = 100
if CHECK_RESULTS_AND_MAKE_ALL_SLOWER:
# TODO: adapt to non-Nvidia devices
torch_device = args.device.replace("gpu", "cuda")
pt_model = torch.jit.load(io.BytesIO(initial_bytes=(resnet.model))).to(torch_device)

for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]:
logger.info(f"Batch size: {batch_size}")
for iteration_number in range(total_iterations + int(batch_size == 1)):
logger.info(f"Iteration: {iteration_number}")
client.run_model(resnet.name, resnet.get_batch(batch_size))
TOTAL_ITERATIONS = 100

client.print_timings(to_file=True)
for log2_bsize in range(args.log_max_batchsize+1):
b_size: int = 2**log2_bsize
logger.info(f"Batch size: {b_size}")
for iteration_number in range(TOTAL_ITERATIONS + int(b_size==1)):
logger.info(f"Iteration: {iteration_number}")
sample_batch = resnet.get_batch(b_size)
remote_result = client.run_model(resnet.name, sample_batch)
logger.info(client.perf_timer.get_last("total_time"))
if CHECK_RESULTS_AND_MAKE_ALL_SLOWER:
local_res = pt_model(sample_batch.to(torch_device))
err_norm = torch.linalg.vector_norm(torch.flatten(remote_result).to(torch_device)-torch.flatten(local_res), ord=1).cpu()
res_norm = torch.linalg.vector_norm(remote_result, ord=1).item()
local_res_norm = torch.linalg.vector_norm(local_res, ord=1).item()
logger.info(f"Avg norm of error {err_norm.item()/b_size} compared to result norm of {res_norm/b_size}:{local_res_norm/b_size}")
torch.cuda.synchronize()

client.perf_timer.print_timings(to_file=True)
28 changes: 15 additions & 13 deletions ex/high_throughput_inference/mock_app_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
import numpy
import time
import torch
from mpi4py import MPI
from smartsim.log import get_logger
from smartsim._core.utils.timings import PerfTimer
from smartredis import Client

logger = get_logger("App")
Expand All @@ -56,6 +58,9 @@ def name(self):

if __name__ == "__main__":

comm = MPI.COMM_WORLD
rank = comm.Get_rank()

parser = argparse.ArgumentParser("Mock application")
parser.add_argument("--device", default="cpu")
args = parser.parse_args()
Expand All @@ -65,24 +70,21 @@ def name(self):
client = Client(cluster=False, address=None)
client.set_model(resnet.name, resnet.model, backend='TORCH', device=args.device.upper())

perf_timer: PerfTimer = PerfTimer(debug=False, timing_on=timing_on, prefix=f"redis{rank}_")

total_iterations = 100
timings=[]
for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]:
logger.info(f"Batch size: {batch_size}")
for iteration_number in range(total_iterations + int(batch_size==1)):
timing = [batch_size]
perf_timer.start_timings("batch_size", batch_size)
logger.info(f"Iteration: {iteration_number}")
start = time.perf_counter()
client.put_tensor(name="batch", data=resnet.get_batch(batch_size).numpy())
client.run_model(name=resnet.name, inputs=["batch"], outputs=["result"])
result = client.get_tensor(name="result")
end = time.perf_counter()
timing.append(end-start)
timings.append(timing)

input_name = f"batch_{rank}"
output_name = f"result_{rank}"
client.put_tensor(name=input_name, data=resnet.get_batch(batch_size).numpy())
client.run_model(name=resnet.name, inputs=[input_name], outputs=[output_name])
result = client.get_tensor(name=output_name)
perf_timer.end_timings()


timings_np = numpy.asarray(timings)
numpy.save("timings.npy", timings_np)
for timing in timings:
print(" ".join(str(t) for t in timing))
perf_timer.print_timings(True)
15 changes: 8 additions & 7 deletions ex/high_throughput_inference/redis_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,24 @@
from smartsim import Experiment
from smartsim.status import TERMINAL_STATUSES
import time
import typing as t

device = "gpu"
DEVICE = "gpu"
filedir = os.path.dirname(__file__)
app_script_name = os.path.join(filedir, "mock_app_redis.py")
model_name = os.path.join(filedir, f"resnet50.{device.upper()}.pt")
model_name = os.path.join(filedir, f"resnet50.{DEVICE}.pt")


exp_path = os.path.join(filedir, "redis_ai")
exp_path = os.path.join(filedir, "redis_ai_multi")
os.makedirs(exp_path, exist_ok=True)
exp = Experiment("redis_ai", launcher="slurm", exp_path=exp_path)
exp = Experiment("redis_ai_multi", launcher="slurm", exp_path=exp_path)

db = exp.create_database(interface="hsn0")

app_rs = exp.create_run_settings(sys.executable, exe_args = [app_script_name, "--device", device])
app_rs = exp.create_run_settings(
sys.executable, exe_args = [app_script_name, "--device", DEVICE]
)
app_rs.set_nodes(1)
app_rs.set_tasks(1)
app_rs.set_tasks(4)
app = exp.create_model("app", run_settings=app_rs)
app.attach_generator_files(to_copy=[app_script_name], to_symlink=[model_name])

Expand Down
Loading

0 comments on commit 5d85995

Please sign in to comment.