Skip to content

Commit

Permalink
Adding prompt length histogram to a script
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 8, 2024
1 parent adc702c commit 57d9a21
Showing 1 changed file with 30 additions and 1 deletion.
31 changes: 30 additions & 1 deletion pdelfin/silver_data/convertsilver_birr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

from pdelfin.prompts import build_finetuning_prompt

# Import Plotly for plotting
import plotly.express as px


def setup_logging():
"""Configure logging for the script."""
Expand Down Expand Up @@ -57,9 +60,11 @@ def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool):
Args:
input_file (str): Path or URL to the input JSONL file.
output_file (str): Path or URL to the output JSONL file.
rewrite_prompt_str (bool): Flag to rewrite the prompt string.
"""
processed_count = 0
error_count = 0
prompt_lengths = []

try:
with smart_open.open(input_file, 'r', encoding='utf-8') as infile, \
Expand Down Expand Up @@ -89,15 +94,21 @@ def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool):
transformed["chat_messages"][0]["content"][0]["text"] = build_finetuning_prompt(raw_page_text)

if transformed is not None:
prompt_text = transformed["chat_messages"][0]["content"][0]["text"]
prompt_length = len(prompt_text)
prompt_lengths.append(prompt_length)

outfile.write(json.dumps(transformed) + '\n')
processed_count += 1
else:
error_count += 1

logging.info(f"Processed '{input_file}': {processed_count} records transformed, {error_count} errors.")
return prompt_lengths
except Exception as e:
logging.exception(e)
logging.error(f"Failed to process file {input_file}: {e}")
return []


def construct_output_file_path(input_file_path, input_dir, output_dir):
Expand Down Expand Up @@ -230,6 +241,7 @@ def main():
tasks.append((input_file, output_file))

# Process files in parallel
all_prompt_lengths = []
with ProcessPoolExecutor(max_workers=max_jobs) as executor:
future_to_file = {
executor.submit(process_file, input_file, output_file, args.rewrite_finetuning_prompt): input_file
Expand All @@ -239,12 +251,29 @@ def main():
for future in as_completed(future_to_file):
input_file = future_to_file[future]
try:
future.result()
prompt_lengths = future.result()
all_prompt_lengths.extend(prompt_lengths)
except Exception as exc:
logging.error(f"File {input_file} generated an exception: {exc}")

logging.info("All files have been processed.")

# Plot histogram of prompt lengths
if all_prompt_lengths:
fig = px.histogram(all_prompt_lengths, nbins=50, title="Histogram of Prompt Lengths")
fig.update_xaxes(title="Prompt Length")
fig.update_yaxes(title="Frequency")
try:
fig.write_image("prompt_lengths_histogram.png")
logging.info("Histogram of prompt lengths has been saved to 'prompt_lengths_histogram.png'.")
except Exception as e:
logging.error(f"Failed to save the histogram image: {e}")
logging.error("Please make sure that the 'kaleido' package is installed (pip install -U kaleido).")
fig.write_html("prompt_lengths_histogram.html")
logging.info("Histogram of prompt lengths has been saved to 'prompt_lengths_histogram.html'.")
else:
logging.warning("No prompt lengths were collected; histogram will not be generated.")


if __name__ == "__main__":
main()

0 comments on commit 57d9a21

Please sign in to comment.