Skip to content

Few random fixes #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions src/chain/chaininterface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@ pub trait ChainWatchInterface: Sync + Send {
//TODO: unregister
}

/// An interface to send a transaction to connected Bitcoin peers.
/// This is for final settlement. An error might indicate that no peers can be reached or
/// that peers rejected the transaction.
/// An interface to send a transaction to the Bitcoin network.
pub trait BroadcasterInterface: Sync + Send {
/// Sends a transaction out to (hopefully) be mined
/// Sends a transaction out to (hopefully) be mined.
fn broadcast_transaction(&self, tx: &Transaction);
}

Expand Down Expand Up @@ -105,8 +103,9 @@ impl ChainWatchInterfaceUtil {
}
}

/// notify listener that a block was connected
/// notification will repeat if notified listener register new listeners
/// Notify listeners that a block was connected.
/// Handles re-scanning the block and calling block_connected again if listeners register new
/// watch data during the callbacks for you (see ChainListener::block_connected for more info).
pub fn block_connected_with_filtering(&self, block: &Block, height: u32) {
let mut reentered = true;
while reentered {
Expand All @@ -125,7 +124,7 @@ impl ChainWatchInterfaceUtil {
}
}

/// notify listener that a block was disconnected
/// Notify listeners that a block was disconnected.
pub fn block_disconnected(&self, header: &BlockHeader) {
let listeners = self.listeners.lock().unwrap().clone();
for listener in listeners.iter() {
Expand All @@ -136,8 +135,10 @@ impl ChainWatchInterfaceUtil {
}
}

/// call listeners for connected blocks if they are still around.
/// returns true if notified listeners registered additional listener
/// Notify listeners that a block was connected.
/// Returns true if notified listeners registered additional watch data (implying that the
/// block must be re-scanned and this function called again prior to further block_connected
/// calls, see ChainListener::block_connected for more info).
pub fn block_connected_checked(&self, header: &BlockHeader, height: u32, txn_matched: &[&Transaction], indexes_of_txn_matched: &[u32]) -> bool {
let last_seen = self.reentered.load(Ordering::Relaxed);

Expand All @@ -151,7 +152,7 @@ impl ChainWatchInterfaceUtil {
return last_seen != self.reentered.load(Ordering::Relaxed);
}

/// Checks if a given transaction matches the current filter
/// Checks if a given transaction matches the current filter.
pub fn does_match_tx(&self, tx: &Transaction) -> bool {
let watched = self.watched.lock().unwrap();
self.does_match_tx_unguarded (tx, &watched)
Expand Down
2 changes: 1 addition & 1 deletion src/ln/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ impl Channel {
if msg.htlc_minimum_msat >= (msg.funding_satoshis - msg.channel_reserve_satoshis) * 1000 {
return Err(HandleError{err: "Minimum htlc value is full channel value", msg: None});
}
Channel::check_remote_fee(fee_estimator, msg.feerate_per_kw).unwrap();
Channel::check_remote_fee(fee_estimator, msg.feerate_per_kw)?;
if msg.to_self_delay > MAX_LOCAL_BREAKDOWN_TIMEOUT {
return Err(HandleError{err: "They wanted our payments to be delayed by a needlessly long period", msg: None});
}
Expand Down
17 changes: 9 additions & 8 deletions src/ln/channelmanager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1211,6 +1211,7 @@ impl ChannelMessageHandler for ChannelManager {

#[cfg(test)]
mod tests {
use chain::chaininterface;
use ln::channelmanager::{ChannelManager,OnionKeys};
use ln::router::{Route, RouteHop, Router};
use ln::msgs;
Expand Down Expand Up @@ -1389,17 +1390,17 @@ mod tests {
}

static mut CHAN_COUNT: u16 = 0;
fn confirm_transaction(chain: &test_utils::TestWatchInterface, tx: &Transaction) {
fn confirm_transaction(chain: &chaininterface::ChainWatchInterfaceUtil, tx: &Transaction) {
let mut header = BlockHeader { version: 0x20000000, prev_blockhash: Default::default(), merkle_root: Default::default(), time: 42, bits: 42, nonce: 42 };
let chan_id = unsafe { CHAN_COUNT };
chain.watch_util.block_connected_checked(&header, 1, &[tx; 1], &[chan_id as u32; 1]);
chain.block_connected_checked(&header, 1, &[tx; 1], &[chan_id as u32; 1]);
for i in 2..100 {
header = BlockHeader { version: 0x20000000, prev_blockhash: header.bitcoin_hash(), merkle_root: Default::default(), time: 42, bits: 42, nonce: 42 };
chain.watch_util.block_connected_checked(&header, i, &[tx; 0], &[0; 0]);
chain.block_connected_checked(&header, i, &[tx; 0], &[0; 0]);
}
}

fn create_chan_between_nodes(node_a: &ChannelManager, chain_a: &test_utils::TestWatchInterface, node_b: &ChannelManager, chain_b: &test_utils::TestWatchInterface) -> (msgs::ChannelAnnouncement, msgs::ChannelUpdate, msgs::ChannelUpdate) {
fn create_chan_between_nodes(node_a: &ChannelManager, chain_a: &chaininterface::ChainWatchInterfaceUtil, node_b: &ChannelManager, chain_b: &chaininterface::ChainWatchInterfaceUtil) -> (msgs::ChannelAnnouncement, msgs::ChannelUpdate, msgs::ChannelUpdate) {
let open_chan = node_a.create_channel(node_b.get_our_node_id(), (1 << 24) - 1, 42).unwrap();
let accept_chan = node_b.handle_open_channel(&node_a.get_our_node_id(), &open_chan).unwrap();
node_a.handle_accept_channel(&node_b.get_our_node_id(), &accept_chan).unwrap();
Expand Down Expand Up @@ -1615,7 +1616,7 @@ mod tests {
let secp_ctx = Secp256k1::new();

let feeest_1 = Arc::new(test_utils::TestFeeEstimator { sat_per_vbyte: 1 });
let chain_monitor_1 = Arc::new(test_utils::TestWatchInterface::new());
let chain_monitor_1 = Arc::new(chaininterface::ChainWatchInterfaceUtil::new());
let chan_monitor_1 = Arc::new(test_utils::TestChannelMonitor{});
let node_id_1 = {
let mut key_slice = [0; 32];
Expand All @@ -1626,7 +1627,7 @@ mod tests {
let router_1 = Router::new(PublicKey::from_secret_key(&secp_ctx, &node_id_1).unwrap());

let feeest_2 = Arc::new(test_utils::TestFeeEstimator { sat_per_vbyte: 1 });
let chain_monitor_2 = Arc::new(test_utils::TestWatchInterface::new());
let chain_monitor_2 = Arc::new(chaininterface::ChainWatchInterfaceUtil::new());
let chan_monitor_2 = Arc::new(test_utils::TestChannelMonitor{});
let node_id_2 = {
let mut key_slice = [0; 32];
Expand All @@ -1637,7 +1638,7 @@ mod tests {
let router_2 = Router::new(PublicKey::from_secret_key(&secp_ctx, &node_id_2).unwrap());

let feeest_3 = Arc::new(test_utils::TestFeeEstimator { sat_per_vbyte: 1 });
let chain_monitor_3 = Arc::new(test_utils::TestWatchInterface::new());
let chain_monitor_3 = Arc::new(chaininterface::ChainWatchInterfaceUtil::new());
let chan_monitor_3 = Arc::new(test_utils::TestChannelMonitor{});
let node_id_3 = {
let mut key_slice = [0; 32];
Expand All @@ -1648,7 +1649,7 @@ mod tests {
let router_3 = Router::new(PublicKey::from_secret_key(&secp_ctx, &node_id_3).unwrap());

let feeest_4 = Arc::new(test_utils::TestFeeEstimator { sat_per_vbyte: 1 });
let chain_monitor_4 = Arc::new(test_utils::TestWatchInterface::new());
let chain_monitor_4 = Arc::new(chaininterface::ChainWatchInterfaceUtil::new());
let chan_monitor_4 = Arc::new(test_utils::TestChannelMonitor{});
let node_id_4 = {
let mut key_slice = [0; 32];
Expand Down
2 changes: 1 addition & 1 deletion src/ln/channelmonitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ impl ChannelMonitor {
for txin in tx.input.iter() {
if self.funding_txo.is_none() || (txin.prev_hash == self.funding_txo.unwrap().0 && txin.prev_index == self.funding_txo.unwrap().1 as u32) {
for tx in self.check_spend_transaction(tx, height).iter() {
broadcaster.broadcast_transaction(tx); // TODO: use result
broadcaster.broadcast_transaction(tx);
}
}
}
Expand Down
121 changes: 90 additions & 31 deletions src/ln/msgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ pub enum DecodeError {
UnknownRealmByte,
/// Failed to decode a public key (ie it's invalid)
BadPublicKey,
/// Failed to decode a signature (ie it's invalid)
BadSignature,
/// Buffer not of right length (either too short or too long)
WrongLength,
}
Expand Down Expand Up @@ -408,6 +410,7 @@ impl Error for DecodeError {
match *self {
DecodeError::UnknownRealmByte => "Unknown realm byte in Onion packet",
DecodeError::BadPublicKey => "Invalid public key in packet",
DecodeError::BadSignature => "Invalid signature in packet",
DecodeError::WrongLength => "Data was wrong length for packet",
}
}
Expand All @@ -433,11 +436,20 @@ macro_rules! secp_pubkey {
};
}

macro_rules! secp_signature {
( $ctx: expr, $slice: expr ) => {
match Signature::from_compact($ctx, $slice) {
Ok(sig) => sig,
Err(_) => return Err(DecodeError::BadSignature)
}
};
}

impl MsgDecodable for LocalFeatures {
fn decode(v: &[u8]) -> Result<Self, DecodeError> {
if v.len() < 3 { return Err(DecodeError::WrongLength); }
let len = byte_utils::slice_to_be16(&v[0..2]) as usize;
if v.len() != len + 2 { return Err(DecodeError::WrongLength); }
if v.len() < len + 2 { return Err(DecodeError::WrongLength); }
let mut flags = Vec::with_capacity(len);
flags.extend_from_slice(&v[2..]);
Ok(Self {
Expand All @@ -458,7 +470,7 @@ impl MsgDecodable for GlobalFeatures {
fn decode(v: &[u8]) -> Result<Self, DecodeError> {
if v.len() < 3 { return Err(DecodeError::WrongLength); }
let len = byte_utils::slice_to_be16(&v[0..2]) as usize;
if v.len() != len + 2 { return Err(DecodeError::WrongLength); }
if v.len() < len + 2 { return Err(DecodeError::WrongLength); }
let mut flags = Vec::with_capacity(len);
flags.extend_from_slice(&v[2..]);
Ok(Self {
Expand All @@ -478,13 +490,10 @@ impl MsgEncodable for GlobalFeatures {
impl MsgDecodable for Init {
fn decode(v: &[u8]) -> Result<Self, DecodeError> {
let global_features = GlobalFeatures::decode(v)?;
if global_features.flags.len() + 4 <= v.len() {
if v.len() < global_features.flags.len() + 4 {
return Err(DecodeError::WrongLength);
}
let local_features = LocalFeatures::decode(&v[global_features.flags.len() + 2..])?;
if global_features.flags.len() + local_features.flags.len() + 4 != v.len() {
return Err(DecodeError::WrongLength);
}
Ok(Self {
global_features: global_features,
local_features: local_features,
Expand All @@ -502,24 +511,20 @@ impl MsgEncodable for Init {

impl MsgDecodable for OpenChannel {
fn decode(v: &[u8]) -> Result<Self, DecodeError> {
if v.len() != 2*32+6*8+4+2*2+6*33+1 {
if v.len() < 2*32+6*8+4+2*2+6*33+1 {
return Err(DecodeError::WrongLength);
}
let ctx = Secp256k1::without_caps();
let funding_pubkey = secp_pubkey!(&ctx, &v[120..153]);
let revocation_basepoint = secp_pubkey!(&ctx, &v[153..186]);
let payment_basepoint = secp_pubkey!(&ctx, &v[186..219]);
let delayed_payment_basepoint = secp_pubkey!(&ctx, &v[219..252]);
let htlc_basepoint = secp_pubkey!(&ctx, &v[252..285]);
let first_per_commitment_point = secp_pubkey!(&ctx, &v[285..318]);

let mut shutdown_scriptpubkey = None;
if v.len() >= 321 {
let len = byte_utils::slice_to_be16(&v[319..321]) as usize;
if v.len() != 321+len {
if v.len() < 321+len {
return Err(DecodeError::WrongLength);
}
shutdown_scriptpubkey = Some(Script::from(v[321..321+len].to_vec()));
} else if v.len() != 2*32+6*8+4+2*2+6*33+1 { // Message cant have 1 extra byte
return Err(DecodeError::WrongLength);
}

Ok(OpenChannel {
Expand All @@ -534,12 +539,12 @@ impl MsgDecodable for OpenChannel {
feerate_per_kw: byte_utils::slice_to_be32(&v[112..116]),
to_self_delay: byte_utils::slice_to_be16(&v[116..118]),
max_accepted_htlcs: byte_utils::slice_to_be16(&v[118..120]),
funding_pubkey: funding_pubkey,
revocation_basepoint: revocation_basepoint,
payment_basepoint: payment_basepoint,
delayed_payment_basepoint: delayed_payment_basepoint,
htlc_basepoint: htlc_basepoint,
first_per_commitment_point: first_per_commitment_point,
funding_pubkey: secp_pubkey!(&ctx, &v[120..153]),
revocation_basepoint: secp_pubkey!(&ctx, &v[153..186]),
payment_basepoint: secp_pubkey!(&ctx, &v[186..219]),
delayed_payment_basepoint: secp_pubkey!(&ctx, &v[219..252]),
htlc_basepoint: secp_pubkey!(&ctx, &v[252..285]),
first_per_commitment_point: secp_pubkey!(&ctx, &v[285..318]),
channel_flags: v[318],
shutdown_scriptpubkey: shutdown_scriptpubkey
})
Expand All @@ -551,10 +556,41 @@ impl MsgEncodable for OpenChannel {
}
}


impl MsgDecodable for AcceptChannel {
fn decode(_v: &[u8]) -> Result<Self, DecodeError> {
unimplemented!();
fn decode(v: &[u8]) -> Result<Self, DecodeError> {
if v.len() < 32+4*8+4+2*2+6*33 {
return Err(DecodeError::WrongLength);
}
let ctx = Secp256k1::without_caps();

let mut shutdown_scriptpubkey = None;
if v.len() >= 272 {
let len = byte_utils::slice_to_be16(&v[270..272]) as usize;
if v.len() < 272+len {
return Err(DecodeError::WrongLength);
}
shutdown_scriptpubkey = Some(Script::from(v[272..272+len].to_vec()));
} else if v.len() != 32+4*8+4+2*2+6*33 { // Message cant have 1 extra byte
return Err(DecodeError::WrongLength);
}

Ok(Self {
temporary_channel_id: deserialize(&v[0..32]).unwrap(),
dust_limit_satoshis: byte_utils::slice_to_be64(&v[32..40]),
max_htlc_value_in_flight_msat: byte_utils::slice_to_be64(&v[40..48]),
channel_reserve_satoshis: byte_utils::slice_to_be64(&v[48..56]),
htlc_minimum_msat: byte_utils::slice_to_be64(&v[56..64]),
minimum_depth: byte_utils::slice_to_be32(&v[64..68]),
to_self_delay: byte_utils::slice_to_be16(&v[68..70]),
max_accepted_htlcs: byte_utils::slice_to_be16(&v[70..72]),
funding_pubkey: secp_pubkey!(&ctx, &v[72..105]),
revocation_basepoint: secp_pubkey!(&ctx, &v[105..138]),
payment_basepoint: secp_pubkey!(&ctx, &v[138..171]),
delayed_payment_basepoint: secp_pubkey!(&ctx, &v[171..204]),
htlc_basepoint: secp_pubkey!(&ctx, &v[204..237]),
first_per_commitment_point: secp_pubkey!(&ctx, &v[237..270]),
shutdown_scriptpubkey: shutdown_scriptpubkey
})
}
}
impl MsgEncodable for AcceptChannel {
Expand All @@ -564,8 +600,17 @@ impl MsgEncodable for AcceptChannel {
}

impl MsgDecodable for FundingCreated {
fn decode(_v: &[u8]) -> Result<Self, DecodeError> {
unimplemented!();
fn decode(v: &[u8]) -> Result<Self, DecodeError> {
if v.len() < 32+32+2+64 {
return Err(DecodeError::WrongLength);
}
let ctx = Secp256k1::without_caps();
Ok(Self {
temporary_channel_id: deserialize(&v[0..32]).unwrap(),
funding_txid: deserialize(&v[32..64]).unwrap(),
funding_output_index: byte_utils::slice_to_be16(&v[64..66]),
signature: secp_signature!(&ctx, &v[66..130]),
})
}
}
impl MsgEncodable for FundingCreated {
Expand All @@ -575,8 +620,15 @@ impl MsgEncodable for FundingCreated {
}

impl MsgDecodable for FundingSigned {
fn decode(_v: &[u8]) -> Result<Self, DecodeError> {
unimplemented!();
fn decode(v: &[u8]) -> Result<Self, DecodeError> {
if v.len() < 32+64 {
return Err(DecodeError::WrongLength);
}
let ctx = Secp256k1::without_caps();
Ok(Self {
channel_id: deserialize(&v[0..32]).unwrap(),
signature: secp_signature!(&ctx, &v[32..96]),
})
}
}
impl MsgEncodable for FundingSigned {
Expand All @@ -586,8 +638,15 @@ impl MsgEncodable for FundingSigned {
}

impl MsgDecodable for FundingLocked {
fn decode(_v: &[u8]) -> Result<Self, DecodeError> {
unimplemented!();
fn decode(v: &[u8]) -> Result<Self, DecodeError> {
if v.len() < 32+33 {
return Err(DecodeError::WrongLength);
}
let ctx = Secp256k1::without_caps();
Ok(Self {
channel_id: deserialize(&v[0..32]).unwrap(),
next_per_commitment_point: secp_pubkey!(&ctx, &v[32..65]),
})
}
}
impl MsgEncodable for FundingLocked {
Expand Down Expand Up @@ -839,7 +898,7 @@ impl MsgEncodable for ChannelUpdate {

impl MsgDecodable for OnionRealm0HopData {
fn decode(v: &[u8]) -> Result<Self, DecodeError> {
if v.len() != 32 {
if v.len() < 32 {
return Err(DecodeError::WrongLength);
}
Ok(OnionRealm0HopData {
Expand All @@ -862,7 +921,7 @@ impl MsgEncodable for OnionRealm0HopData {

impl MsgDecodable for OnionHopData {
fn decode(v: &[u8]) -> Result<Self, DecodeError> {
if v.len() != 65 {
if v.len() < 65 {
return Err(DecodeError::WrongLength);
}
let realm = v[0];
Expand Down
Loading