Skip to content

Commit

Permalink
Fix bugs for Whisper models (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
ling0322 authored Sep 27, 2024
1 parent 4870780 commit 5544966
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 48 deletions.
4 changes: 4 additions & 0 deletions go/bin/args.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ func (a *binArgs) getInput() string {
return a.inputFile
}

func (a *binArgs) tryGetInput() string {
return a.inputFile
}

func (a *binArgs) addOutputFlag() {
a.fs.StringVar(&a.outputFile, "o", "", "the output file.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ package main

import (
"fmt"
"log"
"log/slog"
"strings"

"github.com/asticode/go-astisub"
"github.com/ling0322/libllm/go/llm"
"github.com/ling0322/libllm/go/skill"
)
Expand Down Expand Up @@ -82,7 +84,7 @@ func (t *transcriptionTranslator) translateOneWithRetry(text string) (translatio
func (t *transcriptionTranslator) translateOne(text string) (translationResult, error) {
tr, err := t.translateOneWithRetry(text)
if err != nil {
return translationResult{}, nil
return translationResult{}, err
}

if tr.tgtText == "" {
Expand Down Expand Up @@ -138,3 +140,56 @@ func newTranscripotionTranslator(
tgtLang: tgtLang,
}, nil
}

// read a subtitle file to
func ReadSubtitleFile(filename string) ([]TxResult, error) {
subtitles, err := astisub.OpenFile(filename)
if err != nil {
return nil, err
}

txs := []TxResult{}
for _, item := range subtitles.Items {
lines := []string{}
for _, line := range item.Lines {
lines = append(lines, line.String())
}
tx := TxResult{
Begin: item.StartAt,
End: item.EndAt,
Text: strings.Join(lines, " "),
}

txs = append(txs, tx)
}

return txs, nil
}

func TranslateSubtitle(config translationConfig, inputFile, outputFile string) error {
tt, err := newTranscripotionTranslator(
config.modelName,
config.device,
config.srcLang,
config.tgtLang)
if err != nil {
return err
}

txs, err := ReadSubtitleFile(inputFile)
if err != nil {
return err
}

txs, err = tt.translate(txs)
if err != nil {
return err
}

err = saveTranscription(txs, outputFile)
if err != nil {
log.Fatal(err)
}

return nil
}
34 changes: 10 additions & 24 deletions go/bin/transcribe.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,29 +104,6 @@ func saveTranscription(transcriptions []TxResult, filename string) error {
return nil
}

func translateTranscription(txs []TxResult, config translationConfig, filename string) error {
tt, err := newTranscripotionTranslator(
config.modelName,
config.device,
config.srcLang,
config.tgtLang)
if err != nil {
return err
}

txs, err = tt.translate(txs)
if err != nil {
return err
}

err = saveTranscription(txs, filename)
if err != nil {
log.Fatal(err)
}

return nil
}

func getOutputFile(ba *binArgs) string {
outputFile := ba.tryGetOutput()
if outputFile != "" {
Expand Down Expand Up @@ -239,14 +216,23 @@ func transcribeMain(args []string) {

srcLang := getTranscriptionLang(transcriptions)
if srcLang != skill.UnknownLanguage && tgtLang != skill.UnknownLanguage && srcLang != tgtLang {
// save transcription (without translation)
fileExt := filepath.Ext(outputFile)
fileBaseName := outputFile[:len(outputFile)-len(fileExt)]
outputTxFile := fmt.Sprintf("%s.%s%s", fileBaseName, srcLang.String(), fileExt)
err = saveTranscription(transcriptions, outputTxFile)
if err != nil {
log.Fatal(err)
}

// translation
config := translationConfig{
srcLang,
tgtLang,
getTranslationModel(ba),
device,
}
err := translateTranscription(transcriptions, config, outputFile)
err := TranslateSubtitle(config, outputTxFile, outputFile)
if err != nil {
log.Fatal(err)
}
Expand Down
73 changes: 55 additions & 18 deletions go/bin/translate.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"io"
"log"
"os"
"path/filepath"
"strings"
"time"

Expand Down Expand Up @@ -86,24 +87,7 @@ func translate(translator skill.Translator, req skill.TranslationRequest, onToke
}, nil
}

func translationMain(args []string) {
fs := flag.NewFlagSet("", flag.ExitOnError)
fs.Usage = func() {
printTranslateUsage(fs)
}

ba := newBinArgs(fs)
ba.addModelFlag()
ba.addDeviceFlag()
ba.addLangFlag()
ba.addTargetLangFlag()
_ = fs.Parse(args)

if fs.NArg() != 0 {
fs.Usage()
os.Exit(1)
}

func interactiveTranslation(ba *binArgs) {
modelName := ba.getModel()
model, err := createModelAutoDownload(modelName, ba.getDevice())
if err != nil {
Expand Down Expand Up @@ -150,3 +134,56 @@ func translationMain(args []string) {
)
}
}

func isSubtitleFile(filename string) bool {
if filepath.Ext(filename) == ".srt" {
return true
} else {
return false
}
}

func translateSubtitleFile(ba *binArgs) {
// translation
config := translationConfig{
ba.getLang(),
ba.getTargetLang(),
getTranslationModel(ba),
ba.getDevice(),
}
err := TranslateSubtitle(config, ba.getInput(), ba.getOutput())
if err != nil {
log.Fatal(err)
}
}

func translationMain(args []string) {
fs := flag.NewFlagSet("", flag.ExitOnError)
fs.Usage = func() {
printTranslateUsage(fs)
}

ba := newBinArgs(fs)
ba.addModelFlag()
ba.addDeviceFlag()
ba.addLangFlag()
ba.addTargetLangFlag()
ba.addInputFlag()
ba.addOutputFlag()
_ = fs.Parse(args)

if fs.NArg() != 0 {
fs.Usage()
os.Exit(1)
}

if ba.tryGetInput() != "" {
if isSubtitleFile(ba.getInput()) {
translateSubtitleFile(ba)
} else {
log.Fatal("file not supported.")
}
} else {
interactiveTranslation(ba)
}
}
9 changes: 9 additions & 0 deletions go/skill/audio.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ func NewWaveStream(filename string) (*WaveStream, error) {
}, nil
}

func (w *WaveChunk) Duration() time.Duration {
return w.end - w.begin
}

func bytesToDuration(numBytes int) time.Duration {
return time.Duration(int64(numBytes) * int64(time.Second) / 2 / 16000)
}

func durationToBytes(dur time.Duration) int {
nsPerSample := 1000000000 / SampleRate
nSamples := int(dur.Nanoseconds() / int64(nsPerSample))
Expand Down Expand Up @@ -104,6 +112,7 @@ func (s *WaveStream) ReadChunk(length time.Duration) (WaveChunk, error) {
eof := false
if errors.Is(err, io.EOF) {
eof = true
length = bytesToDuration(len(s.buffer))
if len(s.buffer) == 0 {
return WaveChunk{}, io.EOF
}
Expand Down
1 change: 1 addition & 0 deletions go/skill/bilibili_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func (l *BilibiliIndex) Build(history []Message) (llm.Prompt, error) {
if len(history) > 0 && history[0].Role == "system" {
prompt.AppendControlToken("<unk>")
prompt.AppendText(history[0].Content)
history = history[1:]
}

for _, message := range history {
Expand Down
4 changes: 2 additions & 2 deletions go/skill/translation.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func NewTranslator(model llm.Model) (Translator, error) {
}
}

var sysPromptIndexTranslation = "翻译%s到%s,不要换行,回复请以\"翻译结果:\"开头。"
var sysPromptIndexTranslation = "翻译%s到%s,不能有换行符,回复请以\"翻译结果:\"开头。"

var translationExamples = []map[Lang]string{
{
Expand Down Expand Up @@ -110,7 +110,7 @@ func (l *chatTranslator) getLangString(lang Lang) (name string, err error) {
case Chinese:
return "中文", nil
case English:
return "英文", nil
return "英语", nil
case Japanese:
return "日语", nil
default:
Expand Down
Loading

0 comments on commit 5544966

Please sign in to comment.