Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: support seed and fix absolute path #278

Merged
merged 3 commits into from
Sep 19, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
lint
  • Loading branch information
you-n-g committed Sep 19, 2024
commit 1435e501f03244d2835f5ef026da650d76fe0216
136 changes: 81 additions & 55 deletions rdagent/oai/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _rotate_files(self) -> None:
n = int(m.group(1))
pairs.append((n, f))
pairs.sort(key=lambda x: x[0])
for n, f in pairs[:self.recent_n][::-1]:
for n, f in pairs[: self.recent_n][::-1]:
if (self.path / f"{n+1}.json").exists():
(self.path / f"{n+1}.json").unlink()
f.rename(self.path / f"{n+1}.json")
Expand All @@ -85,7 +85,6 @@ def append(self, conv: tuple[list, str]) -> None:


class SQliteLazyCache(SingletonBaseClass):

def __init__(self, cache_location: str) -> None:
super().__init__()
self.cache_location = cache_location
Expand All @@ -100,21 +99,24 @@ def __init__(self, cache_location: str) -> None:
md5_key TEXT PRIMARY KEY,
chat TEXT
)
""",)
""",
)
self.c.execute(
"""
CREATE TABLE embedding_cache (
md5_key TEXT PRIMARY KEY,
embedding TEXT
)
""",)
""",
)
self.c.execute(
"""
CREATE TABLE message_cache (
conversation_id TEXT PRIMARY KEY,
message TEXT
)
""",)
""",
)
self.conn.commit()

def chat_get(self, key: str) -> str | None:
Expand Down Expand Up @@ -166,7 +168,6 @@ def message_set(self, conversation_id: str, message_value: list[str]) -> None:


class SessionChatHistoryCache(SingletonBaseClass):

def __init__(self) -> None:
"""load all history conversation json file from self.session_cache_location"""
self.cache = SQliteLazyCache(cache_location=RD_AGENT_SETTINGS.prompt_cache_path)
Expand All @@ -179,7 +180,6 @@ def message_set(self, conversation_id: str, message_value: list[str]) -> None:


class ChatSession:

def __init__(self, api_backend: Any, conversation_id: str | None = None, system_prompt: str | None = None) -> None:
self.conversation_id = str(uuid.uuid4()) if conversation_id is None else conversation_id
self.cfg = RD_AGENT_SETTINGS
Expand All @@ -191,10 +191,12 @@ def build_chat_completion_message(self, user_prompt: str) -> list[dict[str, Any]
messages = history_message
if not messages:
messages.append({"role": "system", "content": self.system_prompt})
messages.append({
"role": "user",
"content": user_prompt,
},)
messages.append(
{
"role": "user",
"content": user_prompt,
},
)
return messages

def build_chat_completion_message_and_calculate_token(self, user_prompt: str) -> Any:
Expand All @@ -215,10 +217,12 @@ def build_chat_completion(self, user_prompt: str, **kwargs: Any) -> str:
**kwargs,
)

messages.append({
"role": "assistant",
"content": response,
},)
messages.append(
{
"role": "assistant",
"content": response,
},
)
SessionChatHistoryCache().message_set(self.conversation_id, messages)
return response

Expand All @@ -231,7 +235,6 @@ def display_history(self) -> None:


class APIBackend:

def __init__( # noqa: C901, PLR0912, PLR0915
self,
*,
Expand Down Expand Up @@ -301,10 +304,18 @@ def __init__( # noqa: C901, PLR0912, PLR0915

# Priority: chat_api_key/embedding_api_key > openai_api_key > os.environ.get("OPENAI_API_KEY")
# TODO: Simplify the key design. Consider Pandatic's field alias & priority.
self.chat_api_key = (chat_api_key or self.cfg.chat_openai_api_key or self.cfg.openai_api_key or
os.environ.get("OPENAI_API_KEY"))
self.embedding_api_key = (embedding_api_key or self.cfg.embedding_openai_api_key or
self.cfg.openai_api_key or os.environ.get("OPENAI_API_KEY"))
self.chat_api_key = (
chat_api_key
or self.cfg.chat_openai_api_key
or self.cfg.openai_api_key
or os.environ.get("OPENAI_API_KEY")
)
self.embedding_api_key = (
embedding_api_key
or self.cfg.embedding_openai_api_key
or self.cfg.openai_api_key
or os.environ.get("OPENAI_API_KEY")
)

self.chat_model = self.cfg.chat_model if chat_model is None else chat_model
self.encoder = tiktoken.encoding_for_model(self.chat_model)
Expand All @@ -314,10 +325,12 @@ def __init__( # noqa: C901, PLR0912, PLR0915
self.chat_seed = self.cfg.chat_seed

self.embedding_model = self.cfg.embedding_model if embedding_model is None else embedding_model
self.embedding_api_base = (self.cfg.embedding_azure_api_base
if embedding_api_base is None else embedding_api_base)
self.embedding_api_version = (self.cfg.embedding_azure_api_version
if embedding_api_version is None else embedding_api_version)
self.embedding_api_base = (
self.cfg.embedding_azure_api_base if embedding_api_base is None else embedding_api_base
)
self.embedding_api_version = (
self.cfg.embedding_azure_api_version if embedding_api_version is None else embedding_api_version
)

if self.use_azure:
if self.use_azure_token_provider:
Expand Down Expand Up @@ -356,8 +369,9 @@ def __init__( # noqa: C901, PLR0912, PLR0915

self.dump_chat_cache = self.cfg.dump_chat_cache if dump_chat_cache is None else dump_chat_cache
self.use_chat_cache = self.cfg.use_chat_cache if use_chat_cache is None else use_chat_cache
self.dump_embedding_cache = (self.cfg.dump_embedding_cache
if dump_embedding_cache is None else dump_embedding_cache)
self.dump_embedding_cache = (
self.cfg.dump_embedding_cache if dump_embedding_cache is None else dump_embedding_cache
)
self.use_embedding_cache = self.cfg.use_embedding_cache if use_embedding_cache is None else use_embedding_cache
if self.dump_chat_cache or self.use_chat_cache or self.dump_embedding_cache or self.use_embedding_cache:
self.cache_file_location = self.cfg.prompt_cache_path
Expand Down Expand Up @@ -389,7 +403,7 @@ def build_messages(
) -> list[dict]:
"""
build the messages to avoid implementing several redundant lines of code

"""
if former_messages is None:
former_messages = []
Expand All @@ -407,11 +421,13 @@ def build_messages(
"content": system_prompt,
},
]
messages.extend(former_messages[-1 * self.cfg.max_past_message_include:])
messages.append({
"role": "user",
"content": user_prompt,
},)
messages.extend(former_messages[-1 * self.cfg.max_past_message_include :])
messages.append(
{
"role": "user",
"content": user_prompt,
},
)
return messages

def build_messages_and_create_chat_completion(
Expand Down Expand Up @@ -460,10 +476,12 @@ def _create_chat_completion_auto_continue(self, messages: list, **kwargs: dict)
if finish_reason == "length":
new_message = deepcopy(messages)
new_message.append({"role": "assistant", "content": response})
new_message.append({
"role": "user",
"content": "continue the former output with no overlap",
},)
new_message.append(
{
"role": "user",
"content": "continue the former output with no overlap",
},
)
new_response, finish_reason = self._create_chat_completion_inner_function(messages=new_message, **kwargs)
return response + new_response
return response
Expand Down Expand Up @@ -491,7 +509,7 @@ def _try_create_chat_completion_or_embedding(
kwargs["add_json_in_prompt"] = True
elif embedding and "maximum context length" in e.message:
kwargs["input_content_list"] = [
content[:len(content) // 2] for content in kwargs.get("input_content_list", [])
content[: len(content) // 2] for content in kwargs.get("input_content_list", [])
]
except Exception as e: # noqa: BLE001
logger.warning(e)
Expand All @@ -500,8 +518,9 @@ def _try_create_chat_completion_or_embedding(
error_message = f"Failed to create chat completion after {max_retry} retries."
raise RuntimeError(error_message)

def _create_embedding_inner_function(self, input_content_list: list[str],
**kwargs: Any) -> list[Any]: # noqa: ARG002
def _create_embedding_inner_function(
self, input_content_list: list[str], **kwargs: Any
) -> list[Any]: # noqa: ARG002
content_to_embedding_dict = {}
filtered_input_content_list = []
if self.use_embedding_cache:
Expand Down Expand Up @@ -535,10 +554,12 @@ def _create_embedding_inner_function(self, input_content_list: list[str],
def _build_log_messages(self, messages: list[dict]) -> str:
log_messages = ""
for m in messages:
log_messages += (f"\n{LogColors.MAGENTA}{LogColors.BOLD}Role:{LogColors.END}"
f"{LogColors.CYAN}{m['role']}{LogColors.END}\n"
f"{LogColors.MAGENTA}{LogColors.BOLD}Content:{LogColors.END} "
f"{LogColors.CYAN}{m['content']}{LogColors.END}\n")
log_messages += (
f"\n{LogColors.MAGENTA}{LogColors.BOLD}Role:{LogColors.END}"
f"{LogColors.CYAN}{m['role']}{LogColors.END}\n"
f"{LogColors.MAGENTA}{LogColors.BOLD}Content:{LogColors.END} "
f"{LogColors.CYAN}{m['content']}{LogColors.END}\n"
)
return log_messages

def _create_chat_completion_inner_function( # noqa: C901, PLR0912, PLR0915
Expand All @@ -565,8 +586,9 @@ def _create_chat_completion_inner_function( # noqa: C901, PLR0912, PLR0915
logger.info(self._build_log_messages(messages), tag="llm_messages")
# TODO: fail to use loguru adaptor due to stream response
input_content_json = json.dumps(messages)
input_content_json = (chat_cache_prefix + input_content_json + f"<seed={seed}/>"
) # FIXME this is a hack to make sure the cache represents the round index
input_content_json = (
chat_cache_prefix + input_content_json + f"<seed={seed}/>"
) # FIXME this is a hack to make sure the cache represents the round index
if self.use_chat_cache:
cache_result = self.cache.chat_get(input_content_json)
if cache_result is not None:
Expand Down Expand Up @@ -606,7 +628,9 @@ def _create_chat_completion_inner_function( # noqa: C901, PLR0912, PLR0915
"max_new_tokens": self.gcr_endpoint_max_token,
},
},
},),)
},
),
)

req = urllib.request.Request(self.gcr_endpoint, body, self.headers) # noqa: S310
response = urllib.request.urlopen(req) # noqa: S310
Expand Down Expand Up @@ -640,8 +664,11 @@ def _create_chat_completion_inner_function( # noqa: C901, PLR0912, PLR0915
logger.info(f"{LogColors.CYAN}Response:{LogColors.END}", tag="llm_messages")

for chunk in response:
content = (chunk.choices[0].delta.content
if len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None else "")
content = (
chunk.choices[0].delta.content
if len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None
else ""
)
if self.cfg.log_llm_chat_content:
logger.info(LogColors.CYAN + content + LogColors.END, raw=True, tag="llm_messages")
resp += content
Expand Down Expand Up @@ -693,10 +720,9 @@ def build_messages_and_calculate_token(
) -> int:
if former_messages is None:
former_messages = []
messages = self.build_messages(user_prompt,
system_prompt,
former_messages,
shrink_multiple_break=shrink_multiple_break)
messages = self.build_messages(
user_prompt, system_prompt, former_messages, shrink_multiple_break=shrink_multiple_break
)
return self.calculate_token_from_messages(messages)


Expand All @@ -709,7 +735,7 @@ def create_embedding_with_multiprocessing(str_list: list, slice_count: int = 50,

pool = multiprocessing.Pool(nproc)
result_list = [
pool.apply_async(calculate_embedding_process, (str_list[index:index + slice_count],))
pool.apply_async(calculate_embedding_process, (str_list[index : index + slice_count],))
for index in range(0, len(str_list), slice_count)
]
pool.close()
Expand All @@ -728,8 +754,8 @@ def calculate_embedding_distance_between_str_list(
return [[]]

embeddings = create_embedding_with_multiprocessing(source_str_list + target_str_list, slice_count=50, nproc=8)
source_embeddings = embeddings[:len(source_str_list)]
target_embeddings = embeddings[len(source_str_list):]
source_embeddings = embeddings[: len(source_str_list)]
target_embeddings = embeddings[len(source_str_list) :]

source_embeddings_np = np.array(source_embeddings)
target_embeddings_np = np.array(target_embeddings)
Expand Down