Skip to content

Commit

Permalink
Fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lkomali committed Feb 13, 2025
1 parent 38a23cf commit 03f1671
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,4 @@ repos:
name: Add License
entry: python tools/add_copyright.py
language: python
require_serial: true
require_serial: true
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,17 @@


from pathlib import Path
from typing import Any, Dict, List, Tuple, Union
from typing import List, Tuple, Union

from genai_perf.inputs.retrievers.base_input_retriever import BaseInputRetriever
from genai_perf.inputs.retrievers.generic_dataset import FileData, GenericDataset
from genai_perf.inputs.retrievers.generic_dataset import (
FileData,
GenericDataset,
ImageData,
OptionalData,
TextData,
Timestamp,
)


class BaseFileInputRetriever(BaseInputRetriever):
Expand All @@ -55,8 +62,8 @@ def _verify_file(self, filename: Path) -> None:
raise FileNotFoundError(f"The file '{filename}' does not exist.")

def _get_content_from_input_file(self, filename: Path) -> Union[
Tuple[List[str], List[str]],
Tuple[List[str], List[int], List[Dict[Any, Any]]],
Tuple[TextData, ImageData],
Tuple[TextData, List[Timestamp], List[OptionalData]],
]:
"""
Reads the content from a JSONL file and returns lists of each content type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pathlib import Path
from typing import Any, Dict, List, Tuple

from genai_perf.exceptions import GenAIPerfException
from genai_perf.inputs.retrievers.base_file_input_retriever import (
BaseFileInputRetriever,
)
Expand Down Expand Up @@ -139,11 +140,11 @@ def _get_valid_timestamp(self, data: Dict[str, Any]) -> int:
"""
timestamp = data.get("timestamp")
if timestamp is None:
raise ValueError("Each data entry must have a 'timestamp' field.")
raise GenAIPerfException("Each data entry must have a 'timestamp' field.")
try:
timestamp = int(timestamp)
except ValueError:
raise ValueError(
except Exception:
raise GenAIPerfException(
f"Invalid timestamp: Expecting an integer but received '{timestamp}'."
)

Expand Down
2 changes: 2 additions & 0 deletions genai-perf/genai_perf/subcommand/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def _report_output(
elif args.request_rate:
infer_mode = "request_rate"
load_level = f"{args.request_rate}"
# When using fixed schedule mode, infer mode is not set.
# Setting to default values to avoid an error.
elif args.prompt_source == ic.PromptSource.PAYLOAD:
infer_mode = "request_rate"
load_level = "1.0"
Expand Down
35 changes: 35 additions & 0 deletions genai-perf/tests/test_retrievers/test_payload_input_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from unittest.mock import mock_open, patch

import pytest
from genai_perf.exceptions import GenAIPerfException
from genai_perf.inputs.retrievers.generic_dataset import (
DataRow,
FileData,
Expand Down Expand Up @@ -166,3 +167,37 @@ def test_conflicting_keys_error(self, mock_file, retriever):
match="Each data entry must have only one of 'text_input' or 'text' key name.",
):
retriever._get_content_from_input_file(Path("test_input.jsonl"))

@pytest.mark.parametrize(
"mock_data, expected_error",
[
(
{"text": "What is AI?", "session_id": "abc"},
"Each data entry must have a 'timestamp' field.",
),
(
{"text": "What is AI?", "timestamp": "0s", "session_id": "abc"},
"Invalid timestamp: Expecting an integer but received '0s'",
),
],
)
def test_get_valid_timestamp_invalid(self, retriever, mock_data, expected_error):
with pytest.raises(GenAIPerfException, match=expected_error):
retriever._get_valid_timestamp(mock_data)

@pytest.mark.parametrize(
"mock_data, expected_timestamp",
[
(
{"text": "What is AI?", "timestamp": 0, "session_id": "abc"},
0,
),
(
{"text": "What is AI?", "timestamp": "456", "session_id": "abc"},
456,
),
],
)
def test_get_valid_timestamp_valid(self, retriever, mock_data, expected_timestamp):
timestamp = retriever._get_valid_timestamp(mock_data)
assert timestamp == expected_timestamp

0 comments on commit 03f1671

Please sign in to comment.