diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 065a3241a8e..b94f9784a2e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -13,7 +13,7 @@ jobs: 1.30.0, # 1.34.2 is Debian stable 1.34.2, - # 1.39.0 is MSRV for lightning-net-tokio and generates coverage + # 1.39.0 is MSRV for lightning-net-tokio and lightning-block-sync and generates coverage 1.39.0] include: - toolchain: stable @@ -48,6 +48,15 @@ jobs: - name: Build on Rust ${{ matrix.toolchain }} if: "! matrix.build-net-tokio" run: cargo build --verbose --color always -p lightning + - name: Build Block Sync Clients on Rust ${{ matrix.toolchain }} with features + if: matrix.build-net-tokio + run: | + cd lightning-block-sync + RUSTFLAGS="-C link-dead-code" cargo build --verbose --color always --features rest-client + RUSTFLAGS="-C link-dead-code" cargo build --verbose --color always --features rpc-client + RUSTFLAGS="-C link-dead-code" cargo build --verbose --color always --features rpc-client,rest-client + RUSTFLAGS="-C link-dead-code" cargo build --verbose --color always --features rpc-client,rest-client,tokio + cd .. - name: Test on Rust ${{ matrix.toolchain }} with net-tokio if: "matrix.build-net-tokio && !matrix.coverage" run: cargo test --verbose --color always diff --git a/Cargo.toml b/Cargo.toml index c43e7927581..96f4b1d1770 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ members = [ "lightning", + "lightning-block-sync", "lightning-net-tokio", "lightning-persister", ] diff --git a/lightning-block-sync/Cargo.toml b/lightning-block-sync/Cargo.toml new file mode 100644 index 00000000000..07294d35606 --- /dev/null +++ b/lightning-block-sync/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "lightning-block-sync" +version = "0.0.1" +authors = ["Matt Corallo"] +license = "Apache-2.0" +edition = "2018" +description = """ +Utilities to fetch the chain from Bitcoin Core REST/RPC Interfaces and feed them into Rust Lightning. +""" + +[features] +rest-client = [ "serde", "serde_json", "serde_derive", "chunked_transfer" ] +rpc-client = [ "serde", "serde_json", "serde_derive", "base64", "chunked_transfer" ] + +[dependencies] +bitcoin = "0.24" +lightning = { version = "0.0.11", path = "../lightning" } +tokio = { version = ">=0.2.12", features = [ "tcp", "io-util", "dns" ], optional = true } +serde = { version = "1", optional = true } +serde_json = { version = "1", optional = true } +serde_derive = { version = "1", optional = true } +base64 = { version = "0.9", optional = true } +chunked_transfer = { version = "1.3.0", optional = true } +futures = { version = "0.3.8" } + +[dev-dependencies] +tokio = { version = ">=0.2.12", features = [ "macros", "rt-core" ] } diff --git a/lightning-block-sync/src/http_clients.rs b/lightning-block-sync/src/http_clients.rs new file mode 100644 index 00000000000..be2769cfdba --- /dev/null +++ b/lightning-block-sync/src/http_clients.rs @@ -0,0 +1,1301 @@ +use crate::http_endpoint::HttpEndpoint; +use crate::utils::hex_to_uint256; +use crate::{BlockHeaderData, BlockSource, BlockSourceError, AsyncBlockSourceResult}; + +use bitcoin::blockdata::block::{Block, BlockHeader}; +use bitcoin::consensus::encode; +use bitcoin::hash_types::{BlockHash, TxMerkleNode}; +use bitcoin::hashes::hex::{ToHex, FromHex}; + +use chunked_transfer; + +use serde_derive::Deserialize; + +use serde_json; + +use std::convert::TryFrom; +use std::convert::TryInto; +#[cfg(not(feature = "tokio"))] +use std::io::Write; +use std::net::ToSocketAddrs; +use std::time::Duration; + +#[cfg(feature = "rpc-client")] +use base64; +#[cfg(feature = "rpc-client")] +use std::sync::atomic::{AtomicUsize, Ordering}; + +#[cfg(feature = "tokio")] +use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt}; +#[cfg(feature = "tokio")] +use tokio::net::TcpStream; + +#[cfg(not(feature = "tokio"))] +use std::io::BufRead; +use std::io::Read; +#[cfg(not(feature = "tokio"))] +use std::net::TcpStream; + +/// Maximum HTTP message header size in bytes. +const MAX_HTTP_MESSAGE_HEADER_SIZE: usize = 8192; + +/// Maximum HTTP message body size in bytes. +const MAX_HTTP_MESSAGE_BODY_SIZE: usize = 4_000_000; + +/// Client for making HTTP requests. +struct HttpClient { + stream: TcpStream, +} + +impl HttpClient { + /// Opens a connection to an HTTP endpoint. + fn connect(endpoint: E) -> std::io::Result { + let address = match endpoint.to_socket_addrs()?.next() { + None => { + return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "could not resolve to any addresses")); + }, + Some(address) => address, + }; + let stream = std::net::TcpStream::connect_timeout(&address, Duration::from_secs(1))?; + stream.set_read_timeout(Some(Duration::from_secs(2)))?; + stream.set_write_timeout(Some(Duration::from_secs(1)))?; + + #[cfg(feature = "tokio")] + let stream = { + stream.set_nonblocking(true)?; + TcpStream::from_std(stream)? + }; + + Ok(Self { stream }) + } + + /// Sends a `GET` request for a resource identified by `uri` at the `host`. + async fn get(&mut self, uri: &str, host: &str) -> std::io::Result + where F: TryFrom, Error = std::io::Error> { + let request = format!( + "GET {} HTTP/1.1\r\n\ + Host: {}\r\n\ + Connection: keep-alive\r\n\ + \r\n", uri, host); + self.write_request(request).await?; + let bytes = self.read_response().await?; + F::try_from(bytes) + } + + /// Sends a `POST` request for a resource identified by `uri` at the `host` using the given HTTP + /// authentication credentials. + /// + /// The request body consists of the provided JSON `content`. Returns the response body in `F` + /// format. + async fn post(&mut self, uri: &str, host: &str, auth: &str, content: serde_json::Value) -> std::io::Result + where F: TryFrom, Error = std::io::Error> { + let content = content.to_string(); + let request = format!( + "POST {} HTTP/1.1\r\n\ + Host: {}\r\n\ + Authorization: {}\r\n\ + Connection: keep-alive\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}", uri, host, auth, content.len(), content); + self.write_request(request).await?; + let bytes = self.read_response().await?; + F::try_from(bytes) + } + + /// Writes an HTTP request message. + async fn write_request(&mut self, request: String) -> std::io::Result<()> { + #[cfg(feature = "tokio")] + { + self.stream.write_all(request.as_bytes()).await?; + self.stream.flush().await + } + #[cfg(not(feature = "tokio"))] + { + self.stream.write_all(request.as_bytes())?; + self.stream.flush() + } + } + + /// Reads an HTTP response message. + async fn read_response(&mut self) -> std::io::Result> { + #[cfg(feature = "tokio")] + let stream = self.stream.split().0; + #[cfg(not(feature = "tokio"))] + let stream = std::io::Read::by_ref(&mut self.stream); + + let limited_stream = stream.take(MAX_HTTP_MESSAGE_HEADER_SIZE as u64); + + #[cfg(feature = "tokio")] + let mut reader = tokio::io::BufReader::new(limited_stream); + #[cfg(not(feature = "tokio"))] + let mut reader = std::io::BufReader::new(limited_stream); + + macro_rules! read_line { () => { { + let mut line = String::new(); + #[cfg(feature = "tokio")] + let bytes_read = reader.read_line(&mut line).await?; + #[cfg(not(feature = "tokio"))] + let bytes_read = reader.read_line(&mut line)?; + + match bytes_read { + 0 => None, + _ => { + // Remove trailing CRLF + if line.ends_with('\n') { line.pop(); if line.ends_with('\r') { line.pop(); } } + Some(line) + }, + } + } } } + + // Read and parse status line + let status_line = read_line!() + .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no status line"))?; + let status = HttpStatus::parse(&status_line)?; + + // Read and parse relevant headers + let mut message_length = HttpMessageLength::Empty; + loop { + let line = read_line!() + .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "unexpected eof"))?; + if line.is_empty() { break; } + + let header = HttpHeader::parse(&line)?; + if header.has_name("Content-Length") { + let length = header.value.parse() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + if let HttpMessageLength::Empty = message_length { + message_length = HttpMessageLength::ContentLength(length); + } + continue; + } + + if header.has_name("Transfer-Encoding") { + message_length = HttpMessageLength::TransferEncoding(header.value.into()); + continue; + } + } + + if !status.is_ok() { + return Err(std::io::Error::new(std::io::ErrorKind::NotFound, "not found")); + } + + // Read message body + let read_limit = MAX_HTTP_MESSAGE_BODY_SIZE - reader.buffer().len(); + reader.get_mut().set_limit(read_limit as u64); + match message_length { + HttpMessageLength::Empty => { Ok(Vec::new()) }, + HttpMessageLength::ContentLength(length) => { + if length == 0 || length > MAX_HTTP_MESSAGE_BODY_SIZE { + Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "out of range")) + } else { + let mut content = vec![0; length]; + #[cfg(feature = "tokio")] + reader.read_exact(&mut content[..]).await?; + #[cfg(not(feature = "tokio"))] + reader.read_exact(&mut content[..])?; + Ok(content) + } + }, + HttpMessageLength::TransferEncoding(coding) => { + if !coding.eq_ignore_ascii_case("chunked") { + Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, "unsupported transfer coding")) + } else { + #[cfg(feature = "tokio")] + let reader = ReadAdapter(&mut reader); + let mut decoder = chunked_transfer::Decoder::new(reader); + let mut content = Vec::new(); + decoder.read_to_end(&mut content)?; + Ok(content) + } + }, + } + } +} + +/// HTTP response status code as defined by [RFC 7231]. +/// +/// [RFC 7231]: https://tools.ietf.org/html/rfc7231#section-6 +struct HttpStatus<'a> { + code: &'a str, +} + +impl<'a> HttpStatus<'a> { + /// Parses an HTTP status line as defined by [RFC 7230]. + /// + /// [RFC 7230]: https://tools.ietf.org/html/rfc7230#section-3.1.2 + fn parse(line: &'a String) -> std::io::Result> { + let mut tokens = line.splitn(3, ' '); + + let http_version = tokens.next() + .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no HTTP-Version"))?; + if !http_version.eq_ignore_ascii_case("HTTP/1.1") && + !http_version.eq_ignore_ascii_case("HTTP/1.0") { + return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid HTTP-Version")); + } + + let code = tokens.next() + .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no Status-Code"))?; + if code.len() != 3 || !code.chars().all(|c| c.is_ascii_digit()) { + return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid Status-Code")); + } + + let _reason = tokens.next() + .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no Reason-Phrase"))?; + + Ok(Self { code }) + } + + /// Returns whether the status is successful (i.e., 2xx status class). + fn is_ok(&self) -> bool { + self.code.starts_with('2') + } +} + +/// HTTP response header as defined by [RFC 7231]. +/// +/// [RFC 7231]: https://tools.ietf.org/html/rfc7231#section-7 +struct HttpHeader<'a> { + name: &'a str, + value: &'a str, +} + +impl<'a> HttpHeader<'a> { + /// Parses an HTTP header field as defined by [RFC 7230]. + /// + /// [RFC 7230]: https://tools.ietf.org/html/rfc7230#section-3.2 + fn parse(line: &'a String) -> std::io::Result> { + let mut tokens = line.splitn(2, ':'); + let name = tokens.next() + .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no header name"))?; + let value = tokens.next() + .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no header value"))? + .trim_start(); + Ok(Self { name, value }) + } + + /// Returns whether or the header field has the given name. + fn has_name(&self, name: &str) -> bool { + self.name.eq_ignore_ascii_case(name) + } +} + +/// HTTP message body length as defined by [RFC 7230]. +/// +/// [RFC 7230]: https://tools.ietf.org/html/rfc7230#section-3.3.3 +enum HttpMessageLength { + Empty, + ContentLength(usize), + TransferEncoding(String), +} + +/// An adaptor work making `tokio::io::AsyncRead` compatible with interfaces expecting +/// `std::io::Read`. This effectively makes the adapted object synchronous. +#[cfg(feature = "tokio")] +struct ReadAdapter<'a, R: tokio::io::AsyncRead + std::marker::Unpin>(&'a mut R); + +#[cfg(feature = "tokio")] +impl<'a, R: tokio::io::AsyncRead + std::marker::Unpin> std::io::Read for ReadAdapter<'a, R> { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + futures::executor::block_on(self.0.read(buf)) + } +} + +#[cfg(feature = "rest-client")] +pub struct RESTClient { + endpoint: HttpEndpoint, +} + +#[cfg(feature = "rest-client")] +impl RESTClient { + pub fn new(endpoint: HttpEndpoint) -> Self { + Self { endpoint } + } + + async fn request_resource(&self, resource_path: &str) -> std::io::Result + where F: TryFrom, Error = std::io::Error> + TryInto { + let host = format!("{}:{}", self.endpoint.host(), self.endpoint.port()); + let uri = format!("{}/{}", self.endpoint.path().trim_end_matches("/"), resource_path); + + let mut client = HttpClient::connect(&self.endpoint)?; + client.get::(&uri, &host).await?.try_into() + } +} + +#[cfg(feature = "rpc-client")] +pub struct RPCClient { + basic_auth: String, + endpoint: HttpEndpoint, + id: AtomicUsize, +} + +#[cfg(feature = "rpc-client")] +impl RPCClient { + pub fn new(user_auth: &str, endpoint: HttpEndpoint) -> Self { + Self { + basic_auth: "Basic ".to_string() + &base64::encode(user_auth), + endpoint, + id: AtomicUsize::new(0), + } + } + + async fn call_method(&self, method: &str, params: &[serde_json::Value]) -> std::io::Result + where JsonResponse: TryFrom, Error = std::io::Error> + TryInto { + let host = format!("{}:{}", self.endpoint.host(), self.endpoint.port()); + let uri = self.endpoint.path(); + let content = serde_json::json!({ + "method": method, + "params": params, + "id": &self.id.fetch_add(1, Ordering::AcqRel).to_string() + }); + + let mut client = HttpClient::connect(&self.endpoint)?; + let mut response = client.post::(&uri, &host, &self.basic_auth, content).await?.0; + if !response.is_object() { + return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected JSON object")); + } + + let error = &response["error"]; + if !error.is_null() { + // TODO: Examine error code for a more precise std::io::ErrorKind. + let message = error["message"].as_str().unwrap_or("unknown error"); + return Err(std::io::Error::new(std::io::ErrorKind::Other, message)); + } + + let result = &mut response["result"]; + if result.is_null() { + return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected JSON result")); + } + + JsonResponse(result.take()).try_into() + } +} + +#[derive(Deserialize)] +struct GetHeaderResponse { + pub chainwork: String, + pub height: u32, + + pub version: i32, + pub merkleroot: String, + pub time: u32, + pub nonce: u32, + pub bits: String, + pub previousblockhash: String, +} + +/// Converts from `GetHeaderResponse` to `BlockHeaderData`. +impl TryFrom for BlockHeaderData { + type Error = bitcoin::hashes::hex::Error; + + fn try_from(response: GetHeaderResponse) -> Result { + Ok(BlockHeaderData { + chainwork: hex_to_uint256(&response.chainwork)?, + height: response.height, + header: BlockHeader { + version: response.version, + prev_blockhash: BlockHash::from_hex(&response.previousblockhash)?, + merkle_root: TxMerkleNode::from_hex(&response.merkleroot)?, + time: response.time, + bits: u32::from_be_bytes(<[u8; 4]>::from_hex(&response.bits)?), + nonce: response.nonce, + }, + }) + } +} + +#[cfg(feature = "rpc-client")] +impl BlockSource for RPCClient { + fn get_header<'a>(&'a mut self, header_hash: &'a BlockHash, _height: Option) -> AsyncBlockSourceResult<'a, BlockHeaderData> { + Box::pin(async move { + let header_hash = serde_json::json!(header_hash.to_hex()); + Ok(self.call_method("getblockheader", &[header_hash]).await?) + }) + } + + fn get_block<'a>(&'a mut self, header_hash: &'a BlockHash) -> AsyncBlockSourceResult<'a, Block> { + Box::pin(async move { + let header_hash = serde_json::json!(header_hash.to_hex()); + let verbosity = serde_json::json!(0); + Ok(self.call_method("getblock", &[header_hash, verbosity]).await?) + }) + } + + fn get_best_block<'a>(&'a mut self) -> AsyncBlockSourceResult<'a, (BlockHash, Option)> { + Box::pin(async move { + Ok(self.call_method("getblockchaininfo", &[]).await?) + }) + } +} + +#[cfg(feature = "rest-client")] +impl BlockSource for RESTClient { + fn get_header<'a>(&'a mut self, header_hash: &'a BlockHash, _height: Option) -> AsyncBlockSourceResult<'a, BlockHeaderData> { + Box::pin(async move { + let resource_path = format!("headers/1/{}.json", header_hash.to_hex()); + Ok(self.request_resource::(&resource_path).await?) + }) + } + + fn get_block<'a>(&'a mut self, header_hash: &'a BlockHash) -> AsyncBlockSourceResult<'a, Block> { + Box::pin(async move { + let resource_path = format!("block/{}.bin", header_hash.to_hex()); + Ok(self.request_resource::(&resource_path).await?) + }) + } + + fn get_best_block<'a>(&'a mut self) -> AsyncBlockSourceResult<'a, (BlockHash, Option)> { + Box::pin(async move { + Ok(self.request_resource::("chaininfo.json").await?) + }) + } +} + +/// An HTTP response body in binary format. +struct BinaryResponse(Vec); + +/// An HTTP response body in JSON format. +struct JsonResponse(serde_json::Value); + +/// Interprets bytes from an HTTP response body as binary data. +impl TryFrom> for BinaryResponse { + type Error = std::io::Error; + + fn try_from(bytes: Vec) -> std::io::Result { + Ok(BinaryResponse(bytes)) + } +} + +/// Parses binary data as a block. +impl TryInto for BinaryResponse { + type Error = std::io::Error; + + fn try_into(self) -> std::io::Result { + match encode::deserialize(&self.0) { + Err(_) => return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid block data")), + Ok(block) => Ok(block), + } + } +} + +/// Interprets bytes from an HTTP response body as a JSON value. +impl TryFrom> for JsonResponse { + type Error = std::io::Error; + + fn try_from(bytes: Vec) -> std::io::Result { + Ok(JsonResponse(serde_json::from_slice(&bytes)?)) + } +} + +/// Converts a JSON value into block header data. The JSON value may be an object representing a +/// block header or an array of such objects. In the latter case, the first object is converted. +impl TryInto for JsonResponse { + type Error = std::io::Error; + + fn try_into(self) -> std::io::Result { + let mut header = match self.0 { + serde_json::Value::Array(mut array) if !array.is_empty() => array.drain(..).next().unwrap(), + serde_json::Value::Object(_) => self.0, + _ => return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "unexpected JSON type")), + }; + + if !header.is_object() { + return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected JSON object")); + } + + // Add an empty previousblockhash for the genesis block. + if let None = header.get("previousblockhash") { + let hash: BlockHash = Default::default(); + header.as_object_mut().unwrap().insert("previousblockhash".to_string(), serde_json::json!(hash.to_hex())); + } + + match serde_json::from_value::(header) { + Err(_) => Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid header response")), + Ok(response) => match response.try_into() { + Err(_) => Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid header data")), + Ok(header) => Ok(header), + }, + } + } +} + +/// Converts a JSON value into a block. Assumes the block is hex-encoded in a JSON string. +impl TryInto for JsonResponse { + type Error = std::io::Error; + + fn try_into(self) -> std::io::Result { + match self.0.as_str() { + None => Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected JSON string")), + Some(hex_data) => match Vec::::from_hex(hex_data) { + Err(_) => Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid hex data")), + Ok(block_data) => match encode::deserialize(&block_data) { + Err(_) => Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid block data")), + Ok(block) => Ok(block), + }, + }, + } + } +} + +/// Converts a JSON value into the best block hash and optional height. +impl TryInto<(BlockHash, Option)> for JsonResponse { + type Error = std::io::Error; + + fn try_into(self) -> std::io::Result<(BlockHash, Option)> { + if !self.0.is_object() { + return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected JSON object")); + } + + let hash = match &self.0["bestblockhash"] { + serde_json::Value::String(hex_data) => match BlockHash::from_hex(&hex_data) { + Err(_) => return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid hex data")), + Ok(block_hash) => block_hash, + }, + _ => return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected JSON string")), + }; + + let height = match &self.0["blocks"] { + serde_json::Value::Null => None, + serde_json::Value::Number(height) => match height.as_u64() { + None => return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid height")), + Some(height) => match height.try_into() { + Err(_) => return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid height")), + Ok(height) => Some(height), + } + }, + _ => return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected JSON number")), + }; + + Ok((hash, height)) + } +} + +/// Conversion from `std::io::Error` into `BlockSourceError`. +impl From for BlockSourceError { + fn from(e: std::io::Error) -> BlockSourceError { + match e.kind() { + std::io::ErrorKind::InvalidData => BlockSourceError::Persistent, + std::io::ErrorKind::InvalidInput => BlockSourceError::Persistent, + _ => BlockSourceError::Transient, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::BufRead; + use std::io::Write; + use bitcoin::blockdata::constants::genesis_block; + use bitcoin::consensus::encode; + use bitcoin::network::constants::Network; + + /// Server for handling HTTP client requests with a stock response. + struct HttpServer { + address: std::net::SocketAddr, + handler: std::thread::JoinHandle<()>, + shutdown: std::sync::Arc, + } + + /// Body of HTTP response messages. + enum MessageBody { + Empty, + Content(T), + ChunkedContent(T), + } + + impl HttpServer { + fn responding_with_ok(body: MessageBody) -> Self { + let response = match body { + MessageBody::Empty => "HTTP/1.1 200 OK\r\n\r\n".to_string(), + MessageBody::Content(body) => { + let body = body.to_string(); + format!( + "HTTP/1.1 200 OK\r\n\ + Content-Length: {}\r\n\ + \r\n\ + {}", body.len(), body) + }, + MessageBody::ChunkedContent(body) => { + let mut chuncked_body = Vec::new(); + { + use chunked_transfer::Encoder; + let mut encoder = Encoder::with_chunks_size(&mut chuncked_body, 8); + encoder.write_all(body.to_string().as_bytes()).unwrap(); + } + format!( + "HTTP/1.1 200 OK\r\n\ + Transfer-Encoding: chunked\r\n\ + \r\n\ + {}", String::from_utf8(chuncked_body).unwrap()) + }, + }; + HttpServer::responding_with(response) + } + + fn responding_with_not_found() -> Self { + let response = "HTTP/1.1 404 Not Found\r\n\r\n".to_string(); + HttpServer::responding_with(response) + } + + fn responding_with(response: String) -> Self { + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let address = listener.local_addr().unwrap(); + + let shutdown = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)); + let shutdown_signaled = std::sync::Arc::clone(&shutdown); + let handler = std::thread::spawn(move || { + let (mut stream, _) = listener.accept().unwrap(); + stream.set_write_timeout(Some(Duration::from_secs(1))).unwrap(); + + let lines_read = std::io::BufReader::new(&stream) + .lines() + .take_while(|line| !line.as_ref().unwrap().is_empty()) + .count(); + if lines_read == 0 { return; } + + for chunk in response.as_bytes().chunks(16) { + if shutdown_signaled.load(std::sync::atomic::Ordering::SeqCst) { + break; + } else { + stream.write(chunk).unwrap(); + stream.flush().unwrap(); + } + } + }); + + Self { address, handler, shutdown } + } + + fn shutdown(self) { + self.shutdown.store(true, std::sync::atomic::Ordering::SeqCst); + self.handler.join().unwrap(); + } + + fn endpoint(&self) -> HttpEndpoint { + HttpEndpoint::insecure_host(self.address.ip().to_string()) + .with_port(self.address.port()) + } + } + + /// Parses binary data as string-encoded u32. + impl TryInto for BinaryResponse { + type Error = std::io::Error; + + fn try_into(self) -> std::io::Result { + match std::str::from_utf8(&self.0) { + Err(e) => Err(std::io::Error::new(std::io::ErrorKind::InvalidData, e)), + Ok(s) => match u32::from_str_radix(s, 10) { + Err(e) => Err(std::io::Error::new(std::io::ErrorKind::InvalidData, e)), + Ok(n) => Ok(n), + } + } + } + } + + /// Converts a JSON value into u64. + impl TryInto for JsonResponse { + type Error = std::io::Error; + + fn try_into(self) -> std::io::Result { + match self.0.as_u64() { + None => Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "not a number")), + Some(n) => Ok(n), + } + } + } + + /// Converts from `BlockHeaderData` into a `GetHeaderResponse` JSON value. + impl From for serde_json::Value { + fn from(data: BlockHeaderData) -> Self { + let BlockHeaderData { chainwork, height, header } = data; + serde_json::json!({ + "chainwork": chainwork.to_string()["0x".len()..], + "height": height, + "version": header.version, + "merkleroot": header.merkle_root.to_hex(), + "time": header.time, + "nonce": header.nonce, + "bits": header.bits.to_hex(), + "previousblockhash": header.prev_blockhash.to_hex(), + }) + } + } + + #[test] + fn connect_to_unresolvable_host() { + match HttpClient::connect(("example.invalid", 80)) { + Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::Other), + Ok(_) => panic!("Expected error"), + } + } + + #[test] + fn connect_with_no_socket_address() { + match HttpClient::connect(&vec![][..]) { + Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::InvalidInput), + Ok(_) => panic!("Expected error"), + } + } + + #[test] + fn connect_with_unknown_server() { + match HttpClient::connect(("::", 80)) { + Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::ConnectionRefused), + Ok(_) => panic!("Expected error"), + } + } + + #[tokio::test] + async fn connect_with_valid_endpoint() { + let server = HttpServer::responding_with_ok::(MessageBody::Empty); + + match HttpClient::connect(&server.endpoint()) { + Err(e) => panic!("Unexpected error: {:?}", e), + Ok(_) => {}, + } + } + + #[tokio::test] + async fn read_empty_message() { + let server = HttpServer::responding_with("".to_string()); + + let mut client = HttpClient::connect(&server.endpoint()).unwrap(); + drop(server); + match client.get::("/foo", "foo.com").await { + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); + assert_eq!(e.get_ref().unwrap().to_string(), "no status line"); + }, + Ok(_) => panic!("Expected error"), + } + } + + #[tokio::test] + async fn read_incomplete_message() { + let server = HttpServer::responding_with("HTTP/1.1 200 OK".to_string()); + + let mut client = HttpClient::connect(&server.endpoint()).unwrap(); + drop(server); + match client.get::("/foo", "foo.com").await { + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); + assert_eq!(e.get_ref().unwrap().to_string(), "unexpected eof"); + }, + Ok(_) => panic!("Expected error"), + } + } + + #[tokio::test] + async fn read_too_large_message_headers() { + let response = format!( + "HTTP/1.1 302 Found\r\n\ + Location: {}\r\n\ + \r\n", "Z".repeat(MAX_HTTP_MESSAGE_HEADER_SIZE)); + let server = HttpServer::responding_with(response); + + let mut client = HttpClient::connect(&server.endpoint()).unwrap(); + match client.get::("/foo", "foo.com").await { + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); + assert_eq!(e.get_ref().unwrap().to_string(), "unexpected eof"); + }, + Ok(_) => panic!("Expected error"), + } + } + + #[tokio::test] + async fn read_too_large_message_body() { + let body = "Z".repeat(MAX_HTTP_MESSAGE_BODY_SIZE + 1); + let server = HttpServer::responding_with_ok::(MessageBody::Content(body)); + + let mut client = HttpClient::connect(&server.endpoint()).unwrap(); + match client.get::("/foo", "foo.com").await { + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); + assert_eq!(e.get_ref().unwrap().to_string(), "out of range"); + }, + Ok(_) => panic!("Expected error"), + } + server.shutdown(); + } + + #[tokio::test] + async fn read_message_with_unsupported_transfer_coding() { + let response = String::from( + "HTTP/1.1 200 OK\r\n\ + Transfer-Encoding: gzip\r\n\ + \r\n\ + foobar"); + let server = HttpServer::responding_with(response); + + let mut client = HttpClient::connect(&server.endpoint()).unwrap(); + match client.get::("/foo", "foo.com").await { + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::InvalidInput); + assert_eq!(e.get_ref().unwrap().to_string(), "unsupported transfer coding"); + }, + Ok(_) => panic!("Expected error"), + } + } + + #[tokio::test] + async fn read_empty_message_body() { + let server = HttpServer::responding_with_ok::(MessageBody::Empty); + + let mut client = HttpClient::connect(&server.endpoint()).unwrap(); + match client.get::("/foo", "foo.com").await { + Err(e) => panic!("Unexpected error: {:?}", e), + Ok(bytes) => assert_eq!(bytes.0, Vec::::new()), + } + } + + #[tokio::test] + async fn read_message_body_with_length() { + let body = "foo bar baz qux".repeat(32); + let content = MessageBody::Content(body.clone()); + let server = HttpServer::responding_with_ok::(content); + + let mut client = HttpClient::connect(&server.endpoint()).unwrap(); + match client.get::("/foo", "foo.com").await { + Err(e) => panic!("Unexpected error: {:?}", e), + Ok(bytes) => assert_eq!(bytes.0, body.as_bytes()), + } + } + + #[tokio::test] + async fn read_chunked_message_body() { + let body = "foo bar baz qux".repeat(32); + let chunked_content = MessageBody::ChunkedContent(body.clone()); + let server = HttpServer::responding_with_ok::(chunked_content); + + let mut client = HttpClient::connect(&server.endpoint()).unwrap(); + match client.get::("/foo", "foo.com").await { + Err(e) => panic!("Unexpected error: {:?}", e), + Ok(bytes) => assert_eq!(bytes.0, body.as_bytes()), + } + } + + #[tokio::test] + async fn request_unknown_resource() { + let server = HttpServer::responding_with_not_found(); + let client = RESTClient::new(server.endpoint()); + + match client.request_resource::("/").await { + Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::NotFound), + Ok(_) => panic!("Expected error"), + } + } + + #[tokio::test] + async fn request_malformed_resource() { + let server = HttpServer::responding_with_ok(MessageBody::Content("foo")); + let client = RESTClient::new(server.endpoint()); + + match client.request_resource::("/").await { + Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::InvalidData), + Ok(_) => panic!("Expected error"), + } + } + + #[tokio::test] + async fn request_valid_resource() { + let server = HttpServer::responding_with_ok(MessageBody::Content(42)); + let client = RESTClient::new(server.endpoint()); + + match client.request_resource::("/").await { + Err(e) => panic!("Unexpected error: {:?}", e), + Ok(n) => assert_eq!(n, 42), + } + } + + #[tokio::test] + async fn call_method_returning_unknown_response() { + let server = HttpServer::responding_with_not_found(); + let client = RPCClient::new("credentials", server.endpoint()); + + match client.call_method::("getblockcount", &[]).await { + Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::NotFound), + Ok(_) => panic!("Expected error"), + } + } + + #[tokio::test] + async fn call_method_returning_malfomred_response() { + let response = serde_json::json!("foo"); + let server = HttpServer::responding_with_ok(MessageBody::Content(response)); + let client = RPCClient::new("credentials", server.endpoint()); + + match client.call_method::("getblockcount", &[]).await { + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); + assert_eq!(e.get_ref().unwrap().to_string(), "expected JSON object"); + }, + Ok(_) => panic!("Expected error"), + } + } + + #[tokio::test] + async fn call_method_returning_error() { + let response = serde_json::json!({ + "error": { "code": -8, "message": "invalid parameter" }, + }); + let server = HttpServer::responding_with_ok(MessageBody::Content(response)); + let client = RPCClient::new("credentials", server.endpoint()); + + let invalid_block_hash = serde_json::json!("foo"); + match client.call_method::("getblock", &[invalid_block_hash]).await { + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::Other); + assert_eq!(e.get_ref().unwrap().to_string(), "invalid parameter"); + }, + Ok(_) => panic!("Expected error"), + } + } + + #[tokio::test] + async fn call_method_returning_missing_result() { + let response = serde_json::json!({ "result": null }); + let server = HttpServer::responding_with_ok(MessageBody::Content(response)); + let client = RPCClient::new("credentials", server.endpoint()); + + match client.call_method::("getblockcount", &[]).await { + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); + assert_eq!(e.get_ref().unwrap().to_string(), "expected JSON result"); + }, + Ok(_) => panic!("Expected error"), + } + } + + #[tokio::test] + async fn call_method_returning_valid_result() { + let response = serde_json::json!({ "result": 654470 }); + let server = HttpServer::responding_with_ok(MessageBody::Content(response)); + let client = RPCClient::new("credentials", server.endpoint()); + + match client.call_method::("getblockcount", &[]).await { + Err(e) => panic!("Unexpected error: {:?}", e), + Ok(count) => assert_eq!(count, 654470), + } + } + + #[test] + fn from_bytes_into_binary_response() { + let bytes = b"foo"; + match BinaryResponse::try_from(bytes.to_vec()) { + Err(e) => panic!("Unexpected error: {:?}", e), + Ok(response) => assert_eq!(&response.0, bytes), + } + } + + #[test] + fn from_invalid_bytes_into_json_response() { + let json = serde_json::json!({ "result": 42 }); + match JsonResponse::try_from(json.to_string().as_bytes()[..5].to_vec()) { + Err(_) => {}, + Ok(_) => panic!("Expected error"), + } + } + + #[test] + fn from_valid_bytes_into_json_response() { + let json = serde_json::json!({ "result": 42 }); + match JsonResponse::try_from(json.to_string().as_bytes().to_vec()) { + Err(e) => panic!("Unexpected error: {:?}", e), + Ok(response) => assert_eq!(response.0, json), + } + } + + #[test] + fn into_block_header_from_json_response_with_unexpected_type() { + let response = JsonResponse(serde_json::json!(42)); + match TryInto::::try_into(response) { + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); + assert_eq!(e.get_ref().unwrap().to_string(), "unexpected JSON type"); + }, + Ok(_) => panic!("Expected error"), + } + } + + #[test] + fn into_block_header_from_json_response_with_unexpected_header_type() { + let response = JsonResponse(serde_json::json!([42])); + match TryInto::::try_into(response) { + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); + assert_eq!(e.get_ref().unwrap().to_string(), "expected JSON object"); + }, + Ok(_) => panic!("Expected error"), + } + } + + #[test] + fn into_block_header_from_json_response_with_invalid_header_response() { + let block = genesis_block(Network::Bitcoin); + let mut response = JsonResponse(BlockHeaderData { + chainwork: block.header.work(), + height: 0, + header: block.header + }.into()); + response.0["chainwork"].take(); + + match TryInto::::try_into(response) { + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); + assert_eq!(e.get_ref().unwrap().to_string(), "invalid header response"); + }, + Ok(_) => panic!("Expected error"), + } + } + + #[test] + fn into_block_header_from_json_response_with_invalid_header_data() { + let block = genesis_block(Network::Bitcoin); + let mut response = JsonResponse(BlockHeaderData { + chainwork: block.header.work(), + height: 0, + header: block.header + }.into()); + response.0["chainwork"] = serde_json::json!("foobar"); + + match TryInto::::try_into(response) { + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); + assert_eq!(e.get_ref().unwrap().to_string(), "invalid header data"); + }, + Ok(_) => panic!("Expected error"), + } + } + + #[test] + fn into_block_header_from_json_response_with_valid_header() { + let block = genesis_block(Network::Bitcoin); + let response = JsonResponse(BlockHeaderData { + chainwork: block.header.work(), + height: 0, + header: block.header + }.into()); + + match TryInto::::try_into(response) { + Err(e) => panic!("Unexpected error: {:?}", e), + Ok(data) => { + assert_eq!(data.chainwork, block.header.work()); + assert_eq!(data.height, 0); + assert_eq!(data.header, block.header); + }, + } + } + + #[test] + fn into_block_header_from_json_response_with_valid_header_array() { + let genesis_block = genesis_block(Network::Bitcoin); + let best_block_header = BlockHeader { + prev_blockhash: genesis_block.block_hash(), + ..genesis_block.header + }; + let chainwork = genesis_block.header.work() + best_block_header.work(); + let response = JsonResponse(serde_json::json!([ + serde_json::Value::from(BlockHeaderData { + chainwork, height: 1, header: best_block_header, + }), + serde_json::Value::from(BlockHeaderData { + chainwork: genesis_block.header.work(), height: 0, header: genesis_block.header, + }), + ])); + + match TryInto::::try_into(response) { + Err(e) => panic!("Unexpected error: {:?}", e), + Ok(data) => { + assert_eq!(data.chainwork, chainwork); + assert_eq!(data.height, 1); + assert_eq!(data.header, best_block_header); + }, + } + } + + #[test] + fn into_block_header_from_json_response_without_previous_block_hash() { + let block = genesis_block(Network::Bitcoin); + let mut response = JsonResponse(BlockHeaderData { + chainwork: block.header.work(), + height: 0, + header: block.header + }.into()); + response.0.as_object_mut().unwrap().remove("previousblockhash"); + + match TryInto::::try_into(response) { + Err(e) => panic!("Unexpected error: {:?}", e), + Ok(BlockHeaderData { chainwork: _, height: _, header }) => { + assert_eq!(header, block.header); + }, + } + } + + #[test] + fn into_block_from_invalid_binary_response() { + let response = BinaryResponse(b"foo".to_vec()); + match TryInto::::try_into(response) { + Err(_) => {}, + Ok(_) => panic!("Expected error"), + } + } + + #[test] + fn into_block_from_valid_binary_response() { + let genesis_block = genesis_block(Network::Bitcoin); + let response = BinaryResponse(encode::serialize(&genesis_block)); + match TryInto::::try_into(response) { + Err(e) => panic!("Unexpected error: {:?}", e), + Ok(block) => assert_eq!(block, genesis_block), + } + } + + #[test] + fn into_block_from_json_response_with_unexpected_type() { + let response = JsonResponse(serde_json::json!({ "result": "foo" })); + match TryInto::::try_into(response) { + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); + assert_eq!(e.get_ref().unwrap().to_string(), "expected JSON string"); + }, + Ok(_) => panic!("Expected error"), + } + } + + #[test] + fn into_block_from_json_response_with_invalid_hex_data() { + let response = JsonResponse(serde_json::json!("foobar")); + match TryInto::::try_into(response) { + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); + assert_eq!(e.get_ref().unwrap().to_string(), "invalid hex data"); + }, + Ok(_) => panic!("Expected error"), + } + } + + #[test] + fn into_block_from_json_response_with_invalid_block_data() { + let response = JsonResponse(serde_json::json!("abcd")); + match TryInto::::try_into(response) { + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); + assert_eq!(e.get_ref().unwrap().to_string(), "invalid block data"); + }, + Ok(_) => panic!("Expected error"), + } + } + + #[test] + fn into_block_from_json_response_with_valid_block_data() { + let genesis_block = genesis_block(Network::Bitcoin); + let response = JsonResponse(serde_json::json!(encode::serialize_hex(&genesis_block))); + match TryInto::::try_into(response) { + Err(e) => panic!("Unexpected error: {:?}", e), + Ok(block) => assert_eq!(block, genesis_block), + } + } + + #[test] + fn into_block_hash_from_json_response_with_unexpected_type() { + let response = JsonResponse(serde_json::json!("foo")); + match TryInto::<(BlockHash, Option)>::try_into(response) { + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); + assert_eq!(e.get_ref().unwrap().to_string(), "expected JSON object"); + }, + Ok(_) => panic!("Expected error"), + } + } + + #[test] + fn into_block_hash_from_json_response_with_unexpected_bestblockhash_type() { + let response = JsonResponse(serde_json::json!({ "bestblockhash": 42 })); + match TryInto::<(BlockHash, Option)>::try_into(response) { + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); + assert_eq!(e.get_ref().unwrap().to_string(), "expected JSON string"); + }, + Ok(_) => panic!("Expected error"), + } + } + + #[test] + fn into_block_hash_from_json_response_with_invalid_hex_data() { + let response = JsonResponse(serde_json::json!({ "bestblockhash": "foobar"} )); + match TryInto::<(BlockHash, Option)>::try_into(response) { + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); + assert_eq!(e.get_ref().unwrap().to_string(), "invalid hex data"); + }, + Ok(_) => panic!("Expected error"), + } + } + + #[test] + fn into_block_hash_from_json_response_without_height() { + let block = genesis_block(Network::Bitcoin); + let response = JsonResponse(serde_json::json!({ + "bestblockhash": block.block_hash().to_hex(), + })); + match TryInto::<(BlockHash, Option)>::try_into(response) { + Err(e) => panic!("Unexpected error: {:?}", e), + Ok((hash, height)) => { + assert_eq!(hash, block.block_hash()); + assert!(height.is_none()); + }, + } + } + + #[test] + fn into_block_hash_from_json_response_with_unexpected_blocks_type() { + let block = genesis_block(Network::Bitcoin); + let response = JsonResponse(serde_json::json!({ + "bestblockhash": block.block_hash().to_hex(), + "blocks": "foo", + })); + match TryInto::<(BlockHash, Option)>::try_into(response) { + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); + assert_eq!(e.get_ref().unwrap().to_string(), "expected JSON number"); + }, + Ok(_) => panic!("Expected error"), + } + } + + #[test] + fn into_block_hash_from_json_response_with_invalid_height() { + let block = genesis_block(Network::Bitcoin); + let response = JsonResponse(serde_json::json!({ + "bestblockhash": block.block_hash().to_hex(), + "blocks": u64::MAX, + })); + match TryInto::<(BlockHash, Option)>::try_into(response) { + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); + assert_eq!(e.get_ref().unwrap().to_string(), "invalid height"); + }, + Ok(_) => panic!("Expected error"), + } + } + + #[test] + fn into_block_hash_from_json_response_with_height() { + let block = genesis_block(Network::Bitcoin); + let response = JsonResponse(serde_json::json!({ + "bestblockhash": block.block_hash().to_hex(), + "blocks": 1, + })); + match TryInto::<(BlockHash, Option)>::try_into(response) { + Err(e) => panic!("Unexpected error: {:?}", e), + Ok((hash, height)) => { + assert_eq!(hash, block.block_hash()); + assert_eq!(height.unwrap(), 1); + }, + } + } +} diff --git a/lightning-block-sync/src/http_endpoint.rs b/lightning-block-sync/src/http_endpoint.rs new file mode 100644 index 00000000000..d773b12ebeb --- /dev/null +++ b/lightning-block-sync/src/http_endpoint.rs @@ -0,0 +1,139 @@ +/// Endpoint for interacting with an HTTP-based API. +#[derive(Debug)] +pub struct HttpEndpoint { + scheme: Scheme, + host: String, + port: Option, + path: String, +} + +/// URI scheme compatible with an HTTP endpoint. +#[derive(Debug)] +pub enum Scheme { + HTTP, + HTTPS, +} + +impl HttpEndpoint { + /// Creates an endpoint using the HTTP scheme. + pub fn insecure_host(host: String) -> Self { + Self { + scheme: Scheme::HTTP, + host, + port: None, + path: String::from("/"), + } + } + + /// Creates an endpoint using the HTTPS scheme. + pub fn secure_host(host: String) -> Self { + Self { + scheme: Scheme::HTTPS, + host, + port: None, + path: String::from("/"), + } + } + + /// Specifies a port to use with the endpoint. + pub fn with_port(mut self, port: u16) -> Self { + self.port = Some(port); + self + } + + /// Specifies a path to use with the endpoint. + pub fn with_path(mut self, path: String) -> Self { + self.path = path; + self + } + + /// Returns the endpoint host. + pub fn host(&self) -> &str { + &self.host + } + + /// Returns the endpoint port. + pub fn port(&self) -> u16 { + match self.port { + None => match self.scheme { + Scheme::HTTP => 80, + Scheme::HTTPS => 443, + }, + Some(port) => port, + } + } + + /// Returns the endpoint path. + pub fn path(&self) -> &str { + &self.path + } +} + +impl<'a> std::net::ToSocketAddrs for &'a HttpEndpoint { + type Iter = <(&'a str, u16) as std::net::ToSocketAddrs>::Iter; + + fn to_socket_addrs(&self) -> std::io::Result { + (self.host(), self.port()).to_socket_addrs() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn to_insecure_host() { + let endpoint = HttpEndpoint::insecure_host("foo.com".into()); + assert_eq!(endpoint.host(), "foo.com"); + assert_eq!(endpoint.port(), 80); + } + + #[test] + fn to_secure_host() { + let endpoint = HttpEndpoint::secure_host("foo.com".into()); + assert_eq!(endpoint.host(), "foo.com"); + assert_eq!(endpoint.port(), 443); + } + + #[test] + fn with_custom_port() { + let endpoint = HttpEndpoint::insecure_host("foo.com".into()).with_port(8080); + assert_eq!(endpoint.host(), "foo.com"); + assert_eq!(endpoint.port(), 8080); + } + + #[test] + fn with_uri_path() { + let endpoint = HttpEndpoint::insecure_host("foo.com".into()).with_path("/path".into()); + assert_eq!(endpoint.host(), "foo.com"); + assert_eq!(endpoint.path(), "/path"); + } + + #[test] + fn without_uri_path() { + let endpoint = HttpEndpoint::insecure_host("foo.com".into()); + assert_eq!(endpoint.host(), "foo.com"); + assert_eq!(endpoint.path(), "/"); + } + + #[test] + fn convert_to_socket_addrs() { + let endpoint = HttpEndpoint::insecure_host("foo.com".into()); + let host = endpoint.host(); + let port = endpoint.port(); + + use std::net::ToSocketAddrs; + match (&endpoint).to_socket_addrs() { + Err(e) => panic!("Unexpected error: {:?}", e), + Ok(mut socket_addrs) => { + match socket_addrs.next() { + None => panic!("Expected socket address"), + Some(addr) => { + assert_eq!(addr, (host, port).to_socket_addrs().unwrap().next().unwrap()); + assert!(socket_addrs.next().is_none()); + } + } + } + } + } +} diff --git a/lightning-block-sync/src/lib.rs b/lightning-block-sync/src/lib.rs new file mode 100644 index 00000000000..461b9f4148f --- /dev/null +++ b/lightning-block-sync/src/lib.rs @@ -0,0 +1,1161 @@ +//! An implementation of a simple SPV client which can interrogate abstract block sources to keep +//! lightning objects on the best chain. +//! +//! With feature `rpc-client` we provide a client which can fetch blocks from Bitcoin Core's RPC +//! interface. +//! +//! With feature `rest-client` we provide a client which can fetch blocks from Bitcoin Core's REST +//! interface. +//! +//! Both provided clients support either blocking TCP reads from std::net::TcpStream or, with +//! feature `tokio`, tokio::net::TcpStream inside a Tokio runtime. + +#[cfg(any(feature = "rest-client", feature = "rpc-client"))] +pub mod http_clients; + +#[cfg(any(feature = "rest-client", feature = "rpc-client"))] +pub mod http_endpoint; + +pub mod poller; + +#[cfg(test)] +mod test_utils; + +#[cfg(any(feature = "rest-client", feature = "rpc-client"))] +mod utils; + +use bitcoin::blockdata::block::{Block, BlockHeader}; +use bitcoin::hash_types::BlockHash; +use bitcoin::hashes::hex::ToHex; +use bitcoin::network::constants::Network; +use bitcoin::util::uint::Uint256; + +use lightning::chain; +use lightning::chain::{chaininterface, keysinterface}; +use lightning::chain::channelmonitor::ChannelMonitor; +use lightning::ln::channelmanager::SimpleArcChannelManager; +use lightning::util::logger; + +use std::future::Future; +use std::pin::Pin; +use std::vec::Vec; + +#[derive(Clone, Copy, Debug, PartialEq)] +/// A block header and some associated data. This information should be available from most block +/// sources (and, notably, is available in Bitcoin Core's RPC and REST interfaces). +pub struct BlockHeaderData { + /// The total chain work, in expected number of double-SHA256 hashes required to build a chain + /// of equivalent weight + pub chainwork: Uint256, + /// The block height, with the genesis block heigh set to 0 + pub height: u32, + /// The block header itself + pub header: BlockHeader +} + +/// Result type for `BlockSource` requests. +type BlockSourceResult = Result; + +/// Result type for asynchronous `BlockSource` requests. +/// +/// TODO: Replace with BlockSourceResult once async trait functions are supported. For details, see: +/// https://areweasyncyet.rs. +type AsyncBlockSourceResult<'a, T> = Pin> + 'a + Send>>; + +/// Error type for requests made to a `BlockSource`. +/// +/// Transient errors may be resolved when re-polling, but no attempt will be made to re-poll on +/// persistent errors. +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum BlockSourceError { + /// Indicates an error that won't resolve when retrying a request (e.g., invalid data). + Persistent, + /// Indicates an error that may resolve when retrying a request (e.g., unresponsive). + Transient, +} + +/// Abstract type for a source of block header and block data. +pub trait BlockSource : Sync + Send { + /// Gets the header for a given hash. The height the header should be at is provided, though + /// note that you must return either the header with the requested hash, or an Err, not a + /// different header with the same eight. + /// + /// For sources which cannot find headers based on the hash, returning Transient when + /// height_hint is None is fine, though get_best_block() should never return a None for height + /// on the same source. Such a source should never be used in init_sync_listener as it + /// doesn't have any initial height information. + fn get_header<'a>(&'a mut self, header_hash: &'a BlockHash, height_hint: Option) -> AsyncBlockSourceResult<'a, BlockHeaderData>; + + /// Gets the block for a given hash. BlockSources may be headers-only, in which case they + /// should always return Err(BlockSourceError::Transient) here. + fn get_block<'a>(&'a mut self, header_hash: &'a BlockHash) -> AsyncBlockSourceResult<'a, Block>; + + /// Gets the best block hash and, optionally, its height. + /// Including the height doesn't impact the chain-scannling algorithm, but it is passed to + /// get_header() which may allow some BlockSources to more effeciently find the target header. + fn get_best_block<'a>(&'a mut self) -> AsyncBlockSourceResult<(BlockHash, Option)>; +} + +/// The `Poll` trait defines behavior for polling block sources for a chain tip and retrieving +/// related chain data. It serves as an adapter for `BlockSource`. +pub trait Poll { + /// Returns a chain tip in terms of its relationship to the provided chain tip. + fn poll_chain_tip<'a>(&'a mut self, best_known_chain_tip: ValidatedBlockHeader) -> + AsyncBlockSourceResult<'a, ChainTip>; + + /// Returns the header that preceded the given header in the chain. + fn look_up_previous_header<'a>(&'a mut self, header: &'a ValidatedBlockHeader) -> + AsyncBlockSourceResult<'a, ValidatedBlockHeader>; + + /// Returns the block associated with the given header. + fn fetch_block<'a>(&'a mut self, header: &'a ValidatedBlockHeader) -> + AsyncBlockSourceResult<'a, ValidatedBlock>; +} + +/// A chain tip relative to another chain tip in terms of block hash and chainwork. +#[derive(Clone, Debug, PartialEq)] +pub enum ChainTip { + /// A chain tip with the same hash as another chain's tip. + Common, + + /// A chain tip with more chainwork than another chain's tip. + Better(ValidatedBlockHeader), + + /// A chain tip with less or equal chainwork than another chain's tip. In either case, the + /// hashes of each tip will be different. + Worse(ValidatedBlockHeader), +} + +/// The `Validate` trait defines behavior for validating chain data. +trait Validate { + /// The validated data wrapper which can be dereferenced to obtain the validated data. + type T: std::ops::Deref; + + /// Validates the chain data against the given block hash and any criteria needed to ensure that + /// it is internally consistent. + fn validate(self, block_hash: BlockHash) -> BlockSourceResult; +} + +impl Validate for BlockHeaderData { + type T = ValidatedBlockHeader; + + fn validate(self, block_hash: BlockHash) -> BlockSourceResult { + self.header + .validate_pow(&self.header.target()) + .or(Err(BlockSourceError::Persistent))?; + + if self.header.block_hash() != block_hash { + return Err(BlockSourceError::Persistent); + } + + Ok(ValidatedBlockHeader { block_hash, inner: self }) + } +} + +impl Validate for Block { + type T = ValidatedBlock; + + fn validate(self, block_hash: BlockHash) -> BlockSourceResult { + if self.block_hash() != block_hash { + return Err(BlockSourceError::Persistent); + } + + if !self.check_merkle_root() || !self.check_witness_commitment() { + return Err(BlockSourceError::Persistent); + } + + Ok(ValidatedBlock { block_hash, inner: self }) + } +} + +/// A block header with validated proof of work and corresponding block hash. +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct ValidatedBlockHeader { + block_hash: BlockHash, + inner: BlockHeaderData, +} + +impl std::ops::Deref for ValidatedBlockHeader { + type Target = BlockHeaderData; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl ValidatedBlockHeader { + /// Checks that the header correctly builds on previous_header - the claimed work differential + /// matches the actual PoW in child_header and the difficulty transition is possible, ie within 4x. + fn check_builds_on(&self, previous_header: &ValidatedBlockHeader, network: Network) -> BlockSourceResult<()> { + if self.header.prev_blockhash != previous_header.block_hash { + return Err(BlockSourceError::Persistent); + } + + if self.height != previous_header.height + 1 { + return Err(BlockSourceError::Persistent); + } + + let work = self.header.work(); + if self.chainwork != previous_header.chainwork + work { + return Err(BlockSourceError::Persistent); + } + + if let Network::Bitcoin = network { + if self.height % 2016 == 0 { + let previous_work = previous_header.header.work(); + if work > previous_work << 2 || work < previous_work >> 2 { + return Err(BlockSourceError::Persistent) + } + } else if self.header.bits != previous_header.header.bits { + return Err(BlockSourceError::Persistent) + } + } + + Ok(()) + } +} + +/// A block with validated data against its transaction list and corresponding block hash. +pub struct ValidatedBlock { + block_hash: BlockHash, + inner: Block, +} + +impl std::ops::Deref for ValidatedBlock { + type Target = Block; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +/// Notifies listeners of blocks that have been connected or disconnected from the chain. +struct ChainNotifier { + header_cache: HeaderCache, +} + +enum ForkStep { + ForkPoint(ValidatedBlockHeader), + DisconnectBlock(ValidatedBlockHeader), + ConnectBlock(ValidatedBlockHeader), +} + +impl ChainNotifier { + /// Finds the fork point between new_header and old_header, disconnecting blocks from old_header + /// to get to that point and then connecting blocks until new_header. + /// + /// Validates headers along the transition path, but doesn't fetch blocks until the chain is + /// disconnected to the fork point. Thus, this may return an Err() that includes where the tip + /// ended up which may not be new_header. Note that iff the returned Err has a BlockHeaderData, + /// the header transition from old_header to new_header is valid. + async fn sync_listener(&mut self, new_header: ValidatedBlockHeader, old_header: &ValidatedBlockHeader, chain_poller: &mut P, chain_listener: &mut CL) -> + Result<(), (BlockSourceError, Option)> + { + let mut events = self.find_fork(new_header, old_header, chain_poller).await.map_err(|e| (e, None))?; + + let mut last_disconnect_tip = None; + let mut new_tip = None; + for event in events.iter() { + match &event { + &ForkStep::DisconnectBlock(ref header) => { + let block_hash = header.header.block_hash(); + println!("Disconnecting block {}", block_hash); + if let Some(cached_head) = self.header_cache.remove(&block_hash) { + assert_eq!(cached_head, *header); + } + chain_listener.block_disconnected(&header.header, header.height); + last_disconnect_tip = Some(header.header.prev_blockhash); + }, + &ForkStep::ForkPoint(ref header) => { + new_tip = Some(*header); + }, + _ => {}, + } + } + + // If we disconnected any blocks, we should have new tip data available. If we didn't disconnect + // any blocks we shouldn't have set a ForkPoint as there is no fork. + assert_eq!(last_disconnect_tip.is_some(), new_tip.is_some()); + if let &Some(ref tip_header) = &new_tip { + debug_assert_eq!(tip_header.header.block_hash(), *last_disconnect_tip.as_ref().unwrap()); + } else { + // Set new_tip to indicate that we got a valid header chain we wanted to connect to, but + // failed + new_tip = Some(*old_header); + } + + for event in events.drain(..).rev() { + if let ForkStep::ConnectBlock(header) = event { + let block = chain_poller + .fetch_block(&header).await + .or_else(|e| Err((e, new_tip)))?; + debug_assert_eq!(block.block_hash, header.block_hash); + + println!("Connecting block {}", header.block_hash.to_hex()); + chain_listener.block_connected(&block, header.height); + self.header_cache.insert(header.block_hash, header); + new_tip = Some(header); + } + } + Ok(()) + } + + /// Walks backwards from `current_header` and `prev_header`, finding the common ancestor. Returns + /// the steps needed to produce the chain with `current_header` as its tip from the chain with + /// `prev_header` as its tip. There is no ordering guarantee between different ForkStep types, but + /// `DisconnectBlock` and `ConnectBlock` are each returned in height-descending order. + async fn find_fork(&self, current_header: ValidatedBlockHeader, prev_header: &ValidatedBlockHeader, chain_poller: &mut P) -> BlockSourceResult> { + let mut steps = Vec::new(); + let mut current = current_header; + let mut previous = *prev_header; + loop { + // Found the parent block. + if current.height == previous.height + 1 && + current.header.prev_blockhash == previous.block_hash { + steps.push(ForkStep::ConnectBlock(current)); + break; + } + + // Found a chain fork. + if current.header.prev_blockhash == previous.header.prev_blockhash { + let fork_point = self.look_up_previous_header(chain_poller, &previous).await?; + steps.push(ForkStep::DisconnectBlock(previous)); + steps.push(ForkStep::ConnectBlock(current)); + steps.push(ForkStep::ForkPoint(fork_point)); + break; + } + + // Walk back the chain, finding blocks needed to connect and disconnect. Only walk back the + // header with the greater height, or both if equal heights. + let current_height = current.height; + let previous_height = previous.height; + if current_height <= previous_height { + steps.push(ForkStep::DisconnectBlock(previous)); + previous = self.look_up_previous_header(chain_poller, &previous).await?; + } + if current_height >= previous_height { + steps.push(ForkStep::ConnectBlock(current)); + current = self.look_up_previous_header(chain_poller, ¤t).await?; + } + } + + Ok(steps) + } + + async fn look_up_previous_header(&self, chain_poller: &mut P, header: &ValidatedBlockHeader) -> + BlockSourceResult + { + match self.header_cache.get(&header.header.prev_blockhash) { + Some(prev_header) => Ok(*prev_header), + None => chain_poller.look_up_previous_header(header).await, + } + } +} + +/// Adaptor used for notifying when blocks have been connected or disconnected from the chain. +/// Useful for replaying chain data upon deserialization. +pub trait ChainListener { + fn block_connected(&mut self, block: &Block, height: u32); + fn block_disconnected(&mut self, header: &BlockHeader, height: u32); +} + +impl ChainListener for &SimpleArcChannelManager + where M: chain::Watch, + B: chaininterface::BroadcasterInterface, + F: chaininterface::FeeEstimator, + L: logger::Logger { + fn block_connected(&mut self, block: &Block, height: u32) { + let txdata: Vec<_> = block.txdata.iter().enumerate().collect(); + (**self).block_connected(&block.header, &txdata, height); + } + fn block_disconnected(&mut self, header: &BlockHeader, _height: u32) { + (**self).block_disconnected(header); + } +} + +impl ChainListener for (&mut ChannelMonitor, &B, &F, &L) + where CS: keysinterface::ChannelKeys, + B: chaininterface::BroadcasterInterface, + F: chaininterface::FeeEstimator, + L: logger::Logger { + fn block_connected(&mut self, block: &Block, height: u32) { + let txdata: Vec<_> = block.txdata.iter().enumerate().collect(); + self.0.block_connected(&block.header, &txdata, height, self.1, self.2, self.3); + } + fn block_disconnected(&mut self, header: &BlockHeader, height: u32) { + self.0.block_disconnected(header, height, self.1, self.2, self.3); + } +} + +/// Do a one-time sync of a chain listener from a single *trusted* block source bringing its view +/// of the latest chain tip from old_block to new_block. This is useful on startup when you need +/// to bring each ChannelMonitor, as well as the overall ChannelManager, into sync with each other. +/// +/// Once you have them all at the same block, you should switch to using MicroSPVClient. +pub async fn init_sync_listener(new_block: BlockHash, old_block: BlockHash, block_source: &mut B, network: Network, chain_listener: &mut CL) { + if &old_block[..] == &[0; 32] { return; } + if old_block == new_block { return; } + + let new_header = block_source + .get_header(&new_block, None).await.unwrap() + .validate(new_block).unwrap(); + let old_header = block_source + .get_header(&old_block, None).await.unwrap() + .validate(old_block).unwrap(); + let mut chain_poller = poller::ChainPoller::new(block_source, network); + let mut chain_notifier = ChainNotifier { header_cache: HeaderCache::new() }; + chain_notifier.sync_listener(new_header, &old_header, &mut chain_poller, chain_listener).await.unwrap(); +} + +/// Unbounded cache of header data keyed by block hash. +pub(crate) type HeaderCache = std::collections::HashMap; + +/// A lightweight client for keeping a listener in sync with the chain, which is polled using one +/// one or more block sources. +/// +/// This implements a pretty bare-bones SPV client, checking all relevant commitments and finding +/// the heaviest chain, but not storing the full header chain, leading to some important +/// limitations. +/// +/// TODO: Update comment to reflect this is now the responsibility of chain_poller. +/// While we never check full difficulty transition logic, the mainnet option enables checking that +/// difficulty transitions only happen every two weeks and never shift difficulty more than 4x in +/// either direction, which is sufficient to prevent most minority hashrate attacks. +/// +/// TODO: Update comment as headers are removed from cache when blocks are disconnected. +/// We cache any headers which we connect until every block source is in agreement on the best tip. +/// This prevents one block source from being able to orphan us on a fork of its own creation by +/// not responding to requests for old headers on that fork. However, if one block source is +/// unreachable this may result in our memory usage growing in accordance with the chain. +pub struct MicroSPVClient { + chain_tip: ValidatedBlockHeader, + chain_poller: P, + chain_notifier: ChainNotifier, + chain_listener: CL, +} + +impl MicroSPVClient { + /// Creates a new `MicroSPVClient` with a chain poller for polling one or more block sources and + /// a chain listener for receiving updates of the new chain tip. + /// + /// At least one of the polled `BlockSource`s must provide the necessary headers to disconnect + /// from the given `chain_tip` back to its common ancestor with the best chain assuming that its + /// height, hash, and chainwork are correct. + /// + /// `backup_block_sources` are never queried unless we learned, via some `block_sources` source + /// that there exists a better, valid header chain but we failed to fetch the blocks. This is + /// useful when you have a block source which is more censorship-resistant than others but + /// which only provides headers. In this case, we can use such source(s) to learn of a censorship + /// attack without giving up privacy by querying a privacy-losing block sources. + pub fn new(chain_tip: ValidatedBlockHeader, chain_poller: P, chain_listener: CL) -> Self { + let header_cache = HeaderCache::new(); + let chain_notifier = ChainNotifier { header_cache }; + Self { chain_tip, chain_poller, chain_notifier, chain_listener } + } + + /// Polls for the best tip and updates the chain listener with any connected or disconnected + /// blocks accordingly. + /// + /// Returns the best polled chain tip relative to the previous best known tip and whether any + /// blocks were indeed connected or disconnected. + pub async fn poll_best_tip(&mut self) -> BlockSourceResult<(ChainTip, bool)> { + let chain_tip = self.chain_poller.poll_chain_tip(self.chain_tip).await?; + let blocks_connected = match chain_tip { + ChainTip::Common => false, + ChainTip::Better(chain_tip) => { + debug_assert_ne!(chain_tip.block_hash, self.chain_tip.block_hash); + debug_assert!(chain_tip.chainwork > self.chain_tip.chainwork); + self.update_chain_tip(chain_tip).await + }, + ChainTip::Worse(chain_tip) => { + debug_assert_ne!(chain_tip.block_hash, self.chain_tip.block_hash); + debug_assert!(chain_tip.chainwork <= self.chain_tip.chainwork); + false + }, + }; + Ok((chain_tip, blocks_connected)) + } + + /// Updates the chain tip, syncing the chain listener with any connected or disconnected + /// blocks. Returns whether there were any such blocks. + async fn update_chain_tip(&mut self, best_chain_tip: ValidatedBlockHeader) -> bool { + match self.chain_notifier.sync_listener(best_chain_tip, &self.chain_tip, &mut self.chain_poller, &mut self.chain_listener).await { + Ok(_) => { + self.chain_tip = best_chain_tip; + true + }, + Err((_, Some(chain_tip))) if chain_tip.block_hash != self.chain_tip.block_hash => { + self.chain_tip = chain_tip; + true + }, + Err(_) => false, + } + } +} + +#[cfg(test)] +mod spv_client_tests { + use crate::test_utils::{Blockchain, NullChainListener}; + use super::*; + + use bitcoin::network::constants::Network; + + #[tokio::test] + async fn poll_from_chain_without_headers() { + let mut chain = Blockchain::default().with_height(3).without_headers(); + let best_tip = chain.at_height(1); + + let poller = poller::ChainPoller::new(&mut chain, Network::Testnet); + let mut client = MicroSPVClient::new(best_tip, poller, NullChainListener {}); + match client.poll_best_tip().await { + Err(e) => assert_eq!(e, BlockSourceError::Persistent), + Ok(_) => panic!("Expected error"), + } + assert_eq!(client.chain_tip, best_tip); + } + + #[tokio::test] + async fn poll_from_chain_with_common_tip() { + let mut chain = Blockchain::default().with_height(3); + let common_tip = chain.tip(); + + let poller = poller::ChainPoller::new(&mut chain, Network::Testnet); + let mut client = MicroSPVClient::new(common_tip, poller, NullChainListener {}); + match client.poll_best_tip().await { + Err(e) => panic!("Unexpected error: {:?}", e), + Ok((chain_tip, blocks_connected)) => { + assert_eq!(chain_tip, ChainTip::Common); + assert!(!blocks_connected); + }, + } + assert_eq!(client.chain_tip, common_tip); + } + + #[tokio::test] + async fn poll_from_chain_with_better_tip() { + let mut chain = Blockchain::default().with_height(3); + let new_tip = chain.tip(); + let old_tip = chain.at_height(1); + + let poller = poller::ChainPoller::new(&mut chain, Network::Testnet); + let mut client = MicroSPVClient::new(old_tip, poller, NullChainListener {}); + match client.poll_best_tip().await { + Err(e) => panic!("Unexpected error: {:?}", e), + Ok((chain_tip, blocks_connected)) => { + assert_eq!(chain_tip, ChainTip::Better(new_tip)); + assert!(blocks_connected); + }, + } + assert_eq!(client.chain_tip, new_tip); + } + + #[tokio::test] + async fn poll_from_chain_with_better_tip_and_without_any_new_blocks() { + let mut chain = Blockchain::default().with_height(3).without_blocks(2..); + let new_tip = chain.tip(); + let old_tip = chain.at_height(1); + + let poller = poller::ChainPoller::new(&mut chain, Network::Testnet); + let mut client = MicroSPVClient::new(old_tip, poller, NullChainListener {}); + match client.poll_best_tip().await { + Err(e) => panic!("Unexpected error: {:?}", e), + Ok((chain_tip, blocks_connected)) => { + assert_eq!(chain_tip, ChainTip::Better(new_tip)); + assert!(!blocks_connected); + }, + } + assert_eq!(client.chain_tip, old_tip); + } + + #[tokio::test] + async fn poll_from_chain_with_better_tip_and_without_some_new_blocks() { + let mut chain = Blockchain::default().with_height(3).without_blocks(3..); + let new_tip = chain.tip(); + let old_tip = chain.at_height(1); + + let poller = poller::ChainPoller::new(&mut chain, Network::Testnet); + let mut client = MicroSPVClient::new(old_tip, poller, NullChainListener {}); + match client.poll_best_tip().await { + Err(e) => panic!("Unexpected error: {:?}", e), + Ok((chain_tip, blocks_connected)) => { + assert_eq!(chain_tip, ChainTip::Better(new_tip)); + assert!(blocks_connected); + }, + } + assert_eq!(client.chain_tip, chain.at_height(2)); + } + + #[tokio::test] + async fn poll_from_chain_with_worse_tip() { + let mut chain = Blockchain::default().with_height(3); + let best_tip = chain.tip(); + chain.disconnect_tip(); + let worse_tip = chain.tip(); + + let poller = poller::ChainPoller::new(&mut chain, Network::Testnet); + let mut client = MicroSPVClient::new(best_tip, poller, NullChainListener {}); + match client.poll_best_tip().await { + Err(e) => panic!("Unexpected error: {:?}", e), + Ok((chain_tip, blocks_connected)) => { + assert_eq!(chain_tip, ChainTip::Worse(worse_tip)); + assert!(!blocks_connected); + }, + } + assert_eq!(client.chain_tip, best_tip); + } +} + +#[cfg(test)] +mod chain_notifier_tests { + use crate::test_utils::{Blockchain, MockChainListener}; + use super::*; + + use bitcoin::network::constants::Network; + + #[tokio::test] + async fn sync_from_same_chain() { + let mut chain = Blockchain::default().with_height(3); + + let new_tip = chain.tip(); + let old_tip = chain.at_height(1); + let mut listener = MockChainListener::new() + .expect_block_connected(*chain.at_height(2)) + .expect_block_connected(*new_tip); + let mut notifier = ChainNotifier { header_cache: chain.header_cache(0..=1) }; + let mut poller = poller::ChainPoller::new(&mut chain, Network::Testnet); + match notifier.sync_listener(new_tip, &old_tip, &mut poller, &mut listener).await { + Err((e, _)) => panic!("Unexpected error: {:?}", e), + Ok(_) => {}, + } + } + + #[tokio::test] + async fn sync_from_different_chains() { + let mut test_chain = Blockchain::with_network(Network::Testnet).with_height(1); + let main_chain = Blockchain::with_network(Network::Bitcoin).with_height(1); + + let new_tip = test_chain.tip(); + let old_tip = main_chain.tip(); + let mut listener = MockChainListener::new(); + let mut notifier = ChainNotifier { header_cache: main_chain.header_cache(0..=1) }; + let mut poller = poller::ChainPoller::new(&mut test_chain, Network::Testnet); + match notifier.sync_listener(new_tip, &old_tip, &mut poller, &mut listener).await { + Err((e, _)) => assert_eq!(e, BlockSourceError::Persistent), + Ok(_) => panic!("Expected error"), + } + } + + #[tokio::test] + async fn sync_from_equal_length_fork() { + let main_chain = Blockchain::default().with_height(2); + let mut fork_chain = main_chain.fork_at_height(1); + + let new_tip = fork_chain.tip(); + let old_tip = main_chain.tip(); + let mut listener = MockChainListener::new() + .expect_block_disconnected(*old_tip) + .expect_block_connected(*new_tip); + let mut notifier = ChainNotifier { header_cache: main_chain.header_cache(0..=2) }; + let mut poller = poller::ChainPoller::new(&mut fork_chain, Network::Testnet); + match notifier.sync_listener(new_tip, &old_tip, &mut poller, &mut listener).await { + Err((e, _)) => panic!("Unexpected error: {:?}", e), + Ok(_) => {}, + } + } + + #[tokio::test] + async fn sync_from_shorter_fork() { + let main_chain = Blockchain::default().with_height(3); + let mut fork_chain = main_chain.fork_at_height(1); + fork_chain.disconnect_tip(); + + let new_tip = fork_chain.tip(); + let old_tip = main_chain.tip(); + let mut listener = MockChainListener::new() + .expect_block_disconnected(*old_tip) + .expect_block_disconnected(*main_chain.at_height(2)) + .expect_block_connected(*new_tip); + let mut notifier = ChainNotifier { header_cache: main_chain.header_cache(0..=3) }; + let mut poller = poller::ChainPoller::new(&mut fork_chain, Network::Testnet); + match notifier.sync_listener(new_tip, &old_tip, &mut poller, &mut listener).await { + Err((e, _)) => panic!("Unexpected error: {:?}", e), + Ok(_) => {}, + } + } + + #[tokio::test] + async fn sync_from_longer_fork() { + let mut main_chain = Blockchain::default().with_height(3); + let mut fork_chain = main_chain.fork_at_height(1); + main_chain.disconnect_tip(); + + let new_tip = fork_chain.tip(); + let old_tip = main_chain.tip(); + let mut listener = MockChainListener::new() + .expect_block_disconnected(*old_tip) + .expect_block_connected(*fork_chain.at_height(2)) + .expect_block_connected(*new_tip); + let mut notifier = ChainNotifier { header_cache: main_chain.header_cache(0..=2) }; + let mut poller = poller::ChainPoller::new(&mut fork_chain, Network::Testnet); + match notifier.sync_listener(new_tip, &old_tip, &mut poller, &mut listener).await { + Err((e, _)) => panic!("Unexpected error: {:?}", e), + Ok(_) => {}, + } + } + + #[tokio::test] + async fn sync_from_chain_without_headers() { + let mut chain = Blockchain::default().with_height(3).without_headers(); + + let new_tip = chain.tip(); + let old_tip = chain.at_height(1); + let mut listener = MockChainListener::new(); + let mut notifier = ChainNotifier { header_cache: chain.header_cache(0..=1) }; + let mut poller = poller::ChainPoller::new(&mut chain, Network::Testnet); + match notifier.sync_listener(new_tip, &old_tip, &mut poller, &mut listener).await { + Err((_, tip)) => assert_eq!(tip, None), + Ok(_) => panic!("Expected error"), + } + } + + #[tokio::test] + async fn sync_from_chain_without_any_new_blocks() { + let mut chain = Blockchain::default().with_height(3).without_blocks(2..); + + let new_tip = chain.tip(); + let old_tip = chain.at_height(1); + let mut listener = MockChainListener::new(); + let mut notifier = ChainNotifier { header_cache: chain.header_cache(0..=3) }; + let mut poller = poller::ChainPoller::new(&mut chain, Network::Testnet); + match notifier.sync_listener(new_tip, &old_tip, &mut poller, &mut listener).await { + Err((_, tip)) => assert_eq!(tip, Some(old_tip)), + Ok(_) => panic!("Expected error"), + } + } + + #[tokio::test] + async fn sync_from_chain_without_some_new_blocks() { + let mut chain = Blockchain::default().with_height(3).without_blocks(3..); + + let new_tip = chain.tip(); + let old_tip = chain.at_height(1); + let mut listener = MockChainListener::new() + .expect_block_connected(*chain.at_height(2)); + let mut notifier = ChainNotifier { header_cache: chain.header_cache(0..=3) }; + let mut poller = poller::ChainPoller::new(&mut chain, Network::Testnet); + match notifier.sync_listener(new_tip, &old_tip, &mut poller, &mut listener).await { + Err((_, tip)) => assert_eq!(tip, Some(chain.at_height(2))), + Ok(_) => panic!("Expected error"), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bitcoin::blockdata::block::{Block, BlockHeader}; + use bitcoin::util::uint::Uint256; + use std::collections::HashMap; + use std::sync::{Arc, Mutex}; + + struct TestChainListener { + blocks_connected: Mutex>, + blocks_disconnected: Mutex>, + } + impl ChainListener for Arc { + fn block_connected(&mut self, block: &Block, height: u32) { + self.blocks_connected.lock().unwrap().push((block.header.block_hash(), height)); + } + fn block_disconnected(&mut self, header: &BlockHeader, height: u32) { + self.blocks_disconnected.lock().unwrap().push((header.block_hash(), height)); + } + } + + #[derive(Clone)] + struct BlockData { + block: Block, + chainwork: Uint256, + height: u32, + } + struct Blockchain { + blocks: Mutex>, + best_block: Mutex<(BlockHash, Option)>, + headers_only: bool, + disallowed: Mutex, + } + impl BlockSource for &Blockchain { + fn get_header<'a>(&'a mut self, header_hash: &'a BlockHash, height_hint: Option) -> AsyncBlockSourceResult<'a, BlockHeaderData> { + if *self.disallowed.lock().unwrap() { unreachable!(); } + Box::pin(async move { + match self.blocks.lock().unwrap().get(header_hash) { + Some(block) => { + assert_eq!(Some(block.height), height_hint); + Ok(BlockHeaderData { + chainwork: block.chainwork, + height: block.height, + header: block.block.header.clone(), + }) + }, + None => Err(BlockSourceError::Transient), + } + }) + } + fn get_block<'a>(&'a mut self, header_hash: &'a BlockHash) -> AsyncBlockSourceResult<'a, Block> { + if *self.disallowed.lock().unwrap() { unreachable!(); } + Box::pin(async move { + if self.headers_only { + Err(BlockSourceError::Transient) + } else { + match self.blocks.lock().unwrap().get(header_hash) { + Some(block) => Ok(block.block.clone()), + None => Err(BlockSourceError::Transient), + } + } + }) + } + fn get_best_block<'a>(&'a mut self) -> AsyncBlockSourceResult<'a, (BlockHash, Option)> { + if *self.disallowed.lock().unwrap() { unreachable!(); } + Box::pin(async move { Ok(self.best_block.lock().unwrap().clone()) }) + } + } + + #[tokio::test] + async fn simple_block_connect() { + let genesis = BlockData { + block: bitcoin::blockdata::constants::genesis_block(bitcoin::network::constants::Network::Bitcoin), + chainwork: Uint256::from_u64(0).unwrap(), + height: 0, + }; + + // Build a chain based on genesis 1a, 2a, 3a, and 4a + let block_1a = BlockData { + block: Block { + header: BlockHeader { + version: 0, + prev_blockhash: genesis.block.block_hash(), + merkle_root: Default::default(), time: 0, + bits: genesis.block.header.bits, + nonce: 647569994, + }, + txdata: Vec::new(), + }, + chainwork: Uint256::from_u64(4295032833).unwrap(), + height: 1 + }; + let block_1a_hash = block_1a.block.header.block_hash(); + let block_2a = BlockData { + block: Block { + header: BlockHeader { + version: 0, + prev_blockhash: block_1a.block.block_hash(), + merkle_root: Default::default(), time: 4, + bits: genesis.block.header.bits, + nonce: 1185103332, + }, + txdata: Vec::new(), + }, + chainwork: Uint256::from_u64(4295032833 * 2).unwrap(), + height: 2 + }; + let block_2a_hash = block_2a.block.header.block_hash(); + let block_3a = BlockData { + block: Block { + header: BlockHeader { + version: 0, + prev_blockhash: block_2a.block.block_hash(), + merkle_root: Default::default(), time: 6, + bits: genesis.block.header.bits, + nonce: 198739431, + }, + txdata: Vec::new(), + }, + chainwork: Uint256::from_u64(4295032833 * 3).unwrap(), + height: 3 + }; + let block_3a_hash = block_3a.block.header.block_hash(); + let block_4a = BlockData { + block: Block { + header: BlockHeader { + version: 0, + prev_blockhash: block_3a.block.block_hash(), + merkle_root: Default::default(), time: 0, + bits: genesis.block.header.bits, + nonce: 590371681, + }, + txdata: Vec::new(), + }, + chainwork: Uint256::from_u64(4295032833 * 4).unwrap(), + height: 4 + }; + let block_4a_hash = block_4a.block.header.block_hash(); + + // Build a second chain based on genesis 1b, 2b, and 3b + let block_1b = BlockData { + block: Block { + header: BlockHeader { + version: 0, + prev_blockhash: genesis.block.block_hash(), + merkle_root: Default::default(), time: 6, + bits: genesis.block.header.bits, + nonce: 1347696353, + }, + txdata: Vec::new(), + }, + chainwork: Uint256::from_u64(4295032833).unwrap(), + height: 1 + }; + let block_1b_hash = block_1b.block.header.block_hash(); + let block_2b = BlockData { + block: Block { + header: BlockHeader { + version: 0, + prev_blockhash: block_1b.block.block_hash(), + merkle_root: Default::default(), time: 5, + bits: genesis.block.header.bits, + nonce: 144775545, + }, + txdata: Vec::new(), + }, + chainwork: Uint256::from_u64(4295032833 * 2).unwrap(), + height: 2 + }; + let block_2b_hash = block_2b.block.header.block_hash(); + + // Build a second chain based on 3a: 4c and 5c. + let block_4c = BlockData { + block: Block { + header: BlockHeader { + version: 0, + prev_blockhash: block_3a.block.block_hash(), + merkle_root: Default::default(), time: 17, + bits: genesis.block.header.bits, + nonce: 316634915, + }, + txdata: Vec::new(), + }, + chainwork: Uint256::from_u64(4295032833 * 4).unwrap(), + height: 4 + }; + let block_4c_hash = block_4c.block.header.block_hash(); + let block_5c = BlockData { + block: Block { + header: BlockHeader { + version: 0, + prev_blockhash: block_4c.block.block_hash(), + merkle_root: Default::default(), time: 3, + bits: genesis.block.header.bits, + nonce: 218413871, + }, + txdata: Vec::new(), + }, + chainwork: Uint256::from_u64(4295032833 * 5).unwrap(), + height: 5 + }; + let block_5c_hash = block_5c.block.header.block_hash(); + + // Create four block sources: + // * chain_one and chain_two are general purpose block sources which we use to test reorgs, + // * headers_chain only provides headers, + // * and backup_chain is a backup which should not receive any queries (ie disallowed is + // false) until the headers_chain gets ahead of chain_one and chain_two. + let mut blocks_one = HashMap::new(); + blocks_one.insert(genesis.block.header.block_hash(), genesis.clone()); + blocks_one.insert(block_1a_hash, block_1a.clone()); + blocks_one.insert(block_1b_hash, block_1b); + blocks_one.insert(block_2b_hash, block_2b); + let chain_one = Blockchain { + blocks: Mutex::new(blocks_one), best_block: Mutex::new((block_2b_hash, Some(2))), + headers_only: false, disallowed: Mutex::new(false) + }; + + let mut blocks_two = HashMap::new(); + blocks_two.insert(genesis.block.header.block_hash(), genesis.clone()); + blocks_two.insert(block_1a_hash, block_1a.clone()); + let chain_two = Blockchain { + blocks: Mutex::new(blocks_two), best_block: Mutex::new((block_1a_hash, Some(1))), + headers_only: false, disallowed: Mutex::new(false) + }; + + let mut blocks_three = HashMap::new(); + blocks_three.insert(genesis.block.header.block_hash(), genesis.clone()); + blocks_three.insert(block_1a_hash, block_1a.clone()); + let header_chain = Blockchain { + blocks: Mutex::new(blocks_three), best_block: Mutex::new((block_1a_hash, Some(1))), + headers_only: true, disallowed: Mutex::new(false) + }; + + let mut blocks_four = HashMap::new(); + blocks_four.insert(genesis.block.header.block_hash(), genesis); + blocks_four.insert(block_1a_hash, block_1a); + blocks_four.insert(block_2a_hash, block_2a.clone()); + blocks_four.insert(block_3a_hash, block_3a.clone()); + let backup_chain = Blockchain { + blocks: Mutex::new(blocks_four), best_block: Mutex::new((block_3a_hash, Some(3))), + headers_only: false, disallowed: Mutex::new(true) + }; + + // Stand up a client at block_1a with all four sources: + let chain_listener = Arc::new(TestChainListener { + blocks_connected: Mutex::new(Vec::new()), blocks_disconnected: Mutex::new(Vec::new()) + }); + let mut source_one = &chain_one; + let mut source_two = &chain_two; + let mut source_three = &header_chain; + let mut source_four = &backup_chain; + let mut client = MicroSPVClient::new( + (&chain_one).get_header(&block_1a_hash, Some(1)).await.unwrap().validate(block_1a_hash).unwrap(), + poller::ChainMultiplexer::new( + vec![&mut source_one as &mut dyn BlockSource, &mut source_two as &mut dyn BlockSource, &mut source_three as &mut dyn BlockSource], + vec![&mut source_four as &mut dyn BlockSource], + Network::Bitcoin), + Arc::clone(&chain_listener)); + + // Test that we will reorg onto 2b because chain_one knows about 1b + 2b + match client.poll_best_tip().await { + Ok((ChainTip::Better(chain_tip), blocks_connected)) => { + assert_eq!(chain_tip.block_hash, block_2b_hash); + assert!(blocks_connected); + }, + _ => panic!("Expected better chain tip"), + } + assert_eq!(&chain_listener.blocks_disconnected.lock().unwrap()[..], &[(block_1a_hash, 1)][..]); + assert_eq!(&chain_listener.blocks_connected.lock().unwrap()[..], &[(block_1b_hash, 1), (block_2b_hash, 2)][..]); + assert_eq!(client.chain_notifier.header_cache.len(), 2); + assert!(client.chain_notifier.header_cache.contains_key(&block_1b_hash)); + assert!(client.chain_notifier.header_cache.contains_key(&block_2b_hash)); + + // Test that even if chain_one (which we just got blocks from) stops responding to block or + // header requests we can still reorg back because we never wiped our block cache as + // chain_two always considered the "a" chain to contain the tip. We do this by simply + // wiping the blocks chain_one knows about: + chain_one.blocks.lock().unwrap().clear(); + chain_listener.blocks_connected.lock().unwrap().clear(); + chain_listener.blocks_disconnected.lock().unwrap().clear(); + + // First test that nothing happens if nothing changes: + match client.poll_best_tip().await { + Ok((ChainTip::Common, blocks_connected)) => { + assert!(!blocks_connected); + }, + _ => panic!("Expected common chain tip"), + } + assert!(chain_listener.blocks_disconnected.lock().unwrap().is_empty()); + assert!(chain_listener.blocks_connected.lock().unwrap().is_empty()); + + // Now add block 2a and 3a to chain_two and test that we reorg appropriately: + chain_two.blocks.lock().unwrap().insert(block_2a_hash, block_2a.clone()); + chain_two.blocks.lock().unwrap().insert(block_3a_hash, block_3a.clone()); + *chain_two.best_block.lock().unwrap() = (block_3a_hash, Some(3)); + + match client.poll_best_tip().await { + Ok((ChainTip::Better(chain_tip), blocks_connected)) => { + assert_eq!(chain_tip.block_hash, block_3a_hash); + assert!(blocks_connected); + }, + _ => panic!("Expected better chain tip"), + } + assert_eq!(&chain_listener.blocks_disconnected.lock().unwrap()[..], &[(block_2b_hash, 2), (block_1b_hash, 1)][..]); + assert_eq!(&chain_listener.blocks_connected.lock().unwrap()[..], &[(block_1a_hash, 1), (block_2a_hash, 2), (block_3a_hash, 3)][..]); + + // Note that blocks_past_common_tip is not wiped as chain_one still returns 2a as its tip + // (though a smarter MicroSPVClient may wipe 1a and 2a from the set eventually. + assert_eq!(client.chain_notifier.header_cache.len(), 3); + assert!(client.chain_notifier.header_cache.contains_key(&block_1a_hash)); + assert!(client.chain_notifier.header_cache.contains_key(&block_2a_hash)); + assert!(client.chain_notifier.header_cache.contains_key(&block_3a_hash)); + + chain_listener.blocks_connected.lock().unwrap().clear(); + chain_listener.blocks_disconnected.lock().unwrap().clear(); + + // Test that after chain_one and header_chain consider 3a as their tip that we won't wipe + // the block header cache. + *chain_one.best_block.lock().unwrap() = (block_3a_hash, Some(3)); + *header_chain.best_block.lock().unwrap() = (block_3a_hash, Some(3)); + match client.poll_best_tip().await { + Ok((ChainTip::Common, blocks_connected)) => { + assert!(!blocks_connected); + }, + _ => panic!("Expected common chain tip"), + } + assert!(chain_listener.blocks_disconnected.lock().unwrap().is_empty()); + assert!(chain_listener.blocks_connected.lock().unwrap().is_empty()); + + assert_eq!(client.chain_notifier.header_cache.len(), 3); + + // Test that setting the header chain to 4a does...almost nothing (though backup_chain + // should now be queried) since we can't get the blocks from anywhere. + header_chain.blocks.lock().unwrap().insert(block_2a_hash, block_2a); + header_chain.blocks.lock().unwrap().insert(block_3a_hash, block_3a); + header_chain.blocks.lock().unwrap().insert(block_4a_hash, block_4a.clone()); + *header_chain.best_block.lock().unwrap() = (block_4a_hash, Some(4)); + *backup_chain.disallowed.lock().unwrap() = false; + + match client.poll_best_tip().await { + Ok((ChainTip::Better(chain_tip), blocks_connected)) => { + assert_eq!(chain_tip.block_hash, block_4a_hash); + assert!(!blocks_connected); + }, + _ => panic!("Expected better chain tip"), + } + assert!(chain_listener.blocks_disconnected.lock().unwrap().is_empty()); + assert!(chain_listener.blocks_connected.lock().unwrap().is_empty()); + assert_eq!(client.chain_notifier.header_cache.len(), 3); + + // But if backup_chain *also* has 4a, we'll fetch it from there: + backup_chain.blocks.lock().unwrap().insert(block_4a_hash, block_4a); + *backup_chain.best_block.lock().unwrap() = (block_4a_hash, Some(4)); + + match client.poll_best_tip().await { + Ok((ChainTip::Better(chain_tip), blocks_connected)) => { + assert_eq!(chain_tip.block_hash, block_4a_hash); + assert!(blocks_connected); + }, + _ => panic!("Expected better chain tip"), + } + assert!(chain_listener.blocks_disconnected.lock().unwrap().is_empty()); + assert_eq!(&chain_listener.blocks_connected.lock().unwrap()[..], &[(block_4a_hash, 4)][..]); + assert_eq!(client.chain_notifier.header_cache.len(), 4); + assert!(client.chain_notifier.header_cache.contains_key(&block_4a_hash)); + + chain_listener.blocks_connected.lock().unwrap().clear(); + chain_listener.blocks_disconnected.lock().unwrap().clear(); + + // Note that if only headers_chain has a reorg, we'll end up in a somewhat pessimal case + // where we will disconnect and reconnect at each poll. We should fix this at some point by + // making sure we can at least fetch one block before we disconnect, but short of using a + // ton more memory there isn't much we can do in the case of two disconnects. We check that + // the disconnect happens here on a one-block-disconnected reorg, even though its + // non-normative behavior, as a good test of failing to reorg and returning back to the + // best chain. + header_chain.blocks.lock().unwrap().insert(block_4c_hash, block_4c); + header_chain.blocks.lock().unwrap().insert(block_5c_hash, block_5c); + *header_chain.best_block.lock().unwrap() = (block_5c_hash, Some(5)); + // We'll check the backup chain last, so don't give it 4a, as otherwise we'll connect it: + *backup_chain.best_block.lock().unwrap() = (block_3a_hash, Some(3)); + + match client.poll_best_tip().await { + Ok((ChainTip::Better(chain_tip), blocks_disconnected)) => { + assert_eq!(chain_tip.block_hash, block_5c_hash); + assert!(blocks_disconnected); + }, + _ => panic!("Expected better chain tip"), + } + assert_eq!(&chain_listener.blocks_disconnected.lock().unwrap()[..], &[(block_4a_hash, 4)][..]); + assert!(chain_listener.blocks_connected.lock().unwrap().is_empty()); + + chain_listener.blocks_disconnected.lock().unwrap().clear(); + + // Now reset the headers chain to 4a and test that we end up back there. + *backup_chain.best_block.lock().unwrap() = (block_4a_hash, Some(4)); + *header_chain.best_block.lock().unwrap() = (block_4a_hash, Some(4)); + match client.poll_best_tip().await { + Ok((ChainTip::Better(chain_tip), blocks_connected)) => { + assert_eq!(chain_tip.block_hash, block_4a_hash); + assert!(blocks_connected); + }, + _ => panic!("Expected better chain tip"), + } + assert!(chain_listener.blocks_disconnected.lock().unwrap().is_empty()); + assert_eq!(&chain_listener.blocks_connected.lock().unwrap()[..], &[(block_4a_hash, 4)][..]); + } +} diff --git a/lightning-block-sync/src/poller.rs b/lightning-block-sync/src/poller.rs new file mode 100644 index 00000000000..3611d6206be --- /dev/null +++ b/lightning-block-sync/src/poller.rs @@ -0,0 +1,332 @@ +use crate::{AsyncBlockSourceResult, BlockHeaderData, BlockSource, BlockSourceError, ChainTip, Poll, Validate, ValidatedBlock, ValidatedBlockHeader}; + +use bitcoin::blockdata::block::Block; +use bitcoin::hash_types::BlockHash; +use bitcoin::network::constants::Network; + +use std::ops::DerefMut; + +pub struct ChainPoller + Sized + Sync + Send, T: BlockSource> { + block_source: B, + network: Network, +} + +impl + Sized + Sync + Send, T: BlockSource> ChainPoller { + pub fn new(block_source: B, network: Network) -> Self { + Self { block_source, network } + } +} + +impl + Sized + Sync + Send, T: BlockSource> Poll for ChainPoller { + fn poll_chain_tip<'a>(&'a mut self, best_known_chain_tip: ValidatedBlockHeader) -> + AsyncBlockSourceResult<'a, ChainTip> + { + Box::pin(async move { + let (block_hash, height) = self.block_source.get_best_block().await?; + if block_hash == best_known_chain_tip.header.block_hash() { + return Ok(ChainTip::Common); + } + + let chain_tip = self.block_source + .get_header(&block_hash, height).await? + .validate(block_hash)?; + if chain_tip.chainwork > best_known_chain_tip.chainwork { + Ok(ChainTip::Better(chain_tip)) + } else { + Ok(ChainTip::Worse(chain_tip)) + } + }) + } + + fn look_up_previous_header<'a>(&'a mut self, header: &'a ValidatedBlockHeader) -> + AsyncBlockSourceResult<'a, ValidatedBlockHeader> + { + Box::pin(async move { + if header.height == 0 { + return Err(BlockSourceError::Persistent); + } + + let previous_hash = &header.header.prev_blockhash; + let height = header.height - 1; + let previous_header = self.block_source + .get_header(previous_hash, Some(height)).await? + .validate(*previous_hash)?; + header.check_builds_on(&previous_header, self.network)?; + + Ok(previous_header) + }) + } + + fn fetch_block<'a>(&'a mut self, header: &'a ValidatedBlockHeader) -> + AsyncBlockSourceResult<'a, ValidatedBlock> + { + Box::pin(async move { + self.block_source + .get_block(&header.block_hash).await? + .validate(header.block_hash) + }) + } +} + +pub struct ChainMultiplexer<'b, B: DerefMut + Sized + Sync + Send> { + network: Network, + block_sources: Vec<(B, BlockSourceError)>, + backup_block_sources: Vec<(B, BlockSourceError)>, + best_block_source: usize, +} + +struct DynamicBlockSource<'b>(&'b mut dyn BlockSource); + +impl<'b> BlockSource for DynamicBlockSource<'b>{ + fn get_header<'a>(&'a mut self, header_hash: &'a BlockHash, height_hint: Option) -> AsyncBlockSourceResult<'a, BlockHeaderData> { + Box::pin(async move { + self.0.get_header(header_hash, height_hint).await + }) + } + fn get_block<'a>(&'a mut self, header_hash: &'a BlockHash) -> AsyncBlockSourceResult<'a, Block> { + Box::pin(async move { + self.0.get_block(header_hash).await + }) + } + fn get_best_block<'a>(&'a mut self) -> AsyncBlockSourceResult<'a, (BlockHash, Option)> { + Box::pin(async move { + self.0.get_best_block().await + }) + } +} + +impl<'b, B: DerefMut + Sized + Sync + Send> ChainMultiplexer<'b, B> { + pub fn new(mut block_sources: Vec, mut backup_block_sources: Vec, network: Network) -> Self { + assert!(!block_sources.is_empty()); + let block_sources = block_sources.drain(..).map(|block_source| { + (block_source, BlockSourceError::Transient) + }).collect(); + + let backup_block_sources = backup_block_sources.drain(..).map(|block_source| { + (block_source, BlockSourceError::Transient) + }).collect(); + + Self { network, block_sources, backup_block_sources, best_block_source: 0 } + } + + fn best_and_backup_block_sources(&mut self) -> Vec<&mut (B, BlockSourceError)> { + let best_block_source = self.block_sources.get_mut(self.best_block_source).unwrap(); + let backup_block_sources = self.backup_block_sources.iter_mut(); + std::iter::once(best_block_source) + .chain(backup_block_sources) + .filter(|(_, e)| e == &BlockSourceError::Transient) + .collect() + } +} + +impl<'b, B: DerefMut + Sized + Sync + Send> Poll for ChainMultiplexer<'b, B> { + fn poll_chain_tip<'a>( + &'a mut self, + best_known_chain_tip: ValidatedBlockHeader, + ) -> AsyncBlockSourceResult<'a, ChainTip> { + Box::pin(async move { + let mut heaviest_chain_tip = best_known_chain_tip; + let mut best_result = Err(BlockSourceError::Persistent); + for (i, (block_source, error)) in self.block_sources.iter_mut().enumerate() { + if let BlockSourceError::Persistent = error { + continue; + } + + let mut block_source = DynamicBlockSource(&mut **block_source); + let mut poller = ChainPoller::new(&mut block_source, self.network); + let result = poller.poll_chain_tip(heaviest_chain_tip).await; + match result { + Err(BlockSourceError::Persistent) => { + *error = BlockSourceError::Persistent; + }, + Err(BlockSourceError::Transient) => { + if best_result.is_err() { + best_result = result; + } + }, + Ok(ChainTip::Common) => { + if let Ok(ChainTip::Better(_)) = best_result {} else { + best_result = result; + } + }, + Ok(ChainTip::Better(ref chain_tip)) => { + self.best_block_source = i; + heaviest_chain_tip = *chain_tip; + best_result = result; + }, + Ok(ChainTip::Worse(_)) => { + if best_result.is_err() { + best_result = result; + } + }, + } + } + + best_result + }) + } + + fn look_up_previous_header<'a>( + &'a mut self, + header: &'a ValidatedBlockHeader, + ) -> AsyncBlockSourceResult<'a, ValidatedBlockHeader> { + Box::pin(async move { + let network = self.network; + for (block_source, error) in self.best_and_backup_block_sources() { + let mut block_source = DynamicBlockSource(&mut **block_source); + let mut poller = ChainPoller::new(&mut block_source, network); + let result = poller.look_up_previous_header(header).await; + match result { + Err(e) => *error = e, + Ok(_) => return result, + } + } + Err(BlockSourceError::Persistent) + }) + } + + fn fetch_block<'a>( + &'a mut self, + header: &'a ValidatedBlockHeader, + ) -> AsyncBlockSourceResult<'a, ValidatedBlock> { + Box::pin(async move { + let network = self.network; + for (block_source, error) in self.best_and_backup_block_sources() { + let mut block_source = DynamicBlockSource(&mut **block_source); + let mut poller = ChainPoller::new(&mut block_source, network); + let result = poller.fetch_block(header).await; + match result { + Err(e) => *error = e, + Ok(_) => return result, + } + } + Err(BlockSourceError::Persistent) + }) + } +} + +#[cfg(test)] +mod tests { + use crate::*; + use crate::test_utils::Blockchain; + use super::*; + use bitcoin::util::uint::Uint256; + + #[tokio::test] + async fn poll_empty_chain() { + let mut chain = Blockchain::default().with_height(0); + let best_known_chain_tip = chain.tip(); + chain.disconnect_tip(); + + let mut poller = ChainPoller::new(&mut chain, Network::Bitcoin); + match poller.poll_chain_tip(best_known_chain_tip).await { + Err(e) => assert_eq!(e, BlockSourceError::Transient), + Ok(_) => panic!("Expected error"), + } + } + + #[tokio::test] + async fn poll_chain_without_headers() { + let mut chain = Blockchain::default().with_height(1).without_headers(); + let best_known_chain_tip = chain.at_height(0); + + let mut poller = ChainPoller::new(&mut chain, Network::Bitcoin); + match poller.poll_chain_tip(best_known_chain_tip).await { + Err(e) => assert_eq!(e, BlockSourceError::Persistent), + Ok(_) => panic!("Expected error"), + } + } + + #[tokio::test] + async fn poll_chain_with_invalid_pow() { + let mut chain = Blockchain::default().with_height(1); + let best_known_chain_tip = chain.at_height(0); + + // Invalidate the tip by changing its target. + chain.blocks.last_mut().unwrap().header.bits = + BlockHeader::compact_target_from_u256(&Uint256::from_be_bytes([0; 32])); + + let mut poller = ChainPoller::new(&mut chain, Network::Bitcoin); + match poller.poll_chain_tip(best_known_chain_tip).await { + Err(e) => assert_eq!(e, BlockSourceError::Persistent), + Ok(_) => panic!("Expected error"), + } + } + + #[tokio::test] + async fn poll_chain_with_malformed_headers() { + let mut chain = Blockchain::default().with_height(1).malformed_headers(); + let best_known_chain_tip = chain.at_height(0); + + let mut poller = ChainPoller::new(&mut chain, Network::Bitcoin); + match poller.poll_chain_tip(best_known_chain_tip).await { + Err(e) => assert_eq!(e, BlockSourceError::Persistent), + Ok(_) => panic!("Expected error"), + } + } + + #[tokio::test] + async fn poll_chain_with_common_tip() { + let mut chain = Blockchain::default().with_height(0); + let best_known_chain_tip = chain.tip(); + + let mut poller = ChainPoller::new(&mut chain, Network::Bitcoin); + match poller.poll_chain_tip(best_known_chain_tip).await { + Err(e) => panic!("Unexpected error: {:?}", e), + Ok(tip) => assert_eq!(tip, ChainTip::Common), + } + } + + #[tokio::test] + async fn poll_chain_with_uncommon_tip_but_equal_chainwork() { + let mut chain = Blockchain::default().with_height(1); + let best_known_chain_tip = chain.tip(); + + // Change the nonce to get a different block hash with the same chainwork. + chain.blocks.last_mut().unwrap().header.nonce += 1; + + let worse_chain_tip = chain.tip(); + let worse_chain_tip_hash = worse_chain_tip.header.block_hash(); + let worse_chain_tip = worse_chain_tip.validate(worse_chain_tip_hash).unwrap(); + assert_eq!(best_known_chain_tip.chainwork, worse_chain_tip.chainwork); + + let mut poller = ChainPoller::new(&mut chain, Network::Bitcoin); + match poller.poll_chain_tip(best_known_chain_tip).await { + Err(e) => panic!("Unexpected error: {:?}", e), + Ok(tip) => assert_eq!(tip, ChainTip::Worse(worse_chain_tip)), + } + } + + #[tokio::test] + async fn poll_chain_with_worse_tip() { + let mut chain = Blockchain::default().with_height(1); + let best_known_chain_tip = chain.tip(); + chain.disconnect_tip(); + + let worse_chain_tip = chain.tip(); + let worse_chain_tip_hash = worse_chain_tip.header.block_hash(); + let worse_chain_tip = worse_chain_tip.validate(worse_chain_tip_hash).unwrap(); + + let mut poller = ChainPoller::new(&mut chain, Network::Bitcoin); + match poller.poll_chain_tip(best_known_chain_tip).await { + Err(e) => panic!("Unexpected error: {:?}", e), + Ok(tip) => assert_eq!(tip, ChainTip::Worse(worse_chain_tip)), + } + } + + #[tokio::test] + async fn poll_chain_with_better_tip() { + let mut chain = Blockchain::default().with_height(1); + let best_known_chain_tip = chain.at_height(0); + + let better_chain_tip = chain.tip(); + let better_chain_tip_hash = better_chain_tip.header.block_hash(); + let better_chain_tip = better_chain_tip.validate(better_chain_tip_hash).unwrap(); + + let mut poller = ChainPoller::new(&mut chain, Network::Bitcoin); + match poller.poll_chain_tip(best_known_chain_tip).await { + Err(e) => panic!("Unexpected error: {:?}", e), + Ok(tip) => assert_eq!(tip, ChainTip::Better(better_chain_tip)), + } + } +} diff --git a/lightning-block-sync/src/test_utils.rs b/lightning-block-sync/src/test_utils.rs new file mode 100644 index 00000000000..a7b237c22a1 --- /dev/null +++ b/lightning-block-sync/src/test_utils.rs @@ -0,0 +1,229 @@ +use crate::{AsyncBlockSourceResult, BlockHeaderData, BlockSource, BlockSourceError, ChainListener, HeaderCache, Validate, ValidatedBlockHeader}; +use bitcoin::blockdata::block::{Block, BlockHeader}; +use bitcoin::blockdata::constants::genesis_block; +use bitcoin::hash_types::BlockHash; +use bitcoin::network::constants::Network; +use bitcoin::util::uint::Uint256; +use std::collections::VecDeque; + +#[derive(Default)] +pub struct Blockchain { + pub blocks: Vec, + without_blocks: Option>, + without_headers: bool, + malformed_headers: bool, +} + +impl Blockchain { + pub fn default() -> Self { + Blockchain::with_network(Network::Bitcoin) + } + + pub fn with_network(network: Network) -> Self { + let blocks = vec![genesis_block(network)]; + Self { blocks, ..Default::default() } + } + + pub fn with_height(mut self, height: usize) -> Self { + self.blocks.reserve_exact(height); + let bits = BlockHeader::compact_target_from_u256(&Uint256::from_be_bytes([0xff; 32])); + for i in 1..=height { + let prev_block = &self.blocks[i - 1]; + let prev_blockhash = prev_block.block_hash(); + let time = prev_block.header.time + height as u32; + self.blocks.push(Block { + header: BlockHeader { + version: 0, + prev_blockhash, + merkle_root: Default::default(), + time, + bits, + nonce: 0, + }, + txdata: vec![], + }); + } + self + } + + pub fn without_blocks(self, range: std::ops::RangeFrom) -> Self { + Self { without_blocks: Some(range), ..self } + } + + pub fn without_headers(self) -> Self { + Self { without_headers: true, ..self } + } + + pub fn malformed_headers(self) -> Self { + Self { malformed_headers: true, ..self } + } + + pub fn fork_at_height(&self, height: usize) -> Self { + assert!(height + 1 < self.blocks.len()); + let mut blocks = self.blocks.clone(); + let mut prev_blockhash = blocks[height].block_hash(); + for block in blocks.iter_mut().skip(height + 1) { + block.header.prev_blockhash = prev_blockhash; + block.header.nonce += 1; + prev_blockhash = block.block_hash(); + } + Self { blocks, without_blocks: None, ..*self } + } + + pub fn at_height(&self, height: usize) -> ValidatedBlockHeader { + let block_header = self.at_height_unvalidated(height); + let block_hash = self.blocks[height].block_hash(); + block_header.validate(block_hash).unwrap() + } + + fn at_height_unvalidated(&self, height: usize) -> BlockHeaderData { + assert!(!self.blocks.is_empty()); + assert!(height < self.blocks.len()); + BlockHeaderData { + chainwork: self.blocks[0].header.work() + Uint256::from_u64(height as u64).unwrap(), + height: height as u32, + header: self.blocks[height].header.clone(), + } + } + + pub fn tip(&self) -> ValidatedBlockHeader { + assert!(!self.blocks.is_empty()); + self.at_height(self.blocks.len() - 1) + } + + pub fn disconnect_tip(&mut self) -> Option { + self.blocks.pop() + } + + pub fn header_cache(&self, heights: std::ops::RangeInclusive) -> HeaderCache { + let mut cache = HeaderCache::new(); + for i in heights { + let value = self.at_height(i); + let key = value.header.block_hash(); + assert!(cache.insert(key, value).is_none()); + } + cache + } +} + +impl BlockSource for Blockchain { + fn get_header<'a>(&'a mut self, header_hash: &'a BlockHash, _height_hint: Option) -> AsyncBlockSourceResult<'a, BlockHeaderData> { + Box::pin(async move { + if self.without_headers { + return Err(BlockSourceError::Persistent); + } + + for (height, block) in self.blocks.iter().enumerate() { + if block.header.block_hash() == *header_hash { + let mut header_data = self.at_height_unvalidated(height); + if self.malformed_headers { + header_data.header.time += 1; + } + + return Ok(header_data); + } + } + Err(BlockSourceError::Transient) + }) + } + + fn get_block<'a>(&'a mut self, header_hash: &'a BlockHash) -> AsyncBlockSourceResult<'a, Block> { + Box::pin(async move { + for (height, block) in self.blocks.iter().enumerate() { + if block.header.block_hash() == *header_hash { + if let Some(without_blocks) = &self.without_blocks { + if without_blocks.contains(&height) { + return Err(BlockSourceError::Persistent); + } + } + + return Ok(block.clone()); + } + } + Err(BlockSourceError::Transient) + }) + } + + fn get_best_block<'a>(&'a mut self) -> AsyncBlockSourceResult<'a, (BlockHash, Option)> { + Box::pin(async move { + match self.blocks.last() { + None => Err(BlockSourceError::Transient), + Some(block) => { + let height = (self.blocks.len() - 1) as u32; + Ok((block.block_hash(), Some(height))) + }, + } + }) + } +} + +pub struct NullChainListener; + +impl ChainListener for NullChainListener { + fn block_connected(&mut self, _block: &Block, _height: u32) {} + fn block_disconnected(&mut self, _header: &BlockHeader, _height: u32) {} +} + +pub struct MockChainListener { + expected_blocks_connected: VecDeque, + expected_blocks_disconnected: VecDeque, +} + +impl MockChainListener { + pub fn new() -> Self { + Self { + expected_blocks_connected: VecDeque::new(), + expected_blocks_disconnected: VecDeque::new(), + } + } + + pub fn expect_block_connected(mut self, block: BlockHeaderData) -> Self { + self.expected_blocks_connected.push_back(block); + self + } + + pub fn expect_block_disconnected(mut self, block: BlockHeaderData) -> Self { + self.expected_blocks_disconnected.push_back(block); + self + } +} + +impl ChainListener for MockChainListener { + fn block_connected(&mut self, block: &Block, height: u32) { + match self.expected_blocks_connected.pop_front() { + None => { + panic!("Unexpected block connected: {:?}", block.block_hash()); + }, + Some(expected_block) => { + assert_eq!(block.block_hash(), expected_block.header.block_hash()); + assert_eq!(height, expected_block.height); + }, + } + } + + fn block_disconnected(&mut self, header: &BlockHeader, height: u32) { + match self.expected_blocks_disconnected.pop_front() { + None => { + panic!("Unexpected block disconnected: {:?}", header.block_hash()); + }, + Some(expected_block) => { + assert_eq!(header.block_hash(), expected_block.header.block_hash()); + assert_eq!(height, expected_block.height); + }, + } + } +} + +impl Drop for MockChainListener { + fn drop(&mut self) { + if std::thread::panicking() { + return; + } + if !self.expected_blocks_connected.is_empty() { + panic!("Expected blocks connected: {:?}", self.expected_blocks_connected); + } + if !self.expected_blocks_disconnected.is_empty() { + panic!("Expected blocks disconnected: {:?}", self.expected_blocks_disconnected); + } + } +} diff --git a/lightning-block-sync/src/utils.rs b/lightning-block-sync/src/utils.rs new file mode 100644 index 00000000000..96a2e578877 --- /dev/null +++ b/lightning-block-sync/src/utils.rs @@ -0,0 +1,54 @@ +use bitcoin::hashes::hex::FromHex; +use bitcoin::util::uint::Uint256; + +pub fn hex_to_uint256(hex: &str) -> Result { + let bytes = <[u8; 32]>::from_hex(hex)?; + Ok(Uint256::from_be_bytes(bytes)) +} + +#[cfg(test)] +mod tests { + use super::*; + use bitcoin::util::uint::Uint256; + + #[test] + fn hex_to_uint256_empty_str() { + assert!(hex_to_uint256("").is_err()); + } + + #[test] + fn hex_to_uint256_too_short_str() { + let hex = String::from_utf8(vec![b'0'; 32]).unwrap(); + assert_eq!(hex_to_uint256(&hex), Err(bitcoin::hashes::hex::Error::InvalidLength(64, 32))); + } + + #[test] + fn hex_to_uint256_too_long_str() { + let hex = String::from_utf8(vec![b'0'; 128]).unwrap(); + assert_eq!(hex_to_uint256(&hex), Err(bitcoin::hashes::hex::Error::InvalidLength(64, 128))); + } + + #[test] + fn hex_to_uint256_odd_length_str() { + let hex = String::from_utf8(vec![b'0'; 65]).unwrap(); + assert_eq!(hex_to_uint256(&hex), Err(bitcoin::hashes::hex::Error::OddLengthString(65))); + } + + #[test] + fn hex_to_uint256_invalid_char() { + let hex = String::from_utf8(vec![b'G'; 64]).unwrap(); + assert_eq!(hex_to_uint256(&hex), Err(bitcoin::hashes::hex::Error::InvalidChar(b'G'))); + } + + #[test] + fn hex_to_uint256_lowercase_str() { + let hex: String = std::iter::repeat("0123456789abcdef").take(4).collect(); + assert_eq!(hex_to_uint256(&hex).unwrap(), Uint256([0x0123456789abcdefu64; 4])); + } + + #[test] + fn hex_to_uint256_uppercase_str() { + let hex: String = std::iter::repeat("0123456789ABCDEF").take(4).collect(); + assert_eq!(hex_to_uint256(&hex).unwrap(), Uint256([0x0123456789abcdefu64; 4])); + } +}