Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert full_get_*_* methods to use internal helper instead of duplicating code #202

Merged
merged 2 commits into from
Feb 19, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 77 additions & 60 deletions src/whisper_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -346,25 +346,43 @@ impl WhisperState {
Ok(unsafe { whisper_rs_sys::whisper_full_get_segment_t1_from_state(self.ptr, segment) })
}

fn full_get_segment_raw(&self, segment: c_int) -> Result<&CStr, WhisperError> {
let ret =
unsafe { whisper_rs_sys::whisper_full_get_segment_text_from_state(self.ptr, segment) };
if ret.is_null() {
return Err(WhisperError::NullPointer);
}
unsafe { Ok(CStr::from_ptr(ret)) }
}

/// Get the raw bytes of the specified segment.
///
/// # Arguments
/// * segment: Segment index.
///
/// # Returns
/// `Ok(Vec<u8>)` on success, with the returned bytes or
/// `Err(WhisperError::NullPointer)` on failure (this is the only possible error)
///
/// # C++ equivalent
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
pub fn full_get_segment_bytes(&self, segment: c_int) -> Result<Vec<u8>, WhisperError> {
Ok(self.full_get_segment_raw(segment)?.to_bytes().to_vec())
}

/// Get the text of the specified segment.
///
/// # Arguments
/// * segment: Segment index.
///
/// # Returns
/// Ok(String) on success, Err(WhisperError) on failure.
/// `Ok(String)` on success, with the UTF-8 validated string, or
/// `Err(WhisperError)` on failure (either `NullPointer` or `InvalidUtf8`)
///
/// # C++ equivalent
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
pub fn full_get_segment_text(&self, segment: c_int) -> Result<String, WhisperError> {
let ret =
unsafe { whisper_rs_sys::whisper_full_get_segment_text_from_state(self.ptr, segment) };
if ret.is_null() {
return Err(WhisperError::NullPointer);
}
let c_str = unsafe { CStr::from_ptr(ret) };
let r_str = c_str.to_str()?;
Ok(r_str.to_string())
Ok(self.full_get_segment_raw(segment)?.to_str()?.to_string())
}

/// Get the text of the specified segment.
Expand All @@ -376,53 +394,69 @@ impl WhisperState {
/// * segment: Segment index.
///
/// # Returns
/// Ok(String) on success, Err(WhisperError) on failure.
/// `Ok(String)` on success, or
/// `Err(WhisperError::NullPointer)` on failure (this is the only possible error)
///
/// # C++ equivalent
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
pub fn full_get_segment_text_lossy(&self, segment: c_int) -> Result<String, WhisperError> {
let ret =
unsafe { whisper_rs_sys::whisper_full_get_segment_text_from_state(self.ptr, segment) };
if ret.is_null() {
return Err(WhisperError::NullPointer);
}
let c_str = unsafe { CStr::from_ptr(ret) };
Ok(c_str.to_string_lossy().to_string())
Ok(self
.full_get_segment_raw(segment)?
.to_string_lossy()
.to_string())
}

/// Get the bytes of the specified segment.
/// Get number of tokens in the specified segment.
///
/// # Arguments
/// * segment: Segment index.
///
/// # Returns
/// `Ok(Vec<u8>)` on success, `Err(WhisperError)` on failure.
/// c_int
///
/// # C++ equivalent
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
pub fn full_get_segment_bytes(&self, segment: c_int) -> Result<Vec<u8>, WhisperError> {
let ret =
unsafe { whisper_rs_sys::whisper_full_get_segment_text_from_state(self.ptr, segment) };
/// `int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment)`
#[inline]
pub fn full_n_tokens(&self, segment: c_int) -> Result<c_int, WhisperError> {
Ok(unsafe { whisper_rs_sys::whisper_full_n_tokens_from_state(self.ptr, segment) })
}

fn full_get_token_raw(&self, segment: c_int, token: c_int) -> Result<&CStr, WhisperError> {
let ret = unsafe {
whisper_rs_sys::whisper_full_get_token_text_from_state(
self.ctx.ctx,
self.ptr,
segment,
token,
)
};
if ret.is_null() {
return Err(WhisperError::NullPointer);
}
let c_str = unsafe { CStr::from_ptr(ret) };
Ok(c_str.to_bytes().to_vec())
unsafe { Ok(CStr::from_ptr(ret)) }
}

/// Get number of tokens in the specified segment.
/// Get the raw token bytes of the specified token in the specified segment.
///
/// Useful if you're using a language for which whisper is known to split tokens
/// away from UTF-8 character boundaries.
///
/// # Arguments
/// * segment: Segment index.
/// * token: Token index.
///
/// # Returns
/// c_int
/// `Ok(Vec<u8>)` on success, with the returned bytes or
/// `Err(WhisperError::NullPointer)` on failure (this is the only possible error)
///
/// # C++ equivalent
/// `int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment)`
#[inline]
pub fn full_n_tokens(&self, segment: c_int) -> Result<c_int, WhisperError> {
Ok(unsafe { whisper_rs_sys::whisper_full_n_tokens_from_state(self.ptr, segment) })
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
pub fn full_get_token_bytes(
&self,
segment: c_int,
token: c_int,
) -> Result<Vec<u8>, WhisperError> {
Ok(self.full_get_token_raw(segment, token)?.to_bytes().to_vec())
}

/// Get the token text of the specified token in the specified segment.
Expand All @@ -432,7 +466,8 @@ impl WhisperState {
/// * token: Token index.
///
/// # Returns
/// Ok(String) on success, Err(WhisperError) on failure.
/// `Ok(String)` on success, with the UTF-8 validated string, or
/// `Err(WhisperError)` on failure (either `NullPointer` or `InvalidUtf8`)
///
/// # C++ equivalent
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
Expand All @@ -441,20 +476,10 @@ impl WhisperState {
segment: c_int,
token: c_int,
) -> Result<String, WhisperError> {
let ret = unsafe {
whisper_rs_sys::whisper_full_get_token_text_from_state(
self.ctx.ctx,
self.ptr,
segment,
token,
)
};
if ret.is_null() {
return Err(WhisperError::NullPointer);
}
let c_str = unsafe { CStr::from_ptr(ret) };
let r_str = c_str.to_str()?;
Ok(r_str.to_string())
Ok(self
.full_get_token_raw(segment, token)?
.to_str()?
.to_string())
}

/// Get the token text of the specified token in the specified segment.
Expand All @@ -467,7 +492,8 @@ impl WhisperState {
/// * token: Token index.
///
/// # Returns
/// Ok(String) on success, Err(WhisperError) on failure.
/// `Ok(String)` on success, or
/// `Err(WhisperError::NullPointer)` on failure (this is the only possible error)
///
/// # C++ equivalent
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
Expand All @@ -476,19 +502,10 @@ impl WhisperState {
segment: c_int,
token: c_int,
) -> Result<String, WhisperError> {
let ret = unsafe {
whisper_rs_sys::whisper_full_get_token_text_from_state(
self.ctx.ctx,
self.ptr,
segment,
token,
)
};
if ret.is_null() {
return Err(WhisperError::NullPointer);
}
let c_str = unsafe { CStr::from_ptr(ret) };
Ok(c_str.to_string_lossy().to_string())
Ok(self
.full_get_token_raw(segment, token)?
.to_string_lossy()
.to_string())
}

/// Get the token ID of the specified token in the specified segment.
Expand Down