Skip to content

Commit

Permalink
fix: exllama2 backend
Browse files Browse the repository at this point in the history
Signed-off-by: Sertac Ozercan <[email protected]>
  • Loading branch information
sozercan committed Dec 23, 2023
1 parent 939187a commit 8db3e30
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions backend/python/exllama2/exllama2_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import argparse
import signal
import sys
import os, glob
import os
import glob

from pathlib import Path
import torch
Expand All @@ -21,7 +22,7 @@
)


from exllamav2 import(
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Config,
ExLlamaV2Cache,
Expand All @@ -37,9 +38,12 @@
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))

# Implement the BackendServicer class with the service methods


class BackendServicer(backend_pb2_grpc.BackendServicer):
def Health(self, request, context):
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))

def LoadModel(self, request, context):
try:
model_directory = request.ModelFile
Expand All @@ -50,7 +54,7 @@ def LoadModel(self, request, context):

model = ExLlamaV2(config)

cache = ExLlamaV2Cache(model, lazy = True)
cache = ExLlamaV2Cache(model, lazy=True)
model.load_autosplit(cache)

tokenizer = ExLlamaV2Tokenizer(config)
Expand All @@ -59,7 +63,7 @@ def LoadModel(self, request, context):

generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)

self.generator= generator
self.generator = generator

generator.warmup()
self.model = model
Expand All @@ -85,17 +89,18 @@ def Predict(self, request, context):

if request.Tokens != 0:
tokens = request.Tokens
output = self.generator.generate_simple(request.Prompt, settings, tokens, seed = self.seed)
output = self.generator.generate_simple(
request.Prompt, settings, tokens)

# Remove prompt from response if present
if request.Prompt in output:
output = output.replace(request.Prompt, "")

return backend_pb2.Result(message=bytes(t, encoding='utf-8'))
return backend_pb2.Result(message=bytes(output, encoding='utf-8'))

def PredictStream(self, request, context):
# Implement PredictStream RPC
#for reply in some_data_generator():
# for reply in some_data_generator():
# yield reply
# Not implemented yet
return self.Predict(request, context)
Expand Down Expand Up @@ -124,11 +129,12 @@ def signal_handler(sig, frame):
except KeyboardInterrupt:
server.stop(0)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the gRPC server.")
parser.add_argument(
"--addr", default="localhost:50051", help="The address to bind the server to."
)
args = parser.parse_args()

serve(args.addr)
serve(args.addr)

0 comments on commit 8db3e30

Please sign in to comment.