Skip to content

Commit

Permalink
Rewrite the Go code of the Whisper decoder using C++ (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
ling0322 authored Nov 19, 2024
1 parent ff9b942 commit 004ea5c
Show file tree
Hide file tree
Showing 31 changed files with 26,346 additions and 911 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/cmake-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ env:

jobs:
build:
runs-on: windows-2019
runs-on: windows-2022
steps:
- uses: Jimver/[email protected].16
- uses: Jimver/[email protected].19
id: cuda-toolkit
with:
cuda: '11.8.0'
cuda: '12.5.1'
method: 'network'
sub-packages: '["nvcc", "cudart", "cublas", "cublas_dev", "thrust", "visual_studio_integration"]'
- name: Install Go
Expand Down
5 changes: 3 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ if(WITH_MKL)
endif()
endif()


set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)

#add_compile_options(-fsanitize=address)
Expand Down Expand Up @@ -119,6 +119,7 @@ set(benchmark_LIBADD
${libllm_LIBADD})

add_library(catch2 STATIC "../third_party/catch2/catch_amalgamated.cpp")

add_executable(unittest "src/libllm/test_main.cc")
target_include_directories(unittest PRIVATE "src")

Expand Down
4 changes: 4 additions & 0 deletions go/bin/args.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ func (a *binArgs) getDevice() llm.Device {
return device
}

func (a *binArgs) getRawDevice() string {
return a.device
}

func (a *binArgs) addModelFlag() {
a.fs.Var(&a.models, "m", "the libllm model, it could be model name or model file,"+
" model files are with suffix \".llmpkg\". "+
Expand Down
33 changes: 24 additions & 9 deletions go/bin/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,28 +224,43 @@ func getOrDownloadModel(name string) (modelPath string, err error) {
return downloadModel(name)
}

func createModelAutoDownload(nameOrPath string, device llm.Device) (llm.Model, error) {
var modelPath string
var err error

func autoDownloadModel(nameOrPath string) (filename string, err error) {
if filepath.Ext(nameOrPath) == ".llmpkg" {
modelPath = nameOrPath
filename = nameOrPath
} else {
modelPath, err = getOrDownloadModel(nameOrPath)
filename, err = getOrDownloadModel(nameOrPath)
}

if err != nil {
return nil, err
return
}

_, err = os.Stat(modelPath)
_, err = os.Stat(filename)
if err != nil {
return
}

return
}

func createModelAutoDownload(nameOrPath string, device llm.Device) (llm.Model, error) {
modelPath, err := autoDownloadModel(nameOrPath)
if err != nil {
return nil, fmt.Errorf("model not exist: %s", modelPath)
return nil, err
}

return llm.NewModel(modelPath, device)
}

func createASRModelAutoDownload(nameOrPath, device string) (*llm.ASRModel, error) {
modelPath, err := autoDownloadModel(nameOrPath)
if err != nil {
return nil, err
}

return llm.NewASRModel(modelPath, device)
}

func printDownloadUsage(fs *flag.FlagSet) {
fmt.Fprintln(os.Stderr, "Usage: llm download [OPTIONS]")
fmt.Fprintln(os.Stderr, "")
Expand Down
10 changes: 5 additions & 5 deletions go/bin/subtitle_translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ func (t *transcriptionTranslator) translateOne(text string) (translationResult,
}

func (t *transcriptionTranslator) translate(
transcriptions []skill.TranscriptionResult) ([]skill.TranscriptionResult, error) {
translatedTrxn := []skill.TranscriptionResult{}
transcriptions []llm.RecognitionResult) ([]llm.RecognitionResult, error) {
translatedTrxn := []llm.RecognitionResult{}
for _, transcription := range transcriptions {
tr, err := t.translateOne(transcription.Text)
if err != nil {
Expand Down Expand Up @@ -142,19 +142,19 @@ func newTranscripotionTranslator(
}

// read a subtitle file to
func ReadSubtitleFile(filename string) ([]TxResult, error) {
func ReadSubtitleFile(filename string) ([]llm.RecognitionResult, error) {
subtitles, err := astisub.OpenFile(filename)
if err != nil {
return nil, err
}

txs := []TxResult{}
txs := []llm.RecognitionResult{}
for _, item := range subtitles.Items {
lines := []string{}
for _, line := range item.Lines {
lines = append(lines, line.String())
}
tx := TxResult{
tx := llm.RecognitionResult{
Begin: item.StartAt,
End: item.EndAt,
Text: strings.Join(lines, " "),
Expand Down
45 changes: 15 additions & 30 deletions go/bin/transcribe.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,12 @@ import (
"os"
"path/filepath"
"strings"
"time"

"github.com/asticode/go-astisub"
"github.com/ling0322/libllm/go/llm"
"github.com/ling0322/libllm/go/skill"
)

type TxResult = skill.TranscriptionResult

type translationConfig struct {
srcLang skill.Lang
tgtLang skill.Lang
Expand All @@ -51,9 +48,9 @@ func printTranscribeUsage(fs *flag.FlagSet) {
fmt.Fprintln(os.Stderr, "")
}

func TranscriptionToSubtitle(transcriptions []TxResult) *astisub.Subtitles {
func TranscriptionToSubtitle(results []llm.RecognitionResult) *astisub.Subtitles {
s := astisub.NewSubtitles()
for index, t := range transcriptions {
for index, t := range results {
s.Items = append(s.Items, &astisub.Item{
Index: index,
EndAt: t.End,
Expand All @@ -65,13 +62,13 @@ func TranscriptionToSubtitle(transcriptions []TxResult) *astisub.Subtitles {
return s
}

func getTranscriptionLang(transcriptions []TxResult) skill.Lang {
if len(transcriptions) == 0 {
func getTranscriptionLang(results []llm.RecognitionResult) skill.Lang {
if len(results) == 0 {
return skill.UnknownLanguage
}

langCount := map[string]int{}
for _, tx := range transcriptions {
for _, tx := range results {
langCount[tx.Language] += 1
}

Expand All @@ -93,7 +90,7 @@ func getTranscriptionLang(transcriptions []TxResult) skill.Lang {
}
}

func saveTranscription(transcriptions []TxResult, filename string) error {
func saveTranscription(transcriptions []llm.RecognitionResult, filename string) error {
slog.Info(fmt.Sprintf("save transcription to %s", filename))
subtitle := TranscriptionToSubtitle(transcriptions)
err := subtitle.Write(filename)
Expand Down Expand Up @@ -167,7 +164,7 @@ func transcribeMain(args []string) {
var tgtLang skill.Lang
modelFile := getTranscriotionModel(ba)
tgtLang = ba.getTargetLang()
device := ba.getDevice()
device := ba.getRawDevice()
inputFile := ba.getInput()
outputFile := getOutputFile(ba)

Expand All @@ -176,42 +173,30 @@ func transcribeMain(args []string) {
os.Exit(1)
}

model, err := createModelAutoDownload(modelFile, device)
model, err := createASRModelAutoDownload(modelFile, device)
if err != nil {
log.Fatal(err)
}

slog.Info(fmt.Sprintf("output file is %s", outputFile))

d0 := time.Now()
transcriber, err := skill.NewWhisperTranscriber(model, inputFile)
recognition, err := model.Recognize(inputFile)
if err != nil {
log.Fatal(err)
}

if ba.getRawLang() != "" {
transcriber.SetLanguage(ba.getRawLang())
}

transcriptions := []skill.TranscriptionResult{}
for transcriber.Transcribe() {
r := transcriber.Result()
transcriptions := []llm.RecognitionResult{}
for recognition.Next() {
r := recognition.Result()
transcriptions = append(transcriptions, r)
slog.Info(r.String())
}

if err = transcriber.Err(); err != nil {
if err = recognition.Err(); err != nil {
log.Fatal(err)
}

processingTime := time.Since(d0)
slog.Info(
fmt.Sprintf("processed %s audio in %s, rtf=%.3f",
transcriber.Offset(),
processingTime.Round(time.Millisecond),
processingTime.Seconds()/transcriber.Offset().Seconds()))

transcriber.Dispose()
recognition.Dispose()
model.Dispose()

srcLang := getTranscriptionLang(transcriptions)
Expand All @@ -230,7 +215,7 @@ func transcribeMain(args []string) {
srcLang,
tgtLang,
getTranslationModel(ba),
device,
ba.getDevice(),
}
err := TranslateSubtitle(config, outputTxFile, outputFile)
if err != nil {
Expand Down
Loading

0 comments on commit 004ea5c

Please sign in to comment.