From 23f38c00c6371f2d7e91d7d78f095285a7a12ddf Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 14 Jan 2024 08:57:47 +0000 Subject: [PATCH] feat(mamba): Initial import This is a first iteration of the mamba backend, loosely based on mamba-chat(https://github.com/havenhq/mamba-chat). --- Dockerfile | 5 +- Makefile | 2 + backend/backend.proto | 1 - backend/python/coqui/coqui_server.py | 8 +- backend/python/mamba/Makefile | 16 + backend/python/mamba/README.md | 5 + backend/python/mamba/backend_mamba.py | 182 +++++++++++ backend/python/mamba/backend_pb2.py | 61 ++++ backend/python/mamba/backend_pb2_grpc.py | 363 +++++++++++++++++++++ backend/python/mamba/install.sh | 21 ++ backend/python/mamba/run.sh | 14 + backend/python/mamba/test.sh | 11 + backend/python/mamba/test_backend_mamba.py | 76 +++++ 13 files changed, 762 insertions(+), 3 deletions(-) create mode 100644 backend/python/mamba/Makefile create mode 100644 backend/python/mamba/README.md create mode 100644 backend/python/mamba/backend_mamba.py create mode 100644 backend/python/mamba/backend_pb2.py create mode 100644 backend/python/mamba/backend_pb2_grpc.py create mode 100644 backend/python/mamba/install.sh create mode 100755 backend/python/mamba/run.sh create mode 100644 backend/python/mamba/test.sh create mode 100644 backend/python/mamba/test_backend_mamba.py diff --git a/Dockerfile b/Dockerfile index b5217da0ead2..ab63d442b2b9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,7 +13,7 @@ ARG TARGETVARIANT ENV BUILD_TYPE=${BUILD_TYPE} -ENV EXTERNAL_GRPC_BACKENDS="coqui:/build/backend/python/coqui/run.sh,huggingface-embeddings:/build/backend/python/sentencetransformers/run.sh,petals:/build/backend/python/petals/run.sh,transformers:/build/backend/python/transformers/run.sh,sentencetransformers:/build/backend/python/sentencetransformers/run.sh,autogptq:/build/backend/python/autogptq/run.sh,bark:/build/backend/python/bark/run.sh,diffusers:/build/backend/python/diffusers/run.sh,exllama:/build/backend/python/exllama/run.sh,vall-e-x:/build/backend/python/vall-e-x/run.sh,vllm:/build/backend/python/vllm/run.sh,exllama2:/build/backend/python/exllama2/run.sh,transformers-musicgen:/build/backend/python/transformers-musicgen/run.sh" +ENV EXTERNAL_GRPC_BACKENDS="coqui:/build/backend/python/coqui/run.sh,huggingface-embeddings:/build/backend/python/sentencetransformers/run.sh,petals:/build/backend/python/petals/run.sh,transformers:/build/backend/python/transformers/run.sh,sentencetransformers:/build/backend/python/sentencetransformers/run.sh,autogptq:/build/backend/python/autogptq/run.sh,bark:/build/backend/python/bark/run.sh,diffusers:/build/backend/python/diffusers/run.sh,exllama:/build/backend/python/exllama/run.sh,vall-e-x:/build/backend/python/vall-e-x/run.sh,vllm:/build/backend/python/vllm/run.sh,mamba:/build/backend/python/mamba/run.sh,exllama2:/build/backend/python/exllama2/run.sh,transformers-musicgen:/build/backend/python/transformers-musicgen/run.sh" ARG GO_TAGS="stablediffusion tinydream tts" @@ -168,6 +168,9 @@ RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \ RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \ PATH=$PATH:/opt/conda/bin make -C backend/python/vllm \ ; fi +RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \ + PATH=$PATH:/opt/conda/bin make -C backend/python/mamba \ + ; fi RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \ PATH=$PATH:/opt/conda/bin make -C backend/python/sentencetransformers \ ; fi diff --git a/Makefile b/Makefile index 23fcd53629b2..9958b94414f6 100644 --- a/Makefile +++ b/Makefile @@ -419,6 +419,7 @@ protogen-python: python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/vall-e-x/ --grpc_python_out=backend/python/vall-e-x/ backend/backend.proto python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/vllm/ --grpc_python_out=backend/python/vllm/ backend/backend.proto python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/petals/ --grpc_python_out=backend/python/petals/ backend/backend.proto + python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/mamba/ --grpc_python_out=backend/python/mamba/ backend/backend.proto python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/exllama2/ --grpc_python_out=backend/python/exllama2/ backend/backend.proto ## GRPC @@ -429,6 +430,7 @@ prepare-extra-conda-environments: $(MAKE) -C backend/python/coqui $(MAKE) -C backend/python/diffusers $(MAKE) -C backend/python/vllm + $(MAKE) -C backend/python/mamba $(MAKE) -C backend/python/sentencetransformers $(MAKE) -C backend/python/transformers $(MAKE) -C backend/python/transformers-musicgen diff --git a/backend/backend.proto b/backend/backend.proto index 2124cebf15e7..dff5ffe77407 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -112,7 +112,6 @@ message ModelOptions { int32 CLIPSkip = 33; string ControlNet = 48; - // RWKV string Tokenizer = 34; // LLM (llama.cpp) diff --git a/backend/python/coqui/coqui_server.py b/backend/python/coqui/coqui_server.py index 1c83c4edbf82..c6432208f5e9 100644 --- a/backend/python/coqui/coqui_server.py +++ b/backend/python/coqui/coqui_server.py @@ -33,7 +33,13 @@ def Health(self, request, context): def LoadModel(self, request, context): # Get device - device = "cuda" if request.CUDA else "cpu" + # device = "cuda" if request.CUDA else "cpu" + if torch.cuda.is_available(): + print("CUDA is available", file=sys.stderr) + device = "cuda" + else: + print("CUDA is not available", file=sys.stderr) + device = "cpu" if not torch.cuda.is_available() and request.CUDA: return backend_pb2.Result(success=False, message="CUDA is not available") diff --git a/backend/python/mamba/Makefile b/backend/python/mamba/Makefile new file mode 100644 index 000000000000..3ff00346cd73 --- /dev/null +++ b/backend/python/mamba/Makefile @@ -0,0 +1,16 @@ +.PHONY: mamba +mamba: + $(MAKE) -C ../common-env/transformers + bash install.sh + +.PHONY: run +run: + @echo "Running mamba..." + bash run.sh + @echo "mamba run." + +.PHONY: test +test: + @echo "Testing mamba..." + bash test.sh + @echo "mamba tested." \ No newline at end of file diff --git a/backend/python/mamba/README.md b/backend/python/mamba/README.md new file mode 100644 index 000000000000..d6ead9176e34 --- /dev/null +++ b/backend/python/mamba/README.md @@ -0,0 +1,5 @@ +# Creating a separate environment for the mamba project + +``` +make mamba +``` \ No newline at end of file diff --git a/backend/python/mamba/backend_mamba.py b/backend/python/mamba/backend_mamba.py new file mode 100644 index 000000000000..e99bd727204a --- /dev/null +++ b/backend/python/mamba/backend_mamba.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +from concurrent import futures +import time +import argparse +import signal +import sys +import os + +import backend_pb2 +import backend_pb2_grpc + +import grpc + +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel + +_ONE_DAY_IN_SECONDS = 60 * 60 * 24 + +# If MAX_WORKERS are specified in the environment use it, otherwise default to 1 +MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) +MAMBA_CHAT= os.environ.get('MAMBA_CHAT', '1') == '1' + +# Implement the BackendServicer class with the service methods +class BackendServicer(backend_pb2_grpc.BackendServicer): + """ + A gRPC servicer that implements the Backend service defined in backend.proto. + """ + def generate(self,prompt, max_new_tokens): + """ + Generates text based on the given prompt and maximum number of new tokens. + + Args: + prompt (str): The prompt to generate text from. + max_new_tokens (int): The maximum number of new tokens to generate. + + Returns: + str: The generated text. + """ + self.generator.end_beam_search() + + # Tokenizing the input + ids = self.generator.tokenizer.encode(prompt) + + self.generator.gen_begin_reuse(ids) + initial_len = self.generator.sequence[0].shape[0] + has_leading_space = False + decoded_text = '' + for i in range(max_new_tokens): + token = self.generator.gen_single_token() + if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'): + has_leading_space = True + + decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:]) + if has_leading_space: + decoded_text = ' ' + decoded_text + + if token.item() == self.generator.tokenizer.eos_token_id: + break + return decoded_text + + def Health(self, request, context): + """ + Returns a health check message. + + Args: + request: The health check request. + context: The gRPC context. + + Returns: + backend_pb2.Reply: The health check reply. + """ + return backend_pb2.Reply(message=bytes("OK", 'utf-8')) + + def LoadModel(self, request, context): + """ + Loads a language model. + + Args: + request: The load model request. + context: The gRPC context. + + Returns: + backend_pb2.Result: The load model result. + """ + try: + tokenizerModel = request.Tokenizer + if tokenizerModel == "": + tokenizerModel = request.Model + + tokenizer = AutoTokenizer.from_pretrained(tokenizerModel) + if MAMBA_CHAT: + tokenizer.eos_token = "<|endoftext|>" + tokenizer.pad_token = tokenizer.eos_token + self.tokenizer = tokenizer + self.model = MambaLMHeadModel.from_pretrained(request.Model, device="cuda", dtype=torch.float16) + except Exception as err: + return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") + return backend_pb2.Result(message="Model loaded successfully", success=True) + + def Predict(self, request, context): + """ + Generates text based on the given prompt and sampling parameters. + + Args: + request: The predict request. + context: The gRPC context. + + Returns: + backend_pb2.Result: The predict result. + """ + if request.TopP == 0: + request.TopP = 0.9 + + max_tokens = request.Tokens + + if request.Tokens == 0: + max_tokens = 2000 + + encoded_input = self.tokenizer(request.Prompt) + + out = self.model.generate(input_ids=encoded_input["input_ids"], max_length=max_tokens, temperature=request.Temperratur, + top_p=request.TopP, eos_token_id=self.tokenizer.eos_token_id) + + decoded = self.tokenizer.batch_decode(out) + + generated_text = decoded[0] + + # Remove prompt from response if present + if request.Prompt in generated_text: + generated_text = generated_text.replace(request.Prompt, "") + + return backend_pb2.Result(message=bytes(generated_text, encoding='utf-8')) + + def PredictStream(self, request, context): + """ + Generates text based on the given prompt and sampling parameters, and streams the results. + + Args: + request: The predict stream request. + context: The gRPC context. + + Returns: + backend_pb2.Result: The predict stream result. + """ + # Implement PredictStream RPC + #for reply in some_data_generator(): + # yield reply + # Not implemented yet + return self.Predict(request, context) + +def serve(address): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS)) + backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) + server.add_insecure_port(address) + server.start() + print("Server started. Listening on: " + address, file=sys.stderr) + + # Define the signal handler function + def signal_handler(sig, frame): + print("Received termination signal. Shutting down...") + server.stop(0) + sys.exit(0) + + # Set the signal handlers for SIGINT and SIGTERM + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + try: + while True: + time.sleep(_ONE_DAY_IN_SECONDS) + except KeyboardInterrupt: + server.stop(0) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run the gRPC server.") + parser.add_argument( + "--addr", default="localhost:50051", help="The address to bind the server to." + ) + args = parser.parse_args() + + serve(args.addr) diff --git a/backend/python/mamba/backend_pb2.py b/backend/python/mamba/backend_pb2.py new file mode 100644 index 000000000000..a4a46e04ab97 --- /dev/null +++ b/backend/python/mamba/backend_pb2.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: backend.proto +# Protobuf Python Version: 4.25.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rbackend.proto\x12\x07\x62\x61\x63kend\"\x0f\n\rHealthMessage\"\xa6\x06\n\x0ePredictOptions\x12\x0e\n\x06Prompt\x18\x01 \x01(\t\x12\x0c\n\x04Seed\x18\x02 \x01(\x05\x12\x0f\n\x07Threads\x18\x03 \x01(\x05\x12\x0e\n\x06Tokens\x18\x04 \x01(\x05\x12\x0c\n\x04TopK\x18\x05 \x01(\x05\x12\x0e\n\x06Repeat\x18\x06 \x01(\x05\x12\r\n\x05\x42\x61tch\x18\x07 \x01(\x05\x12\r\n\x05NKeep\x18\x08 \x01(\x05\x12\x13\n\x0bTemperature\x18\t \x01(\x02\x12\x0f\n\x07Penalty\x18\n \x01(\x02\x12\r\n\x05\x46\x31\x36KV\x18\x0b \x01(\x08\x12\x11\n\tDebugMode\x18\x0c \x01(\x08\x12\x13\n\x0bStopPrompts\x18\r \x03(\t\x12\x11\n\tIgnoreEOS\x18\x0e \x01(\x08\x12\x19\n\x11TailFreeSamplingZ\x18\x0f \x01(\x02\x12\x10\n\x08TypicalP\x18\x10 \x01(\x02\x12\x18\n\x10\x46requencyPenalty\x18\x11 \x01(\x02\x12\x17\n\x0fPresencePenalty\x18\x12 \x01(\x02\x12\x10\n\x08Mirostat\x18\x13 \x01(\x05\x12\x13\n\x0bMirostatETA\x18\x14 \x01(\x02\x12\x13\n\x0bMirostatTAU\x18\x15 \x01(\x02\x12\x12\n\nPenalizeNL\x18\x16 \x01(\x08\x12\x11\n\tLogitBias\x18\x17 \x01(\t\x12\r\n\x05MLock\x18\x19 \x01(\x08\x12\x0c\n\x04MMap\x18\x1a \x01(\x08\x12\x16\n\x0ePromptCacheAll\x18\x1b \x01(\x08\x12\x15\n\rPromptCacheRO\x18\x1c \x01(\x08\x12\x0f\n\x07Grammar\x18\x1d \x01(\t\x12\x0f\n\x07MainGPU\x18\x1e \x01(\t\x12\x13\n\x0bTensorSplit\x18\x1f \x01(\t\x12\x0c\n\x04TopP\x18 \x01(\x02\x12\x17\n\x0fPromptCachePath\x18! \x01(\t\x12\r\n\x05\x44\x65\x62ug\x18\" \x01(\x08\x12\x17\n\x0f\x45mbeddingTokens\x18# \x03(\x05\x12\x12\n\nEmbeddings\x18$ \x01(\t\x12\x14\n\x0cRopeFreqBase\x18% \x01(\x02\x12\x15\n\rRopeFreqScale\x18& \x01(\x02\x12\x1b\n\x13NegativePromptScale\x18\' \x01(\x02\x12\x16\n\x0eNegativePrompt\x18( \x01(\t\x12\x0e\n\x06NDraft\x18) \x01(\x05\x12\x0e\n\x06Images\x18* \x03(\t\"\x18\n\x05Reply\x12\x0f\n\x07message\x18\x01 \x01(\x0c\"\xad\x07\n\x0cModelOptions\x12\r\n\x05Model\x18\x01 \x01(\t\x12\x13\n\x0b\x43ontextSize\x18\x02 \x01(\x05\x12\x0c\n\x04Seed\x18\x03 \x01(\x05\x12\x0e\n\x06NBatch\x18\x04 \x01(\x05\x12\x11\n\tF16Memory\x18\x05 \x01(\x08\x12\r\n\x05MLock\x18\x06 \x01(\x08\x12\x0c\n\x04MMap\x18\x07 \x01(\x08\x12\x11\n\tVocabOnly\x18\x08 \x01(\x08\x12\x0f\n\x07LowVRAM\x18\t \x01(\x08\x12\x12\n\nEmbeddings\x18\n \x01(\x08\x12\x0c\n\x04NUMA\x18\x0b \x01(\x08\x12\x12\n\nNGPULayers\x18\x0c \x01(\x05\x12\x0f\n\x07MainGPU\x18\r \x01(\t\x12\x13\n\x0bTensorSplit\x18\x0e \x01(\t\x12\x0f\n\x07Threads\x18\x0f \x01(\x05\x12\x19\n\x11LibrarySearchPath\x18\x10 \x01(\t\x12\x14\n\x0cRopeFreqBase\x18\x11 \x01(\x02\x12\x15\n\rRopeFreqScale\x18\x12 \x01(\x02\x12\x12\n\nRMSNormEps\x18\x13 \x01(\x02\x12\x0c\n\x04NGQA\x18\x14 \x01(\x05\x12\x11\n\tModelFile\x18\x15 \x01(\t\x12\x0e\n\x06\x44\x65vice\x18\x16 \x01(\t\x12\x11\n\tUseTriton\x18\x17 \x01(\x08\x12\x15\n\rModelBaseName\x18\x18 \x01(\t\x12\x18\n\x10UseFastTokenizer\x18\x19 \x01(\x08\x12\x14\n\x0cPipelineType\x18\x1a \x01(\t\x12\x15\n\rSchedulerType\x18\x1b \x01(\t\x12\x0c\n\x04\x43UDA\x18\x1c \x01(\x08\x12\x10\n\x08\x43\x46GScale\x18\x1d \x01(\x02\x12\x0f\n\x07IMG2IMG\x18\x1e \x01(\x08\x12\x11\n\tCLIPModel\x18\x1f \x01(\t\x12\x15\n\rCLIPSubfolder\x18 \x01(\t\x12\x10\n\x08\x43LIPSkip\x18! \x01(\x05\x12\x12\n\nControlNet\x18\x30 \x01(\t\x12\x11\n\tTokenizer\x18\" \x01(\t\x12\x10\n\x08LoraBase\x18# \x01(\t\x12\x13\n\x0bLoraAdapter\x18$ \x01(\t\x12\x11\n\tLoraScale\x18* \x01(\x02\x12\x11\n\tNoMulMatQ\x18% \x01(\x08\x12\x12\n\nDraftModel\x18\' \x01(\t\x12\x11\n\tAudioPath\x18& \x01(\t\x12\x14\n\x0cQuantization\x18( \x01(\t\x12\x0e\n\x06MMProj\x18) \x01(\t\x12\x13\n\x0bRopeScaling\x18+ \x01(\t\x12\x15\n\rYarnExtFactor\x18, \x01(\x02\x12\x16\n\x0eYarnAttnFactor\x18- \x01(\x02\x12\x14\n\x0cYarnBetaFast\x18. \x01(\x02\x12\x14\n\x0cYarnBetaSlow\x18/ \x01(\x02\"*\n\x06Result\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\"%\n\x0f\x45mbeddingResult\x12\x12\n\nembeddings\x18\x01 \x03(\x02\"C\n\x11TranscriptRequest\x12\x0b\n\x03\x64st\x18\x02 \x01(\t\x12\x10\n\x08language\x18\x03 \x01(\t\x12\x0f\n\x07threads\x18\x04 \x01(\r\"N\n\x10TranscriptResult\x12,\n\x08segments\x18\x01 \x03(\x0b\x32\x1a.backend.TranscriptSegment\x12\x0c\n\x04text\x18\x02 \x01(\t\"Y\n\x11TranscriptSegment\x12\n\n\x02id\x18\x01 \x01(\x05\x12\r\n\x05start\x18\x02 \x01(\x03\x12\x0b\n\x03\x65nd\x18\x03 \x01(\x03\x12\x0c\n\x04text\x18\x04 \x01(\t\x12\x0e\n\x06tokens\x18\x05 \x03(\x05\"\xd7\x01\n\x14GenerateImageRequest\x12\x0e\n\x06height\x18\x01 \x01(\x05\x12\r\n\x05width\x18\x02 \x01(\x05\x12\x0c\n\x04mode\x18\x03 \x01(\x05\x12\x0c\n\x04step\x18\x04 \x01(\x05\x12\x0c\n\x04seed\x18\x05 \x01(\x05\x12\x17\n\x0fpositive_prompt\x18\x06 \x01(\t\x12\x17\n\x0fnegative_prompt\x18\x07 \x01(\t\x12\x0b\n\x03\x64st\x18\x08 \x01(\t\x12\x0b\n\x03src\x18\t \x01(\t\x12\x18\n\x10\x45nableParameters\x18\n \x01(\t\x12\x10\n\x08\x43LIPSkip\x18\x0b \x01(\x05\"6\n\nTTSRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05model\x18\x02 \x01(\t\x12\x0b\n\x03\x64st\x18\x03 \x01(\t\"6\n\x14TokenizationResponse\x12\x0e\n\x06length\x18\x01 \x01(\x05\x12\x0e\n\x06tokens\x18\x02 \x03(\x05\"\x8e\x01\n\x0fMemoryUsageData\x12\r\n\x05total\x18\x01 \x01(\x04\x12:\n\tbreakdown\x18\x02 \x03(\x0b\x32\'.backend.MemoryUsageData.BreakdownEntry\x1a\x30\n\x0e\x42reakdownEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x04:\x02\x38\x01\"\xad\x01\n\x0eStatusResponse\x12,\n\x05state\x18\x01 \x01(\x0e\x32\x1d.backend.StatusResponse.State\x12(\n\x06memory\x18\x02 \x01(\x0b\x32\x18.backend.MemoryUsageData\"C\n\x05State\x12\x11\n\rUNINITIALIZED\x10\x00\x12\x08\n\x04\x42USY\x10\x01\x12\t\n\x05READY\x10\x02\x12\x12\n\x05\x45RROR\x10\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01\x32\xf4\x04\n\x07\x42\x61\x63kend\x12\x32\n\x06Health\x12\x16.backend.HealthMessage\x1a\x0e.backend.Reply\"\x00\x12\x34\n\x07Predict\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x12\x35\n\tLoadModel\x12\x15.backend.ModelOptions\x1a\x0f.backend.Result\"\x00\x12<\n\rPredictStream\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x30\x01\x12@\n\tEmbedding\x12\x17.backend.PredictOptions\x1a\x18.backend.EmbeddingResult\"\x00\x12\x41\n\rGenerateImage\x12\x1d.backend.GenerateImageRequest\x1a\x0f.backend.Result\"\x00\x12M\n\x12\x41udioTranscription\x12\x1a.backend.TranscriptRequest\x1a\x19.backend.TranscriptResult\"\x00\x12-\n\x03TTS\x12\x13.backend.TTSRequest\x1a\x0f.backend.Result\"\x00\x12J\n\x0eTokenizeString\x12\x17.backend.PredictOptions\x1a\x1d.backend.TokenizationResponse\"\x00\x12;\n\x06Status\x12\x16.backend.HealthMessage\x1a\x17.backend.StatusResponse\"\x00\x42Z\n\x19io.skynet.localai.backendB\x0eLocalAIBackendP\x01Z+github.com/go-skynet/LocalAI/pkg/grpc/protob\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'backend_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + _globals['DESCRIPTOR']._options = None + _globals['DESCRIPTOR']._serialized_options = b'\n\031io.skynet.localai.backendB\016LocalAIBackendP\001Z+github.com/go-skynet/LocalAI/pkg/grpc/proto' + _globals['_MEMORYUSAGEDATA_BREAKDOWNENTRY']._options = None + _globals['_MEMORYUSAGEDATA_BREAKDOWNENTRY']._serialized_options = b'8\001' + _globals['_HEALTHMESSAGE']._serialized_start=26 + _globals['_HEALTHMESSAGE']._serialized_end=41 + _globals['_PREDICTOPTIONS']._serialized_start=44 + _globals['_PREDICTOPTIONS']._serialized_end=850 + _globals['_REPLY']._serialized_start=852 + _globals['_REPLY']._serialized_end=876 + _globals['_MODELOPTIONS']._serialized_start=879 + _globals['_MODELOPTIONS']._serialized_end=1820 + _globals['_RESULT']._serialized_start=1822 + _globals['_RESULT']._serialized_end=1864 + _globals['_EMBEDDINGRESULT']._serialized_start=1866 + _globals['_EMBEDDINGRESULT']._serialized_end=1903 + _globals['_TRANSCRIPTREQUEST']._serialized_start=1905 + _globals['_TRANSCRIPTREQUEST']._serialized_end=1972 + _globals['_TRANSCRIPTRESULT']._serialized_start=1974 + _globals['_TRANSCRIPTRESULT']._serialized_end=2052 + _globals['_TRANSCRIPTSEGMENT']._serialized_start=2054 + _globals['_TRANSCRIPTSEGMENT']._serialized_end=2143 + _globals['_GENERATEIMAGEREQUEST']._serialized_start=2146 + _globals['_GENERATEIMAGEREQUEST']._serialized_end=2361 + _globals['_TTSREQUEST']._serialized_start=2363 + _globals['_TTSREQUEST']._serialized_end=2417 + _globals['_TOKENIZATIONRESPONSE']._serialized_start=2419 + _globals['_TOKENIZATIONRESPONSE']._serialized_end=2473 + _globals['_MEMORYUSAGEDATA']._serialized_start=2476 + _globals['_MEMORYUSAGEDATA']._serialized_end=2618 + _globals['_MEMORYUSAGEDATA_BREAKDOWNENTRY']._serialized_start=2570 + _globals['_MEMORYUSAGEDATA_BREAKDOWNENTRY']._serialized_end=2618 + _globals['_STATUSRESPONSE']._serialized_start=2621 + _globals['_STATUSRESPONSE']._serialized_end=2794 + _globals['_STATUSRESPONSE_STATE']._serialized_start=2727 + _globals['_STATUSRESPONSE_STATE']._serialized_end=2794 + _globals['_BACKEND']._serialized_start=2797 + _globals['_BACKEND']._serialized_end=3425 +# @@protoc_insertion_point(module_scope) diff --git a/backend/python/mamba/backend_pb2_grpc.py b/backend/python/mamba/backend_pb2_grpc.py new file mode 100644 index 000000000000..79a7677fb27f --- /dev/null +++ b/backend/python/mamba/backend_pb2_grpc.py @@ -0,0 +1,363 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import backend_pb2 as backend__pb2 + + +class BackendStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Health = channel.unary_unary( + '/backend.Backend/Health', + request_serializer=backend__pb2.HealthMessage.SerializeToString, + response_deserializer=backend__pb2.Reply.FromString, + ) + self.Predict = channel.unary_unary( + '/backend.Backend/Predict', + request_serializer=backend__pb2.PredictOptions.SerializeToString, + response_deserializer=backend__pb2.Reply.FromString, + ) + self.LoadModel = channel.unary_unary( + '/backend.Backend/LoadModel', + request_serializer=backend__pb2.ModelOptions.SerializeToString, + response_deserializer=backend__pb2.Result.FromString, + ) + self.PredictStream = channel.unary_stream( + '/backend.Backend/PredictStream', + request_serializer=backend__pb2.PredictOptions.SerializeToString, + response_deserializer=backend__pb2.Reply.FromString, + ) + self.Embedding = channel.unary_unary( + '/backend.Backend/Embedding', + request_serializer=backend__pb2.PredictOptions.SerializeToString, + response_deserializer=backend__pb2.EmbeddingResult.FromString, + ) + self.GenerateImage = channel.unary_unary( + '/backend.Backend/GenerateImage', + request_serializer=backend__pb2.GenerateImageRequest.SerializeToString, + response_deserializer=backend__pb2.Result.FromString, + ) + self.AudioTranscription = channel.unary_unary( + '/backend.Backend/AudioTranscription', + request_serializer=backend__pb2.TranscriptRequest.SerializeToString, + response_deserializer=backend__pb2.TranscriptResult.FromString, + ) + self.TTS = channel.unary_unary( + '/backend.Backend/TTS', + request_serializer=backend__pb2.TTSRequest.SerializeToString, + response_deserializer=backend__pb2.Result.FromString, + ) + self.TokenizeString = channel.unary_unary( + '/backend.Backend/TokenizeString', + request_serializer=backend__pb2.PredictOptions.SerializeToString, + response_deserializer=backend__pb2.TokenizationResponse.FromString, + ) + self.Status = channel.unary_unary( + '/backend.Backend/Status', + request_serializer=backend__pb2.HealthMessage.SerializeToString, + response_deserializer=backend__pb2.StatusResponse.FromString, + ) + + +class BackendServicer(object): + """Missing associated documentation comment in .proto file.""" + + def Health(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Predict(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def LoadModel(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def PredictStream(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Embedding(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GenerateImage(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def AudioTranscription(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def TTS(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def TokenizeString(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Status(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_BackendServicer_to_server(servicer, server): + rpc_method_handlers = { + 'Health': grpc.unary_unary_rpc_method_handler( + servicer.Health, + request_deserializer=backend__pb2.HealthMessage.FromString, + response_serializer=backend__pb2.Reply.SerializeToString, + ), + 'Predict': grpc.unary_unary_rpc_method_handler( + servicer.Predict, + request_deserializer=backend__pb2.PredictOptions.FromString, + response_serializer=backend__pb2.Reply.SerializeToString, + ), + 'LoadModel': grpc.unary_unary_rpc_method_handler( + servicer.LoadModel, + request_deserializer=backend__pb2.ModelOptions.FromString, + response_serializer=backend__pb2.Result.SerializeToString, + ), + 'PredictStream': grpc.unary_stream_rpc_method_handler( + servicer.PredictStream, + request_deserializer=backend__pb2.PredictOptions.FromString, + response_serializer=backend__pb2.Reply.SerializeToString, + ), + 'Embedding': grpc.unary_unary_rpc_method_handler( + servicer.Embedding, + request_deserializer=backend__pb2.PredictOptions.FromString, + response_serializer=backend__pb2.EmbeddingResult.SerializeToString, + ), + 'GenerateImage': grpc.unary_unary_rpc_method_handler( + servicer.GenerateImage, + request_deserializer=backend__pb2.GenerateImageRequest.FromString, + response_serializer=backend__pb2.Result.SerializeToString, + ), + 'AudioTranscription': grpc.unary_unary_rpc_method_handler( + servicer.AudioTranscription, + request_deserializer=backend__pb2.TranscriptRequest.FromString, + response_serializer=backend__pb2.TranscriptResult.SerializeToString, + ), + 'TTS': grpc.unary_unary_rpc_method_handler( + servicer.TTS, + request_deserializer=backend__pb2.TTSRequest.FromString, + response_serializer=backend__pb2.Result.SerializeToString, + ), + 'TokenizeString': grpc.unary_unary_rpc_method_handler( + servicer.TokenizeString, + request_deserializer=backend__pb2.PredictOptions.FromString, + response_serializer=backend__pb2.TokenizationResponse.SerializeToString, + ), + 'Status': grpc.unary_unary_rpc_method_handler( + servicer.Status, + request_deserializer=backend__pb2.HealthMessage.FromString, + response_serializer=backend__pb2.StatusResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'backend.Backend', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class Backend(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def Health(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/Health', + backend__pb2.HealthMessage.SerializeToString, + backend__pb2.Reply.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Predict(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/Predict', + backend__pb2.PredictOptions.SerializeToString, + backend__pb2.Reply.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def LoadModel(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/LoadModel', + backend__pb2.ModelOptions.SerializeToString, + backend__pb2.Result.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def PredictStream(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/backend.Backend/PredictStream', + backend__pb2.PredictOptions.SerializeToString, + backend__pb2.Reply.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Embedding(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/Embedding', + backend__pb2.PredictOptions.SerializeToString, + backend__pb2.EmbeddingResult.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GenerateImage(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/GenerateImage', + backend__pb2.GenerateImageRequest.SerializeToString, + backend__pb2.Result.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def AudioTranscription(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/AudioTranscription', + backend__pb2.TranscriptRequest.SerializeToString, + backend__pb2.TranscriptResult.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def TTS(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/TTS', + backend__pb2.TTSRequest.SerializeToString, + backend__pb2.Result.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def TokenizeString(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/TokenizeString', + backend__pb2.PredictOptions.SerializeToString, + backend__pb2.TokenizationResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Status(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/Status', + backend__pb2.HealthMessage.SerializeToString, + backend__pb2.StatusResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/backend/python/mamba/install.sh b/backend/python/mamba/install.sh new file mode 100644 index 000000000000..b69e22e7a22c --- /dev/null +++ b/backend/python/mamba/install.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +## +## A bash script installs the required dependencies of VALL-E-X and prepares the environment +export PATH=$PATH:/opt/conda/bin + +if [ "$BUILD_TYPE" != "cublas" ]; then + echo "[mamba] Attention!!! nvcc is required - skipping installation" + exit 0 +fi + +# Activate conda environment +source activate transformers + +echo $CONDA_PREFIX + +pip install causal-conv1d==1.0.0 mamba-ssm==1.0.1 + +if [ "$PIP_CACHE_PURGE" = true ] ; then + pip cache purge +fi \ No newline at end of file diff --git a/backend/python/mamba/run.sh b/backend/python/mamba/run.sh new file mode 100755 index 000000000000..3fee29314630 --- /dev/null +++ b/backend/python/mamba/run.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +## +## A bash script wrapper that runs the diffusers server with conda + +export PATH=$PATH:/opt/conda/bin + +# Activate conda environment +source activate transformers + +# get the directory where the bash script is located +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + +python $DIR/backend_mamba.py $@ \ No newline at end of file diff --git a/backend/python/mamba/test.sh b/backend/python/mamba/test.sh new file mode 100644 index 000000000000..b1ff55917f2f --- /dev/null +++ b/backend/python/mamba/test.sh @@ -0,0 +1,11 @@ +#!/bin/bash +## +## A bash script wrapper that runs the transformers server with conda + +# Activate conda environment +source activate transformers + +# get the directory where the bash script is located +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + +python -m unittest $DIR/test_backend_mamba.py \ No newline at end of file diff --git a/backend/python/mamba/test_backend_mamba.py b/backend/python/mamba/test_backend_mamba.py new file mode 100644 index 000000000000..7760f8163f45 --- /dev/null +++ b/backend/python/mamba/test_backend_mamba.py @@ -0,0 +1,76 @@ +import unittest +import subprocess +import time +import backend_pb2 +import backend_pb2_grpc + +import grpc + +import unittest +import subprocess +import time +import grpc +import backend_pb2_grpc +import backend_pb2 + +class TestBackendServicer(unittest.TestCase): + """ + TestBackendServicer is the class that tests the gRPC service. + + This class contains methods to test the startup and shutdown of the gRPC service. + """ + def setUp(self): + self.service = subprocess.Popen(["python", "backend_vllm.py", "--addr", "localhost:50051"]) + time.sleep(10) + + def tearDown(self) -> None: + self.service.terminate() + self.service.wait() + + def test_server_startup(self): + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.Health(backend_pb2.HealthMessage()) + self.assertEqual(response.message, b'OK') + except Exception as err: + print(err) + self.fail("Server failed to start") + finally: + self.tearDown() + def test_load_model(self): + """ + This method tests if the model is loaded successfully + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m")) + self.assertTrue(response.success) + self.assertEqual(response.message, "Model loaded successfully") + except Exception as err: + print(err) + self.fail("LoadModel service failed") + finally: + self.tearDown() + + def test_text(self): + """ + This method tests if the embeddings are generated successfully + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m")) + self.assertTrue(response.success) + req = backend_pb2.PredictOptions(Prompt="The capital of France is") + resp = stub.Predict(req) + self.assertIsNotNone(resp.message) + except Exception as err: + print(err) + self.fail("text service failed") + finally: + self.tearDown() \ No newline at end of file