diff --git a/questdb-rs-ffi/Cargo.lock b/questdb-rs-ffi/Cargo.lock index a3773d0d..c6667056 100644 --- a/questdb-rs-ffi/Cargo.lock +++ b/questdb-rs-ffi/Cargo.lock @@ -201,7 +201,9 @@ name = "questdb-rs" version = "5.0.0" dependencies = [ "base64ct", + "bytes", "dns-lookup", + "http", "indoc", "itoa", "libc", diff --git a/questdb-rs/Cargo.toml b/questdb-rs/Cargo.toml index 84455a67..c844cdd8 100644 --- a/questdb-rs/Cargo.toml +++ b/questdb-rs/Cargo.toml @@ -19,12 +19,14 @@ crate-type = ["lib"] [dependencies] libc = "0.2" -socket2 = { version = "0.5.5", optional = true } dns-lookup = "2.0.4" base64ct = { version = "1.7", features = ["alloc"] } rustls-pemfile = "2.0.0" ryu = { version = "1.0" } itoa = "1.0" +bytes = "1.10.1" + +socket2 = { version = "0.5.5", optional = true } aws-lc-rs = { version = "1.13", optional = true } ring = { version = "0.17.14", optional = true } rustls-pki-types = "1.0.1" @@ -33,9 +35,21 @@ rustls-native-certs = { version = "0.8.1", optional = true } webpki-roots = { version = "1.0.1", default-features = false, optional = true } chrono = { version = "0.4.40", optional = true } +http = { version = "1.3.1", optional = true } + # We need to limit the `ureq` version to 3.0.x since we use # the `ureq::unversioned` module which does not respect semantic versioning. ureq = { version = "3.0.10, <3.1.0", default-features = false, features = ["_tls"], optional = true } + +tokio = { version = "1.45.1", default-features = false, features = ["net"], optional = true } +tokio-rustls = { version = "0.26.2", default-features = false, optional = true } +#hyper = { version = "1.6.0", default-features = false, optional = true } +#http-body-util = { version = "0.1.3", optional = true } +#hyper-util = { version = "0.1.14", optional = true, features = ["client", "client-legacy", "http1"] } +reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"], optional = true } +lasso = { version = "0.7.3", features = ["multi-threaded"], optional = true } +crossbeam-queue = { version = "0.3.12", optional = true } + serde_json = { version = "1", optional = true } questdb-confstr = "0.1.1" rand = { version = "0.9.0", optional = true } @@ -56,10 +70,10 @@ mio = { version = "1", features = ["os-poll", "net"] } chrono = "0.4.31" tempfile = "3" webpki-roots = "1.0.1" -rstest = "0.25.0" +tokio = { version = "1.45.1", features = ["macros", "rt-multi-thread"]} [features] -default = ["sync-sender", "tls-webpki-certs", "ring-crypto"] +default = ["sync-sender", "async-sender-http", "tls-webpki-certs", "ring-crypto"] ## Sync ILP/TCP + ILP/HTTP Sender sync-sender = ["sync-sender-tcp", "sync-sender-http"] @@ -68,7 +82,23 @@ sync-sender = ["sync-sender-tcp", "sync-sender-http"] sync-sender-tcp = ["_sync-sender", "_sender-tcp", "dep:socket2"] ## Sync ILP/HTTP -sync-sender-http = ["_sync-sender", "_sender-http", "dep:ureq", "dep:serde_json", "dep:rand"] +sync-sender-http = [ + "_sync-sender", + "_sender-http", + "dep:ureq"] + +## Async ILP/HTTP Sender +async-sender-http = [ + "_async-sender", + "_sender-http", + "dep:tokio", + "dep:tokio-rustls", + "dep:reqwest", + "dep:lasso", + "dep:crossbeam-queue"] + +## Compatiblity alias. +ilp-over-http = ["sync-sender-tcp"] ## Allow use OS-provided root TLS certificates tls-native-certs = ["dep:rustls-native-certs"] @@ -92,9 +122,10 @@ json_tests = [] chrono_timestamp = ["chrono"] # Hidden derived features, used in code to enable-disable code sections. Don't use directly. -_sender-tcp = [] -_sender-http = [] -_sync-sender = [] +_sender-tcp = [] # enabled for sync-sender-tcp +_sender-http = ["dep:http", "dep:serde_json", "dep:rand"] # enabled for any(sync-sender-http, async-sender-http) +_sync-sender = [] # enabled for any(sync-sender-tcp, sync-sender-http) +_async-sender = [] # enabled for async-sender-http) ## Enable all cross-compatible features. ## The `aws-lc-crypto` and `ring-crypto` features are mutually exclusive, @@ -112,6 +143,14 @@ almost-all-features = [ "ndarray" ] +[[example]] +name = "from_conf" +required-features = ["sync-sender-tcp"] + +[[example]] +name = "from_env" +required-features = ["sync-sender-tcp"] + [[example]] name = "basic" required-features = ["chrono_timestamp", "ndarray"] diff --git a/questdb-rs/build.rs b/questdb-rs/build.rs index df2d3000..c419de83 100644 --- a/questdb-rs/build.rs +++ b/questdb-rs/build.rs @@ -125,7 +125,6 @@ pub mod json_tests { use crate::tests::{TestResult}; use base64ct::Base64; use base64ct::Encoding; - use rstest::rstest; fn matches_any_line(line: &[u8], expected: &[&str]) -> bool { for &exp in expected { @@ -144,12 +143,37 @@ pub mod json_tests { // for line in serde_json::to_string_pretty(&spec).unwrap().split("\n") { // writeln!(output, "/// {}", line)?; // } - writeln!(output, "#[rstest]")?; + let test_name_slug = slugify!(&spec.test_name, separator = "_"); + writeln!(output, "#[test]")?; writeln!( output, - "fn test_{:03}_{}(\n #[values(ProtocolVersion::V1, ProtocolVersion::V2)] version: ProtocolVersion,\n) -> TestResult {{", - index, - slugify!(&spec.test_name, separator = "_") + "fn test_{:03}_{}_v1() -> TestResult {{", + index, test_name_slug + )?; + writeln!( + output, + " _test_{:03}_{}(ProtocolVersion::V1)\n", + index, test_name_slug + )?; + writeln!(output, "}}"); + + writeln!(output, "#[test]")?; + writeln!( + output, + "fn test_{:03}_{}_v2() -> TestResult {{", + index, test_name_slug + )?; + writeln!( + output, + " _test_{:03}_{}(ProtocolVersion::V2)\n", + index, test_name_slug + )?; + writeln!(output, "}}"); + + writeln!( + output, + "fn _test_{:03}_{}(version: ProtocolVersion) -> TestResult {{", + index, test_name_slug )?; writeln!(output, " let mut buffer = Buffer::new(version);")?; @@ -274,11 +298,11 @@ fn main() -> Result<(), Box> { #[cfg(not(any(feature = "_sender-tcp", feature = "_sender-http")))] compile_error!( - "At least one of `sync-sender-tcp` or `sync-sender-http` features must be enabled" + "At least one of `sync-sender-tcp`, `sync-sender-http`, or `async-sender-http` features must be enabled" ); #[cfg(not(any(feature = "aws-lc-crypto", feature = "ring-crypto")))] - compile_error!("You must enable exactly one of the `aws-lc-crypto` or `ring-crypto` features, but none are enabled."); + compile_error!("You must enable exactly one of the `aws-lc-crypto` or `ring-crypto` features, but neither are enabled."); #[cfg(all(feature = "aws-lc-crypto", feature = "ring-crypto"))] compile_error!("You must enable exactly one of the `aws-lc-crypto` or `ring-crypto` features, but both are enabled."); diff --git a/questdb-rs/src/ingress/async_sender/http.rs b/questdb-rs/src/ingress/async_sender/http.rs new file mode 100644 index 00000000..9d2998f7 --- /dev/null +++ b/questdb-rs/src/ingress/async_sender/http.rs @@ -0,0 +1,199 @@ +/******************************************************************************* + * ___ _ ____ ____ + * / _ \ _ _ ___ ___| |_| _ \| __ ) + * | | | | | | |/ _ \/ __| __| | | | _ \ + * | |_| | |_| | __/\__ \ |_| |_| | |_) | + * \__\_\\__,_|\___||___/\__|____/|____/ + * + * Copyright (c) 2014-2019 Appsicle + * Copyright (c) 2019-2025 QuestDB + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ******************************************************************************/ +use std::future::Future; +use std::time::Duration; + +use crate::error::{fmt, Result}; +use crate::ingress::conf::SETTINGS_RETRY_TIMEOUT; +use crate::ingress::http_common::{ + is_retriable_status_code, process_settings_response, ParsedResponseHeaders, +}; +use crate::ingress::tls::TlsSettings; +use crate::ingress::ProtocolVersion; +use bytes::Bytes; +use rand::Rng; +use reqwest::{Client, RequestBuilder, StatusCode, Url}; +use tokio::time::{sleep, Instant}; + +// TODO: +// * Implement Auth. +// * Implement TLS. + +pub(super) struct HttpClient { + tls: Option, + auth: Option, + client: Client, +} + +impl HttpClient { + pub fn new(tls: Option, auth: Option, user_agent: &str) -> Result { + let builder = Client::builder().user_agent(user_agent); + let client = match builder.build() { + Ok(client) => client, + Err(e) => return Err(fmt!(ConfigError, "Could not create http client: {}", e)), + }; + Ok(Self { tls, auth, client }) + } + + pub async fn get( + &self, + url: &Url, + request_timeout: Duration, + ) -> (bool, Result<(StatusCode, ParsedResponseHeaders, Bytes)>) { + let builder = self.client.get(url.clone()).timeout(request_timeout); + perform_request(builder).await + } + + pub async fn get_with_retries( + &self, + url: &Url, + request_timeout: Duration, + retry_timeout: Duration, + ) -> Result<(StatusCode, ParsedResponseHeaders, Bytes)> { + request_with_retries(|| self.get(url, request_timeout), retry_timeout).await + } + + pub async fn post( + &self, + url: &Url, + body: Bytes, + request_timeout: Duration, + ) -> (bool, Result<(StatusCode, ParsedResponseHeaders, Bytes)>) { + let builder = self + .client + .post(url.clone()) + .timeout(request_timeout) + .body(body); + perform_request(builder).await + } + + pub async fn post_with_retries( + &self, + url: &Url, + body: Bytes, + request_timeout: Duration, + retry_timeout: Duration, + ) -> Result<(StatusCode, ParsedResponseHeaders, Bytes)> { + request_with_retries( + || self.post(url, body.clone(), request_timeout), + retry_timeout, + ) + .await + } +} + +pub(super) fn build_url(tls: bool, host: &str, port: &str, path: &str) -> Result { + let schema = if tls { "https" } else { "http" }; + let url_string = format!("{schema}://{host}:{port}/{path}"); + let map_url_err = |url, e| fmt!(CouldNotResolveAddr, "could not parse url {url:?}: {e}"); + Url::parse(&url_string).map_err(|e| map_url_err(&url_string, e)) +} + +fn map_reqwest_err( + err: reqwest::Error, +) -> (bool, Result<(StatusCode, ParsedResponseHeaders, Bytes)>) { + let mut need_retry = false; + if err.is_timeout() || err.is_connect() || err.is_redirect() { + need_retry = true; + } + if let Some(status) = err.status() { + if is_retriable_status_code(status) { + need_retry = true; + } + } + ( + need_retry, + Err(fmt!(SocketError, "Error receiving HTTP response: {err}")), + ) +} + +async fn perform_request( + builder: RequestBuilder, +) -> (bool, Result<(StatusCode, ParsedResponseHeaders, Bytes)>) { + let response = match builder.send().await { + Ok(response) => response, + Err(err) => return map_reqwest_err(err), + }; + let status = response.status(); + let header_data = ParsedResponseHeaders::parse(response.headers()); + match response.bytes().await { + Ok(bytes) => ( + is_retriable_status_code(status), + Ok((status, header_data, bytes)), + ), + Err(err) => map_reqwest_err(err), + } +} + +async fn request_with_retries( + mut do_request: F, + retry_timeout: Duration, +) -> Result<(StatusCode, ParsedResponseHeaders, Bytes)> +where + F: FnMut() -> Fut, + Fut: Future)>, +{ + let (need_retry, last_response) = do_request().await; + if !need_retry || retry_timeout.is_zero() { + return last_response; + } + + let mut rng = rand::rng(); + let retry_end = Instant::now() + retry_timeout; + let mut retry_interval_ms = 10; + loop { + let jitter_ms = rng.random_range(-5i32..5); + let to_sleep_ms = retry_interval_ms + jitter_ms; + let to_sleep = Duration::from_millis(to_sleep_ms as u64); + if (Instant::now() + to_sleep) > retry_end { + return last_response; + } + sleep(to_sleep).await; + let (need_retry, last_response) = do_request().await; + if !need_retry { + return last_response; + } + retry_interval_ms = (retry_interval_ms * 2).min(1000); + } +} + +pub(super) async fn read_server_settings( + client: &HttpClient, + settings_url: &Url, + default_max_name_len: usize, + request_timeout: Duration, +) -> Result<(Vec, usize)> { + let default_protocol_version = ProtocolVersion::V1; + + let response = client + .get_with_retries(settings_url, request_timeout, SETTINGS_RETRY_TIMEOUT) + .await; + + process_settings_response( + response, + settings_url.as_str(), + default_protocol_version, + default_max_name_len, + ) +} diff --git a/questdb-rs/src/ingress/async_sender/mod.rs b/questdb-rs/src/ingress/async_sender/mod.rs new file mode 100644 index 00000000..c3d1d8a3 --- /dev/null +++ b/questdb-rs/src/ingress/async_sender/mod.rs @@ -0,0 +1,167 @@ +/******************************************************************************* + * ___ _ ____ ____ + * / _ \ _ _ ___ ___| |_| _ \| __ ) + * | | | | | | |/ _ \/ __| __| | | | _ \ + * | |_| | |_| | __/\__ \ |_| |_| | |_) | + * \__\_\\__,_|\___||___/\__|____/|____/ + * + * Copyright (c) 2014-2019 Appsicle + * Copyright (c) 2019-2025 QuestDB + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ******************************************************************************/ +use crate::error; +use crate::error::Result; +use crate::ingress::async_sender::http::{build_url, read_server_settings, HttpClient}; +use crate::ingress::conf::{AuthParams, HttpConfig}; +use crate::ingress::http_common::{parse_http_error, pick_protocol_version}; +use crate::ingress::tls::TlsSettings; +use crate::ingress::{ + check_protocol_version, Buffer, FrozenBuffer, NdArrayView, ProtocolVersion, SenderBuilder, +}; +use reqwest::Url; +use std::ops::Deref; +use std::sync::Arc; + +mod http; + +pub(crate) struct AsyncSenderSettings { + max_name_len: usize, + max_buf_size: usize, + protocol_version: ProtocolVersion, + http_config: HttpConfig, +} + +pub struct AsyncSender { + descr: String, + settings: AsyncSenderSettings, + client: HttpClient, + write_url: Url, +} + +impl AsyncSender { + pub async fn from_conf>(conf: T) -> Result> { + SenderBuilder::from_conf(conf)?.build_async().await + } + + pub async fn from_env() -> Result> { + SenderBuilder::from_env()?.build_async().await + } + + pub(crate) async fn new( + descr: String, + host: &str, + port: &str, + tls: Option, + auth: Option, + max_name_len: usize, + max_buf_size: usize, + protocol_version: Option, + http_config: HttpConfig, + ) -> Result> { + let mut settings = AsyncSenderSettings { + max_name_len, // sniffed and overwritten, unless endpoint is old and does not support /settings + max_buf_size, + protocol_version: protocol_version.unwrap_or(ProtocolVersion::V2), // TODO: sniff! + http_config, + }; + + let settings_url = build_url(tls.is_some(), host, port, "settings")?; + let write_url = build_url(tls.is_some(), host, port, "write")?; // TODO: fixme! + let client = HttpClient::new(tls, auth, &settings.http_config.user_agent)?; + let (protocol_versions, max_name_len) = read_server_settings( + &client, + &settings_url, + max_name_len, + *settings.http_config.request_timeout.deref(), + ) + .await?; + + settings.protocol_version = pick_protocol_version(&protocol_versions[..])?; + settings.max_name_len = max_name_len; + + Ok(Arc::new(Self { + descr, + settings, + client, + write_url, + })) + } + + pub fn new_buffer(self: &Arc) -> Buffer { + Buffer::with_max_name_len(self.settings.protocol_version, self.settings.max_name_len) + } + + pub async fn flush(&self, buf: impl Into) -> Result<()> { + self.flush_impl(buf, false).await + } + + pub async fn flush_transactional(&self, buf: impl Into) -> Result<()> { + self.flush_impl(buf, true).await + } + + async fn flush_impl(&self, buf: impl Into, transactional: bool) -> Result<()> { + // TODO: refactor more of the impl into the `http_common` module. + + let buf = buf.into(); + buf.check_can_flush()?; + + if transactional && !buf.transactional() { + return Err(error::fmt!( + InvalidApiCall, + "Buffer contains lines for multiple tables. \ + Transactional flushes are only supported for buffers containing lines for a single table." + )); + } + + if buf.len() > self.settings.max_buf_size { + let buf_len = buf.len(); + return Err(error::fmt!( + InvalidApiCall, + "Could not flush buffer: Buffer size of {} exceeds maximum configured allowed size of {} bytes.", + buf_len, + self.settings.max_buf_size + )); + } + + check_protocol_version(self.settings.protocol_version, buf.protocol_version())?; + + if buf.is_empty() { + return Ok(()); + } + + // Which we can freeze as something that we can send. + let body = buf.bytes(); + + let request_min_throughput = *self.settings.http_config.request_min_throughput; + let extra_time = if request_min_throughput > 0 { + (buf.len() as f64) / (request_min_throughput as f64) + } else { + 0.0f64 + }; + + let (status, header_data, response) = self.client.post_with_retries( + &self.write_url, + body, + *self.settings.http_config.request_timeout + std::time::Duration::from_secs_f64(extra_time), + *self.settings.http_config.retry_timeout.deref() + ).await?; + + if status.is_client_error() || status.is_server_error() { + return Err(parse_http_error(status, header_data, response)); + } + + Ok(()) + } +} diff --git a/questdb-rs/src/ingress/buffer.rs b/questdb-rs/src/ingress/buffer.rs index bc5ce77d..8c97c6d4 100644 --- a/questdb-rs/src/ingress/buffer.rs +++ b/questdb-rs/src/ingress/buffer.rs @@ -28,14 +28,16 @@ use crate::ingress::{ MAX_NAME_LEN_DEFAULT, }; use crate::{error, Error}; +use bytes::{BufMut, Bytes, BytesMut}; use std::fmt::{Debug, Formatter}; +use std::mem; use std::num::NonZeroUsize; use std::slice::from_raw_parts_mut; -fn write_escaped_impl(check_escape_fn: C, quoting_fn: Q, output: &mut Vec, s: &str) +fn write_escaped_impl(check_escape_fn: C, quoting_fn: Q, output: &mut BytesMut, s: &str) where C: Fn(u8) -> bool, - Q: Fn(&mut Vec), + Q: Fn(&mut BytesMut), { let mut to_escape = 0usize; for b in s.bytes() { @@ -80,12 +82,12 @@ fn must_escape_quoted(c: u8) -> bool { matches!(c, b'\n' | b'\r' | b'"' | b'\\') } -fn write_escaped_unquoted(output: &mut Vec, s: &str) { +fn write_escaped_unquoted(output: &mut BytesMut, s: &str) { write_escaped_impl(must_escape_unquoted, |_output| (), output, s); } -fn write_escaped_quoted(output: &mut Vec, s: &str) { - write_escaped_impl(must_escape_quoted, |output| output.push(b'"'), output, s) +fn write_escaped_quoted(output: &mut BytesMut, s: &str) { + write_escaped_impl(must_escape_quoted, |output| output.put_u8(b'"'), output, s) } pub(crate) struct F64Serializer { @@ -170,7 +172,7 @@ impl OpCase { // IMPORTANT: This struct MUST remain `Copy` to ensure that // there are no heap allocations when performing marker operations. #[derive(Debug, Clone, Copy)] -struct BufferState { +pub(crate) struct BufferState { op_case: OpCase, row_count: usize, first_table_len: Option, @@ -374,6 +376,42 @@ impl<'a> AsRef for ColumnName<'a> { } } +#[derive(Clone)] +struct BufferInner +where + T: Clone, +{ + output: T, + state: BufferState, + marker: Option<(usize, BufferState)>, + max_name_len: usize, + protocol_version: ProtocolVersion, +} + +impl BufferInner +where + T: Clone, +{ + /// Check if the next API operation is allowed as per the OP case state machine. + #[inline(always)] + fn check_op(&self, op: Op) -> crate::Result<()> { + if (self.state.op_case as isize & op as isize) > 0 { + Ok(()) + } else { + Err(error::fmt!( + InvalidApiCall, + "State error: Bad call to `{}`, {}.", + op.descr(), + self.state.op_case.next_op_descr() + )) + } + } + + pub(crate) fn check_can_flush(&self) -> crate::Result<()> { + self.check_op(Op::Flush) + } +} + /// A reusable buffer to prepare a batch of ILP messages. /// /// # Example @@ -466,11 +504,71 @@ impl<'a> AsRef for ColumnName<'a> { /// #[derive(Clone)] pub struct Buffer { - output: Vec, - state: BufferState, - marker: Option<(usize, BufferState)>, - max_name_len: usize, - protocol_version: ProtocolVersion, + inner: BufferInner, +} + +/// A buffer that can be sent asynchronously. +/// +/// ```norun +/// let frozen: FrozenBuffer = buffer.into(); +/// ``` +/// +/// A frozen buffer is cheap to clone. +/// This makes it possible to send the same data rows to multiple databases in parallel. +#[derive(Clone)] +pub struct FrozenBuffer { + inner: BufferInner, +} + +// TODO: +// * Document APIs. +// * Implement FrozenBuffer -> Buffer "try" logic. +// * Document example with buffer pool. +impl FrozenBuffer { + pub fn transactional(&self) -> bool { + self.inner.state.transactional + } + + pub fn check_can_flush(&self) -> crate::Result<()> { + self.inner.check_can_flush() + } + + pub fn len(&self) -> usize { + self.inner.output.len() + } + + pub fn is_empty(&self) -> bool { + self.inner.output.is_empty() + } + + pub fn bytes(&self) -> Bytes { + self.inner.output.clone() + } + + pub fn protocol_version(&self) -> ProtocolVersion { + self.inner.protocol_version + } +} + +impl From for FrozenBuffer { + fn from(buf: Buffer) -> Self { + let BufferInner { + output, + state, + marker, + max_name_len, + protocol_version, + } = buf.inner; + let output = output.freeze(); + let inner = BufferInner { + output, + state, + marker, + max_name_len, + protocol_version, + }; + Self { inner } + } } impl Buffer { @@ -499,16 +597,18 @@ impl Buffer { /// For the default max name length limit (127), use [`Self::new`]. pub fn with_max_name_len(protocol_version: ProtocolVersion, max_name_len: usize) -> Self { Self { - output: Vec::new(), - state: BufferState::new(), - marker: None, - max_name_len, - protocol_version, + inner: BufferInner { + output: BytesMut::new(), + state: BufferState::new(), + marker: None, + max_name_len, + protocol_version, + }, } } pub fn protocol_version(&self) -> ProtocolVersion { - self.protocol_version + self.inner.protocol_version } /// Pre-allocate to ensure the buffer has enough capacity for at least the @@ -516,37 +616,37 @@ impl Buffer { /// This does not allocate if such additional capacity is already satisfied. /// See: `capacity`. pub fn reserve(&mut self, additional: usize) { - self.output.reserve(additional); + self.inner.output.reserve(additional); } /// The number of bytes accumulated in the buffer. pub fn len(&self) -> usize { - self.output.len() + self.inner.output.len() } /// The number of rows accumulated in the buffer. pub fn row_count(&self) -> usize { - self.state.row_count + self.inner.state.row_count } /// Tells whether the buffer is transactional. It is transactional iff it contains /// data for at most one table. Additionally, you must send the buffer over HTTP to /// get transactional behavior. pub fn transactional(&self) -> bool { - self.state.transactional + self.inner.state.transactional } pub fn is_empty(&self) -> bool { - self.output.is_empty() + self.inner.output.is_empty() } /// The total number of bytes the buffer can hold before it needs to resize. pub fn capacity(&self) -> usize { - self.output.capacity() + self.inner.output.capacity() } pub fn as_bytes(&self) -> &[u8] { - &self.output + &self.inner.output } /// Mark a rewind point. @@ -556,7 +656,7 @@ impl Buffer { /// Once the marker is no longer needed, call /// [`clear_marker`](Buffer::clear_marker). pub fn set_marker(&mut self) -> crate::Result<()> { - if (self.state.op_case as isize & Op::Table as isize) == 0 { + if (self.inner.state.op_case as isize & Op::Table as isize) == 0 { return Err(error::fmt!( InvalidApiCall, concat!( @@ -566,7 +666,7 @@ impl Buffer { ) )); } - self.marker = Some((self.output.len(), self.state)); + self.inner.marker = Some((self.inner.output.len(), self.inner.state)); Ok(()) } @@ -575,9 +675,9 @@ impl Buffer { /// /// As a side effect, this also clears the marker. pub fn rewind_to_marker(&mut self) -> crate::Result<()> { - if let Some((position, state)) = self.marker.take() { - self.output.truncate(position); - self.state = state; + if let Some((position, state)) = self.inner.marker.take() { + self.inner.output.truncate(position); + self.inner.state = state; Ok(()) } else { Err(error::fmt!( @@ -592,30 +692,21 @@ impl Buffer { /// /// Idempotent. pub fn clear_marker(&mut self) { - self.marker = None; + self.inner.marker = None; } /// Reset the buffer and clear contents whilst retaining /// [`capacity`](Buffer::capacity). pub fn clear(&mut self) { - self.output.clear(); - self.state = BufferState::new(); - self.marker = None; + self.inner.output.clear(); + self.inner.state = BufferState::new(); + self.inner.marker = None; } /// Check if the next API operation is allowed as per the OP case state machine. #[inline(always)] fn check_op(&self, op: Op) -> crate::Result<()> { - if (self.state.op_case as isize & op as isize) > 0 { - Ok(()) - } else { - Err(error::fmt!( - InvalidApiCall, - "State error: Bad call to `{}`, {}.", - op.descr(), - self.state.op_case.next_op_descr() - )) - } + self.inner.check_op(op) } /// Checks if this buffer is ready to be flushed to a sender via one of the @@ -624,17 +715,17 @@ impl Buffer { /// message indicating why this [`Buffer`] cannot be flushed at the moment. #[inline(always)] pub fn check_can_flush(&self) -> crate::Result<()> { - self.check_op(Op::Flush) + self.inner.check_can_flush() } #[inline(always)] fn validate_max_name_len(&self, name: &str) -> crate::Result<()> { - if name.len() > self.max_name_len { + if name.len() > self.inner.max_name_len { return Err(error::fmt!( InvalidName, "Bad name: {:?}: Too long (max {} characters)", name, - self.max_name_len + self.inner.max_name_len )); } Ok(()) @@ -676,17 +767,17 @@ impl Buffer { let name: TableName<'a> = name.try_into()?; self.validate_max_name_len(name.name)?; self.check_op(Op::Table)?; - let table_begin = self.output.len(); - write_escaped_unquoted(&mut self.output, name.name); - let table_end = self.output.len(); - self.state.op_case = OpCase::TableWritten; + let table_begin = self.inner.output.len(); + write_escaped_unquoted(&mut self.inner.output, name.name); + let table_end = self.inner.output.len(); + self.inner.state.op_case = OpCase::TableWritten; // A buffer stops being transactional if it targets multiple tables. - if let Some(first_table_len) = &self.state.first_table_len { - let first_table = &self.output[0..first_table_len.get()]; - let this_table = &self.output[table_begin..table_end]; + if let Some(first_table_len) = &self.inner.state.first_table_len { + let first_table = &self.inner.output[0..first_table_len.get()]; + let this_table = &self.inner.output[table_begin..table_end]; if first_table != this_table { - self.state.transactional = false; + self.inner.state.transactional = false; } } else { debug_assert!(table_begin == 0); @@ -700,7 +791,7 @@ impl Buffer { // Instead we just assert that it's `Some`. debug_assert!(first_table_len.is_some()); - self.state.first_table_len = first_table_len; + self.inner.state.first_table_len = first_table_len; } Ok(self) } @@ -761,11 +852,11 @@ impl Buffer { let name: ColumnName<'a> = name.try_into()?; self.validate_max_name_len(name.name)?; self.check_op(Op::Symbol)?; - self.output.push(b','); - write_escaped_unquoted(&mut self.output, name.name); - self.output.push(b'='); - write_escaped_unquoted(&mut self.output, value.as_ref()); - self.state.op_case = OpCase::SymbolWritten; + self.inner.output.put_u8(b','); + write_escaped_unquoted(&mut self.inner.output, name.name); + self.inner.output.put_u8(b'='); + write_escaped_unquoted(&mut self.inner.output, value.as_ref()); + self.inner.state.op_case = OpCase::SymbolWritten; Ok(self) } @@ -777,15 +868,16 @@ impl Buffer { let name: ColumnName<'a> = name.try_into()?; self.validate_max_name_len(name.name)?; self.check_op(Op::Column)?; - self.output - .push(if (self.state.op_case as isize & Op::Symbol as isize) > 0 { + self.inner.output.put_u8( + if (self.inner.state.op_case as isize & Op::Symbol as isize) > 0 { b' ' } else { b',' - }); - write_escaped_unquoted(&mut self.output, name.name); - self.output.push(b'='); - self.state.op_case = OpCase::ColumnWritten; + }, + ); + write_escaped_unquoted(&mut self.inner.output, name.name); + self.inner.output.put_u8(b'='); + self.inner.state.op_case = OpCase::ColumnWritten; Ok(self) } @@ -825,7 +917,7 @@ impl Buffer { Error: From, { self.write_column_key(name)?; - self.output.push(if value { b't' } else { b'f' }); + self.inner.output.put_u8(if value { b't' } else { b'f' }); Ok(self) } @@ -867,8 +959,8 @@ impl Buffer { self.write_column_key(name)?; let mut buf = itoa::Buffer::new(); let printed = buf.format(value); - self.output.extend_from_slice(printed.as_bytes()); - self.output.push(b'i'); + self.inner.output.extend_from_slice(printed.as_bytes()); + self.inner.output.put_u8(b'i'); Ok(self) } @@ -908,13 +1000,13 @@ impl Buffer { Error: From, { self.write_column_key(name)?; - if !matches!(self.protocol_version, ProtocolVersion::V1) { - self.output.push(b'='); - self.output.push(DOUBLE_BINARY_FORMAT_TYPE); - self.output.extend_from_slice(&value.to_le_bytes()) + if !matches!(self.inner.protocol_version, ProtocolVersion::V1) { + self.inner.output.put_u8(b'='); + self.inner.output.put_u8(DOUBLE_BINARY_FORMAT_TYPE); + self.inner.output.extend_from_slice(&value.to_le_bytes()) } else { let mut ser = F64Serializer::new(value); - self.output.extend_from_slice(ser.as_str().as_bytes()) + self.inner.output.extend_from_slice(ser.as_str().as_bytes()) } Ok(self) } @@ -971,7 +1063,7 @@ impl Buffer { Error: From, { self.write_column_key(name)?; - write_escaped_quoted(&mut self.output, value.as_ref()); + write_escaped_quoted(&mut self.inner.output, value.as_ref()); Ok(self) } @@ -1030,7 +1122,7 @@ impl Buffer { D: ArrayElement + ArrayElementSealed, Error: From, { - if self.protocol_version == ProtocolVersion::V1 { + if self.inner.protocol_version == ProtocolVersion::V1 { return Err(error::fmt!( ProtocolVersionError, "Protocol version v1 does not support array datatype", @@ -1057,30 +1149,32 @@ impl Buffer { let array_buf_size = check_and_get_array_bytes_size(view)?; self.write_column_key(name)?; // binary format flag '=' - self.output.push(b'='); + self.inner.output.put_u8(b'='); // binary format entity type - self.output.push(ARRAY_BINARY_FORMAT_TYPE); + self.inner.output.put_u8(ARRAY_BINARY_FORMAT_TYPE); // ndarr datatype - self.output.push(D::type_tag()); + self.inner.output.put_u8(D::type_tag()); // ndarr dims - self.output.push(ndim as u8); + self.inner.output.put_u8(ndim as u8); let dim_header_size = size_of::() * ndim; - self.output.reserve(dim_header_size + array_buf_size); + self.inner.output.reserve(dim_header_size + array_buf_size); for i in 0..ndim { // ndarr shape - self.output + self.inner + .output .extend_from_slice((view.dim(i)? as u32).to_le_bytes().as_slice()); } - let index = self.output.len(); - let writeable = - unsafe { from_raw_parts_mut(self.output.as_mut_ptr().add(index), array_buf_size) }; + let index = self.inner.output.len(); + let writeable = unsafe { + from_raw_parts_mut(self.inner.output.as_mut_ptr().add(index), array_buf_size) + }; // ndarr data ndarr::write_array_data(view, writeable, array_buf_size)?; - unsafe { self.output.set_len(array_buf_size + index) } + unsafe { self.inner.output.set_len(array_buf_size + index) } Ok(self) } @@ -1151,8 +1245,8 @@ impl Buffer { let timestamp: TimestampMicros = timestamp.try_into()?; let mut buf = itoa::Buffer::new(); let printed = buf.format(timestamp.as_i64()); - self.output.extend_from_slice(printed.as_bytes()); - self.output.push(b't'); + self.inner.output.extend_from_slice(printed.as_bytes()); + self.inner.output.put_u8(b't'); Ok(self) } @@ -1216,11 +1310,11 @@ impl Buffer { } let mut buf = itoa::Buffer::new(); let printed = buf.format(epoch_nanos); - self.output.push(b' '); - self.output.extend_from_slice(printed.as_bytes()); - self.output.push(b'\n'); - self.state.op_case = OpCase::MayFlushOrTable; - self.state.row_count += 1; + self.inner.output.put_u8(b' '); + self.inner.output.extend_from_slice(printed.as_bytes()); + self.inner.output.put_u8(b'\n'); + self.inner.state.op_case = OpCase::MayFlushOrTable; + self.inner.state.row_count += 1; Ok(()) } @@ -1255,9 +1349,9 @@ impl Buffer { /// ``` pub fn at_now(&mut self) -> crate::Result<()> { self.check_op(Op::At)?; - self.output.push(b'\n'); - self.state.op_case = OpCase::MayFlushOrTable; - self.state.row_count += 1; + self.inner.output.put_u8(b'\n'); + self.inner.state.op_case = OpCase::MayFlushOrTable; + self.inner.state.row_count += 1; Ok(()) } } @@ -1265,11 +1359,11 @@ impl Buffer { impl Debug for Buffer { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("Buffer") - .field("output", &DebugBytes(&self.output)) - .field("state", &self.state) - .field("marker", &self.marker) - .field("max_name_len", &self.max_name_len) - .field("protocol_version", &self.protocol_version) + .field("output", &DebugBytes(&self.inner.output)) + .field("state", &self.inner.state) + .field("marker", &self.inner.marker) + .field("max_name_len", &self.inner.max_name_len) + .field("protocol_version", &self.inner.protocol_version) .finish() } } diff --git a/questdb-rs/src/ingress/conf.rs b/questdb-rs/src/ingress/conf.rs index 8e3ac4a4..21dfdce3 100644 --- a/questdb-rs/src/ingress/conf.rs +++ b/questdb-rs/src/ingress/conf.rs @@ -22,9 +22,12 @@ * ******************************************************************************/ -use crate::{Error, ErrorCode, Result}; +use crate::error::{fmt, Error, ErrorCode, Result}; use std::ops::Deref; +#[cfg(feature = "_sender-http")] +pub(crate) const SETTINGS_RETRY_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(1); + /// Wraps a SenderBuilder config setting with the intent of tracking /// whether the value was user-specified or defaulted. /// This helps the builder API ensure that a user-specified value can't @@ -150,3 +153,21 @@ pub(crate) enum AuthParams { #[cfg(feature = "_sender-http")] Token(TokenAuthParams), } + +#[cfg(feature = "_sender-http")] +pub fn auth_params_to_header_string(auth: &Option) -> Result> { + Ok(match auth { + Some(AuthParams::Basic(ref auth)) => Some(auth.to_header_string()), + Some(AuthParams::Token(ref auth)) => Some(auth.to_header_string()?), + + #[cfg(feature = "sync-sender-tcp")] + Some(AuthParams::Ecdsa(_)) => { + return Err(fmt!( + AuthError, + "ECDSA authentication is not supported for ILP over HTTP. \ + Please use basic or token authentication instead." + )); + } + None => None, + }) +} diff --git a/questdb-rs/src/ingress/http_common.rs b/questdb-rs/src/ingress/http_common.rs new file mode 100644 index 00000000..a250a68b --- /dev/null +++ b/questdb-rs/src/ingress/http_common.rs @@ -0,0 +1,285 @@ +/******************************************************************************* + * ___ _ ____ ____ + * / _ \ _ _ ___ ___| |_| _ \| __ ) + * | | | | | | |/ _ \/ __| __| | | | _ \ + * | |_| | |_| | __/\__ \ |_| |_| | |_) | + * \__\_\\__,_|\___||___/\__|____/|____/ + * + * Copyright (c) 2014-2019 Appsicle + * Copyright (c) 2019-2025 QuestDB + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ******************************************************************************/ + +use crate::error::Result; +use crate::ingress::DebugBytes; +use crate::{fmt, ingress::ProtocolVersion, Error}; +use http::{HeaderMap, StatusCode}; +use std::fmt::Write; + +pub(crate) fn is_retriable_status_code(status: StatusCode) -> bool { + status.is_server_error() + && matches!( + status.as_u16(), + // Official HTTP codes + 500 | // Internal Server Error + 503 | // Service Unavailable + 504 | // Gateway Timeout + + // Unofficial extensions + 507 | // Insufficient Storage + 509 | // Bandwidth Limit Exceeded + 523 | // Origin is Unreachable + 524 | // A Timeout Occurred + 529 | // Site is overloaded + 599 // Network Connect Timeout Error + ) +} + +pub(crate) fn check_status_code(status: StatusCode, url: &str) -> Result<()> { + let code = status.as_u16(); + match status.as_u16() { + 404 => Err(fmt!( + HttpNotSupported, + "Could not flush buffer: HTTP endpoint does not support ILP." + )), + 401 | 403 => Err(fmt!( + AuthError, + "Could not flush buffer: HTTP endpoint authentication error [code: {code}]", + )), + _ if status.is_client_error() || status.is_server_error() => Err(fmt!( + SocketError, + "Could not flush buffer: {}: {}", + url, + status.as_str() + )), + _ => Ok(()), + } +} + +fn parse_server_settings( + response: &str, + settings_url: &str, + default_protocol_version: crate::ingress::ProtocolVersion, + default_max_name_len: usize, +) -> Result<(Vec, usize)> { + let json: serde_json::Value = serde_json::from_str(response).map_err(|_| { + crate::error::fmt!( + ProtocolVersionError, + "Malformed server response, settings url: {}, err: response is not valid JSON.", + settings_url, + ) + })?; + + let mut support_versions: Vec = vec![]; + if let Some(serde_json::Value::Array(ref values)) = json + .get("config") + .and_then(|v| v.get("line.proto.support.versions")) + { + for value in values.iter() { + if let Some(v) = value.as_u64() { + match v { + 1 => support_versions.push(ProtocolVersion::V1), + 2 => support_versions.push(ProtocolVersion::V2), + _ => {} + } + } + } + } else { + support_versions.push(default_protocol_version); + } + + let max_name_length = json + .get("config") + .and_then(|v| v.get("cairo.max.file.name.length")) + .and_then(|v| v.as_u64()) + .unwrap_or(default_max_name_len as u64) as usize; + Ok((support_versions, max_name_length)) +} + +pub(crate) fn pick_protocol_version( + server_versions: &[ProtocolVersion], +) -> Result { + [ProtocolVersion::V2, ProtocolVersion::V1] + .into_iter() + .find(|v| server_versions.contains(v)) + .ok_or_else(|| { + fmt!( + ProtocolVersionError, + "Server does not support current client" + ) + }) +} + +pub(crate) fn process_settings_response>( + response: Result<(StatusCode, ParsedResponseHeaders, P)>, + settings_url: &str, + default_protocol_version: ProtocolVersion, + default_max_name_len: usize, +) -> Result<(Vec, usize)> { + let body = match &response { + Ok((status, _header_data, body)) => { + if status.is_client_error() || status.is_server_error() { + if status.as_u16() == 404 { + return Ok((vec![default_protocol_version], default_max_name_len)); + } + return Err(fmt!( + ProtocolVersionError, + "Could not detect server's line protocol version, settings url: {settings_url}, status code: {status}." + )); + } + body.as_ref() + } + Err(e) => { + return Err(fmt!( + ProtocolVersionError, + "Could not read the server's protocol version from the server: {e}", + )) + } + }; + + let body_str = std::str::from_utf8(body).map_err(|utf8_error| { + fmt!( + ProtocolVersionError, + "Could not read the server's /settings response as a string: {:?}: {utf8_error}", + DebugBytes(body) + ) + })?; + + parse_server_settings( + body_str, + settings_url, + default_protocol_version, + default_max_name_len, + ) +} + +fn parse_json_error(json: &serde_json::Value, msg: &str) -> Error { + let mut description = msg.to_string(); + fmt!(ServerFlushError, "Could not flush buffer: {}", msg); + + let error_id = json.get("errorId").and_then(|v| v.as_str()); + let code = json.get("code").and_then(|v| v.as_str()); + let line = json.get("line").and_then(|v| v.as_i64()); + + let mut printed_detail = false; + if error_id.is_some() || code.is_some() || line.is_some() { + description.push_str(" ["); + + if let Some(error_id) = error_id { + description.push_str("id: "); + description.push_str(error_id); + printed_detail = true; + } + + if let Some(code) = code { + if printed_detail { + description.push_str(", "); + } + description.push_str("code: "); + description.push_str(code); + printed_detail = true; + } + + if let Some(line) = line { + if printed_detail { + description.push_str(", "); + } + description.push_str("line: "); + write!(description, "{line}").unwrap(); + } + + description.push(']'); + } + + fmt!(ServerFlushError, "Could not flush buffer: {}", description) +} + +/// Pre-parsed header data fields. +/// Preparsing avoids copying/allocating a heavier `http::header::map::HeaderMap` object. +#[derive(Debug, Default)] +pub(crate) struct ParsedResponseHeaders { + /// "Content-Type" was "application/json" + json_content_type: bool, +} + +impl ParsedResponseHeaders { + pub fn parse(headers: &HeaderMap) -> Self { + let json_content_type = headers + .get("Content-Type") + .and_then(|ct| ct.to_str().ok()) + .is_some_and(|ct| ct.eq_ignore_ascii_case("application/json")); + Self { json_content_type } + } +} + +pub(crate) fn parse_http_error>( + status: StatusCode, + header: ParsedResponseHeaders, + body: P, +) -> Error { + let body = body.as_ref(); + let msg = match std::str::from_utf8(body) { + Ok(body_str) => body_str, + Err(utf8_error) => { + return fmt!( + ServerFlushError, + "Could not read the server's flush response as a string: {:?}: {utf8_error}", + DebugBytes(body) + ); + } + }; + + let code = status.as_u16(); + match (status.as_u16(), msg) { + (404, _) => { + return fmt!( + HttpNotSupported, + "Could not flush buffer: HTTP endpoint does not support ILP." + ); + } + (401, "") | (403, "") => { + return fmt!( + AuthError, + "Could not flush buffer: HTTP endpoint authentication error [code: {code}]" + ); + } + (401, msg) | (403, msg) => { + return fmt!( + AuthError, + "Could not flush buffer: HTTP endpoint authentication error: {msg} [code: {code}]" + ); + } + _ => (), + } + + let string_err = || fmt!(ServerFlushError, "Could not flush buffer: {}", msg); + + if !header.json_content_type { + return string_err(); + } + + let json: serde_json::Value = match serde_json::from_str(&msg) { + Ok(json) => json, + Err(_) => { + return string_err(); + } + }; + + if let Some(serde_json::Value::String(ref msg)) = json.get("message") { + parse_json_error(&json, msg) + } else { + string_err() + } +} diff --git a/questdb-rs/src/ingress/mod.rs b/questdb-rs/src/ingress/mod.rs index 473e6e13..427b077b 100644 --- a/questdb-rs/src/ingress/mod.rs +++ b/questdb-rs/src/ingress/mod.rs @@ -59,8 +59,20 @@ mod timestamp; mod buffer; pub use buffer::*; -mod sender; -pub use sender::*; +#[cfg(feature = "_sender-http")] +mod http_common; + +#[cfg(feature = "_sync-sender")] +mod sync_sender; + +#[cfg(feature = "_sync-sender")] +pub use sync_sender::*; + +#[cfg(feature = "_async-sender")] +mod async_sender; + +#[cfg(feature = "_async-sender")] +pub use async_sender::*; const MAX_NAME_LEN_DEFAULT: usize = 127; @@ -609,7 +621,7 @@ impl SenderBuilder { tls_ca: ConfigSetting::new_default(tls_ca), tls_roots: ConfigSetting::new_default(None), - #[cfg(feature = "sync-sender-http")] + #[cfg(feature = "_sender-http")] http: if protocol.is_httpx() { Some(conf::HttpConfig::default()) } else { @@ -1008,6 +1020,53 @@ impl SenderBuilder { } } + #[cfg(feature = "_async-sender")] + pub async fn build_async(self) -> Result> { + if !self.protocol.is_httpx() { + return Err(fmt!( + ConfigError, + "Only the HTTP and HTTPS protocols are supported by the AsyncSender." + )); + } + + let mut descr = format!("Sender[host={:?},port={:?},", self.host, self.port); + + if self.protocol.tls_enabled() { + write!(descr, "tls=enabled,").unwrap(); + } else { + write!(descr, "tls=disabled,").unwrap(); + } + + #[cfg(feature = "insecure-skip-verify")] + let tls_verify = *self.tls_verify; + + let tls_settings = tls::TlsSettings::build( + self.protocol.tls_enabled(), + #[cfg(feature = "insecure-skip-verify")] + tls_verify, + *self.tls_ca, + self.tls_roots.deref().as_deref(), + )?; + + let auth = self.build_auth()?; + let auth = conf::auth_params_to_header_string(&auth)?; + + let http_config = self.http.unwrap(); + + AsyncSender::new( + descr, + self.host.deref(), + self.port.deref(), + tls_settings, + auth, + *self.max_name_len.deref(), + *self.max_buf_size.deref(), + *self.protocol_version.deref(), + http_config, + ) + .await + } + #[cfg(feature = "_sync-sender")] /// Build the sender. /// @@ -1074,20 +1133,8 @@ impl SenderBuilder { let connector = connector.chain(TlsConnector::new(tls_config)); - let auth = match auth { - Some(conf::AuthParams::Basic(ref auth)) => Some(auth.to_header_string()), - Some(conf::AuthParams::Token(ref auth)) => Some(auth.to_header_string()?), - - #[cfg(feature = "sync-sender-tcp")] - Some(conf::AuthParams::Ecdsa(_)) => { - return Err(fmt!( - AuthError, - "ECDSA authentication is not supported for ILP over HTTP. \ - Please use basic or token authentication instead." - )); - } - None => None, - }; + let auth = conf::auth_params_to_header_string(&auth)?; + let agent_builder = agent_builder .timeout_connect(Some(*http_config.request_timeout.deref())) .http_status_as_error(false); @@ -1133,16 +1180,7 @@ impl SenderBuilder { let (protocol_versions, server_max_name_len) = read_server_settings(http_state, settings_url, max_name_len)?; max_name_len = server_max_name_len; - if protocol_versions.contains(&ProtocolVersion::V2) { - ProtocolVersion::V2 - } else if protocol_versions.contains(&ProtocolVersion::V1) { - ProtocolVersion::V1 - } else { - return Err(fmt!( - ProtocolVersionError, - "Server does not support current client" - )); - } + http_common::pick_protocol_version(&protocol_versions[..])? } else { unreachable!("HTTP handler should be used for HTTP protocol"); } @@ -1285,6 +1323,23 @@ fn parse_key_pair(auth: &conf::EcdsaAuthParams) -> Result { }) } +#[inline(always)] +fn check_protocol_version( + sender_version: ProtocolVersion, + buffer_version: ProtocolVersion, +) -> Result<()> { + if sender_version != buffer_version { + return Err(fmt!( + ProtocolVersionError, + "Attempting to send with protocol version {} \ + but the sender is configured to use protocol version {}", + buffer_version, + sender_version + )); + } + Ok(()) +} + struct DebugBytes<'a>(pub &'a [u8]); impl<'a> Debug for DebugBytes<'a> { diff --git a/questdb-rs/src/ingress/sender/http.rs b/questdb-rs/src/ingress/sync_sender/http.rs similarity index 53% rename from questdb-rs/src/ingress/sender/http.rs rename to questdb-rs/src/ingress/sync_sender/http.rs index 3a99be8e..fe599d00 100644 --- a/questdb-rs/src/ingress/sender/http.rs +++ b/questdb-rs/src/ingress/sync_sender/http.rs @@ -23,7 +23,11 @@ ******************************************************************************/ use crate::error::fmt; +use crate::ingress::http_common::{ + is_retriable_status_code, process_settings_response, ParsedResponseHeaders, +}; use crate::{error, Error}; +use http::{HeaderMap, StatusCode}; use rand::Rng; use rustls::{ClientConnection, StreamOwned}; use rustls_pki_types::ServerName; @@ -38,12 +42,11 @@ use ureq::unversioned::transport::{ Buffers, Connector, LazyBuffers, NextTimeout, Transport, TransportAdapter, }; -use crate::ingress::conf::HttpConfig; -use crate::ingress::ProtocolVersion; +use crate::ingress::conf::{HttpConfig, SETTINGS_RETRY_TIMEOUT}; +use crate::ingress::{DebugBytes, ProtocolVersion}; use ureq::unversioned::*; use ureq::{http, Body}; -#[cfg(feature = "sync-sender-http")] pub(crate) struct SyncHttpHandlerState { /// Maintains a pool of open HTTP connections to the endpoint. pub(crate) agent: ureq::Agent, @@ -58,13 +61,15 @@ pub(crate) struct SyncHttpHandlerState { pub(crate) config: HttpConfig, } -#[cfg(feature = "sync-sender-http")] impl SyncHttpHandlerState { fn send_request( &self, buf: &[u8], request_timeout: Duration, - ) -> (bool, Result, ureq::Error>) { + ) -> ( + bool, + error::Result<(StatusCode, ParsedResponseHeaders, Vec)>, + ) { let request = self .agent .post(&self.url) @@ -78,10 +83,52 @@ impl SyncHttpHandlerState { Some(auth) => request.header("Authorization", auth), None => request, }; + let response = request.send(buf); - match &response { - Ok(res) => (need_retry(Ok(res.status())), response), - Err(err) => (need_retry(Err(err)), response), + match response { + Ok(response_body) => { + let status = response_body.status(); + let (parts, mut body) = response_body.into_parts(); + let headers = parts.headers; + let header_data = ParsedResponseHeaders::parse(&headers); + let need_retry = is_retriable_status_code(status); + match body.read_to_vec() { + Ok(body) => (need_retry, Ok((status, header_data, body))), + Err(err) => ( + need_retry, + Err(fmt!( + SocketError, + "Could not flush buffer: {}: {err}", + &self.url + )), + ), + } + } + Err(ureq::Error::StatusCode(code)) => { + let status = + StatusCode::from_u16(code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + let need_retry = is_retriable_status_code(status); + ( + need_retry, + Ok((status, ParsedResponseHeaders::default(), Vec::new())), + ) + } + Err(err) => { + let need_retry = matches!( + err, + ureq::Error::Timeout(_) + | ureq::Error::ConnectionFailed + | ureq::Error::TooManyRedirects + ); + ( + need_retry, + Err(fmt!( + SocketError, + "Could not flush buffer: {}: {err}", + &self.url + )), + ) + } } } @@ -209,26 +256,9 @@ impl Transport for TlsTransport { } } -fn need_retry(res: Result) -> bool { +fn need_retry(res: Result) -> bool { match res { - Ok(status) => { - status.is_server_error() - && matches!( - status.as_u16(), - // Official HTTP codes - 500 | // Internal Server Error - 503 | // Service Unavailable - 504 | // Gateway Timeout - - // Unofficial extensions - 507 | // Insufficient Storage - 509 | // Bandwidth Limit Exceeded - 523 | // Origin is Unreachable - 524 | // A Timeout Occurred - 529 | // Site is overloaded - 599 // Network Connect Timeout Error - ) - } + Ok(status) => is_retriable_status_code(status), Err(err) => matches!( err, ureq::Error::Timeout(_) | ureq::Error::ConnectionFailed | ureq::Error::TooManyRedirects @@ -236,110 +266,13 @@ fn need_retry(res: Result) -> bool { } } -fn parse_json_error(json: &serde_json::Value, msg: &str) -> Error { - let mut description = msg.to_string(); - error::fmt!(ServerFlushError, "Could not flush buffer: {}", msg); - - let error_id = json.get("errorId").and_then(|v| v.as_str()); - let code = json.get("code").and_then(|v| v.as_str()); - let line = json.get("line").and_then(|v| v.as_i64()); - - let mut printed_detail = false; - if error_id.is_some() || code.is_some() || line.is_some() { - description.push_str(" ["); - - if let Some(error_id) = error_id { - description.push_str("id: "); - description.push_str(error_id); - printed_detail = true; - } - - if let Some(code) = code { - if printed_detail { - description.push_str(", "); - } - description.push_str("code: "); - description.push_str(code); - printed_detail = true; - } - - if let Some(line) = line { - if printed_detail { - description.push_str(", "); - } - description.push_str("line: "); - write!(description, "{line}").unwrap(); - } - - description.push(']'); - } - - error::fmt!(ServerFlushError, "Could not flush buffer: {}", description) -} - -pub(super) fn parse_http_error(http_status_code: u16, response: Response) -> Error { - let (head, body) = response.into_parts(); - let body_content = body.into_with_config().lossy_utf8(true).read_to_string(); - if http_status_code == 404 { - return error::fmt!( - HttpNotSupported, - "Could not flush buffer: HTTP endpoint does not support ILP." - ); - } else if [401, 403].contains(&http_status_code) { - let description = match body_content { - Ok(msg) if !msg.is_empty() => format!(": {msg}"), - _ => "".to_string(), - }; - return error::fmt!( - AuthError, - "Could not flush buffer: HTTP endpoint authentication error{} [code: {}]", - description, - http_status_code - ); - } - - let is_json = match head.headers.get("Content-Type") { - Some(header_value) => match header_value.to_str() { - Ok(s) => s.eq_ignore_ascii_case("application/json"), - Err(_) => false, - }, - None => false, - }; - match body_content { - Ok(msg) => { - let string_err = || error::fmt!(ServerFlushError, "Could not flush buffer: {}", msg); - - if !is_json { - return string_err(); - } - - let json: serde_json::Value = match serde_json::from_str(&msg) { - Ok(json) => json, - Err(_) => { - return string_err(); - } - }; - - if let Some(serde_json::Value::String(ref msg)) = json.get("message") { - parse_json_error(&json, msg) - } else { - string_err() - } - } - Err(err) => { - error::fmt!(SocketError, "Could not flush buffer: {}", err) - } - } -} - -#[allow(clippy::result_large_err)] // `ureq::Error` is large enough to cause this warning. fn retry_http_send( state: &SyncHttpHandlerState, buf: &[u8], request_timeout: Duration, retry_timeout: Duration, - mut last_rep: Result, ureq::Error>, -) -> Result, ureq::Error> { + mut last_response: error::Result<(StatusCode, ParsedResponseHeaders, Vec)>, +) -> error::Result<(StatusCode, ParsedResponseHeaders, Vec)> { let mut rng = rand::rng(); let retry_end = std::time::Instant::now() + retry_timeout; let mut retry_interval_ms = 10; @@ -349,17 +282,12 @@ fn retry_http_send( let to_sleep_ms = retry_interval_ms + jitter_ms; let to_sleep = Duration::from_millis(to_sleep_ms as u64); if (std::time::Instant::now() + to_sleep) > retry_end { - return last_rep; + return last_response; } sleep(to_sleep); - if let Ok(last_rep) = last_rep { - // Actively consume the reader to return the connection to the connection pool. - // see https://github.com/algesten/ureq/issues/94 - _ = last_rep.into_body().read_to_vec(); - } - (need_retry, last_rep) = state.send_request(buf, request_timeout); + (need_retry, last_response) = state.send_request(buf, request_timeout); if !need_retry { - return last_rep; + return last_response; } retry_interval_ms = (retry_interval_ms * 2).min(1000); } @@ -371,13 +299,13 @@ pub(super) fn http_send_with_retries( buf: &[u8], request_timeout: Duration, retry_timeout: Duration, -) -> Result, ureq::Error> { - let (need_retry, last_rep) = state.send_request(buf, request_timeout); +) -> error::Result<(StatusCode, ParsedResponseHeaders, Vec)> { + let (need_retry, last_response) = state.send_request(buf, request_timeout); if !need_retry || retry_timeout.is_zero() { - return last_rep; + return last_response; } - retry_http_send(state, buf, request_timeout, retry_timeout, last_rep) + retry_http_send(state, buf, request_timeout, retry_timeout, last_response) } /// Read the server settings from the `/settings` endpoint. @@ -391,101 +319,53 @@ pub(crate) fn read_server_settings( state: &SyncHttpHandlerState, settings_url: &str, default_max_name_len: usize, -) -> Result<(Vec, usize), Error> { +) -> error::Result<(Vec, usize)> { let default_protocol_version = ProtocolVersion::V1; - let response = match http_get_with_retries( + let response = http_get_with_retries( state, settings_url, *state.config.request_timeout, - Duration::from_secs(1), - ) { - Ok(res) => { - if res.status().is_client_error() || res.status().is_server_error() { - let status = res.status(); - _ = res.into_body().read_to_vec(); - if status.as_u16() == 404 { - return Ok((vec![default_protocol_version], default_max_name_len)); - } - return Err(fmt!( + SETTINGS_RETRY_TIMEOUT, + ); + + // Fully read the response. + let response = match response { + Ok(response) => { + let status = response.status(); + let header_data = ParsedResponseHeaders::parse(response.headers()); + match response.into_body().read_to_vec() { + Ok(body) => Ok((status, header_data, body)), + Err(_) if status.as_u16() == 404 => + Ok((status, header_data, Vec::new())), + Err(err) => Err(fmt!( ProtocolVersionError, - "Could not detect server's line protocol version, settings url: {}, status code: {}.", + "Could not detect server's line protocol version, settings url: {}, status code: {}, err: {}", settings_url, - status - )); - } else { - res + status, + err + )) } } - Err(err) => { - let e = match err { - ureq::Error::StatusCode(code) => { - if code == 404 { - return Ok((vec![default_protocol_version], default_max_name_len)); - } else { - fmt!( - ProtocolVersionError, - "Could not detect server's line protocol version, settings url: {}, err: {}.", - settings_url, - err - ) - } - } - e => { - fmt!( - ProtocolVersionError, - "Could not detect server's line protocol version, settings url: {}, err: {}.", - settings_url, - e - ) - } - }; - return Err(e); - } + Err(ureq::Error::StatusCode(code)) => Ok(( + StatusCode::from_u16(code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), + ParsedResponseHeaders::default(), + Vec::new(), + )), + Err(err) => Err(fmt!( + ProtocolVersionError, + "Could not detect server's line protocol version, settings url: {}, err: {}", + settings_url, + err + )), }; - let (_, body) = response.into_parts(); - let body_content = body.into_with_config().read_to_string(); - - if let Ok(msg) = body_content { - let json: serde_json::Value = serde_json::from_str(&msg).map_err(|_| { - error::fmt!( - ProtocolVersionError, - "Malformed server response, settings url: {}, err: response is not valid JSON.", - settings_url, - ) - })?; - - let mut support_versions: Vec = vec![]; - if let Some(serde_json::Value::Array(ref values)) = json - .get("config") - .and_then(|v| v.get("line.proto.support.versions")) - { - for value in values.iter() { - if let Some(v) = value.as_u64() { - match v { - 1 => support_versions.push(ProtocolVersion::V1), - 2 => support_versions.push(ProtocolVersion::V2), - _ => {} - } - } - } - } else { - support_versions.push(default_protocol_version); - } - - let max_name_length = json - .get("config") - .and_then(|v| v.get("cairo.max.file.name.length")) - .and_then(|v| v.as_u64()) - .unwrap_or(default_max_name_len as u64) as usize; - Ok((support_versions, max_name_length)) - } else { - Err(error::fmt!( - ProtocolVersionError, - "Malformed server response, settings url: {}, err: failed to read response body as UTF-8", settings_url - )) - } + process_settings_response( + response, + settings_url, + default_protocol_version, + default_max_name_len, + ) } #[allow(clippy::result_large_err)] // `ureq::Error` is large enough to cause this warning. @@ -494,7 +374,7 @@ fn retry_http_get( url: &str, request_timeout: Duration, retry_timeout: Duration, - mut last_rep: Result, ureq::Error>, + mut last_response: Result, ureq::Error>, ) -> Result, ureq::Error> { let mut rng = rand::rng(); let retry_end = std::time::Instant::now() + retry_timeout; @@ -505,17 +385,17 @@ fn retry_http_get( let to_sleep_ms = retry_interval_ms + jitter_ms; let to_sleep = Duration::from_millis(to_sleep_ms as u64); if (std::time::Instant::now() + to_sleep) > retry_end { - return last_rep; + return last_response; } sleep(to_sleep); - if let Ok(last_rep) = last_rep { + if let Ok(last_response) = last_response { // Actively consume the reader to return the connection to the connection pool. // see https://github.com/algesten/ureq/issues/94 - _ = last_rep.into_body().read_to_vec(); + _ = last_response.into_body().read_to_vec(); } - (need_retry, last_rep) = state.get_request(url, request_timeout); + (need_retry, last_response) = state.get_request(url, request_timeout); if !need_retry { - return last_rep; + return last_response; } retry_interval_ms = (retry_interval_ms * 2).min(1000); } @@ -528,10 +408,10 @@ fn http_get_with_retries( request_timeout: Duration, retry_timeout: Duration, ) -> Result, ureq::Error> { - let (need_retry, last_rep) = state.get_request(url, request_timeout); + let (need_retry, last_response) = state.get_request(url, request_timeout); if !need_retry || retry_timeout.is_zero() { - return last_rep; + return last_response; } - retry_http_get(state, url, request_timeout, retry_timeout, last_rep) + retry_http_get(state, url, request_timeout, retry_timeout, last_response) } diff --git a/questdb-rs/src/ingress/sender/mod.rs b/questdb-rs/src/ingress/sync_sender/mod.rs similarity index 90% rename from questdb-rs/src/ingress/sender/mod.rs rename to questdb-rs/src/ingress/sync_sender/mod.rs index 6b65cce2..24f6cc77 100644 --- a/questdb-rs/src/ingress/sender/mod.rs +++ b/questdb-rs/src/ingress/sync_sender/mod.rs @@ -23,7 +23,7 @@ ******************************************************************************/ use crate::error::{self, Result}; -use crate::ingress::{Buffer, ProtocolVersion, SenderBuilder}; +use crate::ingress::{check_protocol_version, Buffer, ProtocolVersion, SenderBuilder}; use std::fmt::{Debug, Formatter}; #[cfg(feature = "sync-sender-tcp")] @@ -41,6 +41,7 @@ use crate::ingress::map_io_to_socket_err; #[cfg(feature = "sync-sender-http")] mod http; +use crate::ingress::http_common::parse_http_error; #[cfg(feature = "sync-sender-http")] pub(crate) use http::*; @@ -60,6 +61,8 @@ pub(crate) enum SyncProtocolHandler { pub struct Sender { descr: String, handler: SyncProtocolHandler, + + #[cfg(feature = "sync-sender-tcp")] connected: bool, max_buf_size: usize, protocol_version: ProtocolVersion, @@ -83,7 +86,10 @@ impl Sender { Self { descr, handler, + + #[cfg(feature = "sync-sender-tcp")] connected: true, + max_buf_size, protocol_version, max_name_len, @@ -135,6 +141,7 @@ impl Sender { #[allow(unused_variables)] fn flush_impl(&mut self, buf: &Buffer, transactional: bool) -> Result<()> { + #[cfg(feature = "sync-sender-tcp")] if !self.connected { return Err(error::fmt!( SocketError, @@ -152,7 +159,7 @@ impl Sender { )); } - self.check_protocol_version(buf.protocol_version())?; + check_protocol_version(self.protocol_version, buf.protocol_version())?; let bytes = buf.as_bytes(); if bytes.is_empty() { @@ -193,22 +200,18 @@ impl Sender { 0.0f64 }; - match http_send_with_retries( + let (status, header_data, response) = http_send_with_retries( state, bytes, *state.config.request_timeout + std::time::Duration::from_secs_f64(extra_time), *state.config.retry_timeout, - ) { - Ok(res) => { - if res.status().is_client_error() || res.status().is_server_error() { - Err(parse_http_error(res.status().as_u16(), res)) - } else { - res.into_body(); - Ok(()) - } - } - Err(err) => Err(crate::error::Error::from_ureq_error(err, &state.url)), + )?; + + if status.is_client_error() || status.is_server_error() { + return Err(parse_http_error(status, header_data, response)); } + + Ok(()) } } } @@ -297,18 +300,4 @@ impl Sender { pub fn max_name_len(&self) -> usize { self.max_name_len } - - #[inline(always)] - fn check_protocol_version(&self, version: ProtocolVersion) -> Result<()> { - if self.protocol_version != version { - return Err(error::fmt!( - ProtocolVersionError, - "Attempting to send with protocol version {} \ - but the sender is configured to use protocol version {}", - version, - self.protocol_version - )); - } - Ok(()) - } } diff --git a/questdb-rs/src/ingress/sender/tcp.rs b/questdb-rs/src/ingress/sync_sender/tcp.rs similarity index 100% rename from questdb-rs/src/ingress/sender/tcp.rs rename to questdb-rs/src/ingress/sync_sender/tcp.rs diff --git a/questdb-rs/src/ingress/tls.rs b/questdb-rs/src/ingress/tls.rs index 85a84fd9..9c1fdda2 100644 --- a/questdb-rs/src/ingress/tls.rs +++ b/questdb-rs/src/ingress/tls.rs @@ -103,7 +103,6 @@ fn add_os_roots(root_store: &mut RootCertStore) -> crate::Result<()> { Ok(()) } -#[derive(Debug)] pub(crate) enum TlsSettings { #[cfg(feature = "insecure-skip-verify")] SkipVerify, diff --git a/questdb-rs/src/tests/async_http.rs b/questdb-rs/src/tests/async_http.rs new file mode 100644 index 00000000..43a86999 --- /dev/null +++ b/questdb-rs/src/tests/async_http.rs @@ -0,0 +1,149 @@ +/******************************************************************************* + * ___ _ ____ ____ + * / _ \ _ _ ___ ___| |_| _ \| __ ) + * | | | | | | |/ _ \/ __| __| | | | _ \ + * | |_| | |_| | __/\__ \ |_| |_| | |_) | + * \__\_\\__,_|\___||___/\__|____/|____/ + * + * Copyright (c) 2014-2019 Appsicle + * Copyright (c) 2019-2025 QuestDB + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ******************************************************************************/ +use crate::ingress::ProtocolVersion; +use crate::tests::mock::HttpResponse; +use crate::{ + ingress::{SenderBuilder, TimestampNanos}, + tests::{mock::MockServer, TestResult}, +}; +use std::io; + +async fn _test_sender_auto_detect_protocol_version( + supported_versions: Option>, + expect_version: ProtocolVersion, + max_name_len: usize, + expect_max_name_len: usize, +) -> TestResult { + let supported_versions1 = supported_versions.clone(); + let mut server = MockServer::new()? + .configure_settings_response(supported_versions.as_deref().unwrap_or(&[]), max_name_len); + let sender_builder = server.lsb_http(); + + let server_thread = std::thread::spawn(move || -> io::Result { + server.accept()?; + let req = server.recv_http_q()?; + assert_eq!(req.method(), "GET"); + assert_eq!(req.path(), "/settings"); + match supported_versions1 { + None => server.send_http_response_q( + HttpResponse::empty() + .with_status(404, "Not Found") + .with_header("content-type", "text/plain") + .with_body_str("Not Found"), + )?, + Some(_) => server.send_settings_response()?, + } + let exp = &[ + b"test,t1=v1 ", + crate::tests::f64_to_bytes("f1", 0.5, expect_version).as_slice(), + b" 10000000\n", + ] + .concat(); + let req = server.recv_http_q()?; + assert_eq!(req.body(), exp); + server.send_http_response_q(HttpResponse::empty())?; + Ok(server) + }); + + let mut sender = sender_builder.build_async().await?; + // assert_eq!(sender.protocol_version(), expect_version); + // assert_eq!(sender.max_name_len(), expect_max_name_len); + let mut buf = sender.new_buffer(); + buf.table("test")? + .symbol("t1", "v1")? + .column_f64("f1", 0.5)? + .at(TimestampNanos::new(10000000))?; + sender.flush(buf).await?; + _ = server_thread.join().unwrap()?; + Ok(()) +} + +#[tokio::test] +async fn test_sender_auto_protocol_version_basic() -> TestResult { + _test_sender_auto_detect_protocol_version(Some(vec![1, 2]), ProtocolVersion::V2, 130, 130).await +} + +#[tokio::test] +async fn test_sender_auto_protocol_version_old_server1() -> TestResult { + _test_sender_auto_detect_protocol_version(Some(vec![]), ProtocolVersion::V1, 0, 127).await +} + +#[tokio::test] +async fn test_sender_auto_protocol_version_old_server2() -> TestResult { + _test_sender_auto_detect_protocol_version(None, ProtocolVersion::V1, 0, 127).await +} + +#[tokio::test] +async fn test_sender_auto_protocol_version_only_v1() -> TestResult { + _test_sender_auto_detect_protocol_version(Some(vec![1]), ProtocolVersion::V1, 127, 127).await +} + +#[tokio::test] +async fn test_sender_auto_protocol_version_only_v2() -> TestResult { + _test_sender_auto_detect_protocol_version(Some(vec![2]), ProtocolVersion::V2, 127, 127).await +} + +// #[tokio::test] +// async fn test_two_lines() -> TestResult { +// let mut server = MockServer::new()?; +// let sender_builder = server.lsb_http(); +// +// let server_thread = std::thread::spawn(move || -> io::Result { +// server.accept()?; +// let req = server.recv_http_q()?; +// assert_eq!(req.method(), "GET"); +// assert_eq!(req.path(), "/settings"); +// // match supported_versions1 { +// // None => server.send_http_response_q( +// // HttpResponse::empty() +// // .with_status(404, "Not Found") +// // .with_header("content-type", "text/plain") +// // .with_body_str("Not Found"), +// // )?, +// // Some(_) => server.send_settings_response()?, +// // } +// server.send_settings_response()?; +// // let exp = &[ +// // b"test,t1=v1 ", +// // crate::tests::sync_sender::f64_to_bytes("f1", 0.5, expect_version).as_slice(), +// // b" 10000000\n", +// // ] +// // .concat(); +// // let req = server.recv_http_q()?; +// // assert_eq!(req.body(), exp); +// // server.send_http_response_q(HttpResponse::empty())?; +// Ok(server) +// }); +// +// let mut sender = sender_builder +// .build_async() +// .await?; +// let mut txn = sender.new_transaction("table1")?; +// txn.row()? +// .symbol("a", "B")? +// .column_f64("b", 10.25)? +// .at(TimestampNanos::now())?; +// txn.commit().await?; +// Ok(()) +// } diff --git a/questdb-rs/src/tests/mock.rs b/questdb-rs/src/tests/mock.rs index 91880fb4..877eacc7 100644 --- a/questdb-rs/src/tests/mock.rs +++ b/questdb-rs/src/tests/mock.rs @@ -39,7 +39,7 @@ use std::time::Instant; #[cfg(feature = "sync-sender-tcp")] use crate::tests::ndarr::ArrayColumnTypeTag; -#[cfg(feature = "sync-sender-http")] +#[cfg(feature = "_sender-http")] use std::io::Write; const CLIENT: Token = Token(0); @@ -57,7 +57,7 @@ pub struct MockServer { #[cfg(feature = "sync-sender-tcp")] pub msgs: Vec>, - #[cfg(feature = "sync-sender-http")] + #[cfg(feature = "_sender-http")] settings_response: serde_json::Value, } @@ -87,7 +87,7 @@ fn tls_config() -> Arc { Arc::new(config) } -#[cfg(feature = "sync-sender-http")] +#[cfg(feature = "_sender-http")] pub struct HttpRequest { method: String, path: String, @@ -95,7 +95,7 @@ pub struct HttpRequest { body: Vec, } -#[cfg(feature = "sync-sender-http")] +#[cfg(feature = "_sender-http")] impl HttpRequest { pub fn method(&self) -> &str { &self.method @@ -114,7 +114,7 @@ impl HttpRequest { } } -#[cfg(feature = "sync-sender-http")] +#[cfg(feature = "_sender-http")] pub struct HttpResponse { status_code: u16, status_text: String, @@ -122,7 +122,7 @@ pub struct HttpResponse { body: Vec, } -#[cfg(feature = "sync-sender-http")] +#[cfg(feature = "_sender-http")] impl HttpResponse { pub fn empty() -> Self { HttpResponse { @@ -184,14 +184,14 @@ impl HttpResponse { } } -#[cfg(feature = "sync-sender-http")] +#[cfg(feature = "_sender-http")] fn contains(haystack: &[u8], needle: &[u8]) -> bool { haystack .windows(needle.len()) .any(|window| window == needle) } -#[cfg(feature = "sync-sender-http")] +#[cfg(feature = "_sender-http")] fn position(haystack: &[u8], needle: &[u8]) -> Option { haystack .windows(needle.len()) @@ -217,7 +217,7 @@ impl MockServer { #[cfg(feature = "sync-sender-tcp")] msgs: Vec::new(), - #[cfg(feature = "sync-sender-http")] + #[cfg(feature = "_sender-http")] settings_response: serde_json::Value::Null, }) } @@ -301,7 +301,7 @@ impl MockServer { self.wait_for(timeout, |event| event.is_readable()) } - #[cfg(feature = "sync-sender-http")] + #[cfg(feature = "_sender-http")] pub fn wait_for_send(&mut self, duration: Option) -> io::Result { self.wait_for(duration, |event| event.is_writable()) } @@ -316,7 +316,7 @@ impl MockServer { } } - #[cfg(feature = "sync-sender-http")] + #[cfg(feature = "_sender-http")] pub fn configure_settings_response( mut self, supported_versions: &[u16], @@ -339,7 +339,7 @@ impl MockServer { self } - #[cfg(feature = "sync-sender-http")] + #[cfg(feature = "_sender-http")] fn do_write(&mut self, buf: &[u8]) -> io::Result { let client = self.client.as_mut().unwrap(); if let Some(tls_conn) = self.tls_conn.as_mut() { @@ -350,7 +350,7 @@ impl MockServer { } } - #[cfg(feature = "sync-sender-http")] + #[cfg(feature = "_sender-http")] fn do_write_all(&mut self, buf: &[u8], timeout_sec: Option) -> io::Result<()> { let deadline = timeout_sec.map(|sec| Instant::now() + Duration::from_secs_f64(sec)); let mut pos = 0usize; @@ -385,7 +385,7 @@ impl MockServer { Ok(()) } - #[cfg(feature = "sync-sender-http")] + #[cfg(feature = "_sender-http")] fn read_more(&mut self, accum: &mut Vec, deadline: Instant, stage: &str) -> io::Result<()> { let mut chunk = [0u8; 1024]; let count = match self.do_read(&mut chunk[..]) { @@ -422,7 +422,7 @@ impl MockServer { Ok(()) } - #[cfg(feature = "sync-sender-http")] + #[cfg(feature = "_sender-http")] fn recv_http_method( &mut self, accum: &mut Vec, @@ -451,7 +451,7 @@ impl MockServer { Ok((body_start, method, path)) } - #[cfg(feature = "sync-sender-http")] + #[cfg(feature = "_sender-http")] fn recv_http_headers( &mut self, pos: usize, @@ -480,7 +480,7 @@ impl MockServer { Ok((body_start, headers)) } - #[cfg(feature = "sync-sender-http")] + #[cfg(feature = "_sender-http")] pub fn send_http_response( &mut self, response: HttpResponse, @@ -490,7 +490,7 @@ impl MockServer { Ok(()) } - #[cfg(feature = "sync-sender-http")] + #[cfg(feature = "_sender-http")] pub fn send_settings_response(&mut self) -> io::Result<()> { let response = HttpResponse::empty() .with_status(200, "OK") @@ -499,12 +499,12 @@ impl MockServer { Ok(()) } - #[cfg(feature = "sync-sender-http")] + #[cfg(feature = "_sender-http")] pub fn send_http_response_q(&mut self, response: HttpResponse) -> io::Result<()> { self.send_http_response(response, Some(5.0)) } - #[cfg(feature = "sync-sender-http")] + #[cfg(feature = "_sender-http")] pub fn recv_http(&mut self, wait_timeout_sec: f64) -> io::Result { let mut accum = Vec::::new(); let deadline = Instant::now() + Duration::from_secs_f64(wait_timeout_sec); @@ -537,7 +537,7 @@ impl MockServer { }) } - #[cfg(feature = "sync-sender-http")] + #[cfg(feature = "_sender-http")] pub fn recv_http_q(&mut self) -> io::Result { self.recv_http(5.0) } @@ -636,12 +636,12 @@ impl MockServer { SenderBuilder::new(Protocol::Tcps, self.host, self.port) } - #[cfg(feature = "sync-sender-http")] + #[cfg(feature = "_sender-http")] pub fn lsb_http(&self) -> SenderBuilder { SenderBuilder::new(Protocol::Http, self.host, self.port) } - #[cfg(feature = "sync-sender-http")] + #[cfg(feature = "_sender-http")] pub fn lsb_https(&self) -> SenderBuilder { SenderBuilder::new(Protocol::Https, self.host, self.port) } diff --git a/questdb-rs/src/tests/mod.rs b/questdb-rs/src/tests/mod.rs index 5611c74f..aa0d84c9 100644 --- a/questdb-rs/src/tests/mod.rs +++ b/questdb-rs/src/tests/mod.rs @@ -22,13 +22,20 @@ * ******************************************************************************/ +use crate::ingress::{F64Serializer, ProtocolVersion, DOUBLE_BINARY_FORMAT_TYPE}; + mod f64_serializer; #[cfg(feature = "sync-sender-http")] -mod http; +mod sync_http; + +#[cfg(feature = "async-sender-http")] +mod async_http; mod mock; -mod sender; + +#[cfg(feature = "_sync-sender")] +mod sync_sender; mod ndarr; @@ -66,3 +73,22 @@ pub fn assert_err_contains( } } } + +pub(crate) fn f64_to_bytes(name: &str, value: f64, version: ProtocolVersion) -> Vec { + let mut buf = Vec::new(); + buf.extend_from_slice(name.as_bytes()); + buf.push(b'='); + + match version { + ProtocolVersion::V1 => { + let mut ser = F64Serializer::new(value); + buf.extend_from_slice(ser.as_str().as_bytes()); + } + ProtocolVersion::V2 => { + buf.push(b'='); + buf.push(DOUBLE_BINARY_FORMAT_TYPE); + buf.extend_from_slice(&value.to_le_bytes()); + } + } + buf +} diff --git a/questdb-rs/src/tests/http.rs b/questdb-rs/src/tests/sync_http.rs similarity index 85% rename from questdb-rs/src/tests/http.rs rename to questdb-rs/src/tests/sync_http.rs index d66242db..ff82845d 100644 --- a/questdb-rs/src/tests/http.rs +++ b/questdb-rs/src/tests/sync_http.rs @@ -24,17 +24,23 @@ use crate::ingress::{Buffer, Protocol, ProtocolVersion, SenderBuilder, TimestampNanos}; use crate::tests::mock::{certs_dir, HttpResponse, MockServer}; -use crate::tests::{assert_err_contains, TestResult}; +use crate::tests::{assert_err_contains, f64_to_bytes, TestResult}; use crate::ErrorCode; -use rstest::rstest; use std::io; use std::io::ErrorKind; use std::time::Duration; -#[rstest] -fn test_two_lines( - #[values(ProtocolVersion::V1, ProtocolVersion::V2)] version: ProtocolVersion, -) -> TestResult { +#[test] +fn test_two_lines_v1() -> TestResult { + _test_two_lines(ProtocolVersion::V1) +} + +#[test] +fn test_two_lines_v2() -> TestResult { + _test_two_lines(ProtocolVersion::V2) +} + +fn _test_two_lines(version: ProtocolVersion) -> TestResult { let mut server = MockServer::new()?; let mut sender = server.lsb_http().protocol_version(version)?.build()?; let mut buffer = sender.new_buffer(); @@ -77,10 +83,17 @@ fn test_two_lines( Ok(()) } -#[rstest] -fn test_text_plain_error( - #[values(ProtocolVersion::V1, ProtocolVersion::V2)] version: ProtocolVersion, -) -> TestResult { +#[test] +fn test_text_plain_error_v1() -> TestResult { + _test_text_plain_error(ProtocolVersion::V1) +} + +#[test] +fn test_text_plain_error_v2() -> TestResult { + _test_text_plain_error(ProtocolVersion::V2) +} + +fn _test_text_plain_error(version: ProtocolVersion) -> TestResult { let mut server = MockServer::new()?; let mut sender = server.lsb_http().protocol_version(version)?.build()?; let mut buffer = sender.new_buffer(); @@ -120,10 +133,17 @@ fn test_text_plain_error( Ok(()) } -#[rstest] -fn test_bad_json_error( - #[values(ProtocolVersion::V1, ProtocolVersion::V2)] version: ProtocolVersion, -) -> TestResult { +#[test] +fn test_bad_json_error_v1() -> TestResult { + _test_bad_json_error(ProtocolVersion::V1) +} + +#[test] +fn test_bad_json_error_v2() -> TestResult { + _test_bad_json_error(ProtocolVersion::V2) +} + +fn _test_bad_json_error(version: ProtocolVersion) -> TestResult { let mut server = MockServer::new()?; let mut sender = server.lsb_http().protocol_version(version)?.build()?; let mut buffer = sender.new_buffer(); @@ -168,10 +188,17 @@ fn test_bad_json_error( Ok(()) } -#[rstest] -fn test_json_error( - #[values(ProtocolVersion::V1, ProtocolVersion::V2)] version: ProtocolVersion, -) -> TestResult { +#[test] +fn test_json_error_v1() -> TestResult { + _test_json_error(ProtocolVersion::V1) +} + +#[test] +fn test_json_error_v2() -> TestResult { + _test_json_error(ProtocolVersion::V2) +} + +fn _test_json_error(version: ProtocolVersion) -> TestResult { let mut server = MockServer::new()?; let mut sender = server.lsb_http().protocol_version(version)?.build()?; let mut buffer = sender.new_buffer(); @@ -214,10 +241,17 @@ fn test_json_error( Ok(()) } -#[rstest] -fn test_no_connection( - #[values(ProtocolVersion::V1, ProtocolVersion::V2)] version: ProtocolVersion, -) -> TestResult { +#[test] +fn test_no_connection_v1() -> TestResult { + _test_no_connection(ProtocolVersion::V1) +} + +#[test] +fn test_no_connection_v2() -> TestResult { + _test_no_connection(ProtocolVersion::V2) +} + +fn _test_no_connection(version: ProtocolVersion) -> TestResult { let mut sender = SenderBuilder::new(Protocol::Http, "127.0.0.1", 1) .protocol_version(version)? .build()?; @@ -237,10 +271,17 @@ fn test_no_connection( Ok(()) } -#[rstest] -fn test_old_server_without_ilp_http_support( - #[values(ProtocolVersion::V1, ProtocolVersion::V2)] version: ProtocolVersion, -) -> TestResult { +#[test] +fn test_old_server_without_ilp_http_support_v1() -> TestResult { + _test_old_server_without_ilp_http_support(ProtocolVersion::V1) +} + +#[test] +fn test_old_server_without_ilp_http_support_v2() -> TestResult { + _test_old_server_without_ilp_http_support(ProtocolVersion::V2) +} + +fn _test_old_server_without_ilp_http_support(version: ProtocolVersion) -> TestResult { let mut server = MockServer::new()?; let mut sender = server.lsb_http().protocol_version(version)?.build()?; let mut buffer = sender.new_buffer(); @@ -278,10 +319,17 @@ fn test_old_server_without_ilp_http_support( Ok(()) } -#[rstest] -fn test_http_basic_auth( - #[values(ProtocolVersion::V1, ProtocolVersion::V2)] version: ProtocolVersion, -) -> TestResult { +#[test] +fn test_http_basic_auth_v1() -> TestResult { + _test_http_basic_auth(ProtocolVersion::V1) +} + +#[test] +fn test_http_basic_auth_v2() -> TestResult { + _test_http_basic_auth(ProtocolVersion::V2) +} + +fn _test_http_basic_auth(version: ProtocolVersion) -> TestResult { let mut server = MockServer::new()?; let mut sender = server .lsb_http() @@ -324,10 +372,17 @@ fn test_http_basic_auth( Ok(()) } -#[rstest] -fn test_unauthenticated( - #[values(ProtocolVersion::V1, ProtocolVersion::V2)] version: ProtocolVersion, -) -> TestResult { +#[test] +fn test_unauthenticated_v1() -> TestResult { + _test_unauthenticated(ProtocolVersion::V1) +} + +#[test] +fn test_unauthenticated_v2() -> TestResult { + _test_unauthenticated(ProtocolVersion::V2) +} + +fn _test_unauthenticated(version: ProtocolVersion) -> TestResult { let mut server = MockServer::new()?; let mut sender = server.lsb_http().protocol_version(version)?.build()?; let mut buffer = sender.new_buffer(); @@ -366,10 +421,17 @@ fn test_unauthenticated( Ok(()) } -#[rstest] -fn test_token_auth( - #[values(ProtocolVersion::V1, ProtocolVersion::V2)] version: ProtocolVersion, -) -> TestResult { +#[test] +fn test_token_auth_v1() -> TestResult { + _test_token_auth(ProtocolVersion::V1) +} + +#[test] +fn test_token_auth_v2() -> TestResult { + _test_token_auth(ProtocolVersion::V2) +} + +fn _test_token_auth(version: ProtocolVersion) -> TestResult { let mut server = MockServer::new()?; let mut sender = server .lsb_http() @@ -406,10 +468,17 @@ fn test_token_auth( Ok(()) } -#[rstest] -fn test_request_timeout( - #[values(ProtocolVersion::V1, ProtocolVersion::V2)] version: ProtocolVersion, -) -> TestResult { +#[test] +fn test_request_timeout_v1() -> TestResult { + _test_request_timeout(ProtocolVersion::V1) +} + +#[test] +fn test_request_timeout_v2() -> TestResult { + _test_request_timeout(ProtocolVersion::V2) +} + +fn _test_request_timeout(version: ProtocolVersion) -> TestResult { let server = MockServer::new()?; let request_timeout = Duration::from_millis(50); let mut sender = server @@ -433,10 +502,17 @@ fn test_request_timeout( Ok(()) } -#[rstest] -fn test_tls( - #[values(ProtocolVersion::V1, ProtocolVersion::V2)] version: ProtocolVersion, -) -> TestResult { +#[test] +fn test_tls_v1() -> TestResult { + _test_tls(ProtocolVersion::V1) +} + +#[test] +fn test_tls_v2() -> TestResult { + _test_tls(ProtocolVersion::V2) +} + +fn _test_tls(version: ProtocolVersion) -> TestResult { let mut ca_path = certs_dir(); ca_path.push("server_rootCA.pem"); let mut server = MockServer::new()?; @@ -475,10 +551,17 @@ fn test_tls( Ok(()) } -#[rstest] -fn test_user_agent( - #[values(ProtocolVersion::V1, ProtocolVersion::V2)] version: ProtocolVersion, -) -> TestResult { +#[test] +fn test_user_agent_v1() -> TestResult { + _test_user_agent(ProtocolVersion::V1) +} + +#[test] +fn test_user_agent_v2() -> TestResult { + _test_user_agent(ProtocolVersion::V2) +} + +fn _test_user_agent(version: ProtocolVersion) -> TestResult { let mut server = MockServer::new()?; let mut sender = server .lsb_http() @@ -513,10 +596,17 @@ fn test_user_agent( Ok(()) } -#[rstest] -fn test_two_retries( - #[values(ProtocolVersion::V1, ProtocolVersion::V2)] version: ProtocolVersion, -) -> TestResult { +#[test] +fn test_two_retries_v1() -> TestResult { + _test_two_retries(ProtocolVersion::V1) +} + +#[test] +fn test_two_retries_v2() -> TestResult { + _test_two_retries(ProtocolVersion::V2) +} + +fn _test_two_retries(version: ProtocolVersion) -> TestResult { // Note: This also tests that the _same_ connection is being reused, i.e. tests keepalive. let mut server = MockServer::new()?; let mut sender = server @@ -577,10 +667,17 @@ fn test_two_retries( Ok(()) } -#[rstest] -fn test_one_retry( - #[values(ProtocolVersion::V1, ProtocolVersion::V2)] version: ProtocolVersion, -) -> TestResult { +#[test] +fn test_one_retry_v1() -> TestResult { + _test_one_retry(ProtocolVersion::V1) +} + +#[test] +fn test_one_retry_v2() -> TestResult { + _test_one_retry(ProtocolVersion::V2) +} + +fn _test_one_retry(version: ProtocolVersion) -> TestResult { let mut server = MockServer::new()?; let mut sender = server .lsb_http() @@ -641,10 +738,17 @@ fn test_one_retry( Ok(()) } -#[rstest] -fn test_transactional( - #[values(ProtocolVersion::V1, ProtocolVersion::V2)] version: ProtocolVersion, -) -> TestResult { +#[test] +fn test_transactional_v1() -> TestResult { + _test_transactional(ProtocolVersion::V1) +} + +#[test] +fn test_transactional_v2() -> TestResult { + _test_transactional(ProtocolVersion::V2) +} + +fn _test_transactional(version: ProtocolVersion) -> TestResult { let mut server = MockServer::new()?; let mut sender = server.lsb_http().protocol_version(version)?.build()?; // A buffer with a two tables. @@ -697,7 +801,6 @@ fn test_transactional( Ok(()) } - fn _test_sender_auto_detect_protocol_version( supported_versions: Option>, expect_version: ProtocolVersion, @@ -725,7 +828,7 @@ fn _test_sender_auto_detect_protocol_version( } let exp = &[ b"test,t1=v1 ", - crate::tests::sender::f64_to_bytes("f1", 0.5, expect_version).as_slice(), + f64_to_bytes("f1", 0.5, expect_version).as_slice(), b" 10000000\n", ] .concat(); diff --git a/questdb-rs/src/tests/sender.rs b/questdb-rs/src/tests/sync_sender.rs similarity index 94% rename from questdb-rs/src/tests/sender.rs rename to questdb-rs/src/tests/sync_sender.rs index 888b2dae..fab17382 100644 --- a/questdb-rs/src/tests/sender.rs +++ b/questdb-rs/src/tests/sync_sender.rs @@ -42,22 +42,28 @@ use ndarray::{arr2, ArrayD}; #[cfg(feature = "sync-sender-tcp")] use crate::tests::{ - assert_err_contains, + assert_err_contains, f64_to_bytes, mock::{certs_dir, MockServer}, ndarr::ArrayColumnTypeTag, }; #[cfg(feature = "sync-sender-tcp")] -use rstest::rstest; +use crate::ingress::{CertificateAuthority, ARRAY_BINARY_FORMAT_TYPE}; #[cfg(feature = "sync-sender-tcp")] -use crate::ingress::{CertificateAuthority, ARRAY_BINARY_FORMAT_TYPE}; +#[test] +fn test_basics_v1() -> TestResult { + _test_basics(ProtocolVersion::V1) +} #[cfg(feature = "sync-sender-tcp")] -#[rstest] -fn test_basics( - #[values(ProtocolVersion::V1, ProtocolVersion::V2)] version: ProtocolVersion, -) -> TestResult { +#[test] +fn test_basics_v2() -> TestResult { + _test_basics(ProtocolVersion::V2) +} + +#[cfg(feature = "sync-sender-tcp")] +fn _test_basics(version: ProtocolVersion) -> TestResult { use std::time::SystemTime; let mut server = MockServer::new()?; @@ -234,10 +240,19 @@ fn test_array_f64_for_ndarray() -> TestResult { } #[cfg(feature = "sync-sender-tcp")] -#[rstest] -fn test_max_buf_size( - #[values(ProtocolVersion::V1, ProtocolVersion::V2)] version: ProtocolVersion, -) -> TestResult { +#[test] +fn test_max_buf_size_v1() -> TestResult { + _test_max_buf_size(ProtocolVersion::V1) +} + +#[cfg(feature = "sync-sender-tcp")] +#[test] +fn test_max_buf_size_v2() -> TestResult { + _test_max_buf_size(ProtocolVersion::V2) +} + +#[cfg(feature = "sync-sender-tcp")] +fn _test_max_buf_size(version: ProtocolVersion) -> TestResult { let max = 1024; let mut server = MockServer::new()?; let mut sender = server @@ -612,10 +627,17 @@ fn test_arr_column_name_too_long() -> TestResult { } #[cfg(feature = "sync-sender-tcp")] -#[rstest] -fn test_tls_with_file_ca( - #[values(ProtocolVersion::V1, ProtocolVersion::V2)] version: ProtocolVersion, -) -> TestResult { +fn test_tls_with_file_ca_v1() -> TestResult { + _test_tls_with_file_ca(ProtocolVersion::V1) +} + +#[cfg(feature = "sync-sender-tcp")] +fn test_tls_with_file_ca_v2() -> TestResult { + _test_tls_with_file_ca(ProtocolVersion::V2) +} + +#[cfg(feature = "sync-sender-tcp")] +fn _test_tls_with_file_ca(version: ProtocolVersion) -> TestResult { let mut ca_path = certs_dir(); ca_path.push("server_rootCA.pem"); @@ -720,11 +742,20 @@ fn test_plain_to_tls_server() -> TestResult { Ok(()) } +#[test] #[cfg(feature = "insecure-skip-verify")] -#[rstest] -fn test_tls_insecure_skip_verify( - #[values(ProtocolVersion::V1, ProtocolVersion::V2)] version: ProtocolVersion, -) -> TestResult { +fn test_tls_insecure_skip_verify_v1() -> TestResult { + _test_tls_insecure_skip_verify(ProtocolVersion::V1) +} + +#[test] +#[cfg(feature = "insecure-skip-verify")] +fn test_tls_insecure_skip_verify_v2() -> TestResult { + _test_tls_insecure_skip_verify(ProtocolVersion::V2) +} + +#[cfg(feature = "insecure-skip-verify")] +fn _test_tls_insecure_skip_verify(version: ProtocolVersion) -> TestResult { let server = MockServer::new()?; let lsb = server .lsb_tcps() @@ -793,22 +824,3 @@ fn tcp_mismatched_buffer_and_sender_version() -> TestResult { ); Ok(()) } - -pub(crate) fn f64_to_bytes(name: &str, value: f64, version: ProtocolVersion) -> Vec { - let mut buf = Vec::new(); - buf.extend_from_slice(name.as_bytes()); - buf.push(b'='); - - match version { - ProtocolVersion::V1 => { - let mut ser = F64Serializer::new(value); - buf.extend_from_slice(ser.as_str().as_bytes()); - } - ProtocolVersion::V2 => { - buf.push(b'='); - buf.push(DOUBLE_BINARY_FORMAT_TYPE); - buf.extend_from_slice(&value.to_le_bytes()); - } - } - buf -}