Skip to content

Commit

Permalink
feat: automatically init history store when record sync is enabled (#…
Browse files Browse the repository at this point in the history
…1634)

* add support for getting the total length of a store

* tidy up sync

* auto call init if history is ahead

* fix import order, key regen

* fix import order, key regen

* do not delete key when user deletes account

* message output

* remote init store command; this is now automatic

* should probs make that function return u64 at some point
  • Loading branch information
ellie authored Jan 29, 2024
1 parent 15bad15 commit 366b8ea
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 68 deletions.
30 changes: 29 additions & 1 deletion atuin-client/src/history/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use eyre::{bail, eyre, Result};
use rmp::decode::Bytes;

use crate::{
database::Database,
database::{self, Database},
record::{encryption::PASETO_V4, sqlite_store::SqliteStore, store::Store},
};
use atuin_common::record::{DecryptedData, Host, HostId, Record, RecordId, RecordIdx};
Expand Down Expand Up @@ -255,6 +255,34 @@ impl HistoryStore {

Ok(ret)
}

pub async fn init_store(&self, context: database::Context, db: &impl Database) -> Result<()> {
println!("Importing all history.db data into records.db");

println!("Fetching history from old database");
let history = db.list(&[], &context, None, false, true).await?;

println!("Fetching history already in store");
let store_ids = self.history_ids().await?;

for i in history {
println!("loaded {}", i.id);

if store_ids.contains(&i.id) {
println!("skipping {} - already exists", i.id);
continue;
}

if i.deleted_at.is_some() {
self.push(i.clone()).await?;
self.delete(i.id).await?;
} else {
self.push(i).await?;
}
}

Ok(())
}
}

#[cfg(test)]
Expand Down
32 changes: 32 additions & 0 deletions atuin-client/src/record/sqlite_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,18 @@ impl Store for SqliteStore {
self.idx(host, tag, 0).await
}

async fn len_tag(&self, tag: &str) -> Result<u64> {
let res: Result<(i64,), sqlx::Error> =
sqlx::query_as("select count(*) from store where tag=?1")
.bind(tag)
.fetch_one(&self.pool)
.await;
match res {
Err(e) => Err(eyre!("failed to fetch local store len: {}", e)),
Ok(v) => Ok(v.0 as u64),
}
}

async fn len(&self, host: HostId, tag: &str) -> Result<u64> {
let last = self.last(host, tag).await?;

Expand Down Expand Up @@ -342,6 +354,20 @@ mod tests {
assert_eq!(len, 1, "expected length of 1 after insert");
}

#[tokio::test]
async fn len_tag() {
let db = SqliteStore::new(":memory:", 0.1).await.unwrap();
let record = test_record();
db.push(&record).await.unwrap();

let len = db
.len_tag(record.tag.as_str())
.await
.expect("failed to get store len");

assert_eq!(len, 1, "expected length of 1 after insert");
}

#[tokio::test]
async fn len_different_tags() {
let db = SqliteStore::new(":memory:", 0.1).await.unwrap();
Expand Down Expand Up @@ -379,6 +405,12 @@ mod tests {
100,
"failed to insert 100 records"
);

assert_eq!(
db.len_tag(tail.tag.as_str()).await.unwrap(),
100,
"failed to insert 100 records"
);
}

#[tokio::test]
Expand Down
2 changes: 2 additions & 0 deletions atuin-client/src/record/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ pub trait Store {
) -> Result<()>;

async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>>;

async fn len(&self, host: HostId, tag: &str) -> Result<u64>;
async fn len_tag(&self, tag: &str) -> Result<u64>;

async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
Expand Down
49 changes: 37 additions & 12 deletions atuin-client/src/record/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@ pub enum SyncError {
#[error("the local store is ahead of the remote, but for another host. has remote lost data?")]
LocalAheadOtherHost,

#[error("an issue with the local database occured")]
LocalStoreError,
#[error("an issue with the local database occured: {msg:?}")]
LocalStoreError { msg: String },

#[error("something has gone wrong with the sync logic: {msg:?}")]
SyncLogicError { msg: String },

#[error("a request to the sync server failed")]
RemoteRequestError,
#[error("operational error: {msg:?}")]
OperationalError { msg: String },

#[error("a request to the sync server failed: {msg:?}")]
RemoteRequestError { msg: String },
}

#[derive(Debug, Eq, PartialEq)]
Expand All @@ -45,16 +48,27 @@ pub enum Operation {
},
}

pub async fn diff(settings: &Settings, store: &impl Store) -> Result<(Vec<Diff>, RecordStatus)> {
pub async fn diff(
settings: &Settings,
store: &impl Store,
) -> Result<(Vec<Diff>, RecordStatus), SyncError> {
let client = Client::new(
&settings.sync_address,
&settings.session_token,
settings.network_connect_timeout,
settings.network_timeout,
)?;
)
.map_err(|e| SyncError::OperationalError { msg: e.to_string() })?;

let local_index = store
.status()
.await
.map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?;

let local_index = store.status().await?;
let remote_index = client.record_status().await?;
let remote_index = client
.record_status()
.await
.map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?;

let diff = local_index.diff(&remote_index);

Expand Down Expand Up @@ -166,13 +180,13 @@ async fn sync_upload(
.map_err(|e| {
error!("failed to read upload page: {e:?}");

SyncError::LocalStoreError
SyncError::LocalStoreError { msg: e.to_string() }
})?;

client.post_records(&page).await.map_err(|e| {
error!("failed to post records: {e:?}");

SyncError::RemoteRequestError
SyncError::RemoteRequestError { msg: e.to_string() }
})?;

println!(
Expand Down Expand Up @@ -217,12 +231,12 @@ async fn sync_download(
let page = client
.next_records(host, tag.clone(), local + progress, download_page_size)
.await
.map_err(|_| SyncError::RemoteRequestError)?;
.map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?;

store
.push_batch(page.iter())
.await
.map_err(|_| SyncError::LocalStoreError)?;
.map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?;

println!(
"downloaded {} records from remote, progress {}/{}",
Expand Down Expand Up @@ -283,6 +297,17 @@ pub async fn sync_remote(
Ok((uploaded, downloaded))
}

pub async fn sync(
settings: &Settings,
store: &impl Store,
) -> Result<(i64, Vec<RecordId>), SyncError> {
let (diff, _) = diff(settings, store).await?;
let operations = operations(diff, store).await?;
let (uploaded, downloaded) = sync_remote(operations, store, settings).await?;

Ok((uploaded, downloaded))
}

#[cfg(test)]
mod tests {
use atuin_common::record::{Diff, EncryptedData, HostId, Record};
Expand Down
5 changes: 0 additions & 5 deletions atuin/src/command/client/account/delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use std::path::PathBuf;

pub async fn run(settings: &Settings) -> Result<()> {
let session_path = settings.session_path.as_str();
let key_path = settings.key_path.as_str();

if !PathBuf::from(session_path).exists() {
bail!("You are not logged in");
Expand All @@ -25,10 +24,6 @@ pub async fn run(settings: &Settings) -> Result<()> {
remove_file(PathBuf::from(session_path))?;
}

if PathBuf::from(key_path).exists() {
remove_file(PathBuf::from(key_path))?;
}

println!("Your account is deleted");

Ok(())
Expand Down
3 changes: 1 addition & 2 deletions atuin/src/command/client/account/register.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ pub async fn run(
let mut file = File::create(path).await?;
file.write_all(session.session.as_bytes()).await?;

// Create a new key, and save it to disk
let _key = atuin_client::encryption::new_key(settings)?;
let _key = atuin_client::encryption::load_key(settings)?;

Ok(())
}
43 changes: 1 addition & 42 deletions atuin/src/command/client/history.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,6 @@ pub enum Cmd {
#[arg(long, short)]
format: Option<String>,
},

/// Import all old history.db data into the record store. Do not run more than once, and do not
/// run unless you know what you're doing (or the docs ask you to)
InitStore,
}

#[derive(Clone, Copy, Debug)]
Expand Down Expand Up @@ -321,10 +317,7 @@ impl Cmd {
#[cfg(feature = "sync")]
{
if settings.sync.records {
let (diff, _) = record::sync::diff(settings, &store).await?;
let operations = record::sync::operations(diff, &store).await?;
let (_, downloaded) =
record::sync::sync_remote(operations, &store, settings).await?;
let (_, downloaded) = record::sync::sync(settings, &store).await?;

history_store.incremental_build(db, &downloaded).await?;
} else {
Expand Down Expand Up @@ -380,38 +373,6 @@ impl Cmd {
Ok(())
}

async fn init_store(
context: atuin_client::database::Context,
db: &impl Database,
store: HistoryStore,
) -> Result<()> {
println!("Importing all history.db data into records.db");

println!("Fetching history from old database");
let history = db.list(&[], &context, None, false, true).await?;

println!("Fetching history already in store");
let store_ids = store.history_ids().await?;

for i in history {
println!("loaded {}", i.id);

if store_ids.contains(&i.id) {
println!("skipping {} - already exists", i.id);
continue;
}

if i.deleted_at.is_some() {
store.push(i.clone()).await?;
store.delete(i.id).await?;
} else {
store.push(i).await?;
}
}

Ok(())
}

pub async fn run(
self,
settings: &Settings,
Expand Down Expand Up @@ -468,8 +429,6 @@ impl Cmd {

Ok(())
}

Self::InitStore => Self::init_store(context, db, history_store).await,
}
}
}
24 changes: 18 additions & 6 deletions atuin/src/command/client/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ use clap::Subcommand;
use eyre::{Result, WrapErr};

use atuin_client::{
database::Database,
database::{current_context, Database},
encryption,
history::store::HistoryStore,
record::{sqlite_store::SqliteStore, sync},
record::{sqlite_store::SqliteStore, store::Store, sync},
settings::Settings,
};

Expand Down Expand Up @@ -80,17 +80,29 @@ async fn run(
store: SqliteStore,
) -> Result<()> {
if settings.sync.records {
let (diff, _) = sync::diff(settings, &store).await?;
let operations = sync::operations(diff, &store).await?;
let (uploaded, downloaded) = sync::sync_remote(operations, &store, settings).await?;

let encryption_key: [u8; 32] = encryption::load_key(settings)
.context("could not load encryption key")?
.into();

let host_id = Settings::host_id().expect("failed to get host_id");
let history_store = HistoryStore::new(store.clone(), host_id, encryption_key);

let history_length = db.history_count(true).await?;
let store_history_length = store.len_tag("history").await?;

#[allow(clippy::cast_sign_loss)]
if history_length as u64 > store_history_length {
println!("History DB is longer than history record store");
println!("This happens when you used Atuin pre-record-store");

let context = current_context();
history_store.init_store(context, db).await?;

println!("\n");
}

let (uploaded, downloaded) = sync::sync(settings, &store).await?;

history_store.incremental_build(db, &downloaded).await?;

println!("{uploaded}/{} up/down to record store", downloaded.len());
Expand Down

0 comments on commit 366b8ea

Please sign in to comment.