generated from allenai/python-package-template
-
Notifications
You must be signed in to change notification settings - Fork 397
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
faf8659
commit 12a91ff
Showing
1 changed file
with
126 additions
and
213 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,226 +1,139 @@ | ||
import logging | ||
import argparse | ||
import subprocess | ||
import signal | ||
import sys | ||
import os | ||
import time | ||
import tempfile | ||
import redis | ||
import redis.exceptions | ||
import random | ||
import boto3 | ||
import atexit | ||
import os | ||
|
||
from tqdm import tqdm | ||
from urllib.parse import urlparse | ||
import zstandard as zstd | ||
from io import BytesIO, TextIOWrapper | ||
|
||
from pdelfin.s3_utils import expand_s3_glob, get_s3_bytes, parse_s3_path, put_s3_bytes | ||
|
||
from pdelfin.s3_utils import expand_s3_glob | ||
# Basic logging setup for now | ||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.INFO) | ||
logging.basicConfig(level=logging.INFO) | ||
|
||
# Quiet logs from pypdf | ||
logging.getLogger("pypdf").setLevel(logging.ERROR) | ||
|
||
# Global s3 client for the whole script, feel free to adjust params if you need it | ||
workspace_s3 = boto3.client('s3') | ||
pdf_s3 = boto3.client('s3') | ||
|
||
LOCK_KEY = "queue_populating" | ||
LOCK_TIMEOUT = 30 # seconds | ||
|
||
def populate_queue_if_empty(queue, s3_glob_path, redis_client): | ||
""" | ||
Check if the queue is empty. If it is, attempt to acquire a lock to populate it. | ||
Only one worker should populate the queue at a time. | ||
""" | ||
if queue.llen("work_queue") == 0: | ||
# Attempt to acquire the lock | ||
lock_acquired = redis_client.set(LOCK_KEY, "locked", nx=True, ex=LOCK_TIMEOUT) | ||
if lock_acquired: | ||
print("Acquired lock to populate the queue.") | ||
try: | ||
paths = expand_s3_glob(pdf_s3, s3_glob_path) | ||
if not paths: | ||
print("No paths found to populate the queue.") | ||
return | ||
for path in paths: | ||
queue.rpush("work_queue", path) | ||
print("Queue populated with initial work items.") | ||
except Exception as e: | ||
print(f"Error populating queue: {e}") | ||
# Optionally, handle retry logic or alerting here | ||
finally: | ||
# Release the lock | ||
redis_client.delete(LOCK_KEY) | ||
print("Released lock after populating the queue.") | ||
else: | ||
print("Another worker is populating the queue. Waiting for it to complete.") | ||
# Optionally, wait until the queue is populated | ||
wait_for_queue_population(queue) | ||
|
||
def wait_for_queue_population(queue, wait_time=5, max_wait=60): | ||
""" | ||
Wait until the queue is populated by another worker. | ||
""" | ||
elapsed = 0 | ||
while elapsed < max_wait: | ||
queue_length = queue.llen("work_queue") | ||
if queue_length > 0: | ||
print("Queue has been populated by another worker.") | ||
return | ||
print(f"Waiting for queue to be populated... ({elapsed + wait_time}/{max_wait} seconds)") | ||
time.sleep(wait_time) | ||
elapsed += wait_time | ||
print("Timeout waiting for queue to be populated.") | ||
sys.exit(1) | ||
|
||
def process(item): | ||
# Simulate processing time between 1 and 3 seconds | ||
print(f"Processing item: {item}") | ||
time.sleep(0.5) | ||
print(f"Completed processing item: {item}") | ||
|
||
def get_redis_client(sentinel, master_name, leader_ip, leader_port, max_wait=60): | ||
""" | ||
Obtain a Redis client using Sentinel, with retry logic. | ||
""" | ||
elapsed = 0 | ||
wait_interval = 1 # seconds | ||
while elapsed < max_wait: | ||
try: | ||
r = sentinel.master_for(master_name, socket_timeout=0.1, decode_responses=True) | ||
r.ping() | ||
print(f"Connected to Redis master at {leader_ip}:{leader_port}") | ||
return r | ||
except redis.exceptions.ConnectionError as e: | ||
print(f"Attempt {elapsed + 1}: Unable to connect to Redis master at {leader_ip}:{leader_port}. Retrying in {wait_interval} second(s)...") | ||
time.sleep(wait_interval) | ||
elapsed += wait_interval | ||
print(f"Failed to connect to Redis master at {leader_ip}:{leader_port} after {max_wait} seconds. Exiting.") | ||
sys.exit(1) | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description='Set up Redis Sentinel-based worker queue.') | ||
parser.add_argument('--leader-ip', help='IP address of the initial leader node') | ||
parser.add_argument('--leader-port', type=int, default=6379, help='Port of the initial leader node') | ||
parser.add_argument('--replica', type=int, required=True, help='Replica number (0 to N-1)') | ||
parser.add_argument('--add-pdfs', help='S3 glob path for work items') | ||
|
||
def download_zstd_csv(s3_client, s3_path): | ||
"""Download and decompress a .zstd CSV file from S3.""" | ||
try: | ||
compressed_data = get_s3_bytes(s3_client, s3_path) | ||
dctx = zstd.ZstdDecompressor() | ||
decompressed = dctx.decompress(compressed_data) | ||
text_stream = TextIOWrapper(BytesIO(decompressed), encoding='utf-8') | ||
lines = text_stream.readlines() | ||
logger.info(f"Downloaded and decompressed {s3_path}") | ||
return lines | ||
except s3_client.exceptions.NoSuchKey: | ||
logger.info(f"No existing {s3_path} found in s3, starting fresh.") | ||
return [] | ||
|
||
|
||
def upload_zstd_csv(s3_client, s3_path, lines): | ||
"""Compress and upload a list of lines as a .zstd CSV file to S3.""" | ||
joined_text = "\n".join(lines) | ||
compressor = zstd.ZstdCompressor() | ||
compressed = compressor.compress(joined_text.encode('utf-8')) | ||
put_s3_bytes(s3_client, s3_path, compressed) | ||
logger.info(f"Uploaded compressed {s3_path}") | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline') | ||
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/') | ||
parser.add_argument('--pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None) | ||
parser.add_argument('--target_longest_image_dim', type=int, help='Dimension on longest side to use for rendering the pdf pages', default=1024) | ||
parser.add_argument('--target_anchor_text_len', type=int, help='Maximum amount of anchor text to use (characters)', default=6000) | ||
parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None) | ||
parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None) | ||
parser.add_argument('--group_size', type=int, default=20, help='Number of pdfs that will be part of each work item in the work queue.') | ||
parser.add_argument('--workers', type=int, default=10, help='Number of workers to run at a time') | ||
|
||
args = parser.parse_args() | ||
|
||
replica_number = args.replica | ||
|
||
base_redis_port = 6379 | ||
base_sentinel_port = 26379 | ||
|
||
redis_port = base_redis_port + replica_number | ||
sentinel_port = base_sentinel_port + replica_number | ||
|
||
if replica_number == 0: | ||
leader_ip = args.leader_ip if args.leader_ip else '127.0.0.1' | ||
leader_port = args.leader_port | ||
else: | ||
if not args.leader_ip: | ||
print('Error: --leader-ip is required for replica nodes (replica_number >= 1)') | ||
sys.exit(1) | ||
leader_ip = args.leader_ip | ||
leader_port = args.leader_port | ||
|
||
temp_dir = tempfile.mkdtemp() | ||
redis_conf_path = os.path.join(temp_dir, 'redis.conf') | ||
sentinel_conf_path = os.path.join(temp_dir, 'sentinel.conf') | ||
|
||
print("Redis config path:", redis_conf_path) | ||
|
||
with open(redis_conf_path, 'w') as f: | ||
f.write(f'port {redis_port}\n') | ||
f.write(f'dbfilename dump-{replica_number}.rdb\n') | ||
f.write(f'appendfilename "appendonly-{replica_number}.aof"\n') | ||
f.write(f'logfile "redis-{replica_number}.log"\n') | ||
f.write(f'dir {temp_dir}\n') | ||
if replica_number == 0: | ||
f.write('bind 0.0.0.0\n') | ||
if args.workspace_profile: | ||
workspace_session = boto3.Session(profile_name=args.workspace_profile) | ||
workspace_s3 = workspace_session.client("s3") | ||
|
||
if args.pdf_profile: | ||
pdf_session = boto3.Session(profile_name=args.pdf_profile) | ||
pdf_s3 = pdf_session.client("s3") | ||
|
||
# Check list of pdfs and that it matches what's in the workspace | ||
if args.pdfs: | ||
if args.pdfs.startswith("s3://"): | ||
logger.info(f"Expanding s3 glob at {args.pdfs}") | ||
all_pdfs = expand_s3_glob(pdf_s3, args.pdfs) | ||
elif os.path.exists(args.pdfs): | ||
logger.info(f"Loading file at {args.pdfs}") | ||
with open(args.pdfs, "r") as f: | ||
all_pdfs = list(filter(None, (line.strip() for line in tqdm(f, desc="Processing PDFs")))) | ||
else: | ||
f.write(f'replicaof {leader_ip} {leader_port}\n') | ||
|
||
master_name = 'mymaster' | ||
quorum = 1 | ||
|
||
with open(sentinel_conf_path, 'w') as f: | ||
f.write(f'port {sentinel_port}\n') | ||
f.write(f'dir {temp_dir}\n') | ||
f.write(f'sentinel monitor {master_name} {leader_ip} {leader_port} {quorum}\n') | ||
f.write(f'sentinel down-after-milliseconds {master_name} 5000\n') | ||
f.write(f'sentinel failover-timeout {master_name} 10000\n') | ||
f.write(f'sentinel parallel-syncs {master_name} 1\n') | ||
|
||
redis_process = subprocess.Popen(['redis-server', redis_conf_path]) | ||
sentinel_process = subprocess.Popen(['redis-sentinel', sentinel_conf_path]) | ||
|
||
# Register atexit function to guarantee process termination | ||
def terminate_processes(): | ||
print("Terminating child processes...") | ||
redis_process.terminate() | ||
sentinel_process.terminate() | ||
try: | ||
redis_process.wait(timeout=5) | ||
sentinel_process.wait(timeout=5) | ||
except subprocess.TimeoutExpired: | ||
print("Forcing termination of child processes.") | ||
redis_process.kill() | ||
sentinel_process.kill() | ||
print("Child processes terminated.") | ||
|
||
atexit.register(terminate_processes) | ||
|
||
# Also handle signal-based termination | ||
def handle_signal(signum, frame): | ||
print(f"Received signal {signum}. Terminating processes...") | ||
terminate_processes() | ||
sys.exit(0) | ||
|
||
signal.signal(signal.SIGINT, handle_signal) | ||
signal.signal(signal.SIGTERM, handle_signal) | ||
|
||
time.sleep(2) | ||
|
||
# Use Sentinel to connect to the master | ||
from redis.sentinel import Sentinel | ||
sentinel = Sentinel([('127.0.0.1', sentinel_port)], socket_timeout=0.1) | ||
|
||
# Initial connection to Redis master | ||
redis_client = get_redis_client(sentinel, master_name, leader_ip, leader_port) | ||
|
||
# Populate the work queue if it's empty, using a distributed lock | ||
populate_queue_if_empty(redis_client, args.add_pdfs, redis_client) | ||
raise ValueError("pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)") | ||
|
||
try: | ||
while True: | ||
try: | ||
# Try to get an item from the queue with a 1-minute timeout for processing | ||
work_item = redis_client.brpoplpush("work_queue", "processing_queue", 60) | ||
if work_item: | ||
try: | ||
process(work_item) | ||
# Remove from the processing queue if processed successfully | ||
redis_client.lrem("processing_queue", 1, work_item) | ||
except Exception as e: | ||
print(f"Error processing {work_item}: {e}") | ||
# If an error occurs, let it be requeued after timeout | ||
|
||
queue_length = redis_client.llen("work_queue") | ||
print(f"Total work items in queue: {queue_length}") | ||
|
||
time.sleep(0.1) | ||
|
||
except (redis.exceptions.ConnectionError, redis.exceptions.TimeoutError) as e: | ||
print("Lost connection to Redis. Attempting to reconnect using Sentinel...") | ||
# Attempt to reconnect using Sentinel | ||
while True: | ||
try: | ||
redis_client = get_redis_client(sentinel, master_name, leader_ip, leader_port) | ||
print("Reconnected to Redis master.") | ||
break # Exit the reconnection loop and resume work | ||
except redis.exceptions.ConnectionError: | ||
print("Reconnection failed. Retrying in 5 seconds...") | ||
time.sleep(5) | ||
except Exception as e: | ||
print(f"Unexpected error: {e}") | ||
handle_signal(None, None) | ||
|
||
except KeyboardInterrupt: | ||
handle_signal(None, None) | ||
all_pdfs = set(all_pdfs) | ||
logger.info(f"Found {len(all_pdfs):,} total pdf paths") | ||
|
||
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd") | ||
existing_lines = download_zstd_csv(workspace_s3, index_file_s3_path) | ||
|
||
if __name__ == '__main__': | ||
main() | ||
# Parse existing work items into groups | ||
existing_groups = [line.strip().split(",") for line in existing_lines if line.strip()] | ||
existing_pdf_set = set(pdf for group in existing_groups for pdf in group) | ||
|
||
logger.info(f"Loaded {len(existing_pdf_set):,} existing pdf paths from the workspace") | ||
|
||
# Remove existing PDFs from all_pdfs | ||
new_pdfs = all_pdfs - existing_pdf_set | ||
logger.info(f"{len(new_pdfs):,} new pdf paths to add to the workspace") | ||
|
||
# Group the new PDFs into chunks of group_size | ||
new_groups = [] | ||
current_group = [] | ||
for pdf in sorted(new_pdfs): # Sort for consistency | ||
current_group.append(pdf) | ||
if len(current_group) == args.group_size: | ||
new_groups.append(current_group) | ||
current_group = [] | ||
if current_group: | ||
new_groups.append(current_group) | ||
|
||
logger.info(f"Created {len(new_groups):,} new work groups") | ||
|
||
# Combine existing groups with new groups | ||
combined_groups = existing_groups + new_groups | ||
|
||
# Prepare lines to write back | ||
combined_lines = [",".join(group) for group in combined_groups] | ||
|
||
# Upload the combined work items back to S3 | ||
upload_zstd_csv(workspace_s3, index_file_s3_path, combined_lines) | ||
|
||
logger.info("Completed adding new PDFs.") | ||
|
||
|
||
# If there is a beaker flag, then your job is to trigger this script with N replicas on beaker | ||
# If not, then your job is to do the actual work | ||
|
||
# Start up the sglang server | ||
|
||
# Read in the work queue from s3 | ||
# Read in the done items from the s3 workspace | ||
|
||
# Spawn up to N workers to do: | ||
# In a loop, take a random work item, read in the pdfs, queue in their requests | ||
# Get results back, retry any failed pages | ||
# Check periodically if that work is done in s3, if so, then abandon this work | ||
# Save results back to s3 workspace output folder | ||
|
||
# Possible future addon, in beaker, discover other nodes on this same job | ||
# Send them a message when you take a work item off the queue |