Skip to content

Commit

Permalink
feat: Add support for llamaIndex in evaluation (#1619)
Browse files Browse the repository at this point in the history
Added type checks for llamaIndex LLMs and embeddings in the evaluate
function.
  • Loading branch information
suekou authored Nov 6, 2024
1 parent 2a4a5ad commit 6ff35f7
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ all = [
"rapidfuzz",
"pandas",
"datacompy",
"llama_index",
]
docs = [
"mkdocs>=1.6.1",
Expand Down
14 changes: 11 additions & 3 deletions src/ragas/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from langchain_core.embeddings import Embeddings as LangchainEmbeddings
from langchain_core.language_models import BaseLanguageModel as LangchainLLM

from llama_index.core.base.llms.base import BaseLLM as LlamaIndexLLM
from llama_index.core.base.embeddings.base import BaseEmbedding as LlamaIndexEmbedding

from ragas._analytics import EvaluationEvent, track, track_was_completed
from ragas.callbacks import ChainType, RagasTracer, new_group
from ragas.dataset_schema import (
Expand All @@ -19,13 +22,14 @@
from ragas.embeddings.base import (
BaseRagasEmbeddings,
LangchainEmbeddingsWrapper,
LlamaIndexEmbeddingsWrapper,
embedding_factory,
)
from ragas.exceptions import ExceptionInRunner
from ragas.executor import Executor
from ragas.integrations.helicone import helicone_config
from ragas.llms import llm_factory
from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper
from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper, LlamaIndexLLMWrapper
from ragas.metrics import AspectCritic
from ragas.metrics._answer_correctness import AnswerCorrectness
from ragas.metrics.base import (
Expand Down Expand Up @@ -56,8 +60,8 @@
def evaluate(
dataset: t.Union[Dataset, EvaluationDataset],
metrics: t.Optional[t.Sequence[Metric]] = None,
llm: t.Optional[BaseRagasLLM | LangchainLLM] = None,
embeddings: t.Optional[BaseRagasEmbeddings | LangchainEmbeddings] = None,
llm: t.Optional[BaseRagasLLM | LangchainLLM | LlamaIndexLLM] = None,
embeddings: t.Optional[BaseRagasEmbeddings | LangchainEmbeddings | LlamaIndexEmbedding] = None,
callbacks: Callbacks = None,
in_ci: bool = False,
run_config: RunConfig = RunConfig(),
Expand Down Expand Up @@ -182,8 +186,12 @@ def evaluate(
# set the llm and embeddings
if isinstance(llm, LangchainLLM):
llm = LangchainLLMWrapper(llm, run_config=run_config)
elif isinstance(llm, LlamaIndexLLM):
llm = LlamaIndexLLMWrapper(llm, run_config=run_config)
if isinstance(embeddings, LangchainEmbeddings):
embeddings = LangchainEmbeddingsWrapper(embeddings)
elif isinstance(embeddings, LlamaIndexEmbedding):
embeddings = LlamaIndexEmbeddingsWrapper(embeddings)

# init llms and embeddings
binary_metrics = []
Expand Down

0 comments on commit 6ff35f7

Please sign in to comment.