Skip to content

Commit 8ce4e2f

Browse files
Add MPP ID to pending_outbound_htlcs
We'll use this to correlate MPP shards in upcoming commits
1 parent 00fa09e commit 8ce4e2f

File tree

1 file changed

+41
-16
lines changed

1 file changed

+41
-16
lines changed

lightning/src/ln/channelmanager.rs

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ pub struct ChannelManager<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref,
490490
/// after reloading from disk while replaying blocks against ChannelMonitors.
491491
///
492492
/// Locked *after* channel_state.
493-
pending_outbound_payments: Mutex<HashSet<[u8; 32]>>,
493+
pending_outbound_payments: Mutex<HashSet<([u8; 32], Option<MppId>)>>,
494494

495495
our_network_key: SecretKey,
496496
our_network_pubkey: PublicKey,
@@ -1807,7 +1807,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
18071807
let onion_packet = onion_utils::construct_onion_packet(onion_payloads, onion_keys, prng_seed, payment_hash);
18081808

18091809
let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
1810-
assert!(self.pending_outbound_payments.lock().unwrap().insert(session_priv_bytes));
1810+
assert!(self.pending_outbound_payments.lock().unwrap().insert((session_priv_bytes, mpp_id)));
18111811

18121812
let err: Result<(), _> = loop {
18131813
let mut channel_lock = self.channel_state.lock().unwrap();
@@ -2676,11 +2676,11 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
26762676
self.fail_htlc_backwards_internal(channel_state,
26772677
htlc_src, &payment_hash, HTLCFailReason::Reason { failure_code, data: onion_failure_data});
26782678
},
2679-
HTLCSource::OutboundRoute { session_priv, .. } => {
2679+
HTLCSource::OutboundRoute { session_priv, mpp_id, .. } => {
26802680
if {
26812681
let mut session_priv_bytes = [0; 32];
26822682
session_priv_bytes.copy_from_slice(&session_priv[..]);
2683-
self.pending_outbound_payments.lock().unwrap().remove(&session_priv_bytes)
2683+
self.pending_outbound_payments.lock().unwrap().remove(&(session_priv_bytes, mpp_id))
26842684
} {
26852685
self.pending_events.lock().unwrap().push(
26862686
events::Event::PaymentFailed {
@@ -2716,11 +2716,11 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
27162716
// from block_connected which may run during initialization prior to the chain_monitor
27172717
// being fully configured. See the docs for `ChannelManagerReadArgs` for more.
27182718
match source {
2719-
HTLCSource::OutboundRoute { ref path, session_priv, .. } => {
2719+
HTLCSource::OutboundRoute { ref path, session_priv, mpp_id, .. } => {
27202720
if {
27212721
let mut session_priv_bytes = [0; 32];
27222722
session_priv_bytes.copy_from_slice(&session_priv[..]);
2723-
!self.pending_outbound_payments.lock().unwrap().remove(&session_priv_bytes)
2723+
!self.pending_outbound_payments.lock().unwrap().remove(&(session_priv_bytes, mpp_id))
27242724
} {
27252725
log_trace!(self.logger, "Received duplicative fail for HTLC with payment_hash {}", log_bytes!(payment_hash.0));
27262726
return;
@@ -2967,12 +2967,12 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
29672967

29682968
fn claim_funds_internal(&self, mut channel_state_lock: MutexGuard<ChannelHolder<Signer>>, source: HTLCSource, payment_preimage: PaymentPreimage, forwarded_htlc_value_msat: Option<u64>, from_onchain: bool) {
29692969
match source {
2970-
HTLCSource::OutboundRoute { session_priv, .. } => {
2970+
HTLCSource::OutboundRoute { session_priv, mpp_id, .. } => {
29712971
mem::drop(channel_state_lock);
29722972
if {
29732973
let mut session_priv_bytes = [0; 32];
29742974
session_priv_bytes.copy_from_slice(&session_priv[..]);
2975-
self.pending_outbound_payments.lock().unwrap().remove(&session_priv_bytes)
2975+
self.pending_outbound_payments.lock().unwrap().remove(&(session_priv_bytes, mpp_id))
29762976
} {
29772977
let mut pending_events = self.pending_events.lock().unwrap();
29782978
pending_events.push(events::Event::PaymentSent {
@@ -4919,11 +4919,15 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> Writeable f
49194919

49204920
let pending_outbound_payments = self.pending_outbound_payments.lock().unwrap();
49214921
(pending_outbound_payments.len() as u64).write(writer)?;
4922-
for session_priv in pending_outbound_payments.iter() {
4922+
let mut pending_outbound_mpp_ids = Vec::new();
4923+
for (session_priv, mpp_id) in pending_outbound_payments.iter() {
49234924
session_priv.write(writer)?;
4925+
pending_outbound_mpp_ids.push(mpp_id);
49244926
}
49254927

4926-
write_tlv_fields!(writer, {});
4928+
write_tlv_fields!(writer, {
4929+
// (0, pending_outbound_mpp_ids, vec_type),
4930+
});
49274931

49284932
Ok(())
49294933
}
@@ -5177,14 +5181,35 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
51775181
}
51785182

51795183
let pending_outbound_payments_count: u64 = Readable::read(reader)?;
5180-
let mut pending_outbound_payments: HashSet<[u8; 32]> = HashSet::with_capacity(cmp::min(pending_outbound_payments_count as usize, MAX_ALLOC_SIZE/32));
5184+
let mut pending_outbound_payments: HashSet<([u8; 32], Option<MppId>)> = HashSet::with_capacity(cmp::min(pending_outbound_payments_count as usize, MAX_ALLOC_SIZE/32));
5185+
let mut pending_outbound_session_privs = Vec::new();
5186+
51815187
for _ in 0..pending_outbound_payments_count {
5182-
if !pending_outbound_payments.insert(Readable::read(reader)?) {
5183-
return Err(DecodeError::InvalidValue);
5184-
}
5188+
pending_outbound_session_privs.push(Readable::read(reader)?);
51855189
}
51865190

5187-
read_tlv_fields!(reader, {});
5191+
let mut pending_outbound_mpp_ids = Vec::new();
5192+
read_tlv_fields!(reader, {
5193+
// TODO: how to make this line work
5194+
// (0, pending_outbound_mpp_ids, vec_type),
5195+
});
5196+
5197+
if pending_outbound_mpp_ids.len() == pending_outbound_session_privs.len() {
5198+
for (session_priv, mpp_id) in pending_outbound_session_privs.iter().zip(
5199+
pending_outbound_mpp_ids.iter()) {
5200+
if !pending_outbound_payments.insert((*session_priv, *mpp_id)) {
5201+
return Err(DecodeError::InvalidValue)
5202+
}
5203+
}
5204+
} else if pending_outbound_mpp_ids.len() == 0 {
5205+
for session_priv in pending_outbound_session_privs.iter() {
5206+
if !pending_outbound_payments.insert((*session_priv, None)) {
5207+
return Err(DecodeError::InvalidValue);
5208+
}
5209+
}
5210+
} else {
5211+
return Err(DecodeError::InvalidValue);
5212+
}
51885213

51895214
let mut secp_ctx = Secp256k1::new();
51905215
secp_ctx.seeded_randomize(&args.keys_manager.get_secure_random_bytes());
@@ -5428,7 +5453,7 @@ mod tests {
54285453
expect_payment_failed!(nodes[0], our_payment_hash, true);
54295454

54305455
// Send the second half of the original MPP payment.
5431-
nodes[0].node.send_payment_along_path(&route.paths[0], &our_payment_hash, &Some(payment_secret), 200_000, cur_height, payment_id, &None).unwrap();
5456+
nodes[0].node.send_payment_along_path(&route.paths[0], &our_payment_hash, &Some(payment_secret), 200_000, cur_height, mpp_id, &None).unwrap();
54325457
check_added_monitors!(nodes[0], 1);
54335458
let mut events = nodes[0].node.get_and_clear_pending_msg_events();
54345459
assert_eq!(events.len(), 1);

0 commit comments

Comments
 (0)