Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite the Go code of the Whisper decoder with C++ #103

Merged
merged 9 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/cmake-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ env:

jobs:
build:
runs-on: windows-2019
runs-on: windows-2022
steps:
- uses: Jimver/[email protected].16
- uses: Jimver/[email protected].19
id: cuda-toolkit
with:
cuda: '11.8.0'
cuda: '12.5.1'
method: 'network'
sub-packages: '["nvcc", "cudart", "cublas", "cublas_dev", "thrust", "visual_studio_integration"]'
- name: Install Go
Expand Down
5 changes: 3 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ if(WITH_MKL)
endif()
endif()


set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)

#add_compile_options(-fsanitize=address)
Expand Down Expand Up @@ -119,6 +119,7 @@ set(benchmark_LIBADD
${libllm_LIBADD})

add_library(catch2 STATIC "../third_party/catch2/catch_amalgamated.cpp")

add_executable(unittest "src/libllm/test_main.cc")
target_include_directories(unittest PRIVATE "src")

Expand Down
4 changes: 4 additions & 0 deletions go/bin/args.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ func (a *binArgs) getDevice() llm.Device {
return device
}

func (a *binArgs) getRawDevice() string {
return a.device
}

func (a *binArgs) addModelFlag() {
a.fs.Var(&a.models, "m", "the libllm model, it could be model name or model file,"+
" model files are with suffix \".llmpkg\". "+
Expand Down
33 changes: 24 additions & 9 deletions go/bin/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,28 +224,43 @@ func getOrDownloadModel(name string) (modelPath string, err error) {
return downloadModel(name)
}

func createModelAutoDownload(nameOrPath string, device llm.Device) (llm.Model, error) {
var modelPath string
var err error

func autoDownloadModel(nameOrPath string) (filename string, err error) {
if filepath.Ext(nameOrPath) == ".llmpkg" {
modelPath = nameOrPath
filename = nameOrPath
} else {
modelPath, err = getOrDownloadModel(nameOrPath)
filename, err = getOrDownloadModel(nameOrPath)
}

if err != nil {
return nil, err
return
}

_, err = os.Stat(modelPath)
_, err = os.Stat(filename)
if err != nil {
return
}

return
}

func createModelAutoDownload(nameOrPath string, device llm.Device) (llm.Model, error) {
modelPath, err := autoDownloadModel(nameOrPath)
if err != nil {
return nil, fmt.Errorf("model not exist: %s", modelPath)
return nil, err
}

return llm.NewModel(modelPath, device)
}

func createASRModelAutoDownload(nameOrPath, device string) (*llm.ASRModel, error) {
modelPath, err := autoDownloadModel(nameOrPath)
if err != nil {
return nil, err
}

return llm.NewASRModel(modelPath, device)
}

func printDownloadUsage(fs *flag.FlagSet) {
fmt.Fprintln(os.Stderr, "Usage: llm download [OPTIONS]")
fmt.Fprintln(os.Stderr, "")
Expand Down
10 changes: 5 additions & 5 deletions go/bin/subtitle_translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ func (t *transcriptionTranslator) translateOne(text string) (translationResult,
}

func (t *transcriptionTranslator) translate(
transcriptions []skill.TranscriptionResult) ([]skill.TranscriptionResult, error) {
translatedTrxn := []skill.TranscriptionResult{}
transcriptions []llm.RecognitionResult) ([]llm.RecognitionResult, error) {
translatedTrxn := []llm.RecognitionResult{}
for _, transcription := range transcriptions {
tr, err := t.translateOne(transcription.Text)
if err != nil {
Expand Down Expand Up @@ -142,19 +142,19 @@ func newTranscripotionTranslator(
}

// read a subtitle file to
func ReadSubtitleFile(filename string) ([]TxResult, error) {
func ReadSubtitleFile(filename string) ([]llm.RecognitionResult, error) {
subtitles, err := astisub.OpenFile(filename)
if err != nil {
return nil, err
}

txs := []TxResult{}
txs := []llm.RecognitionResult{}
for _, item := range subtitles.Items {
lines := []string{}
for _, line := range item.Lines {
lines = append(lines, line.String())
}
tx := TxResult{
tx := llm.RecognitionResult{
Begin: item.StartAt,
End: item.EndAt,
Text: strings.Join(lines, " "),
Expand Down
45 changes: 15 additions & 30 deletions go/bin/transcribe.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,12 @@ import (
"os"
"path/filepath"
"strings"
"time"

"github.com/asticode/go-astisub"
"github.com/ling0322/libllm/go/llm"
"github.com/ling0322/libllm/go/skill"
)

type TxResult = skill.TranscriptionResult

type translationConfig struct {
srcLang skill.Lang
tgtLang skill.Lang
Expand All @@ -51,9 +48,9 @@ func printTranscribeUsage(fs *flag.FlagSet) {
fmt.Fprintln(os.Stderr, "")
}

func TranscriptionToSubtitle(transcriptions []TxResult) *astisub.Subtitles {
func TranscriptionToSubtitle(results []llm.RecognitionResult) *astisub.Subtitles {
s := astisub.NewSubtitles()
for index, t := range transcriptions {
for index, t := range results {
s.Items = append(s.Items, &astisub.Item{
Index: index,
EndAt: t.End,
Expand All @@ -65,13 +62,13 @@ func TranscriptionToSubtitle(transcriptions []TxResult) *astisub.Subtitles {
return s
}

func getTranscriptionLang(transcriptions []TxResult) skill.Lang {
if len(transcriptions) == 0 {
func getTranscriptionLang(results []llm.RecognitionResult) skill.Lang {
if len(results) == 0 {
return skill.UnknownLanguage
}

langCount := map[string]int{}
for _, tx := range transcriptions {
for _, tx := range results {
langCount[tx.Language] += 1
}

Expand All @@ -93,7 +90,7 @@ func getTranscriptionLang(transcriptions []TxResult) skill.Lang {
}
}

func saveTranscription(transcriptions []TxResult, filename string) error {
func saveTranscription(transcriptions []llm.RecognitionResult, filename string) error {
slog.Info(fmt.Sprintf("save transcription to %s", filename))
subtitle := TranscriptionToSubtitle(transcriptions)
err := subtitle.Write(filename)
Expand Down Expand Up @@ -167,7 +164,7 @@ func transcribeMain(args []string) {
var tgtLang skill.Lang
modelFile := getTranscriotionModel(ba)
tgtLang = ba.getTargetLang()
device := ba.getDevice()
device := ba.getRawDevice()
inputFile := ba.getInput()
outputFile := getOutputFile(ba)

Expand All @@ -176,42 +173,30 @@ func transcribeMain(args []string) {
os.Exit(1)
}

model, err := createModelAutoDownload(modelFile, device)
model, err := createASRModelAutoDownload(modelFile, device)
if err != nil {
log.Fatal(err)
}

slog.Info(fmt.Sprintf("output file is %s", outputFile))

d0 := time.Now()
transcriber, err := skill.NewWhisperTranscriber(model, inputFile)
recognition, err := model.Recognize(inputFile)
if err != nil {
log.Fatal(err)
}

if ba.getRawLang() != "" {
transcriber.SetLanguage(ba.getRawLang())
}

transcriptions := []skill.TranscriptionResult{}
for transcriber.Transcribe() {
r := transcriber.Result()
transcriptions := []llm.RecognitionResult{}
for recognition.Next() {
r := recognition.Result()
transcriptions = append(transcriptions, r)
slog.Info(r.String())
}

if err = transcriber.Err(); err != nil {
if err = recognition.Err(); err != nil {
log.Fatal(err)
}

processingTime := time.Since(d0)
slog.Info(
fmt.Sprintf("processed %s audio in %s, rtf=%.3f",
transcriber.Offset(),
processingTime.Round(time.Millisecond),
processingTime.Seconds()/transcriber.Offset().Seconds()))

transcriber.Dispose()
recognition.Dispose()
model.Dispose()

srcLang := getTranscriptionLang(transcriptions)
Expand All @@ -230,7 +215,7 @@ func transcribeMain(args []string) {
srcLang,
tgtLang,
getTranslationModel(ba),
device,
ba.getDevice(),
}
err := TranslateSubtitle(config, outputTxFile, outputFile)
if err != nil {
Expand Down
Loading
Loading