Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for mysql and sqlite #100

Merged
merged 18 commits into from
Sep 28, 2024
Merged
10 changes: 1 addition & 9 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,7 @@ ratatui = { version = "0.28.0", features = [
serde = { version = "1.0.188", features = ["derive"] }
serde_json = "1.0.107"
signal-hook = "0.3.17"
sqlx = { version = "0.8.1", features = [
"runtime-tokio",
"tls-rustls",
"postgres",
"uuid",
"chrono",
"json",
"ipnetwork",
] }
sqlx = { version = "0.8.1", features = ["runtime-tokio", "tls-rustls", "postgres", "uuid", "chrono", "json", "ipnetwork", "mysql", "sqlite"] }
strip-ansi-escapes = "0.2.0"
strum = { version = "0.26.1", features = ["derive"] }
tokio = { version = "1.32.0", features = ["full"] }
Expand Down
78 changes: 43 additions & 35 deletions src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@ use ratatui::{
Frame,
};
use serde::{Deserialize, Serialize};
use sqlparser::{ast::Statement, keywords::DELETE};
use sqlparser::{
ast::Statement,
dialect::Dialect,
keywords::{DELETE, NAME},
};
use sqlx::{
postgres::{PgConnectOptions, Postgres},
Either, Transaction,
Connection, Database, Either, Executor, Pool, Transaction,
};
use tokio::{
sync::{
Expand All @@ -36,16 +40,16 @@ use crate::{
Component,
},
config::Config,
database::{self, statement_type_string, DbError, DbPool, Rows},
database::{self, get_dialect, statement_type_string, DatabaseQueries, DbError, DbPool, Rows},
focus::Focus,
tui,
ui::center,
};

pub enum DbTask<'a> {
pub enum DbTask<'a, DB: sqlx::Database> {
Query(tokio::task::JoinHandle<QueryResultsWithMetadata>),
TxStart(tokio::task::JoinHandle<(QueryResultsWithMetadata, Transaction<'a, Postgres>)>),
TxPending(Transaction<'a, Postgres>, QueryResultsWithMetadata),
TxStart(tokio::task::JoinHandle<(QueryResultsWithMetadata, Transaction<'a, DB>)>),
TxPending(Transaction<'a, DB>, QueryResultsWithMetadata),
TxCommit(tokio::task::JoinHandle<QueryResultsWithMetadata>),
}

Expand All @@ -54,20 +58,21 @@ pub struct HistoryEntry {
pub timestamp: chrono::DateTime<chrono::Local>,
}

pub struct AppState<'a> {
pub connection_opts: PgConnectOptions,
pub struct AppState<'a, DB: Database> {
pub connection_opts: <DB::Connection as Connection>::Options,
pub dialect: Arc<dyn Dialect + Send + Sync>,
pub focus: Focus,
pub query_task: Option<DbTask<'a>>,
pub query_task: Option<DbTask<'a, DB>>,
pub history: Vec<HistoryEntry>,
pub last_query_start: Option<chrono::DateTime<chrono::Utc>>,
pub last_query_end: Option<chrono::DateTime<chrono::Utc>>,
}

pub struct Components<'a> {
pub menu: Box<dyn MenuComponent<'a>>,
pub editor: Box<dyn Component>,
pub history: Box<dyn Component>,
pub data: Box<dyn DataComponent<'a>>,
pub struct Components<'a, DB> {
pub menu: Box<dyn MenuComponent<'a, DB>>,
pub editor: Box<dyn Component<DB>>,
pub history: Box<dyn Component<DB>>,
pub data: Box<dyn DataComponent<'a, DB>>,
}

#[derive(Debug)]
Expand All @@ -76,20 +81,29 @@ pub struct QueryResultsWithMetadata {
pub statement_type: Statement,
}

pub struct App<'a> {
pub struct App<'a, DB: sqlx::Database> {
pub mouse_mode_override: Option<bool>,
pub config: Config,
pub components: Components<'static>,
pub components: Components<'static, DB>,
pub should_quit: bool,
pub last_tick_key_events: Vec<KeyEvent>,
pub last_frame_mouse_event: Option<MouseEvent>,
pub pool: Option<DbPool>,
pub state: AppState<'a>,
pub pool: Option<database::DbPool<DB>>,
pub state: AppState<'a, DB>,
last_focused_tab: Focus,
}

impl<'a> App<'a> {
pub fn new(connection_opts: PgConnectOptions, mouse_mode_override: Option<bool>) -> Result<Self> {
impl<'a, DB> App<'a, DB>
where
DB: Database + database::ValueParser + database::DatabaseQueries,
DB::QueryResult: database::HasRowsAffected,
for<'c> <DB as sqlx::Database>::Arguments<'c>: sqlx::IntoArguments<'c, DB>,
for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>,
{
pub fn new(
connection_opts: <DB::Connection as Connection>::Options,
mouse_mode_override: Option<bool>,
) -> Result<Self> {
let focus = Focus::Menu;
let menu = Menu::new();
let editor = Editor::new();
Expand All @@ -111,6 +125,7 @@ impl<'a> App<'a> {
pool: None,
state: AppState {
connection_opts,
dialect: get_dialect(DB::NAME),
focus,
query_task: None,
history: vec![],
Expand All @@ -135,7 +150,7 @@ impl<'a> App<'a> {
pub async fn run(&mut self) -> Result<()> {
let (action_tx, mut action_rx) = mpsc::unbounded_channel();
let connection_opts = self.state.connection_opts.clone();
let pool = database::init_pool(connection_opts).await?;
let pool = database::init_pool::<DB>(connection_opts).await?;
log::info!("{pool:?}");
self.pool = Some(pool);

Expand Down Expand Up @@ -350,36 +365,28 @@ impl<'a> App<'a> {
Action::LoadMenu => {
log::info!("LoadMenu");
if let Some(pool) = &self.pool {
let results = database::query(
"select table_schema, table_name
from information_schema.tables
where table_schema != 'pg_catalog'
and table_schema != 'information_schema'
group by table_schema, table_name
order by table_schema, table_name asc"
.to_owned(),
pool,
)
.await;
let results = database::query(DB::preview_tables_query(), self.state.dialect.as_ref(), pool).await;
self.components.menu.set_table_list(Some(results));
}
},
Action::Query(query_lines) => {
let query_string = query_lines.clone().join(" \n");
if !query_string.is_empty() {
self.add_to_history(query_lines.clone());
let first_query = database::get_first_query(query_string.clone());
let first_query = database::get_first_query(query_string.clone(), self.state.dialect.as_ref());
let should_use_tx = first_query
.map(|(_, statement_type)| (database::should_use_tx(statement_type.clone()), statement_type));
let action_tx = action_tx.clone();
if let Some(pool) = &self.pool {
let pool = pool.clone();
let dialect = self.state.dialect.clone();
match should_use_tx {
Ok((true, statement_type)) => {
self.components.data.set_loading();
let tx = pool.begin().await?;
self.state.query_task = Some(DbTask::TxStart(tokio::spawn(async move {
let (results, tx) = database::query_with_tx(tx, query_string.clone()).await;
let (results, tx) =
database::query_with_tx::<DB>(tx, dialect.as_ref(), query_string.clone()).await;
match results {
Ok(Either::Left(rows_affected)) => {
log::info!("{:?} rows affected", rows_affected);
Expand All @@ -406,8 +413,9 @@ impl<'a> App<'a> {
},
Ok((false, statement_type)) => {
self.components.data.set_loading();
let dialect = self.state.dialect.clone();
self.state.query_task = Some(DbTask::Query(tokio::spawn(async move {
let results = database::query(query_string.clone(), &pool).await;
let results = database::query(query_string.clone(), dialect.as_ref(), &pool).await;
match &results {
Ok(rows) => {
log::info!("{:?} rows, {:?} affected", rows.rows.len(), rows.rows_affected);
Expand Down
47 changes: 46 additions & 1 deletion src/cli.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
use std::path::PathBuf;
use std::{
io::{self, Write},
path::PathBuf,
str::FromStr,
};

use clap::Parser;
use color_eyre::eyre::{self, Result};

use crate::utils::version;

Expand Down Expand Up @@ -37,4 +42,44 @@ pub struct Cli {

#[arg(long = "database", value_name = "DATABASE", help = "Name of database for connection (ex. postgres)")]
pub database: Option<String>,

#[arg(long = "driver", value_name = "DRIVER", help = "Driver for database connection (ex. postgres)")]
pub driver: Option<Driver>,
}

#[derive(Parser, Debug, Clone)]
pub enum Driver {
Postgres,
Mysql,
Sqlite,
}
achristmascarl marked this conversation as resolved.
Show resolved Hide resolved

impl FromStr for Driver {
type Err = eyre::Report;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"postgres" | "postgresql" => Ok(Driver::Postgres),
"mysql" => Ok(Driver::Mysql),
"sqlite" => Ok(Driver::Sqlite),
_ => Err(eyre::Report::msg("Invalid driver")),
}
}
}

pub fn extract_driver_from_url(url: &str) -> Result<Driver> {
let url = url.trim();
if let Some(pos) = url.find("://") {
url[..pos].to_lowercase().parse()
} else {
Err(eyre::Report::msg("Invalid connection URL format"))
}
}

pub fn prompt_for_driver() -> Result<Driver> {
let mut driver = String::new();
print!("Database driver (postgres, mysql, sqlite): ");
io::stdout().flush()?;
io::stdin().read_line(&mut driver)?;
driver.trim().to_lowercase().parse()
}
13 changes: 6 additions & 7 deletions src/components.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ pub mod editor;
pub mod history;
pub mod menu;
pub mod scroll_table;

pub trait Component {
pub trait Component<DB: sqlx::Database> {
/// Register an action handler that can send actions for processing if necessary.
///
/// # Arguments
Expand Down Expand Up @@ -68,7 +67,7 @@ pub trait Component {
&mut self,
event: Option<Event>,
last_tick_key_events: Vec<KeyEvent>,
app_state: &AppState,
app_state: &AppState<'_, DB>,
) -> Result<Option<Action>> {
let r = match event {
Some(Event::Key(key_event)) => self.handle_key_events(key_event, app_state)?,
Expand All @@ -87,7 +86,7 @@ pub trait Component {
///
/// * `Result<Option<Action>>` - An action to be processed or none.
#[allow(unused_variables)]
fn handle_key_events(&mut self, key: KeyEvent, app_state: &AppState) -> Result<Option<Action>> {
fn handle_key_events(&mut self, key: KeyEvent, app_state: &AppState<'_, DB>) -> Result<Option<Action>> {
Ok(None)
}
/// Handle mouse events and produce actions if necessary.
Expand All @@ -100,7 +99,7 @@ pub trait Component {
///
/// * `Result<Option<Action>>` - An action to be processed or none.
#[allow(unused_variables)]
fn handle_mouse_events(&mut self, mouse: MouseEvent, app_state: &AppState) -> Result<Option<Action>> {
fn handle_mouse_events(&mut self, mouse: MouseEvent, app_state: &AppState<'_, DB>) -> Result<Option<Action>> {
Ok(None)
}
/// Update the state of the component based on a received action. (REQUIRED)
Expand All @@ -113,7 +112,7 @@ pub trait Component {
///
/// * `Result<Option<Action>>` - An action to be processed or none.
#[allow(unused_variables)]
fn update(&mut self, action: Action, app_state: &AppState) -> Result<Option<Action>> {
fn update(&mut self, action: Action, app_state: &AppState<'_, DB>) -> Result<Option<Action>> {
Ok(None)
}
/// Render the component on the screen. (REQUIRED)
Expand All @@ -126,5 +125,5 @@ pub trait Component {
/// # Returns
///
/// * `Result<()>` - An Ok result or an error.
fn draw(&mut self, f: &mut Frame<'_>, area: Rect, app_state: &AppState) -> Result<()>;
fn draw(&mut self, f: &mut Frame<'_>, area: Rect, app_state: &AppState<'_, DB>) -> Result<()>;
}
17 changes: 9 additions & 8 deletions src/components/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crossterm::{
use ratatui::{prelude::*, symbols::scrollbar, widgets::*};
use serde::{Deserialize, Serialize};
use sqlparser::ast::Statement;
use sqlx::{Database, Executor, Pool};
use tokio::sync::{mpsc::UnboundedSender, Mutex};
use tui_textarea::{Input, Key};

Expand All @@ -20,7 +21,7 @@ use crate::{
Component,
},
config::{Config, KeyBindings},
database::{get_headers, parse_value, row_to_json, row_to_vec, statement_type_string, DbError, Rows},
database::{get_headers, row_to_json, row_to_vec, statement_type_string, DbError, Rows},
focus::Focus,
tui::Event,
};
Expand Down Expand Up @@ -51,8 +52,8 @@ pub trait SettableDataTable<'a> {
fn set_cancelled(&mut self);
}

pub trait DataComponent<'a>: Component + SettableDataTable<'a> {}
impl<'a, T> DataComponent<'a> for T where T: Component + SettableDataTable<'a>
pub trait DataComponent<'a, DB: sqlx::Database>: Component<DB> + SettableDataTable<'a> {}
impl<'a, T, DB: sqlx::Database> DataComponent<'a, DB> for T where T: Component<DB> + SettableDataTable<'a>
{
}

Expand Down Expand Up @@ -231,7 +232,7 @@ impl<'a> SettableDataTable<'a> for Data<'a> {
}
}

impl<'a> Component for Data<'a> {
impl<'a, DB: Database> Component<DB> for Data<'a> {
fn register_action_handler(&mut self, tx: UnboundedSender<Action>) -> Result<()> {
self.command_tx = Some(tx);
Ok(())
Expand All @@ -245,7 +246,7 @@ impl<'a> Component for Data<'a> {
fn handle_mouse_events(
&mut self,
mouse: crossterm::event::MouseEvent,
app_state: &AppState,
app_state: &AppState<'_, DB>,
) -> Result<Option<Action>> {
if app_state.focus != Focus::Data {
return Ok(None);
Expand All @@ -268,7 +269,7 @@ impl<'a> Component for Data<'a> {
Ok(None)
}

fn handle_key_events(&mut self, key: KeyEvent, app_state: &AppState) -> Result<Option<Action>> {
fn handle_key_events(&mut self, key: KeyEvent, app_state: &AppState<'_, DB>) -> Result<Option<Action>> {
if app_state.focus != Focus::Data {
return Ok(None);
}
Expand Down Expand Up @@ -372,14 +373,14 @@ impl<'a> Component for Data<'a> {
Ok(None)
}

fn update(&mut self, action: Action, app_state: &AppState) -> Result<Option<Action>> {
fn update(&mut self, action: Action, app_state: &AppState<'_, DB>) -> Result<Option<Action>> {
if let Action::Query(query) = action {
self.scrollable.reset_scroll();
}
Ok(None)
}

fn draw(&mut self, f: &mut Frame<'_>, area: Rect, app_state: &AppState) -> Result<()> {
fn draw(&mut self, f: &mut Frame<'_>, area: Rect, app_state: &AppState<'_, DB>) -> Result<()> {
let focused = app_state.focus == Focus::Data;

let mut block = Block::default().borders(Borders::ALL).border_style(if focused {
Expand Down
Loading