From da99090f6523a53ac8a3357ee7249671cada526f Mon Sep 17 00:00:00 2001 From: lkk <33276950+lkk12014402@users.noreply.github.com> Date: Fri, 13 Dec 2024 09:31:11 +0800 Subject: [PATCH] remove examples gateway. (#979) * remove examples gateway. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove gateway. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refine service code. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update http_service.py * remove gateway ut. * remove gateway ut. * fix conflict service name. * Update http_service.py * add handle message ut. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove `multiprocessing.Process` start server code. * fix ut. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove multiprocessing and enhance ut for coverage. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: chen, suyue --- README.md | 14 - comps/__init__.py | 17 - comps/cores/mega/gateway.py | 1117 ----------------- comps/cores/mega/http_service.py | 32 +- comps/cores/mega/micro_service.py | 103 +- comps/cores/mega/utils.py | 73 ++ tests/cores/mega/test_aio.py | 16 +- tests/cores/mega/test_base_statistics.py | 5 +- tests/cores/mega/test_dynamic_batching.py | 5 +- tests/cores/mega/test_handle_message.py | 133 ++ .../mega/test_hybrid_service_orchestrator.py | 11 +- ...t_hybrid_service_orchestrator_with_yaml.py | 5 +- tests/cores/mega/test_microservice.py | 23 +- .../cores/mega/test_multimodalqna_gateway.py | 213 ---- tests/cores/mega/test_runtime_graph.py | 18 +- tests/cores/mega/test_service_orchestrator.py | 9 +- .../test_service_orchestrator_protocol.py | 5 +- .../test_service_orchestrator_streaming.py | 9 +- .../test_service_orchestrator_with_gateway.py | 52 - ...orchestrator_with_retriever_rerank_fake.py | 15 +- ...rvice_orchestrator_with_videoqnagateway.py | 73 -- .../test_service_orchestrator_with_yaml.py | 10 +- 22 files changed, 358 insertions(+), 1600 deletions(-) delete mode 100644 comps/cores/mega/gateway.py create mode 100644 tests/cores/mega/test_handle_message.py delete mode 100644 tests/cores/mega/test_multimodalqna_gateway.py delete mode 100644 tests/cores/mega/test_service_orchestrator_with_gateway.py delete mode 100644 tests/cores/mega/test_service_orchestrator_with_videoqnagateway.py diff --git a/README.md b/README.md index f8c2557849..3a4b17bad9 100644 --- a/README.md +++ b/README.md @@ -125,20 +125,6 @@ class ExampleService: self.megaservice.flow_to(embedding, llm) ``` -## Gateway - -The `Gateway` serves as the interface for users to access the `Megaservice`, providing customized access based on user requirements. It acts as the entry point for incoming requests, routing them to the appropriate `Microservices` within the `Megaservice` architecture. - -`Gateways` support API definition, API versioning, rate limiting, and request transformation, allowing for fine-grained control over how users interact with the underlying `Microservices`. By abstracting the complexity of the underlying infrastructure, `Gateways` provide a seamless and user-friendly experience for interacting with the `Megaservice`. - -For example, the `Gateway` for `ChatQnA` can be built like this: - -```python -from comps import ChatQnAGateway - -self.gateway = ChatQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) -``` - ## Contributing to OPEA Welcome to the OPEA open-source community! We are thrilled to have you here and excited about the potential contributions you can bring to the OPEA platform. Whether you are fixing bugs, adding new GenAI components, improving documentation, or sharing your unique use cases, your contributions are invaluable. diff --git a/comps/__init__.py b/comps/__init__.py index 8fe3ac5fdf..240302c75c 100644 --- a/comps/__init__.py +++ b/comps/__init__.py @@ -47,23 +47,6 @@ from comps.cores.mega.orchestrator import ServiceOrchestrator from comps.cores.mega.orchestrator_with_yaml import ServiceOrchestratorWithYaml from comps.cores.mega.micro_service import MicroService, register_microservice, opea_microservices -from comps.cores.mega.gateway import ( - Gateway, - ChatQnAGateway, - CodeGenGateway, - CodeTransGateway, - DocSumGateway, - TranslationGateway, - SearchQnAGateway, - AudioQnAGateway, - RetrievalToolGateway, - FaqGenGateway, - VideoQnAGateway, - VisualQnAGateway, - MultimodalQnAGateway, - GraphragGateway, - AvatarChatbotGateway, -) # Telemetry from comps.cores.telemetry.opea_telemetry import opea_telemetry diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py deleted file mode 100644 index 29642eea55..0000000000 --- a/comps/cores/mega/gateway.py +++ /dev/null @@ -1,1117 +0,0 @@ -# Copyright (C) 2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -import base64 -import os -from io import BytesIO -from typing import List, Union - -import requests -from fastapi import File, Request, UploadFile -from fastapi.responses import StreamingResponse -from PIL import Image - -from ..proto.api_protocol import ( - AudioChatCompletionRequest, - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseChoice, - ChatMessage, - DocSumChatCompletionRequest, - EmbeddingRequest, - UsageInfo, -) -from ..proto.docarray import DocSumDoc, LLMParams, LLMParamsDoc, RerankedDoc, RerankerParms, RetrieverParms, TextDoc -from .constants import MegaServiceEndpoint, ServiceRoleType, ServiceType -from .micro_service import MicroService - - -def read_pdf(file): - from langchain.document_loaders import PyPDFLoader - - loader = PyPDFLoader(file) - docs = loader.load_and_split() - return docs - - -def read_text_from_file(file, save_file_name): - import docx2txt - from langchain.text_splitter import CharacterTextSplitter - - # read text file - if file.headers["content-type"] == "text/plain": - file.file.seek(0) - content = file.file.read().decode("utf-8") - # Split text - text_splitter = CharacterTextSplitter() - texts = text_splitter.split_text(content) - # Create multiple documents - file_content = texts - # read pdf file - elif file.headers["content-type"] == "application/pdf": - documents = read_pdf(save_file_name) - file_content = [doc.page_content for doc in documents] - # read docx file - elif ( - file.headers["content-type"] == "application/vnd.openxmlformats-officedocument.wordprocessingml.document" - or file.headers["content-type"] == "application/octet-stream" - ): - file_content = docx2txt.process(save_file_name) - - return file_content - - -class Gateway: - def __init__( - self, - megaservice, - host="0.0.0.0", - port=8888, - endpoint=str(MegaServiceEndpoint.CHAT_QNA), - input_datatype=ChatCompletionRequest, - output_datatype=ChatCompletionResponse, - ): - self.megaservice = megaservice - self.host = host - self.port = port - self.endpoint = endpoint - self.input_datatype = input_datatype - self.output_datatype = output_datatype - self.service = MicroService( - self.__class__.__name__, - service_role=ServiceRoleType.MEGASERVICE, - service_type=ServiceType.GATEWAY, - host=self.host, - port=self.port, - endpoint=self.endpoint, - input_datatype=self.input_datatype, - output_datatype=self.output_datatype, - ) - self.define_routes() - self.service.start() - - def define_routes(self): - self.service.app.router.add_api_route(self.endpoint, self.handle_request, methods=["POST"]) - self.service.app.router.add_api_route(str(MegaServiceEndpoint.LIST_SERVICE), self.list_service, methods=["GET"]) - self.service.app.router.add_api_route( - str(MegaServiceEndpoint.LIST_PARAMETERS), self.list_parameter, methods=["GET"] - ) - - def add_route(self, endpoint, handler, methods=["POST"]): - self.service.app.router.add_api_route(endpoint, handler, methods=methods) - - def stop(self): - self.service.stop() - - async def handle_request(self, request: Request): - raise NotImplementedError("Subclasses must implement this method") - - def list_service(self): - response = {} - for node, service in self.megaservice.services.items(): - # Check if the service has a 'description' attribute and it is not None - if hasattr(service, "description") and service.description: - response[node] = {"description": service.description} - # Check if the service has an 'endpoint' attribute and it is not None - if hasattr(service, "endpoint") and service.endpoint: - if node in response: - response[node]["endpoint"] = service.endpoint - else: - response[node] = {"endpoint": service.endpoint} - # If neither 'description' nor 'endpoint' is available, add an error message for the node - if node not in response: - response[node] = {"error": f"Service node {node} does not have 'description' or 'endpoint' attribute."} - return response - - def list_parameter(self): - pass - - def _handle_message(self, messages): - images = [] - if isinstance(messages, str): - prompt = messages - else: - messages_dict = {} - system_prompt = "" - prompt = "" - for message in messages: - msg_role = message["role"] - if msg_role == "system": - system_prompt = message["content"] - elif msg_role == "user": - if type(message["content"]) == list: - text = "" - text_list = [item["text"] for item in message["content"] if item["type"] == "text"] - text += "\n".join(text_list) - image_list = [ - item["image_url"]["url"] for item in message["content"] if item["type"] == "image_url" - ] - if image_list: - messages_dict[msg_role] = (text, image_list) - else: - messages_dict[msg_role] = text - else: - messages_dict[msg_role] = message["content"] - elif msg_role == "assistant": - messages_dict[msg_role] = message["content"] - else: - raise ValueError(f"Unknown role: {msg_role}") - - if system_prompt: - prompt = system_prompt + "\n" - for role, message in messages_dict.items(): - if isinstance(message, tuple): - text, image_list = message - if text: - prompt += role + ": " + text + "\n" - else: - prompt += role + ":" - for img in image_list: - # URL - if img.startswith("http://") or img.startswith("https://"): - response = requests.get(img) - image = Image.open(BytesIO(response.content)).convert("RGBA") - image_bytes = BytesIO() - image.save(image_bytes, format="PNG") - img_b64_str = base64.b64encode(image_bytes.getvalue()).decode() - # Local Path - elif os.path.exists(img): - image = Image.open(img).convert("RGBA") - image_bytes = BytesIO() - image.save(image_bytes, format="PNG") - img_b64_str = base64.b64encode(image_bytes.getvalue()).decode() - # Bytes - else: - img_b64_str = img - - images.append(img_b64_str) - else: - if message: - prompt += role + ": " + message + "\n" - else: - prompt += role + ":" - if images: - return prompt, images - else: - return prompt - - -class ChatQnAGateway(Gateway): - def __init__(self, megaservice, host="0.0.0.0", port=8888): - super().__init__( - megaservice, host, port, str(MegaServiceEndpoint.CHAT_QNA), ChatCompletionRequest, ChatCompletionResponse - ) - - async def handle_request(self, request: Request): - data = await request.json() - print("data in handle request", data) - stream_opt = data.get("stream", True) - chat_request = ChatCompletionRequest.parse_obj(data) - print("chat request in handle request", chat_request) - prompt = self._handle_message(chat_request.messages) - print("prompt in gateway", prompt) - parameters = LLMParams( - max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, - top_k=chat_request.top_k if chat_request.top_k else 10, - top_p=chat_request.top_p if chat_request.top_p else 0.95, - temperature=chat_request.temperature if chat_request.temperature else 0.01, - frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, - presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, - repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, - streaming=stream_opt, - chat_template=chat_request.chat_template if chat_request.chat_template else None, - model=( - chat_request.model - if chat_request.model - else os.getenv("MODEL_ID") if os.getenv("MODEL_ID") else "Intel/neural-chat-7b-v3-3" - ), - ) - retriever_parameters = RetrieverParms( - search_type=chat_request.search_type if chat_request.search_type else "similarity", - k=chat_request.k if chat_request.k else 4, - distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None, - fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20, - lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5, - score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2, - ) - reranker_parameters = RerankerParms( - top_n=chat_request.top_n if chat_request.top_n else 1, - ) - result_dict, runtime_graph = await self.megaservice.schedule( - initial_inputs={"text": prompt}, - llm_parameters=parameters, - retriever_parameters=retriever_parameters, - reranker_parameters=reranker_parameters, - ) - for node, response in result_dict.items(): - if isinstance(response, StreamingResponse): - return response - last_node = runtime_graph.all_leaves()[-1] - response = result_dict[last_node]["text"] - choices = [] - usage = UsageInfo() - choices.append( - ChatCompletionResponseChoice( - index=0, - message=ChatMessage(role="assistant", content=response), - finish_reason="stop", - ) - ) - return ChatCompletionResponse(model="chatqna", choices=choices, usage=usage) - - -class CodeGenGateway(Gateway): - def __init__(self, megaservice, host="0.0.0.0", port=8888): - super().__init__( - megaservice, host, port, str(MegaServiceEndpoint.CODE_GEN), ChatCompletionRequest, ChatCompletionResponse - ) - - async def handle_request(self, request: Request): - data = await request.json() - stream_opt = data.get("stream", True) - chat_request = ChatCompletionRequest.parse_obj(data) - prompt = self._handle_message(chat_request.messages) - parameters = LLMParams( - max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, - top_k=chat_request.top_k if chat_request.top_k else 10, - top_p=chat_request.top_p if chat_request.top_p else 0.95, - temperature=chat_request.temperature if chat_request.temperature else 0.01, - frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, - presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, - repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.2, - streaming=stream_opt, - model=chat_request.model if chat_request.model else None, - ) - result_dict, runtime_graph = await self.megaservice.schedule( - initial_inputs={"query": prompt}, llm_parameters=parameters - ) - for node, response in result_dict.items(): - # Here it suppose the last microservice in the megaservice is LLM. - if ( - isinstance(response, StreamingResponse) - and node == list(self.megaservice.services.keys())[-1] - and self.megaservice.services[node].service_type == ServiceType.LLM - ): - return response - last_node = runtime_graph.all_leaves()[-1] - response = result_dict[last_node]["text"] - choices = [] - usage = UsageInfo() - choices.append( - ChatCompletionResponseChoice( - index=0, - message=ChatMessage(role="assistant", content=response), - finish_reason="stop", - ) - ) - return ChatCompletionResponse(model="codegen", choices=choices, usage=usage) - - -class CodeTransGateway(Gateway): - def __init__(self, megaservice, host="0.0.0.0", port=8888): - super().__init__( - megaservice, host, port, str(MegaServiceEndpoint.CODE_TRANS), ChatCompletionRequest, ChatCompletionResponse - ) - - async def handle_request(self, request: Request): - data = await request.json() - language_from = data["language_from"] - language_to = data["language_to"] - source_code = data["source_code"] - prompt_template = """ - ### System: Please translate the following {language_from} codes into {language_to} codes. Don't output any other content except translated codes. - - ### Original {language_from} codes: - ''' - - {source_code} - - ''' - - ### Translated {language_to} codes: - - """ - prompt = prompt_template.format(language_from=language_from, language_to=language_to, source_code=source_code) - - parameters = LLMParams( - max_tokens=data.get("max_tokens", 1024), - top_k=data.get("top_k", 10), - top_p=data.get("top_p", 0.95), - temperature=data.get("temperature", 0.01), - repetition_penalty=data.get("repetition_penalty", 1.03), - streaming=data.get("stream", True), - ) - - result_dict, runtime_graph = await self.megaservice.schedule( - initial_inputs={"query": prompt}, llm_parameters=parameters - ) - for node, response in result_dict.items(): - # Here it suppose the last microservice in the megaservice is LLM. - if ( - isinstance(response, StreamingResponse) - and node == list(self.megaservice.services.keys())[-1] - and self.megaservice.services[node].service_type == ServiceType.LLM - ): - return response - last_node = runtime_graph.all_leaves()[-1] - response = result_dict[last_node]["text"] - choices = [] - usage = UsageInfo() - choices.append( - ChatCompletionResponseChoice( - index=0, - message=ChatMessage(role="assistant", content=response), - finish_reason="stop", - ) - ) - return ChatCompletionResponse(model="codetrans", choices=choices, usage=usage) - - -class TranslationGateway(Gateway): - def __init__(self, megaservice, host="0.0.0.0", port=8888): - super().__init__( - megaservice, host, port, str(MegaServiceEndpoint.TRANSLATION), ChatCompletionRequest, ChatCompletionResponse - ) - - async def handle_request(self, request: Request): - data = await request.json() - language_from = data["language_from"] - language_to = data["language_to"] - source_language = data["source_language"] - prompt_template = """ - Translate this from {language_from} to {language_to}: - - {language_from}: - {source_language} - - {language_to}: - """ - prompt = prompt_template.format( - language_from=language_from, language_to=language_to, source_language=source_language - ) - result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"query": prompt}) - for node, response in result_dict.items(): - # Here it suppose the last microservice in the megaservice is LLM. - if ( - isinstance(response, StreamingResponse) - and node == list(self.megaservice.services.keys())[-1] - and self.megaservice.services[node].service_type == ServiceType.LLM - ): - return response - last_node = runtime_graph.all_leaves()[-1] - response = result_dict[last_node]["text"] - choices = [] - usage = UsageInfo() - choices.append( - ChatCompletionResponseChoice( - index=0, - message=ChatMessage(role="assistant", content=response), - finish_reason="stop", - ) - ) - return ChatCompletionResponse(model="translation", choices=choices, usage=usage) - - -class DocSumGateway(Gateway): - def __init__(self, megaservice, host="0.0.0.0", port=8888): - super().__init__( - megaservice, - host, - port, - str(MegaServiceEndpoint.DOC_SUMMARY), - input_datatype=DocSumChatCompletionRequest, - output_datatype=ChatCompletionResponse, - ) - - async def handle_request(self, request: Request, files: List[UploadFile] = File(default=None)): - - if "application/json" in request.headers.get("content-type"): - data = await request.json() - stream_opt = data.get("stream", True) - chat_request = ChatCompletionRequest.model_validate(data) - prompt = self._handle_message(chat_request.messages) - - initial_inputs_data = {data["type"]: prompt} - - elif "multipart/form-data" in request.headers.get("content-type"): - data = await request.form() - stream_opt = data.get("stream", True) - chat_request = ChatCompletionRequest.model_validate(data) - - data_type = data.get("type") - - file_summaries = [] - if files: - for file in files: - file_path = f"/tmp/{file.filename}" - - if data_type is not None and data_type in ["audio", "video"]: - raise ValueError( - "Audio and Video file uploads are not supported in docsum with curl request, please use the UI." - ) - - else: - import aiofiles - - async with aiofiles.open(file_path, "wb") as f: - await f.write(await file.read()) - - docs = read_text_from_file(file, file_path) - os.remove(file_path) - - if isinstance(docs, list): - file_summaries.extend(docs) - else: - file_summaries.append(docs) - - if file_summaries: - prompt = self._handle_message(chat_request.messages) + "\n".join(file_summaries) - else: - prompt = self._handle_message(chat_request.messages) - - data_type = data.get("type") - if data_type is not None: - initial_inputs_data = {} - initial_inputs_data[data_type] = prompt - else: - initial_inputs_data = {"query": prompt} - - else: - raise ValueError(f"Unknown request type: {request.headers.get('content-type')}") - - parameters = LLMParams( - max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, - top_k=chat_request.top_k if chat_request.top_k else 10, - top_p=chat_request.top_p if chat_request.top_p else 0.95, - temperature=chat_request.temperature if chat_request.temperature else 0.01, - frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, - presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, - repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, - streaming=stream_opt, - model=chat_request.model if chat_request.model else None, - language=chat_request.language if chat_request.language else "auto", - ) - - result_dict, runtime_graph = await self.megaservice.schedule( - initial_inputs=initial_inputs_data, llm_parameters=parameters - ) - - for node, response in result_dict.items(): - # Here it suppose the last microservice in the megaservice is LLM. - if ( - isinstance(response, StreamingResponse) - and node == list(self.megaservice.services.keys())[-1] - and self.megaservice.services[node].service_type == ServiceType.LLM - ): - return response - last_node = runtime_graph.all_leaves()[-1] - response = result_dict[last_node]["text"] - choices = [] - usage = UsageInfo() - choices.append( - ChatCompletionResponseChoice( - index=0, - message=ChatMessage(role="assistant", content=response), - finish_reason="stop", - ) - ) - return ChatCompletionResponse(model="docsum", choices=choices, usage=usage) - - -class AudioQnAGateway(Gateway): - def __init__(self, megaservice, host="0.0.0.0", port=8888): - super().__init__( - megaservice, - host, - port, - str(MegaServiceEndpoint.AUDIO_QNA), - AudioChatCompletionRequest, - ChatCompletionResponse, - ) - - async def handle_request(self, request: Request): - data = await request.json() - - chat_request = AudioChatCompletionRequest.parse_obj(data) - parameters = LLMParams( - # relatively lower max_tokens for audio conversation - max_tokens=chat_request.max_tokens if chat_request.max_tokens else 128, - top_k=chat_request.top_k if chat_request.top_k else 10, - top_p=chat_request.top_p if chat_request.top_p else 0.95, - temperature=chat_request.temperature if chat_request.temperature else 0.01, - frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, - presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, - repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, - streaming=False, # TODO add streaming LLM output as input to TTS - ) - result_dict, runtime_graph = await self.megaservice.schedule( - initial_inputs={"byte_str": chat_request.audio}, llm_parameters=parameters - ) - - last_node = runtime_graph.all_leaves()[-1] - response = result_dict[last_node]["byte_str"] - - return response - - -class SearchQnAGateway(Gateway): - def __init__(self, megaservice, host="0.0.0.0", port=8888): - super().__init__( - megaservice, host, port, str(MegaServiceEndpoint.SEARCH_QNA), ChatCompletionRequest, ChatCompletionResponse - ) - - async def handle_request(self, request: Request): - data = await request.json() - stream_opt = data.get("stream", True) - chat_request = ChatCompletionRequest.parse_obj(data) - prompt = self._handle_message(chat_request.messages) - parameters = LLMParams( - max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, - top_k=chat_request.top_k if chat_request.top_k else 10, - top_p=chat_request.top_p if chat_request.top_p else 0.95, - temperature=chat_request.temperature if chat_request.temperature else 0.01, - frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, - presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, - repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, - streaming=stream_opt, - ) - result_dict, runtime_graph = await self.megaservice.schedule( - initial_inputs={"text": prompt}, llm_parameters=parameters - ) - for node, response in result_dict.items(): - # Here it suppose the last microservice in the megaservice is LLM. - if ( - isinstance(response, StreamingResponse) - and node == list(self.megaservice.services.keys())[-1] - and self.megaservice.services[node].service_type == ServiceType.LLM - ): - return response - last_node = runtime_graph.all_leaves()[-1] - response = result_dict[last_node]["text"] - choices = [] - usage = UsageInfo() - choices.append( - ChatCompletionResponseChoice( - index=0, - message=ChatMessage(role="assistant", content=response), - finish_reason="stop", - ) - ) - return ChatCompletionResponse(model="searchqna", choices=choices, usage=usage) - - -class FaqGenGateway(Gateway): - def __init__(self, megaservice, host="0.0.0.0", port=8888): - super().__init__( - megaservice, host, port, str(MegaServiceEndpoint.FAQ_GEN), ChatCompletionRequest, ChatCompletionResponse - ) - - async def handle_request(self, request: Request, files: List[UploadFile] = File(default=None)): - data = await request.form() - stream_opt = data.get("stream", True) - chat_request = ChatCompletionRequest.parse_obj(data) - file_summaries = [] - if files: - for file in files: - file_path = f"/tmp/{file.filename}" - - import aiofiles - - async with aiofiles.open(file_path, "wb") as f: - await f.write(await file.read()) - docs = read_text_from_file(file, file_path) - os.remove(file_path) - if isinstance(docs, list): - file_summaries.extend(docs) - else: - file_summaries.append(docs) - - if file_summaries: - prompt = self._handle_message(chat_request.messages) + "\n".join(file_summaries) - else: - prompt = self._handle_message(chat_request.messages) - - parameters = LLMParams( - max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, - top_k=chat_request.top_k if chat_request.top_k else 10, - top_p=chat_request.top_p if chat_request.top_p else 0.95, - temperature=chat_request.temperature if chat_request.temperature else 0.01, - frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, - presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, - repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, - streaming=stream_opt, - model=chat_request.model if chat_request.model else None, - ) - result_dict, runtime_graph = await self.megaservice.schedule( - initial_inputs={"query": prompt}, llm_parameters=parameters - ) - for node, response in result_dict.items(): - # Here it suppose the last microservice in the megaservice is LLM. - if ( - isinstance(response, StreamingResponse) - and node == list(self.megaservice.services.keys())[-1] - and self.megaservice.services[node].service_type == ServiceType.LLM - ): - return response - last_node = runtime_graph.all_leaves()[-1] - response = result_dict[last_node]["text"] - choices = [] - usage = UsageInfo() - choices.append( - ChatCompletionResponseChoice( - index=0, - message=ChatMessage(role="assistant", content=response), - finish_reason="stop", - ) - ) - return ChatCompletionResponse(model="faqgen", choices=choices, usage=usage) - - -class VisualQnAGateway(Gateway): - def __init__(self, megaservice, host="0.0.0.0", port=8888): - super().__init__( - megaservice, host, port, str(MegaServiceEndpoint.VISUAL_QNA), ChatCompletionRequest, ChatCompletionResponse - ) - - async def handle_request(self, request: Request): - data = await request.json() - stream_opt = data.get("stream", False) - chat_request = ChatCompletionRequest.parse_obj(data) - prompt, images = self._handle_message(chat_request.messages) - parameters = LLMParams( - max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, - top_k=chat_request.top_k if chat_request.top_k else 10, - top_p=chat_request.top_p if chat_request.top_p else 0.95, - temperature=chat_request.temperature if chat_request.temperature else 0.01, - frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, - presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, - repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, - streaming=stream_opt, - ) - result_dict, runtime_graph = await self.megaservice.schedule( - initial_inputs={"prompt": prompt, "image": images[0]}, llm_parameters=parameters - ) - for node, response in result_dict.items(): - # Here it suppose the last microservice in the megaservice is LVM. - if ( - isinstance(response, StreamingResponse) - and node == list(self.megaservice.services.keys())[-1] - and self.megaservice.services[node].service_type == ServiceType.LVM - ): - return response - last_node = runtime_graph.all_leaves()[-1] - response = result_dict[last_node]["text"] - choices = [] - usage = UsageInfo() - choices.append( - ChatCompletionResponseChoice( - index=0, - message=ChatMessage(role="assistant", content=response), - finish_reason="stop", - ) - ) - return ChatCompletionResponse(model="visualqna", choices=choices, usage=usage) - - -class VideoQnAGateway(Gateway): - def __init__(self, megaservice, host="0.0.0.0", port=8888): - super().__init__( - megaservice, - host, - port, - str(MegaServiceEndpoint.VIDEO_RAG_QNA), - ChatCompletionRequest, - ChatCompletionResponse, - ) - - async def handle_request(self, request: Request): - data = await request.json() - stream_opt = data.get("stream", False) - chat_request = ChatCompletionRequest.parse_obj(data) - prompt = self._handle_message(chat_request.messages) - parameters = LLMParams( - max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, - top_k=chat_request.top_k if chat_request.top_k else 10, - top_p=chat_request.top_p if chat_request.top_p else 0.95, - temperature=chat_request.temperature if chat_request.temperature else 0.01, - frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, - presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, - repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, - streaming=stream_opt, - ) - result_dict, runtime_graph = await self.megaservice.schedule( - initial_inputs={"text": prompt}, llm_parameters=parameters - ) - for node, response in result_dict.items(): - # Here it suppose the last microservice in the megaservice is LVM. - if ( - isinstance(response, StreamingResponse) - and node == list(self.megaservice.services.keys())[-1] - and self.megaservice.services[node].service_type == ServiceType.LVM - ): - return response - last_node = runtime_graph.all_leaves()[-1] - response = result_dict[last_node]["text"] - choices = [] - usage = UsageInfo() - choices.append( - ChatCompletionResponseChoice( - index=0, - message=ChatMessage(role="assistant", content=response), - finish_reason="stop", - ) - ) - return ChatCompletionResponse(model="videoqna", choices=choices, usage=usage) - - -class RetrievalToolGateway(Gateway): - """embed+retrieve+rerank.""" - - def __init__(self, megaservice, host="0.0.0.0", port=8889): - super().__init__( - megaservice, - host, - port, - str(MegaServiceEndpoint.RETRIEVALTOOL), - Union[TextDoc, ChatCompletionRequest], - Union[RerankedDoc, LLMParamsDoc], - ) - - async def handle_request(self, request: Request): - def parser_input(data, TypeClass, key): - chat_request = None - try: - chat_request = TypeClass.parse_obj(data) - query = getattr(chat_request, key) - except: - query = None - return query, chat_request - - data = await request.json() - query = None - for key, TypeClass in zip(["text", "messages"], [TextDoc, ChatCompletionRequest]): - query, chat_request = parser_input(data, TypeClass, key) - if query is not None: - break - if query is None: - raise ValueError(f"Unknown request type: {data}") - if chat_request is None: - raise ValueError(f"Unknown request type: {data}") - - if isinstance(chat_request, ChatCompletionRequest): - retriever_parameters = RetrieverParms( - search_type=chat_request.search_type if chat_request.search_type else "similarity", - k=chat_request.k if chat_request.k else 4, - distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None, - fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20, - lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5, - score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2, - ) - reranker_parameters = RerankerParms( - top_n=chat_request.top_n if chat_request.top_n else 1, - ) - - initial_inputs = { - "messages": query, - "input": query, # has to be input due to embedding expects either input or text - "search_type": chat_request.search_type if chat_request.search_type else "similarity", - "k": chat_request.k if chat_request.k else 4, - "distance_threshold": chat_request.distance_threshold if chat_request.distance_threshold else None, - "fetch_k": chat_request.fetch_k if chat_request.fetch_k else 20, - "lambda_mult": chat_request.lambda_mult if chat_request.lambda_mult else 0.5, - "score_threshold": chat_request.score_threshold if chat_request.score_threshold else 0.2, - "top_n": chat_request.top_n if chat_request.top_n else 1, - } - - result_dict, runtime_graph = await self.megaservice.schedule( - initial_inputs=initial_inputs, - retriever_parameters=retriever_parameters, - reranker_parameters=reranker_parameters, - ) - else: - result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"text": query}) - - last_node = runtime_graph.all_leaves()[-1] - response = result_dict[last_node] - return response - - -class MultimodalQnAGateway(Gateway): - def __init__(self, multimodal_rag_megaservice, lvm_megaservice, host="0.0.0.0", port=9999): - self.lvm_megaservice = lvm_megaservice - super().__init__( - multimodal_rag_megaservice, - host, - port, - str(MegaServiceEndpoint.MULTIMODAL_QNA), - ChatCompletionRequest, - ChatCompletionResponse, - ) - - # this overrides _handle_message method of Gateway - def _handle_message(self, messages): - images = [] - messages_dicts = [] - if isinstance(messages, str): - prompt = messages - else: - messages_dict = {} - system_prompt = "" - prompt = "" - for message in messages: - msg_role = message["role"] - messages_dict = {} - if msg_role == "system": - system_prompt = message["content"] - elif msg_role == "user": - if type(message["content"]) == list: - text = "" - text_list = [item["text"] for item in message["content"] if item["type"] == "text"] - text += "\n".join(text_list) - image_list = [ - item["image_url"]["url"] for item in message["content"] if item["type"] == "image_url" - ] - if image_list: - messages_dict[msg_role] = (text, image_list) - else: - messages_dict[msg_role] = text - else: - messages_dict[msg_role] = message["content"] - messages_dicts.append(messages_dict) - elif msg_role == "assistant": - messages_dict[msg_role] = message["content"] - messages_dicts.append(messages_dict) - else: - raise ValueError(f"Unknown role: {msg_role}") - - if system_prompt: - prompt = system_prompt + "\n" - for messages_dict in messages_dicts: - for i, (role, message) in enumerate(messages_dict.items()): - if isinstance(message, tuple): - text, image_list = message - if i == 0: - # do not add role for the very first message. - # this will be added by llava_server - if text: - prompt += text + "\n" - else: - if text: - prompt += role.upper() + ": " + text + "\n" - else: - prompt += role.upper() + ":" - for img in image_list: - # URL - if img.startswith("http://") or img.startswith("https://"): - response = requests.get(img) - image = Image.open(BytesIO(response.content)).convert("RGBA") - image_bytes = BytesIO() - image.save(image_bytes, format="PNG") - img_b64_str = base64.b64encode(image_bytes.getvalue()).decode() - # Local Path - elif os.path.exists(img): - image = Image.open(img).convert("RGBA") - image_bytes = BytesIO() - image.save(image_bytes, format="PNG") - img_b64_str = base64.b64encode(image_bytes.getvalue()).decode() - # Bytes - else: - img_b64_str = img - - images.append(img_b64_str) - else: - if i == 0: - # do not add role for the very first message. - # this will be added by llava_server - if message: - prompt += role.upper() + ": " + message + "\n" - else: - if message: - prompt += role.upper() + ": " + message + "\n" - else: - prompt += role.upper() + ":" - if images: - return prompt, images - else: - return prompt - - async def handle_request(self, request: Request): - data = await request.json() - stream_opt = bool(data.get("stream", False)) - if stream_opt: - print("[ MultimodalQnAGateway ] stream=True not used, this has not support streaming yet!") - stream_opt = False - chat_request = ChatCompletionRequest.model_validate(data) - # Multimodal RAG QnA With Videos has not yet accepts image as input during QnA. - prompt_and_image = self._handle_message(chat_request.messages) - if isinstance(prompt_and_image, tuple): - # print(f"This request include image, thus it is a follow-up query. Using lvm megaservice") - prompt, images = prompt_and_image - cur_megaservice = self.lvm_megaservice - initial_inputs = {"prompt": prompt, "image": images[0]} - else: - # print(f"This is the first query, requiring multimodal retrieval. Using multimodal rag megaservice") - prompt = prompt_and_image - cur_megaservice = self.megaservice - initial_inputs = {"text": prompt} - - parameters = LLMParams( - max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, - top_k=chat_request.top_k if chat_request.top_k else 10, - top_p=chat_request.top_p if chat_request.top_p else 0.95, - temperature=chat_request.temperature if chat_request.temperature else 0.01, - frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, - presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, - repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, - streaming=stream_opt, - chat_template=chat_request.chat_template if chat_request.chat_template else None, - ) - result_dict, runtime_graph = await cur_megaservice.schedule( - initial_inputs=initial_inputs, llm_parameters=parameters - ) - for node, response in result_dict.items(): - # the last microservice in this megaservice is LVM. - # checking if LVM returns StreamingResponse - # Currently, LVM with LLAVA has not yet supported streaming. - # @TODO: Will need to test this once LVM with LLAVA supports streaming - if ( - isinstance(response, StreamingResponse) - and node == runtime_graph.all_leaves()[-1] - and self.megaservice.services[node].service_type == ServiceType.LVM - ): - return response - last_node = runtime_graph.all_leaves()[-1] - - if "text" in result_dict[last_node].keys(): - response = result_dict[last_node]["text"] - else: - # text in not response message - # something wrong, for example due to empty retrieval results - if "detail" in result_dict[last_node].keys(): - response = result_dict[last_node]["detail"] - else: - response = "The server fail to generate answer to your query!" - if "metadata" in result_dict[last_node].keys(): - # from retrieval results - metadata = result_dict[last_node]["metadata"] - else: - # follow-up question, no retrieval - metadata = None - choices = [] - usage = UsageInfo() - choices.append( - ChatCompletionResponseChoice( - index=0, - message=ChatMessage(role="assistant", content=response), - finish_reason="stop", - metadata=metadata, - ) - ) - return ChatCompletionResponse(model="multimodalqna", choices=choices, usage=usage) - - -class AvatarChatbotGateway(Gateway): - def __init__(self, megaservice, host="0.0.0.0", port=8888): - super().__init__( - megaservice, - host, - port, - str(MegaServiceEndpoint.AVATAR_CHATBOT), - AudioChatCompletionRequest, - ChatCompletionResponse, - ) - - async def handle_request(self, request: Request): - data = await request.json() - - chat_request = AudioChatCompletionRequest.model_validate(data) - parameters = LLMParams( - # relatively lower max_tokens for audio conversation - max_tokens=chat_request.max_tokens if chat_request.max_tokens else 128, - top_k=chat_request.top_k if chat_request.top_k else 10, - top_p=chat_request.top_p if chat_request.top_p else 0.95, - temperature=chat_request.temperature if chat_request.temperature else 0.01, - repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03, - streaming=False, # TODO add streaming LLM output as input to TTS - ) - # print(parameters) - - result_dict, runtime_graph = await self.megaservice.schedule( - initial_inputs={"byte_str": chat_request.audio}, llm_parameters=parameters - ) - - last_node = runtime_graph.all_leaves()[-1] - response = result_dict[last_node]["video_path"] - return response - - -class GraphragGateway(Gateway): - def __init__(self, megaservice, host="0.0.0.0", port=8888): - super().__init__( - megaservice, host, port, str(MegaServiceEndpoint.GRAPH_RAG), ChatCompletionRequest, ChatCompletionResponse - ) - - async def handle_request(self, request: Request): - data = await request.json() - stream_opt = data.get("stream", True) - chat_request = ChatCompletionRequest.parse_obj(data) - - def parser_input(data, TypeClass, key): - chat_request = None - try: - chat_request = TypeClass.parse_obj(data) - query = getattr(chat_request, key) - except: - query = None - return query, chat_request - - query = None - for key, TypeClass in zip(["text", "input", "messages"], [TextDoc, EmbeddingRequest, ChatCompletionRequest]): - query, chat_request = parser_input(data, TypeClass, key) - if query is not None: - break - if query is None: - raise ValueError(f"Unknown request type: {data}") - if chat_request is None: - raise ValueError(f"Unknown request type: {data}") - prompt = self._handle_message(chat_request.messages) - parameters = LLMParams( - max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, - top_k=chat_request.top_k if chat_request.top_k else 10, - top_p=chat_request.top_p if chat_request.top_p else 0.95, - temperature=chat_request.temperature if chat_request.temperature else 0.01, - frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, - presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, - repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, - streaming=stream_opt, - chat_template=chat_request.chat_template if chat_request.chat_template else None, - ) - retriever_parameters = RetrieverParms( - search_type=chat_request.search_type if chat_request.search_type else "similarity", - k=chat_request.k if chat_request.k else 4, - distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None, - fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20, - lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5, - score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2, - ) - initial_inputs = chat_request - result_dict, runtime_graph = await self.megaservice.schedule( - initial_inputs=initial_inputs, - llm_parameters=parameters, - retriever_parameters=retriever_parameters, - ) - for node, response in result_dict.items(): - if isinstance(response, StreamingResponse): - return response - last_node = runtime_graph.all_leaves()[-1] - response_content = result_dict[last_node]["choices"][0]["message"]["content"] - choices = [] - usage = UsageInfo() - choices.append( - ChatCompletionResponseChoice( - index=0, - message=ChatMessage(role="assistant", content=response_content), - finish_reason="stop", - ) - ) - return ChatCompletionResponse(model="chatqna", choices=choices, usage=usage) diff --git a/comps/cores/mega/http_service.py b/comps/cores/mega/http_service.py index 283540f493..799cc5c80c 100644 --- a/comps/cores/mega/http_service.py +++ b/comps/cores/mega/http_service.py @@ -1,7 +1,9 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import asyncio import logging +import multiprocessing import re from typing import Optional @@ -83,6 +85,11 @@ async def _get_statistics(): return app + def add_startup_event(self, func): + @self.app.on_event("startup") + async def startup_event(): + asyncio.create_task(func) + async def initialize_server(self): """Initialize and return HTTP server.""" self.logger.info("Setting up HTTP server") @@ -110,11 +117,9 @@ async def start_server(self, **kwargs): """ await self.main_loop() - app = self.app - self.server = UviServer( config=Config( - app=app, + app=self.app, host=self.host_address, port=self.primary_port, log_level="info", @@ -137,6 +142,24 @@ async def terminate_server(self): await self.server.shutdown() self.logger.info("Server termination completed") + def _async_setup(self): + self.event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.event_loop) + self.event_loop.run_until_complete(self.initialize_server()) + + def start(self): + """Running method to block the main thread. + + This method runs the event loop until a Future is done. It is designed to be called in the main thread to keep it busy. + """ + self.event_loop.run_until_complete(self.execute_server()) + + def stop(self): + self.event_loop.run_until_complete(self.terminate_server()) + self.event_loop.stop() + self.event_loop.close() + self.logger.close() + @staticmethod def check_server_readiness(ctrl_address: str, timeout: float = 1.0, logger=None, **kwargs) -> bool: """Check if server status is ready. @@ -170,3 +193,6 @@ async def async_check_server_readiness(ctrl_address: str, timeout: float = 1.0, :return: True if status is ready else False. """ return HTTPService.check_server_readiness(ctrl_address, timeout, logger=logger) + + def add_route(self, endpoint, handler, methods=["POST"]): + self.app.router.add_api_route(endpoint, handler, methods=methods) diff --git a/comps/cores/mega/micro_service.py b/comps/cores/mega/micro_service.py index 458b097102..2d79d6414f 100644 --- a/comps/cores/mega/micro_service.py +++ b/comps/cores/mega/micro_service.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio -import multiprocessing import os from collections import defaultdict, deque from enum import Enum @@ -10,6 +9,7 @@ from ..proto.docarray import TextDoc from .constants import ServiceRoleType, ServiceType +from .http_service import HTTPService from .logger import CustomLogger from .utils import check_ports_availability @@ -19,12 +19,12 @@ logflag = os.getenv("LOGFLAG", False) -class MicroService: +class MicroService(HTTPService): """MicroService class to create a microservice.""" def __init__( self, - name: str, + name: str = "", service_role: ServiceRoleType = ServiceRoleType.MICROSERVICE, service_type: ServiceType = ServiceType.LLM, protocol: str = "http", @@ -44,7 +44,6 @@ def __init__( dynamic_batching_max_batch_size: int = 32, ): """Init the microservice.""" - self.name = f"{name}/{self.__class__.__name__}" if name else self.__class__.__name__ self.service_role = service_role self.service_type = service_type self.protocol = protocol @@ -67,24 +66,35 @@ def __init__( self.uvicorn_kwargs["ssl_certfile"] = ssl_certfile if not use_remote_service: + + if self.protocol.lower() == "http": + if not (check_ports_availability(self.host, self.port)): + raise RuntimeError(f"port:{self.port}") + self.provider = provider self.provider_endpoint = provider_endpoint self.endpoints = [] - self.server = self._get_server() - self.app = self.server.app + runtime_args = { + "protocol": self.protocol, + "host": self.host, + "port": self.port, + "title": name, + "description": "OPEA Microservice Infrastructure", + } + + super().__init__(uvicorn_kwargs=self.uvicorn_kwargs, runtime_args=runtime_args) + # create a batch request processor loop if using dynamic batching if self.dynamic_batching: self.buffer_lock = asyncio.Lock() self.request_buffer = defaultdict(deque) + self.add_startup_event(self._dynamic_batch_processor()) - @self.app.on_event("startup") - async def startup_event(): - asyncio.create_task(self._dynamic_batch_processor()) + self._async_setup() - self.event_loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.event_loop) - self.event_loop.run_until_complete(self._async_setup()) + # overwrite name + self.name = f"{name}/{self.__class__.__name__}" if name else self.__class__.__name__ async def _dynamic_batch_processor(self): if logflag: @@ -125,75 +135,6 @@ def _validate_env(self): "set use_remote_service to False if you want to use a local micro service!" ) - def _get_server(self): - """Get the server instance based on the protocol. - - This method currently only supports HTTP services. It creates an instance of the HTTPService class with the - necessary arguments. - In the future, it will also support gRPC services. - """ - self._validate_env() - from .http_service import HTTPService - - runtime_args = { - "protocol": self.protocol, - "host": self.host, - "port": self.port, - "title": self.name, - "description": "OPEA Microservice Infrastructure", - } - - return HTTPService(uvicorn_kwargs=self.uvicorn_kwargs, runtime_args=runtime_args) - - async def _async_setup(self): - """The async method setup the runtime. - - This method is responsible for setting up the server. It first checks if the port is available, then it gets - the server instance and initializes it. - """ - self._validate_env() - if self.protocol.lower() == "http": - if not (check_ports_availability(self.host, self.port)): - raise RuntimeError(f"port:{self.port}") - - await self.server.initialize_server() - - async def _async_run_forever(self): - """Running method of the server.""" - self._validate_env() - await self.server.execute_server() - - def run(self): - """Running method to block the main thread. - - This method runs the event loop until a Future is done. It is designed to be called in the main thread to keep it busy. - """ - self._validate_env() - self.event_loop.run_until_complete(self._async_run_forever()) - - def start(self, in_single_process=False): - self._validate_env() - if in_single_process: - # Resolve HPU segmentation fault and potential tokenizer issues by limiting to same process - self.run() - else: - self.process = multiprocessing.Process(target=self.run, daemon=False, name=self.name) - self.process.start() - - async def _async_teardown(self): - """Shutdown the server.""" - self._validate_env() - await self.server.terminate_server() - - def stop(self): - self._validate_env() - self.event_loop.run_until_complete(self._async_teardown()) - self.event_loop.stop() - self.event_loop.close() - self.server.logger.close() - if self.process.is_alive(): - self.process.terminate() - @property def endpoint_path(self): return f"{self.protocol}://{self.host}:{self.port}{self.endpoint}" diff --git a/comps/cores/mega/utils.py b/comps/cores/mega/utils.py index e5b2df4f5f..6749e66dea 100644 --- a/comps/cores/mega/utils.py +++ b/comps/cores/mega/utils.py @@ -1,15 +1,18 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import base64 import ipaddress import json import multiprocessing import os import random +from io import BytesIO from socket import AF_INET, SOCK_STREAM, socket from typing import List, Optional, Union import requests +from PIL import Image from .logger import CustomLogger @@ -258,3 +261,73 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): if exc_type: self.context_to_manage.__exit__(exc_type, exc_val, exc_tb) + + +def handle_message(messages): + images = [] + if isinstance(messages, str): + prompt = messages + else: + messages_dict = {} + system_prompt = "" + prompt = "" + for message in messages: + msg_role = message["role"] + if msg_role == "system": + system_prompt = message["content"] + elif msg_role == "user": + if type(message["content"]) == list: + text = "" + text_list = [item["text"] for item in message["content"] if item["type"] == "text"] + text += "\n".join(text_list) + image_list = [ + item["image_url"]["url"] for item in message["content"] if item["type"] == "image_url" + ] + if image_list: + messages_dict[msg_role] = (text, image_list) + else: + messages_dict[msg_role] = text + else: + messages_dict[msg_role] = message["content"] + elif msg_role == "assistant": + messages_dict[msg_role] = message["content"] + else: + raise ValueError(f"Unknown role: {msg_role}") + + if system_prompt: + prompt = system_prompt + "\n" + for role, message in messages_dict.items(): + if isinstance(message, tuple): + text, image_list = message + if text: + prompt += role + ": " + text + "\n" + else: + prompt += role + ":" + for img in image_list: + # URL + if img.startswith("http://") or img.startswith("https://"): + response = requests.get(img) + image = Image.open(BytesIO(response.content)).convert("RGBA") + image_bytes = BytesIO() + image.save(image_bytes, format="PNG") + img_b64_str = base64.b64encode(image_bytes.getvalue()).decode() + # Local Path + elif os.path.exists(img): + image = Image.open(img).convert("RGBA") + image_bytes = BytesIO() + image.save(image_bytes, format="PNG") + img_b64_str = base64.b64encode(image_bytes.getvalue()).decode() + # Bytes + else: + img_b64_str = img + + images.append(img_b64_str) + else: + if message: + prompt += role + ": " + message + "\n" + else: + prompt += role + ":" + if images: + return prompt, images + else: + return prompt diff --git a/tests/cores/mega/test_aio.py b/tests/cores/mega/test_aio.py index fc735e70aa..4187cb0349 100644 --- a/tests/cores/mega/test_aio.py +++ b/tests/cores/mega/test_aio.py @@ -14,6 +14,7 @@ import asyncio import json +import multiprocessing import time import unittest @@ -55,9 +56,14 @@ def setUp(self): self.s1 = opea_microservices["s1"] self.s2 = opea_microservices["s2"] self.s3 = opea_microservices["s3"] - self.s1.start() - self.s2.start() - self.s3.start() + + self.process1 = multiprocessing.Process(target=self.s1.start, daemon=False, name="s1") + self.process2 = multiprocessing.Process(target=self.s2.start, daemon=False, name="s2") + self.process3 = multiprocessing.Process(target=self.s3.start, daemon=False, name="s2") + + self.process1.start() + self.process2.start() + self.process3.start() self.service_builder = ServiceOrchestrator() @@ -70,6 +76,10 @@ def tearDown(self): self.s2.stop() self.s3.stop() + self.process1.terminate() + self.process2.terminate() + self.process3.terminate() + async def test_schedule(self): t = time.time() task1 = asyncio.create_task(self.service_builder.schedule(initial_inputs={"text": "hello, "})) diff --git a/tests/cores/mega/test_base_statistics.py b/tests/cores/mega/test_base_statistics.py index ef4e7da3e0..878b3016c5 100644 --- a/tests/cores/mega/test_base_statistics.py +++ b/tests/cores/mega/test_base_statistics.py @@ -3,6 +3,7 @@ import asyncio import json +import multiprocessing import time import unittest @@ -34,13 +35,15 @@ async def s1_add(request: TextDoc) -> TextDoc: class TestBaseStatistics(unittest.IsolatedAsyncioTestCase): def setUp(self): self.s1 = opea_microservices["s1"] - self.s1.start() + self.process1 = multiprocessing.Process(target=self.s1.start, daemon=False, name="s1") + self.process1.start() self.service_builder = ServiceOrchestrator() self.service_builder.add(opea_microservices["s1"]) def tearDown(self): self.s1.stop() + self.process1.terminate() async def test_base_statistics(self): for _ in range(2): diff --git a/tests/cores/mega/test_dynamic_batching.py b/tests/cores/mega/test_dynamic_batching.py index 945054fb0f..bcb185b8fa 100644 --- a/tests/cores/mega/test_dynamic_batching.py +++ b/tests/cores/mega/test_dynamic_batching.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio +import multiprocessing import unittest from enum import Enum @@ -67,10 +68,12 @@ async def fetch(session, url, data): class TestMicroService(unittest.IsolatedAsyncioTestCase): def setUp(self): - opea_microservices["s1"].start() + self.process1 = multiprocessing.Process(target=opea_microservices["s1"].start, daemon=False, name="s1") + self.process1.start() def tearDown(self): opea_microservices["s1"].stop() + self.process1.terminate() async def test_dynamic_batching(self): url1 = "http://localhost:8080/v1/add1" diff --git a/tests/cores/mega/test_handle_message.py b/tests/cores/mega/test_handle_message.py new file mode 100644 index 0000000000..078bcdcd06 --- /dev/null +++ b/tests/cores/mega/test_handle_message.py @@ -0,0 +1,133 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import json +import unittest +from typing import Union + +from comps.cores.mega.utils import handle_message + + +class TestHandleMessage(unittest.IsolatedAsyncioTestCase): + + def test_handle_message(self): + messages = [ + {"role": "user", "content": "opea project! "}, + ] + prompt = handle_message(messages) + self.assertEqual(prompt, "user: opea project! \n") + + def test_handle_message_with_system_prompt(self): + messages = [ + {"role": "system", "content": "System Prompt"}, + {"role": "user", "content": "opea project! "}, + ] + prompt = handle_message(messages) + self.assertEqual(prompt, "System Prompt\nuser: opea project! \n") + + def test_handle_message_with_image(self): + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "hello, "}, + { + "type": "image_url", + "image_url": {"url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, + }, + ], + }, + ] + prompt, image = handle_message(messages) + self.assertEqual(prompt, "user: hello, \n") + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": ""}, + { + "type": "image_url", + "image_url": {"url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, + }, + ], + }, + ] + prompt, image = handle_message(messages) + self.assertEqual(prompt, "user:") + + def test_handle_message_with_image_str(self): + self.img_b64_str = ( + "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mP8/5+hnoEIwDiqkL4KAcT9GO0U4BxoAAAAAElFTkSuQmCC" + ) + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "hello, "}, + { + "type": "image_url", + "image_url": {"url": self.img_b64_str}, + }, + ], + }, + ] + prompt, image = handle_message(messages) + self.assertEqual(image[0], self.img_b64_str) + + def test_handle_message_with_image_local(self): + img_b64_str = ( + "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mP8/5+hnoEIwDiqkL4KAcT9GO0U4BxoAAAAAElFTkSuQmCC" + ) + import base64 + import io + + from PIL import Image + + img = Image.open(io.BytesIO(base64.decodebytes(bytes(img_b64_str, "utf-8")))) + img.save("./test.png") + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "hello, "}, + { + "type": "image_url", + "image_url": {"url": "./test.png"}, + }, + ], + }, + ] + prompt, image = handle_message(messages) + self.assertEqual(prompt, "user: hello, \n") + + def test_handle_message_with_content_list(self): + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "hello, "}, + ], + }, + {"role": "assistant", "content": "opea project! "}, + {"role": "user", "content": ""}, + ] + prompt = handle_message(messages) + self.assertEqual(prompt, "user:assistant: opea project! \n") + + def test_handle_string_message(self): + messages = "hello, " + prompt = handle_message(messages) + self.assertEqual(prompt, "hello, ") + + def test_handle_message_with_invalid_role(self): + messages = [ + {"role": "user_test", "content": "opea project! "}, + ] + self.assertRaises(ValueError, handle_message, messages) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/cores/mega/test_hybrid_service_orchestrator.py b/tests/cores/mega/test_hybrid_service_orchestrator.py index 0838d25ec8..89522eac3e 100644 --- a/tests/cores/mega/test_hybrid_service_orchestrator.py +++ b/tests/cores/mega/test_hybrid_service_orchestrator.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import json +import multiprocessing import unittest from comps import MicroService, ServiceOrchestrator, TextDoc, opea_microservices, register_microservice @@ -19,23 +20,21 @@ async def s1_add(request: TextDoc) -> TextDoc: class TestServiceOrchestrator(unittest.TestCase): def setUp(self): self.s1 = opea_microservices["s1"] - self.s1.start() + self.process1 = multiprocessing.Process(target=self.s1.start, daemon=False, name="s1") + self.process1.start() self.service_builder = ServiceOrchestrator() def tearDown(self): self.s1.stop() + self.process1.terminate() def test_add_remote_service(self): s2 = MicroService(name="s2", host="fakehost", port=8008, endpoint="/v1/add", use_remote_service=True) self.service_builder.add(opea_microservices["s1"]).add(s2) self.service_builder.flow_to(self.s1, s2) self.assertEqual(s2.endpoint_path, "http://fakehost:8008/v1/add") - # Check whether the right exception is raise when init/stop remote service - try: - s2.start() - except Exception as e: - self.assertTrue("Method not allowed" in str(e)) + self.assertRaises(Exception, s2._validate_env, "N/A") if __name__ == "__main__": diff --git a/tests/cores/mega/test_hybrid_service_orchestrator_with_yaml.py b/tests/cores/mega/test_hybrid_service_orchestrator_with_yaml.py index bd23201841..8d70ab43f0 100644 --- a/tests/cores/mega/test_hybrid_service_orchestrator_with_yaml.py +++ b/tests/cores/mega/test_hybrid_service_orchestrator_with_yaml.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import json +import multiprocessing import unittest from comps import ServiceOrchestratorWithYaml, TextDoc, opea_microservices, register_microservice @@ -19,10 +20,12 @@ async def s1_add(request: TextDoc) -> TextDoc: class TestYAMLOrchestrator(unittest.TestCase): def setUp(self) -> None: self.s1 = opea_microservices["s1"] - self.s1.start() + self.process1 = multiprocessing.Process(target=self.s1.start, daemon=False, name="s1") + self.process1.start() def tearDown(self): self.s1.stop() + self.process1.terminate() def test_add_remote_service(self): service_builder = ServiceOrchestratorWithYaml(yaml_file_path="megaservice_hybrid.yaml") diff --git a/tests/cores/mega/test_microservice.py b/tests/cores/mega/test_microservice.py index dbaff9a760..b621dda5ae 100644 --- a/tests/cores/mega/test_microservice.py +++ b/tests/cores/mega/test_microservice.py @@ -2,11 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 import json +import multiprocessing import unittest from fastapi.testclient import TestClient -from comps import TextDoc, opea_microservices, register_microservice +from comps import MicroService, TextDoc, opea_microservices, register_microservice @register_microservice(name="s1", host="0.0.0.0", port=8080, endpoint="/v1/add") @@ -18,14 +19,24 @@ async def add(request: TextDoc) -> TextDoc: return {"text": text} +def sum_test(): + return 1 + 1 + + class TestMicroService(unittest.TestCase): def setUp(self): self.client = TestClient(opea_microservices["s1"].app) - opea_microservices["s1"].start() + opea_microservices["s1"].add_route("/v1/sum", sum_test, methods=["GET"]) + self.process1 = multiprocessing.Process(target=opea_microservices["s1"].start, daemon=False, name="s1") + + self.process1.start() + + self.assertRaises(RuntimeError, MicroService, name="s2", host="0.0.0.0", port=8080, endpoint="/v1/add") def tearDown(self): opea_microservices["s1"].stop() + self.process1.terminate() def test_add_route(self): response = self.client.post("/v1/add", json={"text": "Hello, "}) @@ -34,6 +45,14 @@ def test_add_route(self): response = self.client.get("/metrics") self.assertEqual(response.status_code, 200) + response = self.client.get("/v1/health_check") + self.assertEqual( + response.json(), {"Service Title": "s1", "Service Description": "OPEA Microservice Infrastructure"} + ) + + response = self.client.get("/v1/sum") + self.assertEqual(response.json(), 2) + if __name__ == "__main__": unittest.main() diff --git a/tests/cores/mega/test_multimodalqna_gateway.py b/tests/cores/mega/test_multimodalqna_gateway.py deleted file mode 100644 index c05bf57bdd..0000000000 --- a/tests/cores/mega/test_multimodalqna_gateway.py +++ /dev/null @@ -1,213 +0,0 @@ -# Copyright (C) 2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -import json -import unittest -from typing import Union - -import requests -from fastapi import Request - -from comps import ( - EmbedDoc, - EmbedMultimodalDoc, - LVMDoc, - LVMSearchedMultimodalDoc, - MultimodalDoc, - MultimodalQnAGateway, - SearchedMultimodalDoc, - ServiceOrchestrator, - TextDoc, - opea_microservices, - register_microservice, -) - - -@register_microservice(name="mm_embedding", host="0.0.0.0", port=8083, endpoint="/v1/mm_embedding") -async def mm_embedding_add(request: MultimodalDoc) -> EmbedDoc: - req = request.model_dump_json() - req_dict = json.loads(req) - text = req_dict["text"] - res = {} - res["text"] = text - res["embedding"] = [0.12, 0.45] - return res - - -@register_microservice(name="mm_retriever", host="0.0.0.0", port=8084, endpoint="/v1/mm_retriever") -async def mm_retriever_add(request: EmbedMultimodalDoc) -> SearchedMultimodalDoc: - req = request.model_dump_json() - req_dict = json.loads(req) - text = req_dict["text"] - res = {} - res["retrieved_docs"] = [] - res["initial_query"] = text - res["top_n"] = 1 - res["metadata"] = [ - { - "b64_img_str": "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mP8/5+hnoEIwDiqkL4KAcT9GO0U4BxoAAAAAElFTkSuQmCC", - "transcript_for_inference": "yellow image", - } - ] - res["chat_template"] = "The caption of the image is: '{context}'. {question}" - return res - - -@register_microservice(name="lvm", host="0.0.0.0", port=8085, endpoint="/v1/lvm") -async def lvm_add(request: Union[LVMDoc, LVMSearchedMultimodalDoc]) -> TextDoc: - req = request.model_dump_json() - req_dict = json.loads(req) - if isinstance(request, LVMSearchedMultimodalDoc): - print("request is the output of multimodal retriever") - text = req_dict["initial_query"] - text += "opea project!" - - else: - print("request is from user.") - text = req_dict["prompt"] - text = f"\nUSER: {text}\nASSISTANT:" - - res = {} - res["text"] = text - return res - - -class TestServiceOrchestrator(unittest.IsolatedAsyncioTestCase): - @classmethod - def setUpClass(cls): - cls.mm_embedding = opea_microservices["mm_embedding"] - cls.mm_retriever = opea_microservices["mm_retriever"] - cls.lvm = opea_microservices["lvm"] - cls.mm_embedding.start() - cls.mm_retriever.start() - cls.lvm.start() - - cls.service_builder = ServiceOrchestrator() - - cls.service_builder.add(opea_microservices["mm_embedding"]).add(opea_microservices["mm_retriever"]).add( - opea_microservices["lvm"] - ) - cls.service_builder.flow_to(cls.mm_embedding, cls.mm_retriever) - cls.service_builder.flow_to(cls.mm_retriever, cls.lvm) - - cls.follow_up_query_service_builder = ServiceOrchestrator() - cls.follow_up_query_service_builder.add(cls.lvm) - - cls.gateway = MultimodalQnAGateway(cls.service_builder, cls.follow_up_query_service_builder, port=9898) - - @classmethod - def tearDownClass(cls): - cls.mm_embedding.stop() - cls.mm_retriever.stop() - cls.lvm.stop() - cls.gateway.stop() - - async def test_service_builder_schedule(self): - result_dict, _ = await self.service_builder.schedule(initial_inputs={"text": "hello, "}) - self.assertEqual(result_dict[self.lvm.name]["text"], "hello, opea project!") - - async def test_follow_up_query_service_builder_schedule(self): - result_dict, _ = await self.follow_up_query_service_builder.schedule( - initial_inputs={"prompt": "chao, ", "image": "some image"} - ) - # print(result_dict) - self.assertEqual(result_dict[self.lvm.name]["text"], "\nUSER: chao, \nASSISTANT:") - - def test_MultimodalQnAGateway_gateway(self): - json_data = {"messages": "hello, "} - response = requests.post("http://0.0.0.0:9898/v1/multimodalqna", json=json_data) - response = response.json() - self.assertEqual(response["choices"][-1]["message"]["content"], "hello, opea project!") - - def test_follow_up_MultimodalQnAGateway_gateway(self): - json_data = { - "messages": [ - { - "role": "user", - "content": [ - {"type": "text", "text": "hello, "}, - { - "type": "image_url", - "image_url": {"url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, - }, - ], - }, - {"role": "assistant", "content": "opea project! "}, - {"role": "user", "content": "chao, "}, - ], - "max_tokens": 300, - } - response = requests.post("http://0.0.0.0:9898/v1/multimodalqna", json=json_data) - response = response.json() - self.assertEqual( - response["choices"][-1]["message"]["content"], - "\nUSER: hello, \nASSISTANT: opea project! \nUSER: chao, \n\nASSISTANT:", - ) - - def test_handle_message(self): - messages = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "hello, "}, - { - "type": "image_url", - "image_url": {"url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, - }, - ], - }, - {"role": "assistant", "content": "opea project! "}, - {"role": "user", "content": "chao, "}, - ] - prompt, images = self.gateway._handle_message(messages) - self.assertEqual(prompt, "hello, \nASSISTANT: opea project! \nUSER: chao, \n") - - def test_handle_message_with_system_prompt(self): - messages = [ - {"role": "system", "content": "System Prompt"}, - { - "role": "user", - "content": [ - {"type": "text", "text": "hello, "}, - { - "type": "image_url", - "image_url": {"url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, - }, - ], - }, - {"role": "assistant", "content": "opea project! "}, - {"role": "user", "content": "chao, "}, - ] - prompt, images = self.gateway._handle_message(messages) - self.assertEqual(prompt, "System Prompt\nhello, \nASSISTANT: opea project! \nUSER: chao, \n") - - async def test_handle_request(self): - json_data = { - "messages": [ - { - "role": "user", - "content": [ - {"type": "text", "text": "hello, "}, - { - "type": "image_url", - "image_url": {"url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, - }, - ], - }, - {"role": "assistant", "content": "opea project! "}, - {"role": "user", "content": "chao, "}, - ], - "max_tokens": 300, - } - mock_request = Request(scope={"type": "http"}) - mock_request._json = json_data - res = await self.gateway.handle_request(mock_request) - res = json.loads(res.json()) - self.assertEqual( - res["choices"][-1]["message"]["content"], - "\nUSER: hello, \nASSISTANT: opea project! \nUSER: chao, \n\nASSISTANT:", - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/cores/mega/test_runtime_graph.py b/tests/cores/mega/test_runtime_graph.py index 9a140e0b12..e1449d7fc9 100644 --- a/tests/cores/mega/test_runtime_graph.py +++ b/tests/cores/mega/test_runtime_graph.py @@ -1,6 +1,7 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import multiprocessing import unittest from fastapi.testclient import TestClient @@ -54,10 +55,15 @@ def setUp(self): self.s3 = opea_microservices["s3"] self.s4 = opea_microservices["s4"] - self.s1.start() - self.s2.start() - self.s3.start() - self.s4.start() + self.process1 = multiprocessing.Process(target=self.s1.start, daemon=False, name="s1") + self.process2 = multiprocessing.Process(target=self.s2.start, daemon=False, name="s2") + self.process3 = multiprocessing.Process(target=self.s3.start, daemon=False, name="s3") + self.process4 = multiprocessing.Process(target=self.s4.start, daemon=False, name="s4") + + self.process1.start() + self.process2.start() + self.process3.start() + self.process4.start() self.service_builder = ServiceOrchestrator() self.service_builder.add(self.s1).add(self.s2).add(self.s3).add(self.s4) @@ -70,6 +76,10 @@ def tearDown(self): self.s2.stop() self.s3.stop() self.s4.stop() + self.process1.terminate() + self.process2.terminate() + self.process3.terminate() + self.process4.terminate() async def test_add_route(self): result_dict, runtime_graph = await self.service_builder.schedule(initial_inputs={"text": "Hi!"}) diff --git a/tests/cores/mega/test_service_orchestrator.py b/tests/cores/mega/test_service_orchestrator.py index bd19d77945..bb3e15df57 100644 --- a/tests/cores/mega/test_service_orchestrator.py +++ b/tests/cores/mega/test_service_orchestrator.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import json +import multiprocessing import unittest from comps import ServiceOrchestrator, TextDoc, opea_microservices, register_microservice @@ -30,8 +31,10 @@ class TestServiceOrchestrator(unittest.IsolatedAsyncioTestCase): def setUpClass(cls): cls.s1 = opea_microservices["s1"] cls.s2 = opea_microservices["s2"] - cls.s1.start() - cls.s2.start() + cls.process1 = multiprocessing.Process(target=cls.s1.start, daemon=False, name="s1") + cls.process2 = multiprocessing.Process(target=cls.s2.start, daemon=False, name="s2") + cls.process1.start() + cls.process2.start() cls.service_builder = ServiceOrchestrator() @@ -42,6 +45,8 @@ def setUpClass(cls): def tearDownClass(cls): cls.s1.stop() cls.s2.stop() + cls.process1.terminate() + cls.process2.terminate() async def test_schedule(self): result_dict, _ = await self.service_builder.schedule(initial_inputs={"text": "hello, "}) diff --git a/tests/cores/mega/test_service_orchestrator_protocol.py b/tests/cores/mega/test_service_orchestrator_protocol.py index 9ee2034892..db6cfead8c 100644 --- a/tests/cores/mega/test_service_orchestrator_protocol.py +++ b/tests/cores/mega/test_service_orchestrator_protocol.py @@ -1,6 +1,7 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import multiprocessing import unittest from comps import ServiceOrchestrator, opea_microservices, register_microservice @@ -16,7 +17,8 @@ async def s1_add(request: ChatCompletionRequest) -> ChatCompletionRequest: class TestServiceOrchestratorProtocol(unittest.IsolatedAsyncioTestCase): def setUp(self): self.s1 = opea_microservices["s1"] - self.s1.start() + self.process1 = multiprocessing.Process(target=self.s1.start, daemon=False, name="s1") + self.process1.start() self.service_builder = ServiceOrchestrator() @@ -24,6 +26,7 @@ def setUp(self): def tearDown(self): self.s1.stop() + self.process1.terminate() async def test_schedule(self): input_data = ChatCompletionRequest(messages=[{"role": "user", "content": "What's up man?"}], seed=None) diff --git a/tests/cores/mega/test_service_orchestrator_streaming.py b/tests/cores/mega/test_service_orchestrator_streaming.py index d2331dab62..e2d11b1af5 100644 --- a/tests/cores/mega/test_service_orchestrator_streaming.py +++ b/tests/cores/mega/test_service_orchestrator_streaming.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import json +import multiprocessing import time import unittest @@ -38,8 +39,10 @@ class TestServiceOrchestratorStreaming(unittest.IsolatedAsyncioTestCase): def setUpClass(cls): cls.s0 = opea_microservices["s0"] cls.s1 = opea_microservices["s1"] - cls.s0.start() - cls.s1.start() + cls.process1 = multiprocessing.Process(target=cls.s0.start, daemon=False, name="s0") + cls.process2 = multiprocessing.Process(target=cls.s1.start, daemon=False, name="s1") + cls.process1.start() + cls.process2.start() cls.service_builder = ServiceOrchestrator() @@ -50,6 +53,8 @@ def setUpClass(cls): def tearDownClass(cls): cls.s0.stop() cls.s1.stop() + cls.process1.terminate() + cls.process2.terminate() async def test_schedule(self): result_dict, _ = await self.service_builder.schedule(initial_inputs={"text": "hello, "}) diff --git a/tests/cores/mega/test_service_orchestrator_with_gateway.py b/tests/cores/mega/test_service_orchestrator_with_gateway.py deleted file mode 100644 index 42bad2a2f6..0000000000 --- a/tests/cores/mega/test_service_orchestrator_with_gateway.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (C) 2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -import json -import unittest - -from comps import Gateway, ServiceOrchestrator, TextDoc, opea_microservices, register_microservice - - -@register_microservice(name="s1", host="0.0.0.0", port=8083, endpoint="/v1/add") -async def s1_add(request: TextDoc) -> TextDoc: - req = request.model_dump_json() - req_dict = json.loads(req) - text = req_dict["text"] - text += "opea " - return {"text": text} - - -@register_microservice(name="s2", host="0.0.0.0", port=8084, endpoint="/v1/add") -async def s2_add(request: TextDoc) -> TextDoc: - req = request.model_dump_json() - req_dict = json.loads(req) - text = req_dict["text"] - text += "project!" - return {"text": text} - - -class TestServiceOrchestrator(unittest.IsolatedAsyncioTestCase): - def setUp(self): - self.s1 = opea_microservices["s1"] - self.s2 = opea_microservices["s2"] - self.s1.start() - self.s2.start() - - self.service_builder = ServiceOrchestrator() - - self.service_builder.add(opea_microservices["s1"]).add(opea_microservices["s2"]) - self.service_builder.flow_to(self.s1, self.s2) - self.gateway = Gateway(self.service_builder, port=9898) - - def tearDown(self): - self.s1.stop() - self.s2.stop() - self.gateway.stop() - - async def test_schedule(self): - result_dict, _ = await self.service_builder.schedule(initial_inputs={"text": "hello, "}) - self.assertEqual(result_dict[self.s2.name]["text"], "hello, opea project!") - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/cores/mega/test_service_orchestrator_with_retriever_rerank_fake.py b/tests/cores/mega/test_service_orchestrator_with_retriever_rerank_fake.py index eb74c5fb19..bc0fe48231 100644 --- a/tests/cores/mega/test_service_orchestrator_with_retriever_rerank_fake.py +++ b/tests/cores/mega/test_service_orchestrator_with_retriever_rerank_fake.py @@ -2,9 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 import json +import multiprocessing import unittest -from comps import EmbedDoc, Gateway, ServiceOrchestrator, TextDoc, opea_microservices, register_microservice +from comps import EmbedDoc, ServiceOrchestrator, TextDoc, opea_microservices, register_microservice from comps.cores.mega.constants import ServiceType from comps.cores.proto.docarray import RerankerParms, RetrieverParms @@ -45,8 +46,12 @@ class TestServiceOrchestratorParams(unittest.IsolatedAsyncioTestCase): def setUp(self): self.s1 = opea_microservices["s1"] self.s2 = opea_microservices["s2"] - self.s1.start() - self.s2.start() + + self.process1 = multiprocessing.Process(target=self.s1.start, daemon=False, name="s1") + self.process2 = multiprocessing.Process(target=self.s2.start, daemon=False, name="s2") + + self.process1.start() + self.process2.start() ServiceOrchestrator.align_inputs = align_inputs ServiceOrchestrator.align_outputs = align_outputs @@ -54,12 +59,12 @@ def setUp(self): self.service_builder.add(opea_microservices["s1"]).add(opea_microservices["s2"]) self.service_builder.flow_to(self.s1, self.s2) - self.gateway = Gateway(self.service_builder, port=9898) def tearDown(self): self.s1.stop() self.s2.stop() - self.gateway.stop() + self.process1.terminate() + self.process2.terminate() async def test_retriever_schedule(self): result_dict, _ = await self.service_builder.schedule( diff --git a/tests/cores/mega/test_service_orchestrator_with_videoqnagateway.py b/tests/cores/mega/test_service_orchestrator_with_videoqnagateway.py deleted file mode 100644 index 4905120fbb..0000000000 --- a/tests/cores/mega/test_service_orchestrator_with_videoqnagateway.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright (C) 2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -import json -import unittest - -from fastapi.responses import StreamingResponse - -from comps import ServiceOrchestrator, ServiceType, TextDoc, VideoQnAGateway, opea_microservices, register_microservice -from comps.cores.proto.docarray import LLMParams - - -@register_microservice(name="s1", host="0.0.0.0", port=8083, endpoint="/v1/add") -async def s1_add(request: TextDoc) -> TextDoc: - req = request.model_dump_json() - req_dict = json.loads(req) - text = req_dict["text"] - text += "opea " - return {"text": text} - - -@register_microservice(name="s2", host="0.0.0.0", port=8084, endpoint="/v1/add", service_type=ServiceType.LVM) -async def s2_add(request: TextDoc) -> TextDoc: - req = request.model_dump_json() - req_dict = json.loads(req) - text = req_dict["text"] - - def streamer(text): - yield f"{text}".encode("utf-8") - for i in range(3): - yield "project!".encode("utf-8") - - return StreamingResponse(streamer(text), media_type="text/event-stream") - - -class TestServiceOrchestrator(unittest.IsolatedAsyncioTestCase): - def setUp(self): - self.s1 = opea_microservices["s1"] - self.s2 = opea_microservices["s2"] - self.s1.start() - self.s2.start() - - self.service_builder = ServiceOrchestrator() - - self.service_builder.add(opea_microservices["s1"]).add(opea_microservices["s2"]) - self.service_builder.flow_to(self.s1, self.s2) - self.gateway = VideoQnAGateway(self.service_builder, port=9898) - - def tearDown(self): - self.s1.stop() - self.s2.stop() - self.gateway.stop() - - async def test_schedule(self): - result_dict, _ = await self.service_builder.schedule( - initial_inputs={"text": "hello, "}, llm_parameters=LLMParams(streaming=True) - ) - streaming_response = result_dict[self.s2.name] - - if isinstance(streaming_response, StreamingResponse): - content = b"" - async for chunk in streaming_response.body_iterator: - content += chunk - final_text = content.decode("utf-8") - - print("Streamed content from s2: ", final_text) - - expected_result = "hello, opea project!project!project!" - self.assertEqual(final_text, expected_result) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/cores/mega/test_service_orchestrator_with_yaml.py b/tests/cores/mega/test_service_orchestrator_with_yaml.py index 3a3c6509d3..9da5a77919 100644 --- a/tests/cores/mega/test_service_orchestrator_with_yaml.py +++ b/tests/cores/mega/test_service_orchestrator_with_yaml.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import json +import multiprocessing import unittest from comps import ServiceOrchestratorWithYaml, TextDoc, opea_microservices, register_microservice @@ -29,12 +30,17 @@ class TestYAMLOrchestrator(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: self.s1 = opea_microservices["s1"] self.s2 = opea_microservices["s2"] - self.s1.start() - self.s2.start() + + self.process1 = multiprocessing.Process(target=self.s1.start, daemon=False, name="s1") + self.process2 = multiprocessing.Process(target=self.s2.start, daemon=False, name="s2") + self.process1.start() + self.process2.start() def tearDown(self): self.s1.stop() self.s2.stop() + self.process1.terminate() + self.process2.terminate() async def test_schedule(self): service_builder = ServiceOrchestratorWithYaml(yaml_file_path="megaservice.yaml")