Skip to content

Commit

Permalink
Add MPP ID to pending_outbound_htlcs
Browse files Browse the repository at this point in the history
We'll use this to correlate MPP shards in upcoming commits
  • Loading branch information
valentinewallace committed Aug 20, 2021
1 parent 00fa09e commit 8ce4e2f
Showing 1 changed file with 41 additions and 16 deletions.
57 changes: 41 additions & 16 deletions lightning/src/ln/channelmanager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ pub struct ChannelManager<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref,
/// after reloading from disk while replaying blocks against ChannelMonitors.
///
/// Locked *after* channel_state.
pending_outbound_payments: Mutex<HashSet<[u8; 32]>>,
pending_outbound_payments: Mutex<HashSet<([u8; 32], Option<MppId>)>>,

our_network_key: SecretKey,
our_network_pubkey: PublicKey,
Expand Down Expand Up @@ -1807,7 +1807,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
let onion_packet = onion_utils::construct_onion_packet(onion_payloads, onion_keys, prng_seed, payment_hash);

let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
assert!(self.pending_outbound_payments.lock().unwrap().insert(session_priv_bytes));
assert!(self.pending_outbound_payments.lock().unwrap().insert((session_priv_bytes, mpp_id)));

let err: Result<(), _> = loop {
let mut channel_lock = self.channel_state.lock().unwrap();
Expand Down Expand Up @@ -2676,11 +2676,11 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
self.fail_htlc_backwards_internal(channel_state,
htlc_src, &payment_hash, HTLCFailReason::Reason { failure_code, data: onion_failure_data});
},
HTLCSource::OutboundRoute { session_priv, .. } => {
HTLCSource::OutboundRoute { session_priv, mpp_id, .. } => {
if {
let mut session_priv_bytes = [0; 32];
session_priv_bytes.copy_from_slice(&session_priv[..]);
self.pending_outbound_payments.lock().unwrap().remove(&session_priv_bytes)
self.pending_outbound_payments.lock().unwrap().remove(&(session_priv_bytes, mpp_id))
} {
self.pending_events.lock().unwrap().push(
events::Event::PaymentFailed {
Expand Down Expand Up @@ -2716,11 +2716,11 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
// from block_connected which may run during initialization prior to the chain_monitor
// being fully configured. See the docs for `ChannelManagerReadArgs` for more.
match source {
HTLCSource::OutboundRoute { ref path, session_priv, .. } => {
HTLCSource::OutboundRoute { ref path, session_priv, mpp_id, .. } => {
if {
let mut session_priv_bytes = [0; 32];
session_priv_bytes.copy_from_slice(&session_priv[..]);
!self.pending_outbound_payments.lock().unwrap().remove(&session_priv_bytes)
!self.pending_outbound_payments.lock().unwrap().remove(&(session_priv_bytes, mpp_id))
} {
log_trace!(self.logger, "Received duplicative fail for HTLC with payment_hash {}", log_bytes!(payment_hash.0));
return;
Expand Down Expand Up @@ -2967,12 +2967,12 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana

fn claim_funds_internal(&self, mut channel_state_lock: MutexGuard<ChannelHolder<Signer>>, source: HTLCSource, payment_preimage: PaymentPreimage, forwarded_htlc_value_msat: Option<u64>, from_onchain: bool) {
match source {
HTLCSource::OutboundRoute { session_priv, .. } => {
HTLCSource::OutboundRoute { session_priv, mpp_id, .. } => {
mem::drop(channel_state_lock);
if {
let mut session_priv_bytes = [0; 32];
session_priv_bytes.copy_from_slice(&session_priv[..]);
self.pending_outbound_payments.lock().unwrap().remove(&session_priv_bytes)
self.pending_outbound_payments.lock().unwrap().remove(&(session_priv_bytes, mpp_id))
} {
let mut pending_events = self.pending_events.lock().unwrap();
pending_events.push(events::Event::PaymentSent {
Expand Down Expand Up @@ -4919,11 +4919,15 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> Writeable f

let pending_outbound_payments = self.pending_outbound_payments.lock().unwrap();
(pending_outbound_payments.len() as u64).write(writer)?;
for session_priv in pending_outbound_payments.iter() {
let mut pending_outbound_mpp_ids = Vec::new();
for (session_priv, mpp_id) in pending_outbound_payments.iter() {
session_priv.write(writer)?;
pending_outbound_mpp_ids.push(mpp_id);
}

write_tlv_fields!(writer, {});
write_tlv_fields!(writer, {
// (0, pending_outbound_mpp_ids, vec_type),
});

Ok(())
}
Expand Down Expand Up @@ -5177,14 +5181,35 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
}

let pending_outbound_payments_count: u64 = Readable::read(reader)?;
let mut pending_outbound_payments: HashSet<[u8; 32]> = HashSet::with_capacity(cmp::min(pending_outbound_payments_count as usize, MAX_ALLOC_SIZE/32));
let mut pending_outbound_payments: HashSet<([u8; 32], Option<MppId>)> = HashSet::with_capacity(cmp::min(pending_outbound_payments_count as usize, MAX_ALLOC_SIZE/32));
let mut pending_outbound_session_privs = Vec::new();

for _ in 0..pending_outbound_payments_count {
if !pending_outbound_payments.insert(Readable::read(reader)?) {
return Err(DecodeError::InvalidValue);
}
pending_outbound_session_privs.push(Readable::read(reader)?);
}

read_tlv_fields!(reader, {});
let mut pending_outbound_mpp_ids = Vec::new();
read_tlv_fields!(reader, {
// TODO: how to make this line work
// (0, pending_outbound_mpp_ids, vec_type),
});

if pending_outbound_mpp_ids.len() == pending_outbound_session_privs.len() {
for (session_priv, mpp_id) in pending_outbound_session_privs.iter().zip(
pending_outbound_mpp_ids.iter()) {
if !pending_outbound_payments.insert((*session_priv, *mpp_id)) {
return Err(DecodeError::InvalidValue)
}
}
} else if pending_outbound_mpp_ids.len() == 0 {
for session_priv in pending_outbound_session_privs.iter() {
if !pending_outbound_payments.insert((*session_priv, None)) {
return Err(DecodeError::InvalidValue);
}
}
} else {
return Err(DecodeError::InvalidValue);
}

let mut secp_ctx = Secp256k1::new();
secp_ctx.seeded_randomize(&args.keys_manager.get_secure_random_bytes());
Expand Down Expand Up @@ -5428,7 +5453,7 @@ mod tests {
expect_payment_failed!(nodes[0], our_payment_hash, true);

// Send the second half of the original MPP payment.
nodes[0].node.send_payment_along_path(&route.paths[0], &our_payment_hash, &Some(payment_secret), 200_000, cur_height, payment_id, &None).unwrap();
nodes[0].node.send_payment_along_path(&route.paths[0], &our_payment_hash, &Some(payment_secret), 200_000, cur_height, mpp_id, &None).unwrap();
check_added_monitors!(nodes[0], 1);
let mut events = nodes[0].node.get_and_clear_pending_msg_events();
assert_eq!(events.len(), 1);
Expand Down

0 comments on commit 8ce4e2f

Please sign in to comment.