diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f7157dc..558d65e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -39,6 +39,10 @@ jobs: if: matrix.check-fmt run: rustup component add rustfmt && cargo fmt --all -- --check - name: Test on Rust ${{ matrix.toolchain }} - run: cargo test - - name: Test on Rust ${{ matrix.toolchain }} with LSPS1 support - run: RUSTFLAGS="--cfg lsps1" cargo test + run: | + cargo test + RUSTFLAGS="--cfg lsps1" cargo test + - name: Test on Rust ${{ matrix.toolchain }} with no-std support + run: | + cargo test --no-default-features --features no-std + RUSTFLAGS="--cfg lsps1" cargo test --no-default-features --features no-std diff --git a/Cargo.toml b/Cargo.toml index 261762c..a71a03c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,12 +7,18 @@ description = "Types and primitives to integrate a spec-compliant LSP with an LD # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[features] +default = ["std"] +std = ["lightning/std", "bitcoin/std"] +no-std = ["hashbrown", "lightning/no-std", "bitcoin/no-std", "core2/alloc"] + [dependencies] -lightning = { version = "0.0.118", default-features = false, features = ["max_level_trace", "std"] } +lightning = { version = "0.0.118", default-features = false, features = ["max_level_trace"] } lightning-invoice = "0.26.0" +bitcoin = { version = "0.29.0", default-features = false } +hashbrown = { version = "0.8", optional = true } +core2 = { version = "0.3.0", optional = true, default-features = false } -bitcoin = "0.29.0" - -chrono = { version = "0.4", default-features = false, features = ["std", "serde"] } +chrono = { version = "0.4", default-features = false, features = ["serde", "alloc"] } serde = { version = "1.0", default-features = false, features = ["derive", "alloc"] } serde_json = "1.0" diff --git a/src/events.rs b/src/events.rs index 9038c7a..b3a3a8b 100644 --- a/src/events.rs +++ b/src/events.rs @@ -18,25 +18,42 @@ #[cfg(lsps1)] use crate::lsps1; use crate::lsps2; -use std::collections::VecDeque; -use std::sync::{Condvar, Mutex}; +use crate::prelude::{Vec, VecDeque}; +use crate::sync::Mutex; -#[derive(Default)] pub(crate) struct EventQueue { queue: Mutex>, - condvar: Condvar, + #[cfg(feature = "std")] + condvar: std::sync::Condvar, } impl EventQueue { + pub fn new() -> Self { + let queue = Mutex::new(VecDeque::new()); + #[cfg(feature = "std")] + { + let condvar = std::sync::Condvar::new(); + Self { queue, condvar } + } + #[cfg(not(feature = "std"))] + Self { queue } + } + pub fn enqueue(&self, event: Event) { { let mut queue = self.queue.lock().unwrap(); queue.push_back(event); } + #[cfg(feature = "std")] self.condvar.notify_one(); } + pub fn next_event(&self) -> Option { + self.queue.lock().unwrap().pop_front() + } + + #[cfg(feature = "std")] pub fn wait_next_event(&self) -> Event { let mut queue = self.condvar.wait_while(self.queue.lock().unwrap(), |queue| queue.is_empty()).unwrap(); diff --git a/src/lib.rs b/src/lib.rs index 9ee57ca..11a9d09 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,12 +17,33 @@ #![allow(ellipsis_inclusive_range_patterns)] #![allow(clippy::drop_non_drop)] #![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(not(feature = "std"), no_std)] +#[cfg(not(any(feature = "std", feature = "no-std")))] +compile_error!("at least one of the `std` or `no-std` features must be enabled"); + +#[macro_use] +extern crate alloc; + +mod prelude { + #[cfg(feature = "hashbrown")] + extern crate hashbrown; + + #[cfg(feature = "hashbrown")] + pub use self::hashbrown::{hash_map, HashMap, HashSet}; + pub use alloc::{boxed::Box, collections::VecDeque, string::String, vec, vec::Vec}; + #[cfg(not(feature = "hashbrown"))] + pub use std::collections::{hash_map, HashMap, HashSet}; + + pub use alloc::borrow::ToOwned; + pub use alloc::string::ToString; +} pub mod events; mod lsps0; #[cfg(lsps1)] mod lsps1; pub mod lsps2; +mod sync; mod utils; pub use lsps0::message_handler::{JITChannelsConfig, LiquidityManager, LiquidityProviderConfig}; diff --git a/src/lsps0/message_handler.rs b/src/lsps0/message_handler.rs index efd75d9..43c198b 100644 --- a/src/lsps0/message_handler.rs +++ b/src/lsps0/message_handler.rs @@ -12,6 +12,8 @@ use crate::lsps0::msgs::{LSPSMessage, RawLSPSMessage, LSPS_MESSAGE_TYPE_ID}; use crate::lsps0::protocol::LSPS0MessageHandler; use crate::lsps2::channel_manager::JITChannelManager; use crate::lsps2::msgs::{OpeningFeeParams, RawOpeningFeeParams}; +use crate::prelude::{HashMap, String, ToString, Vec}; +use crate::sync::{Arc, Mutex, RwLock}; use lightning::chain::{self, BestBlock, Confirm, Filter, Listen}; use lightning::ln::channelmanager::{AChannelManager, ChainParameters, InterceptId}; @@ -32,10 +34,8 @@ use bitcoin::BlockHash; #[cfg(lsps1)] use chrono::Utc; -use std::collections::HashMap; -use std::convert::TryFrom; -use std::ops::Deref; -use std::sync::{Arc, Mutex, RwLock}; +use core::convert::TryFrom; +use core::ops::Deref; const LSPS_FEATURE_BIT: usize = 729; @@ -146,7 +146,7 @@ where ) -> Self where { let pending_messages = Arc::new(Mutex::new(vec![])); - let pending_events = Arc::new(EventQueue::default()); + let pending_events = Arc::new(EventQueue::new()); let lsps0_message_handler = LSPS0MessageHandler::new( entropy_source.clone().clone(), @@ -197,13 +197,21 @@ where { } } - /// Blocks until next event is ready and returns it. + /// Blocks the current thread until next event is ready and returns it. /// /// Typically you would spawn a thread or task that calls this in a loop. + #[cfg(feature = "std")] pub fn wait_next_event(&self) -> Event { self.pending_events.wait_next_event() } + /// Returns `Some` if an event is ready. + /// + /// Typically you would spawn a thread or task that calls this in a loop. + pub fn next_event(&self) -> Option { + self.pending_events.next_event() + } + /// Returns and clears all events without blocking. /// /// Typically you would spawn a thread or task that calls this in a loop. diff --git a/src/lsps0/msgs.rs b/src/lsps0/msgs.rs index baddeb0..6630111 100644 --- a/src/lsps0/msgs.rs +++ b/src/lsps0/msgs.rs @@ -7,16 +7,19 @@ use crate::lsps2::msgs::{ LSPS2Message, LSPS2Request, LSPS2Response, LSPS2_BUY_METHOD_NAME, LSPS2_GET_INFO_METHOD_NAME, LSPS2_GET_VERSIONS_METHOD_NAME, }; +use crate::prelude::{HashMap, String, ToString, Vec}; + use lightning::impl_writeable_msg; use lightning::ln::wire; + use serde::de; use serde::de::{MapAccess, Visitor}; use serde::ser::SerializeStruct; use serde::{Deserialize, Deserializer, Serialize}; use serde_json::json; -use std::collections::HashMap; -use std::convert::TryFrom; -use std::fmt; + +use core::convert::TryFrom; +use core::fmt; const LSPS_MESSAGE_SERIALIZED_STRUCT_NAME: &str = "LSPSMessage"; const JSONRPC_FIELD_KEY: &str = "jsonrpc"; diff --git a/src/lsps0/protocol.rs b/src/lsps0/protocol.rs index 3aa4feb..2a3e650 100644 --- a/src/lsps0/protocol.rs +++ b/src/lsps0/protocol.rs @@ -1,17 +1,20 @@ -use bitcoin::secp256k1::PublicKey; -use lightning::ln::msgs::{ErrorAction, LightningError}; -use lightning::sign::EntropySource; -use lightning::util::logger::Level; -use std::ops::Deref; -use std::sync::{Arc, Mutex}; - use crate::lsps0::message_handler::ProtocolMessageHandler; use crate::lsps0::msgs::{ LSPS0Message, LSPS0Request, LSPS0Response, LSPSMessage, ListProtocolsRequest, ListProtocolsResponse, RequestId, ResponseError, }; +use crate::prelude::Vec; +use crate::sync::{Arc, Mutex}; use crate::utils; +use lightning::ln::msgs::{ErrorAction, LightningError}; +use lightning::sign::EntropySource; +use lightning::util::logger::Level; + +use bitcoin::secp256k1::PublicKey; + +use core::ops::Deref; + pub struct LSPS0MessageHandler where ES::Target: EntropySource, @@ -104,7 +107,8 @@ where #[cfg(test)] mod tests { - use std::sync::Arc; + use alloc::string::ToString; + use alloc::sync::Arc; use super::*; diff --git a/src/lsps1/channel_manager.rs b/src/lsps1/channel_manager.rs index 2d0d09b..007075a 100644 --- a/src/lsps1/channel_manager.rs +++ b/src/lsps1/channel_manager.rs @@ -7,12 +7,22 @@ // You may not use this file except in accordance with one or both of these // licenses. -use chrono::Utc; -use std::collections::HashMap; -use std::ops::Deref; -use std::sync::{Arc, Mutex, RwLock}; +use super::msgs::{ + ChannelInfo, CreateOrderRequest, CreateOrderResponse, GetInfoRequest, GetInfoResponse, + GetOrderRequest, GetOrderResponse, LSPS1Message, LSPS1Request, LSPS1Response, OptionsSupported, + Order, OrderId, OrderState, Payment, LSPS1_CREATE_ORDER_REQUEST_INVALID_VERSION_ERROR_CODE, + LSPS1_CREATE_ORDER_REQUEST_ORDER_MISMATCH_ERROR_CODE, +}; +use super::utils::is_valid; + +use crate::events::EventQueue; +use crate::lsps0::message_handler::{CRChannelConfig, ProtocolMessageHandler}; +use crate::lsps0::msgs::{LSPSMessage, RequestId}; +use crate::prelude::{HashMap, String, ToString, Vec}; +use crate::sync::{Arc, Mutex, RwLock}; +use crate::utils; +use crate::{events::Event, lsps0::msgs::ResponseError}; -use bitcoin::secp256k1::PublicKey; use lightning::chain::Filter; use lightning::ln::channelmanager::AChannelManager; use lightning::ln::msgs::{ErrorAction, LightningError}; @@ -21,19 +31,10 @@ use lightning::sign::EntropySource; use lightning::util::errors::APIError; use lightning::util::logger::Level; -use crate::events::EventQueue; -use crate::lsps0::message_handler::{CRChannelConfig, ProtocolMessageHandler}; -use crate::lsps0::msgs::{LSPSMessage, RequestId}; -use crate::utils; -use crate::{events::Event, lsps0::msgs::ResponseError}; +use bitcoin::secp256k1::PublicKey; -use super::msgs::{ - ChannelInfo, CreateOrderRequest, CreateOrderResponse, GetInfoRequest, GetInfoResponse, - GetOrderRequest, GetOrderResponse, LSPS1Message, LSPS1Request, LSPS1Response, OptionsSupported, - Order, OrderId, OrderState, Payment, LSPS1_CREATE_ORDER_REQUEST_INVALID_VERSION_ERROR_CODE, - LSPS1_CREATE_ORDER_REQUEST_ORDER_MISMATCH_ERROR_CODE, -}; -use super::utils::is_valid; +use chrono::Utc; +use core::ops::Deref; const SUPPORTED_SPEC_VERSIONS: [u16; 1] = [1]; @@ -329,11 +330,11 @@ where let inner_state_lock = outer_state_lock .entry(counterparty_node_id) .or_insert(Mutex::new(PeerState::default())); - let peer_state = inner_state_lock.get_mut().unwrap(); - peer_state.insert_inbound_channel(channel_id, channel); + let mut peer_state_lock = inner_state_lock.lock().unwrap(); + peer_state_lock.insert_inbound_channel(channel_id, channel); let request_id = self.generate_request_id(); - peer_state.insert_request(request_id.clone(), channel_id); + peer_state_lock.insert_request(request_id.clone(), channel_id); { let mut pending_messages = self.pending_messages.lock().unwrap(); @@ -375,10 +376,10 @@ where match outer_state_lock.get(counterparty_node_id) { Some(inner_state_lock) => { - let mut peer_state = inner_state_lock.lock().unwrap(); + let mut peer_state_lock = inner_state_lock.lock().unwrap(); let channel_id = - peer_state.request_to_cid.remove(&request_id).ok_or(LightningError { + peer_state_lock.request_to_cid.remove(&request_id).ok_or(LightningError { err: format!( "Received get_info response for an unknown request: {:?}", request_id @@ -386,23 +387,23 @@ where action: ErrorAction::IgnoreAndLog(Level::Info), })?; - let inbound_channel = peer_state + let inbound_channel = peer_state_lock .inbound_channels_by_id .get_mut(&channel_id) .ok_or(LightningError { - err: format!( - "Received get_info response for an unknown channel: {:?}", - channel_id - ), - action: ErrorAction::IgnoreAndLog(Level::Info), - })?; + err: format!( + "Received get_info response for an unknown channel: {:?}", + channel_id + ), + action: ErrorAction::IgnoreAndLog(Level::Info), + })?; let version = match inbound_channel .info_received(result.supported_versions, result.options.clone()) { Ok(version) => version, Err(e) => { - peer_state.remove_inbound_channel(channel_id); + peer_state_lock.remove_inbound_channel(channel_id); return Err(e); } }; @@ -436,25 +437,25 @@ where match outer_state_lock.get(counterparty_node_id) { Some(inner_state_lock) => { - let mut peer_state = inner_state_lock.lock().unwrap(); + let mut peer_state_lock = inner_state_lock.lock().unwrap(); - let inbound_channel = peer_state + let inbound_channel = peer_state_lock .inbound_channels_by_id .get_mut(&channel_id) .ok_or(APIError::APIMisuseError { - err: format!("Channel with id {} not found", channel_id), - })?; + err: format!("Channel with id {} not found", channel_id), + })?; let version = match inbound_channel.order_requested(order.clone()) { Ok(version) => version, Err(e) => { - peer_state.remove_inbound_channel(channel_id); + peer_state_lock.remove_inbound_channel(channel_id); return Err(APIError::APIMisuseError { err: e.err }); } }; let request_id = self.generate_request_id(); - peer_state.insert_request(request_id.clone(), channel_id); + peer_state_lock.insert_request(request_id.clone(), channel_id); { let mut pending_messages = self.pending_messages.lock().unwrap(); @@ -523,9 +524,9 @@ where let inner_state_lock = outer_state_lock .entry(*counterparty_node_id) .or_insert(Mutex::new(PeerState::default())); - let peer_state = inner_state_lock.get_mut().unwrap(); + let mut peer_state_lock = inner_state_lock.lock().unwrap(); - peer_state + peer_state_lock .pending_requests .insert(request_id.clone(), LSPS1Request::CreateOrder(params.clone())); @@ -546,9 +547,9 @@ where match outer_state_lock.get(counterparty_node_id) { Some(inner_state_lock) => { - let mut peer_state = inner_state_lock.lock().unwrap(); + let mut peer_state_lock = inner_state_lock.lock().unwrap(); - match peer_state.pending_requests.remove(&request_id) { + match peer_state_lock.pending_requests.remove(&request_id) { Some(LSPS1Request::CreateOrder(params)) => { let order_id = self.generate_order_id(); let channel = OutboundCRChannel::new( @@ -559,7 +560,7 @@ where payment.clone(), ); - peer_state.insert_outbound_channel(order_id.clone(), channel); + peer_state_lock.insert_outbound_channel(order_id.clone(), channel); self.enqueue_response( *counterparty_node_id, @@ -603,10 +604,10 @@ where let outer_state_lock = self.per_peer_state.read().unwrap(); match outer_state_lock.get(&counterparty_node_id) { Some(inner_state_lock) => { - let mut peer_state = inner_state_lock.lock().unwrap(); + let mut peer_state_lock = inner_state_lock.lock().unwrap(); let channel_id = - peer_state.request_to_cid.remove(&request_id).ok_or(LightningError { + peer_state_lock.request_to_cid.remove(&request_id).ok_or(LightningError { err: format!( "Received create_order response for an unknown request: {:?}", request_id @@ -614,21 +615,21 @@ where action: ErrorAction::IgnoreAndLog(Level::Info), })?; - let inbound_channel = peer_state + let inbound_channel = peer_state_lock .inbound_channels_by_id .get_mut(&channel_id) .ok_or(LightningError { - err: format!( - "Received create_order response for an unknown channel: {:?}", - channel_id - ), - action: ErrorAction::IgnoreAndLog(Level::Info), - })?; + err: format!( + "Received create_order response for an unknown channel: {:?}", + channel_id + ), + action: ErrorAction::IgnoreAndLog(Level::Info), + })?; if let Err(e) = inbound_channel.order_received(&response.order, response.order_id.clone()) { - peer_state.remove_inbound_channel(channel_id); + peer_state_lock.remove_inbound_channel(channel_id); return Err(e); } @@ -644,7 +645,7 @@ where channel: response.channel, })); } else { - peer_state.remove_inbound_channel(channel_id); + peer_state_lock.remove_inbound_channel(channel_id); return Err(LightningError { err: format!("Fees are too high : {:?}", total_fees), action: ErrorAction::IgnoreAndLog(Level::Info), @@ -671,10 +672,10 @@ where let outer_state_lock = self.per_peer_state.read().unwrap(); match outer_state_lock.get(&counterparty_node_id) { Some(inner_state_lock) => { - let mut peer_state = inner_state_lock.lock().unwrap(); + let mut peer_state_lock = inner_state_lock.lock().unwrap(); let channel_id = - peer_state.request_to_cid.remove(&request_id).ok_or(LightningError { + peer_state_lock.request_to_cid.remove(&request_id).ok_or(LightningError { err: format!( "Received create order error for an unknown request: {:?}", request_id @@ -682,16 +683,16 @@ where action: ErrorAction::IgnoreAndLog(Level::Info), })?; - let inbound_channel = peer_state + let inbound_channel = peer_state_lock .inbound_channels_by_id .get_mut(&channel_id) .ok_or(LightningError { - err: format!( - "Received create order error for an unknown channel: {:?}", - channel_id - ), - action: ErrorAction::IgnoreAndLog(Level::Info), - })?; + err: format!( + "Received create order error for an unknown channel: {:?}", + channel_id + ), + action: ErrorAction::IgnoreAndLog(Level::Info), + })?; Ok(()) } None => { @@ -706,18 +707,18 @@ where let outer_state_lock = self.per_peer_state.write().unwrap(); match outer_state_lock.get(&counterparty_node_id) { Some(inner_state_lock) => { - let mut peer_state = inner_state_lock.lock().unwrap(); + let mut peer_state_lock = inner_state_lock.lock().unwrap(); if let Some(inbound_channel) = - peer_state.inbound_channels_by_id.get_mut(&channel_id) + peer_state_lock.inbound_channels_by_id.get_mut(&channel_id) { if let Err(e) = inbound_channel.pay_for_channel(channel_id) { - peer_state.remove_inbound_channel(channel_id); + peer_state_lock.remove_inbound_channel(channel_id); return Err(APIError::APIMisuseError { err: e.err }); } let request_id = self.generate_request_id(); - peer_state.insert_request(request_id.clone(), channel_id); + peer_state_lock.insert_request(request_id.clone(), channel_id); { let mut pending_messages = self.pending_messages.lock().unwrap(); @@ -757,9 +758,9 @@ where let outer_state_lock = self.per_peer_state.read().unwrap(); match outer_state_lock.get(&counterparty_node_id) { Some(inner_state_lock) => { - let mut peer_state = inner_state_lock.lock().unwrap(); + let mut peer_state_lock = inner_state_lock.lock().unwrap(); - let outbound_channel = peer_state + let outbound_channel = peer_state_lock .outbound_channels_by_order_id .get_mut(¶ms.order_id) .ok_or(LightningError { @@ -771,7 +772,7 @@ where })?; if let Err(e) = outbound_channel.create_payment_invoice() { - peer_state.outbound_channels_by_order_id.remove(¶ms.order_id); + peer_state_lock.outbound_channels_by_order_id.remove(¶ms.order_id); self.enqueue_event(Event::LSPS1(super::event::Event::Refund { request_id, counterparty_node_id: *counterparty_node_id, @@ -780,7 +781,7 @@ where return Err(e); } - peer_state + peer_state_lock .pending_requests .insert(request_id.clone(), LSPS1Request::GetOrder(params.clone())); @@ -809,10 +810,10 @@ where match outer_state_lock.get(&counterparty_node_id) { Some(inner_state_lock) => { - let mut peer_state = inner_state_lock.lock().unwrap(); + let mut peer_state_lock = inner_state_lock.lock().unwrap(); if let Some(outbound_channel) = - peer_state.outbound_channels_by_order_id.get_mut(&order_id) + peer_state_lock.outbound_channels_by_order_id.get_mut(&order_id) { let config = &outbound_channel.config; @@ -852,10 +853,10 @@ where let outer_state_lock = self.per_peer_state.read().unwrap(); match outer_state_lock.get(&counterparty_node_id) { Some(inner_state_lock) => { - let mut peer_state = inner_state_lock.lock().unwrap(); + let mut peer_state_lock = inner_state_lock.lock().unwrap(); let channel_id = - peer_state.request_to_cid.remove(&request_id).ok_or(LightningError { + peer_state_lock.request_to_cid.remove(&request_id).ok_or(LightningError { err: format!( "Received get_versions response for an unknown request: {:?}", request_id @@ -863,16 +864,16 @@ where action: ErrorAction::IgnoreAndLog(Level::Info), })?; - let inbound_channel = peer_state + let inbound_channel = peer_state_lock .inbound_channels_by_id .get_mut(&channel_id) .ok_or(LightningError { - err: format!( - "Received get_versions response for an unknown channel: {:?}", - channel_id - ), - action: ErrorAction::IgnoreAndLog(Level::Info), - })?; + err: format!( + "Received get_versions response for an unknown channel: {:?}", + channel_id + ), + action: ErrorAction::IgnoreAndLog(Level::Info), + })?; } None => { return Err(LightningError { @@ -894,10 +895,10 @@ where let outer_state_lock = self.per_peer_state.read().unwrap(); match outer_state_lock.get(&counterparty_node_id) { Some(inner_state_lock) => { - let mut peer_state = inner_state_lock.lock().unwrap(); + let mut peer_state_lock = inner_state_lock.lock().unwrap(); let channel_id = - peer_state.request_to_cid.remove(&request_id).ok_or(LightningError { + peer_state_lock.request_to_cid.remove(&request_id).ok_or(LightningError { err: format!( "Received get_order error for an unknown request: {:?}", request_id @@ -905,7 +906,7 @@ where action: ErrorAction::IgnoreAndLog(Level::Info), })?; - let _inbound_channel = peer_state + let _inbound_channel = peer_state_lock .inbound_channels_by_id .get_mut(&channel_id) .ok_or(LightningError { diff --git a/src/lsps1/event.rs b/src/lsps1/event.rs index 946979f..8bfe9d9 100644 --- a/src/lsps1/event.rs +++ b/src/lsps1/event.rs @@ -1,9 +1,11 @@ #![allow(missing_docs)] -use bitcoin::secp256k1::PublicKey; - use super::msgs::{ChannelInfo, OptionsSupported, Order, OrderId, Payment}; + use crate::lsps0::msgs::RequestId; +use crate::prelude::String; + +use bitcoin::secp256k1::PublicKey; /// An "Event" which you should probably take some action in response to. #[derive(Clone, Debug, PartialEq, Eq)] diff --git a/src/lsps1/msgs.rs b/src/lsps1/msgs.rs index 3248789..535727e 100644 --- a/src/lsps1/msgs.rs +++ b/src/lsps1/msgs.rs @@ -1,9 +1,11 @@ -use chrono::Utc; -use std::convert::TryFrom; +use crate::lsps0::msgs::{LSPSMessage, RequestId, ResponseError}; +use crate::prelude::{String, Vec}; use serde::{Deserialize, Serialize}; -use crate::lsps0::msgs::{LSPSMessage, RequestId, ResponseError}; +use chrono::Utc; + +use core::convert::TryFrom; pub(crate) const LSPS1_GET_INFO_METHOD_NAME: &str = "lsps1.get_info"; pub(crate) const LSPS1_CREATE_ORDER_METHOD_NAME: &str = "lsps1.create_order"; diff --git a/src/lsps2/channel_manager.rs b/src/lsps2/channel_manager.rs index 059bdce..9dac867 100644 --- a/src/lsps2/channel_manager.rs +++ b/src/lsps2/channel_manager.rs @@ -7,12 +7,16 @@ // You may not use this file except in accordance with one or both of these // licenses. -use std::collections::HashMap; -use std::convert::TryInto; -use std::ops::Deref; -use std::sync::{Arc, Mutex, RwLock}; +use crate::events::EventQueue; +use crate::lsps0::message_handler::ProtocolMessageHandler; +use crate::lsps0::msgs::{LSPSMessage, RequestId}; +use crate::lsps2::utils::{compute_opening_fee, is_valid_opening_fee_params}; +use crate::lsps2::LSPS2Event; +use crate::prelude::{HashMap, String, ToString, Vec}; +use crate::sync::{Arc, Mutex, RwLock}; +use crate::{events::Event, lsps0::msgs::ResponseError}; +use crate::{utils, JITChannelsConfig}; -use bitcoin::secp256k1::PublicKey; use lightning::ln::channelmanager::{AChannelManager, InterceptId}; use lightning::ln::msgs::{ErrorAction, LightningError}; use lightning::ln::peer_handler::APeerManager; @@ -21,13 +25,10 @@ use lightning::sign::EntropySource; use lightning::util::errors::APIError; use lightning::util::logger::Level; -use crate::events::EventQueue; -use crate::lsps0::message_handler::ProtocolMessageHandler; -use crate::lsps0::msgs::{LSPSMessage, RequestId}; -use crate::lsps2::utils::{compute_opening_fee, is_valid_opening_fee_params}; -use crate::lsps2::LSPS2Event; -use crate::{events::Event, lsps0::msgs::ResponseError}; -use crate::{utils, JITChannelsConfig}; +use bitcoin::secp256k1::PublicKey; + +use core::convert::TryInto; +use core::ops::Deref; use crate::lsps2::msgs::{ BuyRequest, BuyResponse, GetInfoRequest, GetInfoResponse, GetVersionsRequest, @@ -356,7 +357,6 @@ impl OutboundJITChannel { } } -#[derive(Default)] struct PeerState { inbound_channels_by_id: HashMap, outbound_channels_by_scid: HashMap, @@ -365,6 +365,14 @@ struct PeerState { } impl PeerState { + pub fn new() -> Self { + let inbound_channels_by_id = HashMap::new(); + let outbound_channels_by_scid = HashMap::new(); + let request_to_cid = HashMap::new(); + let pending_requests = HashMap::new(); + Self { inbound_channels_by_id, outbound_channels_by_scid, request_to_cid, pending_requests } + } + pub fn insert_inbound_channel(&mut self, jit_channel_id: u128, channel: InboundJITChannel) { self.inbound_channels_by_id.insert(jit_channel_id, channel); } @@ -442,14 +450,13 @@ where InboundJITChannel::new(jit_channel_id, user_channel_id, payment_size_msat, token); let mut outer_state_lock = self.per_peer_state.write().unwrap(); - let inner_state_lock = outer_state_lock - .entry(counterparty_node_id) - .or_insert(Mutex::new(PeerState::default())); - let peer_state = inner_state_lock.get_mut().unwrap(); - peer_state.insert_inbound_channel(jit_channel_id, channel); + let inner_state_lock = + outer_state_lock.entry(counterparty_node_id).or_insert(Mutex::new(PeerState::new())); + let mut peer_state_lock = inner_state_lock.lock().unwrap(); + peer_state_lock.insert_inbound_channel(jit_channel_id, channel); let request_id = self.generate_request_id(); - peer_state.insert_request(request_id.clone(), jit_channel_id); + peer_state_lock.insert_request(request_id.clone(), jit_channel_id); { let mut pending_messages = self.pending_messages.lock().unwrap(); @@ -876,11 +883,10 @@ where } let mut outer_state_lock = self.per_peer_state.write().unwrap(); - let inner_state_lock: &mut Mutex = outer_state_lock - .entry(*counterparty_node_id) - .or_insert(Mutex::new(PeerState::default())); - let peer_state = inner_state_lock.get_mut().unwrap(); - peer_state + let inner_state_lock: &mut Mutex = + outer_state_lock.entry(*counterparty_node_id).or_insert(Mutex::new(PeerState::new())); + let mut peer_state_lock = inner_state_lock.lock().unwrap(); + peer_state_lock .pending_requests .insert(request_id.clone(), LSPS2Request::GetInfo(params.clone())); @@ -1097,11 +1103,12 @@ where } let mut outer_state_lock = self.per_peer_state.write().unwrap(); - let inner_state_lock = outer_state_lock - .entry(*counterparty_node_id) - .or_insert(Mutex::new(PeerState::default())); - let peer_state = inner_state_lock.get_mut().unwrap(); - peer_state.pending_requests.insert(request_id.clone(), LSPS2Request::Buy(params.clone())); + let inner_state_lock = + outer_state_lock.entry(*counterparty_node_id).or_insert(Mutex::new(PeerState::new())); + let mut peer_state_lock = inner_state_lock.lock().unwrap(); + peer_state_lock + .pending_requests + .insert(request_id.clone(), LSPS2Request::Buy(params.clone())); self.enqueue_event(Event::LSPS2(LSPS2Event::BuyRequest { request_id, @@ -1276,7 +1283,7 @@ fn calculate_amount_to_forward_per_htlc( let proportional_fee_amt_msat = total_fee_msat * htlc.expected_outbound_amount_msat / total_received_msat; - let mut actual_fee_amt_msat = std::cmp::min(fee_remaining_msat, proportional_fee_amt_msat); + let mut actual_fee_amt_msat = core::cmp::min(fee_remaining_msat, proportional_fee_amt_msat); fee_remaining_msat -= actual_fee_amt_msat; if index == htlcs.len() - 1 { diff --git a/src/lsps2/event.rs b/src/lsps2/event.rs index 40082ea..6f094ce 100644 --- a/src/lsps2/event.rs +++ b/src/lsps2/event.rs @@ -7,10 +7,11 @@ // You may not use this file except in accordance with one or both of these // licenses. -use bitcoin::secp256k1::PublicKey; - use super::msgs::OpeningFeeParams; use crate::lsps0::msgs::RequestId; +use crate::prelude::{String, Vec}; + +use bitcoin::secp256k1::PublicKey; /// An event which you should probably take some action in response to. #[derive(Clone, Debug, PartialEq, Eq)] diff --git a/src/lsps2/msgs.rs b/src/lsps2/msgs.rs index 5867df2..d1b9205 100644 --- a/src/lsps2/msgs.rs +++ b/src/lsps2/msgs.rs @@ -1,4 +1,4 @@ -use std::convert::TryFrom; +use core::convert::TryFrom; use bitcoin::hashes::hmac::{Hmac, HmacEngine}; use bitcoin::hashes::sha256::Hash as Sha256; @@ -7,6 +7,7 @@ use chrono::Utc; use serde::{Deserialize, Serialize}; use crate::lsps0::msgs::{LSPSMessage, RequestId, ResponseError}; +use crate::prelude::{String, Vec}; use crate::utils; pub(crate) const LSPS2_GET_VERSIONS_METHOD_NAME: &str = "lsps2.get_versions"; @@ -304,6 +305,8 @@ mod tests { } #[test] + #[cfg(feature = "std")] + // TODO: We need to find a way to check expiry times in no-std builds. fn expired_params_produces_invalid_params() { let min_fee_msat = 100; let proportional = 21; diff --git a/src/lsps2/utils.rs b/src/lsps2/utils.rs index b411745..6721a5b 100644 --- a/src/lsps2/utils.rs +++ b/src/lsps2/utils.rs @@ -1,28 +1,32 @@ +use crate::lsps2::msgs::OpeningFeeParams; +use crate::utils; + use bitcoin::hashes::hmac::{Hmac, HmacEngine}; use bitcoin::hashes::sha256::Hash as Sha256; use bitcoin::hashes::{Hash, HashEngine}; -use std::convert::TryInto; +#[cfg(feature = "std")] use std::time::{SystemTime, UNIX_EPOCH}; -use crate::lsps2::msgs::OpeningFeeParams; -use crate::utils; - /// Determines if the given parameters are valid given the secret used to generate the promise. pub fn is_valid_opening_fee_params( fee_params: &OpeningFeeParams, promise_secret: &[u8; 32], ) -> bool { - let seconds_since_epoch = SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("system clock to be ahead of the unix epoch") - .as_secs(); - let valid_until_seconds_since_epoch = fee_params - .valid_until - .timestamp() - .try_into() - .expect("expiration to be ahead of unix epoch"); - if seconds_since_epoch > valid_until_seconds_since_epoch { - return false; + #[cfg(feature = "std")] + { + // TODO: We need to find a way to check expiry times in no-std builds. + let seconds_since_epoch = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system clock to be ahead of the unix epoch") + .as_secs(); + let valid_until_seconds_since_epoch = fee_params + .valid_until + .timestamp() + .try_into() + .expect("expiration to be ahead of unix epoch"); + if seconds_since_epoch > valid_until_seconds_since_epoch { + return false; + } } let mut hmac = HmacEngine::::new(promise_secret); @@ -48,5 +52,5 @@ pub fn compute_opening_fee( .checked_mul(opening_fee_proportional) .and_then(|f| f.checked_add(999999)) .and_then(|f| f.checked_div(1000000)) - .map(|f| std::cmp::max(f, opening_fee_min_fee_msat)) + .map(|f| core::cmp::max(f, opening_fee_min_fee_msat)) } diff --git a/src/sync/mod.rs b/src/sync/mod.rs new file mode 100644 index 0000000..1371de2 --- /dev/null +++ b/src/sync/mod.rs @@ -0,0 +1,7 @@ +#[cfg(feature = "std")] +pub use std::sync::{Arc, Condvar, Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard}; + +#[cfg(not(feature = "std"))] +mod nostd_sync; +#[cfg(not(feature = "std"))] +pub use nostd_sync::*; diff --git a/src/sync/nostd_sync.rs b/src/sync/nostd_sync.rs new file mode 100644 index 0000000..b0d2630 --- /dev/null +++ b/src/sync/nostd_sync.rs @@ -0,0 +1,102 @@ +//! This file was copied from `rust-lightning`. +pub use ::alloc::sync::Arc; +use core::cell::{Ref, RefCell, RefMut}; +use core::ops::{Deref, DerefMut}; + +pub type LockResult = Result; + +pub struct Mutex { + inner: RefCell, +} + +#[must_use = "if unused the Mutex will immediately unlock"] +pub struct MutexGuard<'a, T: ?Sized + 'a> { + lock: RefMut<'a, T>, +} + +impl Deref for MutexGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &T { + &self.lock.deref() + } +} + +impl DerefMut for MutexGuard<'_, T> { + fn deref_mut(&mut self) -> &mut T { + self.lock.deref_mut() + } +} + +impl Mutex { + pub fn new(inner: T) -> Mutex { + Mutex { inner: RefCell::new(inner) } + } + + pub fn lock<'a>(&'a self) -> LockResult> { + Ok(MutexGuard { lock: self.inner.borrow_mut() }) + } + + pub fn try_lock<'a>(&'a self) -> LockResult> { + Ok(MutexGuard { lock: self.inner.borrow_mut() }) + } + + pub fn into_inner(self) -> LockResult { + Ok(self.inner.into_inner()) + } +} + +pub struct RwLock { + inner: RefCell, +} + +pub struct RwLockReadGuard<'a, T: ?Sized + 'a> { + lock: Ref<'a, T>, +} + +pub struct RwLockWriteGuard<'a, T: ?Sized + 'a> { + lock: RefMut<'a, T>, +} + +impl Deref for RwLockReadGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &T { + &self.lock.deref() + } +} + +impl Deref for RwLockWriteGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &T { + &self.lock.deref() + } +} + +impl DerefMut for RwLockWriteGuard<'_, T> { + fn deref_mut(&mut self) -> &mut T { + self.lock.deref_mut() + } +} + +impl RwLock { + pub fn new(inner: T) -> RwLock { + RwLock { inner: RefCell::new(inner) } + } + + pub fn read<'a>(&'a self) -> LockResult> { + Ok(RwLockReadGuard { lock: self.inner.borrow() }) + } + + pub fn write<'a>(&'a self) -> LockResult> { + Ok(RwLockWriteGuard { lock: self.inner.borrow_mut() }) + } + + pub fn try_write<'a>(&'a self) -> LockResult> { + match self.inner.try_borrow_mut() { + Ok(lock) => Ok(RwLockWriteGuard { lock }), + Err(_) => Err(()), + } + } +} diff --git a/src/utils.rs b/src/utils.rs index d80d8de..d035fde 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,8 +1,10 @@ use bitcoin::secp256k1::PublicKey; +use core::{fmt::Write, ops::Deref}; +use lightning::io; use lightning::sign::EntropySource; -use std::{fmt::Write, ops::Deref}; use crate::lsps0::msgs::RequestId; +use crate::prelude::{String, Vec}; /// Maximum transaction index that can be used in a `short_channel_id`. /// This value is based on the 3-bytes available for tx index. @@ -89,11 +91,11 @@ pub fn to_compressed_pubkey(hex: &str) -> Option { } } -pub fn parse_pubkey(pubkey_str: &str) -> Result { +pub fn parse_pubkey(pubkey_str: &str) -> Result { let pubkey = to_compressed_pubkey(pubkey_str); if pubkey.is_none() { - return Err(std::io::Error::new( - std::io::ErrorKind::Other, + return Err(io::Error::new( + io::ErrorKind::Other, "ERROR: unable to parse given pubkey for node", )); }