Skip to content

Commit

Permalink
Downloads from s3 based on hash
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Nov 12, 2024
1 parent 6598e2d commit 910c2eb
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 121 deletions.
6 changes: 3 additions & 3 deletions pdelfin/beakerpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,8 +398,7 @@ async def worker(args, queue, semaphore, worker_id):

async def sglang_server_task(args, semaphore):
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
# TODO cache locally
#download_directory(args.model, model_cache_dir)
download_directory(args.model, model_cache_dir)

# Check the rope config and make sure it's got the proper key
with open(os.path.join(model_cache_dir, "config.json"), "r") as cfin:
Expand Down Expand Up @@ -484,6 +483,7 @@ async def read_stream(stream):
async def sglang_server_host(args, semaphore):
while True:
await sglang_server_task(args, semaphore)
logger.warning("SGLang server task ended")


async def sglang_server_ready():
Expand Down Expand Up @@ -525,7 +525,7 @@ async def main():
parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None)
parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None)
parser.add_argument('--group_size', type=int, default=20, help='Number of pdfs that will be part of each work item in the work queue.')
parser.add_argument('--workers', type=int, default=3, help='Number of workers to run at a time')
parser.add_argument('--workers', type=int, default=5, help='Number of workers to run at a time')

parser.add_argument('--model', help='List of paths where you can find the model to convert this pdf. You can specify several different paths here, and the script will try to use the one which is fastest to access',
default=["weka://oe-data-default/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/best_bf16/",
Expand Down
267 changes: 149 additions & 118 deletions pdelfin/s3_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import boto3
import requests
import concurrent.futures
import hashlib # Added for MD5 hash computation

from urllib.parse import urlparse
from pathlib import Path
Expand All @@ -14,7 +15,7 @@
from botocore.config import Config
from botocore.exceptions import NoCredentialsError
from boto3.s3.transfer import TransferConfig
from typing import Optional
from typing import Optional, List
from urllib.parse import urlparse
import zstandard as zstd
from io import BytesIO, TextIOWrapper
Expand Down Expand Up @@ -133,21 +134,19 @@ def is_running_on_gcp():
except requests.RequestException:
return False


def download_directory(model_choices: list[str], local_dir: str):
def download_directory(model_choices: List[str], local_dir: str):
"""
Download the model to a specified local directory.
The function will attempt to download from the first available source in the provided list.
Supports Weka (weka://), Google Cloud Storage (gs://), and Amazon S3 (s3://) links.
Args:
model_choices (list[str]): List of model paths (weka://, gs://, or s3://).
model_choices (List[str]): List of model paths (weka://, gs://, or s3://).
local_dir (str): Local directory path where the model will be downloaded.
Raises:
ValueError: If no valid model path is found in the provided choices.
"""
# Ensure the local directory exists
local_path = Path(os.path.expanduser(local_dir))
local_path.mkdir(parents=True, exist_ok=True)
logger.info(f"Local directory set to: {local_path}")
Expand All @@ -157,148 +156,180 @@ def download_directory(model_choices: list[str], local_dir: str):
other_choices = [path for path in model_choices if not path.startswith("weka://")]
prioritized_choices = weka_choices + other_choices

# Iterate through the provided choices and attempt to download from the first available source
for model_path in prioritized_choices:
logger.info(f"Attempting to download from: {model_path}")
try:
if model_path.startswith("weka://"):
download_dir_from_weka(model_path, str(local_path))
download_dir_from_storage(
model_path, str(local_path), storage_type='weka')
logger.info(f"Successfully downloaded model from Weka: {model_path}")
return
elif model_path.startswith("gs://"):
download_dir_from_gcs(model_path, str(local_path))
download_dir_from_storage(
model_path, str(local_path), storage_type='gcs')
logger.info(f"Successfully downloaded model from Google Cloud Storage: {model_path}")
return
elif model_path.startswith("s3://"):
download_dir_from_s3(model_path, str(local_path))
download_dir_from_storage(
model_path, str(local_path), storage_type='s3')
logger.info(f"Successfully downloaded model from S3: {model_path}")
return
else:
logger.warning(f"Unsupported model path scheme: {model_path}")
except Exception as e:
logger.error(f"Failed to download from {model_path}: {e}")
continue # Try the next available source
continue

raise ValueError("Failed to download the model from all provided sources.")


def download_dir_from_gcs(gcs_path: str, local_dir: str):
"""Download model files from Google Cloud Storage to a local directory."""
client = storage.Client()
bucket_name, prefix = parse_s3_path(gcs_path.replace("gs://", "s3://"))
bucket = client.bucket(bucket_name)

blobs = list(bucket.list_blobs(prefix=prefix))
total_files = len(blobs)
logger.info(f"Found {total_files} files in GCS bucket '{bucket_name}' with prefix '{prefix}'.")

with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
for blob in blobs:
relative_path = os.path.relpath(blob.name, prefix)
local_file_path = os.path.join(local_dir, relative_path)
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
futures.append(executor.submit(blob.download_to_filename, local_file_path))

# Use tqdm to display progress
for _ in tqdm(concurrent.futures.as_completed(futures), total=total_files, desc="Downloading from GCS"):
pass

logger.info(f"Downloaded model from Google Cloud Storage to {local_dir}")

def download_dir_from_storage(storage_path: str, local_dir: str, storage_type: str):
"""
Generalized function to download model files from different storage services
to a local directory, syncing using MD5 hashes where possible.
def download_dir_from_s3(s3_path: str, local_dir: str):
"""Download model files from S3 to a local directory."""
boto3_config = Config(
max_pool_connections=500 # Adjust this number based on your requirements
)
s3_client = boto3.client('s3', config=boto3_config)
bucket, prefix = parse_s3_path(s3_path)
paginator = s3_client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
Args:
storage_path (str): The path to the storage location (weka://, gs://, or s3://).
local_dir (str): The local directory where files will be downloaded.
storage_type (str): Type of storage ('weka', 'gcs', or 's3').
Raises:
ValueError: If the storage type is unsupported or credentials are missing.
"""
bucket_name, prefix = parse_s3_path(storage_path)
total_files = 0
objects = []
for page in pages:
if 'Contents' in page:
objects.extend(page['Contents'])

total_files = len(objects)
logger.info(f"Found {total_files} files in S3 bucket '{bucket}' with prefix '{prefix}'.")

with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
for obj in objects:
key = obj["Key"]
relative_path = os.path.relpath(key, prefix)
local_file_path = os.path.join(local_dir, relative_path)
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
futures.append(executor.submit(s3_client.download_file, bucket, key, local_file_path))

# Use tqdm to display progress
for _ in tqdm(concurrent.futures.as_completed(futures), total=total_files, desc="Downloading from S3"):
pass

logger.info(f"Downloaded model from S3 to {local_dir}")


def download_dir_from_weka(weka_path: str, local_dir: str):
"""Download model files from Weka to a local directory."""
# Retrieve Weka credentials from environment variables
weka_access_key = os.getenv("WEKA_ACCESS_KEY_ID")
weka_secret_key = os.getenv("WEKA_SECRET_ACCESS_KEY")
if not weka_access_key or not weka_secret_key:
raise ValueError("WEKA_ACCESS_KEY_ID and WEKA_SECRET_ACCESS_KEY environment variables must be set for Weka access.")

# Configure the boto3 client for Weka
weka_endpoint = "https://weka-aus.beaker.org:9000"
boto3_config = Config(
max_pool_connections=500, # Adjust this number based on your requirements
signature_version='s3v4',
retries={'max_attempts': 10, 'mode': 'standard'}
)
# Configure transfer settings for multipart download
transfer_config = TransferConfig(
multipart_threshold=8 * 1024 * 1024, # 8MB threshold for multipart downloads
multipart_chunksize=8 * 1024 * 1024, # 8MB per part
max_concurrency=100, # Number of threads for each file download
use_threads=True # Enable threading
)
if storage_type == 'gcs':
client = storage.Client()
bucket = client.bucket(bucket_name)
blobs = list(bucket.list_blobs(prefix=prefix))
total_files = len(blobs)
logger.info(f"Found {total_files} files in GCS bucket '{bucket_name}' with prefix '{prefix}'.")

def should_download(blob, local_file_path):
return compare_hashes_gcs(blob, local_file_path)

def download_blob(blob, local_file_path):
blob.download_to_filename(local_file_path)

items = blobs
elif storage_type in ('s3', 'weka'):
if storage_type == 'weka':
weka_access_key = os.getenv("WEKA_ACCESS_KEY_ID")
weka_secret_key = os.getenv("WEKA_SECRET_ACCESS_KEY")
if not weka_access_key or not weka_secret_key:
raise ValueError("WEKA_ACCESS_KEY_ID and WEKA_SECRET_ACCESS_KEY must be set for Weka access.")
endpoint_url = "https://weka-aus.beaker.org:9000"
boto3_config = Config(
max_pool_connections=500,
signature_version='s3v4',
retries={'max_attempts': 10, 'mode': 'standard'}
)
s3_client = boto3.client(
's3',
endpoint_url=endpoint_url,
aws_access_key_id=weka_access_key,
aws_secret_access_key=weka_secret_key,
config=boto3_config
)
else:
s3_client = boto3.client('s3', config=Config(max_pool_connections=500))

paginator = s3_client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
for page in pages:
if 'Contents' in page:
objects.extend(page['Contents'])
total_files = len(objects)
logger.info(f"Found {total_files} files in {'Weka' if storage_type == 'weka' else 'S3'} bucket '{bucket_name}' with prefix '{prefix}'.")

transfer_config = TransferConfig(
multipart_threshold=8 * 1024 * 1024,
multipart_chunksize=8 * 1024 * 1024,
max_concurrency=100,
use_threads=True
)

s3_client = boto3.client(
's3',
endpoint_url=weka_endpoint,
aws_access_key_id=weka_access_key,
aws_secret_access_key=weka_secret_key,
config=boto3_config
)
def should_download(obj, local_file_path):
return compare_hashes_s3(obj, local_file_path)

def download_blob(obj, local_file_path):
s3_client.download_file(bucket_name, obj['Key'], local_file_path, Config=transfer_config)

bucket, prefix = parse_s3_path(weka_path)
paginator = s3_client.get_paginator("list_objects_v2")
try:
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
except s3_client.exceptions.NoSuchBucket:
raise ValueError(f"The bucket '{bucket}' does not exist in Weka.")

objects = []
for page in pages:
if 'Contents' in page:
objects.extend(page['Contents'])

total_files = len(objects)
logger.info(f"Found {total_files} files in Weka bucket '{bucket}' with prefix '{prefix}'.")
items = objects
else:
raise ValueError(f"Unsupported storage type: {storage_type}")

with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
for obj in objects:
key = obj["Key"]
relative_path = os.path.relpath(key, prefix)
for item in items:
if storage_type == 'gcs':
relative_path = os.path.relpath(item.name, prefix)
else:
relative_path = os.path.relpath(item['Key'], prefix)
local_file_path = os.path.join(local_dir, relative_path)
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
futures.append(executor.submit(s3_client.download_file, bucket, key, local_file_path, Config=transfer_config))

# Use tqdm to display progress
for _ in tqdm(concurrent.futures.as_completed(futures), total=total_files, desc="Downloading from Weka"):
pass

logger.info(f"Downloaded model from Weka to {local_dir}")
if should_download(item, local_file_path):
futures.append(executor.submit(download_blob, item, local_file_path))
else:
total_files -= 1 # Decrement total_files as we're skipping this file

if total_files > 0:
for _ in tqdm(concurrent.futures.as_completed(futures), total=total_files, desc=f"Downloading from {storage_type.upper()}"):
pass
else:
logger.info("All files are up-to-date. No downloads needed.")

logger.info(f"Downloaded model from {storage_type.upper()} to {local_dir}")


def compare_hashes_gcs(blob, local_file_path: str) -> bool:
"""Compare MD5 hashes for GCS blobs."""
if os.path.exists(local_file_path):
remote_md5_base64 = blob.md5_hash
hash_md5 = hashlib.md5()
with open(local_file_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
hash_md5.update(chunk)
local_md5 = hash_md5.digest()
remote_md5 = base64.b64decode(remote_md5_base64)
if remote_md5 == local_md5:
logger.info(f"File '{local_file_path}' already up-to-date. Skipping download.")
return False
else:
logger.info(f"File '{local_file_path}' differs from GCS. Downloading.")
return True
else:
logger.info(f"File '{local_file_path}' does not exist locally. Downloading.")
return True


def compare_hashes_s3(obj, local_file_path: str) -> bool:
"""Compare MD5 hashes or sizes for S3 objects (including Weka)."""
if os.path.exists(local_file_path):
etag = obj['ETag'].strip('"')
if '-' in etag:
remote_size = obj['Size']
local_size = os.path.getsize(local_file_path)
if remote_size == local_size:
logger.info(f"File '{local_file_path}' size matches remote multipart file. Skipping download.")
return False
else:
logger.info(f"File '{local_file_path}' size differs from remote multipart file. Downloading.")
return True
else:
hash_md5 = hashlib.md5()
with open(local_file_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
hash_md5.update(chunk)
local_md5 = hash_md5.hexdigest()
if etag == local_md5:
logger.info(f"File '{local_file_path}' already up-to-date. Skipping download.")
return False
else:
logger.info(f"File '{local_file_path}' differs from remote. Downloading.")
return True
else:
logger.info(f"File '{local_file_path}' does not exist locally. Downloading.")
return True

0 comments on commit 910c2eb

Please sign in to comment.