Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add grammar-based sampling (for webui, llamacpp, and koboldcpp) #293

Merged
merged 12 commits into from
Nov 4, 2023
43 changes: 37 additions & 6 deletions memgpt/local_llm/chat_completion_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from .webui.api import get_webui_completion
from .lmstudio.api import get_lmstudio_completion
from .llamacpp.api import get_llamacpp_completion
from .koboldcpp.api import get_koboldcpp_completion
from .llm_chat_completion_wrappers import airoboros, dolphin, zephyr, simple_summary_wrapper
from .utils import DotDict
from ..prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE
Expand All @@ -14,7 +16,8 @@
HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
DEBUG = False
DEFAULT_WRAPPER = airoboros.Airoboros21InnerMonologueWrapper()
# DEBUG = True
DEFAULT_WRAPPER = airoboros.Airoboros21InnerMonologueWrapper
has_shown_warning = False


Expand All @@ -25,6 +28,7 @@ def get_chat_completion(
function_call="auto",
):
global has_shown_warning
grammar = None

if HOST is None:
raise ValueError(f"The OPENAI_API_BASE environment variable is not defined. Please set it in your environment.")
Expand All @@ -39,21 +43,41 @@ def get_chat_completion(
llm_wrapper = simple_summary_wrapper.SimpleSummaryWrapper()
elif model == "airoboros-l2-70b-2.1":
llm_wrapper = airoboros.Airoboros21InnerMonologueWrapper()
elif model == "airoboros-l2-70b-2.1-grammar":
llm_wrapper = airoboros.Airoboros21InnerMonologueWrapper(include_opening_brace_in_prefix=False)
# grammar_name = "json"
cpacker marked this conversation as resolved.
Show resolved Hide resolved
grammar_name = "json_func_calls_with_inner_thoughts"
elif model == "dolphin-2.1-mistral-7b":
llm_wrapper = dolphin.Dolphin21MistralWrapper()
elif model == "dolphin-2.1-mistral-7b-grammar":
llm_wrapper = dolphin.Dolphin21MistralWrapper(include_opening_brace_in_prefix=False)
# grammar_name = "json"
grammar_name = "json_func_calls_with_inner_thoughts"
elif model == "zephyr-7B-alpha" or model == "zephyr-7B-beta":
llm_wrapper = zephyr.ZephyrMistralInnerMonologueWrapper()
elif model == "zephyr-7B-alpha-grammar" or model == "zephyr-7B-beta-grammar":
llm_wrapper = zephyr.ZephyrMistralInnerMonologueWrapper(include_opening_brace_in_prefix=False)
# grammar_name = "json"
grammar_name = "json_func_calls_with_inner_thoughts"
else:
# Warn the user that we're using the fallback
if not has_shown_warning:
print(
f"Warning: no wrapper specified for local LLM, using the default wrapper (you can remove this warning by specifying the wrapper with --model)"
)
has_shown_warning = True
llm_wrapper = DEFAULT_WRAPPER
if HOST_TYPE in ["koboldcpp", "llamacpp", "webui"]:
# make the default to use grammar
llm_wrapper = DEFAULT_WRAPPER(include_opening_brace_in_prefix=False)
# grammar_name = "json"
grammar_name = "json_func_calls_with_inner_thoughts"
else:
llm_wrapper = DEFAULT_WRAPPER()

# First step: turn the message sequence into a prompt that the model expects
if grammar is not None and HOST_TYPE not in ["koboldcpp", "llamacpp", "webui"]:
print(f"Warning: grammars are currently only supported when using llama.cpp as the MemGPT local LLM backend")

# First step: turn the message sequence into a prompt that the model expects
try:
prompt = llm_wrapper.chat_completion_to_prompt(messages, functions)
if DEBUG:
Expand All @@ -65,17 +89,24 @@ def get_chat_completion(

try:
if HOST_TYPE == "webui":
result = get_webui_completion(prompt)
result = get_webui_completion(prompt, grammar=grammar_name)
elif HOST_TYPE == "lmstudio":
result = get_lmstudio_completion(prompt)
vivi marked this conversation as resolved.
Show resolved Hide resolved
elif HOST_TYPE == "llamacpp":
result = get_llamacpp_completion(prompt, grammar=grammar_name)
elif HOST_TYPE == "koboldcpp":
result = get_koboldcpp_completion(prompt, grammar=grammar_name)
else:
print(f"Warning: BACKEND_TYPE was not set, defaulting to webui")
result = get_webui_completion(prompt)
raise LocalLLMError(
f"BACKEND_TYPE is not set, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)"
)
except requests.exceptions.ConnectionError as e:
raise LocalLLMConnectionError(f"Unable to connect to host {HOST}")

if result is None or result == "":
raise LocalLLMError(f"Got back an empty response string from {HOST}")
if DEBUG:
print(f"Raw LLM output:\n{result}")

try:
chat_completion_result = llm_wrapper.output_to_chat_completion_response(result)
Expand Down
26 changes: 26 additions & 0 deletions memgpt/local_llm/grammars/json.gbnf
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# https://github.com/ggerganov/llama.cpp/blob/master/grammars/json.gbnf
root ::= object
value ::= object | array | string | number | ("true" | "false" | "null") ws

object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}" ws

array ::=
"[" ws (
value
("," ws value)*
)? "]" ws

string ::=
"\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws

number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws

# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?
25 changes: 25 additions & 0 deletions memgpt/local_llm/grammars/json_func_calls_with_inner_thoughts.gbnf
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
root ::= Function
Function ::= SendMessage | PauseHeartbeats | CoreMemoryAppend | CoreMemoryReplace | ConversationSearch | ConversationSearchDate | ArchivalMemoryInsert | ArchivalMemorySearch
SendMessage ::= "{" ws "\"function\":" ws "\"send_message\"," ws "\"params\":" ws SendMessageParams "}"
PauseHeartbeats ::= "{" ws "\"function\":" ws "\"pause_heartbeats\"," ws "\"params\":" ws PauseHeartbeatsParams "}"
CoreMemoryAppend ::= "{" ws "\"function\":" ws "\"core_memory_append\"," ws "\"params\":" ws CoreMemoryAppendParams "}"
CoreMemoryReplace ::= "{" ws "\"function\":" ws "\"core_memory_replace\"," ws "\"params\":" ws CoreMemoryReplaceParams "}"
ConversationSearch ::= "{" ws "\"function\":" ws "\"conversation_search\"," ws "\"params\":" ws ConversationSearchParams "}"
ConversationSearchDate ::= "{" ws "\"function\":" ws "\"conversation_search_date\"," ws "\"params\":" ws ConversationSearchDateParams "}"
ArchivalMemoryInsert ::= "{" ws "\"function\":" ws "\"archival_memory_insert\"," ws "\"params\":" ws ArchivalMemoryInsertParams "}"
ArchivalMemorySearch ::= "{" ws "\"function\":" ws "\"archival_memory_search\"," ws "\"params\":" ws ArchivalMemorySearchParams "}"
SendMessageParams ::= "{" ws InnerThoughtsParam "," ws "\"message\":" ws string ws "}"
PauseHeartbeatsParams ::= "{" ws InnerThoughtsParam "," ws "\"minutes\":" ws number ws "}"
CoreMemoryAppendParams ::= "{" ws InnerThoughtsParam "," ws "\"name\":" ws namestring "," ws "\"content\":" ws string ws "," ws RequestHeartbeatParam ws "}"
CoreMemoryReplaceParams ::= "{" ws InnerThoughtsParam "," ws "\"name\":" ws namestring "," ws "\"old_content\":" ws string "," ws "\"new_content\":" ws string ws "," ws RequestHeartbeatParam ws "}"
ConversationSearchParams ::= "{" ws InnerThoughtsParam "," ws "\"query\":" ws string ws "," ws "\"page\":" ws number ws "," ws RequestHeartbeatParam ws "}"
ConversationSearchDateParams ::= "{" ws InnerThoughtsParam "," ws "\"start_date\":" ws string ws "," ws "\"end_date\":" ws string ws "," ws "\"page\":" ws number ws "," ws RequestHeartbeatParam ws "}"
ArchivalMemoryInsertParams ::= "{" ws InnerThoughtsParam "," ws "\"content\":" ws string ws "," ws RequestHeartbeatParam ws "}"
ArchivalMemorySearchParams ::= "{" ws InnerThoughtsParam "," ws "\"query\":" ws string ws "," ws "\"page\":" ws number ws "," ws RequestHeartbeatParam ws "}"
InnerThoughtsParam ::= "\"inner_thoughts\":" ws string
RequestHeartbeatParam ::= "\"request_heartbeat\":" ws boolean
namestring ::= "\"human\"" | "\"persona\""
string ::= "\"" ([^"\[\]{}]*) "\""
boolean ::= "true" | "false"
ws ::= ""
number ::= [0-9]+
59 changes: 59 additions & 0 deletions memgpt/local_llm/koboldcpp/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
from urllib.parse import urljoin
import requests
import tiktoken

from .settings import SIMPLE
from ..utils import load_grammar_file
from ...constants import LLM_MAX_TOKENS

HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
KOBOLDCPP_API_SUFFIX = "/api/v1/generate"
# DEBUG = False
DEBUG = True


def count_tokens(s: str, model: str = "gpt-4") -> int:
encoding = tiktoken.encoding_for_model(model)
return len(encoding.encode(s))


def get_koboldcpp_completion(prompt, grammar=None, settings=SIMPLE):
"""See https://lite.koboldai.net/koboldcpp_api for API spec"""
prompt_tokens = count_tokens(prompt)
if prompt_tokens > LLM_MAX_TOKENS:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {LLM_MAX_TOKENS} tokens)")

# Settings for the generation, includes the prompt + stop tokens, max length, etc
request = settings
request["prompt"] = prompt

# Set grammar
if grammar is not None:
request["grammar"] = load_grammar_file(grammar)

if not HOST.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")

try:
# NOTE: llama.cpp server returns the following when it's out of context
# curl: (52) Empty reply from server
URI = urljoin(HOST.strip("/") + "/", KOBOLDCPP_API_SUFFIX.strip("/"))
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()
result = result["results"][0]["text"]
if DEBUG:
print(f"json API response.text: {result}")
else:
raise Exception(
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
+ f" Make sure that the koboldcpp server is running and reachable at {URI}."
)

except:
# TODO handle gracefully
raise

return result
25 changes: 25 additions & 0 deletions memgpt/local_llm/koboldcpp/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from ...constants import LLM_MAX_TOKENS

# see https://lite.koboldai.net/koboldcpp_api#/v1/post_v1_generate
SIMPLE = {
"stop_sequence": [
"\nUSER:",
"\nASSISTANT:",
"\nFUNCTION RETURN:",
"\nUSER",
"\nASSISTANT",
"\nFUNCTION RETURN",
"\nFUNCTION",
"\nFUNC",
"<|im_start|>",
"<|im_end|>",
"<|im_sep|>",
# '\n' +
# '</s>',
# '<|',
# '\n#',
# '\n\n\n',
],
"max_context_length": LLM_MAX_TOKENS,
"max_length": 512,
cpacker marked this conversation as resolved.
Show resolved Hide resolved
}
59 changes: 59 additions & 0 deletions memgpt/local_llm/llamacpp/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
from urllib.parse import urljoin
import requests
import tiktoken

from .settings import SIMPLE
from ..utils import load_grammar_file
from ...constants import LLM_MAX_TOKENS

HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
LLAMACPP_API_SUFFIX = "/completion"
# DEBUG = False
DEBUG = True


def count_tokens(s: str, model: str = "gpt-4") -> int:
encoding = tiktoken.encoding_for_model(model)
return len(encoding.encode(s))


def get_llamacpp_completion(prompt, grammar=None, settings=SIMPLE):
"""See https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md for instructions on how to run the LLM web server"""
prompt_tokens = count_tokens(prompt)
if prompt_tokens > LLM_MAX_TOKENS:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {LLM_MAX_TOKENS} tokens)")

# Settings for the generation, includes the prompt + stop tokens, max length, etc
request = settings
request["prompt"] = prompt

# Set grammar
if grammar is not None:
request["grammar"] = load_grammar_file(grammar)

if not HOST.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")

try:
# NOTE: llama.cpp server returns the following when it's out of context
# curl: (52) Empty reply from server
URI = urljoin(HOST.strip("/") + "/", LLAMACPP_API_SUFFIX.strip("/"))
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()
result = result["content"]
if DEBUG:
print(f"json API response.text: {result}")
else:
raise Exception(
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
+ f" Make sure that the llama.cpp server is running and reachable at {URI}."
)

except:
# TODO handle gracefully
raise

return result
24 changes: 24 additions & 0 deletions memgpt/local_llm/llamacpp/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from ...constants import LLM_MAX_TOKENS

# see https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md#api-endpoints for options
SIMPLE = {
"stop": [
"\nUSER:",
"\nASSISTANT:",
"\nFUNCTION RETURN:",
"\nUSER",
"\nASSISTANT",
"\nFUNCTION RETURN",
"\nFUNCTION",
"\nFUNC",
"<|im_start|>",
"<|im_end|>",
"<|im_sep|>",
# '\n' +
# '</s>',
# '<|',
# '\n#',
# '\n\n\n',
],
# "n_predict": 3072,
}
2 changes: 1 addition & 1 deletion memgpt/local_llm/lmstudio/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_lmstudio_completion(prompt, settings=SIMPLE, api="chat"):
else:
raise Exception(
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
+ f"Make sure that the LM Studio local inference server is running and reachable at {URI}."
+ f" Make sure that the LM Studio local inference server is running and reachable at {URI}."
)
except:
# TODO handle gracefully
Expand Down
18 changes: 18 additions & 0 deletions memgpt/local_llm/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os


class DotDict(dict):
"""Allow dot access on properties similar to OpenAI response object"""

Expand All @@ -13,3 +16,18 @@ def __getstate__(self):

def __setstate__(self, state):
vars(self).update(state)


def load_grammar_file(grammar):
# Set grammar
grammar_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "grammars", f"{grammar}.gbnf")

# Check if the file exists
if not os.path.isfile(grammar_file):
# If the file doesn't exist, raise a FileNotFoundError
raise FileNotFoundError(f"The grammar file {grammar_file} does not exist.")

with open(grammar_file, "r") as file:
grammar_str = file.read()

return grammar_str
10 changes: 7 additions & 3 deletions memgpt/local_llm/webui/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tiktoken

from .settings import SIMPLE

from ..utils import load_grammar_file
from ...constants import LLM_MAX_TOKENS

HOST = os.getenv("OPENAI_API_BASE")
Expand All @@ -18,7 +18,7 @@ def count_tokens(s: str, model: str = "gpt-4") -> int:
return len(encoding.encode(s))


def get_webui_completion(prompt, settings=SIMPLE):
def get_webui_completion(prompt, settings=SIMPLE, grammar=None):
"""See https://github.com/oobabooga/text-generation-webui for instructions on how to run the LLM web server"""
prompt_tokens = count_tokens(prompt)
if prompt_tokens > LLM_MAX_TOKENS:
Expand All @@ -28,6 +28,10 @@ def get_webui_completion(prompt, settings=SIMPLE):
request = settings
request["prompt"] = prompt

# Set grammar
if grammar is not None:
request["grammar_string"] = load_grammar_file(grammar)

if not HOST.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")

Expand All @@ -42,7 +46,7 @@ def get_webui_completion(prompt, settings=SIMPLE):
else:
raise Exception(
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
+ f"Make sure that the web UI server is running and reachable at {URI}."
+ f" Make sure that the web UI server is running and reachable at {URI}."
)

except:
Expand Down