Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve fetching and caching job details #3194

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -334,13 +334,14 @@ def sync_placement_groups():
"STOPPED",
"SUSPENDED",
"COMPLETING",
"PENDING",
]
)

keep_jobs = {
str(job["job_id"])
for job in json.loads(run(f"{lookup().scontrol} show jobs --json").stdout)["jobs"]
if "job_state" in job and set(job["job_state"]) & keep_states
str(job.id)
for job in lookup().get_jobs()
if job.job_state in keep_states
}
keep_jobs.add("0") # Job 0 is a placeholder for static node placement

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from mock import Mock
from common import TstNodeset, TstCfg # needed to import util
import util
from datetime import timedelta
from google.api_core.client_options import ClientOptions # noqa: E402

# Note: need to install pytest-mock
Expand Down Expand Up @@ -158,14 +159,14 @@ def test_nodeset_reservation_err(nodeset, err):
with pytest.raises(err):
lkp.nodeset_reservation(nodeset)
lkp._get_reservation.assert_not_called()

@pytest.mark.parametrize(
"nodeset,policies,expected",
[
(TstNodeset(), [], None), # no reservation
(TstNodeset(
reservation_name="projects/bobin/reservations/robin",
zone_policy_allow=["eine"]),
zone_policy_allow=["eine"]),
[],
util.ReservationDetails(
project="bobin",
Expand All @@ -175,7 +176,7 @@ def test_nodeset_reservation_err(nodeset, err):
bulk_insert_name="projects/bobin/reservations/robin")),
(TstNodeset(
reservation_name="projects/bobin/reservations/robin",
zone_policy_allow=["eine"]),
zone_policy_allow=["eine"]),
["seven/wanders", "five/red/apples", "yum"],
util.ReservationDetails(
project="bobin",
Expand All @@ -185,7 +186,7 @@ def test_nodeset_reservation_err(nodeset, err):
bulk_insert_name="projects/bobin/reservations/robin")),
(TstNodeset(
reservation_name="projects/bobin/reservations/robin/snek/cheese-brie-6",
zone_policy_allow=["eine"]),
zone_policy_allow=["eine"]),
[],
util.ReservationDetails(
project="bobin",
Expand All @@ -199,16 +200,76 @@ def test_nodeset_reservation_err(nodeset, err):
def test_nodeset_reservation_ok(nodeset, policies, expected):
lkp = util.Lookup(TstCfg())
lkp._get_reservation = Mock()

if not expected:
assert lkp.nodeset_reservation(nodeset) is None
lkp._get_reservation.assert_not_called()
return

lkp._get_reservation.return_value = {
"resourcePolicies": {i: p for i, p in enumerate(policies)},
}
assert lkp.nodeset_reservation(nodeset) == expected
lkp._get_reservation.assert_called_once_with(expected.project, expected.zone, expected.name)




@pytest.mark.parametrize(
"job_info,expected_job",
[
(
"""JobId=123
TimeLimit=02:00:00
JobName=myjob
JobState=PENDING
ReqNodeList=node-[1-10]""",
util.Job(
id=123,
duration=timedelta(days=0, hours=2, minutes=0, seconds=0),
name="myjob",
job_state="PENDING",
required_nodes="node-[1-10]"
),
),
(
"""JobId=456
JobName=anotherjob
JobState=PENDING
ReqNodeList=node-group1""",
util.Job(
id=456,
duration=None,
name="anotherjob",
job_state="PENDING",
required_nodes="node-group1"
),
),
(
"""JobId=789
TimeLimit=00:30:00
JobState=COMPLETED""",
util.Job(
id=789,
duration=timedelta(minutes=30),
name=None,
job_state="COMPLETED",
required_nodes=None
),
),
(
"""JobId=101112
TimeLimit=1-00:30:00
JobState=COMPLETED,
ReqNodeList=node-[1-10],grob-pop-[2,1,44-77]""",
util.Job(
id=101112,
duration=timedelta(days=1, hours=0, minutes=30, seconds=0),
name=None,
job_state="COMPLETED",
required_nodes="node-[1-10],grob-pop-[2,1,44-77]"
),
),
],
)
def test_parse_job_info(job_info, expected_job):
lkp = util.Lookup(TstCfg())
assert lkp._parse_job_info(job_info) == expected_job
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def _fill_cfg_defaults(cfg: NSDict) -> NSDict:
"mount_options": "defaults,hard,intr,_netdev",
}
)

network_storage_iter = filter(
None,
(
Expand Down Expand Up @@ -474,8 +474,8 @@ def _download(bs) -> List[Any]:
), hash

def _assemble_config(
core: Any,
partitions: List[Any],
core: Any,
partitions: List[Any],
nodesets: List[Any],
nodesets_dyn: List[Any],
nodesets_tpu: List[Any],
Expand Down Expand Up @@ -510,17 +510,17 @@ def _add_nodesets(yamls: List[Any], target: dict):
for ns_name in chain(p.partition_nodeset, p.partition_nodeset_dyn, p.partition_nodeset_tpu):
if ns_name not in ns_names:
raise DeffetiveStoredConfigError(f"nodeset {ns_name} not defined in config")

return _fill_cfg_defaults(cfg)

def fetch_config() -> Tuple[bool, NSDict]:
"""
Fetches config from bucket and saves it locally
Fetches config from bucket and saves it locally
Returns True if new (updated) config was fetched
"""
hash_file = Path("/slurm/scripts/.config.hash")
old_hash = hash_file.read_text() if hash_file.exists() else None

cfg_and_hash = _fetch_config(old_hash=old_hash)
if not cfg_and_hash:
return False, _load_config()
Expand Down Expand Up @@ -1460,8 +1460,12 @@ class ReservationDetails:
@dataclass
class Job:
id: int
name: Optional[str] = None
required_nodes: Optional[str] = None
job_state: Optional[str] = None
duration: Optional[timedelta] = None


class Lookup:
"""Wrapper class for cached data access"""

Expand Down Expand Up @@ -1757,11 +1761,11 @@ def _get_reservation(self, project: str, zone: str, name: str) -> object:
"""See https://cloud.google.com/compute/docs/reference/rest/v1/reservations"""
return self.compute.reservations().get(
project=project, zone=zone, reservation=name).execute()

def nodeset_reservation(self, nodeset: object) -> Optional[ReservationDetails]:
if not nodeset.reservation_name:
return None

zones = list(nodeset.zone_policy_allow or [])
assert len(zones) == 1, "Only single zone is supported if using a reservation"
zone = zones[0]
Expand All @@ -1771,7 +1775,7 @@ def nodeset_reservation(self, nodeset: object) -> Optional[ReservationDetails]:
raise ValueError(
f"Invalid reservation name: '{nodeset.reservation_name}', expected format is 'projects/PROJECT/reservations/NAME'"
)

project, name = match.group("project", "reservation")
reservation = self._get_reservation(project, zone, name)

Expand Down Expand Up @@ -1928,26 +1932,54 @@ def nodeset_map(self, hostnames: list):
nodeset_map[self.node_nodeset_name(node)].append(node)
return nodeset_map

def _parse_job_info(self, job_info: str) -> Job:
"""Extract job details"""
if match:= re.search(r"JobId=(\d+)", job_info):
job_id = int(match.group(1))
else:
raise ValueError(f"Job ID not found in the job info: {job_info}")

if match:= re.search(r"TimeLimit=(?:(\d+)-)?(\d{2}):(\d{2}):(\d{2})", job_info):
days, hours, minutes, seconds = match.groups()
duration = timedelta(
days=int(days) if days else 0,
hours=int(hours),
minutes=int(minutes),
seconds=int(seconds)
)
else:
duration = None

if match := re.search(r"JobName=(\w+)", job_info):
name = match.group(1)
else:
name = None

if match := re.search(r"JobState=(\w+)", job_info):
job_state = match.group(1)
else:
job_state = None

if match := re.search(r"ReqNodeList=([^ ]+)", job_info):
required_nodes = match.group(1)
else:
required_nodes = None

return Job(id=job_id, duration=duration, name=name, job_state=job_state, required_nodes=required_nodes)

@lru_cache
def job(self, job_id: int) -> Optional[Job]:
jobInfo = run(f"{self.scontrol} show jobid {job_id}", check=False).stdout.rstrip()
if not jobInfo:
return None
def get_jobs(self) -> List[Job]:
res = run(f"{self.scontrol} show jobs", timeout=30)

timePattern = r"TimeLimit=(?:(\d+)-)?(\d{2}):(\d{2}):(\d{2})"
match = re.search(timePattern, jobInfo)
return [self._parse_job_info(job) for job in res.stdout.split("\n\n")[:-1]]

if not match:
return Job(id=job_id)
@lru_cache
def job(self, job_id: int) -> Optional[Job]:
job_info = run(f"{self.scontrol} show jobid {job_id}", check=False).stdout.rstrip()
if not job_info:
return None

days, hours, minutes, seconds = match.groups()
job_duration = timedelta(
days=int(days) if days else 0,
hours=int(hours),
minutes=int(minutes),
seconds=int(seconds)
)
return Job(id=job_id, duration=job_duration)
return self._parse_job_info(job_info=job_info)

@property
def etc_dir(self) -> Path:
Expand Down
Loading