Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Dec 3, 2024
2 parents 1eda300 + 535181d commit 37cdb9e
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 27 deletions.
113 changes: 87 additions & 26 deletions pdelfin/beakerpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import multiprocessing

from tqdm import tqdm
from urllib.parse import urlparse
from io import BytesIO
from PIL import Image
from pypdf import PdfReader
Expand Down Expand Up @@ -71,8 +72,8 @@
metrics = MetricsKeeper(window=60*5)
tracker = WorkerTracker()

# Process pool for offloading cpu bound work, like calculating anchor texts
process_pool = ProcessPoolExecutor(mp_context=multiprocessing.get_context('spawn'))
# Process pool for offloading cpu bound work, like calculating anchor texts, max 32 workers, otherwise it can spawn way too many workers on a big machine
process_pool = ProcessPoolExecutor(max_workers=min(multiprocessing.cpu_count() // 2 + 1, 32), mp_context=multiprocessing.get_context('spawn'))

# Filter object, cached so it will only get loaded when/if you need it
get_pdf_filter = cache(lambda: PdfFilter(languages_to_keep={Language.ENGLISH, None}, apply_download_spam_check=True, apply_form_check=True))
Expand Down Expand Up @@ -131,7 +132,72 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_
}


async def process_page(args, session: httpx.AsyncClient, worker_id: int, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult:
# Manual simple implementation of HTTP Post
# It feels strange perhaps, but httpx and aiohttp are very complex beasts
# Ex. the sessionpool in httpcore has 4 different locks in it, and I've noticed
# that at the scale of 100M+ requests, that they deadlock in different strange ways
async def apost(url, json_data):
parsed_url = urlparse(url)
host = parsed_url.hostname
port = parsed_url.port or 80
path = parsed_url.path or "/"

writer = None
try:
reader, writer = await asyncio.open_connection(host, port)

json_payload = json.dumps(json_data)
request = (
f"POST {path} HTTP/1.1\r\n"
f"Host: {host}\r\n"
f"Content-Type: application/json\r\n"
f"Content-Length: {len(json_payload)}\r\n"
f"Connection: close\r\n\r\n"
f"{json_payload}"
)
writer.write(request.encode())
await writer.drain()

# Read status line
status_line = await reader.readline()
if not status_line:
raise ConnectionError("No response from server")
status_parts = status_line.decode().strip().split(' ', 2)
if len(status_parts) < 2:
raise ValueError(f"Malformed status line: {status_line.decode().strip()}")
status_code = int(status_parts[1])

# Read headers
headers = {}
while True:
line = await reader.readline()
if line in (b'\r\n', b'\n', b''):
break
key, _, value = line.decode().partition(':')
headers[key.strip().lower()] = value.strip()

# Read response body
if 'content-length' in headers:
body_length = int(headers['content-length'])
response_body = await reader.readexactly(body_length)
else:
raise ConnectionError("Anything other than fixed content length responses are not implemented yet")

return status_code, response_body
except Exception as e:
# Pass through errors
raise e
finally:
# But just make sure to close the socket on your way out
if writer is not None:
try:
writer.close()
await writer.wait_closed()
except:
pass


async def process_page(args, worker_id: int, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult:
COMPLETION_URL = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
MAX_RETRIES = args.max_page_retries

Expand All @@ -150,22 +216,20 @@ async def process_page(args, session: httpx.AsyncClient, worker_id: int, pdf_s3_
image_rotation=local_image_rotation
)

try:
response = await session.post(COMPLETION_URL, json=query)
logger.info(f"Built page query for {pdf_s3_path}-{page_num}")

try:
if response.status_code == 400:
raise ValueError(f"Got BadRequestError from server: {response.text}, skipping this response")
elif response.status_code == 500:
raise ValueError(f"Got InternalServerError from server: {response.text}, skipping this response")
else:
response.raise_for_status()
try:
status_code, response_body = await apost(COMPLETION_URL, json_data=query)

base_response_data = response.json()
finally:
await response.aclose()

if status_code == 400:
raise ValueError(f"Got BadRequestError from server: {response_body}, skipping this response")
elif status_code == 500:
raise ValueError(f"Got InternalServerError from server: {response_body}, skipping this response")
elif status_code != 200:
raise ValueError(f"Error http status {status_code}")

base_response_data = json.loads(response_body)

if base_response_data["usage"]["total_tokens"] > args.model_max_context:
local_anchor_text_len = max(1, local_anchor_text_len // 2)
logger.info(f"Reducing anchor text len to {local_anchor_text_len} for {pdf_s3_path}-{page_num}")
Expand All @@ -191,7 +255,7 @@ async def process_page(args, session: httpx.AsyncClient, worker_id: int, pdf_s3_
output_tokens=base_response_data["usage"].get("completion_tokens", 0),
is_fallback=False,
)
except (httpx.TransportError, asyncio.TimeoutError) as e:
except (ConnectionError, OSError, asyncio.TimeoutError) as e:
logger.warning(f"Client error on attempt {attempt} for {pdf_s3_path}-{page_num}: {type(e)} {e}")

# Now we want to do exponential backoff, and not count this as an actual page retry
Expand Down Expand Up @@ -229,7 +293,7 @@ async def process_page(args, session: httpx.AsyncClient, worker_id: int, pdf_s3_
)


async def process_pdf(args, session: httpx.AsyncClient, worker_id: int, pdf_s3_path: str):
async def process_pdf(args, worker_id: int, pdf_s3_path: str):
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
# TODO Switch to aioboto3 or something
data = await asyncio.to_thread(lambda: get_s3_bytes_with_backoff(pdf_s3, pdf_s3_path))
Expand All @@ -256,7 +320,7 @@ async def process_pdf(args, session: httpx.AsyncClient, worker_id: int, pdf_s3_p
try:
async with asyncio.TaskGroup() as tg:
for page_num in range(1, num_pages + 1):
task = tg.create_task(process_page(args, session, worker_id, pdf_s3_path, tf.name, page_num))
task = tg.create_task(process_page(args, worker_id, pdf_s3_path, tf.name, page_num))
page_tasks.append(task)

# Collect the results from the entire task group, assuming no exceptions
Expand Down Expand Up @@ -348,14 +412,11 @@ async def worker(args, work_queue: S3WorkQueue, semaphore, worker_id):
await tracker.clear_work(worker_id)

try:
async with httpx.AsyncClient(timeout=3600, limits=httpx.Limits(max_keepalive_connections=0, max_connections=6000)) as session:
async with asyncio.TaskGroup() as tg:
dolma_tasks = [tg.create_task(process_pdf(args, session, worker_id, pdf)) for pdf in work_item.s3_work_paths]
logger.info(f"Created all tasks for {work_item.hash}")

logger.info(f"Finished TaskGroup for worker on {work_item.hash}")
async with asyncio.TaskGroup() as tg:
dolma_tasks = [tg.create_task(process_pdf(args, worker_id, pdf)) for pdf in work_item.s3_work_paths]
logger.info(f"Created all tasks for {work_item.hash}")

logger.info(f"Closed ClientSession for {work_item.hash}")
logger.info(f"Finished TaskGroup for worker on {work_item.hash}")

dolma_docs = []
for task in dolma_tasks:
Expand Down
2 changes: 1 addition & 1 deletion pdelfin/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
_MINOR = "1"
# On main and in a nightly release the patch should be one ahead of the last
# released build.
_PATCH = "50"
_PATCH = "51"
# This is mainly for nightly builds which have the suffix ".dev$DATE". See
# https://semver.org/#is-v123-a-semantic-version for the semantics.
_SUFFIX = ""
Expand Down

0 comments on commit 37cdb9e

Please sign in to comment.