Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: save and load Prompts #1458

Merged
merged 12 commits into from
Oct 10, 2024
14 changes: 13 additions & 1 deletion src/ragas/prompt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,18 @@ def _check_if_language_is_supported(language: str):


class BasePrompt(ABC):
def __init__(self, name: t.Optional[str] = None, language: str = "english"):
def __init__(
self,
name: t.Optional[str] = None,
language: str = "english",
original_hash: t.Optional[str] = None,
):
if name is None:
self.name = camel_to_snake(self.__class__.__name__)

_check_if_language_is_supported(language)
self.language = language
self.original_hash = original_hash

@abstractmethod
async def generate(
Expand Down Expand Up @@ -65,10 +71,16 @@ def generate_multiple(
class StringIO(BaseModel):
text: str

def __hash__(self):
return hash(self.text)


class BoolIO(BaseModel):
value: bool

def __hash__(self):
return hash(self.value)


class StringPrompt(BasePrompt):
"""
Expand Down
46 changes: 46 additions & 0 deletions src/ragas/prompt/mixin.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from __future__ import annotations

import inspect
import logging
import os
import typing as t

from .base import _check_if_language_is_supported
from .pydantic_prompt import PydanticPrompt

if t.TYPE_CHECKING:
from ragas.llms.base import BaseRagasLLM


logger = logging.getLogger(__name__)


class PromptMixin:
def get_prompts(self) -> t.Dict[str, PydanticPrompt]:
prompts = {}
Expand Down Expand Up @@ -40,3 +46,43 @@ async def adapt_prompts(
adapted_prompts[name] = adapted_prompt

return adapted_prompts

def save_prompts(self, path: str):
"""
save prompts to a directory in the format of {name}_{language}.json
"""
# check if path is valid
if not os.path.exists(path):
raise ValueError(f"Path {path} does not exist")

prompts = self.get_prompts()
for prompt_name, prompt in prompts.items():
# hash_hex = f"0x{hash(prompt) & 0xFFFFFFFFFFFFFFFF:016x}"
prompt_file_name = os.path.join(
path, f"{prompt_name}_{prompt.language}.json"
)
prompt.save(prompt_file_name)

def load_prompts(self, path: str, language: t.Optional[str] = None):
"""
Load prompts from a directory in the format of {name}_{language}.json
"""
# check if path is valid
if not os.path.exists(path):
raise ValueError(f"Path {path} does not exist")

# check if language is supported, defaults to english
if language is None:
language = "english"
logger.info(
"Language not specified, loading prompts for default language: %s",
language,
)
_check_if_language_is_supported(language)

loaded_prompts = {}
for prompt_name, prompt in self.get_prompts().items():
prompt_file_name = os.path.join(path, f"{prompt_name}_{language}.json")
loaded_prompt = prompt.__class__.load(prompt_file_name)
loaded_prompts[prompt_name] = loaded_prompt
return loaded_prompts
99 changes: 99 additions & 0 deletions src/ragas/prompt/pydantic_prompt.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from __future__ import annotations

import copy
import json
import logging
import os
import typing as t

from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import PydanticOutputParser
from pydantic import BaseModel

from ragas._version import __version__
from ragas.callbacks import new_group
from ragas.exceptions import RagasOutputParserException
from ragas.llms.prompt import PromptValue
Expand Down Expand Up @@ -220,6 +223,11 @@ async def adapt(
# throws ValueError if language is not supported
_check_if_language_is_supported(target_language)

# set the original hash, this is used to
# identify the original prompt object when loading from file
if self.original_hash is None:
self.original_hash = hash(self)

strings = get_all_strings(self.examples)
translated_strings = await translate_statements_prompt.generate(
llm=llm,
Expand All @@ -237,6 +245,97 @@ async def adapt(
new_prompt.language = target_language
return new_prompt

def __hash__(self):
# convert examples to json string for hashing
examples = []
for example in self.examples:
input_model, output_model = example
examples.append(
(input_model.model_dump_json(), output_model.model_dump_json())
)

# not sure if input_model and output_model should be included
return hash(
(
self.name,
self.input_model,
self.output_model,
self.instruction,
*examples,
self.language,
)
)

def __eq__(self, other):
if not isinstance(other, PydanticPrompt):
return False
return (
self.name == other.name
and self.input_model == other.input_model
and self.output_model == other.output_model
and self.instruction == other.instruction
and self.examples == other.examples
and self.language == other.language
)

def save(self, file_path: str):
"""
Save the prompt to a file.
"""
data = {
"ragas_version": __version__,
"original_hash": (
hash(self) if self.original_hash is None else self.original_hash
),
"language": self.language,
"instruction": self.instruction,
"examples": [
{"input": example[0].model_dump(), "output": example[1].model_dump()}
for example in self.examples
],
}
if os.path.exists(file_path):
raise FileExistsError(f"The file '{file_path}' already exists.")
with open(file_path, "w") as f:
json.dump(data, f, indent=2)
print(f"Prompt saved to {file_path}")

@classmethod
def load(cls, file_path: str) -> "PydanticPrompt[InputModel, OutputModel]":
with open(file_path, "r") as f:
data = json.load(f)

# You might want to add version compatibility checks here
ragas_version = data.get("ragas_version")
if ragas_version != __version__:
logger.warning(
"Prompt was saved with Ragas v%s, but you are loading it with Ragas v%s. "
"There might be incompatibilities.",
ragas_version,
__version__,
)
original_hash = data.get("original_hash")

prompt = cls()
instruction = data["instruction"]
examples = [
(
prompt.input_model(**example["input"]),
prompt.output_model(**example["output"]),
)
for example in data["examples"]
]

prompt.instruction = instruction
prompt.examples = examples
prompt.language = data.get("language", prompt.language)

# Optionally, verify the loaded prompt's hash matches the saved hash
if original_hash is not None and hash(prompt) != original_hash:
logger.warning("Loaded prompt hash does not match the saved hash.")

return prompt


# Ragas Output Parser
class OutputStringAndPrompt(BaseModel):
Expand Down
4 changes: 2 additions & 2 deletions src/ragas/testset/synthesizers/abstract_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
AbstractQueryFromTheme,
CAQInput,
CommonConceptsFromKeyphrases,
CommonThemeFromSummaries,
CommonThemeFromSummariesPrompt,
ComparativeAbstractQuery,
Concepts,
KeyphrasesAndNumConcepts,
Expand All @@ -44,7 +44,7 @@ class AbstractQuerySynthesizer(QuerySynthesizer):

def __post_init__(self):
super().__post_init__()
self.common_theme_prompt = CommonThemeFromSummaries()
self.common_theme_prompt = CommonThemeFromSummariesPrompt()

async def _generate_scenarios(
self, n: int, knowledge_graph: KnowledgeGraph, callbacks: Callbacks
Expand Down
2 changes: 1 addition & 1 deletion src/ragas/testset/synthesizers/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Themes(BaseModel):
themes: t.List[Theme]


class CommonThemeFromSummaries(PydanticPrompt[Summaries, Themes]):
class CommonThemeFromSummariesPrompt(PydanticPrompt[Summaries, Themes]):
input_model = Summaries
output_model = Themes
instruction = "Analyze the following summaries and identify given number of common themes. The themes should be concise, descriptive, and highlight a key aspect shared across the summaries."
Expand Down
32 changes: 17 additions & 15 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,24 @@ def pytest_configure(config):
)


class FakeTestLLM(BaseRagasLLM):
def llm(self):
return self

def generate_text(
self, prompt: PromptValue, n=1, temperature=1e-8, stop=None, callbacks=[]
):
generations = [[Generation(text=prompt.prompt_str)] * n]
return LLMResult(generations=generations)

async def agenerate_text(
self, prompt: PromptValue, n=1, temperature=1e-8, stop=None, callbacks=[]
):
return self.generate_text(prompt, n, temperature, stop, callbacks)
class EchoLLM(BaseRagasLLM):
def generate_text( # type: ignore
self,
prompt: PromptValue,
*args,
**kwargs,
) -> LLMResult:
return LLMResult(generations=[[Generation(text=prompt.to_string())]])

async def agenerate_text( # type: ignore
self,
prompt: PromptValue,
*args,
**kwargs,
) -> LLMResult:
return LLMResult(generations=[[Generation(text=prompt.to_string())]])


@pytest.fixture
def fake_llm():
return FakeTestLLM()
return EchoLLM()
48 changes: 48 additions & 0 deletions tests/unit/prompt/test_prompt_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pytest

from ragas.testset.synthesizers import AbstractQuerySynthesizer


def test_prompt_save_load(tmp_path, fake_llm):
synth = AbstractQuerySynthesizer(llm=fake_llm)
synth_prompts = synth.get_prompts()
synth.save_prompts(tmp_path)
loaded_prompts = synth.load_prompts(tmp_path)
assert len(synth_prompts) == len(loaded_prompts)
for name, prompt in synth_prompts.items():
assert name in loaded_prompts
assert prompt == loaded_prompts[name]


@pytest.mark.asyncio
async def test_prompt_save_adapt_load(tmp_path, fake_llm):
synth = AbstractQuerySynthesizer(llm=fake_llm)

# patch adapt_prompts
async def adapt_prompts_patched(self, language, llm):
for prompt in self.get_prompts().values():
prompt.instruction = "test"
prompt.language = language
return self.get_prompts()

synth.adapt_prompts = adapt_prompts_patched.__get__(synth)

# adapt prompts
original_prompts = synth.get_prompts()
adapted_prompts = await synth.adapt_prompts("spanish", fake_llm)
synth.set_prompts(**adapted_prompts)

# save n load
synth.save_prompts(tmp_path)
loaded_prompts = synth.load_prompts(tmp_path, language="spanish")

# check conditions
assert len(adapted_prompts) == len(loaded_prompts)
for name, adapted_prompt in adapted_prompts.items():
assert name in loaded_prompts
assert name in original_prompts

loaded_prompt = loaded_prompts[name]
assert adapted_prompt.instruction == loaded_prompt.instruction
assert adapted_prompt.language == loaded_prompt.language
assert adapted_prompt == loaded_prompt
Loading
Loading