Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FeatureStore descriptor to tensor & model keys #633

Merged
merged 49 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
f85db2a
Add file system descriptor to tensor & model keys
ankona Jul 18, 2024
f252806
post-merge tweaks
ankona Jul 18, 2024
09eff20
update upstream tests
ankona Jul 19, 2024
2cedfb3
dynamic fs attachment, add backbone to ML worker mgr & config loader
ankona Jul 25, 2024
24df7da
add missing from_descriptor methods,
ankona Jul 29, 2024
82fb67a
fix
ankona Jul 29, 2024
65cf4d1
fix env loader tests
ankona Jul 29, 2024
15806fe
move import below conditional
ankona Jul 29, 2024
cb962be
sort imports for dragon
ankona Jul 29, 2024
36883c9
fix feature store type interleaving bug
ankona Jul 29, 2024
2e9f146
isort
ankona Jul 29, 2024
e011b70
fix test failing new validation check
ankona Jul 29, 2024
e6dae22
revert gh workflow changes that will be merged later
ankona Jul 31, 2024
4548fec
add missing docstrings, remove commented parameters
ankona Jul 31, 2024
c29dc6b
docstring
ankona Jul 31, 2024
24cbef2
remove commented out imports
ankona Jul 31, 2024
4eb29b9
remove commented out code
ankona Jul 31, 2024
eb793b6
improve documentation on purpose of backbone fs
ankona Jul 31, 2024
318deac
improve documentatoin about backbone usage
ankona Jul 31, 2024
0eac344
remove deprecated & add missing docstring params
ankona Jul 31, 2024
d3b9512
fix renamed param in docstring
ankona Jul 31, 2024
6e387e8
remove commented lines
ankona Jul 31, 2024
a89f160
remove commented lines
ankona Jul 31, 2024
73c7f9b
formatting
ankona Jul 31, 2024
a5bda09
revert dupe change from upstream
ankona Jul 31, 2024
d50a540
fix confusing docstring
ankona Aug 1, 2024
86b4c2e
fix incomplete docstrings, tweak logs
ankona Aug 1, 2024
2346483
docstring fix
ankona Aug 1, 2024
d419465
validate & report env config loader attempts to call factories
ankona Aug 1, 2024
85a6ee0
report validation failures in MLI pipeline through callback
ankona Aug 1, 2024
d9a30d7
fix removal of early return on empty descriptors
ankona Aug 2, 2024
5f9c727
format docstrings to render correctly
ankona Aug 2, 2024
3fd5ed1
rename backbone env var
ankona Aug 2, 2024
6d4f2e0
debug descriptor failure on build agent
ankona Aug 2, 2024
c75dc5a
download and log original asset name on `smart build --dragon`
ankona Aug 2, 2024
ee07d94
test
ankona Aug 2, 2024
a269186
test
ankona Aug 3, 2024
e75a18f
remove test_worker_Manager
ankona Aug 3, 2024
125dc84
add test_worker_manager back into test set
ankona Aug 5, 2024
783294a
rename SSQueue env var to SS_QUEUE
ankona Aug 6, 2024
9bce16a
remove commented code, rename variable for clarity
ankona Aug 6, 2024
446d000
rename ss_queue -> ss_request_queue
ankona Aug 6, 2024
989db29
replaced log.warning w/log.error on missing components
ankona Aug 6, 2024
3e8d6eb
improve DragonFeatureStore docstrings
ankona Aug 6, 2024
0344398
ensure KeyError is logged
ankona Aug 6, 2024
d040e28
move dragon-based test into correct subdir
ankona Aug 6, 2024
5645f79
formatting fix
ankona Aug 6, 2024
1e74315
remove asset URL overrides
ankona Aug 7, 2024
6097c46
remove usage of deprecated dragon policy affinity
ankona Aug 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Jump to:

Description

- Enable dynamic feature store selection
- Fix dragon package installation bug
- Adjust schemas for better performance
- Add TorchWorker first implementation and mock inference app example
Expand Down
21 changes: 16 additions & 5 deletions ex/high_throughput_inference/mli_driver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@


import argparse
import os
import base64
import cloudpickle
Expand All @@ -26,11 +25,23 @@

torch_worker_str = base64.b64encode(cloudpickle.dumps(TorchWorker)).decode("ascii")

worker_manager_rs = exp.create_run_settings(sys.executable, [worker_manager_script_name, "--device", device, "--worker_class", torch_worker_str])
worker_manager_rs = exp.create_run_settings(
sys.executable,
[
worker_manager_script_name,
"--device",
device,
"--worker_class",
torch_worker_str,
],
)
worker_manager = exp.create_model("worker_manager", run_settings=worker_manager_rs)
worker_manager.attach_generator_files(to_copy=[worker_manager_script_name])

app_rs = exp.create_run_settings(sys.executable, exe_args = [app_script_name, "--device", device])
app_rs = exp.create_run_settings(
sys.executable,
exe_args=[app_script_name, "--device", device],
)
app = exp.create_model("app", run_settings=app_rs)
app.attach_generator_files(to_copy=[app_script_name], to_symlink=[model_name])

Expand All @@ -47,4 +58,4 @@
break
time.sleep(5)

print("Exiting.")
print("Exiting.")
42 changes: 28 additions & 14 deletions ex/high_throughput_inference/mock_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,21 @@
import numbers

from collections import OrderedDict
from smartsim._core.mli.infrastructure.storage.dragonfeaturestore import (
DragonFeatureStore,
)
from smartsim._core.mli.message_handler import MessageHandler
from smartsim.log import get_logger

logger = get_logger("App")


class ProtoClient:
def __init__(self, timing_on: bool):
connect_to_infrastructure()
ddict_str = os.environ["SS_DRG_DDICT"]
ddict_str = os.environ["SS_INFRA_BACKBONE"]
self._ddict = DDict.attach(ddict_str)
self._backbone_descriptor = DragonFeatureStore(self._ddict).descriptor
to_worker_fli_str = None
while to_worker_fli_str is None:
try:
Expand Down Expand Up @@ -88,39 +93,45 @@ def start_timings(self, batch_size: int):
def end_timings(self):
if self._timing_on:
self._add_label_to_timings("total_time")
self._timings["total_time"].append(self._format_number(time.perf_counter()-self._start))
self._timings["total_time"].append(
self._format_number(time.perf_counter() - self._start)
)

def measure_time(self, label: str):
if self._timing_on:
self._add_label_to_timings(label)
self._timings[label].append(self._format_number(time.perf_counter()-self._interm))
self._timings[label].append(
self._format_number(time.perf_counter() - self._interm)
)
self._interm = time.perf_counter()

def print_timings(self, to_file: bool = False):
print(" ".join(self._timings.keys()))
value_array = numpy.array([value for value in self._timings.values()], dtype=float)
value_array = numpy.array(
[value for value in self._timings.values()], dtype=float
)
value_array = numpy.transpose(value_array)
for i in range(value_array.shape[0]):
print(" ".join(self._format_number(value) for value in value_array[i]))
if to_file:
numpy.save("timings.npy", value_array)
numpy.savetxt("timings.txt", value_array)


def run_model(self, model: bytes | str, batch: torch.Tensor):
tensors = [batch.numpy()]
self.start_timings(batch.shape[0])
built_tensor_desc = MessageHandler.build_tensor_descriptor(
"c", "float32", list(batch.shape))
"c", "float32", list(batch.shape)
)
self.measure_time("build_tensor_descriptor")
built_model = None
if isinstance(model, str):
model_arg = MessageHandler.build_model_key(model)
model_arg = MessageHandler.build_model_key(model, self._backbone_descriptor)
else:
model_arg = MessageHandler.build_model(model, "resnet-50", "1.0")
request = MessageHandler.build_request(
reply_channel=self._from_worker_ch_serialized,
model= model_arg,
model=model_arg,
inputs=[built_tensor_desc],
outputs=[],
output_descriptors=[],
Expand All @@ -129,10 +140,12 @@ def run_model(self, model: bytes | str, batch: torch.Tensor):
self.measure_time("build_request")
request_bytes = MessageHandler.serialize_request(request)
self.measure_time("serialize_request")
with self._to_worker_fli.sendh(timeout=None, stream_channel=self._to_worker_ch) as to_sendh:
with self._to_worker_fli.sendh(
timeout=None, stream_channel=self._to_worker_ch
) as to_sendh:
to_sendh.send_bytes(request_bytes)
for t in tensors:
to_sendh.send_bytes(t.tobytes()) #TODO NOT FAST ENOUGH!!!
to_sendh.send_bytes(t.tobytes()) # TODO NOT FAST ENOUGH!!!
# to_sendh.send_bytes(bytes(t.data))
logger.info(f"Message size: {len(request_bytes)} bytes")

Expand All @@ -159,7 +172,7 @@ def set_model(self, key: str, model: bytes):
self._ddict[key] = model


class ResNetWrapper():
class ResNetWrapper:
def __init__(self, name: str, model: str):
self._model = torch.jit.load(model)
self._name = name
Expand All @@ -168,7 +181,7 @@ def __init__(self, name: str, model: str):
torch.jit.save(scripted, buffer)
self._serialized_model = buffer.getvalue()

def get_batch(self, batch_size: int=32):
def get_batch(self, batch_size: int = 32):
return torch.randn((batch_size, 3, 224, 224), dtype=torch.float32)

@property
Expand All @@ -179,6 +192,7 @@ def model(self):
def name(self):
return self._name


if __name__ == "__main__":

parser = argparse.ArgumentParser("Mock application")
Expand All @@ -194,8 +208,8 @@ def name(self):

for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]:
logger.info(f"Batch size: {batch_size}")
for iteration_number in range(total_iterations + int(batch_size==1)):
for iteration_number in range(total_iterations + int(batch_size == 1)):
logger.info(f"Iteration: {iteration_number}")
client.run_model(resnet.name, resnet.get_batch(batch_size))

client.print_timings(to_file=True)
client.print_timings(to_file=True)
29 changes: 16 additions & 13 deletions ex/high_throughput_inference/standalone_workermanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,19 @@
from dragon.data.ddict.ddict import DDict
from dragon.utils import b64decode, b64encode
from dragon.globalservices.api_setup import connect_to_infrastructure

# isort: on
import argparse
import base64
import cloudpickle
import pickle
import optparse
import os

from smartsim._core.mli.comm.channel.dragonchannel import DragonCommChannel
from smartsim._core.mli.infrastructure.storage.dragonfeaturestore import DragonFeatureStore
from smartsim._core.mli.comm.channel.dragonfli import DragonFLIChannel
from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker
from smartsim._core.mli.infrastructure.storage.dragonfeaturestore import (
DragonFeatureStore,
)
from smartsim._core.mli.infrastructure.control.workermanager import WorkerManager
from smartsim._core.mli.infrastructure.environmentloader import EnvironmentConfigLoader

Expand All @@ -67,30 +69,31 @@

args = parser.parse_args()
connect_to_infrastructure()
ddict_str = os.environ["SS_DRG_DDICT"]
ddict_str = os.environ["SS_INFRA_BACKBONE"]
ddict = DDict.attach(ddict_str)

to_worker_channel = Channel.make_process_local()
to_worker_fli = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None)
to_worker_fli_serialized = to_worker_fli.serialize()
ddict["to_worker_fli"] = to_worker_fli_serialized

torch_worker = cloudpickle.loads(base64.b64decode(args.worker_class.encode('ascii')))()

dfs = DragonFeatureStore(ddict)
comm_channel = DragonFLIChannel(to_worker_fli_serialized)
worker_type_name = base64.b64decode(args.worker_class.encode("ascii"))
torch_worker = cloudpickle.loads(worker_type_name)()

os.environ["SSFeatureStore"] = base64.b64encode(pickle.dumps(dfs)).decode("utf-8")
os.environ["SSQueue"] = base64.b64encode(to_worker_fli_serialized).decode("utf-8")
descriptor = base64.b64encode(to_worker_fli_serialized).decode("utf-8")
os.environ["SS_REQUEST_QUEUE"] = descriptor

config_loader = EnvironmentConfigLoader()
config_loader = EnvironmentConfigLoader(
featurestore_factory=DragonFeatureStore.from_descriptor,
callback_factory=DragonCommChannel,
queue_factory=DragonFLIChannel.from_descriptor,
)

worker_manager = WorkerManager(
config_loader=config_loader,
worker=torch_worker,
as_service=True,
cooldown=10,
comm_channel_type=DragonCommChannel,
device = args.device,
device=args.device,
)
worker_manager.execute()
37 changes: 33 additions & 4 deletions smartsim/_core/_cli/scripts/dragon_install.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import pathlib
import shutil
import sys
import typing as t
from urllib.request import urlretrieve

from github import Github
from github.GitReleaseAsset import GitReleaseAsset
Expand Down Expand Up @@ -160,13 +162,40 @@ def retrieve_asset(working_dir: pathlib.Path, asset: GitReleaseAsset) -> pathlib

# if we've previously downloaded the release and still have
# wheels laying around, use that cached version instead
if download_dir.exists() and list(download_dir.rglob("*.whl")):
return download_dir
if download_dir.exists() or list(download_dir.rglob("*.whl")):
# return download_dir
shutil.rmtree(str(download_dir))

download_dir.mkdir(parents=True, exist_ok=True)

# grab a copy of the complete asset
asset_path = download_dir / str(asset.name)
download_url = asset.browser_download_url
if "0.91" not in asset.name:
if "3.9" in python_version():
logger.debug("I want to snake the original w/3.9 rpm")
# download_url = "https://arti.hpc.amslabs.hpecorp.net/ui/native/dragon-rpm-master-local/dev/master/sle15_sp3_pe/x86_64/dragon-0.91-py3.11.5-1d600977c.rpm"
... # temp no-op
elif "3.10" in python_version():
logger.debug("snaking original w/3.10 rpm")
download_url = "https://drive.usercontent.google.com/download?id=1dyScGNomzoPO8-bC8i6zaIbOOhsL83Sp&export=download&authuser=0&confirm=t&uuid=6068afeb-14fd-4303-90a5-498b316d3cce&at=APZUnTWTIf9Tl7Yt8tcdKyodnydV:1722641072921"
elif "3.11" in python_version():
logger.debug("snaking original w/3.11rpm")
download_url = "https://drive.usercontent.google.com/download?id=1vhUXLIu06-RPA_N3wWmi42avnawzizZZ&export=download&authuser=0&confirm=t&uuid=04c920cb-2e66-4762-8e0f-8ad57e0cbbdf&at=APZUnTUKtCv_BgYOkWAaHqoPpGLd:1722640947383"
else:
logger.debug(f"the name was: {asset.name}")

archive = WebTGZ(asset.browser_download_url)
try:
urlretrieve(download_url, str(asset_path))
logger.debug(f"Retrieved asset {asset.name} from {download_url}")
except Exception:
logger.exception(f"Unable to download asset from: {download_url}")

# extract the asset
archive = WebTGZ(download_url)
archive.extract(download_dir)

logger.debug(f"Retrieved {asset.browser_download_url} to {download_dir}")
logger.debug(f"Extracted {download_url} to {download_dir}")
return download_dir


Expand Down
2 changes: 1 addition & 1 deletion smartsim/_core/launcher/dragon/dragonBackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def _start_steps(self) -> None:
env={
**request.current_env,
**request.env,
"SS_DRG_DDICT": self.infra_ddict,
"SS_INFRA_BACKBONE": self.infra_ddict,
},
stdout=dragon_process.Popen.PIPE,
stderr=dragon_process.Popen.PIPE,
Expand Down
2 changes: 2 additions & 0 deletions smartsim/_core/mli/comm/channel/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ def __init__(self, descriptor: t.Union[str, bytes]) -> None:
@abstractmethod
def send(self, value: bytes) -> None:
"""Send a message through the underlying communication channel

:param value: The value to send"""

@abstractmethod
def recv(self) -> t.List[bytes]:
"""Receieve a message through the underlying communication channel

:returns: the received message"""

@property
Expand Down
17 changes: 17 additions & 0 deletions smartsim/_core/mli/comm/channel/dragonchannel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import base64
import sys
import typing as t

Expand Down Expand Up @@ -55,7 +56,23 @@ def send(self, value: bytes) -> None:

def recv(self) -> t.List[bytes]:
"""Receieve a message through the underlying communication channel

:returns: the received message"""
with self._channel.recvh(timeout=None) as recvh:
message_bytes: bytes = recvh.recv_bytes(timeout=None)
return [message_bytes]

@classmethod
def from_descriptor(
cls,
descriptor: str,
) -> "DragonCommChannel":
"""A factory method that creates an instance from a descriptor string

:param descriptor: The descriptor that uniquely identifies the resource
:returns: An attached DragonCommChannel"""
try:
return DragonCommChannel(base64.b64decode(descriptor))
except:
logger.error(f"Failed to create dragon comm channel: {descriptor}")
raise
Loading
Loading