diff --git a/lightning-background-processor/src/lib.rs b/lightning-background-processor/src/lib.rs index 107f65f9b74..c61eb37d27d 100644 --- a/lightning-background-processor/src/lib.rs +++ b/lightning-background-processor/src/lib.rs @@ -379,7 +379,7 @@ mod tests { use lightning::util::ser::Writeable; use lightning::util::test_utils; use lightning::util::persist::KVStorePersister; - use lightning_invoice::payment::{InvoicePayer, RetryAttempts}; + use lightning_invoice::payment::{InvoicePayer, Retry}; use lightning_invoice::utils::DefaultRouter; use lightning_persister::FilesystemPersister; use std::fs; @@ -801,7 +801,7 @@ mod tests { let data_dir = nodes[0].persister.get_data_dir(); let persister = Arc::new(Persister::new(data_dir)); let router = DefaultRouter::new(Arc::clone(&nodes[0].network_graph), Arc::clone(&nodes[0].logger), random_seed_bytes); - let invoice_payer = Arc::new(InvoicePayer::new(Arc::clone(&nodes[0].node), router, Arc::clone(&nodes[0].scorer), Arc::clone(&nodes[0].logger), |_: &_| {}, RetryAttempts(2))); + let invoice_payer = Arc::new(InvoicePayer::new(Arc::clone(&nodes[0].node), router, Arc::clone(&nodes[0].scorer), Arc::clone(&nodes[0].logger), |_: &_| {}, Retry::Attempts(2))); let event_handler = Arc::clone(&invoice_payer); let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone())); assert!(bg_processor.stop().is_ok()); diff --git a/lightning-invoice/src/lib.rs b/lightning-invoice/src/lib.rs index 616ea99f0fe..9fcb4af1b62 100644 --- a/lightning-invoice/src/lib.rs +++ b/lightning-invoice/src/lib.rs @@ -25,6 +25,8 @@ compile_error!("at least one of the `std` or `no-std` features must be enabled") pub mod payment; pub mod utils; +pub(crate) mod time_utils; + extern crate bech32; extern crate bitcoin_hashes; #[macro_use] extern crate lightning; diff --git a/lightning-invoice/src/payment.rs b/lightning-invoice/src/payment.rs index 6b79a0123d9..d6a3abb0ddf 100644 --- a/lightning-invoice/src/payment.rs +++ b/lightning-invoice/src/payment.rs @@ -45,7 +45,7 @@ //! # use lightning::util::logger::{Logger, Record}; //! # use lightning::util::ser::{Writeable, Writer}; //! # use lightning_invoice::Invoice; -//! # use lightning_invoice::payment::{InvoicePayer, Payer, RetryAttempts, Router}; +//! # use lightning_invoice::payment::{InvoicePayer, Payer, Retry, Router}; //! # use secp256k1::PublicKey; //! # use std::cell::RefCell; //! # use std::ops::Deref; @@ -113,7 +113,7 @@ //! # let router = FakeRouter {}; //! # let scorer = RefCell::new(FakeScorer {}); //! # let logger = FakeLogger {}; -//! let invoice_payer = InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); +//! let invoice_payer = InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(2)); //! //! let invoice = "..."; //! if let Ok(invoice) = invoice.parse::() { @@ -146,10 +146,13 @@ use lightning::routing::scoring::{LockableScore, Score}; use lightning::routing::router::{PaymentParameters, Route, RouteParameters}; use lightning::util::events::{Event, EventHandler}; use lightning::util::logger::Logger; +use time_utils::Time; use crate::sync::Mutex; use secp256k1::PublicKey; +use core::fmt; +use core::fmt::{Debug, Display, Formatter}; use core::ops::Deref; use core::time::Duration; #[cfg(feature = "std")] @@ -160,7 +163,17 @@ use std::time::SystemTime; /// See [module-level documentation] for details. /// /// [module-level documentation]: crate::payment -pub struct InvoicePayer +pub type InvoicePayer = InvoicePayerUsingTime::; + +#[cfg(not(feature = "no-std"))] +type ConfiguredTime = std::time::Instant; +#[cfg(feature = "no-std")] +use time_utils; +#[cfg(feature = "no-std")] +type ConfiguredTime = time_utils::Eternity; + +/// (C-not exported) generally all users should use the [`InvoicePayer`] type alias. +pub struct InvoicePayerUsingTime where P::Target: Payer, R: for <'a> Router<<::Target as LockableScore<'a>>::Locked>, @@ -173,8 +186,42 @@ where logger: L, event_handler: E, /// Caches the overall attempts at making a payment, which is updated prior to retrying. - payment_cache: Mutex>, - retry_attempts: RetryAttempts, + payment_cache: Mutex>>, + retry: Retry, +} + +/// Storing minimal payment attempts information required for determining if a outbound payment can +/// be retried. +#[derive(Clone, Copy)] +struct PaymentAttempts { + /// This count will be incremented only after the result of the attempt is known. When it's 0, + /// it means the result of the first attempt is now known yet. + count: usize, + /// This field is only used when retry is [`Retry::Timeout`] which is only build with feature std + first_attempted_at: T +} + +impl PaymentAttempts { + fn new() -> Self { + PaymentAttempts { + count: 0, + first_attempted_at: T::now() + } + } +} + +impl Display for PaymentAttempts { + fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> { + #[cfg(feature = "no-std")] + return write!( f, "attempts: {}", self.count); + #[cfg(not(feature = "no-std"))] + return write!( + f, + "attempts: {}, duration: {}s", + self.count, + T::now().duration_since(self.first_attempted_at).as_secs() + ); + } } /// A trait defining behavior of an [`Invoice`] payer. @@ -211,13 +258,33 @@ pub trait Router { ) -> Result; } -/// Number of attempts to retry payment path failures for an [`Invoice`]. +/// Strategies available to retry payment path failures for an [`Invoice`]. /// -/// Note that this is the number of *path* failures, not full payment retries. For multi-path -/// payments, if this is less than the total number of paths, we will never even retry all of the -/// payment's paths. #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] -pub struct RetryAttempts(pub usize); +pub enum Retry { + /// Max number of attempts to retry payment. + /// + /// Note that this is the number of *path* failures, not full payment retries. For multi-path + /// payments, if this is less than the total number of paths, we will never even retry all of the + /// payment's paths. + Attempts(usize), + #[cfg(feature = "std")] + /// Time elapsed before abandoning retries for a payment. + Timeout(Duration), +} + +impl Retry { + fn is_retryable_now(&self, attempts: &PaymentAttempts) -> bool { + match (self, attempts) { + (Retry::Attempts(max_retry_count), PaymentAttempts { count, .. }) => { + max_retry_count >= &count + }, + #[cfg(feature = "std")] + (Retry::Timeout(max_duration), PaymentAttempts { first_attempted_at, .. } ) => + *max_duration >= T::now().duration_since(*first_attempted_at), + } + } +} /// An error that may occur when making a payment. #[derive(Clone, Debug)] @@ -230,7 +297,7 @@ pub enum PaymentError { Sending(PaymentSendFailure), } -impl InvoicePayer +impl InvoicePayerUsingTime where P::Target: Payer, R: for <'a> Router<<::Target as LockableScore<'a>>::Locked>, @@ -240,9 +307,9 @@ where /// Creates an invoice payer that retries failed payment paths. /// /// Will forward any [`Event::PaymentPathFailed`] events to the decorated `event_handler` once - /// `retry_attempts` has been exceeded for a given [`Invoice`]. + /// `retry` has been exceeded for a given [`Invoice`]. pub fn new( - payer: P, router: R, scorer: S, logger: L, event_handler: E, retry_attempts: RetryAttempts + payer: P, router: R, scorer: S, logger: L, event_handler: E, retry: Retry ) -> Self { Self { payer, @@ -251,7 +318,7 @@ where logger, event_handler, payment_cache: Mutex::new(HashMap::new()), - retry_attempts, + retry, } } @@ -292,7 +359,7 @@ where let payment_hash = PaymentHash(invoice.payment_hash().clone().into_inner()); match self.payment_cache.lock().unwrap().entry(payment_hash) { hash_map::Entry::Occupied(_) => return Err(PaymentError::Invoice("payment pending")), - hash_map::Entry::Vacant(entry) => entry.insert(0), + hash_map::Entry::Vacant(entry) => entry.insert(PaymentAttempts::new()), }; let payment_secret = Some(invoice.payment_secret().clone()); @@ -311,6 +378,7 @@ where let send_payment = |route: &Route| { self.payer.send_payment(route, payment_hash, &payment_secret) }; + self.pay_internal(&route_params, payment_hash, send_payment) .map_err(|e| { self.payment_cache.lock().unwrap().remove(&payment_hash); e }) } @@ -327,7 +395,7 @@ where let payment_hash = PaymentHash(Sha256::hash(&payment_preimage.0).into_inner()); match self.payment_cache.lock().unwrap().entry(payment_hash) { hash_map::Entry::Occupied(_) => return Err(PaymentError::Invoice("payment pending")), - hash_map::Entry::Vacant(entry) => entry.insert(0), + hash_map::Entry::Vacant(entry) => entry.insert(PaymentAttempts::new()), }; let route_params = RouteParameters { @@ -367,13 +435,13 @@ where PaymentSendFailure::PathParameterError(_) => Err(e), PaymentSendFailure::AllFailedRetrySafe(_) => { let mut payment_cache = self.payment_cache.lock().unwrap(); - let retry_count = payment_cache.get_mut(&payment_hash).unwrap(); - if *retry_count >= self.retry_attempts.0 { - Err(e) - } else { - *retry_count += 1; + let payment_attempts = payment_cache.get_mut(&payment_hash).unwrap(); + payment_attempts.count += 1; + if self.retry.is_retryable_now(payment_attempts) { core::mem::drop(payment_cache); Ok(self.pay_internal(params, payment_hash, send_payment)?) + } else { + Err(e) } }, PaymentSendFailure::PartialFailure { failed_paths_retry, payment_id, .. } => { @@ -399,20 +467,22 @@ where fn retry_payment( &self, payment_id: PaymentId, payment_hash: PaymentHash, params: &RouteParameters ) -> Result<(), ()> { - let max_payment_attempts = self.retry_attempts.0 + 1; - let attempts = *self.payment_cache.lock().unwrap() - .entry(payment_hash) - .and_modify(|attempts| *attempts += 1) - .or_insert(1); - - if attempts >= max_payment_attempts { - log_trace!(self.logger, "Payment {} exceeded maximum attempts; not retrying (attempts: {})", log_bytes!(payment_hash.0), attempts); + let attempts = + *self.payment_cache.lock().unwrap().entry(payment_hash) + .and_modify(|attempts| attempts.count += 1) + .or_insert(PaymentAttempts { + count: 1, + first_attempted_at: T::now() + }); + + if !self.retry.is_retryable_now(&attempts) { + log_trace!(self.logger, "Payment {} exceeded maximum attempts; not retrying ({})", log_bytes!(payment_hash.0), attempts); return Err(()); } #[cfg(feature = "std")] { if has_expired(params) { - log_trace!(self.logger, "Invoice expired for payment {}; not retrying (attempts: {})", log_bytes!(payment_hash.0), attempts); + log_trace!(self.logger, "Invoice expired for payment {}; not retrying ({:})", log_bytes!(payment_hash.0), attempts); return Err(()); } } @@ -424,7 +494,7 @@ where &self.scorer.lock() ); if route.is_err() { - log_trace!(self.logger, "Failed to find a route for payment {}; not retrying (attempts: {})", log_bytes!(payment_hash.0), attempts); + log_trace!(self.logger, "Failed to find a route for payment {}; not retrying ({:})", log_bytes!(payment_hash.0), attempts); return Err(()); } @@ -468,7 +538,7 @@ fn has_expired(route_params: &RouteParameters) -> bool { } else { false } } -impl EventHandler for InvoicePayer +impl EventHandler for InvoicePayerUsingTime where P::Target: Payer, R: for <'a> Router<<::Target as LockableScore<'a>>::Locked>, @@ -511,7 +581,7 @@ where let mut payment_cache = self.payment_cache.lock().unwrap(); let attempts = payment_cache .remove(payment_hash) - .map_or(1, |attempts| attempts + 1); + .map_or(1, |attempts| attempts.count + 1); log_trace!(self.logger, "Payment {} succeeded (attempts: {})", log_bytes!(payment_hash.0), attempts); }, _ => {}, @@ -541,6 +611,7 @@ mod tests { use std::cell::RefCell; use std::collections::VecDeque; use std::time::{SystemTime, Duration}; + use time_utils::tests::SinceEpoch; use DEFAULT_EXPIRY_TIME; fn invoice(payment_preimage: PaymentPreimage) -> Invoice { @@ -624,7 +695,7 @@ mod tests { let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0)); + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(0)); let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap()); assert_eq!(*payer.attempts.borrow(), 1); @@ -653,7 +724,7 @@ mod tests { let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(2)); let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap()); assert_eq!(*payer.attempts.borrow(), 1); @@ -698,7 +769,7 @@ mod tests { let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(2)); assert!(invoice_payer.pay_invoice(&invoice).is_ok()); } @@ -720,7 +791,7 @@ mod tests { let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(2)); let payment_id = Some(PaymentId([1; 32])); let event = Event::PaymentPathFailed { @@ -749,7 +820,7 @@ mod tests { } #[test] - fn fails_paying_invoice_after_max_retries() { + fn fails_paying_invoice_after_max_retry_counts() { let event_handled = core::cell::RefCell::new(false); let event_handler = |_: &_| { *event_handled.borrow_mut() = true; }; @@ -765,7 +836,7 @@ mod tests { let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(2)); let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap()); assert_eq!(*payer.attempts.borrow(), 1); @@ -805,6 +876,52 @@ mod tests { assert_eq!(*payer.attempts.borrow(), 3); } + #[cfg(feature = "std")] + #[test] + fn fails_paying_invoice_after_max_retry_timeout() { + let event_handled = core::cell::RefCell::new(false); + let event_handler = |_: &_| { *event_handled.borrow_mut() = true; }; + + let payment_preimage = PaymentPreimage([1; 32]); + let invoice = invoice(payment_preimage); + let final_value_msat = invoice.amount_milli_satoshis().unwrap(); + + let payer = TestPayer::new() + .expect_send(Amount::ForInvoice(final_value_msat)) + .expect_send(Amount::OnRetry(final_value_msat / 2)); + + let router = TestRouter {}; + let scorer = RefCell::new(TestScorer::new()); + let logger = TestLogger::new(); + type InvoicePayerUsingSinceEpoch = InvoicePayerUsingTime::; + + let invoice_payer = + InvoicePayerUsingSinceEpoch::new(&payer, router, &scorer, &logger, event_handler, Retry::Timeout(Duration::from_secs(120))); + + let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap()); + assert_eq!(*payer.attempts.borrow(), 1); + + let event = Event::PaymentPathFailed { + payment_id, + payment_hash: PaymentHash(invoice.payment_hash().clone().into_inner()), + network_update: None, + rejected_by_dest: false, + all_paths_failed: true, + path: TestRouter::path_for_value(final_value_msat), + short_channel_id: None, + retry: Some(TestRouter::retry_for_invoice(&invoice)), + }; + invoice_payer.handle_event(&event); + assert_eq!(*event_handled.borrow(), false); + assert_eq!(*payer.attempts.borrow(), 2); + + SinceEpoch::advance(Duration::from_secs(121)); + + invoice_payer.handle_event(&event); + assert_eq!(*event_handled.borrow(), true); + assert_eq!(*payer.attempts.borrow(), 2); + } + #[test] fn fails_paying_invoice_with_missing_retry_params() { let event_handled = core::cell::RefCell::new(false); @@ -819,7 +936,7 @@ mod tests { let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(2)); let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap()); assert_eq!(*payer.attempts.borrow(), 1); @@ -851,7 +968,7 @@ mod tests { let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(2)); let payment_preimage = PaymentPreimage([1; 32]); let invoice = expired_invoice(payment_preimage); @@ -876,7 +993,7 @@ mod tests { let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(2)); let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap()); assert_eq!(*payer.attempts.borrow(), 1); @@ -917,7 +1034,7 @@ mod tests { let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(2)); let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap()); assert_eq!(*payer.attempts.borrow(), 1); @@ -951,7 +1068,7 @@ mod tests { let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(2)); let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap()); assert_eq!(*payer.attempts.borrow(), 1); @@ -987,7 +1104,7 @@ mod tests { let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0)); + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(0)); let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap()); @@ -1026,7 +1143,7 @@ mod tests { let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &scorer, &logger, |_: &_| {}, RetryAttempts(0)); + InvoicePayer::new(&payer, router, &scorer, &logger, |_: &_| {}, Retry::Attempts(0)); let payment_preimage = PaymentPreimage([1; 32]); let invoice = invoice(payment_preimage); @@ -1050,7 +1167,7 @@ mod tests { let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &scorer, &logger, |_: &_| {}, RetryAttempts(0)); + InvoicePayer::new(&payer, router, &scorer, &logger, |_: &_| {}, Retry::Attempts(0)); match invoice_payer.pay_invoice(&invoice) { Err(PaymentError::Sending(_)) => {}, @@ -1074,7 +1191,7 @@ mod tests { let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0)); + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(0)); let payment_id = Some(invoice_payer.pay_zero_value_invoice(&invoice, final_value_msat).unwrap()); @@ -1097,7 +1214,7 @@ mod tests { let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0)); + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(0)); let payment_preimage = PaymentPreimage([1; 32]); let invoice = invoice(payment_preimage); @@ -1128,7 +1245,7 @@ mod tests { let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(2)); let payment_id = Some(invoice_payer.pay_pubkey( pubkey, payment_preimage, final_value_msat, final_cltv_expiry_delta @@ -1183,7 +1300,7 @@ mod tests { })); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(2)); let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap()); let event = Event::PaymentPathFailed { @@ -1219,7 +1336,7 @@ mod tests { ); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, Retry::Attempts(2)); let payment_id = invoice_payer.pay_invoice(&invoice).unwrap(); let event = Event::PaymentPathSuccessful { @@ -1554,7 +1671,7 @@ mod tests { let event_handler = |_: &_| { panic!(); }; let scorer = RefCell::new(TestScorer::new()); - let invoice_payer = InvoicePayer::new(nodes[0].node, router, &scorer, nodes[0].logger, event_handler, RetryAttempts(1)); + let invoice_payer = InvoicePayer::new(nodes[0].node, router, &scorer, nodes[0].logger, event_handler, Retry::Attempts(1)); assert!(invoice_payer.pay_invoice(&create_invoice_from_channelmanager_and_duration_since_epoch( &nodes[1].node, nodes[1].keys_manager, Currency::Bitcoin, Some(100_010_000), "Invoice".to_string(), @@ -1600,7 +1717,7 @@ mod tests { let event_handler = |_: &_| { panic!(); }; let scorer = RefCell::new(TestScorer::new()); - let invoice_payer = InvoicePayer::new(nodes[0].node, router, &scorer, nodes[0].logger, event_handler, RetryAttempts(1)); + let invoice_payer = InvoicePayer::new(nodes[0].node, router, &scorer, nodes[0].logger, event_handler, Retry::Attempts(1)); assert!(invoice_payer.pay_invoice(&create_invoice_from_channelmanager_and_duration_since_epoch( &nodes[1].node, nodes[1].keys_manager, Currency::Bitcoin, Some(100_010_000), "Invoice".to_string(), @@ -1682,7 +1799,7 @@ mod tests { event_checker(event); }; let scorer = RefCell::new(TestScorer::new()); - let invoice_payer = InvoicePayer::new(nodes[0].node, router, &scorer, nodes[0].logger, event_handler, RetryAttempts(1)); + let invoice_payer = InvoicePayer::new(nodes[0].node, router, &scorer, nodes[0].logger, event_handler, Retry::Attempts(1)); assert!(invoice_payer.pay_invoice(&create_invoice_from_channelmanager_and_duration_since_epoch( &nodes[1].node, nodes[1].keys_manager, Currency::Bitcoin, Some(100_010_000), "Invoice".to_string(), diff --git a/lightning-invoice/src/time_utils.rs b/lightning-invoice/src/time_utils.rs new file mode 120000 index 00000000000..5326cffb861 --- /dev/null +++ b/lightning-invoice/src/time_utils.rs @@ -0,0 +1 @@ +../../lightning/src/util/time.rs \ No newline at end of file diff --git a/lightning/src/routing/scoring.rs b/lightning/src/routing/scoring.rs index 4c47aac47b6..cffd2d90533 100644 --- a/lightning/src/routing/scoring.rs +++ b/lightning/src/routing/scoring.rs @@ -59,6 +59,7 @@ use routing::network_graph::{NetworkGraph, NodeId}; use routing::router::RouteHop; use util::ser::{Readable, ReadableArgs, Writeable, Writer}; use util::logger::Logger; +use util::time::Time; use prelude::*; use core::fmt; @@ -262,7 +263,9 @@ pub type Scorer = ScorerUsingTime::; #[cfg(not(feature = "no-std"))] type ConfiguredTime = std::time::Instant; #[cfg(feature = "no-std")] -type ConfiguredTime = time::Eternity; +use util::time::Eternity; +#[cfg(feature = "no-std")] +type ConfiguredTime = Eternity; // Note that ideally we'd hide ScorerUsingTime from public view by sealing it as well, but rustdoc // doesn't handle this well - instead exposing a `Scorer` which has no trait implementation(s) or @@ -1327,83 +1330,11 @@ impl Readable for ChannelLiquidity { } } -pub(crate) mod time { - use core::ops::Sub; - use core::time::Duration; - /// A measurement of time. - pub trait Time: Copy + Sub where Self: Sized { - /// Returns an instance corresponding to the current moment. - fn now() -> Self; - - /// Returns the amount of time elapsed since `self` was created. - fn elapsed(&self) -> Duration; - - /// Returns the amount of time passed between `earlier` and `self`. - fn duration_since(&self, earlier: Self) -> Duration; - - /// Returns the amount of time passed since the beginning of [`Time`]. - /// - /// Used during (de-)serialization. - fn duration_since_epoch() -> Duration; - } - - /// A state in which time has no meaning. - #[derive(Clone, Copy, Debug, PartialEq, Eq)] - pub struct Eternity; - - #[cfg(not(feature = "no-std"))] - impl Time for std::time::Instant { - fn now() -> Self { - std::time::Instant::now() - } - - fn duration_since(&self, earlier: Self) -> Duration { - self.duration_since(earlier) - } - - fn duration_since_epoch() -> Duration { - use std::time::SystemTime; - SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap() - } - - fn elapsed(&self) -> Duration { - std::time::Instant::elapsed(self) - } - } - - impl Time for Eternity { - fn now() -> Self { - Self - } - - fn duration_since(&self, _earlier: Self) -> Duration { - Duration::from_secs(0) - } - - fn duration_since_epoch() -> Duration { - Duration::from_secs(0) - } - - fn elapsed(&self) -> Duration { - Duration::from_secs(0) - } - } - - impl Sub for Eternity { - type Output = Self; - - fn sub(self, _other: Duration) -> Self { - self - } - } -} - -pub(crate) use self::time::Time; - #[cfg(test)] mod tests { - use super::{ChannelLiquidity, ProbabilisticScoringParameters, ProbabilisticScorerUsingTime, ScoringParameters, ScorerUsingTime, Time}; - use super::time::Eternity; + use super::{ChannelLiquidity, ProbabilisticScoringParameters, ProbabilisticScorerUsingTime, ScoringParameters, ScorerUsingTime}; + use util::time::Time; + use util::time::tests::SinceEpoch; use ln::features::{ChannelFeatures, NodeFeatures}; use ln::msgs::{ChannelAnnouncement, ChannelUpdate, OptionalField, UnsignedChannelAnnouncement, UnsignedChannelUpdate}; @@ -1418,80 +1349,9 @@ mod tests { use bitcoin::hashes::sha256d::Hash as Sha256dHash; use bitcoin::network::constants::Network; use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey}; - use core::cell::Cell; - use core::ops::Sub; use core::time::Duration; use io; - // `Time` tests - - /// Time that can be advanced manually in tests. - #[derive(Clone, Copy, Debug, PartialEq, Eq)] - struct SinceEpoch(Duration); - - impl SinceEpoch { - thread_local! { - static ELAPSED: Cell = core::cell::Cell::new(Duration::from_secs(0)); - } - - fn advance(duration: Duration) { - Self::ELAPSED.with(|elapsed| elapsed.set(elapsed.get() + duration)) - } - } - - impl Time for SinceEpoch { - fn now() -> Self { - Self(Self::duration_since_epoch()) - } - - fn duration_since(&self, earlier: Self) -> Duration { - self.0 - earlier.0 - } - - fn duration_since_epoch() -> Duration { - Self::ELAPSED.with(|elapsed| elapsed.get()) - } - - fn elapsed(&self) -> Duration { - Self::duration_since_epoch() - self.0 - } - } - - impl Sub for SinceEpoch { - type Output = Self; - - fn sub(self, other: Duration) -> Self { - Self(self.0 - other) - } - } - - #[test] - fn time_passes_when_advanced() { - let now = SinceEpoch::now(); - assert_eq!(now.elapsed(), Duration::from_secs(0)); - - SinceEpoch::advance(Duration::from_secs(1)); - SinceEpoch::advance(Duration::from_secs(1)); - - let elapsed = now.elapsed(); - let later = SinceEpoch::now(); - - assert_eq!(elapsed, Duration::from_secs(2)); - assert_eq!(later - elapsed, now); - } - - #[test] - fn time_never_passes_in_an_eternity() { - let now = Eternity::now(); - let elapsed = now.elapsed(); - let later = Eternity::now(); - - assert_eq!(now.elapsed(), Duration::from_secs(0)); - assert_eq!(later - elapsed, now); - } - - // `Scorer` tests - /// A scorer for testing with time that can be manually advanced. type Scorer = ScorerUsingTime::; diff --git a/lightning/src/util/mod.rs b/lightning/src/util/mod.rs index b7ee02d2c1f..c6181ab269a 100644 --- a/lightning/src/util/mod.rs +++ b/lightning/src/util/mod.rs @@ -36,6 +36,7 @@ pub(crate) mod poly1305; pub(crate) mod chacha20poly1305rfc; pub(crate) mod transaction_utils; pub(crate) mod scid_utils; +pub(crate) mod time; /// Logging macro utilities. #[macro_use] diff --git a/lightning/src/util/time.rs b/lightning/src/util/time.rs new file mode 100644 index 00000000000..d3768aa7ca6 --- /dev/null +++ b/lightning/src/util/time.rs @@ -0,0 +1,152 @@ +// This file is licensed under the Apache License, Version 2.0 or the MIT license +// , at your option. +// You may not use this file except in accordance with one or both of these +// licenses. + +//! [`Time`] trait and different implementations. Currently, it's mainly used in tests so we can +//! manually advance time. +//! Other crates may symlink this file to use it while [`Time`] trait is sealed here. + +use core::ops::Sub; +use core::time::Duration; + +/// A measurement of time. +pub trait Time: Copy + Sub where Self: Sized { + /// Returns an instance corresponding to the current moment. + fn now() -> Self; + + /// Returns the amount of time elapsed since `self` was created. + fn elapsed(&self) -> Duration; + + /// Returns the amount of time passed between `earlier` and `self`. + fn duration_since(&self, earlier: Self) -> Duration; + + /// Returns the amount of time passed since the beginning of [`Time`]. + /// + /// Used during (de-)serialization. + fn duration_since_epoch() -> Duration; +} + +/// A state in which time has no meaning. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Eternity; + +impl Time for Eternity { + fn now() -> Self { + Self + } + + fn duration_since(&self, _earlier: Self) -> Duration { + Duration::from_secs(0) + } + + fn duration_since_epoch() -> Duration { + Duration::from_secs(0) + } + + fn elapsed(&self) -> Duration { + Duration::from_secs(0) + } +} + +impl Sub for Eternity { + type Output = Self; + + fn sub(self, _other: Duration) -> Self { + self + } +} + +#[cfg(not(feature = "no-std"))] +impl Time for std::time::Instant { + fn now() -> Self { + std::time::Instant::now() + } + + fn duration_since(&self, earlier: Self) -> Duration { + self.duration_since(earlier) + } + + fn duration_since_epoch() -> Duration { + use std::time::SystemTime; + SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap() + } + fn elapsed(&self) -> Duration { + std::time::Instant::elapsed(self) + } +} + +#[cfg(test)] +pub mod tests { + use super::{Time, Eternity}; + + use core::time::Duration; + use core::ops::Sub; + use core::cell::Cell; + + /// Time that can be advanced manually in tests. + #[derive(Clone, Copy, Debug, PartialEq, Eq)] + pub struct SinceEpoch(Duration); + + impl SinceEpoch { + thread_local! { + static ELAPSED: Cell = core::cell::Cell::new(Duration::from_secs(0)); + } + + pub fn advance(duration: Duration) { + Self::ELAPSED.with(|elapsed| elapsed.set(elapsed.get() + duration)) + } + } + + impl Time for SinceEpoch { + fn now() -> Self { + Self(Self::duration_since_epoch()) + } + + fn duration_since(&self, earlier: Self) -> Duration { + self.0 - earlier.0 + } + + fn duration_since_epoch() -> Duration { + Self::ELAPSED.with(|elapsed| elapsed.get()) + } + + fn elapsed(&self) -> Duration { + Self::duration_since_epoch() - self.0 + } + } + + impl Sub for SinceEpoch { + type Output = Self; + + fn sub(self, other: Duration) -> Self { + Self(self.0 - other) + } + } + + #[test] + fn time_passes_when_advanced() { + let now = SinceEpoch::now(); + assert_eq!(now.elapsed(), Duration::from_secs(0)); + + SinceEpoch::advance(Duration::from_secs(1)); + SinceEpoch::advance(Duration::from_secs(1)); + + let elapsed = now.elapsed(); + let later = SinceEpoch::now(); + + assert_eq!(elapsed, Duration::from_secs(2)); + assert_eq!(later - elapsed, now); + } + + #[test] + fn time_never_passes_in_an_eternity() { + let now = Eternity::now(); + let elapsed = now.elapsed(); + let later = Eternity::now(); + + assert_eq!(now.elapsed(), Duration::from_secs(0)); + assert_eq!(later - elapsed, now); + } +}