Skip to content

Commit

Permalink
Merge pull request #3194 from harshthakkar01/cache-jobs
Browse files Browse the repository at this point in the history
Improve fetching and caching job details
  • Loading branch information
harshthakkar01 authored Nov 7, 2024
2 parents 56096a0 + d88cd50 commit ba8b179
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -334,13 +334,14 @@ def sync_placement_groups():
"STOPPED",
"SUSPENDED",
"COMPLETING",
"PENDING",
]
)

keep_jobs = {
str(job["job_id"])
for job in json.loads(run(f"{lookup().scontrol} show jobs --json").stdout)["jobs"]
if "job_state" in job and set(job["job_state"]) & keep_states
str(job.id)
for job in lookup().get_jobs()
if job.job_state in keep_states
}
keep_jobs.add("0") # Job 0 is a placeholder for static node placement

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from mock import Mock
from common import TstNodeset, TstCfg # needed to import util
import util
from datetime import timedelta
from google.api_core.client_options import ClientOptions # noqa: E402

# Note: need to install pytest-mock
Expand Down Expand Up @@ -158,14 +159,14 @@ def test_nodeset_reservation_err(nodeset, err):
with pytest.raises(err):
lkp.nodeset_reservation(nodeset)
lkp._get_reservation.assert_not_called()

@pytest.mark.parametrize(
"nodeset,policies,expected",
[
(TstNodeset(), [], None), # no reservation
(TstNodeset(
reservation_name="projects/bobin/reservations/robin",
zone_policy_allow=["eine"]),
zone_policy_allow=["eine"]),
[],
util.ReservationDetails(
project="bobin",
Expand All @@ -175,7 +176,7 @@ def test_nodeset_reservation_err(nodeset, err):
bulk_insert_name="projects/bobin/reservations/robin")),
(TstNodeset(
reservation_name="projects/bobin/reservations/robin",
zone_policy_allow=["eine"]),
zone_policy_allow=["eine"]),
["seven/wanders", "five/red/apples", "yum"],
util.ReservationDetails(
project="bobin",
Expand All @@ -185,7 +186,7 @@ def test_nodeset_reservation_err(nodeset, err):
bulk_insert_name="projects/bobin/reservations/robin")),
(TstNodeset(
reservation_name="projects/bobin/reservations/robin/snek/cheese-brie-6",
zone_policy_allow=["eine"]),
zone_policy_allow=["eine"]),
[],
util.ReservationDetails(
project="bobin",
Expand All @@ -199,16 +200,76 @@ def test_nodeset_reservation_err(nodeset, err):
def test_nodeset_reservation_ok(nodeset, policies, expected):
lkp = util.Lookup(TstCfg())
lkp._get_reservation = Mock()

if not expected:
assert lkp.nodeset_reservation(nodeset) is None
lkp._get_reservation.assert_not_called()
return

lkp._get_reservation.return_value = {
"resourcePolicies": {i: p for i, p in enumerate(policies)},
}
assert lkp.nodeset_reservation(nodeset) == expected
lkp._get_reservation.assert_called_once_with(expected.project, expected.zone, expected.name)




@pytest.mark.parametrize(
"job_info,expected_job",
[
(
"""JobId=123
TimeLimit=02:00:00
JobName=myjob
JobState=PENDING
ReqNodeList=node-[1-10]""",
util.Job(
id=123,
duration=timedelta(days=0, hours=2, minutes=0, seconds=0),
name="myjob",
job_state="PENDING",
required_nodes="node-[1-10]"
),
),
(
"""JobId=456
JobName=anotherjob
JobState=PENDING
ReqNodeList=node-group1""",
util.Job(
id=456,
duration=None,
name="anotherjob",
job_state="PENDING",
required_nodes="node-group1"
),
),
(
"""JobId=789
TimeLimit=00:30:00
JobState=COMPLETED""",
util.Job(
id=789,
duration=timedelta(minutes=30),
name=None,
job_state="COMPLETED",
required_nodes=None
),
),
(
"""JobId=101112
TimeLimit=1-00:30:00
JobState=COMPLETED,
ReqNodeList=node-[1-10],grob-pop-[2,1,44-77]""",
util.Job(
id=101112,
duration=timedelta(days=1, hours=0, minutes=30, seconds=0),
name=None,
job_state="COMPLETED",
required_nodes="node-[1-10],grob-pop-[2,1,44-77]"
),
),
],
)
def test_parse_job_info(job_info, expected_job):
lkp = util.Lookup(TstCfg())
assert lkp._parse_job_info(job_info) == expected_job
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def _fill_cfg_defaults(cfg: NSDict) -> NSDict:
"mount_options": "defaults,hard,intr,_netdev",
}
)

network_storage_iter = filter(
None,
(
Expand Down Expand Up @@ -474,8 +474,8 @@ def _download(bs) -> List[Any]:
), hash

def _assemble_config(
core: Any,
partitions: List[Any],
core: Any,
partitions: List[Any],
nodesets: List[Any],
nodesets_dyn: List[Any],
nodesets_tpu: List[Any],
Expand Down Expand Up @@ -510,17 +510,17 @@ def _add_nodesets(yamls: List[Any], target: dict):
for ns_name in chain(p.partition_nodeset, p.partition_nodeset_dyn, p.partition_nodeset_tpu):
if ns_name not in ns_names:
raise DeffetiveStoredConfigError(f"nodeset {ns_name} not defined in config")

return _fill_cfg_defaults(cfg)

def fetch_config() -> Tuple[bool, NSDict]:
"""
Fetches config from bucket and saves it locally
Fetches config from bucket and saves it locally
Returns True if new (updated) config was fetched
"""
hash_file = Path("/slurm/scripts/.config.hash")
old_hash = hash_file.read_text() if hash_file.exists() else None

cfg_and_hash = _fetch_config(old_hash=old_hash)
if not cfg_and_hash:
return False, _load_config()
Expand Down Expand Up @@ -1464,8 +1464,12 @@ class ReservationDetails:
@dataclass
class Job:
id: int
name: Optional[str] = None
required_nodes: Optional[str] = None
job_state: Optional[str] = None
duration: Optional[timedelta] = None


class Lookup:
"""Wrapper class for cached data access"""

Expand Down Expand Up @@ -1761,11 +1765,11 @@ def _get_reservation(self, project: str, zone: str, name: str) -> object:
"""See https://cloud.google.com/compute/docs/reference/rest/v1/reservations"""
return self.compute.reservations().get(
project=project, zone=zone, reservation=name).execute()

def nodeset_reservation(self, nodeset: object) -> Optional[ReservationDetails]:
if not nodeset.reservation_name:
return None

zones = list(nodeset.zone_policy_allow or [])
assert len(zones) == 1, "Only single zone is supported if using a reservation"
zone = zones[0]
Expand All @@ -1775,7 +1779,7 @@ def nodeset_reservation(self, nodeset: object) -> Optional[ReservationDetails]:
raise ValueError(
f"Invalid reservation name: '{nodeset.reservation_name}', expected format is 'projects/PROJECT/reservations/NAME'"
)

project, name = match.group("project", "reservation")
reservation = self._get_reservation(project, zone, name)

Expand Down Expand Up @@ -1932,26 +1936,54 @@ def nodeset_map(self, hostnames: list):
nodeset_map[self.node_nodeset_name(node)].append(node)
return nodeset_map

def _parse_job_info(self, job_info: str) -> Job:
"""Extract job details"""
if match:= re.search(r"JobId=(\d+)", job_info):
job_id = int(match.group(1))
else:
raise ValueError(f"Job ID not found in the job info: {job_info}")

if match:= re.search(r"TimeLimit=(?:(\d+)-)?(\d{2}):(\d{2}):(\d{2})", job_info):
days, hours, minutes, seconds = match.groups()
duration = timedelta(
days=int(days) if days else 0,
hours=int(hours),
minutes=int(minutes),
seconds=int(seconds)
)
else:
duration = None

if match := re.search(r"JobName=(\w+)", job_info):
name = match.group(1)
else:
name = None

if match := re.search(r"JobState=(\w+)", job_info):
job_state = match.group(1)
else:
job_state = None

if match := re.search(r"ReqNodeList=([^ ]+)", job_info):
required_nodes = match.group(1)
else:
required_nodes = None

return Job(id=job_id, duration=duration, name=name, job_state=job_state, required_nodes=required_nodes)

@lru_cache
def job(self, job_id: int) -> Optional[Job]:
jobInfo = run(f"{self.scontrol} show jobid {job_id}", check=False).stdout.rstrip()
if not jobInfo:
return None
def get_jobs(self) -> List[Job]:
res = run(f"{self.scontrol} show jobs", timeout=30)

timePattern = r"TimeLimit=(?:(\d+)-)?(\d{2}):(\d{2}):(\d{2})"
match = re.search(timePattern, jobInfo)
return [self._parse_job_info(job) for job in res.stdout.split("\n\n")[:-1]]

if not match:
return Job(id=job_id)
@lru_cache
def job(self, job_id: int) -> Optional[Job]:
job_info = run(f"{self.scontrol} show jobid {job_id}", check=False).stdout.rstrip()
if not job_info:
return None

days, hours, minutes, seconds = match.groups()
job_duration = timedelta(
days=int(days) if days else 0,
hours=int(hours),
minutes=int(minutes),
seconds=int(seconds)
)
return Job(id=job_id, duration=job_duration)
return self._parse_job_info(job_info=job_info)

@property
def etc_dir(self) -> Path:
Expand Down

0 comments on commit ba8b179

Please sign in to comment.