Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hotfix/web retrieval exploit #590

Merged
merged 15 commits into from
Feb 14, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Storing only the last 200 urls
  • Loading branch information
richwardle committed Feb 14, 2025
commit a1cf6c0d82065e9e2d054596d085aa9121867693
46 changes: 38 additions & 8 deletions prompting/rewards/web_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,22 @@
PAST_WEBSITES_FILE = "past_websites.csv"
TOP_DOMAINS_FILE = "../data/top100k_domains.csv"

# Define blacklisted terms
BLACKLISTED_TERMS = {
"pastebin",
"paste",
"gist",
"github",
"gitlab",
"bitbucket",
"hastebin",
"ghostbin",
"privatebin",
}

# Maximum number of past URLs to store per user
N_PAST_URLS = 200

# Load the past_websites dictionary and top domains
try:
# Load top domains
Expand All @@ -34,8 +50,9 @@
if os.path.exists(PAST_WEBSITES_FILE):
past_websites_df = pd.read_csv(PAST_WEBSITES_FILE)
past_websites = defaultdict(list)
for _, row in past_websites_df.iterrows():
past_websites[row["uid"]].append(row["domain"])
# Group by uid and take only the last N_PAST_URLS entries
for uid, group in past_websites_df.groupby("uid"):
past_websites[uid] = group["domain"].tolist()[-N_PAST_URLS:]
else:
logger.warning(f"Past websites file {PAST_WEBSITES_FILE} does not exist, creating new dictionary")
past_websites = defaultdict(list)
Expand All @@ -45,6 +62,13 @@
past_websites = defaultdict(list)


def _append_to_past_websites(uid: str, domain: str):
"""Helper function to append domain to past_websites while maintaining max size."""
past_websites[uid].append(domain)
if len(past_websites[uid]) > N_PAST_URLS:
past_websites[uid] = past_websites[uid][-N_PAST_URLS:]


class WebsiteResult(BaseModel):
url: str | None
content: str | None
Expand All @@ -66,33 +90,39 @@ def score_website_result(

# Extract domain from URL
parsed_url = urlparse(response_url)

if any(term in response_url for term in BLACKLISTED_TERMS):
logger.debug(f"Domain {parsed_url.netloc} contains blacklisted term, scoring 0")
return 0

netloc = parsed_url.netloc.lower()

# Remove www. prefix if present
if netloc.startswith("www."):
netloc = netloc[4:]

print("NETLOC", parsed_url)
# Check if URL is IP-based or has port
if not netloc or any(c.isdigit() for c in netloc.split(".")) or ":" in netloc:
discount_factor = 0
logger.info(f"URL {response_url} appears to be IP-based or on specific port, setting discount factor to 0")
logger.debug(f"URL {response_url} appears to be IP-based or on specific port, setting discount factor to 0")
return 0
else:
domain = netloc

# If domain is in top 100k, don't apply penalty
if domain in TOP_DOMAINS:
discount_factor = 1.0
logger.info(f"Domain {domain} is in top 100k domains, not applying penalty")
logger.debug(f"Domain {domain} is in top 100k domains, not applying penalty")
else:
# Count how many times this domain has been used by this miner
domain_count = np.sum(np.array([domain == d for d in past_websites[uid]])) + 1
discount_factor = 1.0 / domain_count
if domain in past_websites[uid]:
logger.info(
logger.debug(
f"Already used domain {domain} for this UID, applying ( discount ) factor {discount_factor}"
)
past_websites[uid].append(domain)
_append_to_past_websites(uid, domain)

# Content scraped from the URL provided in the completion.
reference_website_content = DDGDataset.extract_website_content(response_url)
Expand All @@ -101,11 +131,11 @@ def score_website_result(
return 0

if fuzz.ratio(response_content, reference_website_content) < MIN_MATCH_THRESHOLD:
logger.info("Miner returned text that doesn't match the website, scoring 0")
logger.debug("Miner returned text that doesn't match the website, scoring 0")
return 0

if len(response_relevant) > len(response_content) or len(response_relevant) < MIN_RELEVANT_CHARS:
logger.info(
logger.debug(
f"Relevant section is too short (<{MIN_RELEVANT_CHARS} chars) or longer than the whole website content "
f"{len(response_relevant)} > {len(response_content)}"
)
Expand Down
Loading