Skip to content

Commit

Permalink
Merge pull request #186 from ikatson/postgres
Browse files Browse the repository at this point in the history
[Feature] postgres backend for session persistence
  • Loading branch information
ikatson authored Aug 15, 2024
2 parents 9d39411 + 2871c35 commit 4c0bb08
Show file tree
Hide file tree
Showing 10 changed files with 805 additions and 30 deletions.
595 changes: 576 additions & 19 deletions Cargo.lock

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ devserver:
--log-file-rust-log=debug,librqbit=trace \
server start /tmp/scratch/

@PHONY: devserver
devserver-postgres:
echo -n '' > /tmp/rqbit-log && cargo run -- \
--log-file /tmp/rqbit-log \
--log-file-rust-log=debug,librqbit=trace \
server start --persistence-config postgres:///rqbit /tmp/scratch/

@PHONY: clean
clean:
rm -rf target
Expand Down
5 changes: 5 additions & 0 deletions crates/librqbit/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@ default-tls = ["reqwest/default-tls"]
rust-tls = ["reqwest/rustls-tls"]
storage_middleware = ["lru"]
storage_examples = []
postgres = ["sqlx"]

[dependencies]
sqlx = { version = "0.7", features = [
"runtime-tokio",
"postgres",
], optional = true }
bencode = { path = "../bencode", default-features = false, package = "librqbit-bencode", version = "2.2.3" }
tracker_comms = { path = "../tracker_comms", default-features = false, package = "librqbit-tracker-comms", version = "1.0.3" }
buffers = { path = "../buffers", package = "librqbit-buffers", version = "3.0.1" }
Expand Down
8 changes: 8 additions & 0 deletions crates/librqbit/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,8 @@ impl<'a> AddTorrent<'a> {
pub enum SessionPersistenceConfig {
/// The filename for persistence. By default uses an OS-specific folder.
Json { folder: Option<PathBuf> },
#[cfg(feature = "postgres")]
Postgres { connection_string: String },
}

impl SessionPersistenceConfig {
Expand Down Expand Up @@ -494,6 +496,12 @@ impl Session {
.await
.context("error initializing JsonSessionPersistenceStore")?,
)))
},
#[cfg(feature = "postgres")]
Some(SessionPersistenceConfig::Postgres { connection_string }) => {
use crate::session_persistence::postgres::PostgresSessionStorage;
let p = PostgresSessionStorage::new(connection_string).await?;
Ok(Some(Box::new(p)))
}
None => Ok(None),
}
Expand Down
2 changes: 2 additions & 0 deletions crates/librqbit/src/session_persistence/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
pub mod json;
#[cfg(feature = "postgres")]
pub mod postgres;

use std::{collections::HashSet, path::PathBuf};

Expand Down
169 changes: 169 additions & 0 deletions crates/librqbit/src/session_persistence/postgres.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
use std::path::PathBuf;

use crate::{session::TorrentId, torrent_state::ManagedTorrentHandle};
use anyhow::Context;
use futures::{stream::BoxStream, StreamExt};
use librqbit_core::Id20;
use sqlx::{Pool, Postgres};

use super::{SerializedTorrent, SessionPersistenceStore};

#[derive(Debug)]
pub struct PostgresSessionStorage {
pool: Pool<Postgres>,
}

#[derive(sqlx::FromRow)]
struct TorrentsTableRecord {
id: i32,
info_hash: Vec<u8>,
torrent_bytes: Vec<u8>,
trackers: Vec<String>,
output_folder: String,
only_files: Option<Vec<i32>>,
is_paused: bool,
}

impl TorrentsTableRecord {
fn into_serialized_torrent(self) -> Option<(TorrentId, SerializedTorrent)> {
Some((
self.id as TorrentId,
SerializedTorrent {
info_hash: Id20::from_bytes(&self.info_hash).ok()?,
torrent_bytes: self.torrent_bytes.into(),
trackers: self.trackers.into_iter().collect(),
output_folder: PathBuf::from(self.output_folder),
only_files: self
.only_files
.map(|v| v.into_iter().map(|v| v as usize).collect()),
is_paused: self.is_paused,
},
))
}
}

impl PostgresSessionStorage {
pub async fn new(connection_string: &str) -> anyhow::Result<Self> {
use sqlx::postgres::PgPoolOptions;

let pool = PgPoolOptions::new()
.max_connections(1)
.connect(connection_string)
.await?;

sqlx::query("CREATE SEQUENCE IF NOT EXISTS torrents_id AS integer;")
.execute(&pool)
.await
.context("error executing CREATE SEQUENCE")?;

let create_q = "CREATE TABLE IF NOT EXISTS torrents (
id INTEGER PRIMARY KEY DEFAULT nextval('torrents_id'),
info_hash BYTEA NOT NULL,
torrent_bytes BYTEA NOT NULL,
trackers TEXT[] NOT NULL,
output_folder TEXT NOT NULL,
only_files INTEGER[],
is_paused BOOLEAN NOT NULL
)";
sqlx::query(create_q)
.execute(&pool)
.await
.context("error executing CREATE TABLE")?;

Ok(Self { pool })
}
}

#[async_trait::async_trait]
impl SessionPersistenceStore for PostgresSessionStorage {
async fn next_id(&self) -> anyhow::Result<TorrentId> {
let (id,): (i32,) = sqlx::query_as("SELECT nextval('torrents_id')::int")
.fetch_one(&self.pool)
.await
.context("error executing SELECT nextval")?;
Ok(id as usize)
}

async fn store(&self, id: TorrentId, torrent: &ManagedTorrentHandle) -> anyhow::Result<()> {
let torrent_bytes: &[u8] = &torrent.info().torrent_bytes;
let q = "INSERT INTO torrents (id, info_hash, torrent_bytes, trackers, output_folder, only_files, is_paused)
VALUES($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT(id) DO NOTHING";
sqlx::query(q)
.bind::<i32>(id.try_into()?)
.bind(&torrent.info_hash().0[..])
.bind(torrent_bytes)
.bind(torrent.info().trackers.iter().cloned().collect::<Vec<_>>())
.bind(
torrent
.info()
.options
.output_folder
.to_str()
.context("output_folder")?
.to_owned(),
)
.bind(torrent.only_files().map(|o| {
o.into_iter()
.filter_map(|o| o.try_into().ok())
.collect::<Vec<i32>>()
}))
.bind(torrent.is_paused())
.execute(&self.pool)
.await
.context("error executing INSERT INTO torrents")?;
Ok(())
}

async fn delete(&self, id: TorrentId) -> anyhow::Result<()> {
sqlx::query("DELETE FROM torrents WHERE id = $1")
.bind::<i32>(id.try_into()?)
.execute(&self.pool)
.await
.context("error executing DELETE FROM torrents")?;
Ok(())
}

async fn get(&self, id: TorrentId) -> anyhow::Result<SerializedTorrent> {
let row = sqlx::query_as::<_, TorrentsTableRecord>("SELECT * FROM torrents WHERE id = ?")
.bind::<i32>(id.try_into()?)
.fetch_one(&self.pool)
.await
.context("error executing SELECT * FROM torrents")?;
row.into_serialized_torrent()
.context("bug")
.map(|(_, st)| st)
}

async fn update_metadata(
&self,
id: TorrentId,
torrent: &ManagedTorrentHandle,
) -> anyhow::Result<()> {
sqlx::query("UPDATE torrents SET only_files = $1, is_paused = $2 WHERE id = $3")
.bind(torrent.only_files().map(|v| {
v.into_iter()
.filter_map(|f| f.try_into().ok())
.collect::<Vec<i32>>()
}))
.bind(torrent.is_paused())
.bind::<i32>(id.try_into()?)
.execute(&self.pool)
.await
.context("error executing UPDATE torrents")?;
Ok(())
}

async fn stream_all(
&self,
) -> anyhow::Result<BoxStream<'_, anyhow::Result<(TorrentId, SerializedTorrent)>>> {
let torrents = sqlx::query_as::<_, TorrentsTableRecord>("SELECT * FROM torrents")
.fetch_all(&self.pool)
.await
.context("error executing SELECT * FROM torrents")?
.into_iter()
.filter_map(TorrentsTableRecord::into_serialized_torrent)
.map(Ok);
Ok(futures::stream::iter(torrents).boxed())
}
}
4 changes: 4 additions & 0 deletions crates/librqbit/src/torrent_state/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,10 @@ impl ManagedTorrent {
}
}

pub fn is_paused(&self) -> bool {
self.with_state(|s| matches!(s, ManagedTorrentState::Paused(..)))
}

/// Pause the torrent if it's live.
pub(crate) fn pause(&self) -> anyhow::Result<()> {
let mut g = self.locked.write();
Expand Down
15 changes: 10 additions & 5 deletions crates/librqbit_core/src/hash_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ impl<const N: usize> Id<N> {
hex::encode(self.0)
}

pub fn from_bytes(b: &[u8]) -> anyhow::Result<Self> {
let mut v = [0u8; N];
if b.len() != N {
anyhow::bail!("buffer length must be {}, but it's {}", N, b.len());
}
v.copy_from_slice(b);
Ok(Id(v))
}

pub fn distance(&self, other: &Id<N>) -> Id<N> {
let mut xor = [0u8; N];
for (idx, (s, o)) in self
Expand Down Expand Up @@ -81,11 +90,7 @@ impl<const N: usize> FromStr for Id<N> {
Ok(Id(out))
}
Err(err) => {
anyhow::bail!(
"fail to decode base32 string {}: {}",
s,
err
)
anyhow::bail!("fail to decode base32 string {}: {}", s, err)
}
}
} else {
Expand Down
3 changes: 2 additions & 1 deletion crates/rqbit/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@ readme = "README.md"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[features]
default = ["default-tls", "webui"]
default = ["default-tls", "webui", "postgres"]
openssl-vendored = ["openssl/vendored"]
tokio-console = ["console-subscriber", "tokio/tracing"]
webui = ["librqbit/webui"]
timed_existence = ["librqbit/timed_existence"]
default-tls = ["librqbit/default-tls"]
rust-tls = ["librqbit/rust-tls"]
debug_slow_disk = ["librqbit/storage_middleware"]
postgres = ["librqbit/postgres"]

[dependencies]
librqbit = { path = "../librqbit", default-features = false, version = "6.0.0" }
Expand Down
27 changes: 22 additions & 5 deletions crates/rqbit/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ struct ServerStartOptions {
disable_persistence: bool,

/// The folder to store session data in. By default uses OS specific folder.
#[arg(long = "persistence-folder")]
persistence_folder: Option<String>,
#[arg(long = "persistence-config")]
persistence_config: Option<String>,
}

#[derive(Parser)]
Expand Down Expand Up @@ -393,9 +393,26 @@ async fn async_main(opts: Opts) -> anyhow::Result<()> {
SubCommand::Server(server_opts) => match &server_opts.subcommand {
ServerSubcommand::Start(start_opts) => {
if !start_opts.disable_persistence {
sopts.persistence = Some(SessionPersistenceConfig::Json {
folder: start_opts.persistence_folder.clone().map(PathBuf::from),
})
if let Some(p) = start_opts.persistence_config.as_ref() {
if p.starts_with("postgres://") {
#[cfg(feature = "postgres")]
{
sopts.persistence = Some(SessionPersistenceConfig::Postgres {
connection_string: p.clone(),
})
}
#[cfg(not(feature = "postgres"))]
{
anyhow::bail!("rqbit was compiled without postgres support")
}
} else {
sopts.persistence = Some(SessionPersistenceConfig::Json {
folder: Some(p.into()),
})
}
} else {
sopts.persistence = Some(SessionPersistenceConfig::Json { folder: None })
}
}

let session =
Expand Down

0 comments on commit 4c0bb08

Please sign in to comment.