Skip to content

Commit

Permalink
Support specifying language for whisper model (#96)
Browse files Browse the repository at this point in the history
  • Loading branch information
ling0322 authored Aug 27, 2024
1 parent 91aeb8b commit 5558b33
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 221 deletions.
5 changes: 5 additions & 0 deletions go/bin/args.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ func (a *binArgs) getLang() skill.Lang {
return lang
}

func (a *binArgs) getRawLang() string {
return a.lang
}

func (a *binArgs) addTargetLangFlag() {
a.fs.StringVar(&a.targetLang, "targetlang", "", "the target language.")
}
Expand All @@ -180,6 +184,7 @@ func (a *binArgs) getTargetLang() skill.Lang {

lang, err := skill.ParseLang(a.targetLang)
if err != nil {
slog.Error("unsupported target language (-targetlang).")
return skill.UnknownLanguage
}

Expand Down
5 changes: 5 additions & 0 deletions go/bin/transcribe.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ func transcribeMain(args []string) {
ba.addModelFlag()
ba.addInputFlag()
ba.addOutputFlag()
ba.addLangFlag()
ba.addTargetLangFlag()
_ = fs.Parse(args)

Expand Down Expand Up @@ -211,6 +212,10 @@ func transcribeMain(args []string) {
log.Fatal(err)
}

if ba.getRawLang() != "" {
transcriber.SetLanguage(ba.getRawLang())
}

transcriptions := []skill.TranscriptionResult{}
for transcriber.Transcribe() {
r := transcriber.Result()
Expand Down
108 changes: 0 additions & 108 deletions go/skill/audio.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,13 @@
package skill

import (
"bytes"
"errors"
"fmt"
"io"
"log/slog"
"os"
"os/exec"
"path/filepath"
"runtime"
"sync"
"time"

"github.com/ling0322/libllm/go/ffmpegplugin"
)

var gFfmpegBin string
var gFfmpegPluginReady bool

var BlockSize = 60 * 16000 * 2

type WaveStream struct {
Expand Down Expand Up @@ -134,100 +123,3 @@ func (s *WaveStream) ReadChunk(length time.Duration) (WaveChunk, error) {
func (s *WaveStream) Close() error {
return s.reader.Close()
}

var initAudioReader = sync.OnceFunc(func() {
err := ffmpegplugin.Init()
if err != nil {
slog.Warn(fmt.Sprintf("load ffmpeg plugin failed: %s", err))
} else {
gFfmpegPluginReady = true
}

gFfmpegBin = getFfmpegBinInternal()
if gFfmpegBin == "" {
slog.Warn("unable to find ffmpeg")
}
})

// convert the input file to pcm .wav file in OS temporary directory using ffmpeg.
func convertToPcmPlugin(inputFile string) ([]byte, error) {
reader, err := ffmpegplugin.NewReader(inputFile)
if err != nil {
return nil, err
}
defer reader.Close()

return io.ReadAll(reader)
}

// find the path of ffmpeg.
func getFfmpegBinInternal() string {
ffmpegBin := "ffmpeg"
if runtime.GOOS == "windows" {
ffmpegBin += ".exe"
}

cmd := exec.Command(ffmpegBin, "-version")
err := cmd.Run()
if err == nil {
// ffmpeg in $PATH
return ffmpegBin
}

binPath, err := os.Executable()
if err != nil {
return ""
}

binDir := filepath.Dir(binPath)
ffmpegPath := filepath.Join(binDir, ffmpegBin)
_, err = os.Stat(ffmpegPath)
if err != nil {
return ""
}

// ffmpeg found in the dir as llm, check it.
cmd = exec.Command(ffmpegPath, "-version")
err = cmd.Run()
if err != nil {
return ""
}

return ffmpegPath
}

// convert the input file to pcm .wav file in OS temporary directory using ffmpeg.
func convertToPcmBin(inputFile string) ([]byte, error) {
if gFfmpegBin == "" {
return nil, errors.New("unable to find ffmpeg")
}

// ffmpeg found in the dir as llm, check it.
cmd := exec.Command(
gFfmpegBin, "-hide_banner", "-nostdin", "-vn", "-threads", "0", "-i", inputFile, "-f",
"s16le", "-ac", "1", "-acodec", "pcm_s16le", "-ar", "16000", "-")
var dataBuffer, errBuffer bytes.Buffer
cmd.Stdout = &dataBuffer
cmd.Stderr = &errBuffer
if err := cmd.Run(); err != nil {
slog.Error("ffmpeg failed", "stderr", errBuffer.String())
return nil, err
}

return dataBuffer.Bytes(), nil
}

// read audio from media file and return as bytes of 16KHz, 16bit, mono-channel PCM.
func ReadAudioFromMediaFile(inputFile string) ([]byte, error) {
initAudioReader()

if gFfmpegPluginReady {
return convertToPcmPlugin(inputFile)
}

if gFfmpegBin != "" {
return convertToPcmBin(inputFile)
}

return nil, errors.New("unable to read media file since neither ffmpeg binary nor plugin found")
}
113 changes: 4 additions & 109 deletions go/skill/bilibili_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,40 +20,14 @@
package skill

import (
"fmt"
"math/rand"
"errors"

"github.com/ling0322/libllm/go/llm"
)

type BilibiliIndex struct {
}

// translator implemented by bilibili index model
type indexTranslator struct {
model llm.Model
}

var sysPromptIndexTranslation = "翻译%s到%s,不要换行,回复请以\"翻译结果:\"开头。"

var translationExamples = []map[Lang]string{
{
English: "Today is Sunday.",
Chinese: "今天是星期日。",
Japanese: "今日は日曜日です。",
},
{
English: "Hello.",
Chinese: "你好。",
Japanese: "ごきげんよう。",
},
{
English: "Hello, World.",
Chinese: "你好,世界。",
Japanese: "こんにちは世界。",
},
}

func (l *BilibiliIndex) Build(history []Message) (llm.Prompt, error) {
prompt := llm.NewPrompt()
if len(history) > 0 && history[0].Role == "system" {
Expand All @@ -66,91 +40,12 @@ func (l *BilibiliIndex) Build(history []Message) (llm.Prompt, error) {
prompt.AppendControlToken("<|reserved_0|>")
prompt.AppendText(message.Content)
prompt.AppendControlToken("<|reserved_1|>")
} else if message.Role == "assistent" {
} else if message.Role == "assistant" {
prompt.AppendText(message.Content)
} else {
return nil, errors.New("unexpected role")
}
}

return prompt, nil
}

func (l *indexTranslator) IsSupport(source, target Lang) bool {
var sourceOk, targetOk bool

switch source {
case English:
fallthrough
case Japanese:
sourceOk = true
default:
sourceOk = false
}

switch target {
case Chinese:
sourceOk = true
default:
sourceOk = false
}

return sourceOk && targetOk
}

func (l *indexTranslator) getLangString(lang Lang) (name string, err error) {
switch lang {
case Chinese:
return "中文", nil
case English:
return "英文", nil
case Japanese:
return "日语", nil
default:
return "", ErrUnexpectedLanguage
}
}

func (l *indexTranslator) getSysPrompt(source, target Lang) (prompt string, err error) {
srcLang, err := l.getLangString(source)
if err != nil {
return
}

tgtLang, err := l.getLangString(target)
if err != nil {
return
}

return fmt.Sprintf(sysPromptIndexTranslation, srcLang, tgtLang), nil
}

func (l *indexTranslator) Translate(request TranslationRequest) (llm.Completion, error) {
chat, err := NewChat(l.model)
if err != nil {
return nil, err
}

sysPrompt, err := l.getSysPrompt(request.SourceLang, request.TargetLang)
if err != nil {
return nil, err
}

exampleIdx := rand.Intn(len(translationExamples))
leftCtxSrc := translationExamples[exampleIdx][request.SourceLang]
leftCtxTgt := translationExamples[exampleIdx][request.TargetLang]
if request.LeftContextSource != "" {
leftCtxSrc = request.LeftContextSource
leftCtxTgt = request.LeftContextTarget
}

messages := []Message{
{"system", sysPrompt},
{"user", leftCtxSrc + request.Text},
{"assistent", "翻译结果:" + leftCtxTgt},
}

if request.Temperature > 0 {
chat.SetTemperature(request.Temperature)
}

return chat.Chat(messages)
}
Loading

0 comments on commit 5558b33

Please sign in to comment.