Skip to content

Commit

Permalink
Merge pull request #645 from piodul/rework-query-all
Browse files Browse the repository at this point in the history
Replace Connection::query_all with Connection::query_iter
  • Loading branch information
cvybhu authored Mar 2, 2023
2 parents 485d6c4 + 012caa6 commit 739c4f8
Show file tree
Hide file tree
Showing 4 changed files with 442 additions and 449 deletions.
237 changes: 52 additions & 185 deletions scylla/src/transport/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ use std::{
net::{Ipv4Addr, Ipv6Addr},
};

use super::errors::{BadKeyspaceName, BadQuery, DbError, QueryError};
use super::errors::{BadKeyspaceName, DbError, QueryError};
use super::iterator::RowIterator;

use crate::batch::{Batch, BatchStatement};
use crate::frame::protocol_features::ProtocolFeatures;
Expand Down Expand Up @@ -457,98 +458,6 @@ impl Connection {
.await
}

/// Performs query_single_page multiple times to query all available pages
pub async fn query_all(
&self,
query: &Query,
values: impl ValueList,
) -> Result<QueryResult, QueryError> {
// This method is used only for driver internal queries, so no need to consult execution profile here.
self.query_all_with_consistency(
query,
values,
query
.config
.determine_consistency(self.config.default_consistency),
query.get_serial_consistency(),
)
.await
}

pub async fn query_all_with_consistency(
&self,
query: &Query,
values: impl ValueList,
consistency: Consistency,
serial_consistency: Option<SerialConsistency>,
) -> Result<QueryResult, QueryError> {
if query.get_page_size().is_none() {
// Page size should be set when someone wants to use paging
return Err(QueryError::BadQuery(BadQuery::Other(
"Called Connection::query_all without page size set!".to_string(),
)));
}

let mut final_result = QueryResult::default();

let serialized_values = values.serialized()?;
let mut paging_state: Option<Bytes> = None;

loop {
// Send next paged query
let mut cur_result: QueryResult = self
.query_with_consistency(
query,
&serialized_values,
consistency,
serial_consistency,
paging_state,
)
.await?
.into_query_result()?;

// Set paging_state for the next query
paging_state = cur_result.paging_state.take();

// Add current query results to the final_result
final_result.merge_with_next_page_res(cur_result);

if paging_state.is_none() {
// No more pages to query, we can return the final result
return Ok(final_result);
}
}
}

pub async fn execute_single_page(
&self,
prepared_statement: &PreparedStatement,
values: impl ValueList,
paging_state: Option<Bytes>,
) -> Result<QueryResult, QueryError> {
self.execute(prepared_statement, values, paging_state)
.await?
.into_query_result()
}

pub async fn execute(
&self,
prepared_statement: &PreparedStatement,
values: impl ValueList,
paging_state: Option<Bytes>,
) -> Result<QueryResponse, QueryError> {
self.execute_with_consistency(
prepared_statement,
values,
prepared_statement
.config
.determine_consistency(self.config.default_consistency),
prepared_statement.config.serial_consistency.flatten(),
paging_state,
)
.await
}

pub async fn execute_with_consistency(
&self,
prepared_statement: &PreparedStatement,
Expand Down Expand Up @@ -591,41 +500,28 @@ impl Connection {
}
}

/// Performs execute_single_page multiple times to fetch all available pages
#[allow(dead_code)]
pub async fn execute_all(
&self,
prepared_statement: &PreparedStatement,
/// Executes a query and fetches its results over multiple pages, using
/// the asynchronous iterator interface.
pub(crate) async fn query_iter(
self: Arc<Self>,
query: Query,
values: impl ValueList,
) -> Result<QueryResult, QueryError> {
if prepared_statement.get_page_size().is_none() {
return Err(QueryError::BadQuery(BadQuery::Other(
"Called Connection::execute_all without page size set!".to_string(),
)));
}

let mut final_result = QueryResult::default();

let serialized_values = values.serialized()?;
let mut paging_state: Option<Bytes> = None;

loop {
// Send next paged query
let mut cur_result: QueryResult = self
.execute_single_page(prepared_statement, &serialized_values, paging_state)
.await?;

// Set paging_state for the next query
paging_state = cur_result.paging_state.take();
) -> Result<RowIterator, QueryError> {
let serialized_values = values.serialized()?.into_owned();

// Add current query results to the final_result
final_result.merge_with_next_page_res(cur_result);
let consistency = query
.config
.determine_consistency(self.config.default_consistency);
let serial_consistency = query.config.serial_consistency.flatten();

if paging_state.is_none() {
// No more pages to query, we can return the final result
return Ok(final_result);
}
}
RowIterator::new_for_connection_query_iter(
query,
self,
serialized_values,
consistency,
serial_consistency,
)
.await
}

#[allow(dead_code)]
Expand Down Expand Up @@ -1559,7 +1455,6 @@ impl VerifiedKeyspaceName {

#[cfg(test)]
mod tests {
use scylla_cql::errors::BadQuery;
use scylla_cql::frame::protocol_features::{
LWT_OPTIMIZATION_META_BIT_MASK_KEY, SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION,
};
Expand All @@ -1572,12 +1467,12 @@ mod tests {
use tokio::select;
use tokio::sync::mpsc;

use super::super::errors::QueryError;
use super::ConnectionConfig;
use crate::query::Query;
use crate::transport::connection::open_connection;
use crate::utils::test_utils::unique_keyspace_name;
use crate::{IntoTypedRows, SessionBuilder};
use crate::SessionBuilder;
use futures::{StreamExt, TryStreamExt};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
Expand All @@ -1596,20 +1491,20 @@ mod tests {
}
}

/// Tests for Connection::query_all and Connection::execute_all
/// Tests for Connection::query_iter
/// 1. SELECT from an empty table.
/// 2. Create table and insert ints 0..100.
/// Then use query_all and execute_all with page_size set to 7 to select all 100 rows.
/// 3. INSERT query_all should have None in result rows.
/// 4. Calling query_all with a Query that doesn't have page_size set should result in an error.
/// Then use query_iter with page_size set to 7 to select all 100 rows.
/// 3. INSERT query_iter should work and not return any rows.
#[tokio::test]
async fn connection_query_all_execute_all_test() {
async fn connection_query_iter_test() {
let uri = std::env::var("SCYLLA_URI").unwrap_or_else(|_| "127.0.0.1:9042".to_string());
let addr: SocketAddr = resolve_hostname(&uri).await;

let (connection, _) = super::open_connection(addr, None, ConnectionConfig::default())
.await
.unwrap();
let connection = Arc::new(connection);

let ks = unique_keyspace_name();

Expand All @@ -1623,12 +1518,12 @@ mod tests {
session.query(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'SimpleStrategy', 'replication_factor' : 1}}", ks.clone()), &[]).await.unwrap();
session.use_keyspace(ks.clone(), false).await.unwrap();
session
.query("DROP TABLE IF EXISTS connection_query_all_tab", &[])
.query("DROP TABLE IF EXISTS connection_query_iter_tab", &[])
.await
.unwrap();
session
.query(
"CREATE TABLE IF NOT EXISTS connection_query_all_tab (p int primary key)",
"CREATE TABLE IF NOT EXISTS connection_query_iter_tab (p int primary key)",
&[],
)
.await
Expand All @@ -1641,77 +1536,49 @@ mod tests {
.unwrap();

// 1. SELECT from an empty table returns query result where rows are Some(Vec::new())
let select_query = Query::new("SELECT p FROM connection_query_all_tab").with_page_size(7);
let empty_res = connection.query_all(&select_query, &[]).await.unwrap();
assert!(empty_res.rows.unwrap().is_empty());

let mut prepared_select = connection.prepare(&select_query).await.unwrap();
prepared_select.set_page_size(7);
let empty_res_prepared = connection.execute_all(&prepared_select, &[]).await.unwrap();
assert!(empty_res_prepared.rows.unwrap().is_empty());
let select_query = Query::new("SELECT p FROM connection_query_iter_tab").with_page_size(7);
let empty_res = connection
.clone()
.query_iter(select_query.clone(), &[])
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert!(empty_res.is_empty());

// 2. Insert 100 and select using query_all with page_size 7
// 2. Insert 100 and select using query_iter with page_size 7
let values: Vec<i32> = (0..100).collect();
let mut insert_futures = Vec::new();
let insert_query =
Query::new("INSERT INTO connection_query_all_tab (p) VALUES (?)").with_page_size(7);
Query::new("INSERT INTO connection_query_iter_tab (p) VALUES (?)").with_page_size(7);
for v in &values {
insert_futures.push(connection.query_single_page(insert_query.clone(), (v,)));
}

futures::future::try_join_all(insert_futures).await.unwrap();

let mut results: Vec<i32> = connection
.query_all(&select_query, &[])
.clone()
.query_iter(select_query.clone(), &[])
.await
.unwrap()
.rows
.unwrap()
.into_typed::<(i32,)>()
.map(|r| r.unwrap().0)
.collect();
.map(|ret| ret.unwrap().0)
.collect::<Vec<_>>()
.await;
results.sort_unstable(); // Clippy recommended to use sort_unstable instead of sort()
assert_eq!(results, values);

let mut results2: Vec<i32> = connection
.execute_all(&prepared_select, &[])
// 3. INSERT query_iter should work and not return any rows.
let insert_res1 = connection
.query_iter(insert_query, (0,))
.await
.unwrap()
.rows
.unwrap()
.into_typed::<(i32,)>()
.map(|r| r.unwrap().0)
.collect();
results2.sort_unstable();
assert_eq!(results2, values);

// 3. INSERT query_all should have None in result rows.
let insert_res1 = connection.query_all(&insert_query, (0,)).await.unwrap();
assert!(insert_res1.rows.is_none());

let prepared_insert = connection.prepare(&insert_query).await.unwrap();
let insert_res2 = connection
.execute_all(&prepared_insert, (0,))
.try_collect::<Vec<_>>()
.await
.unwrap();
assert!(insert_res2.rows.is_none(),);

// 4. Calling query_all with a Query that doesn't have page_size set should result in an error.
let no_page_size_query = Query::new("SELECT p FROM connection_query_all_tab");
let no_page_res = connection.query_all(&no_page_size_query, &[]).await;
assert!(matches!(
no_page_res,
Err(QueryError::BadQuery(BadQuery::Other(_)))
));

let prepared_no_page_size_query = connection.prepare(&no_page_size_query).await.unwrap();
let prepared_no_page_res = connection
.execute_all(&prepared_no_page_size_query, &[])
.await;
assert!(matches!(
prepared_no_page_res,
Err(QueryError::BadQuery(BadQuery::Other(_)))
));
assert!(insert_res1.is_empty());
}

#[tokio::test]
Expand Down
Loading

0 comments on commit 739c4f8

Please sign in to comment.