Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add file upload to support multimodal Gemini API #115

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/gemini/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ prompto_run_experiment --file data/input/gemini-example.jsonl --max-queries 30

## Multimodal prompting

Multimodal prompting is available with the Gemini API. We provide an example notebook in the [Multimodal prompting with Vertex AI notebook](./gemini-multimodal.ipynb) and example experiment file in [data/input/gemini-multimodal-example.jsonl](https://github.com/alan-turing-institute/prompto/blob/main/examples/gemini/data/input/gemini-multimodal-example.jsonl). You can run it with the following command:
Multimodal prompting is available with the Gemini API. To use it, you first need to upload your files to a dedicated cloud storage using the [File API](https://ai.google.dev/api/files#v1beta.files). To support you with this step, we provide a [notebook](./gemini-upload.ipynb) which takes your multimedia prompts as input and will add to each of your `media` elements a corresponding `uploaded_filename` key/value. You can test this with the example experiment file in [data/input/gemini-multimodal-example.jsonl](https://github.com/alan-turing-institute/prompto/blob/main/examples/gemini/data/input/gemini-multimodal-example.jsonl).
Then, we provide an example notebook in the [Multimodal prompting with Vertex AI notebook](./gemini-multimodal.ipynb). You can run it with the following command:
```bash
prompto_run_experiment --file data/input/gemini-multimodal-example.jsonl --max-queries 30
```
Expand Down
13 changes: 10 additions & 3 deletions examples/gemini/data/input/gemini-multimodal-example.jsonl
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
{"id": 0, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe what is happening in this image", {"type": "image", "media": "pantani_giro.jpg"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}}
{"id": 1, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": [{"type": "image", "media": "mortadella.jpg"}, "what is this?"]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}}
{"id": 2, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["what is in this image?", {"type": "image", "media": "pantani_giro.jpg"}]}, {"role": "model", "parts": "This is image shows a group of cyclists."}, {"role": "user", "parts": "are there any notable cyclists in this image? what are their names?"}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}}
{"id": 1, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe what is happening in this image", {"type": "image", "media": "pantani_giro.jpg"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}}
{"id": 2, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": [{"type": "image", "media": "mortadella.jpg"}, "what is this?"]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}}
{"id": 3, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["what is in this image?", {"type": "image", "media": "pantani_giro.jpg"}]}, {"role": "model", "parts": "This is image shows a group of cyclists."}, {"role": "user", "parts": "are there any notable cyclists in this image? what are their names?"}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}}
{"id": 4, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe what is happening in this video", {"type": "video", "media": "test.mp4"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}}
{"id": 5, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe what is happening in this audio", {"type": "audio", "media": "test.wav"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}}
{"id": 6, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe what is happening in this audio", {"type": "audio", "media": "test.mp3"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}}
{"id": 7, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe this video", {"type": "video", "media": "test.mp4"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}}
{"id": 8, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe this video", {"type": "video", "media": "test2.mp4"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}}
{"id": 9, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": "How does technology impact us?", "safety_filter": "none", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}}
{"id": 10, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "system", "parts": "You are a helpful assistant designed to answer questions briefly."}, {"role": "user", "parts": "Hello, I'm Bob and I'm 6 years old"}, {"role": "model", "parts": "Hi Bob, how may I assist you?"}, {"role": "user", "parts": "How old will I be next year?"}], "safety_filter": "most", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}}
Binary file added examples/gemini/data/media/test.mp3
Binary file not shown.
Binary file added examples/gemini/data/media/test.mp4
Binary file not shown.
Binary file added examples/gemini/data/media/test.wav
Binary file not shown.
Binary file added examples/gemini/data/media/test2.mp4
Binary file not shown.
206 changes: 206 additions & 0 deletions examples/gemini/gemini-upload.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Uploading media to Gemini\n",
"\n",
"This notebook processes an experiment file and associate each media element with the id of the file when uploaded using the Files API"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"import time\n",
"import tqdm\n",
"import base64\n",
"import hashlib\n",
"import google.generativeai as genai"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Set the location of the experiment and media\n",
"\n",
"experiment_location = \"data/input\"\n",
"filename = \"gemini-multimodal-example.jsonl\"\n",
"media_location = \"data/media\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load the GEMINI_API_KEY from the environment\n",
"\n",
"GEMINI_API_KEY = os.environ.get(\"GEMINI_API_KEY\")\n",
"if GEMINI_API_KEY is None:\n",
" raise ValueError(\"GEMINI_API_KEY is not set\")\n",
"\n",
"genai.configure(api_key=GEMINI_API_KEY)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def compute_sha256_base64(file_path, chunk_size=8192):\n",
" \"\"\"\n",
" Compute the SHA256 hash of the file at 'file_path' and return it as a base64-encoded string.\n",
" \"\"\"\n",
" hasher = hashlib.sha256()\n",
" with open(file_path, \"rb\") as f:\n",
" for chunk in iter(lambda: f.read(chunk_size), b\"\"):\n",
" hasher.update(chunk)\n",
" return base64.b64encode(hasher.digest()).decode(\"utf-8\")\n",
"\n",
"\n",
"def remote_file_hash_base64(remote_file):\n",
" \"\"\"\n",
" Convert a remote file's SHA256 hash (stored as a hex-encoded UTF-8 bytes object)\n",
" to a base64-encoded string.\n",
" \"\"\"\n",
" hex_str = remote_file.sha256_hash.decode(\"utf-8\")\n",
" raw_bytes = bytes.fromhex(hex_str)\n",
" return base64.b64encode(raw_bytes).decode(\"utf-8\")\n",
"\n",
"\n",
"def wait_for_processing(file_obj, poll_interval=10):\n",
" \"\"\"\n",
" Poll until the file is no longer in the 'PROCESSING' state.\n",
" Returns the updated file object.\n",
" \"\"\"\n",
" while file_obj.state.name == \"PROCESSING\":\n",
" print(\"Waiting for file to be processed...\")\n",
" time.sleep(poll_interval)\n",
" file_obj = genai.get_file(file_obj.name)\n",
" return file_obj\n",
"\n",
"\n",
"def upload(file_path, already_uploaded_files):\n",
" \"\"\"\n",
" Upload the file at 'file_path' if it hasn't been uploaded yet.\n",
" If a file with the same SHA256 (base64-encoded) hash exists, returns its name.\n",
" Otherwise, uploads the file, waits for it to be processed,\n",
" and returns the new file's name. Raises a ValueError if processing fails.\n",
" \"\"\"\n",
" local_hash = compute_sha256_base64(file_path)\n",
"\n",
" if local_hash in already_uploaded_files:\n",
" return already_uploaded_files[local_hash], already_uploaded_files\n",
"\n",
" # Upload the file if it hasn't been found.\n",
" file_obj = genai.upload_file(path=file_path)\n",
" file_obj = wait_for_processing(file_obj)\n",
"\n",
" if file_obj.state.name == \"FAILED\":\n",
" raise ValueError(\"File processing failed\")\n",
" already_uploaded_files[local_hash] = file_obj.name\n",
" return already_uploaded_files[local_hash], already_uploaded_files"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Retrieve already uploaded files\n",
"\n",
"uploaded_files = {\n",
" remote_file_hash_base64(remote_file): remote_file.name\n",
" for remote_file in genai.list_files()\n",
"}\n",
"print(f\"Found {len(uploaded_files)} files already uploaded\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"files_to_upload = set()\n",
"experiment_path = f\"{experiment_location}/{filename}\"\n",
"\n",
"# Read and collect media file paths\n",
"with open(experiment_path, \"r\") as f:\n",
" lines = f.readlines()\n",
"\n",
"data_list = []\n",
"\n",
"for line in lines:\n",
" data = json.loads(line)\n",
" data_list.append(data)\n",
"\n",
" if not isinstance(data.get(\"prompt\"), list):\n",
" continue\n",
"\n",
" files_to_upload.update(\n",
" f'{media_location}/{el[\"media\"]}'\n",
" for prompt in data[\"prompt\"]\n",
" for part in prompt.get(\"parts\", [])\n",
" if isinstance(el := part, dict) and \"media\" in el\n",
" )\n",
"\n",
"# Upload files and store mappings\n",
"genai_files = {}\n",
"for file_path in tqdm.tqdm(files_to_upload):\n",
" uploaded_filename, uploaded_files = upload(file_path, uploaded_files)\n",
" genai_files[file_path] = uploaded_filename\n",
"\n",
"# Modify data to include uploaded filenames\n",
"for data in data_list:\n",
" if isinstance(data.get(\"prompt\"), list):\n",
" for prompt in data[\"prompt\"]:\n",
" for part in prompt.get(\"parts\", []):\n",
" if isinstance(part, dict) and \"media\" in part:\n",
" file_path = f'{media_location}/{part[\"media\"]}'\n",
" if file_path in genai_files:\n",
" part[\"uploaded_filename\"] = genai_files[file_path]\n",
" else:\n",
" print(f\"Failed to find {file_path} in genai_files\")\n",
"\n",
"# Write modified data back to the JSONL file\n",
"with open(experiment_path, \"w\") as f:\n",
" for data in data_list:\n",
" f.write(json.dumps(data) + \"\\n\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
15 changes: 7 additions & 8 deletions src/prompto/apis/gemini/gemini_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,29 +31,28 @@ def parse_parts_value(part: dict | str, media_folder: str) -> any:

# read multimedia type
type = part.get("type")
uploaded_filename = part.get("uploaded_filename")
if type is None:
raise ValueError("Multimedia type is not specified")
# read file location
media = part.get("media")
if media is None:
raise ValueError("File location is not specified")

# create Part object based on multimedia type
if type == "text":
return media
else:
if type == "image":
media_file_path = os.path.join(media_folder, media)
return PIL.Image.open(media_file_path)
elif type == "file":
if uploaded_filename is None:
raise ValueError(
f"File {media} not uploaded. Please upload the file first."
)
else:
try:
return get_file(name=media)
return get_file(name=uploaded_filename)
except Exception as err:
raise ValueError(
f"Failed to get file: {media} due to error: {type(err).__name__} - {err}"
)
else:
raise ValueError(f"Unsupported multimedia type: {type}")


def parse_parts(parts: list[dict | str] | dict | str, media_folder: str) -> list[any]:
Expand Down
50 changes: 50 additions & 0 deletions tests/apis/gemini/test_gemini_image_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os

from PIL import Image

from prompto.apis.gemini.gemini_utils import parse_parts_value


def test_parse_parts_value_text():
part = "text"
media_folder = "media"
result = parse_parts_value(part, media_folder)
assert result == part


def test_parse_parts_value_image():
part = "image"
media_folder = "media"
result = parse_parts_value(part, media_folder)
assert result == part


def test_parse_parts_value_image_dict():
part = {"type": "image", "media": "pantani_giro.jpg"}
media_folder = "media"

# Create a mock image
if not os.path.exists(media_folder):
os.makedirs(media_folder)
image_path = os.path.join(media_folder, "pantani_giro.jpg")
image = Image.new("RGB", (100, 100), color="red")
image.save(image_path)

result = parse_parts_value(part, media_folder)

# Assert the result
assert result.mode == "RGB"
assert result.size == (100, 100)
assert result.filename.endswith("pantani_giro.jpg")

# Clean up the mock image
os.remove(image_path)
os.rmdir(media_folder)


def test_parse_parts_value_video():
part = {"type": "video", "media": "pantani_giro.mp4"}
media_folder = "media"

result = parse_parts_value(part, media_folder)
assert result == part