Skip to content

Commit

Permalink
More reliable weka
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Nov 27, 2024
1 parent 6872105 commit d4f3cff
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 30 deletions.
82 changes: 53 additions & 29 deletions pdelfin/s3_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import time
import requests
import concurrent.futures
import hashlib # Added for MD5 hash computation
import hashlib

from urllib.parse import urlparse
from pathlib import Path
Expand Down Expand Up @@ -81,6 +81,7 @@ def get_s3_bytes(s3_client, s3_path: str, start_index: Optional[int] = None, end

return obj['Body'].read()


def get_s3_bytes_with_backoff(s3_client, pdf_s3_path, max_retries: int = 8, backoff_factor: int = 2):
attempt = 0

Expand All @@ -106,6 +107,7 @@ def get_s3_bytes_with_backoff(s3_client, pdf_s3_path, max_retries: int = 8, back
logger.error(f"Failed to get_s3_bytes for {pdf_s3_path} after {max_retries} retries.")
raise Exception("Failed to get_s3_bytes after retries")


def put_s3_bytes(s3_client, s3_path: str, data: bytes):
bucket, key = parse_s3_path(s3_path)

Expand Down Expand Up @@ -160,6 +162,7 @@ def is_running_on_gcp():
except requests.RequestException:
return False


def download_directory(model_choices: List[str], local_dir: str):
"""
Download the model to a specified local directory.
Expand Down Expand Up @@ -242,7 +245,12 @@ def should_download(blob, local_file_path):
return compare_hashes_gcs(blob, local_file_path)

def download_blob(blob, local_file_path):
blob.download_to_filename(local_file_path)
try:
blob.download_to_filename(local_file_path)
logger.info(f"Successfully downloaded {blob.name} to {local_file_path}")
except Exception as e:
logger.error(f"Failed to download {blob.name} to {local_file_path}: {e}")
raise

items = blobs
elif storage_type in ('s3', 'weka'):
Expand Down Expand Up @@ -272,21 +280,30 @@ def download_blob(blob, local_file_path):
for page in pages:
if 'Contents' in page:
objects.extend(page['Contents'])
else:
logger.warning(f"No contents found in page: {page}")
total_files = len(objects)
logger.info(f"Found {total_files} files in {'Weka' if storage_type == 'weka' else 'S3'} bucket '{bucket_name}' with prefix '{prefix}'.")

transfer_config = TransferConfig(
multipart_threshold=8 * 1024 * 1024,
multipart_chunksize=8 * 1024 * 1024,
max_concurrency=100,
max_concurrency=10, # Reduced for WekaFS compatibility
use_threads=True
)

def should_download(obj, local_file_path):
return compare_hashes_s3(obj, local_file_path)
return compare_hashes_s3(obj, local_file_path, storage_type)

def download_blob(obj, local_file_path):
s3_client.download_file(bucket_name, obj['Key'], local_file_path, Config=transfer_config)
logger.info(f"Starting download of {obj['Key']} to {local_file_path}")
try:
with open(local_file_path, 'wb') as f:
s3_client.download_fileobj(bucket_name, obj['Key'], f, Config=transfer_config)
logger.info(f"Successfully downloaded {obj['Key']} to {local_file_path}")
except Exception as e:
logger.error(f"Failed to download {obj['Key']} to {local_file_path}: {e}")
raise

items = objects
else:
Expand All @@ -307,8 +324,11 @@ def download_blob(obj, local_file_path):
total_files -= 1 # Decrement total_files as we're skipping this file

if total_files > 0:
for _ in tqdm(concurrent.futures.as_completed(futures), total=total_files, desc=f"Downloading from {storage_type.upper()}"):
pass
for future in tqdm(concurrent.futures.as_completed(futures), total=total_files, desc=f"Downloading from {storage_type.upper()}"):
try:
future.result()
except Exception as e:
logger.error(f"Error occurred during download: {e}")
else:
logger.info("All files are up-to-date. No downloads needed.")

Expand Down Expand Up @@ -336,31 +356,35 @@ def compare_hashes_gcs(blob, local_file_path: str) -> bool:
return True


def compare_hashes_s3(obj, local_file_path: str) -> bool:
def compare_hashes_s3(obj, local_file_path: str, storage_type: str) -> bool:
"""Compare MD5 hashes or sizes for S3 objects (including Weka)."""
if os.path.exists(local_file_path):
etag = obj['ETag'].strip('"')
if '-' in etag:
remote_size = obj['Size']
local_size = os.path.getsize(local_file_path)
if remote_size == local_size:
logger.info(f"File '{local_file_path}' size matches remote multipart file. Skipping download.")
return False
else:
logger.info(f"File '{local_file_path}' size differs from remote multipart file. Downloading.")
return True
if storage_type == 'weka':
return True
else:
hash_md5 = hashlib.md5()
with open(local_file_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
hash_md5.update(chunk)
local_md5 = hash_md5.hexdigest()
if etag == local_md5:
logger.info(f"File '{local_file_path}' already up-to-date. Skipping download.")
return False
etag = obj['ETag'].strip('"')
if '-' in etag:
# Multipart upload, compare sizes
remote_size = obj['Size']
local_size = os.path.getsize(local_file_path)
if remote_size == local_size:
logger.info(f"File '{local_file_path}' size matches remote multipart file. Skipping download.")
return False
else:
logger.info(f"File '{local_file_path}' size differs from remote multipart file. Downloading.")
return True
else:
logger.info(f"File '{local_file_path}' differs from remote. Downloading.")
return True
hash_md5 = hashlib.md5()
with open(local_file_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
hash_md5.update(chunk)
local_md5 = hash_md5.hexdigest()
if etag == local_md5:
logger.info(f"File '{local_file_path}' already up-to-date. Skipping download.")
return False
else:
logger.info(f"File '{local_file_path}' differs from remote. Downloading.")
return True
else:
logger.info(f"File '{local_file_path}' does not exist locally. Downloading.")
return True
return True
2 changes: 1 addition & 1 deletion pdelfin/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
_MINOR = "1"
# On main and in a nightly release the patch should be one ahead of the last
# released build.
_PATCH = "49"
_PATCH = "50"
# This is mainly for nightly builds which have the suffix ".dev$DATE". See
# https://semver.org/#is-v123-a-semantic-version for the semantics.
_SUFFIX = ""
Expand Down

0 comments on commit d4f3cff

Please sign in to comment.