Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Local LLMs #1306

Merged
merged 24 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions gui/pages/Content/Models/ModelForm.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ import {BeatLoader, ClipLoader} from "react-spinners";
import {ToastContainer, toast} from 'react-toastify';

export default function ModelForm({internalId, getModels, sendModelData}){
const models = ['OpenAI', 'Replicate', 'Hugging Face', 'Google Palm'];
const models = ['OpenAI', 'Replicate', 'Hugging Face', 'Google Palm', 'Local LLM'];
const [selectedModel, setSelectedModel] = useState('Select a Model');
const [modelName, setModelName] = useState('');
const [modelDescription, setModelDescription] = useState('');
const [modelTokenLimit, setModelTokenLimit] = useState(4096);
const [modelEndpoint, setModelEndpoint] = useState('');
const [modelDropdown, setModelDropdown] = useState(false);
const [modelVersion, setModelVersion] = useState('');
const [modelContextLength, setContextLength] = useState(4096);
const [tokenError, setTokenError] = useState(false);
const [lockAddition, setLockAddition] = useState(true);
const [isLoading, setIsLoading] = useState(false)
Expand Down Expand Up @@ -87,7 +88,7 @@ export default function ModelForm({internalId, getModels, sendModelData}){
}

const storeModelDetails = (modelProviderId) => {
storeModel(modelName,modelDescription, modelEndpoint, modelProviderId, modelTokenLimit, "Custom", modelVersion).then((response) =>{
storeModel(modelName,modelDescription, modelEndpoint, modelProviderId, modelTokenLimit, "Custom", modelVersion, modelContextLength).then((response) =>{
setIsLoading(false)
let data = response.data
if (data.error) {
Expand Down Expand Up @@ -155,6 +156,12 @@ export default function ModelForm({internalId, getModels, sendModelData}){
onChange={(event) => setModelVersion(event.target.value)}/>
</div>}

{(selectedModel === 'Local LLM') && <div className="mt_24">
<span>Model Context Length</span>
<input className="input_medium mt_8" type="number" placeholder="Enter Model Context Length" value={modelContextLength}
onChange={(event) => setContextLength(event.target.value)}/>
</div>}

<div className="mt_24">
<span>Token Limit</span>
<input className="input_medium mt_8" type="number" placeholder="Enter Model Token Limit" value={modelTokenLimit}
Expand Down
4 changes: 2 additions & 2 deletions gui/pages/api/DashboardService.js
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,8 @@ export const verifyEndPoint = (model_api_key, end_point, model_provider) => {
});
}

export const storeModel = (model_name, description, end_point, model_provider_id, token_limit, type, version) => {
return api.post(`/models_controller/store_model`,{model_name, description, end_point, model_provider_id, token_limit, type, version});
export const storeModel = (model_name, description, end_point, model_provider_id, token_limit, type, version, context_length) => {
return api.post(`/models_controller/store_model`,{model_name, description, end_point, model_provider_id, token_limit, type, version, context_length});
}

export const fetchModels = () => {
Expand Down
11 changes: 10 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from superagi.llms.replicate import Replicate
from superagi.llms.hugging_face import HuggingFace
from superagi.models.agent_template import AgentTemplate
from superagi.models.models_config import ModelsConfig
from superagi.models.organisation import Organisation
from superagi.models.types.login_request import LoginRequest
from superagi.models.types.validate_llm_api_key_request import ValidateAPIKeyRequest
Expand Down Expand Up @@ -215,6 +216,13 @@ def register_toolkit_for_master_organisation():
Organisation.id == marketplace_organisation_id).first()
if marketplace_organisation is not None:
register_marketplace_toolkits(session, marketplace_organisation)

def local_llm_model_config():
existing_models_config = session.query(ModelsConfig).filter(ModelsConfig.org_id == default_user.organisation_id, ModelsConfig.provider == 'Local LLM').first()
if existing_models_config is None:
models_config = ModelsConfig(org_id=default_user.organisation_id, provider='Local LLM', api_key="EMPTY")
session.add(models_config)
session.commit()

IterationWorkflowSeed.build_single_step_agent(session)
IterationWorkflowSeed.build_task_based_agents(session)
Expand All @@ -238,7 +246,8 @@ def register_toolkit_for_master_organisation():
# AgentWorkflowSeed.doc_search_and_code(session)
# AgentWorkflowSeed.build_research_email_workflow(session)
replace_old_iteration_workflows(session)

local_llm_model_config()

if env != "PROD":
register_toolkit_for_all_organisation()
else:
Expand Down
84 changes: 84 additions & 0 deletions migrations/versions/9270eb5a8475_local_llms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""local_llms

Revision ID: 9270eb5a8475
Revises: 3867bb00a495
Create Date: 2023-10-04 09:26:33.865424

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '9270eb5a8475'
down_revision = '3867bb00a495'
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index('ix_agent_schedule_agent_id', table_name='agent_schedule')
rounak610 marked this conversation as resolved.
Show resolved Hide resolved
op.drop_index('ix_agent_schedule_expiry_date', table_name='agent_schedule')
op.drop_index('ix_agent_schedule_status', table_name='agent_schedule')
op.alter_column('agent_workflow_steps', 'unique_id',
existing_type=sa.VARCHAR(),
nullable=True)
op.alter_column('agent_workflow_steps', 'step_type',
existing_type=sa.VARCHAR(),
nullable=True)
op.drop_column('agent_workflows', 'organisation_id')
op.drop_index('ix_events_agent_id', table_name='events')
op.drop_index('ix_events_event_property', table_name='events')
op.drop_index('ix_events_org_id', table_name='events')
op.alter_column('knowledge_configs', 'knowledge_id',
existing_type=sa.INTEGER(),
nullable=True)
op.alter_column('knowledges', 'name',
existing_type=sa.VARCHAR(),
nullable=True)
op.add_column('models', sa.Column('context_length', sa.Integer(), nullable=True))
op.alter_column('vector_db_configs', 'vector_db_id',
existing_type=sa.INTEGER(),
nullable=True)
op.alter_column('vector_db_indices', 'name',
existing_type=sa.VARCHAR(),
nullable=True)
op.alter_column('vector_dbs', 'name',
existing_type=sa.VARCHAR(),
nullable=True)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column('vector_dbs', 'name',
existing_type=sa.VARCHAR(),
nullable=False)
op.alter_column('vector_db_indices', 'name',
existing_type=sa.VARCHAR(),
nullable=False)
op.alter_column('vector_db_configs', 'vector_db_id',
existing_type=sa.INTEGER(),
nullable=False)
op.drop_column('models', 'context_length')
op.alter_column('knowledges', 'name',
existing_type=sa.VARCHAR(),
nullable=False)
op.alter_column('knowledge_configs', 'knowledge_id',
existing_type=sa.INTEGER(),
nullable=False)
op.create_index('ix_events_org_id', 'events', ['org_id'], unique=False)
op.create_index('ix_events_event_property', 'events', ['event_property'], unique=False)
op.create_index('ix_events_agent_id', 'events', ['agent_id'], unique=False)
op.add_column('agent_workflows', sa.Column('organisation_id', sa.INTEGER(), autoincrement=False, nullable=True))
op.alter_column('agent_workflow_steps', 'step_type',
existing_type=sa.VARCHAR(),
nullable=False)
op.alter_column('agent_workflow_steps', 'unique_id',
existing_type=sa.VARCHAR(),
nullable=False)
op.create_index('ix_agent_schedule_status', 'agent_schedule', ['status'], unique=False)
op.create_index('ix_agent_schedule_expiry_date', 'agent_schedule', ['expiry_date'], unique=False)
op.create_index('ix_agent_schedule_agent_id', 'agent_schedule', ['agent_id'], unique=False)
# ### end Alembic commands ###
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,4 @@ google-generativeai==0.1.0
unstructured==0.8.1
ai21==1.2.6
typing-extensions==4.5.0
llama_cpp_python==0.2.7
37 changes: 35 additions & 2 deletions superagi/controllers/models_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
from superagi.helper.auth import check_auth, get_user_organisation
from superagi.helper.models_helper import ModelsHelper
from superagi.apm.call_log_helper import CallLogHelper
from superagi.lib.logger import logger
from superagi.models.models import Models
from superagi.models.models_config import ModelsConfig
from superagi.config.config import get_config
from superagi.controllers.types.models_types import ModelsTypes
from fastapi_sqlalchemy import db
import logging
from pydantic import BaseModel
from superagi.helper.llm_loader import LLMLoader

router = APIRouter()

Expand All @@ -26,6 +28,7 @@ class StoreModelRequest(BaseModel):
token_limit: int
type: str
version: str
context_length: int

class ModelName (BaseModel):
model: str
Expand Down Expand Up @@ -69,7 +72,9 @@ async def verify_end_point(model_api_key: str = None, end_point: str = None, mod
@router.post("/store_model", status_code=200)
async def store_model(request: StoreModelRequest, organisation=Depends(get_user_organisation)):
try:
return Models.store_model_details(db.session, organisation.id, request.model_name, request.description, request.end_point, request.model_provider_id, request.token_limit, request.type, request.version)
#context_length = 4096
logger.info(request)
return Models.store_model_details(db.session, organisation.id, request.model_name, request.description, request.end_point, request.model_provider_id, request.token_limit, request.type, request.version, request.context_length)
except Exception as e:
logging.error(f"Error storing the Model Details: {str(e)}")
raise HTTPException(status_code=500, detail="Internal Server Error")
Expand Down Expand Up @@ -164,4 +169,32 @@ def get_models_details(page: int = 0):
marketplace_models = Models.fetch_marketplace_list(page)
marketplace_models_with_install = Models.get_model_install_details(db.session, marketplace_models, organisation_id,
ModelsTypes.MARKETPLACE.value)
return marketplace_models_with_install
return marketplace_models_with_install

@router.get("/test_local_llm", status_code=200)
def test_local_llm():
try:
llm_loader = LLMLoader(context_length=4096)
llm_model = llm_loader.model
llm_grammar = llm_loader.grammar
if llm_model is None:
logger.error("Model not found.")
raise HTTPException(status_code=404, detail="Error while loading the model. Please check your model path and try again.")
if llm_grammar is None:
logger.error("Grammar not found.")
raise HTTPException(status_code=404, detail="Grammar not found.")

messages = [
{"role":"system",
"content":"You are an AI assistant. Give response in a proper JSON format"},
{"role":"user",
"content":"Hi!"}
]
response = llm_model.create_chat_completion(messages=messages, grammar=llm_grammar)
content = response["choices"][0]["message"]["content"]
logger.info(content)
return "Model loaded successfully."

except Exception as e:
logger.info("Error: ",e)
raise HTTPException(status_code=404, detail="Error while loading the model. Please check your model path and try again.")
38 changes: 38 additions & 0 deletions superagi/helper/llm_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from llama_cpp import Llama
from llama_cpp import LlamaGrammar
from superagi.config.config import get_config
from superagi.lib.logger import logger


class LLMLoader:
_instance = None
_model = None
_grammar = None

def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(LLMLoader, cls).__new__(cls)
return cls._instance

def __init__(self, context_length):
self.context_length = context_length

@property
def model(self):
if self._model is None:
try:
self._model = Llama(
model_path="/app/local_model_path", n_ctx=self.context_length)
except Exception as e:
logger.error(e)
return self._model

@property
def grammar(self):
if self._grammar is None:
try:
self._grammar = LlamaGrammar.from_file(
"superagi/llms/grammar/json.gbnf")
except Exception as e:
logger.error(e)
return self._grammar
3 changes: 3 additions & 0 deletions superagi/jobs/agent_executor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import datetime, timedelta

from sqlalchemy.orm import sessionmaker
from superagi.llms.local_llm import LocalLLM

import superagi.worker
from superagi.agent.agent_iteration_step_handler import AgentIterationStepHandler
Expand Down Expand Up @@ -135,6 +136,8 @@ def get_embedding(cls, model_source, model_api_key):
return HuggingFace(api_key=model_api_key)
if "Replicate" in model_source:
return Replicate(api_key=model_api_key)
if "Custom" in model_source:
return LocalLLM()
return None

def _check_for_max_iterations(self, session, organisation_id, agent_config, agent_execution_id):
Expand Down
25 changes: 25 additions & 0 deletions superagi/llms/grammar/json.gbnf
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
root ::= object
value ::= object | array | string | number | ("true" | "false" | "null") ws

object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}" ws

array ::=
"[" ws (
value
("," ws value)*
)? "]" ws

string ::=
"\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws

number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws

# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?
6 changes: 6 additions & 0 deletions superagi/llms/llm_model_factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from superagi.llms.google_palm import GooglePalm
from superagi.llms.local_llm import LocalLLM
from superagi.llms.openai import OpenAi
from superagi.llms.replicate import Replicate
from superagi.llms.hugging_face import HuggingFace
Expand Down Expand Up @@ -33,6 +34,9 @@ def get_model(organisation_id, api_key, model="gpt-3.5-turbo", **kwargs):
elif provider_name == 'Hugging Face':
print("Provider is Hugging Face")
return HuggingFace(model=model_instance.model_name, end_point=model_instance.end_point, api_key=api_key, **kwargs)
elif provider_name == 'Local LLM':
print("Provider is Local LLM")
return LocalLLM(model=model_instance.model_name, context_length=model_instance.context_length)
else:
print('Unknown provider.')

Expand All @@ -45,5 +49,7 @@ def build_model_with_api_key(provider_name, api_key):
return GooglePalm(api_key=api_key)
elif provider_name.lower() == 'hugging face':
return HuggingFace(api_key=api_key)
elif provider_name.lower() == 'local llm':
return LocalLLM(api_key=api_key)
else:
print('Unknown provider.')
Loading