From 9329951fc8cc1711804d2a36549c7d34b8c56f03 Mon Sep 17 00:00:00 2001 From: Frank Wang <1454884738@qq.com> Date: Mon, 23 Sep 2024 18:09:35 -0500 Subject: [PATCH 01/16] add all db supports --- Cargo.toml | 10 +- src/app.rs | 81 +++++++--- src/components.rs | 19 ++- src/components/data.rs | 39 ++++- src/components/editor.rs | 27 +++- src/components/history.rs | 18 ++- src/components/menu.rs | 36 ++++- src/components/scroll_table.rs | 12 +- src/database.rs | 4 +- src/database/mysql.rs | 9 ++ src/database/postgresql.rs | 286 +++++++++++++++++++++++++++++++++ src/database/sqlite.rs | 7 + src/generic_database.rs | 187 +++++++++++++++++++++ src/main.rs | 157 +++++++++--------- 14 files changed, 753 insertions(+), 139 deletions(-) create mode 100644 src/database/mysql.rs create mode 100644 src/database/postgresql.rs create mode 100644 src/database/sqlite.rs create mode 100644 src/generic_database.rs diff --git a/Cargo.toml b/Cargo.toml index 0494790..6f0d13b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,15 +59,7 @@ ratatui = { version = "0.28.0", features = [ serde = { version = "1.0.188", features = ["derive"] } serde_json = "1.0.107" signal-hook = "0.3.17" -sqlx = { version = "0.8.1", features = [ - "runtime-tokio", - "tls-rustls", - "postgres", - "uuid", - "chrono", - "json", - "ipnetwork", -] } +sqlx = { version = "0.8.1", features = ["runtime-tokio", "tls-rustls", "postgres", "uuid", "chrono", "json", "ipnetwork", "mysql", "sqlite"] } strip-ansi-escapes = "0.2.0" strum = { version = "0.26.1", features = ["derive"] } tokio = { version = "1.32.0", features = ["full"] } diff --git a/src/app.rs b/src/app.rs index a1f12b3..283aaf8 100644 --- a/src/app.rs +++ b/src/app.rs @@ -16,7 +16,7 @@ use serde::{Deserialize, Serialize}; use sqlparser::{ast::Statement, keywords::DELETE}; use sqlx::{ postgres::{PgConnectOptions, Postgres}, - Either, Transaction, + Connection, Database, Either, Executor, Pool, Transaction, }; use tokio::{ sync::{ @@ -38,14 +38,21 @@ use crate::{ config::Config, database::{self, statement_type_string, DbError, DbPool, Rows}, focus::Focus, - tui, + generic_database, tui, ui::center, }; -pub enum DbTask<'a> { +pub enum DbTask<'a, DB> +where + DB: Database + generic_database::ValueParser, + DB::QueryResult: generic_database::HasRowsAffected, + for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + +{ Query(tokio::task::JoinHandle), - TxStart(tokio::task::JoinHandle<(QueryResultsWithMetadata, Transaction<'a, Postgres>)>), - TxPending(Transaction<'a, Postgres>, QueryResultsWithMetadata), + TxStart(tokio::task::JoinHandle<(QueryResultsWithMetadata, Transaction<'a, DB>)>), + TxPending(Transaction<'a, DB>, QueryResultsWithMetadata), TxCommit(tokio::task::JoinHandle), } @@ -54,20 +61,34 @@ pub struct HistoryEntry { pub timestamp: chrono::DateTime, } -pub struct AppState<'a> { - pub connection_opts: PgConnectOptions, +pub struct AppState<'a, DB: Database> +where + DB: Database + generic_database::ValueParser, + DB::QueryResult: generic_database::HasRowsAffected, + for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + +{ + pub connection_opts: ::Options, pub focus: Focus, - pub query_task: Option>, + pub query_task: Option>, pub history: Vec, pub last_query_start: Option>, pub last_query_end: Option>, } -pub struct Components<'a> { - pub menu: Box>, - pub editor: Box, - pub history: Box, - pub data: Box>, +pub struct Components<'a, DB: Database + generic_database::ValueParser> +where + DB: Database + generic_database::ValueParser, + DB::QueryResult: generic_database::HasRowsAffected, + for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + +{ + pub menu: Box>, + pub editor: Box>, + pub history: Box>, + pub data: Box>, } #[derive(Debug)] @@ -76,19 +97,33 @@ pub struct QueryResultsWithMetadata { pub statement_type: Statement, } -pub struct App<'a> { +pub struct App<'a, DB: Database + generic_database::ValueParser> +where + DB: Database + generic_database::ValueParser, + DB::QueryResult: generic_database::HasRowsAffected, + for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + +{ pub config: Config, - pub components: Components<'static>, + pub components: Components<'static, DB>, pub should_quit: bool, pub last_tick_key_events: Vec, pub last_frame_mouse_event: Option, - pub pool: Option, - pub state: AppState<'a>, + pub pool: Option>, + pub state: AppState<'a, DB>, last_focused_tab: Focus, } -impl<'a> App<'a> { - pub fn new(connection_opts: PgConnectOptions) -> Result { +impl<'a, DB> App<'a, DB> +where + DB: Database + generic_database::ValueParser, + DB::QueryResult: generic_database::HasRowsAffected, + for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + +{ + pub fn new(connection_opts: ::Options) -> Result { let focus = Focus::Menu; let menu = Menu::new(); let editor = Editor::new(); @@ -133,7 +168,7 @@ impl<'a> App<'a> { pub async fn run(&mut self) -> Result<()> { let (action_tx, mut action_rx) = mpsc::unbounded_channel(); let connection_opts = self.state.connection_opts.clone(); - let pool = database::init_pool(connection_opts).await?; + let pool = generic_database::init_pool::(connection_opts).await?; log::info!("{pool:?}"); self.pool = Some(pool); @@ -348,7 +383,7 @@ impl<'a> App<'a> { Action::LoadMenu => { log::info!("LoadMenu"); if let Some(pool) = &self.pool { - let results = database::query( + let results = generic_database::query( "select table_schema, table_name from information_schema.tables where table_schema != 'pg_catalog' @@ -377,7 +412,7 @@ impl<'a> App<'a> { self.components.data.set_loading(); let tx = pool.begin().await?; self.state.query_task = Some(DbTask::TxStart(tokio::spawn(async move { - let (results, tx) = database::query_with_tx(tx, query_string.clone()).await; + let (results, tx) = generic_database::query_with_tx::(tx, query_string.clone()).await; match results { Ok(Either::Left(rows_affected)) => { log::info!("{:?} rows affected", rows_affected); @@ -405,7 +440,7 @@ impl<'a> App<'a> { Ok((false, statement_type)) => { self.components.data.set_loading(); self.state.query_task = Some(DbTask::Query(tokio::spawn(async move { - let results = database::query(query_string.clone(), &pool).await; + let results = generic_database::query(query_string.clone(), &pool).await; match &results { Ok(rows) => { log::info!("{:?} rows, {:?} affected", rows.rows.len(), rows.rows_affected); diff --git a/src/components.rs b/src/components.rs index 6400fa2..3079d60 100644 --- a/src/components.rs +++ b/src/components.rs @@ -15,8 +15,15 @@ pub mod editor; pub mod history; pub mod menu; pub mod scroll_table; +use sqlx::{Database, Executor, Pool}; +pub trait Component +where + DB: Database + crate::generic_database::ValueParser, + DB::QueryResult: crate::generic_database::HasRowsAffected, + for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, -pub trait Component { +{ /// Register an action handler that can send actions for processing if necessary. /// /// # Arguments @@ -68,7 +75,7 @@ pub trait Component { &mut self, event: Option, last_tick_key_events: Vec, - app_state: &AppState, + app_state: &AppState<'_, DB>, ) -> Result> { let r = match event { Some(Event::Key(key_event)) => self.handle_key_events(key_event, app_state)?, @@ -87,7 +94,7 @@ pub trait Component { /// /// * `Result>` - An action to be processed or none. #[allow(unused_variables)] - fn handle_key_events(&mut self, key: KeyEvent, app_state: &AppState) -> Result> { + fn handle_key_events(&mut self, key: KeyEvent, app_state: &AppState<'_, DB>) -> Result> { Ok(None) } /// Handle mouse events and produce actions if necessary. @@ -100,7 +107,7 @@ pub trait Component { /// /// * `Result>` - An action to be processed or none. #[allow(unused_variables)] - fn handle_mouse_events(&mut self, mouse: MouseEvent, app_state: &AppState) -> Result> { + fn handle_mouse_events(&mut self, mouse: MouseEvent, app_state: &AppState<'_, DB>) -> Result> { Ok(None) } /// Update the state of the component based on a received action. (REQUIRED) @@ -113,7 +120,7 @@ pub trait Component { /// /// * `Result>` - An action to be processed or none. #[allow(unused_variables)] - fn update(&mut self, action: Action, app_state: &AppState) -> Result> { + fn update(&mut self, action: Action, app_state: &AppState<'_, DB>) -> Result> { Ok(None) } /// Render the component on the screen. (REQUIRED) @@ -126,5 +133,5 @@ pub trait Component { /// # Returns /// /// * `Result<()>` - An Ok result or an error. - fn draw(&mut self, f: &mut Frame<'_>, area: Rect, app_state: &AppState) -> Result<()>; + fn draw(&mut self, f: &mut Frame<'_>, area: Rect, app_state: &AppState<'_, DB>) -> Result<()>; } diff --git a/src/components/data.rs b/src/components/data.rs index 269c84e..e48b419 100644 --- a/src/components/data.rs +++ b/src/components/data.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, sync::Arc, time::Duration}; +use std::{collections::HashMap, marker::PhantomData, sync::Arc, time::Duration}; use color_eyre::eyre::Result; use crossterm::{ @@ -8,6 +8,7 @@ use crossterm::{ use ratatui::{prelude::*, symbols::scrollbar, widgets::*}; use serde::{Deserialize, Serialize}; use sqlparser::ast::Statement; +use sqlx::{Database, Executor, Pool}; use tokio::sync::{mpsc::UnboundedSender, Mutex}; use tui_textarea::{Input, Key}; @@ -51,8 +52,23 @@ pub trait SettableDataTable<'a> { fn set_cancelled(&mut self); } -pub trait DataComponent<'a>: Component + SettableDataTable<'a> {} -impl<'a, T> DataComponent<'a> for T where T: Component + SettableDataTable<'a> +pub trait DataComponent<'a, DB>: Component + SettableDataTable<'a> +where + DB: Database + crate::generic_database::ValueParser, + DB::QueryResult: crate::generic_database::HasRowsAffected, + for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + +{ +} +impl<'a, T, DB> DataComponent<'a, DB> for T +where + T: Component + SettableDataTable<'a>, + DB: Database + crate::generic_database::ValueParser, + DB::QueryResult: crate::generic_database::HasRowsAffected, + for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + { } @@ -231,7 +247,14 @@ impl<'a> SettableDataTable<'a> for Data<'a> { } } -impl<'a> Component for Data<'a> { +impl<'a, DB> Component for Data<'a> +where + DB: Database + crate::generic_database::ValueParser, + DB::QueryResult: crate::generic_database::HasRowsAffected, + for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + +{ fn register_action_handler(&mut self, tx: UnboundedSender) -> Result<()> { self.command_tx = Some(tx); Ok(()) @@ -245,7 +268,7 @@ impl<'a> Component for Data<'a> { fn handle_mouse_events( &mut self, mouse: crossterm::event::MouseEvent, - app_state: &AppState, + app_state: &AppState<'_, DB>, ) -> Result> { if app_state.focus != Focus::Data { return Ok(None); @@ -268,7 +291,7 @@ impl<'a> Component for Data<'a> { Ok(None) } - fn handle_key_events(&mut self, key: KeyEvent, app_state: &AppState) -> Result> { + fn handle_key_events(&mut self, key: KeyEvent, app_state: &AppState<'_, DB>) -> Result> { if app_state.focus != Focus::Data { return Ok(None); } @@ -372,14 +395,14 @@ impl<'a> Component for Data<'a> { Ok(None) } - fn update(&mut self, action: Action, app_state: &AppState) -> Result> { + fn update(&mut self, action: Action, app_state: &AppState<'_, DB>) -> Result> { if let Action::Query(query) = action { self.scrollable.reset_scroll(); } Ok(None) } - fn draw(&mut self, f: &mut Frame<'_>, area: Rect, app_state: &AppState) -> Result<()> { + fn draw(&mut self, f: &mut Frame<'_>, area: Rect, app_state: &AppState<'_, DB>) -> Result<()> { let focused = app_state.focus == Focus::Data; let mut block = Block::default().borders(Borders::ALL).border_style(if focused { diff --git a/src/components/editor.rs b/src/components/editor.rs index 1c27439..a5e1660 100644 --- a/src/components/editor.rs +++ b/src/components/editor.rs @@ -10,6 +10,7 @@ use color_eyre::eyre::Result; use crossterm::event::{KeyCode, KeyEvent, KeyEventKind, MouseEvent, MouseEventKind}; use ratatui::{prelude::*, widgets::*}; use serde::{Deserialize, Serialize}; +use sqlx::{Database, Executor, Pool}; use tokio::sync::mpsc::UnboundedSender; use tui_textarea::{Input, Key, Scrolling, TextArea}; @@ -66,7 +67,14 @@ impl<'a> Editor<'a> { } } - pub fn transition_vim_state(&mut self, input: Input, app_state: &AppState) -> Result<()> { + pub fn transition_vim_state(&mut self, input: Input, app_state: &AppState<'_, DB>) -> Result<()> + where + DB: Database + crate::generic_database::ValueParser, + DB::QueryResult: crate::generic_database::HasRowsAffected, + for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + + { match input { Input { key: Key::Enter, alt: true, .. } | Input { key: Key::Enter, ctrl: true, .. } => { if app_state.query_task.is_none() { @@ -108,7 +116,14 @@ impl<'a> Editor<'a> { } } -impl<'a> Component for Editor<'a> { +impl<'a, DB> Component for Editor<'a> +where + DB: Database + crate::generic_database::ValueParser, + DB::QueryResult: crate::generic_database::HasRowsAffected, + for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + +{ fn register_action_handler(&mut self, tx: UnboundedSender) -> Result<()> { self.command_tx = Some(tx); Ok(()) @@ -119,7 +134,7 @@ impl<'a> Component for Editor<'a> { Ok(()) } - fn handle_mouse_events(&mut self, mouse: MouseEvent, app_state: &AppState) -> Result> { + fn handle_mouse_events(&mut self, mouse: MouseEvent, app_state: &AppState<'_, DB>) -> Result> { if app_state.focus != Focus::Editor { return Ok(None); } @@ -145,7 +160,7 @@ impl<'a> Component for Editor<'a> { &mut self, event: Option, last_tick_key_events: Vec, - app_state: &AppState, + app_state: &AppState<'_, DB>, ) -> Result> { if app_state.focus != Focus::Editor { return Ok(None); @@ -161,7 +176,7 @@ impl<'a> Component for Editor<'a> { Ok(None) } - fn update(&mut self, action: Action, app_state: &AppState) -> Result> { + fn update(&mut self, action: Action, app_state: &AppState<'_, DB>) -> Result> { match action { Action::MenuPreview(preview_type, schema, table) => { if app_state.query_task.is_some() { @@ -230,7 +245,7 @@ impl<'a> Component for Editor<'a> { Ok(None) } - fn draw(&mut self, f: &mut Frame<'_>, area: Rect, app_state: &AppState) -> Result<()> { + fn draw(&mut self, f: &mut Frame<'_>, area: Rect, app_state: &AppState<'_, DB>) -> Result<()> { let focused = app_state.focus == Focus::Editor; if let Some(query_start) = app_state.last_query_start { diff --git a/src/components/history.rs b/src/components/history.rs index 7e68995..6239976 100644 --- a/src/components/history.rs +++ b/src/components/history.rs @@ -2,6 +2,7 @@ use color_eyre::eyre::Result; use crossterm::event::{KeyCode, KeyEvent, KeyEventKind, MouseEvent, MouseEventKind}; use ratatui::{prelude::*, symbols::scrollbar, widgets::*}; use serde::{Deserialize, Serialize}; +use sqlx::{Database, Executor, Pool}; use tokio::sync::mpsc::UnboundedSender; use tui_textarea::{Input, Key, Scrolling, TextArea}; @@ -49,7 +50,14 @@ impl History { } } -impl Component for History { +impl Component for History +where + DB: Database + crate::generic_database::ValueParser, + DB::QueryResult: crate::generic_database::HasRowsAffected, + for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + +{ fn register_action_handler(&mut self, tx: UnboundedSender) -> Result<()> { self.command_tx = Some(tx); Ok(()) @@ -60,7 +68,7 @@ impl Component for History { Ok(()) } - fn handle_mouse_events(&mut self, mouse: MouseEvent, app_state: &AppState) -> Result> { + fn handle_mouse_events(&mut self, mouse: MouseEvent, app_state: &AppState<'_, DB>) -> Result> { if app_state.focus != Focus::History { return Ok(None); } @@ -77,7 +85,7 @@ impl Component for History { Ok(None) } - fn handle_key_events(&mut self, key: KeyEvent, app_state: &AppState) -> Result> { + fn handle_key_events(&mut self, key: KeyEvent, app_state: &AppState<'_, DB>) -> Result> { if app_state.focus != Focus::History { return Ok(None); } @@ -112,11 +120,11 @@ impl Component for History { Ok(None) } - fn update(&mut self, action: Action, app_state: &AppState) -> Result> { + fn update(&mut self, action: Action, app_state: &AppState<'_, DB>) -> Result> { Ok(None) } - fn draw(&mut self, f: &mut Frame<'_>, area: Rect, app_state: &AppState) -> Result<()> { + fn draw(&mut self, f: &mut Frame<'_>, area: Rect, app_state: &AppState<'_, DB>) -> Result<()> { let focused = app_state.focus == Focus::History; if let Some(query_start) = app_state.last_query_start { self.last_query_duration = match app_state.last_query_end { diff --git a/src/components/menu.rs b/src/components/menu.rs index e795664..013ab80 100644 --- a/src/components/menu.rs +++ b/src/components/menu.rs @@ -10,6 +10,7 @@ use crossterm::event::{KeyCode, KeyEvent, MouseEventKind}; use indexmap::IndexMap; use ratatui::{prelude::*, widgets::*}; use serde::{Deserialize, Serialize}; +use sqlx::{Database, Executor, Pool}; use symbols::scrollbar; use tokio::sync::mpsc::UnboundedSender; @@ -34,8 +35,24 @@ pub trait SettableTableList<'a> { fn set_table_list(&mut self, data: Option>); } -pub trait MenuComponent<'a>: Component + SettableTableList<'a> {} -impl<'a, T> MenuComponent<'a> for T where T: Component + SettableTableList<'a> +pub trait MenuComponent<'a, DB>: Component + SettableTableList<'a> +where + DB: Database + crate::generic_database::ValueParser, + DB::QueryResult: crate::generic_database::HasRowsAffected, + for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + +{ +} + +impl<'a, T, DB> MenuComponent<'a, DB> for T +where + T: Component + SettableTableList<'a>, + DB: Database + crate::generic_database::ValueParser, + DB::QueryResult: crate::generic_database::HasRowsAffected, + for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + { } @@ -187,7 +204,14 @@ impl<'a> SettableTableList<'a> for Menu { } } -impl Component for Menu { +impl Component for Menu +where + DB: Database + crate::generic_database::ValueParser, + DB::QueryResult: crate::generic_database::HasRowsAffected, + for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + +{ fn register_action_handler(&mut self, tx: UnboundedSender) -> Result<()> { self.command_tx = Some(tx); Ok(()) @@ -201,7 +225,7 @@ impl Component for Menu { fn handle_mouse_events( &mut self, mouse: crossterm::event::MouseEvent, - app_state: &AppState, + app_state: &AppState<'_, DB>, ) -> Result> { if app_state.focus != Focus::Menu { return Ok(None); @@ -214,7 +238,7 @@ impl Component for Menu { Ok(None) } - fn handle_key_events(&mut self, key: KeyEvent, app_state: &AppState) -> Result> { + fn handle_key_events(&mut self, key: KeyEvent, app_state: &AppState<'_, DB>) -> Result> { if app_state.focus != Focus::Menu { return Ok(None); } @@ -320,7 +344,7 @@ impl Component for Menu { Ok(None) } - fn draw(&mut self, f: &mut Frame<'_>, area: Rect, app_state: &AppState) -> Result<()> { + fn draw(&mut self, f: &mut Frame<'_>, area: Rect, app_state: &AppState<'_, DB>) -> Result<()> { let focused = app_state.focus == Focus::Menu; let parent_block = Block::default(); let stable_keys = self.table_map.keys().enumerate(); diff --git a/src/components/scroll_table.rs b/src/components/scroll_table.rs index 5da838a..8986281 100644 --- a/src/components/scroll_table.rs +++ b/src/components/scroll_table.rs @@ -9,6 +9,7 @@ use ratatui::{ Table, TableState, WidgetRef, }, }; +use sqlx::{Database, Executor, Pool}; use symbols::scrollbar; use super::Component; @@ -190,8 +191,15 @@ impl<'a> ScrollTable<'a> { } } -impl<'a> Component for ScrollTable<'a> { - fn draw(&mut self, f: &mut Frame<'_>, area: Rect, app_state: &AppState) -> Result<()> { +impl<'a, DB> Component for ScrollTable<'a> +where + DB: Database + crate::generic_database::ValueParser, + DB::QueryResult: crate::generic_database::HasRowsAffected, + for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + +{ + fn draw(&mut self, f: &mut Frame<'_>, area: Rect, app_state: &AppState<'_, DB>) -> Result<()> { self.parent_area = area; let render_area = self.block.inner_if_some(area); self.pg_height = std::cmp::min(self.max_height, render_area.height).saturating_sub(3); diff --git a/src/database.rs b/src/database.rs index 02cd955..40ad9e1 100644 --- a/src/database.rs +++ b/src/database.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, fmt::Write, pin::Pin, string::String}; use futures::stream::{BoxStream, StreamExt}; use sqlparser::{ ast::Statement, - dialect::PostgreSqlDialect, + dialect::{MySqlDialect, PostgreSqlDialect, SQLiteDialect}, keywords, parser::{Parser, ParserError}, }; @@ -16,6 +16,8 @@ use sqlx::{ Column, Database, Either, Error, Pool, Row, Transaction, ValueRef, }; +pub mod postgresql; + #[derive(Debug)] pub struct Header { pub name: String, diff --git a/src/database/mysql.rs b/src/database/mysql.rs new file mode 100644 index 0000000..fe83b4d --- /dev/null +++ b/src/database/mysql.rs @@ -0,0 +1,9 @@ +use super::Value; +use crate::generic_database::ValueParser; + +impl ValueParser for MySql { + fn parse_value(row: &Self::Row, col: &Self::Column) -> Option { + // MySQL-specific parsing + todo!() + } +} diff --git a/src/database/postgresql.rs b/src/database/postgresql.rs new file mode 100644 index 0000000..1140d1d --- /dev/null +++ b/src/database/postgresql.rs @@ -0,0 +1,286 @@ +use std::{collections::HashMap, fmt::Write, pin::Pin, string::String}; + +use futures::stream::{BoxStream, StreamExt}; +use sqlparser::{ + ast::Statement, + dialect::PostgreSqlDialect, + keywords, + parser::{Parser, ParserError}, +}; +use sqlx::{ + postgres::{ + PgColumn, PgConnectOptions, PgPool, PgPoolOptions, PgQueryResult, PgRow, PgTypeInfo, PgTypeKind, PgValueRef, + Postgres, + }, + types::Uuid, + Column, Database, Either, Error, Pool, Row, Transaction, ValueRef, +}; + +use super::{vec_to_string, Value}; +use crate::generic_database::ValueParser; + +impl ValueParser for Postgres { + fn parse_value(row: &::Row, col: &::Column) -> Option { + let col_type = col.type_info().to_string(); + let raw_value = row.try_get_raw(col.ordinal()).unwrap(); + if raw_value.is_null() { + return Some(Value { string: "NULL".to_string(), is_null: true }); + } + match col_type.to_uppercase().as_str() { + "TIMESTAMPTZ" => { + let received: chrono::DateTime = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: received.to_string(), is_null: false }) + }, + "TIMESTAMP" => { + let received: chrono::NaiveDateTime = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: received.to_string(), is_null: false }) + }, + "DATE" => { + let received: chrono::NaiveDate = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: received.to_string(), is_null: false }) + }, + "TIME" => { + let received: chrono::NaiveTime = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: received.to_string(), is_null: false }) + }, + "UUID" => { + let received: Uuid = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: received.to_string(), is_null: false }) + }, + "INET" | "CIDR" => { + let received: std::net::IpAddr = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: received.to_string(), is_null: false }) + }, + "JSON" | "JSONB" => { + let received: serde_json::Value = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: received.to_string(), is_null: false }) + }, + "BOOL" => { + let received: bool = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: received.to_string(), is_null: false }) + }, + "SMALLINT" | "SMALLSERIAL" | "INT2" => { + let received: i16 = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: received.to_string(), is_null: false }) + }, + "INT" | "SERIAL" | "INT4" => { + let received: i32 = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: received.to_string(), is_null: false }) + }, + "BIGINT" | "BIGSERIAL" | "INT8" => { + let received: i64 = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: received.to_string(), is_null: false }) + }, + "REAL" | "FLOAT4" => { + let received: f32 = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: received.to_string(), is_null: false }) + }, + "DOUBLE PRECISION" | "FLOAT8" => { + let received: f64 = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: received.to_string(), is_null: false }) + }, + "TEXT" | "VARCHAR" | "NAME" | "CITEXT" | "BPCHAR" | "CHAR" => { + let received: String = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: received, is_null: false }) + }, + "BYTEA" => { + let received: Vec = row.try_get(col.ordinal()).unwrap(); + Some(Value { + string: received.iter().fold(String::new(), |mut output, b| { + let _ = write!(output, "{b:02X}"); + output + }), + is_null: false, + }) + }, + "VOID" => Some(Value { string: "".to_string(), is_null: false }), + _ if col_type.to_uppercase().ends_with("[]") => { + let array_type = col_type.to_uppercase().replace("[]", ""); + match array_type.as_str() { + "TIMESTAMPTZ" => { + let received: Vec> = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: vec_to_string(received), is_null: false }) + }, + "TIMESTAMP" => { + let received: Vec = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: vec_to_string(received), is_null: false }) + }, + "DATE" => { + let received: Vec = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: vec_to_string(received), is_null: false }) + }, + "TIME" => { + let received: Vec = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: vec_to_string(received), is_null: false }) + }, + "UUID" => { + let received: Vec = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: vec_to_string(received), is_null: false }) + }, + "INET" | "CIDR" => { + let received: Vec = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: vec_to_string(received), is_null: false }) + }, + "JSON" | "JSONB" => { + let received: Vec = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: vec_to_string(received), is_null: false }) + }, + "BOOL" => { + let received: Vec = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: vec_to_string(received), is_null: false }) + }, + "SMALLINT" | "SMALLSERIAL" | "INT2" => { + let received: Vec = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: vec_to_string(received), is_null: false }) + }, + "INT" | "SERIAL" | "INT4" => { + let received: Vec = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: vec_to_string(received), is_null: false }) + }, + "BIGINT" | "BIGSERIAL" | "INT8" => { + let received: Vec = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: vec_to_string(received), is_null: false }) + }, + "REAL" | "FLOAT4" => { + let received: Vec = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: vec_to_string(received), is_null: false }) + }, + "DOUBLE PRECISION" | "FLOAT8" => { + let received: Vec = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: vec_to_string(received), is_null: false }) + }, + "TEXT" | "VARCHAR" | "NAME" | "CITEXT" | "BPCHAR" | "CHAR" => { + let received: Vec = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: vec_to_string(received), is_null: false }) + }, + "BYTEA" => { + let received: Vec = row.try_get(col.ordinal()).unwrap(); + Some(Value { + string: received.iter().fold(String::new(), |mut output, b| { + let _ = write!(output, "{b:02X}"); + output + }), + is_null: false, + }) + }, + _ => { + // try to cast custom or other types to strings + let received: Vec = row.try_get_unchecked(col.ordinal()).unwrap(); + Some(Value { string: vec_to_string(received), is_null: false }) + }, + } + }, + _ => { + // try to cast custom or other types to strings + let received: String = row.try_get_unchecked(col.ordinal()).unwrap(); + Some(Value { string: received, is_null: false }) + }, + } + } +} +mod tests { + use sqlparser::{ast::Statement, dialect::PostgreSqlDialect, parser::Parser}; + + use super::*; + use crate::generic_database::{get_first_query, DbError}; + + #[test] + fn test_get_first_query() { + type TestCase = (&'static str, Result<(String, Box bool>), DbError>); + + let test_cases: Vec = vec![ + // single query + ("SELECT * FROM users;", Ok(("SELECT * FROM users".to_string(), Box::new(|s| matches!(s, Statement::Query(_)))))), + // multiple queries + ( + "SELECT * FROM users; DELETE FROM posts;", + Err(DbError::Right(ParserError::ParserError("Only one statement allowed per query".to_owned()))), + ), + // empty query + ("", Err(DbError::Right(ParserError::ParserError("Parsed query is empty".to_owned())))), + // syntax error + ( + "SELEC * FORM users;", + Err(DbError::Right(ParserError::ParserError( + "Expected: an SQL statement, found: SELEC at Line: 1, Column: 1".to_owned(), + ))), + ), + // lowercase + ( + "select * from \"public\".\"users\"", + Ok(("SELECT * FROM \"public\".\"users\"".to_owned(), Box::new(|s| matches!(s, Statement::Query(_))))), + ), + // newlines + ("select *\nfrom users;", Ok(("SELECT * FROM users".to_owned(), Box::new(|s| matches!(s, Statement::Query(_)))))), + // comment-only + ("-- select * from users;", Err(DbError::Right(ParserError::ParserError("Parsed query is empty".to_owned())))), + // commented line(s) + ( + "-- select blah;\nselect * from users", + Ok(("SELECT * FROM users".to_owned(), Box::new(|s| matches!(s, Statement::Query(_))))), + ), + ( + "-- select blah;\nselect * from users\n-- insert blah", + Ok(("SELECT * FROM users".to_owned(), Box::new(|s| matches!(s, Statement::Query(_))))), + ), + // update + ( + "UPDATE users SET name = 'John' WHERE id = 1", + Ok(( + "UPDATE users SET name = 'John' WHERE id = 1".to_owned(), + Box::new(|s| matches!(s, Statement::Update { .. })), + )), + ), + // delete + ( + "DELETE FROM users WHERE id = 1", + Ok(("DELETE FROM users WHERE id = 1".to_owned(), Box::new(|s| matches!(s, Statement::Delete(_))))), + ), + // drop + ("DROP TABLE users", Ok(("DROP TABLE users".to_owned(), Box::new(|s| matches!(s, Statement::Drop { .. }))))), + // explain + ( + "EXPLAIN SELECT * FROM users", + Ok(("EXPLAIN SELECT * FROM users".to_owned(), Box::new(|s| matches!(s, Statement::Explain { .. })))), + ), + ]; + + for (input, expected_output) in test_cases { + let result = get_first_query(input.to_string(), &PostgreSqlDialect {}); + match (result, expected_output) { + (Ok((query, statement_type)), Ok((expected_query, match_statement))) => { + assert_eq!(query, expected_query); + assert!(match_statement(statement_type)); + }, + ( + Err(Either::Right(ParserError::ParserError(msg))), + Err(Either::Right(ParserError::ParserError(expected_msg))), + ) => { + assert_eq!(msg, expected_msg) + }, + _ => panic!("Unexpected result for input: {}", input), + } + } + } + + #[test] + fn test_should_use_tx() { + let dialect = PostgreSqlDialect {}; + let test_cases = vec![ + ("DELETE FROM users WHERE id = 1", true), + ("DROP TABLE users", true), + ("UPDATE users SET name = 'John' WHERE id = 1", true), + ("SELECT * FROM users", false), + ("INSERT INTO users (name) VALUES ('John')", false), + ("EXPLAIN ANALYZE DELETE FROM users WHERE id = 1", true), + ("EXPLAIN SELECT * FROM users", false), + ("EXPLAIN ANALYZE SELECT * FROM users WHERE id = 1", false), + ]; + + for (query, expected) in test_cases { + let ast = Parser::parse_sql(&dialect, query).unwrap(); + let statement = ast[0].clone(); + // assert_eq!(should_use_tx(statement), expected, "Failed for query: {}", query); + } + } +} diff --git a/src/database/sqlite.rs b/src/database/sqlite.rs new file mode 100644 index 0000000..bcbca80 --- /dev/null +++ b/src/database/sqlite.rs @@ -0,0 +1,7 @@ + +impl ValueParser for Sqlite { + fn parse_value(row: &Self::Row, col: &Self::Column) -> Option { + // SQLite-specific parsing + todo!() + } +} \ No newline at end of file diff --git a/src/generic_database.rs b/src/generic_database.rs new file mode 100644 index 0000000..86e8ce3 --- /dev/null +++ b/src/generic_database.rs @@ -0,0 +1,187 @@ +use std::collections::HashMap; + +use futures::stream::{BoxStream, StreamExt}; +use sqlparser::{ + ast::Statement, + dialect::{Dialect, MsSqlDialect, MySqlDialect, PostgreSqlDialect, SQLiteDialect}, + keywords, + parser::{Parser, ParserError}, +}; +use sqlx::{ + mysql::{MySql, MySqlColumn, MySqlQueryResult, MySqlRow}, + pool::PoolOptions, + postgres::{PgColumn, PgQueryResult, PgRow, Postgres}, + sqlite::{Sqlite, SqliteColumn, SqliteQueryResult, SqliteRow}, + Column, Connection, Database, Either, Error, Executor, Pool, Row, Transaction, +}; + +use crate::database::{Header, Headers, Rows, Value}; + +pub type DbPool = Pool; +pub type DbError = Either; + +pub trait HasRowsAffected { + fn rows_affected(&self) -> u64; +} + +// Implement for PostgreSQL +impl HasRowsAffected for PgQueryResult { + fn rows_affected(&self) -> u64 { + self.rows_affected() + } +} + +// Implement for MySQL +impl HasRowsAffected for MySqlQueryResult { + fn rows_affected(&self) -> u64 { + self.rows_affected() + } +} + +// Implement for SQLite +impl HasRowsAffected for SqliteQueryResult { + fn rows_affected(&self) -> u64 { + self.rows_affected() + } +} + +pub async fn init_pool(opts: ::Options) -> Result, Error> { + PoolOptions::new().max_connections(5).connect_with(opts).await +} + +pub async fn query(query: String, pool: &Pool) -> Result +where + DB: Database + ValueParser, + DB::QueryResult: HasRowsAffected, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, +{ + // get the db type from + let dialect = get_dialect(DB::NAME); + let first_query = get_first_query(query, dialect.as_ref()); + match first_query { + Ok((first_query, _)) => { + let stream = sqlx::raw_sql(&first_query).fetch_many(pool); + query_stream::(stream).await + }, + Err(e) => Err(e), + } +} + +pub fn get_first_query(query: String, dialect: &dyn Dialect) -> Result<(String, Statement), DbError> { + let ast = Parser::parse_sql(dialect, &query); + match ast { + Ok(ast) if ast.len() > 1 => { + Err(Either::Right(ParserError::ParserError("Only one statement allowed per query".to_owned()))) + }, + Ok(ast) if ast.is_empty() => Err(Either::Right(ParserError::ParserError("Parsed query is empty".to_owned()))), + Ok(ast) => { + let statement = ast[0].clone(); + Ok((statement.to_string(), statement)) + }, + Err(e) => Err(Either::Right(e)), + } +} + +pub fn get_dialect(db_type: &str) -> Box { + match db_type { + "postgres" => Box::new(PostgreSqlDialect {}), + "mysql" => Box::new(MySqlDialect {}), + _ => Box::new(SQLiteDialect {}), + } +} + +pub trait ValueParser: Database { + fn parse_value(row: &Self::Row, col: &Self::Column) -> Option; +} + +pub fn row_to_vec(row: &DB::Row) -> Vec { + row.columns().iter().map(|col| DB::parse_value(row, col).unwrap().string).collect() +} + +pub fn row_to_json(row: &DB::Row) -> HashMap { + let mut result = HashMap::new(); + for col in row.columns() { + let value = match DB::parse_value(row, col) { + Some(v) => v.string, + _ => "[ unsupported ]".to_string(), + }; + result.insert(col.name().to_string(), value); + } + + result +} + +pub fn get_headers(row: &DB::Row) -> Headers { + row + .columns() + .iter() + .map(|col| Header { name: col.name().to_string(), type_name: col.type_info().to_string() }) + .collect() +} + +pub async fn query_stream<'a, DB>( + mut stream: BoxStream<'_, Result, Error>>, +) -> Result +where + DB: Database + ValueParser, + DB::QueryResult: HasRowsAffected, +{ + let mut query_finished = false; + let mut query_rows = vec![]; + let mut query_rows_affected: Option = None; + let mut headers: Headers = vec![]; + while !query_finished { + let next = stream.next().await; + match next { + Some(Ok(Either::Left(result))) => { + query_rows_affected = Some(result.rows_affected()); + query_finished = true; + }, + Some(Ok(Either::Right(row))) => { + query_rows.push(row_to_vec::(&row)); + if headers.is_empty() { + headers = get_headers::(&row); + } + }, + Some(Err(e)) => return Err(Either::Left(e)), + None => return Err(Either::Left(Error::Protocol("Results stream empty".to_owned()))), + }; + } + Ok(Rows { rows_affected: query_rows_affected, headers, rows: query_rows }) +} + +pub async fn query_with_tx<'a, DB>( + mut tx: Transaction<'_, DB>, + query: String, +) -> (Result, DbError>, Transaction<'_, DB>) +where + DB: Database + ValueParser, + DB::QueryResult: HasRowsAffected, + for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, +{ + let dialect = get_dialect(DB::NAME); + let first_query = get_first_query(query, dialect.as_ref()); + match first_query { + Ok((first_query, statement_type)) => { + match statement_type { + Statement::Explain { .. } => { + let stream = sqlx::raw_sql(&first_query).fetch_many(&mut *tx); + let result = query_stream::(stream).await; + match result { + Ok(result) => (Ok(Either::Right(result)), tx), + Err(e) => (Err(e), tx), + } + }, + _ => { + let result = sqlx::query(&first_query).execute(&mut *tx).await; + match result { + Ok(result) => (Ok(Either::Left(result.rows_affected())), tx), + Err(e) => (Err(DbError::Left(e)), tx), + } + }, + } + }, + Err(e) => (Err(e), tx), + } +} diff --git a/src/main.rs b/src/main.rs index 7e6fd92..fa92094 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,6 +9,7 @@ pub mod components; pub mod config; pub mod database; pub mod focus; +pub mod generic_database; pub mod tui; pub mod ui; pub mod utils; @@ -21,8 +22,8 @@ use std::{ use clap::Parser; use cli::Cli; -use color_eyre::eyre::Result; -use sqlx::postgres::PgConnectOptions; +use color_eyre::eyre::{self, Result}; +use sqlx::{postgres::PgConnectOptions, Connection, Database, Executor, Pool, Postgres}; use crate::{ app::App, @@ -35,10 +36,13 @@ async fn tokio_main() -> Result<()> { initialize_panic_handler()?; let args = Cli::parse(); - let connection_opts = build_connection_opts(args.clone())?; - let mut app = App::new(connection_opts)?; - app.run().await?; - + if let Some(db) = args.database.as_deref() { + if db == "postgres" { + let connection_opts = build_connection_opts::(args.clone())?; + let mut app = App::<'_, Postgres>::new(connection_opts)?; + app.run().await?; + } + } Ok(()) } @@ -53,75 +57,82 @@ async fn main() -> Result<()> { } // sqlx defaults to reading from environment variables if no inputs are provided -fn build_connection_opts(args: Cli) -> Result { +fn build_connection_opts(args: Cli) -> Result<::Options> +where + DB: Database + generic_database::ValueParser, + DB::QueryResult: generic_database::HasRowsAffected, + for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, +{ match args.connection_url { - Some(url) => Ok(PgConnectOptions::from_str(&url)?), + Some(url) => Ok(::Options::from_str(&url)?), None => { - let mut opts = PgConnectOptions::new(); - - if let Some(user) = args.user { - opts = opts.username(&user); - } else { - let mut user: String = String::new(); - print!("username: "); - io::stdout().flush().unwrap(); - io::stdin().read_line(&mut user).unwrap(); - user = user.trim().to_string(); - if !user.is_empty() { - opts = opts.username(&user); - } - } - - if let Some(password) = args.password { - opts = opts.password(&password); - } else { - let mut password = rpassword::prompt_password(format!("password for user {}: ", opts.get_username())).unwrap(); - password = password.trim().to_string(); - if !password.is_empty() { - opts = opts.password(&password); - } - } - - if let Some(host) = args.host { - opts = opts.host(&host); - } else { - let mut host: String = String::new(); - print!("host (ex. localhost): "); - io::stdout().flush().unwrap(); - io::stdin().read_line(&mut host).unwrap(); - host = host.trim().to_string(); - if !host.is_empty() { - opts = opts.host(&host); - } - } - - if let Some(port) = args.port { - opts = opts.port(port); - } else { - let mut port: String = String::new(); - print!("port (ex. 5432): "); - io::stdout().flush().unwrap(); - io::stdin().read_line(&mut port).unwrap(); - port = port.trim().to_string(); - if !port.is_empty() { - opts = opts.port(port.parse()?); - } - } - - if let Some(database) = args.database { - opts = opts.database(&database); - } else { - let mut database: String = String::new(); - print!("database (ex. postgres): "); - io::stdout().flush().unwrap(); - io::stdin().read_line(&mut database).unwrap(); - database = database.trim().to_string(); - if !database.is_empty() { - opts = opts.database(&database); - } - } - - Ok(opts) + // let mut opts = ::Options::new(); + + // if let Some(user) = args.user { + // opts = opts.username(&user); + // } else { + // let mut user: String = String::new(); + // print!("username: "); + // io::stdout().flush().unwrap(); + // io::stdin().read_line(&mut user).unwrap(); + // user = user.trim().to_string(); + // if !user.is_empty() { + // opts = opts.username(&user); + // } + // } + + // if let Some(password) = args.password { + // opts = opts.password(&password); + // } else { + // let mut password = rpassword::prompt_password(format!("password for user {}: ", opts.get_username())).unwrap(); + // password = password.trim().to_string(); + // if !password.is_empty() { + // opts = opts.password(&password); + // } + // } + + // if let Some(host) = args.host { + // opts = opts.host(&host); + // } else { + // let mut host: String = String::new(); + // print!("host (ex. localhost): "); + // io::stdout().flush().unwrap(); + // io::stdin().read_line(&mut host).unwrap(); + // host = host.trim().to_string(); + // if !host.is_empty() { + // opts = opts.host(&host); + // } + // } + + // if let Some(port) = args.port { + // opts = opts.port(port); + // } else { + // let mut port: String = String::new(); + // print!("port (ex. 5432): "); + // io::stdout().flush().unwrap(); + // io::stdin().read_line(&mut port).unwrap(); + // port = port.trim().to_string(); + // if !port.is_empty() { + // opts = opts.port(port.parse()?); + // } + // } + + // if let Some(database) = args.database { + // opts = opts.database(&database); + // } else { + // let mut database: String = String::new(); + // print!("database (ex. postgres): "); + // io::stdout().flush().unwrap(); + // io::stdin().read_line(&mut database).unwrap(); + // database = database.trim().to_string(); + // if !database.is_empty() { + // opts = opts.database(&database); + // } + // } + + // Ok(opts) + Err(eyre::Report::msg("Not implemented")) }, } } From 8d312d57bef0cb34529893d76db14b0114f50040 Mon Sep 17 00:00:00 2001 From: Frank Wang <1454884738@qq.com> Date: Mon, 23 Sep 2024 18:29:06 -0500 Subject: [PATCH 02/16] clean up --- src/app.rs | 51 ++-- src/components.rs | 5 +- src/components/data.rs | 17 +- src/components/editor.rs | 10 +- src/components/history.rs | 5 +- src/components/menu.rs | 17 +- src/components/scroll_table.rs | 5 +- src/database.rs | 455 +++++++++------------------------ src/database/postgresql.rs | 5 +- src/generic_database.rs | 187 -------------- src/main.rs | 5 +- 11 files changed, 173 insertions(+), 589 deletions(-) delete mode 100644 src/generic_database.rs diff --git a/src/app.rs b/src/app.rs index 283aaf8..fdd9285 100644 --- a/src/app.rs +++ b/src/app.rs @@ -13,7 +13,10 @@ use ratatui::{ Frame, }; use serde::{Deserialize, Serialize}; -use sqlparser::{ast::Statement, keywords::DELETE}; +use sqlparser::{ + ast::Statement, + keywords::{DELETE, NAME}, +}; use sqlx::{ postgres::{PgConnectOptions, Postgres}, Connection, Database, Either, Executor, Pool, Transaction, @@ -36,19 +39,18 @@ use crate::{ Component, }, config::Config, - database::{self, statement_type_string, DbError, DbPool, Rows}, + database::{self, get_dialect, statement_type_string, DbError, DbPool, Rows}, focus::Focus, - generic_database, tui, + tui, ui::center, }; pub enum DbTask<'a, DB> where - DB: Database + generic_database::ValueParser, - DB::QueryResult: generic_database::HasRowsAffected, + DB: Database + database::ValueParser, + DB::QueryResult: database::HasRowsAffected, for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, - { Query(tokio::task::JoinHandle), TxStart(tokio::task::JoinHandle<(QueryResultsWithMetadata, Transaction<'a, DB>)>), @@ -63,11 +65,10 @@ pub struct HistoryEntry { pub struct AppState<'a, DB: Database> where - DB: Database + generic_database::ValueParser, - DB::QueryResult: generic_database::HasRowsAffected, + DB: Database + database::ValueParser, + DB::QueryResult: database::HasRowsAffected, for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, - { pub connection_opts: ::Options, pub focus: Focus, @@ -77,13 +78,12 @@ where pub last_query_end: Option>, } -pub struct Components<'a, DB: Database + generic_database::ValueParser> +pub struct Components<'a, DB: Database + database::ValueParser> where - DB: Database + generic_database::ValueParser, - DB::QueryResult: generic_database::HasRowsAffected, + DB: Database + database::ValueParser, + DB::QueryResult: database::HasRowsAffected, for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, - { pub menu: Box>, pub editor: Box>, @@ -97,31 +97,29 @@ pub struct QueryResultsWithMetadata { pub statement_type: Statement, } -pub struct App<'a, DB: Database + generic_database::ValueParser> +pub struct App<'a, DB: Database + database::ValueParser> where - DB: Database + generic_database::ValueParser, - DB::QueryResult: generic_database::HasRowsAffected, + DB: Database + database::ValueParser, + DB::QueryResult: database::HasRowsAffected, for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, - { pub config: Config, pub components: Components<'static, DB>, pub should_quit: bool, pub last_tick_key_events: Vec, pub last_frame_mouse_event: Option, - pub pool: Option>, + pub pool: Option>, pub state: AppState<'a, DB>, last_focused_tab: Focus, } impl<'a, DB> App<'a, DB> where - DB: Database + generic_database::ValueParser, - DB::QueryResult: generic_database::HasRowsAffected, + DB: Database + database::ValueParser, + DB::QueryResult: database::HasRowsAffected, for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, - { pub fn new(connection_opts: ::Options) -> Result { let focus = Focus::Menu; @@ -168,7 +166,7 @@ where pub async fn run(&mut self) -> Result<()> { let (action_tx, mut action_rx) = mpsc::unbounded_channel(); let connection_opts = self.state.connection_opts.clone(); - let pool = generic_database::init_pool::(connection_opts).await?; + let pool = database::init_pool::(connection_opts).await?; log::info!("{pool:?}"); self.pool = Some(pool); @@ -383,7 +381,7 @@ where Action::LoadMenu => { log::info!("LoadMenu"); if let Some(pool) = &self.pool { - let results = generic_database::query( + let results = database::query( "select table_schema, table_name from information_schema.tables where table_schema != 'pg_catalog' @@ -401,7 +399,8 @@ where let query_string = query_lines.clone().join(" \n"); if !query_string.is_empty() { self.add_to_history(query_lines.clone()); - let first_query = database::get_first_query(query_string.clone()); + let dialect = get_dialect(::NAME); + let first_query = database::get_first_query(query_string.clone(), dialect.as_ref()); let should_use_tx = first_query .map(|(_, statement_type)| (database::should_use_tx(statement_type.clone()), statement_type)); let action_tx = action_tx.clone(); @@ -412,7 +411,7 @@ where self.components.data.set_loading(); let tx = pool.begin().await?; self.state.query_task = Some(DbTask::TxStart(tokio::spawn(async move { - let (results, tx) = generic_database::query_with_tx::(tx, query_string.clone()).await; + let (results, tx) = database::query_with_tx::(tx, query_string.clone()).await; match results { Ok(Either::Left(rows_affected)) => { log::info!("{:?} rows affected", rows_affected); @@ -440,7 +439,7 @@ where Ok((false, statement_type)) => { self.components.data.set_loading(); self.state.query_task = Some(DbTask::Query(tokio::spawn(async move { - let results = generic_database::query(query_string.clone(), &pool).await; + let results = database::query(query_string.clone(), &pool).await; match &results { Ok(rows) => { log::info!("{:?} rows, {:?} affected", rows.rows.len(), rows.rows_affected); diff --git a/src/components.rs b/src/components.rs index 3079d60..b522e4c 100644 --- a/src/components.rs +++ b/src/components.rs @@ -18,11 +18,10 @@ pub mod scroll_table; use sqlx::{Database, Executor, Pool}; pub trait Component where - DB: Database + crate::generic_database::ValueParser, - DB::QueryResult: crate::generic_database::HasRowsAffected, + DB: Database + crate::database::ValueParser, + DB::QueryResult: crate::database::HasRowsAffected, for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, - { /// Register an action handler that can send actions for processing if necessary. /// diff --git a/src/components/data.rs b/src/components/data.rs index e48b419..1090cee 100644 --- a/src/components/data.rs +++ b/src/components/data.rs @@ -21,7 +21,7 @@ use crate::{ Component, }, config::{Config, KeyBindings}, - database::{get_headers, parse_value, row_to_json, row_to_vec, statement_type_string, DbError, Rows}, + database::{get_headers, row_to_json, row_to_vec, statement_type_string, DbError, Rows}, focus::Focus, tui::Event, }; @@ -54,21 +54,19 @@ pub trait SettableDataTable<'a> { pub trait DataComponent<'a, DB>: Component + SettableDataTable<'a> where - DB: Database + crate::generic_database::ValueParser, - DB::QueryResult: crate::generic_database::HasRowsAffected, + DB: Database + crate::database::ValueParser, + DB::QueryResult: crate::database::HasRowsAffected, for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, - { } impl<'a, T, DB> DataComponent<'a, DB> for T where T: Component + SettableDataTable<'a>, - DB: Database + crate::generic_database::ValueParser, - DB::QueryResult: crate::generic_database::HasRowsAffected, + DB: Database + crate::database::ValueParser, + DB::QueryResult: crate::database::HasRowsAffected, for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, - { } @@ -249,11 +247,10 @@ impl<'a> SettableDataTable<'a> for Data<'a> { impl<'a, DB> Component for Data<'a> where - DB: Database + crate::generic_database::ValueParser, - DB::QueryResult: crate::generic_database::HasRowsAffected, + DB: Database + crate::database::ValueParser, + DB::QueryResult: crate::database::HasRowsAffected, for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, - { fn register_action_handler(&mut self, tx: UnboundedSender) -> Result<()> { self.command_tx = Some(tx); diff --git a/src/components/editor.rs b/src/components/editor.rs index a5e1660..82c49de 100644 --- a/src/components/editor.rs +++ b/src/components/editor.rs @@ -69,11 +69,10 @@ impl<'a> Editor<'a> { pub fn transition_vim_state(&mut self, input: Input, app_state: &AppState<'_, DB>) -> Result<()> where - DB: Database + crate::generic_database::ValueParser, - DB::QueryResult: crate::generic_database::HasRowsAffected, + DB: Database + crate::database::ValueParser, + DB::QueryResult: crate::database::HasRowsAffected, for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, - { match input { Input { key: Key::Enter, alt: true, .. } | Input { key: Key::Enter, ctrl: true, .. } => { @@ -118,11 +117,10 @@ impl<'a> Editor<'a> { impl<'a, DB> Component for Editor<'a> where - DB: Database + crate::generic_database::ValueParser, - DB::QueryResult: crate::generic_database::HasRowsAffected, + DB: Database + crate::database::ValueParser, + DB::QueryResult: crate::database::HasRowsAffected, for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, - { fn register_action_handler(&mut self, tx: UnboundedSender) -> Result<()> { self.command_tx = Some(tx); diff --git a/src/components/history.rs b/src/components/history.rs index 6239976..a45bf97 100644 --- a/src/components/history.rs +++ b/src/components/history.rs @@ -52,11 +52,10 @@ impl History { impl Component for History where - DB: Database + crate::generic_database::ValueParser, - DB::QueryResult: crate::generic_database::HasRowsAffected, + DB: Database + crate::database::ValueParser, + DB::QueryResult: crate::database::HasRowsAffected, for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, - { fn register_action_handler(&mut self, tx: UnboundedSender) -> Result<()> { self.command_tx = Some(tx); diff --git a/src/components/menu.rs b/src/components/menu.rs index 013ab80..e44cf57 100644 --- a/src/components/menu.rs +++ b/src/components/menu.rs @@ -19,7 +19,7 @@ use crate::{ action::{Action, MenuPreview}, app::{App, AppState}, config::{Config, KeyBindings}, - database::{get_headers, parse_value, row_to_json, row_to_vec, DbError, Rows}, + database::{get_headers, row_to_json, row_to_vec, DbError, Rows}, focus::Focus, tui::Event, }; @@ -37,22 +37,20 @@ pub trait SettableTableList<'a> { pub trait MenuComponent<'a, DB>: Component + SettableTableList<'a> where - DB: Database + crate::generic_database::ValueParser, - DB::QueryResult: crate::generic_database::HasRowsAffected, + DB: Database + crate::database::ValueParser, + DB::QueryResult: crate::database::HasRowsAffected, for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, - { } impl<'a, T, DB> MenuComponent<'a, DB> for T where T: Component + SettableTableList<'a>, - DB: Database + crate::generic_database::ValueParser, - DB::QueryResult: crate::generic_database::HasRowsAffected, + DB: Database + crate::database::ValueParser, + DB::QueryResult: crate::database::HasRowsAffected, for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, - { } @@ -206,11 +204,10 @@ impl<'a> SettableTableList<'a> for Menu { impl Component for Menu where - DB: Database + crate::generic_database::ValueParser, - DB::QueryResult: crate::generic_database::HasRowsAffected, + DB: Database + crate::database::ValueParser, + DB::QueryResult: crate::database::HasRowsAffected, for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, - { fn register_action_handler(&mut self, tx: UnboundedSender) -> Result<()> { self.command_tx = Some(tx); diff --git a/src/components/scroll_table.rs b/src/components/scroll_table.rs index 8986281..3fa4346 100644 --- a/src/components/scroll_table.rs +++ b/src/components/scroll_table.rs @@ -193,11 +193,10 @@ impl<'a> ScrollTable<'a> { impl<'a, DB> Component for ScrollTable<'a> where - DB: Database + crate::generic_database::ValueParser, - DB::QueryResult: crate::generic_database::HasRowsAffected, + DB: Database + crate::database::ValueParser, + DB::QueryResult: crate::database::HasRowsAffected, for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, - { fn draw(&mut self, f: &mut Frame<'_>, area: Rect, app_state: &AppState<'_, DB>) -> Result<()> { self.parent_area = area; diff --git a/src/database.rs b/src/database.rs index 40ad9e1..4b8c473 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,22 +1,21 @@ -use std::{collections::HashMap, fmt::Write, pin::Pin, string::String}; +use std::collections::HashMap; use futures::stream::{BoxStream, StreamExt}; use sqlparser::{ ast::Statement, - dialect::{MySqlDialect, PostgreSqlDialect, SQLiteDialect}, + dialect::{Dialect, MsSqlDialect, MySqlDialect, PostgreSqlDialect, SQLiteDialect}, keywords, parser::{Parser, ParserError}, }; use sqlx::{ - postgres::{ - PgColumn, PgConnectOptions, PgPool, PgPoolOptions, PgQueryResult, PgRow, PgTypeInfo, PgTypeKind, PgValueRef, - Postgres, - }, - types::Uuid, - Column, Database, Either, Error, Pool, Row, Transaction, ValueRef, + mysql::{MySql, MySqlColumn, MySqlQueryResult, MySqlRow}, + pool::PoolOptions, + postgres::{PgColumn, PgQueryResult, PgRow, Postgres}, + sqlite::{Sqlite, SqliteColumn, SqliteQueryResult, SqliteRow}, + Column, Connection, Database, Either, Error, Executor, Pool, Row, Transaction, }; -pub mod postgresql; +mod postgresql; #[derive(Debug)] pub struct Header { @@ -36,29 +35,115 @@ pub struct Rows { pub rows_affected: Option, } pub type Headers = Vec
; -pub type DbPool = PgPool; -pub type DbError = sqlx::Either; +pub type DbPool = Pool; +pub type DbError = Either; -pub async fn init_pool(opts: PgConnectOptions) -> Result { - PgPoolOptions::new().max_connections(5).connect_with(opts).await +pub trait HasRowsAffected { + fn rows_affected(&self) -> u64; } -// since it's possible for raw_sql to execute multiple queries in a single string, -// we only execute the first one and then drop the rest. -pub async fn query(query: String, pool: &PgPool) -> Result { - let first_query = get_first_query(query); +// Implement for PostgreSQL +impl HasRowsAffected for PgQueryResult { + fn rows_affected(&self) -> u64 { + self.rows_affected() + } +} + +// Implement for MySQL +impl HasRowsAffected for MySqlQueryResult { + fn rows_affected(&self) -> u64 { + self.rows_affected() + } +} + +// Implement for SQLite +impl HasRowsAffected for SqliteQueryResult { + fn rows_affected(&self) -> u64 { + self.rows_affected() + } +} + +pub async fn init_pool(opts: ::Options) -> Result, Error> { + PoolOptions::new().max_connections(5).connect_with(opts).await +} + +pub async fn query(query: String, pool: &Pool) -> Result +where + DB: Database + ValueParser, + DB::QueryResult: HasRowsAffected, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, +{ + // get the db type from + let dialect = get_dialect(DB::NAME); + let first_query = get_first_query(query, dialect.as_ref()); match first_query { Ok((first_query, _)) => { let stream = sqlx::raw_sql(&first_query).fetch_many(pool); - query_stream(stream).await + query_stream::(stream).await }, Err(e) => Err(e), } } -pub async fn query_stream( - mut stream: BoxStream<'_, Result, Error>>, -) -> Result { +pub fn get_first_query(query: String, dialect: &dyn Dialect) -> Result<(String, Statement), DbError> { + let ast = Parser::parse_sql(dialect, &query); + match ast { + Ok(ast) if ast.len() > 1 => { + Err(Either::Right(ParserError::ParserError("Only one statement allowed per query".to_owned()))) + }, + Ok(ast) if ast.is_empty() => Err(Either::Right(ParserError::ParserError("Parsed query is empty".to_owned()))), + Ok(ast) => { + let statement = ast[0].clone(); + Ok((statement.to_string(), statement)) + }, + Err(e) => Err(Either::Right(e)), + } +} + +pub fn get_dialect(db_type: &str) -> Box { + match db_type { + "postgres" => Box::new(PostgreSqlDialect {}), + "mysql" => Box::new(MySqlDialect {}), + _ => Box::new(SQLiteDialect {}), + } +} + +pub trait ValueParser: Database { + fn parse_value(row: &Self::Row, col: &Self::Column) -> Option; +} + +pub fn row_to_vec(row: &DB::Row) -> Vec { + row.columns().iter().map(|col| DB::parse_value(row, col).unwrap().string).collect() +} + +pub fn row_to_json(row: &DB::Row) -> HashMap { + let mut result = HashMap::new(); + for col in row.columns() { + let value = match DB::parse_value(row, col) { + Some(v) => v.string, + _ => "[ unsupported ]".to_string(), + }; + result.insert(col.name().to_string(), value); + } + + result +} + +pub fn get_headers(row: &DB::Row) -> Headers { + row + .columns() + .iter() + .map(|col| Header { name: col.name().to_string(), type_name: col.type_info().to_string() }) + .collect() +} + +pub async fn query_stream<'a, DB>( + mut stream: BoxStream<'_, Result, Error>>, +) -> Result +where + DB: Database + ValueParser, + DB::QueryResult: HasRowsAffected, +{ let mut query_finished = false; let mut query_rows = vec![]; let mut query_rows_affected: Option = None; @@ -71,9 +156,9 @@ pub async fn query_stream( query_finished = true; }, Some(Ok(Either::Right(row))) => { - query_rows.push(row_to_vec(&row)); + query_rows.push(row_to_vec::(&row)); if headers.is_empty() { - headers = get_headers(&row); + headers = get_headers::(&row); } }, Some(Err(e)) => return Err(Either::Left(e)), @@ -83,17 +168,24 @@ pub async fn query_stream( Ok(Rows { rows_affected: query_rows_affected, headers, rows: query_rows }) } -pub async fn query_with_tx<'a>( - mut tx: Transaction<'_, Postgres>, +pub async fn query_with_tx<'a, DB>( + mut tx: Transaction<'_, DB>, query: String, -) -> (Result, DbError>, Transaction<'_, Postgres>) { - let first_query = get_first_query(query); +) -> (Result, DbError>, Transaction<'_, DB>) +where + DB: Database + ValueParser, + DB::QueryResult: HasRowsAffected, + for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, +{ + let dialect = get_dialect(DB::NAME); + let first_query = get_first_query(query, dialect.as_ref()); match first_query { Ok((first_query, statement_type)) => { match statement_type { Statement::Explain { .. } => { let stream = sqlx::raw_sql(&first_query).fetch_many(&mut *tx); - let result = query_stream(stream).await; + let result = query_stream::(stream).await; match result { Ok(result) => (Ok(Either::Right(result)), tx), Err(e) => (Err(e), tx), @@ -112,22 +204,6 @@ pub async fn query_with_tx<'a>( } } -pub fn get_first_query(query: String) -> Result<(String, Statement), DbError> { - let dialect = PostgreSqlDialect {}; - let ast = Parser::parse_sql(&dialect, &query); - match ast { - Ok(ast) if ast.len() > 1 => { - Err(Either::Right(ParserError::ParserError("Only one statement allowed per query".to_owned()))) - }, - Ok(ast) if ast.is_empty() => Err(Either::Right(ParserError::ParserError("Parsed query is empty".to_owned()))), - Ok(ast) => { - let statement = ast[0].clone(); - Ok((statement.to_string(), statement)) - }, - Err(e) => Err(Either::Right(e)), - } -} - pub fn statement_type_string(statement: &Statement) -> String { format!("{:?}", statement).split('(').collect::>()[0].split('{').collect::>()[0] .split('[') @@ -150,186 +226,6 @@ pub fn should_use_tx(statement: Statement) -> bool { } } -pub fn get_headers(row: &PgRow) -> Headers { - row - .columns() - .iter() - .map(|col| Header { name: col.name().to_string(), type_name: col.type_info().to_string() }) - .collect() -} - -// parsed based on https://docs.rs/sqlx/latest/sqlx/postgres/types/index.html -pub fn parse_value(row: &PgRow, col: &PgColumn) -> Option { - let col_type = col.type_info().to_string(); - let raw_value = row.try_get_raw(col.ordinal()).unwrap(); - if raw_value.is_null() { - return Some(Value { string: "NULL".to_string(), is_null: true }); - } - match col_type.to_uppercase().as_str() { - "TIMESTAMPTZ" => { - let received: chrono::DateTime = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: received.to_string(), is_null: false }) - }, - "TIMESTAMP" => { - let received: chrono::NaiveDateTime = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: received.to_string(), is_null: false }) - }, - "DATE" => { - let received: chrono::NaiveDate = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: received.to_string(), is_null: false }) - }, - "TIME" => { - let received: chrono::NaiveTime = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: received.to_string(), is_null: false }) - }, - "UUID" => { - let received: Uuid = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: received.to_string(), is_null: false }) - }, - "INET" | "CIDR" => { - let received: std::net::IpAddr = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: received.to_string(), is_null: false }) - }, - "JSON" | "JSONB" => { - let received: serde_json::Value = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: received.to_string(), is_null: false }) - }, - "BOOL" => { - let received: bool = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: received.to_string(), is_null: false }) - }, - "SMALLINT" | "SMALLSERIAL" | "INT2" => { - let received: i16 = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: received.to_string(), is_null: false }) - }, - "INT" | "SERIAL" | "INT4" => { - let received: i32 = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: received.to_string(), is_null: false }) - }, - "BIGINT" | "BIGSERIAL" | "INT8" => { - let received: i64 = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: received.to_string(), is_null: false }) - }, - "REAL" | "FLOAT4" => { - let received: f32 = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: received.to_string(), is_null: false }) - }, - "DOUBLE PRECISION" | "FLOAT8" => { - let received: f64 = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: received.to_string(), is_null: false }) - }, - "TEXT" | "VARCHAR" | "NAME" | "CITEXT" | "BPCHAR" | "CHAR" => { - let received: String = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: received, is_null: false }) - }, - "BYTEA" => { - let received: Vec = row.try_get(col.ordinal()).unwrap(); - Some(Value { - string: received.iter().fold(String::new(), |mut output, b| { - let _ = write!(output, "{b:02X}"); - output - }), - is_null: false, - }) - }, - "VOID" => Some(Value { string: "".to_string(), is_null: false }), - _ if col_type.to_uppercase().ends_with("[]") => { - let array_type = col_type.to_uppercase().replace("[]", ""); - match array_type.as_str() { - "TIMESTAMPTZ" => { - let received: Vec> = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: vec_to_string(received), is_null: false }) - }, - "TIMESTAMP" => { - let received: Vec = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: vec_to_string(received), is_null: false }) - }, - "DATE" => { - let received: Vec = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: vec_to_string(received), is_null: false }) - }, - "TIME" => { - let received: Vec = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: vec_to_string(received), is_null: false }) - }, - "UUID" => { - let received: Vec = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: vec_to_string(received), is_null: false }) - }, - "INET" | "CIDR" => { - let received: Vec = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: vec_to_string(received), is_null: false }) - }, - "JSON" | "JSONB" => { - let received: Vec = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: vec_to_string(received), is_null: false }) - }, - "BOOL" => { - let received: Vec = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: vec_to_string(received), is_null: false }) - }, - "SMALLINT" | "SMALLSERIAL" | "INT2" => { - let received: Vec = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: vec_to_string(received), is_null: false }) - }, - "INT" | "SERIAL" | "INT4" => { - let received: Vec = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: vec_to_string(received), is_null: false }) - }, - "BIGINT" | "BIGSERIAL" | "INT8" => { - let received: Vec = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: vec_to_string(received), is_null: false }) - }, - "REAL" | "FLOAT4" => { - let received: Vec = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: vec_to_string(received), is_null: false }) - }, - "DOUBLE PRECISION" | "FLOAT8" => { - let received: Vec = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: vec_to_string(received), is_null: false }) - }, - "TEXT" | "VARCHAR" | "NAME" | "CITEXT" | "BPCHAR" | "CHAR" => { - let received: Vec = row.try_get(col.ordinal()).unwrap(); - Some(Value { string: vec_to_string(received), is_null: false }) - }, - "BYTEA" => { - let received: Vec = row.try_get(col.ordinal()).unwrap(); - Some(Value { - string: received.iter().fold(String::new(), |mut output, b| { - let _ = write!(output, "{b:02X}"); - output - }), - is_null: false, - }) - }, - _ => { - // try to cast custom or other types to strings - let received: Vec = row.try_get_unchecked(col.ordinal()).unwrap(); - Some(Value { string: vec_to_string(received), is_null: false }) - }, - } - }, - _ => { - // try to cast custom or other types to strings - let received: String = row.try_get_unchecked(col.ordinal()).unwrap(); - Some(Value { string: received, is_null: false }) - }, - } -} - -pub fn row_to_json(row: &PgRow) -> HashMap { - let mut result = HashMap::new(); - for col in row.columns() { - let value = match parse_value(row, col) { - Some(v) => v.string, - _ => "[ unsupported ]".to_string(), - }; - result.insert(col.name().to_string(), value); - } - - result -} - pub fn vec_to_string(vec: Vec) -> String { let mut content = String::new(); for (i, elem) in vec.iter().enumerate() { @@ -341,117 +237,6 @@ pub fn vec_to_string(vec: Vec) -> String { "{ ".to_owned() + &*content + &*" }".to_owned() } -pub fn row_to_vec(row: &PgRow) -> Vec { - row.columns().iter().map(|col| parse_value(row, col).unwrap().string).collect() -} - pub fn get_keywords() -> Vec { keywords::ALL_KEYWORDS.iter().map(|k| k.to_string()).collect() } - -#[cfg(test)] -mod tests { - use sqlparser::{ast::Statement, dialect::PostgreSqlDialect, parser::Parser}; - - use super::*; - - #[test] - fn test_get_first_query() { - type TestCase = (&'static str, Result<(String, Box bool>), DbError>); - - let test_cases: Vec = vec![ - // single query - ("SELECT * FROM users;", Ok(("SELECT * FROM users".to_string(), Box::new(|s| matches!(s, Statement::Query(_)))))), - // multiple queries - ( - "SELECT * FROM users; DELETE FROM posts;", - Err(DbError::Right(ParserError::ParserError("Only one statement allowed per query".to_owned()))), - ), - // empty query - ("", Err(DbError::Right(ParserError::ParserError("Parsed query is empty".to_owned())))), - // syntax error - ( - "SELEC * FORM users;", - Err(DbError::Right(ParserError::ParserError( - "Expected: an SQL statement, found: SELEC at Line: 1, Column: 1".to_owned(), - ))), - ), - // lowercase - ( - "select * from \"public\".\"users\"", - Ok(("SELECT * FROM \"public\".\"users\"".to_owned(), Box::new(|s| matches!(s, Statement::Query(_))))), - ), - // newlines - ("select *\nfrom users;", Ok(("SELECT * FROM users".to_owned(), Box::new(|s| matches!(s, Statement::Query(_)))))), - // comment-only - ("-- select * from users;", Err(DbError::Right(ParserError::ParserError("Parsed query is empty".to_owned())))), - // commented line(s) - ( - "-- select blah;\nselect * from users", - Ok(("SELECT * FROM users".to_owned(), Box::new(|s| matches!(s, Statement::Query(_))))), - ), - ( - "-- select blah;\nselect * from users\n-- insert blah", - Ok(("SELECT * FROM users".to_owned(), Box::new(|s| matches!(s, Statement::Query(_))))), - ), - // update - ( - "UPDATE users SET name = 'John' WHERE id = 1", - Ok(( - "UPDATE users SET name = 'John' WHERE id = 1".to_owned(), - Box::new(|s| matches!(s, Statement::Update { .. })), - )), - ), - // delete - ( - "DELETE FROM users WHERE id = 1", - Ok(("DELETE FROM users WHERE id = 1".to_owned(), Box::new(|s| matches!(s, Statement::Delete(_))))), - ), - // drop - ("DROP TABLE users", Ok(("DROP TABLE users".to_owned(), Box::new(|s| matches!(s, Statement::Drop { .. }))))), - // explain - ( - "EXPLAIN SELECT * FROM users", - Ok(("EXPLAIN SELECT * FROM users".to_owned(), Box::new(|s| matches!(s, Statement::Explain { .. })))), - ), - ]; - - for (input, expected_output) in test_cases { - let result = get_first_query(input.to_string()); - match (result, expected_output) { - (Ok((query, statement_type)), Ok((expected_query, match_statement))) => { - assert_eq!(query, expected_query); - assert!(match_statement(statement_type)); - }, - ( - Err(Either::Right(ParserError::ParserError(msg))), - Err(Either::Right(ParserError::ParserError(expected_msg))), - ) => { - assert_eq!(msg, expected_msg) - }, - _ => panic!("Unexpected result for input: {}", input), - } - } - } - - #[test] - fn test_should_use_tx() { - let dialect = PostgreSqlDialect {}; - let test_cases = vec![ - ("DELETE FROM users WHERE id = 1", true), - ("DROP TABLE users", true), - ("UPDATE users SET name = 'John' WHERE id = 1", true), - ("SELECT * FROM users", false), - ("INSERT INTO users (name) VALUES ('John')", false), - ("EXPLAIN ANALYZE DELETE FROM users WHERE id = 1", true), - ("EXPLAIN SELECT * FROM users", false), - ("EXPLAIN ANALYZE SELECT * FROM users WHERE id = 1", false), - ]; - - for (query, expected) in test_cases { - let ast = Parser::parse_sql(&dialect, query).unwrap(); - let statement = ast[0].clone(); - assert_eq!(should_use_tx(statement), expected, "Failed for query: {}", query); - } - } -} diff --git a/src/database/postgresql.rs b/src/database/postgresql.rs index 1140d1d..de21848 100644 --- a/src/database/postgresql.rs +++ b/src/database/postgresql.rs @@ -17,9 +17,8 @@ use sqlx::{ }; use super::{vec_to_string, Value}; -use crate::generic_database::ValueParser; -impl ValueParser for Postgres { +impl super::ValueParser for Postgres { fn parse_value(row: &::Row, col: &::Column) -> Option { let col_type = col.type_info().to_string(); let raw_value = row.try_get_raw(col.ordinal()).unwrap(); @@ -182,7 +181,7 @@ mod tests { use sqlparser::{ast::Statement, dialect::PostgreSqlDialect, parser::Parser}; use super::*; - use crate::generic_database::{get_first_query, DbError}; + use crate::database::{get_first_query, DbError}; #[test] fn test_get_first_query() { diff --git a/src/generic_database.rs b/src/generic_database.rs deleted file mode 100644 index 86e8ce3..0000000 --- a/src/generic_database.rs +++ /dev/null @@ -1,187 +0,0 @@ -use std::collections::HashMap; - -use futures::stream::{BoxStream, StreamExt}; -use sqlparser::{ - ast::Statement, - dialect::{Dialect, MsSqlDialect, MySqlDialect, PostgreSqlDialect, SQLiteDialect}, - keywords, - parser::{Parser, ParserError}, -}; -use sqlx::{ - mysql::{MySql, MySqlColumn, MySqlQueryResult, MySqlRow}, - pool::PoolOptions, - postgres::{PgColumn, PgQueryResult, PgRow, Postgres}, - sqlite::{Sqlite, SqliteColumn, SqliteQueryResult, SqliteRow}, - Column, Connection, Database, Either, Error, Executor, Pool, Row, Transaction, -}; - -use crate::database::{Header, Headers, Rows, Value}; - -pub type DbPool = Pool; -pub type DbError = Either; - -pub trait HasRowsAffected { - fn rows_affected(&self) -> u64; -} - -// Implement for PostgreSQL -impl HasRowsAffected for PgQueryResult { - fn rows_affected(&self) -> u64 { - self.rows_affected() - } -} - -// Implement for MySQL -impl HasRowsAffected for MySqlQueryResult { - fn rows_affected(&self) -> u64 { - self.rows_affected() - } -} - -// Implement for SQLite -impl HasRowsAffected for SqliteQueryResult { - fn rows_affected(&self) -> u64 { - self.rows_affected() - } -} - -pub async fn init_pool(opts: ::Options) -> Result, Error> { - PoolOptions::new().max_connections(5).connect_with(opts).await -} - -pub async fn query(query: String, pool: &Pool) -> Result -where - DB: Database + ValueParser, - DB::QueryResult: HasRowsAffected, - for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, -{ - // get the db type from - let dialect = get_dialect(DB::NAME); - let first_query = get_first_query(query, dialect.as_ref()); - match first_query { - Ok((first_query, _)) => { - let stream = sqlx::raw_sql(&first_query).fetch_many(pool); - query_stream::(stream).await - }, - Err(e) => Err(e), - } -} - -pub fn get_first_query(query: String, dialect: &dyn Dialect) -> Result<(String, Statement), DbError> { - let ast = Parser::parse_sql(dialect, &query); - match ast { - Ok(ast) if ast.len() > 1 => { - Err(Either::Right(ParserError::ParserError("Only one statement allowed per query".to_owned()))) - }, - Ok(ast) if ast.is_empty() => Err(Either::Right(ParserError::ParserError("Parsed query is empty".to_owned()))), - Ok(ast) => { - let statement = ast[0].clone(); - Ok((statement.to_string(), statement)) - }, - Err(e) => Err(Either::Right(e)), - } -} - -pub fn get_dialect(db_type: &str) -> Box { - match db_type { - "postgres" => Box::new(PostgreSqlDialect {}), - "mysql" => Box::new(MySqlDialect {}), - _ => Box::new(SQLiteDialect {}), - } -} - -pub trait ValueParser: Database { - fn parse_value(row: &Self::Row, col: &Self::Column) -> Option; -} - -pub fn row_to_vec(row: &DB::Row) -> Vec { - row.columns().iter().map(|col| DB::parse_value(row, col).unwrap().string).collect() -} - -pub fn row_to_json(row: &DB::Row) -> HashMap { - let mut result = HashMap::new(); - for col in row.columns() { - let value = match DB::parse_value(row, col) { - Some(v) => v.string, - _ => "[ unsupported ]".to_string(), - }; - result.insert(col.name().to_string(), value); - } - - result -} - -pub fn get_headers(row: &DB::Row) -> Headers { - row - .columns() - .iter() - .map(|col| Header { name: col.name().to_string(), type_name: col.type_info().to_string() }) - .collect() -} - -pub async fn query_stream<'a, DB>( - mut stream: BoxStream<'_, Result, Error>>, -) -> Result -where - DB: Database + ValueParser, - DB::QueryResult: HasRowsAffected, -{ - let mut query_finished = false; - let mut query_rows = vec![]; - let mut query_rows_affected: Option = None; - let mut headers: Headers = vec![]; - while !query_finished { - let next = stream.next().await; - match next { - Some(Ok(Either::Left(result))) => { - query_rows_affected = Some(result.rows_affected()); - query_finished = true; - }, - Some(Ok(Either::Right(row))) => { - query_rows.push(row_to_vec::(&row)); - if headers.is_empty() { - headers = get_headers::(&row); - } - }, - Some(Err(e)) => return Err(Either::Left(e)), - None => return Err(Either::Left(Error::Protocol("Results stream empty".to_owned()))), - }; - } - Ok(Rows { rows_affected: query_rows_affected, headers, rows: query_rows }) -} - -pub async fn query_with_tx<'a, DB>( - mut tx: Transaction<'_, DB>, - query: String, -) -> (Result, DbError>, Transaction<'_, DB>) -where - DB: Database + ValueParser, - DB::QueryResult: HasRowsAffected, - for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, - for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, -{ - let dialect = get_dialect(DB::NAME); - let first_query = get_first_query(query, dialect.as_ref()); - match first_query { - Ok((first_query, statement_type)) => { - match statement_type { - Statement::Explain { .. } => { - let stream = sqlx::raw_sql(&first_query).fetch_many(&mut *tx); - let result = query_stream::(stream).await; - match result { - Ok(result) => (Ok(Either::Right(result)), tx), - Err(e) => (Err(e), tx), - } - }, - _ => { - let result = sqlx::query(&first_query).execute(&mut *tx).await; - match result { - Ok(result) => (Ok(Either::Left(result.rows_affected())), tx), - Err(e) => (Err(DbError::Left(e)), tx), - } - }, - } - }, - Err(e) => (Err(e), tx), - } -} diff --git a/src/main.rs b/src/main.rs index fa92094..28778ce 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,6 @@ pub mod components; pub mod config; pub mod database; pub mod focus; -pub mod generic_database; pub mod tui; pub mod ui; pub mod utils; @@ -59,8 +58,8 @@ async fn main() -> Result<()> { // sqlx defaults to reading from environment variables if no inputs are provided fn build_connection_opts(args: Cli) -> Result<::Options> where - DB: Database + generic_database::ValueParser, - DB::QueryResult: generic_database::HasRowsAffected, + DB: Database + database::ValueParser, + DB::QueryResult: database::HasRowsAffected, for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, { From 293217f22433e058c70ffe00b4f38a8ef6b53d54 Mon Sep 17 00:00:00 2001 From: Frank Wang <1454884738@qq.com> Date: Mon, 23 Sep 2024 19:37:08 -0500 Subject: [PATCH 03/16] clean up --- src/app.rs | 32 ++++---------------------------- src/cli.rs | 2 +- src/components.rs | 9 +-------- src/components/data.rs | 17 ++--------------- src/components/history.rs | 8 +------- src/components/menu.rs | 10 ++-------- src/database.rs | 21 --------------------- src/database/mysql.rs | 6 ++++++ src/database/postgresql.rs | 6 ++++++ src/database/sqlite.rs | 7 ++++++- src/main.rs | 13 ++++++------- 11 files changed, 35 insertions(+), 96 deletions(-) diff --git a/src/app.rs b/src/app.rs index fdd9285..5d4621a 100644 --- a/src/app.rs +++ b/src/app.rs @@ -45,13 +45,7 @@ use crate::{ ui::center, }; -pub enum DbTask<'a, DB> -where - DB: Database + database::ValueParser, - DB::QueryResult: database::HasRowsAffected, - for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, - for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, -{ +pub enum DbTask<'a, DB: sqlx::Database> { Query(tokio::task::JoinHandle), TxStart(tokio::task::JoinHandle<(QueryResultsWithMetadata, Transaction<'a, DB>)>), TxPending(Transaction<'a, DB>, QueryResultsWithMetadata), @@ -63,13 +57,7 @@ pub struct HistoryEntry { pub timestamp: chrono::DateTime, } -pub struct AppState<'a, DB: Database> -where - DB: Database + database::ValueParser, - DB::QueryResult: database::HasRowsAffected, - for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, - for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, -{ +pub struct AppState<'a, DB: Database> { pub connection_opts: ::Options, pub focus: Focus, pub query_task: Option>, @@ -78,13 +66,7 @@ where pub last_query_end: Option>, } -pub struct Components<'a, DB: Database + database::ValueParser> -where - DB: Database + database::ValueParser, - DB::QueryResult: database::HasRowsAffected, - for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, - for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, -{ +pub struct Components<'a, DB: sqlx::Database> { pub menu: Box>, pub editor: Box>, pub history: Box>, @@ -97,13 +79,7 @@ pub struct QueryResultsWithMetadata { pub statement_type: Statement, } -pub struct App<'a, DB: Database + database::ValueParser> -where - DB: Database + database::ValueParser, - DB::QueryResult: database::HasRowsAffected, - for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, - for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, -{ +pub struct App<'a, DB: sqlx::Database> { pub config: Config, pub components: Components<'static, DB>, pub should_quit: bool, diff --git a/src/cli.rs b/src/cli.rs index 3eb59e1..4ec7f6d 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -28,5 +28,5 @@ pub struct Cli { pub port: Option, #[arg(long = "database", value_name = "DATABASE", help = "Name of database for connection (ex. postgres)")] - pub database: Option, + pub database: String, } diff --git a/src/components.rs b/src/components.rs index b522e4c..f90b913 100644 --- a/src/components.rs +++ b/src/components.rs @@ -15,14 +15,7 @@ pub mod editor; pub mod history; pub mod menu; pub mod scroll_table; -use sqlx::{Database, Executor, Pool}; -pub trait Component -where - DB: Database + crate::database::ValueParser, - DB::QueryResult: crate::database::HasRowsAffected, - for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, - for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, -{ +pub trait Component { /// Register an action handler that can send actions for processing if necessary. /// /// # Arguments diff --git a/src/components/data.rs b/src/components/data.rs index 1090cee..d9d409b 100644 --- a/src/components/data.rs +++ b/src/components/data.rs @@ -52,21 +52,8 @@ pub trait SettableDataTable<'a> { fn set_cancelled(&mut self); } -pub trait DataComponent<'a, DB>: Component + SettableDataTable<'a> -where - DB: Database + crate::database::ValueParser, - DB::QueryResult: crate::database::HasRowsAffected, - for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, - for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, -{ -} -impl<'a, T, DB> DataComponent<'a, DB> for T -where - T: Component + SettableDataTable<'a>, - DB: Database + crate::database::ValueParser, - DB::QueryResult: crate::database::HasRowsAffected, - for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, - for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, +pub trait DataComponent<'a, DB: sqlx::Database>: Component + SettableDataTable<'a> {} +impl<'a, T, DB: sqlx::Database> DataComponent<'a, DB> for T where T: Component + SettableDataTable<'a> { } diff --git a/src/components/history.rs b/src/components/history.rs index a45bf97..9a600e7 100644 --- a/src/components/history.rs +++ b/src/components/history.rs @@ -50,13 +50,7 @@ impl History { } } -impl Component for History -where - DB: Database + crate::database::ValueParser, - DB::QueryResult: crate::database::HasRowsAffected, - for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, - for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, -{ +impl Component for History { fn register_action_handler(&mut self, tx: UnboundedSender) -> Result<()> { self.command_tx = Some(tx); Ok(()) diff --git a/src/components/menu.rs b/src/components/menu.rs index e44cf57..e223287 100644 --- a/src/components/menu.rs +++ b/src/components/menu.rs @@ -37,20 +37,14 @@ pub trait SettableTableList<'a> { pub trait MenuComponent<'a, DB>: Component + SettableTableList<'a> where - DB: Database + crate::database::ValueParser, - DB::QueryResult: crate::database::HasRowsAffected, - for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, - for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + DB: sqlx::Database, { } impl<'a, T, DB> MenuComponent<'a, DB> for T where T: Component + SettableTableList<'a>, - DB: Database + crate::database::ValueParser, - DB::QueryResult: crate::database::HasRowsAffected, - for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, - for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + DB: sqlx::Database, { } diff --git a/src/database.rs b/src/database.rs index 4b8c473..e23f181 100644 --- a/src/database.rs +++ b/src/database.rs @@ -42,27 +42,6 @@ pub trait HasRowsAffected { fn rows_affected(&self) -> u64; } -// Implement for PostgreSQL -impl HasRowsAffected for PgQueryResult { - fn rows_affected(&self) -> u64 { - self.rows_affected() - } -} - -// Implement for MySQL -impl HasRowsAffected for MySqlQueryResult { - fn rows_affected(&self) -> u64 { - self.rows_affected() - } -} - -// Implement for SQLite -impl HasRowsAffected for SqliteQueryResult { - fn rows_affected(&self) -> u64 { - self.rows_affected() - } -} - pub async fn init_pool(opts: ::Options) -> Result, Error> { PoolOptions::new().max_connections(5).connect_with(opts).await } diff --git a/src/database/mysql.rs b/src/database/mysql.rs index fe83b4d..058f4d8 100644 --- a/src/database/mysql.rs +++ b/src/database/mysql.rs @@ -1,6 +1,12 @@ use super::Value; use crate::generic_database::ValueParser; +impl super::HasRowsAffected for MySqlQueryResult { + fn rows_affected(&self) -> u64 { + self.rows_affected() + } +} + impl ValueParser for MySql { fn parse_value(row: &Self::Row, col: &Self::Column) -> Option { // MySQL-specific parsing diff --git a/src/database/postgresql.rs b/src/database/postgresql.rs index de21848..6e0afc0 100644 --- a/src/database/postgresql.rs +++ b/src/database/postgresql.rs @@ -18,6 +18,12 @@ use sqlx::{ use super::{vec_to_string, Value}; +impl super::HasRowsAffected for PgQueryResult { + fn rows_affected(&self) -> u64 { + self.rows_affected() + } +} + impl super::ValueParser for Postgres { fn parse_value(row: &::Row, col: &::Column) -> Option { let col_type = col.type_info().to_string(); diff --git a/src/database/sqlite.rs b/src/database/sqlite.rs index bcbca80..b23b6b6 100644 --- a/src/database/sqlite.rs +++ b/src/database/sqlite.rs @@ -1,7 +1,12 @@ +impl super::HasRowsAffected for SqliteQueryResult { + fn rows_affected(&self) -> u64 { + self.rows_affected() + } +} impl ValueParser for Sqlite { fn parse_value(row: &Self::Row, col: &Self::Column) -> Option { // SQLite-specific parsing todo!() } -} \ No newline at end of file +} diff --git a/src/main.rs b/src/main.rs index 28778ce..cfa1d4a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -35,12 +35,15 @@ async fn tokio_main() -> Result<()> { initialize_panic_handler()?; let args = Cli::parse(); - if let Some(db) = args.database.as_deref() { - if db == "postgres" { + match args.database.as_str() { + "postgres" => { let connection_opts = build_connection_opts::(args.clone())?; let mut app = App::<'_, Postgres>::new(connection_opts)?; app.run().await?; - } + }, + "mysql" => todo!(), + "sqlite" => todo!(), + _ => return Err(eyre::Report::msg("Please specify a database type")), } Ok(()) } @@ -58,10 +61,6 @@ async fn main() -> Result<()> { // sqlx defaults to reading from environment variables if no inputs are provided fn build_connection_opts(args: Cli) -> Result<::Options> where - DB: Database + database::ValueParser, - DB::QueryResult: database::HasRowsAffected, - for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, - for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, { match args.connection_url { Some(url) => Ok(::Options::from_str(&url)?), From 2030d276850782614c4a083ab3d2828296d67ab9 Mon Sep 17 00:00:00 2001 From: Frank Wang <1454884738@qq.com> Date: Mon, 23 Sep 2024 19:38:27 -0500 Subject: [PATCH 04/16] more clean up --- src/components/scroll_table.rs | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/components/scroll_table.rs b/src/components/scroll_table.rs index 3fa4346..4333edb 100644 --- a/src/components/scroll_table.rs +++ b/src/components/scroll_table.rs @@ -191,13 +191,7 @@ impl<'a> ScrollTable<'a> { } } -impl<'a, DB> Component for ScrollTable<'a> -where - DB: Database + crate::database::ValueParser, - DB::QueryResult: crate::database::HasRowsAffected, - for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, - for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, -{ +impl<'a, DB: Database> Component for ScrollTable<'a> { fn draw(&mut self, f: &mut Frame<'_>, area: Rect, app_state: &AppState<'_, DB>) -> Result<()> { self.parent_area = area; let render_area = self.block.inner_if_some(area); From 8f287fc5f57d1a96d245bcb77bb1c21a32fbc2fc Mon Sep 17 00:00:00 2001 From: Frank Wang <1454884738@qq.com> Date: Wed, 25 Sep 2024 20:21:52 -0500 Subject: [PATCH 05/16] feat: impl parser for sqlite and musql --- src/app.rs | 29 ++--- src/cli.rs | 5 +- src/components/data.rs | 10 +- src/components/editor.rs | 49 ++----- src/components/menu.rs | 8 +- src/database.rs | 161 +++++++++++++---------- src/database/mysql.rs | 260 ++++++++++++++++++++++++++++++++++++- src/database/postgresql.rs | 146 +++++++++++++++++++-- src/database/sqlite.rs | 176 ++++++++++++++++++++++++- src/main.rs | 108 +++------------ 10 files changed, 702 insertions(+), 250 deletions(-) diff --git a/src/app.rs b/src/app.rs index 5d4621a..6fc2f95 100644 --- a/src/app.rs +++ b/src/app.rs @@ -15,6 +15,7 @@ use ratatui::{ use serde::{Deserialize, Serialize}; use sqlparser::{ ast::Statement, + dialect::Dialect, keywords::{DELETE, NAME}, }; use sqlx::{ @@ -39,7 +40,7 @@ use crate::{ Component, }, config::Config, - database::{self, get_dialect, statement_type_string, DbError, DbPool, Rows}, + database::{self, get_dialect, statement_type_string, DatabaseQueries, DbError, DbPool, Rows}, focus::Focus, tui, ui::center, @@ -59,6 +60,7 @@ pub struct HistoryEntry { pub struct AppState<'a, DB: Database> { pub connection_opts: ::Options, + pub dialect: Arc, pub focus: Focus, pub query_task: Option>, pub history: Vec, @@ -92,7 +94,7 @@ pub struct App<'a, DB: sqlx::Database> { impl<'a, DB> App<'a, DB> where - DB: Database + database::ValueParser, + DB: Database + database::ValueParser + database::DatabaseQueries, DB::QueryResult: database::HasRowsAffected, for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, @@ -118,6 +120,7 @@ where pool: None, state: AppState { connection_opts, + dialect: get_dialect(DB::NAME), focus, query_task: None, history: vec![], @@ -357,17 +360,7 @@ where Action::LoadMenu => { log::info!("LoadMenu"); if let Some(pool) = &self.pool { - let results = database::query( - "select table_schema, table_name - from information_schema.tables - where table_schema != 'pg_catalog' - and table_schema != 'information_schema' - group by table_schema, table_name - order by table_schema, table_name asc" - .to_owned(), - pool, - ) - .await; + let results = database::query(DB::preview_tables_query(), self.state.dialect.as_ref(), pool).await; self.components.menu.set_table_list(Some(results)); } }, @@ -375,19 +368,20 @@ where let query_string = query_lines.clone().join(" \n"); if !query_string.is_empty() { self.add_to_history(query_lines.clone()); - let dialect = get_dialect(::NAME); - let first_query = database::get_first_query(query_string.clone(), dialect.as_ref()); + let first_query = database::get_first_query(query_string.clone(), self.state.dialect.as_ref()); let should_use_tx = first_query .map(|(_, statement_type)| (database::should_use_tx(statement_type.clone()), statement_type)); let action_tx = action_tx.clone(); if let Some(pool) = &self.pool { let pool = pool.clone(); + let dialect = self.state.dialect.clone(); match should_use_tx { Ok((true, statement_type)) => { self.components.data.set_loading(); let tx = pool.begin().await?; self.state.query_task = Some(DbTask::TxStart(tokio::spawn(async move { - let (results, tx) = database::query_with_tx::(tx, query_string.clone()).await; + let (results, tx) = + database::query_with_tx::(tx, dialect.as_ref(), query_string.clone()).await; match results { Ok(Either::Left(rows_affected)) => { log::info!("{:?} rows affected", rows_affected); @@ -414,8 +408,9 @@ where }, Ok((false, statement_type)) => { self.components.data.set_loading(); + let dialect = self.state.dialect.clone(); self.state.query_task = Some(DbTask::Query(tokio::spawn(async move { - let results = database::query(query_string.clone(), &pool).await; + let results = database::query(query_string.clone(), dialect.as_ref(), &pool).await; match &results { Ok(rows) => { log::info!("{:?} rows, {:?} affected", rows.rows.len(), rows.rows_affected); diff --git a/src/cli.rs b/src/cli.rs index 4ec7f6d..a1b7023 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -28,5 +28,8 @@ pub struct Cli { pub port: Option, #[arg(long = "database", value_name = "DATABASE", help = "Name of database for connection (ex. postgres)")] - pub database: String, + pub database: Option, + + #[arg(long = "driver", value_name = "DRIVER", help = "Driver for database connection (ex. postgres)")] + pub driver: String, } diff --git a/src/components/data.rs b/src/components/data.rs index d9d409b..43a1250 100644 --- a/src/components/data.rs +++ b/src/components/data.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, marker::PhantomData, sync::Arc, time::Duration}; +use std::{collections::HashMap, sync::Arc, time::Duration}; use color_eyre::eyre::Result; use crossterm::{ @@ -232,13 +232,7 @@ impl<'a> SettableDataTable<'a> for Data<'a> { } } -impl<'a, DB> Component for Data<'a> -where - DB: Database + crate::database::ValueParser, - DB::QueryResult: crate::database::HasRowsAffected, - for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, - for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, -{ +impl<'a, DB: Database> Component for Data<'a> { fn register_action_handler(&mut self, tx: UnboundedSender) -> Result<()> { self.command_tx = Some(tx); Ok(()) diff --git a/src/components/editor.rs b/src/components/editor.rs index 82c49de..6e32ac0 100644 --- a/src/components/editor.rs +++ b/src/components/editor.rs @@ -19,7 +19,7 @@ use crate::{ action::{Action, MenuPreview}, app::{App, AppState, DbTask}, config::{Config, KeyBindings}, - database::get_keywords, + database::{self, get_keywords, DatabaseQueries, HasRowsAffected, ValueParser}, focus::Focus, tui::Event, vim::{Mode, Transition, Vim}, @@ -67,13 +67,11 @@ impl<'a> Editor<'a> { } } - pub fn transition_vim_state(&mut self, input: Input, app_state: &AppState<'_, DB>) -> Result<()> - where - DB: Database + crate::database::ValueParser, - DB::QueryResult: crate::database::HasRowsAffected, - for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, - for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, - { + pub fn transition_vim_state( + &mut self, + input: Input, + app_state: &AppState<'_, DB>, + ) -> Result<()> { match input { Input { key: Key::Enter, alt: true, .. } | Input { key: Key::Enter, ctrl: true, .. } => { if app_state.query_task.is_none() { @@ -115,13 +113,7 @@ impl<'a> Editor<'a> { } } -impl<'a, DB> Component for Editor<'a> -where - DB: Database + crate::database::ValueParser, - DB::QueryResult: crate::database::HasRowsAffected, - for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, - for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, -{ +impl<'a, DB: Database + DatabaseQueries> Component for Editor<'a> { fn register_action_handler(&mut self, tx: UnboundedSender) -> Result<()> { self.command_tx = Some(tx); Ok(()) @@ -181,28 +173,11 @@ where return Ok(None); } let query = match preview_type { - MenuPreview::Rows => format!("select * from \"{}\".\"{}\" limit 100", schema, table), - MenuPreview::Columns => { - format!( - "select column_name, * from information_schema.columns where table_schema = '{}' and table_name = '{}'", - schema, table - ) - }, - MenuPreview::Constraints => { - format!( - "select constraint_name, * from information_schema.table_constraints where table_schema = '{}' and table_name = '{}'", - schema, table - ) - }, - MenuPreview::Indexes => { - format!( - "select indexname, indexdef, * from pg_indexes where schemaname = '{}' and tablename = '{}'", - schema, table - ) - }, - MenuPreview::Policies => { - format!("select * from pg_policies where schemaname = '{}' and tablename = '{}'", schema, table) - }, + MenuPreview::Rows => DB::preview_rows_query(&schema, &table), + MenuPreview::Columns => DB::preview_columns_query(&schema, &table), + MenuPreview::Constraints => DB::preview_constraints_query(&schema, &table), + MenuPreview::Indexes => DB::preview_indexes_query(&schema, &table), + MenuPreview::Policies => DB::preview_policies_query(&schema, &table), }; self.textarea = TextArea::from(vec![query.clone()]); self.textarea.set_search_pattern(keyword_regex()).unwrap(); diff --git a/src/components/menu.rs b/src/components/menu.rs index e223287..ec79c7e 100644 --- a/src/components/menu.rs +++ b/src/components/menu.rs @@ -196,13 +196,7 @@ impl<'a> SettableTableList<'a> for Menu { } } -impl Component for Menu -where - DB: Database + crate::database::ValueParser, - DB::QueryResult: crate::database::HasRowsAffected, - for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, - for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, -{ +impl Component for Menu { fn register_action_handler(&mut self, tx: UnboundedSender) -> Result<()> { self.command_tx = Some(tx); Ok(()) diff --git a/src/database.rs b/src/database.rs index e23f181..edef59e 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use futures::stream::{BoxStream, StreamExt}; use sqlparser::{ @@ -15,7 +15,11 @@ use sqlx::{ Column, Connection, Database, Either, Error, Executor, Pool, Row, Transaction, }; +use crate::cli::Cli; + +mod mysql; mod postgresql; +mod sqlite; #[derive(Debug)] pub struct Header { @@ -42,19 +46,36 @@ pub trait HasRowsAffected { fn rows_affected(&self) -> u64; } +pub trait DatabaseQueries { + fn preview_tables_query() -> String; + fn preview_rows_query(schema: &str, table: &str) -> String; + fn preview_columns_query(schema: &str, table: &str) -> String; + fn preview_constraints_query(schema: &str, table: &str) -> String; + fn preview_indexes_query(schema: &str, table: &str) -> String; + fn preview_policies_query(schema: &str, table: &str) -> String; +} + +pub trait ValueParser: Database { + fn parse_value(row: &Self::Row, col: &Self::Column) -> Option; +} + +pub trait BuildConnectionOptions: Database { + fn build_connection_opts(args: Cli) -> color_eyre::eyre::Result<::Options>; +} + pub async fn init_pool(opts: ::Options) -> Result, Error> { - PoolOptions::new().max_connections(5).connect_with(opts).await + PoolOptions::new().max_connections(3).connect_with(opts).await } -pub async fn query(query: String, pool: &Pool) -> Result +// since it's possible for raw_sql to execute multiple queries in a single string, +// we only execute the first one and then drop the rest. +pub async fn query(query: String, dialect: &(dyn Dialect + Sync), pool: &Pool) -> Result where DB: Database + ValueParser, DB::QueryResult: HasRowsAffected, for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, { - // get the db type from - let dialect = get_dialect(DB::NAME); - let first_query = get_first_query(query, dialect.as_ref()); + let first_query = get_first_query(query, dialect); match first_query { Ok((first_query, _)) => { let stream = sqlx::raw_sql(&first_query).fetch_many(pool); @@ -64,58 +85,7 @@ where } } -pub fn get_first_query(query: String, dialect: &dyn Dialect) -> Result<(String, Statement), DbError> { - let ast = Parser::parse_sql(dialect, &query); - match ast { - Ok(ast) if ast.len() > 1 => { - Err(Either::Right(ParserError::ParserError("Only one statement allowed per query".to_owned()))) - }, - Ok(ast) if ast.is_empty() => Err(Either::Right(ParserError::ParserError("Parsed query is empty".to_owned()))), - Ok(ast) => { - let statement = ast[0].clone(); - Ok((statement.to_string(), statement)) - }, - Err(e) => Err(Either::Right(e)), - } -} - -pub fn get_dialect(db_type: &str) -> Box { - match db_type { - "postgres" => Box::new(PostgreSqlDialect {}), - "mysql" => Box::new(MySqlDialect {}), - _ => Box::new(SQLiteDialect {}), - } -} - -pub trait ValueParser: Database { - fn parse_value(row: &Self::Row, col: &Self::Column) -> Option; -} - -pub fn row_to_vec(row: &DB::Row) -> Vec { - row.columns().iter().map(|col| DB::parse_value(row, col).unwrap().string).collect() -} - -pub fn row_to_json(row: &DB::Row) -> HashMap { - let mut result = HashMap::new(); - for col in row.columns() { - let value = match DB::parse_value(row, col) { - Some(v) => v.string, - _ => "[ unsupported ]".to_string(), - }; - result.insert(col.name().to_string(), value); - } - - result -} - -pub fn get_headers(row: &DB::Row) -> Headers { - row - .columns() - .iter() - .map(|col| Header { name: col.name().to_string(), type_name: col.type_info().to_string() }) - .collect() -} - +#[allow(clippy::type_complexity)] pub async fn query_stream<'a, DB>( mut stream: BoxStream<'_, Result, Error>>, ) -> Result @@ -123,42 +93,41 @@ where DB: Database + ValueParser, DB::QueryResult: HasRowsAffected, { - let mut query_finished = false; let mut query_rows = vec![]; let mut query_rows_affected: Option = None; let mut headers: Headers = vec![]; - while !query_finished { - let next = stream.next().await; - match next { - Some(Ok(Either::Left(result))) => { + // I change the implementation of the while loop here as the original one times out mysql connection + while let Some(item) = stream.next().await { + match item { + Ok(Either::Left(result)) => { + // For non-SELECT queries query_rows_affected = Some(result.rows_affected()); - query_finished = true; }, - Some(Ok(Either::Right(row))) => { + Ok(Either::Right(row)) => { + // For SELECT queries query_rows.push(row_to_vec::(&row)); if headers.is_empty() { headers = get_headers::(&row); } }, - Some(Err(e)) => return Err(Either::Left(e)), - None => return Err(Either::Left(Error::Protocol("Results stream empty".to_owned()))), - }; + Err(e) => return Err(Either::Left(e)), + } } Ok(Rows { rows_affected: query_rows_affected, headers, rows: query_rows }) } pub async fn query_with_tx<'a, DB>( - mut tx: Transaction<'_, DB>, + mut tx: Transaction<'static, DB>, + dialect: &(dyn Dialect + Sync), query: String, -) -> (Result, DbError>, Transaction<'_, DB>) +) -> (Result, DbError>, Transaction<'static, DB>) where DB: Database + ValueParser, DB::QueryResult: HasRowsAffected, for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, { - let dialect = get_dialect(DB::NAME); - let first_query = get_first_query(query, dialect.as_ref()); + let first_query = get_first_query(query, dialect); match first_query { Ok((first_query, statement_type)) => { match statement_type { @@ -183,6 +152,21 @@ where } } +pub fn get_first_query(query: String, dialect: &dyn Dialect) -> Result<(String, Statement), DbError> { + let ast = Parser::parse_sql(dialect, &query); + match ast { + Ok(ast) if ast.len() > 1 => { + Err(Either::Right(ParserError::ParserError("Only one statement allowed per query".to_owned()))) + }, + Ok(ast) if ast.is_empty() => Err(Either::Right(ParserError::ParserError("Parsed query is empty".to_owned()))), + Ok(ast) => { + let statement = ast[0].clone(); + Ok((statement.to_string(), statement)) + }, + Err(e) => Err(Either::Right(e)), + } +} + pub fn statement_type_string(statement: &Statement) -> String { format!("{:?}", statement).split('(').collect::>()[0].split('{').collect::>()[0] .split('[') @@ -205,6 +189,27 @@ pub fn should_use_tx(statement: Statement) -> bool { } } +pub fn get_headers(row: &DB::Row) -> Headers { + row + .columns() + .iter() + .map(|col| Header { name: col.name().to_string(), type_name: col.type_info().to_string() }) + .collect() +} + +pub fn row_to_json(row: &DB::Row) -> HashMap { + let mut result = HashMap::new(); + for col in row.columns() { + let value = match DB::parse_value(row, col) { + Some(v) => v.string, + _ => "[ unsupported ]".to_string(), + }; + result.insert(col.name().to_string(), value); + } + + result +} + pub fn vec_to_string(vec: Vec) -> String { let mut content = String::new(); for (i, elem) in vec.iter().enumerate() { @@ -216,6 +221,18 @@ pub fn vec_to_string(vec: Vec) -> String { "{ ".to_owned() + &*content + &*" }".to_owned() } +pub fn row_to_vec(row: &DB::Row) -> Vec { + row.columns().iter().map(|col| DB::parse_value(row, col).unwrap().string).collect() +} + pub fn get_keywords() -> Vec { keywords::ALL_KEYWORDS.iter().map(|k| k.to_string()).collect() } + +pub fn get_dialect(db_type: &str) -> Arc { + match db_type { + "PostgreSQL" => Arc::new(PostgreSqlDialect {}), + "MySQL" => Arc::new(MySqlDialect {}), + _ => Arc::new(SQLiteDialect {}), + } +} diff --git a/src/database/mysql.rs b/src/database/mysql.rs index 058f4d8..e039747 100644 --- a/src/database/mysql.rs +++ b/src/database/mysql.rs @@ -1,5 +1,16 @@ -use super::Value; -use crate::generic_database::ValueParser; +use std::{ + fmt::Write, + io::{self, Write as _}, + str::FromStr, +}; + +use serde_json; +use sqlx::{ + mysql::{MySql, MySqlConnectOptions, MySqlQueryResult}, + Column, Database, Row, ValueRef, +}; + +use super::{vec_to_string, Value}; impl super::HasRowsAffected for MySqlQueryResult { fn rows_affected(&self) -> u64 { @@ -7,9 +18,246 @@ impl super::HasRowsAffected for MySqlQueryResult { } } -impl ValueParser for MySql { - fn parse_value(row: &Self::Row, col: &Self::Column) -> Option { - // MySQL-specific parsing - todo!() +impl super::BuildConnectionOptions for MySql { + fn build_connection_opts( + args: crate::cli::Cli, + ) -> color_eyre::eyre::Result<::Options> { + match args.connection_url { + Some(url) => Ok(MySqlConnectOptions::from_str(&url)?), + None => { + let mut opts = MySqlConnectOptions::new(); + + // Username + if let Some(user) = args.user { + opts = opts.username(&user); + } else { + let mut user = String::new(); + print!("username: "); + io::stdout().flush()?; + io::stdin().read_line(&mut user)?; + let user = user.trim(); + if !user.is_empty() { + opts = opts.username(user); + } + } + + // Password + if let Some(password) = args.password { + opts = opts.password(&password); + } else { + let password = rpassword::prompt_password(format!("password for user {}: ", opts.get_username())).unwrap(); + let password = password.trim(); + if !password.is_empty() { + opts = opts.password(password); + } + } + + // Host + if let Some(host) = args.host { + opts = opts.host(&host); + } else { + let mut host = String::new(); + print!("host (ex. localhost): "); + io::stdout().flush()?; + io::stdin().read_line(&mut host)?; + let host = host.trim(); + if !host.is_empty() { + opts = opts.host(host); + } + } + + // Port + if let Some(port) = args.port { + opts = opts.port(port); + } else { + let mut port = String::new(); + print!("port (ex. 3306): "); + io::stdout().flush()?; + io::stdin().read_line(&mut port)?; + let port = port.trim(); + if !port.is_empty() { + opts = opts.port(port.parse()?); + } + } + + // Database + if let Some(database) = args.database { + opts = opts.database(&database); + } else { + let mut database = String::new(); + print!("database (ex. mydb): "); + io::stdout().flush()?; + io::stdin().read_line(&mut database)?; + let database = database.trim(); + if !database.is_empty() { + opts = opts.database(database); + } + } + + Ok(opts) + }, + } + } +} + +impl super::DatabaseQueries for MySql { + fn preview_tables_query() -> String { + "select table_schema as table_schema, table_name as table_name + from information_schema.tables + where table_schema not in ('mysql', 'information_schema', 'performance_schema', 'sys') + order by table_schema, table_name asc" + .to_owned() + } + + fn preview_rows_query(schema: &str, table: &str) -> String { + format!("select * from `{}`.`{}` limit 100", schema, table) + } + + fn preview_columns_query(schema: &str, table: &str) -> String { + format!( + "select column_name, data_type, is_nullable, column_default, extra, column_comment + from information_schema.columns + where table_schema = '{}' and table_name = '{}' + order by ordinal_position", + schema, table + ) + } + + fn preview_constraints_query(schema: &str, table: &str) -> String { + format!( + "select constraint_name, constraint_type, enforced, + group_concat(column_name order by ordinal_position) as column_names + from information_schema.table_constraints + join information_schema.key_column_usage using (constraint_schema, constraint_name, table_schema, table_name) + where table_schema = '{}' and table_name = '{}' + group by constraint_name, constraint_type, enforced + order by constraint_type, constraint_name", + schema, table + ) + } + + fn preview_indexes_query(schema: &str, table: &str) -> String { + format!( + "select index_name, column_name, non_unique, seq_in_index, index_type + from information_schema.statistics + where table_schema = '{}' and table_name = '{}' + order by index_name, seq_in_index", + schema, table + ) + } + + fn preview_policies_query(_schema: &str, _table: &str) -> String { + "select 'MySQL does not support row-level security policies' as message".to_owned() + } +} + +impl super::ValueParser for MySql { + fn parse_value(row: &::Row, col: &::Column) -> Option { + let col_type = col.type_info().to_string(); + let raw_value = row.try_get_raw(col.ordinal()).ok()?; + if raw_value.is_null() { + return Some(Value { string: "NULL".to_string(), is_null: true }); + } + match col_type.to_uppercase().as_str() { + "TINYINT(1)" | "BOOLEAN" | "BOOL" => { + let received: bool = row.try_get(col.ordinal()).ok()?; + Some(Value { string: received.to_string(), is_null: false }) + }, + "TINYINT" => { + let received: i8 = row.try_get(col.ordinal()).ok()?; + Some(Value { string: received.to_string(), is_null: false }) + }, + "SMALLINT" => { + let received: i16 = row.try_get(col.ordinal()).ok()?; + Some(Value { string: received.to_string(), is_null: false }) + }, + "INT" => { + let received: i32 = row.try_get(col.ordinal()).ok()?; + Some(Value { string: received.to_string(), is_null: false }) + }, + "BIGINT" => { + let received: i64 = row.try_get(col.ordinal()).ok()?; + Some(Value { string: received.to_string(), is_null: false }) + }, + "TINYINT UNSIGNED" => { + let received: u8 = row.try_get(col.ordinal()).ok()?; + Some(Value { string: received.to_string(), is_null: false }) + }, + "SMALLINT UNSIGNED" => { + let received: u16 = row.try_get(col.ordinal()).ok()?; + Some(Value { string: received.to_string(), is_null: false }) + }, + "INT UNSIGNED" => { + let received: u32 = row.try_get(col.ordinal()).ok()?; + Some(Value { string: received.to_string(), is_null: false }) + }, + "BIGINT UNSIGNED" => { + let received: u64 = row.try_get(col.ordinal()).ok()?; + Some(Value { string: received.to_string(), is_null: false }) + }, + "FLOAT" => { + let received: f32 = row.try_get(col.ordinal()).ok()?; + Some(Value { string: received.to_string(), is_null: false }) + }, + "DOUBLE" => { + let received: f64 = row.try_get(col.ordinal()).ok()?; + Some(Value { string: received.to_string(), is_null: false }) + }, + "VARCHAR" | "CHAR" | "TEXT" | "BINARY" => { + let received = row.try_get::(col.ordinal()).ok()?; + Some(Value { string: received, is_null: false }) + }, + "VARBINARY" | "BLOB" => { + let received: Vec = row.try_get(col.ordinal()).ok()?; + if let Ok(s) = String::from_utf8(received.clone()) { + Some(Value { string: s, is_null: false }) + } else { + Some(Value { + string: received.iter().fold(String::new(), |mut output, b| { + let _ = write!(output, "{b:02X}"); + output + }), + is_null: false, + }) + } + }, + "INET4" | "INET6" => { + let received: std::net::IpAddr = row.try_get(col.ordinal()).ok()?; + Some(Value { string: received.to_string(), is_null: false }) + }, + "TIME" => { + if let Ok(received) = row.try_get::(col.ordinal()) { + Some(Value { string: received.to_string(), is_null: false }) + } else { + let received: chrono::TimeDelta = row.try_get(col.ordinal()).ok()?; + Some(Value { string: received.to_string(), is_null: false }) + } + }, + "DATE" => { + let received: chrono::NaiveDate = row.try_get(col.ordinal()).ok()?; + Some(Value { string: received.to_string(), is_null: false }) + }, + "DATETIME" => { + let received: chrono::NaiveDateTime = row.try_get(col.ordinal()).ok()?; + Some(Value { string: received.to_string(), is_null: false }) + }, + "TIMESTAMP" => { + let received: chrono::DateTime = row.try_get(col.ordinal()).ok()?; + Some(Value { string: received.to_string(), is_null: false }) + }, + "JSON" => { + let received: serde_json::Value = row.try_get(col.ordinal()).ok()?; + Some(Value { string: received.to_string(), is_null: false }) + }, + "GEOMETRY" => { + // TODO: would have to resort to geozero to parse WKB + Some(Value { string: "TODO".to_owned(), is_null: false }) + }, + _ => { + // Try to cast custom or other types to strings + let received: String = row.try_get_unchecked(col.ordinal()).ok()?; + Some(Value { string: received, is_null: false }) + }, + } } } diff --git a/src/database/postgresql.rs b/src/database/postgresql.rs index 6e0afc0..340cb09 100644 --- a/src/database/postgresql.rs +++ b/src/database/postgresql.rs @@ -1,33 +1,155 @@ -use std::{collections::HashMap, fmt::Write, pin::Pin, string::String}; +use std::{ + fmt::Write, + io::{self, Write as _}, + str::FromStr, + string::String, +}; use futures::stream::{BoxStream, StreamExt}; use sqlparser::{ ast::Statement, dialect::PostgreSqlDialect, - keywords, parser::{Parser, ParserError}, }; use sqlx::{ - postgres::{ - PgColumn, PgConnectOptions, PgPool, PgPoolOptions, PgQueryResult, PgRow, PgTypeInfo, PgTypeKind, PgValueRef, - Postgres, - }, + postgres::{PgConnectOptions, PgQueryResult, Postgres}, types::Uuid, - Column, Database, Either, Error, Pool, Row, Transaction, ValueRef, + Column, Database, Either, Row, ValueRef, }; use super::{vec_to_string, Value}; +impl super::BuildConnectionOptions for sqlx::Postgres { + fn build_connection_opts( + args: crate::cli::Cli, + ) -> color_eyre::eyre::Result<::Options> { + match args.connection_url { + Some(url) => Ok(PgConnectOptions::from_str(&url)?), + None => { + let mut opts = PgConnectOptions::new(); + + if let Some(user) = args.user { + opts = opts.username(&user); + } else { + let mut user: String = String::new(); + print!("username: "); + io::stdout().flush().unwrap(); + io::stdin().read_line(&mut user).unwrap(); + user = user.trim().to_string(); + if !user.is_empty() { + opts = opts.username(&user); + } + } + + if let Some(password) = args.password { + opts = opts.password(&password); + } else { + let mut password = + rpassword::prompt_password(format!("password for user {}: ", opts.get_username())).unwrap(); + password = password.trim().to_string(); + if !password.is_empty() { + opts = opts.password(&password); + } + } + + if let Some(host) = args.host { + opts = opts.host(&host); + } else { + let mut host: String = String::new(); + print!("host (ex. localhost): "); + io::stdout().flush().unwrap(); + io::stdin().read_line(&mut host).unwrap(); + host = host.trim().to_string(); + if !host.is_empty() { + opts = opts.host(&host); + } + } + + if let Some(port) = args.port { + opts = opts.port(port); + } else { + let mut port: String = String::new(); + print!("port (ex. 5432): "); + io::stdout().flush().unwrap(); + io::stdin().read_line(&mut port).unwrap(); + port = port.trim().to_string(); + if !port.is_empty() { + opts = opts.port(port.parse()?); + } + } + + if let Some(database) = args.database { + opts = opts.database(&database); + } else { + let mut database: String = String::new(); + print!("database (ex. postgres): "); + io::stdout().flush().unwrap(); + io::stdin().read_line(&mut database).unwrap(); + database = database.trim().to_string(); + if !database.is_empty() { + opts = opts.database(&database); + } + } + + Ok(opts) + }, + } + } +} + impl super::HasRowsAffected for PgQueryResult { fn rows_affected(&self) -> u64 { self.rows_affected() } } +impl super::DatabaseQueries for Postgres { + fn preview_tables_query() -> String { + "select table_schema, table_name + from information_schema.tables + where table_schema != 'pg_catalog' + and table_schema != 'information_schema' + group by table_schema, table_name + order by table_schema, table_name asc" + .to_owned() + } + + fn preview_rows_query(schema: &str, table: &str) -> String { + format!("select * from \"{}\".\"{}\" limit 100", schema, table) + } + + fn preview_columns_query(schema: &str, table: &str) -> String { + format!( + "select column_name, * from information_schema.columns where table_schema = '{}' and table_name = '{}'", + schema, table + ) + } + + fn preview_constraints_query(schema: &str, table: &str) -> String { + format!( + "select constraint_name, * from information_schema.table_constraints where table_schema = '{}' and table_name = '{}'", + schema, table + ) + } + + fn preview_indexes_query(schema: &str, table: &str) -> String { + format!("select indexname, indexdef, * from pg_indexes where schemaname = '{}' and tablename = '{}'", schema, table) + } + + fn preview_policies_query(schema: &str, table: &str) -> String { + format!("select * from pg_policies where schemaname = '{}' and tablename = '{}'", schema, table) + } +} + impl super::ValueParser for Postgres { + // parsed based on https://docs.rs/sqlx/latest/sqlx/postgres/types/index.html fn parse_value(row: &::Row, col: &::Column) -> Option { let col_type = col.type_info().to_string(); let raw_value = row.try_get_raw(col.ordinal()).unwrap(); + // if col.name() == "dimensions" { + // let received: String = row.try_get_unchecked(col.ordinal()).unwrap(); + // println!("col_type: {:?}, {:?}", col_type, received); + // } if raw_value.is_null() { return Some(Value { string: "NULL".to_string(), is_null: true }); } @@ -184,10 +306,12 @@ impl super::ValueParser for Postgres { } } mod tests { + use std::sync::Arc; + use sqlparser::{ast::Statement, dialect::PostgreSqlDialect, parser::Parser}; use super::*; - use crate::database::{get_first_query, DbError}; + use crate::database::{get_first_query, should_use_tx, DbError}; #[test] fn test_get_first_query() { @@ -250,8 +374,10 @@ mod tests { ), ]; + let dialect = Box::new(PostgreSqlDialect {}); + for (input, expected_output) in test_cases { - let result = get_first_query(input.to_string(), &PostgreSqlDialect {}); + let result = get_first_query(input.to_string(), dialect.as_ref()); match (result, expected_output) { (Ok((query, statement_type)), Ok((expected_query, match_statement))) => { assert_eq!(query, expected_query); @@ -285,7 +411,7 @@ mod tests { for (query, expected) in test_cases { let ast = Parser::parse_sql(&dialect, query).unwrap(); let statement = ast[0].clone(); - // assert_eq!(should_use_tx(statement), expected, "Failed for query: {}", query); + assert_eq!(should_use_tx(statement), expected, "Failed for query: {}", query); } } } diff --git a/src/database/sqlite.rs b/src/database/sqlite.rs index b23b6b6..796a369 100644 --- a/src/database/sqlite.rs +++ b/src/database/sqlite.rs @@ -1,12 +1,180 @@ +use std::{ + fmt::Write, + io::{self, Write as _}, + str::FromStr, + string::String, +}; + +use serde_json; +use sqlx::{ + sqlite::{Sqlite, SqliteConnectOptions, SqliteQueryResult}, + types::{chrono, uuid, Uuid}, + Column, Database, Row, ValueRef, +}; + +use super::{vec_to_string, Value}; +use crate::cli::Cli; + +impl super::BuildConnectionOptions for sqlx::Sqlite { + fn build_connection_opts(args: Cli) -> color_eyre::eyre::Result<::Options> { + match args.connection_url { + Some(url) => Ok(SqliteConnectOptions::from_str(&url)?), + None => { + let filename = if let Some(database) = args.database { + database + } else { + let mut database = String::new(); + print!("database file path (or ':memory:'): "); + io::stdout().flush()?; + io::stdin().read_line(&mut database)?; + let database = database.trim().to_string(); + if database.is_empty() { + return Err(color_eyre::eyre::Report::msg("Database file path is required")); + } + database + }; + + let opts = SqliteConnectOptions::new().filename(&filename); + Ok(opts) + }, + } + } +} + +impl super::DatabaseQueries for Sqlite { + fn preview_tables_query() -> String { + "select '' as table_schema, name as table_name + from sqlite_master + where type = 'table' + and name not like 'sqlite_%' + order by name asc" + .to_owned() + } + + fn preview_rows_query(_schema: &str, table: &str) -> String { + format!("select * from \"{}\" limit 100", table) + } + + fn preview_columns_query(_schema: &str, table: &str) -> String { + format!("pragma table_info(\"{}\")", table) + } + + fn preview_constraints_query(_schema: &str, table: &str) -> String { + format!("pragma foreign_key_list(\"{}\")", table) + } + + fn preview_indexes_query(_schema: &str, table: &str) -> String { + format!("pragma index_list(\"{}\")", table) + } + + fn preview_policies_query(_schema: &str, _table: &str) -> String { + "select 'SQLite does not support row-level security policies' as message".to_owned() + } +} + impl super::HasRowsAffected for SqliteQueryResult { fn rows_affected(&self) -> u64 { self.rows_affected() } } +// macro_rules! parse_nullable { +// ($type:ty) => { +// match row.try_get::, _>(col.ordinal()) { +// Ok(Some(value)) => Some(Value { string: value.to_string(), is_null: false }), +// Ok(None) => Some(Value { string: "NULL".to_string(), is_null: true }), +// Err(_) => None, +// } +// }; +// } -impl ValueParser for Sqlite { - fn parse_value(row: &Self::Row, col: &Self::Column) -> Option { - // SQLite-specific parsing - todo!() +impl super::ValueParser for Sqlite { + fn parse_value(row: &::Row, col: &::Column) -> Option { + let col_type = col.type_info().to_string(); + let raw_value = row.try_get_raw(col.ordinal()).unwrap(); + if raw_value.is_null() { + return Some(Value { string: "NULL".to_string(), is_null: true }); + } + match col_type.to_uppercase().as_str() { + "BOOLEAN" => { + let received: bool = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: received.to_string(), is_null: false }) + }, + "INTEGER" | "INT4" | "INT8" | "BIGINT" => { + let received: i64 = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: received.to_string(), is_null: false }) + }, + "REAL" => { + let received: f64 = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: received.to_string(), is_null: false }) + }, + "TEXT" => { + // Try parsing as different types that might be stored as TEXT + if let Ok(dt) = row.try_get::(col.ordinal()) { + Some(Value { string: dt.to_string(), is_null: false }) + } else if let Ok(dt) = row.try_get::, _>(col.ordinal()) { + Some(Value { string: dt.to_string(), is_null: false }) + } else if let Ok(date) = row.try_get::(col.ordinal()) { + Some(Value { string: date.to_string(), is_null: false }) + } else if let Ok(time) = row.try_get::(col.ordinal()) { + Some(Value { string: time.to_string(), is_null: false }) + } else if let Ok(uuid) = row.try_get::(col.ordinal()) { + Some(Value { string: uuid.to_string(), is_null: false }) + } else if let Ok(json) = row.try_get::(col.ordinal()) { + Some(Value { string: json.to_string(), is_null: false }) + } else { + let received: String = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: received, is_null: false }) + } + }, + "BLOB" => { + let received: Vec = row.try_get(col.ordinal()).unwrap(); + if let Ok(s) = String::from_utf8(received.clone()) { + Some(Value { string: s, is_null: false }) + } else { + Some(Value { + string: received.iter().fold(String::new(), |mut output, b| { + let _ = write!(output, "{b:02X}"); + output + }), + is_null: false, + }) + } + }, + "DATETIME" => { + // Similar to TEXT, but we'll try timestamp first + if let Ok(dt) = row.try_get::(col.ordinal()) { + let dt = chrono::DateTime::from_timestamp(dt, 0).unwrap(); + Some(Value { string: dt.to_string(), is_null: false }) + } else if let Ok(dt) = row.try_get::(col.ordinal()) { + Some(Value { string: dt.to_string(), is_null: false }) + } else if let Ok(dt) = row.try_get::, _>(col.ordinal()) { + Some(Value { string: dt.to_string(), is_null: false }) + } else { + let received: String = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: received, is_null: false }) + } + }, + "DATE" => { + if let Ok(date) = row.try_get::(col.ordinal()) { + Some(Value { string: date.to_string(), is_null: false }) + } else { + let received: String = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: received, is_null: false }) + } + }, + "TIME" => { + if let Ok(time) = row.try_get::(col.ordinal()) { + Some(Value { string: time.to_string(), is_null: false }) + } else { + let received: String = row.try_get(col.ordinal()).unwrap(); + Some(Value { string: received, is_null: false }) + } + }, + _ => { + // For any other types, try to cast to string + let received: String = row.try_get_unchecked(col.ordinal()).unwrap(); + Some(Value { string: received, is_null: false }) + }, + } } } diff --git a/src/main.rs b/src/main.rs index cfa1d4a..df5ef9f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -22,30 +22,39 @@ use std::{ use clap::Parser; use cli::Cli; use color_eyre::eyre::{self, Result}; -use sqlx::{postgres::PgConnectOptions, Connection, Database, Executor, Pool, Postgres}; +use database::{BuildConnectionOptions, DatabaseQueries, HasRowsAffected, ValueParser}; +use sqlx::{postgres::PgConnectOptions, Connection, Database, Executor, MySql, Pool, Postgres, Sqlite}; use crate::{ app::App, utils::{initialize_logging, initialize_panic_handler, version}, }; +async fn run_app(args: Cli) -> Result<()> +where + DB: Database + BuildConnectionOptions + ValueParser + DatabaseQueries, + DB::QueryResult: HasRowsAffected, + for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, +{ + let connection_opts = DB::build_connection_opts(args)?; + let mut app = App::<'_, DB>::new(connection_opts)?; + app.run().await?; + Ok(()) +} + async fn tokio_main() -> Result<()> { initialize_logging()?; initialize_panic_handler()?; let args = Cli::parse(); - match args.database.as_str() { - "postgres" => { - let connection_opts = build_connection_opts::(args.clone())?; - let mut app = App::<'_, Postgres>::new(connection_opts)?; - app.run().await?; - }, - "mysql" => todo!(), - "sqlite" => todo!(), - _ => return Err(eyre::Report::msg("Please specify a database type")), + match args.driver.as_str() { + "postgres" => run_app::(args).await, + "mysql" => run_app::(args).await, + "sqlite" => run_app::(args).await, + _ => Err(eyre::Report::msg("Please provide a valid a database type")), } - Ok(()) } #[tokio::main] @@ -57,80 +66,3 @@ async fn main() -> Result<()> { Ok(()) } } - -// sqlx defaults to reading from environment variables if no inputs are provided -fn build_connection_opts(args: Cli) -> Result<::Options> -where -{ - match args.connection_url { - Some(url) => Ok(::Options::from_str(&url)?), - None => { - // let mut opts = ::Options::new(); - - // if let Some(user) = args.user { - // opts = opts.username(&user); - // } else { - // let mut user: String = String::new(); - // print!("username: "); - // io::stdout().flush().unwrap(); - // io::stdin().read_line(&mut user).unwrap(); - // user = user.trim().to_string(); - // if !user.is_empty() { - // opts = opts.username(&user); - // } - // } - - // if let Some(password) = args.password { - // opts = opts.password(&password); - // } else { - // let mut password = rpassword::prompt_password(format!("password for user {}: ", opts.get_username())).unwrap(); - // password = password.trim().to_string(); - // if !password.is_empty() { - // opts = opts.password(&password); - // } - // } - - // if let Some(host) = args.host { - // opts = opts.host(&host); - // } else { - // let mut host: String = String::new(); - // print!("host (ex. localhost): "); - // io::stdout().flush().unwrap(); - // io::stdin().read_line(&mut host).unwrap(); - // host = host.trim().to_string(); - // if !host.is_empty() { - // opts = opts.host(&host); - // } - // } - - // if let Some(port) = args.port { - // opts = opts.port(port); - // } else { - // let mut port: String = String::new(); - // print!("port (ex. 5432): "); - // io::stdout().flush().unwrap(); - // io::stdin().read_line(&mut port).unwrap(); - // port = port.trim().to_string(); - // if !port.is_empty() { - // opts = opts.port(port.parse()?); - // } - // } - - // if let Some(database) = args.database { - // opts = opts.database(&database); - // } else { - // let mut database: String = String::new(); - // print!("database (ex. postgres): "); - // io::stdout().flush().unwrap(); - // io::stdin().read_line(&mut database).unwrap(); - // database = database.trim().to_string(); - // if !database.is_empty() { - // opts = opts.database(&database); - // } - // } - - // Ok(opts) - Err(eyre::Report::msg("Not implemented")) - }, - } -} From d41ffc03c443956c9350e41a159ddd57b4db9123 Mon Sep 17 00:00:00 2001 From: Frank Wang <1454884738@qq.com> Date: Thu, 26 Sep 2024 17:03:57 -0500 Subject: [PATCH 06/16] fix merge issue --- src/main.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/main.rs b/src/main.rs index 810bf70..46a4855 100644 --- a/src/main.rs +++ b/src/main.rs @@ -32,15 +32,16 @@ use crate::{ utils::{initialize_logging, initialize_panic_handler, version}, }; -async fn run_app(args: Cli) -> Result<()> +async fn run_app(mut args: Cli) -> Result<()> where DB: Database + BuildConnectionOptions + ValueParser + DatabaseQueries, DB::QueryResult: HasRowsAffected, for<'c> ::Arguments<'c>: sqlx::IntoArguments<'c, DB>, for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, { + let mouse_mode = args.mouse_mode.take(); let connection_opts = DB::build_connection_opts(args)?; - let mut app = App::<'_, DB>::new(connection_opts)?; + let mut app = App::<'_, DB>::new(connection_opts, mouse_mode)?; app.run().await?; Ok(()) } From 33f15142e1ceb8ca0eb485966f2860890d46334f Mon Sep 17 00:00:00 2001 From: Frank Wang <1454884738@qq.com> Date: Fri, 27 Sep 2024 07:45:32 -0500 Subject: [PATCH 07/16] some cleanup --- src/app.rs | 2 +- src/components/menu.rs | 11 ++--------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/src/app.rs b/src/app.rs index 6616bf0..c131b32 100644 --- a/src/app.rs +++ b/src/app.rs @@ -68,7 +68,7 @@ pub struct AppState<'a, DB: Database> { pub last_query_end: Option>, } -pub struct Components<'a, DB: sqlx::Database> { +pub struct Components<'a, DB> { pub menu: Box>, pub editor: Box>, pub history: Box>, diff --git a/src/components/menu.rs b/src/components/menu.rs index ec79c7e..9f677e6 100644 --- a/src/components/menu.rs +++ b/src/components/menu.rs @@ -35,16 +35,9 @@ pub trait SettableTableList<'a> { fn set_table_list(&mut self, data: Option>); } -pub trait MenuComponent<'a, DB>: Component + SettableTableList<'a> -where - DB: sqlx::Database, -{ -} +pub trait MenuComponent<'a, DB: Database>: Component + SettableTableList<'a> {} -impl<'a, T, DB> MenuComponent<'a, DB> for T -where - T: Component + SettableTableList<'a>, - DB: sqlx::Database, +impl<'a, T, DB: Database> MenuComponent<'a, DB> for T where T: Component + SettableTableList<'a> { } From 4b3ad8f54f7ca689bbe9b61ff02f242cf50931a1 Mon Sep 17 00:00:00 2001 From: Frank Wang <1454884738@qq.com> Date: Fri, 27 Sep 2024 17:16:36 -0500 Subject: [PATCH 08/16] chore: add test & cli update --- src/cli.rs | 46 ++++++++++++++++- src/database/mysql.rs | 111 +++++++++++++++++++++++++++++++++++++++++ src/database/sqlite.rs | 111 +++++++++++++++++++++++++++++++++++++++++ src/main.rs | 20 +++++--- 4 files changed, 279 insertions(+), 9 deletions(-) diff --git a/src/cli.rs b/src/cli.rs index fde192d..71e5cda 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1,6 +1,11 @@ -use std::path::PathBuf; +use std::{ + io::{self, Write}, + path::PathBuf, + str::FromStr, +}; use clap::Parser; +use color_eyre::eyre::{self, Result}; use crate::utils::version; @@ -39,5 +44,42 @@ pub struct Cli { pub database: Option, #[arg(long = "driver", value_name = "DRIVER", help = "Driver for database connection (ex. postgres)")] - pub driver: String, + pub driver: Option, +} + +#[derive(Parser, Debug, Clone)] +pub enum Driver { + Postgres, + Mysql, + Sqlite, +} + +impl FromStr for Driver { + type Err = eyre::Report; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "postgres" | "postgresql" => Ok(Driver::Postgres), + "mysql" => Ok(Driver::Mysql), + "sqlite" => Ok(Driver::Sqlite), + _ => Err(eyre::Report::msg("Invalid driver")), + } + } +} + +pub fn extract_driver_from_url(url: &str) -> Result { + let url = url.trim(); + if let Some(pos) = url.find("://") { + url[..pos].to_lowercase().parse() + } else { + Err(eyre::Report::msg("Invalid connection URL format")) + } +} + +pub fn prompt_for_driver() -> Result { + let mut driver = String::new(); + print!("Database driver (postgres, mysql, sqlite): "); + io::stdout().flush()?; + io::stdin().read_line(&mut driver)?; + driver.trim().to_lowercase().parse() } diff --git a/src/database/mysql.rs b/src/database/mysql.rs index e039747..33154ac 100644 --- a/src/database/mysql.rs +++ b/src/database/mysql.rs @@ -5,6 +5,7 @@ use std::{ }; use serde_json; +use sqlparser::ast::Statement; use sqlx::{ mysql::{MySql, MySqlConnectOptions, MySqlQueryResult}, Column, Database, Row, ValueRef, @@ -261,3 +262,113 @@ impl super::ValueParser for MySql { } } } + +mod tests { + use sqlparser::{ + ast::Statement, + dialect::MySqlDialect, + parser::{Parser, ParserError}, + }; + + use super::*; + use crate::database::{get_first_query, should_use_tx, DbError}; + + #[test] + fn test_get_first_query_mysql() { + type TestCase = (&'static str, Result<(String, Box bool>), DbError>); + + let test_cases: Vec = vec![ + // single query + ("SELECT * FROM users;", Ok(("SELECT * FROM users".to_string(), Box::new(|s| matches!(s, Statement::Query(_)))))), + // multiple queries + ( + "SELECT * FROM users; DELETE FROM posts;", + Err(DbError::Right(ParserError::ParserError("Only one statement allowed per query".to_owned()))), + ), + // empty query + ("", Err(DbError::Right(ParserError::ParserError("Parsed query is empty".to_owned())))), + // syntax error + ( + "SELEC * FORM users;", + Err(DbError::Right(ParserError::ParserError( + "Expected: an SQL statement, found: SELEC at Line: 1, Column: 1".to_owned(), + ))), + ), + // lowercase + ( + "select * from `users`", + Ok(("SELECT * FROM `users`".to_owned(), Box::new(|s| matches!(s, Statement::Query(_))))), + ), + // newlines + ("select *\nfrom users;", Ok(("SELECT * FROM users".to_owned(), Box::new(|s| matches!(s, Statement::Query(_)))))), + // comment-only + ("-- select * from users;", Err(DbError::Right(ParserError::ParserError("Parsed query is empty".to_owned())))), + // commented line(s) + ( + "-- select blah;\nselect * from users", + Ok(("SELECT * FROM users".to_owned(), Box::new(|s| matches!(s, Statement::Query(_))))), + ), + // update + ( + "UPDATE users SET name = 'John' WHERE id = 1", + Ok(( + "UPDATE users SET name = 'John' WHERE id = 1".to_owned(), + Box::new(|s| matches!(s, Statement::Update { .. })), + )), + ), + // delete + ( + "DELETE FROM users WHERE id = 1", + Ok(("DELETE FROM users WHERE id = 1".to_owned(), Box::new(|s| matches!(s, Statement::Delete { .. })))), + ), + // drop + ("DROP TABLE users", Ok(("DROP TABLE users".to_owned(), Box::new(|s| matches!(s, Statement::Drop { .. }))))), + // explain + ( + "EXPLAIN SELECT * FROM users", + Ok(("EXPLAIN SELECT * FROM users".to_owned(), Box::new(|s| matches!(s, Statement::Explain { .. })))), + ), + ]; + + let dialect = Box::new(MySqlDialect {}); + + for (input, expected_output) in test_cases { + let result = get_first_query(input.to_string(), dialect.as_ref()); + match (result, expected_output) { + (Ok((query, statement)), Ok((expected_query, match_statement))) => { + assert_eq!(query, expected_query); + assert!(match_statement(&statement)); + }, + ( + Err(DbError::Right(ParserError::ParserError(msg))), + Err(DbError::Right(ParserError::ParserError(expected_msg))), + ) => { + assert_eq!(msg, expected_msg); + }, + _ => panic!("Unexpected result for input: {}", input), + } + } + } + + #[test] + fn test_should_use_tx_mysql() { + let dialect = MySqlDialect {}; + let test_cases = vec![ + ("DELETE FROM users WHERE id = 1", true), + // TODO: fix this + ("DROP TABLE users", false), // In MySQL, DROP TABLE causes an implicit commit + ("UPDATE users SET name = 'John' WHERE id = 1", true), + ("SELECT * FROM users", false), + ("INSERT INTO users (name) VALUES ('John')", true), + // EXPLAIN statements in MySQL + ("EXPLAIN DELETE FROM users WHERE id = 1", false), + ("EXPLAIN SELECT * FROM users", false), + ]; + + for (query, expected) in test_cases { + let ast = Parser::parse_sql(&dialect, query).unwrap(); + let statement = ast[0].clone(); + assert_eq!(should_use_tx(statement), expected, "Failed for query: {}", query); + } + } +} diff --git a/src/database/sqlite.rs b/src/database/sqlite.rs index 796a369..e0bf30d 100644 --- a/src/database/sqlite.rs +++ b/src/database/sqlite.rs @@ -178,3 +178,114 @@ impl super::ValueParser for Sqlite { } } } + +mod tests { + use sqlparser::{ + ast::Statement, + dialect::SQLiteDialect, + parser::{Parser, ParserError}, + }; + + use super::*; + use crate::database::{get_first_query, should_use_tx, DbError}; + + #[test] + fn test_get_first_query_sqlite() { + type TestCase = (&'static str, Result<(String, Box bool>), DbError>); + + let test_cases: Vec = vec![ + // single query + ("SELECT * FROM users;", Ok(("SELECT * FROM users".to_string(), Box::new(|s| matches!(s, Statement::Query(_)))))), + // multiple queries + ( + "SELECT * FROM users; DELETE FROM posts;", + Err(DbError::Right(ParserError::ParserError("Only one statement allowed per query".to_owned()))), + ), + // empty query + ("", Err(DbError::Right(ParserError::ParserError("Parsed query is empty".to_owned())))), + // syntax error + ( + "SELEC * FORM users;", + Err(DbError::Right(ParserError::ParserError( + "Expected: an SQL statement, found: SELEC at Line: 1, Column: 1".to_owned(), + ))), + ), + // lowercase + ( + "select * from \"users\"", + Ok(("SELECT * FROM \"users\"".to_owned(), Box::new(|s| matches!(s, Statement::Query(_))))), + ), + // newlines + ("select *\nfrom users;", Ok(("SELECT * FROM users".to_owned(), Box::new(|s| matches!(s, Statement::Query(_)))))), + // comment-only + ("-- select * from users;", Err(DbError::Right(ParserError::ParserError("Parsed query is empty".to_owned())))), + // commented line(s) + ( + "-- select blah;\nselect * from users", + Ok(("SELECT * FROM users".to_owned(), Box::new(|s| matches!(s, Statement::Query(_))))), + ), + // update + ( + "UPDATE users SET name = 'John' WHERE id = 1", + Ok(( + "UPDATE users SET name = 'John' WHERE id = 1".to_owned(), + Box::new(|s| matches!(s, Statement::Update { .. })), + )), + ), + // delete + ( + "DELETE FROM users WHERE id = 1", + Ok(("DELETE FROM users WHERE id = 1".to_owned(), Box::new(|s| matches!(s, Statement::Delete { .. })))), + ), + // drop + ("DROP TABLE users", Ok(("DROP TABLE users".to_owned(), Box::new(|s| matches!(s, Statement::Drop { .. }))))), + // explain + ( + "EXPLAIN SELECT * FROM users", + Ok(("EXPLAIN SELECT * FROM users".to_owned(), Box::new(|s| matches!(s, Statement::Explain { .. })))), + ), + ]; + + let dialect = Box::new(SQLiteDialect {}); + + for (input, expected_output) in test_cases { + let result = get_first_query(input.to_string(), dialect.as_ref()); + match (result, expected_output) { + (Ok((query, statement)), Ok((expected_query, match_statement))) => { + assert_eq!(query, expected_query); + assert!(match_statement(&statement)); + }, + ( + Err(DbError::Right(ParserError::ParserError(msg))), + Err(DbError::Right(ParserError::ParserError(expected_msg))), + ) => { + assert_eq!(msg, expected_msg); + }, + _ => panic!("Unexpected result for input: {}", input), + } + } + } + + #[test] + fn test_should_use_tx_sqlite() { + let dialect = SQLiteDialect {}; + let test_cases = vec![ + ("DELETE FROM users WHERE id = 1", true), + ("DROP TABLE users", true), + ("UPDATE users SET name = 'John' WHERE id = 1", true), + ("SELECT * FROM users", false), + ("INSERT INTO users (name) VALUES ('John')", false), + // SQLite EXPLAIN statements + ("EXPLAIN DELETE FROM users WHERE id = 1", false), + ("EXPLAIN SELECT * FROM users", false), + // TODO: why this fail to parse? + // ("EXPLAIN QUERY PLAN SELECT * FROM users", false), + ]; + + for (query, expected) in test_cases { + let ast = Parser::parse_sql(&dialect, query).unwrap(); + let statement = ast[0].clone(); + assert_eq!(should_use_tx(statement), expected, "Failed for query: {}", query); + } + } +} diff --git a/src/main.rs b/src/main.rs index 46a4855..3e4c91d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -22,7 +22,7 @@ use std::{ }; use clap::Parser; -use cli::Cli; +use cli::{extract_driver_from_url, prompt_for_driver, Cli, Driver}; use color_eyre::eyre::{self, Result}; use database::{BuildConnectionOptions, DatabaseQueries, HasRowsAffected, ValueParser}; use sqlx::{postgres::PgConnectOptions, Connection, Database, Executor, MySql, Pool, Postgres, Sqlite}; @@ -51,12 +51,18 @@ async fn tokio_main() -> Result<()> { initialize_panic_handler()?; - let args = Cli::parse(); - match args.driver.as_str() { - "postgres" => run_app::(args).await, - "mysql" => run_app::(args).await, - "sqlite" => run_app::(args).await, - _ => Err(eyre::Report::msg("Please provide a valid a database type")), + let mut args = Cli::parse(); + let driver = if let Some(driver) = args.driver.take() { + driver + } else if let Some(ref url) = args.connection_url { + extract_driver_from_url(url)? + } else { + prompt_for_driver()? + }; + match driver { + Driver::Postgres => run_app::(args).await, + Driver::Mysql => run_app::(args).await, + Driver::Sqlite => run_app::(args).await, } } From 8b1ce58338ca755b6cc8e49858798b912e098071 Mon Sep 17 00:00:00 2001 From: Frank Wang <1454884738@qq.com> Date: Fri, 27 Sep 2024 19:51:22 -0500 Subject: [PATCH 09/16] skip failed mysql test --- src/database/mysql.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/database/mysql.rs b/src/database/mysql.rs index 33154ac..91db02a 100644 --- a/src/database/mysql.rs +++ b/src/database/mysql.rs @@ -356,7 +356,7 @@ mod tests { let test_cases = vec![ ("DELETE FROM users WHERE id = 1", true), // TODO: fix this - ("DROP TABLE users", false), // In MySQL, DROP TABLE causes an implicit commit + // ("DROP TABLE users", false), // In MySQL, DROP TABLE causes an implicit commit ("UPDATE users SET name = 'John' WHERE id = 1", true), ("SELECT * FROM users", false), ("INSERT INTO users (name) VALUES ('John')", true), From bab194bf49b33b68231ccc2a71ea907911cfd856 Mon Sep 17 00:00:00 2001 From: Frank Wang <1454884738@qq.com> Date: Fri, 27 Sep 2024 19:56:31 -0500 Subject: [PATCH 10/16] small mistake! --- src/database/mysql.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/database/mysql.rs b/src/database/mysql.rs index 91db02a..738a2e0 100644 --- a/src/database/mysql.rs +++ b/src/database/mysql.rs @@ -359,7 +359,7 @@ mod tests { // ("DROP TABLE users", false), // In MySQL, DROP TABLE causes an implicit commit ("UPDATE users SET name = 'John' WHERE id = 1", true), ("SELECT * FROM users", false), - ("INSERT INTO users (name) VALUES ('John')", true), + ("INSERT INTO users (name) VALUES ('John')", false), // EXPLAIN statements in MySQL ("EXPLAIN DELETE FROM users WHERE id = 1", false), ("EXPLAIN SELECT * FROM users", false), From 5ab3def81a72b9c8359b98815d68baafc307535b Mon Sep 17 00:00:00 2001 From: achristmascarl Date: Fri, 27 Sep 2024 21:11:01 -0400 Subject: [PATCH 11/16] change get dialect fn --- src/database.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/database.rs b/src/database.rs index edef59e..babaef1 100644 --- a/src/database.rs +++ b/src/database.rs @@ -233,6 +233,7 @@ pub fn get_dialect(db_type: &str) -> Arc { match db_type { "PostgreSQL" => Arc::new(PostgreSqlDialect {}), "MySQL" => Arc::new(MySqlDialect {}), - _ => Arc::new(SQLiteDialect {}), + "SQLite" => Arc::new(SQLiteDialect {}), + x => panic!("Unsupported database type: {}", x), } } From b8d88a8e17dec4e0371ede5a708e72ad579f09e5 Mon Sep 17 00:00:00 2001 From: achristmascarl Date: Fri, 27 Sep 2024 21:49:25 -0400 Subject: [PATCH 12/16] update dockerfile, readme --- Dockerfile | 4 ++-- README.md | 47 ++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/Dockerfile b/Dockerfile index 22f6b6a..f9164ef 100644 --- a/Dockerfile +++ b/Dockerfile @@ -23,5 +23,5 @@ USER rainfrog HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ CMD pidof rainfrog || exit 1 -# Command to construct the full connection URL using environment variables -CMD ["bash", "-c", "rainfrog --url postgres://$username:$password@$hostname:$db_port/$db_name"] +# Command to construct the full connection options using environment variables +CMD ["bash", "-c", "rainfrog --username=$username --password=$password --host=$hostname --port=$db_port --database=$db_name --driver=$driver"] diff --git a/README.md b/README.md index 8c46f00..24fbdbd 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ a database management tui for postgres ![rainfrog demo](vhs/demo.gif) > [!WARNING] -> rainfrog is currently in beta. +> rainfrog is currently in beta; the mysql and sqlite drivers are unstable. the goal for rainfrog is to provide a lightweight, terminal-based alternative to pgadmin/dbeaver. @@ -22,6 +22,12 @@ pgadmin/dbeaver. > [frogs find refuge in elephant tracks](https://www.sciencedaily.com/releases/2019/06/190604131157.htm) +### supported databases + +rainfrog has mainly been tested with postgres, and postgres will be the primary +database targeted. **mysql and sqlite are also supported, but they are +currently unstable**; use with caution! + ## disclaimer this software is currently under active development; expect breaking changes, @@ -93,14 +99,15 @@ curl -LSsf https://raw.githubusercontent.com/achristmascarl/rainfrog/main/instal Usage: rainfrog [OPTIONS] Options: - -M, --mouse Whether to enable mouse event support. If enabled, your terminal\'s default mouse event handling will not - work. [possible values: true, false] + -M, --mouse Whether to enable mouse event support. If enabled, your terminal\'s default mouse event handling will + not work. [possible values: true, false] -u, --url Full connection URL for the database, e.g. postgres://username:password@localhost:5432/dbname --username Username for database connection --password Password for database connection --host Host for database connection (ex. localhost) --port Port for database connection (ex. 5432) --database Name of database for connection (ex. postgres) + --driver Driver for database connection (ex. postgres) -h, --help Print help -V, --version Print version ``` @@ -113,6 +120,7 @@ default to what is in your environment variables. ```sh rainfrog \ + --driver \ --username \ --host \ --port \ @@ -131,9 +139,13 @@ rainfrog --url $(connection_url) ### `docker run` +for postgres and mysql, you can run it by specifying all +of the options as environment variables: + ```sh docker run --platform linux/amd64 -it --rm --name rainfrog \ --add-host host.docker.internal:host-gateway \ + -e driver="db_driver" \ -e username="" \ -e password="" \ -e hostname="host.docker.internal" \ @@ -141,6 +153,26 @@ docker run --platform linux/amd64 -it --rm --name rainfrog \ -e db_name="" achristmascarl/rainfrog:latest ``` +if you want to provide a custom combination of +options and omit others, you can override the Dockerfile's +CMD like so: + +```sh +docker run --platform linux/amd64 -it --rm --name rainfrog \ + achristmascarl/rainfrog:latest \ + rainfrog # overrides CMD, addition options would come after +``` + +since sqlite is file-based, you may need to mount a path to +the sqlite db as a volume in order to access it: + +```sh +docker run --platform linux/amd64 -it --rm --name rainfrog \ + -v ~/code/rainfrog/dev/rainfrog.sqlite3:/rainfrog.sqlite3 \ + achristmascarl/rainfrog:latest \ + rainfrog --url sqlite:///rainfrog.sqlite3 +``` + ## customization rainfrog can be customized by placing a `rainfrog_config.toml` file in @@ -208,7 +240,8 @@ are the default keybindings. #### query editor -Keybindings may not behave exactly like Vim. The full list of active Vim keybindings in Rainfrog can be found at [vim.rs](./src/vim.rs). +Keybindings may not behave exactly like Vim. The full list of active +Vim keybindings in Rainfrog can be found at [vim.rs](./src/vim.rs). | Keybinding | Description | | ----------------- | -------------------------------------- | @@ -318,6 +351,9 @@ features ## known issues and limitations +- in mysql, DROP statements cannot be rolled back, even if they are part of a + transaction; see + which will address this issue - for x11 and wayland, yanking does not copy to the system clipboard, only to the query editor's buffer. see - in addition to the experience being subpar if the terminal window is too @@ -342,7 +378,8 @@ features for bug reports and feature requests, please [create an issue](https://github.com/achristmascarl/rainfrog/issues/new/choose). -please read [CONTRIBUTING.md](./CONTRIBUTING.md) before opening issues or creating PRs. +please read [CONTRIBUTING.md](./CONTRIBUTING.md) before opening issues +or creating PRs. ## acknowledgements From a84413fec28939d9e8fd1bb4f967743efa6650f7 Mon Sep 17 00:00:00 2001 From: achristmascarl Date: Fri, 27 Sep 2024 21:57:34 -0400 Subject: [PATCH 13/16] update ci --- .github/workflows/ci.yml | 52 ++++++++++++++++++++++++++++++++++++---- Dockerfile | 2 +- README.md | 2 +- 3 files changed, 49 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d32e86e..7d83625 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -68,25 +68,67 @@ jobs: run: | make db-up sleep 5 # wait for db container - - name: docker run + - name: docker run postgres run: | - docker run -dit --name rainfrog_test \ + docker run -dit --name rainfrog_test_postgres \ --add-host host.docker.internal:host-gateway \ + -e driver="postgres" \ -e username="root" \ -e password="password" \ -e hostname="host.docker.internal" \ -e db_port="5499" \ -e db_name="rainfrog" rainfrog_test sleep 5 # wait for container - - name: check container status + - name: check postgres container status run: | - container_status=$(docker ps -f name=rainfrog_test --format "{{.Status}}") + container_status=$(docker ps -f name=rainfrog_test_postgres --format "{{.Status}}") if [[ "$container_status" == "Up"* ]]; then echo "container started" else echo "container did not start" echo "logs: " - docker logs -t rainfrog_test + docker logs -t rainfrog_test_postgres + exit 1 + fi + - name: docker run mysql + run: | + docker run -dit --name rainfrog_test_mysql \ + --add-host host.docker.internal:host-gateway \ + -e driver="mysql" \ + -e username="root" \ + -e password="password" \ + -e hostname="host.docker.internal" \ + -e db_port="3317" \ + -e db_name="rainfrog" rainfrog_test + sleep 5 # wait for container + - name: check mysql container status + run: | + container_status=$(docker ps -f name=rainfrog_test_mysql --format "{{.Status}}") + if [[ "$container_status" == "Up"* ]]; then + echo "container started" + else + echo "container did not start" + echo "logs: " + docker logs -t rainfrog_test_mysql + exit 1 + fi + - name: docker run sqlite + run: | + docker run -dit --name rainfrog_test_sqlite \ + --add-host host.docker.internal:host-gateway \ + -v $(pwd)/dev/rainfrog.sqlite3:/rainfrog.sqlite3 \ + rainfrog_test \ + rainfrog --url sqlite:///rainforg.sqlite3 + sleep 5 # wait for container + - name: check sqlite container status + run: | + container_status=$(docker ps -f name=rainfrog_test_sqlite --format "{{.Status}}") + if [[ "$container_status" == "Up"* ]]; then + echo "container started" + else + echo "container did not start" + echo "logs: " + docker logs -t rainfrog_test_sqlite exit 1 fi diff --git a/Dockerfile b/Dockerfile index f9164ef..110251a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -24,4 +24,4 @@ HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ CMD pidof rainfrog || exit 1 # Command to construct the full connection options using environment variables -CMD ["bash", "-c", "rainfrog --username=$username --password=$password --host=$hostname --port=$db_port --database=$db_name --driver=$driver"] +CMD ["bash", "-c", "rainfrog --username=$username --password=$password --host=$hostname --port=$db_port --database=$db_name --driver=$db_driver"] diff --git a/README.md b/README.md index 24fbdbd..cf68ac5 100644 --- a/README.md +++ b/README.md @@ -145,7 +145,7 @@ of the options as environment variables: ```sh docker run --platform linux/amd64 -it --rm --name rainfrog \ --add-host host.docker.internal:host-gateway \ - -e driver="db_driver" \ + -e db_driver="db_driver" \ -e username="" \ -e password="" \ -e hostname="host.docker.internal" \ From 2c32fb804fea48ce9b6b7078f9ae875e472fc6c9 Mon Sep 17 00:00:00 2001 From: achristmascarl Date: Fri, 27 Sep 2024 22:08:06 -0400 Subject: [PATCH 14/16] fix ci --- .github/workflows/ci.yml | 4 ++-- README.md | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7d83625..db712be 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -72,7 +72,7 @@ jobs: run: | docker run -dit --name rainfrog_test_postgres \ --add-host host.docker.internal:host-gateway \ - -e driver="postgres" \ + -e db_driver="postgres" \ -e username="root" \ -e password="password" \ -e hostname="host.docker.internal" \ @@ -94,7 +94,7 @@ jobs: run: | docker run -dit --name rainfrog_test_mysql \ --add-host host.docker.internal:host-gateway \ - -e driver="mysql" \ + -e db_driver="mysql" \ -e username="root" \ -e password="password" \ -e hostname="host.docker.internal" \ diff --git a/README.md b/README.md index cf68ac5..89e5eb1 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,8 @@ pgadmin/dbeaver. rainfrog has mainly been tested with postgres, and postgres will be the primary database targeted. **mysql and sqlite are also supported, but they are -currently unstable**; use with caution! +currently unstable**; use with caution, and check out the +[known issues](#known-issues-and-limitations) section for things to look out for! ## disclaimer @@ -354,6 +355,8 @@ features - in mysql, DROP statements cannot be rolled back, even if they are part of a transaction; see which will address this issue +- in sqlite, `EXPLAIN QUERY PLAN` does not work due to an issue with the + sql parser; see - for x11 and wayland, yanking does not copy to the system clipboard, only to the query editor's buffer. see - in addition to the experience being subpar if the terminal window is too From 183aa34e8fc8c65f3c24e970f4f62a272da92744 Mon Sep 17 00:00:00 2001 From: achristmascarl Date: Fri, 27 Sep 2024 22:21:20 -0400 Subject: [PATCH 15/16] dbg ci --- .github/workflows/ci.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index db712be..ed4b093 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -114,9 +114,11 @@ jobs: fi - name: docker run sqlite run: | + pwd + ls docker run -dit --name rainfrog_test_sqlite \ --add-host host.docker.internal:host-gateway \ - -v $(pwd)/dev/rainfrog.sqlite3:/rainfrog.sqlite3 \ + -v /home/runner/work/rainfrog/rainfrog/dev/rainfrog.sqlite3:/rainfrog.sqlite3 \ rainfrog_test \ rainfrog --url sqlite:///rainforg.sqlite3 sleep 5 # wait for container From 8cc99d0fc349234f7e00a5119a4f1e9673e79a15 Mon Sep 17 00:00:00 2001 From: achristmascarl Date: Fri, 27 Sep 2024 22:32:34 -0400 Subject: [PATCH 16/16] ofc its a typo --- .github/workflows/ci.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ed4b093..772b70c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -114,13 +114,11 @@ jobs: fi - name: docker run sqlite run: | - pwd - ls docker run -dit --name rainfrog_test_sqlite \ --add-host host.docker.internal:host-gateway \ -v /home/runner/work/rainfrog/rainfrog/dev/rainfrog.sqlite3:/rainfrog.sqlite3 \ rainfrog_test \ - rainfrog --url sqlite:///rainforg.sqlite3 + rainfrog --url sqlite:///rainfrog.sqlite3 sleep 5 # wait for container - name: check sqlite container status run: |