-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathchain_model.py
94 lines (71 loc) · 3.13 KB
/
chain_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import re
from typing import Any, Dict, List
from src.configs import ModelConfig
from src.prompts import Prompt
from src.utils.string_utils import replace_parentheses
from .model import BaseModel
# Format: 0.header N.footer
class ChainModel(BaseModel):
def __init__(self, config: ModelConfig, models: List[BaseModel]):
super().__init__(config)
self.config = config
self.models = models
def predict(self, input: Prompt, **kwargs) -> str: # noqa: C901
curr_prompt = input
prompts: Dict[int | str, Any] = {0: curr_prompt, "prev": curr_prompt}
for j, model in enumerate(self.models):
prompts["self"] = model.config.prompt_template
new_header = ""
new_footer = ""
new_interm = ""
if "header" in model.config.prompt_template:
new_header = fill_template_string(
model.config.prompt_template["header"], prompts # type: ignore
)
curr_prompt.header = new_header
if "footer" in model.config.prompt_template:
new_footer = fill_template_string(
model.config.prompt_template["footer"], prompts # type: ignore
)
curr_prompt.footer = new_footer
if "template" in model.config.prompt_template:
new_interm = fill_template_string(
model.config.prompt_template["template"], prompts # type: ignore
)
curr_prompt.intermediate = new_interm
output = model.predict(curr_prompt)
curr_prompt.answer = output
prompts[j + 1] = curr_prompt
prompts["prev"] = curr_prompt
print(f"==Model {j}\n=Input:\n {curr_prompt.get_prompt()}")
print(f"=Output:\n {output}")
curr_prompt = curr_prompt.get_copy()
assert curr_prompt.answer is not None
return curr_prompt.answer
def predict_multi(self, inputs: List[Prompt], **kwargs):
res = []
for input in inputs:
res.append(self.predict(input, **kwargs))
return res
def predict_string(self, input: str, **kwargs) -> str:
return self.models[0].predict_string(input, **kwargs)
def fill_template_string(template: str, fill_dict: Dict[str, Any]) -> str:
extracted_locations = re.findall("\{.*?\}", template) # noqa: W605
to_insert = replace_parentheses(template)
for i, loc in enumerate(extracted_locations):
index, identifier = loc[1:-1].split(".")
if index not in ["self", "prev"]:
try:
index = int(index)
except ValueError:
raise ValueError(f"Index {index} is not a valid index.")
selected_elem = fill_dict[index]
selected_val = ""
if isinstance(selected_elem, Prompt):
selected_val = getattr(selected_elem, identifier)
elif isinstance(selected_elem, dict):
selected_val = selected_elem[identifier]
else:
raise NotImplementedError
to_insert = to_insert.replace(f"<{i+1}>", selected_val)
return to_insert