From 9608f1acb0acc398a0957b9a209ebe97f546ead2 Mon Sep 17 00:00:00 2001 From: Kevin Mingtarja Date: Sat, 27 Apr 2024 16:01:39 +0700 Subject: [PATCH 1/4] improve dockerfile, makefile, readme --- python/huggingface_server.Dockerfile | 8 ++++---- python/huggingfaceserver/Makefile | 5 +++++ python/huggingfaceserver/README.md | 6 ++++++ 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/python/huggingface_server.Dockerfile b/python/huggingface_server.Dockerfile index 35650cfd06e..26b4f2ee402 100644 --- a/python/huggingface_server.Dockerfile +++ b/python/huggingface_server.Dockerfile @@ -22,14 +22,14 @@ RUN python3 -m venv $VIRTUAL_ENV ENV PATH="$VIRTUAL_ENV/bin:$PATH" COPY kserve/pyproject.toml kserve/poetry.lock kserve/ -RUN cd kserve && poetry install --no-root --no-interaction --no-cache +RUN cd kserve && poetry install --all-extras --no-root --no-interaction --no-cache COPY kserve kserve -RUN cd kserve && poetry install --no-interaction --no-cache +RUN cd kserve && poetry install --all-extras --no-interaction --no-cache COPY huggingfaceserver/pyproject.toml huggingfaceserver/poetry.lock huggingfaceserver/ -RUN cd huggingfaceserver && poetry install --no-root --no-interaction --no-cache +RUN cd huggingfaceserver && poetry install --all-extras --no-root --no-interaction --no-cache COPY huggingfaceserver huggingfaceserver -RUN cd huggingfaceserver && poetry install --no-interaction --no-cache +RUN cd huggingfaceserver && poetry install --all-extras --no-interaction --no-cache RUN pip3 install vllm==${VLLM_VERSION} diff --git a/python/huggingfaceserver/Makefile b/python/huggingfaceserver/Makefile index 92cf90452a8..f0f89a1b2a0 100644 --- a/python/huggingfaceserver/Makefile +++ b/python/huggingfaceserver/Makefile @@ -1,3 +1,5 @@ +IMG ?= hypermode/huggingface-model-server:latest + dev_install: poetry install --with test @@ -9,3 +11,6 @@ test: type_check type_check: mypy --ignore-missing-imports huggingfaceserver + +build-docker: + docker build --platform linux/amd64 -t $(IMG) -f ../huggingface_server.Dockerfile .. diff --git a/python/huggingfaceserver/README.md b/python/huggingfaceserver/README.md index 18172509b0e..f93ff334ed1 100644 --- a/python/huggingfaceserver/README.md +++ b/python/huggingfaceserver/README.md @@ -5,6 +5,12 @@ The preprocess and post-process handlers are implemented based on different ML t token-classification, text-generation, text2text generation. Based on the performance requirement, you can choose to perform the inference on a more optimized inference engine like triton inference server and vLLM for text generation. +## Build Container Image Locally + +The Dockerfile exists in the parent directory +```bash +docker build --platform linux/amd64 -t hypermode/huggingface-model-server:latest -f huggingface_server.Dockerfile .. +``` ## Run Huggingface Server Locally From aab81d7c34f2b6ba24322d57b1f123407140c743 Mon Sep 17 00:00:00 2001 From: Kevin Mingtarja Date: Sat, 27 Apr 2024 16:02:10 +0700 Subject: [PATCH 2/4] support custom classification labels, refactor postprocess --- .../huggingfaceserver/__main__.py | 1 + .../huggingfaceserver/encoder_model.py | 27 ++++++++++-- .../huggingfaceserver/test_model.py | 42 +++++++++++++++++-- 3 files changed, 63 insertions(+), 7 deletions(-) diff --git a/python/huggingfaceserver/huggingfaceserver/__main__.py b/python/huggingfaceserver/huggingfaceserver/__main__.py index b0dde134f02..efa1962e119 100644 --- a/python/huggingfaceserver/huggingfaceserver/__main__.py +++ b/python/huggingfaceserver/huggingfaceserver/__main__.py @@ -189,6 +189,7 @@ def load_model(): tensor_input_names=kwargs.get("tensor_input_names", None), return_token_type_ids=kwargs.get("return_token_type_ids", None), predictor_config=predictor_config, + classification_labels=kwargs.get("classification_labels", None) ) model.load() return model diff --git a/python/huggingfaceserver/huggingfaceserver/encoder_model.py b/python/huggingfaceserver/huggingfaceserver/encoder_model.py index 638a35e0d60..36bef1ba7fc 100644 --- a/python/huggingfaceserver/huggingfaceserver/encoder_model.py +++ b/python/huggingfaceserver/huggingfaceserver/encoder_model.py @@ -79,6 +79,7 @@ def __init__( tokenizer_revision: Optional[str] = None, trust_remote_code: bool = False, predictor_config: Optional[PredictorConfig] = None, + classification_labels: Optional[Dict[int, str]] = None, ): super().__init__(model_name, predictor_config) self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -91,6 +92,7 @@ def __init__( self.model_revision = model_revision self.tokenizer_revision = tokenizer_revision self.trust_remote_code = trust_remote_code + self.classification_labels = classification_labels if model_config: self.model_config = model_config @@ -246,11 +248,28 @@ def postprocess( input_ids = torch.Tensor(input_ids) inferences = [] if self.task == MLTask.sequence_classification: - num_rows, num_cols = outputs.shape + outputs = torch.nn.functional.softmax(outputs, dim = -1).detach() + max_indices = torch.argmax(outputs, dim=1).tolist() + final = outputs.tolist() + + id2label = {0: 0, 1: 1} + if self.classification_labels: + id2label = {i: val for i, val in enumerate(self.classification_labels)} + + num_rows, _ = outputs.shape for i in range(num_rows): - out = outputs[i].unsqueeze(0) - predicted_idx = out.argmax().item() - inferences.append(predicted_idx) + max_id = max_indices[i] + res = { + "label": id2label[max_id], + "confidence": final[i][max_id], + "probabilities": [ + { + "label": id2label[j], + "probability": final[i][j] + } for j in range(len(final[i])) + ] + } + inferences.append(res) return get_predict_response(request, inferences, self.name) elif self.task == MLTask.fill_mask: num_rows = outputs.shape[0] diff --git a/python/huggingfaceserver/huggingfaceserver/test_model.py b/python/huggingfaceserver/huggingfaceserver/test_model.py index 45be9062697..7111ebb1ed5 100644 --- a/python/huggingfaceserver/huggingfaceserver/test_model.py +++ b/python/huggingfaceserver/huggingfaceserver/test_model.py @@ -200,7 +200,25 @@ async def test_bert_sequence_classification(bert_base_yelp_polarity): response = await bert_base_yelp_polarity( {"instances": [request, request]}, headers={} ) - assert response == {"predictions": [1, 1]} + assert response == {"predictions": [ + { + 'confidence': 0.9988189339637756, + 'label': 1, + 'probabilities': [ + {'label': 0, 'probability': 0.0011810670839622617}, + {'label': 1, 'probability': 0.9988189339637756} + ] + }, + { + 'confidence': 0.9988189339637756, + 'label': 1, + 'probabilities': [ + {'label': 0, 'probability': 0.0011810670839622617}, + {'label': 1, 'probability': 0.9988189339637756} + ] + } + ] + } @pytest.mark.asyncio @@ -311,7 +329,25 @@ async def test_input_padding(bert_base_yelp_polarity: HuggingfaceEncoderModel): response = await bert_base_yelp_polarity( {"instances": [request_one, request_two]}, headers={} ) - assert response == {"predictions": [1, 1]} + assert response == {"predictions": [ + { + 'confidence': 0.9988189339637756, + 'label': 1, + 'probabilities': [ + {'label': 0, 'probability': 0.0011810670839622617}, + {'label': 1, 'probability': 0.9988189339637756} + ] + }, + { + 'confidence': 0.9963782429695129, + 'label': 1, + 'probabilities': [ + {'label': 0, 'probability': 0.003621795680373907}, + {'label': 1, 'probability': 0.9963782429695129} + ] + } + ] + } @pytest.mark.asyncio @@ -321,4 +357,4 @@ async def test_input_truncation(bert_base_yelp_polarity: HuggingfaceEncoderModel # unless we set truncation=True in the tokenizer request = "good " * 600 response = await bert_base_yelp_polarity({"instances": [request]}, headers={}) - assert response == {"predictions": [1]} + assert response == {'predictions': [{'confidence': 0.9914830327033997, 'label': 1, 'probabilities': [{'label': 0, 'probability': 0.00851691048592329}, {'label': 1, 'probability': 0.9914830327033997}]}]} From 64e55e6646b3f235d3f7437c15fa234f91180189 Mon Sep 17 00:00:00 2001 From: Kevin Mingtarja Date: Sat, 27 Apr 2024 16:22:55 +0700 Subject: [PATCH 3/4] support text embedding task --- .../huggingfaceserver/encoder_model.py | 22 +++++++++++- .../huggingfaceserver/task.py | 3 ++ .../huggingfaceserver/test_model.py | 34 +++++++++++++++++++ 3 files changed, 58 insertions(+), 1 deletion(-) diff --git a/python/huggingfaceserver/huggingfaceserver/encoder_model.py b/python/huggingfaceserver/huggingfaceserver/encoder_model.py index 36bef1ba7fc..c3ed05f5b2d 100644 --- a/python/huggingfaceserver/huggingfaceserver/encoder_model.py +++ b/python/huggingfaceserver/huggingfaceserver/encoder_model.py @@ -16,6 +16,7 @@ from typing import Any, Dict, Optional, Union import torch +import torch.nn.functional as F from accelerate import init_empty_weights from kserve import Model from kserve.errors import InferenceError @@ -215,6 +216,7 @@ def preprocess( truncation=True, ) context["payload"] = payload + context["inputs"] = inputs context["input_ids"] = inputs["input_ids"] return inputs @@ -231,7 +233,10 @@ async def predict( input_batch = input_batch.to(self._device) try: with torch.no_grad(): - outputs = self._model(**input_batch).logits + if self.task == MLTask.text_embedding.value: + outputs = self._model(**input_batch) + else: + outputs = self._model(**input_batch).logits return outputs except Exception as e: raise InferenceError(str(e)) @@ -286,7 +291,22 @@ def postprocess( predictions = torch.argmax(output, dim=2) inferences.append(predictions.tolist()) return get_predict_response(request, inferences, self.name) + elif self.task == MLTask.text_embedding.value: + # Perform pooling + outputs = mean_pooling(outputs, context["inputs"]["attention_mask"]) + # Normalize embeddings + outputs = F.normalize(outputs, p=2, dim=1) + num_rows, _ = outputs.shape + for i in range(num_rows): + inferences.append(outputs[i].tolist()) + return get_predict_response(request, inferences, self.name) else: raise ValueError( f"Unsupported task {self.task}. Please check the supported `task` option." ) + +# Mean Pooling - Take attention mask into account for correct averaging +def mean_pooling(model_output, attention_mask): + token_embeddings = model_output[0] #First element of model_output contains all token embeddings + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) diff --git a/python/huggingfaceserver/huggingfaceserver/task.py b/python/huggingfaceserver/huggingfaceserver/task.py index 96213488f6c..c919c6e3cbc 100644 --- a/python/huggingfaceserver/huggingfaceserver/task.py +++ b/python/huggingfaceserver/huggingfaceserver/task.py @@ -42,6 +42,7 @@ class MLTask(str, Enum): text_generation = auto() text2text_generation = auto() multiple_choice = auto() + text_embedding = auto() @classmethod def _missing_(cls, value: str): @@ -74,6 +75,7 @@ def _missing_(cls, value: str): MLTask.text_generation: AutoModelForCausalLM, MLTask.text2text_generation: AutoModelForSeq2SeqLM, MLTask.multiple_choice: AutoModelForMultipleChoice, + MLTask.text_embedding: AutoModel, } SUPPORTED_TASKS = { @@ -82,6 +84,7 @@ def _missing_(cls, value: str): MLTask.fill_mask, MLTask.text_generation, MLTask.text2text_generation, + MLTask.text_embedding } diff --git a/python/huggingfaceserver/huggingfaceserver/test_model.py b/python/huggingfaceserver/huggingfaceserver/test_model.py index 7111ebb1ed5..239438500b3 100644 --- a/python/huggingfaceserver/huggingfaceserver/test_model.py +++ b/python/huggingfaceserver/huggingfaceserver/test_model.py @@ -14,6 +14,8 @@ import pytest +import torch.nn.functional as F +import torch from kserve.model import PredictorConfig from kserve.protocol.rest.openai import ChatCompletionRequest, CompletionRequest from kserve.protocol.rest.openai.types import ( @@ -88,6 +90,17 @@ def bert_token_classification(): yield model model.stop() +@pytest.fixture(scope="module") +def text_embedding(): + model = HuggingfaceEncoderModel( + "mxbai-embed-large-v1", + model_id_or_path="mixedbread-ai/mxbai-embed-large-v1", + task=MLTask.text_embedding, + ) + model.load() + yield model + model.stop() + def test_unsupported_model(): config = AutoConfig.from_pretrained("google/tapas-base-finetuned-wtq") @@ -234,6 +247,27 @@ async def test_bert_token_classification(bert_token_classification): ] } +@pytest.mark.asyncio +async def test_text_embedding(text_embedding): + def cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + if len(a.shape) == 1: + a = a.unsqueeze(0) + + if len(b.shape) == 1: + b = b.unsqueeze(0) + + a_norm = F.normalize(a, p=2, dim=1) + b_norm = F.normalize(b, p=2, dim=1) + return torch.mm(a_norm, b_norm.transpose(0, 1)) + + requests = ["I'm happy", "I'm full of happiness", "They were born in the capital city of France, Paris"] + response = await text_embedding({"instances": requests}, headers={}) + predictions = response["predictions"] + + print(cosine_similarity(torch.tensor(predictions[0]), torch.tensor(predictions[1]))[0]) + print(cosine_similarity(torch.tensor(predictions[0]), torch.tensor(predictions[2]))[0]) + assert cosine_similarity(torch.tensor(predictions[0]), torch.tensor(predictions[1]))[0] > 0.9 + assert cosine_similarity(torch.tensor(predictions[0]), torch.tensor(predictions[2]))[0] < 0.6 @pytest.mark.asyncio async def test_bloom_completion(bloom_model: HuggingfaceGenerativeModel): From b48a443ff041e3e111f4fff6723cf3da538b031c Mon Sep 17 00:00:00 2001 From: Kevin Mingtarja Date: Sat, 27 Apr 2024 16:46:24 +0700 Subject: [PATCH 4/4] improve support for token classification (named entity recognition) --- .../huggingfaceserver/encoder_model.py | 32 ++++++++++++++++--- .../huggingfaceserver/test_model.py | 21 ++++++++++-- 2 files changed, 45 insertions(+), 8 deletions(-) diff --git a/python/huggingfaceserver/huggingfaceserver/encoder_model.py b/python/huggingfaceserver/huggingfaceserver/encoder_model.py index c3ed05f5b2d..866064cfb19 100644 --- a/python/huggingfaceserver/huggingfaceserver/encoder_model.py +++ b/python/huggingfaceserver/huggingfaceserver/encoder_model.py @@ -34,6 +34,7 @@ AutoModel, AutoTokenizer, BatchEncoding, + pipeline, PreTrainedModel, PreTrainedTokenizerBase, PretrainedConfig, @@ -94,6 +95,7 @@ def __init__( self.tokenizer_revision = tokenizer_revision self.trust_remote_code = trust_remote_code self.classification_labels = classification_labels + self.nlp = None if model_config: self.model_config = model_config @@ -134,6 +136,9 @@ def load(self) -> bool: if self._model._no_split_modules: device_map = "auto" + # somehow, setting it to True give worse results for NER task + if self.task == MLTask.token_classification.value: + self.do_lower_case = False tokenizer_kwargs = {} model_kwargs = {} @@ -165,6 +170,8 @@ def load(self) -> bool: ) self._model.eval() self._model.to(self._device) + if self.task == MLTask.token_classification: + self.nlp = pipeline("ner", model=self._model, tokenizer=self._tokenizer) logger.info( f"Successfully loaded huggingface model from path {model_id_or_path}" ) @@ -190,6 +197,7 @@ def preprocess( truncation=True, ) context["payload"] = payload + context["inputs"] = inputs context["input_ids"] = inputs["input_ids"] infer_inputs = [] for key, input_tensor in inputs.items(): @@ -206,6 +214,12 @@ def preprocess( ) return infer_request else: + if self.task == MLTask.token_classification.value: + context["payload"] = payload + context["inputs"] = instances + context["input_ids"] = [] + return instances + inputs = self._tokenizer( instances, max_length=self.max_length, @@ -230,8 +244,12 @@ async def predict( # like NVIDIA triton inference server return await super().predict(input_batch, context) else: - input_batch = input_batch.to(self._device) try: + if self.task == MLTask.token_classification: + with torch.no_grad(): + return self.nlp(input_batch) + + input_batch = input_batch.to(self._device) with torch.no_grad(): if self.task == MLTask.text_embedding.value: outputs = self._model(**input_batch) @@ -285,11 +303,15 @@ def postprocess( inferences.append(self._tokenizer.decode(predicted_token_id)) return get_predict_response(request, inferences, self.name) elif self.task == MLTask.token_classification: - num_rows = outputs.shape[0] + num_rows = len(outputs) for i in range(num_rows): - output = outputs[i].unsqueeze(0) - predictions = torch.argmax(output, dim=2) - inferences.append(predictions.tolist()) + output = outputs[i] + for entity in output: + # without this, it fails with + # ValueError: [TypeError("'numpy.float32' object is not iterable"), TypeError('vars() argument must have __dict__ attribute')] + entity["score"] = float(entity["score"]) + predictions = output + inferences.append(predictions) return get_predict_response(request, inferences, self.name) elif self.task == MLTask.text_embedding.value: # Perform pooling diff --git a/python/huggingfaceserver/huggingfaceserver/test_model.py b/python/huggingfaceserver/huggingfaceserver/test_model.py index 239438500b3..95b62abc347 100644 --- a/python/huggingfaceserver/huggingfaceserver/test_model.py +++ b/python/huggingfaceserver/huggingfaceserver/test_model.py @@ -241,9 +241,24 @@ async def test_bert_token_classification(bert_token_classification): {"instances": [request, request]}, headers={} ) assert response == { - "predictions": [ - [[0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], - [[0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + 'predictions': [ + [ + {'entity': 'I-ORG', 'score': 0.9972999691963196, 'index': 1, 'word': 'Hu', 'start': 0, 'end': 2}, + {'entity': 'I-ORG', 'score': 0.9716504216194153, 'index': 2, 'word': '##gging', 'start': 2, 'end': 7}, + {'entity': 'I-ORG', 'score': 0.9962745904922485, 'index': 3, 'word': '##F', 'start': 7, 'end': 8}, + {'entity': 'I-ORG', 'score': 0.993005096912384, 'index': 4, 'word': '##ace', 'start': 8, 'end': 11}, + {'entity': 'I-LOC', 'score': 0.9940695762634277, 'index': 10, 'word': 'Paris', 'start': 34, 'end': 39}, + {'entity': 'I-LOC', 'score': 0.9982321858406067, 'index': 12, 'word': 'New', 'start': 44, 'end': 47}, + {'entity': 'I-LOC', 'score': 0.9975290894508362, 'index': 13, 'word': 'York', 'start': 48, 'end': 52} + ], + [ + {'entity': 'I-ORG', 'score': 0.9972999691963196, 'index': 1, 'word': 'Hu', 'start': 0, 'end': 2}, + {'entity': 'I-ORG', 'score': 0.9716504216194153, 'index': 2, 'word': '##gging', 'start': 2, 'end': 7}, + {'entity': 'I-ORG', 'score': 0.9962745904922485, 'index': 3, 'word': '##F', 'start': 7, 'end': 8}, + {'entity': 'I-ORG', 'score': 0.993005096912384, 'index': 4, 'word': '##ace', 'start': 8, 'end': 11}, + {'entity': 'I-LOC', 'score': 0.9940695762634277, 'index': 10, 'word': 'Paris', 'start': 34, 'end': 39}, + {'entity': 'I-LOC', 'score': 0.9982321858406067, 'index': 12, 'word': 'New', 'start': 44, 'end': 47}, + {'entity': 'I-LOC', 'score': 0.9975290894508362, 'index': 13, 'word': 'York', 'start': 48, 'end': 52}] ] }