Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: 🐍 add mamba support #1589

Merged
merged 1 commit into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ ARG TARGETVARIANT

ENV BUILD_TYPE=${BUILD_TYPE}

ENV EXTERNAL_GRPC_BACKENDS="coqui:/build/backend/python/coqui/run.sh,huggingface-embeddings:/build/backend/python/sentencetransformers/run.sh,petals:/build/backend/python/petals/run.sh,transformers:/build/backend/python/transformers/run.sh,sentencetransformers:/build/backend/python/sentencetransformers/run.sh,autogptq:/build/backend/python/autogptq/run.sh,bark:/build/backend/python/bark/run.sh,diffusers:/build/backend/python/diffusers/run.sh,exllama:/build/backend/python/exllama/run.sh,vall-e-x:/build/backend/python/vall-e-x/run.sh,vllm:/build/backend/python/vllm/run.sh,exllama2:/build/backend/python/exllama2/run.sh,transformers-musicgen:/build/backend/python/transformers-musicgen/run.sh"
ENV EXTERNAL_GRPC_BACKENDS="coqui:/build/backend/python/coqui/run.sh,huggingface-embeddings:/build/backend/python/sentencetransformers/run.sh,petals:/build/backend/python/petals/run.sh,transformers:/build/backend/python/transformers/run.sh,sentencetransformers:/build/backend/python/sentencetransformers/run.sh,autogptq:/build/backend/python/autogptq/run.sh,bark:/build/backend/python/bark/run.sh,diffusers:/build/backend/python/diffusers/run.sh,exllama:/build/backend/python/exllama/run.sh,vall-e-x:/build/backend/python/vall-e-x/run.sh,vllm:/build/backend/python/vllm/run.sh,mamba:/build/backend/python/mamba/run.sh,exllama2:/build/backend/python/exllama2/run.sh,transformers-musicgen:/build/backend/python/transformers-musicgen/run.sh"

ARG GO_TAGS="stablediffusion tinydream tts"

Expand Down Expand Up @@ -168,6 +168,9 @@ RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
PATH=$PATH:/opt/conda/bin make -C backend/python/vllm \
; fi
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
PATH=$PATH:/opt/conda/bin make -C backend/python/mamba \
; fi
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
PATH=$PATH:/opt/conda/bin make -C backend/python/sentencetransformers \
; fi
Expand Down
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ protogen-python:
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/vall-e-x/ --grpc_python_out=backend/python/vall-e-x/ backend/backend.proto
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/vllm/ --grpc_python_out=backend/python/vllm/ backend/backend.proto
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/petals/ --grpc_python_out=backend/python/petals/ backend/backend.proto
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/mamba/ --grpc_python_out=backend/python/mamba/ backend/backend.proto
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/exllama2/ --grpc_python_out=backend/python/exllama2/ backend/backend.proto

## GRPC
Expand All @@ -429,6 +430,7 @@ prepare-extra-conda-environments:
$(MAKE) -C backend/python/coqui
$(MAKE) -C backend/python/diffusers
$(MAKE) -C backend/python/vllm
$(MAKE) -C backend/python/mamba
$(MAKE) -C backend/python/sentencetransformers
$(MAKE) -C backend/python/transformers
$(MAKE) -C backend/python/transformers-musicgen
Expand Down
1 change: 0 additions & 1 deletion backend/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ message ModelOptions {
int32 CLIPSkip = 33;
string ControlNet = 48;

// RWKV
string Tokenizer = 34;

// LLM (llama.cpp)
Expand Down
8 changes: 7 additions & 1 deletion backend/python/coqui/coqui_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@ def Health(self, request, context):
def LoadModel(self, request, context):

# Get device
device = "cuda" if request.CUDA else "cpu"
# device = "cuda" if request.CUDA else "cpu"
if torch.cuda.is_available():
print("CUDA is available", file=sys.stderr)
device = "cuda"
else:
print("CUDA is not available", file=sys.stderr)
device = "cpu"

if not torch.cuda.is_available() and request.CUDA:
return backend_pb2.Result(success=False, message="CUDA is not available")
Expand Down
16 changes: 16 additions & 0 deletions backend/python/mamba/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
.PHONY: mamba
mamba:
$(MAKE) -C ../common-env/transformers
bash install.sh

.PHONY: run
run:
@echo "Running mamba..."
bash run.sh
@echo "mamba run."

.PHONY: test
test:
@echo "Testing mamba..."
bash test.sh
@echo "mamba tested."
5 changes: 5 additions & 0 deletions backend/python/mamba/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Creating a separate environment for the mamba project

```
make mamba
```
182 changes: 182 additions & 0 deletions backend/python/mamba/backend_mamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#!/usr/bin/env python3
from concurrent import futures
import time
import argparse
import signal
import sys
import os

import backend_pb2
import backend_pb2_grpc

import grpc

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

_ONE_DAY_IN_SECONDS = 60 * 60 * 24

# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
MAMBA_CHAT= os.environ.get('MAMBA_CHAT', '1') == '1'

# Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer):
"""
A gRPC servicer that implements the Backend service defined in backend.proto.
"""
def generate(self,prompt, max_new_tokens):
"""
Generates text based on the given prompt and maximum number of new tokens.

Args:
prompt (str): The prompt to generate text from.
max_new_tokens (int): The maximum number of new tokens to generate.

Returns:
str: The generated text.
"""
self.generator.end_beam_search()

# Tokenizing the input
ids = self.generator.tokenizer.encode(prompt)

self.generator.gen_begin_reuse(ids)
initial_len = self.generator.sequence[0].shape[0]
has_leading_space = False
decoded_text = ''
for i in range(max_new_tokens):
token = self.generator.gen_single_token()
if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'):
has_leading_space = True

decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:])
if has_leading_space:
decoded_text = ' ' + decoded_text

if token.item() == self.generator.tokenizer.eos_token_id:
break
return decoded_text

def Health(self, request, context):
"""
Returns a health check message.

Args:
request: The health check request.
context: The gRPC context.

Returns:
backend_pb2.Reply: The health check reply.
"""
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))

def LoadModel(self, request, context):
"""
Loads a language model.

Args:
request: The load model request.
context: The gRPC context.

Returns:
backend_pb2.Result: The load model result.
"""
try:
tokenizerModel = request.Tokenizer
if tokenizerModel == "":
tokenizerModel = request.Model

tokenizer = AutoTokenizer.from_pretrained(tokenizerModel)
if MAMBA_CHAT:
tokenizer.eos_token = "<|endoftext|>"
tokenizer.pad_token = tokenizer.eos_token
self.tokenizer = tokenizer
self.model = MambaLMHeadModel.from_pretrained(request.Model, device="cuda", dtype=torch.float16)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(message="Model loaded successfully", success=True)

def Predict(self, request, context):
"""
Generates text based on the given prompt and sampling parameters.

Args:
request: The predict request.
context: The gRPC context.

Returns:
backend_pb2.Result: The predict result.
"""
if request.TopP == 0:
request.TopP = 0.9

max_tokens = request.Tokens

if request.Tokens == 0:
max_tokens = 2000

encoded_input = self.tokenizer(request.Prompt)

out = self.model.generate(input_ids=encoded_input["input_ids"], max_length=max_tokens, temperature=request.Temperratur,
top_p=request.TopP, eos_token_id=self.tokenizer.eos_token_id)

decoded = self.tokenizer.batch_decode(out)

generated_text = decoded[0]

# Remove prompt from response if present
if request.Prompt in generated_text:
generated_text = generated_text.replace(request.Prompt, "")

return backend_pb2.Result(message=bytes(generated_text, encoding='utf-8'))

def PredictStream(self, request, context):
"""
Generates text based on the given prompt and sampling parameters, and streams the results.

Args:
request: The predict stream request.
context: The gRPC context.

Returns:
backend_pb2.Result: The predict stream result.
"""
# Implement PredictStream RPC
#for reply in some_data_generator():
# yield reply
# Not implemented yet
return self.Predict(request, context)

def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()
print("Server started. Listening on: " + address, file=sys.stderr)

# Define the signal handler function
def signal_handler(sig, frame):
print("Received termination signal. Shutting down...")
server.stop(0)
sys.exit(0)

# Set the signal handlers for SIGINT and SIGTERM
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)

try:
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
server.stop(0)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the gRPC server.")
parser.add_argument(
"--addr", default="localhost:50051", help="The address to bind the server to."
)
args = parser.parse_args()

serve(args.addr)
61 changes: 61 additions & 0 deletions backend/python/mamba/backend_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading