From 038b7f0afa32fe5dccaa91b3b06c751f71914539 Mon Sep 17 00:00:00 2001 From: Young Date: Thu, 19 Sep 2024 09:49:36 +0000 Subject: [PATCH 1/3] fix: support seed and fix absolute path --- .../coder/factor_coder/CoSTEER/evaluators.py | 1 + rdagent/oai/llm_utils.py | 153 ++++++++---------- 2 files changed, 71 insertions(+), 83 deletions(-) diff --git a/rdagent/components/coder/factor_coder/CoSTEER/evaluators.py b/rdagent/components/coder/factor_coder/CoSTEER/evaluators.py index ccd08b456..ef480b4b1 100644 --- a/rdagent/components/coder/factor_coder/CoSTEER/evaluators.py +++ b/rdagent/components/coder/factor_coder/CoSTEER/evaluators.py @@ -505,6 +505,7 @@ def evaluate( user_prompt=user_prompt, system_prompt=system_prompt, json_mode=True, + seed=attempts, # in case of useless retrying when cache enabled. ), ) final_decision = final_evaluation_dict["final_decision"] diff --git a/rdagent/oai/llm_utils.py b/rdagent/oai/llm_utils.py index 427ff0ab5..b3b99ec65 100644 --- a/rdagent/oai/llm_utils.py +++ b/rdagent/oai/llm_utils.py @@ -13,7 +13,7 @@ import uuid from copy import deepcopy from pathlib import Path -from typing import Any +from typing import Any, Optional import numpy as np import tiktoken @@ -72,7 +72,7 @@ def _rotate_files(self) -> None: n = int(m.group(1)) pairs.append((n, f)) pairs.sort(key=lambda x: x[0]) - for n, f in pairs[: self.recent_n][::-1]: + for n, f in pairs[:self.recent_n][::-1]: if (self.path / f"{n+1}.json").exists(): (self.path / f"{n+1}.json").unlink() f.rename(self.path / f"{n+1}.json") @@ -85,6 +85,7 @@ def append(self, conv: tuple[list, str]) -> None: class SQliteLazyCache(SingletonBaseClass): + def __init__(self, cache_location: str) -> None: super().__init__() self.cache_location = cache_location @@ -99,24 +100,21 @@ def __init__(self, cache_location: str) -> None: md5_key TEXT PRIMARY KEY, chat TEXT ) - """, - ) + """,) self.c.execute( """ CREATE TABLE embedding_cache ( md5_key TEXT PRIMARY KEY, embedding TEXT ) - """, - ) + """,) self.c.execute( """ CREATE TABLE message_cache ( conversation_id TEXT PRIMARY KEY, message TEXT ) - """, - ) + """,) self.conn.commit() def chat_get(self, key: str) -> str | None: @@ -168,6 +166,7 @@ def message_set(self, conversation_id: str, message_value: list[str]) -> None: class SessionChatHistoryCache(SingletonBaseClass): + def __init__(self) -> None: """load all history conversation json file from self.session_cache_location""" self.cache = SQliteLazyCache(cache_location=RD_AGENT_SETTINGS.prompt_cache_path) @@ -180,6 +179,7 @@ def message_set(self, conversation_id: str, message_value: list[str]) -> None: class ChatSession: + def __init__(self, api_backend: Any, conversation_id: str | None = None, system_prompt: str | None = None) -> None: self.conversation_id = str(uuid.uuid4()) if conversation_id is None else conversation_id self.cfg = RD_AGENT_SETTINGS @@ -191,12 +191,10 @@ def build_chat_completion_message(self, user_prompt: str) -> list[dict[str, Any] messages = history_message if not messages: messages.append({"role": "system", "content": self.system_prompt}) - messages.append( - { - "role": "user", - "content": user_prompt, - }, - ) + messages.append({ + "role": "user", + "content": user_prompt, + },) return messages def build_chat_completion_message_and_calculate_token(self, user_prompt: str) -> Any: @@ -217,12 +215,10 @@ def build_chat_completion(self, user_prompt: str, **kwargs: Any) -> str: **kwargs, ) - messages.append( - { - "role": "assistant", - "content": response, - }, - ) + messages.append({ + "role": "assistant", + "content": response, + },) SessionChatHistoryCache().message_set(self.conversation_id, messages) return response @@ -235,6 +231,7 @@ def display_history(self) -> None: class APIBackend: + def __init__( # noqa: C901, PLR0912, PLR0915 self, *, @@ -304,18 +301,10 @@ def __init__( # noqa: C901, PLR0912, PLR0915 # Priority: chat_api_key/embedding_api_key > openai_api_key > os.environ.get("OPENAI_API_KEY") # TODO: Simplify the key design. Consider Pandatic's field alias & priority. - self.chat_api_key = ( - chat_api_key - or self.cfg.chat_openai_api_key - or self.cfg.openai_api_key - or os.environ.get("OPENAI_API_KEY") - ) - self.embedding_api_key = ( - embedding_api_key - or self.cfg.embedding_openai_api_key - or self.cfg.openai_api_key - or os.environ.get("OPENAI_API_KEY") - ) + self.chat_api_key = (chat_api_key or self.cfg.chat_openai_api_key or self.cfg.openai_api_key or + os.environ.get("OPENAI_API_KEY")) + self.embedding_api_key = (embedding_api_key or self.cfg.embedding_openai_api_key or + self.cfg.openai_api_key or os.environ.get("OPENAI_API_KEY")) self.chat_model = self.cfg.chat_model if chat_model is None else chat_model self.encoder = tiktoken.encoding_for_model(self.chat_model) @@ -325,12 +314,10 @@ def __init__( # noqa: C901, PLR0912, PLR0915 self.chat_seed = self.cfg.chat_seed self.embedding_model = self.cfg.embedding_model if embedding_model is None else embedding_model - self.embedding_api_base = ( - self.cfg.embedding_azure_api_base if embedding_api_base is None else embedding_api_base - ) - self.embedding_api_version = ( - self.cfg.embedding_azure_api_version if embedding_api_version is None else embedding_api_version - ) + self.embedding_api_base = (self.cfg.embedding_azure_api_base + if embedding_api_base is None else embedding_api_base) + self.embedding_api_version = (self.cfg.embedding_azure_api_version + if embedding_api_version is None else embedding_api_version) if self.use_azure: if self.use_azure_token_provider: @@ -369,9 +356,8 @@ def __init__( # noqa: C901, PLR0912, PLR0915 self.dump_chat_cache = self.cfg.dump_chat_cache if dump_chat_cache is None else dump_chat_cache self.use_chat_cache = self.cfg.use_chat_cache if use_chat_cache is None else use_chat_cache - self.dump_embedding_cache = ( - self.cfg.dump_embedding_cache if dump_embedding_cache is None else dump_embedding_cache - ) + self.dump_embedding_cache = (self.cfg.dump_embedding_cache + if dump_embedding_cache is None else dump_embedding_cache) self.use_embedding_cache = self.cfg.use_embedding_cache if use_embedding_cache is None else use_embedding_cache if self.dump_chat_cache or self.use_chat_cache or self.dump_embedding_cache or self.use_embedding_cache: self.cache_file_location = self.cfg.prompt_cache_path @@ -401,7 +387,10 @@ def build_messages( *, shrink_multiple_break: bool = False, ) -> list[dict]: - """build the messages to avoid implementing several redundant lines of code""" + """ + build the messages to avoid implementing several redundant lines of code + + """ if former_messages is None: former_messages = [] # shrink multiple break will recursively remove multiple breaks(more than 2) @@ -418,13 +407,11 @@ def build_messages( "content": system_prompt, }, ] - messages.extend(former_messages[-1 * self.cfg.max_past_message_include :]) - messages.append( - { - "role": "user", - "content": user_prompt, - }, - ) + messages.extend(former_messages[-1 * self.cfg.max_past_message_include:]) + messages.append({ + "role": "user", + "content": user_prompt, + },) return messages def build_messages_and_create_chat_completion( @@ -440,7 +427,10 @@ def build_messages_and_create_chat_completion( if former_messages is None: former_messages = [] messages = self.build_messages( - user_prompt, system_prompt, former_messages, shrink_multiple_break=shrink_multiple_break + user_prompt, + system_prompt, + former_messages, + shrink_multiple_break=shrink_multiple_break, ) return self._try_create_chat_completion_or_embedding( messages=messages, @@ -470,12 +460,10 @@ def _create_chat_completion_auto_continue(self, messages: list, **kwargs: dict) if finish_reason == "length": new_message = deepcopy(messages) new_message.append({"role": "assistant", "content": response}) - new_message.append( - { - "role": "user", - "content": "continue the former output with no overlap", - }, - ) + new_message.append({ + "role": "user", + "content": "continue the former output with no overlap", + },) new_response, finish_reason = self._create_chat_completion_inner_function(messages=new_message, **kwargs) return response + new_response return response @@ -503,7 +491,7 @@ def _try_create_chat_completion_or_embedding( kwargs["add_json_in_prompt"] = True elif embedding and "maximum context length" in e.message: kwargs["input_content_list"] = [ - content[: len(content) // 2] for content in kwargs.get("input_content_list", []) + content[:len(content) // 2] for content in kwargs.get("input_content_list", []) ] except Exception as e: # noqa: BLE001 logger.warning(e) @@ -512,9 +500,8 @@ def _try_create_chat_completion_or_embedding( error_message = f"Failed to create chat completion after {max_retry} retries." raise RuntimeError(error_message) - def _create_embedding_inner_function( - self, input_content_list: list[str], **kwargs: Any - ) -> list[Any]: # noqa: ARG002 + def _create_embedding_inner_function(self, input_content_list: list[str], + **kwargs: Any) -> list[Any]: # noqa: ARG002 content_to_embedding_dict = {} filtered_input_content_list = [] if self.use_embedding_cache: @@ -548,12 +535,10 @@ def _create_embedding_inner_function( def _build_log_messages(self, messages: list[dict]) -> str: log_messages = "" for m in messages: - log_messages += ( - f"\n{LogColors.MAGENTA}{LogColors.BOLD}Role:{LogColors.END}" - f"{LogColors.CYAN}{m['role']}{LogColors.END}\n" - f"{LogColors.MAGENTA}{LogColors.BOLD}Content:{LogColors.END} " - f"{LogColors.CYAN}{m['content']}{LogColors.END}\n" - ) + log_messages += (f"\n{LogColors.MAGENTA}{LogColors.BOLD}Role:{LogColors.END}" + f"{LogColors.CYAN}{m['role']}{LogColors.END}\n" + f"{LogColors.MAGENTA}{LogColors.BOLD}Content:{LogColors.END} " + f"{LogColors.CYAN}{m['content']}{LogColors.END}\n") return log_messages def _create_chat_completion_inner_function( # noqa: C901, PLR0912, PLR0915 @@ -567,15 +552,21 @@ def _create_chat_completion_inner_function( # noqa: C901, PLR0912, PLR0915 *, json_mode: bool = False, add_json_in_prompt: bool = False, + seed: Optional[int] = None, ) -> str: + """ + seed : Optional[int] + When retrying with cache enabled, it will keep returning the same results. + To make retries useful, we need to enable a seed. + This seed is different from `self.chat_seed` for GPT. It is for the local cache mechanism enabled by RD-Agent locally. + """ # TODO: we can add this function back to avoid so much `self.cfg.log_llm_chat_content` if self.cfg.log_llm_chat_content: logger.info(self._build_log_messages(messages), tag="llm_messages") # TODO: fail to use loguru adaptor due to stream response input_content_json = json.dumps(messages) - input_content_json = ( - chat_cache_prefix + input_content_json - ) # FIXME this is a hack to make sure the cache represents the round index + input_content_json = (chat_cache_prefix + input_content_json + f"" + ) # FIXME this is a hack to make sure the cache represents the round index if self.use_chat_cache: cache_result = self.cache.chat_get(input_content_json) if cache_result is not None: @@ -615,9 +606,7 @@ def _create_chat_completion_inner_function( # noqa: C901, PLR0912, PLR0915 "max_new_tokens": self.gcr_endpoint_max_token, }, }, - }, - ), - ) + },),) req = urllib.request.Request(self.gcr_endpoint, body, self.headers) # noqa: S310 response = urllib.request.urlopen(req) # noqa: S310 @@ -651,11 +640,8 @@ def _create_chat_completion_inner_function( # noqa: C901, PLR0912, PLR0915 logger.info(f"{LogColors.CYAN}Response:{LogColors.END}", tag="llm_messages") for chunk in response: - content = ( - chunk.choices[0].delta.content - if len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None - else "" - ) + content = (chunk.choices[0].delta.content + if len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None else "") if self.cfg.log_llm_chat_content: logger.info(LogColors.CYAN + content + LogColors.END, raw=True, tag="llm_messages") resp += content @@ -707,9 +693,10 @@ def build_messages_and_calculate_token( ) -> int: if former_messages is None: former_messages = [] - messages = self.build_messages( - user_prompt, system_prompt, former_messages, shrink_multiple_break=shrink_multiple_break - ) + messages = self.build_messages(user_prompt, + system_prompt, + former_messages, + shrink_multiple_break=shrink_multiple_break) return self.calculate_token_from_messages(messages) @@ -722,7 +709,7 @@ def create_embedding_with_multiprocessing(str_list: list, slice_count: int = 50, pool = multiprocessing.Pool(nproc) result_list = [ - pool.apply_async(calculate_embedding_process, (str_list[index : index + slice_count],)) + pool.apply_async(calculate_embedding_process, (str_list[index:index + slice_count],)) for index in range(0, len(str_list), slice_count) ] pool.close() @@ -741,8 +728,8 @@ def calculate_embedding_distance_between_str_list( return [[]] embeddings = create_embedding_with_multiprocessing(source_str_list + target_str_list, slice_count=50, nproc=8) - source_embeddings = embeddings[: len(source_str_list)] - target_embeddings = embeddings[len(source_str_list) :] + source_embeddings = embeddings[:len(source_str_list)] + target_embeddings = embeddings[len(source_str_list):] source_embeddings_np = np.array(source_embeddings) target_embeddings_np = np.array(target_embeddings) From c8800304c9dda78c27f90b2408845964257dd25e Mon Sep 17 00:00:00 2001 From: Young Date: Thu, 19 Sep 2024 09:52:21 +0000 Subject: [PATCH 2/3] Absolute path --- rdagent/components/coder/factor_coder/factor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rdagent/components/coder/factor_coder/factor.py b/rdagent/components/coder/factor_coder/factor.py index f794e4c1f..95c0b53d3 100644 --- a/rdagent/components/coder/factor_coder/factor.py +++ b/rdagent/components/coder/factor_coder/factor.py @@ -89,7 +89,7 @@ def __init__( @staticmethod def link_data_to_workspace(data_path: Path, workspace_path: Path): - data_path = Path(data_path) + data_path = Path(data_path).absolute() # in case of relative path that will be invalid when we change cwd. workspace_path = Path(workspace_path) for data_file_path in data_path.iterdir(): workspace_data_file_path = workspace_path / data_file_path.name From 1435e501f03244d2835f5ef026da650d76fe0216 Mon Sep 17 00:00:00 2001 From: Young Date: Thu, 19 Sep 2024 13:59:10 +0000 Subject: [PATCH 3/3] lint --- rdagent/oai/llm_utils.py | 136 +++++++++++++++++++++++---------------- 1 file changed, 81 insertions(+), 55 deletions(-) diff --git a/rdagent/oai/llm_utils.py b/rdagent/oai/llm_utils.py index b3b99ec65..8cf410b61 100644 --- a/rdagent/oai/llm_utils.py +++ b/rdagent/oai/llm_utils.py @@ -72,7 +72,7 @@ def _rotate_files(self) -> None: n = int(m.group(1)) pairs.append((n, f)) pairs.sort(key=lambda x: x[0]) - for n, f in pairs[:self.recent_n][::-1]: + for n, f in pairs[: self.recent_n][::-1]: if (self.path / f"{n+1}.json").exists(): (self.path / f"{n+1}.json").unlink() f.rename(self.path / f"{n+1}.json") @@ -85,7 +85,6 @@ def append(self, conv: tuple[list, str]) -> None: class SQliteLazyCache(SingletonBaseClass): - def __init__(self, cache_location: str) -> None: super().__init__() self.cache_location = cache_location @@ -100,21 +99,24 @@ def __init__(self, cache_location: str) -> None: md5_key TEXT PRIMARY KEY, chat TEXT ) - """,) + """, + ) self.c.execute( """ CREATE TABLE embedding_cache ( md5_key TEXT PRIMARY KEY, embedding TEXT ) - """,) + """, + ) self.c.execute( """ CREATE TABLE message_cache ( conversation_id TEXT PRIMARY KEY, message TEXT ) - """,) + """, + ) self.conn.commit() def chat_get(self, key: str) -> str | None: @@ -166,7 +168,6 @@ def message_set(self, conversation_id: str, message_value: list[str]) -> None: class SessionChatHistoryCache(SingletonBaseClass): - def __init__(self) -> None: """load all history conversation json file from self.session_cache_location""" self.cache = SQliteLazyCache(cache_location=RD_AGENT_SETTINGS.prompt_cache_path) @@ -179,7 +180,6 @@ def message_set(self, conversation_id: str, message_value: list[str]) -> None: class ChatSession: - def __init__(self, api_backend: Any, conversation_id: str | None = None, system_prompt: str | None = None) -> None: self.conversation_id = str(uuid.uuid4()) if conversation_id is None else conversation_id self.cfg = RD_AGENT_SETTINGS @@ -191,10 +191,12 @@ def build_chat_completion_message(self, user_prompt: str) -> list[dict[str, Any] messages = history_message if not messages: messages.append({"role": "system", "content": self.system_prompt}) - messages.append({ - "role": "user", - "content": user_prompt, - },) + messages.append( + { + "role": "user", + "content": user_prompt, + }, + ) return messages def build_chat_completion_message_and_calculate_token(self, user_prompt: str) -> Any: @@ -215,10 +217,12 @@ def build_chat_completion(self, user_prompt: str, **kwargs: Any) -> str: **kwargs, ) - messages.append({ - "role": "assistant", - "content": response, - },) + messages.append( + { + "role": "assistant", + "content": response, + }, + ) SessionChatHistoryCache().message_set(self.conversation_id, messages) return response @@ -231,7 +235,6 @@ def display_history(self) -> None: class APIBackend: - def __init__( # noqa: C901, PLR0912, PLR0915 self, *, @@ -301,10 +304,18 @@ def __init__( # noqa: C901, PLR0912, PLR0915 # Priority: chat_api_key/embedding_api_key > openai_api_key > os.environ.get("OPENAI_API_KEY") # TODO: Simplify the key design. Consider Pandatic's field alias & priority. - self.chat_api_key = (chat_api_key or self.cfg.chat_openai_api_key or self.cfg.openai_api_key or - os.environ.get("OPENAI_API_KEY")) - self.embedding_api_key = (embedding_api_key or self.cfg.embedding_openai_api_key or - self.cfg.openai_api_key or os.environ.get("OPENAI_API_KEY")) + self.chat_api_key = ( + chat_api_key + or self.cfg.chat_openai_api_key + or self.cfg.openai_api_key + or os.environ.get("OPENAI_API_KEY") + ) + self.embedding_api_key = ( + embedding_api_key + or self.cfg.embedding_openai_api_key + or self.cfg.openai_api_key + or os.environ.get("OPENAI_API_KEY") + ) self.chat_model = self.cfg.chat_model if chat_model is None else chat_model self.encoder = tiktoken.encoding_for_model(self.chat_model) @@ -314,10 +325,12 @@ def __init__( # noqa: C901, PLR0912, PLR0915 self.chat_seed = self.cfg.chat_seed self.embedding_model = self.cfg.embedding_model if embedding_model is None else embedding_model - self.embedding_api_base = (self.cfg.embedding_azure_api_base - if embedding_api_base is None else embedding_api_base) - self.embedding_api_version = (self.cfg.embedding_azure_api_version - if embedding_api_version is None else embedding_api_version) + self.embedding_api_base = ( + self.cfg.embedding_azure_api_base if embedding_api_base is None else embedding_api_base + ) + self.embedding_api_version = ( + self.cfg.embedding_azure_api_version if embedding_api_version is None else embedding_api_version + ) if self.use_azure: if self.use_azure_token_provider: @@ -356,8 +369,9 @@ def __init__( # noqa: C901, PLR0912, PLR0915 self.dump_chat_cache = self.cfg.dump_chat_cache if dump_chat_cache is None else dump_chat_cache self.use_chat_cache = self.cfg.use_chat_cache if use_chat_cache is None else use_chat_cache - self.dump_embedding_cache = (self.cfg.dump_embedding_cache - if dump_embedding_cache is None else dump_embedding_cache) + self.dump_embedding_cache = ( + self.cfg.dump_embedding_cache if dump_embedding_cache is None else dump_embedding_cache + ) self.use_embedding_cache = self.cfg.use_embedding_cache if use_embedding_cache is None else use_embedding_cache if self.dump_chat_cache or self.use_chat_cache or self.dump_embedding_cache or self.use_embedding_cache: self.cache_file_location = self.cfg.prompt_cache_path @@ -389,7 +403,7 @@ def build_messages( ) -> list[dict]: """ build the messages to avoid implementing several redundant lines of code - + """ if former_messages is None: former_messages = [] @@ -407,11 +421,13 @@ def build_messages( "content": system_prompt, }, ] - messages.extend(former_messages[-1 * self.cfg.max_past_message_include:]) - messages.append({ - "role": "user", - "content": user_prompt, - },) + messages.extend(former_messages[-1 * self.cfg.max_past_message_include :]) + messages.append( + { + "role": "user", + "content": user_prompt, + }, + ) return messages def build_messages_and_create_chat_completion( @@ -460,10 +476,12 @@ def _create_chat_completion_auto_continue(self, messages: list, **kwargs: dict) if finish_reason == "length": new_message = deepcopy(messages) new_message.append({"role": "assistant", "content": response}) - new_message.append({ - "role": "user", - "content": "continue the former output with no overlap", - },) + new_message.append( + { + "role": "user", + "content": "continue the former output with no overlap", + }, + ) new_response, finish_reason = self._create_chat_completion_inner_function(messages=new_message, **kwargs) return response + new_response return response @@ -491,7 +509,7 @@ def _try_create_chat_completion_or_embedding( kwargs["add_json_in_prompt"] = True elif embedding and "maximum context length" in e.message: kwargs["input_content_list"] = [ - content[:len(content) // 2] for content in kwargs.get("input_content_list", []) + content[: len(content) // 2] for content in kwargs.get("input_content_list", []) ] except Exception as e: # noqa: BLE001 logger.warning(e) @@ -500,8 +518,9 @@ def _try_create_chat_completion_or_embedding( error_message = f"Failed to create chat completion after {max_retry} retries." raise RuntimeError(error_message) - def _create_embedding_inner_function(self, input_content_list: list[str], - **kwargs: Any) -> list[Any]: # noqa: ARG002 + def _create_embedding_inner_function( + self, input_content_list: list[str], **kwargs: Any + ) -> list[Any]: # noqa: ARG002 content_to_embedding_dict = {} filtered_input_content_list = [] if self.use_embedding_cache: @@ -535,10 +554,12 @@ def _create_embedding_inner_function(self, input_content_list: list[str], def _build_log_messages(self, messages: list[dict]) -> str: log_messages = "" for m in messages: - log_messages += (f"\n{LogColors.MAGENTA}{LogColors.BOLD}Role:{LogColors.END}" - f"{LogColors.CYAN}{m['role']}{LogColors.END}\n" - f"{LogColors.MAGENTA}{LogColors.BOLD}Content:{LogColors.END} " - f"{LogColors.CYAN}{m['content']}{LogColors.END}\n") + log_messages += ( + f"\n{LogColors.MAGENTA}{LogColors.BOLD}Role:{LogColors.END}" + f"{LogColors.CYAN}{m['role']}{LogColors.END}\n" + f"{LogColors.MAGENTA}{LogColors.BOLD}Content:{LogColors.END} " + f"{LogColors.CYAN}{m['content']}{LogColors.END}\n" + ) return log_messages def _create_chat_completion_inner_function( # noqa: C901, PLR0912, PLR0915 @@ -565,8 +586,9 @@ def _create_chat_completion_inner_function( # noqa: C901, PLR0912, PLR0915 logger.info(self._build_log_messages(messages), tag="llm_messages") # TODO: fail to use loguru adaptor due to stream response input_content_json = json.dumps(messages) - input_content_json = (chat_cache_prefix + input_content_json + f"" - ) # FIXME this is a hack to make sure the cache represents the round index + input_content_json = ( + chat_cache_prefix + input_content_json + f"" + ) # FIXME this is a hack to make sure the cache represents the round index if self.use_chat_cache: cache_result = self.cache.chat_get(input_content_json) if cache_result is not None: @@ -606,7 +628,9 @@ def _create_chat_completion_inner_function( # noqa: C901, PLR0912, PLR0915 "max_new_tokens": self.gcr_endpoint_max_token, }, }, - },),) + }, + ), + ) req = urllib.request.Request(self.gcr_endpoint, body, self.headers) # noqa: S310 response = urllib.request.urlopen(req) # noqa: S310 @@ -640,8 +664,11 @@ def _create_chat_completion_inner_function( # noqa: C901, PLR0912, PLR0915 logger.info(f"{LogColors.CYAN}Response:{LogColors.END}", tag="llm_messages") for chunk in response: - content = (chunk.choices[0].delta.content - if len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None else "") + content = ( + chunk.choices[0].delta.content + if len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None + else "" + ) if self.cfg.log_llm_chat_content: logger.info(LogColors.CYAN + content + LogColors.END, raw=True, tag="llm_messages") resp += content @@ -693,10 +720,9 @@ def build_messages_and_calculate_token( ) -> int: if former_messages is None: former_messages = [] - messages = self.build_messages(user_prompt, - system_prompt, - former_messages, - shrink_multiple_break=shrink_multiple_break) + messages = self.build_messages( + user_prompt, system_prompt, former_messages, shrink_multiple_break=shrink_multiple_break + ) return self.calculate_token_from_messages(messages) @@ -709,7 +735,7 @@ def create_embedding_with_multiprocessing(str_list: list, slice_count: int = 50, pool = multiprocessing.Pool(nproc) result_list = [ - pool.apply_async(calculate_embedding_process, (str_list[index:index + slice_count],)) + pool.apply_async(calculate_embedding_process, (str_list[index : index + slice_count],)) for index in range(0, len(str_list), slice_count) ] pool.close() @@ -728,8 +754,8 @@ def calculate_embedding_distance_between_str_list( return [[]] embeddings = create_embedding_with_multiprocessing(source_str_list + target_str_list, slice_count=50, nproc=8) - source_embeddings = embeddings[:len(source_str_list)] - target_embeddings = embeddings[len(source_str_list):] + source_embeddings = embeddings[: len(source_str_list)] + target_embeddings = embeddings[len(source_str_list) :] source_embeddings_np = np.array(source_embeddings) target_embeddings_np = np.array(target_embeddings)