Skip to content

Commit

Permalink
Better work queue
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Nov 18, 2024
1 parent 04429b2 commit e499413
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 199 deletions.
274 changes: 82 additions & 192 deletions pdelfin/beakerpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
from pypdf import PdfReader
from functools import partial
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Tuple, List, Dict, Set
from concurrent.futures import ProcessPoolExecutor

from pdelfin.s3_queue import S3WorkQueue, WorkItem
from pdelfin.s3_utils import expand_s3_glob, get_s3_bytes, get_s3_bytes_with_backoff, parse_s3_path, download_zstd_csv, upload_zstd_csv, download_directory
from pdelfin.data.renderpdf import render_pdf_to_base64png
from pdelfin.prompts import build_finetuning_prompt, PageResponse
Expand Down Expand Up @@ -123,160 +124,6 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_
}


def compute_workgroup_sha1(work_group: list[str]) -> str:
sha1 = hashlib.sha1()
# Ensure consistent ordering by sorting the list
for pdf in sorted(work_group):
sha1.update(pdf.encode('utf-8'))
return sha1.hexdigest()


async def populate_pdf_work_queue(args):
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")

if args.pdfs.startswith("s3://"):
logger.info(f"Expanding s3 glob at {args.pdfs}")
all_pdfs = expand_s3_glob(pdf_s3, args.pdfs)
elif os.path.exists(args.pdfs):
logger.info(f"Loading file at {args.pdfs}")
with open(args.pdfs, "r") as f:
all_pdfs = list(filter(None, (line.strip() for line in tqdm(f, desc="Processing PDFs"))))
else:
raise ValueError("pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)")

all_pdfs = set(all_pdfs)
logger.info(f"Found {len(all_pdfs):,} total pdf paths")

existing_lines = download_zstd_csv(workspace_s3, index_file_s3_path)

# Parse existing work items into groups
existing_groups = {}
for line in existing_lines:
if line.strip():
parts = line.strip().split(",")
group_hash = parts[0]
group_pdfs = parts[1:]
existing_groups[group_hash] = group_pdfs
existing_pdf_set = set(pdf for group_pdfs in existing_groups.values() for pdf in group_pdfs)

logger.info(f"Loaded {len(existing_pdf_set):,} existing pdf paths from the workspace")

# Remove existing PDFs from all_pdfs
new_pdfs = all_pdfs - existing_pdf_set
logger.info(f"{len(new_pdfs):,} new pdf paths to add to the workspace")

sample_size = min(100, len(new_pdfs))
sampled_pdfs = random.sample(list(new_pdfs), sample_size)

page_counts = []

for pdf in tqdm(sampled_pdfs, desc="Sampling PDFs to calculate optimial length"):
try:
# Download the PDF to a temp file
with tempfile.NamedTemporaryFile(suffix=".pdf") as tmp_file:
s3_bucket, s3_key = parse_s3_path(pdf)
pdf_s3.download_fileobj(s3_bucket, s3_key, tmp_file)
tmp_file.flush()
reader = PdfReader(tmp_file.name)
page_counts.append(len(reader.pages))
except Exception as e:
logger.warning(f"Failed to read {pdf}: {e}")

if page_counts:
avg_pages_per_pdf = sum(page_counts) / len(page_counts)
else:
logger.warning("Could not read any PDFs to estimate average page count.")
avg_pages_per_pdf = 10 # Default to 10 pages per PDF if sampling fails

group_size = max(1, int(args.pages_per_group / avg_pages_per_pdf))
logger.info(f"Calculated group_size: {group_size} based on average pages per PDF: {avg_pages_per_pdf:.2f}")

new_groups = []
current_group = []
for pdf in sorted(new_pdfs): # Sort for consistency
current_group.append(pdf)
if len(current_group) == group_size:
group_hash = compute_workgroup_sha1(current_group)
new_groups.append((group_hash, current_group))
current_group = []
if current_group:
group_hash = compute_workgroup_sha1(current_group)
new_groups.append((group_hash, current_group))

logger.info(f"Created {len(new_groups):,} new work groups")

# Combine existing groups with new groups
combined_groups = existing_groups.copy()
for group_hash, group_pdfs in new_groups:
combined_groups[group_hash] = group_pdfs

# Prepare lines to write back
combined_lines = [",".join([group_hash] + group_pdfs) for group_hash, group_pdfs in combined_groups.items()]

# Upload the combined work items back to S3
if new_groups:
upload_zstd_csv(workspace_s3, index_file_s3_path, combined_lines)

logger.info("Completed adding new PDFs.")

async def load_pdf_work_queue(args) -> asyncio.Queue:
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
output_glob = os.path.join(args.workspace, "dolma_documents", "*.jsonl")

# Define the two blocking I/O operations
download_task = asyncio.to_thread(download_zstd_csv, workspace_s3, index_file_s3_path)
expand_task = asyncio.to_thread(expand_s3_glob, workspace_s3, output_glob)

# Run both tasks concurrently
work_queue_lines, done_work_items = await asyncio.gather(download_task, expand_task)

# Process the work queue lines
work_queue = {
parts[0]: parts[1:]
for line in work_queue_lines
if (parts := line.strip().split(",")) and line.strip()
}

# Extract done work hashes
done_work_hashes = {
os.path.basename(item)[len('output_'):-len('.jsonl')]
for item in done_work_items
if os.path.basename(item).startswith('output_') and os.path.basename(item).endswith('.jsonl')
}

# Determine remaining work
remaining_work_hashes = set(work_queue) - done_work_hashes
#remaining_work_hashes = set(["0e779f21fbb75d38ed4242c7e5fe57fa9a636bac"]) # If you want to debug with a specific work hash
remaining_work_queue = {
hash_: work_queue[hash_]
for hash_ in remaining_work_hashes
}

# Populate the asyncio.Queue with remaining work
queue = asyncio.Queue()
shuffled_items = list(remaining_work_queue.items())
random.shuffle(shuffled_items)

for work, pdfs in shuffled_items:
await queue.put((work, pdfs))

return queue

async def work_item_completed(args, work_hash: str) -> bool:
# Check if work item has already been completed
output_s3_path = os.path.join(args.workspace, 'dolma_documents', f'output_{work_hash}.jsonl')
bucket, key = parse_s3_path(output_s3_path)

try:
# Check if the output file already exists
await asyncio.to_thread(workspace_s3.head_object, Bucket=bucket, Key=key)
return True
except workspace_s3.exceptions.ClientError as e:
pass

return False


async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult:
COMPLETION_URL = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
MAX_RETRIES = 3
Expand Down Expand Up @@ -442,31 +289,32 @@ def build_dolma_document(pdf_s3_path, page_results):
}
return dolma_doc

async def worker(args, queue, semaphore, worker_id):

async def worker(args, work_queue: S3WorkQueue, semaphore, worker_id):
while True:
[work_hash, pdfs] = await queue.get()
# Wait until allowed to proceed
await semaphore.acquire()

try:
await tracker.clear_work(worker_id)
work_item = await work_queue.get_work()

# Wait until allowed to proceed
await semaphore.acquire()
if work_item is None:
logger.info(f"Worker {worker_id} exiting due to empty queue")
semaphore.release()
break

if await work_item_completed(args, work_hash):
logger.info(f"Work {work_hash} was already completed, skipping")
continue
else:
logger.info(f"Proceeding with {work_hash} on worker {worker_id}")
logger.info(f"Worker {worker_id} processing work item {work_item.hash}")
await tracker.clear_work(worker_id)

try:
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=600),
connector=aiohttp.TCPConnector(limit=1000)) as session:
async with asyncio.TaskGroup() as tg:
dolma_tasks = [tg.create_task(process_pdf(args, session, worker_id, pdf)) for pdf in pdfs]
logger.info(f"Created all tasks for {work_hash}")
dolma_tasks = [tg.create_task(process_pdf(args, session, worker_id, pdf)) for pdf in work_item.s3_work_paths]
logger.info(f"Created all tasks for {work_item.hash}")

logger.info(f"Finished TaskGroup for worker on {work_hash}")
logger.info(f"Finished TaskGroup for worker on {work_item.hash}")

logger.info(f"Closed ClientSession for {work_hash}")
logger.info(f"Closed ClientSession for {work_item.hash}")

dolma_docs = []
for task in dolma_tasks:
Expand All @@ -479,7 +327,7 @@ async def worker(args, queue, semaphore, worker_id):
if result is not None:
dolma_docs.append(result)

logger.info(f"Got {len(dolma_docs)} docs for {work_hash}")
logger.info(f"Got {len(dolma_docs)} docs for {work_item.hash}")

# Write the Dolma documents to a local temporary file in JSONL format
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tf:
Expand All @@ -489,7 +337,7 @@ async def worker(args, queue, semaphore, worker_id):
tf.flush()

# Define the output S3 path using the work_hash
output_s3_path = os.path.join(args.workspace, 'dolma_documents', f'output_{work_hash}.jsonl')
output_s3_path = os.path.join(args.workspace, 'results', f'output_{work_item.hash}.jsonl')

bucket, key = parse_s3_path(output_s3_path)
workspace_s3.upload_file(tf.name, bucket, key)
Expand All @@ -501,9 +349,10 @@ async def worker(args, queue, semaphore, worker_id):
# Update last batch time
last_batch_time = time.perf_counter()
except Exception as e:
logger.exception(f"Exception occurred while processing work_hash {work_hash}: {e}")
logger.exception(f"Exception occurred while processing work_hash {work_item.hash}: {e}")
finally:
queue.task_done()
await work_queue.mark_done(work_item)
semaphore.release()


async def sglang_server_task(args, semaphore):
Expand Down Expand Up @@ -563,6 +412,7 @@ async def process_line(line):

if not server_printed_ready_message and "The server is fired up and ready to roll!" in line:
server_printed_ready_message = True
last_semaphore_release = time.time()

match = re.search(r'#running-req: (\d+)', line)
if match:
Expand Down Expand Up @@ -631,10 +481,10 @@ async def sglang_server_ready():
raise Exception("sglang server did not become ready after waiting.")


async def metrics_reporter(queue):
async def metrics_reporter(work_queue):
while True:
# Leading newlines preserve table formatting in logs
logger.info(f"Queue remaining: {queue.qsize()}")
logger.info(f"Queue remaining: {work_queue.size}")
logger.info("\n" + str(metrics))
logger.info("\n" + str(await tracker.get_status_table()))
await asyncio.sleep(10)
Expand Down Expand Up @@ -716,13 +566,14 @@ def submit_beaker_job(args):

print(f"Experiment URL: https://beaker.org/ex/{experiment_data.id}")


def print_stats(args):
import concurrent.futures
from tqdm import tqdm

# Get total work items and completed items
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
output_glob = os.path.join(args.workspace, "dolma_documents", "*.jsonl")
index_file_s3_path = os.path.join(args.workspace, "work_index_list.csv.zstd")
output_glob = os.path.join(args.workspace, "results", "*.jsonl")

work_queue_lines = download_zstd_csv(workspace_s3, index_file_s3_path)
done_work_items = expand_s3_glob(workspace_s3, output_glob)
Expand Down Expand Up @@ -825,9 +676,54 @@ async def main():

check_poppler_version()

# Create work queue
work_queue = S3WorkQueue(workspace_s3, args.workspace)

if args.pdfs:
logger.info("Got --pdfs argument, going to add to the work queue")
await populate_pdf_work_queue(args)

# Expand s3 paths
if args.pdfs.startswith("s3://"):
logger.info(f"Expanding s3 glob at {args.pdfs}")
s3_work_paths = expand_s3_glob(pdf_s3, args.pdfs)
elif os.path.exists(args.pdfs):
logger.info(f"Loading file at {args.pdfs}")
with open(args.pdfs, "r") as f:
s3_work_paths = list(filter(None, (line.strip() for line in f)))
else:
raise ValueError("pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)")

s3_work_paths = set(s3_work_paths)
logger.info(f"Found {len(s3_work_paths):,} total pdf paths to add")

# Estimate average pages per pdf
sample_size = min(100, len(s3_work_paths))
sampled_pdfs = random.sample(list(s3_work_paths), sample_size)
page_counts = []

for pdf in tqdm(sampled_pdfs, desc="Sampling PDFs to calculate optimal length"):
try:
# Download the PDF to a temp file
with tempfile.NamedTemporaryFile(suffix=".pdf") as tmp_file:
s3_bucket, s3_key = parse_s3_path(pdf)
pdf_s3.download_fileobj(s3_bucket, s3_key, tmp_file)
tmp_file.flush()
reader = PdfReader(tmp_file.name)
page_counts.append(len(reader.pages))
except Exception as e:
logger.warning(f"Failed to read {pdf}: {e}")

if page_counts:
avg_pages_per_pdf = sum(page_counts) / len(page_counts)
else:
logger.warning("Could not read any PDFs to estimate average page count.")
avg_pages_per_pdf = 10 # Default to 10 pages per PDF if sampling fails

items_per_group = max(1, int(args.pages_per_group / avg_pages_per_pdf))
logger.info(f"Calculated items_per_group: {items_per_group} based on average pages per PDF: {avg_pages_per_pdf:.2f}")

# Now call populate_queue
await work_queue.populate_queue(s3_work_paths, items_per_group)

if args.stats:
print_stats(args)
Expand All @@ -839,6 +735,9 @@ async def main():

logger.info(f"Starting pipeline with PID {os.getpid()}")

# Initialize the work queue
await work_queue.initialize_queue()

# Create a semaphore to control worker access
# We only allow one worker to move forward with requests, until the server has no more requests in its queue
# This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible
Expand All @@ -847,9 +746,6 @@ async def main():

sglang_server = asyncio.create_task(sglang_server_host(args, semaphore))

work_queue = await load_pdf_work_queue(args)
logger.info(f"Work queue prepared with {work_queue.qsize()} items")

await sglang_server_ready()

metrics_task = asyncio.create_task(metrics_reporter(work_queue))
Expand All @@ -860,28 +756,22 @@ async def main():
task = asyncio.create_task(worker(args, work_queue, semaphore, worker_id=i))
worker_tasks.append(task)

# Wait for the queue to be fully processed
await work_queue.join()

# Cancel our worker tasks.
for task in worker_tasks:
task.cancel()
# Wait for all worker tasks to finish
await asyncio.gather(*worker_tasks)

# Wait until all worker tasks are cancelled.
await asyncio.gather(*worker_tasks, return_exceptions=True)

# Wait for server to stop
process_pool.shutdown(wait=False)

sglang_server.cancel()
metrics_task.cancel()
logger.info("Work done")


if __name__ == "__main__":
asyncio.run(main())

# TODO
# - Refactor the work queue into its own file so it's reusable and generic, and it makes temporary work files (prevent issue where if a work item is done, then it stalls because queue was just emptied)
# X Refactor the work queue into its own file so it's reusable and generic, and it makes temporary work files (prevent issue where if a work item is done, then it stalls because queue was just emptied)
# X Fix the queue release mechanism so that it just does a timeout, based on zero queue size only, so you don't block things
# - Add logging of failed pages and have the stats function read them
# X Add the page rotation check and mechanism
Expand Down
Loading

0 comments on commit e499413

Please sign in to comment.