Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Utilities to set ML models and scripts from driver scripts #185

Merged
merged 18 commits into from
May 11, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions smartsim/_core/control/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import threading
import time

from ..._core.utils.redis import db_is_active, set_ml_model, set_script
from ...database import Orchestrator
from ...entity import DBNode, EntityList, SmartSimEntity
from ...error import LauncherError, SmartSimError, SSInternalError, SSUnsupportedError
Expand All @@ -40,6 +41,10 @@
from ..utils import check_cluster_status, create_cluster
from .jobmanager import JobManager

from smartredis import Client
from smartredis.error import RedisConnectionError


logger = get_logger(__name__)

# job manager lock
Expand Down Expand Up @@ -286,9 +291,12 @@ def _launch(self, manifest):
raise SmartSimError(msg)
self._launch_orchestrator(orchestrator)

for rc in manifest.ray_clusters:
for rc in manifest.ray_clusters: # cov-wlm
rc._update_workers()

if self.orchestrator_active:
self._set_dbobjects(manifest)

# create all steps prior to launch
steps = []
all_entity_lists = manifest.ensembles + manifest.ray_clusters
Expand All @@ -297,7 +305,7 @@ def _launch(self, manifest):
batch_step = self._create_batch_job_step(elist)
steps.append((batch_step, elist))
else:
# if ensemble is to be run as seperate job steps, aka not in a batch
# if ensemble is to be run as separate job steps, aka not in a batch
job_steps = [(self._create_job_step(e), e) for e in elist.entities]
steps.extend(job_steps)

Expand Down Expand Up @@ -586,3 +594,43 @@ def reload_saved_db(self, checkpoint_file):
finally:
JM_LOCK.release()


def _set_dbobjects(self, manifest):
if not manifest.has_db_objects:
return

db_addresses = self._jobs.get_db_host_addresses()

hosts = list(set([address.split(":")[0] for address in db_addresses]))
ports = list(set([address.split(":")[-1] for address in db_addresses]))

if not db_is_active(hosts=hosts,
ports=ports,
num_shards=len(db_addresses)):
raise SSInternalError("Cannot set DB Objects, DB is not running")

client = Client(address=db_addresses[0], cluster=len(db_addresses) > 1)

for model in manifest.models:
if not model.colocated:
for db_model in model._db_models:
set_ml_model(db_model, client)
for db_script in model._db_scripts:
set_script(db_script, client)

for ensemble in manifest.ensembles:
for db_model in ensemble._db_models:
set_ml_model(db_model, client)
for db_script in ensemble._db_scripts:
set_script(db_script, client)
for entity in ensemble:
if not entity.colocated:
# Set models which could belong only
# to the entities and not to the ensemble
# but avoid duplicates
for db_model in entity._db_models:
if db_model not in ensemble._db_models:
set_ml_model(db_model, client)
for db_script in entity._db_scripts:
if db_script not in ensemble._db_scripts:
set_script(db_script, client)
32 changes: 32 additions & 0 deletions smartsim/_core/control/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,35 @@ def __str__(self):

s += "\n"
return s

@property
def has_db_objects(self):
"""Check if any entity has DBObjects to set
"""
def has_db_models(entity):
if hasattr(entity, "_db_models"):
return len(entity._db_models) > 0
def has_db_scripts(entity):
if hasattr(entity, "_db_scripts"):
return len(entity._db_scripts) > 0

has_db_objects = False
for model in self.models:
has_db_objects |= hasattr(model, "_db_models")
has_db_objects |= any([has_db_models(model) | has_db_scripts(model) for model in self.models])
if has_db_objects:
return True

ensembles = self.ensembles
if not ensembles:
return False

has_db_objects |= any([has_db_models(ensemble) | has_db_scripts(ensemble) for ensemble in ensembles])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comments here to explain what's being done. Love the operator usage but it's not readable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, done.

if has_db_objects:
return True
for ensemble in ensembles:
has_db_objects |= any([has_db_models(model) | has_db_scripts(model) for model in ensemble])
if has_db_objects:
return True

return has_db_objects
124 changes: 122 additions & 2 deletions smartsim/_core/entrypoints/colocated.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from pathlib import Path
from subprocess import PIPE, STDOUT

from smartredis import Client
from smartredis.error import RedisConnectionError
from smartsim._core.utils.network import current_ip
from smartsim.error import SSInternalError
from smartsim.log import get_logger
Expand All @@ -55,8 +57,107 @@
def handle_signal(signo, frame):
cleanup()

def launch_db_model(client: Client, db_model: List[str]):
"""Parse options to launch model on local cluster

def main(network_interface: str, db_cpus: int, command: List[str]):
:param client: SmartRedis client connected to local DB
:type client: Client
:param db_model: List of arguments defining the model
:type db_model: List[str]
:return: Name of model
:rtype: str
"""
parser = argparse.ArgumentParser("Set ML model on DB")
parser.add_argument("--name", type=str)
parser.add_argument("--file", type=str)
parser.add_argument("--backend", type=str)
parser.add_argument("--device", type=str)
parser.add_argument("--devices_per_node", type=int)
parser.add_argument("--batch_size", type=int, default=0)
parser.add_argument("--min_batch_size", type=int, default=0)
parser.add_argument("--tag", type=str, default="")
parser.add_argument("--inputs", nargs="+", default=None)
parser.add_argument("--outputs", nargs="+", default=None)

# Unused if we use SmartRedis
parser.add_argument("--min_batch_timeout", type=int, default=None)
args = parser.parse_args(db_model)

if args.inputs:
inputs = list(args.inputs)
if args.outputs:
outputs = list(args.outputs)

if args.devices_per_node == 1:
client.set_model_from_file(args.name,
args.file,
args.backend,
args.device,
args.batch_size,
args.min_batch_size,
args.tag,
inputs,
outputs)
else:
for device_num in range(args.devices_per_node):
client.set_model_from_file(args.name,
args.file,
args.backend,
args.device+f":{device_num}",
args.batch_size,
args.min_batch_size,
args.tag,
inputs,
outputs)

return args.name

def launch_db_script(client: Client, db_script: List[str]):
"""Parse options to launch script on local cluster

:param client: SmartRedis client connected to local DB
:type client: Client
:param db_model: List of arguments defining the script
:type db_model: List[str]
:return: Name of model
:rtype: str
"""
parser = argparse.ArgumentParser("Set script on DB")
parser.add_argument("--name", type=str)
parser.add_argument("--func", type=str)
parser.add_argument("--file", type=str)
parser.add_argument("--backend", type=str)
parser.add_argument("--device", type=str)
parser.add_argument("--devices_per_node", type=int)
args = parser.parse_args(db_script)
if args.func:
func = args.func.replace("\\n", "\n")

if args.devices_per_node == 1:
client.set_script(args.name,
func,
args.device)
else:
for device_num in range(args.devices_per_node):
client.set_script(args.name,
func,
args.device+f":{device_num}")
elif args.file:
if args.devices_per_node == 1:
client.set_script_from_file(args.name,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if these fail? Have we tried setting a bad model?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The functions are launched in a try-catch statement, I think it should catch such an exception. Do you think we should add a test for that case?

args.file,
args.device)
else:
for device_num in range(args.devices_per_node):
client.set_script_from_file(args.name,
args.file,
args.device+f":{device_num}")


return args.name


def main(network_interface: str, db_cpus: int, command: List[str], db_models: List[List[str]], db_scripts: List[List[str]]):
global DBPID

try:
Expand Down Expand Up @@ -102,6 +203,23 @@ def main(network_interface: str, db_cpus: int, command: List[str]):
f"\tCommand: {' '.join(cmd)}\n\n"
)))

if db_models or db_scripts:
try:
client = Client(cluster=False)
for i, db_model in enumerate(db_models):
logger.debug("Uploading model")
model_name = launch_db_model(client, db_model)
logger.debug(f"Added model {model_name} ({i+1}/{len(db_models)})")
for i, db_script in enumerate(db_scripts):
logger.debug("Uploading script")
script_name = launch_db_script(client, db_script)
logger.debug(f"Added script {script_name} ({i+1}/{len(db_scripts)})")
# Make sure we don't keep this around
del client
except RedisConnectionError:
raise SSInternalError("Failed to set model or script, could not connect to database")


for line in iter(p.stdout.readline, b""):
print(line.decode("utf-8").rstrip(), flush=True)

Expand Down Expand Up @@ -144,6 +262,8 @@ def cleanup():
parser.add_argument("+lockfile", type=str, help="Filename to create for single proc per host")
parser.add_argument("+db_cpus", type=int, default=2, help="Number of CPUs to use for DB")
parser.add_argument("+command", nargs="+", help="Command to run")
parser.add_argument("+db_model", nargs="+", action="append", default=[], help="Model to set on DB")
parser.add_argument("+db_script", nargs="+", action="append", default=[], help="Script to set on DB")
args = parser.parse_args()

tmp_lockfile = Path(tempfile.gettempdir()) / args.lockfile
Expand All @@ -160,7 +280,7 @@ def cleanup():
for sig in SIGNALS:
signal.signal(sig, handle_signal)

main(args.ifname, args.db_cpus, args.command)
main(args.ifname, args.db_cpus, args.command, args.db_model, args.db_script)

# gracefully exit the processes in the distributed application that
# we do not want to have start a colocated process. Only one process
Expand Down
68 changes: 66 additions & 2 deletions smartsim/_core/launcher/colocated.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import sys

from ..config import CONFIG
from ...error import SSUnsupportedError
from ..utils.helpers import create_lockfile_name


Expand Down Expand Up @@ -65,6 +67,7 @@ def write_colocated_launch_script(file_name, db_log, colocated_settings):

f.write(f"{colocated_cmd}\n")
f.write(f"DBPID=$!\n\n")

if colocated_settings["limit_app_cpus"]:
cpus = colocated_settings["cpus"]
f.write(
Expand Down Expand Up @@ -129,7 +132,7 @@ def _build_colocated_wrapper_cmd(port=6780,
# add extra redisAI configurations
for arg, value in rai_args.items():
if value:
# RAI wants arguments for inference in all capps
# RAI wants arguments for inference in all caps
# ex. THREADS_PER_QUEUE=1
db_cmd.append(f"{arg.upper()} {str(value)}")

Expand All @@ -142,17 +145,78 @@ def _build_colocated_wrapper_cmd(port=6780,
])
for db_arg, value in extra_db_args.items():
# replace "_" with "-" in the db_arg because we use kwargs
# for the extra configurations and Python doesn't allow a hypon
# for the extra configurations and Python doesn't allow a hyphen
# in a variable name. All redis and KeyDB configuration options
# use hyphens in their names.
db_arg = db_arg.replace("_", "-")
db_cmd.extend([
f"--{db_arg}",
value
])

db_models = kwargs.get("db_models", None)
if db_models:
db_model_cmd = _build_db_model_cmd(db_models)
db_cmd.extend(db_model_cmd)

db_scripts = kwargs.get("db_scripts", None)
if db_scripts:
db_script_cmd = _build_db_script_cmd(db_scripts)
db_cmd.extend(db_script_cmd)

# run colocated db in the background
db_cmd.append("&")

cmd.extend(db_cmd)
return " ".join(cmd)


def _build_db_model_cmd(db_models):
cmd = []
for db_model in db_models:
cmd.append("+db_model")
cmd.append(f"--name={db_model.name}")

# Here db_model.file is guaranteed to exist
# because we don't allow the user to pass a serialized DBModel
cmd.append(f"--file={db_model.file}")

cmd.append(f"--backend={db_model.backend}")
cmd.append(f"--device={db_model.device}")
cmd.append(f"--devices_per_node={db_model.devices_per_node}")
if db_model.batch_size:
cmd.append(f"--batch_size={db_model.batch_size}")
if db_model.min_batch_size:
cmd.append(f"--min_batch_size={db_model.min_batch_size}")
if db_model.min_batch_timeout:
cmd.append(f"--min_batch_timeout={db_model.min_batch_timeout}")
if db_model.tag:
cmd.append(f"--tag={db_model.tag}")
if db_model.inputs:
cmd.append("--inputs="+",".join(db_model.inputs))
if db_model.outputs:
cmd.append("--outputs="+",".join(db_model.outputs))

return cmd


def _build_db_script_cmd(db_scripts):
cmd = []
for db_script in db_scripts:
cmd.append("+db_script")
cmd.append(f"--name={db_script.name}")
if db_script.func:
# Notice that here db_script.func is guaranteed to be a str
# because we don't allow the user to pass a serialized function
sanitized_func = db_script.func.replace("\n", "\\n")
if not (sanitized_func.startswith("'") and sanitized_func.endswith("'")
or (sanitized_func.startswith('"') and sanitized_func.endswith('"'))):
sanitized_func = "\"" + sanitized_func + "\""
cmd.append(f"--func={sanitized_func}")
elif db_script.file:
cmd.append(f"--file={db_script.file}")
cmd.append(f"--device={db_script.device}")
cmd.append(f"--devices_per_node={db_script.devices_per_node}")

return cmd

2 changes: 1 addition & 1 deletion smartsim/_core/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .helpers import colorize, delete_elements, init_default, installed_redisai_backends
from .redis import check_cluster_status, create_cluster
from .redis import check_cluster_status, create_cluster, db_is_active
Loading