Skip to content

Commit 1bb1e3a

Browse files
committed
Individually lock NetworkGraph fields
In preparation for giving NetworkGraph shared ownership, wrap individual fields in RwLock. This allows removing the outer RwLock used in NetGraphMsgHandler.
1 parent 64159b3 commit 1bb1e3a

File tree

2 files changed

+83
-55
lines changed

2 files changed

+83
-55
lines changed

lightning/src/routing/network_graph.rs

Lines changed: 60 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,10 @@ const MAX_EXCESS_BYTES_FOR_RELAY: usize = 1024;
5151
const MAX_SCIDS_PER_REPLY: usize = 8000;
5252

5353
/// Represents the network as nodes and channels between them
54-
#[derive(Clone, PartialEq)]
5554
pub struct NetworkGraph {
5655
genesis_hash: BlockHash,
57-
channels: BTreeMap<u64, ChannelInfo>,
58-
nodes: BTreeMap<PublicKey, NodeInfo>,
56+
channels: RwLock<BTreeMap<u64, ChannelInfo>>,
57+
nodes: RwLock<BTreeMap<PublicKey, NodeInfo>>,
5958
}
6059

6160
/// A simple newtype for RwLockReadGuard<'a, NetworkGraph>.
@@ -193,7 +192,8 @@ impl<C: Deref , L: Deref > RoutingMessageHandler for NetGraphMsgHandler<C, L> wh
193192
fn get_next_channel_announcements(&self, starting_point: u64, batch_amount: u8) -> Vec<(ChannelAnnouncement, Option<ChannelUpdate>, Option<ChannelUpdate>)> {
194193
let network_graph = self.network_graph.read().unwrap();
195194
let mut result = Vec::with_capacity(batch_amount as usize);
196-
let mut iter = network_graph.get_channels().range(starting_point..);
195+
let channels = network_graph.get_channels();
196+
let mut iter = channels.range(starting_point..);
197197
while result.len() < batch_amount as usize {
198198
if let Some((_, ref chan)) = iter.next() {
199199
if chan.announcement_message.is_some() {
@@ -221,12 +221,13 @@ impl<C: Deref , L: Deref > RoutingMessageHandler for NetGraphMsgHandler<C, L> wh
221221
fn get_next_node_announcements(&self, starting_point: Option<&PublicKey>, batch_amount: u8) -> Vec<NodeAnnouncement> {
222222
let network_graph = self.network_graph.read().unwrap();
223223
let mut result = Vec::with_capacity(batch_amount as usize);
224+
let nodes = network_graph.get_nodes();
224225
let mut iter = if let Some(pubkey) = starting_point {
225-
let mut iter = network_graph.get_nodes().range((*pubkey)..);
226+
let mut iter = nodes.range((*pubkey)..);
226227
iter.next();
227228
iter
228229
} else {
229-
network_graph.get_nodes().range(..)
230+
nodes.range(..)
230231
};
231232
while result.len() < batch_amount as usize {
232233
if let Some((_, ref node)) = iter.next() {
@@ -616,13 +617,15 @@ impl Writeable for NetworkGraph {
616617
write_ver_prefix!(writer, SERIALIZATION_VERSION, MIN_SERIALIZATION_VERSION);
617618

618619
self.genesis_hash.write(writer)?;
619-
(self.channels.len() as u64).write(writer)?;
620-
for (ref chan_id, ref chan_info) in self.channels.iter() {
620+
let channels = self.channels.read().unwrap();
621+
(channels.len() as u64).write(writer)?;
622+
for (ref chan_id, ref chan_info) in channels.iter() {
621623
(*chan_id).write(writer)?;
622624
chan_info.write(writer)?;
623625
}
624-
(self.nodes.len() as u64).write(writer)?;
625-
for (ref node_id, ref node_info) in self.nodes.iter() {
626+
let nodes = self.nodes.read().unwrap();
627+
(nodes.len() as u64).write(writer)?;
628+
for (ref node_id, ref node_info) in nodes.iter() {
626629
node_id.write(writer)?;
627630
node_info.write(writer)?;
628631
}
@@ -655,45 +658,58 @@ impl Readable for NetworkGraph {
655658

656659
Ok(NetworkGraph {
657660
genesis_hash,
658-
channels,
659-
nodes,
661+
channels: RwLock::new(channels),
662+
nodes: RwLock::new(nodes),
660663
})
661664
}
662665
}
663666

664667
impl fmt::Display for NetworkGraph {
665668
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
666669
writeln!(f, "Network map\n[Channels]")?;
667-
for (key, val) in self.channels.iter() {
670+
for (key, val) in self.channels.read().unwrap().iter() {
668671
writeln!(f, " {}: {}", key, val)?;
669672
}
670673
writeln!(f, "[Nodes]")?;
671-
for (key, val) in self.nodes.iter() {
674+
for (key, val) in self.nodes.read().unwrap().iter() {
672675
writeln!(f, " {}: {}", log_pubkey!(key), val)?;
673676
}
674677
Ok(())
675678
}
676679
}
677680

681+
impl PartialEq for NetworkGraph {
682+
fn eq(&self, other: &Self) -> bool {
683+
self.genesis_hash == other.genesis_hash &&
684+
*self.channels.read().unwrap() == *other.channels.read().unwrap() &&
685+
*self.nodes.read().unwrap() == *other.nodes.read().unwrap()
686+
}
687+
}
688+
678689
impl NetworkGraph {
679690
/// Returns all known valid channels' short ids along with announced channel info.
680691
///
681692
/// (C-not exported) because we have no mapping for `BTreeMap`s
682-
pub fn get_channels<'a>(&'a self) -> &'a BTreeMap<u64, ChannelInfo> { &self.channels }
693+
pub fn get_channels(&self) -> RwLockReadGuard<'_, BTreeMap<u64, ChannelInfo>> {
694+
self.channels.read().unwrap()
695+
}
696+
683697
/// Returns all known nodes' public keys along with announced node info.
684698
///
685699
/// (C-not exported) because we have no mapping for `BTreeMap`s
686-
pub fn get_nodes<'a>(&'a self) -> &'a BTreeMap<PublicKey, NodeInfo> { &self.nodes }
700+
pub fn get_nodes(&self) -> RwLockReadGuard<'_, BTreeMap<PublicKey, NodeInfo>> {
701+
self.nodes.read().unwrap()
702+
}
687703

688704
/// Get network addresses by node id.
689705
/// Returns None if the requested node is completely unknown,
690706
/// or if node announcement for the node was never received.
691707
///
692708
/// (C-not exported) as there is no practical way to track lifetimes of returned values.
693-
pub fn get_addresses<'a>(&'a self, pubkey: &PublicKey) -> Option<&'a Vec<NetAddress>> {
694-
if let Some(node) = self.nodes.get(pubkey) {
709+
pub fn get_addresses(&self, pubkey: &PublicKey) -> Option<Vec<NetAddress>> {
710+
if let Some(node) = self.nodes.read().unwrap().get(pubkey) {
695711
if let Some(node_info) = node.announcement_info.as_ref() {
696-
return Some(&node_info.addresses)
712+
return Some(node_info.addresses.clone())
697713
}
698714
}
699715
None
@@ -703,8 +719,8 @@ impl NetworkGraph {
703719
pub fn new(genesis_hash: BlockHash) -> NetworkGraph {
704720
Self {
705721
genesis_hash,
706-
channels: BTreeMap::new(),
707-
nodes: BTreeMap::new(),
722+
channels: RwLock::new(BTreeMap::new()),
723+
nodes: RwLock::new(BTreeMap::new()),
708724
}
709725
}
710726

@@ -729,7 +745,7 @@ impl NetworkGraph {
729745
}
730746

731747
fn update_node_from_announcement_intern(&mut self, msg: &msgs::UnsignedNodeAnnouncement, full_msg: Option<&msgs::NodeAnnouncement>) -> Result<(), LightningError> {
732-
match self.nodes.get_mut(&msg.node_id) {
748+
match self.nodes.write().unwrap().get_mut(&msg.node_id) {
733749
None => Err(LightningError{err: "No existing channels for node_announcement".to_owned(), action: ErrorAction::IgnoreError}),
734750
Some(node) => {
735751
if let Some(node_info) = node.announcement_info.as_ref() {
@@ -838,7 +854,9 @@ impl NetworkGraph {
838854
{ full_msg.cloned() } else { None },
839855
};
840856

841-
match self.channels.entry(msg.short_channel_id) {
857+
let mut channels = self.channels.write().unwrap();
858+
let mut nodes = self.nodes.write().unwrap();
859+
match channels.entry(msg.short_channel_id) {
842860
BtreeEntry::Occupied(mut entry) => {
843861
//TODO: because asking the blockchain if short_channel_id is valid is only optional
844862
//in the blockchain API, we need to handle it smartly here, though it's unclear
@@ -852,7 +870,7 @@ impl NetworkGraph {
852870
// b) we don't track UTXOs of channels we know about and remove them if they
853871
// get reorg'd out.
854872
// c) it's unclear how to do so without exposing ourselves to massive DoS risk.
855-
Self::remove_channel_in_nodes(&mut self.nodes, &entry.get(), msg.short_channel_id);
873+
Self::remove_channel_in_nodes(&mut nodes, &entry.get(), msg.short_channel_id);
856874
*entry.get_mut() = chan_info;
857875
} else {
858876
return Err(LightningError{err: "Already have knowledge of channel".to_owned(), action: ErrorAction::IgnoreAndLog(Level::Trace)})
@@ -865,7 +883,7 @@ impl NetworkGraph {
865883

866884
macro_rules! add_channel_to_node {
867885
( $node_id: expr ) => {
868-
match self.nodes.entry($node_id) {
886+
match nodes.entry($node_id) {
869887
BtreeEntry::Occupied(node_entry) => {
870888
node_entry.into_mut().channels.push(msg.short_channel_id);
871889
},
@@ -891,12 +909,14 @@ impl NetworkGraph {
891909
/// May cause the removal of nodes too, if this was their last channel.
892910
/// If not permanent, makes channels unavailable for routing.
893911
pub fn close_channel_from_update(&mut self, short_channel_id: u64, is_permanent: bool) {
912+
let mut channels = self.channels.write().unwrap();
894913
if is_permanent {
895-
if let Some(chan) = self.channels.remove(&short_channel_id) {
896-
Self::remove_channel_in_nodes(&mut self.nodes, &chan, short_channel_id);
914+
if let Some(chan) = channels.remove(&short_channel_id) {
915+
let mut nodes = self.nodes.write().unwrap();
916+
Self::remove_channel_in_nodes(&mut nodes, &chan, short_channel_id);
897917
}
898918
} else {
899-
if let Some(chan) = self.channels.get_mut(&short_channel_id) {
919+
if let Some(chan) = channels.get_mut(&short_channel_id) {
900920
if let Some(one_to_two) = chan.one_to_two.as_mut() {
901921
one_to_two.enabled = false;
902922
}
@@ -937,7 +957,8 @@ impl NetworkGraph {
937957
let chan_enabled = msg.flags & (1 << 1) != (1 << 1);
938958
let chan_was_enabled;
939959

940-
match self.channels.get_mut(&msg.short_channel_id) {
960+
let mut channels = self.channels.write().unwrap();
961+
match channels.get_mut(&msg.short_channel_id) {
941962
None => return Err(LightningError{err: "Couldn't find channel for update".to_owned(), action: ErrorAction::IgnoreError}),
942963
Some(channel) => {
943964
if let OptionalField::Present(htlc_maximum_msat) = msg.htlc_maximum_msat {
@@ -1000,8 +1021,9 @@ impl NetworkGraph {
10001021
}
10011022
}
10021023

1024+
let mut nodes = self.nodes.write().unwrap();
10031025
if chan_enabled {
1004-
let node = self.nodes.get_mut(&dest_node_id).unwrap();
1026+
let node = nodes.get_mut(&dest_node_id).unwrap();
10051027
let mut base_msat = msg.fee_base_msat;
10061028
let mut proportional_millionths = msg.fee_proportional_millionths;
10071029
if let Some(fees) = node.lowest_inbound_channel_fees {
@@ -1013,11 +1035,11 @@ impl NetworkGraph {
10131035
proportional_millionths
10141036
});
10151037
} else if chan_was_enabled {
1016-
let node = self.nodes.get_mut(&dest_node_id).unwrap();
1038+
let node = nodes.get_mut(&dest_node_id).unwrap();
10171039
let mut lowest_inbound_channel_fees = None;
10181040

10191041
for chan_id in node.channels.iter() {
1020-
let chan = self.channels.get(chan_id).unwrap();
1042+
let chan = channels.get(chan_id).unwrap();
10211043
let chan_info_opt;
10221044
if chan.node_one == dest_node_id {
10231045
chan_info_opt = chan.two_to_one.as_ref();
@@ -1268,7 +1290,7 @@ mod tests {
12681290
match network.get_channels().get(&unsigned_announcement.short_channel_id) {
12691291
None => panic!(),
12701292
Some(_) => ()
1271-
}
1293+
};
12721294
}
12731295

12741296
// If we receive announcement for the same channel (with UTXO lookups disabled),
@@ -1320,7 +1342,7 @@ mod tests {
13201342
match network.get_channels().get(&unsigned_announcement.short_channel_id) {
13211343
None => panic!(),
13221344
Some(_) => ()
1323-
}
1345+
};
13241346
}
13251347

13261348
// If we receive announcement for the same channel (but TX is not confirmed),
@@ -1353,7 +1375,7 @@ mod tests {
13531375
assert_eq!(channel_entry.features, ChannelFeatures::empty());
13541376
},
13551377
_ => panic!()
1356-
}
1378+
};
13571379
}
13581380

13591381
// Don't relay valid channels with excess data
@@ -1484,7 +1506,7 @@ mod tests {
14841506
assert_eq!(channel_info.one_to_two.as_ref().unwrap().cltv_expiry_delta, 144);
14851507
assert!(channel_info.two_to_one.is_none());
14861508
}
1487-
}
1509+
};
14881510
}
14891511

14901512
unsigned_channel_update.timestamp += 100;
@@ -1645,7 +1667,7 @@ mod tests {
16451667
Some(channel_info) => {
16461668
assert!(channel_info.one_to_two.is_some());
16471669
}
1648-
}
1670+
};
16491671
}
16501672

16511673
let channel_close_msg = HTLCFailChannelUpdate::ChannelClosed {
@@ -1663,7 +1685,7 @@ mod tests {
16631685
Some(channel_info) => {
16641686
assert!(!channel_info.one_to_two.as_ref().unwrap().enabled);
16651687
}
1666-
}
1688+
};
16671689
}
16681690

16691691
let channel_close_msg = HTLCFailChannelUpdate::ChannelClosed {

0 commit comments

Comments
 (0)