From 1492bc2c440d829cbfd2055e89048591523432a1 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Sat, 4 Feb 2023 13:28:17 -0800 Subject: [PATCH] wip: write rust driver_manager --- .github/workflows/rust.yml | 100 ++++ CONTRIBUTING.md | 10 + rust/.gitignore | 18 + rust/Cargo.toml | 36 ++ rust/src/driver_manager.rs | 892 ++++++++++++++++++++++++++++++ rust/src/error.rs | 396 +++++++++++++ rust/src/ffi.rs | 825 +++++++++++++++++++++++++++ rust/src/interface.rs | 324 +++++++++++ rust/src/lib.rs | 219 ++++++++ rust/tests/test_driver_manager.rs | 222 ++++++++ 10 files changed, 3042 insertions(+) create mode 100644 .github/workflows/rust.yml create mode 100644 rust/.gitignore create mode 100644 rust/Cargo.toml create mode 100644 rust/src/driver_manager.rs create mode 100644 rust/src/error.rs create mode 100644 rust/src/ffi.rs create mode 100644 rust/src/interface.rs create mode 100644 rust/src/lib.rs create mode 100644 rust/tests/test_driver_manager.rs diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml new file mode 100644 index 0000000000..0ce2b301ac --- /dev/null +++ b/.github/workflows/rust.yml @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: Rust + +on: + pull_request: + branches: + - main + paths: + - "rust/**" + - ".github/workflows/rust.yml" + push: + paths: + - "rust/**" + - ".github/workflows/rust.yml" + +concurrency: + group: ${{ github.repository }}-${{ github.ref }}-${{ github.workflow }} + cancel-in-progress: true + +permissions: + contents: read + +defaults: + run: + working-directory: rust + +jobs: + rust: + strategy: + matrix: + os: [windows-latest, macos-latest, ubuntu-latest] + name: "Rust ${{ matrix.os }}" + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + persist-credentials: false + - name: Install stable toolchain + uses: actions-rs/toolchain@v1 + with: + profile: default + toolchain: stable + override: true + - uses: Swatinem/rust-cache@v2 + - name: Check format + run: cargo fmt -- --check + - name: Clippy + run: cargo clippy --tests + - name: Install sqlite (Windows) + if: matrix.os == 'windows-latest' + shell: cmd + run: | + choco install sqlite + cd /D C:\ProgramData\chocolatey\lib\SQLite\tools + call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvars64.bat" + lib /machine:x64 /def:sqlite3.def /out:sqlite3.lib + echo "C:\ProgramData\chocolatey\lib\SQLite\tools" >> $GITHUB_PATH + - name: Build Driver SQLite (Windows) + if: matrix.os == 'windows-latest' + working-directory: . + shell: pwsh + env: + BUILD_ALL: "0" + BUILD_DRIVER_SQLITE: "1" + ADBC_BUILD_TESTS: OFF + run: | + .\ci\scripts\cpp_build.ps1 $pwd $pwd\build $pwd\rust\build + - name: Build Driver SQLite (Unix) + if: matrix.os != 'windows-latest' + working-directory: . + shell: bash -l {0} + env: + BUILD_ALL: "0" + BUILD_DRIVER_SQLITE: "1" + run: | + ./ci/scripts/cpp_build.sh "$(pwd)" "$(pwd)/build" "$(pwd)/rust/build" + - name: Test + run: | + export LD_LIBRARY_PATH=$(pwd)/build/lib + export DYLD_LIBRARY_PATH=$(pwd)/build/lib + cargo test + - name: Check docs + run: cargo doc diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 0e80e1d816..c4c960dfd9 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -194,6 +194,16 @@ $ pytest -vvx The Ruby libraries are bindings around the GLib libraries. +### Rust + +The Rust components are a standard Rust project. + +```shell +$ cd rust +# Build and run tests +$ cargo test +``` + ## Opening a Pull Request Before opening a pull request, please run the static checks, which are diff --git a/rust/.gitignore b/rust/.gitignore new file mode 100644 index 0000000000..8abe0fcc17 --- /dev/null +++ b/rust/.gitignore @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +Cargo.lock diff --git a/rust/Cargo.toml b/rust/Cargo.toml new file mode 100644 index 0000000000..c54307d887 --- /dev/null +++ b/rust/Cargo.toml @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "arrow-adbc" +version = "0.1.0" +edition = "2021" +rust-version = "1.62" +description = "Rust implementation of Arrow Database Connectivity (ADBC)" +homepage = "https://arrow.apache.org/adbc/" +repository = "https://github.com/apache/arrow-adbc" +authors = ["Apache Arrow "] +license = "Apache-2.0" +keywords = ["arrow", "database", "sql"] + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +arrow = { version = "32.0.0", features = ["ffi"], default_features = false} +libloading = "0.7" + +# TODO: support arrow2 with a non-default feature. diff --git a/rust/src/driver_manager.rs b/rust/src/driver_manager.rs new file mode 100644 index 0000000000..90cfa3a670 --- /dev/null +++ b/rust/src/driver_manager.rs @@ -0,0 +1,892 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Load and use ADBC drivers. +//! +//! ## Loading a driver +//! +//! Drivers are initialized using a function provided by the driver as a main +//! entrypoint, canonically called `AdbcDriverInit`. (Although many will use a +//! different name to support statically linking multiple drivers within the +//! same program.) +//! +//! To load from a function, use [AdbcDriver::load_from_init]. +//! +//! To load from a dynamic library, use [AdbcDriver::load]. +//! +//! ## Using across threads +//! +//! [AdbcDriver] and [AdbcDatabase] can be used across multiple threads. They +//! hold their inner implementations within [std::sync::Arc], so they are +//! cheaply copy-able. +//! +//! [AdbcConnection] should not be used across multiple threads. Driver +//! implementations do not guarantee connection APIs are safe to call from +//! multiple threads, unless calls are carefully sequenced. So instead of using +//! the same connection across multiple threads, create a connection for each +//! thread. [AdbcConnectionBuilder] is [core::marker::Send], so it can be moved +//! to a new thread before initialized into an [AdbcConnection]. [AdbcConnection] +//! holds it's inner data in a [std::rc::Rc], so it is also cheaply copyable. + +use std::{ + cell::RefCell, + ffi::{c_void, CStr, CString}, + ops::{Deref, DerefMut}, + ptr::{null, null_mut}, + rc::Rc, + sync::{Arc, RwLock}, +}; + +use arrow::{ + array::{export_array_into_raw, StringArray, StructArray}, + datatypes::{DataType, Field, Schema}, + error::ArrowError, + ffi::{FFI_ArrowArray, FFI_ArrowSchema}, + ffi_stream::{export_reader_into_raw, ArrowArrayStreamReader, FFI_ArrowArrayStream}, + record_batch::{RecordBatch, RecordBatchReader}, +}; + +use crate::{ + error::{AdbcError, AdbcStatusCode, FFI_AdbcError}, + ffi::{ + driver_function_stubs, FFI_AdbcConnection, FFI_AdbcDatabase, FFI_AdbcDriver, + FFI_AdbcPartitions, FFI_AdbcStatement, + }, + interface::{ConnectionApi, PartitionedStatementResult, StatementApi, StatementResult}, +}; + +/// An error from an ADBC driver. +#[derive(Debug, Clone)] +pub struct AdbcDriverManagerError { + pub message: String, + pub vendor_code: i32, + pub sqlstate: [i8; 5usize], + pub status_code: AdbcStatusCode, +} + +impl std::fmt::Display for AdbcDriverManagerError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}: {} (sqlstate: {:?}, vendor_code: {})", + self.status_code, self.message, self.sqlstate, self.vendor_code + ) + } +} + +impl std::error::Error for AdbcDriverManagerError { + fn description(&self) -> &str { + &self.message + } +} + +pub type Result = std::result::Result; + +impl From for AdbcDriverManagerError { + fn from(value: T) -> Self { + Self { + message: value.message().to_string(), + vendor_code: value.vendor_code(), + sqlstate: value.sqlstate(), + status_code: value.status_code(), + } + } +} + +impl From for AdbcDriverManagerError { + fn from(value: libloading::Error) -> Self { + match value { + // Error from UNIX + libloading::Error::DlOpen { desc } => Self { + message: format!("{desc:?}"), + vendor_code: -1, + sqlstate: [0; 5], + status_code: AdbcStatusCode::Internal, + }, + // Error from Windows + libloading::Error::LoadLibraryExW { source } => Self { + message: format!("{source:?}"), + vendor_code: -1, + sqlstate: [0; 5], + status_code: AdbcStatusCode::Internal, + }, + // The remaining errors either shouldn't be relevant or are unknown and + // have no additional context on them. + _ => Self { + message: "Unknown error while loading shared library".to_string(), + vendor_code: -1, + sqlstate: [0; 5], + status_code: AdbcStatusCode::Unknown, + }, + } + } +} + +/// Convert ADBC-style status & error into our Result type. +fn check_status(status: AdbcStatusCode, error: &FFI_AdbcError) -> Result<()> { + if status == AdbcStatusCode::Ok { + Ok(()) + } else { + let message = unsafe { CStr::from_ptr(error.message) } + .to_string_lossy() + .to_string(); + + Err(AdbcDriverManagerError { + message, + vendor_code: error.vendor_code, + sqlstate: error.sqlstate, + status_code: status, + }) + } +} + +/// An internal safe wrapper around the driver +/// +/// If applicable, keeps the loaded dynamic library in scope as long as the +/// FFI_AdbcDriver so that all it's function pointers remain valid. +struct AdbcDriverInner { + driver: FFI_AdbcDriver, + _library: Option, +} + +/// Call a ADBC driver method on a [AdbcDriverInner], or else a stub function. +macro_rules! driver_method { + ($driver_inner:expr, $func_name:ident) => { + $driver_inner + .driver + .$func_name + .unwrap_or(driver_function_stubs::$func_name) + }; +} + +/// Signature of an ADBC driver init function. +pub type AdbcDriverInitFunc = unsafe extern "C" fn( + version: ::std::os::raw::c_int, + driver: *mut ::std::os::raw::c_void, + error: *mut crate::error::FFI_AdbcError, +) -> crate::error::AdbcStatusCode; + +/// A handle to an ADBC driver. +/// +/// The internal data is held behind a [std::sync::Arc], so it is cheaply-copyable. +#[derive(Clone)] +pub struct AdbcDriver { + inner: Arc, +} + +impl AdbcDriver { + /// Load a driver from a dynamic library. + /// + /// Will attempt to load the dynamic library with the given `name`, find the + /// symbol with name `entrypoint` (defaults to "AdbcDriverInit"), and then + /// call create the driver using the resolved function. + /// + /// `name` should **not** include any platform-specific prefixes or suffixes. + /// For example, use `"adbc_driver_sqlite"` rather than `"libadbc_driver_sqlite.so"`. + pub fn load(name: &str, entrypoint: Option<&[u8]>, version: i32) -> Result { + // Safety: because the driver we are loading contains pointers to functions + // in this library, we must keep it loaded as long as the driver is alive. + let library = unsafe { libloading::Library::new(libloading::library_filename(name))? }; + + let entrypoint = entrypoint.unwrap_or(b"AdbcDriverInit"); + let init_func: libloading::Symbol = unsafe { library.get(entrypoint)? }; + + let driver = Self::load_impl(&init_func, version)?; + Ok(Self { + inner: Arc::new(AdbcDriverInner { + driver, + _library: Some(library), + }), + }) + } + + /// Load a driver from an initialization function. + pub fn load_from_init(init_func: &AdbcDriverInitFunc, version: i32) -> Result { + let driver = Self::load_impl(init_func, version)?; + Ok(Self { + inner: Arc::new(AdbcDriverInner { + driver, + _library: None, + }), + }) + } + + fn load_impl(init_func: &AdbcDriverInitFunc, version: i32) -> Result { + let mut error = FFI_AdbcError::empty(); + let mut driver = FFI_AdbcDriver::empty(); + + let status = unsafe { + init_func( + version, + &mut driver as *mut FFI_AdbcDriver as *mut c_void, + &mut error, + ) + }; + check_status(status, &error)?; + + Ok(driver) + } + + /// Create a new database builder to initialize a database. + pub fn new_database(&self) -> Result { + let mut inner = FFI_AdbcDatabase::empty(); + + let mut error = FFI_AdbcError::empty(); + let database_new = driver_method!(self.inner, database_new); + let status = unsafe { database_new(&mut inner, &mut error) }; + + check_status(status, &error)?; + + Ok(AdbcDatabaseBuilder { + inner, + driver: self.inner.clone(), + }) + } +} + +fn str_to_cstring(value: &str) -> Result { + match CString::new(value) { + Ok(out) => Ok(out), + Err(err) => Err(AdbcDriverManagerError { + message: format!( + "Null character in string at position {}", + err.nul_position() + ), + vendor_code: -1, + sqlstate: [0; 5], + status_code: AdbcStatusCode::InvalidArguments, + }), + } +} + +/// Builder for an [AdbcDatabase]. +/// +/// Use this to set options on a database. While some databases may allow setting +/// options after initialization, many do not. +pub struct AdbcDatabaseBuilder { + inner: FFI_AdbcDatabase, + driver: Arc, +} + +impl AdbcDatabaseBuilder { + pub fn set_option(mut self, key: &str, value: &str) -> Result { + let mut error = FFI_AdbcError::empty(); + let key = str_to_cstring(key)?; + let value = str_to_cstring(value)?; + let set_option = driver_method!(self.driver, database_set_option); + let status = + unsafe { set_option(&mut self.inner, key.as_ptr(), value.as_ptr(), &mut error) }; + + check_status(status, &error)?; + + Ok(self) + } + + pub fn init(self) -> Result { + Ok(AdbcDatabase { + inner: Arc::new(RwLock::new(self.inner)), + driver: self.driver, + }) + } +} + +// Safety: the only thing in the builder that isn't Send + Sync is the +// FFI_AdbcDatabase within the AdbcDriverInner. But the builder has exclusive +// access to that value, since it was created when the builder was constructed +// and there is no public access to it. +unsafe impl Send for AdbcDatabaseBuilder {} +unsafe impl Sync for AdbcDatabaseBuilder {} + +/// An ADBC database handle. +/// +/// See [crate::interface::DatabaseApi] for more details. +#[derive(Clone)] +pub struct AdbcDatabase { + // In general, ADBC objects allow serialized access from multiple threads, + // but not concurrent access. Specific implementations may permit + // multiple threads. To support safe access to all drivers, we wrap them in + // RwLock. + inner: Arc>, + driver: Arc, +} + +impl AdbcDatabase { + /// Set an option on the database. + /// + /// Some drivers may not support setting options after initialization and + /// instead return an error. So when possible prefer setting options on the + /// builder. + pub fn set_option(self, key: &str, value: &str) -> Result<()> { + let mut error = FFI_AdbcError::empty(); + let key = str_to_cstring(key)?; + let value = str_to_cstring(value)?; + + let mut inner_mut = self + .inner + .write() + .expect("Read-write lock of AdbcDatabase was poisoned."); + let status = unsafe { + let set_option = driver_method!(self.driver, database_set_option); + set_option( + inner_mut.deref_mut(), + key.as_ptr(), + value.as_ptr(), + &mut error, + ) + }; + + check_status(status, &error)?; + + Ok(()) + } + + /// Get a connection builder to create a [AdbcConnection]. + pub fn new_connection(&self) -> Result { + let mut inner = FFI_AdbcConnection::empty(); + + let mut error = FFI_AdbcError::empty(); + let status = unsafe { + let connection_new = driver_method!(self.driver, connection_new); + connection_new(&mut inner, &mut error) + }; + + check_status(status, &error)?; + + Ok(AdbcConnectionBuilder { + inner, + database: self.inner.clone(), + driver: self.driver.clone(), + }) + } +} + +// Safety: the only thing in the builder that isn't Send + Sync is the +// FFI_AdbcDatabase within the AdbcDriverInner. The builder ensures it doesn't +// have multiple references to that before this struct is created. And within +// this struct, the value is wrapped in a RwLock to manage access. +unsafe impl Send for AdbcDatabase {} +unsafe impl Sync for AdbcDatabase {} + +/// Builder for an [AdbcConnection]. +pub struct AdbcConnectionBuilder { + inner: FFI_AdbcConnection, + database: Arc>, + driver: Arc, +} + +impl AdbcConnectionBuilder { + /// Set an option on a connection. + pub fn set_option(mut self, key: &str, value: &str) -> Result { + let mut error = FFI_AdbcError::empty(); + let key = str_to_cstring(key)?; + let value = str_to_cstring(value)?; + let status = unsafe { + let set_option = driver_method!(self.driver, connection_set_option); + set_option(&mut self.inner, key.as_ptr(), value.as_ptr(), &mut error) + }; + + check_status(status, &error)?; + + Ok(self) + } + + /// Initialize the connection. + /// + /// [AdbcConnection] is not [core::marker::Send], so move the builder to + /// the destination thread before initializing. + pub fn init(mut self) -> Result { + let mut error = FFI_AdbcError::empty(); + + let mut database_mut = self + .database + .write() + .expect("Read-write lock of AdbcDatabase was poisoned."); + + let connection_init = driver_method!(self.driver, connection_init); + let status = + unsafe { connection_init(&mut self.inner, database_mut.deref_mut(), &mut error) }; + + check_status(status, &error)?; + + Ok(AdbcConnection { + inner: Rc::new(RefCell::new(self.inner)), + driver: self.driver, + }) + } +} + +// Safety: There are only two things within the struct that are not Sync + Send. +// FFI_AdbcDatabase is not thread-safe, but as wrapped is Sync + Send in the same +// way as described for AdbcDatabase. And the inner FFI_AdbcConnection is +// guaranteed to have no references to it outside the builder, since it was +// created at the same time as the builder and there is no way to get a reference +// to it from this struct. +unsafe impl Send for AdbcConnectionBuilder {} +unsafe impl Sync for AdbcConnectionBuilder {} + +/// An ADBC Connection associated with the driver. +/// +/// Connections should be used on a single thread. To use a driver from multiple +/// threads, create a connection for each thread. +/// +/// See [ConnectionApi] for details of the methods. +pub struct AdbcConnection { + inner: Rc>, + driver: Arc, +} + +impl ConnectionApi for AdbcConnection { + type Error = AdbcDriverManagerError; + + fn set_option(&self, key: &str, value: &str) -> std::result::Result<(), Self::Error> { + let mut error = FFI_AdbcError::empty(); + let key = str_to_cstring(key)?; + let value = str_to_cstring(value)?; + + let set_option = driver_method!(self.driver, connection_set_option); + let status = unsafe { + set_option( + self.inner.borrow_mut().deref_mut(), + key.as_ptr(), + value.as_ptr(), + &mut error, + ) + }; + + check_status(status, &error)?; + + Ok(()) + } + + /// Get the valid table types for the database. + /// + /// For example, in sqlite the table types are "view" and "table". + /// + /// This can error if not implemented by the driver. + fn get_table_types(&self) -> std::result::Result, Self::Error> { + let mut error = FFI_AdbcError::empty(); + + let mut reader = FFI_ArrowArrayStream::empty(); + + let get_table_types = driver_method!(self.driver, connection_get_table_types); + let status = unsafe { + get_table_types(self.inner.borrow_mut().deref_mut(), &mut reader, &mut error) + }; + check_status(status, &error)?; + + let reader = unsafe { ArrowArrayStreamReader::from_raw(&mut reader)? }; + + let expected_schema = Schema::new(vec![Field::new("table_type", DataType::Utf8, false)]); + let schema_mismatch_error = |found_schema| AdbcDriverManagerError { + message: format!("Driver returned unexpected schema: {found_schema:?}"), + vendor_code: -1, + sqlstate: [0; 5], + status_code: AdbcStatusCode::Internal, + }; + if reader.schema().deref() != &expected_schema { + return Err(schema_mismatch_error(reader.schema())); + } + + let batches: Vec = reader.collect::>()?; + + let mut out: Vec = + Vec::with_capacity(batches.iter().map(|batch| batch.num_rows()).sum()); + for batch in &batches { + if batch.schema().deref() != &expected_schema { + return Err(schema_mismatch_error(batch.schema())); + } + let column = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); // We just asserted above this is a StringArray. + for value in column.into_iter().flatten() { + out.push(value.to_string()); + } + } + + Ok(out) + } + + fn get_info( + &self, + info_codes: &[u32], + ) -> std::result::Result, Self::Error> { + // TODO: Get this to return a more usable type + let mut error = FFI_AdbcError::empty(); + + let mut reader = FFI_ArrowArrayStream::empty(); + + let get_info = driver_method!(self.driver, connection_get_info); + let status = unsafe { + get_info( + self.inner.borrow_mut().deref_mut(), + info_codes.as_ptr(), + info_codes.len(), + &mut reader, + &mut error, + ) + }; + check_status(status, &error)?; + + let reader = unsafe { ArrowArrayStreamReader::from_raw(&mut reader)? }; + + Ok(Box::new(reader)) + } + + fn get_objects( + &self, + depth: crate::ffi::AdbcObjectDepth, + catalog: Option<&str>, + db_schema: Option<&str>, + table_name: Option<&str>, + table_type: Option<&[&str]>, + column_name: Option<&str>, + ) -> std::result::Result, Self::Error> { + let mut error = FFI_AdbcError::empty(); + + let mut reader = FFI_ArrowArrayStream::empty(); + + let catalog = catalog.map(str_to_cstring).transpose()?; + let catalog_ptr = catalog.map(|s| s.as_ptr()).unwrap_or(null()); + + let db_schema = db_schema.map(str_to_cstring).transpose()?; + let db_schema_ptr = db_schema.map(|s| s.as_ptr()).unwrap_or(null()); + + let table_name = table_name.map(str_to_cstring).transpose()?; + let table_name_ptr = table_name.map(|s| s.as_ptr()).unwrap_or(null()); + + let column_name = column_name.map(str_to_cstring).transpose()?; + let column_name_ptr = column_name.map(|s| s.as_ptr()).unwrap_or(null()); + + let table_type: Vec = match table_type { + Some(table_type) => table_type + .iter() + .map(|&s| str_to_cstring(s)) + .collect::>()?, + None => Vec::new(), + }; + let mut table_type_ptrs: Vec<_> = table_type.iter().map(|s| s.as_ptr()).collect(); + // Make sure the array is null-terminated + table_type_ptrs.push(null()); + + let get_objects = driver_method!(self.driver, connection_get_objects); + let status = unsafe { + get_objects( + self.inner.borrow_mut().deref_mut(), + depth, + catalog_ptr, + db_schema_ptr, + table_name_ptr, + if table_type_ptrs.len() == 1 { + // Just null + null() + } else { + table_type_ptrs.as_ptr() + }, + column_name_ptr, + &mut reader, + &mut error, + ) + }; + check_status(status, &error)?; + + let reader = unsafe { ArrowArrayStreamReader::from_raw(&mut reader)? }; + + Ok(Box::new(reader)) + } + + fn get_table_schema( + &self, + catalog: Option<&str>, + db_schema: Option<&str>, + table_name: &str, + ) -> std::result::Result { + let mut error = FFI_AdbcError::empty(); + + let catalog = catalog.map(str_to_cstring).transpose()?; + let catalog_ptr = catalog.map(|s| s.as_ptr()).unwrap_or(null()); + + let db_schema = db_schema.map(str_to_cstring).transpose()?; + let db_schema_ptr = db_schema.map(|s| s.as_ptr()).unwrap_or(null()); + + let table_name = str_to_cstring(table_name)?; + + let mut schema = FFI_ArrowSchema::empty(); + + let get_table_schema = driver_method!(self.driver, connection_get_table_schema); + let status = unsafe { + get_table_schema( + self.inner.borrow_mut().deref_mut(), + catalog_ptr, + db_schema_ptr, + table_name.as_ptr(), + &mut schema, + &mut error, + ) + }; + check_status(status, &error)?; + + Ok(Schema::try_from(&schema)?) + } + + fn read_partition( + &self, + partition: &[u8], + ) -> std::result::Result, Self::Error> { + let mut error = FFI_AdbcError::empty(); + + let mut reader = FFI_ArrowArrayStream::empty(); + + let read_partition = driver_method!(self.driver, connection_read_partition); + let status = unsafe { + read_partition( + self.inner.borrow_mut().deref_mut(), + partition.as_ptr(), + partition.len(), + &mut reader, + &mut error, + ) + }; + check_status(status, &error)?; + + let reader = unsafe { ArrowArrayStreamReader::from_raw(&mut reader)? }; + + Ok(Box::new(reader)) + } + + fn commit(&self) -> std::result::Result<(), Self::Error> { + let mut error = FFI_AdbcError::empty(); + + let commit = driver_method!(self.driver, connection_commit); + let status = unsafe { commit(self.inner.borrow_mut().deref_mut(), &mut error) }; + check_status(status, &error)?; + Ok(()) + } + + fn rollback(&self) -> std::result::Result<(), Self::Error> { + let mut error = FFI_AdbcError::empty(); + + let rollback = driver_method!(self.driver, connection_rollback); + let status = unsafe { rollback(self.inner.borrow_mut().deref_mut(), &mut error) }; + check_status(status, &error)?; + Ok(()) + } +} + +impl AdbcConnection { + /// Create a new statement. + pub fn new_statement(&self) -> Result { + let mut inner = FFI_AdbcStatement::empty(); + let mut error = FFI_AdbcError::empty(); + + let statement_new = driver_method!(self.driver, statement_new); + let status = + unsafe { statement_new(self.inner.borrow_mut().deref_mut(), &mut inner, &mut error) }; + check_status(status, &error)?; + + Ok(AdbcStatement { + inner, + _connection: self.inner.clone(), + driver: self.driver.clone(), + }) + } +} + +/// A handle to an ADBC statement. +/// +/// See [StatementApi] for details. +pub struct AdbcStatement { + inner: FFI_AdbcStatement, + // We hold onto the connection to make sure it is kept alive (and keep + // lifetime semantics simple). + _connection: Rc>, + driver: Arc, +} + +impl StatementApi for AdbcStatement { + type Error = AdbcDriverManagerError; + + fn prepare(&mut self) -> std::result::Result<(), Self::Error> { + let mut error = FFI_AdbcError::empty(); + + let statement_prepare = driver_method!(self.driver, statement_prepare); + let status = unsafe { statement_prepare(&mut self.inner, &mut error) }; + check_status(status, &error)?; + Ok(()) + } + + fn set_option(&mut self, key: &str, value: &str) -> std::result::Result<(), Self::Error> { + let mut error = FFI_AdbcError::empty(); + + let key = str_to_cstring(key)?; + let value = str_to_cstring(value)?; + + let set_option = driver_method!(self.driver, statement_set_option); + let status = + unsafe { set_option(&mut self.inner, key.as_ptr(), value.as_ptr(), &mut error) }; + check_status(status, &error)?; + Ok(()) + } + + fn set_sql_query(&mut self, query: &str) -> std::result::Result<(), Self::Error> { + let mut error = FFI_AdbcError::empty(); + + let query = str_to_cstring(query)?; + + let set_sql_query = driver_method!(self.driver, statement_set_sql_query); + let status = unsafe { set_sql_query(&mut self.inner, query.as_ptr(), &mut error) }; + check_status(status, &error)?; + Ok(()) + } + + fn set_substrait_plan(&mut self, plan: &[u8]) -> std::result::Result<(), Self::Error> { + let mut error = FFI_AdbcError::empty(); + + let set_substrait_plan = driver_method!(self.driver, statement_set_substrait_plan); + let status = + unsafe { set_substrait_plan(&mut self.inner, plan.as_ptr(), plan.len(), &mut error) }; + check_status(status, &error)?; + Ok(()) + } + + fn get_param_schema(&mut self) -> std::result::Result { + let mut error = FFI_AdbcError::empty(); + + let mut schema = FFI_ArrowSchema::empty(); + + let get_parameter_schema = driver_method!(self.driver, statement_get_parameter_schema); + let status = unsafe { get_parameter_schema(&mut self.inner, &mut schema, &mut error) }; + check_status(status, &error)?; + + Ok(Schema::try_from(&schema)?) + } + + fn bind_data(&mut self, batch: RecordBatch) -> std::result::Result<(), Self::Error> { + let mut error = FFI_AdbcError::empty(); + + let struct_arr = Arc::new(StructArray::from(batch)); + + let mut schema = FFI_ArrowSchema::empty(); + let mut array = FFI_ArrowArray::empty(); + + unsafe { export_array_into_raw(struct_arr, &mut array, &mut schema)? }; + + let statement_bind = driver_method!(self.driver, statement_bind); + let status = + unsafe { statement_bind(&mut self.inner, &mut array, &mut schema, &mut error) }; + check_status(status, &error)?; + + Ok(()) + } + + fn bind_stream( + &mut self, + reader: Box, + ) -> std::result::Result<(), Self::Error> { + let mut error = FFI_AdbcError::empty(); + + let mut stream = FFI_ArrowArrayStream::empty(); + + unsafe { export_reader_into_raw(reader, &mut stream) }; + + let statement_bind_stream = driver_method!(self.driver, statement_bind_stream); + let status = unsafe { statement_bind_stream(&mut self.inner, &mut stream, &mut error) }; + check_status(status, &error)?; + + Ok(()) + } + + fn execute(&mut self) -> std::result::Result { + let mut error = FFI_AdbcError::empty(); + + let mut stream = FFI_ArrowArrayStream::empty(); + let mut rows_affected: i64 = -1; + + let execute_query = driver_method!(self.driver, statement_execute_query); + let status = + unsafe { execute_query(&mut self.inner, &mut stream, &mut rows_affected, &mut error) }; + check_status(status, &error)?; + + let result: Option> = if stream.release.is_none() { + // There was no result + None + } else { + unsafe { Some(Box::new(ArrowArrayStreamReader::from_raw(&mut stream)?)) } + }; + + Ok(StatementResult { + result, + rows_affected, + }) + } + + fn execute_update(&mut self) -> std::result::Result { + let mut error = FFI_AdbcError::empty(); + + let stream = null_mut(); + let mut rows_affected: i64 = -1; + + let execute_query = driver_method!(self.driver, statement_execute_query); + let status = + unsafe { execute_query(&mut self.inner, stream, &mut rows_affected, &mut error) }; + check_status(status, &error)?; + + Ok(rows_affected) + } + + fn execute_partitioned( + &mut self, + ) -> std::result::Result { + let mut error = FFI_AdbcError::empty(); + + let mut schema = FFI_ArrowSchema::empty(); + let mut partitions = FFI_AdbcPartitions::empty(); + let mut rows_affected: i64 = -1; + + let execute_partitions = driver_method!(self.driver, statement_execute_partitions); + let status = unsafe { + execute_partitions( + &mut self.inner, + &mut schema, + &mut partitions, + &mut rows_affected, + &mut error, + ) + }; + check_status(status, &error)?; + + let schema = Schema::try_from(&schema)?; + + let partition_lengths = unsafe { + std::slice::from_raw_parts(partitions.partition_lengths, partitions.num_partitions) + }; + let partition_ptrs = + unsafe { std::slice::from_raw_parts(partitions.partitions, partitions.num_partitions) }; + let partition_ids = partition_ptrs + .iter() + .zip(partition_lengths.iter()) + .map(|(&part_ptr, &len)| unsafe { std::slice::from_raw_parts(part_ptr, len).to_vec() }) + .collect(); + + Ok(PartitionedStatementResult { + schema, + partition_ids, + rows_affected, + }) + } +} diff --git a/rust/src/error.rs b/rust/src/error.rs new file mode 100644 index 0000000000..ca6856ddd1 --- /dev/null +++ b/rust/src/error.rs @@ -0,0 +1,396 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! ADBC error handling utilities +//! +//! ADBC functions report errors in two ways at the same time: first, they +//! return a status code, [AdbcStatusCode], and second, they fill in an out pointer +//! to [FFI_AdbcError]. To easily convert between a Rust error enum and these +//! two types, implement the [AdbcError] trait. With that trait defined, you can +//! use the [check_err] macro to handle Rust errors within ADBC functions. +//! +//! # Examples +//! +//! In simple cases, you can use [FFI_AdbcError::set_message] and return an error +//! status code early. To handle error enums that implement [AdbcError], use [check_err]. +//! +//! ``` +//! use std::ffi::CStr; +//! use std::os::raw::c_char; +//! use arrow_adbc::error::{FFI_AdbcError, AdbcStatusCode, check_err, AdbcError}; +//! +//! unsafe fn adbc_str_utf8_len( +//! key: *const c_char, +//! out: *mut usize, +//! error: *mut FFI_AdbcError) -> AdbcStatusCode { +//! if key.is_null() { +//! FFI_AdbcError::set_message(error, "Passed a null pointer."); +//! return AdbcStatusCode::InvalidArguments; +//! } else { +//! // AdbcError is implemented for Utf8Error +//! let key: &str = check_err!(CStr::from_ptr(key).to_str(), error); +//! let len: usize = key.chars().count(); +//! std::ptr::write_unaligned(out, len); +//! } +//! AdbcStatusCode::Ok +//! } +//! +//! +//! let msg: &[u8] = &[0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x0]; // "hello" +//! let mut out: usize = 0; +//! +//! let mut error = FFI_AdbcError::empty(); +//! +//! let status_code = unsafe { adbc_str_utf8_len( +//! msg.as_ptr() as *const c_char, +//! &mut out as *mut usize, +//! &mut error +//! ) }; +//! +//! assert_eq!(status_code, AdbcStatusCode::Ok); +//! assert_eq!(out, 5); +//! +//! let mut error = FFI_AdbcError::empty(); +//! let mut msg: &[u8] = &[0xff, 0x0]; +//! let status_code = unsafe { adbc_str_utf8_len( +//! msg.as_ptr() as *const c_char, +//! &mut out as *mut usize, +//! &mut error +//! ) }; +//! +//! assert_eq!(status_code, AdbcStatusCode::InvalidArguments); +//! let error_msg = unsafe { CStr::from_ptr(error.message).to_str().unwrap() }; +//! assert_eq!(error_msg, "Invalid UTF-8 character"); +//! assert_eq!(error.sqlstate, [2, 2, 0, 2, 1]); +//! ``` +//! + +use std::{ + ffi::{c_char, CString}, + ptr::null_mut, +}; + +#[derive(Debug, PartialEq, Copy, Clone)] +#[repr(u8)] +pub enum AdbcStatusCode { + /// No error. + Ok = 0, + /// An unknown error occurred. + /// + /// May indicate a driver-side or database-side error. + Unknown = 1, + /// The operation is not implemented or supported. + /// + /// May indicate a driver-side or database-side error. + NotImplemented = 2, + /// A requested resource was not found. + /// + /// May indicate a driver-side or database-side error. + NotFound = 3, + /// A requested resource already exists. + /// + /// May indicate a driver-side or database-side error. + AlreadyExists = 4, + /// The arguments are invalid, likely a programming error. + /// + /// May indicate a driver-side or database-side error. + InvalidArguments = 5, + /// The preconditions for the operation are not met, likely a + /// programming error. + /// + /// For instance, the object may be uninitialized, or may have not + /// been fully configured. + /// + /// May indicate a driver-side or database-side error. + InvalidState = 6, + /// Invalid data was processed (not a programming error). + /// + /// For instance, a division by zero may have occurred during query + /// execution. + /// + /// May indicate a database-side error only. + InvalidData = 7, + /// The database's integrity was affected. + /// + /// For instance, a foreign key check may have failed, or a uniqueness + /// constraint may have been violated. + /// + /// May indicate a database-side error only. + Integrity = 8, + /// An error internal to the driver or database occurred. + /// + /// May indicate a driver-side or database-side error. + Internal = 9, + /// An I/O error occurred. + /// + /// For instance, a remote service may be unavailable. + /// + /// May indicate a driver-side or database-side error. + IO = 10, + /// The operation was cancelled, not due to a timeout. + /// + /// May indicate a driver-side or database-side error. + Cancelled = 11, + /// The operation was cancelled due to a timeout. + /// + /// May indicate a driver-side or database-side error. + Timeout = 12, + /// Authentication failed. + /// + /// May indicate a database-side error only. + Unauthenticated = 13, + /// The client is not authorized to perform the given operation. + /// + /// May indicate a database-side error only. + Unauthorized = 14, +} + +impl std::fmt::Display for AdbcStatusCode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AdbcStatusCode::Ok => write!(f, "Ok"), + AdbcStatusCode::Unknown => write!(f, "Unknown"), + AdbcStatusCode::NotImplemented => write!(f, "Not Implemented"), + AdbcStatusCode::NotFound => write!(f, "Not Found"), + AdbcStatusCode::AlreadyExists => write!(f, "Already Exists"), + AdbcStatusCode::InvalidArguments => write!(f, "Invalid Arguments"), + AdbcStatusCode::InvalidState => write!(f, "Invalid State"), + AdbcStatusCode::InvalidData => write!(f, "Invalid Data"), + AdbcStatusCode::Integrity => write!(f, "Integrity"), + AdbcStatusCode::Internal => write!(f, "Internal Error"), + AdbcStatusCode::IO => write!(f, "IO Error"), + AdbcStatusCode::Cancelled => write!(f, "Cancelled"), + AdbcStatusCode::Timeout => write!(f, "Timeout"), + AdbcStatusCode::Unauthenticated => write!(f, "Unauthenticated"), + AdbcStatusCode::Unauthorized => write!(f, "Unauthorized"), + } + } +} + +/// A detailed error message for an operation. +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct FFI_AdbcError { + /// The error message. + pub message: *mut c_char, + /// A vendor-specific error code, if applicable. + pub vendor_code: i32, + /// A SQLSTATE error code, if provided, as defined by the + /// SQL:2003 standard. If not set, it should be set to + /// "\\0\\0\\0\\0\\0". + pub sqlstate: [c_char; 5usize], + /// Release the contained error. + /// + /// Unlike other structures, this is an embedded callback to make it + /// easier for the driver manager and driver to cooperate. + pub release: Option, +} + +impl FFI_AdbcError { + /// Create an empty error + pub fn empty() -> Self { + Self { + message: null_mut(), + vendor_code: -1, + sqlstate: ['\0' as c_char; 5], + release: None, + } + } + + /// Create a new FFI_AdbcError. + /// + /// `vendor_code` defaults to -1 and `sql_state` defaults to zeros. + pub fn new(message: &str, vendor_code: Option, sqlstate: Option<[c_char; 5]>) -> Self { + Self { + message: CString::new(message).unwrap().into_raw(), + vendor_code: vendor_code.unwrap_or(-1), + sqlstate: sqlstate.unwrap_or(['\0' as c_char; 5]), + release: Some(drop_adbc_error), + } + } + + /// Set an error message. + /// + /// # Safety + /// + /// If `dest` is null, no error is written. If `dest` is non-null, it must + /// be valid for writes. + pub unsafe fn set_message(dest: *mut Self, message: &str) { + if !dest.is_null() { + let error = Self::new(message, None, None); + unsafe { std::ptr::write_unaligned(dest, error) } + } + } +} + +/// An error that can be converted into [FFI_AdbcError] and [AdbcStatusCode]. +/// +/// Can be used in combination with [check_err] when implementing ADBC FFI +/// functions. +pub trait AdbcError { + /// The status code this error corresponds to. + fn status_code(&self) -> AdbcStatusCode; + + /// The message associated with the error. + fn message(&self) -> &str; + + /// A vendor-specific error code. Defaults to always returning `-1`. + fn vendor_code(&self) -> i32 { + -1 + } + + /// A SQLSTATE error code, if provided, as defined by the + /// SQL:2003 standard. By default, it is set to + /// `"\0\0\0\0\0"`. + fn sqlstate(&self) -> [i8; 5] { + [0, 0, 0, 0, 0] + } +} + +impl From<&T> for FFI_AdbcError { + fn from(err: &T) -> Self { + let message: *mut i8 = CString::new(err.message()).unwrap().into_raw(); + Self { + message, + vendor_code: err.vendor_code(), + sqlstate: err.sqlstate(), + release: Some(drop_adbc_error), + } + } +} + +impl AdbcError for std::str::Utf8Error { + fn message(&self) -> &str { + "Invalid UTF-8 character" + } + + fn sqlstate(&self) -> [i8; 5] { + // A character is not in the coded character set or the conversion is not supported. + [2, 2, 0, 2, 1] + } + + fn status_code(&self) -> AdbcStatusCode { + AdbcStatusCode::InvalidArguments + } + + fn vendor_code(&self) -> i32 { + -1 + } +} + +impl AdbcError for ArrowError { + fn message(&self) -> &str { + match self { + ArrowError::CDataInterface(msg) => msg, + ArrowError::SchemaError(msg) => msg, + _ => "Arrow error", // TODO: Fill in remainder + } + } + + fn status_code(&self) -> AdbcStatusCode { + AdbcStatusCode::Internal + } +} + +unsafe extern "C" fn drop_adbc_error(error: *mut FFI_AdbcError) { + if let Some(error) = error.as_mut() { + // Retake pointer so it will drop once out of scope. + if !error.message.is_null() { + let _ = CString::from_raw(error.message); + } + error.message = null_mut(); + } +} + +/// Given a Result, either unwrap the value or handle the error in ADBC function. +/// +/// This macro is for use when implementing ADBC methods that have an out +/// parameter for [FFI_AdbcError] and return [AdbcStatusCode]. If the result is +/// `Ok`, the expression resolves to the value. Otherwise, it will return early, +/// setting the error and status code appropriately. In order for this to work, +/// the error must implement [AdbcError]. +#[macro_export] +macro_rules! check_err { + ($res:expr, $err_out:expr) => { + match $res { + Ok(x) => x, + Err(err) => { + let error = FFI_AdbcError::from(&err); + unsafe { std::ptr::write_unaligned($err_out, error) }; + return err.status_code(); + } + } + }; +} + +use arrow::error::ArrowError; +pub use check_err; + +#[cfg(test)] +mod tests { + use std::ffi::CStr; + + use super::*; + + #[test] + fn test_adbcerror() { + let cases = vec![ + ("hello", None, None), + ("", None, None), + ("unicode 😅", None, None), + ("msg", Some(20), None), + ("msg", None, Some([3, 4, 5, 6, 7])), + ]; + + for (msg, vendor_code, sqlstate) in cases { + let mut err = FFI_AdbcError::new(msg, vendor_code, sqlstate); + assert_eq!( + unsafe { CStr::from_ptr(err.message).to_str().unwrap() }, + msg + ); + assert_eq!(err.vendor_code, vendor_code.unwrap_or(-1)); + assert_eq!(err.sqlstate, sqlstate.unwrap_or([0, 0, 0, 0, 0])); + + assert!(err.release.is_some()); + let release_func = err.release.unwrap(); + unsafe { release_func(&mut err) }; + + assert!(err.message.is_null()); + } + } + + #[test] + fn test_adbcerror_set_message() { + let mut error = FFI_AdbcError::empty(); + + let msg = "Hello world!"; + unsafe { FFI_AdbcError::set_message(&mut error, msg) }; + + assert_eq!( + unsafe { CStr::from_ptr(error.message).to_str().unwrap() }, + msg + ); + assert_eq!(error.vendor_code, -1); + assert_eq!(error.sqlstate, [0, 0, 0, 0, 0]); + + assert!(error.release.is_some()); + let release_func = error.release.unwrap(); + unsafe { release_func(&mut error) }; + + assert!(error.message.is_null()); + } +} diff --git a/rust/src/ffi.rs b/rust/src/ffi.rs new file mode 100644 index 0000000000..2c5de7ed7b --- /dev/null +++ b/rust/src/ffi.rs @@ -0,0 +1,825 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! ADBC FFI structs, as defined in [adbc.h](https://github.com/apache/arrow-adbc/blob/main/adbc.h). +use std::ffi::{c_char, c_void, CStr}; +use std::ptr::{null, null_mut}; + +use crate::error::{AdbcStatusCode, FFI_AdbcError}; +use arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema}; +use arrow::ffi_stream::FFI_ArrowArrayStream; + +/// An instance of a database. +/// +/// Must be kept alive as long as any connections exist. +#[repr(C)] +#[derive(Debug, Clone)] +pub struct FFI_AdbcDatabase { + /// Opaque implementation-defined state. + /// This field is NULLPTR iff the connection is unintialized/freed. + pub private_data: *mut c_void, + /// The associated driver (used by the driver manager to help + /// track state). + pub private_driver: *const FFI_AdbcDriver, +} + +impl FFI_AdbcDatabase { + pub fn empty() -> Self { + Self { + private_data: null_mut(), + private_driver: null_mut(), + } + } +} + +impl Drop for FFI_AdbcDatabase { + fn drop(&mut self) { + if let Some(private_driver) = unsafe { self.private_driver.as_ref() } { + if let Some(release) = private_driver.database_release { + let mut error = FFI_AdbcError::empty(); + let status = unsafe { release(self, &mut error) }; + if status != AdbcStatusCode::Ok { + panic!("Failed to cleanup database: {}", unsafe { + CStr::from_ptr(error.message).to_string_lossy() + }); + } + } + } + } +} + +/// An active database connection. +/// +/// Provides methods for query execution, managing prepared +/// statements, using transactions, and so on. +/// +/// Connections are not required to be thread-safe, but they can be +/// used from multiple threads so long as clients take care to +/// serialize accesses to a connection. Because of this, they do not implement +/// [core::marker::Send] + [core::marker::Sync] on their own. Instead wrap them +/// in the appropriate types to manage access safely and implement those marker +/// traits on the wrapper. +#[repr(C)] +#[derive(Debug, Clone)] +pub struct FFI_AdbcConnection { + /// Opaque implementation-defined state. + /// This field is NULLPTR iff the connection is unintialized/freed. + pub private_data: *mut c_void, + /// The associated driver (used by the driver manager to help + /// track state). + pub private_driver: *mut FFI_AdbcDriver, +} + +impl FFI_AdbcConnection { + pub fn empty() -> Self { + Self { + private_data: null_mut(), + private_driver: null_mut(), + } + } +} + +impl Drop for FFI_AdbcConnection { + fn drop(&mut self) { + if let Some(private_driver) = unsafe { self.private_driver.as_ref() } { + if let Some(release) = private_driver.connection_release.as_ref() { + let mut error = FFI_AdbcError::empty(); + let status = unsafe { release(self, &mut error) }; + if status != AdbcStatusCode::Ok { + panic!("Failed to cleanup connection: {}", unsafe { + CStr::from_ptr(error.message).to_string_lossy() + }); + } + } + } + } +} + +/// A container for all state needed to execute a database +/// query, such as the query itself, parameters for prepared +/// statements, driver parameters, etc. +/// +/// Statements may represent queries or prepared statements. +/// +/// Statements may be used multiple times and can be reconfigured +/// (e.g. they can be reused to execute multiple different queries). +/// However, executing a statement (and changing certain other state) +/// will invalidate result sets obtained prior to that execution. +/// +/// Multiple statements may be created from a single connection. +/// However, the driver may block or error if they are used +/// concurrently (whether from a single thread or multiple threads). +/// +/// Statements are not required to be thread-safe, but they can be +/// used from multiple threads so long as clients take care to +/// serialize accesses to a statement. +#[repr(C)] +#[derive(Debug, Clone)] +pub struct FFI_AdbcStatement { + /// Opaque implementation-defined state. + /// This field is NULLPTR iff the connection is unintialized/freed. + pub private_data: *mut c_void, + /// The associated driver (used by the driver manager to help + /// track state). + pub private_driver: *mut FFI_AdbcDriver, +} + +impl FFI_AdbcStatement { + pub fn empty() -> Self { + Self { + private_data: null_mut(), + private_driver: null_mut(), + } + } +} + +impl Drop for FFI_AdbcStatement { + fn drop(&mut self) { + if let Some(private_driver) = unsafe { self.private_driver.as_ref() } { + if let Some(release) = private_driver.statement_release { + let mut error = FFI_AdbcError::empty(); + let status = unsafe { release(self, &mut error) }; + if status != AdbcStatusCode::Ok { + panic!("Failed to cleanup statement: {}", unsafe { + CStr::from_ptr(error.message).to_string_lossy() + }); + } + } + } + } +} + +/// The partitions of a distributed/partitioned result set. +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct FFI_AdbcPartitions { + /// The number of partitions. + pub num_partitions: usize, + /// The partitions of the result set, where each entry (up to + /// num_partitions entries) is an opaque identifier that can be + /// passed to FFI_AdbcConnectionReadPartition. + pub partitions: *mut *const u8, + /// The length of each corresponding entry in partitions. + pub partition_lengths: *const usize, + /// Opaque implementation-defined state. + /// This field is NULLPTR iff the connection is unintialized/freed. + pub private_data: *mut c_void, + /// Release the contained partitions. + /// + /// Unlike other structures, this is an embedded callback to make it + /// easier for the driver manager and driver to cooperate. + pub release: ::std::option::Option, +} + +impl FFI_AdbcPartitions { + pub fn empty() -> Self { + Self { + num_partitions: 0, + partitions: null_mut(), + partition_lengths: null(), + private_data: null_mut(), + release: None, + } + } +} + +impl From>> for FFI_AdbcPartitions { + fn from(mut value: Vec>) -> Self { + // Make sure capacity and length are the same, so it's easier to reconstruct them. + value.shrink_to_fit(); + + let num_partitions = value.len(); + let mut lengths: Vec = value.iter().map(|v| v.len()).collect(); + let partition_lengths = lengths.as_mut_ptr(); + std::mem::forget(lengths); + + let mut partitions_vec: Vec<*const u8> = value + .into_iter() + .map(|mut p| { + p.shrink_to_fit(); + let ptr = p.as_ptr(); + std::mem::forget(p); + ptr + }) + .collect(); + partitions_vec.shrink_to_fit(); + let partitions = partitions_vec.as_mut_ptr(); + std::mem::forget(partitions_vec); + + Self { + num_partitions, + partitions, + partition_lengths, + private_data: 42 as *mut c_void, // Arbitrary non-null pointer + release: Some(drop_adbc_partitions), + } + } +} + +unsafe extern "C" fn drop_adbc_partitions(partitions: *mut FFI_AdbcPartitions) { + if let Some(partitions) = partitions.as_mut() { + // This must reconstruct every Vec that we called mem::forget on when + // constructing the FFI struct. + let partition_lengths: Vec = Vec::from_raw_parts( + partitions.partition_lengths as *mut usize, + partitions.num_partitions, + partitions.num_partitions, + ); + + let partitions_vec = Vec::from_raw_parts( + partitions.partitions, + partitions.num_partitions, + partitions.num_partitions, + ); + + let _each_partition: Vec> = partitions_vec + .into_iter() + .zip(partition_lengths) + .map(|(ptr, size)| Vec::from_raw_parts(ptr as *mut u8, size, size)) + .collect(); + + partitions.partitions = null_mut(); + partitions.partition_lengths = null_mut(); + partitions.private_data = null_mut(); + partitions.release = None; + } +} + +/// An instance of an initialized database driver. +/// +/// This provides a common interface for vendor-specific driver +/// initialization routines. Drivers should populate this struct, and +/// applications can call ADBC functions through this struct, without +/// worrying about multiple definitions of the same symbol. +#[repr(C)] +#[derive(Debug, Clone)] +pub struct FFI_AdbcDriver { + /// Opaque driver-defined state. + /// This field is NULL if the driver is unintialized/freed (but + /// it need not have a value even if the driver is initialized). + pub private_data: *mut c_void, + /// Opaque driver manager-defined state. + /// This field is NULL if the driver is unintialized/freed (but + /// it need not have a value even if the driver is initialized). + pub private_manager: *mut c_void, + /// Release the driver and perform any cleanup. + /// + /// This is an embedded callback to make it easier for the driver + /// manager and driver to cooperate. + pub release: ::std::option::Option< + unsafe extern "C" fn( + driver: *mut FFI_AdbcDriver, + error: *mut FFI_AdbcError, + ) -> AdbcStatusCode, + >, + pub database_init: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut FFI_AdbcDatabase, + arg2: *mut FFI_AdbcError, + ) -> AdbcStatusCode, + >, + pub database_new: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut FFI_AdbcDatabase, + arg2: *mut FFI_AdbcError, + ) -> AdbcStatusCode, + >, + pub database_set_option: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut FFI_AdbcDatabase, + arg2: *const c_char, + arg3: *const c_char, + arg4: *mut FFI_AdbcError, + ) -> AdbcStatusCode, + >, + pub database_release: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut FFI_AdbcDatabase, + arg2: *mut FFI_AdbcError, + ) -> AdbcStatusCode, + >, + pub connection_commit: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut FFI_AdbcConnection, + arg2: *mut FFI_AdbcError, + ) -> AdbcStatusCode, + >, + pub connection_get_info: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut FFI_AdbcConnection, + arg2: *const u32, + arg3: usize, + arg4: *mut FFI_ArrowArrayStream, + arg5: *mut FFI_AdbcError, + ) -> AdbcStatusCode, + >, + pub connection_get_objects: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut FFI_AdbcConnection, + arg2: AdbcObjectDepth, + arg3: *const c_char, + arg4: *const c_char, + arg5: *const c_char, + arg6: *const *const c_char, + arg7: *const c_char, + arg8: *mut FFI_ArrowArrayStream, + arg9: *mut FFI_AdbcError, + ) -> AdbcStatusCode, + >, + pub connection_get_table_schema: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut FFI_AdbcConnection, + arg2: *const c_char, + arg3: *const c_char, + arg4: *const c_char, + arg5: *mut FFI_ArrowSchema, + arg6: *mut FFI_AdbcError, + ) -> AdbcStatusCode, + >, + pub connection_get_table_types: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut FFI_AdbcConnection, + arg2: *mut FFI_ArrowArrayStream, + arg3: *mut FFI_AdbcError, + ) -> AdbcStatusCode, + >, + pub connection_init: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut FFI_AdbcConnection, + arg2: *mut FFI_AdbcDatabase, + arg3: *mut FFI_AdbcError, + ) -> AdbcStatusCode, + >, + pub connection_new: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut FFI_AdbcConnection, + arg2: *mut FFI_AdbcError, + ) -> AdbcStatusCode, + >, + pub connection_set_option: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut FFI_AdbcConnection, + arg2: *const c_char, + arg3: *const c_char, + arg4: *mut FFI_AdbcError, + ) -> AdbcStatusCode, + >, + pub connection_read_partition: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut FFI_AdbcConnection, + arg2: *const u8, + arg3: usize, + arg4: *mut FFI_ArrowArrayStream, + arg5: *mut FFI_AdbcError, + ) -> AdbcStatusCode, + >, + pub connection_release: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut FFI_AdbcConnection, + arg2: *mut FFI_AdbcError, + ) -> AdbcStatusCode, + >, + pub connection_rollback: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut FFI_AdbcConnection, + arg2: *mut FFI_AdbcError, + ) -> AdbcStatusCode, + >, + pub statement_bind: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut FFI_AdbcStatement, + arg2: *mut FFI_ArrowArray, + arg3: *mut FFI_ArrowSchema, + arg4: *mut FFI_AdbcError, + ) -> AdbcStatusCode, + >, + pub statement_bind_stream: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut FFI_AdbcStatement, + arg2: *mut FFI_ArrowArrayStream, + arg3: *mut FFI_AdbcError, + ) -> AdbcStatusCode, + >, + pub statement_execute_query: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut FFI_AdbcStatement, + arg2: *mut FFI_ArrowArrayStream, + arg3: *mut i64, + arg4: *mut FFI_AdbcError, + ) -> AdbcStatusCode, + >, + pub statement_execute_partitions: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut FFI_AdbcStatement, + arg2: *mut FFI_ArrowSchema, + arg3: *mut FFI_AdbcPartitions, + arg4: *mut i64, + arg5: *mut FFI_AdbcError, + ) -> AdbcStatusCode, + >, + pub statement_get_parameter_schema: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut FFI_AdbcStatement, + arg2: *mut FFI_ArrowSchema, + arg3: *mut FFI_AdbcError, + ) -> AdbcStatusCode, + >, + pub statement_new: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut FFI_AdbcConnection, + arg2: *mut FFI_AdbcStatement, + arg3: *mut FFI_AdbcError, + ) -> AdbcStatusCode, + >, + pub statement_prepare: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut FFI_AdbcStatement, + arg2: *mut FFI_AdbcError, + ) -> AdbcStatusCode, + >, + pub statement_release: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut FFI_AdbcStatement, + arg2: *mut FFI_AdbcError, + ) -> AdbcStatusCode, + >, + pub statement_set_option: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut FFI_AdbcStatement, + arg2: *const c_char, + arg3: *const c_char, + arg4: *mut FFI_AdbcError, + ) -> AdbcStatusCode, + >, + pub statement_set_sql_query: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut FFI_AdbcStatement, + arg2: *const c_char, + arg3: *mut FFI_AdbcError, + ) -> AdbcStatusCode, + >, + pub statement_set_substrait_plan: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut FFI_AdbcStatement, + arg2: *const u8, + arg3: usize, + arg4: *mut FFI_AdbcError, + ) -> AdbcStatusCode, + >, +} + +macro_rules! empty_driver { + ($( $func_name:ident ),+) => { + Self { + private_data: null_mut(), + private_manager: null_mut(), + release: None, + $( + $func_name: Some(driver_function_stubs::$func_name), + )+ + } + }; +} + +impl FFI_AdbcDriver { + /// Get an empty [Self], but with all functions filled in with stubs. + /// + /// Any of the stub functions will simply return [AdbcStatusCode::NotImplemented]. + pub fn empty() -> Self { + empty_driver!( + database_init, + database_new, + database_set_option, + database_release, + connection_commit, + connection_get_info, + connection_get_objects, + connection_get_table_schema, + connection_get_table_types, + connection_init, + connection_new, + connection_read_partition, + connection_release, + connection_rollback, + connection_set_option, + statement_bind, + statement_bind_stream, + statement_execute_partitions, + statement_execute_query, + statement_get_parameter_schema, + statement_new, + statement_prepare, + statement_release, + statement_set_option, + statement_set_sql_query, + statement_set_substrait_plan + ) + } +} + +impl Drop for FFI_AdbcDriver { + fn drop(&mut self) { + if let Some(release) = self.release { + let mut error = FFI_AdbcError::empty(); + let status = unsafe { release(self, &mut error) }; + if status != AdbcStatusCode::Ok { + panic!("Failed to cleanup driver: {}", unsafe { + CStr::from_ptr(error.message).to_string_lossy() + }); + } + } + } +} + +unsafe impl Send for FFI_AdbcDriver {} +unsafe impl Sync for FFI_AdbcDriver {} + +pub(crate) mod driver_function_stubs { + use super::*; + + pub(crate) unsafe extern "C" fn database_init( + _arg1: *mut FFI_AdbcDatabase, + _arg2: *mut FFI_AdbcError, + ) -> AdbcStatusCode { + AdbcStatusCode::NotImplemented + } + + pub(crate) unsafe extern "C" fn database_new( + _arg1: *mut FFI_AdbcDatabase, + _arg2: *mut FFI_AdbcError, + ) -> AdbcStatusCode { + AdbcStatusCode::NotImplemented + } + + pub(crate) unsafe extern "C" fn database_set_option( + _arg1: *mut FFI_AdbcDatabase, + _arg2: *const c_char, + _arg3: *const c_char, + _arg4: *mut FFI_AdbcError, + ) -> AdbcStatusCode { + AdbcStatusCode::NotImplemented + } + + pub(crate) unsafe extern "C" fn database_release( + _arg1: *mut FFI_AdbcDatabase, + _arg2: *mut FFI_AdbcError, + ) -> AdbcStatusCode { + AdbcStatusCode::NotImplemented + } + + pub(crate) unsafe extern "C" fn connection_commit( + _arg1: *mut FFI_AdbcConnection, + _arg2: *mut FFI_AdbcError, + ) -> AdbcStatusCode { + AdbcStatusCode::NotImplemented + } + + pub(crate) unsafe extern "C" fn connection_get_info( + _arg1: *mut FFI_AdbcConnection, + _arg2: *const u32, + _arg3: usize, + _arg4: *mut FFI_ArrowArrayStream, + _arg5: *mut FFI_AdbcError, + ) -> AdbcStatusCode { + AdbcStatusCode::NotImplemented + } + + pub(crate) unsafe extern "C" fn connection_get_objects( + _arg1: *mut FFI_AdbcConnection, + _arg2: AdbcObjectDepth, + _arg3: *const c_char, + _arg4: *const c_char, + _arg5: *const c_char, + _arg6: *const *const c_char, + _arg7: *const c_char, + _arg8: *mut FFI_ArrowArrayStream, + _arg9: *mut FFI_AdbcError, + ) -> AdbcStatusCode { + AdbcStatusCode::NotImplemented + } + + pub(crate) unsafe extern "C" fn connection_get_table_schema( + _arg1: *mut FFI_AdbcConnection, + _arg2: *const c_char, + _arg3: *const c_char, + _arg4: *const c_char, + _arg5: *mut FFI_ArrowSchema, + _arg6: *mut FFI_AdbcError, + ) -> AdbcStatusCode { + AdbcStatusCode::NotImplemented + } + + pub(crate) unsafe extern "C" fn connection_get_table_types( + _arg1: *mut FFI_AdbcConnection, + _arg2: *mut FFI_ArrowArrayStream, + _arg3: *mut FFI_AdbcError, + ) -> AdbcStatusCode { + AdbcStatusCode::NotImplemented + } + + pub(crate) unsafe extern "C" fn connection_init( + _arg1: *mut FFI_AdbcConnection, + _arg2: *mut FFI_AdbcDatabase, + _arg3: *mut FFI_AdbcError, + ) -> AdbcStatusCode { + AdbcStatusCode::NotImplemented + } + + pub(crate) unsafe extern "C" fn connection_new( + _arg1: *mut FFI_AdbcConnection, + _arg2: *mut FFI_AdbcError, + ) -> AdbcStatusCode { + AdbcStatusCode::NotImplemented + } + + pub(crate) unsafe extern "C" fn connection_set_option( + _arg1: *mut FFI_AdbcConnection, + _arg2: *const c_char, + _arg3: *const c_char, + _arg4: *mut FFI_AdbcError, + ) -> AdbcStatusCode { + AdbcStatusCode::NotImplemented + } + + pub(crate) unsafe extern "C" fn connection_read_partition( + _arg1: *mut FFI_AdbcConnection, + _arg2: *const u8, + _arg3: usize, + _arg4: *mut FFI_ArrowArrayStream, + _arg5: *mut FFI_AdbcError, + ) -> AdbcStatusCode { + AdbcStatusCode::NotImplemented + } + + pub(crate) unsafe extern "C" fn connection_release( + _arg1: *mut FFI_AdbcConnection, + _arg2: *mut FFI_AdbcError, + ) -> AdbcStatusCode { + AdbcStatusCode::NotImplemented + } + + pub(crate) unsafe extern "C" fn connection_rollback( + _arg1: *mut FFI_AdbcConnection, + _arg2: *mut FFI_AdbcError, + ) -> AdbcStatusCode { + AdbcStatusCode::NotImplemented + } + + pub(crate) unsafe extern "C" fn statement_bind( + _arg1: *mut FFI_AdbcStatement, + _arg2: *mut FFI_ArrowArray, + _arg3: *mut FFI_ArrowSchema, + _arg4: *mut FFI_AdbcError, + ) -> AdbcStatusCode { + AdbcStatusCode::NotImplemented + } + + pub(crate) unsafe extern "C" fn statement_bind_stream( + _arg1: *mut FFI_AdbcStatement, + _arg2: *mut FFI_ArrowArrayStream, + _arg3: *mut FFI_AdbcError, + ) -> AdbcStatusCode { + AdbcStatusCode::NotImplemented + } + + pub(crate) unsafe extern "C" fn statement_execute_query( + _arg1: *mut FFI_AdbcStatement, + _arg2: *mut FFI_ArrowArrayStream, + _arg3: *mut i64, + _arg4: *mut FFI_AdbcError, + ) -> AdbcStatusCode { + AdbcStatusCode::NotImplemented + } + pub(crate) unsafe extern "C" fn statement_execute_partitions( + _arg1: *mut FFI_AdbcStatement, + _arg2: *mut FFI_ArrowSchema, + _arg3: *mut FFI_AdbcPartitions, + _arg4: *mut i64, + _arg5: *mut FFI_AdbcError, + ) -> AdbcStatusCode { + AdbcStatusCode::NotImplemented + } + + pub(crate) unsafe extern "C" fn statement_get_parameter_schema( + _arg1: *mut FFI_AdbcStatement, + _arg2: *mut FFI_ArrowSchema, + _arg3: *mut FFI_AdbcError, + ) -> AdbcStatusCode { + AdbcStatusCode::NotImplemented + } + + pub(crate) unsafe extern "C" fn statement_new( + _arg1: *mut FFI_AdbcConnection, + _arg2: *mut FFI_AdbcStatement, + _arg3: *mut FFI_AdbcError, + ) -> AdbcStatusCode { + AdbcStatusCode::NotImplemented + } + + pub(crate) unsafe extern "C" fn statement_prepare( + _arg1: *mut FFI_AdbcStatement, + _arg2: *mut FFI_AdbcError, + ) -> AdbcStatusCode { + AdbcStatusCode::NotImplemented + } + + pub(crate) unsafe extern "C" fn statement_release( + _arg1: *mut FFI_AdbcStatement, + _arg2: *mut FFI_AdbcError, + ) -> AdbcStatusCode { + AdbcStatusCode::NotImplemented + } + + pub(crate) unsafe extern "C" fn statement_set_option( + _arg1: *mut FFI_AdbcStatement, + _arg2: *const c_char, + _arg3: *const c_char, + _arg4: *mut FFI_AdbcError, + ) -> AdbcStatusCode { + AdbcStatusCode::NotImplemented + } + + pub(crate) unsafe extern "C" fn statement_set_sql_query( + _arg1: *mut FFI_AdbcStatement, + _arg2: *const c_char, + _arg3: *mut FFI_AdbcError, + ) -> AdbcStatusCode { + AdbcStatusCode::NotImplemented + } + + pub(crate) unsafe extern "C" fn statement_set_substrait_plan( + _arg1: *mut FFI_AdbcStatement, + _arg2: *const u8, + _arg3: usize, + _arg4: *mut FFI_AdbcError, + ) -> AdbcStatusCode { + AdbcStatusCode::NotImplemented + } +} + +/// Depth parameter for GetObjects method. +#[derive(Debug)] +#[repr(i32)] +pub enum AdbcObjectDepth { + /// Metadata on catalogs, schemas, tables, and columns. + All = 0, + /// Metadata on catalogs only. + Catalogs = 1, + /// Metadata on catalogs and schemas. + DBSchemas = 2, + /// Metadata on catalogs, schemas, and tables. + Tables = 3, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_adbc_partitions() { + let cases: Vec>> = + vec![vec![], vec![vec![]], vec![vec![0, 1, 2, 3], vec![4, 5, 6]]]; + + for case in cases { + let num_partitions = case.len(); + let expected_partitions = case.clone(); + + let mut partitions: FFI_AdbcPartitions = case.into(); + + assert_eq!(partitions.num_partitions, num_partitions); + assert!(!partitions.private_data.is_null()); + + for (i, expected_part) in expected_partitions.into_iter().enumerate() { + let part_length = unsafe { *partitions.partition_lengths.add(i) }; + let part = unsafe { + std::slice::from_raw_parts(*partitions.partitions.add(i), part_length) + }; + assert_eq!(part, &expected_part); + } + + assert!(partitions.release.is_some()); + let release_func = partitions.release.unwrap(); + unsafe { + release_func(&mut partitions); + } + + assert!(partitions.partitions.is_null()); + assert!(partitions.partition_lengths.is_null()); + assert!(partitions.private_data.is_null()); + } + } +} diff --git a/rust/src/interface.rs b/rust/src/interface.rs new file mode 100644 index 0000000000..2c99aa3b78 --- /dev/null +++ b/rust/src/interface.rs @@ -0,0 +1,324 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! API traits for ADBC structs +//! +//! These are the interfaces to ADBC structs made more ergonomic for Rust +//! developers. They are implemented by the structs in [crate::driver_manager]. +use arrow::{datatypes::Schema, record_batch::RecordBatch, record_batch::RecordBatchReader}; + +use crate::ffi::AdbcObjectDepth; + +/// Databases hold state shared by multiple connections. This typically means +/// configuration and caches. For in-memory databases, it provides a place to +/// hold ownership of the in-memory database. +pub trait DatabaseApi { + type Error; + + /// Set an option on the database. + /// + /// Some databases may not allow setting options after it has been initialized. + fn set_option(&self, key: &str, value: &str) -> Result<(), Self::Error>; +} + +/// A connection is a single connection to a database. +/// +/// It is never accessed concurrently from multiple threads. +/// +/// # Autocommit +/// +/// Connections should start in autocommit mode. They can be moved out by +/// setting `"adbc.connection.autocommit"` to `"false"` (using +/// [ConnectionApi::set_option]). Turning off autocommit allows customizing +/// the isolation level. Read more in [adbc.h](https://github.com/apache/arrow-adbc/blob/main/adbc.h). +pub trait ConnectionApi { + type Error; + + /// Set an option on the connection. + /// + /// Some connections may not allow setting options after it has been initialized. + fn set_option(&self, key: &str, value: &str) -> Result<(), Self::Error>; + + /// Get metadata about the database/driver. + /// + /// The result is an Arrow dataset with the following schema: + /// + /// Field Name | Field Type + /// ----------------------------|------------------------ + /// `info_name` | `uint32 not null` + /// `info_value` | `INFO_SCHEMA` + /// + /// `INFO_SCHEMA` is a dense union with members: + /// + /// Field Name (Type Code) | Field Type + /// ------------------------------|------------------------ + /// `string_value` (0) | `utf8` + /// `bool_value` (1) | `bool` + /// `int64_value` (2) | `int64` + /// `int32_bitmask` (3) | `int32` + /// `string_list` (4) | `list` + /// `int32_to_int32_list_map` (5) | `map>` + /// + /// Each metadatum is identified by an integer code. The recognized + /// codes are defined as constants. Codes [0, 10_000) are reserved + /// for ADBC usage. Drivers/vendors will ignore requests for + /// unrecognized codes (the row will be omitted from the result). + /// + /// For definitions of known ADBC codes, see + fn get_info(&self, info_codes: &[u32]) -> Result, Self::Error>; + + /// Get a hierarchical view of all catalogs, database schemas, tables, and columns. + /// + /// # Schema + /// + /// The result is an Arrow dataset with the following schema: + /// + /// | Field Name | Field Type | + /// |----------------------------|---------------------------| + /// | `catalog_name` | `utf8` | + /// | `catalog_db_schemas` | `list` | + /// + /// `DB_SCHEMA_SCHEMA` is a Struct with fields: + /// + /// | Field Name | Field Type | + /// |----------------------------|---------------------------| + /// | `db_schema_name` | `utf8` | + /// | `db_schema_tables` | `list` | + /// + /// `TABLE_SCHEMA` is a Struct with fields: + /// + /// | Field Name | Field Type | + /// |----------------------------|---------------------------| + /// | `table_name` | `utf8 not null` | + /// | `table_type` | `utf8 not null` | + /// | `table_columns` | `list` | + /// | `table_constraints` | `list` | + /// + /// `COLUMN_SCHEMA` is a Struct with fields: + /// + /// | Field Name | Field Type | Comments | + /// |----------------------------|---------------------------|----------| + /// | `column_name` | `utf8 not null` | | + /// | `ordinal_position` | `int32` | (1) | + /// | `remarks` | `utf8` | (2) | + /// | `xdbc_data_type` | `int16` | (3) | + /// | `xdbc_type_name` | `utf8` | (3) | + /// | `xdbc_column_size` | `int32` | (3) | + /// | `xdbc_decimal_digits` | `int16` | (3) | + /// | `xdbc_num_prec_radix` | `int16` | (3) | + /// | `xdbc_nullable` | `int16` | (3) | + /// | `xdbc_column_def` | `utf8` | (3) | + /// | `xdbc_sql_data_type` | `int16` | (3) | + /// | `xdbc_datetime_sub` | `int16` | (3) | + /// | `xdbc_char_octet_length` | `int32` | (3) | + /// | `xdbc_is_nullable` | `utf8` | (3) | + /// | `xdbc_scope_catalog` | `utf8` | (3) | + /// | `xdbc_scope_schema` | `utf8` | (3) | + /// | `xdbc_scope_table` | `utf8` | (3) | + /// | `xdbc_is_autoincrement` | `bool` | (3) | + /// | `xdbc_is_generatedcolumn` | `bool` | (3) | + /// + /// 1. The column's ordinal position in the table (starting from 1). + /// 2. Database-specific description of the column. + /// 3. Optional value. Should be null if not supported by the driver. + /// xdbc_ values are meant to provide JDBC/ODBC-compatible metadata + /// in an agnostic manner. + /// + /// `CONSTRAINT_SCHEMA` is a Struct with fields: + /// + /// | Field Name | Field Type | Comments | + /// |----------------------------|---------------------------|----------| + /// | `constraint_name` | `utf8` | | + /// | `constraint_type` | `utf8 not null` | (1) | + /// | `constraint_column_names` | `list not null` | (2) | + /// | `constraint_column_usage` | `list` | (3) | + /// + /// 1. One of 'CHECK', 'FOREIGN KEY', 'PRIMARY KEY', or 'UNIQUE'. + /// 2. The columns on the current table that are constrained, in + /// order. + /// 3. For FOREIGN KEY only, the referenced table and columns. + /// + /// `USAGE_SCHEMA` is a Struct with fields: + /// + /// | Field Name | Field Type | + /// |----------------------------|-------------------------| + /// | `fk_catalog` | `utf8` | + /// | `fk_db_schema` | `utf8` | + /// | `fk_table` | `utf8 not null` | + /// | `fk_column_name` | `utf8 not null` | + /// + /// # Parameters + /// + /// * **depth**: The level of nesting to display. If [AdbcObjectDepth::All], display + /// all levels. If [AdbcObjectDepth::Catalogs], display only catalogs (i.e. `catalog_schemas` + /// will be null). If [AdbcObjectDepth::DBSchemas], display only catalogs and schemas + /// (i.e. `db_schema_tables` will be null), and so on. + /// * **catalog**: Only show tables in the given catalog. If None, + /// do not filter by catalog. If an empty string, only show tables + /// without a catalog. May be a search pattern (see next section). + /// * **db_schema**: Only show tables in the given database schema. If + /// None, do not filter by database schema. If an empty string, only show + /// tables without a database schema. May be a search pattern (see next section). + /// * **table_name**: Only show tables with the given name. If None, do not + /// filter by name. May be a search pattern (see next section). + /// * **table_type**: Only show tables matching one of the given table + /// types. If None, show tables of any type. Valid table types should + /// match those returned by [ConnectionApi::get_table_schema]. + /// * **column_name**: Only show columns with the given name. If + /// None, do not filter by name. May be a search pattern (see next section). + /// + /// # Search patterns + /// + /// Some parameters accept "search patterns", which are + /// strings that can contain the special character `"%"` to match zero + /// or more characters, or `"_"` to match exactly one character. (See + /// the documentation of DatabaseMetaData in JDBC or "Pattern Value + /// Arguments" in the ODBC documentation.) + fn get_objects( + &self, + depth: AdbcObjectDepth, + catalog: Option<&str>, + db_schema: Option<&str>, + table_name: Option<&str>, + table_type: Option<&[&str]>, + column_name: Option<&str>, + ) -> Result, Self::Error>; + + /// Get the Arrow schema of a table. + /// + /// `catalog` or `db_schema` may be `None` when not applicable. + fn get_table_schema( + &self, + catalog: Option<&str>, + db_schema: Option<&str>, + table_name: &str, + ) -> Result; + + /// Get a list of table types in the database. + /// + /// The result is an Arrow dataset with the following schema: + /// + /// Field Name | Field Type + /// -----------------|-------------- + /// `table_type` | `utf8 not null` + fn get_table_types(&self) -> Result, Self::Error>; + + /// Read part of a partitioned result set. + fn read_partition(&self, partition: &[u8]) -> Result, Self::Error>; + + /// Commit any pending transactions. Only used if autocommit is disabled. + fn commit(&self) -> Result<(), Self::Error>; + + /// Roll back any pending transactions. Only used if autocommit is disabled. + fn rollback(&self) -> Result<(), Self::Error>; +} + +/// A container for all state needed to execute a database query, such as the +/// query itself, parameters for prepared statements, driver parameters, etc. +/// +/// Statements may represent queries or prepared statements. +/// +/// Statements may be used multiple times and can be reconfigured +/// (e.g. they can be reused to execute multiple different queries). +/// However, executing a statement (and changing certain other state) +/// will invalidate result sets obtained prior to that execution. +/// +/// Multiple statements may be created from a single connection. +/// However, the driver may block or error if they are used +/// concurrently (whether from a single thread or multiple threads). +pub trait StatementApi { + type Error; + + /// Turn this statement into a prepared statement to be executed multiple times. + /// + /// This should return an error if called before [StatementApi::set_sql_query]. + fn prepare(&mut self) -> Result<(), Self::Error>; + + /// Set a string option on a statement. + fn set_option(&mut self, key: &str, value: &str) -> Result<(), Self::Error>; + + /// Set the SQL query to execute. + fn set_sql_query(&mut self, query: &str) -> Result<(), Self::Error>; + + /// Set the Substrait plan to execute. + fn set_substrait_plan(&mut self, plan: &[u8]) -> Result<(), Self::Error>; + + /// Get the schema for bound parameters. + /// + /// This retrieves an Arrow schema describing the number, names, and + /// types of the parameters in a parameterized statement. The fields + /// of the schema should be in order of the ordinal position of the + /// parameters; named parameters should appear only once. + /// + /// If the parameter does not have a name, or the name cannot be + /// determined, the name of the corresponding field in the schema will + /// be an empty string. If the type cannot be determined, the type of + /// the corresponding field will be NA (NullType). + /// + /// This should return an error if this was called before [StatementApi::prepare]. + fn get_param_schema(&mut self) -> Result; + + /// Bind Arrow data, either for bulk inserts or prepared statements. + fn bind_data(&mut self, batch: RecordBatch) -> Result<(), Self::Error>; + + /// Bind Arrow data, either for bulk inserts or prepared statements. + fn bind_stream(&mut self, stream: Box) -> Result<(), Self::Error>; + + /// Execute a statement and get the results. + /// + /// See [StatementResult]. + fn execute(&mut self) -> Result; + + /// Execute a query that doesn't have a result set. + /// + /// Will return the number of rows affected, or -1 if unknown or unsupported. + fn execute_update(&mut self) -> Result; + + /// Execute a statement with a partitioned result set. + /// + /// This is not required to be implemented, as it only applies to backends + /// that internally partition results. These backends can use this method + /// to support threaded or distributed clients. + /// + /// See [PartitionedStatementResult]. + fn execute_partitioned(&mut self) -> Result; +} + +/// Result of calling [StatementApi::execute]. +/// +/// `result` may be None if there is no meaningful result. +/// `row_affected` may be -1 if not applicable or if it is not supported. +pub struct StatementResult { + pub result: Option>, + pub rows_affected: i64, +} + +/// Partitioned results +/// +/// [ConnectionApi::read_partition] will be called to get the output stream +/// for each partition. +/// +/// These may be used by a multi-threaded or a distributed client. Each partition +/// will be retrieved by a separate connection. For in-memory databases, these +/// may be connections on different threads that all reference the same database. +/// For remote databases, these may be connections in different processes. +#[derive(Debug, Clone)] +pub struct PartitionedStatementResult { + pub schema: Schema, + pub partition_ids: Vec>, + pub rows_affected: i64, +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs new file mode 100644 index 0000000000..f3e938b0d9 --- /dev/null +++ b/rust/src/lib.rs @@ -0,0 +1,219 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Rust structs and utilities for using and building Arrow Database Connectivity (ADBC) drivers. +//! +//! ADBC drivers provide an ABI-stable interface for interacting with databases, +//! that: +//! +//! * Uses the Arrow [C Data interface](https://arrow.apache.org/docs/format/CDataInterface.html) +//! and [C Stream Interface](https://arrow.apache.org/docs/format/CStreamInterface.html) +//! for efficient data interchange. +//! * Supports partitioned result sets for multi-threaded or distributed +//! applications. +//! * Support for [Substrait](https://substrait.io/) plans in addition to SQL queries. +//! +//! When implemented for remote databases, [Flight SQL](https://arrow.apache.org/docs/format/FlightSql.html) +//! can be used as the communication protocol. This means data can be in Arrow +//! format through the whole connection, minimizing serialization and deserialization +//! overhead. +//! +//! Read more about ADBC at +//! +//! ## Using ADBC drivers +//! +//! The [driver_manager] mod allows loading drivers, either from an initialization +//! function or by dynamically finding such a function in a dynamic library. +//! +//! ``` +//! use arrow::datatypes::Int64Type; +//! use arrow::array::as_primitive_array; +//! use arrow::record_batch::RecordBatchReader; +//! +//! use arrow_adbc::driver_manager::AdbcDriver; +//! use arrow_adbc::ADBC_VERSION_1_0_0; +//! use arrow_adbc::interface::StatementApi; +//! +//! # fn main() -> arrow_adbc::driver_manager::Result<()> { +//! let sqlite_driver = AdbcDriver::load("adbc_driver_sqlite", None, ADBC_VERSION_1_0_0)?; +//! let sqlite_database = sqlite_driver.new_database()?.init()?; +//! let sqlite_conn = sqlite_database.new_connection()?.init()?; +//! let mut sqlite_statement = sqlite_conn.new_statement()?; +//! +//! sqlite_statement.set_sql_query("SELECT 1"); +//! let mut results: Box = sqlite_statement.execute()? +//! .result.expect("Query did not have a result"); +//! +//! let batch = results.next().expect("Result did not have at least one batch")?; +//! let result = as_primitive_array::(batch.column(0)); +//! +//! assert_eq!(result.value(0), 1); +//! # Ok(()) +//! # } +//! ``` +pub mod driver_manager; +pub mod error; +pub mod ffi; +pub mod interface; + +pub const ADBC_VERSION_1_0_0: i32 = 1000000; + +/// Known options that can be set on databases, connections, and statements. +/// +/// For use with [crate::interface::DatabaseApi::set_option], +/// [crate::interface::ConnectionApi::set_option], +/// and [crate::interface::StatementApi::set_option]. +pub mod options { + pub const INGEST_OPTION_TARGET_TABLE: &str = "adbc.ingest.target_table"; + pub const ADBC_INGEST_OPTION_MODE: &str = "adbc.ingest.mode"; + pub const ADBC_INGEST_OPTION_MODE_CREATE: &str = "adbc.ingest.mode.create"; + pub const ADBC_INGEST_OPTION_MODE_APPEND: &str = "adbc.ingest.mode.append"; + + /// The name of the canonical option for whether autocommit is enabled. + pub const ADBC_CONNECTION_OPTION_AUTOCOMMIT: &str = "adbc.connection.autocommit"; + /// The name of the canonical option for whether the current connection should + /// be restricted to being read-only. + pub const ADBC_CONNECTION_OPTION_READ_ONLY: &str = "adbc.connection.readonly"; + /// The name of the canonical option for setting the isolation level of a + /// transaction. + /// + /// Should only be used in conjunction with autocommit disabled and + /// AdbcConnectionCommit / AdbcConnectionRollback. If the desired + /// isolation level is not supported by a driver, it should return an + /// appropriate error. + pub const ADBC_CONNECTION_OPTION_ISOLATION_LEVEL: &str = + "adbc.connection.transaction.isolation_level"; + /// Use database or driver default isolation level + pub const ADBC_OPTION_ISOLATION_LEVEL_DEFAULT: &str = + "adbc.connection.transaction.isolation.default"; + /// The lowest isolation level. Dirty reads are allowed, so one transaction + /// may see not-yet-committed changes made by others. + pub const ADBC_OPTION_ISOLATION_LEVEL_READ_UNCOMMITTED: &str = + "adbc.connection.transaction.isolation.read_uncommitted"; + /// Lock-based concurrency control keeps write locks until the + /// end of the transaction, but read locks are released as soon as a + /// SELECT is performed. Non-repeatable reads can occur in this + /// isolation level. + /// + /// More simply put, Read Committed is an isolation level that guarantees + /// that any data read is committed at the moment it is read. It simply + /// restricts the reader from seeing any intermediate, uncommitted, + /// 'dirty' reads. It makes no promise whatsoever that if the transaction + /// re-issues the read, it will find the same data; data is free to change + /// after it is read. + pub const ADBC_OPTION_ISOLATION_LEVEL_READ_COMMITTED: &str = + "adbc.connection.transaction.isolation.read_committed"; + /// Lock-based concurrency control keeps read AND write locks + /// (acquired on selection data) until the end of the transaction. + /// + /// However, range-locks are not managed, so phantom reads can occur. + /// Write skew is possible at this isolation level in some systems. + pub const ADBC_OPTION_ISOLATION_LEVEL_REPEATABLE_READ: &str = + "adbc.connection.transaction.isolation.repeatable_read"; + /// This isolation guarantees that all reads in the transaction + /// will see a consistent snapshot of the database and the transaction + /// should only successfully commit if no updates conflict with any + /// concurrent updates made since that snapshot. + pub const ADBC_OPTION_ISOLATION_LEVEL_SNAPSHOT: &str = + "adbc.connection.transaction.isolation.snapshot"; + /// Serializability requires read and write locks to be released + /// only at the end of the transaction. This includes acquiring range- + /// locks when a select query uses a ranged WHERE clause to avoid + /// phantom reads. + pub const ADBC_OPTION_ISOLATION_LEVEL_SERIALIZABLE: &str = + "adbc.connection.transaction.isolation.serializable"; + /// The central distinction between serializability and linearizability + /// is that serializability is a global property; a property of an entire + /// history of operations and transactions. Linearizability is a local + /// property; a property of a single operation/transaction. + /// + /// Linearizability can be viewed as a special case of strict serializability + /// where transactions are restricted to consist of a single operation applied + /// to a single object. + pub const ADBC_OPTION_ISOLATION_LEVEL_LINEARIZABLE: &str = + "adbc.connection.transaction.isolation.linearizable"; +} + +/// Utilities for driver info +/// +/// For use with [crate::interface::ConnectionApi::get_info]. +pub mod info { + use arrow::datatypes::{DataType, Field, Schema, UnionMode}; + + /// Contains known info codes defined by ADBC. + pub mod codes { + /// The database vendor/product version (type: utf8). + pub const VENDOR_NAME: u32 = 0; + /// The database vendor/product version (type: utf8). + pub const VENDOR_VERSION: u32 = 1; + /// The database vendor/product Arrow library version (type: utf8). + pub const VENDOR_ARROW_VERSION: u32 = 2; + /// The driver name (type: utf8). + pub const DRIVER_NAME: u32 = 100; + /// The driver version (type: utf8). + pub const DRIVER_VERSION: u32 = 101; + /// The driver Arrow library version (type: utf8). + pub const DRIVER_ARROW_VERSION: u32 = 102; + } + pub fn info_schema() -> Schema { + Schema::new(vec![ + Field::new("info_name", DataType::UInt32, false), + Field::new( + "info_value", + DataType::Union( + vec![ + Field::new("string_value", DataType::Utf8, true), + Field::new("bool_value", DataType::Boolean, true), + Field::new("int64_value", DataType::Int64, true), + Field::new("int32_bitmask", DataType::Int32, true), + Field::new( + "string_list", + DataType::List(Box::new(Field::new("item", DataType::Utf8, true))), + true, + ), + Field::new( + "int32_to_int32_list_map", + DataType::Map( + Box::new(Field::new( + "entries", + DataType::Struct(vec![ + Field::new("key", DataType::Int32, false), + Field::new( + "value", + DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + ))), + true, + ), + ]), + true, + )), + false, + ), + true, + ), + ], + vec![0, 1, 2, 3, 4, 5], + UnionMode::Dense, + ), + true, + ), + ]) + } +} diff --git a/rust/tests/test_driver_manager.rs b/rust/tests/test_driver_manager.rs new file mode 100644 index 0000000000..48773b2089 --- /dev/null +++ b/rust/tests/test_driver_manager.rs @@ -0,0 +1,222 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Test driver manager against SQLite implementations. + +use std::collections::{HashMap, HashSet}; +use std::ops::Deref; +use std::sync::Arc; + +use arrow::array::{ + as_list_array, as_primitive_array, as_string_array, as_struct_array, as_union_array, Array, + Int64Array, StringArray, +}; +use arrow::compute::concat_batches; +use arrow::datatypes::{DataType, Field, Schema, UInt32Type}; +use arrow::error::ArrowError; +use arrow::record_batch::RecordBatch; +use arrow_adbc::driver_manager::{AdbcDatabase, AdbcDriver, AdbcStatement, Result}; +use arrow_adbc::ffi::AdbcObjectDepth; +use arrow_adbc::info::{codes, info_schema}; +use arrow_adbc::interface::{ConnectionApi, StatementApi}; +use arrow_adbc::ADBC_VERSION_1_0_0; + +fn get_driver() -> Result { + AdbcDriver::load("adbc_driver_sqlite", None, ADBC_VERSION_1_0_0) +} + +fn get_database() -> Result { + let driver = get_driver()?; + // By passing in "" for uri, we create a distinct temporary database for each + // test, preventing noisy neighbor issues on tests. + driver.new_database()?.set_option("uri", "")?.init() +} + +#[test] +fn test_database() { + let driver = get_driver().unwrap(); + + let builder = driver + .new_database() + .unwrap() + .set_option("uri", "test.db") + .unwrap() + .init() + .unwrap(); + + builder.set_option("uri", "test2.db").unwrap(); +} + +#[test] +fn test_connection_info() { + let database = get_database().unwrap(); + let connection = database.new_connection().unwrap().init().unwrap(); + + let table_types = connection.get_table_types().unwrap(); + assert_eq!(table_types, vec!["table", "view"]); + + let info = connection + .get_info(&[codes::DRIVER_NAME, codes::VENDOR_NAME]) + .unwrap(); + assert_eq!(info.schema().deref(), &info_schema()); + let info: HashMap = info + .flat_map(|maybe_batch| { + let batch = maybe_batch.unwrap(); + let id = as_primitive_array::(batch.column(0)); + let values = as_union_array(batch.column(1)); + let string_values = as_string_array(values.child(0)); + let mut out = vec![]; + for i in 0..batch.num_rows() { + assert_eq!(values.type_id(i), 0); + out.push((id.value(i), string_values.value(i).to_string())); + } + out + }) + .collect(); + assert_eq!(info.len(), 2); + assert_eq!( + info.get(&codes::DRIVER_NAME), + Some(&"ADBC SQLite Driver".to_string()) + ); + assert_eq!(info.get(&codes::VENDOR_NAME), Some(&"SQLite".to_string())); +} + +fn get_example_data() -> RecordBatch { + let ints_arr = Arc::new(Int64Array::from(vec![1, 2, 3, 4])); + let str_arr = Arc::new(StringArray::from(vec!["a", "b", "c", "d"])); + let schema1 = Schema::new(vec![ + Field::new("ints", DataType::Int64, true), + Field::new("strs", DataType::Utf8, true), + ]); + RecordBatch::try_new(Arc::new(schema1), vec![ints_arr, str_arr]).unwrap() +} + +fn upload_data(statement: &mut AdbcStatement, data: RecordBatch, name: &str) { + statement + .set_option(arrow_adbc::options::INGEST_OPTION_TARGET_TABLE, name) + .unwrap(); + statement.bind_data(data).unwrap(); + statement.execute_update().unwrap(); +} + +#[test] +fn test_connection_get_objects() { + let database = get_database().unwrap(); + let connection = database.new_connection().unwrap().init().unwrap(); + + let record_batch = get_example_data(); + let mut statement = connection.new_statement().unwrap(); + upload_data(&mut statement, record_batch, "foo"); + + let objects: Vec = connection + .get_objects(AdbcObjectDepth::All, None, None, None, None, None) + .unwrap() + .collect::>() + .unwrap(); + + assert_eq!(objects.len(), 1); + let batch = &objects[0]; + // There is only 1 database + assert_eq!(batch.num_rows(), 1); + + let db_schemas = as_struct_array(as_list_array(batch.column(1)).values()); + // There is only 1 db_schema + assert_eq!(db_schemas.len(), 1); + + let tables = as_struct_array(as_list_array(db_schemas.column(1)).values()); + // There is only 1 table + assert_eq!(tables.len(), 1); + let table_names = as_string_array(tables.column(0)); + assert_eq!(table_names.value(0), "foo"); + + let columns = as_struct_array(as_list_array(tables.column(2)).values()); + // There are two columns + assert_eq!(columns.len(), 2); + let column_names: HashSet<&str> = as_string_array(columns.column(0)) + .into_iter() + .flatten() + .collect(); + assert_eq!(column_names, HashSet::from_iter(vec!["strs", "ints"])); +} + +#[test] +fn test_connection_get_table_schema() { + let database = get_database().unwrap(); + let connection = database.new_connection().unwrap().init().unwrap(); + + let record_batch = get_example_data(); + let mut statement = connection.new_statement().unwrap(); + upload_data(&mut statement, record_batch.clone(), "bar"); + + let schema = connection.get_table_schema(None, None, "bar").unwrap(); + + assert_eq!(&schema, record_batch.schema().as_ref()); +} + +#[test] +fn test_prepared() { + let database = get_database().unwrap(); + let connection = database.new_connection().unwrap().init().unwrap(); + + let array = Arc::new(Int64Array::from_iter(vec![1, 2, 3, 4])); + + let mut statement = connection.new_statement().unwrap(); + statement.set_sql_query("SELECT ?").unwrap(); + statement.prepare().unwrap(); + let param_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("1", DataType::Int64, false)])), + vec![array.clone()], + ) + .unwrap(); + statement.bind_data(param_batch).unwrap(); + + let result = statement.execute().unwrap(); + + let expected_schema = Schema::new(vec![Field::new("?", DataType::Int64, true)]); + let result_schema = result.result.as_ref().unwrap().schema(); + assert_eq!(result_schema.as_ref(), &expected_schema); + + let data: Vec = result + .result + .unwrap() + .collect::>() + .unwrap(); + let data: RecordBatch = concat_batches(&result_schema, data.iter()).unwrap(); + let expected = RecordBatch::try_new(Arc::new(expected_schema), vec![array]).unwrap(); + assert_eq!(data, expected); +} + +#[test] +fn test_ingest() { + let database = get_database().unwrap(); + let connection = database.new_connection().unwrap().init().unwrap(); + + let record_batch = get_example_data(); + let mut statement = connection.new_statement().unwrap(); + upload_data(&mut statement, record_batch.clone(), "baz"); + + statement.set_sql_query("SELECT * FROM baz").unwrap(); + let result = statement.execute().unwrap(); + let data: Vec = result + .result + .unwrap() + .collect::>() + .unwrap(); + let data: RecordBatch = concat_batches(&data[0].schema(), data.iter()).unwrap(); + + assert_eq!(data, record_batch); +}