From 706c744cf28474a9c856319024a67c1b9cf1c2e2 Mon Sep 17 00:00:00 2001 From: Liuqing Yue Date: Mon, 12 Aug 2024 11:35:54 +0800 Subject: [PATCH 01/18] refactor: allow to encode data with arc --- src/query/service/src/servers/flight/codec.rs | 72 +++ .../src/servers/flight/flight_client.rs | 14 +- .../src/servers/flight/flight_server.rs | 106 ++++ .../src/servers/flight/flight_service.rs | 494 +++++++++++++++--- src/query/service/src/servers/flight/mod.rs | 6 +- .../flight/v1/exchange/exchange_manager.rs | 8 +- .../src/servers/flight/v1/flight_service.rs | 61 ++- .../tests/it/servers/flight/flight_service.rs | 43 +- 8 files changed, 686 insertions(+), 118 deletions(-) create mode 100644 src/query/service/src/servers/flight/codec.rs create mode 100644 src/query/service/src/servers/flight/flight_server.rs diff --git a/src/query/service/src/servers/flight/codec.rs b/src/query/service/src/servers/flight/codec.rs new file mode 100644 index 0000000000000..5185e691f5ac8 --- /dev/null +++ b/src/query/service/src/servers/flight/codec.rs @@ -0,0 +1,72 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::marker::PhantomData; +use std::sync::Arc; + +use prost::Message; +use tonic::codec::Codec; +use tonic::codec::DecodeBuf; +use tonic::codec::Decoder; +use tonic::codec::EncodeBuf; +use tonic::codec::Encoder; +use tonic::Status; + +#[derive(Default)] +pub struct MessageCodec(PhantomData<(E, D)>); + +impl Codec for MessageCodec { + type Encode = Arc; + type Decode = D; + type Encoder = ArcEncoder; + type Decoder = DefaultDecoder; + + fn encoder(&mut self) -> Self::Encoder { + ArcEncoder(PhantomData) + } + + fn decoder(&mut self) -> Self::Decoder { + DefaultDecoder(PhantomData) + } +} + +pub struct ArcEncoder(PhantomData); + +impl Encoder for ArcEncoder { + type Item = Arc; + + type Error = Status; + + fn encode(&mut self, item: Self::Item, dst: &mut EncodeBuf<'_>) -> Result<(), Self::Error> { + item.as_ref() + .encode(dst) + .map_err(|e| Status::internal(e.to_string())) + } +} + +pub struct DefaultDecoder(PhantomData); + +impl Decoder for DefaultDecoder { + type Item = T; + + type Error = Status; + + fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result, Self::Error> { + let item = Message::decode(buf) + .map(Some) + .map_err(|e| Status::internal(e.to_string()))?; + + Ok(item) + } +} diff --git a/src/query/service/src/servers/flight/flight_client.rs b/src/query/service/src/servers/flight/flight_client.rs index 5c3b1012f050a..b82f6f2d41222 100644 --- a/src/query/service/src/servers/flight/flight_client.rs +++ b/src/query/service/src/servers/flight/flight_client.rs @@ -248,11 +248,11 @@ impl FlightReceiver { } pub struct FlightSender { - tx: Sender>, + tx: Sender, Status>>, } impl FlightSender { - pub fn create(tx: Sender>) -> FlightSender { + pub fn create(tx: Sender, Status>>) -> FlightSender { FlightSender { tx } } @@ -262,7 +262,11 @@ impl FlightSender { #[async_backtrace::framed] pub async fn send(&self, data: DataPacket) -> Result<()> { - if let Err(_cause) = self.tx.send(Ok(FlightData::try_from(data)?)).await { + if let Err(_cause) = self + .tx + .send(Ok(Arc::new(FlightData::try_from(data)?))) + .await + { return Err(ErrorCode::AbortedQuery( "Aborted query, because the remote flight channel is closed.", )); @@ -282,12 +286,12 @@ pub enum FlightExchange { notify: Arc, receiver: Receiver>, }, - Sender(Sender>), + Sender(Sender, Status>>), } impl FlightExchange { pub fn create_sender( - sender: Sender>, + sender: Sender, Status>>, ) -> FlightExchange { FlightExchange::Sender(sender) } diff --git a/src/query/service/src/servers/flight/flight_server.rs b/src/query/service/src/servers/flight/flight_server.rs new file mode 100644 index 0000000000000..ee921ac9a826b --- /dev/null +++ b/src/query/service/src/servers/flight/flight_server.rs @@ -0,0 +1,106 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::future::Future; +use std::net::SocketAddr; +use std::sync::Arc; + +use databend_common_base::base::tokio; +use databend_common_base::base::tokio::sync::Notify; +use databend_common_config::InnerConfig; +use databend_common_exception::ErrorCode; +use databend_common_exception::Result; +use log::info; +use tonic::transport::server::TcpIncoming; +use tonic::transport::Identity; +use tonic::transport::Server; +use tonic::transport::ServerTlsConfig; + +use super::v1::DatabendQueryFlightService; +use crate::servers::flight::flight_service::FlightServiceServer; +use crate::servers::Server as DatabendQueryServer; + +pub struct FlightService { + pub config: InnerConfig, + pub abort_notify: Arc, +} + +impl FlightService { + pub fn create(config: InnerConfig) -> Result> { + Ok(Box::new(Self { + config, + abort_notify: Arc::new(Notify::new()), + })) + } + + fn shutdown_notify(&self) -> impl Future + 'static { + let notified = self.abort_notify.clone(); + async move { + notified.notified().await; + } + } + + #[async_backtrace::framed] + async fn server_tls_config(conf: &InnerConfig) -> Result { + let cert = tokio::fs::read(conf.query.rpc_tls_server_cert.as_str()).await?; + let key = tokio::fs::read(conf.query.rpc_tls_server_key.as_str()).await?; + let server_identity = Identity::from_pem(cert, key); + let tls_conf = ServerTlsConfig::new().identity(server_identity); + Ok(tls_conf) + } + + #[async_backtrace::framed] + pub async fn start_with_incoming(&mut self, addr: SocketAddr) -> Result<()> { + let flight_api_service = DatabendQueryFlightService::create(); + let builder = Server::builder(); + let mut builder = if self.config.tls_rpc_server_enabled() { + info!("databend query tls rpc enabled"); + builder + .tls_config(Self::server_tls_config(&self.config).await.map_err(|e| { + ErrorCode::TLSConfigurationFailure(format!( + "failed to load server tls config: {e}", + )) + })?) + .map_err(|e| { + ErrorCode::TLSConfigurationFailure(format!("failed to invoke tls_config: {e}",)) + })? + } else { + builder + }; + + let incoming = TcpIncoming::new(addr, true, None) + .map_err(|e| ErrorCode::CannotListenerPort(format!("{e}")))?; + let server = builder + .add_service(FlightServiceServer::new( + flight_api_service, + usize::MAX, + usize::MAX, + )) + .serve_with_incoming_shutdown(incoming, self.shutdown_notify()); + databend_common_base::runtime::spawn(server); + Ok(()) + } +} + +#[async_trait::async_trait] +impl DatabendQueryServer for FlightService { + #[async_backtrace::framed] + async fn shutdown(&mut self, _graceful: bool) {} + + #[async_backtrace::framed] + async fn start(&mut self, addr: SocketAddr) -> Result { + self.start_with_incoming(addr).await?; + Ok(addr) + } +} diff --git a/src/query/service/src/servers/flight/flight_service.rs b/src/query/service/src/servers/flight/flight_service.rs index 063d61f1ceb1f..c47207fda6a84 100644 --- a/src/query/service/src/servers/flight/flight_service.rs +++ b/src/query/service/src/servers/flight/flight_service.rs @@ -12,96 +12,436 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::future::Future; -use std::net::SocketAddr; +use std::convert::Infallible; use std::sync::Arc; +use std::task::Context; +use std::task::Poll; -use databend_common_arrow::arrow_format::flight::service::flight_service_server::FlightServiceServer; -use databend_common_base::base::tokio; -use databend_common_base::base::tokio::sync::Notify; -use databend_common_config::InnerConfig; -use databend_common_exception::ErrorCode; -use databend_common_exception::Result; -use log::info; -use tonic::transport::server::TcpIncoming; -use tonic::transport::Identity; -use tonic::transport::Server; -use tonic::transport::ServerTlsConfig; - -use super::v1::DatabendQueryFlightService; -use crate::servers::Server as DatabendQueryServer; - -pub struct FlightService { - pub config: InnerConfig, - pub abort_notify: Arc, +use databend_common_arrow::arrow_format::flight; +use databend_common_arrow::arrow_format::flight::data::Action; +use databend_common_arrow::arrow_format::flight::data::ActionType; +use databend_common_arrow::arrow_format::flight::data::Criteria; +use databend_common_arrow::arrow_format::flight::data::Empty; +use databend_common_arrow::arrow_format::flight::data::FlightData; +use databend_common_arrow::arrow_format::flight::data::FlightDescriptor; +use databend_common_arrow::arrow_format::flight::data::FlightInfo; +use databend_common_arrow::arrow_format::flight::data::HandshakeRequest; +use databend_common_arrow::arrow_format::flight::data::HandshakeResponse; +use databend_common_arrow::arrow_format::flight::data::PutResult; +use databend_common_arrow::arrow_format::flight::data::SchemaResult; +use databend_common_arrow::arrow_format::flight::data::Ticket; +use tonic::async_trait; +use tonic::body::empty_body; +use tonic::body::BoxBody; +use tonic::codegen::http::request::Request as HTTPRequest; +use tonic::codegen::http::response::Response as HTTPResponse; +use tonic::codegen::Body; +use tonic::codegen::BoxFuture; +use tonic::codegen::Service; +use tonic::codegen::StdError; +use tonic::server::NamedService; +use tonic::server::ServerStreamingService; +use tonic::server::StreamingService; +use tonic::server::UnaryService; +use tonic::Status; + +use crate::servers::flight::codec::MessageCodec; + +/// This trait is derived from `FlightService`. +/// It has been modified the `do_get` method signatures to use `Arc` as `DoGetStream` type +#[async_trait] +pub trait FlightOperation: Send + Sync + 'static { + type HandshakeStream: tokio_stream::Stream, Status>> + + Send + + 'static; + async fn handshake( + &self, + request: tonic::Request>, + ) -> Result, Status>; + + type ListFlightsStream: tokio_stream::Stream, Status>> + + Send + + 'static; + async fn list_flights( + &self, + request: tonic::Request, + ) -> Result, Status>; + + async fn get_flight_info( + &self, + request: tonic::Request, + ) -> Result>, Status>; + + async fn get_schema( + &self, + request: tonic::Request, + ) -> Result>, Status>; + type DoExchangeStream: tokio_stream::Stream, Status>> + + Send + + 'static; + async fn do_exchange( + &self, + request: tonic::Request>, + ) -> Result, Status>; + + type DoGetStream: tokio_stream::Stream, Status>> + Send + 'static; + async fn do_get( + &self, + request: tonic::Request, + ) -> Result, Status>; + type DoPutStream: tokio_stream::Stream, Status>> + Send + 'static; + async fn do_put( + &self, + request: tonic::Request>, + ) -> Result, Status>; + type DoActionStream: tokio_stream::Stream, Status>> + + Send + + 'static; + async fn do_action( + &self, + request: tonic::Request, + ) -> Result, Status>; + type ListActionsStream: tokio_stream::Stream, Status>> + + Send + + 'static; + async fn list_actions( + &self, + request: tonic::Request, + ) -> Result, Status>; } -impl FlightService { - pub fn create(config: InnerConfig) -> Result> { - Ok(Box::new(Self { - config, - abort_notify: Arc::new(Notify::new()), - })) +struct _Inner(Arc); +impl Clone for _Inner { + fn clone(&self) -> Self { + Self(self.0.clone()) } +} - fn shutdown_notify(&self) -> impl Future + 'static { - let notified = self.abort_notify.clone(); - async move { - notified.notified().await; +pub struct FlightServiceServer { + inner: _Inner, + max_decoding_message_size: Option, + max_encoding_message_size: Option, +} + +impl FlightServiceServer { + pub(crate) fn new( + service: T, + max_decoding_message_size: usize, + max_encoding_message_size: usize, + ) -> Self { + Self { + inner: _Inner(Arc::new(service)), + max_decoding_message_size: Some(max_decoding_message_size), + max_encoding_message_size: Some(max_encoding_message_size), } } +} - #[async_backtrace::framed] - async fn server_tls_config(conf: &InnerConfig) -> Result { - let cert = tokio::fs::read(conf.query.rpc_tls_server_cert.as_str()).await?; - let key = tokio::fs::read(conf.query.rpc_tls_server_key.as_str()).await?; - let server_identity = Identity::from_pem(cert, key); - let tls_conf = ServerTlsConfig::new().identity(server_identity); - Ok(tls_conf) +impl Service> for FlightServiceServer +where + T: FlightOperation, + B: Body + Send + 'static, + B::Error: Into + Send + 'static, +{ + type Response = HTTPResponse; + type Error = Infallible; + type Future = BoxFuture; + + fn poll_ready( + &mut self, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) } - #[async_backtrace::framed] - pub async fn start_with_incoming(&mut self, addr: SocketAddr) -> Result<()> { - let flight_api_service = DatabendQueryFlightService::create(); - let builder = Server::builder(); - let mut builder = if self.config.tls_rpc_server_enabled() { - info!("databend query tls rpc enabled"); - builder - .tls_config(Self::server_tls_config(&self.config).await.map_err(|e| { - ErrorCode::TLSConfigurationFailure(format!( - "failed to load server tls config: {e}", - )) - })?) - .map_err(|e| { - ErrorCode::TLSConfigurationFailure(format!("failed to invoke tls_config: {e}",)) - })? - } else { - builder - }; - - let incoming = TcpIncoming::new(addr, true, None) - .map_err(|e| ErrorCode::CannotListenerPort(format!("{e}")))?; - let server = builder - .add_service( - FlightServiceServer::new(flight_api_service) - .max_encoding_message_size(usize::MAX) - .max_decoding_message_size(usize::MAX), - ) - .serve_with_incoming_shutdown(incoming, self.shutdown_notify()); - - databend_common_base::runtime::spawn(server); - Ok(()) + fn call(&mut self, req: HTTPRequest) -> Self::Future { + match req.uri().path() { + "/arrow.flight.protocol.FlightService/Handshake" => { + struct HandshakeSvc(pub Arc); + impl StreamingService for HandshakeSvc { + type Response = Arc; + type ResponseStream = T::HandshakeStream; + type Future = BoxFuture, tonic::Status>; + fn call( + &mut self, + request: tonic::Request>, + ) -> Self::Future { + let inner = self.0.clone(); + let fut = async move { (*inner).handshake(request).await }; + Box::pin(fut) + } + } + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = HandshakeSvc(inner); + let codec = MessageCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec).apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/arrow.flight.protocol.FlightService/ListFlights" => { + struct ListFlightsSvc(pub Arc); + impl ServerStreamingService for ListFlightsSvc { + type Response = Arc; + type ResponseStream = T::ListFlightsStream; + type Future = BoxFuture, tonic::Status>; + fn call(&mut self, request: tonic::Request) -> Self::Future { + let inner = self.0.clone(); + let fut = async move { (*inner).list_flights(request).await }; + Box::pin(fut) + } + } + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = ListFlightsSvc(inner); + let codec = MessageCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec).apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.server_streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/arrow.flight.protocol.FlightService/GetFlightInfo" => { + struct GetFlightInfoSvc(pub Arc); + impl UnaryService for GetFlightInfoSvc { + type Response = Arc; + type Future = BoxFuture, tonic::Status>; + fn call(&mut self, request: tonic::Request) -> Self::Future { + let inner = self.0.clone(); + let fut = async move { (*inner).get_flight_info(request).await }; + Box::pin(fut) + } + } + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = GetFlightInfoSvc(inner); + let codec = MessageCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec).apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/arrow.flight.protocol.FlightService/GetSchema" => { + struct GetSchemaSvc(pub Arc); + impl UnaryService for GetSchemaSvc { + type Response = Arc; + type Future = BoxFuture, tonic::Status>; + fn call(&mut self, request: tonic::Request) -> Self::Future { + let inner = self.0.clone(); + let fut = async move { (*inner).get_schema(request).await }; + Box::pin(fut) + } + } + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = GetSchemaSvc(inner); + let codec = MessageCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec).apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/arrow.flight.protocol.FlightService/DoGet" => { + struct DoGetSvc(pub Arc); + impl ServerStreamingService for DoGetSvc { + type Response = Arc; + type ResponseStream = T::DoGetStream; + type Future = BoxFuture, Status>; + fn call(&mut self, request: tonic::Request) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = + async move { ::do_get(&inner, request).await }; + Box::pin(fut) + } + } + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.0.clone(); + let fut = async move { + let method = DoGetSvc(inner); + let codec = MessageCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec).apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.server_streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/arrow.flight.protocol.FlightService/DoPut" => { + struct DoPutSvc(pub Arc); + impl StreamingService for DoPutSvc { + type Response = Arc; + type ResponseStream = T::DoPutStream; + type Future = BoxFuture, tonic::Status>; + fn call( + &mut self, + request: tonic::Request>, + ) -> Self::Future { + let inner = self.0.clone(); + let fut = async move { (*inner).do_put(request).await }; + Box::pin(fut) + } + } + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = DoPutSvc(inner); + let codec = MessageCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec).apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/arrow.flight.protocol.FlightService/DoExchange" => { + struct DoExchangeSvc(pub Arc); + impl StreamingService for DoExchangeSvc { + type Response = Arc; + type ResponseStream = T::DoExchangeStream; + type Future = BoxFuture, tonic::Status>; + fn call( + &mut self, + request: tonic::Request>, + ) -> Self::Future { + let inner = self.0.clone(); + let fut = async move { (*inner).do_exchange(request).await }; + Box::pin(fut) + } + } + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = DoExchangeSvc(inner); + let codec = MessageCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec).apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/arrow.flight.protocol.FlightService/DoAction" => { + struct DoActionSvc(pub Arc); + impl ServerStreamingService for DoActionSvc { + type Response = Arc; + + type ResponseStream = T::DoActionStream; + + type Future = BoxFuture, Status>; + + fn call(&mut self, request: tonic::Request) -> Self::Future { + let inner = self.0.clone(); + let fut = async move { (*inner).do_action(request).await }; + Box::pin(fut) + } + } + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.0.clone(); + let fut = async move { + let method = DoActionSvc(inner); + let codec = MessageCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec).apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.server_streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/arrow.flight.protocol.FlightService/ListActions" => { + struct ListActionsSvc(pub Arc); + impl ServerStreamingService for ListActionsSvc { + type Response = Arc; + type ResponseStream = T::ListActionsStream; + type Future = BoxFuture, tonic::Status>; + fn call(&mut self, request: tonic::Request) -> Self::Future { + let inner = self.0.clone(); + let fut = async move { (*inner).list_actions(request).await }; + Box::pin(fut) + } + } + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = ListActionsSvc(inner); + let codec = MessageCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec).apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.server_streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + _ => Box::pin(async move { + Ok(HTTPResponse::builder() + .status(200) + .header("grpc-status", "12") + .header("content-type", "application/grpc") + .body(empty_body()) + .unwrap()) + }), + } } } -#[async_trait::async_trait] -impl DatabendQueryServer for FlightService { - #[async_backtrace::framed] - async fn shutdown(&mut self, _graceful: bool) {} - - #[async_backtrace::framed] - async fn start(&mut self, addr: SocketAddr) -> Result { - self.start_with_incoming(addr).await?; - Ok(addr) +impl Clone for FlightServiceServer { + fn clone(&self) -> Self { + let inner = self.inner.clone(); + Self { + inner, + max_decoding_message_size: self.max_decoding_message_size, + max_encoding_message_size: self.max_encoding_message_size, + } } } + +// The client side still uses FlightClient to find the service by this name, +// so we keep the service name unchanged. +impl NamedService for FlightServiceServer { + const NAME: &'static str = "arrow.flight.protocol.FlightService"; +} diff --git a/src/query/service/src/servers/flight/mod.rs b/src/query/service/src/servers/flight/mod.rs index 177c747d6d8ec..ea435893df189 100644 --- a/src/query/service/src/servers/flight/mod.rs +++ b/src/query/service/src/servers/flight/mod.rs @@ -12,13 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod codec; mod flight_client; +mod flight_server; mod flight_service; mod request_builder; pub mod v1; +pub use codec::MessageCodec; pub use flight_client::FlightClient; pub use flight_client::FlightExchange; pub use flight_client::FlightReceiver; pub use flight_client::FlightSender; -pub use flight_service::FlightService; +pub use flight_server::FlightService; +pub use request_builder::RequestBuilder; diff --git a/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs b/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs index b3ff2fd1f4651..4af106bf13fde 100644 --- a/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs +++ b/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs @@ -349,7 +349,7 @@ impl DataExchangeManager { &self, id: String, target: String, - ) -> Result>> { + ) -> Result, Status>>> { let queries_coordinator_guard = self.queries_coordinator.lock(); let queries_coordinator = unsafe { &mut *queries_coordinator_guard.deref().get() }; @@ -367,7 +367,7 @@ impl DataExchangeManager { query: String, target: String, fragment: usize, - ) -> Result>> { + ) -> Result, Status>>> { let queries_coordinator_guard = self.queries_coordinator.lock(); let queries_coordinator = unsafe { &mut *queries_coordinator_guard.deref().get() }; @@ -575,7 +575,7 @@ impl QueryCoordinator { pub fn add_statistics_exchange( &mut self, target: String, - ) -> Result>> { + ) -> Result, Status>>> { let (tx, rx) = async_channel::bounded(8); match self .statistics_exchanges @@ -607,7 +607,7 @@ impl QueryCoordinator { &mut self, target: String, fragment: usize, - ) -> Result>> { + ) -> Result, Status>>> { let (tx, rx) = async_channel::bounded(8); self.fragment_exchanges.insert( (target, fragment, FLIGHT_SENDER), diff --git a/src/query/service/src/servers/flight/v1/flight_service.rs b/src/query/service/src/servers/flight/v1/flight_service.rs index 2d8c763de7016..541e8e4b1f10c 100644 --- a/src/query/service/src/servers/flight/v1/flight_service.rs +++ b/src/query/service/src/servers/flight/v1/flight_service.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::pin::Pin; +use std::sync::Arc; use databend_common_arrow::arrow_format::flight::data::Action; use databend_common_arrow::arrow_format::flight::data::ActionType; @@ -27,7 +28,6 @@ use databend_common_arrow::arrow_format::flight::data::PutResult; use databend_common_arrow::arrow_format::flight::data::Result as FlightResult; use databend_common_arrow::arrow_format::flight::data::SchemaResult; use databend_common_arrow::arrow_format::flight::data::Ticket; -use databend_common_arrow::arrow_format::flight::service::flight_service_server::FlightService; use databend_common_config::GlobalConfig; use databend_common_exception::ErrorCode; use fastrace::func_path; @@ -39,13 +39,13 @@ use tonic::Response as RawResponse; use tonic::Status; use tonic::Streaming; +use crate::servers::flight::flight_service::FlightOperation; use crate::servers::flight::request_builder::RequestGetter; use crate::servers::flight::v1::actions::flight_actions; use crate::servers::flight::v1::actions::FlightActions; use crate::servers::flight::v1::exchange::DataExchangeManager; -pub type FlightStream = - Pin> + Send + Sync + 'static>>; +pub type FlightStream = Pin> + Send + Sync + 'static>>; pub struct DatabendQueryFlightService { actions: FlightActions, @@ -61,48 +61,52 @@ impl DatabendQueryFlightService { type Response = Result, Status>; type StreamReq = Request>; - #[async_trait::async_trait] -impl FlightService for DatabendQueryFlightService { - type HandshakeStream = FlightStream; +impl FlightOperation for DatabendQueryFlightService { + type HandshakeStream = FlightStream>; #[async_backtrace::framed] async fn handshake(&self, _: StreamReq) -> Response { - Result::Err(Status::unimplemented( + Err(Status::unimplemented( "DatabendQuery does not implement handshake.", )) } - type ListFlightsStream = FlightStream; + type ListFlightsStream = FlightStream>; #[async_backtrace::framed] async fn list_flights(&self, _: Request) -> Response { - Result::Err(Status::unimplemented( + Err(Status::unimplemented( "DatabendQuery does not implement list_flights.", )) } #[async_backtrace::framed] - async fn get_flight_info(&self, _: Request) -> Response { + async fn get_flight_info(&self, _: Request) -> Response> { Err(Status::unimplemented( "DatabendQuery does not implement get_flight_info.", )) } #[async_backtrace::framed] - async fn get_schema(&self, _: Request) -> Response { + async fn get_schema(&self, _: Request) -> Response> { Err(Status::unimplemented( "DatabendQuery does not implement get_schema.", )) } + type DoExchangeStream = FlightStream>; - type DoGetStream = FlightStream; + #[async_backtrace::framed] + async fn do_exchange(&self, _: StreamReq) -> Response { + Err(Status::unimplemented("unimplemented do_exchange")) + } + + type DoGetStream = FlightStream>; #[async_backtrace::framed] async fn do_get(&self, request: Request) -> Response { let root = databend_common_tracing::start_trace_for_remote_request(func_path!(), &request); let _guard = root.set_local_parent(); - match request.get_metadata("x-type")?.as_str() { "request_server_exchange" => { let target = request.get_metadata("x-target")?; @@ -124,28 +128,21 @@ impl FlightService for DatabendQueryFlightService { .handle_exchange_fragment(query_id, target, fragment)?, ))) } + "health" => Ok(RawResponse::new(build_health_response())), exchange_type => Err(Status::unimplemented(format!( "Unimplemented exchange type: {:?}", exchange_type ))), } } - - type DoPutStream = FlightStream; + type DoPutStream = FlightStream>; #[async_backtrace::framed] async fn do_put(&self, _req: StreamReq) -> Response { Err(Status::unimplemented("unimplemented do_put")) } - type DoExchangeStream = FlightStream; - - #[async_backtrace::framed] - async fn do_exchange(&self, _: StreamReq) -> Response { - Err(Status::unimplemented("unimplemented do_exchange")) - } - - type DoActionStream = FlightStream; + type DoActionStream = FlightStream>; #[async_backtrace::framed] async fn do_action(&self, request: Request) -> Response { @@ -169,17 +166,25 @@ impl FlightService for DatabendQueryFlightService { .await { Err(cause) => Err(cause.into()), - Ok(body) => Ok(RawResponse::new( - Box::pin(tokio_stream::once(Ok(FlightResult { body }))) - as FlightStream, - )), + Ok(body) => Ok(RawResponse::new(Box::pin(tokio_stream::once(Ok(Arc::new( + FlightResult { body }, + )))) + as FlightStream>)), } } - type ListActionsStream = FlightStream; + type ListActionsStream = FlightStream>; #[async_backtrace::framed] async fn list_actions(&self, _: Request) -> Response { Ok(RawResponse::new(Box::pin(stream::empty()))) } } +fn build_health_response() -> FlightStream> { + Box::pin(stream::iter(vec![Ok(Arc::new(FlightData { + flight_descriptor: None, + data_header: vec![], + data_body: Vec::from("ok"), + app_metadata: vec![0x03], + }))])) +} diff --git a/src/query/service/tests/it/servers/flight/flight_service.rs b/src/query/service/tests/it/servers/flight/flight_service.rs index 015efc53c7de9..5d45f8c072acc 100644 --- a/src/query/service/tests/it/servers/flight/flight_service.rs +++ b/src/query/service/tests/it/servers/flight/flight_service.rs @@ -17,7 +17,7 @@ use std::net::TcpListener; use std::str::FromStr; use std::sync::Arc; -use databend_common_arrow::arrow_format::flight::data::Empty; +use databend_common_arrow::arrow_format::flight::data::Ticket; use databend_common_arrow::arrow_format::flight::service::flight_service_client::FlightServiceClient; use databend_common_base::base::tokio; use databend_common_exception::ErrorCode; @@ -26,7 +26,9 @@ use databend_common_grpc::ConnectionFactory; use databend_common_grpc::GrpcConnectionError; use databend_common_grpc::RpcClientTlsConfig; use databend_query::servers::flight::FlightService; +use databend_query::servers::flight::RequestBuilder; use databend_query::test_kits::*; +use futures_util::StreamExt; use crate::tests::tls_constants::TEST_CA_CERT; use crate::tests::tls_constants::TEST_CN_NAME; @@ -53,14 +55,26 @@ async fn test_tls_rpc_server() -> Result<()> { // normal case let conn = ConnectionFactory::create_rpc_channel(listener_address, None, tls_conf).await?; let mut f_client = FlightServiceClient::new(conn); - let r = f_client.list_actions(Empty {}).await; + let r = f_client + .do_get( + RequestBuilder::create(Ticket::default()) + .with_metadata("x-type", "health")? + .build(), + ) + .await; assert!(r.is_ok()); // client access without tls enabled will be failed // - channel can still be created, but communication will be failed let channel = ConnectionFactory::create_rpc_channel(listener_address, None, None).await?; let mut f_client = FlightServiceClient::new(channel); - let r = f_client.list_actions(Empty {}).await; + let r = f_client + .do_get( + RequestBuilder::create(Ticket::default()) + .with_metadata("x-type", "health")? + .build(), + ) + .await; assert!(r.is_err()); Ok(()) @@ -120,3 +134,26 @@ async fn test_rpc_server_port_used() -> Result<()> { assert!(r.is_err()); Ok(()) } + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn test_rpc_server_do_get() -> Result<()> { + let listener_address = SocketAddr::from_str("127.0.0.1:9995")?; + let mut rpc_service = FlightService::create(ConfigBuilder::create().build())?; + rpc_service.start(listener_address).await?; + let conn = ConnectionFactory::create_rpc_channel(listener_address, None, None).await?; + let mut f_client = FlightServiceClient::new(conn); + + let r = f_client + .do_get( + RequestBuilder::create(Ticket::default()) + .with_metadata("x-type", "health")? + .build(), + ) + .await; + assert!(r.is_ok()); + let message = r.unwrap().into_inner().next().await.unwrap()?; + assert_eq!(message.app_metadata.last(), Some(&0x03)); + assert_eq!(message.data_body, Vec::from("ok")); + + Ok(()) +} From b05ad850378337b587a297634d093a34b28ebf96 Mon Sep 17 00:00:00 2001 From: Liuqing Yue Date: Thu, 22 Aug 2024 10:40:02 +0800 Subject: [PATCH 02/18] feat: add retry for do_get and do_action --- Cargo.lock | 1 + Cargo.toml | 1 + .../ci/ci-run-stateful-tests-cluster-minio.sh | 2 + .../ci-run-stateful-tests-standalone-minio.sh | 2 + src/common/exception/Cargo.toml | 1 + src/common/exception/src/exception_into.rs | 7 + src/query/service/src/clusters/cluster.rs | 39 +- .../src/servers/flight/flight_client.rs | 478 ++++++++++++++++-- .../src/servers/flight/flight_service.rs | 2 +- src/query/service/src/servers/flight/mod.rs | 2 +- .../flight/v1/exchange/exchange_manager.rs | 346 +++++++++---- .../v1/exchange/exchange_source_reader.rs | 8 +- .../flight/v1/exchange/statistics_receiver.rs | 8 +- .../flight/v1/exchange/statistics_sender.rs | 8 +- .../src/servers/flight/v1/flight_service.rs | 23 +- .../flight/v1/packets/packet_executor.rs | 2 +- src/query/settings/src/settings_default.rs | 12 + .../settings/src/settings_getter_setter.rs | 8 + .../02_query/02_0009_kill_connection.result | 1 + .../02_query/02_0009_kill_connection.sh | 76 +++ 20 files changed, 834 insertions(+), 193 deletions(-) create mode 100644 tests/suites/1_stateful/02_query/02_0009_kill_connection.result create mode 100755 tests/suites/1_stateful/02_query/02_0009_kill_connection.sh diff --git a/Cargo.lock b/Cargo.lock index 8d6b7c67eca85..17de2a57e6d79 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3418,6 +3418,7 @@ dependencies = [ "geos", "geozero 0.13.0", "http 1.1.0", + "hyper 0.14.30", "opendal", "parquet", "paste", diff --git a/Cargo.toml b/Cargo.toml index 76b0ba372dbc4..a18671359502a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -244,6 +244,7 @@ geos = { version = "8.3", features = ["static", "geo", "geo-types"] } geozero = { version = "0.13.0", features = ["default", "with-wkb", "with-geos", "with-geojson"] } hashbrown = { version = "0.14.3", default-features = false } http = "1" +hyper = "0.14.20" itertools = "0.10.5" jsonb = "0.4.1" jwt-simple = "0.11.0" diff --git a/scripts/ci/ci-run-stateful-tests-cluster-minio.sh b/scripts/ci/ci-run-stateful-tests-cluster-minio.sh index b27da3cb6d9ff..a122870e3e555 100755 --- a/scripts/ci/ci-run-stateful-tests-cluster-minio.sh +++ b/scripts/ci/ci-run-stateful-tests-cluster-minio.sh @@ -20,6 +20,8 @@ export STORAGE_ALLOW_INSECURE=true echo "Install dependence" python3 -m pip install --quiet mysql-connector-python requests +sudo apt-get update -yq +sudo apt-get install -yq dsniff net-tools echo "calling test suite" echo "Starting Cluster databend-query" diff --git a/scripts/ci/ci-run-stateful-tests-standalone-minio.sh b/scripts/ci/ci-run-stateful-tests-standalone-minio.sh index 706d7cb50aea4..d713f3ed5becf 100755 --- a/scripts/ci/ci-run-stateful-tests-standalone-minio.sh +++ b/scripts/ci/ci-run-stateful-tests-standalone-minio.sh @@ -20,6 +20,8 @@ export STORAGE_ALLOW_INSECURE=true echo "Install dependence" python3 -m pip install --quiet mysql-connector-python requests +sudo apt-get update -yq +sudo apt-get install -yq dsniff net-tools echo "calling test suite" echo "Starting standalone DatabendQuery(debug)" diff --git a/src/common/exception/Cargo.toml b/src/common/exception/Cargo.toml index 033999b74d0d2..bff581ab5fb5a 100644 --- a/src/common/exception/Cargo.toml +++ b/src/common/exception/Cargo.toml @@ -21,6 +21,7 @@ bincode = { workspace = true } geos = { workspace = true } geozero = { workspace = true } http = { workspace = true } +hyper = { workspace = true } opendal = { workspace = true } parquet = { workspace = true } paste = { workspace = true } diff --git a/src/common/exception/src/exception_into.rs b/src/common/exception/src/exception_into.rs index c20736633799f..dc951b98cf23d 100644 --- a/src/common/exception/src/exception_into.rs +++ b/src/common/exception/src/exception_into.rs @@ -375,6 +375,13 @@ impl From for ErrorCode { tonic::Code::Unknown => { let details = status.details(); if details.is_empty() { + if status.source().map_or(false, |e| e.is::()) { + return ErrorCode::CannotConnectNode(format!( + "{}, source: {:?}", + status.message(), + status.source() + )); + } return ErrorCode::UnknownException(format!( "{}, source: {:?}", status.message(), diff --git a/src/query/service/src/clusters/cluster.rs b/src/query/service/src/clusters/cluster.rs index ba9349572d127..b3750c8bd481f 100644 --- a/src/query/service/src/clusters/cluster.rs +++ b/src/query/service/src/clusters/cluster.rs @@ -50,11 +50,13 @@ use futures::future::Either; use futures::Future; use futures::StreamExt; use log::error; +use log::info; use log::warn; use rand::thread_rng; use rand::Rng; use serde::Deserialize; use serde::Serialize; +use tokio::time::sleep; use crate::servers::flight::FlightClient; @@ -79,7 +81,7 @@ pub trait ClusterHelper { fn get_nodes(&self) -> Vec>; - async fn do_action Deserialize<'de> + Send>( + async fn do_action Deserialize<'de> + Send>( &self, path: &str, message: HashMap, @@ -116,7 +118,7 @@ impl ClusterHelper for Cluster { self.nodes.to_vec() } - async fn do_action Deserialize<'de> + Send>( + async fn do_action Deserialize<'de> + Send>( &self, path: &str, message: HashMap, @@ -145,12 +147,33 @@ impl ClusterHelper for Cluster { let node_secret = node.secret.clone(); async move { - let mut conn = create_client(&config, &flight_address).await?; - Ok::<_, ErrorCode>(( - id, - conn.do_action::<_, Res>(path, node_secret, message, timeout) - .await?, - )) + let mut attempt = 0; + let max_attempts = 2; + + loop { + let mut conn = create_client(&config, &flight_address).await?; + match conn + .do_action::<_, Res>( + path, + node_secret.clone(), + message.clone(), + timeout, + ) + .await + { + Ok(result) => return Ok((id, result)), + Err(e) + if e.code() == ErrorCode::CANNOT_CONNECT_NODE + && attempt < max_attempts => + { + // only retry when error is network problem + info!("retry do_action, attempt: {}", attempt); + attempt += 1; + sleep(Duration::from_secs(1)).await; + } + Err(e) => return Err(e), + } + } } }); } diff --git a/src/query/service/src/servers/flight/flight_client.rs b/src/query/service/src/servers/flight/flight_client.rs index b82f6f2d41222..5f514a7811475 100644 --- a/src/query/service/src/servers/flight/flight_client.rs +++ b/src/query/service/src/servers/flight/flight_client.rs @@ -12,8 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::VecDeque; +use std::pin::Pin; use std::str::FromStr; +use std::sync::atomic::AtomicPtr; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; use std::sync::Arc; +use std::task::Context; +use std::task::Poll; use async_channel::Receiver; use async_channel::Sender; @@ -22,16 +29,21 @@ use databend_common_arrow::arrow_format::flight::data::FlightData; use databend_common_arrow::arrow_format::flight::data::Ticket; use databend_common_arrow::arrow_format::flight::service::flight_service_client::FlightServiceClient; use databend_common_base::base::tokio::time::Duration; -use databend_common_base::runtime::drop_guard; +use databend_common_base::runtime::GlobalIORuntime; +use databend_common_base::runtime::TrySpawn; use databend_common_exception::ErrorCode; use databend_common_exception::Result; use fastrace::func_path; use fastrace::future::FutureExt; use fastrace::Span; +use futures::Stream; use futures::StreamExt; use futures_util::future::Either; +use log::info; +use parking_lot::Mutex; use serde::Deserialize; use serde::Serialize; +use tokio::time::sleep; use tonic::metadata::AsciiMetadataKey; use tonic::metadata::AsciiMetadataValue; use tonic::transport::channel::Channel; @@ -41,6 +53,7 @@ use tonic::Streaming; use crate::pipelines::executor::WatchNotify; use crate::servers::flight::request_builder::RequestBuilder; +use crate::servers::flight::v1::exchange::DataExchangeManager; use crate::servers::flight::v1::packets::DataPacket; pub struct FlightClient { @@ -119,19 +132,31 @@ impl FlightClient { &mut self, query_id: &str, target: &str, + source_address: &str, + retry_times: usize, + retry_interval: usize, ) -> Result { - let streaming = self - .get_streaming( - RequestBuilder::create(Ticket::default()) - .with_metadata("x-type", "request_server_exchange")? - .with_metadata("x-target", target)? - .with_metadata("x-query-id", query_id)? - .build(), - ) - .await?; + let req = RequestBuilder::create(Ticket::default()) + .with_metadata("x-type", "request_server_exchange")? + .with_metadata("x-target", target)? + .with_metadata("x-query-id", query_id)? + .with_metadata("x-continue-from", "0")? + .build(); + let streaming = self.get_streaming(req).await?; let (notify, rx) = Self::streaming_receiver(streaming); - Ok(FlightExchange::create_receiver(notify, rx)) + Ok(FlightExchange::create_receiver( + notify, + rx, + Some(ConnectionInfo { + query_id: query_id.to_string(), + target: target.to_string(), + fragment: None, + source_address: source_address.to_string(), + retry_times, + retry_interval: Duration::from_secs(retry_interval as u64), + }), + )) } #[async_backtrace::framed] @@ -141,19 +166,34 @@ impl FlightClient { query_id: &str, target: &str, fragment: usize, + source_address: &str, + retry_times: usize, + retry_interval: usize, ) -> Result { let request = RequestBuilder::create(Ticket::default()) .with_metadata("x-type", "exchange_fragment")? .with_metadata("x-target", target)? .with_metadata("x-query-id", query_id)? .with_metadata("x-fragment-id", &fragment.to_string())? + .with_metadata("x-continue-from", "0")? .build(); let request = databend_common_tracing::inject_span_to_tonic_request(request); let streaming = self.get_streaming(request).await?; let (notify, rx) = Self::streaming_receiver(streaming); - Ok(FlightExchange::create_receiver(notify, rx)) + Ok(FlightExchange::create_receiver( + notify, + rx, + Some(ConnectionInfo { + query_id: query_id.to_string(), + target: target.to_string(), + fragment: Some(fragment), + source_address: source_address.to_string(), + retry_times, + retry_interval: Duration::from_secs(retry_interval as u64), + }), + )) } fn streaming_receiver( @@ -209,27 +249,51 @@ impl FlightClient { Err(status) => Err(ErrorCode::from(status).add_message_back("(while in query flight)")), } } + + #[async_backtrace::framed] + async fn reconnect(&mut self, info: &ConnectionInfo, seq: usize) -> Result { + let request = match info.fragment { + Some(fragment_id) => RequestBuilder::create(Ticket::default()) + .with_metadata("x-type", "exchange_fragment")? + .with_metadata("x-target", &info.target)? + .with_metadata("x-query-id", &info.query_id)? + .with_metadata("x-fragment-id", &fragment_id.to_string())? + .with_metadata("x-continue-from", &seq.to_string())? + .build(), + None => RequestBuilder::create(Ticket::default()) + .with_metadata("x-type", "request_server_exchange")? + .with_metadata("x-target", &info.target)? + .with_metadata("x-query-id", &info.query_id)? + .with_metadata("x-continue-from", &seq.to_string())? + .build(), + }; + let request = databend_common_tracing::inject_span_to_tonic_request(request); + + let streaming = self.get_streaming(request).await?; + + let (network_notify, recv) = Self::streaming_receiver(streaming); + Ok(FlightRxInner::create(network_notify, recv)) + } } -pub struct FlightReceiver { - notify: Arc, - rx: Receiver>, +#[derive(Clone)] +pub struct ConnectionInfo { + pub query_id: String, + pub target: String, + pub fragment: Option, + pub source_address: String, + pub retry_times: usize, + pub retry_interval: Duration, } -impl Drop for FlightReceiver { - fn drop(&mut self) { - drop_guard(move || { - self.close(); - }) - } +pub struct FlightRxInner { + notify: Arc, + rx: Receiver>, } -impl FlightReceiver { - pub fn create(rx: Receiver>) -> FlightReceiver { - FlightReceiver { - rx, - notify: Arc::new(WatchNotify::new()), - } +impl FlightRxInner { + pub fn create(notify: Arc, rx: Receiver>) -> FlightRxInner { + FlightRxInner { rx, notify } } #[async_backtrace::framed] @@ -247,12 +311,107 @@ impl FlightReceiver { } } +pub struct RetryableFlightReceiver { + seq: Arc, + info: Option, + inner: Arc>, +} + +impl Drop for RetryableFlightReceiver { + fn drop(&mut self) { + self.close(); + } +} + +impl RetryableFlightReceiver { + pub fn dummy() -> RetryableFlightReceiver { + RetryableFlightReceiver { + seq: Arc::new(AtomicUsize::new(0)), + info: None, + inner: Arc::new(Default::default()), + } + } + + #[async_backtrace::framed] + pub async fn recv(&self) -> Result> { + // Non thread safe, we only use atomic to implement mutable. + loop { + let inner = unsafe { &*self.inner.load(Ordering::SeqCst) }; + return match inner.recv().await { + Ok(message) => { + self.seq.fetch_add(1, Ordering::SeqCst); + Ok(message) + } + Err(cause) => { + info!("Error while receiving data from flight : {:?}", cause); + if cause.code() == ErrorCode::CANNOT_CONNECT_NODE { + // only retry when error is network problem + let Err(cause) = self.retry().await else { + info!("Retry flight connection successfully!"); + continue; + }; + + info!("Retry flight connection failure, cause: {:?}", cause); + } + + Err(cause) + } + }; + } + } + + #[async_backtrace::framed] + async fn retry(&self) -> Result<()> { + if let Some(connection_info) = &self.info { + let mut flight_client = + DataExchangeManager::create_client(&connection_info.source_address, true).await?; + + for attempts in 0..connection_info.retry_times { + let Ok(recv) = flight_client + .reconnect(connection_info, self.seq.load(Ordering::Acquire)) + .await + else { + info!("Reconnect attempt {} failed", attempts); + sleep(connection_info.retry_interval).await; + continue; + }; + + let ptr = self + .inner + .swap(Box::into_raw(Box::new(recv)), Ordering::SeqCst); + + unsafe { + // We cannot determine the number of strong ref. so close it. + let broken_connection = Box::from_raw(ptr); + broken_connection.close(); + } + + return Ok(()); + } + + return Err(ErrorCode::Timeout("Exceed max retries time")); + } + + Ok(()) + } + + pub fn close(&self) { + unsafe { + let inner = self.inner.load(Ordering::SeqCst); + + if !inner.is_null() { + (*inner).close(); + } + } + } +} + pub struct FlightSender { - tx: Sender, Status>>, + tx: Sender>, } impl FlightSender { - pub fn create(tx: Sender, Status>>) -> FlightSender { + pub fn create(tx: Sender>) -> FlightSender { FlightSender { tx } } @@ -262,11 +421,7 @@ impl FlightSender { #[async_backtrace::framed] pub async fn send(&self, data: DataPacket) -> Result<()> { - if let Err(_cause) = self - .tx - .send(Ok(Arc::new(FlightData::try_from(data)?))) - .await - { + if let Err(_cause) = self.tx.send(Ok(FlightData::try_from(data)?)).await { return Err(ErrorCode::AbortedQuery( "Aborted query, because the remote flight channel is closed.", )); @@ -280,43 +435,260 @@ impl FlightSender { } } +pub struct SenderPayload { + pub state: Arc>, + pub sender: Option>>, +} + +pub struct ReceiverPayload { + seq: Arc, + info: Option, + inner: Arc>, +} + pub enum FlightExchange { Dummy, - Receiver { - notify: Arc, - receiver: Receiver>, - }, - Sender(Sender, Status>>), + + Sender(SenderPayload), + Receiver(ReceiverPayload), + + MovedSender(SenderPayload), + MovedReceiver(ReceiverPayload), } impl FlightExchange { pub fn create_sender( - sender: Sender, Status>>, + state: Arc>, + sender: Sender>, ) -> FlightExchange { - FlightExchange::Sender(sender) + FlightExchange::Sender(SenderPayload { + state, + sender: Some(sender), + }) } pub fn create_receiver( notify: Arc, receiver: Receiver>, + connection_info: Option, ) -> FlightExchange { - FlightExchange::Receiver { notify, receiver } + FlightExchange::Receiver(ReceiverPayload { + seq: Arc::new(AtomicUsize::new(0)), + info: connection_info, + inner: Arc::new(AtomicPtr::new(Box::into_raw(Box::new( + FlightRxInner::create(notify, receiver), + )))), + }) + } + pub fn take_as_sender(&mut self) -> FlightSender { + let mut flight_sender = FlightExchange::Dummy; + std::mem::swap(self, &mut flight_sender); + + if let FlightExchange::Sender(mut v) = flight_sender { + let flight_sender = FlightSender::create(v.sender.take().unwrap()); + *self = FlightExchange::MovedSender(v); + return flight_sender; + } + + unreachable!("take as sender miss match exchange type") + } + + pub fn take_as_receiver(&mut self) -> RetryableFlightReceiver { + let mut flight_receiver = FlightExchange::Dummy; + std::mem::swap(self, &mut flight_receiver); + + if let FlightExchange::Receiver(v) = flight_receiver { + let flight_receiver = RetryableFlightReceiver { + seq: v.seq.clone(), + info: v.info.clone(), + inner: v.inner.clone(), + }; + + *self = FlightExchange::MovedReceiver(v); + + return flight_receiver; + } + + unreachable!("take as receiver miss match exchange type") + } +} + +pub struct FlightDataAckState { + seq: AtomicUsize, + auto_ack_window_size: usize, + + may_retry: bool, + receiver: Receiver>, + confirmation_queue: VecDeque<(usize, std::result::Result, Status>)>, +} + +impl FlightDataAckState { + pub fn create( + window_size: usize, + receiver: Receiver>, + ) -> Arc> { + Arc::new(Mutex::new(FlightDataAckState { + receiver, + may_retry: true, + seq: AtomicUsize::new(0), + auto_ack_window_size: window_size, + confirmation_queue: VecDeque::with_capacity(window_size), + })) + } + + fn ack_message(&mut self, seq: usize) { + while let Some((id, _)) = self.confirmation_queue.front() { + if *id <= seq { + self.confirmation_queue.pop_front(); + } else { + break; + } + } + } + + fn end_of_stream(&mut self) -> Poll, Status>>> { + let message_seq = self.seq.fetch_add(1, Ordering::SeqCst); + self.ack_message(message_seq); + + self.may_retry = false; + Poll::Ready(None) } - pub fn convert_to_sender(self) -> FlightSender { - match self { - FlightExchange::Sender(tx) => FlightSender { tx }, - _ => unreachable!(), + fn error_of_stream(&mut self, cause: Status) -> Poll, Status>>> { + let message_seq = self.seq.fetch_add(1, Ordering::SeqCst); + + // Automatically acknowledge messages outside the ACK window. + // A better approach is for the client to send back an ACK. + if message_seq >= self.auto_ack_window_size { + self.ack_message(message_seq - self.auto_ack_window_size); } + + self.confirmation_queue + .push_back((message_seq, Err(cause.clone()))); + Poll::Ready(Some(Err(cause))) } - pub fn convert_to_receiver(self) -> FlightReceiver { - match self { - FlightExchange::Receiver { notify, receiver } => FlightReceiver { - notify, - rx: receiver, - }, - _ => unreachable!(), + fn message(&mut self, data: FlightData) -> Poll, Status>>> { + let message_seq = self.seq.fetch_add(1, Ordering::SeqCst); + let data = Arc::new(data); + let duplicate = data.clone(); + + // Automatically acknowledge messages outside the ACK window. + // A better approach is for the client to send back an ACK. + if message_seq >= self.auto_ack_window_size { + self.ack_message(message_seq - self.auto_ack_window_size); } + + self.confirmation_queue.push_back((message_seq, Ok(data))); + Poll::Ready(Some(Ok(duplicate))) + } + + fn check_resend(&mut self) -> Option, Status>> { + let current_seq = self.seq.load(Ordering::SeqCst); + + // normal case, no resend + if let Some((id, _)) = self.confirmation_queue.back() { + if *id == current_seq - 1 { + return None; + } + } + + // message is ack + if let Some((id, _)) = self.confirmation_queue.front() { + if *id > current_seq { + return Some(Err(Status::aborted( + "Aborted query, because the remote flight channel is closed.", + ))); + } + } + + // resend case, iterate the queue to find the message to resend + for (id, res) in self.confirmation_queue.iter() { + if *id == current_seq { + self.seq.fetch_add(1, Ordering::SeqCst); + return Some(res.clone()); + } + } + + None + } + + pub fn poll_next( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, Status>>> { + if let Some(res) = self.check_resend() { + return Poll::Ready(Some(res)); + } + match Pin::new(&mut self.receiver).poll_next(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(None) => self.end_of_stream(), + Poll::Ready(Some(Err(status))) => self.error_of_stream(status), + Poll::Ready(Some(Ok(flight_data))) => self.message(flight_data), + } + } +} + +pub struct FlightDataAckStream { + state: Arc>, +} + +impl FlightDataAckStream { + pub fn create( + state: Arc>, + begin: usize, + ) -> Result { + // reset begin + info!("Create FlightDataAckStream hold lock"); + let mut state_guard = state.lock(); + state_guard.seq.store(begin, Ordering::SeqCst); + state_guard.may_retry = true; + drop(state_guard); + info!("Create FlightDataAckStream release lock"); + Ok(FlightDataAckStream { state }) + } +} + +impl Drop for FlightDataAckStream { + fn drop(&mut self) { + info!("Drop FlightDataAckStream"); + let state_should_retry = { + info!("Drop stage1 hold lock"); + let mut state = self.state.lock(); + if state.may_retry { + state.may_retry = false; + true + } else { + state.receiver.close(); + false + } + }; + info!("Drop stage1 release lock"); + if state_should_retry { + let weak = Arc::downgrade(&self.state); + GlobalIORuntime::instance().spawn(async move { + info!("Drop stage2 begin, wait for 60"); + tokio::time::sleep(Duration::from_secs(60)).await; + if let Some(ss) = weak.upgrade() { + info!("Drop stage2 hold lock"); + let ss = ss.lock(); + if !ss.may_retry { + ss.receiver.close(); + } + info!("Drop stage2 release lock"); + } + }); + } + } +} + +impl Stream for FlightDataAckStream { + type Item = std::result::Result, Status>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + info!("Poll next hold lock"); + let res = self.state.lock().poll_next(cx); + info!("Poll next release lock"); + res } } diff --git a/src/query/service/src/servers/flight/flight_service.rs b/src/query/service/src/servers/flight/flight_service.rs index c47207fda6a84..64507a2c0843d 100644 --- a/src/query/service/src/servers/flight/flight_service.rs +++ b/src/query/service/src/servers/flight/flight_service.rs @@ -150,7 +150,7 @@ where fn poll_ready( &mut self, _cx: &mut Context<'_>, - ) -> Poll> { + ) -> Poll> { Poll::Ready(Ok(())) } diff --git a/src/query/service/src/servers/flight/mod.rs b/src/query/service/src/servers/flight/mod.rs index ea435893df189..c057bfc09c55b 100644 --- a/src/query/service/src/servers/flight/mod.rs +++ b/src/query/service/src/servers/flight/mod.rs @@ -22,7 +22,7 @@ pub mod v1; pub use codec::MessageCodec; pub use flight_client::FlightClient; pub use flight_client::FlightExchange; -pub use flight_client::FlightReceiver; pub use flight_client::FlightSender; +pub use flight_client::RetryableFlightReceiver; pub use flight_server::FlightService; pub use request_builder::RequestBuilder; diff --git a/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs b/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs index 4af106bf13fde..7d3df09e33351 100644 --- a/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs +++ b/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs @@ -21,8 +21,6 @@ use std::sync::atomic::Ordering; use std::sync::Arc; use std::time::Duration; -use async_channel::Receiver; -use databend_common_arrow::arrow_format::flight::data::FlightData; use databend_common_arrow::arrow_format::flight::service::flight_service_client::FlightServiceClient; use databend_common_base::base::GlobalInstance; use databend_common_base::runtime::GlobalIORuntime; @@ -42,7 +40,6 @@ use parking_lot::Mutex; use parking_lot::ReentrantMutex; use petgraph::prelude::EdgeRef; use petgraph::Direction; -use tonic::Status; use super::exchange_params::ExchangeParams; use super::exchange_params::MergeExchangeParams; @@ -58,6 +55,9 @@ use crate::pipelines::PipelineBuildResult; use crate::pipelines::PipelineBuilder; use crate::schedulers::QueryFragmentActions; use crate::schedulers::QueryFragmentsActions; +use crate::servers::flight::flight_client::FlightDataAckState; +use crate::servers::flight::flight_client::FlightDataAckStream; +use crate::servers::flight::flight_client::RetryableFlightReceiver; use crate::servers::flight::v1::actions::init_query_fragments; use crate::servers::flight::v1::actions::INIT_QUERY_FRAGMENTS; use crate::servers::flight::v1::actions::START_PREPARED_QUERY; @@ -70,7 +70,6 @@ use crate::servers::flight::v1::packets::QueryFragment; use crate::servers::flight::v1::packets::QueryFragments; use crate::servers::flight::FlightClient; use crate::servers::flight::FlightExchange; -use crate::servers::flight::FlightReceiver; use crate::servers::flight::FlightSender; use crate::sessions::QueryContext; use crate::sessions::TableContext; @@ -130,6 +129,9 @@ impl DataExchangeManager { let config = GlobalConfig::instance(); let with_cur_rt = env.create_rpc_clint_with_current_rt; + let flight_retry_times = env.settings.get_max_flight_retry_times()?; + let flight_retry_interval = env.settings.get_flight_retry_interval()?; + let mut request_exchanges = HashMap::new(); let mut targets_exchanges = HashMap::new(); @@ -155,12 +157,27 @@ impl DataExchangeManager { Edge::Fragment(v) => QueryExchange::Fragment { source: source.id.clone(), fragment: v, - exchange: flight_client.do_get(&query_id, &target.id, v).await?, + exchange: flight_client + .do_get( + &query_id, + &target.id, + v, + &address, + flight_retry_times, + flight_retry_interval, + ) + .await?, }, Edge::Statistics => QueryExchange::Statistics { source: source.id.clone(), exchange: flight_client - .request_server_exchange(&query_id, &target.id) + .request_server_exchange( + &query_id, + &target.id, + &address, + flight_retry_times, + flight_retry_interval, + ) .await?, }, }) @@ -349,15 +366,21 @@ impl DataExchangeManager { &self, id: String, target: String, - ) -> Result, Status>>> { + continue_from: usize, + ) -> Result { let queries_coordinator_guard = self.queries_coordinator.lock(); let queries_coordinator = unsafe { &mut *queries_coordinator_guard.deref().get() }; match queries_coordinator.entry(id) { - Entry::Occupied(mut v) => v.get_mut().add_statistics_exchange(target), - Entry::Vacant(v) => v - .insert(QueryCoordinator::create()) - .add_statistics_exchange(target), + Entry::Occupied(mut v) => v.get_mut().add_statistics_exchange(target, continue_from), + Entry::Vacant(v) => match continue_from == 0 { + true => v + .insert(QueryCoordinator::create()) + .add_statistics_exchange(target, continue_from), + false => Err(ErrorCode::Timeout( + "Reconnection timeout, the state has been cleared", + )), + }, } } @@ -367,15 +390,26 @@ impl DataExchangeManager { query: String, target: String, fragment: usize, - ) -> Result, Status>>> { + continue_from: usize, + ) -> Result { let queries_coordinator_guard = self.queries_coordinator.lock(); let queries_coordinator = unsafe { &mut *queries_coordinator_guard.deref().get() }; match queries_coordinator.entry(query) { - Entry::Occupied(mut v) => v.get_mut().add_fragment_exchange(target, fragment), - Entry::Vacant(v) => v - .insert(QueryCoordinator::create()) - .add_fragment_exchange(target, fragment), + Entry::Occupied(mut v) => { + v.get_mut() + .add_fragment_exchange(target, fragment, continue_from) + } + Entry::Vacant(v) => match continue_from == 0 { + true => v.insert(QueryCoordinator::create()).add_fragment_exchange( + target, + fragment, + continue_from, + ), + false => Err(ErrorCode::Timeout( + "Reconnection timeout, the state has been cleared", + )), + }, } } @@ -458,12 +492,12 @@ impl DataExchangeManager { match queries_coordinator.get_mut(&query_id) { None => Err(ErrorCode::Internal("Query not exists.")), Some(query_coordinator) => { - assert!(query_coordinator.fragment_exchanges.is_empty()); + query_coordinator.assert_leak_fragment_exchanges(); let injector = DefaultExchangeInjector::create(); let mut build_res = query_coordinator.subscribe_fragment(&ctx, fragment_id, injector)?; - let exchanges = std::mem::take(&mut query_coordinator.statistics_exchanges); + let exchanges = query_coordinator.take_statistics_receivers(); let statistics_receiver = StatisticsReceiver::spawn_receiver(&ctx, exchanges)?; let statistics_receiver: Mutex = @@ -507,13 +541,13 @@ impl DataExchangeManager { pub fn get_flight_receiver( &self, params: &ExchangeParams, - ) -> Result> { + ) -> Result> { let queries_coordinator_guard = self.queries_coordinator.lock(); let queries_coordinator = unsafe { &mut *queries_coordinator_guard.deref().get() }; match queries_coordinator.get_mut(¶ms.get_query_id()) { None => Err(ErrorCode::Internal("Query not exists.")), - Some(coordinator) => coordinator.get_flight_receiver(params), + Some(coordinator) => coordinator.take_flight_receiver(params), } } @@ -551,15 +585,34 @@ struct QueryInfo { query_executor: Option>, } -static FLIGHT_SENDER: u8 = 1; -static FLIGHT_RECEIVER: u8 = 2; +#[derive(Hash, Eq, PartialEq)] +pub struct FragmentExchangeIdentifier { + target: String, + fragment: usize, +} + +#[derive(Hash, Eq, PartialEq)] +pub enum ExchangeIdentifier { + Statistics(String), + DataSender(FragmentExchangeIdentifier), + DataReceiver(FragmentExchangeIdentifier), +} + +impl ExchangeIdentifier { + pub fn fragment_sender(target: String, fragment: usize) -> Self { + ExchangeIdentifier::DataSender(FragmentExchangeIdentifier { target, fragment }) + } + + pub fn fragment_receiver(target: String, fragment: usize) -> Self { + ExchangeIdentifier::DataReceiver(FragmentExchangeIdentifier { target, fragment }) + } +} struct QueryCoordinator { info: Option, fragments_coordinator: HashMap>, - statistics_exchanges: HashMap, - fragment_exchanges: HashMap<(String, usize, u8), FlightExchange>, + exchanges: HashMap, } impl QueryCoordinator { @@ -567,24 +620,67 @@ impl QueryCoordinator { QueryCoordinator { info: None, fragments_coordinator: HashMap::new(), - fragment_exchanges: HashMap::new(), - statistics_exchanges: HashMap::new(), + exchanges: HashMap::new(), + } + } + + pub fn take_statistics_senders(&mut self) -> Vec { + let mut statistics_senders = Vec::with_capacity(1); + + for (identifier, exchange) in &mut self.exchanges { + if let ExchangeIdentifier::Statistics(_) = identifier { + statistics_senders.push(exchange.take_as_sender()); + } + } + + statistics_senders + } + + pub fn take_statistics_receivers(&mut self) -> Vec { + let mut statistics_receivers = Vec::with_capacity(self.exchanges.len()); + + for (identifier, exchange) in &mut self.exchanges { + if let ExchangeIdentifier::Statistics(_) = identifier { + statistics_receivers.push(exchange.take_as_receiver()); + } + } + + statistics_receivers + } + + pub fn assert_leak_fragment_exchanges(&self) { + for (identifier, exchange) in &self.exchanges { + if !matches!(identifier, ExchangeIdentifier::Statistics(_)) { + assert!(matches!( + exchange, + FlightExchange::MovedSender(_) | FlightExchange::MovedReceiver(_) + )); + } } } pub fn add_statistics_exchange( &mut self, target: String, - ) -> Result, Status>>> { + begin: usize, + ) -> Result { let (tx, rx) = async_channel::bounded(8); - match self - .statistics_exchanges - .insert(target, FlightExchange::create_sender(tx)) - { - None => Ok(rx), - Some(_) => Err(ErrorCode::Internal( - "statistics exchanges can only have one", - )), + let identifier = ExchangeIdentifier::Statistics(target); + + match self.exchanges.entry(identifier) { + Entry::Vacant(v) => { + let state = FlightDataAckState::create(10, rx); + v.insert(FlightExchange::create_sender(state.clone(), tx)); + FlightDataAckStream::create(state, begin) + } + Entry::Occupied(mut v) => match v.get_mut() { + FlightExchange::MovedSender(v) => { + FlightDataAckStream::create(v.state.clone(), begin) + } + _ => Err(ErrorCode::Internal( + "statistics exchanges can only have one", + )), + }, } } @@ -593,7 +689,8 @@ impl QueryCoordinator { exchanges: HashMap, ) -> Result<()> { for (source, exchange) in exchanges.into_iter() { - if self.statistics_exchanges.insert(source, exchange).is_some() { + let identifier = ExchangeIdentifier::Statistics(source); + if self.exchanges.insert(identifier, exchange).is_some() { return Err(ErrorCode::Internal( "Internal error, statistics exchange can only have one.", )); @@ -607,13 +704,24 @@ impl QueryCoordinator { &mut self, target: String, fragment: usize, - ) -> Result, Status>>> { + begin: usize, + ) -> Result { let (tx, rx) = async_channel::bounded(8); - self.fragment_exchanges.insert( - (target, fragment, FLIGHT_SENDER), - FlightExchange::create_sender(tx), - ); - Ok(rx) + let identifier = ExchangeIdentifier::fragment_sender(target, fragment); + + match self.exchanges.entry(identifier) { + Entry::Vacant(v) => { + let state = FlightDataAckState::create(10, rx); + v.insert(FlightExchange::create_sender(state.clone(), tx)); + FlightDataAckStream::create(state, begin) + } + Entry::Occupied(mut v) => match v.get_mut() { + FlightExchange::MovedSender(v) => { + FlightDataAckStream::create(v.state.clone(), begin) + } + _ => Err(ErrorCode::Internal("fragment exchange can only have one")), + }, + } } pub fn add_fragment_exchanges( @@ -621,83 +729,101 @@ impl QueryCoordinator { exchanges: HashMap<(String, usize), FlightExchange>, ) -> Result<()> { for ((source, fragment), exchange) in exchanges.into_iter() { - self.fragment_exchanges - .insert((source, fragment, FLIGHT_RECEIVER), exchange); + let identifier = ExchangeIdentifier::fragment_receiver(source, fragment); + + self.exchanges.insert(identifier, exchange); } Ok(()) } pub fn get_flight_senders(&mut self, params: &ExchangeParams) -> Result> { + let mut fragments_exchanges = Vec::with_capacity(self.exchanges.len()); + match params { - ExchangeParams::MergeExchange(params) => Ok(self - .fragment_exchanges - .extract_if(|(_, f, r), _| f == ¶ms.fragment_id && *r == FLIGHT_SENDER) - .map(|(_, v)| v.convert_to_sender()) - .collect::>()), - ExchangeParams::ShuffleExchange(params) => { - let mut exchanges = Vec::with_capacity(params.destination_ids.len()); + ExchangeParams::MergeExchange(params) => { + for (identifier, exchange) in &mut self.exchanges { + if let ExchangeIdentifier::DataSender(v) = identifier { + if v.fragment != params.fragment_id { + continue; + } - for destination in ¶ms.destination_ids { - exchanges.push(match destination == ¶ms.executor_id { - true => Ok(FlightSender::create(async_channel::bounded(1).0)), - false => match self.fragment_exchanges.remove(&( - destination.clone(), - params.fragment_id, - FLIGHT_SENDER, - )) { - Some(exchange_channel) => Ok(exchange_channel.convert_to_sender()), - None => Err(ErrorCode::UnknownFragmentExchange(format!( - "Unknown fragment exchange channel, {}, {}", - destination, params.fragment_id - ))), - }, - }?); + fragments_exchanges.push(exchange.take_as_sender()); + } } + } + ExchangeParams::ShuffleExchange(params) => { + for destination in ¶ms.destination_ids { + if destination == ¶ms.executor_id { + let dummy = FlightSender::create(async_channel::bounded(1).0); + fragments_exchanges.push(dummy); + continue; + } + + let target = destination.clone(); + let fragment = params.fragment_id; + let identifier = ExchangeIdentifier::fragment_sender(target, fragment); + if let Some(v) = self.exchanges.get_mut(&identifier) { + fragments_exchanges.push(v.take_as_sender()); + continue; + } - Ok(exchanges) + return Err(ErrorCode::UnknownFragmentExchange(format!( + "Unknown fragment exchange channel, {}, {}", + destination, params.fragment_id + ))); + } } - } + }; + + Ok(fragments_exchanges) } - pub fn get_flight_receiver( + pub fn take_flight_receiver( &mut self, params: &ExchangeParams, - ) -> Result> { + ) -> Result> { + let mut fragments_exchanges = Vec::with_capacity(self.exchanges.len()); + match params { - ExchangeParams::MergeExchange(params) => Ok(self - .fragment_exchanges - .extract_if(|(_, f, r), _| f == ¶ms.fragment_id && *r == FLIGHT_RECEIVER) - .map(|((source, _, _), v)| (source.clone(), v.convert_to_receiver())) - .collect::>()), - ExchangeParams::ShuffleExchange(params) => { - let mut exchanges = Vec::with_capacity(params.destination_ids.len()); + ExchangeParams::MergeExchange(params) => { + for (identifier, exchange) in &mut self.exchanges { + if let ExchangeIdentifier::DataReceiver(v) = identifier { + if v.fragment != params.fragment_id { + continue; + } - for destination in ¶ms.destination_ids { - exchanges.push(( - destination.clone(), - match destination == ¶ms.executor_id { - true => Ok(FlightReceiver::create(async_channel::bounded(1).1)), - false => match self.fragment_exchanges.remove(&( - destination.clone(), - params.fragment_id, - FLIGHT_RECEIVER, - )) { - Some(v) => Ok(v.convert_to_receiver()), - _ => Err(ErrorCode::UnknownFragmentExchange(format!( - "Unknown fragment flight receiver, {}, {}", - destination, params.fragment_id - ))), - }, - }?, - )); + fragments_exchanges.push((v.target.clone(), exchange.take_as_receiver())); + } } + } + ExchangeParams::ShuffleExchange(params) => { + for destination in ¶ms.destination_ids { + if destination == ¶ms.executor_id { + let dummy = RetryableFlightReceiver::dummy(); + fragments_exchanges.push((destination.clone(), dummy)); + continue; + } - Ok(exchanges) + let source = destination.clone(); + let fragment = params.fragment_id; + let identifier = ExchangeIdentifier::fragment_receiver(source, fragment); + if let Some(v) = self.exchanges.get_mut(&identifier) { + let receiver = v.take_as_receiver(); + fragments_exchanges.push((destination.clone(), receiver)); + continue; + } + + return Err(ErrorCode::UnknownFragmentExchange(format!( + "Unknown fragment flight receiver, {}, {}", + destination, params.fragment_id + ))); + } } - } - } + }; + Ok(fragments_exchanges) + } pub fn prepare_pipeline(&mut self, fragments: &QueryFragments) -> Result<()> { let query_info = self.info.as_ref().expect("expect query info"); let query_context = query_info.query_ctx.clone(); @@ -840,28 +966,30 @@ impl QueryCoordinator { let settings = ExecutorSettings::try_create(info.query_ctx.clone())?; let executor = PipelineCompleteExecutor::from_pipelines(pipelines, settings)?; - assert!(self.fragment_exchanges.is_empty()); + self.assert_leak_fragment_exchanges(); let info_mut = self.info.as_mut().expect("Query info is None"); info_mut.query_executor = Some(executor.clone()); let query_id = info_mut.query_id.clone(); let query_ctx = info_mut.query_ctx.clone(); - let request_server_exchanges = std::mem::take(&mut self.statistics_exchanges); - if request_server_exchanges.len() != 1 { + let ctx = query_ctx.clone(); + let mut statistics_senders = self.take_statistics_senders(); + + let Some(statistics_sender) = statistics_senders.pop() else { + return Err(ErrorCode::Internal( + "Request server must less than 1 if is not request server.", + )); + }; + + if !statistics_senders.is_empty() { return Err(ErrorCode::Internal( "Request server must less than 1 if is not request server.", )); } - let ctx = query_ctx.clone(); - let (_, request_server_exchange) = request_server_exchanges.into_iter().next().unwrap(); - let mut statistics_sender = StatisticsSender::spawn( - &query_id, - ctx, - request_server_exchange, - executor.get_inner(), - ); + let mut statistics_sender = + StatisticsSender::spawn(&query_id, ctx, statistics_sender, executor.get_inner()); let span = if let Some(parent) = SpanContext::current_local_parent() { Span::root("Distributed-Executor", parent) diff --git a/src/query/service/src/servers/flight/v1/exchange/exchange_source_reader.rs b/src/query/service/src/servers/flight/v1/exchange/exchange_source_reader.rs index 88a1a139c21b4..daa7c4bfec835 100644 --- a/src/query/service/src/servers/flight/v1/exchange/exchange_source_reader.rs +++ b/src/query/service/src/servers/flight/v1/exchange/exchange_source_reader.rs @@ -27,15 +27,15 @@ use databend_common_pipeline_core::processors::ProcessorPtr; use databend_common_pipeline_core::PipeItem; use log::info; +use crate::servers::flight::flight_client::RetryableFlightReceiver; use crate::servers::flight::v1::exchange::serde::ExchangeDeserializeMeta; use crate::servers::flight::v1::packets::DataPacket; -use crate::servers::flight::FlightReceiver; pub struct ExchangeSourceReader { finished: AtomicBool, output: Arc, output_data: Vec, - flight_receiver: FlightReceiver, + flight_receiver: RetryableFlightReceiver, source: String, destination: String, fragment: usize, @@ -44,7 +44,7 @@ pub struct ExchangeSourceReader { impl ExchangeSourceReader { pub fn create( output: Arc, - flight_receiver: FlightReceiver, + flight_receiver: RetryableFlightReceiver, source: &str, destination: &str, fragment: usize, @@ -156,7 +156,7 @@ impl Processor for ExchangeSourceReader { } pub fn create_reader_item( - flight_receiver: FlightReceiver, + flight_receiver: RetryableFlightReceiver, source: &str, destination: &str, fragment: usize, diff --git a/src/query/service/src/servers/flight/v1/exchange/statistics_receiver.rs b/src/query/service/src/servers/flight/v1/exchange/statistics_receiver.rs index 84b89ab89b344..bfb8b3291e069 100644 --- a/src/query/service/src/servers/flight/v1/exchange/statistics_receiver.rs +++ b/src/query/service/src/servers/flight/v1/exchange/statistics_receiver.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; use std::sync::Arc; use databend_common_base::base::tokio::sync::broadcast::channel; @@ -25,8 +24,8 @@ use databend_common_exception::Result; use futures_util::future::select; use futures_util::future::Either; +use crate::servers::flight::flight_client::RetryableFlightReceiver; use crate::servers::flight::v1::packets::DataPacket; -use crate::servers::flight::FlightExchange; use crate::sessions::QueryContext; pub struct StatisticsReceiver { @@ -38,14 +37,13 @@ pub struct StatisticsReceiver { impl StatisticsReceiver { pub fn spawn_receiver( ctx: &Arc, - statistics_exchanges: HashMap, + statistics_exchanges: Vec, ) -> Result { let (shutdown_tx, _shutdown_rx) = channel(2); let mut exchange_handler = Vec::with_capacity(statistics_exchanges.len()); let runtime = Runtime::with_worker_threads(2, Some(String::from("StatisticsReceiver")))?; - for (_source, exchange) in statistics_exchanges.into_iter() { - let rx = exchange.convert_to_receiver(); + for rx in statistics_exchanges.into_iter() { exchange_handler.push(runtime.spawn({ let ctx = ctx.clone(); let shutdown_rx = shutdown_tx.subscribe(); diff --git a/src/query/service/src/servers/flight/v1/exchange/statistics_sender.rs b/src/query/service/src/servers/flight/v1/exchange/statistics_sender.rs index b26e8f4d2854d..95d0af8a730fa 100644 --- a/src/query/service/src/servers/flight/v1/exchange/statistics_sender.rs +++ b/src/query/service/src/servers/flight/v1/exchange/statistics_sender.rs @@ -29,7 +29,6 @@ use log::warn; use crate::pipelines::executor::PipelineExecutor; use crate::servers::flight::v1::packets::DataPacket; use crate::servers::flight::v1::packets::ProgressInfo; -use crate::servers::flight::FlightExchange; use crate::servers::flight::FlightSender; use crate::sessions::QueryContext; @@ -43,11 +42,10 @@ impl StatisticsSender { pub fn spawn( query_id: &str, ctx: Arc, - exchange: FlightExchange, + tx: FlightSender, executor: Arc, ) -> Self { let spawner = ctx.clone(); - let tx = exchange.convert_to_sender(); let (shutdown_flag_sender, shutdown_flag_receiver) = async_channel::bounded(1); let handle = spawner.spawn({ @@ -231,8 +229,4 @@ impl StatisticsSender { progress_info } - - // fn fetch_profiling(ctx: &Arc) -> Result> { - // // ctx.get_exchange_manager() - // } } diff --git a/src/query/service/src/servers/flight/v1/flight_service.rs b/src/query/service/src/servers/flight/v1/flight_service.rs index 541e8e4b1f10c..80ddcbac209ba 100644 --- a/src/query/service/src/servers/flight/v1/flight_service.rs +++ b/src/query/service/src/servers/flight/v1/flight_service.rs @@ -111,8 +111,16 @@ impl FlightOperation for DatabendQueryFlightService { "request_server_exchange" => { let target = request.get_metadata("x-target")?; let query_id = request.get_metadata("x-query-id")?; + let continue_from = request + .get_metadata("x-continue-from")? + .parse::() + .unwrap(); Ok(RawResponse::new(Box::pin( - DataExchangeManager::instance().handle_statistics_exchange(query_id, target)?, + DataExchangeManager::instance().handle_statistics_exchange( + query_id, + target, + continue_from, + )?, ))) } "exchange_fragment" => { @@ -122,10 +130,17 @@ impl FlightOperation for DatabendQueryFlightService { .get_metadata("x-fragment-id")? .parse::() .unwrap(); - + let continue_from = request + .get_metadata("x-continue-from")? + .parse::() + .unwrap(); Ok(RawResponse::new(Box::pin( - DataExchangeManager::instance() - .handle_exchange_fragment(query_id, target, fragment)?, + DataExchangeManager::instance().handle_exchange_fragment( + query_id, + target, + fragment, + continue_from, + )?, ))) } "health" => Ok(RawResponse::new(build_health_response())), diff --git a/src/query/service/src/servers/flight/v1/packets/packet_executor.rs b/src/query/service/src/servers/flight/v1/packets/packet_executor.rs index c555184d82912..29999b7727b31 100644 --- a/src/query/service/src/servers/flight/v1/packets/packet_executor.rs +++ b/src/query/service/src/servers/flight/v1/packets/packet_executor.rs @@ -14,7 +14,7 @@ use crate::servers::flight::v1::packets::QueryFragment; -#[derive(Debug, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct QueryFragments { pub query_id: String, pub fragments: Vec, diff --git a/src/query/settings/src/settings_default.rs b/src/query/settings/src/settings_default.rs index 2c189bb5681f8..424335ad5b2cc 100644 --- a/src/query/settings/src/settings_default.rs +++ b/src/query/settings/src/settings_default.rs @@ -856,6 +856,18 @@ impl DefaultSettings { mode: SettingMode::Both, range: Some(SettingRange::Numeric(0..=1)), }), + ("max_flight_connection_retry_times", DefaultSettingValue { + value: UserSettingValue::UInt64(3), + desc: "The maximum retry count for cluster flight. Disable if 0.", + mode: SettingMode::Both, + range: Some(SettingRange::Numeric(0..=30)), + }), + ("flight_connection_retry_interval", DefaultSettingValue { + value: UserSettingValue::UInt64(5), + desc: "The retry interval of cluster flight is in seconds.", + mode: SettingMode::Both, + range: Some(SettingRange::Numeric(0..=900)), + }), ("random_function_seed", DefaultSettingValue { value: UserSettingValue::UInt64(0), desc: "Seed for random function", diff --git a/src/query/settings/src/settings_getter_setter.rs b/src/query/settings/src/settings_getter_setter.rs index 866e56772072d..84164d8e3afd6 100644 --- a/src/query/settings/src/settings_getter_setter.rs +++ b/src/query/settings/src/settings_getter_setter.rs @@ -706,6 +706,14 @@ impl Settings { Ok(self.try_get_u64("random_function_seed")? == 1) } + pub fn get_flight_retry_interval(&self) -> Result { + Ok(self.try_get_u64("flight_connection_retry_interval")? as usize) + } + + pub fn get_max_flight_retry_times(&self) -> Result { + Ok(self.try_get_u64("max_flight_connection_retry_times")? as usize) + } + pub fn get_dynamic_sample_time_budget_ms(&self) -> Result { self.try_get_u64("dynamic_sample_time_budget_ms") } diff --git a/tests/suites/1_stateful/02_query/02_0009_kill_connection.result b/tests/suites/1_stateful/02_query/02_0009_kill_connection.result new file mode 100644 index 0000000000000..be2d12014cc17 --- /dev/null +++ b/tests/suites/1_stateful/02_query/02_0009_kill_connection.result @@ -0,0 +1 @@ +Final state: Succeeded diff --git a/tests/suites/1_stateful/02_query/02_0009_kill_connection.sh b/tests/suites/1_stateful/02_query/02_0009_kill_connection.sh new file mode 100755 index 0000000000000..9f632454311ed --- /dev/null +++ b/tests/suites/1_stateful/02_query/02_0009_kill_connection.sh @@ -0,0 +1,76 @@ +#!/usr/bin/env bash + +perform_initial_query() { + local response=$(curl -s -u root: -XPOST "http://localhost:8000/v1/query" -H 'Content-Type: application/json' -d '{"sql": "select avg(number) from numbers(1000000000)"}') + local stats_uri=$(echo "$response" | jq -r '.stats_uri') + local final_uri=$(echo "$response" | jq -r '.final_uri') + echo "$stats_uri|$final_uri" +} +poll_stats_uri() { + local uri=$1 + local state_exists=true + while $state_exists; do + local response=$(curl -s -u root: -XGET "http://localhost:8000$uri") + if ! echo "$response" | jq -e '.state' > /dev/null; then + state_exists=false + else + sleep 2 + fi + done +} +get_final_state() { + local uri=$1 + local response=$(curl -s -u root: -XGET "http://localhost:8000$uri") + echo "$response" | jq -r '.state' +} + +IFS='|' read -r stats_uri final_uri <<< $(perform_initial_query) + +poll_stats_uri "$stats_uri" & +POLL_PID=$! +sleep 1 +netstat_output=$(netstat -an | grep '9092') + +# skip standalone mode +if [ -z "$netstat_output" ]; then + echo "Final state: Succeeded" + exit 0 +fi + +port=$(echo "$netstat_output" | awk ' + $NF == "ESTABLISHED" { + if ($4 ~ /:9092$/) { + split($5, a, ":") + port = a[2] + } else if ($5 ~ /:9092$/) { + split($4, a, ":") + port = a[2] + } + } + END { + print port + } +') + +# Start tcpkill in the background +sudo tcpkill -i lo host 127.0.0.1 and port $port > tcpkill_output.txt 2>&1 & +TCPKILL_PID=$! + +# Wait for tcpkill to output at least 3 lines or terminate if done earlier +while [ $(wc -l < tcpkill_output.txt) -lt 3 ]; do + if ! kill -0 $TCPKILL_PID 2> /dev/null; then + break + fi + sleep 1 +done + +# Kill tcpkill after the desired number of lines if it's still running +if kill -0 $TCPKILL_PID 2> /dev/null; then + kill $TCPKILL_PID +fi + +wait $POLL_PID + +final_state=$(get_final_state "$final_uri") +echo "Final state: $final_state" + From f2e352f85cbd9860add3febdf7e8ead3b9589d80 Mon Sep 17 00:00:00 2001 From: Liuqing Yue Date: Wed, 4 Sep 2024 16:49:28 +0800 Subject: [PATCH 03/18] refactor: change to do_exchange with a request-response flow --- .../serde/transform_deserializer.rs | 1 + .../src/servers/flight/flight_client.rs | 301 +++++++++++------- .../src/servers/flight/flight_service.rs | 5 +- .../flight/v1/exchange/exchange_manager.rs | 36 ++- .../exchange/serde/exchange_deserializer.rs | 1 + .../flight/v1/exchange/statistics_receiver.rs | 1 + .../src/servers/flight/v1/flight_service.rs | 24 +- .../src/servers/flight/v1/packets/mod.rs | 1 + .../servers/flight/v1/packets/packet_data.rs | 21 ++ .../tests/it/servers/flight/flight_service.rs | 24 +- .../02_query/02_0009_kill_connection.sh | 16 +- 11 files changed, 272 insertions(+), 159 deletions(-) diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_deserializer.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_deserializer.rs index 98abcec37315d..1c5ec5527a113 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_deserializer.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_deserializer.rs @@ -225,6 +225,7 @@ where DataPacket::MutationStatus { .. } => unreachable!(), DataPacket::DataCacheMetrics(_) => unreachable!(), DataPacket::FragmentData(v) => Ok(vec![self.recv_data(meta.packet, v)?]), + DataPacket::FlightControl(_) => unreachable!(), } } } diff --git a/src/query/service/src/servers/flight/flight_client.rs b/src/query/service/src/servers/flight/flight_client.rs index 5f514a7811475..2b6a68291e6ea 100644 --- a/src/query/service/src/servers/flight/flight_client.rs +++ b/src/query/service/src/servers/flight/flight_client.rs @@ -12,25 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::VecDeque; use std::pin::Pin; use std::str::FromStr; +use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicPtr; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering; use std::sync::Arc; use std::task::Context; use std::task::Poll; +use std::task::Waker; use async_channel::Receiver; use async_channel::Sender; use databend_common_arrow::arrow_format::flight::data::Action; use databend_common_arrow::arrow_format::flight::data::FlightData; -use databend_common_arrow::arrow_format::flight::data::Ticket; use databend_common_arrow::arrow_format::flight::service::flight_service_client::FlightServiceClient; use databend_common_base::base::tokio::time::Duration; use databend_common_base::runtime::GlobalIORuntime; use databend_common_base::runtime::TrySpawn; +use databend_common_base::JoinHandle; use databend_common_exception::ErrorCode; use databend_common_exception::Result; use fastrace::func_path; @@ -40,6 +41,7 @@ use futures::Stream; use futures::StreamExt; use futures_util::future::Either; use log::info; +use log::warn; use parking_lot::Mutex; use serde::Deserialize; use serde::Serialize; @@ -55,6 +57,7 @@ use crate::pipelines::executor::WatchNotify; use crate::servers::flight::request_builder::RequestBuilder; use crate::servers::flight::v1::exchange::DataExchangeManager; use crate::servers::flight::v1::packets::DataPacket; +use crate::servers::flight::v1::packets::FlightControlCommand; pub struct FlightClient { inner: FlightServiceClient, @@ -128,7 +131,7 @@ impl FlightClient { } #[async_backtrace::framed] - pub async fn request_server_exchange( + pub async fn request_statistics_exchange( &mut self, query_id: &str, target: &str, @@ -136,7 +139,8 @@ impl FlightClient { retry_times: usize, retry_interval: usize, ) -> Result { - let req = RequestBuilder::create(Ticket::default()) + let (server_tx, server_rx) = async_channel::bounded(1); + let req = RequestBuilder::create(Box::pin(server_rx)) .with_metadata("x-type", "request_server_exchange")? .with_metadata("x-target", target)? .with_metadata("x-query-id", query_id)? @@ -156,12 +160,13 @@ impl FlightClient { retry_times, retry_interval: Duration::from_secs(retry_interval as u64), }), + server_tx, )) } #[async_backtrace::framed] #[fastrace::trace] - pub async fn do_get( + pub async fn request_fragment_exchange( &mut self, query_id: &str, target: &str, @@ -170,7 +175,9 @@ impl FlightClient { retry_times: usize, retry_interval: usize, ) -> Result { - let request = RequestBuilder::create(Ticket::default()) + let (server_tx, server_rx) = async_channel::bounded(1); + + let request = RequestBuilder::create(Box::pin(server_rx)) .with_metadata("x-type", "exchange_fragment")? .with_metadata("x-target", target)? .with_metadata("x-query-id", query_id)? @@ -193,6 +200,7 @@ impl FlightClient { retry_times, retry_interval: Duration::from_secs(retry_interval as u64), }), + server_tx, )) } @@ -209,7 +217,10 @@ impl FlightClient { loop { match futures::future::select(notified, streaming_next).await { - Either::Left((_, _)) | Either::Right((None, _)) => { + Either::Left((_, _)) => { + break; + } + Either::Right((None, _)) => { break; } Either::Right((Some(message), next_notified)) => { @@ -230,7 +241,6 @@ impl FlightClient { } } } - drop(streaming); tx.close(); } @@ -243,8 +253,11 @@ impl FlightClient { } #[async_backtrace::framed] - async fn get_streaming(&mut self, request: Request) -> Result> { - match self.inner.do_get(request).await { + async fn get_streaming( + &mut self, + request: Request>>>, + ) -> Result> { + match self.inner.do_exchange(request).await { Ok(res) => Ok(res.into_inner()), Err(status) => Err(ErrorCode::from(status).add_message_back("(while in query flight)")), } @@ -252,15 +265,16 @@ impl FlightClient { #[async_backtrace::framed] async fn reconnect(&mut self, info: &ConnectionInfo, seq: usize) -> Result { + let (server_tx, server_rx) = async_channel::bounded(1); let request = match info.fragment { - Some(fragment_id) => RequestBuilder::create(Ticket::default()) + Some(fragment_id) => RequestBuilder::create(Box::pin(server_rx)) .with_metadata("x-type", "exchange_fragment")? .with_metadata("x-target", &info.target)? .with_metadata("x-query-id", &info.query_id)? .with_metadata("x-fragment-id", &fragment_id.to_string())? .with_metadata("x-continue-from", &seq.to_string())? .build(), - None => RequestBuilder::create(Ticket::default()) + None => RequestBuilder::create(Box::pin(server_rx)) .with_metadata("x-type", "request_server_exchange")? .with_metadata("x-target", &info.target)? .with_metadata("x-query-id", &info.query_id)? @@ -272,7 +286,7 @@ impl FlightClient { let streaming = self.get_streaming(request).await?; let (network_notify, recv) = Self::streaming_receiver(streaming); - Ok(FlightRxInner::create(network_notify, recv)) + Ok(FlightRxInner::create(network_notify, recv, server_tx)) } } @@ -289,11 +303,20 @@ pub struct ConnectionInfo { pub struct FlightRxInner { notify: Arc, rx: Receiver>, + server_tx: Sender, } impl FlightRxInner { - pub fn create(notify: Arc, rx: Receiver>) -> FlightRxInner { - FlightRxInner { rx, notify } + pub fn create( + notify: Arc, + rx: Receiver>, + server_tx: Sender, + ) -> FlightRxInner { + FlightRxInner { + rx, + notify, + server_tx, + } } #[async_backtrace::framed] @@ -308,6 +331,10 @@ impl FlightRxInner { pub fn close(&self) { self.rx.close(); self.notify.notify_waiters(); + let res = self.server_tx.send( + FlightData::try_from(DataPacket::FlightControl(FlightControlCommand::Close)).unwrap(), + ); + info!("Send close signal to flight server, result: {:?}", res); } } @@ -339,7 +366,18 @@ impl RetryableFlightReceiver { let inner = unsafe { &*self.inner.load(Ordering::SeqCst) }; return match inner.recv().await { Ok(message) => { - self.seq.fetch_add(1, Ordering::SeqCst); + let ack_seq = self.seq.fetch_add(1, Ordering::SeqCst); + if message.is_some() { + let error = inner + .server_tx + .send(FlightData::try_from(DataPacket::FlightControl( + FlightControlCommand::Ack(ack_seq), + ))?) + .await; + if error.is_err() { + info!("Error while sending ack to flight : {:?}", error); + } + } Ok(message) } Err(cause) => { @@ -471,12 +509,13 @@ impl FlightExchange { notify: Arc, receiver: Receiver>, connection_info: Option, + server_tx: Sender, ) -> FlightExchange { FlightExchange::Receiver(ReceiverPayload { seq: Arc::new(AtomicUsize::new(0)), info: connection_info, inner: Arc::new(AtomicPtr::new(Box::into_raw(Box::new( - FlightRxInner::create(notify, receiver), + FlightRxInner::create(notify, receiver, server_tx), )))), }) } @@ -515,98 +554,60 @@ impl FlightExchange { pub struct FlightDataAckState { seq: AtomicUsize, - auto_ack_window_size: usize, - - may_retry: bool, + finish: AtomicBool, receiver: Receiver>, - confirmation_queue: VecDeque<(usize, std::result::Result, Status>)>, + last_packet: Option<(usize, std::result::Result, Status>)>, + clean_up_handle: Option>, + waker: Option, } impl FlightDataAckState { pub fn create( - window_size: usize, receiver: Receiver>, ) -> Arc> { Arc::new(Mutex::new(FlightDataAckState { receiver, - may_retry: true, seq: AtomicUsize::new(0), - auto_ack_window_size: window_size, - confirmation_queue: VecDeque::with_capacity(window_size), + last_packet: None, + finish: AtomicBool::new(false), + clean_up_handle: None, + waker: None, })) } - fn ack_message(&mut self, seq: usize) { - while let Some((id, _)) = self.confirmation_queue.front() { - if *id <= seq { - self.confirmation_queue.pop_front(); - } else { - break; - } - } + fn error_of_stream( + &mut self, + cause: Status, + ) -> Poll, Status>>> { + let message_seq = self.seq.fetch_add(1, Ordering::SeqCst); + self.last_packet = Some((message_seq, Err(cause.clone()))); + Poll::Ready(Some(Err(cause))) } fn end_of_stream(&mut self) -> Poll, Status>>> { - let message_seq = self.seq.fetch_add(1, Ordering::SeqCst); - self.ack_message(message_seq); - - self.may_retry = false; + self.seq.fetch_add(1, Ordering::SeqCst); + self.finish.store(true, Ordering::SeqCst); Poll::Ready(None) } - fn error_of_stream(&mut self, cause: Status) -> Poll, Status>>> { - let message_seq = self.seq.fetch_add(1, Ordering::SeqCst); - - // Automatically acknowledge messages outside the ACK window. - // A better approach is for the client to send back an ACK. - if message_seq >= self.auto_ack_window_size { - self.ack_message(message_seq - self.auto_ack_window_size); - } - - self.confirmation_queue - .push_back((message_seq, Err(cause.clone()))); - Poll::Ready(Some(Err(cause))) - } - - fn message(&mut self, data: FlightData) -> Poll, Status>>> { + fn message( + &mut self, + data: FlightData, + ) -> Poll, Status>>> { let message_seq = self.seq.fetch_add(1, Ordering::SeqCst); let data = Arc::new(data); let duplicate = data.clone(); - - // Automatically acknowledge messages outside the ACK window. - // A better approach is for the client to send back an ACK. - if message_seq >= self.auto_ack_window_size { - self.ack_message(message_seq - self.auto_ack_window_size); - } - - self.confirmation_queue.push_back((message_seq, Ok(data))); + self.last_packet = Some((message_seq, Ok(data))); Poll::Ready(Some(Ok(duplicate))) } fn check_resend(&mut self) -> Option, Status>> { let current_seq = self.seq.load(Ordering::SeqCst); - // normal case, no resend - if let Some((id, _)) = self.confirmation_queue.back() { - if *id == current_seq - 1 { - return None; - } - } - - // message is ack - if let Some((id, _)) = self.confirmation_queue.front() { - if *id > current_seq { - return Some(Err(Status::aborted( - "Aborted query, because the remote flight channel is closed.", - ))); - } - } - - // resend case, iterate the queue to find the message to resend - for (id, res) in self.confirmation_queue.iter() { - if *id == current_seq { + if let Some((seq, packet)) = &self.last_packet { + if seq == ¤t_seq { self.seq.fetch_add(1, Ordering::SeqCst); - return Some(res.clone()); + return Some(packet.clone()); } } @@ -617,9 +618,20 @@ impl FlightDataAckState { &mut self, cx: &mut Context<'_>, ) -> Poll, Status>>> { + if self.finish.load(Ordering::SeqCst) { + return Poll::Ready(None); + } + + // check if seq has been reset, if so, resend last packet if let Some(res) = self.check_resend() { return Poll::Ready(Some(res)); } + + // last packet is not acked, need to wait + if self.last_packet.is_some() { + self.waker = Some(cx.waker().clone()); + return Poll::Pending; + } match Pin::new(&mut self.receiver).poll_next(cx) { Poll::Pending => Poll::Pending, Poll::Ready(None) => self.end_of_stream(), @@ -630,6 +642,7 @@ impl FlightDataAckState { } pub struct FlightDataAckStream { + notify: Arc, state: Arc>, } @@ -637,48 +650,107 @@ impl FlightDataAckStream { pub fn create( state: Arc>, begin: usize, + client_stream: Streaming, ) -> Result { - // reset begin - info!("Create FlightDataAckStream hold lock"); + let notify = Self::streaming_receiver(state.clone(), client_stream); let mut state_guard = state.lock(); state_guard.seq.store(begin, Ordering::SeqCst); - state_guard.may_retry = true; + if let Some(handle) = state_guard.clean_up_handle.take() { + handle.abort(); + } drop(state_guard); - info!("Create FlightDataAckStream release lock"); - Ok(FlightDataAckStream { state }) + Ok(FlightDataAckStream { notify, state }) + } + + fn streaming_receiver( + state: Arc>, + mut streaming: Streaming, + ) -> Arc { + let notify = Arc::new(WatchNotify::new()); + let fut = { + let notify = notify.clone(); + async move { + let mut notified = Box::pin(notify.notified()); + let mut streaming_next = streaming.next(); + + loop { + match futures::future::select(notified, streaming_next).await { + Either::Left((_, _)) | Either::Right((None, _)) => { + break; + } + Either::Right((Some(message), next_notified)) => { + notified = next_notified; + streaming_next = streaming.next(); + match message { + Ok(message) => { + let packet = DataPacket::try_from(message); + match packet { + Ok(DataPacket::FlightControl(command)) => match command { + FlightControlCommand::Ack(_seq) => { + let mut state_guard = state.lock(); + state_guard.last_packet = None; + if let Some(waker) = state_guard.waker.take() { + waker.wake(); + } + drop(state_guard); + } + FlightControlCommand::Close => { + state.lock().finish.store(true, Ordering::SeqCst); + info!("Received Command Close"); + break; + } + }, + Ok(_) => { + unreachable!( + "logic error: only FlightControl packet is expected" + ) + } + Err(_) => { + warn!("flight data is broken"); + break; + } + } + } + Err(_) => { + break; + } + } + } + } + } + drop(state); + drop(streaming); + } + } + .in_span(Span::enter_with_local_parent(full_name!())); + + databend_common_base::runtime::spawn(fut); + + notify } } impl Drop for FlightDataAckStream { fn drop(&mut self) { - info!("Drop FlightDataAckStream"); - let state_should_retry = { - info!("Drop stage1 hold lock"); - let mut state = self.state.lock(); - if state.may_retry { - state.may_retry = false; - true - } else { - state.receiver.close(); - false - } - }; - info!("Drop stage1 release lock"); - if state_should_retry { - let weak = Arc::downgrade(&self.state); - GlobalIORuntime::instance().spawn(async move { - info!("Drop stage2 begin, wait for 60"); - tokio::time::sleep(Duration::from_secs(60)).await; - if let Some(ss) = weak.upgrade() { - info!("Drop stage2 hold lock"); - let ss = ss.lock(); - if !ss.may_retry { - ss.receiver.close(); - } - info!("Drop stage2 release lock"); - } - }); + let mut state = self.state.lock(); + if state.finish.load(Ordering::SeqCst) { + self.notify.notify_waiters(); + state.receiver.close(); + return; } + let weak_state = Arc::downgrade(&self.state); + let notify = Arc::downgrade(&self.notify); + let handle = GlobalIORuntime::instance().spawn(async move { + tokio::time::sleep(Duration::from_secs(60)).await; + if let Some(ss) = weak_state.upgrade() { + let ss = ss.lock(); + ss.receiver.close(); + } + if let Some(notify) = notify.upgrade() { + notify.notify_waiters(); + } + }); + state.clean_up_handle = Some(handle); } } @@ -686,9 +758,6 @@ impl Stream for FlightDataAckStream { type Item = std::result::Result, Status>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - info!("Poll next hold lock"); - let res = self.state.lock().poll_next(cx); - info!("Poll next release lock"); - res + self.state.lock().poll_next(cx) } } diff --git a/src/query/service/src/servers/flight/flight_service.rs b/src/query/service/src/servers/flight/flight_service.rs index 64507a2c0843d..73ba5b46906de 100644 --- a/src/query/service/src/servers/flight/flight_service.rs +++ b/src/query/service/src/servers/flight/flight_service.rs @@ -147,10 +147,7 @@ where type Error = Infallible; type Future = BoxFuture; - fn poll_ready( - &mut self, - _cx: &mut Context<'_>, - ) -> Poll> { + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } diff --git a/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs b/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs index 7d3df09e33351..014859422dc91 100644 --- a/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs +++ b/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs @@ -21,6 +21,7 @@ use std::sync::atomic::Ordering; use std::sync::Arc; use std::time::Duration; +use databend_common_arrow::arrow_format::flight::data::FlightData; use databend_common_arrow::arrow_format::flight::service::flight_service_client::FlightServiceClient; use databend_common_base::base::GlobalInstance; use databend_common_base::runtime::GlobalIORuntime; @@ -40,6 +41,7 @@ use parking_lot::Mutex; use parking_lot::ReentrantMutex; use petgraph::prelude::EdgeRef; use petgraph::Direction; +use tonic::Streaming; use super::exchange_params::ExchangeParams; use super::exchange_params::MergeExchangeParams; @@ -158,7 +160,7 @@ impl DataExchangeManager { source: source.id.clone(), fragment: v, exchange: flight_client - .do_get( + .request_fragment_exchange( &query_id, &target.id, v, @@ -171,7 +173,7 @@ impl DataExchangeManager { Edge::Statistics => QueryExchange::Statistics { source: source.id.clone(), exchange: flight_client - .request_server_exchange( + .request_statistics_exchange( &query_id, &target.id, &address, @@ -367,16 +369,20 @@ impl DataExchangeManager { id: String, target: String, continue_from: usize, + client_stream: Streaming, ) -> Result { let queries_coordinator_guard = self.queries_coordinator.lock(); let queries_coordinator = unsafe { &mut *queries_coordinator_guard.deref().get() }; match queries_coordinator.entry(id) { - Entry::Occupied(mut v) => v.get_mut().add_statistics_exchange(target, continue_from), + Entry::Occupied(mut v) => { + v.get_mut() + .add_statistics_exchange(target, continue_from, client_stream) + } Entry::Vacant(v) => match continue_from == 0 { true => v .insert(QueryCoordinator::create()) - .add_statistics_exchange(target, continue_from), + .add_statistics_exchange(target, continue_from, client_stream), false => Err(ErrorCode::Timeout( "Reconnection timeout, the state has been cleared", )), @@ -391,6 +397,7 @@ impl DataExchangeManager { target: String, fragment: usize, continue_from: usize, + client_stream: Streaming, ) -> Result { let queries_coordinator_guard = self.queries_coordinator.lock(); let queries_coordinator = unsafe { &mut *queries_coordinator_guard.deref().get() }; @@ -398,13 +405,14 @@ impl DataExchangeManager { match queries_coordinator.entry(query) { Entry::Occupied(mut v) => { v.get_mut() - .add_fragment_exchange(target, fragment, continue_from) + .add_fragment_exchange(target, fragment, continue_from, client_stream) } Entry::Vacant(v) => match continue_from == 0 { true => v.insert(QueryCoordinator::create()).add_fragment_exchange( target, fragment, continue_from, + client_stream, ), false => Err(ErrorCode::Timeout( "Reconnection timeout, the state has been cleared", @@ -663,19 +671,20 @@ impl QueryCoordinator { &mut self, target: String, begin: usize, + client_stream: Streaming, ) -> Result { - let (tx, rx) = async_channel::bounded(8); + let (tx, rx) = async_channel::unbounded(); let identifier = ExchangeIdentifier::Statistics(target); match self.exchanges.entry(identifier) { Entry::Vacant(v) => { - let state = FlightDataAckState::create(10, rx); + let state = FlightDataAckState::create(rx); v.insert(FlightExchange::create_sender(state.clone(), tx)); - FlightDataAckStream::create(state, begin) + FlightDataAckStream::create(state, begin, client_stream) } Entry::Occupied(mut v) => match v.get_mut() { FlightExchange::MovedSender(v) => { - FlightDataAckStream::create(v.state.clone(), begin) + FlightDataAckStream::create(v.state.clone(), begin, client_stream) } _ => Err(ErrorCode::Internal( "statistics exchanges can only have one", @@ -705,19 +714,20 @@ impl QueryCoordinator { target: String, fragment: usize, begin: usize, + client_stream: Streaming, ) -> Result { - let (tx, rx) = async_channel::bounded(8); + let (tx, rx) = async_channel::unbounded(); let identifier = ExchangeIdentifier::fragment_sender(target, fragment); match self.exchanges.entry(identifier) { Entry::Vacant(v) => { - let state = FlightDataAckState::create(10, rx); + let state = FlightDataAckState::create(rx); v.insert(FlightExchange::create_sender(state.clone(), tx)); - FlightDataAckStream::create(state, begin) + FlightDataAckStream::create(state, begin, client_stream) } Entry::Occupied(mut v) => match v.get_mut() { FlightExchange::MovedSender(v) => { - FlightDataAckStream::create(v.state.clone(), begin) + FlightDataAckStream::create(v.state.clone(), begin, client_stream) } _ => Err(ErrorCode::Internal("fragment exchange can only have one")), }, diff --git a/src/query/service/src/servers/flight/v1/exchange/serde/exchange_deserializer.rs b/src/query/service/src/servers/flight/v1/exchange/serde/exchange_deserializer.rs index 999f7d990a5df..33e64787be7a6 100644 --- a/src/query/service/src/servers/flight/v1/exchange/serde/exchange_deserializer.rs +++ b/src/query/service/src/servers/flight/v1/exchange/serde/exchange_deserializer.rs @@ -128,6 +128,7 @@ impl BlockMetaTransform for TransformExchangeDeserializ DataPacket::QueryProfiles(_) => unreachable!(), DataPacket::DataCacheMetrics(_) => unreachable!(), DataPacket::FragmentData(v) => Ok(vec![self.recv_data(meta.packet, v)?]), + DataPacket::FlightControl(_) => unreachable!(), } } } diff --git a/src/query/service/src/servers/flight/v1/exchange/statistics_receiver.rs b/src/query/service/src/servers/flight/v1/exchange/statistics_receiver.rs index bfb8b3291e069..02cfbd165f1a4 100644 --- a/src/query/service/src/servers/flight/v1/exchange/statistics_receiver.rs +++ b/src/query/service/src/servers/flight/v1/exchange/statistics_receiver.rs @@ -145,6 +145,7 @@ impl StatisticsReceiver { ctx.get_data_cache_metrics().merge(metrics); Ok(false) } + Ok(Some(DataPacket::FlightControl(_))) => unreachable!(), } } diff --git a/src/query/service/src/servers/flight/v1/flight_service.rs b/src/query/service/src/servers/flight/v1/flight_service.rs index 80ddcbac209ba..fc51067803952 100644 --- a/src/query/service/src/servers/flight/v1/flight_service.rs +++ b/src/query/service/src/servers/flight/v1/flight_service.rs @@ -97,14 +97,10 @@ impl FlightOperation for DatabendQueryFlightService { type DoExchangeStream = FlightStream>; #[async_backtrace::framed] - async fn do_exchange(&self, _: StreamReq) -> Response { - Err(Status::unimplemented("unimplemented do_exchange")) - } - - type DoGetStream = FlightStream>; - - #[async_backtrace::framed] - async fn do_get(&self, request: Request) -> Response { + async fn do_exchange( + &self, + request: StreamReq, + ) -> Response { let root = databend_common_tracing::start_trace_for_remote_request(func_path!(), &request); let _guard = root.set_local_parent(); match request.get_metadata("x-type")?.as_str() { @@ -115,11 +111,13 @@ impl FlightOperation for DatabendQueryFlightService { .get_metadata("x-continue-from")? .parse::() .unwrap(); + let client_stream = request.into_inner(); Ok(RawResponse::new(Box::pin( DataExchangeManager::instance().handle_statistics_exchange( query_id, target, continue_from, + client_stream, )?, ))) } @@ -134,12 +132,14 @@ impl FlightOperation for DatabendQueryFlightService { .get_metadata("x-continue-from")? .parse::() .unwrap(); + let client_stream = request.into_inner(); Ok(RawResponse::new(Box::pin( DataExchangeManager::instance().handle_exchange_fragment( query_id, target, fragment, continue_from, + client_stream, )?, ))) } @@ -150,6 +150,13 @@ impl FlightOperation for DatabendQueryFlightService { ))), } } + + type DoGetStream = FlightStream>; + + #[async_backtrace::framed] + async fn do_get(&self, _request: Request) -> Response { + Err(Status::unimplemented("unimplemented do_exchange")) + } type DoPutStream = FlightStream>; #[async_backtrace::framed] @@ -195,6 +202,7 @@ impl FlightOperation for DatabendQueryFlightService { Ok(RawResponse::new(Box::pin(stream::empty()))) } } + fn build_health_response() -> FlightStream> { Box::pin(stream::iter(vec![Ok(Arc::new(FlightData { flight_descriptor: None, diff --git a/src/query/service/src/servers/flight/v1/packets/mod.rs b/src/query/service/src/servers/flight/v1/packets/mod.rs index ca44d46afde5e..b9ca6e885e46b 100644 --- a/src/query/service/src/servers/flight/v1/packets/mod.rs +++ b/src/query/service/src/servers/flight/v1/packets/mod.rs @@ -19,6 +19,7 @@ mod packet_fragment; mod packet_publisher; pub use packet_data::DataPacket; +pub use packet_data::FlightControlCommand; pub use packet_data::FragmentData; pub use packet_data_progressinfo::ProgressInfo; pub use packet_executor::QueryFragments; diff --git a/src/query/service/src/servers/flight/v1/packets/packet_data.rs b/src/query/service/src/servers/flight/v1/packets/packet_data.rs index 2ca07a7a8dcdd..9a2ab11f2b638 100644 --- a/src/query/service/src/servers/flight/v1/packets/packet_data.rs +++ b/src/query/service/src/servers/flight/v1/packets/packet_data.rs @@ -28,6 +28,8 @@ use databend_common_pipeline_core::processors::PlanProfile; use databend_common_storage::CopyStatus; use databend_common_storage::MutationStatus; use log::error; +use serde::Deserialize; +use serde::Serialize; use crate::servers::flight::v1::packets::ProgressInfo; @@ -61,6 +63,13 @@ pub enum DataPacket { CopyStatus(CopyStatus), MutationStatus(MutationStatus), DataCacheMetrics(DataCacheMetricValues), + FlightControl(FlightControlCommand), +} + +#[derive(Serialize, Deserialize, Debug)] +pub enum FlightControlCommand { + Ack(usize), + Close, } fn calc_size(flight_data: &FlightData) -> usize { @@ -78,6 +87,7 @@ impl DataPacket { DataPacket::FragmentData(v) => calc_size(&v.data) + v.meta.len(), DataPacket::QueryProfiles(_) => 0, DataPacket::DataCacheMetrics(_) => 0, + DataPacket::FlightControl(_) => 0, } } } @@ -136,6 +146,12 @@ impl TryFrom for FlightData { data_header: vec![], flight_descriptor: None, }, + DataPacket::FlightControl(command) => FlightData { + app_metadata: vec![0x09], + data_body: serde_json::to_vec(&command)?, + data_header: vec![], + flight_descriptor: None, + }, }) } } @@ -196,6 +212,11 @@ impl TryFrom for DataPacket { serde_json::from_slice::(&flight_data.data_body)?; Ok(DataPacket::DataCacheMetrics(status)) } + 0x09 => { + let command = + serde_json::from_slice::(&flight_data.data_body)?; + Ok(DataPacket::FlightControl(command)) + } _ => Err(ErrorCode::BadBytes("Unknown flight data packet type.")), } } diff --git a/src/query/service/tests/it/servers/flight/flight_service.rs b/src/query/service/tests/it/servers/flight/flight_service.rs index 5d45f8c072acc..c3821401c3351 100644 --- a/src/query/service/tests/it/servers/flight/flight_service.rs +++ b/src/query/service/tests/it/servers/flight/flight_service.rs @@ -17,7 +17,6 @@ use std::net::TcpListener; use std::str::FromStr; use std::sync::Arc; -use databend_common_arrow::arrow_format::flight::data::Ticket; use databend_common_arrow::arrow_format::flight::service::flight_service_client::FlightServiceClient; use databend_common_base::base::tokio; use databend_common_exception::ErrorCode; @@ -55,27 +54,31 @@ async fn test_tls_rpc_server() -> Result<()> { // normal case let conn = ConnectionFactory::create_rpc_channel(listener_address, None, tls_conf).await?; let mut f_client = FlightServiceClient::new(conn); + let (server_tx, server_rx) = async_channel::bounded(1); let r = f_client - .do_get( - RequestBuilder::create(Ticket::default()) + .do_exchange( + RequestBuilder::create(Box::pin(server_rx)) .with_metadata("x-type", "health")? .build(), ) .await; assert!(r.is_ok()); - + server_tx.close(); // client access without tls enabled will be failed // - channel can still be created, but communication will be failed let channel = ConnectionFactory::create_rpc_channel(listener_address, None, None).await?; + let mut f_client = FlightServiceClient::new(channel); + let (server_tx, server_rx) = async_channel::bounded(1); let r = f_client - .do_get( - RequestBuilder::create(Ticket::default()) + .do_exchange( + RequestBuilder::create(Box::pin(server_rx)) .with_metadata("x-type", "health")? .build(), ) .await; assert!(r.is_err()); + server_tx.close(); Ok(()) } @@ -136,16 +139,16 @@ async fn test_rpc_server_port_used() -> Result<()> { } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] -async fn test_rpc_server_do_get() -> Result<()> { +async fn test_rpc_server_do_exchange() -> Result<()> { let listener_address = SocketAddr::from_str("127.0.0.1:9995")?; let mut rpc_service = FlightService::create(ConfigBuilder::create().build())?; rpc_service.start(listener_address).await?; let conn = ConnectionFactory::create_rpc_channel(listener_address, None, None).await?; let mut f_client = FlightServiceClient::new(conn); - + let (server_tx, server_rx) = async_channel::bounded(1); let r = f_client - .do_get( - RequestBuilder::create(Ticket::default()) + .do_exchange( + RequestBuilder::create(Box::pin(server_rx)) .with_metadata("x-type", "health")? .build(), ) @@ -155,5 +158,6 @@ async fn test_rpc_server_do_get() -> Result<()> { assert_eq!(message.app_metadata.last(), Some(&0x03)); assert_eq!(message.data_body, Vec::from("ok")); + server_tx.close(); Ok(()) } diff --git a/tests/suites/1_stateful/02_query/02_0009_kill_connection.sh b/tests/suites/1_stateful/02_query/02_0009_kill_connection.sh index 9f632454311ed..85c386f819c1e 100755 --- a/tests/suites/1_stateful/02_query/02_0009_kill_connection.sh +++ b/tests/suites/1_stateful/02_query/02_0009_kill_connection.sh @@ -1,7 +1,7 @@ #!/usr/bin/env bash perform_initial_query() { - local response=$(curl -s -u root: -XPOST "http://localhost:8000/v1/query" -H 'Content-Type: application/json' -d '{"sql": "select avg(number) from numbers(1000000000)"}') + local response=$(curl -s -u root: -XPOST "http://localhost:8000/v1/query" -H 'Content-Type: application/json' -d '{"sql": "select avg(number) from numbers(2000000000)"}') local stats_uri=$(echo "$response" | jq -r '.stats_uri') local final_uri=$(echo "$response" | jq -r '.final_uri') echo "$stats_uri|$final_uri" @@ -28,15 +28,9 @@ IFS='|' read -r stats_uri final_uri <<< $(perform_initial_query) poll_stats_uri "$stats_uri" & POLL_PID=$! -sleep 1 +sleep 2 netstat_output=$(netstat -an | grep '9092') -# skip standalone mode -if [ -z "$netstat_output" ]; then - echo "Final state: Succeeded" - exit 0 -fi - port=$(echo "$netstat_output" | awk ' $NF == "ESTABLISHED" { if ($4 ~ /:9092$/) { @@ -52,6 +46,12 @@ port=$(echo "$netstat_output" | awk ' } ') +# skip standalone mode +if [ -z "$port" ]; then + echo "Final state: Succeeded" + exit 0 +fi + # Start tcpkill in the background sudo tcpkill -i lo host 127.0.0.1 and port $port > tcpkill_output.txt 2>&1 & TCPKILL_PID=$! From 08136be0cab88b3c86da24ad6f906a4150300dfb Mon Sep 17 00:00:00 2001 From: Liuqing Yue Date: Wed, 11 Sep 2024 21:37:46 +0800 Subject: [PATCH 04/18] debug --- .../1_stateful/02_query/02_0009_kill_connection.sh | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/suites/1_stateful/02_query/02_0009_kill_connection.sh b/tests/suites/1_stateful/02_query/02_0009_kill_connection.sh index 85c386f819c1e..28fa8f85a45bc 100755 --- a/tests/suites/1_stateful/02_query/02_0009_kill_connection.sh +++ b/tests/suites/1_stateful/02_query/02_0009_kill_connection.sh @@ -8,16 +8,28 @@ perform_initial_query() { } poll_stats_uri() { local uri=$1 + local timeout=30 + local elapsed=0 local state_exists=true + while $state_exists; do local response=$(curl -s -u root: -XGET "http://localhost:8000$uri") + echo "$response" if ! echo "$response" | jq -e '.state' > /dev/null; then state_exists=false else sleep 2 + elapsed=$((elapsed + 2)) + + if [ "$elapsed" -ge "$timeout" ]; then + echo "Polling timed out after $timeout seconds." + kill $$ + exit 1 + fi fi done } + get_final_state() { local uri=$1 local response=$(curl -s -u root: -XGET "http://localhost:8000$uri") @@ -74,3 +86,4 @@ wait $POLL_PID final_state=$(get_final_state "$final_uri") echo "Final state: $final_state" +cat tcpkill_output.txt From 76024d4ca42fb2444a2b9c9664c7d4e28e90565a Mon Sep 17 00:00:00 2001 From: Liuqing Yue Date: Wed, 11 Sep 2024 22:13:36 +0800 Subject: [PATCH 05/18] fixup --- .../src/servers/flight/flight_client.rs | 1 + .../02_query/02_0009_kill_connection.sh | 29 +++++-------------- 2 files changed, 9 insertions(+), 21 deletions(-) diff --git a/src/query/service/src/servers/flight/flight_client.rs b/src/query/service/src/servers/flight/flight_client.rs index 2b6a68291e6ea..b867145189d4b 100644 --- a/src/query/service/src/servers/flight/flight_client.rs +++ b/src/query/service/src/servers/flight/flight_client.rs @@ -655,6 +655,7 @@ impl FlightDataAckStream { let notify = Self::streaming_receiver(state.clone(), client_stream); let mut state_guard = state.lock(); state_guard.seq.store(begin, Ordering::SeqCst); + state_guard.finish.store(false, Ordering::SeqCst); if let Some(handle) = state_guard.clean_up_handle.take() { handle.abort(); } diff --git a/tests/suites/1_stateful/02_query/02_0009_kill_connection.sh b/tests/suites/1_stateful/02_query/02_0009_kill_connection.sh index 28fa8f85a45bc..9f632454311ed 100755 --- a/tests/suites/1_stateful/02_query/02_0009_kill_connection.sh +++ b/tests/suites/1_stateful/02_query/02_0009_kill_connection.sh @@ -1,35 +1,23 @@ #!/usr/bin/env bash perform_initial_query() { - local response=$(curl -s -u root: -XPOST "http://localhost:8000/v1/query" -H 'Content-Type: application/json' -d '{"sql": "select avg(number) from numbers(2000000000)"}') + local response=$(curl -s -u root: -XPOST "http://localhost:8000/v1/query" -H 'Content-Type: application/json' -d '{"sql": "select avg(number) from numbers(1000000000)"}') local stats_uri=$(echo "$response" | jq -r '.stats_uri') local final_uri=$(echo "$response" | jq -r '.final_uri') echo "$stats_uri|$final_uri" } poll_stats_uri() { local uri=$1 - local timeout=30 - local elapsed=0 local state_exists=true - while $state_exists; do local response=$(curl -s -u root: -XGET "http://localhost:8000$uri") - echo "$response" if ! echo "$response" | jq -e '.state' > /dev/null; then state_exists=false else sleep 2 - elapsed=$((elapsed + 2)) - - if [ "$elapsed" -ge "$timeout" ]; then - echo "Polling timed out after $timeout seconds." - kill $$ - exit 1 - fi fi done } - get_final_state() { local uri=$1 local response=$(curl -s -u root: -XGET "http://localhost:8000$uri") @@ -40,9 +28,15 @@ IFS='|' read -r stats_uri final_uri <<< $(perform_initial_query) poll_stats_uri "$stats_uri" & POLL_PID=$! -sleep 2 +sleep 1 netstat_output=$(netstat -an | grep '9092') +# skip standalone mode +if [ -z "$netstat_output" ]; then + echo "Final state: Succeeded" + exit 0 +fi + port=$(echo "$netstat_output" | awk ' $NF == "ESTABLISHED" { if ($4 ~ /:9092$/) { @@ -58,12 +52,6 @@ port=$(echo "$netstat_output" | awk ' } ') -# skip standalone mode -if [ -z "$port" ]; then - echo "Final state: Succeeded" - exit 0 -fi - # Start tcpkill in the background sudo tcpkill -i lo host 127.0.0.1 and port $port > tcpkill_output.txt 2>&1 & TCPKILL_PID=$! @@ -86,4 +74,3 @@ wait $POLL_PID final_state=$(get_final_state "$final_uri") echo "Final state: $final_state" -cat tcpkill_output.txt From 07a150a9e57a6502eec92aa15a2480849080d204 Mon Sep 17 00:00:00 2001 From: Liuqing Yue Date: Thu, 12 Sep 2024 14:40:35 +0800 Subject: [PATCH 06/18] fixup --- .../src/servers/flight/flight_client.rs | 56 +++++++++++++------ .../flight/v1/exchange/exchange_manager.rs | 4 +- 2 files changed, 42 insertions(+), 18 deletions(-) diff --git a/src/query/service/src/servers/flight/flight_client.rs b/src/query/service/src/servers/flight/flight_client.rs index b867145189d4b..b1512c93ee088 100644 --- a/src/query/service/src/servers/flight/flight_client.rs +++ b/src/query/service/src/servers/flight/flight_client.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::VecDeque; use std::pin::Pin; use std::str::FromStr; use std::sync::atomic::AtomicBool; @@ -328,13 +329,16 @@ impl FlightRxInner { } } + pub fn stop_cluster(&self) { + let _ = self.server_tx.send( + FlightData::try_from(DataPacket::FlightControl(FlightControlCommand::Close)).unwrap(), + ); + } + pub fn close(&self) { self.rx.close(); self.notify.notify_waiters(); - let res = self.server_tx.send( - FlightData::try_from(DataPacket::FlightControl(FlightControlCommand::Close)).unwrap(), - ); - info!("Send close signal to flight server, result: {:?}", res); + self.server_tx.close(); } } @@ -438,6 +442,7 @@ impl RetryableFlightReceiver { let inner = self.inner.load(Ordering::SeqCst); if !inner.is_null() { + (*inner).stop_cluster(); (*inner).close(); } } @@ -556,22 +561,25 @@ pub struct FlightDataAckState { seq: AtomicUsize, finish: AtomicBool, receiver: Receiver>, - last_packet: Option<(usize, std::result::Result, Status>)>, + ack_window: VecDeque<(usize, std::result::Result, Status>)>, clean_up_handle: Option>, waker: Option, + window_size: usize, } impl FlightDataAckState { pub fn create( receiver: Receiver>, + window_size: usize ) -> Arc> { Arc::new(Mutex::new(FlightDataAckState { receiver, seq: AtomicUsize::new(0), - last_packet: None, + ack_window: VecDeque::with_capacity(window_size), finish: AtomicBool::new(false), clean_up_handle: None, waker: None, + window_size })) } @@ -580,7 +588,7 @@ impl FlightDataAckState { cause: Status, ) -> Poll, Status>>> { let message_seq = self.seq.fetch_add(1, Ordering::SeqCst); - self.last_packet = Some((message_seq, Err(cause.clone()))); + self.ack_window.push_back((message_seq, Err(cause.clone()))); Poll::Ready(Some(Err(cause))) } @@ -597,23 +605,38 @@ impl FlightDataAckState { let message_seq = self.seq.fetch_add(1, Ordering::SeqCst); let data = Arc::new(data); let duplicate = data.clone(); - self.last_packet = Some((message_seq, Ok(data))); + self.ack_window.push_back((message_seq, Ok(data))); Poll::Ready(Some(Ok(duplicate))) } fn check_resend(&mut self) -> Option, Status>> { let current_seq = self.seq.load(Ordering::SeqCst); - if let Some((seq, packet)) = &self.last_packet { - if seq == ¤t_seq { + if let Some((seq, _packet)) = self.ack_window.back() { + if seq + 1 == current_seq{ + return None; + } + } + // resend case, iterate the queue to find the message to resend + for (id, res) in self.ack_window.iter() { + if *id == current_seq { self.seq.fetch_add(1, Ordering::SeqCst); - return Some(packet.clone()); + return Some(res.clone()); } } None } + fn check_ack_window(&mut self, cx: &mut Context<'_>)-> Option, Status>>>> { + if self.ack_window.len() == self.window_size { + self.waker = Some(cx.waker().clone()); + return Some(Poll::Pending); + } + + None + } + pub fn poll_next( &mut self, cx: &mut Context<'_>, @@ -627,11 +650,12 @@ impl FlightDataAckState { return Poll::Ready(Some(res)); } - // last packet is not acked, need to wait - if self.last_packet.is_some() { - self.waker = Some(cx.waker().clone()); - return Poll::Pending; + // check if ack window is full, if so, wait for ack + if let Some(res) = self.check_ack_window(cx) { + return res; } + + match Pin::new(&mut self.receiver).poll_next(cx) { Poll::Pending => Poll::Pending, Poll::Ready(None) => self.end_of_stream(), @@ -689,7 +713,7 @@ impl FlightDataAckStream { Ok(DataPacket::FlightControl(command)) => match command { FlightControlCommand::Ack(_seq) => { let mut state_guard = state.lock(); - state_guard.last_packet = None; + state_guard.ack_window.pop_front(); if let Some(waker) = state_guard.waker.take() { waker.wake(); } diff --git a/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs b/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs index 014859422dc91..0a01667a84187 100644 --- a/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs +++ b/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs @@ -678,7 +678,7 @@ impl QueryCoordinator { match self.exchanges.entry(identifier) { Entry::Vacant(v) => { - let state = FlightDataAckState::create(rx); + let state = FlightDataAckState::create(rx, 10); v.insert(FlightExchange::create_sender(state.clone(), tx)); FlightDataAckStream::create(state, begin, client_stream) } @@ -721,7 +721,7 @@ impl QueryCoordinator { match self.exchanges.entry(identifier) { Entry::Vacant(v) => { - let state = FlightDataAckState::create(rx); + let state = FlightDataAckState::create(rx, 10); v.insert(FlightExchange::create_sender(state.clone(), tx)); FlightDataAckStream::create(state, begin, client_stream) } From 425512952efdf8b4f4791293f7da049720bd5b90 Mon Sep 17 00:00:00 2001 From: Liuqing Yue Date: Thu, 12 Sep 2024 15:12:59 +0800 Subject: [PATCH 07/18] fixup --- .../src/servers/flight/flight_client.rs | 37 +++++++++---------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/src/query/service/src/servers/flight/flight_client.rs b/src/query/service/src/servers/flight/flight_client.rs index b1512c93ee088..f4c10da66d2ae 100644 --- a/src/query/service/src/servers/flight/flight_client.rs +++ b/src/query/service/src/servers/flight/flight_client.rs @@ -330,15 +330,23 @@ impl FlightRxInner { } pub fn stop_cluster(&self) { - let _ = self.server_tx.send( - FlightData::try_from(DataPacket::FlightControl(FlightControlCommand::Close)).unwrap(), - ); + let tx = self.server_tx.clone(); + let fut = async move { + let _ = tx + .send( + FlightData::try_from(DataPacket::FlightControl(FlightControlCommand::Close)) + .unwrap(), + ) + .await; + } + .in_span(Span::enter_with_local_parent(full_name!())); + + databend_common_base::runtime::spawn(fut); } pub fn close(&self) { self.rx.close(); self.notify.notify_waiters(); - self.server_tx.close(); } } @@ -570,7 +578,7 @@ pub struct FlightDataAckState { impl FlightDataAckState { pub fn create( receiver: Receiver>, - window_size: usize + window_size: usize, ) -> Arc> { Arc::new(Mutex::new(FlightDataAckState { receiver, @@ -579,7 +587,7 @@ impl FlightDataAckState { finish: AtomicBool::new(false), clean_up_handle: None, waker: None, - window_size + window_size, })) } @@ -613,7 +621,7 @@ impl FlightDataAckState { let current_seq = self.seq.load(Ordering::SeqCst); if let Some((seq, _packet)) = self.ack_window.back() { - if seq + 1 == current_seq{ + if seq + 1 == current_seq { return None; } } @@ -628,15 +636,6 @@ impl FlightDataAckState { None } - fn check_ack_window(&mut self, cx: &mut Context<'_>)-> Option, Status>>>> { - if self.ack_window.len() == self.window_size { - self.waker = Some(cx.waker().clone()); - return Some(Poll::Pending); - } - - None - } - pub fn poll_next( &mut self, cx: &mut Context<'_>, @@ -651,11 +650,11 @@ impl FlightDataAckState { } // check if ack window is full, if so, wait for ack - if let Some(res) = self.check_ack_window(cx) { - return res; + if self.ack_window.len() == self.window_size { + self.waker = Some(cx.waker().clone()); + return Poll::Pending; } - match Pin::new(&mut self.receiver).poll_next(cx) { Poll::Pending => Poll::Pending, Poll::Ready(None) => self.end_of_stream(), From 3fb21c603c311bcc4f6cf0bc6b9829ccdd93397e Mon Sep 17 00:00:00 2001 From: Liuqing Yue Date: Thu, 12 Sep 2024 16:39:21 +0800 Subject: [PATCH 08/18] fixup --- .../src/servers/flight/flight_client.rs | 27 ++++++------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/src/query/service/src/servers/flight/flight_client.rs b/src/query/service/src/servers/flight/flight_client.rs index f4c10da66d2ae..c3d870c5a5f3e 100644 --- a/src/query/service/src/servers/flight/flight_client.rs +++ b/src/query/service/src/servers/flight/flight_client.rs @@ -329,25 +329,17 @@ impl FlightRxInner { } } - pub fn stop_cluster(&self) { - let tx = self.server_tx.clone(); - let fut = async move { - let _ = tx - .send( - FlightData::try_from(DataPacket::FlightControl(FlightControlCommand::Close)) - .unwrap(), - ) - .await; - } - .in_span(Span::enter_with_local_parent(full_name!())); - - databend_common_base::runtime::spawn(fut); - } - pub fn close(&self) { self.rx.close(); self.notify.notify_waiters(); } + + pub fn stop_cluster(&mut self) { + let _ = self.server_tx.send_blocking( + FlightData::try_from(DataPacket::FlightControl(FlightControlCommand::Close)) + .expect("convert to flight data error"), + ); + } } pub struct RetryableFlightReceiver { @@ -380,15 +372,12 @@ impl RetryableFlightReceiver { Ok(message) => { let ack_seq = self.seq.fetch_add(1, Ordering::SeqCst); if message.is_some() { - let error = inner + let _ = inner .server_tx .send(FlightData::try_from(DataPacket::FlightControl( FlightControlCommand::Ack(ack_seq), ))?) .await; - if error.is_err() { - info!("Error while sending ack to flight : {:?}", error); - } } Ok(message) } From 96306040028b524bb8f54e7cbb99f227485e433c Mon Sep 17 00:00:00 2001 From: Liuqing Yue Date: Thu, 12 Sep 2024 17:20:06 +0800 Subject: [PATCH 09/18] fixup --- src/query/service/src/servers/flight/flight_client.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/query/service/src/servers/flight/flight_client.rs b/src/query/service/src/servers/flight/flight_client.rs index c3d870c5a5f3e..0fca994db4ab5 100644 --- a/src/query/service/src/servers/flight/flight_client.rs +++ b/src/query/service/src/servers/flight/flight_client.rs @@ -335,10 +335,11 @@ impl FlightRxInner { } pub fn stop_cluster(&mut self) { - let _ = self.server_tx.send_blocking( + let _ = self.server_tx.try_send( FlightData::try_from(DataPacket::FlightControl(FlightControlCommand::Close)) .expect("convert to flight data error"), ); + // ignore the error, because we cannot determine the state of server side } } From cd8d8565f7de0ecfe53d3f1b7a567c7155d90c3d Mon Sep 17 00:00:00 2001 From: Liuqing Yue Date: Thu, 12 Sep 2024 20:10:31 +0800 Subject: [PATCH 10/18] don't need cas, already mutex lock --- .../src/servers/flight/flight_client.rs | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/src/query/service/src/servers/flight/flight_client.rs b/src/query/service/src/servers/flight/flight_client.rs index 0fca994db4ab5..9248c022701db 100644 --- a/src/query/service/src/servers/flight/flight_client.rs +++ b/src/query/service/src/servers/flight/flight_client.rs @@ -15,7 +15,6 @@ use std::collections::VecDeque; use std::pin::Pin; use std::str::FromStr; -use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicPtr; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering; @@ -556,8 +555,8 @@ impl FlightExchange { } pub struct FlightDataAckState { - seq: AtomicUsize, - finish: AtomicBool, + seq: usize, + finish: bool, receiver: Receiver>, ack_window: VecDeque<(usize, std::result::Result, Status>)>, clean_up_handle: Option>, @@ -572,9 +571,9 @@ impl FlightDataAckState { ) -> Arc> { Arc::new(Mutex::new(FlightDataAckState { receiver, - seq: AtomicUsize::new(0), + seq: 0, ack_window: VecDeque::with_capacity(window_size), - finish: AtomicBool::new(false), + finish: false, clean_up_handle: None, waker: None, window_size, @@ -585,14 +584,15 @@ impl FlightDataAckState { &mut self, cause: Status, ) -> Poll, Status>>> { - let message_seq = self.seq.fetch_add(1, Ordering::SeqCst); + let message_seq = self.seq; + self.seq += 1; self.ack_window.push_back((message_seq, Err(cause.clone()))); Poll::Ready(Some(Err(cause))) } fn end_of_stream(&mut self) -> Poll, Status>>> { - self.seq.fetch_add(1, Ordering::SeqCst); - self.finish.store(true, Ordering::SeqCst); + self.seq += 1; + self.finish = true; Poll::Ready(None) } @@ -600,7 +600,8 @@ impl FlightDataAckState { &mut self, data: FlightData, ) -> Poll, Status>>> { - let message_seq = self.seq.fetch_add(1, Ordering::SeqCst); + let message_seq = self.seq; + self.seq += 1; let data = Arc::new(data); let duplicate = data.clone(); self.ack_window.push_back((message_seq, Ok(data))); @@ -608,7 +609,7 @@ impl FlightDataAckState { } fn check_resend(&mut self) -> Option, Status>> { - let current_seq = self.seq.load(Ordering::SeqCst); + let current_seq = self.seq; if let Some((seq, _packet)) = self.ack_window.back() { if seq + 1 == current_seq { @@ -618,7 +619,7 @@ impl FlightDataAckState { // resend case, iterate the queue to find the message to resend for (id, res) in self.ack_window.iter() { if *id == current_seq { - self.seq.fetch_add(1, Ordering::SeqCst); + self.seq += 1; return Some(res.clone()); } } @@ -630,7 +631,7 @@ impl FlightDataAckState { &mut self, cx: &mut Context<'_>, ) -> Poll, Status>>> { - if self.finish.load(Ordering::SeqCst) { + if self.finish { return Poll::Ready(None); } @@ -667,8 +668,8 @@ impl FlightDataAckStream { ) -> Result { let notify = Self::streaming_receiver(state.clone(), client_stream); let mut state_guard = state.lock(); - state_guard.seq.store(begin, Ordering::SeqCst); - state_guard.finish.store(false, Ordering::SeqCst); + state_guard.seq = begin; + state_guard.finish = false; if let Some(handle) = state_guard.clean_up_handle.take() { handle.abort(); } @@ -709,7 +710,7 @@ impl FlightDataAckStream { drop(state_guard); } FlightControlCommand::Close => { - state.lock().finish.store(true, Ordering::SeqCst); + state.lock().finish = true; info!("Received Command Close"); break; } @@ -747,7 +748,7 @@ impl FlightDataAckStream { impl Drop for FlightDataAckStream { fn drop(&mut self) { let mut state = self.state.lock(); - if state.finish.load(Ordering::SeqCst) { + if state.finish { self.notify.notify_waiters(); state.receiver.close(); return; From 1afe141e475292b1a6632ad222d5ea5fbf73eedd Mon Sep 17 00:00:00 2001 From: Liuqing Yue Date: Thu, 12 Sep 2024 23:06:34 +0800 Subject: [PATCH 11/18] try to improve performance by change ack method --- .../src/servers/flight/flight_client.rs | 76 ++++--------------- 1 file changed, 15 insertions(+), 61 deletions(-) diff --git a/src/query/service/src/servers/flight/flight_client.rs b/src/query/service/src/servers/flight/flight_client.rs index 9248c022701db..1ce7973849c1f 100644 --- a/src/query/service/src/servers/flight/flight_client.rs +++ b/src/query/service/src/servers/flight/flight_client.rs @@ -21,7 +21,6 @@ use std::sync::atomic::Ordering; use std::sync::Arc; use std::task::Context; use std::task::Poll; -use std::task::Waker; use async_channel::Receiver; use async_channel::Sender; @@ -41,7 +40,6 @@ use futures::Stream; use futures::StreamExt; use futures_util::future::Either; use log::info; -use log::warn; use parking_lot::Mutex; use serde::Deserialize; use serde::Serialize; @@ -370,15 +368,7 @@ impl RetryableFlightReceiver { let inner = unsafe { &*self.inner.load(Ordering::SeqCst) }; return match inner.recv().await { Ok(message) => { - let ack_seq = self.seq.fetch_add(1, Ordering::SeqCst); - if message.is_some() { - let _ = inner - .server_tx - .send(FlightData::try_from(DataPacket::FlightControl( - FlightControlCommand::Ack(ack_seq), - ))?) - .await; - } + self.seq.fetch_add(1, Ordering::SeqCst); Ok(message) } Err(cause) => { @@ -560,7 +550,6 @@ pub struct FlightDataAckState { receiver: Receiver>, ack_window: VecDeque<(usize, std::result::Result, Status>)>, clean_up_handle: Option>, - waker: Option, window_size: usize, } @@ -575,7 +564,6 @@ impl FlightDataAckState { ack_window: VecDeque::with_capacity(window_size), finish: false, clean_up_handle: None, - waker: None, window_size, })) } @@ -640,10 +628,9 @@ impl FlightDataAckState { return Poll::Ready(Some(res)); } - // check if ack window is full, if so, wait for ack + // check if ack window is full, if so, pop the oldest packet if self.ack_window.len() == self.window_size { - self.waker = Some(cx.waker().clone()); - return Poll::Pending; + self.ack_window.pop_front(); } match Pin::new(&mut self.receiver).poll_next(cx) { @@ -685,50 +672,17 @@ impl FlightDataAckStream { let fut = { let notify = notify.clone(); async move { - let mut notified = Box::pin(notify.notified()); - let mut streaming_next = streaming.next(); - - loop { - match futures::future::select(notified, streaming_next).await { - Either::Left((_, _)) | Either::Right((None, _)) => { - break; - } - Either::Right((Some(message), next_notified)) => { - notified = next_notified; - streaming_next = streaming.next(); - match message { - Ok(message) => { - let packet = DataPacket::try_from(message); - match packet { - Ok(DataPacket::FlightControl(command)) => match command { - FlightControlCommand::Ack(_seq) => { - let mut state_guard = state.lock(); - state_guard.ack_window.pop_front(); - if let Some(waker) = state_guard.waker.take() { - waker.wake(); - } - drop(state_guard); - } - FlightControlCommand::Close => { - state.lock().finish = true; - info!("Received Command Close"); - break; - } - }, - Ok(_) => { - unreachable!( - "logic error: only FlightControl packet is expected" - ) - } - Err(_) => { - warn!("flight data is broken"); - break; - } - } - } - Err(_) => { - break; - } + let notified = Box::pin(notify.notified()); + let streaming_next = streaming.next(); + + match futures::future::select(notified, streaming_next).await { + Either::Left((_, _)) | Either::Right((None, _)) => {} + Either::Right((Some(message), _next_notified)) => { + if let Ok(flight_data) = message { + let packet = DataPacket::try_from(flight_data).unwrap(); + if let DataPacket::FlightControl(FlightControlCommand::Close) = packet { + state.lock().finish = true; + info!("Receive close command from remote, close the flight data ack stream."); } } } @@ -737,7 +691,7 @@ impl FlightDataAckStream { drop(streaming); } } - .in_span(Span::enter_with_local_parent(full_name!())); + .in_span(Span::enter_with_local_parent(func_path!())); databend_common_base::runtime::spawn(fut); From bab36acf8bf688038f0cdb5a15b636a9e222df88 Mon Sep 17 00:00:00 2001 From: Liuqing Yue Date: Sat, 14 Sep 2024 10:57:37 +0800 Subject: [PATCH 12/18] change window size --- src/query/service/src/servers/flight/flight_client.rs | 9 +++------ src/query/service/src/servers/flight/flight_service.rs | 1 + src/query/settings/src/settings_default.rs | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/query/service/src/servers/flight/flight_client.rs b/src/query/service/src/servers/flight/flight_client.rs index 1ce7973849c1f..981c9b0d49fc0 100644 --- a/src/query/service/src/servers/flight/flight_client.rs +++ b/src/query/service/src/servers/flight/flight_client.rs @@ -215,12 +215,10 @@ impl FlightClient { loop { match futures::future::select(notified, streaming_next).await { - Either::Left((_, _)) => { - break; - } - Either::Right((None, _)) => { + Either::Left((_, _)) | Either::Right((None, _)) => { break; } + Either::Right((Some(message), next_notified)) => { notified = next_notified; streaming_next = streaming.next(); @@ -332,11 +330,11 @@ impl FlightRxInner { } pub fn stop_cluster(&mut self) { + // ignore the error, because we cannot determine the state of server side let _ = self.server_tx.try_send( FlightData::try_from(DataPacket::FlightControl(FlightControlCommand::Close)) .expect("convert to flight data error"), ); - // ignore the error, because we cannot determine the state of server side } } @@ -682,7 +680,6 @@ impl FlightDataAckStream { let packet = DataPacket::try_from(flight_data).unwrap(); if let DataPacket::FlightControl(FlightControlCommand::Close) = packet { state.lock().finish = true; - info!("Receive close command from remote, close the flight data ack stream."); } } } diff --git a/src/query/service/src/servers/flight/flight_service.rs b/src/query/service/src/servers/flight/flight_service.rs index 73ba5b46906de..3f314700b2e43 100644 --- a/src/query/service/src/servers/flight/flight_service.rs +++ b/src/query/service/src/servers/flight/flight_service.rs @@ -1,3 +1,4 @@ +// Copyright 2016-2019 The Apache Software Foundation // Copyright 2021 Datafuse Labs // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/src/query/settings/src/settings_default.rs b/src/query/settings/src/settings_default.rs index 424335ad5b2cc..3e62634d51280 100644 --- a/src/query/settings/src/settings_default.rs +++ b/src/query/settings/src/settings_default.rs @@ -863,7 +863,7 @@ impl DefaultSettings { range: Some(SettingRange::Numeric(0..=30)), }), ("flight_connection_retry_interval", DefaultSettingValue { - value: UserSettingValue::UInt64(5), + value: UserSettingValue::UInt64(3), desc: "The retry interval of cluster flight is in seconds.", mode: SettingMode::Both, range: Some(SettingRange::Numeric(0..=900)), From f1160339a66c7456fa77ff7d3e0ac9ebceca3ea3 Mon Sep 17 00:00:00 2001 From: Liuqing Yue Date: Sat, 14 Sep 2024 11:42:01 +0800 Subject: [PATCH 13/18] remove ack --- src/query/service/src/servers/flight/v1/packets/packet_data.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/query/service/src/servers/flight/v1/packets/packet_data.rs b/src/query/service/src/servers/flight/v1/packets/packet_data.rs index 9a2ab11f2b638..4c9a38b0fe3a4 100644 --- a/src/query/service/src/servers/flight/v1/packets/packet_data.rs +++ b/src/query/service/src/servers/flight/v1/packets/packet_data.rs @@ -68,7 +68,6 @@ pub enum DataPacket { #[derive(Serialize, Deserialize, Debug)] pub enum FlightControlCommand { - Ack(usize), Close, } From ab5afc2673d1cc4b08f6eda3336a24bc5be0c677 Mon Sep 17 00:00:00 2001 From: Liuqing Yue Date: Thu, 19 Sep 2024 10:50:52 +0800 Subject: [PATCH 14/18] apply review suggestion --- src/query/config/src/config.rs | 12 ++++++++++++ src/query/config/src/inner.rs | 4 ++++ src/query/service/src/clusters/cluster.rs | 5 +++-- .../service/src/servers/flight/flight_client.rs | 6 +++--- .../servers/flight/v1/exchange/exchange_manager.rs | 8 ++++---- .../it/storages/testdata/configs_table_basic.txt | 2 ++ src/query/settings/src/settings_default.rs | 12 ------------ 7 files changed, 28 insertions(+), 21 deletions(-) diff --git a/src/query/config/src/config.rs b/src/query/config/src/config.rs index 767bdb255b226..d6f3e64f0a2dc 100644 --- a/src/query/config/src/config.rs +++ b/src/query/config/src/config.rs @@ -1681,6 +1681,14 @@ pub struct QueryConfig { #[clap(long, value_name = "VALUE", default_value = "50")] pub max_cached_queries_profiles: usize, + /// The maximum retry count for cluster flight. Disable if 0. + #[clap(long, value_name = "VALUE", default_value = "3")] + pub max_flight_connection_retry_times: u64, + + /// The retry interval of cluster flight is in seconds. + #[clap(long, value_name = "VALUE", default_value = "3")] + pub flight_connection_retry_interval: u64, + #[clap(skip)] pub settings: HashMap, } @@ -1770,6 +1778,8 @@ impl TryInto for QueryConfig { cloud_control_grpc_server_address: self.cloud_control_grpc_server_address, cloud_control_grpc_timeout: self.cloud_control_grpc_timeout, max_cached_queries_profiles: self.max_cached_queries_profiles, + max_flight_connection_retry_times: self.max_flight_connection_retry_times, + flight_connection_retry_interval: self.flight_connection_retry_interval, settings: self .settings .into_iter() @@ -1872,6 +1882,8 @@ impl From for QueryConfig { cloud_control_grpc_server_address: inner.cloud_control_grpc_server_address, cloud_control_grpc_timeout: inner.cloud_control_grpc_timeout, max_cached_queries_profiles: inner.max_cached_queries_profiles, + max_flight_connection_retry_times: inner.max_flight_connection_retry_times, + flight_connection_retry_interval: inner.flight_connection_retry_interval, settings: HashMap::new(), } } diff --git a/src/query/config/src/inner.rs b/src/query/config/src/inner.rs index cf6dc8fb847f0..515c20715ce83 100644 --- a/src/query/config/src/inner.rs +++ b/src/query/config/src/inner.rs @@ -240,6 +240,8 @@ pub struct QueryConfig { pub cloud_control_grpc_server_address: Option, pub cloud_control_grpc_timeout: u64, pub max_cached_queries_profiles: usize, + pub max_flight_connection_retry_times: u64, + pub flight_connection_retry_interval: u64, pub settings: HashMap, } @@ -316,6 +318,8 @@ impl Default for QueryConfig { cloud_control_grpc_timeout: 0, data_retention_time_in_days_max: 90, max_cached_queries_profiles: 50, + max_flight_connection_retry_times: 3, + flight_connection_retry_interval: 3, settings: HashMap::new(), } } diff --git a/src/query/service/src/clusters/cluster.rs b/src/query/service/src/clusters/cluster.rs index b3750c8bd481f..d9eca12b21911 100644 --- a/src/query/service/src/clusters/cluster.rs +++ b/src/query/service/src/clusters/cluster.rs @@ -148,7 +148,8 @@ impl ClusterHelper for Cluster { async move { let mut attempt = 0; - let max_attempts = 2; + let max_attempts = config.query.max_flight_connection_retry_times; + let retry_interval = config.query.flight_connection_retry_interval; loop { let mut conn = create_client(&config, &flight_address).await?; @@ -169,7 +170,7 @@ impl ClusterHelper for Cluster { // only retry when error is network problem info!("retry do_action, attempt: {}", attempt); attempt += 1; - sleep(Duration::from_secs(1)).await; + sleep(Duration::from_secs(retry_interval)).await; } Err(e) => return Err(e), } diff --git a/src/query/service/src/servers/flight/flight_client.rs b/src/query/service/src/servers/flight/flight_client.rs index 981c9b0d49fc0..40ac8d1041aa2 100644 --- a/src/query/service/src/servers/flight/flight_client.rs +++ b/src/query/service/src/servers/flight/flight_client.rs @@ -708,9 +708,9 @@ impl Drop for FlightDataAckStream { let notify = Arc::downgrade(&self.notify); let handle = GlobalIORuntime::instance().spawn(async move { tokio::time::sleep(Duration::from_secs(60)).await; - if let Some(ss) = weak_state.upgrade() { - let ss = ss.lock(); - ss.receiver.close(); + if let Some(state) = weak_state.upgrade() { + let state_guard = state.lock(); + state_guard.receiver.close(); } if let Some(notify) = notify.upgrade() { notify.notify_waiters(); diff --git a/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs b/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs index 0a01667a84187..92ce77e045ba8 100644 --- a/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs +++ b/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs @@ -131,8 +131,8 @@ impl DataExchangeManager { let config = GlobalConfig::instance(); let with_cur_rt = env.create_rpc_clint_with_current_rt; - let flight_retry_times = env.settings.get_max_flight_retry_times()?; - let flight_retry_interval = env.settings.get_flight_retry_interval()?; + let flight_retry_times = config.query.max_flight_connection_retry_times as usize; + let flight_retry_interval = config.query.flight_connection_retry_interval as usize; let mut request_exchanges = HashMap::new(); let mut targets_exchanges = HashMap::new(); @@ -673,7 +673,7 @@ impl QueryCoordinator { begin: usize, client_stream: Streaming, ) -> Result { - let (tx, rx) = async_channel::unbounded(); + let (tx, rx) = async_channel::bounded(8); let identifier = ExchangeIdentifier::Statistics(target); match self.exchanges.entry(identifier) { @@ -716,7 +716,7 @@ impl QueryCoordinator { begin: usize, client_stream: Streaming, ) -> Result { - let (tx, rx) = async_channel::unbounded(); + let (tx, rx) = async_channel::bounded(8); let identifier = ExchangeIdentifier::fragment_sender(target, fragment); match self.exchanges.entry(identifier) { diff --git a/src/query/service/tests/it/storages/testdata/configs_table_basic.txt b/src/query/service/tests/it/storages/testdata/configs_table_basic.txt index 7e20e16c1e523..731914688cf0d 100644 --- a/src/query/service/tests/it/storages/testdata/configs_table_basic.txt +++ b/src/query/service/tests/it/storages/testdata/configs_table_basic.txt @@ -88,6 +88,7 @@ DB.Table: 'system'.'configs', Table: configs-table_id:1, ver:0, Engine: SystemCo | 'query' | 'enable_meta_data_upgrade_json_to_pb_from_v307' | 'false' | '' | | 'query' | 'enable_udf_server' | 'false' | '' | | 'query' | 'flight_api_address' | '127.0.0.1:9090' | '' | +| 'query' | 'flight_connection_retry_interval' | '3' | '' | | 'query' | 'flight_sql_handler_host' | '127.0.0.1' | '' | | 'query' | 'flight_sql_handler_port' | '8900' | '' | | 'query' | 'flight_sql_tls_server_cert' | '' | '' | @@ -105,6 +106,7 @@ DB.Table: 'system'.'configs', Table: configs-table_id:1, ver:0, Engine: SystemCo | 'query' | 'management_mode' | 'false' | '' | | 'query' | 'max_active_sessions' | '256' | '' | | 'query' | 'max_cached_queries_profiles' | '50' | '' | +| 'query' | 'max_flight_connection_retry_times' | '3' | '' | | 'query' | 'max_memory_limit_enabled' | 'false' | '' | | 'query' | 'max_query_log_size' | '10000' | '' | | 'query' | 'max_running_queries' | '8' | '' | diff --git a/src/query/settings/src/settings_default.rs b/src/query/settings/src/settings_default.rs index 893826032a6ce..26944906f790e 100644 --- a/src/query/settings/src/settings_default.rs +++ b/src/query/settings/src/settings_default.rs @@ -865,18 +865,6 @@ impl DefaultSettings { mode: SettingMode::Both, range: Some(SettingRange::Numeric(0..=1)), }), - ("max_flight_connection_retry_times", DefaultSettingValue { - value: UserSettingValue::UInt64(3), - desc: "The maximum retry count for cluster flight. Disable if 0.", - mode: SettingMode::Both, - range: Some(SettingRange::Numeric(0..=30)), - }), - ("flight_connection_retry_interval", DefaultSettingValue { - value: UserSettingValue::UInt64(3), - desc: "The retry interval of cluster flight is in seconds.", - mode: SettingMode::Both, - range: Some(SettingRange::Numeric(0..=900)), - }), ("random_function_seed", DefaultSettingValue { value: UserSettingValue::UInt64(0), desc: "Seed for random function", From 2d56d33181c096c35eb26bb18fbd489b6394dd3b Mon Sep 17 00:00:00 2001 From: Liuqing Yue Date: Thu, 19 Sep 2024 14:45:43 +0800 Subject: [PATCH 15/18] fixup --- src/query/service/src/servers/flight/flight_client.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/query/service/src/servers/flight/flight_client.rs b/src/query/service/src/servers/flight/flight_client.rs index 40ac8d1041aa2..353f2d32c53af 100644 --- a/src/query/service/src/servers/flight/flight_client.rs +++ b/src/query/service/src/servers/flight/flight_client.rs @@ -679,7 +679,10 @@ impl FlightDataAckStream { if let Ok(flight_data) = message { let packet = DataPacket::try_from(flight_data).unwrap(); if let DataPacket::FlightControl(FlightControlCommand::Close) = packet { - state.lock().finish = true; + let mut state_guard = state.lock(); + state_guard.finish = true; + state_guard.receiver.close(); + drop(state_guard); } } } From e37a1268d0365da84ff6192efc815ce71e40f556 Mon Sep 17 00:00:00 2001 From: Liuqing Yue Date: Fri, 20 Sep 2024 12:24:30 +0800 Subject: [PATCH 16/18] chore: use settings to support switch the feature on client side --- src/query/config/src/config.rs | 12 --------- src/query/config/src/inner.rs | 4 --- src/query/service/src/clusters/cluster.rs | 19 +++++++++----- src/query/service/src/clusters/mod.rs | 1 + .../src/interpreters/interpreter_kill.rs | 9 +++++-- .../interpreters/interpreter_set_priority.rs | 9 +++++-- .../interpreters/interpreter_system_action.rs | 9 +++++-- .../interpreter_table_truncate.rs | 9 +++++-- .../src/servers/admin/v1/query_profiling.rs | 9 +++++-- .../src/servers/flight/flight_client.rs | 17 ++++++------ .../flight/v1/exchange/exchange_manager.rs | 26 ++++++++++++------- .../flight/v1/packets/packet_publisher.rs | 5 ++-- .../storages/testdata/configs_table_basic.txt | 2 -- .../settings/src/settings_getter_setter.rs | 8 +++--- 14 files changed, 79 insertions(+), 60 deletions(-) diff --git a/src/query/config/src/config.rs b/src/query/config/src/config.rs index d6f3e64f0a2dc..767bdb255b226 100644 --- a/src/query/config/src/config.rs +++ b/src/query/config/src/config.rs @@ -1681,14 +1681,6 @@ pub struct QueryConfig { #[clap(long, value_name = "VALUE", default_value = "50")] pub max_cached_queries_profiles: usize, - /// The maximum retry count for cluster flight. Disable if 0. - #[clap(long, value_name = "VALUE", default_value = "3")] - pub max_flight_connection_retry_times: u64, - - /// The retry interval of cluster flight is in seconds. - #[clap(long, value_name = "VALUE", default_value = "3")] - pub flight_connection_retry_interval: u64, - #[clap(skip)] pub settings: HashMap, } @@ -1778,8 +1770,6 @@ impl TryInto for QueryConfig { cloud_control_grpc_server_address: self.cloud_control_grpc_server_address, cloud_control_grpc_timeout: self.cloud_control_grpc_timeout, max_cached_queries_profiles: self.max_cached_queries_profiles, - max_flight_connection_retry_times: self.max_flight_connection_retry_times, - flight_connection_retry_interval: self.flight_connection_retry_interval, settings: self .settings .into_iter() @@ -1882,8 +1872,6 @@ impl From for QueryConfig { cloud_control_grpc_server_address: inner.cloud_control_grpc_server_address, cloud_control_grpc_timeout: inner.cloud_control_grpc_timeout, max_cached_queries_profiles: inner.max_cached_queries_profiles, - max_flight_connection_retry_times: inner.max_flight_connection_retry_times, - flight_connection_retry_interval: inner.flight_connection_retry_interval, settings: HashMap::new(), } } diff --git a/src/query/config/src/inner.rs b/src/query/config/src/inner.rs index 515c20715ce83..cf6dc8fb847f0 100644 --- a/src/query/config/src/inner.rs +++ b/src/query/config/src/inner.rs @@ -240,8 +240,6 @@ pub struct QueryConfig { pub cloud_control_grpc_server_address: Option, pub cloud_control_grpc_timeout: u64, pub max_cached_queries_profiles: usize, - pub max_flight_connection_retry_times: u64, - pub flight_connection_retry_interval: u64, pub settings: HashMap, } @@ -318,8 +316,6 @@ impl Default for QueryConfig { cloud_control_grpc_timeout: 0, data_retention_time_in_days_max: 90, max_cached_queries_profiles: 50, - max_flight_connection_retry_times: 3, - flight_connection_retry_interval: 3, settings: HashMap::new(), } } diff --git a/src/query/service/src/clusters/cluster.rs b/src/query/service/src/clusters/cluster.rs index d9eca12b21911..246015ca0670f 100644 --- a/src/query/service/src/clusters/cluster.rs +++ b/src/query/service/src/clusters/cluster.rs @@ -85,7 +85,7 @@ pub trait ClusterHelper { &self, path: &str, message: HashMap, - timeout: u64, + flight_params: FlightParams, ) -> Result>; } @@ -122,7 +122,7 @@ impl ClusterHelper for Cluster { &self, path: &str, message: HashMap, - timeout: u64, + flight_params: FlightParams, ) -> Result> { fn get_node<'a>(nodes: &'a [Arc], id: &str) -> Result<&'a Arc> { for node in nodes { @@ -148,8 +148,6 @@ impl ClusterHelper for Cluster { async move { let mut attempt = 0; - let max_attempts = config.query.max_flight_connection_retry_times; - let retry_interval = config.query.flight_connection_retry_interval; loop { let mut conn = create_client(&config, &flight_address).await?; @@ -158,19 +156,19 @@ impl ClusterHelper for Cluster { path, node_secret.clone(), message.clone(), - timeout, + flight_params.timeout, ) .await { Ok(result) => return Ok((id, result)), Err(e) if e.code() == ErrorCode::CANNOT_CONNECT_NODE - && attempt < max_attempts => + && attempt < flight_params.retry_times => { // only retry when error is network problem info!("retry do_action, attempt: {}", attempt); attempt += 1; - sleep(Duration::from_secs(retry_interval)).await; + sleep(Duration::from_secs(flight_params.retry_interval)).await; } Err(e) => return Err(e), } @@ -529,3 +527,10 @@ pub async fn create_client(config: &InnerConfig, address: &str) -> Result Result { let cluster = self.ctx.get_cluster(); let settings = self.ctx.get_settings(); - let timeout = settings.get_flight_client_timeout()?; let mut message = HashMap::with_capacity(cluster.nodes.len()); @@ -63,9 +63,14 @@ impl KillInterpreter { message.insert(node_info.id.clone(), self.plan.clone()); } } + let flight_params = FlightParams { + timeout: settings.get_flight_client_timeout()?, + retry_times: settings.get_max_flight_retry_times()?, + retry_interval: settings.get_flight_retry_interval()?, + }; let res = cluster - .do_action::<_, bool>(KILL_QUERY, message, timeout) + .do_action::<_, bool>(KILL_QUERY, message, flight_params) .await?; match res.values().any(|x| *x) { diff --git a/src/query/service/src/interpreters/interpreter_set_priority.rs b/src/query/service/src/interpreters/interpreter_set_priority.rs index 0dda6b9dd656b..05f64b30af65a 100644 --- a/src/query/service/src/interpreters/interpreter_set_priority.rs +++ b/src/query/service/src/interpreters/interpreter_set_priority.rs @@ -21,6 +21,7 @@ use databend_common_exception::Result; use databend_common_sql::plans::SetPriorityPlan; use crate::clusters::ClusterHelper; +use crate::clusters::FlightParams; use crate::interpreters::Interpreter; use crate::pipelines::PipelineBuildResult; use crate::servers::flight::v1::actions::SET_PRIORITY; @@ -61,9 +62,13 @@ impl SetPriorityInterpreter { } let settings = self.ctx.get_settings(); - let timeout = settings.get_flight_client_timeout()?; + let flight_params = FlightParams { + timeout: settings.get_flight_client_timeout()?, + retry_times: settings.get_max_flight_retry_times()?, + retry_interval: settings.get_flight_retry_interval()?, + }; let res = cluster - .do_action::<_, bool>(SET_PRIORITY, message, timeout) + .do_action::<_, bool>(SET_PRIORITY, message, flight_params) .await?; match res.values().any(|x| *x) { diff --git a/src/query/service/src/interpreters/interpreter_system_action.rs b/src/query/service/src/interpreters/interpreter_system_action.rs index 86e747e865ee7..c3570923ff9d2 100644 --- a/src/query/service/src/interpreters/interpreter_system_action.rs +++ b/src/query/service/src/interpreters/interpreter_system_action.rs @@ -22,6 +22,7 @@ use databend_common_sql::plans::SystemAction; use databend_common_sql::plans::SystemPlan; use crate::clusters::ClusterHelper; +use crate::clusters::FlightParams; use crate::interpreters::Interpreter; use crate::pipelines::PipelineBuildResult; use crate::servers::flight::v1::actions::SYSTEM_ACTION; @@ -74,9 +75,13 @@ impl Interpreter for SystemActionInterpreter { } let settings = self.ctx.get_settings(); - let timeout = settings.get_flight_client_timeout()?; + let flight_params = FlightParams { + timeout: settings.get_flight_client_timeout()?, + retry_times: settings.get_max_flight_retry_times()?, + retry_interval: settings.get_flight_retry_interval()?, + }; cluster - .do_action::<_, ()>(SYSTEM_ACTION, message, timeout) + .do_action::<_, ()>(SYSTEM_ACTION, message, flight_params) .await?; } diff --git a/src/query/service/src/interpreters/interpreter_table_truncate.rs b/src/query/service/src/interpreters/interpreter_table_truncate.rs index 09a19f79cc43d..850ef56f0303c 100644 --- a/src/query/service/src/interpreters/interpreter_table_truncate.rs +++ b/src/query/service/src/interpreters/interpreter_table_truncate.rs @@ -21,6 +21,7 @@ use databend_common_exception::Result; use databend_common_sql::plans::TruncateTablePlan; use crate::clusters::ClusterHelper; +use crate::clusters::FlightParams; use crate::interpreters::Interpreter; use crate::pipelines::PipelineBuildResult; use crate::servers::flight::v1::actions::TRUNCATE_TABLE; @@ -95,9 +96,13 @@ impl Interpreter for TruncateTableInterpreter { } let settings = self.ctx.get_settings(); - let timeout = settings.get_flight_client_timeout()?; + let flight_params = FlightParams { + timeout: settings.get_flight_client_timeout()?, + retry_times: settings.get_max_flight_retry_times()?, + retry_interval: settings.get_flight_retry_interval()?, + }; cluster - .do_action::<_, ()>(TRUNCATE_TABLE, message, timeout) + .do_action::<_, ()>(TRUNCATE_TABLE, message, flight_params) .await?; } diff --git a/src/query/service/src/servers/admin/v1/query_profiling.rs b/src/query/service/src/servers/admin/v1/query_profiling.rs index 649c16baeb3a2..6f1f6dc7e1182 100644 --- a/src/query/service/src/servers/admin/v1/query_profiling.rs +++ b/src/query/service/src/servers/admin/v1/query_profiling.rs @@ -30,6 +30,7 @@ use poem::IntoResponse; use crate::clusters::ClusterDiscovery; use crate::clusters::ClusterHelper; +use crate::clusters::FlightParams; use crate::servers::flight::v1::actions::GET_PROFILE; use crate::sessions::SessionManager; @@ -103,9 +104,13 @@ async fn get_cluster_profile(query_id: &str) -> Result, ErrorCo message.insert(node_info.id.clone(), query_id.to_owned()); } } - + let flight_params = FlightParams { + timeout: 60, + retry_times: 3, + retry_interval: 3, + }; let res = cluster - .do_action::<_, Option>>(GET_PROFILE, message, 60) + .do_action::<_, Option>>(GET_PROFILE, message, flight_params) .await?; match res.into_values().find(Option::is_some) { diff --git a/src/query/service/src/servers/flight/flight_client.rs b/src/query/service/src/servers/flight/flight_client.rs index 353f2d32c53af..93d9de28a5fb4 100644 --- a/src/query/service/src/servers/flight/flight_client.rs +++ b/src/query/service/src/servers/flight/flight_client.rs @@ -51,6 +51,7 @@ use tonic::Request; use tonic::Status; use tonic::Streaming; +use crate::clusters::FlightParams; use crate::pipelines::executor::WatchNotify; use crate::servers::flight::request_builder::RequestBuilder; use crate::servers::flight::v1::exchange::DataExchangeManager; @@ -134,8 +135,7 @@ impl FlightClient { query_id: &str, target: &str, source_address: &str, - retry_times: usize, - retry_interval: usize, + flight_params: FlightParams, ) -> Result { let (server_tx, server_rx) = async_channel::bounded(1); let req = RequestBuilder::create(Box::pin(server_rx)) @@ -155,8 +155,8 @@ impl FlightClient { target: target.to_string(), fragment: None, source_address: source_address.to_string(), - retry_times, - retry_interval: Duration::from_secs(retry_interval as u64), + retry_times: flight_params.retry_times, + retry_interval: Duration::from_secs(flight_params.retry_interval), }), server_tx, )) @@ -170,8 +170,7 @@ impl FlightClient { target: &str, fragment: usize, source_address: &str, - retry_times: usize, - retry_interval: usize, + flight_params: FlightParams, ) -> Result { let (server_tx, server_rx) = async_channel::bounded(1); @@ -195,8 +194,8 @@ impl FlightClient { target: target.to_string(), fragment: Some(fragment), source_address: source_address.to_string(), - retry_times, - retry_interval: Duration::from_secs(retry_interval as u64), + retry_times: flight_params.retry_times, + retry_interval: Duration::from_secs(flight_params.retry_interval), }), server_tx, )) @@ -292,7 +291,7 @@ pub struct ConnectionInfo { pub target: String, pub fragment: Option, pub source_address: String, - pub retry_times: usize, + pub retry_times: u64, pub retry_interval: Duration, } diff --git a/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs b/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs index 92ce77e045ba8..b55f55b0f2676 100644 --- a/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs +++ b/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs @@ -51,6 +51,7 @@ use super::exchange_transform::ExchangeTransform; use super::statistics_receiver::StatisticsReceiver; use super::statistics_sender::StatisticsSender; use crate::clusters::ClusterHelper; +use crate::clusters::FlightParams; use crate::pipelines::executor::ExecutorSettings; use crate::pipelines::executor::PipelineCompleteExecutor; use crate::pipelines::PipelineBuildResult; @@ -131,8 +132,11 @@ impl DataExchangeManager { let config = GlobalConfig::instance(); let with_cur_rt = env.create_rpc_clint_with_current_rt; - let flight_retry_times = config.query.max_flight_connection_retry_times as usize; - let flight_retry_interval = config.query.flight_connection_retry_interval as usize; + let flight_params = FlightParams { + timeout: env.settings.get_flight_client_timeout()?, + retry_times: env.settings.get_max_flight_retry_times()?, + retry_interval: env.settings.get_flight_retry_interval()?, + }; let mut request_exchanges = HashMap::new(); let mut targets_exchanges = HashMap::new(); @@ -165,8 +169,7 @@ impl DataExchangeManager { &target.id, v, &address, - flight_retry_times, - flight_retry_interval, + flight_params, ) .await?, }, @@ -177,8 +180,7 @@ impl DataExchangeManager { &query_id, &target.id, &address, - flight_retry_times, - flight_retry_interval, + flight_params, ) .await?, }, @@ -452,13 +454,17 @@ impl DataExchangeManager { actions: QueryFragmentsActions, ) -> Result { let settings = ctx.get_settings(); - let timeout = settings.get_flight_client_timeout()?; + let flight_params = FlightParams { + timeout: settings.get_flight_client_timeout()?, + retry_times: settings.get_max_flight_retry_times()?, + retry_interval: settings.get_flight_retry_interval()?, + }; let root_actions = actions.get_root_actions()?; let conf = GlobalConfig::instance(); // Initialize query env between cluster nodes let query_env = actions.get_query_env()?; - query_env.init(&ctx, timeout).await?; + query_env.init(&ctx, flight_params).await?; // Submit distributed tasks to all nodes. let cluster = ctx.get_cluster(); @@ -467,7 +473,7 @@ impl DataExchangeManager { let local_fragments = query_fragments.remove(&conf.query.node_id); let _: HashMap = cluster - .do_action(INIT_QUERY_FRAGMENTS, query_fragments, timeout) + .do_action(INIT_QUERY_FRAGMENTS, query_fragments, flight_params) .await?; self.set_ctx(&ctx.get_id(), ctx.clone())?; @@ -480,7 +486,7 @@ impl DataExchangeManager { let prepared_query = actions.prepared_query()?; let _: HashMap = cluster - .do_action(START_PREPARED_QUERY, prepared_query, timeout) + .do_action(START_PREPARED_QUERY, prepared_query, flight_params) .await?; Ok(build_res) diff --git a/src/query/service/src/servers/flight/v1/packets/packet_publisher.rs b/src/query/service/src/servers/flight/v1/packets/packet_publisher.rs index 33c5f20e11b7d..12bb478627759 100644 --- a/src/query/service/src/servers/flight/v1/packets/packet_publisher.rs +++ b/src/query/service/src/servers/flight/v1/packets/packet_publisher.rs @@ -34,6 +34,7 @@ use serde::Deserialize; use serde::Serialize; use crate::clusters::ClusterHelper; +use crate::clusters::FlightParams; use crate::servers::flight::v1::actions::INIT_QUERY_ENV; use crate::sessions::QueryContext; use crate::sessions::SessionManager; @@ -140,7 +141,7 @@ pub struct QueryEnv { } impl QueryEnv { - pub async fn init(&self, ctx: &Arc, timeout: u64) -> Result<()> { + pub async fn init(&self, ctx: &Arc, flight_params: FlightParams) -> Result<()> { debug!("Dataflow diagram {:?}", self.dataflow_diagram); let cluster = ctx.get_cluster(); @@ -151,7 +152,7 @@ impl QueryEnv { } let _ = cluster - .do_action::<_, ()>(INIT_QUERY_ENV, message, timeout) + .do_action::<_, ()>(INIT_QUERY_ENV, message, flight_params) .await?; Ok(()) diff --git a/src/query/service/tests/it/storages/testdata/configs_table_basic.txt b/src/query/service/tests/it/storages/testdata/configs_table_basic.txt index 731914688cf0d..7e20e16c1e523 100644 --- a/src/query/service/tests/it/storages/testdata/configs_table_basic.txt +++ b/src/query/service/tests/it/storages/testdata/configs_table_basic.txt @@ -88,7 +88,6 @@ DB.Table: 'system'.'configs', Table: configs-table_id:1, ver:0, Engine: SystemCo | 'query' | 'enable_meta_data_upgrade_json_to_pb_from_v307' | 'false' | '' | | 'query' | 'enable_udf_server' | 'false' | '' | | 'query' | 'flight_api_address' | '127.0.0.1:9090' | '' | -| 'query' | 'flight_connection_retry_interval' | '3' | '' | | 'query' | 'flight_sql_handler_host' | '127.0.0.1' | '' | | 'query' | 'flight_sql_handler_port' | '8900' | '' | | 'query' | 'flight_sql_tls_server_cert' | '' | '' | @@ -106,7 +105,6 @@ DB.Table: 'system'.'configs', Table: configs-table_id:1, ver:0, Engine: SystemCo | 'query' | 'management_mode' | 'false' | '' | | 'query' | 'max_active_sessions' | '256' | '' | | 'query' | 'max_cached_queries_profiles' | '50' | '' | -| 'query' | 'max_flight_connection_retry_times' | '3' | '' | | 'query' | 'max_memory_limit_enabled' | 'false' | '' | | 'query' | 'max_query_log_size' | '10000' | '' | | 'query' | 'max_running_queries' | '8' | '' | diff --git a/src/query/settings/src/settings_getter_setter.rs b/src/query/settings/src/settings_getter_setter.rs index b7f3917034d17..41266d8d96237 100644 --- a/src/query/settings/src/settings_getter_setter.rs +++ b/src/query/settings/src/settings_getter_setter.rs @@ -722,12 +722,12 @@ impl Settings { Ok(self.try_get_u64("random_function_seed")? == 1) } - pub fn get_flight_retry_interval(&self) -> Result { - Ok(self.try_get_u64("flight_connection_retry_interval")? as usize) + pub fn get_flight_retry_interval(&self) -> Result { + Ok(self.try_get_u64("flight_connection_retry_interval")?) } - pub fn get_max_flight_retry_times(&self) -> Result { - Ok(self.try_get_u64("max_flight_connection_retry_times")? as usize) + pub fn get_max_flight_retry_times(&self) -> Result { + Ok(self.try_get_u64("max_flight_connection_retry_times")?) } pub fn get_dynamic_sample_time_budget_ms(&self) -> Result { From 2d50a4dd37a4cf8d3ccfa0eaba4b0c12e9fa6a38 Mon Sep 17 00:00:00 2001 From: Liuqing Yue Date: Fri, 20 Sep 2024 14:52:15 +0800 Subject: [PATCH 17/18] chore: use settings to support switch the feature on server side --- src/query/service/src/clusters/cluster.rs | 2 +- .../src/servers/flight/flight_client.rs | 17 ++++++++- .../flight/v1/exchange/exchange_manager.rs | 38 +++++++++++-------- .../src/servers/flight/v1/flight_service.rs | 11 ++++++ src/query/settings/src/settings_default.rs | 12 ++++++ 5 files changed, 62 insertions(+), 18 deletions(-) diff --git a/src/query/service/src/clusters/cluster.rs b/src/query/service/src/clusters/cluster.rs index 246015ca0670f..7a3feb1544958 100644 --- a/src/query/service/src/clusters/cluster.rs +++ b/src/query/service/src/clusters/cluster.rs @@ -528,7 +528,7 @@ pub async fn create_client(config: &InnerConfig, address: &str) -> Result RequestBuilder::create(Box::pin(server_rx)) .with_metadata("x-type", "request_server_exchange")? .with_metadata("x-target", &info.target)? .with_metadata("x-query-id", &info.query_id)? .with_metadata("x-continue-from", &seq.to_string())? + .with_metadata("x-enable-retry", &info.retry_times.to_string())? .build(), }; let request = databend_common_tracing::inject_span_to_tonic_request(request); @@ -642,6 +646,7 @@ impl FlightDataAckState { pub struct FlightDataAckStream { notify: Arc, state: Arc>, + enable_retry: bool, } impl FlightDataAckStream { @@ -649,6 +654,7 @@ impl FlightDataAckStream { state: Arc>, begin: usize, client_stream: Streaming, + enable_retry: bool, ) -> Result { let notify = Self::streaming_receiver(state.clone(), client_stream); let mut state_guard = state.lock(); @@ -658,7 +664,11 @@ impl FlightDataAckStream { handle.abort(); } drop(state_guard); - Ok(FlightDataAckStream { notify, state }) + Ok(FlightDataAckStream { + notify, + state, + enable_retry, + }) } fn streaming_receiver( @@ -700,8 +710,11 @@ impl FlightDataAckStream { impl Drop for FlightDataAckStream { fn drop(&mut self) { + info!("Drop FlightDataAckStream enable: {:?}", self.enable_retry); let mut state = self.state.lock(); - if state.finish { + // if not enable retry, fallback to close immediately + if !self.enable_retry || state.finish { + info!("{:?} {:?}", state.finish, self.enable_retry); self.notify.notify_waiters(); state.receiver.close(); return; diff --git a/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs b/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs index b55f55b0f2676..e7fadcf7d1d30 100644 --- a/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs +++ b/src/query/service/src/servers/flight/v1/exchange/exchange_manager.rs @@ -131,13 +131,11 @@ impl DataExchangeManager { let config = GlobalConfig::instance(); let with_cur_rt = env.create_rpc_clint_with_current_rt; - let flight_params = FlightParams { timeout: env.settings.get_flight_client_timeout()?, retry_times: env.settings.get_max_flight_retry_times()?, retry_interval: env.settings.get_flight_retry_interval()?, }; - let mut request_exchanges = HashMap::new(); let mut targets_exchanges = HashMap::new(); @@ -372,19 +370,22 @@ impl DataExchangeManager { target: String, continue_from: usize, client_stream: Streaming, + enable_retry: bool, ) -> Result { let queries_coordinator_guard = self.queries_coordinator.lock(); let queries_coordinator = unsafe { &mut *queries_coordinator_guard.deref().get() }; match queries_coordinator.entry(id) { - Entry::Occupied(mut v) => { - v.get_mut() - .add_statistics_exchange(target, continue_from, client_stream) - } + Entry::Occupied(mut v) => v.get_mut().add_statistics_exchange( + target, + continue_from, + client_stream, + enable_retry, + ), Entry::Vacant(v) => match continue_from == 0 { true => v .insert(QueryCoordinator::create()) - .add_statistics_exchange(target, continue_from, client_stream), + .add_statistics_exchange(target, continue_from, client_stream, enable_retry), false => Err(ErrorCode::Timeout( "Reconnection timeout, the state has been cleared", )), @@ -400,21 +401,26 @@ impl DataExchangeManager { fragment: usize, continue_from: usize, client_stream: Streaming, + enable_retry: bool, ) -> Result { let queries_coordinator_guard = self.queries_coordinator.lock(); let queries_coordinator = unsafe { &mut *queries_coordinator_guard.deref().get() }; match queries_coordinator.entry(query) { - Entry::Occupied(mut v) => { - v.get_mut() - .add_fragment_exchange(target, fragment, continue_from, client_stream) - } + Entry::Occupied(mut v) => v.get_mut().add_fragment_exchange( + target, + fragment, + continue_from, + client_stream, + enable_retry, + ), Entry::Vacant(v) => match continue_from == 0 { true => v.insert(QueryCoordinator::create()).add_fragment_exchange( target, fragment, continue_from, client_stream, + enable_retry, ), false => Err(ErrorCode::Timeout( "Reconnection timeout, the state has been cleared", @@ -678,6 +684,7 @@ impl QueryCoordinator { target: String, begin: usize, client_stream: Streaming, + enable_retry: bool, ) -> Result { let (tx, rx) = async_channel::bounded(8); let identifier = ExchangeIdentifier::Statistics(target); @@ -686,11 +693,11 @@ impl QueryCoordinator { Entry::Vacant(v) => { let state = FlightDataAckState::create(rx, 10); v.insert(FlightExchange::create_sender(state.clone(), tx)); - FlightDataAckStream::create(state, begin, client_stream) + FlightDataAckStream::create(state, begin, client_stream, enable_retry) } Entry::Occupied(mut v) => match v.get_mut() { FlightExchange::MovedSender(v) => { - FlightDataAckStream::create(v.state.clone(), begin, client_stream) + FlightDataAckStream::create(v.state.clone(), begin, client_stream, enable_retry) } _ => Err(ErrorCode::Internal( "statistics exchanges can only have one", @@ -721,6 +728,7 @@ impl QueryCoordinator { fragment: usize, begin: usize, client_stream: Streaming, + enable_retry: bool, ) -> Result { let (tx, rx) = async_channel::bounded(8); let identifier = ExchangeIdentifier::fragment_sender(target, fragment); @@ -729,11 +737,11 @@ impl QueryCoordinator { Entry::Vacant(v) => { let state = FlightDataAckState::create(rx, 10); v.insert(FlightExchange::create_sender(state.clone(), tx)); - FlightDataAckStream::create(state, begin, client_stream) + FlightDataAckStream::create(state, begin, client_stream, enable_retry) } Entry::Occupied(mut v) => match v.get_mut() { FlightExchange::MovedSender(v) => { - FlightDataAckStream::create(v.state.clone(), begin, client_stream) + FlightDataAckStream::create(v.state.clone(), begin, client_stream, enable_retry) } _ => Err(ErrorCode::Internal("fragment exchange can only have one")), }, diff --git a/src/query/service/src/servers/flight/v1/flight_service.rs b/src/query/service/src/servers/flight/v1/flight_service.rs index fc51067803952..ee4d3382a8537 100644 --- a/src/query/service/src/servers/flight/v1/flight_service.rs +++ b/src/query/service/src/servers/flight/v1/flight_service.rs @@ -111,6 +111,11 @@ impl FlightOperation for DatabendQueryFlightService { .get_metadata("x-continue-from")? .parse::() .unwrap(); + // if x-enable-retry is not 0, set enable_retry to a bool value true, otherwise false + let enable_retry: u64 = request + .get_metadata("x-enable-retry")? + .parse::() + .unwrap_or_default(); let client_stream = request.into_inner(); Ok(RawResponse::new(Box::pin( DataExchangeManager::instance().handle_statistics_exchange( @@ -118,6 +123,7 @@ impl FlightOperation for DatabendQueryFlightService { target, continue_from, client_stream, + enable_retry != 0, )?, ))) } @@ -132,6 +138,10 @@ impl FlightOperation for DatabendQueryFlightService { .get_metadata("x-continue-from")? .parse::() .unwrap(); + let enable_retry: u64 = request + .get_metadata("x-enable-retry")? + .parse::() + .unwrap_or_default(); let client_stream = request.into_inner(); Ok(RawResponse::new(Box::pin( DataExchangeManager::instance().handle_exchange_fragment( @@ -140,6 +150,7 @@ impl FlightOperation for DatabendQueryFlightService { fragment, continue_from, client_stream, + enable_retry != 0, )?, ))) } diff --git a/src/query/settings/src/settings_default.rs b/src/query/settings/src/settings_default.rs index deee0343df8ce..9587d7bc9f1b2 100644 --- a/src/query/settings/src/settings_default.rs +++ b/src/query/settings/src/settings_default.rs @@ -882,6 +882,18 @@ impl DefaultSettings { mode: SettingMode::Both, range: Some(SettingRange::Numeric(0..=1)), }), + ("max_flight_connection_retry_times", DefaultSettingValue { + value: UserSettingValue::UInt64(3), + desc: "The maximum retry count for cluster flight. Disable if 0.", + mode: SettingMode::Both, + range: Some(SettingRange::Numeric(0..=30)), + }), + ("flight_connection_retry_interval", DefaultSettingValue { + value: UserSettingValue::UInt64(3), + desc: "The retry interval of cluster flight is in seconds.", + mode: SettingMode::Both, + range: Some(SettingRange::Numeric(0..=900)), + }), ("random_function_seed", DefaultSettingValue { value: UserSettingValue::UInt64(0), desc: "Seed for random function", From f8aa71e068b17d75330bd49d7fc6aa20de249c43 Mon Sep 17 00:00:00 2001 From: Liuqing Yue Date: Fri, 20 Sep 2024 15:31:09 +0800 Subject: [PATCH 18/18] chore: refine to make clippy happy --- src/query/settings/src/settings_getter_setter.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/query/settings/src/settings_getter_setter.rs b/src/query/settings/src/settings_getter_setter.rs index 41266d8d96237..b52254e2e017a 100644 --- a/src/query/settings/src/settings_getter_setter.rs +++ b/src/query/settings/src/settings_getter_setter.rs @@ -723,11 +723,11 @@ impl Settings { } pub fn get_flight_retry_interval(&self) -> Result { - Ok(self.try_get_u64("flight_connection_retry_interval")?) + self.try_get_u64("flight_connection_retry_interval") } pub fn get_max_flight_retry_times(&self) -> Result { - Ok(self.try_get_u64("max_flight_connection_retry_times")?) + self.try_get_u64("max_flight_connection_retry_times") } pub fn get_dynamic_sample_time_budget_ms(&self) -> Result {