diff --git a/Cargo.toml b/Cargo.toml index 690b735..6acaa5e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ bin = ["anyhow", "chrono", "console", "tokio-postgres", "env_logger", "clap", "t [dependencies] anyhow = { version = "1", optional = true } async-trait = "0.1" -chrono = { version = "0.4", optional = true } +chrono = { version = "0.4",optional = true } clap = { version = "3", features = ["derive", "env"], optional = true } console = { version = "0.15", optional = true } difference = "2.0" @@ -27,7 +27,9 @@ glob = "0.3" humantime = "2" itertools = "0.10" log = "0.4" +postgres-types = { version = "0.2.3", features = ["derive","with-chrono-0_4"] } quick-junit = { version = "0.2", optional = true } +rust_decimal = { version = "1.7.0", features = [ "tokio-pg" ] } tempfile = "3" thiserror = "1" tokio = { version = "1", features = [ diff --git a/src/main.rs b/src/main.rs index 303b951..b7d3b58 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,7 +11,9 @@ use clap::{ArgEnum, Parser}; use console::style; use futures::StreamExt; use itertools::Itertools; +use postgres_types::Type; use quick_junit::{NonSuccessKind, Report, TestCase, TestCaseStatus, TestSuite}; +use rust_decimal::Decimal; use sqllogictest::{Control, Record}; #[derive(Copy, Clone, Debug, PartialEq, ArgEnum)] @@ -116,10 +118,7 @@ async fn main() -> Result<()> { } }); - let pg = Postgres { - client: Arc::new(client), - engine_name: opt.engine.clone(), - }; + let pg = Postgres::new(Arc::new(client), opt.engine.clone()); let files = files.into_iter().try_collect::<_, Vec<_>, _>()?; @@ -311,10 +310,7 @@ async fn run_test_file_on_db( } }); - let pg = Postgres { - client: Arc::new(client), - engine_name: opt.engine.clone(), - }; + let pg = Postgres::new(Arc::new(client), opt.engine.clone()); let result = run_test_file(out, pg, filename).await?; @@ -423,6 +419,18 @@ async fn run_test_file( struct Postgres { client: Arc, engine_name: String, + extend: bool, +} + +impl Postgres { + fn new(client: Arc, engine_name: String) -> Self { + let extend = engine_name == "postgresql-extended"; + Self { + client, + engine_name, + extend, + } + } } #[async_trait] @@ -439,32 +447,100 @@ impl sqllogictest::AsyncDB for Postgres { // and we have to follow the format given by the specific database (pg). // For example, postgres will output `t` as true and `f` as false, // thus we have to write `t`/`f` in the expected results. - let rows = self.client.simple_query(sql).await?; - for row in rows { - match row { - tokio_postgres::SimpleQueryMessage::Row(row) => { - for i in 0..row.len() { - if i != 0 { + if !self.extend { + let rows = self.client.simple_query(sql).await?; + for row in rows { + match row { + tokio_postgres::SimpleQueryMessage::Row(row) => { + for i in 0..row.len() { + if i != 0 { + write!(output, " ").unwrap(); + } + match row.get(i) { + Some(v) => { + if v.is_empty() { + write!(output, "(empty)").unwrap() + } else { + write!(output, "{}", v).unwrap() + } + } + None => write!(output, "NULL").unwrap(), + } + } + } + tokio_postgres::SimpleQueryMessage::CommandComplete(_) => {} + _ => unreachable!(), + } + writeln!(output).unwrap(); + } + Ok(output) + } else { + if sql.contains("select") { + let rows = self.client.query(sql, &[]).await?; + for row in rows { + for (idx, column) in row.columns().iter().enumerate() { + if idx != 0 { write!(output, " ").unwrap(); } - match row.get(i) { - Some(v) => { - if v.is_empty() { - write!(output, "(empty)").unwrap() - } else { - write!(output, "{}", v).unwrap() - } + + match column.type_().clone() { + Type::VARCHAR | Type::TEXT => { + let value: &str = row.get(idx); + write!(output, "{}", value).unwrap(); + } + + Type::INT2 => { + let value: i16 = row.get(idx); + write!(output, "{}", value).unwrap(); + } + Type::INT4 => { + let value: i32 = row.get(idx); + write!(output, "{}", value).unwrap(); + } + Type::INT8 => { + let value: i64 = row.get(idx); + write!(output, "{}", value).unwrap(); + } + Type::BOOL => { + let value: bool = row.get(idx); + write!(output, "{}", value).unwrap(); + } + Type::FLOAT4 => { + let value: f32 = row.get(idx); + write!(output, "{}", value).unwrap(); + } + Type::FLOAT8 => { + let value: f64 = row.get(idx); + write!(output, "{}", value).unwrap(); + } + Type::NUMERIC => { + let value: Decimal = row.get(idx); + write!(output, "{}", value).unwrap(); + } + Type::TIMESTAMP => { + let value: chrono::NaiveDateTime = row.get(idx); + write!(output, "{}", value).unwrap(); + } + Type::DATE => { + let value: chrono::NaiveDate = row.get(idx); + write!(output, "{}", value).unwrap(); + } + Type::TIME => { + let value: chrono::NaiveTime = row.get(idx); + write!(output, "{}", value).unwrap(); + } + _ => { + todo!("Don't support {} type now.", column.type_().name()) } - None => write!(output, "NULL").unwrap(), } } + writeln!(output).unwrap(); } - tokio_postgres::SimpleQueryMessage::CommandComplete(_) => {} - _ => unreachable!(), + } else { + self.client.execute(sql, &[]).await?; } - writeln!(output).unwrap(); + Ok(output) } - Ok(output) } fn engine_name(&self) -> &str {