Skip to content

Commit

Permalink
Control http session at the worker level
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Nov 12, 2024
1 parent fbacdd0 commit 6598e2d
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions pdelfin/beakerpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf
raise ValueError(f"Could not process {pdf_s3_path}-{page_num} after {MAX_RETRIES} attempts")


async def process_pdf(args, worker_id: int, pdf_s3_path: str):
async def process_pdf(args, session: aiohttp.ClientSession, worker_id: int, pdf_s3_path: str):
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
# TODO Switch to aioboto3 or something
data = await asyncio.to_thread(lambda: get_s3_bytes(pdf_s3, pdf_s3_path))
Expand All @@ -301,20 +301,17 @@ async def process_pdf(args, worker_id: int, pdf_s3_path: str):

# List to hold the tasks for processing each page
page_tasks = []
for page_num in range(1, num_pages + 1):
# Create a task for each page
task = asyncio.create_task(process_page(args, session, worker_id, pdf_s3_path, tf.name, page_num))
page_tasks.append(task)

async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=3600), connector=aiohttp.TCPConnector(limit=50)) as session:
for page_num in range(1, num_pages + 1):
# Create a task for each page
task = asyncio.create_task(process_page(args, session, worker_id, pdf_s3_path, tf.name, page_num))
page_tasks.append(task)

# Gather results from all page processing tasks
try:
page_results: list[PageResult] = await asyncio.gather(*page_tasks)
except:
logger.exception(f"Could not load page for {pdf_s3_path}, aborting document")
return None

# Gather results from all page processing tasks
try:
page_results: list[PageResult] = await asyncio.gather(*page_tasks)
except:
logger.exception(f"Could not load page for {pdf_s3_path}, aborting document")
return None

# Build the document text and page spans
document_text = ""
Expand Down Expand Up @@ -369,8 +366,10 @@ async def worker(args, queue, semaphore, worker_id):
# Wait until allowed to proceed
await semaphore.acquire()

dolma_docs = await asyncio.gather(*[process_pdf(args, worker_id, pdf) for pdf in pdfs])
dolma_docs = [doc for doc in dolma_docs if doc is not None]
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=3600),
connector=aiohttp.TCPConnector(limit=1000)) as session:
dolma_docs = await asyncio.gather(*[process_pdf(args, session, worker_id, pdf) for pdf in pdfs])
dolma_docs = [doc for doc in dolma_docs if doc is not None]

# Write the Dolma documents to a local temporary file in JSONL format
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tf:
Expand Down

0 comments on commit 6598e2d

Please sign in to comment.