From 721cb5cde0f869f97ad37508e7b7600480885530 Mon Sep 17 00:00:00 2001 From: jwijffels Date: Sat, 27 Jan 2024 22:25:53 +0100 Subject: [PATCH] #27 Pass on use_gpu in WhisperModel - default to no gpu (#33) --- R/RcppExports.R | 4 ++-- src/RcppExports.cpp | 9 +++++---- src/rcpp_whisper.cpp | 8 ++++---- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/R/RcppExports.R b/R/RcppExports.R index be5cb2a5..d313acd3 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -1,8 +1,8 @@ # Generated by using Rcpp::compileAttributes() -> do not edit by hand # Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393 -whisper_load_model <- function(model) { - .Call('_audio_whisper_whisper_load_model', PACKAGE = 'audio.whisper', model) +whisper_load_model <- function(model, use_gpu = FALSE) { + .Call('_audio_whisper_whisper_load_model', PACKAGE = 'audio.whisper', model, use_gpu) } whisper_encode <- function(model, path, language, token_timestamps = FALSE, translate = FALSE, print_special = FALSE, duration = 0L, offset = 0L, trace = FALSE, n_threads = 1L, n_processors = 1L, entropy_thold = 2.40, logprob_thold = -1.00, beam_size = -1L, best_of = 5L, split_on_word = FALSE, max_context = -1L) { diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 4d4a8b55..07b152bf 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -6,13 +6,14 @@ using namespace Rcpp; // whisper_load_model -SEXP whisper_load_model(std::string model); -RcppExport SEXP _audio_whisper_whisper_load_model(SEXP modelSEXP) { +SEXP whisper_load_model(std::string model, bool use_gpu); +RcppExport SEXP _audio_whisper_whisper_load_model(SEXP modelSEXP, SEXP use_gpuSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::traits::input_parameter< std::string >::type model(modelSEXP); - rcpp_result_gen = Rcpp::wrap(whisper_load_model(model)); + Rcpp::traits::input_parameter< bool >::type use_gpu(use_gpuSEXP); + rcpp_result_gen = Rcpp::wrap(whisper_load_model(model, use_gpu)); return rcpp_result_gen; END_RCPP } @@ -66,7 +67,7 @@ END_RCPP } static const R_CallMethodDef CallEntries[] = { - {"_audio_whisper_whisper_load_model", (DL_FUNC) &_audio_whisper_whisper_load_model, 1}, + {"_audio_whisper_whisper_load_model", (DL_FUNC) &_audio_whisper_whisper_load_model, 2}, {"_audio_whisper_whisper_encode", (DL_FUNC) &_audio_whisper_whisper_encode, 17}, {"_audio_whisper_whisper_print_benchmark", (DL_FUNC) &_audio_whisper_whisper_print_benchmark, 2}, {"_audio_whisper_whisper_language_info", (DL_FUNC) &_audio_whisper_whisper_language_info, 0}, diff --git a/src/rcpp_whisper.cpp b/src/rcpp_whisper.cpp index 828afc10..8ca4189b 100644 --- a/src/rcpp_whisper.cpp +++ b/src/rcpp_whisper.cpp @@ -225,9 +225,9 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper class WhisperModel { public: struct whisper_context * ctx; - WhisperModel(std::string model){ + WhisperModel(std::string model, bool use_gpu = false){ struct whisper_context_params cparams; - cparams.use_gpu = false; + cparams.use_gpu = use_gpu; ctx = whisper_init_from_file_with_params(model.c_str(), cparams); } ~WhisperModel(){ @@ -236,11 +236,11 @@ class WhisperModel { }; // [[Rcpp::export]] -SEXP whisper_load_model(std::string model) { +SEXP whisper_load_model(std::string model, bool use_gpu = false) { // Load language model and return the pointer to be used by whisper_encode //struct whisper_context * ctx = whisper_init(model.c_str()); //Rcpp::XPtr ptr(ctx, false); - WhisperModel * wp = new WhisperModel(model); + WhisperModel * wp = new WhisperModel(model, use_gpu); Rcpp::XPtr ptr(wp, false); return ptr; }