diff --git a/examples/gemini/README.md b/examples/gemini/README.md index 769f48ef..a884132c 100644 --- a/examples/gemini/README.md +++ b/examples/gemini/README.md @@ -11,7 +11,8 @@ prompto_run_experiment --file data/input/gemini-example.jsonl --max-queries 30 ## Multimodal prompting -Multimodal prompting is available with the Gemini API. We provide an example notebook in the [Multimodal prompting with Vertex AI notebook](./gemini-multimodal.ipynb) and example experiment file in [data/input/gemini-multimodal-example.jsonl](https://github.com/alan-turing-institute/prompto/blob/main/examples/gemini/data/input/gemini-multimodal-example.jsonl). You can run it with the following command: +Multimodal prompting is available with the Gemini API. To use it, you first need to upload your files to a dedicated cloud storage using the [File API](https://ai.google.dev/api/files#v1beta.files). To support you with this step, we provide a [notebook](./gemini-upload.ipynb) which takes your multimedia prompts as input and will add to each of your `media` elements a corresponding `uploaded_filename` key/value. You can test this with the example experiment file in [data/input/gemini-multimodal-example.jsonl](https://github.com/alan-turing-institute/prompto/blob/main/examples/gemini/data/input/gemini-multimodal-example.jsonl). +Then, we provide an example notebook in the [Multimodal prompting with Vertex AI notebook](./gemini-multimodal.ipynb). You can run it with the following command: ```bash prompto_run_experiment --file data/input/gemini-multimodal-example.jsonl --max-queries 30 ``` diff --git a/examples/gemini/data/input/gemini-multimodal-example.jsonl b/examples/gemini/data/input/gemini-multimodal-example.jsonl index 8d649647..ee0a0061 100644 --- a/examples/gemini/data/input/gemini-multimodal-example.jsonl +++ b/examples/gemini/data/input/gemini-multimodal-example.jsonl @@ -1,3 +1,10 @@ -{"id": 0, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe what is happening in this image", {"type": "image", "media": "pantani_giro.jpg"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}} -{"id": 1, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": [{"type": "image", "media": "mortadella.jpg"}, "what is this?"]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}} -{"id": 2, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["what is in this image?", {"type": "image", "media": "pantani_giro.jpg"}]}, {"role": "model", "parts": "This is image shows a group of cyclists."}, {"role": "user", "parts": "are there any notable cyclists in this image? what are their names?"}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}} +{"id": 1, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe what is happening in this image", {"type": "image", "media": "pantani_giro.jpg"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}} +{"id": 2, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": [{"type": "image", "media": "mortadella.jpg"}, "what is this?"]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}} +{"id": 3, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["what is in this image?", {"type": "image", "media": "pantani_giro.jpg"}]}, {"role": "model", "parts": "This is image shows a group of cyclists."}, {"role": "user", "parts": "are there any notable cyclists in this image? what are their names?"}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}} +{"id": 4, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe what is happening in this video", {"type": "video", "media": "test.mp4"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}} +{"id": 5, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe what is happening in this audio", {"type": "audio", "media": "test.wav"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}} +{"id": 6, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe what is happening in this audio", {"type": "audio", "media": "test.mp3"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}} +{"id": 7, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe this video", {"type": "video", "media": "test.mp4"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}} +{"id": 8, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe this video", {"type": "video", "media": "test2.mp4"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}} +{"id": 9, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": "How does technology impact us?", "safety_filter": "none", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}} +{"id": 10, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "system", "parts": "You are a helpful assistant designed to answer questions briefly."}, {"role": "user", "parts": "Hello, I'm Bob and I'm 6 years old"}, {"role": "model", "parts": "Hi Bob, how may I assist you?"}, {"role": "user", "parts": "How old will I be next year?"}], "safety_filter": "most", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}} diff --git a/examples/gemini/data/media/test.mp3 b/examples/gemini/data/media/test.mp3 new file mode 100644 index 00000000..0764d7a2 Binary files /dev/null and b/examples/gemini/data/media/test.mp3 differ diff --git a/examples/gemini/data/media/test.mp4 b/examples/gemini/data/media/test.mp4 new file mode 100644 index 00000000..5030fc49 Binary files /dev/null and b/examples/gemini/data/media/test.mp4 differ diff --git a/examples/gemini/data/media/test.wav b/examples/gemini/data/media/test.wav new file mode 100644 index 00000000..9bdf83cb Binary files /dev/null and b/examples/gemini/data/media/test.wav differ diff --git a/examples/gemini/data/media/test2.mp4 b/examples/gemini/data/media/test2.mp4 new file mode 100644 index 00000000..5030fc49 Binary files /dev/null and b/examples/gemini/data/media/test2.mp4 differ diff --git a/examples/gemini/gemini-upload.ipynb b/examples/gemini/gemini-upload.ipynb new file mode 100644 index 00000000..59f1613a --- /dev/null +++ b/examples/gemini/gemini-upload.ipynb @@ -0,0 +1,206 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Uploading media to Gemini\n", + "\n", + "This notebook processes an experiment file and associate each media element with the id of the file when uploaded using the Files API" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "import time\n", + "import tqdm\n", + "import base64\n", + "import hashlib\n", + "import google.generativeai as genai" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Set the location of the experiment and media\n", + "\n", + "experiment_location = \"data/input\"\n", + "filename = \"gemini-multimodal-example.jsonl\"\n", + "media_location = \"data/media\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load the GEMINI_API_KEY from the environment\n", + "\n", + "GEMINI_API_KEY = os.environ.get(\"GEMINI_API_KEY\")\n", + "if GEMINI_API_KEY is None:\n", + " raise ValueError(\"GEMINI_API_KEY is not set\")\n", + "\n", + "genai.configure(api_key=GEMINI_API_KEY)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def compute_sha256_base64(file_path, chunk_size=8192):\n", + " \"\"\"\n", + " Compute the SHA256 hash of the file at 'file_path' and return it as a base64-encoded string.\n", + " \"\"\"\n", + " hasher = hashlib.sha256()\n", + " with open(file_path, \"rb\") as f:\n", + " for chunk in iter(lambda: f.read(chunk_size), b\"\"):\n", + " hasher.update(chunk)\n", + " return base64.b64encode(hasher.digest()).decode(\"utf-8\")\n", + "\n", + "\n", + "def remote_file_hash_base64(remote_file):\n", + " \"\"\"\n", + " Convert a remote file's SHA256 hash (stored as a hex-encoded UTF-8 bytes object)\n", + " to a base64-encoded string.\n", + " \"\"\"\n", + " hex_str = remote_file.sha256_hash.decode(\"utf-8\")\n", + " raw_bytes = bytes.fromhex(hex_str)\n", + " return base64.b64encode(raw_bytes).decode(\"utf-8\")\n", + "\n", + "\n", + "def wait_for_processing(file_obj, poll_interval=10):\n", + " \"\"\"\n", + " Poll until the file is no longer in the 'PROCESSING' state.\n", + " Returns the updated file object.\n", + " \"\"\"\n", + " while file_obj.state.name == \"PROCESSING\":\n", + " print(\"Waiting for file to be processed...\")\n", + " time.sleep(poll_interval)\n", + " file_obj = genai.get_file(file_obj.name)\n", + " return file_obj\n", + "\n", + "\n", + "def upload(file_path, already_uploaded_files):\n", + " \"\"\"\n", + " Upload the file at 'file_path' if it hasn't been uploaded yet.\n", + " If a file with the same SHA256 (base64-encoded) hash exists, returns its name.\n", + " Otherwise, uploads the file, waits for it to be processed,\n", + " and returns the new file's name. Raises a ValueError if processing fails.\n", + " \"\"\"\n", + " local_hash = compute_sha256_base64(file_path)\n", + "\n", + " if local_hash in already_uploaded_files:\n", + " return already_uploaded_files[local_hash], already_uploaded_files\n", + "\n", + " # Upload the file if it hasn't been found.\n", + " file_obj = genai.upload_file(path=file_path)\n", + " file_obj = wait_for_processing(file_obj)\n", + "\n", + " if file_obj.state.name == \"FAILED\":\n", + " raise ValueError(\"File processing failed\")\n", + " already_uploaded_files[local_hash] = file_obj.name\n", + " return already_uploaded_files[local_hash], already_uploaded_files" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Retrieve already uploaded files\n", + "\n", + "uploaded_files = {\n", + " remote_file_hash_base64(remote_file): remote_file.name\n", + " for remote_file in genai.list_files()\n", + "}\n", + "print(f\"Found {len(uploaded_files)} files already uploaded\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "files_to_upload = set()\n", + "experiment_path = f\"{experiment_location}/{filename}\"\n", + "\n", + "# Read and collect media file paths\n", + "with open(experiment_path, \"r\") as f:\n", + " lines = f.readlines()\n", + "\n", + "data_list = []\n", + "\n", + "for line in lines:\n", + " data = json.loads(line)\n", + " data_list.append(data)\n", + "\n", + " if not isinstance(data.get(\"prompt\"), list):\n", + " continue\n", + "\n", + " files_to_upload.update(\n", + " f'{media_location}/{el[\"media\"]}'\n", + " for prompt in data[\"prompt\"]\n", + " for part in prompt.get(\"parts\", [])\n", + " if isinstance(el := part, dict) and \"media\" in el\n", + " )\n", + "\n", + "# Upload files and store mappings\n", + "genai_files = {}\n", + "for file_path in tqdm.tqdm(files_to_upload):\n", + " uploaded_filename, uploaded_files = upload(file_path, uploaded_files)\n", + " genai_files[file_path] = uploaded_filename\n", + "\n", + "# Modify data to include uploaded filenames\n", + "for data in data_list:\n", + " if isinstance(data.get(\"prompt\"), list):\n", + " for prompt in data[\"prompt\"]:\n", + " for part in prompt.get(\"parts\", []):\n", + " if isinstance(part, dict) and \"media\" in part:\n", + " file_path = f'{media_location}/{part[\"media\"]}'\n", + " if file_path in genai_files:\n", + " part[\"uploaded_filename\"] = genai_files[file_path]\n", + " else:\n", + " print(f\"Failed to find {file_path} in genai_files\")\n", + "\n", + "# Write modified data back to the JSONL file\n", + "with open(experiment_path, \"w\") as f:\n", + " for data in data_list:\n", + " f.write(json.dumps(data) + \"\\n\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/prompto/apis/gemini/gemini_utils.py b/src/prompto/apis/gemini/gemini_utils.py index 5ed6dc87..b57d8540 100644 --- a/src/prompto/apis/gemini/gemini_utils.py +++ b/src/prompto/apis/gemini/gemini_utils.py @@ -31,6 +31,7 @@ def parse_parts_value(part: dict | str, media_folder: str) -> any: # read multimedia type type = part.get("type") + uploaded_filename = part.get("uploaded_filename") if type is None: raise ValueError("Multimedia type is not specified") # read file location @@ -38,22 +39,20 @@ def parse_parts_value(part: dict | str, media_folder: str) -> any: if media is None: raise ValueError("File location is not specified") - # create Part object based on multimedia type if type == "text": return media else: - if type == "image": - media_file_path = os.path.join(media_folder, media) - return PIL.Image.open(media_file_path) - elif type == "file": + if uploaded_filename is None: + raise ValueError( + f"File {media} not uploaded. Please upload the file first." + ) + else: try: - return get_file(name=media) + return get_file(name=uploaded_filename) except Exception as err: raise ValueError( f"Failed to get file: {media} due to error: {type(err).__name__} - {err}" ) - else: - raise ValueError(f"Unsupported multimedia type: {type}") def parse_parts(parts: list[dict | str] | dict | str, media_folder: str) -> list[any]: diff --git a/tests/apis/gemini/test_gemini_image_input.py b/tests/apis/gemini/test_gemini_image_input.py new file mode 100644 index 00000000..c1113aba --- /dev/null +++ b/tests/apis/gemini/test_gemini_image_input.py @@ -0,0 +1,50 @@ +import os + +from PIL import Image + +from prompto.apis.gemini.gemini_utils import parse_parts_value + + +def test_parse_parts_value_text(): + part = "text" + media_folder = "media" + result = parse_parts_value(part, media_folder) + assert result == part + + +def test_parse_parts_value_image(): + part = "image" + media_folder = "media" + result = parse_parts_value(part, media_folder) + assert result == part + + +def test_parse_parts_value_image_dict(): + part = {"type": "image", "media": "pantani_giro.jpg"} + media_folder = "media" + + # Create a mock image + if not os.path.exists(media_folder): + os.makedirs(media_folder) + image_path = os.path.join(media_folder, "pantani_giro.jpg") + image = Image.new("RGB", (100, 100), color="red") + image.save(image_path) + + result = parse_parts_value(part, media_folder) + + # Assert the result + assert result.mode == "RGB" + assert result.size == (100, 100) + assert result.filename.endswith("pantani_giro.jpg") + + # Clean up the mock image + os.remove(image_path) + os.rmdir(media_folder) + + +def test_parse_parts_value_video(): + part = {"type": "video", "media": "pantani_giro.mp4"} + media_folder = "media" + + result = parse_parts_value(part, media_folder) + assert result == part