Skip to content

Commit

Permalink
fix(ext/node): rewrite SQLite named parameter handing (#28197)
Browse files Browse the repository at this point in the history
Allow bare named params and handle invalid param name. Also adds
`StatementSync#setAllowBareNamedParameters`

Fixes #28183
  • Loading branch information
littledivy authored Feb 20, 2025
1 parent 664d50f commit c1276d8
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 130 deletions.
1 change: 1 addition & 0 deletions ext/node/ops/sqlite/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ impl DatabaseSync {
inner: raw_stmt,
db: self.conn.clone(),
use_big_ints: Cell::new(false),
allow_bare_named_params: Cell::new(true),
is_iter_finished: false,
})
}
Expand Down
236 changes: 106 additions & 130 deletions ext/node/ops/sqlite/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ pub struct StatementSync {
pub db: Rc<RefCell<Option<rusqlite::Connection>>>,

pub use_big_ints: Cell<bool>,
pub allow_bare_named_params: Cell<bool>,

pub is_iter_finished: bool,
}

Expand Down Expand Up @@ -208,153 +210,32 @@ impl StatementSync {
Ok(Some(result))
}

// Bind the parameters to the prepared statement.
fn bind_params(
fn bind_value(
&self,
scope: &mut v8::HandleScope,
params: Option<&v8::FunctionCallbackArguments>,
) -> Result<(), SqliteError> {
let raw = self.inner;

if let Some(params) = params {
let len = params.length();
let mut param_count = 0;
for i in 0..len {
let value = params.get(i);

if value.is_number() {
let value = value.number_value(scope).unwrap();

// SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt
// as it lives as long as the StatementSync instance.
unsafe {
ffi::sqlite3_bind_double(raw, param_count + 1, value);
}
param_count += 1
} else if value.is_string() {
let value = value.to_rust_string_lossy(scope);

// SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt
// as it lives as long as the StatementSync instance.
//
// SQLITE_TRANSIENT is used to indicate that SQLite should make a copy of the data.
unsafe {
ffi::sqlite3_bind_text(
raw,
param_count + 1,
value.as_ptr() as *const _,
value.len() as i32,
ffi::SQLITE_TRANSIENT(),
);
}
param_count += 1;
} else if value.is_null() {
// SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt
// as it lives as long as the StatementSync instance.
unsafe {
ffi::sqlite3_bind_null(raw, param_count + 1);
}
param_count += 1;
} else if value.is_array_buffer_view() {
let value: v8::Local<v8::ArrayBufferView> = value.try_into().unwrap();
let data = value.data();
let size = value.byte_length();

// SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt
// as it lives as long as the StatementSync instance.
//
// SQLITE_TRANSIENT is used to indicate that SQLite should make a copy of the data.
unsafe {
ffi::sqlite3_bind_blob(
raw,
param_count + 1,
data,
size as i32,
ffi::SQLITE_TRANSIENT(),
);
}
param_count += 1
} else if value.is_big_int() {
let value: v8::Local<v8::BigInt> = value.try_into().unwrap();
let (as_int, lossless) = value.i64_value();
if !lossless {
return Err(SqliteError::FailedBind(
"BigInt value is too large to bind",
));
}

// SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt
// as it lives as long as the StatementSync instance.
unsafe {
ffi::sqlite3_bind_int64(raw, param_count + 1, as_int);
}
param_count += 1;
} else if value.is_object() {
let value: v8::Local<v8::Object> = value.try_into().unwrap();
let maybe_keys =
value.get_property_names(scope, GetPropertyNamesArgs::default());

if let Some(keys) = maybe_keys {
let length = keys.length();
for i in 0..length {
if let Some(key) = keys.get_index(scope, i) {
if let Some(key_str) = key.to_string(scope) {
let key_str = key_str.to_rust_string_lossy(scope);
if let Some(value) = value.get(scope, key) {
self.bind_params_object(
scope,
key_str,
value,
param_count,
)?;
param_count += 1;
}
}
}
}
}
} else {
return Err(SqliteError::FailedBind("Unsupported type"));
}
}
}

Ok(())
}

//helper function that binds the object attribute to named param
fn bind_params_object(
&self,
scope: &mut v8::HandleScope,
key: String,
value: v8::Local<v8::Value>,
count: i32,
index: i32,
) -> Result<(), SqliteError> {
let raw = self.inner;

// SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt
// as it lives as long as the StatementSync instance.
unsafe {
ffi::sqlite3_bind_parameter_index(raw, key.as_ptr() as *const _);
}

if value.is_number() {
let value = value.number_value(scope).unwrap();

// SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt
// as it lives as long as the StatementSync instance.
unsafe {
ffi::sqlite3_bind_double(raw, count + 1, value);
ffi::sqlite3_bind_double(raw, index, value);
}
} else if value.is_string() {
let value = value.to_rust_string_lossy(scope);

// SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt
// as it lives as long as the StatementSync instance.
//
// SQLITE_TRANSIENT is used to indicate that SQLite should make a copy of the data.
unsafe {
ffi::sqlite3_bind_text(
raw,
count + 1,
index,
value.as_ptr() as *const _,
value.len() as i32,
ffi::SQLITE_TRANSIENT(),
Expand All @@ -364,7 +245,7 @@ impl StatementSync {
// SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt
// as it lives as long as the StatementSync instance.
unsafe {
ffi::sqlite3_bind_null(raw, count + 1);
ffi::sqlite3_bind_null(raw, index);
}
} else if value.is_array_buffer_view() {
let value: v8::Local<v8::ArrayBufferView> = value.try_into().unwrap();
Expand All @@ -378,7 +259,7 @@ impl StatementSync {
unsafe {
ffi::sqlite3_bind_blob(
raw,
count + 1,
index,
data,
size as i32,
ffi::SQLITE_TRANSIENT(),
Expand All @@ -396,11 +277,101 @@ impl StatementSync {
// SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt
// as it lives as long as the StatementSync instance.
unsafe {
ffi::sqlite3_bind_int64(raw, count + 1, as_int);
ffi::sqlite3_bind_int64(raw, index, as_int);
}
} else {
return Err(SqliteError::FailedBind("Unsupported type"));
}

Ok(())
}

// Bind the parameters to the prepared statement.
fn bind_params(
&self,
scope: &mut v8::HandleScope,
params: Option<&v8::FunctionCallbackArguments>,
) -> Result<(), SqliteError> {
let raw = self.inner;
let mut anon_start = 0;

if let Some(params) = params {
let param0 = params.get(0);

if param0.is_object() && !param0.is_array_buffer_view() {
let obj = v8::Local::<v8::Object>::try_from(param0).unwrap();
let keys = obj
.get_property_names(scope, GetPropertyNamesArgs::default())
.unwrap();

// Allow specifying named parameters without the SQLite prefix character to improve
// ergonomics. This can be disabled with `StatementSync#setAllowBareNamedParams`.
let mut bare_named_params = std::collections::HashMap::new();
if self.allow_bare_named_params.get() {
// SAFETY: `raw` is a valid pointer to a sqlite3_stmt.
let param_count = unsafe { ffi::sqlite3_bind_parameter_count(raw) };
for i in 1..=param_count {
// SAFETY: `raw` is a valid pointer to a sqlite3_stmt.
let bare_name = unsafe {
let name = ffi::sqlite3_bind_parameter_name(raw, i);
if name.is_null() {
continue;
}
std::ffi::CStr::from_ptr(name.offset(1)).to_bytes()
};

let e = bare_named_params.insert(bare_name, i);
if e.is_some() {
return Err(SqliteError::FailedBind("Duplicate named parameter"));
}
}
}

let len = keys.length();
for j in 0..len {
let key = keys.get_index(scope, j).unwrap();
let key_str = key.to_rust_string_lossy(scope);
let key_c = std::ffi::CString::new(key_str).unwrap();

// SAFETY: `raw` is a valid pointer to a sqlite3_stmt.
let mut r = unsafe {
ffi::sqlite3_bind_parameter_index(raw, key_c.as_ptr() as *const _)
};
if r == 0 {
let lookup = bare_named_params.get(key_c.as_bytes());
if let Some(index) = lookup {
r = *index;
}

if r == 0 {
return Err(SqliteError::FailedBind("Named parameter not found"));
}
}

let value = obj.get(scope, key).unwrap();
self.bind_value(scope, value, r)?;
}

anon_start += 1;
}

let mut anon_idx = 1;
for i in anon_start..params.length() {
// SAFETY: `raw` is a valid pointer to a sqlite3_stmt.
while !unsafe { ffi::sqlite3_bind_parameter_name(raw, anon_idx) }
.is_null()
{
anon_idx += 1;
}

let value = params.get(i);

self.bind_value(scope, value, anon_idx)?;

anon_idx += 1;
}
}

Ok(())
}
}
Expand Down Expand Up @@ -621,6 +592,11 @@ impl StatementSync {
Ok(iterator)
}

#[fast]
fn set_allow_bare_named_parameters(&self, enabled: bool) {
self.allow_bare_named_params.set(enabled);
}

#[fast]
fn set_read_big_ints(&self, enabled: bool) {
self.use_big_ints.set(enabled);
Expand Down
14 changes: 14 additions & 0 deletions tests/unit_node/sqlite_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,20 @@ Deno.test("[node/sqlite] query should handle mixed positional and named paramete
variable3: 2,
}]);

const result2 = db.prepare(query).all({ var2: 1, var1: "test" });
assertEquals(result2, [{
__proto__: null,
variable1: "test",
variable2: 1,
variable3: 2,
}]);

const stmt = db.prepare(query);
stmt.setAllowBareNamedParameters(false);
assertThrows(() => {
stmt.all({ var1: "test", var2: 1 });
});

db.close();
});

Expand Down

0 comments on commit c1276d8

Please sign in to comment.