Skip to content

Commit

Permalink
fixes for rebasing on main
Browse files Browse the repository at this point in the history
  • Loading branch information
markurtz committed Aug 2, 2024
1 parent 56f728b commit 33b7de0
Show file tree
Hide file tree
Showing 9 changed files with 10 additions and 72 deletions.
2 changes: 1 addition & 1 deletion src/guidellm/backend/openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import openai
from typing import Any, Dict, Generator, List, Optional

import openai
from config import settings
from loguru import logger
from openai import OpenAI, Stream
Expand Down
2 changes: 1 addition & 1 deletion src/guidellm/core/serializable.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,4 @@ def load_file(cls, path: Union[str, Path]):
else:
raise ValueError(f"Unsupported file format: {type_}")

return obj
return obj
1 change: 0 additions & 1 deletion src/guidellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def main(
max_requests,
output_path,
):

# Create backend
_backend = Backend.create(
backend_type=backend,
Expand Down
58 changes: 0 additions & 58 deletions src/guidellm/request/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,71 +58,13 @@ def __init__(
# function requires attributes above
super().__init__(tokenizer, mode, async_queue_size)

def _load_dataset(self):
"""
Load the dataset based on the options given either as a dataset name or
a local path.
If no split or column is provided, attempt to infer the best options.
:return: The loaded dataset.
"""

# first load the initial dataset
if self._dataset.endswith(".csv") or self._dataset.endswith(".json"):
logger.debug(f"Loading dataset from local path: {self._dataset}")
extension = self._dataset.split(".")[-1]
dataset = load_dataset(extension, data_files=self._dataset, **self._kwargs)
elif self._dataset.endswith(".py"):
logger.debug(f"Loading dataset from local script: {self._dataset}")
dataset = load_dataset(self._dataset, **self._kwargs)
else:
logger.debug(f"Loading dataset: {self._dataset}")
dataset = load_dataset(self._dataset, **self._kwargs)

# Infer split if not provided
if self._split is None:
for split in PREFERRED_DATA_SPLITS:
if split in dataset.keys():
self._split = split
break
if self._split is None:
self._split = list(dataset.keys())[0]
logger.info(f"Inferred split to use: {self._split}")

# Infer column if not provided
if self._column is None:
for col in PREFERRED_DATA_COLUMNS:
if col in dataset[self._split].column_names:
self._column = col
break
if self._column is None:
self._column = dataset[self._split].column_names[0]
logger.info(f"Inferred column to use for prompts: {self._column}")

dataset = dataset[self._split]
logger.info(
f"Loaded dataset {self._dataset} with split: {self._split} "
f"and column: {self._column}"
)

return dataset

def create_item(self) -> TextGenerationRequest:
"""
Create a new result request item from the dataset.
:return: A new result request.
:rtype: TextGenerationRequest
"""
try:
getattr(self, "_hf_dataset")
except AttributeError:
self._hf_dataset = self._load_dataset()

try:
getattr(self, "_iterator")
except (StopIteration, AttributeError):
self._iterator = iter(self._hf_dataset)

data = next(self._iterator)

Expand Down
6 changes: 3 additions & 3 deletions tests/unit/cli/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
from click.testing import CliRunner


@pytest.fixture
@pytest.fixture()
def cli_runner():
return CliRunner()


@pytest.fixture
@pytest.fixture()
def patch_main(mocker) -> MagicMock:
return mocker.patch("guidellm.main.main.callback")


@pytest.fixture
@pytest.fixture()
def default_main_kwargs() -> Dict[str, Any]:
"""
All the defaults come from the `guidellm.main` function.
Expand Down
3 changes: 1 addition & 2 deletions tests/unit/cli/test_application_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import pytest
from click.testing import CliRunner

from guidellm.main import main


Expand Down Expand Up @@ -32,7 +31,7 @@ def test_main_cli_overrided(


@pytest.mark.parametrize(
"args,expected_stdout",
("args", "expected_stdout"),
[
(
["--backend", "invalid", "--rate-type", "sweep"],
Expand Down
7 changes: 3 additions & 4 deletions tests/unit/cli/test_main_validation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import pytest

from guidellm.main import main


def test_task_without_data(mocker, default_main_kwargs):
patch = mocker.patch("guidellm.backend.Backend.create")
default_main_kwargs.update({"task": "can't be used without data"})
with pytest.raises(NotImplementedError):
getattr(main, "callback")(**default_main_kwargs)
main.callback(**default_main_kwargs) # type: ignore

assert patch.call_count == 1

Expand All @@ -17,7 +16,7 @@ def test_invalid_data_type(mocker, default_main_kwargs):
default_main_kwargs.update({"data_type": "invalid"})

with pytest.raises(ValueError):
getattr(main, "callback")(**default_main_kwargs)
main.callback(**default_main_kwargs) # type: ignore

assert patch.call_count == 1

Expand All @@ -31,7 +30,7 @@ def test_invalid_rate_type(mocker, default_main_kwargs):
default_main_kwargs.update({"rate_type": "invalid", "data_type": "file"})

with pytest.raises(ValueError):
getattr(main, "callback")(**default_main_kwargs)
main.callback(**default_main_kwargs) # type: ignore

assert patch.call_count == 1
assert file_request_generator_initialization_patch.call_count == 1
2 changes: 1 addition & 1 deletion tests/unit/test_logger.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from loguru import logger
from pathlib import Path

import pytest
from config import LoggingSettings
from guidellm.logger import configure_logger
from loguru import logger


@pytest.fixture(autouse=True)
Expand Down
1 change: 0 additions & 1 deletion utils/inject_build_props.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import re
from datetime import datetime

from pathlib import Path

import toml
Expand Down

0 comments on commit 33b7de0

Please sign in to comment.