From 366b8ea97bbe36ad5e3dd8d45f1e787ee2a7f223 Mon Sep 17 00:00:00 2001 From: Ellie Huxtable Date: Mon, 29 Jan 2024 16:38:24 +0000 Subject: [PATCH] feat: automatically init history store when record sync is enabled (#1634) * add support for getting the total length of a store * tidy up sync * auto call init if history is ahead * fix import order, key regen * fix import order, key regen * do not delete key when user deletes account * message output * remote init store command; this is now automatic * should probs make that function return u64 at some point --- atuin-client/src/history/store.rs | 30 +++++++++++- atuin-client/src/record/sqlite_store.rs | 32 +++++++++++++ atuin-client/src/record/store.rs | 2 + atuin-client/src/record/sync.rs | 49 +++++++++++++++----- atuin/src/command/client/account/delete.rs | 5 -- atuin/src/command/client/account/register.rs | 3 +- atuin/src/command/client/history.rs | 43 +---------------- atuin/src/command/client/sync.rs | 24 +++++++--- 8 files changed, 120 insertions(+), 68 deletions(-) diff --git a/atuin-client/src/history/store.rs b/atuin-client/src/history/store.rs index 442da45d6a4..0a2a231230a 100644 --- a/atuin-client/src/history/store.rs +++ b/atuin-client/src/history/store.rs @@ -4,7 +4,7 @@ use eyre::{bail, eyre, Result}; use rmp::decode::Bytes; use crate::{ - database::Database, + database::{self, Database}, record::{encryption::PASETO_V4, sqlite_store::SqliteStore, store::Store}, }; use atuin_common::record::{DecryptedData, Host, HostId, Record, RecordId, RecordIdx}; @@ -255,6 +255,34 @@ impl HistoryStore { Ok(ret) } + + pub async fn init_store(&self, context: database::Context, db: &impl Database) -> Result<()> { + println!("Importing all history.db data into records.db"); + + println!("Fetching history from old database"); + let history = db.list(&[], &context, None, false, true).await?; + + println!("Fetching history already in store"); + let store_ids = self.history_ids().await?; + + for i in history { + println!("loaded {}", i.id); + + if store_ids.contains(&i.id) { + println!("skipping {} - already exists", i.id); + continue; + } + + if i.deleted_at.is_some() { + self.push(i.clone()).await?; + self.delete(i.id).await?; + } else { + self.push(i).await?; + } + } + + Ok(()) + } } #[cfg(test)] diff --git a/atuin-client/src/record/sqlite_store.rs b/atuin-client/src/record/sqlite_store.rs index 50f30d76aaf..e9d7ff59bed 100644 --- a/atuin-client/src/record/sqlite_store.rs +++ b/atuin-client/src/record/sqlite_store.rs @@ -155,6 +155,18 @@ impl Store for SqliteStore { self.idx(host, tag, 0).await } + async fn len_tag(&self, tag: &str) -> Result { + let res: Result<(i64,), sqlx::Error> = + sqlx::query_as("select count(*) from store where tag=?1") + .bind(tag) + .fetch_one(&self.pool) + .await; + match res { + Err(e) => Err(eyre!("failed to fetch local store len: {}", e)), + Ok(v) => Ok(v.0 as u64), + } + } + async fn len(&self, host: HostId, tag: &str) -> Result { let last = self.last(host, tag).await?; @@ -342,6 +354,20 @@ mod tests { assert_eq!(len, 1, "expected length of 1 after insert"); } + #[tokio::test] + async fn len_tag() { + let db = SqliteStore::new(":memory:", 0.1).await.unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let len = db + .len_tag(record.tag.as_str()) + .await + .expect("failed to get store len"); + + assert_eq!(len, 1, "expected length of 1 after insert"); + } + #[tokio::test] async fn len_different_tags() { let db = SqliteStore::new(":memory:", 0.1).await.unwrap(); @@ -379,6 +405,12 @@ mod tests { 100, "failed to insert 100 records" ); + + assert_eq!( + db.len_tag(tail.tag.as_str()).await.unwrap(), + 100, + "failed to insert 100 records" + ); } #[tokio::test] diff --git a/atuin-client/src/record/store.rs b/atuin-client/src/record/store.rs index efe2eb4a70c..40c1224b7b4 100644 --- a/atuin-client/src/record/store.rs +++ b/atuin-client/src/record/store.rs @@ -21,7 +21,9 @@ pub trait Store { ) -> Result<()>; async fn get(&self, id: RecordId) -> Result>; + async fn len(&self, host: HostId, tag: &str) -> Result; + async fn len_tag(&self, tag: &str) -> Result; async fn last(&self, host: HostId, tag: &str) -> Result>>; async fn first(&self, host: HostId, tag: &str) -> Result>>; diff --git a/atuin-client/src/record/sync.rs b/atuin-client/src/record/sync.rs index 97152f79ad1..eca0c930596 100644 --- a/atuin-client/src/record/sync.rs +++ b/atuin-client/src/record/sync.rs @@ -14,14 +14,17 @@ pub enum SyncError { #[error("the local store is ahead of the remote, but for another host. has remote lost data?")] LocalAheadOtherHost, - #[error("an issue with the local database occured")] - LocalStoreError, + #[error("an issue with the local database occured: {msg:?}")] + LocalStoreError { msg: String }, #[error("something has gone wrong with the sync logic: {msg:?}")] SyncLogicError { msg: String }, - #[error("a request to the sync server failed")] - RemoteRequestError, + #[error("operational error: {msg:?}")] + OperationalError { msg: String }, + + #[error("a request to the sync server failed: {msg:?}")] + RemoteRequestError { msg: String }, } #[derive(Debug, Eq, PartialEq)] @@ -45,16 +48,27 @@ pub enum Operation { }, } -pub async fn diff(settings: &Settings, store: &impl Store) -> Result<(Vec, RecordStatus)> { +pub async fn diff( + settings: &Settings, + store: &impl Store, +) -> Result<(Vec, RecordStatus), SyncError> { let client = Client::new( &settings.sync_address, &settings.session_token, settings.network_connect_timeout, settings.network_timeout, - )?; + ) + .map_err(|e| SyncError::OperationalError { msg: e.to_string() })?; + + let local_index = store + .status() + .await + .map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?; - let local_index = store.status().await?; - let remote_index = client.record_status().await?; + let remote_index = client + .record_status() + .await + .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?; let diff = local_index.diff(&remote_index); @@ -166,13 +180,13 @@ async fn sync_upload( .map_err(|e| { error!("failed to read upload page: {e:?}"); - SyncError::LocalStoreError + SyncError::LocalStoreError { msg: e.to_string() } })?; client.post_records(&page).await.map_err(|e| { error!("failed to post records: {e:?}"); - SyncError::RemoteRequestError + SyncError::RemoteRequestError { msg: e.to_string() } })?; println!( @@ -217,12 +231,12 @@ async fn sync_download( let page = client .next_records(host, tag.clone(), local + progress, download_page_size) .await - .map_err(|_| SyncError::RemoteRequestError)?; + .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?; store .push_batch(page.iter()) .await - .map_err(|_| SyncError::LocalStoreError)?; + .map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?; println!( "downloaded {} records from remote, progress {}/{}", @@ -283,6 +297,17 @@ pub async fn sync_remote( Ok((uploaded, downloaded)) } +pub async fn sync( + settings: &Settings, + store: &impl Store, +) -> Result<(i64, Vec), SyncError> { + let (diff, _) = diff(settings, store).await?; + let operations = operations(diff, store).await?; + let (uploaded, downloaded) = sync_remote(operations, store, settings).await?; + + Ok((uploaded, downloaded)) +} + #[cfg(test)] mod tests { use atuin_common::record::{Diff, EncryptedData, HostId, Record}; diff --git a/atuin/src/command/client/account/delete.rs b/atuin/src/command/client/account/delete.rs index 6a4b1406251..3591c6f398c 100644 --- a/atuin/src/command/client/account/delete.rs +++ b/atuin/src/command/client/account/delete.rs @@ -5,7 +5,6 @@ use std::path::PathBuf; pub async fn run(settings: &Settings) -> Result<()> { let session_path = settings.session_path.as_str(); - let key_path = settings.key_path.as_str(); if !PathBuf::from(session_path).exists() { bail!("You are not logged in"); @@ -25,10 +24,6 @@ pub async fn run(settings: &Settings) -> Result<()> { remove_file(PathBuf::from(session_path))?; } - if PathBuf::from(key_path).exists() { - remove_file(PathBuf::from(key_path))?; - } - println!("Your account is deleted"); Ok(()) diff --git a/atuin/src/command/client/account/register.rs b/atuin/src/command/client/account/register.rs index 0523dced860..96b7d7d6c1a 100644 --- a/atuin/src/command/client/account/register.rs +++ b/atuin/src/command/client/account/register.rs @@ -49,8 +49,7 @@ pub async fn run( let mut file = File::create(path).await?; file.write_all(session.session.as_bytes()).await?; - // Create a new key, and save it to disk - let _key = atuin_client::encryption::new_key(settings)?; + let _key = atuin_client::encryption::load_key(settings)?; Ok(()) } diff --git a/atuin/src/command/client/history.rs b/atuin/src/command/client/history.rs index e983cc7b79b..18ae17cf35b 100644 --- a/atuin/src/command/client/history.rs +++ b/atuin/src/command/client/history.rs @@ -88,10 +88,6 @@ pub enum Cmd { #[arg(long, short)] format: Option, }, - - /// Import all old history.db data into the record store. Do not run more than once, and do not - /// run unless you know what you're doing (or the docs ask you to) - InitStore, } #[derive(Clone, Copy, Debug)] @@ -321,10 +317,7 @@ impl Cmd { #[cfg(feature = "sync")] { if settings.sync.records { - let (diff, _) = record::sync::diff(settings, &store).await?; - let operations = record::sync::operations(diff, &store).await?; - let (_, downloaded) = - record::sync::sync_remote(operations, &store, settings).await?; + let (_, downloaded) = record::sync::sync(settings, &store).await?; history_store.incremental_build(db, &downloaded).await?; } else { @@ -380,38 +373,6 @@ impl Cmd { Ok(()) } - async fn init_store( - context: atuin_client::database::Context, - db: &impl Database, - store: HistoryStore, - ) -> Result<()> { - println!("Importing all history.db data into records.db"); - - println!("Fetching history from old database"); - let history = db.list(&[], &context, None, false, true).await?; - - println!("Fetching history already in store"); - let store_ids = store.history_ids().await?; - - for i in history { - println!("loaded {}", i.id); - - if store_ids.contains(&i.id) { - println!("skipping {} - already exists", i.id); - continue; - } - - if i.deleted_at.is_some() { - store.push(i.clone()).await?; - store.delete(i.id).await?; - } else { - store.push(i).await?; - } - } - - Ok(()) - } - pub async fn run( self, settings: &Settings, @@ -468,8 +429,6 @@ impl Cmd { Ok(()) } - - Self::InitStore => Self::init_store(context, db, history_store).await, } } } diff --git a/atuin/src/command/client/sync.rs b/atuin/src/command/client/sync.rs index 2e58f07dba1..5b438453477 100644 --- a/atuin/src/command/client/sync.rs +++ b/atuin/src/command/client/sync.rs @@ -2,10 +2,10 @@ use clap::Subcommand; use eyre::{Result, WrapErr}; use atuin_client::{ - database::Database, + database::{current_context, Database}, encryption, history::store::HistoryStore, - record::{sqlite_store::SqliteStore, sync}, + record::{sqlite_store::SqliteStore, store::Store, sync}, settings::Settings, }; @@ -80,10 +80,6 @@ async fn run( store: SqliteStore, ) -> Result<()> { if settings.sync.records { - let (diff, _) = sync::diff(settings, &store).await?; - let operations = sync::operations(diff, &store).await?; - let (uploaded, downloaded) = sync::sync_remote(operations, &store, settings).await?; - let encryption_key: [u8; 32] = encryption::load_key(settings) .context("could not load encryption key")? .into(); @@ -91,6 +87,22 @@ async fn run( let host_id = Settings::host_id().expect("failed to get host_id"); let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + let history_length = db.history_count(true).await?; + let store_history_length = store.len_tag("history").await?; + + #[allow(clippy::cast_sign_loss)] + if history_length as u64 > store_history_length { + println!("History DB is longer than history record store"); + println!("This happens when you used Atuin pre-record-store"); + + let context = current_context(); + history_store.init_store(context, db).await?; + + println!("\n"); + } + + let (uploaded, downloaded) = sync::sync(settings, &store).await?; + history_store.incremental_build(db, &downloaded).await?; println!("{uploaded}/{} up/down to record store", downloaded.len());