Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: class-based design #15

Merged
merged 20 commits into from
Dec 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/CI-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
CI:
strategy:
matrix:
os-version: ["ubuntu-20.04", "macos-13", "windows-latest"]
os-version: ["ubuntu-20.04", "windows-latest", "macos-13"]
python-version: ["3.9"]
poetry-version: ["1.8.3"]

Expand All @@ -48,7 +48,7 @@ jobs:
- name: Test
run: |
pip install numpy==1.26.4
pip install pre-commit pytest mypy ruff types-requests pytest-cov pytest-asyncio coverage pydantic openai openai-whisper requests beautifulsoup4 tenacity pysubs2
pip install pre-commit pytest mypy ruff types-requests pytest-cov pytest-asyncio coverage pydantic openai openai-whisper httpx tenacity pysubs2

make lint
make test
Expand Down
41 changes: 15 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,36 +40,25 @@ yuisub -h # Displays help message
```python3
import asyncio

from yuisub import translate, bilingual, load
from yuisub.a2t import WhisperModel
from yuisub import SubtitleTranslator

# use an asynchronous environment
# Using an asynchronous environment
async def main() -> None:

# sub from audio
model = WhisperModel(name="medium", device="cuda")
sub = model.transcribe(audio="path/to/audio.mp3")

# sub from file
# sub = load("path/to/input.srt")

# generate bilingual subtitle
sub_zh = await translate(
sub=sub,
model="gpt_model_name",
api_key="your_openai_api_key",
base_url="api_url",
bangumi_url="https://bangumi.tv/subject/424883/"
)

sub_bilingual = await bilingual(
sub_origin=sub,
sub_zh=sub_zh
translator = SubtitleTranslator(
# if you wanna use audio input
# torch_device='cuda',
# whisper_model='medium',

model='gpt_model_name',
api_key='your_openai_api_key',
base_url='api_url',
bangumi_url='https://bangumi.tv/subject/424883/',
bangumi_access_token='your_bangumi_token',
)

# save the ASS files
sub_zh.save("path/to/output.zh.ass")
sub_bilingual.save("path/to/output.bilingual.ass")
sub_zh, sub_bilingual = await translator.get_subtitles(sub='path/to/sub.srt') # Or audio='path/to/audio.mp3',
sub_zh.save('path/to/output_zh.ass')
sub_bilingual.save('path/to/output_bilingual.ass')

asyncio.run(main())
```
Expand Down
4 changes: 3 additions & 1 deletion tests/test_bangumi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from yuisub import bangumi

from . import util


async def test_bangumi() -> None:
url_list = [
Expand All @@ -9,6 +11,6 @@ async def test_bangumi() -> None:
]

for url in url_list:
r = await bangumi(url)
r = await bangumi(url=url, token=util.BANGUMI_ACCESS_TOKEN)
print(r.introduction)
print(r.characters)
9 changes: 5 additions & 4 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import pytest

from tests import util
from yuisub import ORIGIN, Summarizer, Translator, bangumi

from . import util

origin = ORIGIN(
origin="何だよ…けっこう多いじゃねぇか",
)
Expand Down Expand Up @@ -65,7 +66,7 @@ async def test_llm_bangumi() -> None:
model=util.OPENAI_MODEL,
api_key=util.OPENAI_API_KEY,
base_url=util.OPENAI_BASE_URL,
bangumi_info=await bangumi(util.BANGUMI_URL),
bangumi_info=await bangumi(url=util.BANGUMI_URL, token=util.BANGUMI_ACCESS_TOKEN),
)
print(t.system_prompt)
res = await t.ask(origin)
Expand All @@ -78,7 +79,7 @@ async def test_llm_bangumi_2() -> None:
model=util.OPENAI_MODEL,
api_key=util.OPENAI_API_KEY,
base_url=util.OPENAI_BASE_URL,
bangumi_info=await bangumi(util.BANGUMI_URL),
bangumi_info=await bangumi(url=util.BANGUMI_URL, token=util.BANGUMI_ACCESS_TOKEN),
)
print(t.system_prompt)
s = ORIGIN(
Expand All @@ -95,7 +96,7 @@ async def test_llm_summary() -> None:
model=util.OPENAI_MODEL,
api_key=util.OPENAI_API_KEY,
base_url=util.OPENAI_BASE_URL,
bangumi_info=await bangumi(util.BANGUMI_URL),
bangumi_info=await bangumi(url=util.BANGUMI_URL, token=util.BANGUMI_ACCESS_TOKEN),
)
print(t.system_prompt)
res = await t.ask(summary_origin)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import pytest

from tests import util
from yuisub.a2t import WhisperModel
from yuisub.sub import bilingual, load, translate

from . import util


def test_sub() -> None:
sub = load(util.TEST_ENG_SRT)
Expand Down Expand Up @@ -34,6 +35,7 @@ async def test_bilingual_2() -> None:
api_key=util.OPENAI_API_KEY,
base_url=util.OPENAI_BASE_URL,
bangumi_url=util.BANGUMI_URL,
bangumi_access_token=util.BANGUMI_ACCESS_TOKEN,
)
sub_bilingual = await bilingual(sub_origin=sub, sub_zh=sub_zh)

Expand Down
39 changes: 39 additions & 0 deletions tests/test_translator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os

import pytest

from yuisub.translator import SubtitleTranslator

from . import util


@pytest.mark.skipif(os.environ.get("GITHUB_ACTIONS") == "true", reason="Skipping test when running on CI")
async def test_translator_sub() -> None:
translator = SubtitleTranslator(
model=util.OPENAI_MODEL,
api_key=util.OPENAI_API_KEY,
base_url=util.OPENAI_BASE_URL,
bangumi_url=util.BANGUMI_URL,
bangumi_access_token=util.BANGUMI_ACCESS_TOKEN,
)

Comment on lines +10 to +19
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Consider using mocks for CI environment instead of skipping tests

Rather than skipping these tests in CI, consider mocking the external dependencies (Whisper model, OpenAI API) to allow these tests to run in all environments. This would provide better test coverage and catch potential issues earlier.

@pytest.mark.asyncio
@mock.patch('your_module.SubtitleTranslator.get_subtitles')
async def test_translator_sub(mock_get_subtitles) -> None:
    mock_get_subtitles.return_value = (Mock(), Mock())
    translator = SubtitleTranslator(
        model=util.OPENAI_MODEL,
        api_key="mock_key",
        base_url="mock_url",
        bangumi_url="mock_url",
        bangumi_access_token="mock_token"
    )
    await translator.get_subtitles(sub=str(util.TEST_ENG_SRT))

sub_zh, sub_bilingual = await translator.get_subtitles(sub=str(util.TEST_ENG_SRT))
Comment on lines +11 to +20
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Test should verify the content of the translated subtitles

The test only checks if the files are saved but doesn't verify the actual content of the translations. Consider adding assertions to check the translated text, timing, and format of both sub_zh and sub_bilingual.

async def test_translator_sub() -> None:
    translator = SubtitleTranslator(
        model=util.OPENAI_MODEL,
        api_key=util.OPENAI_API_KEY,
        base_url=util.OPENAI_BASE_URL,
        bangumi_url=util.BANGUMI_URL,
        bangumi_access_token=util.BANGUMI_ACCESS_TOKEN,
    )
    sub_zh, sub_bilingual = await translator.get_subtitles(sub=str(util.TEST_ENG_SRT))
    assert "你好" in str(sub_zh)
    assert "Hello" in str(sub_bilingual) and "你好" in str(sub_bilingual)

sub_zh.save(util.projectPATH / "assets" / "test.zh.translator.sub.ass")
sub_bilingual.save(util.projectPATH / "assets" / "test.bilingual.translator.sub.ass")


@pytest.mark.skipif(os.environ.get("GITHUB_ACTIONS") == "true", reason="Skipping test when running on CI")
async def test_translator_audio() -> None:
translator = SubtitleTranslator(
torch_device=util.DEVICE,
whisper_model=util.MODEL_NAME,
model=util.OPENAI_MODEL,
api_key=util.OPENAI_API_KEY,
base_url=util.OPENAI_BASE_URL,
bangumi_url=util.BANGUMI_URL,
bangumi_access_token=util.BANGUMI_ACCESS_TOKEN,
)
Comment on lines +26 to +35
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Add error case tests for the SubtitleTranslator

The tests only cover the happy path. Consider adding tests for error cases such as invalid audio files, network errors, invalid API keys, and other edge cases that could occur during translation.

async def test_translator_audio() -> None:
    translator = SubtitleTranslator(
        torch_device=util.DEVICE,
        whisper_model=util.MODEL_NAME,
        model=util.OPENAI_MODEL,
        api_key=util.OPENAI_API_KEY,
        base_url=util.OPENAI_BASE_URL,
        bangumi_url=util.BANGUMI_URL,
        bangumi_access_token=util.BANGUMI_ACCESS_TOKEN,
    )

    sub_zh, sub_bilingual = await translator.get_subtitles(audio=str(util.TEST_AUDIO))
    sub_zh.save(util.projectPATH / "assets" / "test.zh.translator.audio.ass")
    sub_bilingual.save(util.projectPATH / "assets" / "test.bilingual.translator.audio.ass")

    with pytest.raises(FileNotFoundError):
        await translator.get_subtitles(audio="nonexistent_file.mp3")

    with pytest.raises(Exception):
        invalid_translator = SubtitleTranslator(
            torch_device=util.DEVICE,
            whisper_model=util.MODEL_NAME,
            model=util.OPENAI_MODEL,
            api_key="invalid_key",
            base_url=util.OPENAI_BASE_URL,
            bangumi_url=util.BANGUMI_URL,
            bangumi_access_token=util.BANGUMI_ACCESS_TOKEN,
        )
        await invalid_translator.get_subtitles(audio=str(util.TEST_AUDIO))


sub_zh, sub_bilingual = await translator.get_subtitles(audio=str(util.TEST_AUDIO))
sub_zh.save(util.projectPATH / "assets" / "test.zh.translator.audio.ass")
sub_bilingual.save(util.projectPATH / "assets" / "test.bilingual.translator.audio.ass")
5 changes: 2 additions & 3 deletions tests/util.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import os
from pathlib import Path

import torch

projectPATH = Path(__file__).resolve().parent.parent.absolute()

TEST_AUDIO = projectPATH / "assets" / "test.mp3"
TEST_ENG_SRT = projectPATH / "assets" / "eng.srt"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE = "cpu" if os.environ.get("GITHUB_ACTIONS") == "true" else None
MODEL_NAME = "medium" if DEVICE == "cuda" else "tiny"

BANGUMI_URL = "https://bangumi.tv/subject/424883"
BANGUMI_ACCESS_TOKEN = ""

OPENAI_MODEL = str(os.getenv("OPENAI_MODEL")) if os.getenv("OPENAI_MODEL") else "deepseek-chat"
OPENAI_BASE_URL = str(os.getenv("OPENAI_BASE_URL")) if os.getenv("OPENAI_BASE_URL") else "https://api.deepseek.com"
Expand Down
1 change: 1 addition & 0 deletions yuisub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from yuisub.llm import Summarizer, Translator # noqa: F401
from yuisub.prompt import ORIGIN, ZH # noqa: F401
from yuisub.sub import advertisement, bilingual, load, translate # noqa: F401
from yuisub.translator import SubtitleTranslator # noqa: F401
57 changes: 17 additions & 40 deletions yuisub/__main__.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@
import argparse
import asyncio
import sys

from yuisub.sub import bilingual, load, translate
from yuisub import SubtitleTranslator

# ffmpeg -i test.mkv -c:a mp3 -map 0:a:0 test.mp3
# ffmpeg -i test.mkv -map 0:s:0 eng.srt
parser = argparse.ArgumentParser(description="Generate Bilingual Subtitle from audio or subtitle file")

parser = argparse.ArgumentParser()
parser.description = "Generate Bilingual Subtitle from audio or subtitle file"
# input
# Input
parser.add_argument("-a", "--AUDIO", type=str, help="Path to the audio file", required=False)
parser.add_argument("-s", "--SUB", type=str, help="Path to the input Subtitle file", required=False)
# subtitle output
# Output
parser.add_argument("-oz", "--OUTPUT_ZH", type=str, help="Path to save the Chinese ASS file", required=False)
parser.add_argument("-ob", "--OUTPUT_BILINGUAL", type=str, help="Path to save the bilingual ASS file", required=False)
# openai gpt
# OpenAI GPT
parser.add_argument("-om", "--OPENAI_MODEL", type=str, help="Openai model name", required=True)
parser.add_argument("-api", "--OPENAI_API_KEY", type=str, help="Openai API key", required=True)
parser.add_argument("-url", "--OPENAI_BASE_URL", type=str, help="Openai base URL", required=True)
# bangumi
# Bangumi
parser.add_argument("-bgm", "--BANGUMI_URL", type=str, help="Anime Bangumi URL", required=False)
parser.add_argument("-ac", "--BANGUMI_ACCESS_TOKEN", type=str, help="Anime Bangumi Access Token", required=False)
# whisper
# Whisper
parser.add_argument("-d", "--TORCH_DEVICE", type=str, help="Pytorch device to use", required=False)
parser.add_argument("-wm", "--WHISPER_MODEL", type=str, help="Whisper model to use", required=False)

Expand All @@ -33,47 +29,28 @@ async def main() -> None:
if args.AUDIO and args.SUB:
raise ValueError("Please provide only one input file, either audio or subtitle file")

if not args.AUDIO and not args.SUB:
raise ValueError("Please provide an input file, either audio or subtitle file")

if not args.OUTPUT_ZH and not args.OUTPUT_BILINGUAL:
raise ValueError("Please provide output paths for the subtitles.")

if args.AUDIO:
import torch

from yuisub.a2t import WhisperModel

if args.TORCH_DEVICE:
_DEVICE = args.TORCH_DEVICE
else:
_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if sys.platform == "darwin":
_DEVICE = "mps"

if args.WHISPER_MODEL:
_MODEL = args.WHISPER_MODEL
else:
_MODEL = "medium" if _DEVICE == "cpu" else "large-v2"

model = WhisperModel(name=_MODEL, device=_DEVICE)

sub = model.transcribe(audio=args.AUDIO)

else:
sub = load(args.SUB)

sub_zh = await translate(
sub=sub,
translator = SubtitleTranslator(
model=args.OPENAI_MODEL,
api_key=args.OPENAI_API_KEY,
base_url=args.OPENAI_BASE_URL,
bangumi_url=args.BANGUMI_URL,
bangumi_access_token=args.BANGUMI_ACCESS_TOKEN,
torch_device=args.TORCH_DEVICE,
whisper_model=args.WHISPER_MODEL,
)

sub_bilingual = await bilingual(sub_origin=sub, sub_zh=sub_zh)

sub_zh, sub_bilingual = await translator.get_subtitles(
sub=args.SUB,
audio=args.AUDIO,
)
if args.OUTPUT_ZH:
sub_zh.save(args.OUTPUT_ZH)

if args.OUTPUT_BILINGUAL:
sub_bilingual.save(args.OUTPUT_BILINGUAL)

Expand Down
6 changes: 5 additions & 1 deletion yuisub/a2t.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@

class WhisperModel:
def __init__(
self, name: str = "medium", device: str = "cuda", download_root: Optional[str] = None, in_memory: bool = False
self,
name: str = "medium",
device: Optional[Union[str, torch.device]] = None,
download_root: Optional[str] = None,
in_memory: bool = False,
):
self.model = whisper.load_model(name=name, device=device, download_root=download_root, in_memory=in_memory)

Expand Down
2 changes: 1 addition & 1 deletion yuisub/sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ async def translate(
base_url=base_url,
bangumi_info=bangumi_info,
)
print(summarizer.system_prompt)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Remove or replace debug print statement

Consider using a proper logging system instead of print statements if this information is important for debugging.

import logging

logging.debug(summarizer.system_prompt)


print("Summarizing...")
# get summary
summary = await summarizer.ask(ORIGIN(origin="\n".join(trans_list)))

Expand Down
Loading