Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: config builder middleware #4

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .env.example.global
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
export LITELLM_MASTER_KEY=<master-key starting with sk->
export DATABASE_URL=<db url>
export STORE_MODEL_IN_DB='True'
export LITELLM_SALT_KEY=<salt-key>
export REDIS_HOST=localhost
export REDIS_PORT=6379
export REDIS_PASSWORD=<password>
export BUDSERVE_APP_BASEURL="http://localhost:8000"
212 changes: 212 additions & 0 deletions litellm/custom_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import asyncio
import os
from datetime import datetime, timedelta, timezone

import httpx
from fastapi import HTTPException, Request, status
from pydantic import BaseModel

import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import *
from litellm.proxy.auth.auth_utils import get_request_route, pre_db_read_auth_checks
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body


async def fetch_data(url: str):
async with httpx.AsyncClient() as client:
response = await client.get(url)
return response


budserve_app_baseurl = os.getenv("BUDSERVE_APP_BASEURL", "http://localhost:9000")


async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
"""
Custom Auth dependency for User API Key Authentication
We receive budserve ap key and check if it is valid

Steps:

1. Check api-key in cache
2. Get api-key details from db
3. Check expiry
4. Check budget
5. Check model budget
"""
try:
from litellm.proxy.proxy_server import (
master_key,
prisma_client,
user_api_key_cache,
)

if prisma_client is None:
raise Exception("Prisma client not initialized")

api_key = f"sk-{api_key}"

route: str = get_request_route(request=request)
# get the request body
request_data = await _read_request_body(request=request)
await pre_db_read_auth_checks(
request_data=request_data,
request=request,
route=route,
)

# look for info is user_api_key_auth cache
verbose_proxy_logger.debug(f"API key sent in request >>> {api_key}")
hashed_token = hash_token(api_key)
valid_token: Optional[UserAPIKeyAuth] = (
await user_api_key_cache.async_get_cache(key=hashed_token)
)

verbose_proxy_logger.info(
f"Valid token from cache for key : {hashed_token} >>> {valid_token}"
)
if valid_token is None:
# getting token details from authentication service
url = f"{budserve_app_baseurl}/credentials/details/{api_key.removeprefix('sk-')}"
credential_details_response = await fetch_data(url)
if credential_details_response.status_code != 200:
# No token was found when looking up in the DB
raise Exception("Invalid api key passed")
credential_dict = credential_details_response.json()["result"]
# credential_dict = {
# "key": api_key.removeprefix("sk-"),
# "expiry": (datetime.now() + timedelta(days=1)).strftime(
# "%Y-%m-%d %H:%M:%S"
# ),
# "max_budget": 1,
# "model_budgets": {"gpt-4": 0.003, "gpt-3.5-turbo": 0.002},
# }
valid_token = UserAPIKeyAuth(
api_key=f"sk-{credential_dict['key']}",
expires=credential_dict["expiry"],
max_budget=credential_dict["max_budget"],
model_max_budget=credential_dict["model_budgets"] or {},
)
api_key_spend = await prisma_client.db.litellm_spendlogs.group_by(
by=["api_key"],
sum={"spend": True},
where={
"AND": [
{"api_key": valid_token.token},
]
}, # type: ignore
)
if (
len(api_key_spend) > 0
and "_sum" in api_key_spend[0]
and "spend" in api_key_spend[0]["_sum"]
and api_key_spend[0]["_sum"]["spend"]
):
valid_token.spend = api_key_spend[0]["_sum"]["spend"]
# Add hashed token to cache
verbose_proxy_logger.info(
f"Valid token storing in cache for key : {valid_token.token}"
)
await user_api_key_cache.async_set_cache(
key=valid_token.token,
value=valid_token,
)
verbose_proxy_logger.info(f"Valid token from DB >>> {valid_token}")
verbose_proxy_logger.info(f"Valid token spend >> {valid_token.spend}")
if valid_token is not None:
if valid_token.expires is not None:
current_time = datetime.now(timezone.utc)
expiry_time = datetime.fromisoformat(valid_token.expires)
if (
expiry_time.tzinfo is None
or expiry_time.tzinfo.utcoffset(expiry_time) is None
):
expiry_time = expiry_time.replace(tzinfo=timezone.utc)
verbose_proxy_logger.debug(
f"Checking if token expired, expiry time {expiry_time} and current time {current_time}"
)
if expiry_time < current_time:
# Token exists but is expired.
raise ProxyException(
message=f"Authentication Error - Expired Key. Key Expiry time {expiry_time} and current time {current_time}",
type=ProxyErrorTypes.expired_key,
code=400,
param=api_key,
)
if valid_token.spend is not None and valid_token.max_budget is not None:
if valid_token.spend >= valid_token.max_budget:
raise litellm.BudgetExceededError(
current_cost=valid_token.spend,
max_budget=valid_token.max_budget,
)
max_budget_per_model = valid_token.model_max_budget
current_model = request_data.get("model", None)
if (
max_budget_per_model is not None
and isinstance(max_budget_per_model, dict)
and len(max_budget_per_model) > 0
and prisma_client is not None
and current_model is not None
and valid_token.token is not None
):
## GET THE SPEND FOR THIS MODEL
twenty_eight_days_ago = datetime.now() - timedelta(days=28)
model_spend = await prisma_client.db.litellm_spendlogs.group_by(
by=["model"],
sum={"spend": True},
where={
"AND": [
{"api_key": valid_token.token},
{"startTime": {"gt": twenty_eight_days_ago}},
{"model": current_model},
]
}, # type: ignore
)
verbose_proxy_logger.debug(f"model spends >> {model_spend}")
if (
len(model_spend) > 0
and max_budget_per_model.get(current_model, None) is not None
):
if (
"model" in model_spend[0]
and model_spend[0].get("model") == current_model
and "_sum" in model_spend[0]
and "spend" in model_spend[0]["_sum"]
and model_spend[0]["_sum"]["spend"]
>= max_budget_per_model[current_model]
):
current_model_spend = model_spend[0]["_sum"]["spend"]
current_model_budget = max_budget_per_model[current_model]
raise litellm.BudgetExceededError(
current_cost=current_model_spend,
max_budget=current_model_budget,
)
return valid_token
else:
# No token was found when looking up in the DB
raise Exception("Invalid api key passed")

except Exception as e:
if isinstance(e, litellm.BudgetExceededError):
raise ProxyException(
message=e.message,
type=ProxyErrorTypes.budget_exceeded,
param=None,
code=400,
)
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
type=ProxyErrorTypes.auth_error,
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_401_UNAUTHORIZED),
)
elif isinstance(e, ProxyException):
raise e
raise ProxyException(
message="Authentication Error, " + str(e),
type=ProxyErrorTypes.auth_error,
param=getattr(e, "param", "None"),
code=status.HTTP_401_UNAUTHORIZED,
)
99 changes: 99 additions & 0 deletions litellm/proxy/budserve_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import json

from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware

from litellm._logging import verbose_proxy_logger
from litellm.proxy.auth.auth_utils import get_request_route
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body


class BudServeMiddleware(BaseHTTPMiddleware):
llm_request_list = [
"/chat/completions",
"/completions",
"/embeddings",
"/images/generation",
"/audio/speech",
"/audio/transcriptions",
]

async def get_api_key(self, request):
authorization_header = request.headers.get("Authorization")
api_key = authorization_header.split(" ")[1]
return api_key

async def dispatch(
self,
request,
call_next,
):
"""
Steps to prepare user_config

1. api_key and model (endpoint_name) fetch all endpoint details : model_list
2. Using models involved in endpoint details, fetch proprietary credentials
3. Create user_config using model_configuration (endpoint model) and router_config (project model)
4. Add validations for fallbacks
"""
route: str = get_request_route(request=request)
verbose_proxy_logger.info(f"Request: {route}")
run_through_middleware = any(
each_route in route for each_route in self.llm_request_list
)
verbose_proxy_logger.info(f"Run Through Middleware: {run_through_middleware}")
if not run_through_middleware:
return await call_next(request)

# get the request body
request_data = await _read_request_body(request=request)
api_key = await self.get_api_key(request)
endpoint_name = request_data.get("model")

# get endpoint details to fill cache_params
# redis connection params we will set as kubernetes env variables
# can be fetched using os.getenv
import os

request_data["user_config"] = {
"cache_responses": False,
"redis_host": os.getenv("REDIS_HOST", "localhost"),
"redis_port": os.getenv("REDIS_PORT", 6379),
"redis_password": os.getenv("REDIS_PASSWORD", ""),
"endpoint_cache_settings": {
"cache": False,
"type": "redis-semantic", # gpt_cache_redis
"cache_params": {
"host": os.getenv("REDIS_HOST", "localhost"),
"port": os.getenv("REDIS_PORT", 6379),
"password": os.getenv("REDIS_PASSWORD", ""),
"similarity_threshold": 0.8,
"redis_semantic_cache_use_async": False,
"redis_semantic_cache_embedding_model": "sentence-transformers/all-mpnet-base-v2",
"eviction_policy": {"policy": "ttl", "max_size": 100, "ttl": 600},
},
},
"model_list": [
{
"model_name": "gpt4",
"litellm_params": {
"model": "openai/gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY", "dummy"),
"rpm": 100,
"request_timeout": 120,
},
"model_info": {"id": "model_id:123"},
},
{
"model_name": "gpt4",
"litellm_params": {
"model": "openai/gpt-4",
"api_key": os.getenv("OPENAI_API_KEY", "dummy"),
"tpm": 10000,
},
"model_info": {"id": "model_id:456"},
},
],
}
request._body = json.dumps(request_data).encode("utf-8")
return await call_next(request)
2 changes: 2 additions & 0 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def generate_feedback_box():
get_team_models,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.budserve_middleware import BudServeMiddleware

## Import All Misc routes here ##
from litellm.proxy.caching_routes import router as caching_router
Expand Down Expand Up @@ -450,6 +451,7 @@ async def redirect_ui_middleware(request: Request, call_next):
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(BudServeMiddleware)


from typing import Dict
Expand Down
20 changes: 17 additions & 3 deletions litellm/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def __init__(
router_general_settings: Optional[
RouterGeneralSettings
] = RouterGeneralSettings(),
endpoint_cache_settings: Optional[dict] = None,
) -> None:
"""
Initialize the Router class with the given parameters for caching, reliability, and routing strategy.
Expand Down Expand Up @@ -308,7 +309,13 @@ def __init__(
and redis_port is not None
and redis_password is not None
):
cache_type = "redis"
cache_type = (
"redis"
if endpoint_cache_settings is None
else endpoint_cache_settings.get("cache_params", {}).get(
"type", "redis"
)
)

if redis_url is not None:
cache_config["url"] = redis_url
Expand All @@ -327,9 +334,16 @@ def __init__(
redis_cache = RedisCache(**cache_config)

if cache_responses:
if litellm.cache is None:
if litellm.cache is None and endpoint_cache_settings is not None:
# the cache can be initialized on the proxy server. We should not overwrite it
litellm.cache = litellm.Cache(type=cache_type, **cache_config) # type: ignore
# user_config : enabled cache
enable_cache = endpoint_cache_settings.get("cache", False)
if enable_cache:
endpoint_cache_config = endpoint_cache_settings.get(
"cache_params", {}
)
if endpoint_cache_config:
litellm.cache = litellm.Cache(type=cache_type, **endpoint_cache_config) # type: ignore
self.cache_responses = cache_responses
self.cache = DualCache(
redis_cache=redis_cache, in_memory_cache=InMemoryCache()
Expand Down
8 changes: 8 additions & 0 deletions litellm_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
general_settings:
store_model_in_db: True
custom_auth: litellm.custom_auth.user_api_key_auth
router_settings:
cache_responses: False
redis_host: "os.environ/REDIS_HOST"
redis_port: "os.environ/REDIS_PORT"
redis_password: "os.environ/REDIS_PASSWORD"