Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add retry mechanism for GPU device check in DockerEnv #573

Merged
merged 5 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions rdagent/oai/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import uuid
from copy import deepcopy
from pathlib import Path
from typing import Any, Optional
from typing import Any, Optional, cast

import numpy as np
import tiktoken
Expand Down Expand Up @@ -153,7 +153,7 @@ def embedding_set(self, content_to_embedding_dict: dict) -> None:
def message_get(self, conversation_id: str) -> list[dict[str, Any]]:
self.c.execute("SELECT message FROM message_cache WHERE conversation_id=?", (conversation_id,))
result = self.c.fetchone()
return [] if result is None else json.loads(result[0])
return [] if result is None else cast(list[dict[str, Any]], json.loads(result[0]))

def message_set(self, conversation_id: str, message_value: list[dict[str, Any]]) -> None:
self.c.execute(
Expand Down
18 changes: 12 additions & 6 deletions rdagent/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from rdagent.core.experiment import RD_AGENT_SETTINGS
from rdagent.log import rdagent_logger as logger
from rdagent.oai.llm_utils import md5_hash
from rdagent.utils.workflow import wait_retry

ASpecificBaseModel = TypeVar("ASpecificBaseModel", bound=BaseModel)

Expand Down Expand Up @@ -313,12 +314,17 @@ def _gpu_kwargs(self, client: docker.DockerClient) -> dict: # type: ignore[no-a
[docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if self.conf.enable_gpu else None
),
}
try:
client.containers.run(self.conf.image, "nvidia-smi", **gpu_kwargs)
logger.info("GPU Devices are available.")
except docker.errors.APIError:
return {}
return gpu_kwargs

@wait_retry(5, 10)
def _f() -> dict:
try:
client.containers.run(self.conf.image, "nvidia-smi", **gpu_kwargs)
logger.info("GPU Devices are available.")
except docker.errors.APIError:
return {}
return gpu_kwargs

return _f()

def replace_time_info(self, input_string: str) -> str:
"""To remove any time related information from the logs since it will destroy the cache mechanism"""
Expand Down
68 changes: 38 additions & 30 deletions rdagent/utils/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import pickle
import time
from collections import defaultdict
from dataclasses import dataclass, field
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable
from typing import Any, Callable, TypeVar, cast

from tqdm.auto import tqdm

Expand All @@ -23,7 +23,7 @@

class LoopMeta(type):
@staticmethod
def _get_steps(bases):
def _get_steps(bases: tuple[type, ...]) -> list[str]:
"""
Recursively get all the `steps` from the base classes and combine them into a single list.

Expand All @@ -40,7 +40,7 @@ def _get_steps(bases):
steps.append(step)
return steps

def __new__(cls, clsname, bases, attrs):
def __new__(mcs, clsname: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> Any:
"""
Create a new class with combined steps from base classes and current class.

Expand All @@ -54,13 +54,13 @@ def __new__(cls, clsname, bases, attrs):
"""
steps = LoopMeta._get_steps(bases) # all the base classes of parents
for name, attr in attrs.items():
if not name.startswith("_") and isinstance(attr, Callable):
if not name.startswith("_") and callable(attr):
if name not in steps:
# NOTE: if we override the step in the subclass
# Then it is not the new step. So we skip it.
steps.append(name)
attrs["steps"] = steps
return super().__new__(cls, clsname, bases, attrs)
return super().__new__(mcs, clsname, bases, attrs)


@dataclass
Expand All @@ -77,23 +77,21 @@ class LoopBase:
- The last step is responsible for recording information!!!!
"""

steps: list[Callable] # a list of steps to work on
steps: list[str] # a list of steps to work on
loop_trace: dict[int, list[LoopTrace]]

skip_loop_error: tuple[Exception] = field(
default_factory=tuple
) # you can define a list of error that will skip current loop
skip_loop_error: tuple[type[BaseException], ...] = () # you can define a list of error that will skip current loop

EXCEPTION_KEY = "_EXCEPTION"

def __init__(self):
def __init__(self) -> None:
self.loop_idx = 0 # current loop index
self.step_idx = 0 # the index of next step to be run
self.loop_prev_out = {} # the step results of current loop
self.loop_prev_out: dict[str, Any] = {} # the step results of current loop
self.loop_trace = defaultdict(list[LoopTrace]) # the key is the number of loop
self.session_folder = logger.log_trace_path / "__session__"

def run(self, step_n: int | None = None):
def run(self, step_n: int | None = None) -> None:
"""

Parameters
Expand All @@ -114,17 +112,20 @@ def run(self, step_n: int | None = None):
logger.info(f"Start Loop {li}, Step {si}: {name}")
with logger.tag(f"Loop_{li}.{name}"):
start = datetime.datetime.now(datetime.timezone.utc)
func = getattr(self, name)
func: Callable[..., Any] = cast(Callable[..., Any], getattr(self, name))
try:
self.loop_prev_out[name] = func(self.loop_prev_out)
# TODO: Fix the error logger.exception(f"Skip loop {li} due to {e}")
except self.skip_loop_error as e:
# FIXME: This does not support previous demo (due to their last step is not for recording)
logger.warning(f"Skip loop {li} due to {e}")
# NOTE: strong assumption! The last step is responsible for recording information
self.step_idx = len(self.steps) - 1 # directly jump to the last step.
self.loop_prev_out[self.EXCEPTION_KEY] = e
continue
except Exception as e:
if isinstance(e, self.skip_loop_error):
# FIXME: This does not support previous demo (due to their last step is not for recording)
logger.warning(f"Skip loop {li} due to {e}")
# NOTE: strong assumption! The last step is responsible for recording information
self.step_idx = len(self.steps) - 1 # directly jump to the last step.
self.loop_prev_out[self.EXCEPTION_KEY] = e
continue
else:
raise
finally:
# make sure failure steps are displayed correclty
end = datetime.datetime.now(datetime.timezone.utc)
Expand All @@ -145,27 +146,30 @@ def run(self, step_n: int | None = None):

self.dump(self.session_folder / f"{li}" / f"{si}_{name}") # save a snapshot after the session

def dump(self, path: str | Path):
def dump(self, path: str | Path) -> None:
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("wb") as f:
pickle.dump(self, f)

@classmethod
def load(cls, path: str | Path):
def load(cls, path: str | Path) -> "LoopBase":
path = Path(path)
with path.open("rb") as f:
session = pickle.load(f)
session = cast(LoopBase, pickle.load(f))
logger.set_trace_path(session.session_folder.parent)

max_loop = max(session.loop_trace.keys())
logger.storage.truncate(time=session.loop_trace[max_loop][-1].end)
return session


ASpecificRet = TypeVar("ASpecificRet")


def wait_retry(
retry_n: int = 3, sleep_time: int = 1, transform_args_fn: Callable[[tuple, dict], tuple[tuple, dict]] | None = None
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
) -> Callable[[Callable[..., ASpecificRet]], Callable[..., ASpecificRet]]:
"""Decorator to wait and retry the function for retry_n times.

Example:
Expand All @@ -188,20 +192,24 @@ def wait_retry(
>>> counter
2
"""
assert retry_n > 0, "retry_n should be greater than 0"

def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args: Any, **kwargs: Any) -> Any:
for i in range(retry_n):
def decorator(f: Callable[..., ASpecificRet]) -> Callable[..., ASpecificRet]:
def wrapper(*args: Any, **kwargs: Any) -> ASpecificRet:
for i in range(retry_n + 1):
try:
return f(*args, **kwargs)
except Exception as e:
print(f"Error: {e}")
time.sleep(sleep_time)
if i == retry_n - 1:
raise e
if i == retry_n:
raise
# Update args and kwargs using the transform function if provided.
if transform_args_fn is not None:
args, kwargs = transform_args_fn(args, kwargs)
else:
# just for passing mypy CI.
return f(*args, **kwargs)

return wrapper

Expand Down