From cda2a1efb75c581ade2f26dec81aceeedab6e531 Mon Sep 17 00:00:00 2001 From: Tohrusky <65994850+Tohrusky@users.noreply.github.com> Date: Fri, 15 Nov 2024 18:23:11 +0000 Subject: [PATCH] fix --- .github/workflows/CI-test.yml | 2 +- pyproject.toml | 1 + tests/test_llm.py | 27 +++--- tests/test_sub.py | 10 +-- yuisub/__main__.py | 2 +- yuisub/bangumi.py | 69 +++++++------- yuisub/sub.py | 165 +++++++++++++--------------------- 7 files changed, 116 insertions(+), 160 deletions(-) diff --git a/.github/workflows/CI-test.yml b/.github/workflows/CI-test.yml index cebaae2..2e437fb 100644 --- a/.github/workflows/CI-test.yml +++ b/.github/workflows/CI-test.yml @@ -48,7 +48,7 @@ jobs: - name: Test run: | pip install numpy==1.26.4 - pip install pre-commit pytest mypy ruff types-requests pytest-cov coverage pydantic openai openai-whisper requests beautifulsoup4 tenacity pysubs2 + pip install pre-commit pytest mypy ruff types-requests pytest-cov pytest-asyncio coverage pydantic openai openai-whisper requests beautifulsoup4 tenacity pysubs2 make lint make test diff --git a/pyproject.toml b/pyproject.toml index c9bad98..c755b1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ pre-commit = "^3.7.0" [tool.poetry.group.test.dependencies] coverage = "^7.2.0" pytest = "^8.0" +pytest-asyncio = "^0.24.0" pytest-cov = "^4.0" [tool.poetry.group.typing.dependencies] diff --git a/tests/test_llm.py b/tests/test_llm.py index 2e59796..d40acfd 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -1,4 +1,3 @@ -import asyncio import os import pytest @@ -41,63 +40,63 @@ ) -def test_llm_none() -> None: +async def test_llm_none() -> None: t = Translator(model=util.OPENAI_MODEL, api_key=util.OPENAI_API_KEY, base_url=util.OPENAI_BASE_URL) print(t.system_prompt) - res = asyncio.run(t.ask(ORIGIN(origin=""))) + res = await t.ask(ORIGIN(origin="")) assert res.zh == "" @pytest.mark.skipif(os.environ.get("GITHUB_ACTIONS") == "true", reason="Skipping test when running on CI") -def test_llm() -> None: +async def test_llm() -> None: t = Translator( model=util.OPENAI_MODEL, api_key=util.OPENAI_API_KEY, base_url=util.OPENAI_BASE_URL, ) print(t.system_prompt) - res = asyncio.run(t.ask(origin)) + res = await t.ask(origin) print(res.zh) @pytest.mark.skipif(os.environ.get("GITHUB_ACTIONS") == "true", reason="Skipping test when running on CI") -def test_llm_bangumi() -> None: +async def test_llm_bangumi() -> None: t = Translator( model=util.OPENAI_MODEL, api_key=util.OPENAI_API_KEY, base_url=util.OPENAI_BASE_URL, - bangumi_info=asyncio.run(bangumi(util.BANGUMI_URL)), + bangumi_info=await bangumi(util.BANGUMI_URL), ) print(t.system_prompt) - res = asyncio.run(t.ask(origin)) + res = await t.ask(origin) print(res.zh) @pytest.mark.skipif(os.environ.get("GITHUB_ACTIONS") == "true", reason="Skipping test when running on CI") -def test_llm_bangumi_2() -> None: +async def test_llm_bangumi_2() -> None: t = Translator( model=util.OPENAI_MODEL, api_key=util.OPENAI_API_KEY, base_url=util.OPENAI_BASE_URL, - bangumi_info=asyncio.run(bangumi(util.BANGUMI_URL)), + bangumi_info=await bangumi(util.BANGUMI_URL), ) print(t.system_prompt) s = ORIGIN( origin="♪ 星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と星と", ) - res = asyncio.run(t.ask(s)) + res = await t.ask(s) print(res.zh) @pytest.mark.skipif(os.environ.get("GITHUB_ACTIONS") == "true", reason="Skipping test when running on CI") -def test_llm_summary() -> None: +async def test_llm_summary() -> None: t = Summarizer( model=util.OPENAI_MODEL, api_key=util.OPENAI_API_KEY, base_url=util.OPENAI_BASE_URL, - bangumi_info=asyncio.run(bangumi(util.BANGUMI_URL)), + bangumi_info=await bangumi(util.BANGUMI_URL), ) print(t.system_prompt) - res = asyncio.run(t.ask(summary_origin)) + res = await t.ask(summary_origin) print(res.zh) diff --git a/tests/test_sub.py b/tests/test_sub.py index 0f04f50..5a3c73a 100644 --- a/tests/test_sub.py +++ b/tests/test_sub.py @@ -7,8 +7,8 @@ from yuisub.sub import bilingual, load, translate -async def test_sub() -> None: - sub = await load(util.TEST_ENG_SRT) +def test_sub() -> None: + sub = load(util.TEST_ENG_SRT) sub.save(util.projectPATH / "assets" / "test.en.ass") @@ -19,14 +19,14 @@ def test_audio() -> None: sub.save(util.projectPATH / "assets" / "test.audio.ass") -def test_bilingual() -> None: +async def test_bilingual() -> None: sub = load(util.TEST_ENG_SRT) - _ = bilingual(sub, sub) + await bilingual(sub, sub) @pytest.mark.skipif(os.environ.get("GITHUB_ACTIONS") == "true", reason="Skipping test when running on CI") async def test_bilingual_2() -> None: - sub = await load(util.TEST_ENG_SRT) + sub = load(util.TEST_ENG_SRT) sub_zh = await translate( sub=sub, diff --git a/yuisub/__main__.py b/yuisub/__main__.py index 173832c..a5c4ece 100644 --- a/yuisub/__main__.py +++ b/yuisub/__main__.py @@ -57,7 +57,7 @@ async def main() -> None: sub = model.transcribe(audio=args.AUDIO) else: - sub = await load(args.SUB) + sub = load(args.SUB) sub_zh = await translate( sub=sub, diff --git a/yuisub/bangumi.py b/yuisub/bangumi.py index 509a402..3942fd3 100644 --- a/yuisub/bangumi.py +++ b/yuisub/bangumi.py @@ -1,20 +1,12 @@ -# bangumi.py - import asyncio import re -import time -from dataclasses import dataclass from typing import Any, Dict, List, Optional import httpx from pydantic import BaseModel -# 使用信号量限制并发请求数 -SEMAPHORE_LIMIT = 32 - -@dataclass -class Character: +class Character(BaseModel): id: int name: str chinese_name: Optional[str] = None @@ -26,21 +18,38 @@ class BGM(BaseModel): async def extract_bangumi_id(url: str) -> Optional[str]: - """从Bangumi URL中提取番剧ID""" + """ + Extract bangumi ID from Bangumi URL + + :param url: Bangumi URL + :return: Bangumi ID + """ pattern = r"(?:https?://)?(?:www\.)?(?:bangumi\.tv|bgm\.tv)/subject/(\d+)" match = re.search(pattern, url) return match.group(1) if match else None def construct_api_url(bangumi_id: str) -> str: - """根据番剧ID构建API URL""" + """ + Construct API URL based on bangumi ID + + :param bangumi_id: Bangumi ID + :return: API URL + """ return f"https://api.bgm.tv/v0/subjects/{bangumi_id}" async def get_character_info( client: httpx.AsyncClient, character: Dict[str, Any], semaphore: asyncio.Semaphore ) -> Character: - """获取单个角色的详细信息""" + """ + Get detailed info of a character + + :param client: httpx.AsyncClient + :param character: Character data + :param semaphore: asyncio.Semaphore + :return: Character object + """ async with semaphore: char_id = character["id"] char_name = character["name"] @@ -62,7 +71,13 @@ async def get_character_info( async def fetch_bangumi_data(client: httpx.AsyncClient, url: str) -> tuple[str, List[Dict[str, Any]]]: - """获取番剧基本信息和角色列表""" + """ + Get base info and character list asynchronously + + :param client: httpx.AsyncClient + :param url: Bangumi URL + :return: Tuple of introduction and character list + """ bangumi_id = await extract_bangumi_id(url) if not bangumi_id: raise ValueError("Invalid bangumi URL") @@ -81,16 +96,15 @@ async def fetch_bangumi_data(client: httpx.AsyncClient, url: str) -> tuple[str, async def bangumi(url: Optional[str] = None) -> BGM: """ - 异步获取番剧信息和角色列表 + Get bangumi info and character list asynchronously - Args: - url: Bangumi URL - - Returns: - BGM object containing introduction and characters info + :param url: Bangumi URL + :return: BGM object """ print("Getting bangumi info...") + SEMAPHORE_LIMIT = 32 + if not url: print("Warning: bangumi url is empty") return BGM(introduction="", characters="") @@ -134,20 +148,3 @@ async def bangumi(url: Optional[str] = None) -> BGM: except Exception as e: print(f"Error fetching bangumi info: {e}") raise - - -async def main() -> None: - """ - Main function for testing bangumi functionality - """ - url = "https://bangumi.tv/subject/315574" - start_time = time.time() - result = await bangumi(url) - use_time = time.time() - start_time - print(f"Introduction:\n{result.introduction[:100]}...") - print(f"Characters:\n{result.characters}") - print(f"Use time: {use_time}") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/yuisub/sub.py b/yuisub/sub.py index aed9390..66508d4 100644 --- a/yuisub/sub.py +++ b/yuisub/sub.py @@ -63,18 +63,16 @@ def advertisement(ad: Optional[str] = None, start: int = 0, end: int = 5000) -> return sub_ad -async def load(sub_path: Union[Path, str], encoding: str = "utf-8") -> SSAFile: +def load(sub_path: Union[Path, str], encoding: str = "utf-8") -> SSAFile: """ - 异步加载字幕文件 - Load subtitle from file path, default encoding is utf-8 and remove style :param sub_path: subtitle file path :param encoding: subtitle file encoding, default is utf-8 :return: """ - # 由于pysubs2.load本身是同步的, 我们使用线程池来避免阻塞 - return await asyncio.to_thread(pysubs2.load, str(sub_path), encoding=encoding) + sub = pysubs2.load(str(sub_path), encoding=encoding) + return sub @retry(wait=wait_random(min=3, max=5), stop=stop_after_attempt(5)) @@ -88,17 +86,6 @@ async def translate( ad: Optional[SSAEvent] = advertisement(), # noqa: B008 ) -> SSAFile: """ - 异步翻译字幕文件 - - :param sub: 原始字幕 - :param model: LLM模型 - :param api_key: API密钥 - :param base_url: API基础URL - :param bangumi_url: bangumi URL - :param styles: 字幕样式 - :param ad: 广告信息 - :return: 翻译后的字幕文件 - Translate subtitle file to Chinese :param sub: origin subtitle @@ -110,69 +97,64 @@ async def translate( :param ad: add advertisement to subtitle, default is TensoRaws :return: """ - try: - # 获取待翻译的文本列表 - trans_list: List[str] = [s.text for s in sub] - - # 异步获取bangumi信息 - bangumi_info = await bangumi(bangumi_url) if bangumi_url else None - - # 初始化总结器 - summarizer = Summarizer( - model=model, - api_key=api_key, - base_url=base_url, - bangumi_info=bangumi_info, - ) - - print("Summarizing...") - # 获取总结 - summary = await summarizer.ask(ORIGIN(origin="\n".join(trans_list))) - - # 初始化翻译器 - translator = Translator( - model=model, - api_key=api_key, - base_url=base_url, - bangumi_info=bangumi_info, - summary=summary.zh, - ) - print(translator.system_prompt) - - # 创建翻译任务 - async def translate_text(index: int) -> None: - nonlocal trans_list - translated_text = await translator.ask(ORIGIN(origin=trans_list[index])) - print(f"Translated: {trans_list[index]} ---> {translated_text.zh}") - trans_list[index] = translated_text.zh - - # 并发执行翻译任务 - tasks = [translate_text(i) for i in range(len(sub))] - await asyncio.gather(*tasks) - - # 生成中文字幕 - if styles is None: - styles = PRESET_STYLES - - sub_zh = SSAFile() - sub_zh.styles = styles - - # 添加广告 - if ad: - sub_zh.append(ad) - - # 复制并更新字幕 - sub_temp = deepcopy(sub) - for i, e in enumerate(sub_temp): - e.style = "zh" - e.text = trans_list[i] - sub_zh.append(e) - - return sub_zh - - except Exception as e: - print(f"Translation error: {e}") - raise + # pending translation + trans_list: List[str] = [s.text for s in sub] + + # get bangumi info asynchronously + bangumi_info = await bangumi(bangumi_url) if bangumi_url else None + + # initialize summarizer + summarizer = Summarizer( + model=model, + api_key=api_key, + base_url=base_url, + bangumi_info=bangumi_info, + ) + + print("Summarizing...") + # get summary + summary = await summarizer.ask(ORIGIN(origin="\n".join(trans_list))) + + # initialize translator + translator = Translator( + model=model, + api_key=api_key, + base_url=base_url, + bangumi_info=bangumi_info, + summary=summary.zh, + ) + print(translator.system_prompt) + + # create translate text task + async def _translate(index: int) -> None: + nonlocal trans_list + translated_text = await translator.ask(ORIGIN(origin=trans_list[index])) + print(f"Translated: {trans_list[index]} ---> {translated_text.zh}") + trans_list[index] = translated_text.zh + + # start translation tasks + tasks = [_translate(i) for i in range(len(sub))] + await asyncio.gather(*tasks) + + # gen Chinese subtitle + if styles is None: + styles = PRESET_STYLES + + sub_zh = SSAFile() + sub_zh.styles = styles + + # add advertisement + if ad: + sub_zh.append(ad) + + # copy origin subtitle and replace text with translated text + sub_temp = deepcopy(sub) + for i, e in enumerate(sub_temp): + e.style = "zh" + e.text = trans_list[i] + sub_zh.append(e) + + return sub_zh async def bilingual( @@ -181,9 +163,7 @@ async def bilingual( styles: Optional[Dict[str, SSAStyle]] = None, ) -> SSAFile: """ - 异步生成双语字幕 - - Generate bilingual subtitle file + Generate bilingual subtitle file asynchronously :param sub_origin: Origin subtitle :param sub_zh: Chinese subtitle @@ -207,24 +187,3 @@ async def bilingual( sub_bilingual.append(e) return sub_bilingual - - -# 使用示例 -async def main() -> None: - # 加载字幕 - sub = await load("path/to/subtitle.ass") - - # 翻译字幕 - translated_sub = await translate( - sub=sub, model="your-model", api_key="your-api-key", base_url="your-base-url", bangumi_url="your-bangumi-url" - ) - - # 生成双语字幕 - bilingual_sub = await bilingual(sub, translated_sub) - - # 保存字幕 - await asyncio.to_thread(bilingual_sub.save, "output.ass") - - -if __name__ == "__main__": - asyncio.run(main())