diff --git a/Cargo.lock b/Cargo.lock index 2a1e2aed5c3d3..4f817ab09fde5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3419,7 +3419,7 @@ dependencies = [ "databend-common-ast", "geos", "geozero 0.13.0", - "http 1.1.0", + "hyper 0.14.30", "opendal 0.49.0", "parquet", "paste", diff --git a/Cargo.toml b/Cargo.toml index 54c251b197579..1e0c7239eb723 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -244,6 +244,7 @@ geos = { version = "8.3", features = ["static", "geo", "geo-types"] } geozero = { version = "0.13.0", features = ["default", "with-wkb", "with-geos", "with-geojson"] } hashbrown = { version = "0.14.3", default-features = false } http = "1" +hyper = "0.14.20" iceberg = { version = "0.3.0", git = "https://github.com/apache/iceberg-rust/", rev = "c3549836796f93aa3ad22276af788aa3d92533a1" } iceberg-catalog-hms = { version = "0.3.0", git = "https://github.com/apache/iceberg-rust/", rev = "c3549836796f93aa3ad22276af788aa3d92533a1" } iceberg-catalog-rest = { version = "0.3.0", git = "https://github.com/apache/iceberg-rust/", rev = "c3549836796f93aa3ad22276af788aa3d92533a1" } diff --git a/scripts/ci/ci-run-stateful-tests-cluster-minio.sh b/scripts/ci/ci-run-stateful-tests-cluster-minio.sh index b27da3cb6d9ff..a122870e3e555 100755 --- a/scripts/ci/ci-run-stateful-tests-cluster-minio.sh +++ b/scripts/ci/ci-run-stateful-tests-cluster-minio.sh @@ -20,6 +20,8 @@ export STORAGE_ALLOW_INSECURE=true echo "Install dependence" python3 -m pip install --quiet mysql-connector-python requests +sudo apt-get update -yq +sudo apt-get install -yq dsniff net-tools echo "calling test suite" echo "Starting Cluster databend-query" diff --git a/scripts/ci/ci-run-stateful-tests-standalone-minio.sh b/scripts/ci/ci-run-stateful-tests-standalone-minio.sh index 706d7cb50aea4..d713f3ed5becf 100755 --- a/scripts/ci/ci-run-stateful-tests-standalone-minio.sh +++ b/scripts/ci/ci-run-stateful-tests-standalone-minio.sh @@ -20,6 +20,8 @@ export STORAGE_ALLOW_INSECURE=true echo "Install dependence" python3 -m pip install --quiet mysql-connector-python requests +sudo apt-get update -yq +sudo apt-get install -yq dsniff net-tools echo "calling test suite" echo "Starting standalone DatabendQuery(debug)" diff --git a/src/common/exception/Cargo.toml b/src/common/exception/Cargo.toml index 84080a4ebe46c..7a30e6b350e6e 100644 --- a/src/common/exception/Cargo.toml +++ b/src/common/exception/Cargo.toml @@ -21,6 +21,7 @@ bincode = { workspace = true } geos = { workspace = true } geozero = { workspace = true } http = { workspace = true } +hyper = { workspace = true } opendal = { workspace = true } parquet = { workspace = true } paste = { workspace = true } diff --git a/src/common/exception/src/exception_into.rs b/src/common/exception/src/exception_into.rs index cb18b0e451ccf..7933ce8e142e0 100644 --- a/src/common/exception/src/exception_into.rs +++ b/src/common/exception/src/exception_into.rs @@ -375,6 +375,13 @@ impl From for ErrorCode { tonic::Code::Unknown => { let details = status.details(); if details.is_empty() { + if status.source().map_or(false, |e| e.is::()) { + return ErrorCode::CannotConnectNode(format!( + "{}, source: {:?}", + status.message(), + status.source() + )); + } return ErrorCode::UnknownException(format!( "{}, source: {:?}", status.message(), diff --git a/src/query/service/src/clusters/cluster.rs b/src/query/service/src/clusters/cluster.rs index ba9349572d127..7a3feb1544958 100644 --- a/src/query/service/src/clusters/cluster.rs +++ b/src/query/service/src/clusters/cluster.rs @@ -50,11 +50,13 @@ use futures::future::Either; use futures::Future; use futures::StreamExt; use log::error; +use log::info; use log::warn; use rand::thread_rng; use rand::Rng; use serde::Deserialize; use serde::Serialize; +use tokio::time::sleep; use crate::servers::flight::FlightClient; @@ -79,11 +81,11 @@ pub trait ClusterHelper { fn get_nodes(&self) -> Vec>; - async fn do_action Deserialize<'de> + Send>( + async fn do_action Deserialize<'de> + Send>( &self, path: &str, message: HashMap, - timeout: u64, + flight_params: FlightParams, ) -> Result>; } @@ -116,11 +118,11 @@ impl ClusterHelper for Cluster { self.nodes.to_vec() } - async fn do_action Deserialize<'de> + Send>( + async fn do_action Deserialize<'de> + Send>( &self, path: &str, message: HashMap, - timeout: u64, + flight_params: FlightParams, ) -> Result> { fn get_node<'a>(nodes: &'a [Arc], id: &str) -> Result<&'a Arc> { for node in nodes { @@ -145,12 +147,32 @@ impl ClusterHelper for Cluster { let node_secret = node.secret.clone(); async move { - let mut conn = create_client(&config, &flight_address).await?; - Ok::<_, ErrorCode>(( - id, - conn.do_action::<_, Res>(path, node_secret, message, timeout) - .await?, - )) + let mut attempt = 0; + + loop { + let mut conn = create_client(&config, &flight_address).await?; + match conn + .do_action::<_, Res>( + path, + node_secret.clone(), + message.clone(), + flight_params.timeout, + ) + .await + { + Ok(result) => return Ok((id, result)), + Err(e) + if e.code() == ErrorCode::CANNOT_CONNECT_NODE + && attempt < flight_params.retry_times => + { + // only retry when error is network problem + info!("retry do_action, attempt: {}", attempt); + attempt += 1; + sleep(Duration::from_secs(flight_params.retry_interval)).await; + } + Err(e) => return Err(e), + } + } } }); } @@ -505,3 +527,10 @@ pub async fn create_client(config: &InnerConfig, address: &str) -> Result Result { let cluster = self.ctx.get_cluster(); let settings = self.ctx.get_settings(); - let timeout = settings.get_flight_client_timeout()?; let mut message = HashMap::with_capacity(cluster.nodes.len()); @@ -63,9 +63,14 @@ impl KillInterpreter { message.insert(node_info.id.clone(), self.plan.clone()); } } + let flight_params = FlightParams { + timeout: settings.get_flight_client_timeout()?, + retry_times: settings.get_max_flight_retry_times()?, + retry_interval: settings.get_flight_retry_interval()?, + }; let res = cluster - .do_action::<_, bool>(KILL_QUERY, message, timeout) + .do_action::<_, bool>(KILL_QUERY, message, flight_params) .await?; match res.values().any(|x| *x) { diff --git a/src/query/service/src/interpreters/interpreter_set_priority.rs b/src/query/service/src/interpreters/interpreter_set_priority.rs index 0dda6b9dd656b..05f64b30af65a 100644 --- a/src/query/service/src/interpreters/interpreter_set_priority.rs +++ b/src/query/service/src/interpreters/interpreter_set_priority.rs @@ -21,6 +21,7 @@ use databend_common_exception::Result; use databend_common_sql::plans::SetPriorityPlan; use crate::clusters::ClusterHelper; +use crate::clusters::FlightParams; use crate::interpreters::Interpreter; use crate::pipelines::PipelineBuildResult; use crate::servers::flight::v1::actions::SET_PRIORITY; @@ -61,9 +62,13 @@ impl SetPriorityInterpreter { } let settings = self.ctx.get_settings(); - let timeout = settings.get_flight_client_timeout()?; + let flight_params = FlightParams { + timeout: settings.get_flight_client_timeout()?, + retry_times: settings.get_max_flight_retry_times()?, + retry_interval: settings.get_flight_retry_interval()?, + }; let res = cluster - .do_action::<_, bool>(SET_PRIORITY, message, timeout) + .do_action::<_, bool>(SET_PRIORITY, message, flight_params) .await?; match res.values().any(|x| *x) { diff --git a/src/query/service/src/interpreters/interpreter_system_action.rs b/src/query/service/src/interpreters/interpreter_system_action.rs index 86e747e865ee7..c3570923ff9d2 100644 --- a/src/query/service/src/interpreters/interpreter_system_action.rs +++ b/src/query/service/src/interpreters/interpreter_system_action.rs @@ -22,6 +22,7 @@ use databend_common_sql::plans::SystemAction; use databend_common_sql::plans::SystemPlan; use crate::clusters::ClusterHelper; +use crate::clusters::FlightParams; use crate::interpreters::Interpreter; use crate::pipelines::PipelineBuildResult; use crate::servers::flight::v1::actions::SYSTEM_ACTION; @@ -74,9 +75,13 @@ impl Interpreter for SystemActionInterpreter { } let settings = self.ctx.get_settings(); - let timeout = settings.get_flight_client_timeout()?; + let flight_params = FlightParams { + timeout: settings.get_flight_client_timeout()?, + retry_times: settings.get_max_flight_retry_times()?, + retry_interval: settings.get_flight_retry_interval()?, + }; cluster - .do_action::<_, ()>(SYSTEM_ACTION, message, timeout) + .do_action::<_, ()>(SYSTEM_ACTION, message, flight_params) .await?; } diff --git a/src/query/service/src/interpreters/interpreter_table_truncate.rs b/src/query/service/src/interpreters/interpreter_table_truncate.rs index 09a19f79cc43d..850ef56f0303c 100644 --- a/src/query/service/src/interpreters/interpreter_table_truncate.rs +++ b/src/query/service/src/interpreters/interpreter_table_truncate.rs @@ -21,6 +21,7 @@ use databend_common_exception::Result; use databend_common_sql::plans::TruncateTablePlan; use crate::clusters::ClusterHelper; +use crate::clusters::FlightParams; use crate::interpreters::Interpreter; use crate::pipelines::PipelineBuildResult; use crate::servers::flight::v1::actions::TRUNCATE_TABLE; @@ -95,9 +96,13 @@ impl Interpreter for TruncateTableInterpreter { } let settings = self.ctx.get_settings(); - let timeout = settings.get_flight_client_timeout()?; + let flight_params = FlightParams { + timeout: settings.get_flight_client_timeout()?, + retry_times: settings.get_max_flight_retry_times()?, + retry_interval: settings.get_flight_retry_interval()?, + }; cluster - .do_action::<_, ()>(TRUNCATE_TABLE, message, timeout) + .do_action::<_, ()>(TRUNCATE_TABLE, message, flight_params) .await?; } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_deserializer.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_deserializer.rs index 98abcec37315d..1c5ec5527a113 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_deserializer.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_deserializer.rs @@ -225,6 +225,7 @@ where DataPacket::MutationStatus { .. } => unreachable!(), DataPacket::DataCacheMetrics(_) => unreachable!(), DataPacket::FragmentData(v) => Ok(vec![self.recv_data(meta.packet, v)?]), + DataPacket::FlightControl(_) => unreachable!(), } } } diff --git a/src/query/service/src/servers/admin/v1/query_profiling.rs b/src/query/service/src/servers/admin/v1/query_profiling.rs index 649c16baeb3a2..6f1f6dc7e1182 100644 --- a/src/query/service/src/servers/admin/v1/query_profiling.rs +++ b/src/query/service/src/servers/admin/v1/query_profiling.rs @@ -30,6 +30,7 @@ use poem::IntoResponse; use crate::clusters::ClusterDiscovery; use crate::clusters::ClusterHelper; +use crate::clusters::FlightParams; use crate::servers::flight::v1::actions::GET_PROFILE; use crate::sessions::SessionManager; @@ -103,9 +104,13 @@ async fn get_cluster_profile(query_id: &str) -> Result, ErrorCo message.insert(node_info.id.clone(), query_id.to_owned()); } } - + let flight_params = FlightParams { + timeout: 60, + retry_times: 3, + retry_interval: 3, + }; let res = cluster - .do_action::<_, Option>>(GET_PROFILE, message, 60) + .do_action::<_, Option>>(GET_PROFILE, message, flight_params) .await?; match res.into_values().find(Option::is_some) { diff --git a/src/query/service/src/servers/flight/codec.rs b/src/query/service/src/servers/flight/codec.rs new file mode 100644 index 0000000000000..5185e691f5ac8 --- /dev/null +++ b/src/query/service/src/servers/flight/codec.rs @@ -0,0 +1,72 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::marker::PhantomData; +use std::sync::Arc; + +use prost::Message; +use tonic::codec::Codec; +use tonic::codec::DecodeBuf; +use tonic::codec::Decoder; +use tonic::codec::EncodeBuf; +use tonic::codec::Encoder; +use tonic::Status; + +#[derive(Default)] +pub struct MessageCodec(PhantomData<(E, D)>); + +impl Codec for MessageCodec { + type Encode = Arc; + type Decode = D; + type Encoder = ArcEncoder; + type Decoder = DefaultDecoder; + + fn encoder(&mut self) -> Self::Encoder { + ArcEncoder(PhantomData) + } + + fn decoder(&mut self) -> Self::Decoder { + DefaultDecoder(PhantomData) + } +} + +pub struct ArcEncoder(PhantomData); + +impl Encoder for ArcEncoder { + type Item = Arc; + + type Error = Status; + + fn encode(&mut self, item: Self::Item, dst: &mut EncodeBuf<'_>) -> Result<(), Self::Error> { + item.as_ref() + .encode(dst) + .map_err(|e| Status::internal(e.to_string())) + } +} + +pub struct DefaultDecoder(PhantomData); + +impl Decoder for DefaultDecoder { + type Item = T; + + type Error = Status; + + fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result, Self::Error> { + let item = Message::decode(buf) + .map(Some) + .map_err(|e| Status::internal(e.to_string()))?; + + Ok(item) + } +} diff --git a/src/query/service/src/servers/flight/flight_client.rs b/src/query/service/src/servers/flight/flight_client.rs index 5c3b1012f050a..d898a9b7da742 100644 --- a/src/query/service/src/servers/flight/flight_client.rs +++ b/src/query/service/src/servers/flight/flight_client.rs @@ -12,26 +12,38 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::VecDeque; +use std::pin::Pin; use std::str::FromStr; +use std::sync::atomic::AtomicPtr; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; use std::sync::Arc; +use std::task::Context; +use std::task::Poll; use async_channel::Receiver; use async_channel::Sender; use databend_common_arrow::arrow_format::flight::data::Action; use databend_common_arrow::arrow_format::flight::data::FlightData; -use databend_common_arrow::arrow_format::flight::data::Ticket; use databend_common_arrow::arrow_format::flight::service::flight_service_client::FlightServiceClient; use databend_common_base::base::tokio::time::Duration; -use databend_common_base::runtime::drop_guard; +use databend_common_base::runtime::GlobalIORuntime; +use databend_common_base::runtime::TrySpawn; +use databend_common_base::JoinHandle; use databend_common_exception::ErrorCode; use databend_common_exception::Result; use fastrace::func_path; use fastrace::future::FutureExt; use fastrace::Span; +use futures::Stream; use futures::StreamExt; use futures_util::future::Either; +use log::info; +use parking_lot::Mutex; use serde::Deserialize; use serde::Serialize; +use tokio::time::sleep; use tonic::metadata::AsciiMetadataKey; use tonic::metadata::AsciiMetadataValue; use tonic::transport::channel::Channel; @@ -39,9 +51,12 @@ use tonic::Request; use tonic::Status; use tonic::Streaming; +use crate::clusters::FlightParams; use crate::pipelines::executor::WatchNotify; use crate::servers::flight::request_builder::RequestBuilder; +use crate::servers::flight::v1::exchange::DataExchangeManager; use crate::servers::flight::v1::packets::DataPacket; +use crate::servers::flight::v1::packets::FlightControlCommand; pub struct FlightClient { inner: FlightServiceClient, @@ -115,45 +130,77 @@ impl FlightClient { } #[async_backtrace::framed] - pub async fn request_server_exchange( + pub async fn request_statistics_exchange( &mut self, query_id: &str, target: &str, + source_address: &str, + flight_params: FlightParams, ) -> Result { - let streaming = self - .get_streaming( - RequestBuilder::create(Ticket::default()) - .with_metadata("x-type", "request_server_exchange")? - .with_metadata("x-target", target)? - .with_metadata("x-query-id", query_id)? - .build(), - ) - .await?; + let (server_tx, server_rx) = async_channel::bounded(1); + let req = RequestBuilder::create(Box::pin(server_rx)) + .with_metadata("x-type", "request_server_exchange")? + .with_metadata("x-target", target)? + .with_metadata("x-query-id", query_id)? + .with_metadata("x-continue-from", "0")? + .with_metadata("x-enable-retry", &flight_params.retry_times.to_string())? + .build(); + let streaming = self.get_streaming(req).await?; let (notify, rx) = Self::streaming_receiver(streaming); - Ok(FlightExchange::create_receiver(notify, rx)) + Ok(FlightExchange::create_receiver( + notify, + rx, + Some(ConnectionInfo { + query_id: query_id.to_string(), + target: target.to_string(), + fragment: None, + source_address: source_address.to_string(), + retry_times: flight_params.retry_times, + retry_interval: Duration::from_secs(flight_params.retry_interval), + }), + server_tx, + )) } #[async_backtrace::framed] #[fastrace::trace] - pub async fn do_get( + pub async fn request_fragment_exchange( &mut self, query_id: &str, target: &str, fragment: usize, + source_address: &str, + flight_params: FlightParams, ) -> Result { - let request = RequestBuilder::create(Ticket::default()) + let (server_tx, server_rx) = async_channel::bounded(1); + + let request = RequestBuilder::create(Box::pin(server_rx)) .with_metadata("x-type", "exchange_fragment")? .with_metadata("x-target", target)? .with_metadata("x-query-id", query_id)? .with_metadata("x-fragment-id", &fragment.to_string())? + .with_metadata("x-continue-from", "0")? + .with_metadata("x-enable-retry", &flight_params.retry_times.to_string())? .build(); let request = databend_common_tracing::inject_span_to_tonic_request(request); let streaming = self.get_streaming(request).await?; let (notify, rx) = Self::streaming_receiver(streaming); - Ok(FlightExchange::create_receiver(notify, rx)) + Ok(FlightExchange::create_receiver( + notify, + rx, + Some(ConnectionInfo { + query_id: query_id.to_string(), + target: target.to_string(), + fragment: Some(fragment), + source_address: source_address.to_string(), + retry_times: flight_params.retry_times, + retry_interval: Duration::from_secs(flight_params.retry_interval), + }), + server_tx, + )) } fn streaming_receiver( @@ -172,6 +219,7 @@ impl FlightClient { Either::Left((_, _)) | Either::Right((None, _)) => { break; } + Either::Right((Some(message), next_notified)) => { notified = next_notified; streaming_next = streaming.next(); @@ -190,7 +238,6 @@ impl FlightClient { } } } - drop(streaming); tx.close(); } @@ -203,32 +250,71 @@ impl FlightClient { } #[async_backtrace::framed] - async fn get_streaming(&mut self, request: Request) -> Result> { - match self.inner.do_get(request).await { + async fn get_streaming( + &mut self, + request: Request>>>, + ) -> Result> { + match self.inner.do_exchange(request).await { Ok(res) => Ok(res.into_inner()), Err(status) => Err(ErrorCode::from(status).add_message_back("(while in query flight)")), } } + + #[async_backtrace::framed] + async fn reconnect(&mut self, info: &ConnectionInfo, seq: usize) -> Result { + let (server_tx, server_rx) = async_channel::bounded(1); + let request = match info.fragment { + Some(fragment_id) => RequestBuilder::create(Box::pin(server_rx)) + .with_metadata("x-type", "exchange_fragment")? + .with_metadata("x-target", &info.target)? + .with_metadata("x-query-id", &info.query_id)? + .with_metadata("x-fragment-id", &fragment_id.to_string())? + .with_metadata("x-continue-from", &seq.to_string())? + .with_metadata("x-enable-retry", &info.retry_times.to_string())? + .build(), + None => RequestBuilder::create(Box::pin(server_rx)) + .with_metadata("x-type", "request_server_exchange")? + .with_metadata("x-target", &info.target)? + .with_metadata("x-query-id", &info.query_id)? + .with_metadata("x-continue-from", &seq.to_string())? + .with_metadata("x-enable-retry", &info.retry_times.to_string())? + .build(), + }; + let request = databend_common_tracing::inject_span_to_tonic_request(request); + + let streaming = self.get_streaming(request).await?; + + let (network_notify, recv) = Self::streaming_receiver(streaming); + Ok(FlightRxInner::create(network_notify, recv, server_tx)) + } } -pub struct FlightReceiver { - notify: Arc, - rx: Receiver>, +#[derive(Clone)] +pub struct ConnectionInfo { + pub query_id: String, + pub target: String, + pub fragment: Option, + pub source_address: String, + pub retry_times: u64, + pub retry_interval: Duration, } -impl Drop for FlightReceiver { - fn drop(&mut self) { - drop_guard(move || { - self.close(); - }) - } +pub struct FlightRxInner { + notify: Arc, + rx: Receiver>, + server_tx: Sender, } -impl FlightReceiver { - pub fn create(rx: Receiver>) -> FlightReceiver { - FlightReceiver { +impl FlightRxInner { + pub fn create( + notify: Arc, + rx: Receiver>, + server_tx: Sender, + ) -> FlightRxInner { + FlightRxInner { rx, - notify: Arc::new(WatchNotify::new()), + notify, + server_tx, } } @@ -245,6 +331,110 @@ impl FlightReceiver { self.rx.close(); self.notify.notify_waiters(); } + + pub fn stop_cluster(&mut self) { + // ignore the error, because we cannot determine the state of server side + let _ = self.server_tx.try_send( + FlightData::try_from(DataPacket::FlightControl(FlightControlCommand::Close)) + .expect("convert to flight data error"), + ); + } +} + +pub struct RetryableFlightReceiver { + seq: Arc, + info: Option, + inner: Arc>, +} + +impl Drop for RetryableFlightReceiver { + fn drop(&mut self) { + self.close(); + } +} + +impl RetryableFlightReceiver { + pub fn dummy() -> RetryableFlightReceiver { + RetryableFlightReceiver { + seq: Arc::new(AtomicUsize::new(0)), + info: None, + inner: Arc::new(Default::default()), + } + } + + #[async_backtrace::framed] + pub async fn recv(&self) -> Result> { + // Non thread safe, we only use atomic to implement mutable. + loop { + let inner = unsafe { &*self.inner.load(Ordering::SeqCst) }; + return match inner.recv().await { + Ok(message) => { + self.seq.fetch_add(1, Ordering::SeqCst); + Ok(message) + } + Err(cause) => { + info!("Error while receiving data from flight : {:?}", cause); + if cause.code() == ErrorCode::CANNOT_CONNECT_NODE { + // only retry when error is network problem + let Err(cause) = self.retry().await else { + info!("Retry flight connection successfully!"); + continue; + }; + + info!("Retry flight connection failure, cause: {:?}", cause); + } + + Err(cause) + } + }; + } + } + + #[async_backtrace::framed] + async fn retry(&self) -> Result<()> { + if let Some(connection_info) = &self.info { + let mut flight_client = + DataExchangeManager::create_client(&connection_info.source_address, true).await?; + + for attempts in 0..connection_info.retry_times { + let Ok(recv) = flight_client + .reconnect(connection_info, self.seq.load(Ordering::Acquire)) + .await + else { + info!("Reconnect attempt {} failed", attempts); + sleep(connection_info.retry_interval).await; + continue; + }; + + let ptr = self + .inner + .swap(Box::into_raw(Box::new(recv)), Ordering::SeqCst); + + unsafe { + // We cannot determine the number of strong ref. so close it. + let broken_connection = Box::from_raw(ptr); + broken_connection.close(); + } + + return Ok(()); + } + + return Err(ErrorCode::Timeout("Exceed max retries time")); + } + + Ok(()) + } + + pub fn close(&self) { + unsafe { + let inner = self.inner.load(Ordering::SeqCst); + + if !inner.is_null() { + (*inner).stop_cluster(); + (*inner).close(); + } + } + } } pub struct FlightSender { @@ -276,43 +466,279 @@ impl FlightSender { } } +pub struct SenderPayload { + pub state: Arc>, + pub sender: Option>>, +} + +pub struct ReceiverPayload { + seq: Arc, + info: Option, + inner: Arc>, +} + pub enum FlightExchange { Dummy, - Receiver { - notify: Arc, - receiver: Receiver>, - }, - Sender(Sender>), + + Sender(SenderPayload), + Receiver(ReceiverPayload), + + MovedSender(SenderPayload), + MovedReceiver(ReceiverPayload), } impl FlightExchange { pub fn create_sender( + state: Arc>, sender: Sender>, ) -> FlightExchange { - FlightExchange::Sender(sender) + FlightExchange::Sender(SenderPayload { + state, + sender: Some(sender), + }) } pub fn create_receiver( notify: Arc, receiver: Receiver>, + connection_info: Option, + server_tx: Sender, ) -> FlightExchange { - FlightExchange::Receiver { notify, receiver } + FlightExchange::Receiver(ReceiverPayload { + seq: Arc::new(AtomicUsize::new(0)), + info: connection_info, + inner: Arc::new(AtomicPtr::new(Box::into_raw(Box::new( + FlightRxInner::create(notify, receiver, server_tx), + )))), + }) } + pub fn take_as_sender(&mut self) -> FlightSender { + let mut flight_sender = FlightExchange::Dummy; + std::mem::swap(self, &mut flight_sender); + + if let FlightExchange::Sender(mut v) = flight_sender { + let flight_sender = FlightSender::create(v.sender.take().unwrap()); + *self = FlightExchange::MovedSender(v); + return flight_sender; + } - pub fn convert_to_sender(self) -> FlightSender { - match self { - FlightExchange::Sender(tx) => FlightSender { tx }, - _ => unreachable!(), + unreachable!("take as sender miss match exchange type") + } + + pub fn take_as_receiver(&mut self) -> RetryableFlightReceiver { + let mut flight_receiver = FlightExchange::Dummy; + std::mem::swap(self, &mut flight_receiver); + + if let FlightExchange::Receiver(v) = flight_receiver { + let flight_receiver = RetryableFlightReceiver { + seq: v.seq.clone(), + info: v.info.clone(), + inner: v.inner.clone(), + }; + + *self = FlightExchange::MovedReceiver(v); + + return flight_receiver; + } + + unreachable!("take as receiver miss match exchange type") + } +} + +pub struct FlightDataAckState { + seq: usize, + finish: bool, + receiver: Receiver>, + ack_window: VecDeque<(usize, std::result::Result, Status>)>, + clean_up_handle: Option>, + window_size: usize, +} + +impl FlightDataAckState { + pub fn create( + receiver: Receiver>, + window_size: usize, + ) -> Arc> { + Arc::new(Mutex::new(FlightDataAckState { + receiver, + seq: 0, + ack_window: VecDeque::with_capacity(window_size), + finish: false, + clean_up_handle: None, + window_size, + })) + } + + fn error_of_stream( + &mut self, + cause: Status, + ) -> Poll, Status>>> { + let message_seq = self.seq; + self.seq += 1; + self.ack_window.push_back((message_seq, Err(cause.clone()))); + Poll::Ready(Some(Err(cause))) + } + + fn end_of_stream(&mut self) -> Poll, Status>>> { + self.seq += 1; + self.finish = true; + Poll::Ready(None) + } + + fn message( + &mut self, + data: FlightData, + ) -> Poll, Status>>> { + let message_seq = self.seq; + self.seq += 1; + let data = Arc::new(data); + let duplicate = data.clone(); + self.ack_window.push_back((message_seq, Ok(data))); + Poll::Ready(Some(Ok(duplicate))) + } + + fn check_resend(&mut self) -> Option, Status>> { + let current_seq = self.seq; + + if let Some((seq, _packet)) = self.ack_window.back() { + if seq + 1 == current_seq { + return None; + } + } + // resend case, iterate the queue to find the message to resend + for (id, res) in self.ack_window.iter() { + if *id == current_seq { + self.seq += 1; + return Some(res.clone()); + } + } + + None + } + + pub fn poll_next( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, Status>>> { + if self.finish { + return Poll::Ready(None); + } + + // check if seq has been reset, if so, resend last packet + if let Some(res) = self.check_resend() { + return Poll::Ready(Some(res)); + } + + // check if ack window is full, if so, pop the oldest packet + if self.ack_window.len() == self.window_size { + self.ack_window.pop_front(); + } + + match Pin::new(&mut self.receiver).poll_next(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(None) => self.end_of_stream(), + Poll::Ready(Some(Err(status))) => self.error_of_stream(status), + Poll::Ready(Some(Ok(flight_data))) => self.message(flight_data), + } + } +} + +pub struct FlightDataAckStream { + notify: Arc, + state: Arc>, + enable_retry: bool, +} + +impl FlightDataAckStream { + pub fn create( + state: Arc>, + begin: usize, + client_stream: Streaming, + enable_retry: bool, + ) -> Result { + let notify = Self::streaming_receiver(state.clone(), client_stream); + let mut state_guard = state.lock(); + state_guard.seq = begin; + state_guard.finish = false; + if let Some(handle) = state_guard.clean_up_handle.take() { + handle.abort(); + } + drop(state_guard); + Ok(FlightDataAckStream { + notify, + state, + enable_retry, + }) + } + + fn streaming_receiver( + state: Arc>, + mut streaming: Streaming, + ) -> Arc { + let notify = Arc::new(WatchNotify::new()); + let fut = { + let notify = notify.clone(); + async move { + let notified = Box::pin(notify.notified()); + let streaming_next = streaming.next(); + + match futures::future::select(notified, streaming_next).await { + Either::Left((_, _)) | Either::Right((None, _)) => {} + Either::Right((Some(message), _next_notified)) => { + if let Ok(flight_data) = message { + let packet = DataPacket::try_from(flight_data).unwrap(); + if let DataPacket::FlightControl(FlightControlCommand::Close) = packet { + let mut state_guard = state.lock(); + state_guard.finish = true; + state_guard.receiver.close(); + drop(state_guard); + } + } + } + } + drop(state); + drop(streaming); + } } + .in_span(Span::enter_with_local_parent(func_path!())); + + databend_common_base::runtime::spawn(fut); + + notify } +} - pub fn convert_to_receiver(self) -> FlightReceiver { - match self { - FlightExchange::Receiver { notify, receiver } => FlightReceiver { - notify, - rx: receiver, - }, - _ => unreachable!(), +impl Drop for FlightDataAckStream { + fn drop(&mut self) { + info!("Drop FlightDataAckStream enable: {:?}", self.enable_retry); + let mut state = self.state.lock(); + // if not enable retry, fallback to close immediately + if !self.enable_retry || state.finish { + info!("{:?} {:?}", state.finish, self.enable_retry); + self.notify.notify_waiters(); + state.receiver.close(); + return; } + let weak_state = Arc::downgrade(&self.state); + let notify = Arc::downgrade(&self.notify); + let handle = GlobalIORuntime::instance().spawn(async move { + tokio::time::sleep(Duration::from_secs(60)).await; + if let Some(state) = weak_state.upgrade() { + let state_guard = state.lock(); + state_guard.receiver.close(); + } + if let Some(notify) = notify.upgrade() { + notify.notify_waiters(); + } + }); + state.clean_up_handle = Some(handle); + } +} + +impl Stream for FlightDataAckStream { + type Item = std::result::Result, Status>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.state.lock().poll_next(cx) } } diff --git a/src/query/service/src/servers/flight/flight_server.rs b/src/query/service/src/servers/flight/flight_server.rs new file mode 100644 index 0000000000000..ee921ac9a826b --- /dev/null +++ b/src/query/service/src/servers/flight/flight_server.rs @@ -0,0 +1,106 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::future::Future; +use std::net::SocketAddr; +use std::sync::Arc; + +use databend_common_base::base::tokio; +use databend_common_base::base::tokio::sync::Notify; +use databend_common_config::InnerConfig; +use databend_common_exception::ErrorCode; +use databend_common_exception::Result; +use log::info; +use tonic::transport::server::TcpIncoming; +use tonic::transport::Identity; +use tonic::transport::Server; +use tonic::transport::ServerTlsConfig; + +use super::v1::DatabendQueryFlightService; +use crate::servers::flight::flight_service::FlightServiceServer; +use crate::servers::Server as DatabendQueryServer; + +pub struct FlightService { + pub config: InnerConfig, + pub abort_notify: Arc, +} + +impl FlightService { + pub fn create(config: InnerConfig) -> Result> { + Ok(Box::new(Self { + config, + abort_notify: Arc::new(Notify::new()), + })) + } + + fn shutdown_notify(&self) -> impl Future + 'static { + let notified = self.abort_notify.clone(); + async move { + notified.notified().await; + } + } + + #[async_backtrace::framed] + async fn server_tls_config(conf: &InnerConfig) -> Result { + let cert = tokio::fs::read(conf.query.rpc_tls_server_cert.as_str()).await?; + let key = tokio::fs::read(conf.query.rpc_tls_server_key.as_str()).await?; + let server_identity = Identity::from_pem(cert, key); + let tls_conf = ServerTlsConfig::new().identity(server_identity); + Ok(tls_conf) + } + + #[async_backtrace::framed] + pub async fn start_with_incoming(&mut self, addr: SocketAddr) -> Result<()> { + let flight_api_service = DatabendQueryFlightService::create(); + let builder = Server::builder(); + let mut builder = if self.config.tls_rpc_server_enabled() { + info!("databend query tls rpc enabled"); + builder + .tls_config(Self::server_tls_config(&self.config).await.map_err(|e| { + ErrorCode::TLSConfigurationFailure(format!( + "failed to load server tls config: {e}", + )) + })?) + .map_err(|e| { + ErrorCode::TLSConfigurationFailure(format!("failed to invoke tls_config: {e}",)) + })? + } else { + builder + }; + + let incoming = TcpIncoming::new(addr, true, None) + .map_err(|e| ErrorCode::CannotListenerPort(format!("{e}")))?; + let server = builder + .add_service(FlightServiceServer::new( + flight_api_service, + usize::MAX, + usize::MAX, + )) + .serve_with_incoming_shutdown(incoming, self.shutdown_notify()); + databend_common_base::runtime::spawn(server); + Ok(()) + } +} + +#[async_trait::async_trait] +impl DatabendQueryServer for FlightService { + #[async_backtrace::framed] + async fn shutdown(&mut self, _graceful: bool) {} + + #[async_backtrace::framed] + async fn start(&mut self, addr: SocketAddr) -> Result { + self.start_with_incoming(addr).await?; + Ok(addr) + } +} diff --git a/src/query/service/src/servers/flight/flight_service.rs b/src/query/service/src/servers/flight/flight_service.rs index 063d61f1ceb1f..3f314700b2e43 100644 --- a/src/query/service/src/servers/flight/flight_service.rs +++ b/src/query/service/src/servers/flight/flight_service.rs @@ -1,3 +1,4 @@ +// Copyright 2016-2019 The Apache Software Foundation // Copyright 2021 Datafuse Labs // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,96 +13,433 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::future::Future; -use std::net::SocketAddr; +use std::convert::Infallible; use std::sync::Arc; +use std::task::Context; +use std::task::Poll; -use databend_common_arrow::arrow_format::flight::service::flight_service_server::FlightServiceServer; -use databend_common_base::base::tokio; -use databend_common_base::base::tokio::sync::Notify; -use databend_common_config::InnerConfig; -use databend_common_exception::ErrorCode; -use databend_common_exception::Result; -use log::info; -use tonic::transport::server::TcpIncoming; -use tonic::transport::Identity; -use tonic::transport::Server; -use tonic::transport::ServerTlsConfig; - -use super::v1::DatabendQueryFlightService; -use crate::servers::Server as DatabendQueryServer; - -pub struct FlightService { - pub config: InnerConfig, - pub abort_notify: Arc, +use databend_common_arrow::arrow_format::flight; +use databend_common_arrow::arrow_format::flight::data::Action; +use databend_common_arrow::arrow_format::flight::data::ActionType; +use databend_common_arrow::arrow_format::flight::data::Criteria; +use databend_common_arrow::arrow_format::flight::data::Empty; +use databend_common_arrow::arrow_format::flight::data::FlightData; +use databend_common_arrow::arrow_format::flight::data::FlightDescriptor; +use databend_common_arrow::arrow_format::flight::data::FlightInfo; +use databend_common_arrow::arrow_format::flight::data::HandshakeRequest; +use databend_common_arrow::arrow_format::flight::data::HandshakeResponse; +use databend_common_arrow::arrow_format::flight::data::PutResult; +use databend_common_arrow::arrow_format::flight::data::SchemaResult; +use databend_common_arrow::arrow_format::flight::data::Ticket; +use tonic::async_trait; +use tonic::body::empty_body; +use tonic::body::BoxBody; +use tonic::codegen::http::request::Request as HTTPRequest; +use tonic::codegen::http::response::Response as HTTPResponse; +use tonic::codegen::Body; +use tonic::codegen::BoxFuture; +use tonic::codegen::Service; +use tonic::codegen::StdError; +use tonic::server::NamedService; +use tonic::server::ServerStreamingService; +use tonic::server::StreamingService; +use tonic::server::UnaryService; +use tonic::Status; + +use crate::servers::flight::codec::MessageCodec; + +/// This trait is derived from `FlightService`. +/// It has been modified the `do_get` method signatures to use `Arc` as `DoGetStream` type +#[async_trait] +pub trait FlightOperation: Send + Sync + 'static { + type HandshakeStream: tokio_stream::Stream, Status>> + + Send + + 'static; + async fn handshake( + &self, + request: tonic::Request>, + ) -> Result, Status>; + + type ListFlightsStream: tokio_stream::Stream, Status>> + + Send + + 'static; + async fn list_flights( + &self, + request: tonic::Request, + ) -> Result, Status>; + + async fn get_flight_info( + &self, + request: tonic::Request, + ) -> Result>, Status>; + + async fn get_schema( + &self, + request: tonic::Request, + ) -> Result>, Status>; + type DoExchangeStream: tokio_stream::Stream, Status>> + + Send + + 'static; + async fn do_exchange( + &self, + request: tonic::Request>, + ) -> Result, Status>; + + type DoGetStream: tokio_stream::Stream, Status>> + Send + 'static; + async fn do_get( + &self, + request: tonic::Request, + ) -> Result, Status>; + type DoPutStream: tokio_stream::Stream, Status>> + Send + 'static; + async fn do_put( + &self, + request: tonic::Request>, + ) -> Result, Status>; + type DoActionStream: tokio_stream::Stream, Status>> + + Send + + 'static; + async fn do_action( + &self, + request: tonic::Request, + ) -> Result, Status>; + type ListActionsStream: tokio_stream::Stream, Status>> + + Send + + 'static; + async fn list_actions( + &self, + request: tonic::Request, + ) -> Result, Status>; } -impl FlightService { - pub fn create(config: InnerConfig) -> Result> { - Ok(Box::new(Self { - config, - abort_notify: Arc::new(Notify::new()), - })) +struct _Inner(Arc); +impl Clone for _Inner { + fn clone(&self) -> Self { + Self(self.0.clone()) } +} - fn shutdown_notify(&self) -> impl Future + 'static { - let notified = self.abort_notify.clone(); - async move { - notified.notified().await; +pub struct FlightServiceServer { + inner: _Inner, + max_decoding_message_size: Option, + max_encoding_message_size: Option, +} + +impl FlightServiceServer { + pub(crate) fn new( + service: T, + max_decoding_message_size: usize, + max_encoding_message_size: usize, + ) -> Self { + Self { + inner: _Inner(Arc::new(service)), + max_decoding_message_size: Some(max_decoding_message_size), + max_encoding_message_size: Some(max_encoding_message_size), } } +} - #[async_backtrace::framed] - async fn server_tls_config(conf: &InnerConfig) -> Result { - let cert = tokio::fs::read(conf.query.rpc_tls_server_cert.as_str()).await?; - let key = tokio::fs::read(conf.query.rpc_tls_server_key.as_str()).await?; - let server_identity = Identity::from_pem(cert, key); - let tls_conf = ServerTlsConfig::new().identity(server_identity); - Ok(tls_conf) +impl Service> for FlightServiceServer +where + T: FlightOperation, + B: Body + Send + 'static, + B::Error: Into + Send + 'static, +{ + type Response = HTTPResponse; + type Error = Infallible; + type Future = BoxFuture; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } - #[async_backtrace::framed] - pub async fn start_with_incoming(&mut self, addr: SocketAddr) -> Result<()> { - let flight_api_service = DatabendQueryFlightService::create(); - let builder = Server::builder(); - let mut builder = if self.config.tls_rpc_server_enabled() { - info!("databend query tls rpc enabled"); - builder - .tls_config(Self::server_tls_config(&self.config).await.map_err(|e| { - ErrorCode::TLSConfigurationFailure(format!( - "failed to load server tls config: {e}", - )) - })?) - .map_err(|e| { - ErrorCode::TLSConfigurationFailure(format!("failed to invoke tls_config: {e}",)) - })? - } else { - builder - }; - - let incoming = TcpIncoming::new(addr, true, None) - .map_err(|e| ErrorCode::CannotListenerPort(format!("{e}")))?; - let server = builder - .add_service( - FlightServiceServer::new(flight_api_service) - .max_encoding_message_size(usize::MAX) - .max_decoding_message_size(usize::MAX), - ) - .serve_with_incoming_shutdown(incoming, self.shutdown_notify()); - - databend_common_base::runtime::spawn(server); - Ok(()) + fn call(&mut self, req: HTTPRequest) -> Self::Future { + match req.uri().path() { + "/arrow.flight.protocol.FlightService/Handshake" => { + struct HandshakeSvc(pub Arc); + impl StreamingService for HandshakeSvc { + type Response = Arc; + type ResponseStream = T::HandshakeStream; + type Future = BoxFuture, tonic::Status>; + fn call( + &mut self, + request: tonic::Request>, + ) -> Self::Future { + let inner = self.0.clone(); + let fut = async move { (*inner).handshake(request).await }; + Box::pin(fut) + } + } + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = HandshakeSvc(inner); + let codec = MessageCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec).apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/arrow.flight.protocol.FlightService/ListFlights" => { + struct ListFlightsSvc(pub Arc); + impl ServerStreamingService for ListFlightsSvc { + type Response = Arc; + type ResponseStream = T::ListFlightsStream; + type Future = BoxFuture, tonic::Status>; + fn call(&mut self, request: tonic::Request) -> Self::Future { + let inner = self.0.clone(); + let fut = async move { (*inner).list_flights(request).await }; + Box::pin(fut) + } + } + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = ListFlightsSvc(inner); + let codec = MessageCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec).apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.server_streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/arrow.flight.protocol.FlightService/GetFlightInfo" => { + struct GetFlightInfoSvc(pub Arc); + impl UnaryService for GetFlightInfoSvc { + type Response = Arc; + type Future = BoxFuture, tonic::Status>; + fn call(&mut self, request: tonic::Request) -> Self::Future { + let inner = self.0.clone(); + let fut = async move { (*inner).get_flight_info(request).await }; + Box::pin(fut) + } + } + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = GetFlightInfoSvc(inner); + let codec = MessageCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec).apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/arrow.flight.protocol.FlightService/GetSchema" => { + struct GetSchemaSvc(pub Arc); + impl UnaryService for GetSchemaSvc { + type Response = Arc; + type Future = BoxFuture, tonic::Status>; + fn call(&mut self, request: tonic::Request) -> Self::Future { + let inner = self.0.clone(); + let fut = async move { (*inner).get_schema(request).await }; + Box::pin(fut) + } + } + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = GetSchemaSvc(inner); + let codec = MessageCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec).apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/arrow.flight.protocol.FlightService/DoGet" => { + struct DoGetSvc(pub Arc); + impl ServerStreamingService for DoGetSvc { + type Response = Arc; + type ResponseStream = T::DoGetStream; + type Future = BoxFuture, Status>; + fn call(&mut self, request: tonic::Request) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = + async move { ::do_get(&inner, request).await }; + Box::pin(fut) + } + } + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.0.clone(); + let fut = async move { + let method = DoGetSvc(inner); + let codec = MessageCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec).apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.server_streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/arrow.flight.protocol.FlightService/DoPut" => { + struct DoPutSvc(pub Arc); + impl StreamingService for DoPutSvc { + type Response = Arc; + type ResponseStream = T::DoPutStream; + type Future = BoxFuture, tonic::Status>; + fn call( + &mut self, + request: tonic::Request>, + ) -> Self::Future { + let inner = self.0.clone(); + let fut = async move { (*inner).do_put(request).await }; + Box::pin(fut) + } + } + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = DoPutSvc(inner); + let codec = MessageCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec).apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/arrow.flight.protocol.FlightService/DoExchange" => { + struct DoExchangeSvc(pub Arc); + impl StreamingService for DoExchangeSvc { + type Response = Arc; + type ResponseStream = T::DoExchangeStream; + type Future = BoxFuture, tonic::Status>; + fn call( + &mut self, + request: tonic::Request>, + ) -> Self::Future { + let inner = self.0.clone(); + let fut = async move { (*inner).do_exchange(request).await }; + Box::pin(fut) + } + } + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = DoExchangeSvc(inner); + let codec = MessageCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec).apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/arrow.flight.protocol.FlightService/DoAction" => { + struct DoActionSvc(pub Arc); + impl ServerStreamingService for DoActionSvc { + type Response = Arc; + + type ResponseStream = T::DoActionStream; + + type Future = BoxFuture, Status>; + + fn call(&mut self, request: tonic::Request) -> Self::Future { + let inner = self.0.clone(); + let fut = async move { (*inner).do_action(request).await }; + Box::pin(fut) + } + } + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.0.clone(); + let fut = async move { + let method = DoActionSvc(inner); + let codec = MessageCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec).apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.server_streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/arrow.flight.protocol.FlightService/ListActions" => { + struct ListActionsSvc(pub Arc); + impl ServerStreamingService for ListActionsSvc { + type Response = Arc; + type ResponseStream = T::ListActionsStream; + type Future = BoxFuture, tonic::Status>; + fn call(&mut self, request: tonic::Request) -> Self::Future { + let inner = self.0.clone(); + let fut = async move { (*inner).list_actions(request).await }; + Box::pin(fut) + } + } + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = ListActionsSvc(inner); + let codec = MessageCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec).apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.server_streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + _ => Box::pin(async move { + Ok(HTTPResponse::builder() + .status(200) + .header("grpc-status", "12") + .header("content-type", "application/grpc") + .body(empty_body()) + .unwrap()) + }), + } } } -#[async_trait::async_trait] -impl DatabendQueryServer for FlightService { - #[async_backtrace::framed] - async fn shutdown(&mut self, _graceful: bool) {} - - #[async_backtrace::framed] - async fn start(&mut self, addr: SocketAddr) -> Result { - self.start_with_incoming(addr).await?; - Ok(addr) +impl Clone for FlightServiceServer { + fn clone(&self) -> Self { + let inner = self.inner.clone(); + Self { + inner, + max_decoding_message_size: self.max_decoding_message_size, + max_encoding_message_size: self.max_encoding_message_size, + } } } + +// The client side still uses FlightClient to find the service by this name, +// so we keep the service name unchanged. +impl NamedService for FlightServiceServer { + const NAME: &'static str = "arrow.flight.protocol.FlightService"; +} diff --git a/src/query/service/src/servers/flight/mod.rs b/src/query/service/src/servers/flight/mod.rs index 177c747d6d8ec..c057bfc09c55b 100644 --- a/src/query/service/src/servers/flight/mod.rs +++ b/src/query/service/src/servers/flight/mod.rs @@ -12,13 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod codec; mod flight_client; +mod flight_server; mod flight_service; mod request_builder; pub mod v1; +pub use codec::MessageCodec; pub use flight_client::FlightClient; pub use flight_client::FlightExchange; -pub use flight_client::FlightReceiver; pub use flight_client::FlightSender; -pub use flight_service::FlightService; +pub use flight_client::RetryableFlightReceiver; +pub use flight_server::FlightService; +pub use request_builder::RequestBuilder; diff --git a/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs b/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs index b3ff2fd1f4651..e7fadcf7d1d30 100644 --- a/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs +++ b/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs @@ -21,7 +21,6 @@ use std::sync::atomic::Ordering; use std::sync::Arc; use std::time::Duration; -use async_channel::Receiver; use databend_common_arrow::arrow_format::flight::data::FlightData; use databend_common_arrow::arrow_format::flight::service::flight_service_client::FlightServiceClient; use databend_common_base::base::GlobalInstance; @@ -42,7 +41,7 @@ use parking_lot::Mutex; use parking_lot::ReentrantMutex; use petgraph::prelude::EdgeRef; use petgraph::Direction; -use tonic::Status; +use tonic::Streaming; use super::exchange_params::ExchangeParams; use super::exchange_params::MergeExchangeParams; @@ -52,12 +51,16 @@ use super::exchange_transform::ExchangeTransform; use super::statistics_receiver::StatisticsReceiver; use super::statistics_sender::StatisticsSender; use crate::clusters::ClusterHelper; +use crate::clusters::FlightParams; use crate::pipelines::executor::ExecutorSettings; use crate::pipelines::executor::PipelineCompleteExecutor; use crate::pipelines::PipelineBuildResult; use crate::pipelines::PipelineBuilder; use crate::schedulers::QueryFragmentActions; use crate::schedulers::QueryFragmentsActions; +use crate::servers::flight::flight_client::FlightDataAckState; +use crate::servers::flight::flight_client::FlightDataAckStream; +use crate::servers::flight::flight_client::RetryableFlightReceiver; use crate::servers::flight::v1::actions::init_query_fragments; use crate::servers::flight::v1::actions::INIT_QUERY_FRAGMENTS; use crate::servers::flight::v1::actions::START_PREPARED_QUERY; @@ -70,7 +73,6 @@ use crate::servers::flight::v1::packets::QueryFragment; use crate::servers::flight::v1::packets::QueryFragments; use crate::servers::flight::FlightClient; use crate::servers::flight::FlightExchange; -use crate::servers::flight::FlightReceiver; use crate::servers::flight::FlightSender; use crate::sessions::QueryContext; use crate::sessions::TableContext; @@ -129,7 +131,11 @@ impl DataExchangeManager { let config = GlobalConfig::instance(); let with_cur_rt = env.create_rpc_clint_with_current_rt; - + let flight_params = FlightParams { + timeout: env.settings.get_flight_client_timeout()?, + retry_times: env.settings.get_max_flight_retry_times()?, + retry_interval: env.settings.get_flight_retry_interval()?, + }; let mut request_exchanges = HashMap::new(); let mut targets_exchanges = HashMap::new(); @@ -155,12 +161,25 @@ impl DataExchangeManager { Edge::Fragment(v) => QueryExchange::Fragment { source: source.id.clone(), fragment: v, - exchange: flight_client.do_get(&query_id, &target.id, v).await?, + exchange: flight_client + .request_fragment_exchange( + &query_id, + &target.id, + v, + &address, + flight_params, + ) + .await?, }, Edge::Statistics => QueryExchange::Statistics { source: source.id.clone(), exchange: flight_client - .request_server_exchange(&query_id, &target.id) + .request_statistics_exchange( + &query_id, + &target.id, + &address, + flight_params, + ) .await?, }, }) @@ -349,15 +368,28 @@ impl DataExchangeManager { &self, id: String, target: String, - ) -> Result>> { + continue_from: usize, + client_stream: Streaming, + enable_retry: bool, + ) -> Result { let queries_coordinator_guard = self.queries_coordinator.lock(); let queries_coordinator = unsafe { &mut *queries_coordinator_guard.deref().get() }; match queries_coordinator.entry(id) { - Entry::Occupied(mut v) => v.get_mut().add_statistics_exchange(target), - Entry::Vacant(v) => v - .insert(QueryCoordinator::create()) - .add_statistics_exchange(target), + Entry::Occupied(mut v) => v.get_mut().add_statistics_exchange( + target, + continue_from, + client_stream, + enable_retry, + ), + Entry::Vacant(v) => match continue_from == 0 { + true => v + .insert(QueryCoordinator::create()) + .add_statistics_exchange(target, continue_from, client_stream, enable_retry), + false => Err(ErrorCode::Timeout( + "Reconnection timeout, the state has been cleared", + )), + }, } } @@ -367,15 +399,33 @@ impl DataExchangeManager { query: String, target: String, fragment: usize, - ) -> Result>> { + continue_from: usize, + client_stream: Streaming, + enable_retry: bool, + ) -> Result { let queries_coordinator_guard = self.queries_coordinator.lock(); let queries_coordinator = unsafe { &mut *queries_coordinator_guard.deref().get() }; match queries_coordinator.entry(query) { - Entry::Occupied(mut v) => v.get_mut().add_fragment_exchange(target, fragment), - Entry::Vacant(v) => v - .insert(QueryCoordinator::create()) - .add_fragment_exchange(target, fragment), + Entry::Occupied(mut v) => v.get_mut().add_fragment_exchange( + target, + fragment, + continue_from, + client_stream, + enable_retry, + ), + Entry::Vacant(v) => match continue_from == 0 { + true => v.insert(QueryCoordinator::create()).add_fragment_exchange( + target, + fragment, + continue_from, + client_stream, + enable_retry, + ), + false => Err(ErrorCode::Timeout( + "Reconnection timeout, the state has been cleared", + )), + }, } } @@ -410,13 +460,17 @@ impl DataExchangeManager { actions: QueryFragmentsActions, ) -> Result { let settings = ctx.get_settings(); - let timeout = settings.get_flight_client_timeout()?; + let flight_params = FlightParams { + timeout: settings.get_flight_client_timeout()?, + retry_times: settings.get_max_flight_retry_times()?, + retry_interval: settings.get_flight_retry_interval()?, + }; let root_actions = actions.get_root_actions()?; let conf = GlobalConfig::instance(); // Initialize query env between cluster nodes let query_env = actions.get_query_env()?; - query_env.init(&ctx, timeout).await?; + query_env.init(&ctx, flight_params).await?; // Submit distributed tasks to all nodes. let cluster = ctx.get_cluster(); @@ -425,7 +479,7 @@ impl DataExchangeManager { let local_fragments = query_fragments.remove(&conf.query.node_id); let _: HashMap = cluster - .do_action(INIT_QUERY_FRAGMENTS, query_fragments, timeout) + .do_action(INIT_QUERY_FRAGMENTS, query_fragments, flight_params) .await?; self.set_ctx(&ctx.get_id(), ctx.clone())?; @@ -438,7 +492,7 @@ impl DataExchangeManager { let prepared_query = actions.prepared_query()?; let _: HashMap = cluster - .do_action(START_PREPARED_QUERY, prepared_query, timeout) + .do_action(START_PREPARED_QUERY, prepared_query, flight_params) .await?; Ok(build_res) @@ -458,12 +512,12 @@ impl DataExchangeManager { match queries_coordinator.get_mut(&query_id) { None => Err(ErrorCode::Internal("Query not exists.")), Some(query_coordinator) => { - assert!(query_coordinator.fragment_exchanges.is_empty()); + query_coordinator.assert_leak_fragment_exchanges(); let injector = DefaultExchangeInjector::create(); let mut build_res = query_coordinator.subscribe_fragment(&ctx, fragment_id, injector)?; - let exchanges = std::mem::take(&mut query_coordinator.statistics_exchanges); + let exchanges = query_coordinator.take_statistics_receivers(); let statistics_receiver = StatisticsReceiver::spawn_receiver(&ctx, exchanges)?; let statistics_receiver: Mutex = @@ -507,13 +561,13 @@ impl DataExchangeManager { pub fn get_flight_receiver( &self, params: &ExchangeParams, - ) -> Result> { + ) -> Result> { let queries_coordinator_guard = self.queries_coordinator.lock(); let queries_coordinator = unsafe { &mut *queries_coordinator_guard.deref().get() }; match queries_coordinator.get_mut(¶ms.get_query_id()) { None => Err(ErrorCode::Internal("Query not exists.")), - Some(coordinator) => coordinator.get_flight_receiver(params), + Some(coordinator) => coordinator.take_flight_receiver(params), } } @@ -551,15 +605,34 @@ struct QueryInfo { query_executor: Option>, } -static FLIGHT_SENDER: u8 = 1; -static FLIGHT_RECEIVER: u8 = 2; +#[derive(Hash, Eq, PartialEq)] +pub struct FragmentExchangeIdentifier { + target: String, + fragment: usize, +} + +#[derive(Hash, Eq, PartialEq)] +pub enum ExchangeIdentifier { + Statistics(String), + DataSender(FragmentExchangeIdentifier), + DataReceiver(FragmentExchangeIdentifier), +} + +impl ExchangeIdentifier { + pub fn fragment_sender(target: String, fragment: usize) -> Self { + ExchangeIdentifier::DataSender(FragmentExchangeIdentifier { target, fragment }) + } + + pub fn fragment_receiver(target: String, fragment: usize) -> Self { + ExchangeIdentifier::DataReceiver(FragmentExchangeIdentifier { target, fragment }) + } +} struct QueryCoordinator { info: Option, fragments_coordinator: HashMap>, - statistics_exchanges: HashMap, - fragment_exchanges: HashMap<(String, usize, u8), FlightExchange>, + exchanges: HashMap, } impl QueryCoordinator { @@ -567,24 +640,69 @@ impl QueryCoordinator { QueryCoordinator { info: None, fragments_coordinator: HashMap::new(), - fragment_exchanges: HashMap::new(), - statistics_exchanges: HashMap::new(), + exchanges: HashMap::new(), + } + } + + pub fn take_statistics_senders(&mut self) -> Vec { + let mut statistics_senders = Vec::with_capacity(1); + + for (identifier, exchange) in &mut self.exchanges { + if let ExchangeIdentifier::Statistics(_) = identifier { + statistics_senders.push(exchange.take_as_sender()); + } + } + + statistics_senders + } + + pub fn take_statistics_receivers(&mut self) -> Vec { + let mut statistics_receivers = Vec::with_capacity(self.exchanges.len()); + + for (identifier, exchange) in &mut self.exchanges { + if let ExchangeIdentifier::Statistics(_) = identifier { + statistics_receivers.push(exchange.take_as_receiver()); + } + } + + statistics_receivers + } + + pub fn assert_leak_fragment_exchanges(&self) { + for (identifier, exchange) in &self.exchanges { + if !matches!(identifier, ExchangeIdentifier::Statistics(_)) { + assert!(matches!( + exchange, + FlightExchange::MovedSender(_) | FlightExchange::MovedReceiver(_) + )); + } } } pub fn add_statistics_exchange( &mut self, target: String, - ) -> Result>> { + begin: usize, + client_stream: Streaming, + enable_retry: bool, + ) -> Result { let (tx, rx) = async_channel::bounded(8); - match self - .statistics_exchanges - .insert(target, FlightExchange::create_sender(tx)) - { - None => Ok(rx), - Some(_) => Err(ErrorCode::Internal( - "statistics exchanges can only have one", - )), + let identifier = ExchangeIdentifier::Statistics(target); + + match self.exchanges.entry(identifier) { + Entry::Vacant(v) => { + let state = FlightDataAckState::create(rx, 10); + v.insert(FlightExchange::create_sender(state.clone(), tx)); + FlightDataAckStream::create(state, begin, client_stream, enable_retry) + } + Entry::Occupied(mut v) => match v.get_mut() { + FlightExchange::MovedSender(v) => { + FlightDataAckStream::create(v.state.clone(), begin, client_stream, enable_retry) + } + _ => Err(ErrorCode::Internal( + "statistics exchanges can only have one", + )), + }, } } @@ -593,7 +711,8 @@ impl QueryCoordinator { exchanges: HashMap, ) -> Result<()> { for (source, exchange) in exchanges.into_iter() { - if self.statistics_exchanges.insert(source, exchange).is_some() { + let identifier = ExchangeIdentifier::Statistics(source); + if self.exchanges.insert(identifier, exchange).is_some() { return Err(ErrorCode::Internal( "Internal error, statistics exchange can only have one.", )); @@ -607,13 +726,26 @@ impl QueryCoordinator { &mut self, target: String, fragment: usize, - ) -> Result>> { + begin: usize, + client_stream: Streaming, + enable_retry: bool, + ) -> Result { let (tx, rx) = async_channel::bounded(8); - self.fragment_exchanges.insert( - (target, fragment, FLIGHT_SENDER), - FlightExchange::create_sender(tx), - ); - Ok(rx) + let identifier = ExchangeIdentifier::fragment_sender(target, fragment); + + match self.exchanges.entry(identifier) { + Entry::Vacant(v) => { + let state = FlightDataAckState::create(rx, 10); + v.insert(FlightExchange::create_sender(state.clone(), tx)); + FlightDataAckStream::create(state, begin, client_stream, enable_retry) + } + Entry::Occupied(mut v) => match v.get_mut() { + FlightExchange::MovedSender(v) => { + FlightDataAckStream::create(v.state.clone(), begin, client_stream, enable_retry) + } + _ => Err(ErrorCode::Internal("fragment exchange can only have one")), + }, + } } pub fn add_fragment_exchanges( @@ -621,83 +753,101 @@ impl QueryCoordinator { exchanges: HashMap<(String, usize), FlightExchange>, ) -> Result<()> { for ((source, fragment), exchange) in exchanges.into_iter() { - self.fragment_exchanges - .insert((source, fragment, FLIGHT_RECEIVER), exchange); + let identifier = ExchangeIdentifier::fragment_receiver(source, fragment); + + self.exchanges.insert(identifier, exchange); } Ok(()) } pub fn get_flight_senders(&mut self, params: &ExchangeParams) -> Result> { + let mut fragments_exchanges = Vec::with_capacity(self.exchanges.len()); + match params { - ExchangeParams::MergeExchange(params) => Ok(self - .fragment_exchanges - .extract_if(|(_, f, r), _| f == ¶ms.fragment_id && *r == FLIGHT_SENDER) - .map(|(_, v)| v.convert_to_sender()) - .collect::>()), - ExchangeParams::ShuffleExchange(params) => { - let mut exchanges = Vec::with_capacity(params.destination_ids.len()); + ExchangeParams::MergeExchange(params) => { + for (identifier, exchange) in &mut self.exchanges { + if let ExchangeIdentifier::DataSender(v) = identifier { + if v.fragment != params.fragment_id { + continue; + } - for destination in ¶ms.destination_ids { - exchanges.push(match destination == ¶ms.executor_id { - true => Ok(FlightSender::create(async_channel::bounded(1).0)), - false => match self.fragment_exchanges.remove(&( - destination.clone(), - params.fragment_id, - FLIGHT_SENDER, - )) { - Some(exchange_channel) => Ok(exchange_channel.convert_to_sender()), - None => Err(ErrorCode::UnknownFragmentExchange(format!( - "Unknown fragment exchange channel, {}, {}", - destination, params.fragment_id - ))), - }, - }?); + fragments_exchanges.push(exchange.take_as_sender()); + } } + } + ExchangeParams::ShuffleExchange(params) => { + for destination in ¶ms.destination_ids { + if destination == ¶ms.executor_id { + let dummy = FlightSender::create(async_channel::bounded(1).0); + fragments_exchanges.push(dummy); + continue; + } + + let target = destination.clone(); + let fragment = params.fragment_id; + let identifier = ExchangeIdentifier::fragment_sender(target, fragment); + if let Some(v) = self.exchanges.get_mut(&identifier) { + fragments_exchanges.push(v.take_as_sender()); + continue; + } - Ok(exchanges) + return Err(ErrorCode::UnknownFragmentExchange(format!( + "Unknown fragment exchange channel, {}, {}", + destination, params.fragment_id + ))); + } } - } + }; + + Ok(fragments_exchanges) } - pub fn get_flight_receiver( + pub fn take_flight_receiver( &mut self, params: &ExchangeParams, - ) -> Result> { + ) -> Result> { + let mut fragments_exchanges = Vec::with_capacity(self.exchanges.len()); + match params { - ExchangeParams::MergeExchange(params) => Ok(self - .fragment_exchanges - .extract_if(|(_, f, r), _| f == ¶ms.fragment_id && *r == FLIGHT_RECEIVER) - .map(|((source, _, _), v)| (source.clone(), v.convert_to_receiver())) - .collect::>()), - ExchangeParams::ShuffleExchange(params) => { - let mut exchanges = Vec::with_capacity(params.destination_ids.len()); + ExchangeParams::MergeExchange(params) => { + for (identifier, exchange) in &mut self.exchanges { + if let ExchangeIdentifier::DataReceiver(v) = identifier { + if v.fragment != params.fragment_id { + continue; + } - for destination in ¶ms.destination_ids { - exchanges.push(( - destination.clone(), - match destination == ¶ms.executor_id { - true => Ok(FlightReceiver::create(async_channel::bounded(1).1)), - false => match self.fragment_exchanges.remove(&( - destination.clone(), - params.fragment_id, - FLIGHT_RECEIVER, - )) { - Some(v) => Ok(v.convert_to_receiver()), - _ => Err(ErrorCode::UnknownFragmentExchange(format!( - "Unknown fragment flight receiver, {}, {}", - destination, params.fragment_id - ))), - }, - }?, - )); + fragments_exchanges.push((v.target.clone(), exchange.take_as_receiver())); + } } + } + ExchangeParams::ShuffleExchange(params) => { + for destination in ¶ms.destination_ids { + if destination == ¶ms.executor_id { + let dummy = RetryableFlightReceiver::dummy(); + fragments_exchanges.push((destination.clone(), dummy)); + continue; + } - Ok(exchanges) + let source = destination.clone(); + let fragment = params.fragment_id; + let identifier = ExchangeIdentifier::fragment_receiver(source, fragment); + if let Some(v) = self.exchanges.get_mut(&identifier) { + let receiver = v.take_as_receiver(); + fragments_exchanges.push((destination.clone(), receiver)); + continue; + } + + return Err(ErrorCode::UnknownFragmentExchange(format!( + "Unknown fragment flight receiver, {}, {}", + destination, params.fragment_id + ))); + } } - } - } + }; + Ok(fragments_exchanges) + } pub fn prepare_pipeline(&mut self, fragments: &QueryFragments) -> Result<()> { let query_info = self.info.as_ref().expect("expect query info"); let query_context = query_info.query_ctx.clone(); @@ -840,28 +990,30 @@ impl QueryCoordinator { let settings = ExecutorSettings::try_create(info.query_ctx.clone())?; let executor = PipelineCompleteExecutor::from_pipelines(pipelines, settings)?; - assert!(self.fragment_exchanges.is_empty()); + self.assert_leak_fragment_exchanges(); let info_mut = self.info.as_mut().expect("Query info is None"); info_mut.query_executor = Some(executor.clone()); let query_id = info_mut.query_id.clone(); let query_ctx = info_mut.query_ctx.clone(); - let request_server_exchanges = std::mem::take(&mut self.statistics_exchanges); - if request_server_exchanges.len() != 1 { + let ctx = query_ctx.clone(); + let mut statistics_senders = self.take_statistics_senders(); + + let Some(statistics_sender) = statistics_senders.pop() else { + return Err(ErrorCode::Internal( + "Request server must less than 1 if is not request server.", + )); + }; + + if !statistics_senders.is_empty() { return Err(ErrorCode::Internal( "Request server must less than 1 if is not request server.", )); } - let ctx = query_ctx.clone(); - let (_, request_server_exchange) = request_server_exchanges.into_iter().next().unwrap(); - let mut statistics_sender = StatisticsSender::spawn( - &query_id, - ctx, - request_server_exchange, - executor.get_inner(), - ); + let mut statistics_sender = + StatisticsSender::spawn(&query_id, ctx, statistics_sender, executor.get_inner()); let span = if let Some(parent) = SpanContext::current_local_parent() { Span::root("Distributed-Executor", parent) diff --git a/src/query/service/src/servers/flight/v1/exchange/exchange_source_reader.rs b/src/query/service/src/servers/flight/v1/exchange/exchange_source_reader.rs index 88a1a139c21b4..daa7c4bfec835 100644 --- a/src/query/service/src/servers/flight/v1/exchange/exchange_source_reader.rs +++ b/src/query/service/src/servers/flight/v1/exchange/exchange_source_reader.rs @@ -27,15 +27,15 @@ use databend_common_pipeline_core::processors::ProcessorPtr; use databend_common_pipeline_core::PipeItem; use log::info; +use crate::servers::flight::flight_client::RetryableFlightReceiver; use crate::servers::flight::v1::exchange::serde::ExchangeDeserializeMeta; use crate::servers::flight::v1::packets::DataPacket; -use crate::servers::flight::FlightReceiver; pub struct ExchangeSourceReader { finished: AtomicBool, output: Arc, output_data: Vec, - flight_receiver: FlightReceiver, + flight_receiver: RetryableFlightReceiver, source: String, destination: String, fragment: usize, @@ -44,7 +44,7 @@ pub struct ExchangeSourceReader { impl ExchangeSourceReader { pub fn create( output: Arc, - flight_receiver: FlightReceiver, + flight_receiver: RetryableFlightReceiver, source: &str, destination: &str, fragment: usize, @@ -156,7 +156,7 @@ impl Processor for ExchangeSourceReader { } pub fn create_reader_item( - flight_receiver: FlightReceiver, + flight_receiver: RetryableFlightReceiver, source: &str, destination: &str, fragment: usize, diff --git a/src/query/service/src/servers/flight/v1/exchange/serde/exchange_deserializer.rs b/src/query/service/src/servers/flight/v1/exchange/serde/exchange_deserializer.rs index 999f7d990a5df..33e64787be7a6 100644 --- a/src/query/service/src/servers/flight/v1/exchange/serde/exchange_deserializer.rs +++ b/src/query/service/src/servers/flight/v1/exchange/serde/exchange_deserializer.rs @@ -128,6 +128,7 @@ impl BlockMetaTransform for TransformExchangeDeserializ DataPacket::QueryProfiles(_) => unreachable!(), DataPacket::DataCacheMetrics(_) => unreachable!(), DataPacket::FragmentData(v) => Ok(vec![self.recv_data(meta.packet, v)?]), + DataPacket::FlightControl(_) => unreachable!(), } } } diff --git a/src/query/service/src/servers/flight/v1/exchange/statistics_receiver.rs b/src/query/service/src/servers/flight/v1/exchange/statistics_receiver.rs index 84b89ab89b344..02cfbd165f1a4 100644 --- a/src/query/service/src/servers/flight/v1/exchange/statistics_receiver.rs +++ b/src/query/service/src/servers/flight/v1/exchange/statistics_receiver.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; use std::sync::Arc; use databend_common_base::base::tokio::sync::broadcast::channel; @@ -25,8 +24,8 @@ use databend_common_exception::Result; use futures_util::future::select; use futures_util::future::Either; +use crate::servers::flight::flight_client::RetryableFlightReceiver; use crate::servers::flight::v1::packets::DataPacket; -use crate::servers::flight::FlightExchange; use crate::sessions::QueryContext; pub struct StatisticsReceiver { @@ -38,14 +37,13 @@ pub struct StatisticsReceiver { impl StatisticsReceiver { pub fn spawn_receiver( ctx: &Arc, - statistics_exchanges: HashMap, + statistics_exchanges: Vec, ) -> Result { let (shutdown_tx, _shutdown_rx) = channel(2); let mut exchange_handler = Vec::with_capacity(statistics_exchanges.len()); let runtime = Runtime::with_worker_threads(2, Some(String::from("StatisticsReceiver")))?; - for (_source, exchange) in statistics_exchanges.into_iter() { - let rx = exchange.convert_to_receiver(); + for rx in statistics_exchanges.into_iter() { exchange_handler.push(runtime.spawn({ let ctx = ctx.clone(); let shutdown_rx = shutdown_tx.subscribe(); @@ -147,6 +145,7 @@ impl StatisticsReceiver { ctx.get_data_cache_metrics().merge(metrics); Ok(false) } + Ok(Some(DataPacket::FlightControl(_))) => unreachable!(), } } diff --git a/src/query/service/src/servers/flight/v1/exchange/statistics_sender.rs b/src/query/service/src/servers/flight/v1/exchange/statistics_sender.rs index b26e8f4d2854d..95d0af8a730fa 100644 --- a/src/query/service/src/servers/flight/v1/exchange/statistics_sender.rs +++ b/src/query/service/src/servers/flight/v1/exchange/statistics_sender.rs @@ -29,7 +29,6 @@ use log::warn; use crate::pipelines::executor::PipelineExecutor; use crate::servers::flight::v1::packets::DataPacket; use crate::servers::flight::v1::packets::ProgressInfo; -use crate::servers::flight::FlightExchange; use crate::servers::flight::FlightSender; use crate::sessions::QueryContext; @@ -43,11 +42,10 @@ impl StatisticsSender { pub fn spawn( query_id: &str, ctx: Arc, - exchange: FlightExchange, + tx: FlightSender, executor: Arc, ) -> Self { let spawner = ctx.clone(); - let tx = exchange.convert_to_sender(); let (shutdown_flag_sender, shutdown_flag_receiver) = async_channel::bounded(1); let handle = spawner.spawn({ @@ -231,8 +229,4 @@ impl StatisticsSender { progress_info } - - // fn fetch_profiling(ctx: &Arc) -> Result> { - // // ctx.get_exchange_manager() - // } } diff --git a/src/query/service/src/servers/flight/v1/flight_service.rs b/src/query/service/src/servers/flight/v1/flight_service.rs index 2d8c763de7016..ee4d3382a8537 100644 --- a/src/query/service/src/servers/flight/v1/flight_service.rs +++ b/src/query/service/src/servers/flight/v1/flight_service.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::pin::Pin; +use std::sync::Arc; use databend_common_arrow::arrow_format::flight::data::Action; use databend_common_arrow::arrow_format::flight::data::ActionType; @@ -27,7 +28,6 @@ use databend_common_arrow::arrow_format::flight::data::PutResult; use databend_common_arrow::arrow_format::flight::data::Result as FlightResult; use databend_common_arrow::arrow_format::flight::data::SchemaResult; use databend_common_arrow::arrow_format::flight::data::Ticket; -use databend_common_arrow::arrow_format::flight::service::flight_service_server::FlightService; use databend_common_config::GlobalConfig; use databend_common_exception::ErrorCode; use fastrace::func_path; @@ -39,13 +39,13 @@ use tonic::Response as RawResponse; use tonic::Status; use tonic::Streaming; +use crate::servers::flight::flight_service::FlightOperation; use crate::servers::flight::request_builder::RequestGetter; use crate::servers::flight::v1::actions::flight_actions; use crate::servers::flight::v1::actions::FlightActions; use crate::servers::flight::v1::exchange::DataExchangeManager; -pub type FlightStream = - Pin> + Send + Sync + 'static>>; +pub type FlightStream = Pin> + Send + Sync + 'static>>; pub struct DatabendQueryFlightService { actions: FlightActions, @@ -61,54 +61,70 @@ impl DatabendQueryFlightService { type Response = Result, Status>; type StreamReq = Request>; - #[async_trait::async_trait] -impl FlightService for DatabendQueryFlightService { - type HandshakeStream = FlightStream; +impl FlightOperation for DatabendQueryFlightService { + type HandshakeStream = FlightStream>; #[async_backtrace::framed] async fn handshake(&self, _: StreamReq) -> Response { - Result::Err(Status::unimplemented( + Err(Status::unimplemented( "DatabendQuery does not implement handshake.", )) } - type ListFlightsStream = FlightStream; + type ListFlightsStream = FlightStream>; #[async_backtrace::framed] async fn list_flights(&self, _: Request) -> Response { - Result::Err(Status::unimplemented( + Err(Status::unimplemented( "DatabendQuery does not implement list_flights.", )) } #[async_backtrace::framed] - async fn get_flight_info(&self, _: Request) -> Response { + async fn get_flight_info(&self, _: Request) -> Response> { Err(Status::unimplemented( "DatabendQuery does not implement get_flight_info.", )) } #[async_backtrace::framed] - async fn get_schema(&self, _: Request) -> Response { + async fn get_schema(&self, _: Request) -> Response> { Err(Status::unimplemented( "DatabendQuery does not implement get_schema.", )) } - - type DoGetStream = FlightStream; + type DoExchangeStream = FlightStream>; #[async_backtrace::framed] - async fn do_get(&self, request: Request) -> Response { + async fn do_exchange( + &self, + request: StreamReq, + ) -> Response { let root = databend_common_tracing::start_trace_for_remote_request(func_path!(), &request); let _guard = root.set_local_parent(); - match request.get_metadata("x-type")?.as_str() { "request_server_exchange" => { let target = request.get_metadata("x-target")?; let query_id = request.get_metadata("x-query-id")?; + let continue_from = request + .get_metadata("x-continue-from")? + .parse::() + .unwrap(); + // if x-enable-retry is not 0, set enable_retry to a bool value true, otherwise false + let enable_retry: u64 = request + .get_metadata("x-enable-retry")? + .parse::() + .unwrap_or_default(); + let client_stream = request.into_inner(); Ok(RawResponse::new(Box::pin( - DataExchangeManager::instance().handle_statistics_exchange(query_id, target)?, + DataExchangeManager::instance().handle_statistics_exchange( + query_id, + target, + continue_from, + client_stream, + enable_retry != 0, + )?, ))) } "exchange_fragment" => { @@ -118,12 +134,27 @@ impl FlightService for DatabendQueryFlightService { .get_metadata("x-fragment-id")? .parse::() .unwrap(); - + let continue_from = request + .get_metadata("x-continue-from")? + .parse::() + .unwrap(); + let enable_retry: u64 = request + .get_metadata("x-enable-retry")? + .parse::() + .unwrap_or_default(); + let client_stream = request.into_inner(); Ok(RawResponse::new(Box::pin( - DataExchangeManager::instance() - .handle_exchange_fragment(query_id, target, fragment)?, + DataExchangeManager::instance().handle_exchange_fragment( + query_id, + target, + fragment, + continue_from, + client_stream, + enable_retry != 0, + )?, ))) } + "health" => Ok(RawResponse::new(build_health_response())), exchange_type => Err(Status::unimplemented(format!( "Unimplemented exchange type: {:?}", exchange_type @@ -131,21 +162,20 @@ impl FlightService for DatabendQueryFlightService { } } - type DoPutStream = FlightStream; + type DoGetStream = FlightStream>; #[async_backtrace::framed] - async fn do_put(&self, _req: StreamReq) -> Response { - Err(Status::unimplemented("unimplemented do_put")) + async fn do_get(&self, _request: Request) -> Response { + Err(Status::unimplemented("unimplemented do_exchange")) } - - type DoExchangeStream = FlightStream; + type DoPutStream = FlightStream>; #[async_backtrace::framed] - async fn do_exchange(&self, _: StreamReq) -> Response { - Err(Status::unimplemented("unimplemented do_exchange")) + async fn do_put(&self, _req: StreamReq) -> Response { + Err(Status::unimplemented("unimplemented do_put")) } - type DoActionStream = FlightStream; + type DoActionStream = FlightStream>; #[async_backtrace::framed] async fn do_action(&self, request: Request) -> Response { @@ -169,17 +199,26 @@ impl FlightService for DatabendQueryFlightService { .await { Err(cause) => Err(cause.into()), - Ok(body) => Ok(RawResponse::new( - Box::pin(tokio_stream::once(Ok(FlightResult { body }))) - as FlightStream, - )), + Ok(body) => Ok(RawResponse::new(Box::pin(tokio_stream::once(Ok(Arc::new( + FlightResult { body }, + )))) + as FlightStream>)), } } - type ListActionsStream = FlightStream; + type ListActionsStream = FlightStream>; #[async_backtrace::framed] async fn list_actions(&self, _: Request) -> Response { Ok(RawResponse::new(Box::pin(stream::empty()))) } } + +fn build_health_response() -> FlightStream> { + Box::pin(stream::iter(vec![Ok(Arc::new(FlightData { + flight_descriptor: None, + data_header: vec![], + data_body: Vec::from("ok"), + app_metadata: vec![0x03], + }))])) +} diff --git a/src/query/service/src/servers/flight/v1/packets/mod.rs b/src/query/service/src/servers/flight/v1/packets/mod.rs index ca44d46afde5e..b9ca6e885e46b 100644 --- a/src/query/service/src/servers/flight/v1/packets/mod.rs +++ b/src/query/service/src/servers/flight/v1/packets/mod.rs @@ -19,6 +19,7 @@ mod packet_fragment; mod packet_publisher; pub use packet_data::DataPacket; +pub use packet_data::FlightControlCommand; pub use packet_data::FragmentData; pub use packet_data_progressinfo::ProgressInfo; pub use packet_executor::QueryFragments; diff --git a/src/query/service/src/servers/flight/v1/packets/packet_data.rs b/src/query/service/src/servers/flight/v1/packets/packet_data.rs index 2ca07a7a8dcdd..4c9a38b0fe3a4 100644 --- a/src/query/service/src/servers/flight/v1/packets/packet_data.rs +++ b/src/query/service/src/servers/flight/v1/packets/packet_data.rs @@ -28,6 +28,8 @@ use databend_common_pipeline_core::processors::PlanProfile; use databend_common_storage::CopyStatus; use databend_common_storage::MutationStatus; use log::error; +use serde::Deserialize; +use serde::Serialize; use crate::servers::flight::v1::packets::ProgressInfo; @@ -61,6 +63,12 @@ pub enum DataPacket { CopyStatus(CopyStatus), MutationStatus(MutationStatus), DataCacheMetrics(DataCacheMetricValues), + FlightControl(FlightControlCommand), +} + +#[derive(Serialize, Deserialize, Debug)] +pub enum FlightControlCommand { + Close, } fn calc_size(flight_data: &FlightData) -> usize { @@ -78,6 +86,7 @@ impl DataPacket { DataPacket::FragmentData(v) => calc_size(&v.data) + v.meta.len(), DataPacket::QueryProfiles(_) => 0, DataPacket::DataCacheMetrics(_) => 0, + DataPacket::FlightControl(_) => 0, } } } @@ -136,6 +145,12 @@ impl TryFrom for FlightData { data_header: vec![], flight_descriptor: None, }, + DataPacket::FlightControl(command) => FlightData { + app_metadata: vec![0x09], + data_body: serde_json::to_vec(&command)?, + data_header: vec![], + flight_descriptor: None, + }, }) } } @@ -196,6 +211,11 @@ impl TryFrom for DataPacket { serde_json::from_slice::(&flight_data.data_body)?; Ok(DataPacket::DataCacheMetrics(status)) } + 0x09 => { + let command = + serde_json::from_slice::(&flight_data.data_body)?; + Ok(DataPacket::FlightControl(command)) + } _ => Err(ErrorCode::BadBytes("Unknown flight data packet type.")), } } diff --git a/src/query/service/src/servers/flight/v1/packets/packet_executor.rs b/src/query/service/src/servers/flight/v1/packets/packet_executor.rs index c555184d82912..29999b7727b31 100644 --- a/src/query/service/src/servers/flight/v1/packets/packet_executor.rs +++ b/src/query/service/src/servers/flight/v1/packets/packet_executor.rs @@ -14,7 +14,7 @@ use crate::servers::flight::v1::packets::QueryFragment; -#[derive(Debug, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct QueryFragments { pub query_id: String, pub fragments: Vec, diff --git a/src/query/service/src/servers/flight/v1/packets/packet_publisher.rs b/src/query/service/src/servers/flight/v1/packets/packet_publisher.rs index 33c5f20e11b7d..12bb478627759 100644 --- a/src/query/service/src/servers/flight/v1/packets/packet_publisher.rs +++ b/src/query/service/src/servers/flight/v1/packets/packet_publisher.rs @@ -34,6 +34,7 @@ use serde::Deserialize; use serde::Serialize; use crate::clusters::ClusterHelper; +use crate::clusters::FlightParams; use crate::servers::flight::v1::actions::INIT_QUERY_ENV; use crate::sessions::QueryContext; use crate::sessions::SessionManager; @@ -140,7 +141,7 @@ pub struct QueryEnv { } impl QueryEnv { - pub async fn init(&self, ctx: &Arc, timeout: u64) -> Result<()> { + pub async fn init(&self, ctx: &Arc, flight_params: FlightParams) -> Result<()> { debug!("Dataflow diagram {:?}", self.dataflow_diagram); let cluster = ctx.get_cluster(); @@ -151,7 +152,7 @@ impl QueryEnv { } let _ = cluster - .do_action::<_, ()>(INIT_QUERY_ENV, message, timeout) + .do_action::<_, ()>(INIT_QUERY_ENV, message, flight_params) .await?; Ok(()) diff --git a/src/query/service/tests/it/servers/flight/flight_service.rs b/src/query/service/tests/it/servers/flight/flight_service.rs index 015efc53c7de9..c3821401c3351 100644 --- a/src/query/service/tests/it/servers/flight/flight_service.rs +++ b/src/query/service/tests/it/servers/flight/flight_service.rs @@ -17,7 +17,6 @@ use std::net::TcpListener; use std::str::FromStr; use std::sync::Arc; -use databend_common_arrow::arrow_format::flight::data::Empty; use databend_common_arrow::arrow_format::flight::service::flight_service_client::FlightServiceClient; use databend_common_base::base::tokio; use databend_common_exception::ErrorCode; @@ -26,7 +25,9 @@ use databend_common_grpc::ConnectionFactory; use databend_common_grpc::GrpcConnectionError; use databend_common_grpc::RpcClientTlsConfig; use databend_query::servers::flight::FlightService; +use databend_query::servers::flight::RequestBuilder; use databend_query::test_kits::*; +use futures_util::StreamExt; use crate::tests::tls_constants::TEST_CA_CERT; use crate::tests::tls_constants::TEST_CN_NAME; @@ -53,15 +54,31 @@ async fn test_tls_rpc_server() -> Result<()> { // normal case let conn = ConnectionFactory::create_rpc_channel(listener_address, None, tls_conf).await?; let mut f_client = FlightServiceClient::new(conn); - let r = f_client.list_actions(Empty {}).await; + let (server_tx, server_rx) = async_channel::bounded(1); + let r = f_client + .do_exchange( + RequestBuilder::create(Box::pin(server_rx)) + .with_metadata("x-type", "health")? + .build(), + ) + .await; assert!(r.is_ok()); - + server_tx.close(); // client access without tls enabled will be failed // - channel can still be created, but communication will be failed let channel = ConnectionFactory::create_rpc_channel(listener_address, None, None).await?; + let mut f_client = FlightServiceClient::new(channel); - let r = f_client.list_actions(Empty {}).await; + let (server_tx, server_rx) = async_channel::bounded(1); + let r = f_client + .do_exchange( + RequestBuilder::create(Box::pin(server_rx)) + .with_metadata("x-type", "health")? + .build(), + ) + .await; assert!(r.is_err()); + server_tx.close(); Ok(()) } @@ -120,3 +137,27 @@ async fn test_rpc_server_port_used() -> Result<()> { assert!(r.is_err()); Ok(()) } + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn test_rpc_server_do_exchange() -> Result<()> { + let listener_address = SocketAddr::from_str("127.0.0.1:9995")?; + let mut rpc_service = FlightService::create(ConfigBuilder::create().build())?; + rpc_service.start(listener_address).await?; + let conn = ConnectionFactory::create_rpc_channel(listener_address, None, None).await?; + let mut f_client = FlightServiceClient::new(conn); + let (server_tx, server_rx) = async_channel::bounded(1); + let r = f_client + .do_exchange( + RequestBuilder::create(Box::pin(server_rx)) + .with_metadata("x-type", "health")? + .build(), + ) + .await; + assert!(r.is_ok()); + let message = r.unwrap().into_inner().next().await.unwrap()?; + assert_eq!(message.app_metadata.last(), Some(&0x03)); + assert_eq!(message.data_body, Vec::from("ok")); + + server_tx.close(); + Ok(()) +} diff --git a/src/query/settings/src/settings_default.rs b/src/query/settings/src/settings_default.rs index 50f9aeb27e5bf..b26b1aacec250 100644 --- a/src/query/settings/src/settings_default.rs +++ b/src/query/settings/src/settings_default.rs @@ -909,6 +909,18 @@ impl DefaultSettings { mode: SettingMode::Both, range: Some(SettingRange::Numeric(0..=1)), }), + ("max_flight_connection_retry_times", DefaultSettingValue { + value: UserSettingValue::UInt64(3), + desc: "The maximum retry count for cluster flight. Disable if 0.", + mode: SettingMode::Both, + range: Some(SettingRange::Numeric(0..=30)), + }), + ("flight_connection_retry_interval", DefaultSettingValue { + value: UserSettingValue::UInt64(3), + desc: "The retry interval of cluster flight is in seconds.", + mode: SettingMode::Both, + range: Some(SettingRange::Numeric(0..=900)), + }), ("random_function_seed", DefaultSettingValue { value: UserSettingValue::UInt64(0), desc: "Seed for random function", diff --git a/src/query/settings/src/settings_getter_setter.rs b/src/query/settings/src/settings_getter_setter.rs index 55c2e46cd7e57..3df3ea8ae5619 100644 --- a/src/query/settings/src/settings_getter_setter.rs +++ b/src/query/settings/src/settings_getter_setter.rs @@ -749,6 +749,14 @@ impl Settings { Ok(self.try_get_u64("random_function_seed")? == 1) } + pub fn get_flight_retry_interval(&self) -> Result { + self.try_get_u64("flight_connection_retry_interval") + } + + pub fn get_max_flight_retry_times(&self) -> Result { + self.try_get_u64("max_flight_connection_retry_times") + } + pub fn get_dynamic_sample_time_budget_ms(&self) -> Result { self.try_get_u64("dynamic_sample_time_budget_ms") } diff --git a/tests/suites/1_stateful/02_query/02_0009_kill_connection.result b/tests/suites/1_stateful/02_query/02_0009_kill_connection.result new file mode 100644 index 0000000000000..be2d12014cc17 --- /dev/null +++ b/tests/suites/1_stateful/02_query/02_0009_kill_connection.result @@ -0,0 +1 @@ +Final state: Succeeded diff --git a/tests/suites/1_stateful/02_query/02_0009_kill_connection.sh b/tests/suites/1_stateful/02_query/02_0009_kill_connection.sh new file mode 100755 index 0000000000000..9f632454311ed --- /dev/null +++ b/tests/suites/1_stateful/02_query/02_0009_kill_connection.sh @@ -0,0 +1,76 @@ +#!/usr/bin/env bash + +perform_initial_query() { + local response=$(curl -s -u root: -XPOST "http://localhost:8000/v1/query" -H 'Content-Type: application/json' -d '{"sql": "select avg(number) from numbers(1000000000)"}') + local stats_uri=$(echo "$response" | jq -r '.stats_uri') + local final_uri=$(echo "$response" | jq -r '.final_uri') + echo "$stats_uri|$final_uri" +} +poll_stats_uri() { + local uri=$1 + local state_exists=true + while $state_exists; do + local response=$(curl -s -u root: -XGET "http://localhost:8000$uri") + if ! echo "$response" | jq -e '.state' > /dev/null; then + state_exists=false + else + sleep 2 + fi + done +} +get_final_state() { + local uri=$1 + local response=$(curl -s -u root: -XGET "http://localhost:8000$uri") + echo "$response" | jq -r '.state' +} + +IFS='|' read -r stats_uri final_uri <<< $(perform_initial_query) + +poll_stats_uri "$stats_uri" & +POLL_PID=$! +sleep 1 +netstat_output=$(netstat -an | grep '9092') + +# skip standalone mode +if [ -z "$netstat_output" ]; then + echo "Final state: Succeeded" + exit 0 +fi + +port=$(echo "$netstat_output" | awk ' + $NF == "ESTABLISHED" { + if ($4 ~ /:9092$/) { + split($5, a, ":") + port = a[2] + } else if ($5 ~ /:9092$/) { + split($4, a, ":") + port = a[2] + } + } + END { + print port + } +') + +# Start tcpkill in the background +sudo tcpkill -i lo host 127.0.0.1 and port $port > tcpkill_output.txt 2>&1 & +TCPKILL_PID=$! + +# Wait for tcpkill to output at least 3 lines or terminate if done earlier +while [ $(wc -l < tcpkill_output.txt) -lt 3 ]; do + if ! kill -0 $TCPKILL_PID 2> /dev/null; then + break + fi + sleep 1 +done + +# Kill tcpkill after the desired number of lines if it's still running +if kill -0 $TCPKILL_PID 2> /dev/null; then + kill $TCPKILL_PID +fi + +wait $POLL_PID + +final_state=$(get_final_state "$final_uri") +echo "Final state: $final_state" +