Skip to content

Commit

Permalink
bring back enhancements after getting kserve up-to-date (#42)
Browse files Browse the repository at this point in the history
* improve dockerfile, makefile, readme

* support custom classification labels, refactor postprocess

* support text embedding task

* improve support for token classification (named entity recognition)
  • Loading branch information
kevinmingtarja committed Sep 26, 2024
1 parent 4553c11 commit 15b16e2
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 15 deletions.
8 changes: 4 additions & 4 deletions python/huggingface_server.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ RUN python3 -m venv $VIRTUAL_ENV
ENV PATH="$VIRTUAL_ENV/bin:$PATH"

COPY kserve/pyproject.toml kserve/poetry.lock kserve/
RUN cd kserve && poetry install --no-root --no-interaction --no-cache
RUN cd kserve && poetry install --all-extras --no-root --no-interaction --no-cache
COPY kserve kserve
RUN cd kserve && poetry install --no-interaction --no-cache
RUN cd kserve && poetry install --all-extras --no-interaction --no-cache

COPY huggingfaceserver/pyproject.toml huggingfaceserver/poetry.lock huggingfaceserver/
RUN cd huggingfaceserver && poetry install --no-root --no-interaction --no-cache
RUN cd huggingfaceserver && poetry install --all-extras --no-root --no-interaction --no-cache
COPY huggingfaceserver huggingfaceserver
RUN cd huggingfaceserver && poetry install --no-interaction --no-cache
RUN cd huggingfaceserver && poetry install --all-extras --no-interaction --no-cache

RUN pip3 install vllm==${VLLM_VERSION}

Expand Down
5 changes: 5 additions & 0 deletions python/huggingfaceserver/Makefile
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
IMG ?= hypermode/huggingface-model-server:latest

dev_install:
poetry install --with test -E vllm --no-interaction

Expand All @@ -9,3 +11,6 @@ test: type_check

type_check:
mypy --ignore-missing-imports huggingfaceserver

build-docker:
docker build --platform linux/amd64 -t $(IMG) -f ../huggingface_server.Dockerfile ..
6 changes: 6 additions & 0 deletions python/huggingfaceserver/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ The preprocess and post-process handlers are implemented based on different ML t
token-classification, text-generation, text2text generation. Based on the performance requirement, you can choose to perform
the inference on a more optimized inference engine like triton inference server and vLLM for text generation.

## Build Container Image Locally

The Dockerfile exists in the parent directory
```bash
docker build --platform linux/amd64 -t hypermode/huggingface-model-server:latest -f huggingface_server.Dockerfile ..
```

## Run HuggingFace Server Locally

Expand Down
1 change: 1 addition & 0 deletions python/huggingfaceserver/huggingfaceserver/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def load_model():
return_token_type_ids=kwargs.get("return_token_type_ids", None),
predictor_config=predictor_config,
request_logger=request_logger,
classification_labels=kwargs.get("classification_labels", None)
)
model.load()
return model
Expand Down
49 changes: 44 additions & 5 deletions python/huggingfaceserver/huggingfaceserver/encoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
AutoConfig,
AutoTokenizer,
BatchEncoding,
pipeline,
PreTrainedModel,
PreTrainedTokenizerBase,
PretrainedConfig,
Expand Down Expand Up @@ -84,6 +85,7 @@ def __init__(
return_probabilities: bool = False,
predictor_config: Optional[PredictorConfig] = None,
request_logger: Optional[RequestLogger] = None,
classification_labels: Optional[Dict[int, str]] = None,
):
super().__init__(model_name, predictor_config)
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -99,6 +101,8 @@ def __init__(
self.trust_remote_code = trust_remote_code
self.return_probabilities = return_probabilities
self.request_logger = request_logger
self.classification_labels = classification_labels
self.nlp = None

if model_config:
self.model_config = model_config
Expand Down Expand Up @@ -140,6 +144,9 @@ def load(self) -> bool:

if self._model._no_split_modules:
device_map = "auto"
# somehow, setting it to True give worse results for NER task
if self.task == MLTask.token_classification.value:
self.do_lower_case = False

tokenizer_kwargs = {}
model_kwargs = {}
Expand Down Expand Up @@ -179,6 +186,8 @@ def load(self) -> bool:
# When adding new tokens to the vocabulary, we should make sure to also resize the token embedding
# matrix of the model so that its embedding matrix matches the tokenizer.
self._model.resize_token_embeddings(len(self._tokenizer))
if self.task == MLTask.token_classification:
self.nlp = pipeline("ner", model=self._model, tokenizer=self._tokenizer)
logger.info(
f"Successfully loaded huggingface model from path {model_id_or_path}"
)
Expand Down Expand Up @@ -208,6 +217,7 @@ def preprocess(
truncation=True,
)
context["payload"] = payload
context["inputs"] = inputs
context["input_ids"] = inputs["input_ids"]
if self.task == MLTask.text_embedding:
context["attention_mask"] = inputs["attention_mask"]
Expand All @@ -226,6 +236,12 @@ def preprocess(
)
return infer_request
else:
if self.task == MLTask.token_classification.value:
context["payload"] = payload
context["inputs"] = instances
context["input_ids"] = []
return instances

inputs = self._tokenizer(
instances,
max_length=self.max_length,
Expand All @@ -236,6 +252,7 @@ def preprocess(
truncation=True,
)
context["payload"] = payload
context["inputs"] = inputs
context["input_ids"] = inputs["input_ids"]
if self.task == MLTask.text_embedding:
context["attention_mask"] = inputs["attention_mask"]
Expand All @@ -251,8 +268,12 @@ async def predict(
# like NVIDIA triton inference server
return await super().predict(input_batch, context)
else:
input_batch = input_batch.to(self._device)
try:
if self.task == MLTask.token_classification:
with torch.no_grad():
return self.nlp(input_batch)

input_batch = input_batch.to(self._device)
with torch.no_grad():
outputs = self._model(**input_batch)
if self.task == MLTask.text_embedding.value:
Expand All @@ -276,14 +297,32 @@ def postprocess(
input_ids = torch.Tensor(input_ids)
inferences = []
if self.task == MLTask.sequence_classification:
num_rows, num_cols = outputs.shape
outputs = torch.nn.functional.softmax(outputs, dim = -1).detach()
max_indices = torch.argmax(outputs, dim=1).tolist()
final = outputs.tolist()

id2label = {0: 0, 1: 1}
if self.classification_labels:
id2label = {i: val for i, val in enumerate(self.classification_labels)}

num_rows, _ = outputs.shape
for i in range(num_rows):
out = outputs[i].unsqueeze(0)
if self.return_probabilities:
inferences.append(dict(enumerate(out.numpy().flatten())))
else:
predicted_idx = out.argmax().item()
inferences.append(predicted_idx)
max_id = max_indices[i]
res = {
"label": id2label[max_id],
"confidence": final[i][max_id],
"probabilities": [
{
"label": id2label[j],
"probability": final[i][j]
} for j in range(len(final[i]))
]
}
inferences.append(res)
return get_predict_response(request, inferences, self.name)
elif self.task == MLTask.fill_mask:
num_rows = outputs.shape[0]
Expand All @@ -305,7 +344,7 @@ def postprocess(
inferences.append(self._tokenizer.decode(predicted_token_id))
return get_predict_response(request, inferences, self.name)
elif self.task == MLTask.token_classification:
num_rows = outputs.shape[0]
num_rows = len(outputs)
for i in range(num_rows):
output = outputs[i].unsqueeze(0)
if self.return_probabilities:
Expand Down
95 changes: 89 additions & 6 deletions python/huggingfaceserver/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,17 @@ def openai_gpt_model():
model.stop()


@pytest.fixture(scope="module")
def text_embedding():
model = HuggingfaceEncoderModel(
"mxbai-embed-large-v1",
model_id_or_path="mixedbread-ai/mxbai-embed-large-v1",
task=MLTask.text_embedding,
)
model.load()
yield model
model.stop()

@pytest.fixture(scope="module")
def text_embedding():
model = HuggingfaceEncoderModel(
Expand Down Expand Up @@ -278,7 +289,25 @@ async def test_bert_sequence_classification(bert_base_yelp_polarity):
response = await bert_base_yelp_polarity(
{"instances": [request, request]}, headers={}
)
assert response == {"predictions": [1, 1]}
assert response == {"predictions": [
{
'confidence': 0.9988189339637756,
'label': 1,
'probabilities': [
{'label': 0, 'probability': 0.0011810670839622617},
{'label': 1, 'probability': 0.9988189339637756}
]
},
{
'confidence': 0.9988189339637756,
'label': 1,
'probabilities': [
{'label': 0, 'probability': 0.0011810670839622617},
{'label': 1, 'probability': 0.9988189339637756}
]
}
]
}


@pytest.mark.asyncio
Expand Down Expand Up @@ -314,12 +343,48 @@ async def test_bert_token_classification(bert_token_classification):
{"instances": [request, request]}, headers={}
)
assert response == {
"predictions": [
[[0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
[[0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
'predictions': [
[
{'entity': 'I-ORG', 'score': 0.9972999691963196, 'index': 1, 'word': 'Hu', 'start': 0, 'end': 2},
{'entity': 'I-ORG', 'score': 0.9716504216194153, 'index': 2, 'word': '##gging', 'start': 2, 'end': 7},
{'entity': 'I-ORG', 'score': 0.9962745904922485, 'index': 3, 'word': '##F', 'start': 7, 'end': 8},
{'entity': 'I-ORG', 'score': 0.993005096912384, 'index': 4, 'word': '##ace', 'start': 8, 'end': 11},
{'entity': 'I-LOC', 'score': 0.9940695762634277, 'index': 10, 'word': 'Paris', 'start': 34, 'end': 39},
{'entity': 'I-LOC', 'score': 0.9982321858406067, 'index': 12, 'word': 'New', 'start': 44, 'end': 47},
{'entity': 'I-LOC', 'score': 0.9975290894508362, 'index': 13, 'word': 'York', 'start': 48, 'end': 52}
],
[
{'entity': 'I-ORG', 'score': 0.9972999691963196, 'index': 1, 'word': 'Hu', 'start': 0, 'end': 2},
{'entity': 'I-ORG', 'score': 0.9716504216194153, 'index': 2, 'word': '##gging', 'start': 2, 'end': 7},
{'entity': 'I-ORG', 'score': 0.9962745904922485, 'index': 3, 'word': '##F', 'start': 7, 'end': 8},
{'entity': 'I-ORG', 'score': 0.993005096912384, 'index': 4, 'word': '##ace', 'start': 8, 'end': 11},
{'entity': 'I-LOC', 'score': 0.9940695762634277, 'index': 10, 'word': 'Paris', 'start': 34, 'end': 39},
{'entity': 'I-LOC', 'score': 0.9982321858406067, 'index': 12, 'word': 'New', 'start': 44, 'end': 47},
{'entity': 'I-LOC', 'score': 0.9975290894508362, 'index': 13, 'word': 'York', 'start': 48, 'end': 52}]
]
}

@pytest.mark.asyncio
async def test_text_embedding(text_embedding):
def cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
if len(a.shape) == 1:
a = a.unsqueeze(0)

if len(b.shape) == 1:
b = b.unsqueeze(0)

a_norm = F.normalize(a, p=2, dim=1)
b_norm = F.normalize(b, p=2, dim=1)
return torch.mm(a_norm, b_norm.transpose(0, 1))

requests = ["I'm happy", "I'm full of happiness", "They were born in the capital city of France, Paris"]
response = await text_embedding({"instances": requests}, headers={})
predictions = response["predictions"]

print(cosine_similarity(torch.tensor(predictions[0]), torch.tensor(predictions[1]))[0])
print(cosine_similarity(torch.tensor(predictions[0]), torch.tensor(predictions[2]))[0])
assert cosine_similarity(torch.tensor(predictions[0]), torch.tensor(predictions[1]))[0] > 0.9
assert cosine_similarity(torch.tensor(predictions[0]), torch.tensor(predictions[2]))[0] < 0.6

@pytest.mark.asyncio
async def test_bloom_completion(bloom_model: HuggingfaceGenerativeModel):
Expand Down Expand Up @@ -494,7 +559,25 @@ async def test_input_padding(bert_base_yelp_polarity: HuggingfaceEncoderModel):
response = await bert_base_yelp_polarity(
{"instances": [request_one, request_two]}, headers={}
)
assert response == {"predictions": [1, 1]}
assert response == {"predictions": [
{
'confidence': 0.9988189339637756,
'label': 1,
'probabilities': [
{'label': 0, 'probability': 0.0011810670839622617},
{'label': 1, 'probability': 0.9988189339637756}
]
},
{
'confidence': 0.9963782429695129,
'label': 1,
'probabilities': [
{'label': 0, 'probability': 0.003621795680373907},
{'label': 1, 'probability': 0.9963782429695129}
]
}
]
}


@pytest.mark.asyncio
Expand All @@ -504,7 +587,7 @@ async def test_input_truncation(bert_base_yelp_polarity: HuggingfaceEncoderModel
# unless we set truncation=True in the tokenizer
request = "good " * 600
response = await bert_base_yelp_polarity({"instances": [request]}, headers={})
assert response == {"predictions": [1]}
assert response == {'predictions': [{'confidence': 0.9914830327033997, 'label': 1, 'probabilities': [{'label': 0, 'probability': 0.00851691048592329}, {'label': 1, 'probability': 0.9914830327033997}]}]}


@pytest.mark.asyncio
Expand Down

0 comments on commit 15b16e2

Please sign in to comment.