Skip to content

Commit

Permalink
support extended query mode
Browse files Browse the repository at this point in the history
Signed-off-by: Zejiong Dong <[email protected]>
  • Loading branch information
ZENOTME committed Jul 29, 2022
1 parent d486e87 commit 1dded66
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 27 deletions.
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ bin = ["anyhow", "chrono", "console", "tokio-postgres", "env_logger", "clap", "t
[dependencies]
anyhow = { version = "1", optional = true }
async-trait = "0.1"
chrono = { version = "0.4", optional = true }
chrono = { version = "0.4",optional = true }
clap = { version = "3", features = ["derive", "env"], optional = true }
console = { version = "0.15", optional = true }
difference = "2.0"
Expand All @@ -27,7 +27,9 @@ glob = "0.3"
humantime = "2"
itertools = "0.10"
log = "0.4"
postgres-types = { version = "0.2.3", features = ["derive","with-chrono-0_4"] }
quick-junit = { version = "0.2", optional = true }
rust_decimal = { version = "1.7.0", features = [ "tokio-pg" ] }
tempfile = "3"
thiserror = "1"
tokio = { version = "1", features = [
Expand Down
128 changes: 102 additions & 26 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ use clap::{ArgEnum, Parser};
use console::style;
use futures::StreamExt;
use itertools::Itertools;
use postgres_types::Type;
use quick_junit::{NonSuccessKind, Report, TestCase, TestCaseStatus, TestSuite};
use rust_decimal::Decimal;
use sqllogictest::{Control, Record};

#[derive(Copy, Clone, Debug, PartialEq, ArgEnum)]
Expand Down Expand Up @@ -116,10 +118,7 @@ async fn main() -> Result<()> {
}
});

let pg = Postgres {
client: Arc::new(client),
engine_name: opt.engine.clone(),
};
let pg = Postgres::new(Arc::new(client), opt.engine.clone());

let files = files.into_iter().try_collect::<_, Vec<_>, _>()?;

Expand Down Expand Up @@ -311,10 +310,7 @@ async fn run_test_file_on_db(
}
});

let pg = Postgres {
client: Arc::new(client),
engine_name: opt.engine.clone(),
};
let pg = Postgres::new(Arc::new(client), opt.engine.clone());

let result = run_test_file(out, pg, filename).await?;

Expand Down Expand Up @@ -423,6 +419,18 @@ async fn run_test_file<T: std::io::Write>(
struct Postgres {
client: Arc<tokio_postgres::Client>,
engine_name: String,
extend: bool,
}

impl Postgres {
fn new(client: Arc<tokio_postgres::Client>, engine_name: String) -> Self {
let extend = engine_name == "postgresql-extended";
Self {
client,
engine_name,
extend,
}
}
}

#[async_trait]
Expand All @@ -439,32 +447,100 @@ impl sqllogictest::AsyncDB for Postgres {
// and we have to follow the format given by the specific database (pg).
// For example, postgres will output `t` as true and `f` as false,
// thus we have to write `t`/`f` in the expected results.
let rows = self.client.simple_query(sql).await?;
for row in rows {
match row {
tokio_postgres::SimpleQueryMessage::Row(row) => {
for i in 0..row.len() {
if i != 0 {
if !self.extend {
let rows = self.client.simple_query(sql).await?;
for row in rows {
match row {
tokio_postgres::SimpleQueryMessage::Row(row) => {
for i in 0..row.len() {
if i != 0 {
write!(output, " ").unwrap();
}
match row.get(i) {
Some(v) => {
if v.is_empty() {
write!(output, "(empty)").unwrap()
} else {
write!(output, "{}", v).unwrap()
}
}
None => write!(output, "NULL").unwrap(),
}
}
}
tokio_postgres::SimpleQueryMessage::CommandComplete(_) => {}
_ => unreachable!(),
}
writeln!(output).unwrap();
}
Ok(output)
} else {
if sql.contains("select") {
let rows = self.client.query(sql, &[]).await?;
for row in rows {
for (idx, column) in row.columns().iter().enumerate() {
if idx != 0 {
write!(output, " ").unwrap();
}
match row.get(i) {
Some(v) => {
if v.is_empty() {
write!(output, "(empty)").unwrap()
} else {
write!(output, "{}", v).unwrap()
}

match column.type_().clone() {
Type::VARCHAR | Type::TEXT => {
let value: &str = row.get(idx);
write!(output, "{}", value).unwrap();
}

Type::INT2 => {
let value: i16 = row.get(idx);
write!(output, "{}", value).unwrap();
}
Type::INT4 => {
let value: i32 = row.get(idx);
write!(output, "{}", value).unwrap();
}
Type::INT8 => {
let value: i64 = row.get(idx);
write!(output, "{}", value).unwrap();
}
Type::BOOL => {
let value: bool = row.get(idx);
write!(output, "{}", value).unwrap();
}
Type::FLOAT4 => {
let value: f32 = row.get(idx);
write!(output, "{}", value).unwrap();
}
Type::FLOAT8 => {
let value: f64 = row.get(idx);
write!(output, "{}", value).unwrap();
}
Type::NUMERIC => {
let value: Decimal = row.get(idx);
write!(output, "{}", value).unwrap();
}
Type::TIMESTAMP => {
let value: chrono::NaiveDateTime = row.get(idx);
write!(output, "{}", value).unwrap();
}
Type::DATE => {
let value: chrono::NaiveDate = row.get(idx);
write!(output, "{}", value).unwrap();
}
Type::TIME => {
let value: chrono::NaiveTime = row.get(idx);
write!(output, "{}", value).unwrap();
}
_ => {
todo!("Don't support {} type now.", column.type_().name())
}
None => write!(output, "NULL").unwrap(),
}
}
writeln!(output).unwrap();
}
tokio_postgres::SimpleQueryMessage::CommandComplete(_) => {}
_ => unreachable!(),
} else {
self.client.execute(sql, &[]).await?;
}
writeln!(output).unwrap();
Ok(output)
}
Ok(output)
}

fn engine_name(&self) -> &str {
Expand Down

0 comments on commit 1dded66

Please sign in to comment.