diff --git a/examples/gemini/data/output/gemini-example/18-10-2024-09-57-29-completed-gemini-example.jsonl b/examples/gemini/data/output/gemini-example/18-10-2024-09-57-29-completed-gemini-example.jsonl new file mode 100644 index 00000000..3258f8ea --- /dev/null +++ b/examples/gemini/data/output/gemini-example/18-10-2024-09-57-29-completed-gemini-example.jsonl @@ -0,0 +1,6 @@ +{"id": 0, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": "How does technology impact us?", "safety_filter": "none", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}, "timestamp_sent": "18-10-2024-09-57-31", "response": "Technology impacts us in profound and multifaceted ways, shaping our lives in every aspect from communication and work to entertainment and healthcare. Here's a breakdown of key impacts:\n\n**Positive Impacts:**\n\n* **Enhanced Communication:** Technology has revolutionized communication. We can connect with people across the globe instantly through video calls, instant messaging, and social media. This fosters stronger relationships, facilitates business, and enables global collaboration.\n* **Increased Efficiency and Productivity:** Technology automates tasks, streamlines processes,", "safety_attributes": {"HARM_CATEGORY_SEXUALLY_EXPLICIT": "1", "HARM_CATEGORY_HATE_SPEECH": "1", "HARM_CATEGORY_HARASSMENT": "1", "HARM_CATEGORY_DANGEROUS_CONTENT": "1", "blocked": "[False, False, False, False]", "finish_reason": "MAX_TOKENS"}} +{"id": 1, "api": "gemini", "model_name": "gemini-1.0-pro", "prompt": "How does technology impact us?", "safety_filter": "few", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}, "timestamp_sent": "18-10-2024-09-57-33", "response": "**Positive Impacts of Technology**\n\n* **Enhanced Communication:** Social media, messaging apps, and video conferencing facilitate seamless connection and information exchange globally.\n* **Increased Accessibility:** Technology provides equal access to education, healthcare, and information for individuals regardless of location or circumstances.\n* **Improved Productivity:** Automation, artificial intelligence, and data analytics streamline processes, improve efficiency, and reduce time consumption.\n* **Medical Advancements:** Medical devices, surgical robots, and telemedicine enhance patient care, improve", "safety_attributes": {"HARM_CATEGORY_SEXUALLY_EXPLICIT": "1", "HARM_CATEGORY_HATE_SPEECH": "1", "HARM_CATEGORY_HARASSMENT": "1", "HARM_CATEGORY_DANGEROUS_CONTENT": "1", "blocked": "[False, False, False, False]", "finish_reason": "MAX_TOKENS"}} +{"id": 2, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": ["How does international trade create jobs?", "I want a joke about that"], "safety_filter": "some", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}, "timestamp_sent": "18-10-2024-09-57-35", "response": ["International trade creates jobs in a multitude of ways, boosting both domestic and global economies. Here's a breakdown of how it works:\n\n**1. Increased Production and Export Jobs:**\n\n* **Specialization and Comparative Advantage:** Countries specialize in producing goods and services where they have a comparative advantage (lower opportunity cost). This leads to increased efficiency and higher production levels.\n* **Export Jobs:** Producers of these specialized goods and services find new markets abroad, leading to increased demand for their products", "Why did the economist get a job in international trade?\n\nBecause he was tired of being a \"jobless\" economist! \ud83d\ude1c \n"], "safety_attributes": [{"HARM_CATEGORY_SEXUALLY_EXPLICIT": "1", "HARM_CATEGORY_HATE_SPEECH": "1", "HARM_CATEGORY_HARASSMENT": "1", "HARM_CATEGORY_DANGEROUS_CONTENT": "1", "blocked": "[False, False, False, False]", "finish_reason": "MAX_TOKENS"}, {"HARM_CATEGORY_SEXUALLY_EXPLICIT": "1", "HARM_CATEGORY_HATE_SPEECH": "1", "HARM_CATEGORY_HARASSMENT": "2", "HARM_CATEGORY_DANGEROUS_CONTENT": "1", "blocked": "[False, False, False, False]", "finish_reason": "STOP"}]} +{"id": 3, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "system", "parts": "You are a helpful assistant designed to answer questions briefly."}, {"role": "user", "parts": "What efforts are being made to keep the hakka language alive?"}], "safety_filter": "default", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}, "timestamp_sent": "18-10-2024-09-57-37", "response": "Efforts to preserve Hakka include language schools, cultural festivals, and online resources. \n", "safety_attributes": {"HARM_CATEGORY_SEXUALLY_EXPLICIT": "1", "HARM_CATEGORY_HATE_SPEECH": "1", "HARM_CATEGORY_HARASSMENT": "1", "HARM_CATEGORY_DANGEROUS_CONTENT": "1", "blocked": "[False, False, False, False]", "finish_reason": "STOP"}} +{"id": 4, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": "What efforts are being made to keep the hakka language alive?"}, {"role": "system", "parts": "You are a helpful assistant designed to answer questions briefly."}], "safety_filter": "most", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}, "timestamp_sent": "18-10-2024-09-57-39", "response": "TypeError - if api == 'gemini', then the prompt must be a str, list[str], or list[dict[str,str]] where the dictionary contains the keys 'role' and 'parts' only, and the values for 'role' must be one of 'user' or 'model', except for the first message in the list of dictionaries can be a system message with the key 'role' set to 'system'."} +{"id": 5, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "system", "parts": "You are a helpful assistant designed to answer questions briefly."}, {"role": "user", "parts": "Hello, I'm Bob and I'm 6 years old"}, {"role": "model", "parts": "Hi Bob, how may I assist you?"}, {"role": "user", "parts": "How old will I be next year?"}], "safety_filter": "most", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}, "timestamp_sent": "18-10-2024-09-57-41", "response": "You will be 7 years old next year! \n", "safety_attributes": {"HARM_CATEGORY_SEXUALLY_EXPLICIT": "1", "HARM_CATEGORY_HATE_SPEECH": "1", "HARM_CATEGORY_HARASSMENT": "1", "HARM_CATEGORY_DANGEROUS_CONTENT": "1", "blocked": "[False, False, False, False]", "finish_reason": "STOP"}} diff --git a/examples/gemini/data/input/gemini-example.jsonl b/examples/gemini/data/output/gemini-example/18-10-2024-09-57-29-input-gemini-example.jsonl similarity index 100% rename from examples/gemini/data/input/gemini-example.jsonl rename to examples/gemini/data/output/gemini-example/18-10-2024-09-57-29-input-gemini-example.jsonl diff --git a/examples/gemini/data/output/gemini-example/18-10-2024-09-57-29-log-gemini-example.txt b/examples/gemini/data/output/gemini-example/18-10-2024-09-57-29-log-gemini-example.txt new file mode 100644 index 00000000..3459ed9a --- /dev/null +++ b/examples/gemini/data/output/gemini-example/18-10-2024-09-57-29-log-gemini-example.txt @@ -0,0 +1,2 @@ +18-10-2024, 09:57: Error (i=5, id=4): TypeError - if api == 'gemini', then the prompt must be a str, list[str], or list[dict[str,str]] where the dictionary contains the keys 'role' and 'parts' only, and the values for 'role' must be one of 'user' or 'model', except for the first message in the list of dictionaries can be a system message with the key 'role' set to 'system'. +18-10-2024, 09:57: Completed experiment: gemini-example.jsonl! Experiment processing time: 12.673 seconds, Average time per query: 2.112 seconds diff --git a/examples/gemini/data/output/gemini-example/25-06-2024-20-17-17-completed-gemini-example.jsonl b/examples/gemini/data/output/gemini-example/25-06-2024-20-17-17-completed-gemini-example.jsonl deleted file mode 100644 index 68d780ae..00000000 --- a/examples/gemini/data/output/gemini-example/25-06-2024-20-17-17-completed-gemini-example.jsonl +++ /dev/null @@ -1,6 +0,0 @@ -{"id": 0, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": "How does technology impact us?", "safety_filter": "none", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}, "response": "Technology impacts us in countless ways, both positive and negative. Here's a breakdown of some key areas:\n\n**Positive Impacts:**\n\n* **Enhanced Communication:** Technology has revolutionized how we connect with others. Instant messaging, video calls, and social media platforms allow us to communicate with friends, family, and colleagues across the globe.\n* **Increased Efficiency and Productivity:** Tools like computers, automation, and artificial intelligence (AI) streamline tasks, increase production, and improve overall efficiency in", "safety_attributes": {"HARM_CATEGORY_SEXUALLY_EXPLICIT": "1", "HARM_CATEGORY_HATE_SPEECH": "1", "HARM_CATEGORY_HARASSMENT": "1", "HARM_CATEGORY_DANGEROUS_CONTENT": "1", "blocked": "[False, False, False, False]", "finish_reason": "STOP"}} -{"id": 1, "api": "gemini", "model_name": "gemini-1.0-pro", "prompt": "How does technology impact us?", "safety_filter": "few", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}, "response": "**Positive Impacts:**\n\n* **Enhanced communication and information access:** Social media, messaging apps, and the internet have facilitated seamless communication, knowledge sharing, and access to vast amounts of information.\n* **Increased productivity and efficiency:** Automation, software tools, and remote work capabilities have enhanced productivity in various industries, saving time and effort.\n* **Improved healthcare:** Medical devices, telemedicine, and wearable health trackers have revolutionized healthcare, allowing for more accessible, personalized, and efficient treatments.\n", "safety_attributes": {"HARM_CATEGORY_SEXUALLY_EXPLICIT": "1", "HARM_CATEGORY_HATE_SPEECH": "1", "HARM_CATEGORY_HARASSMENT": "1", "HARM_CATEGORY_DANGEROUS_CONTENT": "1", "blocked": "[False, False, False, False]", "finish_reason": "STOP"}} -{"id": 2, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": ["How does international trade create jobs?", "I want a joke about that"], "safety_filter": "some", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}, "response": ["International trade creates jobs in several ways:\n\n**1. Expanding Markets:**\n\n* **Exports:** When a country exports goods or services, it creates demand for those products from other countries. This leads to increased production, which in turn requires more workers to produce and ship those goods. This can be in manufacturing, agriculture, or service sectors.\n* **Imports:** While imports might seem like they take away jobs, they actually stimulate other sectors of the economy. For example, importing cheaper raw materials", "Why did the economist get lost in the forest?\n\nBecause he kept following the path of least resistance, and it led him to a dead end! \n\n(This joke plays on the economic concept of \"following the path of least resistance\", which can be applied to trade and globalization.) \n"], "safety_attributes": [{"HARM_CATEGORY_SEXUALLY_EXPLICIT": "1", "HARM_CATEGORY_HATE_SPEECH": "1", "HARM_CATEGORY_HARASSMENT": "1", "HARM_CATEGORY_DANGEROUS_CONTENT": "1", "blocked": "[False, False, False, False]", "finish_reason": "STOP"}, {"HARM_CATEGORY_SEXUALLY_EXPLICIT": "1", "HARM_CATEGORY_HATE_SPEECH": "1", "HARM_CATEGORY_HARASSMENT": "1", "HARM_CATEGORY_DANGEROUS_CONTENT": "1", "blocked": "[False, False, False, False]", "finish_reason": "STOP"}]} -{"id": 3, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "system", "parts": "You are a helpful assistant designed to answer questions briefly."}, {"role": "user", "parts": "What efforts are being made to keep the hakka language alive?"}], "safety_filter": "default", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}, "response": "Efforts to preserve Hakka include:\n\n* **Language schools and classes:** Offering instruction to children and adults.\n* **Cultural events:** Promoting Hakka music, dance, and literature.\n* **Media:** Radio stations, online platforms, and publications in Hakka.\n* **Community initiatives:** Encouraging use of Hakka within families and gatherings. \n", "safety_attributes": {"HARM_CATEGORY_SEXUALLY_EXPLICIT": "1", "HARM_CATEGORY_HATE_SPEECH": "1", "HARM_CATEGORY_HARASSMENT": "1", "HARM_CATEGORY_DANGEROUS_CONTENT": "1", "blocked": "[False, False, False, False]", "finish_reason": "STOP"}} -{"id": 4, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": "What efforts are being made to keep the hakka language alive?"}, {"role": "system", "parts": "You are a helpful assistant designed to answer questions briefly."}], "safety_filter": "most", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}, "response": "TypeError - if api == 'gemini', then the prompt must be a str, list[str], or list[dict[str,str]] where the dictionary contains the keys 'role' and 'parts' only, and the values for 'role' must be one of 'user' or 'model', except for the first message in the list of dictionaries can be a system message with the key 'role' set to 'system'."} -{"id": 5, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "system", "parts": "You are a helpful assistant designed to answer questions briefly."}, {"role": "user", "parts": "Hello, I'm Bob and I'm 6 years old"}, {"role": "model", "parts": "Hi Bob, how may I assist you?"}, {"role": "user", "parts": "How old will I be next year?"}], "safety_filter": "most", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}, "response": "You will be 7 years old next year! \n", "safety_attributes": {"HARM_CATEGORY_SEXUALLY_EXPLICIT": "1", "HARM_CATEGORY_HATE_SPEECH": "1", "HARM_CATEGORY_HARASSMENT": "1", "HARM_CATEGORY_DANGEROUS_CONTENT": "1", "blocked": "[False, False, False, False]", "finish_reason": "STOP"}} diff --git a/examples/gemini/data/output/gemini-example/25-06-2024-20-17-17-input-gemini-example.jsonl b/examples/gemini/data/output/gemini-example/25-06-2024-20-17-17-input-gemini-example.jsonl deleted file mode 100644 index a8c913a6..00000000 --- a/examples/gemini/data/output/gemini-example/25-06-2024-20-17-17-input-gemini-example.jsonl +++ /dev/null @@ -1,6 +0,0 @@ -{"id": 0, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": "How does technology impact us?", "safety_filter": "none", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}} -{"id": 1, "api": "gemini", "model_name": "gemini-1.0-pro", "prompt": "How does technology impact us?", "safety_filter": "few", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}} -{"id": 2, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": ["How does international trade create jobs?", "I want a joke about that"], "safety_filter": "some", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}} -{"id": 3, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "system", "parts": "You are a helpful assistant designed to answer questions briefly."}, {"role": "user", "parts": "What efforts are being made to keep the hakka language alive?"}], "safety_filter": "default", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}} -{"id": 4, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": "What efforts are being made to keep the hakka language alive?"}, {"role": "system", "parts": "You are a helpful assistant designed to answer questions briefly."}], "safety_filter": "most", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}} -{"id": 5, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "system", "parts": "You are a helpful assistant designed to answer questions briefly."}, {"role": "user", "parts": "Hello, I'm Bob and I'm 6 years old"}, {"role": "model", "parts": "Hi Bob, how may I assist you?"}, {"role": "user", "parts": "How old will I be next year?"}], "safety_filter": "most", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}} diff --git a/examples/gemini/data/output/gemini-example/25-06-2024-20-17-17-log-gemini-example.txt b/examples/gemini/data/output/gemini-example/25-06-2024-20-17-17-log-gemini-example.txt deleted file mode 100644 index d6171cda..00000000 --- a/examples/gemini/data/output/gemini-example/25-06-2024-20-17-17-log-gemini-example.txt +++ /dev/null @@ -1,2 +0,0 @@ -25-06-2024, 20:21: Error (i=5) [id=4]. TypeError - if api == 'gemini', then the prompt must be a str, list[str], or list[dict[str,str]] where the dictionary contains the keys 'role' and 'parts' only, and the values for 'role' must be one of 'user' or 'model', except for the first message in the list of dictionaries can be a system message with the key 'role' set to 'system'. -25-06-2024, 20:21: Completed experiment gemini-example.jsonl! Experiment processing time: 13.046 seconds, Average time per query: 2.174 seconds diff --git a/examples/gemini/gemini.ipynb b/examples/gemini/gemini.ipynb index 2d3d1397..a1d5ef70 100644 --- a/examples/gemini/gemini.ipynb +++ b/examples/gemini/gemini.ipynb @@ -289,15 +289,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "Sending 6 queries (attempt 1/3): 0%| | 0/6 [00:00 dict: i.e. multi-turn chat with history. """ if prompt_dict["prompt"][0]["role"] == "system": - prompt, model_name, model, safety_settings, generation_config, _ = ( + prompt, model_name, model, safety_settings, generation_config = ( await self._obtain_model_inputs( prompt_dict=prompt_dict, system_instruction=prompt_dict["prompt"][0]["parts"], ) ) - chat = model.start_chat(history=prompt[1:-1]) + chat = model.start_chat( + history=[ + convert_dict_to_input( + content_dict=x, media_folder=self.settings.media_folder + ) + for x in prompt[1:-1] + ] + ) else: - prompt, model_name, model, safety_settings, generation_config, _ = ( + prompt, model_name, model, safety_settings, generation_config = ( await self._obtain_model_inputs( prompt_dict=prompt_dict, system_instruction=None ) ) - chat = model.start_chat(history=prompt[:-1]) + chat = model.start_chat( + history=[ + convert_dict_to_input( + content_dict=x, media_folder=self.settings.media_folder + ) + for x in prompt[:-1] + ] + ) try: response = await chat.send_message_async( - content=prompt[-1], + content=convert_dict_to_input( + content_dict=prompt[-1], media_folder=self.settings.media_folder + ), generation_config=generation_config, safety_settings=safety_settings, stream=False, diff --git a/src/prompto/apis/gemini/gemini_utils.py b/src/prompto/apis/gemini/gemini_utils.py index 509c2a79..7d18c02d 100644 --- a/src/prompto/apis/gemini/gemini_utils.py +++ b/src/prompto/apis/gemini/gemini_utils.py @@ -6,7 +6,7 @@ gemini_chat_roles = set(["user", "model"]) -def parse_multimedia_dict(multimedia_dict: dict, media_folder: str) -> any: +def parse_parts_value(part: dict | str, media_folder: str) -> any: """ Parse multimedia dictionary and create Vertex AI Part object. Expected keys: @@ -15,20 +15,23 @@ def parse_multimedia_dict(multimedia_dict: dict, media_folder: str) -> any: Parameters ---------- - multimedia_dict : dict - Dictionary with multimedia data to parse and create Part object + part : dict | str + Either a dictionary or a string which defines a multimodal object. Returns ------- any Multimedia data object """ + if isinstance(part, str): + return part + # read multimedia type - type = multimedia_dict.get("type") + type = part.get("type") if type is None: raise ValueError("Multimedia type is not specified") # read file location - media = multimedia_dict.get("media") + media = part.get("media") if media is None: raise ValueError("File location is not specified") @@ -50,7 +53,7 @@ def parse_multimedia_dict(multimedia_dict: dict, media_folder: str) -> any: raise ValueError(f"Unsupported multimedia type: {type}") -def parse_multimedia(multimedia: list[dict] | dict, media_folder: str) -> list[any]: +def parse_parts(parts: list[dict | str] | dict | str, media_folder: str) -> list[any]: """ Parse multimedia data and create a list of multimedia data objects. If multimedia is a single dictionary, a list with a single multimedia data object is returned. @@ -66,10 +69,43 @@ def parse_multimedia(multimedia: list[dict] | dict, media_folder: str) -> list[a list[any] | any List of multimedia data object(s) created from the input multimedia data """ - if isinstance(multimedia, dict): - return [parse_multimedia_dict(multimedia, media_folder=media_folder)] - else: - return [parse_multimedia_dict(m, media_folder=media_folder) for m in multimedia] + # convert to list[dict | str] + if isinstance(parts, dict) or isinstance(parts, str): + parts = [parts] + + return [parse_parts_value(p, media_folder=media_folder) for p in parts] + + +def convert_dict_to_input(content_dict: dict, media_folder: str) -> dict: + """ + Convert dictionary to an input that can be used by the Gemini API. + The output is a dictionary with keys "role" and "parts". + + Parameters + ---------- + content_dict : dict + Content dictionary with keys "role" and "parts"strings. + + Returns + ------- + dict + dict with keys "role" and "parts" where the value of + role is either "user" or "model" and the value of + parts is a list of inputs to make up an input (which can include + text or image/video inputs). + """ + if "role" not in content_dict: + raise KeyError("Role key is missing in content dictionary") + if "parts" not in content_dict: + raise KeyError("Parts key is missing in content dictionary") + + return { + "role": content_dict["role"], + "parts": parse_parts( + content_dict["parts"], + media_folder=media_folder, + ), + } def process_response(response: dict) -> str: diff --git a/src/prompto/apis/vertexai/vertexai.py b/src/prompto/apis/vertexai/vertexai.py index b9fa9a9e..43db2073 100644 --- a/src/prompto/apis/vertexai/vertexai.py +++ b/src/prompto/apis/vertexai/vertexai.py @@ -7,7 +7,6 @@ GenerativeModel, HarmBlockThreshold, HarmCategory, - Part, ) from prompto.apis.base import AsyncAPI @@ -16,7 +15,7 @@ process_response, process_safety_attributes, ) -from prompto.apis.vertexai.vertexai_utils import dict_to_content, parse_multimedia +from prompto.apis.vertexai.vertexai_utils import convert_dict_to_input from prompto.settings import Settings from prompto.utils import ( FILE_WRITE_LOCK, @@ -287,16 +286,7 @@ async def _obtain_model_inputs( f"parameters must be a dictionary, not {type(generation_config)}" ) - # parse multimedia data (if any) - multimedia_dict = prompt_dict.get("multimedia", None) - if multimedia_dict is not None: - multimedia = parse_multimedia( - multimedia_dict, media_folder=self.settings.media_folder - ) - else: - multimedia = None - - return prompt, model_name, model, safety_settings, generation_config, multimedia + return prompt, model_name, model, safety_settings, generation_config async def _query_string(self, prompt_dict: dict, index: int | str): """ @@ -304,22 +294,15 @@ async def _query_string(self, prompt_dict: dict, index: int | str): (prompt_dict["prompt"] is a string), i.e. single-turn completion or chat. """ - prompt, model_name, model, safety_settings, generation_config, multimedia = ( + prompt, model_name, model, safety_settings, generation_config = ( await self._obtain_model_inputs( prompt_dict=prompt_dict, system_instruction=None ) ) - # prepare the contents to send to the model - if multimedia is not None: - # prepend the multimedia to the prompt - contents = multimedia + [Part.from_text(prompt)] - else: - contents = [Part.from_text(prompt)] - try: response = await model.generate_content_async( - contents=contents, + contents=prompt, generation_config=generation_config, safety_settings=safety_settings, stream=False, @@ -391,14 +374,13 @@ async def _query_chat(self, prompt_dict: dict, index: int | str): (prompt_dict["prompt"] is a list of strings to sequentially send to the model), i.e. multi-turn chat with history. """ - prompt, model_name, model, safety_settings, generation_config, _ = ( + prompt, model_name, model, safety_settings, generation_config = ( await self._obtain_model_inputs( prompt_dict=prompt_dict, system_instruction=None ) ) chat = model.start_chat(history=[]) - response_list = [] safety_attributes_list = [] try: @@ -495,24 +477,40 @@ async def _query_history(self, prompt_dict: dict, index: int | str) -> dict: i.e. multi-turn chat with history. """ if prompt_dict["prompt"][0]["role"] == "system": - prompt, model_name, model, safety_settings, generation_config, _ = ( + prompt, model_name, model, safety_settings, generation_config = ( await self._obtain_model_inputs( prompt_dict=prompt_dict, system_instruction=prompt_dict["prompt"][0]["parts"], ) ) - chat = model.start_chat(history=[dict_to_content(x) for x in prompt[1:-1]]) + chat = model.start_chat( + history=[ + convert_dict_to_input( + content_dict=x, media_folder=self.settings.media_folder + ) + for x in prompt[1:-1] + ] + ) else: - prompt, model_name, model, safety_settings, generation_config, _ = ( + prompt, model_name, model, safety_settings, generation_config = ( await self._obtain_model_inputs( prompt_dict=prompt_dict, system_instruction=None ) ) - chat = model.start_chat(history=[dict_to_content(x) for x in prompt[:-1]]) + chat = model.start_chat( + history=[ + convert_dict_to_input( + content_dict=x, media_folder=self.settings.media_folder + ) + for x in prompt[:-1] + ] + ) try: response = await chat.send_message_async( - content=dict_to_content(prompt[-1]), + content=convert_dict_to_input( + content_dict=prompt[-1], media_folder=self.settings.media_folder + ), generation_config=generation_config, safety_settings=safety_settings, stream=False, diff --git a/src/prompto/apis/vertexai/vertexai_utils.py b/src/prompto/apis/vertexai/vertexai_utils.py index 53cb9a3b..61279aa2 100644 --- a/src/prompto/apis/vertexai/vertexai_utils.py +++ b/src/prompto/apis/vertexai/vertexai_utils.py @@ -3,30 +3,36 @@ from vertexai.generative_models import Content, Image, Part -def parse_multimedia_dict(multimedia_dict: dict, media_folder: str) -> Part: +def parse_parts_value(part: dict | str, media_folder: str) -> Part: """ - Parse multimedia dictionary and create Vertex AI Part object. - Expected keys: + Create Vertex AI Part objects from a dictionary or string. + If the parts is a string, a Part object with text is created. + Otherwise, parts is a dictionary with the following keys: - type: str, multimedia type, one of ["image", "video", "uri" "text"] - - media: str, file location (if type is image or video) or text (if type is text) - - mime_type: str, mime type of the video file, only required for video and uri types + - media: str, file location (if type is image or video) or text (if type is text). + This can be either a local file path (relative to the media folder) or a GCS URI. + - mime_type: str, mime type of the image orvideo file, only required + if using a GCS URI, or using a local video file Parameters ---------- - multimedia_dict : dict - Dictionary with multimedia data to parse and create Part object + part : dict | str + Either a dictionary or a string which defines a Part object. Returns ------- Part Vertex AI Part object created from multimedia data """ + if isinstance(part, str): + return Part.from_text(part) + # read multimedia type - type = multimedia_dict.get("type") + type = part.get("type") if type is None: raise ValueError("Multimedia type is not specified") # read file location - media = multimedia_dict.get("media") + media = part.get("media") if media is None: raise ValueError("File location is not specified") @@ -35,12 +41,24 @@ def parse_multimedia_dict(multimedia_dict: dict, media_folder: str) -> Part: return Part.from_text(media) else: media_file_path = os.path.join(media_folder, media) - mime_type = multimedia_dict.get("mime_type") + mime_type = part.get("mime_type") if type == "image": + if media.startswith("gs://"): + if mime_type is None: + raise ValueError( + "Mime type is not specified. Required for image type if media is a GCS URI" + ) + + return Part.from_uri(uri=media, mime_type=mime_type) + return Part.from_image(Image.load_from_file(media_file_path)) elif type == "video": if mime_type is None: - raise ValueError("Mime type is not specified. Required for video") + raise ValueError("Mime type is not specified. Required for video type") + + if media.startswith("gs://"): + return Part.from_uri(uri=media, mime_type=mime_type) + return Part.from_data( open(media_file_path, "rb").read(), mime_type="video/mp4" ) @@ -52,29 +70,31 @@ def parse_multimedia_dict(multimedia_dict: dict, media_folder: str) -> Part: raise ValueError(f"Unsupported multimedia type: {type}") -def parse_multimedia(multimedia: list[dict] | dict, media_folder: str) -> list[Part]: +def parse_parts(parts: list[dict | str] | dict | str, media_folder: str) -> list[Part]: """ - Parse multimedia data and create a list of Vertex AI Part object(s). - If multimedia is a single dictionary, a list with a single Part object is returned. + Parse "parts" value and create a list of Vertex AI Part object(s). + If parts is a single dictionary or a string, a list with a single Part + object is returned, otherwise, a list of multiple Part objects is created. Parameters ---------- - multimedia : list[dict] | dict - Multimedia data to parse and create Part object(s). - Can be a list of multimedia dictionaries or a single multimedia dictionary. + parts : list[dict | str] | dict | str + Corresponding to the "parts" value in the prompt. + Can be a list of dictionaries and strings, or a single dictionary or string. Returns ------- - list[Part] | Part - List of Vertex AI Part object(s) created from multimedia data + list[Part] + List of Vertex AI Part object(s) created from "parts" value in a prompt """ - if isinstance(multimedia, dict): - return [parse_multimedia_dict(multimedia, media_folder=media_folder)] - else: - return [parse_multimedia_dict(m, media_folder=media_folder) for m in multimedia] + # convert to list[dict | str] + if isinstance(parts, dict) or isinstance(parts, str): + parts = [parts] + + return [parse_parts_value(p, media_folder=media_folder) for p in parts] -def dict_to_content(content_dict: dict) -> Content: +def convert_dict_to_input(content_dict: dict, media_folder: str) -> Content: """ Convert content dictionary to Vertex AI Content object. @@ -95,5 +115,9 @@ def dict_to_content(content_dict: dict) -> Content: raise KeyError("Parts key is missing in content dictionary") return Content( - role=content_dict["role"], parts=[Part.from_text(content_dict["parts"])] + role=content_dict["role"], + parts=parse_parts( + content_dict["parts"], + media_folder=media_folder, + ), ) diff --git a/src/prompto/test_new.jsonl b/src/prompto/test_new.jsonl new file mode 100644 index 00000000..6d758dfc --- /dev/null +++ b/src/prompto/test_new.jsonl @@ -0,0 +1,2 @@ +{"id": 0, "api": "vertexai", "model_name": "gdm-eval-model-c", "prompt": [{"role": "user", "part": ["what is in this image?", {"type": "image", "mime_type": "video/mp4", "media": "gs://gdm_model_evals_turing_bucket_1/GreatRedSpot.mp4"}]}], "safety_filter": "none", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}} +{"id": 1, "api": "vertexai", "model_name": "gdm-eval-model-c", "prompt": [{"role": "user", "part": ["describe what is happening in this video?", {"type": "image", "mime_type": "video/mp4", "media": "gs://gdm_model_evals_turing_bucket_1/GreatRedSpot.mp4"}]}], "safety_filter": "none", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}} diff --git a/tests/apis/gemini/test_gemini.py b/tests/apis/gemini/test_gemini.py index 9100d74f..60846b3e 100644 --- a/tests/apis/gemini/test_gemini.py +++ b/tests/apis/gemini/test_gemini.py @@ -363,7 +363,7 @@ async def test_gemini_obtain_model_inputs(temporary_data_folders, monkeypatch): } ) assert isinstance(test_case, tuple) - assert len(test_case) == 6 + assert len(test_case) == 5 assert test_case[0] == "test prompt" assert test_case[1] == "gemini_model_name" assert isinstance(test_case[2], GenerativeModel) @@ -371,7 +371,6 @@ async def test_gemini_obtain_model_inputs(temporary_data_folders, monkeypatch): assert test_case[2]._system_instruction is None assert isinstance(test_case[3], dict) assert test_case[4] == {"temperature": 1, "max_output_tokens": 100} - assert test_case[5] is None # test for case where no parameters in prompt_dict test_case = await gemini_api._obtain_model_inputs( @@ -383,7 +382,7 @@ async def test_gemini_obtain_model_inputs(temporary_data_folders, monkeypatch): } ) assert isinstance(test_case, tuple) - assert len(test_case) == 6 + assert len(test_case) == 5 assert test_case[0] == "test prompt" assert test_case[1] == "gemini_model_name" assert isinstance(test_case[2], GenerativeModel) @@ -391,7 +390,6 @@ async def test_gemini_obtain_model_inputs(temporary_data_folders, monkeypatch): assert test_case[2]._system_instruction is None assert isinstance(test_case[3], dict) assert test_case[4] == {} - assert test_case[5] is None # test for case where system_instruction is provided test_case = await gemini_api._obtain_model_inputs( @@ -404,7 +402,7 @@ async def test_gemini_obtain_model_inputs(temporary_data_folders, monkeypatch): system_instruction="hello", ) assert isinstance(test_case, tuple) - assert len(test_case) == 6 + assert len(test_case) == 5 assert test_case[0] == "test prompt" assert test_case[1] == "gemini_model_name" assert isinstance(test_case[2], GenerativeModel) @@ -412,7 +410,6 @@ async def test_gemini_obtain_model_inputs(temporary_data_folders, monkeypatch): assert test_case[2]._system_instruction is not None assert isinstance(test_case[3], dict) assert test_case[4] == {} - assert test_case[5] is None # test error catching when parameters are not a dictionary with pytest.raises( @@ -472,7 +469,7 @@ async def test_gemini_obtain_model_inputs_safety_filters( } ) assert isinstance(test_case, tuple) - assert len(test_case) == 6 + assert len(test_case) == 5 assert test_case[0] == "test prompt" assert test_case[1] == "gemini_model_name" assert isinstance(test_case[2], GenerativeModel) @@ -480,7 +477,6 @@ async def test_gemini_obtain_model_inputs_safety_filters( assert test_case[2]._system_instruction is None assert isinstance(test_case[3], dict) assert test_case[4] == {"temperature": 1, "max_output_tokens": 100} - assert test_case[5] is None # test error if safety filter is not recognised with pytest.raises( diff --git a/tests/apis/gemini/test_gemini_chat_input.py b/tests/apis/gemini/test_gemini_chat_input.py index 3629341e..b621c112 100644 --- a/tests/apis/gemini/test_gemini_chat_input.py +++ b/tests/apis/gemini/test_gemini_chat_input.py @@ -151,7 +151,6 @@ async def test_gemini_query_history_check_chat_init( ), DEFAULT_SAFETY_SETTINGS, prompt_dict_chat["parameters"], - None, ) # error will be raised as we've mocked the start_chat method diff --git a/tests/apis/gemini/test_gemini_history_input.py b/tests/apis/gemini/test_gemini_history_input.py index cb9a3a6d..5b0f8fe5 100644 --- a/tests/apis/gemini/test_gemini_history_input.py +++ b/tests/apis/gemini/test_gemini_history_input.py @@ -77,7 +77,7 @@ async def test_gemini_query_history( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() mock_gemini_call.assert_awaited_once_with( - content=prompt_dict_history["prompt"][1], + content={"role": "user", "parts": [prompt_dict_history["prompt"][1]["parts"]]}, generation_config=prompt_dict_history["parameters"], safety_settings=DEFAULT_SAFETY_SETTINGS, stream=False, @@ -119,7 +119,7 @@ async def test_gemini_query_history_error( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() mock_gemini_call.assert_awaited_once_with( - content=prompt_dict_history["prompt"][1], + content={"role": "user", "parts": [prompt_dict_history["prompt"][1]["parts"]]}, generation_config=prompt_dict_history["parameters"], safety_settings=DEFAULT_SAFETY_SETTINGS, stream=False, @@ -160,7 +160,7 @@ async def test_gemini_query_history_index_error( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() mock_gemini_call.assert_awaited_once_with( - content=prompt_dict_history["prompt"][1], + content={"role": "user", "parts": [prompt_dict_history["prompt"][1]["parts"]]}, generation_config=prompt_dict_history["parameters"], safety_settings=DEFAULT_SAFETY_SETTINGS, stream=False, @@ -206,7 +206,6 @@ async def test_gemini_query_history_check_chat_init( ), DEFAULT_SAFETY_SETTINGS, prompt_dict_history["parameters"], - None, ) # error will be raised as we've mocked the start_chat method @@ -263,7 +262,10 @@ async def test_gemini_query_history_no_system( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() mock_gemini_call.assert_awaited_once_with( - content=prompt_dict_history_no_system["prompt"][2], + content={ + "role": "user", + "parts": [prompt_dict_history_no_system["prompt"][2]["parts"]], + }, generation_config=prompt_dict_history_no_system["parameters"], safety_settings=DEFAULT_SAFETY_SETTINGS, stream=False, @@ -309,7 +311,10 @@ async def test_gemini_query_history_error_no_system( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() mock_gemini_call.assert_awaited_once_with( - content=prompt_dict_history_no_system["prompt"][2], + content={ + "role": "user", + "parts": [prompt_dict_history_no_system["prompt"][2]["parts"]], + }, generation_config=prompt_dict_history_no_system["parameters"], safety_settings=DEFAULT_SAFETY_SETTINGS, stream=False, @@ -356,7 +361,10 @@ async def test_gemini_query_history_index_error_no_system( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() mock_gemini_call.assert_awaited_once_with( - content=prompt_dict_history_no_system["prompt"][2], + content={ + "role": "user", + "parts": [prompt_dict_history_no_system["prompt"][2]["parts"]], + }, generation_config=prompt_dict_history_no_system["parameters"], safety_settings=DEFAULT_SAFETY_SETTINGS, stream=False, @@ -402,7 +410,6 @@ async def test_gemini_query_history_no_system_check_chat_init( ), DEFAULT_SAFETY_SETTINGS, prompt_dict_history_no_system["parameters"], - None, ) # error will be raised as we've mocked the start_chat method @@ -415,7 +422,13 @@ async def test_gemini_query_history_no_system_check_chat_init( ) mock_start_chat.assert_called_once_with( history=[ - prompt_dict_history_no_system["prompt"][0], - prompt_dict_history_no_system["prompt"][1], + { + "role": "user", + "parts": [prompt_dict_history_no_system["prompt"][0]["parts"]], + }, + { + "role": "model", + "parts": [prompt_dict_history_no_system["prompt"][1]["parts"]], + }, ] ) diff --git a/tests/apis/gemini/test_gemini_string_input.py b/tests/apis/gemini/test_gemini_string_input.py index f5237ca5..0554f50c 100644 --- a/tests/apis/gemini/test_gemini_string_input.py +++ b/tests/apis/gemini/test_gemini_string_input.py @@ -74,7 +74,7 @@ async def test_gemini_query_string( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() mock_gemini_call.assert_awaited_once_with( - contents=[prompt_dict_string["prompt"]], + contents=prompt_dict_string["prompt"], generation_config=prompt_dict_string["parameters"], safety_settings=DEFAULT_SAFETY_SETTINGS, stream=False, @@ -123,7 +123,7 @@ async def test_gemini_query_string__index_error( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() mock_gemini_call.assert_awaited_once_with( - contents=[prompt_dict_string["prompt"]], + contents=prompt_dict_string["prompt"], generation_config=prompt_dict_string["parameters"], safety_settings=DEFAULT_SAFETY_SETTINGS, stream=False, @@ -164,7 +164,7 @@ async def test_gemini_query_string_error( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() mock_gemini_call.assert_awaited_once_with( - contents=[prompt_dict_string["prompt"]], + contents=prompt_dict_string["prompt"], generation_config=prompt_dict_string["parameters"], safety_settings=DEFAULT_SAFETY_SETTINGS, stream=False, diff --git a/tests/apis/vertexai/test_vertexai.py b/tests/apis/vertexai/test_vertexai.py index 04463551..9ac5e921 100644 --- a/tests/apis/vertexai/test_vertexai.py +++ b/tests/apis/vertexai/test_vertexai.py @@ -399,7 +399,7 @@ async def test_vertexai_obtain_model_inputs(temporary_data_folders, monkeypatch) } ) assert isinstance(test_case, tuple) - assert len(test_case) == 6 + assert len(test_case) == 5 assert test_case[0] == "test prompt" assert test_case[1] == "vertexai_model_name" assert isinstance(test_case[2], GenerativeModel) @@ -407,7 +407,6 @@ async def test_vertexai_obtain_model_inputs(temporary_data_folders, monkeypatch) assert test_case[2]._system_instruction is None assert isinstance(test_case[3], dict) assert test_case[4] == {"temperature": 1, "max_output_tokens": 100} - assert test_case[5] is None # test for case where no parameters in prompt_dict test_case = await vertexai_api._obtain_model_inputs( @@ -419,7 +418,7 @@ async def test_vertexai_obtain_model_inputs(temporary_data_folders, monkeypatch) }, ) assert isinstance(test_case, tuple) - assert len(test_case) == 6 + assert len(test_case) == 5 assert test_case[0] == "test prompt" assert test_case[1] == "vertexai_model_name" assert isinstance(test_case[2], GenerativeModel) @@ -427,7 +426,6 @@ async def test_vertexai_obtain_model_inputs(temporary_data_folders, monkeypatch) assert test_case[2]._system_instruction is None assert isinstance(test_case[3], dict) assert test_case[4] == {} - assert test_case[5] is None # test for case where system_instruction is provided test_case = await vertexai_api._obtain_model_inputs( @@ -440,7 +438,7 @@ async def test_vertexai_obtain_model_inputs(temporary_data_folders, monkeypatch) system_instruction="hello", ) assert isinstance(test_case, tuple) - assert len(test_case) == 6 + assert len(test_case) == 5 assert test_case[0] == "test prompt" assert test_case[1] == "vertexai_model_name" assert isinstance(test_case[2], GenerativeModel) @@ -448,7 +446,6 @@ async def test_vertexai_obtain_model_inputs(temporary_data_folders, monkeypatch) assert test_case[2]._system_instruction is not None assert isinstance(test_case[3], dict) assert test_case[4] == {} - assert test_case[5] is None # test error catching when parameters are not a dictionary with pytest.raises( @@ -507,7 +504,7 @@ async def test_vertexai_obtain_model_inputs_safety_filters( } ) assert isinstance(test_case, tuple) - assert len(test_case) == 6 + assert len(test_case) == 5 assert test_case[0] == "test prompt" assert test_case[1] == "vertexai_model_name" assert isinstance(test_case[2], GenerativeModel) @@ -517,7 +514,6 @@ async def test_vertexai_obtain_model_inputs_safety_filters( assert test_case[2]._system_instruction is None assert isinstance(test_case[3], dict) assert test_case[4] == {"temperature": 1, "max_output_tokens": 100} - assert test_case[5] is None # test error if safety filter is not recognised with pytest.raises( diff --git a/tests/apis/vertexai/test_vertexai_chat_input.py b/tests/apis/vertexai/test_vertexai_chat_input.py index 5dab0c04..c81ebf8d 100644 --- a/tests/apis/vertexai/test_vertexai_chat_input.py +++ b/tests/apis/vertexai/test_vertexai_chat_input.py @@ -135,7 +135,6 @@ async def test_vertexai_query_history_check_chat_init( ), DEFAULT_SAFETY_SETTINGS, prompt_dict_chat["parameters"], - None, ) # error will be raised as we've mocked the start_chat method diff --git a/tests/apis/vertexai/test_vertexai_history_input.py b/tests/apis/vertexai/test_vertexai_history_input.py index 446f5143..1251a2a9 100644 --- a/tests/apis/vertexai/test_vertexai_history_input.py +++ b/tests/apis/vertexai/test_vertexai_history_input.py @@ -227,7 +227,6 @@ async def test_vertexai_query_history_check_chat_init( ), DEFAULT_SAFETY_SETTINGS, prompt_dict_history["parameters"], - None, ) # error will be raised as we've mocked the start_chat method @@ -467,7 +466,6 @@ async def test_vertexai_query_history_no_system_check_chat_init( ), DEFAULT_SAFETY_SETTINGS, prompt_dict_history_no_system["parameters"], - None, ) # error will be raised as we've mocked the start_chat method diff --git a/tests/apis/vertexai/test_vertexai_string_input.py b/tests/apis/vertexai/test_vertexai_string_input.py index 6ba4264d..1266ead6 100644 --- a/tests/apis/vertexai/test_vertexai_string_input.py +++ b/tests/apis/vertexai/test_vertexai_string_input.py @@ -56,25 +56,12 @@ async def test_vertexai_query_string( mock_vertexai_call.assert_called_once() mock_vertexai_call.assert_awaited_once() - await_kwargs = mock_vertexai_call.await_args.kwargs - # assert await arguments - slightly more complicated since "contents" is a - # list of Part objects which are not directly comparable so we convert to dict and compare - assert ( - "contents" in await_kwargs.keys() - and isinstance(await_kwargs["contents"], list) - and len(await_kwargs["contents"]) == 1 - and await_kwargs["contents"][0].to_dict() - == {"text": prompt_dict_string["prompt"]} + mock_vertexai_call.assert_awaited_once_with( + contents=prompt_dict_string["prompt"], + generation_config=prompt_dict_string["parameters"], + safety_settings=DEFAULT_SAFETY_SETTINGS, + stream=False, ) - assert ( - "generation_config" in await_kwargs.keys() - and await_kwargs["generation_config"] == prompt_dict_string["parameters"] - ) - assert ( - "safety_settings" in await_kwargs.keys() - and await_kwargs["safety_settings"] == DEFAULT_SAFETY_SETTINGS - ) - assert "stream" in await_kwargs.keys() and await_kwargs["stream"] is False mock_process_response.assert_called_once_with(mock_vertexai_call.return_value) mock_process_safety_attr.assert_called_once_with(mock_vertexai_call.return_value) @@ -120,25 +107,12 @@ async def test_vertexai_query_string__index_error( mock_vertexai_call.assert_called_once() mock_vertexai_call.assert_awaited_once() - await_kwargs = mock_vertexai_call.await_args.kwargs - # assert await arguments - slightly more complicated since "contents" is a - # list of Part objects which are not directly comparable so we convert to dict and compare - assert ( - "contents" in await_kwargs.keys() - and isinstance(await_kwargs["contents"], list) - and len(await_kwargs["contents"]) == 1 - and await_kwargs["contents"][0].to_dict() - == {"text": prompt_dict_string["prompt"]} - ) - assert ( - "generation_config" in await_kwargs.keys() - and await_kwargs["generation_config"] == prompt_dict_string["parameters"] + mock_vertexai_call.assert_awaited_once_with( + contents=prompt_dict_string["prompt"], + generation_config=prompt_dict_string["parameters"], + safety_settings=DEFAULT_SAFETY_SETTINGS, + stream=False, ) - assert ( - "safety_settings" in await_kwargs.keys() - and await_kwargs["safety_settings"] == DEFAULT_SAFETY_SETTINGS - ) - assert "stream" in await_kwargs.keys() and await_kwargs["stream"] is False expected_log_message = ( f"Error with model VertexAI ({prompt_dict_string['model_name']}) " @@ -176,25 +150,12 @@ async def test_vertexai_query_string_error( mock_vertexai_call.assert_called_once() mock_vertexai_call.assert_awaited_once() - await_kwargs = mock_vertexai_call.await_args.kwargs - # assert await arguments - slightly more complicated since "contents" is a - # list of Part objects which are not directly comparable so we convert to dict and compare - assert ( - "contents" in await_kwargs.keys() - and isinstance(await_kwargs["contents"], list) - and len(await_kwargs["contents"]) == 1 - and await_kwargs["contents"][0].to_dict() - == {"text": prompt_dict_string["prompt"]} - ) - assert ( - "generation_config" in await_kwargs.keys() - and await_kwargs["generation_config"] == prompt_dict_string["parameters"] - ) - assert ( - "safety_settings" in await_kwargs.keys() - and await_kwargs["safety_settings"] == DEFAULT_SAFETY_SETTINGS + mock_vertexai_call.assert_awaited_once_with( + contents=prompt_dict_string["prompt"], + generation_config=prompt_dict_string["parameters"], + safety_settings=DEFAULT_SAFETY_SETTINGS, + stream=False, ) - assert "stream" in await_kwargs.keys() and await_kwargs["stream"] is False expected_log_message = ( f"Error with model VertexAI ({prompt_dict_string['model_name']}) "