diff --git a/src/chain/chaininterface.rs b/src/chain/chaininterface.rs index f456afda691..8fef60b3b2c 100644 --- a/src/chain/chaininterface.rs +++ b/src/chain/chaininterface.rs @@ -25,11 +25,9 @@ pub trait ChainWatchInterface: Sync + Send { //TODO: unregister } -/// An interface to send a transaction to connected Bitcoin peers. -/// This is for final settlement. An error might indicate that no peers can be reached or -/// that peers rejected the transaction. +/// An interface to send a transaction to the Bitcoin network. pub trait BroadcasterInterface: Sync + Send { - /// Sends a transaction out to (hopefully) be mined + /// Sends a transaction out to (hopefully) be mined. fn broadcast_transaction(&self, tx: &Transaction); } @@ -105,8 +103,9 @@ impl ChainWatchInterfaceUtil { } } - /// notify listener that a block was connected - /// notification will repeat if notified listener register new listeners + /// Notify listeners that a block was connected. + /// Handles re-scanning the block and calling block_connected again if listeners register new + /// watch data during the callbacks for you (see ChainListener::block_connected for more info). pub fn block_connected_with_filtering(&self, block: &Block, height: u32) { let mut reentered = true; while reentered { @@ -125,7 +124,7 @@ impl ChainWatchInterfaceUtil { } } - /// notify listener that a block was disconnected + /// Notify listeners that a block was disconnected. pub fn block_disconnected(&self, header: &BlockHeader) { let listeners = self.listeners.lock().unwrap().clone(); for listener in listeners.iter() { @@ -136,8 +135,10 @@ impl ChainWatchInterfaceUtil { } } - /// call listeners for connected blocks if they are still around. - /// returns true if notified listeners registered additional listener + /// Notify listeners that a block was connected. + /// Returns true if notified listeners registered additional watch data (implying that the + /// block must be re-scanned and this function called again prior to further block_connected + /// calls, see ChainListener::block_connected for more info). pub fn block_connected_checked(&self, header: &BlockHeader, height: u32, txn_matched: &[&Transaction], indexes_of_txn_matched: &[u32]) -> bool { let last_seen = self.reentered.load(Ordering::Relaxed); @@ -151,7 +152,7 @@ impl ChainWatchInterfaceUtil { return last_seen != self.reentered.load(Ordering::Relaxed); } - /// Checks if a given transaction matches the current filter + /// Checks if a given transaction matches the current filter. pub fn does_match_tx(&self, tx: &Transaction) -> bool { let watched = self.watched.lock().unwrap(); self.does_match_tx_unguarded (tx, &watched) diff --git a/src/ln/channel.rs b/src/ln/channel.rs index 61df7a018be..10cfb8bd925 100644 --- a/src/ln/channel.rs +++ b/src/ln/channel.rs @@ -374,7 +374,7 @@ impl Channel { if msg.htlc_minimum_msat >= (msg.funding_satoshis - msg.channel_reserve_satoshis) * 1000 { return Err(HandleError{err: "Minimum htlc value is full channel value", msg: None}); } - Channel::check_remote_fee(fee_estimator, msg.feerate_per_kw).unwrap(); + Channel::check_remote_fee(fee_estimator, msg.feerate_per_kw)?; if msg.to_self_delay > MAX_LOCAL_BREAKDOWN_TIMEOUT { return Err(HandleError{err: "They wanted our payments to be delayed by a needlessly long period", msg: None}); } diff --git a/src/ln/channelmanager.rs b/src/ln/channelmanager.rs index 64c84172c90..b1577b0a8a8 100644 --- a/src/ln/channelmanager.rs +++ b/src/ln/channelmanager.rs @@ -1211,6 +1211,7 @@ impl ChannelMessageHandler for ChannelManager { #[cfg(test)] mod tests { + use chain::chaininterface; use ln::channelmanager::{ChannelManager,OnionKeys}; use ln::router::{Route, RouteHop, Router}; use ln::msgs; @@ -1389,17 +1390,17 @@ mod tests { } static mut CHAN_COUNT: u16 = 0; - fn confirm_transaction(chain: &test_utils::TestWatchInterface, tx: &Transaction) { + fn confirm_transaction(chain: &chaininterface::ChainWatchInterfaceUtil, tx: &Transaction) { let mut header = BlockHeader { version: 0x20000000, prev_blockhash: Default::default(), merkle_root: Default::default(), time: 42, bits: 42, nonce: 42 }; let chan_id = unsafe { CHAN_COUNT }; - chain.watch_util.block_connected_checked(&header, 1, &[tx; 1], &[chan_id as u32; 1]); + chain.block_connected_checked(&header, 1, &[tx; 1], &[chan_id as u32; 1]); for i in 2..100 { header = BlockHeader { version: 0x20000000, prev_blockhash: header.bitcoin_hash(), merkle_root: Default::default(), time: 42, bits: 42, nonce: 42 }; - chain.watch_util.block_connected_checked(&header, i, &[tx; 0], &[0; 0]); + chain.block_connected_checked(&header, i, &[tx; 0], &[0; 0]); } } - fn create_chan_between_nodes(node_a: &ChannelManager, chain_a: &test_utils::TestWatchInterface, node_b: &ChannelManager, chain_b: &test_utils::TestWatchInterface) -> (msgs::ChannelAnnouncement, msgs::ChannelUpdate, msgs::ChannelUpdate) { + fn create_chan_between_nodes(node_a: &ChannelManager, chain_a: &chaininterface::ChainWatchInterfaceUtil, node_b: &ChannelManager, chain_b: &chaininterface::ChainWatchInterfaceUtil) -> (msgs::ChannelAnnouncement, msgs::ChannelUpdate, msgs::ChannelUpdate) { let open_chan = node_a.create_channel(node_b.get_our_node_id(), (1 << 24) - 1, 42).unwrap(); let accept_chan = node_b.handle_open_channel(&node_a.get_our_node_id(), &open_chan).unwrap(); node_a.handle_accept_channel(&node_b.get_our_node_id(), &accept_chan).unwrap(); @@ -1615,7 +1616,7 @@ mod tests { let secp_ctx = Secp256k1::new(); let feeest_1 = Arc::new(test_utils::TestFeeEstimator { sat_per_vbyte: 1 }); - let chain_monitor_1 = Arc::new(test_utils::TestWatchInterface::new()); + let chain_monitor_1 = Arc::new(chaininterface::ChainWatchInterfaceUtil::new()); let chan_monitor_1 = Arc::new(test_utils::TestChannelMonitor{}); let node_id_1 = { let mut key_slice = [0; 32]; @@ -1626,7 +1627,7 @@ mod tests { let router_1 = Router::new(PublicKey::from_secret_key(&secp_ctx, &node_id_1).unwrap()); let feeest_2 = Arc::new(test_utils::TestFeeEstimator { sat_per_vbyte: 1 }); - let chain_monitor_2 = Arc::new(test_utils::TestWatchInterface::new()); + let chain_monitor_2 = Arc::new(chaininterface::ChainWatchInterfaceUtil::new()); let chan_monitor_2 = Arc::new(test_utils::TestChannelMonitor{}); let node_id_2 = { let mut key_slice = [0; 32]; @@ -1637,7 +1638,7 @@ mod tests { let router_2 = Router::new(PublicKey::from_secret_key(&secp_ctx, &node_id_2).unwrap()); let feeest_3 = Arc::new(test_utils::TestFeeEstimator { sat_per_vbyte: 1 }); - let chain_monitor_3 = Arc::new(test_utils::TestWatchInterface::new()); + let chain_monitor_3 = Arc::new(chaininterface::ChainWatchInterfaceUtil::new()); let chan_monitor_3 = Arc::new(test_utils::TestChannelMonitor{}); let node_id_3 = { let mut key_slice = [0; 32]; @@ -1648,7 +1649,7 @@ mod tests { let router_3 = Router::new(PublicKey::from_secret_key(&secp_ctx, &node_id_3).unwrap()); let feeest_4 = Arc::new(test_utils::TestFeeEstimator { sat_per_vbyte: 1 }); - let chain_monitor_4 = Arc::new(test_utils::TestWatchInterface::new()); + let chain_monitor_4 = Arc::new(chaininterface::ChainWatchInterfaceUtil::new()); let chan_monitor_4 = Arc::new(test_utils::TestChannelMonitor{}); let node_id_4 = { let mut key_slice = [0; 32]; diff --git a/src/ln/channelmonitor.rs b/src/ln/channelmonitor.rs index ac08bf5a37b..d5529af9c99 100644 --- a/src/ln/channelmonitor.rs +++ b/src/ln/channelmonitor.rs @@ -456,7 +456,7 @@ impl ChannelMonitor { for txin in tx.input.iter() { if self.funding_txo.is_none() || (txin.prev_hash == self.funding_txo.unwrap().0 && txin.prev_index == self.funding_txo.unwrap().1 as u32) { for tx in self.check_spend_transaction(tx, height).iter() { - broadcaster.broadcast_transaction(tx); // TODO: use result + broadcaster.broadcast_transaction(tx); } } } diff --git a/src/ln/msgs.rs b/src/ln/msgs.rs index 259f90d7d10..3c56aa120ec 100644 --- a/src/ln/msgs.rs +++ b/src/ln/msgs.rs @@ -20,6 +20,8 @@ pub enum DecodeError { UnknownRealmByte, /// Failed to decode a public key (ie it's invalid) BadPublicKey, + /// Failed to decode a signature (ie it's invalid) + BadSignature, /// Buffer not of right length (either too short or too long) WrongLength, } @@ -408,6 +410,7 @@ impl Error for DecodeError { match *self { DecodeError::UnknownRealmByte => "Unknown realm byte in Onion packet", DecodeError::BadPublicKey => "Invalid public key in packet", + DecodeError::BadSignature => "Invalid signature in packet", DecodeError::WrongLength => "Data was wrong length for packet", } } @@ -433,11 +436,20 @@ macro_rules! secp_pubkey { }; } +macro_rules! secp_signature { + ( $ctx: expr, $slice: expr ) => { + match Signature::from_compact($ctx, $slice) { + Ok(sig) => sig, + Err(_) => return Err(DecodeError::BadSignature) + } + }; +} + impl MsgDecodable for LocalFeatures { fn decode(v: &[u8]) -> Result { if v.len() < 3 { return Err(DecodeError::WrongLength); } let len = byte_utils::slice_to_be16(&v[0..2]) as usize; - if v.len() != len + 2 { return Err(DecodeError::WrongLength); } + if v.len() < len + 2 { return Err(DecodeError::WrongLength); } let mut flags = Vec::with_capacity(len); flags.extend_from_slice(&v[2..]); Ok(Self { @@ -458,7 +470,7 @@ impl MsgDecodable for GlobalFeatures { fn decode(v: &[u8]) -> Result { if v.len() < 3 { return Err(DecodeError::WrongLength); } let len = byte_utils::slice_to_be16(&v[0..2]) as usize; - if v.len() != len + 2 { return Err(DecodeError::WrongLength); } + if v.len() < len + 2 { return Err(DecodeError::WrongLength); } let mut flags = Vec::with_capacity(len); flags.extend_from_slice(&v[2..]); Ok(Self { @@ -478,13 +490,10 @@ impl MsgEncodable for GlobalFeatures { impl MsgDecodable for Init { fn decode(v: &[u8]) -> Result { let global_features = GlobalFeatures::decode(v)?; - if global_features.flags.len() + 4 <= v.len() { + if v.len() < global_features.flags.len() + 4 { return Err(DecodeError::WrongLength); } let local_features = LocalFeatures::decode(&v[global_features.flags.len() + 2..])?; - if global_features.flags.len() + local_features.flags.len() + 4 != v.len() { - return Err(DecodeError::WrongLength); - } Ok(Self { global_features: global_features, local_features: local_features, @@ -502,24 +511,20 @@ impl MsgEncodable for Init { impl MsgDecodable for OpenChannel { fn decode(v: &[u8]) -> Result { - if v.len() != 2*32+6*8+4+2*2+6*33+1 { + if v.len() < 2*32+6*8+4+2*2+6*33+1 { return Err(DecodeError::WrongLength); } let ctx = Secp256k1::without_caps(); - let funding_pubkey = secp_pubkey!(&ctx, &v[120..153]); - let revocation_basepoint = secp_pubkey!(&ctx, &v[153..186]); - let payment_basepoint = secp_pubkey!(&ctx, &v[186..219]); - let delayed_payment_basepoint = secp_pubkey!(&ctx, &v[219..252]); - let htlc_basepoint = secp_pubkey!(&ctx, &v[252..285]); - let first_per_commitment_point = secp_pubkey!(&ctx, &v[285..318]); let mut shutdown_scriptpubkey = None; if v.len() >= 321 { let len = byte_utils::slice_to_be16(&v[319..321]) as usize; - if v.len() != 321+len { + if v.len() < 321+len { return Err(DecodeError::WrongLength); } shutdown_scriptpubkey = Some(Script::from(v[321..321+len].to_vec())); + } else if v.len() != 2*32+6*8+4+2*2+6*33+1 { // Message cant have 1 extra byte + return Err(DecodeError::WrongLength); } Ok(OpenChannel { @@ -534,12 +539,12 @@ impl MsgDecodable for OpenChannel { feerate_per_kw: byte_utils::slice_to_be32(&v[112..116]), to_self_delay: byte_utils::slice_to_be16(&v[116..118]), max_accepted_htlcs: byte_utils::slice_to_be16(&v[118..120]), - funding_pubkey: funding_pubkey, - revocation_basepoint: revocation_basepoint, - payment_basepoint: payment_basepoint, - delayed_payment_basepoint: delayed_payment_basepoint, - htlc_basepoint: htlc_basepoint, - first_per_commitment_point: first_per_commitment_point, + funding_pubkey: secp_pubkey!(&ctx, &v[120..153]), + revocation_basepoint: secp_pubkey!(&ctx, &v[153..186]), + payment_basepoint: secp_pubkey!(&ctx, &v[186..219]), + delayed_payment_basepoint: secp_pubkey!(&ctx, &v[219..252]), + htlc_basepoint: secp_pubkey!(&ctx, &v[252..285]), + first_per_commitment_point: secp_pubkey!(&ctx, &v[285..318]), channel_flags: v[318], shutdown_scriptpubkey: shutdown_scriptpubkey }) @@ -551,10 +556,41 @@ impl MsgEncodable for OpenChannel { } } - impl MsgDecodable for AcceptChannel { - fn decode(_v: &[u8]) -> Result { - unimplemented!(); + fn decode(v: &[u8]) -> Result { + if v.len() < 32+4*8+4+2*2+6*33 { + return Err(DecodeError::WrongLength); + } + let ctx = Secp256k1::without_caps(); + + let mut shutdown_scriptpubkey = None; + if v.len() >= 272 { + let len = byte_utils::slice_to_be16(&v[270..272]) as usize; + if v.len() < 272+len { + return Err(DecodeError::WrongLength); + } + shutdown_scriptpubkey = Some(Script::from(v[272..272+len].to_vec())); + } else if v.len() != 32+4*8+4+2*2+6*33 { // Message cant have 1 extra byte + return Err(DecodeError::WrongLength); + } + + Ok(Self { + temporary_channel_id: deserialize(&v[0..32]).unwrap(), + dust_limit_satoshis: byte_utils::slice_to_be64(&v[32..40]), + max_htlc_value_in_flight_msat: byte_utils::slice_to_be64(&v[40..48]), + channel_reserve_satoshis: byte_utils::slice_to_be64(&v[48..56]), + htlc_minimum_msat: byte_utils::slice_to_be64(&v[56..64]), + minimum_depth: byte_utils::slice_to_be32(&v[64..68]), + to_self_delay: byte_utils::slice_to_be16(&v[68..70]), + max_accepted_htlcs: byte_utils::slice_to_be16(&v[70..72]), + funding_pubkey: secp_pubkey!(&ctx, &v[72..105]), + revocation_basepoint: secp_pubkey!(&ctx, &v[105..138]), + payment_basepoint: secp_pubkey!(&ctx, &v[138..171]), + delayed_payment_basepoint: secp_pubkey!(&ctx, &v[171..204]), + htlc_basepoint: secp_pubkey!(&ctx, &v[204..237]), + first_per_commitment_point: secp_pubkey!(&ctx, &v[237..270]), + shutdown_scriptpubkey: shutdown_scriptpubkey + }) } } impl MsgEncodable for AcceptChannel { @@ -564,8 +600,17 @@ impl MsgEncodable for AcceptChannel { } impl MsgDecodable for FundingCreated { - fn decode(_v: &[u8]) -> Result { - unimplemented!(); + fn decode(v: &[u8]) -> Result { + if v.len() < 32+32+2+64 { + return Err(DecodeError::WrongLength); + } + let ctx = Secp256k1::without_caps(); + Ok(Self { + temporary_channel_id: deserialize(&v[0..32]).unwrap(), + funding_txid: deserialize(&v[32..64]).unwrap(), + funding_output_index: byte_utils::slice_to_be16(&v[64..66]), + signature: secp_signature!(&ctx, &v[66..130]), + }) } } impl MsgEncodable for FundingCreated { @@ -575,8 +620,15 @@ impl MsgEncodable for FundingCreated { } impl MsgDecodable for FundingSigned { - fn decode(_v: &[u8]) -> Result { - unimplemented!(); + fn decode(v: &[u8]) -> Result { + if v.len() < 32+64 { + return Err(DecodeError::WrongLength); + } + let ctx = Secp256k1::without_caps(); + Ok(Self { + channel_id: deserialize(&v[0..32]).unwrap(), + signature: secp_signature!(&ctx, &v[32..96]), + }) } } impl MsgEncodable for FundingSigned { @@ -586,8 +638,15 @@ impl MsgEncodable for FundingSigned { } impl MsgDecodable for FundingLocked { - fn decode(_v: &[u8]) -> Result { - unimplemented!(); + fn decode(v: &[u8]) -> Result { + if v.len() < 32+33 { + return Err(DecodeError::WrongLength); + } + let ctx = Secp256k1::without_caps(); + Ok(Self { + channel_id: deserialize(&v[0..32]).unwrap(), + next_per_commitment_point: secp_pubkey!(&ctx, &v[32..65]), + }) } } impl MsgEncodable for FundingLocked { @@ -839,7 +898,7 @@ impl MsgEncodable for ChannelUpdate { impl MsgDecodable for OnionRealm0HopData { fn decode(v: &[u8]) -> Result { - if v.len() != 32 { + if v.len() < 32 { return Err(DecodeError::WrongLength); } Ok(OnionRealm0HopData { @@ -862,7 +921,7 @@ impl MsgEncodable for OnionRealm0HopData { impl MsgDecodable for OnionHopData { fn decode(v: &[u8]) -> Result { - if v.len() != 65 { + if v.len() < 65 { return Err(DecodeError::WrongLength); } let realm = v[0]; diff --git a/src/ln/peer_handler.rs b/src/ln/peer_handler.rs index bddc87b99fd..74848ad77b3 100644 --- a/src/ln/peer_handler.rs +++ b/src/ln/peer_handler.rs @@ -8,13 +8,11 @@ use util::events::{EventsProvider,Event}; use std::collections::{HashMap,LinkedList}; use std::sync::{Arc, Mutex}; -use std::cmp; -use std::mem; -use std::hash; +use std::{cmp,mem,hash,fmt}; pub struct MessageHandler { - pub chan_handler: Arc, - pub route_handler: Arc, + pub chan_handler: Arc, + pub route_handler: Arc, } /// Provides an object which can be used to send data to and which uniquely identifies a connection @@ -43,6 +41,11 @@ pub trait SocketDescriptor : cmp::Eq + hash::Hash + Clone { /// disconnect_event (unless it was provided in response to a new_*_connection event, in which case /// no such disconnect_event must be generated and the socket be silently disconencted). pub struct PeerHandleError {} +impl fmt::Debug for PeerHandleError { + fn fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error> { + formatter.write_str("Peer Send Invalid Data") + } +} struct Peer { channel_encryptor: PeerChannelEncryptor, @@ -206,6 +209,16 @@ impl PeerManager { /// course of this function! /// Panics if the descriptor was not previously registered in a new_*_connection event. pub fn read_event(&self, peer_descriptor: &mut Descriptor, data: Vec) -> Result { + match self.do_read_event(peer_descriptor, data) { + Ok(res) => Ok(res), + Err(e) => { + self.disconnect_event(peer_descriptor); + Err(e) + } + } + } + + fn do_read_event(&self, peer_descriptor: &mut Descriptor, data: Vec) -> Result { let mut upstream_events = Vec::new(); let pause_read = { let mut peers = self.peers.lock().unwrap(); diff --git a/src/util/test_utils.rs b/src/util/test_utils.rs index 1626b60b2ed..df13e1c9025 100644 --- a/src/util/test_utils.rs +++ b/src/util/test_utils.rs @@ -4,10 +4,6 @@ use ln::channelmonitor; use ln::msgs::HandleError; use bitcoin::util::hash::Sha256dHash; -use bitcoin::blockdata::transaction::Transaction; -use bitcoin::blockdata::script::Script; - -use std::sync::Weak; pub struct TestFeeEstimator { pub sat_per_vbyte: u64, @@ -18,31 +14,6 @@ impl chaininterface::FeeEstimator for TestFeeEstimator { } } -pub struct TestWatchInterface { - pub watch_util: chaininterface::ChainWatchInterfaceUtil, -} -impl chaininterface::ChainWatchInterface for TestWatchInterface { - fn install_watch_script(&self, _script_pub_key: Script) { - unimplemented!(); - } - fn install_watch_outpoint(&self, _outpoint: (Sha256dHash, u32)) { - unimplemented!(); - } - fn watch_all_txn(&self) { - unimplemented!(); - } - fn register_listener(&self, listener: Weak) { - self.watch_util.register_listener(listener); - } -} -impl TestWatchInterface { - pub fn new() -> TestWatchInterface { - TestWatchInterface { - watch_util: chaininterface::ChainWatchInterfaceUtil::new(), - } - } -} - pub struct TestChannelMonitor { }