Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Nov 11, 2024
1 parent 9ff107b commit da1b23f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
22 changes: 20 additions & 2 deletions pdelfin/beakerpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import time
import subprocess
import hashlib
import json
import base64
import atexit
import asyncio
Expand Down Expand Up @@ -205,7 +206,7 @@ async def process_page(session, pdf_path, page_num, args) -> PageResponse:

try:
base_response_data = await response.json()
model_response_json = orjson.loads(base_response_data["outputs"][0]["text"])
model_response_json = json.loads(base_response_data["outputs"][0]["text"])
page_response = PageResponse(**model_response_json)
except Exception as e:
logger.warning(f"Could not parse response for {pdf_path}-{page_num}")
Expand Down Expand Up @@ -239,6 +240,7 @@ async def process_pdf(args, pdf_s3_path):
# If we failed to build a page, then this document is toast
# TODO Abort earlier, if a page returns a None, then we can stop processing the whole pdf
if any(page is None for page in page_results):
logger.warning(f"PDF {pdf_s3_path} was not able to complete, not able to process a page")
return None

# Build the document text and page spans
Expand Down Expand Up @@ -305,6 +307,17 @@ async def sglang_server_task(args):
# TODO cache locally
#download_directory(args.model, model_cache_dir)

# Check the rope config and make sure it's got the proper key
with open(os.path.join(model_cache_dir, "config.json"), "r") as cfin:
config_data = json.load(cfin)

if "rope_type" in config_data["rope_scaling"]:
del config_data["rope_scaling"]["rope_type"]
config_data["rope_scaling"]["type"] = "mrope"

with open(os.path.join(model_cache_dir, "config.json"), "w") as cfout:
json.dump(config_data, cfout)

proc = await asyncio.create_subprocess_exec(
"python3",

Expand All @@ -315,7 +328,12 @@ async def sglang_server_task(args):
)

# Make really sure we kill this subprocess on exit
atexit.register(lambda: proc.kill())
def _kill_proc():
proc.terminate()
time.sleep(3)
proc.kill()

atexit.register(_kill_proc)

await proc.wait()

Expand Down
16 changes: 13 additions & 3 deletions pdelfin/s3_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from google.cloud import storage
from botocore.config import Config
from botocore.exceptions import NoCredentialsError
from boto3.s3.transfer import TransferConfig
from typing import Optional
from urllib.parse import urlparse
import zstandard as zstd
Expand Down Expand Up @@ -209,7 +210,7 @@ def download_dir_from_gcs(gcs_path: str, local_dir: str):
def download_dir_from_s3(s3_path: str, local_dir: str):
"""Download model files from S3 to a local directory."""
boto3_config = Config(
max_pool_connections=50 # Adjust this number based on your requirements
max_pool_connections=500 # Adjust this number based on your requirements
)
s3_client = boto3.client('s3', config=boto3_config)
bucket, prefix = parse_s3_path(s3_path)
Expand Down Expand Up @@ -251,10 +252,18 @@ def download_dir_from_weka(weka_path: str, local_dir: str):
# Configure the boto3 client for Weka
weka_endpoint = "https://weka-aus.beaker.org:9000"
boto3_config = Config(
max_pool_connections=50, # Adjust this number based on your requirements
max_pool_connections=500, # Adjust this number based on your requirements
signature_version='s3v4',
retries={'max_attempts': 10, 'mode': 'standard'}
)
# Configure transfer settings for multipart download
transfer_config = TransferConfig(
multipart_threshold=8 * 1024 * 1024, # 8MB threshold for multipart downloads
multipart_chunksize=8 * 1024 * 1024, # 8MB per part
max_concurrency=100, # Number of threads for each file download
use_threads=True # Enable threading
)

s3_client = boto3.client(
's3',
endpoint_url=weka_endpoint,
Expand All @@ -263,6 +272,7 @@ def download_dir_from_weka(weka_path: str, local_dir: str):
config=boto3_config
)


bucket, prefix = parse_s3_path(weka_path)
paginator = s3_client.get_paginator("list_objects_v2")
try:
Expand All @@ -285,7 +295,7 @@ def download_dir_from_weka(weka_path: str, local_dir: str):
relative_path = os.path.relpath(key, prefix)
local_file_path = os.path.join(local_dir, relative_path)
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
futures.append(executor.submit(s3_client.download_file, bucket, key, local_file_path))
futures.append(executor.submit(s3_client.download_file, bucket, key, local_file_path, Config=transfer_config))

# Use tqdm to display progress
for _ in tqdm(concurrent.futures.as_completed(futures), total=total_files, desc="Downloading from Weka"):
Expand Down

0 comments on commit da1b23f

Please sign in to comment.