Skip to content

Commit

Permalink
Merge pull request #97 from LlmKira/dev
Browse files Browse the repository at this point in the history
🎨 refactor(generate_voice): update speaker attributes and API usage
  • Loading branch information
sudoskys authored Jan 4, 2025
2 parents 52ab424 + e5981ba commit 904bf85
Show file tree
Hide file tree
Showing 17 changed files with 408 additions and 366 deletions.
2 changes: 1 addition & 1 deletion playground/generate_voice.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ async def generate_voice(text: str):
try:
voice_gen = VoiceGenerate.build(
text=text,
voice_engine=VoiceSpeakerV1.Crina, # VoiceSpeakerV2.Ligeia,
speaker=VoiceSpeakerV2.Ligeia, # VoiceSpeakerV2.Ligeia,
)
result = await voice_gen.request(
session=credential
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "novelai-python"
version = "0.7.2"
version = "0.7.3"
description = "NovelAI Python Binding With Pydantic"
authors = [
{ name = "sudoskys", email = "[email protected]" },
Expand Down
8 changes: 8 additions & 0 deletions src/novelai_python/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ class ConcurrentGenerationError(APIError):
pass


class DataSerializationError(APIError):
"""
DataSerializationError is raised when the API data is not serializable.
"""

pass


class AuthError(APIError):
"""
AuthError is raised when the API returns an error.
Expand Down
139 changes: 75 additions & 64 deletions src/novelai_python/sdk/ai/augment_image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from io import BytesIO
from typing import Optional, Union, IO, Any
from urllib.parse import urlparse
from zipfile import ZipFile
from zipfile import ZipFile, BadZipFile

import curl_cffi
import httpx
Expand All @@ -25,13 +25,16 @@
from novelai_python.sdk.ai._enum import Model
from ._enum import ReqType, Moods
from ...schema import ApiBaseModel
from ...._exceptions import APIError, AuthError, ConcurrentGenerationError, SessionHttpError
from ...._exceptions import APIError, AuthError, ConcurrentGenerationError, SessionHttpError, DataSerializationError
from ...._response.ai.generate_image import ImageGenerateResp, RequestParams
from ....credential import CredentialBase
from ....utils import try_jsonfy


class AugmentImageInfer(ApiBaseModel):
"""
https://docs.novelai.net/image/directortools.html
"""
_endpoint: str = PrivateAttr("https://image.novelai.net")

@property
Expand Down Expand Up @@ -210,96 +213,104 @@ async def request(self,
:param session: session
:return:
"""
# Data Build
# Prepare request data
request_data = self.model_dump(mode="json", exclude_none=True)
async with session if isinstance(session, AsyncSession) else await session.get_session() as sess:
# Header
sess.headers.update(await self.necessary_headers(request_data))
if override_headers:
sess.headers.clear()
sess.headers.update(override_headers)

# Log the request data (sanitize sensitive content)
try:
_log_data = deepcopy(request_data)
_log_data.update({
"image": "base64 data"
})
logger.debug(f"Request Data: {_log_data}")
del _log_data
if self.image:
_log_data["image"] = "base64 data hidden"
logger.debug(f"Request Data: {json.dumps(_log_data, indent=2)}")
except Exception as e:
logger.warning(f"Error when print log data: {e}")
logger.warning(f"Failed to log request data: {e}")

# Perform request and handle response
try:
assert hasattr(sess, "post"), "session must have post method."
self.ensure_session_has_post_method(sess)
response = await sess.post(
self.base_url,
data=json.dumps(request_data).encode("utf-8")
)
if response.headers.get('Content-Type') not in ['binary/octet-stream', 'application/x-zip-compressed']:
logger.warning(
f"Error with content type: {response.headers.get('Content-Type')} and code: {response.status_code}"
)
try:
_msg = response.json()
except Exception as e:
logger.warning(e)
if not isinstance(response.content, str) and len(response.content) < 50:
raise APIError(
message=f"Unexpected content type: {response.headers.get('Content-Type')}",
request=request_data,
code=response.status_code,
response=try_jsonfy(response.content)
)
else:
_msg = {"statusCode": response.status_code, "message": response.content}
status_code = _msg.get("statusCode", response.status_code)
message = _msg.get("message", "Unknown error")
if (
response.headers.get('Content-Type') not in ['binary/octet-stream',
'application/x-zip-compressed']
or response.status_code >= 400
):
error_message = await self.handle_error_response(response, request_data)
status_code = error_message.get("statusCode", response.status_code)
message = error_message.get("message", "Unknown error")
if status_code in [400, 401, 402]:
# 400 : validation error
# 401 : unauthorized
# 402 : payment required
# 409 : conflict
raise AuthError(message, request=request_data, code=status_code, response=_msg)
if status_code in [409]:
# conflict error
raise APIError(message, request=request_data, code=status_code, response=_msg)
if status_code in [429]:
# concurrent error
raise AuthError(message, request=request_data, code=status_code, response=error_message)
elif status_code == 409:
raise APIError(message, request=request_data, code=status_code, response=error_message)
elif status_code == 429:
raise ConcurrentGenerationError(
message=message,
request=request_data,
code=status_code,
response=_msg
)
raise APIError(message, request=request_data, code=status_code, response=_msg)
zip_file = ZipFile(BytesIO(response.content))
unzip_content = []
with zip_file as zf:
file_list = zf.namelist()
if not file_list:
raise APIError(
message="No file in zip",
request=request_data,
code=response.status_code,
response=try_jsonfy(response.content)
response=error_message,
)
for filename in file_list:
data = zip_file.read(filename)
unzip_content.append((filename, data))
return ImageGenerateResp(
meta=RequestParams(
endpoint=self.base_url,
raw_request=request_data,
),
files=unzip_content
)
else:
raise APIError(message, request=request_data, code=status_code, response=error_message)

# Unpack the ZIP response
try:
zip_file = ZipFile(BytesIO(response.content))
unzip_content = []
with zip_file as zf:
file_list = zf.namelist()
if not file_list:
raise DataSerializationError(
message="The ZIP response contains no files.",
request=request_data,
response=try_jsonfy(response.content),
code=response.status_code,
)
for filename in file_list:
data = zip_file.read(filename)
unzip_content.append((filename, data))
return ImageGenerateResp(
meta=RequestParams(
endpoint=self.base_url,
raw_request=request_data,
),
files=unzip_content,
)
except BadZipFile as e:
# Invalid ZIP file - indicate serialization error
logger.exception("The response content is not a valid ZIP file.")
raise DataSerializationError(
message="Invalid ZIP file received from the API.",
request=request_data,
response={},
code=response.status_code,
) from e
except Exception as e:
logger.exception("Unexpected error while unpacking ZIP response.")
raise DataSerializationError(
message="An unexpected error occurred while processing ZIP data.",
request=request_data,
response={},
code=response.status_code,
) from e
except curl_cffi.requests.errors.RequestsError as exc:
logger.exception(exc)
raise SessionHttpError(
"An AsyncSession RequestsError occurred, maybe SSL error. Try again later!") from exc
raise SessionHttpError("A RequestsError occurred (e.g., SSL error). Try again later.")
except httpx.HTTPError as exc:
logger.exception(exc)
raise SessionHttpError("An HTTPError occurred, maybe SSL error. Try again later!") from exc
raise SessionHttpError("An HTTP error occurred. Try again later.")
except APIError as e:
raise e
except Exception as e:
logger.opt(exception=e).exception("An Unexpected error occurred")
raise e
logger.opt(exception=e).exception("Unexpected error occurred during the request.")
raise Exception("An unexpected error occurred.") from e
2 changes: 1 addition & 1 deletion src/novelai_python/sdk/ai/generate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ async def request(self,
logger.debug(f"LLM request data: {json.dumps(request_data)}")
# Request
try:
assert hasattr(sess, "post"), "session must have post method."
self.ensure_session_has_post_method(sess)
response = await sess.post(
self.base_url,
json=request_data,
Expand Down
Loading

0 comments on commit 904bf85

Please sign in to comment.