Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Ensure model provided in vLLM inference #820

Merged
merged 23 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion presets/ragengine/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"""

# LLM (Large Language Model) configuration
LLM_INFERENCE_URL = os.getenv("LLM_INFERENCE_URL", "http://localhost:5000/chat")
LLM_INFERENCE_URL = os.getenv("LLM_INFERENCE_URL", "http://localhost:5000/v1/completions")
LLM_ACCESS_SECRET = os.getenv("LLM_ACCESS_SECRET", "default-access-secret")
# LLM_RESPONSE_FIELD = os.getenv("LLM_RESPONSE_FIELD", "result") # Uncomment if needed in the future

Expand Down
118 changes: 101 additions & 17 deletions presets/ragengine/inference/inference.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,32 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
from typing import Any
from dataclasses import field
from llama_index.core.llms import CustomLLM, CompletionResponse, LLMMetadata, CompletionResponseGen
from llama_index.llms.openai import OpenAI
from llama_index.core.llms.callbacks import llm_completion_callback
import requests
from requests.exceptions import HTTPError
from urllib.parse import urlparse, urljoin
from ragengine.config import LLM_INFERENCE_URL, LLM_ACCESS_SECRET #, LLM_RESPONSE_FIELD

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

OPENAI_URL_PREFIX = "https://api.openai.com"
HUGGINGFACE_URL_PREFIX = "https://api-inference.huggingface.co"
DEFAULT_HEADERS = {
"Authorization": f"Bearer {LLM_ACCESS_SECRET}",
"Content-Type": "application/json"
}

class Inference(CustomLLM):
params: dict = {}
_default_model: str = None
_model_retrieval_attempted: bool = False

def set_params(self, params: dict) -> None:
self.params = params
Expand All @@ -25,7 +39,7 @@ def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
pass

@llm_completion_callback()
def complete(self, prompt: str, **kwargs) -> CompletionResponse:
def complete(self, prompt: str, formatted: bool, **kwargs) -> CompletionResponse:
try:
if LLM_INFERENCE_URL.startswith(OPENAI_URL_PREFIX):
return self._openai_complete(prompt, **kwargs, **self.params)
Expand All @@ -38,29 +52,99 @@ def complete(self, prompt: str, **kwargs) -> CompletionResponse:
self.params = {}

def _openai_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
llm = OpenAI(
api_key=LLM_ACCESS_SECRET,
**kwargs # Pass all kwargs directly; kwargs may include model, temperature, max_tokens, etc.
)
return llm.complete(prompt)
return OpenAI(api_key=LLM_ACCESS_SECRET, **kwargs).complete(prompt)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this part do something similar to the vllm part but if you want to merge it now for efficiency, it is good


def _huggingface_remote_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
headers = {"Authorization": f"Bearer {LLM_ACCESS_SECRET}"}
data = {"messages": [{"role": "user", "content": prompt}]}
response = requests.post(LLM_INFERENCE_URL, json=data, headers=headers)
response_data = response.json()
return CompletionResponse(text=str(response_data))
return self._post_request(
{"messages": [{"role": "user", "content": prompt}]},
headers={"Authorization": f"Bearer {LLM_ACCESS_SECRET}"}
)

def _custom_api_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
headers = {"Authorization": f"Bearer {LLM_ACCESS_SECRET}"}
model = kwargs.pop("model", self.get_default_model())
data = {"prompt": prompt, **kwargs}
if model:
data["model"] = model # Include the model only if it is not None

# DEBUG: Call the debugging function
# self._debug_curl_command(data)
try:
return self._post_request(data, headers=DEFAULT_HEADERS)
except HTTPError as e:
if e.response.status_code == 400:
logger.warning(
f"Potential issue with 'model' parameter in API response. "
f"Response: {str(e)}. Attempting to update the model name as a mitigation..."
)
self._default_model = self._fetch_default_model() # Fetch default model dynamically
if self._default_model:
logger.info(f"Default model '{self._default_model}' fetched successfully. Retrying request...")
data["model"] = self._default_model
return self._post_request(data, headers=DEFAULT_HEADERS)
else:
logger.error("Failed to fetch a default model. Aborting retry.")
raise # Re-raise the exception if not recoverable
except Exception as e:
logger.error(f"An unexpected error occurred: {e}")
raise

def _get_models_endpoint(self) -> str:
"""
Constructs the URL for the /v1/models endpoint based on LLM_INFERENCE_URL.
"""
parsed = urlparse(LLM_INFERENCE_URL)
return urljoin(f"{parsed.scheme}://{parsed.netloc}", "/v1/models")

response = requests.post(LLM_INFERENCE_URL, json=data, headers=headers)
response_data = response.json()
def _fetch_default_model(self) -> str:
"""
Fetch the default model from the /v1/models endpoint.
"""
try:
models_url = self._get_models_endpoint()
response = requests.get(models_url, headers=DEFAULT_HEADERS)
response.raise_for_status() # Raise an exception for HTTP errors (includes 404)

models = response.json().get("data", [])
return models[0].get("id") if models else None
except Exception as e:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we check the 404 status code here? sometimes, service server may be in initialization phase. i m thinking we should raise error except 404 error.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current approach (catching Exception and handling 404 implicitly) is sufficient. It already logs 404, which can occur if vLLM is initializing or if the endpoint doesn't exist (e.g., non-vLLM). Let me know if you have a specific log message in mind to add

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if vLLM is initializing, you will get a network error instead of 404. because http server is not launched.

logger.error(f"Error fetching models from {models_url}: {e}. \"model\" parameter will not be included with inference call.")
return None

def _get_default_model(self) -> str:
"""
Returns the cached default model if available, otherwise fetches and caches it.
"""
if not self._default_model and not self._model_retrieval_attempted:
self._model_retrieval_attempted = True
self._default_model = self._fetch_default_model()
return self._default_model

# Dynamically extract the field from the response based on the specified response_field
# completion_text = response_data.get(RESPONSE_FIELD, "No response field found") # not necessary for now
return CompletionResponse(text=str(response_data))
def _post_request(self, data: dict, headers: dict) -> CompletionResponse:
try:
response = requests.post(LLM_INFERENCE_URL, json=data, headers=headers)
response.raise_for_status() # Raise exception for HTTP errors
response_data = response.json()
return CompletionResponse(text=str(response_data))
except requests.RequestException as e:
logger.error(f"Error during POST request to {LLM_INFERENCE_URL}: {e}")
raise

def _debug_curl_command(self, data: dict) -> None:
"""
Constructs and prints the equivalent curl command for debugging purposes.
"""
import json
# Construct curl command
curl_command = (
f"curl -X POST {LLM_INFERENCE_URL} "
+ " ".join([f'-H "{key}: {value}"' for key, value in {
"Authorization": f"Bearer {LLM_ACCESS_SECRET}",
"Content-Type": "application/json"
}.items()])
+ f" -d '{json.dumps(data)}'"
)
logger.info("Equivalent curl command:")
logger.info(curl_command)

@property
def metadata(self) -> LLMMetadata:
Expand Down
14 changes: 10 additions & 4 deletions presets/ragengine/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,17 @@ async def index_documents(request: IndexRequest): # TODO: Research async/sync wh
@app.post("/query", response_model=QueryResponse)
async def query_index(request: QueryRequest):
try:
llm_params = request.llm_params or {} # Default to empty dict if no params provided
rerank_params = request.rerank_params or {} # Default to empty dict if no params provided
return rag_ops.query(request.index_name, request.query, request.top_k, llm_params, rerank_params)
llm_params = request.llm_params or {} # Default to empty dict if no params provided
rerank_params = request.rerank_params or {} # Default to empty dict if no params provided
return rag_ops.query(
request.index_name, request.query, request.top_k, llm_params, rerank_params
)
except ValueError as ve:
raise HTTPException(status_code=400, detail=str(ve)) # Validation issue
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(
status_code=500, detail=f"An unexpected error occurred: {str(e)}"
)

@app.get("/indexed-documents", response_model=ListDocumentsResponse)
async def list_all_indexed_documents():
Expand Down
33 changes: 29 additions & 4 deletions presets/ragengine/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Field, model_validator

from pydantic import BaseModel

class Document(BaseModel):
text: str
Expand All @@ -22,8 +23,32 @@ class QueryRequest(BaseModel):
index_name: str
query: str
top_k: int = 10
llm_params: Optional[Dict] = None # Accept a dictionary for parameters
rerank_params: Optional[Dict] = None # Accept a dictionary for parameters
# Accept a dictionary for our LLM parameters
llm_params: Optional[Dict[str, Any]] = Field(
default_factory=dict,
description="Optional parameters for the language model, e.g., temperature, top_p",
)
# Accept a dictionary for rerank parameters
rerank_params: Optional[Dict[str, Any]] = Field(
default_factory=dict,
description="Optional parameters for reranking, e.g., top_n, batch_size",
)

@model_validator(mode="before")
def validate_params(cls, values: Dict[str, Any]) -> Dict[str, Any]:
llm_params = values.get("llm_params", {})
rerank_params = values.get("rerank_params", {})

# Validate LLM parameters
if "temperature" in llm_params and not (0.0 <= llm_params["temperature"] <= 1.0):
raise ValueError("Temperature must be between 0.0 and 1.0.")
# TODO: More LLM Param Validations here
# Validate rerank parameters
top_k = values.get("top_k")
if "top_n" in rerank_params and rerank_params["top_n"] > top_k:
raise ValueError("Invalid configuration: 'top_n' for reranking cannot exceed 'top_k' from the RAG query.")

return values

class ListDocumentsResponse(BaseModel):
documents: Dict[str, Dict[str, Dict[str, str]]]
Expand Down
2 changes: 1 addition & 1 deletion presets/ragengine/tests/api/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def test_query_index_failure():
}

response = client.post("/query", json=request_data)
assert response.status_code == 500
assert response.status_code == 400
assert response.json()["detail"] == "No such index: 'non_existent_index' exists."


Expand Down
2 changes: 1 addition & 1 deletion presets/ragengine/tests/vector_store/test_base_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_query_documents(self, mock_post, vector_store_manager):
mock_post.assert_called_once_with(
LLM_INFERENCE_URL,
json={"prompt": "Context information is below.\n---------------------\ntype: text\n\nFirst document\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: First\nAnswer: ", "formatted": True, 'temperature': 0.7},
headers={"Authorization": f"Bearer {LLM_ACCESS_SECRET}"}
headers={"Authorization": f"Bearer {LLM_ACCESS_SECRET}", 'Content-Type': 'application/json'}
)

def test_add_document(self, vector_store_manager):
Expand Down
Loading