diff --git a/Cargo.toml b/Cargo.toml index 56f4ac32f3f..c43e7927581 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ members = [ "lightning", "lightning-net-tokio", + "lightning-persister", ] # Our tests do actual crypo and lots of work, the tradeoff for -O1 is well worth it. diff --git a/fuzz/src/chanmon_consistency.rs b/fuzz/src/chanmon_consistency.rs index 1650e2e25f7..d88cc71fbf5 100644 --- a/fuzz/src/chanmon_consistency.rs +++ b/fuzz/src/chanmon_consistency.rs @@ -48,6 +48,7 @@ use lightning::routing::router::{Route, RouteHop}; use utils::test_logger; +use utils::test_persister::TestPersister; use bitcoin::secp256k1::key::{PublicKey,SecretKey}; use bitcoin::secp256k1::Secp256k1; @@ -84,7 +85,7 @@ impl Writer for VecWriter { struct TestChainMonitor { pub logger: Arc, - pub chain_monitor: Arc, Arc, Arc, Arc>>, + pub chain_monitor: Arc, Arc, Arc, Arc, Arc>>, pub update_ret: Mutex>, // If we reload a node with an old copy of ChannelMonitors, the ChannelManager deserialization // logic will automatically force-close our channels for us (as we don't have an up-to-date @@ -95,9 +96,9 @@ struct TestChainMonitor { pub should_update_manager: atomic::AtomicBool, } impl TestChainMonitor { - pub fn new(broadcaster: Arc, logger: Arc, feeest: Arc) -> Self { + pub fn new(broadcaster: Arc, logger: Arc, feeest: Arc, persister: Arc) -> Self { Self { - chain_monitor: Arc::new(chainmonitor::ChainMonitor::new(None, broadcaster, logger.clone(), feeest)), + chain_monitor: Arc::new(chainmonitor::ChainMonitor::new(None, broadcaster, logger.clone(), feeest, persister)), logger, update_ret: Mutex::new(Ok(())), latest_monitors: Mutex::new(HashMap::new()), @@ -110,7 +111,7 @@ impl chain::Watch for TestChainMonitor { fn watch_channel(&self, funding_txo: OutPoint, monitor: channelmonitor::ChannelMonitor) -> Result<(), channelmonitor::ChannelMonitorUpdateErr> { let mut ser = VecWriter(Vec::new()); - monitor.write_for_disk(&mut ser).unwrap(); + monitor.serialize_for_disk(&mut ser).unwrap(); if let Some(_) = self.latest_monitors.lock().unwrap().insert(funding_txo, (monitor.get_latest_update_id(), ser.0)) { panic!("Already had monitor pre-watch_channel"); } @@ -127,9 +128,9 @@ impl chain::Watch for TestChainMonitor { }; let mut deserialized_monitor = <(BlockHash, channelmonitor::ChannelMonitor)>:: read(&mut Cursor::new(&map_entry.get().1)).unwrap().1; - deserialized_monitor.update_monitor(update.clone(), &&TestBroadcaster {}, &self.logger).unwrap(); + deserialized_monitor.update_monitor(&update, &&TestBroadcaster {}, &self.logger).unwrap(); let mut ser = VecWriter(Vec::new()); - deserialized_monitor.write_for_disk(&mut ser).unwrap(); + deserialized_monitor.serialize_for_disk(&mut ser).unwrap(); map_entry.insert((update.update_id, ser.0)); self.should_update_manager.store(true, atomic::Ordering::Relaxed); self.update_ret.lock().unwrap().clone() @@ -192,7 +193,7 @@ pub fn do_test(data: &[u8], out: Out) { macro_rules! make_node { ($node_id: expr) => { { let logger: Arc = Arc::new(test_logger::TestLogger::new($node_id.to_string(), out.clone())); - let monitor = Arc::new(TestChainMonitor::new(broadcast.clone(), logger.clone(), fee_est.clone())); + let monitor = Arc::new(TestChainMonitor::new(broadcast.clone(), logger.clone(), fee_est.clone(), Arc::new(TestPersister{}))); let keys_manager = Arc::new(KeyProvider { node_id: $node_id, rand_bytes_id: atomic::AtomicU8::new(0) }); let mut config = UserConfig::default(); @@ -207,7 +208,7 @@ pub fn do_test(data: &[u8], out: Out) { macro_rules! reload_node { ($ser: expr, $node_id: expr, $old_monitors: expr) => { { let logger: Arc = Arc::new(test_logger::TestLogger::new($node_id.to_string(), out.clone())); - let chain_monitor = Arc::new(TestChainMonitor::new(broadcast.clone(), logger.clone(), fee_est.clone())); + let chain_monitor = Arc::new(TestChainMonitor::new(broadcast.clone(), logger.clone(), fee_est.clone(), Arc::new(TestPersister{}))); let keys_manager = Arc::new(KeyProvider { node_id: $node_id, rand_bytes_id: atomic::AtomicU8::new(0) }); let mut config = UserConfig::default(); diff --git a/fuzz/src/chanmon_deser.rs b/fuzz/src/chanmon_deser.rs index 5a76340ff30..fd326cc2ec8 100644 --- a/fuzz/src/chanmon_deser.rs +++ b/fuzz/src/chanmon_deser.rs @@ -26,7 +26,7 @@ impl Writer for VecWriter { pub fn do_test(data: &[u8], _out: Out) { if let Ok((latest_block_hash, monitor)) = <(BlockHash, channelmonitor::ChannelMonitor)>::read(&mut Cursor::new(data)) { let mut w = VecWriter(Vec::new()); - monitor.write_for_disk(&mut w).unwrap(); + monitor.serialize_for_disk(&mut w).unwrap(); let deserialized_copy = <(BlockHash, channelmonitor::ChannelMonitor)>::read(&mut Cursor::new(&w.0)).unwrap(); assert!(latest_block_hash == deserialized_copy.0); assert!(monitor == deserialized_copy.1); diff --git a/fuzz/src/full_stack.rs b/fuzz/src/full_stack.rs index 1ed17b9ea3f..3aeb3d233b0 100644 --- a/fuzz/src/full_stack.rs +++ b/fuzz/src/full_stack.rs @@ -40,6 +40,7 @@ use lightning::util::logger::Logger; use lightning::util::config::UserConfig; use utils::test_logger; +use utils::test_persister::TestPersister; use bitcoin::secp256k1::key::{PublicKey,SecretKey}; use bitcoin::secp256k1::Secp256k1; @@ -145,13 +146,13 @@ impl<'a> std::hash::Hash for Peer<'a> { type ChannelMan = ChannelManager< EnforcingChannelKeys, - Arc, Arc, Arc, Arc>>, + Arc, Arc, Arc, Arc, Arc>>, Arc, Arc, Arc, Arc>; type PeerMan<'a> = PeerManager, Arc, Arc, Arc>>, Arc>; struct MoneyLossDetector<'a> { manager: Arc, - monitor: Arc, Arc, Arc, Arc>>, + monitor: Arc, Arc, Arc, Arc, Arc>>, handler: PeerMan<'a>, peers: &'a RefCell<[bool; 256]>, @@ -165,7 +166,7 @@ struct MoneyLossDetector<'a> { impl<'a> MoneyLossDetector<'a> { pub fn new(peers: &'a RefCell<[bool; 256]>, manager: Arc, - monitor: Arc, Arc, Arc, Arc>>, + monitor: Arc, Arc, Arc, Arc, Arc>>, handler: PeerMan<'a>) -> Self { MoneyLossDetector { manager, @@ -333,7 +334,7 @@ pub fn do_test(data: &[u8], logger: &Arc) { }; let broadcast = Arc::new(TestBroadcaster{}); - let monitor = Arc::new(chainmonitor::ChainMonitor::new(None, broadcast.clone(), Arc::clone(&logger), fee_est.clone())); + let monitor = Arc::new(chainmonitor::ChainMonitor::new(None, broadcast.clone(), Arc::clone(&logger), fee_est.clone(), Arc::new(TestPersister{}))); let keys_manager = Arc::new(KeyProvider { node_secret: our_network_key.clone(), counter: AtomicU64::new(0) }); let mut config = UserConfig::default(); diff --git a/fuzz/src/utils/mod.rs b/fuzz/src/utils/mod.rs index bb5b00a5b54..937eee6b6b2 100644 --- a/fuzz/src/utils/mod.rs +++ b/fuzz/src/utils/mod.rs @@ -8,3 +8,4 @@ // licenses. pub mod test_logger; +pub mod test_persister; diff --git a/fuzz/src/utils/test_persister.rs b/fuzz/src/utils/test_persister.rs new file mode 100644 index 00000000000..0bd60911efe --- /dev/null +++ b/fuzz/src/utils/test_persister.rs @@ -0,0 +1,14 @@ +use lightning::chain::channelmonitor; +use lightning::chain::transaction::OutPoint; +use lightning::util::enforcing_trait_impls::EnforcingChannelKeys; + +pub struct TestPersister {} +impl channelmonitor::Persist for TestPersister { + fn persist_new_channel(&self, _funding_txo: OutPoint, _data: &channelmonitor::ChannelMonitor) -> Result<(), channelmonitor::ChannelMonitorUpdateErr> { + Ok(()) + } + + fn update_persisted_channel(&self, _funding_txo: OutPoint, _update: &channelmonitor::ChannelMonitorUpdate, _data: &channelmonitor::ChannelMonitor) -> Result<(), channelmonitor::ChannelMonitorUpdateErr> { + Ok(()) + } +} diff --git a/lightning-net-tokio/src/lib.rs b/lightning-net-tokio/src/lib.rs index e84ee76229f..36384380fb1 100644 --- a/lightning-net-tokio/src/lib.rs +++ b/lightning-net-tokio/src/lib.rs @@ -36,7 +36,8 @@ //! type Logger = dyn lightning::util::logger::Logger; //! type ChainAccess = dyn lightning::chain::Access; //! type ChainFilter = dyn lightning::chain::Filter; -//! type ChainMonitor = lightning::chain::chainmonitor::ChainMonitor, Arc, Arc, Arc>; +//! type DataPersister = dyn lightning::chain::channelmonitor::Persist; +//! type ChainMonitor = lightning::chain::chainmonitor::ChainMonitor, Arc, Arc, Arc, Arc>; //! type ChannelManager = lightning::ln::channelmanager::SimpleArcChannelManager; //! type PeerManager = lightning::ln::peer_handler::SimpleArcPeerManager; //! diff --git a/lightning-persister/Cargo.toml b/lightning-persister/Cargo.toml new file mode 100644 index 00000000000..44bc952a7d9 --- /dev/null +++ b/lightning-persister/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "lightning-persister" +version = "0.0.1" +authors = ["Valentine Wallace", "Matt Corallo"] +license = "Apache-2.0" +description = """ +Utilities to manage channel data persistence and retrieval. +""" + +[dependencies] +bitcoin = "0.24" +lightning = { version = "0.0.11", path = "../lightning" } +libc = "0.2" + +[dev-dependencies.bitcoin] +version = "0.24" +features = ["bitcoinconsensus"] + +[dev-dependencies] +lightning = { version = "0.0.11", path = "../lightning", features = ["_test_utils"] } diff --git a/lightning-persister/src/lib.rs b/lightning-persister/src/lib.rs new file mode 100644 index 00000000000..48d4d0aa810 --- /dev/null +++ b/lightning-persister/src/lib.rs @@ -0,0 +1,413 @@ +extern crate lightning; +extern crate bitcoin; +extern crate libc; + +use bitcoin::hashes::hex::ToHex; +use lightning::chain::channelmonitor::{ChannelMonitor, ChannelMonitorUpdate, ChannelMonitorUpdateErr}; +use lightning::chain::channelmonitor; +use lightning::chain::keysinterface::ChannelKeys; +use lightning::chain::transaction::OutPoint; +use lightning::util::ser::{Writeable, Readable}; +use std::fs; +use std::io::Error; +use std::path::{Path, PathBuf}; + +#[cfg(test)] +use { + bitcoin::{BlockHash, Txid}, + bitcoin::hashes::hex::FromHex, + std::collections::HashMap, + std::io::Cursor +}; + +#[cfg(not(target_os = "windows"))] +use std::os::unix::io::AsRawFd; + +/// FilesystemPersister persists channel data on disk, where each channel's +/// data is stored in a file named after its funding outpoint. +/// +/// Warning: this module does the best it can with calls to persist data, but it +/// can only guarantee that the data is passed to the drive. It is up to the +/// drive manufacturers to do the actual persistence properly, which they often +/// don't (especially on consumer-grade hardware). Therefore, it is up to the +/// user to validate their entire storage stack, to ensure the writes are +/// persistent. +/// Corollary: especially when dealing with larger amounts of money, it is best +/// practice to have multiple channel data backups and not rely only on one +/// FilesystemPersister. +pub struct FilesystemPersister { + path_to_channel_data: String, +} + +trait DiskWriteable { + fn write(&self, writer: &mut fs::File) -> Result<(), Error>; +} + +impl DiskWriteable for ChannelMonitor { + fn write(&self, writer: &mut fs::File) -> Result<(), Error> { + self.serialize_for_disk(writer) + } +} + +impl FilesystemPersister { + /// Initialize a new FilesystemPersister and set the path to the individual channels' + /// files. + pub fn new(path_to_channel_data: String) -> Self { + return Self { + path_to_channel_data, + } + } + + fn get_full_filepath(&self, funding_txo: OutPoint) -> String { + let mut path = PathBuf::from(&self.path_to_channel_data); + path.push(format!("{}_{}", funding_txo.txid.to_hex(), funding_txo.index)); + path.to_str().unwrap().to_string() + } + + // Utility to write a file to disk. + fn write_channel_data(&self, funding_txo: OutPoint, monitor: &dyn DiskWriteable) -> std::io::Result<()> { + fs::create_dir_all(&self.path_to_channel_data)?; + // Do a crazy dance with lots of fsync()s to be overly cautious here... + // We never want to end up in a state where we've lost the old data, or end up using the + // old data on power loss after we've returned. + // The way to atomically write a file on Unix platforms is: + // open(tmpname), write(tmpfile), fsync(tmpfile), close(tmpfile), rename(), fsync(dir) + let filename = self.get_full_filepath(funding_txo); + let tmp_filename = format!("{}.tmp", filename.clone()); + + { + // Note that going by rust-lang/rust@d602a6b, on MacOS it is only safe to use + // rust stdlib 1.36 or higher. + let mut f = fs::File::create(&tmp_filename)?; + monitor.write(&mut f)?; + f.sync_all()?; + } + fs::rename(&tmp_filename, &filename)?; + // Fsync the parent directory on Unix. + #[cfg(not(target_os = "windows"))] + { + let path = Path::new(&filename).parent().unwrap(); + let dir_file = fs::OpenOptions::new().read(true).open(path)?; + unsafe { libc::fsync(dir_file.as_raw_fd()); } + } + Ok(()) + } + + #[cfg(test)] + fn load_channel_data(&self) -> + Result>, ChannelMonitorUpdateErr> { + if let Err(_) = fs::create_dir_all(&self.path_to_channel_data) { + return Err(ChannelMonitorUpdateErr::PermanentFailure); + } + let mut res = HashMap::new(); + for file_option in fs::read_dir(&self.path_to_channel_data).unwrap() { + let file = file_option.unwrap(); + let owned_file_name = file.file_name(); + let filename = owned_file_name.to_str(); + if !filename.is_some() || !filename.unwrap().is_ascii() || filename.unwrap().len() < 65 { + return Err(ChannelMonitorUpdateErr::PermanentFailure); + } + + let txid = Txid::from_hex(filename.unwrap().split_at(64).0); + if txid.is_err() { return Err(ChannelMonitorUpdateErr::PermanentFailure); } + + let index = filename.unwrap().split_at(65).1.split('.').next().unwrap().parse(); + if index.is_err() { return Err(ChannelMonitorUpdateErr::PermanentFailure); } + + let contents = fs::read(&file.path()); + if contents.is_err() { return Err(ChannelMonitorUpdateErr::PermanentFailure); } + + if let Ok((_, loaded_monitor)) = + <(BlockHash, ChannelMonitor)>::read(&mut Cursor::new(&contents.unwrap())) { + res.insert(OutPoint { txid: txid.unwrap(), index: index.unwrap() }, loaded_monitor); + } else { + return Err(ChannelMonitorUpdateErr::PermanentFailure); + } + } + Ok(res) + } +} + +impl channelmonitor::Persist for FilesystemPersister { + fn persist_new_channel(&self, funding_txo: OutPoint, monitor: &ChannelMonitor) -> Result<(), ChannelMonitorUpdateErr> { + self.write_channel_data(funding_txo, monitor) + .map_err(|_| ChannelMonitorUpdateErr::PermanentFailure) + } + + fn update_persisted_channel(&self, funding_txo: OutPoint, _update: &ChannelMonitorUpdate, monitor: &ChannelMonitor) -> Result<(), ChannelMonitorUpdateErr> { + self.write_channel_data(funding_txo, monitor) + .map_err(|_| ChannelMonitorUpdateErr::PermanentFailure) + } +} + +#[cfg(test)] +impl Drop for FilesystemPersister { + fn drop(&mut self) { + // We test for invalid directory names, so it's OK if directory removal + // fails. + match fs::remove_dir_all(&self.path_to_channel_data) { + Err(e) => println!("Failed to remove test persister directory: {}", e), + _ => {} + } + } +} + +#[cfg(test)] +mod tests { + extern crate lightning; + extern crate bitcoin; + use crate::FilesystemPersister; + use bitcoin::blockdata::block::{Block, BlockHeader}; + use bitcoin::hashes::hex::FromHex; + use bitcoin::Txid; + use DiskWriteable; + use Error; + use lightning::chain::channelmonitor::{Persist, ChannelMonitorUpdateErr}; + use lightning::chain::transaction::OutPoint; + use lightning::{check_closed_broadcast, check_added_monitors}; + use lightning::ln::features::InitFeatures; + use lightning::ln::functional_test_utils::*; + use lightning::ln::msgs::ErrorAction; + use lightning::util::enforcing_trait_impls::EnforcingChannelKeys; + use lightning::util::events::{MessageSendEventsProvider, MessageSendEvent}; + use lightning::util::ser::Writer; + use lightning::util::test_utils; + use std::fs; + use std::io; + #[cfg(target_os = "windows")] + use { + lightning::get_event_msg, + lightning::ln::msgs::ChannelMessageHandler, + }; + + struct TestWriteable{} + impl DiskWriteable for TestWriteable { + fn write(&self, writer: &mut fs::File) -> Result<(), Error> { + writer.write_all(&[42; 1]) + } + } + + // Integration-test the FilesystemPersister. Test relaying a few payments + // and check that the persisted data is updated the appropriate number of + // times. + #[test] + fn test_filesystem_persister() { + // Create the nodes, giving them FilesystemPersisters for data persisters. + let persister_0 = FilesystemPersister::new("test_filesystem_persister_0".to_string()); + let persister_1 = FilesystemPersister::new("test_filesystem_persister_1".to_string()); + let chanmon_cfgs = create_chanmon_cfgs(2); + let mut node_cfgs = create_node_cfgs(2, &chanmon_cfgs); + let chain_mon_0 = test_utils::TestChainMonitor::new(Some(&chanmon_cfgs[0].chain_source), &chanmon_cfgs[0].tx_broadcaster, &chanmon_cfgs[0].logger, &chanmon_cfgs[0].fee_estimator, &persister_0); + let chain_mon_1 = test_utils::TestChainMonitor::new(Some(&chanmon_cfgs[1].chain_source), &chanmon_cfgs[1].tx_broadcaster, &chanmon_cfgs[1].logger, &chanmon_cfgs[1].fee_estimator, &persister_1); + node_cfgs[0].chain_monitor = chain_mon_0; + node_cfgs[1].chain_monitor = chain_mon_1; + let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]); + let nodes = create_network(2, &node_cfgs, &node_chanmgrs); + + // Check that the persisted channel data is empty before any channels are + // open. + let mut persisted_chan_data_0 = persister_0.load_channel_data::().unwrap(); + assert_eq!(persisted_chan_data_0.keys().len(), 0); + let mut persisted_chan_data_1 = persister_1.load_channel_data::().unwrap(); + assert_eq!(persisted_chan_data_1.keys().len(), 0); + + // Helper to make sure the channel is on the expected update ID. + macro_rules! check_persisted_data { + ($expected_update_id: expr) => { + persisted_chan_data_0 = persister_0.load_channel_data::().unwrap(); + assert_eq!(persisted_chan_data_0.keys().len(), 1); + for mon in persisted_chan_data_0.values() { + assert_eq!(mon.get_latest_update_id(), $expected_update_id); + } + persisted_chan_data_1 = persister_1.load_channel_data::().unwrap(); + assert_eq!(persisted_chan_data_1.keys().len(), 1); + for mon in persisted_chan_data_1.values() { + assert_eq!(mon.get_latest_update_id(), $expected_update_id); + } + } + } + + // Create some initial channel and check that a channel was persisted. + let _ = create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known()); + check_persisted_data!(0); + + // Send a few payments and make sure the monitors are updated to the latest. + send_payment(&nodes[0], &vec!(&nodes[1])[..], 8000000, 8_000_000); + check_persisted_data!(5); + send_payment(&nodes[1], &vec!(&nodes[0])[..], 4000000, 4_000_000); + check_persisted_data!(10); + + // Force close because cooperative close doesn't result in any persisted + // updates. + nodes[0].node.force_close_channel(&nodes[0].node.list_channels()[0].channel_id); + check_closed_broadcast!(nodes[0], false); + check_added_monitors!(nodes[0], 1); + + let node_txn = nodes[0].tx_broadcaster.txn_broadcasted.lock().unwrap(); + assert_eq!(node_txn.len(), 1); + + let header = BlockHeader { version: 0x20000000, prev_blockhash: Default::default(), merkle_root: Default::default(), time: 42, bits: 42, nonce: 42 }; + connect_block(&nodes[1], &Block { header, txdata: vec![node_txn[0].clone(), node_txn[0].clone()]}, 1); + check_closed_broadcast!(nodes[1], false); + check_added_monitors!(nodes[1], 1); + + // Make sure everything is persisted as expected after close. + check_persisted_data!(11); + } + + // Test that if the persister's path to channel data is read-only, writing + // data to it fails. Windows ignores the read-only flag for folders, so this + // test is Unix-only. + #[cfg(not(target_os = "windows"))] + #[test] + fn test_readonly_dir() { + let persister = FilesystemPersister::new("test_readonly_dir_persister".to_string()); + let test_writeable = TestWriteable{}; + let test_txo = OutPoint { + txid: Txid::from_hex("8984484a580b825b9972d7adb15050b3ab624ccd731946b3eeddb92f4e7ef6be").unwrap(), + index: 0 + }; + // Create the persister's directory and set it to read-only. + let path = &persister.path_to_channel_data; + fs::create_dir_all(path).unwrap(); + let mut perms = fs::metadata(path).unwrap().permissions(); + perms.set_readonly(true); + fs::set_permissions(path, perms).unwrap(); + match persister.write_channel_data(test_txo, &test_writeable) { + Err(e) => assert_eq!(e.kind(), io::ErrorKind::PermissionDenied), + _ => panic!("Unexpected error message") + } + } + + // Test failure to rename in the process of atomically creating a channel + // monitor's file. We induce this failure by making the `tmp` file a + // directory. + // Explanation: given "from" = the file being renamed, "to" = the + // renamee that already exists: Windows should fail because it'll fail + // whenever "to" is a directory, and Unix should fail because if "from" is a + // file, then "to" is also required to be a file. + #[test] + fn test_rename_failure() { + let persister = FilesystemPersister::new("test_rename_failure".to_string()); + let test_writeable = TestWriteable{}; + let txid_hex = "8984484a580b825b9972d7adb15050b3ab624ccd731946b3eeddb92f4e7ef6be"; + let outp_idx = 0; + let test_txo = OutPoint { + txid: Txid::from_hex(txid_hex).unwrap(), + index: outp_idx, + }; + // Create the channel data file and make it a directory. + let path = &persister.path_to_channel_data; + fs::create_dir_all(format!("{}/{}_{}", path, txid_hex, outp_idx)).unwrap(); + match persister.write_channel_data(test_txo, &test_writeable) { + Err(e) => { + #[cfg(not(target_os = "windows"))] + assert_eq!(e.kind(), io::ErrorKind::Other); + #[cfg(target_os = "windows")] + assert_eq!(e.kind(), io::ErrorKind::PermissionDenied); + } + _ => panic!("Unexpected error message") + } + } + + // Test failure to create the temporary file in the persistence process. + // We induce this failure by having the temp file already exist and be a + // directory. + #[test] + fn test_tmp_file_creation_failure() { + let persister = FilesystemPersister::new("test_tmp_file_creation_failure".to_string()); + let test_writeable = TestWriteable{}; + let txid_hex = "8984484a580b825b9972d7adb15050b3ab624ccd731946b3eeddb92f4e7ef6be"; + let outp_idx = 0; + let test_txo = OutPoint { + txid: Txid::from_hex(txid_hex).unwrap(), + index: outp_idx, + }; + // Create the tmp file and make it a directory. + let path = &persister.path_to_channel_data; + fs::create_dir_all(format!("{}/{}_{}.tmp", path, txid_hex, outp_idx)).unwrap(); + match persister.write_channel_data(test_txo, &test_writeable) { + Err(e) => { + #[cfg(not(target_os = "windows"))] + assert_eq!(e.kind(), io::ErrorKind::Other); + #[cfg(target_os = "windows")] + assert_eq!(e.kind(), io::ErrorKind::PermissionDenied); + } + _ => panic!("Unexpected error message") + } + } + + // Test that if the persister's path to channel data is read-only, writing a + // monitor to it results in the persister returning a PermanentFailure. + // Windows ignores the read-only flag for folders, so this test is Unix-only. + #[cfg(not(target_os = "windows"))] + #[test] + fn test_readonly_dir_perm_failure() { + let persister = FilesystemPersister::new("test_readonly_dir_perm_failure".to_string()); + fs::create_dir_all(&persister.path_to_channel_data).unwrap(); + + // Set up a dummy channel and force close. This will produce a monitor + // that we can then use to test persistence. + let chanmon_cfgs = create_chanmon_cfgs(2); + let node_cfgs = create_node_cfgs(2, &chanmon_cfgs); + let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]); + let nodes = create_network(2, &node_cfgs, &node_chanmgrs); + let chan = create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known()); + nodes[1].node.force_close_channel(&chan.2); + let mut added_monitors = nodes[1].chain_monitor.added_monitors.lock().unwrap(); + + // Set the persister's directory to read-only, which should result in + // returning a permanent failure when we then attempt to persist a + // channel update. + let path = &persister.path_to_channel_data; + let mut perms = fs::metadata(path).unwrap().permissions(); + perms.set_readonly(true); + fs::set_permissions(path, perms).unwrap(); + + let test_txo = OutPoint { + txid: Txid::from_hex("8984484a580b825b9972d7adb15050b3ab624ccd731946b3eeddb92f4e7ef6be").unwrap(), + index: 0 + }; + match persister.persist_new_channel(test_txo, &added_monitors[0].1) { + Err(ChannelMonitorUpdateErr::PermanentFailure) => {}, + _ => panic!("unexpected result from persisting new channel") + } + + nodes[1].node.get_and_clear_pending_msg_events(); + added_monitors.clear(); + } + + // Test that if a persister's directory name is invalid, monitor persistence + // will fail. + #[cfg(target_os = "windows")] + #[test] + fn test_fail_on_open() { + // Set up a dummy channel and force close. This will produce a monitor + // that we can then use to test persistence. + let chanmon_cfgs = create_chanmon_cfgs(2); + let mut node_cfgs = create_node_cfgs(2, &chanmon_cfgs); + let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]); + let nodes = create_network(2, &node_cfgs, &node_chanmgrs); + let chan = create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known()); + nodes[1].node.force_close_channel(&chan.2); + let mut added_monitors = nodes[1].chain_monitor.added_monitors.lock().unwrap(); + + // Create the persister with an invalid directory name and test that the + // channel fails to open because the directories fail to be created. There + // don't seem to be invalid filename characters on Unix that Rust doesn't + // handle, hence why the test is Windows-only. + let persister = FilesystemPersister::new(":<>/".to_string()); + + let test_txo = OutPoint { + txid: Txid::from_hex("8984484a580b825b9972d7adb15050b3ab624ccd731946b3eeddb92f4e7ef6be").unwrap(), + index: 0 + }; + match persister.persist_new_channel(test_txo, &added_monitors[0].1) { + Err(ChannelMonitorUpdateErr::PermanentFailure) => {}, + _ => panic!("unexpected result from persisting new channel") + } + + nodes[1].node.get_and_clear_pending_msg_events(); + added_monitors.clear(); + } +} diff --git a/lightning/Cargo.toml b/lightning/Cargo.toml index 3b12468e9c4..b5ec64d99b0 100644 --- a/lightning/Cargo.toml +++ b/lightning/Cargo.toml @@ -12,6 +12,7 @@ Still missing tons of error-handling. See GitHub issues for suggested projects i [features] fuzztarget = ["bitcoin/fuzztarget"] +_test_utils = ["hex", "regex"] # Unlog messages superior at targeted level. max_level_off = [] max_level_error = [] @@ -25,6 +26,9 @@ unsafe_revoked_tx_signing = [] [dependencies] bitcoin = "0.24" +hex = { version = "0.3", optional = true } +regex = { version = "0.1.80", optional = true } + [dev-dependencies.bitcoin] version = "0.24" features = ["bitcoinconsensus"] diff --git a/lightning/src/chain/chainmonitor.rs b/lightning/src/chain/chainmonitor.rs index 6c67f54f11a..469837f07ed 100644 --- a/lightning/src/chain/chainmonitor.rs +++ b/lightning/src/chain/chainmonitor.rs @@ -34,7 +34,8 @@ use bitcoin::blockdata::block::BlockHeader; use chain; use chain::Filter; use chain::chaininterface::{BroadcasterInterface, FeeEstimator}; -use chain::channelmonitor::{ChannelMonitor, ChannelMonitorUpdate, ChannelMonitorUpdateErr, MonitorEvent, MonitorUpdateError}; +use chain::channelmonitor; +use chain::channelmonitor::{ChannelMonitor, ChannelMonitorUpdate, ChannelMonitorUpdateErr, MonitorEvent, Persist}; use chain::transaction::{OutPoint, TransactionData}; use chain::keysinterface::ChannelKeys; use util::logger::Logger; @@ -55,25 +56,28 @@ use std::ops::Deref; /// [`chain::Watch`]: ../trait.Watch.html /// [`ChannelManager`]: ../../ln/channelmanager/struct.ChannelManager.html /// [module-level documentation]: index.html -pub struct ChainMonitor +pub struct ChainMonitor where C::Target: chain::Filter, T::Target: BroadcasterInterface, F::Target: FeeEstimator, L::Target: Logger, + P::Target: channelmonitor::Persist, { /// The monitors pub monitors: Mutex>>, chain_source: Option, broadcaster: T, logger: L, - fee_estimator: F + fee_estimator: F, + persister: P, } -impl ChainMonitor - where C::Target: chain::Filter, - T::Target: BroadcasterInterface, - F::Target: FeeEstimator, - L::Target: Logger, +impl ChainMonitor +where C::Target: chain::Filter, + T::Target: BroadcasterInterface, + F::Target: FeeEstimator, + L::Target: Logger, + P::Target: channelmonitor::Persist, { /// Dispatches to per-channel monitors, which are responsible for updating their on-chain view /// of a channel and reacting accordingly based on transactions in the connected block. See @@ -124,27 +128,47 @@ impl ChainMonit /// transactions relevant to the watched channels. /// /// [`chain::Filter`]: ../trait.Filter.html - pub fn new(chain_source: Option, broadcaster: T, logger: L, feeest: F) -> Self { + pub fn new(chain_source: Option, broadcaster: T, logger: L, feeest: F, persister: P) -> Self { Self { monitors: Mutex::new(HashMap::new()), chain_source, broadcaster, logger, fee_estimator: feeest, + persister, } } +} + +impl chain::Watch for ChainMonitor +where C::Target: chain::Filter, + T::Target: BroadcasterInterface, + F::Target: FeeEstimator, + L::Target: Logger, + P::Target: channelmonitor::Persist, +{ + type Keys = ChanSigner; /// Adds the monitor that watches the channel referred to by the given outpoint. /// /// Calls back to [`chain::Filter`] with the funding transaction and outputs to watch. /// + /// Note that we persist the given `ChannelMonitor` while holding the `ChainMonitor` + /// monitors lock. + /// /// [`chain::Filter`]: ../trait.Filter.html - fn add_monitor(&self, outpoint: OutPoint, monitor: ChannelMonitor) -> Result<(), MonitorUpdateError> { + fn watch_channel(&self, funding_outpoint: OutPoint, monitor: ChannelMonitor) -> Result<(), ChannelMonitorUpdateErr> { let mut monitors = self.monitors.lock().unwrap(); - let entry = match monitors.entry(outpoint) { - hash_map::Entry::Occupied(_) => return Err(MonitorUpdateError("Channel monitor for given outpoint is already present")), + let entry = match monitors.entry(funding_outpoint) { + hash_map::Entry::Occupied(_) => { + log_error!(self.logger, "Failed to add new channel data: channel monitor for given outpoint is already present"); + return Err(ChannelMonitorUpdateErr::PermanentFailure)}, hash_map::Entry::Vacant(e) => e, }; + if let Err(e) = self.persister.persist_new_channel(funding_outpoint, &monitor) { + log_error!(self.logger, "Failed to persist new channel data"); + return Err(e); + } { let funding_txo = monitor.get_funding_txo(); log_trace!(self.logger, "Got new Channel Monitor for channel {}", log_bytes!(funding_txo.0.to_channel_id()[..])); @@ -162,38 +186,34 @@ impl ChainMonit Ok(()) } - /// Updates the monitor that watches the channel referred to by the given outpoint. - fn update_monitor(&self, outpoint: OutPoint, update: ChannelMonitorUpdate) -> Result<(), MonitorUpdateError> { + /// Note that we persist the given `ChannelMonitor` update while holding the + /// `ChainMonitor` monitors lock. + fn update_channel(&self, funding_txo: OutPoint, update: ChannelMonitorUpdate) -> Result<(), ChannelMonitorUpdateErr> { + // Update the monitor that watches the channel referred to by the given outpoint. let mut monitors = self.monitors.lock().unwrap(); - match monitors.get_mut(&outpoint) { + match monitors.get_mut(&funding_txo) { + None => { + log_error!(self.logger, "Failed to update channel monitor: no such monitor registered"); + Err(ChannelMonitorUpdateErr::PermanentFailure) + }, Some(orig_monitor) => { log_trace!(self.logger, "Updating Channel Monitor for channel {}", log_funding_info!(orig_monitor)); - orig_monitor.update_monitor(update, &self.broadcaster, &self.logger) - }, - None => Err(MonitorUpdateError("No such monitor registered")) - } - } -} - -impl chain::Watch for ChainMonitor - where C::Target: chain::Filter, - T::Target: BroadcasterInterface, - F::Target: FeeEstimator, - L::Target: Logger, -{ - type Keys = ChanSigner; - - fn watch_channel(&self, funding_txo: OutPoint, monitor: ChannelMonitor) -> Result<(), ChannelMonitorUpdateErr> { - match self.add_monitor(funding_txo, monitor) { - Ok(_) => Ok(()), - Err(_) => Err(ChannelMonitorUpdateErr::PermanentFailure), - } - } - - fn update_channel(&self, funding_txo: OutPoint, update: ChannelMonitorUpdate) -> Result<(), ChannelMonitorUpdateErr> { - match self.update_monitor(funding_txo, update) { - Ok(_) => Ok(()), - Err(_) => Err(ChannelMonitorUpdateErr::PermanentFailure), + let update_res = orig_monitor.update_monitor(&update, &self.broadcaster, &self.logger); + if let Err(e) = &update_res { + log_error!(self.logger, "Failed to update channel monitor: {:?}", e); + } + // Even if updating the monitor returns an error, the monitor's state will + // still be changed. So, persist the updated monitor despite the error. + let persist_res = self.persister.update_persisted_channel(funding_txo, &update, orig_monitor); + if let Err(ref e) = persist_res { + log_error!(self.logger, "Failed to persist channel monitor update: {:?}", e); + } + if update_res.is_err() { + Err(ChannelMonitorUpdateErr::PermanentFailure) + } else { + persist_res + } + } } } @@ -206,11 +226,12 @@ impl events::EventsProvider for ChainMonitor +impl events::EventsProvider for ChainMonitor where C::Target: chain::Filter, T::Target: BroadcasterInterface, F::Target: FeeEstimator, L::Target: Logger, + P::Target: channelmonitor::Persist, { fn get_and_clear_pending_events(&self) -> Vec { let mut pending_events = Vec::new(); diff --git a/lightning/src/chain/channelmonitor.rs b/lightning/src/chain/channelmonitor.rs index d8433ef5ffa..889dfa211ee 100644 --- a/lightning/src/chain/channelmonitor.rs +++ b/lightning/src/chain/channelmonitor.rs @@ -57,7 +57,7 @@ use std::io::Error; /// An update generated by the underlying Channel itself which contains some new information the /// ChannelMonitor should be made aware of. -#[cfg_attr(test, derive(PartialEq))] +#[cfg_attr(any(test, feature = "_test_utils"), derive(PartialEq))] #[derive(Clone)] #[must_use] pub struct ChannelMonitorUpdate { @@ -95,7 +95,7 @@ impl Readable for ChannelMonitorUpdate { } /// An error enum representing a failure to persist a channel monitor update. -#[derive(Clone)] +#[derive(Clone, Debug)] pub enum ChannelMonitorUpdateErr { /// Used to indicate a temporary failure (eg connection to a watchtower or remote backup of /// our state failed, but is expected to succeed at some point in the future). @@ -159,7 +159,7 @@ pub enum ChannelMonitorUpdateErr { /// inconsistent with the ChannelMonitor being called. eg for ChannelMonitor::update_monitor this /// means you tried to update a monitor for a different channel or the ChannelMonitorUpdate was /// corrupted. -/// Contains a human-readable error message. +/// Contains a developer-readable error message. #[derive(Debug)] pub struct MonitorUpdateError(pub &'static str); @@ -470,7 +470,7 @@ enum OnchainEvent { const SERIALIZATION_VERSION: u8 = 1; const MIN_SERIALIZATION_VERSION: u8 = 1; -#[cfg_attr(test, derive(PartialEq))] +#[cfg_attr(any(test, feature = "_test_utils"), derive(PartialEq))] #[derive(Clone)] pub(crate) enum ChannelMonitorUpdateStep { LatestHolderCommitmentTXInfo { @@ -696,7 +696,7 @@ pub struct ChannelMonitor { secp_ctx: Secp256k1, //TODO: dedup this a bit... } -#[cfg(any(test, feature = "fuzztarget"))] +#[cfg(any(test, feature = "fuzztarget", feature = "_test_utils"))] /// Used only in testing and fuzztarget to check serialization roundtrips don't change the /// underlying object impl PartialEq for ChannelMonitor { @@ -746,7 +746,7 @@ impl ChannelMonitor { /// the "reorg path" (ie disconnecting blocks until you find a common ancestor from both the /// returned block hash and the the current chain and then reconnecting blocks to get to the /// best chain) upon deserializing the object! - pub fn write_for_disk(&self, writer: &mut W) -> Result<(), Error> { + pub fn serialize_for_disk(&self, writer: &mut W) -> Result<(), Error> { //TODO: We still write out all the serialization here manually instead of using the fancy //serialization framework we have, we should migrate things over to it. writer.write_all(&[SERIALIZATION_VERSION; 1])?; @@ -1162,28 +1162,28 @@ impl ChannelMonitor { /// itself. /// /// panics if the given update is not the next update by update_id. - pub fn update_monitor(&mut self, mut updates: ChannelMonitorUpdate, broadcaster: &B, logger: &L) -> Result<(), MonitorUpdateError> + pub fn update_monitor(&mut self, updates: &ChannelMonitorUpdate, broadcaster: &B, logger: &L) -> Result<(), MonitorUpdateError> where B::Target: BroadcasterInterface, L::Target: Logger, { if self.latest_update_id + 1 != updates.update_id { panic!("Attempted to apply ChannelMonitorUpdates out of order, check the update_id before passing an update to update_monitor!"); } - for update in updates.updates.drain(..) { + for update in updates.updates.iter() { match update { ChannelMonitorUpdateStep::LatestHolderCommitmentTXInfo { commitment_tx, htlc_outputs } => { if self.lockdown_from_offchain { panic!(); } - self.provide_latest_holder_commitment_tx_info(commitment_tx, htlc_outputs)? + self.provide_latest_holder_commitment_tx_info(commitment_tx.clone(), htlc_outputs.clone())? }, ChannelMonitorUpdateStep::LatestCounterpartyCommitmentTXInfo { unsigned_commitment_tx, htlc_outputs, commitment_number, their_revocation_point } => - self.provide_latest_counterparty_commitment_tx_info(&unsigned_commitment_tx, htlc_outputs, commitment_number, their_revocation_point, logger), + self.provide_latest_counterparty_commitment_tx_info(&unsigned_commitment_tx, htlc_outputs.clone(), *commitment_number, *their_revocation_point, logger), ChannelMonitorUpdateStep::PaymentPreimage { payment_preimage } => self.provide_payment_preimage(&PaymentHash(Sha256::hash(&payment_preimage.0[..]).into_inner()), &payment_preimage), ChannelMonitorUpdateStep::CommitmentSecret { idx, secret } => - self.provide_secret(idx, secret)?, + self.provide_secret(*idx, *secret)?, ChannelMonitorUpdateStep::ChannelForceClosed { should_broadcast } => { self.lockdown_from_offchain = true; - if should_broadcast { + if *should_broadcast { self.broadcast_latest_holder_commitment_txn(broadcaster, logger); } else { log_error!(logger, "You have a toxic holder commitment transaction avaible in channel monitor, read comment in ChannelMonitor::get_latest_holder_commitment_txn to be informed of manual action to take"); @@ -2123,6 +2123,61 @@ impl ChannelMonitor { } } +/// `Persist` defines behavior for persisting channel monitors: this could mean +/// writing once to disk, and/or uploading to one or more backup services. +/// +/// Note that for every new monitor, you **must** persist the new `ChannelMonitor` +/// to disk/backups. And, on every update, you **must** persist either the +/// `ChannelMonitorUpdate` or the updated monitor itself. Otherwise, there is risk +/// of situations such as revoking a transaction, then crashing before this +/// revocation can be persisted, then unintentionally broadcasting a revoked +/// transaction and losing money. This is a risk because previous channel states +/// are toxic, so it's important that whatever channel state is persisted is +/// kept up-to-date. +pub trait Persist: Send + Sync { + /// Persist a new channel's data. The data can be stored any way you want, but + /// the identifier provided by Rust-Lightning is the channel's outpoint (and + /// it is up to you to maintain a correct mapping between the outpoint and the + /// stored channel data). Note that you **must** persist every new monitor to + /// disk. See the `Persist` trait documentation for more details. + /// + /// See [`ChannelMonitor::serialize_for_disk`] for writing out a `ChannelMonitor`, + /// and [`ChannelMonitorUpdateErr`] for requirements when returning errors. + /// + /// [`ChannelMonitor::serialize_for_disk`]: struct.ChannelMonitor.html#method.serialize_for_disk + /// [`ChannelMonitorUpdateErr`]: enum.ChannelMonitorUpdateErr.html + fn persist_new_channel(&self, id: OutPoint, data: &ChannelMonitor) -> Result<(), ChannelMonitorUpdateErr>; + + /// Update one channel's data. The provided `ChannelMonitor` has already + /// applied the given update. + /// + /// Note that on every update, you **must** persist either the + /// `ChannelMonitorUpdate` or the updated monitor itself to disk/backups. See + /// the `Persist` trait documentation for more details. + /// + /// If an implementer chooses to persist the updates only, they need to make + /// sure that all the updates are applied to the `ChannelMonitors` *before* + /// the set of channel monitors is given to the `ChannelManager` + /// deserialization routine. See [`ChannelMonitor::update_monitor`] for + /// applying a monitor update to a monitor. If full `ChannelMonitors` are + /// persisted, then there is no need to persist individual updates. + /// + /// Note that there could be a performance tradeoff between persisting complete + /// channel monitors on every update vs. persisting only updates and applying + /// them in batches. The size of each monitor grows `O(number of state updates)` + /// whereas updates are small and `O(1)`. + /// + /// See [`ChannelMonitor::serialize_for_disk`] for writing out a `ChannelMonitor`, + /// [`ChannelMonitorUpdate::write`] for writing out an update, and + /// [`ChannelMonitorUpdateErr`] for requirements when returning errors. + /// + /// [`ChannelMonitor::update_monitor`]: struct.ChannelMonitor.html#impl-1 + /// [`ChannelMonitor::serialize_for_disk`]: struct.ChannelMonitor.html#method.serialize_for_disk + /// [`ChannelMonitorUpdate::write`]: struct.ChannelMonitorUpdate.html#method.write + /// [`ChannelMonitorUpdateErr`]: enum.ChannelMonitorUpdateErr.html + fn update_persisted_channel(&self, id: OutPoint, update: &ChannelMonitorUpdate, data: &ChannelMonitor) -> Result<(), ChannelMonitorUpdateErr>; +} + const MAX_ALLOC_SIZE: usize = 64*1024; impl Readable for (BlockHash, ChannelMonitor) { diff --git a/lightning/src/lib.rs b/lightning/src/lib.rs index dbd89a2ccdf..c5d369268cc 100644 --- a/lightning/src/lib.rs +++ b/lightning/src/lib.rs @@ -18,7 +18,7 @@ //! generated/etc. This makes it a good candidate for tight integration into an existing wallet //! instead of having a rather-separate lightning appendage to a wallet. -#![cfg_attr(not(feature = "fuzztarget"), deny(missing_docs))] +#![cfg_attr(not(any(feature = "fuzztarget", feature = "_test_utils")), deny(missing_docs))] #![forbid(unsafe_code)] // In general, rust is absolutely horrid at supporting users doing things like, @@ -28,8 +28,8 @@ #![allow(ellipsis_inclusive_range_patterns)] extern crate bitcoin; -#[cfg(test)] extern crate hex; -#[cfg(test)] extern crate regex; +#[cfg(any(test, feature = "_test_utils"))] extern crate hex; +#[cfg(any(test, feature = "_test_utils"))] extern crate regex; #[macro_use] pub mod util; diff --git a/lightning/src/ln/chanmon_update_fail_tests.rs b/lightning/src/ln/chanmon_update_fail_tests.rs index e6eb9e6a24a..689c3496de1 100644 --- a/lightning/src/ln/chanmon_update_fail_tests.rs +++ b/lightning/src/ln/chanmon_update_fail_tests.rs @@ -12,15 +12,21 @@ //! There are a bunch of these as their handling is relatively error-prone so they are split out //! here. See also the chanmon_fail_consistency fuzz test. -use chain::channelmonitor::ChannelMonitorUpdateErr; +use bitcoin::blockdata::block::BlockHeader; +use bitcoin::hash_types::BlockHash; +use bitcoin::network::constants::Network; +use chain::channelmonitor::{ChannelMonitor, ChannelMonitorUpdateErr}; use chain::transaction::OutPoint; +use chain::Watch; use ln::channelmanager::{RAACommitmentOrder, PaymentPreimage, PaymentHash, PaymentSecret, PaymentSendFailure}; use ln::features::InitFeatures; use ln::msgs; use ln::msgs::{ChannelMessageHandler, ErrorAction, RoutingMessageHandler}; use routing::router::get_route; +use util::enforcing_trait_impls::EnforcingChannelKeys; use util::events::{Event, EventsProvider, MessageSendEvent, MessageSendEventsProvider}; use util::errors::APIError; +use util::ser::Readable; use bitcoin::hashes::sha256::Hash as Sha256; use bitcoin::hashes::Hash; @@ -29,10 +35,11 @@ use ln::functional_test_utils::*; use util::test_utils; -#[test] -fn test_simple_monitor_permanent_update_fail() { +// If persister_fail is true, we have the persister return a PermanentFailure +// instead of the higher-level ChainMonitor. +fn do_test_simple_monitor_permanent_update_fail(persister_fail: bool) { // Test that we handle a simple permanent monitor update failure - let chanmon_cfgs = create_chanmon_cfgs(2); + let mut chanmon_cfgs = create_chanmon_cfgs(2); let node_cfgs = create_node_cfgs(2, &chanmon_cfgs); let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]); let mut nodes = create_network(2, &node_cfgs, &node_chanmgrs); @@ -41,7 +48,10 @@ fn test_simple_monitor_permanent_update_fail() { let (_, payment_hash_1) = get_payment_preimage_hash!(&nodes[0]); - *nodes[0].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::PermanentFailure); + match persister_fail { + true => chanmon_cfgs[0].persister.set_update_ret(Err(ChannelMonitorUpdateErr::PermanentFailure)), + false => *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::PermanentFailure)) + } let net_graph_msg_handler = &nodes[0].net_graph_msg_handler; let route = get_route(&nodes[0].node.get_our_node_id(), &net_graph_msg_handler.network_graph.read().unwrap(), &nodes[1].node.get_our_node_id(), None, &Vec::new(), 1000000, TEST_FINAL_CLTV, &logger).unwrap(); unwrap_send_err!(nodes[0].node.send_payment(&route, payment_hash_1, &None), true, APIError::ChannelUnavailable {..}, {}); @@ -64,10 +74,87 @@ fn test_simple_monitor_permanent_update_fail() { assert_eq!(nodes[0].node.list_channels().len(), 0); } -fn do_test_simple_monitor_temporary_update_fail(disconnect: bool) { +#[test] +fn test_monitor_and_persister_update_fail() { + // Test that if both updating the `ChannelMonitor` and persisting the updated + // `ChannelMonitor` fail, then the failure from updating the `ChannelMonitor` + // one that gets returned. + let chanmon_cfgs = create_chanmon_cfgs(2); + let node_cfgs = create_node_cfgs(2, &chanmon_cfgs); + let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]); + let mut nodes = create_network(2, &node_cfgs, &node_chanmgrs); + + // Create some initial channel + let chan = create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known()); + let outpoint = OutPoint { txid: chan.3.txid(), index: 0 }; + + // Rebalance the network to generate htlc in the two directions + send_payment(&nodes[0], &vec!(&nodes[1])[..], 10_000_000, 10_000_000); + + // Route an HTLC from node 0 to node 1 (but don't settle) + let preimage = route_payment(&nodes[0], &vec!(&nodes[1])[..], 9_000_000).0; + + // Make a copy of the ChainMonitor so we can capture the error it returns on a + // bogus update. Note that if instead we updated the nodes[0]'s ChainMonitor + // directly, the node would fail to be `Drop`'d at the end because its + // ChannelManager and ChainMonitor would be out of sync. + let chain_source = test_utils::TestChainSource::new(Network::Testnet); + let logger = test_utils::TestLogger::with_id(format!("node {}", 0)); + let persister = test_utils::TestPersister::new(); + let chain_mon = { + let monitors = nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap(); + let monitor = monitors.get(&outpoint).unwrap(); + let mut w = test_utils::TestVecWriter(Vec::new()); + monitor.serialize_for_disk(&mut w).unwrap(); + let new_monitor = <(BlockHash, ChannelMonitor)>::read( + &mut ::std::io::Cursor::new(&w.0)).unwrap().1; + assert!(new_monitor == *monitor); + let chain_mon = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister); + assert!(chain_mon.watch_channel(outpoint, new_monitor).is_ok()); + chain_mon + }; + let header = BlockHeader { version: 0x20000000, prev_blockhash: Default::default(), merkle_root: Default::default(), time: 42, bits: 42, nonce: 42 }; + chain_mon.chain_monitor.block_connected(&header, &[], 200); + + // Set the persister's return value to be a TemporaryFailure. + persister.set_update_ret(Err(ChannelMonitorUpdateErr::TemporaryFailure)); + + // Try to update ChannelMonitor + assert!(nodes[1].node.claim_funds(preimage, &None, 9_000_000)); + check_added_monitors!(nodes[1], 1); + let updates = get_htlc_update_msgs!(nodes[1], nodes[0].node.get_our_node_id()); + assert_eq!(updates.update_fulfill_htlcs.len(), 1); + nodes[0].node.handle_update_fulfill_htlc(&nodes[1].node.get_our_node_id(), &updates.update_fulfill_htlcs[0]); + if let Some(ref mut channel) = nodes[0].node.channel_state.lock().unwrap().by_id.get_mut(&chan.2) { + if let Ok((_, _, _, update)) = channel.commitment_signed(&updates.commitment_signed, &node_cfgs[0].fee_estimator, &node_cfgs[0].logger) { + // Check that even though the persister is returning a TemporaryFailure, + // because the update is bogus, ultimately the error that's returned + // should be a PermanentFailure. + if let Err(ChannelMonitorUpdateErr::PermanentFailure) = chain_mon.chain_monitor.update_channel(outpoint, update.clone()) {} else { panic!("Expected monitor error to be permanent"); } + logger.assert_log_contains("lightning::chain::chainmonitor".to_string(), "Failed to persist channel monitor update: TemporaryFailure".to_string(), 1); + if let Ok(_) = nodes[0].chain_monitor.update_channel(outpoint, update) {} else { assert!(false); } + } else { assert!(false); } + } else { assert!(false); }; + + check_added_monitors!(nodes[0], 1); + let events = nodes[0].node.get_and_clear_pending_events(); + assert_eq!(events.len(), 1); +} + +#[test] +fn test_simple_monitor_permanent_update_fail() { + do_test_simple_monitor_permanent_update_fail(false); + + // Test behavior when the persister returns a PermanentFailure. + do_test_simple_monitor_permanent_update_fail(true); +} + +// If persister_fail is true, we have the persister return a TemporaryFailure instead of the +// higher-level ChainMonitor. +fn do_test_simple_monitor_temporary_update_fail(disconnect: bool, persister_fail: bool) { // Test that we can recover from a simple temporary monitor update failure optionally with // a disconnect in between - let chanmon_cfgs = create_chanmon_cfgs(2); + let mut chanmon_cfgs = create_chanmon_cfgs(2); let node_cfgs = create_node_cfgs(2, &chanmon_cfgs); let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]); let mut nodes = create_network(2, &node_cfgs, &node_chanmgrs); @@ -76,7 +163,10 @@ fn do_test_simple_monitor_temporary_update_fail(disconnect: bool) { let (payment_preimage_1, payment_hash_1) = get_payment_preimage_hash!(&nodes[0]); - *nodes[0].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure); + match persister_fail { + true => chanmon_cfgs[0].persister.set_update_ret(Err(ChannelMonitorUpdateErr::TemporaryFailure)), + false => *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure)) + } { let net_graph_msg_handler = &nodes[0].net_graph_msg_handler; @@ -95,7 +185,10 @@ fn do_test_simple_monitor_temporary_update_fail(disconnect: bool) { reconnect_nodes(&nodes[0], &nodes[1], (true, true), (0, 0), (0, 0), (0, 0), (0, 0), (false, false)); } - *nodes[0].chain_monitor.update_ret.lock().unwrap() = Ok(()); + match persister_fail { + true => chanmon_cfgs[0].persister.set_update_ret(Ok(())), + false => *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Ok(())) + } let (outpoint, latest_update) = nodes[0].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&channel_id).unwrap().clone(); nodes[0].node.channel_monitor_updated(&outpoint, latest_update); check_added_monitors!(nodes[0], 0); @@ -125,7 +218,10 @@ fn do_test_simple_monitor_temporary_update_fail(disconnect: bool) { // Now set it to failed again... let (_, payment_hash_2) = get_payment_preimage_hash!(&nodes[0]); { - *nodes[0].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure); + match persister_fail { + true => chanmon_cfgs[0].persister.set_update_ret(Err(ChannelMonitorUpdateErr::TemporaryFailure)), + false => *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure)) + } let net_graph_msg_handler = &nodes[0].net_graph_msg_handler; let route = get_route(&nodes[0].node.get_our_node_id(), &net_graph_msg_handler.network_graph.read().unwrap(), &nodes[1].node.get_our_node_id(), None, &Vec::new(), 1000000, TEST_FINAL_CLTV, &logger).unwrap(); unwrap_send_err!(nodes[0].node.send_payment(&route, payment_hash_2, &None), false, APIError::MonitorUpdateFailed, {}); @@ -155,8 +251,12 @@ fn do_test_simple_monitor_temporary_update_fail(disconnect: bool) { #[test] fn test_simple_monitor_temporary_update_fail() { - do_test_simple_monitor_temporary_update_fail(false); - do_test_simple_monitor_temporary_update_fail(true); + do_test_simple_monitor_temporary_update_fail(false, false); + do_test_simple_monitor_temporary_update_fail(true, false); + + // Test behavior when the persister returns a TemporaryFailure. + do_test_simple_monitor_temporary_update_fail(false, true); + do_test_simple_monitor_temporary_update_fail(true, true); } fn do_test_monitor_temporary_update_fail(disconnect_count: usize) { @@ -191,7 +291,7 @@ fn do_test_monitor_temporary_update_fail(disconnect_count: usize) { // Now try to send a second payment which will fail to send let (payment_preimage_2, payment_hash_2) = get_payment_preimage_hash!(nodes[0]); { - *nodes[0].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure); + *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure)); let net_graph_msg_handler = &nodes[0].net_graph_msg_handler; let route = get_route(&nodes[0].node.get_our_node_id(), &net_graph_msg_handler.network_graph.read().unwrap(), &nodes[1].node.get_our_node_id(), None, &Vec::new(), 1000000, TEST_FINAL_CLTV, &logger).unwrap(); unwrap_send_err!(nodes[0].node.send_payment(&route, payment_hash_2, &None), false, APIError::MonitorUpdateFailed, {}); @@ -245,7 +345,7 @@ fn do_test_monitor_temporary_update_fail(disconnect_count: usize) { } // Now fix monitor updating... - *nodes[0].chain_monitor.update_ret.lock().unwrap() = Ok(()); + *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Ok(())); let (outpoint, latest_update) = nodes[0].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&channel_id).unwrap().clone(); nodes[0].node.channel_monitor_updated(&outpoint, latest_update); check_added_monitors!(nodes[0], 0); @@ -532,14 +632,14 @@ fn test_monitor_update_fail_cs() { let send_event = SendEvent::from_event(nodes[0].node.get_and_clear_pending_msg_events().remove(0)); nodes[1].node.handle_update_add_htlc(&nodes[0].node.get_our_node_id(), &send_event.msgs[0]); - *nodes[1].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure); + *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure)); nodes[1].node.handle_commitment_signed(&nodes[0].node.get_our_node_id(), &send_event.commitment_msg); assert!(nodes[1].node.get_and_clear_pending_msg_events().is_empty()); nodes[1].logger.assert_log("lightning::ln::channelmanager".to_string(), "Failed to update ChannelMonitor".to_string(), 1); check_added_monitors!(nodes[1], 1); assert!(nodes[1].node.get_and_clear_pending_msg_events().is_empty()); - *nodes[1].chain_monitor.update_ret.lock().unwrap() = Ok(()); + *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Ok(())); let (outpoint, latest_update) = nodes[1].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&channel_id).unwrap().clone(); nodes[1].node.channel_monitor_updated(&outpoint, latest_update); check_added_monitors!(nodes[1], 0); @@ -563,7 +663,7 @@ fn test_monitor_update_fail_cs() { assert!(updates.update_fee.is_none()); assert_eq!(*node_id, nodes[0].node.get_our_node_id()); - *nodes[0].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure); + *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure)); nodes[0].node.handle_commitment_signed(&nodes[1].node.get_our_node_id(), &updates.commitment_signed); assert!(nodes[0].node.get_and_clear_pending_msg_events().is_empty()); nodes[0].logger.assert_log("lightning::ln::channelmanager".to_string(), "Failed to update ChannelMonitor".to_string(), 1); @@ -573,7 +673,7 @@ fn test_monitor_update_fail_cs() { _ => panic!("Unexpected event"), } - *nodes[0].chain_monitor.update_ret.lock().unwrap() = Ok(()); + *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Ok(())); let (outpoint, latest_update) = nodes[0].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&channel_id).unwrap().clone(); nodes[0].node.channel_monitor_updated(&outpoint, latest_update); check_added_monitors!(nodes[0], 0); @@ -622,7 +722,7 @@ fn test_monitor_update_fail_no_rebroadcast() { nodes[1].node.handle_update_add_htlc(&nodes[0].node.get_our_node_id(), &send_event.msgs[0]); let bs_raa = commitment_signed_dance!(nodes[1], nodes[0], send_event.commitment_msg, false, true, false, true); - *nodes[1].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure); + *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure)); nodes[1].node.handle_revoke_and_ack(&nodes[0].node.get_our_node_id(), &bs_raa); assert!(nodes[1].node.get_and_clear_pending_msg_events().is_empty()); nodes[1].logger.assert_log("lightning::ln::channelmanager".to_string(), "Failed to update ChannelMonitor".to_string(), 1); @@ -630,7 +730,7 @@ fn test_monitor_update_fail_no_rebroadcast() { assert!(nodes[1].node.get_and_clear_pending_events().is_empty()); check_added_monitors!(nodes[1], 1); - *nodes[1].chain_monitor.update_ret.lock().unwrap() = Ok(()); + *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Ok(())); let (outpoint, latest_update) = nodes[1].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&channel_id).unwrap().clone(); nodes[1].node.channel_monitor_updated(&outpoint, latest_update); assert!(nodes[1].node.get_and_clear_pending_msg_events().is_empty()); @@ -684,7 +784,7 @@ fn test_monitor_update_raa_while_paused() { check_added_monitors!(nodes[1], 1); let bs_raa = get_event_msg!(nodes[1], MessageSendEvent::SendRevokeAndACK, nodes[0].node.get_our_node_id()); - *nodes[0].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure); + *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure)); nodes[0].node.handle_update_add_htlc(&nodes[1].node.get_our_node_id(), &send_event_2.msgs[0]); nodes[0].node.handle_commitment_signed(&nodes[1].node.get_our_node_id(), &send_event_2.commitment_msg); assert!(nodes[0].node.get_and_clear_pending_msg_events().is_empty()); @@ -696,7 +796,7 @@ fn test_monitor_update_raa_while_paused() { nodes[0].logger.assert_log("lightning::ln::channelmanager".to_string(), "Previous monitor update failure prevented responses to RAA".to_string(), 1); check_added_monitors!(nodes[0], 1); - *nodes[0].chain_monitor.update_ret.lock().unwrap() = Ok(()); + *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Ok(())); let (outpoint, latest_update) = nodes[0].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&channel_id).unwrap().clone(); nodes[0].node.channel_monitor_updated(&outpoint, latest_update); check_added_monitors!(nodes[0], 0); @@ -779,7 +879,7 @@ fn do_test_monitor_update_fail_raa(test_ignore_second_cs: bool) { assert!(nodes[1].node.get_and_clear_pending_msg_events().is_empty()); // Now fail monitor updating. - *nodes[1].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure); + *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure)); nodes[1].node.handle_revoke_and_ack(&nodes[2].node.get_our_node_id(), &bs_revoke_and_ack); assert!(nodes[1].node.get_and_clear_pending_msg_events().is_empty()); nodes[1].logger.assert_log("lightning::ln::channelmanager".to_string(), "Failed to update ChannelMonitor".to_string(), 1); @@ -797,7 +897,7 @@ fn do_test_monitor_update_fail_raa(test_ignore_second_cs: bool) { check_added_monitors!(nodes[0], 1); } - *nodes[1].chain_monitor.update_ret.lock().unwrap() = Ok(()); // We succeed in updating the monitor for the first channel + *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Ok(())); // We succeed in updating the monitor for the first channel send_event = SendEvent::from_event(nodes[0].node.get_and_clear_pending_msg_events().remove(0)); nodes[1].node.handle_update_add_htlc(&nodes[0].node.get_our_node_id(), &send_event.msgs[0]); commitment_signed_dance!(nodes[1], nodes[0], send_event.commitment_msg, false, true); @@ -858,7 +958,7 @@ fn do_test_monitor_update_fail_raa(test_ignore_second_cs: bool) { // Restore monitor updating, ensuring we immediately get a fail-back update and a // update_add update. - *nodes[1].chain_monitor.update_ret.lock().unwrap() = Ok(()); + *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Ok(())); let (outpoint, latest_update) = nodes[1].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&chan_2.2).unwrap().clone(); nodes[1].node.channel_monitor_updated(&outpoint, latest_update); check_added_monitors!(nodes[1], 0); @@ -1020,7 +1120,7 @@ fn test_monitor_update_fail_reestablish() { assert!(nodes[1].node.get_and_clear_pending_msg_events().is_empty()); commitment_signed_dance!(nodes[1], nodes[2], updates.commitment_signed, false); - *nodes[1].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure); + *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure)); nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: InitFeatures::empty() }); nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: InitFeatures::empty() }); @@ -1049,7 +1149,7 @@ fn test_monitor_update_fail_reestablish() { check_added_monitors!(nodes[1], 0); assert!(nodes[1].node.get_and_clear_pending_msg_events().is_empty()); - *nodes[1].chain_monitor.update_ret.lock().unwrap() = Ok(()); + *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Ok(())); let (outpoint, latest_update) = nodes[1].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&chan_1.2).unwrap().clone(); nodes[1].node.channel_monitor_updated(&outpoint, latest_update); check_added_monitors!(nodes[1], 0); @@ -1123,7 +1223,7 @@ fn raa_no_response_awaiting_raa_state() { // Now we have a CS queued up which adds a new HTLC (which will need a RAA/CS response from // nodes[1]) followed by an RAA. Fail the monitor updating prior to the CS, deliver the RAA, // then restore channel monitor updates. - *nodes[1].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure); + *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure)); nodes[1].node.handle_update_add_htlc(&nodes[0].node.get_our_node_id(), &payment_event.msgs[0]); nodes[1].node.handle_commitment_signed(&nodes[0].node.get_our_node_id(), &payment_event.commitment_msg); assert!(nodes[1].node.get_and_clear_pending_msg_events().is_empty()); @@ -1135,7 +1235,7 @@ fn raa_no_response_awaiting_raa_state() { nodes[1].logger.assert_log("lightning::ln::channelmanager".to_string(), "Previous monitor update failure prevented responses to RAA".to_string(), 1); check_added_monitors!(nodes[1], 1); - *nodes[1].chain_monitor.update_ret.lock().unwrap() = Ok(()); + *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Ok(())); let (outpoint, latest_update) = nodes[1].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&channel_id).unwrap().clone(); nodes[1].node.channel_monitor_updated(&outpoint, latest_update); // nodes[1] should be AwaitingRAA here! @@ -1228,7 +1328,7 @@ fn claim_while_disconnected_monitor_update_fail() { // Now deliver a's reestablish, freeing the claim from the holding cell, but fail the monitor // update. - *nodes[1].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure); + *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure)); nodes[1].node.handle_channel_reestablish(&nodes[0].node.get_our_node_id(), &as_reconnect); assert!(nodes[1].node.get_and_clear_pending_msg_events().is_empty()); @@ -1257,7 +1357,7 @@ fn claim_while_disconnected_monitor_update_fail() { // Now un-fail the monitor, which will result in B sending its original commitment update, // receiving the commitment update from A, and the resulting commitment dances. - *nodes[1].chain_monitor.update_ret.lock().unwrap() = Ok(()); + *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Ok(())); let (outpoint, latest_update) = nodes[1].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&channel_id).unwrap().clone(); nodes[1].node.channel_monitor_updated(&outpoint, latest_update); check_added_monitors!(nodes[1], 0); @@ -1342,7 +1442,7 @@ fn monitor_failed_no_reestablish_response() { check_added_monitors!(nodes[0], 1); } - *nodes[1].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure); + *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure)); let mut events = nodes[0].node.get_and_clear_pending_msg_events(); assert_eq!(events.len(), 1); let payment_event = SendEvent::from_event(events.pop().unwrap()); @@ -1366,7 +1466,7 @@ fn monitor_failed_no_reestablish_response() { nodes[1].node.handle_channel_reestablish(&nodes[0].node.get_our_node_id(), &as_reconnect); nodes[0].node.handle_channel_reestablish(&nodes[1].node.get_our_node_id(), &bs_reconnect); - *nodes[1].chain_monitor.update_ret.lock().unwrap() = Ok(()); + *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Ok(())); let (outpoint, latest_update) = nodes[1].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&channel_id).unwrap().clone(); nodes[1].node.channel_monitor_updated(&outpoint, latest_update); check_added_monitors!(nodes[1], 0); @@ -1445,7 +1545,7 @@ fn first_message_on_recv_ordering() { let payment_event = SendEvent::from_event(events.pop().unwrap()); assert_eq!(payment_event.node_id, nodes[1].node.get_our_node_id()); - *nodes[1].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure); + *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure)); // Deliver the final RAA for the first payment, which does not require a response. RAAs // generally require a commitment_signed, so the fact that we're expecting an opposite response @@ -1464,7 +1564,7 @@ fn first_message_on_recv_ordering() { assert!(nodes[1].node.get_and_clear_pending_msg_events().is_empty()); nodes[1].logger.assert_log("lightning::ln::channelmanager".to_string(), "Previous monitor update failure prevented generation of RAA".to_string(), 1); - *nodes[1].chain_monitor.update_ret.lock().unwrap() = Ok(()); + *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Ok(())); let (outpoint, latest_update) = nodes[1].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&channel_id).unwrap().clone(); nodes[1].node.channel_monitor_updated(&outpoint, latest_update); check_added_monitors!(nodes[1], 0); @@ -1509,7 +1609,7 @@ fn test_monitor_update_fail_claim() { let (payment_preimage_1, _) = route_payment(&nodes[0], &[&nodes[1]], 1000000); - *nodes[1].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure); + *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure)); assert!(nodes[1].node.claim_funds(payment_preimage_1, &None, 1_000_000)); check_added_monitors!(nodes[1], 1); @@ -1523,7 +1623,7 @@ fn test_monitor_update_fail_claim() { // Successfully update the monitor on the 1<->2 channel, but the 0<->1 channel should still be // paused, so forward shouldn't succeed until we call channel_monitor_updated(). - *nodes[1].chain_monitor.update_ret.lock().unwrap() = Ok(()); + *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Ok(())); let mut events = nodes[2].node.get_and_clear_pending_msg_events(); assert_eq!(events.len(), 1); @@ -1612,13 +1712,13 @@ fn test_monitor_update_on_pending_forwards() { nodes[1].node.handle_update_add_htlc(&nodes[2].node.get_our_node_id(), &payment_event.msgs[0]); commitment_signed_dance!(nodes[1], nodes[2], payment_event.commitment_msg, false); - *nodes[1].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure); + *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure)); expect_pending_htlcs_forwardable!(nodes[1]); check_added_monitors!(nodes[1], 1); assert!(nodes[1].node.get_and_clear_pending_msg_events().is_empty()); nodes[1].logger.assert_log("lightning::ln::channelmanager".to_string(), "Failed to update ChannelMonitor".to_string(), 1); - *nodes[1].chain_monitor.update_ret.lock().unwrap() = Ok(()); + *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Ok(())); let (outpoint, latest_update) = nodes[1].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&chan_1.2).unwrap().clone(); nodes[1].node.channel_monitor_updated(&outpoint, latest_update); check_added_monitors!(nodes[1], 0); @@ -1675,14 +1775,14 @@ fn monitor_update_claim_fail_no_response() { nodes[1].node.handle_update_add_htlc(&nodes[0].node.get_our_node_id(), &payment_event.msgs[0]); let as_raa = commitment_signed_dance!(nodes[1], nodes[0], payment_event.commitment_msg, false, true, false, true); - *nodes[1].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure); + *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure)); assert!(nodes[1].node.claim_funds(payment_preimage_1, &None, 1_000_000)); check_added_monitors!(nodes[1], 1); let events = nodes[1].node.get_and_clear_pending_msg_events(); assert_eq!(events.len(), 0); nodes[1].logger.assert_log("lightning::ln::channelmanager".to_string(), "Temporary failure claiming HTLC, treating as success: Failed to update ChannelMonitor".to_string(), 1); - *nodes[1].chain_monitor.update_ret.lock().unwrap() = Ok(()); + *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Ok(())); let (outpoint, latest_update) = nodes[1].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&channel_id).unwrap().clone(); nodes[1].node.channel_monitor_updated(&outpoint, latest_update); check_added_monitors!(nodes[1], 0); @@ -1728,19 +1828,19 @@ fn do_during_funding_monitor_fail(confirm_a_first: bool, restore_b_before_conf: nodes[0].node.funding_transaction_generated(&temporary_channel_id, funding_output); check_added_monitors!(nodes[0], 0); - *nodes[1].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure); + *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure)); let funding_created_msg = get_event_msg!(nodes[0], MessageSendEvent::SendFundingCreated, nodes[1].node.get_our_node_id()); let channel_id = OutPoint { txid: funding_created_msg.funding_txid, index: funding_created_msg.funding_output_index }.to_channel_id(); nodes[1].node.handle_funding_created(&nodes[0].node.get_our_node_id(), &funding_created_msg); check_added_monitors!(nodes[1], 1); - *nodes[0].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure); + *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure)); nodes[0].node.handle_funding_signed(&nodes[1].node.get_our_node_id(), &get_event_msg!(nodes[1], MessageSendEvent::SendFundingSigned, nodes[0].node.get_our_node_id())); assert!(nodes[0].node.get_and_clear_pending_msg_events().is_empty()); nodes[0].logger.assert_log("lightning::ln::channelmanager".to_string(), "Failed to update ChannelMonitor".to_string(), 1); check_added_monitors!(nodes[0], 1); assert!(nodes[0].node.get_and_clear_pending_events().is_empty()); - *nodes[0].chain_monitor.update_ret.lock().unwrap() = Ok(()); + *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Ok(())); let (outpoint, latest_update) = nodes[0].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&channel_id).unwrap().clone(); nodes[0].node.channel_monitor_updated(&outpoint, latest_update); check_added_monitors!(nodes[0], 0); @@ -1777,7 +1877,7 @@ fn do_during_funding_monitor_fail(confirm_a_first: bool, restore_b_before_conf: assert!(nodes[1].node.get_and_clear_pending_events().is_empty()); } - *nodes[1].chain_monitor.update_ret.lock().unwrap() = Ok(()); + *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Ok(())); let (outpoint, latest_update) = nodes[1].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&channel_id).unwrap().clone(); nodes[1].node.channel_monitor_updated(&outpoint, latest_update); check_added_monitors!(nodes[1], 0); @@ -1843,7 +1943,7 @@ fn test_path_paused_mpp() { // Set it so that the first monitor update (for the path 0 -> 1 -> 3) succeeds, but the second // (for the path 0 -> 2 -> 3) fails. - *nodes[0].chain_monitor.update_ret.lock().unwrap() = Ok(()); + *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Ok(())); *nodes[0].chain_monitor.next_update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure)); // Now check that we get the right return value, indicating that the first path succeeded but @@ -1855,7 +1955,7 @@ fn test_path_paused_mpp() { if let Err(APIError::MonitorUpdateFailed) = results[1] {} else { panic!(); } } else { panic!(); } check_added_monitors!(nodes[0], 2); - *nodes[0].chain_monitor.update_ret.lock().unwrap() = Ok(()); + *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Ok(())); // Pass the first HTLC of the payment along to nodes[3]. let mut events = nodes[0].node.get_and_clear_pending_msg_events(); diff --git a/lightning/src/ln/channelmanager.rs b/lightning/src/ln/channelmanager.rs index 379fe193009..b43c98c8403 100644 --- a/lightning/src/ln/channelmanager.rs +++ b/lightning/src/ln/channelmanager.rs @@ -405,9 +405,9 @@ pub struct ChannelManager, secp_ctx: Secp256k1, - #[cfg(test)] + #[cfg(any(test, feature = "_test_utils"))] pub(super) channel_state: Mutex>, - #[cfg(not(test))] + #[cfg(not(any(test, feature = "_test_utils")))] channel_state: Mutex>, our_network_key: SecretKey, diff --git a/lightning/src/ln/functional_test_utils.rs b/lightning/src/ln/functional_test_utils.rs index 7d54b6a9216..bc8351e4da0 100644 --- a/lightning/src/ln/functional_test_utils.rs +++ b/lightning/src/ln/functional_test_utils.rs @@ -94,6 +94,7 @@ pub struct TestChanMonCfg { pub tx_broadcaster: test_utils::TestBroadcaster, pub fee_estimator: test_utils::TestFeeEstimator, pub chain_source: test_utils::TestChainSource, + pub persister: test_utils::TestPersister, pub logger: test_utils::TestLogger, } @@ -169,7 +170,7 @@ impl<'a, 'b, 'c> Drop for Node<'a, 'b, 'c> { let old_monitors = self.chain_monitor.chain_monitor.monitors.lock().unwrap(); for (_, old_monitor) in old_monitors.iter() { let mut w = test_utils::TestVecWriter(Vec::new()); - old_monitor.write_for_disk(&mut w).unwrap(); + old_monitor.serialize_for_disk(&mut w).unwrap(); let (_, deserialized_monitor) = <(BlockHash, ChannelMonitor)>::read( &mut ::std::io::Cursor::new(&w.0)).unwrap(); deserialized_monitors.push(deserialized_monitor); @@ -191,14 +192,20 @@ impl<'a, 'b, 'c> Drop for Node<'a, 'b, 'c> { keys_manager: self.keys_manager, fee_estimator: &test_utils::TestFeeEstimator { sat_per_kw: 253 }, chain_monitor: self.chain_monitor, - tx_broadcaster: self.tx_broadcaster.clone(), + tx_broadcaster: &test_utils::TestBroadcaster { + txn_broadcasted: Mutex::new(self.tx_broadcaster.txn_broadcasted.lock().unwrap().clone()) + }, logger: &test_utils::TestLogger::new(), channel_monitors, }).unwrap(); } + let persister = test_utils::TestPersister::new(); + let broadcaster = test_utils::TestBroadcaster { + txn_broadcasted: Mutex::new(self.tx_broadcaster.txn_broadcasted.lock().unwrap().clone()) + }; let chain_source = test_utils::TestChainSource::new(Network::Testnet); - let chain_monitor = test_utils::TestChainMonitor::new(Some(&chain_source), self.tx_broadcaster.clone(), &self.logger, &feeest); + let chain_monitor = test_utils::TestChainMonitor::new(Some(&chain_source), &broadcaster, &self.logger, &feeest, &persister); for deserialized_monitor in deserialized_monitors.drain(..) { if let Err(_) = chain_monitor.watch_channel(deserialized_monitor.get_funding_txo().0, deserialized_monitor) { panic!(); @@ -247,6 +254,8 @@ macro_rules! get_revoke_commit_msgs { } } +/// Get an specific event message from the pending events queue. +#[macro_export] macro_rules! get_event_msg { ($node: expr, $event_type: path, $node_id: expr) => { { @@ -263,6 +272,7 @@ macro_rules! get_event_msg { } } +#[cfg(test)] macro_rules! get_htlc_update_msgs { ($node: expr, $node_id: expr) => { { @@ -279,6 +289,7 @@ macro_rules! get_htlc_update_msgs { } } +#[cfg(test)] macro_rules! get_feerate { ($node: expr, $channel_id: expr) => { { @@ -289,6 +300,7 @@ macro_rules! get_feerate { } } +#[cfg(test)] macro_rules! get_local_commitment_txn { ($node: expr, $channel_id: expr) => { { @@ -305,6 +317,8 @@ macro_rules! get_local_commitment_txn { } } +/// Check the error from attempting a payment. +#[macro_export] macro_rules! unwrap_send_err { ($res: expr, $all_failed: expr, $type: pat, $check: expr) => { match &$res { @@ -327,6 +341,8 @@ macro_rules! unwrap_send_err { } } +/// Check whether N channel monitor(s) have been added. +#[macro_export] macro_rules! check_added_monitors { ($node: expr, $count: expr) => { { @@ -553,6 +569,9 @@ macro_rules! get_closing_signed_broadcast { } } +/// Check that a channel's closing channel update has been broadcasted, and optionally +/// check whether an error message event has occurred. +#[macro_export] macro_rules! check_closed_broadcast { ($node: expr, $with_error_msg: expr) => {{ let events = $node.node.get_and_clear_pending_msg_events(); @@ -750,6 +769,8 @@ macro_rules! commitment_signed_dance { } } +/// Get a payment preimage and hash. +#[macro_export] macro_rules! get_payment_preimage_hash { ($node: expr) => { { @@ -779,6 +800,7 @@ macro_rules! expect_pending_htlcs_forwardable { }} } +#[cfg(test)] macro_rules! expect_payment_received { ($node: expr, $expected_payment_hash: expr, $expected_recv_value: expr) => { let events = $node.node.get_and_clear_pending_events(); @@ -807,6 +829,7 @@ macro_rules! expect_payment_sent { } } +#[cfg(test)] macro_rules! expect_payment_failed { ($node: expr, $expected_payment_hash: expr, $rejected_by_dest: expr $(, $expected_error_code: expr, $expected_error_data: expr)*) => { let events = $node.node.get_and_clear_pending_events(); @@ -1105,7 +1128,8 @@ pub fn create_chanmon_cfgs(node_count: usize) -> Vec { let fee_estimator = test_utils::TestFeeEstimator { sat_per_kw: 253 }; let chain_source = test_utils::TestChainSource::new(Network::Testnet); let logger = test_utils::TestLogger::with_id(format!("node {}", i)); - chan_mon_cfgs.push(TestChanMonCfg{ tx_broadcaster, fee_estimator, chain_source, logger }); + let persister = test_utils::TestPersister::new(); + chan_mon_cfgs.push(TestChanMonCfg{ tx_broadcaster, fee_estimator, chain_source, logger, persister }); } chan_mon_cfgs @@ -1117,7 +1141,7 @@ pub fn create_node_cfgs<'a>(node_count: usize, chanmon_cfgs: &'a Vec(node_count: usize, cfgs: &'a Vec default_config.channel_options.announced_channel = true; default_config.peer_channel_config_limits.force_announced_channel_preference = false; default_config.own_channel_config.our_htlc_minimum_msat = 1000; // sanitization being done by the sender, to exerce receiver logic we need to lift of limit - let node = ChannelManager::new(Network::Testnet, cfgs[i].fee_estimator, &cfgs[i].chain_monitor, cfgs[i].tx_broadcaster, cfgs[i].logger.clone(), &cfgs[i].keys_manager, if node_config[i].is_some() { node_config[i].clone().unwrap() } else { default_config }, 0); + let node = ChannelManager::new(Network::Testnet, cfgs[i].fee_estimator, &cfgs[i].chain_monitor, cfgs[i].tx_broadcaster, cfgs[i].logger, &cfgs[i].keys_manager, if node_config[i].is_some() { node_config[i].clone().unwrap() } else { default_config }, 0); chanmgrs.push(node); } @@ -1284,6 +1308,7 @@ pub fn get_announce_close_broadcast_events<'a, 'b, 'c>(nodes: &Vec {{ let chan_lock = $node.node.channel_state.lock().unwrap(); diff --git a/lightning/src/ln/functional_tests.rs b/lightning/src/ln/functional_tests.rs index bf18a2446cc..a2e12504031 100644 --- a/lightning/src/ln/functional_tests.rs +++ b/lightning/src/ln/functional_tests.rs @@ -4307,6 +4307,7 @@ fn test_no_txn_manager_serialize_deserialize() { let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]); let logger: test_utils::TestLogger; let fee_estimator: test_utils::TestFeeEstimator; + let persister: test_utils::TestPersister; let new_chain_monitor: test_utils::TestChainMonitor; let keys_manager: test_utils::TestKeysInterface; let nodes_0_deserialized: ChannelManager; @@ -4318,11 +4319,12 @@ fn test_no_txn_manager_serialize_deserialize() { let nodes_0_serialized = nodes[0].node.encode(); let mut chan_0_monitor_serialized = test_utils::TestVecWriter(Vec::new()); - nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap().iter().next().unwrap().1.write_for_disk(&mut chan_0_monitor_serialized).unwrap(); + nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap().iter().next().unwrap().1.serialize_for_disk(&mut chan_0_monitor_serialized).unwrap(); logger = test_utils::TestLogger::new(); fee_estimator = test_utils::TestFeeEstimator { sat_per_kw: 253 }; - new_chain_monitor = test_utils::TestChainMonitor::new(Some(nodes[0].chain_source), nodes[0].tx_broadcaster.clone(), &logger, &fee_estimator); + persister = test_utils::TestPersister::new(); + new_chain_monitor = test_utils::TestChainMonitor::new(Some(nodes[0].chain_source), nodes[0].tx_broadcaster.clone(), &logger, &fee_estimator, &persister); nodes[0].chain_monitor = &new_chain_monitor; let mut chan_0_monitor_read = &chan_0_monitor_serialized.0[..]; let (_, mut chan_0_monitor) = <(BlockHash, ChannelMonitor)>::read(&mut chan_0_monitor_read).unwrap(); @@ -4380,6 +4382,7 @@ fn test_manager_serialize_deserialize_events() { let node_cfgs = create_node_cfgs(2, &chanmon_cfgs); let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]); let fee_estimator: test_utils::TestFeeEstimator; + let persister: test_utils::TestPersister; let logger: test_utils::TestLogger; let new_chain_monitor: test_utils::TestChainMonitor; let keys_manager: test_utils::TestKeysInterface; @@ -4425,11 +4428,12 @@ fn test_manager_serialize_deserialize_events() { // Start the de/seriailization process mid-channel creation to check that the channel manager will hold onto events that are serialized let nodes_0_serialized = nodes[0].node.encode(); let mut chan_0_monitor_serialized = test_utils::TestVecWriter(Vec::new()); - nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap().iter().next().unwrap().1.write_for_disk(&mut chan_0_monitor_serialized).unwrap(); + nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap().iter().next().unwrap().1.serialize_for_disk(&mut chan_0_monitor_serialized).unwrap(); fee_estimator = test_utils::TestFeeEstimator { sat_per_kw: 253 }; logger = test_utils::TestLogger::new(); - new_chain_monitor = test_utils::TestChainMonitor::new(Some(nodes[0].chain_source), nodes[0].tx_broadcaster.clone(), &logger, &fee_estimator); + persister = test_utils::TestPersister::new(); + new_chain_monitor = test_utils::TestChainMonitor::new(Some(nodes[0].chain_source), nodes[0].tx_broadcaster.clone(), &logger, &fee_estimator, &persister); nodes[0].chain_monitor = &new_chain_monitor; let mut chan_0_monitor_read = &chan_0_monitor_serialized.0[..]; let (_, mut chan_0_monitor) = <(BlockHash, ChannelMonitor)>::read(&mut chan_0_monitor_read).unwrap(); @@ -4502,6 +4506,7 @@ fn test_simple_manager_serialize_deserialize() { let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]); let logger: test_utils::TestLogger; let fee_estimator: test_utils::TestFeeEstimator; + let persister: test_utils::TestPersister; let new_chain_monitor: test_utils::TestChainMonitor; let keys_manager: test_utils::TestKeysInterface; let nodes_0_deserialized: ChannelManager; @@ -4515,11 +4520,12 @@ fn test_simple_manager_serialize_deserialize() { let nodes_0_serialized = nodes[0].node.encode(); let mut chan_0_monitor_serialized = test_utils::TestVecWriter(Vec::new()); - nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap().iter().next().unwrap().1.write_for_disk(&mut chan_0_monitor_serialized).unwrap(); + nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap().iter().next().unwrap().1.serialize_for_disk(&mut chan_0_monitor_serialized).unwrap(); logger = test_utils::TestLogger::new(); fee_estimator = test_utils::TestFeeEstimator { sat_per_kw: 253 }; - new_chain_monitor = test_utils::TestChainMonitor::new(Some(nodes[0].chain_source), nodes[0].tx_broadcaster.clone(), &logger, &fee_estimator); + persister = test_utils::TestPersister::new(); + new_chain_monitor = test_utils::TestChainMonitor::new(Some(nodes[0].chain_source), nodes[0].tx_broadcaster.clone(), &logger, &fee_estimator, &persister); nodes[0].chain_monitor = &new_chain_monitor; let mut chan_0_monitor_read = &chan_0_monitor_serialized.0[..]; let (_, mut chan_0_monitor) = <(BlockHash, ChannelMonitor)>::read(&mut chan_0_monitor_read).unwrap(); @@ -4561,6 +4567,7 @@ fn test_manager_serialize_deserialize_inconsistent_monitor() { let node_chanmgrs = create_node_chanmgrs(4, &node_cfgs, &[None, None, None, None]); let logger: test_utils::TestLogger; let fee_estimator: test_utils::TestFeeEstimator; + let persister: test_utils::TestPersister; let new_chain_monitor: test_utils::TestChainMonitor; let keys_manager: test_utils::TestKeysInterface; let nodes_0_deserialized: ChannelManager; @@ -4572,7 +4579,7 @@ fn test_manager_serialize_deserialize_inconsistent_monitor() { let mut node_0_stale_monitors_serialized = Vec::new(); for monitor in nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap().iter() { let mut writer = test_utils::TestVecWriter(Vec::new()); - monitor.1.write_for_disk(&mut writer).unwrap(); + monitor.1.serialize_for_disk(&mut writer).unwrap(); node_0_stale_monitors_serialized.push(writer.0); } @@ -4591,13 +4598,14 @@ fn test_manager_serialize_deserialize_inconsistent_monitor() { let mut node_0_monitors_serialized = Vec::new(); for monitor in nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap().iter() { let mut writer = test_utils::TestVecWriter(Vec::new()); - monitor.1.write_for_disk(&mut writer).unwrap(); + monitor.1.serialize_for_disk(&mut writer).unwrap(); node_0_monitors_serialized.push(writer.0); } logger = test_utils::TestLogger::new(); fee_estimator = test_utils::TestFeeEstimator { sat_per_kw: 253 }; - new_chain_monitor = test_utils::TestChainMonitor::new(Some(nodes[0].chain_source), nodes[0].tx_broadcaster.clone(), &logger, &fee_estimator); + persister = test_utils::TestPersister::new(); + new_chain_monitor = test_utils::TestChainMonitor::new(Some(nodes[0].chain_source), nodes[0].tx_broadcaster.clone(), &logger, &fee_estimator, &persister); nodes[0].chain_monitor = &new_chain_monitor; let mut node_0_stale_monitors = Vec::new(); @@ -5742,7 +5750,7 @@ fn test_key_derivation_params() { // We manually create the node configuration to backup the seed. let seed = [42; 32]; let keys_manager = test_utils::TestKeysInterface::new(&seed, Network::Testnet); - let chain_monitor = test_utils::TestChainMonitor::new(Some(&chanmon_cfgs[0].chain_source), &chanmon_cfgs[0].tx_broadcaster, &chanmon_cfgs[0].logger, &chanmon_cfgs[0].fee_estimator); + let chain_monitor = test_utils::TestChainMonitor::new(Some(&chanmon_cfgs[0].chain_source), &chanmon_cfgs[0].tx_broadcaster, &chanmon_cfgs[0].logger, &chanmon_cfgs[0].fee_estimator, &chanmon_cfgs[0].persister); let node = NodeCfg { chain_source: &chanmon_cfgs[0].chain_source, logger: &chanmon_cfgs[0].logger, tx_broadcaster: &chanmon_cfgs[0].tx_broadcaster, fee_estimator: &chanmon_cfgs[0].fee_estimator, chain_monitor, keys_manager, node_seed: seed }; let mut node_cfgs = create_node_cfgs(3, &chanmon_cfgs); node_cfgs.remove(0); @@ -7407,6 +7415,7 @@ fn test_data_loss_protect() { // * we close channel in case of detecting other being fallen behind // * we are able to claim our own outputs thanks to to_remote being static let keys_manager; + let persister; let logger; let fee_estimator; let tx_broadcaster; @@ -7423,7 +7432,7 @@ fn test_data_loss_protect() { // Cache node A state before any channel update let previous_node_state = nodes[0].node.encode(); let mut previous_chain_monitor_state = test_utils::TestVecWriter(Vec::new()); - nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap().iter().next().unwrap().1.write_for_disk(&mut previous_chain_monitor_state).unwrap(); + nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap().iter().next().unwrap().1.serialize_for_disk(&mut previous_chain_monitor_state).unwrap(); send_payment(&nodes[0], &vec!(&nodes[1])[..], 8000000, 8_000_000); send_payment(&nodes[0], &vec!(&nodes[1])[..], 8000000, 8_000_000); @@ -7438,7 +7447,8 @@ fn test_data_loss_protect() { tx_broadcaster = test_utils::TestBroadcaster{txn_broadcasted: Mutex::new(Vec::new())}; fee_estimator = test_utils::TestFeeEstimator { sat_per_kw: 253 }; keys_manager = test_utils::TestKeysInterface::new(&nodes[0].node_seed, Network::Testnet); - monitor = test_utils::TestChainMonitor::new(Some(&chain_source), &tx_broadcaster, &logger, &fee_estimator); + persister = test_utils::TestPersister::new(); + monitor = test_utils::TestChainMonitor::new(Some(&chain_source), &tx_broadcaster, &logger, &fee_estimator, &persister); node_state_0 = { let mut channel_monitors = HashMap::new(); channel_monitors.insert(OutPoint { txid: chan.3.txid(), index: 0 }, &mut chain_monitor); @@ -8299,15 +8309,16 @@ fn test_update_err_monitor_lockdown() { // Copy ChainMonitor to simulate a watchtower and update block height of node 0 until its ChannelMonitor timeout HTLC onchain let chain_source = test_utils::TestChainSource::new(Network::Testnet); let logger = test_utils::TestLogger::with_id(format!("node {}", 0)); + let persister = test_utils::TestPersister::new(); let watchtower = { let monitors = nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap(); let monitor = monitors.get(&outpoint).unwrap(); let mut w = test_utils::TestVecWriter(Vec::new()); - monitor.write_for_disk(&mut w).unwrap(); + monitor.serialize_for_disk(&mut w).unwrap(); let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor)>::read( &mut ::std::io::Cursor::new(&w.0)).unwrap().1; assert!(new_monitor == *monitor); - let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator); + let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister); assert!(watchtower.watch_channel(outpoint, new_monitor).is_ok()); watchtower }; @@ -8357,15 +8368,16 @@ fn test_concurrent_monitor_claim() { // Copy ChainMonitor to simulate watchtower Alice and update block height her ChannelMonitor timeout HTLC onchain let chain_source = test_utils::TestChainSource::new(Network::Testnet); let logger = test_utils::TestLogger::with_id(format!("node {}", "Alice")); + let persister = test_utils::TestPersister::new(); let watchtower_alice = { let monitors = nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap(); let monitor = monitors.get(&outpoint).unwrap(); let mut w = test_utils::TestVecWriter(Vec::new()); - monitor.write_for_disk(&mut w).unwrap(); + monitor.serialize_for_disk(&mut w).unwrap(); let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor)>::read( &mut ::std::io::Cursor::new(&w.0)).unwrap().1; assert!(new_monitor == *monitor); - let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator); + let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister); assert!(watchtower.watch_channel(outpoint, new_monitor).is_ok()); watchtower }; @@ -8382,15 +8394,16 @@ fn test_concurrent_monitor_claim() { // Copy ChainMonitor to simulate watchtower Bob and make it receive a commitment update first. let chain_source = test_utils::TestChainSource::new(Network::Testnet); let logger = test_utils::TestLogger::with_id(format!("node {}", "Bob")); + let persister = test_utils::TestPersister::new(); let watchtower_bob = { let monitors = nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap(); let monitor = monitors.get(&outpoint).unwrap(); let mut w = test_utils::TestVecWriter(Vec::new()); - monitor.write_for_disk(&mut w).unwrap(); + monitor.serialize_for_disk(&mut w).unwrap(); let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor)>::read( &mut ::std::io::Cursor::new(&w.0)).unwrap().1; assert!(new_monitor == *monitor); - let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator); + let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister); assert!(watchtower.watch_channel(outpoint, new_monitor).is_ok()); watchtower }; diff --git a/lightning/src/ln/mod.rs b/lightning/src/ln/mod.rs index cd959a74dc5..1afcb3530fe 100644 --- a/lightning/src/ln/mod.rs +++ b/lightning/src/ln/mod.rs @@ -38,9 +38,9 @@ mod wire; // without the node parameter being mut. This is incorrect, and thus newer rustcs will complain // about an unnecessary mut. Thus, we silence the unused_mut warning in two test modules below. -#[cfg(test)] +#[cfg(any(test, feature = "_test_utils"))] #[macro_use] -pub(crate) mod functional_test_utils; +pub mod functional_test_utils; #[cfg(test)] #[allow(unused_mut)] mod functional_tests; diff --git a/lightning/src/util/mod.rs b/lightning/src/util/mod.rs index 84c8e6c14da..a3cfaa9804e 100644 --- a/lightning/src/util/mod.rs +++ b/lightning/src/util/mod.rs @@ -32,10 +32,10 @@ pub(crate) mod macro_logger; pub mod logger; pub mod config; -#[cfg(test)] -pub(crate) mod test_utils; +#[cfg(any(test, feature = "_test_utils"))] +pub mod test_utils; /// impls of traits that add exra enforcement on the way they're called. Useful for detecting state /// machine errors and used in fuzz targets and tests. -#[cfg(any(test, feature = "fuzztarget"))] +#[cfg(any(test, feature = "fuzztarget", feature = "_test_utils"))] pub mod enforcing_trait_impls; diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index 0370c0e1a40..6c3552d5df0 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -63,19 +63,19 @@ impl chaininterface::FeeEstimator for TestFeeEstimator { pub struct TestChainMonitor<'a> { pub added_monitors: Mutex)>>, pub latest_monitor_update_id: Mutex>, - pub chain_monitor: chainmonitor::ChainMonitor, - pub update_ret: Mutex>, + pub chain_monitor: chainmonitor::ChainMonitor>, + pub update_ret: Mutex>>, // If this is set to Some(), after the next return, we'll always return this until update_ret // is changed: pub next_update_ret: Mutex>>, } impl<'a> TestChainMonitor<'a> { - pub fn new(chain_source: Option<&'a TestChainSource>, broadcaster: &'a chaininterface::BroadcasterInterface, logger: &'a TestLogger, fee_estimator: &'a TestFeeEstimator) -> Self { + pub fn new(chain_source: Option<&'a TestChainSource>, broadcaster: &'a chaininterface::BroadcasterInterface, logger: &'a TestLogger, fee_estimator: &'a TestFeeEstimator, persister: &'a channelmonitor::Persist) -> Self { Self { added_monitors: Mutex::new(Vec::new()), latest_monitor_update_id: Mutex::new(HashMap::new()), - chain_monitor: chainmonitor::ChainMonitor::new(chain_source, broadcaster, logger, fee_estimator), - update_ret: Mutex::new(Ok(())), + chain_monitor: chainmonitor::ChainMonitor::new(chain_source, broadcaster, logger, fee_estimator, persister), + update_ret: Mutex::new(None), next_update_ret: Mutex::new(None), } } @@ -87,19 +87,23 @@ impl<'a> chain::Watch for TestChainMonitor<'a> { // At every point where we get a monitor update, we should be able to send a useful monitor // to a watchtower and disk... let mut w = TestVecWriter(Vec::new()); - monitor.write_for_disk(&mut w).unwrap(); + monitor.serialize_for_disk(&mut w).unwrap(); let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor)>::read( &mut ::std::io::Cursor::new(&w.0)).unwrap().1; assert!(new_monitor == monitor); self.latest_monitor_update_id.lock().unwrap().insert(funding_txo.to_channel_id(), (funding_txo, monitor.get_latest_update_id())); self.added_monitors.lock().unwrap().push((funding_txo, monitor)); - assert!(self.chain_monitor.watch_channel(funding_txo, new_monitor).is_ok()); + let watch_res = self.chain_monitor.watch_channel(funding_txo, new_monitor); let ret = self.update_ret.lock().unwrap().clone(); if let Some(next_ret) = self.next_update_ret.lock().unwrap().take() { - *self.update_ret.lock().unwrap() = next_ret; + *self.update_ret.lock().unwrap() = Some(next_ret); } - ret + if ret.is_some() { + assert!(watch_res.is_ok()); + return ret.unwrap(); + } + watch_res } fn update_channel(&self, funding_txo: OutPoint, update: channelmonitor::ChannelMonitorUpdate) -> Result<(), channelmonitor::ChannelMonitorUpdateErr> { @@ -110,23 +114,27 @@ impl<'a> chain::Watch for TestChainMonitor<'a> { &mut ::std::io::Cursor::new(&w.0)).unwrap() == update); self.latest_monitor_update_id.lock().unwrap().insert(funding_txo.to_channel_id(), (funding_txo, update.update_id)); - assert!(self.chain_monitor.update_channel(funding_txo, update).is_ok()); + let update_res = self.chain_monitor.update_channel(funding_txo, update); // At every point where we get a monitor update, we should be able to send a useful monitor // to a watchtower and disk... let monitors = self.chain_monitor.monitors.lock().unwrap(); let monitor = monitors.get(&funding_txo).unwrap(); w.0.clear(); - monitor.write_for_disk(&mut w).unwrap(); + monitor.serialize_for_disk(&mut w).unwrap(); let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor)>::read( - &mut ::std::io::Cursor::new(&w.0)).unwrap().1; + &mut ::std::io::Cursor::new(&w.0)).unwrap().1; assert!(new_monitor == *monitor); self.added_monitors.lock().unwrap().push((funding_txo, new_monitor)); let ret = self.update_ret.lock().unwrap().clone(); if let Some(next_ret) = self.next_update_ret.lock().unwrap().take() { - *self.update_ret.lock().unwrap() = next_ret; + *self.update_ret.lock().unwrap() = Some(next_ret); } - ret + if ret.is_some() { + assert!(update_res.is_ok()); + return ret.unwrap(); + } + update_res } fn release_pending_monitor_events(&self) -> Vec { @@ -134,6 +142,30 @@ impl<'a> chain::Watch for TestChainMonitor<'a> { } } +pub struct TestPersister { + pub update_ret: Mutex> +} +impl TestPersister { + pub fn new() -> Self { + Self { + update_ret: Mutex::new(Ok(())) + } + } + + pub fn set_update_ret(&self, ret: Result<(), channelmonitor::ChannelMonitorUpdateErr>) { + *self.update_ret.lock().unwrap() = ret; + } +} +impl channelmonitor::Persist for TestPersister { + fn persist_new_channel(&self, _funding_txo: OutPoint, _data: &channelmonitor::ChannelMonitor) -> Result<(), channelmonitor::ChannelMonitorUpdateErr> { + self.update_ret.lock().unwrap().clone() + } + + fn update_persisted_channel(&self, _funding_txo: OutPoint, _update: &channelmonitor::ChannelMonitorUpdate, _data: &channelmonitor::ChannelMonitor) -> Result<(), channelmonitor::ChannelMonitorUpdateErr> { + self.update_ret.lock().unwrap().clone() + } +} + pub struct TestBroadcaster { pub txn_broadcasted: Mutex>, }