Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Tohrusky committed Nov 15, 2024
1 parent 6c5b0be commit cda2a1e
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 160 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
- name: Test
run: |
pip install numpy==1.26.4
pip install pre-commit pytest mypy ruff types-requests pytest-cov coverage pydantic openai openai-whisper requests beautifulsoup4 tenacity pysubs2
pip install pre-commit pytest mypy ruff types-requests pytest-cov pytest-asyncio coverage pydantic openai openai-whisper requests beautifulsoup4 tenacity pysubs2
make lint
make test
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ pre-commit = "^3.7.0"
[tool.poetry.group.test.dependencies]
coverage = "^7.2.0"
pytest = "^8.0"
pytest-asyncio = "^0.24.0"
pytest-cov = "^4.0"

[tool.poetry.group.typing.dependencies]
Expand Down
27 changes: 13 additions & 14 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import os

import pytest
Expand Down Expand Up @@ -41,63 +40,63 @@
)


def test_llm_none() -> None:
async def test_llm_none() -> None:
t = Translator(model=util.OPENAI_MODEL, api_key=util.OPENAI_API_KEY, base_url=util.OPENAI_BASE_URL)
print(t.system_prompt)
res = asyncio.run(t.ask(ORIGIN(origin="")))
res = await t.ask(ORIGIN(origin=""))
assert res.zh == ""


@pytest.mark.skipif(os.environ.get("GITHUB_ACTIONS") == "true", reason="Skipping test when running on CI")
def test_llm() -> None:
async def test_llm() -> None:
t = Translator(
model=util.OPENAI_MODEL,
api_key=util.OPENAI_API_KEY,
base_url=util.OPENAI_BASE_URL,
)
print(t.system_prompt)
res = asyncio.run(t.ask(origin))
res = await t.ask(origin)
print(res.zh)


@pytest.mark.skipif(os.environ.get("GITHUB_ACTIONS") == "true", reason="Skipping test when running on CI")
def test_llm_bangumi() -> None:
async def test_llm_bangumi() -> None:
t = Translator(
model=util.OPENAI_MODEL,
api_key=util.OPENAI_API_KEY,
base_url=util.OPENAI_BASE_URL,
bangumi_info=asyncio.run(bangumi(util.BANGUMI_URL)),
bangumi_info=await bangumi(util.BANGUMI_URL),
)
print(t.system_prompt)
res = asyncio.run(t.ask(origin))
res = await t.ask(origin)
print(res.zh)


@pytest.mark.skipif(os.environ.get("GITHUB_ACTIONS") == "true", reason="Skipping test when running on CI")
def test_llm_bangumi_2() -> None:
async def test_llm_bangumi_2() -> None:
t = Translator(
model=util.OPENAI_MODEL,
api_key=util.OPENAI_API_KEY,
base_url=util.OPENAI_BASE_URL,
bangumi_info=asyncio.run(bangumi(util.BANGUMI_URL)),
bangumi_info=await bangumi(util.BANGUMI_URL),
)
print(t.system_prompt)
s = ORIGIN(
origin="♪ 星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と",
)

res = asyncio.run(t.ask(s))
res = await t.ask(s)
print(res.zh)


@pytest.mark.skipif(os.environ.get("GITHUB_ACTIONS") == "true", reason="Skipping test when running on CI")
def test_llm_summary() -> None:
async def test_llm_summary() -> None:
t = Summarizer(
model=util.OPENAI_MODEL,
api_key=util.OPENAI_API_KEY,
base_url=util.OPENAI_BASE_URL,
bangumi_info=asyncio.run(bangumi(util.BANGUMI_URL)),
bangumi_info=await bangumi(util.BANGUMI_URL),
)
print(t.system_prompt)
res = asyncio.run(t.ask(summary_origin))
res = await t.ask(summary_origin)
print(res.zh)
10 changes: 5 additions & 5 deletions tests/test_sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from yuisub.sub import bilingual, load, translate


async def test_sub() -> None:
sub = await load(util.TEST_ENG_SRT)
def test_sub() -> None:
sub = load(util.TEST_ENG_SRT)
sub.save(util.projectPATH / "assets" / "test.en.ass")


Expand All @@ -19,14 +19,14 @@ def test_audio() -> None:
sub.save(util.projectPATH / "assets" / "test.audio.ass")


def test_bilingual() -> None:
async def test_bilingual() -> None:
sub = load(util.TEST_ENG_SRT)
_ = bilingual(sub, sub)
await bilingual(sub, sub)


@pytest.mark.skipif(os.environ.get("GITHUB_ACTIONS") == "true", reason="Skipping test when running on CI")
async def test_bilingual_2() -> None:
sub = await load(util.TEST_ENG_SRT)
sub = load(util.TEST_ENG_SRT)

sub_zh = await translate(
sub=sub,
Expand Down
2 changes: 1 addition & 1 deletion yuisub/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ async def main() -> None:
sub = model.transcribe(audio=args.AUDIO)

else:
sub = await load(args.SUB)
sub = load(args.SUB)

sub_zh = await translate(
sub=sub,
Expand Down
69 changes: 33 additions & 36 deletions yuisub/bangumi.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
# bangumi.py

import asyncio
import re
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import httpx
from pydantic import BaseModel

# 使用信号量限制并发请求数
SEMAPHORE_LIMIT = 32


@dataclass
class Character:
class Character(BaseModel):
id: int
name: str
chinese_name: Optional[str] = None
Expand All @@ -26,21 +18,38 @@ class BGM(BaseModel):


async def extract_bangumi_id(url: str) -> Optional[str]:
"""从Bangumi URL中提取番剧ID"""
"""
Extract bangumi ID from Bangumi URL
:param url: Bangumi URL
:return: Bangumi ID
"""
pattern = r"(?:https?://)?(?:www\.)?(?:bangumi\.tv|bgm\.tv)/subject/(\d+)"
match = re.search(pattern, url)
return match.group(1) if match else None


def construct_api_url(bangumi_id: str) -> str:
"""根据番剧ID构建API URL"""
"""
Construct API URL based on bangumi ID
:param bangumi_id: Bangumi ID
:return: API URL
"""
return f"https://api.bgm.tv/v0/subjects/{bangumi_id}"


async def get_character_info(
client: httpx.AsyncClient, character: Dict[str, Any], semaphore: asyncio.Semaphore
) -> Character:
"""获取单个角色的详细信息"""
"""
Get detailed info of a character
:param client: httpx.AsyncClient
:param character: Character data
:param semaphore: asyncio.Semaphore
:return: Character object
"""
async with semaphore:
char_id = character["id"]
char_name = character["name"]
Expand All @@ -62,7 +71,13 @@ async def get_character_info(


async def fetch_bangumi_data(client: httpx.AsyncClient, url: str) -> tuple[str, List[Dict[str, Any]]]:
"""获取番剧基本信息和角色列表"""
"""
Get base info and character list asynchronously
:param client: httpx.AsyncClient
:param url: Bangumi URL
:return: Tuple of introduction and character list
"""
bangumi_id = await extract_bangumi_id(url)
if not bangumi_id:
raise ValueError("Invalid bangumi URL")
Expand All @@ -81,16 +96,15 @@ async def fetch_bangumi_data(client: httpx.AsyncClient, url: str) -> tuple[str,

async def bangumi(url: Optional[str] = None) -> BGM:
"""
异步获取番剧信息和角色列表
Get bangumi info and character list asynchronously
Args:
url: Bangumi URL
Returns:
BGM object containing introduction and characters info
:param url: Bangumi URL
:return: BGM object
"""
print("Getting bangumi info...")

SEMAPHORE_LIMIT = 32

if not url:
print("Warning: bangumi url is empty")
return BGM(introduction="", characters="")
Expand Down Expand Up @@ -134,20 +148,3 @@ async def bangumi(url: Optional[str] = None) -> BGM:
except Exception as e:
print(f"Error fetching bangumi info: {e}")
raise


async def main() -> None:
"""
Main function for testing bangumi functionality
"""
url = "https://bangumi.tv/subject/315574"
start_time = time.time()
result = await bangumi(url)
use_time = time.time() - start_time
print(f"Introduction:\n{result.introduction[:100]}...")
print(f"Characters:\n{result.characters}")
print(f"Use time: {use_time}")


if __name__ == "__main__":
asyncio.run(main())
Loading

0 comments on commit cda2a1e

Please sign in to comment.