diff --git a/Cargo.toml b/Cargo.toml index 0a06825e..6781123a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -107,6 +107,7 @@ strum_macros = "0.24" #failure = "0.1" backtrace = "0.3" linkme = "0.3" +serde = { version = "1.0", features = ["derive"] } [dev-dependencies] anyhow = "1.0.38" diff --git a/examples/call.rs b/examples/call.rs index b4148c76..4810774e 100644 --- a/examples/call.rs +++ b/examples/call.rs @@ -81,7 +81,7 @@ fn call_test(ctx: &Context, _: Vec) -> RedisResult { // test resp3 on call_ext let call_options = CallOptionsBuilder::new() .script_mode() - .resp_3(CallOptionResp::Resp3) + .resp(CallOptionResp::Resp3) .errors_as_replies() .build(); ctx.call_ext::<_, CallResult>("HSET", &call_options, &["x", "foo", "bar"]) diff --git a/examples/configuration.rs b/examples/configuration.rs index 949bd5f2..7df55839 100644 --- a/examples/configuration.rs +++ b/examples/configuration.rs @@ -9,8 +9,7 @@ use std::sync::{ use lazy_static::lazy_static; use redis_module::{ configuration::{ConfigurationContext, ConfigurationFlags}, - ConfigurationValue, Context, EnumConfigurationValue, RedisGILGuard, RedisResult, RedisString, - RedisValue, + ConfigurationValue, Context, RedisGILGuard, RedisResult, RedisString, RedisValue, }; enum_configuration! { diff --git a/examples/events.rs b/examples/events.rs index 80ee4df5..6d217ba2 100644 --- a/examples/events.rs +++ b/examples/events.rs @@ -9,15 +9,17 @@ use std::sync::atomic::{AtomicI64, Ordering}; static NUM_KEY_MISSES: AtomicI64 = AtomicI64::new(0); -fn on_event(ctx: &Context, event_type: NotifyEvent, event: &str, key: &str) { +fn on_event(ctx: &Context, event_type: NotifyEvent, event: &str, key: &[u8]) { let msg = format!( "Received event: {:?} on key: {} via event: {}", - event_type, key, event + event_type, + std::str::from_utf8(key).unwrap(), + event ); ctx.log_debug(msg.as_str()); } -fn on_stream(ctx: &Context, _event_type: NotifyEvent, _event: &str, _key: &str) { +fn on_stream(ctx: &Context, _event_type: NotifyEvent, _event: &str, _key: &[u8]) { ctx.log_debug("Stream event received!"); } @@ -34,7 +36,7 @@ fn event_send(ctx: &Context, args: Vec) -> RedisResult { } } -fn on_key_miss(_ctx: &Context, _event_type: NotifyEvent, _event: &str, _key: &str) { +fn on_key_miss(_ctx: &Context, _event_type: NotifyEvent, _event: &str, _key: &[u8]) { NUM_KEY_MISSES.fetch_add(1, Ordering::SeqCst); } @@ -56,7 +58,7 @@ redis_module! { [@EXPIRED @EVICTED: on_event], [@STREAM: on_stream], [@MISSED: on_key_miss], - ] + ], } ////////////////////////////////////////////////////// diff --git a/src/configuration.rs b/src/configuration.rs index 06174430..6111a071 100644 --- a/src/configuration.rs +++ b/src/configuration.rs @@ -41,6 +41,7 @@ macro_rules! enum_configuration { ($(#[$meta:meta])* $vis:vis enum $name:ident { $($(#[$vmeta:meta])* $vname:ident = $val:expr,)* }) => { + use $crate::configuration::EnumConfigurationValue; $(#[$meta])* $vis enum $name { $($(#[$vmeta])* $vname = $val,)* @@ -140,7 +141,7 @@ impl ConfigurationValue for AtomicI64 { impl ConfigurationValue for RedisGILGuard { fn get(&self, ctx: &ConfigurationContext) -> RedisString { let value = self.lock(ctx); - RedisString::create(None, &value) + RedisString::create(None, value.as_str()) } fn set(&self, ctx: &ConfigurationContext, val: RedisString) -> Result<(), RedisError> { let mut value = self.lock(ctx); @@ -152,7 +153,7 @@ impl ConfigurationValue for RedisGILGuard { impl ConfigurationValue for Mutex { fn get(&self, _ctx: &ConfigurationContext) -> RedisString { let value = self.lock().unwrap(); - RedisString::create(None, &value) + RedisString::create(None, value.as_str()) } fn set(&self, _ctx: &ConfigurationContext, val: RedisString) -> Result<(), RedisError> { let mut value = self.lock().unwrap(); @@ -184,7 +185,7 @@ impl + 'static> ConfigrationPrivateData { // we know the GIL is held so it is safe to use Context::dummy(). let configuration_ctx = ConfigurationContext::new(); if let Err(e) = self.variable.set(&configuration_ctx, val) { - let error_msg = RedisString::create(None, &e.to_string()); + let error_msg = RedisString::create(None, e.to_string().as_str()); unsafe { *err = error_msg.take() }; return raw::REDISMODULE_ERR as i32; } @@ -366,7 +367,7 @@ extern "C" fn enum_configuration_set< match val { Ok(val) => private_data.set_val(name, val, err), Err(e) => { - let error_msg = RedisString::create(None, &e.to_string()); + let error_msg = RedisString::create(None, e.to_string().as_str()); unsafe { *err = error_msg.take() }; raw::REDISMODULE_ERR as i32 } @@ -427,7 +428,7 @@ pub fn register_enum_configuration, + args: &[RedisString], ) -> Result<(), RedisError> { if args.len() == 0 { return Ok(()); @@ -437,6 +438,7 @@ pub fn apply_module_args_as_configuration( "Arguments lenght is not devided by 2 (require to be read as module configuration).", )); } + let mut args = args.to_vec(); args.insert(0, ctx.create_string("set")); ctx.call( "config", diff --git a/src/context/call_reply.rs b/src/context/call_reply.rs index 3ee2edf5..020fc421 100644 --- a/src/context/call_reply.rs +++ b/src/context/call_reply.rs @@ -42,6 +42,11 @@ pub struct ErrorCallReply<'root> { _dummy: PhantomData<&'root ()>, } +pub enum ErrorReply<'root> { + Msg(String), + RedisError(ErrorCallReply<'root>), +} + impl<'root> ErrorCallReply<'root> { /// Convert ErrorCallReply to String. /// Return None data is not a valid utf8. @@ -59,7 +64,26 @@ impl<'root> ErrorCallReply<'root> { } } -impl<'root> Debug for ErrorCallReply<'root> { +impl<'root> ErrorReply<'root> { + /// Convert ErrorCallReply to String. + /// Return None data is not a valid utf8. + pub fn to_string(&self) -> Option { + match self { + ErrorReply::Msg(s) => Some(s.clone()), + ErrorReply::RedisError(r) => r.to_string(), + } + } + + /// Return the ErrorCallReply data as &[u8] + pub fn as_bytes(&self) -> &[u8] { + match self { + ErrorReply::Msg(s) => s.as_bytes(), + ErrorReply::RedisError(r) => r.as_bytes(), + } + } +} + +impl<'root> Debug for ErrorReply<'root> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!( f, @@ -332,16 +356,13 @@ impl<'root> VerbatimStringCallReply<'root> { /// Return None if the format is not a valid utf8. pub fn as_parts(&self) -> Option<(&str, &[u8])> { let mut len: usize = 0; - let format: *const u8 = std::ptr::null(); + let mut format: *const c_char = std::ptr::null(); let reply_string: *mut u8 = unsafe { - RedisModule_CallReplyVerbatim.unwrap()( - self.reply.as_ptr(), - &mut len, - &mut (format as *const c_char), - ) as *mut u8 + RedisModule_CallReplyVerbatim.unwrap()(self.reply.as_ptr(), &mut len, &mut format) + as *mut u8 }; Some(( - std::str::from_utf8(unsafe { slice::from_raw_parts(format, 3) }) + std::str::from_utf8(unsafe { slice::from_raw_parts(format as *const u8, 3) }) .ok() .unwrap(), unsafe { slice::from_raw_parts(reply_string, len) }, @@ -381,10 +402,10 @@ fn create_call_reply<'root>(reply: NonNull) -> CallResult< reply: reply, _dummy: PhantomData, })), - ReplyType::Error => Err(ErrorCallReply { + ReplyType::Error => Err(ErrorReply::RedisError(ErrorCallReply { reply: reply, _dummy: PhantomData, - }), + })), ReplyType::Array => Ok(CallReply::Array(ArrayCallReply { reply: reply, _dummy: PhantomData, @@ -426,4 +447,4 @@ pub(crate) fn create_root_call_reply<'root>( reply.map_or(Ok(CallReply::Unknown), |v| create_call_reply(v)) } -pub type CallResult<'root> = Result, ErrorCallReply<'root>>; +pub type CallResult<'root> = Result, ErrorReply<'root>>; diff --git a/src/context/mod.rs b/src/context/mod.rs index ac187703..d411daf6 100644 --- a/src/context/mod.rs +++ b/src/context/mod.rs @@ -1,5 +1,5 @@ use bitflags::bitflags; -use std::borrow::Borrow; +use std::cell::UnsafeCell; use std::ffi::CString; use std::os::raw::{c_char, c_int, c_long, c_longlong}; use std::ptr::{self, NonNull}; @@ -37,6 +37,7 @@ pub struct CallOptionsBuilder { options: String, } +#[derive(Clone)] pub struct CallOptions { options: CString, } @@ -100,7 +101,7 @@ impl CallOptionsBuilder { } /// Allow control the protocol version in which the replies will be returned. - pub fn resp_3(mut self, resp: CallOptionResp) -> CallOptionsBuilder { + pub fn resp(mut self, resp: CallOptionResp) -> CallOptionsBuilder { match resp { CallOptionResp::Auto => self.add_flag("0"), CallOptionResp::Resp2 => (), @@ -117,12 +118,80 @@ impl CallOptionsBuilder { } } +/// This struct allows logging when the Redis GIL is not acquired. +/// It is implemented `Send` and `Sync` so it can safely be used +/// from within different threads. +pub struct DetachContext { + ctx: UnsafeCell>>, +} + +impl Default for DetachContext { + fn default() -> Self { + DetachContext { + ctx: UnsafeCell::new(None), + } + } +} + +impl DetachContext { + pub fn log(&self, level: LogLevel, message: &str) { + let c = unsafe { &*self.ctx.get() }; + crate::logging::log_internal(c.map_or(ptr::null_mut(), |v| v.as_ptr()), level, message); + } + + pub fn log_debug(&self, message: &str) { + self.log(LogLevel::Debug, message); + } + + pub fn log_notice(&self, message: &str) { + self.log(LogLevel::Notice, message); + } + + pub fn log_verbose(&self, message: &str) { + self.log(LogLevel::Verbose, message); + } + + pub fn log_warning(&self, message: &str) { + self.log(LogLevel::Warning, message); + } + + pub fn set_context(&self, ctx: &Context) { + let curr = unsafe { &mut *self.ctx.get() }; + let ctx = unsafe { raw::RedisModule_GetDetachedThreadSafeContext.unwrap()(ctx.ctx) }; + *curr = NonNull::new(ctx); + } +} + +unsafe impl Send for DetachContext {} +unsafe impl Sync for DetachContext {} + /// `Context` is a structure that's designed to give us a high-level interface to /// the Redis module API by abstracting away the raw C FFI calls. pub struct Context { pub ctx: *mut raw::RedisModuleCtx, } +/// A guerd that protected a user that has +/// been set on a context using `autenticate_user`. +/// This guerd make sure to unset the user when freed. +/// It prevent privilege escalation security issues +/// that can happened by forgeting to unset the user. +pub struct ContextUserScope<'ctx> { + ctx: &'ctx Context, +} + +impl<'ctx> Drop for ContextUserScope<'ctx> { + fn drop(&mut self) { + self.ctx.deautenticate_user(); + } +} + +impl<'ctx> ContextUserScope<'ctx> { + fn new(ctx: &'ctx Context) -> ContextUserScope<'ctx> { + ContextUserScope { ctx } + } +} + pub struct StrCallArgs<'a> { is_owner: bool, args: Vec<*mut raw::RedisModuleString>, @@ -173,7 +242,7 @@ where } impl<'a> StrCallArgs<'a> { - fn args_mut(&mut self) -> &mut [*mut raw::RedisModuleString] { + pub(crate) fn args_mut(&mut self) -> &mut [*mut raw::RedisModuleString] { &mut self.args } } @@ -422,8 +491,13 @@ impl Context { raw::replicate_verbatim(self.ctx); } + /// Replicate command to the replica and AOF. + pub fn replicate<'a, T: Into>>(&self, command: &str, args: T) { + raw::replicate(self.ctx, command, args); + } + #[must_use] - pub fn create_string(&self, s: &str) -> RedisString { + pub fn create_string>>(&self, s: T) -> RedisString { RedisString::create(NonNull::new(self.ctx), s) } @@ -534,19 +608,20 @@ impl Context { /// Attach the given user to the current context so each operation performed from /// now on using this context will be validated againts this new user. /// Return Status::Ok on success and Status::Err or failure. - pub fn autenticate_user>(&self, user_name: T) -> raw::Status { - let user_name_blob: &[u8] = user_name.borrow(); - unsafe { - raw::RedisModule_AuthenticateClientWithACLUser.unwrap()( - self.ctx, - user_name_blob.as_ptr() as *const c_char, - user_name_blob.len(), - None, - ptr::null_mut(), - ptr::null_mut(), - ) + pub fn autenticate_user( + &self, + user_name: &RedisString, + ) -> Result, RedisError> { + let user = unsafe { raw::RedisModule_GetModuleUserFromUserName.unwrap()(user_name.inner) }; + if user.is_null() { + return Err(RedisError::Str("User does not exists or disabled")); } - .into() + unsafe { raw::RedisModule_SetContextUser.unwrap()(self.ctx, user) }; + Ok(ContextUserScope::new(self)) + } + + fn deautenticate_user(&self) { + unsafe { raw::RedisModule_SetContextUser.unwrap()(self.ctx, ptr::null_mut()) }; } /// Verify the the given user has the give ACL permission on the given key. diff --git a/src/context/thread_safe.rs b/src/context/thread_safe.rs index 50836cb7..e42069ad 100644 --- a/src/context/thread_safe.rs +++ b/src/context/thread_safe.rs @@ -164,7 +164,8 @@ impl ThreadSafeContext { /// similar to `std::sync::Mutex`. pub fn lock(&self) -> ContextGuard { unsafe { raw::RedisModule_ThreadSafeContextLock.unwrap()(self.ctx) }; - let ctx = Context::new(self.ctx); + let ctx = unsafe { raw::RedisModule_GetThreadSafeContext.unwrap()(ptr::null_mut()) }; + let ctx = Context::new(ctx); ContextGuard { ctx } } } diff --git a/src/include/redismodule.h b/src/include/redismodule.h index 4d45a237..efd2fb35 100644 --- a/src/include/redismodule.h +++ b/src/include/redismodule.h @@ -1211,7 +1211,7 @@ REDISMODULE_API int (*RedisModule_LoadConfigs)(RedisModuleCtx *ctx) REDISMODULE_ /* This is included inline inside each Redis module. */ static int RedisModule_Init(RedisModuleCtx *ctx, const char *name, int ver, int apiver) REDISMODULE_ATTR_UNUSED; -static int RedisModule_Init(RedisModuleCtx *ctx, const char *name, int ver, int apiver) { +static void RedisModule_InitAPI(RedisModuleCtx *ctx) { void *getapifuncptr = ((void**)ctx)[0]; RedisModule_GetApi = (int (*)(const char *, void *)) (unsigned long)getapifuncptr; REDISMODULE_GET_API(Alloc); @@ -1547,7 +1547,10 @@ static int RedisModule_Init(RedisModuleCtx *ctx, const char *name, int ver, int REDISMODULE_GET_API(RegisterStringConfig); REDISMODULE_GET_API(RegisterEnumConfig); REDISMODULE_GET_API(LoadConfigs); +} +static int RedisModule_Init(RedisModuleCtx *ctx, const char *name, int ver, int apiver) { + RedisModule_InitAPI(ctx); if (RedisModule_IsModuleNameBusy && RedisModule_IsModuleNameBusy(name)) return REDISMODULE_ERR; RedisModule_SetModuleAttribs(ctx,name,ver,apiver); return REDISMODULE_OK; diff --git a/src/lib.rs b/src/lib.rs index 5f468218..32476663 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,18 +28,21 @@ pub use crate::context::thread_safe::{DetachedFromClient, ThreadSafeContext}; #[cfg(feature = "experimental-api")] pub use crate::raw::NotifyEvent; -pub use crate::context::call_reply::{CallReply, CallResult}; pub use crate::configuration::ConfigurationValue; pub use crate::configuration::EnumConfigurationValue; +pub use crate::context::call_reply::{CallReply, CallResult, ErrorReply}; pub use crate::context::keys_cursor::KeysCursor; pub use crate::context::server_events; +pub use crate::context::thread_safe::ContextGuard; pub use crate::context::thread_safe::RedisGILGuard; +pub use crate::context::thread_safe::RedisLockIndicator; pub use crate::context::AclPermissions; pub use crate::context::CallOptionResp; pub use crate::context::CallOptions; pub use crate::context::CallOptionsBuilder; pub use crate::context::Context; pub use crate::context::ContextFlags; +pub use crate::context::DetachContext; pub use crate::raw::*; pub use crate::redismodule::*; use backtrace::Backtrace; @@ -78,3 +81,8 @@ pub fn base_info_func( func(ctx, for_crash_report); } } + +/// Initialize RedisModuleAPI without register as a module. +pub fn init_api(ctx: &Context) { + unsafe { crate::raw::Export_RedisModule_InitAPI(ctx.ctx) }; +} diff --git a/src/macros.rs b/src/macros.rs index 39bebb7d..bb1b5efb 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -57,7 +57,7 @@ macro_rules! redis_event_handler { ) -> c_int { let context = $crate::Context::new(ctx); - let redis_key = $crate::RedisString::from_ptr(key).unwrap(); + let redis_key = $crate::RedisString::string_as_slice(key); let event_str = unsafe { CStr::from_ptr(event) }; $event_handler( &context, @@ -108,7 +108,7 @@ macro_rules! redis_module { $(@$event_type:ident) +: $event_handler:expr ]),* $(,)* - ])? + ] $(,)* )? $(configurations: [ $(i64:[$([ $i64_configuration_name:expr, @@ -196,12 +196,6 @@ macro_rules! redis_module { let context = $crate::Context::new(ctx); let args = $crate::decode_args(ctx, argv, argc); - $( - if $init_func(&context, &args) == $crate::Status::Err { - return $crate::Status::Err as c_int; - } - )* - $( if (&$data_type).create_data_type(ctx).is_err() { return raw::Status::Err as c_int; @@ -242,7 +236,7 @@ macro_rules! redis_module { raw::RedisModule_LoadConfigs.unwrap()(ctx); $( if $use_module_args { - if let Err(e) = apply_module_args_as_configuration(&context, args) { + if let Err(e) = apply_module_args_as_configuration(&context, &args) { context.log_warning(&e.to_string()); return raw::Status::Err as c_int; } @@ -257,6 +251,12 @@ macro_rules! redis_module { return raw::Status::Err as c_int; } + $( + if $init_func(&context, &args) == $crate::Status::Err { + return $crate::Status::Err as c_int; + } + )* + raw::Status::Ok as c_int } diff --git a/src/raw.rs b/src/raw.rs index 479d3929..3f21fe44 100644 --- a/src/raw.rs +++ b/src/raw.rs @@ -8,9 +8,8 @@ extern crate num_traits; use std::cmp::Ordering; use std::ffi::{CStr, CString}; -use std::os::raw::{c_char, c_double, c_int, c_long, c_longlong}; +use std::os::raw::{c_char, c_double, c_int, c_long, c_longlong, c_void}; use std::ptr; -use std::ptr::NonNull; use std::slice; use bitflags::bitflags; @@ -20,7 +19,7 @@ use num_traits::FromPrimitive; use crate::error::Error; pub use crate::redisraw::bindings::*; -use crate::{Context, RedisString}; +use crate::{context::StrCallArgs, Context, RedisString}; use crate::{RedisBuffer, RedisError}; bitflags! { @@ -178,6 +177,8 @@ extern "C" { module_version: c_int, api_version: c_int, ) -> c_int; + + pub fn Export_RedisModule_InitAPI(ctx: *mut RedisModuleCtx) -> c_void; } /////////////////////////////////////////////////////////////// @@ -674,13 +675,13 @@ pub fn load_string_buffer(rdb: *mut RedisModuleIO) -> Result } #[allow(clippy::not_unsafe_ptr_arg_deref)] -pub fn replicate(ctx: *mut RedisModuleCtx, command: &str, args: &[&str]) -> Status { - let terminated_args: Vec = args - .iter() - .map(|s| RedisString::create(NonNull::new(ctx), s)) - .collect(); - - let inner_args: Vec<*mut RedisModuleString> = terminated_args.iter().map(|s| s.inner).collect(); +pub fn replicate<'a, T: Into>>( + ctx: *mut RedisModuleCtx, + command: &str, + args: T, +) -> Status { + let mut call_args: StrCallArgs = args.into(); + let final_args = call_args.args_mut(); let cmd = CString::new(command).unwrap(); @@ -689,8 +690,8 @@ pub fn replicate(ctx: *mut RedisModuleCtx, command: &str, args: &[&str]) -> Stat ctx, cmd.as_ptr(), FMT, - inner_args.as_ptr(), - terminated_args.len(), + final_args.as_ptr(), + final_args.len(), ) .into() } @@ -711,6 +712,18 @@ pub fn save_string(rdb: *mut RedisModuleIO, buf: &str) { unsafe { RedisModule_SaveStringBuffer.unwrap()(rdb, buf.as_ptr().cast::(), buf.len()) }; } +#[allow(clippy::not_unsafe_ptr_arg_deref)] +/// Save the `RedisString` into the RDB +pub fn save_redis_string(rdb: *mut RedisModuleIO, s: &RedisString) { + unsafe { RedisModule_SaveString.unwrap()(rdb, s.inner) }; +} + +#[allow(clippy::not_unsafe_ptr_arg_deref)] +/// Save the `&[u8]` into the RDB +pub fn save_slice(rdb: *mut RedisModuleIO, buf: &[u8]) { + unsafe { RedisModule_SaveStringBuffer.unwrap()(rdb, buf.as_ptr().cast::(), buf.len()) }; +} + #[allow(clippy::not_unsafe_ptr_arg_deref)] pub fn save_double(rdb: *mut RedisModuleIO, val: f64) { unsafe { RedisModule_SaveDouble.unwrap()(rdb, val) }; @@ -819,7 +832,7 @@ pub fn get_keyspace_events() -> NotifyEvent { } } -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] pub struct Version { pub major: i32, pub minor: i32, diff --git a/src/rediserror.rs b/src/rediserror.rs index ecccddb8..5caa17d2 100644 --- a/src/rediserror.rs +++ b/src/rediserror.rs @@ -1,4 +1,4 @@ -use crate::context::call_reply::ErrorCallReply; +use crate::context::call_reply::{ErrorCallReply, ErrorReply}; pub use crate::raw; use std::ffi::CStr; use std::fmt; @@ -20,6 +20,15 @@ impl<'root> From> for RedisError { } } +impl<'root> From> for RedisError { + fn from(err: ErrorReply<'root>) -> Self { + RedisError::String( + err.to_string() + .unwrap_or("can not convert error into String".into()), + ) + } +} + impl RedisError { #[must_use] pub const fn nonexistent_key() -> Self { diff --git a/src/redismodule.c b/src/redismodule.c index 5e8afc49..0f18162b 100644 --- a/src/redismodule.c +++ b/src/redismodule.c @@ -7,3 +7,7 @@ int Export_RedisModule_Init(RedisModuleCtx *ctx, const char *name, int ver, int apiver) { return RedisModule_Init(ctx, name, ver, apiver); } + +void Export_RedisModule_InitAPI(RedisModuleCtx *ctx) { + RedisModule_InitAPI(ctx); +} diff --git a/src/redismodule.rs b/src/redismodule.rs index 4b8072df..f46db017 100644 --- a/src/redismodule.rs +++ b/src/redismodule.rs @@ -11,9 +11,12 @@ use std::str; use std::str::Utf8Error; use std::string::FromUtf8Error; +use serde::de::{Error, SeqAccess}; + pub use crate::raw; pub use crate::rediserror::RedisError; pub use crate::redisvalue::RedisValue; +use crate::Context; pub type RedisResult = Result; @@ -124,11 +127,26 @@ impl RedisString { Self { ctx, inner } } + /// Safely clone `RedisString` + /// In general `RedisModuleString` is none atomic ref counted object. + /// So it is not safe to clone it if Redis GIL is not hold. + /// `safe_clone` gets a context reference which indicating that Redis GIL is held. + pub fn safe_clone(&self, ctx: &Context) -> Self { + // RedisString are *not* atomic ref counted, so we must get a lock indicator to clone them. + raw::string_retain_string(ctx.ctx, self.inner); + Self { + ctx: ctx.ctx, + inner: self.inner, + } + } + #[allow(clippy::not_unsafe_ptr_arg_deref)] - pub fn create(ctx: Option>, s: &str) -> Self { + pub fn create>>(ctx: Option>, s: T) -> Self { let ctx = ctx.map_or(std::ptr::null_mut(), |v| v.as_ptr()); let str = CString::new(s).unwrap(); - let inner = unsafe { raw::RedisModule_CreateString.unwrap()(ctx, str.as_ptr(), s.len()) }; + let inner = unsafe { + raw::RedisModule_CreateString.unwrap()(ctx, str.as_ptr(), str.as_bytes().len()) + }; Self { ctx, inner } } @@ -182,7 +200,7 @@ impl RedisString { Self::string_as_slice(self.inner) } - fn string_as_slice<'a>(ptr: *const raw::RedisModuleString) -> &'a [u8] { + pub fn string_as_slice<'a>(ptr: *const raw::RedisModuleString) -> &'a [u8] { let mut len: libc::size_t = 0; let bytes = unsafe { raw::RedisModule_StringPtrLen.unwrap()(ptr, &mut len) }; @@ -307,6 +325,53 @@ impl From for Vec { } } +impl serde::Serialize for RedisString { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_bytes(self.as_slice()) + } +} + +struct RedisStringVisitor; + +impl<'de> serde::de::Visitor<'de> for RedisStringVisitor { + type Value = RedisString; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("A bytes buffer") + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: Error, + { + Ok(RedisString::create(None, v)) + } + + fn visit_seq(self, mut visitor: V) -> Result + where + V: SeqAccess<'de>, + { + let mut v = Vec::new(); + while let Some(elem) = visitor.next_element()? { + v.push(elem); + } + + Ok(RedisString::create(None, v.as_slice())) + } +} + +impl<'de> serde::Deserialize<'de> for RedisString { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_bytes(RedisStringVisitor) + } +} + /////////////////////////////////////////////////// #[derive(Debug)]