Skip to content

Commit 54bcb6e

Browse files
committed
Fix DefaultRouter type restrained to only MutexGuard
Type of DerefMut for DefaultRouter was specialized to only MutexGuard. It should be generic around RefMut and MutexGuard. This commit fixes that
1 parent 86fd9e7 commit 54bcb6e

File tree

5 files changed

+91
-55
lines changed

5 files changed

+91
-55
lines changed

lightning-background-processor/src/lib.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -885,7 +885,22 @@ mod tests {
885885
fn disconnect_socket(&mut self) {}
886886
}
887887

888-
type ChannelManager = channelmanager::ChannelManager<Arc<ChainMonitor>, Arc<test_utils::TestBroadcaster>, Arc<KeysManager>, Arc<KeysManager>, Arc<KeysManager>, Arc<test_utils::TestFeeEstimator>, Arc<DefaultRouter<Arc<NetworkGraph<Arc<test_utils::TestLogger>>>, Arc<test_utils::TestLogger>, Arc<Mutex<TestScorer>>, (), TestScorer>>, Arc<test_utils::TestLogger>>;
888+
type ChannelManager =
889+
channelmanager::ChannelManager<
890+
Arc<ChainMonitor>,
891+
Arc<test_utils::TestBroadcaster>,
892+
Arc<KeysManager>,
893+
Arc<KeysManager>,
894+
Arc<KeysManager>,
895+
Arc<test_utils::TestFeeEstimator>,
896+
Arc<DefaultRouter<
897+
Arc<NetworkGraph<Arc<test_utils::TestLogger>>>,
898+
Arc<test_utils::TestLogger>,
899+
Arc<Mutex<TestScorer>>,
900+
(),
901+
TestScorer>
902+
>,
903+
Arc<test_utils::TestLogger>>;
889904

890905
type ChainMonitor = chainmonitor::ChainMonitor<InMemorySigner, Arc<test_utils::TestChainSource>, Arc<test_utils::TestBroadcaster>, Arc<test_utils::TestFeeEstimator>, Arc<test_utils::TestLogger>, Arc<FilesystemPersister>>;
891906

lightning/src/ln/channelmanager.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,23 @@ pub type SimpleArcChannelManager<M, T, F, L> = ChannelManager<
752752
/// of [`KeysManager`] and [`DefaultRouter`].
753753
///
754754
/// This is not exported to bindings users as Arcs don't make sense in bindings
755-
pub type SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, M, T, F, L> = ChannelManager<&'a M, &'b T, &'c KeysManager, &'c KeysManager, &'c KeysManager, &'d F, &'e DefaultRouter<&'f NetworkGraph<&'g L>, &'g L, &'h Mutex<ProbabilisticScorer<&'f NetworkGraph<&'g L>, &'g L>>, ProbabilisticScoringFeeParameters, ProbabilisticScorer<&'f NetworkGraph<&'g L>, &'g L>>, &'g L>;
755+
pub type SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, M, T, F, L> =
756+
ChannelManager<
757+
&'a M,
758+
&'b T,
759+
&'c KeysManager,
760+
&'c KeysManager,
761+
&'c KeysManager,
762+
&'d F,
763+
&'e DefaultRouter<
764+
&'f NetworkGraph<&'g L>,
765+
&'g L,
766+
&'h Mutex<ProbabilisticScorer<&'f NetworkGraph<&'g L>, &'g L>>,
767+
ProbabilisticScoringFeeParameters,
768+
ProbabilisticScorer<&'f NetworkGraph<&'g L>, &'g L>
769+
>,
770+
&'g L
771+
>;
756772

757773
macro_rules! define_test_pub_trait { ($vis: vis) => {
758774
/// A trivial trait which describes any [`ChannelManager`] used in testing.

lightning/src/routing/router.rs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ use crate::util::chacha20::ChaCha20;
2727

2828
use crate::io;
2929
use crate::prelude::*;
30-
use crate::sync::{Mutex, MutexGuard};
30+
use crate::sync::{Mutex};
3131
use alloc::collections::BinaryHeap;
3232
use core::{cmp, fmt};
33-
use core::ops::Deref;
33+
use core::ops::{Deref, DerefMut};
3434

3535
/// A [`Router`] implemented using [`find_route`].
3636
pub struct DefaultRouter<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Score<ScoreParams = SP>> where
3737
L::Target: Logger,
38-
S::Target: for <'a> LockableScore<'a, Locked = MutexGuard<'a, Sc>>,
38+
S::Target: for <'a> LockableScore<'a, Score = Sc>,
3939
{
4040
network_graph: G,
4141
logger: L,
@@ -46,7 +46,7 @@ pub struct DefaultRouter<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref,
4646

4747
impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Score<ScoreParams = SP>> DefaultRouter<G, L, S, SP, Sc> where
4848
L::Target: Logger,
49-
S::Target: for <'a> LockableScore<'a, Locked = MutexGuard<'a, Sc>>,
49+
S::Target: for <'a> LockableScore<'a, Score = Sc>,
5050
{
5151
/// Creates a new router.
5252
pub fn new(network_graph: G, logger: L, random_seed_bytes: [u8; 32], scorer: S, score_params: SP) -> Self {
@@ -55,9 +55,9 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Scor
5555
}
5656
}
5757

58-
impl< G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Score<ScoreParams = SP>> Router for DefaultRouter<G, L, S, SP, Sc> where
58+
impl< G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Score<ScoreParams = SP>> Router for DefaultRouter<G, L, S, SP, Sc> where
5959
L::Target: Logger,
60-
S::Target: for <'a> LockableScore<'a, Locked = MutexGuard<'a, Sc>>,
60+
S::Target: for <'a> LockableScore<'a, Score = Sc>,
6161
{
6262
fn find_route(
6363
&self,
@@ -73,7 +73,7 @@ impl< G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Sc
7373
};
7474
find_route(
7575
payer, params, &self.network_graph, first_hops, &*self.logger,
76-
&ScorerAccountingForInFlightHtlcs::new(self.scorer.lock(), inflight_htlcs),
76+
&ScorerAccountingForInFlightHtlcs::new(self.scorer.lock().deref_mut(), inflight_htlcs),
7777
&self.score_params,
7878
&random_seed_bytes
7979
)
@@ -104,15 +104,15 @@ pub trait Router {
104104
/// [`find_route`].
105105
///
106106
/// [`Score`]: crate::routing::scoring::Score
107-
pub struct ScorerAccountingForInFlightHtlcs<'a, S: Score> {
108-
scorer: S,
107+
pub struct ScorerAccountingForInFlightHtlcs<'a, S: Score<ScoreParams = SP>, SP: Sized> {
108+
scorer: &'a mut S,
109109
// Maps a channel's short channel id and its direction to the liquidity used up.
110110
inflight_htlcs: &'a InFlightHtlcs,
111111
}
112112

113-
impl<'a, S: Score> ScorerAccountingForInFlightHtlcs<'a, S> {
113+
impl<'a, S: Score<ScoreParams = SP>, SP: Sized> ScorerAccountingForInFlightHtlcs<'a, S, SP> {
114114
/// Initialize a new `ScorerAccountingForInFlightHtlcs`.
115-
pub fn new(scorer: S, inflight_htlcs: &'a InFlightHtlcs) -> Self {
115+
pub fn new(scorer: &'a mut S, inflight_htlcs: &'a InFlightHtlcs) -> Self {
116116
ScorerAccountingForInFlightHtlcs {
117117
scorer,
118118
inflight_htlcs
@@ -121,11 +121,11 @@ impl<'a, S: Score> ScorerAccountingForInFlightHtlcs<'a, S> {
121121
}
122122

123123
#[cfg(c_bindings)]
124-
impl<'a, S: Score> Writeable for ScorerAccountingForInFlightHtlcs<'a, S> {
124+
impl<'a, S: Score<ScoreParams = SP>, SP: Sized> Writeable for ScorerAccountingForInFlightHtlcs<'a, S, SP> {
125125
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> { self.scorer.write(writer) }
126126
}
127127

128-
impl<'a, S: Score> Score for ScorerAccountingForInFlightHtlcs<'a, S> {
128+
impl<'a, S: Score<ScoreParams = SP>, SP: Sized> Score for ScorerAccountingForInFlightHtlcs<'a, S, SP> {
129129
type ScoreParams = S::ScoreParams;
130130
fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage, score_params: &Self::ScoreParams) -> u64 {
131131
if let Some(used_liquidity) = self.inflight_htlcs.used_liquidity_msat(

lightning/src/routing/scoring.rs

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,11 @@ define_score!();
157157
///
158158
/// [`find_route`]: crate::routing::router::find_route
159159
pub trait LockableScore<'a> {
160+
/// The [`Score`] type.
161+
type Score: 'a + Score;
162+
160163
/// The locked [`Score`] type.
161-
type Locked: 'a + Score;
164+
type Locked: DerefMut<Target = Self::Score> + Sized;
162165

163166
/// Returns the locked scorer.
164167
fn lock(&'a self) -> Self::Locked;
@@ -174,60 +177,35 @@ pub trait WriteableScore<'a>: LockableScore<'a> + Writeable {}
174177
impl<'a, T> WriteableScore<'a> for T where T: LockableScore<'a> + Writeable {}
175178
/// This is not exported to bindings users
176179
impl<'a, T: 'a + Score> LockableScore<'a> for Mutex<T> {
180+
type Score = T;
177181
type Locked = MutexGuard<'a, T>;
178182

179-
fn lock(&'a self) -> MutexGuard<'a, T> {
183+
fn lock(&'a self) -> Self::Locked {
180184
Mutex::lock(self).unwrap()
181185
}
182186
}
183187

184188
impl<'a, T: 'a + Score> LockableScore<'a> for RefCell<T> {
189+
type Score = T;
185190
type Locked = RefMut<'a, T>;
186191

187-
fn lock(&'a self) -> RefMut<'a, T> {
192+
fn lock(&'a self) -> Self::Locked {
188193
self.borrow_mut()
189194
}
190195
}
191196

192197
#[cfg(c_bindings)]
193198
/// A concrete implementation of [`LockableScore`] which supports multi-threading.
194-
pub struct MultiThreadedLockableScore<S: Score> {
195-
score: Mutex<S>,
196-
}
197-
#[cfg(c_bindings)]
198-
/// A locked `MultiThreadedLockableScore`.
199-
pub struct MultiThreadedScoreLock<'a, S: Score>(MutexGuard<'a, S>);
200-
#[cfg(c_bindings)]
201-
impl<'a, T: Score + 'a> Score for MultiThreadedScoreLock<'a, T> {
202-
type ScoreParams = <T as Score>::ScoreParams;
203-
fn channel_penalty_msat(&self, scid: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage, score_params: &Self::ScoreParams) -> u64 {
204-
self.0.channel_penalty_msat(scid, source, target, usage, score_params)
205-
}
206-
fn payment_path_failed(&mut self, path: &Path, short_channel_id: u64) {
207-
self.0.payment_path_failed(path, short_channel_id)
208-
}
209-
fn payment_path_successful(&mut self, path: &Path) {
210-
self.0.payment_path_successful(path)
211-
}
212-
fn probe_failed(&mut self, path: &Path, short_channel_id: u64) {
213-
self.0.probe_failed(path, short_channel_id)
214-
}
215-
fn probe_successful(&mut self, path: &Path) {
216-
self.0.probe_successful(path)
217-
}
218-
}
219-
#[cfg(c_bindings)]
220-
impl<'a, T: Score + 'a> Writeable for MultiThreadedScoreLock<'a, T> {
221-
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
222-
self.0.write(writer)
223-
}
199+
pub struct MultiThreadedLockableScore<T: Score> {
200+
score: Mutex<T>,
224201
}
225202

226203
#[cfg(c_bindings)]
227-
impl<'a, T: Score + 'a> LockableScore<'a> for MultiThreadedLockableScore<T> {
204+
impl<'a, T: 'a + Score> LockableScore<'a> for MultiThreadedLockableScore<T> {
205+
type Score = T;
228206
type Locked = MultiThreadedScoreLock<'a, T>;
229207

230-
fn lock(&'a self) -> MultiThreadedScoreLock<'a, T> {
208+
fn lock(&'a self) -> Self::Locked {
231209
MultiThreadedScoreLock(Mutex::lock(&self.score).unwrap())
232210
}
233211
}
@@ -240,7 +218,7 @@ impl<T: Score> Writeable for MultiThreadedLockableScore<T> {
240218
}
241219

242220
#[cfg(c_bindings)]
243-
impl<'a, T: Score + 'a> WriteableScore<'a> for MultiThreadedLockableScore<T> {}
221+
impl<'a, T: 'a + Score> WriteableScore<'a> for MultiThreadedLockableScore<T> {}
244222

245223
#[cfg(c_bindings)]
246224
impl<T: Score> MultiThreadedLockableScore<T> {
@@ -250,6 +228,33 @@ impl<T: Score> MultiThreadedLockableScore<T> {
250228
}
251229
}
252230

231+
#[cfg(c_bindings)]
232+
/// A locked `MultiThreadedLockableScore`.
233+
pub struct MultiThreadedScoreLock<'a, T: Score>(MutexGuard<'a, T>);
234+
235+
#[cfg(c_bindings)]
236+
impl<'a, T: 'a + Score> Writeable for MultiThreadedScoreLock<'a, T> {
237+
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
238+
self.0.write(writer)
239+
}
240+
}
241+
242+
#[cfg(c_bindings)]
243+
impl<'a, T: 'a + Score> DerefMut for MultiThreadedScoreLock<'a, T> {
244+
fn deref_mut(&mut self) -> &mut Self::Target {
245+
self.0.deref_mut()
246+
}
247+
}
248+
249+
#[cfg(c_bindings)]
250+
impl<'a, T: 'a + Score> Deref for MultiThreadedScoreLock<'a, T> {
251+
type Target = T;
252+
253+
fn deref(&self) -> &Self::Target {
254+
self.0.deref()
255+
}
256+
}
257+
253258
#[cfg(c_bindings)]
254259
/// This is not exported to bindings users
255260
impl<'a, T: Writeable> Writeable for RefMut<'a, T> {

lightning/src/util/test_utils.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ use regex;
5151
use crate::io;
5252
use crate::prelude::*;
5353
use core::cell::RefCell;
54+
use core::ops::DerefMut;
5455
use core::time::Duration;
5556
use crate::sync::{Mutex, Arc};
5657
use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
@@ -113,8 +114,8 @@ impl<'a> Router for TestRouter<'a> {
113114
if let Some((find_route_query, find_route_res)) = self.next_routes.lock().unwrap().pop_front() {
114115
assert_eq!(find_route_query, *params);
115116
if let Ok(ref route) = find_route_res {
116-
let locked_scorer = self.scorer.lock().unwrap();
117-
let scorer = ScorerAccountingForInFlightHtlcs::new(locked_scorer, inflight_htlcs);
117+
let mut binding = self.scorer.lock().unwrap();
118+
let scorer = ScorerAccountingForInFlightHtlcs::new(binding.deref_mut(), inflight_htlcs);
118119
for path in &route.paths {
119120
let mut aggregate_msat = 0u64;
120121
for (idx, hop) in path.hops.iter().rev().enumerate() {
@@ -139,10 +140,9 @@ impl<'a> Router for TestRouter<'a> {
139140
return find_route_res;
140141
}
141142
let logger = TestLogger::new();
142-
let scorer = self.scorer.lock().unwrap();
143143
find_route(
144144
payer, params, &self.network_graph, first_hops, &logger,
145-
&ScorerAccountingForInFlightHtlcs::new(scorer, &inflight_htlcs), &(),
145+
&ScorerAccountingForInFlightHtlcs::new(self.scorer.lock().unwrap().deref_mut(), &inflight_htlcs), &(),
146146
&[42; 32]
147147
)
148148
}

0 commit comments

Comments
 (0)