Skip to content

Commit

Permalink
Working on task groups
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Nov 14, 2024
1 parent a58efea commit 827b77e
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 56 deletions.
121 changes: 66 additions & 55 deletions pdelfin/beakerpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,10 @@ async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf
except json.JSONDecodeError as e:
logger.warning(f"JSON decode error on attempt {attempt} for {pdf_s3_path}-{page_num}: {e}")
attempt += 1
except asyncio.CancelledError:
logger.info(f"Process page {pdf_s3_path}-{page_num} cancelled")
await tracker.track_work(worker_id, f"{pdf_s3_path}-{page_num}", "cancelled")
raise
except Exception as e:
logger.warning(f"Unexpected error on attempt {attempt} for {pdf_s3_path}-{page_num}: {e}")
attempt += 1
Expand Down Expand Up @@ -326,61 +330,65 @@ async def process_pdf(args, session: aiohttp.ClientSession, worker_id: int, pdf_

# List to hold the tasks for processing each page
page_tasks = []
for page_num in range(1, num_pages + 1):
# Create a task for each page
task = asyncio.create_task(process_page(args, session, worker_id, pdf_s3_path, tf.name, page_num))
page_tasks.append(task)
page_results = []

# Gather results from all page processing tasks
try:
page_results: list[PageResult] = await asyncio.gather(*page_tasks)
except:
logger.exception(f"Could not load page for {pdf_s3_path}, aborting document")
return None
async with asyncio.TaskGroup() as tg:
for page_num in range(1, num_pages + 1):
task = tg.create_task(process_page(args, session, worker_id, pdf_s3_path, tf.name, page_num))
page_tasks.append(task)

# Build the document text and page spans
document_text = ""
pdf_page_spans = []
current_char_pos = 0

for index, page_result in enumerate(page_results):
if page_result.response.natural_text is not None:
content = page_result.response.natural_text + ("\n" if index == len(page_results) - 1 else "")
else:
content = ""

start_pos = current_char_pos
document_text += content
current_char_pos = len(document_text)
pdf_page_spans.append([start_pos, current_char_pos, page_result.page_num])

if not document_text:
return None # Return None if the document text is empty

# Build the Dolma document
metadata = {
"Source-File": pdf_s3_path,
"pdf-total-pages": num_pages,
"total-input-tokens": sum(page.input_tokens for page in page_results),
"total-output-tokens": sum(page.output_tokens for page in page_results)
}
# Collect the results from the entire task group, assuming no exceptions
page_results = [task.result() for task in page_tasks]
except Exception as e:
logger.exception(f"Exception in process_pdf for {pdf_s3_path}: {e}")
raise

return build_dolma_document(pdf_s3_path, page_results)


def build_dolma_document(pdf_s3_path, page_results):
# Build the document text and page spans
document_text = ""
pdf_page_spans = []
current_char_pos = 0

for index, page_result in enumerate(page_results):
if page_result.response.natural_text is not None:
content = page_result.response.natural_text + ("\n" if index < len(page_results) - 1 else "")
else:
content = ""

start_pos = current_char_pos
document_text += content
current_char_pos = len(document_text)
pdf_page_spans.append([start_pos, current_char_pos, page_result.page_num])

if not document_text:
return None # Return None if the document text is empty

# Build the Dolma document
metadata = {
"Source-File": pdf_s3_path,
"pdf-total-pages": len(page_results),
"total-input-tokens": sum(page.input_tokens for page in page_results),
"total-output-tokens": sum(page.output_tokens for page in page_results)
}

id_ = hashlib.sha1(document_text.encode()).hexdigest()

dolma_doc = {
"id": id_,
"text": document_text,
"source": "pdelfin",
"added": datetime.datetime.now().strftime("%Y-%m-%d"),
"created": datetime.datetime.now().strftime("%Y-%m-%d"),
"metadata": metadata,
"attributes": {
"pdf_page_numbers": pdf_page_spans
}
id_ = hashlib.sha1(document_text.encode()).hexdigest()

dolma_doc = {
"id": id_,
"text": document_text,
"source": "pdelfin",
"added": datetime.datetime.now().strftime("%Y-%m-%d"),
"created": datetime.datetime.now().strftime("%Y-%m-%d"),
"metadata": metadata,
"attributes": {
"pdf_page_numbers": pdf_page_spans
}

return dolma_doc

}
return dolma_doc

async def worker(args, queue, semaphore, worker_id):
while True:
Expand All @@ -396,10 +404,12 @@ async def worker(args, queue, semaphore, worker_id):
logger.info(f"Work {work_hash} was already completed, skipping")
continue

async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=3600),
async with asyncio.TaskGroup() as tg, \
aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=3600),
connector=aiohttp.TCPConnector(limit=1000)) as session:
dolma_docs = await asyncio.gather(*[process_pdf(args, session, worker_id, pdf) for pdf in pdfs])
dolma_docs = [doc for doc in dolma_docs if doc is not None]
dolma_docs = [tg.create_task(process_pdf(args, session, worker_id, pdf)) for pdf in pdfs]

dolma_docs = [task.result() for doc in dolma_docs if task.result() is not None]

# Write the Dolma documents to a local temporary file in JSONL format
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tf:
Expand Down Expand Up @@ -539,9 +549,10 @@ async def sglang_server_ready():
raise Exception("sglang server did not become ready after waiting.")


async def metrics_reporter():
async def metrics_reporter(queue):
while True:
# Leading newlines preserve table formatting in logs
logger.info(f"Queue remaining: {queue.qsize()}")
logger.info("\n" + str(metrics))
logger.info("\n" + str(await tracker.get_status_table()))
await asyncio.sleep(10)
Expand Down Expand Up @@ -697,7 +708,7 @@ async def main():

await sglang_server_ready()

metrics_task = asyncio.create_task(metrics_reporter())
metrics_task = asyncio.create_task(metrics_reporter(work_queue))

# Create worker tasks to process the queue concurrently.
worker_tasks = []
Expand Down
2 changes: 1 addition & 1 deletion pdelfin/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
_MINOR = "1"
# On main and in a nightly release the patch should be one ahead of the last
# released build.
_PATCH = "9"
_PATCH = "10"
# This is mainly for nightly builds which have the suffix ".dev$DATE". See
# https://semver.org/#is-v123-a-semantic-version for the semantics.
_SUFFIX = ""
Expand Down

0 comments on commit 827b77e

Please sign in to comment.