Skip to content

Commit

Permalink
Fix PromptLab templating (mlflow#10341)
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Lok <[email protected]>
  • Loading branch information
daniellok-db authored Nov 10, 2023
1 parent 4194122 commit 3e850ef
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 3 deletions.
11 changes: 8 additions & 3 deletions mlflow/_promptlab.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from string import Template
import re
from typing import List

import yaml
Expand All @@ -14,7 +14,7 @@ def __init__(self, prompt_template, prompt_parameters, model_parameters, model_r
self.prompt_parameters = prompt_parameters
self.model_parameters = model_parameters
self.model_route = model_route
self.prompt_template = Template(prompt_template)
self.prompt_template = prompt_template

def predict(self, inputs: pd.DataFrame) -> List[str]:
from mlflow.gateway import query
Expand All @@ -24,7 +24,12 @@ def predict(self, inputs: pd.DataFrame) -> List[str]:
prompt_parameters_as_dict = {
param.key: inputs[param.key][idx] for param in self.prompt_parameters
}
prompt = self.prompt_template.substitute(prompt_parameters_as_dict)

# copy replacement logic from PromptEngineering.utils.ts for consistency
prompt = self.prompt_template
for key, value in prompt_parameters_as_dict.items():
prompt = re.sub(r"\{\{\s*" + key + r"\s*\}\}", value, prompt)

model_parameters_as_dict = {param.key: param.value for param in self.model_parameters}
result = query(
route=self.model_route, data={"prompt": prompt, **model_parameters_as_dict}
Expand Down
39 changes: 39 additions & 0 deletions tests/promptlab/test_promptlab_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from unittest import mock

import pandas as pd

from mlflow._promptlab import _PromptlabModel
from mlflow.entities.param import Param


def test_promptlab_prompt_replacement():
data = pd.DataFrame(
data=[
{"thing": "books"},
{"thing": "coffee"},
{"thing": "nothing"},
]
)

prompt_parameters = [Param(key="thing", value="books")]
model_parameters = [Param(key="temperature", value=0.5), Param(key="max_tokens", value=10)]
prompt_template = "Write me a story about {{ thing }}."
model_route = "completions"

model = _PromptlabModel(prompt_template, prompt_parameters, model_parameters, model_route)
with mock.patch("mlflow.gateway.query") as mock_query:
model.predict(data)

calls = [
mock.call(
route="completions",
data={
"prompt": f"Write me a story about {thing}.",
"temperature": 0.5,
"max_tokens": 10,
},
)
for thing in data["thing"]
]

mock_query.assert_has_calls(calls, any_order=True)

0 comments on commit 3e850ef

Please sign in to comment.