-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Utilities to set ML models and scripts from driver scripts #185
Changes from 12 commits
a6701c0
2deee1c
ca9c9db
9db1c96
75a1f74
2e31e62
8f6ba4b
004b85e
21bc88a
0d8d7ae
b9f83d1
352f570
1d3672c
f5de152
ec8c918
5cf503c
9b1b681
9c03527
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,6 +37,8 @@ | |
from pathlib import Path | ||
from subprocess import PIPE, STDOUT | ||
|
||
from smartredis import Client | ||
from smartredis.error import RedisConnectionError | ||
from smartsim._core.utils.network import current_ip | ||
from smartsim.error import SSInternalError | ||
from smartsim.log import get_logger | ||
|
@@ -55,8 +57,107 @@ | |
def handle_signal(signo, frame): | ||
cleanup() | ||
|
||
def launch_db_model(client: Client, db_model: List[str]): | ||
"""Parse options to launch model on local cluster | ||
|
||
def main(network_interface: str, db_cpus: int, command: List[str]): | ||
:param client: SmartRedis client connected to local DB | ||
:type client: Client | ||
:param db_model: List of arguments defining the model | ||
:type db_model: List[str] | ||
:return: Name of model | ||
:rtype: str | ||
""" | ||
parser = argparse.ArgumentParser("Set ML model on DB") | ||
parser.add_argument("--name", type=str) | ||
parser.add_argument("--file", type=str) | ||
parser.add_argument("--backend", type=str) | ||
parser.add_argument("--device", type=str) | ||
parser.add_argument("--devices_per_node", type=int) | ||
parser.add_argument("--batch_size", type=int, default=0) | ||
parser.add_argument("--min_batch_size", type=int, default=0) | ||
parser.add_argument("--tag", type=str, default="") | ||
parser.add_argument("--inputs", nargs="+", default=None) | ||
parser.add_argument("--outputs", nargs="+", default=None) | ||
|
||
# Unused if we use SmartRedis | ||
parser.add_argument("--min_batch_timeout", type=int, default=None) | ||
args = parser.parse_args(db_model) | ||
|
||
if args.inputs: | ||
inputs = list(args.inputs) | ||
if args.outputs: | ||
outputs = list(args.outputs) | ||
|
||
if args.devices_per_node == 1: | ||
client.set_model_from_file(args.name, | ||
args.file, | ||
args.backend, | ||
args.device, | ||
args.batch_size, | ||
args.min_batch_size, | ||
args.tag, | ||
inputs, | ||
outputs) | ||
else: | ||
for device_num in range(args.devices_per_node): | ||
client.set_model_from_file(args.name, | ||
args.file, | ||
args.backend, | ||
args.device+f":{device_num}", | ||
args.batch_size, | ||
args.min_batch_size, | ||
args.tag, | ||
inputs, | ||
outputs) | ||
|
||
return args.name | ||
|
||
def launch_db_script(client: Client, db_script: List[str]): | ||
"""Parse options to launch script on local cluster | ||
|
||
:param client: SmartRedis client connected to local DB | ||
:type client: Client | ||
:param db_model: List of arguments defining the script | ||
:type db_model: List[str] | ||
:return: Name of model | ||
:rtype: str | ||
""" | ||
parser = argparse.ArgumentParser("Set script on DB") | ||
parser.add_argument("--name", type=str) | ||
parser.add_argument("--func", type=str) | ||
parser.add_argument("--file", type=str) | ||
parser.add_argument("--backend", type=str) | ||
parser.add_argument("--device", type=str) | ||
parser.add_argument("--devices_per_node", type=int) | ||
args = parser.parse_args(db_script) | ||
if args.func: | ||
func = args.func.replace("\\n", "\n") | ||
|
||
if args.devices_per_node == 1: | ||
client.set_script(args.name, | ||
func, | ||
args.device) | ||
else: | ||
for device_num in range(args.devices_per_node): | ||
client.set_script(args.name, | ||
func, | ||
args.device+f":{device_num}") | ||
elif args.file: | ||
if args.devices_per_node == 1: | ||
client.set_script_from_file(args.name, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens if these fail? Have we tried setting a bad model? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The functions are launched in a try-catch statement, I think it should catch such an exception. Do you think we should add a test for that case? |
||
args.file, | ||
args.device) | ||
else: | ||
for device_num in range(args.devices_per_node): | ||
client.set_script_from_file(args.name, | ||
args.file, | ||
args.device+f":{device_num}") | ||
|
||
|
||
return args.name | ||
|
||
|
||
def main(network_interface: str, db_cpus: int, command: List[str], db_models: List[List[str]], db_scripts: List[List[str]]): | ||
global DBPID | ||
|
||
try: | ||
|
@@ -102,6 +203,23 @@ def main(network_interface: str, db_cpus: int, command: List[str]): | |
f"\tCommand: {' '.join(cmd)}\n\n" | ||
))) | ||
|
||
if db_models or db_scripts: | ||
try: | ||
client = Client(cluster=False) | ||
for i, db_model in enumerate(db_models): | ||
logger.debug("Uploading model") | ||
model_name = launch_db_model(client, db_model) | ||
logger.debug(f"Added model {model_name} ({i+1}/{len(db_models)})") | ||
for i, db_script in enumerate(db_scripts): | ||
logger.debug("Uploading script") | ||
script_name = launch_db_script(client, db_script) | ||
logger.debug(f"Added script {script_name} ({i+1}/{len(db_scripts)})") | ||
# Make sure we don't keep this around | ||
del client | ||
except RedisConnectionError: | ||
raise SSInternalError("Failed to set model or script, could not connect to database") | ||
|
||
|
||
for line in iter(p.stdout.readline, b""): | ||
print(line.decode("utf-8").rstrip(), flush=True) | ||
|
||
|
@@ -144,6 +262,8 @@ def cleanup(): | |
parser.add_argument("+lockfile", type=str, help="Filename to create for single proc per host") | ||
parser.add_argument("+db_cpus", type=int, default=2, help="Number of CPUs to use for DB") | ||
parser.add_argument("+command", nargs="+", help="Command to run") | ||
parser.add_argument("+db_model", nargs="+", action="append", default=[], help="Model to set on DB") | ||
parser.add_argument("+db_script", nargs="+", action="append", default=[], help="Script to set on DB") | ||
args = parser.parse_args() | ||
|
||
tmp_lockfile = Path(tempfile.gettempdir()) / args.lockfile | ||
|
@@ -160,7 +280,7 @@ def cleanup(): | |
for sig in SIGNALS: | ||
signal.signal(sig, handle_signal) | ||
|
||
main(args.ifname, args.db_cpus, args.command) | ||
main(args.ifname, args.db_cpus, args.command, args.db_model, args.db_script) | ||
|
||
# gracefully exit the processes in the distributed application that | ||
# we do not want to have start a colocated process. Only one process | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from .helpers import colorize, delete_elements, init_default, installed_redisai_backends | ||
from .redis import check_cluster_status, create_cluster | ||
from .redis import check_cluster_status, create_cluster, db_is_active |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add comments here to explain what's being done. Love the operator usage but it's not readable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, done.