diff --git a/arrow-flight/Cargo.toml b/arrow-flight/Cargo.toml index 60ba3ae827de..111bf94d804c 100644 --- a/arrow-flight/Cargo.toml +++ b/arrow-flight/Cargo.toml @@ -76,6 +76,7 @@ pin-project-lite = "0.2" tempfile = "3.3" tokio-stream = { version = "0.1", features = ["net"] } tower = "0.4.13" +uuid = { version = "1.10.0", features = ["v4"] } [[example]] name = "flight_sql_server" diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs index df5c1767689a..91790898b1cb 100644 --- a/arrow-flight/src/sql/client.rs +++ b/arrow-flight/src/sql/client.rs @@ -28,15 +28,19 @@ use crate::decode::FlightRecordBatchStream; use crate::encode::FlightDataEncoderBuilder; use crate::error::FlightError; use crate::flight_service_client::FlightServiceClient; -use crate::sql::server::{CLOSE_PREPARED_STATEMENT, CREATE_PREPARED_STATEMENT}; +use crate::sql::gen::action_end_transaction_request::EndTransaction; +use crate::sql::server::{ + BEGIN_TRANSACTION, CLOSE_PREPARED_STATEMENT, CREATE_PREPARED_STATEMENT, END_TRANSACTION, +}; use crate::sql::{ + ActionBeginTransactionRequest, ActionBeginTransactionResult, ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest, - ActionCreatePreparedStatementResult, Any, CommandGetCatalogs, CommandGetCrossReference, - CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys, - CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, - CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery, - CommandStatementUpdate, DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, - SqlInfo, + ActionCreatePreparedStatementResult, ActionEndTransactionRequest, Any, CommandGetCatalogs, + CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, + CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, + CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate, + CommandStatementQuery, CommandStatementUpdate, DoPutPreparedStatementResult, DoPutUpdateResult, + ProstMessageExt, SqlInfo, }; use crate::trailers::extract_lazy_trailers; use crate::{ @@ -399,6 +403,54 @@ impl FlightSqlServiceClient { )) } + /// Request to begin a transaction. + pub async fn begin_transaction(&mut self) -> Result { + let cmd = ActionBeginTransactionRequest {}; + let action = Action { + r#type: BEGIN_TRANSACTION.to_string(), + body: cmd.as_any().encode_to_vec().into(), + }; + let req = self.set_request_headers(action.into_request())?; + let mut result = self + .flight_client + .do_action(req) + .await + .map_err(status_to_arrow_error)? + .into_inner(); + let result = result + .message() + .await + .map_err(status_to_arrow_error)? + .unwrap(); + let any = Any::decode(&*result.body).map_err(decode_error_to_arrow_error)?; + let begin_result: ActionBeginTransactionResult = any.unpack()?.unwrap(); + Ok(begin_result.transaction_id) + } + + /// Request to commit/rollback a transaction. + pub async fn end_transaction( + &mut self, + transaction_id: Bytes, + action: EndTransaction, + ) -> Result<(), ArrowError> { + let cmd = ActionEndTransactionRequest { + transaction_id, + action: action as i32, + }; + let action = Action { + r#type: END_TRANSACTION.to_string(), + body: cmd.as_any().encode_to_vec().into(), + }; + let req = self.set_request_headers(action.into_request())?; + let _ = self + .flight_client + .do_action(req) + .await + .map_err(status_to_arrow_error)? + .into_inner(); + Ok(()) + } + /// Explicitly shut down and clean up the client. pub async fn close(&mut self) -> Result<(), ArrowError> { // TODO: consume self instead of &mut self to explicitly prevent reuse? diff --git a/arrow-flight/src/sql/mod.rs b/arrow-flight/src/sql/mod.rs index b3b9882ee0f2..61eb67b6933e 100644 --- a/arrow-flight/src/sql/mod.rs +++ b/arrow-flight/src/sql/mod.rs @@ -49,6 +49,7 @@ mod gen { include!("arrow.flight.protocol.sql.rs"); } +pub use gen::action_end_transaction_request::EndTransaction; pub use gen::ActionBeginSavepointRequest; pub use gen::ActionBeginSavepointResult; pub use gen::ActionBeginTransactionRequest; diff --git a/arrow-flight/tests/flight_sql_client_cli.rs b/arrow-flight/tests/flight_sql_client_cli.rs index 317eb3900456..631f5cd31465 100644 --- a/arrow-flight/tests/flight_sql_client_cli.rs +++ b/arrow-flight/tests/flight_sql_client_cli.rs @@ -15,25 +15,21 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; use std::{net::SocketAddr, pin::Pin, sync::Arc, time::Duration}; use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray}; +use arrow_flight::sql::client::FlightSqlServiceClient; +use arrow_flight::sql::EndTransaction; use arrow_flight::{ decode::FlightRecordBatchStream, flight_service_server::{FlightService, FlightServiceServer}, sql::{ server::{FlightSqlService, PeekableFlightDataStream}, - ActionBeginSavepointRequest, ActionBeginSavepointResult, ActionBeginTransactionRequest, - ActionBeginTransactionResult, ActionCancelQueryRequest, ActionCancelQueryResult, - ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest, - ActionCreatePreparedStatementResult, ActionCreatePreparedSubstraitPlanRequest, - ActionEndSavepointRequest, ActionEndTransactionRequest, Any, CommandGetCatalogs, - CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, - CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, - CommandGetTables, CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, - CommandPreparedStatementUpdate, CommandStatementQuery, CommandStatementSubstraitPlan, - CommandStatementUpdate, DoPutPreparedStatementResult, ProstMessageExt, SqlInfo, - TicketStatementQuery, + ActionBeginTransactionRequest, ActionBeginTransactionResult, + ActionCreatePreparedStatementRequest, ActionCreatePreparedStatementResult, + ActionEndTransactionRequest, Any, CommandPreparedStatementQuery, CommandStatementQuery, + DoPutPreparedStatementResult, ProstMessageExt, SqlInfo, }, utils::batches_to_flight_data, Action, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, @@ -45,8 +41,11 @@ use assert_cmd::Command; use bytes::Bytes; use futures::{Stream, TryStreamExt}; use prost::Message; +use tokio::sync::Mutex; use tokio::{net::TcpListener, task::JoinHandle}; +use tonic::transport::Endpoint; use tonic::{Request, Response, Status, Streaming}; +use uuid::Uuid; const QUERY: &str = "SELECT * FROM table;"; @@ -140,6 +139,7 @@ async fn test_do_put_prepared_statement(test_server: FlightSqlServiceImpl) { pub async fn test_do_put_prepared_statement_stateless() { test_do_put_prepared_statement(FlightSqlServiceImpl { stateless_prepared_statements: true, + transactions: Arc::new(Mutex::new(HashMap::new())), }) .await } @@ -148,10 +148,48 @@ pub async fn test_do_put_prepared_statement_stateless() { pub async fn test_do_put_prepared_statement_stateful() { test_do_put_prepared_statement(FlightSqlServiceImpl { stateless_prepared_statements: false, + transactions: Arc::new(Mutex::new(HashMap::new())), }) .await } +#[tokio::test] +pub async fn test_begin_end_transaction() { + let test_server = FlightSqlServiceImpl { + stateless_prepared_statements: true, + transactions: Arc::new(Mutex::new(HashMap::new())), + }; + let fixture = TestFixture::new(&test_server).await; + let addr = fixture.addr; + let channel = Endpoint::from_shared(format!("http://{}:{}", addr.ip(), addr.port())) + .unwrap() + .connect() + .await + .expect("error connecting"); + let mut flight_sql_client = FlightSqlServiceClient::new(channel); + + // begin commit + let transaction_id = flight_sql_client.begin_transaction().await.unwrap(); + flight_sql_client + .end_transaction(transaction_id, EndTransaction::Commit) + .await + .unwrap(); + + // begin rollback + let transaction_id = flight_sql_client.begin_transaction().await.unwrap(); + flight_sql_client + .end_transaction(transaction_id, EndTransaction::Rollback) + .await + .unwrap(); + + // unknown transaction id + let transaction_id = "UnknownTransactionId".to_string().into(); + assert!(flight_sql_client + .end_transaction(transaction_id, EndTransaction::Commit) + .await + .is_err()); +} + /// All tests must complete within this many seconds or else the test server is shutdown const DEFAULT_TIMEOUT_SECONDS: u64 = 30; @@ -161,12 +199,14 @@ pub struct FlightSqlServiceImpl { /// prepared statements. stateful servers will not return an updated /// handle after executing `DoPut(CommandPreparedStatementQuery)` stateless_prepared_statements: bool, + transactions: Arc>>, } impl Default for FlightSqlServiceImpl { fn default() -> Self { Self { stateless_prepared_statements: true, + transactions: Arc::new(Mutex::new(HashMap::new())), } } } @@ -318,244 +358,6 @@ impl FlightSqlService for FlightSqlServiceImpl { Ok(resp) } - async fn get_flight_info_substrait_plan( - &self, - _query: CommandStatementSubstraitPlan, - _request: Request, - ) -> Result, Status> { - Err(Status::unimplemented( - "get_flight_info_substrait_plan not implemented", - )) - } - - async fn get_flight_info_catalogs( - &self, - _query: CommandGetCatalogs, - _request: Request, - ) -> Result, Status> { - Err(Status::unimplemented( - "get_flight_info_catalogs not implemented", - )) - } - - async fn get_flight_info_schemas( - &self, - _query: CommandGetDbSchemas, - _request: Request, - ) -> Result, Status> { - Err(Status::unimplemented( - "get_flight_info_schemas not implemented", - )) - } - - async fn get_flight_info_tables( - &self, - _query: CommandGetTables, - _request: Request, - ) -> Result, Status> { - Err(Status::unimplemented( - "get_flight_info_tables not implemented", - )) - } - - async fn get_flight_info_table_types( - &self, - _query: CommandGetTableTypes, - _request: Request, - ) -> Result, Status> { - Err(Status::unimplemented( - "get_flight_info_table_types not implemented", - )) - } - - async fn get_flight_info_sql_info( - &self, - _query: CommandGetSqlInfo, - _request: Request, - ) -> Result, Status> { - Err(Status::unimplemented( - "get_flight_info_sql_info not implemented", - )) - } - - async fn get_flight_info_primary_keys( - &self, - _query: CommandGetPrimaryKeys, - _request: Request, - ) -> Result, Status> { - Err(Status::unimplemented( - "get_flight_info_primary_keys not implemented", - )) - } - - async fn get_flight_info_exported_keys( - &self, - _query: CommandGetExportedKeys, - _request: Request, - ) -> Result, Status> { - Err(Status::unimplemented( - "get_flight_info_exported_keys not implemented", - )) - } - - async fn get_flight_info_imported_keys( - &self, - _query: CommandGetImportedKeys, - _request: Request, - ) -> Result, Status> { - Err(Status::unimplemented( - "get_flight_info_imported_keys not implemented", - )) - } - - async fn get_flight_info_cross_reference( - &self, - _query: CommandGetCrossReference, - _request: Request, - ) -> Result, Status> { - Err(Status::unimplemented( - "get_flight_info_imported_keys not implemented", - )) - } - - async fn get_flight_info_xdbc_type_info( - &self, - _query: CommandGetXdbcTypeInfo, - _request: Request, - ) -> Result, Status> { - Err(Status::unimplemented( - "get_flight_info_xdbc_type_info not implemented", - )) - } - - // do_get - async fn do_get_statement( - &self, - _ticket: TicketStatementQuery, - _request: Request, - ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("do_get_statement not implemented")) - } - - async fn do_get_prepared_statement( - &self, - _query: CommandPreparedStatementQuery, - _request: Request, - ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented( - "do_get_prepared_statement not implemented", - )) - } - - async fn do_get_catalogs( - &self, - _query: CommandGetCatalogs, - _request: Request, - ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("do_get_catalogs not implemented")) - } - - async fn do_get_schemas( - &self, - _query: CommandGetDbSchemas, - _request: Request, - ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("do_get_schemas not implemented")) - } - - async fn do_get_tables( - &self, - _query: CommandGetTables, - _request: Request, - ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("do_get_tables not implemented")) - } - - async fn do_get_table_types( - &self, - _query: CommandGetTableTypes, - _request: Request, - ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("do_get_table_types not implemented")) - } - - async fn do_get_sql_info( - &self, - _query: CommandGetSqlInfo, - _request: Request, - ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("do_get_sql_info not implemented")) - } - - async fn do_get_primary_keys( - &self, - _query: CommandGetPrimaryKeys, - _request: Request, - ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("do_get_primary_keys not implemented")) - } - - async fn do_get_exported_keys( - &self, - _query: CommandGetExportedKeys, - _request: Request, - ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented( - "do_get_exported_keys not implemented", - )) - } - - async fn do_get_imported_keys( - &self, - _query: CommandGetImportedKeys, - _request: Request, - ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented( - "do_get_imported_keys not implemented", - )) - } - - async fn do_get_cross_reference( - &self, - _query: CommandGetCrossReference, - _request: Request, - ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented( - "do_get_cross_reference not implemented", - )) - } - - async fn do_get_xdbc_type_info( - &self, - _query: CommandGetXdbcTypeInfo, - _request: Request, - ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented( - "do_get_xdbc_type_info not implemented", - )) - } - - // do_put - async fn do_put_statement_update( - &self, - _ticket: CommandStatementUpdate, - _request: Request, - ) -> Result { - Err(Status::unimplemented( - "do_put_statement_update not implemented", - )) - } - - async fn do_put_substrait_plan( - &self, - _ticket: CommandStatementSubstraitPlan, - _request: Request, - ) -> Result { - Err(Status::unimplemented( - "do_put_substrait_plan not implemented", - )) - } - async fn do_put_prepared_statement_query( &self, _query: CommandPreparedStatementQuery, @@ -590,16 +392,6 @@ impl FlightSqlService for FlightSqlServiceImpl { Ok(result) } - async fn do_put_prepared_statement_update( - &self, - _query: CommandPreparedStatementUpdate, - _request: Request, - ) -> Result { - Err(Status::unimplemented( - "do_put_prepared_statement_update not implemented", - )) - } - async fn do_action_create_prepared_statement( &self, _query: ActionCreatePreparedStatementRequest, @@ -609,60 +401,38 @@ impl FlightSqlService for FlightSqlServiceImpl { .map_err(|e| Status::internal(format!("Unable to serialize schema: {e}"))) } - async fn do_action_close_prepared_statement( - &self, - _query: ActionClosePreparedStatementRequest, - _request: Request, - ) -> Result<(), Status> { - unimplemented!("Implement do_action_close_prepared_statement") - } - - async fn do_action_create_prepared_substrait_plan( - &self, - _query: ActionCreatePreparedSubstraitPlanRequest, - _request: Request, - ) -> Result { - unimplemented!("Implement do_action_create_prepared_substrait_plan") - } - async fn do_action_begin_transaction( &self, _query: ActionBeginTransactionRequest, _request: Request, ) -> Result { - unimplemented!("Implement do_action_begin_transaction") + let transaction_id = Uuid::new_v4().to_string(); + self.transactions + .lock() + .await + .insert(transaction_id.clone(), ()); + Ok(ActionBeginTransactionResult { + transaction_id: transaction_id.as_bytes().to_vec().into(), + }) } async fn do_action_end_transaction( &self, - _query: ActionEndTransactionRequest, + query: ActionEndTransactionRequest, _request: Request, ) -> Result<(), Status> { - unimplemented!("Implement do_action_end_transaction") - } - - async fn do_action_begin_savepoint( - &self, - _query: ActionBeginSavepointRequest, - _request: Request, - ) -> Result { - unimplemented!("Implement do_action_begin_savepoint") - } - - async fn do_action_end_savepoint( - &self, - _query: ActionEndSavepointRequest, - _request: Request, - ) -> Result<(), Status> { - unimplemented!("Implement do_action_end_savepoint") - } - - async fn do_action_cancel_query( - &self, - _query: ActionCancelQueryRequest, - _request: Request, - ) -> Result { - unimplemented!("Implement do_action_cancel_query") + let transaction_id = String::from_utf8(query.transaction_id.to_vec()) + .map_err(|_| Status::invalid_argument("Invalid transaction id"))?; + if self + .transactions + .lock() + .await + .remove(&transaction_id) + .is_none() + { + return Err(Status::invalid_argument("Transaction id not found")); + } + Ok(()) } async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {}