diff --git a/memgpt/local_llm/chat_completion_proxy.py b/memgpt/local_llm/chat_completion_proxy.py index f533a33873..dd386893bd 100644 --- a/memgpt/local_llm/chat_completion_proxy.py +++ b/memgpt/local_llm/chat_completion_proxy.py @@ -6,6 +6,8 @@ from .webui.api import get_webui_completion from .lmstudio.api import get_lmstudio_completion +from .llamacpp.api import get_llamacpp_completion +from .koboldcpp.api import get_koboldcpp_completion from .llm_chat_completion_wrappers import airoboros, dolphin, zephyr, simple_summary_wrapper from .utils import DotDict from ..prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE @@ -14,7 +16,8 @@ HOST = os.getenv("OPENAI_API_BASE") HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion DEBUG = False -DEFAULT_WRAPPER = airoboros.Airoboros21InnerMonologueWrapper() +# DEBUG = True +DEFAULT_WRAPPER = airoboros.Airoboros21InnerMonologueWrapper has_shown_warning = False @@ -25,6 +28,7 @@ def get_chat_completion( function_call="auto", ): global has_shown_warning + grammar = None if HOST is None: raise ValueError(f"The OPENAI_API_BASE environment variable is not defined. Please set it in your environment.") @@ -39,10 +43,22 @@ def get_chat_completion( llm_wrapper = simple_summary_wrapper.SimpleSummaryWrapper() elif model == "airoboros-l2-70b-2.1": llm_wrapper = airoboros.Airoboros21InnerMonologueWrapper() + elif model == "airoboros-l2-70b-2.1-grammar": + llm_wrapper = airoboros.Airoboros21InnerMonologueWrapper(include_opening_brace_in_prefix=False) + # grammar_name = "json" + grammar_name = "json_func_calls_with_inner_thoughts" elif model == "dolphin-2.1-mistral-7b": llm_wrapper = dolphin.Dolphin21MistralWrapper() + elif model == "dolphin-2.1-mistral-7b-grammar": + llm_wrapper = dolphin.Dolphin21MistralWrapper(include_opening_brace_in_prefix=False) + # grammar_name = "json" + grammar_name = "json_func_calls_with_inner_thoughts" elif model == "zephyr-7B-alpha" or model == "zephyr-7B-beta": llm_wrapper = zephyr.ZephyrMistralInnerMonologueWrapper() + elif model == "zephyr-7B-alpha-grammar" or model == "zephyr-7B-beta-grammar": + llm_wrapper = zephyr.ZephyrMistralInnerMonologueWrapper(include_opening_brace_in_prefix=False) + # grammar_name = "json" + grammar_name = "json_func_calls_with_inner_thoughts" else: # Warn the user that we're using the fallback if not has_shown_warning: @@ -50,10 +66,18 @@ def get_chat_completion( f"Warning: no wrapper specified for local LLM, using the default wrapper (you can remove this warning by specifying the wrapper with --model)" ) has_shown_warning = True - llm_wrapper = DEFAULT_WRAPPER + if HOST_TYPE in ["koboldcpp", "llamacpp", "webui"]: + # make the default to use grammar + llm_wrapper = DEFAULT_WRAPPER(include_opening_brace_in_prefix=False) + # grammar_name = "json" + grammar_name = "json_func_calls_with_inner_thoughts" + else: + llm_wrapper = DEFAULT_WRAPPER() - # First step: turn the message sequence into a prompt that the model expects + if grammar is not None and HOST_TYPE not in ["koboldcpp", "llamacpp", "webui"]: + print(f"Warning: grammars are currently only supported when using llama.cpp as the MemGPT local LLM backend") + # First step: turn the message sequence into a prompt that the model expects try: prompt = llm_wrapper.chat_completion_to_prompt(messages, functions) if DEBUG: @@ -65,17 +89,24 @@ def get_chat_completion( try: if HOST_TYPE == "webui": - result = get_webui_completion(prompt) + result = get_webui_completion(prompt, grammar=grammar_name) elif HOST_TYPE == "lmstudio": result = get_lmstudio_completion(prompt) + elif HOST_TYPE == "llamacpp": + result = get_llamacpp_completion(prompt, grammar=grammar_name) + elif HOST_TYPE == "koboldcpp": + result = get_koboldcpp_completion(prompt, grammar=grammar_name) else: - print(f"Warning: BACKEND_TYPE was not set, defaulting to webui") - result = get_webui_completion(prompt) + raise LocalLLMError( + f"BACKEND_TYPE is not set, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)" + ) except requests.exceptions.ConnectionError as e: raise LocalLLMConnectionError(f"Unable to connect to host {HOST}") if result is None or result == "": raise LocalLLMError(f"Got back an empty response string from {HOST}") + if DEBUG: + print(f"Raw LLM output:\n{result}") try: chat_completion_result = llm_wrapper.output_to_chat_completion_response(result) diff --git a/memgpt/local_llm/grammars/json.gbnf b/memgpt/local_llm/grammars/json.gbnf new file mode 100644 index 0000000000..47afedbfbb --- /dev/null +++ b/memgpt/local_llm/grammars/json.gbnf @@ -0,0 +1,26 @@ +# https://github.com/ggerganov/llama.cpp/blob/master/grammars/json.gbnf +root ::= object +value ::= object | array | string | number | ("true" | "false" | "null") ws + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= ([ \t\n] ws)? diff --git a/memgpt/local_llm/grammars/json_func_calls_with_inner_thoughts.gbnf b/memgpt/local_llm/grammars/json_func_calls_with_inner_thoughts.gbnf new file mode 100644 index 0000000000..4ab308cfd7 --- /dev/null +++ b/memgpt/local_llm/grammars/json_func_calls_with_inner_thoughts.gbnf @@ -0,0 +1,25 @@ +root ::= Function +Function ::= SendMessage | PauseHeartbeats | CoreMemoryAppend | CoreMemoryReplace | ConversationSearch | ConversationSearchDate | ArchivalMemoryInsert | ArchivalMemorySearch +SendMessage ::= "{" ws "\"function\":" ws "\"send_message\"," ws "\"params\":" ws SendMessageParams "}" +PauseHeartbeats ::= "{" ws "\"function\":" ws "\"pause_heartbeats\"," ws "\"params\":" ws PauseHeartbeatsParams "}" +CoreMemoryAppend ::= "{" ws "\"function\":" ws "\"core_memory_append\"," ws "\"params\":" ws CoreMemoryAppendParams "}" +CoreMemoryReplace ::= "{" ws "\"function\":" ws "\"core_memory_replace\"," ws "\"params\":" ws CoreMemoryReplaceParams "}" +ConversationSearch ::= "{" ws "\"function\":" ws "\"conversation_search\"," ws "\"params\":" ws ConversationSearchParams "}" +ConversationSearchDate ::= "{" ws "\"function\":" ws "\"conversation_search_date\"," ws "\"params\":" ws ConversationSearchDateParams "}" +ArchivalMemoryInsert ::= "{" ws "\"function\":" ws "\"archival_memory_insert\"," ws "\"params\":" ws ArchivalMemoryInsertParams "}" +ArchivalMemorySearch ::= "{" ws "\"function\":" ws "\"archival_memory_search\"," ws "\"params\":" ws ArchivalMemorySearchParams "}" +SendMessageParams ::= "{" ws InnerThoughtsParam "," ws "\"message\":" ws string ws "}" +PauseHeartbeatsParams ::= "{" ws InnerThoughtsParam "," ws "\"minutes\":" ws number ws "}" +CoreMemoryAppendParams ::= "{" ws InnerThoughtsParam "," ws "\"name\":" ws namestring "," ws "\"content\":" ws string ws "," ws RequestHeartbeatParam ws "}" +CoreMemoryReplaceParams ::= "{" ws InnerThoughtsParam "," ws "\"name\":" ws namestring "," ws "\"old_content\":" ws string "," ws "\"new_content\":" ws string ws "," ws RequestHeartbeatParam ws "}" +ConversationSearchParams ::= "{" ws InnerThoughtsParam "," ws "\"query\":" ws string ws "," ws "\"page\":" ws number ws "," ws RequestHeartbeatParam ws "}" +ConversationSearchDateParams ::= "{" ws InnerThoughtsParam "," ws "\"start_date\":" ws string ws "," ws "\"end_date\":" ws string ws "," ws "\"page\":" ws number ws "," ws RequestHeartbeatParam ws "}" +ArchivalMemoryInsertParams ::= "{" ws InnerThoughtsParam "," ws "\"content\":" ws string ws "," ws RequestHeartbeatParam ws "}" +ArchivalMemorySearchParams ::= "{" ws InnerThoughtsParam "," ws "\"query\":" ws string ws "," ws "\"page\":" ws number ws "," ws RequestHeartbeatParam ws "}" +InnerThoughtsParam ::= "\"inner_thoughts\":" ws string +RequestHeartbeatParam ::= "\"request_heartbeat\":" ws boolean +namestring ::= "\"human\"" | "\"persona\"" +string ::= "\"" ([^"\[\]{}]*) "\"" +boolean ::= "true" | "false" +ws ::= "" +number ::= [0-9]+ diff --git a/memgpt/local_llm/koboldcpp/api.py b/memgpt/local_llm/koboldcpp/api.py new file mode 100644 index 0000000000..1e81c59388 --- /dev/null +++ b/memgpt/local_llm/koboldcpp/api.py @@ -0,0 +1,59 @@ +import os +from urllib.parse import urljoin +import requests +import tiktoken + +from .settings import SIMPLE +from ..utils import load_grammar_file +from ...constants import LLM_MAX_TOKENS + +HOST = os.getenv("OPENAI_API_BASE") +HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion +KOBOLDCPP_API_SUFFIX = "/api/v1/generate" +# DEBUG = False +DEBUG = True + + +def count_tokens(s: str, model: str = "gpt-4") -> int: + encoding = tiktoken.encoding_for_model(model) + return len(encoding.encode(s)) + + +def get_koboldcpp_completion(prompt, grammar=None, settings=SIMPLE): + """See https://lite.koboldai.net/koboldcpp_api for API spec""" + prompt_tokens = count_tokens(prompt) + if prompt_tokens > LLM_MAX_TOKENS: + raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {LLM_MAX_TOKENS} tokens)") + + # Settings for the generation, includes the prompt + stop tokens, max length, etc + request = settings + request["prompt"] = prompt + + # Set grammar + if grammar is not None: + request["grammar"] = load_grammar_file(grammar) + + if not HOST.startswith(("http://", "https://")): + raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://") + + try: + # NOTE: llama.cpp server returns the following when it's out of context + # curl: (52) Empty reply from server + URI = urljoin(HOST.strip("/") + "/", KOBOLDCPP_API_SUFFIX.strip("/")) + response = requests.post(URI, json=request) + if response.status_code == 200: + result = response.json() + result = result["results"][0]["text"] + if DEBUG: + print(f"json API response.text: {result}") + else: + raise Exception( + f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}." + + f" Make sure that the koboldcpp server is running and reachable at {URI}." + ) + + except: + # TODO handle gracefully + raise + + return result diff --git a/memgpt/local_llm/koboldcpp/settings.py b/memgpt/local_llm/koboldcpp/settings.py new file mode 100644 index 0000000000..ec2bb19514 --- /dev/null +++ b/memgpt/local_llm/koboldcpp/settings.py @@ -0,0 +1,25 @@ +from ...constants import LLM_MAX_TOKENS + +# see https://lite.koboldai.net/koboldcpp_api#/v1/post_v1_generate +SIMPLE = { + "stop_sequence": [ + "\nUSER:", + "\nASSISTANT:", + "\nFUNCTION RETURN:", + "\nUSER", + "\nASSISTANT", + "\nFUNCTION RETURN", + "\nFUNCTION", + "\nFUNC", + "<|im_start|>", + "<|im_end|>", + "<|im_sep|>", + # '\n' + + # '', + # '<|', + # '\n#', + # '\n\n\n', + ], + "max_context_length": LLM_MAX_TOKENS, + "max_length": 512, +} diff --git a/memgpt/local_llm/llamacpp/api.py b/memgpt/local_llm/llamacpp/api.py new file mode 100644 index 0000000000..ea51f71759 --- /dev/null +++ b/memgpt/local_llm/llamacpp/api.py @@ -0,0 +1,59 @@ +import os +from urllib.parse import urljoin +import requests +import tiktoken + +from .settings import SIMPLE +from ..utils import load_grammar_file +from ...constants import LLM_MAX_TOKENS + +HOST = os.getenv("OPENAI_API_BASE") +HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion +LLAMACPP_API_SUFFIX = "/completion" +# DEBUG = False +DEBUG = True + + +def count_tokens(s: str, model: str = "gpt-4") -> int: + encoding = tiktoken.encoding_for_model(model) + return len(encoding.encode(s)) + + +def get_llamacpp_completion(prompt, grammar=None, settings=SIMPLE): + """See https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md for instructions on how to run the LLM web server""" + prompt_tokens = count_tokens(prompt) + if prompt_tokens > LLM_MAX_TOKENS: + raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {LLM_MAX_TOKENS} tokens)") + + # Settings for the generation, includes the prompt + stop tokens, max length, etc + request = settings + request["prompt"] = prompt + + # Set grammar + if grammar is not None: + request["grammar"] = load_grammar_file(grammar) + + if not HOST.startswith(("http://", "https://")): + raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://") + + try: + # NOTE: llama.cpp server returns the following when it's out of context + # curl: (52) Empty reply from server + URI = urljoin(HOST.strip("/") + "/", LLAMACPP_API_SUFFIX.strip("/")) + response = requests.post(URI, json=request) + if response.status_code == 200: + result = response.json() + result = result["content"] + if DEBUG: + print(f"json API response.text: {result}") + else: + raise Exception( + f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}." + + f" Make sure that the llama.cpp server is running and reachable at {URI}." + ) + + except: + # TODO handle gracefully + raise + + return result diff --git a/memgpt/local_llm/llamacpp/settings.py b/memgpt/local_llm/llamacpp/settings.py new file mode 100644 index 0000000000..870afc2a15 --- /dev/null +++ b/memgpt/local_llm/llamacpp/settings.py @@ -0,0 +1,24 @@ +from ...constants import LLM_MAX_TOKENS + +# see https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md#api-endpoints for options +SIMPLE = { + "stop": [ + "\nUSER:", + "\nASSISTANT:", + "\nFUNCTION RETURN:", + "\nUSER", + "\nASSISTANT", + "\nFUNCTION RETURN", + "\nFUNCTION", + "\nFUNC", + "<|im_start|>", + "<|im_end|>", + "<|im_sep|>", + # '\n' + + # '', + # '<|', + # '\n#', + # '\n\n\n', + ], + # "n_predict": 3072, +} diff --git a/memgpt/local_llm/lmstudio/api.py b/memgpt/local_llm/lmstudio/api.py index 1715ba445f..b867774b2c 100644 --- a/memgpt/local_llm/lmstudio/api.py +++ b/memgpt/local_llm/lmstudio/api.py @@ -53,7 +53,7 @@ def get_lmstudio_completion(prompt, settings=SIMPLE, api="chat"): else: raise Exception( f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}." - + f"Make sure that the LM Studio local inference server is running and reachable at {URI}." + + f" Make sure that the LM Studio local inference server is running and reachable at {URI}." ) except: # TODO handle gracefully diff --git a/memgpt/local_llm/utils.py b/memgpt/local_llm/utils.py index a4436ea459..2456776171 100644 --- a/memgpt/local_llm/utils.py +++ b/memgpt/local_llm/utils.py @@ -1,3 +1,6 @@ +import os + + class DotDict(dict): """Allow dot access on properties similar to OpenAI response object""" @@ -13,3 +16,18 @@ def __getstate__(self): def __setstate__(self, state): vars(self).update(state) + + +def load_grammar_file(grammar): + # Set grammar + grammar_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "grammars", f"{grammar}.gbnf") + + # Check if the file exists + if not os.path.isfile(grammar_file): + # If the file doesn't exist, raise a FileNotFoundError + raise FileNotFoundError(f"The grammar file {grammar_file} does not exist.") + + with open(grammar_file, "r") as file: + grammar_str = file.read() + + return grammar_str diff --git a/memgpt/local_llm/webui/api.py b/memgpt/local_llm/webui/api.py index 82fdc0ce5f..97a5c8858d 100644 --- a/memgpt/local_llm/webui/api.py +++ b/memgpt/local_llm/webui/api.py @@ -4,7 +4,7 @@ import tiktoken from .settings import SIMPLE - +from ..utils import load_grammar_file from ...constants import LLM_MAX_TOKENS HOST = os.getenv("OPENAI_API_BASE") @@ -18,7 +18,7 @@ def count_tokens(s: str, model: str = "gpt-4") -> int: return len(encoding.encode(s)) -def get_webui_completion(prompt, settings=SIMPLE): +def get_webui_completion(prompt, settings=SIMPLE, grammar=None): """See https://github.com/oobabooga/text-generation-webui for instructions on how to run the LLM web server""" prompt_tokens = count_tokens(prompt) if prompt_tokens > LLM_MAX_TOKENS: @@ -28,6 +28,10 @@ def get_webui_completion(prompt, settings=SIMPLE): request = settings request["prompt"] = prompt + # Set grammar + if grammar is not None: + request["grammar_string"] = load_grammar_file(grammar) + if not HOST.startswith(("http://", "https://")): raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://") @@ -42,7 +46,7 @@ def get_webui_completion(prompt, settings=SIMPLE): else: raise Exception( f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}." - + f"Make sure that the web UI server is running and reachable at {URI}." + + f" Make sure that the web UI server is running and reachable at {URI}." ) except: