Skip to content

Commit

Permalink
handle nvcf container timeouts (#399)
Browse files Browse the repository at this point in the history
* wrap logging round NVCF calls

* defensively wrap nvcf 404s: handle missing reqids, internal timeouts, nvcf 1hr timeout

* also backoff + retry on timeout
  • Loading branch information
leondz authored Jan 17, 2024
1 parent 334f797 commit d4721c9
Showing 1 changed file with 33 additions and 5 deletions.
38 changes: 33 additions & 5 deletions garak/generators/nvcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@

"""NVCF LLM interface"""

import random
from sqlite3 import Time
import backoff
import logging
import os
import requests
import time

from garak import _config
from garak.generators.base import Generator
Expand All @@ -24,6 +27,10 @@ class NvcfGenerator(Generator):
fetch_url_format = "https://api.nvcf.nvidia.com/v2/nvcf/pexec/status/"
invoke_url_base = "https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions/"

extra_nvcf_logging = False

timeout = 60

def __init__(self, name=None, generations=10):
self.name = name
self.fullname = f"NVCF {self.name}"
Expand All @@ -48,7 +55,16 @@ def __init__(self, name=None, generations=10):
"Accept": "application/json",
}

def _call_model(self, prompt):
@backoff.on_exception(
backoff.fibo,
(
AttributeError,
TimeoutError,
requests.exceptions.HTTPError,
),
max_value=70,
)
def _call_model(self, prompt: str) -> str:
if prompt == "":
return ""

Expand All @@ -65,17 +81,29 @@ def _call_model(self, prompt):
if self.seed is not None:
payload["seed"] = self.seed

request_time = time.time()
response = session.post(self.invoke_url, headers=self.headers, json=payload)

while response.status_code == 202:
if time.time() > request_time + self.timeout:
raise TimeoutError("NVCF Request timed out")
request_id = response.headers.get("NVCF-REQID")
if request_id is None:
msg = "Got HTTP 202 but no NVCF-REQID was returned"
logging.info("nvcf : %s", msg)
raise AttributeError(msg)
fetch_url = self.fetch_url_format + request_id
response = session.get(fetch_url, headers=self.headers)

response.raise_for_status()
response_body = response.json()
if 400 <= response.status_code < 600:
logging.warning("nvcf : returned error code %s", response.status_code)
logging.warning("nvcf : returned error body %s", response.content)
response.raise_for_status()

else:
response_body = response.json()

return response_body["choices"][0]["message"]["content"]
return response_body["choices"][0]["message"]["content"]


default_class = "NvcfGenerator"

0 comments on commit d4721c9

Please sign in to comment.