From 8349d2428ec5df0c312953856b11ecf37b9e3e2a Mon Sep 17 00:00:00 2001 From: root Date: Sun, 20 Oct 2024 16:18:10 +0000 Subject: [PATCH 01/24] move chatqna gateway. --- ChatQnA/chatqna.py | 81 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 75 insertions(+), 6 deletions(-) diff --git a/ChatQnA/chatqna.py b/ChatQnA/chatqna.py index 3b25aeeab2..7eef231fa4 100644 --- a/ChatQnA/chatqna.py +++ b/ChatQnA/chatqna.py @@ -6,7 +6,22 @@ import os import re -from comps import ChatQnAGateway, MicroService, ServiceOrchestrator, ServiceType +from comps import ( + Gateway, + MicroService, + ServiceOrchestrator, + ServiceType, + MegaServiceEndpoint, +) +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + UsageInfo, +) +from comps.cores.proto.docarray import LLMParams, RerankerParms, RetrieverParms + from langchain_core.prompts import PromptTemplate @@ -172,14 +187,21 @@ def align_generator(self, gen, **kwargs): yield "data: [DONE]\n\n" -class ChatQnAService: +class ChatQnAService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port ServiceOrchestrator.align_inputs = align_inputs ServiceOrchestrator.align_outputs = align_outputs ServiceOrchestrator.align_generator = align_generator - self.megaservice = ServiceOrchestrator() + + super().__init__(megaservice=ServiceOrchestrator(), + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.CHAT_QNA), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse + ) def add_remote_service(self): @@ -222,7 +244,6 @@ def add_remote_service(self): self.megaservice.flow_to(embedding, retriever) self.megaservice.flow_to(retriever, rerank) self.megaservice.flow_to(rerank, llm) - self.gateway = ChatQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) def add_remote_service_without_rerank(self): @@ -255,7 +276,6 @@ def add_remote_service_without_rerank(self): self.megaservice.add(embedding).add(retriever).add(llm) self.megaservice.flow_to(embedding, retriever) self.megaservice.flow_to(retriever, llm) - self.gateway = ChatQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) def add_remote_service_with_guardrails(self): guardrail_in = MicroService( @@ -313,8 +333,55 @@ def add_remote_service_with_guardrails(self): self.megaservice.flow_to(retriever, rerank) self.megaservice.flow_to(rerank, llm) # self.megaservice.flow_to(llm, guardrail_out) - self.gateway = ChatQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + async def handle_request(self, request: Request): + data = await request.json() + stream_opt = data.get("stream", True) + chat_request = ChatCompletionRequest.parse_obj(data) + prompt = self._handle_message(chat_request.messages) + parameters = LLMParams( + max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, + presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, + repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, + streaming=stream_opt, + chat_template=chat_request.chat_template if chat_request.chat_template else None, + ) + retriever_parameters = RetrieverParms( + search_type=chat_request.search_type if chat_request.search_type else "similarity", + k=chat_request.k if chat_request.k else 4, + distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None, + fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20, + lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5, + score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2, + ) + reranker_parameters = RerankerParms( + top_n=chat_request.top_n if chat_request.top_n else 1, + ) + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"text": prompt}, + llm_parameters=parameters, + retriever_parameters=retriever_parameters, + reranker_parameters=reranker_parameters, + ) + for node, response in result_dict.items(): + if isinstance(response, StreamingResponse): + return response + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["text"] + choices = [] + usage = UsageInfo() + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + ) + ) + return ChatCompletionResponse(model="chatqna", choices=choices, usage=usage) if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -330,3 +397,5 @@ def add_remote_service_with_guardrails(self): chatqna.add_remote_service_with_guardrails() else: chatqna.add_remote_service() + + # chatqna.start() From 3da7d3cfb3ce4e214bda76b0e504e971fbb37513 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 20 Oct 2024 16:18:11 +0000 Subject: [PATCH 02/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- ChatQnA/chatqna.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/ChatQnA/chatqna.py b/ChatQnA/chatqna.py index 7eef231fa4..afff320be9 100644 --- a/ChatQnA/chatqna.py +++ b/ChatQnA/chatqna.py @@ -6,13 +6,7 @@ import os import re -from comps import ( - Gateway, - MicroService, - ServiceOrchestrator, - ServiceType, - MegaServiceEndpoint, -) +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType from comps.cores.proto.api_protocol import ( ChatCompletionRequest, ChatCompletionResponse, @@ -21,7 +15,6 @@ UsageInfo, ) from comps.cores.proto.docarray import LLMParams, RerankerParms, RetrieverParms - from langchain_core.prompts import PromptTemplate @@ -195,12 +188,13 @@ def __init__(self, host="0.0.0.0", port=8000): ServiceOrchestrator.align_outputs = align_outputs ServiceOrchestrator.align_generator = align_generator - super().__init__(megaservice=ServiceOrchestrator(), + super().__init__( + megaservice=ServiceOrchestrator(), host=self.host, port=self.port, endpoint=str(MegaServiceEndpoint.CHAT_QNA), - input_datatype=ChatCompletionRequest, - output_datatype=ChatCompletionResponse + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, ) def add_remote_service(self): @@ -383,6 +377,7 @@ async def handle_request(self, request: Request): ) return ChatCompletionResponse(model="chatqna", choices=choices, usage=usage) + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--without-rerank", action="store_true") From eea3db3b85ac0df30d21e11752175ffe9213ad05 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 20 Oct 2024 16:24:49 +0000 Subject: [PATCH 03/24] fix import issue. --- ChatQnA/chatqna.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ChatQnA/chatqna.py b/ChatQnA/chatqna.py index afff320be9..db40398ae0 100644 --- a/ChatQnA/chatqna.py +++ b/ChatQnA/chatqna.py @@ -5,6 +5,8 @@ import json import os import re +from fastapi import Request +from fastapi.responses import StreamingResponse from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType from comps.cores.proto.api_protocol import ( From e51a6d7c33004ee7741f6bc15b6c3e4ed11a5cf9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 20 Oct 2024 16:23:54 +0000 Subject: [PATCH 04/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- ChatQnA/chatqna.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ChatQnA/chatqna.py b/ChatQnA/chatqna.py index db40398ae0..e6bba4768e 100644 --- a/ChatQnA/chatqna.py +++ b/ChatQnA/chatqna.py @@ -5,8 +5,6 @@ import json import os import re -from fastapi import Request -from fastapi.responses import StreamingResponse from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType from comps.cores.proto.api_protocol import ( @@ -17,6 +15,8 @@ UsageInfo, ) from comps.cores.proto.docarray import LLMParams, RerankerParms, RetrieverParms +from fastapi import Request +from fastapi.responses import StreamingResponse from langchain_core.prompts import PromptTemplate From a252db33418268dca404a20612f74db41c012e79 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 20 Oct 2024 16:33:47 +0000 Subject: [PATCH 05/24] move codegen gateway. --- CodeGen/codegen.py | 63 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 60 insertions(+), 3 deletions(-) diff --git a/CodeGen/codegen.py b/CodeGen/codegen.py index ac04cc7876..0ff99fee93 100644 --- a/CodeGen/codegen.py +++ b/CodeGen/codegen.py @@ -4,7 +4,18 @@ import asyncio import os -from comps import CodeGenGateway, MicroService, ServiceOrchestrator, ServiceType +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + UsageInfo, +) +from comps.cores.proto.docarray import LLMParams +from fastapi import Request +from fastapi.responses import StreamingResponse + MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 7778)) @@ -12,11 +23,19 @@ LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) -class CodeGenService: +class CodeGenService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port self.megaservice = ServiceOrchestrator() + super().__init__( + megaservice=ServiceOrchestrator(), + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.CHAT_QNA), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) def add_remote_service(self): llm = MicroService( @@ -28,7 +47,45 @@ def add_remote_service(self): service_type=ServiceType.LLM, ) self.megaservice.add(llm) - self.gateway = CodeGenGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + async def handle_request(self, request: Request): + data = await request.json() + stream_opt = data.get("stream", True) + chat_request = ChatCompletionRequest.parse_obj(data) + prompt = self._handle_message(chat_request.messages) + parameters = LLMParams( + max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, + presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, + repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, + streaming=stream_opt, + ) + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"query": prompt}, llm_parameters=parameters + ) + for node, response in result_dict.items(): + # Here it suppose the last microservice in the megaservice is LLM. + if ( + isinstance(response, StreamingResponse) + and node == list(self.megaservice.services.keys())[-1] + and self.megaservice.services[node].service_type == ServiceType.LLM + ): + return response + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["text"] + choices = [] + usage = UsageInfo() + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + ) + ) + return ChatCompletionResponse(model="codegen", choices=choices, usage=usage) if __name__ == "__main__": From 11c1516b364f1b311d13ffd67310fd560232b36a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 20 Oct 2024 16:33:09 +0000 Subject: [PATCH 06/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- CodeGen/codegen.py | 1 - 1 file changed, 1 deletion(-) diff --git a/CodeGen/codegen.py b/CodeGen/codegen.py index 0ff99fee93..be7bf77bec 100644 --- a/CodeGen/codegen.py +++ b/CodeGen/codegen.py @@ -16,7 +16,6 @@ from fastapi import Request from fastapi.responses import StreamingResponse - MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 7778)) LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0") From 91bae4d3cecbc49a9fe6c487ec62b39d3c2cf3f6 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 21 Oct 2024 07:58:59 +0000 Subject: [PATCH 07/24] move code_translation gateway. --- CodeGen/codegen.py | 3 +- CodeTrans/code_translation.py | 64 ++++++++++++++++++++++++++++++++--- 2 files changed, 61 insertions(+), 6 deletions(-) diff --git a/CodeGen/codegen.py b/CodeGen/codegen.py index be7bf77bec..0a7edb04ec 100644 --- a/CodeGen/codegen.py +++ b/CodeGen/codegen.py @@ -26,12 +26,11 @@ class CodeGenService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port - self.megaservice = ServiceOrchestrator() super().__init__( megaservice=ServiceOrchestrator(), host=self.host, port=self.port, - endpoint=str(MegaServiceEndpoint.CHAT_QNA), + endpoint=str(MegaServiceEndpoint.CODE_GEN), input_datatype=ChatCompletionRequest, output_datatype=ChatCompletionResponse, ) diff --git a/CodeTrans/code_translation.py b/CodeTrans/code_translation.py index 1957485112..46b210b70f 100644 --- a/CodeTrans/code_translation.py +++ b/CodeTrans/code_translation.py @@ -4,7 +4,17 @@ import asyncio import os -from comps import CodeTransGateway, MicroService, ServiceOrchestrator +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + UsageInfo, +) +from fastapi import Request +from fastapi.responses import StreamingResponse + MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 7777)) @@ -12,11 +22,18 @@ LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) -class CodeTransService: +class CodeTransService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port - self.megaservice = ServiceOrchestrator() + super().__init__( + megaservice=ServiceOrchestrator(), + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.CODE_TRANS), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) def add_remote_service(self): llm = MicroService( @@ -27,7 +44,46 @@ def add_remote_service(self): use_remote_service=True, ) self.megaservice.add(llm) - self.gateway = CodeTransGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + async def handle_request(self, request: Request): + data = await request.json() + language_from = data["language_from"] + language_to = data["language_to"] + source_code = data["source_code"] + prompt_template = """ + ### System: Please translate the following {language_from} codes into {language_to} codes. + + ### Original codes: + '''{language_from} + + {source_code} + + ''' + + ### Translated codes: + """ + prompt = prompt_template.format(language_from=language_from, language_to=language_to, source_code=source_code) + result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"query": prompt}) + for node, response in result_dict.items(): + # Here it suppose the last microservice in the megaservice is LLM. + if ( + isinstance(response, StreamingResponse) + and node == list(self.megaservice.services.keys())[-1] + and self.megaservice.services[node].service_type == ServiceType.LLM + ): + return response + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["text"] + choices = [] + usage = UsageInfo() + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + ) + ) + return ChatCompletionResponse(model="codetrans", choices=choices, usage=usage) if __name__ == "__main__": From e11773de0ad5dcf06ddb6c5255c7827911a085be Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 21 Oct 2024 07:58:04 +0000 Subject: [PATCH 08/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- CodeTrans/code_translation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/CodeTrans/code_translation.py b/CodeTrans/code_translation.py index 46b210b70f..b65101aabf 100644 --- a/CodeTrans/code_translation.py +++ b/CodeTrans/code_translation.py @@ -15,7 +15,6 @@ from fastapi import Request from fastapi.responses import StreamingResponse - MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 7777)) LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0") From a07a12b7fd771f5caf48244a226671e5205f00ed Mon Sep 17 00:00:00 2001 From: root Date: Mon, 21 Oct 2024 09:37:53 +0000 Subject: [PATCH 09/24] update all examples gateway. --- AudioQnA/audioqna.py | 45 ++++++- DocIndexRetriever/retrieval_tool.py | 83 +++++++++++- DocSum/docsum.py | 64 ++++++++- FaqGen/faqgen.py | 64 ++++++++- MultimodalQnA/multimodalqna.py | 194 ++++++++++++++++++++++++++-- SearchQnA/searchqna.py | 64 ++++++++- Translation/translation.py | 61 ++++++++- VideoQnA/videoqna.py | 64 ++++++++- VisualQnA/visualqna.py | 64 ++++++++- 9 files changed, 660 insertions(+), 43 deletions(-) diff --git a/AudioQnA/audioqna.py b/AudioQnA/audioqna.py index 7987620847..a2be880489 100644 --- a/AudioQnA/audioqna.py +++ b/AudioQnA/audioqna.py @@ -4,7 +4,14 @@ import asyncio import os -from comps import AudioQnAGateway, MicroService, ServiceOrchestrator, ServiceType +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import ( + AudioChatCompletionRequest, + ChatCompletionResponse, +) +from comps.cores.proto.docarray import LLMParams +from fastapi import Request + MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) @@ -16,11 +23,18 @@ TTS_SERVICE_PORT = int(os.getenv("TTS_SERVICE_PORT", 9088)) -class AudioQnAService: +class AudioQnAService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port - self.megaservice = ServiceOrchestrator() + super().__init__( + megaservice=ServiceOrchestrator(), + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.AUDIO_QNA), + input_datatype=AudioChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) def add_remote_service(self): asr = MicroService( @@ -50,7 +64,30 @@ def add_remote_service(self): self.megaservice.add(asr).add(llm).add(tts) self.megaservice.flow_to(asr, llm) self.megaservice.flow_to(llm, tts) - self.gateway = AudioQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + async def handle_request(self, request: Request): + data = await request.json() + + chat_request = AudioChatCompletionRequest.parse_obj(data) + parameters = LLMParams( + # relatively lower max_tokens for audio conversation + max_tokens=chat_request.max_tokens if chat_request.max_tokens else 128, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, + presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, + repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, + streaming=False, # TODO add streaming LLM output as input to TTS + ) + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"byte_str": chat_request.audio}, llm_parameters=parameters + ) + + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["byte_str"] + + return response if __name__ == "__main__": diff --git a/DocIndexRetriever/retrieval_tool.py b/DocIndexRetriever/retrieval_tool.py index d3adc8d352..3b6015d341 100644 --- a/DocIndexRetriever/retrieval_tool.py +++ b/DocIndexRetriever/retrieval_tool.py @@ -4,7 +4,17 @@ import asyncio import os -from comps import MicroService, RetrievalToolGateway, ServiceOrchestrator, ServiceType +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import ChatCompletionRequest, EmbeddingRequest +from comps.cores.proto.docarray import ( + LLMParamsDoc, + RerankedDoc, + RerankerParms, + RetrieverParms, + TextDoc, +) +from fastapi import Request +from fastapi.responses import StreamingResponse MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = os.getenv("MEGA_SERVICE_PORT", 8889) @@ -16,11 +26,19 @@ RERANK_SERVICE_PORT = os.getenv("RERANK_SERVICE_PORT", 8000) -class RetrievalToolService: +class RetrievalToolService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port - self.megaservice = ServiceOrchestrator() + + super().__init__( + megaservice=ServiceOrchestrator(), + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.RETRIEVALTOOL), + input_datatype=Union[TextDoc, EmbeddingRequest, ChatCompletionRequest], + output_datatype=Union[RerankedDoc, LLMParamsDoc], + ) def add_remote_service(self): embedding = MicroService( @@ -51,7 +69,64 @@ def add_remote_service(self): self.megaservice.add(embedding).add(retriever).add(rerank) self.megaservice.flow_to(embedding, retriever) self.megaservice.flow_to(retriever, rerank) - self.gateway = RetrievalToolGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + async def handle_request(self, request: Request): + def parser_input(data, TypeClass, key): + chat_request = None + try: + chat_request = TypeClass.parse_obj(data) + query = getattr(chat_request, key) + except: + query = None + return query, chat_request + + data = await request.json() + query = None + for key, TypeClass in zip(["text", "input", "messages"], [TextDoc, EmbeddingRequest, ChatCompletionRequest]): + query, chat_request = parser_input(data, TypeClass, key) + if query is not None: + break + if query is None: + raise ValueError(f"Unknown request type: {data}") + if chat_request is None: + raise ValueError(f"Unknown request type: {data}") + + if isinstance(chat_request, ChatCompletionRequest): + retriever_parameters = RetrieverParms( + search_type=chat_request.search_type if chat_request.search_type else "similarity", + k=chat_request.k if chat_request.k else 4, + distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None, + fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20, + lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5, + score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2, + ) + reranker_parameters = RerankerParms( + top_n=chat_request.top_n if chat_request.top_n else 1, + ) + + initial_inputs = { + "messages": query, + "input": query, # has to be input due to embedding expects either input or text + "search_type": chat_request.search_type if chat_request.search_type else "similarity", + "k": chat_request.k if chat_request.k else 4, + "distance_threshold": chat_request.distance_threshold if chat_request.distance_threshold else None, + "fetch_k": chat_request.fetch_k if chat_request.fetch_k else 20, + "lambda_mult": chat_request.lambda_mult if chat_request.lambda_mult else 0.5, + "score_threshold": chat_request.score_threshold if chat_request.score_threshold else 0.2, + "top_n": chat_request.top_n if chat_request.top_n else 1, + } + + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs=initial_inputs, + retriever_parameters=retriever_parameters, + reranker_parameters=reranker_parameters, + ) + else: + result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"text": query}) + + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node] + return response if __name__ == "__main__": diff --git a/DocSum/docsum.py b/DocSum/docsum.py index fe6d3229c9..010d08f006 100644 --- a/DocSum/docsum.py +++ b/DocSum/docsum.py @@ -4,7 +4,17 @@ import asyncio import os -from comps import DocSumGateway, MicroService, ServiceOrchestrator, ServiceType +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + UsageInfo, +) +from comps.cores.proto.docarray import LLMParams +from fastapi import Request +from fastapi.responses import StreamingResponse MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) @@ -12,11 +22,18 @@ LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) -class DocSumService: +class DocSumService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port - self.megaservice = ServiceOrchestrator() + super().__init__( + megaservice=ServiceOrchestrator(), + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.DOC_SUMMARY), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) def add_remote_service(self): llm = MicroService( @@ -28,7 +45,46 @@ def add_remote_service(self): service_type=ServiceType.LLM, ) self.megaservice.add(llm) - self.gateway = DocSumGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + async def handle_request(self, request: Request): + data = await request.json() + stream_opt = data.get("stream", True) + chat_request = ChatCompletionRequest.parse_obj(data) + prompt = self._handle_message(chat_request.messages) + parameters = LLMParams( + max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, + presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, + repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, + streaming=stream_opt, + language=chat_request.language if chat_request.language else "auto", + ) + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"query": prompt}, llm_parameters=parameters + ) + for node, response in result_dict.items(): + # Here it suppose the last microservice in the megaservice is LLM. + if ( + isinstance(response, StreamingResponse) + and node == list(self.megaservice.services.keys())[-1] + and self.megaservice.services[node].service_type == ServiceType.LLM + ): + return response + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["text"] + choices = [] + usage = UsageInfo() + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + ) + ) + return ChatCompletionResponse(model="docsum", choices=choices, usage=usage) if __name__ == "__main__": diff --git a/FaqGen/faqgen.py b/FaqGen/faqgen.py index bfcf148711..5ee5178a50 100644 --- a/FaqGen/faqgen.py +++ b/FaqGen/faqgen.py @@ -4,7 +4,17 @@ import asyncio import os -from comps import FaqGenGateway, MicroService, ServiceOrchestrator, ServiceType +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + UsageInfo, +) +from comps.cores.proto.docarray import LLMParams +from fastapi import Request +from fastapi.responses import StreamingResponse MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) @@ -12,11 +22,19 @@ LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) -class FaqGenService: +class FaqGenService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port - self.megaservice = ServiceOrchestrator() + + super().__init__( + megaservice=ServiceOrchestrator(), + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.FAQ_GEN), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) def add_remote_service(self): llm = MicroService( @@ -28,7 +46,45 @@ def add_remote_service(self): service_type=ServiceType.LLM, ) self.megaservice.add(llm) - self.gateway = FaqGenGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + async def handle_request(self, request: Request): + data = await request.json() + stream_opt = data.get("stream", True) + chat_request = ChatCompletionRequest.parse_obj(data) + prompt = self._handle_message(chat_request.messages) + parameters = LLMParams( + max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, + presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, + repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, + streaming=stream_opt, + ) + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"query": prompt}, llm_parameters=parameters + ) + for node, response in result_dict.items(): + # Here it suppose the last microservice in the megaservice is LLM. + if ( + isinstance(response, StreamingResponse) + and node == list(self.megaservice.services.keys())[-1] + and self.megaservice.services[node].service_type == ServiceType.LLM + ): + return response + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["text"] + choices = [] + usage = UsageInfo() + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + ) + ) + return ChatCompletionResponse(model="faqgen", choices=choices, usage=usage) if __name__ == "__main__": diff --git a/MultimodalQnA/multimodalqna.py b/MultimodalQnA/multimodalqna.py index ea1c104dc7..b4561dadc5 100644 --- a/MultimodalQnA/multimodalqna.py +++ b/MultimodalQnA/multimodalqna.py @@ -3,7 +3,17 @@ import os -from comps import MicroService, MultimodalQnAGateway, ServiceOrchestrator, ServiceType +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + UsageInfo, +) +from comps.cores.proto.docarray import LLMParams +from fastapi import Request +from fastapi.responses import StreamingResponse MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) @@ -15,13 +25,21 @@ LVM_SERVICE_PORT = int(os.getenv("LVM_SERVICE_PORT", 9399)) -class MultimodalQnAService: +class MultimodalQnAService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port - self.mmrag_megaservice = ServiceOrchestrator() self.lvm_megaservice = ServiceOrchestrator() + super().__init__( + megaservice=ServiceOrchestrator(), + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.MULTIMODAL_QNA), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) + def add_remote_service(self): mm_embedding = MicroService( name="embedding", @@ -50,19 +68,173 @@ def add_remote_service(self): ) # for mmrag megaservice - self.mmrag_megaservice.add(mm_embedding).add(mm_retriever).add(lvm) - self.mmrag_megaservice.flow_to(mm_embedding, mm_retriever) - self.mmrag_megaservice.flow_to(mm_retriever, lvm) + self.megaservice.add(mm_embedding).add(mm_retriever).add(lvm) + self.megaservice.flow_to(mm_embedding, mm_retriever) + self.megaservice.flow_to(mm_retriever, lvm) # for lvm megaservice self.lvm_megaservice.add(lvm) - self.gateway = MultimodalQnAGateway( - multimodal_rag_megaservice=self.mmrag_megaservice, - lvm_megaservice=self.lvm_megaservice, - host="0.0.0.0", - port=self.port, + # this overrides _handle_message method of Gateway + def _handle_message(self, messages): + images = [] + messages_dicts = [] + if isinstance(messages, str): + prompt = messages + else: + messages_dict = {} + system_prompt = "" + prompt = "" + for message in messages: + msg_role = message["role"] + messages_dict = {} + if msg_role == "system": + system_prompt = message["content"] + elif msg_role == "user": + if type(message["content"]) == list: + text = "" + text_list = [item["text"] for item in message["content"] if item["type"] == "text"] + text += "\n".join(text_list) + image_list = [ + item["image_url"]["url"] for item in message["content"] if item["type"] == "image_url" + ] + if image_list: + messages_dict[msg_role] = (text, image_list) + else: + messages_dict[msg_role] = text + else: + messages_dict[msg_role] = message["content"] + messages_dicts.append(messages_dict) + elif msg_role == "assistant": + messages_dict[msg_role] = message["content"] + messages_dicts.append(messages_dict) + else: + raise ValueError(f"Unknown role: {msg_role}") + + if system_prompt: + prompt = system_prompt + "\n" + for messages_dict in messages_dicts: + for i, (role, message) in enumerate(messages_dict.items()): + if isinstance(message, tuple): + text, image_list = message + if i == 0: + # do not add role for the very first message. + # this will be added by llava_server + if text: + prompt += text + "\n" + else: + if text: + prompt += role.upper() + ": " + text + "\n" + else: + prompt += role.upper() + ":" + for img in image_list: + # URL + if img.startswith("http://") or img.startswith("https://"): + response = requests.get(img) + image = Image.open(BytesIO(response.content)).convert("RGBA") + image_bytes = BytesIO() + image.save(image_bytes, format="PNG") + img_b64_str = base64.b64encode(image_bytes.getvalue()).decode() + # Local Path + elif os.path.exists(img): + image = Image.open(img).convert("RGBA") + image_bytes = BytesIO() + image.save(image_bytes, format="PNG") + img_b64_str = base64.b64encode(image_bytes.getvalue()).decode() + # Bytes + else: + img_b64_str = img + + images.append(img_b64_str) + else: + if i == 0: + # do not add role for the very first message. + # this will be added by llava_server + if message: + prompt += role.upper() + ": " + message + "\n" + else: + if message: + prompt += role.upper() + ": " + message + "\n" + else: + prompt += role.upper() + ":" + if images: + return prompt, images + else: + return prompt + + async def handle_request(self, request: Request): + data = await request.json() + stream_opt = bool(data.get("stream", False)) + if stream_opt: + print("[ MultimodalQnAGateway ] stream=True not used, this has not support streaming yet!") + stream_opt = False + chat_request = ChatCompletionRequest.model_validate(data) + # Multimodal RAG QnA With Videos has not yet accepts image as input during QnA. + prompt_and_image = self._handle_message(chat_request.messages) + if isinstance(prompt_and_image, tuple): + # print(f"This request include image, thus it is a follow-up query. Using lvm megaservice") + prompt, images = prompt_and_image + cur_megaservice = self.lvm_megaservice + initial_inputs = {"prompt": prompt, "image": images[0]} + else: + # print(f"This is the first query, requiring multimodal retrieval. Using multimodal rag megaservice") + prompt = prompt_and_image + cur_megaservice = self.megaservice + initial_inputs = {"text": prompt} + + parameters = LLMParams( + max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, + presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, + repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, + streaming=stream_opt, + chat_template=chat_request.chat_template if chat_request.chat_template else None, + ) + result_dict, runtime_graph = await cur_megaservice.schedule( + initial_inputs=initial_inputs, llm_parameters=parameters + ) + for node, response in result_dict.items(): + # the last microservice in this megaservice is LVM. + # checking if LVM returns StreamingResponse + # Currently, LVM with LLAVA has not yet supported streaming. + # @TODO: Will need to test this once LVM with LLAVA supports streaming + if ( + isinstance(response, StreamingResponse) + and node == runtime_graph.all_leaves()[-1] + and self.megaservice.services[node].service_type == ServiceType.LVM + ): + return response + last_node = runtime_graph.all_leaves()[-1] + + if "text" in result_dict[last_node].keys(): + response = result_dict[last_node]["text"] + else: + # text in not response message + # something wrong, for example due to empty retrieval results + if "detail" in result_dict[last_node].keys(): + response = result_dict[last_node]["detail"] + else: + response = "The server fail to generate answer to your query!" + if "metadata" in result_dict[last_node].keys(): + # from retrieval results + metadata = result_dict[last_node]["metadata"] + else: + # follow-up question, no retrieval + metadata = None + choices = [] + usage = UsageInfo() + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + metadata=metadata, + ) ) + return ChatCompletionResponse(model="multimodalqna", choices=choices, usage=usage) if __name__ == "__main__": diff --git a/SearchQnA/searchqna.py b/SearchQnA/searchqna.py index fa620d42a6..fe4d00aa62 100644 --- a/SearchQnA/searchqna.py +++ b/SearchQnA/searchqna.py @@ -3,7 +3,17 @@ import os -from comps import MicroService, SearchQnAGateway, ServiceOrchestrator, ServiceType +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + UsageInfo, +) +from comps.cores.proto.docarray import LLMParams +from fastapi import Request +from fastapi.responses import StreamingResponse MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) @@ -17,11 +27,19 @@ LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) -class SearchQnAService: +class SearchQnAService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port - self.megaservice = ServiceOrchestrator() + + super().__init__( + megaservice=ServiceOrchestrator(), + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.SEARCH_QNA), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) def add_remote_service(self): embedding = MicroService( @@ -60,7 +78,45 @@ def add_remote_service(self): self.megaservice.flow_to(embedding, web_retriever) self.megaservice.flow_to(web_retriever, rerank) self.megaservice.flow_to(rerank, llm) - self.gateway = SearchQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + async def handle_request(self, request: Request): + data = await request.json() + stream_opt = data.get("stream", True) + chat_request = ChatCompletionRequest.parse_obj(data) + prompt = self._handle_message(chat_request.messages) + parameters = LLMParams( + max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, + presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, + repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, + streaming=stream_opt, + ) + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"text": prompt}, llm_parameters=parameters + ) + for node, response in result_dict.items(): + # Here it suppose the last microservice in the megaservice is LLM. + if ( + isinstance(response, StreamingResponse) + and node == list(self.megaservice.services.keys())[-1] + and self.megaservice.services[node].service_type == ServiceType.LLM + ): + return response + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["text"] + choices = [] + usage = UsageInfo() + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + ) + ) + return ChatCompletionResponse(model="searchqna", choices=choices, usage=usage) if __name__ == "__main__": diff --git a/Translation/translation.py b/Translation/translation.py index 86093a166e..54fea03f35 100644 --- a/Translation/translation.py +++ b/Translation/translation.py @@ -15,7 +15,16 @@ import asyncio import os -from comps import MicroService, ServiceOrchestrator, ServiceType, TranslationGateway +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + UsageInfo, +) +from fastapi import Request +from fastapi.responses import StreamingResponse MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) @@ -23,11 +32,18 @@ LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) -class TranslationService: +class TranslationService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port - self.megaservice = ServiceOrchestrator() + super().__init__( + megaservice=ServiceOrchestrator(), + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.TRANSLATION), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) def add_remote_service(self): llm = MicroService( @@ -39,7 +55,44 @@ def add_remote_service(self): service_type=ServiceType.LLM, ) self.megaservice.add(llm) - self.gateway = TranslationGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + async def handle_request(self, request: Request): + data = await request.json() + language_from = data["language_from"] + language_to = data["language_to"] + source_language = data["source_language"] + prompt_template = """ + Translate this from {language_from} to {language_to}: + + {language_from}: + {source_language} + + {language_to}: + """ + prompt = prompt_template.format( + language_from=language_from, language_to=language_to, source_language=source_language + ) + result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"query": prompt}) + for node, response in result_dict.items(): + # Here it suppose the last microservice in the megaservice is LLM. + if ( + isinstance(response, StreamingResponse) + and node == list(self.megaservice.services.keys())[-1] + and self.megaservice.services[node].service_type == ServiceType.LLM + ): + return response + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["text"] + choices = [] + usage = UsageInfo() + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + ) + ) + return ChatCompletionResponse(model="translation", choices=choices, usage=usage) if __name__ == "__main__": diff --git a/VideoQnA/videoqna.py b/VideoQnA/videoqna.py index fc8fe778e5..cced72a943 100644 --- a/VideoQnA/videoqna.py +++ b/VideoQnA/videoqna.py @@ -3,7 +3,17 @@ import os -from comps import MicroService, ServiceOrchestrator, ServiceType, VideoQnAGateway +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + UsageInfo, +) +from comps.cores.proto.docarray import LLMParams +from fastapi import Request +from fastapi.responses import StreamingResponse MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) @@ -17,11 +27,19 @@ LVM_SERVICE_PORT = int(os.getenv("LVM_SERVICE_PORT", 9000)) -class VideoQnAService: +class VideoQnAService(Gateway): def __init__(self, host="0.0.0.0", port=8888): self.host = host self.port = port - self.megaservice = ServiceOrchestrator() + + super().__init__( + megaservice=ServiceOrchestrator(), + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.VIDEO_RAG_QNA), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) def add_remote_service(self): embedding = MicroService( @@ -60,7 +78,45 @@ def add_remote_service(self): self.megaservice.flow_to(embedding, retriever) self.megaservice.flow_to(retriever, rerank) self.megaservice.flow_to(rerank, lvm) - self.gateway = VideoQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + async def handle_request(self, request: Request): + data = await request.json() + stream_opt = data.get("stream", False) + chat_request = ChatCompletionRequest.parse_obj(data) + prompt = self._handle_message(chat_request.messages) + parameters = LLMParams( + max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, + presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, + repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, + streaming=stream_opt, + ) + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"text": prompt}, llm_parameters=parameters + ) + for node, response in result_dict.items(): + # Here it suppose the last microservice in the megaservice is LVM. + if ( + isinstance(response, StreamingResponse) + and node == list(self.megaservice.services.keys())[-1] + and self.megaservice.services[node].service_type == ServiceType.LVM + ): + return response + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["text"] + choices = [] + usage = UsageInfo() + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + ) + ) + return ChatCompletionResponse(model="videoqna", choices=choices, usage=usage) if __name__ == "__main__": diff --git a/VisualQnA/visualqna.py b/VisualQnA/visualqna.py index 4ad850b006..951c9ed1e8 100644 --- a/VisualQnA/visualqna.py +++ b/VisualQnA/visualqna.py @@ -3,7 +3,17 @@ import os -from comps import MicroService, ServiceOrchestrator, ServiceType, VisualQnAGateway +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + UsageInfo, +) +from comps.cores.proto.docarray import LLMParams +from fastapi import Request +from fastapi.responses import StreamingResponse MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) @@ -11,11 +21,19 @@ LVM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9399)) -class VisualQnAService: +class VisualQnAService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port - self.megaservice = ServiceOrchestrator() + + super().__init__( + megaservice=ServiceOrchestrator(), + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.VISUAL_QNA), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) def add_remote_service(self): llm = MicroService( @@ -27,7 +45,45 @@ def add_remote_service(self): service_type=ServiceType.LVM, ) self.megaservice.add(llm) - self.gateway = VisualQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + async def handle_request(self, request: Request): + data = await request.json() + stream_opt = data.get("stream", False) + chat_request = ChatCompletionRequest.parse_obj(data) + prompt, images = self._handle_message(chat_request.messages) + parameters = LLMParams( + max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, + presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, + repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, + streaming=stream_opt, + ) + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"prompt": prompt, "image": images[0]}, llm_parameters=parameters + ) + for node, response in result_dict.items(): + # Here it suppose the last microservice in the megaservice is LVM. + if ( + isinstance(response, StreamingResponse) + and node == list(self.megaservice.services.keys())[-1] + and self.megaservice.services[node].service_type == ServiceType.LVM + ): + return response + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["text"] + choices = [] + usage = UsageInfo() + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + ) + ) + return ChatCompletionResponse(model="visualqna", choices=choices, usage=usage) if __name__ == "__main__": From ff1b675112cb5ff69d7d304b5c83489ba65a1d04 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 21 Oct 2024 09:37:16 +0000 Subject: [PATCH 10/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- AudioQnA/audioqna.py | 6 +----- DocIndexRetriever/retrieval_tool.py | 8 +------- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/AudioQnA/audioqna.py b/AudioQnA/audioqna.py index a2be880489..761e779f82 100644 --- a/AudioQnA/audioqna.py +++ b/AudioQnA/audioqna.py @@ -5,14 +5,10 @@ import os from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType -from comps.cores.proto.api_protocol import ( - AudioChatCompletionRequest, - ChatCompletionResponse, -) +from comps.cores.proto.api_protocol import AudioChatCompletionRequest, ChatCompletionResponse from comps.cores.proto.docarray import LLMParams from fastapi import Request - MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) ASR_SERVICE_HOST_IP = os.getenv("ASR_SERVICE_HOST_IP", "0.0.0.0") diff --git a/DocIndexRetriever/retrieval_tool.py b/DocIndexRetriever/retrieval_tool.py index 3b6015d341..6b785099ef 100644 --- a/DocIndexRetriever/retrieval_tool.py +++ b/DocIndexRetriever/retrieval_tool.py @@ -6,13 +6,7 @@ from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType from comps.cores.proto.api_protocol import ChatCompletionRequest, EmbeddingRequest -from comps.cores.proto.docarray import ( - LLMParamsDoc, - RerankedDoc, - RerankerParms, - RetrieverParms, - TextDoc, -) +from comps.cores.proto.docarray import LLMParamsDoc, RerankedDoc, RerankerParms, RetrieverParms, TextDoc from fastapi import Request from fastapi.responses import StreamingResponse From 324df2aaac0dbc23e39dfbbb67ecaf7cc37bce72 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 21 Oct 2024 09:44:03 +0000 Subject: [PATCH 11/24] fix import issue. --- DocIndexRetriever/retrieval_tool.py | 1 + MultimodalQnA/multimodalqna.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/DocIndexRetriever/retrieval_tool.py b/DocIndexRetriever/retrieval_tool.py index 6b785099ef..52a1a2ca52 100644 --- a/DocIndexRetriever/retrieval_tool.py +++ b/DocIndexRetriever/retrieval_tool.py @@ -3,6 +3,7 @@ import asyncio import os +from typing import Union from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType from comps.cores.proto.api_protocol import ChatCompletionRequest, EmbeddingRequest diff --git a/MultimodalQnA/multimodalqna.py b/MultimodalQnA/multimodalqna.py index b4561dadc5..8fdf30f93b 100644 --- a/MultimodalQnA/multimodalqna.py +++ b/MultimodalQnA/multimodalqna.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import os - +import base64 from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType from comps.cores.proto.api_protocol import ( ChatCompletionRequest, @@ -12,8 +12,12 @@ UsageInfo, ) from comps.cores.proto.docarray import LLMParams +import requests from fastapi import Request from fastapi.responses import StreamingResponse +from PIL import Image +from io import BytesIO + MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) From cdcfbe9759e38dc99b484b3650f54b256ef9f8d2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 21 Oct 2024 09:43:08 +0000 Subject: [PATCH 12/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- MultimodalQnA/multimodalqna.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/MultimodalQnA/multimodalqna.py b/MultimodalQnA/multimodalqna.py index 8fdf30f93b..fe6a9f6c8f 100644 --- a/MultimodalQnA/multimodalqna.py +++ b/MultimodalQnA/multimodalqna.py @@ -1,8 +1,11 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -import os import base64 +import os +from io import BytesIO + +import requests from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType from comps.cores.proto.api_protocol import ( ChatCompletionRequest, @@ -12,12 +15,9 @@ UsageInfo, ) from comps.cores.proto.docarray import LLMParams -import requests from fastapi import Request from fastapi.responses import StreamingResponse from PIL import Image -from io import BytesIO - MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) From 866d72e56414e417b2fcdae23fa989cdbe3377db Mon Sep 17 00:00:00 2001 From: root Date: Tue, 29 Oct 2024 08:19:59 +0000 Subject: [PATCH 13/24] update service start entry. --- AudioQnA/audioqna.py | 20 ++++++++++++-------- ChatQnA/chatqna.py | 22 +++++++++++++--------- CodeGen/codegen.py | 20 ++++++++++++-------- CodeTrans/code_translation.py | 20 ++++++++++++-------- DocIndexRetriever/retrieval_tool.py | 21 ++++++++++++--------- DocSum/docsum.py | 20 ++++++++++++-------- FaqGen/faqgen.py | 21 ++++++++++++--------- MultimodalQnA/multimodalqna.py | 21 ++++++++++++--------- SearchQnA/searchqna.py | 21 ++++++++++++--------- Translation/translation.py | 20 ++++++++++++-------- VideoQnA/videoqna.py | 21 ++++++++++++--------- VisualQnA/visualqna.py | 21 ++++++++++++--------- 12 files changed, 145 insertions(+), 103 deletions(-) diff --git a/AudioQnA/audioqna.py b/AudioQnA/audioqna.py index 761e779f82..b6bf93dd04 100644 --- a/AudioQnA/audioqna.py +++ b/AudioQnA/audioqna.py @@ -23,14 +23,7 @@ class AudioQnAService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port - super().__init__( - megaservice=ServiceOrchestrator(), - host=self.host, - port=self.port, - endpoint=str(MegaServiceEndpoint.AUDIO_QNA), - input_datatype=AudioChatCompletionRequest, - output_datatype=ChatCompletionResponse, - ) + self.megaservice = ServiceOrchestrator() def add_remote_service(self): asr = MicroService( @@ -85,7 +78,18 @@ async def handle_request(self, request: Request): return response + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.AUDIO_QNA), + input_datatype=AudioChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) + if __name__ == "__main__": audioqna = AudioQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) audioqna.add_remote_service() + audioqna.start() diff --git a/ChatQnA/chatqna.py b/ChatQnA/chatqna.py index f8458969bd..45922792ab 100644 --- a/ChatQnA/chatqna.py +++ b/ChatQnA/chatqna.py @@ -191,14 +191,7 @@ def __init__(self, host="0.0.0.0", port=8000): ServiceOrchestrator.align_outputs = align_outputs ServiceOrchestrator.align_generator = align_generator - super().__init__( - megaservice=ServiceOrchestrator(), - host=self.host, - port=self.port, - endpoint=str(MegaServiceEndpoint.CHAT_QNA), - input_datatype=ChatCompletionRequest, - output_datatype=ChatCompletionResponse, - ) + self.megaservice = ServiceOrchestrator() def add_remote_service(self): @@ -380,6 +373,17 @@ async def handle_request(self, request: Request): ) return ChatCompletionResponse(model="chatqna", choices=choices, usage=usage) + def start(self): + + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.CHAT_QNA), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -396,4 +400,4 @@ async def handle_request(self, request: Request): else: chatqna.add_remote_service() - # chatqna.start() + chatqna.start() diff --git a/CodeGen/codegen.py b/CodeGen/codegen.py index 0a7edb04ec..4fff2c4816 100644 --- a/CodeGen/codegen.py +++ b/CodeGen/codegen.py @@ -26,14 +26,7 @@ class CodeGenService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port - super().__init__( - megaservice=ServiceOrchestrator(), - host=self.host, - port=self.port, - endpoint=str(MegaServiceEndpoint.CODE_GEN), - input_datatype=ChatCompletionRequest, - output_datatype=ChatCompletionResponse, - ) + self.megaservice = ServiceOrchestrator() def add_remote_service(self): llm = MicroService( @@ -85,7 +78,18 @@ async def handle_request(self, request: Request): ) return ChatCompletionResponse(model="codegen", choices=choices, usage=usage) + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.CODE_GEN), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) + if __name__ == "__main__": chatqna = CodeGenService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) chatqna.add_remote_service() + chatqna.start() diff --git a/CodeTrans/code_translation.py b/CodeTrans/code_translation.py index b65101aabf..dba14e2196 100644 --- a/CodeTrans/code_translation.py +++ b/CodeTrans/code_translation.py @@ -25,14 +25,7 @@ class CodeTransService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port - super().__init__( - megaservice=ServiceOrchestrator(), - host=self.host, - port=self.port, - endpoint=str(MegaServiceEndpoint.CODE_TRANS), - input_datatype=ChatCompletionRequest, - output_datatype=ChatCompletionResponse, - ) + self.megaservice = ServiceOrchestrator() def add_remote_service(self): llm = MicroService( @@ -84,7 +77,18 @@ async def handle_request(self, request: Request): ) return ChatCompletionResponse(model="codetrans", choices=choices, usage=usage) + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.CODE_TRANS), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) + if __name__ == "__main__": service_ochestrator = CodeTransService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) service_ochestrator.add_remote_service() + service_ochestrator.start() diff --git a/DocIndexRetriever/retrieval_tool.py b/DocIndexRetriever/retrieval_tool.py index 52a1a2ca52..c0f8f3feb4 100644 --- a/DocIndexRetriever/retrieval_tool.py +++ b/DocIndexRetriever/retrieval_tool.py @@ -25,15 +25,7 @@ class RetrievalToolService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port - - super().__init__( - megaservice=ServiceOrchestrator(), - host=self.host, - port=self.port, - endpoint=str(MegaServiceEndpoint.RETRIEVALTOOL), - input_datatype=Union[TextDoc, EmbeddingRequest, ChatCompletionRequest], - output_datatype=Union[RerankedDoc, LLMParamsDoc], - ) + self.megaservice = ServiceOrchestrator() def add_remote_service(self): embedding = MicroService( @@ -123,7 +115,18 @@ def parser_input(data, TypeClass, key): response = result_dict[last_node] return response + def start(): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.RETRIEVALTOOL), + input_datatype=Union[TextDoc, EmbeddingRequest, ChatCompletionRequest], + output_datatype=Union[RerankedDoc, LLMParamsDoc], + ) + if __name__ == "__main__": chatqna = RetrievalToolService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) chatqna.add_remote_service() + chatqna.start() diff --git a/DocSum/docsum.py b/DocSum/docsum.py index 010d08f006..913febd545 100644 --- a/DocSum/docsum.py +++ b/DocSum/docsum.py @@ -26,14 +26,7 @@ class DocSumService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port - super().__init__( - megaservice=ServiceOrchestrator(), - host=self.host, - port=self.port, - endpoint=str(MegaServiceEndpoint.DOC_SUMMARY), - input_datatype=ChatCompletionRequest, - output_datatype=ChatCompletionResponse, - ) + self.megaservice = ServiceOrchestrator() def add_remote_service(self): llm = MicroService( @@ -86,7 +79,18 @@ async def handle_request(self, request: Request): ) return ChatCompletionResponse(model="docsum", choices=choices, usage=usage) + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.DOC_SUMMARY), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) + if __name__ == "__main__": docsum = DocSumService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) docsum.add_remote_service() + docsum.start() diff --git a/FaqGen/faqgen.py b/FaqGen/faqgen.py index 5ee5178a50..8ee05f9eca 100644 --- a/FaqGen/faqgen.py +++ b/FaqGen/faqgen.py @@ -26,15 +26,7 @@ class FaqGenService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port - - super().__init__( - megaservice=ServiceOrchestrator(), - host=self.host, - port=self.port, - endpoint=str(MegaServiceEndpoint.FAQ_GEN), - input_datatype=ChatCompletionRequest, - output_datatype=ChatCompletionResponse, - ) + self.megaservice = ServiceOrchestrator() def add_remote_service(self): llm = MicroService( @@ -86,7 +78,18 @@ async def handle_request(self, request: Request): ) return ChatCompletionResponse(model="faqgen", choices=choices, usage=usage) + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.FAQ_GEN), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) + if __name__ == "__main__": faqgen = FaqGenService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) faqgen.add_remote_service() + faqgen.start() diff --git a/MultimodalQnA/multimodalqna.py b/MultimodalQnA/multimodalqna.py index fe6a9f6c8f..f1d147dd84 100644 --- a/MultimodalQnA/multimodalqna.py +++ b/MultimodalQnA/multimodalqna.py @@ -34,15 +34,7 @@ def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port self.lvm_megaservice = ServiceOrchestrator() - - super().__init__( - megaservice=ServiceOrchestrator(), - host=self.host, - port=self.port, - endpoint=str(MegaServiceEndpoint.MULTIMODAL_QNA), - input_datatype=ChatCompletionRequest, - output_datatype=ChatCompletionResponse, - ) + self.megaservice = ServiceOrchestrator() def add_remote_service(self): mm_embedding = MicroService( @@ -240,7 +232,18 @@ async def handle_request(self, request: Request): ) return ChatCompletionResponse(model="multimodalqna", choices=choices, usage=usage) + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.MULTIMODAL_QNA), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) + if __name__ == "__main__": mmragwithvideos = MultimodalQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) mmragwithvideos.add_remote_service() + mmragwithvideos.start() diff --git a/SearchQnA/searchqna.py b/SearchQnA/searchqna.py index fe4d00aa62..a4162f1b9f 100644 --- a/SearchQnA/searchqna.py +++ b/SearchQnA/searchqna.py @@ -31,15 +31,7 @@ class SearchQnAService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port - - super().__init__( - megaservice=ServiceOrchestrator(), - host=self.host, - port=self.port, - endpoint=str(MegaServiceEndpoint.SEARCH_QNA), - input_datatype=ChatCompletionRequest, - output_datatype=ChatCompletionResponse, - ) + self.megaservice = ServiceOrchestrator() def add_remote_service(self): embedding = MicroService( @@ -118,7 +110,18 @@ async def handle_request(self, request: Request): ) return ChatCompletionResponse(model="searchqna", choices=choices, usage=usage) + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.SEARCH_QNA), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) + if __name__ == "__main__": searchqna = SearchQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) searchqna.add_remote_service() + searchqna.start() diff --git a/Translation/translation.py b/Translation/translation.py index 54fea03f35..1fa2ff9d37 100644 --- a/Translation/translation.py +++ b/Translation/translation.py @@ -36,14 +36,7 @@ class TranslationService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port - super().__init__( - megaservice=ServiceOrchestrator(), - host=self.host, - port=self.port, - endpoint=str(MegaServiceEndpoint.TRANSLATION), - input_datatype=ChatCompletionRequest, - output_datatype=ChatCompletionResponse, - ) + self.megaservice = ServiceOrchestrator() def add_remote_service(self): llm = MicroService( @@ -94,7 +87,18 @@ async def handle_request(self, request: Request): ) return ChatCompletionResponse(model="translation", choices=choices, usage=usage) + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.TRANSLATION), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) + if __name__ == "__main__": translation = TranslationService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) translation.add_remote_service() + translation.start() diff --git a/VideoQnA/videoqna.py b/VideoQnA/videoqna.py index cced72a943..f4dc1b9352 100644 --- a/VideoQnA/videoqna.py +++ b/VideoQnA/videoqna.py @@ -31,15 +31,7 @@ class VideoQnAService(Gateway): def __init__(self, host="0.0.0.0", port=8888): self.host = host self.port = port - - super().__init__( - megaservice=ServiceOrchestrator(), - host=self.host, - port=self.port, - endpoint=str(MegaServiceEndpoint.VIDEO_RAG_QNA), - input_datatype=ChatCompletionRequest, - output_datatype=ChatCompletionResponse, - ) + self.megaservice = ServiceOrchestrator() def add_remote_service(self): embedding = MicroService( @@ -118,7 +110,18 @@ async def handle_request(self, request: Request): ) return ChatCompletionResponse(model="videoqna", choices=choices, usage=usage) + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.VIDEO_RAG_QNA), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) + if __name__ == "__main__": videoqna = VideoQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) videoqna.add_remote_service() + videoqna.start() diff --git a/VisualQnA/visualqna.py b/VisualQnA/visualqna.py index 951c9ed1e8..2f1cf04d9d 100644 --- a/VisualQnA/visualqna.py +++ b/VisualQnA/visualqna.py @@ -25,15 +25,7 @@ class VisualQnAService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port - - super().__init__( - megaservice=ServiceOrchestrator(), - host=self.host, - port=self.port, - endpoint=str(MegaServiceEndpoint.VISUAL_QNA), - input_datatype=ChatCompletionRequest, - output_datatype=ChatCompletionResponse, - ) + self.megaservice = ServiceOrchestrator() def add_remote_service(self): llm = MicroService( @@ -85,7 +77,18 @@ async def handle_request(self, request: Request): ) return ChatCompletionResponse(model="visualqna", choices=choices, usage=usage) + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.VISUAL_QNA), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) + if __name__ == "__main__": visualqna = VisualQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) visualqna.add_remote_service() + visualqna.start() From e3985272fc6249c713db0f1361f6845f4e36e071 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 29 Oct 2024 08:30:17 +0000 Subject: [PATCH 14/24] update service start entry. --- DocIndexRetriever/retrieval_tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DocIndexRetriever/retrieval_tool.py b/DocIndexRetriever/retrieval_tool.py index c0f8f3feb4..bc3463fd2a 100644 --- a/DocIndexRetriever/retrieval_tool.py +++ b/DocIndexRetriever/retrieval_tool.py @@ -115,7 +115,7 @@ def parser_input(data, TypeClass, key): response = result_dict[last_node] return response - def start(): + def start(self): super().__init__( megaservice=self.megaservice, host=self.host, From f72347fa98a3a8fe7989503de0b97a42d77a2595 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 30 Oct 2024 13:27:21 +0000 Subject: [PATCH 15/24] remove `MEGA_SERVICE_HOST_IP` which is not actually used. --- AudioQnA/audioqna.py | 3 +-- ChatQnA/chatqna.py | 3 +-- CodeGen/codegen.py | 3 +-- CodeTrans/code_translation.py | 3 +-- DocIndexRetriever/retrieval_tool.py | 3 +-- DocSum/docsum.py | 3 +-- FaqGen/faqgen.py | 3 +-- MultimodalQnA/multimodalqna.py | 3 +-- SearchQnA/searchqna.py | 3 +-- Translation/translation.py | 3 +-- VideoQnA/videoqna.py | 3 +-- VisualQnA/visualqna.py | 3 +-- 12 files changed, 12 insertions(+), 24 deletions(-) diff --git a/AudioQnA/audioqna.py b/AudioQnA/audioqna.py index b6bf93dd04..fbede59489 100644 --- a/AudioQnA/audioqna.py +++ b/AudioQnA/audioqna.py @@ -9,7 +9,6 @@ from comps.cores.proto.docarray import LLMParams from fastapi import Request -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) ASR_SERVICE_HOST_IP = os.getenv("ASR_SERVICE_HOST_IP", "0.0.0.0") ASR_SERVICE_PORT = int(os.getenv("ASR_SERVICE_PORT", 9099)) @@ -90,6 +89,6 @@ def start(self): if __name__ == "__main__": - audioqna = AudioQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + audioqna = AudioQnAService(port=MEGA_SERVICE_PORT) audioqna.add_remote_service() audioqna.start() diff --git a/ChatQnA/chatqna.py b/ChatQnA/chatqna.py index 45922792ab..078a0a9d4e 100644 --- a/ChatQnA/chatqna.py +++ b/ChatQnA/chatqna.py @@ -45,7 +45,6 @@ def generate_rag_prompt(question, documents): return template.format(context=context_str, question=question) -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) GUARDRAIL_SERVICE_HOST_IP = os.getenv("GUARDRAIL_SERVICE_HOST_IP", "0.0.0.0") GUARDRAIL_SERVICE_PORT = int(os.getenv("GUARDRAIL_SERVICE_PORT", 80)) @@ -392,7 +391,7 @@ def start(self): args = parser.parse_args() - chatqna = ChatQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + chatqna = ChatQnAService(port=MEGA_SERVICE_PORT) if args.without_rerank: chatqna.add_remote_service_without_rerank() elif args.with_guardrails: diff --git a/CodeGen/codegen.py b/CodeGen/codegen.py index 4fff2c4816..5ae4329d08 100644 --- a/CodeGen/codegen.py +++ b/CodeGen/codegen.py @@ -16,7 +16,6 @@ from fastapi import Request from fastapi.responses import StreamingResponse -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 7778)) LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0") LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) @@ -90,6 +89,6 @@ def start(self): if __name__ == "__main__": - chatqna = CodeGenService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + chatqna = CodeGenService(port=MEGA_SERVICE_PORT) chatqna.add_remote_service() chatqna.start() diff --git a/CodeTrans/code_translation.py b/CodeTrans/code_translation.py index dba14e2196..c163c847fa 100644 --- a/CodeTrans/code_translation.py +++ b/CodeTrans/code_translation.py @@ -15,7 +15,6 @@ from fastapi import Request from fastapi.responses import StreamingResponse -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 7777)) LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0") LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) @@ -89,6 +88,6 @@ def start(self): if __name__ == "__main__": - service_ochestrator = CodeTransService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + service_ochestrator = CodeTransService(port=MEGA_SERVICE_PORT) service_ochestrator.add_remote_service() service_ochestrator.start() diff --git a/DocIndexRetriever/retrieval_tool.py b/DocIndexRetriever/retrieval_tool.py index bc3463fd2a..b902b7a20e 100644 --- a/DocIndexRetriever/retrieval_tool.py +++ b/DocIndexRetriever/retrieval_tool.py @@ -11,7 +11,6 @@ from fastapi import Request from fastapi.responses import StreamingResponse -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = os.getenv("MEGA_SERVICE_PORT", 8889) EMBEDDING_SERVICE_HOST_IP = os.getenv("EMBEDDING_SERVICE_HOST_IP", "0.0.0.0") EMBEDDING_SERVICE_PORT = os.getenv("EMBEDDING_SERVICE_PORT", 6000) @@ -127,6 +126,6 @@ def start(self): if __name__ == "__main__": - chatqna = RetrievalToolService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + chatqna = RetrievalToolService(port=MEGA_SERVICE_PORT) chatqna.add_remote_service() chatqna.start() diff --git a/DocSum/docsum.py b/DocSum/docsum.py index 913febd545..f872bcb5c3 100644 --- a/DocSum/docsum.py +++ b/DocSum/docsum.py @@ -16,7 +16,6 @@ from fastapi import Request from fastapi.responses import StreamingResponse -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0") LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) @@ -91,6 +90,6 @@ def start(self): if __name__ == "__main__": - docsum = DocSumService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + docsum = DocSumService(port=MEGA_SERVICE_PORT) docsum.add_remote_service() docsum.start() diff --git a/FaqGen/faqgen.py b/FaqGen/faqgen.py index 8ee05f9eca..bea78025bf 100644 --- a/FaqGen/faqgen.py +++ b/FaqGen/faqgen.py @@ -16,7 +16,6 @@ from fastapi import Request from fastapi.responses import StreamingResponse -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0") LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) @@ -90,6 +89,6 @@ def start(self): if __name__ == "__main__": - faqgen = FaqGenService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + faqgen = FaqGenService(port=MEGA_SERVICE_PORT) faqgen.add_remote_service() faqgen.start() diff --git a/MultimodalQnA/multimodalqna.py b/MultimodalQnA/multimodalqna.py index f1d147dd84..e9cdc09d4c 100644 --- a/MultimodalQnA/multimodalqna.py +++ b/MultimodalQnA/multimodalqna.py @@ -19,7 +19,6 @@ from fastapi.responses import StreamingResponse from PIL import Image -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) MM_EMBEDDING_SERVICE_HOST_IP = os.getenv("MM_EMBEDDING_SERVICE_HOST_IP", "0.0.0.0") MM_EMBEDDING_PORT_MICROSERVICE = int(os.getenv("MM_EMBEDDING_PORT_MICROSERVICE", 6000)) @@ -244,6 +243,6 @@ def start(self): if __name__ == "__main__": - mmragwithvideos = MultimodalQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + mmragwithvideos = MultimodalQnAService(port=MEGA_SERVICE_PORT) mmragwithvideos.add_remote_service() mmragwithvideos.start() diff --git a/SearchQnA/searchqna.py b/SearchQnA/searchqna.py index a4162f1b9f..1a04a97a60 100644 --- a/SearchQnA/searchqna.py +++ b/SearchQnA/searchqna.py @@ -15,7 +15,6 @@ from fastapi import Request from fastapi.responses import StreamingResponse -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) EMBEDDING_SERVICE_HOST_IP = os.getenv("EMBEDDING_SERVICE_HOST_IP", "0.0.0.0") EMBEDDING_SERVICE_PORT = int(os.getenv("EMBEDDING_SERVICE_PORT", 6000)) @@ -122,6 +121,6 @@ def start(self): if __name__ == "__main__": - searchqna = SearchQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + searchqna = SearchQnAService(port=MEGA_SERVICE_PORT) searchqna.add_remote_service() searchqna.start() diff --git a/Translation/translation.py b/Translation/translation.py index 1fa2ff9d37..92e8b6f0dd 100644 --- a/Translation/translation.py +++ b/Translation/translation.py @@ -26,7 +26,6 @@ from fastapi import Request from fastapi.responses import StreamingResponse -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0") LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) @@ -99,6 +98,6 @@ def start(self): if __name__ == "__main__": - translation = TranslationService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + translation = TranslationService(port=MEGA_SERVICE_PORT) translation.add_remote_service() translation.start() diff --git a/VideoQnA/videoqna.py b/VideoQnA/videoqna.py index f4dc1b9352..7632a1c10b 100644 --- a/VideoQnA/videoqna.py +++ b/VideoQnA/videoqna.py @@ -15,7 +15,6 @@ from fastapi import Request from fastapi.responses import StreamingResponse -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) EMBEDDING_SERVICE_HOST_IP = os.getenv("EMBEDDING_SERVICE_HOST_IP", "0.0.0.0") EMBEDDING_SERVICE_PORT = int(os.getenv("EMBEDDING_SERVICE_PORT", 6000)) @@ -122,6 +121,6 @@ def start(self): if __name__ == "__main__": - videoqna = VideoQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + videoqna = VideoQnAService(port=MEGA_SERVICE_PORT) videoqna.add_remote_service() videoqna.start() diff --git a/VisualQnA/visualqna.py b/VisualQnA/visualqna.py index 2f1cf04d9d..f6519c1d27 100644 --- a/VisualQnA/visualqna.py +++ b/VisualQnA/visualqna.py @@ -15,7 +15,6 @@ from fastapi import Request from fastapi.responses import StreamingResponse -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) LVM_SERVICE_HOST_IP = os.getenv("LVM_SERVICE_HOST_IP", "0.0.0.0") LVM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9399)) @@ -89,6 +88,6 @@ def start(self): if __name__ == "__main__": - visualqna = VisualQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + visualqna = VisualQnAService(port=MEGA_SERVICE_PORT) visualqna.add_remote_service() visualqna.start() From 76caa00653a873e1e60b1e94863bb1ebc3c2bcde Mon Sep 17 00:00:00 2001 From: lkk Date: Thu, 5 Dec 2024 07:25:46 +0000 Subject: [PATCH 16/24] update docsum/faqgen. --- DocSum/docsum.py | 70 +++++++++++++++++++++++++++++++++++++++++++----- FaqGen/faqgen.py | 30 ++++++++++++++++++--- 2 files changed, 89 insertions(+), 11 deletions(-) diff --git a/DocSum/docsum.py b/DocSum/docsum.py index e07763c60c..f05786ef0d 100644 --- a/DocSum/docsum.py +++ b/DocSum/docsum.py @@ -13,8 +13,9 @@ UsageInfo, ) from comps.cores.proto.docarray import LLMParams -from fastapi import Request +from fastapi import File, Request, UploadFile from fastapi.responses import StreamingResponse +from typing import List MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) @@ -52,11 +53,63 @@ def add_remote_service(self): ) self.megaservice.add(llm) - async def handle_request(self, request: Request): - data = await request.json() - stream_opt = data.get("stream", True) - chat_request = ChatCompletionRequest.parse_obj(data) - prompt = self._handle_message(chat_request.messages) + + async def handle_request(self, request: Request, files: List[UploadFile] = File(default=None)): + + if "application/json" in request.headers.get("content-type"): + data = await request.json() + stream_opt = data.get("stream", True) + chat_request = ChatCompletionRequest.model_validate(data) + prompt = self._handle_message(chat_request.messages) + + initial_inputs_data = {data["type"]: prompt} + + elif "multipart/form-data" in request.headers.get("content-type"): + data = await request.form() + stream_opt = data.get("stream", True) + chat_request = ChatCompletionRequest.model_validate(data) + + data_type = data.get("type") + + file_summaries = [] + if files: + for file in files: + file_path = f"/tmp/{file.filename}" + + if data_type is not None and data_type in ["audio", "video"]: + raise ValueError( + "Audio and Video file uploads are not supported in docsum with curl request, please use the UI." + ) + + else: + import aiofiles + + async with aiofiles.open(file_path, "wb") as f: + await f.write(await file.read()) + + docs = read_text_from_file(file, file_path) + os.remove(file_path) + + if isinstance(docs, list): + file_summaries.extend(docs) + else: + file_summaries.append(docs) + + if file_summaries: + prompt = self._handle_message(chat_request.messages) + "\n".join(file_summaries) + else: + prompt = self._handle_message(chat_request.messages) + + data_type = data.get("type") + if data_type is not None: + initial_inputs_data = {} + initial_inputs_data[data_type] = prompt + else: + initial_inputs_data = {"query": prompt} + + else: + raise ValueError(f"Unknown request type: {request.headers.get('content-type')}") + parameters = LLMParams( max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, top_k=chat_request.top_k if chat_request.top_k else 10, @@ -66,11 +119,14 @@ async def handle_request(self, request: Request): presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, streaming=stream_opt, + model=chat_request.model if chat_request.model else None, language=chat_request.language if chat_request.language else "auto", ) + result_dict, runtime_graph = await self.megaservice.schedule( - initial_inputs={"query": prompt}, llm_parameters=parameters + initial_inputs=initial_inputs_data, llm_parameters=parameters ) + for node, response in result_dict.items(): # Here it suppose the last microservice in the megaservice is LLM. if ( diff --git a/FaqGen/faqgen.py b/FaqGen/faqgen.py index bea78025bf..cf3ebee383 100644 --- a/FaqGen/faqgen.py +++ b/FaqGen/faqgen.py @@ -13,8 +13,9 @@ UsageInfo, ) from comps.cores.proto.docarray import LLMParams -from fastapi import Request +from fastapi import File, Request, UploadFile from fastapi.responses import StreamingResponse +from typing import List MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0") @@ -38,11 +39,31 @@ def add_remote_service(self): ) self.megaservice.add(llm) - async def handle_request(self, request: Request): - data = await request.json() + async def handle_request(self, request: Request, files: List[UploadFile] = File(default=None)): + data = await request.form() stream_opt = data.get("stream", True) chat_request = ChatCompletionRequest.parse_obj(data) - prompt = self._handle_message(chat_request.messages) + file_summaries = [] + if files: + for file in files: + file_path = f"/tmp/{file.filename}" + + import aiofiles + + async with aiofiles.open(file_path, "wb") as f: + await f.write(await file.read()) + docs = read_text_from_file(file, file_path) + os.remove(file_path) + if isinstance(docs, list): + file_summaries.extend(docs) + else: + file_summaries.append(docs) + + if file_summaries: + prompt = self._handle_message(chat_request.messages) + "\n".join(file_summaries) + else: + prompt = self._handle_message(chat_request.messages) + parameters = LLMParams( max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, top_k=chat_request.top_k if chat_request.top_k else 10, @@ -52,6 +73,7 @@ async def handle_request(self, request: Request): presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, streaming=stream_opt, + model=chat_request.model if chat_request.model else None, ) result_dict, runtime_graph = await self.megaservice.schedule( initial_inputs={"query": prompt}, llm_parameters=parameters From d13bf2a41ed9b6b485fc6b793a79d04f8485e7eb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Dec 2024 07:22:18 +0000 Subject: [PATCH 17/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- DocSum/docsum.py | 3 +-- FaqGen/faqgen.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/DocSum/docsum.py b/DocSum/docsum.py index f05786ef0d..039cf5c67e 100644 --- a/DocSum/docsum.py +++ b/DocSum/docsum.py @@ -3,6 +3,7 @@ import asyncio import os +from typing import List from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType from comps.cores.proto.api_protocol import ( @@ -15,7 +16,6 @@ from comps.cores.proto.docarray import LLMParams from fastapi import File, Request, UploadFile from fastapi.responses import StreamingResponse -from typing import List MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) @@ -53,7 +53,6 @@ def add_remote_service(self): ) self.megaservice.add(llm) - async def handle_request(self, request: Request, files: List[UploadFile] = File(default=None)): if "application/json" in request.headers.get("content-type"): diff --git a/FaqGen/faqgen.py b/FaqGen/faqgen.py index cf3ebee383..40bd9af812 100644 --- a/FaqGen/faqgen.py +++ b/FaqGen/faqgen.py @@ -3,6 +3,7 @@ import asyncio import os +from typing import List from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType from comps.cores.proto.api_protocol import ( @@ -15,7 +16,6 @@ from comps.cores.proto.docarray import LLMParams from fastapi import File, Request, UploadFile from fastapi.responses import StreamingResponse -from typing import List MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0") From b1d5422c404fd2e304d474e5f728cd599e2026db Mon Sep 17 00:00:00 2001 From: lkk Date: Thu, 5 Dec 2024 07:31:02 +0000 Subject: [PATCH 18/24] update docsum/faqgen. --- DocSum/docsum.py | 1 + FaqGen/faqgen.py | 1 + 2 files changed, 2 insertions(+) diff --git a/DocSum/docsum.py b/DocSum/docsum.py index 039cf5c67e..13f7df6623 100644 --- a/DocSum/docsum.py +++ b/DocSum/docsum.py @@ -14,6 +14,7 @@ UsageInfo, ) from comps.cores.proto.docarray import LLMParams +from comps.cores.mega.gateway import read_text_from_file from fastapi import File, Request, UploadFile from fastapi.responses import StreamingResponse diff --git a/FaqGen/faqgen.py b/FaqGen/faqgen.py index 40bd9af812..a12e5ce4d0 100644 --- a/FaqGen/faqgen.py +++ b/FaqGen/faqgen.py @@ -14,6 +14,7 @@ UsageInfo, ) from comps.cores.proto.docarray import LLMParams +from comps.cores.mega.gateway import read_text_from_file from fastapi import File, Request, UploadFile from fastapi.responses import StreamingResponse From 7f1f52d3b134e54f19c1a9ae482a4f9abecda669 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Dec 2024 07:27:44 +0000 Subject: [PATCH 19/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- DocSum/docsum.py | 2 +- FaqGen/faqgen.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/DocSum/docsum.py b/DocSum/docsum.py index 13f7df6623..39a13b01af 100644 --- a/DocSum/docsum.py +++ b/DocSum/docsum.py @@ -6,6 +6,7 @@ from typing import List from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.mega.gateway import read_text_from_file from comps.cores.proto.api_protocol import ( ChatCompletionRequest, ChatCompletionResponse, @@ -14,7 +15,6 @@ UsageInfo, ) from comps.cores.proto.docarray import LLMParams -from comps.cores.mega.gateway import read_text_from_file from fastapi import File, Request, UploadFile from fastapi.responses import StreamingResponse diff --git a/FaqGen/faqgen.py b/FaqGen/faqgen.py index a12e5ce4d0..9ebe3d673f 100644 --- a/FaqGen/faqgen.py +++ b/FaqGen/faqgen.py @@ -6,6 +6,7 @@ from typing import List from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.mega.gateway import read_text_from_file from comps.cores.proto.api_protocol import ( ChatCompletionRequest, ChatCompletionResponse, @@ -14,7 +15,6 @@ UsageInfo, ) from comps.cores.proto.docarray import LLMParams -from comps.cores.mega.gateway import read_text_from_file from fastapi import File, Request, UploadFile from fastapi.responses import StreamingResponse From 082caee1354c50da8dc235500d5bdfbdfe292365 Mon Sep 17 00:00:00 2001 From: lkk Date: Thu, 5 Dec 2024 09:35:37 +0000 Subject: [PATCH 20/24] move more gates. --- AudioQnA/audioqna_multilang.py | 46 ++++++++++++++-- AvatarChatbot/avatarchatbot.py | 45 ++++++++++++++-- EdgeCraftRAG/chatqna.py | 48 +++++++++-------- GraphRAG/graphrag.py | 96 ++++++++++++++++++++++++++++++++-- 4 files changed, 197 insertions(+), 38 deletions(-) diff --git a/AudioQnA/audioqna_multilang.py b/AudioQnA/audioqna_multilang.py index 8a4ffdd01a..33a1e1d61a 100644 --- a/AudioQnA/audioqna_multilang.py +++ b/AudioQnA/audioqna_multilang.py @@ -5,9 +5,11 @@ import base64 import os -from comps import AudioQnAGateway, MicroService, ServiceOrchestrator, ServiceType +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import AudioChatCompletionRequest, ChatCompletionResponse +from comps.cores.proto.docarray import LLMParams +from fastapi import Request -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) WHISPER_SERVER_HOST_IP = os.getenv("WHISPER_SERVER_HOST_IP", "0.0.0.0") @@ -52,7 +54,7 @@ def align_outputs(self, data, cur_node, inputs, runtime_graph, llm_parameters_di return data -class AudioQnAService: +class AudioQnAService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port @@ -90,9 +92,43 @@ def add_remote_service(self): self.megaservice.add(asr).add(llm).add(tts) self.megaservice.flow_to(asr, llm) self.megaservice.flow_to(llm, tts) - self.gateway = AudioQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + async def handle_request(self, request: Request): + data = await request.json() + + chat_request = AudioChatCompletionRequest.parse_obj(data) + parameters = LLMParams( + # relatively lower max_tokens for audio conversation + max_tokens=chat_request.max_tokens if chat_request.max_tokens else 128, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, + presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, + repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, + streaming=False, # TODO add streaming LLM output as input to TTS + ) + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"byte_str": chat_request.audio}, llm_parameters=parameters + ) + + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["byte_str"] + + return response + + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.AUDIO_QNA), + input_datatype=AudioChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) if __name__ == "__main__": - audioqna = AudioQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + audioqna = AudioQnAService(port=MEGA_SERVICE_PORT) audioqna.add_remote_service() + audioqna.start() diff --git a/AvatarChatbot/avatarchatbot.py b/AvatarChatbot/avatarchatbot.py index b72ae49c5f..0893fc6f93 100644 --- a/AvatarChatbot/avatarchatbot.py +++ b/AvatarChatbot/avatarchatbot.py @@ -5,9 +5,11 @@ import os import sys -from comps import AvatarChatbotGateway, MicroService, ServiceOrchestrator, ServiceType +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import AudioChatCompletionRequest, ChatCompletionResponse +from comps.cores.proto.docarray import LLMParams +from fastapi import Request -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) ASR_SERVICE_HOST_IP = os.getenv("ASR_SERVICE_HOST_IP", "0.0.0.0") ASR_SERVICE_PORT = int(os.getenv("ASR_SERVICE_PORT", 9099)) @@ -27,7 +29,7 @@ def check_env_vars(env_var_list): print("All environment variables are set.") -class AvatarChatbotService: +class AvatarChatbotService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port @@ -70,7 +72,39 @@ def add_remote_service(self): self.megaservice.flow_to(asr, llm) self.megaservice.flow_to(llm, tts) self.megaservice.flow_to(tts, animation) - self.gateway = AvatarChatbotGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + async def handle_request(self, request: Request): + data = await request.json() + + chat_request = AudioChatCompletionRequest.model_validate(data) + parameters = LLMParams( + # relatively lower max_tokens for audio conversation + max_tokens=chat_request.max_tokens if chat_request.max_tokens else 128, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03, + streaming=False, # TODO add streaming LLM output as input to TTS + ) + # print(parameters) + + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"byte_str": chat_request.audio}, llm_parameters=parameters + ) + + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["video_path"] + return response + + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.AVATAR_CHATBOT), + input_datatype=AudioChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) if __name__ == "__main__": @@ -89,5 +123,6 @@ def add_remote_service(self): ] ) - avatarchatbot = AvatarChatbotService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + avatarchatbot = AvatarChatbotService(port=MEGA_SERVICE_PORT) avatarchatbot.add_remote_service() + avatarchatbot.start() diff --git a/EdgeCraftRAG/chatqna.py b/EdgeCraftRAG/chatqna.py index 02f0a84dd8..31c701b5fa 100644 --- a/EdgeCraftRAG/chatqna.py +++ b/EdgeCraftRAG/chatqna.py @@ -5,7 +5,6 @@ from comps import MicroService, ServiceOrchestrator, ServiceType -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "127.0.0.1") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 16011)) PIPELINE_SERVICE_HOST_IP = os.getenv("PIPELINE_SERVICE_HOST_IP", "127.0.0.1") PIPELINE_SERVICE_PORT = int(os.getenv("PIPELINE_SERVICE_PORT", 16010)) @@ -23,11 +22,22 @@ from fastapi.responses import StreamingResponse -class EdgeCraftRagGateway(Gateway): - def __init__(self, megaservice, host="0.0.0.0", port=16011): - super().__init__( - megaservice, host, port, str(MegaServiceEndpoint.CHAT_QNA), ChatCompletionRequest, ChatCompletionResponse +class EdgeCraftRagService(Gateway): + def __init__(self, host="0.0.0.0", port=16010): + self.host = host + self.port = port + self.megaservice = ServiceOrchestrator() + + def add_remote_service(self): + edgecraftrag = MicroService( + name="pipeline", + host=PIPELINE_SERVICE_HOST_IP, + port=PIPELINE_SERVICE_PORT, + endpoint="/v1/chatqna", + use_remote_service=True, + service_type=ServiceType.LLM, ) + self.megaservice.add(edgecraftrag) async def handle_request(self, request: Request): input = await request.json() @@ -61,26 +71,18 @@ async def handle_request(self, request: Request): ) return ChatCompletionResponse(model="edgecraftrag", choices=choices, usage=usage) - -class EdgeCraftRagService: - def __init__(self, host="0.0.0.0", port=16010): - self.host = host - self.port = port - self.megaservice = ServiceOrchestrator() - - def add_remote_service(self): - edgecraftrag = MicroService( - name="pipeline", - host=PIPELINE_SERVICE_HOST_IP, - port=PIPELINE_SERVICE_PORT, - endpoint="/v1/chatqna", - use_remote_service=True, - service_type=ServiceType.LLM, + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.CHAT_QNA), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, ) - self.megaservice.add(edgecraftrag) - self.gateway = EdgeCraftRagGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) if __name__ == "__main__": - edgecraftrag = EdgeCraftRagService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + edgecraftrag = EdgeCraftRagService(port=MEGA_SERVICE_PORT) edgecraftrag.add_remote_service() + edgecraftrag.start() diff --git a/GraphRAG/graphrag.py b/GraphRAG/graphrag.py index a6fd8b0939..eb18fa90d5 100644 --- a/GraphRAG/graphrag.py +++ b/GraphRAG/graphrag.py @@ -6,7 +6,18 @@ import os import re -from comps import GraphragGateway, MicroService, ServiceOrchestrator, ServiceType +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + UsageInfo, + EmbeddingRequest, +) +from comps.cores.proto.docarray import TextDoc, LLMParams, RetrieverParms +from fastapi import Request +from fastapi.responses import StreamingResponse from langchain_core.prompts import PromptTemplate @@ -35,7 +46,6 @@ def generate_rag_prompt(question, documents): return template.format(context=context_str, question=question) -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) RETRIEVER_SERVICE_HOST_IP = os.getenv("RETRIEVER_SERVICE_HOST_IP", "0.0.0.0") RETRIEVER_SERVICE_PORT = int(os.getenv("RETRIEVER_SERVICE_PORT", 7000)) @@ -117,7 +127,7 @@ def align_generator(self, gen, **kwargs): yield "data: [DONE]\n\n" -class GraphRAGService: +class GraphRAGService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port @@ -146,9 +156,85 @@ def add_remote_service(self): ) self.megaservice.add(retriever).add(llm) self.megaservice.flow_to(retriever, llm) - self.gateway = GraphragGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + + async def handle_request(self, request: Request): + data = await request.json() + stream_opt = data.get("stream", True) + chat_request = ChatCompletionRequest.parse_obj(data) + + def parser_input(data, TypeClass, key): + chat_request = None + try: + chat_request = TypeClass.parse_obj(data) + query = getattr(chat_request, key) + except: + query = None + return query, chat_request + + query = None + for key, TypeClass in zip(["text", "input", "messages"], [TextDoc, EmbeddingRequest, ChatCompletionRequest]): + query, chat_request = parser_input(data, TypeClass, key) + if query is not None: + break + if query is None: + raise ValueError(f"Unknown request type: {data}") + if chat_request is None: + raise ValueError(f"Unknown request type: {data}") + prompt = self._handle_message(chat_request.messages) + parameters = LLMParams( + max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, + presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, + repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, + streaming=stream_opt, + chat_template=chat_request.chat_template if chat_request.chat_template else None, + ) + retriever_parameters = RetrieverParms( + search_type=chat_request.search_type if chat_request.search_type else "similarity", + k=chat_request.k if chat_request.k else 4, + distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None, + fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20, + lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5, + score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2, + ) + initial_inputs = chat_request + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs=initial_inputs, + llm_parameters=parameters, + retriever_parameters=retriever_parameters, + ) + for node, response in result_dict.items(): + if isinstance(response, StreamingResponse): + return response + last_node = runtime_graph.all_leaves()[-1] + response_content = result_dict[last_node]["choices"][0]["message"]["content"] + choices = [] + usage = UsageInfo() + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response_content), + finish_reason="stop", + ) + ) + return ChatCompletionResponse(model="chatqna", choices=choices, usage=usage) + + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.GRAPH_RAG), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) if __name__ == "__main__": - graphrag = GraphRAGService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + graphrag = GraphRAGService(port=MEGA_SERVICE_PORT) graphrag.add_remote_service() + graphrag.start() From 3fbefd4f28ae8e5642ae1348b19653a9783545cf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Dec 2024 09:32:03 +0000 Subject: [PATCH 21/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- GraphRAG/graphrag.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/GraphRAG/graphrag.py b/GraphRAG/graphrag.py index eb18fa90d5..f675095147 100644 --- a/GraphRAG/graphrag.py +++ b/GraphRAG/graphrag.py @@ -12,10 +12,10 @@ ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage, - UsageInfo, EmbeddingRequest, + UsageInfo, ) -from comps.cores.proto.docarray import TextDoc, LLMParams, RetrieverParms +from comps.cores.proto.docarray import LLMParams, RetrieverParms, TextDoc from fastapi import Request from fastapi.responses import StreamingResponse from langchain_core.prompts import PromptTemplate @@ -157,7 +157,6 @@ def add_remote_service(self): self.megaservice.add(retriever).add(llm) self.megaservice.flow_to(retriever, llm) - async def handle_request(self, request: Request): data = await request.json() stream_opt = data.get("stream", True) From 6acdb758dfd5ce10342e52de241fc8c9a5305247 Mon Sep 17 00:00:00 2001 From: lkk Date: Fri, 6 Dec 2024 02:27:18 +0000 Subject: [PATCH 22/24] fix 2 example. --- ChatQnA/chatqna_wrapper.py | 78 +++++++++++++++++++++++++-- DocSum/tests/test_compose_on_gaudi.sh | 6 +-- DocSum/tests/test_compose_on_xeon.sh | 8 +-- MultimodalQnA/multimodalqna.py | 2 +- 4 files changed, 82 insertions(+), 12 deletions(-) diff --git a/ChatQnA/chatqna_wrapper.py b/ChatQnA/chatqna_wrapper.py index 09062b5d27..20ea4e37ff 100644 --- a/ChatQnA/chatqna_wrapper.py +++ b/ChatQnA/chatqna_wrapper.py @@ -3,7 +3,18 @@ import os -from comps import ChatQnAGateway, MicroService, ServiceOrchestrator, ServiceType +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + UsageInfo, +) +from comps.cores.proto.docarray import LLMParams, RerankerParms, RetrieverParms +from fastapi import Request +from fastapi.responses import StreamingResponse + MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) @@ -17,7 +28,7 @@ LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) -class ChatQnAService: +class ChatQnAService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port @@ -60,9 +71,68 @@ def add_remote_service(self): self.megaservice.flow_to(embedding, retriever) self.megaservice.flow_to(retriever, rerank) self.megaservice.flow_to(rerank, llm) - self.gateway = ChatQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + async def handle_request(self, request: Request): + data = await request.json() + stream_opt = data.get("stream", True) + chat_request = ChatCompletionRequest.parse_obj(data) + prompt = self._handle_message(chat_request.messages) + parameters = LLMParams( + max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, + presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, + repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, + streaming=stream_opt, + chat_template=chat_request.chat_template if chat_request.chat_template else None, + ) + retriever_parameters = RetrieverParms( + search_type=chat_request.search_type if chat_request.search_type else "similarity", + k=chat_request.k if chat_request.k else 4, + distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None, + fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20, + lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5, + score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2, + ) + reranker_parameters = RerankerParms( + top_n=chat_request.top_n if chat_request.top_n else 1, + ) + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"text": prompt}, + llm_parameters=parameters, + retriever_parameters=retriever_parameters, + reranker_parameters=reranker_parameters, + ) + for node, response in result_dict.items(): + if isinstance(response, StreamingResponse): + return response + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["text"] + choices = [] + usage = UsageInfo() + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + ) + ) + return ChatCompletionResponse(model="chatqna", choices=choices, usage=usage) + + def start(self): + + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.CHAT_QNA), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) if __name__ == "__main__": - chatqna = ChatQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + chatqna = ChatQnAService(port=MEGA_SERVICE_PORT) chatqna.add_remote_service() + chatqna.start() diff --git a/DocSum/tests/test_compose_on_gaudi.sh b/DocSum/tests/test_compose_on_gaudi.sh index 91499197dc..ae43499b8f 100644 --- a/DocSum/tests/test_compose_on_gaudi.sh +++ b/DocSum/tests/test_compose_on_gaudi.sh @@ -148,7 +148,7 @@ function validate_microservices() { # Video2Audio service validate_services \ "${host_ip}:7078/v1/video2audio" \ - "SUQzBAAAAAAAI1RTU0UAAAAPAAADTGF2ZjU4LjI5LjEwMAAAAAAAAAAAAAAA//tQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAASW5mbwAAAA8AAAAIAAAN3wAtLS0tLS0tLS0tLS1LS0tLS0tLS0tLS0tpaWlpaWlpaWlpaWlph4eHh4eHh4eHh4eHpaWlpaWlpaWlpaWlpcPDw8PDw8PDw8PDw+Hh4eHh4eHh4eHh4eH///////////////8AAAAATGF2YzU4LjU0AAAAAAAAAAAAAAAAJAYwAAAAAAAADd9L18KaAAAAAAAAAAAAAAAAAAAAAP/7kGQAAAMhClSVMEACMOAabaCMAREA" \ + "SUQzBAAAAAAAI1RTU0UAAAAPAAADTGF2ZjU4LjI5LjEwMAAAAAAAAAAAAAAA//tQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAASW5mbwAAAA8AAAAIAAAN3wAtLS0tLS0tLS0tLS1LS0tLS0tLS0tLS0tpaWlpaWlpaWlpaWlph4eHh4eHh4eHh4eHpaWlpaWlpaWlpaWlpcPDw8PDw8PDw8PDw+Hh4eHh4eHh4eHh4eH///////////////8AAAAATGF2YzU4LjU0AAAAAAAAAAAAAAAAJAYwAAAAAAAADd95t4qPAAAAAAAAAAAAAAAAAAAAAP/7kGQAAAMhClSVMEACMOAabaCMAREA" \ "dataprep-video2audio" \ "dataprep-video2audio-service" \ "{\"byte_str\": \"$(input_data_for_test "video")\"}" @@ -156,7 +156,7 @@ function validate_microservices() { # Docsum Data service - video validate_services \ "${host_ip}:7079/v1/multimedia2text" \ - '"query":"well"' \ + '"query":"well' \ "dataprep-multimedia2text" \ "dataprep-multimedia2text" \ "{\"video\": \"$(input_data_for_test "video")\"}" @@ -164,7 +164,7 @@ function validate_microservices() { # Docsum Data service - audio validate_services \ "${host_ip}:7079/v1/multimedia2text" \ - '"query":"well"' \ + '"query":"well' \ "dataprep-multimedia2text" \ "dataprep-multimedia2text" \ "{\"audio\": \"$(input_data_for_test "audio")\"}" diff --git a/DocSum/tests/test_compose_on_xeon.sh b/DocSum/tests/test_compose_on_xeon.sh index 555633cfca..0b3dc2319c 100644 --- a/DocSum/tests/test_compose_on_xeon.sh +++ b/DocSum/tests/test_compose_on_xeon.sh @@ -2,7 +2,7 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -set -xe +# set -xe IMAGE_REPO=${IMAGE_REPO:-"opea"} IMAGE_TAG=${IMAGE_TAG:-"latest"} @@ -150,7 +150,7 @@ function validate_microservices() { # Video2Audio service validate_services \ "${host_ip}:7078/v1/video2audio" \ - "SUQzBAAAAAAAI1RTU0UAAAAPAAADTGF2ZjU4LjI5LjEwMAAAAAAAAAAAAAAA//tQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAASW5mbwAAAA8AAAAIAAAN3wAtLS0tLS0tLS0tLS1LS0tLS0tLS0tLS0tpaWlpaWlpaWlpaWlph4eHh4eHh4eHh4eHpaWlpaWlpaWlpaWlpcPDw8PDw8PDw8PDw+Hh4eHh4eHh4eHh4eH///////////////8AAAAATGF2YzU4LjU0AAAAAAAAAAAAAAAAJAYwAAAAAAAADd9L18KaAAAAAAAAAAAAAAAAAAAAAP/7kGQAAAMhClSVMEACMOAabaCMAREA" \ + "SUQzBAAAAAAAI1RTU0UAAAAPAAADTGF2ZjU4LjI5LjEwMAAAAAAAAAAAAAAA//tQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAASW5mbwAAAA8AAAAIAAAN3wAtLS0tLS0tLS0tLS1LS0tLS0tLS0tLS0tpaWlpaWlpaWlpaWlph4eHh4eHh4eHh4eHpaWlpaWlpaWlpaWlpcPDw8PDw8PDw8PDw+Hh4eHh4eHh4eHh4eH///////////////8AAAAATGF2YzU4LjU0AAAAAAAAAAAAAAAAJAYwAAAAAAAADd95t4qPAAAAAAAAAAAAAAAAAAAAAP/7kGQAAAMhClSVMEACMOAabaCMAREA" \ "dataprep-video2audio" \ "dataprep-video2audio-service" \ "{\"byte_str\": \"$(input_data_for_test "video")\"}" @@ -158,7 +158,7 @@ function validate_microservices() { # Docsum Data service - video validate_services \ "${host_ip}:7079/v1/multimedia2text" \ - '"query":"well"' \ + '"query":"well' \ "dataprep-multimedia2text-service" \ "dataprep-multimedia2text" \ "{\"video\": \"$(input_data_for_test "video")\"}" @@ -166,7 +166,7 @@ function validate_microservices() { # Docsum Data service - audio validate_services \ "${host_ip}:7079/v1/multimedia2text" \ - '"query":"well"' \ + '"query":"well' \ "dataprep-multimedia2text-service" \ "dataprep-multimedia2text" \ "{\"audio\": \"$(input_data_for_test "audio")\"}" diff --git a/MultimodalQnA/multimodalqna.py b/MultimodalQnA/multimodalqna.py index e9cdc09d4c..bf9b375749 100644 --- a/MultimodalQnA/multimodalqna.py +++ b/MultimodalQnA/multimodalqna.py @@ -161,7 +161,7 @@ async def handle_request(self, request: Request): data = await request.json() stream_opt = bool(data.get("stream", False)) if stream_opt: - print("[ MultimodalQnAGateway ] stream=True not used, this has not support streaming yet!") + print("[ MultimodalQnAService ] stream=True not used, this has not support streaming yet!") stream_opt = False chat_request = ChatCompletionRequest.model_validate(data) # Multimodal RAG QnA With Videos has not yet accepts image as input during QnA. From 52e0d5a5e99df68f0ee066234e6e1c37c273395d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Dec 2024 02:23:49 +0000 Subject: [PATCH 23/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- ChatQnA/chatqna_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ChatQnA/chatqna_wrapper.py b/ChatQnA/chatqna_wrapper.py index 20ea4e37ff..daa9762039 100644 --- a/ChatQnA/chatqna_wrapper.py +++ b/ChatQnA/chatqna_wrapper.py @@ -15,7 +15,6 @@ from fastapi import Request from fastapi.responses import StreamingResponse - MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) EMBEDDING_SERVICE_HOST_IP = os.getenv("EMBEDDING_SERVICE_HOST_IP", "0.0.0.0") @@ -132,6 +131,7 @@ def start(self): output_datatype=ChatCompletionResponse, ) + if __name__ == "__main__": chatqna = ChatQnAService(port=MEGA_SERVICE_PORT) chatqna.add_remote_service() From 3643e296adb974441e6b6450b6302ff940134c0f Mon Sep 17 00:00:00 2001 From: lkk Date: Fri, 6 Dec 2024 06:38:32 +0000 Subject: [PATCH 24/24] revert docsum ut. --- DocSum/tests/test_compose_on_gaudi.sh | 6 +++--- DocSum/tests/test_compose_on_xeon.sh | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/DocSum/tests/test_compose_on_gaudi.sh b/DocSum/tests/test_compose_on_gaudi.sh index ae43499b8f..91499197dc 100644 --- a/DocSum/tests/test_compose_on_gaudi.sh +++ b/DocSum/tests/test_compose_on_gaudi.sh @@ -148,7 +148,7 @@ function validate_microservices() { # Video2Audio service validate_services \ "${host_ip}:7078/v1/video2audio" \ - "SUQzBAAAAAAAI1RTU0UAAAAPAAADTGF2ZjU4LjI5LjEwMAAAAAAAAAAAAAAA//tQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAASW5mbwAAAA8AAAAIAAAN3wAtLS0tLS0tLS0tLS1LS0tLS0tLS0tLS0tpaWlpaWlpaWlpaWlph4eHh4eHh4eHh4eHpaWlpaWlpaWlpaWlpcPDw8PDw8PDw8PDw+Hh4eHh4eHh4eHh4eH///////////////8AAAAATGF2YzU4LjU0AAAAAAAAAAAAAAAAJAYwAAAAAAAADd95t4qPAAAAAAAAAAAAAAAAAAAAAP/7kGQAAAMhClSVMEACMOAabaCMAREA" \ + "SUQzBAAAAAAAI1RTU0UAAAAPAAADTGF2ZjU4LjI5LjEwMAAAAAAAAAAAAAAA//tQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAASW5mbwAAAA8AAAAIAAAN3wAtLS0tLS0tLS0tLS1LS0tLS0tLS0tLS0tpaWlpaWlpaWlpaWlph4eHh4eHh4eHh4eHpaWlpaWlpaWlpaWlpcPDw8PDw8PDw8PDw+Hh4eHh4eHh4eHh4eH///////////////8AAAAATGF2YzU4LjU0AAAAAAAAAAAAAAAAJAYwAAAAAAAADd9L18KaAAAAAAAAAAAAAAAAAAAAAP/7kGQAAAMhClSVMEACMOAabaCMAREA" \ "dataprep-video2audio" \ "dataprep-video2audio-service" \ "{\"byte_str\": \"$(input_data_for_test "video")\"}" @@ -156,7 +156,7 @@ function validate_microservices() { # Docsum Data service - video validate_services \ "${host_ip}:7079/v1/multimedia2text" \ - '"query":"well' \ + '"query":"well"' \ "dataprep-multimedia2text" \ "dataprep-multimedia2text" \ "{\"video\": \"$(input_data_for_test "video")\"}" @@ -164,7 +164,7 @@ function validate_microservices() { # Docsum Data service - audio validate_services \ "${host_ip}:7079/v1/multimedia2text" \ - '"query":"well' \ + '"query":"well"' \ "dataprep-multimedia2text" \ "dataprep-multimedia2text" \ "{\"audio\": \"$(input_data_for_test "audio")\"}" diff --git a/DocSum/tests/test_compose_on_xeon.sh b/DocSum/tests/test_compose_on_xeon.sh index 0b3dc2319c..555633cfca 100644 --- a/DocSum/tests/test_compose_on_xeon.sh +++ b/DocSum/tests/test_compose_on_xeon.sh @@ -2,7 +2,7 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -# set -xe +set -xe IMAGE_REPO=${IMAGE_REPO:-"opea"} IMAGE_TAG=${IMAGE_TAG:-"latest"} @@ -150,7 +150,7 @@ function validate_microservices() { # Video2Audio service validate_services \ "${host_ip}:7078/v1/video2audio" \ - "SUQzBAAAAAAAI1RTU0UAAAAPAAADTGF2ZjU4LjI5LjEwMAAAAAAAAAAAAAAA//tQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAASW5mbwAAAA8AAAAIAAAN3wAtLS0tLS0tLS0tLS1LS0tLS0tLS0tLS0tpaWlpaWlpaWlpaWlph4eHh4eHh4eHh4eHpaWlpaWlpaWlpaWlpcPDw8PDw8PDw8PDw+Hh4eHh4eHh4eHh4eH///////////////8AAAAATGF2YzU4LjU0AAAAAAAAAAAAAAAAJAYwAAAAAAAADd95t4qPAAAAAAAAAAAAAAAAAAAAAP/7kGQAAAMhClSVMEACMOAabaCMAREA" \ + "SUQzBAAAAAAAI1RTU0UAAAAPAAADTGF2ZjU4LjI5LjEwMAAAAAAAAAAAAAAA//tQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAASW5mbwAAAA8AAAAIAAAN3wAtLS0tLS0tLS0tLS1LS0tLS0tLS0tLS0tpaWlpaWlpaWlpaWlph4eHh4eHh4eHh4eHpaWlpaWlpaWlpaWlpcPDw8PDw8PDw8PDw+Hh4eHh4eHh4eHh4eH///////////////8AAAAATGF2YzU4LjU0AAAAAAAAAAAAAAAAJAYwAAAAAAAADd9L18KaAAAAAAAAAAAAAAAAAAAAAP/7kGQAAAMhClSVMEACMOAabaCMAREA" \ "dataprep-video2audio" \ "dataprep-video2audio-service" \ "{\"byte_str\": \"$(input_data_for_test "video")\"}" @@ -158,7 +158,7 @@ function validate_microservices() { # Docsum Data service - video validate_services \ "${host_ip}:7079/v1/multimedia2text" \ - '"query":"well' \ + '"query":"well"' \ "dataprep-multimedia2text-service" \ "dataprep-multimedia2text" \ "{\"video\": \"$(input_data_for_test "video")\"}" @@ -166,7 +166,7 @@ function validate_microservices() { # Docsum Data service - audio validate_services \ "${host_ip}:7079/v1/multimedia2text" \ - '"query":"well' \ + '"query":"well"' \ "dataprep-multimedia2text-service" \ "dataprep-multimedia2text" \ "{\"audio\": \"$(input_data_for_test "audio")\"}"