Skip to content

Commit d0b8f45

Browse files
authored
Merge pull request #2009 from TheBlueMatt/2023-02-no-racey-retries
Fix (and test) threaded payment retries
2 parents a170478 + d986329 commit d0b8f45

11 files changed

+302
-29
lines changed

lightning/src/ln/channelmanager.rs

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ use crate::prelude::*;
7070
use core::{cmp, mem};
7171
use core::cell::RefCell;
7272
use crate::io::Read;
73-
use crate::sync::{Arc, Mutex, RwLock, RwLockReadGuard, FairRwLock};
73+
use crate::sync::{Arc, Mutex, RwLock, RwLockReadGuard, FairRwLock, LockTestExt, LockHeldState};
7474
use core::sync::atomic::{AtomicUsize, Ordering};
7575
use core::time::Duration;
7676
use core::ops::Deref;
@@ -1218,13 +1218,10 @@ macro_rules! handle_error {
12181218
match $internal {
12191219
Ok(msg) => Ok(msg),
12201220
Err(MsgHandleErrInternal { err, chan_id, shutdown_finish }) => {
1221-
#[cfg(any(feature = "_test_utils", test))]
1222-
{
1223-
// In testing, ensure there are no deadlocks where the lock is already held upon
1224-
// entering the macro.
1225-
debug_assert!($self.pending_events.try_lock().is_ok());
1226-
debug_assert!($self.per_peer_state.try_write().is_ok());
1227-
}
1221+
// In testing, ensure there are no deadlocks where the lock is already held upon
1222+
// entering the macro.
1223+
debug_assert_ne!($self.pending_events.held_by_thread(), LockHeldState::HeldByThread);
1224+
debug_assert_ne!($self.per_peer_state.held_by_thread(), LockHeldState::HeldByThread);
12281225

12291226
let mut msg_events = Vec::with_capacity(2);
12301227

@@ -3722,17 +3719,12 @@ where
37223719
/// Fails an HTLC backwards to the sender of it to us.
37233720
/// Note that we do not assume that channels corresponding to failed HTLCs are still available.
37243721
fn fail_htlc_backwards_internal(&self, source: &HTLCSource, payment_hash: &PaymentHash, onion_error: &HTLCFailReason, destination: HTLCDestination) {
3725-
#[cfg(any(feature = "_test_utils", test))]
3726-
{
3727-
// Ensure that the peer state channel storage lock is not held when calling this
3728-
// function.
3729-
// This ensures that future code doesn't introduce a lock_order requirement for
3730-
// `forward_htlcs` to be locked after the `per_peer_state` peer locks, which calling
3731-
// this function with any `per_peer_state` peer lock aquired would.
3732-
let per_peer_state = self.per_peer_state.read().unwrap();
3733-
for (_, peer) in per_peer_state.iter() {
3734-
debug_assert!(peer.try_lock().is_ok());
3735-
}
3722+
// Ensure that no peer state channel storage lock is held when calling this function.
3723+
// This ensures that future code doesn't introduce a lock-order requirement for
3724+
// `forward_htlcs` to be locked after the `per_peer_state` peer locks, which calling
3725+
// this function with any `per_peer_state` peer lock acquired would.
3726+
for (_, peer) in self.per_peer_state.read().unwrap().iter() {
3727+
debug_assert_ne!(peer.held_by_thread(), LockHeldState::HeldByThread);
37363728
}
37373729

37383730
//TODO: There is a timing attack here where if a node fails an HTLC back to us they can
@@ -7702,7 +7694,7 @@ where
77027694

77037695
inbound_payment_key: expanded_inbound_key,
77047696
pending_inbound_payments: Mutex::new(pending_inbound_payments),
7705-
pending_outbound_payments: OutboundPayments { pending_outbound_payments: Mutex::new(pending_outbound_payments.unwrap()) },
7697+
pending_outbound_payments: OutboundPayments { pending_outbound_payments: Mutex::new(pending_outbound_payments.unwrap()), retry_lock: Mutex::new(()), },
77067698
pending_intercepted_htlcs: Mutex::new(pending_intercepted_htlcs.unwrap()),
77077699

77087700
forward_htlcs: Mutex::new(forward_htlcs),

lightning/src/ln/functional_test_utils.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,19 @@ impl<'a, 'b, 'c> Node<'a, 'b, 'c> {
351351
}
352352
}
353353

354+
/// If we need an unsafe pointer to a `Node` (ie to reference it in a thread
355+
/// pre-std::thread::scope), this provides that with `Sync`. Note that accessing some of the fields
356+
/// in the `Node` are not safe to use (i.e. the ones behind an `Rc`), but that's left to the caller
357+
/// to figure out.
358+
pub struct NodePtr(pub *const Node<'static, 'static, 'static>);
359+
impl NodePtr {
360+
pub fn from_node<'a, 'b: 'a, 'c: 'b>(node: &Node<'a, 'b, 'c>) -> Self {
361+
Self((node as *const Node<'a, 'b, 'c>).cast())
362+
}
363+
}
364+
unsafe impl Send for NodePtr {}
365+
unsafe impl Sync for NodePtr {}
366+
354367
impl<'a, 'b, 'c> Drop for Node<'a, 'b, 'c> {
355368
fn drop(&mut self) {
356369
if !panicking() {

lightning/src/ln/outbound_payment.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,12 +383,14 @@ pub enum PaymentSendFailure {
383383

384384
pub(super) struct OutboundPayments {
385385
pub(super) pending_outbound_payments: Mutex<HashMap<PaymentId, PendingOutboundPayment>>,
386+
pub(super) retry_lock: Mutex<()>,
386387
}
387388

388389
impl OutboundPayments {
389390
pub(super) fn new() -> Self {
390391
Self {
391-
pending_outbound_payments: Mutex::new(HashMap::new())
392+
pending_outbound_payments: Mutex::new(HashMap::new()),
393+
retry_lock: Mutex::new(()),
392394
}
393395
}
394396

@@ -494,6 +496,7 @@ impl OutboundPayments {
494496
FH: Fn() -> Vec<ChannelDetails>,
495497
L::Target: Logger,
496498
{
499+
let _single_thread = self.retry_lock.lock().unwrap();
497500
loop {
498501
let mut outbounds = self.pending_outbound_payments.lock().unwrap();
499502
let mut retry_id_route_params = None;

lightning/src/ln/payment_tests.rs

Lines changed: 163 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ use crate::routing::gossip::NodeId;
4040
#[cfg(feature = "std")]
4141
use {
4242
crate::util::time::tests::SinceEpoch,
43-
std::time::{SystemTime, Duration}
43+
std::time::{SystemTime, Instant, Duration}
4444
};
4545

4646
#[test]
@@ -2557,3 +2557,165 @@ fn test_simple_partial_retry() {
25572557
expect_pending_htlcs_forwardable!(nodes[2]);
25582558
expect_payment_claimable!(nodes[2], payment_hash, payment_secret, amt_msat);
25592559
}
2560+
2561+
#[test]
2562+
#[cfg(feature = "std")]
2563+
fn test_threaded_payment_retries() {
2564+
// In the first version of the in-`ChannelManager` payment retries, retries weren't limited to
2565+
// a single thread and would happily let multiple threads run retries at the same time. Because
2566+
// retries are done by first calculating the amount we need to retry, then dropping the
2567+
// relevant lock, then actually sending, we would happily let multiple threads retry the same
2568+
// amount at the same time, overpaying our original HTLC!
2569+
let chanmon_cfgs = create_chanmon_cfgs(4);
2570+
let node_cfgs = create_node_cfgs(4, &chanmon_cfgs);
2571+
let node_chanmgrs = create_node_chanmgrs(4, &node_cfgs, &[None, None, None, None]);
2572+
let nodes = create_network(4, &node_cfgs, &node_chanmgrs);
2573+
2574+
// There is one mitigating guardrail when retrying payments - we can never over-pay by more
2575+
// than 10% of the original value. Thus, we want all our retries to be below that. In order to
2576+
// keep things simple, we route one HTLC for 0.1% of the payment over channel 1 and the rest
2577+
// out over channel 3+4. This will let us ignore 99% of the payment value and deal with only
2578+
// our channel.
2579+
let chan_1_scid = create_announced_chan_between_nodes_with_value(&nodes, 0, 1, 10_000_000, 0).0.contents.short_channel_id;
2580+
create_announced_chan_between_nodes_with_value(&nodes, 1, 3, 10_000_000, 0);
2581+
let chan_3_scid = create_announced_chan_between_nodes_with_value(&nodes, 0, 2, 10_000_000, 0).0.contents.short_channel_id;
2582+
let chan_4_scid = create_announced_chan_between_nodes_with_value(&nodes, 2, 3, 10_000_000, 0).0.contents.short_channel_id;
2583+
2584+
let amt_msat = 100_000_000;
2585+
let (_, payment_hash, _, payment_secret) = get_route_and_payment_hash!(&nodes[0], nodes[2], amt_msat);
2586+
#[cfg(feature = "std")]
2587+
let payment_expiry_secs = SystemTime::UNIX_EPOCH.elapsed().unwrap().as_secs() + 60 * 60;
2588+
#[cfg(not(feature = "std"))]
2589+
let payment_expiry_secs = 60 * 60;
2590+
let mut invoice_features = InvoiceFeatures::empty();
2591+
invoice_features.set_variable_length_onion_required();
2592+
invoice_features.set_payment_secret_required();
2593+
invoice_features.set_basic_mpp_optional();
2594+
let payment_params = PaymentParameters::from_node_id(nodes[1].node.get_our_node_id(), TEST_FINAL_CLTV)
2595+
.with_expiry_time(payment_expiry_secs as u64)
2596+
.with_features(invoice_features);
2597+
let mut route_params = RouteParameters {
2598+
payment_params,
2599+
final_value_msat: amt_msat,
2600+
final_cltv_expiry_delta: TEST_FINAL_CLTV,
2601+
};
2602+
2603+
let mut route = Route {
2604+
paths: vec![
2605+
vec![RouteHop {
2606+
pubkey: nodes[1].node.get_our_node_id(),
2607+
node_features: nodes[1].node.node_features(),
2608+
short_channel_id: chan_1_scid,
2609+
channel_features: nodes[1].node.channel_features(),
2610+
fee_msat: 0,
2611+
cltv_expiry_delta: 100,
2612+
}, RouteHop {
2613+
pubkey: nodes[3].node.get_our_node_id(),
2614+
node_features: nodes[2].node.node_features(),
2615+
short_channel_id: 42, // Set a random SCID which nodes[1] will fail as unknown
2616+
channel_features: nodes[2].node.channel_features(),
2617+
fee_msat: amt_msat / 1000,
2618+
cltv_expiry_delta: 100,
2619+
}],
2620+
vec![RouteHop {
2621+
pubkey: nodes[2].node.get_our_node_id(),
2622+
node_features: nodes[2].node.node_features(),
2623+
short_channel_id: chan_3_scid,
2624+
channel_features: nodes[2].node.channel_features(),
2625+
fee_msat: 100_000,
2626+
cltv_expiry_delta: 100,
2627+
}, RouteHop {
2628+
pubkey: nodes[3].node.get_our_node_id(),
2629+
node_features: nodes[3].node.node_features(),
2630+
short_channel_id: chan_4_scid,
2631+
channel_features: nodes[3].node.channel_features(),
2632+
fee_msat: amt_msat - amt_msat / 1000,
2633+
cltv_expiry_delta: 100,
2634+
}]
2635+
],
2636+
payment_params: Some(PaymentParameters::from_node_id(nodes[2].node.get_our_node_id(), TEST_FINAL_CLTV)),
2637+
};
2638+
nodes[0].router.expect_find_route(route_params.clone(), Ok(route.clone()));
2639+
2640+
nodes[0].node.send_payment_with_retry(payment_hash, &Some(payment_secret), PaymentId(payment_hash.0), route_params.clone(), Retry::Attempts(0xdeadbeef)).unwrap();
2641+
check_added_monitors!(nodes[0], 2);
2642+
let mut send_msg_events = nodes[0].node.get_and_clear_pending_msg_events();
2643+
assert_eq!(send_msg_events.len(), 2);
2644+
send_msg_events.retain(|msg|
2645+
if let MessageSendEvent::UpdateHTLCs { node_id, .. } = msg {
2646+
// Drop the commitment update for nodes[2], we can just let that one sit pending
2647+
// forever.
2648+
*node_id == nodes[1].node.get_our_node_id()
2649+
} else { panic!(); }
2650+
);
2651+
2652+
// from here on out, the retry `RouteParameters` amount will be amt/1000
2653+
route_params.final_value_msat /= 1000;
2654+
route.paths.pop();
2655+
2656+
let end_time = Instant::now() + Duration::from_secs(1);
2657+
macro_rules! thread_body { () => { {
2658+
// We really want std::thread::scope, but its not stable until 1.63. Until then, we get unsafe.
2659+
let node_ref = NodePtr::from_node(&nodes[0]);
2660+
move || {
2661+
let node_a = unsafe { &*node_ref.0 };
2662+
while Instant::now() < end_time {
2663+
node_a.node.get_and_clear_pending_events(); // wipe the PendingHTLCsForwardable
2664+
// Ignore if we have any pending events, just always pretend we just got a
2665+
// PendingHTLCsForwardable
2666+
node_a.node.process_pending_htlc_forwards();
2667+
}
2668+
}
2669+
} } }
2670+
let mut threads = Vec::new();
2671+
for _ in 0..16 { threads.push(std::thread::spawn(thread_body!())); }
2672+
2673+
// Back in the main thread, poll pending messages and make sure that we never have more than
2674+
// one HTLC pending at a time. Note that the commitment_signed_dance will fail horribly if
2675+
// there are HTLC messages shoved in while its running. This allows us to test that we never
2676+
// generate an additional update_add_htlc until we've fully failed the first.
2677+
let mut previously_failed_channels = Vec::new();
2678+
loop {
2679+
assert_eq!(send_msg_events.len(), 1);
2680+
let send_event = SendEvent::from_event(send_msg_events.pop().unwrap());
2681+
assert_eq!(send_event.msgs.len(), 1);
2682+
2683+
nodes[1].node.handle_update_add_htlc(&nodes[0].node.get_our_node_id(), &send_event.msgs[0]);
2684+
commitment_signed_dance!(nodes[1], nodes[0], send_event.commitment_msg, false, true);
2685+
2686+
// Note that we only push one route into `expect_find_route` at a time, because that's all
2687+
// the retries (should) need. If the bug is reintroduced "real" routes may be selected, but
2688+
// we should still ultimately fail for the same reason - because we're trying to send too
2689+
// many HTLCs at once.
2690+
let mut new_route_params = route_params.clone();
2691+
previously_failed_channels.push(route.paths[0][1].short_channel_id);
2692+
new_route_params.payment_params.previously_failed_channels = previously_failed_channels.clone();
2693+
route.paths[0][1].short_channel_id += 1;
2694+
nodes[0].router.expect_find_route(new_route_params, Ok(route.clone()));
2695+
2696+
let bs_fail_updates = get_htlc_update_msgs!(nodes[1], nodes[0].node.get_our_node_id());
2697+
nodes[0].node.handle_update_fail_htlc(&nodes[1].node.get_our_node_id(), &bs_fail_updates.update_fail_htlcs[0]);
2698+
// The "normal" commitment_signed_dance delivers the final RAA and then calls
2699+
// `check_added_monitors` to ensure only the one RAA-generated monitor update was created.
2700+
// This races with our other threads which may generate an add-HTLCs commitment update via
2701+
// `process_pending_htlc_forwards`. Instead, we defer the monitor update check until after
2702+
// *we've* called `process_pending_htlc_forwards` when its guaranteed to have two updates.
2703+
let last_raa = commitment_signed_dance!(nodes[0], nodes[1], bs_fail_updates.commitment_signed, false, true, false, true);
2704+
nodes[0].node.handle_revoke_and_ack(&nodes[1].node.get_our_node_id(), &last_raa);
2705+
2706+
let cur_time = Instant::now();
2707+
if cur_time > end_time {
2708+
for thread in threads.drain(..) { thread.join().unwrap(); }
2709+
}
2710+
2711+
// Make sure we have some events to handle when we go around...
2712+
nodes[0].node.get_and_clear_pending_events(); // wipe the PendingHTLCsForwardable
2713+
nodes[0].node.process_pending_htlc_forwards();
2714+
send_msg_events = nodes[0].node.get_and_clear_pending_msg_events();
2715+
check_added_monitors!(nodes[0], 2);
2716+
2717+
if cur_time > end_time {
2718+
break;
2719+
}
2720+
}
2721+
}

lightning/src/sync/debug_sync.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ use std::sync::Condvar as StdCondvar;
1414

1515
use crate::prelude::HashMap;
1616

17+
use super::{LockTestExt, LockHeldState};
18+
1719
#[cfg(feature = "backtrace")]
1820
use {crate::prelude::hash_map, backtrace::Backtrace, std::sync::Once};
1921

@@ -168,6 +170,18 @@ impl LockMetadata {
168170
fn pre_lock(this: &Arc<LockMetadata>) { Self::_pre_lock(this, false); }
169171
fn pre_read_lock(this: &Arc<LockMetadata>) -> bool { Self::_pre_lock(this, true) }
170172

173+
fn held_by_thread(this: &Arc<LockMetadata>) -> LockHeldState {
174+
let mut res = LockHeldState::NotHeldByThread;
175+
LOCKS_HELD.with(|held| {
176+
for (locked_idx, _locked) in held.borrow().iter() {
177+
if *locked_idx == this.lock_idx {
178+
res = LockHeldState::HeldByThread;
179+
}
180+
}
181+
});
182+
res
183+
}
184+
171185
fn try_locked(this: &Arc<LockMetadata>) {
172186
LOCKS_HELD.with(|held| {
173187
// Since a try-lock will simply fail if the lock is held already, we do not
@@ -248,6 +262,13 @@ impl<T> Mutex<T> {
248262
}
249263
}
250264

265+
impl <T> LockTestExt for Mutex<T> {
266+
#[inline]
267+
fn held_by_thread(&self) -> LockHeldState {
268+
LockMetadata::held_by_thread(&self.deps)
269+
}
270+
}
271+
251272
pub struct RwLock<T: Sized> {
252273
inner: StdRwLock<T>,
253274
deps: Arc<LockMetadata>,
@@ -332,4 +353,11 @@ impl<T> RwLock<T> {
332353
}
333354
}
334355

356+
impl <T> LockTestExt for RwLock<T> {
357+
#[inline]
358+
fn held_by_thread(&self) -> LockHeldState {
359+
LockMetadata::held_by_thread(&self.deps)
360+
}
361+
}
362+
335363
pub type FairRwLock<T> = RwLock<T>;

lightning/src/util/fairrwlock.rs renamed to lightning/src/sync/fairrwlock.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::sync::{LockResult, RwLock, RwLockReadGuard, RwLockWriteGuard, TryLockResult};
22
use std::sync::atomic::{AtomicUsize, Ordering};
3+
use super::{LockHeldState, LockTestExt};
34

45
/// Rust libstd's RwLock does not provide any fairness guarantees (and, in fact, when used on
56
/// Linux with pthreads under the hood, readers trivially and completely starve writers).
@@ -48,3 +49,11 @@ impl<T> FairRwLock<T> {
4849
self.lock.try_write()
4950
}
5051
}
52+
53+
impl<T> LockTestExt for FairRwLock<T> {
54+
#[inline]
55+
fn held_by_thread(&self) -> LockHeldState {
56+
// fairrwlock is only built in non-test modes, so we should never support tests.
57+
LockHeldState::Unsupported
58+
}
59+
}

lightning/src/sync/mod.rs

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
#[allow(dead_code)] // Depending on the compilation flags some variants are never used
2+
#[derive(Debug, PartialEq, Eq)]
3+
pub(crate) enum LockHeldState {
4+
HeldByThread,
5+
NotHeldByThread,
6+
#[cfg(any(feature = "_bench_unstable", not(test)))]
7+
Unsupported,
8+
}
9+
10+
pub(crate) trait LockTestExt {
11+
fn held_by_thread(&self) -> LockHeldState;
12+
}
13+
114
#[cfg(all(feature = "std", not(feature = "_bench_unstable"), test))]
215
mod debug_sync;
316
#[cfg(all(feature = "std", not(feature = "_bench_unstable"), test))]
@@ -7,9 +20,22 @@ pub use debug_sync::*;
720
mod test_lockorder_checks;
821

922
#[cfg(all(feature = "std", any(feature = "_bench_unstable", not(test))))]
10-
pub use ::std::sync::{Arc, Mutex, Condvar, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};
23+
pub(crate) mod fairrwlock;
24+
#[cfg(all(feature = "std", any(feature = "_bench_unstable", not(test))))]
25+
pub use {std::sync::{Arc, Mutex, Condvar, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard}, fairrwlock::FairRwLock};
26+
1127
#[cfg(all(feature = "std", any(feature = "_bench_unstable", not(test))))]
12-
pub use crate::util::fairrwlock::FairRwLock;
28+
mod ext_impl {
29+
use super::*;
30+
impl<T> LockTestExt for Mutex<T> {
31+
#[inline]
32+
fn held_by_thread(&self) -> LockHeldState { LockHeldState::Unsupported }
33+
}
34+
impl<T> LockTestExt for RwLock<T> {
35+
#[inline]
36+
fn held_by_thread(&self) -> LockHeldState { LockHeldState::Unsupported }
37+
}
38+
}
1339

1440
#[cfg(not(feature = "std"))]
1541
mod nostd_sync;

0 commit comments

Comments
 (0)