Skip to content

Remove peers_needing_send set from peer_handling #456

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 44 additions & 44 deletions lightning-net-tokio/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ pub struct Connection {
}
impl Connection {
fn schedule_read<CMH: ChannelMessageHandler + 'static>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor<CMH>, Arc<CMH>>>, us: Arc<Mutex<Self>>, reader: futures::stream::SplitStream<tokio_codec::Framed<TcpStream, tokio_codec::BytesCodec>>) {
let us_ref = us.clone();
let us_close_ref = us.clone();
let peer_manager_ref = peer_manager.clone();
let connection = us.clone();
let connection_close = us.clone();
let peer_manager_close = peer_manager.clone();
tokio::spawn(reader.for_each(move |b| {
let pending_read = b.to_vec();
{
let mut lock = us_ref.lock().unwrap();
let mut lock = connection.lock().unwrap();
assert!(lock.pending_read.is_empty());
if lock.read_paused {
lock.pending_read = pending_read;
Expand All @@ -60,31 +60,31 @@ impl Connection {
}
}
//TODO: There's a race where we don't meet the requirements of disconnect_socket if its
//called right here, after we release the us_ref lock in the scope above, but before we
//called right here, after we release the connection lock in the scope above, but before we
//call read_event!
match peer_manager.read_event(&mut SocketDescriptor::new(us_ref.clone(), peer_manager.clone()), pending_read) {
match peer_manager.read_event(&mut SocketDescriptor::new(connection.clone(), peer_manager.clone()), pending_read) {
Ok(pause_read) => {
if pause_read {
let mut lock = us_ref.lock().unwrap();
let mut lock = connection.lock().unwrap();
lock.read_paused = true;
}
},
Err(e) => {
us_ref.lock().unwrap().need_disconnect = false;
connection.lock().unwrap().need_disconnect = false;
return future::Either::B(future::result(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, e))));
}
}

if let Err(e) = us_ref.lock().unwrap().event_notify.try_send(()) {
if let Err(e) = connection.lock().unwrap().event_notify.try_send(()) {
// Ignore full errors as we just need them to poll after this point, so if the user
// hasn't received the last send yet, it doesn't matter.
assert!(e.is_full());
}

future::Either::B(future::result(Ok(())))
}).then(move |_| {
if us_close_ref.lock().unwrap().need_disconnect {
peer_manager_ref.disconnect_event(&SocketDescriptor::new(us_close_ref, peer_manager_ref.clone()));
if connection_close.lock().unwrap().need_disconnect {
peer_manager_close.disconnect_event(&SocketDescriptor::new(connection_close, peer_manager_close.clone()));
println!("Peer disconnected!");
} else {
println!("We disconnected peer!");
Expand All @@ -101,9 +101,9 @@ impl Connection {
})).then(|_| {
future::result(Ok(()))
}));
let us = Arc::new(Mutex::new(Self { writer: Some(send_sink), event_notify, pending_read: Vec::new(), read_blocker: None, read_paused: false, need_disconnect: true, id: ID_COUNTER.fetch_add(1, Ordering::AcqRel) }));
let connection = Arc::new(Mutex::new(Self { writer: Some(send_sink), event_notify, pending_read: Vec::new(), read_blocker: None, read_paused: false, need_disconnect: true, id: ID_COUNTER.fetch_add(1, Ordering::AcqRel) }));

(reader, us)
(reader, connection)
}

/// Process incoming messages and feed outgoing messages on the provided socket generated by
Expand All @@ -112,10 +112,10 @@ impl Connection {
/// You should poll the Receive end of event_notify and call get_and_clear_pending_events() on
/// ChannelManager and ChannelMonitor objects.
pub fn setup_inbound<CMH: ChannelMessageHandler + 'static>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor<CMH>, Arc<CMH>>>, event_notify: mpsc::Sender<()>, stream: TcpStream) {
let (reader, us) = Self::new(event_notify, stream);
let (reader, connection) = Self::new(event_notify, stream);

if let Ok(_) = peer_manager.new_inbound_connection(SocketDescriptor::new(us.clone(), peer_manager.clone())) {
Self::schedule_read(peer_manager, us, reader);
if let Ok(_) = peer_manager.new_inbound_connection(SocketDescriptor::new(connection.clone(), peer_manager.clone())) {
Self::schedule_read(peer_manager, connection, reader);
}
}

Expand All @@ -126,11 +126,11 @@ impl Connection {
/// You should poll the Receive end of event_notify and call get_and_clear_pending_events() on
/// ChannelManager and ChannelMonitor objects.
pub fn setup_outbound<CMH: ChannelMessageHandler + 'static>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor<CMH>, Arc<CMH>>>, event_notify: mpsc::Sender<()>, their_node_id: PublicKey, stream: TcpStream) {
let (reader, us) = Self::new(event_notify, stream);
let (reader, connection) = Self::new(event_notify, stream);

if let Ok(initial_send) = peer_manager.new_outbound_connection(their_node_id, SocketDescriptor::new(us.clone(), peer_manager.clone())) {
if SocketDescriptor::new(us.clone(), peer_manager.clone()).send_data(&initial_send, true) == initial_send.len() {
Self::schedule_read(peer_manager, us, reader);
if let Ok(initial_send) = peer_manager.new_outbound_connection(their_node_id, SocketDescriptor::new(connection.clone(), peer_manager.clone())) {
if SocketDescriptor::new(connection.clone(), peer_manager.clone()).send_data(&initial_send, true) == initial_send.len() {
Self::schedule_read(peer_manager, connection, reader);
} else {
println!("Failed to write first full message to socket!");
}
Expand Down Expand Up @@ -172,16 +172,16 @@ impl<CMH: ChannelMessageHandler> SocketDescriptor<CMH> {
impl<CMH: ChannelMessageHandler> peer_handler::SocketDescriptor for SocketDescriptor<CMH> {
fn send_data(&mut self, data: &[u8], resume_read: bool) -> usize {
macro_rules! schedule_read {
($us_ref: expr) => {
($descriptor: expr) => {
tokio::spawn(future::lazy(move || -> Result<(), ()> {
let mut read_data = Vec::new();
{
let mut us = $us_ref.conn.lock().unwrap();
mem::swap(&mut read_data, &mut us.pending_read);
let mut connection = $descriptor.conn.lock().unwrap();
mem::swap(&mut read_data, &mut connection.pending_read);
}
if !read_data.is_empty() {
let mut us_clone = $us_ref.clone();
match $us_ref.peer_manager.read_event(&mut us_clone, read_data) {
//let mut us_clone = $descriptor.clone();
match $descriptor.peer_manager.read_event(&mut $descriptor.clone(), read_data) {
Ok(pause_read) => {
if pause_read { return Ok(()); }
},
Expand All @@ -191,12 +191,12 @@ impl<CMH: ChannelMessageHandler> peer_handler::SocketDescriptor for SocketDescri
}
}
}
let mut us = $us_ref.conn.lock().unwrap();
if let Some(sender) = us.read_blocker.take() {
let mut connection = $descriptor.conn.lock().unwrap();
if let Some(sender) = connection.read_blocker.take() {
sender.send(Ok(())).unwrap();
}
us.read_paused = false;
if let Err(e) = us.event_notify.try_send(()) {
connection.read_paused = false;
if let Err(e) = connection.event_notify.try_send(()) {
// Ignore full errors as we just need them to poll after this point, so if the user
// hasn't received the last send yet, it doesn't matter.
assert!(e.is_full());
Expand All @@ -206,36 +206,36 @@ impl<CMH: ChannelMessageHandler> peer_handler::SocketDescriptor for SocketDescri
}
}

let mut us = self.conn.lock().unwrap();
let mut connection = self.conn.lock().unwrap();
if resume_read {
let us_ref = self.clone();
schedule_read!(us_ref);
let descriptor = self.clone();
schedule_read!(descriptor);
}
if data.is_empty() { return 0; }
if us.writer.is_none() {
us.read_paused = true;
if connection.writer.is_none() {
connection.read_paused = true;
return 0;
}

let mut bytes = bytes::BytesMut::with_capacity(data.len());
bytes.put(data);
let write_res = us.writer.as_mut().unwrap().start_send(bytes.freeze());
let write_res = connection.writer.as_mut().unwrap().start_send(bytes.freeze());
match write_res {
Ok(res) => {
match res {
AsyncSink::Ready => {
data.len()
},
AsyncSink::NotReady(_) => {
us.read_paused = true;
let us_ref = self.clone();
tokio::spawn(us.writer.take().unwrap().flush().then(move |writer_res| -> Result<(), ()> {
connection.read_paused = true;
let descriptor = self.clone();
tokio::spawn(connection.writer.take().unwrap().flush().then(move |writer_res| -> Result<(), ()> {
if let Ok(writer) = writer_res {
{
let mut us = us_ref.conn.lock().unwrap();
us.writer = Some(writer);
let mut connection = descriptor.conn.lock().unwrap();
connection.writer = Some(writer);
}
schedule_read!(us_ref);
schedule_read!(descriptor);
} // we'll fire the disconnect event on the socket reader end
Ok(())
}));
Expand All @@ -251,9 +251,9 @@ impl<CMH: ChannelMessageHandler> peer_handler::SocketDescriptor for SocketDescri
}

fn disconnect_socket(&mut self) {
let mut us = self.conn.lock().unwrap();
us.need_disconnect = true;
us.read_paused = true;
let mut connection = self.conn.lock().unwrap();
connection.need_disconnect = true;
connection.read_paused = true;
}
}
impl<CMH: ChannelMessageHandler> Clone for SocketDescriptor<CMH> {
Expand Down
32 changes: 14 additions & 18 deletions lightning/src/ln/peer_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use util::byte_utils;
use util::events::{MessageSendEvent, MessageSendEventsProvider};
use util::logger::Logger;

use std::collections::{HashMap,hash_map,HashSet,LinkedList};
use std::collections::{HashMap,hash_map,LinkedList};
use std::sync::{Arc, Mutex};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::{cmp,error,hash,fmt};
Expand Down Expand Up @@ -120,6 +120,10 @@ struct Peer {
sync_status: InitSyncTracker,

awaiting_pong: bool,

/// Indicates do_read_event() pushed a message into pending_outbound_buffer but didn't call
/// do_attempt_write_data() to avoid reentrancy. Cleared in process_events().
needing_send: bool,
}

impl Peer {
Expand All @@ -140,9 +144,6 @@ impl Peer {

struct PeerHolder<Descriptor: SocketDescriptor> {
peers: HashMap<Descriptor, Peer>,
/// Added to by do_read_event for cases where we pushed a message onto the send buffer but
/// didn't call do_attempt_write_data to avoid reentrancy. Cleared in process_events()
peers_needing_send: HashSet<Descriptor>,
/// Only add to this set when noise completes:
node_id_to_descriptor: HashMap<PublicKey, Descriptor>,
}
Expand Down Expand Up @@ -228,7 +229,6 @@ impl<Descriptor: SocketDescriptor, CM: Deref> PeerManager<Descriptor, CM> where
message_handler: message_handler,
peers: Mutex::new(PeerHolder {
peers: HashMap::new(),
peers_needing_send: HashSet::new(),
node_id_to_descriptor: HashMap::new()
}),
our_node_secret: our_node_secret,
Expand Down Expand Up @@ -299,6 +299,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref> PeerManager<Descriptor, CM> where
sync_status: InitSyncTracker::NoSyncRequested,

awaiting_pong: false,
needing_send: false,
}).is_some() {
panic!("PeerManager driver duplicated descriptors!");
};
Expand Down Expand Up @@ -336,6 +337,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref> PeerManager<Descriptor, CM> where
sync_status: InitSyncTracker::NoSyncRequested,

awaiting_pong: false,
needing_send: false,
}).is_some() {
panic!("PeerManager driver duplicated descriptors!");
};
Expand Down Expand Up @@ -485,7 +487,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref> PeerManager<Descriptor, CM> where
{
log_trace!(self, "Encoding and sending message of type {} to {}", $msg_code, log_pubkey!(peer.their_node_id.unwrap()));
peer.pending_outbound_buffer.push_back(peer.channel_encryptor.encrypt_message(&encode_msg!($msg, $msg_code)[..]));
peers.peers_needing_send.insert(peer_descriptor.clone());
peer.needing_send = true;
}
}
}
Expand Down Expand Up @@ -644,7 +646,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref> PeerManager<Descriptor, CM> where

if msg.features.initial_routing_sync() {
peer.sync_status = InitSyncTracker::ChannelsSyncing(0);
peers.peers_needing_send.insert(peer_descriptor.clone());
peer.needing_send = true;
Comment on lines -647 to +649
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TheBlueMatt Could you explain why this needs to be set here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we need to stream out our routing db, we (obviously) dont copy the whole thing into our outbound send buff (otherwise we'd just OOM ourselves...), so the logic in do_attempt_write_data has to be hit to add messages to the send buffer if the buffer is empty.

}

if !peer.outbound {
Expand Down Expand Up @@ -1029,7 +1031,6 @@ impl<Descriptor: SocketDescriptor, CM: Deref> PeerManager<Descriptor, CM> where
match *action {
msgs::ErrorAction::DisconnectPeer { ref msg } => {
if let Some(mut descriptor) = peers.node_id_to_descriptor.remove(node_id) {
peers.peers_needing_send.remove(&descriptor);
if let Some(mut peer) = peers.peers.remove(&descriptor) {
if let Some(ref msg) = *msg {
log_trace!(self, "Handling DisconnectPeer HandleError event in peer_handler for node {} with message {}",
Expand Down Expand Up @@ -1063,11 +1064,10 @@ impl<Descriptor: SocketDescriptor, CM: Deref> PeerManager<Descriptor, CM> where
}
}

for mut descriptor in peers.peers_needing_send.drain() {
match peers.peers.get_mut(&descriptor) {
Some(peer) => self.do_attempt_write_data(&mut descriptor, peer),
None => panic!("Inconsistent peers set state!"),
}
let peers_needing_send = peers.peers.iter_mut().filter(|(_, peer)| peer.needing_send);
for (descriptor, peer) in peers_needing_send {
peer.needing_send = false;
self.do_attempt_write_data(&mut descriptor.clone(), peer)
}
}
}
Expand All @@ -1084,9 +1084,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref> PeerManager<Descriptor, CM> where

fn disconnect_event_internal(&self, descriptor: &Descriptor, no_connection_possible: bool) {
let mut peers = self.peers.lock().unwrap();
peers.peers_needing_send.remove(descriptor);
let peer_option = peers.peers.remove(descriptor);
match peer_option {
match peers.peers.remove(descriptor) {
None => panic!("Descriptor for disconnect_event is not already known to PeerManager"),
Some(peer) => {
match peer.their_node_id {
Expand All @@ -1108,13 +1106,11 @@ impl<Descriptor: SocketDescriptor, CM: Deref> PeerManager<Descriptor, CM> where
let mut peers_lock = self.peers.lock().unwrap();
{
let peers = &mut *peers_lock;
let peers_needing_send = &mut peers.peers_needing_send;
let node_id_to_descriptor = &mut peers.node_id_to_descriptor;
let peers = &mut peers.peers;

peers.retain(|descriptor, peer| {
if peer.awaiting_pong == true {
peers_needing_send.remove(descriptor);
match peer.their_node_id {
Some(node_id) => {
node_id_to_descriptor.remove(&node_id);
Expand Down