Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for CPU/GPU choice and initialization before starting the app #2

Merged
merged 3 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion open/text/embeddings/server/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
"""
import uvicorn
import os
from open.text.embeddings.server.app import create_app
from open.text.embeddings.server.app import create_app, initialize_embeddings

if __name__ == "__main__":
initialize_embeddings()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will move the initialize_embeddings() to create_app() as the current approach will break the code of aws.py.

app = create_app()

uvicorn.run(
Expand Down
74 changes: 42 additions & 32 deletions open/text/embeddings/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.embeddings import HuggingFaceBgeEmbeddings
import os
import torch

router = APIRouter()

Expand Down Expand Up @@ -62,39 +63,48 @@ class CreateEmbeddingResponse(BaseModel):

embeddings = None


def _create_embedding(
model: Optional[str],
input: Union[str, List[str]]
):
def initialize_embeddings(model: Optional[str] = None):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for send me the PR.

It make sense to load the text embeddings model before the first request. I leave the code of loading the embedding model on-demand back then due to I haven't make up my mind should the API serve multiple models or single model. I think it is clear now that the API serve single text embeddings model is the way to go, the code is simpler. If another text embeddings model is needed, the user can simply deploy another instance of the API.

Hence, the model parameter no longer needed here:

Suggested change
def initialize_embeddings(model: Optional[str] = None):
def initialize_embeddings():

global embeddings

if embeddings is None:
if model and model != "text-embedding-ada-002":
model_name = model
else:
model_name = os.environ["MODEL"]
print("Loading model:", model_name)
encode_kwargs = {
"normalize_embeddings": bool(os.environ.get("NORMALIZE_EMBEDDINGS", ""))
}
print("encode_kwargs", encode_kwargs)
if "e5" in model_name:
embeddings = HuggingFaceInstructEmbeddings(model_name=model_name,
embed_instruction=E5_EMBED_INSTRUCTION,
query_instruction=E5_QUERY_INSTRUCTION,
encode_kwargs=encode_kwargs)
elif model_name.startswith("BAAI/bge-") and model_name.endswith("-en"):
embeddings = HuggingFaceBgeEmbeddings(model_name=model_name,
query_instruction=BGE_EN_QUERY_INSTRUCTION,
encode_kwargs=encode_kwargs)
elif model_name.startswith("BAAI/bge-") and model_name.endswith("-zh"):
embeddings = HuggingFaceBgeEmbeddings(model_name=model_name,
query_instruction=BGE_ZH_QUERY_INSTRUCTION,
encode_kwargs=encode_kwargs)
else:
embeddings = HuggingFaceEmbeddings(
model_name=model_name, encode_kwargs=encode_kwargs)
if "DEVICE" in os.environ:
device = os.environ["DEVICE"]
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

if model and model != "text-embedding-ada-002":
model_name = model
else:
model_name = os.environ["MODEL"]
print("Loading model:", model_name)
encode_kwargs = {
"normalize_embeddings": bool(os.environ.get("NORMALIZE_EMBEDDINGS", ""))
}
print("encode_kwargs", encode_kwargs)
if "e5" in model_name:
embeddings = HuggingFaceInstructEmbeddings(model_name=model_name,
embed_instruction=E5_EMBED_INSTRUCTION,
query_instruction=E5_QUERY_INSTRUCTION,
encode_kwargs=encode_kwargs,
model_kwargs={"device": device})
elif model_name.startswith("BAAI/bge-") and model_name.endswith("-en"):
embeddings = HuggingFaceBgeEmbeddings(model_name=model_name,
query_instruction=BGE_EN_QUERY_INSTRUCTION,
encode_kwargs=encode_kwargs,
model_kwargs={"device": device})
elif model_name.startswith("BAAI/bge-") and model_name.endswith("-zh"):
embeddings = HuggingFaceBgeEmbeddings(model_name=model_name,
query_instruction=BGE_ZH_QUERY_INSTRUCTION,
encode_kwargs=encode_kwargs,
model_kwargs={"device": device})
else:
embeddings = HuggingFaceEmbeddings(model_name=model_name,
encode_kwargs=encode_kwargs,
model_kwargs={"device": device})


def _create_embedding(input: Union[str, List[str]]):
global embeddings

if isinstance(input, str):
return CreateEmbeddingResponse(data=[Embedding(embedding=embeddings.embed_query(input))])
Expand All @@ -112,5 +122,5 @@ async def create_embedding(
request: CreateEmbeddingRequest
):
return await run_in_threadpool(
_create_embedding, **request.dict(exclude={"user"})
_create_embedding, **request.dict(exclude={"user", "model", "model_config"})
)