diff --git a/python/huggingface_server.Dockerfile b/python/huggingface_server.Dockerfile index 07f1dc326aa..09933f3f8cb 100644 --- a/python/huggingface_server.Dockerfile +++ b/python/huggingface_server.Dockerfile @@ -23,14 +23,14 @@ RUN python3 -m venv $VIRTUAL_ENV ENV PATH="$VIRTUAL_ENV/bin:$PATH" COPY kserve/pyproject.toml kserve/poetry.lock kserve/ -RUN cd kserve && poetry install --no-root --no-interaction --no-cache +RUN cd kserve && poetry install --all-extras --no-root --no-interaction --no-cache COPY kserve kserve -RUN cd kserve && poetry install --no-interaction --no-cache +RUN cd kserve && poetry install --all-extras --no-interaction --no-cache COPY huggingfaceserver/pyproject.toml huggingfaceserver/poetry.lock huggingfaceserver/ -RUN cd huggingfaceserver && poetry install --no-root --no-interaction --no-cache +RUN cd huggingfaceserver && poetry install --all-extras --no-root --no-interaction --no-cache COPY huggingfaceserver huggingfaceserver -RUN cd huggingfaceserver && poetry install --no-interaction --no-cache +RUN cd huggingfaceserver && poetry install --all-extras --no-interaction --no-cache RUN pip3 install vllm==${VLLM_VERSION} diff --git a/python/huggingfaceserver/Makefile b/python/huggingfaceserver/Makefile index 86732a30460..f12bcfef0df 100644 --- a/python/huggingfaceserver/Makefile +++ b/python/huggingfaceserver/Makefile @@ -1,3 +1,5 @@ +IMG ?= hypermode/huggingface-model-server:latest + dev_install: poetry install --with test -E vllm --no-interaction @@ -9,3 +11,6 @@ test: type_check type_check: mypy --ignore-missing-imports huggingfaceserver + +build-docker: + docker build --platform linux/amd64 -t $(IMG) -f ../huggingface_server.Dockerfile .. diff --git a/python/huggingfaceserver/README.md b/python/huggingfaceserver/README.md index d8337b28775..547b7d8a4e5 100644 --- a/python/huggingfaceserver/README.md +++ b/python/huggingfaceserver/README.md @@ -5,6 +5,12 @@ The preprocess and post-process handlers are implemented based on different ML t token-classification, text-generation, text2text generation. Based on the performance requirement, you can choose to perform the inference on a more optimized inference engine like triton inference server and vLLM for text generation. +## Build Container Image Locally + +The Dockerfile exists in the parent directory +```bash +docker build --platform linux/amd64 -t hypermode/huggingface-model-server:latest -f huggingface_server.Dockerfile .. +``` ## Run HuggingFace Server Locally diff --git a/python/huggingfaceserver/huggingfaceserver/__main__.py b/python/huggingfaceserver/huggingfaceserver/__main__.py index eb2b2c97469..e20f283c914 100644 --- a/python/huggingfaceserver/huggingfaceserver/__main__.py +++ b/python/huggingfaceserver/huggingfaceserver/__main__.py @@ -270,6 +270,7 @@ def load_model(): return_token_type_ids=kwargs.get("return_token_type_ids", None), predictor_config=predictor_config, request_logger=request_logger, + classification_labels=kwargs.get("classification_labels", None) ) model.load() return model diff --git a/python/huggingfaceserver/huggingfaceserver/encoder_model.py b/python/huggingfaceserver/huggingfaceserver/encoder_model.py index 1196a818c72..fcfd431847a 100644 --- a/python/huggingfaceserver/huggingfaceserver/encoder_model.py +++ b/python/huggingfaceserver/huggingfaceserver/encoder_model.py @@ -33,6 +33,7 @@ AutoConfig, AutoTokenizer, BatchEncoding, + pipeline, PreTrainedModel, PreTrainedTokenizerBase, PretrainedConfig, @@ -84,6 +85,7 @@ def __init__( return_probabilities: bool = False, predictor_config: Optional[PredictorConfig] = None, request_logger: Optional[RequestLogger] = None, + classification_labels: Optional[Dict[int, str]] = None, ): super().__init__(model_name, predictor_config) self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -99,6 +101,8 @@ def __init__( self.trust_remote_code = trust_remote_code self.return_probabilities = return_probabilities self.request_logger = request_logger + self.classification_labels = classification_labels + self.nlp = None if model_config: self.model_config = model_config @@ -140,6 +144,9 @@ def load(self) -> bool: if self._model._no_split_modules: device_map = "auto" + # somehow, setting it to True give worse results for NER task + if self.task == MLTask.token_classification.value: + self.do_lower_case = False tokenizer_kwargs = {} model_kwargs = {} @@ -179,6 +186,8 @@ def load(self) -> bool: # When adding new tokens to the vocabulary, we should make sure to also resize the token embedding # matrix of the model so that its embedding matrix matches the tokenizer. self._model.resize_token_embeddings(len(self._tokenizer)) + if self.task == MLTask.token_classification: + self.nlp = pipeline("ner", model=self._model, tokenizer=self._tokenizer) logger.info( f"Successfully loaded huggingface model from path {model_id_or_path}" ) @@ -208,6 +217,7 @@ def preprocess( truncation=True, ) context["payload"] = payload + context["inputs"] = inputs context["input_ids"] = inputs["input_ids"] if self.task == MLTask.text_embedding: context["attention_mask"] = inputs["attention_mask"] @@ -226,6 +236,12 @@ def preprocess( ) return infer_request else: + if self.task == MLTask.token_classification.value: + context["payload"] = payload + context["inputs"] = instances + context["input_ids"] = [] + return instances + inputs = self._tokenizer( instances, max_length=self.max_length, @@ -236,6 +252,7 @@ def preprocess( truncation=True, ) context["payload"] = payload + context["inputs"] = inputs context["input_ids"] = inputs["input_ids"] if self.task == MLTask.text_embedding: context["attention_mask"] = inputs["attention_mask"] @@ -251,8 +268,12 @@ async def predict( # like NVIDIA triton inference server return await super().predict(input_batch, context) else: - input_batch = input_batch.to(self._device) try: + if self.task == MLTask.token_classification: + with torch.no_grad(): + return self.nlp(input_batch) + + input_batch = input_batch.to(self._device) with torch.no_grad(): outputs = self._model(**input_batch) if self.task == MLTask.text_embedding.value: @@ -276,14 +297,32 @@ def postprocess( input_ids = torch.Tensor(input_ids) inferences = [] if self.task == MLTask.sequence_classification: - num_rows, num_cols = outputs.shape + outputs = torch.nn.functional.softmax(outputs, dim = -1).detach() + max_indices = torch.argmax(outputs, dim=1).tolist() + final = outputs.tolist() + + id2label = {0: 0, 1: 1} + if self.classification_labels: + id2label = {i: val for i, val in enumerate(self.classification_labels)} + + num_rows, _ = outputs.shape for i in range(num_rows): out = outputs[i].unsqueeze(0) if self.return_probabilities: inferences.append(dict(enumerate(out.numpy().flatten()))) else: - predicted_idx = out.argmax().item() - inferences.append(predicted_idx) + max_id = max_indices[i] + res = { + "label": id2label[max_id], + "confidence": final[i][max_id], + "probabilities": [ + { + "label": id2label[j], + "probability": final[i][j] + } for j in range(len(final[i])) + ] + } + inferences.append(res) return get_predict_response(request, inferences, self.name) elif self.task == MLTask.fill_mask: num_rows = outputs.shape[0] @@ -305,7 +344,7 @@ def postprocess( inferences.append(self._tokenizer.decode(predicted_token_id)) return get_predict_response(request, inferences, self.name) elif self.task == MLTask.token_classification: - num_rows = outputs.shape[0] + num_rows = len(outputs) for i in range(num_rows): output = outputs[i].unsqueeze(0) if self.return_probabilities: diff --git a/python/huggingfaceserver/tests/test_model.py b/python/huggingfaceserver/tests/test_model.py index 8aec1afa23d..c52fc7c8639 100644 --- a/python/huggingfaceserver/tests/test_model.py +++ b/python/huggingfaceserver/tests/test_model.py @@ -139,6 +139,17 @@ def openai_gpt_model(): model.stop() +@pytest.fixture(scope="module") +def text_embedding(): + model = HuggingfaceEncoderModel( + "mxbai-embed-large-v1", + model_id_or_path="mixedbread-ai/mxbai-embed-large-v1", + task=MLTask.text_embedding, + ) + model.load() + yield model + model.stop() + @pytest.fixture(scope="module") def text_embedding(): model = HuggingfaceEncoderModel( @@ -278,7 +289,25 @@ async def test_bert_sequence_classification(bert_base_yelp_polarity): response = await bert_base_yelp_polarity( {"instances": [request, request]}, headers={} ) - assert response == {"predictions": [1, 1]} + assert response == {"predictions": [ + { + 'confidence': 0.9988189339637756, + 'label': 1, + 'probabilities': [ + {'label': 0, 'probability': 0.0011810670839622617}, + {'label': 1, 'probability': 0.9988189339637756} + ] + }, + { + 'confidence': 0.9988189339637756, + 'label': 1, + 'probabilities': [ + {'label': 0, 'probability': 0.0011810670839622617}, + {'label': 1, 'probability': 0.9988189339637756} + ] + } + ] + } @pytest.mark.asyncio @@ -314,12 +343,48 @@ async def test_bert_token_classification(bert_token_classification): {"instances": [request, request]}, headers={} ) assert response == { - "predictions": [ - [[0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], - [[0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + 'predictions': [ + [ + {'entity': 'I-ORG', 'score': 0.9972999691963196, 'index': 1, 'word': 'Hu', 'start': 0, 'end': 2}, + {'entity': 'I-ORG', 'score': 0.9716504216194153, 'index': 2, 'word': '##gging', 'start': 2, 'end': 7}, + {'entity': 'I-ORG', 'score': 0.9962745904922485, 'index': 3, 'word': '##F', 'start': 7, 'end': 8}, + {'entity': 'I-ORG', 'score': 0.993005096912384, 'index': 4, 'word': '##ace', 'start': 8, 'end': 11}, + {'entity': 'I-LOC', 'score': 0.9940695762634277, 'index': 10, 'word': 'Paris', 'start': 34, 'end': 39}, + {'entity': 'I-LOC', 'score': 0.9982321858406067, 'index': 12, 'word': 'New', 'start': 44, 'end': 47}, + {'entity': 'I-LOC', 'score': 0.9975290894508362, 'index': 13, 'word': 'York', 'start': 48, 'end': 52} + ], + [ + {'entity': 'I-ORG', 'score': 0.9972999691963196, 'index': 1, 'word': 'Hu', 'start': 0, 'end': 2}, + {'entity': 'I-ORG', 'score': 0.9716504216194153, 'index': 2, 'word': '##gging', 'start': 2, 'end': 7}, + {'entity': 'I-ORG', 'score': 0.9962745904922485, 'index': 3, 'word': '##F', 'start': 7, 'end': 8}, + {'entity': 'I-ORG', 'score': 0.993005096912384, 'index': 4, 'word': '##ace', 'start': 8, 'end': 11}, + {'entity': 'I-LOC', 'score': 0.9940695762634277, 'index': 10, 'word': 'Paris', 'start': 34, 'end': 39}, + {'entity': 'I-LOC', 'score': 0.9982321858406067, 'index': 12, 'word': 'New', 'start': 44, 'end': 47}, + {'entity': 'I-LOC', 'score': 0.9975290894508362, 'index': 13, 'word': 'York', 'start': 48, 'end': 52}] ] } +@pytest.mark.asyncio +async def test_text_embedding(text_embedding): + def cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + if len(a.shape) == 1: + a = a.unsqueeze(0) + + if len(b.shape) == 1: + b = b.unsqueeze(0) + + a_norm = F.normalize(a, p=2, dim=1) + b_norm = F.normalize(b, p=2, dim=1) + return torch.mm(a_norm, b_norm.transpose(0, 1)) + + requests = ["I'm happy", "I'm full of happiness", "They were born in the capital city of France, Paris"] + response = await text_embedding({"instances": requests}, headers={}) + predictions = response["predictions"] + + print(cosine_similarity(torch.tensor(predictions[0]), torch.tensor(predictions[1]))[0]) + print(cosine_similarity(torch.tensor(predictions[0]), torch.tensor(predictions[2]))[0]) + assert cosine_similarity(torch.tensor(predictions[0]), torch.tensor(predictions[1]))[0] > 0.9 + assert cosine_similarity(torch.tensor(predictions[0]), torch.tensor(predictions[2]))[0] < 0.6 @pytest.mark.asyncio async def test_bloom_completion(bloom_model: HuggingfaceGenerativeModel): @@ -494,7 +559,25 @@ async def test_input_padding(bert_base_yelp_polarity: HuggingfaceEncoderModel): response = await bert_base_yelp_polarity( {"instances": [request_one, request_two]}, headers={} ) - assert response == {"predictions": [1, 1]} + assert response == {"predictions": [ + { + 'confidence': 0.9988189339637756, + 'label': 1, + 'probabilities': [ + {'label': 0, 'probability': 0.0011810670839622617}, + {'label': 1, 'probability': 0.9988189339637756} + ] + }, + { + 'confidence': 0.9963782429695129, + 'label': 1, + 'probabilities': [ + {'label': 0, 'probability': 0.003621795680373907}, + {'label': 1, 'probability': 0.9963782429695129} + ] + } + ] + } @pytest.mark.asyncio @@ -504,7 +587,7 @@ async def test_input_truncation(bert_base_yelp_polarity: HuggingfaceEncoderModel # unless we set truncation=True in the tokenizer request = "good " * 600 response = await bert_base_yelp_polarity({"instances": [request]}, headers={}) - assert response == {"predictions": [1]} + assert response == {'predictions': [{'confidence': 0.9914830327033997, 'label': 1, 'probabilities': [{'label': 0, 'probability': 0.00851691048592329}, {'label': 1, 'probability': 0.9914830327033997}]}]} @pytest.mark.asyncio