diff --git a/src/distributed_physical_optimizer_rule.rs b/src/distributed_physical_optimizer_rule.rs index 17935bf..2f89f6f 100644 --- a/src/distributed_physical_optimizer_rule.rs +++ b/src/distributed_physical_optimizer_rule.rs @@ -1,5 +1,8 @@ use super::{NetworkShuffleExec, PartitionIsolatorExec, StageExec}; use crate::execution_plans::NetworkCoalesceExec; +use crate::metrics::proto::MetricsSetProto; +use crate::protobuf::StageKey; +use dashmap::DashMap; use datafusion::common::plan_err; use datafusion::common::tree_node::TreeNodeRecursion; use datafusion::datasource::source::DataSourceExec; @@ -318,6 +321,20 @@ pub trait NetworkBoundary: ExecutionPlan { } Ok(Arc::clone(children.first().unwrap())) } + + /// metrics_collection is used to collect metrics from child tasks. It is empty when a + /// [NetworkBoundary] is instantiated (deserialized, created via new() etc...). + /// Metrics are populated by executing() the [NetworkBoundary]. It's expected that the + /// collection is complete after the [NetworkBoundary] has been executed. It is undefined + /// what this returns during execution. + /// + /// An instance may receive metrics for 0 to N child tasks, where N is the number of tasks + /// in the stage it is reading from. This is because, by convention, the ArrowFlightEndpoint + /// sends metrics for a task to the last [NetworkBoundary] to read from it, which may or may + /// not be this instance. + fn metrics_collection(&self) -> Option>>> { + None + } } /// Error thrown during distributed planning that prompts the planner to change something and diff --git a/src/execution_plans/metrics.rs b/src/execution_plans/metrics.rs index 4a1ff59..529c85b 100644 --- a/src/execution_plans/metrics.rs +++ b/src/execution_plans/metrics.rs @@ -1,18 +1,31 @@ +use crate::distributed_physical_optimizer_rule::NetworkBoundary; use crate::execution_plans::{NetworkCoalesceExec, NetworkShuffleExec, StageExec}; use crate::metrics::proto::{MetricsSetProto, metrics_set_proto_to_df}; +use arrow::ipc::writer::DictionaryTracker; +use arrow::record_batch::RecordBatch; +use arrow_flight::FlightData; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::ipc::writer::IpcDataGenerator; +use datafusion::arrow::ipc::writer::IpcWriteOptions; +use datafusion::execution::TaskContext; +use datafusion::physical_plan::metrics::MetricsSet; +use futures::{Stream, stream}; +use std::collections::HashMap; +use std::sync::Arc; + +use crate::metrics::proto::df_metrics_set_to_proto; use crate::protobuf::StageKey; +use crate::protobuf::{AppMetadata, FlightAppMetadata, MetricsCollection, TaskMetrics}; +use arrow_flight::error::FlightError; use datafusion::common::internal_err; use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter}; use datafusion::error::{DataFusionError, Result}; use datafusion::execution::SendableRecordBatchStream; -use datafusion::execution::TaskContext; use datafusion::physical_plan::ExecutionPlan; -use datafusion::physical_plan::metrics::MetricsSet; use datafusion::physical_plan::{DisplayAs, DisplayFormatType, PlanProperties}; +use prost::Message; use std::any::Any; -use std::collections::HashMap; use std::fmt::{Debug, Formatter}; -use std::sync::Arc; /// TaskMetricsCollector is used to collect metrics from a task. It implements [TreeNodeRewriter]. /// Note: TaskMetricsCollector is not a [datafusion::physical_plan::ExecutionPlanVisitor] to keep @@ -29,35 +42,33 @@ pub struct TaskMetricsCollector { #[allow(dead_code)] pub struct MetricsCollectorResult { // metrics is a collection of metrics for a task ordered using a pre-order traversal of the task's plan. - task_metrics: Vec, + pub(super) task_metrics: Vec, // child_task_metrics contains metrics for child tasks if they were collected. - child_task_metrics: HashMap>, + pub(super) child_task_metrics: HashMap>, } impl TreeNodeRewriter for TaskMetricsCollector { type Node = Arc; fn f_down(&mut self, plan: Self::Node) -> Result> { - // If the plan is an NetworkShuffleExec, assume it has collected metrics already + // If the plan is a NetwordBoundary, assume it has collected metrics already // from child tasks. let metrics_collection = if let Some(node) = plan.as_any().downcast_ref::() { - let NetworkShuffleExec::Ready(ready) = node else { - return internal_err!( - "unexpected NetworkShuffleExec::Pending during metrics collection" - ); - }; - Some(Arc::clone(&ready.metrics_collection)) + node.metrics_collection() + .map(Some) + .ok_or(DataFusionError::Internal( + "could not collect metrics from NetworkShuffleExec".to_string(), + )) } else if let Some(node) = plan.as_any().downcast_ref::() { - let NetworkCoalesceExec::Ready(ready) = node else { - return internal_err!( - "unexpected NetworkCoalesceExec::Pending during metrics collection" - ); - }; - Some(Arc::clone(&ready.metrics_collection)) + node.metrics_collection() + .map(Some) + .ok_or(DataFusionError::Internal( + "could not collect metrics from NetworkCoalesceExec".to_string(), + )) } else { - None - }; + Ok(None) + }?; if let Some(metrics_collection) = metrics_collection { for mut entry in metrics_collection.iter_mut() { @@ -104,13 +115,16 @@ impl TaskMetricsCollector { } } - /// collect metrics from a StageExec plan and any child tasks. + /// collect metrics from an [ExecutionPlan] (usually a [StageExec].plan) and any child tasks. /// Returns /// - a vec representing the metrics for the current task (ordered using a pre-order traversal) /// - a map representing the metrics for some subset of child tasks collected from NetworkShuffleExec leaves #[allow(dead_code)] - pub fn collect(mut self, stage: &StageExec) -> Result { - stage.plan.clone().rewrite(&mut self)?; + pub fn collect( + mut self, + plan: Arc, + ) -> Result { + plan.rewrite(&mut self)?; Ok(MetricsCollectorResult { task_metrics: self.task_metrics, child_task_metrics: self.child_task_metrics, @@ -270,20 +284,92 @@ impl ExecutionPlan for MetricsWrapperExec { } } +// Collects metrics from the provided stage and encodes it into a stream of flight data using +// the schema of the stage. +pub fn collect_and_create_metrics_flight_data( + stage_key: StageKey, + stage: Arc, +) -> Result> + Send + 'static, FlightError> { + // Get the metrics for the task executed on this worker. Separately, collect metrics for child tasks. + let mut result = TaskMetricsCollector::new() + .collect(stage.plan.clone()) + .map_err(|err| FlightError::ProtocolError(err.to_string()))?; + + // Add the metrics for this task into the collection of task metrics. + // Skip any metrics that can't be converted to proto (unsupported types) + let proto_task_metrics = result + .task_metrics + .iter() + .map(|metrics| { + df_metrics_set_to_proto(metrics) + .map_err(|err| FlightError::ProtocolError(err.to_string())) + }) + .collect::, FlightError>>()?; + result + .child_task_metrics + .insert(stage_key.clone(), proto_task_metrics.clone()); + + // Serialize the metrics for all tasks. + let mut task_metrics_set = vec![]; + for (stage_key, metrics) in result.child_task_metrics.into_iter() { + task_metrics_set.push(TaskMetrics { + stage_key: Some(stage_key), + metrics, + }); + } + + let flight_app_metadata = FlightAppMetadata { + content: Some(AppMetadata::MetricsCollection(MetricsCollection { + tasks: task_metrics_set, + })), + }; + + let metrics_flight_data = + empty_flight_data_with_app_metadata(flight_app_metadata, stage.plan.schema())?; + Ok(Box::pin(stream::once( + async move { Ok(metrics_flight_data) }, + ))) +} + +/// Creates a FlightData with the given app_metadata and empty RecordBatch using the provided schema. +/// We don't use [arrow_flight::encode::FlightDataEncoder] (and by extension, the [arrow_flight::encode::FlightDataEncoderBuilder]) +/// since they skip messages with empty RecordBatch data. +pub fn empty_flight_data_with_app_metadata( + metadata: FlightAppMetadata, + schema: SchemaRef, +) -> Result { + let mut buf = vec![]; + metadata + .encode(&mut buf) + .map_err(|err| FlightError::ProtocolError(err.to_string()))?; + + let empty_batch = RecordBatch::new_empty(schema); + let options = IpcWriteOptions::default(); + let data_gen = IpcDataGenerator::default(); + let mut dictionary_tracker = DictionaryTracker::new(true); + let (_, encoded_data) = data_gen + .encoded_batch(&empty_batch, &mut dictionary_tracker, &options) + .map_err(|e| { + FlightError::ProtocolError(format!("Failed to create empty batch FlightData: {e}")) + })?; + Ok(FlightData::from(encoded_data).with_app_metadata(buf)) +} + #[cfg(test)] mod tests { use super::*; use datafusion::arrow::array::{Int32Array, StringArray}; use datafusion::arrow::record_batch::RecordBatch; + use futures::StreamExt; use crate::DistributedExt; use crate::DistributedPhysicalOptimizerRule; use crate::test_utils::in_memory_channel_resolver::InMemoryChannelResolver; use crate::test_utils::metrics::make_test_metrics_set_proto_from_seed; + use crate::test_utils::plans::{count_plan_nodes, get_stages_and_stage_keys}; use crate::test_utils::session_context::register_temp_parquet_table; use datafusion::execution::{SessionStateBuilder, context::SessionContext}; - use datafusion::physical_plan::metrics::MetricValue; use datafusion::prelude::SessionConfig; use datafusion::{ arrow::datatypes::{DataType, Field, Schema}, @@ -291,198 +377,283 @@ mod tests { }; use std::sync::Arc; - /// Creates a stage with the following structure: - /// - /// SortPreservingMergeExec - /// SortExec - /// ProjectionExec - /// AggregateExec - /// CoalesceBatchesExec - /// NetworkShuffleExec - /// - /// ... (for the purposes of these tests, we don't care about child stages). - async fn make_test_stage_exec_with_5_nodes() -> (StageExec, SessionContext) { + /// Creates a single node session context + async fn make_test_ctx_single_node() -> SessionContext { + make_test_ctx_helper(false).await + } + + /// Creates a distributed session context with in-memory distributed engine + async fn make_test_ctx() -> SessionContext { + make_test_ctx_helper(true).await + } + + /// Creates a session context and registers two tables: + /// - table1 (id: int, name: string) + /// - table2 (id: int, name: string, phone: string, balance: float64) + async fn make_test_ctx_helper(distributed: bool) -> SessionContext { // Create distributed session state with in-memory channel resolver let config = SessionConfig::new().with_target_partitions(2); - let state = SessionStateBuilder::new() + let mut builder = SessionStateBuilder::new() .with_default_features() - .with_config(config) - .with_distributed_channel_resolver(InMemoryChannelResolver::new()) - .with_physical_optimizer_rule(Arc::new( - DistributedPhysicalOptimizerRule::default() - .with_network_coalesce_tasks(2) - .with_network_shuffle_tasks(2), - )) - .build(); + .with_config(config); + + if distributed { + builder = builder + .with_distributed_channel_resolver(InMemoryChannelResolver::new()) + .with_physical_optimizer_rule(Arc::new( + DistributedPhysicalOptimizerRule::default() + .with_network_coalesce_tasks(2) + .with_network_shuffle_tasks(2), + )) + } + let state = builder.build(); let ctx = SessionContext::from(state); - // Create test data - let schema = Arc::new(Schema::new(vec![ + // Create test data for table1 + let schema1 = Arc::new(Schema::new(vec![ Field::new("id", DataType::Int32, false), Field::new("name", DataType::Utf8, false), ])); - let batches = vec![ + let batches1 = vec![ RecordBatch::try_new( - schema.clone(), + schema1.clone(), vec![ Arc::new(Int32Array::from(vec![1, 2, 3])), Arc::new(StringArray::from(vec!["a", "b", "c"])), ], ) .unwrap(), + ]; + + // Create test data for table2 with extended schema + let schema2 = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + Field::new("phone", DataType::Utf8, false), + Field::new("balance", DataType::Float64, false), + ])); + + let batches2 = vec![ RecordBatch::try_new( - schema.clone(), + schema2.clone(), vec![ - Arc::new(Int32Array::from(vec![4, 5, 6])), - Arc::new(StringArray::from(vec!["d", "e", "f"])), + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec![ + "customer1", + "customer2", + "customer3", + ])), + Arc::new(StringArray::from(vec![ + "13-123-4567", + "31-456-7890", + "23-789-0123", + ])), + Arc::new(datafusion::arrow::array::Float64Array::from(vec![ + 100.5, 250.0, 50.25, + ])), ], ) .unwrap(), ]; - // Register the test data as a parquet table - let _ = register_temp_parquet_table("test_table", schema.clone(), batches, &ctx) + // Register the test data as parquet tables + let _ = register_temp_parquet_table("table1", schema1, batches1, &ctx) .await .unwrap(); - let df = ctx - .sql("SELECT id, COUNT(*) as count FROM test_table WHERE id > 1 GROUP BY id ORDER BY id LIMIT 10") + let _ = register_temp_parquet_table("table2", schema2, batches2, &ctx) .await .unwrap(); + + ctx + } + + /// runs a sql query and returns the coordinator StageExec + async fn plan_sql(ctx: &SessionContext, sql: &str) -> StageExec { + let df = ctx.sql(sql).await.unwrap(); let physical_distributed = df.create_physical_plan().await.unwrap(); let stage_exec = match physical_distributed.as_any().downcast_ref::() { Some(stage_exec) => stage_exec.clone(), None => panic!( - "Expected StageExec from distributed optimization, got: {}", + "expected StageExec from distributed optimization, got: {}", physical_distributed.name() ), }; + stage_exec + } - (stage_exec, ctx) + async fn execute_plan(stage_exec: &StageExec, ctx: &SessionContext) { + let task_ctx = ctx.task_ctx(); + let stream = stage_exec.execute(0, task_ctx).unwrap(); + + let mut stream = stream; + while let Some(batch) = stream.next().await { + batch.unwrap(); + } } - #[tokio::test] - #[ignore] - async fn test_metrics_rewriter() { - let (test_stage, _ctx) = make_test_stage_exec_with_5_nodes().await; - let test_metrics_sets = (0..5) // 5 nodes excluding NetworkShuffleExec - .map(|i| make_test_metrics_set_proto_from_seed(i + 10)) - .collect::>(); + /// Asserts that we can collect metrics from a distributed plan generated from the + /// SQL query. It ensures that metrics are collected for all stages and are propagated + /// through network boundaries. + async fn run_metrics_collection_e2e_test(sql: &str) { + // Plan and execute the query + let ctx = make_test_ctx().await; + let stage_exec = plan_sql(&ctx, sql).await; + execute_plan(&stage_exec, &ctx).await; + + // Assert to ensure the distributed test case is sufficiently complex. + let (stages, expected_stage_keys) = get_stages_and_stage_keys(&stage_exec); + assert!( + expected_stage_keys.len() > 1, + "expected more than 1 stage key in test. the plan was not distributed):\n{}", + DisplayableExecutionPlan::new(&stage_exec).indent(true) + ); + + // Collect metrics for all tasks from the root StageExec. + let collector = TaskMetricsCollector::new(); + let result = collector.collect(stage_exec.plan.clone()).unwrap(); + let mut actual_collected_metrics = result.child_task_metrics; + actual_collected_metrics.insert( + StageKey { + query_id: stage_exec.query_id.to_string(), + stage_id: stage_exec.num as u64, + task_number: 0, + }, + result + .task_metrics + .iter() + .map(|m| df_metrics_set_to_proto(m).unwrap()) + .collect::>(), + ); + + // Ensure that there's metrics for each node for each task for each stage. + for expected_stage_key in expected_stage_keys { + // Get the collected metrics for this task. + let actual_metrics = actual_collected_metrics.get(&expected_stage_key).unwrap(); + + // Assert that there's metrics for each node in this task. + let stage = stages.get(&(expected_stage_key.stage_id as usize)).unwrap(); + assert_eq!(actual_metrics.len(), count_plan_nodes(&stage.plan)); + + // Ensure each node has at least one metric which was collected. + for metrics_set in actual_metrics.iter() { + let metrics_set = metrics_set_proto_to_df(metrics_set).unwrap(); + assert!(metrics_set.iter().count() > 0); + } + } + } - let rewriter = TaskMetricsRewriter::new(test_metrics_sets.clone()); - let plan_with_metrics = rewriter - .enrich_task_with_metrics(test_stage.plan.clone()) + /// Asserts that we successfully re-write the metrics of a plan generated from the provided SQL query. + /// Also asserts that the order which metrics are collected from a plan matches the order which + /// they are re-written (ie. ensures we don't assign metrics to the wrong nodes) + /// + /// Only tests single node plans since the [TaskMetricsRewriter] stops on [NetworkBoundary]. + async fn run_metrics_rewriter_test(sql: &str) { + // Generate the plan + let ctx = make_test_ctx_single_node().await; + let plan = ctx + .sql(sql) + .await + .unwrap() + .create_physical_plan() + .await .unwrap(); - let plan_str = - DisplayableExecutionPlan::with_full_metrics(plan_with_metrics.as_ref()).indent(true); - // Expected distributed plan output with metrics - let expected = [ - r"SortPreservingMergeExec: [id@0 ASC NULLS LAST], fetch=10, metrics=[output_rows=10, elapsed_compute=10ns, start_timestamp=2025-09-18 13:00:10 UTC, end_timestamp=2025-09-18 13:00:11 UTC]", - r" SortExec: TopK(fetch=10), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true], metrics=[output_rows=11, elapsed_compute=11ns, start_timestamp=2025-09-18 13:00:11 UTC, end_timestamp=2025-09-18 13:00:12 UTC]", - r" ProjectionExec: expr=[id@0 as id, count(Int64(1))@1 as count], metrics=[output_rows=12, elapsed_compute=12ns, start_timestamp=2025-09-18 13:00:12 UTC, end_timestamp=2025-09-18 13:00:13 UTC]", - r" AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[count(Int64(1))], metrics=[output_rows=13, elapsed_compute=13ns, start_timestamp=2025-09-18 13:00:13 UTC, end_timestamp=2025-09-18 13:00:14 UTC]", - r" CoalesceBatchesExec: target_batch_size=8192, metrics=[output_rows=14, elapsed_compute=14ns, start_timestamp=2025-09-18 13:00:14 UTC, end_timestamp=2025-09-18 13:00:15 UTC]", - r" NetworkShuffleExec, metrics=[]", - "" // trailing newline - ].join("\n"); - assert_eq!(expected, plan_str.to_string()); - } + // Generate metrics for each plan node. + let expected_metrics = (0..count_plan_nodes(&plan)) + .map(|i| make_test_metrics_set_proto_from_seed(i as u64 + 10)) + .collect::>(); - #[tokio::test] - #[ignore] - async fn test_metrics_rewriter_correct_number_of_metrics() { - let test_metrics_set = make_test_metrics_set_proto_from_seed(10); - let (executable_plan, _ctx) = make_test_stage_exec_with_5_nodes().await; - let task_plan = executable_plan - .as_any() - .downcast_ref::() + // Rewrite the metrics. + let rewriter = TaskMetricsRewriter::new(expected_metrics.clone()); + let rewritten_plan = rewriter.enrich_task_with_metrics(plan.clone()).unwrap(); + + // Collect metrics + let actual_metrics = TaskMetricsCollector::new() + .collect(rewritten_plan) .unwrap() - .plan - .clone(); - - // Too few metrics sets. - let rewriter = TaskMetricsRewriter::new(vec![test_metrics_set.clone()]); - let result = rewriter.enrich_task_with_metrics(task_plan.clone()); - assert!(result.is_err()); - - // Too many metrics sets. - let rewriter = TaskMetricsRewriter::new(vec![ - test_metrics_set.clone(), - test_metrics_set.clone(), - test_metrics_set.clone(), - test_metrics_set.clone(), - ]); - let result = rewriter.enrich_task_with_metrics(task_plan.clone()); - assert!(result.is_err()); + .task_metrics; + + // Assert that all the metrics are present and in the same order. + assert_eq!(actual_metrics.len(), expected_metrics.len()); + for (actual_metrics_set, expected_metrics_set) in actual_metrics + .iter() + .map(|m| df_metrics_set_to_proto(m).unwrap()) + .zip(expected_metrics) + { + assert_eq!(actual_metrics_set, expected_metrics_set); + } } #[tokio::test] - #[ignore] - async fn test_metrics_collection() { - let (stage_exec, ctx) = make_test_stage_exec_with_5_nodes().await; + async fn test_metrics_rewriter_1() { + run_metrics_rewriter_test( + "SELECT sum(balance) / 7.0 as avg_yearly from table2 group by name", + ) + .await; + } - // Execute the plan to completion. - let task_ctx = ctx.task_ctx(); - let stream = stage_exec.execute(0, task_ctx).unwrap(); + #[tokio::test] + async fn test_metrics_rewriter_2() { + run_metrics_rewriter_test("SELECT id, COUNT(*) as count FROM table1 WHERE id > 1 GROUP BY id ORDER BY id LIMIT 10").await; + } - use futures::StreamExt; - let mut stream = stream; - while let Some(_batch) = stream.next().await {} + #[tokio::test] + async fn test_metrics_rewriter_3() { + run_metrics_rewriter_test( + "SELECT sum(balance) / 7.0 as avg_yearly + FROM table2 + WHERE name LIKE 'customer%' + AND balance < ( + SELECT 0.2 * avg(balance) + FROM table2 t2_inner + WHERE t2_inner.id = table2.id + )", + ) + .await; + } - let collector = TaskMetricsCollector::new(); - let result = collector.collect(&stage_exec).unwrap(); - - // With the distributed optimizer, we get a much more complex plan structure - // The exact number of metrics sets depends on the plan optimization, so be flexible - assert_eq!(result.task_metrics.len(), 5); - - let expected_metrics_count = [4, 10, 8, 16, 8]; - for (node_idx, metrics_set) in result.task_metrics.iter().enumerate() { - let metrics_count = metrics_set.iter().count(); - assert_eq!(metrics_count, expected_metrics_count[node_idx]); - - // Each node should have basic metrics: ElapsedCompute, OutputRows, StartTimestamp, EndTimestamp. - let mut has_start_timestamp = false; - let mut has_end_timestamp = false; - let mut has_elapsed_compute = false; - let mut has_output_rows = false; - - for metric in metrics_set.iter() { - let metric_name = metric.value().name(); - let metric_value = metric.value(); - - match metric_value { - MetricValue::StartTimestamp(_) if metric_name == "start_timestamp" => { - has_start_timestamp = true; - } - MetricValue::EndTimestamp(_) if metric_name == "end_timestamp" => { - has_end_timestamp = true; - } - MetricValue::ElapsedCompute(_) if metric_name == "elapsed_compute" => { - has_elapsed_compute = true; - } - MetricValue::OutputRows(_) if metric_name == "output_rows" => { - has_output_rows = true; - } - _ => { - // Other metrics are fine, we just validate the core ones - } - } - } + #[tokio::test] + async fn test_metrics_collection_e2e_1() { + run_metrics_collection_e2e_test("SELECT id, COUNT(*) as count FROM table1 WHERE id > 1 GROUP BY id ORDER BY id LIMIT 10").await; + } - // Each node should have the four basic metrics - assert!(has_start_timestamp); - assert!(has_end_timestamp); - assert!(has_elapsed_compute); - assert!(has_output_rows); - } + #[tokio::test] + async fn test_metrics_collection_e2e_2() { + run_metrics_collection_e2e_test( + "SELECT sum(balance) / 7.0 as avg_yearly + FROM table2 + WHERE name LIKE 'customer%' + AND balance < ( + SELECT 0.2 * avg(balance) + FROM table2 t2_inner + WHERE t2_inner.id = table2.id + )", + ) + .await; + } - // TODO: once we propagate metrics from child stages, we can assert this. - assert_eq!(0, result.child_task_metrics.len()); + #[tokio::test] + async fn test_metrics_collection_e2e_3() { + run_metrics_collection_e2e_test( + "SELECT + substring(phone, 1, 2) as country_code, + count(*) as num_customers, + sum(balance) as total_balance + FROM table2 + WHERE substring(phone, 1, 2) IN ('13', '31', '23', '29', '30', '18') + AND balance > ( + SELECT avg(balance) + FROM table2 + WHERE balance > 0.00 + ) + GROUP BY substring(phone, 1, 2) + ORDER BY country_code", + ) + .await; } } diff --git a/src/execution_plans/metrics_collecting_stream.rs b/src/execution_plans/metrics_collecting_stream.rs index 1321f63..d8379b6 100644 --- a/src/execution_plans/metrics_collecting_stream.rs +++ b/src/execution_plans/metrics_collecting_stream.rs @@ -67,8 +67,8 @@ where }; metrics_collection.insert(stage_key, task_metrics.metrics); } - flight_data.app_metadata.clear(); + Ok(()) } } @@ -256,7 +256,8 @@ mod tests { // Create a stream that emits an error - should be propagated through let stream_error = FlightError::ProtocolError("stream error from inner stream".to_string()); let error_stream = stream::iter(vec![Err(stream_error)]); - let mut collecting_stream = MetricsCollectingStream::new(error_stream, metrics_collection); + let mut collecting_stream = + MetricsCollectingStream::new(error_stream, metrics_collection.clone()); let result = collecting_stream.next().await.unwrap(); assert_protocol_error(result, "stream error from inner stream"); diff --git a/src/execution_plans/mod.rs b/src/execution_plans/mod.rs index 37e13e3..48fe68b 100644 --- a/src/execution_plans/mod.rs +++ b/src/execution_plans/mod.rs @@ -5,6 +5,7 @@ mod network_shuffle; mod partition_isolator; mod stage; +pub use metrics::collect_and_create_metrics_flight_data; pub use network_coalesce::{NetworkCoalesceExec, NetworkCoalesceReady}; pub use network_shuffle::{NetworkShuffleExec, NetworkShuffleReadyExec}; pub use partition_isolator::PartitionIsolatorExec; diff --git a/src/execution_plans/network_coalesce.rs b/src/execution_plans/network_coalesce.rs index be9a5fe..489d184 100644 --- a/src/execution_plans/network_coalesce.rs +++ b/src/execution_plans/network_coalesce.rs @@ -3,6 +3,7 @@ use crate::channel_resolver_ext::get_distributed_channel_resolver; use crate::common::scale_partitioning_props; use crate::config_extension_ext::ContextGrpcMetadata; use crate::distributed_physical_optimizer_rule::{NetworkBoundary, limit_tasks_err}; +use crate::execution_plans::metrics_collecting_stream::MetricsCollectingStream; use crate::execution_plans::{DistributedTaskContext, StageExec}; use crate::flight_service::DoGet; use crate::metrics::proto::MetricsSetProto; @@ -89,15 +90,7 @@ pub struct NetworkCoalesceReady { pub(crate) properties: PlanProperties, pub(crate) stage_num: usize, pub(crate) input_tasks: usize, - /// metrics_collection is used to collect metrics from child tasks. It is empty when an - /// is instantiated (deserialized, created via [NetworkCoalesceExec::new_ready] etc...). - /// Metrics are populated in this map via [NetworkCoalesceExec::execute]. - /// - /// An instance may receive metrics for 0 to N child tasks, where N is the number of tasks in - /// the stage it is reading from. This is because, by convention, the ArrowFlightEndpoint - /// sends metrics for a task to the last NetworkCoalesceExec to read from it, which may or may - /// not be this instance. - pub(crate) metrics_collection: Arc>>, + pub(crate) child_task_metrics: Arc>>, } impl NetworkCoalesceExec { @@ -163,7 +156,7 @@ impl NetworkBoundary for NetworkCoalesceExec { }), stage_num, input_tasks: pending.input_tasks, - metrics_collection: Default::default(), + child_task_metrics: Default::default(), }; Ok(Arc::new(Self::Ready(ready))) @@ -184,10 +177,17 @@ impl NetworkBoundary for NetworkCoalesceExec { }), stage_num: ready.stage_num, input_tasks, - metrics_collection: Arc::clone(&ready.metrics_collection), + child_task_metrics: Arc::clone(&ready.child_task_metrics), }), }) } + + fn metrics_collection(&self) -> Option>>> { + match self { + NetworkCoalesceExec::Pending(_) => None, + NetworkCoalesceExec::Ready(v) => Some(v.child_task_metrics.clone()), + } + } } impl DisplayAs for NetworkCoalesceExec { @@ -297,6 +297,7 @@ impl ExecutionPlan for NetworkCoalesceExec { return internal_err!("NetworkCoalesceExec: task is unassigned, cannot proceed"); }; + let metrics_collection_capture = self_ready.child_task_metrics.clone(); let stream = async move { let channel = channel_resolver.get_channel_for_url(&url).await?; let stream = FlightServiceClient::new(channel) @@ -306,8 +307,13 @@ impl ExecutionPlan for NetworkCoalesceExec { .into_inner() .map_err(|err| FlightError::Tonic(Box::new(err))); - Ok(FlightRecordBatchStream::new_from_flight_data(stream) - .map_err(map_flight_to_datafusion_error)) + let metrics_collecting_stream = + MetricsCollectingStream::new(stream, metrics_collection_capture); + + Ok( + FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream) + .map_err(map_flight_to_datafusion_error), + ) } .try_flatten_stream(); diff --git a/src/execution_plans/network_shuffle.rs b/src/execution_plans/network_shuffle.rs index eed04b7..71f82d4 100644 --- a/src/execution_plans/network_shuffle.rs +++ b/src/execution_plans/network_shuffle.rs @@ -3,6 +3,7 @@ use crate::channel_resolver_ext::get_distributed_channel_resolver; use crate::common::scale_partitioning; use crate::config_extension_ext::ContextGrpcMetadata; use crate::distributed_physical_optimizer_rule::NetworkBoundary; +use crate::execution_plans::metrics_collecting_stream::MetricsCollectingStream; use crate::execution_plans::{DistributedTaskContext, StageExec}; use crate::flight_service::DoGet; use crate::metrics::proto::MetricsSetProto; @@ -135,15 +136,7 @@ pub struct NetworkShuffleReadyExec { /// the properties we advertise for this execution plan pub(crate) properties: PlanProperties, pub(crate) stage_num: usize, - /// metrics_collection is used to collect metrics from child tasks. It is empty when an - /// is instantiated (deserialized, created via [NetworkShuffleExec::new_ready] etc...). - /// Metrics are populated in this map via [NetworkShuffleExec::execute]. - /// - /// An instance may receive metrics for 0 to N child tasks, where N is the number of tasks in - /// the stage it is reading from. This is because, by convention, the ArrowFlightEndpoint - /// sends metrics for a task to the last NetworkShuffleExec to read from it, which may or may - /// not be this instance. - pub(crate) metrics_collection: Arc>>, + pub(crate) child_task_metrics: Arc>>, } impl NetworkShuffleExec { @@ -208,7 +201,7 @@ impl NetworkBoundary for NetworkShuffleExec { NetworkShuffleExec::Ready(prev) => NetworkShuffleExec::Ready(NetworkShuffleReadyExec { properties: prev.properties.clone(), stage_num: prev.stage_num, - metrics_collection: Arc::clone(&prev.metrics_collection), + child_task_metrics: Arc::clone(&prev.child_task_metrics), }), }) } @@ -225,11 +218,18 @@ impl NetworkBoundary for NetworkShuffleExec { let ready = NetworkShuffleReadyExec { properties: pending.repartition_exec.properties().clone(), stage_num, - metrics_collection: Default::default(), + child_task_metrics: Default::default(), }; Ok(Arc::new(Self::Ready(ready))) } + + fn metrics_collection(&self) -> Option>>> { + match self { + NetworkShuffleExec::Pending(_) => None, + NetworkShuffleExec::Ready(v) => Some(v.child_task_metrics.clone()), + } + } } impl DisplayAs for NetworkShuffleExec { @@ -330,6 +330,7 @@ impl ExecutionPlan for NetworkShuffleExec { }, ); + let metrics_collection_capture = self_ready.child_task_metrics.clone(); async move { let url = task.url.ok_or(internal_datafusion_err!( "NetworkShuffleExec: task is unassigned, cannot proceed" @@ -343,8 +344,13 @@ impl ExecutionPlan for NetworkShuffleExec { .into_inner() .map_err(|err| FlightError::Tonic(Box::new(err))); - Ok(FlightRecordBatchStream::new_from_flight_data(stream) - .map_err(map_flight_to_datafusion_error)) + let metrics_collecting_stream = + MetricsCollectingStream::new(stream, metrics_collection_capture); + + Ok( + FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream) + .map_err(map_flight_to_datafusion_error), + ) } .try_flatten_stream() .boxed() diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index b9701ff..47add9b 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -1,7 +1,10 @@ use crate::config_extension_ext::ContextGrpcMetadata; -use crate::execution_plans::{DistributedTaskContext, StageExec}; +use crate::execution_plans::{ + DistributedTaskContext, StageExec, collect_and_create_metrics_flight_data, +}; use crate::flight_service::service::ArrowFlightEndpoint; use crate::flight_service::session_builder::DistributedSessionBuilderContext; +use crate::flight_service::trailing_flight_data_stream::TrailingFlightDataStream; use crate::protobuf::{ DistributedCodec, StageExecProto, StageKey, datafusion_error_to_tonic_status, stage_from_proto, }; @@ -94,12 +97,6 @@ impl ArrowFlightEndpoint { }) .await?; let stage = Arc::clone(&stage_data.stage); - let num_partitions_remaining = Arc::clone(&stage_data.num_partitions_remaining); - - // If all the partitions are done, remove the stage from the cache. - if num_partitions_remaining.fetch_sub(1, Ordering::SeqCst) <= 1 { - self.task_data_entries.remove(key); - } // Find out which partition group we are executing let cfg = session_state.config_mut(); @@ -126,7 +123,16 @@ impl ArrowFlightEndpoint { .execute(doget.target_partition as usize, session_state.task_ctx()) .map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?; - Ok(record_batch_stream_to_response(stream)) + let task_data_capture = self.task_data_entries.clone(); + Ok(flight_stream_from_record_batch_stream( + key.clone(), + stage, + stage_data.clone(), + move || { + task_data_capture.remove(key.clone()); + }, + stream, + )) } } @@ -134,7 +140,13 @@ fn missing(field: &'static str) -> impl FnOnce() -> Status { move || Status::invalid_argument(format!("Missing field '{field}'")) } -fn record_batch_stream_to_response( +// Creates a tonic response from a stream of record batches. Handles +// - RecordBatch to flight conversion partition tracking, stage eviction, and trailing metrics. +fn flight_stream_from_record_batch_stream( + stage_key: StageKey, + stage: Arc, + stage_data: TaskData, + evict_stage: impl FnOnce() + Send + 'static, stream: SendableRecordBatchStream, ) -> Response<::DoGetStream> { let flight_data_stream = @@ -144,7 +156,31 @@ fn record_batch_stream_to_response( FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err))) })); - Response::new(Box::pin(flight_data_stream.map_err(|err| match err { + let trailing_metrics_stream = TrailingFlightDataStream::new( + move || { + if stage_data + .num_partitions_remaining + .fetch_sub(1, Ordering::SeqCst) + == 1 + { + evict_stage(); + + let metrics_stream = collect_and_create_metrics_flight_data(stage_key, stage) + .map_err(|err| { + Status::internal(format!( + "error collecting metrics in arrow flight endpoint: {err}" + )) + })?; + + return Ok(Some(metrics_stream)); + } + + Ok(None) + }, + flight_data_stream, + ); + + Response::new(Box::pin(trailing_metrics_stream.map_err(|err| match err { FlightError::Tonic(status) => *status, _ => Status::internal(format!("Error during flight stream: {err}")), }))) @@ -215,9 +251,9 @@ mod tests { let stage_proto = proto_from_stage(&stage, &DefaultPhysicalExtensionCodec {}).unwrap(); let stage_proto_for_closure = stage_proto.clone(); let endpoint_ref = &endpoint; + let do_get = async move |partition: u64, task_number: u64, stage_key: StageKey| { let stage_proto = stage_proto_for_closure.clone(); - // Create DoGet message let doget = DoGet { stage_proto: Some(stage_proto), target_task_index: task_number, @@ -225,14 +261,17 @@ mod tests { stage_key: Some(stage_key), }; - // Create Flight ticket let ticket = Ticket { ticket: Bytes::from(doget.encode_to_vec()), }; - // Call the actual get() method let request = Request::new(ticket); - endpoint_ref.get(request).await + let response = endpoint_ref.get(request).await?; + let mut stream = response.into_inner(); + + // Consume the stream. + while let Some(_flight_data) = stream.try_next().await? {} + Ok::<(), Status>(()) }; // For each task, call do_get() for each partition except the last. @@ -248,7 +287,7 @@ mod tests { // Run the last partition of task 0. Any partition number works. Verify that the task state // is evicted because all partitions have been processed. - let result = do_get(1, 0, task_keys[0].clone()).await; + let result = do_get(2, 0, task_keys[0].clone()).await; assert!(result.is_ok()); let stored_stage_keys = endpoint.task_data_entries.keys().collect::>(); assert_eq!(stored_stage_keys.len(), 2); @@ -256,14 +295,14 @@ mod tests { assert!(stored_stage_keys.contains(&task_keys[2])); // Run the last partition of task 1. - let result = do_get(1, 1, task_keys[1].clone()).await; + let result = do_get(2, 1, task_keys[1].clone()).await; assert!(result.is_ok()); let stored_stage_keys = endpoint.task_data_entries.keys().collect::>(); assert_eq!(stored_stage_keys.len(), 1); assert!(stored_stage_keys.contains(&task_keys[2])); // Run the last partition of the last task. - let result = do_get(1, 2, task_keys[2].clone()).await; + let result = do_get(2, 2, task_keys[2].clone()).await; assert!(result.is_ok()); let stored_stage_keys = endpoint.task_data_entries.keys().collect::>(); assert_eq!(stored_stage_keys.len(), 0); diff --git a/src/flight_service/mod.rs b/src/flight_service/mod.rs index fdc99c4..db3bd91 100644 --- a/src/flight_service/mod.rs +++ b/src/flight_service/mod.rs @@ -1,7 +1,7 @@ mod do_get; mod service; mod session_builder; -mod trailing_flight_data_stream; +pub(super) mod trailing_flight_data_stream; pub(crate) use do_get::DoGet; pub use service::ArrowFlightEndpoint; diff --git a/src/flight_service/service.rs b/src/flight_service/service.rs index 675cd51..df01be7 100644 --- a/src/flight_service/service.rs +++ b/src/flight_service/service.rs @@ -17,7 +17,7 @@ use tonic::{Request, Response, Status, Streaming}; pub struct ArrowFlightEndpoint { pub(super) runtime: Arc, - pub(super) task_data_entries: TTLMap>>, + pub(super) task_data_entries: Arc>>>, pub(super) session_builder: Arc, } @@ -28,7 +28,7 @@ impl ArrowFlightEndpoint { let ttl_map = TTLMap::try_new(TTLMapConfig::default())?; Ok(Self { runtime: Arc::new(RuntimeEnv::default()), - task_data_entries: ttl_map, + task_data_entries: Arc::new(ttl_map), session_builder: Arc::new(session_builder), }) } diff --git a/src/flight_service/trailing_flight_data_stream.rs b/src/flight_service/trailing_flight_data_stream.rs index db85903..ebdf53e 100644 --- a/src/flight_service/trailing_flight_data_stream.rs +++ b/src/flight_service/trailing_flight_data_stream.rs @@ -8,22 +8,24 @@ use tokio::pin; /// TrailingFlightDataStream - wraps a FlightData stream. It calls the `on_complete` closure when the stream is finished. /// If the closure returns a new stream, it will be appended to the original stream and consumed. #[pin_project] -pub struct TrailingFlightDataStream +pub struct TrailingFlightDataStream where S: Stream> + Send, - F: FnOnce() -> Result, FlightError>, + T: Stream> + Send, + F: FnOnce() -> Result, FlightError>, { #[pin] inner: S, on_complete: Option, #[pin] - trailing_stream: Option, + trailing_stream: Option, } -impl TrailingFlightDataStream +impl TrailingFlightDataStream where S: Stream> + Send, - F: FnOnce() -> Result, FlightError>, + T: Stream> + Send, + F: FnOnce() -> Result, FlightError>, { // TODO: remove #[allow(dead_code)] @@ -36,10 +38,11 @@ where } } -impl Stream for TrailingFlightDataStream +impl Stream for TrailingFlightDataStream where S: Stream> + Send, - F: FnOnce() -> Result, FlightError>, + T: Stream> + Send, + F: FnOnce() -> Result, FlightError>, { type Item = Result; @@ -74,7 +77,7 @@ mod tests { use arrow::record_batch::RecordBatch; use arrow_flight::FlightData; use arrow_flight::decode::FlightRecordBatchStream; - use arrow_flight::encode::FlightDataEncoderBuilder; + use arrow_flight::encode::{FlightDataEncoder, FlightDataEncoderBuilder}; use futures::stream::{self, StreamExt}; use std::sync::Arc; @@ -186,7 +189,7 @@ mod tests { )))), ]; let inner_stream = stream::iter(data); - let on_complete = || Ok(None); + let on_complete = || Ok(None::); let trailing_stream = TrailingFlightDataStream::new(on_complete, inner_stream); let record_batches = FlightRecordBatchStream::new_from_flight_data(trailing_stream) .collect::>>() @@ -202,8 +205,7 @@ mod tests { let name_array = StringArray::from(vec!["item1"]); let value_array = Int32Array::from(vec![1]); let inner_stream = create_flight_data_stream(name_array, value_array); - - let on_complete = || -> Result, FlightError> { + let on_complete = || -> Result, FlightError> { Err(FlightError::ExternalError(Box::new(std::io::Error::new( std::io::ErrorKind::Other, "callback error", @@ -225,7 +227,7 @@ mod tests { StringArray::from(vec!["item1"] as Vec<&str>), Int32Array::from(vec![1] as Vec), ); - let on_complete = || Ok(None); + let on_complete = || Ok(None::); let trailing_stream = TrailingFlightDataStream::new(on_complete, inner_stream); let record_batches = FlightRecordBatchStream::new_from_flight_data(trailing_stream) .collect::>>() diff --git a/src/metrics/proto.rs b/src/metrics/proto.rs index 4222d53..0206bc3 100644 --- a/src/metrics/proto.rs +++ b/src/metrics/proto.rs @@ -380,30 +380,28 @@ pub fn metric_proto_to_df(metric: MetricProto) -> Result, DataFusion labels, ))) } - Some(MetricValueProto::StartTimestamp(start_ts)) => match start_ts.value { - Some(value) => { - let timestamp = Timestamp::new(); + Some(MetricValueProto::StartTimestamp(start_ts)) => { + let timestamp = Timestamp::new(); + if let Some(value) = start_ts.value { timestamp.set(DateTime::from_timestamp_nanos(value)); - Ok(Arc::new(Metric::new_with_labels( - MetricValue::StartTimestamp(timestamp), - partition, - labels, - ))) } - None => internal_err!("encountered invalid start timestamp metric with no value"), - }, - Some(MetricValueProto::EndTimestamp(end_ts)) => match end_ts.value { - Some(value) => { - let timestamp = Timestamp::new(); + Ok(Arc::new(Metric::new_with_labels( + MetricValue::StartTimestamp(timestamp), + partition, + labels, + ))) + } + Some(MetricValueProto::EndTimestamp(end_ts)) => { + let timestamp = Timestamp::new(); + if let Some(value) = end_ts.value { timestamp.set(DateTime::from_timestamp_nanos(value)); - Ok(Arc::new(Metric::new_with_labels( - MetricValue::EndTimestamp(timestamp), - partition, - labels, - ))) } - None => internal_err!("encountered invalid end timestamp metric with no value"), - }, + Ok(Arc::new(Metric::new_with_labels( + MetricValue::EndTimestamp(timestamp), + partition, + labels, + ))) + } None => internal_err!("proto metric is missing the metric field"), } } @@ -853,18 +851,22 @@ mod tests { } #[test] - fn test_invalid_proto_timestamp_error() { - // Create a MetricProto with EndTimestamp that has no value (None) - let invalid_end_timestamp_proto = MetricProto { - metric: Some(MetricValueProto::EndTimestamp(EndTimestamp { value: None })), - labels: vec![], - partition: Some(0), - }; - - let result = metric_proto_to_df(invalid_end_timestamp_proto); + fn test_default_timestamp_roundtrip() { + let default_timestamp = Timestamp::default(); + let metric_with_default_timestamp = + Metric::new(MetricValue::EndTimestamp(default_timestamp), Some(0)); + + let proto_result = df_metric_to_proto(Arc::new(metric_with_default_timestamp)); + assert!( + proto_result.is_ok(), + "should successfully convert default timestamp to proto" + ); + + let proto_metric = proto_result.unwrap(); + let roundtrip_result = metric_proto_to_df(proto_metric); assert!( - result.is_err(), - "should return error for invalid end timestamp with no value" + roundtrip_result.is_ok(), + "should successfully roundtrip default timestamp" ); } } diff --git a/src/protobuf/distributed_codec.rs b/src/protobuf/distributed_codec.rs index 67c8fbd..ea9449b 100644 --- a/src/protobuf/distributed_codec.rs +++ b/src/protobuf/distributed_codec.rs @@ -234,7 +234,7 @@ fn new_network_hash_shuffle_exec( Boundedness::Bounded, ), stage_num, - metrics_collection: Default::default(), + child_task_metrics: Default::default(), }) } @@ -268,7 +268,7 @@ fn new_network_coalesce_tasks_exec( ), stage_num, input_tasks, - metrics_collection: Default::default(), + child_task_metrics: Default::default(), }) } diff --git a/src/test_utils/mod.rs b/src/test_utils/mod.rs index 140bd63..a5b3bec 100644 --- a/src/test_utils/mod.rs +++ b/src/test_utils/mod.rs @@ -4,5 +4,6 @@ pub mod localhost; pub mod metrics; pub mod mock_exec; pub mod parquet; +pub mod plans; pub mod session_context; pub mod tpch; diff --git a/src/test_utils/plans.rs b/src/test_utils/plans.rs new file mode 100644 index 0000000..5bb7bce --- /dev/null +++ b/src/test_utils/plans.rs @@ -0,0 +1,66 @@ +use std::sync::Arc; + +use datafusion::{ + common::{HashMap, HashSet}, + physical_plan::ExecutionPlan, +}; + +use crate::{ + StageExec, + execution_plans::{NetworkCoalesceExec, NetworkShuffleExec}, + protobuf::StageKey, +}; + +/// count_plan_nodes counts the number of execution plan nodes in a plan using BFS traversal. +/// This does NOT traverse child stages, only the execution plan tree within this stage. +/// Excludes [NetworkBoundary] nodes from the count. +pub fn count_plan_nodes(plan: &Arc) -> usize { + let mut count = 0; + let mut queue = vec![plan]; + + while let Some(plan) = queue.pop() { + // Skip [NetworkBoundary] nodes from the count. + if !plan.as_any().is::() && !plan.as_any().is::() { + count += 1; + } + + // Add children to the queue for BFS traversal + for child in plan.children() { + queue.push(child); + } + } + count +} + +/// Returns +/// - a map of all stages +/// - a set of all the stage keys (one per task) +pub fn get_stages_and_stage_keys( + stage: &StageExec, +) -> (HashMap, HashSet) { + let query_id = stage.query_id; + let mut i = 0; + let mut queue = vec![stage]; + let mut stage_keys = HashSet::new(); + let mut stages_map = HashMap::new(); + + while i < queue.len() { + let stage = queue[i]; + stages_map.insert(stage.num, stage); + i += 1; + + // Add each task. + for j in 0..stage.tasks.len() { + let stage_key = StageKey { + query_id: query_id.to_string(), + stage_id: stage.num as u64, + task_number: j as u64, + }; + stage_keys.insert(stage_key); + } + + // Add any child stages + queue.extend(stage.child_stages_iter()); + } + (stages_map, stage_keys) +} diff --git a/tests/stateful_execution_plan.rs b/tests/stateful_execution_plan.rs new file mode 100644 index 0000000..5372146 --- /dev/null +++ b/tests/stateful_execution_plan.rs @@ -0,0 +1,278 @@ +#[cfg(all(feature = "integration", test))] +mod tests { + use datafusion::arrow::array::Int64Array; + use datafusion::arrow::compute::SortOptions; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::arrow::record_batch::RecordBatch; + use datafusion::arrow::util::pretty::pretty_format_batches; + use datafusion::common::runtime::SpawnedTask; + use datafusion::error::DataFusionError; + use datafusion::execution::{ + FunctionRegistry, SendableRecordBatchStream, SessionState, SessionStateBuilder, TaskContext, + }; + use datafusion::logical_expr::Operator; + use datafusion::physical_expr::expressions::{BinaryExpr, col, lit}; + use datafusion::physical_expr::{ + EquivalenceProperties, LexOrdering, Partitioning, PhysicalSortExpr, + }; + use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; + use datafusion::physical_plan::filter::FilterExec; + use datafusion::physical_plan::repartition::RepartitionExec; + use datafusion::physical_plan::sorts::sort::SortExec; + use datafusion::physical_plan::stream::RecordBatchStreamAdapter; + use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, displayable, execute_stream, + }; + use datafusion_distributed::test_utils::localhost::start_localhost_context; + use datafusion_distributed::{ + DistributedExt, DistributedSessionBuilderContext, PartitionIsolatorExec, assert_snapshot, + }; + use datafusion_distributed::{DistributedPhysicalOptimizerRule, NetworkShuffleExec}; + use datafusion_proto::physical_plan::PhysicalExtensionCodec; + use datafusion_proto::protobuf::proto_error; + use futures::TryStreamExt; + use prost::Message; + use std::any::Any; + use std::fmt::Formatter; + use std::sync::{Arc, RwLock}; + use std::time::Duration; + use tokio::sync::mpsc; + use tokio_stream::StreamExt; + use tokio_stream::wrappers::ReceiverStream; + + #[tokio::test] + async fn stateful_execution_plan() -> Result<(), Box> { + async fn build_state( + ctx: DistributedSessionBuilderContext, + ) -> Result { + Ok(SessionStateBuilder::new() + .with_runtime_env(ctx.runtime_env) + .with_default_features() + .with_distributed_user_codec(Int64ListExecCodec) + .build()) + } + + let (ctx, _guard) = start_localhost_context(3, build_state).await; + + let distributed_plan = build_plan()?; + let distributed_plan = DistributedPhysicalOptimizerRule::distribute_plan(distributed_plan)?; + + assert_snapshot!(displayable(&distributed_plan).indent(true).to_string(), @r" + ┌───── Stage 3 Tasks: t0:[p0] + │ SortExec: expr=[numbers@0 DESC NULLS LAST], preserve_partitioning=[false] + │ RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=10 + │ NetworkShuffleExec read_from=Stage 2, output_partitions=10, n_tasks=1, input_tasks=10 + └────────────────────────────────────────────────── + ┌───── Stage 2 Tasks: t0:[p0,p1,p2,p3,p4,p5,p6,p7,p8,p9] t1:[p10,p11,p12,p13,p14,p15,p16,p17,p18,p19] t2:[p20,p21,p22,p23,p24,p25,p26,p27,p28,p29] t3:[p30,p31,p32,p33,p34,p35,p36,p37,p38,p39] t4:[p40,p41,p42,p43,p44,p45,p46,p47,p48,p49] t5:[p50,p51,p52,p53,p54,p55,p56,p57,p58,p59] t6:[p60,p61,p62,p63,p64,p65,p66,p67,p68,p69] t7:[p70,p71,p72,p73,p74,p75,p76,p77,p78,p79] t8:[p80,p81,p82,p83,p84,p85,p86,p87,p88,p89] t9:[p90,p91,p92,p93,p94,p95,p96,p97,p98,p99] + │ RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + │ SortExec: expr=[numbers@0 DESC NULLS LAST], preserve_partitioning=[false] + │ NetworkShuffleExec read_from=Stage 1, output_partitions=1, n_tasks=10, input_tasks=1 + └────────────────────────────────────────────────── + ┌───── Stage 1 Tasks: t0:[p0,p1,p2,p3,p4,p5,p6,p7,p8,p9] + │ RepartitionExec: partitioning=Hash([numbers@0], 10), input_partitions=1 + │ FilterExec: numbers@0 > 1 + │ StatefulInt64ListExec: length=6 + └────────────────────────────────────────────────── + "); + + let stream = execute_stream(Arc::new(distributed_plan), ctx.task_ctx())?; + let batches_distributed = stream.try_collect::>().await?; + + assert_snapshot!(pretty_format_batches(&batches_distributed).unwrap(), @r" + +---------+ + | numbers | + +---------+ + | 6 | + | 5 | + | 4 | + | 3 | + | 2 | + +---------+ + "); + Ok(()) + } + + fn build_plan() -> Result, DataFusionError> { + let mut plan: Arc = + Arc::new(StatefulInt64ListExec::new(vec![1, 2, 3, 4, 5, 6])); + + plan = Arc::new(PartitionIsolatorExec::new_pending(plan)); + + plan = Arc::new(FilterExec::try_new( + Arc::new(BinaryExpr::new( + col("numbers", &plan.schema())?, + Operator::Gt, + lit(1i64), + )), + plan, + )?); + + plan = Arc::new(NetworkShuffleExec::try_new( + Arc::clone(&plan), + Partitioning::Hash(vec![col("numbers", &plan.schema())?], 1), + 10, + )?); + + plan = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new( + col("numbers", &plan.schema())?, + SortOptions::new(true, false), + )]) + .unwrap(), + plan, + )); + + plan = Arc::new(NetworkShuffleExec::try_new( + plan, + Partitioning::RoundRobinBatch(10), + 10, + )?); + + plan = Arc::new(RepartitionExec::try_new( + plan, + Partitioning::RoundRobinBatch(1), + )?); + + plan = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new( + col("numbers", &plan.schema())?, + SortOptions::new(true, false), + )]) + .unwrap(), + plan, + )); + + Ok(plan) + } + + #[derive(Debug)] + pub struct StatefulInt64ListExec { + plan_properties: PlanProperties, + numbers: Vec, + task: RwLock>>, + tx: RwLock>>, + rx: RwLock>>, + } + + impl StatefulInt64ListExec { + fn new(numbers: Vec) -> Self { + let schema = Schema::new(vec![Field::new("numbers", DataType::Int64, false)]); + let (tx, rx) = mpsc::channel(10); + Self { + numbers, + plan_properties: PlanProperties::new( + EquivalenceProperties::new(Arc::new(schema)), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ), + task: RwLock::new(None), + tx: RwLock::new(Some(tx)), + rx: RwLock::new(Some(rx)), + } + } + } + + impl DisplayAs for StatefulInt64ListExec { + fn fmt_as(&self, _: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + write!(f, "StatefulInt64ListExec: length={:?}", self.numbers.len()) + } + } + + impl ExecutionPlan for StatefulInt64ListExec { + fn name(&self) -> &str { + "StatefulInt64ListExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.plan_properties + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> datafusion::common::Result> { + Ok(self) + } + + fn execute( + &self, + _: usize, + _: Arc, + ) -> datafusion::common::Result { + if let Some(tx) = self.tx.write().unwrap().take() { + let numbers = self.numbers.clone(); + self.task + .write() + .unwrap() + .replace(SpawnedTask::spawn(async move { + for n in numbers { + tx.send(n).await.unwrap(); + tokio::time::sleep(Duration::from_millis(100)).await; + } + })); + } + + let rx = self.rx.write().unwrap().take().unwrap(); + let schema = self.schema(); + + let stream = ReceiverStream::new(rx).map(move |v| { + RecordBatch::try_new(schema.clone(), vec![Arc::new(Int64Array::from(vec![v]))]) + .map_err(DataFusionError::from) + }); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema().clone(), + stream, + ))) + } + } + + #[derive(Debug)] + struct Int64ListExecCodec; + + #[derive(Clone, PartialEq, ::prost::Message)] + struct Int64ListExecProto { + #[prost(message, repeated, tag = "1")] + numbers: Vec, + } + + impl PhysicalExtensionCodec for Int64ListExecCodec { + fn try_decode( + &self, + buf: &[u8], + _: &[Arc], + _registry: &dyn FunctionRegistry, + ) -> datafusion::common::Result> { + let node = + Int64ListExecProto::decode(buf).map_err(|err| proto_error(format!("{err}")))?; + Ok(Arc::new(StatefulInt64ListExec::new(node.numbers.clone()))) + } + + fn try_encode( + &self, + node: Arc, + buf: &mut Vec, + ) -> datafusion::common::Result<()> { + let Some(plan) = node.as_any().downcast_ref::() else { + return Err(proto_error(format!( + "Expected plan to be of type Int64ListExec, but was {}", + node.name() + ))); + }; + Int64ListExecProto { + numbers: plan.numbers.clone(), + } + .encode(buf) + .map_err(|err| proto_error(format!("{err}"))) + } + } +} \ No newline at end of file