diff --git a/AudioQnA/audioqna.py b/AudioQnA/audioqna.py index 7987620847..fbede59489 100644 --- a/AudioQnA/audioqna.py +++ b/AudioQnA/audioqna.py @@ -4,9 +4,11 @@ import asyncio import os -from comps import AudioQnAGateway, MicroService, ServiceOrchestrator, ServiceType +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import AudioChatCompletionRequest, ChatCompletionResponse +from comps.cores.proto.docarray import LLMParams +from fastapi import Request -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) ASR_SERVICE_HOST_IP = os.getenv("ASR_SERVICE_HOST_IP", "0.0.0.0") ASR_SERVICE_PORT = int(os.getenv("ASR_SERVICE_PORT", 9099)) @@ -16,7 +18,7 @@ TTS_SERVICE_PORT = int(os.getenv("TTS_SERVICE_PORT", 9088)) -class AudioQnAService: +class AudioQnAService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port @@ -50,9 +52,43 @@ def add_remote_service(self): self.megaservice.add(asr).add(llm).add(tts) self.megaservice.flow_to(asr, llm) self.megaservice.flow_to(llm, tts) - self.gateway = AudioQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + async def handle_request(self, request: Request): + data = await request.json() + + chat_request = AudioChatCompletionRequest.parse_obj(data) + parameters = LLMParams( + # relatively lower max_tokens for audio conversation + max_tokens=chat_request.max_tokens if chat_request.max_tokens else 128, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, + presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, + repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, + streaming=False, # TODO add streaming LLM output as input to TTS + ) + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"byte_str": chat_request.audio}, llm_parameters=parameters + ) + + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["byte_str"] + + return response + + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.AUDIO_QNA), + input_datatype=AudioChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) if __name__ == "__main__": - audioqna = AudioQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + audioqna = AudioQnAService(port=MEGA_SERVICE_PORT) audioqna.add_remote_service() + audioqna.start() diff --git a/AudioQnA/audioqna_multilang.py b/AudioQnA/audioqna_multilang.py index 8a4ffdd01a..33a1e1d61a 100644 --- a/AudioQnA/audioqna_multilang.py +++ b/AudioQnA/audioqna_multilang.py @@ -5,9 +5,11 @@ import base64 import os -from comps import AudioQnAGateway, MicroService, ServiceOrchestrator, ServiceType +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import AudioChatCompletionRequest, ChatCompletionResponse +from comps.cores.proto.docarray import LLMParams +from fastapi import Request -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) WHISPER_SERVER_HOST_IP = os.getenv("WHISPER_SERVER_HOST_IP", "0.0.0.0") @@ -52,7 +54,7 @@ def align_outputs(self, data, cur_node, inputs, runtime_graph, llm_parameters_di return data -class AudioQnAService: +class AudioQnAService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port @@ -90,9 +92,43 @@ def add_remote_service(self): self.megaservice.add(asr).add(llm).add(tts) self.megaservice.flow_to(asr, llm) self.megaservice.flow_to(llm, tts) - self.gateway = AudioQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + async def handle_request(self, request: Request): + data = await request.json() + + chat_request = AudioChatCompletionRequest.parse_obj(data) + parameters = LLMParams( + # relatively lower max_tokens for audio conversation + max_tokens=chat_request.max_tokens if chat_request.max_tokens else 128, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, + presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, + repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, + streaming=False, # TODO add streaming LLM output as input to TTS + ) + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"byte_str": chat_request.audio}, llm_parameters=parameters + ) + + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["byte_str"] + + return response + + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.AUDIO_QNA), + input_datatype=AudioChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) if __name__ == "__main__": - audioqna = AudioQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + audioqna = AudioQnAService(port=MEGA_SERVICE_PORT) audioqna.add_remote_service() + audioqna.start() diff --git a/AvatarChatbot/avatarchatbot.py b/AvatarChatbot/avatarchatbot.py index b72ae49c5f..0893fc6f93 100644 --- a/AvatarChatbot/avatarchatbot.py +++ b/AvatarChatbot/avatarchatbot.py @@ -5,9 +5,11 @@ import os import sys -from comps import AvatarChatbotGateway, MicroService, ServiceOrchestrator, ServiceType +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import AudioChatCompletionRequest, ChatCompletionResponse +from comps.cores.proto.docarray import LLMParams +from fastapi import Request -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) ASR_SERVICE_HOST_IP = os.getenv("ASR_SERVICE_HOST_IP", "0.0.0.0") ASR_SERVICE_PORT = int(os.getenv("ASR_SERVICE_PORT", 9099)) @@ -27,7 +29,7 @@ def check_env_vars(env_var_list): print("All environment variables are set.") -class AvatarChatbotService: +class AvatarChatbotService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port @@ -70,7 +72,39 @@ def add_remote_service(self): self.megaservice.flow_to(asr, llm) self.megaservice.flow_to(llm, tts) self.megaservice.flow_to(tts, animation) - self.gateway = AvatarChatbotGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + async def handle_request(self, request: Request): + data = await request.json() + + chat_request = AudioChatCompletionRequest.model_validate(data) + parameters = LLMParams( + # relatively lower max_tokens for audio conversation + max_tokens=chat_request.max_tokens if chat_request.max_tokens else 128, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03, + streaming=False, # TODO add streaming LLM output as input to TTS + ) + # print(parameters) + + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"byte_str": chat_request.audio}, llm_parameters=parameters + ) + + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["video_path"] + return response + + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.AVATAR_CHATBOT), + input_datatype=AudioChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) if __name__ == "__main__": @@ -89,5 +123,6 @@ def add_remote_service(self): ] ) - avatarchatbot = AvatarChatbotService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + avatarchatbot = AvatarChatbotService(port=MEGA_SERVICE_PORT) avatarchatbot.add_remote_service() + avatarchatbot.start() diff --git a/ChatQnA/chatqna.py b/ChatQnA/chatqna.py index 95318e9613..695fab0a4a 100644 --- a/ChatQnA/chatqna.py +++ b/ChatQnA/chatqna.py @@ -6,7 +6,17 @@ import os import re -from comps import ChatQnAGateway, MicroService, ServiceOrchestrator, ServiceType +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + UsageInfo, +) +from comps.cores.proto.docarray import LLMParams, RerankerParms, RetrieverParms +from fastapi import Request +from fastapi.responses import StreamingResponse from langchain_core.prompts import PromptTemplate @@ -35,7 +45,6 @@ def generate_rag_prompt(question, documents): return template.format(context=context_str, question=question) -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) GUARDRAIL_SERVICE_HOST_IP = os.getenv("GUARDRAIL_SERVICE_HOST_IP", "0.0.0.0") GUARDRAIL_SERVICE_PORT = int(os.getenv("GUARDRAIL_SERVICE_PORT", 80)) @@ -178,13 +187,14 @@ def align_generator(self, gen, **kwargs): yield "data: [DONE]\n\n" -class ChatQnAService: +class ChatQnAService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port ServiceOrchestrator.align_inputs = align_inputs ServiceOrchestrator.align_outputs = align_outputs ServiceOrchestrator.align_generator = align_generator + self.megaservice = ServiceOrchestrator() def add_remote_service(self): @@ -228,7 +238,6 @@ def add_remote_service(self): self.megaservice.flow_to(embedding, retriever) self.megaservice.flow_to(retriever, rerank) self.megaservice.flow_to(rerank, llm) - self.gateway = ChatQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) def add_remote_service_without_rerank(self): @@ -261,7 +270,6 @@ def add_remote_service_without_rerank(self): self.megaservice.add(embedding).add(retriever).add(llm) self.megaservice.flow_to(embedding, retriever) self.megaservice.flow_to(retriever, llm) - self.gateway = ChatQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) def add_remote_service_with_guardrails(self): guardrail_in = MicroService( @@ -319,7 +327,66 @@ def add_remote_service_with_guardrails(self): self.megaservice.flow_to(retriever, rerank) self.megaservice.flow_to(rerank, llm) # self.megaservice.flow_to(llm, guardrail_out) - self.gateway = ChatQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + async def handle_request(self, request: Request): + data = await request.json() + stream_opt = data.get("stream", True) + chat_request = ChatCompletionRequest.parse_obj(data) + prompt = self._handle_message(chat_request.messages) + parameters = LLMParams( + max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, + presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, + repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, + streaming=stream_opt, + chat_template=chat_request.chat_template if chat_request.chat_template else None, + ) + retriever_parameters = RetrieverParms( + search_type=chat_request.search_type if chat_request.search_type else "similarity", + k=chat_request.k if chat_request.k else 4, + distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None, + fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20, + lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5, + score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2, + ) + reranker_parameters = RerankerParms( + top_n=chat_request.top_n if chat_request.top_n else 1, + ) + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"text": prompt}, + llm_parameters=parameters, + retriever_parameters=retriever_parameters, + reranker_parameters=reranker_parameters, + ) + for node, response in result_dict.items(): + if isinstance(response, StreamingResponse): + return response + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["text"] + choices = [] + usage = UsageInfo() + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + ) + ) + return ChatCompletionResponse(model="chatqna", choices=choices, usage=usage) + + def start(self): + + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.CHAT_QNA), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) if __name__ == "__main__": @@ -329,10 +396,12 @@ def add_remote_service_with_guardrails(self): args = parser.parse_args() - chatqna = ChatQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + chatqna = ChatQnAService(port=MEGA_SERVICE_PORT) if args.without_rerank: chatqna.add_remote_service_without_rerank() elif args.with_guardrails: chatqna.add_remote_service_with_guardrails() else: chatqna.add_remote_service() + + chatqna.start() diff --git a/ChatQnA/chatqna_wrapper.py b/ChatQnA/chatqna_wrapper.py index 09062b5d27..daa9762039 100644 --- a/ChatQnA/chatqna_wrapper.py +++ b/ChatQnA/chatqna_wrapper.py @@ -3,7 +3,17 @@ import os -from comps import ChatQnAGateway, MicroService, ServiceOrchestrator, ServiceType +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + UsageInfo, +) +from comps.cores.proto.docarray import LLMParams, RerankerParms, RetrieverParms +from fastapi import Request +from fastapi.responses import StreamingResponse MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) @@ -17,7 +27,7 @@ LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) -class ChatQnAService: +class ChatQnAService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port @@ -60,9 +70,69 @@ def add_remote_service(self): self.megaservice.flow_to(embedding, retriever) self.megaservice.flow_to(retriever, rerank) self.megaservice.flow_to(rerank, llm) - self.gateway = ChatQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + async def handle_request(self, request: Request): + data = await request.json() + stream_opt = data.get("stream", True) + chat_request = ChatCompletionRequest.parse_obj(data) + prompt = self._handle_message(chat_request.messages) + parameters = LLMParams( + max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, + presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, + repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, + streaming=stream_opt, + chat_template=chat_request.chat_template if chat_request.chat_template else None, + ) + retriever_parameters = RetrieverParms( + search_type=chat_request.search_type if chat_request.search_type else "similarity", + k=chat_request.k if chat_request.k else 4, + distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None, + fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20, + lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5, + score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2, + ) + reranker_parameters = RerankerParms( + top_n=chat_request.top_n if chat_request.top_n else 1, + ) + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"text": prompt}, + llm_parameters=parameters, + retriever_parameters=retriever_parameters, + reranker_parameters=reranker_parameters, + ) + for node, response in result_dict.items(): + if isinstance(response, StreamingResponse): + return response + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["text"] + choices = [] + usage = UsageInfo() + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + ) + ) + return ChatCompletionResponse(model="chatqna", choices=choices, usage=usage) + + def start(self): + + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.CHAT_QNA), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) if __name__ == "__main__": - chatqna = ChatQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + chatqna = ChatQnAService(port=MEGA_SERVICE_PORT) chatqna.add_remote_service() + chatqna.start() diff --git a/CodeGen/codegen.py b/CodeGen/codegen.py index ac04cc7876..5ae4329d08 100644 --- a/CodeGen/codegen.py +++ b/CodeGen/codegen.py @@ -4,15 +4,24 @@ import asyncio import os -from comps import CodeGenGateway, MicroService, ServiceOrchestrator, ServiceType +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + UsageInfo, +) +from comps.cores.proto.docarray import LLMParams +from fastapi import Request +from fastapi.responses import StreamingResponse -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 7778)) LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0") LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) -class CodeGenService: +class CodeGenService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port @@ -28,9 +37,58 @@ def add_remote_service(self): service_type=ServiceType.LLM, ) self.megaservice.add(llm) - self.gateway = CodeGenGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + async def handle_request(self, request: Request): + data = await request.json() + stream_opt = data.get("stream", True) + chat_request = ChatCompletionRequest.parse_obj(data) + prompt = self._handle_message(chat_request.messages) + parameters = LLMParams( + max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, + presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, + repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, + streaming=stream_opt, + ) + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"query": prompt}, llm_parameters=parameters + ) + for node, response in result_dict.items(): + # Here it suppose the last microservice in the megaservice is LLM. + if ( + isinstance(response, StreamingResponse) + and node == list(self.megaservice.services.keys())[-1] + and self.megaservice.services[node].service_type == ServiceType.LLM + ): + return response + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["text"] + choices = [] + usage = UsageInfo() + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + ) + ) + return ChatCompletionResponse(model="codegen", choices=choices, usage=usage) + + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.CODE_GEN), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) if __name__ == "__main__": - chatqna = CodeGenService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + chatqna = CodeGenService(port=MEGA_SERVICE_PORT) chatqna.add_remote_service() + chatqna.start() diff --git a/CodeTrans/code_translation.py b/CodeTrans/code_translation.py index 1957485112..c163c847fa 100644 --- a/CodeTrans/code_translation.py +++ b/CodeTrans/code_translation.py @@ -4,15 +4,23 @@ import asyncio import os -from comps import CodeTransGateway, MicroService, ServiceOrchestrator +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + UsageInfo, +) +from fastapi import Request +from fastapi.responses import StreamingResponse -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 7777)) LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0") LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) -class CodeTransService: +class CodeTransService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port @@ -27,9 +35,59 @@ def add_remote_service(self): use_remote_service=True, ) self.megaservice.add(llm) - self.gateway = CodeTransGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + async def handle_request(self, request: Request): + data = await request.json() + language_from = data["language_from"] + language_to = data["language_to"] + source_code = data["source_code"] + prompt_template = """ + ### System: Please translate the following {language_from} codes into {language_to} codes. + + ### Original codes: + '''{language_from} + + {source_code} + + ''' + + ### Translated codes: + """ + prompt = prompt_template.format(language_from=language_from, language_to=language_to, source_code=source_code) + result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"query": prompt}) + for node, response in result_dict.items(): + # Here it suppose the last microservice in the megaservice is LLM. + if ( + isinstance(response, StreamingResponse) + and node == list(self.megaservice.services.keys())[-1] + and self.megaservice.services[node].service_type == ServiceType.LLM + ): + return response + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["text"] + choices = [] + usage = UsageInfo() + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + ) + ) + return ChatCompletionResponse(model="codetrans", choices=choices, usage=usage) + + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.CODE_TRANS), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) if __name__ == "__main__": - service_ochestrator = CodeTransService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + service_ochestrator = CodeTransService(port=MEGA_SERVICE_PORT) service_ochestrator.add_remote_service() + service_ochestrator.start() diff --git a/DocIndexRetriever/retrieval_tool.py b/DocIndexRetriever/retrieval_tool.py index d3adc8d352..b902b7a20e 100644 --- a/DocIndexRetriever/retrieval_tool.py +++ b/DocIndexRetriever/retrieval_tool.py @@ -3,10 +3,14 @@ import asyncio import os +from typing import Union -from comps import MicroService, RetrievalToolGateway, ServiceOrchestrator, ServiceType +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import ChatCompletionRequest, EmbeddingRequest +from comps.cores.proto.docarray import LLMParamsDoc, RerankedDoc, RerankerParms, RetrieverParms, TextDoc +from fastapi import Request +from fastapi.responses import StreamingResponse -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = os.getenv("MEGA_SERVICE_PORT", 8889) EMBEDDING_SERVICE_HOST_IP = os.getenv("EMBEDDING_SERVICE_HOST_IP", "0.0.0.0") EMBEDDING_SERVICE_PORT = os.getenv("EMBEDDING_SERVICE_PORT", 6000) @@ -16,7 +20,7 @@ RERANK_SERVICE_PORT = os.getenv("RERANK_SERVICE_PORT", 8000) -class RetrievalToolService: +class RetrievalToolService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port @@ -51,9 +55,77 @@ def add_remote_service(self): self.megaservice.add(embedding).add(retriever).add(rerank) self.megaservice.flow_to(embedding, retriever) self.megaservice.flow_to(retriever, rerank) - self.gateway = RetrievalToolGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + async def handle_request(self, request: Request): + def parser_input(data, TypeClass, key): + chat_request = None + try: + chat_request = TypeClass.parse_obj(data) + query = getattr(chat_request, key) + except: + query = None + return query, chat_request + + data = await request.json() + query = None + for key, TypeClass in zip(["text", "input", "messages"], [TextDoc, EmbeddingRequest, ChatCompletionRequest]): + query, chat_request = parser_input(data, TypeClass, key) + if query is not None: + break + if query is None: + raise ValueError(f"Unknown request type: {data}") + if chat_request is None: + raise ValueError(f"Unknown request type: {data}") + + if isinstance(chat_request, ChatCompletionRequest): + retriever_parameters = RetrieverParms( + search_type=chat_request.search_type if chat_request.search_type else "similarity", + k=chat_request.k if chat_request.k else 4, + distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None, + fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20, + lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5, + score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2, + ) + reranker_parameters = RerankerParms( + top_n=chat_request.top_n if chat_request.top_n else 1, + ) + + initial_inputs = { + "messages": query, + "input": query, # has to be input due to embedding expects either input or text + "search_type": chat_request.search_type if chat_request.search_type else "similarity", + "k": chat_request.k if chat_request.k else 4, + "distance_threshold": chat_request.distance_threshold if chat_request.distance_threshold else None, + "fetch_k": chat_request.fetch_k if chat_request.fetch_k else 20, + "lambda_mult": chat_request.lambda_mult if chat_request.lambda_mult else 0.5, + "score_threshold": chat_request.score_threshold if chat_request.score_threshold else 0.2, + "top_n": chat_request.top_n if chat_request.top_n else 1, + } + + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs=initial_inputs, + retriever_parameters=retriever_parameters, + reranker_parameters=reranker_parameters, + ) + else: + result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"text": query}) + + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node] + return response + + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.RETRIEVALTOOL), + input_datatype=Union[TextDoc, EmbeddingRequest, ChatCompletionRequest], + output_datatype=Union[RerankedDoc, LLMParamsDoc], + ) if __name__ == "__main__": - chatqna = RetrievalToolService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + chatqna = RetrievalToolService(port=MEGA_SERVICE_PORT) chatqna.add_remote_service() + chatqna.start() diff --git a/DocSum/docsum.py b/DocSum/docsum.py index f6094191a0..39a13b01af 100644 --- a/DocSum/docsum.py +++ b/DocSum/docsum.py @@ -3,10 +3,21 @@ import asyncio import os +from typing import List -from comps import DocSumGateway, MicroService, ServiceOrchestrator, ServiceType +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.mega.gateway import read_text_from_file +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + UsageInfo, +) +from comps.cores.proto.docarray import LLMParams +from fastapi import File, Request, UploadFile +from fastapi.responses import StreamingResponse -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) DATA_SERVICE_HOST_IP = os.getenv("DATA_SERVICE_HOST_IP", "0.0.0.0") @@ -16,7 +27,7 @@ LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) -class DocSumService: +class DocSumService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port @@ -41,12 +52,114 @@ def add_remote_service(self): use_remote_service=True, service_type=ServiceType.LLM, ) + self.megaservice.add(llm) - self.megaservice.add(data).add(llm) - self.megaservice.flow_to(data, llm) - self.gateway = DocSumGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + async def handle_request(self, request: Request, files: List[UploadFile] = File(default=None)): + + if "application/json" in request.headers.get("content-type"): + data = await request.json() + stream_opt = data.get("stream", True) + chat_request = ChatCompletionRequest.model_validate(data) + prompt = self._handle_message(chat_request.messages) + + initial_inputs_data = {data["type"]: prompt} + + elif "multipart/form-data" in request.headers.get("content-type"): + data = await request.form() + stream_opt = data.get("stream", True) + chat_request = ChatCompletionRequest.model_validate(data) + + data_type = data.get("type") + + file_summaries = [] + if files: + for file in files: + file_path = f"/tmp/{file.filename}" + + if data_type is not None and data_type in ["audio", "video"]: + raise ValueError( + "Audio and Video file uploads are not supported in docsum with curl request, please use the UI." + ) + + else: + import aiofiles + + async with aiofiles.open(file_path, "wb") as f: + await f.write(await file.read()) + + docs = read_text_from_file(file, file_path) + os.remove(file_path) + + if isinstance(docs, list): + file_summaries.extend(docs) + else: + file_summaries.append(docs) + + if file_summaries: + prompt = self._handle_message(chat_request.messages) + "\n".join(file_summaries) + else: + prompt = self._handle_message(chat_request.messages) + + data_type = data.get("type") + if data_type is not None: + initial_inputs_data = {} + initial_inputs_data[data_type] = prompt + else: + initial_inputs_data = {"query": prompt} + + else: + raise ValueError(f"Unknown request type: {request.headers.get('content-type')}") + + parameters = LLMParams( + max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, + presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, + repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, + streaming=stream_opt, + model=chat_request.model if chat_request.model else None, + language=chat_request.language if chat_request.language else "auto", + ) + + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs=initial_inputs_data, llm_parameters=parameters + ) + + for node, response in result_dict.items(): + # Here it suppose the last microservice in the megaservice is LLM. + if ( + isinstance(response, StreamingResponse) + and node == list(self.megaservice.services.keys())[-1] + and self.megaservice.services[node].service_type == ServiceType.LLM + ): + return response + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["text"] + choices = [] + usage = UsageInfo() + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + ) + ) + return ChatCompletionResponse(model="docsum", choices=choices, usage=usage) + + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.DOC_SUMMARY), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) if __name__ == "__main__": - docsum = DocSumService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + docsum = DocSumService(port=MEGA_SERVICE_PORT) docsum.add_remote_service() + docsum.start() diff --git a/EdgeCraftRAG/chatqna.py b/EdgeCraftRAG/chatqna.py index 02f0a84dd8..31c701b5fa 100644 --- a/EdgeCraftRAG/chatqna.py +++ b/EdgeCraftRAG/chatqna.py @@ -5,7 +5,6 @@ from comps import MicroService, ServiceOrchestrator, ServiceType -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "127.0.0.1") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 16011)) PIPELINE_SERVICE_HOST_IP = os.getenv("PIPELINE_SERVICE_HOST_IP", "127.0.0.1") PIPELINE_SERVICE_PORT = int(os.getenv("PIPELINE_SERVICE_PORT", 16010)) @@ -23,11 +22,22 @@ from fastapi.responses import StreamingResponse -class EdgeCraftRagGateway(Gateway): - def __init__(self, megaservice, host="0.0.0.0", port=16011): - super().__init__( - megaservice, host, port, str(MegaServiceEndpoint.CHAT_QNA), ChatCompletionRequest, ChatCompletionResponse +class EdgeCraftRagService(Gateway): + def __init__(self, host="0.0.0.0", port=16010): + self.host = host + self.port = port + self.megaservice = ServiceOrchestrator() + + def add_remote_service(self): + edgecraftrag = MicroService( + name="pipeline", + host=PIPELINE_SERVICE_HOST_IP, + port=PIPELINE_SERVICE_PORT, + endpoint="/v1/chatqna", + use_remote_service=True, + service_type=ServiceType.LLM, ) + self.megaservice.add(edgecraftrag) async def handle_request(self, request: Request): input = await request.json() @@ -61,26 +71,18 @@ async def handle_request(self, request: Request): ) return ChatCompletionResponse(model="edgecraftrag", choices=choices, usage=usage) - -class EdgeCraftRagService: - def __init__(self, host="0.0.0.0", port=16010): - self.host = host - self.port = port - self.megaservice = ServiceOrchestrator() - - def add_remote_service(self): - edgecraftrag = MicroService( - name="pipeline", - host=PIPELINE_SERVICE_HOST_IP, - port=PIPELINE_SERVICE_PORT, - endpoint="/v1/chatqna", - use_remote_service=True, - service_type=ServiceType.LLM, + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.CHAT_QNA), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, ) - self.megaservice.add(edgecraftrag) - self.gateway = EdgeCraftRagGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) if __name__ == "__main__": - edgecraftrag = EdgeCraftRagService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + edgecraftrag = EdgeCraftRagService(port=MEGA_SERVICE_PORT) edgecraftrag.add_remote_service() + edgecraftrag.start() diff --git a/FaqGen/faqgen.py b/FaqGen/faqgen.py index bfcf148711..9ebe3d673f 100644 --- a/FaqGen/faqgen.py +++ b/FaqGen/faqgen.py @@ -3,16 +3,27 @@ import asyncio import os +from typing import List -from comps import FaqGenGateway, MicroService, ServiceOrchestrator, ServiceType +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.mega.gateway import read_text_from_file +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + UsageInfo, +) +from comps.cores.proto.docarray import LLMParams +from fastapi import File, Request, UploadFile +from fastapi.responses import StreamingResponse -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0") LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) -class FaqGenService: +class FaqGenService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port @@ -28,9 +39,79 @@ def add_remote_service(self): service_type=ServiceType.LLM, ) self.megaservice.add(llm) - self.gateway = FaqGenGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + async def handle_request(self, request: Request, files: List[UploadFile] = File(default=None)): + data = await request.form() + stream_opt = data.get("stream", True) + chat_request = ChatCompletionRequest.parse_obj(data) + file_summaries = [] + if files: + for file in files: + file_path = f"/tmp/{file.filename}" + + import aiofiles + + async with aiofiles.open(file_path, "wb") as f: + await f.write(await file.read()) + docs = read_text_from_file(file, file_path) + os.remove(file_path) + if isinstance(docs, list): + file_summaries.extend(docs) + else: + file_summaries.append(docs) + + if file_summaries: + prompt = self._handle_message(chat_request.messages) + "\n".join(file_summaries) + else: + prompt = self._handle_message(chat_request.messages) + + parameters = LLMParams( + max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, + presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, + repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, + streaming=stream_opt, + model=chat_request.model if chat_request.model else None, + ) + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"query": prompt}, llm_parameters=parameters + ) + for node, response in result_dict.items(): + # Here it suppose the last microservice in the megaservice is LLM. + if ( + isinstance(response, StreamingResponse) + and node == list(self.megaservice.services.keys())[-1] + and self.megaservice.services[node].service_type == ServiceType.LLM + ): + return response + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["text"] + choices = [] + usage = UsageInfo() + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + ) + ) + return ChatCompletionResponse(model="faqgen", choices=choices, usage=usage) + + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.FAQ_GEN), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) if __name__ == "__main__": - faqgen = FaqGenService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + faqgen = FaqGenService(port=MEGA_SERVICE_PORT) faqgen.add_remote_service() + faqgen.start() diff --git a/GraphRAG/graphrag.py b/GraphRAG/graphrag.py index a6fd8b0939..f675095147 100644 --- a/GraphRAG/graphrag.py +++ b/GraphRAG/graphrag.py @@ -6,7 +6,18 @@ import os import re -from comps import GraphragGateway, MicroService, ServiceOrchestrator, ServiceType +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + EmbeddingRequest, + UsageInfo, +) +from comps.cores.proto.docarray import LLMParams, RetrieverParms, TextDoc +from fastapi import Request +from fastapi.responses import StreamingResponse from langchain_core.prompts import PromptTemplate @@ -35,7 +46,6 @@ def generate_rag_prompt(question, documents): return template.format(context=context_str, question=question) -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) RETRIEVER_SERVICE_HOST_IP = os.getenv("RETRIEVER_SERVICE_HOST_IP", "0.0.0.0") RETRIEVER_SERVICE_PORT = int(os.getenv("RETRIEVER_SERVICE_PORT", 7000)) @@ -117,7 +127,7 @@ def align_generator(self, gen, **kwargs): yield "data: [DONE]\n\n" -class GraphRAGService: +class GraphRAGService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port @@ -146,9 +156,84 @@ def add_remote_service(self): ) self.megaservice.add(retriever).add(llm) self.megaservice.flow_to(retriever, llm) - self.gateway = GraphragGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + async def handle_request(self, request: Request): + data = await request.json() + stream_opt = data.get("stream", True) + chat_request = ChatCompletionRequest.parse_obj(data) + + def parser_input(data, TypeClass, key): + chat_request = None + try: + chat_request = TypeClass.parse_obj(data) + query = getattr(chat_request, key) + except: + query = None + return query, chat_request + + query = None + for key, TypeClass in zip(["text", "input", "messages"], [TextDoc, EmbeddingRequest, ChatCompletionRequest]): + query, chat_request = parser_input(data, TypeClass, key) + if query is not None: + break + if query is None: + raise ValueError(f"Unknown request type: {data}") + if chat_request is None: + raise ValueError(f"Unknown request type: {data}") + prompt = self._handle_message(chat_request.messages) + parameters = LLMParams( + max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, + presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, + repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, + streaming=stream_opt, + chat_template=chat_request.chat_template if chat_request.chat_template else None, + ) + retriever_parameters = RetrieverParms( + search_type=chat_request.search_type if chat_request.search_type else "similarity", + k=chat_request.k if chat_request.k else 4, + distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None, + fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20, + lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5, + score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2, + ) + initial_inputs = chat_request + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs=initial_inputs, + llm_parameters=parameters, + retriever_parameters=retriever_parameters, + ) + for node, response in result_dict.items(): + if isinstance(response, StreamingResponse): + return response + last_node = runtime_graph.all_leaves()[-1] + response_content = result_dict[last_node]["choices"][0]["message"]["content"] + choices = [] + usage = UsageInfo() + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response_content), + finish_reason="stop", + ) + ) + return ChatCompletionResponse(model="chatqna", choices=choices, usage=usage) + + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.GRAPH_RAG), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) if __name__ == "__main__": - graphrag = GraphRAGService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + graphrag = GraphRAGService(port=MEGA_SERVICE_PORT) graphrag.add_remote_service() + graphrag.start() diff --git a/MultimodalQnA/multimodalqna.py b/MultimodalQnA/multimodalqna.py index ea1c104dc7..bf9b375749 100644 --- a/MultimodalQnA/multimodalqna.py +++ b/MultimodalQnA/multimodalqna.py @@ -1,11 +1,24 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import base64 import os +from io import BytesIO -from comps import MicroService, MultimodalQnAGateway, ServiceOrchestrator, ServiceType +import requests +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + UsageInfo, +) +from comps.cores.proto.docarray import LLMParams +from fastapi import Request +from fastapi.responses import StreamingResponse +from PIL import Image -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) MM_EMBEDDING_SERVICE_HOST_IP = os.getenv("MM_EMBEDDING_SERVICE_HOST_IP", "0.0.0.0") MM_EMBEDDING_PORT_MICROSERVICE = int(os.getenv("MM_EMBEDDING_PORT_MICROSERVICE", 6000)) @@ -15,12 +28,12 @@ LVM_SERVICE_PORT = int(os.getenv("LVM_SERVICE_PORT", 9399)) -class MultimodalQnAService: +class MultimodalQnAService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port - self.mmrag_megaservice = ServiceOrchestrator() self.lvm_megaservice = ServiceOrchestrator() + self.megaservice = ServiceOrchestrator() def add_remote_service(self): mm_embedding = MicroService( @@ -50,21 +63,186 @@ def add_remote_service(self): ) # for mmrag megaservice - self.mmrag_megaservice.add(mm_embedding).add(mm_retriever).add(lvm) - self.mmrag_megaservice.flow_to(mm_embedding, mm_retriever) - self.mmrag_megaservice.flow_to(mm_retriever, lvm) + self.megaservice.add(mm_embedding).add(mm_retriever).add(lvm) + self.megaservice.flow_to(mm_embedding, mm_retriever) + self.megaservice.flow_to(mm_retriever, lvm) # for lvm megaservice self.lvm_megaservice.add(lvm) - self.gateway = MultimodalQnAGateway( - multimodal_rag_megaservice=self.mmrag_megaservice, - lvm_megaservice=self.lvm_megaservice, - host="0.0.0.0", + # this overrides _handle_message method of Gateway + def _handle_message(self, messages): + images = [] + messages_dicts = [] + if isinstance(messages, str): + prompt = messages + else: + messages_dict = {} + system_prompt = "" + prompt = "" + for message in messages: + msg_role = message["role"] + messages_dict = {} + if msg_role == "system": + system_prompt = message["content"] + elif msg_role == "user": + if type(message["content"]) == list: + text = "" + text_list = [item["text"] for item in message["content"] if item["type"] == "text"] + text += "\n".join(text_list) + image_list = [ + item["image_url"]["url"] for item in message["content"] if item["type"] == "image_url" + ] + if image_list: + messages_dict[msg_role] = (text, image_list) + else: + messages_dict[msg_role] = text + else: + messages_dict[msg_role] = message["content"] + messages_dicts.append(messages_dict) + elif msg_role == "assistant": + messages_dict[msg_role] = message["content"] + messages_dicts.append(messages_dict) + else: + raise ValueError(f"Unknown role: {msg_role}") + + if system_prompt: + prompt = system_prompt + "\n" + for messages_dict in messages_dicts: + for i, (role, message) in enumerate(messages_dict.items()): + if isinstance(message, tuple): + text, image_list = message + if i == 0: + # do not add role for the very first message. + # this will be added by llava_server + if text: + prompt += text + "\n" + else: + if text: + prompt += role.upper() + ": " + text + "\n" + else: + prompt += role.upper() + ":" + for img in image_list: + # URL + if img.startswith("http://") or img.startswith("https://"): + response = requests.get(img) + image = Image.open(BytesIO(response.content)).convert("RGBA") + image_bytes = BytesIO() + image.save(image_bytes, format="PNG") + img_b64_str = base64.b64encode(image_bytes.getvalue()).decode() + # Local Path + elif os.path.exists(img): + image = Image.open(img).convert("RGBA") + image_bytes = BytesIO() + image.save(image_bytes, format="PNG") + img_b64_str = base64.b64encode(image_bytes.getvalue()).decode() + # Bytes + else: + img_b64_str = img + + images.append(img_b64_str) + else: + if i == 0: + # do not add role for the very first message. + # this will be added by llava_server + if message: + prompt += role.upper() + ": " + message + "\n" + else: + if message: + prompt += role.upper() + ": " + message + "\n" + else: + prompt += role.upper() + ":" + if images: + return prompt, images + else: + return prompt + + async def handle_request(self, request: Request): + data = await request.json() + stream_opt = bool(data.get("stream", False)) + if stream_opt: + print("[ MultimodalQnAService ] stream=True not used, this has not support streaming yet!") + stream_opt = False + chat_request = ChatCompletionRequest.model_validate(data) + # Multimodal RAG QnA With Videos has not yet accepts image as input during QnA. + prompt_and_image = self._handle_message(chat_request.messages) + if isinstance(prompt_and_image, tuple): + # print(f"This request include image, thus it is a follow-up query. Using lvm megaservice") + prompt, images = prompt_and_image + cur_megaservice = self.lvm_megaservice + initial_inputs = {"prompt": prompt, "image": images[0]} + else: + # print(f"This is the first query, requiring multimodal retrieval. Using multimodal rag megaservice") + prompt = prompt_and_image + cur_megaservice = self.megaservice + initial_inputs = {"text": prompt} + + parameters = LLMParams( + max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, + presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, + repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, + streaming=stream_opt, + chat_template=chat_request.chat_template if chat_request.chat_template else None, + ) + result_dict, runtime_graph = await cur_megaservice.schedule( + initial_inputs=initial_inputs, llm_parameters=parameters + ) + for node, response in result_dict.items(): + # the last microservice in this megaservice is LVM. + # checking if LVM returns StreamingResponse + # Currently, LVM with LLAVA has not yet supported streaming. + # @TODO: Will need to test this once LVM with LLAVA supports streaming + if ( + isinstance(response, StreamingResponse) + and node == runtime_graph.all_leaves()[-1] + and self.megaservice.services[node].service_type == ServiceType.LVM + ): + return response + last_node = runtime_graph.all_leaves()[-1] + + if "text" in result_dict[last_node].keys(): + response = result_dict[last_node]["text"] + else: + # text in not response message + # something wrong, for example due to empty retrieval results + if "detail" in result_dict[last_node].keys(): + response = result_dict[last_node]["detail"] + else: + response = "The server fail to generate answer to your query!" + if "metadata" in result_dict[last_node].keys(): + # from retrieval results + metadata = result_dict[last_node]["metadata"] + else: + # follow-up question, no retrieval + metadata = None + choices = [] + usage = UsageInfo() + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + metadata=metadata, + ) + ) + return ChatCompletionResponse(model="multimodalqna", choices=choices, usage=usage) + + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, port=self.port, + endpoint=str(MegaServiceEndpoint.MULTIMODAL_QNA), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, ) if __name__ == "__main__": - mmragwithvideos = MultimodalQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + mmragwithvideos = MultimodalQnAService(port=MEGA_SERVICE_PORT) mmragwithvideos.add_remote_service() + mmragwithvideos.start() diff --git a/SearchQnA/searchqna.py b/SearchQnA/searchqna.py index fa620d42a6..1a04a97a60 100644 --- a/SearchQnA/searchqna.py +++ b/SearchQnA/searchqna.py @@ -3,9 +3,18 @@ import os -from comps import MicroService, SearchQnAGateway, ServiceOrchestrator, ServiceType +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + UsageInfo, +) +from comps.cores.proto.docarray import LLMParams +from fastapi import Request +from fastapi.responses import StreamingResponse -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) EMBEDDING_SERVICE_HOST_IP = os.getenv("EMBEDDING_SERVICE_HOST_IP", "0.0.0.0") EMBEDDING_SERVICE_PORT = int(os.getenv("EMBEDDING_SERVICE_PORT", 6000)) @@ -17,7 +26,7 @@ LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) -class SearchQnAService: +class SearchQnAService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port @@ -60,9 +69,58 @@ def add_remote_service(self): self.megaservice.flow_to(embedding, web_retriever) self.megaservice.flow_to(web_retriever, rerank) self.megaservice.flow_to(rerank, llm) - self.gateway = SearchQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + async def handle_request(self, request: Request): + data = await request.json() + stream_opt = data.get("stream", True) + chat_request = ChatCompletionRequest.parse_obj(data) + prompt = self._handle_message(chat_request.messages) + parameters = LLMParams( + max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, + presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, + repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, + streaming=stream_opt, + ) + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"text": prompt}, llm_parameters=parameters + ) + for node, response in result_dict.items(): + # Here it suppose the last microservice in the megaservice is LLM. + if ( + isinstance(response, StreamingResponse) + and node == list(self.megaservice.services.keys())[-1] + and self.megaservice.services[node].service_type == ServiceType.LLM + ): + return response + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["text"] + choices = [] + usage = UsageInfo() + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + ) + ) + return ChatCompletionResponse(model="searchqna", choices=choices, usage=usage) + + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.SEARCH_QNA), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) if __name__ == "__main__": - searchqna = SearchQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + searchqna = SearchQnAService(port=MEGA_SERVICE_PORT) searchqna.add_remote_service() + searchqna.start() diff --git a/Translation/translation.py b/Translation/translation.py index 86093a166e..92e8b6f0dd 100644 --- a/Translation/translation.py +++ b/Translation/translation.py @@ -15,15 +15,23 @@ import asyncio import os -from comps import MicroService, ServiceOrchestrator, ServiceType, TranslationGateway +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + UsageInfo, +) +from fastapi import Request +from fastapi.responses import StreamingResponse -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0") LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) -class TranslationService: +class TranslationService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port @@ -39,9 +47,57 @@ def add_remote_service(self): service_type=ServiceType.LLM, ) self.megaservice.add(llm) - self.gateway = TranslationGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + async def handle_request(self, request: Request): + data = await request.json() + language_from = data["language_from"] + language_to = data["language_to"] + source_language = data["source_language"] + prompt_template = """ + Translate this from {language_from} to {language_to}: + + {language_from}: + {source_language} + + {language_to}: + """ + prompt = prompt_template.format( + language_from=language_from, language_to=language_to, source_language=source_language + ) + result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"query": prompt}) + for node, response in result_dict.items(): + # Here it suppose the last microservice in the megaservice is LLM. + if ( + isinstance(response, StreamingResponse) + and node == list(self.megaservice.services.keys())[-1] + and self.megaservice.services[node].service_type == ServiceType.LLM + ): + return response + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["text"] + choices = [] + usage = UsageInfo() + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + ) + ) + return ChatCompletionResponse(model="translation", choices=choices, usage=usage) + + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.TRANSLATION), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) if __name__ == "__main__": - translation = TranslationService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + translation = TranslationService(port=MEGA_SERVICE_PORT) translation.add_remote_service() + translation.start() diff --git a/VideoQnA/videoqna.py b/VideoQnA/videoqna.py index fc8fe778e5..7632a1c10b 100644 --- a/VideoQnA/videoqna.py +++ b/VideoQnA/videoqna.py @@ -3,9 +3,18 @@ import os -from comps import MicroService, ServiceOrchestrator, ServiceType, VideoQnAGateway +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + UsageInfo, +) +from comps.cores.proto.docarray import LLMParams +from fastapi import Request +from fastapi.responses import StreamingResponse -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) EMBEDDING_SERVICE_HOST_IP = os.getenv("EMBEDDING_SERVICE_HOST_IP", "0.0.0.0") EMBEDDING_SERVICE_PORT = int(os.getenv("EMBEDDING_SERVICE_PORT", 6000)) @@ -17,7 +26,7 @@ LVM_SERVICE_PORT = int(os.getenv("LVM_SERVICE_PORT", 9000)) -class VideoQnAService: +class VideoQnAService(Gateway): def __init__(self, host="0.0.0.0", port=8888): self.host = host self.port = port @@ -60,9 +69,58 @@ def add_remote_service(self): self.megaservice.flow_to(embedding, retriever) self.megaservice.flow_to(retriever, rerank) self.megaservice.flow_to(rerank, lvm) - self.gateway = VideoQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + async def handle_request(self, request: Request): + data = await request.json() + stream_opt = data.get("stream", False) + chat_request = ChatCompletionRequest.parse_obj(data) + prompt = self._handle_message(chat_request.messages) + parameters = LLMParams( + max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, + presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, + repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, + streaming=stream_opt, + ) + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"text": prompt}, llm_parameters=parameters + ) + for node, response in result_dict.items(): + # Here it suppose the last microservice in the megaservice is LVM. + if ( + isinstance(response, StreamingResponse) + and node == list(self.megaservice.services.keys())[-1] + and self.megaservice.services[node].service_type == ServiceType.LVM + ): + return response + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["text"] + choices = [] + usage = UsageInfo() + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + ) + ) + return ChatCompletionResponse(model="videoqna", choices=choices, usage=usage) + + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.VIDEO_RAG_QNA), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) if __name__ == "__main__": - videoqna = VideoQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + videoqna = VideoQnAService(port=MEGA_SERVICE_PORT) videoqna.add_remote_service() + videoqna.start() diff --git a/VisualQnA/visualqna.py b/VisualQnA/visualqna.py index 4ad850b006..f6519c1d27 100644 --- a/VisualQnA/visualqna.py +++ b/VisualQnA/visualqna.py @@ -3,15 +3,24 @@ import os -from comps import MicroService, ServiceOrchestrator, ServiceType, VisualQnAGateway +from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + UsageInfo, +) +from comps.cores.proto.docarray import LLMParams +from fastapi import Request +from fastapi.responses import StreamingResponse -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) LVM_SERVICE_HOST_IP = os.getenv("LVM_SERVICE_HOST_IP", "0.0.0.0") LVM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9399)) -class VisualQnAService: +class VisualQnAService(Gateway): def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port @@ -27,9 +36,58 @@ def add_remote_service(self): service_type=ServiceType.LVM, ) self.megaservice.add(llm) - self.gateway = VisualQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + async def handle_request(self, request: Request): + data = await request.json() + stream_opt = data.get("stream", False) + chat_request = ChatCompletionRequest.parse_obj(data) + prompt, images = self._handle_message(chat_request.messages) + parameters = LLMParams( + max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, + presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, + repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, + streaming=stream_opt, + ) + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"prompt": prompt, "image": images[0]}, llm_parameters=parameters + ) + for node, response in result_dict.items(): + # Here it suppose the last microservice in the megaservice is LVM. + if ( + isinstance(response, StreamingResponse) + and node == list(self.megaservice.services.keys())[-1] + and self.megaservice.services[node].service_type == ServiceType.LVM + ): + return response + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["text"] + choices = [] + usage = UsageInfo() + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + ) + ) + return ChatCompletionResponse(model="visualqna", choices=choices, usage=usage) + + def start(self): + super().__init__( + megaservice=self.megaservice, + host=self.host, + port=self.port, + endpoint=str(MegaServiceEndpoint.VISUAL_QNA), + input_datatype=ChatCompletionRequest, + output_datatype=ChatCompletionResponse, + ) if __name__ == "__main__": - visualqna = VisualQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + visualqna = VisualQnAService(port=MEGA_SERVICE_PORT) visualqna.add_remote_service() + visualqna.start()