diff --git a/ext/node/ops/sqlite/database.rs b/ext/node/ops/sqlite/database.rs index d0ddc79a0cd4e7..7f6d40bbf3d1e2 100644 --- a/ext/node/ops/sqlite/database.rs +++ b/ext/node/ops/sqlite/database.rs @@ -205,6 +205,7 @@ impl DatabaseSync { inner: raw_stmt, db: self.conn.clone(), use_big_ints: Cell::new(false), + allow_bare_named_params: Cell::new(true), is_iter_finished: false, }) } diff --git a/ext/node/ops/sqlite/statement.rs b/ext/node/ops/sqlite/statement.rs index 3682eec822a1d3..1c0a1d677ac203 100644 --- a/ext/node/ops/sqlite/statement.rs +++ b/ext/node/ops/sqlite/statement.rs @@ -29,6 +29,8 @@ pub struct StatementSync { pub db: Rc>>, pub use_big_ints: Cell, + pub allow_bare_named_params: Cell, + pub is_iter_finished: bool, } @@ -208,153 +210,32 @@ impl StatementSync { Ok(Some(result)) } - // Bind the parameters to the prepared statement. - fn bind_params( + fn bind_value( &self, scope: &mut v8::HandleScope, - params: Option<&v8::FunctionCallbackArguments>, - ) -> Result<(), SqliteError> { - let raw = self.inner; - - if let Some(params) = params { - let len = params.length(); - let mut param_count = 0; - for i in 0..len { - let value = params.get(i); - - if value.is_number() { - let value = value.number_value(scope).unwrap(); - - // SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt - // as it lives as long as the StatementSync instance. - unsafe { - ffi::sqlite3_bind_double(raw, param_count + 1, value); - } - param_count += 1 - } else if value.is_string() { - let value = value.to_rust_string_lossy(scope); - - // SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt - // as it lives as long as the StatementSync instance. - // - // SQLITE_TRANSIENT is used to indicate that SQLite should make a copy of the data. - unsafe { - ffi::sqlite3_bind_text( - raw, - param_count + 1, - value.as_ptr() as *const _, - value.len() as i32, - ffi::SQLITE_TRANSIENT(), - ); - } - param_count += 1; - } else if value.is_null() { - // SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt - // as it lives as long as the StatementSync instance. - unsafe { - ffi::sqlite3_bind_null(raw, param_count + 1); - } - param_count += 1; - } else if value.is_array_buffer_view() { - let value: v8::Local = value.try_into().unwrap(); - let data = value.data(); - let size = value.byte_length(); - - // SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt - // as it lives as long as the StatementSync instance. - // - // SQLITE_TRANSIENT is used to indicate that SQLite should make a copy of the data. - unsafe { - ffi::sqlite3_bind_blob( - raw, - param_count + 1, - data, - size as i32, - ffi::SQLITE_TRANSIENT(), - ); - } - param_count += 1 - } else if value.is_big_int() { - let value: v8::Local = value.try_into().unwrap(); - let (as_int, lossless) = value.i64_value(); - if !lossless { - return Err(SqliteError::FailedBind( - "BigInt value is too large to bind", - )); - } - - // SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt - // as it lives as long as the StatementSync instance. - unsafe { - ffi::sqlite3_bind_int64(raw, param_count + 1, as_int); - } - param_count += 1; - } else if value.is_object() { - let value: v8::Local = value.try_into().unwrap(); - let maybe_keys = - value.get_property_names(scope, GetPropertyNamesArgs::default()); - - if let Some(keys) = maybe_keys { - let length = keys.length(); - for i in 0..length { - if let Some(key) = keys.get_index(scope, i) { - if let Some(key_str) = key.to_string(scope) { - let key_str = key_str.to_rust_string_lossy(scope); - if let Some(value) = value.get(scope, key) { - self.bind_params_object( - scope, - key_str, - value, - param_count, - )?; - param_count += 1; - } - } - } - } - } - } else { - return Err(SqliteError::FailedBind("Unsupported type")); - } - } - } - - Ok(()) - } - - //helper function that binds the object attribute to named param - fn bind_params_object( - &self, - scope: &mut v8::HandleScope, - key: String, value: v8::Local, - count: i32, + index: i32, ) -> Result<(), SqliteError> { let raw = self.inner; - - // SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt - // as it lives as long as the StatementSync instance. - unsafe { - ffi::sqlite3_bind_parameter_index(raw, key.as_ptr() as *const _); - } - if value.is_number() { let value = value.number_value(scope).unwrap(); // SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt // as it lives as long as the StatementSync instance. unsafe { - ffi::sqlite3_bind_double(raw, count + 1, value); + ffi::sqlite3_bind_double(raw, index, value); } } else if value.is_string() { let value = value.to_rust_string_lossy(scope); // SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt // as it lives as long as the StatementSync instance. + // + // SQLITE_TRANSIENT is used to indicate that SQLite should make a copy of the data. unsafe { ffi::sqlite3_bind_text( raw, - count + 1, + index, value.as_ptr() as *const _, value.len() as i32, ffi::SQLITE_TRANSIENT(), @@ -364,7 +245,7 @@ impl StatementSync { // SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt // as it lives as long as the StatementSync instance. unsafe { - ffi::sqlite3_bind_null(raw, count + 1); + ffi::sqlite3_bind_null(raw, index); } } else if value.is_array_buffer_view() { let value: v8::Local = value.try_into().unwrap(); @@ -378,7 +259,7 @@ impl StatementSync { unsafe { ffi::sqlite3_bind_blob( raw, - count + 1, + index, data, size as i32, ffi::SQLITE_TRANSIENT(), @@ -396,11 +277,101 @@ impl StatementSync { // SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt // as it lives as long as the StatementSync instance. unsafe { - ffi::sqlite3_bind_int64(raw, count + 1, as_int); + ffi::sqlite3_bind_int64(raw, index, as_int); } } else { return Err(SqliteError::FailedBind("Unsupported type")); } + + Ok(()) + } + + // Bind the parameters to the prepared statement. + fn bind_params( + &self, + scope: &mut v8::HandleScope, + params: Option<&v8::FunctionCallbackArguments>, + ) -> Result<(), SqliteError> { + let raw = self.inner; + let mut anon_start = 0; + + if let Some(params) = params { + let param0 = params.get(0); + + if param0.is_object() && !param0.is_array_buffer_view() { + let obj = v8::Local::::try_from(param0).unwrap(); + let keys = obj + .get_property_names(scope, GetPropertyNamesArgs::default()) + .unwrap(); + + // Allow specifying named parameters without the SQLite prefix character to improve + // ergonomics. This can be disabled with `StatementSync#setAllowBareNamedParams`. + let mut bare_named_params = std::collections::HashMap::new(); + if self.allow_bare_named_params.get() { + // SAFETY: `raw` is a valid pointer to a sqlite3_stmt. + let param_count = unsafe { ffi::sqlite3_bind_parameter_count(raw) }; + for i in 1..=param_count { + // SAFETY: `raw` is a valid pointer to a sqlite3_stmt. + let bare_name = unsafe { + let name = ffi::sqlite3_bind_parameter_name(raw, i); + if name.is_null() { + continue; + } + std::ffi::CStr::from_ptr(name.offset(1)).to_bytes() + }; + + let e = bare_named_params.insert(bare_name, i); + if e.is_some() { + return Err(SqliteError::FailedBind("Duplicate named parameter")); + } + } + } + + let len = keys.length(); + for j in 0..len { + let key = keys.get_index(scope, j).unwrap(); + let key_str = key.to_rust_string_lossy(scope); + let key_c = std::ffi::CString::new(key_str).unwrap(); + + // SAFETY: `raw` is a valid pointer to a sqlite3_stmt. + let mut r = unsafe { + ffi::sqlite3_bind_parameter_index(raw, key_c.as_ptr() as *const _) + }; + if r == 0 { + let lookup = bare_named_params.get(key_c.as_bytes()); + if let Some(index) = lookup { + r = *index; + } + + if r == 0 { + return Err(SqliteError::FailedBind("Named parameter not found")); + } + } + + let value = obj.get(scope, key).unwrap(); + self.bind_value(scope, value, r)?; + } + + anon_start += 1; + } + + let mut anon_idx = 1; + for i in anon_start..params.length() { + // SAFETY: `raw` is a valid pointer to a sqlite3_stmt. + while !unsafe { ffi::sqlite3_bind_parameter_name(raw, anon_idx) } + .is_null() + { + anon_idx += 1; + } + + let value = params.get(i); + + self.bind_value(scope, value, anon_idx)?; + + anon_idx += 1; + } + } + Ok(()) } } @@ -621,6 +592,11 @@ impl StatementSync { Ok(iterator) } + #[fast] + fn set_allow_bare_named_parameters(&self, enabled: bool) { + self.allow_bare_named_params.set(enabled); + } + #[fast] fn set_read_big_ints(&self, enabled: bool) { self.use_big_ints.set(enabled); diff --git a/tests/unit_node/sqlite_test.ts b/tests/unit_node/sqlite_test.ts index b14d7356c8da92..c633a6d84d1a69 100644 --- a/tests/unit_node/sqlite_test.ts +++ b/tests/unit_node/sqlite_test.ts @@ -214,6 +214,20 @@ Deno.test("[node/sqlite] query should handle mixed positional and named paramete variable3: 2, }]); + const result2 = db.prepare(query).all({ var2: 1, var1: "test" }); + assertEquals(result2, [{ + __proto__: null, + variable1: "test", + variable2: 1, + variable3: 2, + }]); + + const stmt = db.prepare(query); + stmt.setAllowBareNamedParameters(false); + assertThrows(() => { + stmt.all({ var1: "test", var2: 1 }); + }); + db.close(); });