Skip to content

Align get_route's interface with ChannelManager and Invoice #946

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions fuzz/src/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use lightning::chain::transaction::OutPoint;
use lightning::ln::channelmanager::ChannelDetails;
use lightning::ln::features::InitFeatures;
use lightning::ln::msgs;
use lightning::routing::router::{get_route, RouteHintHop};
use lightning::routing::router::{get_route, RouteHint, RouteHintHop};
use lightning::util::logger::Logger;
use lightning::util::ser::Readable;
use lightning::routing::network_graph::{NetworkGraph, RoutingFees};
Expand Down Expand Up @@ -225,13 +225,13 @@ pub fn do_test<Out: test_logger::Output>(data: &[u8], out: Out) {
Some(&first_hops_vec[..])
},
};
let mut last_hops_vec = Vec::new();
let mut last_hops = Vec::new();
{
let count = get_slice!(1)[0];
for _ in 0..count {
scid += 1;
let rnid = node_pks.iter().skip(slice_to_be16(get_slice!(2))as usize % node_pks.len()).next().unwrap();
last_hops_vec.push(RouteHintHop {
last_hops.push(RouteHint(vec![RouteHintHop {
src_node_id: *rnid,
short_channel_id: scid,
fees: RoutingFees {
Expand All @@ -241,10 +241,9 @@ pub fn do_test<Out: test_logger::Output>(data: &[u8], out: Out) {
cltv_expiry_delta: slice_to_be16(get_slice!(2)),
htlc_minimum_msat: Some(slice_to_be64(get_slice!(8))),
htlc_maximum_msat: None,
});
}]));
}
}
let last_hops = &last_hops_vec[..];
for target in node_pks.iter() {
let _ = get_route(&our_pubkey, &net_graph, target, None,
first_hops.map(|c| c.iter().collect::<Vec<_>>()).as_ref().map(|a| a.as_slice()),
Expand Down
22 changes: 11 additions & 11 deletions lightning-invoice/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use bitcoin_hashes::Hash;
use bitcoin_hashes::sha256;
use lightning::ln::PaymentSecret;
use lightning::routing::network_graph::RoutingFees;
use lightning::routing::router::RouteHintHop;
use lightning::routing::router::{RouteHint, RouteHintHop};

use num_traits::{CheckedAdd, CheckedMul};

Expand All @@ -21,7 +21,7 @@ use secp256k1::recovery::{RecoveryId, RecoverableSignature};
use secp256k1::key::PublicKey;

use super::{Invoice, Sha256, TaggedField, ExpiryTime, MinFinalCltvExpiry, Fallback, PayeePubKey, InvoiceSignature, PositiveTimestamp,
SemanticError, RouteHint, Description, RawTaggedField, Currency, RawHrp, SiPrefix, RawInvoice, constants, SignedRawInvoice,
SemanticError, PrivateRoute, Description, RawTaggedField, Currency, RawHrp, SiPrefix, RawInvoice, constants, SignedRawInvoice,
RawDataPart, CreationError, InvoiceFeatures};

use self::hrp_sm::parse_hrp;
Expand Down Expand Up @@ -433,8 +433,8 @@ impl FromBase32 for TaggedField {
Ok(TaggedField::MinFinalCltvExpiry(MinFinalCltvExpiry::from_base32(field_data)?)),
constants::TAG_FALLBACK =>
Ok(TaggedField::Fallback(Fallback::from_base32(field_data)?)),
constants::TAG_ROUTE =>
Ok(TaggedField::Route(RouteHint::from_base32(field_data)?)),
constants::TAG_PRIVATE_ROUTE =>
Ok(TaggedField::PrivateRoute(PrivateRoute::from_base32(field_data)?)),
constants::TAG_PAYMENT_SECRET =>
Ok(TaggedField::PaymentSecret(PaymentSecret::from_base32(field_data)?)),
constants::TAG_FEATURES =>
Expand Down Expand Up @@ -558,10 +558,10 @@ impl FromBase32 for Fallback {
}
}

impl FromBase32 for RouteHint {
impl FromBase32 for PrivateRoute {
type Err = ParseError;

fn from_base32(field_data: &[u5]) -> Result<RouteHint, ParseError> {
fn from_base32(field_data: &[u5]) -> Result<PrivateRoute, ParseError> {
let bytes = Vec::<u8>::from_base32(field_data)?;

if bytes.len() % 51 != 0 {
Expand Down Expand Up @@ -593,7 +593,7 @@ impl FromBase32 for RouteHint {
route_hops.push(hop);
}

Ok(RouteHint(route_hops))
Ok(PrivateRoute(RouteHint(route_hops)))
}
}

Expand Down Expand Up @@ -930,8 +930,8 @@ mod test {
#[test]
fn test_parse_route() {
use lightning::routing::network_graph::RoutingFees;
use lightning::routing::router::RouteHintHop;
use ::RouteHint;
use lightning::routing::router::{RouteHint, RouteHintHop};
use ::PrivateRoute;
use bech32::FromBase32;
use de::parse_int_be;

Expand Down Expand Up @@ -976,10 +976,10 @@ mod test {
htlc_maximum_msat: None
});

assert_eq!(RouteHint::from_base32(&input), Ok(RouteHint(expected)));
assert_eq!(PrivateRoute::from_base32(&input), Ok(PrivateRoute(RouteHint(expected))));

assert_eq!(
RouteHint::from_base32(&[u5::try_from_u8(0).unwrap(); 40][..]),
PrivateRoute::from_base32(&[u5::try_from_u8(0).unwrap(); 40][..]),
Err(ParseError::UnexpectedEndOfTaggedFields)
);
}
Expand Down
115 changes: 71 additions & 44 deletions lightning-invoice/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use lightning::ln::PaymentSecret;
use lightning::ln::features::InvoiceFeatures;
#[cfg(any(doc, test))]
use lightning::routing::network_graph::RoutingFees;
use lightning::routing::router::RouteHintHop;
use lightning::routing::router::RouteHint;

use secp256k1::key::PublicKey;
use secp256k1::{Message, Secp256k1};
Expand Down Expand Up @@ -362,7 +362,7 @@ pub enum TaggedField {
ExpiryTime(ExpiryTime),
MinFinalCltvExpiry(MinFinalCltvExpiry),
Fallback(Fallback),
Route(RouteHint),
PrivateRoute(PrivateRoute),
PaymentSecret(PaymentSecret),
Features(InvoiceFeatures),
}
Expand Down Expand Up @@ -419,7 +419,7 @@ pub struct InvoiceSignature(pub RecoverableSignature);
/// The encoded route has to be <1024 5bit characters long (<=639 bytes or <=12 hops)
///
#[derive(Eq, PartialEq, Debug, Clone)]
pub struct RouteHint(Vec<RouteHintHop>);
pub struct PrivateRoute(RouteHint);

/// Tag constants as specified in BOLT11
#[allow(missing_docs)]
Expand All @@ -431,7 +431,7 @@ pub mod constants {
pub const TAG_EXPIRY_TIME: u8 = 6;
pub const TAG_MIN_FINAL_CLTV_EXPIRY: u8 = 24;
pub const TAG_FALLBACK: u8 = 9;
pub const TAG_ROUTE: u8 = 3;
pub const TAG_PRIVATE_ROUTE: u8 = 3;
pub const TAG_PAYMENT_SECRET: u8 = 16;
pub const TAG_FEATURES: u8 = 5;
}
Expand Down Expand Up @@ -509,9 +509,9 @@ impl<D: tb::Bool, H: tb::Bool, T: tb::Bool, C: tb::Bool, S: tb::Bool> InvoiceBui
}

/// Adds a private route.
pub fn route(mut self, route: Vec<RouteHintHop>) -> Self {
match RouteHint::new(route) {
Ok(r) => self.tagged_fields.push(TaggedField::Route(r)),
pub fn private_route(mut self, hint: RouteHint) -> Self {
match PrivateRoute::new(hint) {
Ok(r) => self.tagged_fields.push(TaggedField::PrivateRoute(r)),
Err(e) => self.error = Some(e),
}
self
Expand Down Expand Up @@ -747,7 +747,7 @@ impl SignedRawInvoice {
/// Finds the first element of an enum stream of a given variant and extracts one member of the
/// variant. If no element was found `None` gets returned.
///
/// The following example would extract the first
/// The following example would extract the first B.
/// ```
/// use Enum::*
///
Expand All @@ -761,11 +761,35 @@ impl SignedRawInvoice {
/// assert_eq!(find_extract!(elements.iter(), Enum::B(ref x), x), Some(3u16))
/// ```
macro_rules! find_extract {
($iter:expr, $enm:pat, $enm_var:ident) => {
($iter:expr, $enm:pat, $enm_var:ident) => {
find_all_extract!($iter, $enm, $enm_var).next()
};
}

/// Finds the all elements of an enum stream of a given variant and extracts one member of the
/// variant through an iterator.
///
/// The following example would extract all A.
/// ```
/// use Enum::*
///
/// enum Enum {
/// A(u8),
/// B(u16)
/// }
///
/// let elements = vec![A(1), A(2), B(3), A(4)]
///
/// assert_eq!(
/// find_all_extract!(elements.iter(), Enum::A(ref x), x).collect::<Vec<u8>>(),
/// vec![1u8, 2u8, 4u8])
/// ```
macro_rules! find_all_extract {
($iter:expr, $enm:pat, $enm_var:ident) => {
$iter.filter_map(|tf| match *tf {
$enm => Some($enm_var),
_ => None,
}).next()
})
};
}

Expand Down Expand Up @@ -886,17 +910,11 @@ impl RawInvoice {

/// (C-not exported) as we don't support Vec<&NonOpaqueType>
pub fn fallbacks(&self) -> Vec<&Fallback> {
self.known_tagged_fields().filter_map(|tf| match tf {
&TaggedField::Fallback(ref f) => Some(f),
_ => None,
}).collect::<Vec<&Fallback>>()
find_all_extract!(self.known_tagged_fields(), TaggedField::Fallback(ref x), x).collect()
}

pub fn routes(&self) -> Vec<&RouteHint> {
self.known_tagged_fields().filter_map(|tf| match tf {
&TaggedField::Route(ref r) => Some(r),
_ => None,
}).collect::<Vec<&RouteHint>>()
pub fn private_routes(&self) -> Vec<&PrivateRoute> {
find_all_extract!(self.known_tagged_fields(), TaggedField::PrivateRoute(ref x), x).collect()
}

pub fn amount_pico_btc(&self) -> Option<u64> {
Expand Down Expand Up @@ -1048,7 +1066,7 @@ impl Invoice {
Ok(())
}

/// Constructs an `Invoice` from a `SignedInvoice` by checking all its invariants.
/// Constructs an `Invoice` from a `SignedRawInvoice` by checking all its invariants.
/// ```
/// use lightning_invoice::*;
///
Expand Down Expand Up @@ -1145,8 +1163,15 @@ impl Invoice {
}

/// Returns a list of all routes included in the invoice
pub fn routes(&self) -> Vec<&RouteHint> {
self.signed_invoice.routes()
pub fn private_routes(&self) -> Vec<&PrivateRoute> {
self.signed_invoice.private_routes()
}

/// Returns a list of all routes included in the invoice as the underlying hints
pub fn route_hints(&self) -> Vec<&RouteHint> {
find_all_extract!(
self.signed_invoice.known_tagged_fields(), TaggedField::PrivateRoute(ref x), x
).map(|route| &**route).collect()
}

/// Returns the currency for which the invoice was issued
Expand Down Expand Up @@ -1177,7 +1202,7 @@ impl TaggedField {
TaggedField::ExpiryTime(_) => constants::TAG_EXPIRY_TIME,
TaggedField::MinFinalCltvExpiry(_) => constants::TAG_MIN_FINAL_CLTV_EXPIRY,
TaggedField::Fallback(_) => constants::TAG_FALLBACK,
TaggedField::Route(_) => constants::TAG_ROUTE,
TaggedField::PrivateRoute(_) => constants::TAG_PRIVATE_ROUTE,
TaggedField::PaymentSecret(_) => constants::TAG_PAYMENT_SECRET,
TaggedField::Features(_) => constants::TAG_FEATURES,
};
Expand Down Expand Up @@ -1268,32 +1293,32 @@ impl ExpiryTime {
}
}

impl RouteHint {
/// Create a new (partial) route from a list of hops
pub fn new(hops: Vec<RouteHintHop>) -> Result<RouteHint, CreationError> {
if hops.len() <= 12 {
Ok(RouteHint(hops))
impl PrivateRoute {
/// Creates a new (partial) route from a list of hops
pub fn new(hops: RouteHint) -> Result<PrivateRoute, CreationError> {
if hops.0.len() <= 12 {
Ok(PrivateRoute(hops))
} else {
Err(CreationError::RouteTooLong)
}
}

/// Returrn the underlying vector of hops
pub fn into_inner(self) -> Vec<RouteHintHop> {
/// Returns the underlying list of hops
pub fn into_inner(self) -> RouteHint {
self.0
}
}

impl Into<Vec<RouteHintHop>> for RouteHint {
fn into(self) -> Vec<RouteHintHop> {
impl Into<RouteHint> for PrivateRoute {
fn into(self) -> RouteHint {
self.into_inner()
}
}

impl Deref for RouteHint {
type Target = Vec<RouteHintHop>;
impl Deref for PrivateRoute {
type Target = RouteHint;

fn deref(&self) -> &Vec<RouteHintHop> {
fn deref(&self) -> &RouteHint {
&self.0
}
}
Expand Down Expand Up @@ -1652,6 +1677,7 @@ mod test {
#[test]
fn test_builder_fail() {
use ::*;
use lightning::routing::router::RouteHintHop;
use std::iter::FromIterator;
use secp256k1::key::PublicKey;

Expand Down Expand Up @@ -1686,10 +1712,10 @@ mod test {
htlc_minimum_msat: None,
htlc_maximum_msat: None,
};
let too_long_route = vec![route_hop; 13];
let too_long_route = RouteHint(vec![route_hop; 13]);
let long_route_res = builder.clone()
.description("Test".into())
.route(too_long_route)
.private_route(too_long_route)
.build_raw();
assert_eq!(long_route_res, Err(CreationError::RouteTooLong));

Expand All @@ -1704,6 +1730,7 @@ mod test {
#[test]
fn test_builder_ok() {
use ::*;
use lightning::routing::router::RouteHintHop;
use secp256k1::Secp256k1;
use secp256k1::key::{SecretKey, PublicKey};
use std::time::{UNIX_EPOCH, Duration};
Expand All @@ -1719,7 +1746,7 @@ mod test {
).unwrap();
let public_key = PublicKey::from_secret_key(&secp_ctx, &private_key);

let route_1 = vec![
let route_1 = RouteHint(vec![
RouteHintHop {
src_node_id: public_key.clone(),
short_channel_id: de::parse_int_be(&[123; 8], 256).expect("short chan ID slice too big?"),
Expand All @@ -1742,9 +1769,9 @@ mod test {
htlc_minimum_msat: None,
htlc_maximum_msat: None,
}
];
]);

let route_2 = vec![
let route_2 = RouteHint(vec![
RouteHintHop {
src_node_id: public_key.clone(),
short_channel_id: 0,
Expand All @@ -1767,7 +1794,7 @@ mod test {
htlc_minimum_msat: None,
htlc_maximum_msat: None,
}
];
]);

let builder = InvoiceBuilder::new(Currency::BitcoinTestnet)
.amount_pico_btc(123)
Expand All @@ -1776,8 +1803,8 @@ mod test {
.expiry_time(Duration::from_secs(54321))
.min_final_cltv_expiry(144)
.fallback(Fallback::PubKeyHash([0;20]))
.route(route_1.clone())
.route(route_2.clone())
.private_route(route_1.clone())
.private_route(route_2.clone())
.description_hash(sha256::Hash::from_slice(&[3;32][..]).unwrap())
.payment_hash(sha256::Hash::from_slice(&[21;32][..]).unwrap())
.payment_secret(PaymentSecret([42; 32]))
Expand All @@ -1800,7 +1827,7 @@ mod test {
assert_eq!(invoice.expiry_time(), Duration::from_secs(54321));
assert_eq!(invoice.min_final_cltv_expiry(), 144);
assert_eq!(invoice.fallbacks(), vec![&Fallback::PubKeyHash([0;20])]);
assert_eq!(invoice.routes(), vec![&RouteHint(route_1), &RouteHint(route_2)]);
assert_eq!(invoice.private_routes(), vec![&PrivateRoute(route_1), &PrivateRoute(route_2)]);
assert_eq!(
invoice.description(),
InvoiceDescription::Hash(&Sha256(sha256::Hash::from_slice(&[3;32][..]).unwrap()))
Expand Down
Loading