@@ -51,11 +51,10 @@ const MAX_EXCESS_BYTES_FOR_RELAY: usize = 1024;
51
51
const MAX_SCIDS_PER_REPLY : usize = 8000 ;
52
52
53
53
/// Represents the network as nodes and channels between them
54
- #[ derive( Clone , PartialEq ) ]
55
54
pub struct NetworkGraph {
56
55
genesis_hash : BlockHash ,
57
- channels : BTreeMap < u64 , ChannelInfo > ,
58
- nodes : BTreeMap < PublicKey , NodeInfo > ,
56
+ channels : RwLock < BTreeMap < u64 , ChannelInfo > > ,
57
+ nodes : RwLock < BTreeMap < PublicKey , NodeInfo > > ,
59
58
}
60
59
61
60
/// A simple newtype for RwLockReadGuard<'a, NetworkGraph>.
@@ -193,7 +192,8 @@ impl<C: Deref , L: Deref > RoutingMessageHandler for NetGraphMsgHandler<C, L> wh
193
192
fn get_next_channel_announcements ( & self , starting_point : u64 , batch_amount : u8 ) -> Vec < ( ChannelAnnouncement , Option < ChannelUpdate > , Option < ChannelUpdate > ) > {
194
193
let network_graph = self . network_graph . read ( ) . unwrap ( ) ;
195
194
let mut result = Vec :: with_capacity ( batch_amount as usize ) ;
196
- let mut iter = network_graph. get_channels ( ) . range ( starting_point..) ;
195
+ let channels = network_graph. get_channels ( ) ;
196
+ let mut iter = channels. range ( starting_point..) ;
197
197
while result. len ( ) < batch_amount as usize {
198
198
if let Some ( ( _, ref chan) ) = iter. next ( ) {
199
199
if chan. announcement_message . is_some ( ) {
@@ -221,12 +221,13 @@ impl<C: Deref , L: Deref > RoutingMessageHandler for NetGraphMsgHandler<C, L> wh
221
221
fn get_next_node_announcements ( & self , starting_point : Option < & PublicKey > , batch_amount : u8 ) -> Vec < NodeAnnouncement > {
222
222
let network_graph = self . network_graph . read ( ) . unwrap ( ) ;
223
223
let mut result = Vec :: with_capacity ( batch_amount as usize ) ;
224
+ let nodes = network_graph. get_nodes ( ) ;
224
225
let mut iter = if let Some ( pubkey) = starting_point {
225
- let mut iter = network_graph . get_nodes ( ) . range ( ( * pubkey) ..) ;
226
+ let mut iter = nodes . range ( ( * pubkey) ..) ;
226
227
iter. next ( ) ;
227
228
iter
228
229
} else {
229
- network_graph . get_nodes ( ) . range ( ..)
230
+ nodes . range ( ..)
230
231
} ;
231
232
while result. len ( ) < batch_amount as usize {
232
233
if let Some ( ( _, ref node) ) = iter. next ( ) {
@@ -616,13 +617,15 @@ impl Writeable for NetworkGraph {
616
617
write_ver_prefix ! ( writer, SERIALIZATION_VERSION , MIN_SERIALIZATION_VERSION ) ;
617
618
618
619
self . genesis_hash . write ( writer) ?;
619
- ( self . channels . len ( ) as u64 ) . write ( writer) ?;
620
- for ( ref chan_id, ref chan_info) in self . channels . iter ( ) {
620
+ let channels = self . channels . read ( ) . unwrap ( ) ;
621
+ ( channels. len ( ) as u64 ) . write ( writer) ?;
622
+ for ( ref chan_id, ref chan_info) in channels. iter ( ) {
621
623
( * chan_id) . write ( writer) ?;
622
624
chan_info. write ( writer) ?;
623
625
}
624
- ( self . nodes . len ( ) as u64 ) . write ( writer) ?;
625
- for ( ref node_id, ref node_info) in self . nodes . iter ( ) {
626
+ let nodes = self . nodes . read ( ) . unwrap ( ) ;
627
+ ( nodes. len ( ) as u64 ) . write ( writer) ?;
628
+ for ( ref node_id, ref node_info) in nodes. iter ( ) {
626
629
node_id. write ( writer) ?;
627
630
node_info. write ( writer) ?;
628
631
}
@@ -655,45 +658,58 @@ impl Readable for NetworkGraph {
655
658
656
659
Ok ( NetworkGraph {
657
660
genesis_hash,
658
- channels,
659
- nodes,
661
+ channels : RwLock :: new ( channels ) ,
662
+ nodes : RwLock :: new ( nodes ) ,
660
663
} )
661
664
}
662
665
}
663
666
664
667
impl fmt:: Display for NetworkGraph {
665
668
fn fmt ( & self , f : & mut fmt:: Formatter ) -> Result < ( ) , fmt:: Error > {
666
669
writeln ! ( f, "Network map\n [Channels]" ) ?;
667
- for ( key, val) in self . channels . iter ( ) {
670
+ for ( key, val) in self . channels . read ( ) . unwrap ( ) . iter ( ) {
668
671
writeln ! ( f, " {}: {}" , key, val) ?;
669
672
}
670
673
writeln ! ( f, "[Nodes]" ) ?;
671
- for ( key, val) in self . nodes . iter ( ) {
674
+ for ( key, val) in self . nodes . read ( ) . unwrap ( ) . iter ( ) {
672
675
writeln ! ( f, " {}: {}" , log_pubkey!( key) , val) ?;
673
676
}
674
677
Ok ( ( ) )
675
678
}
676
679
}
677
680
681
+ impl PartialEq for NetworkGraph {
682
+ fn eq ( & self , other : & Self ) -> bool {
683
+ self . genesis_hash == other. genesis_hash &&
684
+ * self . channels . read ( ) . unwrap ( ) == * other. channels . read ( ) . unwrap ( ) &&
685
+ * self . nodes . read ( ) . unwrap ( ) == * other. nodes . read ( ) . unwrap ( )
686
+ }
687
+ }
688
+
678
689
impl NetworkGraph {
679
690
/// Returns all known valid channels' short ids along with announced channel info.
680
691
///
681
692
/// (C-not exported) because we have no mapping for `BTreeMap`s
682
- pub fn get_channels < ' a > ( & ' a self ) -> & ' a BTreeMap < u64 , ChannelInfo > { & self . channels }
693
+ pub fn get_channels ( & self ) -> RwLockReadGuard < ' _ , BTreeMap < u64 , ChannelInfo > > {
694
+ self . channels . read ( ) . unwrap ( )
695
+ }
696
+
683
697
/// Returns all known nodes' public keys along with announced node info.
684
698
///
685
699
/// (C-not exported) because we have no mapping for `BTreeMap`s
686
- pub fn get_nodes < ' a > ( & ' a self ) -> & ' a BTreeMap < PublicKey , NodeInfo > { & self . nodes }
700
+ pub fn get_nodes ( & self ) -> RwLockReadGuard < ' _ , BTreeMap < PublicKey , NodeInfo > > {
701
+ self . nodes . read ( ) . unwrap ( )
702
+ }
687
703
688
704
/// Get network addresses by node id.
689
705
/// Returns None if the requested node is completely unknown,
690
706
/// or if node announcement for the node was never received.
691
707
///
692
708
/// (C-not exported) as there is no practical way to track lifetimes of returned values.
693
- pub fn get_addresses < ' a > ( & ' a self , pubkey : & PublicKey ) -> Option < & ' a Vec < NetAddress > > {
694
- if let Some ( node) = self . nodes . get ( pubkey) {
709
+ pub fn get_addresses ( & self , pubkey : & PublicKey ) -> Option < Vec < NetAddress > > {
710
+ if let Some ( node) = self . nodes . read ( ) . unwrap ( ) . get ( pubkey) {
695
711
if let Some ( node_info) = node. announcement_info . as_ref ( ) {
696
- return Some ( & node_info. addresses )
712
+ return Some ( node_info. addresses . clone ( ) )
697
713
}
698
714
}
699
715
None
@@ -703,8 +719,8 @@ impl NetworkGraph {
703
719
pub fn new ( genesis_hash : BlockHash ) -> NetworkGraph {
704
720
Self {
705
721
genesis_hash,
706
- channels : BTreeMap :: new ( ) ,
707
- nodes : BTreeMap :: new ( ) ,
722
+ channels : RwLock :: new ( BTreeMap :: new ( ) ) ,
723
+ nodes : RwLock :: new ( BTreeMap :: new ( ) ) ,
708
724
}
709
725
}
710
726
@@ -729,7 +745,7 @@ impl NetworkGraph {
729
745
}
730
746
731
747
fn update_node_from_announcement_intern ( & mut self , msg : & msgs:: UnsignedNodeAnnouncement , full_msg : Option < & msgs:: NodeAnnouncement > ) -> Result < ( ) , LightningError > {
732
- match self . nodes . get_mut ( & msg. node_id ) {
748
+ match self . nodes . write ( ) . unwrap ( ) . get_mut ( & msg. node_id ) {
733
749
None => Err ( LightningError { err : "No existing channels for node_announcement" . to_owned ( ) , action : ErrorAction :: IgnoreError } ) ,
734
750
Some ( node) => {
735
751
if let Some ( node_info) = node. announcement_info . as_ref ( ) {
@@ -838,7 +854,9 @@ impl NetworkGraph {
838
854
{ full_msg. cloned ( ) } else { None } ,
839
855
} ;
840
856
841
- match self . channels . entry ( msg. short_channel_id ) {
857
+ let mut channels = self . channels . write ( ) . unwrap ( ) ;
858
+ let mut nodes = self . nodes . write ( ) . unwrap ( ) ;
859
+ match channels. entry ( msg. short_channel_id ) {
842
860
BtreeEntry :: Occupied ( mut entry) => {
843
861
//TODO: because asking the blockchain if short_channel_id is valid is only optional
844
862
//in the blockchain API, we need to handle it smartly here, though it's unclear
@@ -852,7 +870,7 @@ impl NetworkGraph {
852
870
// b) we don't track UTXOs of channels we know about and remove them if they
853
871
// get reorg'd out.
854
872
// c) it's unclear how to do so without exposing ourselves to massive DoS risk.
855
- Self :: remove_channel_in_nodes ( & mut self . nodes , & entry. get ( ) , msg. short_channel_id ) ;
873
+ Self :: remove_channel_in_nodes ( & mut nodes, & entry. get ( ) , msg. short_channel_id ) ;
856
874
* entry. get_mut ( ) = chan_info;
857
875
} else {
858
876
return Err ( LightningError { err : "Already have knowledge of channel" . to_owned ( ) , action : ErrorAction :: IgnoreAndLog ( Level :: Trace ) } )
@@ -865,7 +883,7 @@ impl NetworkGraph {
865
883
866
884
macro_rules! add_channel_to_node {
867
885
( $node_id: expr ) => {
868
- match self . nodes. entry( $node_id) {
886
+ match nodes. entry( $node_id) {
869
887
BtreeEntry :: Occupied ( node_entry) => {
870
888
node_entry. into_mut( ) . channels. push( msg. short_channel_id) ;
871
889
} ,
@@ -891,12 +909,14 @@ impl NetworkGraph {
891
909
/// May cause the removal of nodes too, if this was their last channel.
892
910
/// If not permanent, makes channels unavailable for routing.
893
911
pub fn close_channel_from_update ( & mut self , short_channel_id : u64 , is_permanent : bool ) {
912
+ let mut channels = self . channels . write ( ) . unwrap ( ) ;
894
913
if is_permanent {
895
- if let Some ( chan) = self . channels . remove ( & short_channel_id) {
896
- Self :: remove_channel_in_nodes ( & mut self . nodes , & chan, short_channel_id) ;
914
+ if let Some ( chan) = channels. remove ( & short_channel_id) {
915
+ let mut nodes = self . nodes . write ( ) . unwrap ( ) ;
916
+ Self :: remove_channel_in_nodes ( & mut nodes, & chan, short_channel_id) ;
897
917
}
898
918
} else {
899
- if let Some ( chan) = self . channels . get_mut ( & short_channel_id) {
919
+ if let Some ( chan) = channels. get_mut ( & short_channel_id) {
900
920
if let Some ( one_to_two) = chan. one_to_two . as_mut ( ) {
901
921
one_to_two. enabled = false ;
902
922
}
@@ -937,7 +957,8 @@ impl NetworkGraph {
937
957
let chan_enabled = msg. flags & ( 1 << 1 ) != ( 1 << 1 ) ;
938
958
let chan_was_enabled;
939
959
940
- match self . channels . get_mut ( & msg. short_channel_id ) {
960
+ let mut channels = self . channels . write ( ) . unwrap ( ) ;
961
+ match channels. get_mut ( & msg. short_channel_id ) {
941
962
None => return Err ( LightningError { err : "Couldn't find channel for update" . to_owned ( ) , action : ErrorAction :: IgnoreError } ) ,
942
963
Some ( channel) => {
943
964
if let OptionalField :: Present ( htlc_maximum_msat) = msg. htlc_maximum_msat {
@@ -1000,8 +1021,9 @@ impl NetworkGraph {
1000
1021
}
1001
1022
}
1002
1023
1024
+ let mut nodes = self . nodes . write ( ) . unwrap ( ) ;
1003
1025
if chan_enabled {
1004
- let node = self . nodes . get_mut ( & dest_node_id) . unwrap ( ) ;
1026
+ let node = nodes. get_mut ( & dest_node_id) . unwrap ( ) ;
1005
1027
let mut base_msat = msg. fee_base_msat ;
1006
1028
let mut proportional_millionths = msg. fee_proportional_millionths ;
1007
1029
if let Some ( fees) = node. lowest_inbound_channel_fees {
@@ -1013,11 +1035,11 @@ impl NetworkGraph {
1013
1035
proportional_millionths
1014
1036
} ) ;
1015
1037
} else if chan_was_enabled {
1016
- let node = self . nodes . get_mut ( & dest_node_id) . unwrap ( ) ;
1038
+ let node = nodes. get_mut ( & dest_node_id) . unwrap ( ) ;
1017
1039
let mut lowest_inbound_channel_fees = None ;
1018
1040
1019
1041
for chan_id in node. channels . iter ( ) {
1020
- let chan = self . channels . get ( chan_id) . unwrap ( ) ;
1042
+ let chan = channels. get ( chan_id) . unwrap ( ) ;
1021
1043
let chan_info_opt;
1022
1044
if chan. node_one == dest_node_id {
1023
1045
chan_info_opt = chan. two_to_one . as_ref ( ) ;
@@ -1268,7 +1290,7 @@ mod tests {
1268
1290
match network. get_channels ( ) . get ( & unsigned_announcement. short_channel_id ) {
1269
1291
None => panic ! ( ) ,
1270
1292
Some ( _) => ( )
1271
- }
1293
+ } ;
1272
1294
}
1273
1295
1274
1296
// If we receive announcement for the same channel (with UTXO lookups disabled),
@@ -1320,7 +1342,7 @@ mod tests {
1320
1342
match network. get_channels ( ) . get ( & unsigned_announcement. short_channel_id ) {
1321
1343
None => panic ! ( ) ,
1322
1344
Some ( _) => ( )
1323
- }
1345
+ } ;
1324
1346
}
1325
1347
1326
1348
// If we receive announcement for the same channel (but TX is not confirmed),
@@ -1353,7 +1375,7 @@ mod tests {
1353
1375
assert_eq ! ( channel_entry. features, ChannelFeatures :: empty( ) ) ;
1354
1376
} ,
1355
1377
_ => panic ! ( )
1356
- }
1378
+ } ;
1357
1379
}
1358
1380
1359
1381
// Don't relay valid channels with excess data
@@ -1484,7 +1506,7 @@ mod tests {
1484
1506
assert_eq ! ( channel_info. one_to_two. as_ref( ) . unwrap( ) . cltv_expiry_delta, 144 ) ;
1485
1507
assert ! ( channel_info. two_to_one. is_none( ) ) ;
1486
1508
}
1487
- }
1509
+ } ;
1488
1510
}
1489
1511
1490
1512
unsigned_channel_update. timestamp += 100 ;
@@ -1645,7 +1667,7 @@ mod tests {
1645
1667
Some ( channel_info) => {
1646
1668
assert ! ( channel_info. one_to_two. is_some( ) ) ;
1647
1669
}
1648
- }
1670
+ } ;
1649
1671
}
1650
1672
1651
1673
let channel_close_msg = HTLCFailChannelUpdate :: ChannelClosed {
@@ -1663,7 +1685,7 @@ mod tests {
1663
1685
Some ( channel_info) => {
1664
1686
assert ! ( !channel_info. one_to_two. as_ref( ) . unwrap( ) . enabled) ;
1665
1687
}
1666
- }
1688
+ } ;
1667
1689
}
1668
1690
1669
1691
let channel_close_msg = HTLCFailChannelUpdate :: ChannelClosed {
0 commit comments