Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add system prompt and vary temp and new tokens #545

Merged
merged 3 commits into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions prompting/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,36 @@ class InferenceRewardConfig(BaseRewardConfig):

Ask a question about the text and nothing else:"""

SYSTEM_PROMPTS = [
"",
"You are a helpful AI assistant. Provide concise, accurate answers to any questions asked.",
"You are a friendly and patient assistant. Communicate your responses in a clear, easy-to-understand way, ensuring the user feels supported.",
"You are a creative helper. Offer engaging, imaginative responses that keep the user interested, while maintaining accuracy and clarity.",
]


class InferenceTask(BaseTextTask):
name: ClassVar[str] = "inference"
# TODO: Once we want to enable the 'actual' inference task with exact models
query: str | None = None
reference: str | None = None
system_prompt: str | None = None
llm_model: ModelConfig | None = None
llm_model_id: ModelConfig | None = random.choice(ModelZoo.models_configs).llm_model_id
seed: int = Field(default_factory=lambda: random.randint(0, 1_000_000), allow_mutation=False)
sampling_params: dict[str, float] = shared_settings.SAMPLING_PARAMS
sampling_params: dict[str, float] = shared_settings.SAMPLING_PARAMS.copy()

@model_validator(mode="after")
def random_llm_model_id(self):
if self.query: # If we are already defining query, as in the case of organics, we also specify model.
return self
# Choose system prompt and randomize inference settings
self.system_prompt = random.choice(SYSTEM_PROMPTS)
self.messages = []
if self.system_prompt:
self.messages.append({"role": "system", "content": self.system_prompt})
self.sampling_params["temperature"] = random.randint(0, 10) / 10
self.sampling_params["max_new_tokens"] = random.choice([256, 512, 1024, 2048])

if np.random.rand() < 0.2:
self.llm_model_id = None
Expand All @@ -55,8 +70,9 @@ def random_llm_model_id(self):
def make_query(self, dataset_entry: ChatEntry) -> str:
if self.query:
return self.query
self.query = dataset_entry.messages
self.messages = dataset_entry.messages
self.messages.extend(dataset_entry.messages)
self.query = self.messages

return self.query

def make_reference(self, dataset_entry: ChatEntry) -> str:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ max-line-length = 120
extend-ignore = "D203,E203,E251,E266,E302,E305,E401,E402,E501,F401,F403,W503"
exclude = ".git,__pycache__,dist,.venv,venv,*/lib/python*/site-packages,*/lib64/python*/site-packages"
# TODO: Decrease to at least 10 (prompting/weight_setting/weight_setter.py).
max-complexity = 13
max-complexity = 14

[tool.isort]
atomic = true
Expand Down
Loading