From 91aeb8bacc78ac0595ee779be024dc6b37a13be5 Mon Sep 17 00:00:00 2001 From: Xiaoyang Chen Date: Tue, 27 Aug 2024 10:00:00 +0800 Subject: [PATCH] Use wave stream as input for whisper transcription (#95) --- .github/workflows/cmake-windows.yml | 4 +- go/bin/transcribe.go | 6 +- go/ffmpegplugin/plugin.c | 46 +++--- go/ffmpegplugin/plugin.h | 14 +- go/ffmpegplugin/read_audio.go | 59 +++++-- go/skill/audio.go | 107 ++++++++++++- go/skill/whisper.go | 105 ++++++------- src/libllm/read_audio_ffmpeg.cc | 235 ++++++++++++++++++---------- src/libllm/read_audio_ffmpeg.h | 12 +- 9 files changed, 408 insertions(+), 180 deletions(-) diff --git a/.github/workflows/cmake-windows.yml b/.github/workflows/cmake-windows.yml index 2095901..749db7f 100644 --- a/.github/workflows/cmake-windows.yml +++ b/.github/workflows/cmake-windows.yml @@ -33,7 +33,9 @@ jobs: - name: Install cutlass run: cd third_party && bash install_cutlass.sh - name: Build ffmpeg plugin for Windows - run: gcc -shared -o llmpluginffmpeg.dll + run: g++ -shared -o llmpluginffmpeg.dll + -fno-exceptions + -fno-rtti -Isrc -Ithird_party/ffmpeg -DLIBLLM_EXPORTS diff --git a/go/bin/transcribe.go b/go/bin/transcribe.go index 85aa521..14e960b 100644 --- a/go/bin/transcribe.go +++ b/go/bin/transcribe.go @@ -206,7 +206,11 @@ func transcribeMain(args []string) { slog.Info(fmt.Sprintf("output file is %s", outputFile)) d0 := time.Now() - transcriber := skill.NewWhisperTranscriber(model, inputFile) + transcriber, err := skill.NewWhisperTranscriber(model, inputFile) + if err != nil { + log.Fatal(err) + } + transcriptions := []skill.TranscriptionResult{} for transcriber.Transcribe() { r := transcriber.Result() diff --git a/go/ffmpegplugin/plugin.c b/go/ffmpegplugin/plugin.c index 785218c..31cd775 100644 --- a/go/ffmpegplugin/plugin.c +++ b/go/ffmpegplugin/plugin.c @@ -41,12 +41,10 @@ typedef void *LLM_HMODULE; typedef HMODULE LLM_HMODULE; #endif -void *(*p_llm_ffmpeg_plugin_load_library)(const char *library_path) = NULL; -char *(*p_llm_ffmpeg_plugin_get_err)() = NULL; -int32_t (*p_llm_ffmpeg_plugin_read_16khz_mono_pcm_from_media_file)( - const char *filename, - char **output_buffer, - int32_t *output_size) = NULL; +char *(*p_llm_ffmpeg_get_err)() = NULL; +void *(*p_llm_ffmpeg_audio_open)(const char *filename); +void (*p_llm_ffmpeg_audio_close)(void *reader); +int32_t (*p_llm_ffmpeg_audio_read)(void *reader, char *buf, int32_t buf_size); // load the libllm shared library. void *llm_ffmpeg_plugin_load_library(const char *libraryPath) { @@ -74,16 +72,20 @@ void *llm_ffmpeg_plugin_load_library(const char *libraryPath) { int llm_ffmpeg_plugin_load_symbols(void *pDll) { LLM_HMODULE hDll = (LLM_HMODULE)pDll; - LOAD_SYMBOL(hDll, llm_ffmpeg_plugin_get_err); - LOAD_SYMBOL(hDll, llm_ffmpeg_plugin_read_16khz_mono_pcm_from_media_file); + LOAD_SYMBOL(hDll, llm_ffmpeg_get_err); + LOAD_SYMBOL(hDll, llm_ffmpeg_audio_open); + LOAD_SYMBOL(hDll, llm_ffmpeg_audio_close); + LOAD_SYMBOL(hDll, llm_ffmpeg_audio_read); return 0; } // load the libllm shared library. -void llm_ffmpeg_plugin_destroy_librray(void *handle) { - p_llm_ffmpeg_plugin_get_err = NULL; - p_llm_ffmpeg_plugin_read_16khz_mono_pcm_from_media_file = NULL; +void llm_ffmpeg_plugin_destroy_library(void *handle) { + p_llm_ffmpeg_get_err = NULL; + p_llm_ffmpeg_audio_open = NULL; + p_llm_ffmpeg_audio_close = NULL; + p_llm_ffmpeg_audio_read = NULL; // first try to load the dll from same folder as current module. #if defined(LUT_PLATFORM_APPLE) || defined(LUT_PLATFORM_LINUX) @@ -100,16 +102,18 @@ void llm_ffmpeg_plugin_destroy_librray(void *handle) { #endif } -char *llm_ffmpeg_plugin_get_err() { - return p_llm_ffmpeg_plugin_get_err(); +void *llm_ffmpeg_audio_open(const char *filename) { + return p_llm_ffmpeg_audio_open(filename); } -int32_t llm_ffmpeg_plugin_read_16khz_mono_pcm_from_media_file( - const char *filename, - char **output_buffer, - int32_t *output_size) { - return p_llm_ffmpeg_plugin_read_16khz_mono_pcm_from_media_file( - filename, - output_buffer, - output_size); +void llm_ffmpeg_audio_close(void *reader) { + return p_llm_ffmpeg_audio_close(reader); +} + +int32_t llm_ffmpeg_audio_read(void *reader, char *buf, int32_t buf_size) { + return p_llm_ffmpeg_audio_read(reader, buf, buf_size); +} + +const char *llm_ffmpeg_get_err() { + return p_llm_ffmpeg_get_err(); } diff --git a/go/ffmpegplugin/plugin.h b/go/ffmpegplugin/plugin.h index 230445f..af7a827 100644 --- a/go/ffmpegplugin/plugin.h +++ b/go/ffmpegplugin/plugin.h @@ -21,11 +21,13 @@ #include +typedef struct llm_ffmpeg_audio_reader_t llm_ffmpeg_audio_reader_t; + void *llm_ffmpeg_plugin_load_library(const char *library_path); int llm_ffmpeg_plugin_load_symbols(void *handle); -void llm_ffmpeg_plugin_destroy_librray(void *handle); -char *llm_ffmpeg_plugin_get_err(); -int32_t llm_ffmpeg_plugin_read_16khz_mono_pcm_from_media_file( - const char *filename, - char **output_buffer, - int32_t *output_size); +void llm_ffmpeg_plugin_destroy_library(void *handle); + +const char *llm_ffmpeg_get_err(); +void *llm_ffmpeg_audio_open(const char *filename); +void llm_ffmpeg_audio_close(void *reader); +int32_t llm_ffmpeg_audio_read(void *reader, char *buf, int32_t buf_size); diff --git a/go/ffmpegplugin/read_audio.go b/go/ffmpegplugin/read_audio.go index ebb80a9..65c0728 100644 --- a/go/ffmpegplugin/read_audio.go +++ b/go/ffmpegplugin/read_audio.go @@ -26,6 +26,8 @@ package ffmpegplugin import "C" import ( "errors" + "io" + "log/slog" "os" "path/filepath" "runtime" @@ -39,6 +41,10 @@ var gDll unsafe.Pointer var Init = sync.OnceValue[error](initIntrnal) +type Reader struct { + handle unsafe.Pointer +} + func initIntrnal() error { // load the shared library. binPath, err := os.Executable() @@ -74,7 +80,9 @@ func initIntrnal() error { return nil } -func Read16KHzMonoPcmFromMediaFile(filename string) (pcmdata []byte, err error) { +func NewReader(filename string) (*Reader, error) { + Init() + if !gInit.Load() { return nil, errors.New("ffmpeg plugin not initialized") } @@ -82,15 +90,46 @@ func Read16KHzMonoPcmFromMediaFile(filename string) (pcmdata []byte, err error) cName := C.CString(filename) defer C.free(unsafe.Pointer(cName)) - var outputPtr *C.char - outputLen := C.int(0) - ret := C.llm_ffmpeg_plugin_read_16khz_mono_pcm_from_media_file(cName, &outputPtr, &outputLen) - if ret < 0 { - err = errors.New(C.GoString(C.llm_ffmpeg_plugin_get_err())) - return + handle := C.llm_ffmpeg_audio_open(cName) + if handle == nil { + return nil, errors.New(C.GoString(C.llm_ffmpeg_get_err())) + } + + reader := &Reader{ + unsafe.Pointer(handle), + } + runtime.SetFinalizer(reader, func(r *Reader) { + if r.handle != nil { + slog.Warn("ffmpegplugin.Reader is not closed") + r.Close() + } + }) + return reader, nil +} + +func (r *Reader) Read(b []byte) (n int, err error) { + if r.handle == nil { + return 0, errors.New("llm_ffmpeg_audio_reader_t handle is empty") } - pcmdata = make([]byte, int(outputLen)) - C.memcpy(unsafe.Pointer(&pcmdata[0]), unsafe.Pointer(outputPtr), C.size_t(outputLen)) - return + buf := (*C.char)(unsafe.Pointer(&b[0])) + bufsize := C.int32_t(len(b)) + + nb := C.llm_ffmpeg_audio_read(r.handle, buf, bufsize) + if nb == 0 { + return 0, io.EOF + } else if nb < 0 { + return 0, errors.New(C.GoString(C.llm_ffmpeg_get_err())) + } else { + return int(nb), nil + } +} + +func (r *Reader) Close() error { + if r.handle != nil { + C.llm_ffmpeg_audio_close(r.handle) + r.handle = nil + } + + return nil } diff --git a/go/skill/audio.go b/go/skill/audio.go index 56920dc..6211645 100644 --- a/go/skill/audio.go +++ b/go/skill/audio.go @@ -23,12 +23,14 @@ import ( "bytes" "errors" "fmt" + "io" "log/slog" "os" "os/exec" "path/filepath" "runtime" "sync" + "time" "github.com/ling0322/libllm/go/ffmpegplugin" ) @@ -36,6 +38,103 @@ import ( var gFfmpegBin string var gFfmpegPluginReady bool +var BlockSize = 60 * 16000 * 2 + +type WaveStream struct { + reader *ffmpegplugin.Reader + buffer []byte + bufferOffset time.Duration +} + +type WaveChunk struct { + begin time.Duration + end time.Duration + eof bool + data []byte +} + +func NewWaveStream(filename string) (*WaveStream, error) { + reader, err := ffmpegplugin.NewReader(filename) + if err != nil { + return nil, err + } + + return &WaveStream{ + reader: reader, + }, nil +} + +func durationToBytes(dur time.Duration) int { + nsPerSample := 1000000000 / SampleRate + nSamples := int(dur.Nanoseconds() / int64(nsPerSample)) + nBytes := nSamples * 2 + + return nBytes +} + +func (s *WaveStream) ensureOffset(offset time.Duration) error { + if offset < s.bufferOffset { + return errors.New("wave stream could not seek backward") + } + + length := offset - s.bufferOffset + for len(s.buffer) < durationToBytes(length) { + b := make([]byte, BlockSize) + n, err := s.reader.Read(b) + if err != nil { + return err + } + + s.buffer = append(s.buffer, b[:n]...) + } + + return nil +} + +func (s *WaveStream) Seek(offset time.Duration) error { + err := s.ensureOffset(offset) + if err != nil { + return err + } + + forwardDuration := offset - s.bufferOffset + forwardBytes := durationToBytes(forwardDuration) + + s.buffer = s.buffer[forwardBytes:] + s.bufferOffset = offset + + return nil +} + +func (s *WaveStream) Offset() time.Duration { + return s.bufferOffset +} + +func (s *WaveStream) ReadChunk(length time.Duration) (WaveChunk, error) { + err := s.ensureOffset(s.bufferOffset + length) + eof := false + if errors.Is(err, io.EOF) { + eof = true + if len(s.buffer) == 0 { + return WaveChunk{}, io.EOF + } + } else if err != nil { + return WaveChunk{}, err + } + + n := min(len(s.buffer), durationToBytes(length)) + return WaveChunk{ + begin: s.bufferOffset, + end: s.bufferOffset + length, + eof: eof, + data: s.buffer[:n], + }, nil +} + +func (s *WaveStream) Close() error { + return s.reader.Close() +} + var initAudioReader = sync.OnceFunc(func() { err := ffmpegplugin.Init() if err != nil { @@ -52,7 +151,13 @@ var initAudioReader = sync.OnceFunc(func() { // convert the input file to pcm .wav file in OS temporary directory using ffmpeg. func convertToPcmPlugin(inputFile string) ([]byte, error) { - return ffmpegplugin.Read16KHzMonoPcmFromMediaFile(inputFile) + reader, err := ffmpegplugin.NewReader(inputFile) + if err != nil { + return nil, err + } + defer reader.Close() + + return io.ReadAll(reader) } // find the path of ffmpeg. diff --git a/go/skill/whisper.go b/go/skill/whisper.go index 80fa522..c5adea3 100644 --- a/go/skill/whisper.go +++ b/go/skill/whisper.go @@ -22,6 +22,7 @@ package skill import ( "errors" "fmt" + "io" "log/slog" "math" "regexp" @@ -33,7 +34,7 @@ import ( var regexpLangToken = regexp.MustCompile(`^<\|([a-z][a-z][a-z]?)\|>$`) var ErrInvalidWhisperSequence = errors.New("invalid sequence for Whisper model") -var ErrFilenameIsEmpty = errors.New("input filename is empty") +var ErrStreamIsEmpty = errors.New("input stream is nil") var ErrWhisperModelIsNil = errors.New("whisper model is nil") var ErrNoMoreResults = errors.New("no more results") var ErrAudioEndOfStream = errors.New("audio end of stream") @@ -41,26 +42,17 @@ var ErrAudioEndOfStream = errors.New("audio end of stream") const SampleRate = 16000 const BytesPerSample = 2 -const ( - whisperStateAudio = iota - whisperStateStartOfTranscription - whisperStateLanguage - whisperStateTranscribe - whisperStateBeginTime - whisperStateText - whisperStateEndTime -) - // the transcriber with whisper model. implements the interface Transcriber. type WhisperTranscriber struct { // the whisper model. WhisperModel llm.Model // the reader for input file. - InputFile string + stream *WaveStream + + streamOffset time.Duration - // current state in whisper sequence decoding. - state int + chunk WaveChunk // if any errors occured. err error @@ -73,21 +65,19 @@ type WhisperTranscriber struct { // the predicted language. predictedLanguage string - - // the wave bytes for decoding. The format is 16khz 16bit mono-channel PCM without headfer. - wavePayload []byte - - // offset of the current segment in wavePayload. - waveOffset time.Duration } // create a new instance of WhisperTranscriber from whisper model and stream of input file. -func NewWhisperTranscriber(whisperModel llm.Model, inputFile string) *WhisperTranscriber { +func NewWhisperTranscriber(whisperModel llm.Model, inputFile string) (*WhisperTranscriber, error) { + stream, err := NewWaveStream(inputFile) + if err != nil { + return nil, err + } + return &WhisperTranscriber{ WhisperModel: whisperModel, - InputFile: inputFile, - state: whisperStateAudio, - } + stream: stream, + }, nil } // parse a whisper timestamp token like <|3.22|>. On success. return the (parsed-time, true). On @@ -138,7 +128,7 @@ func (w *WhisperTranscriber) decodeTranscription() (TranscriptionResult, error) if !ok { return TranscriptionResult{}, fmt.Errorf("%w: not a time token", ErrInvalidWhisperSequence) } - result.Begin = w.waveOffset + beginOffset + result.Begin = w.stream.Offset() + beginOffset transcriptionDone := false for w.comp.Next() { @@ -147,7 +137,7 @@ func (w *WhisperTranscriber) decodeTranscription() (TranscriptionResult, error) slog.Debug("comp.next()", "token", token, "piece", piece) offset, isTimestampToken := w.parseTimestampToken(token) if isTimestampToken { - result.End = w.waveOffset + offset + result.End = w.stream.Offset() + offset transcriptionDone = true break } @@ -177,20 +167,28 @@ func (w *WhisperTranscriber) parseLanguageToken(token string) (lang string, ok b // prefill audio and prompt when in the begining of decoding or last audio segment finished. If no // transcriotion result or <|nospeech|> got, return ErrNoMoreResults. func (w *WhisperTranscriber) prefillNextAudioSegment() error { - nsPerSample := 1000000000 / SampleRate - sampleOffset := int(w.waveOffset.Nanoseconds() / int64(nsPerSample)) - byteOffset := sampleOffset * 2 - if len(w.wavePayload)-byteOffset < SampleRate/10 { - // ignore the last segment that less than 0.1s. + if w.chunk.eof { + // if the current chunk is already the last one. + return ErrAudioEndOfStream + } + + err := w.stream.Seek(w.streamOffset) + if errors.Is(err, io.EOF) { return ErrAudioEndOfStream + } else if err != nil { + return err } - slog.Info("prefill segment", "offset", w.waveOffset, "byteOffset", byteOffset) - nBytes := min(len(w.wavePayload)-byteOffset, 30*SampleRate*2) - audio := w.wavePayload[byteOffset : byteOffset+nBytes] + slog.Info("transcribe segment", "offset", w.stream.Offset()) + w.chunk, err = w.stream.ReadChunk(30 * time.Second) + if errors.Is(err, io.EOF) { + return ErrAudioEndOfStream + } else if err != nil { + return err + } prompt := llm.NewPrompt() - prompt.AppendAudio(audio, llm.Pcm16kHz16BitMono) + prompt.AppendAudio(w.chunk.data, llm.Pcm16kHz16BitMono) prompt.AppendControlToken("<|startoftranscript|>") compConfig := llm.NewCompletionConfig() @@ -243,9 +241,6 @@ func (w *WhisperTranscriber) disposeCompAndSetToNil() { // implements interface Transcriber. func (w *WhisperTranscriber) Transcribe() bool { - if w.wavePayload == nil { - w.wavePayload, w.err = ReadAudioFromMediaFile(w.InputFile) - } if w.err != nil { return false } @@ -255,8 +250,8 @@ func (w *WhisperTranscriber) Transcribe() bool { return false } - if w.InputFile == "" { - w.err = ErrFilenameIsEmpty + if w.stream == nil { + w.err = ErrStreamIsEmpty return false } @@ -265,32 +260,30 @@ func (w *WhisperTranscriber) Transcribe() bool { beginOfSegment := false if w.comp == nil { w.err = w.prefillNextAudioSegment() + if errors.Is(w.err, ErrNoMoreResults) { + w.disposeCompAndSetToNil() + w.streamOffset += 30 * time.Second + continue + } else if errors.Is(w.err, ErrAudioEndOfStream) { + w.disposeCompAndSetToNil() + w.err = nil + return false + } else if w.err != nil { + return false + } beginOfSegment = true } - if errors.Is(w.err, ErrNoMoreResults) { - w.disposeCompAndSetToNil() - w.err = nil - w.waveOffset += 30 * time.Second - continue - } else if errors.Is(w.err, ErrAudioEndOfStream) { - w.disposeCompAndSetToNil() - w.err = nil - return false - } else if w.err != nil { - return false - } - result, err := w.decodeTranscription() if errors.Is(err, ErrNoMoreResults) && beginOfSegment { // if no result for the whole audio segment, move forward to the next 30s segment. w.disposeCompAndSetToNil() - w.waveOffset += 30 * time.Second + w.streamOffset += 30 * time.Second continue } else if errors.Is(err, ErrNoMoreResults) && !beginOfSegment { // move the wave offset to the end of last completed transcription. w.disposeCompAndSetToNil() - w.waveOffset = w.result.End + w.streamOffset = w.result.End continue } else if err != nil { w.err = err @@ -321,5 +314,5 @@ func (w *WhisperTranscriber) Dispose() { // implements interface Transcriber. func (w *WhisperTranscriber) Offset() time.Duration { - return w.waveOffset + return w.stream.Offset() } diff --git a/src/libllm/read_audio_ffmpeg.cc b/src/libllm/read_audio_ffmpeg.cc index 13aa37c..c7373ea 100644 --- a/src/libllm/read_audio_ffmpeg.cc +++ b/src/libllm/read_audio_ffmpeg.cc @@ -19,6 +19,10 @@ #include "libllm/read_audio_ffmpeg.h" +#include +#include +#include + #ifdef __cplusplus extern "C" { #endif // __cplusplus @@ -33,57 +37,111 @@ extern "C" { thread_local char errmsg[256]; -char *llm_ffmpeg_plugin_get_err() { - errmsg[sizeof(errmsg) - 1] = '\0'; - return errmsg; +struct llm_ffmpeg_audio_reader_t { + AVFrame *frame; + AVFormatContext *formatCtx; + AVCodecContext *codecCtx; + SwrContext *swrCtx; + int audioStreamIndex; + uint8_t *resampleBuffer; + int resampleBufferSize; + int resampleLineSize; + bool eof; + + std::deque buffer; + + llm_ffmpeg_audio_reader_t(); + ~llm_ffmpeg_audio_reader_t(); + void ensureResampleBuffer(int nSamples); +}; + +llm_ffmpeg_audio_reader_t::llm_ffmpeg_audio_reader_t() + : frame(nullptr), + formatCtx(nullptr), + codecCtx(nullptr), + swrCtx(nullptr), + audioStreamIndex(-1), + resampleBuffer(nullptr), + resampleBufferSize(0), + resampleLineSize(0), + eof(true) { } -int32_t llm_ffmpeg_plugin_read_16khz_mono_pcm_from_media_file( - const char *filename, - char **output_buffer, - int32_t *output_size) { - AVFormatContext *format_ctx = nullptr; - AVCodecContext *codec_ctx = nullptr; - AVStream *audio_stream = nullptr; - const AVCodec *codec = nullptr; - struct SwrContext *swr_ctx = nullptr; - AVPacket packet; - AVFrame *frame = nullptr; +llm_ffmpeg_audio_reader_t::~llm_ffmpeg_audio_reader_t() { + if (frame) av_frame_free(&frame); + if (codecCtx) avcodec_free_context(&codecCtx); + if (formatCtx) avformat_close_input(&formatCtx); + if (swrCtx) swr_free(&swrCtx); + if (resampleBuffer) av_free(resampleBuffer); + + formatCtx = nullptr; + codecCtx = nullptr; + swrCtx = nullptr; + audioStreamIndex = -1; + resampleBuffer = nullptr; + resampleBufferSize = 0; + resampleLineSize = 0; + eof = true; +} + +void llm_ffmpeg_audio_reader_t::ensureResampleBuffer(int nSamples) { + if (resampleBufferSize >= nSamples) { + return; + } - int audio_stream_index = -1; - int ret; + if (resampleBuffer) av_free(resampleBuffer); + int ret = av_samples_alloc(&resampleBuffer, &resampleLineSize, 1, nSamples, AV_SAMPLE_FMT_S16, 0); + if (ret < 0) { + fprintf(stderr, "failed to alloc samples, retcode=%d\n", ret); + abort(); + } + + resampleBufferSize = nSamples; +} - if ((ret = avformat_open_input(&format_ctx, filename, nullptr, nullptr)) < 0) { +llm_ffmpeg_audio_reader_t *llm_ffmpeg_audio_open(const char *filename) { + std::unique_ptr reader = std::make_unique(); + + int ret = avformat_open_input(&reader->formatCtx, filename, nullptr, nullptr); + if (ret < 0) { snprintf(errmsg, sizeof(errmsg), "Could not open input file \"%s\"", filename); - return ret; + return nullptr; } // find stream info - if ((ret = avformat_find_stream_info(format_ctx, nullptr)) < 0) { + ret = avformat_find_stream_info(reader->formatCtx, nullptr); + if (ret < 0) { snprintf(errmsg, sizeof(errmsg), "Could not find stream information"); - return ret; + return nullptr; } // find the audio stream - audio_stream_index = av_find_best_stream(format_ctx, AVMEDIA_TYPE_AUDIO, -1, -1, &codec, 0); - if (audio_stream_index < 0) { + const AVCodec *codec = nullptr; + reader->audioStreamIndex = av_find_best_stream( + reader->formatCtx, + AVMEDIA_TYPE_AUDIO, + -1, + -1, + &codec, + 0); + if (reader->audioStreamIndex < 0) { snprintf(errmsg, sizeof(errmsg), "Could not find audio stream in input file"); - return audio_stream_index; + return nullptr; } - audio_stream = format_ctx->streams[audio_stream_index]; + AVStream *audioStream = reader->formatCtx->streams[reader->audioStreamIndex]; // get the codec context - codec_ctx = avcodec_alloc_context3(codec); - if (!codec_ctx) { + reader->codecCtx = avcodec_alloc_context3(codec); + if (!reader->codecCtx) { snprintf(errmsg, sizeof(errmsg), "Could not allocate codec context"); - return AVERROR(ENOMEM); + return nullptr; } - avcodec_parameters_to_context(codec_ctx, audio_stream->codecpar); + avcodec_parameters_to_context(reader->codecCtx, audioStream->codecpar); // open the codec - if ((ret = avcodec_open2(codec_ctx, codec, nullptr)) < 0) { + if ((ret = avcodec_open2(reader->codecCtx, codec, nullptr)) < 0) { snprintf(errmsg, sizeof(errmsg), "Could not open codec"); - return ret; + return nullptr; } // initialize resampler @@ -91,81 +149,100 @@ int32_t llm_ffmpeg_plugin_read_16khz_mono_pcm_from_media_file( av_channel_layout_default(&ch_layout_output, 1); ret = swr_alloc_set_opts2( - &swr_ctx, - &ch_layout_output, // Output channel layout - AV_SAMPLE_FMT_S16, // Output sample format - 16000, // Output sample rate - &codec_ctx->ch_layout, // Input channel layout - codec_ctx->sample_fmt, // Input sample format - codec_ctx->sample_rate, // Input sample rate + &reader->swrCtx, + &ch_layout_output, // Output channel layout + AV_SAMPLE_FMT_S16, // Output sample format + 16000, // Output sample rate + &reader->codecCtx->ch_layout, // Input channel layout + reader->codecCtx->sample_fmt, // Input sample format + reader->codecCtx->sample_rate, // Input sample rate 0, nullptr); - if (ret < 0 || swr_init(swr_ctx) < 0) { + if (ret < 0) { snprintf(errmsg, sizeof(errmsg), "Could not allocate resample context"); - return AVERROR(ENOMEM); + return nullptr; + } + + ret = swr_init(reader->swrCtx); + if (ret < 0) { + snprintf(errmsg, sizeof(errmsg), "Could not initialize resample context"); + return nullptr; } // allocate memory for decoding - frame = av_frame_alloc(); - if (!frame) { + reader->frame = av_frame_alloc(); + if (!reader->frame) { snprintf(errmsg, sizeof(errmsg), "Could not allocate audio frame"); - return AVERROR(ENOMEM); + return nullptr; } - // Prepare output buffer - *output_buffer = nullptr; - *output_size = 0; - - // Read and decode audio frames - while (av_read_frame(format_ctx, &packet) >= 0) { - if (packet.stream_index == audio_stream_index) { - if (avcodec_send_packet(codec_ctx, &packet) == 0) { - while (avcodec_receive_frame(codec_ctx, frame) == 0) { - uint8_t *out_buf; - int out_linesize; - int out_samples = av_rescale_rnd( - swr_get_delay(swr_ctx, codec_ctx->sample_rate) + frame->nb_samples, + reader->eof = false; + return reader.release(); +} + +void llm_ffmpeg_audio_close(llm_ffmpeg_audio_reader_t *reader) { + delete reader; +} + +int32_t llm_ffmpeg_audio_read(llm_ffmpeg_audio_reader_t *reader, char *buf, int32_t buf_size) { + while ((!reader->eof) && buf_size > reader->buffer.size()) { + AVPacket packet; + int ret = av_read_frame(reader->formatCtx, &packet); + if (ret < 0) { + reader->eof = true; + break; + } + + if (packet.stream_index == reader->audioStreamIndex) { + if (avcodec_send_packet(reader->codecCtx, &packet) == 0) { + while (avcodec_receive_frame(reader->codecCtx, reader->frame) == 0) { + int outSamples = av_rescale_rnd( + swr_get_delay(reader->swrCtx, reader->codecCtx->sample_rate) + + reader->frame->nb_samples, 16000, - codec_ctx->sample_rate, + reader->codecCtx->sample_rate, AV_ROUND_UP); // Allocate memory for resampled data - av_samples_alloc(&out_buf, &out_linesize, 1, out_samples, AV_SAMPLE_FMT_S16, 0); + reader->ensureResampleBuffer(outSamples); // Resample the data - int resampled_data = swr_convert( - swr_ctx, - &out_buf, - out_samples, - (const uint8_t **)frame->data, - frame->nb_samples); + int ret = swr_convert( + reader->swrCtx, + &reader->resampleBuffer, + outSamples, + (const uint8_t **)reader->frame->data, + reader->frame->nb_samples); + if (ret < 0) { + snprintf(errmsg, sizeof(errmsg), "Could not convert samples"); + return ret; + } // Calculate the size of the resampled data - int data_size = av_samples_get_buffer_size( - &out_linesize, + int destBufferSize = av_samples_get_buffer_size( + &reader->resampleLineSize, 1, - resampled_data, + ret, AV_SAMPLE_FMT_S16, 1); // Reallocate output buffer to append new data - *output_buffer = (char *)realloc(*output_buffer, *output_size + data_size); - memcpy(*output_buffer + *output_size, out_buf, data_size); - *output_size += data_size; - - // Free the allocated buffer for resampled data - av_free(out_buf); + reader->buffer.insert( + reader->buffer.end(), + reader->resampleBuffer, + reader->resampleBuffer + destBufferSize); } } - av_packet_unref(&packet); } } -_read_pcm_from_media_file_clean_up: - av_frame_free(&frame); - avcodec_free_context(&codec_ctx); - avformat_close_input(&format_ctx); - swr_free(&swr_ctx); + int nBytesToCopy = std::min(static_cast(reader->buffer.size()), buf_size); + std::copy(reader->buffer.begin(), reader->buffer.begin() + nBytesToCopy, buf); + reader->buffer.erase(reader->buffer.begin(), reader->buffer.begin() + nBytesToCopy); - return ret; + return nBytesToCopy; +} + +const char *llm_ffmpeg_get_err() { + return errmsg; } diff --git a/src/libllm/read_audio_ffmpeg.h b/src/libllm/read_audio_ffmpeg.h index 1affa6a..864d065 100644 --- a/src/libllm/read_audio_ffmpeg.h +++ b/src/libllm/read_audio_ffmpeg.h @@ -39,11 +39,13 @@ extern "C" { #endif // __cplusplus -LLMAPI char *llm_ffmpeg_plugin_get_err(); -LLMAPI int32_t llm_ffmpeg_plugin_read_16khz_mono_pcm_from_media_file( - const char *filename, - char **output_buffer, - int32_t *output_size); +typedef struct llm_ffmpeg_audio_reader_t llm_ffmpeg_audio_reader_t; + +LLMAPI const char *llm_ffmpeg_get_err(); + +llm_ffmpeg_audio_reader_t *llm_ffmpeg_audio_open(const char *filename); +void llm_ffmpeg_audio_close(llm_ffmpeg_audio_reader_t *reader); +int32_t llm_ffmpeg_audio_read(llm_ffmpeg_audio_reader_t *reader, char *buf, int32_t buf_size); #ifdef __cplusplus } // extern "C"