Skip to content

Commit 0bef91a

Browse files
committed
Move banned_nodes to params struct
1 parent cc13d2b commit 0bef91a

File tree

2 files changed

+44
-30
lines changed

2 files changed

+44
-30
lines changed

lightning/src/routing/router.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5721,21 +5721,23 @@ mod tests {
57215721
let keys_manager = test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
57225722
let random_seed_bytes = keys_manager.get_secure_random_bytes();
57235723

5724-
let scorer_params = ProbabilisticScoringParameters::default();
5725-
let mut scorer = ProbabilisticScorer::new(scorer_params.clone(), Arc::clone(&network_graph), Arc::clone(&logger));
5724+
let payment_params = PaymentParameters::from_node_id(nodes[10]);
5725+
let mut scorer_params = ProbabilisticScoringParameters::default();
57265726

57275727
// First check we can get a route.
5728-
let payment_params = PaymentParameters::from_node_id(nodes[10]);
5728+
let scorer = ProbabilisticScorer::new(scorer_params.clone(), Arc::clone(&network_graph), Arc::clone(&logger));
57295729
let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, 42, Arc::clone(&logger), &scorer, &random_seed_bytes);
57305730
assert!(route.is_ok());
57315731

57325732
// Then check that we can't get a route if we ban an intermediate node.
5733-
scorer.add_banned(&NodeId::from_pubkey(&nodes[3]));
5733+
scorer_params.add_banned(&NodeId::from_pubkey(&nodes[3]));
5734+
let scorer = ProbabilisticScorer::new(scorer_params.clone(), Arc::clone(&network_graph), Arc::clone(&logger));
57345735
let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, 42, Arc::clone(&logger), &scorer, &random_seed_bytes);
57355736
assert!(route.is_err());
57365737

57375738
// Finally make sure we can route again, when we remove the ban.
5738-
scorer.remove_banned(&NodeId::from_pubkey(&nodes[3]));
5739+
scorer_params.remove_banned(&NodeId::from_pubkey(&nodes[3]));
5740+
let scorer = ProbabilisticScorer::new(scorer_params.clone(), Arc::clone(&network_graph), Arc::clone(&logger));
57395741
let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, 42, Arc::clone(&logger), &scorer, &random_seed_bytes);
57405742
assert!(route.is_ok());
57415743
}

lightning/src/routing/scoring.rs

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -300,14 +300,13 @@ where L::Target: Logger {
300300
logger: L,
301301
// TODO: Remove entries of closed channels.
302302
channel_liquidities: HashMap<u64, ChannelLiquidity<T>>,
303-
banned_nodes: HashSet<NodeId>,
304303
}
305304

306305
/// Parameters for configuring [`ProbabilisticScorer`].
307306
///
308307
/// Used to configure base, liquidity, and amount penalties, the sum of which comprises the channel
309308
/// penalty (i.e., the amount in msats willing to be paid to avoid routing through the channel).
310-
#[derive(Clone, Copy)]
309+
#[derive(Clone)]
311310
pub struct ProbabilisticScoringParameters {
312311
/// A fixed penalty in msats to apply to each channel.
313312
///
@@ -362,6 +361,11 @@ pub struct ProbabilisticScoringParameters {
362361
///
363362
/// Default value: 256 msat
364363
pub amount_penalty_multiplier_msat: u64,
364+
365+
/// A list of nodes that won't be considered during path finding.
366+
///
367+
/// (C-not exported)
368+
pub banned_nodes: HashSet<NodeId>,
365369
}
366370

367371
/// Accounting for channel liquidity balance uncertainty.
@@ -400,7 +404,6 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, T: Time> ProbabilisticScorerU
400404
network_graph,
401405
logger,
402406
channel_liquidities: HashMap::new(),
403-
banned_nodes: HashSet::new(),
404407
}
405408
}
406409

@@ -410,22 +413,6 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, T: Time> ProbabilisticScorerU
410413
self
411414
}
412415

413-
/// Marks the node with the given `node_id` as banned, i.e.,
414-
/// it will be avoided during path finding.
415-
pub fn add_banned(&mut self, node_id: &NodeId) {
416-
self.banned_nodes.insert(*node_id);
417-
}
418-
419-
/// Removes the node with the given `node_id` from the list of nodes to avoid.
420-
pub fn remove_banned(&mut self, node_id: &NodeId) {
421-
self.banned_nodes.remove(node_id);
422-
}
423-
424-
/// Clears the list of nodes that are avoided during path finding.
425-
pub fn clear_banned(&mut self) {
426-
self.banned_nodes = HashSet::new();
427-
}
428-
429416
/// Dump the contents of this scorer into the configured logger.
430417
///
431418
/// Note that this writes roughly one line per channel for which we have a liquidity estimate,
@@ -462,8 +449,33 @@ impl ProbabilisticScoringParameters {
462449
liquidity_penalty_multiplier_msat: 0,
463450
liquidity_offset_half_life: Duration::from_secs(3600),
464451
amount_penalty_multiplier_msat: 0,
452+
banned_nodes: HashSet::new(),
453+
}
454+
}
455+
456+
/// Marks the node with the given `node_id` as banned, i.e.,
457+
/// it will be avoided during path finding.
458+
pub fn add_banned(&mut self, node_id: &NodeId) {
459+
self.banned_nodes.insert(*node_id);
460+
}
461+
462+
/// Marks all nodes in the given list as banned, i.e.,
463+
/// they will be avoided during path finding.
464+
pub fn add_banned_from_list(&mut self, node_ids: Vec<NodeId>) {
465+
for id in node_ids {
466+
self.banned_nodes.insert(id);
465467
}
466468
}
469+
470+
/// Removes the node with the given `node_id` from the list of nodes to avoid.
471+
pub fn remove_banned(&mut self, node_id: &NodeId) {
472+
self.banned_nodes.remove(node_id);
473+
}
474+
475+
/// Clears the list of nodes that are avoided during path finding.
476+
pub fn clear_banned(&mut self) {
477+
self.banned_nodes = HashSet::new();
478+
}
467479
}
468480

469481
impl Default for ProbabilisticScoringParameters {
@@ -473,6 +485,7 @@ impl Default for ProbabilisticScoringParameters {
473485
liquidity_penalty_multiplier_msat: 40_000,
474486
liquidity_offset_half_life: Duration::from_secs(3600),
475487
amount_penalty_multiplier_msat: 256,
488+
banned_nodes: HashSet::new(),
476489
}
477490
}
478491
}
@@ -673,7 +686,7 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, T: Time> Score for Probabilis
673686
fn channel_penalty_msat(
674687
&self, short_channel_id: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage
675688
) -> u64 {
676-
if self.banned_nodes.contains(source) || self.banned_nodes.contains(target) {
689+
if self.params.banned_nodes.contains(source) || self.params.banned_nodes.contains(target) {
677690
return u64::max_value();
678691
}
679692

@@ -693,7 +706,7 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, T: Time> Score for Probabilis
693706
.get(&short_channel_id)
694707
.unwrap_or(&ChannelLiquidity::new())
695708
.as_directed(source, target, capacity_msat, liquidity_offset_half_life)
696-
.penalty_msat(amount_msat, self.params)
709+
.penalty_msat(amount_msat, self.params.clone())
697710
}
698711

699712
fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64) {
@@ -1099,7 +1112,6 @@ ReadableArgs<(ProbabilisticScoringParameters, G, L)> for ProbabilisticScorerUsin
10991112
network_graph,
11001113
logger,
11011114
channel_liquidities,
1102-
banned_nodes: HashSet::new(),
11031115
})
11041116
}
11051117
}
@@ -1868,7 +1880,7 @@ mod tests {
18681880
liquidity_offset_half_life: Duration::from_secs(10),
18691881
..ProbabilisticScoringParameters::zero_penalty()
18701882
};
1871-
let mut scorer = ProbabilisticScorer::new(params, &network_graph, &logger);
1883+
let mut scorer = ProbabilisticScorer::new(params.clone(), &network_graph, &logger);
18721884
let source = source_node_id();
18731885
let target = target_node_id();
18741886
let usage = ChannelUsage {
@@ -1904,7 +1916,7 @@ mod tests {
19041916
liquidity_offset_half_life: Duration::from_secs(10),
19051917
..ProbabilisticScoringParameters::zero_penalty()
19061918
};
1907-
let mut scorer = ProbabilisticScorer::new(params, &network_graph, &logger);
1919+
let mut scorer = ProbabilisticScorer::new(params.clone(), &network_graph, &logger);
19081920
let source = source_node_id();
19091921
let target = target_node_id();
19101922
let usage = ChannelUsage {
@@ -2092,7 +2104,7 @@ mod tests {
20922104
let logger = TestLogger::new();
20932105
let network_graph = network_graph(&logger);
20942106
let params = ProbabilisticScoringParameters::default();
2095-
let scorer = ProbabilisticScorer::new(params, &network_graph, &logger);
2107+
let scorer = ProbabilisticScorer::new(params.clone(), &network_graph, &logger);
20962108
let source = source_node_id();
20972109
let target = target_node_id();
20982110

0 commit comments

Comments
 (0)