Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ext/node): rewrite SQLite named parameter handing #28197

Merged
merged 2 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ext/node/ops/sqlite/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ impl DatabaseSync {
inner: raw_stmt,
db: self.conn.clone(),
use_big_ints: Cell::new(false),
allow_bare_named_params: Cell::new(true),
is_iter_finished: false,
})
}
Expand Down
236 changes: 106 additions & 130 deletions ext/node/ops/sqlite/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ pub struct StatementSync {
pub db: Rc<RefCell<Option<rusqlite::Connection>>>,

pub use_big_ints: Cell<bool>,
pub allow_bare_named_params: Cell<bool>,

pub is_iter_finished: bool,
}

Expand Down Expand Up @@ -208,153 +210,32 @@ impl StatementSync {
Ok(Some(result))
}

// Bind the parameters to the prepared statement.
fn bind_params(
fn bind_value(
&self,
scope: &mut v8::HandleScope,
params: Option<&v8::FunctionCallbackArguments>,
) -> Result<(), SqliteError> {
let raw = self.inner;

if let Some(params) = params {
let len = params.length();
let mut param_count = 0;
for i in 0..len {
let value = params.get(i);

if value.is_number() {
let value = value.number_value(scope).unwrap();

// SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt
// as it lives as long as the StatementSync instance.
unsafe {
ffi::sqlite3_bind_double(raw, param_count + 1, value);
}
param_count += 1
} else if value.is_string() {
let value = value.to_rust_string_lossy(scope);

// SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt
// as it lives as long as the StatementSync instance.
//
// SQLITE_TRANSIENT is used to indicate that SQLite should make a copy of the data.
unsafe {
ffi::sqlite3_bind_text(
raw,
param_count + 1,
value.as_ptr() as *const _,
value.len() as i32,
ffi::SQLITE_TRANSIENT(),
);
}
param_count += 1;
} else if value.is_null() {
// SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt
// as it lives as long as the StatementSync instance.
unsafe {
ffi::sqlite3_bind_null(raw, param_count + 1);
}
param_count += 1;
} else if value.is_array_buffer_view() {
let value: v8::Local<v8::ArrayBufferView> = value.try_into().unwrap();
let data = value.data();
let size = value.byte_length();

// SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt
// as it lives as long as the StatementSync instance.
//
// SQLITE_TRANSIENT is used to indicate that SQLite should make a copy of the data.
unsafe {
ffi::sqlite3_bind_blob(
raw,
param_count + 1,
data,
size as i32,
ffi::SQLITE_TRANSIENT(),
);
}
param_count += 1
} else if value.is_big_int() {
let value: v8::Local<v8::BigInt> = value.try_into().unwrap();
let (as_int, lossless) = value.i64_value();
if !lossless {
return Err(SqliteError::FailedBind(
"BigInt value is too large to bind",
));
}

// SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt
// as it lives as long as the StatementSync instance.
unsafe {
ffi::sqlite3_bind_int64(raw, param_count + 1, as_int);
}
param_count += 1;
} else if value.is_object() {
let value: v8::Local<v8::Object> = value.try_into().unwrap();
let maybe_keys =
value.get_property_names(scope, GetPropertyNamesArgs::default());

if let Some(keys) = maybe_keys {
let length = keys.length();
for i in 0..length {
if let Some(key) = keys.get_index(scope, i) {
if let Some(key_str) = key.to_string(scope) {
let key_str = key_str.to_rust_string_lossy(scope);
if let Some(value) = value.get(scope, key) {
self.bind_params_object(
scope,
key_str,
value,
param_count,
)?;
param_count += 1;
}
}
}
}
}
} else {
return Err(SqliteError::FailedBind("Unsupported type"));
}
}
}

Ok(())
}

//helper function that binds the object attribute to named param
fn bind_params_object(
&self,
scope: &mut v8::HandleScope,
key: String,
value: v8::Local<v8::Value>,
count: i32,
index: i32,
) -> Result<(), SqliteError> {
let raw = self.inner;

// SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt
// as it lives as long as the StatementSync instance.
unsafe {
ffi::sqlite3_bind_parameter_index(raw, key.as_ptr() as *const _);
}

if value.is_number() {
let value = value.number_value(scope).unwrap();

// SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt
// as it lives as long as the StatementSync instance.
unsafe {
ffi::sqlite3_bind_double(raw, count + 1, value);
ffi::sqlite3_bind_double(raw, index, value);
}
} else if value.is_string() {
let value = value.to_rust_string_lossy(scope);

// SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt
// as it lives as long as the StatementSync instance.
//
// SQLITE_TRANSIENT is used to indicate that SQLite should make a copy of the data.
unsafe {
ffi::sqlite3_bind_text(
raw,
count + 1,
index,
value.as_ptr() as *const _,
value.len() as i32,
ffi::SQLITE_TRANSIENT(),
Expand All @@ -364,7 +245,7 @@ impl StatementSync {
// SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt
// as it lives as long as the StatementSync instance.
unsafe {
ffi::sqlite3_bind_null(raw, count + 1);
ffi::sqlite3_bind_null(raw, index);
}
} else if value.is_array_buffer_view() {
let value: v8::Local<v8::ArrayBufferView> = value.try_into().unwrap();
Expand All @@ -378,7 +259,7 @@ impl StatementSync {
unsafe {
ffi::sqlite3_bind_blob(
raw,
count + 1,
index,
data,
size as i32,
ffi::SQLITE_TRANSIENT(),
Expand All @@ -396,11 +277,101 @@ impl StatementSync {
// SAFETY: `self.inner` is a valid pointer to a sqlite3_stmt
// as it lives as long as the StatementSync instance.
unsafe {
ffi::sqlite3_bind_int64(raw, count + 1, as_int);
ffi::sqlite3_bind_int64(raw, index, as_int);
}
} else {
return Err(SqliteError::FailedBind("Unsupported type"));
}

Ok(())
}

// Bind the parameters to the prepared statement.
fn bind_params(
&self,
scope: &mut v8::HandleScope,
params: Option<&v8::FunctionCallbackArguments>,
) -> Result<(), SqliteError> {
let raw = self.inner;
let mut anon_start = 0;

if let Some(params) = params {
let param0 = params.get(0);

if param0.is_object() && !param0.is_array_buffer_view() {
let obj = v8::Local::<v8::Object>::try_from(param0).unwrap();
let keys = obj
.get_property_names(scope, GetPropertyNamesArgs::default())
.unwrap();

// Allow specifying named parameters without the SQLite prefix character to improve
// ergonomics. This can be disabled with `StatementSync#setAllowBareNamedParams`.
let mut bare_named_params = std::collections::HashMap::new();
if self.allow_bare_named_params.get() {
// SAFETY: `raw` is a valid pointer to a sqlite3_stmt.
let param_count = unsafe { ffi::sqlite3_bind_parameter_count(raw) };
for i in 1..=param_count {
// SAFETY: `raw` is a valid pointer to a sqlite3_stmt.
let bare_name = unsafe {
let name = ffi::sqlite3_bind_parameter_name(raw, i);
if name.is_null() {
continue;
}
std::ffi::CStr::from_ptr(name.offset(1)).to_bytes()
};

let e = bare_named_params.insert(bare_name, i);
if e.is_some() {
return Err(SqliteError::FailedBind("Duplicate named parameter"));
}
}
}

let len = keys.length();
for j in 0..len {
let key = keys.get_index(scope, j).unwrap();
let key_str = key.to_rust_string_lossy(scope);
let key_c = std::ffi::CString::new(key_str).unwrap();

// SAFETY: `raw` is a valid pointer to a sqlite3_stmt.
let mut r = unsafe {
ffi::sqlite3_bind_parameter_index(raw, key_c.as_ptr() as *const _)
};
if r == 0 {
let lookup = bare_named_params.get(key_c.as_bytes());
if let Some(index) = lookup {
r = *index;
}

if r == 0 {
return Err(SqliteError::FailedBind("Named parameter not found"));
}
}

let value = obj.get(scope, key).unwrap();
self.bind_value(scope, value, r)?;
}

anon_start += 1;
}

let mut anon_idx = 1;
for i in anon_start..params.length() {
// SAFETY: `raw` is a valid pointer to a sqlite3_stmt.
while !unsafe { ffi::sqlite3_bind_parameter_name(raw, anon_idx) }
.is_null()
{
anon_idx += 1;
}

let value = params.get(i);

self.bind_value(scope, value, anon_idx)?;

anon_idx += 1;
}
}

Ok(())
}
}
Expand Down Expand Up @@ -621,6 +592,11 @@ impl StatementSync {
Ok(iterator)
}

#[fast]
fn set_allow_bare_named_parameters(&self, enabled: bool) {
self.allow_bare_named_params.set(enabled);
}

#[fast]
fn set_read_big_ints(&self, enabled: bool) {
self.use_big_ints.set(enabled);
Expand Down
14 changes: 14 additions & 0 deletions tests/unit_node/sqlite_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,20 @@ Deno.test("[node/sqlite] query should handle mixed positional and named paramete
variable3: 2,
}]);

const result2 = db.prepare(query).all({ var2: 1, var1: "test" });
assertEquals(result2, [{
__proto__: null,
variable1: "test",
variable2: 1,
variable3: 2,
}]);

const stmt = db.prepare(query);
stmt.setAllowBareNamedParameters(false);
assertThrows(() => {
stmt.all({ var1: "test", var2: 1 });
});

db.close();
});

Expand Down