Skip to content

Fix DefaultRouter type restrained to only MutexGuard #2383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion lightning-background-processor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,22 @@ mod tests {
fn disconnect_socket(&mut self) {}
}

type ChannelManager = channelmanager::ChannelManager<Arc<ChainMonitor>, Arc<test_utils::TestBroadcaster>, Arc<KeysManager>, Arc<KeysManager>, Arc<KeysManager>, Arc<test_utils::TestFeeEstimator>, Arc<DefaultRouter<Arc<NetworkGraph<Arc<test_utils::TestLogger>>>, Arc<test_utils::TestLogger>, Arc<Mutex<TestScorer>>, (), TestScorer>>, Arc<test_utils::TestLogger>>;
type ChannelManager =
channelmanager::ChannelManager<
Arc<ChainMonitor>,
Arc<test_utils::TestBroadcaster>,
Arc<KeysManager>,
Arc<KeysManager>,
Arc<KeysManager>,
Arc<test_utils::TestFeeEstimator>,
Arc<DefaultRouter<
Arc<NetworkGraph<Arc<test_utils::TestLogger>>>,
Arc<test_utils::TestLogger>,
Arc<Mutex<TestScorer>>,
(),
TestScorer>
>,
Arc<test_utils::TestLogger>>;

type ChainMonitor = chainmonitor::ChainMonitor<InMemorySigner, Arc<test_utils::TestChainSource>, Arc<test_utils::TestBroadcaster>, Arc<test_utils::TestFeeEstimator>, Arc<test_utils::TestLogger>, Arc<FilesystemPersister>>;

Expand Down
18 changes: 17 additions & 1 deletion lightning/src/ln/channelmanager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,23 @@ pub type SimpleArcChannelManager<M, T, F, L> = ChannelManager<
/// of [`KeysManager`] and [`DefaultRouter`].
///
/// This is not exported to bindings users as Arcs don't make sense in bindings
pub type SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, M, T, F, L> = ChannelManager<&'a M, &'b T, &'c KeysManager, &'c KeysManager, &'c KeysManager, &'d F, &'e DefaultRouter<&'f NetworkGraph<&'g L>, &'g L, &'h Mutex<ProbabilisticScorer<&'f NetworkGraph<&'g L>, &'g L>>, ProbabilisticScoringFeeParameters, ProbabilisticScorer<&'f NetworkGraph<&'g L>, &'g L>>, &'g L>;
pub type SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, M, T, F, L> =
ChannelManager<
&'a M,
&'b T,
&'c KeysManager,
&'c KeysManager,
&'c KeysManager,
&'d F,
&'e DefaultRouter<
&'f NetworkGraph<&'g L>,
&'g L,
&'h Mutex<ProbabilisticScorer<&'f NetworkGraph<&'g L>, &'g L>>,
ProbabilisticScoringFeeParameters,
ProbabilisticScorer<&'f NetworkGraph<&'g L>, &'g L>
>,
&'g L
>;

macro_rules! define_test_pub_trait { ($vis: vis) => {
/// A trivial trait which describes any [`ChannelManager`] used in testing.
Expand Down
26 changes: 13 additions & 13 deletions lightning/src/routing/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ use crate::util::chacha20::ChaCha20;

use crate::io;
use crate::prelude::*;
use crate::sync::{Mutex, MutexGuard};
use crate::sync::{Mutex};
use alloc::collections::BinaryHeap;
use core::{cmp, fmt};
use core::ops::Deref;
use core::ops::{Deref, DerefMut};

/// A [`Router`] implemented using [`find_route`].
pub struct DefaultRouter<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Score<ScoreParams = SP>> where
L::Target: Logger,
S::Target: for <'a> LockableScore<'a, Locked = MutexGuard<'a, Sc>>,
S::Target: for <'a> LockableScore<'a, Score = Sc>,
{
network_graph: G,
logger: L,
Expand All @@ -46,7 +46,7 @@ pub struct DefaultRouter<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref,

impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Score<ScoreParams = SP>> DefaultRouter<G, L, S, SP, Sc> where
L::Target: Logger,
S::Target: for <'a> LockableScore<'a, Locked = MutexGuard<'a, Sc>>,
S::Target: for <'a> LockableScore<'a, Score = Sc>,
{
/// Creates a new router.
pub fn new(network_graph: G, logger: L, random_seed_bytes: [u8; 32], scorer: S, score_params: SP) -> Self {
Expand All @@ -55,9 +55,9 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Scor
}
}

impl< G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Score<ScoreParams = SP>> Router for DefaultRouter<G, L, S, SP, Sc> where
impl< G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Score<ScoreParams = SP>> Router for DefaultRouter<G, L, S, SP, Sc> where
L::Target: Logger,
S::Target: for <'a> LockableScore<'a, Locked = MutexGuard<'a, Sc>>,
S::Target: for <'a> LockableScore<'a, Score = Sc>,
{
fn find_route(
&self,
Expand All @@ -73,7 +73,7 @@ impl< G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Sc
};
find_route(
payer, params, &self.network_graph, first_hops, &*self.logger,
&ScorerAccountingForInFlightHtlcs::new(self.scorer.lock(), inflight_htlcs),
&ScorerAccountingForInFlightHtlcs::new(self.scorer.lock().deref_mut(), inflight_htlcs),
&self.score_params,
&random_seed_bytes
)
Expand Down Expand Up @@ -104,15 +104,15 @@ pub trait Router {
/// [`find_route`].
///
/// [`Score`]: crate::routing::scoring::Score
pub struct ScorerAccountingForInFlightHtlcs<'a, S: Score> {
scorer: S,
pub struct ScorerAccountingForInFlightHtlcs<'a, S: Score<ScoreParams = SP>, SP: Sized> {
scorer: &'a mut S,
// Maps a channel's short channel id and its direction to the liquidity used up.
inflight_htlcs: &'a InFlightHtlcs,
}

impl<'a, S: Score> ScorerAccountingForInFlightHtlcs<'a, S> {
impl<'a, S: Score<ScoreParams = SP>, SP: Sized> ScorerAccountingForInFlightHtlcs<'a, S, SP> {
/// Initialize a new `ScorerAccountingForInFlightHtlcs`.
pub fn new(scorer: S, inflight_htlcs: &'a InFlightHtlcs) -> Self {
pub fn new(scorer: &'a mut S, inflight_htlcs: &'a InFlightHtlcs) -> Self {
ScorerAccountingForInFlightHtlcs {
scorer,
inflight_htlcs
Expand All @@ -121,11 +121,11 @@ impl<'a, S: Score> ScorerAccountingForInFlightHtlcs<'a, S> {
}

#[cfg(c_bindings)]
impl<'a, S: Score> Writeable for ScorerAccountingForInFlightHtlcs<'a, S> {
impl<'a, S: Score<ScoreParams = SP>, SP: Sized> Writeable for ScorerAccountingForInFlightHtlcs<'a, S, SP> {
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> { self.scorer.write(writer) }
}

impl<'a, S: Score> Score for ScorerAccountingForInFlightHtlcs<'a, S> {
impl<'a, S: Score<ScoreParams = SP>, SP: Sized> Score for ScorerAccountingForInFlightHtlcs<'a, S, SP> {
type ScoreParams = S::ScoreParams;
fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage, score_params: &Self::ScoreParams) -> u64 {
if let Some(used_liquidity) = self.inflight_htlcs.used_liquidity_msat(
Expand Down
77 changes: 41 additions & 36 deletions lightning/src/routing/scoring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,11 @@ define_score!();
///
/// [`find_route`]: crate::routing::router::find_route
pub trait LockableScore<'a> {
/// The [`Score`] type.
type Score: 'a + Score;

/// The locked [`Score`] type.
type Locked: 'a + Score;
type Locked: DerefMut<Target = Self::Score> + Sized;

/// Returns the locked scorer.
fn lock(&'a self) -> Self::Locked;
Expand All @@ -174,60 +177,35 @@ pub trait WriteableScore<'a>: LockableScore<'a> + Writeable {}
impl<'a, T> WriteableScore<'a> for T where T: LockableScore<'a> + Writeable {}
/// This is not exported to bindings users
impl<'a, T: 'a + Score> LockableScore<'a> for Mutex<T> {
type Score = T;
type Locked = MutexGuard<'a, T>;

fn lock(&'a self) -> MutexGuard<'a, T> {
fn lock(&'a self) -> Self::Locked {
Mutex::lock(self).unwrap()
}
}

impl<'a, T: 'a + Score> LockableScore<'a> for RefCell<T> {
type Score = T;
type Locked = RefMut<'a, T>;

fn lock(&'a self) -> RefMut<'a, T> {
fn lock(&'a self) -> Self::Locked {
self.borrow_mut()
}
}

#[cfg(c_bindings)]
/// A concrete implementation of [`LockableScore`] which supports multi-threading.
pub struct MultiThreadedLockableScore<S: Score> {
score: Mutex<S>,
}
#[cfg(c_bindings)]
/// A locked `MultiThreadedLockableScore`.
pub struct MultiThreadedScoreLock<'a, S: Score>(MutexGuard<'a, S>);
#[cfg(c_bindings)]
impl<'a, T: Score + 'a> Score for MultiThreadedScoreLock<'a, T> {
type ScoreParams = <T as Score>::ScoreParams;
fn channel_penalty_msat(&self, scid: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage, score_params: &Self::ScoreParams) -> u64 {
self.0.channel_penalty_msat(scid, source, target, usage, score_params)
}
fn payment_path_failed(&mut self, path: &Path, short_channel_id: u64) {
self.0.payment_path_failed(path, short_channel_id)
}
fn payment_path_successful(&mut self, path: &Path) {
self.0.payment_path_successful(path)
}
fn probe_failed(&mut self, path: &Path, short_channel_id: u64) {
self.0.probe_failed(path, short_channel_id)
}
fn probe_successful(&mut self, path: &Path) {
self.0.probe_successful(path)
}
}
Comment on lines -200 to -218
Copy link
Contributor Author

@henghonglee henghonglee Jul 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now that MultiThreadedScoreLock<'a, T> implements DerefMut, this implementation is in conflict with scoring.rs:120-142 and can be removed.

#[cfg(c_bindings)]
impl<'a, T: Score + 'a> Writeable for MultiThreadedScoreLock<'a, T> {
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
self.0.write(writer)
}
pub struct MultiThreadedLockableScore<T: Score> {
score: Mutex<T>,
}

#[cfg(c_bindings)]
impl<'a, T: Score + 'a> LockableScore<'a> for MultiThreadedLockableScore<T> {
impl<'a, T: 'a + Score> LockableScore<'a> for MultiThreadedLockableScore<T> {
type Score = T;
type Locked = MultiThreadedScoreLock<'a, T>;

fn lock(&'a self) -> MultiThreadedScoreLock<'a, T> {
fn lock(&'a self) -> Self::Locked {
MultiThreadedScoreLock(Mutex::lock(&self.score).unwrap())
}
}
Expand All @@ -240,7 +218,7 @@ impl<T: Score> Writeable for MultiThreadedLockableScore<T> {
}

#[cfg(c_bindings)]
impl<'a, T: Score + 'a> WriteableScore<'a> for MultiThreadedLockableScore<T> {}
impl<'a, T: 'a + Score> WriteableScore<'a> for MultiThreadedLockableScore<T> {}

#[cfg(c_bindings)]
impl<T: Score> MultiThreadedLockableScore<T> {
Expand All @@ -250,6 +228,33 @@ impl<T: Score> MultiThreadedLockableScore<T> {
}
}

#[cfg(c_bindings)]
/// A locked `MultiThreadedLockableScore`.
pub struct MultiThreadedScoreLock<'a, T: Score>(MutexGuard<'a, T>);

#[cfg(c_bindings)]
impl<'a, T: 'a + Score> Writeable for MultiThreadedScoreLock<'a, T> {
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
self.0.write(writer)
}
}

#[cfg(c_bindings)]
impl<'a, T: 'a + Score> DerefMut for MultiThreadedScoreLock<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.deref_mut()
}
}

#[cfg(c_bindings)]
impl<'a, T: 'a + Score> Deref for MultiThreadedScoreLock<'a, T> {
type Target = T;

fn deref(&self) -> &Self::Target {
self.0.deref()
}
}

Comment on lines +231 to +257
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rearranged MultiThreadedScoreLock and MultiThreadedLockableScore methods to be colocated since their names are too similar and methods were interweaved

#[cfg(c_bindings)]
/// This is not exported to bindings users
impl<'a, T: Writeable> Writeable for RefMut<'a, T> {
Expand Down
8 changes: 4 additions & 4 deletions lightning/src/util/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ use regex;
use crate::io;
use crate::prelude::*;
use core::cell::RefCell;
use core::ops::DerefMut;
use core::time::Duration;
use crate::sync::{Mutex, Arc};
use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
Expand Down Expand Up @@ -113,8 +114,8 @@ impl<'a> Router for TestRouter<'a> {
if let Some((find_route_query, find_route_res)) = self.next_routes.lock().unwrap().pop_front() {
assert_eq!(find_route_query, *params);
if let Ok(ref route) = find_route_res {
let locked_scorer = self.scorer.lock().unwrap();
let scorer = ScorerAccountingForInFlightHtlcs::new(locked_scorer, inflight_htlcs);
let mut binding = self.scorer.lock().unwrap();
let scorer = ScorerAccountingForInFlightHtlcs::new(binding.deref_mut(), inflight_htlcs);
for path in &route.paths {
let mut aggregate_msat = 0u64;
for (idx, hop) in path.hops.iter().rev().enumerate() {
Expand All @@ -139,10 +140,9 @@ impl<'a> Router for TestRouter<'a> {
return find_route_res;
}
let logger = TestLogger::new();
let scorer = self.scorer.lock().unwrap();
find_route(
payer, params, &self.network_graph, first_hops, &logger,
&ScorerAccountingForInFlightHtlcs::new(scorer, &inflight_htlcs), &(),
&ScorerAccountingForInFlightHtlcs::new(self.scorer.lock().unwrap().deref_mut(), &inflight_htlcs), &(),
&[42; 32]
)
}
Expand Down