Skip to content

Commit

Permalink
Starting on a new approach
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Nov 7, 2024
1 parent faf8659 commit 12a91ff
Showing 1 changed file with 126 additions and 213 deletions.
339 changes: 126 additions & 213 deletions pdelfin/beakerpipeline.py
Original file line number Diff line number Diff line change
@@ -1,226 +1,139 @@
import logging
import argparse
import subprocess
import signal
import sys
import os
import time
import tempfile
import redis
import redis.exceptions
import random
import boto3
import atexit
import os

from tqdm import tqdm
from urllib.parse import urlparse
import zstandard as zstd
from io import BytesIO, TextIOWrapper

from pdelfin.s3_utils import expand_s3_glob, get_s3_bytes, parse_s3_path, put_s3_bytes

from pdelfin.s3_utils import expand_s3_glob
# Basic logging setup for now
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)

# Quiet logs from pypdf
logging.getLogger("pypdf").setLevel(logging.ERROR)

# Global s3 client for the whole script, feel free to adjust params if you need it
workspace_s3 = boto3.client('s3')
pdf_s3 = boto3.client('s3')

LOCK_KEY = "queue_populating"
LOCK_TIMEOUT = 30 # seconds

def populate_queue_if_empty(queue, s3_glob_path, redis_client):
"""
Check if the queue is empty. If it is, attempt to acquire a lock to populate it.
Only one worker should populate the queue at a time.
"""
if queue.llen("work_queue") == 0:
# Attempt to acquire the lock
lock_acquired = redis_client.set(LOCK_KEY, "locked", nx=True, ex=LOCK_TIMEOUT)
if lock_acquired:
print("Acquired lock to populate the queue.")
try:
paths = expand_s3_glob(pdf_s3, s3_glob_path)
if not paths:
print("No paths found to populate the queue.")
return
for path in paths:
queue.rpush("work_queue", path)
print("Queue populated with initial work items.")
except Exception as e:
print(f"Error populating queue: {e}")
# Optionally, handle retry logic or alerting here
finally:
# Release the lock
redis_client.delete(LOCK_KEY)
print("Released lock after populating the queue.")
else:
print("Another worker is populating the queue. Waiting for it to complete.")
# Optionally, wait until the queue is populated
wait_for_queue_population(queue)

def wait_for_queue_population(queue, wait_time=5, max_wait=60):
"""
Wait until the queue is populated by another worker.
"""
elapsed = 0
while elapsed < max_wait:
queue_length = queue.llen("work_queue")
if queue_length > 0:
print("Queue has been populated by another worker.")
return
print(f"Waiting for queue to be populated... ({elapsed + wait_time}/{max_wait} seconds)")
time.sleep(wait_time)
elapsed += wait_time
print("Timeout waiting for queue to be populated.")
sys.exit(1)

def process(item):
# Simulate processing time between 1 and 3 seconds
print(f"Processing item: {item}")
time.sleep(0.5)
print(f"Completed processing item: {item}")

def get_redis_client(sentinel, master_name, leader_ip, leader_port, max_wait=60):
"""
Obtain a Redis client using Sentinel, with retry logic.
"""
elapsed = 0
wait_interval = 1 # seconds
while elapsed < max_wait:
try:
r = sentinel.master_for(master_name, socket_timeout=0.1, decode_responses=True)
r.ping()
print(f"Connected to Redis master at {leader_ip}:{leader_port}")
return r
except redis.exceptions.ConnectionError as e:
print(f"Attempt {elapsed + 1}: Unable to connect to Redis master at {leader_ip}:{leader_port}. Retrying in {wait_interval} second(s)...")
time.sleep(wait_interval)
elapsed += wait_interval
print(f"Failed to connect to Redis master at {leader_ip}:{leader_port} after {max_wait} seconds. Exiting.")
sys.exit(1)

def main():
parser = argparse.ArgumentParser(description='Set up Redis Sentinel-based worker queue.')
parser.add_argument('--leader-ip', help='IP address of the initial leader node')
parser.add_argument('--leader-port', type=int, default=6379, help='Port of the initial leader node')
parser.add_argument('--replica', type=int, required=True, help='Replica number (0 to N-1)')
parser.add_argument('--add-pdfs', help='S3 glob path for work items')

def download_zstd_csv(s3_client, s3_path):
"""Download and decompress a .zstd CSV file from S3."""
try:
compressed_data = get_s3_bytes(s3_client, s3_path)
dctx = zstd.ZstdDecompressor()
decompressed = dctx.decompress(compressed_data)
text_stream = TextIOWrapper(BytesIO(decompressed), encoding='utf-8')
lines = text_stream.readlines()
logger.info(f"Downloaded and decompressed {s3_path}")
return lines
except s3_client.exceptions.NoSuchKey:
logger.info(f"No existing {s3_path} found in s3, starting fresh.")
return []


def upload_zstd_csv(s3_client, s3_path, lines):
"""Compress and upload a list of lines as a .zstd CSV file to S3."""
joined_text = "\n".join(lines)
compressor = zstd.ZstdCompressor()
compressed = compressor.compress(joined_text.encode('utf-8'))
put_s3_bytes(s3_client, s3_path, compressed)
logger.info(f"Uploaded compressed {s3_path}")


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/')
parser.add_argument('--pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None)
parser.add_argument('--target_longest_image_dim', type=int, help='Dimension on longest side to use for rendering the pdf pages', default=1024)
parser.add_argument('--target_anchor_text_len', type=int, help='Maximum amount of anchor text to use (characters)', default=6000)
parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None)
parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None)
parser.add_argument('--group_size', type=int, default=20, help='Number of pdfs that will be part of each work item in the work queue.')
parser.add_argument('--workers', type=int, default=10, help='Number of workers to run at a time')

args = parser.parse_args()

replica_number = args.replica

base_redis_port = 6379
base_sentinel_port = 26379

redis_port = base_redis_port + replica_number
sentinel_port = base_sentinel_port + replica_number

if replica_number == 0:
leader_ip = args.leader_ip if args.leader_ip else '127.0.0.1'
leader_port = args.leader_port
else:
if not args.leader_ip:
print('Error: --leader-ip is required for replica nodes (replica_number >= 1)')
sys.exit(1)
leader_ip = args.leader_ip
leader_port = args.leader_port

temp_dir = tempfile.mkdtemp()
redis_conf_path = os.path.join(temp_dir, 'redis.conf')
sentinel_conf_path = os.path.join(temp_dir, 'sentinel.conf')

print("Redis config path:", redis_conf_path)

with open(redis_conf_path, 'w') as f:
f.write(f'port {redis_port}\n')
f.write(f'dbfilename dump-{replica_number}.rdb\n')
f.write(f'appendfilename "appendonly-{replica_number}.aof"\n')
f.write(f'logfile "redis-{replica_number}.log"\n')
f.write(f'dir {temp_dir}\n')
if replica_number == 0:
f.write('bind 0.0.0.0\n')
if args.workspace_profile:
workspace_session = boto3.Session(profile_name=args.workspace_profile)
workspace_s3 = workspace_session.client("s3")

if args.pdf_profile:
pdf_session = boto3.Session(profile_name=args.pdf_profile)
pdf_s3 = pdf_session.client("s3")

# Check list of pdfs and that it matches what's in the workspace
if args.pdfs:
if args.pdfs.startswith("s3://"):
logger.info(f"Expanding s3 glob at {args.pdfs}")
all_pdfs = expand_s3_glob(pdf_s3, args.pdfs)
elif os.path.exists(args.pdfs):
logger.info(f"Loading file at {args.pdfs}")
with open(args.pdfs, "r") as f:
all_pdfs = list(filter(None, (line.strip() for line in tqdm(f, desc="Processing PDFs"))))
else:
f.write(f'replicaof {leader_ip} {leader_port}\n')

master_name = 'mymaster'
quorum = 1

with open(sentinel_conf_path, 'w') as f:
f.write(f'port {sentinel_port}\n')
f.write(f'dir {temp_dir}\n')
f.write(f'sentinel monitor {master_name} {leader_ip} {leader_port} {quorum}\n')
f.write(f'sentinel down-after-milliseconds {master_name} 5000\n')
f.write(f'sentinel failover-timeout {master_name} 10000\n')
f.write(f'sentinel parallel-syncs {master_name} 1\n')

redis_process = subprocess.Popen(['redis-server', redis_conf_path])
sentinel_process = subprocess.Popen(['redis-sentinel', sentinel_conf_path])

# Register atexit function to guarantee process termination
def terminate_processes():
print("Terminating child processes...")
redis_process.terminate()
sentinel_process.terminate()
try:
redis_process.wait(timeout=5)
sentinel_process.wait(timeout=5)
except subprocess.TimeoutExpired:
print("Forcing termination of child processes.")
redis_process.kill()
sentinel_process.kill()
print("Child processes terminated.")

atexit.register(terminate_processes)

# Also handle signal-based termination
def handle_signal(signum, frame):
print(f"Received signal {signum}. Terminating processes...")
terminate_processes()
sys.exit(0)

signal.signal(signal.SIGINT, handle_signal)
signal.signal(signal.SIGTERM, handle_signal)

time.sleep(2)

# Use Sentinel to connect to the master
from redis.sentinel import Sentinel
sentinel = Sentinel([('127.0.0.1', sentinel_port)], socket_timeout=0.1)

# Initial connection to Redis master
redis_client = get_redis_client(sentinel, master_name, leader_ip, leader_port)

# Populate the work queue if it's empty, using a distributed lock
populate_queue_if_empty(redis_client, args.add_pdfs, redis_client)
raise ValueError("pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)")

try:
while True:
try:
# Try to get an item from the queue with a 1-minute timeout for processing
work_item = redis_client.brpoplpush("work_queue", "processing_queue", 60)
if work_item:
try:
process(work_item)
# Remove from the processing queue if processed successfully
redis_client.lrem("processing_queue", 1, work_item)
except Exception as e:
print(f"Error processing {work_item}: {e}")
# If an error occurs, let it be requeued after timeout

queue_length = redis_client.llen("work_queue")
print(f"Total work items in queue: {queue_length}")

time.sleep(0.1)

except (redis.exceptions.ConnectionError, redis.exceptions.TimeoutError) as e:
print("Lost connection to Redis. Attempting to reconnect using Sentinel...")
# Attempt to reconnect using Sentinel
while True:
try:
redis_client = get_redis_client(sentinel, master_name, leader_ip, leader_port)
print("Reconnected to Redis master.")
break # Exit the reconnection loop and resume work
except redis.exceptions.ConnectionError:
print("Reconnection failed. Retrying in 5 seconds...")
time.sleep(5)
except Exception as e:
print(f"Unexpected error: {e}")
handle_signal(None, None)

except KeyboardInterrupt:
handle_signal(None, None)
all_pdfs = set(all_pdfs)
logger.info(f"Found {len(all_pdfs):,} total pdf paths")

index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
existing_lines = download_zstd_csv(workspace_s3, index_file_s3_path)

if __name__ == '__main__':
main()
# Parse existing work items into groups
existing_groups = [line.strip().split(",") for line in existing_lines if line.strip()]
existing_pdf_set = set(pdf for group in existing_groups for pdf in group)

logger.info(f"Loaded {len(existing_pdf_set):,} existing pdf paths from the workspace")

# Remove existing PDFs from all_pdfs
new_pdfs = all_pdfs - existing_pdf_set
logger.info(f"{len(new_pdfs):,} new pdf paths to add to the workspace")

# Group the new PDFs into chunks of group_size
new_groups = []
current_group = []
for pdf in sorted(new_pdfs): # Sort for consistency
current_group.append(pdf)
if len(current_group) == args.group_size:
new_groups.append(current_group)
current_group = []
if current_group:
new_groups.append(current_group)

logger.info(f"Created {len(new_groups):,} new work groups")

# Combine existing groups with new groups
combined_groups = existing_groups + new_groups

# Prepare lines to write back
combined_lines = [",".join(group) for group in combined_groups]

# Upload the combined work items back to S3
upload_zstd_csv(workspace_s3, index_file_s3_path, combined_lines)

logger.info("Completed adding new PDFs.")


# If there is a beaker flag, then your job is to trigger this script with N replicas on beaker
# If not, then your job is to do the actual work

# Start up the sglang server

# Read in the work queue from s3
# Read in the done items from the s3 workspace

# Spawn up to N workers to do:
# In a loop, take a random work item, read in the pdfs, queue in their requests
# Get results back, retry any failed pages
# Check periodically if that work is done in s3, if so, then abandon this work
# Save results back to s3 workspace output folder

# Possible future addon, in beaker, discover other nodes on this same job
# Send them a message when you take a work item off the queue

0 comments on commit 12a91ff

Please sign in to comment.