Skip to content

Commit

Permalink
Add sentence transformers example (#8425)
Browse files Browse the repository at this point in the history
Signed-off-by: Ben Wilson <[email protected]>
  • Loading branch information
BenWilson2 authored May 12, 2023
1 parent a0203a1 commit ba17e5a
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 10 deletions.
18 changes: 8 additions & 10 deletions examples/pip_requirements/pip_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,24 +47,22 @@ def main():
run_id = run.info.run_id

# Get the expected mlflow version
mlflow_version = Version(mlflow.__version__)
mlflow_version_range = (
f"mlflow<{mlflow_version.major + 1},>={mlflow_version.major}.{mlflow_version.minor}"
)
mlflow_version_raw = Version(mlflow.__version__)
mlflow_version = f"mlflow=={mlflow_version_raw.major}.{mlflow_version_raw.minor}"

# Default (both `pip_requirements` and `extra_pip_requirements` are unspecified)
artifact_path = "default"
mlflow.xgboost.log_model(model, artifact_path, signature=signature)
pip_reqs = get_pip_requirements(run_id, artifact_path)
assert pip_reqs.issuperset([mlflow_version_range, xgb_req]), pip_reqs
assert pip_reqs.issuperset([mlflow_version, xgb_req]), pip_reqs

# Overwrite the default set of pip requirements using `pip_requirements`
artifact_path = "pip_requirements"
mlflow.xgboost.log_model(
model, artifact_path, pip_requirements=[sklearn_req], signature=signature
)
pip_reqs = get_pip_requirements(run_id, artifact_path)
assert pip_reqs == {mlflow_version_range, sklearn_req}, pip_reqs
assert pip_reqs == {mlflow_version, sklearn_req}, pip_reqs

# Add extra pip requirements on top of the default set of pip requirements
# using `extra_pip_requirements`
Expand All @@ -73,7 +71,7 @@ def main():
model, artifact_path, extra_pip_requirements=[sklearn_req], signature=signature
)
pip_reqs = get_pip_requirements(run_id, artifact_path)
assert pip_reqs.issuperset([mlflow_version_range, xgb_req, sklearn_req]), pip_reqs
assert pip_reqs.issuperset([mlflow_version, xgb_req, sklearn_req]), pip_reqs

# Specify pip requirements using a requirements file
with tempfile.NamedTemporaryFile("w", suffix=".requirements.txt") as f:
Expand All @@ -86,7 +84,7 @@ def main():
model, artifact_path, pip_requirements=f.name, signature=signature
)
pip_reqs = get_pip_requirements(run_id, artifact_path)
assert pip_reqs == {mlflow_version_range, sklearn_req}, pip_reqs
assert pip_reqs == {mlflow_version, sklearn_req}, pip_reqs

# List of pip requirement strings
artifact_path = "requirements_file_list"
Expand All @@ -97,7 +95,7 @@ def main():
signature=signature,
)
pip_reqs = get_pip_requirements(run_id, artifact_path)
assert pip_reqs == {mlflow_version_range, xgb_req, sklearn_req}, pip_reqs
assert pip_reqs == {mlflow_version, xgb_req, sklearn_req}, pip_reqs

# Using a constraints file
with tempfile.NamedTemporaryFile("w", suffix=".constraints.txt") as f:
Expand All @@ -114,7 +112,7 @@ def main():
pip_reqs, pip_cons = get_pip_requirements(
run_id, artifact_path, return_constraints=True
)
assert pip_reqs == {mlflow_version_range, xgb_req, "-c constraints.txt"}, pip_reqs
assert pip_reqs == {mlflow_version, xgb_req, "-c constraints.txt"}, pip_reqs
assert pip_cons == {sklearn_req}, pip_cons


Expand Down
54 changes: 54 additions & 0 deletions examples/transformers/sentence_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from transformers import BertModel, BertTokenizerFast, pipeline
import mlflow
import torch


sentence_transformers_architecture = "sentence-transformers/all-MiniLM-L12-v2"
task = "feature-extraction"

model = BertModel.from_pretrained(sentence_transformers_architecture)
tokenizer = BertTokenizerFast.from_pretrained(sentence_transformers_architecture)

sentence_transformer_pipeline = pipeline(task=task, model=model, tokenizer=tokenizer)

with mlflow.start_run():
model_info = mlflow.transformers.log_model(
transformers_model=sentence_transformer_pipeline,
artifact_path="sentence_transformer",
framework="pt",
torch_dtype=torch.bfloat16,
)

loaded = mlflow.transformers.load_model(model_info.model_uri, return_type="components")


def pool_and_normalize_encodings(input_sentences, model, tokenizer, **kwargs):
def pool(model_output, attention_mask):
embeddings = model_output[0]
expanded_mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
return torch.sum(embeddings * expanded_mask, 1) / torch.clamp(
expanded_mask.sum(1), min=1e-9
)

encoded = tokenizer(
input_sentences,
padding=True,
truncation=True,
return_tensors="pt",
)
with torch.no_grad():
model_output = model(**encoded)

pooled = pool(model_output, encoded["attention_mask"])
return torch.nn.functional.normalize(pooled, p=2, dim=1)


sentences = [
"He said that he's sinking all of his investment budget into coconuts.",
"No matter how deep you dig, there's going to be a point when it just gets too hot.",
"She said that there isn't a noticeable difference between a 10 year and a 15 year whisky.",
]

encoded_sentences = pool_and_normalize_encodings(sentences, **loaded)

print(encoded_sentences)
1 change: 1 addition & 0 deletions tests/examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def test_mlflow_run_example(directory, params, tmp_path):
("transformers", ["python", "conversational.py"]),
("transformers", ["python", "load_components.py"]),
("transformers", ["python", "simple.py"]),
("transformers", ["python", "sentence_transformer.py"]),
],
)
def test_command_example(directory, command):
Expand Down

0 comments on commit ba17e5a

Please sign in to comment.