Skip to content

Commit

Permalink
feat(py): Increase concurrency for pytest evals, fix output rendering…
Browse files Browse the repository at this point in the history
… nits (#1489)
  • Loading branch information
jacoblee93 authored Feb 5, 2025
1 parent 652fd09 commit c442158
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 84 deletions.
24 changes: 6 additions & 18 deletions python/langsmith/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def _handle_output_args(args):
"""Handle output arguments."""
if any(opt in args for opt in ["--langsmith-output"]):
# Only add --quiet if it's not already there
if not any(a in args for a in ["-q", "--quiet"]):
args.insert(0, "--quiet")
if not any(a in args for a in ["-qq"]):
args.insert(0, "-qq")
# Disable built-in output capturing
if not any(a in args for a in ["-s", "--capture=no"]):
args.insert(0, "-s")
Expand Down Expand Up @@ -181,7 +181,7 @@ def _generate_table(self, suite_name: str):
process_ids = self.test_suites[suite_name]

title = f"""Test Suite: [bold]{suite_name}[/bold]
LangSmith link: [bright_cyan][link={self.test_suite_urls[suite_name]}]⌘ + click here[/link][/bright_cyan]""" # noqa: E501
LangSmith URL: [bright_cyan]{self.test_suite_urls[suite_name]}[/bright_cyan]""" # noqa: E501
table = Table(title=title, title_justify="left")
table.add_column("Test")
table.add_column("Inputs")
Expand All @@ -190,7 +190,6 @@ def _generate_table(self, suite_name: str):
table.add_column("Status")
table.add_column("Feedback")
table.add_column("Duration")
table.add_column("Logged")

# Test, inputs, ref outputs, outputs col width
max_status = len("status")
Expand Down Expand Up @@ -231,9 +230,7 @@ def _generate_table(self, suite_name: str):
aggregate_feedback = "--"

max_duration = max(max_duration, len(aggregate_duration))
max_dynamic_col_width = (
self.console.width - (max_status + max_duration + len("Logged"))
) // 5
max_dynamic_col_width = (self.console.width - (max_status + max_duration)) // 5
max_dynamic_col_width = max(max_dynamic_col_width, 8)

for pid, status in suite_statuses.items():
Expand Down Expand Up @@ -262,29 +259,19 @@ def _generate_table(self, suite_name: str):
f"[{status_color}]{status.get('status', 'queued')}[/{status_color}]",
feedback,
f"{duration:.2f}s",
"x" if status.get("logged") else "",
)

if suite_statuses:
logged = sum(s.get("logged", False) for s in suite_statuses.values()) / len(
suite_statuses
)
aggregate_logged = f"{logged:.0%}"
else:
aggregate_logged = "--"

# Add a blank row or a section separator if you like:
table.add_row("", "", "", "", "", "", "")
# Finally, our “footer” row:
table.add_row(
"[bold]Summary[/bold]",
"[bold]Averages[/bold]",
"",
"",
"",
aggregate_status,
aggregate_feedback,
aggregate_duration,
aggregate_logged,
)

return table
Expand All @@ -302,6 +289,7 @@ def pytest_configure(self, config):
def pytest_sessionfinish(self, session):
"""Stop Rich Live rendering at the end of the session."""
self.live.stop()
self.live.console.print("\nFinishing up...")


def pytest_configure(config):
Expand Down
119 changes: 53 additions & 66 deletions python/langsmith/testing/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import time
import uuid
import warnings
from collections import defaultdict
from concurrent.futures import Future
from pathlib import Path
from typing import (
Expand Down Expand Up @@ -450,7 +449,7 @@ def _get_example_id(
def _end_tests(test_suite: _LangSmithTestSuite):
git_info = ls_env.get_git_info() or {}
test_suite.shutdown()
dataset_version = test_suite.get_version()
dataset_version = test_suite.get_dataset_version()
dataset_id = test_suite._dataset.id
test_suite.client.update_project(
test_suite.experiment_id,
Expand Down Expand Up @@ -498,10 +497,8 @@ def __init__(
self.client = client or rt.get_cached_client()
self._experiment = experiment
self._dataset = dataset
self._version: Optional[datetime.datetime] = None
self._executor = ls_utils.ContextThreadPoolExecutor(max_workers=1)
self._example_futures: dict[ID_TYPE, list[Future]] = defaultdict(list)
self._example_modified_at: dict[ID_TYPE, datetime.datetime] = {}
self._dataset_version: Optional[datetime.datetime] = dataset.modified_at
self._executor = ls_utils.ContextThreadPoolExecutor()
atexit.register(_end_tests, self)

@property
Expand Down Expand Up @@ -538,20 +535,8 @@ def from_test(
def name(self):
return self._experiment.name

def update_version(self, version: datetime.datetime, example_id: uuid.UUID) -> None:
with self._lock:
if self._version is None or version > self._version:
self._version = version
self._example_modified_at[example_id] = version

def get_version(
self, example_id: Optional[uuid.UUID] = None
) -> Optional[datetime.datetime]:
with self._lock:
if not example_id:
return self._version
else:
return self._example_modified_at.get(example_id)
def get_dataset_version(self):
return self._dataset_version

def submit_result(
self,
Expand Down Expand Up @@ -591,23 +576,7 @@ def sync_example(
update = {"inputs": inputs, "reference_outputs": outputs}
update = {k: v for k, v in update.items() if v is not None}
pytest_plugin.update_process_status(pytest_nodeid, update)
future = self._executor.submit(
self._sync_example,
example_id,
inputs,
outputs,
metadata.copy() if metadata else metadata,
)
with self._lock:
self._example_futures[example_id].append(future)

def _sync_example(
self,
example_id: uuid.UUID,
inputs: Optional[dict],
outputs: Optional[dict],
metadata: Optional[dict],
) -> None:
metadata = metadata.copy() if metadata else metadata
inputs = _serde_example_values(inputs)
outputs = _serde_example_values(outputs)
try:
Expand Down Expand Up @@ -636,8 +605,14 @@ def _sync_example(
dataset_id=self.id,
)
example = self.client.read_example(example_id=example.id)
if example.modified_at:
self.update_version(example.modified_at, example_id=example_id)
if self._dataset_version is None:
self._dataset_version = example.modified_at
elif (
example.modified_at
and self._dataset_version
and example.modified_at > self._dataset_version
):
self._dataset_version = example.modified_at

def _submit_feedback(
self,
Expand All @@ -664,18 +639,12 @@ def _create_feedback(self, run_id: ID_TYPE, feedback: dict, **kwargs: Any) -> No
def shutdown(self):
self._executor.shutdown()

def wait_example_updates(self, example_id: ID_TYPE):
"""Wait for all example updates to complete."""
with self._lock:
while self._example_futures[example_id]:
self._example_futures[example_id].pop().result()

def end_run(
self,
run_tree,
example_id,
outputs,
end_time,
reference_outputs,
pytest_plugin=None,
pytest_nodeid=None,
) -> Future:
Expand All @@ -684,23 +653,25 @@ def end_run(
run_tree=run_tree,
example_id=example_id,
outputs=outputs,
end_time=end_time,
reference_outputs=reference_outputs,
pytest_plugin=pytest_plugin,
pytest_nodeid=pytest_nodeid,
)

def _end_run(
self, run_tree, example_id, outputs, end_time, pytest_plugin, pytest_nodeid
self,
run_tree,
example_id,
outputs,
reference_outputs,
pytest_plugin,
pytest_nodeid,
) -> None:
# TODO: remove this hack so that run durations are correct
# Ensure example is fully updated
self.wait_example_updates(example_id)
# Ensure that run end time is after example modified at.
example_modified_at = self.get_version(example_id=example_id) or end_time
end_time = max(example_modified_at, end_time)
run_tree.end(outputs=outputs, end_time=end_time)
self.sync_example(example_id, inputs=run_tree.inputs, outputs=reference_outputs)
run_tree.end(outputs=outputs)
run_tree.patch()
pytest_plugin.update_process_status(pytest_nodeid, {"logged": True})


class _TestCase:
Expand All @@ -711,17 +682,24 @@ def __init__(
run_id: uuid.UUID,
pytest_plugin: Any = None,
pytest_nodeid: Any = None,
inputs: Optional[dict] = None,
reference_outputs: Optional[dict] = None,
) -> None:
self.test_suite = test_suite
self.example_id = example_id
self.run_id = run_id
self.pytest_plugin = pytest_plugin
self.pytest_nodeid = pytest_nodeid
self._logged_reference_outputs: Optional[dict] = None

if pytest_plugin and pytest_nodeid:
pytest_plugin.add_process_to_test_suite(
test_suite._dataset.name, pytest_nodeid
)
if inputs:
self.log_inputs(inputs)
if reference_outputs:
self.log_reference_outputs(reference_outputs)

def sync_example(
self, *, inputs: Optional[dict] = None, outputs: Optional[dict] = None
Expand All @@ -746,12 +724,25 @@ def submit_feedback(self, *args, **kwargs: Any):
},
)

def log_inputs(self, inputs: dict) -> None:
if self.pytest_plugin and self.pytest_nodeid:
self.pytest_plugin.update_process_status(
self.pytest_nodeid, {"inputs": inputs}
)

def log_outputs(self, outputs: dict) -> None:
if self.pytest_plugin and self.pytest_nodeid:
self.pytest_plugin.update_process_status(
self.pytest_nodeid, {"outputs": outputs}
)

def log_reference_outputs(self, reference_outputs: dict) -> None:
self._logged_reference_outputs = reference_outputs
if self.pytest_plugin and self.pytest_nodeid:
self.pytest_plugin.update_process_status(
self.pytest_nodeid, {"reference_outputs": reference_outputs}
)

def submit_test_result(
self,
error: Optional[str] = None,
Expand Down Expand Up @@ -780,12 +771,11 @@ def end_time(self) -> None:
def end_run(self, run_tree, outputs: Any) -> None:
if not (outputs is None or isinstance(outputs, dict)):
outputs = {"output": outputs}
end_time = datetime.datetime.now(datetime.timezone.utc)
self.test_suite.end_run(
run_tree,
self.example_id,
outputs,
end_time=end_time,
reference_outputs=self._logged_reference_outputs,
pytest_plugin=self.pytest_plugin,
pytest_nodeid=self.pytest_nodeid,
)
Expand Down Expand Up @@ -837,12 +827,6 @@ def _create_test_case(
)
example_id, example_name = _get_example_id(func, inputs, test_suite.id)
example_id = langtest_extra["id"] or example_id
test_suite.sync_example(
example_id,
inputs=inputs,
outputs=outputs,
metadata={"signature": _get_test_repr(func, signature), "name": example_name},
)
pytest_plugin = (
pytest_request.config.pluginmanager.get_plugin("langsmith_output_plugin")
if pytest_request
Expand All @@ -855,13 +839,16 @@ def _create_test_case(
+ "/compare?selectedSessions="
+ str(test_suite.experiment_id)
)
return _TestCase(
test_case = _TestCase(
test_suite,
example_id,
run_id=uuid.uuid4(),
inputs=inputs,
reference_outputs=outputs,
pytest_plugin=pytest_plugin,
pytest_nodeid=pytest_nodeid,
)
return test_case


def _run_test(
Expand Down Expand Up @@ -1048,7 +1035,7 @@ def test_foo() -> None:
)
raise ValueError(msg)
run_tree.add_inputs(inputs)
test_case.sync_example(inputs=inputs)
test_case.log_inputs(inputs)


def log_outputs(outputs: dict, /) -> None:
Expand Down Expand Up @@ -1095,7 +1082,7 @@ def test_foo() -> None:
test_case.log_outputs(outputs)


def log_reference_outputs(outputs: dict, /) -> None:
def log_reference_outputs(reference_outputs: dict, /) -> None:
"""Log example reference outputs from within a pytest test run.
.. warning::
Expand Down Expand Up @@ -1134,7 +1121,7 @@ def test_foo() -> None:
"decorated with @pytest.mark.langsmith."
)
raise ValueError(msg)
test_case.sync_example(outputs=outputs)
test_case.log_reference_outputs(reference_outputs)


def log_feedback(
Expand Down
34 changes: 34 additions & 0 deletions python/tests/evaluation/test_decorator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import time

import pytest

Expand Down Expand Up @@ -81,3 +82,36 @@ def test_param(a, b):
t.log_outputs({"sum": a + b})
t.log_reference_outputs({"sum": a + b})
assert a + b == a + b


@pytest.fixture
def inputs() -> int:
return 5


@pytest.fixture
def reference_outputs() -> int:
return 10


@pytest.mark.skipif(
not os.getenv("LANGSMITH_TRACING"),
reason="LANGSMITH_TRACING environment variable not set",
)
@pytest.mark.langsmith(output_keys=["reference_outputs"])
def test_fixture(inputs: int, reference_outputs: int):
result = 2 * inputs
t.log_outputs({"d": result})
assert result == reference_outputs


@pytest.mark.skipif(
not os.getenv("LANGSMITH_TRACING"),
reason="LANGSMITH_TRACING environment variable not set",
)
@pytest.mark.langsmith
def test_slow_test():
t.log_inputs({"slow": "I am slow"})
time.sleep(5)
t.log_outputs({"slow_result": "I am slow"})
t.log_reference_outputs({"slow_result": "I am not fast"})

0 comments on commit c442158

Please sign in to comment.