Skip to content

Commit

Permalink
Updates for pydantic serialization / deserialization to rebase on cur…
Browse files Browse the repository at this point in the history
…rent main, address review comments, and finalize last pieces as well as test fixes
  • Loading branch information
markurtz committed Jul 19, 2024
1 parent ff910a0 commit 19399b7
Show file tree
Hide file tree
Showing 13 changed files with 226 additions and 136 deletions.
3 changes: 3 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ repos:
datasets,
loguru,
numpy,
pydantic,
pyyaml,
openai,
requests,
transformers,
Expand All @@ -38,4 +40,5 @@ repos:
# types
types-click,
types-requests,
types-PyYAML,
]
9 changes: 8 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ dependencies = [
"loguru",
"numpy",
"openai",
"pydantic>=2.0.0",
"pyyaml>=6.0.0",
"requests",
"transformers",
]
Expand All @@ -48,7 +50,12 @@ dev = [
"pytest-mock~=3.14.0",
"ruff~=0.5.2",
"tox~=4.16.0",
"types-requests~=2.32.0"
"types-requests~=2.32.0",

# type-checking
"types-click",
"types-requests",
"types-PyYAML",
]


Expand Down
8 changes: 5 additions & 3 deletions src/guidellm/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,18 @@ def submit(self, request: TextGenerationRequest) -> TextGenerationResult:

logger.info(f"Submitting request with prompt: {request.prompt}")

result = TextGenerationResult(TextGenerationRequest(prompt=request.prompt))
result = TextGenerationResult(
request=TextGenerationRequest(prompt=request.prompt)
)
result.start(request.prompt)

for response in self.make_request(request): # GenerativeResponse
if response.type_ == "token_iter" and response.add_token:
result.output_token(response.add_token)
elif response.type_ == "final":
result.end(
response.prompt_token_count,
response.output_token_count,
prompt_token_count=response.prompt_token_count,
output_token_count=response.output_token_count,
)

logger.info(f"Request completed with output: {result.output}")
Expand Down
33 changes: 11 additions & 22 deletions src/guidellm/core/distribution.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import List, Optional, Union
from typing import List, Sequence

import numpy as np
from loguru import logger
from pydantic import Field

from guidellm.core.serializable import Serializable

Expand All @@ -14,24 +15,12 @@ class Distribution(Serializable):
statistical analyses.
"""

def __init__(self, **data):
super().__init__(**data)
logger.debug(f"Initialized Distribution with data: {self.data}")
data: Sequence[float] = Field(
default_factory=list, description="The data points of the distribution."
)

def __str__(self) -> str:
"""
Return a string representation of the Distribution.
"""
return (
f"Distribution(mean={self.mean:.2f}, median={self.median:.2f}, "
f"min={self.min}, max={self.max}, count={len(self.data)})"
)

def __repr__(self) -> str:
"""
Return an unambiguous string representation of the Distribution for debugging.
"""
return f"Distribution(data={self.data})"
def __str__(self):
return f"Distribution({self.describe()})"

@property
def mean(self) -> float:
Expand Down Expand Up @@ -99,7 +88,7 @@ def percentile(self, percentile: float) -> float:
logger.warning("No data points available to calculate percentile.")
return 0.0

percentile_value = np.percentile(self._data, percentile).item()
percentile_value = np.percentile(self.data, percentile).item()
logger.debug(f"Calculated {percentile}th percentile: {percentile_value}")
return percentile_value

Expand Down Expand Up @@ -180,15 +169,15 @@ def describe(self) -> dict:
logger.debug(f"Generated description: {description}")
return description

def add_data(self, new_data: Union[List[int], List[float]]):
def add_data(self, new_data: Sequence[float]):
"""
Add new data points to the distribution.
:param new_data: A list of new numerical data points to add.
"""
self.data.extend(new_data)
self.data = list(self.data) + list(new_data)
logger.debug(f"Added new data: {new_data}")

def remove_data(self, remove_data: Union[List[int], List[float]]):
def remove_data(self, remove_data: Sequence[float]):
"""
Remove specified data points from the distribution.
:param remove_data: A list of numerical data points to remove.
Expand Down
40 changes: 18 additions & 22 deletions src/guidellm/core/request.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import uuid
from typing import Dict, Optional, Any
from typing import Any, Dict, Optional

from pydantic import Field

from guidellm.core.serializable import Serializable

Expand All @@ -9,24 +11,18 @@ class TextGenerationRequest(Serializable):
A class to represent a text generation request for generative AI workloads.
"""

id: str
prompt: str
prompt_token_count: Optional[int]
generated_token_count: Optional[int]
params: Dict[str, Any]

def __init__(
self,
prompt: str,
prompt_token_count: Optional[int] = None,
generated_token_count: Optional[int] = None,
params: Optional[Dict[str, Any]] = None,
id: Optional[str] = None,
):
super().__init__(
id=str(uuid.uuid4()) if id is None else id,
prompt=prompt,
prompt_token_count=prompt_token_count,
generated_token_count=generated_token_count,
params=params or {},
)
id: str = Field(
default_factory=lambda: str(uuid.uuid4()),
description="The unique identifier for the request.",
)
prompt: str = Field(description="The input prompt for the text generation.")
prompt_token_count: Optional[int] = Field(
default=None, description="The number of tokens in the input prompt."
)
generate_token_count: Optional[int] = Field(
default=None, description="The number of tokens to generate."
)
params: Dict[str, Any] = Field(
default_factory=dict,
description="The parameters for the text generation request.",
)
Loading

0 comments on commit 19399b7

Please sign in to comment.