Skip to content

Commit

Permalink
Fix DefaultRouter type restrained to only MutexGuard
Browse files Browse the repository at this point in the history
Type of DerefMut for DefaultRouter was specialized to only MutexGuard.
It should be generic around RefMut and MutexGuard. This commit fixes that
  • Loading branch information
henghonglee committed Jul 4, 2023
1 parent 86fd9e7 commit 54bcb6e
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 55 deletions.
17 changes: 16 additions & 1 deletion lightning-background-processor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,22 @@ mod tests {
fn disconnect_socket(&mut self) {}
}

type ChannelManager = channelmanager::ChannelManager<Arc<ChainMonitor>, Arc<test_utils::TestBroadcaster>, Arc<KeysManager>, Arc<KeysManager>, Arc<KeysManager>, Arc<test_utils::TestFeeEstimator>, Arc<DefaultRouter<Arc<NetworkGraph<Arc<test_utils::TestLogger>>>, Arc<test_utils::TestLogger>, Arc<Mutex<TestScorer>>, (), TestScorer>>, Arc<test_utils::TestLogger>>;
type ChannelManager =
channelmanager::ChannelManager<
Arc<ChainMonitor>,
Arc<test_utils::TestBroadcaster>,
Arc<KeysManager>,
Arc<KeysManager>,
Arc<KeysManager>,
Arc<test_utils::TestFeeEstimator>,
Arc<DefaultRouter<
Arc<NetworkGraph<Arc<test_utils::TestLogger>>>,
Arc<test_utils::TestLogger>,
Arc<Mutex<TestScorer>>,
(),
TestScorer>
>,
Arc<test_utils::TestLogger>>;

type ChainMonitor = chainmonitor::ChainMonitor<InMemorySigner, Arc<test_utils::TestChainSource>, Arc<test_utils::TestBroadcaster>, Arc<test_utils::TestFeeEstimator>, Arc<test_utils::TestLogger>, Arc<FilesystemPersister>>;

Expand Down
18 changes: 17 additions & 1 deletion lightning/src/ln/channelmanager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,23 @@ pub type SimpleArcChannelManager<M, T, F, L> = ChannelManager<
/// of [`KeysManager`] and [`DefaultRouter`].
///
/// This is not exported to bindings users as Arcs don't make sense in bindings
pub type SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, M, T, F, L> = ChannelManager<&'a M, &'b T, &'c KeysManager, &'c KeysManager, &'c KeysManager, &'d F, &'e DefaultRouter<&'f NetworkGraph<&'g L>, &'g L, &'h Mutex<ProbabilisticScorer<&'f NetworkGraph<&'g L>, &'g L>>, ProbabilisticScoringFeeParameters, ProbabilisticScorer<&'f NetworkGraph<&'g L>, &'g L>>, &'g L>;
pub type SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, M, T, F, L> =
ChannelManager<
&'a M,
&'b T,
&'c KeysManager,
&'c KeysManager,
&'c KeysManager,
&'d F,
&'e DefaultRouter<
&'f NetworkGraph<&'g L>,
&'g L,
&'h Mutex<ProbabilisticScorer<&'f NetworkGraph<&'g L>, &'g L>>,
ProbabilisticScoringFeeParameters,
ProbabilisticScorer<&'f NetworkGraph<&'g L>, &'g L>
>,
&'g L
>;

macro_rules! define_test_pub_trait { ($vis: vis) => {
/// A trivial trait which describes any [`ChannelManager`] used in testing.
Expand Down
26 changes: 13 additions & 13 deletions lightning/src/routing/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ use crate::util::chacha20::ChaCha20;

use crate::io;
use crate::prelude::*;
use crate::sync::{Mutex, MutexGuard};
use crate::sync::{Mutex};
use alloc::collections::BinaryHeap;
use core::{cmp, fmt};
use core::ops::Deref;
use core::ops::{Deref, DerefMut};

/// A [`Router`] implemented using [`find_route`].
pub struct DefaultRouter<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Score<ScoreParams = SP>> where
L::Target: Logger,
S::Target: for <'a> LockableScore<'a, Locked = MutexGuard<'a, Sc>>,
S::Target: for <'a> LockableScore<'a, Score = Sc>,
{
network_graph: G,
logger: L,
Expand All @@ -46,7 +46,7 @@ pub struct DefaultRouter<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref,

impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Score<ScoreParams = SP>> DefaultRouter<G, L, S, SP, Sc> where
L::Target: Logger,
S::Target: for <'a> LockableScore<'a, Locked = MutexGuard<'a, Sc>>,
S::Target: for <'a> LockableScore<'a, Score = Sc>,
{
/// Creates a new router.
pub fn new(network_graph: G, logger: L, random_seed_bytes: [u8; 32], scorer: S, score_params: SP) -> Self {
Expand All @@ -55,9 +55,9 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Scor
}
}

impl< G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Score<ScoreParams = SP>> Router for DefaultRouter<G, L, S, SP, Sc> where
impl< G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Score<ScoreParams = SP>> Router for DefaultRouter<G, L, S, SP, Sc> where
L::Target: Logger,
S::Target: for <'a> LockableScore<'a, Locked = MutexGuard<'a, Sc>>,
S::Target: for <'a> LockableScore<'a, Score = Sc>,
{
fn find_route(
&self,
Expand All @@ -73,7 +73,7 @@ impl< G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Sc
};
find_route(
payer, params, &self.network_graph, first_hops, &*self.logger,
&ScorerAccountingForInFlightHtlcs::new(self.scorer.lock(), inflight_htlcs),
&ScorerAccountingForInFlightHtlcs::new(self.scorer.lock().deref_mut(), inflight_htlcs),
&self.score_params,
&random_seed_bytes
)
Expand Down Expand Up @@ -104,15 +104,15 @@ pub trait Router {
/// [`find_route`].
///
/// [`Score`]: crate::routing::scoring::Score
pub struct ScorerAccountingForInFlightHtlcs<'a, S: Score> {
scorer: S,
pub struct ScorerAccountingForInFlightHtlcs<'a, S: Score<ScoreParams = SP>, SP: Sized> {
scorer: &'a mut S,
// Maps a channel's short channel id and its direction to the liquidity used up.
inflight_htlcs: &'a InFlightHtlcs,
}

impl<'a, S: Score> ScorerAccountingForInFlightHtlcs<'a, S> {
impl<'a, S: Score<ScoreParams = SP>, SP: Sized> ScorerAccountingForInFlightHtlcs<'a, S, SP> {
/// Initialize a new `ScorerAccountingForInFlightHtlcs`.
pub fn new(scorer: S, inflight_htlcs: &'a InFlightHtlcs) -> Self {
pub fn new(scorer: &'a mut S, inflight_htlcs: &'a InFlightHtlcs) -> Self {
ScorerAccountingForInFlightHtlcs {
scorer,
inflight_htlcs
Expand All @@ -121,11 +121,11 @@ impl<'a, S: Score> ScorerAccountingForInFlightHtlcs<'a, S> {
}

#[cfg(c_bindings)]
impl<'a, S: Score> Writeable for ScorerAccountingForInFlightHtlcs<'a, S> {
impl<'a, S: Score<ScoreParams = SP>, SP: Sized> Writeable for ScorerAccountingForInFlightHtlcs<'a, S, SP> {
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> { self.scorer.write(writer) }
}

impl<'a, S: Score> Score for ScorerAccountingForInFlightHtlcs<'a, S> {
impl<'a, S: Score<ScoreParams = SP>, SP: Sized> Score for ScorerAccountingForInFlightHtlcs<'a, S, SP> {
type ScoreParams = S::ScoreParams;
fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage, score_params: &Self::ScoreParams) -> u64 {
if let Some(used_liquidity) = self.inflight_htlcs.used_liquidity_msat(
Expand Down
77 changes: 41 additions & 36 deletions lightning/src/routing/scoring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,11 @@ define_score!();
///
/// [`find_route`]: crate::routing::router::find_route
pub trait LockableScore<'a> {
/// The [`Score`] type.
type Score: 'a + Score;

/// The locked [`Score`] type.
type Locked: 'a + Score;
type Locked: DerefMut<Target = Self::Score> + Sized;

/// Returns the locked scorer.
fn lock(&'a self) -> Self::Locked;
Expand All @@ -174,60 +177,35 @@ pub trait WriteableScore<'a>: LockableScore<'a> + Writeable {}
impl<'a, T> WriteableScore<'a> for T where T: LockableScore<'a> + Writeable {}
/// This is not exported to bindings users
impl<'a, T: 'a + Score> LockableScore<'a> for Mutex<T> {
type Score = T;
type Locked = MutexGuard<'a, T>;

fn lock(&'a self) -> MutexGuard<'a, T> {
fn lock(&'a self) -> Self::Locked {
Mutex::lock(self).unwrap()
}
}

impl<'a, T: 'a + Score> LockableScore<'a> for RefCell<T> {
type Score = T;
type Locked = RefMut<'a, T>;

fn lock(&'a self) -> RefMut<'a, T> {
fn lock(&'a self) -> Self::Locked {
self.borrow_mut()
}
}

#[cfg(c_bindings)]
/// A concrete implementation of [`LockableScore`] which supports multi-threading.
pub struct MultiThreadedLockableScore<S: Score> {
score: Mutex<S>,
}
#[cfg(c_bindings)]
/// A locked `MultiThreadedLockableScore`.
pub struct MultiThreadedScoreLock<'a, S: Score>(MutexGuard<'a, S>);
#[cfg(c_bindings)]
impl<'a, T: Score + 'a> Score for MultiThreadedScoreLock<'a, T> {
type ScoreParams = <T as Score>::ScoreParams;
fn channel_penalty_msat(&self, scid: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage, score_params: &Self::ScoreParams) -> u64 {
self.0.channel_penalty_msat(scid, source, target, usage, score_params)
}
fn payment_path_failed(&mut self, path: &Path, short_channel_id: u64) {
self.0.payment_path_failed(path, short_channel_id)
}
fn payment_path_successful(&mut self, path: &Path) {
self.0.payment_path_successful(path)
}
fn probe_failed(&mut self, path: &Path, short_channel_id: u64) {
self.0.probe_failed(path, short_channel_id)
}
fn probe_successful(&mut self, path: &Path) {
self.0.probe_successful(path)
}
}
#[cfg(c_bindings)]
impl<'a, T: Score + 'a> Writeable for MultiThreadedScoreLock<'a, T> {
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
self.0.write(writer)
}
pub struct MultiThreadedLockableScore<T: Score> {
score: Mutex<T>,
}

#[cfg(c_bindings)]
impl<'a, T: Score + 'a> LockableScore<'a> for MultiThreadedLockableScore<T> {
impl<'a, T: 'a + Score> LockableScore<'a> for MultiThreadedLockableScore<T> {
type Score = T;
type Locked = MultiThreadedScoreLock<'a, T>;

fn lock(&'a self) -> MultiThreadedScoreLock<'a, T> {
fn lock(&'a self) -> Self::Locked {
MultiThreadedScoreLock(Mutex::lock(&self.score).unwrap())
}
}
Expand All @@ -240,7 +218,7 @@ impl<T: Score> Writeable for MultiThreadedLockableScore<T> {
}

#[cfg(c_bindings)]
impl<'a, T: Score + 'a> WriteableScore<'a> for MultiThreadedLockableScore<T> {}
impl<'a, T: 'a + Score> WriteableScore<'a> for MultiThreadedLockableScore<T> {}

#[cfg(c_bindings)]
impl<T: Score> MultiThreadedLockableScore<T> {
Expand All @@ -250,6 +228,33 @@ impl<T: Score> MultiThreadedLockableScore<T> {
}
}

#[cfg(c_bindings)]
/// A locked `MultiThreadedLockableScore`.
pub struct MultiThreadedScoreLock<'a, T: Score>(MutexGuard<'a, T>);

#[cfg(c_bindings)]
impl<'a, T: 'a + Score> Writeable for MultiThreadedScoreLock<'a, T> {
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
self.0.write(writer)
}
}

#[cfg(c_bindings)]
impl<'a, T: 'a + Score> DerefMut for MultiThreadedScoreLock<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.deref_mut()
}
}

#[cfg(c_bindings)]
impl<'a, T: 'a + Score> Deref for MultiThreadedScoreLock<'a, T> {
type Target = T;

fn deref(&self) -> &Self::Target {
self.0.deref()
}
}

#[cfg(c_bindings)]
/// This is not exported to bindings users
impl<'a, T: Writeable> Writeable for RefMut<'a, T> {
Expand Down
8 changes: 4 additions & 4 deletions lightning/src/util/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ use regex;
use crate::io;
use crate::prelude::*;
use core::cell::RefCell;
use core::ops::DerefMut;
use core::time::Duration;
use crate::sync::{Mutex, Arc};
use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
Expand Down Expand Up @@ -113,8 +114,8 @@ impl<'a> Router for TestRouter<'a> {
if let Some((find_route_query, find_route_res)) = self.next_routes.lock().unwrap().pop_front() {
assert_eq!(find_route_query, *params);
if let Ok(ref route) = find_route_res {
let locked_scorer = self.scorer.lock().unwrap();
let scorer = ScorerAccountingForInFlightHtlcs::new(locked_scorer, inflight_htlcs);
let mut binding = self.scorer.lock().unwrap();
let scorer = ScorerAccountingForInFlightHtlcs::new(binding.deref_mut(), inflight_htlcs);
for path in &route.paths {
let mut aggregate_msat = 0u64;
for (idx, hop) in path.hops.iter().rev().enumerate() {
Expand All @@ -139,10 +140,9 @@ impl<'a> Router for TestRouter<'a> {
return find_route_res;
}
let logger = TestLogger::new();
let scorer = self.scorer.lock().unwrap();
find_route(
payer, params, &self.network_graph, first_hops, &logger,
&ScorerAccountingForInFlightHtlcs::new(scorer, &inflight_htlcs), &(),
&ScorerAccountingForInFlightHtlcs::new(self.scorer.lock().unwrap().deref_mut(), &inflight_htlcs), &(),
&[42; 32]
)
}
Expand Down

0 comments on commit 54bcb6e

Please sign in to comment.