Skip to content

Commit

Permalink
implement mamba
Browse files Browse the repository at this point in the history
  • Loading branch information
mudler committed Jan 19, 2024
1 parent 42be925 commit 62c0e52
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 11 deletions.
5 changes: 4 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ ARG TARGETVARIANT

ENV BUILD_TYPE=${BUILD_TYPE}

ENV EXTERNAL_GRPC_BACKENDS="coqui:/build/backend/python/coqui/run.sh,huggingface-embeddings:/build/backend/python/sentencetransformers/run.sh,petals:/build/backend/python/petals/run.sh,transformers:/build/backend/python/transformers/run.sh,sentencetransformers:/build/backend/python/sentencetransformers/run.sh,autogptq:/build/backend/python/autogptq/run.sh,bark:/build/backend/python/bark/run.sh,diffusers:/build/backend/python/diffusers/run.sh,exllama:/build/backend/python/exllama/run.sh,vall-e-x:/build/backend/python/vall-e-x/run.sh,vllm:/build/backend/python/vllm/run.sh,exllama2:/build/backend/python/exllama2/run.sh,transformers-musicgen:/build/backend/python/transformers-musicgen/run.sh"
ENV EXTERNAL_GRPC_BACKENDS="coqui:/build/backend/python/coqui/run.sh,huggingface-embeddings:/build/backend/python/sentencetransformers/run.sh,petals:/build/backend/python/petals/run.sh,transformers:/build/backend/python/transformers/run.sh,sentencetransformers:/build/backend/python/sentencetransformers/run.sh,autogptq:/build/backend/python/autogptq/run.sh,bark:/build/backend/python/bark/run.sh,diffusers:/build/backend/python/diffusers/run.sh,exllama:/build/backend/python/exllama/run.sh,vall-e-x:/build/backend/python/vall-e-x/run.sh,vllm:/build/backend/python/vllm/run.sh,mamba:/build/backend/python/mamba/run.sh,exllama2:/build/backend/python/exllama2/run.sh,transformers-musicgen:/build/backend/python/transformers-musicgen/run.sh"

ARG GO_TAGS="stablediffusion tinydream tts"

Expand Down Expand Up @@ -168,6 +168,9 @@ RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
PATH=$PATH:/opt/conda/bin make -C backend/python/vllm \
; fi
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
PATH=$PATH:/opt/conda/bin make -C backend/python/mamba \
; fi
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
PATH=$PATH:/opt/conda/bin make -C backend/python/sentencetransformers \
; fi
Expand Down
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ protogen-python:
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/vall-e-x/ --grpc_python_out=backend/python/vall-e-x/ backend/backend.proto
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/vllm/ --grpc_python_out=backend/python/vllm/ backend/backend.proto
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/petals/ --grpc_python_out=backend/python/petals/ backend/backend.proto
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/mamba/ --grpc_python_out=backend/python/mamba/ backend/backend.proto
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/exllama2/ --grpc_python_out=backend/python/exllama2/ backend/backend.proto

## GRPC
Expand All @@ -429,6 +430,7 @@ prepare-extra-conda-environments:
$(MAKE) -C backend/python/coqui
$(MAKE) -C backend/python/diffusers
$(MAKE) -C backend/python/vllm
$(MAKE) -C backend/python/mamba
$(MAKE) -C backend/python/sentencetransformers
$(MAKE) -C backend/python/transformers
$(MAKE) -C backend/python/transformers-musicgen
Expand Down
1 change: 0 additions & 1 deletion backend/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ message ModelOptions {
int32 CLIPSkip = 33;
string ControlNet = 48;

// RWKV
string Tokenizer = 34;

// LLM (llama.cpp)
Expand Down
8 changes: 7 additions & 1 deletion backend/python/coqui/coqui_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@ def Health(self, request, context):
def LoadModel(self, request, context):

# Get device
device = "cuda" if request.CUDA else "cpu"
# device = "cuda" if request.CUDA else "cpu"
if torch.cuda.is_available():
print("CUDA is available", file=sys.stderr)
device = "cuda"
else:
print("CUDA is not available", file=sys.stderr)
device = "cpu"

if not torch.cuda.is_available() and request.CUDA:
return backend_pb2.Result(success=False, message="CUDA is not available")
Expand Down
36 changes: 28 additions & 8 deletions backend/python/mamba/backend_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@
import backend_pb2_grpc

import grpc
from vllm import LLM, SamplingParams

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

_ONE_DAY_IN_SECONDS = 60 * 60 * 24

# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
MAMBA_CHAT= os.environ.get('MAMBA_CHAT', '1') == '1'

# Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer):
Expand Down Expand Up @@ -80,10 +84,16 @@ def LoadModel(self, request, context):
backend_pb2.Result: The load model result.
"""
try:
if request.Quantization != "":
self.llm = LLM(model=request.Model, quantization=request.Quantization)
else:
self.llm = LLM(model=request.Model)
tokenizerModel = request.Tokenizer
if tokenizerModel == "":
tokenizerModel = request.Model

tokenizer = AutoTokenizer.from_pretrained(tokenizerModel)
if MAMBA_CHAT:
tokenizer.eos_token = "<|endoftext|>"
tokenizer.pad_token = tokenizer.eos_token
self.tokenizer = tokenizer
self.model = MambaLMHeadModel.from_pretrained(request.Model, device="cuda", dtype=torch.float16)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(message="Model loaded successfully", success=True)
Expand All @@ -102,10 +112,20 @@ def Predict(self, request, context):
if request.TopP == 0:
request.TopP = 0.9

sampling_params = SamplingParams(temperature=request.Temperature, top_p=request.TopP)
outputs = self.llm.generate([request.Prompt], sampling_params)
max_tokens = request.Tokens

if request.Tokens == 0:
max_tokens = 2000

encoded_input = self.tokenizer(request.Prompt)

out = self.model.generate(input_ids=encoded_input["input_ids"], max_length=max_tokens, temperature=request.Temperratur,
top_p=request.TopP, eos_token_id=self.tokenizer.eos_token_id)

decoded = self.tokenizer.batch_decode(out)

generated_text = decoded[0]

generated_text = outputs[0].outputs[0].text
# Remove prompt from response if present
if request.Prompt in generated_text:
generated_text = generated_text.replace(request.Prompt, "")
Expand Down

0 comments on commit 62c0e52

Please sign in to comment.