Skip to content

Commit

Permalink
Convertsilver birr script can go in and out of S3 now
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Sep 30, 2024
1 parent b856b45 commit 963e946
Showing 1 changed file with 88 additions and 33 deletions.
121 changes: 88 additions & 33 deletions pdelfin/silver_data/convertsilver_birr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
# Converts data that was built by "buildsilver.py" into something you can feed to the mise/birr batch inference pipeline
# to efficiently generate eval samples with against a local model

import argparse
import json
import re
Expand All @@ -9,6 +6,8 @@
import sys
import logging

import smart_open

from pdelfin.prompts import build_finetuning_prompt


Expand All @@ -23,6 +22,11 @@ def setup_logging():
)


def is_s3_path(path):
"""Check if the given path is an S3 path."""
return str(path).startswith('s3://')


def transform_json_object(obj):
"""
Transform a single JSON object by extracting and renaming specific fields.
Expand All @@ -46,21 +50,21 @@ def transform_json_object(obj):
return None


def process_file(input_file: Path, output_file: Path, rewrite_prompt_str: bool):
def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool):
"""
Process a single JSONL file: read, transform, and write to output.
Args:
input_file (Path): Path to the input JSONL file.
output_file (Path): Path to the output JSONL file.
input_file (str): Path or URL to the input JSONL file.
output_file (str): Path or URL to the output JSONL file.
"""
processed_count = 0
error_count = 0

try:
with input_file.open('r', encoding='utf-8') as infile, \
output_file.open('w', encoding='utf-8') as outfile:
with smart_open.open(input_file, 'r', encoding='utf-8') as infile, \
smart_open.open(output_file, 'w', encoding='utf-8') as outfile:

for line_number, line in enumerate(infile, 1):
line = line.strip()
if not line:
Expand All @@ -85,16 +89,69 @@ def process_file(input_file: Path, output_file: Path, rewrite_prompt_str: bool):
transformed["chat_messages"][0]["content"][0]["text"] = build_finetuning_prompt(raw_page_text)

if transformed is not None:
json.dump(transformed, outfile)
outfile.write('\n')
outfile.write(json.dumps(transformed) + '\n')
processed_count += 1
else:
error_count += 1

logging.info(f"Processed '{input_file.name}': {processed_count} records transformed, {error_count} errors.")
logging.info(f"Processed '{input_file}': {processed_count} records transformed, {error_count} errors.")
except Exception as e:
logging.error(f"Failed to process file {input_file}: {e}")


def construct_output_file_path(input_file_path, input_dir, output_dir):
"""
Given an input file path, input directory, and output directory,
construct the corresponding output file path.
Args:
input_file_path (str): Path to the input file.
input_dir (str): Path to the input directory.
output_dir (str): Path to the output directory.
Returns:
str: Path to the output file.
"""
input_file = Path(input_file_path)
input_dir_path = Path(input_dir)
relative_path = input_file.relative_to(input_dir_path)

if is_s3_path(output_dir):
# For S3 output paths, construct the S3 URL manually
output_file_path = output_dir.rstrip('/') + '/' + str(relative_path).replace('\\', '/')
else:
# For local output paths
output_file_path = str(Path(output_dir) / relative_path)
return output_file_path


def list_input_files(input_dir):
"""
List all JSONL files in the input directory.
Args:
input_dir (str): Path to the input directory.
Returns:
list: List of input file paths.
"""
if is_s3_path(input_dir):
# Use smart_open's s3 functionality to list files
import boto3
s3 = boto3.resource('s3')
bucket_name = input_dir.split('s3://')[1].split('/')[0]
prefix = '/'.join(input_dir.split('s3://')[1].split('/')[1:])
bucket = s3.Bucket(bucket_name)
files = []
for obj in bucket.objects.filter(Prefix=prefix):
if obj.key.endswith('.jsonl'):
files.append(f's3://{bucket_name}/{obj.key}')
return files
else:
input_dir_path = Path(input_dir)
return [str(p) for p in input_dir_path.glob('*.jsonl')]


def main():
setup_logging()
parser = argparse.ArgumentParser(
Expand All @@ -104,17 +161,17 @@ def main():
'--rewrite_finetuning_prompt',
action='store_true',
default=False,
help="Rewrites the input prompt from standard OPENAI instruction format, into our finetuned format"
help="Rewrites the input prompt from standard OPENAI instruction format into our finetuned format"
)
parser.add_argument(
'input_dir',
type=str,
help='Path to the input directory containing JSONL files.'
help='Path to the input directory containing JSONL files. Can be a local path or S3 URL.'
)
parser.add_argument(
'output_dir',
type=str,
help='Path to the output directory where transformed JSONL files will be saved.'
help='Path to the output directory where transformed JSONL files will be saved. Can be a local path or S3 URL.'
)
parser.add_argument(
'--jobs', '-j',
Expand All @@ -124,43 +181,41 @@ def main():
)
args = parser.parse_args()

input_dir = Path(args.input_dir)
output_dir = Path(args.output_dir)
input_dir = args.input_dir.rstrip('/')
output_dir = args.output_dir.rstrip('/')
max_jobs = args.jobs

# Validate input directory
if not input_dir.exists() or not input_dir.is_dir():
logging.error(f"Input directory '{input_dir}' does not exist or is not a directory.")
sys.exit(1)

# Create output directory if it doesn't exist
output_dir.mkdir(parents=True, exist_ok=True)
# List input files
input_files = list_input_files(input_dir)

# Gather all JSONL files in the input directory
jsonl_files = list(input_dir.glob('*.jsonl'))

if not jsonl_files:
if not input_files:
logging.warning(f"No JSONL files found in '{input_dir}'. Exiting.")
sys.exit(0)

logging.info(f"Found {len(jsonl_files)} JSONL files to process.")
logging.info(f"Found {len(input_files)} JSONL files to process.")

# Prepare tasks for parallel processing
tasks = []
for input_file in input_files:
output_file = construct_output_file_path(input_file, input_dir, output_dir)
tasks.append((input_file, output_file))

# Process files in parallel
with ProcessPoolExecutor(max_workers=max_jobs) as executor:
future_to_file = {
executor.submit(process_file, file, output_dir / file.name, args.rewrite_finetuning_prompt): file
for file in jsonl_files
executor.submit(process_file, input_file, output_file, args.rewrite_finetuning_prompt): input_file
for input_file, output_file in tasks
}

for future in as_completed(future_to_file):
file = future_to_file[future]
input_file = future_to_file[future]
try:
future.result()
except Exception as exc:
logging.error(f"File {file.name} generated an exception: {exc}")
logging.error(f"File {input_file} generated an exception: {exc}")

logging.info("All files have been processed.")


if __name__ == "__main__":
main()

0 comments on commit 963e946

Please sign in to comment.