Skip to content

Commit

Permalink
Better stats
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Nov 12, 2024
1 parent 9ce28c0 commit ae9b1c4
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 12 deletions.
20 changes: 14 additions & 6 deletions pdelfin/beakerpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pdelfin.prompts import build_finetuning_prompt, PageResponse
from pdelfin.prompts.anchor import get_anchor_text
from pdelfin.check import check_poppler_version
from pdelfin.metrics import MetricsKeeper
from pdelfin.metrics import MetricsKeeper, WorkerTracker

# Initialize logger
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -62,6 +62,7 @@

# Global variables for token statistics
metrics = MetricsKeeper(window=60*5)
tracker = WorkerTracker()

# Process pool for offloading cpu bound work, like calculating anchor texts
process_pool = ProcessPoolExecutor()
Expand Down Expand Up @@ -229,11 +230,12 @@ async def load_pdf_work_queue(args) -> asyncio.Queue:
return queue


async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult:
async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult:
COMPLETION_URL = "http://localhost:30000/v1/chat/completions"
MAX_RETRIES = 3

attempt = 0
await tracker.track_work(worker_id, f"{pdf_s3_path}-{page_num}", "started")

while attempt < MAX_RETRIES:
query = await build_page_query(
Expand All @@ -255,6 +257,7 @@ async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, p
model_response_json = json.loads(base_response_data["choices"][0]["message"]["content"])
page_response = PageResponse(**model_response_json)

await tracker.track_work(worker_id, f"{pdf_s3_path}-{page_num}", "finished")
return PageResult(
pdf_s3_path,
page_num,
Expand Down Expand Up @@ -282,8 +285,9 @@ async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, p
logger.error(f"Failed to process {pdf_s3_path}-{page_num} after {MAX_RETRIES} attempts.")
raise ValueError(f"Could not process {pdf_s3_path}-{page_num} after {MAX_RETRIES} attempts")

await tracker.track_work(worker_id, f"{pdf_s3_path}-{page_num}", "errored")

async def process_pdf(args, pdf_s3_path: str):
async def process_pdf(args, worker_id: int, pdf_s3_path: str):
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
# TODO Switch to aioboto3 or something
data = await asyncio.to_thread(lambda: get_s3_bytes(pdf_s3, pdf_s3_path))
Expand All @@ -299,7 +303,7 @@ async def process_pdf(args, pdf_s3_path: str):
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=3600), connector=aiohttp.TCPConnector(limit=10)) as session:
for page_num in range(1, num_pages + 1):
# Create a task for each page
task = asyncio.create_task(process_page(args, session, pdf_s3_path, tf.name, page_num))
task = asyncio.create_task(process_page(args, session, worker_id, pdf_s3_path, tf.name, page_num))
page_tasks.append(task)

# Gather results from all page processing tasks
Expand Down Expand Up @@ -362,7 +366,7 @@ async def worker(args, queue, semaphore, worker_id):
# Wait until allowed to proceed
await semaphore.acquire()

dolma_docs = await asyncio.gather(*[process_pdf(args, pdf) for pdf in pdfs])
dolma_docs = await asyncio.gather(*[process_pdf(args, worker_id, pdf) for pdf in pdfs])
dolma_docs = [doc for doc in dolma_docs if doc is not None]

# Write the Dolma documents to a local temporary file in JSONL format
Expand Down Expand Up @@ -501,11 +505,15 @@ async def sglang_server_ready():

raise Exception("sglang server did not become ready after waiting.")


async def metrics_reporter():
while True:
logger.info(metrics)
# Leading newlines preserve table formatting in logs
logger.info("\n" + str(metrics))
logger.info("\n" + str(await tracker.get_status_table()))
await asyncio.sleep(10)


async def main():
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/')
Expand Down
90 changes: 84 additions & 6 deletions pdelfin/metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import time
import asyncio
from collections import deque, defaultdict
from dataclasses import dataclass, field
from typing import Dict

class MetricsKeeper:
def __init__(self, window=60*5):
Expand Down Expand Up @@ -47,20 +50,95 @@ def __str__(self):
Returns a formatted string of metrics showing tokens/sec since start and within the window.
Returns:
str: Formatted metrics string.
str: Formatted metrics string as a table.
"""
current_time = time.time()
elapsed_time = current_time - self.start_time
window_time = min(self.window, elapsed_time) if elapsed_time > 0 else 1 # Prevent division by zero

metrics_strings = []
# Header
header = f"{'Metric Name':<20} {'Lifetime (tokens/sec)':>25} {'Window (tokens/sec)':>25}"
separator = "-" * len(header)
lines = [header, separator]

# Sort metrics alphabetically for consistency
for key in sorted(self.total_metrics.keys()):
total = self.total_metrics[key]
window = self.window_sum.get(key, 0)
total_rate = total / elapsed_time if elapsed_time > 0 else 0
window_rate = window / window_time if window_time > 0 else 0
metrics_strings.append(
f"{key}: {total_rate:.2f}/sec (last {int(window_time)}s: {window_rate:.2f}/sec)"
)
line = f"{key:<20} {total_rate:>25.2f} {window_rate:>25.2f}"
lines.append(line)

return "\n".join(lines)


class WorkerTracker:
def __init__(self):
"""
Initializes the WorkerTracker with a default dictionary.
Each worker ID maps to another dictionary that holds counts for each state.
"""
# Mapping from worker_id to a dictionary of state counts
self.worker_status: Dict[int, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
self.lock = asyncio.Lock()

async def track_work(self, worker_id: int, work_item_id: str, state: str):
"""
Update the state count for a specific worker.
Args:
worker_id (int): The ID of the worker.
work_item_id (str): The unique identifier of the work item (unused in this implementation).
state (str): The state to increment for the work item.
"""
async with self.lock:
self.worker_status[worker_id][state] += 1
logger.debug(f"Worker {worker_id} - State '{state}' incremented to {self.worker_status[worker_id][state]}.")

async def get_status_table(self) -> str:
"""
Generate a formatted table of the current status of all workers.
Returns:
str: A string representation of the workers' statuses.
"""
async with self.lock:
# Determine all unique states across all workers
all_states = set()
for states in self.worker_status.values():
all_states.update(states.keys())
all_states = sorted(all_states)

headers = ["Worker ID"] + all_states
rows = []
for worker_id, states in sorted(self.worker_status.items()):
row = [str(worker_id)]
for state in all_states:
count = states.get(state, 0)
row.append(str(count))
rows.append(row)

# Calculate column widths
col_widths = [len(header) for header in headers]
for row in rows:
for idx, cell in enumerate(row):
col_widths[idx] = max(col_widths[idx], len(cell))

return ", ".join(metrics_strings)
# Create the table header
header_line = " | ".join(header.ljust(col_widths[idx]) for idx, header in enumerate(headers))
separator = "-+-".join('-' * col_widths[idx] for idx in range(len(headers)))

# Create the table rows
row_lines = [" | ".join(cell.ljust(col_widths[idx]) for idx, cell in enumerate(row)) for row in rows]

# Combine all parts
table = "\n".join([header_line, separator] + row_lines)
return table

def __str__(self):
"""
String representation is not directly supported.
Use 'await get_status_table()' to retrieve the status table.
"""
raise NotImplementedError("Use 'await get_status_table()' to get the status table.")

0 comments on commit ae9b1c4

Please sign in to comment.