Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve testing robustness on SLURM machines #381

Merged
merged 73 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from 71 commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
c48db14
Update tests to pass on_wlm
ashao Oct 2, 2023
7e02516
Define make_test_dir and get_test_dir fixtures
al-rigazzi Oct 12, 2023
2f9882b
More permissive naming for caller_function
al-rigazzi Oct 12, 2023
978e177
Style
al-rigazzi Oct 12, 2023
0e25a4a
Update tests to pass on_wlm
ashao Oct 2, 2023
c6930c0
Respond to review feedback
ashao Oct 16, 2023
86e51f8
Modify for mpirun with PBS
ashao Oct 16, 2023
cc8c986
Merge branch 'fix_tests' of https://github.com/ashao/SmartSim into fi…
ashao Oct 16, 2023
89c2008
Fix db shutdown and some fixtures
al-rigazzi Oct 16, 2023
765c571
Update DBModel tests
al-rigazzi Oct 17, 2023
fde1efe
Merge branch 'develop' into test-tmp-dir
al-rigazzi Oct 17, 2023
ea6bac5
Begin adding a context manager for orchestrators in multidb cases
ashao Oct 17, 2023
662e4c5
Merge branch 'test-tmp-dir' of https://github.com/al-rigazzi/SmartSim…
ashao Oct 17, 2023
b46ccb9
More multidb tests wokring, stopping at test_multidb.py::test_multidb…
ashao Oct 18, 2023
30d3fe4
Fix start_in_context
al-rigazzi Oct 18, 2023
131310f
Fix fixture usage
al-rigazzi Oct 18, 2023
dc301bb
Fix get_status
al-rigazzi Oct 18, 2023
380c8ef
Fix mypy issues
al-rigazzi Oct 18, 2023
62b79b0
Fix a couple of tests
ashao Oct 18, 2023
34ee0c3
Merge branch 'fix_tests' of https://github.com/ashao/SmartSim into fi…
ashao Oct 18, 2023
ee39204
tests are passing on PBS
ashao Oct 18, 2023
9f8a623
Fix one last typo
ashao Oct 18, 2023
3769c90
Make reset_hosts work on LSF
al-rigazzi Oct 19, 2023
c959e17
Comply to mypy syntax for union
al-rigazzi Oct 19, 2023
336bebd
Update signatures in conftest.py
al-rigazzi Oct 25, 2023
a343e27
Address reviewer's comments
al-rigazzi Oct 26, 2023
2038471
Fix name collision in FileUtils
al-rigazzi Oct 26, 2023
9ff094a
Fix fixture usage
al-rigazzi Oct 26, 2023
ac89651
Fix lock scope
al-rigazzi Oct 27, 2023
f7be14a
Replace repeated module level function
ashao Oct 30, 2023
e2b9238
Remove extraneous print in add_batch_resources
ashao Oct 30, 2023
aa95d5e
Enforce type for batch resources
ashao Oct 30, 2023
9f4eac0
Reset license text after inadvertent find/replace
ashao Oct 31, 2023
e697f58
Fix SS env vars
al-rigazzi Nov 1, 2023
225b5b2
Merge branch 'develop' of https://github.com/CrayLabs/SmartSim into f…
al-rigazzi Nov 1, 2023
cce73db
Use db_identifier
al-rigazzi Nov 1, 2023
bd3ff35
Disable key prefixing for test colocated entities
al-rigazzi Nov 1, 2023
2eab465
Add test to validate uds socket file name
al-rigazzi Nov 2, 2023
fcdc5db
Remove "test_dir = test_dir"
al-rigazzi Nov 2, 2023
8524819
Fix typehinting
ashao Nov 2, 2023
be38db2
Address feedback from @drozt
ashao Nov 3, 2023
99f47f8
Separate db name and db id
al-rigazzi Nov 3, 2023
a9b3fff
Add test for db ids and names
al-rigazzi Nov 3, 2023
f664477
Fix db node test for local
al-rigazzi Nov 3, 2023
1b29f12
Addresse reviewer's comments
al-rigazzi Nov 9, 2023
bebbb04
Make socket filename unique in tests
al-rigazzi Nov 23, 2023
0b67edd
Fix smartredis test scripts
al-rigazzi Nov 24, 2023
05fe0b2
Make some asserts more helpful
al-rigazzi Nov 24, 2023
e66c65c
Patch TF multigpu tests
al-rigazzi Nov 25, 2023
a3842a4
Add info about num_test_devices
al-rigazzi Nov 25, 2023
05093b7
Add details to failing asserts in test_dbmodel
al-rigazzi Nov 26, 2023
1e101e7
Add mem cap to dataloader tests
al-rigazzi Nov 26, 2023
94d4790
Fix number of devices if not GPU
al-rigazzi Nov 26, 2023
dc78094
Merge branch 'develop' of https://github.com/CrayLabs/SmartSim into f…
al-rigazzi Nov 27, 2023
9af0c75
MyPy
al-rigazzi Nov 27, 2023
c68b2af
Merge branch 'develop' of https://github.com/CrayLabs/SmartSim into f…
al-rigazzi Nov 30, 2023
762db80
Spawn in TF saving/serializing in a new process to avoid a locked GPU
ashao Dec 1, 2023
31520a9
Revert "Spawn in TF saving/serializing in a new process to avoid a lo…
ashao Dec 1, 2023
b703bc9
Simplify the logic in QsubBatchSettings
ashao Dec 7, 2023
48defa2
Delete extraneous scripts
ashao Dec 8, 2023
b929ba8
Refactor QsubBatchSettings resources
ashao Dec 8, 2023
a0d8328
Merge branch 'develop' into fix_tests
al-rigazzi Dec 8, 2023
8e0e82f
Merge branch 'fix_tests' of https://github.com/ashao/SmartSim into fi…
al-rigazzi Dec 8, 2023
0979ced
Merge branch 'develop' into fix_tests
al-rigazzi Dec 8, 2023
2e72604
Delete misleading comment
al-rigazzi Dec 8, 2023
14c420d
Correct type hints and more robust resource validation
ashao Dec 8, 2023
bd4345e
Fix one use of | insteado t.Union
ashao Dec 9, 2023
f2123fa
Fix an incorrect typehint
ashao Dec 9, 2023
5ff7007
Yet another | instead of t.Union
ashao Dec 9, 2023
f485ad1
Remove extraneous assignment and blackify
ashao Dec 9, 2023
2824d69
Remove now invalid test and update type checking
ashao Dec 9, 2023
2358c48
Fix accidental collision with default value
ashao Dec 11, 2023
0663f5d
Update behaviour for test_create_pbs_batch
ashao Dec 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 103 additions & 121 deletions conftest.py

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions smartsim/_core/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json
import os
import psutil
import typing as t
Expand Down Expand Up @@ -179,6 +180,20 @@ def test_num_gpus(self) -> int: # pragma: no cover
def test_port(self) -> int: # pragma: no cover
return int(os.environ.get("SMARTSIM_TEST_PORT", 6780))

@property
def test_batch_resources(self) -> t.Dict[t.Any,t.Any]: # pragma: no cover
resource_str = os.environ.get("SMARTSIM_TEST_BATCH_RESOURCES", "{}")
resources = json.loads(resource_str)
if not isinstance(resources, dict):
raise TypeError(
(
"SMARTSIM_TEST_BATCH_RESOURCES was not interpreted as a "
"dictionary, check to make sure that it is a valid "
f"JSON string: {resource_str}"
)
)
return resources

@property
def test_interface(self) -> t.List[str]: # pragma: no cover
if interfaces_cfg := os.environ.get("SMARTSIM_TEST_INTERFACE", None):
Expand Down
63 changes: 39 additions & 24 deletions smartsim/_core/control/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from __future__ import annotations

import itertools
import os.path as osp
import pathlib
import pickle
Expand All @@ -39,9 +40,16 @@

from smartredis import Client, ConfigOptions

from smartsim._core.utils.network import get_ip_from_host

from ..._core.launcher.step import Step
from ..._core.utils.helpers import unpack_colo_db_identifier, unpack_db_identifier
from ..._core.utils.redis import db_is_active, set_ml_model, set_script, shutdown_db
from ..._core.utils.redis import (
db_is_active,
set_ml_model,
set_script,
shutdown_db_node,
)
from ...database import Orchestrator
from ...entity import (
Ensemble,
Expand Down Expand Up @@ -235,12 +243,22 @@ def stop_db(self, db: Orchestrator) -> None:
if db.batch:
self.stop_entity(db)
else:
shutdown_db(db.hosts, db.ports)
with JM_LOCK:
for entity in db:
job = self._jobs[entity.name]
job.set_status(STATUS_CANCELLED, "", 0, output=None, error=None)
self._jobs.move_to_completed(job)
for node in db.entities:
for host_ip, port in itertools.product(
(get_ip_from_host(host) for host in node.hosts), db.ports
):
Comment on lines +247 to +250
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably out of scope for this PR, but this feels like the intent of this loop could be better represented if it were made a method(s) on the orc/dbnode classes, rather than manually preforming this nested iteration.

What would you think about doing something like:

class Orchestrator:
    def get_running_hosts_and_ports(self):  # TODO: give this method a less dumb name
        return {
            node.get_running_hosts_and_ports()
            for node 
            in itertools.chain.from_iterable(self.entities)}

class DBNode:
    def get_running_hosts_and_ports(self):  # TODO: same here
        return {
            (shard.ip_address, shard.port)
            for shard 
            in self._get_launched_shard_info()}

class LaunchedShardData:
    @property
    def ip_address(self):
        return get_ip_from_host(self.hostname)

?

This way this we could abstract away the nested iteration, and the whole loop would just become:

for host, port in db.get_running_hosts_and_ports():
    retcode, _, _ = shutdown_db_node(host, port)
    ...  # etc.

which I tend to find a hair more readable

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair, but with orchestrator.py being close to 1000 lines, we may have to postpone this change to when the class will be restructured.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^^ Agreed. Let's kick this conversation out into a dedicated DB refactor ticket!

retcode, _, _ = shutdown_db_node(host_ip, port)
# Sometimes the DB will not shutdown (unless we force NOSAVE)
if retcode != 0:
self.stop_entity(node)
continue

job = self._jobs[node.name]
job.set_status(STATUS_CANCELLED, "", 0, output=None, error=None)
self._jobs.move_to_completed(job)

db.reset_hosts()

def stop_entity_list(self, entity_list: EntitySequence[SmartSimEntity]) -> None:
"""Stop an instance of an entity list
Expand Down Expand Up @@ -358,9 +376,9 @@ def _launch(
for orchestrator in manifest.dbs:
for key in self._jobs.get_db_host_addresses():
_, db_id = unpack_db_identifier(key, "_")
if orchestrator.name == db_id:
if orchestrator.db_identifier == db_id:
raise SSDBIDConflictError(
f"Database identifier {orchestrator.name}"
f"Database identifier {orchestrator.db_identifier}"
" has already been used. Pass in a unique"
" name for db_identifier"
)
Expand Down Expand Up @@ -600,30 +618,27 @@ def _prep_entity_client_env(self, entity: Model) -> None:

for db_id, addresses in address_dict.items():
db_name, _ = unpack_db_identifier(db_id, "_")

if addresses:
if len(addresses) <= 128:
client_env[f"SSDB{db_name}"] = ",".join(addresses)
else:
# Cap max length of SSDB
client_env[f"SSDB{db_name}"] = ",".join(addresses[:128])
if entity.incoming_entities:
client_env[f"SSKEYIN{db_name}"] = ",".join(
[in_entity.name for in_entity in entity.incoming_entities]
)
if entity.query_key_prefixing():
client_env[f"SSKEYOUT{db_name}"] = entity.name
# Cap max length of SSDB
client_env[f"SSDB{db_name}"] = ",".join(addresses[:128])

# Retrieve num_shards to append to client env
client_env[f"SR_DB_TYPE{db_name}"] = (
CLUSTERED if len(addresses) > 1 else STANDALONE
)

# Retrieve num_shards to append to client env
client_env[f"SR_DB_TYPE{db_name}"] = (
CLUSTERED if len(addresses) > 1 else STANDALONE
if entity.incoming_entities:
client_env["SSKEYIN"] = ",".join(
[in_entity.name for in_entity in entity.incoming_entities]
)
if entity.query_key_prefixing():
client_env["SSKEYOUT"] = entity.name

# Set address to local if it's a colocated model
if entity.colocated and entity.run_settings.colocated_db_settings is not None:
db_name_colo = entity.run_settings.colocated_db_settings["db_identifier"]

for key in self._jobs.get_db_host_addresses():
for key in address_dict:
_, db_id = unpack_db_identifier(key, "_")
if db_name_colo == db_id:
raise SSDBIDConflictError(
Expand Down
29 changes: 17 additions & 12 deletions smartsim/_core/control/jobmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@

from ...database import Orchestrator
from ...entity import DBNode, SmartSimEntity, EntitySequence
from ...error import SmartSimError
from ...log import get_logger
from ...status import TERMINAL_STATUSES
from ...status import TERMINAL_STATUSES, STATUS_NEVER_STARTED
from ..config import CONFIG
from ..launcher import LocalLauncher, Launcher
from ..utils.network import get_ip_from_host
Expand Down Expand Up @@ -160,6 +159,13 @@ def __call__(self) -> t.Dict[str, Job]:
all_jobs = {**self.jobs, **self.db_jobs}
return all_jobs

def __contains__(self, key: str) -> bool:
try:
self[key] # pylint: disable=pointless-statement
return True
except KeyError:
return False

def add_job(
self,
job_name: str,
Expand Down Expand Up @@ -242,17 +248,14 @@ def get_status(
:returns: tuple of status
"""
with self._lock:
try:
if entity.name in self.completed:
return self.completed[entity.name].status
if entity.name in self.completed:
return self.completed[entity.name].status

if entity.name in self:
job: Job = self[entity.name] # locked
except KeyError:
raise SmartSimError(
f"Entity {entity.name} has not been launched in this Experiment"
) from None
return job.status

return job.status
return STATUS_NEVER_STARTED

def set_launcher(self, launcher: Launcher) -> None:
"""Set the launcher of the job manager to a specific launcher instance
Expand Down Expand Up @@ -312,7 +315,7 @@ def get_db_host_addresses(self) -> t.Dict[str, t.List[str]]:
:rtype: Dict[str, list]
"""

address_dict = {}
address_dict: t.Dict[str, t.List[str]] = {}
for db_job in self.db_jobs.values():
addresses = []
if isinstance(db_job.entity, (DBNode, Orchestrator)):
Expand All @@ -321,7 +324,9 @@ def get_db_host_addresses(self) -> t.Dict[str, t.List[str]]:
ip_addr = get_ip_from_host(combine[0])
addresses.append(":".join((ip_addr, str(combine[1]))))

address_dict.update({db_entity.name: addresses})
dict_entry: t.List[str] = address_dict.get(db_entity.db_identifier, [])
dict_entry.extend(addresses)
address_dict[db_entity.db_identifier] = dict_entry

return address_dict

Expand Down
24 changes: 12 additions & 12 deletions smartsim/_core/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,28 +40,28 @@


def unpack_db_identifier(db_id: str, token: str) -> t.Tuple[str, str]:
"""Unpack the unformatted database identifier using the token,
and format for env variable suffix
:db_id: the unformatted database identifier eg. identifier_1_0
:token: character '_' or '-' to use to unpack the database identifier
:return: db suffix, and formatted db_id eg. _identifier_1, identifier_1
"""Unpack the unformatted database identifier
and format for env variable suffix using the token
:param db_id: the unformatted database identifier eg. identifier_1
:type db_id: str
:param token: character to use to construct the db suffix
:type token: str
:return: db id suffix and formatted db_id e.g. ("_identifier_1", "identifier_1")
:rtype: (str, str)
"""

if db_id == "orchestrator":
return "", ""
db_id = "_".join(db_id.split(token)[:-1])
# if unpacked db_id is default, return empty
if db_id == "orchestrator":
# if db_id is default after unpack, return empty
return "", ""
db_name_suffix = "_" + db_id
db_name_suffix = token + db_id
return db_name_suffix, db_id


def unpack_colo_db_identifier(db_id: str) -> str:
"""Create database identifier suffix for colocated database
:db_id: the unformatted database identifier
:param db_id: the unformatted database identifier
:type db_id: str
:return: db suffix
:rtype: str
"""
return "_" + db_id if db_id else ""

Expand Down
44 changes: 21 additions & 23 deletions smartsim/_core/utils/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import itertools
import logging
import redis
import time
Expand Down Expand Up @@ -219,31 +218,30 @@ def set_script(db_script: DBScript, client: Client) -> None:
raise error


def shutdown_db(hosts: t.List[str], ports: t.List[int]) -> None: # cov-wlm
"""Send shutdown signal to cluster instances.
def shutdown_db_node(host_ip: str, port: int) -> t.Tuple[int, str, str]: # cov-wlm
"""Send shutdown signal to DB node.

Should only be used in the case where cluster deallocation
needs to occur manually. Usually, the SmartSim task manager
needs to occur manually. Usually, the SmartSim job manager
will take care of this automatically.

:param hosts: List of hostnames to connect to
:type hosts: List[str]
:param ports: List of ports for each hostname
:type ports: List[int]
:raises SmartSimError: if cluster creation fails
:param host_ip: IP of host to connect to
:type hosts: str
:param ports: Port to which node is listening
:type ports: int
:return: returncode, output, and error of the process
:rtype: tuple of (int, str, str)
"""
for host_ip, port in itertools.product(
(get_ip_from_host(host) for host in hosts), ports
):
# call cluster command
redis_cli = CONFIG.database_cli
cmd = [redis_cli, "-h", host_ip, "-p", str(port), "shutdown"]
returncode, out, err = execute_cmd(
cmd, proc_input="yes", shell=False, timeout=10
)
redis_cli = CONFIG.database_cli
cmd = [redis_cli, "-h", host_ip, "-p", str(port), "shutdown"]
returncode, out, err = execute_cmd(
cmd, proc_input="yes", shell=False, timeout=10
)

if returncode != 0:
logger.error(out)
logger.error(err)
elif out:
logger.debug(out)

if returncode != 0:
logger.error(out)
logger.error(err)
else:
logger.debug(out)
return returncode, out, err
Loading