Skip to content

Commit

Permalink
Basic work queue from claude
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Nov 18, 2024
1 parent 995b1d1 commit 04429b2
Show file tree
Hide file tree
Showing 2 changed files with 395 additions and 15 deletions.
168 changes: 153 additions & 15 deletions pdelfin/s3_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import hashlib
import tempfile
import datetime
from typing import Optional, Tuple, List, Dict, Set
from dataclasses import dataclass
import asyncio
Expand Down Expand Up @@ -58,26 +59,86 @@ def __init__(self, s3_client, workspace_path: str):

self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd")
self._output_glob = os.path.join(self.workspace_path, "results", "*.jsonl")
self._queue = asyncio.Queue()

@staticmethod
def _compute_workgroup_hash(s3_work_paths: List[str]) -> str:
"""
Compute a deterministic hash for a group of PDFs.
Compute a deterministic hash for a group of paths.
Args:
pdfs: List of PDF S3 paths
s3_work_paths: List of S3 paths
Returns:
SHA1 hash of the sorted PDF paths
SHA1 hash of the sorted paths
"""
sha1 = hashlib.sha1()
for pdf in sorted(s3_work_paths):
sha1.update(pdf.encode('utf-8'))
for path in sorted(s3_work_paths):
sha1.update(path.encode('utf-8'))
return sha1.hexdigest()

async def populate_queue(self, s3_work_paths: list[str], items_per_group: int) -> None:
"""
Add new items to the work queue.
Args:
s3_work_paths: Each individual s3 path that we will process over
items_per_group: Number of items to group together in a single work item
"""
all_paths = set(s3_work_paths)
logger.info(f"Found {len(all_paths):,} total paths")

# Load existing work groups
existing_lines = await asyncio.to_thread(download_zstd_csv, self.s3_client, self._index_path)
existing_groups = {}
for line in existing_lines:
if line.strip():
parts = line.strip().split(",")
group_hash = parts[0]
group_paths = parts[1:]
existing_groups[group_hash] = group_paths

existing_path_set = {path for paths in existing_groups.values() for path in paths}

# Find new paths to process
new_paths = all_paths - existing_path_set
logger.info(f"{len(new_paths):,} new paths to add to the workspace")

if not new_paths:
return

# Create new work groups
new_groups = []
current_group = []
for path in sorted(new_paths):
current_group.append(path)
if len(current_group) == items_per_group:
group_hash = self._compute_workgroup_hash(current_group)
new_groups.append((group_hash, current_group))
current_group = []
if current_group:
group_hash = self._compute_workgroup_hash(current_group)
new_groups.append((group_hash, current_group))

logger.info(f"Created {len(new_groups):,} new work groups")

async def populate_queue(self, s3_work_paths: str, items_per_group: int) -> None:
pass
# Combine and save updated work groups
combined_groups = existing_groups.copy()
for group_hash, group_paths in new_groups:
combined_groups[group_hash] = group_paths

combined_lines = [
",".join([group_hash] + group_paths)
for group_hash, group_paths in combined_groups.items()
]

if new_groups:
await asyncio.to_thread(
upload_zstd_csv,
self.s3_client,
self._index_path,
combined_lines
)

async def initialize_queue(self) -> None:
"""
Expand Down Expand Up @@ -116,7 +177,7 @@ async def initialize_queue(self) -> None:
# Find remaining work and shuffle
remaining_work_hashes = set(work_queue) - done_work_hashes
remaining_items = [
WorkItem(hash_=hash_, pdfs=work_queue[hash_])
WorkItem(hash=hash_, s3_work_paths=work_queue[hash_])
for hash_ in remaining_work_hashes
]
random.shuffle(remaining_items)
Expand All @@ -127,7 +188,7 @@ async def initialize_queue(self) -> None:
await self._queue.put(item)

logger.info(f"Initialized queue with {self._queue.qsize()} work items")

async def is_completed(self, work_hash: str) -> bool:
"""
Check if a work item has been completed.
Expand All @@ -138,7 +199,7 @@ async def is_completed(self, work_hash: str) -> bool:
Returns:
True if the work is completed, False otherwise
"""
output_s3_path = ""TODO""
output_s3_path = os.path.join(self.workspace_path, "results", f"output_{work_hash}.jsonl")
bucket, key = parse_s3_path(output_s3_path)

try:
Expand All @@ -151,12 +212,89 @@ async def is_completed(self, work_hash: str) -> bool:
except self.s3_client.exceptions.ClientError:
return False

async def get_work(self) -> Optional[WorkItem]:
pass
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
"""
Get the next available work item that isn't completed or locked.
Args:
worker_lock_timeout_secs: Number of seconds before considering a worker lock stale (default 30 mins)
Returns:
WorkItem if work is available, None if queue is empty
"""
while True:
try:
work_item = self._queue.get_nowait()
except asyncio.QueueEmpty:
return None

# Check if work is already completed
if await self.is_completed(work_item.hash):
logger.debug(f"Work item {work_item.hash} already completed, skipping")
self._queue.task_done()
continue

# Check for worker lock
lock_path = os.path.join(self.workspace_path, "worker_locks", f"output_{work_item.hash}.jsonl")
bucket, key = parse_s3_path(lock_path)

try:
response = await asyncio.to_thread(
self.s3_client.head_object,
Bucket=bucket,
Key=key
)

# Check if lock is stale
last_modified = response['LastModified']
if (datetime.datetime.now(datetime.timezone.utc) - last_modified).total_seconds() > worker_lock_timeout_secs:
# Lock is stale, we can take this work
logger.debug(f"Found stale lock for {work_item.hash}, taking work item")
else:
# Lock is active, skip this work
logger.debug(f"Work item {work_item.hash} is locked by another worker, skipping")
self._queue.task_done()
continue

except self.s3_client.exceptions.ClientError:
# No lock exists, we can take this work
pass

# Create our lock file
try:
await asyncio.to_thread(
self.s3_client.put_object,
Bucket=bucket,
Key=key,
Body=b''
)
except Exception as e:
logger.warning(f"Failed to create lock file for {work_item.hash}: {e}")
self._queue.task_done()
continue

return work_item

async def mark_done(self, work_item: WorkItem) -> None:
"""
Mark a work item as done by removing its lock file.
Args:
work_item: The WorkItem to mark as done
"""
lock_path = os.path.join(self.workspace_path, "worker_locks", f"output_{work_item.hash}.jsonl")
bucket, key = parse_s3_path(lock_path)

try:
await asyncio.to_thread(
self.s3_client.delete_object,
Bucket=bucket,
Key=key
)
except Exception as e:
logger.warning(f"Failed to delete lock file for {work_item.hash}: {e}")

def mark_done(self, work_item: WorkItem) -> None:
"""Mark the most recently gotten work item as complete"""
pass
self._queue.task_done()

@property
def size(self) -> int:
Expand Down
Loading

0 comments on commit 04429b2

Please sign in to comment.