From 55449665b682b2a16b5ee126175348ecf705094f Mon Sep 17 00:00:00 2001 From: Xiaoyang Chen Date: Sat, 28 Sep 2024 00:34:08 +0800 Subject: [PATCH] Fix bugs for Whisper models (#99) --- go/bin/args.go | 4 + ...n_translator.go => subtitle_translator.go} | 57 ++++++++++++++- go/bin/transcribe.go | 34 +++------ go/bin/translate.go | 73 ++++++++++++++----- go/skill/audio.go | 9 +++ go/skill/bilibili_index.go | 1 + go/skill/translation.go | 4 +- go/skill/whisper.go | 55 ++++++++++++++ src/libllm/whisper.cc | 7 +- 9 files changed, 196 insertions(+), 48 deletions(-) rename go/bin/{transcription_translator.go => subtitle_translator.go} (79%) diff --git a/go/bin/args.go b/go/bin/args.go index e151490..72cfe2e 100644 --- a/go/bin/args.go +++ b/go/bin/args.go @@ -132,6 +132,10 @@ func (a *binArgs) getInput() string { return a.inputFile } +func (a *binArgs) tryGetInput() string { + return a.inputFile +} + func (a *binArgs) addOutputFlag() { a.fs.StringVar(&a.outputFile, "o", "", "the output file.") } diff --git a/go/bin/transcription_translator.go b/go/bin/subtitle_translator.go similarity index 79% rename from go/bin/transcription_translator.go rename to go/bin/subtitle_translator.go index 1c5ccd0..630a6ac 100644 --- a/go/bin/transcription_translator.go +++ b/go/bin/subtitle_translator.go @@ -21,9 +21,11 @@ package main import ( "fmt" + "log" "log/slog" "strings" + "github.com/asticode/go-astisub" "github.com/ling0322/libllm/go/llm" "github.com/ling0322/libllm/go/skill" ) @@ -82,7 +84,7 @@ func (t *transcriptionTranslator) translateOneWithRetry(text string) (translatio func (t *transcriptionTranslator) translateOne(text string) (translationResult, error) { tr, err := t.translateOneWithRetry(text) if err != nil { - return translationResult{}, nil + return translationResult{}, err } if tr.tgtText == "" { @@ -138,3 +140,56 @@ func newTranscripotionTranslator( tgtLang: tgtLang, }, nil } + +// read a subtitle file to +func ReadSubtitleFile(filename string) ([]TxResult, error) { + subtitles, err := astisub.OpenFile(filename) + if err != nil { + return nil, err + } + + txs := []TxResult{} + for _, item := range subtitles.Items { + lines := []string{} + for _, line := range item.Lines { + lines = append(lines, line.String()) + } + tx := TxResult{ + Begin: item.StartAt, + End: item.EndAt, + Text: strings.Join(lines, " "), + } + + txs = append(txs, tx) + } + + return txs, nil +} + +func TranslateSubtitle(config translationConfig, inputFile, outputFile string) error { + tt, err := newTranscripotionTranslator( + config.modelName, + config.device, + config.srcLang, + config.tgtLang) + if err != nil { + return err + } + + txs, err := ReadSubtitleFile(inputFile) + if err != nil { + return err + } + + txs, err = tt.translate(txs) + if err != nil { + return err + } + + err = saveTranscription(txs, outputFile) + if err != nil { + log.Fatal(err) + } + + return nil +} diff --git a/go/bin/transcribe.go b/go/bin/transcribe.go index dd86db8..379d805 100644 --- a/go/bin/transcribe.go +++ b/go/bin/transcribe.go @@ -104,29 +104,6 @@ func saveTranscription(transcriptions []TxResult, filename string) error { return nil } -func translateTranscription(txs []TxResult, config translationConfig, filename string) error { - tt, err := newTranscripotionTranslator( - config.modelName, - config.device, - config.srcLang, - config.tgtLang) - if err != nil { - return err - } - - txs, err = tt.translate(txs) - if err != nil { - return err - } - - err = saveTranscription(txs, filename) - if err != nil { - log.Fatal(err) - } - - return nil -} - func getOutputFile(ba *binArgs) string { outputFile := ba.tryGetOutput() if outputFile != "" { @@ -239,6 +216,15 @@ func transcribeMain(args []string) { srcLang := getTranscriptionLang(transcriptions) if srcLang != skill.UnknownLanguage && tgtLang != skill.UnknownLanguage && srcLang != tgtLang { + // save transcription (without translation) + fileExt := filepath.Ext(outputFile) + fileBaseName := outputFile[:len(outputFile)-len(fileExt)] + outputTxFile := fmt.Sprintf("%s.%s%s", fileBaseName, srcLang.String(), fileExt) + err = saveTranscription(transcriptions, outputTxFile) + if err != nil { + log.Fatal(err) + } + // translation config := translationConfig{ srcLang, @@ -246,7 +232,7 @@ func transcribeMain(args []string) { getTranslationModel(ba), device, } - err := translateTranscription(transcriptions, config, outputFile) + err := TranslateSubtitle(config, outputTxFile, outputFile) if err != nil { log.Fatal(err) } diff --git a/go/bin/translate.go b/go/bin/translate.go index e14c974..98fde4d 100644 --- a/go/bin/translate.go +++ b/go/bin/translate.go @@ -27,6 +27,7 @@ import ( "io" "log" "os" + "path/filepath" "strings" "time" @@ -86,24 +87,7 @@ func translate(translator skill.Translator, req skill.TranslationRequest, onToke }, nil } -func translationMain(args []string) { - fs := flag.NewFlagSet("", flag.ExitOnError) - fs.Usage = func() { - printTranslateUsage(fs) - } - - ba := newBinArgs(fs) - ba.addModelFlag() - ba.addDeviceFlag() - ba.addLangFlag() - ba.addTargetLangFlag() - _ = fs.Parse(args) - - if fs.NArg() != 0 { - fs.Usage() - os.Exit(1) - } - +func interactiveTranslation(ba *binArgs) { modelName := ba.getModel() model, err := createModelAutoDownload(modelName, ba.getDevice()) if err != nil { @@ -150,3 +134,56 @@ func translationMain(args []string) { ) } } + +func isSubtitleFile(filename string) bool { + if filepath.Ext(filename) == ".srt" { + return true + } else { + return false + } +} + +func translateSubtitleFile(ba *binArgs) { + // translation + config := translationConfig{ + ba.getLang(), + ba.getTargetLang(), + getTranslationModel(ba), + ba.getDevice(), + } + err := TranslateSubtitle(config, ba.getInput(), ba.getOutput()) + if err != nil { + log.Fatal(err) + } +} + +func translationMain(args []string) { + fs := flag.NewFlagSet("", flag.ExitOnError) + fs.Usage = func() { + printTranslateUsage(fs) + } + + ba := newBinArgs(fs) + ba.addModelFlag() + ba.addDeviceFlag() + ba.addLangFlag() + ba.addTargetLangFlag() + ba.addInputFlag() + ba.addOutputFlag() + _ = fs.Parse(args) + + if fs.NArg() != 0 { + fs.Usage() + os.Exit(1) + } + + if ba.tryGetInput() != "" { + if isSubtitleFile(ba.getInput()) { + translateSubtitleFile(ba) + } else { + log.Fatal("file not supported.") + } + } else { + interactiveTranslation(ba) + } +} diff --git a/go/skill/audio.go b/go/skill/audio.go index bd9a1c7..216fc9d 100644 --- a/go/skill/audio.go +++ b/go/skill/audio.go @@ -53,6 +53,14 @@ func NewWaveStream(filename string) (*WaveStream, error) { }, nil } +func (w *WaveChunk) Duration() time.Duration { + return w.end - w.begin +} + +func bytesToDuration(numBytes int) time.Duration { + return time.Duration(int64(numBytes) * int64(time.Second) / 2 / 16000) +} + func durationToBytes(dur time.Duration) int { nsPerSample := 1000000000 / SampleRate nSamples := int(dur.Nanoseconds() / int64(nsPerSample)) @@ -104,6 +112,7 @@ func (s *WaveStream) ReadChunk(length time.Duration) (WaveChunk, error) { eof := false if errors.Is(err, io.EOF) { eof = true + length = bytesToDuration(len(s.buffer)) if len(s.buffer) == 0 { return WaveChunk{}, io.EOF } diff --git a/go/skill/bilibili_index.go b/go/skill/bilibili_index.go index b8762df..fed9204 100644 --- a/go/skill/bilibili_index.go +++ b/go/skill/bilibili_index.go @@ -33,6 +33,7 @@ func (l *BilibiliIndex) Build(history []Message) (llm.Prompt, error) { if len(history) > 0 && history[0].Role == "system" { prompt.AppendControlToken("") prompt.AppendText(history[0].Content) + history = history[1:] } for _, message := range history { diff --git a/go/skill/translation.go b/go/skill/translation.go index 258156a..960285d 100644 --- a/go/skill/translation.go +++ b/go/skill/translation.go @@ -63,7 +63,7 @@ func NewTranslator(model llm.Model) (Translator, error) { } } -var sysPromptIndexTranslation = "翻译%s到%s,不要换行,回复请以\"翻译结果:\"开头。" +var sysPromptIndexTranslation = "翻译%s到%s,不能有换行符,回复请以\"翻译结果:\"开头。" var translationExamples = []map[Lang]string{ { @@ -110,7 +110,7 @@ func (l *chatTranslator) getLangString(lang Lang) (name string, err error) { case Chinese: return "中文", nil case English: - return "英文", nil + return "英语", nil case Japanese: return "日语", nil default: diff --git a/go/skill/whisper.go b/go/skill/whisper.go index 3b08c1f..8635e82 100644 --- a/go/skill/whisper.go +++ b/go/skill/whisper.go @@ -20,6 +20,8 @@ package skill import ( + "bytes" + "compress/zlib" "errors" "fmt" "io" @@ -27,6 +29,7 @@ import ( "math" "regexp" "strconv" + "strings" "time" "github.com/ling0322/libllm/go/llm" @@ -68,6 +71,9 @@ type WhisperTranscriber struct { // the specified whisper language. Empty means let whisper to predict. language string + + // the transcription history. + history []string } // create a new instance of WhisperTranscriber from whisper model and stream of input file. @@ -136,6 +142,10 @@ func (w *WhisperTranscriber) decodeTranscription() (TranscriptionResult, error) return TranscriptionResult{}, fmt.Errorf("%w: not a time token", ErrInvalidWhisperSequence) } result.Begin = w.stream.Offset() + beginOffset + if w.chunk.Duration()-beginOffset < 200*time.Millisecond { + slog.Info("return ErrNoMoreResults", "begin", result.Begin, "audio_len", w.chunk.end) + return TranscriptionResult{}, ErrNoMoreResults + } transcriptionDone := false for w.comp.Next() { @@ -158,6 +168,7 @@ func (w *WhisperTranscriber) decodeTranscription() (TranscriptionResult, error) return TranscriptionResult{}, ErrNoMoreResults } + result.End = min(result.End, w.chunk.end) return result, nil } @@ -190,6 +201,9 @@ func (w *WhisperTranscriber) prefillNextAudioSegment() error { w.chunk, err = w.stream.ReadChunk(30 * time.Second) if errors.Is(err, io.EOF) { return ErrAudioEndOfStream + } else if w.chunk.Duration() < 500*time.Millisecond { + slog.Info("w.chunk.Duration() < 500*time.Millisecond") + return ErrAudioEndOfStream } else if err != nil { return err } @@ -271,10 +285,12 @@ func (w *WhisperTranscriber) Transcribe() bool { if w.comp == nil { w.err = w.prefillNextAudioSegment() if errors.Is(w.err, ErrNoMoreResults) { + slog.Info("segment end", "reason", "ErrNoMoreResultsPrefill") w.disposeCompAndSetToNil() w.streamOffset += 30 * time.Second continue } else if errors.Is(w.err, ErrAudioEndOfStream) { + slog.Info("segment end", "reason", "ErrAudioEndOfStream") w.disposeCompAndSetToNil() w.err = nil return false @@ -282,16 +298,19 @@ func (w *WhisperTranscriber) Transcribe() bool { return false } beginOfSegment = true + w.resetRepetitionChecker() } result, err := w.decodeTranscription() if errors.Is(err, ErrNoMoreResults) && beginOfSegment { // if no result for the whole audio segment, move forward to the next 30s segment. + slog.Info("segment end", "reason", "ErrNoMoreResults") w.disposeCompAndSetToNil() w.streamOffset += 30 * time.Second continue } else if errors.Is(err, ErrNoMoreResults) && !beginOfSegment { // move the wave offset to the end of last completed transcription. + slog.Info("segment end", "reason", "ErrNoMoreResults") w.disposeCompAndSetToNil() w.streamOffset = w.result.End continue @@ -300,11 +319,47 @@ func (w *WhisperTranscriber) Transcribe() bool { return false } + if w.checkRepetition(result) { + // once repetition happened, stop the current decoding process. + slog.Info("segment end", "reason", "RepetitionDetected") + w.disposeCompAndSetToNil() + w.streamOffset = w.result.End + continue + } + w.result = result return true } } +func (w *WhisperTranscriber) resetRepetitionChecker() { + w.history = []string{} +} + +func (w *WhisperTranscriber) checkRepetition(r TranscriptionResult) bool { + w.history = append(w.history, r.Text) + if len(w.history) > 200 { + w.history = w.history[1:] + } + + text := strings.Join(w.history, " ") + + // compress text + var b bytes.Buffer + + writer := zlib.NewWriter(&b) + writer.Write([]byte(text)) + writer.Close() + + compressionRatio := float32(len(text)) * 1.0 / float32(len(b.Bytes())) + if compressionRatio > 2.4 { + slog.Info("check repetition", "compression_ratio", compressionRatio) + return true + } else { + return false + } +} + // implements interface Transcriber. func (w *WhisperTranscriber) Result() TranscriptionResult { return w.result diff --git a/src/libllm/whisper.cc b/src/libllm/whisper.cc index 6ad8127..80ccc61 100644 --- a/src/libllm/whisper.cc +++ b/src/libllm/whisper.cc @@ -727,15 +727,16 @@ void WhisperLogitsProcessor::processLogits(Tensor logits) { if (lastWasTimestamp) { _lastTimeTokenIdx = static_cast(_history.size()); if (penultimateWasTimestamp) { - // do not mask the <|30.00|> timestamp tag - F::fill(logits.slice(-1, {_beginTimeToken, _endTimeToken}), -Inf); + F::fill(logits.slice(-1, {_beginTimeToken, _endTimeToken + 1}), -Inf); } else { F::fill(logits.slice(-1, {0, _eotToken + 1}), -Inf); } } if (_lastTimeToken > _beginTimeToken) { - F::fill(logits.slice(-1, {_beginTimeToken, _lastTimeToken + 1}), -Inf); + // do not mask the <|30.00|> timestamp tag + int endToken = std::min(_lastTimeToken + 1, _endTimeToken); + F::fill(logits.slice(-1, {_beginTimeToken, endToken}), -Inf); } Tensor probs = F::softmax(logits);